plexus-python-common 1.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of plexus-python-common might be problematic. Click here for more details.

@@ -0,0 +1,195 @@
1
+ import datetime
2
+ from collections.abc import Callable, Generator
3
+ from typing import Any
4
+
5
+ import pyparsing as pp
6
+ import ujson as json
7
+ from iker.common.utils.funcutils import singleton
8
+ from iker.common.utils.jsonutils import JsonType
9
+ from iker.common.utils.randutils import randomizer
10
+
11
+ from plexus.common.utils.strutils import BagName, UserName, VehicleName
12
+ from plexus.common.utils.strutils import dot_case_parser, kebab_case_parser, snake_case_parser
13
+ from plexus.common.utils.strutils import hex_string_parser
14
+ from plexus.common.utils.strutils import parse_bag_name, parse_user_name, parse_vehicle_name
15
+ from plexus.common.utils.strutils import strict_abspath_parser, strict_relpath_parser
16
+ from plexus.common.utils.strutils import tag_parser, topic_parser, vin_code_chars, vin_code_parser
17
+ from plexus.common.utils.strutils import uuid_parser
18
+
19
+
20
+ def make_compute_vin_code_check_digit() -> Callable[[str], str]:
21
+ trans_nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 7, 9, 2, 3, 4, 5, 6, 7, 8, 9]
22
+ weights = [8, 7, 6, 5, 4, 3, 2, 10, 0, 9, 8, 7, 6, 5, 4, 3, 2]
23
+
24
+ trans_dict = {vin_code_char: trans_num for vin_code_char, trans_num in zip(vin_code_chars, trans_nums)}
25
+
26
+ def func(vin_code: str) -> str:
27
+ remainder = sum(trans_dict[vin_code_char] * weight for vin_code_char, weight in zip(vin_code, weights)) % 11
28
+ return "X" if remainder == 10 else str(remainder)
29
+
30
+ return func
31
+
32
+
33
+ compute_vin_code_check_digit = make_compute_vin_code_check_digit()
34
+
35
+
36
+ def make_validate_string(element: pp.ParserElement) -> Callable[[str], None]:
37
+ def func(s: str) -> None:
38
+ try:
39
+ if not element.parse_string(s, parse_all=True):
40
+ raise ValueError(f"failed to parse '{s}'")
41
+ except Exception as e:
42
+ raise ValueError(f"encountered error while parsing '{s}'") from e
43
+
44
+ return func
45
+
46
+
47
+ validate_hex_string = make_validate_string(hex_string_parser)
48
+
49
+ validate_snake_case = make_validate_string(snake_case_parser)
50
+ validate_kebab_case = make_validate_string(kebab_case_parser)
51
+ validate_dot_case = make_validate_string(dot_case_parser)
52
+
53
+ validate_uuid = make_validate_string(uuid_parser)
54
+
55
+ validate_strict_relpath = make_validate_string(strict_relpath_parser)
56
+ validate_strict_abspath = make_validate_string(strict_abspath_parser)
57
+
58
+ validate_tag = make_validate_string(tag_parser)
59
+ validate_topic = make_validate_string(topic_parser)
60
+
61
+
62
+ def validate_vin_code(vin_code: str):
63
+ make_validate_string(vin_code_parser)(vin_code)
64
+ check_digit = compute_vin_code_check_digit(vin_code)
65
+ if check_digit != vin_code[8]:
66
+ raise ValueError(f"get wrong VIN code check digit from '{vin_code}', expected '{check_digit}'")
67
+
68
+
69
+ def make_validate_parse_string(parse: Callable[[str], Any]) -> Callable[[str], None]:
70
+ def func(s: str) -> None:
71
+ try:
72
+ if not parse(s):
73
+ raise ValueError(f"failed to parse '{s}'")
74
+ except Exception as e:
75
+ raise ValueError(f"encountered error while parsing '{s}'") from e
76
+
77
+ return func
78
+
79
+
80
+ validate_user_name = make_validate_parse_string(parse_user_name)
81
+ validate_vehicle_name = make_validate_parse_string(parse_vehicle_name)
82
+ validate_bag_name = make_validate_parse_string(parse_bag_name)
83
+
84
+
85
+ def validate_dt_timezone(dt: datetime.datetime):
86
+ if dt.tzinfo != datetime.timezone.utc:
87
+ raise ValueError(f"dt '{dt}' is not in UTC")
88
+
89
+
90
+ def validate_json_type_dump_size(json_type: JsonType, dump_size_limit: int = 10000):
91
+ dump_string = json.dumps(json_type, ensure_ascii=False)
92
+ if len(dump_string) > dump_size_limit:
93
+ raise ValueError(f"dump size exceeds the maximum length '{dump_size_limit}'")
94
+
95
+
96
+ def random_vin_code() -> str:
97
+ vin_code = randomizer().random_string(vin_code_chars, 17)
98
+ check_digit = compute_vin_code_check_digit(vin_code)
99
+ return vin_code[:8] + check_digit + vin_code[9:]
100
+
101
+
102
+ @singleton
103
+ def known_topics() -> list[str]:
104
+ return [
105
+ "/sensor/camera/front_center",
106
+ "/sensor/camera/front_left",
107
+ "/sensor/camera/front_right",
108
+ "/sensor/camera/side_left",
109
+ "/sensor/camera/side_right",
110
+ "/sensor/camera/rear_left",
111
+ "/sensor/camera/rear_right",
112
+ "/sensor/lidar/front_center",
113
+ "/sensor/lidar/front_left_corner",
114
+ "/sensor/lidar/front_right_corner",
115
+ "/sensor/lidar/side_left",
116
+ "/sensor/lidar/side_right",
117
+ ]
118
+
119
+
120
+ @singleton
121
+ def known_user_names() -> list[UserName]:
122
+ return [
123
+ UserName("adam", "anderson"),
124
+ UserName("ben", "bennett"),
125
+ UserName("charlie", "clark"),
126
+ UserName("david", "dixon"),
127
+ UserName("evan", "edwards"),
128
+ UserName("frank", "fisher"),
129
+ UserName("george", "graham"),
130
+ UserName("henry", "harrison"),
131
+ UserName("isaac", "irving"),
132
+ UserName("jack", "jacobs"),
133
+ UserName("kevin", "kennedy"),
134
+ UserName("luke", "lawson"),
135
+ UserName("michael", "mitchell"),
136
+ UserName("nathan", "newton"),
137
+ UserName("oscar", "owens"),
138
+ UserName("paul", "peterson"),
139
+ UserName("quincy", "quinn"),
140
+ UserName("ryan", "robinson"),
141
+ UserName("sam", "stevens"),
142
+ UserName("tom", "thomas"),
143
+ UserName("umar", "underwood"),
144
+ UserName("victor", "vaughan"),
145
+ UserName("william", "walker"),
146
+ UserName("xander", "xavier"),
147
+ UserName("yale", "young"),
148
+ UserName("zane", "zimmerman"),
149
+ ]
150
+
151
+
152
+ @singleton
153
+ def known_vehicle_names() -> list[VehicleName]:
154
+ return [
155
+ VehicleName("cascadia", "antelope", "00000", "3AKJGLD5XLS000000"),
156
+ VehicleName("cascadia", "bear", "00001", "3AKJGLD51LS000001"),
157
+ VehicleName("cascadia", "cheetah", "00002", "3AKJGLD53LS000002"),
158
+ VehicleName("cascadia", "dolphin", "00003", "3AKJGLD55LS000003"),
159
+ VehicleName("cascadia", "eagle", "00004", "3AKJGLD57LS000004"),
160
+ VehicleName("cascadia", "falcon", "00005", "3AKJGLD59LS000005"),
161
+ VehicleName("cascadia", "gorilla", "00006", "3AKJGLD50LS000006"),
162
+ VehicleName("cascadia", "hawk", "00007", "3AKJGLD52LS000007"),
163
+ VehicleName("cascadia", "iguana", "00008", "3AKJGLD54LS000008"),
164
+ VehicleName("cascadia", "jaguar", "00009", "3AKJGLD56LS000009"),
165
+ VehicleName("cascadia", "koala", "00010", "3AKJGLD52LS000010"),
166
+ VehicleName("cascadia", "leopard", "00011", "3AKJGLD54LS000011"),
167
+ VehicleName("cascadia", "mongoose", "00012", "3AKJGLD56LS000012"),
168
+ VehicleName("cascadia", "narwhal", "00013", "3AKJGLD58LS000013"),
169
+ VehicleName("cascadia", "otter", "00014", "3AKJGLD5XLS000014"),
170
+ VehicleName("cascadia", "panther", "00015", "3AKJGLD51LS000015"),
171
+ VehicleName("cascadia", "quail", "00016", "3AKJGLD53LS000016"),
172
+ VehicleName("cascadia", "rhino", "00017", "3AKJGLD55LS000017"),
173
+ VehicleName("cascadia", "snake", "00018", "3AKJGLD57LS000018"),
174
+ VehicleName("cascadia", "tiger", "00019", "3AKJGLD59LS000019"),
175
+ VehicleName("cascadia", "urial", "00020", "3AKJGLD55LS000020"),
176
+ VehicleName("cascadia", "vulture", "00021", "3AKJGLD57LS000021"),
177
+ VehicleName("cascadia", "wolf", "00022", "3AKJGLD59LS000022"),
178
+ VehicleName("cascadia", "xerus", "00023", "3AKJGLD50LS000023"),
179
+ VehicleName("cascadia", "yak", "00024", "3AKJGLD52LS000024"),
180
+ VehicleName("cascadia", "zebra", "00025", "3AKJGLD54LS000025"),
181
+ ]
182
+
183
+
184
+ def random_bag_names_sequence(
185
+ min_record_dt: datetime.datetime,
186
+ max_record_dt: datetime.datetime,
187
+ min_sequence_length: int,
188
+ max_sequence_length: int,
189
+ ) -> Generator[BagName, None, None]:
190
+ vehicle_name = randomizer().choose(known_vehicle_names())
191
+ record_dt = randomizer().random_datetime(begin=min_record_dt, end=max_record_dt)
192
+ bags_count = randomizer().next_int(min_sequence_length, max_sequence_length)
193
+
194
+ for record_sn in range(bags_count):
195
+ yield BagName(vehicle_name=vehicle_name, record_dt=record_dt, record_sn=record_sn)
@@ -0,0 +1,92 @@
1
+ import datetime
2
+ import os
3
+ from collections.abc import Generator, Iterable
4
+ from typing import Any
5
+
6
+ import ujson as json
7
+ from iker.common.utils.dtutils import dt_format, dt_parse, extended_format
8
+ from iker.common.utils.jsonutils import JsonType, JsonValueCompatible
9
+ from iker.common.utils.jsonutils import json_reformat
10
+ from iker.common.utils.sequtils import batch_yield
11
+
12
+ from plexus.common.utils.shutils import collect_volumed_filenames, populate_volumed_filenames
13
+
14
+ __all__ = [
15
+ "json_datetime_decoder",
16
+ "json_datetime_encoder",
17
+ "json_loads",
18
+ "json_dumps",
19
+ "read_chunked_jsonl",
20
+ "write_chunked_jsonl",
21
+ ]
22
+
23
+
24
+ def json_datetime_decoder(v: Any) -> datetime.datetime:
25
+ if isinstance(v, str):
26
+ return json_datetime_decoder(dt_parse(v, extended_format(with_us=True, with_tz=True)))
27
+ if isinstance(v, datetime.datetime):
28
+ return v.replace(tzinfo=datetime.timezone.utc)
29
+ raise ValueError("unexpected type of value for datetime decoder")
30
+
31
+
32
+ def json_datetime_encoder(v: Any) -> str:
33
+ if isinstance(v, str):
34
+ return json_datetime_encoder(dt_parse(v, extended_format(with_us=True, with_tz=True)))
35
+ if isinstance(v, datetime.datetime):
36
+ return dt_format(v.replace(tzinfo=datetime.timezone.utc), extended_format(with_us=True, with_tz=True))
37
+ raise ValueError("unexpected type of value for datetime encoder")
38
+
39
+
40
+ def json_deserializer(obj):
41
+ def value_formatter(value: JsonValueCompatible) -> JsonType:
42
+ if not isinstance(value, str):
43
+ return value
44
+ try:
45
+ return dt_parse(value, extended_format(with_us=True, with_tz=True))
46
+ except Exception:
47
+ return value
48
+
49
+ return json_reformat(obj, value_formatter=value_formatter)
50
+
51
+
52
+ def json_serializer(obj):
53
+ def unregistered_formatter(unregistered: Any) -> JsonType:
54
+ if isinstance(unregistered, datetime.datetime):
55
+ return dt_format(unregistered, extended_format(with_us=True, with_tz=True))
56
+ return None
57
+
58
+ return json_reformat(obj, raise_if_unregistered=False, unregistered_formatter=unregistered_formatter)
59
+
60
+
61
+ def json_loads(s: str) -> JsonType:
62
+ return json_deserializer(json.loads(s))
63
+
64
+
65
+ def json_dumps(obj: JsonType) -> str:
66
+ return json.dumps(json_serializer(obj), ensure_ascii=False, escape_forward_slashes=False)
67
+
68
+
69
+ def read_chunked_jsonl(template: str) -> Generator[tuple[JsonType, str], None, None]:
70
+ for path, _ in collect_volumed_filenames(template):
71
+ with open(path, mode="r", encoding="utf-8") as fh:
72
+ for line in fh:
73
+ yield json_loads(line), path
74
+
75
+
76
+ def write_chunked_jsonl(records: Iterable[JsonType], template: str, chunk_size: int) -> list[tuple[str, int]]:
77
+ generator = populate_volumed_filenames(template)
78
+ entry = []
79
+ for batch_index, batch in enumerate(batch_yield(records, chunk_size)):
80
+ path, _ = next(generator)
81
+ lines = 0
82
+ with open(path, mode="w") as fh:
83
+ for record in batch:
84
+ fh.write(json_dumps(record))
85
+ fh.write("\n")
86
+ lines += 1
87
+ entry.append((path, lines))
88
+ if len(entry) == 1:
89
+ path, lines = entry[0]
90
+ os.rename(path, template)
91
+ return [(template, lines)]
92
+ return entry
@@ -0,0 +1,335 @@
1
+ import datetime
2
+ from typing import Self, TypeVar
3
+
4
+ import pydantic as pdt
5
+ import sqlalchemy as sa
6
+ import sqlalchemy.dialects.postgresql as sa_pg
7
+ from sqlmodel import Field, SQLModel
8
+
9
+ from plexus.common.utils.datautils import validate_dt_timezone
10
+ from plexus.common.utils.jsonutils import json_datetime_encoder
11
+
12
+ __all__ = [
13
+ "compare_postgresql_types",
14
+ "validate_model_extended",
15
+ "collect_model_tables",
16
+ "model_copy_from",
17
+ "make_base_model",
18
+ "make_serial_model_mixin",
19
+ "make_record_model_mixin",
20
+ "make_snapshot_model_mixin",
21
+ "serial_model_mixin",
22
+ "record_model_mixin",
23
+ "snapshot_model_mixin",
24
+ "SerialModel",
25
+ "RecordModel",
26
+ "SnapshotModel",
27
+ "make_snapshot_model_trigger",
28
+ ]
29
+
30
+ ModelT = TypeVar("ModelT", bound=SQLModel)
31
+ ModelU = TypeVar("ModelU", bound=SQLModel)
32
+
33
+
34
+ def compare_postgresql_types(type_a, type_b) -> bool:
35
+ """
36
+ Compares two Postgresql-specific column types to determine if they are equivalent.
37
+ This includes types from sqlalchemy.dialects.postgresql like ARRAY, JSON, UUID, etc.
38
+ """
39
+ if not isinstance(type_a, type(type_b)):
40
+ return False
41
+ if isinstance(type_a, sa_pg.ARRAY):
42
+ return compare_postgresql_types(type_a.item_type, type_b.item_type)
43
+ if isinstance(type_a, (sa_pg.VARCHAR, sa_pg.CHAR, sa_pg.TEXT)):
44
+ return type_a.length == type_b.length
45
+ if isinstance(type_a, (sa_pg.TIMESTAMP, sa_pg.TIME)):
46
+ return type_a.timezone == type_b.timezone
47
+ if isinstance(type_a, sa_pg.NUMERIC):
48
+ return type_a.precision == type_b.precision and type_a.scale == type_b.scale
49
+ return type(type_a) in {
50
+ sa_pg.BOOLEAN,
51
+ sa_pg.INTEGER,
52
+ sa_pg.BIGINT,
53
+ sa_pg.SMALLINT,
54
+ sa_pg.FLOAT,
55
+ sa_pg.DOUBLE_PRECISION,
56
+ sa_pg.REAL,
57
+ sa_pg.DATE,
58
+ sa_pg.UUID,
59
+ sa_pg.JSON,
60
+ sa_pg.JSONB,
61
+ sa_pg.HSTORE,
62
+ }
63
+
64
+
65
+ def validate_model_extended(model_base: type[SQLModel], model_extended: type[SQLModel]) -> bool:
66
+ """
67
+ Validates if `model_extended` is an extension of `model_base` by checking if all fields in `model_base`
68
+ are present in `model_extended` with compatible types.
69
+
70
+ :param model_base: The base model class to compare against.
71
+ :param model_extended: The model class that is expected to extend the base model.
72
+ :return: True if `model_extended` extends `model_base` correctly, False otherwise.
73
+ """
74
+ columns_a = {column.name: column.type for column in model_base.__table__.columns}
75
+ columns_b = {column.name: column.type for column in model_extended.__table__.columns}
76
+
77
+ for field_a, field_a_type in columns_a.items():
78
+ field_b_type = columns_b.get(field_a)
79
+ if field_b_type is None or not compare_postgresql_types(field_a_type, field_b_type):
80
+ return False
81
+ return True
82
+
83
+
84
+ def collect_model_tables(*models: ModelT) -> sa.MetaData:
85
+ metadata = sa.MetaData()
86
+ for base in models:
87
+ for table in base.metadata.tables.values():
88
+ table.to_metadata(metadata)
89
+ return metadata
90
+
91
+
92
+ def model_copy_from(dst: ModelT, src: ModelU, **kwargs) -> ModelT:
93
+ if not isinstance(dst, SQLModel) or not isinstance(src, SQLModel):
94
+ raise TypeError("both 'dst' and 'src' must be instances of SQLModel or its subclasses")
95
+
96
+ for field, value in src.model_dump(**kwargs).items():
97
+ if field not in dst.model_fields:
98
+ continue
99
+ # Skip fields that are not present in the destination model
100
+ if value is None and dst.model_fields[field].required:
101
+ raise ValueError(f"field '{field}' is required but got None")
102
+
103
+ # Only set the field if it exists in the destination model
104
+ if hasattr(dst, field):
105
+ # If the field is a SQLModel, recursively copy it
106
+ if isinstance(value, SQLModel):
107
+ value = model_copy_from(getattr(dst, field), value, **kwargs)
108
+ elif isinstance(value, list) and all(isinstance(item, SQLModel) for item in value):
109
+ value = [model_copy_from(dst_item, src_item, **kwargs)
110
+ for dst_item, src_item in zip(getattr(dst, field), value)]
111
+
112
+ setattr(dst, field, value)
113
+
114
+ return dst
115
+
116
+
117
+ def make_base_model() -> type[SQLModel]:
118
+ """
119
+ Creates a base SQLModel class with custom metadata and JSON encoding for datetime fields.
120
+ Use this as a base for all models that require these configurations.
121
+ """
122
+
123
+ class BaseModel(SQLModel):
124
+ metadata = sa.MetaData()
125
+ model_config = pdt.ConfigDict(json_encoders={datetime.datetime: json_datetime_encoder})
126
+
127
+ return BaseModel
128
+
129
+
130
+ def make_serial_model_mixin() -> type[SQLModel]:
131
+ """
132
+ Creates a mixin class for SQLModel models that adds a unique identifier field `sid`.
133
+ Use this mixin to add an auto-incremented primary key to your models.
134
+ """
135
+
136
+ class ModelMixin(SQLModel):
137
+ sid: int | None = Field(
138
+ sa_column=sa.Column(sa_pg.BIGINT, primary_key=True, autoincrement=True),
139
+ default=None,
140
+ description="Unique auto-incremented primary key for the record",
141
+ )
142
+
143
+ return ModelMixin
144
+
145
+
146
+ def make_record_model_mixin() -> type[SQLModel]:
147
+ """
148
+ Creates a mixin class for SQLModel models that adds common fields and validation logic for updatable records.
149
+ This mixin includes `sid`, `created_at`, and `updated_at` fields, along with validation for timestamps.
150
+ """
151
+
152
+ class ModelMixin(SQLModel):
153
+ sid: int | None = Field(
154
+ sa_column=sa.Column(sa_pg.BIGINT, primary_key=True, autoincrement=True),
155
+ default=None,
156
+ description="Unique auto-incremented primary key for the record",
157
+ )
158
+ created_at: datetime.datetime | None = Field(
159
+ sa_column=sa.Column(sa_pg.TIMESTAMP(timezone=True)),
160
+ default=None,
161
+ description="Timestamp (with timezone) when the record was created",
162
+ )
163
+ updated_at: datetime.datetime | None = Field(
164
+ sa_column=sa.Column(sa_pg.TIMESTAMP(timezone=True)),
165
+ default=None,
166
+ description="Timestamp (with timezone) when the record was last updated",
167
+ )
168
+
169
+ @pdt.field_validator("created_at", mode="after")
170
+ @classmethod
171
+ def validate_created_at(cls, v: datetime.datetime) -> datetime.datetime:
172
+ if v is not None:
173
+ validate_dt_timezone(v)
174
+ return v
175
+
176
+ @pdt.field_validator("updated_at", mode="after")
177
+ @classmethod
178
+ def validate_updated_at(cls, v: datetime.datetime) -> datetime.datetime:
179
+ if v is not None:
180
+ validate_dt_timezone(v)
181
+ return v
182
+
183
+ @pdt.model_validator(mode="after")
184
+ @classmethod
185
+ def validate_created_at_updated_at(cls, m: Self) -> Self:
186
+ if m.created_at is not None and m.updated_at is not None and m.created_at > m.updated_at:
187
+ raise ValueError(f"create time '{m.created_at}' is greater than update time '{m.updated_at}'")
188
+ return m
189
+
190
+ @classmethod
191
+ def make_created_at_index(cls, index_name: str) -> sa.Index:
192
+ """
193
+ Helper to create an index on the `created_at` field with the given index name.
194
+ """
195
+ return sa.Index(index_name, "created_at")
196
+
197
+ return ModelMixin
198
+
199
+
200
+ def make_snapshot_model_mixin() -> type[SQLModel]:
201
+ """
202
+ Provides a mixin class for SQLModel models that adds common fields and validation logic for record snapshots.
203
+ A snapshot model tracks the full change history of an entity: when any field changes, the current record (with a
204
+ NULL expiration time) is updated to set its expiration time, and a new record with the updated values is created.
205
+
206
+ The mixin includes the following fields:
207
+ - `sid`: Unique, auto-incremented primary key identifying each snapshot of the record in the change history.
208
+ - `created_at`: Time (with timezone) when this snapshot of the record was created and became active.
209
+ - `expired_at`: Time (with timezone) when this snapshot of the record was superseded or became inactive;
210
+ `None` if still active.
211
+ - `record_sid`: Foreign key to the record this snapshot belongs to; used to link snapshots together.
212
+ """
213
+
214
+ class ModelMixin(SQLModel):
215
+ sid: int | None = Field(
216
+ sa_column=sa.Column(sa_pg.BIGINT, primary_key=True, autoincrement=True),
217
+ default=None,
218
+ description="Unique auto-incremented primary key for each record snapshot",
219
+ )
220
+ created_at: datetime.datetime | None = Field(
221
+ sa_column=sa.Column(sa_pg.TIMESTAMP(timezone=True)),
222
+ default=None,
223
+ description="Timestamp (with timezone) when this record snapshot became active",
224
+ )
225
+ expired_at: datetime.datetime | None = Field(
226
+ sa_column=sa.Column(sa_pg.TIMESTAMP(timezone=True)),
227
+ default=None,
228
+ description="Timestamp (with timezone) when this record snapshot became inactive; None if still active",
229
+ )
230
+ record_sid: int | None = Field(
231
+ sa_column=sa.Column(sa_pg.BIGINT, nullable=True),
232
+ default=None,
233
+ description="Foreign key to the record this snapshot belongs to",
234
+ )
235
+
236
+ @pdt.field_validator("created_at", mode="after")
237
+ @classmethod
238
+ def validate_created_at(cls, v: datetime.datetime) -> datetime.datetime:
239
+ if v is not None:
240
+ validate_dt_timezone(v)
241
+ return v
242
+
243
+ @pdt.field_validator("expired_at", mode="after")
244
+ @classmethod
245
+ def validate_expired_at(cls, v: datetime.datetime) -> datetime.datetime:
246
+ if v is not None:
247
+ validate_dt_timezone(v)
248
+ return v
249
+
250
+ @pdt.model_validator(mode="after")
251
+ @classmethod
252
+ def validate_created_at_expired_at(cls, m: Self) -> Self:
253
+ if m.created_at is not None and m.expired_at is not None and m.created_at > m.expired_at:
254
+ raise ValueError(f"create time '{m.created_at}' is greater than expire time '{m.expired_at}'")
255
+ return m
256
+
257
+ @classmethod
258
+ def make_created_at_expired_at_index(cls, index_name: str) -> sa.Index:
259
+ return sa.Index(index_name, "created_at", "expired_at")
260
+
261
+ return ModelMixin
262
+
263
+
264
+ serial_model_mixin = make_serial_model_mixin()
265
+ record_model_mixin = make_record_model_mixin()
266
+ snapshot_model_mixin = make_snapshot_model_mixin()
267
+
268
+
269
+ class SerialModel(make_base_model(), make_serial_model_mixin(), table=True):
270
+ pass
271
+
272
+
273
+ class RecordModel(make_base_model(), make_record_model_mixin(), table=True):
274
+ pass
275
+
276
+
277
+ class SnapshotModel(make_base_model(), make_snapshot_model_mixin(), table=True):
278
+ pass
279
+
280
+
281
+ def make_snapshot_model_trigger(engine: sa.Engine, model: type[SQLModel]):
282
+ table_name = model.__tablename__
283
+ if not table_name:
284
+ raise ValueError("missing '__tablename__' attribute")
285
+
286
+ if not validate_model_extended(SnapshotModel, model):
287
+ raise ValueError("not an extended model of 'SnapshotModel'")
288
+
289
+ record_sid_seq_name = f"{table_name}_record_sid_seq"
290
+ snapshot_auto_update_function_name = f"{table_name}_snapshot_auto_update_function"
291
+ snapshot_auto_update_trigger_name = f"{table_name}_snapshot_auto_update_trigger"
292
+
293
+ # language=postgresql
294
+ create_record_sid_seq_sql = f"""
295
+ CREATE SEQUENCE "{record_sid_seq_name}"
296
+ START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1;
297
+ """
298
+
299
+ # language=postgresql
300
+ create_snapshot_auto_update_function_sql = f"""
301
+ CREATE FUNCTION "{snapshot_auto_update_function_name}"()
302
+ RETURNS TRIGGER AS $$
303
+ BEGIN
304
+ IF NEW."record_sid" IS NULL THEN
305
+ NEW."record_sid" := nextval('{record_sid_seq_name}');
306
+ END IF;
307
+
308
+ IF NEW."created_at" IS NULL THEN
309
+ NEW."created_at" := CURRENT_TIMESTAMP;
310
+ END IF;
311
+
312
+ IF NEW."record_sid" IS NOT NULL THEN
313
+ UPDATE "{table_name}"
314
+ SET "expired_at" = NEW."created_at"
315
+ WHERE "record_sid" = NEW."record_sid" AND "expired_at" IS NULL;
316
+ END IF;
317
+
318
+ RETURN NEW;
319
+ END;
320
+ $$ LANGUAGE plpgsql;
321
+ """
322
+
323
+ # language=postgresql
324
+ create_snapshot_auto_update_trigger_sql = f"""
325
+ CREATE TRIGGER "{snapshot_auto_update_trigger_name}"
326
+ BEFORE INSERT ON "{table_name}"
327
+ FOR EACH ROW
328
+ EXECUTE FUNCTION "{snapshot_auto_update_function_name}"();
329
+ """
330
+
331
+ with engine.connect() as conn:
332
+ conn.execute(sa.text(create_record_sid_seq_sql))
333
+ conn.execute(sa.text(create_snapshot_auto_update_function_sql))
334
+ conn.execute(sa.text(create_snapshot_auto_update_trigger_sql))
335
+ conn.commit()