from typing import Any, Literal
from pyspark.sql import DataFrame, Column, functions as F
from yipit_databricks_utils.helpers.telemetry import track_usage
from etl_toolkit import expressions as E
from etl_toolkit.exceptions import InvalidInputException
@track_usage
[docs]
def get_aggregates(
df: DataFrame,
value_column: str | Column,
aggregate_functions: list[
Literal["min", "max", "count", "avg", "sum", "count_distinct"]
],
) -> dict[str, Any]:
"""
Returns a dictionary of aggregate values based on the ``value_column``, ``df``, and ``aggregate_functions`` specified.
This is used to be able to extract the min, max, etc. values from a pyspark dataframe into python and use in subsequent code.
The output is a dictionary with key/value pairs for each aggregate calculation specified.
.. tip:: Certain pipeline logic can benefit from calculating these simple aggregates in a dedicated step and then referenced in subsequent
transformations as literals - this simplifies the spark query plan by cutting down on window expressions and/or joins.
A common example of this is calculating the percentage of a column across the total for the entire dataset. This would otherwise be an expensive
window operation without the use of this function.
:param df: Dataframe used to calculate aggregate values
:param value_column: Column or string to use when calculating aggregate values. If a string is provided, it is referenced as a Column.
:param aggregate_functions: A list of aggregate function names. For each function name, a separate aggregate calculation will be returned in the output dict.
Examples
^^^^^^^^^^^^^^^^
.. code-block:: python
:caption: Example of generating aggregate values. Notice the output is a dictonary with keys representing the aggregate functions specified.
from etl_toolkit import E, F, A
from datetime import date
df = spark.createDataFrame([
{"value": 100, "color": "red", "date": date(2024, 1, 1)},
{"value": 50, "color": "blue", "date": date(2024, 1, 1)},
])
print(
A.get_aggregates(
df,
value_column="value",
aggregate_functions=["min", "max", "sum"],
)
)
# out: {'min': 50, 'max': 100, 'sum': 150}
.. code-block:: python
:caption: Example of using aggregate values to avoid an expensive window operation.
Notice how the aggregates are python values that can be used as literals in subsequent spark operations.
from etl_toolkit import E, F, A, W
from datetime import date
df = spark.createDataFrame([
{"value": 100, "color": "red", "date": date(2024, 1, 1)},
{"value": 50, "color": "blue", "date": date(2024, 1, 1)},
])
# This costly window operation can be reimplemented using get_aggregates
enriched_df = df.withColumn(
"value_percent",
F.col("value") / F.sum("value").over(W.partitionBy(F.lit(1)))
)
# Using get_aggregates, we can accomplish the same logic:
aggregates = A.get_aggregates(
df,
value_column="value",
aggregate_functions=["sum"]
)
display(
df
.withColumn("value_percent", F.col("value") / F.lit(aggregates["sum"]))
)
+--------------+--------------+--------------+----------------+
|color |date |value |value_percent |
+--------------+--------------+--------------+----------------+
| red| 2024-01-01| 100| 0.666666|
+--------------+--------------+--------------+----------------+
| blue| 2024-01-01| 50| 0.333333|
+--------------+--------------+--------------+----------------+
"""
value_column = E.normalize_column(value_column)
if len(aggregate_functions) == 0:
raise InvalidInputException("Must provide at least one aggregate_function")
# Build agg functions based on what is specified
select_columns = []
for aggregate_function in aggregate_functions:
if aggregate_function == "min":
select_columns.append(F.min(value_column).alias(aggregate_function))
elif aggregate_function == "max":
select_columns.append(F.max(value_column).alias(aggregate_function))
elif aggregate_function == "count":
select_columns.append(F.count(value_column).alias(aggregate_function))
elif aggregate_function == "avg":
select_columns.append(F.avg(value_column).alias(aggregate_function))
elif aggregate_function == "sum":
select_columns.append(F.sum(value_column).alias(aggregate_function))
elif aggregate_function == "count_distinct":
select_columns.append(
F.countDistinct(value_column).alias(aggregate_function)
)
else:
raise InvalidInputException(
f"Invalid aggregate_function supplied: {aggregate_function}. "
'Must be one of "min", "max", "count", "avg", "sum", "count_distinct"'
)
last_value_df = df.select(select_columns).first().asDict()
return last_value_df
@track_usage
def get_column_type(df: DataFrame, column: str | Column) -> str:
"""
Return the spark column type as a string for the provided column on the dataframe
"""
return df.select(E.normalize_column(column)).dtypes[0][1]