datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +20 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +7 -7
- datachain/catalog/__init__.py +2 -2
- datachain/catalog/catalog.py +621 -507
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +28 -18
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +24 -33
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +83 -52
- datachain/cli/commands/ls.py +17 -17
- datachain/cli/commands/show.py +4 -4
- datachain/cli/parser/__init__.py +8 -74
- datachain/cli/parser/job.py +95 -3
- datachain/cli/parser/studio.py +11 -4
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +4 -4
- datachain/client/fsspec.py +45 -28
- datachain/client/gcs.py +6 -6
- datachain/client/hf.py +29 -2
- datachain/client/http.py +157 -0
- datachain/client/local.py +15 -11
- datachain/client/s3.py +17 -9
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +5 -1
- datachain/data_storage/metastore.py +1252 -186
- datachain/data_storage/schema.py +58 -45
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +286 -127
- datachain/data_storage/warehouse.py +250 -113
- datachain/dataset.py +353 -148
- datachain/delta.py +391 -0
- datachain/diff/__init__.py +27 -29
- datachain/error.py +60 -0
- datachain/func/__init__.py +2 -1
- datachain/func/aggregate.py +66 -42
- datachain/func/array.py +242 -38
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +110 -60
- datachain/func/func.py +96 -45
- datachain/func/numeric.py +55 -38
- datachain/func/path.py +32 -20
- datachain/func/random.py +2 -2
- datachain/func/string.py +67 -37
- datachain/func/window.py +7 -8
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +58 -22
- datachain/lib/audio.py +245 -0
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/flatten.py +5 -3
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/sql_to_python.py +8 -0
- datachain/lib/convert/values_to_tuples.py +156 -51
- datachain/lib/data_model.py +42 -20
- datachain/lib/dataset_info.py +36 -8
- datachain/lib/dc/__init__.py +8 -2
- datachain/lib/dc/csv.py +25 -28
- datachain/lib/dc/database.py +398 -0
- datachain/lib/dc/datachain.py +1289 -425
- datachain/lib/dc/datasets.py +320 -38
- datachain/lib/dc/hf.py +38 -24
- datachain/lib/dc/json.py +29 -32
- datachain/lib/dc/listings.py +112 -8
- datachain/lib/dc/pandas.py +16 -12
- datachain/lib/dc/parquet.py +35 -23
- datachain/lib/dc/records.py +31 -23
- datachain/lib/dc/storage.py +154 -64
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +24 -16
- datachain/lib/dc/values.py +8 -9
- datachain/lib/file.py +622 -89
- datachain/lib/hf.py +69 -39
- datachain/lib/image.py +14 -14
- datachain/lib/listing.py +14 -11
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +3 -4
- datachain/lib/model_store.py +39 -7
- datachain/lib/namespaces.py +125 -0
- datachain/lib/projects.py +130 -0
- datachain/lib/pytorch.py +32 -21
- datachain/lib/settings.py +192 -56
- datachain/lib/signal_schema.py +427 -104
- datachain/lib/tar.py +1 -2
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +164 -76
- datachain/lib/udf_signature.py +60 -35
- datachain/lib/utils.py +118 -4
- datachain/lib/video.py +17 -9
- datachain/lib/webdataset.py +61 -56
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +22 -10
- datachain/model/bbox.py +3 -1
- datachain/model/ultralytics/bbox.py +16 -12
- datachain/model/ultralytics/pose.py +16 -12
- datachain/model/ultralytics/segment.py +16 -12
- datachain/namespace.py +84 -0
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +78 -0
- datachain/query/batch.py +40 -41
- datachain/query/dataset.py +604 -322
- datachain/query/dispatch.py +261 -154
- datachain/query/metrics.py +4 -6
- datachain/query/params.py +2 -3
- datachain/query/queue.py +3 -12
- datachain/query/schema.py +11 -6
- datachain/query/session.py +200 -33
- datachain/query/udf.py +34 -2
- datachain/remote/studio.py +171 -69
- datachain/script_meta.py +12 -12
- datachain/semver.py +68 -0
- datachain/sql/__init__.py +2 -0
- datachain/sql/functions/array.py +33 -1
- datachain/sql/postgresql_dialect.py +9 -0
- datachain/sql/postgresql_types.py +21 -0
- datachain/sql/sqlite/__init__.py +5 -1
- datachain/sql/sqlite/base.py +102 -29
- datachain/sql/sqlite/types.py +8 -13
- datachain/sql/types.py +70 -15
- datachain/studio.py +223 -46
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +101 -59
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
- datachain-0.39.0.dist-info/RECORD +173 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
- datachain/cli/commands/query.py +0 -53
- datachain/query/utils.py +0 -42
- datachain-0.14.2.dist-info/RECORD +0 -158
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/sql/sqlite/base.py
CHANGED
|
@@ -2,13 +2,12 @@ import logging
|
|
|
2
2
|
import re
|
|
3
3
|
import sqlite3
|
|
4
4
|
import warnings
|
|
5
|
-
from collections.abc import Iterable
|
|
5
|
+
from collections.abc import Callable, Iterable
|
|
6
|
+
from contextlib import closing
|
|
6
7
|
from datetime import MAXYEAR, MINYEAR, datetime, timezone
|
|
7
8
|
from functools import cache
|
|
8
9
|
from types import MappingProxyType
|
|
9
|
-
from typing import Callable, Optional
|
|
10
10
|
|
|
11
|
-
import orjson
|
|
12
11
|
import sqlalchemy as sa
|
|
13
12
|
from sqlalchemy.dialects import sqlite
|
|
14
13
|
from sqlalchemy.ext.compiler import compiles
|
|
@@ -16,6 +15,7 @@ from sqlalchemy.sql.elements import literal
|
|
|
16
15
|
from sqlalchemy.sql.expression import case
|
|
17
16
|
from sqlalchemy.sql.functions import func
|
|
18
17
|
|
|
18
|
+
from datachain import json
|
|
19
19
|
from datachain.sql.functions import (
|
|
20
20
|
aggregate,
|
|
21
21
|
array,
|
|
@@ -88,6 +88,9 @@ def setup():
|
|
|
88
88
|
compiles(sql_path.file_ext, "sqlite")(compile_path_file_ext)
|
|
89
89
|
compiles(array.length, "sqlite")(compile_array_length)
|
|
90
90
|
compiles(array.contains, "sqlite")(compile_array_contains)
|
|
91
|
+
compiles(array.slice, "sqlite")(compile_array_slice)
|
|
92
|
+
compiles(array.join, "sqlite")(compile_array_join)
|
|
93
|
+
compiles(array.get_element, "sqlite")(compile_array_get_element)
|
|
91
94
|
compiles(string.length, "sqlite")(compile_string_length)
|
|
92
95
|
compiles(string.split, "sqlite")(compile_string_split)
|
|
93
96
|
compiles(string.regexp_replace, "sqlite")(compile_string_regexp_replace)
|
|
@@ -109,7 +112,10 @@ def setup():
|
|
|
109
112
|
compiles(numeric.int_hash_64, "sqlite")(compile_int_hash_64)
|
|
110
113
|
compiles(numeric.bit_hamming_distance, "sqlite")(compile_bit_hamming_distance)
|
|
111
114
|
|
|
112
|
-
|
|
115
|
+
with closing(sqlite3.connect(":memory:")) as _usearch_conn:
|
|
116
|
+
usearch_available = load_usearch_extension(_usearch_conn)
|
|
117
|
+
|
|
118
|
+
if usearch_available:
|
|
113
119
|
compiles(array.cosine_distance, "sqlite")(compile_cosine_distance_ext)
|
|
114
120
|
compiles(array.euclidean_distance, "sqlite")(compile_euclidean_distance_ext)
|
|
115
121
|
else:
|
|
@@ -129,7 +135,7 @@ def run_compiler_hook(name):
|
|
|
129
135
|
|
|
130
136
|
|
|
131
137
|
def functions_exist(
|
|
132
|
-
names: Iterable[str], connection:
|
|
138
|
+
names: Iterable[str], connection: sqlite3.Connection | None = None
|
|
133
139
|
) -> bool:
|
|
134
140
|
"""
|
|
135
141
|
Returns True if all function names are defined for the given connection.
|
|
@@ -143,23 +149,34 @@ def functions_exist(
|
|
|
143
149
|
f"Found value of type {type(n).__name__}: {n!r}"
|
|
144
150
|
)
|
|
145
151
|
|
|
152
|
+
close_connection = False
|
|
146
153
|
if connection is None:
|
|
147
154
|
connection = sqlite3.connect(":memory:")
|
|
155
|
+
close_connection = True
|
|
148
156
|
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
157
|
+
try:
|
|
158
|
+
if not names:
|
|
159
|
+
return True
|
|
160
|
+
column1 = sa.column("column1", sa.String)
|
|
161
|
+
func_name_query = column1.not_in(
|
|
162
|
+
sa.select(sa.column("name", sa.String)).select_from(
|
|
163
|
+
func.pragma_function_list()
|
|
164
|
+
)
|
|
165
|
+
)
|
|
166
|
+
query = (
|
|
167
|
+
sa.select(func.count() == 0)
|
|
168
|
+
.select_from(sa.values(column1).data([(n,) for n in names]))
|
|
169
|
+
.where(func_name_query)
|
|
170
|
+
)
|
|
171
|
+
comp = query.compile(dialect=sqlite_dialect)
|
|
172
|
+
if comp.params:
|
|
173
|
+
result = connection.execute(comp.string, comp.params)
|
|
174
|
+
else:
|
|
175
|
+
result = connection.execute(comp.string)
|
|
176
|
+
return bool(result.fetchone()[0])
|
|
177
|
+
finally:
|
|
178
|
+
if close_connection:
|
|
179
|
+
connection.close()
|
|
163
180
|
|
|
164
181
|
|
|
165
182
|
def create_user_defined_sql_functions(connection):
|
|
@@ -179,7 +196,7 @@ def missing_vector_function(name, exc):
|
|
|
179
196
|
|
|
180
197
|
|
|
181
198
|
def sqlite_string_split(string: str, sep: str, maxsplit: int = -1) -> str:
|
|
182
|
-
return
|
|
199
|
+
return json.dumps(string.split(sep, maxsplit), ensure_ascii=False)
|
|
183
200
|
|
|
184
201
|
|
|
185
202
|
def sqlite_int_hash_64(x: int) -> int:
|
|
@@ -198,9 +215,7 @@ def sqlite_int_hash_64(x: int) -> int:
|
|
|
198
215
|
def sqlite_bit_hamming_distance(a: int, b: int) -> int:
|
|
199
216
|
"""Calculate the Hamming distance between two integers."""
|
|
200
217
|
diff = (a & MAX_INT64) ^ (b & MAX_INT64)
|
|
201
|
-
|
|
202
|
-
return diff.bit_count()
|
|
203
|
-
return bin(diff).count("1")
|
|
218
|
+
return diff.bit_count()
|
|
204
219
|
|
|
205
220
|
|
|
206
221
|
def sqlite_byte_hamming_distance(a: str, b: str) -> int:
|
|
@@ -212,7 +227,7 @@ def sqlite_byte_hamming_distance(a: str, b: str) -> int:
|
|
|
212
227
|
elif len(b) < len(a):
|
|
213
228
|
diff = len(a) - len(b)
|
|
214
229
|
a = a[: len(b)]
|
|
215
|
-
return diff + sum(c1 != c2 for c1, c2 in zip(a, b))
|
|
230
|
+
return diff + sum(c1 != c2 for c1, c2 in zip(a, b, strict=False))
|
|
216
231
|
|
|
217
232
|
|
|
218
233
|
def register_user_defined_sql_functions() -> None:
|
|
@@ -270,6 +285,22 @@ def register_user_defined_sql_functions() -> None:
|
|
|
270
285
|
|
|
271
286
|
_registered_function_creators["string_functions"] = create_string_functions
|
|
272
287
|
|
|
288
|
+
def create_array_functions(conn):
|
|
289
|
+
conn.create_function(
|
|
290
|
+
"json_array_get_element", 2, py_json_array_get_element, deterministic=True
|
|
291
|
+
)
|
|
292
|
+
conn.create_function(
|
|
293
|
+
"json_array_slice", 2, py_json_array_slice, deterministic=True
|
|
294
|
+
)
|
|
295
|
+
conn.create_function(
|
|
296
|
+
"json_array_slice", 3, py_json_array_slice, deterministic=True
|
|
297
|
+
)
|
|
298
|
+
conn.create_function(
|
|
299
|
+
"json_array_join", 2, py_json_array_join, deterministic=True
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
_registered_function_creators["array_functions"] = create_array_functions
|
|
303
|
+
|
|
273
304
|
has_json_extension = functions_exist(["json_array_length", "json_array_contains"])
|
|
274
305
|
if not has_json_extension:
|
|
275
306
|
|
|
@@ -285,7 +316,11 @@ def register_user_defined_sql_functions() -> None:
|
|
|
285
316
|
|
|
286
317
|
|
|
287
318
|
def adapt_datetime(val: datetime) -> str:
|
|
288
|
-
|
|
319
|
+
is_utc_check = val.tzinfo is timezone.utc
|
|
320
|
+
tzname_check = val.tzname() == "UTC"
|
|
321
|
+
combined_check = is_utc_check or tzname_check
|
|
322
|
+
|
|
323
|
+
if not combined_check:
|
|
289
324
|
try:
|
|
290
325
|
val = val.astimezone(timezone.utc)
|
|
291
326
|
except (OverflowError, ValueError, OSError):
|
|
@@ -295,6 +330,7 @@ def adapt_datetime(val: datetime) -> str:
|
|
|
295
330
|
val = datetime.min.replace(tzinfo=timezone.utc)
|
|
296
331
|
else:
|
|
297
332
|
raise
|
|
333
|
+
|
|
298
334
|
return val.replace(tzinfo=None).isoformat(" ")
|
|
299
335
|
|
|
300
336
|
|
|
@@ -429,13 +465,42 @@ def compile_byte_hamming_distance(element, compiler, **kwargs):
|
|
|
429
465
|
|
|
430
466
|
|
|
431
467
|
def py_json_array_length(arr):
|
|
432
|
-
return len(
|
|
468
|
+
return len(json.loads(arr))
|
|
433
469
|
|
|
434
470
|
|
|
435
471
|
def py_json_array_contains(arr, value, is_json):
|
|
436
472
|
if is_json:
|
|
437
|
-
value =
|
|
438
|
-
return value in
|
|
473
|
+
value = json.loads(value)
|
|
474
|
+
return value in json.loads(arr)
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def py_json_array_get_element(val, idx):
|
|
478
|
+
arr = json.loads(val)
|
|
479
|
+
try:
|
|
480
|
+
return arr[idx]
|
|
481
|
+
except IndexError:
|
|
482
|
+
return None
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
def py_json_array_slice(val, offset: int, length: int | None = None):
|
|
486
|
+
arr = json.loads(val)
|
|
487
|
+
try:
|
|
488
|
+
return json.dumps(
|
|
489
|
+
list(arr[offset : offset + length] if length is not None else arr[offset:]),
|
|
490
|
+
ensure_ascii=False,
|
|
491
|
+
)
|
|
492
|
+
except IndexError:
|
|
493
|
+
return None
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def py_json_array_join(val, sep: str):
|
|
497
|
+
return sep.join(json.loads(val))
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def compile_array_get_element(element, compiler, **kwargs):
|
|
501
|
+
return compiler.process(
|
|
502
|
+
func.json_array_get_element(*element.clauses.clauses), **kwargs
|
|
503
|
+
)
|
|
439
504
|
|
|
440
505
|
|
|
441
506
|
def compile_array_length(element, compiler, **kwargs):
|
|
@@ -448,6 +513,14 @@ def compile_array_contains(element, compiler, **kwargs):
|
|
|
448
513
|
)
|
|
449
514
|
|
|
450
515
|
|
|
516
|
+
def compile_array_slice(element, compiler, **kwargs):
|
|
517
|
+
return compiler.process(func.json_array_slice(*element.clauses.clauses), **kwargs)
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
def compile_array_join(element, compiler, **kwargs):
|
|
521
|
+
return compiler.process(func.json_array_join(*element.clauses.clauses), **kwargs)
|
|
522
|
+
|
|
523
|
+
|
|
451
524
|
def compile_string_length(element, compiler, **kwargs):
|
|
452
525
|
return compiler.process(func.length(*element.clauses.clauses), **kwargs)
|
|
453
526
|
|
|
@@ -544,7 +617,7 @@ def compile_collect(element, compiler, **kwargs):
|
|
|
544
617
|
|
|
545
618
|
|
|
546
619
|
@cache
|
|
547
|
-
def usearch_sqlite_path() ->
|
|
620
|
+
def usearch_sqlite_path() -> str | None:
|
|
548
621
|
try:
|
|
549
622
|
import usearch
|
|
550
623
|
except ImportError:
|
datachain/sql/sqlite/types.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import sqlite3
|
|
2
2
|
|
|
3
|
-
import orjson
|
|
4
3
|
from sqlalchemy import types
|
|
5
4
|
|
|
5
|
+
from datachain import json
|
|
6
6
|
from datachain.sql.types import TypeConverter, TypeReadConverter
|
|
7
7
|
|
|
8
8
|
try:
|
|
@@ -28,26 +28,21 @@ class Array(types.UserDefinedType):
|
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
def adapt_array(arr):
|
|
31
|
-
return
|
|
31
|
+
return json.dumps(arr, ensure_ascii=False)
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
def adapt_dict(dct):
|
|
35
|
-
return
|
|
35
|
+
return json.dumps(dct, ensure_ascii=False)
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
def convert_array(arr):
|
|
39
|
-
return
|
|
39
|
+
return json.loads(arr)
|
|
40
40
|
|
|
41
41
|
|
|
42
42
|
def adapt_np_array(arr):
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
return obj
|
|
47
|
-
|
|
48
|
-
return orjson.dumps(
|
|
49
|
-
arr, option=orjson.OPT_SERIALIZE_NUMPY, default=_json_serialize
|
|
50
|
-
).decode("utf-8")
|
|
43
|
+
# Primarily needed for UDF numpy results (e.g. WDS)
|
|
44
|
+
# tolist() gives nested Python lists + native scalars; ujson.dumps handles NaN/Inf.
|
|
45
|
+
return json.dumps(arr.tolist(), ensure_ascii=False)
|
|
51
46
|
|
|
52
47
|
|
|
53
48
|
def adapt_np_generic(val):
|
|
@@ -74,5 +69,5 @@ class SQLiteTypeConverter(TypeConverter):
|
|
|
74
69
|
class SQLiteTypeReadConverter(TypeReadConverter):
|
|
75
70
|
def array(self, value, item_type, dialect):
|
|
76
71
|
if isinstance(value, str):
|
|
77
|
-
value =
|
|
72
|
+
value = json.loads(value)
|
|
78
73
|
return super().array(value, item_type, dialect)
|
datachain/sql/types.py
CHANGED
|
@@ -12,14 +12,15 @@ for sqlite we can use `sqlite.register_converter`
|
|
|
12
12
|
( https://docs.python.org/3/library/sqlite3.html#sqlite3.register_converter )
|
|
13
13
|
"""
|
|
14
14
|
|
|
15
|
+
import numbers
|
|
15
16
|
from datetime import datetime
|
|
16
17
|
from types import MappingProxyType
|
|
17
18
|
from typing import Any, Union
|
|
18
19
|
|
|
19
|
-
import orjson
|
|
20
20
|
import sqlalchemy as sa
|
|
21
21
|
from sqlalchemy import TypeDecorator, types
|
|
22
22
|
|
|
23
|
+
from datachain import json as jsonlib
|
|
23
24
|
from datachain.lib.data_model import StandardType
|
|
24
25
|
|
|
25
26
|
_registry: dict[str, "TypeConverter"] = {}
|
|
@@ -58,9 +59,14 @@ def converter(dialect) -> "TypeConverter":
|
|
|
58
59
|
try:
|
|
59
60
|
return registry[name]
|
|
60
61
|
except KeyError:
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
62
|
+
# Fall back to default converter if specific dialect not found
|
|
63
|
+
try:
|
|
64
|
+
return registry["default"]
|
|
65
|
+
except KeyError:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"No type converter registered for dialect: {dialect.name!r} "
|
|
68
|
+
f"and no default converter available"
|
|
69
|
+
) from None
|
|
64
70
|
|
|
65
71
|
|
|
66
72
|
def read_converter(dialect) -> "TypeReadConverter":
|
|
@@ -68,9 +74,14 @@ def read_converter(dialect) -> "TypeReadConverter":
|
|
|
68
74
|
try:
|
|
69
75
|
return read_converter_registry[name]
|
|
70
76
|
except KeyError:
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
77
|
+
# Fall back to default converter if specific dialect not found
|
|
78
|
+
try:
|
|
79
|
+
return read_converter_registry["default"]
|
|
80
|
+
except KeyError:
|
|
81
|
+
raise ValueError(
|
|
82
|
+
f"No read type converter registered for dialect: {dialect.name!r} "
|
|
83
|
+
f"and no default converter available"
|
|
84
|
+
) from None
|
|
74
85
|
|
|
75
86
|
|
|
76
87
|
def type_defaults(dialect) -> "TypeDefaults":
|
|
@@ -78,7 +89,14 @@ def type_defaults(dialect) -> "TypeDefaults":
|
|
|
78
89
|
try:
|
|
79
90
|
return type_defaults_registry[name]
|
|
80
91
|
except KeyError:
|
|
81
|
-
|
|
92
|
+
# Fall back to default converter if specific dialect not found
|
|
93
|
+
try:
|
|
94
|
+
return type_defaults_registry["default"]
|
|
95
|
+
except KeyError:
|
|
96
|
+
raise ValueError(
|
|
97
|
+
f"No type defaults registered for dialect: {dialect.name!r} "
|
|
98
|
+
f"and no default converter available"
|
|
99
|
+
) from None
|
|
82
100
|
|
|
83
101
|
|
|
84
102
|
def db_defaults(dialect) -> "DBDefaults":
|
|
@@ -86,7 +104,14 @@ def db_defaults(dialect) -> "DBDefaults":
|
|
|
86
104
|
try:
|
|
87
105
|
return db_defaults_registry[name]
|
|
88
106
|
except KeyError:
|
|
89
|
-
|
|
107
|
+
# Fall back to default converter if specific dialect not found
|
|
108
|
+
try:
|
|
109
|
+
return db_defaults_registry["default"]
|
|
110
|
+
except KeyError:
|
|
111
|
+
raise ValueError(
|
|
112
|
+
f"No DB defaults registered for dialect: {dialect.name!r} "
|
|
113
|
+
f"and no default converter available"
|
|
114
|
+
) from None
|
|
90
115
|
|
|
91
116
|
|
|
92
117
|
class SQLType(TypeDecorator):
|
|
@@ -312,10 +337,28 @@ class Array(SQLType):
|
|
|
312
337
|
|
|
313
338
|
@classmethod
|
|
314
339
|
def from_dict(cls, d: dict[str, Any]) -> Union[type["SQLType"], "SQLType"]:
|
|
315
|
-
|
|
316
|
-
d["item_type"]
|
|
317
|
-
|
|
318
|
-
|
|
340
|
+
try:
|
|
341
|
+
array_item = d["item_type"]
|
|
342
|
+
except KeyError as e:
|
|
343
|
+
raise ValueError("Array type must have 'item_type' field") from e
|
|
344
|
+
|
|
345
|
+
if not isinstance(array_item, dict):
|
|
346
|
+
raise TypeError("Array 'item_type' field must be a dictionary")
|
|
347
|
+
|
|
348
|
+
try:
|
|
349
|
+
item_type = array_item["type"]
|
|
350
|
+
except KeyError as e:
|
|
351
|
+
raise ValueError("Array 'item_type' must have 'type' field") from e
|
|
352
|
+
|
|
353
|
+
try:
|
|
354
|
+
sub_t = NAME_TYPES_MAPPING[item_type]
|
|
355
|
+
except KeyError as e:
|
|
356
|
+
raise ValueError(f"Array item type '{item_type}' is not supported") from e
|
|
357
|
+
|
|
358
|
+
try:
|
|
359
|
+
return cls(sub_t.from_dict(d["item_type"])) # type: ignore [attr-defined]
|
|
360
|
+
except KeyError as e:
|
|
361
|
+
raise ValueError(f"Array item type '{item_type}' is not supported") from e
|
|
319
362
|
|
|
320
363
|
@staticmethod
|
|
321
364
|
def default_value(dialect):
|
|
@@ -328,7 +371,7 @@ class Array(SQLType):
|
|
|
328
371
|
def on_read_convert(self, value, dialect):
|
|
329
372
|
r = read_converter(dialect).array(value, self.item_type, dialect)
|
|
330
373
|
if isinstance(self.item_type, JSON):
|
|
331
|
-
r = [
|
|
374
|
+
r = [jsonlib.loads(item) if isinstance(item, str) else item for item in r]
|
|
332
375
|
return r
|
|
333
376
|
|
|
334
377
|
|
|
@@ -403,6 +446,18 @@ class TypeReadConverter:
|
|
|
403
446
|
return value
|
|
404
447
|
|
|
405
448
|
def boolean(self, value):
|
|
449
|
+
if value is None or isinstance(value, bool):
|
|
450
|
+
return value
|
|
451
|
+
|
|
452
|
+
if isinstance(value, numbers.Integral):
|
|
453
|
+
return bool(value)
|
|
454
|
+
if isinstance(value, str):
|
|
455
|
+
normalized = value.strip().lower()
|
|
456
|
+
if normalized in {"true", "t", "yes", "y", "1"}:
|
|
457
|
+
return True
|
|
458
|
+
if normalized in {"false", "f", "no", "n", "0"}:
|
|
459
|
+
return False
|
|
460
|
+
|
|
406
461
|
return value
|
|
407
462
|
|
|
408
463
|
def int(self, value):
|
|
@@ -442,7 +497,7 @@ class TypeReadConverter:
|
|
|
442
497
|
if isinstance(value, str):
|
|
443
498
|
if value == "":
|
|
444
499
|
return {}
|
|
445
|
-
return
|
|
500
|
+
return jsonlib.loads(value)
|
|
446
501
|
return value
|
|
447
502
|
|
|
448
503
|
def datetime(self, value):
|