Source code for etl_toolkit.analyses.card.source

from datetime import date
from enum import Enum
from typing import List, Literal, Optional

from yipit_databricks_utils.helpers.pyspark_utils import get_spark_session
from yipit_databricks_utils.helpers.telemetry import track_usage
from pyspark.sql import DataFrame, functions as F

from etl_toolkit.exceptions import InvalidInputException
from etl_toolkit.analyses.card.core import get_table_configuration_for_dataset


class CardTransactionType(str, Enum):
    """Enum defining the types of card transactions available across datasets."""

    TRANSACTIONS = "transactions"
    DEPOSITS = "deposits"
    ALL = "all"


def _validate_and_normalize_transaction_types(
    transaction_types: Optional[List[str | CardTransactionType]],
) -> List[CardTransactionType]:
    """
    Validates and normalizes a list of transaction types.

    :param transaction_types: List of transaction types to validate
    :return: List of normalized CardTransactionType enums
    """
    if transaction_types is None:
        return [CardTransactionType.TRANSACTIONS]

    normalized_types = []
    for t in transaction_types:
        if isinstance(t, str):
            try:
                t = CardTransactionType(t.lower())
            except ValueError:
                raise InvalidInputException(
                    f"Invalid transaction type {t}, must be one of {[t.value for t in CardTransactionType]}"
                )
        elif not isinstance(t, CardTransactionType):
            raise InvalidInputException(
                f"Transaction type must be string or CardTransactionType enum, got {type(t)}"
            )
        normalized_types.append(t)

    # If ALL is specified, ignore other types
    if CardTransactionType.ALL in normalized_types:
        return [CardTransactionType.ALL]

    return normalized_types


def _validate_date_range(start_date: Optional[date], end_date: Optional[date]) -> None:
    """Validates that a date range is logical."""
    if start_date and end_date and start_date > end_date:
        raise InvalidInputException(
            f"Invalid date range: start_date {start_date} must be before or equal to end_date {end_date}"
        )


@track_usage
[docs] def source_card_transactions( dataset: Literal["skywalker", "mando", "yoda"], start_date: Optional[date] = None, end_date: Optional[date] = None, transaction_types: Optional[List[str | CardTransactionType]] = None, group_column_names: Optional[List[str]] = None, ) -> DataFrame: """ Returns a DataFrame of card transactions from the specified source dataset, with optional filtering for date ranges and transaction types. This function provides a standardized way to access card transaction data across different datasets while handling their unique schemas and storage patterns. The function supports both single-table datasets (like Mando) where transaction types are distinguished by a column value, and multi-table datasets (like Skywalker/Yoda) where different transaction types are stored in separate tables. .. tip:: When working with card transaction data, it's recommended to use this function rather than directly accessing the source tables to ensure consistent filtering and proper handling of dataset-specific schemas. :param dataset: The source dataset to retrieve transactions from. Must be one of "skywalker", "mando", or "yoda". :param start_date: Optional start date to filter transactions. If provided, only includes transactions on or after this date. :param end_date: Optional end date to filter transactions. If provided, only includes transactions on or before this date. :param transaction_types: Optional list of transaction types to include. Can use either string values or CardTransactionType enum. Valid values are "transactions", "deposits", or "all". Defaults to ["transactions"] if not specified. If "all" is included in the list, other types are ignored. :param group_column_names: Optional list of grouping columns to use for this dataset, overriding defaults :return: A DataFrame containing the filtered card transactions from the specified dataset. :raises InvalidInputException: If any validation checks fail Examples ^^^^^^^^ .. code-block:: python :caption: Basic usage to get recent Mando transactions from datetime import date from etl_toolkit import A # Get Mando transactions for January 2024 transactions_df = A.source_card_transactions( dataset="mando", start_date=date(2024, 1, 1), end_date=date(2024, 1, 31) ) .. code-block:: python :caption: Get both transactions and deposits from Skywalker from etl_toolkit import A from etl_toolkit.analyses.card.source import CardTransactionType # Get all Skywalker transaction types using enum all_activity_df = A.source_card_transactions( dataset="skywalker", transaction_types=[ CardTransactionType.TRANSACTIONS, CardTransactionType.DEPOSITS ] ) # Alternative using string values all_activity_df = A.source_card_transactions( dataset="skywalker", transaction_types=["transactions", "deposits"] ) .. code-block:: python :caption: Get all transaction types using the ALL type from etl_toolkit import A from etl_toolkit.analyses.card.source import CardTransactionType # Get all transaction types from Mando all_df = A.source_card_transactions( dataset="mando", transaction_types=[CardTransactionType.ALL] ) # Alternative using string value all_df = A.source_card_transactions( dataset="mando", transaction_types=["all"] ) .. code-block:: python :caption: Get only deposits from a specific date range from datetime import date from etl_toolkit import A # Get Yoda deposits for Q1 2024 deposits_df = A.source_card_transactions( dataset="yoda", start_date=date(2024, 1, 1), end_date=date(2024, 3, 31), transaction_types=["deposits"] ) """ spark = get_spark_session() # Validate inputs _validate_date_range(start_date, end_date) normalized_types = _validate_and_normalize_transaction_types(transaction_types) # Get dataset configuration and data config = get_table_configuration_for_dataset( dataset=dataset, group_column_names=group_column_names ) if config.uses_combined_table or CardTransactionType.ALL in normalized_types: # Use single table with optional type filtering df = spark.table(config.get_source_table(normalized_types)) type_filter = config.get_transaction_filter(normalized_types) if type_filter is not None: df = df.where(type_filter) else: # Union separate type tables dfs = [] for txn_type in normalized_types: table_name = config.get_source_table([txn_type]) df = spark.table(table_name) dfs.append(df) df = dfs[0] if len(dfs) == 1 else dfs[0].unionByName(*dfs[1:]) # Apply date filters if specified if start_date: df = df.where(F.col(config.date_column_name) >= F.lit(start_date)) if end_date: df = df.where(F.col(config.date_column_name) <= F.lit(end_date)) return df