datachain 0.30.5__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. datachain/__init__.py +4 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +5 -5
  4. datachain/catalog/__init__.py +0 -2
  5. datachain/catalog/catalog.py +276 -354
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +8 -3
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +10 -17
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +42 -27
  12. datachain/cli/commands/ls.py +15 -15
  13. datachain/cli/commands/show.py +2 -2
  14. datachain/cli/parser/__init__.py +3 -43
  15. datachain/cli/parser/job.py +1 -1
  16. datachain/cli/parser/utils.py +1 -2
  17. datachain/cli/utils.py +2 -15
  18. datachain/client/azure.py +2 -2
  19. datachain/client/fsspec.py +34 -23
  20. datachain/client/gcs.py +3 -3
  21. datachain/client/http.py +157 -0
  22. datachain/client/local.py +11 -7
  23. datachain/client/s3.py +3 -3
  24. datachain/config.py +4 -8
  25. datachain/data_storage/db_engine.py +12 -6
  26. datachain/data_storage/job.py +2 -0
  27. datachain/data_storage/metastore.py +716 -137
  28. datachain/data_storage/schema.py +20 -27
  29. datachain/data_storage/serializer.py +105 -15
  30. datachain/data_storage/sqlite.py +114 -114
  31. datachain/data_storage/warehouse.py +140 -48
  32. datachain/dataset.py +109 -89
  33. datachain/delta.py +117 -42
  34. datachain/diff/__init__.py +25 -33
  35. datachain/error.py +24 -0
  36. datachain/func/aggregate.py +9 -11
  37. datachain/func/array.py +12 -12
  38. datachain/func/base.py +7 -4
  39. datachain/func/conditional.py +9 -13
  40. datachain/func/func.py +63 -45
  41. datachain/func/numeric.py +5 -7
  42. datachain/func/string.py +2 -2
  43. datachain/hash_utils.py +123 -0
  44. datachain/job.py +11 -7
  45. datachain/json.py +138 -0
  46. datachain/lib/arrow.py +18 -15
  47. datachain/lib/audio.py +60 -59
  48. datachain/lib/clip.py +14 -13
  49. datachain/lib/convert/python_to_sql.py +6 -10
  50. datachain/lib/convert/values_to_tuples.py +151 -53
  51. datachain/lib/data_model.py +23 -19
  52. datachain/lib/dataset_info.py +7 -7
  53. datachain/lib/dc/__init__.py +2 -1
  54. datachain/lib/dc/csv.py +22 -26
  55. datachain/lib/dc/database.py +37 -34
  56. datachain/lib/dc/datachain.py +518 -324
  57. datachain/lib/dc/datasets.py +38 -30
  58. datachain/lib/dc/hf.py +16 -20
  59. datachain/lib/dc/json.py +17 -18
  60. datachain/lib/dc/listings.py +5 -8
  61. datachain/lib/dc/pandas.py +3 -6
  62. datachain/lib/dc/parquet.py +33 -21
  63. datachain/lib/dc/records.py +9 -13
  64. datachain/lib/dc/storage.py +103 -65
  65. datachain/lib/dc/storage_pattern.py +251 -0
  66. datachain/lib/dc/utils.py +17 -14
  67. datachain/lib/dc/values.py +3 -6
  68. datachain/lib/file.py +187 -50
  69. datachain/lib/hf.py +7 -5
  70. datachain/lib/image.py +13 -13
  71. datachain/lib/listing.py +5 -5
  72. datachain/lib/listing_info.py +1 -2
  73. datachain/lib/meta_formats.py +2 -3
  74. datachain/lib/model_store.py +20 -8
  75. datachain/lib/namespaces.py +59 -7
  76. datachain/lib/projects.py +51 -9
  77. datachain/lib/pytorch.py +31 -23
  78. datachain/lib/settings.py +188 -85
  79. datachain/lib/signal_schema.py +302 -64
  80. datachain/lib/text.py +8 -7
  81. datachain/lib/udf.py +103 -63
  82. datachain/lib/udf_signature.py +59 -34
  83. datachain/lib/utils.py +20 -0
  84. datachain/lib/video.py +3 -4
  85. datachain/lib/webdataset.py +31 -36
  86. datachain/lib/webdataset_laion.py +15 -16
  87. datachain/listing.py +12 -5
  88. datachain/model/bbox.py +3 -1
  89. datachain/namespace.py +22 -3
  90. datachain/node.py +6 -6
  91. datachain/nodes_thread_pool.py +0 -1
  92. datachain/plugins.py +24 -0
  93. datachain/project.py +4 -4
  94. datachain/query/batch.py +10 -12
  95. datachain/query/dataset.py +376 -194
  96. datachain/query/dispatch.py +112 -84
  97. datachain/query/metrics.py +3 -4
  98. datachain/query/params.py +2 -3
  99. datachain/query/queue.py +2 -1
  100. datachain/query/schema.py +7 -6
  101. datachain/query/session.py +190 -33
  102. datachain/query/udf.py +9 -6
  103. datachain/remote/studio.py +90 -53
  104. datachain/script_meta.py +12 -12
  105. datachain/sql/sqlite/base.py +37 -25
  106. datachain/sql/sqlite/types.py +1 -1
  107. datachain/sql/types.py +36 -5
  108. datachain/studio.py +49 -40
  109. datachain/toolkit/split.py +31 -10
  110. datachain/utils.py +39 -48
  111. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/METADATA +26 -38
  112. datachain-0.39.0.dist-info/RECORD +173 -0
  113. datachain/cli/commands/query.py +0 -54
  114. datachain/query/utils.py +0 -36
  115. datachain-0.30.5.dist-info/RECORD +0 -168
  116. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/WHEEL +0 -0
  117. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  118. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  119. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,37 @@
1
1
  import atexit
2
- import gc
3
2
  import logging
3
+ import os
4
4
  import re
5
5
  import sys
6
- from typing import TYPE_CHECKING, ClassVar, Optional
6
+ import traceback
7
+ from collections.abc import Callable
8
+ from typing import TYPE_CHECKING, ClassVar
7
9
  from uuid import uuid4
10
+ from weakref import WeakSet
8
11
 
9
12
  from datachain.catalog import get_catalog
10
- from datachain.error import TableMissingError
13
+ from datachain.data_storage import JobQueryType, JobStatus
14
+ from datachain.error import JobNotFoundError, TableMissingError
11
15
 
12
16
  if TYPE_CHECKING:
13
17
  from datachain.catalog import Catalog
14
- from datachain.dataset import DatasetRecord
18
+ from datachain.job import Job
15
19
 
16
20
  logger = logging.getLogger("datachain")
17
21
 
18
22
 
23
+ def is_script_run() -> bool:
24
+ """
25
+ Returns True if this was ran as python script, e.g python my_script.py.
26
+ Otherwise (if interactive or module run) returns False.
27
+ """
28
+ try:
29
+ argv0 = sys.argv[0]
30
+ except (IndexError, AttributeError):
31
+ return False
32
+ return bool(argv0) and argv0 not in ("-c", "-m", "ipython")
33
+
34
+
19
35
  class Session:
20
36
  """
21
37
  Session is a context that keeps track of temporary DataChain datasets for a proper
@@ -39,10 +55,18 @@ class Session:
39
55
  catalog (Catalog): Catalog object.
40
56
  """
41
57
 
42
- GLOBAL_SESSION_CTX: Optional["Session"] = None
58
+ GLOBAL_SESSION_CTX: "Session | None" = None
43
59
  SESSION_CONTEXTS: ClassVar[list["Session"]] = []
60
+ _ALL_SESSIONS: ClassVar[WeakSet["Session"]] = WeakSet()
44
61
  ORIGINAL_EXCEPT_HOOK = None
45
62
 
63
+ # Job management - class-level to ensure one job per process
64
+ _CURRENT_JOB: ClassVar["Job | None"] = None
65
+ _JOB_STATUS: ClassVar[JobStatus | None] = None
66
+ _OWNS_JOB: ClassVar[bool | None] = None
67
+ _JOB_HOOKS_REGISTERED: ClassVar[bool] = False
68
+ _JOB_FINALIZE_HOOK: ClassVar[Callable[[], None] | None] = None
69
+
46
70
  DATASET_PREFIX = "session_"
47
71
  GLOBAL_SESSION_NAME = "global"
48
72
  SESSION_UUID_LEN = 6
@@ -51,8 +75,8 @@ class Session:
51
75
  def __init__(
52
76
  self,
53
77
  name="",
54
- catalog: Optional["Catalog"] = None,
55
- client_config: Optional[dict] = None,
78
+ catalog: "Catalog | None" = None,
79
+ client_config: dict | None = None,
56
80
  in_memory: bool = False,
57
81
  ):
58
82
  if re.match(r"^[0-9a-zA-Z]*$", name) is None:
@@ -69,7 +93,7 @@ class Session:
69
93
  self.catalog = catalog or get_catalog(
70
94
  client_config=client_config, in_memory=in_memory
71
95
  )
72
- self.dataset_versions: list[tuple[DatasetRecord, str, bool]] = []
96
+ Session._ALL_SESSIONS.add(self)
73
97
 
74
98
  def __enter__(self):
75
99
  # Push the current context onto the stack
@@ -78,9 +102,8 @@ class Session:
78
102
  return self
79
103
 
80
104
  def __exit__(self, exc_type, exc_val, exc_tb):
81
- if exc_type:
82
- self._cleanup_created_versions()
83
-
105
+ # Don't cleanup created versions on exception
106
+ # Datasets should persist even if the session fails
84
107
  self._cleanup_temp_datasets()
85
108
  if self.is_new_catalog:
86
109
  self.catalog.metastore.close_on_exit()
@@ -88,11 +111,116 @@ class Session:
88
111
 
89
112
  if Session.SESSION_CONTEXTS:
90
113
  Session.SESSION_CONTEXTS.pop()
114
+ Session._ALL_SESSIONS.discard(self)
91
115
 
92
- def add_dataset_version(
93
- self, dataset: "DatasetRecord", version: str, listing: bool = False
94
- ) -> None:
95
- self.dataset_versions.append((dataset, version, listing))
116
+ def get_or_create_job(self) -> "Job":
117
+ """
118
+ Get or create a Job for this process.
119
+
120
+ Returns:
121
+ Job: The active Job instance.
122
+
123
+ Behavior:
124
+ - If a job already exists, it is returned.
125
+ - If ``DATACHAIN_JOB_ID`` is set, the corresponding job is fetched.
126
+ - Otherwise, a new job is created:
127
+ * Name = absolute path to the Python script.
128
+ * Query = empty string.
129
+ * Parent = last job with the same name, if available.
130
+ * Status = "running".
131
+ Exit hooks are registered to finalize the job.
132
+
133
+ Note:
134
+ Job is shared across all Session instances to ensure one job per process.
135
+ """
136
+ if Session._CURRENT_JOB:
137
+ return Session._CURRENT_JOB
138
+
139
+ if env_job_id := os.getenv("DATACHAIN_JOB_ID"):
140
+ # SaaS run: just fetch existing job
141
+ Session._CURRENT_JOB = self.catalog.metastore.get_job(env_job_id)
142
+ if not Session._CURRENT_JOB:
143
+ raise JobNotFoundError(
144
+ f"Job {env_job_id} from DATACHAIN_JOB_ID env not found"
145
+ )
146
+ Session._OWNS_JOB = False
147
+ else:
148
+ # Local run: create new job
149
+ if is_script_run():
150
+ script = os.path.abspath(sys.argv[0])
151
+ else:
152
+ # Interactive session or module run - use unique name to avoid
153
+ # linking unrelated sessions
154
+ script = str(uuid4())
155
+ python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
156
+
157
+ # try to find the parent job
158
+ parent = self.catalog.metastore.get_last_job_by_name(script)
159
+
160
+ job_id = self.catalog.metastore.create_job(
161
+ name=script,
162
+ query="",
163
+ query_type=JobQueryType.PYTHON,
164
+ status=JobStatus.RUNNING,
165
+ python_version=python_version,
166
+ parent_job_id=parent.id if parent else None,
167
+ )
168
+ Session._CURRENT_JOB = self.catalog.metastore.get_job(job_id)
169
+ Session._OWNS_JOB = True
170
+ Session._JOB_STATUS = JobStatus.RUNNING
171
+
172
+ # register cleanup hooks only once
173
+ if not Session._JOB_HOOKS_REGISTERED:
174
+
175
+ def _finalize_success_hook() -> None:
176
+ self._finalize_job_success()
177
+
178
+ Session._JOB_FINALIZE_HOOK = _finalize_success_hook
179
+ atexit.register(Session._JOB_FINALIZE_HOOK)
180
+ Session._JOB_HOOKS_REGISTERED = True
181
+
182
+ assert Session._CURRENT_JOB is not None
183
+ return Session._CURRENT_JOB
184
+
185
+ def _finalize_job_success(self):
186
+ """Mark the current job as completed."""
187
+ if (
188
+ Session._CURRENT_JOB
189
+ and Session._OWNS_JOB
190
+ and Session._JOB_STATUS == JobStatus.RUNNING
191
+ ):
192
+ self.catalog.metastore.set_job_status(
193
+ Session._CURRENT_JOB.id, JobStatus.COMPLETE
194
+ )
195
+ Session._JOB_STATUS = JobStatus.COMPLETE
196
+
197
+ def _finalize_job_as_canceled(self):
198
+ """Mark the current job as canceled."""
199
+ if (
200
+ Session._CURRENT_JOB
201
+ and Session._OWNS_JOB
202
+ and Session._JOB_STATUS == JobStatus.RUNNING
203
+ ):
204
+ self.catalog.metastore.set_job_status(
205
+ Session._CURRENT_JOB.id, JobStatus.CANCELED
206
+ )
207
+ Session._JOB_STATUS = JobStatus.CANCELED
208
+
209
+ def _finalize_job_as_failed(self, exc_type, exc_value, tb):
210
+ """Mark the current job as failed with error details."""
211
+ if (
212
+ Session._CURRENT_JOB
213
+ and Session._OWNS_JOB
214
+ and Session._JOB_STATUS == JobStatus.RUNNING
215
+ ):
216
+ error_stack = "".join(traceback.format_exception(exc_type, exc_value, tb))
217
+ self.catalog.metastore.set_job_status(
218
+ Session._CURRENT_JOB.id,
219
+ JobStatus.FAILED,
220
+ error_message=str(exc_value),
221
+ error_stack=error_stack,
222
+ )
223
+ Session._JOB_STATUS = JobStatus.FAILED
96
224
 
97
225
  def generate_temp_dataset_name(self) -> str:
98
226
  return self.get_temp_prefix() + uuid4().hex[: self.TEMP_TABLE_UUID_LEN]
@@ -113,22 +241,12 @@ class Session:
113
241
  except TableMissingError:
114
242
  pass
115
243
 
116
- def _cleanup_created_versions(self) -> None:
117
- if not self.dataset_versions:
118
- return
119
-
120
- for dataset, version, listing in self.dataset_versions:
121
- if not listing:
122
- self.catalog.remove_dataset_version(dataset, version)
123
-
124
- self.dataset_versions.clear()
125
-
126
244
  @classmethod
127
245
  def get(
128
246
  cls,
129
- session: Optional["Session"] = None,
130
- catalog: Optional["Catalog"] = None,
131
- client_config: Optional[dict] = None,
247
+ session: "Session | None" = None,
248
+ catalog: "Catalog | None" = None,
249
+ client_config: dict | None = None,
132
250
  in_memory: bool = False,
133
251
  ) -> "Session":
134
252
  """Creates a Session() object from a catalog.
@@ -173,33 +291,72 @@ class Session:
173
291
 
174
292
  @staticmethod
175
293
  def except_hook(exc_type, exc_value, exc_traceback):
176
- Session.GLOBAL_SESSION_CTX.__exit__(exc_type, exc_value, exc_traceback)
294
+ if Session.GLOBAL_SESSION_CTX:
295
+ # Handle KeyboardInterrupt specially - mark as canceled and exit with
296
+ # signal code
297
+ if exc_type is KeyboardInterrupt:
298
+ Session.GLOBAL_SESSION_CTX._finalize_job_as_canceled()
299
+ else:
300
+ Session.GLOBAL_SESSION_CTX._finalize_job_as_failed(
301
+ exc_type, exc_value, exc_traceback
302
+ )
303
+ Session.GLOBAL_SESSION_CTX.__exit__(exc_type, exc_value, exc_traceback)
304
+
177
305
  Session._global_cleanup()
178
306
 
307
+ # Always delegate to original hook if it exists
179
308
  if Session.ORIGINAL_EXCEPT_HOOK:
180
309
  Session.ORIGINAL_EXCEPT_HOOK(exc_type, exc_value, exc_traceback)
181
310
 
311
+ if exc_type is KeyboardInterrupt:
312
+ # Exit with SIGINT signal code (128 + 2 = 130, or -2 in subprocess terms)
313
+ sys.exit(130)
314
+
182
315
  @classmethod
183
316
  def cleanup_for_tests(cls):
317
+ cls._close_all_contexts()
184
318
  if cls.GLOBAL_SESSION_CTX is not None:
185
319
  cls.GLOBAL_SESSION_CTX.__exit__(None, None, None)
186
320
  cls.GLOBAL_SESSION_CTX = None
187
321
  atexit.unregister(cls._global_cleanup)
188
322
 
323
+ # Reset job-related class variables
324
+ if cls._JOB_FINALIZE_HOOK:
325
+ try:
326
+ atexit.unregister(cls._JOB_FINALIZE_HOOK)
327
+ except ValueError:
328
+ pass # Hook was not registered
329
+ cls._CURRENT_JOB = None
330
+ cls._JOB_STATUS = None
331
+ cls._OWNS_JOB = None
332
+ cls._JOB_HOOKS_REGISTERED = False
333
+ cls._JOB_FINALIZE_HOOK = None
334
+
189
335
  if cls.ORIGINAL_EXCEPT_HOOK:
190
336
  sys.excepthook = cls.ORIGINAL_EXCEPT_HOOK
191
337
 
192
338
  @staticmethod
193
339
  def _global_cleanup():
340
+ Session._close_all_contexts()
194
341
  if Session.GLOBAL_SESSION_CTX is not None:
195
342
  Session.GLOBAL_SESSION_CTX.__exit__(None, None, None)
196
343
 
197
- for obj in gc.get_objects(): # Get all tracked objects
344
+ for session in list(Session._ALL_SESSIONS):
198
345
  try:
199
- if isinstance(obj, Session):
200
- # Cleanup temp dataset for session variables.
201
- obj.__exit__(None, None, None)
346
+ session.__exit__(None, None, None)
202
347
  except ReferenceError:
203
348
  continue # Object has been finalized already
204
349
  except Exception as e: # noqa: BLE001
205
350
  logger.error(f"Exception while cleaning up session: {e}") # noqa: G004
351
+
352
+ @classmethod
353
+ def _close_all_contexts(cls) -> None:
354
+ while cls.SESSION_CONTEXTS:
355
+ session = cls.SESSION_CONTEXTS.pop()
356
+ try:
357
+ session.__exit__(None, None, None)
358
+ except Exception as exc: # noqa: BLE001
359
+ logger.error(
360
+ "Exception while closing session context during cleanup: %s",
361
+ exc,
362
+ )
datachain/query/udf.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import TYPE_CHECKING, Any, Callable, Optional, TypedDict, Union
2
+ from collections.abc import Callable
3
+ from typing import TYPE_CHECKING, Any, TypedDict
3
4
 
4
5
  if TYPE_CHECKING:
5
6
  from sqlalchemy import Select, Table
@@ -17,10 +18,11 @@ class UdfInfo(TypedDict):
17
18
  query: "Select"
18
19
  udf_fields: list[str]
19
20
  batching: "BatchingStrategy"
20
- processes: Optional[int]
21
+ processes: int | None
21
22
  is_generator: bool
22
23
  cache: bool
23
24
  rows_total: int
25
+ batch_size: int
24
26
 
25
27
 
26
28
  class AbstractUDFDistributor(ABC):
@@ -32,13 +34,14 @@ class AbstractUDFDistributor(ABC):
32
34
  query: "Select",
33
35
  udf_data: bytes,
34
36
  batching: "BatchingStrategy",
35
- workers: Union[bool, int],
36
- processes: Union[bool, int],
37
+ workers: bool | int,
38
+ processes: bool | int,
37
39
  udf_fields: list[str],
38
40
  rows_total: int,
39
41
  use_cache: bool,
40
42
  is_generator: bool = False,
41
- min_task_size: Optional[Union[str, int]] = None,
43
+ min_task_size: str | int | None = None,
44
+ batch_size: int | None = None,
42
45
  ) -> None: ...
43
46
 
44
47
  @abstractmethod
@@ -46,4 +49,4 @@ class AbstractUDFDistributor(ABC):
46
49
 
47
50
  @staticmethod
48
51
  @abstractmethod
49
- def run_udf(fd: Optional[int] = None) -> int: ...
52
+ def run_udf() -> int: ...
@@ -1,16 +1,10 @@
1
- import base64
2
1
  import json
3
2
  import logging
4
3
  import os
5
4
  from collections.abc import AsyncIterator, Iterable, Iterator
6
5
  from datetime import datetime, timedelta, timezone
7
6
  from struct import unpack
8
- from typing import (
9
- Any,
10
- Generic,
11
- Optional,
12
- TypeVar,
13
- )
7
+ from typing import Any, BinaryIO, Generic, TypeVar
14
8
  from urllib.parse import urlparse, urlunparse
15
9
 
16
10
  import websockets
@@ -22,16 +16,17 @@ from datachain.error import DataChainError
22
16
  from datachain.utils import STUDIO_URL, retry_with_backoff
23
17
 
24
18
  T = TypeVar("T")
25
- LsData = Optional[list[dict[str, Any]]]
26
- DatasetInfoData = Optional[dict[str, Any]]
27
- DatasetRowsData = Optional[Iterable[dict[str, Any]]]
28
- DatasetJobVersionsData = Optional[dict[str, Any]]
29
- DatasetExportStatus = Optional[dict[str, Any]]
30
- DatasetExportSignedUrls = Optional[list[str]]
31
- FileUploadData = Optional[dict[str, Any]]
32
- JobData = Optional[dict[str, Any]]
33
- JobListData = dict[str, Any]
34
- ClusterListData = dict[str, Any]
19
+ LsData = list[dict[str, Any]] | None
20
+ DatasetInfoData = dict[str, Any] | None
21
+ DatasetRowsData = Iterable[dict[str, Any]] | None
22
+ DatasetJobVersionsData = dict[str, Any] | None
23
+ DatasetExportStatus = dict[str, Any] | None
24
+ DatasetExportSignedUrls = list[str] | None
25
+ FileUploadData = dict[str, Any] | None
26
+ JobData = dict[str, Any] | None
27
+ JobListData = list[dict[str, Any]]
28
+ ClusterListData = list[dict[str, Any]]
29
+
35
30
  logger = logging.getLogger("datachain")
36
31
 
37
32
  DATASET_ROWS_CHUNK_SIZE = 8192
@@ -92,7 +87,7 @@ class Response(Generic[T]):
92
87
 
93
88
 
94
89
  class StudioClient:
95
- def __init__(self, timeout: float = 3600.0, team: Optional[str] = None) -> None:
90
+ def __init__(self, timeout: float = 3600.0, team: str | None = None) -> None:
96
91
  self._check_dependencies()
97
92
  self.timeout = timeout
98
93
  self._config = None
@@ -153,7 +148,7 @@ class StudioClient:
153
148
  ) from None
154
149
 
155
150
  def _send_request_msgpack(
156
- self, route: str, data: dict[str, Any], method: Optional[str] = "POST"
151
+ self, route: str, data: dict[str, Any], method: str | None = "POST"
157
152
  ) -> Response[Any]:
158
153
  import msgpack
159
154
  import requests
@@ -191,7 +186,7 @@ class StudioClient:
191
186
 
192
187
  @retry_with_backoff(retries=3, errors=(HTTPError, Timeout))
193
188
  def _send_request(
194
- self, route: str, data: dict[str, Any], method: Optional[str] = "POST"
189
+ self, route: str, data: dict[str, Any], method: str | None = "POST"
195
190
  ) -> Response[Any]:
196
191
  """
197
192
  Function that communicate Studio API.
@@ -239,6 +234,45 @@ class StudioClient:
239
234
 
240
235
  return Response(data, ok, message, response.status_code)
241
236
 
237
+ def _send_multipart_request(
238
+ self, route: str, files: dict[str, Any], params: dict[str, Any] | None = None
239
+ ) -> Response[Any]:
240
+ """
241
+ Function that communicates with Studio API using multipart/form-data.
242
+ It will raise an exception, and try to retry, if 5xx status code is
243
+ returned, or if Timeout exceptions is thrown from the requests lib
244
+ """
245
+ import requests
246
+
247
+ # Add team_name to params
248
+ request_params = {**(params or {}), "team_name": self.team}
249
+
250
+ response = requests.post(
251
+ url=f"{self.url}/{route}",
252
+ files=files,
253
+ params=request_params,
254
+ headers={
255
+ "Authorization": f"token {self.token}",
256
+ },
257
+ timeout=self.timeout,
258
+ )
259
+
260
+ ok = response.ok
261
+ try:
262
+ data = json.loads(response.content.decode("utf-8"))
263
+ except json.decoder.JSONDecodeError:
264
+ data = {}
265
+
266
+ if not ok:
267
+ if response.status_code == 403:
268
+ message = f"Not authorized for the team {self.team}"
269
+ else:
270
+ message = data.get("message", "")
271
+ else:
272
+ message = ""
273
+
274
+ return Response(data, ok, message, response.status_code)
275
+
242
276
  @staticmethod
243
277
  def _unpacker_hook(code, data):
244
278
  import msgpack
@@ -305,7 +339,7 @@ class StudioClient:
305
339
  response = self._send_request_msgpack("datachain/ls", {"source": path})
306
340
  yield path, response
307
341
 
308
- def ls_datasets(self, prefix: Optional[str] = None) -> Response[LsData]:
342
+ def ls_datasets(self, prefix: str | None = None) -> Response[LsData]:
309
343
  return self._send_request(
310
344
  "datachain/datasets", {"prefix": prefix}, method="GET"
311
345
  )
@@ -315,9 +349,9 @@ class StudioClient:
315
349
  name: str,
316
350
  namespace: str,
317
351
  project: str,
318
- new_name: Optional[str] = None,
319
- description: Optional[str] = None,
320
- attrs: Optional[list[str]] = None,
352
+ new_name: str | None = None,
353
+ description: str | None = None,
354
+ attrs: list[str] | None = None,
321
355
  ) -> Response[DatasetInfoData]:
322
356
  body = {
323
357
  "new_name": new_name,
@@ -338,8 +372,8 @@ class StudioClient:
338
372
  name: str,
339
373
  namespace: str,
340
374
  project: str,
341
- version: Optional[str] = None,
342
- force: Optional[bool] = False,
375
+ version: str | None = None,
376
+ force: bool | None = False,
343
377
  ) -> Response[DatasetInfoData]:
344
378
  return self._send_request(
345
379
  "datachain/datasets",
@@ -409,29 +443,30 @@ class StudioClient:
409
443
  method="GET",
410
444
  )
411
445
 
412
- def upload_file(self, content: bytes, file_name: str) -> Response[FileUploadData]:
413
- data = {
414
- "file_content": base64.b64encode(content).decode("utf-8"),
415
- "file_name": file_name,
416
- }
417
- return self._send_request("datachain/upload-file", data)
446
+ def upload_file(
447
+ self, file_obj: BinaryIO, file_name: str
448
+ ) -> Response[FileUploadData]:
449
+ # Prepare multipart form data
450
+ files = {"file": (file_name, file_obj, "application/octet-stream")}
451
+
452
+ return self._send_multipart_request("datachain/jobs/files", files)
418
453
 
419
454
  def create_job(
420
455
  self,
421
456
  query: str,
422
457
  query_type: str,
423
- environment: Optional[str] = None,
424
- workers: Optional[int] = None,
425
- query_name: Optional[str] = None,
426
- files: Optional[list[str]] = None,
427
- python_version: Optional[str] = None,
428
- requirements: Optional[str] = None,
429
- repository: Optional[str] = None,
430
- priority: Optional[int] = None,
431
- cluster: Optional[str] = None,
432
- start_time: Optional[str] = None,
433
- cron: Optional[str] = None,
434
- credentials_name: Optional[str] = None,
458
+ environment: str | None = None,
459
+ workers: int | None = None,
460
+ query_name: str | None = None,
461
+ files: list[str] | None = None,
462
+ python_version: str | None = None,
463
+ requirements: str | None = None,
464
+ repository: str | None = None,
465
+ priority: int | None = None,
466
+ cluster: str | None = None,
467
+ start_time: str | None = None,
468
+ cron: str | None = None,
469
+ credentials_name: str | None = None,
435
470
  ) -> Response[JobData]:
436
471
  data = {
437
472
  "query": query,
@@ -449,25 +484,27 @@ class StudioClient:
449
484
  "cron_expression": cron,
450
485
  "credentials_name": credentials_name,
451
486
  }
452
- return self._send_request("datachain/job", data)
487
+ return self._send_request("datachain/jobs/", data)
453
488
 
454
489
  def get_jobs(
455
490
  self,
456
- status: Optional[str] = None,
491
+ status: str | None = None,
457
492
  limit: int = 20,
493
+ job_id: str | None = None,
458
494
  ) -> Response[JobListData]:
459
- return self._send_request(
460
- "datachain/jobs",
461
- {"status": status, "limit": limit} if status else {"limit": limit},
462
- method="GET",
463
- )
495
+ params: dict[str, Any] = {"limit": limit}
496
+ if status is not None:
497
+ params["status"] = status
498
+ if job_id is not None:
499
+ params["job_id"] = job_id
500
+ return self._send_request("datachain/jobs/", params, method="GET")
464
501
 
465
502
  def cancel_job(
466
503
  self,
467
504
  job_id: str,
468
505
  ) -> Response[JobData]:
469
- url = f"datachain/job/{job_id}/cancel"
506
+ url = f"datachain/jobs/{job_id}/cancel"
470
507
  return self._send_request(url, data={}, method="POST")
471
508
 
472
509
  def get_clusters(self) -> Response[ClusterListData]:
473
- return self._send_request("datachain/clusters", {}, method="GET")
510
+ return self._send_request("datachain/clusters/", {}, method="GET")
datachain/script_meta.py CHANGED
@@ -1,6 +1,6 @@
1
1
  import re
2
2
  from dataclasses import dataclass
3
- from typing import Any, Optional
3
+ from typing import Any
4
4
 
5
5
  try:
6
6
  import tomllib
@@ -59,23 +59,23 @@ class ScriptConfig:
59
59
 
60
60
  """
61
61
 
62
- python_version: Optional[str]
62
+ python_version: str | None
63
63
  dependencies: list[str]
64
64
  attachments: dict[str, str]
65
65
  params: dict[str, Any]
66
66
  inputs: dict[str, Any]
67
67
  outputs: dict[str, Any]
68
- num_workers: Optional[int] = None
68
+ num_workers: int | None = None
69
69
 
70
70
  def __init__(
71
71
  self,
72
- python_version: Optional[str] = None,
73
- dependencies: Optional[list[str]] = None,
74
- attachments: Optional[dict[str, str]] = None,
75
- params: Optional[dict[str, Any]] = None,
76
- inputs: Optional[dict[str, Any]] = None,
77
- outputs: Optional[dict[str, Any]] = None,
78
- num_workers: Optional[int] = None,
72
+ python_version: str | None = None,
73
+ dependencies: list[str] | None = None,
74
+ attachments: dict[str, str] | None = None,
75
+ params: dict[str, Any] | None = None,
76
+ inputs: dict[str, Any] | None = None,
77
+ outputs: dict[str, Any] | None = None,
78
+ num_workers: int | None = None,
79
79
  ):
80
80
  self.python_version = python_version
81
81
  self.dependencies = dependencies or []
@@ -98,7 +98,7 @@ class ScriptConfig:
98
98
  return self.attachments.get(name, default)
99
99
 
100
100
  @staticmethod
101
- def read(script: str) -> Optional[dict]:
101
+ def read(script: str) -> dict | None:
102
102
  """Converts inline script metadata to dict with all found data"""
103
103
  regex = (
104
104
  r"(?m)^# \/\/\/ (?P<type>[a-zA-Z0-9-]+)[ \t]*$[\r\n|\r|\n]"
@@ -119,7 +119,7 @@ class ScriptConfig:
119
119
  return None
120
120
 
121
121
  @staticmethod
122
- def parse(script: str) -> Optional["ScriptConfig"]:
122
+ def parse(script: str) -> "ScriptConfig | None":
123
123
  """
124
124
  Method that is parsing inline script metadata from datachain script and
125
125
  instantiating ScriptConfig class with found data. If no inline metadata is