from typing import Literal, Optional
from datetime import date
from pyspark.sql import DataFrame, Column, functions as F, Window as W
from yipit_databricks_utils.helpers.pyspark_utils import get_spark_session
from yipit_databricks_utils.helpers.telemetry import track_usage
from etl_toolkit import expressions as E
from etl_toolkit.analyses.card.core import (
BaseCardTableConfiguration,
get_table_configuration_for_dataset,
)
from etl_toolkit.exceptions import InvalidInputException
from etl_toolkit.analyses.calculation import add_percent_of_total_columns
@track_usage
[docs]
def add_card_paneling(
df: DataFrame,
dataset: Literal["skywalker", "yoda", "mando"] = "yoda",
panel_ids: Optional[list[str]] = None,
add_geo_weights: bool = True,
add_income_weights: bool = True,
add_card_type_weights: bool = True,
qa: bool = False,
) -> DataFrame:
"""
:bdg-primary:`QA Mode support`
Applies card paneling on the input dataframe (``df``) using the specified ``panel_ids`` (panel table names) and ``dataset`` (Skywalker, Yoda, or Mando).
Transactions where the user is not a member of the panel will be filtered out unless ``qa=True``. In addition, various weight columns will be added
based on the panel to account for geography, income, and card type biases in the dataset. Currently only one panel ID can be provided.
.. note:: For Mando datasets, only basic user paneling is supported. Income weights, card type weights, and geography weights are not yet supported
for Mando data.
A fully adjusted transaction amount should multiply geopgraphy, income, card type weights, and any other adjustments (ex: lag_adjustment) to the transaction amount.
The additional columns added to the input ``df`` for this adjustment include:
* ``is_in_panel``: This column indicates if the transaction's user is a member of any of the Panel ID(s) specified . Boolean type.
* ``geo_weight``: This column includes the geo weight value for a given transaction given its panelist and panel ID. Double type.
Only available for Skywalker and Yoda datasets.
* ``income_weight``: This column includes the income weight value for a given transaction given its panelist and panel ID. Double type.
Only available for Skywalker and Yoda datasets.
* ``card_type_weight``: This column includes the card type weight value for a given transaction given its panelist and panel ID. Double type.
Only available for Skywalker and Yoda datasets.
:param df: Input dataframe to add adjustment columns to through this function. It should be generated from the txns_all table from Skywalker, Yoda or Mando table and parsed for specific merchant(s).
:param dataset: The source dataset that the input ``df`` is derived from. Can be either ``yoda``, ``skywalker``, or ``mando``. Defaults to ``yoda``.
:param panel_ids: A list of strings indicated the Panel ID(s) (panel table names) that this adjustment should account for.
Defaults to panel ``fixed_201901_100`` for Yoda, ``fixed_201901_111_ex`` for Skywalker, and appropriate geographic panels for Mando.
:param add_geo_weights: Boolean flag to control if a ``geo_weight`` column should be added to the dataset. Only applicable for Skywalker/Yoda. Defaults to ``True``.
:param add_income_weights: Boolean flag to control if an ``income_weight`` column should be added to the dataset. Only applicable for Skywalker/Yoda. Defaults to ``True``.
:param add_card_type_weights: Boolean flag to control if a ``card_type_weight`` column should be added to the dataset. Only applicable for Skywalker/Yoda. Defaults to ``True``.
:param qa: Boolean flag to control QA mode for this function. When ``True``, then all rows are preserved and the ``is_in_panel`` column indicates if the transaction was in any of the Panel IDs specified. Defaults to ``False``, where non-panel transactions are filtered out.
Examples
^^^^^^^^^^
.. code-block:: python
:caption: Example of applying card paneling to a Yoda dataset. Note the default Yoda panel is used here.
from etl_toolkit import E, F, A
tost_txns = spark.table("yd_production.tost_silver.txns_parsed")
display(
A.add_card_paneling(
tost_txns,
dataset="yoda",
)
)
+--------------+--------------+--------------+----------------+----------------+-------------------+------------------+------------------+
|cobrid |cardtype |trans_dow |... |is_in_panel |geo_weight |income_weight |card_type_weight |
+--------------+--------------+--------------+----------------+----------------+-------------------+------------------+------------------+
| 175| DEBIT| 4| ...| True| 0.6403014580204711|0.6190908069398377|2.9107223403194861|
+--------------+--------------+--------------+----------------+----------------+-------------------+------------------+------------------+
| 174| DEBIT| 5| ...| True| 0.4155116960839209|0.6190908069398377|2.9107223403194861|
+--------------+--------------+--------------+----------------+----------------+-------------------+------------------+------------------+
.. code-block:: python
:caption: Example of applying card paneling to a Yoda dataset when QA mode is enabled. Note that out of panel transactions are included.
from etl_toolkit import E, F, A
tost_txns = spark.table("yd_production.tost_silver.txns_parsed")
display(
A.add_card_paneling(
tost_txns,
dataset="yoda",
)
)
+--------------+--------------+--------------+----------------+----------------+-------------------+------------------+------------------+
|cobrid |cardtype |trans_dow |... |is_in_panel |geo_weight |income_weight |card_type_weight |
+--------------+--------------+--------------+----------------+----------------+-------------------+------------------+------------------+
| 175| DEBIT| 4| ...| True| 0.6403014580204711|0.6190908069398377|2.9107223403194861|
+--------------+--------------+--------------+----------------+----------------+-------------------+------------------+------------------+
| 174| DEBIT| 5| ...| True| 0.4155116960839209|0.6190908069398377|2.9107223403194861|
+--------------+--------------+--------------+----------------+----------------+-------------------+------------------+------------------+
| 174| DEBIT| 5| ...| False| 0.4155116960839209|0.6190908069398377|2.9107223403194861|
+--------------+--------------+--------------+----------------+----------------+-------------------+------------------+------------------+
.. code-block:: python
:caption: Example of applying card paneling to a Yoda dataset with a specific Panel ID and not including card type weights.
from etl_toolkit import E, F, A
tost_txns = spark.table("yd_production.tost_silver.txns_parsed")
display(
A.add_card_paneling(
tost_txns,
dataset="yoda",
panel_ids=["fixed_201901_333_cbs_green_red_teal"],
add_card_type_weights=False,
)
)
+--------------+--------------+--------------+----------------+----------------+-------------------+------------------+
|cobrid |cardtype |trans_dow |... |is_in_panel |geo_weight |income_weight |
+--------------+--------------+--------------+----------------+----------------+-------------------+------------------+
| 175| DEBIT| 4| ...| True| 0.6403014580204711|0.6190908069398377|
+--------------+--------------+--------------+----------------+----------------+-------------------+------------------+
| 174| DEBIT| 5| ...| True| 0.4155116960839209|0.6190908069398377|
+--------------+--------------+--------------+----------------+----------------+-------------------+------------------+
.. code-block:: python
:caption: Example of applying card paneling to a Mando dataset. Note how only the is_in_panel column is added.
from etl_toolkit import E, F, A
mando_txns = spark.table("yd_3p_mando.mando_gold.txns_parsed")
display(
A.add_card_paneling(
mando_txns,
dataset="mando",
)
)
"""
table_configuration = get_table_configuration_for_dataset(
dataset,
panel_ids=panel_ids,
)
# Note that mando does not yet support additional weighting features
if dataset == "mando":
add_geo_weights = False
add_income_weights = False
add_card_type_weights = False
# Join in relevant panels based on panel_ids and
# add a boolean column to indicate if the txn is in the panel
is_in_panel_column_name = "is_in_panel"
txns_paneled = df.join(
table_configuration.panel_df, ["user_id"], how="left"
).withColumns(
{
f"{is_in_panel_column_name}_{panel_id}": _is_in_panel(panel_id)
for panel_id in table_configuration.panel_ids
}
if table_configuration.is_multi_panel
else {
is_in_panel_column_name: _is_in_panel(table_configuration.panel_ids[0]),
}
)
# When not using QA mode, records not part of any panel are dropped
if not qa:
is_in_at_least_one_panel = F.size(F.col("panel_ids")) > 0
txns_paneled = txns_paneled.where(is_in_at_least_one_panel)
# Add relevant weight adjustments as new columns based on the panels specified
if add_geo_weights:
txns_paneled = add_card_geo_weights(
txns_paneled,
dataset=dataset,
panel_ids=panel_ids,
)
if add_income_weights:
txns_paneled = add_card_income_weights(
txns_paneled,
dataset=dataset,
panel_ids=panel_ids,
)
if add_card_type_weights:
txns_paneled = _add_card_type_weights(
txns_paneled,
dataset=dataset,
panel_ids=panel_ids,
)
return txns_paneled
@track_usage
def add_card_geo_weights(
df: DataFrame,
dataset: Literal["skywalker", "yoda"] = "yoda",
panel_ids: Optional[list[str]] = None,
weight_column_name: str = "geo_weight",
) -> DataFrame:
spark = get_spark_session()
table_configuration = get_table_configuration_for_dataset(
dataset,
panel_ids=panel_ids,
)
enriched_df = df
for panel_id in table_configuration.panel_ids:
geo_weights_table_name = f"{table_configuration.panel_catalog}.{table_configuration.panel_database}.{panel_id}_geo_weights"
geo_weights = spark.table(geo_weights_table_name).select(
"user_id", "month", F.col("weight").alias("geo_weight_raw")
)
enriched_df = (
enriched_df.join(
geo_weights,
["user_id", "month"],
how="left",
)
.withColumn(
f"{weight_column_name}_{panel_id}"
if table_configuration.is_multi_panel
else weight_column_name,
F.coalesce(F.col("geo_weight_raw"), F.lit(1)),
)
.drop("geo_weight_raw")
)
return enriched_df
def _add_card_type_weights(
df: DataFrame,
dataset: Literal["skywalker", "yoda"] = "yoda",
panel_ids: Optional[list[str]] = None,
weight_column_name: str = "card_type_weight",
) -> DataFrame:
spark = get_spark_session()
table_configuration = get_table_configuration_for_dataset(
dataset,
panel_ids=panel_ids,
)
enriched_df = df
for panel_id in table_configuration.panel_ids:
card_type_weights_table_name = f"{table_configuration.panel_catalog}.{table_configuration.panel_database}.{panel_id}_card_weights"
card_type_weights = spark.table(card_type_weights_table_name).select(
F.col("type").alias(table_configuration.card_type_column_name),
F.col("card_type_weight").alias("card_type_weight_raw"),
)
enriched_df = (
enriched_df.join(
card_type_weights,
[table_configuration.card_type_column_name],
how="left",
)
.withColumn(
f"{weight_column_name}_{panel_id}"
if table_configuration.is_multi_panel
else weight_column_name,
F.coalesce(F.col("card_type_weight_raw"), F.lit(1)),
)
.drop("card_type_weight_raw")
)
return enriched_df
@track_usage
def add_card_income_weights(
df: DataFrame,
dataset: Literal["skywalker", "yoda"] = "yoda",
panel_ids: Optional[list[str]] = None,
weight_column_name: str = "income_weight",
) -> DataFrame:
spark = get_spark_session()
table_configuration = get_table_configuration_for_dataset(
dataset,
panel_ids=panel_ids,
)
enriched_df = df
for panel_id in table_configuration.panel_ids:
income_weights_table_name = f"{table_configuration.panel_catalog}.{table_configuration.panel_database}.{panel_id}_income_weights"
income_weights = spark.table(income_weights_table_name).select(
"user_id", F.col("weight").alias("income_weight_raw")
)
enriched_df = (
enriched_df.join(
income_weights,
["user_id"],
how="left",
)
.withColumn(
f"{weight_column_name}_{panel_id}"
if table_configuration.is_multi_panel
else weight_column_name,
F.coalesce(F.col("income_weight_raw"), F.lit(1)),
)
.drop("income_weight_raw")
)
return enriched_df
@track_usage
[docs]
def get_card_panels(
dataset: Literal["skywalker", "yoda"] = "yoda",
) -> DataFrame:
"""
Utility function to list all available Panel IDs for a given ``dataset``.
The Panel IDs can be used in the ``panel_ids`` argument for ETL toolkit card analyses.
:param dataset: The dataset to retrieve available Panel IDs from. Can be either ``yoda`` or ``skywalker``. Defaults to ``yoda``.
Examples
^^^^^^^^^^
.. code-block:: python
:caption: Example of retrieving all Yoda Panel IDs
from etl_toolkit import E, F, A
display(
A.get_card_panels(
dataset="yoda",
)
)
+------------------------------+
|panel_id |
+------------------------------+
|fixed_201701_100 |
+------------------------------+
|fixed_201701_100_card_weights |
+------------------------------+
|... |
+------------------------------+
.. code-block:: python
:caption: Example of retrieving all Skywalker Panel IDs
from etl_toolkit import E, F, A
display(
A.get_card_panels(
dataset="skywalker",
)
)
+---------------------------------+
|panel_id |
+---------------------------------+
|fixed_201701_333_ex |
+---------------------------------+
|fixed_201701_333_ex_card_weights |
+---------------------------------+
|... |
+---------------------------------+
"""
spark = get_spark_session()
table_configuration = get_table_configuration_for_dataset(
dataset,
)
df = (
spark.table(f"system.information_schema.tables")
.where(F.col("table_catalog") == table_configuration.panel_catalog)
.where(F.col("table_schema") == table_configuration.panel_database)
.select(F.col("table_name").alias("panel_id"))
.orderBy("panel_id")
)
return df
def _is_in_panel(panel_id: str) -> Column:
return F.coalesce(
F.array_contains(F.col("panel_ids"), F.lit(panel_id)), F.lit(False)
)
@track_usage
[docs]
def add_card_paneling_reweighted(
df: DataFrame,
dataset: Literal["skywalker", "mando", "yoda"] = "yoda",
base_panel_id: Optional[str] = None,
weighting_panel_id: Optional[str] = None,
panel_overlap_start_date: date = date(2021, 1, 1),
panel_overlap_end_date: date = date(2021, 12, 31),
reweight_min: float = 0.5,
reweight_max: float = 2.0,
) -> DataFrame:
"""
Applies a base panel using A.add_card_paneling and then uses a weighted panel to generate weights
for shifting merchant volumes by merchant. This combines two panels - a base panel for the main paneling
and weights, and a weighting panel used to determine merchant reweighting factors.
:param df: Input dataframe of transactions data
:param dataset: The source dataset the df belongs to
:param base_panel_id: The base panel ID for the dataset. If not provided the default panel is used based on the dataset.
:param weighting_panel_id: The weighting panel ID for the dataset. If not provided the default weighting panel is used based on the dataset.
:param panel_overlap_start_date: The minimum date the provided panel IDs overlap, used to generate a fixed window to determine reweighting
:param panel_overlap_end_date: The maximum date the provided panel IDs overlap, used to generate a fixed window to determine reweighting
:param reweight_min: Overall minimum value for the reweight column
:param reweight_max: Overall maximum value for the reweight column
:returns: A dataframe of transactions for the given dataset, with paneling applied and a reweight column added to indicate the adjustment.
Examples
^^^^^^^^^^
.. code-block:: python
:caption: Example of applying card paneling reweighting with default panels
from etl_toolkit import A
df = spark.table("yd_production.tost_silver.txns_parsed")
panel_df = A.add_card_paneling_reweighted(
df,
dataset="yoda",
)
.. code-block:: python
:caption: Example of applying card paneling reweighting with custom panel IDs
from etl_toolkit import A
from datetime import date
df = spark.table("yd_production.tost_silver.txns_parsed")
panel_df = A.add_card_paneling_reweighted(
df,
dataset="yoda",
base_panel_id="fixed_201901_202203_100",
weighting_panel_id="fixed_202101_202403_100",
panel_overlap_start_date=date(2021,1,1),
panel_overlap_end_date=date(2021,12,31)
)
"""
base_config = get_table_configuration_for_dataset(
dataset,
panel_ids=[base_panel_id] if base_panel_id else None,
)
weighting_config = get_table_configuration_for_dataset(
dataset,
panel_ids=[weighting_panel_id] if weighting_panel_id else None,
)
# Calculate base paneled df
base_paneled_df = add_card_paneling(
df,
dataset=dataset,
panel_ids=[base_panel_id] if base_panel_id else None,
add_geo_weights=True,
add_income_weights=False,
add_card_type_weights=True,
qa=False,
)
# Calculate weighting paneled df
weighting_paneled_df = add_card_paneling(
df,
dataset=dataset,
panel_ids=[weighting_panel_id] if weighting_panel_id else None,
add_geo_weights=True,
add_income_weights=False,
add_card_type_weights=True,
qa=False,
)
# Define date column based on dataset
date_col = "date" if dataset == "skywalker" else "trans_date"
amount_col = "amount" if dataset == "skywalker" else "trans_amount"
raw_amount_col = "amount_raw" if dataset == "skywalker" else "trans_amount_raw"
adjusted_transaction_amount = (
F.col(amount_col)
* F.col("geo_weight")
* F.coalesce(F.col("card_type_weight"), F.lit(1.0))
)
# Calculate adjusted amounts for overlap period to determine reweighting
overlap_period_filter = E.between(
F.col(date_col), panel_overlap_start_date, panel_overlap_end_date
)
# Calculate GMV for base panel during overlap period with panel weights applied
base_panel_gmv = (
base_paneled_df.withColumn(
"adjusted_transaction_amount", adjusted_transaction_amount
)
.filter(overlap_period_filter)
.groupBy(F.col("yd_tag_merchant"))
.agg(F.sum("adjusted_transaction_amount").alias("gmv_base"))
.alias("base")
)
# Calculate GMV for weighting panel during overlap period with panel weights applied
weighting_panel_gmv = (
weighting_paneled_df.withColumn(
"adjusted_transaction_amount", adjusted_transaction_amount
)
.filter(overlap_period_filter)
.groupBy(F.col("yd_tag_merchant"))
.agg(F.sum("adjusted_transaction_amount").alias("gmv_reweight"))
.alias("weight")
)
# Calculate reweight factors with proper default for non-overlap periods
reweight_factor_raw = F.coalesce(
F.col("gmv_reweight") / F.col("gmv_base"), F.lit(1.0)
)
reweight_factor = E.chain_cases(
[
E.case(reweight_factor_raw > reweight_max, reweight_max),
E.case(reweight_factor_raw < reweight_min, reweight_min),
],
otherwise=reweight_factor_raw,
)
reweight_factors = (
base_panel_gmv.join(
weighting_panel_gmv,
E.all(
[
F.col("base.yd_tag_merchant") == F.col("weight.yd_tag_merchant"),
]
),
"left",
)
.select(
F.col("base.yd_tag_merchant").alias("factor_yd_tag_merchant"),
reweight_factor.alias("reweight_factor"),
)
.alias("factors")
)
# Store original amount and apply all weights
final_df = (
base_paneled_df.alias("base_df")
.join(
reweight_factors,
E.all(
[
F.col("base_df.yd_tag_merchant") == F.col("factor_yd_tag_merchant"),
]
),
how="left",
)
.withColumn("reweight_factor", F.coalesce(F.col("reweight_factor"), F.lit(1.0)))
.withColumn(raw_amount_col, F.col(amount_col))
.withColumn(
amount_col,
F.col(raw_amount_col)
* F.col("geo_weight")
* F.coalesce(F.col("card_type_weight"), F.lit(1.0))
* F.col("reweight_factor"),
)
.drop("factor_yd_tag_merchant")
)
return final_df