datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +20 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +7 -7
- datachain/catalog/__init__.py +2 -2
- datachain/catalog/catalog.py +621 -507
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +28 -18
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +24 -33
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +83 -52
- datachain/cli/commands/ls.py +17 -17
- datachain/cli/commands/show.py +4 -4
- datachain/cli/parser/__init__.py +8 -74
- datachain/cli/parser/job.py +95 -3
- datachain/cli/parser/studio.py +11 -4
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +4 -4
- datachain/client/fsspec.py +45 -28
- datachain/client/gcs.py +6 -6
- datachain/client/hf.py +29 -2
- datachain/client/http.py +157 -0
- datachain/client/local.py +15 -11
- datachain/client/s3.py +17 -9
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +5 -1
- datachain/data_storage/metastore.py +1252 -186
- datachain/data_storage/schema.py +58 -45
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +286 -127
- datachain/data_storage/warehouse.py +250 -113
- datachain/dataset.py +353 -148
- datachain/delta.py +391 -0
- datachain/diff/__init__.py +27 -29
- datachain/error.py +60 -0
- datachain/func/__init__.py +2 -1
- datachain/func/aggregate.py +66 -42
- datachain/func/array.py +242 -38
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +110 -60
- datachain/func/func.py +96 -45
- datachain/func/numeric.py +55 -38
- datachain/func/path.py +32 -20
- datachain/func/random.py +2 -2
- datachain/func/string.py +67 -37
- datachain/func/window.py +7 -8
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +58 -22
- datachain/lib/audio.py +245 -0
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/flatten.py +5 -3
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/sql_to_python.py +8 -0
- datachain/lib/convert/values_to_tuples.py +156 -51
- datachain/lib/data_model.py +42 -20
- datachain/lib/dataset_info.py +36 -8
- datachain/lib/dc/__init__.py +8 -2
- datachain/lib/dc/csv.py +25 -28
- datachain/lib/dc/database.py +398 -0
- datachain/lib/dc/datachain.py +1289 -425
- datachain/lib/dc/datasets.py +320 -38
- datachain/lib/dc/hf.py +38 -24
- datachain/lib/dc/json.py +29 -32
- datachain/lib/dc/listings.py +112 -8
- datachain/lib/dc/pandas.py +16 -12
- datachain/lib/dc/parquet.py +35 -23
- datachain/lib/dc/records.py +31 -23
- datachain/lib/dc/storage.py +154 -64
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +24 -16
- datachain/lib/dc/values.py +8 -9
- datachain/lib/file.py +622 -89
- datachain/lib/hf.py +69 -39
- datachain/lib/image.py +14 -14
- datachain/lib/listing.py +14 -11
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +3 -4
- datachain/lib/model_store.py +39 -7
- datachain/lib/namespaces.py +125 -0
- datachain/lib/projects.py +130 -0
- datachain/lib/pytorch.py +32 -21
- datachain/lib/settings.py +192 -56
- datachain/lib/signal_schema.py +427 -104
- datachain/lib/tar.py +1 -2
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +164 -76
- datachain/lib/udf_signature.py +60 -35
- datachain/lib/utils.py +118 -4
- datachain/lib/video.py +17 -9
- datachain/lib/webdataset.py +61 -56
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +22 -10
- datachain/model/bbox.py +3 -1
- datachain/model/ultralytics/bbox.py +16 -12
- datachain/model/ultralytics/pose.py +16 -12
- datachain/model/ultralytics/segment.py +16 -12
- datachain/namespace.py +84 -0
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +78 -0
- datachain/query/batch.py +40 -41
- datachain/query/dataset.py +604 -322
- datachain/query/dispatch.py +261 -154
- datachain/query/metrics.py +4 -6
- datachain/query/params.py +2 -3
- datachain/query/queue.py +3 -12
- datachain/query/schema.py +11 -6
- datachain/query/session.py +200 -33
- datachain/query/udf.py +34 -2
- datachain/remote/studio.py +171 -69
- datachain/script_meta.py +12 -12
- datachain/semver.py +68 -0
- datachain/sql/__init__.py +2 -0
- datachain/sql/functions/array.py +33 -1
- datachain/sql/postgresql_dialect.py +9 -0
- datachain/sql/postgresql_types.py +21 -0
- datachain/sql/sqlite/__init__.py +5 -1
- datachain/sql/sqlite/base.py +102 -29
- datachain/sql/sqlite/types.py +8 -13
- datachain/sql/types.py +70 -15
- datachain/studio.py +223 -46
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +101 -59
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
- datachain-0.39.0.dist-info/RECORD +173 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
- datachain/cli/commands/query.py +0 -53
- datachain/query/utils.py +0 -42
- datachain-0.14.2.dist-info/RECORD +0 -158
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/lib/signal_schema.py
CHANGED
|
@@ -1,30 +1,32 @@
|
|
|
1
1
|
import copy
|
|
2
|
+
import hashlib
|
|
3
|
+
import logging
|
|
4
|
+
import math
|
|
5
|
+
import types
|
|
2
6
|
import warnings
|
|
3
|
-
from collections.abc import Iterator, Sequence
|
|
7
|
+
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
|
|
4
8
|
from dataclasses import dataclass
|
|
5
9
|
from datetime import datetime
|
|
6
10
|
from inspect import isclass
|
|
7
|
-
from typing import (
|
|
11
|
+
from typing import (
|
|
8
12
|
IO,
|
|
9
13
|
TYPE_CHECKING,
|
|
10
14
|
Annotated,
|
|
11
15
|
Any,
|
|
12
|
-
Callable,
|
|
13
|
-
Dict,
|
|
14
16
|
Final,
|
|
15
|
-
List,
|
|
16
17
|
Literal,
|
|
17
|
-
Mapping,
|
|
18
18
|
Optional,
|
|
19
19
|
Union,
|
|
20
20
|
get_args,
|
|
21
21
|
get_origin,
|
|
22
22
|
)
|
|
23
23
|
|
|
24
|
-
from pydantic import BaseModel, Field, create_model
|
|
24
|
+
from pydantic import BaseModel, Field, ValidationError, create_model
|
|
25
25
|
from sqlalchemy import ColumnElement
|
|
26
26
|
from typing_extensions import Literal as LiteralEx
|
|
27
27
|
|
|
28
|
+
from datachain import json
|
|
29
|
+
from datachain.func import literal
|
|
28
30
|
from datachain.func.func import Func
|
|
29
31
|
from datachain.lib.convert.python_to_sql import python_to_sql
|
|
30
32
|
from datachain.lib.convert.sql_to_python import sql_to_python
|
|
@@ -32,14 +34,16 @@ from datachain.lib.convert.unflatten import unflatten_to_json_pos
|
|
|
32
34
|
from datachain.lib.data_model import DataModel, DataType, DataValue
|
|
33
35
|
from datachain.lib.file import File
|
|
34
36
|
from datachain.lib.model_store import ModelStore
|
|
35
|
-
from datachain.lib.utils import DataChainParamsError
|
|
36
|
-
from datachain.query.schema import DEFAULT_DELIMITER, Column
|
|
37
|
+
from datachain.lib.utils import DataChainColumnError, DataChainParamsError
|
|
38
|
+
from datachain.query.schema import DEFAULT_DELIMITER, C, Column, ColumnMeta
|
|
37
39
|
from datachain.sql.types import SQLType
|
|
38
40
|
|
|
39
41
|
if TYPE_CHECKING:
|
|
40
42
|
from datachain.catalog import Catalog
|
|
41
43
|
|
|
42
44
|
|
|
45
|
+
logger = logging.getLogger(__name__)
|
|
46
|
+
|
|
43
47
|
NAMES_TO_TYPES = {
|
|
44
48
|
"int": int,
|
|
45
49
|
"str": str,
|
|
@@ -68,7 +72,7 @@ class SignalSchemaWarning(RuntimeWarning):
|
|
|
68
72
|
|
|
69
73
|
|
|
70
74
|
class SignalResolvingError(SignalSchemaError):
|
|
71
|
-
def __init__(self, path:
|
|
75
|
+
def __init__(self, path: list[str] | None, msg: str):
|
|
72
76
|
name = " '" + ".".join(path) + "'" if path else ""
|
|
73
77
|
super().__init__(f"cannot resolve signal name{name}: {msg}")
|
|
74
78
|
|
|
@@ -78,6 +82,55 @@ class SetupError(SignalSchemaError):
|
|
|
78
82
|
super().__init__(f"cannot setup value '{name}': {msg}")
|
|
79
83
|
|
|
80
84
|
|
|
85
|
+
def generate_merge_root_mapping(
|
|
86
|
+
left_names: Iterable[str],
|
|
87
|
+
right_names: Sequence[str],
|
|
88
|
+
*,
|
|
89
|
+
extract_root: Callable[[str], str],
|
|
90
|
+
prefix: str,
|
|
91
|
+
) -> dict[str, str]:
|
|
92
|
+
"""Compute root renames for schema merges.
|
|
93
|
+
|
|
94
|
+
Returns a mapping from each right-side root to the target root name while
|
|
95
|
+
preserving the order in which right-side roots first appear. The mapping
|
|
96
|
+
avoids collisions with roots already present on the left side and among
|
|
97
|
+
the right-side roots themselves. When a conflict is detected, the
|
|
98
|
+
``prefix`` string is used to derive candidate root names until a unique
|
|
99
|
+
one is found.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
existing_roots = {extract_root(name) for name in left_names}
|
|
103
|
+
|
|
104
|
+
right_root_order: list[str] = []
|
|
105
|
+
right_roots: set[str] = set()
|
|
106
|
+
for name in right_names:
|
|
107
|
+
root = extract_root(name)
|
|
108
|
+
if root not in right_roots:
|
|
109
|
+
right_roots.add(root)
|
|
110
|
+
right_root_order.append(root)
|
|
111
|
+
|
|
112
|
+
used_roots = set(existing_roots)
|
|
113
|
+
root_mapping: dict[str, str] = {}
|
|
114
|
+
|
|
115
|
+
for root in right_root_order:
|
|
116
|
+
if root not in used_roots:
|
|
117
|
+
root_mapping[root] = root
|
|
118
|
+
used_roots.add(root)
|
|
119
|
+
continue
|
|
120
|
+
|
|
121
|
+
suffix = 0
|
|
122
|
+
while True:
|
|
123
|
+
base = prefix if root in prefix else f"{prefix}{root}"
|
|
124
|
+
candidate_root = base if suffix == 0 else f"{base}_{suffix}"
|
|
125
|
+
if candidate_root not in used_roots and candidate_root not in right_roots:
|
|
126
|
+
root_mapping[root] = candidate_root
|
|
127
|
+
used_roots.add(candidate_root)
|
|
128
|
+
break
|
|
129
|
+
suffix += 1
|
|
130
|
+
|
|
131
|
+
return root_mapping
|
|
132
|
+
|
|
133
|
+
|
|
81
134
|
class SignalResolvingTypeError(SignalResolvingError):
|
|
82
135
|
def __init__(self, method: str, field):
|
|
83
136
|
super().__init__(
|
|
@@ -87,12 +140,18 @@ class SignalResolvingTypeError(SignalResolvingError):
|
|
|
87
140
|
)
|
|
88
141
|
|
|
89
142
|
|
|
143
|
+
class SignalRemoveError(SignalSchemaError):
|
|
144
|
+
def __init__(self, path: list[str] | None, msg: str):
|
|
145
|
+
name = " '" + ".".join(path) + "'" if path else ""
|
|
146
|
+
super().__init__(f"cannot remove signal name{name}: {msg}")
|
|
147
|
+
|
|
148
|
+
|
|
90
149
|
class CustomType(BaseModel):
|
|
91
150
|
schema_version: int = Field(ge=1, le=2, strict=True)
|
|
92
151
|
name: str
|
|
93
152
|
fields: dict[str, str]
|
|
94
|
-
bases: list[tuple[str, str,
|
|
95
|
-
hidden_fields:
|
|
153
|
+
bases: list[tuple[str, str, str | None]]
|
|
154
|
+
hidden_fields: list[str] | None = None
|
|
96
155
|
|
|
97
156
|
@classmethod
|
|
98
157
|
def deserialize(cls, data: dict[str, Any], type_name: str) -> "CustomType":
|
|
@@ -112,8 +171,8 @@ class CustomType(BaseModel):
|
|
|
112
171
|
|
|
113
172
|
def create_feature_model(
|
|
114
173
|
name: str,
|
|
115
|
-
fields: Mapping[str,
|
|
116
|
-
base:
|
|
174
|
+
fields: Mapping[str, type | tuple[type, Any] | None],
|
|
175
|
+
base: type | None = None,
|
|
117
176
|
) -> type[BaseModel]:
|
|
118
177
|
"""
|
|
119
178
|
This gets or returns a dynamic feature model for use in restoring a model
|
|
@@ -130,7 +189,7 @@ def create_feature_model(
|
|
|
130
189
|
**{
|
|
131
190
|
field_name: anno if isinstance(anno, tuple) else (anno, None)
|
|
132
191
|
for field_name, anno in fields.items()
|
|
133
|
-
},
|
|
192
|
+
}, # type: ignore[arg-type]
|
|
134
193
|
)
|
|
135
194
|
|
|
136
195
|
|
|
@@ -139,12 +198,12 @@ class SignalSchema:
|
|
|
139
198
|
values: dict[str, DataType]
|
|
140
199
|
tree: dict[str, Any]
|
|
141
200
|
setup_func: dict[str, Callable]
|
|
142
|
-
setup_values:
|
|
201
|
+
setup_values: dict[str, Any] | None
|
|
143
202
|
|
|
144
203
|
def __init__(
|
|
145
204
|
self,
|
|
146
205
|
values: dict[str, DataType],
|
|
147
|
-
setup:
|
|
206
|
+
setup: dict[str, Callable] | None = None,
|
|
148
207
|
):
|
|
149
208
|
self.values = values
|
|
150
209
|
self.tree = self._build_tree(values)
|
|
@@ -183,8 +242,8 @@ class SignalSchema:
|
|
|
183
242
|
return SignalSchema(signals)
|
|
184
243
|
|
|
185
244
|
@staticmethod
|
|
186
|
-
def _get_bases(fr: type) -> list[tuple[str, str,
|
|
187
|
-
bases: list[tuple[str, str,
|
|
245
|
+
def _get_bases(fr: type) -> list[tuple[str, str, str | None]]:
|
|
246
|
+
bases: list[tuple[str, str, str | None]] = []
|
|
188
247
|
for base in fr.__mro__:
|
|
189
248
|
model_store_name = (
|
|
190
249
|
ModelStore.get_name(base) if issubclass(base, DataModel) else None
|
|
@@ -250,6 +309,11 @@ class SignalSchema:
|
|
|
250
309
|
signals["_custom_types"] = custom_types
|
|
251
310
|
return signals
|
|
252
311
|
|
|
312
|
+
def hash(self) -> str:
|
|
313
|
+
"""Create SHA hash of this schema"""
|
|
314
|
+
json_str = json.dumps(self.serialize(), sort_keys=True, separators=(",", ":"))
|
|
315
|
+
return hashlib.sha256(json_str.encode("utf-8")).hexdigest()
|
|
316
|
+
|
|
253
317
|
@staticmethod
|
|
254
318
|
def _split_subtypes(type_name: str) -> list[str]:
|
|
255
319
|
"""This splits a list of subtypes, including proper square bracket handling."""
|
|
@@ -276,7 +340,7 @@ class SignalSchema:
|
|
|
276
340
|
@staticmethod
|
|
277
341
|
def _deserialize_custom_type(
|
|
278
342
|
type_name: str, custom_types: dict[str, Any]
|
|
279
|
-
) ->
|
|
343
|
+
) -> type | None:
|
|
280
344
|
"""Given a type name like MyType@v1 gets a type from ModelStore or recreates
|
|
281
345
|
it based on the information from the custom types dict that includes fields and
|
|
282
346
|
bases."""
|
|
@@ -309,7 +373,7 @@ class SignalSchema:
|
|
|
309
373
|
return None
|
|
310
374
|
|
|
311
375
|
@staticmethod
|
|
312
|
-
def _resolve_type(type_name: str, custom_types: dict[str, Any]) ->
|
|
376
|
+
def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> type | None:
|
|
313
377
|
"""Convert a string-based type back into a python type."""
|
|
314
378
|
type_name = type_name.strip()
|
|
315
379
|
if not type_name:
|
|
@@ -318,7 +382,7 @@ class SignalSchema:
|
|
|
318
382
|
return None
|
|
319
383
|
|
|
320
384
|
bracket_idx = type_name.find("[")
|
|
321
|
-
subtypes:
|
|
385
|
+
subtypes: tuple[type | None, ...] | None = None
|
|
322
386
|
if bracket_idx > -1:
|
|
323
387
|
if bracket_idx == 0:
|
|
324
388
|
raise ValueError("Type cannot start with '['")
|
|
@@ -439,35 +503,54 @@ class SignalSchema:
|
|
|
439
503
|
res[db_name] = python_to_sql(type_)
|
|
440
504
|
return res
|
|
441
505
|
|
|
442
|
-
def row_to_objs(self, row: Sequence[Any]) -> list[
|
|
506
|
+
def row_to_objs(self, row: Sequence[Any]) -> list[Any]:
|
|
443
507
|
self._init_setup_values()
|
|
444
508
|
|
|
445
|
-
objs: list[
|
|
509
|
+
objs: list[Any] = []
|
|
446
510
|
pos = 0
|
|
447
511
|
for name, fr_type in self.values.items():
|
|
448
|
-
if self.setup_values and
|
|
449
|
-
objs.append(
|
|
512
|
+
if self.setup_values and name in self.setup_values:
|
|
513
|
+
objs.append(self.setup_values.get(name))
|
|
450
514
|
elif (fr := ModelStore.to_pydantic(fr_type)) is not None:
|
|
451
515
|
j, pos = unflatten_to_json_pos(fr, row, pos)
|
|
452
|
-
|
|
516
|
+
try:
|
|
517
|
+
obj = fr(**j)
|
|
518
|
+
except ValidationError as e:
|
|
519
|
+
if self._all_values_none(j):
|
|
520
|
+
logger.debug("Failed to create input for %s: %s", name, e)
|
|
521
|
+
obj = None
|
|
522
|
+
else:
|
|
523
|
+
raise
|
|
524
|
+
objs.append(obj)
|
|
453
525
|
else:
|
|
454
526
|
objs.append(row[pos])
|
|
455
527
|
pos += 1
|
|
456
528
|
return objs
|
|
457
529
|
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
530
|
+
@staticmethod
|
|
531
|
+
def _all_values_none(value: Any) -> bool:
|
|
532
|
+
if isinstance(value, dict):
|
|
533
|
+
return all(SignalSchema._all_values_none(v) for v in value.values())
|
|
534
|
+
if isinstance(value, (list, tuple, set)):
|
|
535
|
+
return all(SignalSchema._all_values_none(v) for v in value)
|
|
536
|
+
if isinstance(value, float):
|
|
537
|
+
# NaN is used to represent NULL and NaN float values in datachain
|
|
538
|
+
# Since SQLite does not have a separate NULL type, we need to check for NaN
|
|
539
|
+
return math.isnan(value) or value is None
|
|
540
|
+
return value is None
|
|
541
|
+
|
|
542
|
+
def get_file_signal(self) -> str | None:
|
|
543
|
+
for signal_name, signal_type in self.values.items():
|
|
544
|
+
if (fr := ModelStore.to_pydantic(signal_type)) is not None and issubclass(
|
|
461
545
|
fr, File
|
|
462
546
|
):
|
|
463
|
-
return
|
|
464
|
-
|
|
465
|
-
return False
|
|
547
|
+
return signal_name
|
|
548
|
+
return None
|
|
466
549
|
|
|
467
550
|
def slice(
|
|
468
551
|
self,
|
|
469
|
-
params: dict[str,
|
|
470
|
-
setup:
|
|
552
|
+
params: dict[str, DataType | Any],
|
|
553
|
+
setup: dict[str, Callable] | None = None,
|
|
471
554
|
is_batch: bool = False,
|
|
472
555
|
) -> "SignalSchema":
|
|
473
556
|
"""
|
|
@@ -491,9 +574,13 @@ class SignalSchema:
|
|
|
491
574
|
schema_origin = get_origin(schema_type)
|
|
492
575
|
param_origin = get_origin(param_type)
|
|
493
576
|
|
|
494
|
-
if schema_origin
|
|
577
|
+
if schema_origin in (Union, types.UnionType) and type(None) in get_args(
|
|
578
|
+
schema_type
|
|
579
|
+
):
|
|
495
580
|
schema_type = get_args(schema_type)[0]
|
|
496
|
-
if param_origin
|
|
581
|
+
if param_origin in (Union, types.UnionType) and type(None) in get_args(
|
|
582
|
+
param_type
|
|
583
|
+
):
|
|
497
584
|
param_type = get_args(param_type)[0]
|
|
498
585
|
|
|
499
586
|
if is_batch:
|
|
@@ -529,22 +616,97 @@ class SignalSchema:
|
|
|
529
616
|
pos = 0
|
|
530
617
|
for fr_cls in self.values.values():
|
|
531
618
|
if (fr := ModelStore.to_pydantic(fr_cls)) is None:
|
|
532
|
-
|
|
619
|
+
value = row[pos]
|
|
533
620
|
pos += 1
|
|
621
|
+
converted = self._convert_feature_value(fr_cls, value, catalog, cache)
|
|
622
|
+
res.append(converted)
|
|
534
623
|
else:
|
|
535
624
|
json, pos = unflatten_to_json_pos(fr, row, pos) # type: ignore[union-attr]
|
|
536
|
-
|
|
537
|
-
|
|
625
|
+
try:
|
|
626
|
+
obj = fr(**json)
|
|
627
|
+
SignalSchema._set_file_stream(obj, catalog, cache)
|
|
628
|
+
except ValidationError as e:
|
|
629
|
+
if self._all_values_none(json):
|
|
630
|
+
logger.debug("Failed to create feature for %s: %s", fr_cls, e)
|
|
631
|
+
obj = None
|
|
632
|
+
else:
|
|
633
|
+
raise
|
|
538
634
|
res.append(obj)
|
|
539
635
|
return res
|
|
540
636
|
|
|
637
|
+
def _convert_feature_value(
|
|
638
|
+
self,
|
|
639
|
+
annotation: DataType,
|
|
640
|
+
value: Any,
|
|
641
|
+
catalog: "Catalog",
|
|
642
|
+
cache: bool,
|
|
643
|
+
) -> Any:
|
|
644
|
+
"""Convert raw DB value into declared annotation if needed."""
|
|
645
|
+
if value is None:
|
|
646
|
+
return None
|
|
647
|
+
|
|
648
|
+
result = value
|
|
649
|
+
origin = get_origin(annotation)
|
|
650
|
+
|
|
651
|
+
if origin in (Union, types.UnionType):
|
|
652
|
+
non_none_args = [
|
|
653
|
+
arg for arg in get_args(annotation) if arg is not type(None)
|
|
654
|
+
]
|
|
655
|
+
if len(non_none_args) == 1:
|
|
656
|
+
annotation = non_none_args[0]
|
|
657
|
+
origin = get_origin(annotation)
|
|
658
|
+
else:
|
|
659
|
+
return result
|
|
660
|
+
|
|
661
|
+
if ModelStore.is_pydantic(annotation):
|
|
662
|
+
if isinstance(value, annotation):
|
|
663
|
+
obj = value
|
|
664
|
+
elif isinstance(value, Mapping):
|
|
665
|
+
obj = annotation(**value)
|
|
666
|
+
else:
|
|
667
|
+
return result
|
|
668
|
+
assert isinstance(obj, BaseModel)
|
|
669
|
+
SignalSchema._set_file_stream(obj, catalog, cache)
|
|
670
|
+
result = obj
|
|
671
|
+
elif origin is list:
|
|
672
|
+
args = get_args(annotation)
|
|
673
|
+
if args and isinstance(value, (list, tuple)):
|
|
674
|
+
item_type = args[0]
|
|
675
|
+
result = [
|
|
676
|
+
self._convert_feature_value(item_type, item, catalog, cache)
|
|
677
|
+
if item is not None
|
|
678
|
+
else None
|
|
679
|
+
for item in value
|
|
680
|
+
]
|
|
681
|
+
elif origin is dict:
|
|
682
|
+
args = get_args(annotation)
|
|
683
|
+
if len(args) == 2 and isinstance(value, dict):
|
|
684
|
+
key_type, val_type = args
|
|
685
|
+
result = {}
|
|
686
|
+
for key, val in value.items():
|
|
687
|
+
if key_type is str:
|
|
688
|
+
converted_key = key
|
|
689
|
+
else:
|
|
690
|
+
loaded_key = json.loads(key)
|
|
691
|
+
converted_key = self._convert_feature_value(
|
|
692
|
+
key_type, loaded_key, catalog, cache
|
|
693
|
+
)
|
|
694
|
+
converted_val = (
|
|
695
|
+
self._convert_feature_value(val_type, val, catalog, cache)
|
|
696
|
+
if val_type is not Any
|
|
697
|
+
else val
|
|
698
|
+
)
|
|
699
|
+
result[converted_key] = converted_val
|
|
700
|
+
|
|
701
|
+
return result
|
|
702
|
+
|
|
541
703
|
@staticmethod
|
|
542
704
|
def _set_file_stream(
|
|
543
705
|
obj: BaseModel, catalog: "Catalog", cache: bool = False
|
|
544
706
|
) -> None:
|
|
545
707
|
if isinstance(obj, File):
|
|
546
708
|
obj._set_stream(catalog, caching_enabled=cache)
|
|
547
|
-
for field, finfo in obj.model_fields.items():
|
|
709
|
+
for field, finfo in type(obj).model_fields.items():
|
|
548
710
|
if ModelStore.is_pydantic(finfo.annotation):
|
|
549
711
|
SignalSchema._set_file_stream(getattr(obj, field), catalog, cache)
|
|
550
712
|
|
|
@@ -566,8 +728,8 @@ class SignalSchema:
|
|
|
566
728
|
raise SignalResolvingError([col_name], "is not found")
|
|
567
729
|
|
|
568
730
|
def db_signals(
|
|
569
|
-
self, name:
|
|
570
|
-
) ->
|
|
731
|
+
self, name: str | None = None, as_columns=False, include_hidden: bool = True
|
|
732
|
+
) -> list[str] | list[Column]:
|
|
571
733
|
"""
|
|
572
734
|
Returns DB columns as strings or Column objects with proper types
|
|
573
735
|
Optionally, it can filter results by specific object, returning only his signals
|
|
@@ -583,6 +745,9 @@ class SignalSchema:
|
|
|
583
745
|
]
|
|
584
746
|
|
|
585
747
|
if name:
|
|
748
|
+
if "." in name:
|
|
749
|
+
name = ColumnMeta.to_db_name(name)
|
|
750
|
+
|
|
586
751
|
signals = [
|
|
587
752
|
s
|
|
588
753
|
for s in signals
|
|
@@ -591,6 +756,35 @@ class SignalSchema:
|
|
|
591
756
|
|
|
592
757
|
return signals # type: ignore[return-value]
|
|
593
758
|
|
|
759
|
+
def user_signals(
|
|
760
|
+
self,
|
|
761
|
+
*,
|
|
762
|
+
include_hidden: bool = True,
|
|
763
|
+
include_sys: bool = False,
|
|
764
|
+
) -> list[str]:
|
|
765
|
+
return [
|
|
766
|
+
".".join(path)
|
|
767
|
+
for path, _, has_subtree, _ in self.get_flat_tree(
|
|
768
|
+
include_hidden=include_hidden, include_sys=include_sys
|
|
769
|
+
)
|
|
770
|
+
if not has_subtree
|
|
771
|
+
]
|
|
772
|
+
|
|
773
|
+
def compare_signals(
|
|
774
|
+
self,
|
|
775
|
+
other: "SignalSchema",
|
|
776
|
+
*,
|
|
777
|
+
include_hidden: bool = True,
|
|
778
|
+
include_sys: bool = False,
|
|
779
|
+
) -> tuple[set[str], set[str]]:
|
|
780
|
+
left = set(
|
|
781
|
+
self.user_signals(include_hidden=include_hidden, include_sys=include_sys)
|
|
782
|
+
)
|
|
783
|
+
right = set(
|
|
784
|
+
other.user_signals(include_hidden=include_hidden, include_sys=include_sys)
|
|
785
|
+
)
|
|
786
|
+
return left - right, right - left
|
|
787
|
+
|
|
594
788
|
def resolve(self, *names: str) -> "SignalSchema":
|
|
595
789
|
schema = {}
|
|
596
790
|
for field in names:
|
|
@@ -601,37 +795,60 @@ class SignalSchema:
|
|
|
601
795
|
return SignalSchema(schema)
|
|
602
796
|
|
|
603
797
|
def _find_in_tree(self, path: list[str]) -> DataType:
|
|
798
|
+
if val := self.tree.get(".".join(path)):
|
|
799
|
+
# If the path is a single string, we can directly access it
|
|
800
|
+
# without traversing the tree.
|
|
801
|
+
return val[0]
|
|
802
|
+
|
|
604
803
|
curr_tree = self.tree
|
|
605
804
|
curr_type = None
|
|
606
805
|
i = 0
|
|
607
806
|
while curr_tree is not None and i < len(path):
|
|
608
807
|
if val := curr_tree.get(path[i]):
|
|
609
808
|
curr_type, curr_tree = val
|
|
610
|
-
elif i == 0 and len(path) > 1 and (val := curr_tree.get(".".join(path))):
|
|
611
|
-
curr_type, curr_tree = val
|
|
612
|
-
break
|
|
613
809
|
else:
|
|
614
810
|
curr_type = None
|
|
811
|
+
break
|
|
615
812
|
i += 1
|
|
616
813
|
|
|
617
|
-
if curr_type is None:
|
|
814
|
+
if curr_type is None or i < len(path):
|
|
815
|
+
# If we reached the end of the path and didn't find a type,
|
|
816
|
+
# or if we didn't traverse the entire path, raise an error.
|
|
618
817
|
raise SignalResolvingError(path, "is not found")
|
|
619
818
|
|
|
620
819
|
return curr_type
|
|
621
820
|
|
|
821
|
+
def group_by(
|
|
822
|
+
self, partition_by: Sequence[str], new_column: Sequence[Column]
|
|
823
|
+
) -> "SignalSchema":
|
|
824
|
+
orig_schema = SignalSchema(copy.deepcopy(self.values))
|
|
825
|
+
schema = orig_schema.to_partial(*partition_by)
|
|
826
|
+
|
|
827
|
+
vals = {c.name: sql_to_python(c) for c in new_column}
|
|
828
|
+
return SignalSchema(schema.values | vals)
|
|
829
|
+
|
|
622
830
|
def select_except_signals(self, *args: str) -> "SignalSchema":
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
raise SignalResolvingTypeError("select_except()", field)
|
|
831
|
+
def has_signal(signal: str):
|
|
832
|
+
signal = signal.replace(".", DEFAULT_DELIMITER)
|
|
833
|
+
return any(signal == s for s in self.db_signals())
|
|
627
834
|
|
|
628
|
-
|
|
835
|
+
schema = copy.deepcopy(self.values)
|
|
836
|
+
for signal in args:
|
|
837
|
+
if not isinstance(signal, str):
|
|
838
|
+
raise SignalResolvingTypeError("select_except()", signal)
|
|
839
|
+
|
|
840
|
+
if signal not in self.values:
|
|
841
|
+
if has_signal(signal):
|
|
842
|
+
raise SignalRemoveError(
|
|
843
|
+
signal.split("."),
|
|
844
|
+
"select_except() error - removing nested signal would"
|
|
845
|
+
" break parent schema, which isn't supported.",
|
|
846
|
+
)
|
|
629
847
|
raise SignalResolvingError(
|
|
630
|
-
|
|
631
|
-
"select_except() error - the
|
|
632
|
-
"inside of feature (not supported)",
|
|
848
|
+
signal.split("."),
|
|
849
|
+
"select_except() error - the signal does not exist",
|
|
633
850
|
)
|
|
634
|
-
del schema[
|
|
851
|
+
del schema[signal]
|
|
635
852
|
|
|
636
853
|
return SignalSchema(schema)
|
|
637
854
|
|
|
@@ -645,31 +862,49 @@ class SignalSchema:
|
|
|
645
862
|
|
|
646
863
|
def mutate(self, args_map: dict) -> "SignalSchema":
|
|
647
864
|
new_values = self.values.copy()
|
|
865
|
+
primitives = (bool, str, int, float)
|
|
648
866
|
|
|
649
867
|
for name, value in args_map.items():
|
|
868
|
+
current_type = None
|
|
869
|
+
|
|
870
|
+
if C.is_nested(name):
|
|
871
|
+
try:
|
|
872
|
+
current_type = self.get_column_type(name)
|
|
873
|
+
except SignalResolvingError as err:
|
|
874
|
+
msg = f"Creating new nested columns directly is not allowed: {name}"
|
|
875
|
+
raise ValueError(msg) from err
|
|
876
|
+
|
|
650
877
|
if isinstance(value, Column) and value.name in self.values:
|
|
651
878
|
# renaming existing signal
|
|
879
|
+
# Note: it won't touch nested signals here (e.g. file__path)
|
|
880
|
+
# we don't allow removing nested columns to keep objects consistent
|
|
652
881
|
del new_values[value.name]
|
|
653
882
|
new_values[name] = self.values[value.name]
|
|
654
|
-
|
|
655
|
-
if isinstance(value, Column):
|
|
883
|
+
elif isinstance(value, Column):
|
|
656
884
|
# adding new signal from existing signal field
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
value.name, with_subtree=True
|
|
660
|
-
)
|
|
661
|
-
continue
|
|
662
|
-
except SignalResolvingError:
|
|
663
|
-
pass
|
|
664
|
-
if isinstance(value, Func):
|
|
885
|
+
new_values[name] = self.get_column_type(value.name, with_subtree=True)
|
|
886
|
+
elif isinstance(value, Func):
|
|
665
887
|
# adding new signal with function
|
|
666
888
|
new_values[name] = value.get_result_type(self)
|
|
667
|
-
|
|
668
|
-
|
|
889
|
+
elif isinstance(value, primitives):
|
|
890
|
+
# For primitives, store the type, not the value
|
|
891
|
+
val = literal(value)
|
|
892
|
+
val.type = python_to_sql(type(value))()
|
|
893
|
+
new_values[name] = sql_to_python(val)
|
|
894
|
+
elif isinstance(value, ColumnElement):
|
|
669
895
|
# adding new signal
|
|
670
896
|
new_values[name] = sql_to_python(value)
|
|
671
|
-
|
|
672
|
-
|
|
897
|
+
else:
|
|
898
|
+
new_values[name] = value
|
|
899
|
+
|
|
900
|
+
if C.is_nested(name):
|
|
901
|
+
if current_type != new_values[name]:
|
|
902
|
+
msg = (
|
|
903
|
+
f"Altering nested column type is not allowed: {name}, "
|
|
904
|
+
f"current type: {current_type}, new type: {new_values[name]}"
|
|
905
|
+
)
|
|
906
|
+
raise ValueError(msg)
|
|
907
|
+
del new_values[name]
|
|
673
908
|
|
|
674
909
|
return SignalSchema(new_values)
|
|
675
910
|
|
|
@@ -683,12 +918,37 @@ class SignalSchema:
|
|
|
683
918
|
right_schema: "SignalSchema",
|
|
684
919
|
rname: str,
|
|
685
920
|
) -> "SignalSchema":
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
921
|
+
merged_values = dict(self.values)
|
|
922
|
+
|
|
923
|
+
right_names = list(right_schema.values.keys())
|
|
924
|
+
root_mapping = generate_merge_root_mapping(
|
|
925
|
+
self.values.keys(),
|
|
926
|
+
right_names,
|
|
927
|
+
extract_root=self._extract_root,
|
|
928
|
+
prefix=rname,
|
|
929
|
+
)
|
|
690
930
|
|
|
691
|
-
|
|
931
|
+
for key, type_ in right_schema.values.items():
|
|
932
|
+
root = self._extract_root(key)
|
|
933
|
+
tail = key.partition(".")[2]
|
|
934
|
+
mapped_root = root_mapping[root]
|
|
935
|
+
new_name = mapped_root if not tail else f"{mapped_root}.{tail}"
|
|
936
|
+
merged_values[new_name] = type_
|
|
937
|
+
|
|
938
|
+
return SignalSchema(merged_values)
|
|
939
|
+
|
|
940
|
+
@staticmethod
|
|
941
|
+
def _extract_root(name: str) -> str:
|
|
942
|
+
if "." in name:
|
|
943
|
+
return name.split(".", 1)[0]
|
|
944
|
+
return name
|
|
945
|
+
|
|
946
|
+
def append(self, right: "SignalSchema") -> "SignalSchema":
|
|
947
|
+
missing_schema = {
|
|
948
|
+
key: right.values[key]
|
|
949
|
+
for key in [k for k in right.values if k not in self.values]
|
|
950
|
+
}
|
|
951
|
+
return SignalSchema(self.values | missing_schema)
|
|
692
952
|
|
|
693
953
|
def get_signals(self, target_type: type[DataModel]) -> Iterator[str]:
|
|
694
954
|
for path, type_, has_subtree, _ in self.get_flat_tree():
|
|
@@ -701,29 +961,38 @@ class SignalSchema:
|
|
|
701
961
|
return create_model(
|
|
702
962
|
name,
|
|
703
963
|
__base__=(DataModel,), # type: ignore[call-overload]
|
|
704
|
-
**fields,
|
|
964
|
+
**fields, # type: ignore[arg-type]
|
|
705
965
|
)
|
|
706
966
|
|
|
707
967
|
@staticmethod
|
|
708
968
|
def _build_tree(
|
|
709
969
|
values: dict[str, DataType],
|
|
710
|
-
) -> dict[str, tuple[DataType,
|
|
970
|
+
) -> dict[str, tuple[DataType, dict | None]]:
|
|
711
971
|
return {
|
|
712
972
|
name: (val, SignalSchema._build_tree_for_type(val))
|
|
713
973
|
for name, val in values.items()
|
|
714
974
|
}
|
|
715
975
|
|
|
716
976
|
def get_flat_tree(
|
|
717
|
-
self,
|
|
977
|
+
self,
|
|
978
|
+
include_hidden: bool = True,
|
|
979
|
+
include_sys: bool = True,
|
|
718
980
|
) -> Iterator[tuple[list[str], DataType, bool, int]]:
|
|
719
|
-
yield from self._get_flat_tree(self.tree, [], 0, include_hidden)
|
|
981
|
+
yield from self._get_flat_tree(self.tree, [], 0, include_hidden, include_sys)
|
|
720
982
|
|
|
721
983
|
def _get_flat_tree(
|
|
722
|
-
self,
|
|
984
|
+
self,
|
|
985
|
+
tree: dict,
|
|
986
|
+
prefix: list[str],
|
|
987
|
+
depth: int,
|
|
988
|
+
include_hidden: bool,
|
|
989
|
+
include_sys: bool,
|
|
723
990
|
) -> Iterator[tuple[list[str], DataType, bool, int]]:
|
|
724
991
|
for name, (type_, substree) in tree.items():
|
|
725
992
|
suffix = name.split(".")
|
|
726
993
|
new_prefix = prefix + suffix
|
|
994
|
+
if not include_sys and new_prefix and new_prefix[0] == "sys":
|
|
995
|
+
continue
|
|
727
996
|
hidden_fields = getattr(type_, "_hidden_fields", None)
|
|
728
997
|
if hidden_fields and substree and not include_hidden:
|
|
729
998
|
substree = {
|
|
@@ -736,10 +1005,10 @@ class SignalSchema:
|
|
|
736
1005
|
yield new_prefix, type_, has_subtree, depth
|
|
737
1006
|
if substree is not None:
|
|
738
1007
|
yield from self._get_flat_tree(
|
|
739
|
-
substree, new_prefix, depth + 1, include_hidden
|
|
1008
|
+
substree, new_prefix, depth + 1, include_hidden, include_sys
|
|
740
1009
|
)
|
|
741
1010
|
|
|
742
|
-
def print_tree(self, indent: int = 2, start_at: int = 0, file:
|
|
1011
|
+
def print_tree(self, indent: int = 2, start_at: int = 0, file: IO | None = None):
|
|
743
1012
|
for path, type_, _, depth in self.get_flat_tree():
|
|
744
1013
|
total_indent = start_at + depth * indent
|
|
745
1014
|
col_name = " " * total_indent + path[-1]
|
|
@@ -769,7 +1038,28 @@ class SignalSchema:
|
|
|
769
1038
|
], max_length
|
|
770
1039
|
|
|
771
1040
|
def __or__(self, other):
|
|
772
|
-
|
|
1041
|
+
new_values = dict(self.values)
|
|
1042
|
+
|
|
1043
|
+
for name, new_type in other.values.items():
|
|
1044
|
+
if name in new_values:
|
|
1045
|
+
current_type = new_values[name]
|
|
1046
|
+
if current_type != new_type:
|
|
1047
|
+
raise DataChainColumnError(
|
|
1048
|
+
name,
|
|
1049
|
+
"signal already exists with a different type",
|
|
1050
|
+
)
|
|
1051
|
+
continue
|
|
1052
|
+
|
|
1053
|
+
root = self._extract_root(name)
|
|
1054
|
+
if any(self._extract_root(existing) == root for existing in new_values):
|
|
1055
|
+
raise DataChainColumnError(
|
|
1056
|
+
name,
|
|
1057
|
+
"signal root already exists in schema",
|
|
1058
|
+
)
|
|
1059
|
+
|
|
1060
|
+
new_values[name] = new_type
|
|
1061
|
+
|
|
1062
|
+
return self.__class__(new_values)
|
|
773
1063
|
|
|
774
1064
|
def __contains__(self, name: str):
|
|
775
1065
|
return name in self.values
|
|
@@ -778,15 +1068,20 @@ class SignalSchema:
|
|
|
778
1068
|
return self.values.pop(name)
|
|
779
1069
|
|
|
780
1070
|
@staticmethod
|
|
781
|
-
def _type_to_str(type_:
|
|
1071
|
+
def _type_to_str(type_: type | None, subtypes: list | None = None) -> str: # noqa: C901, PLR0911
|
|
782
1072
|
"""Convert a type to a string-based representation."""
|
|
783
1073
|
if type_ is None:
|
|
784
1074
|
return "NoneType"
|
|
785
1075
|
|
|
786
1076
|
origin = get_origin(type_)
|
|
787
1077
|
|
|
788
|
-
if origin
|
|
1078
|
+
if origin in (Union, types.UnionType):
|
|
789
1079
|
args = get_args(type_)
|
|
1080
|
+
if len(args) == 2 and type(None) in args:
|
|
1081
|
+
# This is an Optional type.
|
|
1082
|
+
non_none_type = args[0] if args[1] is type(None) else args[1]
|
|
1083
|
+
type_str = SignalSchema._type_to_str(non_none_type, subtypes)
|
|
1084
|
+
return f"Optional[{type_str}]"
|
|
790
1085
|
formatted_types = ", ".join(
|
|
791
1086
|
SignalSchema._type_to_str(arg, subtypes) for arg in args
|
|
792
1087
|
)
|
|
@@ -795,21 +1090,21 @@ class SignalSchema:
|
|
|
795
1090
|
args = get_args(type_)
|
|
796
1091
|
type_str = SignalSchema._type_to_str(args[0], subtypes)
|
|
797
1092
|
return f"Optional[{type_str}]"
|
|
798
|
-
if origin
|
|
1093
|
+
if origin is list:
|
|
799
1094
|
args = get_args(type_)
|
|
1095
|
+
if len(args) == 0:
|
|
1096
|
+
return "list"
|
|
800
1097
|
type_str = SignalSchema._type_to_str(args[0], subtypes)
|
|
801
1098
|
return f"list[{type_str}]"
|
|
802
|
-
if origin
|
|
1099
|
+
if origin is dict:
|
|
803
1100
|
args = get_args(type_)
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
)
|
|
807
|
-
|
|
808
|
-
f",
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
)
|
|
812
|
-
return f"dict[{type_str}{vals}]"
|
|
1101
|
+
if len(args) == 0:
|
|
1102
|
+
return "dict"
|
|
1103
|
+
key_type = SignalSchema._type_to_str(args[0], subtypes)
|
|
1104
|
+
if len(args) == 1:
|
|
1105
|
+
return f"dict[{key_type}, Any]"
|
|
1106
|
+
val_type = SignalSchema._type_to_str(args[1], subtypes)
|
|
1107
|
+
return f"dict[{key_type}, {val_type}]"
|
|
813
1108
|
if origin == Annotated:
|
|
814
1109
|
args = get_args(type_)
|
|
815
1110
|
return SignalSchema._type_to_str(args[0], subtypes)
|
|
@@ -823,7 +1118,7 @@ class SignalSchema:
|
|
|
823
1118
|
# Include this type in the list of all subtypes, if requested.
|
|
824
1119
|
subtypes.append(type_)
|
|
825
1120
|
if not hasattr(type_, "__name__"):
|
|
826
|
-
# This can happen for some third-party or custom types
|
|
1121
|
+
# This can happen for some third-party or custom types
|
|
827
1122
|
warnings.warn(
|
|
828
1123
|
f"Unable to determine name of type '{type_}'.",
|
|
829
1124
|
SignalSchemaWarning,
|
|
@@ -838,7 +1133,7 @@ class SignalSchema:
|
|
|
838
1133
|
@staticmethod
|
|
839
1134
|
def _build_tree_for_type(
|
|
840
1135
|
model: DataType,
|
|
841
|
-
) ->
|
|
1136
|
+
) -> dict[str, tuple[DataType, dict | None]] | None:
|
|
842
1137
|
if (fr := ModelStore.to_pydantic(model)) is not None:
|
|
843
1138
|
return SignalSchema._build_tree_for_model(fr)
|
|
844
1139
|
return None
|
|
@@ -846,8 +1141,8 @@ class SignalSchema:
|
|
|
846
1141
|
@staticmethod
|
|
847
1142
|
def _build_tree_for_model(
|
|
848
1143
|
model: type[BaseModel],
|
|
849
|
-
) ->
|
|
850
|
-
res: dict[str, tuple[DataType,
|
|
1144
|
+
) -> dict[str, tuple[DataType, dict | None]] | None:
|
|
1145
|
+
res: dict[str, tuple[DataType, dict | None]] = {}
|
|
851
1146
|
|
|
852
1147
|
for name, f_info in model.model_fields.items():
|
|
853
1148
|
anno = f_info.annotation
|
|
@@ -859,7 +1154,7 @@ class SignalSchema:
|
|
|
859
1154
|
|
|
860
1155
|
return res
|
|
861
1156
|
|
|
862
|
-
def to_partial(self, *columns: str) -> "SignalSchema":
|
|
1157
|
+
def to_partial(self, *columns: str) -> "SignalSchema": # noqa: C901
|
|
863
1158
|
"""
|
|
864
1159
|
Convert the schema to a partial schema with only the specified columns.
|
|
865
1160
|
|
|
@@ -896,15 +1191,21 @@ class SignalSchema:
|
|
|
896
1191
|
schema: dict[str, Any] = {}
|
|
897
1192
|
schema_custom_types: dict[str, CustomType] = {}
|
|
898
1193
|
|
|
899
|
-
data_model_bases:
|
|
1194
|
+
data_model_bases: list[tuple[str, str, str | None]] | None = None
|
|
900
1195
|
|
|
901
1196
|
signal_partials: dict[str, str] = {}
|
|
902
1197
|
partial_versions: dict[str, int] = {}
|
|
903
1198
|
|
|
904
1199
|
def _type_name_to_partial(signal_name: str, type_name: str) -> str:
|
|
905
|
-
if
|
|
1200
|
+
# Check if we need to create a partial for this type
|
|
1201
|
+
# Only create partials for custom types that are in the custom_types dict
|
|
1202
|
+
if type_name not in custom_types:
|
|
906
1203
|
return type_name
|
|
907
|
-
|
|
1204
|
+
|
|
1205
|
+
if "@" in type_name:
|
|
1206
|
+
model_name, _ = ModelStore.parse_name_version(type_name)
|
|
1207
|
+
else:
|
|
1208
|
+
model_name = type_name
|
|
908
1209
|
|
|
909
1210
|
if signal_name not in signal_partials:
|
|
910
1211
|
partial_versions.setdefault(model_name, 0)
|
|
@@ -928,6 +1229,14 @@ class SignalSchema:
|
|
|
928
1229
|
parent_type_partial = _type_name_to_partial(signal, parent_type)
|
|
929
1230
|
|
|
930
1231
|
schema[signal] = parent_type_partial
|
|
1232
|
+
|
|
1233
|
+
# If this is a complex signal without field specifier (just "file")
|
|
1234
|
+
# and it's a custom type, include the entire complex signal
|
|
1235
|
+
if len(column_parts) == 1 and parent_type in custom_types:
|
|
1236
|
+
# Include the entire complex signal - no need to create partial
|
|
1237
|
+
schema[signal] = parent_type
|
|
1238
|
+
continue
|
|
1239
|
+
|
|
931
1240
|
continue
|
|
932
1241
|
|
|
933
1242
|
if parent_type not in custom_types:
|
|
@@ -942,6 +1251,20 @@ class SignalSchema:
|
|
|
942
1251
|
f"Field {signal} not found in custom type {parent_type}"
|
|
943
1252
|
)
|
|
944
1253
|
|
|
1254
|
+
# Check if this is the last part and if the column type is a complex
|
|
1255
|
+
is_last_part = i == len(column_parts) - 1
|
|
1256
|
+
is_complex_signal = signal_type in custom_types
|
|
1257
|
+
|
|
1258
|
+
if is_last_part and is_complex_signal:
|
|
1259
|
+
schema[column] = signal_type
|
|
1260
|
+
# Also need to remove the partial schema entry we created for the
|
|
1261
|
+
# parent since we're promoting the nested complex column to root
|
|
1262
|
+
parent_signal = column_parts[0]
|
|
1263
|
+
schema.pop(parent_signal, None)
|
|
1264
|
+
# Don't create partial types for this case
|
|
1265
|
+
break
|
|
1266
|
+
|
|
1267
|
+
# Create partial type for this field
|
|
945
1268
|
partial_type = _type_name_to_partial(
|
|
946
1269
|
".".join(column_parts[: i + 1]),
|
|
947
1270
|
signal_type,
|