from datetime import datetime, date
from typing import Literal
from pyspark.sql import functions as F, types as T, Column, DataFrame, Window as W
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
from pyspark.sql.connect.column import Column as ConnectColumn
from pyspark.errors.exceptions.connect import (
AnalysisException as AnalysisConnectException,
)
from pyspark.errors import AnalysisException
from pyspark.sql import SparkSession
from yipit_databricks_utils.helpers.telemetry import track_usage
from etl_toolkit import expressions as E
from etl_toolkit.exceptions import InvalidColumnTypeException, InvalidInputException
from etl_toolkit.analyses.scalar import get_column_type, get_aggregates
from etl_toolkit.analyses.ordering import shift_columns
from etl_toolkit.analyses.calculation import add_percent_of_total_columns
from etl_toolkit.analyses.time import periods
RANGE_JOIN_BIN_SIZE = {
"WEEK": 7,
"MONTH": 30,
"QUARTER": 90,
}
@track_usage
[docs]
def index_from_rolling_panel(
input_df: DataFrame,
start_date: datetime | date,
date_column: Column | str,
metrics: dict[str, dict],
slices: list[str | Column] = None,
panel_periodicity: Literal["WEEK", "MONTH", "QUARTER"] = "MONTH",
index_periodicity: Literal["DAY", "WEEK", "MONTH", "QUARTER"] = "DAY",
panel_end_date_column: Column | str = "panel_end_date",
user_create_date_column: Column | str = "create_time",
spark: SparkSession = None,
) -> DataFrame:
"""
Returns a dataframe that maps the input dataframe to a series of indices for each metric provided.
The indices are calculated through an aggregate function, a base value, and a rolling panel methodology, where the panel is fixed for 2 consecutive periods.
The period length can be configured using ``panel_periodicity``.
Indices can be broken down across various ``slices`` to visualize trends on key segments of the dataset.
Multiple metrics can be specified, these will be available as different columns in the output dataframe.
.. note:: This function is expected to work with Edison (e-receipt) derived dataframes. The rolling panel methodology is
designed around the fluctuating behavior of the Edison user panel.
:param input_df: The input dataframe to use to generate a series of indices.
:param start_date: The minimum start date of the ``input_df``, all metrics will be filtered on or after this date.
:param date_column: A Column or string that indicates the date column of the ``input_df`` to use in paneling and index logic. If a string is used, it is referenced as a Column.
:param metrics: A dict of various metric configurations to generate indices. An index will be created for each metric defined in this dictionary. Each metric config must have an ``aggregation`` Column and a ``start_value`` defined for the index.
:param slices: Optional list of string or columns that define the slices with in the dataframe. If slices are used, the indices generated will be for each slice within the ``input_df``. If strings are used in the list they are resolved as Columns.
:param panel_periodicity: The unit of length for the inverval of each paneling window.
:param index_periodicity: The unit of length for the inverval of each indexing window. A row will be generated for each period based on this argument.
:param panel_end_date_column: The Column to use when determing the exit date of a user from the panel in the rolling panel methodology.
:param user_create_date_column: The Column to use when determing the state date of a user in the panel in the rolling panel methodology.
:param spark: Optional argument to pass in a spark session to this function. Normally this is not used and a spark session is generated automatically. This is usually used for library developers.
Examples
^^^^^^^^^^^^^^^^
.. code-block:: python
:caption: Example of a daily index creation for Lyft. This creates two indices, ``rides`` and ``total_charges`` based on different aggregations.
The paneling is a rolling-panel based on a weekly interval.
from etl_toolkit import E, A, F
source_df = spark.table(f"{DESTINATION_CATALOG}.{SILVER_DATABASE}.deduped_receipts")
index_df = A.index_from_rolling_panel(
source_df,
start_date=date(2017, 1, 2),
date_column="adjusted_order_date",
metrics= {
"rides": {
"calculation": F.sum("geo_weight"),
"start_value": 13.471719625795382,
},
"total_charges": {
"calculation": F.sum(total_charges),
"start_value": 13.471719625795382,
},
},
panel_periodicity="WEEK",
index_periodicity="DAY",
)
"""
_validate_inputs(input_df, start_date, panel_periodicity)
# Normalize date/timestamp specific columns into a date type after validation
date_column = E.normalize_column(date_column)
user_create_date_column = E.normalize_column(user_create_date_column)
panel_end_date_column = E.normalize_column(panel_end_date_column)
_validate_date_column(input_df, date_column)
date_column_name = input_df.select(date_column).columns[0]
date_column = date_column.cast("date").alias(date_column_name)
# Validate metric columns, should raise analysis exception if not valid
_validate_metrics(input_df, metrics)
index_metrics = [
config["calculation"].alias(metric_name)
for metric_name, config in metrics.items()
]
# Validate slices, if used
if slices is None:
slices = []
else:
slices = [E.normalize_column(slice_column) for slice_column in slices]
_validate_slices(input_df, slices)
# Normalize input data by truncating to start_date and enriching with period columns for paneling logic
max_data_through = get_aggregates(input_df, date_column, ["max"])["max"]
max_panel_end = get_aggregates(
input_df, panel_end_date_column.cast("date"), ["max"]
)["max"]
periods_df = _get_current_and_future_periods(
start_date, max_data_through, periodicity=panel_periodicity, spark=spark
)
normalized_input = input_df.filter(date_column >= start_date).join(
periods_df.hint("range_join", RANGE_JOIN_BIN_SIZE[panel_periodicity]),
E.between(date_column, "period_start", "period_end"),
how="left",
)
# We want to generate 2 aggregations to create a rolling panel index
# - Aggregation #1 includes panel members that are at least through the current period end
# - Aggregation #2 includes panel members that are at least through the prior period end
# - Then we join #1 to #2 on a lag to determine period-over-period growth rates and back into an index value
current_period_aggregates = (
normalized_input.where(
user_create_date_column
< F.coalesce(F.col("prev_period_start"), F.lit(start_date))
)
.where(
E.any(
[
panel_end_date_column.cast("date") > F.col("period_end"),
panel_end_date_column.cast("date") == max_panel_end,
]
)
)
.groupBy("period_start", "period_end")
.agg(*index_metrics)
.unpivot(
["period_start", "period_end"], [x for x in metrics], "metric_name", "value"
)
)
prev_period_aggregates = (
_filter_for_panel(
normalized_input,
max_panel_end,
user_create_date_column,
panel_end_date_column,
)
.groupBy("period_end", "next_period_start")
.agg(*index_metrics)
.unpivot(
["period_end", "next_period_start"],
[metric_name for metric_name in metrics],
"metric_name",
"value",
)
)
period_over_period_growth = F.coalesce(
F.try_divide(current_period_aggregates.value, prev_period_aggregates.value) - 1,
F.lit(0),
)
generate_index_window = W.partitionBy(
current_period_aggregates.metric_name
).orderBy(current_period_aggregates.period_end.asc())
generate_index_from_growth_rates = F.exp(
F.sum(F.log(1 + period_over_period_growth)).over(generate_index_window)
) * F.col("start_value")
index_df = (
current_period_aggregates.withColumn(
"start_value",
E.chain_assigns(
[
E.assign(
F.lit(config["start_value"]),
F.col("metric_name") == metric_name,
)
for metric_name, config in metrics.items()
]
),
)
.join(
prev_period_aggregates,
E.all(
[
current_period_aggregates.period_start
== prev_period_aggregates.next_period_start,
current_period_aggregates.metric_name
== prev_period_aggregates.metric_name,
]
),
how="left",
)
.select(
current_period_aggregates.period_start,
current_period_aggregates.period_end,
current_period_aggregates.metric_name,
period_over_period_growth.alias("period_growth"),
generate_index_from_growth_rates.alias("index_value"),
)
)
# If slices are specified, re-aggregate values using the slices + day, otherwise just day
# and then join to the index_df to calculate percentage of daily index to allocate to each day( + slice).
# Index values need to be calculated across slices first to account for distribution bias
# or incompleteness within each slice
group_by_columns = [date_column, "period_start", "period_end"]
if len(slices):
group_by_columns = group_by_columns + slices
# Aggregate metrics by day and optionally slices
# Then aggregate metrics again by period and get the percent of the day(+slice)
# this will be used to scale the daily index in the final output
daily_df = (
_filter_for_panel(
normalized_input,
max_panel_end,
user_create_date_column,
panel_end_date_column,
)
.groupBy(group_by_columns)
.agg(*index_metrics)
)
daily_mix_df = add_percent_of_total_columns(
daily_df,
value_columns=[metric_name for metric_name in metrics],
total_grouping_columns=["period_end"],
).unpivot(
group_by_columns,
[F.col(f"{metric_name}_percent").alias(metric_name) for metric_name in metrics],
"metric_name",
"percent_of_total",
)
daily_index_df = daily_mix_df.join(
index_df.hint("range_join", RANGE_JOIN_BIN_SIZE[panel_periodicity]),
E.all(
[
E.between(date_column_name, index_df.period_start, index_df.period_end),
index_df.metric_name == daily_mix_df.metric_name,
]
),
how="inner",
).select(
F.col(date_column_name).alias("date"),
index_df.period_start.alias("period_start"),
index_df.period_end.alias("period_end"),
_metric_name_with_slices(
daily_mix_df.metric_name,
slices,
).alias("metric_name"),
(F.col("index_value") * F.col("percent_of_total")).alias("index_value"),
index_df.metric_name.alias("base_metric_name"),
*slices,
)
final_df = shift_columns(
daily_index_df,
[
"date",
"period_start",
"period_end",
"metric_name",
E.growth_rate_by_lag(
"index_value",
W.partitionBy("metric_name").orderBy("date"),
default=0,
).alias("daily_growth"),
"index_value",
"base_metric_name",
],
)
# Gross up to higher periodicities if index periodicity is not daily
if index_periodicity in ["WEEK", "MONTH", "QUARTER"]:
index_date_df = periods(
start_date,
max_data_through,
steps=1,
step_unit=index_periodicity,
spark=spark,
)
final_df = (
final_df.join(
index_date_df.hint(
"range_join", RANGE_JOIN_BIN_SIZE[index_periodicity]
),
E.between(
final_df.date, index_date_df.period_start, index_date_df.period_end
),
how="left",
)
.groupBy(
index_date_df.period_start,
index_date_df.period_end,
"metric_name",
"base_metric_name",
*slices,
)
.agg(F.sum("index_value").alias("index_value"))
.select(
F.col("period_start").alias("date"),
"period_start",
"period_end",
"metric_name",
E.growth_rate_by_lag(
"index_value",
W.partitionBy("metric_name").orderBy("period_start"),
default=0,
).alias("period_growth"),
"index_value",
"base_metric_name",
*slices,
)
)
return final_df
def _validate_inputs(
input_df: DataFrame, start_date: date | datetime, panel_periodicity: str
):
if panel_periodicity not in ["WEEK", "MONTH", "QUARTER"]:
raise InvalidInputException(
f'Invalid panel_periodicity {panel_periodicity}, must be one of "WEEK", "MONTH", "QUARTER"'
)
if not isinstance(start_date, (date, datetime)):
raise InvalidInputException(
f"Invalid start_date, {start_date}, must be a datetime or date python object"
)
if not isinstance(input_df, (DataFrame, ConnectDataFrame)):
raise InvalidInputException(
f"Invalid input_df, {input_df}, must be a dataframe"
)
def _validate_date_column(input_df: DataFrame, date_column: Column):
_check_column_on_df(
input_df,
date_column,
f"Provided date_column {date_column} does not exist or cannot be queried on the input_df, change the column",
)
if get_column_type(input_df, date_column) not in ["date", "timestamp"]:
raise InvalidColumnTypeException(
f"Invalid date_column {date_column}, must be a date or timestamp type"
)
def _validate_slices(input_df: DataFrame, slice_columns: list[Column]):
# Validate slice columns exist on dataframe
for slice_column in slice_columns:
_check_column_on_df(
input_df,
slice_column,
f"Provided slice {slice_column} does not exist or cannot be queried on the input_df, change the column",
)
def _validate_metrics(input_df: DataFrame, metrics: dict):
if not isinstance(metrics, dict):
raise InvalidInputException(
"Provided metrics must be a dict, with keys representing the metric names and the value as the metric configuration"
)
for metric_name, configuration in metrics.items():
if "calculation" not in configuration:
raise InvalidInputException(
f"Provided metric {metric_name} must have a calculation that is a pyspark aggregate expression"
)
if not isinstance(configuration["calculation"], (Column, ConnectColumn)):
raise InvalidInputException(
f"Provided metric {metric_name} must have a calculation that is a pyspark aggregate expression"
)
_check_column_on_df(
input_df,
configuration["calculation"],
f'Provided metric calculation {configuration["calculation"]} does not exist or cannot be queried on the input_df, change the column',
)
if "start_value" not in configuration:
raise InvalidInputException(
f"Provided metric {metric_name} must have a start value that is an int or float"
)
if not isinstance(configuration["start_value"], (int, float)):
raise InvalidInputException(
f"Provided metric {metric_name} must have a start value that is an int or float"
)
def _get_current_and_future_periods(
start_date: datetime | date,
end_date: datetime | date = None,
periodicity: str = "DAY",
spark: SparkSession = None,
) -> DataFrame:
end_date = end_date or datetime.utcnow()
date_df = periods(start_date, end_date, steps=1, step_unit=periodicity, spark=spark)
window = W.orderBy(F.col("period_start").asc())
next_period_start = F.lead(F.col("period_start")).over(window)
next_period_end = F.lead(F.col("period_end")).over(window)
prev_period_start = F.lag(F.col("period_start")).over(window)
prev_period_end = F.lag(F.col("period_end")).over(window)
next_next_period_start = F.lead(F.col("period_start"), 2).over(window)
next_next_period_end = F.lead(F.col("period_end"), 2).over(window)
interval = F.expr(f"INTERVAL 1 {periodicity}")
cutoff_interval = interval - F.expr(f"INTERVAL 1 DAY")
df = (
date_df.select(
F.col("period_start").cast("date").alias("period_start"),
F.col("period_end").cast("date").alias("period_end"),
next_period_start.cast("date").alias("next_period_start"),
next_period_end.cast("date").alias("next_period_end"),
prev_period_start.cast("date").alias("prev_period_start"),
prev_period_end.cast("date").alias("prev_period_end"),
next_next_period_start.cast("date").alias("next_next_period_start"),
next_next_period_end.cast("date").alias("next_next_period_end"),
)
.withColumns(
{
"prev_period_start": prev_period_start,
"prev_period_end": prev_period_end,
}
)
.where(F.col("period_start") >= start_date)
)
return df
def _filter_for_panel(
input_df: DataFrame,
panel_cutoff: date,
user_create_date_column: Column,
panel_end_date_column: Column,
) -> DataFrame:
return input_df.where(user_create_date_column < F.col("period_start")).where(
E.any(
[
panel_end_date_column.cast("date") > F.col("next_period_end"),
panel_end_date_column.cast("date") == panel_cutoff,
]
)
)
def _metric_name_with_slices(
metric_name_column: Column, slice_columns: list[Column]
) -> Column:
# (revenue, []) -> "revenue"
# (revenue, [country, state]) -> "revenue - country, state"
if len(slice_columns) == 0:
return metric_name_column
separator = F.lit(", ")
concat_columns = [metric_name_column, F.lit(" - ")]
for idx, slice_column in enumerate(slice_columns):
concat_columns.append(slice_column)
if idx != (len(slice_columns) - 1):
concat_columns.append(separator)
metric_name_with_slice = F.concat(*concat_columns)
return metric_name_with_slice
def _check_column_on_df(
df: DataFrame,
column: Column,
exception_message: str = None,
):
# Workaround to surface AnalysisException for invalid columns on a lazy dataframe in spark connect
try:
df.select(column).limit(0).first()
except (AnalysisException, AnalysisConnectException):
raise InvalidInputException(
exception_message or f"{column} does not exist on dataframe"
)