datachain 0.30.5__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +4 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +5 -5
- datachain/catalog/__init__.py +0 -2
- datachain/catalog/catalog.py +276 -354
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +8 -3
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +10 -17
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +42 -27
- datachain/cli/commands/ls.py +15 -15
- datachain/cli/commands/show.py +2 -2
- datachain/cli/parser/__init__.py +3 -43
- datachain/cli/parser/job.py +1 -1
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +2 -2
- datachain/client/fsspec.py +34 -23
- datachain/client/gcs.py +3 -3
- datachain/client/http.py +157 -0
- datachain/client/local.py +11 -7
- datachain/client/s3.py +3 -3
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +2 -0
- datachain/data_storage/metastore.py +716 -137
- datachain/data_storage/schema.py +20 -27
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +114 -114
- datachain/data_storage/warehouse.py +140 -48
- datachain/dataset.py +109 -89
- datachain/delta.py +117 -42
- datachain/diff/__init__.py +25 -33
- datachain/error.py +24 -0
- datachain/func/aggregate.py +9 -11
- datachain/func/array.py +12 -12
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +9 -13
- datachain/func/func.py +63 -45
- datachain/func/numeric.py +5 -7
- datachain/func/string.py +2 -2
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +18 -15
- datachain/lib/audio.py +60 -59
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/values_to_tuples.py +151 -53
- datachain/lib/data_model.py +23 -19
- datachain/lib/dataset_info.py +7 -7
- datachain/lib/dc/__init__.py +2 -1
- datachain/lib/dc/csv.py +22 -26
- datachain/lib/dc/database.py +37 -34
- datachain/lib/dc/datachain.py +518 -324
- datachain/lib/dc/datasets.py +38 -30
- datachain/lib/dc/hf.py +16 -20
- datachain/lib/dc/json.py +17 -18
- datachain/lib/dc/listings.py +5 -8
- datachain/lib/dc/pandas.py +3 -6
- datachain/lib/dc/parquet.py +33 -21
- datachain/lib/dc/records.py +9 -13
- datachain/lib/dc/storage.py +103 -65
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +17 -14
- datachain/lib/dc/values.py +3 -6
- datachain/lib/file.py +187 -50
- datachain/lib/hf.py +7 -5
- datachain/lib/image.py +13 -13
- datachain/lib/listing.py +5 -5
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +2 -3
- datachain/lib/model_store.py +20 -8
- datachain/lib/namespaces.py +59 -7
- datachain/lib/projects.py +51 -9
- datachain/lib/pytorch.py +31 -23
- datachain/lib/settings.py +188 -85
- datachain/lib/signal_schema.py +302 -64
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +103 -63
- datachain/lib/udf_signature.py +59 -34
- datachain/lib/utils.py +20 -0
- datachain/lib/video.py +3 -4
- datachain/lib/webdataset.py +31 -36
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +12 -5
- datachain/model/bbox.py +3 -1
- datachain/namespace.py +22 -3
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +4 -4
- datachain/query/batch.py +10 -12
- datachain/query/dataset.py +376 -194
- datachain/query/dispatch.py +112 -84
- datachain/query/metrics.py +3 -4
- datachain/query/params.py +2 -3
- datachain/query/queue.py +2 -1
- datachain/query/schema.py +7 -6
- datachain/query/session.py +190 -33
- datachain/query/udf.py +9 -6
- datachain/remote/studio.py +90 -53
- datachain/script_meta.py +12 -12
- datachain/sql/sqlite/base.py +37 -25
- datachain/sql/sqlite/types.py +1 -1
- datachain/sql/types.py +36 -5
- datachain/studio.py +49 -40
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +39 -48
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/METADATA +26 -38
- datachain-0.39.0.dist-info/RECORD +173 -0
- datachain/cli/commands/query.py +0 -54
- datachain/query/utils.py +0 -36
- datachain-0.30.5.dist-info/RECORD +0 -168
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/WHEEL +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/lib/signal_schema.py
CHANGED
|
@@ -1,30 +1,31 @@
|
|
|
1
1
|
import copy
|
|
2
|
+
import hashlib
|
|
3
|
+
import logging
|
|
4
|
+
import math
|
|
5
|
+
import types
|
|
2
6
|
import warnings
|
|
3
|
-
from collections.abc import Iterator, Sequence
|
|
7
|
+
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
|
|
4
8
|
from dataclasses import dataclass
|
|
5
9
|
from datetime import datetime
|
|
6
10
|
from inspect import isclass
|
|
7
|
-
from typing import (
|
|
11
|
+
from typing import (
|
|
8
12
|
IO,
|
|
9
13
|
TYPE_CHECKING,
|
|
10
14
|
Annotated,
|
|
11
15
|
Any,
|
|
12
|
-
Callable,
|
|
13
|
-
Dict,
|
|
14
16
|
Final,
|
|
15
|
-
List,
|
|
16
17
|
Literal,
|
|
17
|
-
Mapping,
|
|
18
18
|
Optional,
|
|
19
19
|
Union,
|
|
20
20
|
get_args,
|
|
21
21
|
get_origin,
|
|
22
22
|
)
|
|
23
23
|
|
|
24
|
-
from pydantic import BaseModel, Field, create_model
|
|
24
|
+
from pydantic import BaseModel, Field, ValidationError, create_model
|
|
25
25
|
from sqlalchemy import ColumnElement
|
|
26
26
|
from typing_extensions import Literal as LiteralEx
|
|
27
27
|
|
|
28
|
+
from datachain import json
|
|
28
29
|
from datachain.func import literal
|
|
29
30
|
from datachain.func.func import Func
|
|
30
31
|
from datachain.lib.convert.python_to_sql import python_to_sql
|
|
@@ -33,7 +34,7 @@ from datachain.lib.convert.unflatten import unflatten_to_json_pos
|
|
|
33
34
|
from datachain.lib.data_model import DataModel, DataType, DataValue
|
|
34
35
|
from datachain.lib.file import File
|
|
35
36
|
from datachain.lib.model_store import ModelStore
|
|
36
|
-
from datachain.lib.utils import DataChainParamsError
|
|
37
|
+
from datachain.lib.utils import DataChainColumnError, DataChainParamsError
|
|
37
38
|
from datachain.query.schema import DEFAULT_DELIMITER, C, Column, ColumnMeta
|
|
38
39
|
from datachain.sql.types import SQLType
|
|
39
40
|
|
|
@@ -41,6 +42,8 @@ if TYPE_CHECKING:
|
|
|
41
42
|
from datachain.catalog import Catalog
|
|
42
43
|
|
|
43
44
|
|
|
45
|
+
logger = logging.getLogger(__name__)
|
|
46
|
+
|
|
44
47
|
NAMES_TO_TYPES = {
|
|
45
48
|
"int": int,
|
|
46
49
|
"str": str,
|
|
@@ -69,7 +72,7 @@ class SignalSchemaWarning(RuntimeWarning):
|
|
|
69
72
|
|
|
70
73
|
|
|
71
74
|
class SignalResolvingError(SignalSchemaError):
|
|
72
|
-
def __init__(self, path:
|
|
75
|
+
def __init__(self, path: list[str] | None, msg: str):
|
|
73
76
|
name = " '" + ".".join(path) + "'" if path else ""
|
|
74
77
|
super().__init__(f"cannot resolve signal name{name}: {msg}")
|
|
75
78
|
|
|
@@ -79,6 +82,55 @@ class SetupError(SignalSchemaError):
|
|
|
79
82
|
super().__init__(f"cannot setup value '{name}': {msg}")
|
|
80
83
|
|
|
81
84
|
|
|
85
|
+
def generate_merge_root_mapping(
|
|
86
|
+
left_names: Iterable[str],
|
|
87
|
+
right_names: Sequence[str],
|
|
88
|
+
*,
|
|
89
|
+
extract_root: Callable[[str], str],
|
|
90
|
+
prefix: str,
|
|
91
|
+
) -> dict[str, str]:
|
|
92
|
+
"""Compute root renames for schema merges.
|
|
93
|
+
|
|
94
|
+
Returns a mapping from each right-side root to the target root name while
|
|
95
|
+
preserving the order in which right-side roots first appear. The mapping
|
|
96
|
+
avoids collisions with roots already present on the left side and among
|
|
97
|
+
the right-side roots themselves. When a conflict is detected, the
|
|
98
|
+
``prefix`` string is used to derive candidate root names until a unique
|
|
99
|
+
one is found.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
existing_roots = {extract_root(name) for name in left_names}
|
|
103
|
+
|
|
104
|
+
right_root_order: list[str] = []
|
|
105
|
+
right_roots: set[str] = set()
|
|
106
|
+
for name in right_names:
|
|
107
|
+
root = extract_root(name)
|
|
108
|
+
if root not in right_roots:
|
|
109
|
+
right_roots.add(root)
|
|
110
|
+
right_root_order.append(root)
|
|
111
|
+
|
|
112
|
+
used_roots = set(existing_roots)
|
|
113
|
+
root_mapping: dict[str, str] = {}
|
|
114
|
+
|
|
115
|
+
for root in right_root_order:
|
|
116
|
+
if root not in used_roots:
|
|
117
|
+
root_mapping[root] = root
|
|
118
|
+
used_roots.add(root)
|
|
119
|
+
continue
|
|
120
|
+
|
|
121
|
+
suffix = 0
|
|
122
|
+
while True:
|
|
123
|
+
base = prefix if root in prefix else f"{prefix}{root}"
|
|
124
|
+
candidate_root = base if suffix == 0 else f"{base}_{suffix}"
|
|
125
|
+
if candidate_root not in used_roots and candidate_root not in right_roots:
|
|
126
|
+
root_mapping[root] = candidate_root
|
|
127
|
+
used_roots.add(candidate_root)
|
|
128
|
+
break
|
|
129
|
+
suffix += 1
|
|
130
|
+
|
|
131
|
+
return root_mapping
|
|
132
|
+
|
|
133
|
+
|
|
82
134
|
class SignalResolvingTypeError(SignalResolvingError):
|
|
83
135
|
def __init__(self, method: str, field):
|
|
84
136
|
super().__init__(
|
|
@@ -89,7 +141,7 @@ class SignalResolvingTypeError(SignalResolvingError):
|
|
|
89
141
|
|
|
90
142
|
|
|
91
143
|
class SignalRemoveError(SignalSchemaError):
|
|
92
|
-
def __init__(self, path:
|
|
144
|
+
def __init__(self, path: list[str] | None, msg: str):
|
|
93
145
|
name = " '" + ".".join(path) + "'" if path else ""
|
|
94
146
|
super().__init__(f"cannot remove signal name{name}: {msg}")
|
|
95
147
|
|
|
@@ -98,8 +150,8 @@ class CustomType(BaseModel):
|
|
|
98
150
|
schema_version: int = Field(ge=1, le=2, strict=True)
|
|
99
151
|
name: str
|
|
100
152
|
fields: dict[str, str]
|
|
101
|
-
bases: list[tuple[str, str,
|
|
102
|
-
hidden_fields:
|
|
153
|
+
bases: list[tuple[str, str, str | None]]
|
|
154
|
+
hidden_fields: list[str] | None = None
|
|
103
155
|
|
|
104
156
|
@classmethod
|
|
105
157
|
def deserialize(cls, data: dict[str, Any], type_name: str) -> "CustomType":
|
|
@@ -119,8 +171,8 @@ class CustomType(BaseModel):
|
|
|
119
171
|
|
|
120
172
|
def create_feature_model(
|
|
121
173
|
name: str,
|
|
122
|
-
fields: Mapping[str,
|
|
123
|
-
base:
|
|
174
|
+
fields: Mapping[str, type | tuple[type, Any] | None],
|
|
175
|
+
base: type | None = None,
|
|
124
176
|
) -> type[BaseModel]:
|
|
125
177
|
"""
|
|
126
178
|
This gets or returns a dynamic feature model for use in restoring a model
|
|
@@ -137,7 +189,7 @@ def create_feature_model(
|
|
|
137
189
|
**{
|
|
138
190
|
field_name: anno if isinstance(anno, tuple) else (anno, None)
|
|
139
191
|
for field_name, anno in fields.items()
|
|
140
|
-
},
|
|
192
|
+
}, # type: ignore[arg-type]
|
|
141
193
|
)
|
|
142
194
|
|
|
143
195
|
|
|
@@ -146,12 +198,12 @@ class SignalSchema:
|
|
|
146
198
|
values: dict[str, DataType]
|
|
147
199
|
tree: dict[str, Any]
|
|
148
200
|
setup_func: dict[str, Callable]
|
|
149
|
-
setup_values:
|
|
201
|
+
setup_values: dict[str, Any] | None
|
|
150
202
|
|
|
151
203
|
def __init__(
|
|
152
204
|
self,
|
|
153
205
|
values: dict[str, DataType],
|
|
154
|
-
setup:
|
|
206
|
+
setup: dict[str, Callable] | None = None,
|
|
155
207
|
):
|
|
156
208
|
self.values = values
|
|
157
209
|
self.tree = self._build_tree(values)
|
|
@@ -190,8 +242,8 @@ class SignalSchema:
|
|
|
190
242
|
return SignalSchema(signals)
|
|
191
243
|
|
|
192
244
|
@staticmethod
|
|
193
|
-
def _get_bases(fr: type) -> list[tuple[str, str,
|
|
194
|
-
bases: list[tuple[str, str,
|
|
245
|
+
def _get_bases(fr: type) -> list[tuple[str, str, str | None]]:
|
|
246
|
+
bases: list[tuple[str, str, str | None]] = []
|
|
195
247
|
for base in fr.__mro__:
|
|
196
248
|
model_store_name = (
|
|
197
249
|
ModelStore.get_name(base) if issubclass(base, DataModel) else None
|
|
@@ -257,6 +309,11 @@ class SignalSchema:
|
|
|
257
309
|
signals["_custom_types"] = custom_types
|
|
258
310
|
return signals
|
|
259
311
|
|
|
312
|
+
def hash(self) -> str:
|
|
313
|
+
"""Create SHA hash of this schema"""
|
|
314
|
+
json_str = json.dumps(self.serialize(), sort_keys=True, separators=(",", ":"))
|
|
315
|
+
return hashlib.sha256(json_str.encode("utf-8")).hexdigest()
|
|
316
|
+
|
|
260
317
|
@staticmethod
|
|
261
318
|
def _split_subtypes(type_name: str) -> list[str]:
|
|
262
319
|
"""This splits a list of subtypes, including proper square bracket handling."""
|
|
@@ -283,7 +340,7 @@ class SignalSchema:
|
|
|
283
340
|
@staticmethod
|
|
284
341
|
def _deserialize_custom_type(
|
|
285
342
|
type_name: str, custom_types: dict[str, Any]
|
|
286
|
-
) ->
|
|
343
|
+
) -> type | None:
|
|
287
344
|
"""Given a type name like MyType@v1 gets a type from ModelStore or recreates
|
|
288
345
|
it based on the information from the custom types dict that includes fields and
|
|
289
346
|
bases."""
|
|
@@ -316,7 +373,7 @@ class SignalSchema:
|
|
|
316
373
|
return None
|
|
317
374
|
|
|
318
375
|
@staticmethod
|
|
319
|
-
def _resolve_type(type_name: str, custom_types: dict[str, Any]) ->
|
|
376
|
+
def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> type | None:
|
|
320
377
|
"""Convert a string-based type back into a python type."""
|
|
321
378
|
type_name = type_name.strip()
|
|
322
379
|
if not type_name:
|
|
@@ -325,7 +382,7 @@ class SignalSchema:
|
|
|
325
382
|
return None
|
|
326
383
|
|
|
327
384
|
bracket_idx = type_name.find("[")
|
|
328
|
-
subtypes:
|
|
385
|
+
subtypes: tuple[type | None, ...] | None = None
|
|
329
386
|
if bracket_idx > -1:
|
|
330
387
|
if bracket_idx == 0:
|
|
331
388
|
raise ValueError("Type cannot start with '['")
|
|
@@ -456,13 +513,33 @@ class SignalSchema:
|
|
|
456
513
|
objs.append(self.setup_values.get(name))
|
|
457
514
|
elif (fr := ModelStore.to_pydantic(fr_type)) is not None:
|
|
458
515
|
j, pos = unflatten_to_json_pos(fr, row, pos)
|
|
459
|
-
|
|
516
|
+
try:
|
|
517
|
+
obj = fr(**j)
|
|
518
|
+
except ValidationError as e:
|
|
519
|
+
if self._all_values_none(j):
|
|
520
|
+
logger.debug("Failed to create input for %s: %s", name, e)
|
|
521
|
+
obj = None
|
|
522
|
+
else:
|
|
523
|
+
raise
|
|
524
|
+
objs.append(obj)
|
|
460
525
|
else:
|
|
461
526
|
objs.append(row[pos])
|
|
462
527
|
pos += 1
|
|
463
528
|
return objs
|
|
464
529
|
|
|
465
|
-
|
|
530
|
+
@staticmethod
|
|
531
|
+
def _all_values_none(value: Any) -> bool:
|
|
532
|
+
if isinstance(value, dict):
|
|
533
|
+
return all(SignalSchema._all_values_none(v) for v in value.values())
|
|
534
|
+
if isinstance(value, (list, tuple, set)):
|
|
535
|
+
return all(SignalSchema._all_values_none(v) for v in value)
|
|
536
|
+
if isinstance(value, float):
|
|
537
|
+
# NaN is used to represent NULL and NaN float values in datachain
|
|
538
|
+
# Since SQLite does not have a separate NULL type, we need to check for NaN
|
|
539
|
+
return math.isnan(value) or value is None
|
|
540
|
+
return value is None
|
|
541
|
+
|
|
542
|
+
def get_file_signal(self) -> str | None:
|
|
466
543
|
for signal_name, signal_type in self.values.items():
|
|
467
544
|
if (fr := ModelStore.to_pydantic(signal_type)) is not None and issubclass(
|
|
468
545
|
fr, File
|
|
@@ -472,8 +549,8 @@ class SignalSchema:
|
|
|
472
549
|
|
|
473
550
|
def slice(
|
|
474
551
|
self,
|
|
475
|
-
params: dict[str,
|
|
476
|
-
setup:
|
|
552
|
+
params: dict[str, DataType | Any],
|
|
553
|
+
setup: dict[str, Callable] | None = None,
|
|
477
554
|
is_batch: bool = False,
|
|
478
555
|
) -> "SignalSchema":
|
|
479
556
|
"""
|
|
@@ -497,9 +574,13 @@ class SignalSchema:
|
|
|
497
574
|
schema_origin = get_origin(schema_type)
|
|
498
575
|
param_origin = get_origin(param_type)
|
|
499
576
|
|
|
500
|
-
if schema_origin
|
|
577
|
+
if schema_origin in (Union, types.UnionType) and type(None) in get_args(
|
|
578
|
+
schema_type
|
|
579
|
+
):
|
|
501
580
|
schema_type = get_args(schema_type)[0]
|
|
502
|
-
if param_origin
|
|
581
|
+
if param_origin in (Union, types.UnionType) and type(None) in get_args(
|
|
582
|
+
param_type
|
|
583
|
+
):
|
|
503
584
|
param_type = get_args(param_type)[0]
|
|
504
585
|
|
|
505
586
|
if is_batch:
|
|
@@ -535,15 +616,90 @@ class SignalSchema:
|
|
|
535
616
|
pos = 0
|
|
536
617
|
for fr_cls in self.values.values():
|
|
537
618
|
if (fr := ModelStore.to_pydantic(fr_cls)) is None:
|
|
538
|
-
|
|
619
|
+
value = row[pos]
|
|
539
620
|
pos += 1
|
|
621
|
+
converted = self._convert_feature_value(fr_cls, value, catalog, cache)
|
|
622
|
+
res.append(converted)
|
|
540
623
|
else:
|
|
541
624
|
json, pos = unflatten_to_json_pos(fr, row, pos) # type: ignore[union-attr]
|
|
542
|
-
|
|
543
|
-
|
|
625
|
+
try:
|
|
626
|
+
obj = fr(**json)
|
|
627
|
+
SignalSchema._set_file_stream(obj, catalog, cache)
|
|
628
|
+
except ValidationError as e:
|
|
629
|
+
if self._all_values_none(json):
|
|
630
|
+
logger.debug("Failed to create feature for %s: %s", fr_cls, e)
|
|
631
|
+
obj = None
|
|
632
|
+
else:
|
|
633
|
+
raise
|
|
544
634
|
res.append(obj)
|
|
545
635
|
return res
|
|
546
636
|
|
|
637
|
+
def _convert_feature_value(
|
|
638
|
+
self,
|
|
639
|
+
annotation: DataType,
|
|
640
|
+
value: Any,
|
|
641
|
+
catalog: "Catalog",
|
|
642
|
+
cache: bool,
|
|
643
|
+
) -> Any:
|
|
644
|
+
"""Convert raw DB value into declared annotation if needed."""
|
|
645
|
+
if value is None:
|
|
646
|
+
return None
|
|
647
|
+
|
|
648
|
+
result = value
|
|
649
|
+
origin = get_origin(annotation)
|
|
650
|
+
|
|
651
|
+
if origin in (Union, types.UnionType):
|
|
652
|
+
non_none_args = [
|
|
653
|
+
arg for arg in get_args(annotation) if arg is not type(None)
|
|
654
|
+
]
|
|
655
|
+
if len(non_none_args) == 1:
|
|
656
|
+
annotation = non_none_args[0]
|
|
657
|
+
origin = get_origin(annotation)
|
|
658
|
+
else:
|
|
659
|
+
return result
|
|
660
|
+
|
|
661
|
+
if ModelStore.is_pydantic(annotation):
|
|
662
|
+
if isinstance(value, annotation):
|
|
663
|
+
obj = value
|
|
664
|
+
elif isinstance(value, Mapping):
|
|
665
|
+
obj = annotation(**value)
|
|
666
|
+
else:
|
|
667
|
+
return result
|
|
668
|
+
assert isinstance(obj, BaseModel)
|
|
669
|
+
SignalSchema._set_file_stream(obj, catalog, cache)
|
|
670
|
+
result = obj
|
|
671
|
+
elif origin is list:
|
|
672
|
+
args = get_args(annotation)
|
|
673
|
+
if args and isinstance(value, (list, tuple)):
|
|
674
|
+
item_type = args[0]
|
|
675
|
+
result = [
|
|
676
|
+
self._convert_feature_value(item_type, item, catalog, cache)
|
|
677
|
+
if item is not None
|
|
678
|
+
else None
|
|
679
|
+
for item in value
|
|
680
|
+
]
|
|
681
|
+
elif origin is dict:
|
|
682
|
+
args = get_args(annotation)
|
|
683
|
+
if len(args) == 2 and isinstance(value, dict):
|
|
684
|
+
key_type, val_type = args
|
|
685
|
+
result = {}
|
|
686
|
+
for key, val in value.items():
|
|
687
|
+
if key_type is str:
|
|
688
|
+
converted_key = key
|
|
689
|
+
else:
|
|
690
|
+
loaded_key = json.loads(key)
|
|
691
|
+
converted_key = self._convert_feature_value(
|
|
692
|
+
key_type, loaded_key, catalog, cache
|
|
693
|
+
)
|
|
694
|
+
converted_val = (
|
|
695
|
+
self._convert_feature_value(val_type, val, catalog, cache)
|
|
696
|
+
if val_type is not Any
|
|
697
|
+
else val
|
|
698
|
+
)
|
|
699
|
+
result[converted_key] = converted_val
|
|
700
|
+
|
|
701
|
+
return result
|
|
702
|
+
|
|
547
703
|
@staticmethod
|
|
548
704
|
def _set_file_stream(
|
|
549
705
|
obj: BaseModel, catalog: "Catalog", cache: bool = False
|
|
@@ -572,8 +728,8 @@ class SignalSchema:
|
|
|
572
728
|
raise SignalResolvingError([col_name], "is not found")
|
|
573
729
|
|
|
574
730
|
def db_signals(
|
|
575
|
-
self, name:
|
|
576
|
-
) ->
|
|
731
|
+
self, name: str | None = None, as_columns=False, include_hidden: bool = True
|
|
732
|
+
) -> list[str] | list[Column]:
|
|
577
733
|
"""
|
|
578
734
|
Returns DB columns as strings or Column objects with proper types
|
|
579
735
|
Optionally, it can filter results by specific object, returning only his signals
|
|
@@ -600,6 +756,35 @@ class SignalSchema:
|
|
|
600
756
|
|
|
601
757
|
return signals # type: ignore[return-value]
|
|
602
758
|
|
|
759
|
+
def user_signals(
|
|
760
|
+
self,
|
|
761
|
+
*,
|
|
762
|
+
include_hidden: bool = True,
|
|
763
|
+
include_sys: bool = False,
|
|
764
|
+
) -> list[str]:
|
|
765
|
+
return [
|
|
766
|
+
".".join(path)
|
|
767
|
+
for path, _, has_subtree, _ in self.get_flat_tree(
|
|
768
|
+
include_hidden=include_hidden, include_sys=include_sys
|
|
769
|
+
)
|
|
770
|
+
if not has_subtree
|
|
771
|
+
]
|
|
772
|
+
|
|
773
|
+
def compare_signals(
|
|
774
|
+
self,
|
|
775
|
+
other: "SignalSchema",
|
|
776
|
+
*,
|
|
777
|
+
include_hidden: bool = True,
|
|
778
|
+
include_sys: bool = False,
|
|
779
|
+
) -> tuple[set[str], set[str]]:
|
|
780
|
+
left = set(
|
|
781
|
+
self.user_signals(include_hidden=include_hidden, include_sys=include_sys)
|
|
782
|
+
)
|
|
783
|
+
right = set(
|
|
784
|
+
other.user_signals(include_hidden=include_hidden, include_sys=include_sys)
|
|
785
|
+
)
|
|
786
|
+
return left - right, right - left
|
|
787
|
+
|
|
603
788
|
def resolve(self, *names: str) -> "SignalSchema":
|
|
604
789
|
schema = {}
|
|
605
790
|
for field in names:
|
|
@@ -733,12 +918,30 @@ class SignalSchema:
|
|
|
733
918
|
right_schema: "SignalSchema",
|
|
734
919
|
rname: str,
|
|
735
920
|
) -> "SignalSchema":
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
921
|
+
merged_values = dict(self.values)
|
|
922
|
+
|
|
923
|
+
right_names = list(right_schema.values.keys())
|
|
924
|
+
root_mapping = generate_merge_root_mapping(
|
|
925
|
+
self.values.keys(),
|
|
926
|
+
right_names,
|
|
927
|
+
extract_root=self._extract_root,
|
|
928
|
+
prefix=rname,
|
|
929
|
+
)
|
|
930
|
+
|
|
931
|
+
for key, type_ in right_schema.values.items():
|
|
932
|
+
root = self._extract_root(key)
|
|
933
|
+
tail = key.partition(".")[2]
|
|
934
|
+
mapped_root = root_mapping[root]
|
|
935
|
+
new_name = mapped_root if not tail else f"{mapped_root}.{tail}"
|
|
936
|
+
merged_values[new_name] = type_
|
|
740
937
|
|
|
741
|
-
return SignalSchema(
|
|
938
|
+
return SignalSchema(merged_values)
|
|
939
|
+
|
|
940
|
+
@staticmethod
|
|
941
|
+
def _extract_root(name: str) -> str:
|
|
942
|
+
if "." in name:
|
|
943
|
+
return name.split(".", 1)[0]
|
|
944
|
+
return name
|
|
742
945
|
|
|
743
946
|
def append(self, right: "SignalSchema") -> "SignalSchema":
|
|
744
947
|
missing_schema = {
|
|
@@ -758,29 +961,38 @@ class SignalSchema:
|
|
|
758
961
|
return create_model(
|
|
759
962
|
name,
|
|
760
963
|
__base__=(DataModel,), # type: ignore[call-overload]
|
|
761
|
-
**fields,
|
|
964
|
+
**fields, # type: ignore[arg-type]
|
|
762
965
|
)
|
|
763
966
|
|
|
764
967
|
@staticmethod
|
|
765
968
|
def _build_tree(
|
|
766
969
|
values: dict[str, DataType],
|
|
767
|
-
) -> dict[str, tuple[DataType,
|
|
970
|
+
) -> dict[str, tuple[DataType, dict | None]]:
|
|
768
971
|
return {
|
|
769
972
|
name: (val, SignalSchema._build_tree_for_type(val))
|
|
770
973
|
for name, val in values.items()
|
|
771
974
|
}
|
|
772
975
|
|
|
773
976
|
def get_flat_tree(
|
|
774
|
-
self,
|
|
977
|
+
self,
|
|
978
|
+
include_hidden: bool = True,
|
|
979
|
+
include_sys: bool = True,
|
|
775
980
|
) -> Iterator[tuple[list[str], DataType, bool, int]]:
|
|
776
|
-
yield from self._get_flat_tree(self.tree, [], 0, include_hidden)
|
|
981
|
+
yield from self._get_flat_tree(self.tree, [], 0, include_hidden, include_sys)
|
|
777
982
|
|
|
778
983
|
def _get_flat_tree(
|
|
779
|
-
self,
|
|
984
|
+
self,
|
|
985
|
+
tree: dict,
|
|
986
|
+
prefix: list[str],
|
|
987
|
+
depth: int,
|
|
988
|
+
include_hidden: bool,
|
|
989
|
+
include_sys: bool,
|
|
780
990
|
) -> Iterator[tuple[list[str], DataType, bool, int]]:
|
|
781
991
|
for name, (type_, substree) in tree.items():
|
|
782
992
|
suffix = name.split(".")
|
|
783
993
|
new_prefix = prefix + suffix
|
|
994
|
+
if not include_sys and new_prefix and new_prefix[0] == "sys":
|
|
995
|
+
continue
|
|
784
996
|
hidden_fields = getattr(type_, "_hidden_fields", None)
|
|
785
997
|
if hidden_fields and substree and not include_hidden:
|
|
786
998
|
substree = {
|
|
@@ -793,10 +1005,10 @@ class SignalSchema:
|
|
|
793
1005
|
yield new_prefix, type_, has_subtree, depth
|
|
794
1006
|
if substree is not None:
|
|
795
1007
|
yield from self._get_flat_tree(
|
|
796
|
-
substree, new_prefix, depth + 1, include_hidden
|
|
1008
|
+
substree, new_prefix, depth + 1, include_hidden, include_sys
|
|
797
1009
|
)
|
|
798
1010
|
|
|
799
|
-
def print_tree(self, indent: int = 2, start_at: int = 0, file:
|
|
1011
|
+
def print_tree(self, indent: int = 2, start_at: int = 0, file: IO | None = None):
|
|
800
1012
|
for path, type_, _, depth in self.get_flat_tree():
|
|
801
1013
|
total_indent = start_at + depth * indent
|
|
802
1014
|
col_name = " " * total_indent + path[-1]
|
|
@@ -826,7 +1038,28 @@ class SignalSchema:
|
|
|
826
1038
|
], max_length
|
|
827
1039
|
|
|
828
1040
|
def __or__(self, other):
|
|
829
|
-
|
|
1041
|
+
new_values = dict(self.values)
|
|
1042
|
+
|
|
1043
|
+
for name, new_type in other.values.items():
|
|
1044
|
+
if name in new_values:
|
|
1045
|
+
current_type = new_values[name]
|
|
1046
|
+
if current_type != new_type:
|
|
1047
|
+
raise DataChainColumnError(
|
|
1048
|
+
name,
|
|
1049
|
+
"signal already exists with a different type",
|
|
1050
|
+
)
|
|
1051
|
+
continue
|
|
1052
|
+
|
|
1053
|
+
root = self._extract_root(name)
|
|
1054
|
+
if any(self._extract_root(existing) == root for existing in new_values):
|
|
1055
|
+
raise DataChainColumnError(
|
|
1056
|
+
name,
|
|
1057
|
+
"signal root already exists in schema",
|
|
1058
|
+
)
|
|
1059
|
+
|
|
1060
|
+
new_values[name] = new_type
|
|
1061
|
+
|
|
1062
|
+
return self.__class__(new_values)
|
|
830
1063
|
|
|
831
1064
|
def __contains__(self, name: str):
|
|
832
1065
|
return name in self.values
|
|
@@ -835,15 +1068,20 @@ class SignalSchema:
|
|
|
835
1068
|
return self.values.pop(name)
|
|
836
1069
|
|
|
837
1070
|
@staticmethod
|
|
838
|
-
def _type_to_str(type_:
|
|
1071
|
+
def _type_to_str(type_: type | None, subtypes: list | None = None) -> str: # noqa: C901, PLR0911
|
|
839
1072
|
"""Convert a type to a string-based representation."""
|
|
840
1073
|
if type_ is None:
|
|
841
1074
|
return "NoneType"
|
|
842
1075
|
|
|
843
1076
|
origin = get_origin(type_)
|
|
844
1077
|
|
|
845
|
-
if origin
|
|
1078
|
+
if origin in (Union, types.UnionType):
|
|
846
1079
|
args = get_args(type_)
|
|
1080
|
+
if len(args) == 2 and type(None) in args:
|
|
1081
|
+
# This is an Optional type.
|
|
1082
|
+
non_none_type = args[0] if args[1] is type(None) else args[1]
|
|
1083
|
+
type_str = SignalSchema._type_to_str(non_none_type, subtypes)
|
|
1084
|
+
return f"Optional[{type_str}]"
|
|
847
1085
|
formatted_types = ", ".join(
|
|
848
1086
|
SignalSchema._type_to_str(arg, subtypes) for arg in args
|
|
849
1087
|
)
|
|
@@ -852,21 +1090,21 @@ class SignalSchema:
|
|
|
852
1090
|
args = get_args(type_)
|
|
853
1091
|
type_str = SignalSchema._type_to_str(args[0], subtypes)
|
|
854
1092
|
return f"Optional[{type_str}]"
|
|
855
|
-
if origin
|
|
1093
|
+
if origin is list:
|
|
856
1094
|
args = get_args(type_)
|
|
1095
|
+
if len(args) == 0:
|
|
1096
|
+
return "list"
|
|
857
1097
|
type_str = SignalSchema._type_to_str(args[0], subtypes)
|
|
858
1098
|
return f"list[{type_str}]"
|
|
859
|
-
if origin
|
|
1099
|
+
if origin is dict:
|
|
860
1100
|
args = get_args(type_)
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
)
|
|
864
|
-
|
|
865
|
-
f",
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
)
|
|
869
|
-
return f"dict[{type_str}{vals}]"
|
|
1101
|
+
if len(args) == 0:
|
|
1102
|
+
return "dict"
|
|
1103
|
+
key_type = SignalSchema._type_to_str(args[0], subtypes)
|
|
1104
|
+
if len(args) == 1:
|
|
1105
|
+
return f"dict[{key_type}, Any]"
|
|
1106
|
+
val_type = SignalSchema._type_to_str(args[1], subtypes)
|
|
1107
|
+
return f"dict[{key_type}, {val_type}]"
|
|
870
1108
|
if origin == Annotated:
|
|
871
1109
|
args = get_args(type_)
|
|
872
1110
|
return SignalSchema._type_to_str(args[0], subtypes)
|
|
@@ -880,7 +1118,7 @@ class SignalSchema:
|
|
|
880
1118
|
# Include this type in the list of all subtypes, if requested.
|
|
881
1119
|
subtypes.append(type_)
|
|
882
1120
|
if not hasattr(type_, "__name__"):
|
|
883
|
-
# This can happen for some third-party or custom types
|
|
1121
|
+
# This can happen for some third-party or custom types
|
|
884
1122
|
warnings.warn(
|
|
885
1123
|
f"Unable to determine name of type '{type_}'.",
|
|
886
1124
|
SignalSchemaWarning,
|
|
@@ -895,7 +1133,7 @@ class SignalSchema:
|
|
|
895
1133
|
@staticmethod
|
|
896
1134
|
def _build_tree_for_type(
|
|
897
1135
|
model: DataType,
|
|
898
|
-
) ->
|
|
1136
|
+
) -> dict[str, tuple[DataType, dict | None]] | None:
|
|
899
1137
|
if (fr := ModelStore.to_pydantic(model)) is not None:
|
|
900
1138
|
return SignalSchema._build_tree_for_model(fr)
|
|
901
1139
|
return None
|
|
@@ -903,8 +1141,8 @@ class SignalSchema:
|
|
|
903
1141
|
@staticmethod
|
|
904
1142
|
def _build_tree_for_model(
|
|
905
1143
|
model: type[BaseModel],
|
|
906
|
-
) ->
|
|
907
|
-
res: dict[str, tuple[DataType,
|
|
1144
|
+
) -> dict[str, tuple[DataType, dict | None]] | None:
|
|
1145
|
+
res: dict[str, tuple[DataType, dict | None]] = {}
|
|
908
1146
|
|
|
909
1147
|
for name, f_info in model.model_fields.items():
|
|
910
1148
|
anno = f_info.annotation
|
|
@@ -953,7 +1191,7 @@ class SignalSchema:
|
|
|
953
1191
|
schema: dict[str, Any] = {}
|
|
954
1192
|
schema_custom_types: dict[str, CustomType] = {}
|
|
955
1193
|
|
|
956
|
-
data_model_bases:
|
|
1194
|
+
data_model_bases: list[tuple[str, str, str | None]] | None = None
|
|
957
1195
|
|
|
958
1196
|
signal_partials: dict[str, str] = {}
|
|
959
1197
|
partial_versions: dict[str, int] = {}
|
datachain/lib/text.py
CHANGED
|
@@ -1,16 +1,17 @@
|
|
|
1
|
-
from
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import Any
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
def convert_text(
|
|
8
|
-
text:
|
|
9
|
-
tokenizer:
|
|
10
|
-
tokenizer_kwargs:
|
|
11
|
-
encoder:
|
|
12
|
-
device:
|
|
13
|
-
) ->
|
|
9
|
+
text: str | list[str],
|
|
10
|
+
tokenizer: Callable | None = None,
|
|
11
|
+
tokenizer_kwargs: dict[str, Any] | None = None,
|
|
12
|
+
encoder: Callable | None = None,
|
|
13
|
+
device: str | torch.device | None = None,
|
|
14
|
+
) -> str | list[str] | torch.Tensor:
|
|
14
15
|
"""
|
|
15
16
|
Tokenize and otherwise transform text.
|
|
16
17
|
|