datachain 0.8.8__py3-none-any.whl → 0.8.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/cli/__init__.py +14 -7
- datachain/cli/commands/datasets.py +2 -3
- datachain/cli/parser/__init__.py +69 -82
- datachain/cli/parser/job.py +20 -25
- datachain/cli/parser/studio.py +41 -65
- datachain/cli/parser/utils.py +1 -1
- datachain/cli/utils.py +1 -1
- datachain/client/local.py +1 -1
- datachain/data_storage/sqlite.py +38 -7
- datachain/data_storage/warehouse.py +2 -2
- datachain/lib/arrow.py +1 -1
- datachain/lib/convert/python_to_sql.py +15 -3
- datachain/lib/convert/unflatten.py +1 -2
- datachain/lib/dc.py +26 -5
- datachain/lib/file.py +27 -4
- datachain/lib/listing.py +4 -4
- datachain/lib/pytorch.py +3 -1
- datachain/lib/udf.py +56 -20
- datachain/model/bbox.py +9 -9
- datachain/model/pose.py +9 -9
- datachain/model/segment.py +6 -6
- datachain/progress.py +0 -13
- datachain/query/dataset.py +20 -14
- datachain/remote/studio.py +2 -2
- datachain/sql/sqlite/base.py +35 -14
- datachain/studio.py +22 -16
- {datachain-0.8.8.dist-info → datachain-0.8.10.dist-info}/METADATA +4 -3
- {datachain-0.8.8.dist-info → datachain-0.8.10.dist-info}/RECORD +32 -32
- {datachain-0.8.8.dist-info → datachain-0.8.10.dist-info}/LICENSE +0 -0
- {datachain-0.8.8.dist-info → datachain-0.8.10.dist-info}/WHEEL +0 -0
- {datachain-0.8.8.dist-info → datachain-0.8.10.dist-info}/entry_points.txt +0 -0
- {datachain-0.8.8.dist-info → datachain-0.8.10.dist-info}/top_level.txt +0 -0
datachain/lib/udf.py
CHANGED
|
@@ -16,6 +16,7 @@ from datachain.lib.convert.flatten import flatten
|
|
|
16
16
|
from datachain.lib.data_model import DataValue
|
|
17
17
|
from datachain.lib.file import File
|
|
18
18
|
from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
|
|
19
|
+
from datachain.progress import CombinedDownloadCallback
|
|
19
20
|
from datachain.query.batch import (
|
|
20
21
|
Batch,
|
|
21
22
|
BatchingStrategy,
|
|
@@ -301,20 +302,42 @@ async def _prefetch_input(
|
|
|
301
302
|
return row
|
|
302
303
|
|
|
303
304
|
|
|
305
|
+
def _remove_prefetched(row: T) -> None:
|
|
306
|
+
for obj in row:
|
|
307
|
+
if isinstance(obj, File):
|
|
308
|
+
catalog = obj._catalog
|
|
309
|
+
assert catalog is not None
|
|
310
|
+
try:
|
|
311
|
+
catalog.cache.remove(obj)
|
|
312
|
+
except Exception as e: # noqa: BLE001
|
|
313
|
+
print(f"Failed to remove prefetched item {obj.name!r}: {e!s}")
|
|
314
|
+
|
|
315
|
+
|
|
304
316
|
def _prefetch_inputs(
|
|
305
317
|
prepared_inputs: "Iterable[T]",
|
|
306
318
|
prefetch: int = 0,
|
|
307
319
|
download_cb: Optional["Callback"] = None,
|
|
308
|
-
after_prefetch:
|
|
320
|
+
after_prefetch: Optional[Callable[[], None]] = None,
|
|
321
|
+
remove_prefetched: bool = False,
|
|
309
322
|
) -> "abc.Generator[T, None, None]":
|
|
310
|
-
if prefetch
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
323
|
+
if not prefetch:
|
|
324
|
+
yield from prepared_inputs
|
|
325
|
+
return
|
|
326
|
+
|
|
327
|
+
if after_prefetch is None:
|
|
328
|
+
after_prefetch = noop
|
|
329
|
+
if isinstance(download_cb, CombinedDownloadCallback):
|
|
330
|
+
after_prefetch = download_cb.increment_file_count
|
|
331
|
+
|
|
332
|
+
f = partial(_prefetch_input, download_cb=download_cb, after_prefetch=after_prefetch)
|
|
333
|
+
mapper = AsyncMapper(f, prepared_inputs, workers=prefetch)
|
|
334
|
+
with closing(mapper.iterate()) as row_iter:
|
|
335
|
+
for row in row_iter:
|
|
336
|
+
try:
|
|
337
|
+
yield row # type: ignore[misc]
|
|
338
|
+
finally:
|
|
339
|
+
if remove_prefetched:
|
|
340
|
+
_remove_prefetched(row)
|
|
318
341
|
|
|
319
342
|
|
|
320
343
|
def _get_cache(
|
|
@@ -351,7 +374,13 @@ class Mapper(UDFBase):
|
|
|
351
374
|
)
|
|
352
375
|
|
|
353
376
|
prepared_inputs = _prepare_rows(udf_inputs)
|
|
354
|
-
prepared_inputs = _prefetch_inputs(
|
|
377
|
+
prepared_inputs = _prefetch_inputs(
|
|
378
|
+
prepared_inputs,
|
|
379
|
+
self.prefetch,
|
|
380
|
+
download_cb=download_cb,
|
|
381
|
+
remove_prefetched=bool(self.prefetch) and not cache,
|
|
382
|
+
)
|
|
383
|
+
|
|
355
384
|
with closing(prepared_inputs):
|
|
356
385
|
for id_, *udf_args in prepared_inputs:
|
|
357
386
|
result_objs = self.process_safe(udf_args)
|
|
@@ -391,9 +420,9 @@ class BatchMapper(UDFBase):
|
|
|
391
420
|
)
|
|
392
421
|
result_objs = list(self.process_safe(udf_args))
|
|
393
422
|
n_objs = len(result_objs)
|
|
394
|
-
assert (
|
|
395
|
-
n_objs
|
|
396
|
-
)
|
|
423
|
+
assert n_objs == n_rows, (
|
|
424
|
+
f"{self.name} returns {n_objs} rows, but {n_rows} were expected"
|
|
425
|
+
)
|
|
397
426
|
udf_outputs = (self._flatten_row(row) for row in result_objs)
|
|
398
427
|
output = [
|
|
399
428
|
{"sys__id": row_id} | dict(zip(self.signal_names, signals))
|
|
@@ -429,15 +458,22 @@ class Generator(UDFBase):
|
|
|
429
458
|
row, udf_fields, catalog, cache, download_cb
|
|
430
459
|
)
|
|
431
460
|
|
|
461
|
+
def _process_row(row):
|
|
462
|
+
with safe_closing(self.process_safe(row)) as result_objs:
|
|
463
|
+
for result_obj in result_objs:
|
|
464
|
+
udf_output = self._flatten_row(result_obj)
|
|
465
|
+
yield dict(zip(self.signal_names, udf_output))
|
|
466
|
+
|
|
432
467
|
prepared_inputs = _prepare_rows(udf_inputs)
|
|
433
|
-
prepared_inputs = _prefetch_inputs(
|
|
468
|
+
prepared_inputs = _prefetch_inputs(
|
|
469
|
+
prepared_inputs,
|
|
470
|
+
self.prefetch,
|
|
471
|
+
download_cb=download_cb,
|
|
472
|
+
remove_prefetched=bool(self.prefetch) and not cache,
|
|
473
|
+
)
|
|
434
474
|
with closing(prepared_inputs):
|
|
435
|
-
for row in prepared_inputs:
|
|
436
|
-
|
|
437
|
-
udf_outputs = (self._flatten_row(row) for row in result_objs)
|
|
438
|
-
output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
|
|
439
|
-
processed_cb.relative_update(1)
|
|
440
|
-
yield output
|
|
475
|
+
for row in processed_cb.wrap(prepared_inputs):
|
|
476
|
+
yield _process_row(row)
|
|
441
477
|
|
|
442
478
|
self.teardown()
|
|
443
479
|
|
datachain/model/bbox.py
CHANGED
|
@@ -22,9 +22,9 @@ class BBox(DataModel):
|
|
|
22
22
|
@staticmethod
|
|
23
23
|
def from_list(coords: list[float], title: str = "") -> "BBox":
|
|
24
24
|
assert len(coords) == 4, "Bounding box must be a list of 4 coordinates."
|
|
25
|
-
assert all(
|
|
26
|
-
|
|
27
|
-
)
|
|
25
|
+
assert all(isinstance(value, (int, float)) for value in coords), (
|
|
26
|
+
"Bounding box coordinates must be floats or integers."
|
|
27
|
+
)
|
|
28
28
|
return BBox(
|
|
29
29
|
title=title,
|
|
30
30
|
coords=[round(c) for c in coords],
|
|
@@ -64,12 +64,12 @@ class OBBox(DataModel):
|
|
|
64
64
|
|
|
65
65
|
@staticmethod
|
|
66
66
|
def from_list(coords: list[float], title: str = "") -> "OBBox":
|
|
67
|
-
assert (
|
|
68
|
-
|
|
69
|
-
)
|
|
70
|
-
assert all(
|
|
71
|
-
|
|
72
|
-
)
|
|
67
|
+
assert len(coords) == 8, (
|
|
68
|
+
"Oriented bounding box must be a list of 8 coordinates."
|
|
69
|
+
)
|
|
70
|
+
assert all(isinstance(value, (int, float)) for value in coords), (
|
|
71
|
+
"Oriented bounding box coordinates must be floats or integers."
|
|
72
|
+
)
|
|
73
73
|
return OBBox(
|
|
74
74
|
title=title,
|
|
75
75
|
coords=[round(c) for c in coords],
|
datachain/model/pose.py
CHANGED
|
@@ -22,9 +22,9 @@ class Pose(DataModel):
|
|
|
22
22
|
def from_list(points: list[list[float]]) -> "Pose":
|
|
23
23
|
assert len(points) == 2, "Pose must be a list of 2 lists: x and y coordinates."
|
|
24
24
|
points_x, points_y = points
|
|
25
|
-
assert (
|
|
26
|
-
|
|
27
|
-
)
|
|
25
|
+
assert len(points_x) == len(points_y) == 17, (
|
|
26
|
+
"Pose x and y coordinates must have the same length of 17."
|
|
27
|
+
)
|
|
28
28
|
assert all(
|
|
29
29
|
isinstance(value, (int, float)) for value in [*points_x, *points_y]
|
|
30
30
|
), "Pose coordinates must be floats or integers."
|
|
@@ -61,13 +61,13 @@ class Pose3D(DataModel):
|
|
|
61
61
|
|
|
62
62
|
@staticmethod
|
|
63
63
|
def from_list(points: list[list[float]]) -> "Pose3D":
|
|
64
|
-
assert (
|
|
65
|
-
|
|
66
|
-
)
|
|
64
|
+
assert len(points) == 3, (
|
|
65
|
+
"Pose3D must be a list of 3 lists: x, y coordinates and visible."
|
|
66
|
+
)
|
|
67
67
|
points_x, points_y, points_v = points
|
|
68
|
-
assert (
|
|
69
|
-
|
|
70
|
-
)
|
|
68
|
+
assert len(points_x) == len(points_y) == len(points_v) == 17, (
|
|
69
|
+
"Pose3D x, y coordinates and visible must have the same length of 17."
|
|
70
|
+
)
|
|
71
71
|
assert all(
|
|
72
72
|
isinstance(value, (int, float))
|
|
73
73
|
for value in [*points_x, *points_y, *points_v]
|
datachain/model/segment.py
CHANGED
|
@@ -22,13 +22,13 @@ class Segment(DataModel):
|
|
|
22
22
|
|
|
23
23
|
@staticmethod
|
|
24
24
|
def from_list(points: list[list[float]], title: str = "") -> "Segment":
|
|
25
|
-
assert (
|
|
26
|
-
|
|
27
|
-
)
|
|
25
|
+
assert len(points) == 2, (
|
|
26
|
+
"Segment must be a list of 2 lists: x and y coordinates."
|
|
27
|
+
)
|
|
28
28
|
points_x, points_y = points
|
|
29
|
-
assert len(points_x) == len(
|
|
30
|
-
|
|
31
|
-
)
|
|
29
|
+
assert len(points_x) == len(points_y), (
|
|
30
|
+
"Segment x and y coordinates must have the same length."
|
|
31
|
+
)
|
|
32
32
|
assert all(
|
|
33
33
|
isinstance(value, (int, float)) for value in [*points_x, *points_y]
|
|
34
34
|
), "Segment coordinates must be floats or integers."
|
datachain/progress.py
CHANGED
|
@@ -1,14 +1,5 @@
|
|
|
1
|
-
"""Manages progress bars."""
|
|
2
|
-
|
|
3
|
-
import logging
|
|
4
|
-
from threading import RLock
|
|
5
|
-
|
|
6
1
|
from fsspec import Callback
|
|
7
2
|
from fsspec.callbacks import TqdmCallback
|
|
8
|
-
from tqdm.auto import tqdm
|
|
9
|
-
|
|
10
|
-
logger = logging.getLogger(__name__)
|
|
11
|
-
tqdm.set_lock(RLock())
|
|
12
3
|
|
|
13
4
|
|
|
14
5
|
class CombinedDownloadCallback(Callback):
|
|
@@ -24,10 +15,6 @@ class CombinedDownloadCallback(Callback):
|
|
|
24
15
|
class TqdmCombinedDownloadCallback(CombinedDownloadCallback, TqdmCallback):
|
|
25
16
|
def __init__(self, tqdm_kwargs=None, *args, **kwargs):
|
|
26
17
|
self.files_count = 0
|
|
27
|
-
tqdm_kwargs = tqdm_kwargs or {}
|
|
28
|
-
tqdm_kwargs.setdefault("postfix", {}).setdefault("files", self.files_count)
|
|
29
|
-
kwargs = kwargs or {}
|
|
30
|
-
kwargs["tqdm_cls"] = tqdm
|
|
31
18
|
super().__init__(tqdm_kwargs, *args, **kwargs)
|
|
32
19
|
|
|
33
20
|
def increment_file_count(self, n: int = 1) -> None:
|
datachain/query/dataset.py
CHANGED
|
@@ -336,15 +336,16 @@ def process_udf_outputs(
|
|
|
336
336
|
for udf_output in udf_results:
|
|
337
337
|
if not udf_output:
|
|
338
338
|
continue
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
len(rows)
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
339
|
+
with safe_closing(udf_output):
|
|
340
|
+
for row in udf_output:
|
|
341
|
+
cb.relative_update()
|
|
342
|
+
rows.append(adjust_outputs(warehouse, row, udf_col_types))
|
|
343
|
+
if len(rows) >= batch_size or (
|
|
344
|
+
len(rows) % 10 == 0 and psutil.virtual_memory().percent > 80
|
|
345
|
+
):
|
|
346
|
+
for row_chunk in batched(rows, batch_size):
|
|
347
|
+
warehouse.insert_rows(udf_table, row_chunk)
|
|
348
|
+
rows.clear()
|
|
348
349
|
|
|
349
350
|
if rows:
|
|
350
351
|
for row_chunk in batched(rows, batch_size):
|
|
@@ -355,7 +356,7 @@ def process_udf_outputs(
|
|
|
355
356
|
|
|
356
357
|
def get_download_callback(suffix: str = "", **kwargs) -> CombinedDownloadCallback:
|
|
357
358
|
return TqdmCombinedDownloadCallback(
|
|
358
|
-
{
|
|
359
|
+
tqdm_kwargs={
|
|
359
360
|
"desc": "Download" + suffix,
|
|
360
361
|
"unit": "B",
|
|
361
362
|
"unit_scale": True,
|
|
@@ -363,6 +364,7 @@ def get_download_callback(suffix: str = "", **kwargs) -> CombinedDownloadCallbac
|
|
|
363
364
|
"leave": False,
|
|
364
365
|
**kwargs,
|
|
365
366
|
},
|
|
367
|
+
tqdm_cls=tqdm,
|
|
366
368
|
)
|
|
367
369
|
|
|
368
370
|
|
|
@@ -873,6 +875,7 @@ class SQLJoin(Step):
|
|
|
873
875
|
query2: "DatasetQuery"
|
|
874
876
|
predicates: Union[JoinPredicateType, tuple[JoinPredicateType, ...]]
|
|
875
877
|
inner: bool
|
|
878
|
+
full: bool
|
|
876
879
|
rname: str
|
|
877
880
|
|
|
878
881
|
def get_query(self, dq: "DatasetQuery", temp_tables: list[str]) -> sa.Subquery:
|
|
@@ -975,14 +978,14 @@ class SQLJoin(Step):
|
|
|
975
978
|
self.validate_expression(join_expression, q1, q2)
|
|
976
979
|
|
|
977
980
|
def q(*columns):
|
|
978
|
-
|
|
981
|
+
return self.catalog.warehouse.join(
|
|
979
982
|
q1,
|
|
980
983
|
q2,
|
|
981
984
|
join_expression,
|
|
982
985
|
inner=self.inner,
|
|
986
|
+
full=self.full,
|
|
987
|
+
columns=columns,
|
|
983
988
|
)
|
|
984
|
-
return sqlalchemy.select(*columns).select_from(join_query)
|
|
985
|
-
# return sqlalchemy.select(*subquery.c).select_from(subquery)
|
|
986
989
|
|
|
987
990
|
return step_result(
|
|
988
991
|
q,
|
|
@@ -1487,6 +1490,7 @@ class DatasetQuery:
|
|
|
1487
1490
|
dataset_query: "DatasetQuery",
|
|
1488
1491
|
predicates: Union[JoinPredicateType, Sequence[JoinPredicateType]],
|
|
1489
1492
|
inner=False,
|
|
1493
|
+
full=False,
|
|
1490
1494
|
rname="{name}_right",
|
|
1491
1495
|
) -> "Self":
|
|
1492
1496
|
left = self.clone(new_table=False)
|
|
@@ -1502,7 +1506,9 @@ class DatasetQuery:
|
|
|
1502
1506
|
if isinstance(predicates, (str, ColumnClause, ColumnElement))
|
|
1503
1507
|
else tuple(predicates)
|
|
1504
1508
|
)
|
|
1505
|
-
new_query.steps = [
|
|
1509
|
+
new_query.steps = [
|
|
1510
|
+
SQLJoin(self.catalog, left, right, predicates, inner, full, rname)
|
|
1511
|
+
]
|
|
1506
1512
|
return new_query
|
|
1507
1513
|
|
|
1508
1514
|
@detach
|
datachain/remote/studio.py
CHANGED
|
@@ -75,7 +75,7 @@ class StudioClient:
|
|
|
75
75
|
|
|
76
76
|
if not token:
|
|
77
77
|
raise DataChainError(
|
|
78
|
-
"Studio token is not set. Use `datachain
|
|
78
|
+
"Studio token is not set. Use `datachain auth login` "
|
|
79
79
|
"or environment variable `DVC_STUDIO_TOKEN` to set it."
|
|
80
80
|
)
|
|
81
81
|
|
|
@@ -105,7 +105,7 @@ class StudioClient:
|
|
|
105
105
|
if not team:
|
|
106
106
|
raise DataChainError(
|
|
107
107
|
"Studio team is not set. "
|
|
108
|
-
"Use `datachain
|
|
108
|
+
"Use `datachain auth team <team_name>` "
|
|
109
109
|
"or environment variable `DVC_STUDIO_TEAM` to set it."
|
|
110
110
|
"You can also set it in the config file as team under studio."
|
|
111
111
|
)
|
datachain/sql/sqlite/base.py
CHANGED
|
@@ -4,6 +4,7 @@ import sqlite3
|
|
|
4
4
|
import warnings
|
|
5
5
|
from collections.abc import Iterable
|
|
6
6
|
from datetime import MAXYEAR, MINYEAR, datetime, timezone
|
|
7
|
+
from functools import cache
|
|
7
8
|
from types import MappingProxyType
|
|
8
9
|
from typing import Callable, Optional
|
|
9
10
|
|
|
@@ -526,24 +527,44 @@ def compile_collect(element, compiler, **kwargs):
|
|
|
526
527
|
return compiler.process(func.json_group_array(*element.clauses.clauses), **kwargs)
|
|
527
528
|
|
|
528
529
|
|
|
529
|
-
|
|
530
|
+
@cache
|
|
531
|
+
def usearch_sqlite_path() -> Optional[str]:
|
|
530
532
|
try:
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
533
|
+
import usearch
|
|
534
|
+
except ImportError:
|
|
535
|
+
return None
|
|
534
536
|
|
|
535
|
-
|
|
537
|
+
with warnings.catch_warnings():
|
|
538
|
+
# usearch binary is not available for Windows, see: https://github.com/unum-cloud/usearch/issues/427.
|
|
539
|
+
# and, sometimes fail to download the binary in other platforms
|
|
540
|
+
# triggering UserWarning.
|
|
536
541
|
|
|
537
|
-
|
|
538
|
-
# usearch binary is not available for Windows, see: https://github.com/unum-cloud/usearch/issues/427.
|
|
539
|
-
# and, sometimes fail to download the binary in other platforms
|
|
540
|
-
# triggering UserWarning.
|
|
542
|
+
warnings.filterwarnings("ignore", category=UserWarning, module="usearch")
|
|
541
543
|
|
|
542
|
-
|
|
543
|
-
|
|
544
|
+
try:
|
|
545
|
+
return usearch.sqlite_path()
|
|
546
|
+
except FileNotFoundError:
|
|
547
|
+
return None
|
|
544
548
|
|
|
545
|
-
conn.enable_load_extension(False)
|
|
546
|
-
return True
|
|
547
549
|
|
|
548
|
-
|
|
550
|
+
def load_usearch_extension(conn: sqlite3.Connection) -> bool:
|
|
551
|
+
# usearch is part of the vector optional dependencies
|
|
552
|
+
# we use the extension's cosine and euclidean distance functions
|
|
553
|
+
ext_path = usearch_sqlite_path()
|
|
554
|
+
if ext_path is None:
|
|
555
|
+
return False
|
|
556
|
+
|
|
557
|
+
try:
|
|
558
|
+
conn.enable_load_extension(True)
|
|
559
|
+
except AttributeError:
|
|
560
|
+
# sqlite3 module is not built with loadable extension support by default.
|
|
561
|
+
return False
|
|
562
|
+
|
|
563
|
+
try:
|
|
564
|
+
conn.load_extension(ext_path)
|
|
565
|
+
except sqlite3.OperationalError:
|
|
549
566
|
return False
|
|
567
|
+
else:
|
|
568
|
+
return True
|
|
569
|
+
finally:
|
|
570
|
+
conn.enable_load_extension(False)
|
datachain/studio.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import os
|
|
3
|
+
import sys
|
|
3
4
|
from typing import TYPE_CHECKING, Optional
|
|
4
5
|
|
|
5
|
-
from tabulate import tabulate
|
|
6
|
-
|
|
7
6
|
from datachain.catalog.catalog import raise_remote_error
|
|
8
7
|
from datachain.config import Config, ConfigLevel
|
|
9
8
|
from datachain.dataset import QUERY_DATASET_PREFIX
|
|
@@ -21,6 +20,13 @@ POST_LOGIN_MESSAGE = (
|
|
|
21
20
|
|
|
22
21
|
|
|
23
22
|
def process_jobs_args(args: "Namespace"):
|
|
23
|
+
if args.cmd is None:
|
|
24
|
+
print(
|
|
25
|
+
f"Use 'datachain {args.command} --help' to see available options",
|
|
26
|
+
file=sys.stderr,
|
|
27
|
+
)
|
|
28
|
+
return 1
|
|
29
|
+
|
|
24
30
|
if args.cmd == "run":
|
|
25
31
|
return create_job(
|
|
26
32
|
args.query_file,
|
|
@@ -41,20 +47,20 @@ def process_jobs_args(args: "Namespace"):
|
|
|
41
47
|
raise DataChainError(f"Unknown command '{args.cmd}'.")
|
|
42
48
|
|
|
43
49
|
|
|
44
|
-
def
|
|
50
|
+
def process_auth_cli_args(args: "Namespace"):
|
|
51
|
+
if args.cmd is None:
|
|
52
|
+
print(
|
|
53
|
+
f"Use 'datachain {args.command} --help' to see available options",
|
|
54
|
+
file=sys.stderr,
|
|
55
|
+
)
|
|
56
|
+
return 1
|
|
57
|
+
|
|
45
58
|
if args.cmd == "login":
|
|
46
59
|
return login(args)
|
|
47
60
|
if args.cmd == "logout":
|
|
48
61
|
return logout()
|
|
49
62
|
if args.cmd == "token":
|
|
50
63
|
return token()
|
|
51
|
-
if args.cmd == "dataset":
|
|
52
|
-
rows = [
|
|
53
|
-
{"Name": name, "Version": version}
|
|
54
|
-
for name, version in list_datasets(args.team)
|
|
55
|
-
]
|
|
56
|
-
print(tabulate(rows, headers="keys"))
|
|
57
|
-
return 0
|
|
58
64
|
|
|
59
65
|
if args.cmd == "team":
|
|
60
66
|
return set_team(args)
|
|
@@ -89,7 +95,7 @@ def login(args: "Namespace"):
|
|
|
89
95
|
raise DataChainError(
|
|
90
96
|
"Token already exists. "
|
|
91
97
|
"To login with a different token, "
|
|
92
|
-
"logout using `datachain
|
|
98
|
+
"logout using `datachain auth logout`."
|
|
93
99
|
)
|
|
94
100
|
|
|
95
101
|
open_browser = not args.no_open
|
|
@@ -115,12 +121,12 @@ def logout():
|
|
|
115
121
|
token = conf.get("studio", {}).get("token")
|
|
116
122
|
if not token:
|
|
117
123
|
raise DataChainError(
|
|
118
|
-
"Not logged in to Studio. Log in with 'datachain
|
|
124
|
+
"Not logged in to Studio. Log in with 'datachain auth login'."
|
|
119
125
|
)
|
|
120
126
|
|
|
121
127
|
del conf["studio"]["token"]
|
|
122
128
|
|
|
123
|
-
print("Logged out from Studio. (you can log back in with 'datachain
|
|
129
|
+
print("Logged out from Studio. (you can log back in with 'datachain auth login')")
|
|
124
130
|
|
|
125
131
|
|
|
126
132
|
def token():
|
|
@@ -128,7 +134,7 @@ def token():
|
|
|
128
134
|
token = config.get("token")
|
|
129
135
|
if not token:
|
|
130
136
|
raise DataChainError(
|
|
131
|
-
"Not logged in to Studio. Log in with 'datachain
|
|
137
|
+
"Not logged in to Studio. Log in with 'datachain auth login'."
|
|
132
138
|
)
|
|
133
139
|
|
|
134
140
|
print(token)
|
|
@@ -293,7 +299,7 @@ def cancel_job(job_id: str, team_name: Optional[str]):
|
|
|
293
299
|
token = Config().read().get("studio", {}).get("token")
|
|
294
300
|
if not token:
|
|
295
301
|
raise DataChainError(
|
|
296
|
-
"Not logged in to Studio. Log in with 'datachain
|
|
302
|
+
"Not logged in to Studio. Log in with 'datachain auth login'."
|
|
297
303
|
)
|
|
298
304
|
|
|
299
305
|
client = StudioClient(team=team_name)
|
|
@@ -308,7 +314,7 @@ def show_job_logs(job_id: str, team_name: Optional[str]):
|
|
|
308
314
|
token = Config().read().get("studio", {}).get("token")
|
|
309
315
|
if not token:
|
|
310
316
|
raise DataChainError(
|
|
311
|
-
"Not logged in to Studio. Log in with 'datachain
|
|
317
|
+
"Not logged in to Studio. Log in with 'datachain auth login'."
|
|
312
318
|
)
|
|
313
319
|
|
|
314
320
|
client = StudioClient(team=team_name)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: datachain
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.10
|
|
4
4
|
Summary: Wrangle unstructured AI data at scale
|
|
5
5
|
Author-email: Dmitry Petrov <support@dvc.org>
|
|
6
6
|
License: Apache-2.0
|
|
@@ -99,7 +99,7 @@ Requires-Dist: unstructured[pdf]<0.16.12; extra == "examples"
|
|
|
99
99
|
Requires-Dist: pdfplumber==0.11.5; extra == "examples"
|
|
100
100
|
Requires-Dist: huggingface_hub[hf_transfer]; extra == "examples"
|
|
101
101
|
Requires-Dist: onnx==1.16.1; extra == "examples"
|
|
102
|
-
Requires-Dist: ultralytics==8.3.
|
|
102
|
+
Requires-Dist: ultralytics==8.3.61; extra == "examples"
|
|
103
103
|
|
|
104
104
|
================
|
|
105
105
|
|logo| DataChain
|
|
@@ -189,13 +189,14 @@ Python code:
|
|
|
189
189
|
|
|
190
190
|
.. code:: py
|
|
191
191
|
|
|
192
|
+
import os
|
|
192
193
|
from mistralai import Mistral
|
|
193
194
|
from datachain import File, DataChain, Column
|
|
194
195
|
|
|
195
196
|
PROMPT = "Was this dialog successful? Answer in a single word: Success or Failure."
|
|
196
197
|
|
|
197
198
|
def eval_dialogue(file: File) -> bool:
|
|
198
|
-
client = Mistral()
|
|
199
|
+
client = Mistral(api_key = os.environ["MISTRAL_API_KEY"])
|
|
199
200
|
response = client.chat.complete(
|
|
200
201
|
model="open-mixtral-8x22b",
|
|
201
202
|
messages=[{"role": "system", "content": PROMPT},
|