datachain 0.2.13__py3-none-any.whl → 0.2.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

datachain/lib/dc.py CHANGED
@@ -33,6 +33,7 @@ from datachain.lib.settings import Settings
33
33
  from datachain.lib.signal_schema import SignalSchema
34
34
  from datachain.lib.udf import (
35
35
  Aggregator,
36
+ BatchMapper,
36
37
  Generator,
37
38
  Mapper,
38
39
  UDFBase,
@@ -192,6 +193,8 @@ class DataChain(DatasetQuery):
192
193
  ```
193
194
  """
194
195
 
196
+ max_row_count: Optional[int] = None
197
+
195
198
  DEFAULT_FILE_RECORD: ClassVar[dict] = {
196
199
  "source": "",
197
200
  "name": "",
@@ -237,7 +240,6 @@ class DataChain(DatasetQuery):
237
240
  def settings(
238
241
  self,
239
242
  cache=None,
240
- batch=None,
241
243
  parallel=None,
242
244
  workers=None,
243
245
  min_task_size=None,
@@ -250,7 +252,6 @@ class DataChain(DatasetQuery):
250
252
 
251
253
  Parameters:
252
254
  cache : data caching (default=False)
253
- batch : size of the batch (default=1000)
254
255
  parallel : number of thread for processors. True is a special value to
255
256
  enable all available CPUs (default=1)
256
257
  workers : number of distributed workers. Only for Studio mode. (default=1)
@@ -268,7 +269,7 @@ class DataChain(DatasetQuery):
268
269
  chain = self.clone()
269
270
  if sys is not None:
270
271
  chain._sys = sys
271
- chain._settings.add(Settings(cache, batch, parallel, workers, min_task_size))
272
+ chain._settings.add(Settings(cache, parallel, workers, min_task_size))
272
273
  return chain
273
274
 
274
275
  def reset_settings(self, settings: Optional[Settings] = None) -> "Self":
@@ -342,9 +343,9 @@ class DataChain(DatasetQuery):
342
343
  spec: Optional[DataType] = None,
343
344
  schema_from: Optional[str] = "auto",
344
345
  jmespath: Optional[str] = None,
345
- object_name: str = "",
346
+ object_name: Optional[str] = "",
346
347
  model_name: Optional[str] = None,
347
- show_schema: Optional[bool] = False,
348
+ print_schema: Optional[bool] = False,
348
349
  meta_type: Optional[str] = "json",
349
350
  nrows=None,
350
351
  **kwargs,
@@ -359,17 +360,17 @@ class DataChain(DatasetQuery):
359
360
  schema_from : path to sample to infer spec (if schema not provided)
360
361
  object_name : generated object column name
361
362
  model_name : optional generated model name
362
- show_schema : print auto-generated schema
363
+ print_schema : print auto-generated schema
363
364
  jmespath : optional JMESPATH expression to reduce JSON
364
365
  nrows : optional row limit for jsonl and JSON arrays
365
366
 
366
367
  Example:
367
- infer JSON schema from data, reduce using JMESPATH, print schema
368
+ infer JSON schema from data, reduce using JMESPATH
368
369
  ```py
369
370
  chain = DataChain.from_json("gs://json", jmespath="key1.key2")
370
371
  ```
371
372
 
372
- infer JSON schema from a particular path, print data model
373
+ infer JSON schema from a particular path
373
374
  ```py
374
375
  chain = DataChain.from_json("gs://json_ds", schema_from="gs://json/my.json")
375
376
  ```
@@ -384,7 +385,67 @@ class DataChain(DatasetQuery):
384
385
  if (not object_name) and jmespath:
385
386
  object_name = jmespath_to_name(jmespath)
386
387
  if not object_name:
387
- object_name = "json"
388
+ object_name = meta_type
389
+ chain = DataChain.from_storage(path=path, type=type, **kwargs)
390
+ signal_dict = {
391
+ object_name: read_meta(
392
+ schema_from=schema_from,
393
+ meta_type=meta_type,
394
+ spec=spec,
395
+ model_name=model_name,
396
+ print_schema=print_schema,
397
+ jmespath=jmespath,
398
+ nrows=nrows,
399
+ )
400
+ }
401
+ return chain.gen(**signal_dict) # type: ignore[misc, arg-type]
402
+
403
+ @classmethod
404
+ def from_jsonl(
405
+ cls,
406
+ path,
407
+ type: Literal["binary", "text", "image"] = "text",
408
+ spec: Optional[DataType] = None,
409
+ schema_from: Optional[str] = "auto",
410
+ jmespath: Optional[str] = None,
411
+ object_name: Optional[str] = "",
412
+ model_name: Optional[str] = None,
413
+ print_schema: Optional[bool] = False,
414
+ meta_type: Optional[str] = "jsonl",
415
+ nrows=None,
416
+ **kwargs,
417
+ ) -> "DataChain":
418
+ """Get data from JSON lines. It returns the chain itself.
419
+
420
+ Parameters:
421
+ path : storage URI with directory. URI must start with storage prefix such
422
+ as `s3://`, `gs://`, `az://` or "file:///"
423
+ type : read file as "binary", "text", or "image" data. Default is "binary".
424
+ spec : optional Data Model
425
+ schema_from : path to sample to infer spec (if schema not provided)
426
+ object_name : generated object column name
427
+ model_name : optional generated model name
428
+ print_schema : print auto-generated schema
429
+ jmespath : optional JMESPATH expression to reduce JSON
430
+ nrows : optional row limit for jsonl and JSON arrays
431
+
432
+ Example:
433
+ infer JSONl schema from data, limit parsing to 1 row
434
+ ```py
435
+ chain = DataChain.from_jsonl("gs://myjsonl", nrows=1)
436
+ ```
437
+ """
438
+ if schema_from == "auto":
439
+ schema_from = path
440
+
441
+ def jmespath_to_name(s: str):
442
+ name_end = re.search(r"\W", s).start() if re.search(r"\W", s) else len(s) # type: ignore[union-attr]
443
+ return s[:name_end]
444
+
445
+ if (not object_name) and jmespath:
446
+ object_name = jmespath_to_name(jmespath)
447
+ if not object_name:
448
+ object_name = meta_type
388
449
  chain = DataChain.from_storage(path=path, type=type, **kwargs)
389
450
  signal_dict = {
390
451
  object_name: read_meta(
@@ -392,12 +453,12 @@ class DataChain(DatasetQuery):
392
453
  meta_type=meta_type,
393
454
  spec=spec,
394
455
  model_name=model_name,
395
- show_schema=show_schema,
456
+ print_schema=print_schema,
396
457
  jmespath=jmespath,
397
458
  nrows=nrows,
398
459
  )
399
460
  }
400
- return chain.gen(**signal_dict) # type: ignore[arg-type]
461
+ return chain.gen(**signal_dict) # type: ignore[misc, arg-type]
401
462
 
402
463
  @classmethod
403
464
  def datasets(
@@ -428,7 +489,7 @@ class DataChain(DatasetQuery):
428
489
  **{object_name: datasets}, # type: ignore[arg-type]
429
490
  )
430
491
 
431
- def show_json_schema( # type: ignore[override]
492
+ def print_json_schema( # type: ignore[override]
432
493
  self, jmespath: Optional[str] = None, model_name: Optional[str] = None
433
494
  ) -> "DataChain":
434
495
  """Print JSON data model and save it. It returns the chain itself.
@@ -453,7 +514,7 @@ class DataChain(DatasetQuery):
453
514
  output=str,
454
515
  )
455
516
 
456
- def show_jsonl_schema( # type: ignore[override]
517
+ def print_jsonl_schema( # type: ignore[override]
457
518
  self, jmespath: Optional[str] = None, model_name: Optional[str] = None
458
519
  ) -> "DataChain":
459
520
  """Print JSON data model and save it. It returns the chain itself.
@@ -538,14 +599,16 @@ class DataChain(DatasetQuery):
538
599
 
539
600
  Using func and output as a map:
540
601
  ```py
541
- chain = chain.map(lambda name: name[:-4] + ".json", output={"res": str})
602
+ chain = chain.map(
603
+ lambda name: name.split("."), output={"stem": str, "ext": str}
604
+ )
542
605
  chain.save("new_dataset")
543
606
  ```
544
607
  """
545
608
  udf_obj = self._udf_to_obj(Mapper, func, params, output, signal_map)
546
609
 
547
610
  chain = self.add_signals(
548
- udf_obj.to_udf_wrapper(self._settings.batch),
611
+ udf_obj.to_udf_wrapper(),
549
612
  **self._settings.to_dict(),
550
613
  )
551
614
 
@@ -558,7 +621,7 @@ class DataChain(DatasetQuery):
558
621
  output: OutputType = None,
559
622
  **signal_map,
560
623
  ) -> "Self":
561
- """Apply a function to each row to create new rows (with potentially new
624
+ r"""Apply a function to each row to create new rows (with potentially new
562
625
  signals). The function needs to return a new objects for each of the new rows.
563
626
  It returns a chain itself with new signals.
564
627
 
@@ -568,11 +631,20 @@ class DataChain(DatasetQuery):
568
631
  one key differences: It produces a sequence of rows for each input row (like
569
632
  extracting multiple file records from a single tar file or bounding boxes from a
570
633
  single image file).
634
+
635
+ Example:
636
+ ```py
637
+ chain = chain.gen(
638
+ line=lambda file: [l for l in file.read().split("\n")],
639
+ output=str,
640
+ )
641
+ chain.save("new_dataset")
642
+ ```
571
643
  """
572
644
  udf_obj = self._udf_to_obj(Generator, func, params, output, signal_map)
573
645
  chain = DatasetQuery.generate(
574
646
  self,
575
- udf_obj.to_udf_wrapper(self._settings.batch),
647
+ udf_obj.to_udf_wrapper(),
576
648
  **self._settings.to_dict(),
577
649
  )
578
650
 
@@ -592,23 +664,68 @@ class DataChain(DatasetQuery):
592
664
 
593
665
  Input-output relationship: N:M
594
666
 
595
- This method bears similarity to `gen()` and map(), employing a comparable set of
596
- parameters, yet differs in two crucial aspects:
667
+ This method bears similarity to `gen()` and `map()`, employing a comparable set
668
+ of parameters, yet differs in two crucial aspects:
597
669
  1. The `partition_by` parameter: This specifies the column name or a list of
598
670
  column names that determine the grouping criteria for aggregation.
599
671
  2. Group-based UDF function input: Instead of individual rows, the function
600
672
  receives a list all rows within each group defined by `partition_by`.
673
+
674
+ Example:
675
+ ```py
676
+ chain = chain.agg(
677
+ total=lambda category, amount: [sum(amount)],
678
+ output=float,
679
+ partition_by="category",
680
+ )
681
+ chain.save("new_dataset")
682
+ ```
601
683
  """
602
684
  udf_obj = self._udf_to_obj(Aggregator, func, params, output, signal_map)
603
685
  chain = DatasetQuery.generate(
604
686
  self,
605
- udf_obj.to_udf_wrapper(self._settings.batch),
687
+ udf_obj.to_udf_wrapper(),
606
688
  partition_by=partition_by,
607
689
  **self._settings.to_dict(),
608
690
  )
609
691
 
610
692
  return chain.reset_schema(udf_obj.output).reset_settings(self._settings)
611
693
 
694
+ def batch_map(
695
+ self,
696
+ func: Optional[Callable] = None,
697
+ params: Union[None, str, Sequence[str]] = None,
698
+ output: OutputType = None,
699
+ batch: int = 1000,
700
+ **signal_map,
701
+ ) -> "Self":
702
+ """This is a batch version of `map()`.
703
+
704
+ Input-output relationship: N:N
705
+
706
+ It accepts the same parameters plus an
707
+ additional parameter:
708
+
709
+ batch : Size of each batch passed to `func`. Defaults to 1000.
710
+
711
+ Example:
712
+ ```py
713
+ chain = chain.batch_map(
714
+ sqrt=lambda size: np.sqrt(size),
715
+ output=float
716
+ )
717
+ chain.save("new_dataset")
718
+ ```
719
+ """
720
+ udf_obj = self._udf_to_obj(BatchMapper, func, params, output, signal_map)
721
+ chain = DatasetQuery.add_signals(
722
+ self,
723
+ udf_obj.to_udf_wrapper(batch),
724
+ **self._settings.to_dict(),
725
+ )
726
+
727
+ return chain.add_schema(udf_obj.output).reset_settings(self._settings)
728
+
612
729
  def _udf_to_obj(
613
730
  self,
614
731
  target_class: type[UDFBase],
@@ -951,6 +1068,41 @@ class DataChain(DatasetQuery):
951
1068
 
952
1069
  return ds
953
1070
 
1071
+ def subtract( # type: ignore[override]
1072
+ self,
1073
+ other: "DataChain",
1074
+ on: Optional[Union[str, Sequence[str]]] = None,
1075
+ ) -> "Self":
1076
+ """Remove rows that appear in another chain.
1077
+
1078
+ Parameters:
1079
+ other: chain whose rows will be removed from `self`
1080
+ on: columns to consider for determining row equality. If unspecified,
1081
+ defaults to all common columns between `self` and `other`.
1082
+ """
1083
+ if isinstance(on, str):
1084
+ on = [on]
1085
+ if on is None:
1086
+ other_columns = set(other._effective_signals_schema.db_signals())
1087
+ signals = [
1088
+ c
1089
+ for c in self._effective_signals_schema.db_signals()
1090
+ if c in other_columns
1091
+ ]
1092
+ if not signals:
1093
+ raise DataChainParamsError("subtract(): no common columns")
1094
+ elif not isinstance(on, Sequence):
1095
+ raise TypeError(
1096
+ f"'on' must be 'str' or 'Sequence' object but got type '{type(on)}'",
1097
+ )
1098
+ elif not on:
1099
+ raise DataChainParamsError(
1100
+ "'on' cannot be empty",
1101
+ )
1102
+ else:
1103
+ signals = self.signals_schema.resolve(*on).db_signals()
1104
+ return super()._subtract(other, signals)
1105
+
954
1106
  @classmethod
955
1107
  def from_values(
956
1108
  cls,
@@ -1081,6 +1233,7 @@ class DataChain(DatasetQuery):
1081
1233
  output: OutputType = None,
1082
1234
  object_name: str = "",
1083
1235
  model_name: str = "",
1236
+ source: bool = True,
1084
1237
  nrows: Optional[int] = None,
1085
1238
  **kwargs,
1086
1239
  ) -> "DataChain":
@@ -1092,8 +1245,9 @@ class DataChain(DatasetQuery):
1092
1245
  case types will be inferred.
1093
1246
  object_name : Generated object column name.
1094
1247
  model_name : Generated model name.
1095
- kwargs : Parameters to pass to pyarrow.dataset.dataset.
1248
+ source : Whether to include info about the source file.
1096
1249
  nrows : Optional row limit.
1250
+ kwargs : Parameters to pass to pyarrow.dataset.dataset.
1097
1251
 
1098
1252
  Example:
1099
1253
  Reading a json lines file:
@@ -1120,18 +1274,24 @@ class DataChain(DatasetQuery):
1120
1274
  except ValueError as e:
1121
1275
  raise DatasetPrepareError(self.name, e) from e
1122
1276
 
1277
+ if isinstance(output, dict):
1278
+ model_name = model_name or object_name or ""
1279
+ model = DataChain._dict_to_data_model(model_name, output)
1280
+ else:
1281
+ model = output # type: ignore[assignment]
1282
+
1123
1283
  if object_name:
1124
- if isinstance(output, dict):
1125
- model_name = model_name or object_name
1126
- output = DataChain._dict_to_data_model(model_name, output)
1127
- output = {object_name: output} # type: ignore[dict-item]
1284
+ output = {object_name: model} # type: ignore[dict-item]
1128
1285
  elif isinstance(output, type(BaseModel)):
1129
1286
  output = {
1130
1287
  name: info.annotation # type: ignore[misc]
1131
1288
  for name, info in output.model_fields.items()
1132
1289
  }
1133
- output = {"source": IndexedFile} | output # type: ignore[assignment,operator]
1134
- return self.gen(ArrowGenerator(schema, nrows, **kwargs), output=output)
1290
+ if source:
1291
+ output = {"source": IndexedFile} | output # type: ignore[assignment,operator]
1292
+ return self.gen(
1293
+ ArrowGenerator(schema, model, source, nrows, **kwargs), output=output
1294
+ )
1135
1295
 
1136
1296
  @staticmethod
1137
1297
  def _dict_to_data_model(
@@ -1150,10 +1310,10 @@ class DataChain(DatasetQuery):
1150
1310
  path,
1151
1311
  delimiter: str = ",",
1152
1312
  header: bool = True,
1153
- column_names: Optional[list[str]] = None,
1154
1313
  output: OutputType = None,
1155
1314
  object_name: str = "",
1156
1315
  model_name: str = "",
1316
+ source: bool = True,
1157
1317
  nrows=None,
1158
1318
  **kwargs,
1159
1319
  ) -> "DataChain":
@@ -1169,6 +1329,7 @@ class DataChain(DatasetQuery):
1169
1329
  case types will be inferred.
1170
1330
  object_name : Created object column name.
1171
1331
  model_name : Generated model name.
1332
+ source : Whether to include info about the source file.
1172
1333
  nrows : Optional row limit.
1173
1334
 
1174
1335
  Example:
@@ -1187,6 +1348,7 @@ class DataChain(DatasetQuery):
1187
1348
 
1188
1349
  chain = DataChain.from_storage(path, **kwargs)
1189
1350
 
1351
+ column_names = None
1190
1352
  if not header:
1191
1353
  if not output:
1192
1354
  msg = "error parsing csv - provide output if no header"
@@ -1208,6 +1370,7 @@ class DataChain(DatasetQuery):
1208
1370
  output=output,
1209
1371
  object_name=object_name,
1210
1372
  model_name=model_name,
1373
+ source=source,
1211
1374
  nrows=nrows,
1212
1375
  format=format,
1213
1376
  )
@@ -1220,6 +1383,7 @@ class DataChain(DatasetQuery):
1220
1383
  output: Optional[dict[str, DataType]] = None,
1221
1384
  object_name: str = "",
1222
1385
  model_name: str = "",
1386
+ source: bool = True,
1223
1387
  nrows=None,
1224
1388
  **kwargs,
1225
1389
  ) -> "DataChain":
@@ -1232,6 +1396,7 @@ class DataChain(DatasetQuery):
1232
1396
  output : Dictionary defining column names and their corresponding types.
1233
1397
  object_name : Created object column name.
1234
1398
  model_name : Generated model name.
1399
+ source : Whether to include info about the source file.
1235
1400
  nrows : Optional row limit.
1236
1401
 
1237
1402
  Example:
@@ -1250,6 +1415,7 @@ class DataChain(DatasetQuery):
1250
1415
  output=output,
1251
1416
  object_name=object_name,
1252
1417
  model_name=model_name,
1418
+ source=source,
1253
1419
  nrows=None,
1254
1420
  format="parquet",
1255
1421
  partitioning=partitioning,
@@ -1436,7 +1602,18 @@ class DataChain(DatasetQuery):
1436
1602
  @detach
1437
1603
  def limit(self, n: int) -> "Self":
1438
1604
  """Return the first n rows of the chain."""
1439
- return super().limit(n)
1605
+ n = max(n, 0)
1606
+
1607
+ if self.max_row_count is None:
1608
+ self.max_row_count = n
1609
+ return super().limit(n)
1610
+
1611
+ limit = min(n, self.max_row_count)
1612
+ if limit == self.max_row_count:
1613
+ return self
1614
+
1615
+ self.max_row_count = limit
1616
+ return super().limit(self.max_row_count)
1440
1617
 
1441
1618
  @detach
1442
1619
  def offset(self, offset: int) -> "Self":
datachain/lib/file.py CHANGED
@@ -12,7 +12,6 @@ from urllib.parse import unquote, urlparse
12
12
  from urllib.request import url2pathname
13
13
 
14
14
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
15
- from fsspec.implementations.local import LocalFileSystem
16
15
  from PIL import Image
17
16
  from pydantic import Field, field_validator
18
17
 
@@ -283,9 +282,8 @@ class File(DataModel):
283
282
  def get_path(self) -> str:
284
283
  """Returns file path."""
285
284
  path = unquote(self.get_uri())
286
- fs = self.get_fs()
287
- if isinstance(fs, LocalFileSystem):
288
- # Drop file:// protocol
285
+ source = urlparse(self.source)
286
+ if source.scheme == "file":
289
287
  path = urlparse(path).path
290
288
  path = url2pathname(path)
291
289
  return path
@@ -300,13 +298,10 @@ class File(DataModel):
300
298
  elif placement == "etag":
301
299
  path = f"{self.etag}{self.get_file_suffix()}"
302
300
  elif placement == "fullpath":
303
- fs = self.get_fs()
304
- if isinstance(fs, LocalFileSystem):
305
- path = unquote(self.get_full_name())
306
- else:
307
- path = (
308
- Path(urlparse(self.source).netloc) / unquote(self.get_full_name())
309
- ).as_posix()
301
+ path = unquote(self.get_full_name())
302
+ source = urlparse(self.source)
303
+ if source.scheme and source.scheme != "file":
304
+ path = posixpath.join(source.netloc, path)
310
305
  elif placement == "checksum":
311
306
  raise NotImplementedError("Checksum placement not implemented yet")
312
307
  else:
@@ -11,9 +11,9 @@ from collections.abc import Iterator
11
11
  from typing import Any, Callable
12
12
 
13
13
  import jmespath as jsp
14
- from pydantic import ValidationError
14
+ from pydantic import Field, ValidationError # noqa: F401
15
15
 
16
- from datachain.lib.data_model import ModelStore # noqa: F401
16
+ from datachain.lib.data_model import DataModel # noqa: F401
17
17
  from datachain.lib.file import File
18
18
 
19
19
 
@@ -87,7 +87,8 @@ def read_schema(source_file, data_type="csv", expr=None, model_name=None):
87
87
  except subprocess.CalledProcessError as e:
88
88
  model_output = f"An error occurred in datamodel-codegen: {e.stderr}"
89
89
  print(f"{model_output}")
90
- print("\n" + f"ModelStore.register({model_name})" + "\n")
90
+ print("\n" + "from datachain.lib.data_model import DataModel" + "\n")
91
+ print("\n" + f"DataModel.register({model_name})" + "\n")
91
92
  print("\n" + f"spec={model_name}" + "\n")
92
93
  return model_output
93
94
 
@@ -100,7 +101,7 @@ def read_meta( # noqa: C901
100
101
  schema_from=None,
101
102
  meta_type="json",
102
103
  jmespath=None,
103
- show_schema=False,
104
+ print_schema=False,
104
105
  model_name=None,
105
106
  nrows=None,
106
107
  ) -> Callable:
@@ -128,7 +129,7 @@ def read_meta( # noqa: C901
128
129
  model_output = captured_output.getvalue()
129
130
  captured_output.close()
130
131
 
131
- if show_schema:
132
+ if print_schema:
132
133
  print(f"{model_output}")
133
134
  # Below 'spec' should be a dynamically converted DataModel from Pydantic
134
135
  if not spec:
@@ -147,18 +148,18 @@ def read_meta( # noqa: C901
147
148
 
148
149
  def parse_data(
149
150
  file: File,
150
- DataModel=spec, # noqa: N803
151
+ data_model=spec,
151
152
  meta_type=meta_type,
152
153
  jmespath=jmespath,
153
154
  nrows=nrows,
154
155
  ) -> Iterator[spec]:
155
- def validator(json_object: dict) -> spec:
156
+ def validator(json_object: dict, nrow=0) -> spec:
156
157
  json_string = json.dumps(json_object)
157
158
  try:
158
- data_instance = DataModel.model_validate_json(json_string)
159
+ data_instance = data_model.model_validate_json(json_string)
159
160
  yield data_instance
160
161
  except ValidationError as e:
161
- print(f"Validation error occurred in file {file.name}:", e)
162
+ print(f"Validation error occurred in row {nrow} file {file.name}:", e)
162
163
 
163
164
  if meta_type == "csv":
164
165
  with (
@@ -184,7 +185,7 @@ def read_meta( # noqa: C901
184
185
  nrow = nrow + 1
185
186
  if nrows is not None and nrow > nrows:
186
187
  return
187
- yield from validator(json_dict)
188
+ yield from validator(json_dict, nrow)
188
189
 
189
190
  if meta_type == "jsonl":
190
191
  try:
@@ -197,7 +198,7 @@ def read_meta( # noqa: C901
197
198
  return
198
199
  json_object = process_json(data_string, jmespath)
199
200
  data_string = fd.readline()
200
- yield from validator(json_object)
201
+ yield from validator(json_object, nrow)
201
202
  except OSError as e:
202
203
  print(f"An unexpected file error occurred in file {file.name}: {e}")
203
204
 
datachain/lib/settings.py CHANGED
@@ -7,11 +7,8 @@ class SettingsError(DataChainParamsError):
7
7
 
8
8
 
9
9
  class Settings:
10
- def __init__(
11
- self, cache=None, batch=None, parallel=None, workers=None, min_task_size=None
12
- ):
10
+ def __init__(self, cache=None, parallel=None, workers=None, min_task_size=None):
13
11
  self._cache = cache
14
- self._batch = batch
15
12
  self.parallel = parallel
16
13
  self._workers = workers
17
14
  self.min_task_size = min_task_size
@@ -22,12 +19,6 @@ class Settings:
22
19
  f" while {cache.__class__.__name__} was given"
23
20
  )
24
21
 
25
- if not isinstance(batch, int) and batch is not None:
26
- raise SettingsError(
27
- "'batch' argument must be int or None"
28
- f" while {batch.__class__.__name__} was given"
29
- )
30
-
31
22
  if not isinstance(parallel, int) and parallel is not None:
32
23
  raise SettingsError(
33
24
  "'parallel' argument must be int or None"
@@ -54,10 +45,6 @@ class Settings:
54
45
  def cache(self):
55
46
  return self._cache if self._cache is not None else False
56
47
 
57
- @property
58
- def batch(self):
59
- return self._batch if self._batch is not None else 1
60
-
61
48
  @property
62
49
  def workers(self):
63
50
  return self._workers if self._workers is not None else False
@@ -66,8 +53,6 @@ class Settings:
66
53
  res = {}
67
54
  if self._cache is not None:
68
55
  res["cache"] = self.cache
69
- if self._batch is not None:
70
- res["batch"] = self.batch
71
56
  if self.parallel is not None:
72
57
  res["parallel"] = self.parallel
73
58
  if self._workers is not None:
@@ -78,7 +63,6 @@ class Settings:
78
63
 
79
64
  def add(self, settings: "Settings"):
80
65
  self._cache = settings._cache or self._cache
81
- self._batch = settings._batch or self._batch
82
66
  self.parallel = settings.parallel or self.parallel
83
67
  self._workers = settings._workers or self._workers
84
68
  self.min_task_size = settings.min_task_size or self.min_task_size
datachain/lib/udf.py CHANGED
@@ -225,11 +225,10 @@ class UDFBase(AbstractUDF):
225
225
  def __call__(self, *rows, cache, download_cb):
226
226
  if self.is_input_grouped:
227
227
  objs = self._parse_grouped_rows(rows[0], cache, download_cb)
228
+ elif self.is_input_batched:
229
+ objs = zip(*self._parse_rows(rows[0], cache, download_cb))
228
230
  else:
229
- objs = self._parse_rows(rows, cache, download_cb)
230
-
231
- if not self.is_input_batched:
232
- objs = objs[0]
231
+ objs = self._parse_rows([rows], cache, download_cb)[0]
233
232
 
234
233
  result_objs = self.process_safe(objs)
235
234
 
@@ -259,17 +258,24 @@ class UDFBase(AbstractUDF):
259
258
 
260
259
  if not self.is_output_batched:
261
260
  res = list(res)
262
- assert len(res) == 1, (
263
- f"{self.name} returns {len(res)} " f"rows while it's not batched"
264
- )
261
+ assert (
262
+ len(res) == 1
263
+ ), f"{self.name} returns {len(res)} rows while it's not batched"
265
264
  if isinstance(res[0], tuple):
266
265
  res = res[0]
266
+ elif (
267
+ self.is_input_batched
268
+ and self.is_output_batched
269
+ and not self.is_input_grouped
270
+ ):
271
+ res = list(res)
272
+ assert len(res) == len(
273
+ rows[0]
274
+ ), f"{self.name} returns {len(res)} rows while len(rows[0]) expected"
267
275
 
268
276
  return res
269
277
 
270
278
  def _parse_rows(self, rows, cache, download_cb):
271
- if not self.is_input_batched:
272
- rows = [rows]
273
279
  objs = []
274
280
  for row in rows:
275
281
  obj_row = self.params.row_to_objs(row)
@@ -330,7 +336,9 @@ class Mapper(UDFBase):
330
336
  """Inherit from this class to pass to `DataChain.map()`."""
331
337
 
332
338
 
333
- class BatchMapper(Mapper):
339
+ class BatchMapper(UDFBase):
340
+ """Inherit from this class to pass to `DataChain.batch_map()`."""
341
+
334
342
  is_input_batched = True
335
343
  is_output_batched = True
336
344