datachain 0.7.10__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +53 -41
- datachain/cli.py +25 -3
- datachain/client/__init__.py +1 -2
- datachain/data_storage/sqlite.py +20 -6
- datachain/lib/dc.py +160 -110
- datachain/lib/diff.py +197 -0
- datachain/lib/file.py +2 -1
- datachain/lib/meta_formats.py +40 -43
- datachain/lib/pytorch.py +1 -5
- datachain/lib/signal_schema.py +28 -6
- datachain/query/dataset.py +5 -1
- datachain/remote/studio.py +53 -1
- datachain/studio.py +47 -2
- datachain/toolkit/split.py +19 -6
- {datachain-0.7.10.dist-info → datachain-0.8.0.dist-info}/METADATA +10 -10
- {datachain-0.7.10.dist-info → datachain-0.8.0.dist-info}/RECORD +20 -19
- {datachain-0.7.10.dist-info → datachain-0.8.0.dist-info}/LICENSE +0 -0
- {datachain-0.7.10.dist-info → datachain-0.8.0.dist-info}/WHEEL +0 -0
- {datachain-0.7.10.dist-info → datachain-0.8.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.7.10.dist-info → datachain-0.8.0.dist-info}/top_level.txt +0 -0
datachain/lib/diff.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import string
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from typing import TYPE_CHECKING, Optional, Union
|
|
5
|
+
|
|
6
|
+
import sqlalchemy as sa
|
|
7
|
+
|
|
8
|
+
from datachain.lib.signal_schema import SignalSchema
|
|
9
|
+
from datachain.query.schema import Column
|
|
10
|
+
from datachain.sql.types import String
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from datachain.lib.dc import DataChain
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
C = Column
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def compare( # noqa: PLR0912, PLR0915, C901
|
|
20
|
+
left: "DataChain",
|
|
21
|
+
right: "DataChain",
|
|
22
|
+
on: Union[str, Sequence[str]],
|
|
23
|
+
right_on: Optional[Union[str, Sequence[str]]] = None,
|
|
24
|
+
compare: Optional[Union[str, Sequence[str]]] = None,
|
|
25
|
+
right_compare: Optional[Union[str, Sequence[str]]] = None,
|
|
26
|
+
added: bool = True,
|
|
27
|
+
deleted: bool = True,
|
|
28
|
+
modified: bool = True,
|
|
29
|
+
same: bool = True,
|
|
30
|
+
status_col: Optional[str] = None,
|
|
31
|
+
) -> "DataChain":
|
|
32
|
+
"""Comparing two chains by identifying rows that are added, deleted, modified
|
|
33
|
+
or same"""
|
|
34
|
+
dialect = left._query.dialect
|
|
35
|
+
|
|
36
|
+
rname = "right_"
|
|
37
|
+
|
|
38
|
+
def _rprefix(c: str, rc: str) -> str:
|
|
39
|
+
"""Returns prefix of right of two companion left - right columns
|
|
40
|
+
from merge. If companion columns have the same name then prefix will
|
|
41
|
+
be present in right column name, otherwise it won't.
|
|
42
|
+
"""
|
|
43
|
+
return rname if c == rc else ""
|
|
44
|
+
|
|
45
|
+
def _to_list(obj: Union[str, Sequence[str]]) -> list[str]:
|
|
46
|
+
return [obj] if isinstance(obj, str) else list(obj)
|
|
47
|
+
|
|
48
|
+
if on is None:
|
|
49
|
+
raise ValueError("'on' must be specified")
|
|
50
|
+
|
|
51
|
+
on = _to_list(on)
|
|
52
|
+
if right_on:
|
|
53
|
+
right_on = _to_list(right_on)
|
|
54
|
+
if len(on) != len(right_on):
|
|
55
|
+
raise ValueError("'on' and 'right_on' must be have the same length")
|
|
56
|
+
|
|
57
|
+
if compare:
|
|
58
|
+
compare = _to_list(compare)
|
|
59
|
+
|
|
60
|
+
if right_compare:
|
|
61
|
+
if not compare:
|
|
62
|
+
raise ValueError("'compare' must be defined if 'right_compare' is defined")
|
|
63
|
+
|
|
64
|
+
right_compare = _to_list(right_compare)
|
|
65
|
+
if len(compare) != len(right_compare):
|
|
66
|
+
raise ValueError(
|
|
67
|
+
"'compare' and 'right_compare' must be have the same length"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
if not any([added, deleted, modified, same]):
|
|
71
|
+
raise ValueError(
|
|
72
|
+
"At least one of added, deleted, modified, same flags must be set"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# we still need status column for internal implementation even if not
|
|
76
|
+
# needed in output
|
|
77
|
+
need_status_col = bool(status_col)
|
|
78
|
+
status_col = status_col or "diff_" + "".join(
|
|
79
|
+
random.choice(string.ascii_letters) # noqa: S311
|
|
80
|
+
for _ in range(10)
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# calculate on and compare column names
|
|
84
|
+
right_on = right_on or on
|
|
85
|
+
cols = left.signals_schema.clone_without_sys_signals().db_signals()
|
|
86
|
+
right_cols = right.signals_schema.clone_without_sys_signals().db_signals()
|
|
87
|
+
|
|
88
|
+
on = left.signals_schema.resolve(*on).db_signals() # type: ignore[assignment]
|
|
89
|
+
right_on = right.signals_schema.resolve(*right_on).db_signals() # type: ignore[assignment]
|
|
90
|
+
if compare:
|
|
91
|
+
right_compare = right_compare or compare
|
|
92
|
+
compare = left.signals_schema.resolve(*compare).db_signals() # type: ignore[assignment]
|
|
93
|
+
right_compare = right.signals_schema.resolve(*right_compare).db_signals() # type: ignore[assignment]
|
|
94
|
+
elif not compare and len(cols) != len(right_cols):
|
|
95
|
+
# here we will mark all rows that are not added or deleted as modified since
|
|
96
|
+
# there was no explicit list of compare columns provided (meaning we need
|
|
97
|
+
# to check all columns to determine if row is modified or same), but
|
|
98
|
+
# the number of columns on left and right is not the same (one of the chains
|
|
99
|
+
# have additional column)
|
|
100
|
+
compare = None
|
|
101
|
+
right_compare = None
|
|
102
|
+
else:
|
|
103
|
+
compare = [c for c in cols if c in right_cols] # type: ignore[misc, assignment]
|
|
104
|
+
right_compare = compare
|
|
105
|
+
|
|
106
|
+
diff_cond = []
|
|
107
|
+
|
|
108
|
+
if added:
|
|
109
|
+
added_cond = sa.and_(
|
|
110
|
+
*[
|
|
111
|
+
C(c) == None # noqa: E711
|
|
112
|
+
for c in [f"{_rprefix(c, rc)}{rc}" for c, rc in zip(on, right_on)]
|
|
113
|
+
]
|
|
114
|
+
)
|
|
115
|
+
diff_cond.append((added_cond, "A"))
|
|
116
|
+
if modified and compare:
|
|
117
|
+
modified_cond = sa.or_(
|
|
118
|
+
*[
|
|
119
|
+
C(c) != C(f"{_rprefix(c, rc)}{rc}")
|
|
120
|
+
for c, rc in zip(compare, right_compare) # type: ignore[arg-type]
|
|
121
|
+
]
|
|
122
|
+
)
|
|
123
|
+
diff_cond.append((modified_cond, "M"))
|
|
124
|
+
if same and compare:
|
|
125
|
+
same_cond = sa.and_(
|
|
126
|
+
*[
|
|
127
|
+
C(c) == C(f"{_rprefix(c, rc)}{rc}")
|
|
128
|
+
for c, rc in zip(compare, right_compare) # type: ignore[arg-type]
|
|
129
|
+
]
|
|
130
|
+
)
|
|
131
|
+
diff_cond.append((same_cond, "S"))
|
|
132
|
+
|
|
133
|
+
diff = sa.case(*diff_cond, else_=None if compare else "M").label(status_col)
|
|
134
|
+
diff.type = String()
|
|
135
|
+
|
|
136
|
+
left_right_merge = left.merge(
|
|
137
|
+
right, on=on, right_on=right_on, inner=False, rname=rname
|
|
138
|
+
)
|
|
139
|
+
left_right_merge_select = left_right_merge._query.select(
|
|
140
|
+
*(
|
|
141
|
+
[C(c) for c in left_right_merge.signals_schema.db_signals("sys")]
|
|
142
|
+
+ [C(c) for c in on]
|
|
143
|
+
+ [C(c) for c in cols if c not in on]
|
|
144
|
+
+ [diff]
|
|
145
|
+
)
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
diff_col = sa.literal("D").label(status_col)
|
|
149
|
+
diff_col.type = String()
|
|
150
|
+
|
|
151
|
+
right_left_merge = right.merge(
|
|
152
|
+
left, on=right_on, right_on=on, inner=False, rname=rname
|
|
153
|
+
).filter(
|
|
154
|
+
sa.and_(
|
|
155
|
+
*[C(f"{_rprefix(c, rc)}{c}") == None for c, rc in zip(on, right_on)] # noqa: E711
|
|
156
|
+
)
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
def _default_val(chain: "DataChain", col: str):
|
|
160
|
+
col_type = chain._query.column_types[col] # type: ignore[index]
|
|
161
|
+
val = sa.literal(col_type.default_value(dialect)).label(col)
|
|
162
|
+
val.type = col_type()
|
|
163
|
+
return val
|
|
164
|
+
|
|
165
|
+
right_left_merge_select = right_left_merge._query.select(
|
|
166
|
+
*(
|
|
167
|
+
[C(c) for c in right_left_merge.signals_schema.db_signals("sys")]
|
|
168
|
+
+ [
|
|
169
|
+
C(c) if c == rc else _default_val(left, c)
|
|
170
|
+
for c, rc in zip(on, right_on)
|
|
171
|
+
]
|
|
172
|
+
+ [
|
|
173
|
+
C(c) if c in right_cols else _default_val(left, c) # type: ignore[arg-type]
|
|
174
|
+
for c in cols
|
|
175
|
+
if c not in on
|
|
176
|
+
]
|
|
177
|
+
+ [diff_col]
|
|
178
|
+
)
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
if not deleted:
|
|
182
|
+
res = left_right_merge_select
|
|
183
|
+
elif deleted and not any([added, modified, same]):
|
|
184
|
+
res = right_left_merge_select
|
|
185
|
+
else:
|
|
186
|
+
res = left_right_merge_select.union(right_left_merge_select)
|
|
187
|
+
|
|
188
|
+
res = res.filter(C(status_col) != None) # noqa: E711
|
|
189
|
+
|
|
190
|
+
schema = left.signals_schema
|
|
191
|
+
if need_status_col:
|
|
192
|
+
res = res.select()
|
|
193
|
+
schema = SignalSchema({status_col: str}) | schema
|
|
194
|
+
else:
|
|
195
|
+
res = res.select_except(C(status_col))
|
|
196
|
+
|
|
197
|
+
return left._evolve(query=res, signal_schema=schema)
|
datachain/lib/file.py
CHANGED
|
@@ -17,7 +17,6 @@ from urllib.request import url2pathname
|
|
|
17
17
|
|
|
18
18
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
19
19
|
from PIL import Image
|
|
20
|
-
from pyarrow.dataset import dataset
|
|
21
20
|
from pydantic import Field, field_validator
|
|
22
21
|
|
|
23
22
|
from datachain.client.fileslice import FileSlice
|
|
@@ -452,6 +451,8 @@ class ArrowRow(DataModel):
|
|
|
452
451
|
@contextmanager
|
|
453
452
|
def open(self):
|
|
454
453
|
"""Stream row contents from indexed file."""
|
|
454
|
+
from pyarrow.dataset import dataset
|
|
455
|
+
|
|
455
456
|
if self.file._caching_enabled:
|
|
456
457
|
self.file.ensure_cached()
|
|
457
458
|
path = self.file.get_local_path()
|
datachain/lib/meta_formats.py
CHANGED
|
@@ -6,7 +6,6 @@ from collections.abc import Iterator
|
|
|
6
6
|
from pathlib import Path
|
|
7
7
|
from typing import Callable
|
|
8
8
|
|
|
9
|
-
import datamodel_code_generator
|
|
10
9
|
import jmespath as jsp
|
|
11
10
|
from pydantic import BaseModel, ConfigDict, Field, ValidationError # noqa: F401
|
|
12
11
|
|
|
@@ -39,36 +38,41 @@ def process_json(data_string, jmespath):
|
|
|
39
38
|
return json_dict
|
|
40
39
|
|
|
41
40
|
|
|
42
|
-
|
|
43
|
-
|
|
41
|
+
def gen_datamodel_code(
|
|
42
|
+
source_file, format="json", jmespath=None, model_name=None
|
|
43
|
+
) -> str:
|
|
44
|
+
"""Generates Python code with Pydantic models that corresponds
|
|
45
|
+
to the provided JSON, CSV, or JSONL file.
|
|
46
|
+
It support root JSON arrays (samples the first entry).
|
|
47
|
+
"""
|
|
44
48
|
data_string = ""
|
|
45
49
|
# using uiid to get around issue #1617
|
|
46
50
|
if not model_name:
|
|
47
51
|
# comply with Python class names
|
|
48
52
|
uid_str = str(generate_uuid()).replace("-", "")
|
|
49
|
-
model_name = f"Model{
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
json_object = process_json(data_string, expr)
|
|
64
|
-
if data_type == "json" and isinstance(json_object, list):
|
|
53
|
+
model_name = f"Model{format}{uid_str}"
|
|
54
|
+
|
|
55
|
+
with source_file.open() as fd: # CSV can be larger than memory
|
|
56
|
+
if format == "csv":
|
|
57
|
+
data_string += fd.readline().replace("\r", "")
|
|
58
|
+
data_string += fd.readline().replace("\r", "")
|
|
59
|
+
elif format == "jsonl":
|
|
60
|
+
data_string = fd.readline().replace("\r", "")
|
|
61
|
+
else:
|
|
62
|
+
data_string = fd.read() # other meta must fit into RAM
|
|
63
|
+
|
|
64
|
+
if format in ("json", "jsonl"):
|
|
65
|
+
json_object = process_json(data_string, jmespath)
|
|
66
|
+
if format == "json" and isinstance(json_object, list):
|
|
65
67
|
json_object = json_object[0] # sample the 1st object from JSON array
|
|
66
|
-
if
|
|
67
|
-
|
|
68
|
+
if format == "jsonl":
|
|
69
|
+
format = "json" # treat json line as plain JSON in auto-schema
|
|
68
70
|
data_string = json.dumps(json_object)
|
|
69
71
|
|
|
72
|
+
import datamodel_code_generator
|
|
73
|
+
|
|
70
74
|
input_file_types = {i.value: i for i in datamodel_code_generator.InputFileType}
|
|
71
|
-
input_file_type = input_file_types[
|
|
75
|
+
input_file_type = input_file_types[format]
|
|
72
76
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
73
77
|
output = Path(tmpdir) / "model.py"
|
|
74
78
|
datamodel_code_generator.generate(
|
|
@@ -94,36 +98,29 @@ spec = {model_name}
|
|
|
94
98
|
def read_meta( # noqa: C901
|
|
95
99
|
spec=None,
|
|
96
100
|
schema_from=None,
|
|
97
|
-
|
|
101
|
+
format="json",
|
|
98
102
|
jmespath=None,
|
|
99
|
-
print_schema=False,
|
|
100
103
|
model_name=None,
|
|
101
104
|
nrows=None,
|
|
102
105
|
) -> Callable:
|
|
103
106
|
from datachain.lib.dc import DataChain
|
|
104
107
|
|
|
105
108
|
if schema_from:
|
|
106
|
-
|
|
107
|
-
DataChain.from_storage(schema_from, type="text")
|
|
108
|
-
.limit(1)
|
|
109
|
-
.map( # dummy column created (#1615)
|
|
110
|
-
meta_schema=lambda file: read_schema(
|
|
111
|
-
file, data_type=meta_type, expr=jmespath, model_name=model_name
|
|
112
|
-
),
|
|
113
|
-
output=str,
|
|
114
|
-
)
|
|
109
|
+
file = next(
|
|
110
|
+
DataChain.from_storage(schema_from, type="text").limit(1).collect("file")
|
|
115
111
|
)
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
112
|
+
model_code = gen_datamodel_code(
|
|
113
|
+
file, format=format, jmespath=jmespath, model_name=model_name
|
|
114
|
+
)
|
|
115
|
+
assert isinstance(model_code, str)
|
|
116
|
+
|
|
120
117
|
# Below 'spec' should be a dynamically converted DataModel from Pydantic
|
|
121
118
|
if not spec:
|
|
122
119
|
gl = globals()
|
|
123
|
-
exec(
|
|
120
|
+
exec(model_code, gl) # type: ignore[arg-type] # noqa: S102
|
|
124
121
|
spec = gl["spec"]
|
|
125
122
|
|
|
126
|
-
if not
|
|
123
|
+
if not spec and not schema_from:
|
|
127
124
|
raise ValueError(
|
|
128
125
|
"Must provide a static schema in spec: or metadata sample in schema_from:"
|
|
129
126
|
)
|
|
@@ -135,7 +132,7 @@ def read_meta( # noqa: C901
|
|
|
135
132
|
def parse_data(
|
|
136
133
|
file: File,
|
|
137
134
|
data_model=spec,
|
|
138
|
-
|
|
135
|
+
format=format,
|
|
139
136
|
jmespath=jmespath,
|
|
140
137
|
nrows=nrows,
|
|
141
138
|
) -> Iterator[spec]:
|
|
@@ -147,7 +144,7 @@ def read_meta( # noqa: C901
|
|
|
147
144
|
except ValidationError as e:
|
|
148
145
|
print(f"Validation error occurred in row {nrow} file {file.name}:", e)
|
|
149
146
|
|
|
150
|
-
if
|
|
147
|
+
if format == "csv":
|
|
151
148
|
with (
|
|
152
149
|
file.open() as fd
|
|
153
150
|
): # TODO: if schema is statically given, should allow CSV without headers
|
|
@@ -155,7 +152,7 @@ def read_meta( # noqa: C901
|
|
|
155
152
|
for row in reader: # CSV can be larger than memory
|
|
156
153
|
yield from validator(row)
|
|
157
154
|
|
|
158
|
-
if
|
|
155
|
+
if format == "json":
|
|
159
156
|
try:
|
|
160
157
|
with file.open() as fd: # JSON must fit into RAM
|
|
161
158
|
data_string = fd.read()
|
|
@@ -173,7 +170,7 @@ def read_meta( # noqa: C901
|
|
|
173
170
|
return
|
|
174
171
|
yield from validator(json_dict, nrow)
|
|
175
172
|
|
|
176
|
-
if
|
|
173
|
+
if format == "jsonl":
|
|
177
174
|
try:
|
|
178
175
|
nrow = 0
|
|
179
176
|
with file.open() as fd:
|
datachain/lib/pytorch.py
CHANGED
|
@@ -7,7 +7,6 @@ from torch import float32
|
|
|
7
7
|
from torch.distributed import get_rank, get_world_size
|
|
8
8
|
from torch.utils.data import IterableDataset, get_worker_info
|
|
9
9
|
from torchvision.transforms import v2
|
|
10
|
-
from tqdm import tqdm
|
|
11
10
|
|
|
12
11
|
from datachain import Session
|
|
13
12
|
from datachain.asyn import AsyncMapper
|
|
@@ -112,10 +111,7 @@ class PytorchDataset(IterableDataset):
|
|
|
112
111
|
from datachain.lib.udf import _prefetch_input
|
|
113
112
|
|
|
114
113
|
rows = AsyncMapper(_prefetch_input, rows, workers=self.prefetch).iterate()
|
|
115
|
-
|
|
116
|
-
desc = f"Parsed PyTorch dataset for rank={total_rank} worker"
|
|
117
|
-
with tqdm(rows, desc=desc, unit=" rows", position=total_rank) as rows_it:
|
|
118
|
-
yield from map(self._process_row, rows_it)
|
|
114
|
+
yield from map(self._process_row, rows)
|
|
119
115
|
|
|
120
116
|
def _process_row(self, row_features):
|
|
121
117
|
row = []
|
datachain/lib/signal_schema.py
CHANGED
|
@@ -402,9 +402,20 @@ class SignalSchema:
|
|
|
402
402
|
if ModelStore.is_pydantic(finfo.annotation):
|
|
403
403
|
SignalSchema._set_file_stream(getattr(obj, field), catalog, cache)
|
|
404
404
|
|
|
405
|
-
def get_column_type(self, col_name: str) -> DataType:
|
|
405
|
+
def get_column_type(self, col_name: str, with_subtree: bool = False) -> DataType:
|
|
406
|
+
"""
|
|
407
|
+
Returns column type by column name.
|
|
408
|
+
|
|
409
|
+
If `with_subtree` is True, then it will return the type of the column
|
|
410
|
+
even if it has a subtree (e.g. model with nested fields), otherwise it will
|
|
411
|
+
return the type of the column (standard type field, not the model).
|
|
412
|
+
|
|
413
|
+
If column is not found, raises `SignalResolvingError`.
|
|
414
|
+
"""
|
|
406
415
|
for path, _type, has_subtree, _ in self.get_flat_tree():
|
|
407
|
-
if not has_subtree and DEFAULT_DELIMITER.join(
|
|
416
|
+
if (with_subtree or not has_subtree) and DEFAULT_DELIMITER.join(
|
|
417
|
+
path
|
|
418
|
+
) == col_name:
|
|
408
419
|
return _type
|
|
409
420
|
raise SignalResolvingError([col_name], "is not found")
|
|
410
421
|
|
|
@@ -492,14 +503,25 @@ class SignalSchema:
|
|
|
492
503
|
# renaming existing signal
|
|
493
504
|
del new_values[value.name]
|
|
494
505
|
new_values[name] = self.values[value.name]
|
|
495
|
-
|
|
506
|
+
continue
|
|
507
|
+
if isinstance(value, Column):
|
|
508
|
+
# adding new signal from existing signal field
|
|
509
|
+
try:
|
|
510
|
+
new_values[name] = self.get_column_type(
|
|
511
|
+
value.name, with_subtree=True
|
|
512
|
+
)
|
|
513
|
+
continue
|
|
514
|
+
except SignalResolvingError:
|
|
515
|
+
pass
|
|
516
|
+
if isinstance(value, Func):
|
|
496
517
|
# adding new signal with function
|
|
497
518
|
new_values[name] = value.get_result_type(self)
|
|
498
|
-
|
|
519
|
+
continue
|
|
520
|
+
if isinstance(value, ColumnElement):
|
|
499
521
|
# adding new signal
|
|
500
522
|
new_values[name] = sql_to_python(value)
|
|
501
|
-
|
|
502
|
-
|
|
523
|
+
continue
|
|
524
|
+
new_values[name] = value
|
|
503
525
|
|
|
504
526
|
return SignalSchema(new_values)
|
|
505
527
|
|
datachain/query/dataset.py
CHANGED
|
@@ -35,7 +35,6 @@ from sqlalchemy.sql.schema import TableClause
|
|
|
35
35
|
from sqlalchemy.sql.selectable import Select
|
|
36
36
|
|
|
37
37
|
from datachain.asyn import ASYNC_WORKERS, AsyncMapper, OrderedMapper
|
|
38
|
-
from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE, get_catalog
|
|
39
38
|
from datachain.data_storage.schema import (
|
|
40
39
|
PARTITION_COLUMN_ID,
|
|
41
40
|
partition_col_names,
|
|
@@ -394,6 +393,8 @@ class UDFStep(Step, ABC):
|
|
|
394
393
|
"""
|
|
395
394
|
|
|
396
395
|
def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
|
|
396
|
+
from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE
|
|
397
|
+
|
|
397
398
|
use_partitioning = self.partition_by is not None
|
|
398
399
|
batching = self.udf.get_batching(use_partitioning)
|
|
399
400
|
workers = self.workers
|
|
@@ -1068,6 +1069,7 @@ class DatasetQuery:
|
|
|
1068
1069
|
if "sys__id" in self.column_types:
|
|
1069
1070
|
self.column_types.pop("sys__id")
|
|
1070
1071
|
self.starting_step = QueryStep(self.catalog, name, self.version)
|
|
1072
|
+
self.dialect = self.catalog.warehouse.db.dialect
|
|
1071
1073
|
|
|
1072
1074
|
def __iter__(self):
|
|
1073
1075
|
return iter(self.db_results())
|
|
@@ -1087,6 +1089,8 @@ class DatasetQuery:
|
|
|
1087
1089
|
def delete(
|
|
1088
1090
|
name: str, version: Optional[int] = None, catalog: Optional["Catalog"] = None
|
|
1089
1091
|
) -> None:
|
|
1092
|
+
from datachain.catalog import get_catalog
|
|
1093
|
+
|
|
1090
1094
|
catalog = catalog or get_catalog()
|
|
1091
1095
|
version = version or catalog.get_dataset(name).latest_version
|
|
1092
1096
|
catalog.remove_dataset(name, version)
|
datachain/remote/studio.py
CHANGED
|
@@ -2,7 +2,7 @@ import base64
|
|
|
2
2
|
import json
|
|
3
3
|
import logging
|
|
4
4
|
import os
|
|
5
|
-
from collections.abc import Iterable, Iterator
|
|
5
|
+
from collections.abc import AsyncIterator, Iterable, Iterator
|
|
6
6
|
from datetime import datetime, timedelta, timezone
|
|
7
7
|
from struct import unpack
|
|
8
8
|
from typing import (
|
|
@@ -11,6 +11,9 @@ from typing import (
|
|
|
11
11
|
Optional,
|
|
12
12
|
TypeVar,
|
|
13
13
|
)
|
|
14
|
+
from urllib.parse import urlparse, urlunparse
|
|
15
|
+
|
|
16
|
+
import websockets
|
|
14
17
|
|
|
15
18
|
from datachain.config import Config
|
|
16
19
|
from datachain.dataset import DatasetStats
|
|
@@ -22,6 +25,7 @@ LsData = Optional[list[dict[str, Any]]]
|
|
|
22
25
|
DatasetInfoData = Optional[dict[str, Any]]
|
|
23
26
|
DatasetStatsData = Optional[DatasetStats]
|
|
24
27
|
DatasetRowsData = Optional[Iterable[dict[str, Any]]]
|
|
28
|
+
DatasetJobVersionsData = Optional[dict[str, Any]]
|
|
25
29
|
DatasetExportStatus = Optional[dict[str, Any]]
|
|
26
30
|
DatasetExportSignedUrls = Optional[list[str]]
|
|
27
31
|
FileUploadData = Optional[dict[str, Any]]
|
|
@@ -231,6 +235,40 @@ class StudioClient:
|
|
|
231
235
|
|
|
232
236
|
return msgpack.ExtType(code, data)
|
|
233
237
|
|
|
238
|
+
async def tail_job_logs(self, job_id: str) -> AsyncIterator[dict]:
|
|
239
|
+
"""
|
|
240
|
+
Follow job logs via websocket connection.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
job_id: ID of the job to follow logs for
|
|
244
|
+
|
|
245
|
+
Yields:
|
|
246
|
+
Dict containing either job status updates or log messages
|
|
247
|
+
"""
|
|
248
|
+
parsed_url = urlparse(self.url)
|
|
249
|
+
ws_url = urlunparse(
|
|
250
|
+
parsed_url._replace(scheme="wss" if parsed_url.scheme == "https" else "ws")
|
|
251
|
+
)
|
|
252
|
+
ws_url = f"{ws_url}/logs/follow/?job_id={job_id}&team_name={self.team}"
|
|
253
|
+
|
|
254
|
+
async with websockets.connect(
|
|
255
|
+
ws_url,
|
|
256
|
+
additional_headers={"Authorization": f"token {self.token}"},
|
|
257
|
+
) as websocket:
|
|
258
|
+
while True:
|
|
259
|
+
try:
|
|
260
|
+
message = await websocket.recv()
|
|
261
|
+
data = json.loads(message)
|
|
262
|
+
|
|
263
|
+
# Yield the parsed message data
|
|
264
|
+
yield data
|
|
265
|
+
|
|
266
|
+
except websockets.exceptions.ConnectionClosed:
|
|
267
|
+
break
|
|
268
|
+
except Exception as e: # noqa: BLE001
|
|
269
|
+
logger.error("Error receiving websocket message: %s", e)
|
|
270
|
+
break
|
|
271
|
+
|
|
234
272
|
def ls(self, paths: Iterable[str]) -> Iterator[tuple[str, Response[LsData]]]:
|
|
235
273
|
# TODO: change LsData (response.data value) to be list of lists
|
|
236
274
|
# to handle cases where a path will be expanded (i.e. globs)
|
|
@@ -302,6 +340,13 @@ class StudioClient:
|
|
|
302
340
|
method="GET",
|
|
303
341
|
)
|
|
304
342
|
|
|
343
|
+
def dataset_job_versions(self, job_id: str) -> Response[DatasetJobVersionsData]:
|
|
344
|
+
return self._send_request(
|
|
345
|
+
"datachain/datasets/dataset_job_versions",
|
|
346
|
+
{"job_id": job_id},
|
|
347
|
+
method="GET",
|
|
348
|
+
)
|
|
349
|
+
|
|
305
350
|
def dataset_stats(self, name: str, version: int) -> Response[DatasetStatsData]:
|
|
306
351
|
response = self._send_request(
|
|
307
352
|
"datachain/datasets/stats",
|
|
@@ -359,3 +404,10 @@ class StudioClient:
|
|
|
359
404
|
"requirements": requirements,
|
|
360
405
|
}
|
|
361
406
|
return self._send_request("datachain/job", data)
|
|
407
|
+
|
|
408
|
+
def cancel_job(
|
|
409
|
+
self,
|
|
410
|
+
job_id: str,
|
|
411
|
+
) -> Response[JobData]:
|
|
412
|
+
url = f"datachain/job/{job_id}/cancel"
|
|
413
|
+
return self._send_request(url, data={}, method="POST")
|
datachain/studio.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
import os
|
|
2
3
|
from typing import TYPE_CHECKING, Optional
|
|
3
4
|
|
|
@@ -19,7 +20,7 @@ POST_LOGIN_MESSAGE = (
|
|
|
19
20
|
)
|
|
20
21
|
|
|
21
22
|
|
|
22
|
-
def process_studio_cli_args(args: "Namespace"):
|
|
23
|
+
def process_studio_cli_args(args: "Namespace"): # noqa: PLR0911
|
|
23
24
|
if args.cmd == "login":
|
|
24
25
|
return login(args)
|
|
25
26
|
if args.cmd == "logout":
|
|
@@ -47,6 +48,9 @@ def process_studio_cli_args(args: "Namespace"):
|
|
|
47
48
|
args.req_file,
|
|
48
49
|
)
|
|
49
50
|
|
|
51
|
+
if args.cmd == "cancel":
|
|
52
|
+
return cancel_job(args.job_id, args.team)
|
|
53
|
+
|
|
50
54
|
if args.cmd == "team":
|
|
51
55
|
return set_team(args)
|
|
52
56
|
raise DataChainError(f"Unknown command '{args.cmd}'.")
|
|
@@ -227,8 +231,34 @@ def create_job(
|
|
|
227
231
|
if not response.data:
|
|
228
232
|
raise DataChainError("Failed to create job")
|
|
229
233
|
|
|
230
|
-
|
|
234
|
+
job_id = response.data.get("job", {}).get("id")
|
|
235
|
+
print(f"Job {job_id} created")
|
|
231
236
|
print("Open the job in Studio at", response.data.get("job", {}).get("url"))
|
|
237
|
+
print("=" * 40)
|
|
238
|
+
|
|
239
|
+
# Sync usage
|
|
240
|
+
async def _run():
|
|
241
|
+
async for message in client.tail_job_logs(job_id):
|
|
242
|
+
if "logs" in message:
|
|
243
|
+
for log in message["logs"]:
|
|
244
|
+
print(log["message"], end="")
|
|
245
|
+
elif "job" in message:
|
|
246
|
+
print(f"\n>>>> Job is now in {message['job']['status']} status.")
|
|
247
|
+
|
|
248
|
+
asyncio.run(_run())
|
|
249
|
+
|
|
250
|
+
response = client.dataset_job_versions(job_id)
|
|
251
|
+
if not response.ok:
|
|
252
|
+
raise_remote_error(response.message)
|
|
253
|
+
|
|
254
|
+
response_data = response.data
|
|
255
|
+
if response_data:
|
|
256
|
+
dataset_versions = response_data.get("dataset_versions", [])
|
|
257
|
+
print("\n\n>>>> Dataset versions created during the job:")
|
|
258
|
+
for version in dataset_versions:
|
|
259
|
+
print(f" - {version.get('dataset_name')}@v{version.get('version')}")
|
|
260
|
+
else:
|
|
261
|
+
print("No dataset versions created during the job.")
|
|
232
262
|
|
|
233
263
|
|
|
234
264
|
def upload_files(client: StudioClient, files: list[str]) -> list[str]:
|
|
@@ -248,3 +278,18 @@ def upload_files(client: StudioClient, files: list[str]) -> list[str]:
|
|
|
248
278
|
if file_id:
|
|
249
279
|
file_ids.append(str(file_id))
|
|
250
280
|
return file_ids
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def cancel_job(job_id: str, team_name: Optional[str]):
|
|
284
|
+
token = Config().read().get("studio", {}).get("token")
|
|
285
|
+
if not token:
|
|
286
|
+
raise DataChainError(
|
|
287
|
+
"Not logged in to Studio. Log in with 'datachain studio login'."
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
client = StudioClient(team=team_name)
|
|
291
|
+
response = client.cancel_job(job_id)
|
|
292
|
+
if not response.ok:
|
|
293
|
+
raise_remote_error(response.message)
|
|
294
|
+
|
|
295
|
+
print(f"Job {job_id} canceled")
|