datachain 0.1.13__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (49) hide show
  1. datachain/__init__.py +0 -4
  2. datachain/asyn.py +3 -3
  3. datachain/catalog/__init__.py +3 -3
  4. datachain/catalog/catalog.py +6 -6
  5. datachain/catalog/loader.py +3 -3
  6. datachain/cli.py +10 -2
  7. datachain/client/azure.py +37 -1
  8. datachain/client/fsspec.py +1 -1
  9. datachain/client/local.py +1 -1
  10. datachain/data_storage/__init__.py +1 -1
  11. datachain/data_storage/metastore.py +11 -3
  12. datachain/data_storage/schema.py +12 -7
  13. datachain/data_storage/sqlite.py +3 -0
  14. datachain/data_storage/warehouse.py +31 -30
  15. datachain/dataset.py +1 -3
  16. datachain/lib/arrow.py +85 -0
  17. datachain/lib/cached_stream.py +3 -85
  18. datachain/lib/dc.py +382 -179
  19. datachain/lib/feature.py +46 -91
  20. datachain/lib/feature_registry.py +4 -1
  21. datachain/lib/feature_utils.py +2 -2
  22. datachain/lib/file.py +30 -44
  23. datachain/lib/image.py +9 -2
  24. datachain/lib/meta_formats.py +66 -34
  25. datachain/lib/settings.py +5 -5
  26. datachain/lib/signal_schema.py +103 -105
  27. datachain/lib/udf.py +10 -38
  28. datachain/lib/udf_signature.py +11 -6
  29. datachain/lib/webdataset_laion.py +5 -22
  30. datachain/listing.py +8 -8
  31. datachain/node.py +1 -1
  32. datachain/progress.py +1 -1
  33. datachain/query/builtins.py +1 -1
  34. datachain/query/dataset.py +42 -119
  35. datachain/query/dispatch.py +1 -1
  36. datachain/query/metrics.py +19 -0
  37. datachain/query/schema.py +13 -3
  38. datachain/sql/__init__.py +1 -1
  39. datachain/sql/sqlite/base.py +34 -2
  40. datachain/sql/sqlite/vector.py +13 -5
  41. datachain/utils.py +1 -122
  42. {datachain-0.1.13.dist-info → datachain-0.2.1.dist-info}/METADATA +11 -4
  43. {datachain-0.1.13.dist-info → datachain-0.2.1.dist-info}/RECORD +47 -47
  44. {datachain-0.1.13.dist-info → datachain-0.2.1.dist-info}/WHEEL +1 -1
  45. datachain/_version.py +0 -16
  46. datachain/lib/parquet.py +0 -32
  47. {datachain-0.1.13.dist-info → datachain-0.2.1.dist-info}/LICENSE +0 -0
  48. {datachain-0.1.13.dist-info → datachain-0.2.1.dist-info}/entry_points.txt +0 -0
  49. {datachain-0.1.13.dist-info → datachain-0.2.1.dist-info}/top_level.txt +0 -0
datachain/lib/feature.py CHANGED
@@ -4,8 +4,10 @@ import re
4
4
  import warnings
5
5
  from collections.abc import Iterable, Sequence
6
6
  from datetime import datetime
7
+ from functools import lru_cache
7
8
  from types import GenericAlias
8
9
  from typing import (
10
+ TYPE_CHECKING,
9
11
  Any,
10
12
  ClassVar,
11
13
  Literal,
@@ -22,7 +24,7 @@ from typing_extensions import Literal as LiteralEx
22
24
 
23
25
  from datachain.lib.feature_registry import Registry
24
26
  from datachain.query import C
25
- from datachain.query.udf import UDFOutputSpec
27
+ from datachain.query.schema import DEFAULT_DELIMITER
26
28
  from datachain.sql.types import (
27
29
  JSON,
28
30
  Array,
@@ -38,6 +40,9 @@ from datachain.sql.types import (
38
40
  String,
39
41
  )
40
42
 
43
+ if TYPE_CHECKING:
44
+ from datachain.catalog import Catalog
45
+
41
46
  FeatureStandardType = Union[
42
47
  type[int],
43
48
  type[str],
@@ -62,6 +67,7 @@ TYPE_TO_DATACHAIN = {
62
67
  bool: Boolean,
63
68
  datetime: DateTime, # Note, list of datetime is not supported yet
64
69
  bytes: Binary, # Note, list of bytes is not supported yet
70
+ list: Array,
65
71
  dict: JSON,
66
72
  }
67
73
 
@@ -108,8 +114,6 @@ warnings.filterwarnings(
108
114
  # skipped within loops.
109
115
  feature_classes_lookup: dict[type, bool] = {}
110
116
 
111
- DEFAULT_DELIMITER = "__"
112
-
113
117
 
114
118
  class Feature(BaseModel):
115
119
  """A base class for defining data classes that serve as inputs and outputs for
@@ -117,9 +121,6 @@ class Feature(BaseModel):
117
121
  `pydantic`'s BaseModel.
118
122
  """
119
123
 
120
- _is_shallow: ClassVar[bool] = False
121
- _expand_class_name: ClassVar[bool] = False
122
- _delimiter: ClassVar[str] = DEFAULT_DELIMITER
123
124
  _is_file: ClassVar[bool] = False
124
125
  _version: ClassVar[int] = 1
125
126
 
@@ -135,20 +136,6 @@ class Feature(BaseModel):
135
136
  def _name(cls) -> str:
136
137
  return f"{cls.__name__}@{cls._version}"
137
138
 
138
- def _get_value_with_check(self, *args: Any, **kwargs: Any) -> Any:
139
- signature = inspect.signature(self.get_value)
140
- for i, (name, prm) in enumerate(signature.parameters.items()):
141
- if prm.default == inspect.Parameter.empty:
142
- if i < len(args):
143
- continue
144
- if name not in kwargs:
145
- raise ValueError(
146
- f"unable to get value for class {self.__class__.__name__}"
147
- f" due to a missing parameter {name} in get_value()"
148
- )
149
-
150
- return self.get_value(*args, **kwargs)
151
-
152
139
  @classmethod
153
140
  def __pydantic_init_subclass__(cls):
154
141
  Registry.add(cls)
@@ -162,9 +149,10 @@ class Feature(BaseModel):
162
149
 
163
150
  @classmethod
164
151
  def _normalize(cls, name: str) -> str:
165
- if cls._delimiter and cls._delimiter.lower() in name.lower():
152
+ if DEFAULT_DELIMITER in name:
166
153
  raise RuntimeError(
167
- f"variable '{name}' cannot be used because it contains {cls._delimiter}"
154
+ f"variable '{name}' cannot be used "
155
+ f"because it contains {DEFAULT_DELIMITER}"
168
156
  )
169
157
  return Feature._to_snake_case(name)
170
158
 
@@ -174,7 +162,7 @@ class Feature(BaseModel):
174
162
  s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
175
163
  return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
176
164
 
177
- def _set_stream(self, catalog, stream=None, caching_enabled: bool = False) -> None:
165
+ def _set_stream(self, catalog: "Catalog", caching_enabled: bool = False) -> None:
178
166
  pass
179
167
 
180
168
  @classmethod
@@ -187,35 +175,6 @@ class Feature(BaseModel):
187
175
  if Feature.is_feature(anno):
188
176
  yield from anno.get_file_signals([*path, name]) # type: ignore[union-attr]
189
177
 
190
- @classmethod
191
- def _flatten_full_schema(cls, fields, name_path):
192
- for name, f_info in fields.items():
193
- anno = f_info.annotation
194
- name = cls._normalize(name)
195
-
196
- orig = get_origin(anno)
197
- if orig == list:
198
- anno = get_args(anno)
199
- if isinstance(anno, tuple):
200
- anno = anno[0]
201
- is_list = True
202
- else:
203
- is_list = False
204
-
205
- if Feature.is_feature(anno):
206
- lst = copy.copy(name_path)
207
- lst = [] if anno._is_shallow else [*lst, name]
208
-
209
- if is_list:
210
- yield anno._delimiter.join(lst), Array(JSON)
211
- else:
212
- yield from cls._flatten_full_schema(anno.model_fields, lst)
213
- else:
214
- typ = convert_type_to_datachain(anno)
215
- if is_list:
216
- typ = Array(typ)
217
- yield cls._delimiter.join([*name_path, name]), typ
218
-
219
178
  @classmethod
220
179
  def is_feature(cls, anno) -> bool:
221
180
  if anno in feature_classes_lookup:
@@ -242,22 +201,10 @@ class Feature(BaseModel):
242
201
  def is_feature_type(cls, t: type) -> bool:
243
202
  if cls.is_standard_type(t):
244
203
  return True
245
- if get_origin(t) == list and len(get_args(t)) == 1:
204
+ if get_origin(t) is list and len(get_args(t)) == 1:
246
205
  return cls.is_feature_type(get_args(t)[0])
247
206
  return cls.is_feature(t)
248
207
 
249
- @classmethod
250
- def _to_udf_spec(cls):
251
- return list(cls._flatten_full_schema(cls.model_fields, []))
252
-
253
- @staticmethod
254
- def _features_to_udf_spec(fr_classes: Sequence[type["Feature"]]) -> UDFOutputSpec:
255
- return dict(
256
- item
257
- for b in fr_classes
258
- for item in b._to_udf_spec() # type: ignore[attr-defined]
259
- )
260
-
261
208
  def _flatten_fields_values(self, fields, model):
262
209
  for name, f_info in fields.items():
263
210
  anno = f_info.annotation
@@ -280,16 +227,15 @@ class Feature(BaseModel):
280
227
  yield value
281
228
 
282
229
  def _flatten(self):
283
- return tuple(self._flatten_generator())
284
-
285
- def _flatten_generator(self):
286
- # Optimization: Use a generator instead of a tuple if all values are going to
287
- # be used immediately in another comprehension or function call.
288
- return self._flatten_fields_values(self.model_fields, self)
230
+ return tuple(self._flatten_fields_values(self.model_fields, self))
289
231
 
290
232
  @staticmethod
291
233
  def _flatten_list(objs):
292
- return tuple(val for obj in objs for val in obj._flatten_generator())
234
+ return tuple(
235
+ val
236
+ for obj in objs
237
+ for val in obj._flatten_fields_values(obj.model_fields, obj)
238
+ )
293
239
 
294
240
  @classmethod
295
241
  def _unflatten_with_path(cls, dump, name_path: list[str]):
@@ -300,14 +246,12 @@ class Feature(BaseModel):
300
246
  lst = copy.copy(name_path)
301
247
 
302
248
  if inspect.isclass(anno) and issubclass(anno, Feature):
303
- if not cls._is_shallow:
304
- lst.append(name_norm)
305
-
249
+ lst.append(name_norm)
306
250
  val = anno._unflatten_with_path(dump, lst)
307
251
  res[name] = val
308
252
  else:
309
253
  lst.append(name_norm)
310
- curr_path = cls._delimiter.join(lst)
254
+ curr_path = DEFAULT_DELIMITER.join(lst)
311
255
  res[name] = dump[curr_path]
312
256
  return cls(**res)
313
257
 
@@ -336,6 +280,18 @@ class Feature(BaseModel):
336
280
  pos += 1
337
281
  return res, pos
338
282
 
283
+ @classmethod
284
+ @lru_cache(maxsize=1000)
285
+ def build_tree(cls):
286
+ res = {}
287
+
288
+ for name, f_info in cls.model_fields.items():
289
+ anno = f_info.annotation
290
+ subtree = anno.build_tree() if Feature.is_feature(anno) else None
291
+ res[name] = (anno, subtree)
292
+
293
+ return res
294
+
339
295
 
340
296
  class RestrictedAttribute:
341
297
  """Descriptor implementing an attribute that can only be accessed through
@@ -374,7 +330,7 @@ class FeatureAttributeWrapper:
374
330
 
375
331
  @property
376
332
  def name(self) -> str:
377
- return self.cls._delimiter.join(self.prefix)
333
+ return DEFAULT_DELIMITER.join(self.prefix)
378
334
 
379
335
  def __getattr__(self, name):
380
336
  field_info = self.cls.model_fields.get(name)
@@ -401,22 +357,16 @@ def _resolve(cls, name, field_info, prefix: list[str]):
401
357
  except TypeError:
402
358
  anno_sql_class = NullType
403
359
  new_prefix = copy.copy(prefix)
404
- if not cls._is_shallow:
405
- new_prefix.append(norm_name)
406
- return C(cls._delimiter.join(new_prefix), anno_sql_class)
360
+ new_prefix.append(norm_name)
361
+ return C(DEFAULT_DELIMITER.join(new_prefix), anno_sql_class)
407
362
 
408
- if not cls._is_shallow:
409
- return FeatureAttributeWrapper(anno, [*prefix, norm_name])
410
-
411
- new_prefix_value = copy.copy(prefix)
412
- if not cls._is_shallow:
413
- new_prefix_value.append(norm_name)
414
- return FeatureAttributeWrapper(anno, new_prefix_value)
363
+ return FeatureAttributeWrapper(anno, [*prefix, norm_name])
415
364
 
416
365
 
417
366
  def convert_type_to_datachain(typ): # noqa: PLR0911
418
367
  if inspect.isclass(typ) and issubclass(typ, SQLType):
419
368
  return typ
369
+
420
370
  res = TYPE_TO_DATACHAIN.get(typ)
421
371
  if res:
422
372
  return res
@@ -430,7 +380,12 @@ def convert_type_to_datachain(typ): # noqa: PLR0911
430
380
  if inspect.isclass(orig) and (issubclass(list, orig) or issubclass(tuple, orig)):
431
381
  if args is None or len(args) != 1:
432
382
  raise TypeError(f"Cannot resolve type '{typ}' for flattening features")
433
- next_type = convert_type_to_datachain(args[0])
383
+
384
+ args0 = args[0]
385
+ if Feature.is_feature(args0):
386
+ return Array(JSON())
387
+
388
+ next_type = convert_type_to_datachain(args0)
434
389
  return Array(next_type)
435
390
 
436
391
  if inspect.isclass(orig) and issubclass(dict, orig):
@@ -443,10 +398,10 @@ def convert_type_to_datachain(typ): # noqa: PLR0911
443
398
  if orig == Union and len(args) >= 2:
444
399
  args_no_nones = [arg for arg in args if arg != type(None)]
445
400
  if len(args_no_nones) == 2:
446
- args_no_dicts = [arg for arg in args_no_nones if arg != dict]
447
- if len(args_no_dicts) == 1 and get_origin(args_no_dicts[0]) == list:
401
+ args_no_dicts = [arg for arg in args_no_nones if arg is not dict]
402
+ if len(args_no_dicts) == 1 and get_origin(args_no_dicts[0]) is list:
448
403
  arg = get_args(args_no_dicts[0])
449
- if len(arg) == 1 and arg[0] == dict:
404
+ if len(arg) == 1 and arg[0] is dict:
450
405
  return JSON
451
406
 
452
407
  raise TypeError(f"Cannot recognize type {typ}")
@@ -1,5 +1,8 @@
1
+ import logging
1
2
  from typing import Any, ClassVar, Optional
2
3
 
4
+ logger = logging.getLogger(__name__)
5
+
3
6
 
4
7
  class Registry:
5
8
  reg: ClassVar[dict[str, dict[int, Any]]] = {}
@@ -14,7 +17,7 @@ class Registry:
14
17
  version = fr._version # type: ignore[attr-defined]
15
18
  if version in cls.reg[name]:
16
19
  full_name = f"{name}@{version}"
17
- raise ValueError(f"Feature {full_name} is already registered")
20
+ logger.warning("Feature %s is already registered", full_name)
18
21
  cls.reg[name][version] = fr
19
22
 
20
23
  @classmethod
@@ -40,7 +40,7 @@ def pydantic_to_feature(data_cls: type[BaseModel]) -> type[Feature]:
40
40
  anno = field_info.annotation
41
41
  if anno not in TYPE_TO_DATACHAIN:
42
42
  orig = get_origin(anno)
43
- if orig == list:
43
+ if orig is list:
44
44
  anno = get_args(anno) # type: ignore[assignment]
45
45
  if isinstance(anno, Sequence):
46
46
  anno = anno[0] # type: ignore[unreachable]
@@ -122,7 +122,7 @@ def features_to_tuples(
122
122
  if isinstance(output, dict):
123
123
  raise FeatureToTupleError(
124
124
  ds_name,
125
- f"output type must be dict[str, FeatureType] while "
125
+ "output type must be dict[str, FeatureType] while "
126
126
  f"'{type(output).__name__}' is given",
127
127
  )
128
128
  else:
datachain/lib/file.py CHANGED
@@ -1,30 +1,24 @@
1
1
  import json
2
2
  from abc import ABC, abstractmethod
3
3
  from datetime import datetime
4
- from io import BytesIO
5
4
  from pathlib import Path
6
- from typing import Any, ClassVar, Literal, Optional, Union
5
+ from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, Union
6
+ from urllib.parse import unquote, urlparse
7
+ from urllib.request import url2pathname
7
8
 
8
- from fsspec import Callback
9
+ from fsspec.implementations.local import LocalFileSystem
9
10
  from pydantic import Field, field_validator
10
11
 
11
- from datachain.lib.feature import Feature
12
- from datachain.utils import TIME_ZERO
13
-
14
- try:
15
- from PIL import Image
16
- except ImportError as exc:
17
- raise ImportError(
18
- "Missing dependencies for computer vision:\n"
19
- "To install run:\n\n"
20
- " pip install 'datachain[cv]'\n"
21
- ) from exc
22
-
23
12
  from datachain.cache import UniqueId
24
13
  from datachain.client.fileslice import FileSlice
25
14
  from datachain.lib.cached_stream import PreCachedStream, PreDownloadStream
15
+ from datachain.lib.feature import Feature
26
16
  from datachain.lib.utils import DataChainError
27
17
  from datachain.sql.types import JSON, Int, String
18
+ from datachain.utils import TIME_ZERO
19
+
20
+ if TYPE_CHECKING:
21
+ from datachain.catalog import Catalog
28
22
 
29
23
 
30
24
  class FileFeature(Feature):
@@ -49,7 +43,7 @@ class VFileError(DataChainError):
49
43
 
50
44
  class FileError(DataChainError):
51
45
  def __init__(self, file: "File", message: str):
52
- super().__init__(f"Error in file {file.get_full_path()}: {message}")
46
+ super().__init__(f"Error in file {file.get_uri()}: {message}")
53
47
 
54
48
 
55
49
  class VFile(ABC):
@@ -190,26 +184,17 @@ class File(FileFeature):
190
184
 
191
185
  def open(self):
192
186
  if self._stream is None:
193
- if self._catalog is None:
194
- raise FileError(self, "stream is not set")
195
- self._stream = self._open_stream()
187
+ raise FileError(self, "stream is not set")
196
188
 
197
189
  if self.location:
198
190
  return VFileRegistry.resolve(self, self.location)
199
191
 
200
192
  return self._stream
201
193
 
202
- def _set_stream(
203
- self, catalog=None, stream=None, caching_enabled: bool = False
204
- ) -> None:
205
- if self._catalog is None and catalog is None:
206
- raise DataChainError(f"Cannot set file '{stream}' without catalog")
207
-
208
- if catalog:
209
- self._catalog = catalog
210
-
194
+ def _set_stream(self, catalog: "Catalog", caching_enabled: bool = False) -> None:
195
+ self._catalog = catalog
211
196
  stream_class = PreCachedStream if caching_enabled else PreDownloadStream
212
- self._stream = stream_class(stream, self.size, self._catalog, self.get_uid())
197
+ self._stream = stream_class(self._catalog, self.get_uid())
213
198
  self._caching_enabled = caching_enabled
214
199
 
215
200
  def get_uid(self) -> UniqueId:
@@ -237,22 +222,23 @@ class File(FileFeature):
237
222
  def get_full_name(self):
238
223
  return (Path(self.parent) / self.name).as_posix()
239
224
 
240
- def get_full_path(self):
225
+ def get_uri(self):
241
226
  return f"{self.source}/{self.get_full_name()}"
242
227
 
243
- def _open_stream(self, cache: bool = False, cb: Optional[Callback] = None):
244
- client = self._catalog.get_client(self.source)
245
- uid = self.get_uid()
246
- return client.open_object(uid, use_cache=cache, cb=cb)
228
+ def get_path(self) -> str:
229
+ path = unquote(self.get_uri())
230
+ fs = self.get_fs()
231
+ if isinstance(fs, LocalFileSystem):
232
+ # Drop file:// protocol
233
+ path = urlparse(path).path
234
+ path = url2pathname(path)
235
+ return path
247
236
 
248
-
249
- BinaryFile = File
237
+ def get_fs(self):
238
+ return self._catalog.get_client(self.source).fs
250
239
 
251
240
 
252
- class ImageFile(File):
253
- def get_value(self):
254
- value = super().get_value()
255
- return Image.open(BytesIO(value))
241
+ BinaryFile = File
256
242
 
257
243
 
258
244
  class TextFile(File):
@@ -260,10 +246,8 @@ class TextFile(File):
260
246
  super().__init__(**kwargs)
261
247
  self._stream = None
262
248
 
263
- def _set_stream(
264
- self, catalog=None, stream=None, caching_enabled: bool = False
265
- ) -> None:
266
- super()._set_stream(catalog, stream, caching_enabled)
249
+ def _set_stream(self, catalog: "Catalog", caching_enabled: bool = False) -> None:
250
+ super()._set_stream(catalog, caching_enabled)
267
251
  self._stream.set_mode("r")
268
252
 
269
253
 
@@ -272,6 +256,8 @@ def get_file(type: Literal["binary", "text", "image"] = "binary"):
272
256
  if type == "text":
273
257
  file = TextFile
274
258
  elif type == "image":
259
+ from datachain.lib.image import ImageFile
260
+
275
261
  file = ImageFile # type: ignore[assignment]
276
262
 
277
263
  def get_file_type(
datachain/lib/image.py CHANGED
@@ -1,7 +1,8 @@
1
1
  import inspect
2
+ from io import BytesIO
2
3
  from typing import Any, Callable, Optional
3
4
 
4
- from datachain.lib.file import ImageFile
5
+ from datachain.lib.file import File
5
6
 
6
7
  try:
7
8
  import torch
@@ -16,6 +17,12 @@ except ImportError as exc:
16
17
  from datachain.lib.reader import FeatureReader
17
18
 
18
19
 
20
+ class ImageFile(File):
21
+ def get_value(self):
22
+ value = super().get_value()
23
+ return Image.open(BytesIO(value))
24
+
25
+
19
26
  def convert_image(
20
27
  img: Image.Image,
21
28
  mode: str = "RGB",
@@ -48,7 +55,7 @@ def convert_image(
48
55
  and inspect.ismethod(getattr(open_clip_model, method_name))
49
56
  ):
50
57
  raise ValueError(
51
- f"Unable to render Image: 'open_clip_model' doesn't support"
58
+ "Unable to render Image: 'open_clip_model' doesn't support"
52
59
  f" '{method_name}()'"
53
60
  )
54
61
  img = open_clip_model.encode_image(img)
@@ -11,6 +11,7 @@ from collections.abc import Iterator
11
11
  from typing import Any, Callable
12
12
 
13
13
  import jmespath as jsp
14
+ from pydantic import ValidationError
14
15
 
15
16
  from datachain.lib.feature_utils import pydantic_to_feature # noqa: F401
16
17
  from datachain.lib.file import File
@@ -25,46 +26,48 @@ def generate_uuid():
25
26
  # JSON decoder
26
27
  def load_json_from_string(json_string):
27
28
  try:
28
- data = json.loads(json_string)
29
- print("Successfully parsed JSON", file=sys.stderr)
30
- return data
29
+ return json.loads(json_string)
31
30
  except json.JSONDecodeError:
32
- print("Failed to decode JSON: The string is not formatted correctly.")
33
- return None
31
+ print(f"Failed to decode JSON: {json_string} is not formatted correctly.")
32
+ return None
34
33
 
35
34
 
36
- # Read valid JSON and return a data object sample
35
+ # Validate and reduce JSON
37
36
  def process_json(data_string, jmespath):
38
37
  json_dict = load_json_from_string(data_string)
39
38
  if jmespath:
40
39
  json_dict = jsp.search(jmespath, json_dict)
41
- # we allow non-list JSONs here to print the root schema
42
- # but if jmespath expression is given, we assume a list
43
- if not isinstance(json_dict, list):
44
- raise ValueError("JMESPATH expression must resolve to a list")
45
- return None
46
- json_dict = json_dict[0] # sample the first object
47
- return json.dumps(json_dict)
40
+ return json_dict
48
41
 
49
42
 
50
43
  # Print a dynamic datamodel-codegen output from JSON or CSV on stdout
51
- def read_schema(source_file, data_type="csv", expr=None):
44
+ def read_schema(source_file, data_type="csv", expr=None, model_name=None):
52
45
  data_string = ""
53
- uid_str = str(generate_uuid()).replace("-", "") # comply with Python class names
54
46
  # using uiid to get around issue #1617
55
- model_name = f"Model{uid_str}"
47
+ if not model_name:
48
+ uid_str = str(generate_uuid()).replace(
49
+ "-", ""
50
+ ) # comply with Python class names
51
+ model_name = f"Model{data_type}{uid_str}"
56
52
  try:
57
53
  with source_file.open() as fd: # CSV can be larger than memory
58
54
  if data_type == "csv":
59
55
  data_string += fd.readline().decode("utf-8", "ignore").replace("\r", "")
60
56
  data_string += fd.readline().decode("utf-8", "ignore").replace("\r", "")
57
+ elif data_type == "jsonl":
58
+ data_string = fd.readline().decode("utf-8", "ignore").replace("\r", "")
61
59
  else:
62
60
  data_string = fd.read() # other meta must fit into RAM
63
61
  except OSError as e:
64
62
  print(f"An unexpected file error occurred: {e}")
65
63
  return
66
- if data_type == "json":
67
- data_string = process_json(data_string, expr)
64
+ if data_type in ("json", "jsonl"):
65
+ json_object = process_json(data_string, expr)
66
+ if data_type == "json" and isinstance(json_object, list):
67
+ json_object = json_object[0] # sample the 1st object from JSON array
68
+ if data_type == "jsonl":
69
+ data_type = "json" # treat json line as plain JSON in auto-schema
70
+ data_string = json.dumps(json_object)
68
71
  command = [
69
72
  "datamodel-codegen",
70
73
  "--input-file-type",
@@ -73,8 +76,8 @@ def read_schema(source_file, data_type="csv", expr=None):
73
76
  model_name,
74
77
  ]
75
78
  try:
76
- result = subprocess.run(
77
- command, # noqa: S603
79
+ result = subprocess.run( # noqa: S603
80
+ command,
78
81
  input=data_string,
79
82
  text=True,
80
83
  capture_output=True,
@@ -87,13 +90,19 @@ def read_schema(source_file, data_type="csv", expr=None):
87
90
  model_output = f"An error occurred in datamodel-codegen: {e.stderr}"
88
91
  print(f"{model_output}")
89
92
  print("\n" + f"spec=pydantic_to_feature({model_name})" + "\n")
93
+ return model_output
90
94
 
91
95
 
92
96
  #
93
97
  # UDF mapper which calls chain in the setup to infer the dynamic schema
94
98
  #
95
- def read_meta(
96
- spec=None, schema_from=None, meta_type="json", jmespath=None, show_schema=False
99
+ def read_meta( # noqa: C901
100
+ spec=None,
101
+ schema_from=None,
102
+ meta_type="json",
103
+ jmespath=None,
104
+ show_schema=False,
105
+ model_name=None,
97
106
  ) -> Callable:
98
107
  from datachain.lib.dc import DataChain
99
108
 
@@ -108,7 +117,7 @@ def read_meta(
108
117
  .limit(1)
109
118
  .map( # dummy column created (#1615)
110
119
  meta_schema=lambda file: read_schema(
111
- file, data_type=meta_type, expr=jmespath
120
+ file, data_type=meta_type, expr=jmespath, model_name=model_name
112
121
  ),
113
122
  output=str,
114
123
  )
@@ -119,6 +128,7 @@ def read_meta(
119
128
  sys.stdout = current_stdout
120
129
  model_output = captured_output.getvalue()
121
130
  captured_output.close()
131
+
122
132
  if show_schema:
123
133
  print(f"{model_output}")
124
134
  # Below 'spec' should be a dynamically converted Feature from Pydantic datamodel
@@ -135,30 +145,52 @@ def read_meta(
135
145
  #
136
146
  # UDF mapper parsing a JSON or CSV file using schema spec
137
147
  #
148
+
138
149
  def parse_data(
139
- file: File, data_model=spec, meta_type=meta_type, jmespath=jmespath
150
+ file: File,
151
+ DataModel=spec, # noqa: N803
152
+ meta_type=meta_type,
153
+ jmespath=jmespath,
140
154
  ) -> Iterator[spec]:
155
+ def validator(json_object: dict) -> spec:
156
+ json_string = json.dumps(json_object)
157
+ try:
158
+ data_instance = DataModel.model_validate_json(json_string)
159
+ yield data_instance
160
+ except ValidationError as e:
161
+ print(f"Validation error occurred in file {file.name}:", e)
162
+
141
163
  if meta_type == "csv":
142
164
  with (
143
165
  file.open() as fd
144
166
  ): # TODO: if schema is statically given, should allow CSV without headers
145
167
  reader = csv.DictReader(fd)
146
168
  for row in reader: # CSV can be larger than memory
147
- json_string = json.dumps(row)
148
- yield data_model.model_validate_json(json_string)
169
+ yield from validator(row)
170
+
149
171
  if meta_type == "json":
150
172
  try:
151
173
  with file.open() as fd: # JSON must fit into RAM
152
174
  data_string = fd.read()
153
175
  except OSError as e:
154
- print(f"An unexpected file error occurred: {e}")
155
- json_object = load_json_from_string(data_string)
156
- if jmespath:
157
- json_object = jsp.search(jmespath, json_object)
176
+ print(f"An unexpected file error occurred in file {file.name}: {e}")
177
+ json_object = process_json(data_string, jmespath)
158
178
  if not isinstance(json_object, list):
159
- raise ValueError("JSON expression must resolve in a list of objects")
160
- for json_dict in json_object:
161
- json_string = json.dumps(json_dict)
162
- yield data_model.model_validate_json(json_string)
179
+ yield from validator(json_object)
180
+
181
+ else:
182
+ for json_dict in json_object:
183
+ yield from validator(json_dict)
184
+
185
+ if meta_type == "jsonl":
186
+ try:
187
+ with file.open() as fd:
188
+ data_string = fd.readline().replace("\r", "")
189
+ while data_string:
190
+ json_object = process_json(data_string, jmespath)
191
+ data_string = fd.readline()
192
+ yield from validator(json_object)
193
+ except OSError as e:
194
+ print(f"An unexpected file error occurred in file {file.name}: {e}")
163
195
 
164
196
  return parse_data