datachain 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -1,5 +1,8 @@
1
1
  import atexit
2
+ import logging
3
+ import os
2
4
  import re
5
+ import sys
3
6
  from typing import TYPE_CHECKING, Optional
4
7
  from uuid import uuid4
5
8
 
@@ -9,6 +12,8 @@ from datachain.error import TableMissingError
9
12
  if TYPE_CHECKING:
10
13
  from datachain.catalog import Catalog
11
14
 
15
+ logger = logging.getLogger("datachain")
16
+
12
17
 
13
18
  class Session:
14
19
  """
@@ -35,6 +40,7 @@ class Session:
35
40
 
36
41
  GLOBAL_SESSION_CTX: Optional["Session"] = None
37
42
  GLOBAL_SESSION: Optional["Session"] = None
43
+ ORIGINAL_EXCEPT_HOOK = None
38
44
 
39
45
  DATASET_PREFIX = "session_"
40
46
  GLOBAL_SESSION_NAME = "global"
@@ -58,6 +64,7 @@ class Session:
58
64
 
59
65
  session_uuid = uuid4().hex[: self.SESSION_UUID_LEN]
60
66
  self.name = f"{name}_{session_uuid}"
67
+ self.job_id = os.getenv("DATACHAIN_JOB_ID") or str(uuid4())
61
68
  self.is_new_catalog = not catalog
62
69
  self.catalog = catalog or get_catalog(
63
70
  client_config=client_config, in_memory=in_memory
@@ -67,6 +74,9 @@ class Session:
67
74
  return self
68
75
 
69
76
  def __exit__(self, exc_type, exc_val, exc_tb):
77
+ if exc_type:
78
+ self._cleanup_created_versions(self.name)
79
+
70
80
  self._cleanup_temp_datasets()
71
81
  if self.is_new_catalog:
72
82
  self.catalog.metastore.close_on_exit()
@@ -88,6 +98,21 @@ class Session:
88
98
  except TableMissingError:
89
99
  pass
90
100
 
101
+ def _cleanup_created_versions(self, job_id: str) -> None:
102
+ versions = self.catalog.metastore.get_job_dataset_versions(job_id)
103
+ if not versions:
104
+ return
105
+
106
+ datasets = {}
107
+ for dataset_name, version in versions:
108
+ if dataset_name not in datasets:
109
+ datasets[dataset_name] = self.catalog.get_dataset(dataset_name)
110
+ dataset = datasets[dataset_name]
111
+ logger.info(
112
+ "Removing dataset version %s@%s due to exception", dataset_name, version
113
+ )
114
+ self.catalog.remove_dataset_version(dataset, version)
115
+
91
116
  @classmethod
92
117
  def get(
93
118
  cls,
@@ -114,9 +139,23 @@ class Session:
114
139
  in_memory=in_memory,
115
140
  )
116
141
  cls.GLOBAL_SESSION = cls.GLOBAL_SESSION_CTX.__enter__()
142
+
117
143
  atexit.register(cls._global_cleanup)
144
+ cls.ORIGINAL_EXCEPT_HOOK = sys.excepthook
145
+ sys.excepthook = cls.except_hook
146
+
118
147
  return cls.GLOBAL_SESSION
119
148
 
149
+ @staticmethod
150
+ def except_hook(exc_type, exc_value, exc_traceback):
151
+ Session._global_cleanup()
152
+ if Session.GLOBAL_SESSION_CTX is not None:
153
+ job_id = Session.GLOBAL_SESSION_CTX.job_id
154
+ Session.GLOBAL_SESSION_CTX._cleanup_created_versions(job_id)
155
+
156
+ if Session.ORIGINAL_EXCEPT_HOOK:
157
+ Session.ORIGINAL_EXCEPT_HOOK(exc_type, exc_value, exc_traceback)
158
+
120
159
  @classmethod
121
160
  def cleanup_for_tests(cls):
122
161
  if cls.GLOBAL_SESSION_CTX is not None:
@@ -125,6 +164,9 @@ class Session:
125
164
  cls.GLOBAL_SESSION_CTX = None
126
165
  atexit.unregister(cls._global_cleanup)
127
166
 
167
+ if cls.ORIGINAL_EXCEPT_HOOK:
168
+ sys.excepthook = cls.ORIGINAL_EXCEPT_HOOK
169
+
128
170
  @staticmethod
129
171
  def _global_cleanup():
130
172
  if Session.GLOBAL_SESSION_CTX is not None:
@@ -37,6 +37,18 @@ class regexp_replace(GenericFunction): # noqa: N801
37
37
  inherit_cache = True
38
38
 
39
39
 
40
+ class replace(GenericFunction): # noqa: N801
41
+ """
42
+ Replaces substring with another string.
43
+ """
44
+
45
+ type = String()
46
+ package = "string"
47
+ name = "replace"
48
+ inherit_cache = True
49
+
50
+
40
51
  compiler_not_implemented(length)
41
52
  compiler_not_implemented(split)
42
53
  compiler_not_implemented(regexp_replace)
54
+ compiler_not_implemented(replace)
@@ -78,7 +78,8 @@ def setup():
78
78
  compiles(array.length, "sqlite")(compile_array_length)
79
79
  compiles(string.length, "sqlite")(compile_string_length)
80
80
  compiles(string.split, "sqlite")(compile_string_split)
81
- compiles(string.regexp_replace, "sqlite")(compile_regexp_replace)
81
+ compiles(string.regexp_replace, "sqlite")(compile_string_regexp_replace)
82
+ compiles(string.replace, "sqlite")(compile_string_replace)
82
83
  compiles(conditional.greatest, "sqlite")(compile_greatest)
83
84
  compiles(conditional.least, "sqlite")(compile_least)
84
85
  compiles(Values, "sqlite")(compile_values)
@@ -273,10 +274,6 @@ def path_file_ext(path):
273
274
  return func.substr(path, func.length(path) - path_file_ext_length(path) + 1)
274
275
 
275
276
 
276
- def compile_regexp_replace(element, compiler, **kwargs):
277
- return f"regexp_replace({compiler.process(element.clauses, **kwargs)})"
278
-
279
-
280
277
  def compile_path_parent(element, compiler, **kwargs):
281
278
  return compiler.process(path_parent(*element.clauses.clauses), **kwargs)
282
279
 
@@ -331,6 +328,14 @@ def compile_string_split(element, compiler, **kwargs):
331
328
  return compiler.process(func.split(*element.clauses.clauses), **kwargs)
332
329
 
333
330
 
331
+ def compile_string_regexp_replace(element, compiler, **kwargs):
332
+ return f"regexp_replace({compiler.process(element.clauses, **kwargs)})"
333
+
334
+
335
+ def compile_string_replace(element, compiler, **kwargs):
336
+ return compiler.process(func.replace(*element.clauses.clauses), **kwargs)
337
+
338
+
334
339
  def compile_greatest(element, compiler, **kwargs):
335
340
  """
336
341
  Compiles a sql function for `greatest(*args)` taking 1 or more args
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: datachain
3
- Version: 0.5.0
3
+ Version: 0.6.0
4
4
  Summary: Wrangle unstructured AI data at scale
5
5
  Author-email: Dmitry Petrov <support@dvc.org>
6
6
  License: Apache-2.0
@@ -18,7 +18,7 @@ datachain/storage.py,sha256=RiSJLYdHUjnrEWkLBKPcETHpAxld_B2WxLg711t0aZI,3733
18
18
  datachain/telemetry.py,sha256=0A4IOPPp9VlP5pyW9eBfaTK3YhHGzHl7dQudQjUAx9A,994
19
19
  datachain/utils.py,sha256=KeFSRHsiYthnTu4a6bH-rw04mX1m8krTX0f2NqfQGFI,12114
20
20
  datachain/catalog/__init__.py,sha256=g2iAAFx_gEIrqshXlhSEbrc8qDaEH11cjU40n3CHDz4,409
21
- datachain/catalog/catalog.py,sha256=FuKuIiCwPgN5Ea25hnFe_ZFZH9YEUZ2ma9k_Lczk-JU,63867
21
+ datachain/catalog/catalog.py,sha256=BsMyk2RQibQYHgrmovFZeSEpPVMTwgb_7ntVYdc7t-E,64090
22
22
  datachain/catalog/datasource.py,sha256=D-VWIVDCM10A8sQavLhRXdYSCG7F4o4ifswEF80_NAQ,1412
23
23
  datachain/catalog/loader.py,sha256=-6VelNfXUdgUnwInVyA8g86Boxv2xqhTh9xNS-Zlwig,8242
24
24
  datachain/client/__init__.py,sha256=T4wiYL9KIM0ZZ_UqIyzV8_ufzYlewmizlV4iymHNluE,86
@@ -33,17 +33,17 @@ datachain/data_storage/__init__.py,sha256=cEOJpyu1JDZtfUupYucCDNFI6e5Wmp_Oyzq6rZ
33
33
  datachain/data_storage/db_engine.py,sha256=81Ol1of9TTTzD97ORajCnP366Xz2mEJt6C-kTUCaru4,3406
34
34
  datachain/data_storage/id_generator.py,sha256=lCEoU0BM37Ai2aRpSbwo5oQT0GqZnSpYwwvizathRMQ,4292
35
35
  datachain/data_storage/job.py,sha256=w-7spowjkOa1P5fUVtJou3OltT0L48P0RYWZ9rSJ9-s,383
36
- datachain/data_storage/metastore.py,sha256=NV4FJ_W16Q19Sx70i5Qtre-n4DC2kMD0qw0vBz3j7Ks,52228
36
+ datachain/data_storage/metastore.py,sha256=HfCxk4lmDUg2Q4WsFNQGMWxllP0mToA00fxkFTwdNIE,52919
37
37
  datachain/data_storage/schema.py,sha256=AGbjyEir5UmRZXI3m0jChZogUh5wd8csj6-YlUWaAxQ,8383
38
38
  datachain/data_storage/serializer.py,sha256=6G2YtOFqqDzJf1KbvZraKGXl2XHZyVml2krunWUum5o,927
39
- datachain/data_storage/sqlite.py,sha256=EBKJncuzcyQfcKFm2mUjvHjHRTODsteM-k_zndunBrw,28834
39
+ datachain/data_storage/sqlite.py,sha256=fW08P7AbJ0cDbTbcTKuAGpvMXvBjg-QkGsKT_Dslyws,28383
40
40
  datachain/data_storage/warehouse.py,sha256=fXhVfao3NfWFGbbG5uJ-Ga4bX1FiKVfcbDyQgECYfk8,32122
41
41
  datachain/lib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
42
- datachain/lib/arrow.py,sha256=aUsoQmxDmuSnB8Ik9p57Y66gc_dgx6NBqkDDIfLsvno,7630
42
+ datachain/lib/arrow.py,sha256=0R2CYsN82nNa5_03iS6jVix9EKeeqNZNAMgpSQP2hfo,9482
43
43
  datachain/lib/clip.py,sha256=lm5CzVi4Cj1jVLEKvERKArb-egb9j1Ls-fwTItT6vlI,6150
44
- datachain/lib/data_model.py,sha256=gHIjlow84GMRDa78yLL1Ud-N18or21fnTyPEwsatpXY,2045
44
+ datachain/lib/data_model.py,sha256=ECTbvlnzM98hp2mZ4fo82Yi0-MuoqTIQasQKGIyd89I,2040
45
45
  datachain/lib/dataset_info.py,sha256=srPPhI2UHf6hFPBecyFEVw2SS5aPisIIMsvGgKqi7ss,2366
46
- datachain/lib/dc.py,sha256=yTyHrKIswCzdlvl2n-wdEVZEEF5VQpkLJPzPfUL9CTU,72054
46
+ datachain/lib/dc.py,sha256=XmAFU9k79wUHIh0gYab8j-wF4vIlyW6opJcOy8fmoVc,76666
47
47
  datachain/lib/file.py,sha256=LjTW_-PDAnoUhvyB4bJ8Y8n__XGqrxvmd9mDOF0Gir8,14875
48
48
  datachain/lib/hf.py,sha256=cPnmLuprr0pYABH7KqA5FARQ1JGlywdDwD3yDzVAm4k,5920
49
49
  datachain/lib/image.py,sha256=AMXYwQsmarZjRbPCZY3M1jDsM2WAB_b3cTY4uOIuXNU,2675
@@ -53,11 +53,11 @@ datachain/lib/meta_formats.py,sha256=3f-0vpMTesagS9iMd3y9-u9r-7g0eqYsxmK4fVfNWlw
53
53
  datachain/lib/model_store.py,sha256=DNIv8Y6Jtk1_idNLzIpsThOsdW2BMAudyUCbPUcgcxk,2515
54
54
  datachain/lib/pytorch.py,sha256=W-ARi2xH1f1DUkVfRuerW-YWYgSaJASmNCxtz2lrJGI,6072
55
55
  datachain/lib/settings.py,sha256=39thOpYJw-zPirzeNO6pmRC2vPrQvt4eBsw1xLWDFsw,2344
56
- datachain/lib/signal_schema.py,sha256=iqgubjCBRiUJB30miv05qFX4uU04dA_Pzi3DCUsHZGs,24177
56
+ datachain/lib/signal_schema.py,sha256=gj45dRQuOsKDmaKaJxb5j63HYVGw-Ks1fyAS1FpyOWA,24145
57
57
  datachain/lib/tar.py,sha256=3WIzao6yD5fbLqXLTt9GhPGNonbFIs_fDRu-9vgLgsA,1038
58
58
  datachain/lib/text.py,sha256=UNHm8fhidk7wdrWqacEWaA6I9ykfYqarQ2URby7jc7M,1261
59
- datachain/lib/udf.py,sha256=nG7DDuPgZ5ZuijwvDoCq-OZMxlDM8vFNzyxMmik0Y1c,11716
60
- datachain/lib/udf_signature.py,sha256=gMStcEeYJka5M6cg50Z9orC6y6HzCAJ3MkFqqn1fjZg,7137
59
+ datachain/lib/udf.py,sha256=GvhWLCXZUY7sz1QMRBj1AJDSzzhyj15xs3Ia9hjJrJE,12697
60
+ datachain/lib/udf_signature.py,sha256=GXw24A-Olna6DWCdgy2bC-gZh_gLGPQ-KvjuI6pUjC0,7281
61
61
  datachain/lib/utils.py,sha256=5-kJlAZE0D9nXXweAjo7-SP_AWGo28feaDByONYaooQ,463
62
62
  datachain/lib/vfile.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
63
63
  datachain/lib/webdataset.py,sha256=o7SHk5HOUWsZ5Ln04xOM04eQqiBHiJNO7xLgyVBrwo8,6924
@@ -67,17 +67,16 @@ datachain/lib/convert/flatten.py,sha256=Uebc5CeqCsacp-nr6IG9i6OGuUavXqdqnoGctZBk
67
67
  datachain/lib/convert/python_to_sql.py,sha256=40SAOdoOgikZRhn8iomCPDRoxC3RFxjJLivEAA9MHDU,2880
68
68
  datachain/lib/convert/sql_to_python.py,sha256=lGnKzSF_tz9Y_5SSKkrIU95QEjpcDzvOxIRkEKTQag0,443
69
69
  datachain/lib/convert/unflatten.py,sha256=Ogvh_5wg2f38_At_1lN0D_e2uZOOpYEvwvB2xdq56Tw,2012
70
- datachain/lib/convert/values_to_tuples.py,sha256=YOdbjzHq-uj6-cV2Qq43G72eN2avMNDGl4x5t6yQMl8,3931
70
+ datachain/lib/convert/values_to_tuples.py,sha256=varRCnSMT_pZmHznrd2Yi05qXLLz_v9YH_pOCpHSkdc,3921
71
71
  datachain/query/__init__.py,sha256=0NBOZVgIDpCcj1Ci883dQ9A0iiwe03xzmotkOCFbxYc,293
72
- datachain/query/batch.py,sha256=-vlpINJiertlnaoUVv1C95RatU0F6zuhpIYRufJRo1M,3660
73
- datachain/query/dataset.py,sha256=tLCTaj4K93BY93GgOPv9PknZByEF89zpHc7y9s8ZF_w,53610
74
- datachain/query/dispatch.py,sha256=CFAc09O6UllcyUSSEY1GUlEMPzeO8RYhXinNN4HBl9M,12405
72
+ datachain/query/batch.py,sha256=3QlwshhpUc1amZRtXWVXEEuq47hEQgQlY0Ji48DR6hg,3508
73
+ datachain/query/dataset.py,sha256=MF_E7yjbFQV6NcP4gKbJFXiWuoQkpQ7-Jmxa59FxenE,53630
74
+ datachain/query/dispatch.py,sha256=wjjTWw6sFQbB9SKRh78VbfvwSMgJXCfqJklS3-9KnCU,12025
75
75
  datachain/query/metrics.py,sha256=r5b0ygYhokbXp8Mg3kCH8iFSRw0jxzyeBe-C-J_bKFc,938
76
76
  datachain/query/params.py,sha256=O_j89mjYRLOwWNhYZl-z7mi-rkdP7WyFmaDufsdTryE,863
77
77
  datachain/query/queue.py,sha256=waqM_KzavU8C-G95-4211Nd4GXna_u2747Chgwtgz2w,3839
78
78
  datachain/query/schema.py,sha256=I8zLWJuWl5N332ni9mAzDYtcxMJupVPgWkSDe8spNEk,8019
79
- datachain/query/session.py,sha256=UPH5Z4fzCDsvj81ji0e8GA6Mgra3bOAEpVq4htqOtis,4317
80
- datachain/query/udf.py,sha256=HB2hbEuiGA4ch9P2mh9iLA5Jj9mRj-4JFy9VfjTLJ8U,3622
79
+ datachain/query/session.py,sha256=kpFFJMfWBnxaMPojMGhJRbk-BOsSYI8Ckl6vvqnx7d0,5787
81
80
  datachain/remote/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
82
81
  datachain/remote/studio.py,sha256=f5s6qSZ9uB4URGUoU_8_W1KZRRQQVSm6cgEBkBUEfuE,7226
83
82
  datachain/sql/__init__.py,sha256=A2djrbQwSMUZZEIKGnm-mnRA-NDSbiDJNpAmmwGNyIo,303
@@ -91,15 +90,15 @@ datachain/sql/functions/array.py,sha256=EB7nJSncUc1PuxlHyzU2gVhF8DuXaxpGlxb5e8X2
91
90
  datachain/sql/functions/conditional.py,sha256=q7YUKfunXeEldXaxgT-p5pUTcOEVU_tcQ2BJlquTRPs,207
92
91
  datachain/sql/functions/path.py,sha256=zixpERotTFP6LZ7I4TiGtyRA8kXOoZmH1yzH9oRW0mg,1294
93
92
  datachain/sql/functions/random.py,sha256=vBwEEj98VH4LjWixUCygQ5Bz1mv1nohsCG0-ZTELlVg,271
94
- datachain/sql/functions/string.py,sha256=NSQIpmtQgm68hz3TFJsgHMBuo4MjBNhDSyEIC3pWkT8,916
93
+ datachain/sql/functions/string.py,sha256=DYgiw8XSk7ge7GXvyRI1zbaMruIizNeI-puOjriQGZQ,1148
95
94
  datachain/sql/sqlite/__init__.py,sha256=TAdJX0Bg28XdqPO-QwUVKy8rg78cgMileHvMNot7d04,166
96
- datachain/sql/sqlite/base.py,sha256=WLPHBhZbXbiqPoRV1VgDrXJqku4UuvJpBhYeQ0k5rI8,13364
95
+ datachain/sql/sqlite/base.py,sha256=3gDMLKSWkxnbiZ1dykYa5VuHSSlg5sLY9ihMqcH_o1M,13578
97
96
  datachain/sql/sqlite/types.py,sha256=yzvp0sXSEoEYXs6zaYC_2YubarQoZH-MiUNXcpuEP4s,1573
98
97
  datachain/sql/sqlite/vector.py,sha256=ncW4eu2FlJhrP_CIpsvtkUabZlQdl2D5Lgwy_cbfqR0,469
99
98
  datachain/torch/__init__.py,sha256=gIS74PoEPy4TB3X6vx9nLO0Y3sLJzsA8ckn8pRWihJM,579
100
- datachain-0.5.0.dist-info/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
101
- datachain-0.5.0.dist-info/METADATA,sha256=tKSZNiHZY0WJ_w6irkpSF7qDfuOTfiYNEQ6St3eBs-M,17156
102
- datachain-0.5.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
103
- datachain-0.5.0.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
104
- datachain-0.5.0.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
105
- datachain-0.5.0.dist-info/RECORD,,
99
+ datachain-0.6.0.dist-info/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
100
+ datachain-0.6.0.dist-info/METADATA,sha256=4nxP9eUg6o9ymkwy-hz4DsqRM5IBtqhInNE7vsE0lxY,17156
101
+ datachain-0.6.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
102
+ datachain-0.6.0.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
103
+ datachain-0.6.0.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
104
+ datachain-0.6.0.dist-info/RECORD,,
datachain/query/udf.py DELETED
@@ -1,126 +0,0 @@
1
- import typing
2
- from collections.abc import Iterable, Iterator, Sequence
3
- from dataclasses import dataclass
4
- from typing import (
5
- TYPE_CHECKING,
6
- Any,
7
- )
8
-
9
- from fsspec.callbacks import DEFAULT_CALLBACK, Callback
10
-
11
- from datachain.dataset import RowDict
12
-
13
- from .batch import (
14
- Batch,
15
- BatchingStrategy,
16
- NoBatching,
17
- Partition,
18
- RowsOutputBatch,
19
- UDFInputBatch,
20
- )
21
- from .schema import UDFParameter
22
-
23
- if TYPE_CHECKING:
24
- from datachain.catalog import Catalog
25
-
26
- from .batch import RowsOutput, UDFInput
27
-
28
- ColumnType = Any
29
-
30
-
31
- # Specification for the output of a UDF
32
- UDFOutputSpec = typing.Mapping[str, ColumnType]
33
-
34
- # Result type when calling the UDF wrapper around the actual
35
- # Python function / class implementing it.
36
- UDFResult = dict[str, Any]
37
-
38
-
39
- @dataclass
40
- class UDFProperties:
41
- """Container for basic UDF properties."""
42
-
43
- params: list[UDFParameter]
44
- output: UDFOutputSpec
45
- batch: int = 1
46
-
47
- def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy:
48
- if use_partitioning:
49
- return Partition()
50
- if self.batch == 1:
51
- return NoBatching()
52
- if self.batch > 1:
53
- return Batch(self.batch)
54
- raise ValueError(f"invalid batch size {self.batch}")
55
-
56
- def signal_names(self) -> Iterable[str]:
57
- return self.output.keys()
58
-
59
-
60
- class UDFBase:
61
- """A base class for implementing stateful UDFs."""
62
-
63
- def __init__(
64
- self,
65
- properties: UDFProperties,
66
- ):
67
- self.properties = properties
68
- self.signal_names = properties.signal_names()
69
- self.output = properties.output
70
-
71
- def run(
72
- self,
73
- udf_fields: "Sequence[str]",
74
- udf_inputs: "Iterable[RowsOutput]",
75
- catalog: "Catalog",
76
- is_generator: bool,
77
- cache: bool,
78
- download_cb: Callback = DEFAULT_CALLBACK,
79
- processed_cb: Callback = DEFAULT_CALLBACK,
80
- ) -> Iterator[Iterable["UDFResult"]]:
81
- for batch in udf_inputs:
82
- if isinstance(batch, RowsOutputBatch):
83
- n_rows = len(batch.rows)
84
- inputs: UDFInput = UDFInputBatch(
85
- [RowDict(zip(udf_fields, row)) for row in batch.rows]
86
- )
87
- else:
88
- n_rows = 1
89
- inputs = RowDict(zip(udf_fields, batch))
90
- output = self.run_once(catalog, inputs, is_generator, cache, cb=download_cb)
91
- processed_cb.relative_update(n_rows)
92
- yield output
93
-
94
- def run_once(
95
- self,
96
- catalog: "Catalog",
97
- arg: "UDFInput",
98
- is_generator: bool = False,
99
- cache: bool = False,
100
- cb: Callback = DEFAULT_CALLBACK,
101
- ) -> Iterable[UDFResult]:
102
- raise NotImplementedError
103
-
104
- def bind_parameters(self, catalog: "Catalog", row: "RowDict", **kwargs) -> list:
105
- return [p.get_value(catalog, row, **kwargs) for p in self.properties.params]
106
-
107
- def _process_results(
108
- self,
109
- rows: Sequence["RowDict"],
110
- results: Sequence[Sequence[Any]],
111
- is_generator=False,
112
- ) -> Iterable[UDFResult]:
113
- """Create a list of dictionaries representing UDF results."""
114
-
115
- # outputting rows
116
- if is_generator:
117
- # each row in results is a tuple of column values
118
- return (dict(zip(self.signal_names, row)) for row in results)
119
-
120
- # outputting signals
121
- row_ids = [row["sys__id"] for row in rows]
122
- return [
123
- {"sys__id": row_id} | dict(zip(self.signal_names, signals))
124
- for row_id, signals in zip(row_ids, results)
125
- if signals is not None # skip rows with no output
126
- ]