datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. datachain/__init__.py +20 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +7 -7
  4. datachain/catalog/__init__.py +2 -2
  5. datachain/catalog/catalog.py +621 -507
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +28 -18
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +24 -33
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +83 -52
  12. datachain/cli/commands/ls.py +17 -17
  13. datachain/cli/commands/show.py +4 -4
  14. datachain/cli/parser/__init__.py +8 -74
  15. datachain/cli/parser/job.py +95 -3
  16. datachain/cli/parser/studio.py +11 -4
  17. datachain/cli/parser/utils.py +1 -2
  18. datachain/cli/utils.py +2 -15
  19. datachain/client/azure.py +4 -4
  20. datachain/client/fsspec.py +45 -28
  21. datachain/client/gcs.py +6 -6
  22. datachain/client/hf.py +29 -2
  23. datachain/client/http.py +157 -0
  24. datachain/client/local.py +15 -11
  25. datachain/client/s3.py +17 -9
  26. datachain/config.py +4 -8
  27. datachain/data_storage/db_engine.py +12 -6
  28. datachain/data_storage/job.py +5 -1
  29. datachain/data_storage/metastore.py +1252 -186
  30. datachain/data_storage/schema.py +58 -45
  31. datachain/data_storage/serializer.py +105 -15
  32. datachain/data_storage/sqlite.py +286 -127
  33. datachain/data_storage/warehouse.py +250 -113
  34. datachain/dataset.py +353 -148
  35. datachain/delta.py +391 -0
  36. datachain/diff/__init__.py +27 -29
  37. datachain/error.py +60 -0
  38. datachain/func/__init__.py +2 -1
  39. datachain/func/aggregate.py +66 -42
  40. datachain/func/array.py +242 -38
  41. datachain/func/base.py +7 -4
  42. datachain/func/conditional.py +110 -60
  43. datachain/func/func.py +96 -45
  44. datachain/func/numeric.py +55 -38
  45. datachain/func/path.py +32 -20
  46. datachain/func/random.py +2 -2
  47. datachain/func/string.py +67 -37
  48. datachain/func/window.py +7 -8
  49. datachain/hash_utils.py +123 -0
  50. datachain/job.py +11 -7
  51. datachain/json.py +138 -0
  52. datachain/lib/arrow.py +58 -22
  53. datachain/lib/audio.py +245 -0
  54. datachain/lib/clip.py +14 -13
  55. datachain/lib/convert/flatten.py +5 -3
  56. datachain/lib/convert/python_to_sql.py +6 -10
  57. datachain/lib/convert/sql_to_python.py +8 -0
  58. datachain/lib/convert/values_to_tuples.py +156 -51
  59. datachain/lib/data_model.py +42 -20
  60. datachain/lib/dataset_info.py +36 -8
  61. datachain/lib/dc/__init__.py +8 -2
  62. datachain/lib/dc/csv.py +25 -28
  63. datachain/lib/dc/database.py +398 -0
  64. datachain/lib/dc/datachain.py +1289 -425
  65. datachain/lib/dc/datasets.py +320 -38
  66. datachain/lib/dc/hf.py +38 -24
  67. datachain/lib/dc/json.py +29 -32
  68. datachain/lib/dc/listings.py +112 -8
  69. datachain/lib/dc/pandas.py +16 -12
  70. datachain/lib/dc/parquet.py +35 -23
  71. datachain/lib/dc/records.py +31 -23
  72. datachain/lib/dc/storage.py +154 -64
  73. datachain/lib/dc/storage_pattern.py +251 -0
  74. datachain/lib/dc/utils.py +24 -16
  75. datachain/lib/dc/values.py +8 -9
  76. datachain/lib/file.py +622 -89
  77. datachain/lib/hf.py +69 -39
  78. datachain/lib/image.py +14 -14
  79. datachain/lib/listing.py +14 -11
  80. datachain/lib/listing_info.py +1 -2
  81. datachain/lib/meta_formats.py +3 -4
  82. datachain/lib/model_store.py +39 -7
  83. datachain/lib/namespaces.py +125 -0
  84. datachain/lib/projects.py +130 -0
  85. datachain/lib/pytorch.py +32 -21
  86. datachain/lib/settings.py +192 -56
  87. datachain/lib/signal_schema.py +427 -104
  88. datachain/lib/tar.py +1 -2
  89. datachain/lib/text.py +8 -7
  90. datachain/lib/udf.py +164 -76
  91. datachain/lib/udf_signature.py +60 -35
  92. datachain/lib/utils.py +118 -4
  93. datachain/lib/video.py +17 -9
  94. datachain/lib/webdataset.py +61 -56
  95. datachain/lib/webdataset_laion.py +15 -16
  96. datachain/listing.py +22 -10
  97. datachain/model/bbox.py +3 -1
  98. datachain/model/ultralytics/bbox.py +16 -12
  99. datachain/model/ultralytics/pose.py +16 -12
  100. datachain/model/ultralytics/segment.py +16 -12
  101. datachain/namespace.py +84 -0
  102. datachain/node.py +6 -6
  103. datachain/nodes_thread_pool.py +0 -1
  104. datachain/plugins.py +24 -0
  105. datachain/project.py +78 -0
  106. datachain/query/batch.py +40 -41
  107. datachain/query/dataset.py +604 -322
  108. datachain/query/dispatch.py +261 -154
  109. datachain/query/metrics.py +4 -6
  110. datachain/query/params.py +2 -3
  111. datachain/query/queue.py +3 -12
  112. datachain/query/schema.py +11 -6
  113. datachain/query/session.py +200 -33
  114. datachain/query/udf.py +34 -2
  115. datachain/remote/studio.py +171 -69
  116. datachain/script_meta.py +12 -12
  117. datachain/semver.py +68 -0
  118. datachain/sql/__init__.py +2 -0
  119. datachain/sql/functions/array.py +33 -1
  120. datachain/sql/postgresql_dialect.py +9 -0
  121. datachain/sql/postgresql_types.py +21 -0
  122. datachain/sql/sqlite/__init__.py +5 -1
  123. datachain/sql/sqlite/base.py +102 -29
  124. datachain/sql/sqlite/types.py +8 -13
  125. datachain/sql/types.py +70 -15
  126. datachain/studio.py +223 -46
  127. datachain/toolkit/split.py +31 -10
  128. datachain/utils.py +101 -59
  129. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
  130. datachain-0.39.0.dist-info/RECORD +173 -0
  131. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
  132. datachain/cli/commands/query.py +0 -53
  133. datachain/query/utils.py +0 -42
  134. datachain-0.14.2.dist-info/RECORD +0 -158
  135. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  136. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  137. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,14 @@
1
1
  import shlex
2
2
  from collections.abc import Iterable, Iterator
3
3
  from itertools import chain
4
- from typing import TYPE_CHECKING, Optional
5
-
6
- if TYPE_CHECKING:
7
- from datachain.catalog import Catalog
4
+ from typing import TYPE_CHECKING
8
5
 
9
6
  from datachain.cli.utils import determine_flavors
10
7
  from datachain.config import Config
8
+ from datachain.query.session import Session
9
+
10
+ if TYPE_CHECKING:
11
+ from datachain.catalog import Catalog
11
12
 
12
13
 
13
14
  def ls(
@@ -16,7 +17,7 @@ def ls(
16
17
  studio: bool = False,
17
18
  local: bool = False,
18
19
  all: bool = True,
19
- team: Optional[str] = None,
20
+ team: str | None = None,
20
21
  **kwargs,
21
22
  ):
22
23
  token = Config().read().get("studio", {}).get("token")
@@ -32,18 +33,15 @@ def ls(
32
33
  def ls_local(
33
34
  sources,
34
35
  long: bool = False,
35
- catalog: Optional["Catalog"] = None,
36
+ catalog=None,
36
37
  client_config=None,
37
38
  **kwargs,
38
39
  ):
39
40
  from datachain import listings
40
41
 
41
42
  if sources:
42
- if catalog is None:
43
- from datachain.catalog import get_catalog
44
-
45
- catalog = get_catalog(client_config=client_config)
46
-
43
+ session = Session.get(catalog=catalog, client_config=client_config)
44
+ catalog = session.catalog
47
45
  actual_sources = list(ls_urls(sources, catalog=catalog, long=long, **kwargs))
48
46
  if len(actual_sources) == 1:
49
47
  for _, entries in actual_sources:
@@ -63,8 +61,8 @@ def ls_local(
63
61
  print(format_ls_entry(entry))
64
62
  else:
65
63
  # Collect results in a list here to prevent interference from `tqdm` and `print`
66
- listing = list(listings().collect("listing"))
67
- for ls in listing:
64
+ listing = listings().to_list("listing")
65
+ for (ls,) in listing:
68
66
  print(format_ls_entry(f"{ls.uri}@v{ls.version}")) # type: ignore[union-attr]
69
67
 
70
68
 
@@ -78,7 +76,7 @@ def format_ls_entry(entry: str) -> str:
78
76
  def ls_remote(
79
77
  paths: Iterable[str],
80
78
  long: bool = False,
81
- team: Optional[str] = None,
79
+ team: str | None = None,
82
80
  ):
83
81
  from datachain.node import long_line_str
84
82
  from datachain.remote.studio import StudioClient
@@ -145,7 +143,7 @@ def _ls_urls_flat(
145
143
  long: bool,
146
144
  catalog: "Catalog",
147
145
  **kwargs,
148
- ) -> Iterator[tuple[str, Iterator[str]]]:
146
+ ) -> Iterator[tuple[str, Iterable[str]]]:
149
147
  from datachain.client import Client
150
148
  from datachain.node import long_line_str
151
149
 
@@ -154,7 +152,9 @@ def _ls_urls_flat(
154
152
  if client_cls.is_root_url(source):
155
153
  buckets = client_cls.ls_buckets(**catalog.client_config)
156
154
  if long:
157
- values = (long_line_str(b.name, b.created) for b in buckets)
155
+ values: Iterable[str] = (
156
+ long_line_str(b.name, b.created) for b in buckets
157
+ )
158
158
  else:
159
159
  values = (b.name for b in buckets)
160
160
  yield source, values
@@ -164,7 +164,7 @@ def _ls_urls_flat(
164
164
  if long:
165
165
  fields.append("last_modified")
166
166
  for data_source, results in catalog.ls([source], fields=fields, **kwargs):
167
- values = (_node_data_to_ls_values(r, long) for r in results)
167
+ values = [_node_data_to_ls_values(r, long) for r in results]
168
168
  found = True
169
169
  yield data_source.dirname(), values
170
170
  if not found:
@@ -1,5 +1,5 @@
1
1
  from collections.abc import Sequence
2
- from typing import TYPE_CHECKING, Optional
2
+ from typing import TYPE_CHECKING
3
3
 
4
4
  from datachain.lib.signal_schema import SignalSchema
5
5
 
@@ -10,7 +10,7 @@ if TYPE_CHECKING:
10
10
  def show(
11
11
  catalog: "Catalog",
12
12
  name: str,
13
- version: Optional[int] = None,
13
+ version: str | None = None,
14
14
  limit: int = 10,
15
15
  offset: int = 0,
16
16
  columns: Sequence[str] = (),
@@ -42,8 +42,8 @@ def show(
42
42
  print("Name: ", name)
43
43
  if dataset.description:
44
44
  print("Description: ", dataset.description)
45
- if dataset.labels:
46
- print("Labels: ", ",".join(dataset.labels))
45
+ if dataset.attrs:
46
+ print("Attributes: ", ",".join(dataset.attrs))
47
47
  print("\n")
48
48
 
49
49
  show_records(records, collapse_columns=not no_collapse, hidden_fields=hidden_fields)
@@ -3,7 +3,7 @@ from importlib.metadata import PackageNotFoundError, version
3
3
 
4
4
  import shtab
5
5
 
6
- from datachain.cli.utils import BooleanOptionalAction, KeyValueArgs
6
+ from datachain.cli.utils import BooleanOptionalAction
7
7
 
8
8
  from .job import add_jobs_parser
9
9
  from .studio import add_auth_parser
@@ -16,9 +16,7 @@ from .utils import (
16
16
  add_update_arg,
17
17
  find_columns_type,
18
18
  )
19
- from .utils import (
20
- CustomArgumentParser as ArgumentParser,
21
- )
19
+ from .utils import CustomArgumentParser as ArgumentParser
22
20
 
23
21
 
24
22
  def get_parser() -> ArgumentParser: # noqa: PLR0915
@@ -217,29 +215,9 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
217
215
  help="Dataset description",
218
216
  )
219
217
  parse_edit_dataset.add_argument(
220
- "--labels",
218
+ "--attrs",
221
219
  nargs="+",
222
- help="Dataset labels",
223
- )
224
- parse_edit_dataset.add_argument(
225
- "--studio",
226
- action="store_true",
227
- default=False,
228
- help="Edit dataset from Studio",
229
- )
230
- parse_edit_dataset.add_argument(
231
- "-L",
232
- "--local",
233
- action="store_true",
234
- default=False,
235
- help="Edit local dataset only",
236
- )
237
- parse_edit_dataset.add_argument(
238
- "-a",
239
- "--all",
240
- action="store_true",
241
- default=True,
242
- help="Edit both datasets from studio and local",
220
+ help="Dataset attributes",
243
221
  )
244
222
  parse_edit_dataset.add_argument(
245
223
  "--team",
@@ -302,7 +280,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
302
280
  "--version",
303
281
  action="store",
304
282
  default=None,
305
- type=int,
283
+ type=str,
306
284
  help="Dataset version",
307
285
  )
308
286
  rm_dataset_parser.add_argument(
@@ -315,21 +293,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
315
293
  "--studio",
316
294
  action="store_true",
317
295
  default=False,
318
- help="Remove dataset from Studio",
319
- )
320
- rm_dataset_parser.add_argument(
321
- "-L",
322
- "--local",
323
- action="store_true",
324
- default=False,
325
- help="Remove local datasets only",
326
- )
327
- rm_dataset_parser.add_argument(
328
- "-a",
329
- "--all",
330
- action="store_true",
331
- default=True,
332
- help="Remove both local and studio",
296
+ help="Remove dataset from Studio only",
333
297
  )
334
298
  rm_dataset_parser.add_argument(
335
299
  "--team",
@@ -495,43 +459,12 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
495
459
  "--version",
496
460
  action="store",
497
461
  default=None,
498
- type=int,
462
+ type=str,
499
463
  help="Dataset version",
500
464
  )
501
465
  show_parser.add_argument("--schema", action="store_true", help="Show schema")
502
466
  add_show_args(show_parser)
503
467
 
504
- query_parser = subp.add_parser(
505
- "query",
506
- parents=[parent_parser],
507
- description="Create a new dataset with a query script.",
508
- formatter_class=CustomHelpFormatter,
509
- )
510
- add_anon_arg(query_parser)
511
- query_parser.add_argument(
512
- "script", metavar="<script.py>", type=str, help="Filepath for script"
513
- )
514
- query_parser.add_argument(
515
- "--parallel",
516
- nargs="?",
517
- type=int,
518
- const=-1,
519
- default=None,
520
- metavar="N",
521
- help=(
522
- "Use multiprocessing to run any query script UDFs with N worker processes. "
523
- "N defaults to the CPU count"
524
- ),
525
- )
526
- query_parser.add_argument(
527
- "-p",
528
- "--param",
529
- metavar="param=value",
530
- nargs=1,
531
- action=KeyValueArgs,
532
- help="Query parameters",
533
- )
534
-
535
468
  parse_clear_cache = subp.add_parser(
536
469
  "clear-cache",
537
470
  parents=[parent_parser],
@@ -550,6 +483,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
550
483
 
551
484
  subp.add_parser("internal-run-udf", parents=[parent_parser])
552
485
  subp.add_parser("internal-run-udf-worker", parents=[parent_parser])
486
+
553
487
  add_completion_parser(subp, [parent_parser])
554
488
  return parser
555
489
 
@@ -13,11 +13,16 @@ def add_jobs_parser(subparsers, parent_parser) -> None:
13
13
  )
14
14
  jobs_subparser = jobs_parser.add_subparsers(
15
15
  dest="cmd",
16
- help="Use `datachain auth CMD --help` to display command-specific help",
16
+ help="Use `datachain job CMD --help` to display command-specific help",
17
17
  )
18
18
 
19
19
  studio_run_help = "Run a job in Studio"
20
- studio_run_description = "Run a job in Studio."
20
+ studio_run_description = "Run a job in Studio. \n"
21
+ studio_run_description += (
22
+ "When using --start-time or --cron,"
23
+ " the job is scheduled to run but won't start immediately"
24
+ " (can be seen in the Tasks tab in UI)"
25
+ )
21
26
 
22
27
  studio_run_parser = jobs_subparser.add_parser(
23
28
  "run",
@@ -51,6 +56,20 @@ def add_jobs_parser(subparsers, parent_parser) -> None:
51
56
  help="Environment variables in KEY=VALUE format",
52
57
  )
53
58
 
59
+ studio_run_parser.add_argument(
60
+ "--cluster",
61
+ type=str,
62
+ action="store",
63
+ help="Compute cluster to run the job on",
64
+ )
65
+
66
+ studio_run_parser.add_argument(
67
+ "-c",
68
+ "--credentials-name",
69
+ action="store",
70
+ help="Name of the credentials to use for the job",
71
+ )
72
+
54
73
  studio_run_parser.add_argument(
55
74
  "--workers",
56
75
  type=int,
@@ -64,7 +83,12 @@ def add_jobs_parser(subparsers, parent_parser) -> None:
64
83
  studio_run_parser.add_argument(
65
84
  "--python-version",
66
85
  action="store",
67
- help="Python version for the job (e.g., 3.9, 3.10, 3.11)",
86
+ help="Python version for the job (e.g., 3.10, 3.11, 3.12, 3.13)",
87
+ )
88
+ studio_run_parser.add_argument(
89
+ "--repository",
90
+ action="store",
91
+ help="Repository URL to clone before running the job",
68
92
  )
69
93
  studio_run_parser.add_argument(
70
94
  "--req-file",
@@ -77,6 +101,56 @@ def add_jobs_parser(subparsers, parent_parser) -> None:
77
101
  nargs="+",
78
102
  help="Python package requirements",
79
103
  )
104
+ studio_run_parser.add_argument(
105
+ "--priority",
106
+ type=int,
107
+ default=5,
108
+ help="Priority for the job in range 0-5. "
109
+ "Lower value is higher priority (default: 5)",
110
+ )
111
+ studio_run_parser.add_argument(
112
+ "--start-time",
113
+ action="store",
114
+ help="Time to schedule a task in YYYY-MM-DDTHH:mm format or natural language.",
115
+ )
116
+ studio_run_parser.add_argument(
117
+ "--cron", action="store", help="Cron expression for the cron task."
118
+ )
119
+ studio_run_parser.add_argument(
120
+ "--no-wait",
121
+ action="store_true",
122
+ help="Do not wait for the job to finish",
123
+ )
124
+
125
+ studio_ls_help = "List jobs in Studio"
126
+ studio_ls_description = "List jobs in Studio."
127
+
128
+ studio_ls_parser = jobs_subparser.add_parser(
129
+ "ls",
130
+ parents=[parent_parser],
131
+ description=studio_ls_description,
132
+ help=studio_ls_help,
133
+ formatter_class=CustomHelpFormatter,
134
+ )
135
+
136
+ studio_ls_parser.add_argument(
137
+ "--status",
138
+ action="store",
139
+ help="Status to filter jobs by",
140
+ )
141
+
142
+ studio_ls_parser.add_argument(
143
+ "--team",
144
+ action="store",
145
+ default=None,
146
+ help="Team to list jobs for (default: from config)",
147
+ )
148
+ studio_ls_parser.add_argument(
149
+ "--limit",
150
+ type=int,
151
+ default=20,
152
+ help="Limit the number of jobs returned (default: 20)",
153
+ )
80
154
 
81
155
  studio_cancel_help = "Cancel a job in Studio"
82
156
  studio_cancel_description = "Cancel a running job in Studio."
@@ -123,3 +197,21 @@ def add_jobs_parser(subparsers, parent_parser) -> None:
123
197
  default=None,
124
198
  help="Team to check logs for (default: from config)",
125
199
  )
200
+
201
+ studio_clusters_help = "List compute clusters in Studio"
202
+ studio_clusters_description = "List compute clusters in Studio."
203
+
204
+ studio_clusters_parser = jobs_subparser.add_parser(
205
+ "clusters",
206
+ parents=[parent_parser],
207
+ description=studio_clusters_description,
208
+ help=studio_clusters_help,
209
+ formatter_class=CustomHelpFormatter,
210
+ )
211
+
212
+ studio_clusters_parser.add_argument(
213
+ "--team",
214
+ action="store",
215
+ default=None,
216
+ help="Team to list clusters for (default: from config)",
217
+ )
@@ -89,8 +89,13 @@ def add_auth_parser(subparsers, parent_parser) -> None:
89
89
  help="Remove the token from the local project config",
90
90
  )
91
91
 
92
- auth_team_help = "Set default team for Studio operations"
93
- auth_team_description = "Set the default team for Studio operations."
92
+ auth_team_help = "Set or show default team for Studio operations"
93
+ auth_team_description = (
94
+ "Set or show the default team for Studio operations. "
95
+ "This will be used globally by default. "
96
+ "Use --local to set the team locally for the current project. "
97
+ "If no team name is provided, the default team will be shown."
98
+ )
94
99
 
95
100
  team_parser = auth_subparser.add_parser(
96
101
  "team",
@@ -102,13 +107,15 @@ def add_auth_parser(subparsers, parent_parser) -> None:
102
107
  team_parser.add_argument(
103
108
  "team_name",
104
109
  action="store",
110
+ default=None,
111
+ nargs="?",
105
112
  help="Name of the team to set as default",
106
113
  )
107
114
  team_parser.add_argument(
108
- "--global",
115
+ "--local",
109
116
  action="store_true",
110
117
  default=False,
111
- help="Set team globally for all projects",
118
+ help="Set team locally for the current project",
112
119
  )
113
120
 
114
121
  auth_token_help = "View Studio authentication token" # noqa: S105
@@ -1,5 +1,4 @@
1
1
  from argparse import Action, ArgumentParser, ArgumentTypeError, HelpFormatter
2
- from typing import Union
3
2
 
4
3
  from datachain.cli.utils import CommaSeparatedArgs
5
4
 
@@ -44,7 +43,7 @@ def parse_find_column(column: str) -> str:
44
43
  )
45
44
 
46
45
 
47
- def add_sources_arg(parser: ArgumentParser, nargs: Union[str, int] = "+") -> Action:
46
+ def add_sources_arg(parser: ArgumentParser, nargs: str | int = "+") -> Action:
48
47
  return parser.add_argument(
49
48
  "sources",
50
49
  type=str,
datachain/cli/utils.py CHANGED
@@ -1,6 +1,5 @@
1
1
  import logging
2
- from argparse import SUPPRESS, Action, ArgumentError, Namespace, _AppendAction
3
- from typing import Optional
2
+ from argparse import SUPPRESS, Action, Namespace, _AppendAction
4
3
 
5
4
  from datachain.error import DataChainError
6
5
 
@@ -64,18 +63,6 @@ class CommaSeparatedArgs(_AppendAction): # pylint: disable=protected-access
64
63
  setattr(namespace, self.dest, list(dict.fromkeys(items)))
65
64
 
66
65
 
67
- class KeyValueArgs(_AppendAction): # pylint: disable=protected-access
68
- def __call__(self, parser, namespace, values, option_string=None):
69
- items = getattr(namespace, self.dest) or {}
70
- for raw_value in filter(bool, values):
71
- key, sep, value = raw_value.partition("=")
72
- if not key or not sep or value == "":
73
- raise ArgumentError(self, f"expected 'key=value', got {raw_value!r}")
74
- items[key.strip()] = value
75
-
76
- setattr(namespace, self.dest, items)
77
-
78
-
79
66
  def get_logging_level(args: Namespace) -> int:
80
67
  if args.quiet:
81
68
  return logging.CRITICAL
@@ -84,7 +71,7 @@ def get_logging_level(args: Namespace) -> int:
84
71
  return logging.INFO
85
72
 
86
73
 
87
- def determine_flavors(studio: bool, local: bool, all: bool, token: Optional[str]):
74
+ def determine_flavors(studio: bool, local: bool, all: bool, token: str | None):
88
75
  if studio and not token:
89
76
  raise DataChainError(
90
77
  "Not logged in to Studio. Log in with 'datachain auth login'."
datachain/client/azure.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Any, Optional
1
+ from typing import Any
2
2
  from urllib.parse import parse_qs, urlsplit, urlunsplit
3
3
 
4
4
  from adlfs import AzureBlobFileSystem
@@ -15,7 +15,7 @@ class AzureClient(Client):
15
15
  protocol = "az"
16
16
 
17
17
  def info_to_file(self, v: dict[str, Any], path: str) -> File:
18
- version_id = v.get("version_id")
18
+ version_id = v.get("version_id") if self._is_version_aware() else None
19
19
  return File(
20
20
  source=self.uri,
21
21
  path=path,
@@ -65,7 +65,7 @@ class AzureClient(Client):
65
65
  if entries:
66
66
  await result_queue.put(entries)
67
67
  pbar.update(len(entries))
68
- if not found:
68
+ if not found and prefix:
69
69
  raise FileNotFoundError(
70
70
  f"Unable to resolve remote path: {prefix}"
71
71
  )
@@ -73,7 +73,7 @@ class AzureClient(Client):
73
73
  result_queue.put_nowait(None)
74
74
 
75
75
  @classmethod
76
- def version_path(cls, path: str, version_id: Optional[str]) -> str:
76
+ def version_path(cls, path: str, version_id: str | None) -> str:
77
77
  parts = list(urlsplit(path))
78
78
  query = parse_qs(parts[3])
79
79
  if "versionid" in query:
@@ -10,15 +10,7 @@ from abc import ABC, abstractmethod
10
10
  from collections.abc import AsyncIterator, Iterator, Sequence
11
11
  from datetime import datetime
12
12
  from shutil import copy2
13
- from typing import (
14
- TYPE_CHECKING,
15
- Any,
16
- BinaryIO,
17
- ClassVar,
18
- NamedTuple,
19
- Optional,
20
- Union,
21
- )
13
+ from typing import TYPE_CHECKING, Any, BinaryIO, ClassVar, NamedTuple
22
14
  from urllib.parse import urlparse
23
15
 
24
16
  from dvc_objects.fs.system import reflink
@@ -44,11 +36,12 @@ FETCH_WORKERS = 100
44
36
  DELIMITER = "/" # Path delimiter.
45
37
 
46
38
  DATA_SOURCE_URI_PATTERN = re.compile(r"^[\w]+:\/\/.*$")
39
+ CLOUD_STORAGE_PROTOCOLS = {"s3", "gs", "az", "hf"}
47
40
 
48
- ResultQueue = asyncio.Queue[Optional[Sequence["File"]]]
41
+ ResultQueue = asyncio.Queue[Sequence["File"] | None]
49
42
 
50
43
 
51
- def _is_win_local_path(uri: str) -> bool:
44
+ def is_win_local_path(uri: str) -> bool:
52
45
  if sys.platform == "win32":
53
46
  if len(uri) >= 1 and uri[0] == "\\":
54
47
  return True
@@ -62,10 +55,20 @@ def _is_win_local_path(uri: str) -> bool:
62
55
  return False
63
56
 
64
57
 
58
+ def is_cloud_uri(uri: str) -> bool:
59
+ protocol = urlparse(uri).scheme
60
+ return protocol in CLOUD_STORAGE_PROTOCOLS
61
+
62
+
63
+ def get_cloud_schemes() -> list[str]:
64
+ """Get list of cloud storage scheme prefixes."""
65
+ return [f"{p}://" for p in CLOUD_STORAGE_PROTOCOLS]
66
+
67
+
65
68
  class Bucket(NamedTuple):
66
69
  name: str
67
70
  uri: "StorageURI"
68
- created: Optional[datetime]
71
+ created: datetime | None
69
72
 
70
73
 
71
74
  class Client(ABC):
@@ -77,21 +80,22 @@ class Client(ABC):
77
80
  def __init__(self, name: str, fs_kwargs: dict[str, Any], cache: Cache) -> None:
78
81
  self.name = name
79
82
  self.fs_kwargs = fs_kwargs
80
- self._fs: Optional[AbstractFileSystem] = None
83
+ self._fs: AbstractFileSystem | None = None
81
84
  self.cache = cache
82
85
  self.uri = self.get_uri(self.name)
83
86
 
84
87
  @staticmethod
85
- def get_implementation(url: Union[str, os.PathLike[str]]) -> type["Client"]:
88
+ def get_implementation(url: str | os.PathLike[str]) -> type["Client"]: # noqa: PLR0911
86
89
  from .azure import AzureClient
87
90
  from .gcs import GCSClient
88
91
  from .hf import HfClient
92
+ from .http import HTTPClient, HTTPSClient
89
93
  from .local import FileClient
90
94
  from .s3 import ClientS3
91
95
 
92
96
  protocol = urlparse(os.fspath(url)).scheme
93
97
 
94
- if not protocol or _is_win_local_path(os.fspath(url)):
98
+ if not protocol or is_win_local_path(os.fspath(url)):
95
99
  return FileClient
96
100
  if protocol == ClientS3.protocol:
97
101
  return ClientS3
@@ -103,9 +107,18 @@ class Client(ABC):
103
107
  return FileClient
104
108
  if protocol == HfClient.protocol:
105
109
  return HfClient
110
+ if protocol == HTTPClient.protocol:
111
+ return HTTPClient
112
+ if protocol == HTTPSClient.protocol:
113
+ return HTTPSClient
106
114
 
107
115
  raise NotImplementedError(f"Unsupported protocol: {protocol}")
108
116
 
117
+ @classmethod
118
+ def path_to_uri(cls, path: str) -> str:
119
+ """Convert a path-like object to a URI. Default: identity."""
120
+ return path
121
+
109
122
  @staticmethod
110
123
  def is_data_source_uri(name: str) -> bool:
111
124
  # Returns True if name is one of supported data sources URIs, e.g s3 bucket
@@ -118,9 +131,7 @@ class Client(ABC):
118
131
  return cls.get_uri(storage_name), rel_path
119
132
 
120
133
  @staticmethod
121
- def get_client(
122
- source: Union[str, os.PathLike[str]], cache: Cache, **kwargs
123
- ) -> "Client":
134
+ def get_client(source: str | os.PathLike[str], cache: Cache, **kwargs) -> "Client":
124
135
  cls = Client.get_implementation(source)
125
136
  storage_url, _ = cls.split_url(os.fspath(source))
126
137
  if os.name == "nt":
@@ -136,7 +147,7 @@ class Client(ABC):
136
147
  return fs
137
148
 
138
149
  @classmethod
139
- def version_path(cls, path: str, version_id: Optional[str]) -> str:
150
+ def version_path(cls, path: str, version_id: str | None) -> str:
140
151
  return path
141
152
 
142
153
  @classmethod
@@ -207,24 +218,25 @@ class Client(ABC):
207
218
  )
208
219
 
209
220
  async def get_current_etag(self, file: "File") -> str:
221
+ file_path = file.get_path_normalized()
210
222
  kwargs = {}
211
- if getattr(self.fs, "version_aware", False):
223
+ if self._is_version_aware():
212
224
  kwargs["version_id"] = file.version
213
225
  info = await self.fs._info(
214
- self.get_full_path(file.path, file.version), **kwargs
226
+ self.get_full_path(file_path, file.version), **kwargs
215
227
  )
216
- return self.info_to_file(info, file.path).etag
228
+ return self.info_to_file(info, file_path).etag
217
229
 
218
- def get_file_info(self, path: str, version_id: Optional[str] = None) -> "File":
230
+ def get_file_info(self, path: str, version_id: str | None = None) -> "File":
219
231
  info = self.fs.info(self.get_full_path(path, version_id), version_id=version_id)
220
232
  return self.info_to_file(info, path)
221
233
 
222
- async def get_size(self, path: str, version_id: Optional[str] = None) -> int:
234
+ async def get_size(self, path: str, version_id: str | None = None) -> int:
223
235
  return await self.fs._size(
224
236
  self.version_path(path, version_id), version_id=version_id
225
237
  )
226
238
 
227
- async def get_file(self, lpath, rpath, callback, version_id: Optional[str] = None):
239
+ async def get_file(self, lpath, rpath, callback, version_id: str | None = None):
228
240
  return await self.fs._get_file(
229
241
  self.version_path(lpath, version_id),
230
242
  rpath,
@@ -326,15 +338,19 @@ class Client(ABC):
326
338
  """
327
339
  return not (key.startswith("/") or key.endswith("/") or "//" in key)
328
340
 
341
+ def _is_version_aware(self) -> bool:
342
+ return getattr(self.fs, "version_aware", False)
343
+
329
344
  async def ls_dir(self, path):
330
- if getattr(self.fs, "version_aware", False):
345
+ kwargs = {}
346
+ if self._is_version_aware():
331
347
  kwargs = {"versions": True}
332
348
  return await self.fs._ls(path, detail=True, **kwargs)
333
349
 
334
350
  def rel_path(self, path: str) -> str:
335
351
  return self.fs.split_path(path)[1]
336
352
 
337
- def get_full_path(self, rel_path: str, version_id: Optional[str] = None) -> str:
353
+ def get_full_path(self, rel_path: str, version_id: str | None = None) -> str:
338
354
  return self.version_path(f"{self.PREFIX}{self.name}/{rel_path}", version_id)
339
355
 
340
356
  @abstractmethod
@@ -382,7 +398,8 @@ class Client(ABC):
382
398
  return open(cache_path, mode="rb")
383
399
  assert not file.location
384
400
  return FileWrapper(
385
- self.fs.open(self.get_full_path(file.path, file.version)), cb
401
+ self.fs.open(self.get_full_path(file.get_path_normalized(), file.version)),
402
+ cb,
386
403
  ) # type: ignore[return-value]
387
404
 
388
405
  def upload(self, data: bytes, path: str) -> "File":