datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +20 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +7 -7
- datachain/catalog/__init__.py +2 -2
- datachain/catalog/catalog.py +621 -507
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +28 -18
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +24 -33
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +83 -52
- datachain/cli/commands/ls.py +17 -17
- datachain/cli/commands/show.py +4 -4
- datachain/cli/parser/__init__.py +8 -74
- datachain/cli/parser/job.py +95 -3
- datachain/cli/parser/studio.py +11 -4
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +4 -4
- datachain/client/fsspec.py +45 -28
- datachain/client/gcs.py +6 -6
- datachain/client/hf.py +29 -2
- datachain/client/http.py +157 -0
- datachain/client/local.py +15 -11
- datachain/client/s3.py +17 -9
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +5 -1
- datachain/data_storage/metastore.py +1252 -186
- datachain/data_storage/schema.py +58 -45
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +286 -127
- datachain/data_storage/warehouse.py +250 -113
- datachain/dataset.py +353 -148
- datachain/delta.py +391 -0
- datachain/diff/__init__.py +27 -29
- datachain/error.py +60 -0
- datachain/func/__init__.py +2 -1
- datachain/func/aggregate.py +66 -42
- datachain/func/array.py +242 -38
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +110 -60
- datachain/func/func.py +96 -45
- datachain/func/numeric.py +55 -38
- datachain/func/path.py +32 -20
- datachain/func/random.py +2 -2
- datachain/func/string.py +67 -37
- datachain/func/window.py +7 -8
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +58 -22
- datachain/lib/audio.py +245 -0
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/flatten.py +5 -3
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/sql_to_python.py +8 -0
- datachain/lib/convert/values_to_tuples.py +156 -51
- datachain/lib/data_model.py +42 -20
- datachain/lib/dataset_info.py +36 -8
- datachain/lib/dc/__init__.py +8 -2
- datachain/lib/dc/csv.py +25 -28
- datachain/lib/dc/database.py +398 -0
- datachain/lib/dc/datachain.py +1289 -425
- datachain/lib/dc/datasets.py +320 -38
- datachain/lib/dc/hf.py +38 -24
- datachain/lib/dc/json.py +29 -32
- datachain/lib/dc/listings.py +112 -8
- datachain/lib/dc/pandas.py +16 -12
- datachain/lib/dc/parquet.py +35 -23
- datachain/lib/dc/records.py +31 -23
- datachain/lib/dc/storage.py +154 -64
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +24 -16
- datachain/lib/dc/values.py +8 -9
- datachain/lib/file.py +622 -89
- datachain/lib/hf.py +69 -39
- datachain/lib/image.py +14 -14
- datachain/lib/listing.py +14 -11
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +3 -4
- datachain/lib/model_store.py +39 -7
- datachain/lib/namespaces.py +125 -0
- datachain/lib/projects.py +130 -0
- datachain/lib/pytorch.py +32 -21
- datachain/lib/settings.py +192 -56
- datachain/lib/signal_schema.py +427 -104
- datachain/lib/tar.py +1 -2
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +164 -76
- datachain/lib/udf_signature.py +60 -35
- datachain/lib/utils.py +118 -4
- datachain/lib/video.py +17 -9
- datachain/lib/webdataset.py +61 -56
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +22 -10
- datachain/model/bbox.py +3 -1
- datachain/model/ultralytics/bbox.py +16 -12
- datachain/model/ultralytics/pose.py +16 -12
- datachain/model/ultralytics/segment.py +16 -12
- datachain/namespace.py +84 -0
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +78 -0
- datachain/query/batch.py +40 -41
- datachain/query/dataset.py +604 -322
- datachain/query/dispatch.py +261 -154
- datachain/query/metrics.py +4 -6
- datachain/query/params.py +2 -3
- datachain/query/queue.py +3 -12
- datachain/query/schema.py +11 -6
- datachain/query/session.py +200 -33
- datachain/query/udf.py +34 -2
- datachain/remote/studio.py +171 -69
- datachain/script_meta.py +12 -12
- datachain/semver.py +68 -0
- datachain/sql/__init__.py +2 -0
- datachain/sql/functions/array.py +33 -1
- datachain/sql/postgresql_dialect.py +9 -0
- datachain/sql/postgresql_types.py +21 -0
- datachain/sql/sqlite/__init__.py +5 -1
- datachain/sql/sqlite/base.py +102 -29
- datachain/sql/sqlite/types.py +8 -13
- datachain/sql/types.py +70 -15
- datachain/studio.py +223 -46
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +101 -59
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
- datachain-0.39.0.dist-info/RECORD +173 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
- datachain/cli/commands/query.py +0 -53
- datachain/query/utils.py +0 -42
- datachain-0.14.2.dist-info/RECORD +0 -158
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/studio.py
CHANGED
|
@@ -1,10 +1,16 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import os
|
|
3
3
|
import sys
|
|
4
|
-
|
|
4
|
+
import warnings
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
import dateparser
|
|
9
|
+
import tabulate
|
|
5
10
|
|
|
6
11
|
from datachain.config import Config, ConfigLevel
|
|
7
|
-
from datachain.
|
|
12
|
+
from datachain.data_storage.job import JobStatus
|
|
13
|
+
from datachain.dataset import QUERY_DATASET_PREFIX, parse_dataset_name
|
|
8
14
|
from datachain.error import DataChainError
|
|
9
15
|
from datachain.remote.studio import StudioClient
|
|
10
16
|
from datachain.utils import STUDIO_URL
|
|
@@ -16,6 +22,8 @@ POST_LOGIN_MESSAGE = (
|
|
|
16
22
|
"Once you've logged in, return here "
|
|
17
23
|
"and you'll be ready to start using DataChain with Studio."
|
|
18
24
|
)
|
|
25
|
+
RETRY_MAX_TIMES = 10
|
|
26
|
+
RETRY_SLEEP_SEC = 1
|
|
19
27
|
|
|
20
28
|
|
|
21
29
|
def process_jobs_args(args: "Namespace"):
|
|
@@ -35,14 +43,28 @@ def process_jobs_args(args: "Namespace"):
|
|
|
35
43
|
args.workers,
|
|
36
44
|
args.files,
|
|
37
45
|
args.python_version,
|
|
46
|
+
args.repository,
|
|
38
47
|
args.req,
|
|
39
48
|
args.req_file,
|
|
49
|
+
args.priority,
|
|
50
|
+
args.cluster,
|
|
51
|
+
args.start_time,
|
|
52
|
+
args.cron,
|
|
53
|
+
args.no_wait,
|
|
54
|
+
args.credentials_name,
|
|
40
55
|
)
|
|
41
56
|
|
|
42
57
|
if args.cmd == "cancel":
|
|
43
58
|
return cancel_job(args.id, args.team)
|
|
44
59
|
if args.cmd == "logs":
|
|
45
60
|
return show_job_logs(args.id, args.team)
|
|
61
|
+
|
|
62
|
+
if args.cmd == "ls":
|
|
63
|
+
return list_jobs(args.status, args.team, args.limit)
|
|
64
|
+
|
|
65
|
+
if args.cmd == "clusters":
|
|
66
|
+
return list_clusters(args.team)
|
|
67
|
+
|
|
46
68
|
raise DataChainError(f"Unknown command '{args.cmd}'.")
|
|
47
69
|
|
|
48
70
|
|
|
@@ -60,14 +82,24 @@ def process_auth_cli_args(args: "Namespace"):
|
|
|
60
82
|
return logout(args.local)
|
|
61
83
|
if args.cmd == "token":
|
|
62
84
|
return token()
|
|
63
|
-
|
|
64
85
|
if args.cmd == "team":
|
|
65
86
|
return set_team(args)
|
|
66
87
|
raise DataChainError(f"Unknown command '{args.cmd}'.")
|
|
67
88
|
|
|
68
89
|
|
|
69
90
|
def set_team(args: "Namespace"):
|
|
70
|
-
|
|
91
|
+
if args.team_name is None:
|
|
92
|
+
config = Config().read().get("studio", {})
|
|
93
|
+
team = config.get("team")
|
|
94
|
+
if team:
|
|
95
|
+
print(f"Default team is '{team}'")
|
|
96
|
+
return 0
|
|
97
|
+
|
|
98
|
+
raise DataChainError(
|
|
99
|
+
"No default team set. Use `datachain auth team <team_name>` to set one."
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
level = ConfigLevel.LOCAL if args.local else ConfigLevel.GLOBAL
|
|
71
103
|
config = Config(level)
|
|
72
104
|
with config.edit() as conf:
|
|
73
105
|
studio_conf = conf.get("studio", {})
|
|
@@ -80,11 +112,13 @@ def set_team(args: "Namespace"):
|
|
|
80
112
|
def login(args: "Namespace"):
|
|
81
113
|
from dvc_studio_client.auth import StudioAuthError, get_access_token
|
|
82
114
|
|
|
115
|
+
from datachain.remote.studio import get_studio_env_variable
|
|
116
|
+
|
|
83
117
|
config = Config().read().get("studio", {})
|
|
84
118
|
name = args.name
|
|
85
119
|
hostname = (
|
|
86
120
|
args.hostname
|
|
87
|
-
or
|
|
121
|
+
or get_studio_env_variable("URL")
|
|
88
122
|
or config.get("url")
|
|
89
123
|
or STUDIO_URL
|
|
90
124
|
)
|
|
@@ -113,6 +147,7 @@ def login(args: "Namespace"):
|
|
|
113
147
|
level = ConfigLevel.LOCAL if args.local else ConfigLevel.GLOBAL
|
|
114
148
|
config_path = save_config(hostname, access_token, level=level)
|
|
115
149
|
print(f"Authentication complete. Saved token to {config_path}.")
|
|
150
|
+
print("You can now use 'datachain auth team' to set the default team.")
|
|
116
151
|
return 0
|
|
117
152
|
|
|
118
153
|
|
|
@@ -141,7 +176,12 @@ def token():
|
|
|
141
176
|
print(token)
|
|
142
177
|
|
|
143
178
|
|
|
144
|
-
def list_datasets(team:
|
|
179
|
+
def list_datasets(team: str | None = None, name: str | None = None):
|
|
180
|
+
def ds_full_name(ds: dict) -> str:
|
|
181
|
+
return (
|
|
182
|
+
f"{ds['project']['namespace']['name']}.{ds['project']['name']}.{ds['name']}"
|
|
183
|
+
)
|
|
184
|
+
|
|
145
185
|
if name:
|
|
146
186
|
yield from list_dataset_versions(team, name)
|
|
147
187
|
return
|
|
@@ -158,18 +198,22 @@ def list_datasets(team: Optional[str] = None, name: Optional[str] = None):
|
|
|
158
198
|
|
|
159
199
|
for d in response.data:
|
|
160
200
|
name = d.get("name")
|
|
201
|
+
full_name = ds_full_name(d)
|
|
161
202
|
if name and name.startswith(QUERY_DATASET_PREFIX):
|
|
162
203
|
continue
|
|
163
204
|
|
|
164
205
|
for v in d.get("versions", []):
|
|
165
206
|
version = v.get("version")
|
|
166
|
-
yield (
|
|
207
|
+
yield (full_name, version)
|
|
167
208
|
|
|
168
209
|
|
|
169
|
-
def list_dataset_versions(team:
|
|
210
|
+
def list_dataset_versions(team: str | None = None, name: str = ""):
|
|
170
211
|
client = StudioClient(team=team)
|
|
171
212
|
|
|
172
|
-
|
|
213
|
+
namespace_name, project_name, name = parse_dataset_name(name)
|
|
214
|
+
if not namespace_name or not project_name:
|
|
215
|
+
raise DataChainError(f"Missing namespace or project form dataset name {name}")
|
|
216
|
+
response = client.dataset_info(namespace_name, project_name, name)
|
|
173
217
|
|
|
174
218
|
if not response.ok:
|
|
175
219
|
raise DataChainError(response.message)
|
|
@@ -183,14 +227,18 @@ def list_dataset_versions(team: Optional[str] = None, name: str = ""):
|
|
|
183
227
|
|
|
184
228
|
|
|
185
229
|
def edit_studio_dataset(
|
|
186
|
-
team_name:
|
|
230
|
+
team_name: str | None,
|
|
187
231
|
name: str,
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
232
|
+
namespace: str,
|
|
233
|
+
project: str,
|
|
234
|
+
new_name: str | None = None,
|
|
235
|
+
description: str | None = None,
|
|
236
|
+
attrs: list[str] | None = None,
|
|
191
237
|
):
|
|
192
238
|
client = StudioClient(team=team_name)
|
|
193
|
-
response = client.edit_dataset(
|
|
239
|
+
response = client.edit_dataset(
|
|
240
|
+
name, namespace, project, new_name, description, attrs
|
|
241
|
+
)
|
|
194
242
|
if not response.ok:
|
|
195
243
|
raise DataChainError(response.message)
|
|
196
244
|
|
|
@@ -198,13 +246,15 @@ def edit_studio_dataset(
|
|
|
198
246
|
|
|
199
247
|
|
|
200
248
|
def remove_studio_dataset(
|
|
201
|
-
team_name:
|
|
249
|
+
team_name: str | None,
|
|
202
250
|
name: str,
|
|
203
|
-
|
|
204
|
-
|
|
251
|
+
namespace: str,
|
|
252
|
+
project: str,
|
|
253
|
+
version: str | None = None,
|
|
254
|
+
force: bool | None = False,
|
|
205
255
|
):
|
|
206
256
|
client = StudioClient(team=team_name)
|
|
207
|
-
response = client.rm_dataset(name, version, force)
|
|
257
|
+
response = client.rm_dataset(name, namespace, project, version, force)
|
|
208
258
|
if not response.ok:
|
|
209
259
|
raise DataChainError(response.message)
|
|
210
260
|
|
|
@@ -222,42 +272,102 @@ def save_config(hostname, token, level=ConfigLevel.GLOBAL):
|
|
|
222
272
|
return config.config_file()
|
|
223
273
|
|
|
224
274
|
|
|
275
|
+
def parse_start_time(start_time_str: str | None) -> str | None:
|
|
276
|
+
if not start_time_str:
|
|
277
|
+
return None
|
|
278
|
+
|
|
279
|
+
# dateparser#1246: it explores strptime patterns lacking a year, which
|
|
280
|
+
# triggers a CPython 3.13 DeprecationWarning. Suppress that noise until a
|
|
281
|
+
# new dateparser release includes the upstream fix.
|
|
282
|
+
# https://github.com/scrapinghub/dateparser/issues/1246
|
|
283
|
+
with warnings.catch_warnings():
|
|
284
|
+
warnings.filterwarnings(
|
|
285
|
+
"ignore",
|
|
286
|
+
category=DeprecationWarning,
|
|
287
|
+
module="dateparser\\.utils\\.strptime",
|
|
288
|
+
)
|
|
289
|
+
parsed_datetime = dateparser.parse(start_time_str)
|
|
290
|
+
|
|
291
|
+
if parsed_datetime is None:
|
|
292
|
+
raise DataChainError(
|
|
293
|
+
f"Could not parse datetime string: '{start_time_str}'. "
|
|
294
|
+
f"Supported formats include: '2024-01-15 14:30:00', 'tomorrow 3pm', "
|
|
295
|
+
f"'monday 9am', '2024-01-15T14:30:00Z', 'in 2 hours', etc."
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
# Convert to ISO format string
|
|
299
|
+
return parsed_datetime.isoformat()
|
|
300
|
+
|
|
301
|
+
|
|
225
302
|
def show_logs_from_client(client, job_id):
|
|
226
303
|
# Sync usage
|
|
227
304
|
async def _run():
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
305
|
+
retry_count = 0
|
|
306
|
+
latest_status = None
|
|
307
|
+
processed_statuses = set()
|
|
308
|
+
while True:
|
|
309
|
+
async for message in client.tail_job_logs(job_id):
|
|
310
|
+
if "logs" in message:
|
|
311
|
+
for log in message["logs"]:
|
|
312
|
+
print(log["message"], end="")
|
|
313
|
+
elif "job" in message:
|
|
314
|
+
latest_status = message["job"]["status"]
|
|
315
|
+
if latest_status in processed_statuses:
|
|
316
|
+
continue
|
|
317
|
+
processed_statuses.add(latest_status)
|
|
318
|
+
print(f"\n>>>> Job is now in {latest_status} status.")
|
|
319
|
+
|
|
320
|
+
try:
|
|
321
|
+
if retry_count > RETRY_MAX_TIMES or (
|
|
322
|
+
latest_status and JobStatus[latest_status].finished()
|
|
323
|
+
):
|
|
324
|
+
break
|
|
325
|
+
await asyncio.sleep(RETRY_SLEEP_SEC)
|
|
326
|
+
retry_count += 1
|
|
327
|
+
except KeyError:
|
|
328
|
+
pass
|
|
329
|
+
|
|
330
|
+
return latest_status
|
|
331
|
+
|
|
332
|
+
final_status = asyncio.run(_run())
|
|
236
333
|
|
|
237
334
|
response = client.dataset_job_versions(job_id)
|
|
238
335
|
if not response.ok:
|
|
239
336
|
raise DataChainError(response.message)
|
|
240
337
|
|
|
241
338
|
response_data = response.data
|
|
242
|
-
if response_data:
|
|
339
|
+
if response_data and response_data.get("dataset_versions"):
|
|
243
340
|
dataset_versions = response_data.get("dataset_versions", [])
|
|
244
341
|
print("\n\n>>>> Dataset versions created during the job:")
|
|
245
342
|
for version in dataset_versions:
|
|
246
343
|
print(f" - {version.get('dataset_name')}@v{version.get('version')}")
|
|
247
344
|
else:
|
|
248
|
-
print("
|
|
345
|
+
print("\n\nNo dataset versions created during the job.")
|
|
346
|
+
|
|
347
|
+
exit_code_by_status = {
|
|
348
|
+
"FAILED": 1,
|
|
349
|
+
"CANCELED": 2,
|
|
350
|
+
}
|
|
351
|
+
return exit_code_by_status.get(final_status.upper(), 0) if final_status else 0
|
|
249
352
|
|
|
250
353
|
|
|
251
354
|
def create_job(
|
|
252
355
|
query_file: str,
|
|
253
|
-
team_name:
|
|
254
|
-
env_file:
|
|
255
|
-
env:
|
|
256
|
-
workers:
|
|
257
|
-
files:
|
|
258
|
-
python_version:
|
|
259
|
-
|
|
260
|
-
|
|
356
|
+
team_name: str | None,
|
|
357
|
+
env_file: str | None = None,
|
|
358
|
+
env: list[str] | None = None,
|
|
359
|
+
workers: int | None = None,
|
|
360
|
+
files: list[str] | None = None,
|
|
361
|
+
python_version: str | None = None,
|
|
362
|
+
repository: str | None = None,
|
|
363
|
+
req: list[str] | None = None,
|
|
364
|
+
req_file: str | None = None,
|
|
365
|
+
priority: int | None = None,
|
|
366
|
+
cluster: str | None = None,
|
|
367
|
+
start_time: str | None = None,
|
|
368
|
+
cron: str | None = None,
|
|
369
|
+
no_wait: bool | None = False,
|
|
370
|
+
credentials_name: str | None = None,
|
|
261
371
|
):
|
|
262
372
|
query_type = "PYTHON" if query_file.endswith(".py") else "SHELL"
|
|
263
373
|
with open(query_file) as f:
|
|
@@ -276,6 +386,11 @@ def create_job(
|
|
|
276
386
|
client = StudioClient(team=team_name)
|
|
277
387
|
file_ids = upload_files(client, files) if files else []
|
|
278
388
|
|
|
389
|
+
# Parse start_time if provided
|
|
390
|
+
parsed_start_time = parse_start_time(start_time)
|
|
391
|
+
if cron and parsed_start_time is None:
|
|
392
|
+
parsed_start_time = datetime.now(timezone.utc).isoformat()
|
|
393
|
+
|
|
279
394
|
response = client.create_job(
|
|
280
395
|
query=query,
|
|
281
396
|
query_type=query_type,
|
|
@@ -284,7 +399,13 @@ def create_job(
|
|
|
284
399
|
query_name=os.path.basename(query_file),
|
|
285
400
|
files=file_ids,
|
|
286
401
|
python_version=python_version,
|
|
402
|
+
repository=repository,
|
|
287
403
|
requirements=requirements,
|
|
404
|
+
priority=priority,
|
|
405
|
+
cluster=cluster,
|
|
406
|
+
start_time=parsed_start_time,
|
|
407
|
+
cron=cron,
|
|
408
|
+
credentials_name=credentials_name,
|
|
288
409
|
)
|
|
289
410
|
if not response.ok:
|
|
290
411
|
raise DataChainError(response.message)
|
|
@@ -292,12 +413,17 @@ def create_job(
|
|
|
292
413
|
if not response.data:
|
|
293
414
|
raise DataChainError("Failed to create job")
|
|
294
415
|
|
|
295
|
-
job_id = response.data.get("
|
|
416
|
+
job_id = response.data.get("id")
|
|
417
|
+
|
|
418
|
+
if parsed_start_time or cron:
|
|
419
|
+
print(f"Job {job_id} is scheduled as a task in Studio.")
|
|
420
|
+
return 0
|
|
421
|
+
|
|
296
422
|
print(f"Job {job_id} created")
|
|
297
|
-
print("Open the job in Studio at", response.data.get("
|
|
423
|
+
print("Open the job in Studio at", response.data.get("url"))
|
|
298
424
|
print("=" * 40)
|
|
299
425
|
|
|
300
|
-
show_logs_from_client(client, job_id)
|
|
426
|
+
return 0 if no_wait else show_logs_from_client(client, job_id)
|
|
301
427
|
|
|
302
428
|
|
|
303
429
|
def upload_files(client: StudioClient, files: list[str]) -> list[str]:
|
|
@@ -305,21 +431,19 @@ def upload_files(client: StudioClient, files: list[str]) -> list[str]:
|
|
|
305
431
|
for file in files:
|
|
306
432
|
file_name = os.path.basename(file)
|
|
307
433
|
with open(file, "rb") as f:
|
|
308
|
-
|
|
309
|
-
response = client.upload_file(file_content, file_name)
|
|
434
|
+
response = client.upload_file(f, file_name)
|
|
310
435
|
if not response.ok:
|
|
311
436
|
raise DataChainError(response.message)
|
|
312
437
|
|
|
313
438
|
if not response.data:
|
|
314
439
|
raise DataChainError(f"Failed to upload file {file_name}")
|
|
315
440
|
|
|
316
|
-
file_id
|
|
317
|
-
if file_id:
|
|
441
|
+
if file_id := response.data.get("id"):
|
|
318
442
|
file_ids.append(str(file_id))
|
|
319
443
|
return file_ids
|
|
320
444
|
|
|
321
445
|
|
|
322
|
-
def cancel_job(job_id: str, team_name:
|
|
446
|
+
def cancel_job(job_id: str, team_name: str | None):
|
|
323
447
|
token = Config().read().get("studio", {}).get("token")
|
|
324
448
|
if not token:
|
|
325
449
|
raise DataChainError(
|
|
@@ -334,7 +458,32 @@ def cancel_job(job_id: str, team_name: Optional[str]):
|
|
|
334
458
|
print(f"Job {job_id} canceled")
|
|
335
459
|
|
|
336
460
|
|
|
337
|
-
def
|
|
461
|
+
def list_jobs(status: str | None, team_name: str | None, limit: int):
|
|
462
|
+
client = StudioClient(team=team_name)
|
|
463
|
+
response = client.get_jobs(status, limit)
|
|
464
|
+
if not response.ok:
|
|
465
|
+
raise DataChainError(response.message)
|
|
466
|
+
|
|
467
|
+
jobs = response.data or []
|
|
468
|
+
if not jobs:
|
|
469
|
+
print("No jobs found")
|
|
470
|
+
return
|
|
471
|
+
|
|
472
|
+
rows = [
|
|
473
|
+
{
|
|
474
|
+
"ID": job.get("id"),
|
|
475
|
+
"Name": job.get("name"),
|
|
476
|
+
"Status": job.get("status"),
|
|
477
|
+
"Created at": job.get("created_at"),
|
|
478
|
+
"Created by": job.get("created_by"),
|
|
479
|
+
}
|
|
480
|
+
for job in jobs
|
|
481
|
+
]
|
|
482
|
+
|
|
483
|
+
print(tabulate.tabulate(rows, headers="keys", tablefmt="grid"))
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
def show_job_logs(job_id: str, team_name: str | None):
|
|
338
487
|
token = Config().read().get("studio", {}).get("token")
|
|
339
488
|
if not token:
|
|
340
489
|
raise DataChainError(
|
|
@@ -342,4 +491,32 @@ def show_job_logs(job_id: str, team_name: Optional[str]):
|
|
|
342
491
|
)
|
|
343
492
|
|
|
344
493
|
client = StudioClient(team=team_name)
|
|
345
|
-
show_logs_from_client(client, job_id)
|
|
494
|
+
return show_logs_from_client(client, job_id)
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
def list_clusters(team_name: str | None):
|
|
498
|
+
client = StudioClient(team=team_name)
|
|
499
|
+
response = client.get_clusters()
|
|
500
|
+
if not response.ok:
|
|
501
|
+
raise DataChainError(response.message)
|
|
502
|
+
|
|
503
|
+
clusters = response.data or []
|
|
504
|
+
if not clusters:
|
|
505
|
+
print("No clusters found")
|
|
506
|
+
return
|
|
507
|
+
|
|
508
|
+
rows = [
|
|
509
|
+
{
|
|
510
|
+
"ID": cluster.get("id"),
|
|
511
|
+
"Name": cluster.get("name"),
|
|
512
|
+
"Status": cluster.get("status"),
|
|
513
|
+
"Cloud Provider": cluster.get("cloud_provider"),
|
|
514
|
+
"Cloud Credentials": cluster.get("cloud_credentials"),
|
|
515
|
+
"Is Active": cluster.get("is_active"),
|
|
516
|
+
"Is Default": cluster.get("default"),
|
|
517
|
+
"Max Workers": cluster.get("max_workers"),
|
|
518
|
+
}
|
|
519
|
+
for cluster in clusters
|
|
520
|
+
]
|
|
521
|
+
|
|
522
|
+
print(tabulate.tabulate(rows, headers="keys", tablefmt="grid"))
|
datachain/toolkit/split.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import random
|
|
2
|
-
from typing import Optional
|
|
3
2
|
|
|
4
3
|
from datachain import C, DataChain
|
|
4
|
+
from datachain.lib.signal_schema import SignalResolvingError
|
|
5
5
|
|
|
6
6
|
RESOLUTION = 2**31 - 1 # Maximum positive value for a 32-bit signed integer.
|
|
7
7
|
|
|
@@ -9,7 +9,7 @@ RESOLUTION = 2**31 - 1 # Maximum positive value for a 32-bit signed integer.
|
|
|
9
9
|
def train_test_split(
|
|
10
10
|
dc: DataChain,
|
|
11
11
|
weights: list[float],
|
|
12
|
-
seed:
|
|
12
|
+
seed: int | None = None,
|
|
13
13
|
) -> list[DataChain]:
|
|
14
14
|
"""
|
|
15
15
|
Splits a DataChain into multiple subsets based on the provided weights.
|
|
@@ -60,7 +60,10 @@ def train_test_split(
|
|
|
60
60
|
```
|
|
61
61
|
|
|
62
62
|
Note:
|
|
63
|
-
|
|
63
|
+
Splits reuse the same best-effort shuffle used by `DataChain.shuffle`. Results
|
|
64
|
+
are typically repeatable, but earlier operations such as `merge`, `union`, or
|
|
65
|
+
custom SQL that reshuffle rows can change the outcome between runs. Add order by
|
|
66
|
+
stable keys first when you need strict reproducibility.
|
|
64
67
|
"""
|
|
65
68
|
if len(weights) < 2:
|
|
66
69
|
raise ValueError("Weights should have at least two elements")
|
|
@@ -69,16 +72,34 @@ def train_test_split(
|
|
|
69
72
|
|
|
70
73
|
weights_normalized = [weight / sum(weights) for weight in weights]
|
|
71
74
|
|
|
75
|
+
try:
|
|
76
|
+
dc.signals_schema.resolve("sys.rand")
|
|
77
|
+
except SignalResolvingError:
|
|
78
|
+
dc = dc.persist()
|
|
79
|
+
|
|
72
80
|
rand_col = C("sys.rand")
|
|
73
81
|
if seed is not None:
|
|
74
82
|
uniform_seed = random.Random(seed).randrange(1, RESOLUTION) # noqa: S311
|
|
75
83
|
rand_col = (rand_col % RESOLUTION) * uniform_seed # type: ignore[assignment]
|
|
76
84
|
rand_col = rand_col % RESOLUTION # type: ignore[assignment]
|
|
77
85
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
|
|
86
|
+
boundaries: list[int] = [0]
|
|
87
|
+
cumulative = 0.0
|
|
88
|
+
for weight in weights_normalized[:-1]:
|
|
89
|
+
cumulative += weight
|
|
90
|
+
boundary = round(cumulative * RESOLUTION)
|
|
91
|
+
boundaries.append(min(boundary, RESOLUTION))
|
|
92
|
+
boundaries.append(RESOLUTION)
|
|
93
|
+
|
|
94
|
+
splits: list[DataChain] = []
|
|
95
|
+
last_index = len(weights_normalized) - 1
|
|
96
|
+
for index in range(len(weights_normalized)):
|
|
97
|
+
lower = boundaries[index]
|
|
98
|
+
if index == last_index:
|
|
99
|
+
condition = rand_col >= lower
|
|
100
|
+
else:
|
|
101
|
+
upper = boundaries[index + 1]
|
|
102
|
+
condition = (rand_col >= lower) & (rand_col < upper)
|
|
103
|
+
splits.append(dc.filter(condition))
|
|
104
|
+
|
|
105
|
+
return splits
|