datachain 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (39) hide show
  1. datachain/__init__.py +0 -4
  2. datachain/catalog/catalog.py +17 -2
  3. datachain/cli.py +8 -1
  4. datachain/data_storage/db_engine.py +0 -2
  5. datachain/data_storage/schema.py +15 -26
  6. datachain/data_storage/sqlite.py +3 -0
  7. datachain/data_storage/warehouse.py +1 -7
  8. datachain/lib/arrow.py +7 -13
  9. datachain/lib/cached_stream.py +3 -85
  10. datachain/lib/clip.py +151 -0
  11. datachain/lib/dc.py +41 -59
  12. datachain/lib/feature.py +5 -1
  13. datachain/lib/feature_registry.py +3 -2
  14. datachain/lib/feature_utils.py +1 -2
  15. datachain/lib/file.py +17 -24
  16. datachain/lib/image.py +37 -79
  17. datachain/lib/pytorch.py +4 -2
  18. datachain/lib/signal_schema.py +3 -4
  19. datachain/lib/text.py +18 -49
  20. datachain/lib/udf.py +64 -55
  21. datachain/lib/udf_signature.py +11 -10
  22. datachain/lib/utils.py +17 -0
  23. datachain/lib/webdataset.py +2 -2
  24. datachain/listing.py +0 -3
  25. datachain/query/dataset.py +66 -46
  26. datachain/query/dispatch.py +2 -2
  27. datachain/query/schema.py +1 -8
  28. datachain/query/udf.py +16 -18
  29. datachain/sql/sqlite/base.py +34 -2
  30. datachain/sql/sqlite/vector.py +13 -5
  31. datachain/utils.py +28 -0
  32. {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/METADATA +3 -2
  33. {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/RECORD +37 -38
  34. {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/WHEEL +1 -1
  35. datachain/_version.py +0 -16
  36. datachain/lib/reader.py +0 -49
  37. {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/LICENSE +0 -0
  38. {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/entry_points.txt +0 -0
  39. {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/top_level.txt +0 -0
datachain/lib/udf.py CHANGED
@@ -1,15 +1,16 @@
1
1
  import inspect
2
2
  import sys
3
3
  import traceback
4
- from typing import TYPE_CHECKING, Callable, Optional
4
+ from typing import TYPE_CHECKING, Callable
5
5
 
6
6
  from datachain.lib.feature import Feature
7
7
  from datachain.lib.signal_schema import SignalSchema
8
- from datachain.lib.utils import DataChainError, DataChainParamsError
9
- from datachain.query import Stream, udf
8
+ from datachain.lib.udf_signature import UdfSignature
9
+ from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
10
+ from datachain.query import udf
10
11
 
11
12
  if TYPE_CHECKING:
12
- from dvxc.query.udf import UDFWrapper
13
+ from datachain.query.udf import UDFWrapper
13
14
 
14
15
 
15
16
  class UdfError(DataChainParamsError):
@@ -17,31 +18,68 @@ class UdfError(DataChainParamsError):
17
18
  super().__init__(f"UDF error: {msg}")
18
19
 
19
20
 
20
- class UDFBase:
21
+ class UDFBase(AbstractUDF):
21
22
  is_input_batched = False
22
23
  is_output_batched = False
23
24
  is_input_grouped = False
24
25
 
25
- def __init__(
26
- self,
27
- params: SignalSchema,
28
- output: SignalSchema,
29
- func: Optional[Callable] = None,
30
- ):
26
+ def __init__(self):
27
+ self.params = None
28
+ self.output = None
29
+ self.params_spec = None
30
+ self.output_spec = None
31
+ self._contains_stream = None
32
+ self._catalog = None
33
+ self._func = None
34
+
35
+ def process(self, *args, **kwargs):
36
+ """Processing function that needs to be defined by user"""
37
+ if not self._func:
38
+ raise NotImplementedError("UDF processing is not implemented")
39
+ return self._func(*args, **kwargs)
40
+
41
+ def setup(self):
42
+ """Initialization process executed on each worker before processing begins.
43
+ This is needed for tasks like pre-loading ML models prior to scoring.
44
+ """
45
+
46
+ def teardown(self):
47
+ """Teardown process executed on each process/worker after processing ends.
48
+ This is needed for tasks like closing connections to end-points.
49
+ """
50
+
51
+ def _init(self, sign: UdfSignature, params: SignalSchema, func: Callable):
31
52
  self.params = params
32
- self.output = output
33
- self._func = func
53
+ self.output = sign.output_schema
34
54
 
35
- params_spec = params.to_udf_spec()
55
+ params_spec = self.params.to_udf_spec()
36
56
  self.params_spec = list(params_spec.keys())
37
- self._contains_stream = False
38
- if params.contains_file():
39
- self.params_spec.insert(0, Stream()) # type: ignore[arg-type]
40
- self._contains_stream = True
57
+ self.output_spec = self.output.to_udf_spec()
41
58
 
42
- self.output_spec = output.to_udf_spec()
59
+ self._func = func
43
60
 
44
- self._catalog = None
61
+ @classmethod
62
+ def _create(
63
+ cls,
64
+ target_class: type["UDFBase"],
65
+ sign: UdfSignature,
66
+ params: SignalSchema,
67
+ catalog,
68
+ ) -> "UDFBase":
69
+ if isinstance(sign.func, AbstractUDF):
70
+ if not isinstance(sign.func, target_class): # type: ignore[unreachable]
71
+ raise UdfError(
72
+ f"cannot create UDF: provided UDF '{sign.func.__name__}'"
73
+ f" must be a child of target class '{target_class.__name__}'",
74
+ )
75
+ result = sign.func
76
+ func = None
77
+ else:
78
+ result = target_class()
79
+ func = sign.func
80
+
81
+ result._init(sign, params, func)
82
+ return result
45
83
 
46
84
  @property
47
85
  def name(self):
@@ -58,25 +96,10 @@ class UDFBase:
58
96
  udf_wrapper = udf(self.params_spec, self.output_spec, batch=batch)
59
97
  return udf_wrapper(self)
60
98
 
61
- def bootstrap(self):
62
- """Initialization process executed on each worker before processing begins.
63
- This is needed for tasks like pre-loading ML models prior to scoring.
64
- """
65
-
66
- def teardown(self):
67
- """Teardown process executed on each process/worker after processing ends.
68
- This is needed for tasks like closing connections to end-points.
69
- """
70
-
71
- def process(self, *args, **kwargs):
72
- if not self._func:
73
- raise NotImplementedError("UDF processing is not implemented")
74
- return self._func(*args, **kwargs)
75
-
76
99
  def validate_results(self, results, *args, **kwargs):
77
100
  return results
78
101
 
79
- def __call__(self, *rows, **kwargs):
102
+ def __call__(self, *rows):
80
103
  if self.is_input_grouped:
81
104
  objs = self._parse_grouped_rows(rows)
82
105
  else:
@@ -122,18 +145,10 @@ class UDFBase:
122
145
  rows = [rows]
123
146
  objs = []
124
147
  for row in rows:
125
- if self._contains_stream:
126
- stream, *row = row
127
- else:
128
- stream = None
129
-
130
148
  obj_row = self.params.row_to_objs(row)
131
-
132
- if self._contains_stream:
133
- for obj in obj_row:
134
- if isinstance(obj, Feature):
135
- obj._set_stream(self._catalog, stream, True)
136
-
149
+ for obj in obj_row:
150
+ if isinstance(obj, Feature):
151
+ obj._set_stream(self._catalog, caching_enabled=True)
137
152
  objs.append(obj_row)
138
153
  return objs
139
154
 
@@ -150,13 +165,7 @@ class UDFBase:
150
165
  output_map[name] = []
151
166
 
152
167
  for flat_obj in group:
153
- if self._contains_stream:
154
- position = 1
155
- stream = flat_obj[0]
156
- else:
157
- position = 0
158
- stream = None
159
-
168
+ position = 0
160
169
  for signal, (cls, length) in spec_map.items():
161
170
  slice = flat_obj[position : position + length]
162
171
  position += length
@@ -167,7 +176,7 @@ class UDFBase:
167
176
  obj = slice[0]
168
177
 
169
178
  if isinstance(obj, Feature):
170
- obj._set_stream(self._catalog, stream)
179
+ obj._set_stream(self._catalog)
171
180
  output_map[signal].append(obj)
172
181
 
173
182
  return list(output_map.values())
@@ -5,7 +5,7 @@ from typing import Callable, Optional, Union, get_args, get_origin
5
5
 
6
6
  from datachain.lib.feature import Feature, FeatureType, FeatureTypeNames
7
7
  from datachain.lib.signal_schema import SignalSchema
8
- from datachain.lib.utils import DataChainParamsError
8
+ from datachain.lib.utils import AbstractUDF, DataChainParamsError
9
9
 
10
10
 
11
11
  class UdfSignatureError(DataChainParamsError):
@@ -49,10 +49,13 @@ class UdfSignature:
49
49
  else:
50
50
  if func is None:
51
51
  raise UdfSignatureError(chain, "user function is not defined")
52
+
52
53
  udf_func = func
53
54
  signal_name = None
55
+
54
56
  if not callable(udf_func):
55
- raise UdfSignatureError(chain, f"function '{func}' is not callable")
57
+ raise UdfSignatureError(chain, f"UDF '{udf_func}' is not callable")
58
+
56
59
  func_params_map_sign, func_outs_sign, is_iterator = (
57
60
  UdfSignature._func_signature(chain, udf_func)
58
61
  )
@@ -108,13 +111,6 @@ class UdfSignature:
108
111
  if isinstance(output, str):
109
112
  output = [output]
110
113
  if isinstance(output, Sequence):
111
- if not func_outs_sign:
112
- raise UdfSignatureError(
113
- chain,
114
- "output types are not specified. Specify types in 'output' as"
115
- " a dict or as function return value hint.",
116
- )
117
-
118
114
  if len(func_outs_sign) != len(output):
119
115
  raise UdfSignatureError(
120
116
  chain,
@@ -158,8 +154,13 @@ class UdfSignature:
158
154
 
159
155
  @staticmethod
160
156
  def _func_signature(
161
- chain: str, func: Callable
157
+ chain: str, udf_func: Callable
162
158
  ) -> tuple[dict[str, type], Sequence[type], bool]:
159
+ if isinstance(udf_func, AbstractUDF):
160
+ func = udf_func.process # type: ignore[unreachable]
161
+ else:
162
+ func = udf_func
163
+
163
164
  sign = inspect.signature(func)
164
165
 
165
166
  input_map = {prm.name: prm.annotation for prm in sign.parameters.values()}
datachain/lib/utils.py CHANGED
@@ -1,3 +1,20 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+
4
+ class AbstractUDF(ABC):
5
+ @abstractmethod
6
+ def process(self, *args, **kwargs):
7
+ pass
8
+
9
+ @abstractmethod
10
+ def setup(self):
11
+ pass
12
+
13
+ @abstractmethod
14
+ def teardown(self):
15
+ pass
16
+
17
+
1
18
  class DataChainError(Exception):
2
19
  def __init__(self, message):
3
20
  super().__init__(message)
@@ -2,6 +2,7 @@ import hashlib
2
2
  import json
3
3
  import tarfile
4
4
  from collections.abc import Iterator, Sequence
5
+ from pathlib import Path
5
6
  from typing import (
6
7
  Any,
7
8
  Callable,
@@ -240,10 +241,9 @@ class TarStream(File):
240
241
  def get_tar_groups(stream, tar, core_extensions, spec, encoding="utf-8"):
241
242
  builder = Builder(stream, core_extensions, spec, tar, encoding)
242
243
 
243
- for item in tar.getmembers():
244
+ for item in sorted(tar.getmembers(), key=lambda m: Path(m.name).stem):
244
245
  if not item.isfile():
245
246
  continue
246
-
247
247
  try:
248
248
  builder.add(item)
249
249
  except StopIteration:
datachain/listing.py CHANGED
@@ -20,9 +20,6 @@ if TYPE_CHECKING:
20
20
  from datachain.storage import Storage
21
21
 
22
22
 
23
- RANDOM_BITS = 63 # size of the random integer field
24
-
25
-
26
23
  class Listing:
27
24
  def __init__(
28
25
  self,
@@ -1,3 +1,4 @@
1
+ import ast
1
2
  import contextlib
2
3
  import datetime
3
4
  import inspect
@@ -51,9 +52,10 @@ from datachain.data_storage.schema import (
51
52
  from datachain.dataset import DatasetStatus, RowDict
52
53
  from datachain.error import DatasetNotFoundError, QueryScriptCancelError
53
54
  from datachain.progress import CombinedDownloadCallback
55
+ from datachain.query.schema import DEFAULT_DELIMITER
54
56
  from datachain.sql.functions import rand
55
57
  from datachain.storage import Storage, StorageURI
56
- from datachain.utils import batched, determine_processes
58
+ from datachain.utils import batched, determine_processes, inside_notebook
57
59
 
58
60
  from .batch import RowBatch
59
61
  from .metrics import metrics
@@ -62,7 +64,6 @@ from .session import Session
62
64
  from .udf import UDFBase, UDFClassWrapper, UDFFactory, UDFType
63
65
 
64
66
  if TYPE_CHECKING:
65
- import pandas as pd
66
67
  from sqlalchemy.sql.elements import ClauseElement
67
68
  from sqlalchemy.sql.schema import Table
68
69
  from sqlalchemy.sql.selectable import GenerativeSelect
@@ -547,8 +548,9 @@ class UDF(Step, ABC):
547
548
  else:
548
549
  udf = self.udf
549
550
 
550
- if hasattr(udf.func, "bootstrap") and callable(udf.func.bootstrap):
551
- udf.func.bootstrap()
551
+ if hasattr(udf.func, "setup") and callable(udf.func.setup):
552
+ udf.func.setup()
553
+
552
554
  warehouse = self.catalog.warehouse
553
555
 
554
556
  with contextlib.closing(
@@ -599,12 +601,15 @@ class UDF(Step, ABC):
599
601
  # Create a dynamic module with the generated name
600
602
  dynamic_module = types.ModuleType(feature_module_name)
601
603
  # Get the import lines for the necessary objects from the main module
602
- import_lines = [
603
- source.getimport(obj, alias=name)
604
- for name, obj in inspect.getmembers(sys.modules["__main__"], _imports)
605
- if not (name.startswith("__") and name.endswith("__"))
606
- ]
607
604
  main_module = sys.modules["__main__"]
605
+ if getattr(main_module, "__file__", None):
606
+ import_lines = list(get_imports(main_module))
607
+ else:
608
+ import_lines = [
609
+ source.getimport(obj, alias=name)
610
+ for name, obj in main_module.__dict__.items()
611
+ if _imports(obj) and not (name.startswith("__") and name.endswith("__"))
612
+ ]
608
613
 
609
614
  # Get the feature classes from the main module
610
615
  feature_classes = {
@@ -612,6 +617,10 @@ class UDF(Step, ABC):
612
617
  for name, obj in main_module.__dict__.items()
613
618
  if _feature_predicate(obj)
614
619
  }
620
+ if not feature_classes:
621
+ yield None
622
+ return
623
+
615
624
  # Get the source code of the feature classes
616
625
  feature_sources = [source.getsource(cls) for _, cls in feature_classes.items()]
617
626
  # Set the module name for the feature classes to the generated name
@@ -621,7 +630,7 @@ class UDF(Step, ABC):
621
630
  # Add the dynamic module to the sys.modules dictionary
622
631
  sys.modules[feature_module_name] = dynamic_module
623
632
  # Combine the import lines and feature sources
624
- feature_file = "".join(import_lines) + "\n".join(feature_sources)
633
+ feature_file = "\n".join(import_lines) + "\n" + "\n".join(feature_sources)
625
634
 
626
635
  # Write the module content to a .py file
627
636
  with open(f"{feature_module_name}.py", "w") as module_file:
@@ -1362,33 +1371,11 @@ class DatasetQuery:
1362
1371
  cols = result.columns
1363
1372
  return [dict(zip(cols, row)) for row in result]
1364
1373
 
1365
- @classmethod
1366
- def create_empty_record(
1367
- cls, name: Optional[str] = None, session: Optional[Session] = None
1368
- ) -> "DatasetRecord":
1369
- session = Session.get(session)
1370
- if name is None:
1371
- name = session.generate_temp_dataset_name()
1372
- columns = session.catalog.warehouse.dataset_row_cls.file_columns()
1373
- return session.catalog.create_dataset(name, columns=columns)
1374
-
1375
- @classmethod
1376
- def insert_record(
1377
- cls,
1378
- dsr: "DatasetRecord",
1379
- record: dict[str, Any],
1380
- session: Optional[Session] = None,
1381
- ) -> None:
1382
- session = Session.get(session)
1383
- dr = session.catalog.warehouse.dataset_rows(dsr)
1384
- insert_q = dr.get_table().insert().values(**record)
1385
- session.catalog.warehouse.db.execute(insert_q)
1386
-
1387
1374
  def to_pandas(self) -> "pd.DataFrame":
1388
- import pandas as pd
1389
-
1390
1375
  records = self.to_records()
1391
- return pd.DataFrame.from_records(records)
1376
+ df = pd.DataFrame.from_records(records)
1377
+ df.columns = [c.replace(DEFAULT_DELIMITER, ".") for c in df.columns]
1378
+ return df
1392
1379
 
1393
1380
  def shuffle(self) -> "Self":
1394
1381
  # ToDo: implement shaffle based on seed and/or generating random column
@@ -1410,8 +1397,17 @@ class DatasetQuery:
1410
1397
 
1411
1398
  def show(self, limit=20) -> None:
1412
1399
  df = self.limit(limit).to_pandas()
1413
- no_footer = re.sub(r"\n\[\d+ rows x \d+ columns\]$", "", str(df))
1414
- print(no_footer.rstrip(" \n"))
1400
+
1401
+ options = ["display.max_colwidth", 50, "display.show_dimensions", False]
1402
+ with pd.option_context(*options):
1403
+ if inside_notebook():
1404
+ from IPython.display import display
1405
+
1406
+ display(df)
1407
+
1408
+ else:
1409
+ print(df.to_string())
1410
+
1415
1411
  if len(df) == limit:
1416
1412
  print(f"[limited by {limit} objects]")
1417
1413
 
@@ -1692,6 +1688,15 @@ class DatasetQuery:
1692
1688
  storage.timestamp_str,
1693
1689
  )
1694
1690
 
1691
+ def exec(self) -> "Self":
1692
+ """Execute the query."""
1693
+ try:
1694
+ query = self.clone()
1695
+ query.apply_steps()
1696
+ finally:
1697
+ self.cleanup()
1698
+ return query
1699
+
1695
1700
  def save(
1696
1701
  self,
1697
1702
  name: Optional[str] = None,
@@ -1737,22 +1742,16 @@ class DatasetQuery:
1737
1742
 
1738
1743
  # Exclude the id column and let the db create it to avoid unique
1739
1744
  # constraint violations.
1740
- cols = [col.name for col in dr.get_table().c if col.name != "id"]
1741
- assert cols
1742
1745
  q = query.exclude(("id",))
1743
-
1744
1746
  if q._order_by_clauses:
1745
1747
  # ensuring we have id sorted by order by clause if it exists in a query
1746
1748
  q = q.add_columns(
1747
1749
  f.row_number().over(order_by=q._order_by_clauses).label("id")
1748
1750
  )
1749
- cols.append("id")
1750
-
1751
- self.catalog.warehouse.db.execute(
1752
- sqlalchemy.insert(dr.get_table()).from_select(cols, q),
1753
- **kwargs,
1754
- )
1755
1751
 
1752
+ cols = tuple(c.name for c in q.columns)
1753
+ insert_q = sqlalchemy.insert(dr.get_table()).from_select(cols, q)
1754
+ self.catalog.warehouse.db.execute(insert_q, **kwargs)
1756
1755
  self.catalog.metastore.update_dataset_status(
1757
1756
  dataset, DatasetStatus.COMPLETE, version=version
1758
1757
  )
@@ -1884,3 +1883,24 @@ def _feature_predicate(obj):
1884
1883
 
1885
1884
  def _imports(obj):
1886
1885
  return not source.isfrommain(obj)
1886
+
1887
+
1888
+ def get_imports(m):
1889
+ root = ast.parse(inspect.getsource(m))
1890
+
1891
+ for node in ast.iter_child_nodes(root):
1892
+ if isinstance(node, ast.Import):
1893
+ module = None
1894
+ elif isinstance(node, ast.ImportFrom):
1895
+ module = node.module
1896
+ else:
1897
+ continue
1898
+
1899
+ for n in node.names:
1900
+ import_script = ""
1901
+ if module:
1902
+ import_script += f"from {module} "
1903
+ import_script += f"import {n.name}"
1904
+ if n.asname:
1905
+ import_script += f" as {n.asname}"
1906
+ yield import_script
@@ -370,8 +370,8 @@ class UDFWorker:
370
370
  return WorkerCallback(self.done_queue)
371
371
 
372
372
  def run(self) -> None:
373
- if hasattr(self.udf.func, "bootstrap") and callable(self.udf.func.bootstrap):
374
- self.udf.func.bootstrap()
373
+ if hasattr(self.udf.func, "setup") and callable(self.udf.func.setup):
374
+ self.udf.func.setup()
375
375
  while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
376
376
  n_rows = len(batch.rows) if isinstance(batch, RowBatch) else 1
377
377
  udf_output = self.udf(
datachain/query/schema.py CHANGED
@@ -3,14 +3,12 @@ import json
3
3
  from abc import ABC, abstractmethod
4
4
  from datetime import datetime, timezone
5
5
  from fnmatch import fnmatch
6
- from random import getrandbits
7
6
  from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union
8
7
 
9
8
  import attrs
10
9
  import sqlalchemy as sa
11
10
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
12
11
 
13
- from datachain.data_storage.warehouse import RANDOM_BITS
14
12
  from datachain.sql.types import JSON, Boolean, DateTime, Int, Int64, SQLType, String
15
13
 
16
14
  if TYPE_CHECKING:
@@ -217,7 +215,7 @@ class DatasetRow:
217
215
  "source": String,
218
216
  "parent": String,
219
217
  "name": String,
220
- "size": Int,
218
+ "size": Int64,
221
219
  "location": JSON,
222
220
  "vtype": String,
223
221
  "dir_type": Int,
@@ -227,8 +225,6 @@ class DatasetRow:
227
225
  "last_modified": DateTime,
228
226
  "version": String,
229
227
  "etag": String,
230
- # system column
231
- "random": Int64,
232
228
  }
233
229
 
234
230
  @staticmethod
@@ -267,8 +263,6 @@ class DatasetRow:
267
263
 
268
264
  last_modified = last_modified or datetime.now(timezone.utc)
269
265
 
270
- random = getrandbits(RANDOM_BITS)
271
-
272
266
  return ( # type: ignore [return-value]
273
267
  source,
274
268
  parent,
@@ -283,7 +277,6 @@ class DatasetRow:
283
277
  last_modified,
284
278
  version,
285
279
  etag,
286
- random,
287
280
  )
288
281
 
289
282
  @staticmethod
datachain/query/udf.py CHANGED
@@ -14,6 +14,7 @@ from typing import (
14
14
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
15
15
 
16
16
  from datachain.dataset import RowDict
17
+ from datachain.lib.utils import AbstractUDF
17
18
 
18
19
  from .batch import Batch, BatchingStrategy, NoBatching, Partition, RowBatch
19
20
  from .schema import (
@@ -58,14 +59,6 @@ class UDFProperties:
58
59
  def signal_names(self) -> Iterable[str]:
59
60
  return self.output.keys()
60
61
 
61
- def parameter_parser(self) -> Callable:
62
- """Generate a parameter list from a dataset row."""
63
-
64
- def plist(catalog: "Catalog", row: "RowDict", **kwargs) -> list:
65
- return [p.get_value(catalog, row, **kwargs) for p in self.params]
66
-
67
- return plist
68
-
69
62
 
70
63
  def udf(
71
64
  params: Sequence[UDFParamSpec],
@@ -113,32 +106,37 @@ class UDFBase:
113
106
  self.func = func
114
107
  self.properties = properties
115
108
  self.signal_names = properties.signal_names()
116
- self.parameter_parser = properties.parameter_parser()
117
109
  self.output = properties.output
118
110
 
119
111
  def __call__(
120
112
  self,
121
113
  catalog: "Catalog",
122
- param: "BatchingResult",
114
+ arg: "BatchingResult",
123
115
  is_generator: bool = False,
124
116
  cache: bool = False,
125
117
  cb: Callback = DEFAULT_CALLBACK,
126
118
  ) -> Iterable[UDFResult]:
127
- if isinstance(param, RowBatch):
119
+ if isinstance(self.func, AbstractUDF):
120
+ self.func._catalog = catalog # type: ignore[unreachable]
121
+
122
+ if isinstance(arg, RowBatch):
128
123
  udf_inputs = [
129
- self.parameter_parser(catalog, row, cache=cache, cb=cb)
130
- for row in param.rows
124
+ self.bind_parameters(catalog, row, cache=cache, cb=cb)
125
+ for row in arg.rows
131
126
  ]
132
127
  udf_outputs = self.func(udf_inputs)
133
- return self._process_results(param.rows, udf_outputs, is_generator)
134
- if isinstance(param, RowDict):
135
- udf_inputs = self.parameter_parser(catalog, param, cache=cache, cb=cb)
128
+ return self._process_results(arg.rows, udf_outputs, is_generator)
129
+ if isinstance(arg, RowDict):
130
+ udf_inputs = self.bind_parameters(catalog, arg, cache=cache, cb=cb)
136
131
  udf_outputs = self.func(*udf_inputs)
137
132
  if not is_generator:
138
133
  # udf_outputs is generator already if is_generator=True
139
134
  udf_outputs = [udf_outputs]
140
- return self._process_results([param], udf_outputs, is_generator)
141
- raise ValueError(f"unexpected UDF parameter {param}")
135
+ return self._process_results([arg], udf_outputs, is_generator)
136
+ raise ValueError(f"Unexpected UDF argument: {arg}")
137
+
138
+ def bind_parameters(self, catalog: "Catalog", row: "RowDict", **kwargs) -> list:
139
+ return [p.get_value(catalog, row, **kwargs) for p in self.properties.params]
142
140
 
143
141
  def _process_results(
144
142
  self,
@@ -71,8 +71,6 @@ def setup():
71
71
  compiles(sql_path.name, "sqlite")(compile_path_name)
72
72
  compiles(sql_path.file_stem, "sqlite")(compile_path_file_stem)
73
73
  compiles(sql_path.file_ext, "sqlite")(compile_path_file_ext)
74
- compiles(array.cosine_distance, "sqlite")(compile_cosine_distance)
75
- compiles(array.euclidean_distance, "sqlite")(compile_euclidean_distance)
76
74
  compiles(array.length, "sqlite")(compile_array_length)
77
75
  compiles(string.length, "sqlite")(compile_string_length)
78
76
  compiles(string.split, "sqlite")(compile_string_split)
@@ -81,6 +79,13 @@ def setup():
81
79
  compiles(Values, "sqlite")(compile_values)
82
80
  compiles(random.rand, "sqlite")(compile_rand)
83
81
 
82
+ if load_usearch_extension(sqlite3.connect(":memory:")):
83
+ compiles(array.cosine_distance, "sqlite")(compile_cosine_distance_ext)
84
+ compiles(array.euclidean_distance, "sqlite")(compile_euclidean_distance_ext)
85
+ else:
86
+ compiles(array.cosine_distance, "sqlite")(compile_cosine_distance)
87
+ compiles(array.euclidean_distance, "sqlite")(compile_euclidean_distance)
88
+
84
89
  register_user_defined_sql_functions()
85
90
  setup_is_complete = True
86
91
 
@@ -246,11 +251,23 @@ def compile_path_file_ext(element, compiler, **kwargs):
246
251
  return compiler.process(path_file_ext(*element.clauses.clauses), **kwargs)
247
252
 
248
253
 
254
+ def compile_cosine_distance_ext(element, compiler, **kwargs):
255
+ run_compiler_hook("cosine_distance")
256
+ return f"distance_cosine_f32({compiler.process(element.clauses, **kwargs)})"
257
+
258
+
249
259
  def compile_cosine_distance(element, compiler, **kwargs):
250
260
  run_compiler_hook("cosine_distance")
251
261
  return f"cosine_distance({compiler.process(element.clauses, **kwargs)})"
252
262
 
253
263
 
264
+ def compile_euclidean_distance_ext(element, compiler, **kwargs):
265
+ run_compiler_hook("euclidean_distance")
266
+ return (
267
+ f"sqrt(distance_sqeuclidean_f32({compiler.process(element.clauses, **kwargs)}))"
268
+ )
269
+
270
+
254
271
  def compile_euclidean_distance(element, compiler, **kwargs):
255
272
  run_compiler_hook("euclidean_distance")
256
273
  return f"euclidean_distance({compiler.process(element.clauses, **kwargs)})"
@@ -330,3 +347,18 @@ def compile_values(element, compiler, **kwargs):
330
347
 
331
348
  def compile_rand(element, compiler, **kwargs):
332
349
  return compiler.process(func.random(), **kwargs)
350
+
351
+
352
+ def load_usearch_extension(conn) -> bool:
353
+ try:
354
+ # usearch is part of the vector optional dependencies
355
+ # we use the extension's cosine and euclidean distance functions
356
+ from usearch import sqlite_path
357
+
358
+ conn.enable_load_extension(True)
359
+ conn.load_extension(sqlite_path())
360
+ conn.enable_load_extension(False)
361
+ return True
362
+
363
+ except Exception: # noqa: BLE001
364
+ return False