datachain 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +5 -7
- datachain/cli.py +1 -1
- datachain/data_storage/metastore.py +2 -2
- datachain/data_storage/sqlite.py +21 -0
- datachain/data_storage/warehouse.py +28 -8
- datachain/lib/arrow.py +27 -8
- datachain/lib/convert/flatten.py +10 -5
- datachain/lib/convert/python_to_sql.py +1 -1
- datachain/lib/data_model.py +6 -1
- datachain/lib/dc.py +102 -32
- datachain/lib/meta_formats.py +6 -6
- datachain/lib/settings.py +1 -17
- datachain/lib/signal_schema.py +4 -1
- datachain/lib/udf.py +18 -10
- datachain/query/dataset.py +10 -46
- datachain/sql/types.py +5 -1
- {datachain-0.2.14.dist-info → datachain-0.2.16.dist-info}/METADATA +1 -1
- {datachain-0.2.14.dist-info → datachain-0.2.16.dist-info}/RECORD +22 -22
- {datachain-0.2.14.dist-info → datachain-0.2.16.dist-info}/WHEEL +1 -1
- {datachain-0.2.14.dist-info → datachain-0.2.16.dist-info}/LICENSE +0 -0
- {datachain-0.2.14.dist-info → datachain-0.2.16.dist-info}/entry_points.txt +0 -0
- {datachain-0.2.14.dist-info → datachain-0.2.16.dist-info}/top_level.txt +0 -0
datachain/catalog/catalog.py
CHANGED
|
@@ -1217,16 +1217,14 @@ class Catalog:
|
|
|
1217
1217
|
def get_temp_table_names(self) -> list[str]:
|
|
1218
1218
|
return self.warehouse.get_temp_table_names()
|
|
1219
1219
|
|
|
1220
|
-
def
|
|
1220
|
+
def cleanup_tables(self, names: Iterable[str]) -> None:
|
|
1221
1221
|
"""
|
|
1222
|
-
Drop tables
|
|
1222
|
+
Drop tables passed.
|
|
1223
1223
|
|
|
1224
|
-
This should be implemented
|
|
1225
|
-
|
|
1226
|
-
needed. When running the same `DatasetQuery` multiple times we
|
|
1227
|
-
may use the same temporary table names.
|
|
1224
|
+
This should be implemented to ensure that the provided tables
|
|
1225
|
+
are cleaned up as soon as they are no longer needed.
|
|
1228
1226
|
"""
|
|
1229
|
-
self.warehouse.
|
|
1227
|
+
self.warehouse.cleanup_tables(names)
|
|
1230
1228
|
self.id_generator.delete_uris(names)
|
|
1231
1229
|
|
|
1232
1230
|
def create_dataset_from_sources(
|
datachain/cli.py
CHANGED
|
@@ -910,7 +910,7 @@ def garbage_collect(catalog: "Catalog"):
|
|
|
910
910
|
print("Nothing to clean up.")
|
|
911
911
|
else:
|
|
912
912
|
print(f"Garbage collecting {len(temp_tables)} tables.")
|
|
913
|
-
catalog.
|
|
913
|
+
catalog.cleanup_tables(temp_tables)
|
|
914
914
|
|
|
915
915
|
|
|
916
916
|
def completion(shell: str) -> str:
|
|
@@ -97,7 +97,7 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
97
97
|
def close(self) -> None:
|
|
98
98
|
"""Closes any active database or HTTP connections."""
|
|
99
99
|
|
|
100
|
-
def
|
|
100
|
+
def cleanup_tables(self, temp_table_names: list[str]) -> None:
|
|
101
101
|
"""Cleanup temp tables."""
|
|
102
102
|
|
|
103
103
|
def cleanup_for_tests(self) -> None:
|
|
@@ -457,7 +457,7 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
457
457
|
"""Closes any active database connections."""
|
|
458
458
|
self.db.close()
|
|
459
459
|
|
|
460
|
-
def
|
|
460
|
+
def cleanup_tables(self, temp_table_names: list[str]) -> None:
|
|
461
461
|
"""Cleanup temp tables."""
|
|
462
462
|
self.id_generator.delete_uris(temp_table_names)
|
|
463
463
|
|
datachain/data_storage/sqlite.py
CHANGED
|
@@ -42,6 +42,7 @@ if TYPE_CHECKING:
|
|
|
42
42
|
from sqlalchemy.dialects.sqlite import Insert
|
|
43
43
|
from sqlalchemy.schema import SchemaItem
|
|
44
44
|
from sqlalchemy.sql.elements import ColumnClause, ColumnElement, TextClause
|
|
45
|
+
from sqlalchemy.sql.selectable import Select
|
|
45
46
|
from sqlalchemy.types import TypeEngine
|
|
46
47
|
|
|
47
48
|
|
|
@@ -705,3 +706,23 @@ class SQLiteWarehouse(AbstractWarehouse):
|
|
|
705
706
|
client_config=None,
|
|
706
707
|
) -> list[str]:
|
|
707
708
|
raise NotImplementedError("Exporting dataset table not implemented for SQLite")
|
|
709
|
+
|
|
710
|
+
def create_pre_udf_table(self, query: "Select") -> "Table":
|
|
711
|
+
"""
|
|
712
|
+
Create a temporary table from a query for use in a UDF.
|
|
713
|
+
"""
|
|
714
|
+
columns = [
|
|
715
|
+
sqlalchemy.Column(c.name, c.type)
|
|
716
|
+
for c in query.selected_columns
|
|
717
|
+
if c.name != "sys__id"
|
|
718
|
+
]
|
|
719
|
+
table = self.create_udf_table(columns)
|
|
720
|
+
|
|
721
|
+
select_q = query.with_only_columns(
|
|
722
|
+
*[c for c in query.selected_columns if c.name != "sys__id"]
|
|
723
|
+
)
|
|
724
|
+
self.db.execute(
|
|
725
|
+
table.insert().from_select(list(select_q.selected_columns), select_q)
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
return table
|
|
@@ -2,6 +2,8 @@ import glob
|
|
|
2
2
|
import json
|
|
3
3
|
import logging
|
|
4
4
|
import posixpath
|
|
5
|
+
import random
|
|
6
|
+
import string
|
|
5
7
|
from abc import ABC, abstractmethod
|
|
6
8
|
from collections.abc import Generator, Iterable, Iterator, Sequence
|
|
7
9
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
@@ -24,6 +26,7 @@ from datachain.utils import sql_escape_like
|
|
|
24
26
|
if TYPE_CHECKING:
|
|
25
27
|
from sqlalchemy.sql._typing import _ColumnsClauseArgument
|
|
26
28
|
from sqlalchemy.sql.elements import ColumnElement
|
|
29
|
+
from sqlalchemy.sql.selectable import Select
|
|
27
30
|
from sqlalchemy.types import TypeEngine
|
|
28
31
|
|
|
29
32
|
from datachain.data_storage import AbstractIDGenerator, schema
|
|
@@ -252,6 +255,12 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
252
255
|
prefix = self.DATASET_SOURCE_TABLE_PREFIX
|
|
253
256
|
return f"{prefix}{dataset_name}_{version}"
|
|
254
257
|
|
|
258
|
+
def temp_table_name(self) -> str:
|
|
259
|
+
return self.TMP_TABLE_NAME_PREFIX + _random_string(6)
|
|
260
|
+
|
|
261
|
+
def udf_table_name(self) -> str:
|
|
262
|
+
return self.UDF_TABLE_NAME_PREFIX + _random_string(6)
|
|
263
|
+
|
|
255
264
|
#
|
|
256
265
|
# Datasets
|
|
257
266
|
#
|
|
@@ -869,8 +878,8 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
869
878
|
|
|
870
879
|
def create_udf_table(
|
|
871
880
|
self,
|
|
872
|
-
name: str,
|
|
873
881
|
columns: Sequence["sa.Column"] = (),
|
|
882
|
+
name: Optional[str] = None,
|
|
874
883
|
) -> "sa.Table":
|
|
875
884
|
"""
|
|
876
885
|
Create a temporary table for storing custom signals generated by a UDF.
|
|
@@ -878,7 +887,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
878
887
|
and UDFs are run in other processes when run in parallel.
|
|
879
888
|
"""
|
|
880
889
|
tbl = sa.Table(
|
|
881
|
-
name,
|
|
890
|
+
name or self.udf_table_name(),
|
|
882
891
|
sa.MetaData(),
|
|
883
892
|
sa.Column("sys__id", Int, primary_key=True),
|
|
884
893
|
*columns,
|
|
@@ -886,6 +895,12 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
886
895
|
self.db.create_table(tbl, if_not_exists=True)
|
|
887
896
|
return tbl
|
|
888
897
|
|
|
898
|
+
@abstractmethod
|
|
899
|
+
def create_pre_udf_table(self, query: "Select") -> "Table":
|
|
900
|
+
"""
|
|
901
|
+
Create a temporary table from a query for use in a UDF.
|
|
902
|
+
"""
|
|
903
|
+
|
|
889
904
|
def is_temp_table_name(self, name: str) -> bool:
|
|
890
905
|
"""Returns if the given table name refers to a temporary
|
|
891
906
|
or no longer needed table."""
|
|
@@ -900,14 +915,12 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
900
915
|
if self.is_temp_table_name(t)
|
|
901
916
|
]
|
|
902
917
|
|
|
903
|
-
def
|
|
918
|
+
def cleanup_tables(self, names: Iterable[str]) -> None:
|
|
904
919
|
"""
|
|
905
|
-
Drop tables
|
|
920
|
+
Drop tables passed.
|
|
906
921
|
|
|
907
|
-
This should be implemented
|
|
908
|
-
|
|
909
|
-
needed. When running the same `DatasetQuery` multiple times we
|
|
910
|
-
may use the same temporary table names.
|
|
922
|
+
This should be implemented to ensure that the provided tables
|
|
923
|
+
are cleaned up as soon as they are no longer needed.
|
|
911
924
|
"""
|
|
912
925
|
for name in names:
|
|
913
926
|
self.db.drop_table(Table(name, self.db.metadata), if_exists=True)
|
|
@@ -937,3 +950,10 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
937
950
|
& (tq.c.is_latest == true())
|
|
938
951
|
)
|
|
939
952
|
)
|
|
953
|
+
|
|
954
|
+
|
|
955
|
+
def _random_string(length: int) -> str:
|
|
956
|
+
return "".join(
|
|
957
|
+
random.choice(string.ascii_letters + string.digits) # noqa: S311
|
|
958
|
+
for i in range(length)
|
|
959
|
+
)
|
datachain/lib/arrow.py
CHANGED
|
@@ -10,13 +10,17 @@ from datachain.lib.file import File, IndexedFile
|
|
|
10
10
|
from datachain.lib.udf import Generator
|
|
11
11
|
|
|
12
12
|
if TYPE_CHECKING:
|
|
13
|
+
from pydantic import BaseModel
|
|
14
|
+
|
|
13
15
|
from datachain.lib.dc import DataChain
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
class ArrowGenerator(Generator):
|
|
17
19
|
def __init__(
|
|
18
20
|
self,
|
|
19
|
-
|
|
21
|
+
input_schema: Optional["pa.Schema"] = None,
|
|
22
|
+
output_schema: Optional[type["BaseModel"]] = None,
|
|
23
|
+
source: bool = True,
|
|
20
24
|
nrows: Optional[int] = None,
|
|
21
25
|
**kwargs,
|
|
22
26
|
):
|
|
@@ -25,24 +29,36 @@ class ArrowGenerator(Generator):
|
|
|
25
29
|
|
|
26
30
|
Parameters:
|
|
27
31
|
|
|
28
|
-
|
|
32
|
+
input_schema : Optional pyarrow schema for validation.
|
|
33
|
+
output_schema : Optional pydantic model for validation.
|
|
34
|
+
source : Whether to include info about the source file.
|
|
29
35
|
nrows : Optional row limit.
|
|
30
36
|
kwargs: Parameters to pass to pyarrow.dataset.dataset.
|
|
31
37
|
"""
|
|
32
38
|
super().__init__()
|
|
33
|
-
self.
|
|
39
|
+
self.input_schema = input_schema
|
|
40
|
+
self.output_schema = output_schema
|
|
41
|
+
self.source = source
|
|
34
42
|
self.nrows = nrows
|
|
35
43
|
self.kwargs = kwargs
|
|
36
44
|
|
|
37
45
|
def process(self, file: File):
|
|
38
46
|
path = file.get_path()
|
|
39
|
-
ds = dataset(
|
|
47
|
+
ds = dataset(
|
|
48
|
+
path, filesystem=file.get_fs(), schema=self.input_schema, **self.kwargs
|
|
49
|
+
)
|
|
40
50
|
index = 0
|
|
41
51
|
with tqdm(desc="Parsed by pyarrow", unit=" rows") as pbar:
|
|
42
|
-
for record_batch in ds.to_batches():
|
|
52
|
+
for record_batch in ds.to_batches(use_threads=False):
|
|
43
53
|
for record in record_batch.to_pylist():
|
|
44
|
-
|
|
45
|
-
|
|
54
|
+
vals = list(record.values())
|
|
55
|
+
if self.output_schema:
|
|
56
|
+
fields = self.output_schema.model_fields
|
|
57
|
+
vals = [self.output_schema(**dict(zip(fields, vals)))]
|
|
58
|
+
if self.source:
|
|
59
|
+
yield [IndexedFile(file=file, index=index), *vals]
|
|
60
|
+
else:
|
|
61
|
+
yield vals
|
|
46
62
|
index += 1
|
|
47
63
|
if self.nrows and index >= self.nrows:
|
|
48
64
|
return
|
|
@@ -76,7 +92,10 @@ def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = Non
|
|
|
76
92
|
if not column:
|
|
77
93
|
column = f"c{default_column}"
|
|
78
94
|
default_column += 1
|
|
79
|
-
|
|
95
|
+
dtype = _arrow_type_mapper(field.type) # type: ignore[assignment]
|
|
96
|
+
if field.nullable:
|
|
97
|
+
dtype = Optional[dtype] # type: ignore[assignment]
|
|
98
|
+
output[column] = dtype
|
|
80
99
|
|
|
81
100
|
return output
|
|
82
101
|
|
datachain/lib/convert/flatten.py
CHANGED
|
@@ -41,17 +41,22 @@ def flatten_list(obj_list):
|
|
|
41
41
|
)
|
|
42
42
|
|
|
43
43
|
|
|
44
|
+
def _flatten_list_field(value: list):
|
|
45
|
+
assert isinstance(value, list)
|
|
46
|
+
if value and ModelStore.is_pydantic(type(value[0])):
|
|
47
|
+
return [val.model_dump() for val in value]
|
|
48
|
+
if value and isinstance(value[0], list):
|
|
49
|
+
return [_flatten_list_field(v) for v in value]
|
|
50
|
+
return value
|
|
51
|
+
|
|
52
|
+
|
|
44
53
|
def _flatten_fields_values(fields, obj: BaseModel):
|
|
45
54
|
for name, f_info in fields.items():
|
|
46
55
|
anno = f_info.annotation
|
|
47
56
|
# Optimization: Access attributes directly to skip the model_dump() call.
|
|
48
57
|
value = getattr(obj, name)
|
|
49
|
-
|
|
50
58
|
if isinstance(value, list):
|
|
51
|
-
|
|
52
|
-
yield [val.model_dump() for val in value]
|
|
53
|
-
else:
|
|
54
|
-
yield value
|
|
59
|
+
yield _flatten_list_field(value)
|
|
55
60
|
elif isinstance(value, dict):
|
|
56
61
|
yield {
|
|
57
62
|
key: val.model_dump() if ModelStore.is_pydantic(type(val)) else val
|
|
@@ -82,7 +82,7 @@ def python_to_sql(typ): # noqa: PLR0911
|
|
|
82
82
|
def _is_json_inside_union(orig, args) -> bool:
|
|
83
83
|
if orig == Union and len(args) >= 2:
|
|
84
84
|
# List in JSON: Union[dict, list[dict]]
|
|
85
|
-
args_no_nones = [arg for arg in args if arg != type(None)]
|
|
85
|
+
args_no_nones = [arg for arg in args if arg != type(None)] # noqa: E721
|
|
86
86
|
if len(args_no_nones) == 2:
|
|
87
87
|
args_no_dicts = [arg for arg in args_no_nones if arg is not dict]
|
|
88
88
|
if len(args_no_dicts) == 1 and get_origin(args_no_dicts[0]) is list:
|
datachain/lib/data_model.py
CHANGED
|
@@ -47,7 +47,12 @@ def is_chain_type(t: type) -> bool:
|
|
|
47
47
|
if any(t is ft or t is get_args(ft)[0] for ft in get_args(StandardType)):
|
|
48
48
|
return True
|
|
49
49
|
|
|
50
|
-
|
|
50
|
+
orig = get_origin(t)
|
|
51
|
+
args = get_args(t)
|
|
52
|
+
if orig is list and len(args) == 1:
|
|
51
53
|
return is_chain_type(get_args(t)[0])
|
|
52
54
|
|
|
55
|
+
if orig is Union and len(args) == 2 and (type(None) in args):
|
|
56
|
+
return is_chain_type(args[0])
|
|
57
|
+
|
|
53
58
|
return False
|
datachain/lib/dc.py
CHANGED
|
@@ -33,6 +33,7 @@ from datachain.lib.settings import Settings
|
|
|
33
33
|
from datachain.lib.signal_schema import SignalSchema
|
|
34
34
|
from datachain.lib.udf import (
|
|
35
35
|
Aggregator,
|
|
36
|
+
BatchMapper,
|
|
36
37
|
Generator,
|
|
37
38
|
Mapper,
|
|
38
39
|
UDFBase,
|
|
@@ -237,7 +238,6 @@ class DataChain(DatasetQuery):
|
|
|
237
238
|
def settings(
|
|
238
239
|
self,
|
|
239
240
|
cache=None,
|
|
240
|
-
batch=None,
|
|
241
241
|
parallel=None,
|
|
242
242
|
workers=None,
|
|
243
243
|
min_task_size=None,
|
|
@@ -250,7 +250,6 @@ class DataChain(DatasetQuery):
|
|
|
250
250
|
|
|
251
251
|
Parameters:
|
|
252
252
|
cache : data caching (default=False)
|
|
253
|
-
batch : size of the batch (default=1000)
|
|
254
253
|
parallel : number of thread for processors. True is a special value to
|
|
255
254
|
enable all available CPUs (default=1)
|
|
256
255
|
workers : number of distributed workers. Only for Studio mode. (default=1)
|
|
@@ -268,7 +267,7 @@ class DataChain(DatasetQuery):
|
|
|
268
267
|
chain = self.clone()
|
|
269
268
|
if sys is not None:
|
|
270
269
|
chain._sys = sys
|
|
271
|
-
chain._settings.add(Settings(cache,
|
|
270
|
+
chain._settings.add(Settings(cache, parallel, workers, min_task_size))
|
|
272
271
|
return chain
|
|
273
272
|
|
|
274
273
|
def reset_settings(self, settings: Optional[Settings] = None) -> "Self":
|
|
@@ -344,7 +343,7 @@ class DataChain(DatasetQuery):
|
|
|
344
343
|
jmespath: Optional[str] = None,
|
|
345
344
|
object_name: Optional[str] = "",
|
|
346
345
|
model_name: Optional[str] = None,
|
|
347
|
-
|
|
346
|
+
print_schema: Optional[bool] = False,
|
|
348
347
|
meta_type: Optional[str] = "json",
|
|
349
348
|
nrows=None,
|
|
350
349
|
**kwargs,
|
|
@@ -359,7 +358,7 @@ class DataChain(DatasetQuery):
|
|
|
359
358
|
schema_from : path to sample to infer spec (if schema not provided)
|
|
360
359
|
object_name : generated object column name
|
|
361
360
|
model_name : optional generated model name
|
|
362
|
-
|
|
361
|
+
print_schema : print auto-generated schema
|
|
363
362
|
jmespath : optional JMESPATH expression to reduce JSON
|
|
364
363
|
nrows : optional row limit for jsonl and JSON arrays
|
|
365
364
|
|
|
@@ -392,7 +391,7 @@ class DataChain(DatasetQuery):
|
|
|
392
391
|
meta_type=meta_type,
|
|
393
392
|
spec=spec,
|
|
394
393
|
model_name=model_name,
|
|
395
|
-
|
|
394
|
+
print_schema=print_schema,
|
|
396
395
|
jmespath=jmespath,
|
|
397
396
|
nrows=nrows,
|
|
398
397
|
)
|
|
@@ -409,7 +408,7 @@ class DataChain(DatasetQuery):
|
|
|
409
408
|
jmespath: Optional[str] = None,
|
|
410
409
|
object_name: Optional[str] = "",
|
|
411
410
|
model_name: Optional[str] = None,
|
|
412
|
-
|
|
411
|
+
print_schema: Optional[bool] = False,
|
|
413
412
|
meta_type: Optional[str] = "jsonl",
|
|
414
413
|
nrows=None,
|
|
415
414
|
**kwargs,
|
|
@@ -424,7 +423,7 @@ class DataChain(DatasetQuery):
|
|
|
424
423
|
schema_from : path to sample to infer spec (if schema not provided)
|
|
425
424
|
object_name : generated object column name
|
|
426
425
|
model_name : optional generated model name
|
|
427
|
-
|
|
426
|
+
print_schema : print auto-generated schema
|
|
428
427
|
jmespath : optional JMESPATH expression to reduce JSON
|
|
429
428
|
nrows : optional row limit for jsonl and JSON arrays
|
|
430
429
|
|
|
@@ -452,7 +451,7 @@ class DataChain(DatasetQuery):
|
|
|
452
451
|
meta_type=meta_type,
|
|
453
452
|
spec=spec,
|
|
454
453
|
model_name=model_name,
|
|
455
|
-
|
|
454
|
+
print_schema=print_schema,
|
|
456
455
|
jmespath=jmespath,
|
|
457
456
|
nrows=nrows,
|
|
458
457
|
)
|
|
@@ -488,7 +487,7 @@ class DataChain(DatasetQuery):
|
|
|
488
487
|
**{object_name: datasets}, # type: ignore[arg-type]
|
|
489
488
|
)
|
|
490
489
|
|
|
491
|
-
def
|
|
490
|
+
def print_json_schema( # type: ignore[override]
|
|
492
491
|
self, jmespath: Optional[str] = None, model_name: Optional[str] = None
|
|
493
492
|
) -> "DataChain":
|
|
494
493
|
"""Print JSON data model and save it. It returns the chain itself.
|
|
@@ -513,7 +512,7 @@ class DataChain(DatasetQuery):
|
|
|
513
512
|
output=str,
|
|
514
513
|
)
|
|
515
514
|
|
|
516
|
-
def
|
|
515
|
+
def print_jsonl_schema( # type: ignore[override]
|
|
517
516
|
self, jmespath: Optional[str] = None, model_name: Optional[str] = None
|
|
518
517
|
) -> "DataChain":
|
|
519
518
|
"""Print JSON data model and save it. It returns the chain itself.
|
|
@@ -598,14 +597,16 @@ class DataChain(DatasetQuery):
|
|
|
598
597
|
|
|
599
598
|
Using func and output as a map:
|
|
600
599
|
```py
|
|
601
|
-
chain = chain.map(
|
|
600
|
+
chain = chain.map(
|
|
601
|
+
lambda name: name.split("."), output={"stem": str, "ext": str}
|
|
602
|
+
)
|
|
602
603
|
chain.save("new_dataset")
|
|
603
604
|
```
|
|
604
605
|
"""
|
|
605
606
|
udf_obj = self._udf_to_obj(Mapper, func, params, output, signal_map)
|
|
606
607
|
|
|
607
608
|
chain = self.add_signals(
|
|
608
|
-
udf_obj.to_udf_wrapper(
|
|
609
|
+
udf_obj.to_udf_wrapper(),
|
|
609
610
|
**self._settings.to_dict(),
|
|
610
611
|
)
|
|
611
612
|
|
|
@@ -618,7 +619,7 @@ class DataChain(DatasetQuery):
|
|
|
618
619
|
output: OutputType = None,
|
|
619
620
|
**signal_map,
|
|
620
621
|
) -> "Self":
|
|
621
|
-
"""Apply a function to each row to create new rows (with potentially new
|
|
622
|
+
r"""Apply a function to each row to create new rows (with potentially new
|
|
622
623
|
signals). The function needs to return a new objects for each of the new rows.
|
|
623
624
|
It returns a chain itself with new signals.
|
|
624
625
|
|
|
@@ -628,11 +629,20 @@ class DataChain(DatasetQuery):
|
|
|
628
629
|
one key differences: It produces a sequence of rows for each input row (like
|
|
629
630
|
extracting multiple file records from a single tar file or bounding boxes from a
|
|
630
631
|
single image file).
|
|
632
|
+
|
|
633
|
+
Example:
|
|
634
|
+
```py
|
|
635
|
+
chain = chain.gen(
|
|
636
|
+
line=lambda file: [l for l in file.read().split("\n")],
|
|
637
|
+
output=str,
|
|
638
|
+
)
|
|
639
|
+
chain.save("new_dataset")
|
|
640
|
+
```
|
|
631
641
|
"""
|
|
632
642
|
udf_obj = self._udf_to_obj(Generator, func, params, output, signal_map)
|
|
633
643
|
chain = DatasetQuery.generate(
|
|
634
644
|
self,
|
|
635
|
-
udf_obj.to_udf_wrapper(
|
|
645
|
+
udf_obj.to_udf_wrapper(),
|
|
636
646
|
**self._settings.to_dict(),
|
|
637
647
|
)
|
|
638
648
|
|
|
@@ -652,23 +662,68 @@ class DataChain(DatasetQuery):
|
|
|
652
662
|
|
|
653
663
|
Input-output relationship: N:M
|
|
654
664
|
|
|
655
|
-
This method bears similarity to `gen()` and map()
|
|
656
|
-
parameters, yet differs in two crucial aspects:
|
|
665
|
+
This method bears similarity to `gen()` and `map()`, employing a comparable set
|
|
666
|
+
of parameters, yet differs in two crucial aspects:
|
|
657
667
|
1. The `partition_by` parameter: This specifies the column name or a list of
|
|
658
668
|
column names that determine the grouping criteria for aggregation.
|
|
659
669
|
2. Group-based UDF function input: Instead of individual rows, the function
|
|
660
670
|
receives a list all rows within each group defined by `partition_by`.
|
|
671
|
+
|
|
672
|
+
Example:
|
|
673
|
+
```py
|
|
674
|
+
chain = chain.agg(
|
|
675
|
+
total=lambda category, amount: [sum(amount)],
|
|
676
|
+
output=float,
|
|
677
|
+
partition_by="category",
|
|
678
|
+
)
|
|
679
|
+
chain.save("new_dataset")
|
|
680
|
+
```
|
|
661
681
|
"""
|
|
662
682
|
udf_obj = self._udf_to_obj(Aggregator, func, params, output, signal_map)
|
|
663
683
|
chain = DatasetQuery.generate(
|
|
664
684
|
self,
|
|
665
|
-
udf_obj.to_udf_wrapper(
|
|
685
|
+
udf_obj.to_udf_wrapper(),
|
|
666
686
|
partition_by=partition_by,
|
|
667
687
|
**self._settings.to_dict(),
|
|
668
688
|
)
|
|
669
689
|
|
|
670
690
|
return chain.reset_schema(udf_obj.output).reset_settings(self._settings)
|
|
671
691
|
|
|
692
|
+
def batch_map(
|
|
693
|
+
self,
|
|
694
|
+
func: Optional[Callable] = None,
|
|
695
|
+
params: Union[None, str, Sequence[str]] = None,
|
|
696
|
+
output: OutputType = None,
|
|
697
|
+
batch: int = 1000,
|
|
698
|
+
**signal_map,
|
|
699
|
+
) -> "Self":
|
|
700
|
+
"""This is a batch version of `map()`.
|
|
701
|
+
|
|
702
|
+
Input-output relationship: N:N
|
|
703
|
+
|
|
704
|
+
It accepts the same parameters plus an
|
|
705
|
+
additional parameter:
|
|
706
|
+
|
|
707
|
+
batch : Size of each batch passed to `func`. Defaults to 1000.
|
|
708
|
+
|
|
709
|
+
Example:
|
|
710
|
+
```py
|
|
711
|
+
chain = chain.batch_map(
|
|
712
|
+
sqrt=lambda size: np.sqrt(size),
|
|
713
|
+
output=float
|
|
714
|
+
)
|
|
715
|
+
chain.save("new_dataset")
|
|
716
|
+
```
|
|
717
|
+
"""
|
|
718
|
+
udf_obj = self._udf_to_obj(BatchMapper, func, params, output, signal_map)
|
|
719
|
+
chain = DatasetQuery.add_signals(
|
|
720
|
+
self,
|
|
721
|
+
udf_obj.to_udf_wrapper(batch),
|
|
722
|
+
**self._settings.to_dict(),
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
return chain.add_schema(udf_obj.output).reset_settings(self._settings)
|
|
726
|
+
|
|
672
727
|
def _udf_to_obj(
|
|
673
728
|
self,
|
|
674
729
|
target_class: type[UDFBase],
|
|
@@ -1067,7 +1122,7 @@ class DataChain(DatasetQuery):
|
|
|
1067
1122
|
def _func_fr() -> Iterator[tuple_type]: # type: ignore[valid-type]
|
|
1068
1123
|
yield from tuples
|
|
1069
1124
|
|
|
1070
|
-
chain = DataChain.
|
|
1125
|
+
chain = DataChain.from_records(DataChain.DEFAULT_FILE_RECORD, session=session)
|
|
1071
1126
|
if object_name:
|
|
1072
1127
|
output = {object_name: DataChain._dict_to_data_model(object_name, output)} # type: ignore[arg-type]
|
|
1073
1128
|
return chain.gen(_func_fr, output=output)
|
|
@@ -1176,6 +1231,7 @@ class DataChain(DatasetQuery):
|
|
|
1176
1231
|
output: OutputType = None,
|
|
1177
1232
|
object_name: str = "",
|
|
1178
1233
|
model_name: str = "",
|
|
1234
|
+
source: bool = True,
|
|
1179
1235
|
nrows: Optional[int] = None,
|
|
1180
1236
|
**kwargs,
|
|
1181
1237
|
) -> "DataChain":
|
|
@@ -1187,8 +1243,9 @@ class DataChain(DatasetQuery):
|
|
|
1187
1243
|
case types will be inferred.
|
|
1188
1244
|
object_name : Generated object column name.
|
|
1189
1245
|
model_name : Generated model name.
|
|
1190
|
-
|
|
1246
|
+
source : Whether to include info about the source file.
|
|
1191
1247
|
nrows : Optional row limit.
|
|
1248
|
+
kwargs : Parameters to pass to pyarrow.dataset.dataset.
|
|
1192
1249
|
|
|
1193
1250
|
Example:
|
|
1194
1251
|
Reading a json lines file:
|
|
@@ -1215,18 +1272,24 @@ class DataChain(DatasetQuery):
|
|
|
1215
1272
|
except ValueError as e:
|
|
1216
1273
|
raise DatasetPrepareError(self.name, e) from e
|
|
1217
1274
|
|
|
1275
|
+
if isinstance(output, dict):
|
|
1276
|
+
model_name = model_name or object_name or ""
|
|
1277
|
+
model = DataChain._dict_to_data_model(model_name, output)
|
|
1278
|
+
else:
|
|
1279
|
+
model = output # type: ignore[assignment]
|
|
1280
|
+
|
|
1218
1281
|
if object_name:
|
|
1219
|
-
|
|
1220
|
-
model_name = model_name or object_name
|
|
1221
|
-
output = DataChain._dict_to_data_model(model_name, output)
|
|
1222
|
-
output = {object_name: output} # type: ignore[dict-item]
|
|
1282
|
+
output = {object_name: model} # type: ignore[dict-item]
|
|
1223
1283
|
elif isinstance(output, type(BaseModel)):
|
|
1224
1284
|
output = {
|
|
1225
1285
|
name: info.annotation # type: ignore[misc]
|
|
1226
1286
|
for name, info in output.model_fields.items()
|
|
1227
1287
|
}
|
|
1228
|
-
|
|
1229
|
-
|
|
1288
|
+
if source:
|
|
1289
|
+
output = {"source": IndexedFile} | output # type: ignore[assignment,operator]
|
|
1290
|
+
return self.gen(
|
|
1291
|
+
ArrowGenerator(schema, model, source, nrows, **kwargs), output=output
|
|
1292
|
+
)
|
|
1230
1293
|
|
|
1231
1294
|
@staticmethod
|
|
1232
1295
|
def _dict_to_data_model(
|
|
@@ -1245,10 +1308,10 @@ class DataChain(DatasetQuery):
|
|
|
1245
1308
|
path,
|
|
1246
1309
|
delimiter: str = ",",
|
|
1247
1310
|
header: bool = True,
|
|
1248
|
-
column_names: Optional[list[str]] = None,
|
|
1249
1311
|
output: OutputType = None,
|
|
1250
1312
|
object_name: str = "",
|
|
1251
1313
|
model_name: str = "",
|
|
1314
|
+
source: bool = True,
|
|
1252
1315
|
nrows=None,
|
|
1253
1316
|
**kwargs,
|
|
1254
1317
|
) -> "DataChain":
|
|
@@ -1264,6 +1327,7 @@ class DataChain(DatasetQuery):
|
|
|
1264
1327
|
case types will be inferred.
|
|
1265
1328
|
object_name : Created object column name.
|
|
1266
1329
|
model_name : Generated model name.
|
|
1330
|
+
source : Whether to include info about the source file.
|
|
1267
1331
|
nrows : Optional row limit.
|
|
1268
1332
|
|
|
1269
1333
|
Example:
|
|
@@ -1282,6 +1346,7 @@ class DataChain(DatasetQuery):
|
|
|
1282
1346
|
|
|
1283
1347
|
chain = DataChain.from_storage(path, **kwargs)
|
|
1284
1348
|
|
|
1349
|
+
column_names = None
|
|
1285
1350
|
if not header:
|
|
1286
1351
|
if not output:
|
|
1287
1352
|
msg = "error parsing csv - provide output if no header"
|
|
@@ -1303,6 +1368,7 @@ class DataChain(DatasetQuery):
|
|
|
1303
1368
|
output=output,
|
|
1304
1369
|
object_name=object_name,
|
|
1305
1370
|
model_name=model_name,
|
|
1371
|
+
source=source,
|
|
1306
1372
|
nrows=nrows,
|
|
1307
1373
|
format=format,
|
|
1308
1374
|
)
|
|
@@ -1315,6 +1381,7 @@ class DataChain(DatasetQuery):
|
|
|
1315
1381
|
output: Optional[dict[str, DataType]] = None,
|
|
1316
1382
|
object_name: str = "",
|
|
1317
1383
|
model_name: str = "",
|
|
1384
|
+
source: bool = True,
|
|
1318
1385
|
nrows=None,
|
|
1319
1386
|
**kwargs,
|
|
1320
1387
|
) -> "DataChain":
|
|
@@ -1327,6 +1394,7 @@ class DataChain(DatasetQuery):
|
|
|
1327
1394
|
output : Dictionary defining column names and their corresponding types.
|
|
1328
1395
|
object_name : Created object column name.
|
|
1329
1396
|
model_name : Generated model name.
|
|
1397
|
+
source : Whether to include info about the source file.
|
|
1330
1398
|
nrows : Optional row limit.
|
|
1331
1399
|
|
|
1332
1400
|
Example:
|
|
@@ -1345,6 +1413,7 @@ class DataChain(DatasetQuery):
|
|
|
1345
1413
|
output=output,
|
|
1346
1414
|
object_name=object_name,
|
|
1347
1415
|
model_name=model_name,
|
|
1416
|
+
source=source,
|
|
1348
1417
|
nrows=None,
|
|
1349
1418
|
format="parquet",
|
|
1350
1419
|
partitioning=partitioning,
|
|
@@ -1370,13 +1439,14 @@ class DataChain(DatasetQuery):
|
|
|
1370
1439
|
)
|
|
1371
1440
|
|
|
1372
1441
|
@classmethod
|
|
1373
|
-
def
|
|
1442
|
+
def from_records(
|
|
1374
1443
|
cls,
|
|
1375
1444
|
to_insert: Optional[Union[dict, list[dict]]],
|
|
1376
1445
|
session: Optional[Session] = None,
|
|
1377
1446
|
) -> "DataChain":
|
|
1378
|
-
"""Create
|
|
1379
|
-
generating a
|
|
1447
|
+
"""Create a DataChain from the provided records. This method can be used for
|
|
1448
|
+
programmatically generating a chain in contrast of reading data from storages
|
|
1449
|
+
or other sources.
|
|
1380
1450
|
|
|
1381
1451
|
Parameters:
|
|
1382
1452
|
to_insert : records (or a single record) to insert. Each record is
|
|
@@ -1384,8 +1454,8 @@ class DataChain(DatasetQuery):
|
|
|
1384
1454
|
|
|
1385
1455
|
Example:
|
|
1386
1456
|
```py
|
|
1387
|
-
empty = DataChain.
|
|
1388
|
-
single_record = DataChain.
|
|
1457
|
+
empty = DataChain.from_records()
|
|
1458
|
+
single_record = DataChain.from_records(DataChain.DEFAULT_FILE_RECORD)
|
|
1389
1459
|
```
|
|
1390
1460
|
"""
|
|
1391
1461
|
session = Session.get(session)
|
datachain/lib/meta_formats.py
CHANGED
|
@@ -101,7 +101,7 @@ def read_meta( # noqa: C901
|
|
|
101
101
|
schema_from=None,
|
|
102
102
|
meta_type="json",
|
|
103
103
|
jmespath=None,
|
|
104
|
-
|
|
104
|
+
print_schema=False,
|
|
105
105
|
model_name=None,
|
|
106
106
|
nrows=None,
|
|
107
107
|
) -> Callable:
|
|
@@ -129,7 +129,7 @@ def read_meta( # noqa: C901
|
|
|
129
129
|
model_output = captured_output.getvalue()
|
|
130
130
|
captured_output.close()
|
|
131
131
|
|
|
132
|
-
if
|
|
132
|
+
if print_schema:
|
|
133
133
|
print(f"{model_output}")
|
|
134
134
|
# Below 'spec' should be a dynamically converted DataModel from Pydantic
|
|
135
135
|
if not spec:
|
|
@@ -153,13 +153,13 @@ def read_meta( # noqa: C901
|
|
|
153
153
|
jmespath=jmespath,
|
|
154
154
|
nrows=nrows,
|
|
155
155
|
) -> Iterator[spec]:
|
|
156
|
-
def validator(json_object: dict) -> spec:
|
|
156
|
+
def validator(json_object: dict, nrow=0) -> spec:
|
|
157
157
|
json_string = json.dumps(json_object)
|
|
158
158
|
try:
|
|
159
159
|
data_instance = data_model.model_validate_json(json_string)
|
|
160
160
|
yield data_instance
|
|
161
161
|
except ValidationError as e:
|
|
162
|
-
print(f"Validation error occurred in file {file.name}:", e)
|
|
162
|
+
print(f"Validation error occurred in row {nrow} file {file.name}:", e)
|
|
163
163
|
|
|
164
164
|
if meta_type == "csv":
|
|
165
165
|
with (
|
|
@@ -185,7 +185,7 @@ def read_meta( # noqa: C901
|
|
|
185
185
|
nrow = nrow + 1
|
|
186
186
|
if nrows is not None and nrow > nrows:
|
|
187
187
|
return
|
|
188
|
-
yield from validator(json_dict)
|
|
188
|
+
yield from validator(json_dict, nrow)
|
|
189
189
|
|
|
190
190
|
if meta_type == "jsonl":
|
|
191
191
|
try:
|
|
@@ -198,7 +198,7 @@ def read_meta( # noqa: C901
|
|
|
198
198
|
return
|
|
199
199
|
json_object = process_json(data_string, jmespath)
|
|
200
200
|
data_string = fd.readline()
|
|
201
|
-
yield from validator(json_object)
|
|
201
|
+
yield from validator(json_object, nrow)
|
|
202
202
|
except OSError as e:
|
|
203
203
|
print(f"An unexpected file error occurred in file {file.name}: {e}")
|
|
204
204
|
|
datachain/lib/settings.py
CHANGED
|
@@ -7,11 +7,8 @@ class SettingsError(DataChainParamsError):
|
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class Settings:
|
|
10
|
-
def __init__(
|
|
11
|
-
self, cache=None, batch=None, parallel=None, workers=None, min_task_size=None
|
|
12
|
-
):
|
|
10
|
+
def __init__(self, cache=None, parallel=None, workers=None, min_task_size=None):
|
|
13
11
|
self._cache = cache
|
|
14
|
-
self._batch = batch
|
|
15
12
|
self.parallel = parallel
|
|
16
13
|
self._workers = workers
|
|
17
14
|
self.min_task_size = min_task_size
|
|
@@ -22,12 +19,6 @@ class Settings:
|
|
|
22
19
|
f" while {cache.__class__.__name__} was given"
|
|
23
20
|
)
|
|
24
21
|
|
|
25
|
-
if not isinstance(batch, int) and batch is not None:
|
|
26
|
-
raise SettingsError(
|
|
27
|
-
"'batch' argument must be int or None"
|
|
28
|
-
f" while {batch.__class__.__name__} was given"
|
|
29
|
-
)
|
|
30
|
-
|
|
31
22
|
if not isinstance(parallel, int) and parallel is not None:
|
|
32
23
|
raise SettingsError(
|
|
33
24
|
"'parallel' argument must be int or None"
|
|
@@ -54,10 +45,6 @@ class Settings:
|
|
|
54
45
|
def cache(self):
|
|
55
46
|
return self._cache if self._cache is not None else False
|
|
56
47
|
|
|
57
|
-
@property
|
|
58
|
-
def batch(self):
|
|
59
|
-
return self._batch if self._batch is not None else 1
|
|
60
|
-
|
|
61
48
|
@property
|
|
62
49
|
def workers(self):
|
|
63
50
|
return self._workers if self._workers is not None else False
|
|
@@ -66,8 +53,6 @@ class Settings:
|
|
|
66
53
|
res = {}
|
|
67
54
|
if self._cache is not None:
|
|
68
55
|
res["cache"] = self.cache
|
|
69
|
-
if self._batch is not None:
|
|
70
|
-
res["batch"] = self.batch
|
|
71
56
|
if self.parallel is not None:
|
|
72
57
|
res["parallel"] = self.parallel
|
|
73
58
|
if self._workers is not None:
|
|
@@ -78,7 +63,6 @@ class Settings:
|
|
|
78
63
|
|
|
79
64
|
def add(self, settings: "Settings"):
|
|
80
65
|
self._cache = settings._cache or self._cache
|
|
81
|
-
self._batch = settings._batch or self._batch
|
|
82
66
|
self.parallel = settings.parallel or self.parallel
|
|
83
67
|
self._workers = settings._workers or self._workers
|
|
84
68
|
self.min_task_size = settings.min_task_size or self.min_task_size
|
datachain/lib/signal_schema.py
CHANGED
|
@@ -243,8 +243,11 @@ class SignalSchema:
|
|
|
243
243
|
curr_type = None
|
|
244
244
|
i = 0
|
|
245
245
|
while curr_tree is not None and i < len(path):
|
|
246
|
-
if val := curr_tree.get(path[i]
|
|
246
|
+
if val := curr_tree.get(path[i]):
|
|
247
247
|
curr_type, curr_tree = val
|
|
248
|
+
elif i == 0 and len(path) > 1 and (val := curr_tree.get(".".join(path))):
|
|
249
|
+
curr_type, curr_tree = val
|
|
250
|
+
break
|
|
248
251
|
else:
|
|
249
252
|
curr_type = None
|
|
250
253
|
i += 1
|
datachain/lib/udf.py
CHANGED
|
@@ -225,11 +225,10 @@ class UDFBase(AbstractUDF):
|
|
|
225
225
|
def __call__(self, *rows, cache, download_cb):
|
|
226
226
|
if self.is_input_grouped:
|
|
227
227
|
objs = self._parse_grouped_rows(rows[0], cache, download_cb)
|
|
228
|
+
elif self.is_input_batched:
|
|
229
|
+
objs = zip(*self._parse_rows(rows[0], cache, download_cb))
|
|
228
230
|
else:
|
|
229
|
-
objs = self._parse_rows(rows, cache, download_cb)
|
|
230
|
-
|
|
231
|
-
if not self.is_input_batched:
|
|
232
|
-
objs = objs[0]
|
|
231
|
+
objs = self._parse_rows([rows], cache, download_cb)[0]
|
|
233
232
|
|
|
234
233
|
result_objs = self.process_safe(objs)
|
|
235
234
|
|
|
@@ -259,17 +258,24 @@ class UDFBase(AbstractUDF):
|
|
|
259
258
|
|
|
260
259
|
if not self.is_output_batched:
|
|
261
260
|
res = list(res)
|
|
262
|
-
assert
|
|
263
|
-
|
|
264
|
-
)
|
|
261
|
+
assert (
|
|
262
|
+
len(res) == 1
|
|
263
|
+
), f"{self.name} returns {len(res)} rows while it's not batched"
|
|
265
264
|
if isinstance(res[0], tuple):
|
|
266
265
|
res = res[0]
|
|
266
|
+
elif (
|
|
267
|
+
self.is_input_batched
|
|
268
|
+
and self.is_output_batched
|
|
269
|
+
and not self.is_input_grouped
|
|
270
|
+
):
|
|
271
|
+
res = list(res)
|
|
272
|
+
assert len(res) == len(
|
|
273
|
+
rows[0]
|
|
274
|
+
), f"{self.name} returns {len(res)} rows while len(rows[0]) expected"
|
|
267
275
|
|
|
268
276
|
return res
|
|
269
277
|
|
|
270
278
|
def _parse_rows(self, rows, cache, download_cb):
|
|
271
|
-
if not self.is_input_batched:
|
|
272
|
-
rows = [rows]
|
|
273
279
|
objs = []
|
|
274
280
|
for row in rows:
|
|
275
281
|
obj_row = self.params.row_to_objs(row)
|
|
@@ -330,7 +336,9 @@ class Mapper(UDFBase):
|
|
|
330
336
|
"""Inherit from this class to pass to `DataChain.map()`."""
|
|
331
337
|
|
|
332
338
|
|
|
333
|
-
class BatchMapper(
|
|
339
|
+
class BatchMapper(UDFBase):
|
|
340
|
+
"""Inherit from this class to pass to `DataChain.batch_map()`."""
|
|
341
|
+
|
|
334
342
|
is_input_batched = True
|
|
335
343
|
is_output_batched = True
|
|
336
344
|
|
datachain/query/dataset.py
CHANGED
|
@@ -262,9 +262,7 @@ class DatasetDiffOperation(Step):
|
|
|
262
262
|
temp_tables.extend(self.dq.temp_table_names)
|
|
263
263
|
|
|
264
264
|
# creating temp table that will hold subtract results
|
|
265
|
-
temp_table_name = self.catalog.warehouse.
|
|
266
|
-
6
|
|
267
|
-
)
|
|
265
|
+
temp_table_name = self.catalog.warehouse.temp_table_name()
|
|
268
266
|
temp_tables.append(temp_table_name)
|
|
269
267
|
|
|
270
268
|
columns = [
|
|
@@ -448,9 +446,6 @@ class UDFStep(Step, ABC):
|
|
|
448
446
|
to select
|
|
449
447
|
"""
|
|
450
448
|
|
|
451
|
-
def udf_table_name(self) -> str:
|
|
452
|
-
return self.catalog.warehouse.UDF_TABLE_NAME_PREFIX + _random_string(6)
|
|
453
|
-
|
|
454
449
|
def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
|
|
455
450
|
use_partitioning = self.partition_by is not None
|
|
456
451
|
batching = self.udf.properties.get_batching(use_partitioning)
|
|
@@ -574,9 +569,7 @@ class UDFStep(Step, ABC):
|
|
|
574
569
|
list_partition_by = [self.partition_by]
|
|
575
570
|
|
|
576
571
|
# create table with partitions
|
|
577
|
-
tbl = self.catalog.warehouse.create_udf_table(
|
|
578
|
-
self.udf_table_name(), partition_columns()
|
|
579
|
-
)
|
|
572
|
+
tbl = self.catalog.warehouse.create_udf_table(partition_columns())
|
|
580
573
|
|
|
581
574
|
# fill table with partitions
|
|
582
575
|
cols = [
|
|
@@ -638,37 +631,12 @@ class UDFSignal(UDFStep):
|
|
|
638
631
|
for (col_name, col_type) in self.udf.output.items()
|
|
639
632
|
]
|
|
640
633
|
|
|
641
|
-
return self.catalog.warehouse.create_udf_table(
|
|
642
|
-
self.udf_table_name(), udf_output_columns
|
|
643
|
-
)
|
|
644
|
-
|
|
645
|
-
def create_pre_udf_table(self, query: Select) -> "Table":
|
|
646
|
-
columns = [
|
|
647
|
-
sqlalchemy.Column(c.name, c.type)
|
|
648
|
-
for c in query.selected_columns
|
|
649
|
-
if c.name != "sys__id"
|
|
650
|
-
]
|
|
651
|
-
table = self.catalog.warehouse.create_udf_table(self.udf_table_name(), columns)
|
|
652
|
-
select_q = query.with_only_columns(
|
|
653
|
-
*[c for c in query.selected_columns if c.name != "sys__id"]
|
|
654
|
-
)
|
|
655
|
-
|
|
656
|
-
# if there is order by clause we need row_number to preserve order
|
|
657
|
-
# if there is no order by clause we still need row_number to generate
|
|
658
|
-
# unique ids as uniqueness is important for this table
|
|
659
|
-
select_q = select_q.add_columns(
|
|
660
|
-
f.row_number().over(order_by=select_q._order_by_clauses).label("sys__id")
|
|
661
|
-
)
|
|
662
|
-
|
|
663
|
-
self.catalog.warehouse.db.execute(
|
|
664
|
-
table.insert().from_select(list(select_q.selected_columns), select_q)
|
|
665
|
-
)
|
|
666
|
-
return table
|
|
634
|
+
return self.catalog.warehouse.create_udf_table(udf_output_columns)
|
|
667
635
|
|
|
668
636
|
def process_input_query(self, query: Select) -> tuple[Select, list["Table"]]:
|
|
669
637
|
if os.getenv("DATACHAIN_DISABLE_QUERY_CACHE", "") not in ("", "0"):
|
|
670
638
|
return query, []
|
|
671
|
-
table = self.create_pre_udf_table(query)
|
|
639
|
+
table = self.catalog.warehouse.create_pre_udf_table(query)
|
|
672
640
|
q: Select = sqlalchemy.select(*table.c)
|
|
673
641
|
if query._order_by_clauses:
|
|
674
642
|
# we are adding ordering only if it's explicitly added by user in
|
|
@@ -732,7 +700,7 @@ class RowGenerator(UDFStep):
|
|
|
732
700
|
def create_udf_table(self, query: Select) -> "Table":
|
|
733
701
|
warehouse = self.catalog.warehouse
|
|
734
702
|
|
|
735
|
-
table_name = self.udf_table_name()
|
|
703
|
+
table_name = self.catalog.warehouse.udf_table_name()
|
|
736
704
|
columns: tuple[Column, ...] = tuple(
|
|
737
705
|
Column(name, typ) for name, typ in self.udf.output.items()
|
|
738
706
|
)
|
|
@@ -1233,10 +1201,10 @@ class DatasetQuery:
|
|
|
1233
1201
|
# implementations, as errors may close or render unusable the existing
|
|
1234
1202
|
# connections.
|
|
1235
1203
|
metastore = self.catalog.metastore.clone(use_new_connection=True)
|
|
1236
|
-
metastore.
|
|
1204
|
+
metastore.cleanup_tables(self.temp_table_names)
|
|
1237
1205
|
metastore.close()
|
|
1238
1206
|
warehouse = self.catalog.warehouse.clone(use_new_connection=True)
|
|
1239
|
-
warehouse.
|
|
1207
|
+
warehouse.cleanup_tables(self.temp_table_names)
|
|
1240
1208
|
warehouse.close()
|
|
1241
1209
|
self.temp_table_names = []
|
|
1242
1210
|
|
|
@@ -1415,6 +1383,9 @@ class DatasetQuery:
|
|
|
1415
1383
|
@detach
|
|
1416
1384
|
def limit(self, n: int) -> "Self":
|
|
1417
1385
|
query = self.clone(new_table=False)
|
|
1386
|
+
for step in query.steps:
|
|
1387
|
+
if isinstance(step, SQLLimit) and step.n < n:
|
|
1388
|
+
return query
|
|
1418
1389
|
query.steps.append(SQLLimit(n))
|
|
1419
1390
|
return query
|
|
1420
1391
|
|
|
@@ -1802,10 +1773,3 @@ def query_wrapper(dataset_query: DatasetQuery) -> DatasetQuery:
|
|
|
1802
1773
|
|
|
1803
1774
|
_send_result(dataset_query)
|
|
1804
1775
|
return dataset_query
|
|
1805
|
-
|
|
1806
|
-
|
|
1807
|
-
def _random_string(length: int) -> str:
|
|
1808
|
-
return "".join(
|
|
1809
|
-
random.choice(string.ascii_letters + string.digits) # noqa: S311
|
|
1810
|
-
for i in range(length)
|
|
1811
|
-
)
|
datachain/sql/types.py
CHANGED
|
@@ -12,6 +12,7 @@ for sqlite we can use `sqlite.register_converter`
|
|
|
12
12
|
( https://docs.python.org/3/library/sqlite3.html#sqlite3.register_converter )
|
|
13
13
|
"""
|
|
14
14
|
|
|
15
|
+
import json
|
|
15
16
|
from datetime import datetime
|
|
16
17
|
from types import MappingProxyType
|
|
17
18
|
from typing import Any, Union
|
|
@@ -247,7 +248,10 @@ class Array(SQLType):
|
|
|
247
248
|
return type_defaults(dialect).array()
|
|
248
249
|
|
|
249
250
|
def on_read_convert(self, value, dialect):
|
|
250
|
-
|
|
251
|
+
r = read_converter(dialect).array(value, self.item_type, dialect)
|
|
252
|
+
if isinstance(self.item_type, JSON):
|
|
253
|
+
r = [json.loads(item) if isinstance(item, str) else item for item in r]
|
|
254
|
+
return r
|
|
251
255
|
|
|
252
256
|
|
|
253
257
|
class JSON(SQLType):
|
|
@@ -2,7 +2,7 @@ datachain/__init__.py,sha256=GeyhE-5LgfJav2OKYGaieP2lBvf2Gm-ihj7thnK9zjI,800
|
|
|
2
2
|
datachain/__main__.py,sha256=hG3Y4ARGEqe1AWwNMd259rBlqtphx1Wk39YbueQ0yV8,91
|
|
3
3
|
datachain/asyn.py,sha256=CKCFQJ0CbB3r04S7mUTXxriKzPnOvdUaVPXjM8vCtJw,7644
|
|
4
4
|
datachain/cache.py,sha256=N6PCEFJlWRpq7f_zeBNoaURFCJFAV7ibsLJqyiMHbBg,4207
|
|
5
|
-
datachain/cli.py,sha256=
|
|
5
|
+
datachain/cli.py,sha256=DbmI1sXs7-KCQz6RdLE_JAp3XO3yrTSRJ71LdUzx-XE,33099
|
|
6
6
|
datachain/cli_utils.py,sha256=jrn9ejGXjybeO1ur3fjdSiAyCHZrX0qsLLbJzN9ErPM,2418
|
|
7
7
|
datachain/config.py,sha256=PfC7W5yO6HFO6-iMB4YB-0RR88LPiGmD6sS_SfVbGso,1979
|
|
8
8
|
datachain/dataset.py,sha256=MZezyuJWNj_3PEtzr0epPMNyWAOTrhTSPI5FmemV6L4,14470
|
|
@@ -17,7 +17,7 @@ datachain/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
|
17
17
|
datachain/storage.py,sha256=RiSJLYdHUjnrEWkLBKPcETHpAxld_B2WxLg711t0aZI,3733
|
|
18
18
|
datachain/utils.py,sha256=kgH5NPj47eC_KrFTd6ZS206lKVhnJVFt5XsqkK6ppTc,12483
|
|
19
19
|
datachain/catalog/__init__.py,sha256=g2iAAFx_gEIrqshXlhSEbrc8qDaEH11cjU40n3CHDz4,409
|
|
20
|
-
datachain/catalog/catalog.py,sha256=
|
|
20
|
+
datachain/catalog/catalog.py,sha256=LZo9pIgi_HOUWpxX1c7RMt5OnrlDHXx2YpL5oP8X0kk,80397
|
|
21
21
|
datachain/catalog/datasource.py,sha256=D-VWIVDCM10A8sQavLhRXdYSCG7F4o4ifswEF80_NAQ,1412
|
|
22
22
|
datachain/catalog/loader.py,sha256=GJ8zhEYkC7TuaPzCsjJQ4LtTdECu-wwYzC12MikPOMQ,7307
|
|
23
23
|
datachain/catalog/subclass.py,sha256=B5R0qxeTYEyVAAPM1RutBPSoXZc8L5mVVZeSGXki9Sw,2096
|
|
@@ -32,41 +32,41 @@ datachain/data_storage/__init__.py,sha256=cEOJpyu1JDZtfUupYucCDNFI6e5Wmp_Oyzq6rZ
|
|
|
32
32
|
datachain/data_storage/db_engine.py,sha256=rgBuqJ-M1j5QyqiUQuJRewctuvRRj8LBDL54-aPEFxE,3287
|
|
33
33
|
datachain/data_storage/id_generator.py,sha256=VlDALKijggegAnNMJwuMETJgnLoPYxpkrkld5DNTPQw,3839
|
|
34
34
|
datachain/data_storage/job.py,sha256=w-7spowjkOa1P5fUVtJou3OltT0L48P0RYWZ9rSJ9-s,383
|
|
35
|
-
datachain/data_storage/metastore.py,sha256=
|
|
35
|
+
datachain/data_storage/metastore.py,sha256=ody-hWyrisGuNlzy24bc7QBqPXWIg64NcucIhZYronk,54842
|
|
36
36
|
datachain/data_storage/schema.py,sha256=FQvt5MUMSnI5ZAE7Nthae4aaJpt8JC4nH8KiWDuhJkk,8135
|
|
37
37
|
datachain/data_storage/serializer.py,sha256=6G2YtOFqqDzJf1KbvZraKGXl2XHZyVml2krunWUum5o,927
|
|
38
|
-
datachain/data_storage/sqlite.py,sha256=
|
|
39
|
-
datachain/data_storage/warehouse.py,sha256=
|
|
38
|
+
datachain/data_storage/sqlite.py,sha256=w0d_cZ2u9LpQYFFXll22mnxHaxPOoJdHlsKAZmONQpA,25605
|
|
39
|
+
datachain/data_storage/warehouse.py,sha256=3iD946WXgGxohZ5lagmwydFZr7j7RceZW423QXU_7_U,33120
|
|
40
40
|
datachain/lib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
41
|
-
datachain/lib/arrow.py,sha256=
|
|
41
|
+
datachain/lib/arrow.py,sha256=9C5AVH6tLo9hwzav-1tLLnmWP-3_SReYCOfcOC54pu0,4437
|
|
42
42
|
datachain/lib/clip.py,sha256=16u4b_y2Y15nUS2UN_8ximMo6r_-_4IQpmct2ol-e-g,5730
|
|
43
|
-
datachain/lib/data_model.py,sha256=
|
|
43
|
+
datachain/lib/data_model.py,sha256=qfTtQNncS5pt9SvXdMEa5kClniaT6XBGBfO7onEz2TI,1632
|
|
44
44
|
datachain/lib/dataset_info.py,sha256=lONGr71ozo1DS4CQEhnpKORaU4qFb6Ketv8Xm8CVm2U,2188
|
|
45
|
-
datachain/lib/dc.py,sha256=
|
|
45
|
+
datachain/lib/dc.py,sha256=6RtwA7MC3hosxi9RBgpOXjkv46SdN99g9N_u4mCDUUo,56071
|
|
46
46
|
datachain/lib/file.py,sha256=n9GBmZ1CjzDjHkbUBsUrs8JOJrAoh3MV2Cc8hBkex20,11957
|
|
47
47
|
datachain/lib/image.py,sha256=TgYhRhzd4nkytfFMeykQkPyzqb5Le_-tU81unVMPn4Q,2328
|
|
48
|
-
datachain/lib/meta_formats.py,sha256=
|
|
48
|
+
datachain/lib/meta_formats.py,sha256=jlSYWRUeDMjun_YCsQ2JxyaDJpEpokzHDPmKUAoCXnU,7034
|
|
49
49
|
datachain/lib/model_store.py,sha256=c4USXsBBjrGH8VOh4seIgOiav-qHOwdoixtxfLgU63c,2409
|
|
50
50
|
datachain/lib/pytorch.py,sha256=9PsypKseyKfIimTmTQOgb-pbNXgeeAHLdlWx0qRPULY,5660
|
|
51
|
-
datachain/lib/settings.py,sha256=
|
|
52
|
-
datachain/lib/signal_schema.py,sha256=
|
|
51
|
+
datachain/lib/settings.py,sha256=39thOpYJw-zPirzeNO6pmRC2vPrQvt4eBsw1xLWDFsw,2344
|
|
52
|
+
datachain/lib/signal_schema.py,sha256=XQTINSN_FJK76Jn8qd03g6J0cum58knP8U7Iuw-zKyU,14704
|
|
53
53
|
datachain/lib/text.py,sha256=dVe2Ilc_gW2EV0kun0UwegiCkapWcd20cef7CgINWHU,1083
|
|
54
|
-
datachain/lib/udf.py,sha256=
|
|
54
|
+
datachain/lib/udf.py,sha256=IjuDt2B8E3xEHhcJnaK_ZhmivdrOYPXz5uf7ylpktws,11815
|
|
55
55
|
datachain/lib/udf_signature.py,sha256=gMStcEeYJka5M6cg50Z9orC6y6HzCAJ3MkFqqn1fjZg,7137
|
|
56
56
|
datachain/lib/utils.py,sha256=5-kJlAZE0D9nXXweAjo7-SP_AWGo28feaDByONYaooQ,463
|
|
57
57
|
datachain/lib/vfile.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
58
58
|
datachain/lib/webdataset.py,sha256=nIa6ubv94CwnATeeSdE7f_F9Zkz9LuBTfbXvFg3_-Ak,8295
|
|
59
59
|
datachain/lib/webdataset_laion.py,sha256=PQP6tQmUP7Xu9fPuAGK1JDBYA6T5UufYMUTGaxgspJA,2118
|
|
60
60
|
datachain/lib/convert/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
61
|
-
datachain/lib/convert/flatten.py,sha256=
|
|
62
|
-
datachain/lib/convert/python_to_sql.py,sha256=
|
|
61
|
+
datachain/lib/convert/flatten.py,sha256=YMoC00BqEy3zSpvCp6Q0DfxihuPmgjUJj1g2cesWGPs,1790
|
|
62
|
+
datachain/lib/convert/python_to_sql.py,sha256=4gplGlr_Kg-Z40OpJUzJiarDWj7pwbUOk-dPOYYCJ9Q,2629
|
|
63
63
|
datachain/lib/convert/sql_to_python.py,sha256=HK414fexSQ4Ur-OY7_pKvDKEGdtos1CeeAFa4RxH4nU,532
|
|
64
64
|
datachain/lib/convert/unflatten.py,sha256=Ogvh_5wg2f38_At_1lN0D_e2uZOOpYEvwvB2xdq56Tw,2012
|
|
65
65
|
datachain/lib/convert/values_to_tuples.py,sha256=aVoHWMOUGLAiS6_BBwKJqVIne91VffOW6-dWyNE7oHg,3715
|
|
66
66
|
datachain/query/__init__.py,sha256=tv-spkjUCYamMN9ys_90scYrZ8kJ7C7d1MTYVmxGtk4,325
|
|
67
67
|
datachain/query/batch.py,sha256=j-_ZcuQra2Ro3Wj4crtqQCg-7xuv-p84hr4QHdvT7as,3479
|
|
68
68
|
datachain/query/builtins.py,sha256=ZKNs49t8Oa_OaboCBIEqtXZt7c1Qe9OR_C_HpoDriIU,2781
|
|
69
|
-
datachain/query/dataset.py,sha256=
|
|
69
|
+
datachain/query/dataset.py,sha256=iTz3c5nJ-WmoQ5zcvKGT9ly6xVKJtD_fk76LA7zecWk,60164
|
|
70
70
|
datachain/query/dispatch.py,sha256=oGX9ZuoKWPB_EyqAZD_eULcO3OejY44_keSmFS6SHT0,13315
|
|
71
71
|
datachain/query/metrics.py,sha256=vsECqbZfoSDBnvC3GQlziKXmISVYDLgHP1fMPEOtKyo,640
|
|
72
72
|
datachain/query/params.py,sha256=O_j89mjYRLOwWNhYZl-z7mi-rkdP7WyFmaDufsdTryE,863
|
|
@@ -77,7 +77,7 @@ datachain/remote/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
|
|
|
77
77
|
datachain/remote/studio.py,sha256=f5s6qSZ9uB4URGUoU_8_W1KZRRQQVSm6cgEBkBUEfuE,7226
|
|
78
78
|
datachain/sql/__init__.py,sha256=A2djrbQwSMUZZEIKGnm-mnRA-NDSbiDJNpAmmwGNyIo,303
|
|
79
79
|
datachain/sql/selectable.py,sha256=fBM-wS1TUA42kVEAAiwqGtibIevyZAEritwt8PZGyLQ,1589
|
|
80
|
-
datachain/sql/types.py,sha256=
|
|
80
|
+
datachain/sql/types.py,sha256=SShudhdIpdfTKDxWDDqOajYRkTCkIgQbilA94g4i-4E,10389
|
|
81
81
|
datachain/sql/utils.py,sha256=rzlJw08etivdrcuQPqNVvVWhuVSyUPUQEEc6DOhu258,818
|
|
82
82
|
datachain/sql/default/__init__.py,sha256=XQ2cEZpzWiABqjV-6yYHUBGI9vN_UHxbxZENESmVAWw,45
|
|
83
83
|
datachain/sql/default/base.py,sha256=h44005q3qtMc9cjWmRufWwcBr5CfK_dnvG4IrcSQs_8,536
|
|
@@ -92,9 +92,9 @@ datachain/sql/sqlite/base.py,sha256=Jb1csbIARjEvwbylnvgNA7ChozSyoL3CQzOGBUf8QAw,
|
|
|
92
92
|
datachain/sql/sqlite/types.py,sha256=yzvp0sXSEoEYXs6zaYC_2YubarQoZH-MiUNXcpuEP4s,1573
|
|
93
93
|
datachain/sql/sqlite/vector.py,sha256=ncW4eu2FlJhrP_CIpsvtkUabZlQdl2D5Lgwy_cbfqR0,469
|
|
94
94
|
datachain/torch/__init__.py,sha256=gIS74PoEPy4TB3X6vx9nLO0Y3sLJzsA8ckn8pRWihJM,579
|
|
95
|
-
datachain-0.2.
|
|
96
|
-
datachain-0.2.
|
|
97
|
-
datachain-0.2.
|
|
98
|
-
datachain-0.2.
|
|
99
|
-
datachain-0.2.
|
|
100
|
-
datachain-0.2.
|
|
95
|
+
datachain-0.2.16.dist-info/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
|
|
96
|
+
datachain-0.2.16.dist-info/METADATA,sha256=1f326fK-ZnS0nPvETuUj9PaI4R5SatpGVDIsQiJ0OvM,14577
|
|
97
|
+
datachain-0.2.16.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
|
98
|
+
datachain-0.2.16.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
|
|
99
|
+
datachain-0.2.16.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
|
|
100
|
+
datachain-0.2.16.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|