datachain 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +8 -0
- datachain/data_storage/metastore.py +20 -1
- datachain/data_storage/sqlite.py +24 -32
- datachain/lib/arrow.py +64 -19
- datachain/lib/convert/values_to_tuples.py +2 -2
- datachain/lib/data_model.py +1 -1
- datachain/lib/dc.py +131 -12
- datachain/lib/signal_schema.py +6 -6
- datachain/lib/udf.py +208 -160
- datachain/lib/udf_signature.py +8 -6
- datachain/query/batch.py +0 -10
- datachain/query/dataset.py +7 -7
- datachain/query/dispatch.py +2 -14
- datachain/query/session.py +42 -0
- datachain/sql/functions/string.py +12 -0
- datachain/sql/sqlite/base.py +10 -5
- {datachain-0.5.0.dist-info → datachain-0.6.0.dist-info}/METADATA +1 -1
- {datachain-0.5.0.dist-info → datachain-0.6.0.dist-info}/RECORD +22 -23
- datachain/query/udf.py +0 -126
- {datachain-0.5.0.dist-info → datachain-0.6.0.dist-info}/LICENSE +0 -0
- {datachain-0.5.0.dist-info → datachain-0.6.0.dist-info}/WHEEL +0 -0
- {datachain-0.5.0.dist-info → datachain-0.6.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.5.0.dist-info → datachain-0.6.0.dist-info}/top_level.txt +0 -0
datachain/query/session.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
import atexit
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
2
4
|
import re
|
|
5
|
+
import sys
|
|
3
6
|
from typing import TYPE_CHECKING, Optional
|
|
4
7
|
from uuid import uuid4
|
|
5
8
|
|
|
@@ -9,6 +12,8 @@ from datachain.error import TableMissingError
|
|
|
9
12
|
if TYPE_CHECKING:
|
|
10
13
|
from datachain.catalog import Catalog
|
|
11
14
|
|
|
15
|
+
logger = logging.getLogger("datachain")
|
|
16
|
+
|
|
12
17
|
|
|
13
18
|
class Session:
|
|
14
19
|
"""
|
|
@@ -35,6 +40,7 @@ class Session:
|
|
|
35
40
|
|
|
36
41
|
GLOBAL_SESSION_CTX: Optional["Session"] = None
|
|
37
42
|
GLOBAL_SESSION: Optional["Session"] = None
|
|
43
|
+
ORIGINAL_EXCEPT_HOOK = None
|
|
38
44
|
|
|
39
45
|
DATASET_PREFIX = "session_"
|
|
40
46
|
GLOBAL_SESSION_NAME = "global"
|
|
@@ -58,6 +64,7 @@ class Session:
|
|
|
58
64
|
|
|
59
65
|
session_uuid = uuid4().hex[: self.SESSION_UUID_LEN]
|
|
60
66
|
self.name = f"{name}_{session_uuid}"
|
|
67
|
+
self.job_id = os.getenv("DATACHAIN_JOB_ID") or str(uuid4())
|
|
61
68
|
self.is_new_catalog = not catalog
|
|
62
69
|
self.catalog = catalog or get_catalog(
|
|
63
70
|
client_config=client_config, in_memory=in_memory
|
|
@@ -67,6 +74,9 @@ class Session:
|
|
|
67
74
|
return self
|
|
68
75
|
|
|
69
76
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
77
|
+
if exc_type:
|
|
78
|
+
self._cleanup_created_versions(self.name)
|
|
79
|
+
|
|
70
80
|
self._cleanup_temp_datasets()
|
|
71
81
|
if self.is_new_catalog:
|
|
72
82
|
self.catalog.metastore.close_on_exit()
|
|
@@ -88,6 +98,21 @@ class Session:
|
|
|
88
98
|
except TableMissingError:
|
|
89
99
|
pass
|
|
90
100
|
|
|
101
|
+
def _cleanup_created_versions(self, job_id: str) -> None:
|
|
102
|
+
versions = self.catalog.metastore.get_job_dataset_versions(job_id)
|
|
103
|
+
if not versions:
|
|
104
|
+
return
|
|
105
|
+
|
|
106
|
+
datasets = {}
|
|
107
|
+
for dataset_name, version in versions:
|
|
108
|
+
if dataset_name not in datasets:
|
|
109
|
+
datasets[dataset_name] = self.catalog.get_dataset(dataset_name)
|
|
110
|
+
dataset = datasets[dataset_name]
|
|
111
|
+
logger.info(
|
|
112
|
+
"Removing dataset version %s@%s due to exception", dataset_name, version
|
|
113
|
+
)
|
|
114
|
+
self.catalog.remove_dataset_version(dataset, version)
|
|
115
|
+
|
|
91
116
|
@classmethod
|
|
92
117
|
def get(
|
|
93
118
|
cls,
|
|
@@ -114,9 +139,23 @@ class Session:
|
|
|
114
139
|
in_memory=in_memory,
|
|
115
140
|
)
|
|
116
141
|
cls.GLOBAL_SESSION = cls.GLOBAL_SESSION_CTX.__enter__()
|
|
142
|
+
|
|
117
143
|
atexit.register(cls._global_cleanup)
|
|
144
|
+
cls.ORIGINAL_EXCEPT_HOOK = sys.excepthook
|
|
145
|
+
sys.excepthook = cls.except_hook
|
|
146
|
+
|
|
118
147
|
return cls.GLOBAL_SESSION
|
|
119
148
|
|
|
149
|
+
@staticmethod
|
|
150
|
+
def except_hook(exc_type, exc_value, exc_traceback):
|
|
151
|
+
Session._global_cleanup()
|
|
152
|
+
if Session.GLOBAL_SESSION_CTX is not None:
|
|
153
|
+
job_id = Session.GLOBAL_SESSION_CTX.job_id
|
|
154
|
+
Session.GLOBAL_SESSION_CTX._cleanup_created_versions(job_id)
|
|
155
|
+
|
|
156
|
+
if Session.ORIGINAL_EXCEPT_HOOK:
|
|
157
|
+
Session.ORIGINAL_EXCEPT_HOOK(exc_type, exc_value, exc_traceback)
|
|
158
|
+
|
|
120
159
|
@classmethod
|
|
121
160
|
def cleanup_for_tests(cls):
|
|
122
161
|
if cls.GLOBAL_SESSION_CTX is not None:
|
|
@@ -125,6 +164,9 @@ class Session:
|
|
|
125
164
|
cls.GLOBAL_SESSION_CTX = None
|
|
126
165
|
atexit.unregister(cls._global_cleanup)
|
|
127
166
|
|
|
167
|
+
if cls.ORIGINAL_EXCEPT_HOOK:
|
|
168
|
+
sys.excepthook = cls.ORIGINAL_EXCEPT_HOOK
|
|
169
|
+
|
|
128
170
|
@staticmethod
|
|
129
171
|
def _global_cleanup():
|
|
130
172
|
if Session.GLOBAL_SESSION_CTX is not None:
|
|
@@ -37,6 +37,18 @@ class regexp_replace(GenericFunction): # noqa: N801
|
|
|
37
37
|
inherit_cache = True
|
|
38
38
|
|
|
39
39
|
|
|
40
|
+
class replace(GenericFunction): # noqa: N801
|
|
41
|
+
"""
|
|
42
|
+
Replaces substring with another string.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
type = String()
|
|
46
|
+
package = "string"
|
|
47
|
+
name = "replace"
|
|
48
|
+
inherit_cache = True
|
|
49
|
+
|
|
50
|
+
|
|
40
51
|
compiler_not_implemented(length)
|
|
41
52
|
compiler_not_implemented(split)
|
|
42
53
|
compiler_not_implemented(regexp_replace)
|
|
54
|
+
compiler_not_implemented(replace)
|
datachain/sql/sqlite/base.py
CHANGED
|
@@ -78,7 +78,8 @@ def setup():
|
|
|
78
78
|
compiles(array.length, "sqlite")(compile_array_length)
|
|
79
79
|
compiles(string.length, "sqlite")(compile_string_length)
|
|
80
80
|
compiles(string.split, "sqlite")(compile_string_split)
|
|
81
|
-
compiles(string.regexp_replace, "sqlite")(
|
|
81
|
+
compiles(string.regexp_replace, "sqlite")(compile_string_regexp_replace)
|
|
82
|
+
compiles(string.replace, "sqlite")(compile_string_replace)
|
|
82
83
|
compiles(conditional.greatest, "sqlite")(compile_greatest)
|
|
83
84
|
compiles(conditional.least, "sqlite")(compile_least)
|
|
84
85
|
compiles(Values, "sqlite")(compile_values)
|
|
@@ -273,10 +274,6 @@ def path_file_ext(path):
|
|
|
273
274
|
return func.substr(path, func.length(path) - path_file_ext_length(path) + 1)
|
|
274
275
|
|
|
275
276
|
|
|
276
|
-
def compile_regexp_replace(element, compiler, **kwargs):
|
|
277
|
-
return f"regexp_replace({compiler.process(element.clauses, **kwargs)})"
|
|
278
|
-
|
|
279
|
-
|
|
280
277
|
def compile_path_parent(element, compiler, **kwargs):
|
|
281
278
|
return compiler.process(path_parent(*element.clauses.clauses), **kwargs)
|
|
282
279
|
|
|
@@ -331,6 +328,14 @@ def compile_string_split(element, compiler, **kwargs):
|
|
|
331
328
|
return compiler.process(func.split(*element.clauses.clauses), **kwargs)
|
|
332
329
|
|
|
333
330
|
|
|
331
|
+
def compile_string_regexp_replace(element, compiler, **kwargs):
|
|
332
|
+
return f"regexp_replace({compiler.process(element.clauses, **kwargs)})"
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def compile_string_replace(element, compiler, **kwargs):
|
|
336
|
+
return compiler.process(func.replace(*element.clauses.clauses), **kwargs)
|
|
337
|
+
|
|
338
|
+
|
|
334
339
|
def compile_greatest(element, compiler, **kwargs):
|
|
335
340
|
"""
|
|
336
341
|
Compiles a sql function for `greatest(*args)` taking 1 or more args
|
|
@@ -18,7 +18,7 @@ datachain/storage.py,sha256=RiSJLYdHUjnrEWkLBKPcETHpAxld_B2WxLg711t0aZI,3733
|
|
|
18
18
|
datachain/telemetry.py,sha256=0A4IOPPp9VlP5pyW9eBfaTK3YhHGzHl7dQudQjUAx9A,994
|
|
19
19
|
datachain/utils.py,sha256=KeFSRHsiYthnTu4a6bH-rw04mX1m8krTX0f2NqfQGFI,12114
|
|
20
20
|
datachain/catalog/__init__.py,sha256=g2iAAFx_gEIrqshXlhSEbrc8qDaEH11cjU40n3CHDz4,409
|
|
21
|
-
datachain/catalog/catalog.py,sha256=
|
|
21
|
+
datachain/catalog/catalog.py,sha256=BsMyk2RQibQYHgrmovFZeSEpPVMTwgb_7ntVYdc7t-E,64090
|
|
22
22
|
datachain/catalog/datasource.py,sha256=D-VWIVDCM10A8sQavLhRXdYSCG7F4o4ifswEF80_NAQ,1412
|
|
23
23
|
datachain/catalog/loader.py,sha256=-6VelNfXUdgUnwInVyA8g86Boxv2xqhTh9xNS-Zlwig,8242
|
|
24
24
|
datachain/client/__init__.py,sha256=T4wiYL9KIM0ZZ_UqIyzV8_ufzYlewmizlV4iymHNluE,86
|
|
@@ -33,17 +33,17 @@ datachain/data_storage/__init__.py,sha256=cEOJpyu1JDZtfUupYucCDNFI6e5Wmp_Oyzq6rZ
|
|
|
33
33
|
datachain/data_storage/db_engine.py,sha256=81Ol1of9TTTzD97ORajCnP366Xz2mEJt6C-kTUCaru4,3406
|
|
34
34
|
datachain/data_storage/id_generator.py,sha256=lCEoU0BM37Ai2aRpSbwo5oQT0GqZnSpYwwvizathRMQ,4292
|
|
35
35
|
datachain/data_storage/job.py,sha256=w-7spowjkOa1P5fUVtJou3OltT0L48P0RYWZ9rSJ9-s,383
|
|
36
|
-
datachain/data_storage/metastore.py,sha256=
|
|
36
|
+
datachain/data_storage/metastore.py,sha256=HfCxk4lmDUg2Q4WsFNQGMWxllP0mToA00fxkFTwdNIE,52919
|
|
37
37
|
datachain/data_storage/schema.py,sha256=AGbjyEir5UmRZXI3m0jChZogUh5wd8csj6-YlUWaAxQ,8383
|
|
38
38
|
datachain/data_storage/serializer.py,sha256=6G2YtOFqqDzJf1KbvZraKGXl2XHZyVml2krunWUum5o,927
|
|
39
|
-
datachain/data_storage/sqlite.py,sha256=
|
|
39
|
+
datachain/data_storage/sqlite.py,sha256=fW08P7AbJ0cDbTbcTKuAGpvMXvBjg-QkGsKT_Dslyws,28383
|
|
40
40
|
datachain/data_storage/warehouse.py,sha256=fXhVfao3NfWFGbbG5uJ-Ga4bX1FiKVfcbDyQgECYfk8,32122
|
|
41
41
|
datachain/lib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
42
|
-
datachain/lib/arrow.py,sha256=
|
|
42
|
+
datachain/lib/arrow.py,sha256=0R2CYsN82nNa5_03iS6jVix9EKeeqNZNAMgpSQP2hfo,9482
|
|
43
43
|
datachain/lib/clip.py,sha256=lm5CzVi4Cj1jVLEKvERKArb-egb9j1Ls-fwTItT6vlI,6150
|
|
44
|
-
datachain/lib/data_model.py,sha256=
|
|
44
|
+
datachain/lib/data_model.py,sha256=ECTbvlnzM98hp2mZ4fo82Yi0-MuoqTIQasQKGIyd89I,2040
|
|
45
45
|
datachain/lib/dataset_info.py,sha256=srPPhI2UHf6hFPBecyFEVw2SS5aPisIIMsvGgKqi7ss,2366
|
|
46
|
-
datachain/lib/dc.py,sha256=
|
|
46
|
+
datachain/lib/dc.py,sha256=XmAFU9k79wUHIh0gYab8j-wF4vIlyW6opJcOy8fmoVc,76666
|
|
47
47
|
datachain/lib/file.py,sha256=LjTW_-PDAnoUhvyB4bJ8Y8n__XGqrxvmd9mDOF0Gir8,14875
|
|
48
48
|
datachain/lib/hf.py,sha256=cPnmLuprr0pYABH7KqA5FARQ1JGlywdDwD3yDzVAm4k,5920
|
|
49
49
|
datachain/lib/image.py,sha256=AMXYwQsmarZjRbPCZY3M1jDsM2WAB_b3cTY4uOIuXNU,2675
|
|
@@ -53,11 +53,11 @@ datachain/lib/meta_formats.py,sha256=3f-0vpMTesagS9iMd3y9-u9r-7g0eqYsxmK4fVfNWlw
|
|
|
53
53
|
datachain/lib/model_store.py,sha256=DNIv8Y6Jtk1_idNLzIpsThOsdW2BMAudyUCbPUcgcxk,2515
|
|
54
54
|
datachain/lib/pytorch.py,sha256=W-ARi2xH1f1DUkVfRuerW-YWYgSaJASmNCxtz2lrJGI,6072
|
|
55
55
|
datachain/lib/settings.py,sha256=39thOpYJw-zPirzeNO6pmRC2vPrQvt4eBsw1xLWDFsw,2344
|
|
56
|
-
datachain/lib/signal_schema.py,sha256=
|
|
56
|
+
datachain/lib/signal_schema.py,sha256=gj45dRQuOsKDmaKaJxb5j63HYVGw-Ks1fyAS1FpyOWA,24145
|
|
57
57
|
datachain/lib/tar.py,sha256=3WIzao6yD5fbLqXLTt9GhPGNonbFIs_fDRu-9vgLgsA,1038
|
|
58
58
|
datachain/lib/text.py,sha256=UNHm8fhidk7wdrWqacEWaA6I9ykfYqarQ2URby7jc7M,1261
|
|
59
|
-
datachain/lib/udf.py,sha256=
|
|
60
|
-
datachain/lib/udf_signature.py,sha256=
|
|
59
|
+
datachain/lib/udf.py,sha256=GvhWLCXZUY7sz1QMRBj1AJDSzzhyj15xs3Ia9hjJrJE,12697
|
|
60
|
+
datachain/lib/udf_signature.py,sha256=GXw24A-Olna6DWCdgy2bC-gZh_gLGPQ-KvjuI6pUjC0,7281
|
|
61
61
|
datachain/lib/utils.py,sha256=5-kJlAZE0D9nXXweAjo7-SP_AWGo28feaDByONYaooQ,463
|
|
62
62
|
datachain/lib/vfile.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
63
63
|
datachain/lib/webdataset.py,sha256=o7SHk5HOUWsZ5Ln04xOM04eQqiBHiJNO7xLgyVBrwo8,6924
|
|
@@ -67,17 +67,16 @@ datachain/lib/convert/flatten.py,sha256=Uebc5CeqCsacp-nr6IG9i6OGuUavXqdqnoGctZBk
|
|
|
67
67
|
datachain/lib/convert/python_to_sql.py,sha256=40SAOdoOgikZRhn8iomCPDRoxC3RFxjJLivEAA9MHDU,2880
|
|
68
68
|
datachain/lib/convert/sql_to_python.py,sha256=lGnKzSF_tz9Y_5SSKkrIU95QEjpcDzvOxIRkEKTQag0,443
|
|
69
69
|
datachain/lib/convert/unflatten.py,sha256=Ogvh_5wg2f38_At_1lN0D_e2uZOOpYEvwvB2xdq56Tw,2012
|
|
70
|
-
datachain/lib/convert/values_to_tuples.py,sha256=
|
|
70
|
+
datachain/lib/convert/values_to_tuples.py,sha256=varRCnSMT_pZmHznrd2Yi05qXLLz_v9YH_pOCpHSkdc,3921
|
|
71
71
|
datachain/query/__init__.py,sha256=0NBOZVgIDpCcj1Ci883dQ9A0iiwe03xzmotkOCFbxYc,293
|
|
72
|
-
datachain/query/batch.py,sha256
|
|
73
|
-
datachain/query/dataset.py,sha256=
|
|
74
|
-
datachain/query/dispatch.py,sha256=
|
|
72
|
+
datachain/query/batch.py,sha256=3QlwshhpUc1amZRtXWVXEEuq47hEQgQlY0Ji48DR6hg,3508
|
|
73
|
+
datachain/query/dataset.py,sha256=MF_E7yjbFQV6NcP4gKbJFXiWuoQkpQ7-Jmxa59FxenE,53630
|
|
74
|
+
datachain/query/dispatch.py,sha256=wjjTWw6sFQbB9SKRh78VbfvwSMgJXCfqJklS3-9KnCU,12025
|
|
75
75
|
datachain/query/metrics.py,sha256=r5b0ygYhokbXp8Mg3kCH8iFSRw0jxzyeBe-C-J_bKFc,938
|
|
76
76
|
datachain/query/params.py,sha256=O_j89mjYRLOwWNhYZl-z7mi-rkdP7WyFmaDufsdTryE,863
|
|
77
77
|
datachain/query/queue.py,sha256=waqM_KzavU8C-G95-4211Nd4GXna_u2747Chgwtgz2w,3839
|
|
78
78
|
datachain/query/schema.py,sha256=I8zLWJuWl5N332ni9mAzDYtcxMJupVPgWkSDe8spNEk,8019
|
|
79
|
-
datachain/query/session.py,sha256=
|
|
80
|
-
datachain/query/udf.py,sha256=HB2hbEuiGA4ch9P2mh9iLA5Jj9mRj-4JFy9VfjTLJ8U,3622
|
|
79
|
+
datachain/query/session.py,sha256=kpFFJMfWBnxaMPojMGhJRbk-BOsSYI8Ckl6vvqnx7d0,5787
|
|
81
80
|
datachain/remote/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
82
81
|
datachain/remote/studio.py,sha256=f5s6qSZ9uB4URGUoU_8_W1KZRRQQVSm6cgEBkBUEfuE,7226
|
|
83
82
|
datachain/sql/__init__.py,sha256=A2djrbQwSMUZZEIKGnm-mnRA-NDSbiDJNpAmmwGNyIo,303
|
|
@@ -91,15 +90,15 @@ datachain/sql/functions/array.py,sha256=EB7nJSncUc1PuxlHyzU2gVhF8DuXaxpGlxb5e8X2
|
|
|
91
90
|
datachain/sql/functions/conditional.py,sha256=q7YUKfunXeEldXaxgT-p5pUTcOEVU_tcQ2BJlquTRPs,207
|
|
92
91
|
datachain/sql/functions/path.py,sha256=zixpERotTFP6LZ7I4TiGtyRA8kXOoZmH1yzH9oRW0mg,1294
|
|
93
92
|
datachain/sql/functions/random.py,sha256=vBwEEj98VH4LjWixUCygQ5Bz1mv1nohsCG0-ZTELlVg,271
|
|
94
|
-
datachain/sql/functions/string.py,sha256=
|
|
93
|
+
datachain/sql/functions/string.py,sha256=DYgiw8XSk7ge7GXvyRI1zbaMruIizNeI-puOjriQGZQ,1148
|
|
95
94
|
datachain/sql/sqlite/__init__.py,sha256=TAdJX0Bg28XdqPO-QwUVKy8rg78cgMileHvMNot7d04,166
|
|
96
|
-
datachain/sql/sqlite/base.py,sha256=
|
|
95
|
+
datachain/sql/sqlite/base.py,sha256=3gDMLKSWkxnbiZ1dykYa5VuHSSlg5sLY9ihMqcH_o1M,13578
|
|
97
96
|
datachain/sql/sqlite/types.py,sha256=yzvp0sXSEoEYXs6zaYC_2YubarQoZH-MiUNXcpuEP4s,1573
|
|
98
97
|
datachain/sql/sqlite/vector.py,sha256=ncW4eu2FlJhrP_CIpsvtkUabZlQdl2D5Lgwy_cbfqR0,469
|
|
99
98
|
datachain/torch/__init__.py,sha256=gIS74PoEPy4TB3X6vx9nLO0Y3sLJzsA8ckn8pRWihJM,579
|
|
100
|
-
datachain-0.
|
|
101
|
-
datachain-0.
|
|
102
|
-
datachain-0.
|
|
103
|
-
datachain-0.
|
|
104
|
-
datachain-0.
|
|
105
|
-
datachain-0.
|
|
99
|
+
datachain-0.6.0.dist-info/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
|
|
100
|
+
datachain-0.6.0.dist-info/METADATA,sha256=4nxP9eUg6o9ymkwy-hz4DsqRM5IBtqhInNE7vsE0lxY,17156
|
|
101
|
+
datachain-0.6.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
|
102
|
+
datachain-0.6.0.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
|
|
103
|
+
datachain-0.6.0.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
|
|
104
|
+
datachain-0.6.0.dist-info/RECORD,,
|
datachain/query/udf.py
DELETED
|
@@ -1,126 +0,0 @@
|
|
|
1
|
-
import typing
|
|
2
|
-
from collections.abc import Iterable, Iterator, Sequence
|
|
3
|
-
from dataclasses import dataclass
|
|
4
|
-
from typing import (
|
|
5
|
-
TYPE_CHECKING,
|
|
6
|
-
Any,
|
|
7
|
-
)
|
|
8
|
-
|
|
9
|
-
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
10
|
-
|
|
11
|
-
from datachain.dataset import RowDict
|
|
12
|
-
|
|
13
|
-
from .batch import (
|
|
14
|
-
Batch,
|
|
15
|
-
BatchingStrategy,
|
|
16
|
-
NoBatching,
|
|
17
|
-
Partition,
|
|
18
|
-
RowsOutputBatch,
|
|
19
|
-
UDFInputBatch,
|
|
20
|
-
)
|
|
21
|
-
from .schema import UDFParameter
|
|
22
|
-
|
|
23
|
-
if TYPE_CHECKING:
|
|
24
|
-
from datachain.catalog import Catalog
|
|
25
|
-
|
|
26
|
-
from .batch import RowsOutput, UDFInput
|
|
27
|
-
|
|
28
|
-
ColumnType = Any
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
# Specification for the output of a UDF
|
|
32
|
-
UDFOutputSpec = typing.Mapping[str, ColumnType]
|
|
33
|
-
|
|
34
|
-
# Result type when calling the UDF wrapper around the actual
|
|
35
|
-
# Python function / class implementing it.
|
|
36
|
-
UDFResult = dict[str, Any]
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
@dataclass
|
|
40
|
-
class UDFProperties:
|
|
41
|
-
"""Container for basic UDF properties."""
|
|
42
|
-
|
|
43
|
-
params: list[UDFParameter]
|
|
44
|
-
output: UDFOutputSpec
|
|
45
|
-
batch: int = 1
|
|
46
|
-
|
|
47
|
-
def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy:
|
|
48
|
-
if use_partitioning:
|
|
49
|
-
return Partition()
|
|
50
|
-
if self.batch == 1:
|
|
51
|
-
return NoBatching()
|
|
52
|
-
if self.batch > 1:
|
|
53
|
-
return Batch(self.batch)
|
|
54
|
-
raise ValueError(f"invalid batch size {self.batch}")
|
|
55
|
-
|
|
56
|
-
def signal_names(self) -> Iterable[str]:
|
|
57
|
-
return self.output.keys()
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
class UDFBase:
|
|
61
|
-
"""A base class for implementing stateful UDFs."""
|
|
62
|
-
|
|
63
|
-
def __init__(
|
|
64
|
-
self,
|
|
65
|
-
properties: UDFProperties,
|
|
66
|
-
):
|
|
67
|
-
self.properties = properties
|
|
68
|
-
self.signal_names = properties.signal_names()
|
|
69
|
-
self.output = properties.output
|
|
70
|
-
|
|
71
|
-
def run(
|
|
72
|
-
self,
|
|
73
|
-
udf_fields: "Sequence[str]",
|
|
74
|
-
udf_inputs: "Iterable[RowsOutput]",
|
|
75
|
-
catalog: "Catalog",
|
|
76
|
-
is_generator: bool,
|
|
77
|
-
cache: bool,
|
|
78
|
-
download_cb: Callback = DEFAULT_CALLBACK,
|
|
79
|
-
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
80
|
-
) -> Iterator[Iterable["UDFResult"]]:
|
|
81
|
-
for batch in udf_inputs:
|
|
82
|
-
if isinstance(batch, RowsOutputBatch):
|
|
83
|
-
n_rows = len(batch.rows)
|
|
84
|
-
inputs: UDFInput = UDFInputBatch(
|
|
85
|
-
[RowDict(zip(udf_fields, row)) for row in batch.rows]
|
|
86
|
-
)
|
|
87
|
-
else:
|
|
88
|
-
n_rows = 1
|
|
89
|
-
inputs = RowDict(zip(udf_fields, batch))
|
|
90
|
-
output = self.run_once(catalog, inputs, is_generator, cache, cb=download_cb)
|
|
91
|
-
processed_cb.relative_update(n_rows)
|
|
92
|
-
yield output
|
|
93
|
-
|
|
94
|
-
def run_once(
|
|
95
|
-
self,
|
|
96
|
-
catalog: "Catalog",
|
|
97
|
-
arg: "UDFInput",
|
|
98
|
-
is_generator: bool = False,
|
|
99
|
-
cache: bool = False,
|
|
100
|
-
cb: Callback = DEFAULT_CALLBACK,
|
|
101
|
-
) -> Iterable[UDFResult]:
|
|
102
|
-
raise NotImplementedError
|
|
103
|
-
|
|
104
|
-
def bind_parameters(self, catalog: "Catalog", row: "RowDict", **kwargs) -> list:
|
|
105
|
-
return [p.get_value(catalog, row, **kwargs) for p in self.properties.params]
|
|
106
|
-
|
|
107
|
-
def _process_results(
|
|
108
|
-
self,
|
|
109
|
-
rows: Sequence["RowDict"],
|
|
110
|
-
results: Sequence[Sequence[Any]],
|
|
111
|
-
is_generator=False,
|
|
112
|
-
) -> Iterable[UDFResult]:
|
|
113
|
-
"""Create a list of dictionaries representing UDF results."""
|
|
114
|
-
|
|
115
|
-
# outputting rows
|
|
116
|
-
if is_generator:
|
|
117
|
-
# each row in results is a tuple of column values
|
|
118
|
-
return (dict(zip(self.signal_names, row)) for row in results)
|
|
119
|
-
|
|
120
|
-
# outputting signals
|
|
121
|
-
row_ids = [row["sys__id"] for row in rows]
|
|
122
|
-
return [
|
|
123
|
-
{"sys__id": row_id} | dict(zip(self.signal_names, signals))
|
|
124
|
-
for row_id, signals in zip(row_ids, results)
|
|
125
|
-
if signals is not None # skip rows with no output
|
|
126
|
-
]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|