hyperion-sdk 0.2.0.dev1741815359__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,623 @@
1
+ """The data catalog."""
2
+
3
+ import asyncio
4
+ import datetime
5
+ import re
6
+ import tempfile
7
+ from collections.abc import Iterable, Iterator
8
+ from contextlib import ExitStack, contextmanager
9
+ from pathlib import Path
10
+ from typing import IO, Any, BinaryIO, ClassVar, Generic, TypeAlias, TypedDict, TypeVar, cast
11
+ from uuid import uuid4
12
+
13
+ import botocore.exceptions
14
+ import fastavro
15
+ import fastavro.validation
16
+ import fastavro.write
17
+
18
+ from hyperion.asyncutils import AsyncTaskQueue, aiter_any
19
+ from hyperion.catalog.schema import SchemaStore
20
+ from hyperion.config import storage_config
21
+ from hyperion.dateutils import (
22
+ TimeResolution,
23
+ TimeResolutionUnit,
24
+ assure_timezone,
25
+ iter_dates_between,
26
+ quantize_datetime,
27
+ truncate_datetime,
28
+ utcnow,
29
+ )
30
+ from hyperion.entities.catalog import (
31
+ AssetProtocol,
32
+ DataLakeAsset,
33
+ FeatureAsset,
34
+ PersistentStoreAsset,
35
+ get_prefixed_path,
36
+ )
37
+ from hyperion.infrastructure.aws import S3Client
38
+ from hyperion.infrastructure.queue import ArrivalEvent, DataLakeArrivalMessage, Queue
39
+ from hyperion.logging import get_logger
40
+
41
+ __all__ = ["AssetNotFoundError", "Catalog", "CatalogError"]
42
+
43
+ logger = get_logger("catalog")
44
+
45
+ RepartitionableAssetType: TypeAlias = FeatureAsset | DataLakeAsset
46
+ RepartitionableAsset = TypeVar("RepartitionableAsset", bound=RepartitionableAssetType)
47
+
48
+
49
+ class StoreBucketConfig(TypedDict):
50
+ """Configuration for storing assets in a bucket."""
51
+
52
+ bucket: str
53
+ prefix: str
54
+
55
+
56
+ class CatalogError(Exception):
57
+ """Base class for catalog errors."""
58
+
59
+
60
+ class AssetNotFoundError(CatalogError):
61
+ """Raised when an asset is not found in the catalog."""
62
+
63
+ def __init__(self, asset: "AssetProtocol") -> None:
64
+ super().__init__(f"Asset {asset.name!r} not found in the catalog.")
65
+
66
+
67
+ def _unpack_args(*args: Any, **kwargs: Any) -> tuple[Any, ...]:
68
+ arguments = (*args,)
69
+ for _, value in sorted(kwargs.items(), key=lambda pair: pair[0]):
70
+ arguments += (value,)
71
+ return arguments
72
+
73
+
74
+ class PersistentStore:
75
+ """A persistent store for assets."""
76
+
77
+ # TODO: Unfinished business
78
+ # https://github.com/Zephyr-Trade/FVE-map/issues/9
79
+ _instances: ClassVar[dict[tuple[PersistentStoreAsset, str, str], "PersistentStore"]] = {}
80
+
81
+ def __new__(cls, *args: Any, **kwargs: Any) -> "PersistentStore":
82
+ init_arguments = _unpack_args(*args, **kwargs)
83
+ if init_arguments not in cls._instances:
84
+ cls._instances[init_arguments] = super().__new__(cls)
85
+ return cls._instances[init_arguments]
86
+
87
+ def __init__(
88
+ self, asset: PersistentStoreAsset, persistent_store_bucket: str, persistent_store_prefix: str = ""
89
+ ) -> None:
90
+ """Initialize the persistent store.
91
+
92
+ Args:
93
+ asset (PersistentStoreAsset): The asset to store.
94
+ persistent_store_bucket (str): The bucket to store the asset in.
95
+ persistent_store_prefix (str): The prefix for the asset in the bucket.
96
+ """
97
+ self.asset = asset
98
+ self.persistent_store_bucket = persistent_store_bucket
99
+ self.persistent_store_prefix = persistent_store_prefix
100
+ self._local_path: Path | None = None
101
+ self._etag: str | None = None
102
+
103
+ def cleanup(self) -> None:
104
+ """Clean up the persistent store.
105
+
106
+ This method deletes the local file if it exists.
107
+ """
108
+ if self._local_path is None:
109
+ logger.debug("Persistent store was not retrieved, nothing to clean up.", asset=self.asset)
110
+ return
111
+ logger.info(
112
+ "Cleaning up previously retrieved persistent store.", asset=self.asset, path=self._local_path.as_posix()
113
+ )
114
+ self._local_path.unlink(missing_ok=True)
115
+ self._local_path = None
116
+
117
+ def retrieve(self) -> None:
118
+ """Retrieve the persistent store from the S3 bucket."""
119
+ s3_client = S3Client()
120
+ try:
121
+ remote_etag = s3_client.get_object_attributes(
122
+ self.persistent_store_bucket, self.asset.get_path(self.persistent_store_prefix)
123
+ ).etag
124
+ except botocore.exceptions.ClientError as error:
125
+ if error.response["Error"]["Code"] == "NoSuchKey":
126
+ raise AssetNotFoundError(self.asset) from error
127
+ raise
128
+ if self._local_path is not None:
129
+ if remote_etag == self._etag:
130
+ logger.info(
131
+ "Persistent store previously retrieved.",
132
+ asset=self.asset,
133
+ path=self._local_path.as_posix(),
134
+ etag=remote_etag,
135
+ )
136
+ return
137
+ logger.info(
138
+ "Forcing re-download of a an outdated previously retrieved store.",
139
+ asset=self.asset,
140
+ local_etag=self._etag,
141
+ remote_etag=remote_etag,
142
+ )
143
+ self.cleanup()
144
+
145
+ local_path = Path(tempfile.gettempdir()) / f"{uuid4().hex}.asset"
146
+ logger.info("Retrieving persistent store.", asset=self.asset, path=local_path.as_posix())
147
+ s3_client.download(self.persistent_store_bucket, self.asset.get_path(self.persistent_store_prefix), local_path)
148
+ self._local_path = local_path
149
+ self._etag = remote_etag
150
+
151
+ def __enter__(self) -> None:
152
+ self.retrieve()
153
+
154
+ def __exit__(self, *args: Any) -> None:
155
+ self.cleanup()
156
+
157
+
158
+ class WritablePersistentStore(PersistentStore):
159
+ """A writable persistent store for assets."""
160
+
161
+ # TODO: Unfinished business
162
+ # https://github.com/Zephyr-Trade/FVE-map/issues/9
163
+ def store(self, data: Iterable[dict[str, Any]]) -> None:
164
+ """Store data in the persistent store.
165
+
166
+ Args:
167
+ data (Iterable[dict[str, Any]]): The data to store.
168
+ """
169
+ with tempfile.TemporaryFile("+wb") as file:
170
+ logger.info("Pouring persistent store asset into temporary file.", asset=self.asset, path=file.name)
171
+ schema = SchemaStore.from_config().get_asset_schema(self.asset)
172
+ _write_avro(file, schema, data, self.asset.to_metadata())
173
+ s3_client = S3Client()
174
+ s3_client.upload(file, self.persistent_store_bucket, self.asset.get_path(self.persistent_store_prefix))
175
+
176
+
177
+ def _write_avro(
178
+ fp: BinaryIO | IO[bytes], schema: dict[str, Any], data: Iterable[dict[str, Any]], metadata: dict[str, str]
179
+ ) -> None:
180
+ fastavro.writer(
181
+ fp,
182
+ records=data,
183
+ schema=schema,
184
+ codec="deflate",
185
+ validator=True,
186
+ codec_compression_level=7,
187
+ strict=False,
188
+ strict_allow_default=True,
189
+ metadata=metadata,
190
+ )
191
+
192
+
193
+ class Catalog:
194
+ """The data catalog.
195
+
196
+ The catalog is responsible for storing and retrieving assets.
197
+ """
198
+
199
+ def __init__(
200
+ self,
201
+ *,
202
+ data_lake_bucket: str,
203
+ feature_store_bucket: str,
204
+ persistent_store_bucket: str,
205
+ data_lake_prefix: str = "",
206
+ feature_store_prefix: str = "",
207
+ persistent_store_prefix: str = "",
208
+ queue: Queue | None = None,
209
+ ) -> None:
210
+ """Initialize the catalog.
211
+
212
+ Args:
213
+ data_lake_bucket (str): The bucket for data lake assets.
214
+ feature_store_bucket (str): The bucket for feature store assets.
215
+ persistent_store_bucket (str): The bucket for persistent store assets.
216
+ data_lake_prefix (str): The prefix for data lake assets.
217
+ feature_store_prefix (str): The prefix for feature store assets.
218
+ persistent_store_prefix (str): The prefix for persistent store assets.
219
+ queue (Queue, optional): The queue to use for notifications. Defaults to None.
220
+ """
221
+ self.data_lake_bucket = data_lake_bucket
222
+ self.feature_store_bucket = feature_store_bucket
223
+ self.persistent_store_bucket = persistent_store_bucket
224
+ self.data_lake_prefix = data_lake_prefix
225
+ self.feature_store_prefix = feature_store_prefix
226
+ self.persistent_store_prefix = persistent_store_prefix
227
+ self._config_map: dict[type[AssetProtocol], StoreBucketConfig] = {
228
+ DataLakeAsset: {"bucket": self.data_lake_bucket, "prefix": self.data_lake_prefix},
229
+ PersistentStoreAsset: {"bucket": self.persistent_store_bucket, "prefix": self.persistent_store_prefix},
230
+ FeatureAsset: {"bucket": self.feature_store_bucket, "prefix": self.feature_store_prefix},
231
+ }
232
+ self._s3_client: S3Client | None = None
233
+ self.queue = queue or Queue.from_config()
234
+
235
+ @property
236
+ def s3_client(self) -> S3Client:
237
+ """Get the S3 client."""
238
+ if self._s3_client is None:
239
+ self._s3_client = S3Client()
240
+ return self._s3_client
241
+
242
+ @staticmethod
243
+ def from_config() -> "Catalog":
244
+ """Create a catalog from the configuration."""
245
+ return Catalog(
246
+ data_lake_bucket=storage_config.data_lake_bucket,
247
+ feature_store_bucket=storage_config.feature_store_bucket,
248
+ persistent_store_bucket=storage_config.persistent_store_bucket,
249
+ data_lake_prefix=storage_config.data_lake_prefix,
250
+ feature_store_prefix=storage_config.data_lake_prefix,
251
+ persistent_store_prefix=storage_config.persistent_store_prefix,
252
+ )
253
+
254
+ @contextmanager
255
+ def _prepare_asset_storage(self, asset: AssetProtocol, data: Iterable[dict[str, Any]]) -> Iterator[IO[bytes]]:
256
+ with tempfile.NamedTemporaryFile("+wb") as file:
257
+ schema = SchemaStore.from_config().get_asset_schema(asset)
258
+ path = Path(file.name)
259
+ logger.info("Pouring asset into a temporary file.", asset=asset, file=path.as_posix())
260
+ _write_avro(file, schema, data, asset.to_metadata())
261
+ logger.info("Avro file was created successfully.", path=path.as_posix(), size=file.tell())
262
+ file.seek(0)
263
+ yield file
264
+
265
+ async def store_asset_async(
266
+ self, asset: AssetProtocol, data: Iterable[dict[str, Any]], notify: bool = True
267
+ ) -> None:
268
+ """Store an asset in its bucket asynchronously.
269
+
270
+ Args:
271
+ asset (AssetProtocol): The asset to store.
272
+ data (Iterable[dict[str, Any]]): The data to store.
273
+ notify (bool, optional): Whether to send a notification. Defaults to True.
274
+ """
275
+ store_config = self.get_store_config(asset)
276
+ logger.info("Preparing asset storage.", asset=asset, **store_config)
277
+ with ExitStack() as stack:
278
+ file = stack.enter_context(await asyncio.to_thread(self._prepare_asset_storage, asset, data))
279
+ await self.s3_client.upload_async(
280
+ file, bucket=store_config["bucket"], name=asset.get_path(store_config["prefix"])
281
+ )
282
+ if notify:
283
+ self._notify_asset_arrival(asset)
284
+
285
+ def _notify_asset_arrival(self, asset: AssetProtocol) -> None:
286
+ if not isinstance(asset, DataLakeAsset):
287
+ logger.debug("Skipping notification for asset, not DataLakeAsset type.", asset=asset)
288
+ return
289
+ message = DataLakeArrivalMessage(asset=asset, event=ArrivalEvent.ARRIVED)
290
+ logger.info("Sending data lake arrival message.", asset=asset, message=message, queue=self.queue)
291
+ self.queue.send(message)
292
+
293
+ def get_feature_data(
294
+ self,
295
+ name: str,
296
+ resolution: TimeResolution | str,
297
+ the_now: datetime.datetime | None = None,
298
+ tolerance: int = 0,
299
+ ) -> Iterator[dict[str, Any]]:
300
+ resolution = resolution if isinstance(resolution, TimeResolution) else TimeResolution.from_str(resolution)
301
+ the_now = the_now or utcnow()
302
+ try_timestamps = [the_now - resolution.delta * i for i in range(0, tolerance + 1)]
303
+ for timestamp in try_timestamps:
304
+ feature_partition_date = quantize_datetime(timestamp, resolution)
305
+ feature_asset = FeatureAsset(name, feature_partition_date, resolution)
306
+ logger.debug(
307
+ f"Trying to find feature data for feature {feature_asset.feature_name!r}.", feature_asset=feature_asset
308
+ )
309
+ try:
310
+ self.get_asset_file_size(feature_asset)
311
+ except AssetNotFoundError as error:
312
+ logger.error(f"Feature asset was not found - {error}.", feature_asset=feature_asset)
313
+ continue
314
+ return self.retrieve_asset(feature_asset)
315
+ raise AssetNotFoundError(feature_asset)
316
+
317
+ def store_asset(self, asset: AssetProtocol, data: Iterable[dict[str, Any]], notify: bool = True) -> None:
318
+ """Store an asset in its bucket.
319
+
320
+ Args:
321
+ asset (AssetProtocol): The asset to store.
322
+ data (Iterable[dict[str, Any]]): The data to store.
323
+ notify (bool, optional): Whether to send a notification. Defaults to True.
324
+ """
325
+ store_config = self.get_store_config(asset)
326
+ logger.info("Preparing asset storage.", asset=asset, **store_config)
327
+ with self._prepare_asset_storage(asset, data) as file:
328
+ self.s3_client.upload(file, bucket=store_config["bucket"], name=asset.get_path(store_config["prefix"]))
329
+ if notify:
330
+ self._notify_asset_arrival(asset)
331
+
332
+ def get_store_config(self, asset: AssetProtocol | type[AssetProtocol]) -> StoreBucketConfig:
333
+ try:
334
+ config_key = asset if isinstance(asset, type) else asset.__class__
335
+ return self._config_map[config_key]
336
+ except KeyError:
337
+ logger.error("Attempting to get store config for an unsupported asset tyoe.", asset=asset)
338
+ raise
339
+
340
+ def get_asset_file_size(self, asset: AssetProtocol) -> int:
341
+ """Find asset avro file and get its file size in bytes.
342
+
343
+ Args:
344
+ asset (AssetProtocol): The asset to get the file size for.
345
+
346
+ Returns:
347
+ int: The file size in bytes.
348
+ """
349
+ store_config = self.get_store_config(asset)
350
+
351
+ object_path = asset.get_path(store_config["prefix"])
352
+ logger.info("Getting attributes for an asset.", asset=asset, **store_config)
353
+ try:
354
+ return self.s3_client.get_object_attributes(store_config["bucket"], object_path).object_size
355
+ except botocore.exceptions.ClientError as error:
356
+ if error.response["Error"]["Code"] == "NoSuchKey":
357
+ raise AssetNotFoundError(asset) from error
358
+ raise
359
+
360
+ def retrieve_asset(self, asset: AssetProtocol) -> Iterator[dict[str, Any]]:
361
+ """Retrieve an asset based on its type and store config.
362
+
363
+ Args:
364
+ asset (AssetProtocol): The asset to retrieve.
365
+
366
+ Yields:
367
+ dict[str, Any]: The asset data.
368
+ """
369
+ store_config = self.get_store_config(asset)
370
+ file_size = self.get_asset_file_size(asset)
371
+ logger.info("Preparing asset for retrieval.", asset=asset, file_size=file_size, **store_config)
372
+ with tempfile.NamedTemporaryFile("+wb") as file:
373
+ logger.info("Downloading asset into a temporary file.", asset=asset, path=file.name)
374
+ self.s3_client.download(store_config["bucket"], asset.get_path(store_config["prefix"]), file)
375
+ file.seek(0)
376
+ for row_number, row in enumerate(fastavro.reader(file), start=1):
377
+ if isinstance(row, dict):
378
+ yield row
379
+ else:
380
+ logger.error(
381
+ "Unexpected data found in data lake asset row.",
382
+ asset=asset,
383
+ expected="dict",
384
+ row_number=row_number,
385
+ got=str(type(row)),
386
+ )
387
+ raise TypeError("Unexpected data received when reading asset data.")
388
+
389
+ def iter_datalake_partitions(self, asset_name: str, date_part: str | None = None) -> Iterator[DataLakeAsset]:
390
+ """Iterate over data lake partitions.
391
+
392
+ Providing the date part can significantly reduce the number of keys to iterate over.
393
+ Partition dates are stored as ISO formatted strings, therefore the date part should be in the same format.
394
+ E.g. to only iterate over partitions for January 2025, provide '2025-01'.
395
+
396
+ Args:
397
+ asset_name (str): The name of the asset.
398
+ date_part (str, optional): The date part to filter by. Defaults to None.
399
+
400
+ Yields:
401
+ DataLakeAsset: The data lake asset.
402
+ """
403
+ store_config = self.get_store_config(DataLakeAsset)
404
+ keys_prefix = get_prefixed_path(f"{asset_name}/date={date_part if date_part else ''}", store_config["prefix"])
405
+ version_patt = re.compile(r"v(?P<version>\d+)\.avro")
406
+ for key in self.s3_client.iter_objects(store_config["bucket"], keys_prefix):
407
+ try:
408
+ key_asset_name, partition, filename = key.split("/")
409
+ except ValueError:
410
+ logger.warning(
411
+ "The key path does not have 'Asset/Partition/Version' format and will be skipped.", key=key
412
+ )
413
+ continue
414
+ if key_asset_name != asset_name:
415
+ logger.warning(
416
+ "The key path does not match the asset name and will be skipped.", key=key, asset_name=asset_name
417
+ )
418
+ continue
419
+ if (match := version_patt.match(filename)) is None:
420
+ logger.warning("The key path does not match the version pattern and will be skipped.", key=key)
421
+ continue
422
+ version = int(match.group("version"))
423
+ partition_date_str = partition.split("date=")[1]
424
+ partition_date = datetime.datetime.fromisoformat(partition_date_str)
425
+ logger.debug(
426
+ "Found data lake partition.", asset_name=asset_name, partition_date=partition_date, version=version
427
+ )
428
+ yield DataLakeAsset(asset_name, assure_timezone(partition_date), version)
429
+
430
+ def find_latest_datalake_partition(self, asset_name: str, date_part: str | None = None) -> DataLakeAsset:
431
+ """Find the latest data lake partition.
432
+
433
+ Providing the date part can significantly reduce the number of keys to iterate over.
434
+ Partition dates are stored as ISO formatted strings, therefore the date part should be in the same format.
435
+ E.g. to only iterate over partitions for January 2025, provide '2025-01'.
436
+
437
+ Args:
438
+ asset_name (str): The name of the asset.
439
+ date_part (str, optional): The date part to filter by. Defaults to None.
440
+
441
+ Returns:
442
+ DataLakeAsset: The latest data lake partition.
443
+ """
444
+ return next(
445
+ iter(
446
+ sorted(
447
+ self.iter_datalake_partitions(asset_name, date_part),
448
+ key=lambda partition: partition.date,
449
+ reverse=True,
450
+ )
451
+ )
452
+ )
453
+
454
+ def iter_feature_store_partitions(
455
+ self,
456
+ feature_name: str,
457
+ resolution: TimeResolution | str,
458
+ date_from: datetime.datetime,
459
+ date_to: datetime.datetime,
460
+ version: int = 1,
461
+ ) -> Iterator[FeatureAsset]:
462
+ """Iterate over feature store partitions relevant for a given time range.
463
+
464
+ For a given time range, finds all feature store partitions that could contain
465
+ data for that range based on the feature's resolution. For example, for dates
466
+ between 2025-01-01 and 2025-01-15 with 7d resolution, this would check partitions
467
+ 2025-01-08, 2025-01-15, and 2025-01-22, since data points from those dates would
468
+ be stored in these quantized partitions.
469
+
470
+ Args:
471
+ feature_name (str): The name of the feature.
472
+ resolution (TimeResolution | str): The time resolution of the feature.
473
+ date_from (datetime.datetime): Start of the time range to find partitions for.
474
+ date_to (datetime.datetime): End of the time range to find partitions for.
475
+ version (int): Schema version of the feature. Defaults to 1.
476
+
477
+ Yields:
478
+ FeatureAsset: Feature store assets that could contain data for the time range.
479
+ """
480
+ resolution = resolution if isinstance(resolution, TimeResolution) else TimeResolution.from_str(resolution)
481
+
482
+ dates = iter_dates_between(date_from, date_to, resolution.unit)
483
+
484
+ partition_dates = {quantize_datetime(date, resolution) for date in dates}
485
+
486
+ for partition_date in sorted(partition_dates):
487
+ feature_asset = FeatureAsset(feature_name, partition_date, resolution, schema_version=version)
488
+
489
+ try:
490
+ self.get_asset_file_size(feature_asset)
491
+ logger.debug("Found partition for the feature.", asset=feature_asset)
492
+ yield feature_asset
493
+
494
+ except AssetNotFoundError:
495
+ logger.debug("No feature store partition found for timestamp.", asset=feature_asset)
496
+ continue
497
+
498
+ async def repartition(
499
+ self,
500
+ asset: DataLakeAsset | FeatureAsset,
501
+ granularity: TimeResolutionUnit,
502
+ date_attribute: str = "timestamp",
503
+ data: Iterable[dict[str, Any]] | None = None,
504
+ ) -> None:
505
+ """Repartition a data lake asset based on a time resolution unit.
506
+
507
+ If data is not provided, the asset is retrieved from the catalog.
508
+
509
+ Args:
510
+ asset (DataLakeAsset): The asset to repartition.
511
+ granularity (TimeResolutionUnit): The time resolution unit to use.
512
+ date_attribute (str, optional): The date attribute to use. Defaults to "timestamp".
513
+ data (Iterable[dict[str, Any]], optional): The data to repartition. Defaults to None.
514
+ """
515
+ repartitioner = AssetRepartitioner(self, asset, granularity, date_attribute)
516
+ await repartitioner.repartition(data)
517
+
518
+
519
+ class AssetRepartitioner(Generic[RepartitionableAsset]):
520
+ """A class to repartition a data lake asset based on a time resolution unit."""
521
+
522
+ def __init__(
523
+ self,
524
+ catalog: Catalog,
525
+ asset: RepartitionableAsset,
526
+ granularity: TimeResolutionUnit,
527
+ date_attribute: str = "timestamp",
528
+ ) -> None:
529
+ """Initialize the repartitioner.
530
+
531
+ Args:
532
+ catalog (Catalog): The catalog to use.
533
+ asset (DataLakeAsset | FeatureAsset): The asset to repartition.
534
+ granularity (TimeResolutionUnit): The time resolution unit to use.
535
+ date_attribute (str, optional): The date attribute to use. Defaults to "timestamp".
536
+ """
537
+ self.catalog = catalog
538
+ self.asset = asset
539
+ self.granularity = granularity
540
+ self.date_attribute = date_attribute
541
+ self._partition_name = "date" if isinstance(self.asset, DataLakeAsset) else "timestamp"
542
+
543
+ self._state: dict[datetime.datetime, tuple[IO[bytes], RepartitionableAsset, fastavro.write.Writer]] = {}
544
+
545
+ def __enter__(self) -> None:
546
+ self._state = {}
547
+
548
+ def __exit__(self, *args: Any) -> None:
549
+ for file, _, __ in self._state.values():
550
+ logger.info("Closing temporary file.", path=file.name)
551
+ file.close()
552
+
553
+ def _create_partition_asset(self, partition_date: datetime.datetime) -> RepartitionableAsset:
554
+ if isinstance(self.asset, DataLakeAsset):
555
+ return cast(RepartitionableAsset, DataLakeAsset(self.asset.name, partition_date, self.asset.schema_version))
556
+ if isinstance(self.asset, FeatureAsset):
557
+ return cast(
558
+ RepartitionableAsset,
559
+ FeatureAsset(
560
+ self.asset.name,
561
+ partition_date,
562
+ self.asset.resolution,
563
+ self.asset.schema_version,
564
+ self.asset.partition_keys,
565
+ ),
566
+ )
567
+ raise TypeError(f"Unsupported asset type {type(self.asset)!r}.")
568
+
569
+ def _get_handler(
570
+ self, partition_date: datetime.datetime
571
+ ) -> tuple[IO[bytes], RepartitionableAsset, fastavro.write.Writer]:
572
+ if partition_date in self._state:
573
+ return self._state[partition_date]
574
+ logger.info("Creating a new handler for partition date.", partition_date=partition_date.isoformat())
575
+ file = tempfile.NamedTemporaryFile("+wb") # noqa: SIM115
576
+ logger.info(f"Partition will be temporarily stored in {file.name!r}.", path=file.name)
577
+ asset = self._create_partition_asset(partition_date)
578
+ writer = fastavro.write.Writer(
579
+ file,
580
+ schema=SchemaStore.from_config().get_asset_schema(asset),
581
+ codec="deflate",
582
+ validator=True,
583
+ metadata=asset.to_metadata(),
584
+ )
585
+ handler = (file, asset, writer)
586
+ self._state[partition_date] = handler
587
+ return handler
588
+
589
+ async def repartition(self, data: Iterable[dict[str, Any]] | None = None) -> None:
590
+ """Repartition the asset.
591
+
592
+ This method reads the asset data, partitions it based on the date attribute and granularity,
593
+ and uploads the partitioned data to the data lake bucket.
594
+ If data is not provided, the asset is retrieved from the catalog.
595
+
596
+ Args:
597
+ data (Iterable[dict[str, Any]], optional): The data to repartition. Defaults to None.
598
+ """
599
+ data = data or self.catalog.retrieve_asset(self.asset)
600
+ with self:
601
+ for record in data:
602
+ timestamp = record.get(self.date_attribute)
603
+ if not isinstance(timestamp, datetime.datetime):
604
+ raise ValueError(
605
+ f"Asset {self.asset!r} cannot be repartitioned using date attribute "
606
+ f"{self.date_attribute!r} - it is not a valid datetime."
607
+ )
608
+ partition_date = truncate_datetime(timestamp, self.granularity)
609
+ _, __, writer = self._get_handler(partition_date)
610
+ writer.write(record)
611
+ store_config = self.catalog.get_store_config(self.asset)
612
+ logger.info("Finished creating partitioned avro files.")
613
+
614
+ async with AsyncTaskQueue[None](maxsize=5) as queue:
615
+ async for file, asset, writer in aiter_any(self._state.values()):
616
+ logger.info("Dumping and uploading asset from temporary file.", asset=asset, path=file.name)
617
+ writer.dump()
618
+ file.seek(0)
619
+ await queue.add_task(
620
+ self.catalog.s3_client.upload_async(
621
+ file, bucket=store_config["bucket"], name=asset.get_path(store_config["prefix"])
622
+ )
623
+ )