datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. datachain/__init__.py +20 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +7 -7
  4. datachain/catalog/__init__.py +2 -2
  5. datachain/catalog/catalog.py +621 -507
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +28 -18
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +24 -33
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +83 -52
  12. datachain/cli/commands/ls.py +17 -17
  13. datachain/cli/commands/show.py +4 -4
  14. datachain/cli/parser/__init__.py +8 -74
  15. datachain/cli/parser/job.py +95 -3
  16. datachain/cli/parser/studio.py +11 -4
  17. datachain/cli/parser/utils.py +1 -2
  18. datachain/cli/utils.py +2 -15
  19. datachain/client/azure.py +4 -4
  20. datachain/client/fsspec.py +45 -28
  21. datachain/client/gcs.py +6 -6
  22. datachain/client/hf.py +29 -2
  23. datachain/client/http.py +157 -0
  24. datachain/client/local.py +15 -11
  25. datachain/client/s3.py +17 -9
  26. datachain/config.py +4 -8
  27. datachain/data_storage/db_engine.py +12 -6
  28. datachain/data_storage/job.py +5 -1
  29. datachain/data_storage/metastore.py +1252 -186
  30. datachain/data_storage/schema.py +58 -45
  31. datachain/data_storage/serializer.py +105 -15
  32. datachain/data_storage/sqlite.py +286 -127
  33. datachain/data_storage/warehouse.py +250 -113
  34. datachain/dataset.py +353 -148
  35. datachain/delta.py +391 -0
  36. datachain/diff/__init__.py +27 -29
  37. datachain/error.py +60 -0
  38. datachain/func/__init__.py +2 -1
  39. datachain/func/aggregate.py +66 -42
  40. datachain/func/array.py +242 -38
  41. datachain/func/base.py +7 -4
  42. datachain/func/conditional.py +110 -60
  43. datachain/func/func.py +96 -45
  44. datachain/func/numeric.py +55 -38
  45. datachain/func/path.py +32 -20
  46. datachain/func/random.py +2 -2
  47. datachain/func/string.py +67 -37
  48. datachain/func/window.py +7 -8
  49. datachain/hash_utils.py +123 -0
  50. datachain/job.py +11 -7
  51. datachain/json.py +138 -0
  52. datachain/lib/arrow.py +58 -22
  53. datachain/lib/audio.py +245 -0
  54. datachain/lib/clip.py +14 -13
  55. datachain/lib/convert/flatten.py +5 -3
  56. datachain/lib/convert/python_to_sql.py +6 -10
  57. datachain/lib/convert/sql_to_python.py +8 -0
  58. datachain/lib/convert/values_to_tuples.py +156 -51
  59. datachain/lib/data_model.py +42 -20
  60. datachain/lib/dataset_info.py +36 -8
  61. datachain/lib/dc/__init__.py +8 -2
  62. datachain/lib/dc/csv.py +25 -28
  63. datachain/lib/dc/database.py +398 -0
  64. datachain/lib/dc/datachain.py +1289 -425
  65. datachain/lib/dc/datasets.py +320 -38
  66. datachain/lib/dc/hf.py +38 -24
  67. datachain/lib/dc/json.py +29 -32
  68. datachain/lib/dc/listings.py +112 -8
  69. datachain/lib/dc/pandas.py +16 -12
  70. datachain/lib/dc/parquet.py +35 -23
  71. datachain/lib/dc/records.py +31 -23
  72. datachain/lib/dc/storage.py +154 -64
  73. datachain/lib/dc/storage_pattern.py +251 -0
  74. datachain/lib/dc/utils.py +24 -16
  75. datachain/lib/dc/values.py +8 -9
  76. datachain/lib/file.py +622 -89
  77. datachain/lib/hf.py +69 -39
  78. datachain/lib/image.py +14 -14
  79. datachain/lib/listing.py +14 -11
  80. datachain/lib/listing_info.py +1 -2
  81. datachain/lib/meta_formats.py +3 -4
  82. datachain/lib/model_store.py +39 -7
  83. datachain/lib/namespaces.py +125 -0
  84. datachain/lib/projects.py +130 -0
  85. datachain/lib/pytorch.py +32 -21
  86. datachain/lib/settings.py +192 -56
  87. datachain/lib/signal_schema.py +427 -104
  88. datachain/lib/tar.py +1 -2
  89. datachain/lib/text.py +8 -7
  90. datachain/lib/udf.py +164 -76
  91. datachain/lib/udf_signature.py +60 -35
  92. datachain/lib/utils.py +118 -4
  93. datachain/lib/video.py +17 -9
  94. datachain/lib/webdataset.py +61 -56
  95. datachain/lib/webdataset_laion.py +15 -16
  96. datachain/listing.py +22 -10
  97. datachain/model/bbox.py +3 -1
  98. datachain/model/ultralytics/bbox.py +16 -12
  99. datachain/model/ultralytics/pose.py +16 -12
  100. datachain/model/ultralytics/segment.py +16 -12
  101. datachain/namespace.py +84 -0
  102. datachain/node.py +6 -6
  103. datachain/nodes_thread_pool.py +0 -1
  104. datachain/plugins.py +24 -0
  105. datachain/project.py +78 -0
  106. datachain/query/batch.py +40 -41
  107. datachain/query/dataset.py +604 -322
  108. datachain/query/dispatch.py +261 -154
  109. datachain/query/metrics.py +4 -6
  110. datachain/query/params.py +2 -3
  111. datachain/query/queue.py +3 -12
  112. datachain/query/schema.py +11 -6
  113. datachain/query/session.py +200 -33
  114. datachain/query/udf.py +34 -2
  115. datachain/remote/studio.py +171 -69
  116. datachain/script_meta.py +12 -12
  117. datachain/semver.py +68 -0
  118. datachain/sql/__init__.py +2 -0
  119. datachain/sql/functions/array.py +33 -1
  120. datachain/sql/postgresql_dialect.py +9 -0
  121. datachain/sql/postgresql_types.py +21 -0
  122. datachain/sql/sqlite/__init__.py +5 -1
  123. datachain/sql/sqlite/base.py +102 -29
  124. datachain/sql/sqlite/types.py +8 -13
  125. datachain/sql/types.py +70 -15
  126. datachain/studio.py +223 -46
  127. datachain/toolkit/split.py +31 -10
  128. datachain/utils.py +101 -59
  129. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
  130. datachain-0.39.0.dist-info/RECORD +173 -0
  131. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
  132. datachain/cli/commands/query.py +0 -53
  133. datachain/query/utils.py +0 -42
  134. datachain-0.14.2.dist-info/RECORD +0 -158
  135. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  136. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  137. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,19 @@
1
- from typing import (
2
- TYPE_CHECKING,
3
- Optional,
4
- )
1
+ from collections.abc import Sequence
2
+ from typing import TYPE_CHECKING, get_origin, get_type_hints
5
3
 
6
- from datachain.lib.dataset_info import DatasetInfo
7
- from datachain.lib.file import (
8
- File,
4
+ from datachain.error import (
5
+ DatasetNotFoundError,
6
+ DatasetVersionNotFoundError,
7
+ ProjectNotFoundError,
9
8
  )
9
+ from datachain.lib.dataset_info import DatasetInfo
10
+ from datachain.lib.projects import get as get_project
10
11
  from datachain.lib.settings import Settings
11
12
  from datachain.lib.signal_schema import SignalSchema
12
13
  from datachain.query import Session
13
14
  from datachain.query.dataset import DatasetQuery
14
15
 
15
- from .utils import Sys
16
+ from .utils import Sys, is_studio
16
17
  from .values import read_values
17
18
 
18
19
  if TYPE_CHECKING:
@@ -25,21 +26,64 @@ if TYPE_CHECKING:
25
26
 
26
27
  def read_dataset(
27
28
  name: str,
28
- version: Optional[int] = None,
29
- session: Optional[Session] = None,
30
- settings: Optional[dict] = None,
31
- fallback_to_studio: bool = True,
29
+ namespace: str | None = None,
30
+ project: str | None = None,
31
+ version: str | int | None = None,
32
+ session: Session | None = None,
33
+ settings: dict | None = None,
34
+ delta: bool | None = False,
35
+ delta_on: str | Sequence[str] | None = (
36
+ "file.path",
37
+ "file.etag",
38
+ "file.version",
39
+ ),
40
+ delta_result_on: str | Sequence[str] | None = None,
41
+ delta_compare: str | Sequence[str] | None = None,
42
+ delta_retry: bool | str | None = None,
43
+ delta_unsafe: bool = False,
44
+ update: bool = False,
32
45
  ) -> "DataChain":
33
46
  """Get data from a saved Dataset. It returns the chain itself.
34
47
  If dataset or version is not found locally, it will try to pull it from Studio.
35
48
 
36
49
  Parameters:
37
- name : dataset name
38
- version : dataset version
39
- session : Session to use for the chain.
40
- settings : Settings to use for the chain.
41
- fallback_to_studio : Try to pull dataset from Studio if not found locally.
42
- Default is True.
50
+ name: The dataset name, which can be a fully qualified name including the
51
+ namespace and project. Alternatively, it can be a regular name, in which
52
+ case the explicitly defined namespace and project will be used if they are
53
+ set; otherwise, default values will be applied.
54
+ namespace: optional name of namespace in which dataset to read is created
55
+ project: optional name of project in which dataset to read is created
56
+ version: dataset version. Supports:
57
+ - Exact version strings: "1.2.3"
58
+ - Legacy integer versions: 1, 2, 3 (finds latest major version)
59
+ - Version specifiers (PEP 440): ">=1.0.0,<2.0.0", "~=1.4.2", "==1.2.*", etc.
60
+ session: Session to use for the chain.
61
+ settings: Settings to use for the chain.
62
+ delta: If True, only process new or changed files instead of reprocessing
63
+ everything. This saves time by skipping files that were already processed in
64
+ previous versions. The optimization is working when a new version of the
65
+ dataset is created.
66
+ Default is False.
67
+ delta_on: Field(s) that uniquely identify each record in the source data.
68
+ Used to detect which records are new or changed.
69
+ Default is ("file.path", "file.etag", "file.version").
70
+ delta_result_on: Field(s) in the result dataset that match `delta_on` fields.
71
+ Only needed if you rename the identifying fields during processing.
72
+ Default is None.
73
+ delta_compare: Field(s) used to detect if a record has changed.
74
+ If not specified, all fields except `delta_on` fields are used.
75
+ Default is None.
76
+ delta_retry: Controls retry behavior for failed records:
77
+ - String (field name): Reprocess records where this field is not empty
78
+ (error mode)
79
+ - True: Reprocess records missing from the result dataset (missing mode)
80
+ - None: No retry processing (default)
81
+ update: If True always checks for newer versions available on Studio, even if
82
+ some version of the dataset exists locally already. If False (default), it
83
+ will only fetch the dataset from Studio if it is not found locally.
84
+ delta_unsafe: Allow restricted ops in delta: merge, agg, union, group_by,
85
+ distinct.
86
+
43
87
 
44
88
  Example:
45
89
  ```py
@@ -48,11 +92,27 @@ def read_dataset(
48
92
  ```
49
93
 
50
94
  ```py
51
- chain = dc.read_dataset("my_cats", fallback_to_studio=False)
95
+ import datachain as dc
96
+ chain = dc.read_dataset("dev.animals.my_cats")
97
+ ```
98
+
99
+ ```py
100
+ chain = dc.read_dataset("my_cats", version="1.0.0")
101
+ ```
102
+
103
+ ```py
104
+ # Using version specifiers (PEP 440)
105
+ chain = dc.read_dataset("my_cats", version=">=1.0.0,<2.0.0")
106
+ ```
107
+
108
+ ```py
109
+ # Legacy integer version support (finds latest in major version)
110
+ chain = dc.read_dataset("my_cats", version=1) # Latest 1.x.x version
52
111
  ```
53
112
 
54
113
  ```py
55
- chain = dc.read_dataset("my_cats", version=1)
114
+ # Always check for newer versions matching a version specifier from Studio
115
+ chain = dc.read_dataset("my_cats", version=">=1.0.0", update=True)
56
116
  ```
57
117
 
58
118
  ```py
@@ -66,10 +126,9 @@ def read_dataset(
66
126
  }
67
127
  chain = dc.read_dataset(
68
128
  name="my_cats",
69
- version=1,
129
+ version="1.0.0",
70
130
  session=session,
71
131
  settings=settings,
72
- fallback_to_studio=True,
73
132
  )
74
133
  ```
75
134
  """
@@ -77,34 +136,96 @@ def read_dataset(
77
136
 
78
137
  from .datachain import DataChain
79
138
 
80
- query = DatasetQuery(
81
- name=name,
82
- version=version,
83
- session=session,
84
- indexing_column_types=File._datachain_column_types,
85
- fallback_to_studio=fallback_to_studio,
86
- )
87
139
  telemetry.send_event_once("class", "datachain_init", name=name, version=version)
140
+
141
+ session = Session.get(session)
142
+ catalog = session.catalog
143
+
144
+ namespace_name, project_name, name = catalog.get_full_dataset_name(
145
+ name,
146
+ project_name=project,
147
+ namespace_name=namespace,
148
+ )
149
+
150
+ if version is not None:
151
+ dataset = session.catalog.get_dataset_with_remote_fallback(
152
+ name, namespace_name, project_name, update=update
153
+ )
154
+
155
+ # Convert legacy integer versions to version specifiers
156
+ # For backward compatibility we still allow users to put version as integer
157
+ # in which case we convert it to a version specifier that finds the latest
158
+ # version where major part is equal to that input version.
159
+ # For example if user sets version=2, we convert it to ">=2.0.0,<3.0.0"
160
+ # which will find something like 2.4.3 (assuming 2.4.3 is the biggest among
161
+ # all 2.* dataset versions)
162
+ if isinstance(version, int):
163
+ version_spec = f">={version}.0.0,<{version + 1}.0.0"
164
+ else:
165
+ version_spec = str(version)
166
+
167
+ from packaging.specifiers import InvalidSpecifier, SpecifierSet
168
+
169
+ try:
170
+ # Try to parse as version specifier
171
+ SpecifierSet(version_spec)
172
+ # If it's a valid specifier set, find the latest compatible version
173
+ latest_compatible = dataset.latest_compatible_version(version_spec)
174
+ if not latest_compatible:
175
+ raise DatasetVersionNotFoundError(
176
+ f"No dataset {name} version matching specifier {version_spec}"
177
+ )
178
+ version = latest_compatible
179
+ except InvalidSpecifier:
180
+ # If not a valid specifier, treat as exact version string
181
+ # This handles cases like "1.2.3" which are exact versions, not specifiers
182
+ pass
183
+
88
184
  if settings:
89
185
  _settings = Settings(**settings)
90
186
  else:
91
187
  _settings = Settings()
92
188
 
189
+ query = DatasetQuery(
190
+ name=name,
191
+ project_name=project_name,
192
+ namespace_name=namespace_name,
193
+ version=version, # type: ignore[arg-type]
194
+ session=session,
195
+ update=update,
196
+ )
197
+
93
198
  signals_schema = SignalSchema({"sys": Sys})
94
199
  if query.feature_schema:
95
200
  signals_schema |= SignalSchema.deserialize(query.feature_schema)
96
201
  else:
97
202
  signals_schema |= SignalSchema.from_column_types(query.column_types or {})
98
- return DataChain(query, _settings, signals_schema)
203
+
204
+ if delta:
205
+ signals_schema = signals_schema.clone_without_sys_signals()
206
+
207
+ chain = DataChain(query, _settings, signals_schema)
208
+
209
+ if delta:
210
+ chain = chain._as_delta(
211
+ on=delta_on,
212
+ right_on=delta_result_on,
213
+ compare=delta_compare,
214
+ delta_retry=delta_retry,
215
+ delta_unsafe=delta_unsafe,
216
+ )
217
+
218
+ return chain
99
219
 
100
220
 
101
221
  def datasets(
102
- session: Optional[Session] = None,
103
- settings: Optional[dict] = None,
222
+ session: Session | None = None,
223
+ settings: dict | None = None,
104
224
  in_memory: bool = False,
105
- object_name: str = "dataset",
225
+ column: str | None = None,
106
226
  include_listing: bool = False,
107
227
  studio: bool = False,
228
+ attrs: list[str] | None = None,
108
229
  ) -> "DataChain":
109
230
  """Generate chain with list of registered datasets.
110
231
 
@@ -112,10 +233,15 @@ def datasets(
112
233
  session: Optional session instance. If not provided, uses default session.
113
234
  settings: Optional dictionary of settings to configure the chain.
114
235
  in_memory: If True, creates an in-memory session. Defaults to False.
115
- object_name: Name of the output object in the chain. Defaults to "dataset".
236
+ column: Name of the output column in the chain. Defaults to None which
237
+ means no top level column will be created.
116
238
  include_listing: If True, includes listing datasets. Defaults to False.
117
239
  studio: If True, returns datasets from Studio only,
118
240
  otherwise returns all local datasets. Defaults to False.
241
+ attrs: Optional list of attributes to filter datasets on. It can be just
242
+ attribute without value e.g "NLP", or attribute with value
243
+ e.g "location=US". Attribute with value can also accept "*" to target
244
+ all that have specific name e.g "location=*"
119
245
 
120
246
  Returns:
121
247
  DataChain: A new DataChain instance containing dataset information.
@@ -124,8 +250,8 @@ def datasets(
124
250
  ```py
125
251
  import datachain as dc
126
252
 
127
- chain = dc.datasets()
128
- for ds in chain.collect("dataset"):
253
+ chain = dc.datasets(column="dataset")
254
+ for ds in chain.to_iter("dataset"):
129
255
  print(f"{ds.name}@v{ds.version}")
130
256
  ```
131
257
  """
@@ -139,11 +265,167 @@ def datasets(
139
265
  include_listing=include_listing, studio=studio
140
266
  )
141
267
  ]
268
+ datasets_values = [d for d in datasets_values if not d.is_temp]
269
+
270
+ if attrs:
271
+ for attr in attrs:
272
+ datasets_values = [d for d in datasets_values if d.has_attr(attr)]
273
+
274
+ if not column:
275
+ # flattening dataset fields
276
+ schema = {
277
+ k: get_origin(v) if get_origin(v) is dict else v
278
+ for k, v in get_type_hints(DatasetInfo).items()
279
+ if k in DatasetInfo.model_fields
280
+ }
281
+ data = {k: [] for k in DatasetInfo.model_fields} # type: ignore[var-annotated]
282
+ for d in [d.model_dump() for d in datasets_values]:
283
+ for field, value in d.items():
284
+ data[field].append(value)
285
+
286
+ return read_values(
287
+ session=session,
288
+ settings=settings,
289
+ in_memory=in_memory,
290
+ output=schema,
291
+ **data, # type: ignore[arg-type]
292
+ )
142
293
 
143
294
  return read_values(
144
295
  session=session,
145
296
  settings=settings,
146
297
  in_memory=in_memory,
147
- output={object_name: DatasetInfo},
148
- **{object_name: datasets_values}, # type: ignore[arg-type]
298
+ output={column: DatasetInfo},
299
+ **{column: datasets_values}, # type: ignore[arg-type]
300
+ )
301
+
302
+
303
+ def delete_dataset(
304
+ name: str,
305
+ namespace: str | None = None,
306
+ project: str | None = None,
307
+ version: str | None = None,
308
+ force: bool | None = False,
309
+ studio: bool | None = False,
310
+ session: Session | None = None,
311
+ in_memory: bool = False,
312
+ ) -> None:
313
+ """Removes specific dataset version or all dataset versions, depending on
314
+ a force flag.
315
+
316
+ Args:
317
+ name: The dataset name, which can be a fully qualified name including the
318
+ namespace and project. Alternatively, it can be a regular name, in which
319
+ case the explicitly defined namespace and project will be used if they are
320
+ set; otherwise, default values will be applied.
321
+ namespace: optional name of namespace in which dataset to delete is created
322
+ project: optional name of project in which dataset to delete is created
323
+ version: Optional dataset version
324
+ force: If true, all datasets versions will be removed. Defaults to False.
325
+ studio: If True, removes dataset from Studio only, otherwise removes local
326
+ dataset. Defaults to False.
327
+ session: Optional session instance. If not provided, uses default session.
328
+ in_memory: If True, creates an in-memory session. Defaults to False.
329
+
330
+ Returns: None
331
+
332
+ Example:
333
+ ```py
334
+ import datachain as dc
335
+ dc.delete_dataset("cats")
336
+ ```
337
+
338
+ ```py
339
+ import datachain as dc
340
+ dc.delete_dataset("cats", version="1.0.0")
341
+ ```
342
+ """
343
+ from datachain.studio import remove_studio_dataset
344
+
345
+ session = Session.get(session, in_memory=in_memory)
346
+ catalog = session.catalog
347
+
348
+ namespace_name, project_name, name = catalog.get_full_dataset_name(
349
+ name,
350
+ project_name=project,
351
+ namespace_name=namespace,
352
+ )
353
+
354
+ if not is_studio() and studio:
355
+ return remove_studio_dataset(
356
+ None, name, namespace_name, project_name, version=version, force=force
357
+ )
358
+
359
+ try:
360
+ ds_project = get_project(project_name, namespace_name, session=session)
361
+ except ProjectNotFoundError:
362
+ raise DatasetNotFoundError(
363
+ f"Dataset {name} not found in namespace {namespace_name} and project",
364
+ f" {project_name}",
365
+ ) from None
366
+
367
+ if not force:
368
+ version = (
369
+ version
370
+ or catalog.get_dataset(
371
+ name,
372
+ namespace_name=ds_project.namespace.name,
373
+ project_name=ds_project.name,
374
+ ).latest_version
375
+ )
376
+ else:
377
+ version = None
378
+ catalog.remove_dataset(name, ds_project, version=version, force=force)
379
+
380
+
381
+ def move_dataset(
382
+ src: str,
383
+ dest: str,
384
+ session: Session | None = None,
385
+ in_memory: bool = False,
386
+ ) -> None:
387
+ """Moves an entire dataset between namespaces and projects.
388
+
389
+ Args:
390
+ src: The source dataset name. This can be a fully qualified name that includes
391
+ the namespace and project, or a regular name. If a regular name is used,
392
+ default values will be applied. The source dataset will no longer exist
393
+ after the move.
394
+ dest: The destination dataset name. This can also be a fully qualified
395
+ name with a namespace and project, or just a regular name (default values
396
+ will be used in that case). The original dataset will be moved here.
397
+ session: An optional session instance. If not provided, the default session
398
+ will be used.
399
+ in_memory: If True, creates an in-memory session. Defaults to False.
400
+
401
+ Returns:
402
+ None
403
+
404
+ Examples:
405
+ ```python
406
+ import datachain as dc
407
+ dc.move_dataset("cats", "new_cats")
408
+ ```
409
+
410
+ ```python
411
+ import datachain as dc
412
+ dc.move_dataset("dev.animals.cats", "prod.animals.cats")
413
+ ```
414
+ """
415
+ session = Session.get(session, in_memory=in_memory)
416
+ catalog = session.catalog
417
+
418
+ namespace, project, name = catalog.get_full_dataset_name(src)
419
+ dest_namespace, dest_project, dest_name = catalog.get_full_dataset_name(dest)
420
+
421
+ dataset = catalog.get_dataset(name, namespace_name=namespace, project_name=project)
422
+
423
+ catalog.update_dataset(
424
+ dataset,
425
+ name=dest_name,
426
+ project_id=catalog.metastore.get_project(
427
+ dest_project,
428
+ dest_namespace,
429
+ create=is_studio(),
430
+ ).id,
149
431
  )
datachain/lib/dc/hf.py CHANGED
@@ -1,8 +1,4 @@
1
- from typing import (
2
- TYPE_CHECKING,
3
- Optional,
4
- Union,
5
- )
1
+ from typing import TYPE_CHECKING, Any
6
2
 
7
3
  from datachain.lib.data_model import dict_to_data_model
8
4
  from datachain.query import Session
@@ -19,24 +15,29 @@ if TYPE_CHECKING:
19
15
 
20
16
 
21
17
  def read_hf(
22
- dataset: Union[str, "HFDatasetType"],
23
- *args,
24
- session: Optional[Session] = None,
25
- settings: Optional[dict] = None,
26
- object_name: str = "",
18
+ dataset: "HFDatasetType",
19
+ *args: Any,
20
+ session: Session | None = None,
21
+ settings: dict | None = None,
22
+ column: str = "",
27
23
  model_name: str = "",
28
- **kwargs,
24
+ limit: int = 0,
25
+ **kwargs: Any,
29
26
  ) -> "DataChain":
30
- """Generate chain from huggingface hub dataset.
27
+ """Generate chain from Hugging Face Hub dataset.
31
28
 
32
29
  Parameters:
33
- dataset : Path or name of the dataset to read from Hugging Face Hub,
30
+ dataset: Path or name of the dataset to read from Hugging Face Hub,
34
31
  or an instance of `datasets.Dataset`-like object.
35
- session : Session to use for the chain.
36
- settings : Settings to use for the chain.
37
- object_name : Generated object column name.
38
- model_name : Generated model name.
39
- kwargs : Parameters to pass to datasets.load_dataset.
32
+ args: Additional positional arguments to pass to `datasets.load_dataset`.
33
+ session: Session to use for the chain.
34
+ settings: Settings to use for the chain.
35
+ column: Generated object column name.
36
+ model_name: Generated model name.
37
+ limit: The maximum number of items to read from the HF dataset.
38
+ Applies `take(limit)` to `datasets.load_dataset`.
39
+ Defaults to 0 (no limit).
40
+ kwargs: Parameters to pass to `datasets.load_dataset`.
40
41
 
41
42
  Example:
42
43
  Load from Hugging Face Hub:
@@ -52,6 +53,18 @@ def read_hf(
52
53
  import datachain as dc
53
54
  chain = dc.read_hf(ds)
54
55
  ```
56
+
57
+ Streaming with limit, for large datasets:
58
+ ```py
59
+ import datachain as dc
60
+ ds = dc.read_hf("beans", split="train", streaming=True, limit=10)
61
+ ```
62
+
63
+ or use HF split syntax (not supported if streaming is enabled):
64
+ ```py
65
+ import datachain as dc
66
+ ds = dc.read_hf("beans", split="train[%10]")
67
+ ```
55
68
  """
56
69
  from datachain.lib.hf import HFGenerator, get_output_schema, stream_splits
57
70
 
@@ -62,12 +75,13 @@ def read_hf(
62
75
  if len(ds_dict) > 1:
63
76
  output = {"split": str}
64
77
 
65
- model_name = model_name or object_name or ""
78
+ model_name = model_name or column or ""
66
79
  hf_features = next(iter(ds_dict.values())).features
67
- output = output | get_output_schema(hf_features)
68
- model = dict_to_data_model(model_name, output)
69
- if object_name:
70
- output = {object_name: model}
80
+ hf_output, normalized_names = get_output_schema(hf_features, list(output.keys()))
81
+ output = output | hf_output
82
+ model = dict_to_data_model(model_name, output, list(normalized_names.values()))
83
+ if column:
84
+ output = {column: model}
71
85
 
72
86
  chain = read_values(split=list(ds_dict.keys()), session=session, settings=settings)
73
- return chain.gen(HFGenerator(dataset, model, *args, **kwargs), output=output)
87
+ return chain.gen(HFGenerator(dataset, model, limit, *args, **kwargs), output=output)
datachain/lib/dc/json.py CHANGED
@@ -1,18 +1,12 @@
1
1
  import os
2
- import os.path
3
2
  import re
4
- from typing import (
5
- TYPE_CHECKING,
6
- Optional,
7
- Union,
8
- )
3
+ from typing import TYPE_CHECKING
9
4
 
5
+ import cloudpickle
6
+
7
+ from datachain.lib import meta_formats
10
8
  from datachain.lib.data_model import DataType
11
- from datachain.lib.file import (
12
- File,
13
- FileType,
14
- )
15
- from datachain.lib.meta_formats import read_meta
9
+ from datachain.lib.file import File, FileType
16
10
 
17
11
  if TYPE_CHECKING:
18
12
  from typing_extensions import ParamSpec
@@ -23,30 +17,30 @@ if TYPE_CHECKING:
23
17
 
24
18
 
25
19
  def read_json(
26
- path: Union[str, os.PathLike[str]],
20
+ path: str | os.PathLike[str],
27
21
  type: FileType = "text",
28
- spec: Optional[DataType] = None,
29
- schema_from: Optional[str] = "auto",
30
- jmespath: Optional[str] = None,
31
- object_name: Optional[str] = "",
32
- model_name: Optional[str] = None,
33
- format: Optional[str] = "json",
34
- nrows=None,
22
+ spec: DataType | None = None,
23
+ schema_from: str | None = "auto",
24
+ jmespath: str | None = None,
25
+ column: str | None = "",
26
+ model_name: str | None = None,
27
+ format: str | None = "json",
28
+ nrows: int | None = None,
35
29
  **kwargs,
36
30
  ) -> "DataChain":
37
31
  """Get data from JSON. It returns the chain itself.
38
32
 
39
33
  Parameters:
40
- path : storage URI with directory. URI must start with storage prefix such
34
+ path: storage URI with directory. URI must start with storage prefix such
41
35
  as `s3://`, `gs://`, `az://` or "file:///"
42
- type : read file as "binary", "text", or "image" data. Default is "text".
43
- spec : optional Data Model
44
- schema_from : path to sample to infer spec (if schema not provided)
45
- object_name : generated object column name
46
- model_name : optional generated model name
36
+ type: read file as "binary", "text", or "image" data. Default is "text".
37
+ spec: optional Data Model
38
+ schema_from: path to sample to infer spec (if schema not provided)
39
+ column: generated column name
40
+ model_name: optional generated model name
47
41
  format: "json", "jsonl"
48
- jmespath : optional JMESPATH expression to reduce JSON
49
- nrows : optional row limit for jsonl and JSON arrays
42
+ jmespath: optional JMESPATH expression to reduce JSON
43
+ nrows: optional row limit for jsonl and JSON arrays
50
44
 
51
45
  Example:
52
46
  infer JSON schema from data, reduce using JMESPATH
@@ -70,13 +64,13 @@ def read_json(
70
64
  name_end = re.search(r"\W", s).start() if re.search(r"\W", s) else len(s) # type: ignore[union-attr]
71
65
  return s[:name_end]
72
66
 
73
- if (not object_name) and jmespath:
74
- object_name = jmespath_to_name(jmespath)
75
- if not object_name:
76
- object_name = format
67
+ if (not column) and jmespath:
68
+ column = jmespath_to_name(jmespath)
69
+ if not column:
70
+ column = format
77
71
  chain = read_storage(uri=path, type=type, **kwargs)
78
72
  signal_dict = {
79
- object_name: read_meta(
73
+ column: meta_formats.read_meta(
80
74
  schema_from=schema_from,
81
75
  format=format,
82
76
  spec=spec,
@@ -88,4 +82,7 @@ def read_json(
88
82
  }
89
83
  # disable prefetch if nrows is set
90
84
  settings = {"prefetch": 0} if nrows else {}
85
+
86
+ cloudpickle.register_pickle_by_value(meta_formats)
87
+
91
88
  return chain.settings(**settings).gen(**signal_dict) # type: ignore[misc, arg-type]