from datetime import date
from enum import Enum
from typing import List, Literal, Optional
from yipit_databricks_utils.helpers.pyspark_utils import get_spark_session
from yipit_databricks_utils.helpers.telemetry import track_usage
from pyspark.sql import DataFrame, functions as F
from etl_toolkit.exceptions import InvalidInputException
from etl_toolkit.analyses.card.core import get_table_configuration_for_dataset
class CardTransactionType(str, Enum):
"""Enum defining the types of card transactions available across datasets."""
TRANSACTIONS = "transactions"
DEPOSITS = "deposits"
ALL = "all"
def _validate_and_normalize_transaction_types(
transaction_types: Optional[List[str | CardTransactionType]],
) -> List[CardTransactionType]:
"""
Validates and normalizes a list of transaction types.
:param transaction_types: List of transaction types to validate
:return: List of normalized CardTransactionType enums
"""
if transaction_types is None:
return [CardTransactionType.TRANSACTIONS]
normalized_types = []
for t in transaction_types:
if isinstance(t, str):
try:
t = CardTransactionType(t.lower())
except ValueError:
raise InvalidInputException(
f"Invalid transaction type {t}, must be one of {[t.value for t in CardTransactionType]}"
)
elif not isinstance(t, CardTransactionType):
raise InvalidInputException(
f"Transaction type must be string or CardTransactionType enum, got {type(t)}"
)
normalized_types.append(t)
# If ALL is specified, ignore other types
if CardTransactionType.ALL in normalized_types:
return [CardTransactionType.ALL]
return normalized_types
def _validate_date_range(start_date: Optional[date], end_date: Optional[date]) -> None:
"""Validates that a date range is logical."""
if start_date and end_date and start_date > end_date:
raise InvalidInputException(
f"Invalid date range: start_date {start_date} must be before or equal to end_date {end_date}"
)
@track_usage
[docs]
def source_card_transactions(
dataset: Literal["skywalker", "mando", "yoda"],
start_date: Optional[date] = None,
end_date: Optional[date] = None,
transaction_types: Optional[List[str | CardTransactionType]] = None,
group_column_names: Optional[List[str]] = None,
) -> DataFrame:
"""
Returns a DataFrame of card transactions from the specified source dataset, with optional filtering for date ranges
and transaction types. This function provides a standardized way to access card transaction data across different
datasets while handling their unique schemas and storage patterns.
The function supports both single-table datasets (like Mando) where transaction types are distinguished by a column value,
and multi-table datasets (like Skywalker/Yoda) where different transaction types are stored in separate tables.
.. tip:: When working with card transaction data, it's recommended to use this function rather than directly
accessing the source tables to ensure consistent filtering and proper handling of dataset-specific schemas.
:param dataset: The source dataset to retrieve transactions from. Must be one of "skywalker", "mando", or "yoda".
:param start_date: Optional start date to filter transactions. If provided, only includes transactions on or after this date.
:param end_date: Optional end date to filter transactions. If provided, only includes transactions on or before this date.
:param transaction_types: Optional list of transaction types to include. Can use either string values or CardTransactionType enum.
Valid values are "transactions", "deposits", or "all". Defaults to ["transactions"] if not specified.
If "all" is included in the list, other types are ignored.
:param group_column_names: Optional list of grouping columns to use for this dataset, overriding defaults
:return: A DataFrame containing the filtered card transactions from the specified dataset.
:raises InvalidInputException: If any validation checks fail
Examples
^^^^^^^^
.. code-block:: python
:caption: Basic usage to get recent Mando transactions
from datetime import date
from etl_toolkit import A
# Get Mando transactions for January 2024
transactions_df = A.source_card_transactions(
dataset="mando",
start_date=date(2024, 1, 1),
end_date=date(2024, 1, 31)
)
.. code-block:: python
:caption: Get both transactions and deposits from Skywalker
from etl_toolkit import A
from etl_toolkit.analyses.card.source import CardTransactionType
# Get all Skywalker transaction types using enum
all_activity_df = A.source_card_transactions(
dataset="skywalker",
transaction_types=[
CardTransactionType.TRANSACTIONS,
CardTransactionType.DEPOSITS
]
)
# Alternative using string values
all_activity_df = A.source_card_transactions(
dataset="skywalker",
transaction_types=["transactions", "deposits"]
)
.. code-block:: python
:caption: Get all transaction types using the ALL type
from etl_toolkit import A
from etl_toolkit.analyses.card.source import CardTransactionType
# Get all transaction types from Mando
all_df = A.source_card_transactions(
dataset="mando",
transaction_types=[CardTransactionType.ALL]
)
# Alternative using string value
all_df = A.source_card_transactions(
dataset="mando",
transaction_types=["all"]
)
.. code-block:: python
:caption: Get only deposits from a specific date range
from datetime import date
from etl_toolkit import A
# Get Yoda deposits for Q1 2024
deposits_df = A.source_card_transactions(
dataset="yoda",
start_date=date(2024, 1, 1),
end_date=date(2024, 3, 31),
transaction_types=["deposits"]
)
"""
spark = get_spark_session()
# Validate inputs
_validate_date_range(start_date, end_date)
normalized_types = _validate_and_normalize_transaction_types(transaction_types)
# Get dataset configuration and data
config = get_table_configuration_for_dataset(
dataset=dataset, group_column_names=group_column_names
)
if config.uses_combined_table or CardTransactionType.ALL in normalized_types:
# Use single table with optional type filtering
df = spark.table(config.get_source_table(normalized_types))
type_filter = config.get_transaction_filter(normalized_types)
if type_filter is not None:
df = df.where(type_filter)
else:
# Union separate type tables
dfs = []
for txn_type in normalized_types:
table_name = config.get_source_table([txn_type])
df = spark.table(table_name)
dfs.append(df)
df = dfs[0] if len(dfs) == 1 else dfs[0].unionByName(*dfs[1:])
# Apply date filters if specified
if start_date:
df = df.where(F.col(config.date_column_name) >= F.lit(start_date))
if end_date:
df = df.where(F.col(config.date_column_name) <= F.lit(end_date))
return df