Source code for etl_toolkit.expressions.normalization

from typing import Literal
import re

from pyspark.sql import functions as F, Column
from yipit_databricks_utils.helpers.telemetry import track_usage
from pyspark.sql.types import (
    StringType,
    IntegerType,
    LongType,
    FloatType,
    DoubleType,
    DecimalType,
    BooleanType,
    TimestampType,
    DateType,
)

from etl_toolkit.expressions.core import normalize_column
from etl_toolkit.expressions.time import normalize_date, normalize_timestamp
from etl_toolkit.exceptions import InvalidInputException

from etl_toolkit import expressions as E


@track_usage
[docs] def normalize_text( col: Column | str, case: Literal["lower", "upper"] = "lower" ) -> Column: """ Clean string type columns by removing extra white spaces and optionally converting to lower or upper case. :param col: Column to be normalized. The column should be of string type. If a string is passed, it will be referenced as a Column. :param case: The desired case of the output column, should be either "lower" or "upper". Examples ----------- .. code-block:: python :caption: Using E.normalize_text to clean up a column. Notice how leading and extra spaces are cleaned along with capitalization. from etl_toolkit import E, F df = spark.createDataFrame([ {"input": " This is a Test String"}, ]) display( df.withColumn("output", E.normalize_text("input")) ) +------------------------+----------------------+ |input |output | +------------------------+----------------------+ | This is a Test String |this is a test string | +------------------------+----------------------+ """ col = normalize_column(col) if case.lower() not in ("lower", "upper"): raise ValueError(f"case: {case} is invalid. It can only be 'lower' or 'upper'") case_col = F.lower(col) if case.lower() == "upper": case_col = F.upper(col) # remove any extra spaces new_col = F.trim(F.regexp_replace(case_col, r"\s\s+", " ").cast("string")) return new_col
@track_usage
[docs] def try_cast(column: Column | str, to_type: str) -> Column: """ Safely attempts to cast a column to the specified type, returning null if the cast fails. This is useful for handling potentially dirty data where some values may not conform to the desired type. :param column: The column to cast. Can be a Column object or string column name. :param to_type: Target type as string. For decimal type, specify as 'decimal(precision,scale)'. Other supported types: 'string', 'int', 'long', 'float', 'double', 'boolean', 'timestamp', 'date' :return: A Column containing the cast values, with nulls where casting failed Examples -------- .. code-block:: python :caption: Example of safely casting string values to integers from etl_toolkit import E, F df = spark.createDataFrame([ {"value": "123"}, {"value": "abc"}, {"value": "456"} ]) display( df.withColumn("cast_value", E.try_cast("value", "int")) ) +-------+-----------+ |value |cast_value | +-------+-----------+ |123 |123 | |abc |null | |456 |456 | +-------+-----------+ .. code-block:: python :caption: Example of safely casting strings to dates from etl_toolkit import E, F df = spark.createDataFrame([ {"date_str": "2024-01-01"}, {"date_str": "invalid"}, {"date_str": "2024-02-01"} ]) display( df.withColumn("parsed_date", E.try_cast("date_str", "date")) ) +------------+------------+ |date_str |parsed_date | +------------+------------+ |2024-01-01 |2024-01-01 | |invalid |null | |2024-02-01 |2024-02-01 | +------------+------------+ .. code-block:: python :caption: Example of casting to decimal with custom precision from etl_toolkit import E, F df = spark.createDataFrame([ {"value": "123.456"}, {"value": "invalid"}, {"value": "789.012"} ]) # Cast to decimal with precision 10 and scale 3 display( df.withColumn("cast_value", E.try_cast("value", "decimal(38,18)")) ) """ input_column = normalize_column(column) if to_type == "date": return normalize_date(input_column) elif to_type == "timestamp": return normalize_timestamp(input_column) decimal_match = re.match(r"decimal\((\d+),(\d+)\)", to_type.lower()) if decimal_match: precision, scale = map(int, decimal_match.groups()) sql_type = f"decimal({precision},{scale})" elif to_type.lower() == "decimal": raise InvalidInputException( "Decimal type requires precision and scale. Use format: 'decimal(precision,scale)'" ) else: sql_type = to_type.lower().replace("long", "bigint") if isinstance(column, str): col_expr = f"`{column}`" else: col_str = str(input_column) match = re.search(r"'([^']*)'", col_str) col_name = match.group(1) if match else col_str col_expr = f"`{col_name}`" return F.expr(f"TRY_CAST({col_expr} AS {sql_type})")