ingestify 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ingestify/exceptions.py CHANGED
@@ -8,3 +8,7 @@ class ConfigurationError(IngestifyError):
8
8
 
9
9
  class DuplicateFile(IngestifyError):
10
10
  pass
11
+
12
+
13
+ class SaveError(IngestifyError):
14
+ pass
@@ -1,50 +1,22 @@
1
- import json
2
- from datetime import datetime
3
- from typing import Type, Any, TypeVar
1
+ from ingestify.domain import DatasetCreated
2
+ from ingestify.domain.models.dataset.events import RevisionAdded
3
+ from ingestify.domain.models.event import DomainEvent
4
4
 
5
- from dataclass_factory import Schema, Factory, NameStyle
6
- from dataclass_factory.schema_helpers import type_checker
7
5
 
8
- from ingestify.domain import DatasetCreated, Identifier
9
- from ingestify.domain.models.dataset.events import MetadataUpdated, RevisionAdded
6
+ event_types = {
7
+ DatasetCreated.event_type: DatasetCreated,
8
+ RevisionAdded.event_type: RevisionAdded,
9
+ }
10
10
 
11
- isotime_schema = Schema(
12
- parser=lambda x: datetime.fromisoformat(x.replace("Z", "+00:00")), # type: ignore
13
- serializer=lambda x: datetime.isoformat(x).replace("+00:00", "Z"),
14
- )
15
11
 
16
- identifier_schema = Schema(
17
- # json.loads(x) for backwards compatibility
18
- parser=lambda x: Identifier(x if isinstance(x, dict) else json.loads(x)),
19
- serializer=lambda x: dict(x),
20
- )
12
+ def deserialize(event_dict: dict) -> DomainEvent:
13
+ event_cls = event_types[event_dict["event_type"]]
14
+ return event_cls.model_validate(event_dict)
21
15
 
22
- factory = Factory(
23
- schemas={
24
- datetime: isotime_schema,
25
- Identifier: identifier_schema,
26
- DatasetCreated: Schema(
27
- pre_parse=type_checker(DatasetCreated.event_type, "event_type")
28
- ),
29
- MetadataUpdated: Schema(
30
- pre_parse=type_checker(MetadataUpdated.event_type, "event_type")
31
- ),
32
- RevisionAdded: Schema(
33
- pre_parse=type_checker(RevisionAdded.event_type, "event_type")
34
- ),
35
- # ClipSelectionContent: Schema(pre_parse=type_checker(ClipSelectionContent.content_type, field="contentType")),
36
- # TeamInfoImageContent: Schema(pre_parse=type_checker(TeamInfoImageContent.content_type, field="contentType")),
37
- # StaticVideoContent: Schema(pre_parse=type_checker(StaticVideoContent.content_type, field="contentType"))
38
- },
39
- default_schema=Schema(),
40
- )
41
16
 
42
- T = TypeVar("T")
17
+ def serialize(event: DomainEvent) -> dict:
18
+ event_dict = event.model_dump(mode="json")
43
19
 
44
-
45
- def serialize(data: T, class_: Type[T] = None) -> Any:
46
- return factory.dump(data, class_)
47
-
48
-
49
- def unserialize(data: Any, class_: Type[T]) -> T:
50
- return factory.load(data, class_)
20
+ # Make sure event_type is always part of the event_dict. Pydantic might skip it when the type is ClassVar
21
+ event_dict["event_type"] = event.event_type
22
+ return event_dict
@@ -1,13 +1,24 @@
1
+ import itertools
1
2
  import json
2
3
  import uuid
4
+ from collections import defaultdict
3
5
  from typing import Optional, Union, List
4
6
 
5
- from sqlalchemy import create_engine, func, text, tuple_
7
+ from sqlalchemy import (
8
+ create_engine,
9
+ func,
10
+ text,
11
+ tuple_,
12
+ Table,
13
+ insert,
14
+ Transaction,
15
+ Connection,
16
+ )
6
17
  from sqlalchemy.engine import make_url
7
18
  from sqlalchemy.exc import NoSuchModuleError
8
19
  from sqlalchemy.orm import Session, joinedload
9
20
 
10
- from ingestify.domain import File
21
+ from ingestify.domain import File, Revision
11
22
  from ingestify.domain.models import (
12
23
  Dataset,
13
24
  DatasetCollection,
@@ -15,11 +26,22 @@ from ingestify.domain.models import (
15
26
  Identifier,
16
27
  Selector,
17
28
  )
29
+ from ingestify.domain.models.base import BaseModel
18
30
  from ingestify.domain.models.dataset.collection_metadata import (
19
31
  DatasetCollectionMetadata,
20
32
  )
21
-
22
- from .mapping import dataset_table, metadata
33
+ from ingestify.domain.models.ingestion.ingestion_job_summary import IngestionJobSummary
34
+ from ingestify.domain.models.task.task_summary import TaskSummary
35
+ from ingestify.exceptions import IngestifyError
36
+
37
+ from .tables import (
38
+ metadata,
39
+ dataset_table,
40
+ file_table,
41
+ revision_table,
42
+ ingestion_job_summary_table,
43
+ task_summary_table,
44
+ )
23
45
 
24
46
 
25
47
  def parse_value(v):
@@ -113,6 +135,31 @@ class SqlAlchemyDatasetRepository(DatasetRepository):
113
135
  def session(self):
114
136
  return self.session_provider.get()
115
137
 
138
+ def _upsert(self, connection: Connection, table: Table, entities: list[dict]):
139
+ dialect = self.session.bind.dialect.name
140
+ if dialect == "mysql":
141
+ from sqlalchemy.dialects.mysql import insert
142
+ elif dialect == "postgresql":
143
+ from sqlalchemy.dialects.postgresql import insert
144
+ elif dialect == "sqlite":
145
+ from sqlalchemy.dialects.sqlite import insert
146
+ else:
147
+ raise IngestifyError(f"Don't know how to do an upsert in {dialect}")
148
+
149
+ stmt = insert(table).values(entities)
150
+
151
+ primary_key_columns = [column for column in table.columns if column.primary_key]
152
+
153
+ set_ = {
154
+ name: getattr(stmt.excluded, name)
155
+ for name, column in table.columns.items()
156
+ if column not in primary_key_columns
157
+ }
158
+
159
+ stmt = stmt.on_conflict_do_update(index_elements=primary_key_columns, set_=set_)
160
+
161
+ connection.execute(stmt)
162
+
116
163
  def _filter_query(
117
164
  self,
118
165
  query,
@@ -122,11 +169,11 @@ class SqlAlchemyDatasetRepository(DatasetRepository):
122
169
  dataset_id: Optional[Union[str, List[str]]] = None,
123
170
  selector: Optional[Union[Selector, List[Selector]]] = None,
124
171
  ):
125
- query = query.filter(Dataset.bucket == bucket)
172
+ query = query.filter(dataset_table.c.bucket == bucket)
126
173
  if dataset_type:
127
- query = query.filter(Dataset.dataset_type == dataset_type)
174
+ query = query.filter(dataset_table.c.dataset_type == dataset_type)
128
175
  if provider:
129
- query = query.filter(Dataset.provider == provider)
176
+ query = query.filter(dataset_table.c.provider == provider)
130
177
  if dataset_id is not None:
131
178
  if isinstance(dataset_id, list):
132
179
  if len(dataset_id) == 0:
@@ -134,9 +181,9 @@ class SqlAlchemyDatasetRepository(DatasetRepository):
134
181
  # return an empty DatasetCollection
135
182
  return DatasetCollection()
136
183
 
137
- query = query.filter(Dataset.dataset_id.in_(dataset_id))
184
+ query = query.filter(dataset_table.c.dataset_id.in_(dataset_id))
138
185
  else:
139
- query = query.filter(Dataset.dataset_id == dataset_id)
186
+ query = query.filter(dataset_table.c.dataset_id == dataset_id)
140
187
 
141
188
  dialect = self.session.bind.dialect.name
142
189
 
@@ -175,7 +222,7 @@ class SqlAlchemyDatasetRepository(DatasetRepository):
175
222
  else:
176
223
  column = column.as_string()
177
224
  else:
178
- column = func.json_extract(Dataset.identifier, f"$.{k}")
225
+ column = func.json_extract(dataset_table.c.identifier, f"$.{k}")
179
226
  columns.append(column)
180
227
 
181
228
  values = []
@@ -189,6 +236,60 @@ class SqlAlchemyDatasetRepository(DatasetRepository):
189
236
  query = query.filter(text(where))
190
237
  return query
191
238
 
239
+ def load_datasets(self, dataset_ids: list[str]) -> list[Dataset]:
240
+ if not dataset_ids:
241
+ return []
242
+
243
+ dataset_rows = list(
244
+ self.session.query(dataset_table).filter(
245
+ dataset_table.c.dataset_id.in_(dataset_ids)
246
+ )
247
+ )
248
+ revisions_per_dataset = {}
249
+ rows = (
250
+ self.session.query(revision_table)
251
+ .filter(revision_table.c.dataset_id.in_(dataset_ids))
252
+ .order_by(revision_table.c.dataset_id)
253
+ )
254
+
255
+ for dataset_id, revisions in itertools.groupby(
256
+ rows, key=lambda row: row.dataset_id
257
+ ):
258
+ revisions_per_dataset[dataset_id] = list(revisions)
259
+
260
+ files_per_revision = {}
261
+ rows = (
262
+ self.session.query(file_table)
263
+ .filter(file_table.c.dataset_id.in_(dataset_ids))
264
+ .order_by(file_table.c.dataset_id, file_table.c.revision_id)
265
+ )
266
+
267
+ for (dataset_id, revision_id), files in itertools.groupby(
268
+ rows, key=lambda row: (row.dataset_id, row.revision_id)
269
+ ):
270
+ files_per_revision[(dataset_id, revision_id)] = list(files)
271
+
272
+ datasets = []
273
+ for dataset_row in dataset_rows:
274
+ dataset_id = dataset_row.dataset_id
275
+ revisions = []
276
+ for revision_row in revisions_per_dataset.get(dataset_id, []):
277
+ files = [
278
+ File.model_validate(file_row)
279
+ for file_row in files_per_revision.get(
280
+ (dataset_id, revision_row.revision_id), []
281
+ )
282
+ ]
283
+ revision = Revision.model_validate(
284
+ {**revision_row._mapping, "modified_files": files}
285
+ )
286
+ revisions.append(revision)
287
+
288
+ datasets.append(
289
+ Dataset.model_validate({**dataset_row._mapping, "revisions": revisions})
290
+ )
291
+ return datasets
292
+
192
293
  def get_dataset_collection(
193
294
  self,
194
295
  bucket: str,
@@ -209,17 +310,19 @@ class SqlAlchemyDatasetRepository(DatasetRepository):
209
310
  )
210
311
 
211
312
  if not metadata_only:
212
- dataset_query = apply_query_filter(self.session.query(Dataset))
213
- datasets = list(dataset_query)
313
+ dataset_query = apply_query_filter(
314
+ self.session.query(dataset_table.c.dataset_id)
315
+ )
316
+ dataset_ids = [row.dataset_id for row in dataset_query]
317
+ datasets = self.load_datasets(dataset_ids)
214
318
  else:
215
319
  datasets = []
216
320
 
217
321
  metadata_result_row = apply_query_filter(
218
322
  self.session.query(
219
- func.min(File.modified_at).label("first_modified_at"),
220
- func.max(File.modified_at).label("last_modified_at"),
323
+ func.max(dataset_table.c.last_modified_at).label("last_modified_at"),
221
324
  func.count().label("row_count"),
222
- ).join(Dataset, Dataset.dataset_id == File.dataset_id)
325
+ )
223
326
  ).first()
224
327
  dataset_collection_metadata = DatasetCollectionMetadata(*metadata_result_row)
225
328
 
@@ -228,12 +331,153 @@ class SqlAlchemyDatasetRepository(DatasetRepository):
228
331
  def save(self, bucket: str, dataset: Dataset):
229
332
  # Just make sure
230
333
  dataset.bucket = bucket
231
- self.session.add(dataset)
232
- self.session.commit()
334
+
335
+ self._save([dataset])
336
+
337
+ def connect(self):
338
+ return self.session_provider.engine.connect()
339
+
340
+ def _save(self, datasets: list[Dataset]):
341
+ """Only do upserts. Never delete. Rows get only deleted when an entire Dataset is removed."""
342
+ datasets_entities = []
343
+ revision_entities = []
344
+ file_entities = []
345
+
346
+ for dataset in datasets:
347
+ datasets_entities.append(dataset.model_dump(exclude={"revisions"}))
348
+ for revision in dataset.revisions:
349
+ revision_entities.append(
350
+ {
351
+ **revision.model_dump(
352
+ exclude={"is_squashed", "modified_files"}
353
+ ),
354
+ "dataset_id": dataset.dataset_id,
355
+ }
356
+ )
357
+ for file in revision.modified_files:
358
+ file_entities.append(
359
+ {
360
+ **file.model_dump(),
361
+ "dataset_id": dataset.dataset_id,
362
+ "revision_id": revision.revision_id,
363
+ }
364
+ )
365
+
366
+ with self.connect() as connection:
367
+ try:
368
+ self._upsert(connection, dataset_table, datasets_entities)
369
+ self._upsert(connection, revision_table, revision_entities)
370
+ self._upsert(connection, file_table, file_entities)
371
+ except Exception:
372
+ connection.rollback()
373
+ raise
374
+ else:
375
+ connection.commit()
233
376
 
234
377
  def destroy(self, dataset: Dataset):
235
- self.session.delete(dataset)
236
- self.session.commit()
378
+ with self.connect() as connection:
379
+ try:
380
+ # Delete modified files related to the dataset
381
+ file_table.delete().where(
382
+ file_table.c.dataset_id == dataset.dataset_id
383
+ ).execute()
384
+
385
+ # Delete revisions related to the dataset
386
+ revision_table.delete().where(
387
+ revision_table.c.dataset_id == dataset.dataset_id
388
+ ).execute()
389
+
390
+ # Delete the dataset itself
391
+ dataset_table.delete().where(
392
+ dataset_table.c.dataset_id == dataset.dataset_id
393
+ ).execute()
394
+
395
+ connection.commit()
396
+ except Exception:
397
+ connection.rollback()
398
+ raise
237
399
 
238
400
  def next_identity(self):
239
401
  return str(uuid.uuid4())
402
+
403
+ # TODO: consider moving the IngestionJobSummary methods to a different Repository
404
+ def save_ingestion_job_summary(self, ingestion_job_summary: IngestionJobSummary):
405
+ ingestion_job_summary_entities = [
406
+ ingestion_job_summary.model_dump(exclude={"task_summaries"})
407
+ ]
408
+ task_summary_entities = []
409
+ for task_summary in ingestion_job_summary.task_summaries:
410
+ task_summary_entities.append(
411
+ {
412
+ **task_summary.model_dump(),
413
+ "ingestion_job_summary_id": ingestion_job_summary.ingestion_job_summary_id,
414
+ }
415
+ )
416
+
417
+ with self.session_provider.engine.connect() as connection:
418
+ try:
419
+ self._upsert(
420
+ connection,
421
+ ingestion_job_summary_table,
422
+ ingestion_job_summary_entities,
423
+ )
424
+ if task_summary_entities:
425
+ self._upsert(connection, task_summary_table, task_summary_entities)
426
+ except Exception:
427
+ connection.rollback()
428
+ raise
429
+ else:
430
+ connection.commit()
431
+
432
+ def load_ingestion_job_summaries(self) -> list[IngestionJobSummary]:
433
+ ingestion_job_summary_ids = [
434
+ row.ingestion_job_summary_id
435
+ for row in self.session.query(
436
+ ingestion_job_summary_table.c.ingestion_job_summary_id
437
+ )
438
+ ]
439
+
440
+ ingestion_job_summary_rows = list(
441
+ self.session.query(ingestion_job_summary_table).filter(
442
+ ingestion_job_summary_table.c.ingestion_job_summary_id.in_(
443
+ ingestion_job_summary_ids
444
+ )
445
+ )
446
+ )
447
+
448
+ task_summary_entities_per_job_summary = {}
449
+ rows = (
450
+ self.session.query(task_summary_table)
451
+ .filter(
452
+ task_summary_table.c.ingestion_job_summary_id.in_(
453
+ ingestion_job_summary_ids
454
+ )
455
+ )
456
+ .order_by(task_summary_table.c.ingestion_job_summary_id)
457
+ )
458
+
459
+ for ingestion_job_summary_id, task_summaries_rows in itertools.groupby(
460
+ rows, key=lambda row: row.ingestion_job_summary_id
461
+ ):
462
+ task_summary_entities_per_job_summary[ingestion_job_summary_id] = list(
463
+ task_summaries_rows
464
+ )
465
+
466
+ ingestion_job_summaries = []
467
+ for ingestion_job_summary_row in ingestion_job_summary_rows:
468
+ task_summaries = [
469
+ TaskSummary.model_validate(row)
470
+ for row in task_summary_entities_per_job_summary.get(
471
+ ingestion_job_summary_row.ingestion_job_summary_id, []
472
+ )
473
+ ]
474
+
475
+ ingestion_job_summaries.append(
476
+ IngestionJobSummary.model_validate(
477
+ {
478
+ **ingestion_job_summary_row._mapping,
479
+ "task_summaries": task_summaries,
480
+ }
481
+ )
482
+ )
483
+ return ingestion_job_summaries