from dataclasses import dataclass
from typing import Literal, Optional, Union, ClassVar
import re
from pyspark.sql import DataFrame, Column, functions as F, Window as W, types as T
from etl_toolkit import A, E
from etl_toolkit.exceptions import InvalidInputException
@dataclass
[docs]
class coverage_metric:
"""
Defines a metric for calculating coverage ratios between observed and reported values.
:param name: Name of the metric (must be alphanumeric/lowercase/underscore)
:param observed_column_name: Column containing the observed values to measure
:param reported_column_name: Column containing the reported values to compare against
:param filter: Optional filter condition to apply before aggregation
Examples
--------
.. code-block:: python
from etl_toolkit import F, A
# Basic metric definition
revenue_metric = A.coverage_metric(
name='revenue_total',
observed_column_name='adj_trans_amount',
reported_column_name='reported_total'
)
# Metric with filter condition
online_revenue = A.coverage_metric(
name='revenue_online',
observed_column_name='adj_trans_amount',
reported_column_name='reported_online',
filter=F.col('channel')=='ONLINE'
)
"""
name: str
observed_column_name: Union[str, Column]
reported_column_name: Union[str, Column]
filter: Optional[Column] = None
def __post_init__(self):
if not re.match(r"^[a-z0-9_]+$", self.name):
raise InvalidInputException(
"Metric name must be alphanumeric lowercase with underscores"
)
@property
def coverage_column(self) -> str:
return f"coverage_{self.name}"
@property
def last_quarter_column(self) -> str:
return f"{self.coverage_column}_last_quarter"
@property
def qoq_column(self) -> str:
return f"{self.coverage_column}_qoq"
@property
def seasonal_column(self) -> str:
return f"{self.coverage_column}_seasonal"
@property
def recent_column(self) -> str:
return f"{self.coverage_column}_recent"
@property
def trend_column(self) -> str:
return f"{self.coverage_column}_trend"
def _get_period_config(periodicity: Literal["QUARTER", "DAY"]) -> dict:
"""Helper function to get period-specific configuration."""
match periodicity.upper():
case "QUARTER":
return {
"grouping_columns": ["quarter", "start_date", "end_date"],
"ordering_column": "start_date",
"comparison_periods": [1, 2, 3, 4, 5],
}
case "DAY":
return {
"grouping_columns": ["date"],
"ordering_column": "date",
"comparison_periods": [1, 7, 14, 30],
}
case _:
raise InvalidInputException("periodicity must be either 'QUARTER' or 'DAY'")
def _calculate_lag_metrics(
df: DataFrame, metric: coverage_metric, time_window: W, periodicity: str
) -> DataFrame:
"""Helper function to calculate lag-based metrics."""
lower_periodicity = periodicity.lower()
lag_columns = {
f"{metric.coverage_column}_period_{period}_{lower_periodicity}": F.lag(
F.col(metric.coverage_column), period
).over(time_window)
for period in _get_period_config(periodicity)["comparison_periods"]
}
return df.withColumns(lag_columns)
def _correct_coverage_coalesce(column_name: str, reported_column_name: str) -> Column:
"""Helper function to handle reporting edge cases."""
return F.when(
E.all(
[F.col("reported_total").isNotNull(), F.col(reported_column_name).isNull()]
),
F.lit(None),
).otherwise(F.col(column_name))
[docs]
def card_coverage_metrics(
dataframe: DataFrame,
metric_definitions: list[coverage_metric],
periodicity: Literal["QUARTER", "DAY"] = "QUARTER",
time_period_grouping_columns: Optional[list[Union[str, Column]]] = None,
additional_grouping_columns: list[Union[str, Column]] = ["ticker"],
transaction_date_column: Union[str, Column] = "trans_date",
time_period_ordering_column: Optional[Union[str, Column]] = None,
normalize_by_days: bool = False,
) -> DataFrame:
"""
Calculates coverage ratios by comparing observed values to reported metrics across time periods.
Coverage metrics help assess data completeness by measuring the difference between transactions
captured in card data versus those reported in company filings.
The function adds coverage calculations and lag-based estimates:
- Basic coverage ratios (observed/reported)
- Period-over-period coverage changes
- Lag-adjusted metrics accounting for historical patterns
- Seasonal adjustments (for quarterly data)
- Recent trend indicators (for daily data)
:param dataframe: DataFrame containing transaction and reported metric data
:param metric_definitions: List of coverage_metric objects defining metrics to calculate
:param periodicity: Analysis granularity, either "QUARTER" or "DAY". Defaults to "QUARTER"
:param time_period_grouping_columns: Columns to group timestamps. Defaults to ["quarter", "start_date", "end_date"] for quarters, ["date"] for days
:param additional_grouping_columns: Additional grouping dimensions (e.g. ticker). Defaults to ["ticker"]
:param transaction_date_column: Column containing transaction dates. Defaults to "trans_date"
:param time_period_ordering_column: Column for ordering periods. Defaults to "start_date" or "date" based on periodicity
:param normalize_by_days: Whether to adjust for partial data periods. Defaults to False
Examples
--------
.. code-block:: python
:caption: Calculate quarterly coverage metrics with base metrics
from etl_toolkit import F, A
# Define metrics to analyze
metrics = [
A.coverage_metric(
name='revenue_total',
observed_column_name='adj_trans_amount',
reported_column_name='reported_total'
)
]
quarterly_coverage = A.card_coverage_metrics(
transactions_df,
metric_definitions=metrics,
periodicity='QUARTER'
)
display(quarterly_coverage)
+--------+------------+-------------+----------------+----------------------------+---------------------------+
|ticker |quarter |revenue_total|reported_total |coverage_revenue_total |coverage_revenue_seasonal |
+--------+------------+-------------+----------------+----------------------------+---------------------------+
|AAPL |2024-01-01 | 1500000| 2000000 | 0.75 | 0.77 |
+--------+------------+-------------+----------------+----------------------------+---------------------------+
|AAPL |2023-10-01 | 1400000| 2000000 | 0.70 | 0.72 |
+--------+------------+-------------+----------------+----------------------------+---------------------------+
.. code-block:: python
:caption: Calculate daily coverage with normalization and channel splits
from etl_toolkit import F, A
metrics = [
A.coverage_metric(
name='revenue_total',
observed_column_name='adj_trans_amount',
reported_column_name='reported_total'
),
A.coverage_metric(
name='revenue_online',
observed_column_name='adj_trans_amount',
reported_column_name='reported_online',
filter=F.col('channel')=='ONLINE'
)
]
daily_coverage = A.card_coverage_metrics(
transactions_df,
metric_definitions=metrics,
periodicity='DAY',
normalize_by_days=True
)
display(daily_coverage)
+--------+----------+-------------+--------------+---------------+----------------------+------------------------+
|ticker |date |revenue_total|revenue_online|reported_total|coverage_revenue_total|coverage_revenue_online |
+--------+----------+-------------+--------------+---------------+----------------------+------------------------+
|AAPL |2024-01-01| 150000| 50000| 200000| 0.75 | 0.71 |
+--------+----------+-------------+--------------+---------------+----------------------+------------------------+
|AAPL |2024-01-02| 140000| 45000| 200000| 0.70 | 0.68 |
+--------+----------+-------------+--------------+---------------+----------------------+------------------------+
"""
periodicity = periodicity.upper()
period_config = _get_period_config(periodicity)
time_period_grouping_columns = (
time_period_grouping_columns or period_config["grouping_columns"]
)
time_period_ordering_column = (
time_period_ordering_column or period_config["ordering_column"]
)
# Add period columns first
if periodicity == "DAY":
if transaction_date_column != "date":
dataframe = dataframe.withColumn(
"date", F.date_trunc("DAY", transaction_date_column)
)
# Handle empty additional_grouping_columns
if len(additional_grouping_columns) == 0:
additional_grouping_columns = [F.lit(1)]
all_grouping_columns = time_period_grouping_columns + additional_grouping_columns
time_window = W.partitionBy(additional_grouping_columns).orderBy(
time_period_ordering_column
)
# Track reported metrics to avoid duplication
unique_reported_metrics = []
for metric in metric_definitions:
if metric.reported_column_name not in unique_reported_metrics:
unique_reported_metrics.append(metric.reported_column_name)
# Filter for valid periods
filtered_df = dataframe.filter(
E.any([F.col(x).isNotNull() for x in time_period_grouping_columns])
)
# Build aggregations for each metric and reported metric
aggregation_expressions = [
F.sum(
F.col(metric.observed_column_name)
if metric.filter is None
else F.col(metric.observed_column_name) * metric.filter.cast("int")
).alias(metric.name)
for metric in metric_definitions
] + [F.first(F.col(metric)).alias(metric) for metric in unique_reported_metrics]
# Add min date aggregation if needed
if normalize_by_days and periodicity == "QUARTER":
aggregation_expressions.append(
F.min(transaction_date_column).alias("min_transaction_date")
)
# Group and aggregate data
observed_df = filtered_df.groupBy(*all_grouping_columns).agg(
*aggregation_expressions
)
# Handle day normalization
if normalize_by_days and periodicity == "QUARTER":
observed_df = observed_df.withColumns(
{
"days_in_quarter": F.datediff("end_date", "start_date") + 1,
"days_with_data": F.datediff("end_date", "min_transaction_date") + 1,
}
)
day_normalization_factor = F.col("days_with_data") / F.col("days_in_quarter")
else:
day_normalization_factor = F.lit(1)
# Calculate base coverage ratios
base_coverage_columns = {
metric.coverage_column: F.col(metric.name)
/ (F.col(metric.reported_column_name) * day_normalization_factor)
for metric in metric_definitions
}
coverage_metrics = observed_df.withColumns(base_coverage_columns)
# Calculate metrics for each period
for metric in metric_definitions:
coverage_metrics = _calculate_lag_metrics(
coverage_metrics, metric, time_window, periodicity
)
lower_periodicity = periodicity.lower()
if periodicity == "QUARTER":
# Calculate seasonal metrics
seasonal_col = F.col(
f"{metric.coverage_column}_period_1_{lower_periodicity}"
) * (
F.col(f"{metric.coverage_column}_period_4_{lower_periodicity}")
/ F.col(f"{metric.coverage_column}_period_5_{lower_periodicity}")
)
coverage_metrics = coverage_metrics.withColumns(
{
metric.last_quarter_column: F.coalesce(
F.col(metric.coverage_column),
F.col(f"{metric.coverage_column}_period_1_{lower_periodicity}"),
F.col(f"{metric.coverage_column}_period_2_{lower_periodicity}"),
),
metric.qoq_column: F.coalesce(
F.col(metric.coverage_column),
F.pow(
F.col(
f"{metric.coverage_column}_period_1_{lower_periodicity}"
),
2,
)
/ F.col(
f"{metric.coverage_column}_period_2_{lower_periodicity}"
),
F.pow(
F.col(
f"{metric.coverage_column}_period_2_{lower_periodicity}"
),
3,
)
/ F.pow(
F.col(
f"{metric.coverage_column}_period_3_{lower_periodicity}"
),
2,
),
),
metric.seasonal_column: F.coalesce(
F.col(metric.coverage_column),
seasonal_col,
F.lag(seasonal_col, 1).over(time_window)
* (
F.col(
f"{metric.coverage_column}_period_4_{lower_periodicity}"
)
/ F.col(
f"{metric.coverage_column}_period_5_{lower_periodicity}"
)
),
),
}
)
else:
coverage_metrics = coverage_metrics.withColumns(
{
metric.recent_column: F.coalesce(
F.col(metric.coverage_column),
F.col(f"{metric.coverage_column}_period_1_{lower_periodicity}"),
F.col(f"{metric.coverage_column}_period_7_{lower_periodicity}"),
),
metric.trend_column: F.coalesce(
F.col(metric.coverage_column),
F.col(f"{metric.coverage_column}_period_7_{lower_periodicity}")
/ F.col(
f"{metric.coverage_column}_period_14_{lower_periodicity}"
),
),
}
)
# Handle reporting edge cases for quarterly
if periodicity == "QUARTER":
for metric in metric_definitions:
if metric.name not in ["revenue_total"]:
coverage_metrics = coverage_metrics.withColumns(
{
metric.last_quarter_column: _correct_coverage_coalesce(
metric.last_quarter_column, metric.reported_column_name
),
metric.qoq_column: _correct_coverage_coalesce(
metric.qoq_column, metric.reported_column_name
),
metric.seasonal_column: _correct_coverage_coalesce(
metric.seasonal_column, metric.reported_column_name
),
}
)
# Drop intermediate columns
output_columns = [
col for col in coverage_metrics.columns if not col.startswith("_")
]
final_coverage = coverage_metrics.select(*output_columns)
return final_coverage