from typing import Literal, Optional, List
from datetime import timedelta, datetime, date
from pyspark.sql import DataFrame, Column, functions as F, Window as W, types as T
from yipit_databricks_utils.helpers.pyspark_utils import get_spark_session
from yipit_databricks_utils.helpers.telemetry import track_usage
from etl_toolkit import expressions as E
from etl_toolkit.exceptions import InvalidInputException
from etl_toolkit.analyses.card.core import (
BaseCardTableConfiguration,
get_table_configuration_for_dataset,
)
from etl_toolkit.analyses.scalar import get_aggregates
from etl_toolkit.analyses.calculation import (
add_percent_of_total_columns,
add_percent_fill_columns,
)
MINIMUM_THRESHOLD = 63
MAXIMUM_THRESHOLD = 126
PLACEHOLDER = F.lit("UNKNOWN")
MAX_PERCENT = F.lit(1.0)
DECIMAL_PRECISION = T.DecimalType(precision=14, scale=2)
@track_usage
[docs]
def add_card_day_of_week_lag_adjustment(
df: DataFrame,
threshold: int = 84,
dataset: Literal["skywalker", "yoda", "mando"] = "yoda",
transaction_type: Literal["debit", "credit"] = "debit",
panel_ids: Optional[list[str]] = None,
group_column_names: Optional[list[str]] = None,
max_adjustment: float = 5.0,
use_central_adjustment: bool = False,
enable_paneling: bool = False,
) -> DataFrame:
"""
Add additional column(s) to an input dataframe, ``df``, of card data that calculates adjustments
to account for the merchant-specific lag of when transactions (txns) are processed.
The lag addressed in this function is due to the day-of-week ("DOW") a transaction falls on.
The additional columns added to the input ``df`` for this adjustment include:
* ``txns``: The number of transactions found within each grouping of the ``group_column_names``, the dataset's card type, the DOW of the transaction, and the number of days (lag) between the transaction and the max file date of the df. Int column type.
* ``txns_percent_fill``: Percentage of ``txns`` that have been observed by the above grouping, ordered by the lag calculation. Double column type.
* ``lag_adjustment``: Equal to inverse of the ``txns_percent_fill``, this is the adjustment factor that should be multiplied to the transaction amount to account for the DOW lag behavior. Defaults to 1 if no lag is observed. Double column type.
For further context, certain merchants may have lag patterns that are different from the lag pattern observed when looking at the whole card dataset.
For these merchants, using this lag adjustment functions allows you to calculate a more accurate lag adjustment.
* For example in Skywalker, we have noticed transactions from merchants like Toast and Starbucks arrive faster than typical Skywalker transactions, which causes any central lag adjustment to consistently overestimate for these merchants.
* This trend is not as prominent in Yoda, however, it is still preferred to use a merchant-specific lag adjustment.
:param df: Input dataframe to add adjustment columns to through this function. It should be generated from the txns_all table from Skywalker or Yoda table and parsed for specific merchant(s).
:param threshold: An integer indicating the number of days prior to the maximum file date to calculate the lag adjustment for each cobrand, card type, and any other grouping columns. This value must be between 63 and 126 days.
:param dataset: The source dataset that the input ``df`` is derived from. Can be either ``yoda`` or ``skywalker``. Defaults to ``yoda``.
:param transaction_type: The type of transactions that should be used to calculate the adjustment. Can be either ``credit`` or ``debit`` and is only applied when ``dataset="skywlaker"``. Defaults to ``debit``.
:param panel_ids: A list of strings indicated the Panel ID(s) (panel table names) that this adjustment should account for. Defaults to panel ``fixed_201901_100`` for Yoda and ``fixed_201901_111_ex`` for Skywalker.
:param group_column_names: An optional list of grouping columns to generate lag adjustments within. This can be used if the input dataframe spans multiple slices where the lag behavior can be different. Defaults to ``["cobrid"]`` for Yoda and ``["yipit_cobrand_id"]`` for Skywalker.
:param max_adjustment: An upper bound value to apply for the final ``lag_adjustment`` column. Defaults to ``5.0`` , which indicates the adjustment value cannot be a factor greater than 5.
:param use_central_adjustment: Deprecated feature. An optional boolean to use the central lag adjustment value from the source dataframe. Using merchant-specific DOW adjustments is preferred. Defaults to ``False``.
.. tip:: One of the arguments for this function is ``threshold``, a couple of notes on this:
* Lag behavior can change over time, if your merchant's lag behavior has changed in the last 126 days, you may want to limit the ``threshold`` to a lower number.
* It may be beneficial to alter the threshold, however, if you do, you should test several different thresholds and QA the results.
Examples
^^^^^^^^^^
.. code-block:: python
:caption: Example of applying day of week adjustments to a Yoda dataset.
from etl_toolkit import E, F, A
tost_txns = spark.table("yd_production.tost_silver.txns_paneled")
display(
A.add_card_day_of_week_lag_adjustment(
tost_txns,
dataset="yoda",
)
)
+--------------+--------------+--------------+----------------+----------------+-------------------+------------------+
|cobrid |cardtype |trans_dow |... |txns |txns_percent_fill |lag_adjustment |
+--------------+--------------+--------------+----------------+----------------+-------------------+------------------+
| 175| DEBIT| 4| ...| 1154536| 0.890645002553501|1.1227818009790382|
+--------------+--------------+--------------+----------------+----------------+-------------------+------------------+
| 174| DEBIT| 5| ...| null| null| 1|
+--------------+--------------+--------------+----------------+----------------+-------------------+------------------+
.. code-block:: python
:caption: Example of applying day of week adjustments to a Yoda dataset while setting a ``max_adjustment``. Note how it establishes a ceiling for the ``lag_adjustment`` column.
from etl_toolkit import E, F, A
tost_txns = spark.table("yd_production.tost_silver.txns_paneled")
display(
A.add_card_day_of_week_lag_adjustment(
tost_txns,
dataset="yoda",
max_adjustment=1.1,
)
)
+--------------+--------------+--------------+----------------+----------------+-------------------+------------------+
|cobrid |cardtype |trans_dow |... |txns |txns_percent_fill |lag_adjustment |
+--------------+--------------+--------------+----------------+----------------+-------------------+------------------+
| 175| DEBIT| 4| ...| 1154536| 0.890645002553501| 1.1|
+--------------+--------------+--------------+----------------+----------------+-------------------+------------------+
| 174| DEBIT| 5| ...| null| null| 1|
+--------------+--------------+--------------+----------------+----------------+-------------------+------------------+
"""
spark = get_spark_session()
_validate_threshold(threshold)
table_configuration = get_table_configuration_for_dataset(
dataset,
panel_ids=panel_ids,
group_column_names=group_column_names,
)
max_date = get_aggregates(df, table_configuration.file_date_column_name, ["max"])[
"max"
]
start_date = max_date - timedelta(days=threshold)
end_date = max_date - timedelta(days=28)
if dataset == "skywalker":
df = df.filter(F.col("txn_type") == transaction_type)
lag_data_base = df.filter(
E.between(table_configuration.date_column_name, start_date, end_date)
)
if enable_paneling:
lag_data_base = lag_data_base.join(
table_configuration.panel_df, ["user_id"], how="left_semi"
)
lag_data = (
lag_data_base.select(
F.datediff(
E.normalize_column(table_configuration.file_date_column_name),
E.normalize_column(table_configuration.date_column_name),
).alias("lag"),
F.coalesce(
E.normalize_column(table_configuration.card_type_column_name),
PLACEHOLDER,
).alias(table_configuration.card_type_column_name),
*[
F.coalesce(E.normalize_column(group_column_name), PLACEHOLDER).alias(
group_column_name
)
for group_column_name in table_configuration.group_column_names
],
E.normalize_column(table_configuration.day_of_week_column_name),
F.lit(1).alias("txns"),
)
.groupBy(*table_configuration.grouping_columns_with_lag)
.agg(F.count("txns").alias("txns"))
)
distinct_groupings = _get_distinct_groupings_with_lag(
lag_data,
table_configuration,
)
lag_adjustment = add_percent_fill_columns(
distinct_groupings.join(
lag_data, table_configuration.grouping_columns_with_lag, how="left"
).select(
F.col("lag").alias("current_lag_calculated"),
*table_configuration.group_column_names,
table_configuration.card_type_column_name,
table_configuration.day_of_week_column_name,
F.coalesce(F.col("txns"), F.lit(0)).alias("txns"),
),
value_columns=["txns"],
total_grouping_columns=table_configuration.grouping_columns_with_day_of_week,
order_columns=["current_lag_calculated"],
)
if use_central_adjustment:
lag_adjustment_expression = F.pow(F.col("central_lag_adjustment_dow"), -1)
else:
lag_adjustment_expression = F.pow(
F.greatest(
F.coalesce(F.col("txns_percent_fill"), MAX_PERCENT),
F.try_divide(MAX_PERCENT, F.lit(max_adjustment)),
),
-1,
)
lag_added = (
df.withColumns(
{
"current_lag_calculated": F.datediff(
F.lit(max_date),
E.normalize_column(table_configuration.date_column_name),
)
}
)
.join(
lag_adjustment,
table_configuration.grouping_columns_with_current_lag,
how="left",
)
.withColumns(
{
# Coalescing percent_fill again in case there is a join key that doesn't match
# Then apply a floor based on the max_adjustment argument
"lag_adjustment": lag_adjustment_expression
}
)
)
return lag_added
def _validate_threshold(threshold: int):
if threshold < MINIMUM_THRESHOLD:
raise InvalidInputException(
f"threshold input is under the minimum value allowed, set to {MINIMUM_THRESHOLD} or higher"
)
if threshold > MAXIMUM_THRESHOLD:
raise InvalidInputException(
f"threshold input is over the maximum value allowed, set to {MAXIMUM_THRESHOLD} or lower"
)
def _get_distinct_groupings_with_lag(
lag_data: DataFrame,
table_configuration: BaseCardTableConfiguration,
lag_periods: int = 29,
) -> DataFrame:
distinct_groupings = lag_data.select(
table_configuration.grouping_columns
).distinct()
return distinct_groupings.withColumns(
{
"lag": F.explode(F.array([F.lit(idx) for idx in range(lag_periods)])),
table_configuration.day_of_week_column_name: F.explode(
F.array([F.lit(idx) for idx in range(1, 8)])
),
}
)
@track_usage
[docs]
def add_card_date_adjustment(
df: DataFrame,
date_column_name: str = "date",
group_column_names: Optional[list[str]] = None,
adjustment_type: Literal["revenue", "count"] = "revenue",
correct_jan_24: bool = True,
dataset: Literal["skywalker", "mando"] = "skywalker",
source_country_list: List = ["uk_de_at", "fr"],
amount_column: str = "amount_usd",
) -> DataFrame:
"""
Modifies the ``date_column_name`` specified for the input dataframe (``df``) to account for delays in processing transactions.
For Skywalker data, this is based on the ``yipit_cobrand_id`` and card ``source``. For Mando data, this is based on the ``source_country``
and ``card_type``. The date adjustment is determined by the ``adjustment_type``, which indicates either the ``count`` of transactions
or total ``revenue`` for the transactions observed during each period. This adjustment is necessary given delays in processing transactions
due to holidays or non-business days that can distort the actual date a transaction occurred.
The additional or modified columns added to the input ``df`` for this adjustment include:
* ``<date_column_name>``: This column will be modified to reflect the corrected date based on the expected lag. Date type.
* ``<date_column_name>_raw``: A new column added that indicates the original, non-adjusted date for the transaction. Date type.
:param df: Input dataframe to add adjustment columns to through this function. It should be generated from the txns_all table from either Skywalker or Mando and parsed for specific merchant(s).
:param date_column_name: The date column name of the ``df`` to adjust. This function will correct the dates of transactions within this column. Defaults to ``date``.
:param group_column_names: An optional list of grouping columns to generate lag adjustments within. This can be used if the input dataframe spans multiple slices where the date behavior can be different. Defaults to None, which indicates there are no slices to group by for this adjustment.
:param adjustment_type: Indicates whether the date adjustment should be calculated based on ``revenue`` or transaction ``count``. Defaults to ``revenue``.
:param correct_jan_24: Optional flag to control if a subset of 2024-01-23 transactions should be assigned to 2024-01-22 via a deterministic sample. Defaults to ``True``.
:param dataset: The source dataset that the input ``df`` is derived from. Can be either ``skywalker`` or ``mando``. Defaults to ``skywalker``.
:param source_country_list: For Mando data only. List of countries to process. Must be one or both of ``uk_de_at`` and ``fr``. Defaults to ``["uk_de_at", "fr"]``.
:param amount_column: For Mando data only. The name of the amount column to use for revenue adjustments. Defaults to ``amount_usd``.
Examples
^^^^^^^^^^
.. code-block:: python
:caption: Example of applying the date adjustment to a Skywalker dataset. Note how the date column values are changed and a new date_raw column is added with the original values.
from etl_toolkit import E, F, A
cmg_txns = spark.table(f'yd_production.cmg_silver.txns_paneled')
display(
A.add_card_date_adjustment(cmg_txns)
)
+--------------+----------------+--------------+----------------+----------------+
|source |yipit_cobrand_id|date_raw |... |date |
+--------------+----------------+--------------+----------------+----------------+
| bank| 1| 2019-02-19| ...| 2019-02-16|
+--------------+----------------+--------------+----------------+----------------+
| bank| 1| 2019-02-19| ...| 2019-02-16|
+--------------+----------------+--------------+----------------+----------------+
.. code-block:: python
:caption: Example of applying the date adjustment to a Mando dataset for specific countries.
from etl_toolkit import E, F, A
eu_txns = spark.table(f'yd_production.eu_silver.txns_paneled')
display(
A.add_card_date_adjustment(
eu_txns,
dataset="mando",
source_country_list=["uk_de_at"],
amount_column="amount_eur"
)
)
+------------+--------------+--------------+----------------+----------------+
|card_type |source_country|date_raw |... |date |
+------------+--------------+--------------+----------------+----------------+
| DEBIT | uk_de_at | 2019-02-19| ...| 2019-02-16|
+------------+--------------+--------------+----------------+----------------+
| CREDIT | uk_de_at | 2019-02-19| ...| 2019-02-19|
+------------+--------------+--------------+----------------+----------------+
"""
spark = get_spark_session()
if adjustment_type not in ("count", "revenue"):
raise InvalidInputException(
"'adjustment_type' is invalid, should be either 'revenue' or 'count'"
)
if dataset not in ("skywalker", "mando"):
raise InvalidInputException(
"'dataset' is invalid, must be 'skwyalker' or 'mando'"
)
for country in source_country_list:
if country not in ["uk_de_at", "fr"]:
raise InvalidInputException(
"For source_country_list, input a list containing 'uk_de_at' and/or 'fr'. All other inputs will cause a failure."
)
table_configuration = get_table_configuration_for_dataset(
dataset,
group_column_names=group_column_names,
)
primary_dataset_segment = table_configuration.primary_dataset_segment
cardtype_column_name = table_configuration.card_type_column_name
credit_value = table_configuration.credit_value
debit_value = table_configuration.debit_value
if dataset == "mando":
correct_jan_24 = False
elif dataset == "skywalker":
amount_column = "amount"
if correct_jan_24:
if dataset == "skywalker":
# Take a partial sample of txns on 2024-01-23 and re-assign to 2024-01-22
# The sample is meant to be consistent between runs by hashing the txn ID and using that to determine samping
jan_24_adjustment_condition = E.all(
F.col(date_column_name) == date(2024, 1, 23),
F.col("source") == "bank",
F.col("yipit_cobrand_id") == "1",
)
# Generate a 32-bit hash integer from the txn and then calculate its percentage in the available space of 32-bit numbers
# That percentage is then used to see if the txn falls within the sample
# Spark uses postivie and negative (signed) integers for F.hash, so need to offset by 2^31 to get a range of positive integers
# that can mapped to the range (0, 1)
hashed_txn_id_sample_float = F.try_divide(
F.hash(F.col("txn_id")) + F.pow(2, 31).cast("bigint"),
((F.pow(2, 32).cast("bigint")) - 1),
)
df = df.select(
[
F.when(
E.all(
jan_24_adjustment_condition,
hashed_txn_id_sample_float <= F.lit(0.5067),
),
F.lit(date(2024, 1, 22)),
)
.otherwise(F.col(date_column_name))
.alias(date_column_name)
if col == date_column_name
else col
for col in df.columns
]
)
else:
raise InvalidInputException(
"Cannot use correct_jan_24 with non-skywalker datasets"
)
if group_column_names is None:
using_groups = False
group_column_names = ["group_for_adjustment"]
group_columns = [F.lit("all").alias("group_for_adjustment")]
else:
using_groups = True
group_columns = [F.col(group_column) for group_column in group_column_names]
# The time.days_filled table indicates how many non-business days have elapsed between two days
# ex: July 4 is thursday then the following friday would be representing 2 days of txns. if july 4 were monday then the following tuesday would represent 4 day
# this is important in determine the potential duration of the lag to apply the adjustment
if dataset == "mando":
df_time = (spark.table("yd_1p_central.time.european_days_filled")).alias(
"df_time"
) ## Need to use European fill days for Mando
df_excluded_countries = df.filter(
~F.col(primary_dataset_segment).isin(source_country_list)
).withColumn(f"{date_column_name}_raw", F.col(f"{date_column_name}"))
df = (
df.filter(F.col(primary_dataset_segment).isin(source_country_list))
.withColumn(amount_column, F.col(amount_column).cast(DECIMAL_PRECISION))
.alias("df")
)
if dataset == "skywalker":
df_time = spark.table("yd_1p_central.time.days_filled").alias("df_time")
df_fill_days = (
df.alias("df_input")
.join(df_time, df[date_column_name] == df_time.day, how="left")
.select(
"df_input.*",
df_time.fill_status,
df_time.fill_day,
df_time.fill_count,
*group_columns,
F.hash(F.col("txn_id")).alias("hash_id"),
)
).alias("df_fill_days")
running_window = ( # For the Mando running window there is no yipit_cobrand_id so we can't partition by this column
W.partitionBy(
F.col("fill_day"),
F.col(
primary_dataset_segment
), # This is unique to Mando, its the closest equivalent to cobrand, but it refers to which country/country table the data comes from
*group_columns,
F.col(
cardtype_column_name
), # This is the same thing as source for Skywalker
)
.orderBy(F.col("hash_id"))
.rowsBetween(W.unboundedPreceding, W.currentRow)
)
df_panel = df_fill_days.withColumns(
{
"running_revenue": F.sum(amount_column)
.over(running_window)
.cast(DECIMAL_PRECISION),
"running_count": F.count(F.lit(1))
.over(running_window)
.cast(DECIMAL_PRECISION),
}
).alias("df_panel")
df_time_filtered = (
df_time.withColumn(date_column_name, F.col("day"))
.filter(
E.between(
F.col("day"),
date(2016, 1, 1),
F.current_date(),
include_lower_bound=True,
include_upper_bound=False,
)
)
.alias("df_time_filtered")
)
df_dates = (
df_time_filtered.crossJoin(
df_panel.select(
primary_dataset_segment,
cardtype_column_name,
*group_columns,
).distinct()
).select(
df_time_filtered[date_column_name],
df_time_filtered.fill_day,
df_time_filtered.fill_count,
df_panel[primary_dataset_segment],
df_panel[cardtype_column_name],
*[df_panel[group_column] for group_column in group_column_names],
)
).alias("df_dates")
df_card_summary_prep = (
df_dates.filter(F.col(cardtype_column_name) == credit_value)
.join(
df_panel,
E.all(
df_dates[date_column_name] == df_panel[date_column_name],
df_dates[cardtype_column_name] == df_panel[cardtype_column_name],
df_dates[primary_dataset_segment] == df_panel[primary_dataset_segment],
*[
df_dates[group_column] == df_panel[group_column]
for group_column in group_column_names
],
),
how="left",
)
.groupby(
df_dates[date_column_name],
df_dates.fill_day,
df_dates.fill_count,
df_dates[cardtype_column_name],
df_dates[primary_dataset_segment],
*[df_dates[group_column] for group_column in group_column_names],
)
.agg(
F.sum(F.coalesce(amount_column, F.lit(0))).alias("revenue"),
F.count(amount_column).alias("txns"),
)
.alias("df_card_summary_prep")
)
fill_percent_window_columns = [
"fill_day",
cardtype_column_name,
primary_dataset_segment,
*group_column_names,
]
df_card_summary_fill_count = df_card_summary_prep.groupBy(
fill_percent_window_columns
).agg(F.sum("txns").alias("fill_day_count"))
df_card_summary_total_prep = add_percent_of_total_columns(
df_card_summary_prep,
value_columns=["revenue", "txns"],
total_grouping_columns=fill_percent_window_columns,
).withColumnsRenamed(
{
"revenue_percent": "prelim_revenue_fill_day_pct",
"txns_percent": "prelim_txns_fill_day_pct",
}
)
df_fill_pct = df_card_summary_total_prep.join(
df_card_summary_fill_count,
fill_percent_window_columns,
how="left",
).alias("df_fill_pct")
default_percent = F.try_divide(F.lit(1), F.col("fill_count"))
df_card_summary = df_fill_pct.withColumns(
{
"revenue_fill_day_pct": F.coalesce(
F.col("prelim_revenue_fill_day_pct"), default_percent
),
"txns_fill_day_pct": F.coalesce(
F.col("prelim_txns_fill_day_pct"), default_percent
),
}
).alias("df_card_summary")
df_bank_summary = (
df_dates.where(df_dates[cardtype_column_name] == debit_value)
.join(
df_panel,
E.all(
df_dates[date_column_name] == df_panel[date_column_name],
df_dates[primary_dataset_segment] == df_panel[primary_dataset_segment],
df_dates[cardtype_column_name] == df_panel[cardtype_column_name],
*[
df_dates[group_column] == df_panel[group_column]
for group_column in group_column_names
],
),
how="left",
)
.groupby(
df_dates.fill_day,
df_dates[cardtype_column_name],
df_dates[primary_dataset_segment],
*[df_dates[group_column] for group_column in group_column_names],
)
.agg(
F.sum(F.coalesce(amount_column, F.lit(0))).alias("revenue"),
F.count(amount_column).alias("txns"),
)
.alias("df_bank_summary")
)
df_sum_prep = (
df_card_summary.join(
df_bank_summary,
[primary_dataset_segment, "fill_day"] + group_column_names,
how="left",
)
.withColumns(
{
"bank_revenue": df_card_summary.revenue_fill_day_pct
* df_bank_summary.revenue,
"bank_count": df_card_summary.txns_fill_day_pct * df_bank_summary.txns,
}
)
.select(
df_card_summary[date_column_name],
df_card_summary.fill_day,
df_card_summary[primary_dataset_segment],
*[df_card_summary[group_column] for group_column in group_column_names],
F.col("bank_revenue"),
F.col("bank_count"),
)
.alias("df_sum_prep")
)
sum_prep_window = (
W.partitionBy(
F.col("fill_day"),
F.col(f"df_sum_prep.{primary_dataset_segment}"),
*group_columns,
)
.orderBy(date_column_name)
.rowsBetween(W.unboundedPreceding, -1)
)
revenue_to_date = F.coalesce(
F.sum("bank_revenue").over(sum_prep_window), F.lit(0)
).cast(DECIMAL_PRECISION)
count_to_date = F.coalesce(
F.sum("bank_count").over(sum_prep_window), F.lit(0)
).cast(DECIMAL_PRECISION)
df_sum = (
df_sum_prep.where(df_sum_prep[primary_dataset_segment].isNotNull())
.withColumns(
{
"fill_revenue_td": revenue_to_date,
"running_revenue_cap": (revenue_to_date + F.col("bank_revenue")).cast(
DECIMAL_PRECISION
),
"fill_count_td": count_to_date,
"running_count_cap": (count_to_date + F.col("bank_count")).cast(
DECIMAL_PRECISION
),
}
)
.alias("df_sum")
)
adjusted_date = E.chain_cases(
[
E.case(
F.coalesce(F.col(f"a.{cardtype_column_name}"), F.lit("UNKNOWN"))
!= debit_value,
df_panel[date_column_name],
), ### I believe this change is necessary for Mando. I don't htink it would impact SW, but not positive
E.case(F.col("a.fill_status") == "full_day", df_panel[date_column_name]),
E.case(
F.col(f"a.{primary_dataset_segment}").isNull(),
df_panel[date_column_name],
),
],
otherwise=df_sum[date_column_name],
)
df_final = (
df_panel.alias("a")
.join(
df_sum.alias("b"),
E.all(
df_panel.fill_day == df_sum.fill_day,
E.between(
df_panel[f"running_{adjustment_type}"],
df_sum[f"fill_{adjustment_type}_td"],
df_sum[f"running_{adjustment_type}_cap"],
include_lower_bound=False,
include_upper_bound=True,
),
*[
F.col(f"a.{group_column}") == F.col(f"b.{group_column}")
for group_column in group_column_names
],
F.col(f"a.{primary_dataset_segment}")
== F.col(f"b.{primary_dataset_segment}"),
F.col(f"a.{cardtype_column_name}") == debit_value,
),
how="left",
)
.select(
"a.*",
adjusted_date.alias("adjusted_date"),
)
.drop(
*[
"fill_status",
"fill_day",
"fill_count",
"hash_id",
"running_revenue",
"running_count",
]
)
)
if not using_groups:
df_final = df_final.drop(*group_column_names)
# Normalize schema one more time to handle column renames and avoid
# introducing column duplicates in output df schema
final_columns = []
visited = set()
for col in df_final.columns:
if col == "adjusted_date":
col_to_add = F.col("adjusted_date").alias(date_column_name)
if date_column_name not in visited:
final_columns.append(col_to_add)
visited.add(date_column_name)
# Also fix quarter/month to be derived from the adjusted date value
elif col == "quarter":
col_to_add = E.date_trunc("QUARTER", F.col("adjusted_date")).alias(
"quarter"
)
if col not in visited:
visited.add(col)
final_columns.append(col_to_add)
elif col == "month":
col_to_add = E.date_trunc("MONTH", F.col("adjusted_date")).alias("month")
if col not in visited:
visited.add(col)
final_columns.append(col_to_add)
elif col == date_column_name:
col_to_add = F.col(date_column_name).alias(f"{date_column_name}_raw")
if f"{date_column_name}_raw" not in visited:
final_columns.append(col_to_add)
visited.add(f"{date_column_name}_raw")
else:
col_to_add = col
if col_to_add not in visited:
final_columns.append(col_to_add)
visited.add(col_to_add)
df_final = df_final.select(final_columns)
if dataset == "mando":
df_final = df_final.unionByName(df_excluded_countries)
return df_final