from pyspark.sql import Column, Window as W, functions as F
from yipit_databricks_utils.helpers.telemetry import track_usage
from etl_toolkit.expressions.core import normalize_column, normalize_literal
@track_usage
[docs]
def growth_rate_by_lag(
value_column: str | Column,
window: W,
num_periods: int = 1,
default: int | float | Column = None,
base_value_column: str | Column = None,
) -> Column:
"""
Generate a simple growth rate calculation using a provided window expression and value column.
The window is used to lag the value column to generate the base period for the growth rate.
By default, the lag is 1 row but it can be increased using ``num_periods``.
The growth rate is calculated as (x1 - x0) / x0.
.. caution:: The lag is performed via a row-based window expression and can lead to incorrect results if there are gaps in the date ranges.
It's important to adjust for missing dates in the data to ensure accurate calculations.
The ``A.fill_periods`` function can be useful in these scenarios.
:param value_column: The column to use as the numerator for the growth rate. A numeric type column should be used. If a string is passed, it is treated as a Column.
:param window: A pyspark Window object used for the lag operation. The window must have an ``.orderBy`` clause defined.
:param num_periods: The number of periods to lag the current value against for the growth rate.
:param default: An optional value to use if the growth rate column is NULL. This is useful to fill in the initial value of the growth rate column. A number passed in will be treated a literal.
:param base_value_column: An optional column to use as the denominator for the growth rate using `num_periods` as the lag interval. A numeric type column should be used. If the `value_column` is desired for the denominator, leave this as `None`. Default is `None` (i.e. the `value_column` is used for the denominator.)
Examples
-----------
.. code-block:: python
:caption: Using E.growth_rate_by_lag to calculate a y/y growth rate.
from datetime import date
from etl_toolkit import E, F, W
df = spark.createDataFrame([
{"date": date(2020, 1, 1), "input": 0},
{"date": date(2021, 1, 2), "input": 10},
{"date": date(2022, 1, 3), "input": 20},
{"date": date(2023, 1, 4), "input": 25},
{"date": date(2024, 1, 5), "input": 50},
])
display(
df
.withColumn(
"output",
E.growth_rate_by_lag(
"input",
W.partitionBy(F.lit(1)).orderBy("date"),
)
)
)
+--------------+--------------+--------------+
|date |input |output |
+--------------+--------------+--------------+
|2020-01-01 |0 |NULL |
+--------------+--------------+--------------+
|2021-01-02 |10 |NULL |
+--------------+--------------+--------------+
|2022-01-03 |20 |1.0 |
+--------------+--------------+--------------+
|2023-01-04 |25 |0.25 |
+--------------+--------------+--------------+
|2024-01-05 |50 |1.0 |
+--------------+--------------+--------------+
.. code-block:: python
:caption: Use the default value to coalesce null values in the calculation
from datetime import date
from etl_toolkit import E, F, W
df = spark.createDataFrame([
{"date": date(2020, 1, 1), "input": 0},
{"date": date(2021, 1, 2), "input": 10},
{"date": date(2022, 1, 3), "input": 20},
{"date": date(2023, 1, 4), "input": 25},
{"date": date(2024, 1, 5), "input": 50},
])
display(
df
.withColumn(
"output",
E.growth_rate_by_lag(
"input",
W.partitionBy(F.lit(1)).orderBy("date"),
default=0,
)
)
)
+--------------+--------------+--------------+
|date |input |output |
+--------------+--------------+--------------+
|2020-01-01 |0 |0 |
+--------------+--------------+--------------+
|2021-01-02 |10 |0 |
+--------------+--------------+--------------+
|2022-01-03 |20 |1.0 |
+--------------+--------------+--------------+
|2023-01-04 |25 |0.25 |
+--------------+--------------+--------------+
|2024-01-05 |50 |1.0 |
+--------------+--------------+--------------+
.. code-block:: python
:caption: Use num_periods to generate a 2-yr simple growth rate.
from datetime import date
from etl_toolkit import E, F, W
df = spark.createDataFrame([
{"date": date(2020, 1, 1), "input": 0},
{"date": date(2021, 1, 2), "input": 10},
{"date": date(2022, 1, 3), "input": 20},
{"date": date(2023, 1, 4), "input": 25},
{"date": date(2024, 1, 5), "input": 50},
])
display(
df
.withColumn(
"output",
E.growth_rate_by_lag(
"input",
W.partitionBy(F.lit(1)).orderBy("date"),
num_periods=2,
)
)
)
+--------------+--------------+--------------+
|date |input |output |
+--------------+--------------+--------------+
|2020-01-01 |0 |NULL |
+--------------+--------------+--------------+
|2021-01-02 |10 |NULL |
+--------------+--------------+--------------+
|2022-01-03 |20 |NULL |
+--------------+--------------+--------------+
|2023-01-04 |25 |1.5 |
+--------------+--------------+--------------+
|2024-01-05 |50 |1.5 |
+--------------+--------------+--------------+
.. code-block:: python
:caption: Use a different `base_value_column` to calculate a y/y growth rate. Notice how the lag
to determine the denominator is now on the `base` column instead of the `input` column.
from datetime import date
from etl_toolkit import E, F, W
df = spark.createDataFrame([
{"date": date(2020, 1, 1), "input": 0, "base": 5},
{"date": date(2021, 1, 2), "input": 10, "base": 16},
{"date": date(2022, 1, 3), "input": 20, "base": 20},
{"date": date(2023, 1, 4), "input": 25, "base": 25},
{"date": date(2024, 1, 5), "input": 50, "base": 30},
])
display(
df
.withColumn(
"output",
E.growth_rate_by_lag(
"input",
W.partitionBy(F.lit(1)).orderBy("date"),
base_value_column="base",
)
)
)
+--------------+--------------+--------------+--------------+
|date |input |base |output |
+--------------+--------------+--------------+--------------+
|2020-01-01 |0 |5 |0 |
+--------------+--------------+--------------+--------------+
|2021-01-02 |10 |16 |1.0 |
+--------------+--------------+--------------+--------------+
|2022-01-03 |20 |20 |0.25 |
+--------------+--------------+--------------+--------------+
|2023-01-04 |25 |25 |0.25 |
+--------------+--------------+--------------+--------------+
|2024-01-05 |50 |30 |1.0 |
+--------------+--------------+--------------+--------------+
"""
value_column = normalize_column(value_column)
if base_value_column is None:
base_value_column = value_column
else:
base_value_column = normalize_column(base_value_column)
growth_rate_column = (
F.try_divide(value_column, F.lag(base_value_column, num_periods).over(window))
) - 1
if default is not None:
return F.coalesce(growth_rate_column, normalize_literal(default))
return growth_rate_column