datachain 0.7.0__py3-none-any.whl → 0.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (52) hide show
  1. datachain/__init__.py +0 -3
  2. datachain/catalog/catalog.py +8 -6
  3. datachain/cli.py +1 -1
  4. datachain/client/fsspec.py +9 -9
  5. datachain/data_storage/schema.py +2 -2
  6. datachain/data_storage/sqlite.py +5 -4
  7. datachain/data_storage/warehouse.py +18 -18
  8. datachain/func/__init__.py +49 -0
  9. datachain/{lib/func → func}/aggregate.py +13 -11
  10. datachain/func/array.py +176 -0
  11. datachain/func/base.py +23 -0
  12. datachain/func/conditional.py +81 -0
  13. datachain/func/func.py +384 -0
  14. datachain/func/path.py +110 -0
  15. datachain/func/random.py +23 -0
  16. datachain/func/string.py +154 -0
  17. datachain/func/window.py +49 -0
  18. datachain/lib/arrow.py +24 -12
  19. datachain/lib/data_model.py +25 -9
  20. datachain/lib/dataset_info.py +2 -2
  21. datachain/lib/dc.py +94 -56
  22. datachain/lib/hf.py +1 -1
  23. datachain/lib/signal_schema.py +1 -1
  24. datachain/lib/utils.py +1 -0
  25. datachain/lib/webdataset_laion.py +5 -5
  26. datachain/model/__init__.py +6 -0
  27. datachain/model/bbox.py +102 -0
  28. datachain/model/pose.py +88 -0
  29. datachain/model/segment.py +47 -0
  30. datachain/model/ultralytics/__init__.py +27 -0
  31. datachain/model/ultralytics/bbox.py +147 -0
  32. datachain/model/ultralytics/pose.py +113 -0
  33. datachain/model/ultralytics/segment.py +91 -0
  34. datachain/nodes_fetcher.py +2 -2
  35. datachain/query/dataset.py +57 -34
  36. datachain/sql/__init__.py +0 -2
  37. datachain/sql/functions/__init__.py +0 -26
  38. datachain/sql/selectable.py +11 -5
  39. datachain/sql/sqlite/base.py +11 -2
  40. datachain/toolkit/split.py +6 -2
  41. {datachain-0.7.0.dist-info → datachain-0.7.2.dist-info}/METADATA +72 -71
  42. {datachain-0.7.0.dist-info → datachain-0.7.2.dist-info}/RECORD +46 -35
  43. {datachain-0.7.0.dist-info → datachain-0.7.2.dist-info}/WHEEL +1 -1
  44. datachain/lib/func/__init__.py +0 -32
  45. datachain/lib/func/func.py +0 -152
  46. datachain/lib/models/__init__.py +0 -5
  47. datachain/lib/models/bbox.py +0 -45
  48. datachain/lib/models/pose.py +0 -37
  49. datachain/lib/models/yolo.py +0 -39
  50. {datachain-0.7.0.dist-info → datachain-0.7.2.dist-info}/LICENSE +0 -0
  51. {datachain-0.7.0.dist-info → datachain-0.7.2.dist-info}/entry_points.txt +0 -0
  52. {datachain-0.7.0.dist-info → datachain-0.7.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,49 @@
1
+ from dataclasses import dataclass
2
+
3
+ from datachain.query.schema import ColumnMeta
4
+
5
+
6
+ @dataclass
7
+ class Window:
8
+ """Represents a window specification for SQL window functions."""
9
+
10
+ partition_by: str
11
+ order_by: str
12
+ desc: bool = False
13
+
14
+
15
+ def window(partition_by: str, order_by: str, desc: bool = False) -> Window:
16
+ """
17
+ Defines a window specification for SQL window functions.
18
+
19
+ The `window` function specifies how to partition and order the result set
20
+ for the associated window function. It is used to define the scope of the rows
21
+ that the window function will operate on.
22
+
23
+ Args:
24
+ partition_by (str): The column name by which to partition the result set.
25
+ Rows with the same value in the partition column
26
+ will be grouped together for the window function.
27
+ order_by (str): The column name by which to order the rows
28
+ within each partition. This determines the sequence in which
29
+ the window function is applied.
30
+ desc (bool, optional): If True, the rows will be ordered in descending order.
31
+ Defaults to False, which orders the rows
32
+ in ascending order.
33
+
34
+ Returns:
35
+ Window: A Window object representing the window specification.
36
+
37
+ Example:
38
+ ```py
39
+ window = func.window(partition_by="signal.category", order_by="created_at")
40
+ dc.mutate(
41
+ row_number=func.row_number().over(window),
42
+ )
43
+ ```
44
+ """
45
+ return Window(
46
+ ColumnMeta.to_db_name(partition_by),
47
+ ColumnMeta.to_db_name(order_by),
48
+ desc,
49
+ )
datachain/lib/arrow.py CHANGED
@@ -116,31 +116,43 @@ def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
116
116
  return pa.unify_schemas(schemas)
117
117
 
118
118
 
119
- def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = None):
120
- """Generate UDF output schema from pyarrow schema."""
119
+ def schema_to_output(
120
+ schema: pa.Schema, col_names: Optional[Sequence[str]] = None
121
+ ) -> tuple[dict[str, type], list[str]]:
122
+ """
123
+ Generate UDF output schema from pyarrow schema.
124
+ Returns a tuple of output schema and original column names (since they may be
125
+ normalized in the output dict).
126
+ """
127
+ signal_schema = _get_datachain_schema(schema)
128
+ if signal_schema:
129
+ return signal_schema.values, list(signal_schema.values)
130
+
121
131
  if col_names and (len(schema) != len(col_names)):
122
132
  raise ValueError(
123
133
  "Error generating output from Arrow schema - "
124
134
  f"Schema has {len(schema)} columns but got {len(col_names)} column names."
125
135
  )
126
136
  if not col_names:
127
- col_names = schema.names
128
- signal_schema = _get_datachain_schema(schema)
129
- if signal_schema:
130
- return signal_schema.values
131
- columns = list(normalize_col_names(col_names).keys()) # type: ignore[arg-type]
137
+ col_names = schema.names or []
138
+
139
+ normalized_col_dict = normalize_col_names(col_names)
140
+ col_names = list(normalized_col_dict)
141
+
132
142
  hf_schema = _get_hf_schema(schema)
133
143
  if hf_schema:
134
144
  return {
135
- column: hf_type for hf_type, column in zip(hf_schema[1].values(), columns)
136
- }
145
+ column: hf_type for hf_type, column in zip(hf_schema[1].values(), col_names)
146
+ }, list(normalized_col_dict.values())
147
+
137
148
  output = {}
138
- for field, column in zip(schema, columns):
139
- dtype = arrow_type_mapper(field.type, column) # type: ignore[assignment]
149
+ for field, column in zip(schema, col_names):
150
+ dtype = arrow_type_mapper(field.type, column)
140
151
  if field.nullable and not ModelStore.is_pydantic(dtype):
141
152
  dtype = Optional[dtype] # type: ignore[assignment]
142
153
  output[column] = dtype
143
- return output
154
+
155
+ return output, list(normalized_col_dict.values())
144
156
 
145
157
 
146
158
  def arrow_type_mapper(col_type: pa.DataType, column: str = "") -> type: # noqa: PLR0911
@@ -1,8 +1,8 @@
1
1
  from collections.abc import Sequence
2
2
  from datetime import datetime
3
- from typing import ClassVar, Union, get_args, get_origin
3
+ from typing import ClassVar, Optional, Union, get_args, get_origin
4
4
 
5
- from pydantic import BaseModel, Field, create_model
5
+ from pydantic import AliasChoices, BaseModel, Field, create_model
6
6
 
7
7
  from datachain.lib.model_store import ModelStore
8
8
  from datachain.lib.utils import normalize_col_names
@@ -60,17 +60,33 @@ def is_chain_type(t: type) -> bool:
60
60
  return False
61
61
 
62
62
 
63
- def dict_to_data_model(name: str, data_dict: dict[str, DataType]) -> type[BaseModel]:
64
- # Gets a map of a normalized_name -> original_name
65
- columns = normalize_col_names(list(data_dict.keys()))
66
- # We reverse if for convenience to original_name -> normalized_name
67
- columns = {v: k for k, v in columns.items()}
63
+ def dict_to_data_model(
64
+ name: str,
65
+ data_dict: dict[str, DataType],
66
+ original_names: Optional[list[str]] = None,
67
+ ) -> type[BaseModel]:
68
+ if not original_names:
69
+ # Gets a map of a normalized_name -> original_name
70
+ columns = normalize_col_names(list(data_dict))
71
+ data_dict = dict(zip(columns.keys(), data_dict.values()))
72
+ original_names = list(columns.values())
68
73
 
69
74
  fields = {
70
- columns[name]: (anno, Field(alias=name)) for name, anno in data_dict.items()
75
+ name: (
76
+ anno,
77
+ Field(
78
+ validation_alias=AliasChoices(name, original_names[idx] or name),
79
+ default=None,
80
+ ),
81
+ )
82
+ for idx, (name, anno) in enumerate(data_dict.items())
71
83
  }
84
+
85
+ class _DataModelStrict(BaseModel, extra="forbid"):
86
+ pass
87
+
72
88
  return create_model(
73
89
  name,
74
- __base__=(DataModel,), # type: ignore[call-overload]
90
+ __base__=_DataModelStrict,
75
91
  **fields,
76
92
  ) # type: ignore[call-overload]
@@ -23,8 +23,8 @@ class DatasetInfo(DataModel):
23
23
  finished_at: Optional[datetime] = Field(default=None)
24
24
  num_objects: Optional[int] = Field(default=None)
25
25
  size: Optional[int] = Field(default=None)
26
- params: dict[str, str] = Field(default=dict)
27
- metrics: dict[str, Any] = Field(default=dict)
26
+ params: dict[str, str] = Field(default={})
27
+ metrics: dict[str, Any] = Field(default={})
28
28
  error_message: str = Field(default="")
29
29
  error_stack: str = Field(default="")
30
30
 
datachain/lib/dc.py CHANGED
@@ -28,13 +28,14 @@ from sqlalchemy.sql.sqltypes import NullType
28
28
  from datachain.client import Client
29
29
  from datachain.client.local import FileClient
30
30
  from datachain.dataset import DatasetRecord
31
+ from datachain.func.base import Function
32
+ from datachain.func.func import Func
31
33
  from datachain.lib.convert.python_to_sql import python_to_sql
32
34
  from datachain.lib.convert.values_to_tuples import values_to_tuples
33
35
  from datachain.lib.data_model import DataModel, DataType, DataValue, dict_to_data_model
34
36
  from datachain.lib.dataset_info import DatasetInfo
35
37
  from datachain.lib.file import ArrowRow, File, get_file_type
36
38
  from datachain.lib.file import ExportPlacement as FileExportPlacement
37
- from datachain.lib.func import Func
38
39
  from datachain.lib.listing import (
39
40
  list_bucket,
40
41
  ls,
@@ -112,9 +113,29 @@ class DatasetFromValuesError(DataChainParamsError): # noqa: D101
112
113
  super().__init__(f"Dataset{name} from values error: {msg}")
113
114
 
114
115
 
115
- def _get_merge_error_str(col: Union[str, sqlalchemy.ColumnElement]) -> str:
116
+ MergeColType = Union[str, Function, sqlalchemy.ColumnElement]
117
+
118
+
119
+ def _validate_merge_on(
120
+ on: Union[MergeColType, Sequence[MergeColType]],
121
+ ds: "DataChain",
122
+ ) -> Sequence[MergeColType]:
123
+ if isinstance(on, (str, sqlalchemy.ColumnElement)):
124
+ return [on]
125
+ if isinstance(on, Function):
126
+ return [on.get_column(table=ds._query.table)]
127
+ if isinstance(on, Sequence):
128
+ return [
129
+ c.get_column(table=ds._query.table) if isinstance(c, Function) else c
130
+ for c in on
131
+ ]
132
+
133
+
134
+ def _get_merge_error_str(col: MergeColType) -> str:
116
135
  if isinstance(col, str):
117
136
  return col
137
+ if isinstance(col, Function):
138
+ return f"{col.name}()"
118
139
  if isinstance(col, sqlalchemy.Column):
119
140
  return col.name.replace(DEFAULT_DELIMITER, ".")
120
141
  if isinstance(col, sqlalchemy.ColumnElement) and hasattr(col, "name"):
@@ -125,11 +146,13 @@ def _get_merge_error_str(col: Union[str, sqlalchemy.ColumnElement]) -> str:
125
146
  class DatasetMergeError(DataChainParamsError): # noqa: D101
126
147
  def __init__( # noqa: D107
127
148
  self,
128
- on: Sequence[Union[str, sqlalchemy.ColumnElement]],
129
- right_on: Optional[Sequence[Union[str, sqlalchemy.ColumnElement]]],
149
+ on: Union[MergeColType, Sequence[MergeColType]],
150
+ right_on: Optional[Union[MergeColType, Sequence[MergeColType]]],
130
151
  msg: str,
131
152
  ):
132
- def _get_str(on: Sequence[Union[str, sqlalchemy.ColumnElement]]) -> str:
153
+ def _get_str(
154
+ on: Union[MergeColType, Sequence[MergeColType]],
155
+ ) -> str:
133
156
  if not isinstance(on, Sequence):
134
157
  return str(on) # type: ignore[unreachable]
135
158
  return ", ".join([_get_merge_error_str(col) for col in on])
@@ -348,6 +371,9 @@ class DataChain:
348
371
  enable all available CPUs (default=1)
349
372
  workers : number of distributed workers. Only for Studio mode. (default=1)
350
373
  min_task_size : minimum number of tasks (default=1)
374
+ prefetch: number of workers to use for downloading files in advance.
375
+ This is enabled by default and uses 2 workers.
376
+ To disable prefetching, set it to 0.
351
377
 
352
378
  Example:
353
379
  ```py
@@ -648,6 +674,7 @@ class DataChain:
648
674
  col: str,
649
675
  model_name: Optional[str] = None,
650
676
  object_name: Optional[str] = None,
677
+ schema_sample_size: int = 1,
651
678
  ) -> "DataChain":
652
679
  """Explodes a column containing JSON objects (dict or str DataChain type) into
653
680
  individual columns based on the schema of the JSON. Schema is inferred from
@@ -659,6 +686,9 @@ class DataChain:
659
686
  automatically.
660
687
  object_name: optional generated object column name. By default generates the
661
688
  name automatically.
689
+ schema_sample_size: the number of rows to use for inferring the schema of
690
+ the JSON (in case some fields are optional and it's not enough to
691
+ analyze a single row).
662
692
 
663
693
  Returns:
664
694
  DataChain: A new DataChain instance with the new set of columns.
@@ -669,21 +699,22 @@ class DataChain:
669
699
 
670
700
  from datachain.lib.arrow import schema_to_output
671
701
 
672
- json_value = next(self.limit(1).collect(col))
673
- json_dict = (
702
+ json_values = list(self.limit(schema_sample_size).collect(col))
703
+ json_dicts = [
674
704
  json.loads(json_value) if isinstance(json_value, str) else json_value
675
- )
705
+ for json_value in json_values
706
+ ]
676
707
 
677
- if not isinstance(json_dict, dict):
708
+ if any(not isinstance(json_dict, dict) for json_dict in json_dicts):
678
709
  raise TypeError(f"Column {col} should be a string or dict type with JSON")
679
710
 
680
- schema = pa.Table.from_pylist([json_dict]).schema
681
- output = schema_to_output(schema, None)
711
+ schema = pa.Table.from_pylist(json_dicts).schema
712
+ output, original_names = schema_to_output(schema, None)
682
713
 
683
714
  if not model_name:
684
715
  model_name = f"{col.title()}ExplodedModel"
685
716
 
686
- model = dict_to_data_model(model_name, output)
717
+ model = dict_to_data_model(model_name, output, original_names)
687
718
 
688
719
  def json_to_model(json_value: Union[str, dict]):
689
720
  json_dict = (
@@ -776,7 +807,7 @@ class DataChain:
776
807
  ```py
777
808
  uri = "gs://datachain-demo/coco2017/annotations_captions/"
778
809
  chain = DataChain.from_storage(uri)
779
- chain = chain.show_json_schema()
810
+ chain = chain.print_json_schema()
780
811
  chain.save()
781
812
  ```
782
813
  """
@@ -1119,7 +1150,7 @@ class DataChain:
1119
1150
  def group_by(
1120
1151
  self,
1121
1152
  *,
1122
- partition_by: Union[str, Sequence[str]],
1153
+ partition_by: Union[str, Func, Sequence[Union[str, Func]]],
1123
1154
  **kwargs: Func,
1124
1155
  ) -> "Self":
1125
1156
  """Group rows by specified set of signals and return new signals
@@ -1136,36 +1167,47 @@ class DataChain:
1136
1167
  )
1137
1168
  ```
1138
1169
  """
1139
- if isinstance(partition_by, str):
1170
+ if isinstance(partition_by, (str, Func)):
1140
1171
  partition_by = [partition_by]
1141
1172
  if not partition_by:
1142
1173
  raise ValueError("At least one column should be provided for partition_by")
1143
1174
 
1144
- if not kwargs:
1145
- raise ValueError("At least one column should be provided for group_by")
1146
- for col_name, func in kwargs.items():
1147
- if not isinstance(func, Func):
1148
- raise DataChainColumnError(
1149
- col_name,
1150
- f"Column {col_name} has type {type(func)} but expected Func object",
1151
- )
1152
-
1153
1175
  partition_by_columns: list[Column] = []
1154
1176
  signal_columns: list[Column] = []
1155
1177
  schema_fields: dict[str, DataType] = {}
1156
1178
 
1157
1179
  # validate partition_by columns and add them to the schema
1158
- for col_name in partition_by:
1159
- col_db_name = ColumnMeta.to_db_name(col_name)
1160
- col_type = self.signals_schema.get_column_type(col_db_name)
1161
- col = Column(col_db_name, python_to_sql(col_type))
1162
- partition_by_columns.append(col)
1180
+ for col in partition_by:
1181
+ if isinstance(col, str):
1182
+ col_db_name = ColumnMeta.to_db_name(col)
1183
+ col_type = self.signals_schema.get_column_type(col_db_name)
1184
+ column = Column(col_db_name, python_to_sql(col_type))
1185
+ elif isinstance(col, Function):
1186
+ column = col.get_column(self.signals_schema)
1187
+ col_db_name = column.name
1188
+ col_type = column.type.python_type
1189
+ else:
1190
+ raise DataChainColumnError(
1191
+ col,
1192
+ (
1193
+ f"partition_by column {col} has type {type(col)}"
1194
+ " but expected str or Function"
1195
+ ),
1196
+ )
1197
+ partition_by_columns.append(column)
1163
1198
  schema_fields[col_db_name] = col_type
1164
1199
 
1165
1200
  # validate signal columns and add them to the schema
1201
+ if not kwargs:
1202
+ raise ValueError("At least one column should be provided for group_by")
1166
1203
  for col_name, func in kwargs.items():
1167
- col = func.get_column(self.signals_schema, label=col_name)
1168
- signal_columns.append(col)
1204
+ if not isinstance(func, Func):
1205
+ raise DataChainColumnError(
1206
+ col_name,
1207
+ f"Column {col_name} has type {type(func)} but expected Func object",
1208
+ )
1209
+ column = func.get_column(self.signals_schema, label=col_name)
1210
+ signal_columns.append(column)
1169
1211
  schema_fields[col_name] = func.get_result_type(self.signals_schema)
1170
1212
 
1171
1213
  return self._evolve(
@@ -1413,25 +1455,16 @@ class DataChain:
1413
1455
  def merge(
1414
1456
  self,
1415
1457
  right_ds: "DataChain",
1416
- on: Union[
1417
- str,
1418
- sqlalchemy.ColumnElement,
1419
- Sequence[Union[str, sqlalchemy.ColumnElement]],
1420
- ],
1421
- right_on: Union[
1422
- str,
1423
- sqlalchemy.ColumnElement,
1424
- Sequence[Union[str, sqlalchemy.ColumnElement]],
1425
- None,
1426
- ] = None,
1458
+ on: Union[MergeColType, Sequence[MergeColType]],
1459
+ right_on: Optional[Union[MergeColType, Sequence[MergeColType]]] = None,
1427
1460
  inner=False,
1428
1461
  rname="right_",
1429
1462
  ) -> "Self":
1430
1463
  """Merge two chains based on the specified criteria.
1431
1464
 
1432
1465
  Parameters:
1433
- right_ds : Chain to join with.
1434
- on : Predicate or list of Predicates to join on. If both chains have the
1466
+ right_ds: Chain to join with.
1467
+ on: Predicate or list of Predicates to join on. If both chains have the
1435
1468
  same predicates then this predicate is enough for the join. Otherwise,
1436
1469
  `right_on` parameter has to specify the predicates for the other chain.
1437
1470
  right_on: Optional predicate or list of Predicates
@@ -1448,23 +1481,24 @@ class DataChain:
1448
1481
  if on is None:
1449
1482
  raise DatasetMergeError(["None"], None, "'on' must be specified")
1450
1483
 
1451
- if isinstance(on, (str, sqlalchemy.ColumnElement)):
1452
- on = [on]
1453
- elif not isinstance(on, Sequence):
1484
+ on = _validate_merge_on(on, self)
1485
+ if not on:
1454
1486
  raise DatasetMergeError(
1455
1487
  on,
1456
1488
  right_on,
1457
- f"'on' must be 'str' or 'Sequence' object but got type '{type(on)}'",
1489
+ (
1490
+ "'on' must be 'str', 'Func' or 'Sequence' object "
1491
+ f"but got type '{type(on)}'"
1492
+ ),
1458
1493
  )
1459
1494
 
1460
1495
  if right_on is not None:
1461
- if isinstance(right_on, (str, sqlalchemy.ColumnElement)):
1462
- right_on = [right_on]
1463
- elif not isinstance(right_on, Sequence):
1496
+ right_on = _validate_merge_on(right_on, right_ds)
1497
+ if not right_on:
1464
1498
  raise DatasetMergeError(
1465
1499
  on,
1466
1500
  right_on,
1467
- "'right_on' must be 'str' or 'Sequence' object"
1501
+ "'right_on' must be 'str', 'Func' or 'Sequence' object"
1468
1502
  f" but got type '{type(right_on)}'",
1469
1503
  )
1470
1504
 
@@ -1480,10 +1514,12 @@ class DataChain:
1480
1514
 
1481
1515
  def _resolve(
1482
1516
  ds: DataChain,
1483
- col: Union[str, sqlalchemy.ColumnElement],
1517
+ col: Union[str, Function, sqlalchemy.ColumnElement],
1484
1518
  side: Union[str, None],
1485
1519
  ):
1486
1520
  try:
1521
+ if isinstance(col, Function):
1522
+ return ds.c(col.get_column())
1487
1523
  return ds.c(col) if isinstance(col, (str, C)) else col
1488
1524
  except ValueError:
1489
1525
  if side:
@@ -1834,13 +1870,14 @@ class DataChain:
1834
1870
  if col_names or not output:
1835
1871
  try:
1836
1872
  schema = infer_schema(self, **kwargs)
1837
- output = schema_to_output(schema, col_names)
1873
+ output, _ = schema_to_output(schema, col_names)
1838
1874
  except ValueError as e:
1839
1875
  raise DatasetPrepareError(self.name, e) from e
1840
1876
 
1841
1877
  if isinstance(output, dict):
1842
1878
  model_name = model_name or object_name or ""
1843
1879
  model = dict_to_data_model(model_name, output)
1880
+ output = model
1844
1881
  else:
1845
1882
  model = output # type: ignore[assignment]
1846
1883
 
@@ -1851,6 +1888,7 @@ class DataChain:
1851
1888
  name: info.annotation # type: ignore[misc]
1852
1889
  for name, info in output.model_fields.items()
1853
1890
  }
1891
+
1854
1892
  if source:
1855
1893
  output = {"source": ArrowRow} | output # type: ignore[assignment,operator]
1856
1894
  return self.gen(
@@ -2389,9 +2427,9 @@ class DataChain:
2389
2427
  dc.filter(C("file.name").glob("*.jpg"))
2390
2428
  ```
2391
2429
 
2392
- Using `datachain.sql.functions`
2430
+ Using `datachain.func`
2393
2431
  ```py
2394
- from datachain.sql.functions import string
2432
+ from datachain.func import string
2395
2433
  dc.filter(string.length(C("file.name")) > 5)
2396
2434
  ```
2397
2435
 
datachain/lib/hf.py CHANGED
@@ -98,7 +98,7 @@ class HFGenerator(Generator):
98
98
  with tqdm(desc=desc, unit=" rows") as pbar:
99
99
  for row in ds:
100
100
  output_dict = {}
101
- if split:
101
+ if split and "split" in self.output_schema.model_fields:
102
102
  output_dict["split"] = split
103
103
  for name, feat in ds.features.items():
104
104
  anno = self.output_schema.model_fields[name].annotation
@@ -23,12 +23,12 @@ from pydantic import BaseModel, create_model
23
23
  from sqlalchemy import ColumnElement
24
24
  from typing_extensions import Literal as LiteralEx
25
25
 
26
+ from datachain.func.func import Func
26
27
  from datachain.lib.convert.python_to_sql import python_to_sql
27
28
  from datachain.lib.convert.sql_to_python import sql_to_python
28
29
  from datachain.lib.convert.unflatten import unflatten_to_json_pos
29
30
  from datachain.lib.data_model import DataModel, DataType, DataValue
30
31
  from datachain.lib.file import File
31
- from datachain.lib.func import Func
32
32
  from datachain.lib.model_store import ModelStore
33
33
  from datachain.lib.utils import DataChainParamsError
34
34
  from datachain.query.schema import DEFAULT_DELIMITER, Column
datachain/lib/utils.py CHANGED
@@ -33,6 +33,7 @@ class DataChainColumnError(DataChainParamsError):
33
33
 
34
34
 
35
35
  def normalize_col_names(col_names: Sequence[str]) -> dict[str, str]:
36
+ """Returns normalized_name -> original_name dict."""
36
37
  gen_col_counter = 0
37
38
  new_col_names = {}
38
39
  org_col_names = set(col_names)
@@ -49,11 +49,11 @@ class WDSLaion(WDSBasic):
49
49
  class LaionMeta(BaseModel):
50
50
  file: File
51
51
  index: Optional[int] = Field(default=None)
52
- b32_img: list[float] = Field(default=None)
53
- b32_txt: list[float] = Field(default=None)
54
- l14_img: list[float] = Field(default=None)
55
- l14_txt: list[float] = Field(default=None)
56
- dedup: list[float] = Field(default=None)
52
+ b32_img: list[float] = Field(default=[])
53
+ b32_txt: list[float] = Field(default=[])
54
+ l14_img: list[float] = Field(default=[])
55
+ l14_txt: list[float] = Field(default=[])
56
+ dedup: list[float] = Field(default=[])
57
57
 
58
58
 
59
59
  def process_laion_meta(file: File) -> Iterator[LaionMeta]:
@@ -0,0 +1,6 @@
1
+ from . import ultralytics
2
+ from .bbox import BBox, OBBox
3
+ from .pose import Pose, Pose3D
4
+ from .segment import Segment
5
+
6
+ __all__ = ["BBox", "OBBox", "Pose", "Pose3D", "Segment", "ultralytics"]
@@ -0,0 +1,102 @@
1
+ from pydantic import Field
2
+
3
+ from datachain.lib.data_model import DataModel
4
+
5
+
6
+ class BBox(DataModel):
7
+ """
8
+ A data model for representing bounding box.
9
+
10
+ Attributes:
11
+ title (str): The title of the bounding box.
12
+ coords (list[int]): The coordinates of the bounding box.
13
+
14
+ The bounding box is defined by two points:
15
+ - (x1, y1): The top-left corner of the box.
16
+ - (x2, y2): The bottom-right corner of the box.
17
+ """
18
+
19
+ title: str = Field(default="")
20
+ coords: list[int] = Field(default=[])
21
+
22
+ @staticmethod
23
+ def from_list(coords: list[float], title: str = "") -> "BBox":
24
+ assert len(coords) == 4, "Bounding box must be a list of 4 coordinates."
25
+ assert all(
26
+ isinstance(value, (int, float)) for value in coords
27
+ ), "Bounding box coordinates must be floats or integers."
28
+ return BBox(
29
+ title=title,
30
+ coords=[round(c) for c in coords],
31
+ )
32
+
33
+ @staticmethod
34
+ def from_dict(coords: dict[str, float], title: str = "") -> "BBox":
35
+ assert isinstance(coords, dict) and set(coords) == {
36
+ "x1",
37
+ "y1",
38
+ "x2",
39
+ "y2",
40
+ }, "Bounding box must be a dictionary with keys 'x1', 'y1', 'x2' and 'y2'."
41
+ return BBox.from_list(
42
+ [coords["x1"], coords["y1"], coords["x2"], coords["y2"]],
43
+ title=title,
44
+ )
45
+
46
+
47
+ class OBBox(DataModel):
48
+ """
49
+ A data model for representing oriented bounding boxes.
50
+
51
+ Attributes:
52
+ title (str): The title of the oriented bounding box.
53
+ coords (list[int]): The coordinates of the oriented bounding box.
54
+
55
+ The oriented bounding box is defined by four points:
56
+ - (x1, y1): The first corner of the box.
57
+ - (x2, y2): The second corner of the box.
58
+ - (x3, y3): The third corner of the box.
59
+ - (x4, y4): The fourth corner of the box.
60
+ """
61
+
62
+ title: str = Field(default="")
63
+ coords: list[int] = Field(default=[])
64
+
65
+ @staticmethod
66
+ def from_list(coords: list[float], title: str = "") -> "OBBox":
67
+ assert (
68
+ len(coords) == 8
69
+ ), "Oriented bounding box must be a list of 8 coordinates."
70
+ assert all(
71
+ isinstance(value, (int, float)) for value in coords
72
+ ), "Oriented bounding box coordinates must be floats or integers."
73
+ return OBBox(
74
+ title=title,
75
+ coords=[round(c) for c in coords],
76
+ )
77
+
78
+ @staticmethod
79
+ def from_dict(coords: dict[str, float], title: str = "") -> "OBBox":
80
+ assert isinstance(coords, dict) and set(coords) == {
81
+ "x1",
82
+ "y1",
83
+ "x2",
84
+ "y2",
85
+ "x3",
86
+ "y3",
87
+ "x4",
88
+ "y4",
89
+ }, "Oriented bounding box must be a dictionary with coordinates."
90
+ return OBBox.from_list(
91
+ [
92
+ coords["x1"],
93
+ coords["y1"],
94
+ coords["x2"],
95
+ coords["y2"],
96
+ coords["x3"],
97
+ coords["y3"],
98
+ coords["x4"],
99
+ coords["y4"],
100
+ ],
101
+ title=title,
102
+ )