from yipit_databricks_utils import create_table as create_table_ydbu
from yipit_databricks_utils.future import append_table
from yipit_databricks_utils.helpers.workspace import get_current_notebook_path
from yipit_databricks_utils.helpers.telemetry import track_usage
from yipit_databricks_utils.helpers.jobs import is_inside_job_run
from yipit_databricks_utils.helpers.dbutils import get_dbutils
from yipit_databricks_utils.helpers.logging import get_logger
from yipit_databricks_client import get_spark_session
from pyspark.sql.utils import AnalysisException
from pyspark.sql import DataFrame as SparkDataFrame
from pyspark.sql.connect.dataframe import DataFrame
import itertools
import os
import re
from etl_toolkit.exceptions import BaseETLToolkitException
from typing import Optional
logger = get_logger()
def validate_name(name: str, pattern: str, error_message: str):
if not re.match(pattern, name.lower()):
raise BaseETLToolkitException(error_message)
def validate_database_name(database_name: str):
pattern = r"^[a-z0-9_]+_(bronze|silver|gold|sandbox)$"
error_message = (
f"Invalid database name: {database_name}. "
"Database name must be lowercase, contain only letters, numbers, and underscores, "
"and end with _bronze, _silver, _gold, or _sandbox."
)
validate_name(database_name, pattern, error_message)
def validate_catalog_name(catalog_name: str):
pattern = r"^[a-z0-9_]+$"
error_message = (
f"Invalid catalog name: {catalog_name}. "
"Catalog name must be lowercase and contain only letters, numbers, and underscores."
)
validate_name(catalog_name, pattern, error_message)
def validate_table_name(table_name: str):
pattern = r"^[a-z0-9_]+$"
error_message = (
f"Invalid table name: {table_name}. "
"Table name must be lowercase and contain only letters, numbers, and underscores."
)
validate_name(table_name, pattern, error_message)
def validate_column_names(df: DataFrame):
valid_pattern = re.compile(r"^[a-z_][a-z0-9_]*$")
invalid_columns = [col for col in df.columns if not valid_pattern.match(col)]
if invalid_columns:
raise BaseETLToolkitException(
f"The following columns have invalid names: {', '.join(invalid_columns)}. "
"Column names must be lowercase, start with a letter or underscore, "
"and can only contain lowercase letters, numbers, and underscores."
)
def validate_dataframe(flat_dfs):
# Validate all DataFrames before writing
for df in flat_dfs:
if not isinstance(df, (SparkDataFrame, DataFrame)):
raise BaseETLToolkitException(
"All inputs must be Spark DataFrames or Spark Connect DataFrames."
)
validate_column_names(df)
@track_usage
def validate_notebook_path():
try:
path_components = get_current_notebook_path().split("/")[-3:]
# Normalize each component to lowercase before validation
catalog_name, database_name, table_name = [
comp.lower() for comp in path_components
]
validate_catalog_name(catalog_name)
validate_database_name(database_name)
validate_table_name(table_name)
return catalog_name, database_name, table_name
except IndexError:
raise BaseETLToolkitException(
"etl_toolkit.create_table must be called within a golden-path structured repo "
"with the standard folder nesting for the calling notebook, where the full "
"path of the notebook reflects the table name being created. Please edit the "
"notebook path to follow the golden path conventions."
""
"The pattern for the path is: /<ticker>/pipelines/<catalog>/<database>/<table> "
"(ex: /lyft/pipelines/yd_prodicution/lyft_gold/receipts_deduped)."
)
def resolve_notebook_path(dbutils=None):
if dbutils is None:
dbutils = get_dbutils()
context = dbutils.notebook.entry_point.getDbutils().notebook().getContext()
in_repo = not context.mlflowGitRelativePath().isEmpty()
if in_repo:
notebook_path = context.mlflowGitRelativePath().get()
else:
notebook_path = get_current_notebook_path()
return notebook_path
def resolve_repo_url(dbutils=None):
context = dbutils.notebook.entry_point.getDbutils().notebook().getContext()
if dbutils is None:
dbutils = get_dbutils()
in_repo = not context.mlflowGitRelativePath().isEmpty()
if in_repo:
repo_url = context.mlflowGitRepoUrl().get()
else:
repo_url = None
return repo_url
@track_usage
[docs]
def create_table(
*dfs: DataFrame | list[DataFrame],
catalog_name: Optional[str] = None,
database_name: Optional[str] = None,
table_name: Optional[str] = None,
**kwargs,
):
"""
Creates a new table or overwrites an existing table with the provided DataFrame(s).
The function supports two modes of operation:
1. Path-based (default): Uses the notebook path to determine table location
2. Manual override: Uses explicitly provided catalog, database, and table names (only for specific
engineering usage, not recommended otherwise)
This function is designed to simplify the process of creating and managing tables in a Databricks environment.
It resolves the table name based on the current notebook path, following a specific folder structure convention
as laid out in the Golden Path standards.
If multiple DataFrames are provided, they are appended to the table sequentially.
:param dfs: One or more DataFrames to be written to the table. If multiple DataFrames are provided,
they are appended to the table in the order they are given.
:param catalog_name: The name of the catalog to create the table in. If not provided, the catalog name
will be derived from the notebook path.
:param database_name: The name of the database to create the table in. If not provided, the database
name will be derived from the notebook path.
:param table_name: The name of the table to create. If not provided, the table name will be derived
from the notebook path.
:param kwargs: Additional keyword arguments to be passed to the underlying create_table_ydbu function.
Notable parameters include:
- table_comment: A string comment to be added to the table metadata.
- file_retention_days: Integer that dictates how many days of lookback you want to retain for your tables.
- minimum_versions: Integer that says how many historical table versions you want to keep accessible.
.. warning:: Using this ``create_table`` function requires following the Golden Path standards.
Ensure that your notebook is located in the correct folder structure for proper table name resolution.
``create_table`` will use the folder structure to determine the catalog, database, and table name.
Incorrect pathing will result in errors.
The expected path structure, as outlined in the Golden Path standards, is:
``/<ticker>/pipelines/<catalog>/<database>/<table>``
For example, a notebook path of ``/uber/pipelines/yd_production/uber_gold/coverage``
that calls ``create_table`` will generate a table named `yd_production.uber_gold.coverage`
.. tip:: When multiple DataFrames are provided, they are appended to the table sequentially.
The first DataFrame creates the table, and subsequent DataFrames are appended.
This allows for flexible data loading patterns and more performant table creation as it
avoids ineffecient union operations.
Examples
--------
.. code-block:: python
:caption: Creating a table with a single DataFrame
from etl_toolkit import create_table
df = spark.createDataFrame([
{"id": 1, "name": "Leopold"},
{"id": 2, "name": "Molly"}
])
create_table(df)
.. code-block:: python
:caption: Creating a table with multiple DataFrames
from etl_toolkit import create_table
df1 = spark.createDataFrame([
{"id": 1, "name": "Leopold"}
])
df2 = spark.createDataFrame([
{"id": 2, "name": "Molly"}
])
create_table(df1, df2)
.. code-block:: python
:caption: Creating a table with a custom comment
from etl_toolkit import create_table
df = spark.createDataFrame([
{"id": 1, "name": "Leopold"},
{"id": 2, "name": "Molly"}
{"id": 3, "name": "Stephen"}
])
create_table(df, table_comment="This table contains information about characters in a novel.")
.. code-block:: python
:caption: Creating a table with custom overrides
from etl_toolkit import create_table
df = spark.createDataFrame([
{"id": 1, "name": "Leopold"},
{"id": 2, "name": "Molly"}
{"id": 3, "name": "Stephen"}
])
create_table(
df,
catalog_name='my_catalog',
database_name='my_database_gold',
table_name='my_table'
)
"""
dbutils = get_dbutils()
notebook_path = resolve_notebook_path(dbutils)
repo_url = resolve_repo_url(dbutils)
# Determine if we're using manual override or path-based naming
if any([catalog_name, database_name, table_name]):
# If any override is provided, all must be provided
if not all([catalog_name, database_name, table_name]):
raise BaseETLToolkitException(
"When providing manual overrides, all three parameters must be specified: "
"catalog_name, database_name, and table_name"
)
# Normalize and validate the provided names
catalog_name = catalog_name.lower()
database_name = database_name.lower()
table_name = table_name.lower()
validate_catalog_name(catalog_name)
validate_database_name(database_name)
validate_table_name(table_name)
else:
# Use path-based naming
catalog_name, database_name, table_name = validate_notebook_path()
final_table_comment = f"""
\n
------------------------------------------------------
This table was created via the `etl_toolkit`. Details:
* Repo URL: {repo_url}
* Repo Path: {notebook_path}
"""
if kwargs.get("table_comment"):
table_comment = kwargs.pop("table_comment")
final_table_comment = table_comment + final_table_comment
# Flatten the list of DataFrames
flat_dfs = (
list(itertools.chain.from_iterable(dfs))
if isinstance(dfs[0], list)
else list(dfs)
)
if not flat_dfs:
raise BaseETLToolkitException(
"You must input either a list of DataFrames or a single DataFrame."
)
validate_dataframe(flat_dfs)
table_stats = []
# Write DataFrames
for idx, df in enumerate(flat_dfs):
if idx == 0:
create_table_ydbu(
database_name,
table_name,
df,
overwrite=True,
catalog_name=catalog_name,
table_comment=final_table_comment,
**kwargs,
)
else:
spark_options = kwargs.get("spark_options", {})
spark_options["mergeSchema"] = True
append_table(
database_name,
table_name,
df,
catalog_name=catalog_name,
spark_options=spark_options,
**kwargs,
)
table_stats.append(
{
"df": df,
"catalog_name": catalog_name,
"database_name": database_name,
"table_name": table_name,
"is_created": idx == 0,
}
)
if not is_inside_job_run():
print_table_stats(table_stats)
def print_table_stats(table_stats):
for stats in table_stats:
df = stats["df"]
catalog_name = stats["catalog_name"]
database_name = stats["database_name"]
table_name = stats["table_name"]
is_created = stats["is_created"]
table_action = "Created" if is_created else "Appended to"
row_count = df.count()
print(f"\nTable {table_action}: {catalog_name}.{database_name}.{table_name}")
print(f"Row count: {row_count}")
print(f"Schema: {df.schema}")
print("Sample data:")
display(df.limit(10))
print("-" * 50) # Add a separator between table stats