from functools import reduce
from datetime import datetime, date
from pyspark.sql import functions as F, Column, DataFrame
from yipit_databricks_utils.helpers.telemetry import track_usage
from etl_toolkit.expressions.core import (
normalize_literal,
normalize_column,
is_numeric_type,
is_datetime_type,
)
[docs]
def any(*conditions: list[Column | str] | Column | str) -> Column:
"""
Expression that can take a series of boolean expressions and wrap them in OR statements.
This will return a boolean Column that is True if any one condition is True. It is recommended you use ``E.any`` over
``|`` expressions in pyspark as it is easier to read and makes chaining multiple conditions simpler to write.
e.g. [x>2, y>3] -> (x>2) | (y>3)
Returns a python ``None`` object if no conditions are passed.
:param conditions: A series of ``Column`` expressions that can be passed in as a list or directly, separated by commas. If strings are passed, they will be treated as Columns. Expressions must be of boolean type.
Examples
-----------
.. code-block:: python
:caption: Simple OR expression using E.any
from etl_toolkit import E, F
df = spark.createDataFrame([
{"input": 100},
])
display(
df
.select(
"input",
E.any(
F.col("input") < 0,
F.col("input") < 1000
).alias("output")
)
)
+--------------+--------------+
|input |output |
+--------------+--------------+
|100 | True|
+--------------+--------------+
.. code-block:: python
:caption: Pass in any number of conditions to use in E.any. It can be helpful to pass a list if building conditions dynamically.
from etl_toolkit import E, F
df = spark.createDataFrame([
{"input": 100},
])
columns = ["input"]
conditions = [F.col(column_name) > 10 for column_name in columns]
display(
df
.select(
"input",
E.any(conditions).alias("output"),
)
)
+--------------+--------------+
|input |output |
+--------------+--------------+
|100 | True|
+--------------+--------------+
"""
flattened_conditions = _flatten_expressions(conditions)
if len(flattened_conditions) == 0:
return None
return reduce(lambda a, b: a | b, flattened_conditions)
[docs]
def all(*conditions: list[Column | str] | Column | str) -> Column:
"""
Expression that can take a series of boolean expressions and wrap them in AND statements.
This will return a boolean Column that is True if all conditions are True. It is recommended you use ``E.all`` over
``&`` expressions in pyspark as it is easier to read and makes chaining multiple conditions simpler to write.
e.g. [x>2, y>3] -> (x>2) & (y>3)
Returns a python ``None`` object if no conditions are passed.
:param conditions: A series of ``Column`` expressions that can be passed in as a list or directly, separated by commas. If strings are passed, they will be treated as Columns. Expressions must be of boolean type.
Examples
-----------
.. code-block:: python
:caption: Simple AND expression using E.all
from etl_toolkit import E, F
df = spark.createDataFrame([
{"input": 100},
])
display(
df
.select(
"input",
E.all(
F.col("input") < 0,
F.col("input") < 1000
).alias("output")
)
)
+--------------+--------------+
|input |output |
+--------------+--------------+
|100 | False|
+--------------+--------------+
.. code-block:: python
:caption: Pass in any number of conditions to use in E.all. It can be helpful to pass a list if building conditions dynamically.
from etl_toolkit import E, F
df = spark.createDataFrame([
{"input": 100},
])
columns = ["input"]
conditions = [F.col(column_name) > 10 for column_name in columns]
display(
df
.select(
"input",
E.all(conditions).alias("output"),
)
)
+--------------+--------------+
|input |output |
+--------------+--------------+
|100 | True|
+--------------+--------------+
"""
flattened_conditions = _flatten_expressions(conditions)
if len(flattened_conditions) == 0:
return None
return reduce(lambda a, b: a & b, flattened_conditions)
[docs]
def between(
column: Column | str,
lower_bound_column: Column | str | int | float | date | datetime,
upper_bound_column: Column | str | int | float | date | datetime,
include_lower_bound: bool = True,
include_upper_bound: bool = True,
) -> Column:
"""
Expression to return a boolean that indicates if a column is within two other columns or values (bounds).
The definition considers values at the bounds valid, but can be adjusted using
``include_lower_bound`` and ``include_upper_bound``
e.g. [x, a, b] -> a <= x <= b
:param column: A Column to compare against the lower and upper bounds. The column must be a compatible type with the bounds. Passing a string will be evaluated as a Column.
:param lower_bound_column: A Column that represents the minimum bound of the expression. The column must be a compatible type with the primary column and other bound. Passing a non-column type will be evaluated as a literal.
:param upper_bound_column: A Column that represents the maximum bound of the expression. The column must be a compatible type with the primary column and other bound. Passing a non-column type will be evaluated as a literal.
:param include_lower_bound: A boolean flag that if ``True`` will consider values at the minimum bound valid (i.e. <= expression), otherwise values as the bound are not valid (i.e. < expression)
:param include_upper_bound: A boolean flag that if ``True`` will consider values at the maximum bound valid (i.e. >= expression), otherwise values as the bound are not valid (i.e. > expression)
Examples
-----------
.. code-block:: python
:caption: Using E.between to handle filtering outliers. Notice how the bounds can be expressed as python literal values or Columns.
from etl_toolkit import E, F
df = spark.createDataFrame([
{"input": 0},
{"input": 20},
{"input": 100},
{"input": 120},
{"input": 300},
])
display(
df
.where(E.between("input", 10, 200))
)
+--------------+
|input |
+--------------+
|20 |
+--------------+
|100 |
+--------------+
|120 |
+--------------+
.. code-block:: python
:caption: You can make the function not include bounds with the optional flags.
from etl_toolkit import E, F
df = spark.createDataFrame([
{"input": 0},
{"input": 20},
{"input": 100},
{"input": 120},
{"input": 300},
])
display(
df
.where(E.between("input", 20, 120, include_upper_bound=False, include_lower_bound=False))
)
+--------------+
|input |
+--------------+
|100 |
+--------------+
"""
if is_numeric_type(column) or is_datetime_type(column):
column = normalize_literal(column)
else:
column = normalize_column(column)
if is_numeric_type(lower_bound_column) or is_datetime_type(lower_bound_column):
lower_bound = normalize_literal(lower_bound_column)
else:
lower_bound = normalize_column(lower_bound_column)
if is_numeric_type(upper_bound_column) or is_datetime_type(upper_bound_column):
upper_bound = normalize_literal(upper_bound_column)
else:
upper_bound = normalize_column(upper_bound_column)
if include_upper_bound and include_lower_bound:
return all(
[
column >= lower_bound,
column <= upper_bound,
]
)
elif include_lower_bound:
return all(
[
column >= lower_bound,
column < upper_bound,
]
)
elif include_upper_bound:
return all(
[
column > lower_bound,
column <= upper_bound,
]
)
return all(
[
column > lower_bound,
column < upper_bound,
]
)
def _flatten_expressions(expressions: list[Column | str] | Column | str) -> list:
# Given a single string, we shouldn't iterate over it
if isinstance(expressions, str):
return [normalize_column(expressions)]
flattened_expressions = []
for expression in expressions:
if isinstance(expression, list):
for item in expression:
flattened_expressions.append(normalize_column(item))
else:
flattened_expressions.append(normalize_column(expression))
return flattened_expressions