diff --git a/shared/src/shared/utils.py b/shared/src/shared/utils.py index eea6086..e191629 100644 --- a/shared/src/shared/utils.py +++ b/shared/src/shared/utils.py @@ -1,5 +1,8 @@ +from collections.abc import Generator from typing import Mapping +import polars as pl + import dagster as dg @@ -41,7 +44,7 @@ def parse_coalesced_partition_key( return dict(zip(dimension_names, parts)) -def get_partition_keys(context: dg.OpExecutionContext) -> Mapping[str, str]: +def get_partition_keys(context: dg.AssetExecutionContext) -> Mapping[str, str]: """ Get the partition key from the execution context. @@ -80,3 +83,24 @@ def parse_partition_keys( k: parse_coalesced_partition_key(k, dimension_names) for k in context.asset_partition_keys_for_input(input_name) } + + +def load_partitions( + context: dg.AssetExecutionContext, asset_key: dg.AssetKey, partitions: set[str] +) -> Generator[pl.DataFrame, None, None]: + """ + Load data from an asset for the specified partitions. + + Args: + context: The Dagster execution context. + asset_key: The key of the asset to load data from. + partitions: A set of partition keys to load data for. + + Yields: + DataFrames for each partition specified. + """ + from definitions import definitions + + loader = definitions.get_asset_value_loader(instance=context.instance) + for partition_key in partitions: + yield loader.load_asset_value(asset_key=asset_key, partition_key=partition_key)