plexus-python-common 1.0.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1428 @@
1
+ import datetime
2
+ from typing import Protocol, Self
3
+
4
+ import pydantic as pdt
5
+ import sqlalchemy as sa
6
+ import sqlalchemy.dialects.postgresql as sa_pg
7
+ import sqlalchemy.exc as sa_exc
8
+ import sqlalchemy.orm as sa_orm
9
+ from sqlmodel import Field, SQLModel
10
+
11
+ from plexus.common.utils.datautils import validate_dt_timezone
12
+ from plexus.common.utils.jsonutils import json_datetime_encoder
13
+
14
+ __all__ = [
15
+ "compare_postgresql_types",
16
+ "model_name_of",
17
+ "validate_model_extended",
18
+ "collect_model_tables",
19
+ "model_copy_from",
20
+ "make_base_model",
21
+ "SerialModelMixinProtocol",
22
+ "RecordModelMixinProtocol",
23
+ "SnapshotModelMixinProtocol",
24
+ "RevisionModelMixinProtocol",
25
+ "SerialModelMixin",
26
+ "RecordModelMixin",
27
+ "SnapshotModelMixin",
28
+ "RevisionModelMixin",
29
+ "make_serial_model_mixin",
30
+ "make_record_model_mixin",
31
+ "make_snapshot_model_mixin",
32
+ "make_revision_model_mixin",
33
+ "serial_model_mixin",
34
+ "record_model_mixin",
35
+ "snapshot_model_mixin",
36
+ "revision_model_mixin",
37
+ "SerialModel",
38
+ "RecordModel",
39
+ "SnapshotModel",
40
+ "RevisionModel",
41
+ "clone_serial_model_instance",
42
+ "clone_record_model_instance",
43
+ "clone_snapshot_model_instance",
44
+ "clone_revision_model_instance",
45
+ "make_snapshot_model_trigger",
46
+ "make_revision_model_trigger",
47
+ "db_make_order_by_clause",
48
+ "db_create_serial_model",
49
+ "db_create_serial_models",
50
+ "db_read_serial_model",
51
+ "db_read_serial_models",
52
+ "db_update_serial_model",
53
+ "db_delete_serial_model",
54
+ "db_create_record_model",
55
+ "db_create_record_models",
56
+ "db_update_record_model",
57
+ "db_create_snapshot_model",
58
+ "db_create_snapshot_models",
59
+ "db_read_snapshot_models_of_record",
60
+ "db_read_latest_snapshot_model_of_record",
61
+ "db_read_active_snapshot_model_of_record",
62
+ "db_read_expired_snapshot_models_of_record",
63
+ "db_read_latest_snapshot_models",
64
+ "db_read_active_snapshot_models",
65
+ "db_update_snapshot_model",
66
+ "db_expire_snapshot_model",
67
+ "db_activate_snapshot_model",
68
+ "db_create_revision_model",
69
+ "db_create_revision_models",
70
+ "db_read_revision_models_of_record",
71
+ "db_read_latest_revision_model_of_record",
72
+ "db_read_active_revision_model_of_record",
73
+ "db_read_expired_revision_models_of_record",
74
+ "db_read_latest_revision_models",
75
+ "db_read_active_revision_models",
76
+ "db_update_revision_model",
77
+ "db_expire_revision_model",
78
+ "db_activate_revision_model",
79
+ ]
80
+
81
+
82
+ def compare_postgresql_types(type_a, type_b) -> bool:
83
+ """
84
+ Compares two Postgresql-specific column types to determine if they are equivalent.
85
+ This includes types from sqlalchemy.dialects.postgresql like ARRAY, JSON, UUID, etc.
86
+ """
87
+ if not isinstance(type_a, type(type_b)):
88
+ return False
89
+ if isinstance(type_a, sa_pg.ARRAY):
90
+ return compare_postgresql_types(type_a.item_type, type_b.item_type)
91
+ if isinstance(type_a, (sa_pg.VARCHAR, sa_pg.CHAR, sa_pg.TEXT)):
92
+ return type_a.length == type_b.length
93
+ if isinstance(type_a, (sa_pg.TIMESTAMP, sa_pg.TIME)):
94
+ return type_a.timezone == type_b.timezone
95
+ if isinstance(type_a, sa_pg.NUMERIC):
96
+ return type_a.precision == type_b.precision and type_a.scale == type_b.scale
97
+ return type(type_a) in {
98
+ sa_pg.BOOLEAN,
99
+ sa_pg.INTEGER,
100
+ sa_pg.BIGINT,
101
+ sa_pg.SMALLINT,
102
+ sa_pg.FLOAT,
103
+ sa_pg.DOUBLE_PRECISION,
104
+ sa_pg.REAL,
105
+ sa_pg.DATE,
106
+ sa_pg.UUID,
107
+ sa_pg.JSON,
108
+ sa_pg.JSONB,
109
+ sa_pg.HSTORE,
110
+ }
111
+
112
+
113
+ def model_name_of(model: type[SQLModel], fallback_classname: bool = True) -> str | None:
114
+ table_name = getattr(model, "__tablename__")
115
+ if not table_name:
116
+ return model.__name__ if fallback_classname else None
117
+ return table_name
118
+
119
+
120
+ def validate_model_extended(model_base: type[SQLModel], model_extended: type[SQLModel]) -> bool:
121
+ """
122
+ Validates if ``model_extended`` is an extension of ``model_base`` by checking if all fields in ``model_base``
123
+ are present in ``model_extended`` with compatible types.
124
+
125
+ :param model_base: The base model class to compare against.
126
+ :param model_extended: The model class that is expected to extend the base model.
127
+ :return: True if ``model_extended`` extends ``model_base`` correctly, False otherwise.
128
+ """
129
+ columns_a = {column.name: column.type for column in model_base.__table__.columns}
130
+ columns_b = {column.name: column.type for column in model_extended.__table__.columns}
131
+
132
+ for field_a, field_a_type in columns_a.items():
133
+ field_b_type = columns_b.get(field_a)
134
+ if field_b_type is None or not compare_postgresql_types(field_a_type, field_b_type):
135
+ return False
136
+ return True
137
+
138
+
139
+ def collect_model_tables[ModelT: SQLModel](*models: ModelT) -> sa.MetaData:
140
+ metadata = sa.MetaData()
141
+ for base in models:
142
+ for table in base.metadata.tables.values():
143
+ table.to_metadata(metadata)
144
+ return metadata
145
+
146
+
147
+ def model_copy_from[ModelT: SQLModel, ModelU: SQLModel](dst: ModelT, src: ModelU, **kwargs) -> ModelT:
148
+ if not isinstance(dst, SQLModel) or not isinstance(src, SQLModel):
149
+ raise TypeError("both 'dst' and 'src' must be instances of SQLModel or its subclasses")
150
+
151
+ for field, value in src.model_dump(**kwargs).items():
152
+ if field not in dst.model_fields:
153
+ continue
154
+ # Skip fields that are not present in the destination model
155
+ if value is None and dst.model_fields[field].required:
156
+ raise ValueError(f"field '{field}' is required but got None")
157
+
158
+ # Only set the field if it exists in the destination model
159
+ if hasattr(dst, field):
160
+ # If the field is a SQLModel, recursively copy it
161
+ if isinstance(value, SQLModel):
162
+ value = model_copy_from(getattr(dst, field), value, **kwargs)
163
+ elif isinstance(value, list) and all(isinstance(item, SQLModel) for item in value):
164
+ value = [model_copy_from(dst_item, src_item, **kwargs)
165
+ for dst_item, src_item in zip(getattr(dst, field), value)]
166
+
167
+ setattr(dst, field, value)
168
+
169
+ return dst
170
+
171
+
172
+ def make_base_model() -> type[SQLModel]:
173
+ """
174
+ Creates a base SQLModel class with custom metadata and JSON encoding for datetime fields.
175
+ Use this as a base for all models that require these configurations.
176
+ """
177
+
178
+ class BaseModel(SQLModel):
179
+ metadata = sa.MetaData()
180
+ model_config = pdt.ConfigDict(json_encoders={datetime.datetime: json_datetime_encoder})
181
+
182
+ return BaseModel
183
+
184
+
185
+ class SerialModelMixinProtocol(Protocol):
186
+ sid: int | None
187
+
188
+
189
+ class RecordModelMixinProtocol(SerialModelMixinProtocol):
190
+ created_at: datetime.datetime | None
191
+ updated_at: datetime.datetime | None
192
+
193
+ @classmethod
194
+ def make_index_created_at(cls, index_name: str) -> sa.Index:
195
+ """
196
+ Helper to create an index on the ``created_at`` field with the given index name.
197
+ :param index_name: Name of the index to create.
198
+ :return: The created SQLAlchemy Index object.
199
+ """
200
+ ...
201
+
202
+
203
+ class SnapshotModelMixinProtocol(SerialModelMixinProtocol):
204
+ created_at: datetime.datetime | None
205
+ expired_at: datetime.datetime | None
206
+ record_sid: int | None
207
+
208
+ @classmethod
209
+ def make_index_created_at_expired_at(cls, index_name: str) -> sa.Index:
210
+ """
211
+ Helper to create an index on the ``created_at`` and ``expired_at`` fields with the given index name.
212
+ :param index_name: Name of the index to create.
213
+ :return: The created SQLAlchemy Index object.
214
+ """
215
+ ...
216
+
217
+ @classmethod
218
+ def make_active_unique_index_record_sid(cls, index_name: str) -> sa.Index:
219
+ """
220
+ Helper to create a unique index on the ``record_sid`` field for active records (where ``expired_at`` is NULL).
221
+ This ensures that there is only one active snapshot per record at any given time.
222
+ :param index_name: Name of the index to create.
223
+ :return: The created SQLAlchemy Index object.
224
+ """
225
+ ...
226
+
227
+ @classmethod
228
+ def make_active_index_for(cls, index_name: str, *fields: str) -> sa.Index:
229
+ """
230
+ Helper to create a non-unique index on the specified fields for active records (where ``expired_at`` is NULL).
231
+ This allows efficient querying of active snapshots based on the specified fields.
232
+ :param index_name: Name of the index to create.
233
+ :param fields: Fields to include in the index.
234
+ :return: The created SQLAlchemy Index object.
235
+ """
236
+ ...
237
+
238
+ @classmethod
239
+ def make_active_unique_index_for(cls, index_name: str, *fields: str) -> sa.Index:
240
+ """
241
+ Helper to create a unique index on the specified fields for active records (where ``expired_at`` is NULL).
242
+ This ensures that there is only one active snapshot per combination of the specified fields at any given
243
+ time.
244
+ :param index_name: Name of the index to create.
245
+ :param fields: Fields to include in the unique index.
246
+ :return: The created SQLAlchemy Index object.
247
+ """
248
+ ...
249
+
250
+
251
+ class RevisionModelMixinProtocol(SerialModelMixinProtocol):
252
+ created_at: datetime.datetime | None
253
+ updated_at: datetime.datetime | None
254
+ expired_at: datetime.datetime | None
255
+ record_sid: int | None
256
+ revision: int | None
257
+
258
+ @classmethod
259
+ def make_index_created_at_updated_at_expired_at(cls, index_name: str) -> sa.Index:
260
+ """
261
+ Helper to create an index on the ``created_at``, ``updated_at``, and ``expired_at`` fields with the given
262
+ index name.
263
+ :param index_name: Name of the index to create.
264
+ :return: The created SQLAlchemy Index object.
265
+ """
266
+ ...
267
+
268
+ @classmethod
269
+ def make_unique_index_record_sid_revision(cls, index_name: str) -> sa.Index:
270
+ """
271
+ Helper to create a unique index on the ``record_sid`` and ``revision`` fields.
272
+ This ensures that each revision number is unique per record.
273
+ :param index_name: Name of the index to create.
274
+ :return: The created SQLAlchemy Index object.
275
+ """
276
+ ...
277
+
278
+ @classmethod
279
+ def make_active_unique_index_record_sid(cls, index_name: str) -> sa.Index:
280
+ """
281
+ Helper to create a unique index on the ``record_sid`` field for active records (where ``expired_at`` is NULL).
282
+ This ensures that there is only one active revision per record at any given time.
283
+ :param index_name: Name of the index to create.
284
+ :return: The created SQLAlchemy Index object.
285
+ """
286
+ ...
287
+
288
+ @classmethod
289
+ def make_active_index_for(cls, index_name: str, *fields: str) -> sa.Index:
290
+ """
291
+ Helper to create a non-unique index on the specified fields for active records (where ``expired_at`` is NULL).
292
+ This allows efficient querying of active revisions based on the specified fields.
293
+ :param index_name: Name of the index to create.
294
+ :param fields: Fields to include in the index.
295
+ :return: The created SQLAlchemy Index object.
296
+ """
297
+ ...
298
+
299
+ @classmethod
300
+ def make_active_unique_index_for(cls, index_name: str, *fields: str) -> sa.Index:
301
+ """
302
+ Helper to create a unique index on the specified fields for active records (where ``expired_at`` is NULL).
303
+ This ensures that there is only one active revision per combination of the specified fields at any given
304
+ time.
305
+ :param index_name: Name of the index to create.
306
+ :param fields: Fields to include in the unique index.
307
+ :return: The created SQLAlchemy Index object.
308
+ """
309
+ ...
310
+
311
+
312
+ # At the present time, we cannot express intersection of Protocol and SQLModel directly.
313
+ # Thus, we define union types here for the mixins.
314
+ SerialModelMixin = SerialModelMixinProtocol | SQLModel
315
+ RecordModelMixin = RecordModelMixinProtocol | SQLModel
316
+ SnapshotModelMixin = SnapshotModelMixinProtocol | SQLModel
317
+ RevisionModelMixin = RevisionModelMixinProtocol | SQLModel
318
+
319
+
320
+ def make_serial_model_mixin() -> type[SerialModelMixin]:
321
+ """
322
+ Creates a mixin class for SQLModel models that adds a unique identifier field `sid`.
323
+ Use this mixin to add an auto-incremented primary key to your models.
324
+ """
325
+
326
+ class ModelMixin(SQLModel):
327
+ sid: int | None = Field(
328
+ sa_column=sa.Column(sa_pg.BIGINT, primary_key=True, autoincrement=True),
329
+ default=None,
330
+ description="Unique auto-incremented primary key for the record",
331
+ )
332
+
333
+ return ModelMixin
334
+
335
+
336
+ def make_record_model_mixin() -> type[RecordModelMixin]:
337
+ """
338
+ Creates a mixin class for SQLModel models that adds common fields and validation logic for updatable records.
339
+ This mixin includes ``sid``, ``created_at``, and ``updated_at`` fields, along with validation for timestamps.
340
+ """
341
+
342
+ class ModelMixin(SQLModel):
343
+ sid: int | None = Field(
344
+ sa_column=sa.Column(sa_pg.BIGINT, primary_key=True, autoincrement=True),
345
+ default=None,
346
+ description="Unique auto-incremented primary key for the record",
347
+ )
348
+ created_at: datetime.datetime | None = Field(
349
+ sa_column=sa.Column(sa_pg.TIMESTAMP(timezone=True)),
350
+ default=None,
351
+ description="Timestamp (with timezone) when the record was created",
352
+ )
353
+ updated_at: datetime.datetime | None = Field(
354
+ sa_column=sa.Column(sa_pg.TIMESTAMP(timezone=True)),
355
+ default=None,
356
+ description="Timestamp (with timezone) when the record was last updated",
357
+ )
358
+
359
+ @pdt.field_validator("created_at", mode="after")
360
+ @classmethod
361
+ def validate_created_at(cls, v: datetime.datetime) -> datetime.datetime:
362
+ if v is not None:
363
+ validate_dt_timezone(v)
364
+ return v
365
+
366
+ @pdt.field_validator("updated_at", mode="after")
367
+ @classmethod
368
+ def validate_updated_at(cls, v: datetime.datetime) -> datetime.datetime:
369
+ if v is not None:
370
+ validate_dt_timezone(v)
371
+ return v
372
+
373
+ @pdt.model_validator(mode="after")
374
+ def validate_created_at_updated_at(self) -> Self:
375
+ if self.created_at is not None and self.updated_at is not None and self.created_at > self.updated_at:
376
+ raise ValueError(f"create time '{self.created_at}' is greater than update time '{self.updated_at}'")
377
+ return self
378
+
379
+ @classmethod
380
+ def make_index_created_at(cls, index_name: str) -> sa.Index:
381
+ return sa.Index(index_name, "created_at")
382
+
383
+ return ModelMixin
384
+
385
+
386
+ def make_snapshot_model_mixin() -> type[SnapshotModelMixin]:
387
+ """
388
+ Provides a mixin class for SQLModel models that adds common fields and validation logic for record snapshots.
389
+ A snapshot model tracks the full change history of an entity: when any field changes, the current record (with a
390
+ NULL expiration time) is updated to set its expiration time, and a new record with the updated values is created.
391
+
392
+ The mixin includes the following fields:
393
+ - ``sid``: Unique, auto-incremented primary key identifying each snapshot of the record in the change history.
394
+ - ``created_at``: Time (with timezone) when this snapshot of the record was created and became active.
395
+ - ``expired_at``: Time (with timezone) when this snapshot of the record was superseded or became inactive;
396
+ ``None`` if still active.
397
+ - ``record_sid``: Foreign key to the record this snapshot belongs to; used to link snapshots together.
398
+ """
399
+
400
+ class ModelMixin(SQLModel):
401
+ sid: int | None = Field(
402
+ sa_column=sa.Column(sa_pg.BIGINT, primary_key=True, autoincrement=True),
403
+ default=None,
404
+ description="Unique auto-incremented primary key for each record snapshot",
405
+ )
406
+ created_at: datetime.datetime | None = Field(
407
+ sa_column=sa.Column(sa_pg.TIMESTAMP(timezone=True)),
408
+ default=None,
409
+ description="Timestamp (with timezone) when this record snapshot became active",
410
+ )
411
+ expired_at: datetime.datetime | None = Field(
412
+ sa_column=sa.Column(sa_pg.TIMESTAMP(timezone=True)),
413
+ default=None,
414
+ description="Timestamp (with timezone) when this record snapshot became inactive; None if still active",
415
+ )
416
+ record_sid: int | None = Field(
417
+ sa_column=sa.Column(sa_pg.BIGINT, nullable=True),
418
+ default=None,
419
+ description="Foreign key to the record this snapshot belongs to",
420
+ )
421
+
422
+ @pdt.field_validator("created_at", mode="after")
423
+ @classmethod
424
+ def validate_created_at(cls, v: datetime.datetime) -> datetime.datetime:
425
+ if v is not None:
426
+ validate_dt_timezone(v)
427
+ return v
428
+
429
+ @pdt.field_validator("expired_at", mode="after")
430
+ @classmethod
431
+ def validate_expired_at(cls, v: datetime.datetime) -> datetime.datetime:
432
+ if v is not None:
433
+ validate_dt_timezone(v)
434
+ return v
435
+
436
+ @pdt.model_validator(mode="after")
437
+ def validate_created_at_expired_at(self) -> Self:
438
+ if self.created_at is not None and self.expired_at is not None and self.created_at > self.expired_at:
439
+ raise ValueError(f"create time '{self.created_at}' is greater than expire time '{self.expired_at}'")
440
+ return self
441
+
442
+ @classmethod
443
+ def make_index_created_at_expired_at(cls, index_name: str) -> sa.Index:
444
+ return sa.Index(index_name, "created_at", "expired_at")
445
+
446
+ @classmethod
447
+ def make_active_unique_index_record_sid(cls, index_name: str) -> sa.Index:
448
+ return sa.Index(
449
+ index_name,
450
+ "record_sid",
451
+ unique=True,
452
+ postgresql_where=sa.text('"expired_at" IS NULL'),
453
+ )
454
+
455
+ @classmethod
456
+ def make_active_index_for(cls, index_name: str, *fields: str) -> sa.Index:
457
+ return sa.Index(
458
+ index_name,
459
+ *fields,
460
+ postgresql_where=sa.text('"expired_at" IS NULL'),
461
+ )
462
+
463
+ @classmethod
464
+ def make_active_unique_index_for(cls, index_name: str, *fields: str) -> sa.Index:
465
+ return sa.Index(
466
+ index_name,
467
+ *fields,
468
+ unique=True,
469
+ postgresql_where=sa.text('"expired_at" IS NULL'),
470
+ )
471
+
472
+ return ModelMixin
473
+
474
+
475
+ def make_revision_model_mixin() -> type[RevisionModelMixin]:
476
+ """
477
+ Provides a mixin class for SQLModel models that adds common fields and validation logic for record revisions.
478
+ A revision model tracks the full change history of an entity: when any field changes, the current record (with a
479
+ NULL expiration time) is updated to set its expiration time, and a new record with the updated values is created.
480
+
481
+ The mixin includes the following fields:
482
+ - ``sid``: Unique, auto-incremented primary key identifying each revision of the record in the change history.
483
+ - ``created_at``: Time (with timezone) when the record was first created.
484
+ - ``updated_at``: Time (with timezone) when the record was updated and this record revision became active.
485
+ - ``expired_at``: Time (with timezone) when this revision of the record was superseded or became inactive;
486
+ ``None`` if still active.
487
+ - ``record_sid``: Auto-incremented key of the record this revision belongs to; used to link revisions together.
488
+ - ``revision``: Revision number for the record, used to track changes over time.
489
+ """
490
+
491
+ class ModelMixin(SQLModel):
492
+ sid: int | None = Field(
493
+ sa_column=sa.Column(sa_pg.BIGINT, primary_key=True, autoincrement=True),
494
+ default=None,
495
+ description="Unique auto-incremented primary key for each record revision",
496
+ )
497
+ created_at: datetime.datetime | None = Field(
498
+ sa_column=sa.Column(sa_pg.TIMESTAMP(timezone=True)),
499
+ default=None,
500
+ description="Timestamp (with timezone) when this record is first created (preserved across revisions)",
501
+ )
502
+ updated_at: datetime.datetime | None = Field(
503
+ sa_column=sa.Column(sa_pg.TIMESTAMP(timezone=True)),
504
+ default=None,
505
+ description="Timestamp (with timezone) when this record is updated and this record revision became active",
506
+ )
507
+ expired_at: datetime.datetime | None = Field(
508
+ sa_column=sa.Column(sa_pg.TIMESTAMP(timezone=True)),
509
+ default=None,
510
+ description="Timestamp (with timezone) when this record revision became inactive; None if still active",
511
+ )
512
+ record_sid: int | None = Field(
513
+ sa_column=sa.Column(sa_pg.BIGINT, nullable=True),
514
+ default=None,
515
+ description="Auto-incremented key of the record this revision belongs to",
516
+ )
517
+ revision: int | None = Field(
518
+ sa_column=sa.Column(sa_pg.INTEGER, nullable=True),
519
+ default=None,
520
+ description="Revision number for the record",
521
+ )
522
+
523
+ @pdt.field_validator("created_at", mode="after")
524
+ @classmethod
525
+ def validate_created_at(cls, v: datetime.datetime) -> datetime.datetime:
526
+ if v is not None:
527
+ validate_dt_timezone(v)
528
+ return v
529
+
530
+ @pdt.field_validator("updated_at", mode="after")
531
+ @classmethod
532
+ def validate_updated_at(cls, v: datetime.datetime) -> datetime.datetime:
533
+ if v is not None:
534
+ validate_dt_timezone(v)
535
+ return v
536
+
537
+ @pdt.field_validator("expired_at", mode="after")
538
+ @classmethod
539
+ def validate_expired_at(cls, v: datetime.datetime) -> datetime.datetime:
540
+ if v is not None:
541
+ validate_dt_timezone(v)
542
+ return v
543
+
544
+ @pdt.field_validator("revision", mode="after")
545
+ @classmethod
546
+ def validate_revision(cls, v: int) -> int:
547
+ if v is not None and not v > 0:
548
+ raise ValueError("revision number must be positive integer")
549
+ return v
550
+
551
+ @pdt.model_validator(mode="after")
552
+ def validate_created_at_updated_at_expired_at(self) -> Self:
553
+ if self.created_at is not None and self.updated_at is not None and self.created_at > self.updated_at:
554
+ raise ValueError(f"create time '{self.created_at}' is greater than update time '{self.updated_at}'")
555
+ if self.updated_at is not None and self.expired_at is not None and self.updated_at > self.expired_at:
556
+ raise ValueError(f"update time '{self.updated_at}' is greater than expire time '{self.expired_at}'")
557
+ return self
558
+
559
+ @classmethod
560
+ def make_index_created_at_updated_at_expired_at(cls, index_name: str) -> sa.Index:
561
+ return sa.Index(index_name, "created_at", "updated_at", "expired_at")
562
+
563
+ @classmethod
564
+ def make_unique_index_record_sid_revision(cls, index_name: str) -> sa.Index:
565
+ return sa.Index(index_name, "record_sid", "revision", unique=True)
566
+
567
+ @classmethod
568
+ def make_active_unique_index_record_sid(cls, index_name: str) -> sa.Index:
569
+ return sa.Index(
570
+ index_name,
571
+ "record_sid",
572
+ unique=True,
573
+ postgresql_where=sa.text('"expired_at" IS NULL'),
574
+ )
575
+
576
+ @classmethod
577
+ def make_active_index_for(cls, index_name: str, *fields: str) -> sa.Index:
578
+ return sa.Index(
579
+ index_name,
580
+ *fields,
581
+ postgresql_where=sa.text('"expired_at" IS NULL'),
582
+ )
583
+
584
+ @classmethod
585
+ def make_active_unique_index_for(cls, index_name: str, *fields: str) -> sa.Index:
586
+ return sa.Index(
587
+ index_name,
588
+ *fields,
589
+ unique=True,
590
+ postgresql_where=sa.text('"expired_at" IS NULL'),
591
+ )
592
+
593
+ return ModelMixin
594
+
595
+
596
+ serial_model_mixin = make_serial_model_mixin()
597
+ record_model_mixin = make_record_model_mixin()
598
+ snapshot_model_mixin = make_snapshot_model_mixin()
599
+ revision_model_mixin = make_revision_model_mixin()
600
+
601
+
602
+ class SerialModel(make_base_model(), make_serial_model_mixin(), table=True):
603
+ pass
604
+
605
+
606
+ class RecordModel(make_base_model(), make_record_model_mixin(), table=True):
607
+ pass
608
+
609
+
610
+ class SnapshotModel(make_base_model(), make_snapshot_model_mixin(), table=True):
611
+ pass
612
+
613
+
614
+ class RevisionModel(make_base_model(), make_revision_model_mixin(), table=True):
615
+ pass
616
+
617
+
618
+ def make_snapshot_model_trigger[SnapshotModelT: SnapshotModelMixin](engine: sa.Engine, model: type[SnapshotModelT]):
619
+ """
620
+ Creates the necessary database objects (sequence, function, trigger) to support automatic snapshot management
621
+ for the given snapshot model. This includes a sequence for `record_sid`, a function to handle snapshot updates,
622
+ and a trigger to invoke the function before inserts. The model must extend `SnapshotModel`.
623
+
624
+ :param engine: SQLAlchemy engine connected to the target database.
625
+ :param model: The snapshot model class extending `SnapshotModel`.
626
+ """
627
+ table_name = model_name_of(model, fallback_classname=False)
628
+ if not table_name:
629
+ raise ValueError("cannot determine table name from model")
630
+
631
+ if not validate_model_extended(SnapshotModel, model):
632
+ raise ValueError("not an extended model of 'SnapshotModel'")
633
+
634
+ record_sid_seq_name = f"{table_name}_record_sid_seq"
635
+ snapshot_auto_update_function_name = f"{table_name}_snapshot_auto_update_function"
636
+ snapshot_auto_update_trigger_name = f"{table_name}_snapshot_auto_update_trigger"
637
+
638
+ # language=postgresql
639
+ create_record_sid_seq_sql = f"""
640
+ CREATE SEQUENCE "{record_sid_seq_name}" START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1;
641
+ """
642
+
643
+ # language=postgresql
644
+ create_snapshot_auto_update_function_sql = f"""
645
+ CREATE FUNCTION "{snapshot_auto_update_function_name}"()
646
+ RETURNS TRIGGER AS $$
647
+ BEGIN
648
+ IF NEW."record_sid" IS NULL THEN
649
+ IF NEW."created_at" IS NULL THEN
650
+ NEW."created_at" := CURRENT_TIMESTAMP;
651
+ END IF;
652
+
653
+ NEW."expired_at" := NULL;
654
+ NEW."record_sid" := nextval('{record_sid_seq_name}');
655
+ ELSE
656
+ IF NEW."created_at" IS NULL THEN
657
+ NEW."created_at" := CURRENT_TIMESTAMP;
658
+ END IF;
659
+
660
+ NEW."expired_at" := NULL;
661
+
662
+ UPDATE "{table_name}"
663
+ SET "expired_at" = NEW."created_at"
664
+ WHERE "record_sid" = NEW."record_sid" AND "expired_at" IS NULL;
665
+ END IF;
666
+
667
+ RETURN NEW;
668
+ END;
669
+ $$ LANGUAGE plpgsql;
670
+ """
671
+
672
+ # language=postgresql
673
+ create_snapshot_auto_update_trigger_sql = f"""
674
+ CREATE TRIGGER "{snapshot_auto_update_trigger_name}"
675
+ BEFORE INSERT ON "{table_name}"
676
+ FOR EACH ROW
677
+ EXECUTE FUNCTION "{snapshot_auto_update_function_name}"();
678
+ """
679
+
680
+ with engine.connect() as conn:
681
+ with conn.begin():
682
+ conn.execute(sa.text(create_record_sid_seq_sql))
683
+ conn.execute(sa.text(create_snapshot_auto_update_function_sql))
684
+ conn.execute(sa.text(create_snapshot_auto_update_trigger_sql))
685
+
686
+
687
+ def make_revision_model_trigger[RevisionModelT: RevisionModelMixin](engine: sa.Engine, model: type[RevisionModelT]):
688
+ """
689
+ Creates the necessary database objects (sequence, function, trigger) to support automatic revision management
690
+ for the given revision model. This includes a sequence for `record_sid`, a function to handle revision updates,
691
+ and a trigger to invoke the function before inserts. The model must extend `RevisionModel`.
692
+
693
+ :param engine: SQLAlchemy engine connected to the target database.
694
+ :param model: The revision model class extending `RevisionModel`.
695
+ """
696
+ table_name = model_name_of(model, fallback_classname=False)
697
+ if not table_name:
698
+ raise ValueError("cannot determine table name from model")
699
+
700
+ if not validate_model_extended(RevisionModel, model):
701
+ raise ValueError("not an extended model of 'RevisionModel'")
702
+
703
+ record_sid_seq_name = f"{table_name}_record_sid_seq"
704
+ revision_auto_update_function_name = f"{table_name}_revision_auto_update_function"
705
+ revision_auto_update_trigger_name = f"{table_name}_revision_auto_update_trigger"
706
+
707
+ # language=postgresql
708
+ create_record_sid_seq_sql = f"""
709
+ CREATE SEQUENCE "{record_sid_seq_name}" START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1;
710
+ """
711
+
712
+ # language=postgresql
713
+ create_revision_auto_update_function_sql = f"""
714
+ CREATE FUNCTION "{revision_auto_update_function_name}"()
715
+ RETURNS TRIGGER AS $$
716
+ BEGIN
717
+ IF NEW."record_sid" IS NULL THEN
718
+ IF NEW."created_at" IS NULL THEN
719
+ NEW."created_at" := CURRENT_TIMESTAMP;
720
+ END IF;
721
+
722
+ NEW."updated_at" := NEW."created_at";
723
+ NEW."expired_at" := NULL;
724
+ NEW."record_sid" := nextval('{record_sid_seq_name}');
725
+ NEW."revision" := 1;
726
+ ELSE
727
+ SELECT MAX("created_at") INTO NEW."created_at"
728
+ FROM "{table_name}"
729
+ WHERE "record_sid" = NEW."record_sid";
730
+
731
+ IF NEW."updated_at" IS NULL THEN
732
+ NEW."updated_at" := CURRENT_TIMESTAMP;
733
+ END IF;
734
+
735
+ NEW."expired_at" := NULL;
736
+
737
+ SELECT COALESCE(MAX("revision"), 0) + 1 INTO NEW."revision"
738
+ FROM "{table_name}"
739
+ WHERE "record_sid" = NEW."record_sid";
740
+
741
+ UPDATE "{table_name}"
742
+ SET "expired_at" = NEW."updated_at"
743
+ WHERE "record_sid" = NEW."record_sid" AND "expired_at" IS NULL;
744
+ END IF;
745
+
746
+ RETURN NEW;
747
+ END;
748
+ $$ LANGUAGE plpgsql;
749
+ """
750
+
751
+ # language=postgresql
752
+ create_revision_auto_update_trigger_sql = f"""
753
+ CREATE TRIGGER "{revision_auto_update_trigger_name}"
754
+ BEFORE INSERT ON "{table_name}"
755
+ FOR EACH ROW
756
+ EXECUTE FUNCTION "{revision_auto_update_function_name}"();
757
+ """
758
+
759
+ with engine.connect() as conn:
760
+ with conn.begin():
761
+ conn.execute(sa.text(create_record_sid_seq_sql))
762
+ conn.execute(sa.text(create_revision_auto_update_function_sql))
763
+ conn.execute(sa.text(create_revision_auto_update_trigger_sql))
764
+
765
+
766
+ def clone_serial_model_instance[SerialModelT: SerialModelMixin](
767
+ model: type[SerialModelT],
768
+ instance: SerialModelMixin,
769
+ *,
770
+ clear_meta_fields: bool = True,
771
+ inplace: bool = False,
772
+ ) -> SerialModelT:
773
+ result = model.model_validate(instance)
774
+ result = instance if inplace else result
775
+ if clear_meta_fields:
776
+ result.sid = None
777
+ return result
778
+
779
+
780
+ def clone_record_model_instance[RecordModelT: RecordModelMixin](
781
+ model: type[RecordModelT],
782
+ instance: RecordModelMixin,
783
+ *,
784
+ clear_meta_fields: bool = True,
785
+ inplace: bool = False,
786
+ ) -> RecordModelT:
787
+ result = model.model_validate(instance)
788
+ result = instance if inplace else result
789
+ if clear_meta_fields:
790
+ result.sid = None
791
+ result.created_at = None
792
+ result.updated_at = None
793
+ return result
794
+
795
+
796
+ def clone_snapshot_model_instance[SnapshotModelT: SnapshotModelMixin](
797
+ model: type[SnapshotModelT],
798
+ instance: SnapshotModelMixin,
799
+ *,
800
+ clear_meta_fields: bool = True,
801
+ inplace: bool = False,
802
+ ) -> SnapshotModelT:
803
+ result = model.model_validate(instance)
804
+ result = instance if inplace else result
805
+ if clear_meta_fields:
806
+ result.sid = None
807
+ result.created_at = None
808
+ result.expired_at = None
809
+ result.record_sid = None
810
+ return result
811
+
812
+
813
+ def clone_revision_model_instance[RevisionModelT: RevisionModelMixin](
814
+ model: type[RevisionModelT],
815
+ instance: RevisionModelMixin,
816
+ *,
817
+ clear_meta_fields: bool = True,
818
+ inplace: bool = False,
819
+ ) -> RevisionModelT:
820
+ result = model.model_validate(instance)
821
+ result = instance if inplace else result
822
+ if clear_meta_fields:
823
+ result.sid = None
824
+ result.created_at = None
825
+ result.updated_at = None
826
+ result.expired_at = None
827
+ result.record_sid = None
828
+ result.revision = None
829
+ return result
830
+
831
+
832
+ def db_make_order_by_clause[SerialModelT: SerialModelMixin](
833
+ model: type[SerialModelT],
834
+ order_by: list[str] | None = None,
835
+ ):
836
+ order_criteria = []
837
+ if order_by:
838
+ for field in order_by:
839
+ if field.startswith("-"):
840
+ order_criteria.append(sa.desc(getattr(model, field[1:])))
841
+ else:
842
+ order_criteria.append(sa.asc(getattr(model, field)))
843
+ else:
844
+ order_criteria.append(model.sid)
845
+ return order_criteria
846
+
847
+
848
+ def db_create_serial_model[SerialModelT: SerialModelMixin](
849
+ db: sa_orm.Session,
850
+ model: type[SerialModelT],
851
+ instance: SerialModelMixin,
852
+ ) -> SerialModelT:
853
+ db_instance = clone_serial_model_instance(model, instance)
854
+ db.add(db_instance)
855
+ db.flush()
856
+
857
+ return db_instance
858
+
859
+
860
+ def db_create_serial_models[SerialModelT: SerialModelMixin](
861
+ db: sa_orm.Session,
862
+ model: type[SerialModelT],
863
+ instances: list[SerialModelMixin],
864
+ ) -> list[SerialModelT]:
865
+ db_instances = [clone_serial_model_instance(model, instance) for instance in instances]
866
+ db.add_all(db_instances)
867
+ db.flush()
868
+
869
+ return db_instances
870
+
871
+
872
+ def db_read_serial_model[SerialModelT: SerialModelMixin](
873
+ db: sa_orm.Session,
874
+ model: type[SerialModelT],
875
+ sid: int,
876
+ ) -> SerialModelT:
877
+ db_instance = db.query(model).where(model.sid == sid).one_or_none()
878
+ if db_instance is None:
879
+ raise sa_exc.NoResultFound(f"'{model_name_of(model)}' of specified sid '{sid}' not found")
880
+
881
+ return db_instance
882
+
883
+
884
+ def db_read_serial_models[SerialModelT: SerialModelMixin](
885
+ db: sa_orm.Session,
886
+ model: type[SerialModelT],
887
+ skip: int | None = None,
888
+ limit: int | None = None,
889
+ order_by: list[str] | None = None,
890
+ ) -> list[SerialModelT]:
891
+ query = db.query(model).order_by(*db_make_order_by_clause(model, order_by))
892
+ if skip is not None:
893
+ query = query.offset(skip)
894
+ if limit is not None:
895
+ query = query.limit(limit)
896
+ return query.all()
897
+
898
+
899
+ def db_update_serial_model[SerialModelT: SerialModelMixin](
900
+ db: sa_orm.Session,
901
+ model: type[SerialModelT],
902
+ instance: SerialModelMixin,
903
+ sid: int,
904
+ ) -> SerialModelT:
905
+ db_instance = db.query(model).where(model.sid == sid).one_or_none()
906
+ if db_instance is None:
907
+ raise sa_exc.NoResultFound(f"'{model_name_of(model)}' of specified sid '{sid}' not found")
908
+
909
+ db_instance = model_copy_from(db_instance, clone_serial_model_instance(model, instance), exclude_none=True)
910
+ db_instance = clone_serial_model_instance(model, db_instance, clear_meta_fields=False, inplace=True)
911
+ db.flush()
912
+
913
+ return db_instance
914
+
915
+
916
+ def db_delete_serial_model[SerialModelT: SerialModelMixin](
917
+ db: sa_orm.Session,
918
+ model: type[SerialModelT],
919
+ sid: int,
920
+ ) -> SerialModelT:
921
+ db_instance = db.query(model).where(model.sid == sid).one_or_none()
922
+ if db_instance is None:
923
+ raise sa_exc.NoResultFound(f"'{model_name_of(model)}' of specified sid '{sid}' not found")
924
+
925
+ db.delete(db_instance)
926
+ db.flush()
927
+
928
+ return db_instance
929
+
930
+
931
+ def db_create_record_model[RecordModelT: RecordModelMixin](
932
+ db: sa_orm.Session,
933
+ model: type[RecordModelT],
934
+ instance: RecordModelMixin,
935
+ created_at: datetime.datetime | None = None,
936
+ ) -> RecordModelT:
937
+ db_instance = clone_record_model_instance(model, instance)
938
+ db_instance.created_at = created_at
939
+ db_instance.updated_at = created_at
940
+ db.add(db_instance)
941
+ db.flush()
942
+
943
+ return db_instance
944
+
945
+
946
+ def db_create_record_models[RecordModelT: RecordModelMixin](
947
+ db: sa_orm.Session,
948
+ model: type[RecordModelT],
949
+ instances: list[RecordModelMixin],
950
+ created_at: datetime.datetime | None = None,
951
+ ) -> list[RecordModelT]:
952
+ db_instances = [clone_record_model_instance(model, instance) for instance in instances]
953
+ for db_instance in db_instances:
954
+ db_instance.created_at = created_at
955
+ db_instance.updated_at = created_at
956
+ db.add_all(db_instances)
957
+ db.flush()
958
+
959
+ return db_instances
960
+
961
+
962
+ def db_update_record_model[RecordModelT: RecordModelMixin](
963
+ db: sa_orm.Session,
964
+ model: type[RecordModelT],
965
+ instance: RecordModelMixin,
966
+ sid: int,
967
+ updated_at: datetime.datetime,
968
+ ) -> RecordModelT:
969
+ db_instance = db.query(model).where(model.sid == sid).one_or_none()
970
+ if db_instance is None:
971
+ raise sa_exc.NoResultFound(f"'{model_name_of(model)}' of specified sid '{sid}' not found")
972
+
973
+ db_instance = model_copy_from(db_instance, clone_record_model_instance(model, instance), exclude_none=True)
974
+ db_instance.updated_at = updated_at
975
+ db_instance = clone_record_model_instance(model, db_instance, clear_meta_fields=False, inplace=True)
976
+ db.flush()
977
+
978
+ return db_instance
979
+
980
+
981
+ def db_create_snapshot_model[SnapshotModelT: SnapshotModelMixin](
982
+ db: sa_orm.Session,
983
+ model: type[SnapshotModelT],
984
+ instance: SnapshotModelMixin,
985
+ created_at: datetime.datetime,
986
+ ) -> SnapshotModelT:
987
+ db_instance = clone_snapshot_model_instance(model, instance)
988
+ db_instance.created_at = created_at
989
+ db_instance.expired_at = None
990
+ db.add(db_instance)
991
+ db.flush()
992
+
993
+ db_instance.record_sid = db_instance.sid
994
+ db.flush()
995
+
996
+ return db_instance
997
+
998
+
999
+ def db_create_snapshot_models[SnapshotModelT: SnapshotModelMixin](
1000
+ db: sa_orm.Session,
1001
+ model: type[SnapshotModelT],
1002
+ instances: list[SnapshotModelMixin],
1003
+ created_at: datetime.datetime,
1004
+ ) -> list[SnapshotModelT]:
1005
+ db_instances = [clone_snapshot_model_instance(model, instance) for instance in instances]
1006
+ for db_instance in db_instances:
1007
+ db_instance.created_at = created_at
1008
+ db_instance.expired_at = None
1009
+ db.add_all(db_instances)
1010
+ db.flush()
1011
+
1012
+ for db_instance in db_instances:
1013
+ db_instance.record_sid = db_instance.sid
1014
+ db.flush()
1015
+
1016
+ return db_instances
1017
+
1018
+
1019
+ def db_read_snapshot_models_of_record[SnapshotModelT: SnapshotModelMixin](
1020
+ db: sa_orm.Session,
1021
+ model: type[SnapshotModelT],
1022
+ record_sid: int,
1023
+ ) -> list[SnapshotModelT]:
1024
+ return (
1025
+ db
1026
+ .query(model)
1027
+ .where(model.record_sid == record_sid)
1028
+ .order_by(model.created_at.desc())
1029
+ .all()
1030
+ )
1031
+
1032
+
1033
+ def db_read_latest_snapshot_model_of_record[SnapshotModelT: SnapshotModelMixin](
1034
+ db: sa_orm.Session,
1035
+ model: type[SnapshotModelT],
1036
+ record_sid: int,
1037
+ ) -> SnapshotModelT:
1038
+ db_instance = (
1039
+ db
1040
+ .query(model)
1041
+ .where(model.record_sid == record_sid)
1042
+ .order_by(model.created_at.desc())
1043
+ .first()
1044
+ )
1045
+ if db_instance is None:
1046
+ raise sa_exc.NoResultFound(f"'{model_name_of(model)}' of specified record_sid '{record_sid}' not found")
1047
+
1048
+ return db_instance
1049
+
1050
+
1051
+ def db_read_active_snapshot_model_of_record[SnapshotModelT: SnapshotModelMixin](
1052
+ db: sa_orm.Session,
1053
+ model: type[SnapshotModelT],
1054
+ record_sid: int,
1055
+ ) -> SnapshotModelT:
1056
+ db_instance = db.query(model).where(model.record_sid == record_sid, model.expired_at.is_(None)).one_or_none()
1057
+ if db_instance is None:
1058
+ raise sa_exc.NoResultFound(f"Active '{model_name_of(model)}' of specified record_sid '{record_sid}' not found")
1059
+
1060
+ return db_instance
1061
+
1062
+
1063
+ def db_read_expired_snapshot_models_of_record[SnapshotModelT: SnapshotModelMixin](
1064
+ db: sa_orm.Session,
1065
+ model: type[SnapshotModelT],
1066
+ record_sid: int,
1067
+ ) -> list[SnapshotModelT]:
1068
+ return (
1069
+ db
1070
+ .query(model)
1071
+ .where(model.record_sid == record_sid, model.expired_at.is_not(None))
1072
+ .order_by(model.created_at.desc())
1073
+ .all()
1074
+ )
1075
+
1076
+
1077
+ def db_read_latest_snapshot_models[SnapshotModelT: SnapshotModelMixin](
1078
+ db: sa_orm.Session,
1079
+ model: type[SnapshotModelT],
1080
+ skip: int | None = None,
1081
+ limit: int | None = None,
1082
+ order_by: list[str] | None = None,
1083
+ ) -> list[SnapshotModelT]:
1084
+ subquery = (
1085
+ db
1086
+ .query(model.record_sid,
1087
+ sa.func.max(model.created_at).label("max_created_at"))
1088
+ .group_by(model.record_sid)
1089
+ .subquery()
1090
+ )
1091
+
1092
+ query = (
1093
+ db
1094
+ .query(model)
1095
+ .join(subquery,
1096
+ sa.and_(model.record_sid == subquery.c.record_sid, model.created_at == subquery.c.max_created_at))
1097
+ .order_by(*db_make_order_by_clause(model, order_by))
1098
+ )
1099
+ if skip is not None:
1100
+ query = query.offset(skip)
1101
+ if limit is not None:
1102
+ query = query.limit(limit)
1103
+ return query.all()
1104
+
1105
+
1106
+ def db_read_active_snapshot_models[SnapshotModelT: SnapshotModelMixin](
1107
+ db: sa_orm.Session,
1108
+ model: type[SnapshotModelT],
1109
+ skip: int | None = None,
1110
+ limit: int | None = None,
1111
+ order_by: list[str] | None = None,
1112
+ ) -> list[SnapshotModelT]:
1113
+ query = db.query(model).where(model.expired_at.is_(None)).order_by(*db_make_order_by_clause(model, order_by))
1114
+ if skip is not None:
1115
+ query = query.offset(skip)
1116
+ if limit is not None:
1117
+ query = query.limit(limit)
1118
+ return query.all()
1119
+
1120
+
1121
+ def db_update_snapshot_model[SnapshotModelT: SnapshotModelMixin](
1122
+ db: sa_orm.Session,
1123
+ model: type[SnapshotModelT],
1124
+ instance: SnapshotModelMixin,
1125
+ record_sid: int,
1126
+ updated_at: datetime.datetime,
1127
+ ) -> SnapshotModelT:
1128
+ db_instance = db.query(model).where(model.record_sid == record_sid, model.expired_at.is_(None)).one_or_none()
1129
+ if db_instance is None:
1130
+ raise sa_exc.NoResultFound(f"Active '{model_name_of(model)}' of specified record_sid '{record_sid}' not found")
1131
+
1132
+ db_instance.expired_at = updated_at
1133
+ db_instance = clone_snapshot_model_instance(model, db_instance, clear_meta_fields=False, inplace=True)
1134
+ db.flush()
1135
+
1136
+ db_new_instance = clone_snapshot_model_instance(model, instance)
1137
+ db_new_instance.record_sid = record_sid
1138
+ db_new_instance.created_at = updated_at
1139
+ db_new_instance.expired_at = None
1140
+ db.add(db_new_instance)
1141
+ db.flush()
1142
+
1143
+ return db_new_instance
1144
+
1145
+
1146
+ def db_expire_snapshot_model[SnapshotModelT: SnapshotModelMixin](
1147
+ db: sa_orm.Session,
1148
+ model: type[SnapshotModelT],
1149
+ record_sid: int,
1150
+ updated_at: datetime.datetime,
1151
+ ) -> SnapshotModelT:
1152
+ db_instance = (
1153
+ db
1154
+ .query(model)
1155
+ .where(model.record_sid == record_sid, model.expired_at.is_(None))
1156
+ .one_or_none()
1157
+ )
1158
+ if db_instance is None:
1159
+ raise sa_exc.NoResultFound(f"Active '{model_name_of(model)}' of specified record_sid '{record_sid}' not found")
1160
+
1161
+ db_instance.expired_at = updated_at
1162
+ db_instance = clone_snapshot_model_instance(model, db_instance, clear_meta_fields=False, inplace=True)
1163
+ db.flush()
1164
+
1165
+ return db_instance
1166
+
1167
+
1168
+ def db_activate_snapshot_model[SnapshotModelT: SnapshotModelMixin](
1169
+ db: sa_orm.Session,
1170
+ model: type[SnapshotModelT],
1171
+ record_sid: int,
1172
+ updated_at: datetime.datetime,
1173
+ ) -> SnapshotModelT:
1174
+ db_instance = db.query(model).where(model.record_sid == record_sid, model.expired_at.is_(None)).one_or_none()
1175
+ if db_instance is not None:
1176
+ raise sa_exc.MultipleResultsFound(
1177
+ f"Active '{model_name_of(model)}' of specified record_sid '{record_sid}' already exists")
1178
+
1179
+ db_instance = (
1180
+ db
1181
+ .query(model)
1182
+ .where(model.record_sid == record_sid, model.expired_at.is_not(None))
1183
+ .order_by(model.created_at.desc())
1184
+ .first()
1185
+ )
1186
+ if db_instance is None:
1187
+ raise sa_exc.NoResultFound(f"Expired '{model_name_of(model)}' of specified record_sid '{record_sid}' not found")
1188
+
1189
+ db_new_instance = clone_snapshot_model_instance(model, db_instance)
1190
+ db_new_instance.record_sid = record_sid
1191
+ db_new_instance.created_at = db_instance.expired_at
1192
+ db_new_instance.expired_at = updated_at
1193
+ db_new_instance = clone_snapshot_model_instance(model, db_new_instance, clear_meta_fields=False, inplace=True)
1194
+ db_new_instance.created_at = updated_at
1195
+ db_new_instance.expired_at = None
1196
+ db.add(db_new_instance)
1197
+ db.flush()
1198
+
1199
+ return db_new_instance
1200
+
1201
+
1202
+ def db_create_revision_model[RevisionModelT: RevisionModelMixin](
1203
+ db: sa_orm.Session,
1204
+ model: type[RevisionModelT],
1205
+ instance: RevisionModelMixin,
1206
+ created_at: datetime.datetime,
1207
+ ) -> RevisionModelT:
1208
+ db_instance = clone_revision_model_instance(model, instance)
1209
+ db_instance.created_at = created_at
1210
+ db_instance.updated_at = created_at
1211
+ db_instance.expired_at = None
1212
+ db_instance.revision = 1
1213
+ db.add(db_instance)
1214
+ db.flush()
1215
+
1216
+ db_instance.record_sid = db_instance.sid
1217
+ db.flush()
1218
+
1219
+ return db_instance
1220
+
1221
+
1222
+ def db_create_revision_models[RevisionModelT: RevisionModelMixin](
1223
+ db: sa_orm.Session,
1224
+ model: type[RevisionModelT],
1225
+ instances: list[RevisionModelMixin],
1226
+ created_at: datetime.datetime,
1227
+ ) -> list[RevisionModelT]:
1228
+ db_instances = [clone_revision_model_instance(model, instance) for instance in instances]
1229
+ for db_instance in db_instances:
1230
+ db_instance.created_at = created_at
1231
+ db_instance.updated_at = created_at
1232
+ db_instance.expired_at = None
1233
+ db_instance.revision = 1
1234
+ db.add_all(db_instances)
1235
+ db.flush()
1236
+
1237
+ for db_instance in db_instances:
1238
+ db_instance.record_sid = db_instance.sid
1239
+ db.flush()
1240
+
1241
+ return db_instances
1242
+
1243
+
1244
+ def db_read_revision_models_of_record[RevisionModelT: RevisionModelMixin](
1245
+ db: sa_orm.Session,
1246
+ model: type[RevisionModelT],
1247
+ record_sid: int,
1248
+ ) -> list[RevisionModelT]:
1249
+ return (
1250
+ db
1251
+ .query(model)
1252
+ .where(model.record_sid == record_sid)
1253
+ .order_by(model.revision.desc())
1254
+ .all()
1255
+ )
1256
+
1257
+
1258
+ def db_read_latest_revision_model_of_record[RevisionModelT: RevisionModelMixin](
1259
+ db: sa_orm.Session,
1260
+ model: type[RevisionModelT],
1261
+ record_sid: int,
1262
+ ) -> RevisionModelT:
1263
+ db_instance = (
1264
+ db
1265
+ .query(model)
1266
+ .where(model.record_sid == record_sid)
1267
+ .order_by(model.revision.desc())
1268
+ .first()
1269
+ )
1270
+ if db_instance is None:
1271
+ raise sa_exc.NoResultFound(f"'{model_name_of(model)}' of specified record_sid '{record_sid}' not found")
1272
+
1273
+ return db_instance
1274
+
1275
+
1276
+ def db_read_active_revision_model_of_record[RevisionModelT: RevisionModelMixin](
1277
+ db: sa_orm.Session,
1278
+ model: type[RevisionModelT],
1279
+ record_sid: int,
1280
+ ) -> RevisionModelT:
1281
+ db_instance = db.query(model).where(model.record_sid == record_sid, model.expired_at.is_(None)).one_or_none()
1282
+ if db_instance is None:
1283
+ raise sa_exc.NoResultFound(f"Active '{model_name_of(model)}' of specified record_sid '{record_sid}' not found")
1284
+
1285
+ return db_instance
1286
+
1287
+
1288
+ def db_read_expired_revision_models_of_record[RevisionModelT: RevisionModelMixin](
1289
+ db: sa_orm.Session,
1290
+ model: type[RevisionModelT],
1291
+ record_sid: int,
1292
+ ) -> list[RevisionModelT]:
1293
+ return (
1294
+ db
1295
+ .query(model)
1296
+ .where(model.record_sid == record_sid, model.expired_at.is_not(None))
1297
+ .order_by(model.revision.desc())
1298
+ .all()
1299
+ )
1300
+
1301
+
1302
+ def db_read_latest_revision_models[RevisionModelT: RevisionModelMixin](
1303
+ db: sa_orm.Session,
1304
+ model: type[RevisionModelT],
1305
+ skip: int | None = None,
1306
+ limit: int | None = None,
1307
+ order_by: list[str] | None = None,
1308
+ ) -> list[RevisionModelT]:
1309
+ subquery = (
1310
+ db
1311
+ .query(model.record_sid,
1312
+ sa.func.max(model.revision).label("max_revision"))
1313
+ .group_by(model.record_sid)
1314
+ .subquery()
1315
+ )
1316
+
1317
+ query = (
1318
+ db
1319
+ .query(model)
1320
+ .join(subquery,
1321
+ sa.and_(model.record_sid == subquery.c.record_sid, model.revision == subquery.c.max_revision))
1322
+ .order_by(*db_make_order_by_clause(model, order_by))
1323
+ )
1324
+ if skip is not None:
1325
+ query = query.offset(skip)
1326
+ if limit is not None:
1327
+ query = query.limit(limit)
1328
+ return query.all()
1329
+
1330
+
1331
+ def db_read_active_revision_models[RevisionModelT: RevisionModelMixin](
1332
+ db: sa_orm.Session,
1333
+ model: type[RevisionModelT],
1334
+ skip: int | None = None,
1335
+ limit: int | None = None,
1336
+ order_by: list[str] | None = None,
1337
+ ) -> list[RevisionModelT]:
1338
+ query = db.query(model).where(model.expired_at.is_(None)).order_by(*db_make_order_by_clause(model, order_by))
1339
+ if skip is not None:
1340
+ query = query.offset(skip)
1341
+ if limit is not None:
1342
+ query = query.limit(limit)
1343
+ return query.all()
1344
+
1345
+
1346
+ def db_update_revision_model[RevisionModelT: RevisionModelMixin](
1347
+ db: sa_orm.Session,
1348
+ model: type[RevisionModelT],
1349
+ instance: RevisionModelMixin,
1350
+ record_sid: int,
1351
+ updated_at: datetime.datetime,
1352
+ ) -> RevisionModelT:
1353
+ db_instance = db.query(model).where(model.record_sid == record_sid, model.expired_at.is_(None)).one_or_none()
1354
+ if db_instance is None:
1355
+ raise sa_exc.NoResultFound(f"Active '{model_name_of(model)}' of specified record_sid '{record_sid}' not found")
1356
+
1357
+ db_instance.expired_at = updated_at
1358
+ db_instance = clone_revision_model_instance(model, db_instance, clear_meta_fields=False, inplace=True)
1359
+ db.flush()
1360
+
1361
+ db_new_instance = clone_revision_model_instance(model, instance)
1362
+ db_new_instance.record_sid = record_sid
1363
+ db_new_instance.created_at = db_instance.created_at
1364
+ db_new_instance.updated_at = updated_at
1365
+ db_new_instance.expired_at = None
1366
+ db_new_instance.revision = db_instance.revision + 1
1367
+ db.add(db_new_instance)
1368
+ db.flush()
1369
+
1370
+ return db_new_instance
1371
+
1372
+
1373
+ def db_expire_revision_model[RevisionModelT: RevisionModelMixin](
1374
+ db: sa_orm.Session,
1375
+ model: type[RevisionModelT],
1376
+ record_sid: int,
1377
+ updated_at: datetime.datetime,
1378
+ ) -> RevisionModelT:
1379
+ db_instance = (
1380
+ db
1381
+ .query(model)
1382
+ .where(model.record_sid == record_sid, model.expired_at.is_(None))
1383
+ .one_or_none()
1384
+ )
1385
+ if db_instance is None:
1386
+ raise sa_exc.NoResultFound(f"Active '{model_name_of(model)}' of specified record_sid '{record_sid}' not found")
1387
+
1388
+ db_instance.expired_at = updated_at
1389
+ db_instance = clone_revision_model_instance(model, db_instance, clear_meta_fields=False, inplace=True)
1390
+ db.flush()
1391
+
1392
+ return db_instance
1393
+
1394
+
1395
+ def db_activate_revision_model[RevisionModelT: RevisionModelMixin](
1396
+ db: sa_orm.Session,
1397
+ model: type[RevisionModelT],
1398
+ record_sid: int,
1399
+ updated_at: datetime.datetime,
1400
+ ) -> RevisionModelT:
1401
+ db_instance = db.query(model).where(model.record_sid == record_sid, model.expired_at.is_(None)).one_or_none()
1402
+ if db_instance is not None:
1403
+ raise sa_exc.MultipleResultsFound(
1404
+ f"Active '{model_name_of(model)}' of specified record_sid '{record_sid}' already exists")
1405
+
1406
+ db_instance = (
1407
+ db
1408
+ .query(model)
1409
+ .where(model.record_sid == record_sid, model.expired_at.is_not(None))
1410
+ .order_by(model.revision.desc())
1411
+ .first()
1412
+ )
1413
+ if db_instance is None:
1414
+ raise sa_exc.NoResultFound(f"Expired '{model_name_of(model)}' of specified record_sid '{record_sid}' not found")
1415
+
1416
+ db_new_instance = clone_revision_model_instance(model, db_instance)
1417
+ db_new_instance.record_sid = record_sid
1418
+ db_new_instance.created_at = db_instance.created_at
1419
+ db_new_instance.updated_at = db_instance.expired_at
1420
+ db_new_instance.expired_at = updated_at
1421
+ db_new_instance.revision = db_instance.revision + 1
1422
+ db_new_instance = clone_revision_model_instance(model, db_new_instance, clear_meta_fields=False, inplace=True)
1423
+ db_new_instance.updated_at = updated_at
1424
+ db_new_instance.expired_at = None
1425
+ db.add(db_new_instance)
1426
+ db.flush()
1427
+
1428
+ return db_new_instance