plexus-python-common 1.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of plexus-python-common might be problematic. Click here for more details.
- plexus/common/__init__.py +6 -0
- plexus/common/carto/OSMFile.py +259 -0
- plexus/common/carto/OSMNode.py +25 -0
- plexus/common/carto/OSMTags.py +101 -0
- plexus/common/carto/OSMWay.py +24 -0
- plexus/common/carto/__init__.py +11 -0
- plexus/common/config.py +84 -0
- plexus/common/pose.py +107 -0
- plexus/common/proj.py +305 -0
- plexus/common/utils/__init__.py +0 -0
- plexus/common/utils/bagutils.py +218 -0
- plexus/common/utils/datautils.py +195 -0
- plexus/common/utils/jsonutils.py +92 -0
- plexus/common/utils/ormutils.py +335 -0
- plexus/common/utils/s3utils.py +118 -0
- plexus/common/utils/shutils.py +234 -0
- plexus/common/utils/strutils.py +285 -0
- plexus_python_common-1.0.7.dist-info/METADATA +29 -0
- plexus_python_common-1.0.7.dist-info/RECORD +21 -0
- plexus_python_common-1.0.7.dist-info/WHEEL +5 -0
- plexus_python_common-1.0.7.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
from collections.abc import Callable, Generator
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import pyparsing as pp
|
|
6
|
+
import ujson as json
|
|
7
|
+
from iker.common.utils.funcutils import singleton
|
|
8
|
+
from iker.common.utils.jsonutils import JsonType
|
|
9
|
+
from iker.common.utils.randutils import randomizer
|
|
10
|
+
|
|
11
|
+
from plexus.common.utils.strutils import BagName, UserName, VehicleName
|
|
12
|
+
from plexus.common.utils.strutils import dot_case_parser, kebab_case_parser, snake_case_parser
|
|
13
|
+
from plexus.common.utils.strutils import hex_string_parser
|
|
14
|
+
from plexus.common.utils.strutils import parse_bag_name, parse_user_name, parse_vehicle_name
|
|
15
|
+
from plexus.common.utils.strutils import strict_abspath_parser, strict_relpath_parser
|
|
16
|
+
from plexus.common.utils.strutils import tag_parser, topic_parser, vin_code_chars, vin_code_parser
|
|
17
|
+
from plexus.common.utils.strutils import uuid_parser
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def make_compute_vin_code_check_digit() -> Callable[[str], str]:
|
|
21
|
+
trans_nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 7, 9, 2, 3, 4, 5, 6, 7, 8, 9]
|
|
22
|
+
weights = [8, 7, 6, 5, 4, 3, 2, 10, 0, 9, 8, 7, 6, 5, 4, 3, 2]
|
|
23
|
+
|
|
24
|
+
trans_dict = {vin_code_char: trans_num for vin_code_char, trans_num in zip(vin_code_chars, trans_nums)}
|
|
25
|
+
|
|
26
|
+
def func(vin_code: str) -> str:
|
|
27
|
+
remainder = sum(trans_dict[vin_code_char] * weight for vin_code_char, weight in zip(vin_code, weights)) % 11
|
|
28
|
+
return "X" if remainder == 10 else str(remainder)
|
|
29
|
+
|
|
30
|
+
return func
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
compute_vin_code_check_digit = make_compute_vin_code_check_digit()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def make_validate_string(element: pp.ParserElement) -> Callable[[str], None]:
|
|
37
|
+
def func(s: str) -> None:
|
|
38
|
+
try:
|
|
39
|
+
if not element.parse_string(s, parse_all=True):
|
|
40
|
+
raise ValueError(f"failed to parse '{s}'")
|
|
41
|
+
except Exception as e:
|
|
42
|
+
raise ValueError(f"encountered error while parsing '{s}'") from e
|
|
43
|
+
|
|
44
|
+
return func
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
validate_hex_string = make_validate_string(hex_string_parser)
|
|
48
|
+
|
|
49
|
+
validate_snake_case = make_validate_string(snake_case_parser)
|
|
50
|
+
validate_kebab_case = make_validate_string(kebab_case_parser)
|
|
51
|
+
validate_dot_case = make_validate_string(dot_case_parser)
|
|
52
|
+
|
|
53
|
+
validate_uuid = make_validate_string(uuid_parser)
|
|
54
|
+
|
|
55
|
+
validate_strict_relpath = make_validate_string(strict_relpath_parser)
|
|
56
|
+
validate_strict_abspath = make_validate_string(strict_abspath_parser)
|
|
57
|
+
|
|
58
|
+
validate_tag = make_validate_string(tag_parser)
|
|
59
|
+
validate_topic = make_validate_string(topic_parser)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def validate_vin_code(vin_code: str):
|
|
63
|
+
make_validate_string(vin_code_parser)(vin_code)
|
|
64
|
+
check_digit = compute_vin_code_check_digit(vin_code)
|
|
65
|
+
if check_digit != vin_code[8]:
|
|
66
|
+
raise ValueError(f"get wrong VIN code check digit from '{vin_code}', expected '{check_digit}'")
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def make_validate_parse_string(parse: Callable[[str], Any]) -> Callable[[str], None]:
|
|
70
|
+
def func(s: str) -> None:
|
|
71
|
+
try:
|
|
72
|
+
if not parse(s):
|
|
73
|
+
raise ValueError(f"failed to parse '{s}'")
|
|
74
|
+
except Exception as e:
|
|
75
|
+
raise ValueError(f"encountered error while parsing '{s}'") from e
|
|
76
|
+
|
|
77
|
+
return func
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
validate_user_name = make_validate_parse_string(parse_user_name)
|
|
81
|
+
validate_vehicle_name = make_validate_parse_string(parse_vehicle_name)
|
|
82
|
+
validate_bag_name = make_validate_parse_string(parse_bag_name)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def validate_dt_timezone(dt: datetime.datetime):
|
|
86
|
+
if dt.tzinfo != datetime.timezone.utc:
|
|
87
|
+
raise ValueError(f"dt '{dt}' is not in UTC")
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def validate_json_type_dump_size(json_type: JsonType, dump_size_limit: int = 10000):
|
|
91
|
+
dump_string = json.dumps(json_type, ensure_ascii=False)
|
|
92
|
+
if len(dump_string) > dump_size_limit:
|
|
93
|
+
raise ValueError(f"dump size exceeds the maximum length '{dump_size_limit}'")
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def random_vin_code() -> str:
|
|
97
|
+
vin_code = randomizer().random_string(vin_code_chars, 17)
|
|
98
|
+
check_digit = compute_vin_code_check_digit(vin_code)
|
|
99
|
+
return vin_code[:8] + check_digit + vin_code[9:]
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@singleton
|
|
103
|
+
def known_topics() -> list[str]:
|
|
104
|
+
return [
|
|
105
|
+
"/sensor/camera/front_center",
|
|
106
|
+
"/sensor/camera/front_left",
|
|
107
|
+
"/sensor/camera/front_right",
|
|
108
|
+
"/sensor/camera/side_left",
|
|
109
|
+
"/sensor/camera/side_right",
|
|
110
|
+
"/sensor/camera/rear_left",
|
|
111
|
+
"/sensor/camera/rear_right",
|
|
112
|
+
"/sensor/lidar/front_center",
|
|
113
|
+
"/sensor/lidar/front_left_corner",
|
|
114
|
+
"/sensor/lidar/front_right_corner",
|
|
115
|
+
"/sensor/lidar/side_left",
|
|
116
|
+
"/sensor/lidar/side_right",
|
|
117
|
+
]
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@singleton
|
|
121
|
+
def known_user_names() -> list[UserName]:
|
|
122
|
+
return [
|
|
123
|
+
UserName("adam", "anderson"),
|
|
124
|
+
UserName("ben", "bennett"),
|
|
125
|
+
UserName("charlie", "clark"),
|
|
126
|
+
UserName("david", "dixon"),
|
|
127
|
+
UserName("evan", "edwards"),
|
|
128
|
+
UserName("frank", "fisher"),
|
|
129
|
+
UserName("george", "graham"),
|
|
130
|
+
UserName("henry", "harrison"),
|
|
131
|
+
UserName("isaac", "irving"),
|
|
132
|
+
UserName("jack", "jacobs"),
|
|
133
|
+
UserName("kevin", "kennedy"),
|
|
134
|
+
UserName("luke", "lawson"),
|
|
135
|
+
UserName("michael", "mitchell"),
|
|
136
|
+
UserName("nathan", "newton"),
|
|
137
|
+
UserName("oscar", "owens"),
|
|
138
|
+
UserName("paul", "peterson"),
|
|
139
|
+
UserName("quincy", "quinn"),
|
|
140
|
+
UserName("ryan", "robinson"),
|
|
141
|
+
UserName("sam", "stevens"),
|
|
142
|
+
UserName("tom", "thomas"),
|
|
143
|
+
UserName("umar", "underwood"),
|
|
144
|
+
UserName("victor", "vaughan"),
|
|
145
|
+
UserName("william", "walker"),
|
|
146
|
+
UserName("xander", "xavier"),
|
|
147
|
+
UserName("yale", "young"),
|
|
148
|
+
UserName("zane", "zimmerman"),
|
|
149
|
+
]
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@singleton
|
|
153
|
+
def known_vehicle_names() -> list[VehicleName]:
|
|
154
|
+
return [
|
|
155
|
+
VehicleName("cascadia", "antelope", "00000", "3AKJGLD5XLS000000"),
|
|
156
|
+
VehicleName("cascadia", "bear", "00001", "3AKJGLD51LS000001"),
|
|
157
|
+
VehicleName("cascadia", "cheetah", "00002", "3AKJGLD53LS000002"),
|
|
158
|
+
VehicleName("cascadia", "dolphin", "00003", "3AKJGLD55LS000003"),
|
|
159
|
+
VehicleName("cascadia", "eagle", "00004", "3AKJGLD57LS000004"),
|
|
160
|
+
VehicleName("cascadia", "falcon", "00005", "3AKJGLD59LS000005"),
|
|
161
|
+
VehicleName("cascadia", "gorilla", "00006", "3AKJGLD50LS000006"),
|
|
162
|
+
VehicleName("cascadia", "hawk", "00007", "3AKJGLD52LS000007"),
|
|
163
|
+
VehicleName("cascadia", "iguana", "00008", "3AKJGLD54LS000008"),
|
|
164
|
+
VehicleName("cascadia", "jaguar", "00009", "3AKJGLD56LS000009"),
|
|
165
|
+
VehicleName("cascadia", "koala", "00010", "3AKJGLD52LS000010"),
|
|
166
|
+
VehicleName("cascadia", "leopard", "00011", "3AKJGLD54LS000011"),
|
|
167
|
+
VehicleName("cascadia", "mongoose", "00012", "3AKJGLD56LS000012"),
|
|
168
|
+
VehicleName("cascadia", "narwhal", "00013", "3AKJGLD58LS000013"),
|
|
169
|
+
VehicleName("cascadia", "otter", "00014", "3AKJGLD5XLS000014"),
|
|
170
|
+
VehicleName("cascadia", "panther", "00015", "3AKJGLD51LS000015"),
|
|
171
|
+
VehicleName("cascadia", "quail", "00016", "3AKJGLD53LS000016"),
|
|
172
|
+
VehicleName("cascadia", "rhino", "00017", "3AKJGLD55LS000017"),
|
|
173
|
+
VehicleName("cascadia", "snake", "00018", "3AKJGLD57LS000018"),
|
|
174
|
+
VehicleName("cascadia", "tiger", "00019", "3AKJGLD59LS000019"),
|
|
175
|
+
VehicleName("cascadia", "urial", "00020", "3AKJGLD55LS000020"),
|
|
176
|
+
VehicleName("cascadia", "vulture", "00021", "3AKJGLD57LS000021"),
|
|
177
|
+
VehicleName("cascadia", "wolf", "00022", "3AKJGLD59LS000022"),
|
|
178
|
+
VehicleName("cascadia", "xerus", "00023", "3AKJGLD50LS000023"),
|
|
179
|
+
VehicleName("cascadia", "yak", "00024", "3AKJGLD52LS000024"),
|
|
180
|
+
VehicleName("cascadia", "zebra", "00025", "3AKJGLD54LS000025"),
|
|
181
|
+
]
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def random_bag_names_sequence(
|
|
185
|
+
min_record_dt: datetime.datetime,
|
|
186
|
+
max_record_dt: datetime.datetime,
|
|
187
|
+
min_sequence_length: int,
|
|
188
|
+
max_sequence_length: int,
|
|
189
|
+
) -> Generator[BagName, None, None]:
|
|
190
|
+
vehicle_name = randomizer().choose(known_vehicle_names())
|
|
191
|
+
record_dt = randomizer().random_datetime(begin=min_record_dt, end=max_record_dt)
|
|
192
|
+
bags_count = randomizer().next_int(min_sequence_length, max_sequence_length)
|
|
193
|
+
|
|
194
|
+
for record_sn in range(bags_count):
|
|
195
|
+
yield BagName(vehicle_name=vehicle_name, record_dt=record_dt, record_sn=record_sn)
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import os
|
|
3
|
+
from collections.abc import Generator, Iterable
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import ujson as json
|
|
7
|
+
from iker.common.utils.dtutils import dt_format, dt_parse, extended_format
|
|
8
|
+
from iker.common.utils.jsonutils import JsonType, JsonValueCompatible
|
|
9
|
+
from iker.common.utils.jsonutils import json_reformat
|
|
10
|
+
from iker.common.utils.sequtils import batch_yield
|
|
11
|
+
|
|
12
|
+
from plexus.common.utils.shutils import collect_volumed_filenames, populate_volumed_filenames
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"json_datetime_decoder",
|
|
16
|
+
"json_datetime_encoder",
|
|
17
|
+
"json_loads",
|
|
18
|
+
"json_dumps",
|
|
19
|
+
"read_chunked_jsonl",
|
|
20
|
+
"write_chunked_jsonl",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def json_datetime_decoder(v: Any) -> datetime.datetime:
|
|
25
|
+
if isinstance(v, str):
|
|
26
|
+
return json_datetime_decoder(dt_parse(v, extended_format(with_us=True, with_tz=True)))
|
|
27
|
+
if isinstance(v, datetime.datetime):
|
|
28
|
+
return v.replace(tzinfo=datetime.timezone.utc)
|
|
29
|
+
raise ValueError("unexpected type of value for datetime decoder")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def json_datetime_encoder(v: Any) -> str:
|
|
33
|
+
if isinstance(v, str):
|
|
34
|
+
return json_datetime_encoder(dt_parse(v, extended_format(with_us=True, with_tz=True)))
|
|
35
|
+
if isinstance(v, datetime.datetime):
|
|
36
|
+
return dt_format(v.replace(tzinfo=datetime.timezone.utc), extended_format(with_us=True, with_tz=True))
|
|
37
|
+
raise ValueError("unexpected type of value for datetime encoder")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def json_deserializer(obj):
|
|
41
|
+
def value_formatter(value: JsonValueCompatible) -> JsonType:
|
|
42
|
+
if not isinstance(value, str):
|
|
43
|
+
return value
|
|
44
|
+
try:
|
|
45
|
+
return dt_parse(value, extended_format(with_us=True, with_tz=True))
|
|
46
|
+
except Exception:
|
|
47
|
+
return value
|
|
48
|
+
|
|
49
|
+
return json_reformat(obj, value_formatter=value_formatter)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def json_serializer(obj):
|
|
53
|
+
def unregistered_formatter(unregistered: Any) -> JsonType:
|
|
54
|
+
if isinstance(unregistered, datetime.datetime):
|
|
55
|
+
return dt_format(unregistered, extended_format(with_us=True, with_tz=True))
|
|
56
|
+
return None
|
|
57
|
+
|
|
58
|
+
return json_reformat(obj, raise_if_unregistered=False, unregistered_formatter=unregistered_formatter)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def json_loads(s: str) -> JsonType:
|
|
62
|
+
return json_deserializer(json.loads(s))
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def json_dumps(obj: JsonType) -> str:
|
|
66
|
+
return json.dumps(json_serializer(obj), ensure_ascii=False, escape_forward_slashes=False)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def read_chunked_jsonl(template: str) -> Generator[tuple[JsonType, str], None, None]:
|
|
70
|
+
for path, _ in collect_volumed_filenames(template):
|
|
71
|
+
with open(path, mode="r", encoding="utf-8") as fh:
|
|
72
|
+
for line in fh:
|
|
73
|
+
yield json_loads(line), path
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def write_chunked_jsonl(records: Iterable[JsonType], template: str, chunk_size: int) -> list[tuple[str, int]]:
|
|
77
|
+
generator = populate_volumed_filenames(template)
|
|
78
|
+
entry = []
|
|
79
|
+
for batch_index, batch in enumerate(batch_yield(records, chunk_size)):
|
|
80
|
+
path, _ = next(generator)
|
|
81
|
+
lines = 0
|
|
82
|
+
with open(path, mode="w") as fh:
|
|
83
|
+
for record in batch:
|
|
84
|
+
fh.write(json_dumps(record))
|
|
85
|
+
fh.write("\n")
|
|
86
|
+
lines += 1
|
|
87
|
+
entry.append((path, lines))
|
|
88
|
+
if len(entry) == 1:
|
|
89
|
+
path, lines = entry[0]
|
|
90
|
+
os.rename(path, template)
|
|
91
|
+
return [(template, lines)]
|
|
92
|
+
return entry
|
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
from typing import Self, TypeVar
|
|
3
|
+
|
|
4
|
+
import pydantic as pdt
|
|
5
|
+
import sqlalchemy as sa
|
|
6
|
+
import sqlalchemy.dialects.postgresql as sa_pg
|
|
7
|
+
from sqlmodel import Field, SQLModel
|
|
8
|
+
|
|
9
|
+
from plexus.common.utils.datautils import validate_dt_timezone
|
|
10
|
+
from plexus.common.utils.jsonutils import json_datetime_encoder
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"compare_postgresql_types",
|
|
14
|
+
"validate_model_extended",
|
|
15
|
+
"collect_model_tables",
|
|
16
|
+
"model_copy_from",
|
|
17
|
+
"make_base_model",
|
|
18
|
+
"make_serial_model_mixin",
|
|
19
|
+
"make_record_model_mixin",
|
|
20
|
+
"make_snapshot_model_mixin",
|
|
21
|
+
"serial_model_mixin",
|
|
22
|
+
"record_model_mixin",
|
|
23
|
+
"snapshot_model_mixin",
|
|
24
|
+
"SerialModel",
|
|
25
|
+
"RecordModel",
|
|
26
|
+
"SnapshotModel",
|
|
27
|
+
"make_snapshot_model_trigger",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
ModelT = TypeVar("ModelT", bound=SQLModel)
|
|
31
|
+
ModelU = TypeVar("ModelU", bound=SQLModel)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def compare_postgresql_types(type_a, type_b) -> bool:
|
|
35
|
+
"""
|
|
36
|
+
Compares two Postgresql-specific column types to determine if they are equivalent.
|
|
37
|
+
This includes types from sqlalchemy.dialects.postgresql like ARRAY, JSON, UUID, etc.
|
|
38
|
+
"""
|
|
39
|
+
if not isinstance(type_a, type(type_b)):
|
|
40
|
+
return False
|
|
41
|
+
if isinstance(type_a, sa_pg.ARRAY):
|
|
42
|
+
return compare_postgresql_types(type_a.item_type, type_b.item_type)
|
|
43
|
+
if isinstance(type_a, (sa_pg.VARCHAR, sa_pg.CHAR, sa_pg.TEXT)):
|
|
44
|
+
return type_a.length == type_b.length
|
|
45
|
+
if isinstance(type_a, (sa_pg.TIMESTAMP, sa_pg.TIME)):
|
|
46
|
+
return type_a.timezone == type_b.timezone
|
|
47
|
+
if isinstance(type_a, sa_pg.NUMERIC):
|
|
48
|
+
return type_a.precision == type_b.precision and type_a.scale == type_b.scale
|
|
49
|
+
return type(type_a) in {
|
|
50
|
+
sa_pg.BOOLEAN,
|
|
51
|
+
sa_pg.INTEGER,
|
|
52
|
+
sa_pg.BIGINT,
|
|
53
|
+
sa_pg.SMALLINT,
|
|
54
|
+
sa_pg.FLOAT,
|
|
55
|
+
sa_pg.DOUBLE_PRECISION,
|
|
56
|
+
sa_pg.REAL,
|
|
57
|
+
sa_pg.DATE,
|
|
58
|
+
sa_pg.UUID,
|
|
59
|
+
sa_pg.JSON,
|
|
60
|
+
sa_pg.JSONB,
|
|
61
|
+
sa_pg.HSTORE,
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def validate_model_extended(model_base: type[SQLModel], model_extended: type[SQLModel]) -> bool:
|
|
66
|
+
"""
|
|
67
|
+
Validates if `model_extended` is an extension of `model_base` by checking if all fields in `model_base`
|
|
68
|
+
are present in `model_extended` with compatible types.
|
|
69
|
+
|
|
70
|
+
:param model_base: The base model class to compare against.
|
|
71
|
+
:param model_extended: The model class that is expected to extend the base model.
|
|
72
|
+
:return: True if `model_extended` extends `model_base` correctly, False otherwise.
|
|
73
|
+
"""
|
|
74
|
+
columns_a = {column.name: column.type for column in model_base.__table__.columns}
|
|
75
|
+
columns_b = {column.name: column.type for column in model_extended.__table__.columns}
|
|
76
|
+
|
|
77
|
+
for field_a, field_a_type in columns_a.items():
|
|
78
|
+
field_b_type = columns_b.get(field_a)
|
|
79
|
+
if field_b_type is None or not compare_postgresql_types(field_a_type, field_b_type):
|
|
80
|
+
return False
|
|
81
|
+
return True
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def collect_model_tables(*models: ModelT) -> sa.MetaData:
|
|
85
|
+
metadata = sa.MetaData()
|
|
86
|
+
for base in models:
|
|
87
|
+
for table in base.metadata.tables.values():
|
|
88
|
+
table.to_metadata(metadata)
|
|
89
|
+
return metadata
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def model_copy_from(dst: ModelT, src: ModelU, **kwargs) -> ModelT:
|
|
93
|
+
if not isinstance(dst, SQLModel) or not isinstance(src, SQLModel):
|
|
94
|
+
raise TypeError("both 'dst' and 'src' must be instances of SQLModel or its subclasses")
|
|
95
|
+
|
|
96
|
+
for field, value in src.model_dump(**kwargs).items():
|
|
97
|
+
if field not in dst.model_fields:
|
|
98
|
+
continue
|
|
99
|
+
# Skip fields that are not present in the destination model
|
|
100
|
+
if value is None and dst.model_fields[field].required:
|
|
101
|
+
raise ValueError(f"field '{field}' is required but got None")
|
|
102
|
+
|
|
103
|
+
# Only set the field if it exists in the destination model
|
|
104
|
+
if hasattr(dst, field):
|
|
105
|
+
# If the field is a SQLModel, recursively copy it
|
|
106
|
+
if isinstance(value, SQLModel):
|
|
107
|
+
value = model_copy_from(getattr(dst, field), value, **kwargs)
|
|
108
|
+
elif isinstance(value, list) and all(isinstance(item, SQLModel) for item in value):
|
|
109
|
+
value = [model_copy_from(dst_item, src_item, **kwargs)
|
|
110
|
+
for dst_item, src_item in zip(getattr(dst, field), value)]
|
|
111
|
+
|
|
112
|
+
setattr(dst, field, value)
|
|
113
|
+
|
|
114
|
+
return dst
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def make_base_model() -> type[SQLModel]:
|
|
118
|
+
"""
|
|
119
|
+
Creates a base SQLModel class with custom metadata and JSON encoding for datetime fields.
|
|
120
|
+
Use this as a base for all models that require these configurations.
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
class BaseModel(SQLModel):
|
|
124
|
+
metadata = sa.MetaData()
|
|
125
|
+
model_config = pdt.ConfigDict(json_encoders={datetime.datetime: json_datetime_encoder})
|
|
126
|
+
|
|
127
|
+
return BaseModel
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def make_serial_model_mixin() -> type[SQLModel]:
|
|
131
|
+
"""
|
|
132
|
+
Creates a mixin class for SQLModel models that adds a unique identifier field `sid`.
|
|
133
|
+
Use this mixin to add an auto-incremented primary key to your models.
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
class ModelMixin(SQLModel):
|
|
137
|
+
sid: int | None = Field(
|
|
138
|
+
sa_column=sa.Column(sa_pg.BIGINT, primary_key=True, autoincrement=True),
|
|
139
|
+
default=None,
|
|
140
|
+
description="Unique auto-incremented primary key for the record",
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
return ModelMixin
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def make_record_model_mixin() -> type[SQLModel]:
|
|
147
|
+
"""
|
|
148
|
+
Creates a mixin class for SQLModel models that adds common fields and validation logic for updatable records.
|
|
149
|
+
This mixin includes `sid`, `created_at`, and `updated_at` fields, along with validation for timestamps.
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
class ModelMixin(SQLModel):
|
|
153
|
+
sid: int | None = Field(
|
|
154
|
+
sa_column=sa.Column(sa_pg.BIGINT, primary_key=True, autoincrement=True),
|
|
155
|
+
default=None,
|
|
156
|
+
description="Unique auto-incremented primary key for the record",
|
|
157
|
+
)
|
|
158
|
+
created_at: datetime.datetime | None = Field(
|
|
159
|
+
sa_column=sa.Column(sa_pg.TIMESTAMP(timezone=True)),
|
|
160
|
+
default=None,
|
|
161
|
+
description="Timestamp (with timezone) when the record was created",
|
|
162
|
+
)
|
|
163
|
+
updated_at: datetime.datetime | None = Field(
|
|
164
|
+
sa_column=sa.Column(sa_pg.TIMESTAMP(timezone=True)),
|
|
165
|
+
default=None,
|
|
166
|
+
description="Timestamp (with timezone) when the record was last updated",
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
@pdt.field_validator("created_at", mode="after")
|
|
170
|
+
@classmethod
|
|
171
|
+
def validate_created_at(cls, v: datetime.datetime) -> datetime.datetime:
|
|
172
|
+
if v is not None:
|
|
173
|
+
validate_dt_timezone(v)
|
|
174
|
+
return v
|
|
175
|
+
|
|
176
|
+
@pdt.field_validator("updated_at", mode="after")
|
|
177
|
+
@classmethod
|
|
178
|
+
def validate_updated_at(cls, v: datetime.datetime) -> datetime.datetime:
|
|
179
|
+
if v is not None:
|
|
180
|
+
validate_dt_timezone(v)
|
|
181
|
+
return v
|
|
182
|
+
|
|
183
|
+
@pdt.model_validator(mode="after")
|
|
184
|
+
@classmethod
|
|
185
|
+
def validate_created_at_updated_at(cls, m: Self) -> Self:
|
|
186
|
+
if m.created_at is not None and m.updated_at is not None and m.created_at > m.updated_at:
|
|
187
|
+
raise ValueError(f"create time '{m.created_at}' is greater than update time '{m.updated_at}'")
|
|
188
|
+
return m
|
|
189
|
+
|
|
190
|
+
@classmethod
|
|
191
|
+
def make_created_at_index(cls, index_name: str) -> sa.Index:
|
|
192
|
+
"""
|
|
193
|
+
Helper to create an index on the `created_at` field with the given index name.
|
|
194
|
+
"""
|
|
195
|
+
return sa.Index(index_name, "created_at")
|
|
196
|
+
|
|
197
|
+
return ModelMixin
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def make_snapshot_model_mixin() -> type[SQLModel]:
|
|
201
|
+
"""
|
|
202
|
+
Provides a mixin class for SQLModel models that adds common fields and validation logic for record snapshots.
|
|
203
|
+
A snapshot model tracks the full change history of an entity: when any field changes, the current record (with a
|
|
204
|
+
NULL expiration time) is updated to set its expiration time, and a new record with the updated values is created.
|
|
205
|
+
|
|
206
|
+
The mixin includes the following fields:
|
|
207
|
+
- `sid`: Unique, auto-incremented primary key identifying each snapshot of the record in the change history.
|
|
208
|
+
- `created_at`: Time (with timezone) when this snapshot of the record was created and became active.
|
|
209
|
+
- `expired_at`: Time (with timezone) when this snapshot of the record was superseded or became inactive;
|
|
210
|
+
`None` if still active.
|
|
211
|
+
- `record_sid`: Foreign key to the record this snapshot belongs to; used to link snapshots together.
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
class ModelMixin(SQLModel):
|
|
215
|
+
sid: int | None = Field(
|
|
216
|
+
sa_column=sa.Column(sa_pg.BIGINT, primary_key=True, autoincrement=True),
|
|
217
|
+
default=None,
|
|
218
|
+
description="Unique auto-incremented primary key for each record snapshot",
|
|
219
|
+
)
|
|
220
|
+
created_at: datetime.datetime | None = Field(
|
|
221
|
+
sa_column=sa.Column(sa_pg.TIMESTAMP(timezone=True)),
|
|
222
|
+
default=None,
|
|
223
|
+
description="Timestamp (with timezone) when this record snapshot became active",
|
|
224
|
+
)
|
|
225
|
+
expired_at: datetime.datetime | None = Field(
|
|
226
|
+
sa_column=sa.Column(sa_pg.TIMESTAMP(timezone=True)),
|
|
227
|
+
default=None,
|
|
228
|
+
description="Timestamp (with timezone) when this record snapshot became inactive; None if still active",
|
|
229
|
+
)
|
|
230
|
+
record_sid: int | None = Field(
|
|
231
|
+
sa_column=sa.Column(sa_pg.BIGINT, nullable=True),
|
|
232
|
+
default=None,
|
|
233
|
+
description="Foreign key to the record this snapshot belongs to",
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
@pdt.field_validator("created_at", mode="after")
|
|
237
|
+
@classmethod
|
|
238
|
+
def validate_created_at(cls, v: datetime.datetime) -> datetime.datetime:
|
|
239
|
+
if v is not None:
|
|
240
|
+
validate_dt_timezone(v)
|
|
241
|
+
return v
|
|
242
|
+
|
|
243
|
+
@pdt.field_validator("expired_at", mode="after")
|
|
244
|
+
@classmethod
|
|
245
|
+
def validate_expired_at(cls, v: datetime.datetime) -> datetime.datetime:
|
|
246
|
+
if v is not None:
|
|
247
|
+
validate_dt_timezone(v)
|
|
248
|
+
return v
|
|
249
|
+
|
|
250
|
+
@pdt.model_validator(mode="after")
|
|
251
|
+
@classmethod
|
|
252
|
+
def validate_created_at_expired_at(cls, m: Self) -> Self:
|
|
253
|
+
if m.created_at is not None and m.expired_at is not None and m.created_at > m.expired_at:
|
|
254
|
+
raise ValueError(f"create time '{m.created_at}' is greater than expire time '{m.expired_at}'")
|
|
255
|
+
return m
|
|
256
|
+
|
|
257
|
+
@classmethod
|
|
258
|
+
def make_created_at_expired_at_index(cls, index_name: str) -> sa.Index:
|
|
259
|
+
return sa.Index(index_name, "created_at", "expired_at")
|
|
260
|
+
|
|
261
|
+
return ModelMixin
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
serial_model_mixin = make_serial_model_mixin()
|
|
265
|
+
record_model_mixin = make_record_model_mixin()
|
|
266
|
+
snapshot_model_mixin = make_snapshot_model_mixin()
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
class SerialModel(make_base_model(), make_serial_model_mixin(), table=True):
|
|
270
|
+
pass
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class RecordModel(make_base_model(), make_record_model_mixin(), table=True):
|
|
274
|
+
pass
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class SnapshotModel(make_base_model(), make_snapshot_model_mixin(), table=True):
|
|
278
|
+
pass
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def make_snapshot_model_trigger(engine: sa.Engine, model: type[SQLModel]):
|
|
282
|
+
table_name = model.__tablename__
|
|
283
|
+
if not table_name:
|
|
284
|
+
raise ValueError("missing '__tablename__' attribute")
|
|
285
|
+
|
|
286
|
+
if not validate_model_extended(SnapshotModel, model):
|
|
287
|
+
raise ValueError("not an extended model of 'SnapshotModel'")
|
|
288
|
+
|
|
289
|
+
record_sid_seq_name = f"{table_name}_record_sid_seq"
|
|
290
|
+
snapshot_auto_update_function_name = f"{table_name}_snapshot_auto_update_function"
|
|
291
|
+
snapshot_auto_update_trigger_name = f"{table_name}_snapshot_auto_update_trigger"
|
|
292
|
+
|
|
293
|
+
# language=postgresql
|
|
294
|
+
create_record_sid_seq_sql = f"""
|
|
295
|
+
CREATE SEQUENCE "{record_sid_seq_name}"
|
|
296
|
+
START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1;
|
|
297
|
+
"""
|
|
298
|
+
|
|
299
|
+
# language=postgresql
|
|
300
|
+
create_snapshot_auto_update_function_sql = f"""
|
|
301
|
+
CREATE FUNCTION "{snapshot_auto_update_function_name}"()
|
|
302
|
+
RETURNS TRIGGER AS $$
|
|
303
|
+
BEGIN
|
|
304
|
+
IF NEW."record_sid" IS NULL THEN
|
|
305
|
+
NEW."record_sid" := nextval('{record_sid_seq_name}');
|
|
306
|
+
END IF;
|
|
307
|
+
|
|
308
|
+
IF NEW."created_at" IS NULL THEN
|
|
309
|
+
NEW."created_at" := CURRENT_TIMESTAMP;
|
|
310
|
+
END IF;
|
|
311
|
+
|
|
312
|
+
IF NEW."record_sid" IS NOT NULL THEN
|
|
313
|
+
UPDATE "{table_name}"
|
|
314
|
+
SET "expired_at" = NEW."created_at"
|
|
315
|
+
WHERE "record_sid" = NEW."record_sid" AND "expired_at" IS NULL;
|
|
316
|
+
END IF;
|
|
317
|
+
|
|
318
|
+
RETURN NEW;
|
|
319
|
+
END;
|
|
320
|
+
$$ LANGUAGE plpgsql;
|
|
321
|
+
"""
|
|
322
|
+
|
|
323
|
+
# language=postgresql
|
|
324
|
+
create_snapshot_auto_update_trigger_sql = f"""
|
|
325
|
+
CREATE TRIGGER "{snapshot_auto_update_trigger_name}"
|
|
326
|
+
BEFORE INSERT ON "{table_name}"
|
|
327
|
+
FOR EACH ROW
|
|
328
|
+
EXECUTE FUNCTION "{snapshot_auto_update_function_name}"();
|
|
329
|
+
"""
|
|
330
|
+
|
|
331
|
+
with engine.connect() as conn:
|
|
332
|
+
conn.execute(sa.text(create_record_sid_seq_sql))
|
|
333
|
+
conn.execute(sa.text(create_snapshot_auto_update_function_sql))
|
|
334
|
+
conn.execute(sa.text(create_snapshot_auto_update_trigger_sql))
|
|
335
|
+
conn.commit()
|