datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. datachain/__init__.py +20 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +7 -7
  4. datachain/catalog/__init__.py +2 -2
  5. datachain/catalog/catalog.py +621 -507
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +28 -18
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +24 -33
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +83 -52
  12. datachain/cli/commands/ls.py +17 -17
  13. datachain/cli/commands/show.py +4 -4
  14. datachain/cli/parser/__init__.py +8 -74
  15. datachain/cli/parser/job.py +95 -3
  16. datachain/cli/parser/studio.py +11 -4
  17. datachain/cli/parser/utils.py +1 -2
  18. datachain/cli/utils.py +2 -15
  19. datachain/client/azure.py +4 -4
  20. datachain/client/fsspec.py +45 -28
  21. datachain/client/gcs.py +6 -6
  22. datachain/client/hf.py +29 -2
  23. datachain/client/http.py +157 -0
  24. datachain/client/local.py +15 -11
  25. datachain/client/s3.py +17 -9
  26. datachain/config.py +4 -8
  27. datachain/data_storage/db_engine.py +12 -6
  28. datachain/data_storage/job.py +5 -1
  29. datachain/data_storage/metastore.py +1252 -186
  30. datachain/data_storage/schema.py +58 -45
  31. datachain/data_storage/serializer.py +105 -15
  32. datachain/data_storage/sqlite.py +286 -127
  33. datachain/data_storage/warehouse.py +250 -113
  34. datachain/dataset.py +353 -148
  35. datachain/delta.py +391 -0
  36. datachain/diff/__init__.py +27 -29
  37. datachain/error.py +60 -0
  38. datachain/func/__init__.py +2 -1
  39. datachain/func/aggregate.py +66 -42
  40. datachain/func/array.py +242 -38
  41. datachain/func/base.py +7 -4
  42. datachain/func/conditional.py +110 -60
  43. datachain/func/func.py +96 -45
  44. datachain/func/numeric.py +55 -38
  45. datachain/func/path.py +32 -20
  46. datachain/func/random.py +2 -2
  47. datachain/func/string.py +67 -37
  48. datachain/func/window.py +7 -8
  49. datachain/hash_utils.py +123 -0
  50. datachain/job.py +11 -7
  51. datachain/json.py +138 -0
  52. datachain/lib/arrow.py +58 -22
  53. datachain/lib/audio.py +245 -0
  54. datachain/lib/clip.py +14 -13
  55. datachain/lib/convert/flatten.py +5 -3
  56. datachain/lib/convert/python_to_sql.py +6 -10
  57. datachain/lib/convert/sql_to_python.py +8 -0
  58. datachain/lib/convert/values_to_tuples.py +156 -51
  59. datachain/lib/data_model.py +42 -20
  60. datachain/lib/dataset_info.py +36 -8
  61. datachain/lib/dc/__init__.py +8 -2
  62. datachain/lib/dc/csv.py +25 -28
  63. datachain/lib/dc/database.py +398 -0
  64. datachain/lib/dc/datachain.py +1289 -425
  65. datachain/lib/dc/datasets.py +320 -38
  66. datachain/lib/dc/hf.py +38 -24
  67. datachain/lib/dc/json.py +29 -32
  68. datachain/lib/dc/listings.py +112 -8
  69. datachain/lib/dc/pandas.py +16 -12
  70. datachain/lib/dc/parquet.py +35 -23
  71. datachain/lib/dc/records.py +31 -23
  72. datachain/lib/dc/storage.py +154 -64
  73. datachain/lib/dc/storage_pattern.py +251 -0
  74. datachain/lib/dc/utils.py +24 -16
  75. datachain/lib/dc/values.py +8 -9
  76. datachain/lib/file.py +622 -89
  77. datachain/lib/hf.py +69 -39
  78. datachain/lib/image.py +14 -14
  79. datachain/lib/listing.py +14 -11
  80. datachain/lib/listing_info.py +1 -2
  81. datachain/lib/meta_formats.py +3 -4
  82. datachain/lib/model_store.py +39 -7
  83. datachain/lib/namespaces.py +125 -0
  84. datachain/lib/projects.py +130 -0
  85. datachain/lib/pytorch.py +32 -21
  86. datachain/lib/settings.py +192 -56
  87. datachain/lib/signal_schema.py +427 -104
  88. datachain/lib/tar.py +1 -2
  89. datachain/lib/text.py +8 -7
  90. datachain/lib/udf.py +164 -76
  91. datachain/lib/udf_signature.py +60 -35
  92. datachain/lib/utils.py +118 -4
  93. datachain/lib/video.py +17 -9
  94. datachain/lib/webdataset.py +61 -56
  95. datachain/lib/webdataset_laion.py +15 -16
  96. datachain/listing.py +22 -10
  97. datachain/model/bbox.py +3 -1
  98. datachain/model/ultralytics/bbox.py +16 -12
  99. datachain/model/ultralytics/pose.py +16 -12
  100. datachain/model/ultralytics/segment.py +16 -12
  101. datachain/namespace.py +84 -0
  102. datachain/node.py +6 -6
  103. datachain/nodes_thread_pool.py +0 -1
  104. datachain/plugins.py +24 -0
  105. datachain/project.py +78 -0
  106. datachain/query/batch.py +40 -41
  107. datachain/query/dataset.py +604 -322
  108. datachain/query/dispatch.py +261 -154
  109. datachain/query/metrics.py +4 -6
  110. datachain/query/params.py +2 -3
  111. datachain/query/queue.py +3 -12
  112. datachain/query/schema.py +11 -6
  113. datachain/query/session.py +200 -33
  114. datachain/query/udf.py +34 -2
  115. datachain/remote/studio.py +171 -69
  116. datachain/script_meta.py +12 -12
  117. datachain/semver.py +68 -0
  118. datachain/sql/__init__.py +2 -0
  119. datachain/sql/functions/array.py +33 -1
  120. datachain/sql/postgresql_dialect.py +9 -0
  121. datachain/sql/postgresql_types.py +21 -0
  122. datachain/sql/sqlite/__init__.py +5 -1
  123. datachain/sql/sqlite/base.py +102 -29
  124. datachain/sql/sqlite/types.py +8 -13
  125. datachain/sql/types.py +70 -15
  126. datachain/studio.py +223 -46
  127. datachain/toolkit/split.py +31 -10
  128. datachain/utils.py +101 -59
  129. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
  130. datachain-0.39.0.dist-info/RECORD +173 -0
  131. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
  132. datachain/cli/commands/query.py +0 -53
  133. datachain/query/utils.py +0 -42
  134. datachain-0.14.2.dist-info/RECORD +0 -158
  135. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  136. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  137. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
@@ -1,47 +1,64 @@
1
- import base64
2
1
  import json
3
2
  import logging
4
3
  import os
5
4
  from collections.abc import AsyncIterator, Iterable, Iterator
6
5
  from datetime import datetime, timedelta, timezone
7
6
  from struct import unpack
8
- from typing import (
9
- Any,
10
- Generic,
11
- Optional,
12
- TypeVar,
13
- )
7
+ from typing import Any, BinaryIO, Generic, TypeVar
14
8
  from urllib.parse import urlparse, urlunparse
15
9
 
16
10
  import websockets
17
11
  from requests.exceptions import HTTPError, Timeout
18
12
 
19
13
  from datachain.config import Config
14
+ from datachain.dataset import DatasetRecord
20
15
  from datachain.error import DataChainError
21
16
  from datachain.utils import STUDIO_URL, retry_with_backoff
22
17
 
23
18
  T = TypeVar("T")
24
- LsData = Optional[list[dict[str, Any]]]
25
- DatasetInfoData = Optional[dict[str, Any]]
26
- DatasetRowsData = Optional[Iterable[dict[str, Any]]]
27
- DatasetJobVersionsData = Optional[dict[str, Any]]
28
- DatasetExportStatus = Optional[dict[str, Any]]
29
- DatasetExportSignedUrls = Optional[list[str]]
30
- FileUploadData = Optional[dict[str, Any]]
31
- JobData = Optional[dict[str, Any]]
19
+ LsData = list[dict[str, Any]] | None
20
+ DatasetInfoData = dict[str, Any] | None
21
+ DatasetRowsData = Iterable[dict[str, Any]] | None
22
+ DatasetJobVersionsData = dict[str, Any] | None
23
+ DatasetExportStatus = dict[str, Any] | None
24
+ DatasetExportSignedUrls = list[str] | None
25
+ FileUploadData = dict[str, Any] | None
26
+ JobData = dict[str, Any] | None
27
+ JobListData = list[dict[str, Any]]
28
+ ClusterListData = list[dict[str, Any]]
32
29
 
33
30
  logger = logging.getLogger("datachain")
34
31
 
35
32
  DATASET_ROWS_CHUNK_SIZE = 8192
36
33
 
37
34
 
35
+ def get_studio_env_variable(name: str) -> Any:
36
+ """
37
+ Get the value of a DataChain Studio environment variable.
38
+ It first checks for the variable prefixed with 'DATACHAIN_STUDIO_',
39
+ then checks for the deprecated 'DVC_STUDIO_' prefix.
40
+ If neither is set, it returns the provided default value.
41
+ """
42
+ if (value := os.environ.get(f"DATACHAIN_STUDIO_{name}")) is not None:
43
+ return value
44
+ if (value := os.environ.get(f"DVC_STUDIO_{name}")) is not None: # deprecated
45
+ logger.warning(
46
+ "Environment variable 'DVC_STUDIO_%s' is deprecated, "
47
+ "use 'DATACHAIN_STUDIO_%s' instead.",
48
+ name,
49
+ name,
50
+ )
51
+ return value
52
+ return None
53
+
54
+
38
55
  def _is_server_error(status_code: int) -> bool:
39
56
  return str(status_code).startswith("5")
40
57
 
41
58
 
42
59
  def is_token_set() -> bool:
43
60
  return (
44
- bool(os.environ.get("DVC_STUDIO_TOKEN"))
61
+ bool(get_studio_env_variable("TOKEN"))
45
62
  or Config().read().get("studio", {}).get("token") is not None
46
63
  )
47
64
 
@@ -56,10 +73,11 @@ def _parse_dates(obj: dict, date_fields: list[str]):
56
73
 
57
74
 
58
75
  class Response(Generic[T]):
59
- def __init__(self, data: T, ok: bool, message: str) -> None:
76
+ def __init__(self, data: T, ok: bool, message: str, status: int) -> None:
60
77
  self.data = data
61
78
  self.ok = ok
62
79
  self.message = message
80
+ self.status = status
63
81
 
64
82
  def __repr__(self):
65
83
  return (
@@ -69,7 +87,7 @@ class Response(Generic[T]):
69
87
 
70
88
 
71
89
  class StudioClient:
72
- def __init__(self, timeout: float = 3600.0, team: Optional[str] = None) -> None:
90
+ def __init__(self, timeout: float = 3600.0, team: str | None = None) -> None:
73
91
  self._check_dependencies()
74
92
  self.timeout = timeout
75
93
  self._config = None
@@ -77,12 +95,12 @@ class StudioClient:
77
95
 
78
96
  @property
79
97
  def token(self) -> str:
80
- token = os.environ.get("DVC_STUDIO_TOKEN") or self.config.get("token")
98
+ token = get_studio_env_variable("TOKEN") or self.config.get("token")
81
99
 
82
100
  if not token:
83
101
  raise DataChainError(
84
102
  "Studio token is not set. Use `datachain auth login` "
85
- "or environment variable `DVC_STUDIO_TOKEN` to set it."
103
+ "or environment variable `DATACHAIN_STUDIO_TOKEN` to set it."
86
104
  )
87
105
 
88
106
  return token
@@ -90,8 +108,8 @@ class StudioClient:
90
108
  @property
91
109
  def url(self) -> str:
92
110
  return (
93
- os.environ.get("DVC_STUDIO_URL") or self.config.get("url") or STUDIO_URL
94
- ) + "/api"
111
+ get_studio_env_variable("URL") or self.config.get("url") or STUDIO_URL
112
+ ).rstrip("/") + "/api"
95
113
 
96
114
  @property
97
115
  def config(self) -> dict:
@@ -106,13 +124,13 @@ class StudioClient:
106
124
  return self._team
107
125
 
108
126
  def _get_team(self) -> str:
109
- team = os.environ.get("DVC_STUDIO_TEAM") or self.config.get("team")
127
+ team = get_studio_env_variable("TEAM") or self.config.get("team")
110
128
 
111
129
  if not team:
112
130
  raise DataChainError(
113
131
  "Studio team is not set. "
114
132
  "Use `datachain auth team <team_name>` "
115
- "or environment variable `DVC_STUDIO_TEAM` to set it. "
133
+ "or environment variable `DATACHAIN_STUDIO_TEAM` to set it. "
116
134
  "You can also set `studio.team` in the config file."
117
135
  )
118
136
 
@@ -130,7 +148,7 @@ class StudioClient:
130
148
  ) from None
131
149
 
132
150
  def _send_request_msgpack(
133
- self, route: str, data: dict[str, Any], method: Optional[str] = "POST"
151
+ self, route: str, data: dict[str, Any], method: str | None = "POST"
134
152
  ) -> Response[Any]:
135
153
  import msgpack
136
154
  import requests
@@ -164,11 +182,11 @@ class StudioClient:
164
182
  message = "Indexing in progress"
165
183
  else:
166
184
  message = content.get("message", "")
167
- return Response(response_data, ok, message)
185
+ return Response(response_data, ok, message, response.status_code)
168
186
 
169
187
  @retry_with_backoff(retries=3, errors=(HTTPError, Timeout))
170
188
  def _send_request(
171
- self, route: str, data: dict[str, Any], method: Optional[str] = "POST"
189
+ self, route: str, data: dict[str, Any], method: str | None = "POST"
172
190
  ) -> Response[Any]:
173
191
  """
174
192
  Function that communicate Studio API.
@@ -214,7 +232,46 @@ class StudioClient:
214
232
  else:
215
233
  message = ""
216
234
 
217
- return Response(data, ok, message)
235
+ return Response(data, ok, message, response.status_code)
236
+
237
+ def _send_multipart_request(
238
+ self, route: str, files: dict[str, Any], params: dict[str, Any] | None = None
239
+ ) -> Response[Any]:
240
+ """
241
+ Function that communicates with Studio API using multipart/form-data.
242
+ It will raise an exception, and try to retry, if 5xx status code is
243
+ returned, or if Timeout exceptions is thrown from the requests lib
244
+ """
245
+ import requests
246
+
247
+ # Add team_name to params
248
+ request_params = {**(params or {}), "team_name": self.team}
249
+
250
+ response = requests.post(
251
+ url=f"{self.url}/{route}",
252
+ files=files,
253
+ params=request_params,
254
+ headers={
255
+ "Authorization": f"token {self.token}",
256
+ },
257
+ timeout=self.timeout,
258
+ )
259
+
260
+ ok = response.ok
261
+ try:
262
+ data = json.loads(response.content.decode("utf-8"))
263
+ except json.decoder.JSONDecodeError:
264
+ data = {}
265
+
266
+ if not ok:
267
+ if response.status_code == 403:
268
+ message = f"Not authorized for the team {self.team}"
269
+ else:
270
+ message = data.get("message", "")
271
+ else:
272
+ message = ""
273
+
274
+ return Response(data, ok, message, response.status_code)
218
275
 
219
276
  @staticmethod
220
277
  def _unpacker_hook(code, data):
@@ -282,21 +339,27 @@ class StudioClient:
282
339
  response = self._send_request_msgpack("datachain/ls", {"source": path})
283
340
  yield path, response
284
341
 
285
- def ls_datasets(self) -> Response[LsData]:
286
- return self._send_request("datachain/datasets", {}, method="GET")
342
+ def ls_datasets(self, prefix: str | None = None) -> Response[LsData]:
343
+ return self._send_request(
344
+ "datachain/datasets", {"prefix": prefix}, method="GET"
345
+ )
287
346
 
288
347
  def edit_dataset(
289
348
  self,
290
349
  name: str,
291
- new_name: Optional[str] = None,
292
- description: Optional[str] = None,
293
- labels: Optional[list[str]] = None,
350
+ namespace: str,
351
+ project: str,
352
+ new_name: str | None = None,
353
+ description: str | None = None,
354
+ attrs: list[str] | None = None,
294
355
  ) -> Response[DatasetInfoData]:
295
356
  body = {
296
357
  "new_name": new_name,
297
- "dataset_name": name,
358
+ "name": name,
359
+ "namespace": namespace,
360
+ "project": project,
298
361
  "description": description,
299
- "labels": labels,
362
+ "attrs": attrs,
300
363
  }
301
364
 
302
365
  return self._send_request(
@@ -307,44 +370,44 @@ class StudioClient:
307
370
  def rm_dataset(
308
371
  self,
309
372
  name: str,
310
- version: Optional[int] = None,
311
- force: Optional[bool] = False,
373
+ namespace: str,
374
+ project: str,
375
+ version: str | None = None,
376
+ force: bool | None = False,
312
377
  ) -> Response[DatasetInfoData]:
313
378
  return self._send_request(
314
379
  "datachain/datasets",
315
380
  {
316
- "dataset_name": name,
317
- "dataset_version": version,
381
+ "name": name,
382
+ "namespace": namespace,
383
+ "project": project,
384
+ "version": version,
318
385
  "force": force,
319
386
  },
320
387
  method="DELETE",
321
388
  )
322
389
 
323
- def dataset_info(self, name: str) -> Response[DatasetInfoData]:
390
+ def dataset_info(
391
+ self, namespace: str, project: str, name: str
392
+ ) -> Response[DatasetInfoData]:
324
393
  def _parse_dataset_info(dataset_info):
325
394
  _parse_dates(dataset_info, ["created_at", "finished_at"])
326
395
  for version in dataset_info.get("versions"):
327
396
  _parse_dates(version, ["created_at"])
397
+ _parse_dates(dataset_info.get("project"), ["created_at"])
398
+ _parse_dates(dataset_info.get("project").get("namespace"), ["created_at"])
328
399
 
329
400
  return dataset_info
330
401
 
331
402
  response = self._send_request(
332
- "datachain/datasets/info", {"dataset_name": name}, method="GET"
403
+ "datachain/datasets/info",
404
+ {"namespace": namespace, "project": project, "name": name},
405
+ method="GET",
333
406
  )
334
407
  if response.ok:
335
408
  response.data = _parse_dataset_info(response.data)
336
409
  return response
337
410
 
338
- def dataset_rows_chunk(
339
- self, name: str, version: int, offset: int
340
- ) -> Response[DatasetRowsData]:
341
- req_data = {"dataset_name": name, "dataset_version": version}
342
- return self._send_request_msgpack(
343
- "datachain/datasets/rows",
344
- {**req_data, "offset": offset, "limit": DATASET_ROWS_CHUNK_SIZE},
345
- method="GET",
346
- )
347
-
348
411
  def dataset_job_versions(self, job_id: str) -> Response[DatasetJobVersionsData]:
349
412
  return self._send_request(
350
413
  "datachain/datasets/dataset_job_versions",
@@ -353,40 +416,57 @@ class StudioClient:
353
416
  )
354
417
 
355
418
  def export_dataset_table(
356
- self, name: str, version: int
419
+ self, dataset: DatasetRecord, version: str
357
420
  ) -> Response[DatasetExportSignedUrls]:
358
421
  return self._send_request(
359
422
  "datachain/datasets/export",
360
- {"dataset_name": name, "dataset_version": version},
423
+ {
424
+ "namespace": dataset.project.namespace.name,
425
+ "project": dataset.project.name,
426
+ "name": dataset.name,
427
+ "version": version,
428
+ },
361
429
  method="GET",
362
430
  )
363
431
 
364
432
  def dataset_export_status(
365
- self, name: str, version: int
433
+ self, dataset: DatasetRecord, version: str
366
434
  ) -> Response[DatasetExportStatus]:
367
435
  return self._send_request(
368
436
  "datachain/datasets/export-status",
369
- {"dataset_name": name, "dataset_version": version},
437
+ {
438
+ "namespace": dataset.project.namespace.name,
439
+ "project": dataset.project.name,
440
+ "name": dataset.name,
441
+ "version": version,
442
+ },
370
443
  method="GET",
371
444
  )
372
445
 
373
- def upload_file(self, content: bytes, file_name: str) -> Response[FileUploadData]:
374
- data = {
375
- "file_content": base64.b64encode(content).decode("utf-8"),
376
- "file_name": file_name,
377
- }
378
- return self._send_request("datachain/upload-file", data)
446
+ def upload_file(
447
+ self, file_obj: BinaryIO, file_name: str
448
+ ) -> Response[FileUploadData]:
449
+ # Prepare multipart form data
450
+ files = {"file": (file_name, file_obj, "application/octet-stream")}
451
+
452
+ return self._send_multipart_request("datachain/jobs/files", files)
379
453
 
380
454
  def create_job(
381
455
  self,
382
456
  query: str,
383
457
  query_type: str,
384
- environment: Optional[str] = None,
385
- workers: Optional[int] = None,
386
- query_name: Optional[str] = None,
387
- files: Optional[list[str]] = None,
388
- python_version: Optional[str] = None,
389
- requirements: Optional[str] = None,
458
+ environment: str | None = None,
459
+ workers: int | None = None,
460
+ query_name: str | None = None,
461
+ files: list[str] | None = None,
462
+ python_version: str | None = None,
463
+ requirements: str | None = None,
464
+ repository: str | None = None,
465
+ priority: int | None = None,
466
+ cluster: str | None = None,
467
+ start_time: str | None = None,
468
+ cron: str | None = None,
469
+ credentials_name: str | None = None,
390
470
  ) -> Response[JobData]:
391
471
  data = {
392
472
  "query": query,
@@ -397,12 +477,34 @@ class StudioClient:
397
477
  "files": files,
398
478
  "python_version": python_version,
399
479
  "requirements": requirements,
480
+ "repository": repository,
481
+ "priority": priority,
482
+ "compute_cluster_name": cluster,
483
+ "start_after": start_time,
484
+ "cron_expression": cron,
485
+ "credentials_name": credentials_name,
400
486
  }
401
- return self._send_request("datachain/job", data)
487
+ return self._send_request("datachain/jobs/", data)
488
+
489
+ def get_jobs(
490
+ self,
491
+ status: str | None = None,
492
+ limit: int = 20,
493
+ job_id: str | None = None,
494
+ ) -> Response[JobListData]:
495
+ params: dict[str, Any] = {"limit": limit}
496
+ if status is not None:
497
+ params["status"] = status
498
+ if job_id is not None:
499
+ params["job_id"] = job_id
500
+ return self._send_request("datachain/jobs/", params, method="GET")
402
501
 
403
502
  def cancel_job(
404
503
  self,
405
504
  job_id: str,
406
505
  ) -> Response[JobData]:
407
- url = f"datachain/job/{job_id}/cancel"
506
+ url = f"datachain/jobs/{job_id}/cancel"
408
507
  return self._send_request(url, data={}, method="POST")
508
+
509
+ def get_clusters(self) -> Response[ClusterListData]:
510
+ return self._send_request("datachain/clusters/", {}, method="GET")
datachain/script_meta.py CHANGED
@@ -1,6 +1,6 @@
1
1
  import re
2
2
  from dataclasses import dataclass
3
- from typing import Any, Optional
3
+ from typing import Any
4
4
 
5
5
  try:
6
6
  import tomllib
@@ -59,23 +59,23 @@ class ScriptConfig:
59
59
 
60
60
  """
61
61
 
62
- python_version: Optional[str]
62
+ python_version: str | None
63
63
  dependencies: list[str]
64
64
  attachments: dict[str, str]
65
65
  params: dict[str, Any]
66
66
  inputs: dict[str, Any]
67
67
  outputs: dict[str, Any]
68
- num_workers: Optional[int] = None
68
+ num_workers: int | None = None
69
69
 
70
70
  def __init__(
71
71
  self,
72
- python_version: Optional[str] = None,
73
- dependencies: Optional[list[str]] = None,
74
- attachments: Optional[dict[str, str]] = None,
75
- params: Optional[dict[str, Any]] = None,
76
- inputs: Optional[dict[str, Any]] = None,
77
- outputs: Optional[dict[str, Any]] = None,
78
- num_workers: Optional[int] = None,
72
+ python_version: str | None = None,
73
+ dependencies: list[str] | None = None,
74
+ attachments: dict[str, str] | None = None,
75
+ params: dict[str, Any] | None = None,
76
+ inputs: dict[str, Any] | None = None,
77
+ outputs: dict[str, Any] | None = None,
78
+ num_workers: int | None = None,
79
79
  ):
80
80
  self.python_version = python_version
81
81
  self.dependencies = dependencies or []
@@ -98,7 +98,7 @@ class ScriptConfig:
98
98
  return self.attachments.get(name, default)
99
99
 
100
100
  @staticmethod
101
- def read(script: str) -> Optional[dict]:
101
+ def read(script: str) -> dict | None:
102
102
  """Converts inline script metadata to dict with all found data"""
103
103
  regex = (
104
104
  r"(?m)^# \/\/\/ (?P<type>[a-zA-Z0-9-]+)[ \t]*$[\r\n|\r|\n]"
@@ -119,7 +119,7 @@ class ScriptConfig:
119
119
  return None
120
120
 
121
121
  @staticmethod
122
- def parse(script: str) -> Optional["ScriptConfig"]:
122
+ def parse(script: str) -> "ScriptConfig | None":
123
123
  """
124
124
  Method that is parsing inline script metadata from datachain script and
125
125
  instantiating ScriptConfig class with found data. If no inline metadata is
datachain/semver.py ADDED
@@ -0,0 +1,68 @@
1
+ # Maximum version number for semver (major.minor.patch) is 999999.999999.999999
2
+ # this number was chosen because value("999999.999999.999999") < 2**63 - 1
3
+ MAX_VERSION_NUMBER = 999_999
4
+
5
+
6
+ def parse(version: str) -> tuple[int, int, int]:
7
+ """Parsing semver into 3 integers: major, minor, patch"""
8
+ validate(version)
9
+ parts = version.split(".")
10
+ return int(parts[0]), int(parts[1]), int(parts[2])
11
+
12
+
13
+ def validate(version: str) -> None:
14
+ """
15
+ Raises exception if version doesn't have valid semver format which is:
16
+ <major>.<minor>.<patch> or one of version parts is not positive integer
17
+ """
18
+ error_message = (
19
+ "Invalid version. It should be in format: <major>.<minor>.<patch> where"
20
+ " each version part is positive integer"
21
+ )
22
+ parts = version.split(".")
23
+ if len(parts) != 3:
24
+ raise ValueError(error_message)
25
+ for part in parts:
26
+ try:
27
+ val = int(part)
28
+ assert 0 <= val <= MAX_VERSION_NUMBER
29
+ except (ValueError, AssertionError):
30
+ raise ValueError(error_message) from None
31
+
32
+
33
+ def create(major: int = 0, minor: int = 0, patch: int = 0) -> str:
34
+ """Creates new semver from 3 integers: major, minor and patch"""
35
+ if not (
36
+ 0 <= major <= MAX_VERSION_NUMBER
37
+ and 0 <= minor <= MAX_VERSION_NUMBER
38
+ and 0 <= patch <= MAX_VERSION_NUMBER
39
+ ):
40
+ raise ValueError("Major, minor and patch must be greater or equal to zero")
41
+
42
+ return ".".join([str(major), str(minor), str(patch)])
43
+
44
+
45
+ def value(version: str) -> int:
46
+ """
47
+ Calculate integer value of a version. This is useful when comparing two versions.
48
+ """
49
+ major, minor, patch = parse(version)
50
+ limit = MAX_VERSION_NUMBER + 1
51
+ return major * (limit**2) + minor * limit + patch
52
+
53
+
54
+ def compare(v1: str, v2: str) -> int:
55
+ """
56
+ Compares 2 versions and returns:
57
+ -1 if v1 < v2
58
+ 0 if v1 == v2
59
+ 1 if v1 > v2
60
+ """
61
+ v1_val = value(v1)
62
+ v2_val = value(v2)
63
+
64
+ if v1_val < v2_val:
65
+ return -1
66
+ if v1_val > v2_val:
67
+ return 1
68
+ return 0
datachain/sql/__init__.py CHANGED
@@ -1,6 +1,8 @@
1
1
  from sqlalchemy.sql.elements import literal
2
2
  from sqlalchemy.sql.expression import column
3
3
 
4
+ # Import PostgreSQL dialect registration (registers PostgreSQL type converter)
5
+ from . import postgresql_dialect # noqa: F401
4
6
  from .default import setup as default_setup
5
7
  from .selectable import select, values
6
8
 
@@ -1,6 +1,6 @@
1
1
  from sqlalchemy.sql.functions import GenericFunction
2
2
 
3
- from datachain.sql.types import Boolean, Float, Int64
3
+ from datachain.sql.types import Boolean, Float, Int64, String
4
4
  from datachain.sql.utils import compiler_not_implemented
5
5
 
6
6
 
@@ -48,6 +48,37 @@ class contains(GenericFunction): # noqa: N801
48
48
  inherit_cache = True
49
49
 
50
50
 
51
+ class slice(GenericFunction): # noqa: N801
52
+ """
53
+ Returns a slice of the array.
54
+ """
55
+
56
+ package = "array"
57
+ name = "slice"
58
+ inherit_cache = True
59
+
60
+
61
+ class join(GenericFunction): # noqa: N801
62
+ """
63
+ Returns the concatenation of the array elements.
64
+ """
65
+
66
+ type = String()
67
+ package = "array"
68
+ name = "join"
69
+ inherit_cache = True
70
+
71
+
72
+ class get_element(GenericFunction): # noqa: N801
73
+ """
74
+ Returns the element at the given index in the array.
75
+ """
76
+
77
+ package = "array"
78
+ name = "get_element"
79
+ inherit_cache = True
80
+
81
+
51
82
  class sip_hash_64(GenericFunction): # noqa: N801
52
83
  """
53
84
  Computes the SipHash-64 hash of the array.
@@ -63,4 +94,5 @@ compiler_not_implemented(cosine_distance)
63
94
  compiler_not_implemented(euclidean_distance)
64
95
  compiler_not_implemented(length)
65
96
  compiler_not_implemented(contains)
97
+ compiler_not_implemented(get_element)
66
98
  compiler_not_implemented(sip_hash_64)
@@ -0,0 +1,9 @@
1
+ """
2
+ PostgreSQL dialect registration for DataChain.
3
+ """
4
+
5
+ from datachain.sql.postgresql_types import PostgreSQLTypeConverter
6
+ from datachain.sql.types import register_backend_types
7
+
8
+ # Register PostgreSQL type converter
9
+ register_backend_types("postgresql", PostgreSQLTypeConverter())
@@ -0,0 +1,21 @@
1
+ """
2
+ PostgreSQL-specific type converter for DataChain.
3
+
4
+ Handles PostgreSQL-specific type mappings that differ from the default dialect.
5
+ """
6
+
7
+ from sqlalchemy.dialects import postgresql
8
+
9
+ from datachain.sql.types import TypeConverter
10
+
11
+
12
+ class PostgreSQLTypeConverter(TypeConverter):
13
+ """PostgreSQL-specific type converter."""
14
+
15
+ def datetime(self):
16
+ """PostgreSQL uses TIMESTAMP WITH TIME ZONE to preserve timezone information."""
17
+ return postgresql.TIMESTAMP(timezone=True)
18
+
19
+ def json(self):
20
+ """PostgreSQL uses JSONB for better performance and query capabilities."""
21
+ return postgresql.JSONB()
@@ -1,4 +1,8 @@
1
- from .base import create_user_defined_sql_functions, setup, sqlite_dialect
1
+ from .base import (
2
+ create_user_defined_sql_functions,
3
+ setup,
4
+ sqlite_dialect,
5
+ )
2
6
 
3
7
  __all__ = [
4
8
  "create_user_defined_sql_functions",