from dataclasses import dataclass
from datetime import datetime, date, timedelta
from typing import Literal, Optional, List
from pyspark.sql import DataFrame, SparkSession, functions as F
from etl_toolkit import expressions as E, A
from etl_toolkit.exceptions import InvalidInputException
from yipit_databricks_utils.helpers.pyspark_utils import get_spark_session
from yipit_databricks_utils.helpers.telemetry import track_usage
@dataclass
class Holiday:
"""
A class representing a holiday period with a name, start date, and optional end date.
Used to define custom holidays in calendar configurations.
:param name: Name of the holiday
:param start_date: Start date of the holiday
:param end_date: End date if holiday spans multiple days. Defaults to None (single day holiday)
:param description: Optional description of the holiday. Defaults to None
:return: A Holiday object
:raises InvalidInputException: If end_date is provided and is before start_date
Examples:
.. code-block:: python
:caption: Create a single day holiday
from etl_toolkit import A
from datetime import date
holiday = A.Holiday(
name="New Year's Day",
start_date=date(2024, 1, 1)
)
.. code-block:: python
:caption: Create a multi-day holiday period
holiday = A.Holiday(
name="Winter Break",
start_date=date(2024, 12, 24),
end_date=date(2024, 12, 26),
description="Company winter holiday"
)
Notes:
- If end_date is not provided, the holiday is treated as a single day
- The to_row property returns a tuple of (name, start_date, end_date) for use in calendar generation
"""
name: str
start_date: date
end_date: Optional[date] = None
description: Optional[str] = None
def __post_init__(self):
if self.end_date and self.start_date > self.end_date:
raise InvalidInputException(
f"Holiday {self.name}: start_date {self.start_date} must be before end_date {self.end_date}"
)
@property
def to_row(self) -> tuple:
return (self.name, self.start_date, self.end_date or self.start_date)
holiday = Holiday # type: ignore # noqa
@dataclass
class Period:
"""
A class representing a fiscal period (quarter or half) with defined start and end dates.
Used to define custom fiscal periods in calendar configurations.
:param start_date: Start date of the period
:param end_date: End date of the period
:param type: Type of period - either "quarter" or "half"
:return: A Period object
:raises InvalidInputException: If start_date is after end_date
Examples:
.. code-block:: python
:caption: Create a fiscal quarter
from etl_toolkit import A
from datetime import date
q1 = A.Period(
start_date=date(2024, 1, 1),
end_date=date(2024, 3, 31),
type="quarter"
)
.. code-block:: python
:caption: Create a fiscal half
h1 = A.Period(
start_date=date(2024, 1, 1),
end_date=date(2024, 6, 30),
type="half"
)
Notes:
- The duration_days property returns the total number of days in the period (inclusive)
- Period types are strictly enforced to be either "quarter" or "half"
"""
start_date: date
end_date: date
type: Literal["quarter", "half"]
def __post_init__(self):
if self.start_date > self.end_date:
raise InvalidInputException(
f"Period with {self.type}: start_date {self.start_date} must be before end_date {self.end_date}"
)
@property
def duration_days(self) -> int:
return (self.end_date - self.start_date).days + 1
period = Period
@dataclass
class CalendarConfiguration:
"""
A class representing calendar configuration settings for generating calendar data.
:param calendar_id: Identifier for this calendar configuration
:param start_date: Start date of the calendar range
:param end_date: End date of the calendar range
:param quarters: Optional list of custom quarter Period objects
:param halves: Optional list of custom half-year Period objects
:param holidays: Optional list of custom Holiday objects
:param country_holiday_codes: Optional list of country codes for default holidays
:param fiscal_year_start_month: Month when fiscal year starts (1-12)
:param fiscal_year_start_day: Day of month when fiscal year starts (1-31)
:raises InvalidInputException: If start_date is after end_date or if fiscal year config is invalid
"""
calendar_id: str
start_date: date
end_date: date
quarters: Optional[List[Period]] = None
halves: Optional[List[Period]] = None
holidays: Optional[List[Holiday]] = None
country_holiday_codes: Optional[List[str]] = (None,)
fiscal_year_start_month: int = 1
fiscal_year_start_day: int = 1
def __post_init__(self):
self.validate_dates()
self.validate_fiscal_config()
def validate_dates(self) -> None:
if self.start_date > self.end_date:
raise InvalidInputException(
f"start_date {self.start_date} must be before end_date {self.end_date}"
)
def validate_fiscal_config(self) -> None:
if self.fiscal_year_start_month not in range(1, 13):
raise InvalidInputException(
f"fiscal_year_start_month must be between 1 and 12, got {self.fiscal_year_start_month}"
)
if self.fiscal_year_start_day not in range(1, 32):
raise InvalidInputException(
f"fiscal_year_start_day must be between 1 and 31, got {self.fiscal_year_start_day}"
)
def _generate_base_calendar(
config: CalendarConfiguration, spark: SparkSession
) -> DataFrame:
return A.periods(
start=config.start_date, end=config.end_date, step_unit="DAY", spark=spark
).select(E.normalize_date("period_start").alias("day"))
def _get_month_periods(config: CalendarConfiguration, spark: SparkSession) -> DataFrame:
return A.periods(
start=config.start_date, end=config.end_date, step_unit="MONTH", spark=spark
).withColumn("period_end", F.last_day(F.col("period_start")))
def _get_week_periods(config: CalendarConfiguration, spark: SparkSession) -> DataFrame:
return (
A.periods(
start=config.start_date,
end=config.end_date,
step_unit="DAY",
steps=1,
spark=spark,
)
.withColumn("week_start", E.date_trunc("WEEK", "period_start"))
.withColumn("period_start", F.col("week_start"))
.withColumn("period_end", F.date_add(F.col("week_start"), 6))
.select("period_start", "period_end")
.distinct()
)
def _get_quarter_periods(
config: CalendarConfiguration, spark: SparkSession
) -> DataFrame:
if config.quarters:
return spark.createDataFrame(
[(q.start_date, q.end_date) for q in config.quarters],
["quarter_period_start", "quarter_period_end"],
)
return (
A.periods(
start=config.start_date,
end=config.end_date,
step_unit="QUARTER",
spark=spark,
)
.withColumn("quarter_label", E.quarter_label("period_start"))
.withColumn("period_end", F.last_day(F.add_months(F.col("period_start"), 2)))
.withColumnRenamed("period_start", "quarter_period_start")
.withColumnRenamed("period_end", "quarter_period_end")
)
def _add_period_metadata(
df: DataFrame, config: CalendarConfiguration, spark: SparkSession
) -> DataFrame:
"""Add period metadata using consistent join conditions."""
month_periods = _get_month_periods(config, spark)
week_periods = _get_week_periods(config, spark)
quarters_df = _get_quarter_periods(config, spark)
return (
df.join(
month_periods,
E.between(df.day, month_periods.period_start, month_periods.period_end),
"left",
)
.withColumnRenamed("period_start", "month_period_start")
.withColumnRenamed("period_end", "month_period_end")
.join(
week_periods,
E.between(df.day, week_periods.period_start, week_periods.period_end),
"left",
)
.withColumnRenamed("period_start", "week_period_start")
.withColumnRenamed("period_end", "week_period_end")
.join(
quarters_df,
E.between(
df.day, quarters_df.quarter_period_start, quarters_df.quarter_period_end
),
"left",
)
)
def _get_default_holidays(
spark: SparkSession,
start_date: date,
end_date: date,
country_codes: Optional[List[str]] = None,
) -> DataFrame:
"""Get holidays from the calendar_holidays table."""
default_holidays_df = spark.table(
"yd_lib_etl_toolkit.test_fixtures_bronze.calendar_holidays"
)
if not country_codes:
return default_holidays_df.limit(0)
return default_holidays_df.where(
E.all(
F.col("holiday_date").between(start_date, end_date),
F.col("country_code").isin(country_codes),
F.col("region_code").isNull(),
)
).select(F.col("holiday_date"), F.col("holiday_name"))
def _add_holiday_metadata(
df: DataFrame, config: CalendarConfiguration, spark: SparkSession
) -> DataFrame:
"""Add holiday information using consistent null-safe joins."""
if config.holidays:
# Custom holidays take precedence
holidays_df = spark.createDataFrame(
[h.to_row for h in config.holidays],
["holiday_name", "holiday_start", "holiday_end"],
)
df = df.join(
holidays_df,
E.between(df.day, holidays_df.holiday_start, holidays_df.holiday_end),
"left",
)
else:
# Use default holidays from the calendar_holidays table
holidays_df = _get_default_holidays(
spark, config.start_date, config.end_date, config.country_holiday_codes
)
df = df.join(holidays_df, df.day == holidays_df.holiday_date, "left")
return df.withColumns(
{
"is_holiday": F.col("holiday_name").isNotNull(),
"is_business_day": (F.dayofweek(F.col("day")).isin([1, 7]) == False)
& (F.col("is_holiday") == False),
}
)
def _generate_halves_from_quarters(quarters: List[Period]) -> List[Period]:
"""
Generate half-year periods from quarter periods.
:param quarters: List of quarter Period objects
:return: List of half-year Period objects
"""
if not quarters or len(quarters) < 4:
return []
sorted_quarters = sorted(quarters, key=lambda x: x.start_date)
halves = []
for half_idx in range(0, len(sorted_quarters), 2):
if half_idx + 1 < len(sorted_quarters):
half = Period(
start_date=sorted_quarters[half_idx].start_date,
end_date=sorted_quarters[half_idx + 1].end_date,
type="half",
)
halves.append(half)
return halves
@track_usage
[docs]
def calendar(
calendar_id: str = "standard",
start_date: Optional[date] = None,
end_date: Optional[date] = None,
custom_holidays: Optional[List[Holiday]] = None,
country_holiday_codes: Optional[List[str]] = None,
quarters: Optional[List[Period]] = None,
halves: Optional[List[Period]] = None,
fiscal_year_start_month: int = 1,
fiscal_year_start_day: int = 1,
qa: bool = False,
spark: SparkSession = None,
) -> DataFrame:
"""
Generate a dataframe with periodicities (day, week, month, quarter, year) and calendar metadata like holidays
and business days. The calendar can be customized with different fiscal year configurations, custom holidays,
and period definitions.
The function returns calendar data formatted similarly to the Freeport calendar features with standardized fields
for dates, periods, and calendar metadata.
:param calendar_id: Identifier for the calendar configuration to use. Defaults to "standard" which uses a standard 365-day calendar.
:param start_date: Optional minimum date to include in the calendar. Defaults to 5 years before current date.
:param end_date: Optional maximum date to include in the calendar. Defaults to 5 years after current date.
:param custom_holidays: Optional list of Holiday objects defining custom holidays to include.
:param country_holiday_codes: Optional list of two letter country codes to include default holidays.
:param quarters: Optional list of Period objects defining custom quarter date ranges.
:param halves: Optional list of Period objects defining custom half-year date ranges.
:param fiscal_year_start_month: Month when fiscal year starts (1-12). Defaults to 1 (January).
:param fiscal_year_start_day: Day of month when fiscal year starts (1-31). Defaults to 1.
:param qa: Enable QA mode to add additional columns for validation.
:param spark: Optional SparkSession. If not provided, will attempt to get session from yipit_databricks_utils. Generally, this is **not needed** as the session is automatically generated in databricks. It is used by library developers.
:return: DataFrame: Calendar data with the following columns:
- day (date): The calendar date
- calendar_type (string): Type of calendar (fiscal/standard)
- year_label (string): Year label (e.g. FY2024)
- year_period_start (date): Start date of year period
- year_period_end (date): End date of year period
- quarter_period_start (date): Start date of quarter
- quarter_period_end (date): End date of quarter
- quarter_label (string): Quarter label (e.g. 1Q24)
- month_period_start (date): Start date of month
- month_period_end (date): End date of month
- week_period_start (date): Start of week (Sunday)
- week_period_end (date): End of week (Saturday)
- half_year_period_start (date): Start of half-year period
- half_year_period_end (date): End of half-year period
- half_year_label (string): Half-year label (e.g. 1HY24)
- is_business_day (boolean): Is this a business day
- is_holiday (boolean): Is this a holiday
- holiday_name (string): Name of holiday if applicable
- days_in_week (int): Days elapsed in current week
- days_in_month (int): Days elapsed in current month
- days_in_quarter (int): Days elapsed in current quarter
- days_in_year (int): Days elapsed in current year
- days_in_half_year (int): Days elapsed in current half year
Additional columns for fiscal calendars:
- custom_year_of_quarter (int): Custom year mapping for quarter
- leap_year (boolean): Is this a leap year
- leap_day (boolean): Is this a leap day (Feb 29)
Examples
^^^^^^^^^^
.. code-block:: python
:caption: Generate standard calendar for default date range
from etl_toolkit import A
df = A.calendar()
.. code-block:: python
:caption: Generate calendar with specific date range
from etl_toolkit import A
from datetime import date
df = A.calendar(
start_date=date(2024, 1, 1),
end_date=date(2024, 12, 31)
)
.. code-block:: python
:caption: Generate calendar with custom holidays
from etl_toolkit import A
from datetime import date
df = A.calendar(
custom_holidays=[
A.holiday(
name="Company Holiday",
start_date=date(2024, 12, 24),
end_date=date(2024, 12, 26)
)
]
)
.. code-block:: python
:caption: Generate calendar with custom fiscal quarters
from etl_toolkit import A
from datetime import date
df = A.calendar(
quarters=[
A.period(
start_date=date(2024, 10, 1),
end_date=date(2024, 12, 31),
type="quarter"
),
A.period(
start_date=date(2025, 1, 1),
end_date=date(2025, 3, 31),
type="quarter"
),
A.period(
start_date=date(2025, 4, 1),
end_date=date(2025, 6, 30),
type="quarter"
),
A.period(
start_date=date(2025, 7, 1),
end_date=date(2025, 9, 30),
type="quarter"
)
]
)
.. code-block:: python
:caption: Generate calendar with holidays from multiple countries
from etl_toolkit import A
from datetime import date
df = A.calendar(
start_date=date(2024, 1, 1),
end_date=date(2024, 12, 31),
country_holiday_codes=["US", "GB", "JP"] # Include US, UK, and Japan holidays
)
"""
spark = spark or get_spark_session()
# Set default date range if not provided
if start_date is None:
start_date = datetime.now().date() - timedelta(days=5 * 365)
if end_date is None:
end_date = datetime.now().date() + timedelta(days=5 * 365)
if quarters and not halves:
halves = _generate_halves_from_quarters(quarters)
config = CalendarConfiguration(
calendar_id=calendar_id,
start_date=start_date,
end_date=end_date,
quarters=quarters,
halves=halves,
holidays=custom_holidays,
country_holiday_codes=country_holiday_codes,
fiscal_year_start_month=fiscal_year_start_month,
fiscal_year_start_day=fiscal_year_start_day,
)
# Generate base calendar
df = _generate_base_calendar(config, spark)
df = df.withColumn("calendar_type", F.lit(calendar_id))
df = df.withColumn("year_label", F.year(F.col("day")).cast("string"))
# Add period metadata
df = _add_period_metadata(df, config, spark)
# Add year period information
df = df.withColumns(
{
"year_period_start": E.date_trunc("YEAR", "day"),
"year_period_end": E.date_end("YEAR", "day"),
}
)
# Add period elapsed days
df = df.withColumns(
{
"days_in_week": (
F.datediff(F.col("day"), F.col("week_period_start")) + 1
).cast("int"),
"days_in_month": F.dayofmonth(F.col("day")),
"days_in_quarter": F.datediff(F.col("day"), F.col("quarter_period_start")),
"days_in_year": F.datediff(F.col("day"), F.col("year_period_start")),
}
)
# Add holiday metadata
df = _add_holiday_metadata(df, config, spark)
# Add fiscal calendar specific columns
if calendar_id != "standard":
df = df.withColumns(
{
"leap_year": E.all(
F.year(F.col("day")) % 4 == 0,
E.any(
F.year(F.col("day")) % 100 != 0, F.year(F.col("day")) % 400 == 0
),
),
"leap_day": E.all(
F.month(F.col("day")) == 2, F.dayofmonth(F.col("day")) == 29
),
}
)
# Add QA columns if requested
if qa:
df = df.withColumns(
{
"is_fiscal_period": F.month(F.col("day")) >= fiscal_year_start_month,
"fiscal_year_offset": F.lit(fiscal_year_start_month - 1),
"period_type": F.when(
F.col("calendar_type") == "standard", "standard"
).otherwise("fiscal"),
}
)
return df.orderBy("day")