from typing import Literal
import re
from pyspark.sql import functions as F, Column
from yipit_databricks_utils.helpers.telemetry import track_usage
from pyspark.sql.types import (
StringType,
IntegerType,
LongType,
FloatType,
DoubleType,
DecimalType,
BooleanType,
TimestampType,
DateType,
)
from etl_toolkit.expressions.core import normalize_column
from etl_toolkit.expressions.time import normalize_date, normalize_timestamp
from etl_toolkit.exceptions import InvalidInputException
from etl_toolkit import expressions as E
@track_usage
[docs]
def normalize_text(
col: Column | str, case: Literal["lower", "upper"] = "lower"
) -> Column:
"""
Clean string type columns by removing extra white spaces and optionally converting to lower or upper case.
:param col: Column to be normalized. The column should be of string type. If a string is passed, it will be referenced as a Column.
:param case: The desired case of the output column, should be either "lower" or "upper".
Examples
-----------
.. code-block:: python
:caption: Using E.normalize_text to clean up a column.
Notice how leading and extra spaces are cleaned along with capitalization.
from etl_toolkit import E, F
df = spark.createDataFrame([
{"input": " This is a Test String"},
])
display(
df.withColumn("output", E.normalize_text("input"))
)
+------------------------+----------------------+
|input |output |
+------------------------+----------------------+
| This is a Test String |this is a test string |
+------------------------+----------------------+
"""
col = normalize_column(col)
if case.lower() not in ("lower", "upper"):
raise ValueError(f"case: {case} is invalid. It can only be 'lower' or 'upper'")
case_col = F.lower(col)
if case.lower() == "upper":
case_col = F.upper(col)
# remove any extra spaces
new_col = F.trim(F.regexp_replace(case_col, r"\s\s+", " ").cast("string"))
return new_col
@track_usage
[docs]
def try_cast(column: Column | str, to_type: str) -> Column:
"""
Safely attempts to cast a column to the specified type, returning null if the cast fails.
This is useful for handling potentially dirty data where some values may not conform to
the desired type.
:param column: The column to cast. Can be a Column object or string column name.
:param to_type: Target type as string. For decimal type, specify as 'decimal(precision,scale)'.
Other supported types: 'string', 'int', 'long', 'float', 'double',
'boolean', 'timestamp', 'date'
:return: A Column containing the cast values, with nulls where casting failed
Examples
--------
.. code-block:: python
:caption: Example of safely casting string values to integers
from etl_toolkit import E, F
df = spark.createDataFrame([
{"value": "123"},
{"value": "abc"},
{"value": "456"}
])
display(
df.withColumn("cast_value", E.try_cast("value", "int"))
)
+-------+-----------+
|value |cast_value |
+-------+-----------+
|123 |123 |
|abc |null |
|456 |456 |
+-------+-----------+
.. code-block:: python
:caption: Example of safely casting strings to dates
from etl_toolkit import E, F
df = spark.createDataFrame([
{"date_str": "2024-01-01"},
{"date_str": "invalid"},
{"date_str": "2024-02-01"}
])
display(
df.withColumn("parsed_date", E.try_cast("date_str", "date"))
)
+------------+------------+
|date_str |parsed_date |
+------------+------------+
|2024-01-01 |2024-01-01 |
|invalid |null |
|2024-02-01 |2024-02-01 |
+------------+------------+
.. code-block:: python
:caption: Example of casting to decimal with custom precision
from etl_toolkit import E, F
df = spark.createDataFrame([
{"value": "123.456"},
{"value": "invalid"},
{"value": "789.012"}
])
# Cast to decimal with precision 10 and scale 3
display(
df.withColumn("cast_value", E.try_cast("value", "decimal(38,18)"))
)
"""
input_column = normalize_column(column)
if to_type == "date":
return normalize_date(input_column)
elif to_type == "timestamp":
return normalize_timestamp(input_column)
decimal_match = re.match(r"decimal\((\d+),(\d+)\)", to_type.lower())
if decimal_match:
precision, scale = map(int, decimal_match.groups())
sql_type = f"decimal({precision},{scale})"
elif to_type.lower() == "decimal":
raise InvalidInputException(
"Decimal type requires precision and scale. Use format: 'decimal(precision,scale)'"
)
else:
sql_type = to_type.lower().replace("long", "bigint")
if isinstance(column, str):
col_expr = f"`{column}`"
else:
col_str = str(input_column)
match = re.search(r"'([^']*)'", col_str)
col_name = match.group(1) if match else col_str
col_expr = f"`{col_name}`"
return F.expr(f"TRY_CAST({col_expr} AS {sql_type})")