Source code for etl_toolkit.analyses.scalar

from typing import Any, Literal

from pyspark.sql import DataFrame, Column, functions as F
from yipit_databricks_utils.helpers.telemetry import track_usage

from etl_toolkit import expressions as E
from etl_toolkit.exceptions import InvalidInputException


@track_usage
[docs] def get_aggregates( df: DataFrame, value_column: str | Column, aggregate_functions: list[ Literal["min", "max", "count", "avg", "sum", "count_distinct"] ], ) -> dict[str, Any]: """ Returns a dictionary of aggregate values based on the ``value_column``, ``df``, and ``aggregate_functions`` specified. This is used to be able to extract the min, max, etc. values from a pyspark dataframe into python and use in subsequent code. The output is a dictionary with key/value pairs for each aggregate calculation specified. .. tip:: Certain pipeline logic can benefit from calculating these simple aggregates in a dedicated step and then referenced in subsequent transformations as literals - this simplifies the spark query plan by cutting down on window expressions and/or joins. A common example of this is calculating the percentage of a column across the total for the entire dataset. This would otherwise be an expensive window operation without the use of this function. :param df: Dataframe used to calculate aggregate values :param value_column: Column or string to use when calculating aggregate values. If a string is provided, it is referenced as a Column. :param aggregate_functions: A list of aggregate function names. For each function name, a separate aggregate calculation will be returned in the output dict. Examples ^^^^^^^^^^^^^^^^ .. code-block:: python :caption: Example of generating aggregate values. Notice the output is a dictonary with keys representing the aggregate functions specified. from etl_toolkit import E, F, A from datetime import date df = spark.createDataFrame([ {"value": 100, "color": "red", "date": date(2024, 1, 1)}, {"value": 50, "color": "blue", "date": date(2024, 1, 1)}, ]) print( A.get_aggregates( df, value_column="value", aggregate_functions=["min", "max", "sum"], ) ) # out: {'min': 50, 'max': 100, 'sum': 150} .. code-block:: python :caption: Example of using aggregate values to avoid an expensive window operation. Notice how the aggregates are python values that can be used as literals in subsequent spark operations. from etl_toolkit import E, F, A, W from datetime import date df = spark.createDataFrame([ {"value": 100, "color": "red", "date": date(2024, 1, 1)}, {"value": 50, "color": "blue", "date": date(2024, 1, 1)}, ]) # This costly window operation can be reimplemented using get_aggregates enriched_df = df.withColumn( "value_percent", F.col("value") / F.sum("value").over(W.partitionBy(F.lit(1))) ) # Using get_aggregates, we can accomplish the same logic: aggregates = A.get_aggregates( df, value_column="value", aggregate_functions=["sum"] ) display( df .withColumn("value_percent", F.col("value") / F.lit(aggregates["sum"])) ) +--------------+--------------+--------------+----------------+ |color |date |value |value_percent | +--------------+--------------+--------------+----------------+ | red| 2024-01-01| 100| 0.666666| +--------------+--------------+--------------+----------------+ | blue| 2024-01-01| 50| 0.333333| +--------------+--------------+--------------+----------------+ """ value_column = E.normalize_column(value_column) if len(aggregate_functions) == 0: raise InvalidInputException("Must provide at least one aggregate_function") # Build agg functions based on what is specified select_columns = [] for aggregate_function in aggregate_functions: if aggregate_function == "min": select_columns.append(F.min(value_column).alias(aggregate_function)) elif aggregate_function == "max": select_columns.append(F.max(value_column).alias(aggregate_function)) elif aggregate_function == "count": select_columns.append(F.count(value_column).alias(aggregate_function)) elif aggregate_function == "avg": select_columns.append(F.avg(value_column).alias(aggregate_function)) elif aggregate_function == "sum": select_columns.append(F.sum(value_column).alias(aggregate_function)) elif aggregate_function == "count_distinct": select_columns.append( F.countDistinct(value_column).alias(aggregate_function) ) else: raise InvalidInputException( f"Invalid aggregate_function supplied: {aggregate_function}. " 'Must be one of "min", "max", "count", "avg", "sum", "count_distinct"' ) last_value_df = df.select(select_columns).first().asDict() return last_value_df
@track_usage def get_column_type(df: DataFrame, column: str | Column) -> str: """ Return the spark column type as a string for the provided column on the dataframe """ return df.select(E.normalize_column(column)).dtypes[0][1]