Source code for etl_toolkit.analyses.index

from datetime import datetime, date
from typing import Literal

from pyspark.sql import functions as F, types as T, Column, DataFrame, Window as W
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
from pyspark.sql.connect.column import Column as ConnectColumn
from pyspark.errors.exceptions.connect import (
    AnalysisException as AnalysisConnectException,
)
from pyspark.errors import AnalysisException
from pyspark.sql import SparkSession
from yipit_databricks_utils.helpers.telemetry import track_usage

from etl_toolkit import expressions as E
from etl_toolkit.exceptions import InvalidColumnTypeException, InvalidInputException
from etl_toolkit.analyses.scalar import get_column_type, get_aggregates
from etl_toolkit.analyses.ordering import shift_columns
from etl_toolkit.analyses.calculation import add_percent_of_total_columns
from etl_toolkit.analyses.time import periods


RANGE_JOIN_BIN_SIZE = {
    "WEEK": 7,
    "MONTH": 30,
    "QUARTER": 90,
}


@track_usage
[docs] def index_from_rolling_panel( input_df: DataFrame, start_date: datetime | date, date_column: Column | str, metrics: dict[str, dict], slices: list[str | Column] = None, panel_periodicity: Literal["WEEK", "MONTH", "QUARTER"] = "MONTH", index_periodicity: Literal["DAY", "WEEK", "MONTH", "QUARTER"] = "DAY", panel_end_date_column: Column | str = "panel_end_date", user_create_date_column: Column | str = "create_time", spark: SparkSession = None, ) -> DataFrame: """ Returns a dataframe that maps the input dataframe to a series of indices for each metric provided. The indices are calculated through an aggregate function, a base value, and a rolling panel methodology, where the panel is fixed for 2 consecutive periods. The period length can be configured using ``panel_periodicity``. Indices can be broken down across various ``slices`` to visualize trends on key segments of the dataset. Multiple metrics can be specified, these will be available as different columns in the output dataframe. .. note:: This function is expected to work with Edison (e-receipt) derived dataframes. The rolling panel methodology is designed around the fluctuating behavior of the Edison user panel. :param input_df: The input dataframe to use to generate a series of indices. :param start_date: The minimum start date of the ``input_df``, all metrics will be filtered on or after this date. :param date_column: A Column or string that indicates the date column of the ``input_df`` to use in paneling and index logic. If a string is used, it is referenced as a Column. :param metrics: A dict of various metric configurations to generate indices. An index will be created for each metric defined in this dictionary. Each metric config must have an ``aggregation`` Column and a ``start_value`` defined for the index. :param slices: Optional list of string or columns that define the slices with in the dataframe. If slices are used, the indices generated will be for each slice within the ``input_df``. If strings are used in the list they are resolved as Columns. :param panel_periodicity: The unit of length for the inverval of each paneling window. :param index_periodicity: The unit of length for the inverval of each indexing window. A row will be generated for each period based on this argument. :param panel_end_date_column: The Column to use when determing the exit date of a user from the panel in the rolling panel methodology. :param user_create_date_column: The Column to use when determing the state date of a user in the panel in the rolling panel methodology. :param spark: Optional argument to pass in a spark session to this function. Normally this is not used and a spark session is generated automatically. This is usually used for library developers. Examples ^^^^^^^^^^^^^^^^ .. code-block:: python :caption: Example of a daily index creation for Lyft. This creates two indices, ``rides`` and ``total_charges`` based on different aggregations. The paneling is a rolling-panel based on a weekly interval. from etl_toolkit import E, A, F source_df = spark.table(f"{DESTINATION_CATALOG}.{SILVER_DATABASE}.deduped_receipts") index_df = A.index_from_rolling_panel( source_df, start_date=date(2017, 1, 2), date_column="adjusted_order_date", metrics= { "rides": { "calculation": F.sum("geo_weight"), "start_value": 13.471719625795382, }, "total_charges": { "calculation": F.sum(total_charges), "start_value": 13.471719625795382, }, }, panel_periodicity="WEEK", index_periodicity="DAY", ) """ _validate_inputs(input_df, start_date, panel_periodicity) # Normalize date/timestamp specific columns into a date type after validation date_column = E.normalize_column(date_column) user_create_date_column = E.normalize_column(user_create_date_column) panel_end_date_column = E.normalize_column(panel_end_date_column) _validate_date_column(input_df, date_column) date_column_name = input_df.select(date_column).columns[0] date_column = date_column.cast("date").alias(date_column_name) # Validate metric columns, should raise analysis exception if not valid _validate_metrics(input_df, metrics) index_metrics = [ config["calculation"].alias(metric_name) for metric_name, config in metrics.items() ] # Validate slices, if used if slices is None: slices = [] else: slices = [E.normalize_column(slice_column) for slice_column in slices] _validate_slices(input_df, slices) # Normalize input data by truncating to start_date and enriching with period columns for paneling logic max_data_through = get_aggregates(input_df, date_column, ["max"])["max"] max_panel_end = get_aggregates( input_df, panel_end_date_column.cast("date"), ["max"] )["max"] periods_df = _get_current_and_future_periods( start_date, max_data_through, periodicity=panel_periodicity, spark=spark ) normalized_input = input_df.filter(date_column >= start_date).join( periods_df.hint("range_join", RANGE_JOIN_BIN_SIZE[panel_periodicity]), E.between(date_column, "period_start", "period_end"), how="left", ) # We want to generate 2 aggregations to create a rolling panel index # - Aggregation #1 includes panel members that are at least through the current period end # - Aggregation #2 includes panel members that are at least through the prior period end # - Then we join #1 to #2 on a lag to determine period-over-period growth rates and back into an index value current_period_aggregates = ( normalized_input.where( user_create_date_column < F.coalesce(F.col("prev_period_start"), F.lit(start_date)) ) .where( E.any( [ panel_end_date_column.cast("date") > F.col("period_end"), panel_end_date_column.cast("date") == max_panel_end, ] ) ) .groupBy("period_start", "period_end") .agg(*index_metrics) .unpivot( ["period_start", "period_end"], [x for x in metrics], "metric_name", "value" ) ) prev_period_aggregates = ( _filter_for_panel( normalized_input, max_panel_end, user_create_date_column, panel_end_date_column, ) .groupBy("period_end", "next_period_start") .agg(*index_metrics) .unpivot( ["period_end", "next_period_start"], [metric_name for metric_name in metrics], "metric_name", "value", ) ) period_over_period_growth = F.coalesce( F.try_divide(current_period_aggregates.value, prev_period_aggregates.value) - 1, F.lit(0), ) generate_index_window = W.partitionBy( current_period_aggregates.metric_name ).orderBy(current_period_aggregates.period_end.asc()) generate_index_from_growth_rates = F.exp( F.sum(F.log(1 + period_over_period_growth)).over(generate_index_window) ) * F.col("start_value") index_df = ( current_period_aggregates.withColumn( "start_value", E.chain_assigns( [ E.assign( F.lit(config["start_value"]), F.col("metric_name") == metric_name, ) for metric_name, config in metrics.items() ] ), ) .join( prev_period_aggregates, E.all( [ current_period_aggregates.period_start == prev_period_aggregates.next_period_start, current_period_aggregates.metric_name == prev_period_aggregates.metric_name, ] ), how="left", ) .select( current_period_aggregates.period_start, current_period_aggregates.period_end, current_period_aggregates.metric_name, period_over_period_growth.alias("period_growth"), generate_index_from_growth_rates.alias("index_value"), ) ) # If slices are specified, re-aggregate values using the slices + day, otherwise just day # and then join to the index_df to calculate percentage of daily index to allocate to each day( + slice). # Index values need to be calculated across slices first to account for distribution bias # or incompleteness within each slice group_by_columns = [date_column, "period_start", "period_end"] if len(slices): group_by_columns = group_by_columns + slices # Aggregate metrics by day and optionally slices # Then aggregate metrics again by period and get the percent of the day(+slice) # this will be used to scale the daily index in the final output daily_df = ( _filter_for_panel( normalized_input, max_panel_end, user_create_date_column, panel_end_date_column, ) .groupBy(group_by_columns) .agg(*index_metrics) ) daily_mix_df = add_percent_of_total_columns( daily_df, value_columns=[metric_name for metric_name in metrics], total_grouping_columns=["period_end"], ).unpivot( group_by_columns, [F.col(f"{metric_name}_percent").alias(metric_name) for metric_name in metrics], "metric_name", "percent_of_total", ) daily_index_df = daily_mix_df.join( index_df.hint("range_join", RANGE_JOIN_BIN_SIZE[panel_periodicity]), E.all( [ E.between(date_column_name, index_df.period_start, index_df.period_end), index_df.metric_name == daily_mix_df.metric_name, ] ), how="inner", ).select( F.col(date_column_name).alias("date"), index_df.period_start.alias("period_start"), index_df.period_end.alias("period_end"), _metric_name_with_slices( daily_mix_df.metric_name, slices, ).alias("metric_name"), (F.col("index_value") * F.col("percent_of_total")).alias("index_value"), index_df.metric_name.alias("base_metric_name"), *slices, ) final_df = shift_columns( daily_index_df, [ "date", "period_start", "period_end", "metric_name", E.growth_rate_by_lag( "index_value", W.partitionBy("metric_name").orderBy("date"), default=0, ).alias("daily_growth"), "index_value", "base_metric_name", ], ) # Gross up to higher periodicities if index periodicity is not daily if index_periodicity in ["WEEK", "MONTH", "QUARTER"]: index_date_df = periods( start_date, max_data_through, steps=1, step_unit=index_periodicity, spark=spark, ) final_df = ( final_df.join( index_date_df.hint( "range_join", RANGE_JOIN_BIN_SIZE[index_periodicity] ), E.between( final_df.date, index_date_df.period_start, index_date_df.period_end ), how="left", ) .groupBy( index_date_df.period_start, index_date_df.period_end, "metric_name", "base_metric_name", *slices, ) .agg(F.sum("index_value").alias("index_value")) .select( F.col("period_start").alias("date"), "period_start", "period_end", "metric_name", E.growth_rate_by_lag( "index_value", W.partitionBy("metric_name").orderBy("period_start"), default=0, ).alias("period_growth"), "index_value", "base_metric_name", *slices, ) ) return final_df
def _validate_inputs( input_df: DataFrame, start_date: date | datetime, panel_periodicity: str ): if panel_periodicity not in ["WEEK", "MONTH", "QUARTER"]: raise InvalidInputException( f'Invalid panel_periodicity {panel_periodicity}, must be one of "WEEK", "MONTH", "QUARTER"' ) if not isinstance(start_date, (date, datetime)): raise InvalidInputException( f"Invalid start_date, {start_date}, must be a datetime or date python object" ) if not isinstance(input_df, (DataFrame, ConnectDataFrame)): raise InvalidInputException( f"Invalid input_df, {input_df}, must be a dataframe" ) def _validate_date_column(input_df: DataFrame, date_column: Column): _check_column_on_df( input_df, date_column, f"Provided date_column {date_column} does not exist or cannot be queried on the input_df, change the column", ) if get_column_type(input_df, date_column) not in ["date", "timestamp"]: raise InvalidColumnTypeException( f"Invalid date_column {date_column}, must be a date or timestamp type" ) def _validate_slices(input_df: DataFrame, slice_columns: list[Column]): # Validate slice columns exist on dataframe for slice_column in slice_columns: _check_column_on_df( input_df, slice_column, f"Provided slice {slice_column} does not exist or cannot be queried on the input_df, change the column", ) def _validate_metrics(input_df: DataFrame, metrics: dict): if not isinstance(metrics, dict): raise InvalidInputException( "Provided metrics must be a dict, with keys representing the metric names and the value as the metric configuration" ) for metric_name, configuration in metrics.items(): if "calculation" not in configuration: raise InvalidInputException( f"Provided metric {metric_name} must have a calculation that is a pyspark aggregate expression" ) if not isinstance(configuration["calculation"], (Column, ConnectColumn)): raise InvalidInputException( f"Provided metric {metric_name} must have a calculation that is a pyspark aggregate expression" ) _check_column_on_df( input_df, configuration["calculation"], f'Provided metric calculation {configuration["calculation"]} does not exist or cannot be queried on the input_df, change the column', ) if "start_value" not in configuration: raise InvalidInputException( f"Provided metric {metric_name} must have a start value that is an int or float" ) if not isinstance(configuration["start_value"], (int, float)): raise InvalidInputException( f"Provided metric {metric_name} must have a start value that is an int or float" ) def _get_current_and_future_periods( start_date: datetime | date, end_date: datetime | date = None, periodicity: str = "DAY", spark: SparkSession = None, ) -> DataFrame: end_date = end_date or datetime.utcnow() date_df = periods(start_date, end_date, steps=1, step_unit=periodicity, spark=spark) window = W.orderBy(F.col("period_start").asc()) next_period_start = F.lead(F.col("period_start")).over(window) next_period_end = F.lead(F.col("period_end")).over(window) prev_period_start = F.lag(F.col("period_start")).over(window) prev_period_end = F.lag(F.col("period_end")).over(window) next_next_period_start = F.lead(F.col("period_start"), 2).over(window) next_next_period_end = F.lead(F.col("period_end"), 2).over(window) interval = F.expr(f"INTERVAL 1 {periodicity}") cutoff_interval = interval - F.expr(f"INTERVAL 1 DAY") df = ( date_df.select( F.col("period_start").cast("date").alias("period_start"), F.col("period_end").cast("date").alias("period_end"), next_period_start.cast("date").alias("next_period_start"), next_period_end.cast("date").alias("next_period_end"), prev_period_start.cast("date").alias("prev_period_start"), prev_period_end.cast("date").alias("prev_period_end"), next_next_period_start.cast("date").alias("next_next_period_start"), next_next_period_end.cast("date").alias("next_next_period_end"), ) .withColumns( { "prev_period_start": prev_period_start, "prev_period_end": prev_period_end, } ) .where(F.col("period_start") >= start_date) ) return df def _filter_for_panel( input_df: DataFrame, panel_cutoff: date, user_create_date_column: Column, panel_end_date_column: Column, ) -> DataFrame: return input_df.where(user_create_date_column < F.col("period_start")).where( E.any( [ panel_end_date_column.cast("date") > F.col("next_period_end"), panel_end_date_column.cast("date") == panel_cutoff, ] ) ) def _metric_name_with_slices( metric_name_column: Column, slice_columns: list[Column] ) -> Column: # (revenue, []) -> "revenue" # (revenue, [country, state]) -> "revenue - country, state" if len(slice_columns) == 0: return metric_name_column separator = F.lit(", ") concat_columns = [metric_name_column, F.lit(" - ")] for idx, slice_column in enumerate(slice_columns): concat_columns.append(slice_column) if idx != (len(slice_columns) - 1): concat_columns.append(separator) metric_name_with_slice = F.concat(*concat_columns) return metric_name_with_slice def _check_column_on_df( df: DataFrame, column: Column, exception_message: str = None, ): # Workaround to surface AnalysisException for invalid columns on a lazy dataframe in spark connect try: df.select(column).limit(0).first() except (AnalysisException, AnalysisConnectException): raise InvalidInputException( exception_message or f"{column} does not exist on dataframe" )