datachain 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (49) hide show
  1. datachain/__init__.py +3 -4
  2. datachain/cache.py +10 -4
  3. datachain/catalog/catalog.py +35 -15
  4. datachain/cli.py +37 -32
  5. datachain/data_storage/metastore.py +24 -0
  6. datachain/data_storage/warehouse.py +3 -1
  7. datachain/job.py +56 -0
  8. datachain/lib/arrow.py +19 -7
  9. datachain/lib/clip.py +89 -66
  10. datachain/lib/convert/{type_converter.py → python_to_sql.py} +6 -6
  11. datachain/lib/convert/sql_to_python.py +23 -0
  12. datachain/lib/convert/values_to_tuples.py +51 -33
  13. datachain/lib/data_model.py +6 -27
  14. datachain/lib/dataset_info.py +70 -0
  15. datachain/lib/dc.py +646 -152
  16. datachain/lib/file.py +117 -15
  17. datachain/lib/image.py +1 -1
  18. datachain/lib/meta_formats.py +14 -2
  19. datachain/lib/model_store.py +3 -2
  20. datachain/lib/pytorch.py +10 -7
  21. datachain/lib/signal_schema.py +39 -14
  22. datachain/lib/text.py +2 -1
  23. datachain/lib/udf.py +56 -5
  24. datachain/lib/udf_signature.py +1 -1
  25. datachain/lib/webdataset.py +4 -3
  26. datachain/node.py +11 -8
  27. datachain/query/dataset.py +66 -147
  28. datachain/query/dispatch.py +15 -13
  29. datachain/query/schema.py +2 -0
  30. datachain/query/session.py +4 -4
  31. datachain/sql/functions/array.py +12 -0
  32. datachain/sql/functions/string.py +8 -0
  33. datachain/torch/__init__.py +1 -1
  34. datachain/utils.py +45 -0
  35. datachain-0.2.12.dist-info/METADATA +412 -0
  36. {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/RECORD +40 -45
  37. {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/WHEEL +1 -1
  38. datachain/lib/feature_registry.py +0 -77
  39. datachain/lib/gpt4_vision.py +0 -97
  40. datachain/lib/hf_image_to_text.py +0 -97
  41. datachain/lib/hf_pipeline.py +0 -90
  42. datachain/lib/image_transform.py +0 -103
  43. datachain/lib/iptc_exif_xmp.py +0 -76
  44. datachain/lib/unstructured.py +0 -41
  45. datachain/text/__init__.py +0 -3
  46. datachain-0.2.10.dist-info/METADATA +0 -430
  47. {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/LICENSE +0 -0
  48. {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/entry_points.txt +0 -0
  49. {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/top_level.txt +0 -0
datachain/lib/dc.py CHANGED
@@ -1,22 +1,31 @@
1
1
  import copy
2
+ import os
2
3
  import re
3
4
  from collections.abc import Iterator, Sequence
5
+ from functools import wraps
4
6
  from typing import (
5
7
  TYPE_CHECKING,
6
8
  Any,
9
+ BinaryIO,
7
10
  Callable,
8
11
  ClassVar,
9
12
  Literal,
10
13
  Optional,
14
+ TypeVar,
11
15
  Union,
16
+ overload,
12
17
  )
13
18
 
19
+ import pandas as pd
14
20
  import sqlalchemy
15
21
  from pydantic import BaseModel, create_model
22
+ from sqlalchemy.sql.functions import GenericFunction
16
23
 
17
24
  from datachain import DataModel
18
25
  from datachain.lib.convert.values_to_tuples import values_to_tuples
19
26
  from datachain.lib.data_model import DataType
27
+ from datachain.lib.dataset_info import DatasetInfo
28
+ from datachain.lib.file import ExportPlacement as FileExportPlacement
20
29
  from datachain.lib.file import File, IndexedFile, get_file
21
30
  from datachain.lib.meta_formats import read_meta, read_schema
22
31
  from datachain.lib.model_store import ModelStore
@@ -24,7 +33,6 @@ from datachain.lib.settings import Settings
24
33
  from datachain.lib.signal_schema import SignalSchema
25
34
  from datachain.lib.udf import (
26
35
  Aggregator,
27
- BatchMapper,
28
36
  Generator,
29
37
  Mapper,
30
38
  UDFBase,
@@ -38,29 +46,60 @@ from datachain.query.dataset import (
38
46
  detach,
39
47
  )
40
48
  from datachain.query.schema import Column, DatasetRow
49
+ from datachain.utils import inside_notebook
41
50
 
42
51
  if TYPE_CHECKING:
43
- import pandas as pd
44
- from typing_extensions import Self
52
+ from typing_extensions import Concatenate, ParamSpec, Self
53
+
54
+ P = ParamSpec("P")
45
55
 
46
56
  C = Column
47
57
 
58
+ _T = TypeVar("_T")
59
+ D = TypeVar("D", bound="DataChain")
60
+
61
+
62
+ def resolve_columns(
63
+ method: "Callable[Concatenate[D, P], D]",
64
+ ) -> "Callable[Concatenate[D, P], D]":
65
+ """Decorator that resolvs input column names to their actual DB names. This is
66
+ specially important for nested columns as user works with them by using dot
67
+ notation e.g (file.name) but are actually defined with default delimiter
68
+ in DB, e.g file__name.
69
+ If there are any sql functions in arguments, they will just be transferred as is
70
+ to a method.
71
+ """
72
+
73
+ @wraps(method)
74
+ def _inner(self: D, *args: "P.args", **kwargs: "P.kwargs") -> D:
75
+ resolved_args = self.signals_schema.resolve(
76
+ *[arg for arg in args if not isinstance(arg, GenericFunction)]
77
+ ).db_signals()
78
+
79
+ for idx, arg in enumerate(args):
80
+ if isinstance(arg, GenericFunction):
81
+ resolved_args.insert(idx, arg)
82
+
83
+ return method(self, *resolved_args, **kwargs)
84
+
85
+ return _inner
48
86
 
49
- class DatasetPrepareError(DataChainParamsError):
50
- def __init__(self, name, msg, output=None):
87
+
88
+ class DatasetPrepareError(DataChainParamsError): # noqa: D101
89
+ def __init__(self, name, msg, output=None): # noqa: D107
51
90
  name = f" '{name}'" if name else ""
52
91
  output = f" output '{output}'" if output else ""
53
92
  super().__init__(f"Dataset{name}{output} processing prepare error: {msg}")
54
93
 
55
94
 
56
- class DatasetFromValuesError(DataChainParamsError):
57
- def __init__(self, name, msg):
95
+ class DatasetFromValuesError(DataChainParamsError): # noqa: D101
96
+ def __init__(self, name, msg): # noqa: D107
58
97
  name = f" '{name}'" if name else ""
59
- super().__init__(f"Dataset {name} from values error: {msg}")
98
+ super().__init__(f"Dataset{name} from values error: {msg}")
60
99
 
61
100
 
62
- class DatasetMergeError(DataChainParamsError):
63
- def __init__(self, on: Sequence[str], right_on: Optional[Sequence[str]], msg: str):
101
+ class DatasetMergeError(DataChainParamsError): # noqa: D101
102
+ def __init__(self, on: Sequence[str], right_on: Optional[Sequence[str]], msg: str): # noqa: D107
64
103
  on_str = ", ".join(on) if isinstance(on, Sequence) else ""
65
104
  right_on_str = (
66
105
  ", right_on='" + ", ".join(right_on) + "'"
@@ -74,6 +113,8 @@ OutputType = Union[None, DataType, Sequence[str], dict[str, DataType]]
74
113
 
75
114
 
76
115
  class Sys(DataModel):
116
+ """Model for internal DataChain signals `id` and `rand`."""
117
+
77
118
  id: int
78
119
  rand: int
79
120
 
@@ -86,7 +127,7 @@ class DataChain(DatasetQuery):
86
127
  enrich data.
87
128
 
88
129
  Data in DataChain is presented as Python classes with arbitrary set of fields,
89
- including nested classes. The data classes have to inherit from `Feature` class.
130
+ including nested classes. The data classes have to inherit from `DataModel` class.
90
131
  The supported set of field types include: majority of the type supported by the
91
132
  underlyind library `Pydantic`.
92
133
 
@@ -98,34 +139,56 @@ class DataChain(DatasetQuery):
98
139
 
99
140
  `DataChain.from_dataset("name")` - reading from a dataset.
100
141
 
101
- `DataChain.from_features(fib=[1, 2, 3, 5, 8])` - generating from a values.
142
+ `DataChain.from_values(fib=[1, 2, 3, 5, 8])` - generating from values.
143
+
144
+ `DataChain.from_pandas(pd.DataFrame(...))` - generating from pandas.
145
+
146
+ `DataChain.from_json("file.json")` - generating from json.
102
147
 
148
+ `DataChain.from_csv("file.csv")` - generating from csv.
149
+
150
+ `DataChain.from_parquet("file.parquet")` - generating from parquet.
103
151
 
104
152
  Example:
105
153
  ```py
106
- from datachain import DataChain, Feature
107
- from datachain.lib.claude import claude_processor
154
+ import os
155
+
156
+ from mistralai.client import MistralClient
157
+ from mistralai.models.chat_completion import ChatMessage
158
+
159
+ from datachain.dc import DataChain, Column
108
160
 
109
- class Rating(Feature):
110
- status: str = ""
111
- explanation: str = ""
161
+ PROMPT = (
162
+ "Was this bot dialog successful? "
163
+ "Describe the 'result' as 'Yes' or 'No' in a short JSON"
164
+ )
112
165
 
113
- PROMPT = "A 'user' is a human trying to find the best mobile plan.... "
114
- MODEL = "claude-3-opus-20240229"
166
+ model = "mistral-large-latest"
167
+ api_key = os.environ["MISTRAL_API_KEY"]
115
168
 
116
169
  chain = (
117
- DataChain.from_storage("s3://my-bucket/my")
118
- .filter(C.name.glob("*.txt"))
170
+ DataChain.from_storage("gs://datachain-demo/chatbot-KiT/")
119
171
  .limit(5)
120
- .map(claude=claude_processor(prompt=PROMPT, model=MODEL))
172
+ .settings(cache=True, parallel=5)
121
173
  .map(
122
- rating=lambda claude: Rating(
123
- **(json.loads(claude.content[0].text) if claude.content else {})
124
- ),
125
- output=Rating,
174
+ mistral_response=lambda file: MistralClient(api_key=api_key)
175
+ .chat(
176
+ model=model,
177
+ response_format={"type": "json_object"},
178
+ messages=[
179
+ ChatMessage(role="user", content=f"{PROMPT}: {file.read()}")
180
+ ],
181
+ )
182
+ .choices[0]
183
+ .message.content,
184
+ )
185
+ .save()
126
186
  )
127
- chain.save("ratings")
128
- print(chain)
187
+
188
+ try:
189
+ print(chain.select("mistral_response").results())
190
+ except Exception as e:
191
+ print(f"do you have the right Mistral API key? {e}")
129
192
  ```
130
193
  """
131
194
 
@@ -137,8 +200,9 @@ class DataChain(DatasetQuery):
137
200
  }
138
201
 
139
202
  def __init__(self, *args, **kwargs):
140
- """This method needs to be redefined as a part of Dataset and DacaChin
141
- decoupling."""
203
+ """This method needs to be redefined as a part of Dataset and DataChain
204
+ decoupling.
205
+ """
142
206
  super().__init__(
143
207
  *args,
144
208
  **kwargs,
@@ -147,19 +211,25 @@ class DataChain(DatasetQuery):
147
211
  self._settings = Settings()
148
212
  self._setup = {}
149
213
 
214
+ self.signals_schema = SignalSchema({"sys": Sys})
150
215
  if self.feature_schema:
151
- self.signals_schema = SignalSchema.deserialize(self.feature_schema)
216
+ self.signals_schema |= SignalSchema.deserialize(self.feature_schema)
152
217
  else:
153
- self.signals_schema = SignalSchema.from_column_types(self.column_types)
218
+ self.signals_schema |= SignalSchema.from_column_types(self.column_types)
219
+
220
+ self._sys = False
154
221
 
155
222
  @property
156
- def schema(self):
157
- return self.signals_schema.values if self.signals_schema else None
223
+ def schema(self) -> dict[str, DataType]:
224
+ """Get schema of the chain."""
225
+ return self._effective_signals_schema.values
158
226
 
159
- def print_schema(self):
160
- self.signals_schema.print_tree()
227
+ def print_schema(self) -> None:
228
+ """Print schema of the chain."""
229
+ self._effective_signals_schema.print_tree()
161
230
 
162
231
  def clone(self, new_table: bool = True) -> "Self":
232
+ """Make a copy of the chain in a new table."""
163
233
  obj = super().clone(new_table=new_table)
164
234
  obj.signals_schema = copy.deepcopy(self.signals_schema)
165
235
  return obj
@@ -171,7 +241,7 @@ class DataChain(DatasetQuery):
171
241
  parallel=None,
172
242
  workers=None,
173
243
  min_task_size=None,
174
- include_sys: Optional[bool] = None,
244
+ sys: Optional[bool] = None,
175
245
  ) -> "Self":
176
246
  """Change settings for chain.
177
247
 
@@ -196,10 +266,8 @@ class DataChain(DatasetQuery):
196
266
  ```
197
267
  """
198
268
  chain = self.clone()
199
- if include_sys is True:
200
- chain.signals_schema = SignalSchema({"sys": Sys}) | chain.signals_schema
201
- elif include_sys is False and "sys" in chain.signals_schema:
202
- chain.signals_schema.remove("sys")
269
+ if sys is not None:
270
+ chain._sys = sys
203
271
  chain._settings.add(Settings(cache, batch, parallel, workers, min_task_size))
204
272
  return chain
205
273
 
@@ -208,17 +276,14 @@ class DataChain(DatasetQuery):
208
276
  self._settings = settings if settings else Settings()
209
277
  return self
210
278
 
211
- def reset_schema(self, signals_schema: SignalSchema) -> "Self":
279
+ def reset_schema(self, signals_schema: SignalSchema) -> "Self": # noqa: D102
212
280
  self.signals_schema = signals_schema
213
281
  return self
214
282
 
215
- def add_schema(self, signals_schema: SignalSchema) -> "Self":
283
+ def add_schema(self, signals_schema: SignalSchema) -> "Self": # noqa: D102
216
284
  self.signals_schema |= signals_schema
217
285
  return self
218
286
 
219
- def get_file_signals(self) -> list[str]:
220
- return list(self.signals_schema.get_file_signals())
221
-
222
287
  @classmethod
223
288
  def from_storage(
224
289
  cls,
@@ -228,10 +293,11 @@ class DataChain(DatasetQuery):
228
293
  session: Optional[Session] = None,
229
294
  recursive: Optional[bool] = True,
230
295
  object_name: str = "file",
296
+ update: bool = False,
231
297
  **kwargs,
232
298
  ) -> "Self":
233
- """Get data from a storage as a list of file with all file attributes. It
234
- returns the chain itself as usual.
299
+ """Get data from a storage as a list of file with all file attributes.
300
+ It returns the chain itself as usual.
235
301
 
236
302
  Parameters:
237
303
  path : storage URI with directory. URI must start with storage prefix such
@@ -239,6 +305,7 @@ class DataChain(DatasetQuery):
239
305
  type : read file as "binary", "text", or "image" data. Default is "binary".
240
306
  recursive : search recursively for the given path.
241
307
  object_name : Created object column name.
308
+ update : force storage reindexing. Default is False.
242
309
 
243
310
  Example:
244
311
  ```py
@@ -246,20 +313,24 @@ class DataChain(DatasetQuery):
246
313
  ```
247
314
  """
248
315
  func = get_file(type)
249
- return cls(path, session=session, recursive=recursive, **kwargs).map(
250
- **{object_name: func}
316
+ return (
317
+ cls(path, session=session, recursive=recursive, update=update, **kwargs)
318
+ .map(**{object_name: func})
319
+ .select(object_name)
251
320
  )
252
321
 
253
322
  @classmethod
254
323
  def from_dataset(cls, name: str, version: Optional[int] = None) -> "DataChain":
255
- """Get data from dataset. It returns the chain itself.
324
+ """Get data from a saved Dataset. It returns the chain itself.
256
325
 
257
326
  Parameters:
258
327
  name : dataset name
259
328
  version : dataset version
260
329
 
261
- Examples:
262
- >>> chain = DataChain.from_dataset("my_cats")
330
+ Example:
331
+ ```py
332
+ chain = DataChain.from_dataset("my_cats")
333
+ ```
263
334
  """
264
335
  return DataChain(name=name, version=version)
265
336
 
@@ -275,6 +346,7 @@ class DataChain(DatasetQuery):
275
346
  model_name: Optional[str] = None,
276
347
  show_schema: Optional[bool] = False,
277
348
  meta_type: Optional[str] = "json",
349
+ nrows=None,
278
350
  **kwargs,
279
351
  ) -> "DataChain":
280
352
  """Get data from JSON. It returns the chain itself.
@@ -284,18 +356,23 @@ class DataChain(DatasetQuery):
284
356
  as `s3://`, `gs://`, `az://` or "file:///"
285
357
  type : read file as "binary", "text", or "image" data. Default is "binary".
286
358
  spec : optional Data Model
287
- schema_from : path to sample to infer spec from
359
+ schema_from : path to sample to infer spec (if schema not provided)
288
360
  object_name : generated object column name
289
- model_name : generated model name
361
+ model_name : optional generated model name
290
362
  show_schema : print auto-generated schema
291
- jmespath : JMESPATH expression to reduce JSON
363
+ jmespath : optional JMESPATH expression to reduce JSON
364
+ nrows : optional row limit for jsonl and JSON arrays
292
365
 
293
- Examples:
366
+ Example:
294
367
  infer JSON schema from data, reduce using JMESPATH, print schema
295
- >>> chain = DataChain.from_json("gs://json", jmespath="key1.key2")
368
+ ```py
369
+ chain = DataChain.from_json("gs://json", jmespath="key1.key2")
370
+ ```
296
371
 
297
372
  infer JSON schema from a particular path, print data model
298
- >>> chain = DataChain.from_json("gs://json_ds", schema_from="gs://json/my.json")
373
+ ```py
374
+ chain = DataChain.from_json("gs://json_ds", schema_from="gs://json/my.json")
375
+ ```
299
376
  """
300
377
  if schema_from == "auto":
301
378
  schema_from = path
@@ -317,10 +394,40 @@ class DataChain(DatasetQuery):
317
394
  model_name=model_name,
318
395
  show_schema=show_schema,
319
396
  jmespath=jmespath,
397
+ nrows=nrows,
320
398
  )
321
399
  }
322
400
  return chain.gen(**signal_dict) # type: ignore[arg-type]
323
401
 
402
+ @classmethod
403
+ def datasets(
404
+ cls, session: Optional[Session] = None, object_name: str = "dataset"
405
+ ) -> "DataChain":
406
+ """Generate chain with list of registered datasets.
407
+
408
+ Example:
409
+ ```py
410
+ from datachain import DataChain
411
+
412
+ chain = DataChain.datasets()
413
+ for ds in chain.collect("dataset"):
414
+ print(f"{ds.name}@v{ds.version}")
415
+ ```
416
+ """
417
+ session = Session.get(session)
418
+ catalog = session.catalog
419
+
420
+ datasets = [
421
+ DatasetInfo.from_models(d, v, j)
422
+ for d, v, j in catalog.list_datasets_versions()
423
+ ]
424
+
425
+ return cls.from_values(
426
+ session=session,
427
+ output={object_name: DatasetInfo},
428
+ **{object_name: datasets}, # type: ignore[arg-type]
429
+ )
430
+
324
431
  def show_json_schema( # type: ignore[override]
325
432
  self, jmespath: Optional[str] = None, model_name: Optional[str] = None
326
433
  ) -> "DataChain":
@@ -330,12 +437,14 @@ class DataChain(DatasetQuery):
330
437
  jmespath : JMESPATH expression to reduce JSON
331
438
  model_name : generated model name
332
439
 
333
- Examples:
440
+ Example:
334
441
  print JSON schema and save to column "meta_from":
335
- >>> uri = "gs://datachain-demo/coco2017/annotations_captions/"
336
- >>> chain = DataChain.from_storage(uri)
337
- >>> chain = chain.show_json_schema()
338
- >>> chain.save()
442
+ ```py
443
+ uri = "gs://datachain-demo/coco2017/annotations_captions/"
444
+ chain = DataChain.from_storage(uri)
445
+ chain = chain.show_json_schema()
446
+ chain.save()
447
+ ```
339
448
  """
340
449
  return self.map(
341
450
  meta_schema=lambda file: read_schema(
@@ -370,11 +479,29 @@ class DataChain(DatasetQuery):
370
479
  removed after process ends. Temp dataset are useful for optimization.
371
480
  version : version of a dataset. Default - the last version that exist.
372
481
  """
373
- schema = self.signals_schema.serialize()
374
- schema.pop("sys", None)
482
+ schema = self.signals_schema.clone_without_sys_signals().serialize()
375
483
  return super().save(name=name, version=version, feature_schema=schema)
376
484
 
377
485
  def apply(self, func, *args, **kwargs):
486
+ """Apply any function to the chain.
487
+
488
+ Useful for reusing in a chain of operations.
489
+
490
+ Example:
491
+ ```py
492
+ def parse_stem(chain):
493
+ return chain.map(
494
+ lambda file: file.get_file_stem()
495
+ output={"stem": str}
496
+ )
497
+
498
+ chain = (
499
+ DataChain.from_storage("s3://my-bucket")
500
+ .apply(parse_stem)
501
+ .filter(C("stem").glob("*cat*"))
502
+ )
503
+ ```
504
+ """
378
505
  return func(self, *args, **kwargs)
379
506
 
380
507
  def map(
@@ -402,16 +529,19 @@ class DataChain(DatasetQuery):
402
529
  signal name in format of `map(my_sign=my_func)`. This helps define
403
530
  signal names and function in a nicer way.
404
531
 
405
- Examples:
532
+ Example:
406
533
  Using signal_map and single type in output:
407
- >>> chain = chain.map(value=lambda name: name[:-4] + ".json", output=str)
408
- >>> chain.save("new_dataset")
534
+ ```py
535
+ chain = chain.map(value=lambda name: name[:-4] + ".json", output=str)
536
+ chain.save("new_dataset")
537
+ ```
409
538
 
410
539
  Using func and output as a map:
411
- >>> chain = chain.map(lambda name: name[:-4] + ".json", output={"res": str})
412
- >>> chain.save("new_dataset")
540
+ ```py
541
+ chain = chain.map(lambda name: name[:-4] + ".json", output={"res": str})
542
+ chain.save("new_dataset")
543
+ ```
413
544
  """
414
-
415
545
  udf_obj = self._udf_to_obj(Mapper, func, params, output, signal_map)
416
546
 
417
547
  chain = self.add_signals(
@@ -439,7 +569,6 @@ class DataChain(DatasetQuery):
439
569
  extracting multiple file records from a single tar file or bounding boxes from a
440
570
  single image file).
441
571
  """
442
-
443
572
  udf_obj = self._udf_to_obj(Generator, func, params, output, signal_map)
444
573
  chain = DatasetQuery.generate(
445
574
  self,
@@ -480,27 +609,6 @@ class DataChain(DatasetQuery):
480
609
 
481
610
  return chain.reset_schema(udf_obj.output).reset_settings(self._settings)
482
611
 
483
- def batch_map(
484
- self,
485
- func: Optional[Callable] = None,
486
- params: Union[None, str, Sequence[str]] = None,
487
- output: OutputType = None,
488
- **signal_map,
489
- ) -> "Self":
490
- """This is a batch version of map().
491
-
492
- It accepts the same parameters plus an
493
- additional parameter:
494
- """
495
- udf_obj = self._udf_to_obj(BatchMapper, func, params, output, signal_map)
496
- chain = DatasetQuery.generate(
497
- self,
498
- udf_obj.to_udf_wrapper(self._settings.batch),
499
- **self._settings.to_dict(),
500
- )
501
-
502
- return chain.add_schema(udf_obj.output).reset_settings(self._settings)
503
-
504
612
  def _udf_to_obj(
505
613
  self,
506
614
  target_class: type[UDFBase],
@@ -515,7 +623,11 @@ class DataChain(DatasetQuery):
515
623
  sign = UdfSignature.parse(name, signal_map, func, params, output, is_generator)
516
624
  DataModel.register(list(sign.output_schema.values.values()))
517
625
 
518
- params_schema = self.signals_schema.slice(sign.params, self._setup)
626
+ signals_schema = self.signals_schema
627
+ if self._sys:
628
+ signals_schema = SignalSchema({"sys": Sys}) | signals_schema
629
+
630
+ params_schema = signals_schema.slice(sign.params, self._setup)
519
631
 
520
632
  return target_class._create(sign, params_schema)
521
633
 
@@ -531,9 +643,38 @@ class DataChain(DatasetQuery):
531
643
  return res
532
644
 
533
645
  @detach
534
- def select(self, *args: str) -> "Self":
646
+ @resolve_columns
647
+ def order_by(self, *args, descending: bool = False) -> "Self":
648
+ """Orders by specified set of signals.
649
+
650
+ Parameters:
651
+ descending (bool): Whether to sort in descending order or not.
652
+ """
653
+ if descending:
654
+ args = tuple([sqlalchemy.desc(a) for a in args])
655
+
656
+ return super().order_by(*args)
657
+
658
+ @detach
659
+ def distinct(self, arg: str, *args: str) -> "Self": # type: ignore[override]
660
+ """Removes duplicate rows based on uniqueness of some input column(s)
661
+ i.e if rows are found with the same value of input column(s), only one
662
+ row is left in the result set.
663
+
664
+ Example:
665
+ ```py
666
+ dc.distinct("file.parent", "file.name")
667
+ )
668
+ ```
669
+ """
670
+ return super().distinct(*self.signals_schema.resolve(arg, *args).db_signals())
671
+
672
+ @detach
673
+ def select(self, *args: str, _sys: bool = True) -> "Self":
535
674
  """Select only a specified set of signals."""
536
675
  new_schema = self.signals_schema.resolve(*args)
676
+ if _sys:
677
+ new_schema = SignalSchema({"sys": Sys}) | new_schema
537
678
  columns = new_schema.db_signals()
538
679
  chain = super().select(*columns)
539
680
  chain.signals_schema = new_schema
@@ -548,45 +689,156 @@ class DataChain(DatasetQuery):
548
689
  chain.signals_schema = new_schema
549
690
  return chain
550
691
 
551
- def iterate_flatten(self) -> Iterator[tuple[Any]]:
552
- db_signals = self.signals_schema.db_signals()
692
+ @detach
693
+ def mutate(self, **kwargs) -> "Self":
694
+ """Create new signals based on existing signals.
695
+
696
+ This method is vectorized and more efficient compared to map(), and it does not
697
+ extract or download any data from the internal database. However, it can only
698
+ utilize predefined built-in functions and their combinations.
699
+
700
+ The supported functions:
701
+ Numerical: +, -, *, /, rand(), avg(), count(), func(),
702
+ greatest(), least(), max(), min(), sum()
703
+ String: length(), split()
704
+ Filename: name(), parent(), file_stem(), file_ext()
705
+ Array: length(), sip_hash_64(), euclidean_distance(),
706
+ cosine_distance()
707
+
708
+ Example:
709
+ ```py
710
+ dc.mutate(
711
+ area=Column("image.height") * Column("image.width"),
712
+ extension=file_ext(Column("file.name")),
713
+ dist=cosine_distance(embedding_text, embedding_image)
714
+ )
715
+ ```
716
+ """
717
+ chain = super().mutate(**kwargs)
718
+ chain.signals_schema = self.signals_schema.mutate(kwargs)
719
+ return chain
720
+
721
+ @property
722
+ def _effective_signals_schema(self) -> "SignalSchema":
723
+ """Effective schema used for user-facing API like collect, to_pandas, etc."""
724
+ signals_schema = self.signals_schema
725
+ if not self._sys:
726
+ return signals_schema.clone_without_sys_signals()
727
+ return signals_schema
728
+
729
+ @overload
730
+ def collect_flatten(self) -> Iterator[tuple[Any, ...]]: ...
731
+
732
+ @overload
733
+ def collect_flatten(
734
+ self, *, row_factory: Callable[[list[str], tuple[Any, ...]], _T]
735
+ ) -> Iterator[_T]: ...
736
+
737
+ def collect_flatten(self, *, row_factory=None):
738
+ """Yields flattened rows of values as a tuple.
739
+
740
+ Args:
741
+ row_factory : A callable to convert row to a custom format.
742
+ It should accept two arguments: a list of column names and
743
+ a tuple of row values.
744
+ """
745
+ db_signals = self._effective_signals_schema.db_signals()
553
746
  with super().select(*db_signals).as_iterable() as rows:
747
+ if row_factory:
748
+ rows = (row_factory(db_signals, r) for r in rows)
554
749
  yield from rows
555
750
 
751
+ @overload
752
+ def results(self) -> list[tuple[Any, ...]]: ...
753
+
754
+ @overload
556
755
  def results(
557
- self, row_factory: Optional[Callable] = None, **kwargs
558
- ) -> list[tuple[Any, ...]]:
559
- rows = self.iterate_flatten()
560
- if row_factory:
561
- db_signals = self.signals_schema.db_signals()
562
- rows = (row_factory(db_signals, r) for r in rows)
563
- return list(rows)
756
+ self, *, row_factory: Callable[[list[str], tuple[Any, ...]], _T]
757
+ ) -> list[_T]: ...
564
758
 
565
- def iterate(self, *cols: str) -> Iterator[list[DataType]]:
566
- """Iterate over rows.
759
+ def results(self, *, row_factory=None): # noqa: D102
760
+ if row_factory is None:
761
+ return list(self.collect_flatten())
762
+ return list(self.collect_flatten(row_factory=row_factory))
567
763
 
568
- If columns are specified - limit them to specified
569
- columns.
570
- """
571
- chain = self.select(*cols) if cols else self
572
- for row in chain.iterate_flatten():
573
- yield chain.signals_schema.row_to_features(
574
- row, catalog=chain.session.catalog, cache=chain._settings.cache
575
- )
764
+ def to_records(self) -> list[dict[str, Any]]:
765
+ """Convert every row to a dictionary."""
766
+
767
+ def to_dict(cols: list[str], row: tuple[Any, ...]) -> dict[str, Any]:
768
+ return dict(zip(cols, row))
769
+
770
+ return self.results(row_factory=to_dict)
771
+
772
+ @overload
773
+ def collect(self) -> Iterator[tuple[DataType, ...]]: ...
774
+
775
+ @overload
776
+ def collect(self, col: str) -> Iterator[DataType]: ... # type: ignore[overload-overlap]
576
777
 
577
- def iterate_one(self, col: str) -> Iterator[DataType]:
578
- for item in self.iterate(col):
579
- yield item[0]
778
+ @overload
779
+ def collect(self, *cols: str) -> Iterator[tuple[DataType, ...]]: ...
580
780
 
581
- def collect(self, *cols: str) -> list[list[DataType]]:
582
- return list(self.iterate(*cols))
781
+ def collect(self, *cols: str) -> Iterator[Union[DataType, tuple[DataType, ...]]]: # type: ignore[overload-overlap,misc]
782
+ """Yields rows of values, optionally limited to the specified columns.
583
783
 
584
- def collect_one(self, col: str) -> list[DataType]:
585
- return list(self.iterate_one(col))
784
+ Args:
785
+ *cols: Limit to the specified columns. By default, all columns are selected.
586
786
 
587
- def to_pytorch(self, **kwargs):
588
- """Convert to pytorch dataset format."""
787
+ Yields:
788
+ (DataType): Yields a single item if a column is selected.
789
+ (tuple[DataType, ...]): Yields a tuple of items if multiple columns are
790
+ selected.
589
791
 
792
+ Example:
793
+ Iterating over all rows:
794
+ ```py
795
+ for row in dc.collect():
796
+ print(row)
797
+ ```
798
+
799
+ Iterating over all rows with selected columns:
800
+ ```py
801
+ for name, size in dc.collect("file.name", "file.size"):
802
+ print(name, size)
803
+ ```
804
+
805
+ Iterating over a single column:
806
+ ```py
807
+ for file in dc.collect("file.name"):
808
+ print(file)
809
+ ```
810
+ """
811
+ chain = self.select(*cols) if cols else self
812
+ signals_schema = chain._effective_signals_schema
813
+ db_signals = signals_schema.db_signals()
814
+ with super().select(*db_signals).as_iterable() as rows:
815
+ for row in rows:
816
+ ret = signals_schema.row_to_features(
817
+ row, catalog=chain.session.catalog, cache=chain._settings.cache
818
+ )
819
+ yield ret[0] if len(cols) == 1 else tuple(ret)
820
+
821
+ def to_pytorch(
822
+ self, transform=None, tokenizer=None, tokenizer_kwargs=None, num_samples=0
823
+ ):
824
+ """Convert to pytorch dataset format.
825
+
826
+ Args:
827
+ transform (Transform): Torchvision transforms to apply to the dataset.
828
+ tokenizer (Callable): Tokenizer to use to tokenize text values.
829
+ tokenizer_kwargs (dict): Additional kwargs to pass when calling tokenizer.
830
+ num_samples (int): Number of random samples to draw for each epoch.
831
+ This argument is ignored if `num_samples=0` (the default).
832
+
833
+ Example:
834
+ ```py
835
+ from torch.utils.data import DataLoader
836
+ loader = DataLoader(
837
+ chain.select("file", "label").to_pytorch(),
838
+ batch_size=16
839
+ )
840
+ ```
841
+ """
590
842
  from datachain.torch import PytorchDataset
591
843
 
592
844
  if self.attached:
@@ -594,9 +846,17 @@ class DataChain(DatasetQuery):
594
846
  else:
595
847
  chain = self.save()
596
848
  assert chain.name is not None # for mypy
597
- return PytorchDataset(chain.name, chain.version, catalog=self.catalog, **kwargs)
849
+ return PytorchDataset(
850
+ chain.name,
851
+ chain.version,
852
+ catalog=self.catalog,
853
+ transform=transform,
854
+ tokenizer=tokenizer,
855
+ tokenizer_kwargs=tokenizer_kwargs,
856
+ num_samples=num_samples,
857
+ )
598
858
 
599
- def remove_file_signals(self) -> "Self":
859
+ def remove_file_signals(self) -> "Self": # noqa: D102
600
860
  schema = self.signals_schema.clone_without_file_signals()
601
861
  return self.select(*schema.values.keys())
602
862
 
@@ -621,9 +881,11 @@ class DataChain(DatasetQuery):
621
881
  inner (bool): Whether to run inner join or outer join.
622
882
  rname (str): name prefix for conflicting signal names.
623
883
 
624
- Examples:
625
- >>> meta = meta_emd.merge(meta_pq, on=(C.name, C.emd__index),
626
- right_on=(C.name, C.pq__index))
884
+ Example:
885
+ ```py
886
+ meta = meta_emd.merge(meta_pq, on=(C.name, C.emd__index),
887
+ right_on=(C.name, C.pq__index))
888
+ ```
627
889
  """
628
890
  if on is None:
629
891
  raise DatasetMergeError(["None"], None, "'on' must be specified")
@@ -637,8 +899,10 @@ class DataChain(DatasetQuery):
637
899
  f"'on' must be 'str' or 'Sequence' object but got type '{type(on)}'",
638
900
  )
639
901
 
640
- on_columns = self.signals_schema.resolve(*on).db_signals()
902
+ signals_schema = self.signals_schema.clone_without_sys_signals()
903
+ on_columns = signals_schema.resolve(*on).db_signals()
641
904
 
905
+ right_signals_schema = right_ds.signals_schema.clone_without_sys_signals()
642
906
  if right_on is not None:
643
907
  if isinstance(right_on, str):
644
908
  right_on = [right_on]
@@ -655,7 +919,7 @@ class DataChain(DatasetQuery):
655
919
  on, right_on, "'on' and 'right_on' must have the same length'"
656
920
  )
657
921
 
658
- right_on_columns = right_ds.signals_schema.resolve(*right_on).db_signals()
922
+ right_on_columns = right_signals_schema.resolve(*right_on).db_signals()
659
923
 
660
924
  if len(right_on_columns) != len(on_columns):
661
925
  on_str = ", ".join(right_on_columns)
@@ -681,7 +945,9 @@ class DataChain(DatasetQuery):
681
945
  ds = self.join(right_ds, sqlalchemy.and_(*ops), inner, rname + "{name}")
682
946
 
683
947
  ds.feature_schema = None
684
- ds.signals_schema = self.signals_schema.merge(right_ds.signals_schema, rname)
948
+ ds.signals_schema = SignalSchema({"sys": Sys}) | signals_schema.merge(
949
+ right_signals_schema, rname
950
+ )
685
951
 
686
952
  return ds
687
953
 
@@ -694,7 +960,13 @@ class DataChain(DatasetQuery):
694
960
  object_name: str = "",
695
961
  **fr_map,
696
962
  ) -> "DataChain":
697
- """Generate chain from list of values."""
963
+ """Generate chain from list of values.
964
+
965
+ Example:
966
+ ```py
967
+ DataChain.from_values(fib=[1, 2, 3, 5, 8])
968
+ ```
969
+ """
698
970
  tuple_type, output, tuples = values_to_tuples(ds_name, output, **fr_map)
699
971
 
700
972
  def _func_fr() -> Iterator[tuple_type]: # type: ignore[valid-type]
@@ -713,7 +985,16 @@ class DataChain(DatasetQuery):
713
985
  session: Optional[Session] = None,
714
986
  object_name: str = "",
715
987
  ) -> "DataChain":
716
- """Generate chain from pandas data-frame."""
988
+ """Generate chain from pandas data-frame.
989
+
990
+ Example:
991
+ ```py
992
+ import pandas as pd
993
+
994
+ df = pd.DataFrame({"fib": [1, 2, 3, 5, 8]})
995
+ DataChain.from_pandas(df)
996
+ ```
997
+ """
717
998
  fr_map = {col.lower(): df[col].tolist() for col in df.columns}
718
999
 
719
1000
  for column in fr_map:
@@ -731,11 +1012,76 @@ class DataChain(DatasetQuery):
731
1012
 
732
1013
  return cls.from_values(name, session, object_name=object_name, **fr_map)
733
1014
 
1015
+ def to_pandas(self, flatten=False) -> "pd.DataFrame":
1016
+ """Return a pandas DataFrame from the chain.
1017
+
1018
+ Parameters:
1019
+ flatten : Whether to use a multiindex or flatten column names.
1020
+ """
1021
+ headers, max_length = self._effective_signals_schema.get_headers_with_length()
1022
+ if flatten or max_length < 2:
1023
+ df = pd.DataFrame.from_records(self.to_records())
1024
+ if headers:
1025
+ df.columns = [".".join(filter(None, header)) for header in headers]
1026
+ return df
1027
+
1028
+ transposed_result = list(map(list, zip(*self.results())))
1029
+ data = {tuple(n): val for n, val in zip(headers, transposed_result)}
1030
+ return pd.DataFrame(data)
1031
+
1032
+ def show(
1033
+ self,
1034
+ limit: int = 20,
1035
+ flatten=False,
1036
+ transpose=False,
1037
+ truncate=True,
1038
+ ) -> None:
1039
+ """Show a preview of the chain results.
1040
+
1041
+ Parameters:
1042
+ limit : How many rows to show.
1043
+ flatten : Whether to use a multiindex or flatten column names.
1044
+ transpose : Whether to transpose rows and columns.
1045
+ truncate : Whether or not to truncate the contents of columns.
1046
+ """
1047
+ dc = self.limit(limit) if limit > 0 else self
1048
+ df = dc.to_pandas(flatten)
1049
+ if transpose:
1050
+ df = df.T
1051
+
1052
+ options: list = [
1053
+ "display.max_columns",
1054
+ None,
1055
+ "display.multi_sparse",
1056
+ False,
1057
+ ]
1058
+
1059
+ try:
1060
+ if columns := os.get_terminal_size().columns:
1061
+ options.extend(["display.width", columns])
1062
+ except OSError:
1063
+ pass
1064
+
1065
+ if not truncate:
1066
+ options.extend(["display.max_colwidth", None])
1067
+
1068
+ with pd.option_context(*options):
1069
+ if inside_notebook():
1070
+ from IPython.display import display
1071
+
1072
+ display(df)
1073
+ else:
1074
+ print(df)
1075
+
1076
+ if len(df) == limit:
1077
+ print(f"\n[Limited by {len(df)} rows]")
1078
+
734
1079
  def parse_tabular(
735
1080
  self,
736
1081
  output: OutputType = None,
737
1082
  object_name: str = "",
738
1083
  model_name: str = "",
1084
+ nrows: Optional[int] = None,
739
1085
  **kwargs,
740
1086
  ) -> "DataChain":
741
1087
  """Generate chain from list of tabular files.
@@ -747,18 +1093,22 @@ class DataChain(DatasetQuery):
747
1093
  object_name : Generated object column name.
748
1094
  model_name : Generated model name.
749
1095
  kwargs : Parameters to pass to pyarrow.dataset.dataset.
1096
+ nrows : Optional row limit.
750
1097
 
751
- Examples:
1098
+ Example:
752
1099
  Reading a json lines file:
753
- >>> dc = DataChain.from_storage("s3://mybucket/file.jsonl")
754
- >>> dc = dc.parse_tabular(format="json")
1100
+ ```py
1101
+ dc = DataChain.from_storage("s3://mybucket/file.jsonl")
1102
+ dc = dc.parse_tabular(format="json")
1103
+ ```
755
1104
 
756
1105
  Reading a filtered list of files as a dataset:
757
- >>> dc = DataChain.from_storage("s3://mybucket")
758
- >>> dc = dc.filter(C("file.name").glob("*.jsonl"))
759
- >>> dc = dc.parse_tabular(format="json")
1106
+ ```py
1107
+ dc = DataChain.from_storage("s3://mybucket")
1108
+ dc = dc.filter(C("file.name").glob("*.jsonl"))
1109
+ dc = dc.parse_tabular(format="json")
1110
+ ```
760
1111
  """
761
-
762
1112
  from datachain.lib.arrow import ArrowGenerator, infer_schema, schema_to_output
763
1113
 
764
1114
  schema = None
@@ -781,7 +1131,7 @@ class DataChain(DatasetQuery):
781
1131
  for name, info in output.model_fields.items()
782
1132
  }
783
1133
  output = {"source": IndexedFile} | output # type: ignore[assignment,operator]
784
- return self.gen(ArrowGenerator(schema, **kwargs), output=output)
1134
+ return self.gen(ArrowGenerator(schema, nrows, **kwargs), output=output)
785
1135
 
786
1136
  @staticmethod
787
1137
  def _dict_to_data_model(
@@ -804,6 +1154,7 @@ class DataChain(DatasetQuery):
804
1154
  output: OutputType = None,
805
1155
  object_name: str = "",
806
1156
  model_name: str = "",
1157
+ nrows=None,
807
1158
  **kwargs,
808
1159
  ) -> "DataChain":
809
1160
  """Generate chain from csv files.
@@ -818,13 +1169,18 @@ class DataChain(DatasetQuery):
818
1169
  case types will be inferred.
819
1170
  object_name : Created object column name.
820
1171
  model_name : Generated model name.
1172
+ nrows : Optional row limit.
821
1173
 
822
- Examples:
1174
+ Example:
823
1175
  Reading a csv file:
824
- >>> dc = DataChain.from_csv("s3://mybucket/file.csv")
1176
+ ```py
1177
+ dc = DataChain.from_csv("s3://mybucket/file.csv")
1178
+ ```
825
1179
 
826
1180
  Reading csv files from a directory as a combined dataset:
827
- >>> dc = DataChain.from_csv("s3://mybucket/dir")
1181
+ ```py
1182
+ dc = DataChain.from_csv("s3://mybucket/dir")
1183
+ ```
828
1184
  """
829
1185
  from pyarrow.csv import ParseOptions, ReadOptions
830
1186
  from pyarrow.dataset import CsvFileFormat
@@ -849,7 +1205,11 @@ class DataChain(DatasetQuery):
849
1205
  read_options = ReadOptions(column_names=column_names)
850
1206
  format = CsvFileFormat(parse_options=parse_options, read_options=read_options)
851
1207
  return chain.parse_tabular(
852
- output=output, object_name=object_name, model_name=model_name, format=format
1208
+ output=output,
1209
+ object_name=object_name,
1210
+ model_name=model_name,
1211
+ nrows=nrows,
1212
+ format=format,
853
1213
  )
854
1214
 
855
1215
  @classmethod
@@ -860,6 +1220,7 @@ class DataChain(DatasetQuery):
860
1220
  output: Optional[dict[str, DataType]] = None,
861
1221
  object_name: str = "",
862
1222
  model_name: str = "",
1223
+ nrows=None,
863
1224
  **kwargs,
864
1225
  ) -> "DataChain":
865
1226
  """Generate chain from parquet files.
@@ -871,23 +1232,48 @@ class DataChain(DatasetQuery):
871
1232
  output : Dictionary defining column names and their corresponding types.
872
1233
  object_name : Created object column name.
873
1234
  model_name : Generated model name.
1235
+ nrows : Optional row limit.
874
1236
 
875
- Examples:
1237
+ Example:
876
1238
  Reading a single file:
877
- >>> dc = DataChain.from_parquet("s3://mybucket/file.parquet")
1239
+ ```py
1240
+ dc = DataChain.from_parquet("s3://mybucket/file.parquet")
1241
+ ```
878
1242
 
879
1243
  Reading a partitioned dataset from a directory:
880
- >>> dc = DataChain.from_parquet("s3://mybucket/dir")
1244
+ ```py
1245
+ dc = DataChain.from_parquet("s3://mybucket/dir")
1246
+ ```
881
1247
  """
882
1248
  chain = DataChain.from_storage(path, **kwargs)
883
1249
  return chain.parse_tabular(
884
1250
  output=output,
885
1251
  object_name=object_name,
886
1252
  model_name=model_name,
1253
+ nrows=None,
887
1254
  format="parquet",
888
1255
  partitioning=partitioning,
889
1256
  )
890
1257
 
1258
+ def to_parquet(
1259
+ self,
1260
+ path: Union[str, os.PathLike[str], BinaryIO],
1261
+ partition_cols: Optional[Sequence[str]] = None,
1262
+ **kwargs,
1263
+ ) -> None:
1264
+ """Save chain to parquet file.
1265
+
1266
+ Parameters:
1267
+ path : Path or a file-like binary object to save the file.
1268
+ partition_cols : Column names by which to partition the dataset.
1269
+ """
1270
+ _partition_cols = list(partition_cols) if partition_cols else None
1271
+ return self.to_pandas().to_parquet(
1272
+ path,
1273
+ partition_cols=_partition_cols,
1274
+ **kwargs,
1275
+ )
1276
+
891
1277
  @classmethod
892
1278
  def create_empty(
893
1279
  cls,
@@ -901,9 +1287,11 @@ class DataChain(DatasetQuery):
901
1287
  to_insert : records (or a single record) to insert. Each record is
902
1288
  a dictionary of signals and theirs values.
903
1289
 
904
- Examples:
905
- >>> empty = DataChain.create_empty()
906
- >>> single_record = DataChain.create_empty(DataChain.DEFAULT_FILE_RECORD)
1290
+ Example:
1291
+ ```py
1292
+ empty = DataChain.create_empty()
1293
+ single_record = DataChain.create_empty(DataChain.DEFAULT_FILE_RECORD)
1294
+ ```
907
1295
  """
908
1296
  session = Session.get(session)
909
1297
  catalog = session.catalog
@@ -929,18 +1317,47 @@ class DataChain(DatasetQuery):
929
1317
  return DataChain(name=dsr.name)
930
1318
 
931
1319
  def sum(self, fr: DataType): # type: ignore[override]
1320
+ """Compute the sum of a column."""
932
1321
  return self._extend_to_data_model("sum", fr)
933
1322
 
934
1323
  def avg(self, fr: DataType): # type: ignore[override]
1324
+ """Compute the average of a column."""
935
1325
  return self._extend_to_data_model("avg", fr)
936
1326
 
937
1327
  def min(self, fr: DataType): # type: ignore[override]
1328
+ """Compute the minimum of a column."""
938
1329
  return self._extend_to_data_model("min", fr)
939
1330
 
940
1331
  def max(self, fr: DataType): # type: ignore[override]
1332
+ """Compute the maximum of a column."""
941
1333
  return self._extend_to_data_model("max", fr)
942
1334
 
943
1335
  def setup(self, **kwargs) -> "Self":
1336
+ """Setup variables to pass to UDF functions.
1337
+
1338
+ Use before running map/gen/agg/batch_map to save an object and pass it as an
1339
+ argument to the UDF.
1340
+
1341
+ Example:
1342
+ ```py
1343
+ import anthropic
1344
+ from anthropic.types import Message
1345
+
1346
+ (
1347
+ DataChain.from_storage(DATA, type="text")
1348
+ .settings(parallel=4, cache=True)
1349
+ .setup(client=lambda: anthropic.Anthropic(api_key=API_KEY))
1350
+ .map(
1351
+ claude=lambda client, file: client.messages.create(
1352
+ model=MODEL,
1353
+ system=PROMPT,
1354
+ messages=[{"role": "user", "content": file.get_value()}],
1355
+ ),
1356
+ output=Message,
1357
+ )
1358
+ )
1359
+ ```
1360
+ """
944
1361
  intersection = set(self._setup.keys()) & set(kwargs.keys())
945
1362
  if intersection:
946
1363
  keys = ", ".join(intersection)
@@ -948,3 +1365,80 @@ class DataChain(DatasetQuery):
948
1365
 
949
1366
  self._setup = self._setup | kwargs
950
1367
  return self
1368
+
1369
+ def export_files(
1370
+ self,
1371
+ output: str,
1372
+ signal="file",
1373
+ placement: FileExportPlacement = "fullpath",
1374
+ use_cache: bool = True,
1375
+ ) -> None:
1376
+ """Method that exports all files from chain to some folder."""
1377
+ if placement == "filename":
1378
+ print("Checking if file names are unique")
1379
+ if self.distinct(f"{signal}.name").count() != self.count():
1380
+ raise ValueError("Files with the same name found")
1381
+
1382
+ for file in self.collect(signal):
1383
+ file.export(output, placement, use_cache) # type: ignore[union-attr]
1384
+
1385
+ def shuffle(self) -> "Self":
1386
+ """Shuffle the rows of the chain deterministically."""
1387
+ return self.order_by("sys.rand")
1388
+
1389
+ def sample(self, n) -> "Self":
1390
+ """Return a random sample from the chain.
1391
+
1392
+ Parameters:
1393
+ n (int): Number of samples to draw.
1394
+
1395
+ NOTE: Samples are not deterministic, and streamed/paginated queries or
1396
+ multiple workers will draw samples with replacement.
1397
+ """
1398
+ return super().sample(n)
1399
+
1400
+ @detach
1401
+ def filter(self, *args) -> "Self":
1402
+ """Filter the chain according to conditions.
1403
+
1404
+ Example:
1405
+ Basic usage with built-in operators
1406
+ ```py
1407
+ dc.filter(C("width") < 200)
1408
+ ```
1409
+
1410
+ Using glob to match patterns
1411
+ ```py
1412
+ dc.filter(C("file.name").glob("*.jpg))
1413
+ ```
1414
+
1415
+ Using `datachain.sql.functions`
1416
+ ```py
1417
+ from datachain.sql.functions import string
1418
+ dc.filter(string.length(C("file.name")) > 5)
1419
+ ```
1420
+
1421
+ Combining filters with "or"
1422
+ ```py
1423
+ dc.filter(C("file.name").glob("cat*") | C("file.name").glob("dog*))
1424
+ ```
1425
+
1426
+ Combining filters with "and"
1427
+ ```py
1428
+ dc.filter(
1429
+ C("file.name").glob("*.jpg) &
1430
+ (string.length(C("file.name")) > 5)
1431
+ )
1432
+ ```
1433
+ """
1434
+ return super().filter(*args)
1435
+
1436
+ @detach
1437
+ def limit(self, n: int) -> "Self":
1438
+ """Return the first n rows of the chain."""
1439
+ return super().limit(n)
1440
+
1441
+ @detach
1442
+ def offset(self, offset: int) -> "Self":
1443
+ """Return the results starting with the offset row."""
1444
+ return super().offset(offset)