datachain 0.7.11__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -1,34 +1,37 @@
1
1
  import contextlib
2
- from collections.abc import Iterator, Sequence
2
+ from collections.abc import Iterable, Sequence
3
3
  from itertools import chain
4
4
  from multiprocessing import cpu_count
5
5
  from sys import stdin
6
- from typing import Optional
6
+ from threading import Timer
7
+ from typing import TYPE_CHECKING, Optional
7
8
 
8
9
  import attrs
9
10
  import multiprocess
10
11
  from cloudpickle import load, loads
11
12
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
12
13
  from multiprocess import get_context
14
+ from sqlalchemy.sql import func
13
15
 
14
16
  from datachain.catalog import Catalog
15
17
  from datachain.catalog.loader import get_distributed_class
16
- from datachain.lib.udf import UDFAdapter, UDFResult
18
+ from datachain.query.batch import RowsOutput, RowsOutputBatch
17
19
  from datachain.query.dataset import (
18
20
  get_download_callback,
19
21
  get_generated_callback,
20
22
  get_processed_callback,
21
23
  process_udf_outputs,
22
24
  )
23
- from datachain.query.queue import (
24
- get_from_queue,
25
- marshal,
26
- msgpack_pack,
27
- msgpack_unpack,
28
- put_into_queue,
29
- unmarshal,
30
- )
31
- from datachain.utils import batched_it
25
+ from datachain.query.queue import get_from_queue, put_into_queue
26
+ from datachain.query.udf import UdfInfo
27
+ from datachain.query.utils import get_query_id_column
28
+ from datachain.utils import batched, flatten
29
+
30
+ if TYPE_CHECKING:
31
+ from sqlalchemy import Select, Table
32
+
33
+ from datachain.data_storage import AbstractMetastore, AbstractWarehouse
34
+ from datachain.lib.udf import UDFAdapter
32
35
 
33
36
  DEFAULT_BATCH_SIZE = 10000
34
37
  STOP_SIGNAL = "STOP"
@@ -38,10 +41,6 @@ FAILED_STATUS = "FAILED"
38
41
  NOTIFY_STATUS = "NOTIFY"
39
42
 
40
43
 
41
- def full_module_type_path(typ: type) -> str:
42
- return f"{typ.__module__}.{typ.__qualname__}"
43
-
44
-
45
44
  def get_n_workers_from_arg(n_workers: Optional[int] = None) -> int:
46
45
  if not n_workers:
47
46
  return cpu_count()
@@ -52,55 +51,42 @@ def get_n_workers_from_arg(n_workers: Optional[int] = None) -> int:
52
51
 
53
52
  def udf_entrypoint() -> int:
54
53
  # Load UDF info from stdin
55
- udf_info = load(stdin.buffer)
56
-
57
- (
58
- warehouse_class,
59
- warehouse_args,
60
- warehouse_kwargs,
61
- ) = udf_info["warehouse_clone_params"]
62
- warehouse = warehouse_class(*warehouse_args, **warehouse_kwargs)
54
+ udf_info: UdfInfo = load(stdin.buffer)
63
55
 
64
56
  # Parallel processing (faster for more CPU-heavy UDFs)
65
- dispatch = UDFDispatcher(
66
- udf_info["udf_data"],
67
- udf_info["catalog_init"],
68
- udf_info["metastore_clone_params"],
69
- udf_info["warehouse_clone_params"],
70
- udf_fields=udf_info["udf_fields"],
71
- cache=udf_info["cache"],
72
- is_generator=udf_info.get("is_generator", False),
73
- )
57
+ dispatch = UDFDispatcher(udf_info)
74
58
 
75
59
  query = udf_info["query"]
76
60
  batching = udf_info["batching"]
77
- table = udf_info["table"]
78
61
  n_workers = udf_info["processes"]
79
- udf = loads(udf_info["udf_data"])
80
62
  if n_workers is True:
81
- # Use default number of CPUs (cores)
82
- n_workers = None
63
+ n_workers = None # Use default number of CPUs (cores)
64
+
65
+ wh_cls, wh_args, wh_kwargs = udf_info["warehouse_clone_params"]
66
+ warehouse: AbstractWarehouse = wh_cls(*wh_args, **wh_kwargs)
67
+
68
+ total_rows = next(
69
+ warehouse.db.execute(
70
+ query.with_only_columns(func.count(query.c.sys__id)).order_by(None)
71
+ )
72
+ )[0]
83
73
 
84
74
  with contextlib.closing(
85
- batching(warehouse.dataset_select_paginated, query)
75
+ batching(warehouse.dataset_select_paginated, query, ids_only=True)
86
76
  ) as udf_inputs:
87
77
  download_cb = get_download_callback()
88
78
  processed_cb = get_processed_callback()
89
- generated_cb = get_generated_callback(dispatch.is_generator)
90
79
  try:
91
- udf_results = dispatch.run_udf_parallel(
92
- marshal(udf_inputs),
80
+ dispatch.run_udf_parallel(
81
+ udf_inputs,
82
+ total_rows=total_rows,
93
83
  n_workers=n_workers,
94
84
  processed_cb=processed_cb,
95
85
  download_cb=download_cb,
96
86
  )
97
- process_udf_outputs(warehouse, table, udf_results, udf, cb=generated_cb)
98
87
  finally:
99
88
  download_cb.close()
100
89
  processed_cb.close()
101
- generated_cb.close()
102
-
103
- warehouse.insert_rows_done(table)
104
90
 
105
91
  return 0
106
92
 
@@ -114,32 +100,17 @@ class UDFDispatcher:
114
100
  task_queue: Optional[multiprocess.Queue] = None
115
101
  done_queue: Optional[multiprocess.Queue] = None
116
102
 
117
- def __init__(
118
- self,
119
- udf_data,
120
- catalog_init_params,
121
- metastore_clone_params,
122
- warehouse_clone_params,
123
- udf_fields: "Sequence[str]",
124
- cache: bool,
125
- is_generator: bool = False,
126
- buffer_size: int = DEFAULT_BATCH_SIZE,
127
- ):
128
- self.udf_data = udf_data
129
- self.catalog_init_params = catalog_init_params
130
- (
131
- self.metastore_class,
132
- self.metastore_args,
133
- self.metastore_kwargs,
134
- ) = metastore_clone_params
135
- (
136
- self.warehouse_class,
137
- self.warehouse_args,
138
- self.warehouse_kwargs,
139
- ) = warehouse_clone_params
140
- self.udf_fields = udf_fields
141
- self.cache = cache
142
- self.is_generator = is_generator
103
+ def __init__(self, udf_info: UdfInfo, buffer_size: int = DEFAULT_BATCH_SIZE):
104
+ self.udf_data = udf_info["udf_data"]
105
+ self.catalog_init_params = udf_info["catalog_init"]
106
+ self.metastore_clone_params = udf_info["metastore_clone_params"]
107
+ self.warehouse_clone_params = udf_info["warehouse_clone_params"]
108
+ self.query = udf_info["query"]
109
+ self.table = udf_info["table"]
110
+ self.udf_fields = udf_info["udf_fields"]
111
+ self.cache = udf_info["cache"]
112
+ self.is_generator = udf_info["is_generator"]
113
+ self.is_batching = udf_info["batching"].is_batching
143
114
  self.buffer_size = buffer_size
144
115
  self.catalog = None
145
116
  self.task_queue = None
@@ -148,12 +119,10 @@ class UDFDispatcher:
148
119
 
149
120
  def _create_worker(self) -> "UDFWorker":
150
121
  if not self.catalog:
151
- metastore = self.metastore_class(
152
- *self.metastore_args, **self.metastore_kwargs
153
- )
154
- warehouse = self.warehouse_class(
155
- *self.warehouse_args, **self.warehouse_kwargs
156
- )
122
+ ms_cls, ms_args, ms_kwargs = self.metastore_clone_params
123
+ metastore: AbstractMetastore = ms_cls(*ms_args, **ms_kwargs)
124
+ ws_cls, ws_args, ws_kwargs = self.warehouse_clone_params
125
+ warehouse: AbstractWarehouse = ws_cls(*ws_args, **ws_kwargs)
157
126
  self.catalog = Catalog(metastore, warehouse, **self.catalog_init_params)
158
127
  self.udf = loads(self.udf_data)
159
128
  return UDFWorker(
@@ -161,7 +130,10 @@ class UDFDispatcher:
161
130
  self.udf,
162
131
  self.task_queue,
163
132
  self.done_queue,
133
+ self.query,
134
+ self.table,
164
135
  self.is_generator,
136
+ self.is_batching,
165
137
  self.cache,
166
138
  self.udf_fields,
167
139
  )
@@ -189,26 +161,27 @@ class UDFDispatcher:
189
161
 
190
162
  def run_udf_parallel( # noqa: C901, PLR0912
191
163
  self,
192
- input_rows,
164
+ input_rows: Iterable[RowsOutput],
165
+ total_rows: int,
193
166
  n_workers: Optional[int] = None,
194
- input_queue=None,
195
167
  processed_cb: Callback = DEFAULT_CALLBACK,
196
168
  download_cb: Callback = DEFAULT_CALLBACK,
197
- ) -> Iterator[Sequence[UDFResult]]:
169
+ ) -> None:
198
170
  n_workers = get_n_workers_from_arg(n_workers)
199
171
 
172
+ input_batch_size = total_rows // n_workers
173
+ if input_batch_size == 0:
174
+ input_batch_size = 1
175
+ elif input_batch_size > DEFAULT_BATCH_SIZE:
176
+ input_batch_size = DEFAULT_BATCH_SIZE
177
+
200
178
  if self.buffer_size < n_workers:
201
179
  raise RuntimeError(
202
180
  "Parallel run error: buffer size is smaller than "
203
181
  f"number of workers: {self.buffer_size} < {n_workers}"
204
182
  )
205
183
 
206
- if input_queue:
207
- streaming_mode = True
208
- self.task_queue = input_queue
209
- else:
210
- streaming_mode = False
211
- self.task_queue = self.ctx.Queue()
184
+ self.task_queue = self.ctx.Queue()
212
185
  self.done_queue = self.ctx.Queue()
213
186
  pool = [
214
187
  self.ctx.Process(name=f"Worker-UDF-{i}", target=self._run_worker)
@@ -223,41 +196,41 @@ class UDFDispatcher:
223
196
  # Will be set to True when the input is exhausted
224
197
  input_finished = False
225
198
 
226
- if not streaming_mode:
227
- # Stop all workers after the input rows have finished processing
228
- input_data = chain(input_rows, [STOP_SIGNAL] * n_workers)
199
+ if not self.is_batching:
200
+ input_rows = batched(flatten(input_rows), input_batch_size)
229
201
 
230
- # Add initial buffer of tasks
231
- for _ in range(self.buffer_size):
232
- try:
233
- put_into_queue(self.task_queue, next(input_data))
234
- except StopIteration:
235
- input_finished = True
236
- break
202
+ # Stop all workers after the input rows have finished processing
203
+ input_data = chain(input_rows, [STOP_SIGNAL] * n_workers)
204
+
205
+ # Add initial buffer of tasks
206
+ for _ in range(self.buffer_size):
207
+ try:
208
+ put_into_queue(self.task_queue, next(input_data))
209
+ except StopIteration:
210
+ input_finished = True
211
+ break
237
212
 
238
213
  # Process all tasks
239
214
  while n_workers > 0:
240
215
  result = get_from_queue(self.done_queue)
216
+
217
+ if downloaded := result.get("downloaded"):
218
+ download_cb.relative_update(downloaded)
219
+ if processed := result.get("processed"):
220
+ processed_cb.relative_update(processed)
221
+
241
222
  status = result["status"]
242
- if status == NOTIFY_STATUS:
243
- if downloaded := result.get("downloaded"):
244
- download_cb.relative_update(downloaded)
245
- if processed := result.get("processed"):
246
- processed_cb.relative_update(processed)
223
+ if status in (OK_STATUS, NOTIFY_STATUS):
224
+ pass # Do nothing here
247
225
  elif status == FINISHED_STATUS:
248
- # Worker finished
249
- n_workers -= 1
250
- elif status == OK_STATUS:
251
- if processed := result.get("processed"):
252
- processed_cb.relative_update(processed)
253
- yield msgpack_unpack(result["result"])
226
+ n_workers -= 1 # Worker finished
254
227
  else: # Failed / error
255
228
  n_workers -= 1
256
229
  if exc := result.get("exception"):
257
230
  raise exc
258
231
  raise RuntimeError("Internal error: Parallel UDF execution failed")
259
232
 
260
- if status == OK_STATUS and not streaming_mode and not input_finished:
233
+ if status == OK_STATUS and not input_finished:
261
234
  try:
262
235
  put_into_queue(self.task_queue, next(input_data))
263
236
  except StopIteration:
@@ -311,11 +284,14 @@ class ProcessedCallback(Callback):
311
284
 
312
285
  @attrs.define
313
286
  class UDFWorker:
314
- catalog: Catalog
315
- udf: UDFAdapter
287
+ catalog: "Catalog"
288
+ udf: "UDFAdapter"
316
289
  task_queue: "multiprocess.Queue"
317
290
  done_queue: "multiprocess.Queue"
291
+ query: "Select"
292
+ table: "Table"
318
293
  is_generator: bool
294
+ is_batching: bool
319
295
  cache: bool
320
296
  udf_fields: Sequence[str]
321
297
  cb: Callback = attrs.field()
@@ -326,30 +302,54 @@ class UDFWorker:
326
302
 
327
303
  def run(self) -> None:
328
304
  processed_cb = ProcessedCallback()
305
+ generated_cb = get_generated_callback(self.is_generator)
306
+
329
307
  udf_results = self.udf.run(
330
308
  self.udf_fields,
331
- unmarshal(self.get_inputs()),
309
+ self.get_inputs(),
332
310
  self.catalog,
333
- self.is_generator,
334
311
  self.cache,
335
312
  download_cb=self.cb,
336
313
  processed_cb=processed_cb,
337
314
  )
338
- for udf_output in udf_results:
339
- for batch in batched_it(udf_output, DEFAULT_BATCH_SIZE):
340
- put_into_queue(
341
- self.done_queue,
342
- {
343
- "status": OK_STATUS,
344
- "result": msgpack_pack(list(batch)),
345
- },
346
- )
315
+ process_udf_outputs(
316
+ self.catalog.warehouse,
317
+ self.table,
318
+ self.notify_and_process(udf_results, processed_cb),
319
+ self.udf,
320
+ cb=generated_cb,
321
+ )
322
+
323
+ put_into_queue(
324
+ self.done_queue,
325
+ {"status": FINISHED_STATUS, "processed": processed_cb.processed_rows},
326
+ )
327
+
328
+ def notify_and_process(self, udf_results, processed_cb):
329
+ for row in udf_results:
347
330
  put_into_queue(
348
331
  self.done_queue,
349
- {"status": NOTIFY_STATUS, "processed": processed_cb.processed_rows},
332
+ {"status": OK_STATUS, "processed": processed_cb.processed_rows},
350
333
  )
351
- put_into_queue(self.done_queue, {"status": FINISHED_STATUS})
334
+ yield row
352
335
 
353
336
  def get_inputs(self):
354
- while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
355
- yield batch
337
+ warehouse = self.catalog.warehouse.clone()
338
+ col_id = get_query_id_column(self.query)
339
+
340
+ if self.is_batching:
341
+ while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
342
+ ids = [row[0] for row in batch.rows]
343
+ rows = warehouse.dataset_rows_select(self.query.where(col_id.in_(ids)))
344
+ yield RowsOutputBatch(list(rows))
345
+ else:
346
+ while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
347
+ yield from warehouse.dataset_rows_select(
348
+ self.query.where(col_id.in_(batch))
349
+ )
350
+
351
+
352
+ class RepeatTimer(Timer):
353
+ def run(self):
354
+ while not self.finished.wait(self.interval):
355
+ self.function(*self.args, **self.kwargs)
@@ -69,7 +69,7 @@ class Session:
69
69
  self.catalog = catalog or get_catalog(
70
70
  client_config=client_config, in_memory=in_memory
71
71
  )
72
- self.dataset_versions: list[tuple[DatasetRecord, int]] = []
72
+ self.dataset_versions: list[tuple[DatasetRecord, int, bool]] = []
73
73
 
74
74
  def __enter__(self):
75
75
  # Push the current context onto the stack
@@ -89,8 +89,10 @@ class Session:
89
89
  if Session.SESSION_CONTEXTS:
90
90
  Session.SESSION_CONTEXTS.pop()
91
91
 
92
- def add_dataset_version(self, dataset: "DatasetRecord", version: int) -> None:
93
- self.dataset_versions.append((dataset, version))
92
+ def add_dataset_version(
93
+ self, dataset: "DatasetRecord", version: int, listing: bool = False
94
+ ) -> None:
95
+ self.dataset_versions.append((dataset, version, listing))
94
96
 
95
97
  def generate_temp_dataset_name(self) -> str:
96
98
  return self.get_temp_prefix() + uuid4().hex[: self.TEMP_TABLE_UUID_LEN]
@@ -111,8 +113,9 @@ class Session:
111
113
  if not self.dataset_versions:
112
114
  return
113
115
 
114
- for dataset, version in self.dataset_versions:
115
- self.catalog.remove_dataset_version(dataset, version)
116
+ for dataset, version, listing in self.dataset_versions:
117
+ if not listing:
118
+ self.catalog.remove_dataset_version(dataset, version)
116
119
 
117
120
  self.dataset_versions.clear()
118
121
 
datachain/query/udf.py ADDED
@@ -0,0 +1,20 @@
1
+ from typing import TYPE_CHECKING, Any, Callable, Optional, TypedDict
2
+
3
+ if TYPE_CHECKING:
4
+ from sqlalchemy import Select, Table
5
+
6
+ from datachain.query.batch import BatchingStrategy
7
+
8
+
9
+ class UdfInfo(TypedDict):
10
+ udf_data: bytes
11
+ catalog_init: dict[str, Any]
12
+ metastore_clone_params: tuple[Callable[..., Any], list[Any], dict[str, Any]]
13
+ warehouse_clone_params: tuple[Callable[..., Any], list[Any], dict[str, Any]]
14
+ table: "Table"
15
+ query: "Select"
16
+ udf_fields: list[str]
17
+ batching: "BatchingStrategy"
18
+ processes: Optional[int]
19
+ is_generator: bool
20
+ cache: bool
@@ -0,0 +1,42 @@
1
+ from typing import TYPE_CHECKING, Optional, Union
2
+
3
+ from sqlalchemy import Column
4
+
5
+ if TYPE_CHECKING:
6
+ from sqlalchemy import ColumnElement, Select, TextClause
7
+
8
+
9
+ ColT = Union[Column, "ColumnElement", "TextClause"]
10
+
11
+
12
+ def column_name(col: ColT) -> str:
13
+ """Returns column name from column element."""
14
+ return col.name if isinstance(col, Column) else str(col)
15
+
16
+
17
+ def get_query_column(query: "Select", name: str) -> Optional[ColT]:
18
+ """Returns column element from query by name or None if column not found."""
19
+ return next((col for col in query.inner_columns if column_name(col) == name), None)
20
+
21
+
22
+ def get_query_id_column(query: "Select") -> ColT:
23
+ """Returns ID column element from query or None if column not found."""
24
+ col = get_query_column(query, "sys__id")
25
+ if col is None:
26
+ raise RuntimeError("sys__id column not found in query")
27
+ return col
28
+
29
+
30
+ def select_only_columns(query: "Select", *names: str) -> "Select":
31
+ """Returns query selecting defined columns only."""
32
+ if not names:
33
+ return query
34
+
35
+ cols: list[ColT] = []
36
+ for name in names:
37
+ col = get_query_column(query, name)
38
+ if col is None:
39
+ raise ValueError(f"Column '{name}' not found in query")
40
+ cols.append(col)
41
+
42
+ return query.with_only_columns(*cols)
@@ -2,7 +2,7 @@ import base64
2
2
  import json
3
3
  import logging
4
4
  import os
5
- from collections.abc import Iterable, Iterator
5
+ from collections.abc import AsyncIterator, Iterable, Iterator
6
6
  from datetime import datetime, timedelta, timezone
7
7
  from struct import unpack
8
8
  from typing import (
@@ -11,6 +11,9 @@ from typing import (
11
11
  Optional,
12
12
  TypeVar,
13
13
  )
14
+ from urllib.parse import urlparse, urlunparse
15
+
16
+ import websockets
14
17
 
15
18
  from datachain.config import Config
16
19
  from datachain.dataset import DatasetStats
@@ -22,6 +25,7 @@ LsData = Optional[list[dict[str, Any]]]
22
25
  DatasetInfoData = Optional[dict[str, Any]]
23
26
  DatasetStatsData = Optional[DatasetStats]
24
27
  DatasetRowsData = Optional[Iterable[dict[str, Any]]]
28
+ DatasetJobVersionsData = Optional[dict[str, Any]]
25
29
  DatasetExportStatus = Optional[dict[str, Any]]
26
30
  DatasetExportSignedUrls = Optional[list[str]]
27
31
  FileUploadData = Optional[dict[str, Any]]
@@ -231,6 +235,40 @@ class StudioClient:
231
235
 
232
236
  return msgpack.ExtType(code, data)
233
237
 
238
+ async def tail_job_logs(self, job_id: str) -> AsyncIterator[dict]:
239
+ """
240
+ Follow job logs via websocket connection.
241
+
242
+ Args:
243
+ job_id: ID of the job to follow logs for
244
+
245
+ Yields:
246
+ Dict containing either job status updates or log messages
247
+ """
248
+ parsed_url = urlparse(self.url)
249
+ ws_url = urlunparse(
250
+ parsed_url._replace(scheme="wss" if parsed_url.scheme == "https" else "ws")
251
+ )
252
+ ws_url = f"{ws_url}/logs/follow/?job_id={job_id}&team_name={self.team}"
253
+
254
+ async with websockets.connect(
255
+ ws_url,
256
+ additional_headers={"Authorization": f"token {self.token}"},
257
+ ) as websocket:
258
+ while True:
259
+ try:
260
+ message = await websocket.recv()
261
+ data = json.loads(message)
262
+
263
+ # Yield the parsed message data
264
+ yield data
265
+
266
+ except websockets.exceptions.ConnectionClosed:
267
+ break
268
+ except Exception as e: # noqa: BLE001
269
+ logger.error("Error receiving websocket message: %s", e)
270
+ break
271
+
234
272
  def ls(self, paths: Iterable[str]) -> Iterator[tuple[str, Response[LsData]]]:
235
273
  # TODO: change LsData (response.data value) to be list of lists
236
274
  # to handle cases where a path will be expanded (i.e. globs)
@@ -302,6 +340,13 @@ class StudioClient:
302
340
  method="GET",
303
341
  )
304
342
 
343
+ def dataset_job_versions(self, job_id: str) -> Response[DatasetJobVersionsData]:
344
+ return self._send_request(
345
+ "datachain/datasets/dataset_job_versions",
346
+ {"job_id": job_id},
347
+ method="GET",
348
+ )
349
+
305
350
  def dataset_stats(self, name: str, version: int) -> Response[DatasetStatsData]:
306
351
  response = self._send_request(
307
352
  "datachain/datasets/stats",
@@ -359,3 +404,10 @@ class StudioClient:
359
404
  "requirements": requirements,
360
405
  }
361
406
  return self._send_request("datachain/job", data)
407
+
408
+ def cancel_job(
409
+ self,
410
+ job_id: str,
411
+ ) -> Response[JobData]:
412
+ url = f"datachain/job/{job_id}/cancel"
413
+ return self._send_request(url, data={}, method="POST")
datachain/studio.py CHANGED
@@ -1,3 +1,4 @@
1
+ import asyncio
1
2
  import os
2
3
  from typing import TYPE_CHECKING, Optional
3
4
 
@@ -19,7 +20,7 @@ POST_LOGIN_MESSAGE = (
19
20
  )
20
21
 
21
22
 
22
- def process_studio_cli_args(args: "Namespace"):
23
+ def process_studio_cli_args(args: "Namespace"): # noqa: PLR0911
23
24
  if args.cmd == "login":
24
25
  return login(args)
25
26
  if args.cmd == "logout":
@@ -47,6 +48,9 @@ def process_studio_cli_args(args: "Namespace"):
47
48
  args.req_file,
48
49
  )
49
50
 
51
+ if args.cmd == "cancel":
52
+ return cancel_job(args.job_id, args.team)
53
+
50
54
  if args.cmd == "team":
51
55
  return set_team(args)
52
56
  raise DataChainError(f"Unknown command '{args.cmd}'.")
@@ -227,8 +231,34 @@ def create_job(
227
231
  if not response.data:
228
232
  raise DataChainError("Failed to create job")
229
233
 
230
- print(f"Job {response.data.get('job', {}).get('id')} created")
234
+ job_id = response.data.get("job", {}).get("id")
235
+ print(f"Job {job_id} created")
231
236
  print("Open the job in Studio at", response.data.get("job", {}).get("url"))
237
+ print("=" * 40)
238
+
239
+ # Sync usage
240
+ async def _run():
241
+ async for message in client.tail_job_logs(job_id):
242
+ if "logs" in message:
243
+ for log in message["logs"]:
244
+ print(log["message"], end="")
245
+ elif "job" in message:
246
+ print(f"\n>>>> Job is now in {message['job']['status']} status.")
247
+
248
+ asyncio.run(_run())
249
+
250
+ response = client.dataset_job_versions(job_id)
251
+ if not response.ok:
252
+ raise_remote_error(response.message)
253
+
254
+ response_data = response.data
255
+ if response_data:
256
+ dataset_versions = response_data.get("dataset_versions", [])
257
+ print("\n\n>>>> Dataset versions created during the job:")
258
+ for version in dataset_versions:
259
+ print(f" - {version.get('dataset_name')}@v{version.get('version')}")
260
+ else:
261
+ print("No dataset versions created during the job.")
232
262
 
233
263
 
234
264
  def upload_files(client: StudioClient, files: list[str]) -> list[str]:
@@ -248,3 +278,18 @@ def upload_files(client: StudioClient, files: list[str]) -> list[str]:
248
278
  if file_id:
249
279
  file_ids.append(str(file_id))
250
280
  return file_ids
281
+
282
+
283
+ def cancel_job(job_id: str, team_name: Optional[str]):
284
+ token = Config().read().get("studio", {}).get("token")
285
+ if not token:
286
+ raise DataChainError(
287
+ "Not logged in to Studio. Log in with 'datachain studio login'."
288
+ )
289
+
290
+ client = StudioClient(team=team_name)
291
+ response = client.cancel_job(job_id)
292
+ if not response.ok:
293
+ raise_remote_error(response.message)
294
+
295
+ print(f"Job {job_id} canceled")
datachain/utils.py CHANGED
@@ -263,7 +263,7 @@ def batched_it(iterable: Iterable[_T_co], n: int) -> Iterator[Iterator[_T_co]]:
263
263
 
264
264
  def flatten(items):
265
265
  for item in items:
266
- if isinstance(item, list):
266
+ if isinstance(item, (list, tuple)):
267
267
  yield from item
268
268
  else:
269
269
  yield item