datachain 0.30.5__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. datachain/__init__.py +4 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +5 -5
  4. datachain/catalog/__init__.py +0 -2
  5. datachain/catalog/catalog.py +276 -354
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +8 -3
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +10 -17
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +42 -27
  12. datachain/cli/commands/ls.py +15 -15
  13. datachain/cli/commands/show.py +2 -2
  14. datachain/cli/parser/__init__.py +3 -43
  15. datachain/cli/parser/job.py +1 -1
  16. datachain/cli/parser/utils.py +1 -2
  17. datachain/cli/utils.py +2 -15
  18. datachain/client/azure.py +2 -2
  19. datachain/client/fsspec.py +34 -23
  20. datachain/client/gcs.py +3 -3
  21. datachain/client/http.py +157 -0
  22. datachain/client/local.py +11 -7
  23. datachain/client/s3.py +3 -3
  24. datachain/config.py +4 -8
  25. datachain/data_storage/db_engine.py +12 -6
  26. datachain/data_storage/job.py +2 -0
  27. datachain/data_storage/metastore.py +716 -137
  28. datachain/data_storage/schema.py +20 -27
  29. datachain/data_storage/serializer.py +105 -15
  30. datachain/data_storage/sqlite.py +114 -114
  31. datachain/data_storage/warehouse.py +140 -48
  32. datachain/dataset.py +109 -89
  33. datachain/delta.py +117 -42
  34. datachain/diff/__init__.py +25 -33
  35. datachain/error.py +24 -0
  36. datachain/func/aggregate.py +9 -11
  37. datachain/func/array.py +12 -12
  38. datachain/func/base.py +7 -4
  39. datachain/func/conditional.py +9 -13
  40. datachain/func/func.py +63 -45
  41. datachain/func/numeric.py +5 -7
  42. datachain/func/string.py +2 -2
  43. datachain/hash_utils.py +123 -0
  44. datachain/job.py +11 -7
  45. datachain/json.py +138 -0
  46. datachain/lib/arrow.py +18 -15
  47. datachain/lib/audio.py +60 -59
  48. datachain/lib/clip.py +14 -13
  49. datachain/lib/convert/python_to_sql.py +6 -10
  50. datachain/lib/convert/values_to_tuples.py +151 -53
  51. datachain/lib/data_model.py +23 -19
  52. datachain/lib/dataset_info.py +7 -7
  53. datachain/lib/dc/__init__.py +2 -1
  54. datachain/lib/dc/csv.py +22 -26
  55. datachain/lib/dc/database.py +37 -34
  56. datachain/lib/dc/datachain.py +518 -324
  57. datachain/lib/dc/datasets.py +38 -30
  58. datachain/lib/dc/hf.py +16 -20
  59. datachain/lib/dc/json.py +17 -18
  60. datachain/lib/dc/listings.py +5 -8
  61. datachain/lib/dc/pandas.py +3 -6
  62. datachain/lib/dc/parquet.py +33 -21
  63. datachain/lib/dc/records.py +9 -13
  64. datachain/lib/dc/storage.py +103 -65
  65. datachain/lib/dc/storage_pattern.py +251 -0
  66. datachain/lib/dc/utils.py +17 -14
  67. datachain/lib/dc/values.py +3 -6
  68. datachain/lib/file.py +187 -50
  69. datachain/lib/hf.py +7 -5
  70. datachain/lib/image.py +13 -13
  71. datachain/lib/listing.py +5 -5
  72. datachain/lib/listing_info.py +1 -2
  73. datachain/lib/meta_formats.py +2 -3
  74. datachain/lib/model_store.py +20 -8
  75. datachain/lib/namespaces.py +59 -7
  76. datachain/lib/projects.py +51 -9
  77. datachain/lib/pytorch.py +31 -23
  78. datachain/lib/settings.py +188 -85
  79. datachain/lib/signal_schema.py +302 -64
  80. datachain/lib/text.py +8 -7
  81. datachain/lib/udf.py +103 -63
  82. datachain/lib/udf_signature.py +59 -34
  83. datachain/lib/utils.py +20 -0
  84. datachain/lib/video.py +3 -4
  85. datachain/lib/webdataset.py +31 -36
  86. datachain/lib/webdataset_laion.py +15 -16
  87. datachain/listing.py +12 -5
  88. datachain/model/bbox.py +3 -1
  89. datachain/namespace.py +22 -3
  90. datachain/node.py +6 -6
  91. datachain/nodes_thread_pool.py +0 -1
  92. datachain/plugins.py +24 -0
  93. datachain/project.py +4 -4
  94. datachain/query/batch.py +10 -12
  95. datachain/query/dataset.py +376 -194
  96. datachain/query/dispatch.py +112 -84
  97. datachain/query/metrics.py +3 -4
  98. datachain/query/params.py +2 -3
  99. datachain/query/queue.py +2 -1
  100. datachain/query/schema.py +7 -6
  101. datachain/query/session.py +190 -33
  102. datachain/query/udf.py +9 -6
  103. datachain/remote/studio.py +90 -53
  104. datachain/script_meta.py +12 -12
  105. datachain/sql/sqlite/base.py +37 -25
  106. datachain/sql/sqlite/types.py +1 -1
  107. datachain/sql/types.py +36 -5
  108. datachain/studio.py +49 -40
  109. datachain/toolkit/split.py +31 -10
  110. datachain/utils.py +39 -48
  111. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/METADATA +26 -38
  112. datachain-0.39.0.dist-info/RECORD +173 -0
  113. datachain/cli/commands/query.py +0 -54
  114. datachain/query/utils.py +0 -36
  115. datachain-0.30.5.dist-info/RECORD +0 -168
  116. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/WHEEL +0 -0
  117. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  118. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  119. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
@@ -1,30 +1,31 @@
1
1
  import copy
2
+ import hashlib
3
+ import logging
4
+ import math
5
+ import types
2
6
  import warnings
3
- from collections.abc import Iterator, Sequence
7
+ from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
4
8
  from dataclasses import dataclass
5
9
  from datetime import datetime
6
10
  from inspect import isclass
7
- from typing import ( # noqa: UP035
11
+ from typing import (
8
12
  IO,
9
13
  TYPE_CHECKING,
10
14
  Annotated,
11
15
  Any,
12
- Callable,
13
- Dict,
14
16
  Final,
15
- List,
16
17
  Literal,
17
- Mapping,
18
18
  Optional,
19
19
  Union,
20
20
  get_args,
21
21
  get_origin,
22
22
  )
23
23
 
24
- from pydantic import BaseModel, Field, create_model
24
+ from pydantic import BaseModel, Field, ValidationError, create_model
25
25
  from sqlalchemy import ColumnElement
26
26
  from typing_extensions import Literal as LiteralEx
27
27
 
28
+ from datachain import json
28
29
  from datachain.func import literal
29
30
  from datachain.func.func import Func
30
31
  from datachain.lib.convert.python_to_sql import python_to_sql
@@ -33,7 +34,7 @@ from datachain.lib.convert.unflatten import unflatten_to_json_pos
33
34
  from datachain.lib.data_model import DataModel, DataType, DataValue
34
35
  from datachain.lib.file import File
35
36
  from datachain.lib.model_store import ModelStore
36
- from datachain.lib.utils import DataChainParamsError
37
+ from datachain.lib.utils import DataChainColumnError, DataChainParamsError
37
38
  from datachain.query.schema import DEFAULT_DELIMITER, C, Column, ColumnMeta
38
39
  from datachain.sql.types import SQLType
39
40
 
@@ -41,6 +42,8 @@ if TYPE_CHECKING:
41
42
  from datachain.catalog import Catalog
42
43
 
43
44
 
45
+ logger = logging.getLogger(__name__)
46
+
44
47
  NAMES_TO_TYPES = {
45
48
  "int": int,
46
49
  "str": str,
@@ -69,7 +72,7 @@ class SignalSchemaWarning(RuntimeWarning):
69
72
 
70
73
 
71
74
  class SignalResolvingError(SignalSchemaError):
72
- def __init__(self, path: Optional[list[str]], msg: str):
75
+ def __init__(self, path: list[str] | None, msg: str):
73
76
  name = " '" + ".".join(path) + "'" if path else ""
74
77
  super().__init__(f"cannot resolve signal name{name}: {msg}")
75
78
 
@@ -79,6 +82,55 @@ class SetupError(SignalSchemaError):
79
82
  super().__init__(f"cannot setup value '{name}': {msg}")
80
83
 
81
84
 
85
+ def generate_merge_root_mapping(
86
+ left_names: Iterable[str],
87
+ right_names: Sequence[str],
88
+ *,
89
+ extract_root: Callable[[str], str],
90
+ prefix: str,
91
+ ) -> dict[str, str]:
92
+ """Compute root renames for schema merges.
93
+
94
+ Returns a mapping from each right-side root to the target root name while
95
+ preserving the order in which right-side roots first appear. The mapping
96
+ avoids collisions with roots already present on the left side and among
97
+ the right-side roots themselves. When a conflict is detected, the
98
+ ``prefix`` string is used to derive candidate root names until a unique
99
+ one is found.
100
+ """
101
+
102
+ existing_roots = {extract_root(name) for name in left_names}
103
+
104
+ right_root_order: list[str] = []
105
+ right_roots: set[str] = set()
106
+ for name in right_names:
107
+ root = extract_root(name)
108
+ if root not in right_roots:
109
+ right_roots.add(root)
110
+ right_root_order.append(root)
111
+
112
+ used_roots = set(existing_roots)
113
+ root_mapping: dict[str, str] = {}
114
+
115
+ for root in right_root_order:
116
+ if root not in used_roots:
117
+ root_mapping[root] = root
118
+ used_roots.add(root)
119
+ continue
120
+
121
+ suffix = 0
122
+ while True:
123
+ base = prefix if root in prefix else f"{prefix}{root}"
124
+ candidate_root = base if suffix == 0 else f"{base}_{suffix}"
125
+ if candidate_root not in used_roots and candidate_root not in right_roots:
126
+ root_mapping[root] = candidate_root
127
+ used_roots.add(candidate_root)
128
+ break
129
+ suffix += 1
130
+
131
+ return root_mapping
132
+
133
+
82
134
  class SignalResolvingTypeError(SignalResolvingError):
83
135
  def __init__(self, method: str, field):
84
136
  super().__init__(
@@ -89,7 +141,7 @@ class SignalResolvingTypeError(SignalResolvingError):
89
141
 
90
142
 
91
143
  class SignalRemoveError(SignalSchemaError):
92
- def __init__(self, path: Optional[list[str]], msg: str):
144
+ def __init__(self, path: list[str] | None, msg: str):
93
145
  name = " '" + ".".join(path) + "'" if path else ""
94
146
  super().__init__(f"cannot remove signal name{name}: {msg}")
95
147
 
@@ -98,8 +150,8 @@ class CustomType(BaseModel):
98
150
  schema_version: int = Field(ge=1, le=2, strict=True)
99
151
  name: str
100
152
  fields: dict[str, str]
101
- bases: list[tuple[str, str, Optional[str]]]
102
- hidden_fields: Optional[list[str]] = None
153
+ bases: list[tuple[str, str, str | None]]
154
+ hidden_fields: list[str] | None = None
103
155
 
104
156
  @classmethod
105
157
  def deserialize(cls, data: dict[str, Any], type_name: str) -> "CustomType":
@@ -119,8 +171,8 @@ class CustomType(BaseModel):
119
171
 
120
172
  def create_feature_model(
121
173
  name: str,
122
- fields: Mapping[str, Union[type, None, tuple[type, Any]]],
123
- base: Optional[type] = None,
174
+ fields: Mapping[str, type | tuple[type, Any] | None],
175
+ base: type | None = None,
124
176
  ) -> type[BaseModel]:
125
177
  """
126
178
  This gets or returns a dynamic feature model for use in restoring a model
@@ -137,7 +189,7 @@ def create_feature_model(
137
189
  **{
138
190
  field_name: anno if isinstance(anno, tuple) else (anno, None)
139
191
  for field_name, anno in fields.items()
140
- },
192
+ }, # type: ignore[arg-type]
141
193
  )
142
194
 
143
195
 
@@ -146,12 +198,12 @@ class SignalSchema:
146
198
  values: dict[str, DataType]
147
199
  tree: dict[str, Any]
148
200
  setup_func: dict[str, Callable]
149
- setup_values: Optional[dict[str, Any]]
201
+ setup_values: dict[str, Any] | None
150
202
 
151
203
  def __init__(
152
204
  self,
153
205
  values: dict[str, DataType],
154
- setup: Optional[dict[str, Callable]] = None,
206
+ setup: dict[str, Callable] | None = None,
155
207
  ):
156
208
  self.values = values
157
209
  self.tree = self._build_tree(values)
@@ -190,8 +242,8 @@ class SignalSchema:
190
242
  return SignalSchema(signals)
191
243
 
192
244
  @staticmethod
193
- def _get_bases(fr: type) -> list[tuple[str, str, Optional[str]]]:
194
- bases: list[tuple[str, str, Optional[str]]] = []
245
+ def _get_bases(fr: type) -> list[tuple[str, str, str | None]]:
246
+ bases: list[tuple[str, str, str | None]] = []
195
247
  for base in fr.__mro__:
196
248
  model_store_name = (
197
249
  ModelStore.get_name(base) if issubclass(base, DataModel) else None
@@ -257,6 +309,11 @@ class SignalSchema:
257
309
  signals["_custom_types"] = custom_types
258
310
  return signals
259
311
 
312
+ def hash(self) -> str:
313
+ """Create SHA hash of this schema"""
314
+ json_str = json.dumps(self.serialize(), sort_keys=True, separators=(",", ":"))
315
+ return hashlib.sha256(json_str.encode("utf-8")).hexdigest()
316
+
260
317
  @staticmethod
261
318
  def _split_subtypes(type_name: str) -> list[str]:
262
319
  """This splits a list of subtypes, including proper square bracket handling."""
@@ -283,7 +340,7 @@ class SignalSchema:
283
340
  @staticmethod
284
341
  def _deserialize_custom_type(
285
342
  type_name: str, custom_types: dict[str, Any]
286
- ) -> Optional[type]:
343
+ ) -> type | None:
287
344
  """Given a type name like MyType@v1 gets a type from ModelStore or recreates
288
345
  it based on the information from the custom types dict that includes fields and
289
346
  bases."""
@@ -316,7 +373,7 @@ class SignalSchema:
316
373
  return None
317
374
 
318
375
  @staticmethod
319
- def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]:
376
+ def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> type | None:
320
377
  """Convert a string-based type back into a python type."""
321
378
  type_name = type_name.strip()
322
379
  if not type_name:
@@ -325,7 +382,7 @@ class SignalSchema:
325
382
  return None
326
383
 
327
384
  bracket_idx = type_name.find("[")
328
- subtypes: Optional[tuple[Optional[type], ...]] = None
385
+ subtypes: tuple[type | None, ...] | None = None
329
386
  if bracket_idx > -1:
330
387
  if bracket_idx == 0:
331
388
  raise ValueError("Type cannot start with '['")
@@ -456,13 +513,33 @@ class SignalSchema:
456
513
  objs.append(self.setup_values.get(name))
457
514
  elif (fr := ModelStore.to_pydantic(fr_type)) is not None:
458
515
  j, pos = unflatten_to_json_pos(fr, row, pos)
459
- objs.append(fr(**j))
516
+ try:
517
+ obj = fr(**j)
518
+ except ValidationError as e:
519
+ if self._all_values_none(j):
520
+ logger.debug("Failed to create input for %s: %s", name, e)
521
+ obj = None
522
+ else:
523
+ raise
524
+ objs.append(obj)
460
525
  else:
461
526
  objs.append(row[pos])
462
527
  pos += 1
463
528
  return objs
464
529
 
465
- def get_file_signal(self) -> Optional[str]:
530
+ @staticmethod
531
+ def _all_values_none(value: Any) -> bool:
532
+ if isinstance(value, dict):
533
+ return all(SignalSchema._all_values_none(v) for v in value.values())
534
+ if isinstance(value, (list, tuple, set)):
535
+ return all(SignalSchema._all_values_none(v) for v in value)
536
+ if isinstance(value, float):
537
+ # NaN is used to represent NULL and NaN float values in datachain
538
+ # Since SQLite does not have a separate NULL type, we need to check for NaN
539
+ return math.isnan(value) or value is None
540
+ return value is None
541
+
542
+ def get_file_signal(self) -> str | None:
466
543
  for signal_name, signal_type in self.values.items():
467
544
  if (fr := ModelStore.to_pydantic(signal_type)) is not None and issubclass(
468
545
  fr, File
@@ -472,8 +549,8 @@ class SignalSchema:
472
549
 
473
550
  def slice(
474
551
  self,
475
- params: dict[str, Union[DataType, Any]],
476
- setup: Optional[dict[str, Callable]] = None,
552
+ params: dict[str, DataType | Any],
553
+ setup: dict[str, Callable] | None = None,
477
554
  is_batch: bool = False,
478
555
  ) -> "SignalSchema":
479
556
  """
@@ -497,9 +574,13 @@ class SignalSchema:
497
574
  schema_origin = get_origin(schema_type)
498
575
  param_origin = get_origin(param_type)
499
576
 
500
- if schema_origin is Union and type(None) in get_args(schema_type):
577
+ if schema_origin in (Union, types.UnionType) and type(None) in get_args(
578
+ schema_type
579
+ ):
501
580
  schema_type = get_args(schema_type)[0]
502
- if param_origin is Union and type(None) in get_args(param_type):
581
+ if param_origin in (Union, types.UnionType) and type(None) in get_args(
582
+ param_type
583
+ ):
503
584
  param_type = get_args(param_type)[0]
504
585
 
505
586
  if is_batch:
@@ -535,15 +616,90 @@ class SignalSchema:
535
616
  pos = 0
536
617
  for fr_cls in self.values.values():
537
618
  if (fr := ModelStore.to_pydantic(fr_cls)) is None:
538
- res.append(row[pos])
619
+ value = row[pos]
539
620
  pos += 1
621
+ converted = self._convert_feature_value(fr_cls, value, catalog, cache)
622
+ res.append(converted)
540
623
  else:
541
624
  json, pos = unflatten_to_json_pos(fr, row, pos) # type: ignore[union-attr]
542
- obj = fr(**json)
543
- SignalSchema._set_file_stream(obj, catalog, cache)
625
+ try:
626
+ obj = fr(**json)
627
+ SignalSchema._set_file_stream(obj, catalog, cache)
628
+ except ValidationError as e:
629
+ if self._all_values_none(json):
630
+ logger.debug("Failed to create feature for %s: %s", fr_cls, e)
631
+ obj = None
632
+ else:
633
+ raise
544
634
  res.append(obj)
545
635
  return res
546
636
 
637
+ def _convert_feature_value(
638
+ self,
639
+ annotation: DataType,
640
+ value: Any,
641
+ catalog: "Catalog",
642
+ cache: bool,
643
+ ) -> Any:
644
+ """Convert raw DB value into declared annotation if needed."""
645
+ if value is None:
646
+ return None
647
+
648
+ result = value
649
+ origin = get_origin(annotation)
650
+
651
+ if origin in (Union, types.UnionType):
652
+ non_none_args = [
653
+ arg for arg in get_args(annotation) if arg is not type(None)
654
+ ]
655
+ if len(non_none_args) == 1:
656
+ annotation = non_none_args[0]
657
+ origin = get_origin(annotation)
658
+ else:
659
+ return result
660
+
661
+ if ModelStore.is_pydantic(annotation):
662
+ if isinstance(value, annotation):
663
+ obj = value
664
+ elif isinstance(value, Mapping):
665
+ obj = annotation(**value)
666
+ else:
667
+ return result
668
+ assert isinstance(obj, BaseModel)
669
+ SignalSchema._set_file_stream(obj, catalog, cache)
670
+ result = obj
671
+ elif origin is list:
672
+ args = get_args(annotation)
673
+ if args and isinstance(value, (list, tuple)):
674
+ item_type = args[0]
675
+ result = [
676
+ self._convert_feature_value(item_type, item, catalog, cache)
677
+ if item is not None
678
+ else None
679
+ for item in value
680
+ ]
681
+ elif origin is dict:
682
+ args = get_args(annotation)
683
+ if len(args) == 2 and isinstance(value, dict):
684
+ key_type, val_type = args
685
+ result = {}
686
+ for key, val in value.items():
687
+ if key_type is str:
688
+ converted_key = key
689
+ else:
690
+ loaded_key = json.loads(key)
691
+ converted_key = self._convert_feature_value(
692
+ key_type, loaded_key, catalog, cache
693
+ )
694
+ converted_val = (
695
+ self._convert_feature_value(val_type, val, catalog, cache)
696
+ if val_type is not Any
697
+ else val
698
+ )
699
+ result[converted_key] = converted_val
700
+
701
+ return result
702
+
547
703
  @staticmethod
548
704
  def _set_file_stream(
549
705
  obj: BaseModel, catalog: "Catalog", cache: bool = False
@@ -572,8 +728,8 @@ class SignalSchema:
572
728
  raise SignalResolvingError([col_name], "is not found")
573
729
 
574
730
  def db_signals(
575
- self, name: Optional[str] = None, as_columns=False, include_hidden: bool = True
576
- ) -> Union[list[str], list[Column]]:
731
+ self, name: str | None = None, as_columns=False, include_hidden: bool = True
732
+ ) -> list[str] | list[Column]:
577
733
  """
578
734
  Returns DB columns as strings or Column objects with proper types
579
735
  Optionally, it can filter results by specific object, returning only his signals
@@ -600,6 +756,35 @@ class SignalSchema:
600
756
 
601
757
  return signals # type: ignore[return-value]
602
758
 
759
+ def user_signals(
760
+ self,
761
+ *,
762
+ include_hidden: bool = True,
763
+ include_sys: bool = False,
764
+ ) -> list[str]:
765
+ return [
766
+ ".".join(path)
767
+ for path, _, has_subtree, _ in self.get_flat_tree(
768
+ include_hidden=include_hidden, include_sys=include_sys
769
+ )
770
+ if not has_subtree
771
+ ]
772
+
773
+ def compare_signals(
774
+ self,
775
+ other: "SignalSchema",
776
+ *,
777
+ include_hidden: bool = True,
778
+ include_sys: bool = False,
779
+ ) -> tuple[set[str], set[str]]:
780
+ left = set(
781
+ self.user_signals(include_hidden=include_hidden, include_sys=include_sys)
782
+ )
783
+ right = set(
784
+ other.user_signals(include_hidden=include_hidden, include_sys=include_sys)
785
+ )
786
+ return left - right, right - left
787
+
603
788
  def resolve(self, *names: str) -> "SignalSchema":
604
789
  schema = {}
605
790
  for field in names:
@@ -733,12 +918,30 @@ class SignalSchema:
733
918
  right_schema: "SignalSchema",
734
919
  rname: str,
735
920
  ) -> "SignalSchema":
736
- schema_right = {
737
- rname + key if key in self.values else key: type_
738
- for key, type_ in right_schema.values.items()
739
- }
921
+ merged_values = dict(self.values)
922
+
923
+ right_names = list(right_schema.values.keys())
924
+ root_mapping = generate_merge_root_mapping(
925
+ self.values.keys(),
926
+ right_names,
927
+ extract_root=self._extract_root,
928
+ prefix=rname,
929
+ )
930
+
931
+ for key, type_ in right_schema.values.items():
932
+ root = self._extract_root(key)
933
+ tail = key.partition(".")[2]
934
+ mapped_root = root_mapping[root]
935
+ new_name = mapped_root if not tail else f"{mapped_root}.{tail}"
936
+ merged_values[new_name] = type_
740
937
 
741
- return SignalSchema(self.values | schema_right)
938
+ return SignalSchema(merged_values)
939
+
940
+ @staticmethod
941
+ def _extract_root(name: str) -> str:
942
+ if "." in name:
943
+ return name.split(".", 1)[0]
944
+ return name
742
945
 
743
946
  def append(self, right: "SignalSchema") -> "SignalSchema":
744
947
  missing_schema = {
@@ -758,29 +961,38 @@ class SignalSchema:
758
961
  return create_model(
759
962
  name,
760
963
  __base__=(DataModel,), # type: ignore[call-overload]
761
- **fields,
964
+ **fields, # type: ignore[arg-type]
762
965
  )
763
966
 
764
967
  @staticmethod
765
968
  def _build_tree(
766
969
  values: dict[str, DataType],
767
- ) -> dict[str, tuple[DataType, Optional[dict]]]:
970
+ ) -> dict[str, tuple[DataType, dict | None]]:
768
971
  return {
769
972
  name: (val, SignalSchema._build_tree_for_type(val))
770
973
  for name, val in values.items()
771
974
  }
772
975
 
773
976
  def get_flat_tree(
774
- self, include_hidden: bool = True
977
+ self,
978
+ include_hidden: bool = True,
979
+ include_sys: bool = True,
775
980
  ) -> Iterator[tuple[list[str], DataType, bool, int]]:
776
- yield from self._get_flat_tree(self.tree, [], 0, include_hidden)
981
+ yield from self._get_flat_tree(self.tree, [], 0, include_hidden, include_sys)
777
982
 
778
983
  def _get_flat_tree(
779
- self, tree: dict, prefix: list[str], depth: int, include_hidden: bool
984
+ self,
985
+ tree: dict,
986
+ prefix: list[str],
987
+ depth: int,
988
+ include_hidden: bool,
989
+ include_sys: bool,
780
990
  ) -> Iterator[tuple[list[str], DataType, bool, int]]:
781
991
  for name, (type_, substree) in tree.items():
782
992
  suffix = name.split(".")
783
993
  new_prefix = prefix + suffix
994
+ if not include_sys and new_prefix and new_prefix[0] == "sys":
995
+ continue
784
996
  hidden_fields = getattr(type_, "_hidden_fields", None)
785
997
  if hidden_fields and substree and not include_hidden:
786
998
  substree = {
@@ -793,10 +1005,10 @@ class SignalSchema:
793
1005
  yield new_prefix, type_, has_subtree, depth
794
1006
  if substree is not None:
795
1007
  yield from self._get_flat_tree(
796
- substree, new_prefix, depth + 1, include_hidden
1008
+ substree, new_prefix, depth + 1, include_hidden, include_sys
797
1009
  )
798
1010
 
799
- def print_tree(self, indent: int = 2, start_at: int = 0, file: Optional[IO] = None):
1011
+ def print_tree(self, indent: int = 2, start_at: int = 0, file: IO | None = None):
800
1012
  for path, type_, _, depth in self.get_flat_tree():
801
1013
  total_indent = start_at + depth * indent
802
1014
  col_name = " " * total_indent + path[-1]
@@ -826,7 +1038,28 @@ class SignalSchema:
826
1038
  ], max_length
827
1039
 
828
1040
  def __or__(self, other):
829
- return self.__class__(self.values | other.values)
1041
+ new_values = dict(self.values)
1042
+
1043
+ for name, new_type in other.values.items():
1044
+ if name in new_values:
1045
+ current_type = new_values[name]
1046
+ if current_type != new_type:
1047
+ raise DataChainColumnError(
1048
+ name,
1049
+ "signal already exists with a different type",
1050
+ )
1051
+ continue
1052
+
1053
+ root = self._extract_root(name)
1054
+ if any(self._extract_root(existing) == root for existing in new_values):
1055
+ raise DataChainColumnError(
1056
+ name,
1057
+ "signal root already exists in schema",
1058
+ )
1059
+
1060
+ new_values[name] = new_type
1061
+
1062
+ return self.__class__(new_values)
830
1063
 
831
1064
  def __contains__(self, name: str):
832
1065
  return name in self.values
@@ -835,15 +1068,20 @@ class SignalSchema:
835
1068
  return self.values.pop(name)
836
1069
 
837
1070
  @staticmethod
838
- def _type_to_str(type_: Optional[type], subtypes: Optional[list] = None) -> str: # noqa: PLR0911
1071
+ def _type_to_str(type_: type | None, subtypes: list | None = None) -> str: # noqa: C901, PLR0911
839
1072
  """Convert a type to a string-based representation."""
840
1073
  if type_ is None:
841
1074
  return "NoneType"
842
1075
 
843
1076
  origin = get_origin(type_)
844
1077
 
845
- if origin == Union:
1078
+ if origin in (Union, types.UnionType):
846
1079
  args = get_args(type_)
1080
+ if len(args) == 2 and type(None) in args:
1081
+ # This is an Optional type.
1082
+ non_none_type = args[0] if args[1] is type(None) else args[1]
1083
+ type_str = SignalSchema._type_to_str(non_none_type, subtypes)
1084
+ return f"Optional[{type_str}]"
847
1085
  formatted_types = ", ".join(
848
1086
  SignalSchema._type_to_str(arg, subtypes) for arg in args
849
1087
  )
@@ -852,21 +1090,21 @@ class SignalSchema:
852
1090
  args = get_args(type_)
853
1091
  type_str = SignalSchema._type_to_str(args[0], subtypes)
854
1092
  return f"Optional[{type_str}]"
855
- if origin in (list, List): # noqa: UP006
1093
+ if origin is list:
856
1094
  args = get_args(type_)
1095
+ if len(args) == 0:
1096
+ return "list"
857
1097
  type_str = SignalSchema._type_to_str(args[0], subtypes)
858
1098
  return f"list[{type_str}]"
859
- if origin in (dict, Dict): # noqa: UP006
1099
+ if origin is dict:
860
1100
  args = get_args(type_)
861
- type_str = (
862
- SignalSchema._type_to_str(args[0], subtypes) if len(args) > 0 else ""
863
- )
864
- vals = (
865
- f", {SignalSchema._type_to_str(args[1], subtypes)}"
866
- if len(args) > 1
867
- else ""
868
- )
869
- return f"dict[{type_str}{vals}]"
1101
+ if len(args) == 0:
1102
+ return "dict"
1103
+ key_type = SignalSchema._type_to_str(args[0], subtypes)
1104
+ if len(args) == 1:
1105
+ return f"dict[{key_type}, Any]"
1106
+ val_type = SignalSchema._type_to_str(args[1], subtypes)
1107
+ return f"dict[{key_type}, {val_type}]"
870
1108
  if origin == Annotated:
871
1109
  args = get_args(type_)
872
1110
  return SignalSchema._type_to_str(args[0], subtypes)
@@ -880,7 +1118,7 @@ class SignalSchema:
880
1118
  # Include this type in the list of all subtypes, if requested.
881
1119
  subtypes.append(type_)
882
1120
  if not hasattr(type_, "__name__"):
883
- # This can happen for some third-party or custom types, mostly on Python 3.9
1121
+ # This can happen for some third-party or custom types
884
1122
  warnings.warn(
885
1123
  f"Unable to determine name of type '{type_}'.",
886
1124
  SignalSchemaWarning,
@@ -895,7 +1133,7 @@ class SignalSchema:
895
1133
  @staticmethod
896
1134
  def _build_tree_for_type(
897
1135
  model: DataType,
898
- ) -> Optional[dict[str, tuple[DataType, Optional[dict]]]]:
1136
+ ) -> dict[str, tuple[DataType, dict | None]] | None:
899
1137
  if (fr := ModelStore.to_pydantic(model)) is not None:
900
1138
  return SignalSchema._build_tree_for_model(fr)
901
1139
  return None
@@ -903,8 +1141,8 @@ class SignalSchema:
903
1141
  @staticmethod
904
1142
  def _build_tree_for_model(
905
1143
  model: type[BaseModel],
906
- ) -> Optional[dict[str, tuple[DataType, Optional[dict]]]]:
907
- res: dict[str, tuple[DataType, Optional[dict]]] = {}
1144
+ ) -> dict[str, tuple[DataType, dict | None]] | None:
1145
+ res: dict[str, tuple[DataType, dict | None]] = {}
908
1146
 
909
1147
  for name, f_info in model.model_fields.items():
910
1148
  anno = f_info.annotation
@@ -953,7 +1191,7 @@ class SignalSchema:
953
1191
  schema: dict[str, Any] = {}
954
1192
  schema_custom_types: dict[str, CustomType] = {}
955
1193
 
956
- data_model_bases: Optional[list[tuple[str, str, Optional[str]]]] = None
1194
+ data_model_bases: list[tuple[str, str, str | None]] | None = None
957
1195
 
958
1196
  signal_partials: dict[str, str] = {}
959
1197
  partial_versions: dict[str, int] = {}
datachain/lib/text.py CHANGED
@@ -1,16 +1,17 @@
1
- from typing import Any, Callable, Optional, Union
1
+ from collections.abc import Callable
2
+ from typing import Any
2
3
 
3
4
  import torch
4
5
  from transformers.tokenization_utils_base import PreTrainedTokenizerBase
5
6
 
6
7
 
7
8
  def convert_text(
8
- text: Union[str, list[str]],
9
- tokenizer: Optional[Callable] = None,
10
- tokenizer_kwargs: Optional[dict[str, Any]] = None,
11
- encoder: Optional[Callable] = None,
12
- device: Optional[Union[str, torch.device]] = None,
13
- ) -> Union[str, list[str], torch.Tensor]:
9
+ text: str | list[str],
10
+ tokenizer: Callable | None = None,
11
+ tokenizer_kwargs: dict[str, Any] | None = None,
12
+ encoder: Callable | None = None,
13
+ device: str | torch.device | None = None,
14
+ ) -> str | list[str] | torch.Tensor:
14
15
  """
15
16
  Tokenize and otherwise transform text.
16
17