Source code for etl_toolkit.writer.mutable_table

import json
from typing import Optional
from datetime import datetime

from pyspark.sql import functions as F, Window as W
from pyspark.sql import DataFrame as SparkDataFrame
from pyspark.sql.connect.dataframe import DataFrame

from yipit_databricks_utils.helpers.pyspark_utils import get_spark_session
from yipit_databricks_utils.helpers.telemetry import track_usage

from etl_toolkit.writer.table_create import (
    create_table,
    validate_catalog_name,
    validate_database_name,
    validate_table_name,
    validate_notebook_path,
)
from etl_toolkit.exceptions import BaseETLToolkitException


@track_usage
[docs] def create_mutable_table( *dfs: DataFrame | list[DataFrame], primary_column_name: str = "uuid", url_filter_column_names: Optional[list[str]] = None, column_definitions: Optional[list[dict]] = None, ag_grid_kwargs: Optional[dict] = None, dash_grid_options: Optional[dict] = None, color_scale: Optional[dict] = None, catalog_name: Optional[str] = None, database_name: Optional[str] = None, table_name: Optional[str] = None, **kwargs, ): """ Create a table that supports the ``MutableTable`` data app structure. Can start from an initial series of dataframes that will have metadata fields added to it automatically. - Dataframes may be empty so long as they have a defined schema - Additonal data app options can be passed in from this function and will be stored as table properties. These will be fetched when the data app renders to control MutableTable behavior. - After the table is created, visit the data app while specifying ``?table_name=<table_name>`` to see the table in the data app .. tip:: If using optional arguments, it is helpful to understand Dash AG grid and the options that are available to customize table behavior. :param dfs: One or more DataFrames to be written to the table. If multiple DataFrames are provided, they are appended to the table in the order they are given. :param primary_column_name: The column name (``str``) of the table to be used as primary key. It must be a UUID, as new records in the originate from the data app will be a UUID4. :param url_filter_column_names: List of column names that exist on the table that can be used to filter the table down via URL parameters when viewing in the data app. :param column_definitions: List of column definitions descriped as a list of dictionaries that affect the table behavior in the data app. This matches the syntaks of column definitions for AG Grid generally. If not specified, the table's schema is used to establish reasonable defaults for AG Grid. :param ag_grid_kwargs: Optional dict of keyword argument values to be passed as ``kwargs`` to ``AgGrid`` when rendering in the data app. :param dash_grid_options: Optional dict of configuration values to be passed to the ``dashGridOptions`` argument of ``AgGrid`` when rendering in the data app. :param color_scale: Optional dict to override color styles for ``pending``, ``staged``, ``committed``, or ``synced`` rows. Each key should map to a css-valid color property. :param catalog_name: The name of the catalog to create the table in. If not provided, the catalog name will be derived from the notebook path. :param database_name: The name of the database to create the table in. If not provided, the database name will be derived from the notebook path. :param table_name: The name of the table to create. If not provided, the table name will be derived from the notebook path. Examples -------- .. code-block:: python :caption: Creating a mutable table import uuid from etl_toolkit import create_mutable_table df = spark.createDataFrame([ {"uuid": str(uuid.uuid4()), "name": "Leopold"}, {"uuid": str(uuid.uuid4()), "name": "Molly"} ]) create_mutable_table(df, primary_column_name="uuid",) .. code-block:: python :caption: Creating a mutable table with custom configurations import uuid from etl_toolkit import create_mutable_table df = spark.createDataFrame([ {"uuid": str(uuid.uuid4()), "name": "Leopold"}, {"uuid": str(uuid.uuid4()), "name": "Molly"} ]) create_mutable_table( df, primary_column_name="uuid", url_filter_column_names=["name"], column_definitions=[ { "field": "uuid", "editable": False, "cellDataType": "string", } ], ) """ _validate_table_name_inputs( catalog_name=catalog_name, database_name=database_name, table_name=table_name, ) spark = get_spark_session() for df in dfs: if not isinstance(df, (SparkDataFrame, DataFrame)): raise BaseETLToolkitException("Invalid dataframe type") if primary_column_name not in df.columns: raise BaseETLToolkitException("Invalid table schema, missing primary key") final_dfs = [ df.withColumns( { "_mt_record_primary_key": F.lit(None).cast("string"), "_mt_change_id": F.lit(None).cast("string"), "_mt_change_group_id": F.lit(None).cast("string"), "_mt_change_type": F.lit(None).cast("string"), "_mt_change_owner_id": F.lit(None).cast("string"), "_mt_is_staged": F.lit(None).cast("boolean"), "_mt_staged_timestamp": F.lit(None).cast("timestamp"), "_mt_is_committed": F.lit(None).cast("boolean"), "_mt_commit_id": F.lit(None).cast("string"), "_mt_committed_timestamp": F.lit(None).cast("timestamp"), "_mt_is_synced": F.lit(None).cast("boolean"), "_mt_synced_timestamp": F.lit(None).cast("timestamp"), "_mt_create_timestamp": F.lit(None).cast("timestamp"), "_mt_update_timestamp": F.lit(None).cast("timestamp"), "_mt_additional_metadata": F.lit(None).cast("string"), } ) for df in dfs ] create_table( *final_dfs, catalog_name=catalog_name, database_name=database_name, table_name=table_name, **kwargs, ) # Add metadata fields spark.sql( f""" ALTER TABLE {catalog_name}.{database_name}.{table_name} SET TBLPROPERTIES('atlas.mutable_table.ENABLED' = 'true') """ ) spark.sql( f""" ALTER TABLE {catalog_name}.{database_name}.{table_name} SET TBLPROPERTIES('atlas.mutable_table.ROW_ID_FIELD' = '{primary_column_name}') """ ) if url_filter_column_names: _validate_json_serializable(url_filter_column_names, "url_filter_column_names") spark.sql( f""" ALTER TABLE {catalog_name}.{database_name}.{table_name} SET TBLPROPERTIES('atlas.mutable_table.URL_FILTER_COLUMNS' = '{json.dumps(url_filter_column_names)}') """ ) if column_definitions: _validate_json_serializable(column_definitions, "column_definitions") spark.sql( f""" ALTER TABLE {catalog_name}.{database_name}.{table_name} SET TBLPROPERTIES('atlas.mutable_table.COLUMN_DEFINITIONS' = '{json.dumps(column_definitions)}') """ ) if ag_grid_kwargs: _validate_json_serializable(ag_grid_kwargs, "ag_grid_kwargs") spark.sql( f""" ALTER TABLE {catalog_name}.{database_name}.{table_name} SET TBLPROPERTIES('atlas.mutable_table.AG_GRID_KWARGS' = '{json.dumps(ag_grid_kwargs)}') """ ) if dash_grid_options: _validate_json_serializable(dash_grid_options, "dash_grid_options") spark.sql( f""" ALTER TABLE {catalog_name}.{database_name}.{table_name} SET TBLPROPERTIES('atlas.mutable_table.DASH_GRID_OPTIONS' = '{json.dumps(dash_grid_options)}') """ ) if color_scale: _validate_json_serializable(color_scale, "color_scale") spark.sql( f""" ALTER TABLE {catalog_name}.{database_name}.{table_name} SET TBLPROPERTIES('atlas.mutable_table.COLOR_SCALE' = '{json.dumps(color_scale)}') """ )
@track_usage
[docs] def sync_mutable_table( changelog_table_name, catalog_name: Optional[str] = None, database_name: Optional[str] = None, table_name: Optional[str] = None, connection_info: dict = None, ): """ Run this operation on a mutable table to sync committed changes from the data app to the delta table. - After syncing, any queries on the delta table will be up to date with edits from the data app - Only committed changes are synced. Staged changes are ignored. - After comitting, the data app state will be updated to reflect the committed changes are now in a "synced" state :param changelog_table_name: The name of the data app changelog table. :param catalog_name: The name of the catalog to create the table in. If not provided, the catalog name will be derived from the notebook path. :param database_name: The name of the database to create the table in. If not provided, the database name will be derived from the notebook path. :param table_name: The name of the table to create. If not provided, the table name will be derived from the notebook path. :param connection_info: Dictionary of postges connection credentials. These are used to authenticate to the database instance. .. warning:: Make sure to have ``psycopg[binary]==3.2.9`` installed in the environment when running these functions. .. tip:: If using lakebase, the user should be a client ID for the service principal permissioned to connect to the database. The password should be an oauth token for the service principal given the workspace the lakebase instance lives in. Examples -------- .. code-block:: python :caption: Syncing a sync_mutable_table from etl_toolkit import sync_mutable_table connection_info = { "dbname": "databricks_postgres", "user": client_id, "password": token, "port": 5432, "host": lakebase_host, } sync_mutable_table( changlog_table_name="yd_dbr_example_pg.public.changelog_records", connection_info=connection_info, ) """ if any([catalog_name, database_name, table_name]): _validate_table_name_inputs( catalog_name=catalog_name, database_name=database_name, table_name=table_name, ) else: catalog_name, database_name, table_name = validate_notebook_path() if changelog_table_name is None: raise BaseETLToolkitException("Changelog table name must be specified") if connection_info is None: raise BaseETLToolkitException("Connection information must be specified") SOURCE_TABLE = f"{catalog_name}.{database_name}.{table_name}" spark = get_spark_session() is_enabled = ( spark.sql(f"SHOW TBLPROPERTIES {SOURCE_TABLE}") .where(F.col("key") == "atlas.mutable_table.ENABLED") .first() .value ).lower() == "true" if not is_enabled: raise BaseETLToolkitException( "Table is not a mutable table, check that `create_mutable_table` was used initially" ) PRIMARY_KEY_FIELD = ( spark.sql(f"SHOW TBLPROPERTIES {SOURCE_TABLE}") .where(F.col("key") == "atlas.mutable_table.ROW_ID_FIELD") .first() .value ) NOW = datetime.utcnow() spark = get_spark_session() from delta.tables import DeltaTable table = DeltaTable.forName(spark, SOURCE_TABLE) source_df = spark.table(SOURCE_TABLE) # Ignore metadata columns when identifying column-level edits for a sync operation source_df_trimmed = source_df.select( [col for col in source_df.columns if not col.startswith("_mt")] ) # Fetch row-level changes that are committed but not synced from the data app df = ( spark.table(changelog_table_name) .where(F.col("record_source_table") == SOURCE_TABLE) .where(F.col("is_staged")) .where(F.col("is_committed")) .where(~F.col("is_synced")) .orderBy(F.asc("committed_timestamp"), F.asc("staged_timestamp")) ) rows_to_update = df.collect() # Group changes based on primary key and order by commit timestamp # Merges will happen on each group, allowing for multiple row-level edits to be processed in a single sync df_parsed = df.withColumns( { "change_fields_parsed": F.from_json( "change_fields", source_df_trimmed.schema ), PRIMARY_KEY_FIELD: F.col("record_primary_key"), "rank": F.row_number().over( W.partitionBy("record_primary_key").orderBy( F.asc("committed_timestamp") ) ), } ) df_parsed.display() groupings = df_parsed.groupBy("rank").count().orderBy("rank").collect() # For each group merge in changes to the table # We preserve the row/column values if there wasn't an edit to it # Metadata columns are updated automatically merged_columns = _get_merge_columns(source_df_trimmed) for group in groupings: print( f"Updating group {group.rank} / {len(groupings)} containing {group['count']} rows .." ) joined = ( df_parsed.alias("new") .where(F.col("rank") == group.rank) .join( source_df_trimmed.alias("original"), [PRIMARY_KEY_FIELD], how="left", ) .select( merged_columns + [ F.col("new.record_primary_key").alias("_mt_record_primary_key"), F.col("new.change_id").alias("_mt_change_id"), F.col("new.change_group_id").alias("_mt_change_group_id"), F.col("new.change_type").alias("_mt_change_type"), F.col("new.change_owner_id").alias("_mt_change_owner_id"), F.col("new.is_staged").alias("_mt_is_staged"), F.col("new.staged_timestamp").alias("_mt_staged_timestamp"), F.col("new.commit_id").alias("_mt_commit_id"), F.col("new.is_committed").alias("_mt_is_committed"), F.col("new.committed_timestamp").alias("_mt_committed_timestamp"), F.lit(True).alias("_mt_is_synced"), F.lit(NOW).alias("_mt_synced_timestamp"), F.col("new.create_timestamp").alias("_mt_create_timestamp"), F.col("new.update_timestamp").alias("_mt_update_timestamp"), F.col("new.additional_metadata").alias("_mt_additional_metadata"), ] ) ) joined.display() ( table.merge( joined, joined[PRIMARY_KEY_FIELD] == table.toDF()[PRIMARY_KEY_FIELD] ) .whenNotMatchedInsertAll() .whenMatchedUpdateAll() .execute() ) # Mark all changelog records that were preocessed as synced primary_keys = [] for row in rows_to_update: primary_keys.append(row["change_id"]) _mark_changelog_records_as_synced( primary_keys, NOW, connection_info, )
def _mark_changelog_records_as_synced( primary_keys: list[str], current_time: datetime, connection_info: dict[str], ): """ Sets a list of ChangelogRecord instances to be in a synced state in the Postgres database """ import psycopg connection_string = _get_conn_string(connection_info) for row in primary_keys: print(f"Updating change_id: {row} ..") conn = psycopg.connect(connection_string) with conn: with conn.cursor() as cur: cur.execute( f""" UPDATE public.changelog_records SET is_synced = TRUE, synced_timestamp = %s WHERE change_id = ANY(%s) """, (current_time, primary_keys), ) conn.commit() def _get_merge_columns( source_df_trimmed: DataFrame, ): merged_columns = [] for col, dtype in source_df_trimmed.dtypes: is_changed = F.array_contains( F.json_object_keys("new.change_fields"), col, ) if dtype.lower().startswith("array") or dtype.lower().startswith("struct"): is_changed = is_changed & ( ~F.from_json( F.get_json_object("new.change_fields", f"$.{col}"), schema=dtype ) .cast(dtype) .eqNullSafe( F.from_json( F.get_json_object("new.prior_fields", f"$.{col}"), schema=dtype ).cast(dtype) ) ) else: is_changed = is_changed & ( # Ignore any new fields in inserted records that did not change from prior fields ~F.get_json_object("new.change_fields", f"$.{col}") .cast(dtype) .eqNullSafe( F.get_json_object("new.prior_fields", f"$.{col}").cast(dtype) ) ) # Use JSON parsing for nested fields if dtype.lower().startswith("array") or dtype.lower().startswith("struct"): expr = ( F.when( is_changed, F.from_json( F.get_json_object("new.change_fields", f"$.{col}"), schema=dtype, ).cast(dtype), ) .otherwise( F.col(f"original.{col}"), ) .alias(col) ) else: expr = ( F.when( is_changed, F.get_json_object("new.change_fields", f"$.{col}").cast(dtype), ) .otherwise( F.col(f"original.{col}"), ) .alias(col) ) merged_columns.append(expr) return merged_columns def _get_conn_string(connection_info: dict[str]) -> str: return f"dbname={connection_info['dbname']} user={connection_info['user']} password={connection_info['password']} host={connection_info['host']} port={connection_info['port']} sslmode=require" def _validate_table_name_inputs( catalog_name: Optional[str] = None, database_name: Optional[str] = None, table_name: Optional[str] = None, ): if any([catalog_name, database_name, table_name]): # If any override is provided, all must be provided if not all([catalog_name, database_name, table_name]): raise BaseETLToolkitException( "When providing manual overrides, all three parameters must be specified: " "catalog_name, database_name, and table_name" ) # Normalize and validate the provided names catalog_name = catalog_name.lower() database_name = database_name.lower() table_name = table_name.lower() validate_catalog_name(catalog_name) validate_database_name(database_name) validate_table_name(table_name) def _validate_json_serializable(obj: dict, name: str): try: json.dumps(obj) except (TypeError, ValueError) as e: raise BaseETLToolkitException( f"{name} contains non-JSON-serializable data: {e}" )