Source code for etl_toolkit.analyses.standard_metrics.unified_kpi

import json
from datetime import date
from dateutil.relativedelta import relativedelta
from typing import Any, Dict, List, Literal, Optional

from pyspark.sql import functions as F, Window as W, Column, DataFrame

from yipit_databricks_utils.helpers.telemetry import track_usage
from yipit_databricks_utils.helpers.pyspark_utils import get_spark_session

from etl_toolkit import E
from etl_toolkit.analyses.calculation import (
    add_lag_columns,
    add_percent_of_total_columns,
)
from etl_toolkit.analyses.scalar import get_aggregates
from etl_toolkit.exceptions import InvalidInputException
from etl_toolkit.analyses.standard_metrics.config import (
    entity_configuration,
    standard_metric_metadata,
    standard_metric_configuration,
)
from etl_toolkit.analyses.standard_metrics.helpers import (
    SLICE_COLUMN_COUNT,
    SLICE_NAME_VALUE_COLUMNS,
    analysis_name_clean,
    STANDARD_DECIMAL_TYPE,
    GRANULARITY_ORDER_MAPPING,
)


# Conversion factors to standardize different time periods to weeks
# Used for comparing metrics across different time periods
calendar_week_factor = {
    "DAY": 7,  # Days per week
    "WEEK": 1,  # Base unit
    "MONTH": (1 / 4),  # Approximation of weeks per month
    "QUARTER": (1 / 13),  # Approximation of weeks per quarter
    "HALF_YEAR": (1 / 26),  # Approximation of weeks per half year
    "YEAR": (1 / 52),  # Weeks per year
}

# Mapping of standard calendar column names to period-specific column names
# Used to maintain consistency in column naming across different calendar periods
calendar_metadata_columns_map = {
    "week_period_start": "week_start_period",
    "week_period_end": "week_end_period",
    "month_period_start": "month_start_period",
    "month_period_end": "month_end_period",
    "quarter_period_start": "quarter_start_period",
    "quarter_period_end": "quarter_end_period",
    "year_period_start": "year_start_period",
    "year_period_end": "year_end_period",
    "year_label": "year_label_period",
    "quarter_label": "quarter_label_period",
}


def _validate_min_max_granularity(
    minimum_granularity: str,
    maximum_granularity: str,
    display_period_granularity: str,
    report_period_granularity: str,
    source_table_granularity: str,
):
    """
    Validate input granularities.
    """
    if report_period_granularity == "HALF_YEAR" and source_table_granularity != "DAY":
        raise ValueError(
            "Invalid granularity configuration. 'HALF_YEAR' report_period_granularity is only supported for 'DAY' source_table_granularity."
        )

    maximum_granularity_order = GRANULARITY_ORDER_MAPPING[maximum_granularity]
    minimum_granularity_order = GRANULARITY_ORDER_MAPPING[minimum_granularity]
    report_granularity_order = GRANULARITY_ORDER_MAPPING[report_period_granularity]
    display_granularity_order = GRANULARITY_ORDER_MAPPING[display_period_granularity]

    if maximum_granularity_order < minimum_granularity_order:
        raise ValueError(
            f"Invalid granularity configuration. Maximum granularity '{maximum_granularity}' must be the same as or more granular than minimum granularity '{minimum_granularity}'."
        )

    if maximum_granularity_order < report_granularity_order:
        raise ValueError(
            f"Invalid granularity configuration. Maximum granularity '{maximum_granularity}' must be the same as or more granular than report period granularity '{report_period_granularity}'."
        )

    if maximum_granularity_order < display_granularity_order:
        raise ValueError(
            f"Invalid granularity configuration. Maximum granularity '{maximum_granularity}' must be the same as or more granular than display period granularity '{display_period_granularity}'."
        )


def _get_simple_aggregate_configs(
    periodicity,
    adjust_days,
    aggregate_type,
    ptd_date_over_date,
    source_table_granularity,
    slices=None,
    trailing_params=None,
) -> Dict[str, Any]:
    """
    Generate a configuration dictionary for simple aggregate calculations such as SUM, AVG, etc. across different time periods.
    """
    # Build the basic configuration for simple aggregation
    agg = {
        "name": "SIMPLE_AGGREGATE",
        "aggregate_function": aggregate_type,
        "adjust_to_days_in_current_period": adjust_days,
        "ptd_date_over_date": adjust_days
        and ptd_date_over_date,  # Only enable ptd comparison if we're adjusting for days in period
        "slices": slices or [],  # Default to empty list if no slices provided
        "analysis_periodicity": periodicity,  # Determines the time buckets for aggregation
        "source_table_granularity": source_table_granularity,  # Needed to properly roll up source data
    }

    # Add trailing period parameters if provided (for T7D, etc.)
    if trailing_params:
        agg.update(trailing_params)
    return agg


def _get_growth_rate_configs(
    periodicity,
    adjust_days,
    calendar_type,
    aggregate_type,
    max_relevant_years,
    growth_rate_type,
    source_table_granularity,
    ptd_date_over_date,
    slices=None,
    trailing_params=None,
    current_period_limit_days=None,
) -> Dict[str, Any]:
    """
    Generate a configuration dictionary for growth rate calculations, which compare metrics across different time periods to measure change.
    """
    # For monthly periodicity, we always use EXACT_N_YEARS calendar type
    # This ensures consistent month-over-month comparisons
    calendar_type_for_growth_rate = (
        "EXACT_N_YEARS" if periodicity == "MONTH" else calendar_type
    )

    # For period-to-date-minus-7-days comparisons, we only need to look back 1 year
    override_max_relevant_years = 1 if (current_period_limit_days == 7) else None

    # Build the basic configuration for growth rate calculation
    growth = {
        "name": "GROWTH_RATE",
        "aggregate_function": aggregate_type,
        "max_relevant_years": override_max_relevant_years
        or max_relevant_years,  # Use override if set, otherwise use configured value
        "calendar_type": calendar_type_for_growth_rate,  # Determines how periods are compared (52-week year vs exact calendar year)
        "growth_rate_type": growth_rate_type,  # CAGR or Simple
        "slices": slices or [],  # Default to empty list if no slices provided
        "analysis_periodicity": periodicity,  # Time bucket for growth rate calculation
        "source_table_granularity": source_table_granularity,
        "adjust_to_days_in_current_period": adjust_days,
        "ptd_date_over_date": adjust_days
        and ptd_date_over_date,  # Only enable ptd comparison if we're adjusting for days
    }

    # Add trailing period parameters if provided (for T7D, etc.)
    if trailing_params:
        growth.update(trailing_params)

    # Current period limit days is used to reduce the number of days considered in the current period
    # For example, current_period_limit_days = 7 will only exclude the last 7 days of the current period
    if current_period_limit_days:
        growth.update({"current_period_limit_days": current_period_limit_days})
        growth[
            "ptd_date_over_date"
        ] = False  # Disable ptd comparison when using period limit
    return growth


def _get_analysis_function_list(
    standard_metric_metadata: standard_metric_metadata,
    standard_metric_configuration: standard_metric_configuration,
) -> List[Dict[str, Any]]:
    """
    Generate the complete set of analysis configurations based on
    the metric metadata and configuration parameters. This determines which
    calculations to perform and how to structure them.
    """

    # Extract configuration parameters for easier access
    report_period_granularity = standard_metric_metadata.report_period_granularity
    display_period_granularity = standard_metric_metadata.display_period_granularity
    source_table_granularity = standard_metric_configuration.source_table_granularity
    aggregate_type = standard_metric_configuration.aggregate_function
    growth_rate_type = standard_metric_configuration.growth_rate_type
    max_relevant_years = standard_metric_configuration.max_relevant_years
    calendar_type = standard_metric_configuration.calendar_type
    slice_columns = standard_metric_configuration.slice_columns
    trailing_period_length = standard_metric_configuration.trailing_period_length
    trailing_period_aggregate_function = (
        standard_metric_configuration.trailing_period_aggregate_function
    )
    custom_function = standard_metric_configuration.custom_function
    minimum_granularity = (
        standard_metric_configuration.minimum_granularity or report_period_granularity
    )
    maximum_granularity = (
        standard_metric_configuration.maximum_granularity or source_table_granularity
    )
    exclude_month_to_date = standard_metric_configuration.exclude_month_to_date
    exclude_quarter_to_date = standard_metric_configuration.exclude_quarter_to_date

    _validate_min_max_granularity(
        minimum_granularity,
        maximum_granularity,
        display_period_granularity,
        report_period_granularity,
        source_table_granularity,
    )

    # For weekly display, always use 7 days as the trailing period
    # This ensures consistency in weekly reporting
    if display_period_granularity == "WEEK":
        trailing_period_length = 7

    # Default to include quarterly and monthly analyses
    periodicity_list = ["QUARTER", "MONTH"]

    # Add half-year analysis if configured for half-year reporting
    if standard_metric_metadata.report_period_granularity == "HALF_YEAR" and (
        not minimum_granularity or minimum_granularity == "HALF_YEAR"
    ):
        periodicity_list.append("HALF_YEAR")

    if source_table_granularity == "QUARTER":
        periodicity_list = [
            "QUARTER"
        ]  # Remove monthly if source granularity is quarterly
        # half_year is only supported for daily source granularities
    elif minimum_granularity and minimum_granularity == "MONTH":
        periodicity_list = [
            "MONTH"
        ]  # Remove quarterly if minimum granularity is monthly
    elif minimum_granularity and minimum_granularity == "DAY":
        periodicity_list = []

    # For EXACT_N_YEARS calendar type, enable period-to-date date-over-date comparison
    # This allows comparing partial periods (e.g., QTD this year vs QTD last year)
    ptd_date_over_date = calendar_type == "EXACT_N_YEARS"

    # If custom_function is PCT_OF_TOTAL or EXCLUDE_TOTAL, exclude the total (non-sliced)
    # analyses. By default, we include non-sliced analyses
    include_total = True
    if custom_function and custom_function in ["PCT_OF_TOTAL", "EXCLUDE_TOTAL"]:
        include_total = False

    if custom_function and custom_function == "PCT_OF_TOTAL":
        # Set growth_rate_type and trailing_period_length to None to avoid calculating the analyses
        # before nominal values are converted to a percentage for PCT_OF_TOTAL metrics
        growth_rate_type = None
        trailing_period_length = None

    analysis_function_list = []
    if aggregate_type:
        for periodicity in periodicity_list:
            # For each non-daily periodicity, add simple aggregate and growth rate analysis
            # configurations to analysis_function_list. By default, this includes full period sliced,
            # period-to-date non-sliced, full-period sliced, and period-to-date non-sliced analyses.
            # If exclude_ptd_analyses = True, period-to-date analyses are excluded.
            # If include_total is False, non-sliced analyses are excluded.
            exclude_ptd_analyses = False
            if periodicity == source_table_granularity:
                # If the periodicity for this analysis group matches the source granularity, we exclude
                # period-to-date analyses
                exclude_ptd_analyses = True
            elif exclude_month_to_date and periodicity == "MONTH":
                exclude_ptd_analyses = True
            elif exclude_quarter_to_date and periodicity == "QUARTER":
                exclude_ptd_analyses = True

            ptd_simple_agg_config = _get_simple_aggregate_configs(
                periodicity=periodicity,
                adjust_days=True,
                aggregate_type=aggregate_type,
                ptd_date_over_date=ptd_date_over_date,
                source_table_granularity=source_table_granularity,
            )

            full_period_simple_agg_config = _get_simple_aggregate_configs(
                periodicity=periodicity,
                adjust_days=False,
                aggregate_type=aggregate_type,
                ptd_date_over_date=ptd_date_over_date,
                source_table_granularity=source_table_granularity,
            )

            if include_total:
                analysis_function_list.append(full_period_simple_agg_config)

                if not exclude_ptd_analyses:
                    analysis_function_list.append(ptd_simple_agg_config)

            if slice_columns:
                sliced_full_period_simple_agg_config = (
                    full_period_simple_agg_config.copy()
                )
                sliced_full_period_simple_agg_config["slices"] = slice_columns
                analysis_function_list.append(sliced_full_period_simple_agg_config)

                if not exclude_ptd_analyses:
                    sliced_ptd_simple_agg_config = ptd_simple_agg_config.copy()
                    sliced_ptd_simple_agg_config["slices"] = slice_columns
                    analysis_function_list.append(sliced_ptd_simple_agg_config)

            if growth_rate_type:
                ptd_growth_rate_config = _get_growth_rate_configs(
                    periodicity=periodicity,
                    adjust_days=True,
                    calendar_type=calendar_type,
                    aggregate_type=aggregate_type,
                    max_relevant_years=max_relevant_years,
                    growth_rate_type=growth_rate_type,
                    source_table_granularity=source_table_granularity,
                    ptd_date_over_date=ptd_date_over_date,
                )

                full_period_growth_rate_config = _get_growth_rate_configs(
                    periodicity=periodicity,
                    adjust_days=False,
                    calendar_type=calendar_type,
                    aggregate_type=aggregate_type,
                    max_relevant_years=max_relevant_years,
                    growth_rate_type=growth_rate_type,
                    source_table_granularity=source_table_granularity,
                    ptd_date_over_date=ptd_date_over_date,
                )

                if include_total:
                    analysis_function_list.append(full_period_growth_rate_config)

                    if not exclude_ptd_analyses:
                        analysis_function_list.append(ptd_growth_rate_config)

                if slice_columns:
                    sliced_full_period_growth_rate_config = (
                        full_period_growth_rate_config.copy()
                    )
                    sliced_full_period_growth_rate_config["slices"] = slice_columns
                    analysis_function_list.append(sliced_full_period_growth_rate_config)

                    if not exclude_ptd_analyses:
                        sliced_ptd_growth_rate_config = ptd_growth_rate_config.copy()
                        sliced_ptd_growth_rate_config["slices"] = slice_columns
                        analysis_function_list.append(sliced_ptd_growth_rate_config)

    if source_table_granularity == "DAY":
        # If the source granularity is DAY, we also add daily simple aggregate and growth rate analyses.
        # By default, this includes trailing & sliced, trailing & non-sliced, non-trailing & sliced,
        # and non-trailing & non-sliced analyses.
        # If include_total is False, non-sliced analyses are excluded.
        if growth_rate_type:
            if include_total:
                analysis_function_list.append(
                    _get_growth_rate_configs(
                        periodicity="QUARTER",
                        adjust_days=True,
                        calendar_type=calendar_type,
                        aggregate_type=aggregate_type,
                        max_relevant_years=max_relevant_years,
                        growth_rate_type=growth_rate_type,
                        source_table_granularity=source_table_granularity,
                        ptd_date_over_date=ptd_date_over_date,
                        current_period_limit_days=7,
                    )
                )

            if trailing_period_length:
                trailing_params = {
                    "trailing_period_length": trailing_period_length,
                    "trailing_period_aggregate_function": trailing_period_aggregate_function,
                }

                trailing_day_growth_rate_config = _get_growth_rate_configs(
                    periodicity=source_table_granularity,
                    adjust_days=False,
                    calendar_type=calendar_type,
                    aggregate_type=aggregate_type,
                    max_relevant_years=max_relevant_years,
                    growth_rate_type=growth_rate_type,
                    source_table_granularity=source_table_granularity,
                    ptd_date_over_date=ptd_date_over_date,
                    trailing_params=trailing_params,
                )
                if include_total:
                    analysis_function_list.append(trailing_day_growth_rate_config)

                if slice_columns:
                    trailing_sliced_day_growth_rate_config = (
                        trailing_day_growth_rate_config.copy()
                    )
                    trailing_sliced_day_growth_rate_config["slices"] = slice_columns
                    analysis_function_list.append(
                        trailing_sliced_day_growth_rate_config
                    )

        if include_total:
            aggregate_list = ["SUM"]
            if aggregate_type and aggregate_type != "SUM":
                aggregate_list.append(aggregate_type)
            for operator in aggregate_list:
                analysis_function_list.append(
                    _get_simple_aggregate_configs(
                        periodicity=source_table_granularity,
                        adjust_days=False,
                        aggregate_type=operator,
                        ptd_date_over_date=ptd_date_over_date,
                        source_table_granularity=source_table_granularity,
                    )
                )

        if slice_columns:
            aggregate_type_for_slice = aggregate_type or "SUM"

            sliced_day_simple_agg_config = _get_simple_aggregate_configs(
                periodicity=source_table_granularity,
                adjust_days=False,
                aggregate_type=aggregate_type_for_slice,
                ptd_date_over_date=ptd_date_over_date,
                source_table_granularity=source_table_granularity,
                slices=slice_columns,
            )
            analysis_function_list.append(sliced_day_simple_agg_config)

            if trailing_period_length:
                trailing_params = {
                    "trailing_period_length": trailing_period_length,
                    "trailing_period_aggregate_function": trailing_period_aggregate_function,
                }
                trailing_sliced_day_simple_agg_config = (
                    sliced_day_simple_agg_config.copy()
                )
                trailing_sliced_day_simple_agg_config.update(trailing_params)
                analysis_function_list.append(trailing_sliced_day_simple_agg_config)

        if trailing_period_length:
            trailing_aggregation_types = ["SUM", "AVG"]

            if include_total:
                for operator in trailing_aggregation_types:
                    trailing_params = {
                        "trailing_period_length": trailing_period_length,
                        "trailing_period_aggregate_function": operator,
                    }
                    analysis_function_list.append(
                        _get_simple_aggregate_configs(
                            periodicity=source_table_granularity,
                            adjust_days=False,
                            aggregate_type=aggregate_type,
                            ptd_date_over_date=ptd_date_over_date,
                            source_table_granularity=source_table_granularity,
                            trailing_params=trailing_params,
                        )
                    )
    return analysis_function_list


def _get_analysis_function_groups(
    standard_metric_metadata: standard_metric_metadata,
    standard_metric_configuration: standard_metric_configuration,
) -> List[Dict[str, Any]]:
    """
    Group analyses that can be calculated together. An analysis group is defined by the periodicity, if the analyses are period to date, and whether or not the analyses are sliced.
    """
    analysis_function_list = _get_analysis_function_list(
        standard_metric_metadata, standard_metric_configuration
    )

    grouped_functions = {}

    for function in analysis_function_list:
        analysis_periodicity = function.get("analysis_periodicity")
        ptd_date_over_date = function.get("ptd_date_over_date")
        slices = tuple(function.get("slices", []))

        group_key = (analysis_periodicity, ptd_date_over_date, slices)
        if group_key not in grouped_functions:
            grouped_functions[group_key] = {
                "analysis_periodicity": analysis_periodicity,
                "ptd_date_over_date": ptd_date_over_date,
                "slices": list(slices),
                "functions": [],
            }

        grouped_functions[group_key]["functions"].append(
            {
                k: v
                for k, v in function.items()
                if k not in {"analysis_periodicity", "ptd_date_over_date", "slices"}
            }
        )

    analysis_function_groups = list(grouped_functions.values())

    return analysis_function_groups


def _get_calendar_col_for_analysis_group(
    analysis_periodicity: str,
    ptd_date_over_date: bool,
    half_year_columns_exist: bool,
) -> List[str]:
    """
    Get calendar columns for group analyses configuration.
    """

    half_year_columns = [
        "half_year_period_start",
        "half_year_period_end",
        "half_year_label",
        "half_year_length_in_days",
        "days_in_half_year",
    ]

    period_start_end_columns = [
        F.col("day").alias("period_start"),
        F.col("day").alias("period_end"),
    ]

    days_in_columns = [
        "days_in_week",
        "days_in_month",
        "days_in_quarter",
        "days_in_year",
    ]

    if not half_year_columns_exist:
        half_year_columns = [
            F.lit(None).cast("date").alias("half_year_period_start"),
            F.lit(None).cast("date").alias("half_year_period_end"),
            F.lit(None).cast("string").alias("half_year_label"),
            F.lit(None).cast("int").alias("half_year_length_in_days"),
            F.lit(None).cast("int").alias("days_in_half_year"),
        ]

    if analysis_periodicity != "DAY":
        analysis_periodicity_lower = analysis_periodicity.lower()
        period_start_end_columns = [
            F.col(f"{analysis_periodicity_lower}_period_start").alias("period_start"),
            F.col(f"{analysis_periodicity_lower}_period_end").alias("period_end"),
        ]

    if ptd_date_over_date:
        days_in_columns = [
            "days_in_week",
            F.col("days_in_month_lp_adjusted").alias("days_in_month"),
            F.col("days_in_quarter_lp_adjusted").alias("days_in_quarter"),
            F.col("days_in_year_lp_adjusted").alias("days_in_year"),
        ]

    return [*half_year_columns, *period_start_end_columns, *days_in_columns]


def _get_affected_by_leap_year_expression(analysis_periodicity: str):
    """Generate expression to flag if a time interval is affected by a leap year"""
    match analysis_periodicity:
        case "DAY":
            # Daily granularity is not affected by leap years in our calculations
            return F.lit(False)
        case "MONTH":
            # Only February in leap years is affected
            return E.all(
                F.col("leap_year") == 1,
                F.month(F.col("day")) == 2,
                F.day(F.col("day")) == 29,
            )
        case "QUARTER":
            return (
                F.max(F.col("leap_day")).over(
                    W.partitionBy("quarter_period_start")
                    .orderBy("day")
                    .rowsBetween(W.unboundedPreceding, W.unboundedFollowing)
                )
                == 1
            )
        case "HALF_YEAR":
            # First half of a leap year is affected by leap day and subsequent months
            return E.any(
                E.all(
                    F.col("leap_year") == 1,
                    F.month(F.col("day")) == 2,
                    F.day(F.col("day")) == 29,
                ),
                E.all(
                    F.col("leap_year") == 1,
                    F.month(F.col("day")) >= 3,
                    F.month(F.col("day")) < 7,
                ),
            )


def _get_base_calendar(
    calendar_df: DataFrame,
    date_column: str,
    analysis_periodicity: str,
    ptd_date_over_date: bool,
) -> DataFrame:
    """
    Generate base calendar dataframe for analysis group with all necessary time period columns.

    This function takes a calendar DataFrame and enriches it with additional columns needed
    for time-based analysis, including period starts/ends, labels, and leap year handling.
    """
    # Check if half-year columns exist in the source calendar
    calendar_col = calendar_df.columns
    half_year_columns_exist = "half_year_period_start" in calendar_col

    # Get additional calendar columns needed for this analysis
    analysis_calendar_columns = _get_calendar_col_for_analysis_group(
        analysis_periodicity, ptd_date_over_date, half_year_columns_exist
    )

    # Calculate which periods are affected by leap years
    affected_by_leap_year_expression = _get_affected_by_leap_year_expression(
        analysis_periodicity
    )

    # Create the enriched calendar DataFrame with all necessary columns
    base_calendar_df = calendar_df.select(
        # Core date column
        F.col("day").alias(date_column),
        # Period start/end dates for all granularities
        "week_period_start",
        "week_period_end",
        "month_period_start",
        "month_period_end",
        "quarter_period_start",
        "quarter_period_end",
        "year_period_start",
        "year_period_end",
        # Period labels
        "year_label",
        "quarter_label",
        # Leap year indicator
        "leap_year",
        # Calculate months elapsed in current quarter/year
        F.months_between(F.add_months(F.col("day"), 1), F.col("quarter_period_start"))
        .cast("int")
        .alias("months_in_quarter"),
        F.months_between(F.add_months(F.col("day"), 1), F.col("year_period_start"))
        .cast("int")
        .alias("months_in_year"),
        # Flag periods affected by leap years
        F.when(affected_by_leap_year_expression, F.lit(1))
        .otherwise(F.lit(0))
        .alias("affected_by_leap_year"),
        # Additional analysis-specific calendar columns
        *analysis_calendar_columns,
    )

    return base_calendar_df


def _aggregate_base_calendar(
    calendar_df: DataFrame,
    date_column: str,
    analysis_periodicity: str,
    ptd_date_over_date: bool,
    source_table_granularity: str,
    input_table_bounds: dict[str, date],
) -> DataFrame:
    """
    Aggregate the base calendar dataframe by source table granularity and filter to relevant dates.

    This function takes the enriched calendar DataFrame and:
    1. Filters it to only include dates within the input data range
    2. Aggregates it to match the source data granularity
    3. Preserves all necessary calendar metadata for analysis
    """

    min_date_col = E.normalize_literal(input_table_bounds["min"])
    max_date_col = E.normalize_literal(input_table_bounds["max"])

    # Get the enriched base calendar
    base_calendar_df = _get_base_calendar(
        calendar_df,
        date_column,
        analysis_periodicity,
        ptd_date_over_date,
    )

    if source_table_granularity in ["MONTH", "QUARTER"]:
        source_granularity_lower = source_table_granularity.lower()
        group_by_columns = [
            "period_start",
            "period_end",
            f"{source_granularity_lower}_period_start",
            f"{source_granularity_lower}_period_end",
        ]
        if source_table_granularity == "MONTH":
            agg_period_start = "quarter_period_start"
            agg_period_end = "quarter_period_end"
        else:
            agg_period_start = "month_period_start"
            agg_period_end = "month_period_end"

        base_calendar_df = base_calendar_df.groupBy(*group_by_columns).agg(
            F.min(date_column).alias(date_column),
            F.min("week_period_start").alias("week_period_start"),
            F.max("week_period_end").alias("week_period_end"),
            F.min(agg_period_start).alias(agg_period_start),
            F.max(agg_period_end).alias(agg_period_end),
            F.min("year_period_start").alias("year_period_start"),
            F.max("year_period_end").alias("year_period_end"),
            F.first("year_label").alias("year_label"),
            F.first("quarter_label").alias("quarter_label"),
            F.first("leap_year").alias("leap_year"),
            F.max("days_in_week").alias("days_in_week"),
            F.max("days_in_month").alias("days_in_month"),
            F.max("days_in_quarter").alias("days_in_quarter"),
            F.max("days_in_year").alias("days_in_year"),
            F.max("days_in_half_year").alias("days_in_half_year"),
            F.max("months_in_quarter").alias("months_in_quarter"),
            F.max("months_in_year").alias("months_in_year"),
            F.max("affected_by_leap_year").alias("affected_by_leap_year"),
        )

    base_calendar_agg_df = base_calendar_df.where(
        E.between(date_column, min_date_col, max_date_col)
    )

    return base_calendar_agg_df


def _get_calendar_bounds(
    calendar_df: DataFrame,
    date_column: str,
) -> DataFrame:
    """
    Calculate last day count for each periodicity grain. Used for period-to-date calculations.

    :param calendar_df: Calendar DataFrame.
    :param date_column: Source table date column.
    """

    window = (
        W.partitionBy()
        .orderBy("period_start", date_column)
        .rowsBetween(W.unboundedPreceding, W.unboundedFollowing)
    )

    calendar_bounds = calendar_df.select(
        F.last("days_in_week").over(window).alias("current_days_in_week"),
        F.last("days_in_month").over(window).alias("current_days_in_month"),
        F.last("days_in_quarter").over(window).alias("current_days_in_quarter"),
        F.last("days_in_half_year").over(window).alias("current_days_in_half_year"),
        F.last("days_in_year").over(window).alias("current_days_in_year"),
        F.last("months_in_quarter").over(window).alias("current_months_in_quarter"),
        F.last("months_in_year").over(window).alias("current_months_in_year"),
    )

    calendar_bounds_dict = calendar_bounds.first().asDict()
    for key in calendar_bounds_dict.keys():
        val = calendar_bounds_dict[key]
        calendar_bounds_dict[key] = E.normalize_literal(val)

    return calendar_bounds_dict


def _get_pre_aggregate_expression(
    date_column: str,
    value_column: str,
    slice_columns: List[str],
    trailing_period_length: int = None,
    trailing_period_aggregate_function: str = None,
) -> Column:
    """
    Generate column expression for analysis for pre-aggregation. If trailing analysis, applies trailing priod configuration on input column based on slices and trailing periodicity. Adds sequential index to column names to support downstream aggregations on the same column.
    """
    if trailing_period_length:
        window = (
            W.partitionBy(*slice_columns)
            .orderBy(date_column)
            .rowsBetween(
                -(trailing_period_length - 1),
                W.currentRow,
            )
        )

    match trailing_period_aggregate_function:
        case "SUM":
            return F.sum(value_column).over(window)
        case "AVG":
            return F.avg(value_column).over(window)
        case _:
            return F.col(value_column)


def _get_period_through_expression(
    source_table_granularity: str,
    analysis_periodicity: str,
    ptd_date_over_date: bool,
    adjust_to_days_for_analysis: bool,
    current_period_limit_days: int,
) -> Column:
    """
    Generate column expression for the period through date. If analysis is period-to-date, adjusts the date to the last day of the period. Otherwise, uses the period end date.
    """
    if adjust_to_days_for_analysis:
        if source_table_granularity == "DAY":
            analysis_col_name = f"current_days_in_{analysis_periodicity}".lower()
            analysis_col = F.col(analysis_col_name)

            if ptd_date_over_date:
                period_through_expression = F.when(
                    F.max(F.col("affected_by_leap_year")) == 1,
                    F.date_add(
                        F.col("period_start"),
                        F.first(analysis_col) - current_period_limit_days + 1,
                    ),
                ).otherwise(
                    F.date_add(
                        F.col("period_start"),
                        F.first(analysis_col) - current_period_limit_days,
                    )
                )
            else:
                period_through_expression = F.date_add(
                    F.col("period_start"),
                    F.first(analysis_col) - current_period_limit_days,
                )
        else:  # source_table_granularity == "MONTH"
            period_through_expression = F.date_add(
                F.add_months(
                    F.col("period_start"),
                    F.first(F.col("current_months_in_quarter")),
                ),
                -1,
            )
    else:
        period_through_expression = F.col("period_end")

    return period_through_expression


def _get_aggregate_expression(
    source_table_granularity: str,
    analysis_periodicity: str,
    aggregate_function: str,
    adjust_to_days_for_analysis: bool,
    current_period_limit_days: int,
    pre_agg_column_name: str,
) -> Column:
    """
    Generate column expression for analysis for aggregation. Applies function aggreation operator on pre-aggregate expression for input column.
    """
    if adjust_to_days_for_analysis:
        analysis_col_name = (
            f"{source_table_granularity}s_in_{analysis_periodicity}".lower()
        )
        analysis_col = F.col(analysis_col_name)
        current_analysis_col = F.col(f"current_{analysis_col_name}")

        match aggregate_function:
            case "EOP":
                return E.sum_if(
                    analysis_col == current_analysis_col,
                    F.col(pre_agg_column_name),
                )
            case "SUM":
                return E.sum_if(
                    analysis_col <= (current_analysis_col - current_period_limit_days),
                    F.col(pre_agg_column_name),
                )
            case "AVG":
                return E.avg_if(
                    analysis_col <= (current_analysis_col - current_period_limit_days),
                    F.col(pre_agg_column_name),
                )
    match aggregate_function:
        case "EOP":
            return F.last(F.col(pre_agg_column_name))
        case "SUM":
            return F.sum(F.col(pre_agg_column_name))
        case "AVG":
            return F.avg(F.col(pre_agg_column_name))


def _get_aggregate_expression_dict(
    value_column: str,
    date_column: str,
    source_table_granularity: str,
    analysis_periodicity: str,
    ptd_date_over_date: bool,
    slice_columns: List[str],
    analysis_function_list: List[Dict[str, Any]],
) -> Dict[str, List[Column]]:
    """
    Loops through the analyses in an analysis groups and generates a dictionary containing lists of pre-aggregate, period through, and aggregate expressions to be applied to the analysis group.
    """
    pre_aggregate_expression_list = []
    period_through_expression_list = []
    aggregate_expression_list = []

    for idx, analysis_function in enumerate(analysis_function_list):
        aggregate_function = analysis_function.get("aggregate_function")
        trailing_period_length = analysis_function.get("trailing_period_length")
        trailing_period_aggregate_function = analysis_function.get(
            "trailing_period_aggregate_function"
        )
        adjust_to_days_in_current_period = analysis_function.get(
            "adjust_to_days_in_current_period"
        )
        current_period_limit_days = analysis_function.get(
            "current_period_limit_days", 0
        )
        adjust_to_days_for_analysis = (
            adjust_to_days_in_current_period
            and analysis_periodicity != source_table_granularity
        )

        pre_agg_expression = _get_pre_aggregate_expression(
            date_column,
            value_column,
            slice_columns,
            trailing_period_length,
            trailing_period_aggregate_function,
        ).alias(f"pre_agg_{idx}")

        pre_aggregate_expression_list.append(pre_agg_expression)

        period_through_expression = _get_period_through_expression(
            source_table_granularity,
            analysis_periodicity,
            ptd_date_over_date,
            adjust_to_days_for_analysis,
            current_period_limit_days,
        ).alias(f"period_through_{idx}")

        period_through_expression_list.append(period_through_expression)

        aggregate_expression = _get_aggregate_expression(
            source_table_granularity,
            analysis_periodicity,
            aggregate_function,
            adjust_to_days_for_analysis,
            current_period_limit_days,
            pre_agg_column_name=f"pre_agg_{idx}",
        ).alias(f"agg_{idx}")

        aggregate_expression_list.append(aggregate_expression)

    aggregate_expression_dict = {
        "pre_aggregate_expression_list": pre_aggregate_expression_list,
        "period_through_expression_list": period_through_expression_list,
        "aggregate_expression_list": aggregate_expression_list,
    }

    return aggregate_expression_dict


def _aggregate_values(
    base_agg_df: DataFrame,
    calendar_df: DataFrame,
    value_column: str,
    date_column: str,
    source_table_granularity: str,
    input_table_bounds: dict[str, date],
    analysis_periodicity: str,
    ptd_date_over_date: bool,
    slice_columns: List[str],
    analysis_function_list: List[Dict[str, Any]],
) -> DataFrame:
    """
    Generates dataframe of aggregated values for each analysis in an analysis group.
    """

    base_calendar_agg_df = _aggregate_base_calendar(
        calendar_df,
        date_column,
        analysis_periodicity,
        ptd_date_over_date,
        source_table_granularity,
        input_table_bounds,
    )

    calendar_bounds_dict = _get_calendar_bounds(base_calendar_agg_df, date_column)

    calendar_bounds_col = calendar_bounds_dict.keys()

    base_agg_calendar_join_df = base_agg_df.join(
        base_calendar_agg_df, date_column, "outer"
    ).withColumns(calendar_bounds_dict)

    aggregate_expression_dict = _get_aggregate_expression_dict(
        value_column,
        date_column,
        source_table_granularity,
        analysis_periodicity,
        ptd_date_over_date,
        slice_columns,
        analysis_function_list,
    )

    pre_aggregate_expression_list = aggregate_expression_dict[
        "pre_aggregate_expression_list"
    ]

    pre_agg_df = base_agg_calendar_join_df.select(
        date_column,
        "period_start",
        "period_end",
        "week_period_start",
        "week_period_end",
        "month_period_start",
        "month_period_end",
        "quarter_period_start",
        "quarter_period_end",
        "year_period_start",
        "year_period_end",
        "days_in_week",
        "days_in_month",
        "days_in_quarter",
        "days_in_year",
        "months_in_quarter",
        "months_in_year",
        "leap_year",
        "days_in_half_year",
        "affected_by_leap_year",
        *pre_aggregate_expression_list,
        *slice_columns,
        *calendar_bounds_col,
    )

    period_index_window = W.partitionBy(*slice_columns).orderBy(
        F.col("period_start").desc(),
    )
    period_index_expression = (F.row_number().over(period_index_window) - 1).alias(
        "period_index"
    )
    period_through_expression_list = aggregate_expression_dict[
        "period_through_expression_list"
    ]
    aggregate_expression_list = aggregate_expression_dict["aggregate_expression_list"]

    agg_df = pre_agg_df.groupBy(
        "period_start",
        "period_end",
        *slice_columns,
    ).agg(
        period_index_expression,
        *period_through_expression_list,
        *aggregate_expression_list,
    )

    return agg_df


def _get_growth_rate_calculation_expression(
    analysis_periodicity: str,
    slice_columns: List[str],
    growth_rate_type: str,
    calendar_type: str,
    year: int,
    idx: int,
) -> Column:
    """
    Generate column expression for growth rate calculations.
    """
    if calendar_type == "EXACT_N_YEARS":
        prior_period_expression = F.col(f"agg_{idx}_lag_{year}_year")
    else:  # calendar_type == "52_WEEK"
        window = W.partitionBy(*slice_columns).orderBy("period_start")

        lag_period = int(year * 52 * calendar_week_factor[analysis_periodicity])

        prior_period_expression = F.lag(F.col(f"agg_{idx}"), lag_period).over(window)

    if growth_rate_type == "CAGR":
        calculation_expression = (
            F.pow(
                F.try_divide(F.col(f"agg_{idx}"), prior_period_expression),
                (1 / year),
            )
            - 1
        )
    else:  # growth_rate_type == "SIMPLE"
        calculation_expression = (
            F.try_divide(F.col(f"agg_{idx}"), prior_period_expression) - 1
        )

    return calculation_expression


def _add_growth_rate_calculations(
    base_agg_df: DataFrame,
    calendar_df: DataFrame,
    value_column: str,
    date_column: str,
    source_table_granularity: str,
    input_table_bounds: dict[str, date],
    analysis_periodicity: str,
    ptd_date_over_date: bool,
    slice_columns: List[str],
    analysis_function_list: List[Dict[str, Any]],
) -> DataFrame:
    """
    Generates dateframe with growth rate calculations applied to aggregated columns for growth rate analyses in the analysis group.
    """
    agg_df = _aggregate_values(
        base_agg_df,
        calendar_df,
        value_column,
        date_column,
        source_table_granularity,
        input_table_bounds,
        analysis_periodicity,
        ptd_date_over_date,
        slice_columns,
        analysis_function_list,
    )

    period_through_column_list = []
    calculation_expression_list = []

    for idx, analysis_function in enumerate(analysis_function_list):
        function_name = analysis_function.get("name")
        growth_rate_type = analysis_function.get("growth_rate_type")
        calendar_type = analysis_function.get("calendar_type")
        max_relevant_years = analysis_function.get("max_relevant_years")

        if function_name == "GROWTH_RATE" and calendar_type == "EXACT_N_YEARS":
            for year in range(1, max_relevant_years + 1):
                agg_df = add_lag_columns(
                    agg_df,
                    value_columns=[f"agg_{idx}"],
                    date_column="period_start",
                    slice_columns=slice_columns,
                    step_unit="YEAR",
                    steps=year,
                )
        else:
            agg_df = agg_df

        period_through_column_list.append(f"period_through_{idx}")

        if function_name == "GROWTH_RATE":
            for year in range(1, max_relevant_years + 1):
                calculation_expression = _get_growth_rate_calculation_expression(
                    analysis_periodicity,
                    slice_columns,
                    growth_rate_type,
                    calendar_type,
                    year,
                    idx,
                ).alias(f"value_{idx}_{year}y_growth")

                calculation_expression_list.append(calculation_expression)
        else:
            calculation_expression_list.append(
                F.col(f"agg_{idx}").alias(f"value_{idx}")
            )

    calc_df = agg_df.select(
        "period_start",
        "period_end",
        "period_index",
        *period_through_column_list,
        *calculation_expression_list,
        *slice_columns,
    )

    return calc_df


def _get_analysis_columns(
    analysis_periodicity: str,
    function_name: str,
    adjust_to_days_in_current_period: bool,
    slice_columns: List[str],
    aggregate_function: str,
    calendar_type: str = None,
    year: int = None,
    current_period_limit_days: int = None,
    trailing_period_length: int = None,
    trailing_period_aggregate_function: str = None,
    growth_rate_type: str = None,
) -> Dict[str, Any]:
    """
    Generates a dictionary containing metadata columns for each analysis.
    """
    naming_components = [analysis_periodicity, function_name]
    if year:
        naming_components.insert(1, f"{year}y")
    if current_period_limit_days:
        naming_components.append(
            f"period_to_date_minus_{current_period_limit_days}_days"
        )
    elif adjust_to_days_in_current_period:
        naming_components.append("period_to_date")
    if trailing_period_length:
        naming_components.append("trailing_day")
    if slice_columns:
        naming_components.append("sliced_data")

    internal_dashboard_analysis_name = ("_".join(naming_components)).lower()

    aggregate_function_for_analysis = (
        trailing_period_aggregate_function or aggregate_function
    )

    trailing_period_granularity = "DAY" if trailing_period_length else None

    analysis_options = {
        "duration": year,
        "growth_rate_type": growth_rate_type,
    }

    analysis_columns = {
        "internal_dashboard_analysis_name": F.lit(internal_dashboard_analysis_name),
        "calculation_type": F.lit(function_name),
        "calendar_type": F.lit(calendar_type).cast("string"),
        "trailing_period": F.lit(trailing_period_length).cast("int"),
        "trailing_granularity": F.lit(trailing_period_granularity).cast("string"),
        "aggregation_type": F.lit(aggregate_function_for_analysis),
        "periodicity": F.lit(analysis_periodicity),
        "analysis_options": F.lit(json.dumps(analysis_options)),
    }

    return analysis_columns


def _unpivot_analysis_group(
    base_agg_df: DataFrame,
    calendar_df: DataFrame,
    value_column: str,
    date_column: str,
    source_table_granularity: str,
    input_table_bounds: dict[str, date],
    analysis_periodicity: str,
    ptd_date_over_date: bool,
    slice_columns: List[str],
    analysis_function_list: List[Dict[str, Any]],
) -> DataFrame:
    """
    Transpose analysis group dataframe into tall format and add analysis metadata columns.
    """

    calc_df = _add_growth_rate_calculations(
        base_agg_df,
        calendar_df,
        value_column,
        date_column,
        source_table_granularity,
        input_table_bounds,
        analysis_periodicity,
        ptd_date_over_date,
        slice_columns,
        analysis_function_list,
    )
    min_date = input_table_bounds["min"]
    union_df = None

    for idx, analysis_function in enumerate(analysis_function_list):
        function_name = analysis_function.get("name")
        aggregate_function = analysis_function.get("aggregate_function")
        adjust_to_days_in_current_period = analysis_function.get(
            "adjust_to_days_in_current_period"
        )
        current_period_limit_days = analysis_function.get("current_period_limit_days")
        calendar_type = analysis_function.get("calendar_type")
        trailing_period_length = analysis_function.get("trailing_period_length")
        trailing_period_aggregate_function = analysis_function.get(
            "trailing_period_aggregate_function"
        )
        growth_rate_type = analysis_function.get("growth_rate_type")

        # limit_date_expression is used to exclude analyses that fall outside of the input table bounds
        if trailing_period_length:
            # Exclude trailing analyses if the full trailing period does not exist in the source table
            # E.g., if the data starts on 2020-01-01, T7D analyses should not be added for dates
            # before 2020-01-07
            limit_date_expression = F.date_add(
                F.col("period_start"), 1 - trailing_period_length
            )
        else:
            limit_date_expression = F.col("period_start")

        if function_name == "GROWTH_RATE":
            max_relevant_years = analysis_function.get("max_relevant_years")
            for year in range(1, max_relevant_years + 1):
                # Exclude growth rate analyses if the value is not available for the relevant year
                # E.g., if the data starts on 2020-01-01, 1Y growth rates should not be added for dates
                # before 2021-01-01
                limit_growth_rate_date_expression = limit_date_expression - F.expr(
                    f"INTERVAL {year} YEAR"
                )
                analysis_columns = _get_analysis_columns(
                    analysis_periodicity,
                    function_name,
                    adjust_to_days_in_current_period,
                    slice_columns,
                    aggregate_function,
                    calendar_type=calendar_type,
                    year=year,
                    current_period_limit_days=current_period_limit_days,
                    trailing_period_length=trailing_period_length,
                    trailing_period_aggregate_function=trailing_period_aggregate_function,
                    growth_rate_type=growth_rate_type,
                )
                df = calc_df.select(
                    "period_start",
                    "period_end",
                    "period_index",
                    F.col(f"period_through_{idx}").alias("period_through"),
                    F.col(f"value_{idx}_{year}y_growth").alias("value"),
                    *slice_columns,
                    limit_growth_rate_date_expression.alias("analysis_min_date"),
                ).withColumns(analysis_columns)

                if union_df == None:
                    union_df = df
                else:
                    union_df = union_df.unionByName(df)
        else:
            analysis_columns = _get_analysis_columns(
                analysis_periodicity,
                function_name,
                adjust_to_days_in_current_period,
                slice_columns,
                aggregate_function,
                trailing_period_length=trailing_period_length,
                trailing_period_aggregate_function=trailing_period_aggregate_function,
                growth_rate_type=growth_rate_type,
            )
            df = calc_df.select(
                "period_start",
                "period_end",
                "period_index",
                F.col(f"period_through_{idx}").alias("period_through"),
                F.col(f"value_{idx}").alias("value"),
                *slice_columns,
                limit_date_expression.alias("analysis_min_date"),
            ).withColumns(analysis_columns)

            if union_df == None:
                union_df = df
            else:
                union_df = union_df.unionByName(df)

    slice_name_list = []
    slice_value_list = []
    for idx, slice_column in enumerate(slice_columns, 1):
        slice_name_list.append(
            F.lit(slice_column).cast("string").alias(f"slice_name_{idx}")
        )
        slice_value_list.append(
            F.col(slice_column).cast("string").alias(f"slice_value_{idx}")
        )

    slice_list_len = len(slice_name_list)
    for i in range(slice_list_len + 1, SLICE_COLUMN_COUNT + 1):
        slice_name_list.append(F.lit(None).cast("string").alias(f"slice_name_{i}"))
        slice_value_list.append(F.lit(None).cast("string").alias(f"slice_value_{i}"))

    union_col = union_df.columns
    unpivot_df = (
        union_df.where(F.col("analysis_min_date") >= min_date)
        .select(*union_col, *slice_name_list, *slice_value_list)
        .drop(*slice_columns)
    )

    return unpivot_df


def _tall_format(
    base_agg_df: DataFrame,
    calendar_df: DataFrame,
    value_column: str,
    date_column: str,
    source_table_granularity: str,
    input_table_bounds: dict[str, date],
    analysis_periodicity: str,
    ptd_date_over_date: bool,
    slice_columns: List[str],
    analysis_function_list: List[Dict[str, Any]],
) -> DataFrame:
    """
    Add calendar metadata to analysis group dataframe and normalize slice columns.
    """
    unpivot_df = _unpivot_analysis_group(
        base_agg_df,
        calendar_df,
        value_column,
        date_column,
        source_table_granularity,
        input_table_bounds,
        analysis_periodicity,
        ptd_date_over_date,
        slice_columns,
        analysis_function_list,
    )

    base_calendar_df = _get_base_calendar(
        calendar_df,
        date_column,
        analysis_periodicity,
        ptd_date_over_date,
    )

    start_select_list = []
    end_select_list = []
    tall_format_select_list = []

    for column in calendar_metadata_columns_map.keys():
        mapped_column = calendar_metadata_columns_map[column]
        start_column_name = f"{mapped_column}_start"
        end_column_name = f"{mapped_column}_end"

        tall_format_select_list.append(start_column_name)
        tall_format_select_list.append(end_column_name)
        start_select_list.append(F.col(column).alias(start_column_name))
        end_select_list.append(F.col(column).alias(end_column_name))

    start_join_df = base_calendar_df.select(
        F.col(date_column).alias("period_start"),  # join on "period_start"
        *start_select_list,
        "half_year_period_start",
        "half_year_period_end",
        "half_year_label",
        "half_year_length_in_days",
        "days_in_half_year",
    )

    end_join_df = base_calendar_df.select(
        F.col(date_column).alias("period_through"),  # join on "period_through"
        *end_select_list,
    )

    tall_format_df = (
        unpivot_df.join(start_join_df, "period_start", "left")
        .join(end_join_df, "period_through", "left")
        .select(
            "period_start",
            F.col("period_through").alias("period_end"),
            "period_index",
            "half_year_period_start",
            "half_year_period_end",
            "half_year_label",
            "half_year_length_in_days",
            "days_in_half_year",
            "internal_dashboard_analysis_name",
            "calculation_type",
            "trailing_period",
            "trailing_granularity",
            "aggregation_type",
            "periodicity",
            F.col("value").cast(STANDARD_DECIMAL_TYPE).alias("value"),
            "calendar_type",
            *tall_format_select_list,
            *[f"slice_name_{i}" for i in range(1, SLICE_COLUMN_COUNT + 1)],
            *[f"slice_value_{j}" for j in range(1, SLICE_COLUMN_COUNT + 1)],
            "analysis_options",
        )
    )

    return tall_format_df


def _add_entity_configuration(
    df: DataFrame,
    entity_configuration: entity_configuration,
) -> DataFrame:
    top_level_entity_name = entity_configuration.top_level_entity_name
    top_level_entity_ticker = entity_configuration.top_level_entity_ticker
    exchange = entity_configuration.exchange
    entity_name = entity_configuration.entity_name
    figi = entity_configuration.figi
    if top_level_entity_ticker:
        top_level_entity_ticker = F.concat_ws(
            ":", F.lit(top_level_entity_ticker), F.lit(exchange)
        )
    else:
        top_level_entity_ticker = F.lit(None).cast("string")

    entity_df = df.withColumns(
        {
            "top_level_entity_name": F.lit(top_level_entity_name),
            "top_level_entity_ticker": top_level_entity_ticker,
            "exchange": F.lit(exchange).cast("string"),
            "entity_name": F.lit(entity_name).cast("string"),
            "figi": F.lit(figi).cast("string"),
        }
    )

    return entity_df


def _add_metric_metadata(
    df: DataFrame,
    standard_metric_metadata: standard_metric_metadata,
) -> DataFrame:
    metric_name = standard_metric_metadata.metric_name
    company_comparable_kpi = standard_metric_metadata.company_comparable_kpi
    uses_va_for_actuals = standard_metric_metadata.uses_va_for_actuals
    display_period_granularity = standard_metric_metadata.display_period_granularity
    report_period_granularity = standard_metric_metadata.report_period_granularity
    currency = standard_metric_metadata.currency
    value_divisor = standard_metric_metadata.value_divisor
    visible_alpha_id = standard_metric_metadata.visible_alpha_id

    metadata_df = df.withColumns(
        {
            "metric_name": F.lit(metric_name),
            "company_comparable_kpi": F.lit(company_comparable_kpi),
            "uses_va_for_actuals": F.lit(uses_va_for_actuals),
            "display_period_granularity": F.lit(display_period_granularity),
            "report_period_granularity": F.lit(report_period_granularity),
            "currency": F.lit(currency).cast("string"),
            "value_divisor": F.lit(value_divisor).cast("int"),
            "visible_alpha_id": F.lit(visible_alpha_id).cast("int"),
        }
    )

    return metadata_df


def _add_metric_config(
    df: DataFrame,
    standard_metric_configuration: standard_metric_configuration,
) -> DataFrame:
    metric_configuration = {
        "source_table_granularity": standard_metric_configuration.source_table_granularity,
        "max_relevant_years": standard_metric_configuration.max_relevant_years,
        "aggregate_function": standard_metric_configuration.aggregate_function,
        "growth_rate_type": standard_metric_configuration.growth_rate_type,
        "calendar_type": standard_metric_configuration.calendar_type,
        "slice_columns": standard_metric_configuration.slice_columns,
        "trailing_period": {
            "aggregate_sql_operator": standard_metric_configuration.trailing_period_aggregate_function,
            "grain": "DAY",
            "length": standard_metric_configuration.trailing_period_length,
        },
        "maximum_granularity": standard_metric_configuration.maximum_granularity,
    }
    source_table_filter_conditions = (
        standard_metric_configuration.source_table_filter_conditions
    )
    if source_table_filter_conditions:
        filter_values = []

        for condition in source_table_filter_conditions:
            value = (
                (str(condition).split(", ", 1)[1].split(")")[0])
                .strip()
                .lower()
                .replace(" ", "_")
            )
            filter_values.append(value)

        filter_conditions_string = " - ".join(filter_values)
    else:
        filter_conditions_string = None

    config_df = df.withColumns(
        {
            "metric_options": F.lit(json.dumps(metric_configuration)),
            "metric_filtering_conditions": F.lit(str(source_table_filter_conditions)),
            "filter_conditions_string": F.lit(filter_conditions_string),
        }
    )

    return config_df


@track_usage
[docs] def standard_metric_unified_kpi( df: DataFrame, entity_configuration: entity_configuration, standard_metric_metadata: standard_metric_metadata, standard_metric_configuration: standard_metric_configuration, calendar_df: Optional[DataFrame] = None, ) -> DataFrame: """ Generates a comprehensive dataframe containing all unified KPI analyses for a metric. This function performs a series of transformations to analyze a metric across multiple time periods and calculation types. The process includes: 1. Data Preparation: - Validates and extracts configuration parameters - Sets up calendar information for time-based analysis - Applies any source table filters 2. Analysis Generation: - Creates analysis groups based on periodicity (QUARTER, MONTH, etc.) - Calculates simple aggregates (SUM, AVG) for each period - Computes growth rates (YoY, QoQ) if configured - Handles period-to-date comparisons - Processes data slices (e.g., by product, region) 3. Data Enrichment: - Adds entity information (company name, ticker) - Includes metric metadata (currency, divisor) - Attaches calendar metadata (period labels, dates) 4. Special Handling: - Supports custom functions (e.g., PCT_OF_TOTAL) - Handles leap year adjustments - Manages trailing period calculations - Normalizes values across different period lengths :param df: Source DataFrame containing the metric data :param entity_configuration: A.entity_configuration for configuring entity details :param standard_metric_metadata: A.standard_metric_metadata for configuring metric metadata :param standard_metric_configuration: A.standard_metric_configuration for configuring metric calculations :param calendar_df: Optional calendar DataFrame (uses standard calendar if not provided) :return: DataFrame containing comprehensive metric unified KPI analyses Examples ^^^^^^^^ .. code-block:: python :caption: Generating unified KPI analyses for a metric. Example output is a subset of the actual output. from etl_toolkit import A input_df = spark.table("yd_production.afrm_live_reported.afrm_gmv_sliced") calendar_df = spark.table("yd_fp_investor_audit.afrm_xnas_deliverable_gold.custom_calendar__dmv__000") entity_configuration = A.entity_configuration( top_level_entity_name="Affirm", top_level_entity_ticker="AFRM:XNAS" ) standard_metric_metadata = A.standard_metric_metadata( metric_name="GMV", company_comparable_kpi=True, display_period_granularity="DAY", report_period_granularity="QUARTER", currency="USD", value_divisor=1000000, visible_alpha_id=14081329, ) standard_metric_configuration = A.standard_metric_configuration( source_input_column="gmv", source_input_date_column="date", aggregate_function="SUM", growth_rate_type="CAGR", max_relevant_years=4, calendar_type="EXACT_N_YEARS", slice_columns=["merchant"], trailing_period_length=7, trailing_period_aggregate_function="AVG", ) df = A.standard_metric_unified_kpi( input_df, entity_configuration, standard_metric_metadata, standard_metric_configuration, calendar_df ) display(df) +------------+---+-------------+---+-----------------+---+-----------------+-----------------+----------------+------------+---+-------------+ |metric_name |...|slice_name_1 |...|value |...|aggregation_type |calculation_type |trailing_period |periodicity |...|period_start | +------------+---+-------------+---+-----------------+---+-----------------+-----------------+----------------+------------+---+-------------+ |GMV |...|null |...|80536050.596995 |...|SUM |SIMPLE_AGGREGATE |null |DAY |...|2024-10-01 | +------------+---+-------------+---+-----------------+---+-----------------+-----------------+----------------+------------+---+-------------+ |GMV |...|merchant |...|35380357.162450 |...|SUM |SIMPLE_AGGREGATE |null |DAY |...|2024-10-01 | +------------+---+-------------+---+-----------------+---+-----------------+-----------------+----------------+------------+---+-------------+ |GMV |...|null |...|77256191.770087 |...|AVG |SIMPLE_AGGREGATE |7 |DAY |...|2024-10-01 | +------------+---+-------------+---+-----------------+---+-----------------+-----------------+----------------+------------+---+-------------+ |GMV |...|merchant |...|41700742.409985 |...|AVG |SIMPLE_AGGREGATE |7 |DAY |...|2024-10-01 | +------------+---+-------------+---+-----------------+---+-----------------+-----------------+----------------+------------+---+-------------+ """ # Extract configuration parameters value_column = standard_metric_configuration.source_input_column date_column = standard_metric_configuration.source_input_date_column source_table_granularity = standard_metric_configuration.source_table_granularity custom_function = standard_metric_configuration.custom_function # Use standard calendar if none provided if not calendar_df: spark = get_spark_session() calendar_df = spark.table( "yd_fp_investor_audit.calendar_gold.standard_calendar__dmv__000" ) # Apply source table filters if configured if standard_metric_configuration.source_table_filter_conditions: df = df.where(standard_metric_configuration.source_table_filter) # Get the date range of the input data input_table_bounds = get_aggregates( df, value_column=date_column, aggregate_functions=["min", "max"] ) # Generate analysis configurations based on metric metadata and configuration analysis_function_groups = _get_analysis_function_groups( standard_metric_metadata, standard_metric_configuration ) pre_final_df = None # Process each analysis group (different time periods and calculation types) for analysis_group in analysis_function_groups: analysis_periodicity = analysis_group[ "analysis_periodicity" ] # e.g., QUARTER, MONTH ptd_date_over_date = analysis_group[ "ptd_date_over_date" ] # Period-to-date comparison flag slice_list_for_group = analysis_group["slices"] # Dimensions to analyze by analysis_function_list = analysis_group["functions"] # Calculations to perform # Aggregate the base data by date and slices base_agg_df = df.groupBy(date_column, *slice_list_for_group).agg( F.sum(value_column).alias(value_column) ) # Transform the data into the tall format with all analyses tall_format_df = _tall_format( base_agg_df, calendar_df, value_column, date_column, source_table_granularity, input_table_bounds, analysis_periodicity, ptd_date_over_date, slice_list_for_group, analysis_function_list, ) # Combine all analysis groups into a single DataFrame if pre_final_df: pre_final_df = pre_final_df.unionByName(tall_format_df) else: pre_final_df = tall_format_df if custom_function and custom_function == "PCT_OF_TOTAL": # For PCT_OF_TOTAL metrics, transform nominal values into percentages of total for each # slice value within each period and analysis type. This is useful for metrics like # market share or composition analysis. pct_of_total_df = ( add_percent_of_total_columns( pre_final_df, value_columns=["value"], total_grouping_columns=[ "period_start", "period_end", "internal_dashboard_analysis_name", ], ) .withColumns( { "denominator": F.try_divide(F.col("value"), F.col("value_percent")), "derived_metric": F.lit(True), "operator": F.lit("DIVISION"), } ) .withColumnRenamed("value", "numerator") .withColumnRenamed("value_percent", "value") ) # Apply trailing and growth rate analyses after nominal values are converted to percentages # This ensures growth rates reflect changes in percentage points rather than absolute values pre_config_df = _add_derived_analyses( pct_of_total_df, standard_metric_configuration ) else: pre_config_df = pre_final_df.withColumns( { "numerator": F.lit(None), "denominator": F.lit(None), "derived_metric": F.lit(False), "operator": F.lit(None), } ) # Add entity information (company name, ticker, exchange, etc.) entity_df = _add_entity_configuration(pre_config_df, entity_configuration) # Add metric metadata (currency, divisor, display granularity, etc.) metadata_df = _add_metric_metadata(entity_df, standard_metric_metadata) # Add configuration details used to generate the analyses config_df = _add_metric_config(metadata_df, standard_metric_configuration) # Create the final DataFrame with a standardized column structure # This ensures consistency across all metrics and makes the data # easier to consume by downstream processes final_df = config_df.select( F.lit(None).cast("string").alias("value_id"), F.lit(None).cast("string").alias("analysis_id"), F.lit(analysis_name_clean).alias("analysis_name"), "metric_name", "company_comparable_kpi", F.lit(1).cast("int").alias("methodology_version"), "top_level_entity_name", "top_level_entity_ticker", "entity_name", "currency", *SLICE_NAME_VALUE_COLUMNS, F.col("value").cast(STANDARD_DECIMAL_TYPE).alias("value"), "value_divisor", F.lit("actual").alias("value_type"), "aggregation_type", "calculation_type", "trailing_period", "trailing_granularity", "periodicity", "analysis_options", "metric_options", F.lit(True).alias("key_metrics_experience_analysis"), "period_start", "period_end", "period_index", "calendar_type", "year_label_period_start", "year_label_period_end", "year_start_period_start", "year_end_period_start", "year_start_period_end", "year_end_period_end", "quarter_label_period_start", "quarter_label_period_end", "quarter_start_period_start", "quarter_end_period_start", "quarter_start_period_end", "quarter_end_period_end", "month_start_period_start", "month_end_period_start", "month_start_period_end", "month_end_period_end", "week_start_period_start", "week_end_period_start", "week_start_period_end", "week_end_period_end", "internal_dashboard_analysis_name", F.current_timestamp().alias("publication_timestamp"), "metric_filtering_conditions", F.lit(None).cast("int").alias("internal_metric_id"), "half_year_period_start", "half_year_period_end", "half_year_label", "half_year_length_in_days", "days_in_half_year", "figi", "uses_va_for_actuals", "display_period_granularity", "report_period_granularity", "visible_alpha_id", # numerator, denominator, operator, and derived_metric flag are added to support aggregating # PCT_OF_TOTAL and derived metrics at different slice combinations in the data download "derived_metric", F.when(F.col("calculation_type") == "SIMPLE_AGGREGATE", F.col("numerator")) .otherwise(F.lit(None)) .cast(STANDARD_DECIMAL_TYPE) .alias("numerator"), F.when(F.col("calculation_type") == "SIMPLE_AGGREGATE", F.col("denominator")) .otherwise(F.lit(None)) .cast(STANDARD_DECIMAL_TYPE) .alias("denominator"), F.when(F.col("calculation_type") == "SIMPLE_AGGREGATE", F.col("operator")) .otherwise(F.lit(None)) .cast("string") .alias("operator"), ) return final_df
def _get_derived_comparison_expression( operator: Literal[ "DIVISION", "ADDITION", "MULTIPLICATION", "SUBTRACTION" ] = "DIVISION", ) -> Column: """ Returns a comparison expression for a derived metric. """ match operator: case "DIVISION": return F.try_divide(F.col("value_1"), F.col("value_2")) case "ADDITION": return F.try_add(F.col("value_1"), F.col("value_2")) case "MULTIPLICATION": return F.try_multiply(F.col("value_1"), F.col("value_2")) case "SUBTRACTION": return F.try_subtract(F.col("value_1"), F.col("value_2")) def _validate_derived_metric_configuration( df: DataFrame, standard_metric_configuration: standard_metric_configuration, ): slice_columns = standard_metric_configuration.slice_columns trailing_period_length = standard_metric_configuration.trailing_period_length if slice_columns or trailing_period_length: source_configs = ( df.select( F.get_json_object(F.col("metric_options_1"), "$.slice_columns").alias( "slice_columns_1" ), F.get_json_object(F.col("metric_options_2"), "$.slice_columns").alias( "slice_columns_2" ), F.get_json_object( F.col("metric_options_1"), "$.source_table_granularity" ).alias("source_table_granularity_1"), F.get_json_object( F.col("metric_options_2"), "$.source_table_granularity" ).alias("source_table_granularity_2"), ) .first() .asDict() ) # convert slice_columns to json format to avoid returning error caused by # metric configuration using single quotes if slice_columns: slice_columns_str = json.dumps(slice_columns).replace(" ", "") slice_columns_1_str = source_configs["slice_columns_1"].replace(" ", "") slice_columns_2_str = source_configs["slice_columns_2"].replace(" ", "") if ( slice_columns_str != slice_columns_1_str or slice_columns_str != slice_columns_2_str ): raise InvalidInputException( "slice_columns must match the slice_columns from both input metric configurations." ) if trailing_period_length and ( source_configs["source_table_granularity_1"] != "DAY" or source_configs["source_table_granularity_2"] != "DAY" ): raise InvalidInputException( "Trailing analyses are only supported for daily periodicities." ) def _add_derived_trailing_analyses( df: DataFrame, trailing_period_length: int, trailing_period_aggregate_function: str, slice_columns: List[str] = None, ) -> DataFrame: """Add trailing period analyses to derived metric calculations.""" # Get the base daily aggregates to calculate trailing values from day_simple_aggregate_df = df.where( F.col("internal_dashboard_analysis_name") == "day_simple_aggregate" ) min_date = ( df.select(F.min("period_start").alias("min_date")).first().asDict()["min_date"] ) # Exclude trailing analyses if the full trailing period does not exist in the source table limit_date_expression = F.date_add( F.col("period_start"), 1 - trailing_period_length ) union_df = None # Calculate both SUM and AVG trailing values for the overall metric for aggregate_function in ["SUM", "AVG"]: # Create the trailing period expression based on the window trailing_expression = _get_pre_aggregate_expression( date_column="period_start", value_column="value", slice_columns=SLICE_NAME_VALUE_COLUMNS, trailing_period_length=trailing_period_length, trailing_period_aggregate_function=aggregate_function, ) # Apply the trailing calculation and update metadata trailing_value_df = day_simple_aggregate_df.withColumn( "trailing_value", trailing_expression ) trailing_df = ( trailing_value_df.drop("value") .withColumnRenamed("trailing_value", "value") .withColumns( { "internal_dashboard_analysis_name": F.lit( "day_simple_aggregate_trailing_day" ), "aggregation_type": F.lit(aggregate_function), "trailing_period": F.lit(trailing_period_length), "trailing_granularity": F.lit("DAY"), "analysis_min_date": limit_date_expression, } ) .where(F.col("analysis_min_date") >= min_date) .drop("analysis_min_date") ) union_df = union_df.unionByName(trailing_df) if union_df else trailing_df # If slices are configured, calculate trailing values for each slice if slice_columns: sliced_df = df.where( F.col("internal_dashboard_analysis_name") == "day_simple_aggregate_sliced_data" ) # Create trailing expression for sliced data trailing_expression = _get_pre_aggregate_expression( date_column="period_start", value_column="value", slice_columns=SLICE_NAME_VALUE_COLUMNS, trailing_period_length=trailing_period_length, trailing_period_aggregate_function=trailing_period_aggregate_function, ) # Apply the trailing calculation to sliced data trailing_value_df = sliced_df.withColumn("trailing_value", trailing_expression) trailing_df = ( trailing_value_df.drop("value") .withColumnRenamed("trailing_value", "value") .withColumns( { "internal_dashboard_analysis_name": F.lit( "day_simple_aggregate_trailing_day_sliced_data" ), "aggregation_type": F.lit(trailing_period_aggregate_function), "trailing_period": F.lit(trailing_period_length), "trailing_granularity": F.lit("DAY"), "analysis_min_date": limit_date_expression, } ) .where(F.col("analysis_min_date") >= min_date) .drop("analysis_min_date") ) # Add sliced trailing calculations to the result set union_df = union_df.unionByName(trailing_df) if union_df else trailing_df return union_df def _add_derived_growth_rate_analyses( df: DataFrame, growth_rate_type: str, calendar_type: str, max_relevant_years: int, slice_columns: List[str] = None, ) -> DataFrame: """Add growth rate analyses to derived metric calculations.""" # Exclude day simple aggregate analyses as they don't need growth rates analysis_df = ( df.where(F.col("internal_dashboard_analysis_name") != "day_simple_aggregate") .select("internal_dashboard_analysis_name") .distinct() ) analysis_list = [ row.internal_dashboard_analysis_name for row in analysis_df.collect() ] min_date = ( df.select(F.min("period_start").alias("min_date")).first().asDict()["min_date"] ) growth_rate_df = None column_list = df.columns # Process each analysis type for analysis in analysis_list: # Prepare the data for growth rate calculation analysis_df = df.where( F.col("internal_dashboard_analysis_name") == analysis ).withColumn("agg_0", F.col("value")) analysis_periodicity = analysis.split("_")[0].upper() calendar_type_for_growth_rate = ( "EXACT_N_YEARS" if analysis_periodicity == "MONTH" else calendar_type ) if calendar_type_for_growth_rate == "EXACT_N_YEARS": if "slice" in analysis: SLICE_COLUMNS_LENGTH = len(slice_columns) SLICE_COLUMNS_FOR_LAG = [ f"slice_name_{i}" for i in range(1, SLICE_COLUMNS_LENGTH + 1) ] SLICE_VALUE_LIST = [ f"slice_value_{j}" for j in range(1, SLICE_COLUMNS_LENGTH + 1) ] SLICE_COLUMNS_FOR_LAG.extend(SLICE_VALUE_LIST) else: SLICE_COLUMNS_FOR_LAG = [] for year in range(1, max_relevant_years + 1): analysis_df = add_lag_columns( analysis_df, value_columns=["agg_0"], date_column="period_start", slice_columns=SLICE_COLUMNS_FOR_LAG, step_unit="YEAR", steps=year, ) for year in range(1, max_relevant_years + 1): new_internal_dashboard_analysis_name = analysis.replace( "simple_aggregate", f"{year}y_growth_rate" ) calculation_expression = _get_growth_rate_calculation_expression( analysis_periodicity, SLICE_NAME_VALUE_COLUMNS, growth_rate_type, calendar_type, year, 0, ) # Exclude growth rate analyses if the value is not available for the relevant year limit_date_expression = F.col("period_start") - F.expr( f"INTERVAL {year} YEAR" ) growth_rate_value_df = ( analysis_df.withColumn("growth_rate_value", calculation_expression) .drop("value") .withColumnRenamed("growth_rate_value", "value") .select(*column_list) .withColumns( { "internal_dashboard_analysis_name": F.lit( new_internal_dashboard_analysis_name ), "calculation_type": F.lit("GROWTH_RATE"), "calendar_type": F.lit(calendar_type_for_growth_rate), "analysis_min_date": limit_date_expression, } ) .where(F.col("analysis_min_date") >= min_date) .drop("analysis_min_date") ) growth_rate_df = ( growth_rate_df.unionByName(growth_rate_value_df) if growth_rate_df else growth_rate_value_df ) return growth_rate_df def _add_derived_analyses( df: DataFrame, standard_metric_configuration: standard_metric_configuration, ) -> DataFrame: trailing_period_length = standard_metric_configuration.trailing_period_length growth_rate_type = standard_metric_configuration.growth_rate_type slice_columns = standard_metric_configuration.slice_columns if trailing_period_length: trailing_period_aggregate_function = ( standard_metric_configuration.trailing_period_aggregate_function ) trailing_df = _add_derived_trailing_analyses( df, trailing_period_length, trailing_period_aggregate_function, slice_columns, ) df = df.unionByName(trailing_df) if growth_rate_type: calendar_type = standard_metric_configuration.calendar_type max_relevant_years = standard_metric_configuration.max_relevant_years growth_rate_df = _add_derived_growth_rate_analyses( df, growth_rate_type, calendar_type, max_relevant_years, slice_columns, ) df = df.unionByName(growth_rate_df) return df @track_usage
[docs] def standard_metric_unified_kpi_derived( unified_kpi_df_1: DataFrame, unified_kpi_df_2: DataFrame, standard_metric_metadata: standard_metric_metadata, standard_metric_configuration: standard_metric_configuration, operator: Literal[ "DIVISION", "ADDITION", "MULTIPLICATION", "SUBTRACTION" ] = "DIVISION", ) -> DataFrame: """ Generates a dataframe containing a unified KPI analyses for a derived standard metric based on unified KPI dataframes of two input metrics. .. note:: The mathematical operation preserves the order of the input metrics' unified KPI tables. If the operator is DIVISION, for example, the metric from unified_kpi_df_1 will be used as the numerator and the metric from unified_kpi_df_1 will be used as the denominator. :param unified_kpi_df_1: A dataframe containing unified KPI analyses of one metric :param unified_kpi_df_2: A dataframe containing unified KPI analyses of another metric :param standard_metric_metadata: A.standard_metric_metadata configurations :param standard_metric_configuration: A.standard_metric_configuration configurations :param operator: mathematical operation between the two standard metrics to get the derived metric :return: A dataframe containing unified KPI analyses of the derived metric Examples ^^^^^^^^ .. code-block:: python :caption: Generating unified KPI analyses for a derived metric. Example output is a subset of the actual output. from etl_toolkit import A entity_configuration = A.entity_configuration( top_level_entity_name="Chewy", top_level_entity_ticker="CHWY:XNYS", figi="BBG00P19DLQ4", ) calendar_df = spark.table( "yd_fp_investor_audit.chwy_xnys_deliverable_gold.custom_calendar__dmv__000" ) standard_metric_metadata_1 = A.standard_metric_metadata( metric_name="Net Sales - Order Date", company_comparable_kpi=False, currency="USD", value_divisor=1000000, ) standard_metric_configuration_1 = A.standard_metric_configuration( source_input_column="net_sales_order_date", source_input_date_column="date", calendar_type="52_WEEK", trailing_period_aggregate_function="SUM", ) input_df_1 = spark.table("yd_production.chwy_live_reported.chwy_net_sales_order_date") unified_kpi_df_1 = A.standard_metric_unified_kpi( input_df_1, entity_configuration, standard_metric_metadata_1, standard_metric_configuration_1, calendar_df ) standard_metric_metadata_2 = A.standard_metric_metadata( metric_name="Orders - Order Date", company_comparable_kpi=False, ) standard_metric_configuration_2 = A.standard_metric_configuration( source_input_column="order_date_orders_index", source_input_date_column="date", calendar_type="52_WEEK", trailing_period_length=None, trailing_period_aggregate_function=None, ) input_df_2 = spark.table("yd_production.chwy_reported.edison_daily_sales") unified_kpi_df_2 = A.standard_metric_unified_kpi( input_df_2, entity_configuration, standard_metric_metadata_2, standard_metric_configuration_2, calendar_df ) derived_standard_metric_metadata = A.standard_metric_metadata( metric_name="AOV - Order Date", company_comparable_kpi=False, ) derived_standard_metric_configuration = A.standard_metric_configuration( max_relevant_years=2, growth_rate_type="CAGR", calendar_type="52_WEEK", trailing_period_length=7, trailing_period_aggregate_function="AVG", ) df = unified_kpi_derived( unified_kpi_df_1, unified_kpi_df_2, derived_standard_metric_metadata, operator="DIVISION" ) display(df) +-----------------+---+-------------+---+--------------+---+-----------------+-----------------+----------------+------------+---+-------------+ |metric_name |...|slice_name_1 |...|value |...|aggregation_type |calculation_type |trailing_period |periodicity |...|period_start | +-----------------+---+-------------+---+--------------+---+-----------------+-----------------+----------------+------------+---+-------------+ |AOV - Order Date |...|null |...|34136.196960 |...|SUM |SIMPLE_AGGREGATE |null |QUARTER |...|2021-02-01 | +-----------------+---+-------------+---+--------------+---+-----------------+-----------------+----------------+------------+---+-------------+ |AOV - Order Date |...|null |...|40176.041644 |...|SUM |SIMPLE_AGGREGATE |null |QUARTER |...|2023-01-30 | +-----------------+---+-------------+---+--------------+---+-----------------+-----------------+----------------+------------+---+-------------+ |AOV - Order Date |...|null |...|29545.027237 |...|SUM |SIMPLE_AGGREGATE |null |QUARTER |...|2019-11-04 | +-----------------+---+-------------+---+--------------+---+-----------------+-----------------+----------------+------------+---+-------------+ |AOV - Order Date |...|null |...|37124.360009 |...|SUM |SIMPLE_AGGREGATE |null |QUARTER |...|2022-05-02 | +-----------------+---+-------------+---+--------------+---+-----------------+-----------------+----------------+------------+---+-------------+ """ operator_normalized = operator.upper() comparison_expression = _get_derived_comparison_expression(operator_normalized) df_1 = ( unified_kpi_df_1.where(F.col("calculation_type") == "SIMPLE_AGGREGATE") .where(F.col("trailing_period").isNull()) .withColumnRenamed("value", "value_1") .withColumnRenamed("metric_options", "metric_options_1") .drop( "metric_name", "company_comparable_kpi", "uses_va_for_actuals", "display_period_granularity", "report_period_granularity", "currency", "value_divisor", "visible_alpha_id", "aggregation_type", ) ) df_2 = ( unified_kpi_df_2.where(F.col("calculation_type") == "SIMPLE_AGGREGATE") .where(F.col("trailing_period").isNull()) .select( "period_start", "period_end", "internal_dashboard_analysis_name", *SLICE_NAME_VALUE_COLUMNS, F.col("value").alias("value_2"), F.col("metric_options").alias("metric_options_2"), ) ) join_col = [ "period_start", "period_end", "internal_dashboard_analysis_name", *SLICE_NAME_VALUE_COLUMNS, ] combined_df = ( df_1.join( df_2, [df_1[c].eqNullSafe(df_2[c]) for c in join_col], "inner", ) .drop(*[df_2[c] for c in join_col]) .withColumns( { "value": comparison_expression, "aggregation_type": F.lit("SUM"), } ) ) _validate_derived_metric_configuration( combined_df, standard_metric_configuration, ) if ( standard_metric_configuration.trailing_period_length or standard_metric_configuration.growth_rate_type ): derived_analyses_df = combined_df.drop("metric_options_1", "metric_options_2") pre_final_df = _add_derived_analyses( derived_analyses_df, standard_metric_configuration, ) else: pre_final_df = combined_df.drop("metric_options_1", "metric_options_2") metadata_df = _add_metric_metadata(pre_final_df, standard_metric_metadata) config_df = _add_metric_config(metadata_df, standard_metric_configuration) final_df = config_df.select( F.lit(None).cast("string").alias("value_id"), F.lit(None).cast("string").alias("analysis_id"), F.lit(analysis_name_clean).alias("analysis_name"), "metric_name", "company_comparable_kpi", F.lit(1).cast("int").alias("methodology_version"), "top_level_entity_name", "top_level_entity_ticker", "entity_name", "currency", *SLICE_NAME_VALUE_COLUMNS, F.col("value").cast(STANDARD_DECIMAL_TYPE).alias("value"), "value_divisor", F.lit("actual").alias("value_type"), "aggregation_type", "calculation_type", "trailing_period", "trailing_granularity", "periodicity", "analysis_options", "metric_options", F.lit(True).alias("key_metrics_experience_analysis"), "period_start", "period_end", "period_index", "calendar_type", "year_label_period_start", "year_label_period_end", "year_start_period_start", "year_end_period_start", "year_start_period_end", "year_end_period_end", "quarter_label_period_start", "quarter_label_period_end", "quarter_start_period_start", "quarter_end_period_start", "quarter_start_period_end", "quarter_end_period_end", "month_start_period_start", "month_end_period_start", "month_start_period_end", "month_end_period_end", "week_start_period_start", "week_end_period_start", "week_start_period_end", "week_end_period_end", "internal_dashboard_analysis_name", F.current_timestamp().alias("publication_timestamp"), "metric_filtering_conditions", F.lit(None).cast("int").alias("internal_metric_id"), "half_year_period_start", "half_year_period_end", "half_year_label", "half_year_length_in_days", "days_in_half_year", "figi", "uses_va_for_actuals", "display_period_granularity", "report_period_granularity", "visible_alpha_id", # numerator, denominator, operator, and derived_metric flag are added to support aggregating # PCT_OF_TOTAL and derived metrics at different slice combinations in the data download F.lit(True).alias("derived_metric"), F.when(F.col("calculation_type") == "SIMPLE_AGGREGATE", F.col("value_1")) .otherwise(F.lit(None)) .cast(STANDARD_DECIMAL_TYPE) .alias("numerator"), F.when(F.col("calculation_type") == "SIMPLE_AGGREGATE", F.col("value_2")) .otherwise(F.lit(None)) .cast(STANDARD_DECIMAL_TYPE) .alias("denominator"), F.when( F.col("calculation_type") == "SIMPLE_AGGREGATE", F.lit(operator_normalized) ) .otherwise(F.lit(None)) .cast("string") .alias("operator"), ) return final_df