datachain 0.8.8__py3-none-any.whl → 0.8.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

datachain/lib/udf.py CHANGED
@@ -16,6 +16,7 @@ from datachain.lib.convert.flatten import flatten
16
16
  from datachain.lib.data_model import DataValue
17
17
  from datachain.lib.file import File
18
18
  from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
19
+ from datachain.progress import CombinedDownloadCallback
19
20
  from datachain.query.batch import (
20
21
  Batch,
21
22
  BatchingStrategy,
@@ -301,20 +302,42 @@ async def _prefetch_input(
301
302
  return row
302
303
 
303
304
 
305
+ def _remove_prefetched(row: T) -> None:
306
+ for obj in row:
307
+ if isinstance(obj, File):
308
+ catalog = obj._catalog
309
+ assert catalog is not None
310
+ try:
311
+ catalog.cache.remove(obj)
312
+ except Exception as e: # noqa: BLE001
313
+ print(f"Failed to remove prefetched item {obj.name!r}: {e!s}")
314
+
315
+
304
316
  def _prefetch_inputs(
305
317
  prepared_inputs: "Iterable[T]",
306
318
  prefetch: int = 0,
307
319
  download_cb: Optional["Callback"] = None,
308
- after_prefetch: "Callable[[], None]" = noop,
320
+ after_prefetch: Optional[Callable[[], None]] = None,
321
+ remove_prefetched: bool = False,
309
322
  ) -> "abc.Generator[T, None, None]":
310
- if prefetch > 0:
311
- f = partial(
312
- _prefetch_input,
313
- download_cb=download_cb,
314
- after_prefetch=after_prefetch,
315
- )
316
- prepared_inputs = AsyncMapper(f, prepared_inputs, workers=prefetch).iterate() # type: ignore[assignment]
317
- yield from prepared_inputs
323
+ if not prefetch:
324
+ yield from prepared_inputs
325
+ return
326
+
327
+ if after_prefetch is None:
328
+ after_prefetch = noop
329
+ if isinstance(download_cb, CombinedDownloadCallback):
330
+ after_prefetch = download_cb.increment_file_count
331
+
332
+ f = partial(_prefetch_input, download_cb=download_cb, after_prefetch=after_prefetch)
333
+ mapper = AsyncMapper(f, prepared_inputs, workers=prefetch)
334
+ with closing(mapper.iterate()) as row_iter:
335
+ for row in row_iter:
336
+ try:
337
+ yield row # type: ignore[misc]
338
+ finally:
339
+ if remove_prefetched:
340
+ _remove_prefetched(row)
318
341
 
319
342
 
320
343
  def _get_cache(
@@ -351,7 +374,13 @@ class Mapper(UDFBase):
351
374
  )
352
375
 
353
376
  prepared_inputs = _prepare_rows(udf_inputs)
354
- prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch)
377
+ prepared_inputs = _prefetch_inputs(
378
+ prepared_inputs,
379
+ self.prefetch,
380
+ download_cb=download_cb,
381
+ remove_prefetched=bool(self.prefetch) and not cache,
382
+ )
383
+
355
384
  with closing(prepared_inputs):
356
385
  for id_, *udf_args in prepared_inputs:
357
386
  result_objs = self.process_safe(udf_args)
@@ -391,9 +420,9 @@ class BatchMapper(UDFBase):
391
420
  )
392
421
  result_objs = list(self.process_safe(udf_args))
393
422
  n_objs = len(result_objs)
394
- assert (
395
- n_objs == n_rows
396
- ), f"{self.name} returns {n_objs} rows, but {n_rows} were expected"
423
+ assert n_objs == n_rows, (
424
+ f"{self.name} returns {n_objs} rows, but {n_rows} were expected"
425
+ )
397
426
  udf_outputs = (self._flatten_row(row) for row in result_objs)
398
427
  output = [
399
428
  {"sys__id": row_id} | dict(zip(self.signal_names, signals))
@@ -429,15 +458,22 @@ class Generator(UDFBase):
429
458
  row, udf_fields, catalog, cache, download_cb
430
459
  )
431
460
 
461
+ def _process_row(row):
462
+ with safe_closing(self.process_safe(row)) as result_objs:
463
+ for result_obj in result_objs:
464
+ udf_output = self._flatten_row(result_obj)
465
+ yield dict(zip(self.signal_names, udf_output))
466
+
432
467
  prepared_inputs = _prepare_rows(udf_inputs)
433
- prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch)
468
+ prepared_inputs = _prefetch_inputs(
469
+ prepared_inputs,
470
+ self.prefetch,
471
+ download_cb=download_cb,
472
+ remove_prefetched=bool(self.prefetch) and not cache,
473
+ )
434
474
  with closing(prepared_inputs):
435
- for row in prepared_inputs:
436
- result_objs = self.process_safe(row)
437
- udf_outputs = (self._flatten_row(row) for row in result_objs)
438
- output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
439
- processed_cb.relative_update(1)
440
- yield output
475
+ for row in processed_cb.wrap(prepared_inputs):
476
+ yield _process_row(row)
441
477
 
442
478
  self.teardown()
443
479
 
datachain/model/bbox.py CHANGED
@@ -22,9 +22,9 @@ class BBox(DataModel):
22
22
  @staticmethod
23
23
  def from_list(coords: list[float], title: str = "") -> "BBox":
24
24
  assert len(coords) == 4, "Bounding box must be a list of 4 coordinates."
25
- assert all(
26
- isinstance(value, (int, float)) for value in coords
27
- ), "Bounding box coordinates must be floats or integers."
25
+ assert all(isinstance(value, (int, float)) for value in coords), (
26
+ "Bounding box coordinates must be floats or integers."
27
+ )
28
28
  return BBox(
29
29
  title=title,
30
30
  coords=[round(c) for c in coords],
@@ -64,12 +64,12 @@ class OBBox(DataModel):
64
64
 
65
65
  @staticmethod
66
66
  def from_list(coords: list[float], title: str = "") -> "OBBox":
67
- assert (
68
- len(coords) == 8
69
- ), "Oriented bounding box must be a list of 8 coordinates."
70
- assert all(
71
- isinstance(value, (int, float)) for value in coords
72
- ), "Oriented bounding box coordinates must be floats or integers."
67
+ assert len(coords) == 8, (
68
+ "Oriented bounding box must be a list of 8 coordinates."
69
+ )
70
+ assert all(isinstance(value, (int, float)) for value in coords), (
71
+ "Oriented bounding box coordinates must be floats or integers."
72
+ )
73
73
  return OBBox(
74
74
  title=title,
75
75
  coords=[round(c) for c in coords],
datachain/model/pose.py CHANGED
@@ -22,9 +22,9 @@ class Pose(DataModel):
22
22
  def from_list(points: list[list[float]]) -> "Pose":
23
23
  assert len(points) == 2, "Pose must be a list of 2 lists: x and y coordinates."
24
24
  points_x, points_y = points
25
- assert (
26
- len(points_x) == len(points_y) == 17
27
- ), "Pose x and y coordinates must have the same length of 17."
25
+ assert len(points_x) == len(points_y) == 17, (
26
+ "Pose x and y coordinates must have the same length of 17."
27
+ )
28
28
  assert all(
29
29
  isinstance(value, (int, float)) for value in [*points_x, *points_y]
30
30
  ), "Pose coordinates must be floats or integers."
@@ -61,13 +61,13 @@ class Pose3D(DataModel):
61
61
 
62
62
  @staticmethod
63
63
  def from_list(points: list[list[float]]) -> "Pose3D":
64
- assert (
65
- len(points) == 3
66
- ), "Pose3D must be a list of 3 lists: x, y coordinates and visible."
64
+ assert len(points) == 3, (
65
+ "Pose3D must be a list of 3 lists: x, y coordinates and visible."
66
+ )
67
67
  points_x, points_y, points_v = points
68
- assert (
69
- len(points_x) == len(points_y) == len(points_v) == 17
70
- ), "Pose3D x, y coordinates and visible must have the same length of 17."
68
+ assert len(points_x) == len(points_y) == len(points_v) == 17, (
69
+ "Pose3D x, y coordinates and visible must have the same length of 17."
70
+ )
71
71
  assert all(
72
72
  isinstance(value, (int, float))
73
73
  for value in [*points_x, *points_y, *points_v]
@@ -22,13 +22,13 @@ class Segment(DataModel):
22
22
 
23
23
  @staticmethod
24
24
  def from_list(points: list[list[float]], title: str = "") -> "Segment":
25
- assert (
26
- len(points) == 2
27
- ), "Segment must be a list of 2 lists: x and y coordinates."
25
+ assert len(points) == 2, (
26
+ "Segment must be a list of 2 lists: x and y coordinates."
27
+ )
28
28
  points_x, points_y = points
29
- assert len(points_x) == len(
30
- points_y
31
- ), "Segment x and y coordinates must have the same length."
29
+ assert len(points_x) == len(points_y), (
30
+ "Segment x and y coordinates must have the same length."
31
+ )
32
32
  assert all(
33
33
  isinstance(value, (int, float)) for value in [*points_x, *points_y]
34
34
  ), "Segment coordinates must be floats or integers."
datachain/progress.py CHANGED
@@ -1,14 +1,5 @@
1
- """Manages progress bars."""
2
-
3
- import logging
4
- from threading import RLock
5
-
6
1
  from fsspec import Callback
7
2
  from fsspec.callbacks import TqdmCallback
8
- from tqdm.auto import tqdm
9
-
10
- logger = logging.getLogger(__name__)
11
- tqdm.set_lock(RLock())
12
3
 
13
4
 
14
5
  class CombinedDownloadCallback(Callback):
@@ -24,10 +15,6 @@ class CombinedDownloadCallback(Callback):
24
15
  class TqdmCombinedDownloadCallback(CombinedDownloadCallback, TqdmCallback):
25
16
  def __init__(self, tqdm_kwargs=None, *args, **kwargs):
26
17
  self.files_count = 0
27
- tqdm_kwargs = tqdm_kwargs or {}
28
- tqdm_kwargs.setdefault("postfix", {}).setdefault("files", self.files_count)
29
- kwargs = kwargs or {}
30
- kwargs["tqdm_cls"] = tqdm
31
18
  super().__init__(tqdm_kwargs, *args, **kwargs)
32
19
 
33
20
  def increment_file_count(self, n: int = 1) -> None:
@@ -336,15 +336,16 @@ def process_udf_outputs(
336
336
  for udf_output in udf_results:
337
337
  if not udf_output:
338
338
  continue
339
- for row in udf_output:
340
- cb.relative_update()
341
- rows.append(adjust_outputs(warehouse, row, udf_col_types))
342
- if len(rows) >= batch_size or (
343
- len(rows) % 10 == 0 and psutil.virtual_memory().percent > 80
344
- ):
345
- for row_chunk in batched(rows, batch_size):
346
- warehouse.insert_rows(udf_table, row_chunk)
347
- rows.clear()
339
+ with safe_closing(udf_output):
340
+ for row in udf_output:
341
+ cb.relative_update()
342
+ rows.append(adjust_outputs(warehouse, row, udf_col_types))
343
+ if len(rows) >= batch_size or (
344
+ len(rows) % 10 == 0 and psutil.virtual_memory().percent > 80
345
+ ):
346
+ for row_chunk in batched(rows, batch_size):
347
+ warehouse.insert_rows(udf_table, row_chunk)
348
+ rows.clear()
348
349
 
349
350
  if rows:
350
351
  for row_chunk in batched(rows, batch_size):
@@ -355,7 +356,7 @@ def process_udf_outputs(
355
356
 
356
357
  def get_download_callback(suffix: str = "", **kwargs) -> CombinedDownloadCallback:
357
358
  return TqdmCombinedDownloadCallback(
358
- {
359
+ tqdm_kwargs={
359
360
  "desc": "Download" + suffix,
360
361
  "unit": "B",
361
362
  "unit_scale": True,
@@ -363,6 +364,7 @@ def get_download_callback(suffix: str = "", **kwargs) -> CombinedDownloadCallbac
363
364
  "leave": False,
364
365
  **kwargs,
365
366
  },
367
+ tqdm_cls=tqdm,
366
368
  )
367
369
 
368
370
 
@@ -873,6 +875,7 @@ class SQLJoin(Step):
873
875
  query2: "DatasetQuery"
874
876
  predicates: Union[JoinPredicateType, tuple[JoinPredicateType, ...]]
875
877
  inner: bool
878
+ full: bool
876
879
  rname: str
877
880
 
878
881
  def get_query(self, dq: "DatasetQuery", temp_tables: list[str]) -> sa.Subquery:
@@ -975,14 +978,14 @@ class SQLJoin(Step):
975
978
  self.validate_expression(join_expression, q1, q2)
976
979
 
977
980
  def q(*columns):
978
- join_query = self.catalog.warehouse.join(
981
+ return self.catalog.warehouse.join(
979
982
  q1,
980
983
  q2,
981
984
  join_expression,
982
985
  inner=self.inner,
986
+ full=self.full,
987
+ columns=columns,
983
988
  )
984
- return sqlalchemy.select(*columns).select_from(join_query)
985
- # return sqlalchemy.select(*subquery.c).select_from(subquery)
986
989
 
987
990
  return step_result(
988
991
  q,
@@ -1487,6 +1490,7 @@ class DatasetQuery:
1487
1490
  dataset_query: "DatasetQuery",
1488
1491
  predicates: Union[JoinPredicateType, Sequence[JoinPredicateType]],
1489
1492
  inner=False,
1493
+ full=False,
1490
1494
  rname="{name}_right",
1491
1495
  ) -> "Self":
1492
1496
  left = self.clone(new_table=False)
@@ -1502,7 +1506,9 @@ class DatasetQuery:
1502
1506
  if isinstance(predicates, (str, ColumnClause, ColumnElement))
1503
1507
  else tuple(predicates)
1504
1508
  )
1505
- new_query.steps = [SQLJoin(self.catalog, left, right, predicates, inner, rname)]
1509
+ new_query.steps = [
1510
+ SQLJoin(self.catalog, left, right, predicates, inner, full, rname)
1511
+ ]
1506
1512
  return new_query
1507
1513
 
1508
1514
  @detach
@@ -75,7 +75,7 @@ class StudioClient:
75
75
 
76
76
  if not token:
77
77
  raise DataChainError(
78
- "Studio token is not set. Use `datachain studio login` "
78
+ "Studio token is not set. Use `datachain auth login` "
79
79
  "or environment variable `DVC_STUDIO_TOKEN` to set it."
80
80
  )
81
81
 
@@ -105,7 +105,7 @@ class StudioClient:
105
105
  if not team:
106
106
  raise DataChainError(
107
107
  "Studio team is not set. "
108
- "Use `datachain studio team <team_name>` "
108
+ "Use `datachain auth team <team_name>` "
109
109
  "or environment variable `DVC_STUDIO_TEAM` to set it."
110
110
  "You can also set it in the config file as team under studio."
111
111
  )
@@ -4,6 +4,7 @@ import sqlite3
4
4
  import warnings
5
5
  from collections.abc import Iterable
6
6
  from datetime import MAXYEAR, MINYEAR, datetime, timezone
7
+ from functools import cache
7
8
  from types import MappingProxyType
8
9
  from typing import Callable, Optional
9
10
 
@@ -526,24 +527,44 @@ def compile_collect(element, compiler, **kwargs):
526
527
  return compiler.process(func.json_group_array(*element.clauses.clauses), **kwargs)
527
528
 
528
529
 
529
- def load_usearch_extension(conn: sqlite3.Connection) -> bool:
530
+ @cache
531
+ def usearch_sqlite_path() -> Optional[str]:
530
532
  try:
531
- # usearch is part of the vector optional dependencies
532
- # we use the extension's cosine and euclidean distance functions
533
- from usearch import sqlite_path
533
+ import usearch
534
+ except ImportError:
535
+ return None
534
536
 
535
- conn.enable_load_extension(True)
537
+ with warnings.catch_warnings():
538
+ # usearch binary is not available for Windows, see: https://github.com/unum-cloud/usearch/issues/427.
539
+ # and, sometimes fail to download the binary in other platforms
540
+ # triggering UserWarning.
536
541
 
537
- with warnings.catch_warnings():
538
- # usearch binary is not available for Windows, see: https://github.com/unum-cloud/usearch/issues/427.
539
- # and, sometimes fail to download the binary in other platforms
540
- # triggering UserWarning.
542
+ warnings.filterwarnings("ignore", category=UserWarning, module="usearch")
541
543
 
542
- warnings.filterwarnings("ignore", category=UserWarning, module="usearch")
543
- conn.load_extension(sqlite_path())
544
+ try:
545
+ return usearch.sqlite_path()
546
+ except FileNotFoundError:
547
+ return None
544
548
 
545
- conn.enable_load_extension(False)
546
- return True
547
549
 
548
- except Exception: # noqa: BLE001
550
+ def load_usearch_extension(conn: sqlite3.Connection) -> bool:
551
+ # usearch is part of the vector optional dependencies
552
+ # we use the extension's cosine and euclidean distance functions
553
+ ext_path = usearch_sqlite_path()
554
+ if ext_path is None:
555
+ return False
556
+
557
+ try:
558
+ conn.enable_load_extension(True)
559
+ except AttributeError:
560
+ # sqlite3 module is not built with loadable extension support by default.
561
+ return False
562
+
563
+ try:
564
+ conn.load_extension(ext_path)
565
+ except sqlite3.OperationalError:
549
566
  return False
567
+ else:
568
+ return True
569
+ finally:
570
+ conn.enable_load_extension(False)
datachain/studio.py CHANGED
@@ -1,9 +1,8 @@
1
1
  import asyncio
2
2
  import os
3
+ import sys
3
4
  from typing import TYPE_CHECKING, Optional
4
5
 
5
- from tabulate import tabulate
6
-
7
6
  from datachain.catalog.catalog import raise_remote_error
8
7
  from datachain.config import Config, ConfigLevel
9
8
  from datachain.dataset import QUERY_DATASET_PREFIX
@@ -21,6 +20,13 @@ POST_LOGIN_MESSAGE = (
21
20
 
22
21
 
23
22
  def process_jobs_args(args: "Namespace"):
23
+ if args.cmd is None:
24
+ print(
25
+ f"Use 'datachain {args.command} --help' to see available options",
26
+ file=sys.stderr,
27
+ )
28
+ return 1
29
+
24
30
  if args.cmd == "run":
25
31
  return create_job(
26
32
  args.query_file,
@@ -41,20 +47,20 @@ def process_jobs_args(args: "Namespace"):
41
47
  raise DataChainError(f"Unknown command '{args.cmd}'.")
42
48
 
43
49
 
44
- def process_studio_cli_args(args: "Namespace"):
50
+ def process_auth_cli_args(args: "Namespace"):
51
+ if args.cmd is None:
52
+ print(
53
+ f"Use 'datachain {args.command} --help' to see available options",
54
+ file=sys.stderr,
55
+ )
56
+ return 1
57
+
45
58
  if args.cmd == "login":
46
59
  return login(args)
47
60
  if args.cmd == "logout":
48
61
  return logout()
49
62
  if args.cmd == "token":
50
63
  return token()
51
- if args.cmd == "dataset":
52
- rows = [
53
- {"Name": name, "Version": version}
54
- for name, version in list_datasets(args.team)
55
- ]
56
- print(tabulate(rows, headers="keys"))
57
- return 0
58
64
 
59
65
  if args.cmd == "team":
60
66
  return set_team(args)
@@ -89,7 +95,7 @@ def login(args: "Namespace"):
89
95
  raise DataChainError(
90
96
  "Token already exists. "
91
97
  "To login with a different token, "
92
- "logout using `datachain studio logout`."
98
+ "logout using `datachain auth logout`."
93
99
  )
94
100
 
95
101
  open_browser = not args.no_open
@@ -115,12 +121,12 @@ def logout():
115
121
  token = conf.get("studio", {}).get("token")
116
122
  if not token:
117
123
  raise DataChainError(
118
- "Not logged in to Studio. Log in with 'datachain studio login'."
124
+ "Not logged in to Studio. Log in with 'datachain auth login'."
119
125
  )
120
126
 
121
127
  del conf["studio"]["token"]
122
128
 
123
- print("Logged out from Studio. (you can log back in with 'datachain studio login')")
129
+ print("Logged out from Studio. (you can log back in with 'datachain auth login')")
124
130
 
125
131
 
126
132
  def token():
@@ -128,7 +134,7 @@ def token():
128
134
  token = config.get("token")
129
135
  if not token:
130
136
  raise DataChainError(
131
- "Not logged in to Studio. Log in with 'datachain studio login'."
137
+ "Not logged in to Studio. Log in with 'datachain auth login'."
132
138
  )
133
139
 
134
140
  print(token)
@@ -293,7 +299,7 @@ def cancel_job(job_id: str, team_name: Optional[str]):
293
299
  token = Config().read().get("studio", {}).get("token")
294
300
  if not token:
295
301
  raise DataChainError(
296
- "Not logged in to Studio. Log in with 'datachain studio login'."
302
+ "Not logged in to Studio. Log in with 'datachain auth login'."
297
303
  )
298
304
 
299
305
  client = StudioClient(team=team_name)
@@ -308,7 +314,7 @@ def show_job_logs(job_id: str, team_name: Optional[str]):
308
314
  token = Config().read().get("studio", {}).get("token")
309
315
  if not token:
310
316
  raise DataChainError(
311
- "Not logged in to Studio. Log in with 'datachain studio login'."
317
+ "Not logged in to Studio. Log in with 'datachain auth login'."
312
318
  )
313
319
 
314
320
  client = StudioClient(team=team_name)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: datachain
3
- Version: 0.8.8
3
+ Version: 0.8.10
4
4
  Summary: Wrangle unstructured AI data at scale
5
5
  Author-email: Dmitry Petrov <support@dvc.org>
6
6
  License: Apache-2.0
@@ -99,7 +99,7 @@ Requires-Dist: unstructured[pdf]<0.16.12; extra == "examples"
99
99
  Requires-Dist: pdfplumber==0.11.5; extra == "examples"
100
100
  Requires-Dist: huggingface_hub[hf_transfer]; extra == "examples"
101
101
  Requires-Dist: onnx==1.16.1; extra == "examples"
102
- Requires-Dist: ultralytics==8.3.58; extra == "examples"
102
+ Requires-Dist: ultralytics==8.3.61; extra == "examples"
103
103
 
104
104
  ================
105
105
  |logo| DataChain
@@ -189,13 +189,14 @@ Python code:
189
189
 
190
190
  .. code:: py
191
191
 
192
+ import os
192
193
  from mistralai import Mistral
193
194
  from datachain import File, DataChain, Column
194
195
 
195
196
  PROMPT = "Was this dialog successful? Answer in a single word: Success or Failure."
196
197
 
197
198
  def eval_dialogue(file: File) -> bool:
198
- client = Mistral()
199
+ client = Mistral(api_key = os.environ["MISTRAL_API_KEY"])
199
200
  response = client.chat.complete(
200
201
  model="open-mixtral-8x22b",
201
202
  messages=[{"role": "system", "content": PROMPT},