datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. datachain/__init__.py +20 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +7 -7
  4. datachain/catalog/__init__.py +2 -2
  5. datachain/catalog/catalog.py +621 -507
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +28 -18
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +24 -33
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +83 -52
  12. datachain/cli/commands/ls.py +17 -17
  13. datachain/cli/commands/show.py +4 -4
  14. datachain/cli/parser/__init__.py +8 -74
  15. datachain/cli/parser/job.py +95 -3
  16. datachain/cli/parser/studio.py +11 -4
  17. datachain/cli/parser/utils.py +1 -2
  18. datachain/cli/utils.py +2 -15
  19. datachain/client/azure.py +4 -4
  20. datachain/client/fsspec.py +45 -28
  21. datachain/client/gcs.py +6 -6
  22. datachain/client/hf.py +29 -2
  23. datachain/client/http.py +157 -0
  24. datachain/client/local.py +15 -11
  25. datachain/client/s3.py +17 -9
  26. datachain/config.py +4 -8
  27. datachain/data_storage/db_engine.py +12 -6
  28. datachain/data_storage/job.py +5 -1
  29. datachain/data_storage/metastore.py +1252 -186
  30. datachain/data_storage/schema.py +58 -45
  31. datachain/data_storage/serializer.py +105 -15
  32. datachain/data_storage/sqlite.py +286 -127
  33. datachain/data_storage/warehouse.py +250 -113
  34. datachain/dataset.py +353 -148
  35. datachain/delta.py +391 -0
  36. datachain/diff/__init__.py +27 -29
  37. datachain/error.py +60 -0
  38. datachain/func/__init__.py +2 -1
  39. datachain/func/aggregate.py +66 -42
  40. datachain/func/array.py +242 -38
  41. datachain/func/base.py +7 -4
  42. datachain/func/conditional.py +110 -60
  43. datachain/func/func.py +96 -45
  44. datachain/func/numeric.py +55 -38
  45. datachain/func/path.py +32 -20
  46. datachain/func/random.py +2 -2
  47. datachain/func/string.py +67 -37
  48. datachain/func/window.py +7 -8
  49. datachain/hash_utils.py +123 -0
  50. datachain/job.py +11 -7
  51. datachain/json.py +138 -0
  52. datachain/lib/arrow.py +58 -22
  53. datachain/lib/audio.py +245 -0
  54. datachain/lib/clip.py +14 -13
  55. datachain/lib/convert/flatten.py +5 -3
  56. datachain/lib/convert/python_to_sql.py +6 -10
  57. datachain/lib/convert/sql_to_python.py +8 -0
  58. datachain/lib/convert/values_to_tuples.py +156 -51
  59. datachain/lib/data_model.py +42 -20
  60. datachain/lib/dataset_info.py +36 -8
  61. datachain/lib/dc/__init__.py +8 -2
  62. datachain/lib/dc/csv.py +25 -28
  63. datachain/lib/dc/database.py +398 -0
  64. datachain/lib/dc/datachain.py +1289 -425
  65. datachain/lib/dc/datasets.py +320 -38
  66. datachain/lib/dc/hf.py +38 -24
  67. datachain/lib/dc/json.py +29 -32
  68. datachain/lib/dc/listings.py +112 -8
  69. datachain/lib/dc/pandas.py +16 -12
  70. datachain/lib/dc/parquet.py +35 -23
  71. datachain/lib/dc/records.py +31 -23
  72. datachain/lib/dc/storage.py +154 -64
  73. datachain/lib/dc/storage_pattern.py +251 -0
  74. datachain/lib/dc/utils.py +24 -16
  75. datachain/lib/dc/values.py +8 -9
  76. datachain/lib/file.py +622 -89
  77. datachain/lib/hf.py +69 -39
  78. datachain/lib/image.py +14 -14
  79. datachain/lib/listing.py +14 -11
  80. datachain/lib/listing_info.py +1 -2
  81. datachain/lib/meta_formats.py +3 -4
  82. datachain/lib/model_store.py +39 -7
  83. datachain/lib/namespaces.py +125 -0
  84. datachain/lib/projects.py +130 -0
  85. datachain/lib/pytorch.py +32 -21
  86. datachain/lib/settings.py +192 -56
  87. datachain/lib/signal_schema.py +427 -104
  88. datachain/lib/tar.py +1 -2
  89. datachain/lib/text.py +8 -7
  90. datachain/lib/udf.py +164 -76
  91. datachain/lib/udf_signature.py +60 -35
  92. datachain/lib/utils.py +118 -4
  93. datachain/lib/video.py +17 -9
  94. datachain/lib/webdataset.py +61 -56
  95. datachain/lib/webdataset_laion.py +15 -16
  96. datachain/listing.py +22 -10
  97. datachain/model/bbox.py +3 -1
  98. datachain/model/ultralytics/bbox.py +16 -12
  99. datachain/model/ultralytics/pose.py +16 -12
  100. datachain/model/ultralytics/segment.py +16 -12
  101. datachain/namespace.py +84 -0
  102. datachain/node.py +6 -6
  103. datachain/nodes_thread_pool.py +0 -1
  104. datachain/plugins.py +24 -0
  105. datachain/project.py +78 -0
  106. datachain/query/batch.py +40 -41
  107. datachain/query/dataset.py +604 -322
  108. datachain/query/dispatch.py +261 -154
  109. datachain/query/metrics.py +4 -6
  110. datachain/query/params.py +2 -3
  111. datachain/query/queue.py +3 -12
  112. datachain/query/schema.py +11 -6
  113. datachain/query/session.py +200 -33
  114. datachain/query/udf.py +34 -2
  115. datachain/remote/studio.py +171 -69
  116. datachain/script_meta.py +12 -12
  117. datachain/semver.py +68 -0
  118. datachain/sql/__init__.py +2 -0
  119. datachain/sql/functions/array.py +33 -1
  120. datachain/sql/postgresql_dialect.py +9 -0
  121. datachain/sql/postgresql_types.py +21 -0
  122. datachain/sql/sqlite/__init__.py +5 -1
  123. datachain/sql/sqlite/base.py +102 -29
  124. datachain/sql/sqlite/types.py +8 -13
  125. datachain/sql/types.py +70 -15
  126. datachain/studio.py +223 -46
  127. datachain/toolkit/split.py +31 -10
  128. datachain/utils.py +101 -59
  129. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
  130. datachain-0.39.0.dist-info/RECORD +173 -0
  131. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
  132. datachain/cli/commands/query.py +0 -53
  133. datachain/query/utils.py +0 -42
  134. datachain-0.14.2.dist-info/RECORD +0 -158
  135. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  136. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  137. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
@@ -1,30 +1,32 @@
1
1
  import copy
2
+ import hashlib
3
+ import logging
4
+ import math
5
+ import types
2
6
  import warnings
3
- from collections.abc import Iterator, Sequence
7
+ from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
4
8
  from dataclasses import dataclass
5
9
  from datetime import datetime
6
10
  from inspect import isclass
7
- from typing import ( # noqa: UP035
11
+ from typing import (
8
12
  IO,
9
13
  TYPE_CHECKING,
10
14
  Annotated,
11
15
  Any,
12
- Callable,
13
- Dict,
14
16
  Final,
15
- List,
16
17
  Literal,
17
- Mapping,
18
18
  Optional,
19
19
  Union,
20
20
  get_args,
21
21
  get_origin,
22
22
  )
23
23
 
24
- from pydantic import BaseModel, Field, create_model
24
+ from pydantic import BaseModel, Field, ValidationError, create_model
25
25
  from sqlalchemy import ColumnElement
26
26
  from typing_extensions import Literal as LiteralEx
27
27
 
28
+ from datachain import json
29
+ from datachain.func import literal
28
30
  from datachain.func.func import Func
29
31
  from datachain.lib.convert.python_to_sql import python_to_sql
30
32
  from datachain.lib.convert.sql_to_python import sql_to_python
@@ -32,14 +34,16 @@ from datachain.lib.convert.unflatten import unflatten_to_json_pos
32
34
  from datachain.lib.data_model import DataModel, DataType, DataValue
33
35
  from datachain.lib.file import File
34
36
  from datachain.lib.model_store import ModelStore
35
- from datachain.lib.utils import DataChainParamsError
36
- from datachain.query.schema import DEFAULT_DELIMITER, Column
37
+ from datachain.lib.utils import DataChainColumnError, DataChainParamsError
38
+ from datachain.query.schema import DEFAULT_DELIMITER, C, Column, ColumnMeta
37
39
  from datachain.sql.types import SQLType
38
40
 
39
41
  if TYPE_CHECKING:
40
42
  from datachain.catalog import Catalog
41
43
 
42
44
 
45
+ logger = logging.getLogger(__name__)
46
+
43
47
  NAMES_TO_TYPES = {
44
48
  "int": int,
45
49
  "str": str,
@@ -68,7 +72,7 @@ class SignalSchemaWarning(RuntimeWarning):
68
72
 
69
73
 
70
74
  class SignalResolvingError(SignalSchemaError):
71
- def __init__(self, path: Optional[list[str]], msg: str):
75
+ def __init__(self, path: list[str] | None, msg: str):
72
76
  name = " '" + ".".join(path) + "'" if path else ""
73
77
  super().__init__(f"cannot resolve signal name{name}: {msg}")
74
78
 
@@ -78,6 +82,55 @@ class SetupError(SignalSchemaError):
78
82
  super().__init__(f"cannot setup value '{name}': {msg}")
79
83
 
80
84
 
85
+ def generate_merge_root_mapping(
86
+ left_names: Iterable[str],
87
+ right_names: Sequence[str],
88
+ *,
89
+ extract_root: Callable[[str], str],
90
+ prefix: str,
91
+ ) -> dict[str, str]:
92
+ """Compute root renames for schema merges.
93
+
94
+ Returns a mapping from each right-side root to the target root name while
95
+ preserving the order in which right-side roots first appear. The mapping
96
+ avoids collisions with roots already present on the left side and among
97
+ the right-side roots themselves. When a conflict is detected, the
98
+ ``prefix`` string is used to derive candidate root names until a unique
99
+ one is found.
100
+ """
101
+
102
+ existing_roots = {extract_root(name) for name in left_names}
103
+
104
+ right_root_order: list[str] = []
105
+ right_roots: set[str] = set()
106
+ for name in right_names:
107
+ root = extract_root(name)
108
+ if root not in right_roots:
109
+ right_roots.add(root)
110
+ right_root_order.append(root)
111
+
112
+ used_roots = set(existing_roots)
113
+ root_mapping: dict[str, str] = {}
114
+
115
+ for root in right_root_order:
116
+ if root not in used_roots:
117
+ root_mapping[root] = root
118
+ used_roots.add(root)
119
+ continue
120
+
121
+ suffix = 0
122
+ while True:
123
+ base = prefix if root in prefix else f"{prefix}{root}"
124
+ candidate_root = base if suffix == 0 else f"{base}_{suffix}"
125
+ if candidate_root not in used_roots and candidate_root not in right_roots:
126
+ root_mapping[root] = candidate_root
127
+ used_roots.add(candidate_root)
128
+ break
129
+ suffix += 1
130
+
131
+ return root_mapping
132
+
133
+
81
134
  class SignalResolvingTypeError(SignalResolvingError):
82
135
  def __init__(self, method: str, field):
83
136
  super().__init__(
@@ -87,12 +140,18 @@ class SignalResolvingTypeError(SignalResolvingError):
87
140
  )
88
141
 
89
142
 
143
+ class SignalRemoveError(SignalSchemaError):
144
+ def __init__(self, path: list[str] | None, msg: str):
145
+ name = " '" + ".".join(path) + "'" if path else ""
146
+ super().__init__(f"cannot remove signal name{name}: {msg}")
147
+
148
+
90
149
  class CustomType(BaseModel):
91
150
  schema_version: int = Field(ge=1, le=2, strict=True)
92
151
  name: str
93
152
  fields: dict[str, str]
94
- bases: list[tuple[str, str, Optional[str]]]
95
- hidden_fields: Optional[list[str]] = None
153
+ bases: list[tuple[str, str, str | None]]
154
+ hidden_fields: list[str] | None = None
96
155
 
97
156
  @classmethod
98
157
  def deserialize(cls, data: dict[str, Any], type_name: str) -> "CustomType":
@@ -112,8 +171,8 @@ class CustomType(BaseModel):
112
171
 
113
172
  def create_feature_model(
114
173
  name: str,
115
- fields: Mapping[str, Union[type, None, tuple[type, Any]]],
116
- base: Optional[type] = None,
174
+ fields: Mapping[str, type | tuple[type, Any] | None],
175
+ base: type | None = None,
117
176
  ) -> type[BaseModel]:
118
177
  """
119
178
  This gets or returns a dynamic feature model for use in restoring a model
@@ -130,7 +189,7 @@ def create_feature_model(
130
189
  **{
131
190
  field_name: anno if isinstance(anno, tuple) else (anno, None)
132
191
  for field_name, anno in fields.items()
133
- },
192
+ }, # type: ignore[arg-type]
134
193
  )
135
194
 
136
195
 
@@ -139,12 +198,12 @@ class SignalSchema:
139
198
  values: dict[str, DataType]
140
199
  tree: dict[str, Any]
141
200
  setup_func: dict[str, Callable]
142
- setup_values: Optional[dict[str, Any]]
201
+ setup_values: dict[str, Any] | None
143
202
 
144
203
  def __init__(
145
204
  self,
146
205
  values: dict[str, DataType],
147
- setup: Optional[dict[str, Callable]] = None,
206
+ setup: dict[str, Callable] | None = None,
148
207
  ):
149
208
  self.values = values
150
209
  self.tree = self._build_tree(values)
@@ -183,8 +242,8 @@ class SignalSchema:
183
242
  return SignalSchema(signals)
184
243
 
185
244
  @staticmethod
186
- def _get_bases(fr: type) -> list[tuple[str, str, Optional[str]]]:
187
- bases: list[tuple[str, str, Optional[str]]] = []
245
+ def _get_bases(fr: type) -> list[tuple[str, str, str | None]]:
246
+ bases: list[tuple[str, str, str | None]] = []
188
247
  for base in fr.__mro__:
189
248
  model_store_name = (
190
249
  ModelStore.get_name(base) if issubclass(base, DataModel) else None
@@ -250,6 +309,11 @@ class SignalSchema:
250
309
  signals["_custom_types"] = custom_types
251
310
  return signals
252
311
 
312
+ def hash(self) -> str:
313
+ """Create SHA hash of this schema"""
314
+ json_str = json.dumps(self.serialize(), sort_keys=True, separators=(",", ":"))
315
+ return hashlib.sha256(json_str.encode("utf-8")).hexdigest()
316
+
253
317
  @staticmethod
254
318
  def _split_subtypes(type_name: str) -> list[str]:
255
319
  """This splits a list of subtypes, including proper square bracket handling."""
@@ -276,7 +340,7 @@ class SignalSchema:
276
340
  @staticmethod
277
341
  def _deserialize_custom_type(
278
342
  type_name: str, custom_types: dict[str, Any]
279
- ) -> Optional[type]:
343
+ ) -> type | None:
280
344
  """Given a type name like MyType@v1 gets a type from ModelStore or recreates
281
345
  it based on the information from the custom types dict that includes fields and
282
346
  bases."""
@@ -309,7 +373,7 @@ class SignalSchema:
309
373
  return None
310
374
 
311
375
  @staticmethod
312
- def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]:
376
+ def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> type | None:
313
377
  """Convert a string-based type back into a python type."""
314
378
  type_name = type_name.strip()
315
379
  if not type_name:
@@ -318,7 +382,7 @@ class SignalSchema:
318
382
  return None
319
383
 
320
384
  bracket_idx = type_name.find("[")
321
- subtypes: Optional[tuple[Optional[type], ...]] = None
385
+ subtypes: tuple[type | None, ...] | None = None
322
386
  if bracket_idx > -1:
323
387
  if bracket_idx == 0:
324
388
  raise ValueError("Type cannot start with '['")
@@ -439,35 +503,54 @@ class SignalSchema:
439
503
  res[db_name] = python_to_sql(type_)
440
504
  return res
441
505
 
442
- def row_to_objs(self, row: Sequence[Any]) -> list[DataValue]:
506
+ def row_to_objs(self, row: Sequence[Any]) -> list[Any]:
443
507
  self._init_setup_values()
444
508
 
445
- objs: list[DataValue] = []
509
+ objs: list[Any] = []
446
510
  pos = 0
447
511
  for name, fr_type in self.values.items():
448
- if self.setup_values and (val := self.setup_values.get(name, None)):
449
- objs.append(val)
512
+ if self.setup_values and name in self.setup_values:
513
+ objs.append(self.setup_values.get(name))
450
514
  elif (fr := ModelStore.to_pydantic(fr_type)) is not None:
451
515
  j, pos = unflatten_to_json_pos(fr, row, pos)
452
- objs.append(fr(**j))
516
+ try:
517
+ obj = fr(**j)
518
+ except ValidationError as e:
519
+ if self._all_values_none(j):
520
+ logger.debug("Failed to create input for %s: %s", name, e)
521
+ obj = None
522
+ else:
523
+ raise
524
+ objs.append(obj)
453
525
  else:
454
526
  objs.append(row[pos])
455
527
  pos += 1
456
528
  return objs
457
529
 
458
- def contains_file(self) -> bool:
459
- for type_ in self.values.values():
460
- if (fr := ModelStore.to_pydantic(type_)) is not None and issubclass(
530
+ @staticmethod
531
+ def _all_values_none(value: Any) -> bool:
532
+ if isinstance(value, dict):
533
+ return all(SignalSchema._all_values_none(v) for v in value.values())
534
+ if isinstance(value, (list, tuple, set)):
535
+ return all(SignalSchema._all_values_none(v) for v in value)
536
+ if isinstance(value, float):
537
+ # NaN is used to represent NULL and NaN float values in datachain
538
+ # Since SQLite does not have a separate NULL type, we need to check for NaN
539
+ return math.isnan(value) or value is None
540
+ return value is None
541
+
542
+ def get_file_signal(self) -> str | None:
543
+ for signal_name, signal_type in self.values.items():
544
+ if (fr := ModelStore.to_pydantic(signal_type)) is not None and issubclass(
461
545
  fr, File
462
546
  ):
463
- return True
464
-
465
- return False
547
+ return signal_name
548
+ return None
466
549
 
467
550
  def slice(
468
551
  self,
469
- params: dict[str, Union[DataType, Any]],
470
- setup: Optional[dict[str, Callable]] = None,
552
+ params: dict[str, DataType | Any],
553
+ setup: dict[str, Callable] | None = None,
471
554
  is_batch: bool = False,
472
555
  ) -> "SignalSchema":
473
556
  """
@@ -491,9 +574,13 @@ class SignalSchema:
491
574
  schema_origin = get_origin(schema_type)
492
575
  param_origin = get_origin(param_type)
493
576
 
494
- if schema_origin is Union and type(None) in get_args(schema_type):
577
+ if schema_origin in (Union, types.UnionType) and type(None) in get_args(
578
+ schema_type
579
+ ):
495
580
  schema_type = get_args(schema_type)[0]
496
- if param_origin is Union and type(None) in get_args(param_type):
581
+ if param_origin in (Union, types.UnionType) and type(None) in get_args(
582
+ param_type
583
+ ):
497
584
  param_type = get_args(param_type)[0]
498
585
 
499
586
  if is_batch:
@@ -529,22 +616,97 @@ class SignalSchema:
529
616
  pos = 0
530
617
  for fr_cls in self.values.values():
531
618
  if (fr := ModelStore.to_pydantic(fr_cls)) is None:
532
- res.append(row[pos])
619
+ value = row[pos]
533
620
  pos += 1
621
+ converted = self._convert_feature_value(fr_cls, value, catalog, cache)
622
+ res.append(converted)
534
623
  else:
535
624
  json, pos = unflatten_to_json_pos(fr, row, pos) # type: ignore[union-attr]
536
- obj = fr(**json)
537
- SignalSchema._set_file_stream(obj, catalog, cache)
625
+ try:
626
+ obj = fr(**json)
627
+ SignalSchema._set_file_stream(obj, catalog, cache)
628
+ except ValidationError as e:
629
+ if self._all_values_none(json):
630
+ logger.debug("Failed to create feature for %s: %s", fr_cls, e)
631
+ obj = None
632
+ else:
633
+ raise
538
634
  res.append(obj)
539
635
  return res
540
636
 
637
+ def _convert_feature_value(
638
+ self,
639
+ annotation: DataType,
640
+ value: Any,
641
+ catalog: "Catalog",
642
+ cache: bool,
643
+ ) -> Any:
644
+ """Convert raw DB value into declared annotation if needed."""
645
+ if value is None:
646
+ return None
647
+
648
+ result = value
649
+ origin = get_origin(annotation)
650
+
651
+ if origin in (Union, types.UnionType):
652
+ non_none_args = [
653
+ arg for arg in get_args(annotation) if arg is not type(None)
654
+ ]
655
+ if len(non_none_args) == 1:
656
+ annotation = non_none_args[0]
657
+ origin = get_origin(annotation)
658
+ else:
659
+ return result
660
+
661
+ if ModelStore.is_pydantic(annotation):
662
+ if isinstance(value, annotation):
663
+ obj = value
664
+ elif isinstance(value, Mapping):
665
+ obj = annotation(**value)
666
+ else:
667
+ return result
668
+ assert isinstance(obj, BaseModel)
669
+ SignalSchema._set_file_stream(obj, catalog, cache)
670
+ result = obj
671
+ elif origin is list:
672
+ args = get_args(annotation)
673
+ if args and isinstance(value, (list, tuple)):
674
+ item_type = args[0]
675
+ result = [
676
+ self._convert_feature_value(item_type, item, catalog, cache)
677
+ if item is not None
678
+ else None
679
+ for item in value
680
+ ]
681
+ elif origin is dict:
682
+ args = get_args(annotation)
683
+ if len(args) == 2 and isinstance(value, dict):
684
+ key_type, val_type = args
685
+ result = {}
686
+ for key, val in value.items():
687
+ if key_type is str:
688
+ converted_key = key
689
+ else:
690
+ loaded_key = json.loads(key)
691
+ converted_key = self._convert_feature_value(
692
+ key_type, loaded_key, catalog, cache
693
+ )
694
+ converted_val = (
695
+ self._convert_feature_value(val_type, val, catalog, cache)
696
+ if val_type is not Any
697
+ else val
698
+ )
699
+ result[converted_key] = converted_val
700
+
701
+ return result
702
+
541
703
  @staticmethod
542
704
  def _set_file_stream(
543
705
  obj: BaseModel, catalog: "Catalog", cache: bool = False
544
706
  ) -> None:
545
707
  if isinstance(obj, File):
546
708
  obj._set_stream(catalog, caching_enabled=cache)
547
- for field, finfo in obj.model_fields.items():
709
+ for field, finfo in type(obj).model_fields.items():
548
710
  if ModelStore.is_pydantic(finfo.annotation):
549
711
  SignalSchema._set_file_stream(getattr(obj, field), catalog, cache)
550
712
 
@@ -566,8 +728,8 @@ class SignalSchema:
566
728
  raise SignalResolvingError([col_name], "is not found")
567
729
 
568
730
  def db_signals(
569
- self, name: Optional[str] = None, as_columns=False, include_hidden: bool = True
570
- ) -> Union[list[str], list[Column]]:
731
+ self, name: str | None = None, as_columns=False, include_hidden: bool = True
732
+ ) -> list[str] | list[Column]:
571
733
  """
572
734
  Returns DB columns as strings or Column objects with proper types
573
735
  Optionally, it can filter results by specific object, returning only his signals
@@ -583,6 +745,9 @@ class SignalSchema:
583
745
  ]
584
746
 
585
747
  if name:
748
+ if "." in name:
749
+ name = ColumnMeta.to_db_name(name)
750
+
586
751
  signals = [
587
752
  s
588
753
  for s in signals
@@ -591,6 +756,35 @@ class SignalSchema:
591
756
 
592
757
  return signals # type: ignore[return-value]
593
758
 
759
+ def user_signals(
760
+ self,
761
+ *,
762
+ include_hidden: bool = True,
763
+ include_sys: bool = False,
764
+ ) -> list[str]:
765
+ return [
766
+ ".".join(path)
767
+ for path, _, has_subtree, _ in self.get_flat_tree(
768
+ include_hidden=include_hidden, include_sys=include_sys
769
+ )
770
+ if not has_subtree
771
+ ]
772
+
773
+ def compare_signals(
774
+ self,
775
+ other: "SignalSchema",
776
+ *,
777
+ include_hidden: bool = True,
778
+ include_sys: bool = False,
779
+ ) -> tuple[set[str], set[str]]:
780
+ left = set(
781
+ self.user_signals(include_hidden=include_hidden, include_sys=include_sys)
782
+ )
783
+ right = set(
784
+ other.user_signals(include_hidden=include_hidden, include_sys=include_sys)
785
+ )
786
+ return left - right, right - left
787
+
594
788
  def resolve(self, *names: str) -> "SignalSchema":
595
789
  schema = {}
596
790
  for field in names:
@@ -601,37 +795,60 @@ class SignalSchema:
601
795
  return SignalSchema(schema)
602
796
 
603
797
  def _find_in_tree(self, path: list[str]) -> DataType:
798
+ if val := self.tree.get(".".join(path)):
799
+ # If the path is a single string, we can directly access it
800
+ # without traversing the tree.
801
+ return val[0]
802
+
604
803
  curr_tree = self.tree
605
804
  curr_type = None
606
805
  i = 0
607
806
  while curr_tree is not None and i < len(path):
608
807
  if val := curr_tree.get(path[i]):
609
808
  curr_type, curr_tree = val
610
- elif i == 0 and len(path) > 1 and (val := curr_tree.get(".".join(path))):
611
- curr_type, curr_tree = val
612
- break
613
809
  else:
614
810
  curr_type = None
811
+ break
615
812
  i += 1
616
813
 
617
- if curr_type is None:
814
+ if curr_type is None or i < len(path):
815
+ # If we reached the end of the path and didn't find a type,
816
+ # or if we didn't traverse the entire path, raise an error.
618
817
  raise SignalResolvingError(path, "is not found")
619
818
 
620
819
  return curr_type
621
820
 
821
+ def group_by(
822
+ self, partition_by: Sequence[str], new_column: Sequence[Column]
823
+ ) -> "SignalSchema":
824
+ orig_schema = SignalSchema(copy.deepcopy(self.values))
825
+ schema = orig_schema.to_partial(*partition_by)
826
+
827
+ vals = {c.name: sql_to_python(c) for c in new_column}
828
+ return SignalSchema(schema.values | vals)
829
+
622
830
  def select_except_signals(self, *args: str) -> "SignalSchema":
623
- schema = copy.deepcopy(self.values)
624
- for field in args:
625
- if not isinstance(field, str):
626
- raise SignalResolvingTypeError("select_except()", field)
831
+ def has_signal(signal: str):
832
+ signal = signal.replace(".", DEFAULT_DELIMITER)
833
+ return any(signal == s for s in self.db_signals())
627
834
 
628
- if field not in self.values:
835
+ schema = copy.deepcopy(self.values)
836
+ for signal in args:
837
+ if not isinstance(signal, str):
838
+ raise SignalResolvingTypeError("select_except()", signal)
839
+
840
+ if signal not in self.values:
841
+ if has_signal(signal):
842
+ raise SignalRemoveError(
843
+ signal.split("."),
844
+ "select_except() error - removing nested signal would"
845
+ " break parent schema, which isn't supported.",
846
+ )
629
847
  raise SignalResolvingError(
630
- field.split("."),
631
- "select_except() error - the feature name does not exist or "
632
- "inside of feature (not supported)",
848
+ signal.split("."),
849
+ "select_except() error - the signal does not exist",
633
850
  )
634
- del schema[field]
851
+ del schema[signal]
635
852
 
636
853
  return SignalSchema(schema)
637
854
 
@@ -645,31 +862,49 @@ class SignalSchema:
645
862
 
646
863
  def mutate(self, args_map: dict) -> "SignalSchema":
647
864
  new_values = self.values.copy()
865
+ primitives = (bool, str, int, float)
648
866
 
649
867
  for name, value in args_map.items():
868
+ current_type = None
869
+
870
+ if C.is_nested(name):
871
+ try:
872
+ current_type = self.get_column_type(name)
873
+ except SignalResolvingError as err:
874
+ msg = f"Creating new nested columns directly is not allowed: {name}"
875
+ raise ValueError(msg) from err
876
+
650
877
  if isinstance(value, Column) and value.name in self.values:
651
878
  # renaming existing signal
879
+ # Note: it won't touch nested signals here (e.g. file__path)
880
+ # we don't allow removing nested columns to keep objects consistent
652
881
  del new_values[value.name]
653
882
  new_values[name] = self.values[value.name]
654
- continue
655
- if isinstance(value, Column):
883
+ elif isinstance(value, Column):
656
884
  # adding new signal from existing signal field
657
- try:
658
- new_values[name] = self.get_column_type(
659
- value.name, with_subtree=True
660
- )
661
- continue
662
- except SignalResolvingError:
663
- pass
664
- if isinstance(value, Func):
885
+ new_values[name] = self.get_column_type(value.name, with_subtree=True)
886
+ elif isinstance(value, Func):
665
887
  # adding new signal with function
666
888
  new_values[name] = value.get_result_type(self)
667
- continue
668
- if isinstance(value, ColumnElement):
889
+ elif isinstance(value, primitives):
890
+ # For primitives, store the type, not the value
891
+ val = literal(value)
892
+ val.type = python_to_sql(type(value))()
893
+ new_values[name] = sql_to_python(val)
894
+ elif isinstance(value, ColumnElement):
669
895
  # adding new signal
670
896
  new_values[name] = sql_to_python(value)
671
- continue
672
- new_values[name] = value
897
+ else:
898
+ new_values[name] = value
899
+
900
+ if C.is_nested(name):
901
+ if current_type != new_values[name]:
902
+ msg = (
903
+ f"Altering nested column type is not allowed: {name}, "
904
+ f"current type: {current_type}, new type: {new_values[name]}"
905
+ )
906
+ raise ValueError(msg)
907
+ del new_values[name]
673
908
 
674
909
  return SignalSchema(new_values)
675
910
 
@@ -683,12 +918,37 @@ class SignalSchema:
683
918
  right_schema: "SignalSchema",
684
919
  rname: str,
685
920
  ) -> "SignalSchema":
686
- schema_right = {
687
- rname + key if key in self.values else key: type_
688
- for key, type_ in right_schema.values.items()
689
- }
921
+ merged_values = dict(self.values)
922
+
923
+ right_names = list(right_schema.values.keys())
924
+ root_mapping = generate_merge_root_mapping(
925
+ self.values.keys(),
926
+ right_names,
927
+ extract_root=self._extract_root,
928
+ prefix=rname,
929
+ )
690
930
 
691
- return SignalSchema(self.values | schema_right)
931
+ for key, type_ in right_schema.values.items():
932
+ root = self._extract_root(key)
933
+ tail = key.partition(".")[2]
934
+ mapped_root = root_mapping[root]
935
+ new_name = mapped_root if not tail else f"{mapped_root}.{tail}"
936
+ merged_values[new_name] = type_
937
+
938
+ return SignalSchema(merged_values)
939
+
940
+ @staticmethod
941
+ def _extract_root(name: str) -> str:
942
+ if "." in name:
943
+ return name.split(".", 1)[0]
944
+ return name
945
+
946
+ def append(self, right: "SignalSchema") -> "SignalSchema":
947
+ missing_schema = {
948
+ key: right.values[key]
949
+ for key in [k for k in right.values if k not in self.values]
950
+ }
951
+ return SignalSchema(self.values | missing_schema)
692
952
 
693
953
  def get_signals(self, target_type: type[DataModel]) -> Iterator[str]:
694
954
  for path, type_, has_subtree, _ in self.get_flat_tree():
@@ -701,29 +961,38 @@ class SignalSchema:
701
961
  return create_model(
702
962
  name,
703
963
  __base__=(DataModel,), # type: ignore[call-overload]
704
- **fields,
964
+ **fields, # type: ignore[arg-type]
705
965
  )
706
966
 
707
967
  @staticmethod
708
968
  def _build_tree(
709
969
  values: dict[str, DataType],
710
- ) -> dict[str, tuple[DataType, Optional[dict]]]:
970
+ ) -> dict[str, tuple[DataType, dict | None]]:
711
971
  return {
712
972
  name: (val, SignalSchema._build_tree_for_type(val))
713
973
  for name, val in values.items()
714
974
  }
715
975
 
716
976
  def get_flat_tree(
717
- self, include_hidden: bool = True
977
+ self,
978
+ include_hidden: bool = True,
979
+ include_sys: bool = True,
718
980
  ) -> Iterator[tuple[list[str], DataType, bool, int]]:
719
- yield from self._get_flat_tree(self.tree, [], 0, include_hidden)
981
+ yield from self._get_flat_tree(self.tree, [], 0, include_hidden, include_sys)
720
982
 
721
983
  def _get_flat_tree(
722
- self, tree: dict, prefix: list[str], depth: int, include_hidden: bool
984
+ self,
985
+ tree: dict,
986
+ prefix: list[str],
987
+ depth: int,
988
+ include_hidden: bool,
989
+ include_sys: bool,
723
990
  ) -> Iterator[tuple[list[str], DataType, bool, int]]:
724
991
  for name, (type_, substree) in tree.items():
725
992
  suffix = name.split(".")
726
993
  new_prefix = prefix + suffix
994
+ if not include_sys and new_prefix and new_prefix[0] == "sys":
995
+ continue
727
996
  hidden_fields = getattr(type_, "_hidden_fields", None)
728
997
  if hidden_fields and substree and not include_hidden:
729
998
  substree = {
@@ -736,10 +1005,10 @@ class SignalSchema:
736
1005
  yield new_prefix, type_, has_subtree, depth
737
1006
  if substree is not None:
738
1007
  yield from self._get_flat_tree(
739
- substree, new_prefix, depth + 1, include_hidden
1008
+ substree, new_prefix, depth + 1, include_hidden, include_sys
740
1009
  )
741
1010
 
742
- def print_tree(self, indent: int = 2, start_at: int = 0, file: Optional[IO] = None):
1011
+ def print_tree(self, indent: int = 2, start_at: int = 0, file: IO | None = None):
743
1012
  for path, type_, _, depth in self.get_flat_tree():
744
1013
  total_indent = start_at + depth * indent
745
1014
  col_name = " " * total_indent + path[-1]
@@ -769,7 +1038,28 @@ class SignalSchema:
769
1038
  ], max_length
770
1039
 
771
1040
  def __or__(self, other):
772
- return self.__class__(self.values | other.values)
1041
+ new_values = dict(self.values)
1042
+
1043
+ for name, new_type in other.values.items():
1044
+ if name in new_values:
1045
+ current_type = new_values[name]
1046
+ if current_type != new_type:
1047
+ raise DataChainColumnError(
1048
+ name,
1049
+ "signal already exists with a different type",
1050
+ )
1051
+ continue
1052
+
1053
+ root = self._extract_root(name)
1054
+ if any(self._extract_root(existing) == root for existing in new_values):
1055
+ raise DataChainColumnError(
1056
+ name,
1057
+ "signal root already exists in schema",
1058
+ )
1059
+
1060
+ new_values[name] = new_type
1061
+
1062
+ return self.__class__(new_values)
773
1063
 
774
1064
  def __contains__(self, name: str):
775
1065
  return name in self.values
@@ -778,15 +1068,20 @@ class SignalSchema:
778
1068
  return self.values.pop(name)
779
1069
 
780
1070
  @staticmethod
781
- def _type_to_str(type_: Optional[type], subtypes: Optional[list] = None) -> str: # noqa: PLR0911
1071
+ def _type_to_str(type_: type | None, subtypes: list | None = None) -> str: # noqa: C901, PLR0911
782
1072
  """Convert a type to a string-based representation."""
783
1073
  if type_ is None:
784
1074
  return "NoneType"
785
1075
 
786
1076
  origin = get_origin(type_)
787
1077
 
788
- if origin == Union:
1078
+ if origin in (Union, types.UnionType):
789
1079
  args = get_args(type_)
1080
+ if len(args) == 2 and type(None) in args:
1081
+ # This is an Optional type.
1082
+ non_none_type = args[0] if args[1] is type(None) else args[1]
1083
+ type_str = SignalSchema._type_to_str(non_none_type, subtypes)
1084
+ return f"Optional[{type_str}]"
790
1085
  formatted_types = ", ".join(
791
1086
  SignalSchema._type_to_str(arg, subtypes) for arg in args
792
1087
  )
@@ -795,21 +1090,21 @@ class SignalSchema:
795
1090
  args = get_args(type_)
796
1091
  type_str = SignalSchema._type_to_str(args[0], subtypes)
797
1092
  return f"Optional[{type_str}]"
798
- if origin in (list, List): # noqa: UP006
1093
+ if origin is list:
799
1094
  args = get_args(type_)
1095
+ if len(args) == 0:
1096
+ return "list"
800
1097
  type_str = SignalSchema._type_to_str(args[0], subtypes)
801
1098
  return f"list[{type_str}]"
802
- if origin in (dict, Dict): # noqa: UP006
1099
+ if origin is dict:
803
1100
  args = get_args(type_)
804
- type_str = (
805
- SignalSchema._type_to_str(args[0], subtypes) if len(args) > 0 else ""
806
- )
807
- vals = (
808
- f", {SignalSchema._type_to_str(args[1], subtypes)}"
809
- if len(args) > 1
810
- else ""
811
- )
812
- return f"dict[{type_str}{vals}]"
1101
+ if len(args) == 0:
1102
+ return "dict"
1103
+ key_type = SignalSchema._type_to_str(args[0], subtypes)
1104
+ if len(args) == 1:
1105
+ return f"dict[{key_type}, Any]"
1106
+ val_type = SignalSchema._type_to_str(args[1], subtypes)
1107
+ return f"dict[{key_type}, {val_type}]"
813
1108
  if origin == Annotated:
814
1109
  args = get_args(type_)
815
1110
  return SignalSchema._type_to_str(args[0], subtypes)
@@ -823,7 +1118,7 @@ class SignalSchema:
823
1118
  # Include this type in the list of all subtypes, if requested.
824
1119
  subtypes.append(type_)
825
1120
  if not hasattr(type_, "__name__"):
826
- # This can happen for some third-party or custom types, mostly on Python 3.9
1121
+ # This can happen for some third-party or custom types
827
1122
  warnings.warn(
828
1123
  f"Unable to determine name of type '{type_}'.",
829
1124
  SignalSchemaWarning,
@@ -838,7 +1133,7 @@ class SignalSchema:
838
1133
  @staticmethod
839
1134
  def _build_tree_for_type(
840
1135
  model: DataType,
841
- ) -> Optional[dict[str, tuple[DataType, Optional[dict]]]]:
1136
+ ) -> dict[str, tuple[DataType, dict | None]] | None:
842
1137
  if (fr := ModelStore.to_pydantic(model)) is not None:
843
1138
  return SignalSchema._build_tree_for_model(fr)
844
1139
  return None
@@ -846,8 +1141,8 @@ class SignalSchema:
846
1141
  @staticmethod
847
1142
  def _build_tree_for_model(
848
1143
  model: type[BaseModel],
849
- ) -> Optional[dict[str, tuple[DataType, Optional[dict]]]]:
850
- res: dict[str, tuple[DataType, Optional[dict]]] = {}
1144
+ ) -> dict[str, tuple[DataType, dict | None]] | None:
1145
+ res: dict[str, tuple[DataType, dict | None]] = {}
851
1146
 
852
1147
  for name, f_info in model.model_fields.items():
853
1148
  anno = f_info.annotation
@@ -859,7 +1154,7 @@ class SignalSchema:
859
1154
 
860
1155
  return res
861
1156
 
862
- def to_partial(self, *columns: str) -> "SignalSchema":
1157
+ def to_partial(self, *columns: str) -> "SignalSchema": # noqa: C901
863
1158
  """
864
1159
  Convert the schema to a partial schema with only the specified columns.
865
1160
 
@@ -896,15 +1191,21 @@ class SignalSchema:
896
1191
  schema: dict[str, Any] = {}
897
1192
  schema_custom_types: dict[str, CustomType] = {}
898
1193
 
899
- data_model_bases: Optional[list[tuple[str, str, Optional[str]]]] = None
1194
+ data_model_bases: list[tuple[str, str, str | None]] | None = None
900
1195
 
901
1196
  signal_partials: dict[str, str] = {}
902
1197
  partial_versions: dict[str, int] = {}
903
1198
 
904
1199
  def _type_name_to_partial(signal_name: str, type_name: str) -> str:
905
- if "@" not in type_name:
1200
+ # Check if we need to create a partial for this type
1201
+ # Only create partials for custom types that are in the custom_types dict
1202
+ if type_name not in custom_types:
906
1203
  return type_name
907
- model_name, _ = ModelStore.parse_name_version(type_name)
1204
+
1205
+ if "@" in type_name:
1206
+ model_name, _ = ModelStore.parse_name_version(type_name)
1207
+ else:
1208
+ model_name = type_name
908
1209
 
909
1210
  if signal_name not in signal_partials:
910
1211
  partial_versions.setdefault(model_name, 0)
@@ -928,6 +1229,14 @@ class SignalSchema:
928
1229
  parent_type_partial = _type_name_to_partial(signal, parent_type)
929
1230
 
930
1231
  schema[signal] = parent_type_partial
1232
+
1233
+ # If this is a complex signal without field specifier (just "file")
1234
+ # and it's a custom type, include the entire complex signal
1235
+ if len(column_parts) == 1 and parent_type in custom_types:
1236
+ # Include the entire complex signal - no need to create partial
1237
+ schema[signal] = parent_type
1238
+ continue
1239
+
931
1240
  continue
932
1241
 
933
1242
  if parent_type not in custom_types:
@@ -942,6 +1251,20 @@ class SignalSchema:
942
1251
  f"Field {signal} not found in custom type {parent_type}"
943
1252
  )
944
1253
 
1254
+ # Check if this is the last part and if the column type is a complex
1255
+ is_last_part = i == len(column_parts) - 1
1256
+ is_complex_signal = signal_type in custom_types
1257
+
1258
+ if is_last_part and is_complex_signal:
1259
+ schema[column] = signal_type
1260
+ # Also need to remove the partial schema entry we created for the
1261
+ # parent since we're promoting the nested complex column to root
1262
+ parent_signal = column_parts[0]
1263
+ schema.pop(parent_signal, None)
1264
+ # Don't create partial types for this case
1265
+ break
1266
+
1267
+ # Create partial type for this field
945
1268
  partial_type = _type_name_to_partial(
946
1269
  ".".join(column_parts[: i + 1]),
947
1270
  signal_type,