datachain 0.14.1__py3-none-any.whl → 0.14.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,7 +5,7 @@ from typing import (
5
5
 
6
6
  from datachain.query import Session
7
7
 
8
- from .values import from_values
8
+ from .values import read_values
9
9
 
10
10
  if TYPE_CHECKING:
11
11
  import pandas as pd
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
16
16
  P = ParamSpec("P")
17
17
 
18
18
 
19
- def from_pandas( # type: ignore[override]
19
+ def read_pandas( # type: ignore[override]
20
20
  df: "pd.DataFrame",
21
21
  name: str = "",
22
22
  session: Optional[Session] = None,
@@ -32,7 +32,7 @@ def from_pandas( # type: ignore[override]
32
32
  import datachain as dc
33
33
 
34
34
  df = pd.DataFrame({"fib": [1, 2, 3, 5, 8]})
35
- dc.from_pandas(df)
35
+ dc.read_pandas(df)
36
36
  ```
37
37
  """
38
38
  from .utils import DatasetPrepareError
@@ -46,7 +46,7 @@ def from_pandas( # type: ignore[override]
46
46
  f"import from pandas error - '{column}' cannot be a column name",
47
47
  )
48
48
 
49
- return from_values(
49
+ return read_values(
50
50
  name,
51
51
  session,
52
52
  settings=settings,
@@ -15,7 +15,7 @@ if TYPE_CHECKING:
15
15
  P = ParamSpec("P")
16
16
 
17
17
 
18
- def from_parquet(
18
+ def read_parquet(
19
19
  path,
20
20
  partitioning: Any = "hive",
21
21
  output: Optional[dict[str, DataType]] = None,
@@ -43,18 +43,18 @@ def from_parquet(
43
43
  Reading a single file:
44
44
  ```py
45
45
  import datachain as dc
46
- dc.from_parquet("s3://mybucket/file.parquet")
46
+ dc.read_parquet("s3://mybucket/file.parquet")
47
47
  ```
48
48
 
49
49
  Reading a partitioned dataset from a directory:
50
50
  ```py
51
51
  import datachain as dc
52
- dc.from_parquet("s3://mybucket/dir")
52
+ dc.read_parquet("s3://mybucket/dir")
53
53
  ```
54
54
  """
55
- from .storage import from_storage
55
+ from .storage import read_storage
56
56
 
57
- chain = from_storage(path, session=session, settings=settings, **kwargs)
57
+ chain = read_storage(path, session=session, settings=settings, **kwargs)
58
58
  return chain.parse_tabular(
59
59
  output=output,
60
60
  object_name=object_name,
@@ -21,7 +21,7 @@ if TYPE_CHECKING:
21
21
  P = ParamSpec("P")
22
22
 
23
23
 
24
- def from_records(
24
+ def read_records(
25
25
  to_insert: Optional[Union[dict, list[dict]]],
26
26
  session: Optional[Session] = None,
27
27
  settings: Optional[dict] = None,
@@ -40,10 +40,10 @@ def from_records(
40
40
  Example:
41
41
  ```py
42
42
  import datachain as dc
43
- single_record = dc.from_records(dc.DEFAULT_FILE_RECORD)
43
+ single_record = dc.read_records(dc.DEFAULT_FILE_RECORD)
44
44
  ```
45
45
  """
46
- from .datasets import from_dataset
46
+ from .datasets import read_dataset
47
47
 
48
48
  session = Session.get(session, in_memory=in_memory)
49
49
  catalog = session.catalog
@@ -87,4 +87,4 @@ def from_records(
87
87
  insert_q = dr.get_table().insert()
88
88
  for record in to_insert:
89
89
  db.execute(insert_q.values(**record))
90
- return from_dataset(name=dsr.name, session=session, settings=settings)
90
+ return read_dataset(name=dsr.name, session=session, settings=settings)
@@ -21,7 +21,7 @@ if TYPE_CHECKING:
21
21
  from .datachain import DataChain
22
22
 
23
23
 
24
- def from_storage(
24
+ def read_storage(
25
25
  uri: Union[str, os.PathLike[str], list[str], list[os.PathLike[str]]],
26
26
  *,
27
27
  type: FileType = "binary",
@@ -55,12 +55,12 @@ def from_storage(
55
55
  Simple call from s3:
56
56
  ```python
57
57
  import datachain as dc
58
- chain = dc.from_storage("s3://my-bucket/my-dir")
58
+ chain = dc.read_storage("s3://my-bucket/my-dir")
59
59
  ```
60
60
 
61
61
  Multiple URIs:
62
62
  ```python
63
- chain = dc.from_storage([
63
+ chain = dc.read_storage([
64
64
  "s3://bucket1/dir1",
65
65
  "s3://bucket2/dir2"
66
66
  ])
@@ -68,7 +68,7 @@ def from_storage(
68
68
 
69
69
  With AWS S3-compatible storage:
70
70
  ```python
71
- chain = dc.from_storage(
71
+ chain = dc.read_storage(
72
72
  "s3://my-bucket/my-dir",
73
73
  client_config = {"aws_endpoint_url": "<minio-endpoint-url>"}
74
74
  )
@@ -77,7 +77,7 @@ def from_storage(
77
77
  Pass existing session
78
78
  ```py
79
79
  session = Session.get()
80
- chain = dc.from_storage([
80
+ chain = dc.read_storage([
81
81
  "path/to/dir1",
82
82
  "path/to/dir2"
83
83
  ], session=session, recursive=True)
@@ -88,9 +88,9 @@ def from_storage(
88
88
  avoiding redundant updates for URIs pointing to the same storage location.
89
89
  """
90
90
  from .datachain import DataChain
91
- from .datasets import from_dataset
92
- from .records import from_records
93
- from .values import from_values
91
+ from .datasets import read_dataset
92
+ from .records import read_records
93
+ from .values import read_values
94
94
 
95
95
  file_type = get_file_type(type)
96
96
 
@@ -122,7 +122,8 @@ def from_storage(
122
122
  )
123
123
  continue
124
124
 
125
- dc = from_dataset(list_ds_name, session=session, settings=settings)
125
+ dc = read_dataset(list_ds_name, session=session, settings=settings)
126
+ dc._query.update = update
126
127
  dc.signals_schema = dc.signals_schema.mutate({f"{object_name}": file_type})
127
128
 
128
129
  if update or not list_ds_exists:
@@ -130,7 +131,7 @@ def from_storage(
130
131
  def lst_fn(ds_name, lst_uri):
131
132
  # disable prefetch for listing, as it pre-downloads all files
132
133
  (
133
- from_records(
134
+ read_records(
134
135
  DataChain.DEFAULT_FILE_RECORD,
135
136
  session=session,
136
137
  settings=settings,
@@ -144,7 +145,7 @@ def from_storage(
144
145
  .save(ds_name, listing=True)
145
146
  )
146
147
 
147
- dc._query.add_before_steps(
148
+ dc._query.set_listing_fn(
148
149
  lambda ds_name=list_ds_name, lst_uri=list_uri: lst_fn(ds_name, lst_uri)
149
150
  )
150
151
 
@@ -154,7 +155,7 @@ def from_storage(
154
155
  listed_ds_name.add(list_ds_name)
155
156
 
156
157
  if file_values:
157
- file_chain = from_values(
158
+ file_chain = read_values(
158
159
  session=session,
159
160
  settings=settings,
160
161
  in_memory=in_memory,
@@ -6,7 +6,7 @@ from typing import (
6
6
 
7
7
  from datachain.lib.convert.values_to_tuples import values_to_tuples
8
8
  from datachain.lib.data_model import dict_to_data_model
9
- from datachain.lib.dc.records import from_records
9
+ from datachain.lib.dc.records import read_records
10
10
  from datachain.lib.dc.utils import OutputType
11
11
  from datachain.query import Session
12
12
 
@@ -18,7 +18,7 @@ if TYPE_CHECKING:
18
18
  P = ParamSpec("P")
19
19
 
20
20
 
21
- def from_values(
21
+ def read_values(
22
22
  ds_name: str = "",
23
23
  session: Optional[Session] = None,
24
24
  settings: Optional[dict] = None,
@@ -32,7 +32,7 @@ def from_values(
32
32
  Example:
33
33
  ```py
34
34
  import datachain as dc
35
- dc.from_values(fib=[1, 2, 3, 5, 8])
35
+ dc.read_values(fib=[1, 2, 3, 5, 8])
36
36
  ```
37
37
  """
38
38
  from .datachain import DataChain
@@ -42,7 +42,7 @@ def from_values(
42
42
  def _func_fr() -> Iterator[tuple_type]: # type: ignore[valid-type]
43
43
  yield from tuples
44
44
 
45
- chain = from_records(
45
+ chain = read_records(
46
46
  DataChain.DEFAULT_FILE_RECORD,
47
47
  session=session,
48
48
  settings=settings,
datachain/lib/listing.py CHANGED
@@ -4,6 +4,7 @@ import os
4
4
  import posixpath
5
5
  from collections.abc import Iterator
6
6
  from contextlib import contextmanager
7
+ from datetime import datetime, timedelta, timezone
7
8
  from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
8
9
 
9
10
  from fsspec.asyn import get_loop
@@ -32,6 +33,16 @@ logging.getLogger("aiobotocore.credentials").setLevel(logging.CRITICAL)
32
33
  logging.getLogger("gcsfs").setLevel(logging.CRITICAL)
33
34
 
34
35
 
36
+ def listing_dataset_expired(lst_ds) -> bool:
37
+ """Function that checks if listing dataset is expired or not"""
38
+ lst_version = lst_ds.versions[-1]
39
+ if not lst_version.finished_at:
40
+ return False
41
+
42
+ expires = lst_version.finished_at + timedelta(seconds=LISTING_TTL)
43
+ return datetime.now(timezone.utc) > expires
44
+
45
+
35
46
  def list_bucket(uri: str, cache, client_config=None) -> Callable:
36
47
  """
37
48
  Function that returns another generator function that yields File objects
@@ -103,10 +103,10 @@ def read_meta( # noqa: C901
103
103
  model_name=None,
104
104
  nrows=None,
105
105
  ) -> Callable:
106
- from datachain import from_storage
106
+ from datachain import read_storage
107
107
 
108
108
  if schema_from:
109
- file = next(from_storage(schema_from, type="text").limit(1).collect("file"))
109
+ file = next(read_storage(schema_from, type="text").limit(1).collect("file"))
110
110
  model_code = gen_datamodel_code(
111
111
  file, format=format, jmespath=jmespath, model_name=model_name
112
112
  )
datachain/lib/pytorch.py CHANGED
@@ -14,7 +14,7 @@ from torchvision.transforms import v2
14
14
  from datachain import Session
15
15
  from datachain.cache import get_temp_cache
16
16
  from datachain.catalog import Catalog, get_catalog
17
- from datachain.lib.dc.datasets import from_dataset
17
+ from datachain.lib.dc.datasets import read_dataset
18
18
  from datachain.lib.settings import Settings
19
19
  from datachain.lib.text import convert_text
20
20
  from datachain.progress import CombinedDownloadCallback
@@ -122,7 +122,7 @@ class PytorchDataset(IterableDataset):
122
122
  ) -> Generator[tuple[Any, ...], None, None]:
123
123
  catalog = self._get_catalog()
124
124
  session = Session("PyTorch", catalog=catalog)
125
- ds = from_dataset(
125
+ ds = read_dataset(
126
126
  name=self.name, version=self.version, session=session
127
127
  ).settings(cache=self.cache, prefetch=self.prefetch)
128
128
  ds = ds.remove_file_signals()
datachain/lib/udf.py CHANGED
@@ -145,7 +145,7 @@ class UDFBase(AbstractUDF):
145
145
  return emb[0].tolist()
146
146
 
147
147
  (
148
- dc.from_storage(
148
+ dc.read_storage(
149
149
  "gs://datachain-demo/fashion-product-images/images", type="image"
150
150
  )
151
151
  .limit(5)
@@ -47,15 +47,20 @@ from datachain.error import (
47
47
  QueryScriptCancelError,
48
48
  )
49
49
  from datachain.func.base import Function
50
- from datachain.lib.listing import is_listing_dataset
50
+ from datachain.lib.listing import (
51
+ is_listing_dataset,
52
+ listing_dataset_expired,
53
+ )
51
54
  from datachain.lib.udf import UDFAdapter, _get_cache
52
55
  from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
53
56
  from datachain.query.schema import C, UDFParamSpec, normalize_param
54
57
  from datachain.query.session import Session
58
+ from datachain.query.udf import UdfInfo
55
59
  from datachain.sql.functions.random import rand
56
60
  from datachain.utils import (
57
61
  batched,
58
62
  determine_processes,
63
+ determine_workers,
59
64
  filtered_cloudpickle_dumps,
60
65
  get_datachain_executable,
61
66
  safe_closing,
@@ -71,7 +76,6 @@ if TYPE_CHECKING:
71
76
  from datachain.data_storage import AbstractWarehouse
72
77
  from datachain.dataset import DatasetRecord
73
78
  from datachain.lib.udf import UDFAdapter, UDFResult
74
- from datachain.query.udf import UdfInfo
75
79
 
76
80
  P = ParamSpec("P")
77
81
 
@@ -411,20 +415,15 @@ class UDFStep(Step, ABC):
411
415
  def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
412
416
  from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE
413
417
 
414
- use_partitioning = self.partition_by is not None
415
- batching = self.udf.get_batching(use_partitioning)
416
- workers = self.workers
417
- if (
418
- not workers
419
- and os.environ.get("DATACHAIN_DISTRIBUTED")
420
- and os.environ.get("DATACHAIN_SETTINGS_WORKERS")
421
- ):
422
- # Enable distributed processing by default if the module is available,
423
- # and a default number of workers is provided.
424
- workers = True
418
+ rows_total = self.catalog.warehouse.query_count(query)
419
+ if rows_total == 0:
420
+ return
425
421
 
426
- processes = determine_processes(self.parallel)
422
+ workers = determine_workers(self.workers, rows_total=rows_total)
423
+ processes = determine_processes(self.parallel, rows_total=rows_total)
427
424
 
425
+ use_partitioning = self.partition_by is not None
426
+ batching = self.udf.get_batching(use_partitioning)
428
427
  udf_fields = [str(c.name) for c in query.selected_columns]
429
428
 
430
429
  prefetch = self.udf.prefetch
@@ -438,23 +437,24 @@ class UDFStep(Step, ABC):
438
437
  "distributed processing."
439
438
  )
440
439
 
441
- from datachain.catalog.loader import get_distributed_class
442
-
443
- distributor = get_distributed_class(
444
- min_task_size=self.min_task_size
445
- )
446
- distributor(
447
- self.udf,
448
- catalog,
449
- udf_table,
450
- query,
451
- workers,
452
- processes,
440
+ from datachain.catalog.loader import get_udf_distributor_class
441
+
442
+ udf_distributor_class = get_udf_distributor_class()
443
+ udf_distributor = udf_distributor_class(
444
+ catalog=catalog,
445
+ table=udf_table,
446
+ query=query,
447
+ udf_data=filtered_cloudpickle_dumps(self.udf),
448
+ batching=batching,
449
+ workers=workers,
450
+ processes=processes,
453
451
  udf_fields=udf_fields,
452
+ rows_total=rows_total,
453
+ use_cache=self.cache,
454
454
  is_generator=self.is_generator,
455
- use_partitioning=use_partitioning,
456
- cache=self.cache,
455
+ min_task_size=self.min_task_size,
457
456
  )
457
+ udf_distributor()
458
458
  elif processes:
459
459
  # Parallel processing (faster for more CPU-heavy UDFs)
460
460
  if catalog.in_memory:
@@ -462,19 +462,21 @@ class UDFStep(Step, ABC):
462
462
  "In-memory databases cannot be used "
463
463
  "with parallel processing."
464
464
  )
465
- udf_info: UdfInfo = {
466
- "udf_data": filtered_cloudpickle_dumps(self.udf),
467
- "catalog_init": catalog.get_init_params(),
468
- "metastore_clone_params": catalog.metastore.clone_params(),
469
- "warehouse_clone_params": catalog.warehouse.clone_params(),
470
- "table": udf_table,
471
- "query": query,
472
- "udf_fields": udf_fields,
473
- "batching": batching,
474
- "processes": processes,
475
- "is_generator": self.is_generator,
476
- "cache": self.cache,
477
- }
465
+
466
+ udf_info = UdfInfo(
467
+ udf_data=filtered_cloudpickle_dumps(self.udf),
468
+ catalog_init=catalog.get_init_params(),
469
+ metastore_clone_params=catalog.metastore.clone_params(),
470
+ warehouse_clone_params=catalog.warehouse.clone_params(),
471
+ table=udf_table,
472
+ query=query,
473
+ udf_fields=udf_fields,
474
+ batching=batching,
475
+ processes=processes,
476
+ is_generator=self.is_generator,
477
+ cache=self.cache,
478
+ rows_total=rows_total,
479
+ )
478
480
 
479
481
  # Run the UDFDispatcher in another process to avoid needing
480
482
  # if __name__ == '__main__': in user scripts
@@ -1080,6 +1082,7 @@ class DatasetQuery:
1080
1082
  indexing_column_types: Optional[dict[str, Any]] = None,
1081
1083
  in_memory: bool = False,
1082
1084
  fallback_to_studio: bool = True,
1085
+ update: bool = False,
1083
1086
  ) -> None:
1084
1087
  from datachain.remote.studio import is_token_set
1085
1088
 
@@ -1097,6 +1100,8 @@ class DatasetQuery:
1097
1100
  self.feature_schema: Optional[dict] = None
1098
1101
  self.column_types: Optional[dict[str, Any]] = None
1099
1102
  self.before_steps: list[Callable] = []
1103
+ self.listing_fn: Optional[Callable] = None
1104
+ self.update = update
1100
1105
 
1101
1106
  self.list_ds_name: Optional[str] = None
1102
1107
 
@@ -1190,23 +1195,30 @@ class DatasetQuery:
1190
1195
  col.table = self.table
1191
1196
  return col
1192
1197
 
1193
- def add_before_steps(self, fn: Callable) -> None:
1194
- """
1195
- Setting custom function to be run before applying steps
1196
- """
1197
- self.before_steps.append(fn)
1198
+ def set_listing_fn(self, fn: Callable) -> None:
1199
+ """Setting listing function to be run if needed"""
1200
+ self.listing_fn = fn
1198
1201
 
1199
1202
  def apply_steps(self) -> QueryGenerator:
1200
1203
  """
1201
1204
  Apply the steps in the query and return the resulting
1202
1205
  sqlalchemy.SelectBase.
1203
1206
  """
1204
- for fn in self.before_steps:
1205
- fn()
1207
+ if self.list_ds_name and not self.starting_step:
1208
+ listing_ds = None
1209
+ try:
1210
+ listing_ds = self.catalog.get_dataset(self.list_ds_name)
1211
+ except DatasetNotFoundError:
1212
+ pass
1213
+
1214
+ if not listing_ds or self.update or listing_dataset_expired(listing_ds):
1215
+ assert self.listing_fn
1216
+ self.listing_fn()
1217
+ listing_ds = self.catalog.get_dataset(self.list_ds_name)
1206
1218
 
1207
- if self.list_ds_name:
1208
1219
  # at this point we know what is our starting listing dataset name
1209
- self._set_starting_step(self.catalog.get_dataset(self.list_ds_name)) # type: ignore [arg-type]
1220
+ self._set_starting_step(listing_ds) # type: ignore [arg-type]
1221
+
1210
1222
  query = self.clone()
1211
1223
 
1212
1224
  index = os.getenv("DATACHAIN_QUERY_CHUNK_INDEX", self._chunk_index)
@@ -11,11 +11,10 @@ import multiprocess
11
11
  from cloudpickle import load, loads
12
12
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
13
13
  from multiprocess import get_context
14
- from sqlalchemy.sql import func
15
14
 
16
15
  from datachain.catalog import Catalog
17
16
  from datachain.catalog.catalog import clone_catalog_with_cache
18
- from datachain.catalog.loader import get_distributed_class
17
+ from datachain.catalog.loader import get_udf_distributor_class
19
18
  from datachain.lib.udf import _get_cache
20
19
  from datachain.query.batch import RowsOutput, RowsOutputBatch
21
20
  from datachain.query.dataset import (
@@ -59,6 +58,7 @@ def udf_entrypoint() -> int:
59
58
  dispatch = UDFDispatcher(udf_info)
60
59
 
61
60
  query = udf_info["query"]
61
+ rows_total = udf_info["rows_total"]
62
62
  batching = udf_info["batching"]
63
63
  n_workers = udf_info["processes"]
64
64
  if n_workers is True:
@@ -67,12 +67,6 @@ def udf_entrypoint() -> int:
67
67
  wh_cls, wh_args, wh_kwargs = udf_info["warehouse_clone_params"]
68
68
  warehouse: AbstractWarehouse = wh_cls(*wh_args, **wh_kwargs)
69
69
 
70
- total_rows = next(
71
- warehouse.db.execute(
72
- query.with_only_columns(func.count(query.c.sys__id)).order_by(None)
73
- )
74
- )[0]
75
-
76
70
  with contextlib.closing(
77
71
  batching(warehouse.dataset_select_paginated, query, ids_only=True)
78
72
  ) as udf_inputs:
@@ -81,7 +75,7 @@ def udf_entrypoint() -> int:
81
75
  try:
82
76
  dispatch.run_udf_parallel(
83
77
  udf_inputs,
84
- total_rows=total_rows,
78
+ rows_total=rows_total,
85
79
  n_workers=n_workers,
86
80
  processed_cb=processed_cb,
87
81
  download_cb=download_cb,
@@ -94,7 +88,7 @@ def udf_entrypoint() -> int:
94
88
 
95
89
 
96
90
  def udf_worker_entrypoint() -> int:
97
- return get_distributed_class().run_worker()
91
+ return get_udf_distributor_class().run_worker()
98
92
 
99
93
 
100
94
  class UDFDispatcher:
@@ -164,14 +158,14 @@ class UDFDispatcher:
164
158
  def run_udf_parallel( # noqa: C901, PLR0912
165
159
  self,
166
160
  input_rows: Iterable[RowsOutput],
167
- total_rows: int,
161
+ rows_total: int,
168
162
  n_workers: Optional[int] = None,
169
163
  processed_cb: Callback = DEFAULT_CALLBACK,
170
164
  download_cb: Callback = DEFAULT_CALLBACK,
171
165
  ) -> None:
172
166
  n_workers = get_n_workers_from_arg(n_workers)
173
167
 
174
- input_batch_size = total_rows // n_workers
168
+ input_batch_size = rows_total // n_workers
175
169
  if input_batch_size == 0:
176
170
  input_batch_size = 1
177
171
  elif input_batch_size > DEFAULT_BATCH_SIZE:
datachain/query/udf.py CHANGED
@@ -1,8 +1,10 @@
1
- from typing import TYPE_CHECKING, Any, Callable, Optional, TypedDict
1
+ from abc import ABC, abstractmethod
2
+ from typing import TYPE_CHECKING, Any, Callable, Optional, TypedDict, Union
2
3
 
3
4
  if TYPE_CHECKING:
4
5
  from sqlalchemy import Select, Table
5
6
 
7
+ from datachain.catalog import Catalog
6
8
  from datachain.query.batch import BatchingStrategy
7
9
 
8
10
 
@@ -18,3 +20,30 @@ class UdfInfo(TypedDict):
18
20
  processes: Optional[int]
19
21
  is_generator: bool
20
22
  cache: bool
23
+ rows_total: int
24
+
25
+
26
+ class AbstractUDFDistributor(ABC):
27
+ @abstractmethod
28
+ def __init__(
29
+ self,
30
+ catalog: "Catalog",
31
+ table: "Table",
32
+ query: "Select",
33
+ udf_data: bytes,
34
+ batching: "BatchingStrategy",
35
+ workers: Union[bool, int],
36
+ processes: Union[bool, int],
37
+ udf_fields: list[str],
38
+ rows_total: int,
39
+ use_cache: bool,
40
+ is_generator: bool = False,
41
+ min_task_size: Optional[Union[str, int]] = None,
42
+ ) -> None: ...
43
+
44
+ @abstractmethod
45
+ def __call__(self) -> None: ...
46
+
47
+ @staticmethod
48
+ @abstractmethod
49
+ def run_worker() -> int: ...
@@ -41,7 +41,7 @@ def train_test_split(
41
41
  from datachain.toolkit import train_test_split
42
42
 
43
43
  # Load a DataChain from a storage source (e.g., S3 bucket)
44
- dc = dc.from_storage("s3://bucket/dir/")
44
+ dc = dc.read_storage("s3://bucket/dir/")
45
45
 
46
46
  # Perform a 70/30 train-test split
47
47
  train, test = train_test_split(dc, [0.7, 0.3])
datachain/utils.py CHANGED
@@ -286,15 +286,41 @@ def retry_with_backoff(retries=5, backoff_sec=1, errors=(Exception,)):
286
286
  return retry
287
287
 
288
288
 
289
- def determine_processes(parallel: Optional[Union[bool, int]]) -> Union[bool, int]:
289
+ def determine_workers(
290
+ workers: Union[bool, int],
291
+ rows_total: Optional[int] = None,
292
+ ) -> Union[bool, int]:
293
+ """Determine the number of workers to use for distributed processing."""
294
+ if rows_total is not None and rows_total <= 1:
295
+ # Disable distributed processing if there is no rows or only one row.
296
+ return False
297
+ if (
298
+ workers is False
299
+ and os.environ.get("DATACHAIN_DISTRIBUTED")
300
+ and os.environ.get("DATACHAIN_SETTINGS_WORKERS")
301
+ ):
302
+ # Enable distributed processing by default if the module is available,
303
+ # and a default number of workers is provided.
304
+ workers = int(os.environ["DATACHAIN_SETTINGS_WORKERS"])
305
+ if not workers or workers <= 0:
306
+ return False
307
+ return workers
308
+
309
+
310
+ def determine_processes(
311
+ parallel: Optional[Union[bool, int]] = None,
312
+ rows_total: Optional[int] = None,
313
+ ) -> Union[bool, int]:
314
+ """Determine the number of processes to use for parallel processing."""
315
+ if rows_total is not None and rows_total <= 1:
316
+ # Disable parallel processing if there is no rows or only one row.
317
+ return False
290
318
  if parallel is None and os.environ.get("DATACHAIN_SETTINGS_PARALLEL") is not None:
291
319
  parallel = int(os.environ["DATACHAIN_SETTINGS_PARALLEL"])
292
- if parallel is None or parallel is False:
320
+ if parallel is None or parallel is False or parallel == 0:
293
321
  return False
294
322
  if parallel is True:
295
323
  return True
296
- if parallel == 0:
297
- return False
298
324
  if parallel < 0:
299
325
  return True
300
326
  return parallel