from datetime import date
from enum import Enum
from typing import List, Literal, Optional
from pyspark.sql import DataFrame, functions as F, Column, Window as W
from pyspark.sql.types import StringType
from yipit_databricks_utils.helpers.pyspark_utils import get_spark_session
from yipit_databricks_utils.helpers.telemetry import track_usage
from etl_toolkit import expressions as E
from etl_toolkit.exceptions import InvalidInputException
from etl_toolkit.analyses.ereceipts.core import get_ereceipts_table_configuration
from etl_toolkit.analyses.ereceipts.constants import (
EDISON_RIDES_ALLOWED_VENDORS,
EDISON_WALMART_ALLOWED_VENDORS,
PRE_JOINED_USER_COLUMNS,
USER_ENRICHMENT_COLUMNS,
USER_DEMO_COLUMNS,
local_email_time_logic,
)
def _validate_vendor_filter(vendor_list: list[[str]]) -> None:
if vendor_list is not None and not isinstance(vendor_list, list):
raise InvalidInputException(
f"Vendor list must be a list of string(s) with vendor name"
)
if isinstance(vendor_list, list):
for vendor in vendor_list:
if not isinstance(vendor, str):
raise InvalidInputException(
f"Vendor list must contain string(s) only, but got {vendor}"
)
if vendor_list is None or vendor_list == []:
print(
f"Warning: Accessing the entire table without filtering for vendors may cause performance issues. Please provide a list of vendors instead."
)
def _validate_receipt_vendors(receipt_type: str, vendor_list: list[str]) -> None:
allowed_vendor_list = []
match receipt_type:
case "rides":
allowed_vendor_list = EDISON_RIDES_ALLOWED_VENDORS
case "walmart":
allowed_vendor_list = EDISON_WALMART_ALLOWED_VENDORS
case _:
raise InvalidInputException(
f"Allowed vendors not defined for receipt type '{receipt_type}'"
)
for vendor in vendor_list:
if vendor.lower() not in [x.lower() for x in allowed_vendor_list]:
raise InvalidInputException(
f"Vendor '{vendor}' is not a valid vendor for {receipt_type} receipts. Valid vendors are: {allowed_vendor_list}"
)
def _validate_country_filter(country) -> None:
if country is not None and not isinstance(country, str):
raise InvalidInputException(
f"Expected string for country, got {type(country).__name__}"
)
def _validate_include_duplicates_flag(config, include_duplicates) -> None:
if include_duplicates is not None:
if not isinstance(include_duplicates, bool):
raise InvalidInputException(
f"Expected bool for include_duplicates, got {type(include_duplicates).__name__}"
)
if (
include_duplicates
and "with_dupes" not in list(config.source_tables.keys())
and "dupes" not in list(config.source_tables.keys())
):
raise InvalidInputException(
f"Receipt type '{config.receipt_type}' does not support duplicate tables. Please remove this argument or set it to False"
)
def _validate_granularity(config, granularity) -> None:
if granularity is not None:
if not isinstance(granularity, str):
raise InvalidInputException(
f"Expected string for granularity, got {type(granularity).__name__}"
)
if granularity not in ["orders", "items"]:
raise InvalidInputException(
f"granularity must be one of 'orders' or 'items'"
)
if config.receipt_type == "clean_headers":
raise InvalidInputException(
f"Receipt type '{config.receipt_type}' does not support granularity. Please remove this argument"
)
if granularity == "items" and config.unique_identifier != "item_checksum":
raise InvalidInputException(
f"Receipt type '{config.receipt_type}' does not support 'items' granularity. Please remove this argument or set it to 'orders'"
)
if granularity == "orders" and config.unique_identifier not in [
"checksum",
"item_checksum",
]:
raise InvalidInputException(
f"Receipt type '{config.receipt_type}' does not support 'orders' granularity. Please remove this argument or set it to 'items'"
)
def _validate_add_app_id_flag(config, add_app_id) -> None:
if add_app_id is not None:
if not isinstance(add_app_id, bool):
raise InvalidInputException(
f"Expected bool for add_app_id, got {type(add_app_id).__name__}"
)
if add_app_id and config.unique_identifier != "item_checksum":
raise InvalidInputException(
f"App ID joining is only supported for tables with item_checksum, but unique identifier is {config.unique_identifier}"
)
def _validate_add_demo_columns_flag(config, add_demo_columns) -> None:
if add_demo_columns is not None:
if not isinstance(add_demo_columns, bool):
raise InvalidInputException(
f"Expected bool for add_demo_columns, got {type(add_demo_columns).__name__}"
)
def _validate_source_ereceipts_inputs(
config: dict,
receipt_type: str,
vendor_list: list[str],
country: str,
include_duplicates: bool,
granularity: str,
add_app_id: bool,
add_demo_columns: bool,
):
# Validate vendor_list
_validate_vendor_filter(vendor_list)
if receipt_type in ["rides", "walmart"]:
_validate_receipt_vendors(receipt_type, vendor_list)
# Validate country filter
_validate_country_filter(country)
# Validate include_duplicates
_validate_include_duplicates_flag(config, include_duplicates)
# Validate granularity
_validate_granularity(config, granularity)
# Validate include_duplicates and granularity
if include_duplicates and granularity == "orders":
raise InvalidInputException(
f"Order-level deduping is not supported for duplicate tables"
)
# Validate add_app_id
_validate_add_app_id_flag(config, add_app_id)
# Validate add_demo_columns
_validate_add_demo_columns_flag(config, add_demo_columns)
def _alert_if_last_processed_date_mismatch(
df_source_main: DataFrame, df_source_dupes: DataFrame, config: dict
) -> None:
main_max_dt = df_source_main.select(
F.to_date(F.max("email_timestamp")).alias("main_max_dt")
).collect()[0]["main_max_dt"]
dupe_max_dt = df_source_dupes.select(
F.to_date(F.max("email_timestamp")).alias("dupes_max_dt")
).collect()[0]["dupes_max_dt"]
if main_max_dt != dupe_max_dt:
print(
f"Main {config.receipt_type} table and duplicates have different max dates. Main {config.receipt_type} table: {main_max_dt}, Duplicates {config.receipt_type} table: {dupe_max_dt}"
)
def _filter_for_vendor_list_if_present(
df: DataFrame,
vendor_column_name: str,
vendor_list: list[str],
) -> DataFrame:
if vendor_list is None or vendor_list == []:
return df
else:
return df.filter(
F.lower(vendor_column_name).isin([vendor.lower() for vendor in vendor_list])
)
@track_usage
[docs]
def source_ereceipts(
receipt_type: Literal[
"clean_headers",
"purchase",
"flights",
"hotels",
"rental_cars",
"rides",
"subscriptions",
"walmart",
],
vendor_list: list[str] = None,
country: Optional[
Literal["us", "intl", "ca", "all", "US", "INTL", "CA", "ALL"]
] = None,
include_duplicates: Optional[bool] = False,
granularity: Literal["orders", "items"] = None,
add_app_id: Optional[bool] = False,
add_demo_columns: Optional[bool] = False,
) -> DataFrame:
"""
Fetch a DataFrame of Edison e-receipts with filtering and transformations. This function standardizes e-receipts access, ensuring consistency across modules.
.. tip:: Use this function instead of directly accessing source tables for consistent filtering and schema handling.
:param receipt_type: Type of receipt data to retrieve. Supported types:
- `clean_headers`
- `purchase`
- `subscriptions`
- `rental_cars`
- `hotels`
- `flights`
- `rides`
- `walmart`
:param vendor_list: List of vendors to filter. Must be non-empty.
:param country: Optional. Valid values: `ALL`, `US`, `CA`, `INTL`. Default varies by type.
:param include_duplicates: Include duplicate entries. Default: `False`. Some types disallow duplicates.
:param granularity: Specify granularity level. Valid: `items`, `orders`. Default depends on receipt type.
:param add_app_id: Include application ID. Default: `False`. Supported only for certain receipt types.
:param add_demo_columns: Add demographic columns. Default: `False`.
:return: Filtered e-receipts DataFrame.
:raises InvalidInputException: On validation failure.
### Examples
^^^^^^^^^^^^
#### Clean Headers
Retrieve receipts for specific vendors in the US:
.. code-block:: python
clean_headers_df = A.source_ereceipts(
receipt_type="clean_headers",
vendor_list=["Uber", "Walmart"], # at least 1 vendor required
country="ALL" # defaults to ALL for clean headers
)
#### Including Duplicates
Allow duplicates (if supported):
.. code-block:: python
purchase_df = A.source_ereceipts(
receipt_type="purchase",
vendor_list=["Walmart"],
country="US" # defaults to US for purchase
include_duplicates=True # defaults to False for all receipt types
)
#### Adjusting Granularity
Retrieve order-level data:
.. code-block:: python
subscriptions_df = A.source_ereceipts(
receipt_type="purchase",
vendor_list=["UberEats"],
granularity="orders" # defaults to "items" for purchase
)
#### Adding Extra Columns
Include app IDs and user demographic columns:
.. code-block:: python
purchase_with_details = A.source_ereceipts(
receipt_type="purchase",
vendor_list=["Walmart"],
add_app_id=True,
add_demo_columns=True
)
Validation Rules
================
.. list-table:: Validation Rules
:header-rows: 1
* - Rule
- Description
* - **Receipt Type**
- Must be valid. Raises `InvalidInputException` for invalid types.
* - **Vendor List**
- Non-empty list of strings. Errors on empty or invalid input.
* - **Country**
- Valid values: `ALL`, `US`, `CA`, `INTL`.
- Defaults to `US` for `purchase`
* - **Include Duplicates**
- Boolean. Unsupported types error if `True`.
* - **Granularity**
- `items` or `orders`. Unsupported types raise errors.
* - **Add App ID**
- Supported for specific types only. Raises error if unsupported.
* - **Add Demo Columns**
- Boolean. Supported for all types.
### Receipt Type-Specific Behavior
.. code-block:: text
+-------------------+-----------------+------------+-----------------+--------+--------------+
| Receipt Type | country | duplicates | granularity | app ID | demo columns |
+-------------------+-----------------+------------+-----------------+--------+--------------+
| clean_headers | ALL | No | N/A | No | Yes |
+-------------------+-----------------+------------+-----------------+--------+--------------+
| purchase | US | Yes | items, orders | Yes | Yes |
+-------------------+-----------------+------------+-----------------+--------+--------------+
| subscriptions | ALL | No | items, orders | Yes | Yes |
+-------------------+-----------------+------------+-----------------+--------+--------------+
| rental_cars | ALL | Yes | orders | No | Yes |
+-------------------+-----------------+------------+-----------------+--------+--------------+
| hotels | ALL | Yes | orders | No | Yes |
+-------------------+-----------------+------------+-----------------+--------+--------------+
| flights | ALL | Yes | items | No | Yes |
+-------------------+-----------------+------------+-----------------+--------+--------------+
| rides | ALL | No | items, orders | No | Yes |
+-------------------+-----------------+------------+-----------------+--------+--------------+
| walmart | ALL | Yes | items, orders | No | Yes |
+-------------------+-----------------+------------+-----------------+--------+--------------+
"""
spark = get_spark_session()
# Get receipt configurations
config = get_ereceipts_table_configuration(
receipt_type=receipt_type,
)
# Validate inputs
_validate_source_ereceipts_inputs(
config,
receipt_type,
vendor_list,
country,
include_duplicates,
granularity,
add_app_id,
add_demo_columns,
)
# List source data tables
df_main = _filter_for_vendor_list_if_present(
spark.table(config.source_tables["main"]),
config.vendor_column_name,
vendor_list,
)
if include_duplicates:
if "with_dupes" in config.source_tables.keys():
df_with_dupes = _filter_for_vendor_list_if_present(
spark.table(config.source_tables["with_dupes"]),
config.vendor_column_name,
vendor_list,
)
elif "dupes" in config.source_tables.keys():
df_dupes = _filter_for_vendor_list_if_present(
spark.table(config.source_tables["dupes"]),
config.vendor_column_name,
vendor_list,
)
else: # Initial validation should already handle this, but just in case
raise InvalidInputException(
f"Duplicate tables are not supported for receipt type '{receipt_type}'"
)
# Define "df" to transform. This is the ultimate return value
if include_duplicates:
if "with_dupes" in config.source_tables.keys():
df = df_with_dupes.withColumn( # matching product schema
config.duplicate_flag_column_name,
F.when(
F.col("in_duplicate_table"),
F.lit("1").cast(StringType()),
).when(
~F.col("in_duplicate_table"),
F.lit(None).cast(StringType()),
),
)
elif "dupes" in config.source_tables.keys():
_alert_if_last_processed_date_mismatch(df_main, df_dupes, config)
df = (
df_main.withColumns(
{
"in_duplicate_table": F.lit(False),
config.duplicate_flag_column_name: F.lit("1").cast(
StringType()
),
}
)
).unionByName(
df_dupes.withColumns(
{
"in_duplicate_table": F.lit(True),
config.duplicate_flag_column_name: F.lit(None).cast(
StringType()
),
}
)
)
else:
df = df_main
# Dedupe source data if orders
granularity = granularity or config.granularity
if granularity == "orders" and receipt_type == "purchase":
df = _filter_for_vendor_list_if_present(
spark.table(config.source_tables["orders"]),
config.vendor_column_name,
vendor_list,
)
elif granularity == "orders":
df = (
df.withColumn(
"rn",
F.row_number().over(
W.partitionBy(["checksum"]).orderBy(F.desc("update_timestamp"))
),
)
.filter(F.col("rn") == 1)
.drop("rn")
)
# Filter on country
country = country or config.default_country
match country.upper():
case "ALL":
df = df
case "INTL":
df = df.filter(
E.any(
F.col("derived_user_country") != F.lit("US"),
F.col("derived_user_country").isNull(),
)
)
case "US":
df = df.filter(F.col("derived_user_country") == "US")
case "CA":
df = df.filter(F.col("derived_user_country") == "CA")
case _:
raise InvalidInputException(
f"Provided country ({country}) is not supported in this function. Country must be one of 'US', 'ALL', 'CA', or 'INTL'"
)
# Apply product transformation
if config.transformation_fn:
df = config.transformation_fn(df, include_duplicates, config)
# Join app ID
if add_app_id:
app_df = spark.table(config.app_id_mapping_table_name).select(
"item_checksum",
"app_id",
"app_release_date",
"app_name",
F.col("category_name").alias("app_category"),
)
df = df.join(
app_df,
on=["item_checksum"],
how="left",
)
# For clean headers, user joining is done upstream
# Re-join user if adding demo cols (less frequently used)
if receipt_type == "clean_headers" and not (add_demo_columns):
return df
# Join user ID & optionally remove demo columns
# Note that DST and geo weighting columns are dervied from user demo columns
df = df.drop(
*PRE_JOINED_USER_COLUMNS,
*USER_ENRICHMENT_COLUMNS,
*USER_DEMO_COLUMNS,
)
df_user = spark.table(config.user_table_name).select(
"user_id",
*USER_ENRICHMENT_COLUMNS,
*USER_DEMO_COLUMNS,
)
df = df.alias("a").join(
df_user.alias("b"),
on=["user_id"],
how="inner",
)
# Add DST column, derived from user demo
if receipt_type in [
"purchase",
"rides",
"walmart",
"subscriptions",
]:
df = df.select(
"*",
local_email_time_logic.alias("local_email_time"),
)
# Add geo weighting columns, dervied from user demo
if receipt_type in ["purchase"]:
metro_weights = spark.table(config.geo_weight_table_name).alias("weights")
df = (
df.alias("a")
.join(
metro_weights.alias("b"),
E.all(
F.col("a.order_date") == F.col("b.order_date"),
F.col("a.metro_name") == F.col("b.metro_name"),
),
how="left",
)
.select(
"a.*",
"b.central_geo_weight_metro_weekly",
"b.central_geo_weight_metro_monthly",
"b.central_geo_weight_metro_quarterly",
)
)
# Optionally remove user demo columns once dst & geo weighting columns are added
if not add_demo_columns:
df = df.drop(*USER_DEMO_COLUMNS)
return df