datachain 0.7.11__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

datachain/lib/listing.py CHANGED
@@ -15,6 +15,7 @@ from datachain.utils import uses_glob
15
15
 
16
16
  if TYPE_CHECKING:
17
17
  from datachain.lib.dc import DataChain
18
+ from datachain.query.session import Session
18
19
 
19
20
  LISTING_TTL = 4 * 60 * 60 # cached listing lasts 4 hours
20
21
  LISTING_PREFIX = "lst__" # listing datasets start with this name
@@ -108,3 +109,46 @@ def listing_uri_from_name(dataset_name: str) -> str:
108
109
  if not is_listing_dataset(dataset_name):
109
110
  raise ValueError(f"Dataset {dataset_name} is not a listing")
110
111
  return dataset_name.removeprefix(LISTING_PREFIX)
112
+
113
+
114
+ def get_listing(
115
+ uri: str, session: "Session", update: bool = False
116
+ ) -> tuple[str, str, str, bool]:
117
+ """Returns correct listing dataset name that must be used for saving listing
118
+ operation. It takes into account existing listings and reusability of those.
119
+ It also returns boolean saying if returned dataset name is reused / already
120
+ exists or not (on update it always returns False - just because there was no
121
+ reason to complicate it so far). And it returns correct listing path that should
122
+ be used to find rows based on uri.
123
+ """
124
+ from datachain.client.local import FileClient
125
+
126
+ catalog = session.catalog
127
+ cache = catalog.cache
128
+ client_config = catalog.client_config
129
+
130
+ client = Client.get_client(uri, cache, **client_config)
131
+ ds_name, list_uri, list_path = parse_listing_uri(uri, cache, client_config)
132
+ listing = None
133
+
134
+ listings = [
135
+ ls for ls in catalog.listings() if not ls.is_expired and ls.contains(ds_name)
136
+ ]
137
+
138
+ # if no need to update - choosing the most recent one;
139
+ # otherwise, we'll using the exact original `ds_name`` in this case:
140
+ # - if a "bigger" listing exists, we don't want to update it, it's better
141
+ # to create a new "smaller" one on "update=True"
142
+ # - if an exact listing exists it will have the same name as `ds_name`
143
+ # anyway below
144
+ if listings and not update:
145
+ listing = sorted(listings, key=lambda ls: ls.created_at)[-1]
146
+
147
+ # for local file system we need to fix listing path / prefix
148
+ # if we are reusing existing listing
149
+ if isinstance(client, FileClient) and listing and listing.name != ds_name:
150
+ list_path = f'{ds_name.strip("/").removeprefix(listing.name)}/{list_path}'
151
+
152
+ ds_name = listing.name if listing else ds_name
153
+
154
+ return ds_name, list_uri, list_path, bool(listing)
@@ -38,38 +38,41 @@ def process_json(data_string, jmespath):
38
38
  return json_dict
39
39
 
40
40
 
41
- # Print a dynamic datamodel-codegen output from JSON or CSV on stdout
42
- def read_schema(source_file, data_type="csv", expr=None, model_name=None):
41
+ def gen_datamodel_code(
42
+ source_file, format="json", jmespath=None, model_name=None
43
+ ) -> str:
44
+ """Generates Python code with Pydantic models that corresponds
45
+ to the provided JSON, CSV, or JSONL file.
46
+ It support root JSON arrays (samples the first entry).
47
+ """
43
48
  data_string = ""
44
49
  # using uiid to get around issue #1617
45
50
  if not model_name:
46
51
  # comply with Python class names
47
52
  uid_str = str(generate_uuid()).replace("-", "")
48
- model_name = f"Model{data_type}{uid_str}"
49
- try:
50
- with source_file.open() as fd: # CSV can be larger than memory
51
- if data_type == "csv":
52
- data_string += fd.readline().replace("\r", "")
53
- data_string += fd.readline().replace("\r", "")
54
- elif data_type == "jsonl":
55
- data_string = fd.readline().replace("\r", "")
56
- else:
57
- data_string = fd.read() # other meta must fit into RAM
58
- except OSError as e:
59
- print(f"An unexpected file error occurred: {e}")
60
- return
61
- if data_type in ("json", "jsonl"):
62
- json_object = process_json(data_string, expr)
63
- if data_type == "json" and isinstance(json_object, list):
53
+ model_name = f"Model{format}{uid_str}"
54
+
55
+ with source_file.open() as fd: # CSV can be larger than memory
56
+ if format == "csv":
57
+ data_string += fd.readline().replace("\r", "")
58
+ data_string += fd.readline().replace("\r", "")
59
+ elif format == "jsonl":
60
+ data_string = fd.readline().replace("\r", "")
61
+ else:
62
+ data_string = fd.read() # other meta must fit into RAM
63
+
64
+ if format in ("json", "jsonl"):
65
+ json_object = process_json(data_string, jmespath)
66
+ if format == "json" and isinstance(json_object, list):
64
67
  json_object = json_object[0] # sample the 1st object from JSON array
65
- if data_type == "jsonl":
66
- data_type = "json" # treat json line as plain JSON in auto-schema
68
+ if format == "jsonl":
69
+ format = "json" # treat json line as plain JSON in auto-schema
67
70
  data_string = json.dumps(json_object)
68
71
 
69
72
  import datamodel_code_generator
70
73
 
71
74
  input_file_types = {i.value: i for i in datamodel_code_generator.InputFileType}
72
- input_file_type = input_file_types[data_type]
75
+ input_file_type = input_file_types[format]
73
76
  with tempfile.TemporaryDirectory() as tmpdir:
74
77
  output = Path(tmpdir) / "model.py"
75
78
  datamodel_code_generator.generate(
@@ -95,36 +98,29 @@ spec = {model_name}
95
98
  def read_meta( # noqa: C901
96
99
  spec=None,
97
100
  schema_from=None,
98
- meta_type="json",
101
+ format="json",
99
102
  jmespath=None,
100
- print_schema=False,
101
103
  model_name=None,
102
104
  nrows=None,
103
105
  ) -> Callable:
104
106
  from datachain.lib.dc import DataChain
105
107
 
106
108
  if schema_from:
107
- chain = (
108
- DataChain.from_storage(schema_from, type="text")
109
- .limit(1)
110
- .map( # dummy column created (#1615)
111
- meta_schema=lambda file: read_schema(
112
- file, data_type=meta_type, expr=jmespath, model_name=model_name
113
- ),
114
- output=str,
115
- )
109
+ file = next(
110
+ DataChain.from_storage(schema_from, type="text").limit(1).collect("file")
116
111
  )
117
- (model_output,) = chain.collect("meta_schema")
118
- assert isinstance(model_output, str)
119
- if print_schema:
120
- print(f"{model_output}")
112
+ model_code = gen_datamodel_code(
113
+ file, format=format, jmespath=jmespath, model_name=model_name
114
+ )
115
+ assert isinstance(model_code, str)
116
+
121
117
  # Below 'spec' should be a dynamically converted DataModel from Pydantic
122
118
  if not spec:
123
119
  gl = globals()
124
- exec(model_output, gl) # type: ignore[arg-type] # noqa: S102
120
+ exec(model_code, gl) # type: ignore[arg-type] # noqa: S102
125
121
  spec = gl["spec"]
126
122
 
127
- if not (spec) and not (schema_from):
123
+ if not spec and not schema_from:
128
124
  raise ValueError(
129
125
  "Must provide a static schema in spec: or metadata sample in schema_from:"
130
126
  )
@@ -136,7 +132,7 @@ def read_meta( # noqa: C901
136
132
  def parse_data(
137
133
  file: File,
138
134
  data_model=spec,
139
- meta_type=meta_type,
135
+ format=format,
140
136
  jmespath=jmespath,
141
137
  nrows=nrows,
142
138
  ) -> Iterator[spec]:
@@ -148,7 +144,7 @@ def read_meta( # noqa: C901
148
144
  except ValidationError as e:
149
145
  print(f"Validation error occurred in row {nrow} file {file.name}:", e)
150
146
 
151
- if meta_type == "csv":
147
+ if format == "csv":
152
148
  with (
153
149
  file.open() as fd
154
150
  ): # TODO: if schema is statically given, should allow CSV without headers
@@ -156,7 +152,7 @@ def read_meta( # noqa: C901
156
152
  for row in reader: # CSV can be larger than memory
157
153
  yield from validator(row)
158
154
 
159
- if meta_type == "json":
155
+ if format == "json":
160
156
  try:
161
157
  with file.open() as fd: # JSON must fit into RAM
162
158
  data_string = fd.read()
@@ -174,7 +170,7 @@ def read_meta( # noqa: C901
174
170
  return
175
171
  yield from validator(json_dict, nrow)
176
172
 
177
- if meta_type == "jsonl":
173
+ if format == "jsonl":
178
174
  try:
179
175
  nrow = 0
180
176
  with file.open() as fd:
datachain/lib/udf.py CHANGED
@@ -85,7 +85,6 @@ class UDFAdapter:
85
85
  udf_fields: "Sequence[str]",
86
86
  udf_inputs: "Iterable[RowsOutput]",
87
87
  catalog: "Catalog",
88
- is_generator: bool,
89
88
  cache: bool,
90
89
  download_cb: Callback = DEFAULT_CALLBACK,
91
90
  processed_cb: Callback = DEFAULT_CALLBACK,
datachain/query/batch.py CHANGED
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Callable, Optional, Union
7
7
 
8
8
  from datachain.data_storage.schema import PARTITION_COLUMN_ID
9
9
  from datachain.data_storage.warehouse import SELECT_BATCH_SIZE
10
+ from datachain.query.utils import get_query_column, get_query_id_column
10
11
 
11
12
  if TYPE_CHECKING:
12
13
  from sqlalchemy import Select
@@ -23,11 +24,14 @@ RowsOutput = Union[Sequence, RowsOutputBatch]
23
24
  class BatchingStrategy(ABC):
24
25
  """BatchingStrategy provides means of batching UDF executions."""
25
26
 
27
+ is_batching: bool
28
+
26
29
  @abstractmethod
27
30
  def __call__(
28
31
  self,
29
- execute: Callable[..., Generator[Sequence, None, None]],
32
+ execute: Callable,
30
33
  query: "Select",
34
+ ids_only: bool = False,
31
35
  ) -> Generator[RowsOutput, None, None]:
32
36
  """Apply the provided parameters to the UDF."""
33
37
 
@@ -38,11 +42,16 @@ class NoBatching(BatchingStrategy):
38
42
  batch UDF calls.
39
43
  """
40
44
 
45
+ is_batching = False
46
+
41
47
  def __call__(
42
48
  self,
43
- execute: Callable[..., Generator[Sequence, None, None]],
49
+ execute: Callable,
44
50
  query: "Select",
51
+ ids_only: bool = False,
45
52
  ) -> Generator[Sequence, None, None]:
53
+ if ids_only:
54
+ query = query.with_only_columns(get_query_id_column(query))
46
55
  return execute(query)
47
56
 
48
57
 
@@ -52,14 +61,20 @@ class Batch(BatchingStrategy):
52
61
  is passed a sequence of multiple parameter sets.
53
62
  """
54
63
 
64
+ is_batching = True
65
+
55
66
  def __init__(self, count: int):
56
67
  self.count = count
57
68
 
58
69
  def __call__(
59
70
  self,
60
- execute: Callable[..., Generator[Sequence, None, None]],
71
+ execute: Callable,
61
72
  query: "Select",
73
+ ids_only: bool = False,
62
74
  ) -> Generator[RowsOutputBatch, None, None]:
75
+ if ids_only:
76
+ query = query.with_only_columns(get_query_id_column(query))
77
+
63
78
  # choose page size that is a multiple of the batch size
64
79
  page_size = math.ceil(SELECT_BATCH_SIZE / self.count) * self.count
65
80
 
@@ -84,19 +99,30 @@ class Partition(BatchingStrategy):
84
99
  Dataset rows need to be sorted by the grouping column.
85
100
  """
86
101
 
102
+ is_batching = True
103
+
87
104
  def __call__(
88
105
  self,
89
- execute: Callable[..., Generator[Sequence, None, None]],
106
+ execute: Callable,
90
107
  query: "Select",
108
+ ids_only: bool = False,
91
109
  ) -> Generator[RowsOutputBatch, None, None]:
110
+ id_col = get_query_id_column(query)
111
+ if (partition_col := get_query_column(query, PARTITION_COLUMN_ID)) is None:
112
+ raise RuntimeError("partition column not found in query")
113
+
114
+ if ids_only:
115
+ query = query.with_only_columns(id_col, partition_col)
116
+
92
117
  current_partition: Optional[int] = None
93
118
  batch: list[Sequence] = []
94
119
 
95
120
  query_fields = [str(c.name) for c in query.selected_columns]
121
+ id_column_idx = query_fields.index("sys__id")
96
122
  partition_column_idx = query_fields.index(PARTITION_COLUMN_ID)
97
123
 
98
124
  ordered_query = query.order_by(None).order_by(
99
- PARTITION_COLUMN_ID,
125
+ partition_col,
100
126
  *query._order_by_clauses,
101
127
  )
102
128
 
@@ -108,7 +134,7 @@ class Partition(BatchingStrategy):
108
134
  if len(batch) > 0:
109
135
  yield RowsOutputBatch(batch)
110
136
  batch = []
111
- batch.append(row)
137
+ batch.append([row[id_column_idx]] if ids_only else row)
112
138
 
113
139
  if len(batch) > 0:
114
140
  yield RowsOutputBatch(batch)
@@ -43,8 +43,9 @@ from datachain.data_storage.schema import (
43
43
  from datachain.dataset import DatasetStatus, RowDict
44
44
  from datachain.error import DatasetNotFoundError, QueryScriptCancelError
45
45
  from datachain.func.base import Function
46
- from datachain.lib.udf import UDFAdapter
47
46
  from datachain.progress import CombinedDownloadCallback
47
+ from datachain.query.schema import C, UDFParamSpec, normalize_param
48
+ from datachain.query.session import Session
48
49
  from datachain.sql.functions.random import rand
49
50
  from datachain.utils import (
50
51
  batched,
@@ -53,9 +54,6 @@ from datachain.utils import (
53
54
  get_datachain_executable,
54
55
  )
55
56
 
56
- from .schema import C, UDFParamSpec, normalize_param
57
- from .session import Session
58
-
59
57
  if TYPE_CHECKING:
60
58
  from sqlalchemy.sql.elements import ClauseElement
61
59
  from sqlalchemy.sql.schema import Table
@@ -65,7 +63,8 @@ if TYPE_CHECKING:
65
63
  from datachain.catalog import Catalog
66
64
  from datachain.data_storage import AbstractWarehouse
67
65
  from datachain.dataset import DatasetRecord
68
- from datachain.lib.udf import UDFResult
66
+ from datachain.lib.udf import UDFAdapter, UDFResult
67
+ from datachain.query.udf import UdfInfo
69
68
 
70
69
  P = ParamSpec("P")
71
70
 
@@ -301,7 +300,7 @@ def adjust_outputs(
301
300
  return row
302
301
 
303
302
 
304
- def get_udf_col_types(warehouse: "AbstractWarehouse", udf: UDFAdapter) -> list[tuple]:
303
+ def get_udf_col_types(warehouse: "AbstractWarehouse", udf: "UDFAdapter") -> list[tuple]:
305
304
  """Optimization: Precompute UDF column types so these don't have to be computed
306
305
  in the convert_type function for each row in a loop."""
307
306
  dialect = warehouse.db.dialect
@@ -322,7 +321,7 @@ def process_udf_outputs(
322
321
  warehouse: "AbstractWarehouse",
323
322
  udf_table: "Table",
324
323
  udf_results: Iterator[Iterable["UDFResult"]],
325
- udf: UDFAdapter,
324
+ udf: "UDFAdapter",
326
325
  batch_size: int = INSERT_BATCH_SIZE,
327
326
  cb: Callback = DEFAULT_CALLBACK,
328
327
  ) -> None:
@@ -347,6 +346,8 @@ def process_udf_outputs(
347
346
  for row_chunk in batched(rows, batch_size):
348
347
  warehouse.insert_rows(udf_table, row_chunk)
349
348
 
349
+ warehouse.insert_rows_done(udf_table)
350
+
350
351
 
351
352
  def get_download_callback() -> Callback:
352
353
  return CombinedDownloadCallback(
@@ -366,7 +367,7 @@ def get_generated_callback(is_generator: bool = False) -> Callback:
366
367
 
367
368
  @frozen
368
369
  class UDFStep(Step, ABC):
369
- udf: UDFAdapter
370
+ udf: "UDFAdapter"
370
371
  catalog: "Catalog"
371
372
  partition_by: Optional[PartitionByType] = None
372
373
  parallel: Optional[int] = None
@@ -440,7 +441,7 @@ class UDFStep(Step, ABC):
440
441
  raise RuntimeError(
441
442
  "In-memory databases cannot be used with parallel processing."
442
443
  )
443
- udf_info = {
444
+ udf_info: UdfInfo = {
444
445
  "udf_data": filtered_cloudpickle_dumps(self.udf),
445
446
  "catalog_init": self.catalog.get_init_params(),
446
447
  "metastore_clone_params": self.catalog.metastore.clone_params(),
@@ -464,8 +465,8 @@ class UDFStep(Step, ABC):
464
465
 
465
466
  with subprocess.Popen(cmd, env=envs, stdin=subprocess.PIPE) as process: # noqa: S603
466
467
  process.communicate(process_data)
467
- if process.poll():
468
- raise RuntimeError("UDF Execution Failed!")
468
+ if retval := process.poll():
469
+ raise RuntimeError(f"UDF Execution Failed! Exit code: {retval}")
469
470
  else:
470
471
  # Otherwise process single-threaded (faster for smaller UDFs)
471
472
  warehouse = self.catalog.warehouse
@@ -479,7 +480,6 @@ class UDFStep(Step, ABC):
479
480
  udf_fields,
480
481
  udf_inputs,
481
482
  self.catalog,
482
- self.is_generator,
483
483
  self.cache,
484
484
  download_cb,
485
485
  processed_cb,
@@ -496,8 +496,6 @@ class UDFStep(Step, ABC):
496
496
  processed_cb.close()
497
497
  generated_cb.close()
498
498
 
499
- warehouse.insert_rows_done(udf_table)
500
-
501
499
  except QueryScriptCancelError:
502
500
  self.catalog.warehouse.close()
503
501
  sys.exit(QUERY_SCRIPT_CANCELED_EXIT_CODE)
@@ -1069,6 +1067,7 @@ class DatasetQuery:
1069
1067
  if "sys__id" in self.column_types:
1070
1068
  self.column_types.pop("sys__id")
1071
1069
  self.starting_step = QueryStep(self.catalog, name, self.version)
1070
+ self.dialect = self.catalog.warehouse.db.dialect
1072
1071
 
1073
1072
  def __iter__(self):
1074
1073
  return iter(self.db_results())
@@ -1490,7 +1489,7 @@ class DatasetQuery:
1490
1489
  @detach
1491
1490
  def add_signals(
1492
1491
  self,
1493
- udf: UDFAdapter,
1492
+ udf: "UDFAdapter",
1494
1493
  parallel: Optional[int] = None,
1495
1494
  workers: Union[bool, int] = False,
1496
1495
  min_task_size: Optional[int] = None,
@@ -1534,7 +1533,7 @@ class DatasetQuery:
1534
1533
  @detach
1535
1534
  def generate(
1536
1535
  self,
1537
- udf: UDFAdapter,
1536
+ udf: "UDFAdapter",
1538
1537
  parallel: Optional[int] = None,
1539
1538
  workers: Union[bool, int] = False,
1540
1539
  min_task_size: Optional[int] = None,
@@ -1616,7 +1615,9 @@ class DatasetQuery:
1616
1615
  )
1617
1616
  version = version or dataset.latest_version
1618
1617
 
1619
- self.session.add_dataset_version(dataset=dataset, version=version)
1618
+ self.session.add_dataset_version(
1619
+ dataset=dataset, version=version, listing=kwargs.get("listing", False)
1620
+ )
1620
1621
 
1621
1622
  dr = self.catalog.warehouse.dataset_rows(dataset)
1622
1623