from datetime import datetime, date
import re
from typing import Literal
from pyspark.sql import functions as F, DataFrame, Window as W, Column
from yipit_databricks_utils.helpers.telemetry import track_usage
from etl_toolkit import expressions as E
from etl_toolkit.exceptions import InvalidColumnTypeException, InvalidInputException
@track_usage
[docs]
def add_lag_columns(
df: DataFrame,
value_columns: list[str | Column],
date_column: str | Column,
slice_columns: list[str | Column] = None,
steps: int = 1,
step_unit: Literal[
"DAY", "WEEK", "MONTH", "YEAR", "HOUR", "MINUTE", "SECOND"
] = "DAY",
) -> DataFrame:
"""
Adds additional columns to input dataframe that are lagged version of the ``value_columns`` specified.
The lag is calculated based on ``date_column`` and the specified interval (`steps` * ``step_unit``).
Lags can be performed within each slice if the ``slice_columns`` are specified.
The added lag columns are named in a standard way based on the interval, ex: ("revenue", 1, "DAY") -> "revenue_lag_1_day"
.. caution:: The lag is performed via a self-join to match a date with the corresponding (date - interval) and within any slice(s).
It's important to account for missing dates in the data to ensure accurate calculations.
The ``A.fill_periods`` function can be useful in these scenarios.
:param df: The input Dataframe to add lag calculations to
:param value_columns: A list of Columns or strings to base the lag calculations for. Each column will have a corresponding lag column added. If strings are provided, they are resolved as Columns.
:param date_column: A Column or str that should be of date or timestamp type. This column is used to determine the lagged period for lag calculations. If a string is provided, it is resolved as a Column.
:param slice_columns: An optional list of Columns or strings to define the slices of the dataframe. Within each slice a lag calculation will be generated based on the date_column + interval. If strings are provided, they are resolved as Columns.
:param steps: The number of `step_units` that define a lag interval.
:param step_unit: The unit of length of the lag period, and when combined with `steps`, equals the lag interval.
Examples
^^^^^^^^^^
.. code-block:: python
:caption: Basic example of a 1-day lag of the value column based on the date column.
from etl_toolkit import E, F, A
from datetime import date
df = spark.createDataFrame([
{"value": 100, "color": "red", "date": date(2024, 1, 2)},
{"value": 50, "color": "red", "date": date(2024, 1, 1)},
])
display(
A.add_lag_columns(
df,
value_columns=["value"],
date_column="date",
step_unit="DAY",
steps=1,
)
)
+--------------+--------------+--------------+----------------+
|color |date |value |value_lag_1_day |
+--------------+--------------+--------------+----------------+
| red| 2024-01-02| 100| 50|
+--------------+--------------+--------------+----------------+
| red| 2024-01-01| 50| null|
+--------------+--------------+--------------+----------------+
.. code-block:: python
:caption: Example of a 2-day lag of the value column based on the date column.
Notice the lag column name changed to reflect the interval.
from etl_toolkit import E, F, A
from datetime import date
df = spark.createDataFrame([
{"value": 250, "color": "red", "date": date(2024, 1, 5)},
{"value": 200, "color": "red", "date": date(2024, 1, 4)},
{"value": 150, "color": "red", "date": date(2024, 1, 3)},
{"value": 100, "color": "red", "date": date(2024, 1, 2)},
{"value": 50, "color": "red", "date": date(2024, 1, 1)},
])
display(
A.add_lag_columns(
df,
value_columns=["value"],
date_column="date",
step_unit="DAY",
steps=2,
)
)
+--------------+--------------+--------------+---------------+
|color |date |value |value_lag_2_day|
+--------------+--------------+--------------+---------------+
| red| 2024-01-05| 250| 150|
+--------------+--------------+--------------+---------------+
| red| 2024-01-04| 200| 100|
+--------------+--------------+--------------+---------------+
| red| 2024-01-03| 150| 50|
+--------------+--------------+--------------+---------------+
| red| 2024-01-02| 100| null|
+--------------+--------------+--------------+---------------+
| red| 2024-01-01| 50| null|
+--------------+--------------+--------------+---------------+
.. code-block:: python
:caption: Example of a 2-week lag of the value column based on the date column.
Notice how the dataframe has a weekly periodicity, and the data is still lagged correctly
because a self-join is performed in the operation.
from etl_toolkit import E, F, A
from datetime import date
df = spark.createDataFrame([
{"value": 250, "color": "red", "date": date(2024, 1, 29)},
{"value": 200, "color": "red", "date": date(2024, 1, 22)},
{"value": 150, "color": "red", "date": date(2024, 1, 15)},
{"value": 100, "color": "red", "date": date(2024, 1, 8)},
{"value": 50, "color": "red", "date": date(2024, 1, 1)},
])
display(
A.add_lag_columns(
df,
value_columns=["value"],
date_column="date",
step_unit="WEEK",
steps=2,
)
)
+--------------+--------------+--------------+----------------+
|color |date |value |value_lag_2_week|
+--------------+--------------+--------------+----------------+
| red| 2024-01-29| 250| 150|
+--------------+--------------+--------------+----------------+
| red| 2024-01-22| 200| 100|
+--------------+--------------+--------------+----------------+
| red| 2024-01-15| 150| 50|
+--------------+--------------+--------------+----------------+
| red| 2024-01-08| 100| null|
+--------------+--------------+--------------+----------------+
| red| 2024-01-01| 50| null|
+--------------+--------------+--------------+----------------+
.. code-block:: python
:caption: Example of a 1-day lag of the value column using the color column for slices.
Notice that the lag is applied within each unique slice of the dataframe.
from etl_toolkit import E, F, A
from datetime import date
df = spark.createDataFrame([
{"value": 100, "color": "red", "date": date(2024, 1, 2)},
{"value": 50, "color": "red", "date": date(2024, 1, 1)},
{"value": 150, "color": "blue", "date": date(2024, 1, 2)},
{"value": 75, "color": "blue", "date": date(2024, 1, 1)},
])
display(
A.add_lag_columns(
df,
value_columns=["value"],
date_column="date",
slice_columns=["color"],
step_unit="DAY",
steps=1,
)
)
+--------------+--------------+--------------+---------------+
|color |date |value |value_lag_1_day|
+--------------+--------------+--------------+---------------+
| red| 2024-01-02| 100| 50|
+--------------+--------------+--------------+---------------+
| red| 2024-01-01| 50| null|
+--------------+--------------+--------------+---------------+
| blue| 2024-01-02| 150| 75|
+--------------+--------------+--------------+---------------+
| blue| 2024-01-01| 75| null|
+--------------+--------------+--------------+---------------+
"""
slice_columns = slice_columns or []
slice_columns = [E.normalize_column(slice_column) for slice_column in slice_columns]
slice_column_names = df.select(slice_columns).columns if len(slice_columns) else []
lag_interval = F.expr(f"INTERVAL {steps} {step_unit}")
value_columns = [E.normalize_column(value_column) for value_column in value_columns]
value_column_names = df.select(value_columns).columns
date_column = E.normalize_column(date_column)
date_column_name = df.select(date_column).columns[0]
return (
df.alias("a")
.join(
df.select([date_column, *value_columns, *slice_columns]).alias("b"),
E.all(
[
(F.col(f"a.{date_column_name}") - lag_interval)
== F.col(f"b.{date_column_name}"),
*[
F.col(f"a.{slice_column}") == F.col(f"b.{slice_column}")
for slice_column in slice_column_names
],
]
),
how="left",
)
.select(
"a.*",
*[
F.col(f"b.{value_column}").alias(
f"{value_column}_lag_{steps}_{step_unit.lower()}"
)
for value_column in value_column_names
],
)
)
@track_usage
[docs]
def add_percent_of_total_columns(
df: DataFrame,
value_columns: list[str | Column],
total_grouping_columns: list[str | Column],
suffix: str = "percent",
) -> DataFrame:
"""
Add additional column(s) to the dataframe that equal the percent of each row of the ``value_columns`` given the sum of the values across the ``total_grouping_columns``.
Each percent column added will have a standard naming convention that is the <value_column>_<suffix> (ex: "gmv" -> "gmv_percent"). The default ``suffix`` is "percent",
but this can be adjusted.
.. tip:: It is recommended this function is used whenever a percent of total operation is needed.
The implementation uses a group by + join under the hood, which is more performant than a
window expression.
:param df: The dataframe to add percent of total columns to
:param value_columns: A list of Columns or strings to generate percent of total columns. Each column specified will have a percent of total column added to the output dataframe. If a strings are provided, they will be resolved as Columns.
:param total_grouping_columns: A list of Columns or strings to define the slices of the dataframe. The sum of the value column for each unique combination of values in this list will be used as the denominator for the percentage column. If a strings are provided, they will be resolved as Columns.
:param suffix: The suffix that is added to each percent column generated for the output dataframe.
Examples
^^^^^^^^^^
.. code-block:: python
:caption: Example of generating percent of total columns based on the color column for groupings.
from etl_toolkit import E, F, A
from datetime import date
df = spark.createDataFrame([
{"value": 100, "color": "red", "date": date(2024, 1, 2)},
{"value": 50, "color": "red", "date": date(2024, 1, 1)},
])
display(
A.add_percent_of_total_columns(
df,
value_columns=["value"],
total_grouping_columns=["color"],
)
)
+--------------+--------------+--------------+----------------+
|color |date |value |value_percent |
+--------------+--------------+--------------+----------------+
| red| 2024-01-02| 100| 0.666666|
+--------------+--------------+--------------+----------------+
| red| 2024-01-01| 50| 0.333333|
+--------------+--------------+--------------+----------------+
.. code-block:: python
:caption: Example of modifying the suffix argument to adjust the new column names.
from etl_toolkit import E, F, A
from datetime import date
df = spark.createDataFrame([
{"value": 100, "color": "red", "date": date(2024, 1, 2)},
{"value": 50, "color": "red", "date": date(2024, 1, 1)},
])
display(
A.add_percent_of_total_columns(
df,
value_columns=["value"],
total_grouping_columns=["color"],
suffix="percent_of_total",
)
)
+--------------+--------------+--------------+-------------------------+
|color |date |value |value_percent_of_total |
+--------------+--------------+--------------+-------------------------+
| red| 2024-01-02| 100| 0.666666|
+--------------+--------------+--------------+-------------------------+
| red| 2024-01-01| 50| 0.333333|
+--------------+--------------+--------------+-------------------------+
.. code-block:: python
:caption: Example of generating multiple percent of total columns based on the date column for groupings.
from etl_toolkit import E, F, A
from datetime import date
df = spark.createDataFrame([
{"value": 100, "color": "red", "date": date(2024, 1, 1), "count": 5},
{"value": 50, "color": "blue", "date": date(2024, 1, 1), "count": 10},
])
display(
A.add_percent_of_total_columns(
df,
value_columns=["value", "count"],
total_grouping_columns=["date"],
)
)
+--------------+--------------+--------------+----------------+--------------+----------------+
|color |date |value |count |count_percent |value_percent |
+--------------+--------------+--------------+----------------+--------------+----------------+
| red| 2024-01-01| 100| 5| 0.333333| 0.666666|
+--------------+--------------+--------------+----------------+--------------+----------------+
| blue| 2024-01-01| 50| 10| 0.666666| 0.333333|
+--------------+--------------+--------------+----------------+--------------+----------------+
"""
if not isinstance(total_grouping_columns, list):
raise InvalidColumnTypeException(
"total_grouping_columns must be a list of column names (str) or pyspark columns"
)
if not isinstance(value_columns, list):
raise InvalidColumnTypeException(
"value_columns must be a list of column names (str) or pyspark columns"
)
if not re.search(r"[a-z0-9_]+", suffix):
raise InvalidInputException(
"suffix must be lower-case alphanumeric and underscores characters only"
)
original_column_names = df.columns
total_grouping_columns = [E.normalize_column(col) for col in total_grouping_columns]
total_grouping_column_names = df.select(total_grouping_columns).columns
value_columns = [E.normalize_column(col) for col in value_columns]
value_column_names = df.select(value_columns).columns
# Use a group by + join as that is more performant than a window expression
group_df = df.groupBy(*total_grouping_columns).agg(
*[
F.sum(value_column).alias(f"{value_column_names[idx]}_total")
for idx, value_column in enumerate(value_columns)
]
)
# Ensure all group columns are columns on the original DF for null-safe join condition
grouping_columns_to_add = {
group_column_name: total_grouping_columns[idx]
for idx, group_column_name in enumerate(total_grouping_column_names)
if group_column_name not in original_column_names
}
if len(grouping_columns_to_add):
df = df.withColumns(grouping_columns_to_add)
# Only include the original df columns and the pct of total columns
enriched_df = (
df.alias("original")
.join(
group_df.alias("agg"),
E.all(
[
F.col(f"original.{grouping_column_name}").eqNullSafe(
F.col(f"agg.{grouping_column_name}")
)
for idx, grouping_column_name in enumerate(
total_grouping_column_names
)
]
),
how="left",
)
.select(
*[
F.col(f"original.{column_name}")
for column_name in original_column_names
],
*[
F.try_divide(
value_column, F.col(f"{value_column_names[idx]}_total")
).alias(f"{value_column_names[idx]}_{suffix}")
for idx, value_column in enumerate(value_columns)
],
)
)
return enriched_df
@track_usage
def add_percent_fill_columns(
df: DataFrame,
value_columns: list[str | Column],
total_grouping_columns: list[str | Column],
order_columns: list[str | Column],
suffix: str = "percent_fill",
) -> DataFrame:
if not isinstance(total_grouping_columns, list):
raise InvalidColumnTypeException(
"total_grouping_columns must be a list of column names (str) or pyspark columns"
)
if not isinstance(value_columns, list):
raise InvalidColumnTypeException(
"value_columns must be a list of column names (str) or pyspark columns"
)
if not isinstance(order_columns, list):
raise InvalidColumnTypeException(
"order_columns must be a list of column names (str) or pyspark columns"
)
if not re.search(r"[a-z0-9_]+", suffix):
raise InvalidInputException(
"suffix must be lower-case alphanumeric and underscores characters only"
)
original_column_names = df.columns
total_grouping_columns = [E.normalize_column(col) for col in total_grouping_columns]
total_grouping_column_names = df.select(total_grouping_columns).columns
order_columns = [E.normalize_column(col) for col in order_columns]
order_column_names = df.select(order_columns).columns
value_columns = [E.normalize_column(col) for col in value_columns]
value_column_names = df.select(value_columns).columns
# Use a group by + join to avoid as many window expressions as possible
group_df = df.groupBy(*total_grouping_columns).agg(
*[
F.sum(value_column).alias(f"{value_column_names[idx]}_total")
for idx, value_column in enumerate(value_columns)
]
)
# Ensure all group columns are columns on the original DF for null-safe join condition
grouping_columns_to_add = {
group_column_name: total_grouping_columns[idx]
for idx, group_column_name in enumerate(total_grouping_column_names)
if group_column_name not in original_column_names
}
if len(grouping_columns_to_add):
df = df.withColumns(grouping_columns_to_add)
# Only include the original df columns and the percent_fill columns
# percent_fill is the running or "cumulative" total / the absolute total for each grouping
cumulative_window = (
W.partitionBy(
*[
F.col(f"original.{grouping_column_name}")
for grouping_column_name in total_grouping_column_names
]
)
.orderBy(
*[
F.col(f"original.{order_column_name}")
for order_column_name in order_column_names
]
)
.rowsBetween(W.unboundedPreceding, W.currentRow)
)
enriched_df = (
df.alias("original")
.join(
group_df.alias("agg"),
E.all(
[
F.col(f"original.{grouping_column_name}").eqNullSafe(
F.col(f"agg.{grouping_column_name}")
)
for idx, grouping_column_name in enumerate(
total_grouping_column_names
)
]
),
how="left",
)
.select(
*[
F.col(f"original.{column_name}")
for column_name in original_column_names
],
*[
(
F.sum(value_column).over(cumulative_window)
/ F.col(f"{value_column_names[idx]}_total")
).alias(f"{value_column_names[idx]}_{suffix}")
for idx, value_column in enumerate(value_columns)
],
)
)
return enriched_df