datachain 0.3.7__py3-none-any.whl → 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -1540,87 +1540,6 @@ class Catalog:
1540
1540
  dataset = self.get_dataset(name)
1541
1541
  return self.update_dataset(dataset, **update_data)
1542
1542
 
1543
- def merge_datasets(
1544
- self,
1545
- src: DatasetRecord,
1546
- dst: DatasetRecord,
1547
- src_version: int,
1548
- dst_version: Optional[int] = None,
1549
- ) -> DatasetRecord:
1550
- """
1551
- Merges records from source to destination dataset.
1552
- It will create new version
1553
- of a dataset with records merged from old version and the source, unless
1554
- existing version is specified for destination in which case it must
1555
- be in non final status as datasets are immutable
1556
- """
1557
- if (
1558
- dst_version
1559
- and not dst.is_valid_next_version(dst_version)
1560
- and dst.get_version(dst_version).is_final_status()
1561
- ):
1562
- raise DatasetInvalidVersionError(
1563
- f"Version {dst_version} must be higher than the current latest one"
1564
- )
1565
-
1566
- src_dep = self.get_dataset_dependencies(src.name, src_version)
1567
- dst_dep = self.get_dataset_dependencies(
1568
- dst.name,
1569
- dst.latest_version, # type: ignore[arg-type]
1570
- )
1571
-
1572
- if dst.has_version(dst_version): # type: ignore[arg-type]
1573
- # case where we don't create new version, but append to the existing one
1574
- self.warehouse.merge_dataset_rows(
1575
- src,
1576
- dst,
1577
- src_version,
1578
- dst_version=dst_version, # type: ignore[arg-type]
1579
- )
1580
- merged_schema = src.serialized_schema | dst.serialized_schema
1581
- self.update_dataset(dst, schema=merged_schema)
1582
- self.update_dataset_version_with_warehouse_info(
1583
- dst,
1584
- dst_version, # type: ignore[arg-type]
1585
- schema=merged_schema,
1586
- )
1587
- for dep in src_dep:
1588
- if dep and dep not in dst_dep:
1589
- self.metastore.add_dependency(
1590
- dep,
1591
- dst.name,
1592
- dst_version, # type: ignore[arg-type]
1593
- )
1594
- else:
1595
- # case where we create new version of merged results
1596
- src_dr = self.warehouse.dataset_rows(src, src_version)
1597
- dst_dr = self.warehouse.dataset_rows(dst)
1598
-
1599
- merge_result_columns = list(
1600
- {
1601
- c.name: c for c in list(src_dr.table.c) + list(dst_dr.table.c)
1602
- }.values()
1603
- )
1604
-
1605
- dst_version = dst_version or dst.next_version
1606
- dst = self.create_new_dataset_version(
1607
- dst,
1608
- dst_version,
1609
- columns=merge_result_columns,
1610
- )
1611
- self.warehouse.merge_dataset_rows(
1612
- src,
1613
- dst,
1614
- src_version,
1615
- dst_version,
1616
- )
1617
- self.update_dataset_version_with_warehouse_info(dst, dst_version)
1618
- for dep in set(src_dep + dst_dep):
1619
- if dep:
1620
- self.metastore.add_dependency(dep, dst.name, dst_version)
1621
-
1622
- return dst
1623
-
1624
1543
  def get_file_signals(
1625
1544
  self, dataset_name: str, dataset_version: int, row: RowDict
1626
1545
  ) -> Optional[dict]:
@@ -1641,17 +1560,8 @@ class Catalog:
1641
1560
  version = self.get_dataset(dataset_name).get_version(dataset_version)
1642
1561
 
1643
1562
  file_signals_values = {}
1644
- file_schemas = {}
1645
- # TODO: To remove after we properly fix deserialization
1646
- for signal, type_name in version.feature_schema.items():
1647
- from datachain.lib.model_store import ModelStore
1648
-
1649
- type_name_parsed, v = ModelStore.parse_name_version(type_name)
1650
- fr = ModelStore.get(type_name_parsed, v)
1651
- if fr and issubclass(fr, File):
1652
- file_schemas[signal] = type_name
1653
1563
 
1654
- schema = SignalSchema.deserialize(file_schemas)
1564
+ schema = SignalSchema.deserialize(version.feature_schema)
1655
1565
  for file_signals in schema.get_signals(File):
1656
1566
  prefix = file_signals.replace(".", DEFAULT_DELIMITER) + DEFAULT_DELIMITER
1657
1567
  file_signals_values[file_signals] = {
@@ -1997,7 +1907,7 @@ class Catalog:
1997
1907
  """
1998
1908
  from datachain.query.dataset import ExecutionResult
1999
1909
 
2000
- feature_file = tempfile.NamedTemporaryFile(
1910
+ feature_file = tempfile.NamedTemporaryFile( # noqa: SIM115
2001
1911
  dir=os.getcwd(), suffix=".py", delete=False
2002
1912
  )
2003
1913
  _, feature_module = os.path.split(feature_file.name)
datachain/cli.py CHANGED
@@ -336,36 +336,6 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
336
336
  help="Display size using powers of 1000 not 1024",
337
337
  )
338
338
 
339
- parse_merge_datasets = subp.add_parser(
340
- "merge-datasets", parents=[parent_parser], description="Merges datasets"
341
- )
342
- parse_merge_datasets.add_argument(
343
- "--src",
344
- action="store",
345
- default=None,
346
- help="Source dataset name",
347
- )
348
- parse_merge_datasets.add_argument(
349
- "--dst",
350
- action="store",
351
- default=None,
352
- help="Destination dataset name",
353
- )
354
- parse_merge_datasets.add_argument(
355
- "--src-version",
356
- action="store",
357
- default=None,
358
- type=int,
359
- help="Source dataset version",
360
- )
361
- parse_merge_datasets.add_argument(
362
- "--dst-version",
363
- action="store",
364
- default=None,
365
- type=int,
366
- help="Destination dataset version",
367
- )
368
-
369
339
  parse_ls = subp.add_parser(
370
340
  "ls", parents=[parent_parser], description="List storage contents"
371
341
  )
@@ -996,13 +966,6 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
996
966
  new_name=args.new_name,
997
967
  labels=args.labels,
998
968
  )
999
- elif args.command == "merge-datasets":
1000
- catalog.merge_datasets(
1001
- catalog.get_dataset(args.src),
1002
- catalog.get_dataset(args.dst),
1003
- args.src_version,
1004
- dst_version=args.dst_version,
1005
- )
1006
969
  elif args.command == "ls":
1007
970
  ls(
1008
971
  args.sources,
datachain/lib/arrow.py CHANGED
@@ -95,7 +95,7 @@ def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = Non
95
95
  if not column:
96
96
  column = f"c{default_column}"
97
97
  default_column += 1
98
- dtype = _arrow_type_mapper(field.type) # type: ignore[assignment]
98
+ dtype = arrow_type_mapper(field.type) # type: ignore[assignment]
99
99
  if field.nullable:
100
100
  dtype = Optional[dtype] # type: ignore[assignment]
101
101
  output[column] = dtype
@@ -103,7 +103,7 @@ def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = Non
103
103
  return output
104
104
 
105
105
 
106
- def _arrow_type_mapper(col_type: pa.DataType) -> type: # noqa: PLR0911
106
+ def arrow_type_mapper(col_type: pa.DataType) -> type: # noqa: PLR0911
107
107
  """Convert pyarrow types to basic types."""
108
108
  from datetime import datetime
109
109
 
@@ -122,16 +122,16 @@ def _arrow_type_mapper(col_type: pa.DataType) -> type: # noqa: PLR0911
122
122
  if pa.types.is_string(col_type) or pa.types.is_large_string(col_type):
123
123
  return str
124
124
  if pa.types.is_list(col_type):
125
- return list[_arrow_type_mapper(col_type.value_type)] # type: ignore[return-value, misc]
125
+ return list[arrow_type_mapper(col_type.value_type)] # type: ignore[return-value, misc]
126
126
  if pa.types.is_struct(col_type) or pa.types.is_map(col_type):
127
127
  return dict
128
128
  if isinstance(col_type, pa.lib.DictionaryType):
129
- return _arrow_type_mapper(col_type.value_type) # type: ignore[return-value]
129
+ return arrow_type_mapper(col_type.value_type) # type: ignore[return-value]
130
130
  raise TypeError(f"{col_type!r} datatypes not supported")
131
131
 
132
132
 
133
133
  def _nrows_file(file: File, nrows: int) -> str:
134
- tf = NamedTemporaryFile(delete=False)
134
+ tf = NamedTemporaryFile(delete=False) # noqa: SIM115
135
135
  with file.open(mode="r") as reader:
136
136
  with open(tf.name, "a") as writer:
137
137
  for row, line in enumerate(reader):
datachain/lib/clip.py CHANGED
@@ -1,5 +1,5 @@
1
1
  import inspect
2
- from typing import TYPE_CHECKING, Any, Callable, Literal, Union
2
+ from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
3
3
 
4
4
  import torch
5
5
  from transformers.modeling_utils import PreTrainedModel
@@ -39,6 +39,7 @@ def clip_similarity_scores(
39
39
  tokenizer: Callable,
40
40
  prob: bool = False,
41
41
  image_to_text: bool = True,
42
+ device: Optional[Union[str, torch.device]] = None,
42
43
  ) -> list[list[float]]:
43
44
  """
44
45
  Calculate CLIP similarity scores between one or more images and/or text.
@@ -52,6 +53,7 @@ def clip_similarity_scores(
52
53
  prob : Compute softmax probabilities.
53
54
  image_to_text : Whether to compute for image-to-text or text-to-image. Ignored
54
55
  if only one of images or text provided.
56
+ device : Device to use. Defaults is None - use model's device.
55
57
 
56
58
 
57
59
  Example:
@@ -130,17 +132,26 @@ def clip_similarity_scores(
130
132
  ```
131
133
  """
132
134
 
135
+ if device is None:
136
+ if hasattr(model, "device"):
137
+ device = model.device
138
+ else:
139
+ device = next(model.parameters()).device
140
+ else:
141
+ model = model.to(device)
133
142
  with torch.no_grad():
134
143
  if images is not None:
135
144
  encoder = _get_encoder(model, "image")
136
145
  image_features = convert_images(
137
- images, transform=preprocess, encoder=encoder
146
+ images, transform=preprocess, encoder=encoder, device=device
138
147
  )
139
148
  image_features /= image_features.norm(dim=-1, keepdim=True) # type: ignore[union-attr]
140
149
 
141
150
  if text is not None:
142
151
  encoder = _get_encoder(model, "text")
143
- text_features = convert_text(text, tokenizer, encoder=encoder)
152
+ text_features = convert_text(
153
+ text, tokenizer, encoder=encoder, device=device
154
+ )
144
155
  text_features /= text_features.norm(dim=-1, keepdim=True) # type: ignore[union-attr]
145
156
 
146
157
  if images is not None and text is not None:
@@ -73,6 +73,9 @@ def python_to_sql(typ): # noqa: PLR0911
73
73
  if len(args) == 2 and (type(None) in args):
74
74
  return python_to_sql(args[0])
75
75
 
76
+ if _is_union_str_literal(orig, args):
77
+ return String
78
+
76
79
  if _is_json_inside_union(orig, args):
77
80
  return JSON
78
81
 
@@ -94,3 +97,9 @@ def _is_json_inside_union(orig, args) -> bool:
94
97
  if any(inspect.isclass(arg) and issubclass(arg, BaseModel) for arg in args):
95
98
  return True
96
99
  return False
100
+
101
+
102
+ def _is_union_str_literal(orig, args) -> bool:
103
+ if orig != Union:
104
+ return False
105
+ return all(arg is str or get_origin(arg) in (Literal, LiteralEx) for arg in args)
@@ -2,7 +2,7 @@ from collections.abc import Sequence
2
2
  from datetime import datetime
3
3
  from typing import ClassVar, Union, get_args, get_origin
4
4
 
5
- from pydantic import BaseModel
5
+ from pydantic import BaseModel, create_model
6
6
 
7
7
  from datachain.lib.model_store import ModelStore
8
8
 
@@ -57,3 +57,12 @@ def is_chain_type(t: type) -> bool:
57
57
  return is_chain_type(args[0])
58
58
 
59
59
  return False
60
+
61
+
62
+ def dict_to_data_model(name: str, data_dict: dict[str, DataType]) -> type[BaseModel]:
63
+ fields = {name: (anno, ...) for name, anno in data_dict.items()}
64
+ return create_model(
65
+ name,
66
+ __base__=(DataModel,), # type: ignore[call-overload]
67
+ **fields,
68
+ ) # type: ignore[call-overload]
datachain/lib/dc.py CHANGED
@@ -18,14 +18,13 @@ from typing import (
18
18
 
19
19
  import pandas as pd
20
20
  import sqlalchemy
21
- from pydantic import BaseModel, create_model
21
+ from pydantic import BaseModel
22
22
  from sqlalchemy.sql.functions import GenericFunction
23
23
  from sqlalchemy.sql.sqltypes import NullType
24
24
 
25
- from datachain import DataModel
26
25
  from datachain.lib.convert.python_to_sql import python_to_sql
27
26
  from datachain.lib.convert.values_to_tuples import values_to_tuples
28
- from datachain.lib.data_model import DataType
27
+ from datachain.lib.data_model import DataModel, DataType, dict_to_data_model
29
28
  from datachain.lib.dataset_info import DatasetInfo
30
29
  from datachain.lib.file import ExportPlacement as FileExportPlacement
31
30
  from datachain.lib.file import File, IndexedFile, get_file
@@ -55,6 +54,8 @@ from datachain.utils import inside_notebook
55
54
  if TYPE_CHECKING:
56
55
  from typing_extensions import Concatenate, ParamSpec, Self
57
56
 
57
+ from datachain.lib.hf import HFDatasetType
58
+
58
59
  P = ParamSpec("P")
59
60
 
60
61
  C = Column
@@ -77,12 +78,12 @@ def resolve_columns(
77
78
  @wraps(method)
78
79
  def _inner(self: D, *args: "P.args", **kwargs: "P.kwargs") -> D:
79
80
  resolved_args = self.signals_schema.resolve(
80
- *[arg for arg in args if not isinstance(arg, GenericFunction)]
81
+ *[arg for arg in args if not isinstance(arg, GenericFunction)] # type: ignore[arg-type]
81
82
  ).db_signals()
82
83
 
83
84
  for idx, arg in enumerate(args):
84
85
  if isinstance(arg, GenericFunction):
85
- resolved_args.insert(idx, arg)
86
+ resolved_args.insert(idx, arg) # type: ignore[arg-type]
86
87
 
87
88
  return method(self, *resolved_args, **kwargs)
88
89
 
@@ -208,23 +209,28 @@ class DataChain(DatasetQuery):
208
209
  "size": 0,
209
210
  }
210
211
 
211
- def __init__(self, *args, **kwargs):
212
+ def __init__(self, *args, settings: Optional[dict] = None, **kwargs):
212
213
  """This method needs to be redefined as a part of Dataset and DataChain
213
214
  decoupling.
214
215
  """
215
- super().__init__(
216
+ super().__init__( # type: ignore[misc]
216
217
  *args,
217
218
  **kwargs,
218
219
  indexing_column_types=File._datachain_column_types,
219
220
  )
220
- self._settings = Settings()
221
- self._setup = {}
221
+ if settings:
222
+ self._settings = Settings(**settings)
223
+ else:
224
+ self._settings = Settings()
225
+ self._setup: dict = {}
222
226
 
223
227
  self.signals_schema = SignalSchema({"sys": Sys})
224
228
  if self.feature_schema:
225
229
  self.signals_schema |= SignalSchema.deserialize(self.feature_schema)
226
230
  else:
227
- self.signals_schema |= SignalSchema.from_column_types(self.column_types)
231
+ self.signals_schema |= SignalSchema.from_column_types(
232
+ self.column_types or {}
233
+ )
228
234
 
229
235
  self._sys = False
230
236
 
@@ -309,6 +315,7 @@ class DataChain(DatasetQuery):
309
315
  *,
310
316
  type: Literal["binary", "text", "image"] = "binary",
311
317
  session: Optional[Session] = None,
318
+ settings: Optional[dict] = None,
312
319
  in_memory: bool = False,
313
320
  recursive: Optional[bool] = True,
314
321
  object_name: str = "file",
@@ -336,6 +343,7 @@ class DataChain(DatasetQuery):
336
343
  cls(
337
344
  path,
338
345
  session=session,
346
+ settings=settings,
339
347
  recursive=recursive,
340
348
  update=update,
341
349
  in_memory=in_memory,
@@ -489,6 +497,7 @@ class DataChain(DatasetQuery):
489
497
  def datasets(
490
498
  cls,
491
499
  session: Optional[Session] = None,
500
+ settings: Optional[dict] = None,
492
501
  in_memory: bool = False,
493
502
  object_name: str = "dataset",
494
503
  ) -> "DataChain":
@@ -513,6 +522,7 @@ class DataChain(DatasetQuery):
513
522
 
514
523
  return cls.from_values(
515
524
  session=session,
525
+ settings=settings,
516
526
  in_memory=in_memory,
517
527
  output={object_name: DatasetInfo},
518
528
  **{object_name: datasets}, # type: ignore[arg-type]
@@ -895,7 +905,7 @@ class DataChain(DatasetQuery):
895
905
  if isinstance(value, Column):
896
906
  # renaming existing column
897
907
  for signal in schema.db_signals(name=value.name, as_columns=True):
898
- mutated[signal.name.replace(value.name, name, 1)] = signal
908
+ mutated[signal.name.replace(value.name, name, 1)] = signal # type: ignore[union-attr]
899
909
  else:
900
910
  # adding new signal
901
911
  mutated[name] = value
@@ -1086,7 +1096,7 @@ class DataChain(DatasetQuery):
1086
1096
  )
1087
1097
 
1088
1098
  signals_schema = self.signals_schema.clone_without_sys_signals()
1089
- on_columns = signals_schema.resolve(*on).db_signals()
1099
+ on_columns: list[str] = signals_schema.resolve(*on).db_signals() # type: ignore[assignment]
1090
1100
 
1091
1101
  right_signals_schema = right_ds.signals_schema.clone_without_sys_signals()
1092
1102
  if right_on is not None:
@@ -1105,7 +1115,9 @@ class DataChain(DatasetQuery):
1105
1115
  on, right_on, "'on' and 'right_on' must have the same length'"
1106
1116
  )
1107
1117
 
1108
- right_on_columns = right_signals_schema.resolve(*right_on).db_signals()
1118
+ right_on_columns: list[str] = right_signals_schema.resolve(
1119
+ *right_on
1120
+ ).db_signals() # type: ignore[assignment]
1109
1121
 
1110
1122
  if len(right_on_columns) != len(on_columns):
1111
1123
  on_str = ", ".join(right_on_columns)
@@ -1141,17 +1153,35 @@ class DataChain(DatasetQuery):
1141
1153
  self,
1142
1154
  other: "DataChain",
1143
1155
  on: Optional[Union[str, Sequence[str]]] = None,
1156
+ right_on: Optional[Union[str, Sequence[str]]] = None,
1144
1157
  ) -> "Self":
1145
1158
  """Remove rows that appear in another chain.
1146
1159
 
1147
1160
  Parameters:
1148
1161
  other: chain whose rows will be removed from `self`
1149
- on: columns to consider for determining row equality. If unspecified,
1150
- defaults to all common columns between `self` and `other`.
1162
+ on: columns to consider for determining row equality in `self`.
1163
+ If unspecified, defaults to all common columns
1164
+ between `self` and `other`.
1165
+ right_on: columns to consider for determining row equality in `other`.
1166
+ If unspecified, defaults to the same values as `on`.
1151
1167
  """
1152
1168
  if isinstance(on, str):
1169
+ if not on:
1170
+ raise DataChainParamsError("'on' cannot be an empty string")
1153
1171
  on = [on]
1154
- if on is None:
1172
+ elif isinstance(on, Sequence):
1173
+ if not on or any(not col for col in on):
1174
+ raise DataChainParamsError("'on' cannot contain empty strings")
1175
+
1176
+ if isinstance(right_on, str):
1177
+ if not right_on:
1178
+ raise DataChainParamsError("'right_on' cannot be an empty string")
1179
+ right_on = [right_on]
1180
+ elif isinstance(right_on, Sequence):
1181
+ if not right_on or any(not col for col in right_on):
1182
+ raise DataChainParamsError("'right_on' cannot contain empty strings")
1183
+
1184
+ if on is None and right_on is None:
1155
1185
  other_columns = set(other._effective_signals_schema.db_signals())
1156
1186
  signals = [
1157
1187
  c
@@ -1160,16 +1190,29 @@ class DataChain(DatasetQuery):
1160
1190
  ]
1161
1191
  if not signals:
1162
1192
  raise DataChainParamsError("subtract(): no common columns")
1163
- elif not isinstance(on, Sequence):
1164
- raise TypeError(
1165
- f"'on' must be 'str' or 'Sequence' object but got type '{type(on)}'",
1166
- )
1167
- elif not on:
1193
+ elif on is not None and right_on is None:
1194
+ right_on = on
1195
+ signals = list(self.signals_schema.resolve(*on).db_signals())
1196
+ elif on is None and right_on is not None:
1168
1197
  raise DataChainParamsError(
1169
- "'on' cannot be empty",
1198
+ "'on' must be specified when 'right_on' is provided"
1170
1199
  )
1171
1200
  else:
1172
- signals = self.signals_schema.resolve(*on).db_signals()
1201
+ if not isinstance(on, Sequence) or not isinstance(right_on, Sequence):
1202
+ raise TypeError(
1203
+ "'on' and 'right_on' must be 'str' or 'Sequence' object"
1204
+ )
1205
+ if len(on) != len(right_on):
1206
+ raise DataChainParamsError(
1207
+ "'on' and 'right_on' must have the same length"
1208
+ )
1209
+ signals = list(
1210
+ zip(
1211
+ self.signals_schema.resolve(*on).db_signals(),
1212
+ other.signals_schema.resolve(*right_on).db_signals(),
1213
+ ) # type: ignore[arg-type]
1214
+ )
1215
+
1173
1216
  return super()._subtract(other, signals) # type: ignore[arg-type]
1174
1217
 
1175
1218
  @classmethod
@@ -1177,6 +1220,7 @@ class DataChain(DatasetQuery):
1177
1220
  cls,
1178
1221
  ds_name: str = "",
1179
1222
  session: Optional[Session] = None,
1223
+ settings: Optional[dict] = None,
1180
1224
  in_memory: bool = False,
1181
1225
  output: OutputType = None,
1182
1226
  object_name: str = "",
@@ -1195,10 +1239,13 @@ class DataChain(DatasetQuery):
1195
1239
  yield from tuples
1196
1240
 
1197
1241
  chain = DataChain.from_records(
1198
- DataChain.DEFAULT_FILE_RECORD, session=session, in_memory=in_memory
1242
+ DataChain.DEFAULT_FILE_RECORD,
1243
+ session=session,
1244
+ settings=settings,
1245
+ in_memory=in_memory,
1199
1246
  )
1200
1247
  if object_name:
1201
- output = {object_name: DataChain._dict_to_data_model(object_name, output)} # type: ignore[arg-type]
1248
+ output = {object_name: dict_to_data_model(object_name, output)} # type: ignore[arg-type]
1202
1249
  return chain.gen(_func_fr, output=output)
1203
1250
 
1204
1251
  @classmethod
@@ -1207,6 +1254,7 @@ class DataChain(DatasetQuery):
1207
1254
  df: "pd.DataFrame",
1208
1255
  name: str = "",
1209
1256
  session: Optional[Session] = None,
1257
+ settings: Optional[dict] = None,
1210
1258
  in_memory: bool = False,
1211
1259
  object_name: str = "",
1212
1260
  ) -> "DataChain":
@@ -1236,7 +1284,12 @@ class DataChain(DatasetQuery):
1236
1284
  )
1237
1285
 
1238
1286
  return cls.from_values(
1239
- name, session, object_name=object_name, in_memory=in_memory, **fr_map
1287
+ name,
1288
+ session,
1289
+ settings=settings,
1290
+ object_name=object_name,
1291
+ in_memory=in_memory,
1292
+ **fr_map,
1240
1293
  )
1241
1294
 
1242
1295
  def to_pandas(self, flatten=False) -> "pd.DataFrame":
@@ -1306,6 +1359,59 @@ class DataChain(DatasetQuery):
1306
1359
  if len(df) == limit:
1307
1360
  print(f"\n[Limited by {len(df)} rows]")
1308
1361
 
1362
+ @classmethod
1363
+ def from_hf(
1364
+ cls,
1365
+ dataset: Union[str, "HFDatasetType"],
1366
+ *args,
1367
+ session: Optional[Session] = None,
1368
+ settings: Optional[dict] = None,
1369
+ object_name: str = "",
1370
+ model_name: str = "",
1371
+ **kwargs,
1372
+ ) -> "DataChain":
1373
+ """Generate chain from huggingface hub dataset.
1374
+
1375
+ Parameters:
1376
+ dataset : Path or name of the dataset to read from Hugging Face Hub,
1377
+ or an instance of `datasets.Dataset`-like object.
1378
+ session : Session to use for the chain.
1379
+ settings : Settings to use for the chain.
1380
+ object_name : Generated object column name.
1381
+ model_name : Generated model name.
1382
+ kwargs : Parameters to pass to datasets.load_dataset.
1383
+
1384
+ Example:
1385
+ Load from Hugging Face Hub:
1386
+ ```py
1387
+ DataChain.from_hf("beans", split="train")
1388
+ ```
1389
+
1390
+ Generate chain from loaded dataset:
1391
+ ```py
1392
+ from datasets import load_dataset
1393
+ ds = load_dataset("beans", split="train")
1394
+ DataChain.from_hf(ds)
1395
+ ```
1396
+ """
1397
+ from datachain.lib.hf import HFGenerator, get_output_schema, stream_splits
1398
+
1399
+ output: dict[str, DataType] = {}
1400
+ ds_dict = stream_splits(dataset, *args, **kwargs)
1401
+ if len(ds_dict) > 1:
1402
+ output = {"split": str}
1403
+
1404
+ model_name = model_name or object_name or ""
1405
+ output = output | get_output_schema(next(iter(ds_dict.values())), model_name)
1406
+ model = dict_to_data_model(model_name, output)
1407
+ if object_name:
1408
+ output = {object_name: model}
1409
+
1410
+ chain = DataChain.from_values(
1411
+ split=list(ds_dict.keys()), session=session, settings=settings
1412
+ )
1413
+ return chain.gen(HFGenerator(dataset, model, *args, **kwargs), output=output)
1414
+
1309
1415
  def parse_tabular(
1310
1416
  self,
1311
1417
  output: OutputType = None,
@@ -1367,7 +1473,7 @@ class DataChain(DatasetQuery):
1367
1473
 
1368
1474
  if isinstance(output, dict):
1369
1475
  model_name = model_name or object_name or ""
1370
- model = DataChain._dict_to_data_model(model_name, output)
1476
+ model = dict_to_data_model(model_name, output)
1371
1477
  else:
1372
1478
  model = output # type: ignore[assignment]
1373
1479
 
@@ -1384,17 +1490,6 @@ class DataChain(DatasetQuery):
1384
1490
  ArrowGenerator(schema, model, source, nrows, **kwargs), output=output
1385
1491
  )
1386
1492
 
1387
- @staticmethod
1388
- def _dict_to_data_model(
1389
- name: str, data_dict: dict[str, DataType]
1390
- ) -> type[BaseModel]:
1391
- fields = {name: (anno, ...) for name, anno in data_dict.items()}
1392
- return create_model(
1393
- name,
1394
- __base__=(DataModel,), # type: ignore[call-overload]
1395
- **fields,
1396
- ) # type: ignore[call-overload]
1397
-
1398
1493
  @classmethod
1399
1494
  def from_csv(
1400
1495
  cls,
@@ -1543,6 +1638,7 @@ class DataChain(DatasetQuery):
1543
1638
  cls,
1544
1639
  to_insert: Optional[Union[dict, list[dict]]],
1545
1640
  session: Optional[Session] = None,
1641
+ settings: Optional[dict] = None,
1546
1642
  in_memory: bool = False,
1547
1643
  schema: Optional[dict[str, DataType]] = None,
1548
1644
  ) -> "DataChain":
@@ -1597,7 +1693,7 @@ class DataChain(DatasetQuery):
1597
1693
  insert_q = dr.get_table().insert()
1598
1694
  for record in to_insert:
1599
1695
  db.execute(insert_q.values(**record))
1600
- return DataChain(name=dsr.name)
1696
+ return DataChain(name=dsr.name, settings=settings)
1601
1697
 
1602
1698
  def sum(self, fr: DataType): # type: ignore[override]
1603
1699
  """Compute the sum of a column."""