datachain 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +3 -4
- datachain/cache.py +10 -4
- datachain/catalog/catalog.py +35 -15
- datachain/cli.py +37 -32
- datachain/data_storage/metastore.py +24 -0
- datachain/data_storage/warehouse.py +3 -1
- datachain/job.py +56 -0
- datachain/lib/arrow.py +19 -7
- datachain/lib/clip.py +89 -66
- datachain/lib/convert/{type_converter.py → python_to_sql.py} +6 -6
- datachain/lib/convert/sql_to_python.py +23 -0
- datachain/lib/convert/values_to_tuples.py +51 -33
- datachain/lib/data_model.py +6 -27
- datachain/lib/dataset_info.py +70 -0
- datachain/lib/dc.py +646 -152
- datachain/lib/file.py +117 -15
- datachain/lib/image.py +1 -1
- datachain/lib/meta_formats.py +14 -2
- datachain/lib/model_store.py +3 -2
- datachain/lib/pytorch.py +10 -7
- datachain/lib/signal_schema.py +39 -14
- datachain/lib/text.py +2 -1
- datachain/lib/udf.py +56 -5
- datachain/lib/udf_signature.py +1 -1
- datachain/lib/webdataset.py +4 -3
- datachain/node.py +11 -8
- datachain/query/dataset.py +66 -147
- datachain/query/dispatch.py +15 -13
- datachain/query/schema.py +2 -0
- datachain/query/session.py +4 -4
- datachain/sql/functions/array.py +12 -0
- datachain/sql/functions/string.py +8 -0
- datachain/torch/__init__.py +1 -1
- datachain/utils.py +45 -0
- datachain-0.2.12.dist-info/METADATA +412 -0
- {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/RECORD +40 -45
- {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/WHEEL +1 -1
- datachain/lib/feature_registry.py +0 -77
- datachain/lib/gpt4_vision.py +0 -97
- datachain/lib/hf_image_to_text.py +0 -97
- datachain/lib/hf_pipeline.py +0 -90
- datachain/lib/image_transform.py +0 -103
- datachain/lib/iptc_exif_xmp.py +0 -76
- datachain/lib/unstructured.py +0 -41
- datachain/text/__init__.py +0 -3
- datachain-0.2.10.dist-info/METADATA +0 -430
- {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/LICENSE +0 -0
- {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/entry_points.txt +0 -0
- {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/top_level.txt +0 -0
datachain/lib/dc.py
CHANGED
|
@@ -1,22 +1,31 @@
|
|
|
1
1
|
import copy
|
|
2
|
+
import os
|
|
2
3
|
import re
|
|
3
4
|
from collections.abc import Iterator, Sequence
|
|
5
|
+
from functools import wraps
|
|
4
6
|
from typing import (
|
|
5
7
|
TYPE_CHECKING,
|
|
6
8
|
Any,
|
|
9
|
+
BinaryIO,
|
|
7
10
|
Callable,
|
|
8
11
|
ClassVar,
|
|
9
12
|
Literal,
|
|
10
13
|
Optional,
|
|
14
|
+
TypeVar,
|
|
11
15
|
Union,
|
|
16
|
+
overload,
|
|
12
17
|
)
|
|
13
18
|
|
|
19
|
+
import pandas as pd
|
|
14
20
|
import sqlalchemy
|
|
15
21
|
from pydantic import BaseModel, create_model
|
|
22
|
+
from sqlalchemy.sql.functions import GenericFunction
|
|
16
23
|
|
|
17
24
|
from datachain import DataModel
|
|
18
25
|
from datachain.lib.convert.values_to_tuples import values_to_tuples
|
|
19
26
|
from datachain.lib.data_model import DataType
|
|
27
|
+
from datachain.lib.dataset_info import DatasetInfo
|
|
28
|
+
from datachain.lib.file import ExportPlacement as FileExportPlacement
|
|
20
29
|
from datachain.lib.file import File, IndexedFile, get_file
|
|
21
30
|
from datachain.lib.meta_formats import read_meta, read_schema
|
|
22
31
|
from datachain.lib.model_store import ModelStore
|
|
@@ -24,7 +33,6 @@ from datachain.lib.settings import Settings
|
|
|
24
33
|
from datachain.lib.signal_schema import SignalSchema
|
|
25
34
|
from datachain.lib.udf import (
|
|
26
35
|
Aggregator,
|
|
27
|
-
BatchMapper,
|
|
28
36
|
Generator,
|
|
29
37
|
Mapper,
|
|
30
38
|
UDFBase,
|
|
@@ -38,29 +46,60 @@ from datachain.query.dataset import (
|
|
|
38
46
|
detach,
|
|
39
47
|
)
|
|
40
48
|
from datachain.query.schema import Column, DatasetRow
|
|
49
|
+
from datachain.utils import inside_notebook
|
|
41
50
|
|
|
42
51
|
if TYPE_CHECKING:
|
|
43
|
-
import
|
|
44
|
-
|
|
52
|
+
from typing_extensions import Concatenate, ParamSpec, Self
|
|
53
|
+
|
|
54
|
+
P = ParamSpec("P")
|
|
45
55
|
|
|
46
56
|
C = Column
|
|
47
57
|
|
|
58
|
+
_T = TypeVar("_T")
|
|
59
|
+
D = TypeVar("D", bound="DataChain")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def resolve_columns(
|
|
63
|
+
method: "Callable[Concatenate[D, P], D]",
|
|
64
|
+
) -> "Callable[Concatenate[D, P], D]":
|
|
65
|
+
"""Decorator that resolvs input column names to their actual DB names. This is
|
|
66
|
+
specially important for nested columns as user works with them by using dot
|
|
67
|
+
notation e.g (file.name) but are actually defined with default delimiter
|
|
68
|
+
in DB, e.g file__name.
|
|
69
|
+
If there are any sql functions in arguments, they will just be transferred as is
|
|
70
|
+
to a method.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
@wraps(method)
|
|
74
|
+
def _inner(self: D, *args: "P.args", **kwargs: "P.kwargs") -> D:
|
|
75
|
+
resolved_args = self.signals_schema.resolve(
|
|
76
|
+
*[arg for arg in args if not isinstance(arg, GenericFunction)]
|
|
77
|
+
).db_signals()
|
|
78
|
+
|
|
79
|
+
for idx, arg in enumerate(args):
|
|
80
|
+
if isinstance(arg, GenericFunction):
|
|
81
|
+
resolved_args.insert(idx, arg)
|
|
82
|
+
|
|
83
|
+
return method(self, *resolved_args, **kwargs)
|
|
84
|
+
|
|
85
|
+
return _inner
|
|
48
86
|
|
|
49
|
-
|
|
50
|
-
|
|
87
|
+
|
|
88
|
+
class DatasetPrepareError(DataChainParamsError): # noqa: D101
|
|
89
|
+
def __init__(self, name, msg, output=None): # noqa: D107
|
|
51
90
|
name = f" '{name}'" if name else ""
|
|
52
91
|
output = f" output '{output}'" if output else ""
|
|
53
92
|
super().__init__(f"Dataset{name}{output} processing prepare error: {msg}")
|
|
54
93
|
|
|
55
94
|
|
|
56
|
-
class DatasetFromValuesError(DataChainParamsError):
|
|
57
|
-
def __init__(self, name, msg):
|
|
95
|
+
class DatasetFromValuesError(DataChainParamsError): # noqa: D101
|
|
96
|
+
def __init__(self, name, msg): # noqa: D107
|
|
58
97
|
name = f" '{name}'" if name else ""
|
|
59
|
-
super().__init__(f"Dataset
|
|
98
|
+
super().__init__(f"Dataset{name} from values error: {msg}")
|
|
60
99
|
|
|
61
100
|
|
|
62
|
-
class DatasetMergeError(DataChainParamsError):
|
|
63
|
-
def __init__(self, on: Sequence[str], right_on: Optional[Sequence[str]], msg: str):
|
|
101
|
+
class DatasetMergeError(DataChainParamsError): # noqa: D101
|
|
102
|
+
def __init__(self, on: Sequence[str], right_on: Optional[Sequence[str]], msg: str): # noqa: D107
|
|
64
103
|
on_str = ", ".join(on) if isinstance(on, Sequence) else ""
|
|
65
104
|
right_on_str = (
|
|
66
105
|
", right_on='" + ", ".join(right_on) + "'"
|
|
@@ -74,6 +113,8 @@ OutputType = Union[None, DataType, Sequence[str], dict[str, DataType]]
|
|
|
74
113
|
|
|
75
114
|
|
|
76
115
|
class Sys(DataModel):
|
|
116
|
+
"""Model for internal DataChain signals `id` and `rand`."""
|
|
117
|
+
|
|
77
118
|
id: int
|
|
78
119
|
rand: int
|
|
79
120
|
|
|
@@ -86,7 +127,7 @@ class DataChain(DatasetQuery):
|
|
|
86
127
|
enrich data.
|
|
87
128
|
|
|
88
129
|
Data in DataChain is presented as Python classes with arbitrary set of fields,
|
|
89
|
-
including nested classes. The data classes have to inherit from `
|
|
130
|
+
including nested classes. The data classes have to inherit from `DataModel` class.
|
|
90
131
|
The supported set of field types include: majority of the type supported by the
|
|
91
132
|
underlyind library `Pydantic`.
|
|
92
133
|
|
|
@@ -98,34 +139,56 @@ class DataChain(DatasetQuery):
|
|
|
98
139
|
|
|
99
140
|
`DataChain.from_dataset("name")` - reading from a dataset.
|
|
100
141
|
|
|
101
|
-
`DataChain.
|
|
142
|
+
`DataChain.from_values(fib=[1, 2, 3, 5, 8])` - generating from values.
|
|
143
|
+
|
|
144
|
+
`DataChain.from_pandas(pd.DataFrame(...))` - generating from pandas.
|
|
145
|
+
|
|
146
|
+
`DataChain.from_json("file.json")` - generating from json.
|
|
102
147
|
|
|
148
|
+
`DataChain.from_csv("file.csv")` - generating from csv.
|
|
149
|
+
|
|
150
|
+
`DataChain.from_parquet("file.parquet")` - generating from parquet.
|
|
103
151
|
|
|
104
152
|
Example:
|
|
105
153
|
```py
|
|
106
|
-
|
|
107
|
-
|
|
154
|
+
import os
|
|
155
|
+
|
|
156
|
+
from mistralai.client import MistralClient
|
|
157
|
+
from mistralai.models.chat_completion import ChatMessage
|
|
158
|
+
|
|
159
|
+
from datachain.dc import DataChain, Column
|
|
108
160
|
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
161
|
+
PROMPT = (
|
|
162
|
+
"Was this bot dialog successful? "
|
|
163
|
+
"Describe the 'result' as 'Yes' or 'No' in a short JSON"
|
|
164
|
+
)
|
|
112
165
|
|
|
113
|
-
|
|
114
|
-
|
|
166
|
+
model = "mistral-large-latest"
|
|
167
|
+
api_key = os.environ["MISTRAL_API_KEY"]
|
|
115
168
|
|
|
116
169
|
chain = (
|
|
117
|
-
DataChain.from_storage("
|
|
118
|
-
.filter(C.name.glob("*.txt"))
|
|
170
|
+
DataChain.from_storage("gs://datachain-demo/chatbot-KiT/")
|
|
119
171
|
.limit(5)
|
|
120
|
-
.
|
|
172
|
+
.settings(cache=True, parallel=5)
|
|
121
173
|
.map(
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
174
|
+
mistral_response=lambda file: MistralClient(api_key=api_key)
|
|
175
|
+
.chat(
|
|
176
|
+
model=model,
|
|
177
|
+
response_format={"type": "json_object"},
|
|
178
|
+
messages=[
|
|
179
|
+
ChatMessage(role="user", content=f"{PROMPT}: {file.read()}")
|
|
180
|
+
],
|
|
181
|
+
)
|
|
182
|
+
.choices[0]
|
|
183
|
+
.message.content,
|
|
184
|
+
)
|
|
185
|
+
.save()
|
|
126
186
|
)
|
|
127
|
-
|
|
128
|
-
|
|
187
|
+
|
|
188
|
+
try:
|
|
189
|
+
print(chain.select("mistral_response").results())
|
|
190
|
+
except Exception as e:
|
|
191
|
+
print(f"do you have the right Mistral API key? {e}")
|
|
129
192
|
```
|
|
130
193
|
"""
|
|
131
194
|
|
|
@@ -137,8 +200,9 @@ class DataChain(DatasetQuery):
|
|
|
137
200
|
}
|
|
138
201
|
|
|
139
202
|
def __init__(self, *args, **kwargs):
|
|
140
|
-
"""This method needs to be redefined as a part of Dataset and
|
|
141
|
-
decoupling.
|
|
203
|
+
"""This method needs to be redefined as a part of Dataset and DataChain
|
|
204
|
+
decoupling.
|
|
205
|
+
"""
|
|
142
206
|
super().__init__(
|
|
143
207
|
*args,
|
|
144
208
|
**kwargs,
|
|
@@ -147,19 +211,25 @@ class DataChain(DatasetQuery):
|
|
|
147
211
|
self._settings = Settings()
|
|
148
212
|
self._setup = {}
|
|
149
213
|
|
|
214
|
+
self.signals_schema = SignalSchema({"sys": Sys})
|
|
150
215
|
if self.feature_schema:
|
|
151
|
-
self.signals_schema
|
|
216
|
+
self.signals_schema |= SignalSchema.deserialize(self.feature_schema)
|
|
152
217
|
else:
|
|
153
|
-
self.signals_schema
|
|
218
|
+
self.signals_schema |= SignalSchema.from_column_types(self.column_types)
|
|
219
|
+
|
|
220
|
+
self._sys = False
|
|
154
221
|
|
|
155
222
|
@property
|
|
156
|
-
def schema(self):
|
|
157
|
-
|
|
223
|
+
def schema(self) -> dict[str, DataType]:
|
|
224
|
+
"""Get schema of the chain."""
|
|
225
|
+
return self._effective_signals_schema.values
|
|
158
226
|
|
|
159
|
-
def print_schema(self):
|
|
160
|
-
|
|
227
|
+
def print_schema(self) -> None:
|
|
228
|
+
"""Print schema of the chain."""
|
|
229
|
+
self._effective_signals_schema.print_tree()
|
|
161
230
|
|
|
162
231
|
def clone(self, new_table: bool = True) -> "Self":
|
|
232
|
+
"""Make a copy of the chain in a new table."""
|
|
163
233
|
obj = super().clone(new_table=new_table)
|
|
164
234
|
obj.signals_schema = copy.deepcopy(self.signals_schema)
|
|
165
235
|
return obj
|
|
@@ -171,7 +241,7 @@ class DataChain(DatasetQuery):
|
|
|
171
241
|
parallel=None,
|
|
172
242
|
workers=None,
|
|
173
243
|
min_task_size=None,
|
|
174
|
-
|
|
244
|
+
sys: Optional[bool] = None,
|
|
175
245
|
) -> "Self":
|
|
176
246
|
"""Change settings for chain.
|
|
177
247
|
|
|
@@ -196,10 +266,8 @@ class DataChain(DatasetQuery):
|
|
|
196
266
|
```
|
|
197
267
|
"""
|
|
198
268
|
chain = self.clone()
|
|
199
|
-
if
|
|
200
|
-
chain.
|
|
201
|
-
elif include_sys is False and "sys" in chain.signals_schema:
|
|
202
|
-
chain.signals_schema.remove("sys")
|
|
269
|
+
if sys is not None:
|
|
270
|
+
chain._sys = sys
|
|
203
271
|
chain._settings.add(Settings(cache, batch, parallel, workers, min_task_size))
|
|
204
272
|
return chain
|
|
205
273
|
|
|
@@ -208,17 +276,14 @@ class DataChain(DatasetQuery):
|
|
|
208
276
|
self._settings = settings if settings else Settings()
|
|
209
277
|
return self
|
|
210
278
|
|
|
211
|
-
def reset_schema(self, signals_schema: SignalSchema) -> "Self":
|
|
279
|
+
def reset_schema(self, signals_schema: SignalSchema) -> "Self": # noqa: D102
|
|
212
280
|
self.signals_schema = signals_schema
|
|
213
281
|
return self
|
|
214
282
|
|
|
215
|
-
def add_schema(self, signals_schema: SignalSchema) -> "Self":
|
|
283
|
+
def add_schema(self, signals_schema: SignalSchema) -> "Self": # noqa: D102
|
|
216
284
|
self.signals_schema |= signals_schema
|
|
217
285
|
return self
|
|
218
286
|
|
|
219
|
-
def get_file_signals(self) -> list[str]:
|
|
220
|
-
return list(self.signals_schema.get_file_signals())
|
|
221
|
-
|
|
222
287
|
@classmethod
|
|
223
288
|
def from_storage(
|
|
224
289
|
cls,
|
|
@@ -228,10 +293,11 @@ class DataChain(DatasetQuery):
|
|
|
228
293
|
session: Optional[Session] = None,
|
|
229
294
|
recursive: Optional[bool] = True,
|
|
230
295
|
object_name: str = "file",
|
|
296
|
+
update: bool = False,
|
|
231
297
|
**kwargs,
|
|
232
298
|
) -> "Self":
|
|
233
|
-
"""Get data from a storage as a list of file with all file attributes.
|
|
234
|
-
returns the chain itself as usual.
|
|
299
|
+
"""Get data from a storage as a list of file with all file attributes.
|
|
300
|
+
It returns the chain itself as usual.
|
|
235
301
|
|
|
236
302
|
Parameters:
|
|
237
303
|
path : storage URI with directory. URI must start with storage prefix such
|
|
@@ -239,6 +305,7 @@ class DataChain(DatasetQuery):
|
|
|
239
305
|
type : read file as "binary", "text", or "image" data. Default is "binary".
|
|
240
306
|
recursive : search recursively for the given path.
|
|
241
307
|
object_name : Created object column name.
|
|
308
|
+
update : force storage reindexing. Default is False.
|
|
242
309
|
|
|
243
310
|
Example:
|
|
244
311
|
```py
|
|
@@ -246,20 +313,24 @@ class DataChain(DatasetQuery):
|
|
|
246
313
|
```
|
|
247
314
|
"""
|
|
248
315
|
func = get_file(type)
|
|
249
|
-
return
|
|
250
|
-
**
|
|
316
|
+
return (
|
|
317
|
+
cls(path, session=session, recursive=recursive, update=update, **kwargs)
|
|
318
|
+
.map(**{object_name: func})
|
|
319
|
+
.select(object_name)
|
|
251
320
|
)
|
|
252
321
|
|
|
253
322
|
@classmethod
|
|
254
323
|
def from_dataset(cls, name: str, version: Optional[int] = None) -> "DataChain":
|
|
255
|
-
"""Get data from
|
|
324
|
+
"""Get data from a saved Dataset. It returns the chain itself.
|
|
256
325
|
|
|
257
326
|
Parameters:
|
|
258
327
|
name : dataset name
|
|
259
328
|
version : dataset version
|
|
260
329
|
|
|
261
|
-
|
|
262
|
-
|
|
330
|
+
Example:
|
|
331
|
+
```py
|
|
332
|
+
chain = DataChain.from_dataset("my_cats")
|
|
333
|
+
```
|
|
263
334
|
"""
|
|
264
335
|
return DataChain(name=name, version=version)
|
|
265
336
|
|
|
@@ -275,6 +346,7 @@ class DataChain(DatasetQuery):
|
|
|
275
346
|
model_name: Optional[str] = None,
|
|
276
347
|
show_schema: Optional[bool] = False,
|
|
277
348
|
meta_type: Optional[str] = "json",
|
|
349
|
+
nrows=None,
|
|
278
350
|
**kwargs,
|
|
279
351
|
) -> "DataChain":
|
|
280
352
|
"""Get data from JSON. It returns the chain itself.
|
|
@@ -284,18 +356,23 @@ class DataChain(DatasetQuery):
|
|
|
284
356
|
as `s3://`, `gs://`, `az://` or "file:///"
|
|
285
357
|
type : read file as "binary", "text", or "image" data. Default is "binary".
|
|
286
358
|
spec : optional Data Model
|
|
287
|
-
schema_from : path to sample to infer spec
|
|
359
|
+
schema_from : path to sample to infer spec (if schema not provided)
|
|
288
360
|
object_name : generated object column name
|
|
289
|
-
model_name : generated model name
|
|
361
|
+
model_name : optional generated model name
|
|
290
362
|
show_schema : print auto-generated schema
|
|
291
|
-
jmespath : JMESPATH expression to reduce JSON
|
|
363
|
+
jmespath : optional JMESPATH expression to reduce JSON
|
|
364
|
+
nrows : optional row limit for jsonl and JSON arrays
|
|
292
365
|
|
|
293
|
-
|
|
366
|
+
Example:
|
|
294
367
|
infer JSON schema from data, reduce using JMESPATH, print schema
|
|
295
|
-
|
|
368
|
+
```py
|
|
369
|
+
chain = DataChain.from_json("gs://json", jmespath="key1.key2")
|
|
370
|
+
```
|
|
296
371
|
|
|
297
372
|
infer JSON schema from a particular path, print data model
|
|
298
|
-
|
|
373
|
+
```py
|
|
374
|
+
chain = DataChain.from_json("gs://json_ds", schema_from="gs://json/my.json")
|
|
375
|
+
```
|
|
299
376
|
"""
|
|
300
377
|
if schema_from == "auto":
|
|
301
378
|
schema_from = path
|
|
@@ -317,10 +394,40 @@ class DataChain(DatasetQuery):
|
|
|
317
394
|
model_name=model_name,
|
|
318
395
|
show_schema=show_schema,
|
|
319
396
|
jmespath=jmespath,
|
|
397
|
+
nrows=nrows,
|
|
320
398
|
)
|
|
321
399
|
}
|
|
322
400
|
return chain.gen(**signal_dict) # type: ignore[arg-type]
|
|
323
401
|
|
|
402
|
+
@classmethod
|
|
403
|
+
def datasets(
|
|
404
|
+
cls, session: Optional[Session] = None, object_name: str = "dataset"
|
|
405
|
+
) -> "DataChain":
|
|
406
|
+
"""Generate chain with list of registered datasets.
|
|
407
|
+
|
|
408
|
+
Example:
|
|
409
|
+
```py
|
|
410
|
+
from datachain import DataChain
|
|
411
|
+
|
|
412
|
+
chain = DataChain.datasets()
|
|
413
|
+
for ds in chain.collect("dataset"):
|
|
414
|
+
print(f"{ds.name}@v{ds.version}")
|
|
415
|
+
```
|
|
416
|
+
"""
|
|
417
|
+
session = Session.get(session)
|
|
418
|
+
catalog = session.catalog
|
|
419
|
+
|
|
420
|
+
datasets = [
|
|
421
|
+
DatasetInfo.from_models(d, v, j)
|
|
422
|
+
for d, v, j in catalog.list_datasets_versions()
|
|
423
|
+
]
|
|
424
|
+
|
|
425
|
+
return cls.from_values(
|
|
426
|
+
session=session,
|
|
427
|
+
output={object_name: DatasetInfo},
|
|
428
|
+
**{object_name: datasets}, # type: ignore[arg-type]
|
|
429
|
+
)
|
|
430
|
+
|
|
324
431
|
def show_json_schema( # type: ignore[override]
|
|
325
432
|
self, jmespath: Optional[str] = None, model_name: Optional[str] = None
|
|
326
433
|
) -> "DataChain":
|
|
@@ -330,12 +437,14 @@ class DataChain(DatasetQuery):
|
|
|
330
437
|
jmespath : JMESPATH expression to reduce JSON
|
|
331
438
|
model_name : generated model name
|
|
332
439
|
|
|
333
|
-
|
|
440
|
+
Example:
|
|
334
441
|
print JSON schema and save to column "meta_from":
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
442
|
+
```py
|
|
443
|
+
uri = "gs://datachain-demo/coco2017/annotations_captions/"
|
|
444
|
+
chain = DataChain.from_storage(uri)
|
|
445
|
+
chain = chain.show_json_schema()
|
|
446
|
+
chain.save()
|
|
447
|
+
```
|
|
339
448
|
"""
|
|
340
449
|
return self.map(
|
|
341
450
|
meta_schema=lambda file: read_schema(
|
|
@@ -370,11 +479,29 @@ class DataChain(DatasetQuery):
|
|
|
370
479
|
removed after process ends. Temp dataset are useful for optimization.
|
|
371
480
|
version : version of a dataset. Default - the last version that exist.
|
|
372
481
|
"""
|
|
373
|
-
schema = self.signals_schema.serialize()
|
|
374
|
-
schema.pop("sys", None)
|
|
482
|
+
schema = self.signals_schema.clone_without_sys_signals().serialize()
|
|
375
483
|
return super().save(name=name, version=version, feature_schema=schema)
|
|
376
484
|
|
|
377
485
|
def apply(self, func, *args, **kwargs):
|
|
486
|
+
"""Apply any function to the chain.
|
|
487
|
+
|
|
488
|
+
Useful for reusing in a chain of operations.
|
|
489
|
+
|
|
490
|
+
Example:
|
|
491
|
+
```py
|
|
492
|
+
def parse_stem(chain):
|
|
493
|
+
return chain.map(
|
|
494
|
+
lambda file: file.get_file_stem()
|
|
495
|
+
output={"stem": str}
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
chain = (
|
|
499
|
+
DataChain.from_storage("s3://my-bucket")
|
|
500
|
+
.apply(parse_stem)
|
|
501
|
+
.filter(C("stem").glob("*cat*"))
|
|
502
|
+
)
|
|
503
|
+
```
|
|
504
|
+
"""
|
|
378
505
|
return func(self, *args, **kwargs)
|
|
379
506
|
|
|
380
507
|
def map(
|
|
@@ -402,16 +529,19 @@ class DataChain(DatasetQuery):
|
|
|
402
529
|
signal name in format of `map(my_sign=my_func)`. This helps define
|
|
403
530
|
signal names and function in a nicer way.
|
|
404
531
|
|
|
405
|
-
|
|
532
|
+
Example:
|
|
406
533
|
Using signal_map and single type in output:
|
|
407
|
-
|
|
408
|
-
|
|
534
|
+
```py
|
|
535
|
+
chain = chain.map(value=lambda name: name[:-4] + ".json", output=str)
|
|
536
|
+
chain.save("new_dataset")
|
|
537
|
+
```
|
|
409
538
|
|
|
410
539
|
Using func and output as a map:
|
|
411
|
-
|
|
412
|
-
|
|
540
|
+
```py
|
|
541
|
+
chain = chain.map(lambda name: name[:-4] + ".json", output={"res": str})
|
|
542
|
+
chain.save("new_dataset")
|
|
543
|
+
```
|
|
413
544
|
"""
|
|
414
|
-
|
|
415
545
|
udf_obj = self._udf_to_obj(Mapper, func, params, output, signal_map)
|
|
416
546
|
|
|
417
547
|
chain = self.add_signals(
|
|
@@ -439,7 +569,6 @@ class DataChain(DatasetQuery):
|
|
|
439
569
|
extracting multiple file records from a single tar file or bounding boxes from a
|
|
440
570
|
single image file).
|
|
441
571
|
"""
|
|
442
|
-
|
|
443
572
|
udf_obj = self._udf_to_obj(Generator, func, params, output, signal_map)
|
|
444
573
|
chain = DatasetQuery.generate(
|
|
445
574
|
self,
|
|
@@ -480,27 +609,6 @@ class DataChain(DatasetQuery):
|
|
|
480
609
|
|
|
481
610
|
return chain.reset_schema(udf_obj.output).reset_settings(self._settings)
|
|
482
611
|
|
|
483
|
-
def batch_map(
|
|
484
|
-
self,
|
|
485
|
-
func: Optional[Callable] = None,
|
|
486
|
-
params: Union[None, str, Sequence[str]] = None,
|
|
487
|
-
output: OutputType = None,
|
|
488
|
-
**signal_map,
|
|
489
|
-
) -> "Self":
|
|
490
|
-
"""This is a batch version of map().
|
|
491
|
-
|
|
492
|
-
It accepts the same parameters plus an
|
|
493
|
-
additional parameter:
|
|
494
|
-
"""
|
|
495
|
-
udf_obj = self._udf_to_obj(BatchMapper, func, params, output, signal_map)
|
|
496
|
-
chain = DatasetQuery.generate(
|
|
497
|
-
self,
|
|
498
|
-
udf_obj.to_udf_wrapper(self._settings.batch),
|
|
499
|
-
**self._settings.to_dict(),
|
|
500
|
-
)
|
|
501
|
-
|
|
502
|
-
return chain.add_schema(udf_obj.output).reset_settings(self._settings)
|
|
503
|
-
|
|
504
612
|
def _udf_to_obj(
|
|
505
613
|
self,
|
|
506
614
|
target_class: type[UDFBase],
|
|
@@ -515,7 +623,11 @@ class DataChain(DatasetQuery):
|
|
|
515
623
|
sign = UdfSignature.parse(name, signal_map, func, params, output, is_generator)
|
|
516
624
|
DataModel.register(list(sign.output_schema.values.values()))
|
|
517
625
|
|
|
518
|
-
|
|
626
|
+
signals_schema = self.signals_schema
|
|
627
|
+
if self._sys:
|
|
628
|
+
signals_schema = SignalSchema({"sys": Sys}) | signals_schema
|
|
629
|
+
|
|
630
|
+
params_schema = signals_schema.slice(sign.params, self._setup)
|
|
519
631
|
|
|
520
632
|
return target_class._create(sign, params_schema)
|
|
521
633
|
|
|
@@ -531,9 +643,38 @@ class DataChain(DatasetQuery):
|
|
|
531
643
|
return res
|
|
532
644
|
|
|
533
645
|
@detach
|
|
534
|
-
|
|
646
|
+
@resolve_columns
|
|
647
|
+
def order_by(self, *args, descending: bool = False) -> "Self":
|
|
648
|
+
"""Orders by specified set of signals.
|
|
649
|
+
|
|
650
|
+
Parameters:
|
|
651
|
+
descending (bool): Whether to sort in descending order or not.
|
|
652
|
+
"""
|
|
653
|
+
if descending:
|
|
654
|
+
args = tuple([sqlalchemy.desc(a) for a in args])
|
|
655
|
+
|
|
656
|
+
return super().order_by(*args)
|
|
657
|
+
|
|
658
|
+
@detach
|
|
659
|
+
def distinct(self, arg: str, *args: str) -> "Self": # type: ignore[override]
|
|
660
|
+
"""Removes duplicate rows based on uniqueness of some input column(s)
|
|
661
|
+
i.e if rows are found with the same value of input column(s), only one
|
|
662
|
+
row is left in the result set.
|
|
663
|
+
|
|
664
|
+
Example:
|
|
665
|
+
```py
|
|
666
|
+
dc.distinct("file.parent", "file.name")
|
|
667
|
+
)
|
|
668
|
+
```
|
|
669
|
+
"""
|
|
670
|
+
return super().distinct(*self.signals_schema.resolve(arg, *args).db_signals())
|
|
671
|
+
|
|
672
|
+
@detach
|
|
673
|
+
def select(self, *args: str, _sys: bool = True) -> "Self":
|
|
535
674
|
"""Select only a specified set of signals."""
|
|
536
675
|
new_schema = self.signals_schema.resolve(*args)
|
|
676
|
+
if _sys:
|
|
677
|
+
new_schema = SignalSchema({"sys": Sys}) | new_schema
|
|
537
678
|
columns = new_schema.db_signals()
|
|
538
679
|
chain = super().select(*columns)
|
|
539
680
|
chain.signals_schema = new_schema
|
|
@@ -548,45 +689,156 @@ class DataChain(DatasetQuery):
|
|
|
548
689
|
chain.signals_schema = new_schema
|
|
549
690
|
return chain
|
|
550
691
|
|
|
551
|
-
|
|
552
|
-
|
|
692
|
+
@detach
|
|
693
|
+
def mutate(self, **kwargs) -> "Self":
|
|
694
|
+
"""Create new signals based on existing signals.
|
|
695
|
+
|
|
696
|
+
This method is vectorized and more efficient compared to map(), and it does not
|
|
697
|
+
extract or download any data from the internal database. However, it can only
|
|
698
|
+
utilize predefined built-in functions and their combinations.
|
|
699
|
+
|
|
700
|
+
The supported functions:
|
|
701
|
+
Numerical: +, -, *, /, rand(), avg(), count(), func(),
|
|
702
|
+
greatest(), least(), max(), min(), sum()
|
|
703
|
+
String: length(), split()
|
|
704
|
+
Filename: name(), parent(), file_stem(), file_ext()
|
|
705
|
+
Array: length(), sip_hash_64(), euclidean_distance(),
|
|
706
|
+
cosine_distance()
|
|
707
|
+
|
|
708
|
+
Example:
|
|
709
|
+
```py
|
|
710
|
+
dc.mutate(
|
|
711
|
+
area=Column("image.height") * Column("image.width"),
|
|
712
|
+
extension=file_ext(Column("file.name")),
|
|
713
|
+
dist=cosine_distance(embedding_text, embedding_image)
|
|
714
|
+
)
|
|
715
|
+
```
|
|
716
|
+
"""
|
|
717
|
+
chain = super().mutate(**kwargs)
|
|
718
|
+
chain.signals_schema = self.signals_schema.mutate(kwargs)
|
|
719
|
+
return chain
|
|
720
|
+
|
|
721
|
+
@property
|
|
722
|
+
def _effective_signals_schema(self) -> "SignalSchema":
|
|
723
|
+
"""Effective schema used for user-facing API like collect, to_pandas, etc."""
|
|
724
|
+
signals_schema = self.signals_schema
|
|
725
|
+
if not self._sys:
|
|
726
|
+
return signals_schema.clone_without_sys_signals()
|
|
727
|
+
return signals_schema
|
|
728
|
+
|
|
729
|
+
@overload
|
|
730
|
+
def collect_flatten(self) -> Iterator[tuple[Any, ...]]: ...
|
|
731
|
+
|
|
732
|
+
@overload
|
|
733
|
+
def collect_flatten(
|
|
734
|
+
self, *, row_factory: Callable[[list[str], tuple[Any, ...]], _T]
|
|
735
|
+
) -> Iterator[_T]: ...
|
|
736
|
+
|
|
737
|
+
def collect_flatten(self, *, row_factory=None):
|
|
738
|
+
"""Yields flattened rows of values as a tuple.
|
|
739
|
+
|
|
740
|
+
Args:
|
|
741
|
+
row_factory : A callable to convert row to a custom format.
|
|
742
|
+
It should accept two arguments: a list of column names and
|
|
743
|
+
a tuple of row values.
|
|
744
|
+
"""
|
|
745
|
+
db_signals = self._effective_signals_schema.db_signals()
|
|
553
746
|
with super().select(*db_signals).as_iterable() as rows:
|
|
747
|
+
if row_factory:
|
|
748
|
+
rows = (row_factory(db_signals, r) for r in rows)
|
|
554
749
|
yield from rows
|
|
555
750
|
|
|
751
|
+
@overload
|
|
752
|
+
def results(self) -> list[tuple[Any, ...]]: ...
|
|
753
|
+
|
|
754
|
+
@overload
|
|
556
755
|
def results(
|
|
557
|
-
self, row_factory:
|
|
558
|
-
) -> list[
|
|
559
|
-
rows = self.iterate_flatten()
|
|
560
|
-
if row_factory:
|
|
561
|
-
db_signals = self.signals_schema.db_signals()
|
|
562
|
-
rows = (row_factory(db_signals, r) for r in rows)
|
|
563
|
-
return list(rows)
|
|
756
|
+
self, *, row_factory: Callable[[list[str], tuple[Any, ...]], _T]
|
|
757
|
+
) -> list[_T]: ...
|
|
564
758
|
|
|
565
|
-
def
|
|
566
|
-
|
|
759
|
+
def results(self, *, row_factory=None): # noqa: D102
|
|
760
|
+
if row_factory is None:
|
|
761
|
+
return list(self.collect_flatten())
|
|
762
|
+
return list(self.collect_flatten(row_factory=row_factory))
|
|
567
763
|
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
764
|
+
def to_records(self) -> list[dict[str, Any]]:
|
|
765
|
+
"""Convert every row to a dictionary."""
|
|
766
|
+
|
|
767
|
+
def to_dict(cols: list[str], row: tuple[Any, ...]) -> dict[str, Any]:
|
|
768
|
+
return dict(zip(cols, row))
|
|
769
|
+
|
|
770
|
+
return self.results(row_factory=to_dict)
|
|
771
|
+
|
|
772
|
+
@overload
|
|
773
|
+
def collect(self) -> Iterator[tuple[DataType, ...]]: ...
|
|
774
|
+
|
|
775
|
+
@overload
|
|
776
|
+
def collect(self, col: str) -> Iterator[DataType]: ... # type: ignore[overload-overlap]
|
|
576
777
|
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
yield item[0]
|
|
778
|
+
@overload
|
|
779
|
+
def collect(self, *cols: str) -> Iterator[tuple[DataType, ...]]: ...
|
|
580
780
|
|
|
581
|
-
def collect(self, *cols: str) ->
|
|
582
|
-
|
|
781
|
+
def collect(self, *cols: str) -> Iterator[Union[DataType, tuple[DataType, ...]]]: # type: ignore[overload-overlap,misc]
|
|
782
|
+
"""Yields rows of values, optionally limited to the specified columns.
|
|
583
783
|
|
|
584
|
-
|
|
585
|
-
|
|
784
|
+
Args:
|
|
785
|
+
*cols: Limit to the specified columns. By default, all columns are selected.
|
|
586
786
|
|
|
587
|
-
|
|
588
|
-
|
|
787
|
+
Yields:
|
|
788
|
+
(DataType): Yields a single item if a column is selected.
|
|
789
|
+
(tuple[DataType, ...]): Yields a tuple of items if multiple columns are
|
|
790
|
+
selected.
|
|
589
791
|
|
|
792
|
+
Example:
|
|
793
|
+
Iterating over all rows:
|
|
794
|
+
```py
|
|
795
|
+
for row in dc.collect():
|
|
796
|
+
print(row)
|
|
797
|
+
```
|
|
798
|
+
|
|
799
|
+
Iterating over all rows with selected columns:
|
|
800
|
+
```py
|
|
801
|
+
for name, size in dc.collect("file.name", "file.size"):
|
|
802
|
+
print(name, size)
|
|
803
|
+
```
|
|
804
|
+
|
|
805
|
+
Iterating over a single column:
|
|
806
|
+
```py
|
|
807
|
+
for file in dc.collect("file.name"):
|
|
808
|
+
print(file)
|
|
809
|
+
```
|
|
810
|
+
"""
|
|
811
|
+
chain = self.select(*cols) if cols else self
|
|
812
|
+
signals_schema = chain._effective_signals_schema
|
|
813
|
+
db_signals = signals_schema.db_signals()
|
|
814
|
+
with super().select(*db_signals).as_iterable() as rows:
|
|
815
|
+
for row in rows:
|
|
816
|
+
ret = signals_schema.row_to_features(
|
|
817
|
+
row, catalog=chain.session.catalog, cache=chain._settings.cache
|
|
818
|
+
)
|
|
819
|
+
yield ret[0] if len(cols) == 1 else tuple(ret)
|
|
820
|
+
|
|
821
|
+
def to_pytorch(
|
|
822
|
+
self, transform=None, tokenizer=None, tokenizer_kwargs=None, num_samples=0
|
|
823
|
+
):
|
|
824
|
+
"""Convert to pytorch dataset format.
|
|
825
|
+
|
|
826
|
+
Args:
|
|
827
|
+
transform (Transform): Torchvision transforms to apply to the dataset.
|
|
828
|
+
tokenizer (Callable): Tokenizer to use to tokenize text values.
|
|
829
|
+
tokenizer_kwargs (dict): Additional kwargs to pass when calling tokenizer.
|
|
830
|
+
num_samples (int): Number of random samples to draw for each epoch.
|
|
831
|
+
This argument is ignored if `num_samples=0` (the default).
|
|
832
|
+
|
|
833
|
+
Example:
|
|
834
|
+
```py
|
|
835
|
+
from torch.utils.data import DataLoader
|
|
836
|
+
loader = DataLoader(
|
|
837
|
+
chain.select("file", "label").to_pytorch(),
|
|
838
|
+
batch_size=16
|
|
839
|
+
)
|
|
840
|
+
```
|
|
841
|
+
"""
|
|
590
842
|
from datachain.torch import PytorchDataset
|
|
591
843
|
|
|
592
844
|
if self.attached:
|
|
@@ -594,9 +846,17 @@ class DataChain(DatasetQuery):
|
|
|
594
846
|
else:
|
|
595
847
|
chain = self.save()
|
|
596
848
|
assert chain.name is not None # for mypy
|
|
597
|
-
return PytorchDataset(
|
|
849
|
+
return PytorchDataset(
|
|
850
|
+
chain.name,
|
|
851
|
+
chain.version,
|
|
852
|
+
catalog=self.catalog,
|
|
853
|
+
transform=transform,
|
|
854
|
+
tokenizer=tokenizer,
|
|
855
|
+
tokenizer_kwargs=tokenizer_kwargs,
|
|
856
|
+
num_samples=num_samples,
|
|
857
|
+
)
|
|
598
858
|
|
|
599
|
-
def remove_file_signals(self) -> "Self":
|
|
859
|
+
def remove_file_signals(self) -> "Self": # noqa: D102
|
|
600
860
|
schema = self.signals_schema.clone_without_file_signals()
|
|
601
861
|
return self.select(*schema.values.keys())
|
|
602
862
|
|
|
@@ -621,9 +881,11 @@ class DataChain(DatasetQuery):
|
|
|
621
881
|
inner (bool): Whether to run inner join or outer join.
|
|
622
882
|
rname (str): name prefix for conflicting signal names.
|
|
623
883
|
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
884
|
+
Example:
|
|
885
|
+
```py
|
|
886
|
+
meta = meta_emd.merge(meta_pq, on=(C.name, C.emd__index),
|
|
887
|
+
right_on=(C.name, C.pq__index))
|
|
888
|
+
```
|
|
627
889
|
"""
|
|
628
890
|
if on is None:
|
|
629
891
|
raise DatasetMergeError(["None"], None, "'on' must be specified")
|
|
@@ -637,8 +899,10 @@ class DataChain(DatasetQuery):
|
|
|
637
899
|
f"'on' must be 'str' or 'Sequence' object but got type '{type(on)}'",
|
|
638
900
|
)
|
|
639
901
|
|
|
640
|
-
|
|
902
|
+
signals_schema = self.signals_schema.clone_without_sys_signals()
|
|
903
|
+
on_columns = signals_schema.resolve(*on).db_signals()
|
|
641
904
|
|
|
905
|
+
right_signals_schema = right_ds.signals_schema.clone_without_sys_signals()
|
|
642
906
|
if right_on is not None:
|
|
643
907
|
if isinstance(right_on, str):
|
|
644
908
|
right_on = [right_on]
|
|
@@ -655,7 +919,7 @@ class DataChain(DatasetQuery):
|
|
|
655
919
|
on, right_on, "'on' and 'right_on' must have the same length'"
|
|
656
920
|
)
|
|
657
921
|
|
|
658
|
-
right_on_columns =
|
|
922
|
+
right_on_columns = right_signals_schema.resolve(*right_on).db_signals()
|
|
659
923
|
|
|
660
924
|
if len(right_on_columns) != len(on_columns):
|
|
661
925
|
on_str = ", ".join(right_on_columns)
|
|
@@ -681,7 +945,9 @@ class DataChain(DatasetQuery):
|
|
|
681
945
|
ds = self.join(right_ds, sqlalchemy.and_(*ops), inner, rname + "{name}")
|
|
682
946
|
|
|
683
947
|
ds.feature_schema = None
|
|
684
|
-
ds.signals_schema =
|
|
948
|
+
ds.signals_schema = SignalSchema({"sys": Sys}) | signals_schema.merge(
|
|
949
|
+
right_signals_schema, rname
|
|
950
|
+
)
|
|
685
951
|
|
|
686
952
|
return ds
|
|
687
953
|
|
|
@@ -694,7 +960,13 @@ class DataChain(DatasetQuery):
|
|
|
694
960
|
object_name: str = "",
|
|
695
961
|
**fr_map,
|
|
696
962
|
) -> "DataChain":
|
|
697
|
-
"""Generate chain from list of values.
|
|
963
|
+
"""Generate chain from list of values.
|
|
964
|
+
|
|
965
|
+
Example:
|
|
966
|
+
```py
|
|
967
|
+
DataChain.from_values(fib=[1, 2, 3, 5, 8])
|
|
968
|
+
```
|
|
969
|
+
"""
|
|
698
970
|
tuple_type, output, tuples = values_to_tuples(ds_name, output, **fr_map)
|
|
699
971
|
|
|
700
972
|
def _func_fr() -> Iterator[tuple_type]: # type: ignore[valid-type]
|
|
@@ -713,7 +985,16 @@ class DataChain(DatasetQuery):
|
|
|
713
985
|
session: Optional[Session] = None,
|
|
714
986
|
object_name: str = "",
|
|
715
987
|
) -> "DataChain":
|
|
716
|
-
"""Generate chain from pandas data-frame.
|
|
988
|
+
"""Generate chain from pandas data-frame.
|
|
989
|
+
|
|
990
|
+
Example:
|
|
991
|
+
```py
|
|
992
|
+
import pandas as pd
|
|
993
|
+
|
|
994
|
+
df = pd.DataFrame({"fib": [1, 2, 3, 5, 8]})
|
|
995
|
+
DataChain.from_pandas(df)
|
|
996
|
+
```
|
|
997
|
+
"""
|
|
717
998
|
fr_map = {col.lower(): df[col].tolist() for col in df.columns}
|
|
718
999
|
|
|
719
1000
|
for column in fr_map:
|
|
@@ -731,11 +1012,76 @@ class DataChain(DatasetQuery):
|
|
|
731
1012
|
|
|
732
1013
|
return cls.from_values(name, session, object_name=object_name, **fr_map)
|
|
733
1014
|
|
|
1015
|
+
def to_pandas(self, flatten=False) -> "pd.DataFrame":
|
|
1016
|
+
"""Return a pandas DataFrame from the chain.
|
|
1017
|
+
|
|
1018
|
+
Parameters:
|
|
1019
|
+
flatten : Whether to use a multiindex or flatten column names.
|
|
1020
|
+
"""
|
|
1021
|
+
headers, max_length = self._effective_signals_schema.get_headers_with_length()
|
|
1022
|
+
if flatten or max_length < 2:
|
|
1023
|
+
df = pd.DataFrame.from_records(self.to_records())
|
|
1024
|
+
if headers:
|
|
1025
|
+
df.columns = [".".join(filter(None, header)) for header in headers]
|
|
1026
|
+
return df
|
|
1027
|
+
|
|
1028
|
+
transposed_result = list(map(list, zip(*self.results())))
|
|
1029
|
+
data = {tuple(n): val for n, val in zip(headers, transposed_result)}
|
|
1030
|
+
return pd.DataFrame(data)
|
|
1031
|
+
|
|
1032
|
+
def show(
|
|
1033
|
+
self,
|
|
1034
|
+
limit: int = 20,
|
|
1035
|
+
flatten=False,
|
|
1036
|
+
transpose=False,
|
|
1037
|
+
truncate=True,
|
|
1038
|
+
) -> None:
|
|
1039
|
+
"""Show a preview of the chain results.
|
|
1040
|
+
|
|
1041
|
+
Parameters:
|
|
1042
|
+
limit : How many rows to show.
|
|
1043
|
+
flatten : Whether to use a multiindex or flatten column names.
|
|
1044
|
+
transpose : Whether to transpose rows and columns.
|
|
1045
|
+
truncate : Whether or not to truncate the contents of columns.
|
|
1046
|
+
"""
|
|
1047
|
+
dc = self.limit(limit) if limit > 0 else self
|
|
1048
|
+
df = dc.to_pandas(flatten)
|
|
1049
|
+
if transpose:
|
|
1050
|
+
df = df.T
|
|
1051
|
+
|
|
1052
|
+
options: list = [
|
|
1053
|
+
"display.max_columns",
|
|
1054
|
+
None,
|
|
1055
|
+
"display.multi_sparse",
|
|
1056
|
+
False,
|
|
1057
|
+
]
|
|
1058
|
+
|
|
1059
|
+
try:
|
|
1060
|
+
if columns := os.get_terminal_size().columns:
|
|
1061
|
+
options.extend(["display.width", columns])
|
|
1062
|
+
except OSError:
|
|
1063
|
+
pass
|
|
1064
|
+
|
|
1065
|
+
if not truncate:
|
|
1066
|
+
options.extend(["display.max_colwidth", None])
|
|
1067
|
+
|
|
1068
|
+
with pd.option_context(*options):
|
|
1069
|
+
if inside_notebook():
|
|
1070
|
+
from IPython.display import display
|
|
1071
|
+
|
|
1072
|
+
display(df)
|
|
1073
|
+
else:
|
|
1074
|
+
print(df)
|
|
1075
|
+
|
|
1076
|
+
if len(df) == limit:
|
|
1077
|
+
print(f"\n[Limited by {len(df)} rows]")
|
|
1078
|
+
|
|
734
1079
|
def parse_tabular(
|
|
735
1080
|
self,
|
|
736
1081
|
output: OutputType = None,
|
|
737
1082
|
object_name: str = "",
|
|
738
1083
|
model_name: str = "",
|
|
1084
|
+
nrows: Optional[int] = None,
|
|
739
1085
|
**kwargs,
|
|
740
1086
|
) -> "DataChain":
|
|
741
1087
|
"""Generate chain from list of tabular files.
|
|
@@ -747,18 +1093,22 @@ class DataChain(DatasetQuery):
|
|
|
747
1093
|
object_name : Generated object column name.
|
|
748
1094
|
model_name : Generated model name.
|
|
749
1095
|
kwargs : Parameters to pass to pyarrow.dataset.dataset.
|
|
1096
|
+
nrows : Optional row limit.
|
|
750
1097
|
|
|
751
|
-
|
|
1098
|
+
Example:
|
|
752
1099
|
Reading a json lines file:
|
|
753
|
-
|
|
754
|
-
|
|
1100
|
+
```py
|
|
1101
|
+
dc = DataChain.from_storage("s3://mybucket/file.jsonl")
|
|
1102
|
+
dc = dc.parse_tabular(format="json")
|
|
1103
|
+
```
|
|
755
1104
|
|
|
756
1105
|
Reading a filtered list of files as a dataset:
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
1106
|
+
```py
|
|
1107
|
+
dc = DataChain.from_storage("s3://mybucket")
|
|
1108
|
+
dc = dc.filter(C("file.name").glob("*.jsonl"))
|
|
1109
|
+
dc = dc.parse_tabular(format="json")
|
|
1110
|
+
```
|
|
760
1111
|
"""
|
|
761
|
-
|
|
762
1112
|
from datachain.lib.arrow import ArrowGenerator, infer_schema, schema_to_output
|
|
763
1113
|
|
|
764
1114
|
schema = None
|
|
@@ -781,7 +1131,7 @@ class DataChain(DatasetQuery):
|
|
|
781
1131
|
for name, info in output.model_fields.items()
|
|
782
1132
|
}
|
|
783
1133
|
output = {"source": IndexedFile} | output # type: ignore[assignment,operator]
|
|
784
|
-
return self.gen(ArrowGenerator(schema, **kwargs), output=output)
|
|
1134
|
+
return self.gen(ArrowGenerator(schema, nrows, **kwargs), output=output)
|
|
785
1135
|
|
|
786
1136
|
@staticmethod
|
|
787
1137
|
def _dict_to_data_model(
|
|
@@ -804,6 +1154,7 @@ class DataChain(DatasetQuery):
|
|
|
804
1154
|
output: OutputType = None,
|
|
805
1155
|
object_name: str = "",
|
|
806
1156
|
model_name: str = "",
|
|
1157
|
+
nrows=None,
|
|
807
1158
|
**kwargs,
|
|
808
1159
|
) -> "DataChain":
|
|
809
1160
|
"""Generate chain from csv files.
|
|
@@ -818,13 +1169,18 @@ class DataChain(DatasetQuery):
|
|
|
818
1169
|
case types will be inferred.
|
|
819
1170
|
object_name : Created object column name.
|
|
820
1171
|
model_name : Generated model name.
|
|
1172
|
+
nrows : Optional row limit.
|
|
821
1173
|
|
|
822
|
-
|
|
1174
|
+
Example:
|
|
823
1175
|
Reading a csv file:
|
|
824
|
-
|
|
1176
|
+
```py
|
|
1177
|
+
dc = DataChain.from_csv("s3://mybucket/file.csv")
|
|
1178
|
+
```
|
|
825
1179
|
|
|
826
1180
|
Reading csv files from a directory as a combined dataset:
|
|
827
|
-
|
|
1181
|
+
```py
|
|
1182
|
+
dc = DataChain.from_csv("s3://mybucket/dir")
|
|
1183
|
+
```
|
|
828
1184
|
"""
|
|
829
1185
|
from pyarrow.csv import ParseOptions, ReadOptions
|
|
830
1186
|
from pyarrow.dataset import CsvFileFormat
|
|
@@ -849,7 +1205,11 @@ class DataChain(DatasetQuery):
|
|
|
849
1205
|
read_options = ReadOptions(column_names=column_names)
|
|
850
1206
|
format = CsvFileFormat(parse_options=parse_options, read_options=read_options)
|
|
851
1207
|
return chain.parse_tabular(
|
|
852
|
-
output=output,
|
|
1208
|
+
output=output,
|
|
1209
|
+
object_name=object_name,
|
|
1210
|
+
model_name=model_name,
|
|
1211
|
+
nrows=nrows,
|
|
1212
|
+
format=format,
|
|
853
1213
|
)
|
|
854
1214
|
|
|
855
1215
|
@classmethod
|
|
@@ -860,6 +1220,7 @@ class DataChain(DatasetQuery):
|
|
|
860
1220
|
output: Optional[dict[str, DataType]] = None,
|
|
861
1221
|
object_name: str = "",
|
|
862
1222
|
model_name: str = "",
|
|
1223
|
+
nrows=None,
|
|
863
1224
|
**kwargs,
|
|
864
1225
|
) -> "DataChain":
|
|
865
1226
|
"""Generate chain from parquet files.
|
|
@@ -871,23 +1232,48 @@ class DataChain(DatasetQuery):
|
|
|
871
1232
|
output : Dictionary defining column names and their corresponding types.
|
|
872
1233
|
object_name : Created object column name.
|
|
873
1234
|
model_name : Generated model name.
|
|
1235
|
+
nrows : Optional row limit.
|
|
874
1236
|
|
|
875
|
-
|
|
1237
|
+
Example:
|
|
876
1238
|
Reading a single file:
|
|
877
|
-
|
|
1239
|
+
```py
|
|
1240
|
+
dc = DataChain.from_parquet("s3://mybucket/file.parquet")
|
|
1241
|
+
```
|
|
878
1242
|
|
|
879
1243
|
Reading a partitioned dataset from a directory:
|
|
880
|
-
|
|
1244
|
+
```py
|
|
1245
|
+
dc = DataChain.from_parquet("s3://mybucket/dir")
|
|
1246
|
+
```
|
|
881
1247
|
"""
|
|
882
1248
|
chain = DataChain.from_storage(path, **kwargs)
|
|
883
1249
|
return chain.parse_tabular(
|
|
884
1250
|
output=output,
|
|
885
1251
|
object_name=object_name,
|
|
886
1252
|
model_name=model_name,
|
|
1253
|
+
nrows=None,
|
|
887
1254
|
format="parquet",
|
|
888
1255
|
partitioning=partitioning,
|
|
889
1256
|
)
|
|
890
1257
|
|
|
1258
|
+
def to_parquet(
|
|
1259
|
+
self,
|
|
1260
|
+
path: Union[str, os.PathLike[str], BinaryIO],
|
|
1261
|
+
partition_cols: Optional[Sequence[str]] = None,
|
|
1262
|
+
**kwargs,
|
|
1263
|
+
) -> None:
|
|
1264
|
+
"""Save chain to parquet file.
|
|
1265
|
+
|
|
1266
|
+
Parameters:
|
|
1267
|
+
path : Path or a file-like binary object to save the file.
|
|
1268
|
+
partition_cols : Column names by which to partition the dataset.
|
|
1269
|
+
"""
|
|
1270
|
+
_partition_cols = list(partition_cols) if partition_cols else None
|
|
1271
|
+
return self.to_pandas().to_parquet(
|
|
1272
|
+
path,
|
|
1273
|
+
partition_cols=_partition_cols,
|
|
1274
|
+
**kwargs,
|
|
1275
|
+
)
|
|
1276
|
+
|
|
891
1277
|
@classmethod
|
|
892
1278
|
def create_empty(
|
|
893
1279
|
cls,
|
|
@@ -901,9 +1287,11 @@ class DataChain(DatasetQuery):
|
|
|
901
1287
|
to_insert : records (or a single record) to insert. Each record is
|
|
902
1288
|
a dictionary of signals and theirs values.
|
|
903
1289
|
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
1290
|
+
Example:
|
|
1291
|
+
```py
|
|
1292
|
+
empty = DataChain.create_empty()
|
|
1293
|
+
single_record = DataChain.create_empty(DataChain.DEFAULT_FILE_RECORD)
|
|
1294
|
+
```
|
|
907
1295
|
"""
|
|
908
1296
|
session = Session.get(session)
|
|
909
1297
|
catalog = session.catalog
|
|
@@ -929,18 +1317,47 @@ class DataChain(DatasetQuery):
|
|
|
929
1317
|
return DataChain(name=dsr.name)
|
|
930
1318
|
|
|
931
1319
|
def sum(self, fr: DataType): # type: ignore[override]
|
|
1320
|
+
"""Compute the sum of a column."""
|
|
932
1321
|
return self._extend_to_data_model("sum", fr)
|
|
933
1322
|
|
|
934
1323
|
def avg(self, fr: DataType): # type: ignore[override]
|
|
1324
|
+
"""Compute the average of a column."""
|
|
935
1325
|
return self._extend_to_data_model("avg", fr)
|
|
936
1326
|
|
|
937
1327
|
def min(self, fr: DataType): # type: ignore[override]
|
|
1328
|
+
"""Compute the minimum of a column."""
|
|
938
1329
|
return self._extend_to_data_model("min", fr)
|
|
939
1330
|
|
|
940
1331
|
def max(self, fr: DataType): # type: ignore[override]
|
|
1332
|
+
"""Compute the maximum of a column."""
|
|
941
1333
|
return self._extend_to_data_model("max", fr)
|
|
942
1334
|
|
|
943
1335
|
def setup(self, **kwargs) -> "Self":
|
|
1336
|
+
"""Setup variables to pass to UDF functions.
|
|
1337
|
+
|
|
1338
|
+
Use before running map/gen/agg/batch_map to save an object and pass it as an
|
|
1339
|
+
argument to the UDF.
|
|
1340
|
+
|
|
1341
|
+
Example:
|
|
1342
|
+
```py
|
|
1343
|
+
import anthropic
|
|
1344
|
+
from anthropic.types import Message
|
|
1345
|
+
|
|
1346
|
+
(
|
|
1347
|
+
DataChain.from_storage(DATA, type="text")
|
|
1348
|
+
.settings(parallel=4, cache=True)
|
|
1349
|
+
.setup(client=lambda: anthropic.Anthropic(api_key=API_KEY))
|
|
1350
|
+
.map(
|
|
1351
|
+
claude=lambda client, file: client.messages.create(
|
|
1352
|
+
model=MODEL,
|
|
1353
|
+
system=PROMPT,
|
|
1354
|
+
messages=[{"role": "user", "content": file.get_value()}],
|
|
1355
|
+
),
|
|
1356
|
+
output=Message,
|
|
1357
|
+
)
|
|
1358
|
+
)
|
|
1359
|
+
```
|
|
1360
|
+
"""
|
|
944
1361
|
intersection = set(self._setup.keys()) & set(kwargs.keys())
|
|
945
1362
|
if intersection:
|
|
946
1363
|
keys = ", ".join(intersection)
|
|
@@ -948,3 +1365,80 @@ class DataChain(DatasetQuery):
|
|
|
948
1365
|
|
|
949
1366
|
self._setup = self._setup | kwargs
|
|
950
1367
|
return self
|
|
1368
|
+
|
|
1369
|
+
def export_files(
|
|
1370
|
+
self,
|
|
1371
|
+
output: str,
|
|
1372
|
+
signal="file",
|
|
1373
|
+
placement: FileExportPlacement = "fullpath",
|
|
1374
|
+
use_cache: bool = True,
|
|
1375
|
+
) -> None:
|
|
1376
|
+
"""Method that exports all files from chain to some folder."""
|
|
1377
|
+
if placement == "filename":
|
|
1378
|
+
print("Checking if file names are unique")
|
|
1379
|
+
if self.distinct(f"{signal}.name").count() != self.count():
|
|
1380
|
+
raise ValueError("Files with the same name found")
|
|
1381
|
+
|
|
1382
|
+
for file in self.collect(signal):
|
|
1383
|
+
file.export(output, placement, use_cache) # type: ignore[union-attr]
|
|
1384
|
+
|
|
1385
|
+
def shuffle(self) -> "Self":
|
|
1386
|
+
"""Shuffle the rows of the chain deterministically."""
|
|
1387
|
+
return self.order_by("sys.rand")
|
|
1388
|
+
|
|
1389
|
+
def sample(self, n) -> "Self":
|
|
1390
|
+
"""Return a random sample from the chain.
|
|
1391
|
+
|
|
1392
|
+
Parameters:
|
|
1393
|
+
n (int): Number of samples to draw.
|
|
1394
|
+
|
|
1395
|
+
NOTE: Samples are not deterministic, and streamed/paginated queries or
|
|
1396
|
+
multiple workers will draw samples with replacement.
|
|
1397
|
+
"""
|
|
1398
|
+
return super().sample(n)
|
|
1399
|
+
|
|
1400
|
+
@detach
|
|
1401
|
+
def filter(self, *args) -> "Self":
|
|
1402
|
+
"""Filter the chain according to conditions.
|
|
1403
|
+
|
|
1404
|
+
Example:
|
|
1405
|
+
Basic usage with built-in operators
|
|
1406
|
+
```py
|
|
1407
|
+
dc.filter(C("width") < 200)
|
|
1408
|
+
```
|
|
1409
|
+
|
|
1410
|
+
Using glob to match patterns
|
|
1411
|
+
```py
|
|
1412
|
+
dc.filter(C("file.name").glob("*.jpg))
|
|
1413
|
+
```
|
|
1414
|
+
|
|
1415
|
+
Using `datachain.sql.functions`
|
|
1416
|
+
```py
|
|
1417
|
+
from datachain.sql.functions import string
|
|
1418
|
+
dc.filter(string.length(C("file.name")) > 5)
|
|
1419
|
+
```
|
|
1420
|
+
|
|
1421
|
+
Combining filters with "or"
|
|
1422
|
+
```py
|
|
1423
|
+
dc.filter(C("file.name").glob("cat*") | C("file.name").glob("dog*))
|
|
1424
|
+
```
|
|
1425
|
+
|
|
1426
|
+
Combining filters with "and"
|
|
1427
|
+
```py
|
|
1428
|
+
dc.filter(
|
|
1429
|
+
C("file.name").glob("*.jpg) &
|
|
1430
|
+
(string.length(C("file.name")) > 5)
|
|
1431
|
+
)
|
|
1432
|
+
```
|
|
1433
|
+
"""
|
|
1434
|
+
return super().filter(*args)
|
|
1435
|
+
|
|
1436
|
+
@detach
|
|
1437
|
+
def limit(self, n: int) -> "Self":
|
|
1438
|
+
"""Return the first n rows of the chain."""
|
|
1439
|
+
return super().limit(n)
|
|
1440
|
+
|
|
1441
|
+
@detach
|
|
1442
|
+
def offset(self, offset: int) -> "Self":
|
|
1443
|
+
"""Return the results starting with the offset row."""
|
|
1444
|
+
return super().offset(offset)
|