Source code for etl_toolkit.expressions.mapping

from dataclasses import dataclass
from itertools import chain

from pyspark.sql import functions as F, Column
from yipit_databricks_utils.helpers.telemetry import track_usage

from etl_toolkit.expressions.core import normalize_column, normalize_literal


@dataclass
class CaseStatement:
    when: Column | str
    then: Column | str | dict[str, str | Column]

    @property
    def expression(self) -> (Column, Column):
        when = normalize_column(self.when)

        # If it is a dict, generate a named struct dynamically
        # can then reference dict keys directly in pyspark
        if isinstance(self.then, dict):
            then = F.named_struct(
                *list(
                    chain.from_iterable(
                        [
                            (normalize_literal(col), normalize_literal(value))
                            for col, value in self.then.items()
                        ]
                    )
                )
            )

        else:
            then = normalize_literal(self.then)

        return (when, then)


case = CaseStatement


def assign(
    then: Column | str | dict[str, str | Column],
    when: Column | str,
) -> case:
    return case(when, then)


[docs] def chain_cases(conditions: list[case], otherwise: Column | str = None) -> Column: """ Utility function to write case statements in Pyspark more efficiently. It accepts a list of conditions that are ``E.cases`` or ``E.assign``. .. tip:: ``E.chain_assigns`` and ``E.chain_cases`` are aliases of each other. They both produce the same results, however, sylistically it is better to use ``E.case`` with ``E.chain_cases`` and ``E.assign`` with ``E.chain_assigns``. The ``E.case`` function makes it straightforward to write when/then cases. For example, ``E.case(F.col('x') == 1, 'test')`` -> case when x=1 then 'test'. The case function accepts a boolean expresssion Column as the first argument and a ``Column``, ``string``, or ``dict`` for the second argument that is returned if the boolean expression is True. If a string is specified for the second argument, it is treated as a literal. The ``E.assign`` function is similar to ``E.case``, however the arguments are reversed. For example, ``E.assign('test', F.col('x') == 1)`` -> case when x=1 then 'test'. The assign function can be more readable code when defining a long series of mapping conditions, since it can be easier to see the match value first before the condition. Can set a default value if no cases are met via the ``otherwise`` argument. :param conditions: List of ``E.cases`` or ``E.assigns`` that are evaluated in order for matches. If the case matches, the associate value is used for the row. Must have at least one condition in the list. :param otherwise: A default value to use if no cases are matched from ``conditions``. If a string is passed it will be treated as a literal. Examples ----------- .. code-block:: python :caption: Using E.chain_cases to define a mapping column from etl_toolkit import E, F df = spark.createDataFrame([ {"input": "aapl"}, {"input": "nke"}, {"input": "dpz"}, ]) sector_mapping = E.chain_cases([ E.case(F.col("input") == "aapl", "Tech"), E.case(F.col("input") == "nke", "Retail"), E.case(F.col("input") == "dpz", "Food"), ]) display( df .withColumn("output", sector_mapping) ) +--------------+--------------+ |input |output | +--------------+--------------+ |aapl |Tech | +--------------+--------------+ |nke |Retail | +--------------+--------------+ |dpz |Food | +--------------+--------------+ .. code-block:: python :caption: Using E.chain_assigns to define a mapping column. Notice that it can be easier to read the match value first and then the associated condition. from etl_toolkit import E, F df = spark.createDataFrame([ {"input": "aapl"}, {"input": "nke"}, {"input": "dpz"}, ]) sector_mapping = E.chain_assigns([ E.assign("Tech", F.col("input") == "aapl"), E.assign("Retail", F.col("input") == "nke"), E.assign("Food", F.col("input") == "dpz"), ]) display( df .withColumn("output", sector_mapping) ) +--------------+--------------+ |input |output | +--------------+--------------+ |aapl |Tech | +--------------+--------------+ |nke |Retail | +--------------+--------------+ |dpz |Food | +--------------+--------------+ .. code-block:: python :caption: Using the otherwise parameter to specify default values for no matches. from etl_toolkit import E, F df = spark.createDataFrame([ {"input": "aapl"}, {"input": "nke"}, {"input": "tsla"}, ]) sector_mapping = E.chain_cases([ E.case(F.col("input") == "aapl", "Tech"), E.case(F.col("input") == "nke", "Retail"), E.case(F.col("input") == "dpz", "Food"), ], otherwise="N/A") display( df .withColumn("output", sector_mapping) ) +--------------+--------------+ |input |output | +--------------+--------------+ |aapl |Tech | +--------------+--------------+ |nke |Retail | +--------------+--------------+ |tsla |N/A | +--------------+--------------+ .. code-block:: python :caption: Using dict valutes in ``E.case`` will return a Map type column for matches. The nested fields can also be accessed in pyspark directly if a map isn't desired. This approach works well when multiple columns need to be assigned for a given condition, rather than creating multiple case statement expressions with repeated logic. from etl_toolkit import E, F df = spark.createDataFrame([ {"input": "aapl"}, {"input": "nke"}, {"input": "dpz"}, ]) sector_mapping = E.chain_cases([ E.case( F.col("input") == "aapl", { "sector": "Technology", "subsector": "Consumer Devices", }, ), E.case( F.col("input") == "nke", { "sector": "Retail", "subsector": "Apparel", }, ), ]) display( df .withColumn("output", sector_mapping) .withColumn("sector", sector_mapping.sector) .withColumn("subsector", sector_mapping.subsector) ) +--------------+----------------------------------------------------+--------------+-----------------+ |input |output |sector |subsector | +--------------+----------------------------------------------------+--------------+-----------------+ |aapl |{"sector": "Tech", "subsector": "Consumer Devices"} |Tech |Consumer Devices | +--------------+----------------------------------------------------+--------------+-----------------+ |nke |{"sector": "Retail", "subsector": "Apparel"} |Retail |Apparel | +--------------+----------------------------------------------------+--------------+-----------------+ |tsla |null |null |null | +--------------+----------------------------------------------------+--------------+-----------------+ """ if not isinstance(conditions, list): raise ValueError("Must input a list of condition(s) (E.assign or E.case)") if len(conditions) == 0: raise ValueError( "At least 1 condition (E.assign or E.case) must be provided in the conditions list" ) case_statement = F.when(*conditions[0].expression) for condition in conditions[1:]: case_statement = case_statement.when(*condition.expression) case_statement = case_statement.otherwise(normalize_literal(otherwise)) return case_statement
# Alias to make naming consistent between E.assign and E.case
[docs] chain_assigns = chain_cases
def chain_cases_multiple( conditions: list[case], otherwise: Column | str = None ) -> Column: """ Utility function to write case statements in Pyspark that return multiple matching cases. It accepts a list of conditions that are ``E.cases`` or ``E.assign``. .. tip:: ``E.chain_assigns_multiple`` and ``E.chain_cases_multiple`` are aliases of each other. They both produce the same results, however, stylistically it is better to use ``E.case`` with ``E.chain_cases_multiple`` and ``E.assign`` with ``E.chain_assigns_multiple``. The ``E.case`` function makes it straightforward to write when/then cases. For example, ``E.case(F.col('x') == 1, 'test')`` -> case when x=1 then 'test'. The case function accepts a boolean expression Column as the first argument and a ``Column``, ``string``, or ``dict`` for the second argument that is returned if the boolean expression is True. If a string is specified for the second argument, it is treated as a literal. The ``E.assign`` function is similar to ``E.case``, however the arguments are reversed. For example, ``E.assign('test', F.col('x') == 1)`` -> case when x=1 then 'test'. The assign function can be more readable code when defining a long series of mapping conditions, since it can be easier to see the match value first before the condition. Can set a default value if no cases are met via the ``otherwise`` argument. Unlike the original ``chain_cases``, this function returns an array of all matching cases, allowing for multiple matches per row. :param conditions: List of ``E.cases`` or ``E.assigns`` that are evaluated for matches. If a case matches, the associated value is added to the result array for the row. Must have at least one condition in the list. :param otherwise: A default value to use if no cases are matched from ``conditions``. If a string is passed it will be treated as a literal. :return: A Column containing an array of all matching case values. Examples ----------- .. code-block:: python :caption: Using E.chain_cases_multiple to define a mapping column with multiple matches from etl_toolkit import E, F df = spark.createDataFrame([ {"input": "aapl", "market_cap": 2000}, {"input": "nke", "market_cap": 150}, {"input": "tsla", "market_cap": 800}, ]) sector_mapping = E.chain_cases_multiple([ E.case(F.col("input") == "aapl", "Tech"), E.case(F.col("input") == "nke", "Retail"), E.case(F.col("market_cap") > 500, "Large Cap"), E.case(F.col("market_cap") <= 500, "Small Cap"), ]) display( df .withColumn("output", sector_mapping) ) +--------------+------------+------------------------+ |input |market_cap |output | +--------------+------------+------------------------+ |aapl |2000 |["Tech", "Large Cap"] | +--------------+------------+------------------------+ |nke |150 |["Retail", "Small Cap"] | +--------------+------------+------------------------+ |tsla |800 |["Large Cap"] | +--------------+------------+------------------------+ .. code-block:: python :caption: Using the otherwise parameter to specify default values for no matches. from etl_toolkit import E, F df = spark.createDataFrame([ {"input": "aapl", "market_cap": 2000}, {"input": "nke", "market_cap": 150}, {"input": "xyz", "market_cap": 50}, ]) sector_mapping = E.chain_cases_multiple([ E.case(F.col("input") == "aapl", "Tech"), E.case(F.col("input") == "nke", "Retail"), E.case(F.col("market_cap") > 500, "Large Cap"), ], otherwise="Unknown") display( df .withColumn("output", sector_mapping) ) +--------------+------------+------------------------+ |input |market_cap |output | +--------------+------------+------------------------+ |aapl |2000 |["Tech", "Large Cap"] | +--------------+------------+------------------------+ |nke |150 |["Retail"] | +--------------+------------+------------------------+ |xyz |50 |["Unknown"] | +--------------+------------+------------------------+ """ if not isinstance(conditions, list): raise ValueError("Must input a list of condition(s) (E.assign or E.case)") if len(conditions) == 0: raise ValueError( "At least 1 condition (E.assign or E.case) must be provided in the conditions list" ) # Create an array of all matching cases matches = F.array( *[ F.when( normalize_column(condition.when), normalize_literal(condition.then) ).otherwise(F.lit(None)) for condition in conditions ] ) # Filter out None values result = F.filter(matches, lambda x: x.isNotNull()) # Apply otherwise condition if the array is empty if otherwise is not None: result = F.when(F.size(result) > 0, result).otherwise( F.array(normalize_literal(otherwise)) ) return result