Source code for etl_toolkit.analyses.ereceipts.source

from datetime import date
from enum import Enum
from typing import List, Literal, Optional

from pyspark.sql import DataFrame, functions as F, Column, Window as W
from pyspark.sql.types import StringType
from yipit_databricks_utils.helpers.pyspark_utils import get_spark_session
from yipit_databricks_utils.helpers.telemetry import track_usage

from etl_toolkit import expressions as E
from etl_toolkit.exceptions import InvalidInputException
from etl_toolkit.analyses.ereceipts.core import get_ereceipts_table_configuration
from etl_toolkit.analyses.ereceipts.constants import (
    EDISON_RIDES_ALLOWED_VENDORS,
    EDISON_WALMART_ALLOWED_VENDORS,
    PRE_JOINED_USER_COLUMNS,
    USER_ENRICHMENT_COLUMNS,
    USER_DEMO_COLUMNS,
    local_email_time_logic,
)


def _validate_vendor_filter(vendor_list: list[[str]]) -> None:
    if vendor_list is not None and not isinstance(vendor_list, list):
        raise InvalidInputException(
            f"Vendor list must be a list of string(s) with vendor name"
        )
    if isinstance(vendor_list, list):
        for vendor in vendor_list:
            if not isinstance(vendor, str):
                raise InvalidInputException(
                    f"Vendor list must contain string(s) only, but got {vendor}"
                )
    if vendor_list is None or vendor_list == []:
        print(
            f"Warning: Accessing the entire table without filtering for vendors may cause performance issues. Please provide a list of vendors instead."
        )


def _validate_receipt_vendors(receipt_type: str, vendor_list: list[str]) -> None:
    allowed_vendor_list = []
    match receipt_type:
        case "rides":
            allowed_vendor_list = EDISON_RIDES_ALLOWED_VENDORS
        case "walmart":
            allowed_vendor_list = EDISON_WALMART_ALLOWED_VENDORS
        case _:
            raise InvalidInputException(
                f"Allowed vendors not defined for receipt type '{receipt_type}'"
            )
    for vendor in vendor_list:
        if vendor.lower() not in [x.lower() for x in allowed_vendor_list]:
            raise InvalidInputException(
                f"Vendor '{vendor}' is not a valid vendor for {receipt_type} receipts. Valid vendors are: {allowed_vendor_list}"
            )


def _validate_country_filter(country) -> None:
    if country is not None and not isinstance(country, str):
        raise InvalidInputException(
            f"Expected string for country, got {type(country).__name__}"
        )


def _validate_include_duplicates_flag(config, include_duplicates) -> None:
    if include_duplicates is not None:
        if not isinstance(include_duplicates, bool):
            raise InvalidInputException(
                f"Expected bool for include_duplicates, got {type(include_duplicates).__name__}"
            )
        if (
            include_duplicates
            and "with_dupes" not in list(config.source_tables.keys())
            and "dupes" not in list(config.source_tables.keys())
        ):
            raise InvalidInputException(
                f"Receipt type '{config.receipt_type}' does not support duplicate tables. Please remove this argument or set it to False"
            )


def _validate_granularity(config, granularity) -> None:
    if granularity is not None:
        if not isinstance(granularity, str):
            raise InvalidInputException(
                f"Expected string for granularity, got {type(granularity).__name__}"
            )
        if granularity not in ["orders", "items"]:
            raise InvalidInputException(
                f"granularity must be one of 'orders' or 'items'"
            )
        if config.receipt_type == "clean_headers":
            raise InvalidInputException(
                f"Receipt type '{config.receipt_type}' does not support granularity. Please remove this argument"
            )
        if granularity == "items" and config.unique_identifier != "item_checksum":
            raise InvalidInputException(
                f"Receipt type '{config.receipt_type}' does not support 'items' granularity. Please remove this argument or set it to 'orders'"
            )
        if granularity == "orders" and config.unique_identifier not in [
            "checksum",
            "item_checksum",
        ]:
            raise InvalidInputException(
                f"Receipt type '{config.receipt_type}' does not support 'orders' granularity. Please remove this argument or set it to 'items'"
            )


def _validate_add_app_id_flag(config, add_app_id) -> None:
    if add_app_id is not None:
        if not isinstance(add_app_id, bool):
            raise InvalidInputException(
                f"Expected bool for add_app_id, got {type(add_app_id).__name__}"
            )
        if add_app_id and config.unique_identifier != "item_checksum":
            raise InvalidInputException(
                f"App ID joining is only supported for tables with item_checksum, but unique identifier is {config.unique_identifier}"
            )


def _validate_add_demo_columns_flag(config, add_demo_columns) -> None:
    if add_demo_columns is not None:
        if not isinstance(add_demo_columns, bool):
            raise InvalidInputException(
                f"Expected bool for add_demo_columns, got {type(add_demo_columns).__name__}"
            )


def _validate_source_ereceipts_inputs(
    config: dict,
    receipt_type: str,
    vendor_list: list[str],
    country: str,
    include_duplicates: bool,
    granularity: str,
    add_app_id: bool,
    add_demo_columns: bool,
):
    # Validate vendor_list
    _validate_vendor_filter(vendor_list)
    if receipt_type in ["rides", "walmart"]:
        _validate_receipt_vendors(receipt_type, vendor_list)

    # Validate country filter
    _validate_country_filter(country)

    # Validate include_duplicates
    _validate_include_duplicates_flag(config, include_duplicates)

    # Validate granularity
    _validate_granularity(config, granularity)

    # Validate include_duplicates and granularity
    if include_duplicates and granularity == "orders":
        raise InvalidInputException(
            f"Order-level deduping is not supported for duplicate tables"
        )

    # Validate add_app_id
    _validate_add_app_id_flag(config, add_app_id)

    # Validate add_demo_columns
    _validate_add_demo_columns_flag(config, add_demo_columns)


def _alert_if_last_processed_date_mismatch(
    df_source_main: DataFrame, df_source_dupes: DataFrame, config: dict
) -> None:
    main_max_dt = df_source_main.select(
        F.to_date(F.max("email_timestamp")).alias("main_max_dt")
    ).collect()[0]["main_max_dt"]
    dupe_max_dt = df_source_dupes.select(
        F.to_date(F.max("email_timestamp")).alias("dupes_max_dt")
    ).collect()[0]["dupes_max_dt"]
    if main_max_dt != dupe_max_dt:
        print(
            f"Main {config.receipt_type} table and duplicates have different max dates. Main {config.receipt_type} table: {main_max_dt}, Duplicates {config.receipt_type} table: {dupe_max_dt}"
        )


def _filter_for_vendor_list_if_present(
    df: DataFrame,
    vendor_column_name: str,
    vendor_list: list[str],
) -> DataFrame:
    if vendor_list is None or vendor_list == []:
        return df
    else:
        return df.filter(
            F.lower(vendor_column_name).isin([vendor.lower() for vendor in vendor_list])
        )


@track_usage
[docs] def source_ereceipts( receipt_type: Literal[ "clean_headers", "purchase", "flights", "hotels", "rental_cars", "rides", "subscriptions", "walmart", ], vendor_list: list[str] = None, country: Optional[ Literal["us", "intl", "ca", "all", "US", "INTL", "CA", "ALL"] ] = None, include_duplicates: Optional[bool] = False, granularity: Literal["orders", "items"] = None, add_app_id: Optional[bool] = False, add_demo_columns: Optional[bool] = False, ) -> DataFrame: """ Fetch a DataFrame of Edison e-receipts with filtering and transformations. This function standardizes e-receipts access, ensuring consistency across modules. .. tip:: Use this function instead of directly accessing source tables for consistent filtering and schema handling. :param receipt_type: Type of receipt data to retrieve. Supported types: - `clean_headers` - `purchase` - `subscriptions` - `rental_cars` - `hotels` - `flights` - `rides` - `walmart` :param vendor_list: List of vendors to filter. Must be non-empty. :param country: Optional. Valid values: `ALL`, `US`, `CA`, `INTL`. Default varies by type. :param include_duplicates: Include duplicate entries. Default: `False`. Some types disallow duplicates. :param granularity: Specify granularity level. Valid: `items`, `orders`. Default depends on receipt type. :param add_app_id: Include application ID. Default: `False`. Supported only for certain receipt types. :param add_demo_columns: Add demographic columns. Default: `False`. :return: Filtered e-receipts DataFrame. :raises InvalidInputException: On validation failure. ### Examples ^^^^^^^^^^^^ #### Clean Headers Retrieve receipts for specific vendors in the US: .. code-block:: python clean_headers_df = A.source_ereceipts( receipt_type="clean_headers", vendor_list=["Uber", "Walmart"], # at least 1 vendor required country="ALL" # defaults to ALL for clean headers ) #### Including Duplicates Allow duplicates (if supported): .. code-block:: python purchase_df = A.source_ereceipts( receipt_type="purchase", vendor_list=["Walmart"], country="US" # defaults to US for purchase include_duplicates=True # defaults to False for all receipt types ) #### Adjusting Granularity Retrieve order-level data: .. code-block:: python subscriptions_df = A.source_ereceipts( receipt_type="purchase", vendor_list=["UberEats"], granularity="orders" # defaults to "items" for purchase ) #### Adding Extra Columns Include app IDs and user demographic columns: .. code-block:: python purchase_with_details = A.source_ereceipts( receipt_type="purchase", vendor_list=["Walmart"], add_app_id=True, add_demo_columns=True ) Validation Rules ================ .. list-table:: Validation Rules :header-rows: 1 * - Rule - Description * - **Receipt Type** - Must be valid. Raises `InvalidInputException` for invalid types. * - **Vendor List** - Non-empty list of strings. Errors on empty or invalid input. * - **Country** - Valid values: `ALL`, `US`, `CA`, `INTL`. - Defaults to `US` for `purchase` * - **Include Duplicates** - Boolean. Unsupported types error if `True`. * - **Granularity** - `items` or `orders`. Unsupported types raise errors. * - **Add App ID** - Supported for specific types only. Raises error if unsupported. * - **Add Demo Columns** - Boolean. Supported for all types. ### Receipt Type-Specific Behavior .. code-block:: text +-------------------+-----------------+------------+-----------------+--------+--------------+ | Receipt Type | country | duplicates | granularity | app ID | demo columns | +-------------------+-----------------+------------+-----------------+--------+--------------+ | clean_headers | ALL | No | N/A | No | Yes | +-------------------+-----------------+------------+-----------------+--------+--------------+ | purchase | US | Yes | items, orders | Yes | Yes | +-------------------+-----------------+------------+-----------------+--------+--------------+ | subscriptions | ALL | No | items, orders | Yes | Yes | +-------------------+-----------------+------------+-----------------+--------+--------------+ | rental_cars | ALL | Yes | orders | No | Yes | +-------------------+-----------------+------------+-----------------+--------+--------------+ | hotels | ALL | Yes | orders | No | Yes | +-------------------+-----------------+------------+-----------------+--------+--------------+ | flights | ALL | Yes | items | No | Yes | +-------------------+-----------------+------------+-----------------+--------+--------------+ | rides | ALL | No | items, orders | No | Yes | +-------------------+-----------------+------------+-----------------+--------+--------------+ | walmart | ALL | Yes | items, orders | No | Yes | +-------------------+-----------------+------------+-----------------+--------+--------------+ """ spark = get_spark_session() # Get receipt configurations config = get_ereceipts_table_configuration( receipt_type=receipt_type, ) # Validate inputs _validate_source_ereceipts_inputs( config, receipt_type, vendor_list, country, include_duplicates, granularity, add_app_id, add_demo_columns, ) # List source data tables df_main = _filter_for_vendor_list_if_present( spark.table(config.source_tables["main"]), config.vendor_column_name, vendor_list, ) if include_duplicates: if "with_dupes" in config.source_tables.keys(): df_with_dupes = _filter_for_vendor_list_if_present( spark.table(config.source_tables["with_dupes"]), config.vendor_column_name, vendor_list, ) elif "dupes" in config.source_tables.keys(): df_dupes = _filter_for_vendor_list_if_present( spark.table(config.source_tables["dupes"]), config.vendor_column_name, vendor_list, ) else: # Initial validation should already handle this, but just in case raise InvalidInputException( f"Duplicate tables are not supported for receipt type '{receipt_type}'" ) # Define "df" to transform. This is the ultimate return value if include_duplicates: if "with_dupes" in config.source_tables.keys(): df = df_with_dupes.withColumn( # matching product schema config.duplicate_flag_column_name, F.when( F.col("in_duplicate_table"), F.lit("1").cast(StringType()), ).when( ~F.col("in_duplicate_table"), F.lit(None).cast(StringType()), ), ) elif "dupes" in config.source_tables.keys(): _alert_if_last_processed_date_mismatch(df_main, df_dupes, config) df = ( df_main.withColumns( { "in_duplicate_table": F.lit(False), config.duplicate_flag_column_name: F.lit("1").cast( StringType() ), } ) ).unionByName( df_dupes.withColumns( { "in_duplicate_table": F.lit(True), config.duplicate_flag_column_name: F.lit(None).cast( StringType() ), } ) ) else: df = df_main # Dedupe source data if orders granularity = granularity or config.granularity if granularity == "orders" and receipt_type == "purchase": df = _filter_for_vendor_list_if_present( spark.table(config.source_tables["orders"]), config.vendor_column_name, vendor_list, ) elif granularity == "orders": df = ( df.withColumn( "rn", F.row_number().over( W.partitionBy(["checksum"]).orderBy(F.desc("update_timestamp")) ), ) .filter(F.col("rn") == 1) .drop("rn") ) # Filter on country country = country or config.default_country match country.upper(): case "ALL": df = df case "INTL": df = df.filter( E.any( F.col("derived_user_country") != F.lit("US"), F.col("derived_user_country").isNull(), ) ) case "US": df = df.filter(F.col("derived_user_country") == "US") case "CA": df = df.filter(F.col("derived_user_country") == "CA") case _: raise InvalidInputException( f"Provided country ({country}) is not supported in this function. Country must be one of 'US', 'ALL', 'CA', or 'INTL'" ) # Apply product transformation if config.transformation_fn: df = config.transformation_fn(df, include_duplicates, config) # Join app ID if add_app_id: app_df = spark.table(config.app_id_mapping_table_name).select( "item_checksum", "app_id", "app_release_date", "app_name", F.col("category_name").alias("app_category"), ) df = df.join( app_df, on=["item_checksum"], how="left", ) # For clean headers, user joining is done upstream # Re-join user if adding demo cols (less frequently used) if receipt_type == "clean_headers" and not (add_demo_columns): return df # Join user ID & optionally remove demo columns # Note that DST and geo weighting columns are dervied from user demo columns df = df.drop( *PRE_JOINED_USER_COLUMNS, *USER_ENRICHMENT_COLUMNS, *USER_DEMO_COLUMNS, ) df_user = spark.table(config.user_table_name).select( "user_id", *USER_ENRICHMENT_COLUMNS, *USER_DEMO_COLUMNS, ) df = df.alias("a").join( df_user.alias("b"), on=["user_id"], how="inner", ) # Add DST column, derived from user demo if receipt_type in [ "purchase", "rides", "walmart", "subscriptions", ]: df = df.select( "*", local_email_time_logic.alias("local_email_time"), ) # Add geo weighting columns, dervied from user demo if receipt_type in ["purchase"]: metro_weights = spark.table(config.geo_weight_table_name).alias("weights") df = ( df.alias("a") .join( metro_weights.alias("b"), E.all( F.col("a.order_date") == F.col("b.order_date"), F.col("a.metro_name") == F.col("b.metro_name"), ), how="left", ) .select( "a.*", "b.central_geo_weight_metro_weekly", "b.central_geo_weight_metro_monthly", "b.central_geo_weight_metro_quarterly", ) ) # Optionally remove user demo columns once dst & geo weighting columns are added if not add_demo_columns: df = df.drop(*USER_DEMO_COLUMNS) return df