datachain 0.7.10__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

datachain/lib/diff.py ADDED
@@ -0,0 +1,197 @@
1
+ import random
2
+ import string
3
+ from collections.abc import Sequence
4
+ from typing import TYPE_CHECKING, Optional, Union
5
+
6
+ import sqlalchemy as sa
7
+
8
+ from datachain.lib.signal_schema import SignalSchema
9
+ from datachain.query.schema import Column
10
+ from datachain.sql.types import String
11
+
12
+ if TYPE_CHECKING:
13
+ from datachain.lib.dc import DataChain
14
+
15
+
16
+ C = Column
17
+
18
+
19
+ def compare( # noqa: PLR0912, PLR0915, C901
20
+ left: "DataChain",
21
+ right: "DataChain",
22
+ on: Union[str, Sequence[str]],
23
+ right_on: Optional[Union[str, Sequence[str]]] = None,
24
+ compare: Optional[Union[str, Sequence[str]]] = None,
25
+ right_compare: Optional[Union[str, Sequence[str]]] = None,
26
+ added: bool = True,
27
+ deleted: bool = True,
28
+ modified: bool = True,
29
+ same: bool = True,
30
+ status_col: Optional[str] = None,
31
+ ) -> "DataChain":
32
+ """Comparing two chains by identifying rows that are added, deleted, modified
33
+ or same"""
34
+ dialect = left._query.dialect
35
+
36
+ rname = "right_"
37
+
38
+ def _rprefix(c: str, rc: str) -> str:
39
+ """Returns prefix of right of two companion left - right columns
40
+ from merge. If companion columns have the same name then prefix will
41
+ be present in right column name, otherwise it won't.
42
+ """
43
+ return rname if c == rc else ""
44
+
45
+ def _to_list(obj: Union[str, Sequence[str]]) -> list[str]:
46
+ return [obj] if isinstance(obj, str) else list(obj)
47
+
48
+ if on is None:
49
+ raise ValueError("'on' must be specified")
50
+
51
+ on = _to_list(on)
52
+ if right_on:
53
+ right_on = _to_list(right_on)
54
+ if len(on) != len(right_on):
55
+ raise ValueError("'on' and 'right_on' must be have the same length")
56
+
57
+ if compare:
58
+ compare = _to_list(compare)
59
+
60
+ if right_compare:
61
+ if not compare:
62
+ raise ValueError("'compare' must be defined if 'right_compare' is defined")
63
+
64
+ right_compare = _to_list(right_compare)
65
+ if len(compare) != len(right_compare):
66
+ raise ValueError(
67
+ "'compare' and 'right_compare' must be have the same length"
68
+ )
69
+
70
+ if not any([added, deleted, modified, same]):
71
+ raise ValueError(
72
+ "At least one of added, deleted, modified, same flags must be set"
73
+ )
74
+
75
+ # we still need status column for internal implementation even if not
76
+ # needed in output
77
+ need_status_col = bool(status_col)
78
+ status_col = status_col or "diff_" + "".join(
79
+ random.choice(string.ascii_letters) # noqa: S311
80
+ for _ in range(10)
81
+ )
82
+
83
+ # calculate on and compare column names
84
+ right_on = right_on or on
85
+ cols = left.signals_schema.clone_without_sys_signals().db_signals()
86
+ right_cols = right.signals_schema.clone_without_sys_signals().db_signals()
87
+
88
+ on = left.signals_schema.resolve(*on).db_signals() # type: ignore[assignment]
89
+ right_on = right.signals_schema.resolve(*right_on).db_signals() # type: ignore[assignment]
90
+ if compare:
91
+ right_compare = right_compare or compare
92
+ compare = left.signals_schema.resolve(*compare).db_signals() # type: ignore[assignment]
93
+ right_compare = right.signals_schema.resolve(*right_compare).db_signals() # type: ignore[assignment]
94
+ elif not compare and len(cols) != len(right_cols):
95
+ # here we will mark all rows that are not added or deleted as modified since
96
+ # there was no explicit list of compare columns provided (meaning we need
97
+ # to check all columns to determine if row is modified or same), but
98
+ # the number of columns on left and right is not the same (one of the chains
99
+ # have additional column)
100
+ compare = None
101
+ right_compare = None
102
+ else:
103
+ compare = [c for c in cols if c in right_cols] # type: ignore[misc, assignment]
104
+ right_compare = compare
105
+
106
+ diff_cond = []
107
+
108
+ if added:
109
+ added_cond = sa.and_(
110
+ *[
111
+ C(c) == None # noqa: E711
112
+ for c in [f"{_rprefix(c, rc)}{rc}" for c, rc in zip(on, right_on)]
113
+ ]
114
+ )
115
+ diff_cond.append((added_cond, "A"))
116
+ if modified and compare:
117
+ modified_cond = sa.or_(
118
+ *[
119
+ C(c) != C(f"{_rprefix(c, rc)}{rc}")
120
+ for c, rc in zip(compare, right_compare) # type: ignore[arg-type]
121
+ ]
122
+ )
123
+ diff_cond.append((modified_cond, "M"))
124
+ if same and compare:
125
+ same_cond = sa.and_(
126
+ *[
127
+ C(c) == C(f"{_rprefix(c, rc)}{rc}")
128
+ for c, rc in zip(compare, right_compare) # type: ignore[arg-type]
129
+ ]
130
+ )
131
+ diff_cond.append((same_cond, "S"))
132
+
133
+ diff = sa.case(*diff_cond, else_=None if compare else "M").label(status_col)
134
+ diff.type = String()
135
+
136
+ left_right_merge = left.merge(
137
+ right, on=on, right_on=right_on, inner=False, rname=rname
138
+ )
139
+ left_right_merge_select = left_right_merge._query.select(
140
+ *(
141
+ [C(c) for c in left_right_merge.signals_schema.db_signals("sys")]
142
+ + [C(c) for c in on]
143
+ + [C(c) for c in cols if c not in on]
144
+ + [diff]
145
+ )
146
+ )
147
+
148
+ diff_col = sa.literal("D").label(status_col)
149
+ diff_col.type = String()
150
+
151
+ right_left_merge = right.merge(
152
+ left, on=right_on, right_on=on, inner=False, rname=rname
153
+ ).filter(
154
+ sa.and_(
155
+ *[C(f"{_rprefix(c, rc)}{c}") == None for c, rc in zip(on, right_on)] # noqa: E711
156
+ )
157
+ )
158
+
159
+ def _default_val(chain: "DataChain", col: str):
160
+ col_type = chain._query.column_types[col] # type: ignore[index]
161
+ val = sa.literal(col_type.default_value(dialect)).label(col)
162
+ val.type = col_type()
163
+ return val
164
+
165
+ right_left_merge_select = right_left_merge._query.select(
166
+ *(
167
+ [C(c) for c in right_left_merge.signals_schema.db_signals("sys")]
168
+ + [
169
+ C(c) if c == rc else _default_val(left, c)
170
+ for c, rc in zip(on, right_on)
171
+ ]
172
+ + [
173
+ C(c) if c in right_cols else _default_val(left, c) # type: ignore[arg-type]
174
+ for c in cols
175
+ if c not in on
176
+ ]
177
+ + [diff_col]
178
+ )
179
+ )
180
+
181
+ if not deleted:
182
+ res = left_right_merge_select
183
+ elif deleted and not any([added, modified, same]):
184
+ res = right_left_merge_select
185
+ else:
186
+ res = left_right_merge_select.union(right_left_merge_select)
187
+
188
+ res = res.filter(C(status_col) != None) # noqa: E711
189
+
190
+ schema = left.signals_schema
191
+ if need_status_col:
192
+ res = res.select()
193
+ schema = SignalSchema({status_col: str}) | schema
194
+ else:
195
+ res = res.select_except(C(status_col))
196
+
197
+ return left._evolve(query=res, signal_schema=schema)
datachain/lib/file.py CHANGED
@@ -17,7 +17,6 @@ from urllib.request import url2pathname
17
17
 
18
18
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
19
19
  from PIL import Image
20
- from pyarrow.dataset import dataset
21
20
  from pydantic import Field, field_validator
22
21
 
23
22
  from datachain.client.fileslice import FileSlice
@@ -452,6 +451,8 @@ class ArrowRow(DataModel):
452
451
  @contextmanager
453
452
  def open(self):
454
453
  """Stream row contents from indexed file."""
454
+ from pyarrow.dataset import dataset
455
+
455
456
  if self.file._caching_enabled:
456
457
  self.file.ensure_cached()
457
458
  path = self.file.get_local_path()
@@ -6,7 +6,6 @@ from collections.abc import Iterator
6
6
  from pathlib import Path
7
7
  from typing import Callable
8
8
 
9
- import datamodel_code_generator
10
9
  import jmespath as jsp
11
10
  from pydantic import BaseModel, ConfigDict, Field, ValidationError # noqa: F401
12
11
 
@@ -39,36 +38,41 @@ def process_json(data_string, jmespath):
39
38
  return json_dict
40
39
 
41
40
 
42
- # Print a dynamic datamodel-codegen output from JSON or CSV on stdout
43
- def read_schema(source_file, data_type="csv", expr=None, model_name=None):
41
+ def gen_datamodel_code(
42
+ source_file, format="json", jmespath=None, model_name=None
43
+ ) -> str:
44
+ """Generates Python code with Pydantic models that corresponds
45
+ to the provided JSON, CSV, or JSONL file.
46
+ It support root JSON arrays (samples the first entry).
47
+ """
44
48
  data_string = ""
45
49
  # using uiid to get around issue #1617
46
50
  if not model_name:
47
51
  # comply with Python class names
48
52
  uid_str = str(generate_uuid()).replace("-", "")
49
- model_name = f"Model{data_type}{uid_str}"
50
- try:
51
- with source_file.open() as fd: # CSV can be larger than memory
52
- if data_type == "csv":
53
- data_string += fd.readline().replace("\r", "")
54
- data_string += fd.readline().replace("\r", "")
55
- elif data_type == "jsonl":
56
- data_string = fd.readline().replace("\r", "")
57
- else:
58
- data_string = fd.read() # other meta must fit into RAM
59
- except OSError as e:
60
- print(f"An unexpected file error occurred: {e}")
61
- return
62
- if data_type in ("json", "jsonl"):
63
- json_object = process_json(data_string, expr)
64
- if data_type == "json" and isinstance(json_object, list):
53
+ model_name = f"Model{format}{uid_str}"
54
+
55
+ with source_file.open() as fd: # CSV can be larger than memory
56
+ if format == "csv":
57
+ data_string += fd.readline().replace("\r", "")
58
+ data_string += fd.readline().replace("\r", "")
59
+ elif format == "jsonl":
60
+ data_string = fd.readline().replace("\r", "")
61
+ else:
62
+ data_string = fd.read() # other meta must fit into RAM
63
+
64
+ if format in ("json", "jsonl"):
65
+ json_object = process_json(data_string, jmespath)
66
+ if format == "json" and isinstance(json_object, list):
65
67
  json_object = json_object[0] # sample the 1st object from JSON array
66
- if data_type == "jsonl":
67
- data_type = "json" # treat json line as plain JSON in auto-schema
68
+ if format == "jsonl":
69
+ format = "json" # treat json line as plain JSON in auto-schema
68
70
  data_string = json.dumps(json_object)
69
71
 
72
+ import datamodel_code_generator
73
+
70
74
  input_file_types = {i.value: i for i in datamodel_code_generator.InputFileType}
71
- input_file_type = input_file_types[data_type]
75
+ input_file_type = input_file_types[format]
72
76
  with tempfile.TemporaryDirectory() as tmpdir:
73
77
  output = Path(tmpdir) / "model.py"
74
78
  datamodel_code_generator.generate(
@@ -94,36 +98,29 @@ spec = {model_name}
94
98
  def read_meta( # noqa: C901
95
99
  spec=None,
96
100
  schema_from=None,
97
- meta_type="json",
101
+ format="json",
98
102
  jmespath=None,
99
- print_schema=False,
100
103
  model_name=None,
101
104
  nrows=None,
102
105
  ) -> Callable:
103
106
  from datachain.lib.dc import DataChain
104
107
 
105
108
  if schema_from:
106
- chain = (
107
- DataChain.from_storage(schema_from, type="text")
108
- .limit(1)
109
- .map( # dummy column created (#1615)
110
- meta_schema=lambda file: read_schema(
111
- file, data_type=meta_type, expr=jmespath, model_name=model_name
112
- ),
113
- output=str,
114
- )
109
+ file = next(
110
+ DataChain.from_storage(schema_from, type="text").limit(1).collect("file")
115
111
  )
116
- (model_output,) = chain.collect("meta_schema")
117
- assert isinstance(model_output, str)
118
- if print_schema:
119
- print(f"{model_output}")
112
+ model_code = gen_datamodel_code(
113
+ file, format=format, jmespath=jmespath, model_name=model_name
114
+ )
115
+ assert isinstance(model_code, str)
116
+
120
117
  # Below 'spec' should be a dynamically converted DataModel from Pydantic
121
118
  if not spec:
122
119
  gl = globals()
123
- exec(model_output, gl) # type: ignore[arg-type] # noqa: S102
120
+ exec(model_code, gl) # type: ignore[arg-type] # noqa: S102
124
121
  spec = gl["spec"]
125
122
 
126
- if not (spec) and not (schema_from):
123
+ if not spec and not schema_from:
127
124
  raise ValueError(
128
125
  "Must provide a static schema in spec: or metadata sample in schema_from:"
129
126
  )
@@ -135,7 +132,7 @@ def read_meta( # noqa: C901
135
132
  def parse_data(
136
133
  file: File,
137
134
  data_model=spec,
138
- meta_type=meta_type,
135
+ format=format,
139
136
  jmespath=jmespath,
140
137
  nrows=nrows,
141
138
  ) -> Iterator[spec]:
@@ -147,7 +144,7 @@ def read_meta( # noqa: C901
147
144
  except ValidationError as e:
148
145
  print(f"Validation error occurred in row {nrow} file {file.name}:", e)
149
146
 
150
- if meta_type == "csv":
147
+ if format == "csv":
151
148
  with (
152
149
  file.open() as fd
153
150
  ): # TODO: if schema is statically given, should allow CSV without headers
@@ -155,7 +152,7 @@ def read_meta( # noqa: C901
155
152
  for row in reader: # CSV can be larger than memory
156
153
  yield from validator(row)
157
154
 
158
- if meta_type == "json":
155
+ if format == "json":
159
156
  try:
160
157
  with file.open() as fd: # JSON must fit into RAM
161
158
  data_string = fd.read()
@@ -173,7 +170,7 @@ def read_meta( # noqa: C901
173
170
  return
174
171
  yield from validator(json_dict, nrow)
175
172
 
176
- if meta_type == "jsonl":
173
+ if format == "jsonl":
177
174
  try:
178
175
  nrow = 0
179
176
  with file.open() as fd:
datachain/lib/pytorch.py CHANGED
@@ -7,7 +7,6 @@ from torch import float32
7
7
  from torch.distributed import get_rank, get_world_size
8
8
  from torch.utils.data import IterableDataset, get_worker_info
9
9
  from torchvision.transforms import v2
10
- from tqdm import tqdm
11
10
 
12
11
  from datachain import Session
13
12
  from datachain.asyn import AsyncMapper
@@ -112,10 +111,7 @@ class PytorchDataset(IterableDataset):
112
111
  from datachain.lib.udf import _prefetch_input
113
112
 
114
113
  rows = AsyncMapper(_prefetch_input, rows, workers=self.prefetch).iterate()
115
-
116
- desc = f"Parsed PyTorch dataset for rank={total_rank} worker"
117
- with tqdm(rows, desc=desc, unit=" rows", position=total_rank) as rows_it:
118
- yield from map(self._process_row, rows_it)
114
+ yield from map(self._process_row, rows)
119
115
 
120
116
  def _process_row(self, row_features):
121
117
  row = []
@@ -402,9 +402,20 @@ class SignalSchema:
402
402
  if ModelStore.is_pydantic(finfo.annotation):
403
403
  SignalSchema._set_file_stream(getattr(obj, field), catalog, cache)
404
404
 
405
- def get_column_type(self, col_name: str) -> DataType:
405
+ def get_column_type(self, col_name: str, with_subtree: bool = False) -> DataType:
406
+ """
407
+ Returns column type by column name.
408
+
409
+ If `with_subtree` is True, then it will return the type of the column
410
+ even if it has a subtree (e.g. model with nested fields), otherwise it will
411
+ return the type of the column (standard type field, not the model).
412
+
413
+ If column is not found, raises `SignalResolvingError`.
414
+ """
406
415
  for path, _type, has_subtree, _ in self.get_flat_tree():
407
- if not has_subtree and DEFAULT_DELIMITER.join(path) == col_name:
416
+ if (with_subtree or not has_subtree) and DEFAULT_DELIMITER.join(
417
+ path
418
+ ) == col_name:
408
419
  return _type
409
420
  raise SignalResolvingError([col_name], "is not found")
410
421
 
@@ -492,14 +503,25 @@ class SignalSchema:
492
503
  # renaming existing signal
493
504
  del new_values[value.name]
494
505
  new_values[name] = self.values[value.name]
495
- elif isinstance(value, Func):
506
+ continue
507
+ if isinstance(value, Column):
508
+ # adding new signal from existing signal field
509
+ try:
510
+ new_values[name] = self.get_column_type(
511
+ value.name, with_subtree=True
512
+ )
513
+ continue
514
+ except SignalResolvingError:
515
+ pass
516
+ if isinstance(value, Func):
496
517
  # adding new signal with function
497
518
  new_values[name] = value.get_result_type(self)
498
- elif isinstance(value, ColumnElement):
519
+ continue
520
+ if isinstance(value, ColumnElement):
499
521
  # adding new signal
500
522
  new_values[name] = sql_to_python(value)
501
- else:
502
- new_values[name] = value
523
+ continue
524
+ new_values[name] = value
503
525
 
504
526
  return SignalSchema(new_values)
505
527
 
@@ -35,7 +35,6 @@ from sqlalchemy.sql.schema import TableClause
35
35
  from sqlalchemy.sql.selectable import Select
36
36
 
37
37
  from datachain.asyn import ASYNC_WORKERS, AsyncMapper, OrderedMapper
38
- from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE, get_catalog
39
38
  from datachain.data_storage.schema import (
40
39
  PARTITION_COLUMN_ID,
41
40
  partition_col_names,
@@ -394,6 +393,8 @@ class UDFStep(Step, ABC):
394
393
  """
395
394
 
396
395
  def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
396
+ from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE
397
+
397
398
  use_partitioning = self.partition_by is not None
398
399
  batching = self.udf.get_batching(use_partitioning)
399
400
  workers = self.workers
@@ -1068,6 +1069,7 @@ class DatasetQuery:
1068
1069
  if "sys__id" in self.column_types:
1069
1070
  self.column_types.pop("sys__id")
1070
1071
  self.starting_step = QueryStep(self.catalog, name, self.version)
1072
+ self.dialect = self.catalog.warehouse.db.dialect
1071
1073
 
1072
1074
  def __iter__(self):
1073
1075
  return iter(self.db_results())
@@ -1087,6 +1089,8 @@ class DatasetQuery:
1087
1089
  def delete(
1088
1090
  name: str, version: Optional[int] = None, catalog: Optional["Catalog"] = None
1089
1091
  ) -> None:
1092
+ from datachain.catalog import get_catalog
1093
+
1090
1094
  catalog = catalog or get_catalog()
1091
1095
  version = version or catalog.get_dataset(name).latest_version
1092
1096
  catalog.remove_dataset(name, version)
@@ -2,7 +2,7 @@ import base64
2
2
  import json
3
3
  import logging
4
4
  import os
5
- from collections.abc import Iterable, Iterator
5
+ from collections.abc import AsyncIterator, Iterable, Iterator
6
6
  from datetime import datetime, timedelta, timezone
7
7
  from struct import unpack
8
8
  from typing import (
@@ -11,6 +11,9 @@ from typing import (
11
11
  Optional,
12
12
  TypeVar,
13
13
  )
14
+ from urllib.parse import urlparse, urlunparse
15
+
16
+ import websockets
14
17
 
15
18
  from datachain.config import Config
16
19
  from datachain.dataset import DatasetStats
@@ -22,6 +25,7 @@ LsData = Optional[list[dict[str, Any]]]
22
25
  DatasetInfoData = Optional[dict[str, Any]]
23
26
  DatasetStatsData = Optional[DatasetStats]
24
27
  DatasetRowsData = Optional[Iterable[dict[str, Any]]]
28
+ DatasetJobVersionsData = Optional[dict[str, Any]]
25
29
  DatasetExportStatus = Optional[dict[str, Any]]
26
30
  DatasetExportSignedUrls = Optional[list[str]]
27
31
  FileUploadData = Optional[dict[str, Any]]
@@ -231,6 +235,40 @@ class StudioClient:
231
235
 
232
236
  return msgpack.ExtType(code, data)
233
237
 
238
+ async def tail_job_logs(self, job_id: str) -> AsyncIterator[dict]:
239
+ """
240
+ Follow job logs via websocket connection.
241
+
242
+ Args:
243
+ job_id: ID of the job to follow logs for
244
+
245
+ Yields:
246
+ Dict containing either job status updates or log messages
247
+ """
248
+ parsed_url = urlparse(self.url)
249
+ ws_url = urlunparse(
250
+ parsed_url._replace(scheme="wss" if parsed_url.scheme == "https" else "ws")
251
+ )
252
+ ws_url = f"{ws_url}/logs/follow/?job_id={job_id}&team_name={self.team}"
253
+
254
+ async with websockets.connect(
255
+ ws_url,
256
+ additional_headers={"Authorization": f"token {self.token}"},
257
+ ) as websocket:
258
+ while True:
259
+ try:
260
+ message = await websocket.recv()
261
+ data = json.loads(message)
262
+
263
+ # Yield the parsed message data
264
+ yield data
265
+
266
+ except websockets.exceptions.ConnectionClosed:
267
+ break
268
+ except Exception as e: # noqa: BLE001
269
+ logger.error("Error receiving websocket message: %s", e)
270
+ break
271
+
234
272
  def ls(self, paths: Iterable[str]) -> Iterator[tuple[str, Response[LsData]]]:
235
273
  # TODO: change LsData (response.data value) to be list of lists
236
274
  # to handle cases where a path will be expanded (i.e. globs)
@@ -302,6 +340,13 @@ class StudioClient:
302
340
  method="GET",
303
341
  )
304
342
 
343
+ def dataset_job_versions(self, job_id: str) -> Response[DatasetJobVersionsData]:
344
+ return self._send_request(
345
+ "datachain/datasets/dataset_job_versions",
346
+ {"job_id": job_id},
347
+ method="GET",
348
+ )
349
+
305
350
  def dataset_stats(self, name: str, version: int) -> Response[DatasetStatsData]:
306
351
  response = self._send_request(
307
352
  "datachain/datasets/stats",
@@ -359,3 +404,10 @@ class StudioClient:
359
404
  "requirements": requirements,
360
405
  }
361
406
  return self._send_request("datachain/job", data)
407
+
408
+ def cancel_job(
409
+ self,
410
+ job_id: str,
411
+ ) -> Response[JobData]:
412
+ url = f"datachain/job/{job_id}/cancel"
413
+ return self._send_request(url, data={}, method="POST")
datachain/studio.py CHANGED
@@ -1,3 +1,4 @@
1
+ import asyncio
1
2
  import os
2
3
  from typing import TYPE_CHECKING, Optional
3
4
 
@@ -19,7 +20,7 @@ POST_LOGIN_MESSAGE = (
19
20
  )
20
21
 
21
22
 
22
- def process_studio_cli_args(args: "Namespace"):
23
+ def process_studio_cli_args(args: "Namespace"): # noqa: PLR0911
23
24
  if args.cmd == "login":
24
25
  return login(args)
25
26
  if args.cmd == "logout":
@@ -47,6 +48,9 @@ def process_studio_cli_args(args: "Namespace"):
47
48
  args.req_file,
48
49
  )
49
50
 
51
+ if args.cmd == "cancel":
52
+ return cancel_job(args.job_id, args.team)
53
+
50
54
  if args.cmd == "team":
51
55
  return set_team(args)
52
56
  raise DataChainError(f"Unknown command '{args.cmd}'.")
@@ -227,8 +231,34 @@ def create_job(
227
231
  if not response.data:
228
232
  raise DataChainError("Failed to create job")
229
233
 
230
- print(f"Job {response.data.get('job', {}).get('id')} created")
234
+ job_id = response.data.get("job", {}).get("id")
235
+ print(f"Job {job_id} created")
231
236
  print("Open the job in Studio at", response.data.get("job", {}).get("url"))
237
+ print("=" * 40)
238
+
239
+ # Sync usage
240
+ async def _run():
241
+ async for message in client.tail_job_logs(job_id):
242
+ if "logs" in message:
243
+ for log in message["logs"]:
244
+ print(log["message"], end="")
245
+ elif "job" in message:
246
+ print(f"\n>>>> Job is now in {message['job']['status']} status.")
247
+
248
+ asyncio.run(_run())
249
+
250
+ response = client.dataset_job_versions(job_id)
251
+ if not response.ok:
252
+ raise_remote_error(response.message)
253
+
254
+ response_data = response.data
255
+ if response_data:
256
+ dataset_versions = response_data.get("dataset_versions", [])
257
+ print("\n\n>>>> Dataset versions created during the job:")
258
+ for version in dataset_versions:
259
+ print(f" - {version.get('dataset_name')}@v{version.get('version')}")
260
+ else:
261
+ print("No dataset versions created during the job.")
232
262
 
233
263
 
234
264
  def upload_files(client: StudioClient, files: list[str]) -> list[str]:
@@ -248,3 +278,18 @@ def upload_files(client: StudioClient, files: list[str]) -> list[str]:
248
278
  if file_id:
249
279
  file_ids.append(str(file_id))
250
280
  return file_ids
281
+
282
+
283
+ def cancel_job(job_id: str, team_name: Optional[str]):
284
+ token = Config().read().get("studio", {}).get("token")
285
+ if not token:
286
+ raise DataChainError(
287
+ "Not logged in to Studio. Log in with 'datachain studio login'."
288
+ )
289
+
290
+ client = StudioClient(team=team_name)
291
+ response = client.cancel_job(job_id)
292
+ if not response.ok:
293
+ raise_remote_error(response.message)
294
+
295
+ print(f"Job {job_id} canceled")