datachain 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +3 -4
- datachain/cache.py +10 -4
- datachain/catalog/catalog.py +42 -16
- datachain/cli.py +48 -32
- datachain/data_storage/metastore.py +24 -0
- datachain/data_storage/warehouse.py +3 -1
- datachain/job.py +56 -0
- datachain/lib/arrow.py +19 -7
- datachain/lib/clip.py +89 -66
- datachain/lib/convert/{type_converter.py → python_to_sql.py} +6 -6
- datachain/lib/convert/sql_to_python.py +23 -0
- datachain/lib/convert/values_to_tuples.py +51 -33
- datachain/lib/data_model.py +6 -27
- datachain/lib/dataset_info.py +70 -0
- datachain/lib/dc.py +618 -156
- datachain/lib/file.py +130 -22
- datachain/lib/image.py +1 -1
- datachain/lib/meta_formats.py +14 -2
- datachain/lib/model_store.py +3 -2
- datachain/lib/pytorch.py +10 -7
- datachain/lib/signal_schema.py +19 -11
- datachain/lib/text.py +2 -1
- datachain/lib/udf.py +56 -5
- datachain/lib/udf_signature.py +1 -1
- datachain/node.py +11 -8
- datachain/query/dataset.py +62 -28
- datachain/query/schema.py +2 -0
- datachain/query/session.py +4 -4
- datachain/sql/functions/array.py +12 -0
- datachain/sql/functions/string.py +8 -0
- datachain/torch/__init__.py +1 -1
- datachain/utils.py +6 -0
- datachain-0.2.13.dist-info/METADATA +411 -0
- {datachain-0.2.11.dist-info → datachain-0.2.13.dist-info}/RECORD +38 -42
- {datachain-0.2.11.dist-info → datachain-0.2.13.dist-info}/WHEEL +1 -1
- datachain/lib/gpt4_vision.py +0 -97
- datachain/lib/hf_image_to_text.py +0 -97
- datachain/lib/hf_pipeline.py +0 -90
- datachain/lib/image_transform.py +0 -103
- datachain/lib/iptc_exif_xmp.py +0 -76
- datachain/lib/unstructured.py +0 -41
- datachain/text/__init__.py +0 -3
- datachain-0.2.11.dist-info/METADATA +0 -431
- {datachain-0.2.11.dist-info → datachain-0.2.13.dist-info}/LICENSE +0 -0
- {datachain-0.2.11.dist-info → datachain-0.2.13.dist-info}/entry_points.txt +0 -0
- {datachain-0.2.11.dist-info → datachain-0.2.13.dist-info}/top_level.txt +0 -0
datachain/lib/dc.py
CHANGED
|
@@ -1,23 +1,31 @@
|
|
|
1
1
|
import copy
|
|
2
|
+
import os
|
|
2
3
|
import re
|
|
3
4
|
from collections.abc import Iterator, Sequence
|
|
5
|
+
from functools import wraps
|
|
4
6
|
from typing import (
|
|
5
7
|
TYPE_CHECKING,
|
|
6
8
|
Any,
|
|
9
|
+
BinaryIO,
|
|
7
10
|
Callable,
|
|
8
11
|
ClassVar,
|
|
9
12
|
Literal,
|
|
10
13
|
Optional,
|
|
14
|
+
TypeVar,
|
|
11
15
|
Union,
|
|
16
|
+
overload,
|
|
12
17
|
)
|
|
13
18
|
|
|
14
19
|
import pandas as pd
|
|
15
20
|
import sqlalchemy
|
|
16
21
|
from pydantic import BaseModel, create_model
|
|
22
|
+
from sqlalchemy.sql.functions import GenericFunction
|
|
17
23
|
|
|
18
24
|
from datachain import DataModel
|
|
19
25
|
from datachain.lib.convert.values_to_tuples import values_to_tuples
|
|
20
26
|
from datachain.lib.data_model import DataType
|
|
27
|
+
from datachain.lib.dataset_info import DatasetInfo
|
|
28
|
+
from datachain.lib.file import ExportPlacement as FileExportPlacement
|
|
21
29
|
from datachain.lib.file import File, IndexedFile, get_file
|
|
22
30
|
from datachain.lib.meta_formats import read_meta, read_schema
|
|
23
31
|
from datachain.lib.model_store import ModelStore
|
|
@@ -25,7 +33,6 @@ from datachain.lib.settings import Settings
|
|
|
25
33
|
from datachain.lib.signal_schema import SignalSchema
|
|
26
34
|
from datachain.lib.udf import (
|
|
27
35
|
Aggregator,
|
|
28
|
-
BatchMapper,
|
|
29
36
|
Generator,
|
|
30
37
|
Mapper,
|
|
31
38
|
UDFBase,
|
|
@@ -42,26 +49,57 @@ from datachain.query.schema import Column, DatasetRow
|
|
|
42
49
|
from datachain.utils import inside_notebook
|
|
43
50
|
|
|
44
51
|
if TYPE_CHECKING:
|
|
45
|
-
from typing_extensions import Self
|
|
52
|
+
from typing_extensions import Concatenate, ParamSpec, Self
|
|
53
|
+
|
|
54
|
+
P = ParamSpec("P")
|
|
46
55
|
|
|
47
56
|
C = Column
|
|
48
57
|
|
|
58
|
+
_T = TypeVar("_T")
|
|
59
|
+
D = TypeVar("D", bound="DataChain")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def resolve_columns(
|
|
63
|
+
method: "Callable[Concatenate[D, P], D]",
|
|
64
|
+
) -> "Callable[Concatenate[D, P], D]":
|
|
65
|
+
"""Decorator that resolvs input column names to their actual DB names. This is
|
|
66
|
+
specially important for nested columns as user works with them by using dot
|
|
67
|
+
notation e.g (file.name) but are actually defined with default delimiter
|
|
68
|
+
in DB, e.g file__name.
|
|
69
|
+
If there are any sql functions in arguments, they will just be transferred as is
|
|
70
|
+
to a method.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
@wraps(method)
|
|
74
|
+
def _inner(self: D, *args: "P.args", **kwargs: "P.kwargs") -> D:
|
|
75
|
+
resolved_args = self.signals_schema.resolve(
|
|
76
|
+
*[arg for arg in args if not isinstance(arg, GenericFunction)]
|
|
77
|
+
).db_signals()
|
|
78
|
+
|
|
79
|
+
for idx, arg in enumerate(args):
|
|
80
|
+
if isinstance(arg, GenericFunction):
|
|
81
|
+
resolved_args.insert(idx, arg)
|
|
82
|
+
|
|
83
|
+
return method(self, *resolved_args, **kwargs)
|
|
84
|
+
|
|
85
|
+
return _inner
|
|
49
86
|
|
|
50
|
-
|
|
51
|
-
|
|
87
|
+
|
|
88
|
+
class DatasetPrepareError(DataChainParamsError): # noqa: D101
|
|
89
|
+
def __init__(self, name, msg, output=None): # noqa: D107
|
|
52
90
|
name = f" '{name}'" if name else ""
|
|
53
91
|
output = f" output '{output}'" if output else ""
|
|
54
92
|
super().__init__(f"Dataset{name}{output} processing prepare error: {msg}")
|
|
55
93
|
|
|
56
94
|
|
|
57
|
-
class DatasetFromValuesError(DataChainParamsError):
|
|
58
|
-
def __init__(self, name, msg):
|
|
95
|
+
class DatasetFromValuesError(DataChainParamsError): # noqa: D101
|
|
96
|
+
def __init__(self, name, msg): # noqa: D107
|
|
59
97
|
name = f" '{name}'" if name else ""
|
|
60
|
-
super().__init__(f"Dataset
|
|
98
|
+
super().__init__(f"Dataset{name} from values error: {msg}")
|
|
61
99
|
|
|
62
100
|
|
|
63
|
-
class DatasetMergeError(DataChainParamsError):
|
|
64
|
-
def __init__(self, on: Sequence[str], right_on: Optional[Sequence[str]], msg: str):
|
|
101
|
+
class DatasetMergeError(DataChainParamsError): # noqa: D101
|
|
102
|
+
def __init__(self, on: Sequence[str], right_on: Optional[Sequence[str]], msg: str): # noqa: D107
|
|
65
103
|
on_str = ", ".join(on) if isinstance(on, Sequence) else ""
|
|
66
104
|
right_on_str = (
|
|
67
105
|
", right_on='" + ", ".join(right_on) + "'"
|
|
@@ -75,6 +113,8 @@ OutputType = Union[None, DataType, Sequence[str], dict[str, DataType]]
|
|
|
75
113
|
|
|
76
114
|
|
|
77
115
|
class Sys(DataModel):
|
|
116
|
+
"""Model for internal DataChain signals `id` and `rand`."""
|
|
117
|
+
|
|
78
118
|
id: int
|
|
79
119
|
rand: int
|
|
80
120
|
|
|
@@ -87,7 +127,7 @@ class DataChain(DatasetQuery):
|
|
|
87
127
|
enrich data.
|
|
88
128
|
|
|
89
129
|
Data in DataChain is presented as Python classes with arbitrary set of fields,
|
|
90
|
-
including nested classes. The data classes have to inherit from `
|
|
130
|
+
including nested classes. The data classes have to inherit from `DataModel` class.
|
|
91
131
|
The supported set of field types include: majority of the type supported by the
|
|
92
132
|
underlyind library `Pydantic`.
|
|
93
133
|
|
|
@@ -99,34 +139,56 @@ class DataChain(DatasetQuery):
|
|
|
99
139
|
|
|
100
140
|
`DataChain.from_dataset("name")` - reading from a dataset.
|
|
101
141
|
|
|
102
|
-
`DataChain.
|
|
142
|
+
`DataChain.from_values(fib=[1, 2, 3, 5, 8])` - generating from values.
|
|
143
|
+
|
|
144
|
+
`DataChain.from_pandas(pd.DataFrame(...))` - generating from pandas.
|
|
145
|
+
|
|
146
|
+
`DataChain.from_json("file.json")` - generating from json.
|
|
103
147
|
|
|
148
|
+
`DataChain.from_csv("file.csv")` - generating from csv.
|
|
149
|
+
|
|
150
|
+
`DataChain.from_parquet("file.parquet")` - generating from parquet.
|
|
104
151
|
|
|
105
152
|
Example:
|
|
106
153
|
```py
|
|
107
|
-
|
|
108
|
-
|
|
154
|
+
import os
|
|
155
|
+
|
|
156
|
+
from mistralai.client import MistralClient
|
|
157
|
+
from mistralai.models.chat_completion import ChatMessage
|
|
158
|
+
|
|
159
|
+
from datachain.dc import DataChain, Column
|
|
109
160
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
161
|
+
PROMPT = (
|
|
162
|
+
"Was this bot dialog successful? "
|
|
163
|
+
"Describe the 'result' as 'Yes' or 'No' in a short JSON"
|
|
164
|
+
)
|
|
113
165
|
|
|
114
|
-
|
|
115
|
-
|
|
166
|
+
model = "mistral-large-latest"
|
|
167
|
+
api_key = os.environ["MISTRAL_API_KEY"]
|
|
116
168
|
|
|
117
169
|
chain = (
|
|
118
|
-
DataChain.from_storage("
|
|
119
|
-
.filter(C.name.glob("*.txt"))
|
|
170
|
+
DataChain.from_storage("gs://datachain-demo/chatbot-KiT/")
|
|
120
171
|
.limit(5)
|
|
121
|
-
.
|
|
172
|
+
.settings(cache=True, parallel=5)
|
|
122
173
|
.map(
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
174
|
+
mistral_response=lambda file: MistralClient(api_key=api_key)
|
|
175
|
+
.chat(
|
|
176
|
+
model=model,
|
|
177
|
+
response_format={"type": "json_object"},
|
|
178
|
+
messages=[
|
|
179
|
+
ChatMessage(role="user", content=f"{PROMPT}: {file.read()}")
|
|
180
|
+
],
|
|
181
|
+
)
|
|
182
|
+
.choices[0]
|
|
183
|
+
.message.content,
|
|
184
|
+
)
|
|
185
|
+
.save()
|
|
127
186
|
)
|
|
128
|
-
|
|
129
|
-
|
|
187
|
+
|
|
188
|
+
try:
|
|
189
|
+
print(chain.select("mistral_response").results())
|
|
190
|
+
except Exception as e:
|
|
191
|
+
print(f"do you have the right Mistral API key? {e}")
|
|
130
192
|
```
|
|
131
193
|
"""
|
|
132
194
|
|
|
@@ -138,8 +200,9 @@ class DataChain(DatasetQuery):
|
|
|
138
200
|
}
|
|
139
201
|
|
|
140
202
|
def __init__(self, *args, **kwargs):
|
|
141
|
-
"""This method needs to be redefined as a part of Dataset and
|
|
142
|
-
decoupling.
|
|
203
|
+
"""This method needs to be redefined as a part of Dataset and DataChain
|
|
204
|
+
decoupling.
|
|
205
|
+
"""
|
|
143
206
|
super().__init__(
|
|
144
207
|
*args,
|
|
145
208
|
**kwargs,
|
|
@@ -148,19 +211,25 @@ class DataChain(DatasetQuery):
|
|
|
148
211
|
self._settings = Settings()
|
|
149
212
|
self._setup = {}
|
|
150
213
|
|
|
214
|
+
self.signals_schema = SignalSchema({"sys": Sys})
|
|
151
215
|
if self.feature_schema:
|
|
152
|
-
self.signals_schema
|
|
216
|
+
self.signals_schema |= SignalSchema.deserialize(self.feature_schema)
|
|
153
217
|
else:
|
|
154
|
-
self.signals_schema
|
|
218
|
+
self.signals_schema |= SignalSchema.from_column_types(self.column_types)
|
|
219
|
+
|
|
220
|
+
self._sys = False
|
|
155
221
|
|
|
156
222
|
@property
|
|
157
|
-
def schema(self):
|
|
158
|
-
|
|
223
|
+
def schema(self) -> dict[str, DataType]:
|
|
224
|
+
"""Get schema of the chain."""
|
|
225
|
+
return self._effective_signals_schema.values
|
|
159
226
|
|
|
160
|
-
def print_schema(self):
|
|
161
|
-
|
|
227
|
+
def print_schema(self) -> None:
|
|
228
|
+
"""Print schema of the chain."""
|
|
229
|
+
self._effective_signals_schema.print_tree()
|
|
162
230
|
|
|
163
231
|
def clone(self, new_table: bool = True) -> "Self":
|
|
232
|
+
"""Make a copy of the chain in a new table."""
|
|
164
233
|
obj = super().clone(new_table=new_table)
|
|
165
234
|
obj.signals_schema = copy.deepcopy(self.signals_schema)
|
|
166
235
|
return obj
|
|
@@ -172,7 +241,7 @@ class DataChain(DatasetQuery):
|
|
|
172
241
|
parallel=None,
|
|
173
242
|
workers=None,
|
|
174
243
|
min_task_size=None,
|
|
175
|
-
|
|
244
|
+
sys: Optional[bool] = None,
|
|
176
245
|
) -> "Self":
|
|
177
246
|
"""Change settings for chain.
|
|
178
247
|
|
|
@@ -197,10 +266,8 @@ class DataChain(DatasetQuery):
|
|
|
197
266
|
```
|
|
198
267
|
"""
|
|
199
268
|
chain = self.clone()
|
|
200
|
-
if
|
|
201
|
-
chain.
|
|
202
|
-
elif include_sys is False and "sys" in chain.signals_schema:
|
|
203
|
-
chain.signals_schema.remove("sys")
|
|
269
|
+
if sys is not None:
|
|
270
|
+
chain._sys = sys
|
|
204
271
|
chain._settings.add(Settings(cache, batch, parallel, workers, min_task_size))
|
|
205
272
|
return chain
|
|
206
273
|
|
|
@@ -209,17 +276,14 @@ class DataChain(DatasetQuery):
|
|
|
209
276
|
self._settings = settings if settings else Settings()
|
|
210
277
|
return self
|
|
211
278
|
|
|
212
|
-
def reset_schema(self, signals_schema: SignalSchema) -> "Self":
|
|
279
|
+
def reset_schema(self, signals_schema: SignalSchema) -> "Self": # noqa: D102
|
|
213
280
|
self.signals_schema = signals_schema
|
|
214
281
|
return self
|
|
215
282
|
|
|
216
|
-
def add_schema(self, signals_schema: SignalSchema) -> "Self":
|
|
283
|
+
def add_schema(self, signals_schema: SignalSchema) -> "Self": # noqa: D102
|
|
217
284
|
self.signals_schema |= signals_schema
|
|
218
285
|
return self
|
|
219
286
|
|
|
220
|
-
def get_file_signals(self) -> list[str]:
|
|
221
|
-
return list(self.signals_schema.get_file_signals())
|
|
222
|
-
|
|
223
287
|
@classmethod
|
|
224
288
|
def from_storage(
|
|
225
289
|
cls,
|
|
@@ -229,10 +293,11 @@ class DataChain(DatasetQuery):
|
|
|
229
293
|
session: Optional[Session] = None,
|
|
230
294
|
recursive: Optional[bool] = True,
|
|
231
295
|
object_name: str = "file",
|
|
296
|
+
update: bool = False,
|
|
232
297
|
**kwargs,
|
|
233
298
|
) -> "Self":
|
|
234
|
-
"""Get data from a storage as a list of file with all file attributes.
|
|
235
|
-
returns the chain itself as usual.
|
|
299
|
+
"""Get data from a storage as a list of file with all file attributes.
|
|
300
|
+
It returns the chain itself as usual.
|
|
236
301
|
|
|
237
302
|
Parameters:
|
|
238
303
|
path : storage URI with directory. URI must start with storage prefix such
|
|
@@ -240,6 +305,7 @@ class DataChain(DatasetQuery):
|
|
|
240
305
|
type : read file as "binary", "text", or "image" data. Default is "binary".
|
|
241
306
|
recursive : search recursively for the given path.
|
|
242
307
|
object_name : Created object column name.
|
|
308
|
+
update : force storage reindexing. Default is False.
|
|
243
309
|
|
|
244
310
|
Example:
|
|
245
311
|
```py
|
|
@@ -247,20 +313,24 @@ class DataChain(DatasetQuery):
|
|
|
247
313
|
```
|
|
248
314
|
"""
|
|
249
315
|
func = get_file(type)
|
|
250
|
-
return
|
|
251
|
-
**
|
|
316
|
+
return (
|
|
317
|
+
cls(path, session=session, recursive=recursive, update=update, **kwargs)
|
|
318
|
+
.map(**{object_name: func})
|
|
319
|
+
.select(object_name)
|
|
252
320
|
)
|
|
253
321
|
|
|
254
322
|
@classmethod
|
|
255
323
|
def from_dataset(cls, name: str, version: Optional[int] = None) -> "DataChain":
|
|
256
|
-
"""Get data from
|
|
324
|
+
"""Get data from a saved Dataset. It returns the chain itself.
|
|
257
325
|
|
|
258
326
|
Parameters:
|
|
259
327
|
name : dataset name
|
|
260
328
|
version : dataset version
|
|
261
329
|
|
|
262
|
-
|
|
263
|
-
|
|
330
|
+
Example:
|
|
331
|
+
```py
|
|
332
|
+
chain = DataChain.from_dataset("my_cats")
|
|
333
|
+
```
|
|
264
334
|
"""
|
|
265
335
|
return DataChain(name=name, version=version)
|
|
266
336
|
|
|
@@ -276,6 +346,7 @@ class DataChain(DatasetQuery):
|
|
|
276
346
|
model_name: Optional[str] = None,
|
|
277
347
|
show_schema: Optional[bool] = False,
|
|
278
348
|
meta_type: Optional[str] = "json",
|
|
349
|
+
nrows=None,
|
|
279
350
|
**kwargs,
|
|
280
351
|
) -> "DataChain":
|
|
281
352
|
"""Get data from JSON. It returns the chain itself.
|
|
@@ -285,18 +356,23 @@ class DataChain(DatasetQuery):
|
|
|
285
356
|
as `s3://`, `gs://`, `az://` or "file:///"
|
|
286
357
|
type : read file as "binary", "text", or "image" data. Default is "binary".
|
|
287
358
|
spec : optional Data Model
|
|
288
|
-
schema_from : path to sample to infer spec
|
|
359
|
+
schema_from : path to sample to infer spec (if schema not provided)
|
|
289
360
|
object_name : generated object column name
|
|
290
|
-
model_name : generated model name
|
|
361
|
+
model_name : optional generated model name
|
|
291
362
|
show_schema : print auto-generated schema
|
|
292
|
-
jmespath : JMESPATH expression to reduce JSON
|
|
363
|
+
jmespath : optional JMESPATH expression to reduce JSON
|
|
364
|
+
nrows : optional row limit for jsonl and JSON arrays
|
|
293
365
|
|
|
294
|
-
|
|
366
|
+
Example:
|
|
295
367
|
infer JSON schema from data, reduce using JMESPATH, print schema
|
|
296
|
-
|
|
368
|
+
```py
|
|
369
|
+
chain = DataChain.from_json("gs://json", jmespath="key1.key2")
|
|
370
|
+
```
|
|
297
371
|
|
|
298
372
|
infer JSON schema from a particular path, print data model
|
|
299
|
-
|
|
373
|
+
```py
|
|
374
|
+
chain = DataChain.from_json("gs://json_ds", schema_from="gs://json/my.json")
|
|
375
|
+
```
|
|
300
376
|
"""
|
|
301
377
|
if schema_from == "auto":
|
|
302
378
|
schema_from = path
|
|
@@ -318,10 +394,40 @@ class DataChain(DatasetQuery):
|
|
|
318
394
|
model_name=model_name,
|
|
319
395
|
show_schema=show_schema,
|
|
320
396
|
jmespath=jmespath,
|
|
397
|
+
nrows=nrows,
|
|
321
398
|
)
|
|
322
399
|
}
|
|
323
400
|
return chain.gen(**signal_dict) # type: ignore[arg-type]
|
|
324
401
|
|
|
402
|
+
@classmethod
|
|
403
|
+
def datasets(
|
|
404
|
+
cls, session: Optional[Session] = None, object_name: str = "dataset"
|
|
405
|
+
) -> "DataChain":
|
|
406
|
+
"""Generate chain with list of registered datasets.
|
|
407
|
+
|
|
408
|
+
Example:
|
|
409
|
+
```py
|
|
410
|
+
from datachain import DataChain
|
|
411
|
+
|
|
412
|
+
chain = DataChain.datasets()
|
|
413
|
+
for ds in chain.collect("dataset"):
|
|
414
|
+
print(f"{ds.name}@v{ds.version}")
|
|
415
|
+
```
|
|
416
|
+
"""
|
|
417
|
+
session = Session.get(session)
|
|
418
|
+
catalog = session.catalog
|
|
419
|
+
|
|
420
|
+
datasets = [
|
|
421
|
+
DatasetInfo.from_models(d, v, j)
|
|
422
|
+
for d, v, j in catalog.list_datasets_versions()
|
|
423
|
+
]
|
|
424
|
+
|
|
425
|
+
return cls.from_values(
|
|
426
|
+
session=session,
|
|
427
|
+
output={object_name: DatasetInfo},
|
|
428
|
+
**{object_name: datasets}, # type: ignore[arg-type]
|
|
429
|
+
)
|
|
430
|
+
|
|
325
431
|
def show_json_schema( # type: ignore[override]
|
|
326
432
|
self, jmespath: Optional[str] = None, model_name: Optional[str] = None
|
|
327
433
|
) -> "DataChain":
|
|
@@ -331,12 +437,14 @@ class DataChain(DatasetQuery):
|
|
|
331
437
|
jmespath : JMESPATH expression to reduce JSON
|
|
332
438
|
model_name : generated model name
|
|
333
439
|
|
|
334
|
-
|
|
440
|
+
Example:
|
|
335
441
|
print JSON schema and save to column "meta_from":
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
442
|
+
```py
|
|
443
|
+
uri = "gs://datachain-demo/coco2017/annotations_captions/"
|
|
444
|
+
chain = DataChain.from_storage(uri)
|
|
445
|
+
chain = chain.show_json_schema()
|
|
446
|
+
chain.save()
|
|
447
|
+
```
|
|
340
448
|
"""
|
|
341
449
|
return self.map(
|
|
342
450
|
meta_schema=lambda file: read_schema(
|
|
@@ -371,11 +479,29 @@ class DataChain(DatasetQuery):
|
|
|
371
479
|
removed after process ends. Temp dataset are useful for optimization.
|
|
372
480
|
version : version of a dataset. Default - the last version that exist.
|
|
373
481
|
"""
|
|
374
|
-
schema = self.signals_schema.serialize()
|
|
375
|
-
schema.pop("sys", None)
|
|
482
|
+
schema = self.signals_schema.clone_without_sys_signals().serialize()
|
|
376
483
|
return super().save(name=name, version=version, feature_schema=schema)
|
|
377
484
|
|
|
378
485
|
def apply(self, func, *args, **kwargs):
|
|
486
|
+
"""Apply any function to the chain.
|
|
487
|
+
|
|
488
|
+
Useful for reusing in a chain of operations.
|
|
489
|
+
|
|
490
|
+
Example:
|
|
491
|
+
```py
|
|
492
|
+
def parse_stem(chain):
|
|
493
|
+
return chain.map(
|
|
494
|
+
lambda file: file.get_file_stem()
|
|
495
|
+
output={"stem": str}
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
chain = (
|
|
499
|
+
DataChain.from_storage("s3://my-bucket")
|
|
500
|
+
.apply(parse_stem)
|
|
501
|
+
.filter(C("stem").glob("*cat*"))
|
|
502
|
+
)
|
|
503
|
+
```
|
|
504
|
+
"""
|
|
379
505
|
return func(self, *args, **kwargs)
|
|
380
506
|
|
|
381
507
|
def map(
|
|
@@ -403,16 +529,19 @@ class DataChain(DatasetQuery):
|
|
|
403
529
|
signal name in format of `map(my_sign=my_func)`. This helps define
|
|
404
530
|
signal names and function in a nicer way.
|
|
405
531
|
|
|
406
|
-
|
|
532
|
+
Example:
|
|
407
533
|
Using signal_map and single type in output:
|
|
408
|
-
|
|
409
|
-
|
|
534
|
+
```py
|
|
535
|
+
chain = chain.map(value=lambda name: name[:-4] + ".json", output=str)
|
|
536
|
+
chain.save("new_dataset")
|
|
537
|
+
```
|
|
410
538
|
|
|
411
539
|
Using func and output as a map:
|
|
412
|
-
|
|
413
|
-
|
|
540
|
+
```py
|
|
541
|
+
chain = chain.map(lambda name: name[:-4] + ".json", output={"res": str})
|
|
542
|
+
chain.save("new_dataset")
|
|
543
|
+
```
|
|
414
544
|
"""
|
|
415
|
-
|
|
416
545
|
udf_obj = self._udf_to_obj(Mapper, func, params, output, signal_map)
|
|
417
546
|
|
|
418
547
|
chain = self.add_signals(
|
|
@@ -440,7 +569,6 @@ class DataChain(DatasetQuery):
|
|
|
440
569
|
extracting multiple file records from a single tar file or bounding boxes from a
|
|
441
570
|
single image file).
|
|
442
571
|
"""
|
|
443
|
-
|
|
444
572
|
udf_obj = self._udf_to_obj(Generator, func, params, output, signal_map)
|
|
445
573
|
chain = DatasetQuery.generate(
|
|
446
574
|
self,
|
|
@@ -481,27 +609,6 @@ class DataChain(DatasetQuery):
|
|
|
481
609
|
|
|
482
610
|
return chain.reset_schema(udf_obj.output).reset_settings(self._settings)
|
|
483
611
|
|
|
484
|
-
def batch_map(
|
|
485
|
-
self,
|
|
486
|
-
func: Optional[Callable] = None,
|
|
487
|
-
params: Union[None, str, Sequence[str]] = None,
|
|
488
|
-
output: OutputType = None,
|
|
489
|
-
**signal_map,
|
|
490
|
-
) -> "Self":
|
|
491
|
-
"""This is a batch version of map().
|
|
492
|
-
|
|
493
|
-
It accepts the same parameters plus an
|
|
494
|
-
additional parameter:
|
|
495
|
-
"""
|
|
496
|
-
udf_obj = self._udf_to_obj(BatchMapper, func, params, output, signal_map)
|
|
497
|
-
chain = DatasetQuery.generate(
|
|
498
|
-
self,
|
|
499
|
-
udf_obj.to_udf_wrapper(self._settings.batch),
|
|
500
|
-
**self._settings.to_dict(),
|
|
501
|
-
)
|
|
502
|
-
|
|
503
|
-
return chain.add_schema(udf_obj.output).reset_settings(self._settings)
|
|
504
|
-
|
|
505
612
|
def _udf_to_obj(
|
|
506
613
|
self,
|
|
507
614
|
target_class: type[UDFBase],
|
|
@@ -516,7 +623,11 @@ class DataChain(DatasetQuery):
|
|
|
516
623
|
sign = UdfSignature.parse(name, signal_map, func, params, output, is_generator)
|
|
517
624
|
DataModel.register(list(sign.output_schema.values.values()))
|
|
518
625
|
|
|
519
|
-
|
|
626
|
+
signals_schema = self.signals_schema
|
|
627
|
+
if self._sys:
|
|
628
|
+
signals_schema = SignalSchema({"sys": Sys}) | signals_schema
|
|
629
|
+
|
|
630
|
+
params_schema = signals_schema.slice(sign.params, self._setup)
|
|
520
631
|
|
|
521
632
|
return target_class._create(sign, params_schema)
|
|
522
633
|
|
|
@@ -532,9 +643,38 @@ class DataChain(DatasetQuery):
|
|
|
532
643
|
return res
|
|
533
644
|
|
|
534
645
|
@detach
|
|
535
|
-
|
|
646
|
+
@resolve_columns
|
|
647
|
+
def order_by(self, *args, descending: bool = False) -> "Self":
|
|
648
|
+
"""Orders by specified set of signals.
|
|
649
|
+
|
|
650
|
+
Parameters:
|
|
651
|
+
descending (bool): Whether to sort in descending order or not.
|
|
652
|
+
"""
|
|
653
|
+
if descending:
|
|
654
|
+
args = tuple([sqlalchemy.desc(a) for a in args])
|
|
655
|
+
|
|
656
|
+
return super().order_by(*args)
|
|
657
|
+
|
|
658
|
+
@detach
|
|
659
|
+
def distinct(self, arg: str, *args: str) -> "Self": # type: ignore[override]
|
|
660
|
+
"""Removes duplicate rows based on uniqueness of some input column(s)
|
|
661
|
+
i.e if rows are found with the same value of input column(s), only one
|
|
662
|
+
row is left in the result set.
|
|
663
|
+
|
|
664
|
+
Example:
|
|
665
|
+
```py
|
|
666
|
+
dc.distinct("file.parent", "file.name")
|
|
667
|
+
)
|
|
668
|
+
```
|
|
669
|
+
"""
|
|
670
|
+
return super().distinct(*self.signals_schema.resolve(arg, *args).db_signals())
|
|
671
|
+
|
|
672
|
+
@detach
|
|
673
|
+
def select(self, *args: str, _sys: bool = True) -> "Self":
|
|
536
674
|
"""Select only a specified set of signals."""
|
|
537
675
|
new_schema = self.signals_schema.resolve(*args)
|
|
676
|
+
if _sys:
|
|
677
|
+
new_schema = SignalSchema({"sys": Sys}) | new_schema
|
|
538
678
|
columns = new_schema.db_signals()
|
|
539
679
|
chain = super().select(*columns)
|
|
540
680
|
chain.signals_schema = new_schema
|
|
@@ -549,45 +689,156 @@ class DataChain(DatasetQuery):
|
|
|
549
689
|
chain.signals_schema = new_schema
|
|
550
690
|
return chain
|
|
551
691
|
|
|
552
|
-
|
|
553
|
-
|
|
692
|
+
@detach
|
|
693
|
+
def mutate(self, **kwargs) -> "Self":
|
|
694
|
+
"""Create new signals based on existing signals.
|
|
695
|
+
|
|
696
|
+
This method is vectorized and more efficient compared to map(), and it does not
|
|
697
|
+
extract or download any data from the internal database. However, it can only
|
|
698
|
+
utilize predefined built-in functions and their combinations.
|
|
699
|
+
|
|
700
|
+
The supported functions:
|
|
701
|
+
Numerical: +, -, *, /, rand(), avg(), count(), func(),
|
|
702
|
+
greatest(), least(), max(), min(), sum()
|
|
703
|
+
String: length(), split()
|
|
704
|
+
Filename: name(), parent(), file_stem(), file_ext()
|
|
705
|
+
Array: length(), sip_hash_64(), euclidean_distance(),
|
|
706
|
+
cosine_distance()
|
|
707
|
+
|
|
708
|
+
Example:
|
|
709
|
+
```py
|
|
710
|
+
dc.mutate(
|
|
711
|
+
area=Column("image.height") * Column("image.width"),
|
|
712
|
+
extension=file_ext(Column("file.name")),
|
|
713
|
+
dist=cosine_distance(embedding_text, embedding_image)
|
|
714
|
+
)
|
|
715
|
+
```
|
|
716
|
+
"""
|
|
717
|
+
chain = super().mutate(**kwargs)
|
|
718
|
+
chain.signals_schema = self.signals_schema.mutate(kwargs)
|
|
719
|
+
return chain
|
|
720
|
+
|
|
721
|
+
@property
|
|
722
|
+
def _effective_signals_schema(self) -> "SignalSchema":
|
|
723
|
+
"""Effective schema used for user-facing API like collect, to_pandas, etc."""
|
|
724
|
+
signals_schema = self.signals_schema
|
|
725
|
+
if not self._sys:
|
|
726
|
+
return signals_schema.clone_without_sys_signals()
|
|
727
|
+
return signals_schema
|
|
728
|
+
|
|
729
|
+
@overload
|
|
730
|
+
def collect_flatten(self) -> Iterator[tuple[Any, ...]]: ...
|
|
731
|
+
|
|
732
|
+
@overload
|
|
733
|
+
def collect_flatten(
|
|
734
|
+
self, *, row_factory: Callable[[list[str], tuple[Any, ...]], _T]
|
|
735
|
+
) -> Iterator[_T]: ...
|
|
736
|
+
|
|
737
|
+
def collect_flatten(self, *, row_factory=None):
|
|
738
|
+
"""Yields flattened rows of values as a tuple.
|
|
739
|
+
|
|
740
|
+
Args:
|
|
741
|
+
row_factory : A callable to convert row to a custom format.
|
|
742
|
+
It should accept two arguments: a list of column names and
|
|
743
|
+
a tuple of row values.
|
|
744
|
+
"""
|
|
745
|
+
db_signals = self._effective_signals_schema.db_signals()
|
|
554
746
|
with super().select(*db_signals).as_iterable() as rows:
|
|
747
|
+
if row_factory:
|
|
748
|
+
rows = (row_factory(db_signals, r) for r in rows)
|
|
555
749
|
yield from rows
|
|
556
750
|
|
|
751
|
+
@overload
|
|
752
|
+
def results(self) -> list[tuple[Any, ...]]: ...
|
|
753
|
+
|
|
754
|
+
@overload
|
|
557
755
|
def results(
|
|
558
|
-
self, row_factory:
|
|
559
|
-
) -> list[
|
|
560
|
-
rows = self.iterate_flatten()
|
|
561
|
-
if row_factory:
|
|
562
|
-
db_signals = self.signals_schema.db_signals()
|
|
563
|
-
rows = (row_factory(db_signals, r) for r in rows)
|
|
564
|
-
return list(rows)
|
|
756
|
+
self, *, row_factory: Callable[[list[str], tuple[Any, ...]], _T]
|
|
757
|
+
) -> list[_T]: ...
|
|
565
758
|
|
|
566
|
-
def
|
|
567
|
-
|
|
759
|
+
def results(self, *, row_factory=None): # noqa: D102
|
|
760
|
+
if row_factory is None:
|
|
761
|
+
return list(self.collect_flatten())
|
|
762
|
+
return list(self.collect_flatten(row_factory=row_factory))
|
|
568
763
|
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
)
|
|
764
|
+
def to_records(self) -> list[dict[str, Any]]:
|
|
765
|
+
"""Convert every row to a dictionary."""
|
|
766
|
+
|
|
767
|
+
def to_dict(cols: list[str], row: tuple[Any, ...]) -> dict[str, Any]:
|
|
768
|
+
return dict(zip(cols, row))
|
|
769
|
+
|
|
770
|
+
return self.results(row_factory=to_dict)
|
|
577
771
|
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
yield item[0]
|
|
772
|
+
@overload
|
|
773
|
+
def collect(self) -> Iterator[tuple[DataType, ...]]: ...
|
|
581
774
|
|
|
582
|
-
|
|
583
|
-
|
|
775
|
+
@overload
|
|
776
|
+
def collect(self, col: str) -> Iterator[DataType]: ... # type: ignore[overload-overlap]
|
|
584
777
|
|
|
585
|
-
|
|
586
|
-
|
|
778
|
+
@overload
|
|
779
|
+
def collect(self, *cols: str) -> Iterator[tuple[DataType, ...]]: ...
|
|
587
780
|
|
|
588
|
-
def
|
|
589
|
-
"""
|
|
781
|
+
def collect(self, *cols: str) -> Iterator[Union[DataType, tuple[DataType, ...]]]: # type: ignore[overload-overlap,misc]
|
|
782
|
+
"""Yields rows of values, optionally limited to the specified columns.
|
|
590
783
|
|
|
784
|
+
Args:
|
|
785
|
+
*cols: Limit to the specified columns. By default, all columns are selected.
|
|
786
|
+
|
|
787
|
+
Yields:
|
|
788
|
+
(DataType): Yields a single item if a column is selected.
|
|
789
|
+
(tuple[DataType, ...]): Yields a tuple of items if multiple columns are
|
|
790
|
+
selected.
|
|
791
|
+
|
|
792
|
+
Example:
|
|
793
|
+
Iterating over all rows:
|
|
794
|
+
```py
|
|
795
|
+
for row in dc.collect():
|
|
796
|
+
print(row)
|
|
797
|
+
```
|
|
798
|
+
|
|
799
|
+
Iterating over all rows with selected columns:
|
|
800
|
+
```py
|
|
801
|
+
for name, size in dc.collect("file.name", "file.size"):
|
|
802
|
+
print(name, size)
|
|
803
|
+
```
|
|
804
|
+
|
|
805
|
+
Iterating over a single column:
|
|
806
|
+
```py
|
|
807
|
+
for file in dc.collect("file.name"):
|
|
808
|
+
print(file)
|
|
809
|
+
```
|
|
810
|
+
"""
|
|
811
|
+
chain = self.select(*cols) if cols else self
|
|
812
|
+
signals_schema = chain._effective_signals_schema
|
|
813
|
+
db_signals = signals_schema.db_signals()
|
|
814
|
+
with super().select(*db_signals).as_iterable() as rows:
|
|
815
|
+
for row in rows:
|
|
816
|
+
ret = signals_schema.row_to_features(
|
|
817
|
+
row, catalog=chain.session.catalog, cache=chain._settings.cache
|
|
818
|
+
)
|
|
819
|
+
yield ret[0] if len(cols) == 1 else tuple(ret)
|
|
820
|
+
|
|
821
|
+
def to_pytorch(
|
|
822
|
+
self, transform=None, tokenizer=None, tokenizer_kwargs=None, num_samples=0
|
|
823
|
+
):
|
|
824
|
+
"""Convert to pytorch dataset format.
|
|
825
|
+
|
|
826
|
+
Args:
|
|
827
|
+
transform (Transform): Torchvision transforms to apply to the dataset.
|
|
828
|
+
tokenizer (Callable): Tokenizer to use to tokenize text values.
|
|
829
|
+
tokenizer_kwargs (dict): Additional kwargs to pass when calling tokenizer.
|
|
830
|
+
num_samples (int): Number of random samples to draw for each epoch.
|
|
831
|
+
This argument is ignored if `num_samples=0` (the default).
|
|
832
|
+
|
|
833
|
+
Example:
|
|
834
|
+
```py
|
|
835
|
+
from torch.utils.data import DataLoader
|
|
836
|
+
loader = DataLoader(
|
|
837
|
+
chain.select("file", "label").to_pytorch(),
|
|
838
|
+
batch_size=16
|
|
839
|
+
)
|
|
840
|
+
```
|
|
841
|
+
"""
|
|
591
842
|
from datachain.torch import PytorchDataset
|
|
592
843
|
|
|
593
844
|
if self.attached:
|
|
@@ -595,9 +846,17 @@ class DataChain(DatasetQuery):
|
|
|
595
846
|
else:
|
|
596
847
|
chain = self.save()
|
|
597
848
|
assert chain.name is not None # for mypy
|
|
598
|
-
return PytorchDataset(
|
|
849
|
+
return PytorchDataset(
|
|
850
|
+
chain.name,
|
|
851
|
+
chain.version,
|
|
852
|
+
catalog=self.catalog,
|
|
853
|
+
transform=transform,
|
|
854
|
+
tokenizer=tokenizer,
|
|
855
|
+
tokenizer_kwargs=tokenizer_kwargs,
|
|
856
|
+
num_samples=num_samples,
|
|
857
|
+
)
|
|
599
858
|
|
|
600
|
-
def remove_file_signals(self) -> "Self":
|
|
859
|
+
def remove_file_signals(self) -> "Self": # noqa: D102
|
|
601
860
|
schema = self.signals_schema.clone_without_file_signals()
|
|
602
861
|
return self.select(*schema.values.keys())
|
|
603
862
|
|
|
@@ -622,9 +881,11 @@ class DataChain(DatasetQuery):
|
|
|
622
881
|
inner (bool): Whether to run inner join or outer join.
|
|
623
882
|
rname (str): name prefix for conflicting signal names.
|
|
624
883
|
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
884
|
+
Example:
|
|
885
|
+
```py
|
|
886
|
+
meta = meta_emd.merge(meta_pq, on=(C.name, C.emd__index),
|
|
887
|
+
right_on=(C.name, C.pq__index))
|
|
888
|
+
```
|
|
628
889
|
"""
|
|
629
890
|
if on is None:
|
|
630
891
|
raise DatasetMergeError(["None"], None, "'on' must be specified")
|
|
@@ -638,8 +899,10 @@ class DataChain(DatasetQuery):
|
|
|
638
899
|
f"'on' must be 'str' or 'Sequence' object but got type '{type(on)}'",
|
|
639
900
|
)
|
|
640
901
|
|
|
641
|
-
|
|
902
|
+
signals_schema = self.signals_schema.clone_without_sys_signals()
|
|
903
|
+
on_columns = signals_schema.resolve(*on).db_signals()
|
|
642
904
|
|
|
905
|
+
right_signals_schema = right_ds.signals_schema.clone_without_sys_signals()
|
|
643
906
|
if right_on is not None:
|
|
644
907
|
if isinstance(right_on, str):
|
|
645
908
|
right_on = [right_on]
|
|
@@ -656,7 +919,7 @@ class DataChain(DatasetQuery):
|
|
|
656
919
|
on, right_on, "'on' and 'right_on' must have the same length'"
|
|
657
920
|
)
|
|
658
921
|
|
|
659
|
-
right_on_columns =
|
|
922
|
+
right_on_columns = right_signals_schema.resolve(*right_on).db_signals()
|
|
660
923
|
|
|
661
924
|
if len(right_on_columns) != len(on_columns):
|
|
662
925
|
on_str = ", ".join(right_on_columns)
|
|
@@ -682,7 +945,9 @@ class DataChain(DatasetQuery):
|
|
|
682
945
|
ds = self.join(right_ds, sqlalchemy.and_(*ops), inner, rname + "{name}")
|
|
683
946
|
|
|
684
947
|
ds.feature_schema = None
|
|
685
|
-
ds.signals_schema =
|
|
948
|
+
ds.signals_schema = SignalSchema({"sys": Sys}) | signals_schema.merge(
|
|
949
|
+
right_signals_schema, rname
|
|
950
|
+
)
|
|
686
951
|
|
|
687
952
|
return ds
|
|
688
953
|
|
|
@@ -695,7 +960,13 @@ class DataChain(DatasetQuery):
|
|
|
695
960
|
object_name: str = "",
|
|
696
961
|
**fr_map,
|
|
697
962
|
) -> "DataChain":
|
|
698
|
-
"""Generate chain from list of values.
|
|
963
|
+
"""Generate chain from list of values.
|
|
964
|
+
|
|
965
|
+
Example:
|
|
966
|
+
```py
|
|
967
|
+
DataChain.from_values(fib=[1, 2, 3, 5, 8])
|
|
968
|
+
```
|
|
969
|
+
"""
|
|
699
970
|
tuple_type, output, tuples = values_to_tuples(ds_name, output, **fr_map)
|
|
700
971
|
|
|
701
972
|
def _func_fr() -> Iterator[tuple_type]: # type: ignore[valid-type]
|
|
@@ -714,7 +985,16 @@ class DataChain(DatasetQuery):
|
|
|
714
985
|
session: Optional[Session] = None,
|
|
715
986
|
object_name: str = "",
|
|
716
987
|
) -> "DataChain":
|
|
717
|
-
"""Generate chain from pandas data-frame.
|
|
988
|
+
"""Generate chain from pandas data-frame.
|
|
989
|
+
|
|
990
|
+
Example:
|
|
991
|
+
```py
|
|
992
|
+
import pandas as pd
|
|
993
|
+
|
|
994
|
+
df = pd.DataFrame({"fib": [1, 2, 3, 5, 8]})
|
|
995
|
+
DataChain.from_pandas(df)
|
|
996
|
+
```
|
|
997
|
+
"""
|
|
718
998
|
fr_map = {col.lower(): df[col].tolist() for col in df.columns}
|
|
719
999
|
|
|
720
1000
|
for column in fr_map:
|
|
@@ -733,7 +1013,12 @@ class DataChain(DatasetQuery):
|
|
|
733
1013
|
return cls.from_values(name, session, object_name=object_name, **fr_map)
|
|
734
1014
|
|
|
735
1015
|
def to_pandas(self, flatten=False) -> "pd.DataFrame":
|
|
736
|
-
|
|
1016
|
+
"""Return a pandas DataFrame from the chain.
|
|
1017
|
+
|
|
1018
|
+
Parameters:
|
|
1019
|
+
flatten : Whether to use a multiindex or flatten column names.
|
|
1020
|
+
"""
|
|
1021
|
+
headers, max_length = self._effective_signals_schema.get_headers_with_length()
|
|
737
1022
|
if flatten or max_length < 2:
|
|
738
1023
|
df = pd.DataFrame.from_records(self.to_records())
|
|
739
1024
|
if headers:
|
|
@@ -744,15 +1029,43 @@ class DataChain(DatasetQuery):
|
|
|
744
1029
|
data = {tuple(n): val for n, val in zip(headers, transposed_result)}
|
|
745
1030
|
return pd.DataFrame(data)
|
|
746
1031
|
|
|
747
|
-
def show(
|
|
1032
|
+
def show(
|
|
1033
|
+
self,
|
|
1034
|
+
limit: int = 20,
|
|
1035
|
+
flatten=False,
|
|
1036
|
+
transpose=False,
|
|
1037
|
+
truncate=True,
|
|
1038
|
+
) -> None:
|
|
1039
|
+
"""Show a preview of the chain results.
|
|
1040
|
+
|
|
1041
|
+
Parameters:
|
|
1042
|
+
limit : How many rows to show.
|
|
1043
|
+
flatten : Whether to use a multiindex or flatten column names.
|
|
1044
|
+
transpose : Whether to transpose rows and columns.
|
|
1045
|
+
truncate : Whether or not to truncate the contents of columns.
|
|
1046
|
+
"""
|
|
748
1047
|
dc = self.limit(limit) if limit > 0 else self
|
|
749
1048
|
df = dc.to_pandas(flatten)
|
|
750
1049
|
if transpose:
|
|
751
1050
|
df = df.T
|
|
752
1051
|
|
|
753
|
-
|
|
754
|
-
"display.max_columns",
|
|
755
|
-
|
|
1052
|
+
options: list = [
|
|
1053
|
+
"display.max_columns",
|
|
1054
|
+
None,
|
|
1055
|
+
"display.multi_sparse",
|
|
1056
|
+
False,
|
|
1057
|
+
]
|
|
1058
|
+
|
|
1059
|
+
try:
|
|
1060
|
+
if columns := os.get_terminal_size().columns:
|
|
1061
|
+
options.extend(["display.width", columns])
|
|
1062
|
+
except OSError:
|
|
1063
|
+
pass
|
|
1064
|
+
|
|
1065
|
+
if not truncate:
|
|
1066
|
+
options.extend(["display.max_colwidth", None])
|
|
1067
|
+
|
|
1068
|
+
with pd.option_context(*options):
|
|
756
1069
|
if inside_notebook():
|
|
757
1070
|
from IPython.display import display
|
|
758
1071
|
|
|
@@ -768,6 +1081,7 @@ class DataChain(DatasetQuery):
|
|
|
768
1081
|
output: OutputType = None,
|
|
769
1082
|
object_name: str = "",
|
|
770
1083
|
model_name: str = "",
|
|
1084
|
+
nrows: Optional[int] = None,
|
|
771
1085
|
**kwargs,
|
|
772
1086
|
) -> "DataChain":
|
|
773
1087
|
"""Generate chain from list of tabular files.
|
|
@@ -779,18 +1093,22 @@ class DataChain(DatasetQuery):
|
|
|
779
1093
|
object_name : Generated object column name.
|
|
780
1094
|
model_name : Generated model name.
|
|
781
1095
|
kwargs : Parameters to pass to pyarrow.dataset.dataset.
|
|
1096
|
+
nrows : Optional row limit.
|
|
782
1097
|
|
|
783
|
-
|
|
1098
|
+
Example:
|
|
784
1099
|
Reading a json lines file:
|
|
785
|
-
|
|
786
|
-
|
|
1100
|
+
```py
|
|
1101
|
+
dc = DataChain.from_storage("s3://mybucket/file.jsonl")
|
|
1102
|
+
dc = dc.parse_tabular(format="json")
|
|
1103
|
+
```
|
|
787
1104
|
|
|
788
1105
|
Reading a filtered list of files as a dataset:
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
1106
|
+
```py
|
|
1107
|
+
dc = DataChain.from_storage("s3://mybucket")
|
|
1108
|
+
dc = dc.filter(C("file.name").glob("*.jsonl"))
|
|
1109
|
+
dc = dc.parse_tabular(format="json")
|
|
1110
|
+
```
|
|
792
1111
|
"""
|
|
793
|
-
|
|
794
1112
|
from datachain.lib.arrow import ArrowGenerator, infer_schema, schema_to_output
|
|
795
1113
|
|
|
796
1114
|
schema = None
|
|
@@ -813,7 +1131,7 @@ class DataChain(DatasetQuery):
|
|
|
813
1131
|
for name, info in output.model_fields.items()
|
|
814
1132
|
}
|
|
815
1133
|
output = {"source": IndexedFile} | output # type: ignore[assignment,operator]
|
|
816
|
-
return self.gen(ArrowGenerator(schema, **kwargs), output=output)
|
|
1134
|
+
return self.gen(ArrowGenerator(schema, nrows, **kwargs), output=output)
|
|
817
1135
|
|
|
818
1136
|
@staticmethod
|
|
819
1137
|
def _dict_to_data_model(
|
|
@@ -836,6 +1154,7 @@ class DataChain(DatasetQuery):
|
|
|
836
1154
|
output: OutputType = None,
|
|
837
1155
|
object_name: str = "",
|
|
838
1156
|
model_name: str = "",
|
|
1157
|
+
nrows=None,
|
|
839
1158
|
**kwargs,
|
|
840
1159
|
) -> "DataChain":
|
|
841
1160
|
"""Generate chain from csv files.
|
|
@@ -850,13 +1169,18 @@ class DataChain(DatasetQuery):
|
|
|
850
1169
|
case types will be inferred.
|
|
851
1170
|
object_name : Created object column name.
|
|
852
1171
|
model_name : Generated model name.
|
|
1172
|
+
nrows : Optional row limit.
|
|
853
1173
|
|
|
854
|
-
|
|
1174
|
+
Example:
|
|
855
1175
|
Reading a csv file:
|
|
856
|
-
|
|
1176
|
+
```py
|
|
1177
|
+
dc = DataChain.from_csv("s3://mybucket/file.csv")
|
|
1178
|
+
```
|
|
857
1179
|
|
|
858
1180
|
Reading csv files from a directory as a combined dataset:
|
|
859
|
-
|
|
1181
|
+
```py
|
|
1182
|
+
dc = DataChain.from_csv("s3://mybucket/dir")
|
|
1183
|
+
```
|
|
860
1184
|
"""
|
|
861
1185
|
from pyarrow.csv import ParseOptions, ReadOptions
|
|
862
1186
|
from pyarrow.dataset import CsvFileFormat
|
|
@@ -881,7 +1205,11 @@ class DataChain(DatasetQuery):
|
|
|
881
1205
|
read_options = ReadOptions(column_names=column_names)
|
|
882
1206
|
format = CsvFileFormat(parse_options=parse_options, read_options=read_options)
|
|
883
1207
|
return chain.parse_tabular(
|
|
884
|
-
output=output,
|
|
1208
|
+
output=output,
|
|
1209
|
+
object_name=object_name,
|
|
1210
|
+
model_name=model_name,
|
|
1211
|
+
nrows=nrows,
|
|
1212
|
+
format=format,
|
|
885
1213
|
)
|
|
886
1214
|
|
|
887
1215
|
@classmethod
|
|
@@ -892,6 +1220,7 @@ class DataChain(DatasetQuery):
|
|
|
892
1220
|
output: Optional[dict[str, DataType]] = None,
|
|
893
1221
|
object_name: str = "",
|
|
894
1222
|
model_name: str = "",
|
|
1223
|
+
nrows=None,
|
|
895
1224
|
**kwargs,
|
|
896
1225
|
) -> "DataChain":
|
|
897
1226
|
"""Generate chain from parquet files.
|
|
@@ -903,23 +1232,48 @@ class DataChain(DatasetQuery):
|
|
|
903
1232
|
output : Dictionary defining column names and their corresponding types.
|
|
904
1233
|
object_name : Created object column name.
|
|
905
1234
|
model_name : Generated model name.
|
|
1235
|
+
nrows : Optional row limit.
|
|
906
1236
|
|
|
907
|
-
|
|
1237
|
+
Example:
|
|
908
1238
|
Reading a single file:
|
|
909
|
-
|
|
1239
|
+
```py
|
|
1240
|
+
dc = DataChain.from_parquet("s3://mybucket/file.parquet")
|
|
1241
|
+
```
|
|
910
1242
|
|
|
911
1243
|
Reading a partitioned dataset from a directory:
|
|
912
|
-
|
|
1244
|
+
```py
|
|
1245
|
+
dc = DataChain.from_parquet("s3://mybucket/dir")
|
|
1246
|
+
```
|
|
913
1247
|
"""
|
|
914
1248
|
chain = DataChain.from_storage(path, **kwargs)
|
|
915
1249
|
return chain.parse_tabular(
|
|
916
1250
|
output=output,
|
|
917
1251
|
object_name=object_name,
|
|
918
1252
|
model_name=model_name,
|
|
1253
|
+
nrows=None,
|
|
919
1254
|
format="parquet",
|
|
920
1255
|
partitioning=partitioning,
|
|
921
1256
|
)
|
|
922
1257
|
|
|
1258
|
+
def to_parquet(
|
|
1259
|
+
self,
|
|
1260
|
+
path: Union[str, os.PathLike[str], BinaryIO],
|
|
1261
|
+
partition_cols: Optional[Sequence[str]] = None,
|
|
1262
|
+
**kwargs,
|
|
1263
|
+
) -> None:
|
|
1264
|
+
"""Save chain to parquet file.
|
|
1265
|
+
|
|
1266
|
+
Parameters:
|
|
1267
|
+
path : Path or a file-like binary object to save the file.
|
|
1268
|
+
partition_cols : Column names by which to partition the dataset.
|
|
1269
|
+
"""
|
|
1270
|
+
_partition_cols = list(partition_cols) if partition_cols else None
|
|
1271
|
+
return self.to_pandas().to_parquet(
|
|
1272
|
+
path,
|
|
1273
|
+
partition_cols=_partition_cols,
|
|
1274
|
+
**kwargs,
|
|
1275
|
+
)
|
|
1276
|
+
|
|
923
1277
|
@classmethod
|
|
924
1278
|
def create_empty(
|
|
925
1279
|
cls,
|
|
@@ -933,9 +1287,11 @@ class DataChain(DatasetQuery):
|
|
|
933
1287
|
to_insert : records (or a single record) to insert. Each record is
|
|
934
1288
|
a dictionary of signals and theirs values.
|
|
935
1289
|
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
1290
|
+
Example:
|
|
1291
|
+
```py
|
|
1292
|
+
empty = DataChain.create_empty()
|
|
1293
|
+
single_record = DataChain.create_empty(DataChain.DEFAULT_FILE_RECORD)
|
|
1294
|
+
```
|
|
939
1295
|
"""
|
|
940
1296
|
session = Session.get(session)
|
|
941
1297
|
catalog = session.catalog
|
|
@@ -961,18 +1317,47 @@ class DataChain(DatasetQuery):
|
|
|
961
1317
|
return DataChain(name=dsr.name)
|
|
962
1318
|
|
|
963
1319
|
def sum(self, fr: DataType): # type: ignore[override]
|
|
1320
|
+
"""Compute the sum of a column."""
|
|
964
1321
|
return self._extend_to_data_model("sum", fr)
|
|
965
1322
|
|
|
966
1323
|
def avg(self, fr: DataType): # type: ignore[override]
|
|
1324
|
+
"""Compute the average of a column."""
|
|
967
1325
|
return self._extend_to_data_model("avg", fr)
|
|
968
1326
|
|
|
969
1327
|
def min(self, fr: DataType): # type: ignore[override]
|
|
1328
|
+
"""Compute the minimum of a column."""
|
|
970
1329
|
return self._extend_to_data_model("min", fr)
|
|
971
1330
|
|
|
972
1331
|
def max(self, fr: DataType): # type: ignore[override]
|
|
1332
|
+
"""Compute the maximum of a column."""
|
|
973
1333
|
return self._extend_to_data_model("max", fr)
|
|
974
1334
|
|
|
975
1335
|
def setup(self, **kwargs) -> "Self":
|
|
1336
|
+
"""Setup variables to pass to UDF functions.
|
|
1337
|
+
|
|
1338
|
+
Use before running map/gen/agg/batch_map to save an object and pass it as an
|
|
1339
|
+
argument to the UDF.
|
|
1340
|
+
|
|
1341
|
+
Example:
|
|
1342
|
+
```py
|
|
1343
|
+
import anthropic
|
|
1344
|
+
from anthropic.types import Message
|
|
1345
|
+
|
|
1346
|
+
(
|
|
1347
|
+
DataChain.from_storage(DATA, type="text")
|
|
1348
|
+
.settings(parallel=4, cache=True)
|
|
1349
|
+
.setup(client=lambda: anthropic.Anthropic(api_key=API_KEY))
|
|
1350
|
+
.map(
|
|
1351
|
+
claude=lambda client, file: client.messages.create(
|
|
1352
|
+
model=MODEL,
|
|
1353
|
+
system=PROMPT,
|
|
1354
|
+
messages=[{"role": "user", "content": file.get_value()}],
|
|
1355
|
+
),
|
|
1356
|
+
output=Message,
|
|
1357
|
+
)
|
|
1358
|
+
)
|
|
1359
|
+
```
|
|
1360
|
+
"""
|
|
976
1361
|
intersection = set(self._setup.keys()) & set(kwargs.keys())
|
|
977
1362
|
if intersection:
|
|
978
1363
|
keys = ", ".join(intersection)
|
|
@@ -980,3 +1365,80 @@ class DataChain(DatasetQuery):
|
|
|
980
1365
|
|
|
981
1366
|
self._setup = self._setup | kwargs
|
|
982
1367
|
return self
|
|
1368
|
+
|
|
1369
|
+
def export_files(
|
|
1370
|
+
self,
|
|
1371
|
+
output: str,
|
|
1372
|
+
signal="file",
|
|
1373
|
+
placement: FileExportPlacement = "fullpath",
|
|
1374
|
+
use_cache: bool = True,
|
|
1375
|
+
) -> None:
|
|
1376
|
+
"""Method that exports all files from chain to some folder."""
|
|
1377
|
+
if placement == "filename":
|
|
1378
|
+
print("Checking if file names are unique")
|
|
1379
|
+
if self.distinct(f"{signal}.name").count() != self.count():
|
|
1380
|
+
raise ValueError("Files with the same name found")
|
|
1381
|
+
|
|
1382
|
+
for file in self.collect(signal):
|
|
1383
|
+
file.export(output, placement, use_cache) # type: ignore[union-attr]
|
|
1384
|
+
|
|
1385
|
+
def shuffle(self) -> "Self":
|
|
1386
|
+
"""Shuffle the rows of the chain deterministically."""
|
|
1387
|
+
return self.order_by("sys.rand")
|
|
1388
|
+
|
|
1389
|
+
def sample(self, n) -> "Self":
|
|
1390
|
+
"""Return a random sample from the chain.
|
|
1391
|
+
|
|
1392
|
+
Parameters:
|
|
1393
|
+
n (int): Number of samples to draw.
|
|
1394
|
+
|
|
1395
|
+
NOTE: Samples are not deterministic, and streamed/paginated queries or
|
|
1396
|
+
multiple workers will draw samples with replacement.
|
|
1397
|
+
"""
|
|
1398
|
+
return super().sample(n)
|
|
1399
|
+
|
|
1400
|
+
@detach
|
|
1401
|
+
def filter(self, *args) -> "Self":
|
|
1402
|
+
"""Filter the chain according to conditions.
|
|
1403
|
+
|
|
1404
|
+
Example:
|
|
1405
|
+
Basic usage with built-in operators
|
|
1406
|
+
```py
|
|
1407
|
+
dc.filter(C("width") < 200)
|
|
1408
|
+
```
|
|
1409
|
+
|
|
1410
|
+
Using glob to match patterns
|
|
1411
|
+
```py
|
|
1412
|
+
dc.filter(C("file.name").glob("*.jpg))
|
|
1413
|
+
```
|
|
1414
|
+
|
|
1415
|
+
Using `datachain.sql.functions`
|
|
1416
|
+
```py
|
|
1417
|
+
from datachain.sql.functions import string
|
|
1418
|
+
dc.filter(string.length(C("file.name")) > 5)
|
|
1419
|
+
```
|
|
1420
|
+
|
|
1421
|
+
Combining filters with "or"
|
|
1422
|
+
```py
|
|
1423
|
+
dc.filter(C("file.name").glob("cat*") | C("file.name").glob("dog*))
|
|
1424
|
+
```
|
|
1425
|
+
|
|
1426
|
+
Combining filters with "and"
|
|
1427
|
+
```py
|
|
1428
|
+
dc.filter(
|
|
1429
|
+
C("file.name").glob("*.jpg) &
|
|
1430
|
+
(string.length(C("file.name")) > 5)
|
|
1431
|
+
)
|
|
1432
|
+
```
|
|
1433
|
+
"""
|
|
1434
|
+
return super().filter(*args)
|
|
1435
|
+
|
|
1436
|
+
@detach
|
|
1437
|
+
def limit(self, n: int) -> "Self":
|
|
1438
|
+
"""Return the first n rows of the chain."""
|
|
1439
|
+
return super().limit(n)
|
|
1440
|
+
|
|
1441
|
+
@detach
|
|
1442
|
+
def offset(self, offset: int) -> "Self":
|
|
1443
|
+
"""Return the results starting with the offset row."""
|
|
1444
|
+
return super().offset(offset)
|