from dataclasses import dataclass
from itertools import chain
from pyspark.sql import functions as F, Column
from yipit_databricks_utils.helpers.telemetry import track_usage
from etl_toolkit.expressions.core import normalize_column, normalize_literal
@dataclass
class CaseStatement:
when: Column | str
then: Column | str | dict[str, str | Column]
@property
def expression(self) -> (Column, Column):
when = normalize_column(self.when)
# If it is a dict, generate a named struct dynamically
# can then reference dict keys directly in pyspark
if isinstance(self.then, dict):
then = F.named_struct(
*list(
chain.from_iterable(
[
(normalize_literal(col), normalize_literal(value))
for col, value in self.then.items()
]
)
)
)
else:
then = normalize_literal(self.then)
return (when, then)
case = CaseStatement
def assign(
then: Column | str | dict[str, str | Column],
when: Column | str,
) -> case:
return case(when, then)
[docs]
def chain_cases(conditions: list[case], otherwise: Column | str = None) -> Column:
"""
Utility function to write case statements in Pyspark more efficiently.
It accepts a list of conditions that are ``E.cases`` or ``E.assign``.
.. tip:: ``E.chain_assigns`` and ``E.chain_cases`` are aliases of each other.
They both produce the same results, however, sylistically it is better to use
``E.case`` with ``E.chain_cases`` and ``E.assign`` with ``E.chain_assigns``.
The ``E.case`` function makes it straightforward to write when/then cases.
For example, ``E.case(F.col('x') == 1, 'test')`` -> case when x=1 then 'test'.
The case function accepts a boolean expresssion Column as the first argument and a
``Column``, ``string``, or ``dict`` for the second argument that is returned if the boolean expression is True.
If a string is specified for the second argument, it is treated as a literal.
The ``E.assign`` function is similar to ``E.case``, however the arguments are reversed.
For example, ``E.assign('test', F.col('x') == 1)`` -> case when x=1 then 'test'.
The assign function can be more readable code when defining a long series of mapping conditions,
since it can be easier to see the match value first before the condition.
Can set a default value if no cases are met via the ``otherwise`` argument.
:param conditions: List of ``E.cases`` or ``E.assigns`` that are evaluated in order for matches. If the case matches, the associate value is used for the row. Must have at least one condition in the list.
:param otherwise: A default value to use if no cases are matched from ``conditions``. If a string is passed it will be treated as a literal.
Examples
-----------
.. code-block:: python
:caption: Using E.chain_cases to define a mapping column
from etl_toolkit import E, F
df = spark.createDataFrame([
{"input": "aapl"},
{"input": "nke"},
{"input": "dpz"},
])
sector_mapping = E.chain_cases([
E.case(F.col("input") == "aapl", "Tech"),
E.case(F.col("input") == "nke", "Retail"),
E.case(F.col("input") == "dpz", "Food"),
])
display(
df
.withColumn("output", sector_mapping)
)
+--------------+--------------+
|input |output |
+--------------+--------------+
|aapl |Tech |
+--------------+--------------+
|nke |Retail |
+--------------+--------------+
|dpz |Food |
+--------------+--------------+
.. code-block:: python
:caption: Using E.chain_assigns to define a mapping column. Notice that it can be easier to read the match value first and then the associated condition.
from etl_toolkit import E, F
df = spark.createDataFrame([
{"input": "aapl"},
{"input": "nke"},
{"input": "dpz"},
])
sector_mapping = E.chain_assigns([
E.assign("Tech", F.col("input") == "aapl"),
E.assign("Retail", F.col("input") == "nke"),
E.assign("Food", F.col("input") == "dpz"),
])
display(
df
.withColumn("output", sector_mapping)
)
+--------------+--------------+
|input |output |
+--------------+--------------+
|aapl |Tech |
+--------------+--------------+
|nke |Retail |
+--------------+--------------+
|dpz |Food |
+--------------+--------------+
.. code-block:: python
:caption: Using the otherwise parameter to specify default values for no matches.
from etl_toolkit import E, F
df = spark.createDataFrame([
{"input": "aapl"},
{"input": "nke"},
{"input": "tsla"},
])
sector_mapping = E.chain_cases([
E.case(F.col("input") == "aapl", "Tech"),
E.case(F.col("input") == "nke", "Retail"),
E.case(F.col("input") == "dpz", "Food"),
], otherwise="N/A")
display(
df
.withColumn("output", sector_mapping)
)
+--------------+--------------+
|input |output |
+--------------+--------------+
|aapl |Tech |
+--------------+--------------+
|nke |Retail |
+--------------+--------------+
|tsla |N/A |
+--------------+--------------+
.. code-block:: python
:caption: Using dict valutes in ``E.case`` will return a Map type column for matches.
The nested fields can also be accessed in pyspark directly if a map isn't desired.
This approach works well when multiple columns need to be assigned for a given condition,
rather than creating multiple case statement expressions with repeated logic.
from etl_toolkit import E, F
df = spark.createDataFrame([
{"input": "aapl"},
{"input": "nke"},
{"input": "dpz"},
])
sector_mapping = E.chain_cases([
E.case(
F.col("input") == "aapl",
{
"sector": "Technology",
"subsector": "Consumer Devices",
},
),
E.case(
F.col("input") == "nke",
{
"sector": "Retail",
"subsector": "Apparel",
},
),
])
display(
df
.withColumn("output", sector_mapping)
.withColumn("sector", sector_mapping.sector)
.withColumn("subsector", sector_mapping.subsector)
)
+--------------+----------------------------------------------------+--------------+-----------------+
|input |output |sector |subsector |
+--------------+----------------------------------------------------+--------------+-----------------+
|aapl |{"sector": "Tech", "subsector": "Consumer Devices"} |Tech |Consumer Devices |
+--------------+----------------------------------------------------+--------------+-----------------+
|nke |{"sector": "Retail", "subsector": "Apparel"} |Retail |Apparel |
+--------------+----------------------------------------------------+--------------+-----------------+
|tsla |null |null |null |
+--------------+----------------------------------------------------+--------------+-----------------+
"""
if not isinstance(conditions, list):
raise ValueError("Must input a list of condition(s) (E.assign or E.case)")
if len(conditions) == 0:
raise ValueError(
"At least 1 condition (E.assign or E.case) must be provided in the conditions list"
)
case_statement = F.when(*conditions[0].expression)
for condition in conditions[1:]:
case_statement = case_statement.when(*condition.expression)
case_statement = case_statement.otherwise(normalize_literal(otherwise))
return case_statement
# Alias to make naming consistent between E.assign and E.case
[docs]
chain_assigns = chain_cases
def chain_cases_multiple(
conditions: list[case], otherwise: Column | str = None
) -> Column:
"""
Utility function to write case statements in Pyspark that return multiple matching cases.
It accepts a list of conditions that are ``E.cases`` or ``E.assign``.
.. tip:: ``E.chain_assigns_multiple`` and ``E.chain_cases_multiple`` are aliases of each other.
They both produce the same results, however, stylistically it is better to use
``E.case`` with ``E.chain_cases_multiple`` and ``E.assign`` with ``E.chain_assigns_multiple``.
The ``E.case`` function makes it straightforward to write when/then cases.
For example, ``E.case(F.col('x') == 1, 'test')`` -> case when x=1 then 'test'.
The case function accepts a boolean expression Column as the first argument and a
``Column``, ``string``, or ``dict`` for the second argument that is returned if the boolean expression is True.
If a string is specified for the second argument, it is treated as a literal.
The ``E.assign`` function is similar to ``E.case``, however the arguments are reversed.
For example, ``E.assign('test', F.col('x') == 1)`` -> case when x=1 then 'test'.
The assign function can be more readable code when defining a long series of mapping conditions,
since it can be easier to see the match value first before the condition.
Can set a default value if no cases are met via the ``otherwise`` argument.
Unlike the original ``chain_cases``, this function returns an array of all matching cases,
allowing for multiple matches per row.
:param conditions: List of ``E.cases`` or ``E.assigns`` that are evaluated for matches. If a case matches, the associated value is added to the result array for the row. Must have at least one condition in the list.
:param otherwise: A default value to use if no cases are matched from ``conditions``. If a string is passed it will be treated as a literal.
:return: A Column containing an array of all matching case values.
Examples
-----------
.. code-block:: python
:caption: Using E.chain_cases_multiple to define a mapping column with multiple matches
from etl_toolkit import E, F
df = spark.createDataFrame([
{"input": "aapl", "market_cap": 2000},
{"input": "nke", "market_cap": 150},
{"input": "tsla", "market_cap": 800},
])
sector_mapping = E.chain_cases_multiple([
E.case(F.col("input") == "aapl", "Tech"),
E.case(F.col("input") == "nke", "Retail"),
E.case(F.col("market_cap") > 500, "Large Cap"),
E.case(F.col("market_cap") <= 500, "Small Cap"),
])
display(
df
.withColumn("output", sector_mapping)
)
+--------------+------------+------------------------+
|input |market_cap |output |
+--------------+------------+------------------------+
|aapl |2000 |["Tech", "Large Cap"] |
+--------------+------------+------------------------+
|nke |150 |["Retail", "Small Cap"] |
+--------------+------------+------------------------+
|tsla |800 |["Large Cap"] |
+--------------+------------+------------------------+
.. code-block:: python
:caption: Using the otherwise parameter to specify default values for no matches.
from etl_toolkit import E, F
df = spark.createDataFrame([
{"input": "aapl", "market_cap": 2000},
{"input": "nke", "market_cap": 150},
{"input": "xyz", "market_cap": 50},
])
sector_mapping = E.chain_cases_multiple([
E.case(F.col("input") == "aapl", "Tech"),
E.case(F.col("input") == "nke", "Retail"),
E.case(F.col("market_cap") > 500, "Large Cap"),
], otherwise="Unknown")
display(
df
.withColumn("output", sector_mapping)
)
+--------------+------------+------------------------+
|input |market_cap |output |
+--------------+------------+------------------------+
|aapl |2000 |["Tech", "Large Cap"] |
+--------------+------------+------------------------+
|nke |150 |["Retail"] |
+--------------+------------+------------------------+
|xyz |50 |["Unknown"] |
+--------------+------------+------------------------+
"""
if not isinstance(conditions, list):
raise ValueError("Must input a list of condition(s) (E.assign or E.case)")
if len(conditions) == 0:
raise ValueError(
"At least 1 condition (E.assign or E.case) must be provided in the conditions list"
)
# Create an array of all matching cases
matches = F.array(
*[
F.when(
normalize_column(condition.when), normalize_literal(condition.then)
).otherwise(F.lit(None))
for condition in conditions
]
)
# Filter out None values
result = F.filter(matches, lambda x: x.isNotNull())
# Apply otherwise condition if the array is empty
if otherwise is not None:
result = F.when(F.size(result) > 0, result).otherwise(
F.array(normalize_literal(otherwise))
)
return result