datachain 0.30.5__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +4 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +5 -5
- datachain/catalog/__init__.py +0 -2
- datachain/catalog/catalog.py +276 -354
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +8 -3
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +10 -17
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +42 -27
- datachain/cli/commands/ls.py +15 -15
- datachain/cli/commands/show.py +2 -2
- datachain/cli/parser/__init__.py +3 -43
- datachain/cli/parser/job.py +1 -1
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +2 -2
- datachain/client/fsspec.py +34 -23
- datachain/client/gcs.py +3 -3
- datachain/client/http.py +157 -0
- datachain/client/local.py +11 -7
- datachain/client/s3.py +3 -3
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +2 -0
- datachain/data_storage/metastore.py +716 -137
- datachain/data_storage/schema.py +20 -27
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +114 -114
- datachain/data_storage/warehouse.py +140 -48
- datachain/dataset.py +109 -89
- datachain/delta.py +117 -42
- datachain/diff/__init__.py +25 -33
- datachain/error.py +24 -0
- datachain/func/aggregate.py +9 -11
- datachain/func/array.py +12 -12
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +9 -13
- datachain/func/func.py +63 -45
- datachain/func/numeric.py +5 -7
- datachain/func/string.py +2 -2
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +18 -15
- datachain/lib/audio.py +60 -59
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/values_to_tuples.py +151 -53
- datachain/lib/data_model.py +23 -19
- datachain/lib/dataset_info.py +7 -7
- datachain/lib/dc/__init__.py +2 -1
- datachain/lib/dc/csv.py +22 -26
- datachain/lib/dc/database.py +37 -34
- datachain/lib/dc/datachain.py +518 -324
- datachain/lib/dc/datasets.py +38 -30
- datachain/lib/dc/hf.py +16 -20
- datachain/lib/dc/json.py +17 -18
- datachain/lib/dc/listings.py +5 -8
- datachain/lib/dc/pandas.py +3 -6
- datachain/lib/dc/parquet.py +33 -21
- datachain/lib/dc/records.py +9 -13
- datachain/lib/dc/storage.py +103 -65
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +17 -14
- datachain/lib/dc/values.py +3 -6
- datachain/lib/file.py +187 -50
- datachain/lib/hf.py +7 -5
- datachain/lib/image.py +13 -13
- datachain/lib/listing.py +5 -5
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +2 -3
- datachain/lib/model_store.py +20 -8
- datachain/lib/namespaces.py +59 -7
- datachain/lib/projects.py +51 -9
- datachain/lib/pytorch.py +31 -23
- datachain/lib/settings.py +188 -85
- datachain/lib/signal_schema.py +302 -64
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +103 -63
- datachain/lib/udf_signature.py +59 -34
- datachain/lib/utils.py +20 -0
- datachain/lib/video.py +3 -4
- datachain/lib/webdataset.py +31 -36
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +12 -5
- datachain/model/bbox.py +3 -1
- datachain/namespace.py +22 -3
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +4 -4
- datachain/query/batch.py +10 -12
- datachain/query/dataset.py +376 -194
- datachain/query/dispatch.py +112 -84
- datachain/query/metrics.py +3 -4
- datachain/query/params.py +2 -3
- datachain/query/queue.py +2 -1
- datachain/query/schema.py +7 -6
- datachain/query/session.py +190 -33
- datachain/query/udf.py +9 -6
- datachain/remote/studio.py +90 -53
- datachain/script_meta.py +12 -12
- datachain/sql/sqlite/base.py +37 -25
- datachain/sql/sqlite/types.py +1 -1
- datachain/sql/types.py +36 -5
- datachain/studio.py +49 -40
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +39 -48
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/METADATA +26 -38
- datachain-0.39.0.dist-info/RECORD +173 -0
- datachain/cli/commands/query.py +0 -54
- datachain/query/utils.py +0 -36
- datachain-0.30.5.dist-info/RECORD +0 -168
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/WHEEL +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/sql/sqlite/base.py
CHANGED
|
@@ -2,20 +2,20 @@ import logging
|
|
|
2
2
|
import re
|
|
3
3
|
import sqlite3
|
|
4
4
|
import warnings
|
|
5
|
-
from collections.abc import Iterable
|
|
5
|
+
from collections.abc import Callable, Iterable
|
|
6
|
+
from contextlib import closing
|
|
6
7
|
from datetime import MAXYEAR, MINYEAR, datetime, timezone
|
|
7
8
|
from functools import cache
|
|
8
9
|
from types import MappingProxyType
|
|
9
|
-
from typing import Callable, Optional
|
|
10
10
|
|
|
11
11
|
import sqlalchemy as sa
|
|
12
|
-
import ujson as json
|
|
13
12
|
from sqlalchemy.dialects import sqlite
|
|
14
13
|
from sqlalchemy.ext.compiler import compiles
|
|
15
14
|
from sqlalchemy.sql.elements import literal
|
|
16
15
|
from sqlalchemy.sql.expression import case
|
|
17
16
|
from sqlalchemy.sql.functions import func
|
|
18
17
|
|
|
18
|
+
from datachain import json
|
|
19
19
|
from datachain.sql.functions import (
|
|
20
20
|
aggregate,
|
|
21
21
|
array,
|
|
@@ -112,7 +112,10 @@ def setup():
|
|
|
112
112
|
compiles(numeric.int_hash_64, "sqlite")(compile_int_hash_64)
|
|
113
113
|
compiles(numeric.bit_hamming_distance, "sqlite")(compile_bit_hamming_distance)
|
|
114
114
|
|
|
115
|
-
|
|
115
|
+
with closing(sqlite3.connect(":memory:")) as _usearch_conn:
|
|
116
|
+
usearch_available = load_usearch_extension(_usearch_conn)
|
|
117
|
+
|
|
118
|
+
if usearch_available:
|
|
116
119
|
compiles(array.cosine_distance, "sqlite")(compile_cosine_distance_ext)
|
|
117
120
|
compiles(array.euclidean_distance, "sqlite")(compile_euclidean_distance_ext)
|
|
118
121
|
else:
|
|
@@ -132,7 +135,7 @@ def run_compiler_hook(name):
|
|
|
132
135
|
|
|
133
136
|
|
|
134
137
|
def functions_exist(
|
|
135
|
-
names: Iterable[str], connection:
|
|
138
|
+
names: Iterable[str], connection: sqlite3.Connection | None = None
|
|
136
139
|
) -> bool:
|
|
137
140
|
"""
|
|
138
141
|
Returns True if all function names are defined for the given connection.
|
|
@@ -146,23 +149,34 @@ def functions_exist(
|
|
|
146
149
|
f"Found value of type {type(n).__name__}: {n!r}"
|
|
147
150
|
)
|
|
148
151
|
|
|
152
|
+
close_connection = False
|
|
149
153
|
if connection is None:
|
|
150
154
|
connection = sqlite3.connect(":memory:")
|
|
155
|
+
close_connection = True
|
|
151
156
|
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
157
|
+
try:
|
|
158
|
+
if not names:
|
|
159
|
+
return True
|
|
160
|
+
column1 = sa.column("column1", sa.String)
|
|
161
|
+
func_name_query = column1.not_in(
|
|
162
|
+
sa.select(sa.column("name", sa.String)).select_from(
|
|
163
|
+
func.pragma_function_list()
|
|
164
|
+
)
|
|
165
|
+
)
|
|
166
|
+
query = (
|
|
167
|
+
sa.select(func.count() == 0)
|
|
168
|
+
.select_from(sa.values(column1).data([(n,) for n in names]))
|
|
169
|
+
.where(func_name_query)
|
|
170
|
+
)
|
|
171
|
+
comp = query.compile(dialect=sqlite_dialect)
|
|
172
|
+
if comp.params:
|
|
173
|
+
result = connection.execute(comp.string, comp.params)
|
|
174
|
+
else:
|
|
175
|
+
result = connection.execute(comp.string)
|
|
176
|
+
return bool(result.fetchone()[0])
|
|
177
|
+
finally:
|
|
178
|
+
if close_connection:
|
|
179
|
+
connection.close()
|
|
166
180
|
|
|
167
181
|
|
|
168
182
|
def create_user_defined_sql_functions(connection):
|
|
@@ -201,9 +215,7 @@ def sqlite_int_hash_64(x: int) -> int:
|
|
|
201
215
|
def sqlite_bit_hamming_distance(a: int, b: int) -> int:
|
|
202
216
|
"""Calculate the Hamming distance between two integers."""
|
|
203
217
|
diff = (a & MAX_INT64) ^ (b & MAX_INT64)
|
|
204
|
-
|
|
205
|
-
return diff.bit_count()
|
|
206
|
-
return bin(diff).count("1")
|
|
218
|
+
return diff.bit_count()
|
|
207
219
|
|
|
208
220
|
|
|
209
221
|
def sqlite_byte_hamming_distance(a: str, b: str) -> int:
|
|
@@ -215,7 +227,7 @@ def sqlite_byte_hamming_distance(a: str, b: str) -> int:
|
|
|
215
227
|
elif len(b) < len(a):
|
|
216
228
|
diff = len(a) - len(b)
|
|
217
229
|
a = a[: len(b)]
|
|
218
|
-
return diff + sum(c1 != c2 for c1, c2 in zip(a, b))
|
|
230
|
+
return diff + sum(c1 != c2 for c1, c2 in zip(a, b, strict=False))
|
|
219
231
|
|
|
220
232
|
|
|
221
233
|
def register_user_defined_sql_functions() -> None:
|
|
@@ -470,7 +482,7 @@ def py_json_array_get_element(val, idx):
|
|
|
470
482
|
return None
|
|
471
483
|
|
|
472
484
|
|
|
473
|
-
def py_json_array_slice(val, offset: int, length:
|
|
485
|
+
def py_json_array_slice(val, offset: int, length: int | None = None):
|
|
474
486
|
arr = json.loads(val)
|
|
475
487
|
try:
|
|
476
488
|
return json.dumps(
|
|
@@ -605,7 +617,7 @@ def compile_collect(element, compiler, **kwargs):
|
|
|
605
617
|
|
|
606
618
|
|
|
607
619
|
@cache
|
|
608
|
-
def usearch_sqlite_path() ->
|
|
620
|
+
def usearch_sqlite_path() -> str | None:
|
|
609
621
|
try:
|
|
610
622
|
import usearch
|
|
611
623
|
except ImportError:
|
datachain/sql/sqlite/types.py
CHANGED
datachain/sql/types.py
CHANGED
|
@@ -12,14 +12,15 @@ for sqlite we can use `sqlite.register_converter`
|
|
|
12
12
|
( https://docs.python.org/3/library/sqlite3.html#sqlite3.register_converter )
|
|
13
13
|
"""
|
|
14
14
|
|
|
15
|
+
import numbers
|
|
15
16
|
from datetime import datetime
|
|
16
17
|
from types import MappingProxyType
|
|
17
18
|
from typing import Any, Union
|
|
18
19
|
|
|
19
20
|
import sqlalchemy as sa
|
|
20
|
-
import ujson as jsonlib
|
|
21
21
|
from sqlalchemy import TypeDecorator, types
|
|
22
22
|
|
|
23
|
+
from datachain import json as jsonlib
|
|
23
24
|
from datachain.lib.data_model import StandardType
|
|
24
25
|
|
|
25
26
|
_registry: dict[str, "TypeConverter"] = {}
|
|
@@ -336,10 +337,28 @@ class Array(SQLType):
|
|
|
336
337
|
|
|
337
338
|
@classmethod
|
|
338
339
|
def from_dict(cls, d: dict[str, Any]) -> Union[type["SQLType"], "SQLType"]:
|
|
339
|
-
|
|
340
|
-
d["item_type"]
|
|
341
|
-
|
|
342
|
-
|
|
340
|
+
try:
|
|
341
|
+
array_item = d["item_type"]
|
|
342
|
+
except KeyError as e:
|
|
343
|
+
raise ValueError("Array type must have 'item_type' field") from e
|
|
344
|
+
|
|
345
|
+
if not isinstance(array_item, dict):
|
|
346
|
+
raise TypeError("Array 'item_type' field must be a dictionary")
|
|
347
|
+
|
|
348
|
+
try:
|
|
349
|
+
item_type = array_item["type"]
|
|
350
|
+
except KeyError as e:
|
|
351
|
+
raise ValueError("Array 'item_type' must have 'type' field") from e
|
|
352
|
+
|
|
353
|
+
try:
|
|
354
|
+
sub_t = NAME_TYPES_MAPPING[item_type]
|
|
355
|
+
except KeyError as e:
|
|
356
|
+
raise ValueError(f"Array item type '{item_type}' is not supported") from e
|
|
357
|
+
|
|
358
|
+
try:
|
|
359
|
+
return cls(sub_t.from_dict(d["item_type"])) # type: ignore [attr-defined]
|
|
360
|
+
except KeyError as e:
|
|
361
|
+
raise ValueError(f"Array item type '{item_type}' is not supported") from e
|
|
343
362
|
|
|
344
363
|
@staticmethod
|
|
345
364
|
def default_value(dialect):
|
|
@@ -427,6 +446,18 @@ class TypeReadConverter:
|
|
|
427
446
|
return value
|
|
428
447
|
|
|
429
448
|
def boolean(self, value):
|
|
449
|
+
if value is None or isinstance(value, bool):
|
|
450
|
+
return value
|
|
451
|
+
|
|
452
|
+
if isinstance(value, numbers.Integral):
|
|
453
|
+
return bool(value)
|
|
454
|
+
if isinstance(value, str):
|
|
455
|
+
normalized = value.strip().lower()
|
|
456
|
+
if normalized in {"true", "t", "yes", "y", "1"}:
|
|
457
|
+
return True
|
|
458
|
+
if normalized in {"false", "f", "no", "n", "0"}:
|
|
459
|
+
return False
|
|
460
|
+
|
|
430
461
|
return value
|
|
431
462
|
|
|
432
463
|
def int(self, value):
|
datachain/studio.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import os
|
|
3
3
|
import sys
|
|
4
|
+
import warnings
|
|
4
5
|
from datetime import datetime, timezone
|
|
5
|
-
from typing import TYPE_CHECKING
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
6
7
|
|
|
7
8
|
import dateparser
|
|
8
9
|
import tabulate
|
|
@@ -175,7 +176,7 @@ def token():
|
|
|
175
176
|
print(token)
|
|
176
177
|
|
|
177
178
|
|
|
178
|
-
def list_datasets(team:
|
|
179
|
+
def list_datasets(team: str | None = None, name: str | None = None):
|
|
179
180
|
def ds_full_name(ds: dict) -> str:
|
|
180
181
|
return (
|
|
181
182
|
f"{ds['project']['namespace']['name']}.{ds['project']['name']}.{ds['name']}"
|
|
@@ -206,7 +207,7 @@ def list_datasets(team: Optional[str] = None, name: Optional[str] = None):
|
|
|
206
207
|
yield (full_name, version)
|
|
207
208
|
|
|
208
209
|
|
|
209
|
-
def list_dataset_versions(team:
|
|
210
|
+
def list_dataset_versions(team: str | None = None, name: str = ""):
|
|
210
211
|
client = StudioClient(team=team)
|
|
211
212
|
|
|
212
213
|
namespace_name, project_name, name = parse_dataset_name(name)
|
|
@@ -226,13 +227,13 @@ def list_dataset_versions(team: Optional[str] = None, name: str = ""):
|
|
|
226
227
|
|
|
227
228
|
|
|
228
229
|
def edit_studio_dataset(
|
|
229
|
-
team_name:
|
|
230
|
+
team_name: str | None,
|
|
230
231
|
name: str,
|
|
231
232
|
namespace: str,
|
|
232
233
|
project: str,
|
|
233
|
-
new_name:
|
|
234
|
-
description:
|
|
235
|
-
attrs:
|
|
234
|
+
new_name: str | None = None,
|
|
235
|
+
description: str | None = None,
|
|
236
|
+
attrs: list[str] | None = None,
|
|
236
237
|
):
|
|
237
238
|
client = StudioClient(team=team_name)
|
|
238
239
|
response = client.edit_dataset(
|
|
@@ -245,12 +246,12 @@ def edit_studio_dataset(
|
|
|
245
246
|
|
|
246
247
|
|
|
247
248
|
def remove_studio_dataset(
|
|
248
|
-
team_name:
|
|
249
|
+
team_name: str | None,
|
|
249
250
|
name: str,
|
|
250
251
|
namespace: str,
|
|
251
252
|
project: str,
|
|
252
|
-
version:
|
|
253
|
-
force:
|
|
253
|
+
version: str | None = None,
|
|
254
|
+
force: bool | None = False,
|
|
254
255
|
):
|
|
255
256
|
client = StudioClient(team=team_name)
|
|
256
257
|
response = client.rm_dataset(name, namespace, project, version, force)
|
|
@@ -271,12 +272,21 @@ def save_config(hostname, token, level=ConfigLevel.GLOBAL):
|
|
|
271
272
|
return config.config_file()
|
|
272
273
|
|
|
273
274
|
|
|
274
|
-
def parse_start_time(start_time_str:
|
|
275
|
+
def parse_start_time(start_time_str: str | None) -> str | None:
|
|
275
276
|
if not start_time_str:
|
|
276
277
|
return None
|
|
277
278
|
|
|
278
|
-
#
|
|
279
|
-
|
|
279
|
+
# dateparser#1246: it explores strptime patterns lacking a year, which
|
|
280
|
+
# triggers a CPython 3.13 DeprecationWarning. Suppress that noise until a
|
|
281
|
+
# new dateparser release includes the upstream fix.
|
|
282
|
+
# https://github.com/scrapinghub/dateparser/issues/1246
|
|
283
|
+
with warnings.catch_warnings():
|
|
284
|
+
warnings.filterwarnings(
|
|
285
|
+
"ignore",
|
|
286
|
+
category=DeprecationWarning,
|
|
287
|
+
module="dateparser\\.utils\\.strptime",
|
|
288
|
+
)
|
|
289
|
+
parsed_datetime = dateparser.parse(start_time_str)
|
|
280
290
|
|
|
281
291
|
if parsed_datetime is None:
|
|
282
292
|
raise DataChainError(
|
|
@@ -343,21 +353,21 @@ def show_logs_from_client(client, job_id):
|
|
|
343
353
|
|
|
344
354
|
def create_job(
|
|
345
355
|
query_file: str,
|
|
346
|
-
team_name:
|
|
347
|
-
env_file:
|
|
348
|
-
env:
|
|
349
|
-
workers:
|
|
350
|
-
files:
|
|
351
|
-
python_version:
|
|
352
|
-
repository:
|
|
353
|
-
req:
|
|
354
|
-
req_file:
|
|
355
|
-
priority:
|
|
356
|
-
cluster:
|
|
357
|
-
start_time:
|
|
358
|
-
cron:
|
|
359
|
-
no_wait:
|
|
360
|
-
credentials_name:
|
|
356
|
+
team_name: str | None,
|
|
357
|
+
env_file: str | None = None,
|
|
358
|
+
env: list[str] | None = None,
|
|
359
|
+
workers: int | None = None,
|
|
360
|
+
files: list[str] | None = None,
|
|
361
|
+
python_version: str | None = None,
|
|
362
|
+
repository: str | None = None,
|
|
363
|
+
req: list[str] | None = None,
|
|
364
|
+
req_file: str | None = None,
|
|
365
|
+
priority: int | None = None,
|
|
366
|
+
cluster: str | None = None,
|
|
367
|
+
start_time: str | None = None,
|
|
368
|
+
cron: str | None = None,
|
|
369
|
+
no_wait: bool | None = False,
|
|
370
|
+
credentials_name: str | None = None,
|
|
361
371
|
):
|
|
362
372
|
query_type = "PYTHON" if query_file.endswith(".py") else "SHELL"
|
|
363
373
|
with open(query_file) as f:
|
|
@@ -403,14 +413,14 @@ def create_job(
|
|
|
403
413
|
if not response.data:
|
|
404
414
|
raise DataChainError("Failed to create job")
|
|
405
415
|
|
|
406
|
-
job_id = response.data.get("
|
|
416
|
+
job_id = response.data.get("id")
|
|
407
417
|
|
|
408
418
|
if parsed_start_time or cron:
|
|
409
419
|
print(f"Job {job_id} is scheduled as a task in Studio.")
|
|
410
420
|
return 0
|
|
411
421
|
|
|
412
422
|
print(f"Job {job_id} created")
|
|
413
|
-
print("Open the job in Studio at", response.data.get("
|
|
423
|
+
print("Open the job in Studio at", response.data.get("url"))
|
|
414
424
|
print("=" * 40)
|
|
415
425
|
|
|
416
426
|
return 0 if no_wait else show_logs_from_client(client, job_id)
|
|
@@ -421,21 +431,19 @@ def upload_files(client: StudioClient, files: list[str]) -> list[str]:
|
|
|
421
431
|
for file in files:
|
|
422
432
|
file_name = os.path.basename(file)
|
|
423
433
|
with open(file, "rb") as f:
|
|
424
|
-
|
|
425
|
-
response = client.upload_file(file_content, file_name)
|
|
434
|
+
response = client.upload_file(f, file_name)
|
|
426
435
|
if not response.ok:
|
|
427
436
|
raise DataChainError(response.message)
|
|
428
437
|
|
|
429
438
|
if not response.data:
|
|
430
439
|
raise DataChainError(f"Failed to upload file {file_name}")
|
|
431
440
|
|
|
432
|
-
file_id
|
|
433
|
-
if file_id:
|
|
441
|
+
if file_id := response.data.get("id"):
|
|
434
442
|
file_ids.append(str(file_id))
|
|
435
443
|
return file_ids
|
|
436
444
|
|
|
437
445
|
|
|
438
|
-
def cancel_job(job_id: str, team_name:
|
|
446
|
+
def cancel_job(job_id: str, team_name: str | None):
|
|
439
447
|
token = Config().read().get("studio", {}).get("token")
|
|
440
448
|
if not token:
|
|
441
449
|
raise DataChainError(
|
|
@@ -450,13 +458,13 @@ def cancel_job(job_id: str, team_name: Optional[str]):
|
|
|
450
458
|
print(f"Job {job_id} canceled")
|
|
451
459
|
|
|
452
460
|
|
|
453
|
-
def list_jobs(status:
|
|
461
|
+
def list_jobs(status: str | None, team_name: str | None, limit: int):
|
|
454
462
|
client = StudioClient(team=team_name)
|
|
455
463
|
response = client.get_jobs(status, limit)
|
|
456
464
|
if not response.ok:
|
|
457
465
|
raise DataChainError(response.message)
|
|
458
466
|
|
|
459
|
-
jobs = response.data
|
|
467
|
+
jobs = response.data or []
|
|
460
468
|
if not jobs:
|
|
461
469
|
print("No jobs found")
|
|
462
470
|
return
|
|
@@ -475,7 +483,7 @@ def list_jobs(status: Optional[str], team_name: Optional[str], limit: int):
|
|
|
475
483
|
print(tabulate.tabulate(rows, headers="keys", tablefmt="grid"))
|
|
476
484
|
|
|
477
485
|
|
|
478
|
-
def show_job_logs(job_id: str, team_name:
|
|
486
|
+
def show_job_logs(job_id: str, team_name: str | None):
|
|
479
487
|
token = Config().read().get("studio", {}).get("token")
|
|
480
488
|
if not token:
|
|
481
489
|
raise DataChainError(
|
|
@@ -486,13 +494,13 @@ def show_job_logs(job_id: str, team_name: Optional[str]):
|
|
|
486
494
|
return show_logs_from_client(client, job_id)
|
|
487
495
|
|
|
488
496
|
|
|
489
|
-
def list_clusters(team_name:
|
|
497
|
+
def list_clusters(team_name: str | None):
|
|
490
498
|
client = StudioClient(team=team_name)
|
|
491
499
|
response = client.get_clusters()
|
|
492
500
|
if not response.ok:
|
|
493
501
|
raise DataChainError(response.message)
|
|
494
502
|
|
|
495
|
-
clusters = response.data
|
|
503
|
+
clusters = response.data or []
|
|
496
504
|
if not clusters:
|
|
497
505
|
print("No clusters found")
|
|
498
506
|
return
|
|
@@ -505,6 +513,7 @@ def list_clusters(team_name: Optional[str]):
|
|
|
505
513
|
"Cloud Provider": cluster.get("cloud_provider"),
|
|
506
514
|
"Cloud Credentials": cluster.get("cloud_credentials"),
|
|
507
515
|
"Is Active": cluster.get("is_active"),
|
|
516
|
+
"Is Default": cluster.get("default"),
|
|
508
517
|
"Max Workers": cluster.get("max_workers"),
|
|
509
518
|
}
|
|
510
519
|
for cluster in clusters
|
datachain/toolkit/split.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import random
|
|
2
|
-
from typing import Optional
|
|
3
2
|
|
|
4
3
|
from datachain import C, DataChain
|
|
4
|
+
from datachain.lib.signal_schema import SignalResolvingError
|
|
5
5
|
|
|
6
6
|
RESOLUTION = 2**31 - 1 # Maximum positive value for a 32-bit signed integer.
|
|
7
7
|
|
|
@@ -9,7 +9,7 @@ RESOLUTION = 2**31 - 1 # Maximum positive value for a 32-bit signed integer.
|
|
|
9
9
|
def train_test_split(
|
|
10
10
|
dc: DataChain,
|
|
11
11
|
weights: list[float],
|
|
12
|
-
seed:
|
|
12
|
+
seed: int | None = None,
|
|
13
13
|
) -> list[DataChain]:
|
|
14
14
|
"""
|
|
15
15
|
Splits a DataChain into multiple subsets based on the provided weights.
|
|
@@ -60,7 +60,10 @@ def train_test_split(
|
|
|
60
60
|
```
|
|
61
61
|
|
|
62
62
|
Note:
|
|
63
|
-
|
|
63
|
+
Splits reuse the same best-effort shuffle used by `DataChain.shuffle`. Results
|
|
64
|
+
are typically repeatable, but earlier operations such as `merge`, `union`, or
|
|
65
|
+
custom SQL that reshuffle rows can change the outcome between runs. Add order by
|
|
66
|
+
stable keys first when you need strict reproducibility.
|
|
64
67
|
"""
|
|
65
68
|
if len(weights) < 2:
|
|
66
69
|
raise ValueError("Weights should have at least two elements")
|
|
@@ -69,16 +72,34 @@ def train_test_split(
|
|
|
69
72
|
|
|
70
73
|
weights_normalized = [weight / sum(weights) for weight in weights]
|
|
71
74
|
|
|
75
|
+
try:
|
|
76
|
+
dc.signals_schema.resolve("sys.rand")
|
|
77
|
+
except SignalResolvingError:
|
|
78
|
+
dc = dc.persist()
|
|
79
|
+
|
|
72
80
|
rand_col = C("sys.rand")
|
|
73
81
|
if seed is not None:
|
|
74
82
|
uniform_seed = random.Random(seed).randrange(1, RESOLUTION) # noqa: S311
|
|
75
83
|
rand_col = (rand_col % RESOLUTION) * uniform_seed # type: ignore[assignment]
|
|
76
84
|
rand_col = rand_col % RESOLUTION # type: ignore[assignment]
|
|
77
85
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
|
|
86
|
+
boundaries: list[int] = [0]
|
|
87
|
+
cumulative = 0.0
|
|
88
|
+
for weight in weights_normalized[:-1]:
|
|
89
|
+
cumulative += weight
|
|
90
|
+
boundary = round(cumulative * RESOLUTION)
|
|
91
|
+
boundaries.append(min(boundary, RESOLUTION))
|
|
92
|
+
boundaries.append(RESOLUTION)
|
|
93
|
+
|
|
94
|
+
splits: list[DataChain] = []
|
|
95
|
+
last_index = len(weights_normalized) - 1
|
|
96
|
+
for index in range(len(weights_normalized)):
|
|
97
|
+
lower = boundaries[index]
|
|
98
|
+
if index == last_index:
|
|
99
|
+
condition = rand_col >= lower
|
|
100
|
+
else:
|
|
101
|
+
upper = boundaries[index + 1]
|
|
102
|
+
condition = (rand_col >= lower) & (rand_col < upper)
|
|
103
|
+
splits.append(dc.filter(condition))
|
|
104
|
+
|
|
105
|
+
return splits
|