datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +20 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +7 -7
- datachain/catalog/__init__.py +2 -2
- datachain/catalog/catalog.py +621 -507
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +28 -18
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +24 -33
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +83 -52
- datachain/cli/commands/ls.py +17 -17
- datachain/cli/commands/show.py +4 -4
- datachain/cli/parser/__init__.py +8 -74
- datachain/cli/parser/job.py +95 -3
- datachain/cli/parser/studio.py +11 -4
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +4 -4
- datachain/client/fsspec.py +45 -28
- datachain/client/gcs.py +6 -6
- datachain/client/hf.py +29 -2
- datachain/client/http.py +157 -0
- datachain/client/local.py +15 -11
- datachain/client/s3.py +17 -9
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +5 -1
- datachain/data_storage/metastore.py +1252 -186
- datachain/data_storage/schema.py +58 -45
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +286 -127
- datachain/data_storage/warehouse.py +250 -113
- datachain/dataset.py +353 -148
- datachain/delta.py +391 -0
- datachain/diff/__init__.py +27 -29
- datachain/error.py +60 -0
- datachain/func/__init__.py +2 -1
- datachain/func/aggregate.py +66 -42
- datachain/func/array.py +242 -38
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +110 -60
- datachain/func/func.py +96 -45
- datachain/func/numeric.py +55 -38
- datachain/func/path.py +32 -20
- datachain/func/random.py +2 -2
- datachain/func/string.py +67 -37
- datachain/func/window.py +7 -8
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +58 -22
- datachain/lib/audio.py +245 -0
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/flatten.py +5 -3
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/sql_to_python.py +8 -0
- datachain/lib/convert/values_to_tuples.py +156 -51
- datachain/lib/data_model.py +42 -20
- datachain/lib/dataset_info.py +36 -8
- datachain/lib/dc/__init__.py +8 -2
- datachain/lib/dc/csv.py +25 -28
- datachain/lib/dc/database.py +398 -0
- datachain/lib/dc/datachain.py +1289 -425
- datachain/lib/dc/datasets.py +320 -38
- datachain/lib/dc/hf.py +38 -24
- datachain/lib/dc/json.py +29 -32
- datachain/lib/dc/listings.py +112 -8
- datachain/lib/dc/pandas.py +16 -12
- datachain/lib/dc/parquet.py +35 -23
- datachain/lib/dc/records.py +31 -23
- datachain/lib/dc/storage.py +154 -64
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +24 -16
- datachain/lib/dc/values.py +8 -9
- datachain/lib/file.py +622 -89
- datachain/lib/hf.py +69 -39
- datachain/lib/image.py +14 -14
- datachain/lib/listing.py +14 -11
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +3 -4
- datachain/lib/model_store.py +39 -7
- datachain/lib/namespaces.py +125 -0
- datachain/lib/projects.py +130 -0
- datachain/lib/pytorch.py +32 -21
- datachain/lib/settings.py +192 -56
- datachain/lib/signal_schema.py +427 -104
- datachain/lib/tar.py +1 -2
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +164 -76
- datachain/lib/udf_signature.py +60 -35
- datachain/lib/utils.py +118 -4
- datachain/lib/video.py +17 -9
- datachain/lib/webdataset.py +61 -56
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +22 -10
- datachain/model/bbox.py +3 -1
- datachain/model/ultralytics/bbox.py +16 -12
- datachain/model/ultralytics/pose.py +16 -12
- datachain/model/ultralytics/segment.py +16 -12
- datachain/namespace.py +84 -0
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +78 -0
- datachain/query/batch.py +40 -41
- datachain/query/dataset.py +604 -322
- datachain/query/dispatch.py +261 -154
- datachain/query/metrics.py +4 -6
- datachain/query/params.py +2 -3
- datachain/query/queue.py +3 -12
- datachain/query/schema.py +11 -6
- datachain/query/session.py +200 -33
- datachain/query/udf.py +34 -2
- datachain/remote/studio.py +171 -69
- datachain/script_meta.py +12 -12
- datachain/semver.py +68 -0
- datachain/sql/__init__.py +2 -0
- datachain/sql/functions/array.py +33 -1
- datachain/sql/postgresql_dialect.py +9 -0
- datachain/sql/postgresql_types.py +21 -0
- datachain/sql/sqlite/__init__.py +5 -1
- datachain/sql/sqlite/base.py +102 -29
- datachain/sql/sqlite/types.py +8 -13
- datachain/sql/types.py +70 -15
- datachain/studio.py +223 -46
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +101 -59
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
- datachain-0.39.0.dist-info/RECORD +173 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
- datachain/cli/commands/query.py +0 -53
- datachain/query/utils.py +0 -42
- datachain-0.14.2.dist-info/RECORD +0 -158
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/remote/studio.py
CHANGED
|
@@ -1,47 +1,64 @@
|
|
|
1
|
-
import base64
|
|
2
1
|
import json
|
|
3
2
|
import logging
|
|
4
3
|
import os
|
|
5
4
|
from collections.abc import AsyncIterator, Iterable, Iterator
|
|
6
5
|
from datetime import datetime, timedelta, timezone
|
|
7
6
|
from struct import unpack
|
|
8
|
-
from typing import
|
|
9
|
-
Any,
|
|
10
|
-
Generic,
|
|
11
|
-
Optional,
|
|
12
|
-
TypeVar,
|
|
13
|
-
)
|
|
7
|
+
from typing import Any, BinaryIO, Generic, TypeVar
|
|
14
8
|
from urllib.parse import urlparse, urlunparse
|
|
15
9
|
|
|
16
10
|
import websockets
|
|
17
11
|
from requests.exceptions import HTTPError, Timeout
|
|
18
12
|
|
|
19
13
|
from datachain.config import Config
|
|
14
|
+
from datachain.dataset import DatasetRecord
|
|
20
15
|
from datachain.error import DataChainError
|
|
21
16
|
from datachain.utils import STUDIO_URL, retry_with_backoff
|
|
22
17
|
|
|
23
18
|
T = TypeVar("T")
|
|
24
|
-
LsData =
|
|
25
|
-
DatasetInfoData =
|
|
26
|
-
DatasetRowsData =
|
|
27
|
-
DatasetJobVersionsData =
|
|
28
|
-
DatasetExportStatus =
|
|
29
|
-
DatasetExportSignedUrls =
|
|
30
|
-
FileUploadData =
|
|
31
|
-
JobData =
|
|
19
|
+
LsData = list[dict[str, Any]] | None
|
|
20
|
+
DatasetInfoData = dict[str, Any] | None
|
|
21
|
+
DatasetRowsData = Iterable[dict[str, Any]] | None
|
|
22
|
+
DatasetJobVersionsData = dict[str, Any] | None
|
|
23
|
+
DatasetExportStatus = dict[str, Any] | None
|
|
24
|
+
DatasetExportSignedUrls = list[str] | None
|
|
25
|
+
FileUploadData = dict[str, Any] | None
|
|
26
|
+
JobData = dict[str, Any] | None
|
|
27
|
+
JobListData = list[dict[str, Any]]
|
|
28
|
+
ClusterListData = list[dict[str, Any]]
|
|
32
29
|
|
|
33
30
|
logger = logging.getLogger("datachain")
|
|
34
31
|
|
|
35
32
|
DATASET_ROWS_CHUNK_SIZE = 8192
|
|
36
33
|
|
|
37
34
|
|
|
35
|
+
def get_studio_env_variable(name: str) -> Any:
|
|
36
|
+
"""
|
|
37
|
+
Get the value of a DataChain Studio environment variable.
|
|
38
|
+
It first checks for the variable prefixed with 'DATACHAIN_STUDIO_',
|
|
39
|
+
then checks for the deprecated 'DVC_STUDIO_' prefix.
|
|
40
|
+
If neither is set, it returns the provided default value.
|
|
41
|
+
"""
|
|
42
|
+
if (value := os.environ.get(f"DATACHAIN_STUDIO_{name}")) is not None:
|
|
43
|
+
return value
|
|
44
|
+
if (value := os.environ.get(f"DVC_STUDIO_{name}")) is not None: # deprecated
|
|
45
|
+
logger.warning(
|
|
46
|
+
"Environment variable 'DVC_STUDIO_%s' is deprecated, "
|
|
47
|
+
"use 'DATACHAIN_STUDIO_%s' instead.",
|
|
48
|
+
name,
|
|
49
|
+
name,
|
|
50
|
+
)
|
|
51
|
+
return value
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
|
|
38
55
|
def _is_server_error(status_code: int) -> bool:
|
|
39
56
|
return str(status_code).startswith("5")
|
|
40
57
|
|
|
41
58
|
|
|
42
59
|
def is_token_set() -> bool:
|
|
43
60
|
return (
|
|
44
|
-
bool(
|
|
61
|
+
bool(get_studio_env_variable("TOKEN"))
|
|
45
62
|
or Config().read().get("studio", {}).get("token") is not None
|
|
46
63
|
)
|
|
47
64
|
|
|
@@ -56,10 +73,11 @@ def _parse_dates(obj: dict, date_fields: list[str]):
|
|
|
56
73
|
|
|
57
74
|
|
|
58
75
|
class Response(Generic[T]):
|
|
59
|
-
def __init__(self, data: T, ok: bool, message: str) -> None:
|
|
76
|
+
def __init__(self, data: T, ok: bool, message: str, status: int) -> None:
|
|
60
77
|
self.data = data
|
|
61
78
|
self.ok = ok
|
|
62
79
|
self.message = message
|
|
80
|
+
self.status = status
|
|
63
81
|
|
|
64
82
|
def __repr__(self):
|
|
65
83
|
return (
|
|
@@ -69,7 +87,7 @@ class Response(Generic[T]):
|
|
|
69
87
|
|
|
70
88
|
|
|
71
89
|
class StudioClient:
|
|
72
|
-
def __init__(self, timeout: float = 3600.0, team:
|
|
90
|
+
def __init__(self, timeout: float = 3600.0, team: str | None = None) -> None:
|
|
73
91
|
self._check_dependencies()
|
|
74
92
|
self.timeout = timeout
|
|
75
93
|
self._config = None
|
|
@@ -77,12 +95,12 @@ class StudioClient:
|
|
|
77
95
|
|
|
78
96
|
@property
|
|
79
97
|
def token(self) -> str:
|
|
80
|
-
token =
|
|
98
|
+
token = get_studio_env_variable("TOKEN") or self.config.get("token")
|
|
81
99
|
|
|
82
100
|
if not token:
|
|
83
101
|
raise DataChainError(
|
|
84
102
|
"Studio token is not set. Use `datachain auth login` "
|
|
85
|
-
"or environment variable `
|
|
103
|
+
"or environment variable `DATACHAIN_STUDIO_TOKEN` to set it."
|
|
86
104
|
)
|
|
87
105
|
|
|
88
106
|
return token
|
|
@@ -90,8 +108,8 @@ class StudioClient:
|
|
|
90
108
|
@property
|
|
91
109
|
def url(self) -> str:
|
|
92
110
|
return (
|
|
93
|
-
|
|
94
|
-
) + "/api"
|
|
111
|
+
get_studio_env_variable("URL") or self.config.get("url") or STUDIO_URL
|
|
112
|
+
).rstrip("/") + "/api"
|
|
95
113
|
|
|
96
114
|
@property
|
|
97
115
|
def config(self) -> dict:
|
|
@@ -106,13 +124,13 @@ class StudioClient:
|
|
|
106
124
|
return self._team
|
|
107
125
|
|
|
108
126
|
def _get_team(self) -> str:
|
|
109
|
-
team =
|
|
127
|
+
team = get_studio_env_variable("TEAM") or self.config.get("team")
|
|
110
128
|
|
|
111
129
|
if not team:
|
|
112
130
|
raise DataChainError(
|
|
113
131
|
"Studio team is not set. "
|
|
114
132
|
"Use `datachain auth team <team_name>` "
|
|
115
|
-
"or environment variable `
|
|
133
|
+
"or environment variable `DATACHAIN_STUDIO_TEAM` to set it. "
|
|
116
134
|
"You can also set `studio.team` in the config file."
|
|
117
135
|
)
|
|
118
136
|
|
|
@@ -130,7 +148,7 @@ class StudioClient:
|
|
|
130
148
|
) from None
|
|
131
149
|
|
|
132
150
|
def _send_request_msgpack(
|
|
133
|
-
self, route: str, data: dict[str, Any], method:
|
|
151
|
+
self, route: str, data: dict[str, Any], method: str | None = "POST"
|
|
134
152
|
) -> Response[Any]:
|
|
135
153
|
import msgpack
|
|
136
154
|
import requests
|
|
@@ -164,11 +182,11 @@ class StudioClient:
|
|
|
164
182
|
message = "Indexing in progress"
|
|
165
183
|
else:
|
|
166
184
|
message = content.get("message", "")
|
|
167
|
-
return Response(response_data, ok, message)
|
|
185
|
+
return Response(response_data, ok, message, response.status_code)
|
|
168
186
|
|
|
169
187
|
@retry_with_backoff(retries=3, errors=(HTTPError, Timeout))
|
|
170
188
|
def _send_request(
|
|
171
|
-
self, route: str, data: dict[str, Any], method:
|
|
189
|
+
self, route: str, data: dict[str, Any], method: str | None = "POST"
|
|
172
190
|
) -> Response[Any]:
|
|
173
191
|
"""
|
|
174
192
|
Function that communicate Studio API.
|
|
@@ -214,7 +232,46 @@ class StudioClient:
|
|
|
214
232
|
else:
|
|
215
233
|
message = ""
|
|
216
234
|
|
|
217
|
-
return Response(data, ok, message)
|
|
235
|
+
return Response(data, ok, message, response.status_code)
|
|
236
|
+
|
|
237
|
+
def _send_multipart_request(
|
|
238
|
+
self, route: str, files: dict[str, Any], params: dict[str, Any] | None = None
|
|
239
|
+
) -> Response[Any]:
|
|
240
|
+
"""
|
|
241
|
+
Function that communicates with Studio API using multipart/form-data.
|
|
242
|
+
It will raise an exception, and try to retry, if 5xx status code is
|
|
243
|
+
returned, or if Timeout exceptions is thrown from the requests lib
|
|
244
|
+
"""
|
|
245
|
+
import requests
|
|
246
|
+
|
|
247
|
+
# Add team_name to params
|
|
248
|
+
request_params = {**(params or {}), "team_name": self.team}
|
|
249
|
+
|
|
250
|
+
response = requests.post(
|
|
251
|
+
url=f"{self.url}/{route}",
|
|
252
|
+
files=files,
|
|
253
|
+
params=request_params,
|
|
254
|
+
headers={
|
|
255
|
+
"Authorization": f"token {self.token}",
|
|
256
|
+
},
|
|
257
|
+
timeout=self.timeout,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
ok = response.ok
|
|
261
|
+
try:
|
|
262
|
+
data = json.loads(response.content.decode("utf-8"))
|
|
263
|
+
except json.decoder.JSONDecodeError:
|
|
264
|
+
data = {}
|
|
265
|
+
|
|
266
|
+
if not ok:
|
|
267
|
+
if response.status_code == 403:
|
|
268
|
+
message = f"Not authorized for the team {self.team}"
|
|
269
|
+
else:
|
|
270
|
+
message = data.get("message", "")
|
|
271
|
+
else:
|
|
272
|
+
message = ""
|
|
273
|
+
|
|
274
|
+
return Response(data, ok, message, response.status_code)
|
|
218
275
|
|
|
219
276
|
@staticmethod
|
|
220
277
|
def _unpacker_hook(code, data):
|
|
@@ -282,21 +339,27 @@ class StudioClient:
|
|
|
282
339
|
response = self._send_request_msgpack("datachain/ls", {"source": path})
|
|
283
340
|
yield path, response
|
|
284
341
|
|
|
285
|
-
def ls_datasets(self) -> Response[LsData]:
|
|
286
|
-
return self._send_request(
|
|
342
|
+
def ls_datasets(self, prefix: str | None = None) -> Response[LsData]:
|
|
343
|
+
return self._send_request(
|
|
344
|
+
"datachain/datasets", {"prefix": prefix}, method="GET"
|
|
345
|
+
)
|
|
287
346
|
|
|
288
347
|
def edit_dataset(
|
|
289
348
|
self,
|
|
290
349
|
name: str,
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
350
|
+
namespace: str,
|
|
351
|
+
project: str,
|
|
352
|
+
new_name: str | None = None,
|
|
353
|
+
description: str | None = None,
|
|
354
|
+
attrs: list[str] | None = None,
|
|
294
355
|
) -> Response[DatasetInfoData]:
|
|
295
356
|
body = {
|
|
296
357
|
"new_name": new_name,
|
|
297
|
-
"
|
|
358
|
+
"name": name,
|
|
359
|
+
"namespace": namespace,
|
|
360
|
+
"project": project,
|
|
298
361
|
"description": description,
|
|
299
|
-
"
|
|
362
|
+
"attrs": attrs,
|
|
300
363
|
}
|
|
301
364
|
|
|
302
365
|
return self._send_request(
|
|
@@ -307,44 +370,44 @@ class StudioClient:
|
|
|
307
370
|
def rm_dataset(
|
|
308
371
|
self,
|
|
309
372
|
name: str,
|
|
310
|
-
|
|
311
|
-
|
|
373
|
+
namespace: str,
|
|
374
|
+
project: str,
|
|
375
|
+
version: str | None = None,
|
|
376
|
+
force: bool | None = False,
|
|
312
377
|
) -> Response[DatasetInfoData]:
|
|
313
378
|
return self._send_request(
|
|
314
379
|
"datachain/datasets",
|
|
315
380
|
{
|
|
316
|
-
"
|
|
317
|
-
"
|
|
381
|
+
"name": name,
|
|
382
|
+
"namespace": namespace,
|
|
383
|
+
"project": project,
|
|
384
|
+
"version": version,
|
|
318
385
|
"force": force,
|
|
319
386
|
},
|
|
320
387
|
method="DELETE",
|
|
321
388
|
)
|
|
322
389
|
|
|
323
|
-
def dataset_info(
|
|
390
|
+
def dataset_info(
|
|
391
|
+
self, namespace: str, project: str, name: str
|
|
392
|
+
) -> Response[DatasetInfoData]:
|
|
324
393
|
def _parse_dataset_info(dataset_info):
|
|
325
394
|
_parse_dates(dataset_info, ["created_at", "finished_at"])
|
|
326
395
|
for version in dataset_info.get("versions"):
|
|
327
396
|
_parse_dates(version, ["created_at"])
|
|
397
|
+
_parse_dates(dataset_info.get("project"), ["created_at"])
|
|
398
|
+
_parse_dates(dataset_info.get("project").get("namespace"), ["created_at"])
|
|
328
399
|
|
|
329
400
|
return dataset_info
|
|
330
401
|
|
|
331
402
|
response = self._send_request(
|
|
332
|
-
"datachain/datasets/info",
|
|
403
|
+
"datachain/datasets/info",
|
|
404
|
+
{"namespace": namespace, "project": project, "name": name},
|
|
405
|
+
method="GET",
|
|
333
406
|
)
|
|
334
407
|
if response.ok:
|
|
335
408
|
response.data = _parse_dataset_info(response.data)
|
|
336
409
|
return response
|
|
337
410
|
|
|
338
|
-
def dataset_rows_chunk(
|
|
339
|
-
self, name: str, version: int, offset: int
|
|
340
|
-
) -> Response[DatasetRowsData]:
|
|
341
|
-
req_data = {"dataset_name": name, "dataset_version": version}
|
|
342
|
-
return self._send_request_msgpack(
|
|
343
|
-
"datachain/datasets/rows",
|
|
344
|
-
{**req_data, "offset": offset, "limit": DATASET_ROWS_CHUNK_SIZE},
|
|
345
|
-
method="GET",
|
|
346
|
-
)
|
|
347
|
-
|
|
348
411
|
def dataset_job_versions(self, job_id: str) -> Response[DatasetJobVersionsData]:
|
|
349
412
|
return self._send_request(
|
|
350
413
|
"datachain/datasets/dataset_job_versions",
|
|
@@ -353,40 +416,57 @@ class StudioClient:
|
|
|
353
416
|
)
|
|
354
417
|
|
|
355
418
|
def export_dataset_table(
|
|
356
|
-
self,
|
|
419
|
+
self, dataset: DatasetRecord, version: str
|
|
357
420
|
) -> Response[DatasetExportSignedUrls]:
|
|
358
421
|
return self._send_request(
|
|
359
422
|
"datachain/datasets/export",
|
|
360
|
-
{
|
|
423
|
+
{
|
|
424
|
+
"namespace": dataset.project.namespace.name,
|
|
425
|
+
"project": dataset.project.name,
|
|
426
|
+
"name": dataset.name,
|
|
427
|
+
"version": version,
|
|
428
|
+
},
|
|
361
429
|
method="GET",
|
|
362
430
|
)
|
|
363
431
|
|
|
364
432
|
def dataset_export_status(
|
|
365
|
-
self,
|
|
433
|
+
self, dataset: DatasetRecord, version: str
|
|
366
434
|
) -> Response[DatasetExportStatus]:
|
|
367
435
|
return self._send_request(
|
|
368
436
|
"datachain/datasets/export-status",
|
|
369
|
-
{
|
|
437
|
+
{
|
|
438
|
+
"namespace": dataset.project.namespace.name,
|
|
439
|
+
"project": dataset.project.name,
|
|
440
|
+
"name": dataset.name,
|
|
441
|
+
"version": version,
|
|
442
|
+
},
|
|
370
443
|
method="GET",
|
|
371
444
|
)
|
|
372
445
|
|
|
373
|
-
def upload_file(
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
}
|
|
378
|
-
|
|
446
|
+
def upload_file(
|
|
447
|
+
self, file_obj: BinaryIO, file_name: str
|
|
448
|
+
) -> Response[FileUploadData]:
|
|
449
|
+
# Prepare multipart form data
|
|
450
|
+
files = {"file": (file_name, file_obj, "application/octet-stream")}
|
|
451
|
+
|
|
452
|
+
return self._send_multipart_request("datachain/jobs/files", files)
|
|
379
453
|
|
|
380
454
|
def create_job(
|
|
381
455
|
self,
|
|
382
456
|
query: str,
|
|
383
457
|
query_type: str,
|
|
384
|
-
environment:
|
|
385
|
-
workers:
|
|
386
|
-
query_name:
|
|
387
|
-
files:
|
|
388
|
-
python_version:
|
|
389
|
-
requirements:
|
|
458
|
+
environment: str | None = None,
|
|
459
|
+
workers: int | None = None,
|
|
460
|
+
query_name: str | None = None,
|
|
461
|
+
files: list[str] | None = None,
|
|
462
|
+
python_version: str | None = None,
|
|
463
|
+
requirements: str | None = None,
|
|
464
|
+
repository: str | None = None,
|
|
465
|
+
priority: int | None = None,
|
|
466
|
+
cluster: str | None = None,
|
|
467
|
+
start_time: str | None = None,
|
|
468
|
+
cron: str | None = None,
|
|
469
|
+
credentials_name: str | None = None,
|
|
390
470
|
) -> Response[JobData]:
|
|
391
471
|
data = {
|
|
392
472
|
"query": query,
|
|
@@ -397,12 +477,34 @@ class StudioClient:
|
|
|
397
477
|
"files": files,
|
|
398
478
|
"python_version": python_version,
|
|
399
479
|
"requirements": requirements,
|
|
480
|
+
"repository": repository,
|
|
481
|
+
"priority": priority,
|
|
482
|
+
"compute_cluster_name": cluster,
|
|
483
|
+
"start_after": start_time,
|
|
484
|
+
"cron_expression": cron,
|
|
485
|
+
"credentials_name": credentials_name,
|
|
400
486
|
}
|
|
401
|
-
return self._send_request("datachain/
|
|
487
|
+
return self._send_request("datachain/jobs/", data)
|
|
488
|
+
|
|
489
|
+
def get_jobs(
|
|
490
|
+
self,
|
|
491
|
+
status: str | None = None,
|
|
492
|
+
limit: int = 20,
|
|
493
|
+
job_id: str | None = None,
|
|
494
|
+
) -> Response[JobListData]:
|
|
495
|
+
params: dict[str, Any] = {"limit": limit}
|
|
496
|
+
if status is not None:
|
|
497
|
+
params["status"] = status
|
|
498
|
+
if job_id is not None:
|
|
499
|
+
params["job_id"] = job_id
|
|
500
|
+
return self._send_request("datachain/jobs/", params, method="GET")
|
|
402
501
|
|
|
403
502
|
def cancel_job(
|
|
404
503
|
self,
|
|
405
504
|
job_id: str,
|
|
406
505
|
) -> Response[JobData]:
|
|
407
|
-
url = f"datachain/
|
|
506
|
+
url = f"datachain/jobs/{job_id}/cancel"
|
|
408
507
|
return self._send_request(url, data={}, method="POST")
|
|
508
|
+
|
|
509
|
+
def get_clusters(self) -> Response[ClusterListData]:
|
|
510
|
+
return self._send_request("datachain/clusters/", {}, method="GET")
|
datachain/script_meta.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import re
|
|
2
2
|
from dataclasses import dataclass
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any
|
|
4
4
|
|
|
5
5
|
try:
|
|
6
6
|
import tomllib
|
|
@@ -59,23 +59,23 @@ class ScriptConfig:
|
|
|
59
59
|
|
|
60
60
|
"""
|
|
61
61
|
|
|
62
|
-
python_version:
|
|
62
|
+
python_version: str | None
|
|
63
63
|
dependencies: list[str]
|
|
64
64
|
attachments: dict[str, str]
|
|
65
65
|
params: dict[str, Any]
|
|
66
66
|
inputs: dict[str, Any]
|
|
67
67
|
outputs: dict[str, Any]
|
|
68
|
-
num_workers:
|
|
68
|
+
num_workers: int | None = None
|
|
69
69
|
|
|
70
70
|
def __init__(
|
|
71
71
|
self,
|
|
72
|
-
python_version:
|
|
73
|
-
dependencies:
|
|
74
|
-
attachments:
|
|
75
|
-
params:
|
|
76
|
-
inputs:
|
|
77
|
-
outputs:
|
|
78
|
-
num_workers:
|
|
72
|
+
python_version: str | None = None,
|
|
73
|
+
dependencies: list[str] | None = None,
|
|
74
|
+
attachments: dict[str, str] | None = None,
|
|
75
|
+
params: dict[str, Any] | None = None,
|
|
76
|
+
inputs: dict[str, Any] | None = None,
|
|
77
|
+
outputs: dict[str, Any] | None = None,
|
|
78
|
+
num_workers: int | None = None,
|
|
79
79
|
):
|
|
80
80
|
self.python_version = python_version
|
|
81
81
|
self.dependencies = dependencies or []
|
|
@@ -98,7 +98,7 @@ class ScriptConfig:
|
|
|
98
98
|
return self.attachments.get(name, default)
|
|
99
99
|
|
|
100
100
|
@staticmethod
|
|
101
|
-
def read(script: str) ->
|
|
101
|
+
def read(script: str) -> dict | None:
|
|
102
102
|
"""Converts inline script metadata to dict with all found data"""
|
|
103
103
|
regex = (
|
|
104
104
|
r"(?m)^# \/\/\/ (?P<type>[a-zA-Z0-9-]+)[ \t]*$[\r\n|\r|\n]"
|
|
@@ -119,7 +119,7 @@ class ScriptConfig:
|
|
|
119
119
|
return None
|
|
120
120
|
|
|
121
121
|
@staticmethod
|
|
122
|
-
def parse(script: str) ->
|
|
122
|
+
def parse(script: str) -> "ScriptConfig | None":
|
|
123
123
|
"""
|
|
124
124
|
Method that is parsing inline script metadata from datachain script and
|
|
125
125
|
instantiating ScriptConfig class with found data. If no inline metadata is
|
datachain/semver.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# Maximum version number for semver (major.minor.patch) is 999999.999999.999999
|
|
2
|
+
# this number was chosen because value("999999.999999.999999") < 2**63 - 1
|
|
3
|
+
MAX_VERSION_NUMBER = 999_999
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def parse(version: str) -> tuple[int, int, int]:
|
|
7
|
+
"""Parsing semver into 3 integers: major, minor, patch"""
|
|
8
|
+
validate(version)
|
|
9
|
+
parts = version.split(".")
|
|
10
|
+
return int(parts[0]), int(parts[1]), int(parts[2])
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def validate(version: str) -> None:
|
|
14
|
+
"""
|
|
15
|
+
Raises exception if version doesn't have valid semver format which is:
|
|
16
|
+
<major>.<minor>.<patch> or one of version parts is not positive integer
|
|
17
|
+
"""
|
|
18
|
+
error_message = (
|
|
19
|
+
"Invalid version. It should be in format: <major>.<minor>.<patch> where"
|
|
20
|
+
" each version part is positive integer"
|
|
21
|
+
)
|
|
22
|
+
parts = version.split(".")
|
|
23
|
+
if len(parts) != 3:
|
|
24
|
+
raise ValueError(error_message)
|
|
25
|
+
for part in parts:
|
|
26
|
+
try:
|
|
27
|
+
val = int(part)
|
|
28
|
+
assert 0 <= val <= MAX_VERSION_NUMBER
|
|
29
|
+
except (ValueError, AssertionError):
|
|
30
|
+
raise ValueError(error_message) from None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def create(major: int = 0, minor: int = 0, patch: int = 0) -> str:
|
|
34
|
+
"""Creates new semver from 3 integers: major, minor and patch"""
|
|
35
|
+
if not (
|
|
36
|
+
0 <= major <= MAX_VERSION_NUMBER
|
|
37
|
+
and 0 <= minor <= MAX_VERSION_NUMBER
|
|
38
|
+
and 0 <= patch <= MAX_VERSION_NUMBER
|
|
39
|
+
):
|
|
40
|
+
raise ValueError("Major, minor and patch must be greater or equal to zero")
|
|
41
|
+
|
|
42
|
+
return ".".join([str(major), str(minor), str(patch)])
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def value(version: str) -> int:
|
|
46
|
+
"""
|
|
47
|
+
Calculate integer value of a version. This is useful when comparing two versions.
|
|
48
|
+
"""
|
|
49
|
+
major, minor, patch = parse(version)
|
|
50
|
+
limit = MAX_VERSION_NUMBER + 1
|
|
51
|
+
return major * (limit**2) + minor * limit + patch
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def compare(v1: str, v2: str) -> int:
|
|
55
|
+
"""
|
|
56
|
+
Compares 2 versions and returns:
|
|
57
|
+
-1 if v1 < v2
|
|
58
|
+
0 if v1 == v2
|
|
59
|
+
1 if v1 > v2
|
|
60
|
+
"""
|
|
61
|
+
v1_val = value(v1)
|
|
62
|
+
v2_val = value(v2)
|
|
63
|
+
|
|
64
|
+
if v1_val < v2_val:
|
|
65
|
+
return -1
|
|
66
|
+
if v1_val > v2_val:
|
|
67
|
+
return 1
|
|
68
|
+
return 0
|
datachain/sql/__init__.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
from sqlalchemy.sql.elements import literal
|
|
2
2
|
from sqlalchemy.sql.expression import column
|
|
3
3
|
|
|
4
|
+
# Import PostgreSQL dialect registration (registers PostgreSQL type converter)
|
|
5
|
+
from . import postgresql_dialect # noqa: F401
|
|
4
6
|
from .default import setup as default_setup
|
|
5
7
|
from .selectable import select, values
|
|
6
8
|
|
datachain/sql/functions/array.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from sqlalchemy.sql.functions import GenericFunction
|
|
2
2
|
|
|
3
|
-
from datachain.sql.types import Boolean, Float, Int64
|
|
3
|
+
from datachain.sql.types import Boolean, Float, Int64, String
|
|
4
4
|
from datachain.sql.utils import compiler_not_implemented
|
|
5
5
|
|
|
6
6
|
|
|
@@ -48,6 +48,37 @@ class contains(GenericFunction): # noqa: N801
|
|
|
48
48
|
inherit_cache = True
|
|
49
49
|
|
|
50
50
|
|
|
51
|
+
class slice(GenericFunction): # noqa: N801
|
|
52
|
+
"""
|
|
53
|
+
Returns a slice of the array.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
package = "array"
|
|
57
|
+
name = "slice"
|
|
58
|
+
inherit_cache = True
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class join(GenericFunction): # noqa: N801
|
|
62
|
+
"""
|
|
63
|
+
Returns the concatenation of the array elements.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
type = String()
|
|
67
|
+
package = "array"
|
|
68
|
+
name = "join"
|
|
69
|
+
inherit_cache = True
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class get_element(GenericFunction): # noqa: N801
|
|
73
|
+
"""
|
|
74
|
+
Returns the element at the given index in the array.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
package = "array"
|
|
78
|
+
name = "get_element"
|
|
79
|
+
inherit_cache = True
|
|
80
|
+
|
|
81
|
+
|
|
51
82
|
class sip_hash_64(GenericFunction): # noqa: N801
|
|
52
83
|
"""
|
|
53
84
|
Computes the SipHash-64 hash of the array.
|
|
@@ -63,4 +94,5 @@ compiler_not_implemented(cosine_distance)
|
|
|
63
94
|
compiler_not_implemented(euclidean_distance)
|
|
64
95
|
compiler_not_implemented(length)
|
|
65
96
|
compiler_not_implemented(contains)
|
|
97
|
+
compiler_not_implemented(get_element)
|
|
66
98
|
compiler_not_implemented(sip_hash_64)
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PostgreSQL dialect registration for DataChain.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from datachain.sql.postgresql_types import PostgreSQLTypeConverter
|
|
6
|
+
from datachain.sql.types import register_backend_types
|
|
7
|
+
|
|
8
|
+
# Register PostgreSQL type converter
|
|
9
|
+
register_backend_types("postgresql", PostgreSQLTypeConverter())
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PostgreSQL-specific type converter for DataChain.
|
|
3
|
+
|
|
4
|
+
Handles PostgreSQL-specific type mappings that differ from the default dialect.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from sqlalchemy.dialects import postgresql
|
|
8
|
+
|
|
9
|
+
from datachain.sql.types import TypeConverter
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PostgreSQLTypeConverter(TypeConverter):
|
|
13
|
+
"""PostgreSQL-specific type converter."""
|
|
14
|
+
|
|
15
|
+
def datetime(self):
|
|
16
|
+
"""PostgreSQL uses TIMESTAMP WITH TIME ZONE to preserve timezone information."""
|
|
17
|
+
return postgresql.TIMESTAMP(timezone=True)
|
|
18
|
+
|
|
19
|
+
def json(self):
|
|
20
|
+
"""PostgreSQL uses JSONB for better performance and query capabilities."""
|
|
21
|
+
return postgresql.JSONB()
|