datachain 0.3.16__py3-none-any.whl → 0.3.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -3,7 +3,6 @@ import inspect
3
3
  import logging
4
4
  import os
5
5
  import random
6
- import re
7
6
  import string
8
7
  import subprocess
9
8
  import sys
@@ -36,7 +35,6 @@ from sqlalchemy.sql.selectable import Select
36
35
 
37
36
  from datachain.asyn import ASYNC_WORKERS, AsyncMapper, OrderedMapper
38
37
  from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE, get_catalog
39
- from datachain.client import Client
40
38
  from datachain.data_storage.schema import (
41
39
  PARTITION_COLUMN_ID,
42
40
  partition_col_names,
@@ -46,7 +44,6 @@ from datachain.dataset import DatasetStatus, RowDict
46
44
  from datachain.error import DatasetNotFoundError, QueryScriptCancelError
47
45
  from datachain.progress import CombinedDownloadCallback
48
46
  from datachain.sql.functions import rand
49
- from datachain.storage import Storage, StorageURI
50
47
  from datachain.utils import (
51
48
  batched,
52
49
  determine_processes,
@@ -56,7 +53,7 @@ from datachain.utils import (
56
53
 
57
54
  from .schema import C, UDFParamSpec, normalize_param
58
55
  from .session import Session
59
- from .udf import UDFBase, UDFClassWrapper, UDFFactory, UDFType
56
+ from .udf import UDFBase
60
57
 
61
58
  if TYPE_CHECKING:
62
59
  from sqlalchemy.sql.elements import ClauseElement
@@ -77,9 +74,7 @@ INSERT_BATCH_SIZE = 10000
77
74
 
78
75
  PartitionByType = Union[ColumnElement, Sequence[ColumnElement]]
79
76
  JoinPredicateType = Union[str, ColumnClause, ColumnElement]
80
- # dependency can be either dataset_name + dataset_version tuple or just storage uri
81
- # depending what type of dependency we are adding
82
- DatasetDependencyType = Union[tuple[str, int], StorageURI]
77
+ DatasetDependencyType = tuple[str, int]
83
78
 
84
79
  logger = logging.getLogger("datachain")
85
80
 
@@ -185,38 +180,6 @@ class QueryStep(StartingStep):
185
180
  )
186
181
 
187
182
 
188
- @frozen
189
- class IndexingStep(StartingStep):
190
- path: str
191
- catalog: "Catalog"
192
- kwargs: dict[str, Any]
193
- recursive: Optional[bool] = True
194
-
195
- def apply(self):
196
- self.catalog.index([self.path], **self.kwargs)
197
- uri, path = Client.parse_url(self.path)
198
- _partial_id, partial_path = self.catalog.metastore.get_valid_partial_id(
199
- uri, path
200
- )
201
- dataset = self.catalog.get_dataset(Storage.dataset_name(uri, partial_path))
202
- dataset_rows = self.catalog.warehouse.dataset_rows(
203
- dataset, dataset.latest_version
204
- )
205
-
206
- def q(*columns):
207
- col_names = [c.name for c in columns]
208
- return self.catalog.warehouse.nodes_dataset_query(
209
- dataset_rows,
210
- column_names=col_names,
211
- path=path,
212
- recursive=self.recursive,
213
- )
214
-
215
- storage = self.catalog.metastore.get_storage(uri)
216
-
217
- return step_result(q, dataset_rows.c, dependencies=[storage.uri])
218
-
219
-
220
183
  def generator_then_call(generator, func: Callable):
221
184
  """
222
185
  Yield items from generator then execute a function and yield
@@ -230,7 +193,7 @@ def generator_then_call(generator, func: Callable):
230
193
  class DatasetDiffOperation(Step):
231
194
  """
232
195
  Abstract class for operations that are calculation some kind of diff between
233
- datasets queries like subtract, changed etc.
196
+ datasets queries like subtract etc.
234
197
  """
235
198
 
236
199
  dq: "DatasetQuery"
@@ -304,28 +267,6 @@ class Subtract(DatasetDiffOperation):
304
267
  return sq.select().except_(sq.select().where(where_clause))
305
268
 
306
269
 
307
- @frozen
308
- class Changed(DatasetDiffOperation):
309
- """
310
- Calculates rows that are changed in a source query compared to target query
311
- Changed means it has same source + path but different last_modified
312
- Example:
313
- >>> ds = DatasetQuery(name="dogs_cats") # some older dataset with embeddings
314
- >>> ds_updated = (
315
- DatasetQuery("gs://dvcx-datalakes/dogs-and-cats")
316
- .filter(C.size > 1000) # we can also filter out source query
317
- .changed(ds)
318
- .add_signals(calc_embeddings) # calculae embeddings only on changed rows
319
- .union(ds) # union with old dataset that's missing updated rows
320
- .save("dogs_cats_updated")
321
- )
322
-
323
- """
324
-
325
- def query(self, source_query: Select, target_query: Select) -> Select:
326
- return self.catalog.warehouse.changed_query(source_query, target_query)
327
-
328
-
329
270
  def adjust_outputs(
330
271
  warehouse: "AbstractWarehouse", row: dict[str, Any], udf_col_types: list[tuple]
331
272
  ) -> dict[str, Any]:
@@ -423,7 +364,7 @@ def get_generated_callback(is_generator: bool = False) -> Callback:
423
364
 
424
365
  @frozen
425
366
  class UDFStep(Step, ABC):
426
- udf: UDFType
367
+ udf: UDFBase
427
368
  catalog: "Catalog"
428
369
  partition_by: Optional[PartitionByType] = None
429
370
  parallel: Optional[int] = None
@@ -529,12 +470,6 @@ class UDFStep(Step, ABC):
529
470
 
530
471
  else:
531
472
  # Otherwise process single-threaded (faster for smaller UDFs)
532
- # Optionally instantiate the UDF instance if a class is provided.
533
- if isinstance(self.udf, UDFFactory):
534
- udf: UDFBase = self.udf()
535
- else:
536
- udf = self.udf
537
-
538
473
  warehouse = self.catalog.warehouse
539
474
 
540
475
  with contextlib.closing(
@@ -544,7 +479,7 @@ class UDFStep(Step, ABC):
544
479
  processed_cb = get_processed_callback()
545
480
  generated_cb = get_generated_callback(self.is_generator)
546
481
  try:
547
- udf_results = udf.run(
482
+ udf_results = self.udf.run(
548
483
  udf_fields,
549
484
  udf_inputs,
550
485
  self.catalog,
@@ -557,7 +492,7 @@ class UDFStep(Step, ABC):
557
492
  warehouse,
558
493
  udf_table,
559
494
  udf_results,
560
- udf,
495
+ self.udf,
561
496
  cb=generated_cb,
562
497
  )
563
498
  finally:
@@ -1096,28 +1031,14 @@ class ResultIter:
1096
1031
  class DatasetQuery:
1097
1032
  def __init__(
1098
1033
  self,
1099
- path: str = "",
1100
- name: str = "",
1034
+ name: str,
1101
1035
  version: Optional[int] = None,
1102
1036
  catalog: Optional["Catalog"] = None,
1103
- client_config=None,
1104
- recursive: Optional[bool] = True,
1105
1037
  session: Optional[Session] = None,
1106
- anon: bool = False,
1107
- indexing_feature_schema: Optional[dict] = None,
1108
1038
  indexing_column_types: Optional[dict[str, Any]] = None,
1109
- update: Optional[bool] = False,
1110
1039
  in_memory: bool = False,
1111
1040
  ):
1112
- if client_config is None:
1113
- client_config = {}
1114
-
1115
- if anon:
1116
- client_config["anon"] = True
1117
-
1118
- self.session = Session.get(
1119
- session, catalog=catalog, client_config=client_config, in_memory=in_memory
1120
- )
1041
+ self.session = Session.get(session, catalog=catalog, in_memory=in_memory)
1121
1042
  self.catalog = catalog or self.session.catalog
1122
1043
  self.steps: list[Step] = []
1123
1044
  self._chunk_index: Optional[int] = None
@@ -1131,26 +1052,14 @@ class DatasetQuery:
1131
1052
  self.feature_schema: Optional[dict] = None
1132
1053
  self.column_types: Optional[dict[str, Any]] = None
1133
1054
 
1134
- if path:
1135
- kwargs = {"update": True} if update else {}
1136
- self.starting_step = IndexingStep(path, self.catalog, kwargs, recursive)
1137
- self.feature_schema = indexing_feature_schema
1138
- self.column_types = indexing_column_types
1139
- elif name:
1140
- self.name = name
1141
- ds = self.catalog.get_dataset(name)
1142
- self.version = version or ds.latest_version
1143
- self.feature_schema = ds.get_version(self.version).feature_schema
1144
- self.column_types = copy(ds.schema)
1145
- if "sys__id" in self.column_types:
1146
- self.column_types.pop("sys__id")
1147
- self.starting_step = QueryStep(self.catalog, name, self.version)
1148
- else:
1149
- raise ValueError("must provide path or name")
1150
-
1151
- @staticmethod
1152
- def is_storage_path(path):
1153
- return bool(re.compile(r"^[a-zA-Z0-9]+://").match(path))
1055
+ self.name = name
1056
+ ds = self.catalog.get_dataset(name)
1057
+ self.version = version or ds.latest_version
1058
+ self.feature_schema = ds.get_version(self.version).feature_schema
1059
+ self.column_types = copy(ds.schema)
1060
+ if "sys__id" in self.column_types:
1061
+ self.column_types.pop("sys__id")
1062
+ self.starting_step = QueryStep(self.catalog, name, self.version)
1154
1063
 
1155
1064
  def __iter__(self):
1156
1065
  return iter(self.db_results())
@@ -1556,7 +1465,7 @@ class DatasetQuery:
1556
1465
  @detach
1557
1466
  def add_signals(
1558
1467
  self,
1559
- udf: UDFType,
1468
+ udf: UDFBase,
1560
1469
  parallel: Optional[int] = None,
1561
1470
  workers: Union[bool, int] = False,
1562
1471
  min_task_size: Optional[int] = None,
@@ -1577,9 +1486,6 @@ class DatasetQuery:
1577
1486
  at least that minimum number of rows to each distributed worker, mostly useful
1578
1487
  if there are a very large number of small tasks to process.
1579
1488
  """
1580
- if isinstance(udf, UDFClassWrapper): # type: ignore[unreachable]
1581
- # This is a bare decorated class, "instantiate" it now.
1582
- udf = udf() # type: ignore[unreachable]
1583
1489
  query = self.clone()
1584
1490
  query.steps.append(
1585
1491
  UDFSignal(
@@ -1595,34 +1501,21 @@ class DatasetQuery:
1595
1501
  return query
1596
1502
 
1597
1503
  @detach
1598
- def subtract(self, dq: "DatasetQuery") -> "Self":
1599
- return self._subtract(dq, on=[("source", "source"), ("path", "path")])
1600
-
1601
- @detach
1602
- def _subtract(self, dq: "DatasetQuery", on: Sequence[tuple[str, str]]) -> "Self":
1504
+ def subtract(self, dq: "DatasetQuery", on: Sequence[tuple[str, str]]) -> "Self":
1603
1505
  query = self.clone()
1604
1506
  query.steps.append(Subtract(dq, self.catalog, on=on))
1605
1507
  return query
1606
1508
 
1607
- @detach
1608
- def changed(self, dq: "DatasetQuery") -> "Self":
1609
- query = self.clone()
1610
- query.steps.append(Changed(dq, self.catalog))
1611
- return query
1612
-
1613
1509
  @detach
1614
1510
  def generate(
1615
1511
  self,
1616
- udf: UDFType,
1512
+ udf: UDFBase,
1617
1513
  parallel: Optional[int] = None,
1618
1514
  workers: Union[bool, int] = False,
1619
1515
  min_task_size: Optional[int] = None,
1620
1516
  partition_by: Optional[PartitionByType] = None,
1621
1517
  cache: bool = False,
1622
1518
  ) -> "Self":
1623
- if isinstance(udf, UDFClassWrapper): # type: ignore[unreachable]
1624
- # This is a bare decorated class, "instantiate" it now.
1625
- udf = udf() # type: ignore[unreachable]
1626
1519
  query = self.clone()
1627
1520
  steps = query.steps
1628
1521
  steps.append(
@@ -1640,24 +1533,13 @@ class DatasetQuery:
1640
1533
 
1641
1534
  def _add_dependencies(self, dataset: "DatasetRecord", version: int):
1642
1535
  for dependency in self.dependencies:
1643
- if isinstance(dependency, tuple):
1644
- # dataset dependency
1645
- ds_dependency_name, ds_dependency_version = dependency
1646
- self.catalog.metastore.add_dataset_dependency(
1647
- dataset.name,
1648
- version,
1649
- ds_dependency_name,
1650
- ds_dependency_version,
1651
- )
1652
- else:
1653
- # storage dependency - its name is a valid StorageURI
1654
- storage = self.catalog.metastore.get_storage(dependency)
1655
- self.catalog.metastore.add_storage_dependency(
1656
- StorageURI(dataset.name),
1657
- version,
1658
- storage.uri,
1659
- storage.timestamp_str,
1660
- )
1536
+ ds_dependency_name, ds_dependency_version = dependency
1537
+ self.catalog.metastore.add_dataset_dependency(
1538
+ dataset.name,
1539
+ version,
1540
+ ds_dependency_name,
1541
+ ds_dependency_version,
1542
+ )
1661
1543
 
1662
1544
  def exec(self) -> "Self":
1663
1545
  """Execute the query."""
@@ -27,7 +27,7 @@ from datachain.query.queue import (
27
27
  put_into_queue,
28
28
  unmarshal,
29
29
  )
30
- from datachain.query.udf import UDFBase, UDFFactory, UDFResult
30
+ from datachain.query.udf import UDFBase, UDFResult
31
31
  from datachain.utils import batched_it
32
32
 
33
33
  DEFAULT_BATCH_SIZE = 10000
@@ -156,8 +156,6 @@ class UDFDispatcher:
156
156
 
157
157
  @property
158
158
  def batch_size(self):
159
- if not self.udf:
160
- self.udf = self.udf_factory()
161
159
  if self._batch_size is None:
162
160
  if hasattr(self.udf, "properties") and hasattr(
163
161
  self.udf.properties, "batch"
@@ -181,18 +179,7 @@ class UDFDispatcher:
181
179
  self.catalog = Catalog(
182
180
  id_generator, metastore, warehouse, **self.catalog_init_params
183
181
  )
184
- udf = loads(self.udf_data)
185
- # isinstance cannot be used here, as cloudpickle packages the entire class
186
- # definition, and so these two types are not considered exactly equal,
187
- # even if they have the same import path.
188
- if full_module_type_path(type(udf)) != full_module_type_path(UDFFactory):
189
- self.udf = udf
190
- else:
191
- self.udf = None
192
- self.udf_factory = udf
193
- if not self.udf:
194
- self.udf = self.udf_factory()
195
-
182
+ self.udf = loads(self.udf_data)
196
183
  return UDFWorker(
197
184
  self.catalog,
198
185
  self.udf,
datachain/query/schema.py CHANGED
@@ -9,6 +9,7 @@ import attrs
9
9
  import sqlalchemy as sa
10
10
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
11
11
 
12
+ from datachain.lib.file import File
12
13
  from datachain.sql.types import JSON, Boolean, DateTime, Int64, SQLType, String
13
14
 
14
15
  if TYPE_CHECKING:
@@ -19,6 +20,17 @@ if TYPE_CHECKING:
19
20
  DEFAULT_DELIMITER = "__"
20
21
 
21
22
 
23
+ def file_signals(row, signal_name="file"):
24
+ # TODO this is workaround until we decide what to do with these classes
25
+ prefix = f"{signal_name}{DEFAULT_DELIMITER}"
26
+ return {
27
+ c_name.removeprefix(prefix): c_value
28
+ for c_name, c_value in row.items()
29
+ if c_name.startswith(prefix)
30
+ and DEFAULT_DELIMITER not in c_name.removeprefix(prefix)
31
+ }
32
+
33
+
22
34
  class ColumnMeta(type):
23
35
  @staticmethod
24
36
  def to_db_name(name: str) -> str:
@@ -86,11 +98,11 @@ class Object(UDFParameter):
86
98
  cb: Callback = DEFAULT_CALLBACK,
87
99
  **kwargs,
88
100
  ) -> Any:
89
- client = catalog.get_client(row["source"])
90
- uid = catalog._get_row_uid(row)
101
+ file = File._from_row(file_signals(row))
102
+ client = catalog.get_client(file.source)
91
103
  if cache:
92
- client.download(uid, callback=cb)
93
- with client.open_object(uid, use_cache=cache, cb=cb) as f:
104
+ client.download(file, callback=cb)
105
+ with client.open_object(file, use_cache=cache, cb=cb) as f:
94
106
  return self.reader(f)
95
107
 
96
108
  async def get_value_async(
@@ -103,12 +115,12 @@ class Object(UDFParameter):
103
115
  cb: Callback = DEFAULT_CALLBACK,
104
116
  **kwargs,
105
117
  ) -> Any:
106
- client = catalog.get_client(row["source"])
107
- uid = catalog._get_row_uid(row)
118
+ file = File._from_row(file_signals(row))
119
+ client = catalog.get_client(file.source)
108
120
  if cache:
109
- await client._download(uid, callback=cb)
121
+ await client._download(file, callback=cb)
110
122
  obj = await mapper.to_thread(
111
- functools.partial(client.open_object, uid, use_cache=cache, cb=cb)
123
+ functools.partial(client.open_object, file, use_cache=cache, cb=cb)
112
124
  )
113
125
  with obj:
114
126
  return await mapper.to_thread(self.reader, obj)
@@ -129,11 +141,11 @@ class Stream(UDFParameter):
129
141
  cb: Callback = DEFAULT_CALLBACK,
130
142
  **kwargs,
131
143
  ) -> Any:
132
- client = catalog.get_client(row["source"])
133
- uid = catalog._get_row_uid(row)
144
+ file = File._from_row(file_signals(row))
145
+ client = catalog.get_client(file.source)
134
146
  if cache:
135
- client.download(uid, callback=cb)
136
- return client.open_object(uid, use_cache=cache, cb=cb)
147
+ client.download(file, callback=cb)
148
+ return client.open_object(file, use_cache=cache, cb=cb)
137
149
 
138
150
  async def get_value_async(
139
151
  self,
@@ -145,12 +157,12 @@ class Stream(UDFParameter):
145
157
  cb: Callback = DEFAULT_CALLBACK,
146
158
  **kwargs,
147
159
  ) -> Any:
148
- client = catalog.get_client(row["source"])
149
- uid = catalog._get_row_uid(row)
160
+ file = File._from_row(file_signals(row))
161
+ client = catalog.get_client(file.source)
150
162
  if cache:
151
- await client._download(uid, callback=cb)
163
+ await client._download(file, callback=cb)
152
164
  return await mapper.to_thread(
153
- functools.partial(client.open_object, uid, use_cache=cache, cb=cb)
165
+ functools.partial(client.open_object, file, use_cache=cache, cb=cb)
154
166
  )
155
167
 
156
168
 
@@ -178,10 +190,10 @@ class LocalFilename(UDFParameter):
178
190
  # If the glob pattern is specified and the row filename
179
191
  # does not match it, then return None
180
192
  return None
181
- client = catalog.get_client(row["source"])
182
- uid = catalog._get_row_uid(row)
183
- client.download(uid, callback=cb)
184
- return client.cache.get_path(uid)
193
+ file = File._from_row(file_signals(row))
194
+ client = catalog.get_client(file.source)
195
+ client.download(file, callback=cb)
196
+ return client.cache.get_path(file)
185
197
 
186
198
  async def get_value_async(
187
199
  self,
@@ -197,10 +209,10 @@ class LocalFilename(UDFParameter):
197
209
  # If the glob pattern is specified and the row filename
198
210
  # does not match it, then return None
199
211
  return None
200
- client = catalog.get_client(row["source"])
201
- uid = catalog._get_row_uid(row)
202
- await client._download(uid, callback=cb)
203
- return client.cache.get_path(uid)
212
+ file = File._from_row(file_signals(row))
213
+ client = catalog.get_client(file.source)
214
+ await client._download(file, callback=cb)
215
+ return client.cache.get_path(file)
204
216
 
205
217
 
206
218
  UDFParamSpec = Union[str, Column, UDFParameter]
datachain/query/udf.py CHANGED
@@ -1,14 +1,9 @@
1
1
  import typing
2
- from collections.abc import Iterable, Iterator, Mapping, Sequence
2
+ from collections.abc import Iterable, Iterator, Sequence
3
3
  from dataclasses import dataclass
4
- from functools import WRAPPER_ASSIGNMENTS
5
- from inspect import isclass
6
4
  from typing import (
7
5
  TYPE_CHECKING,
8
6
  Any,
9
- Callable,
10
- Optional,
11
- Union,
12
7
  )
13
8
 
14
9
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
@@ -23,11 +18,7 @@ from .batch import (
23
18
  RowsOutputBatch,
24
19
  UDFInputBatch,
25
20
  )
26
- from .schema import (
27
- UDFParameter,
28
- UDFParamSpec,
29
- normalize_param,
30
- )
21
+ from .schema import UDFParameter
31
22
 
32
23
  if TYPE_CHECKING:
33
24
  from datachain.catalog import Catalog
@@ -66,41 +57,6 @@ class UDFProperties:
66
57
  return self.output.keys()
67
58
 
68
59
 
69
- def udf(
70
- params: Sequence[UDFParamSpec],
71
- output: UDFOutputSpec,
72
- *,
73
- method: Optional[str] = None, # only used for class-based UDFs
74
- batch: int = 1,
75
- ):
76
- """
77
- Decorate a function or a class to be used as a UDF.
78
-
79
- The decorator expects both the outputs and inputs of the UDF to be specified.
80
- The outputs are defined as a collection of tuples containing the signal name
81
- and type.
82
- Parameters are defined as a list of column objects (e.g. C.name).
83
- Optionally, UDFs can be run on batches of rows to improve performance, this
84
- is determined by the 'batch' parameter. When operating on batches of inputs,
85
- the UDF function will be called with a single argument - a list
86
- of tuples containing inputs (e.g. ((input1_a, input1_b), (input2_a, input2b))).
87
- """
88
- if isinstance(params, str):
89
- params = (params,)
90
- if not isinstance(output, Mapping):
91
- raise TypeError(f"'output' must be a mapping, got {type(output).__name__}")
92
-
93
- properties = UDFProperties([normalize_param(p) for p in params], output, batch)
94
-
95
- def decorator(udf_base: Union[Callable, type]):
96
- if isclass(udf_base):
97
- return UDFClassWrapper(udf_base, properties, method=method)
98
- if callable(udf_base):
99
- return UDFWrapper(udf_base, properties)
100
-
101
- return decorator
102
-
103
-
104
60
  class UDFBase:
105
61
  """A base class for implementing stateful UDFs."""
106
62
 
@@ -168,105 +124,3 @@ class UDFBase:
168
124
  for row_id, signals in zip(row_ids, results)
169
125
  if signals is not None # skip rows with no output
170
126
  ]
171
-
172
-
173
- class UDFClassWrapper:
174
- """
175
- A wrapper for class-based (stateful) UDFs.
176
- """
177
-
178
- def __init__(
179
- self,
180
- udf_class: type,
181
- properties: UDFProperties,
182
- method: Optional[str] = None,
183
- ):
184
- self.udf_class = udf_class
185
- self.udf_method = method
186
- self.properties = properties
187
- self.output = properties.output
188
-
189
- def __call__(self, *args, **kwargs) -> "UDFFactory":
190
- return UDFFactory(
191
- self.udf_class,
192
- args,
193
- kwargs,
194
- self.properties,
195
- self.udf_method,
196
- )
197
-
198
-
199
- class UDFWrapper(UDFBase):
200
- """A wrapper class for function UDFs to be used in custom signal generation."""
201
-
202
- def __init__(
203
- self,
204
- func: Callable,
205
- properties: UDFProperties,
206
- ):
207
- self.func = func
208
- super().__init__(properties)
209
- # This emulates the behavior of functools.wraps for a class decorator
210
- for attr in WRAPPER_ASSIGNMENTS:
211
- if hasattr(func, attr):
212
- setattr(self, attr, getattr(func, attr))
213
-
214
- def run_once(
215
- self,
216
- catalog: "Catalog",
217
- arg: "UDFInput",
218
- is_generator: bool = False,
219
- cache: bool = False,
220
- cb: Callback = DEFAULT_CALLBACK,
221
- ) -> Iterable[UDFResult]:
222
- if isinstance(arg, UDFInputBatch):
223
- udf_inputs = [
224
- self.bind_parameters(catalog, row, cache=cache, cb=cb)
225
- for row in arg.rows
226
- ]
227
- udf_outputs = self.func(udf_inputs)
228
- return self._process_results(arg.rows, udf_outputs, is_generator)
229
- if isinstance(arg, RowDict):
230
- udf_inputs = self.bind_parameters(catalog, arg, cache=cache, cb=cb)
231
- udf_outputs = self.func(*udf_inputs)
232
- if not is_generator:
233
- # udf_outputs is generator already if is_generator=True
234
- udf_outputs = [udf_outputs]
235
- return self._process_results([arg], udf_outputs, is_generator)
236
- raise ValueError(f"Unexpected UDF argument: {arg}")
237
-
238
- # This emulates the behavior of functools.wraps for a class decorator
239
- def __repr__(self):
240
- return repr(self.func)
241
-
242
-
243
- class UDFFactory:
244
- """
245
- A wrapper for late instantiation of UDF classes, primarily for use in parallelized
246
- execution.
247
- """
248
-
249
- def __init__(
250
- self,
251
- udf_class: type,
252
- args,
253
- kwargs,
254
- properties: UDFProperties,
255
- method: Optional[str] = None,
256
- ):
257
- self.udf_class = udf_class
258
- self.udf_method = method
259
- self.args = args
260
- self.kwargs = kwargs
261
- self.properties = properties
262
- self.output = properties.output
263
-
264
- def __call__(self) -> UDFWrapper:
265
- udf_func = self.udf_class(*self.args, **self.kwargs)
266
- if self.udf_method:
267
- udf_func = getattr(udf_func, self.udf_method)
268
-
269
- return UDFWrapper(udf_func, self.properties)
270
-
271
-
272
- UDFType = Union[UDFBase, UDFFactory]
datachain/sql/types.py CHANGED
@@ -12,11 +12,11 @@ for sqlite we can use `sqlite.register_converter`
12
12
  ( https://docs.python.org/3/library/sqlite3.html#sqlite3.register_converter )
13
13
  """
14
14
 
15
- import json
16
15
  from datetime import datetime
17
16
  from types import MappingProxyType
18
17
  from typing import Any, Union
19
18
 
19
+ import orjson
20
20
  import sqlalchemy as sa
21
21
  from sqlalchemy import TypeDecorator, types
22
22
 
@@ -312,7 +312,7 @@ class Array(SQLType):
312
312
  def on_read_convert(self, value, dialect):
313
313
  r = read_converter(dialect).array(value, self.item_type, dialect)
314
314
  if isinstance(self.item_type, JSON):
315
- r = [json.loads(item) if isinstance(item, str) else item for item in r]
315
+ r = [orjson.loads(item) if isinstance(item, str) else item for item in r]
316
316
  return r
317
317
 
318
318
 
@@ -420,6 +420,8 @@ class TypeReadConverter:
420
420
  return [item_type.on_read_convert(x, dialect) for x in value]
421
421
 
422
422
  def json(self, value):
423
+ if isinstance(value, str):
424
+ return orjson.loads(value)
423
425
  return value
424
426
 
425
427
  def datetime(self, value):