singlestoredb 1.16.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- singlestoredb/__init__.py +75 -0
- singlestoredb/ai/__init__.py +2 -0
- singlestoredb/ai/chat.py +139 -0
- singlestoredb/ai/embeddings.py +128 -0
- singlestoredb/alchemy/__init__.py +90 -0
- singlestoredb/apps/__init__.py +3 -0
- singlestoredb/apps/_cloud_functions.py +90 -0
- singlestoredb/apps/_config.py +72 -0
- singlestoredb/apps/_connection_info.py +18 -0
- singlestoredb/apps/_dashboards.py +47 -0
- singlestoredb/apps/_process.py +32 -0
- singlestoredb/apps/_python_udfs.py +100 -0
- singlestoredb/apps/_stdout_supress.py +30 -0
- singlestoredb/apps/_uvicorn_util.py +36 -0
- singlestoredb/auth.py +245 -0
- singlestoredb/config.py +484 -0
- singlestoredb/connection.py +1487 -0
- singlestoredb/converters.py +950 -0
- singlestoredb/docstring/__init__.py +33 -0
- singlestoredb/docstring/attrdoc.py +126 -0
- singlestoredb/docstring/common.py +230 -0
- singlestoredb/docstring/epydoc.py +267 -0
- singlestoredb/docstring/google.py +412 -0
- singlestoredb/docstring/numpydoc.py +562 -0
- singlestoredb/docstring/parser.py +100 -0
- singlestoredb/docstring/py.typed +1 -0
- singlestoredb/docstring/rest.py +256 -0
- singlestoredb/docstring/tests/__init__.py +1 -0
- singlestoredb/docstring/tests/_pydoctor.py +21 -0
- singlestoredb/docstring/tests/test_epydoc.py +729 -0
- singlestoredb/docstring/tests/test_google.py +1007 -0
- singlestoredb/docstring/tests/test_numpydoc.py +1100 -0
- singlestoredb/docstring/tests/test_parse_from_object.py +109 -0
- singlestoredb/docstring/tests/test_parser.py +248 -0
- singlestoredb/docstring/tests/test_rest.py +547 -0
- singlestoredb/docstring/tests/test_util.py +70 -0
- singlestoredb/docstring/util.py +141 -0
- singlestoredb/exceptions.py +120 -0
- singlestoredb/functions/__init__.py +16 -0
- singlestoredb/functions/decorator.py +201 -0
- singlestoredb/functions/dtypes.py +1793 -0
- singlestoredb/functions/ext/__init__.py +1 -0
- singlestoredb/functions/ext/arrow.py +375 -0
- singlestoredb/functions/ext/asgi.py +2133 -0
- singlestoredb/functions/ext/json.py +420 -0
- singlestoredb/functions/ext/mmap.py +413 -0
- singlestoredb/functions/ext/rowdat_1.py +724 -0
- singlestoredb/functions/ext/timer.py +89 -0
- singlestoredb/functions/ext/utils.py +218 -0
- singlestoredb/functions/signature.py +1578 -0
- singlestoredb/functions/typing/__init__.py +41 -0
- singlestoredb/functions/typing/numpy.py +20 -0
- singlestoredb/functions/typing/pandas.py +2 -0
- singlestoredb/functions/typing/polars.py +2 -0
- singlestoredb/functions/typing/pyarrow.py +2 -0
- singlestoredb/functions/utils.py +421 -0
- singlestoredb/fusion/__init__.py +11 -0
- singlestoredb/fusion/graphql.py +213 -0
- singlestoredb/fusion/handler.py +916 -0
- singlestoredb/fusion/handlers/__init__.py +0 -0
- singlestoredb/fusion/handlers/export.py +525 -0
- singlestoredb/fusion/handlers/files.py +690 -0
- singlestoredb/fusion/handlers/job.py +660 -0
- singlestoredb/fusion/handlers/models.py +250 -0
- singlestoredb/fusion/handlers/stage.py +502 -0
- singlestoredb/fusion/handlers/utils.py +324 -0
- singlestoredb/fusion/handlers/workspace.py +956 -0
- singlestoredb/fusion/registry.py +249 -0
- singlestoredb/fusion/result.py +399 -0
- singlestoredb/http/__init__.py +27 -0
- singlestoredb/http/connection.py +1267 -0
- singlestoredb/magics/__init__.py +34 -0
- singlestoredb/magics/run_personal.py +137 -0
- singlestoredb/magics/run_shared.py +134 -0
- singlestoredb/management/__init__.py +9 -0
- singlestoredb/management/billing_usage.py +148 -0
- singlestoredb/management/cluster.py +462 -0
- singlestoredb/management/export.py +295 -0
- singlestoredb/management/files.py +1102 -0
- singlestoredb/management/inference_api.py +105 -0
- singlestoredb/management/job.py +887 -0
- singlestoredb/management/manager.py +373 -0
- singlestoredb/management/organization.py +226 -0
- singlestoredb/management/region.py +169 -0
- singlestoredb/management/utils.py +423 -0
- singlestoredb/management/workspace.py +1927 -0
- singlestoredb/mysql/__init__.py +177 -0
- singlestoredb/mysql/_auth.py +298 -0
- singlestoredb/mysql/charset.py +214 -0
- singlestoredb/mysql/connection.py +2032 -0
- singlestoredb/mysql/constants/CLIENT.py +38 -0
- singlestoredb/mysql/constants/COMMAND.py +32 -0
- singlestoredb/mysql/constants/CR.py +78 -0
- singlestoredb/mysql/constants/ER.py +474 -0
- singlestoredb/mysql/constants/EXTENDED_TYPE.py +3 -0
- singlestoredb/mysql/constants/FIELD_TYPE.py +48 -0
- singlestoredb/mysql/constants/FLAG.py +15 -0
- singlestoredb/mysql/constants/SERVER_STATUS.py +10 -0
- singlestoredb/mysql/constants/VECTOR_TYPE.py +6 -0
- singlestoredb/mysql/constants/__init__.py +0 -0
- singlestoredb/mysql/converters.py +271 -0
- singlestoredb/mysql/cursors.py +896 -0
- singlestoredb/mysql/err.py +92 -0
- singlestoredb/mysql/optionfile.py +20 -0
- singlestoredb/mysql/protocol.py +450 -0
- singlestoredb/mysql/tests/__init__.py +19 -0
- singlestoredb/mysql/tests/base.py +126 -0
- singlestoredb/mysql/tests/conftest.py +37 -0
- singlestoredb/mysql/tests/test_DictCursor.py +132 -0
- singlestoredb/mysql/tests/test_SSCursor.py +141 -0
- singlestoredb/mysql/tests/test_basic.py +452 -0
- singlestoredb/mysql/tests/test_connection.py +851 -0
- singlestoredb/mysql/tests/test_converters.py +58 -0
- singlestoredb/mysql/tests/test_cursor.py +141 -0
- singlestoredb/mysql/tests/test_err.py +16 -0
- singlestoredb/mysql/tests/test_issues.py +514 -0
- singlestoredb/mysql/tests/test_load_local.py +75 -0
- singlestoredb/mysql/tests/test_nextset.py +88 -0
- singlestoredb/mysql/tests/test_optionfile.py +27 -0
- singlestoredb/mysql/tests/thirdparty/__init__.py +6 -0
- singlestoredb/mysql/tests/thirdparty/test_MySQLdb/__init__.py +9 -0
- singlestoredb/mysql/tests/thirdparty/test_MySQLdb/capabilities.py +323 -0
- singlestoredb/mysql/tests/thirdparty/test_MySQLdb/dbapi20.py +865 -0
- singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py +110 -0
- singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py +224 -0
- singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py +101 -0
- singlestoredb/mysql/times.py +23 -0
- singlestoredb/notebook/__init__.py +16 -0
- singlestoredb/notebook/_objects.py +213 -0
- singlestoredb/notebook/_portal.py +352 -0
- singlestoredb/py.typed +0 -0
- singlestoredb/pytest.py +352 -0
- singlestoredb/server/__init__.py +0 -0
- singlestoredb/server/docker.py +452 -0
- singlestoredb/server/free_tier.py +267 -0
- singlestoredb/tests/__init__.py +0 -0
- singlestoredb/tests/alltypes.sql +307 -0
- singlestoredb/tests/alltypes_no_nulls.sql +208 -0
- singlestoredb/tests/empty.sql +0 -0
- singlestoredb/tests/ext_funcs/__init__.py +702 -0
- singlestoredb/tests/local_infile.csv +3 -0
- singlestoredb/tests/test.ipynb +18 -0
- singlestoredb/tests/test.sql +680 -0
- singlestoredb/tests/test2.ipynb +18 -0
- singlestoredb/tests/test2.sql +1 -0
- singlestoredb/tests/test_basics.py +1332 -0
- singlestoredb/tests/test_config.py +318 -0
- singlestoredb/tests/test_connection.py +3103 -0
- singlestoredb/tests/test_dbapi.py +27 -0
- singlestoredb/tests/test_exceptions.py +45 -0
- singlestoredb/tests/test_ext_func.py +1472 -0
- singlestoredb/tests/test_ext_func_data.py +1101 -0
- singlestoredb/tests/test_fusion.py +1527 -0
- singlestoredb/tests/test_http.py +288 -0
- singlestoredb/tests/test_management.py +1599 -0
- singlestoredb/tests/test_plugin.py +33 -0
- singlestoredb/tests/test_results.py +171 -0
- singlestoredb/tests/test_types.py +132 -0
- singlestoredb/tests/test_udf.py +737 -0
- singlestoredb/tests/test_udf_returns.py +459 -0
- singlestoredb/tests/test_vectorstore.py +51 -0
- singlestoredb/tests/test_xdict.py +333 -0
- singlestoredb/tests/utils.py +141 -0
- singlestoredb/types.py +373 -0
- singlestoredb/utils/__init__.py +0 -0
- singlestoredb/utils/config.py +950 -0
- singlestoredb/utils/convert_rows.py +69 -0
- singlestoredb/utils/debug.py +13 -0
- singlestoredb/utils/dtypes.py +205 -0
- singlestoredb/utils/events.py +65 -0
- singlestoredb/utils/mogrify.py +151 -0
- singlestoredb/utils/results.py +585 -0
- singlestoredb/utils/xdict.py +425 -0
- singlestoredb/vectorstore.py +192 -0
- singlestoredb/warnings.py +5 -0
- singlestoredb-1.16.1.dist-info/METADATA +165 -0
- singlestoredb-1.16.1.dist-info/RECORD +183 -0
- singlestoredb-1.16.1.dist-info/WHEEL +5 -0
- singlestoredb-1.16.1.dist-info/entry_points.txt +2 -0
- singlestoredb-1.16.1.dist-info/licenses/LICENSE +201 -0
- singlestoredb-1.16.1.dist-info/top_level.txt +3 -0
- sqlx/__init__.py +4 -0
- sqlx/magic.py +113 -0
|
@@ -0,0 +1,1578 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
import dataclasses
|
|
3
|
+
import datetime
|
|
4
|
+
import inspect
|
|
5
|
+
import numbers
|
|
6
|
+
import os
|
|
7
|
+
import re
|
|
8
|
+
import string
|
|
9
|
+
import sys
|
|
10
|
+
import types
|
|
11
|
+
import typing
|
|
12
|
+
from collections.abc import Sequence
|
|
13
|
+
from typing import Any
|
|
14
|
+
from typing import Callable
|
|
15
|
+
from typing import Dict
|
|
16
|
+
from typing import List
|
|
17
|
+
from typing import Optional
|
|
18
|
+
from typing import Tuple
|
|
19
|
+
from typing import TypeVar
|
|
20
|
+
from typing import Union
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
import numpy as np
|
|
24
|
+
has_numpy = True
|
|
25
|
+
except ImportError:
|
|
26
|
+
has_numpy = False
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
from . import dtypes as dt
|
|
30
|
+
from . import utils
|
|
31
|
+
from .typing import Table
|
|
32
|
+
from .typing import Masked
|
|
33
|
+
from ..mysql.converters import escape_item # type: ignore
|
|
34
|
+
|
|
35
|
+
if sys.version_info >= (3, 10):
|
|
36
|
+
_UNION_TYPES = {typing.Union, types.UnionType}
|
|
37
|
+
else:
|
|
38
|
+
_UNION_TYPES = {typing.Union}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def is_union(x: Any) -> bool:
|
|
42
|
+
"""Check if the object is a Union."""
|
|
43
|
+
return typing.get_origin(x) in _UNION_TYPES
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class NoDefaultType:
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
NO_DEFAULT = NoDefaultType()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
array_types: Tuple[Any, ...]
|
|
54
|
+
|
|
55
|
+
if has_numpy:
|
|
56
|
+
array_types = (Sequence, np.ndarray)
|
|
57
|
+
numpy_type_map = {
|
|
58
|
+
np.integer: 'int64',
|
|
59
|
+
np.int_: 'int64',
|
|
60
|
+
np.int64: 'int64',
|
|
61
|
+
np.int32: 'int32',
|
|
62
|
+
np.int16: 'int16',
|
|
63
|
+
np.int8: 'int8',
|
|
64
|
+
np.uint: 'uint64',
|
|
65
|
+
np.unsignedinteger: 'uint64',
|
|
66
|
+
np.uint64: 'uint64',
|
|
67
|
+
np.uint32: 'uint32',
|
|
68
|
+
np.uint16: 'uint16',
|
|
69
|
+
np.uint8: 'uint8',
|
|
70
|
+
np.longlong: 'uint64',
|
|
71
|
+
np.ulonglong: 'uint64',
|
|
72
|
+
np.str_: 'str',
|
|
73
|
+
np.bytes_: 'bytes',
|
|
74
|
+
np.float64: 'float64',
|
|
75
|
+
np.float32: 'float32',
|
|
76
|
+
np.float16: 'float16',
|
|
77
|
+
np.double: 'float64',
|
|
78
|
+
}
|
|
79
|
+
if hasattr(np, 'unicode_'):
|
|
80
|
+
numpy_type_map[np.unicode_] = 'str'
|
|
81
|
+
if hasattr(np, 'float_'):
|
|
82
|
+
numpy_type_map[np.float_] = 'float64'
|
|
83
|
+
else:
|
|
84
|
+
array_types = (Sequence,)
|
|
85
|
+
numpy_type_map = {}
|
|
86
|
+
|
|
87
|
+
float_type_map = {
|
|
88
|
+
'float': 'float64',
|
|
89
|
+
'float_': 'float64',
|
|
90
|
+
'float64': 'float64',
|
|
91
|
+
'f8': 'float64',
|
|
92
|
+
'double': 'float64',
|
|
93
|
+
'float32': 'float32',
|
|
94
|
+
'f4': 'float32',
|
|
95
|
+
'float16': 'float16',
|
|
96
|
+
'f2': 'float16',
|
|
97
|
+
'float8': 'float8',
|
|
98
|
+
'f1': 'float8',
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
int_type_map = {
|
|
102
|
+
'int': 'int64',
|
|
103
|
+
'integer': 'int64',
|
|
104
|
+
'int_': 'int64',
|
|
105
|
+
'int64': 'int64',
|
|
106
|
+
'i8': 'int64',
|
|
107
|
+
'int32': 'int32',
|
|
108
|
+
'i4': 'int32',
|
|
109
|
+
'int16': 'int16',
|
|
110
|
+
'i2': 'int16',
|
|
111
|
+
'int8': 'int8',
|
|
112
|
+
'i1': 'int8',
|
|
113
|
+
'uint': 'uint64',
|
|
114
|
+
'uinteger': 'uint64',
|
|
115
|
+
'uint_': 'uint64',
|
|
116
|
+
'uint64': 'uint64',
|
|
117
|
+
'u8': 'uint64',
|
|
118
|
+
'uint32': 'uint32',
|
|
119
|
+
'u4': 'uint32',
|
|
120
|
+
'uint16': 'uint16',
|
|
121
|
+
'u2': 'uint16',
|
|
122
|
+
'uint8': 'uint8',
|
|
123
|
+
'u1': 'uint8',
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
sql_type_map = {
|
|
127
|
+
'bool': 'BOOL',
|
|
128
|
+
'int8': 'TINYINT',
|
|
129
|
+
'int16': 'SMALLINT',
|
|
130
|
+
'int32': 'INT',
|
|
131
|
+
'int64': 'BIGINT',
|
|
132
|
+
'uint8': 'TINYINT UNSIGNED',
|
|
133
|
+
'uint16': 'SMALLINT UNSIGNED',
|
|
134
|
+
'uint32': 'INT UNSIGNED',
|
|
135
|
+
'uint64': 'BIGINT UNSIGNED',
|
|
136
|
+
'float32': 'FLOAT',
|
|
137
|
+
'float64': 'DOUBLE',
|
|
138
|
+
'str': 'TEXT',
|
|
139
|
+
'bytes': 'BLOB',
|
|
140
|
+
'null': 'NULL',
|
|
141
|
+
'datetime': 'DATETIME',
|
|
142
|
+
'datetime6': 'DATETIME(6)',
|
|
143
|
+
'date': 'DATE',
|
|
144
|
+
'time': 'TIME',
|
|
145
|
+
'time6': 'TIME(6)',
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
sql_to_type_map = {
|
|
149
|
+
'BOOL': 'bool',
|
|
150
|
+
'TINYINT': 'int8',
|
|
151
|
+
'TINYINT UNSIGNED': 'uint8',
|
|
152
|
+
'SMALLINT': 'int16',
|
|
153
|
+
'SMALLINT UNSIGNED': 'int16',
|
|
154
|
+
'MEDIUMINT': 'int32',
|
|
155
|
+
'MEDIUMINT UNSIGNED': 'int32',
|
|
156
|
+
'INT24': 'int32',
|
|
157
|
+
'INT24 UNSIGNED': 'int32',
|
|
158
|
+
'INT': 'int32',
|
|
159
|
+
'INT UNSIGNED': 'int32',
|
|
160
|
+
'INTEGER': 'int32',
|
|
161
|
+
'INTEGER UNSIGNED': 'int32',
|
|
162
|
+
'BIGINT': 'int64',
|
|
163
|
+
'BIGINT UNSIGNED': 'int64',
|
|
164
|
+
'FLOAT': 'float32',
|
|
165
|
+
'DOUBLE': 'float64',
|
|
166
|
+
'REAL': 'float64',
|
|
167
|
+
'DATE': 'date',
|
|
168
|
+
'TIME': 'time',
|
|
169
|
+
'TIME(6)': 'time6',
|
|
170
|
+
'DATETIME': 'datetime',
|
|
171
|
+
'DATETIME(6)': 'datetime',
|
|
172
|
+
'TIMESTAMP': 'datetime',
|
|
173
|
+
'TIMESTAMP(6)': 'datetime',
|
|
174
|
+
'YEAR': 'uint64',
|
|
175
|
+
'CHAR': 'str',
|
|
176
|
+
'VARCHAR': 'str',
|
|
177
|
+
'TEXT': 'str',
|
|
178
|
+
'TINYTEXT': 'str',
|
|
179
|
+
'MEDIUMTEXT': 'str',
|
|
180
|
+
'LONGTEXT': 'str',
|
|
181
|
+
'BINARY': 'bytes',
|
|
182
|
+
'VARBINARY': 'bytes',
|
|
183
|
+
'BLOB': 'bytes',
|
|
184
|
+
'TINYBLOB': 'bytes',
|
|
185
|
+
'MEDIUMBLOB': 'bytes',
|
|
186
|
+
'LONGBLOB': 'bytes',
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
@dataclasses.dataclass
|
|
191
|
+
class ParamSpec:
|
|
192
|
+
# Normalized data type of the parameter
|
|
193
|
+
dtype: Any
|
|
194
|
+
|
|
195
|
+
# Name of the parameter, if applicable
|
|
196
|
+
name: str = ''
|
|
197
|
+
|
|
198
|
+
# SQL type of the parameter
|
|
199
|
+
sql_type: str = ''
|
|
200
|
+
|
|
201
|
+
# Default value of the parameter, if applicable
|
|
202
|
+
default: Any = NO_DEFAULT
|
|
203
|
+
|
|
204
|
+
# Transformer function to apply to the parameter
|
|
205
|
+
transformer: Optional[Callable[..., Any]] = None
|
|
206
|
+
|
|
207
|
+
# Whether the parameter is optional (e.g., Union[T, None] or Optional[T])
|
|
208
|
+
is_optional: bool = False
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
class Collection:
|
|
212
|
+
"""Base class for collection data types."""
|
|
213
|
+
|
|
214
|
+
def __init__(self, *item_dtypes: Union[List[type], type]):
|
|
215
|
+
self.item_dtypes = item_dtypes
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class TupleCollection(Collection):
|
|
219
|
+
pass
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
class ArrayCollection(Collection):
|
|
223
|
+
pass
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def get_data_format(obj: Any) -> str:
|
|
227
|
+
"""Return the data format of the DataFrame / Table / vector."""
|
|
228
|
+
# Cheating here a bit so we don't have to import pandas / polars / pyarrow
|
|
229
|
+
# unless we absolutely need to
|
|
230
|
+
if getattr(obj, '__module__', '').startswith('pandas.'):
|
|
231
|
+
return 'pandas'
|
|
232
|
+
if getattr(obj, '__module__', '').startswith('polars.'):
|
|
233
|
+
return 'polars'
|
|
234
|
+
if getattr(obj, '__module__', '').startswith('pyarrow.'):
|
|
235
|
+
return 'arrow'
|
|
236
|
+
if getattr(obj, '__module__', '').startswith('numpy.'):
|
|
237
|
+
return 'numpy'
|
|
238
|
+
if isinstance(obj, list):
|
|
239
|
+
return 'list'
|
|
240
|
+
return 'scalar'
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def escape_name(name: str) -> str:
|
|
244
|
+
"""Escape a function parameter name."""
|
|
245
|
+
if '`' in name:
|
|
246
|
+
name = name.replace('`', '``')
|
|
247
|
+
return f'`{name}`'
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def simplify_dtype(dtype: Any) -> List[Any]:
|
|
251
|
+
"""
|
|
252
|
+
Expand a type annotation to a flattened list of atomic types.
|
|
253
|
+
|
|
254
|
+
This function will attempty to find the underlying type of a
|
|
255
|
+
type annotation. For example, a Union of types will be flattened
|
|
256
|
+
to a list of types. A Tuple or Array type will be expanded to
|
|
257
|
+
a list of types. A TypeVar will be expanded to a list of
|
|
258
|
+
constraints and bounds.
|
|
259
|
+
|
|
260
|
+
Parameters
|
|
261
|
+
----------
|
|
262
|
+
dtype : Any
|
|
263
|
+
Python type annotation
|
|
264
|
+
|
|
265
|
+
Returns
|
|
266
|
+
-------
|
|
267
|
+
List[Any]
|
|
268
|
+
list of dtype strings, TupleCollections, and ArrayCollections
|
|
269
|
+
|
|
270
|
+
"""
|
|
271
|
+
origin = typing.get_origin(dtype)
|
|
272
|
+
atype = type(dtype)
|
|
273
|
+
args = []
|
|
274
|
+
|
|
275
|
+
# Flatten Unions
|
|
276
|
+
if is_union(dtype):
|
|
277
|
+
for x in typing.get_args(dtype):
|
|
278
|
+
args.extend(simplify_dtype(x))
|
|
279
|
+
|
|
280
|
+
# Expand custom types to individual types (does not support bounds)
|
|
281
|
+
elif atype is TypeVar:
|
|
282
|
+
for x in dtype.__constraints__:
|
|
283
|
+
args.extend(simplify_dtype(x))
|
|
284
|
+
if dtype.__bound__:
|
|
285
|
+
args.extend(simplify_dtype(dtype.__bound__))
|
|
286
|
+
|
|
287
|
+
# Sequence types
|
|
288
|
+
elif origin is not None and inspect.isclass(origin) and issubclass(origin, Sequence):
|
|
289
|
+
item_args: List[Union[List[type], type]] = []
|
|
290
|
+
for x in typing.get_args(dtype):
|
|
291
|
+
item_dtype = simplify_dtype(x)
|
|
292
|
+
if len(item_dtype) == 1:
|
|
293
|
+
item_args.append(item_dtype[0])
|
|
294
|
+
else:
|
|
295
|
+
item_args.append(item_dtype)
|
|
296
|
+
if origin is tuple or origin is Tuple:
|
|
297
|
+
args.append(TupleCollection(*item_args))
|
|
298
|
+
elif len(item_args) > 1:
|
|
299
|
+
raise TypeError('sequence types may only contain one item data type')
|
|
300
|
+
else:
|
|
301
|
+
args.append(ArrayCollection(*item_args))
|
|
302
|
+
|
|
303
|
+
# Not a Union or TypeVar
|
|
304
|
+
else:
|
|
305
|
+
args.append(dtype)
|
|
306
|
+
|
|
307
|
+
return args
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def normalize_dtype(dtype: Any) -> str:
|
|
311
|
+
"""
|
|
312
|
+
Normalize the type annotation into a type name.
|
|
313
|
+
|
|
314
|
+
Parameters
|
|
315
|
+
----------
|
|
316
|
+
dtype : Any
|
|
317
|
+
Type annotation, list of type annotations, or a string
|
|
318
|
+
containing a SQL type name
|
|
319
|
+
|
|
320
|
+
Returns
|
|
321
|
+
-------
|
|
322
|
+
str
|
|
323
|
+
Normalized type name
|
|
324
|
+
|
|
325
|
+
"""
|
|
326
|
+
if isinstance(dtype, list):
|
|
327
|
+
return '|'.join(normalize_dtype(x) for x in dtype)
|
|
328
|
+
|
|
329
|
+
if isinstance(dtype, str):
|
|
330
|
+
return sql_to_dtype(dtype)
|
|
331
|
+
|
|
332
|
+
if typing.get_origin(dtype) is np.dtype:
|
|
333
|
+
dtype = typing.get_args(dtype)[0]
|
|
334
|
+
|
|
335
|
+
# Specific types
|
|
336
|
+
if dtype is None or dtype is type(None): # noqa: E721
|
|
337
|
+
return 'null'
|
|
338
|
+
if dtype is int:
|
|
339
|
+
return 'int64'
|
|
340
|
+
if dtype is float:
|
|
341
|
+
return 'float64'
|
|
342
|
+
if dtype is bool:
|
|
343
|
+
return 'bool'
|
|
344
|
+
|
|
345
|
+
if utils.is_dataclass(dtype):
|
|
346
|
+
dc_fields = dataclasses.fields(dtype)
|
|
347
|
+
item_dtypes = ','.join(
|
|
348
|
+
f'{normalize_dtype(simplify_dtype(x.type))}' for x in dc_fields
|
|
349
|
+
)
|
|
350
|
+
return f'tuple[{item_dtypes}]'
|
|
351
|
+
|
|
352
|
+
if utils.is_typeddict(dtype):
|
|
353
|
+
td_fields = utils.get_annotations(dtype).keys()
|
|
354
|
+
item_dtypes = ','.join(
|
|
355
|
+
f'{normalize_dtype(simplify_dtype(dtype[x]))}' for x in td_fields
|
|
356
|
+
)
|
|
357
|
+
return f'tuple[{item_dtypes}]'
|
|
358
|
+
|
|
359
|
+
if utils.is_pydantic(dtype):
|
|
360
|
+
pyd_fields = dtype.model_fields.values()
|
|
361
|
+
item_dtypes = ','.join(
|
|
362
|
+
f'{normalize_dtype(simplify_dtype(x.annotation))}' # type: ignore
|
|
363
|
+
for x in pyd_fields
|
|
364
|
+
)
|
|
365
|
+
return f'tuple[{item_dtypes}]'
|
|
366
|
+
|
|
367
|
+
if utils.is_namedtuple(dtype):
|
|
368
|
+
nt_fields = utils.get_annotations(dtype).values()
|
|
369
|
+
item_dtypes = ','.join(
|
|
370
|
+
f'{normalize_dtype(simplify_dtype(dtype[x]))}' for x in nt_fields
|
|
371
|
+
)
|
|
372
|
+
return f'tuple[{item_dtypes}]'
|
|
373
|
+
|
|
374
|
+
if not inspect.isclass(dtype):
|
|
375
|
+
|
|
376
|
+
# Check for compound types
|
|
377
|
+
origin = typing.get_origin(dtype)
|
|
378
|
+
if origin is not None:
|
|
379
|
+
|
|
380
|
+
# Tuple type
|
|
381
|
+
if origin is Tuple:
|
|
382
|
+
args = typing.get_args(dtype)
|
|
383
|
+
item_dtypes = ','.join(normalize_dtype(x) for x in args)
|
|
384
|
+
return f'tuple[{item_dtypes}]'
|
|
385
|
+
|
|
386
|
+
# Array types
|
|
387
|
+
elif inspect.isclass(origin) and issubclass(origin, array_types):
|
|
388
|
+
args = typing.get_args(dtype)
|
|
389
|
+
item_dtype = normalize_dtype(args[0])
|
|
390
|
+
return f'array[{item_dtype}]'
|
|
391
|
+
|
|
392
|
+
raise TypeError(f'unsupported type annotation: {dtype}')
|
|
393
|
+
|
|
394
|
+
if isinstance(dtype, ArrayCollection):
|
|
395
|
+
item_dtypes = ','.join(normalize_dtype(x) for x in dtype.item_dtypes)
|
|
396
|
+
return f'array[{item_dtypes}]'
|
|
397
|
+
|
|
398
|
+
if isinstance(dtype, TupleCollection):
|
|
399
|
+
item_dtypes = ','.join(normalize_dtype(x) for x in dtype.item_dtypes)
|
|
400
|
+
return f'tuple[{item_dtypes}]'
|
|
401
|
+
|
|
402
|
+
# Check numpy types if it's available
|
|
403
|
+
if dtype in numpy_type_map:
|
|
404
|
+
return numpy_type_map[dtype]
|
|
405
|
+
|
|
406
|
+
# Broad numeric types
|
|
407
|
+
if issubclass(dtype, int):
|
|
408
|
+
return 'int64'
|
|
409
|
+
if issubclass(dtype, float):
|
|
410
|
+
return 'float64'
|
|
411
|
+
|
|
412
|
+
# Strings / Bytes
|
|
413
|
+
if issubclass(dtype, str):
|
|
414
|
+
return 'str'
|
|
415
|
+
if issubclass(dtype, (bytes, bytearray)):
|
|
416
|
+
return 'bytes'
|
|
417
|
+
|
|
418
|
+
# Date / Times
|
|
419
|
+
if issubclass(dtype, datetime.datetime):
|
|
420
|
+
return 'datetime'
|
|
421
|
+
if issubclass(dtype, datetime.date):
|
|
422
|
+
return 'date'
|
|
423
|
+
if issubclass(dtype, datetime.timedelta):
|
|
424
|
+
return 'time'
|
|
425
|
+
|
|
426
|
+
# Last resort, guess it by the name...
|
|
427
|
+
name = dtype.__name__.lower()
|
|
428
|
+
is_float = issubclass(dtype, numbers.Real)
|
|
429
|
+
is_int = issubclass(dtype, numbers.Integral)
|
|
430
|
+
if is_float:
|
|
431
|
+
return float_type_map.get(name, 'float64')
|
|
432
|
+
if is_int:
|
|
433
|
+
return int_type_map.get(name, 'int64')
|
|
434
|
+
|
|
435
|
+
raise TypeError(
|
|
436
|
+
f'unsupported type annotation: {dtype}; '
|
|
437
|
+
'use `args`/`returns` on the @udf/@tvf decorator to specify the data type',
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
def collapse_dtypes(dtypes: Union[str, List[str]], include_null: bool = False) -> str:
|
|
442
|
+
"""
|
|
443
|
+
Collapse a dtype possibly containing multiple data types to one type.
|
|
444
|
+
|
|
445
|
+
This function can fail if there is no single type that naturally
|
|
446
|
+
encompasses all of the types in the list.
|
|
447
|
+
|
|
448
|
+
Parameters
|
|
449
|
+
----------
|
|
450
|
+
dtypes : str or list[str]
|
|
451
|
+
The data types to collapse
|
|
452
|
+
include_null : bool, optional
|
|
453
|
+
Whether to force include null types in the result
|
|
454
|
+
|
|
455
|
+
Returns
|
|
456
|
+
-------
|
|
457
|
+
str
|
|
458
|
+
|
|
459
|
+
"""
|
|
460
|
+
if isinstance(dtypes, str) and '|' in dtypes:
|
|
461
|
+
dtypes = dtypes.split('|')
|
|
462
|
+
|
|
463
|
+
if not isinstance(dtypes, list):
|
|
464
|
+
return dtypes
|
|
465
|
+
|
|
466
|
+
orig_dtypes = dtypes
|
|
467
|
+
dtypes = list(set(dtypes))
|
|
468
|
+
|
|
469
|
+
is_nullable = include_null or 'null' in dtypes
|
|
470
|
+
|
|
471
|
+
dtypes = [x for x in dtypes if x != 'null']
|
|
472
|
+
|
|
473
|
+
if 'uint64' in dtypes:
|
|
474
|
+
dtypes = [x for x in dtypes if x not in ('uint8', 'uint16', 'uint32')]
|
|
475
|
+
if 'uint32' in dtypes:
|
|
476
|
+
dtypes = [x for x in dtypes if x not in ('uint8', 'uint16')]
|
|
477
|
+
if 'uint16' in dtypes:
|
|
478
|
+
dtypes = [x for x in dtypes if x not in ('uint8',)]
|
|
479
|
+
|
|
480
|
+
if 'int64' in dtypes:
|
|
481
|
+
dtypes = [
|
|
482
|
+
x for x in dtypes if x not in (
|
|
483
|
+
'int8', 'int16', 'int32',
|
|
484
|
+
'uint8', 'uint16', 'uint32',
|
|
485
|
+
)
|
|
486
|
+
]
|
|
487
|
+
if 'int32' in dtypes:
|
|
488
|
+
dtypes = [
|
|
489
|
+
x for x in dtypes if x not in (
|
|
490
|
+
'int8', 'int16',
|
|
491
|
+
'uint8', 'uint16',
|
|
492
|
+
)
|
|
493
|
+
]
|
|
494
|
+
if 'int16' in dtypes:
|
|
495
|
+
dtypes = [x for x in dtypes if x not in ('int8', 'uint8')]
|
|
496
|
+
|
|
497
|
+
if 'float64' in dtypes:
|
|
498
|
+
dtypes = [
|
|
499
|
+
x for x in dtypes if x not in (
|
|
500
|
+
'float32',
|
|
501
|
+
'int8', 'int16', 'int32', 'int64',
|
|
502
|
+
'uint8', 'uint16', 'uint32', 'uint64',
|
|
503
|
+
)
|
|
504
|
+
]
|
|
505
|
+
if 'float32' in dtypes:
|
|
506
|
+
dtypes = [
|
|
507
|
+
x for x in dtypes if x not in (
|
|
508
|
+
'int8', 'int16', 'int32',
|
|
509
|
+
'uint8', 'uint16', 'uint32',
|
|
510
|
+
)
|
|
511
|
+
]
|
|
512
|
+
|
|
513
|
+
for i, item in enumerate(dtypes):
|
|
514
|
+
|
|
515
|
+
if item.startswith('array[') and '|' in item:
|
|
516
|
+
_, item_spec = item.split('[', 1)
|
|
517
|
+
item_spec = item_spec[:-1]
|
|
518
|
+
item = collapse_dtypes(item_spec.split('|'))
|
|
519
|
+
dtypes[i] = f'array[{item}]'
|
|
520
|
+
|
|
521
|
+
elif item.startswith('tuple[') and '|' in item:
|
|
522
|
+
_, item_spec = item.split('[', 1)
|
|
523
|
+
item_spec = item_spec[:-1]
|
|
524
|
+
sub_dtypes = []
|
|
525
|
+
for subitem_spec in item_spec.split(','):
|
|
526
|
+
item = collapse_dtypes(subitem_spec.split('|'))
|
|
527
|
+
sub_dtypes.append(item)
|
|
528
|
+
dtypes[i] = f'tuple[{",".join(sub_dtypes)}]'
|
|
529
|
+
|
|
530
|
+
if len(dtypes) > 1:
|
|
531
|
+
raise TypeError(
|
|
532
|
+
'types can not be collapsed to a single type: '
|
|
533
|
+
f'{", ".join(orig_dtypes)}',
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
if not dtypes:
|
|
537
|
+
return 'null'
|
|
538
|
+
|
|
539
|
+
return dtypes[0] + ('?' if is_nullable else '')
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
def get_dataclass_schema(obj: Any) -> List[ParamSpec]:
|
|
543
|
+
"""
|
|
544
|
+
Get the schema of a dataclass.
|
|
545
|
+
|
|
546
|
+
Parameters
|
|
547
|
+
----------
|
|
548
|
+
obj : dataclass
|
|
549
|
+
The dataclass to get the schema of
|
|
550
|
+
|
|
551
|
+
Returns
|
|
552
|
+
-------
|
|
553
|
+
List[ParamSpec]
|
|
554
|
+
A list of parameter specifications for the dataclass fields
|
|
555
|
+
|
|
556
|
+
"""
|
|
557
|
+
return [
|
|
558
|
+
ParamSpec(
|
|
559
|
+
name=f.name,
|
|
560
|
+
dtype=f.type,
|
|
561
|
+
default=NO_DEFAULT if f.default is dataclasses.MISSING else f.default,
|
|
562
|
+
)
|
|
563
|
+
for f in dataclasses.fields(obj)
|
|
564
|
+
]
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
def get_typeddict_schema(obj: Any) -> List[ParamSpec]:
|
|
568
|
+
"""
|
|
569
|
+
Get the schema of a TypedDict.
|
|
570
|
+
|
|
571
|
+
Parameters
|
|
572
|
+
----------
|
|
573
|
+
obj : TypedDict
|
|
574
|
+
The TypedDict to get the schema of
|
|
575
|
+
|
|
576
|
+
Returns
|
|
577
|
+
-------
|
|
578
|
+
List[ParamSpec]
|
|
579
|
+
A list of parameter specifications for the TypedDict fields
|
|
580
|
+
|
|
581
|
+
"""
|
|
582
|
+
return [
|
|
583
|
+
ParamSpec(
|
|
584
|
+
name=k,
|
|
585
|
+
dtype=v,
|
|
586
|
+
default=getattr(obj, k, NO_DEFAULT),
|
|
587
|
+
)
|
|
588
|
+
for k, v in utils.get_annotations(obj).items()
|
|
589
|
+
]
|
|
590
|
+
|
|
591
|
+
|
|
592
|
+
def get_pydantic_schema(obj: Any) -> List[ParamSpec]:
|
|
593
|
+
"""
|
|
594
|
+
Get the schema of a pydantic model.
|
|
595
|
+
|
|
596
|
+
Parameters
|
|
597
|
+
----------
|
|
598
|
+
obj : pydantic.BaseModel
|
|
599
|
+
The pydantic model to get the schema of
|
|
600
|
+
|
|
601
|
+
Returns
|
|
602
|
+
-------
|
|
603
|
+
List[ParamSpec]
|
|
604
|
+
A list of parameter specifications for the pydantic model fields
|
|
605
|
+
|
|
606
|
+
"""
|
|
607
|
+
import pydantic_core
|
|
608
|
+
return [
|
|
609
|
+
ParamSpec(
|
|
610
|
+
name=k,
|
|
611
|
+
dtype=v.annotation,
|
|
612
|
+
default=NO_DEFAULT
|
|
613
|
+
if v.default is pydantic_core.PydanticUndefined else v.default,
|
|
614
|
+
)
|
|
615
|
+
for k, v in obj.model_fields.items()
|
|
616
|
+
]
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
def get_namedtuple_schema(obj: Any) -> List[ParamSpec]:
|
|
620
|
+
"""
|
|
621
|
+
Get the schema of a named tuple.
|
|
622
|
+
|
|
623
|
+
Parameters
|
|
624
|
+
----------
|
|
625
|
+
obj : NamedTuple
|
|
626
|
+
The named tuple to get the schema of
|
|
627
|
+
|
|
628
|
+
Returns
|
|
629
|
+
-------
|
|
630
|
+
List[ParamSpec]
|
|
631
|
+
A list of parameter specifications for the named tuple fields
|
|
632
|
+
|
|
633
|
+
"""
|
|
634
|
+
return [
|
|
635
|
+
(
|
|
636
|
+
ParamSpec(
|
|
637
|
+
name=k,
|
|
638
|
+
dtype=v,
|
|
639
|
+
default=obj._field_defaults.get(k, NO_DEFAULT),
|
|
640
|
+
)
|
|
641
|
+
)
|
|
642
|
+
for k, v in utils.get_annotations(obj).items()
|
|
643
|
+
]
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
def get_table_schema(obj: Any) -> List[ParamSpec]:
|
|
647
|
+
"""
|
|
648
|
+
Get the schema of a Table.
|
|
649
|
+
|
|
650
|
+
Parameters
|
|
651
|
+
----------
|
|
652
|
+
obj : Table
|
|
653
|
+
The Table to get the schema of
|
|
654
|
+
|
|
655
|
+
Returns
|
|
656
|
+
-------
|
|
657
|
+
List[ParamSpec]
|
|
658
|
+
A list of parameter specifications for the Table fields
|
|
659
|
+
|
|
660
|
+
"""
|
|
661
|
+
return [
|
|
662
|
+
ParamSpec(
|
|
663
|
+
name=k,
|
|
664
|
+
dtype=v,
|
|
665
|
+
default=getattr(obj, k, NO_DEFAULT),
|
|
666
|
+
)
|
|
667
|
+
for k, v in utils.get_annotations(obj).items()
|
|
668
|
+
]
|
|
669
|
+
|
|
670
|
+
|
|
671
|
+
def get_colspec(overrides: List[Any]) -> List[ParamSpec]:
|
|
672
|
+
"""
|
|
673
|
+
Get the column specification from the overrides.
|
|
674
|
+
|
|
675
|
+
Parameters
|
|
676
|
+
----------
|
|
677
|
+
overrides : List[Any]
|
|
678
|
+
The overrides to get the column specification from
|
|
679
|
+
|
|
680
|
+
Returns
|
|
681
|
+
-------
|
|
682
|
+
List[ParamSpec]
|
|
683
|
+
A list of parameter specifications for the column fields
|
|
684
|
+
|
|
685
|
+
"""
|
|
686
|
+
if len(overrides) == 1:
|
|
687
|
+
|
|
688
|
+
override = overrides[0]
|
|
689
|
+
|
|
690
|
+
# Dataclass
|
|
691
|
+
if utils.is_dataclass(override):
|
|
692
|
+
return get_dataclass_schema(override)
|
|
693
|
+
|
|
694
|
+
# TypedDict
|
|
695
|
+
elif utils.is_typeddict(override):
|
|
696
|
+
return get_typeddict_schema(override)
|
|
697
|
+
|
|
698
|
+
# Named tuple
|
|
699
|
+
elif utils.is_namedtuple(override):
|
|
700
|
+
return get_namedtuple_schema(override)
|
|
701
|
+
|
|
702
|
+
# Pydantic model
|
|
703
|
+
elif utils.is_pydantic(override):
|
|
704
|
+
return get_pydantic_schema(override)
|
|
705
|
+
|
|
706
|
+
# List of types
|
|
707
|
+
return [
|
|
708
|
+
ParamSpec(
|
|
709
|
+
name=getattr(x, 'name', ''),
|
|
710
|
+
dtype=sql_to_dtype(x) if isinstance(x, str) else x,
|
|
711
|
+
sql_type=x if isinstance(x, str) else '',
|
|
712
|
+
) for x in overrides
|
|
713
|
+
]
|
|
714
|
+
|
|
715
|
+
|
|
716
|
+
def unpack_masked_type(obj: Any) -> Any:
|
|
717
|
+
"""
|
|
718
|
+
Unpack a masked type into a single type.
|
|
719
|
+
|
|
720
|
+
Parameters
|
|
721
|
+
----------
|
|
722
|
+
obj : Any
|
|
723
|
+
The masked type to unpack
|
|
724
|
+
|
|
725
|
+
Returns
|
|
726
|
+
-------
|
|
727
|
+
Any
|
|
728
|
+
The unpacked type
|
|
729
|
+
|
|
730
|
+
"""
|
|
731
|
+
if typing.get_origin(obj) is Masked:
|
|
732
|
+
return typing.get_args(obj)[0]
|
|
733
|
+
return obj
|
|
734
|
+
|
|
735
|
+
|
|
736
|
+
def unwrap_optional(annotation: Any) -> Tuple[Any, bool]:
|
|
737
|
+
"""
|
|
738
|
+
Unwrap Optional[T] and Union[T, None] annotations to get the underlying type.
|
|
739
|
+
Also indicates whether the type was optional.
|
|
740
|
+
|
|
741
|
+
Examples:
|
|
742
|
+
Optional[int] -> (int, True)
|
|
743
|
+
Union[str, None] -> (str, True)
|
|
744
|
+
Union[int, str, None] -> (Union[int, str], True)
|
|
745
|
+
Union[int, str] -> (Union[int, str], False)
|
|
746
|
+
int -> (int, False)
|
|
747
|
+
|
|
748
|
+
Parameters
|
|
749
|
+
----------
|
|
750
|
+
annotation : Any
|
|
751
|
+
The type annotation to unwrap
|
|
752
|
+
|
|
753
|
+
Returns
|
|
754
|
+
-------
|
|
755
|
+
Tuple[Any, bool]
|
|
756
|
+
A tuple containing:
|
|
757
|
+
- The unwrapped type annotation
|
|
758
|
+
- A boolean indicating if the original type was optional (contained None)
|
|
759
|
+
|
|
760
|
+
"""
|
|
761
|
+
origin = typing.get_origin(annotation)
|
|
762
|
+
is_optional = False
|
|
763
|
+
|
|
764
|
+
# Handle Union types (which includes Optional)
|
|
765
|
+
if origin is Union:
|
|
766
|
+
args = typing.get_args(annotation)
|
|
767
|
+
# Check if None is in the union
|
|
768
|
+
is_optional = type(None) in args
|
|
769
|
+
|
|
770
|
+
# Filter out None/NoneType
|
|
771
|
+
non_none_args = [arg for arg in args if arg is not type(None)]
|
|
772
|
+
|
|
773
|
+
if not non_none_args:
|
|
774
|
+
# If only None was in the Union
|
|
775
|
+
from typing import Any
|
|
776
|
+
return Any, is_optional
|
|
777
|
+
elif len(non_none_args) == 1:
|
|
778
|
+
# If there's only one type left, return it directly
|
|
779
|
+
return non_none_args[0], is_optional
|
|
780
|
+
else:
|
|
781
|
+
# Recreate the Union with the remaining types
|
|
782
|
+
return Union[tuple(non_none_args)], is_optional
|
|
783
|
+
|
|
784
|
+
return annotation, is_optional
|
|
785
|
+
|
|
786
|
+
|
|
787
|
+
def is_composite_type(spec: Any) -> bool:
|
|
788
|
+
"""
|
|
789
|
+
Check if the object is a composite type (e.g., dataclass, TypedDict, etc.).
|
|
790
|
+
|
|
791
|
+
Parameters
|
|
792
|
+
----------
|
|
793
|
+
spec : Any
|
|
794
|
+
The object to check
|
|
795
|
+
|
|
796
|
+
Returns
|
|
797
|
+
-------
|
|
798
|
+
bool
|
|
799
|
+
True if the object is a composite type, False otherwise
|
|
800
|
+
|
|
801
|
+
"""
|
|
802
|
+
return inspect.isclass(spec) and \
|
|
803
|
+
(
|
|
804
|
+
utils.is_dataframe(spec)
|
|
805
|
+
or utils.is_dataclass(spec)
|
|
806
|
+
or utils.is_typeddict(spec)
|
|
807
|
+
or utils.is_pydantic(spec)
|
|
808
|
+
or utils.is_namedtuple(spec)
|
|
809
|
+
)
|
|
810
|
+
|
|
811
|
+
|
|
812
|
+
def check_composite_type(colspec: List[ParamSpec], mode: str, type_name: str) -> bool:
|
|
813
|
+
"""
|
|
814
|
+
Check if the column specification is a composite type.
|
|
815
|
+
|
|
816
|
+
Parameters
|
|
817
|
+
----------
|
|
818
|
+
colspec : List[ParamSpec]
|
|
819
|
+
The column specification to check
|
|
820
|
+
mode : str
|
|
821
|
+
The mode of the function, either 'parameter' or 'return'
|
|
822
|
+
type_name : str
|
|
823
|
+
The name of the parent type
|
|
824
|
+
|
|
825
|
+
Returns
|
|
826
|
+
-------
|
|
827
|
+
bool
|
|
828
|
+
Verify the composite type is valid for the given mode
|
|
829
|
+
|
|
830
|
+
"""
|
|
831
|
+
if mode == 'parameter':
|
|
832
|
+
if is_composite_type(colspec[0].dtype):
|
|
833
|
+
raise TypeError(
|
|
834
|
+
'composite types are not allowed in a '
|
|
835
|
+
f'{type_name}: {colspec[0].dtype.__name__}',
|
|
836
|
+
)
|
|
837
|
+
elif mode == 'return':
|
|
838
|
+
if is_composite_type(colspec[0].dtype):
|
|
839
|
+
raise TypeError(
|
|
840
|
+
'composite types are not allowed in a '
|
|
841
|
+
f'{type_name}: {colspec[0].dtype.__name__}',
|
|
842
|
+
)
|
|
843
|
+
return False
|
|
844
|
+
|
|
845
|
+
|
|
846
|
+
def get_schema(
|
|
847
|
+
spec: Any,
|
|
848
|
+
overrides: Optional[List[ParamSpec]] = None,
|
|
849
|
+
mode: str = 'parameter',
|
|
850
|
+
) -> Tuple[List[ParamSpec], str, str]:
|
|
851
|
+
"""
|
|
852
|
+
Expand a return type annotation into a list of types and field names.
|
|
853
|
+
|
|
854
|
+
Parameters
|
|
855
|
+
----------
|
|
856
|
+
spec : Any
|
|
857
|
+
The return type specification
|
|
858
|
+
overrides : List[ParamSpec], optional
|
|
859
|
+
List of SQL type specifications for the return type
|
|
860
|
+
mode : str
|
|
861
|
+
The mode of the function, either 'parameter' or 'return'
|
|
862
|
+
|
|
863
|
+
Returns
|
|
864
|
+
-------
|
|
865
|
+
Tuple[List[ParamSpec], str, str]
|
|
866
|
+
A list of parameter specifications for the function,
|
|
867
|
+
the normalized data format, and the SQL definition of the type
|
|
868
|
+
|
|
869
|
+
"""
|
|
870
|
+
colspec = []
|
|
871
|
+
data_format = ''
|
|
872
|
+
function_type = 'udf'
|
|
873
|
+
udf_parameter = '`returns=`' if mode == 'return' else '`args=`'
|
|
874
|
+
|
|
875
|
+
spec, is_optional = unwrap_optional(spec)
|
|
876
|
+
origin = typing.get_origin(spec)
|
|
877
|
+
args = typing.get_args(spec)
|
|
878
|
+
args_origins = [typing.get_origin(x) if x is not None else None for x in args]
|
|
879
|
+
|
|
880
|
+
# Make sure that the result of a TVF is a list or dataframe
|
|
881
|
+
if mode == 'return':
|
|
882
|
+
|
|
883
|
+
# See if it's a Table subclass with annotations
|
|
884
|
+
if inspect.isclass(origin) and origin is Table:
|
|
885
|
+
|
|
886
|
+
function_type = 'tvf'
|
|
887
|
+
|
|
888
|
+
if utils.is_dataframe(args[0]):
|
|
889
|
+
if not overrides:
|
|
890
|
+
raise TypeError(
|
|
891
|
+
'column types must be specified by the '
|
|
892
|
+
'`returns=` parameter of the @udf decorator',
|
|
893
|
+
)
|
|
894
|
+
|
|
895
|
+
if utils.get_module(args[0]) in ['pandas', 'polars', 'pyarrow']:
|
|
896
|
+
data_format = utils.get_module(args[0])
|
|
897
|
+
spec = args[0]
|
|
898
|
+
else:
|
|
899
|
+
raise TypeError(
|
|
900
|
+
'only pandas.DataFrames, polars.DataFrames, '
|
|
901
|
+
'and pyarrow.Tables are supported as tables.',
|
|
902
|
+
)
|
|
903
|
+
|
|
904
|
+
elif typing.get_origin(args[0]) is list:
|
|
905
|
+
if len(args) != 1:
|
|
906
|
+
raise TypeError(
|
|
907
|
+
'only one list is supported within a table; to '
|
|
908
|
+
'return multiple columns, use a tuple, NamedTuple, '
|
|
909
|
+
'dataclass, TypedDict, or pydantic model',
|
|
910
|
+
)
|
|
911
|
+
spec = typing.get_args(args[0])[0]
|
|
912
|
+
data_format = 'list'
|
|
913
|
+
|
|
914
|
+
elif all([utils.is_vector(x, include_masks=True) for x in args]):
|
|
915
|
+
pass
|
|
916
|
+
|
|
917
|
+
else:
|
|
918
|
+
raise TypeError(
|
|
919
|
+
'return type for TVF must be a list, DataFrame / Table, '
|
|
920
|
+
'or tuple of vectors',
|
|
921
|
+
)
|
|
922
|
+
|
|
923
|
+
# Short circuit check for common valid types
|
|
924
|
+
elif utils.is_vector(spec) or spec in {str, float, int, bytes}:
|
|
925
|
+
pass
|
|
926
|
+
|
|
927
|
+
# Try to catch some common mistakes
|
|
928
|
+
elif origin in [tuple, dict] or tuple in args_origins or is_composite_type(spec):
|
|
929
|
+
raise TypeError(
|
|
930
|
+
'invalid return type for a UDF; expecting a scalar or vector, '
|
|
931
|
+
f'but got {getattr(spec, "__name__", spec)}',
|
|
932
|
+
)
|
|
933
|
+
|
|
934
|
+
# Short circuit check for common valid types
|
|
935
|
+
elif utils.is_vector(spec) or spec in {str, float, int, bytes}:
|
|
936
|
+
pass
|
|
937
|
+
|
|
938
|
+
# Error out for incorrect parameter types
|
|
939
|
+
elif origin in [tuple, dict] or tuple in args_origins or is_composite_type(spec):
|
|
940
|
+
raise TypeError(
|
|
941
|
+
'parameter types must be scalar or vector, '
|
|
942
|
+
f'got {getattr(spec, "__name__", spec)}',
|
|
943
|
+
)
|
|
944
|
+
|
|
945
|
+
#
|
|
946
|
+
# Process each parameter / return type into a colspec
|
|
947
|
+
#
|
|
948
|
+
|
|
949
|
+
# Dataframe type
|
|
950
|
+
if utils.is_dataframe(spec):
|
|
951
|
+
if not overrides:
|
|
952
|
+
raise TypeError(
|
|
953
|
+
'column types must be specified in the '
|
|
954
|
+
f'{udf_parameter} parameter of the @udf decorator for a DataFrame',
|
|
955
|
+
)
|
|
956
|
+
# colspec = get_colspec(overrides[0].dtype)
|
|
957
|
+
colspec = overrides
|
|
958
|
+
|
|
959
|
+
# Numpy array types
|
|
960
|
+
elif utils.is_numpy(spec):
|
|
961
|
+
data_format = 'numpy'
|
|
962
|
+
if overrides:
|
|
963
|
+
colspec = overrides
|
|
964
|
+
elif len(typing.get_args(spec)) < 2:
|
|
965
|
+
raise TypeError(
|
|
966
|
+
'numpy array must have an element data type specified '
|
|
967
|
+
f'in the {udf_parameter} parameter of the @udf decorator '
|
|
968
|
+
'or with an NDArray type annotation',
|
|
969
|
+
)
|
|
970
|
+
else:
|
|
971
|
+
colspec = [ParamSpec(dtype=typing.get_args(spec)[1])]
|
|
972
|
+
check_composite_type(colspec, mode, 'numpy array')
|
|
973
|
+
|
|
974
|
+
# Pandas Series
|
|
975
|
+
elif utils.is_pandas_series(spec):
|
|
976
|
+
data_format = 'pandas'
|
|
977
|
+
if not overrides:
|
|
978
|
+
raise TypeError(
|
|
979
|
+
'pandas Series must have an element data type specified '
|
|
980
|
+
f'in the {udf_parameter} parameter of the @udf decorator',
|
|
981
|
+
)
|
|
982
|
+
colspec = overrides
|
|
983
|
+
check_composite_type(colspec, mode, 'pandas Series')
|
|
984
|
+
|
|
985
|
+
# Polars Series
|
|
986
|
+
elif utils.is_polars_series(spec):
|
|
987
|
+
data_format = 'polars'
|
|
988
|
+
if not overrides:
|
|
989
|
+
raise TypeError(
|
|
990
|
+
'polars Series must have an element data type specified '
|
|
991
|
+
f'in the {udf_parameter} parameter of the @udf decorator',
|
|
992
|
+
)
|
|
993
|
+
colspec = overrides
|
|
994
|
+
check_composite_type(colspec, mode, 'polars Series')
|
|
995
|
+
|
|
996
|
+
# PyArrow Array
|
|
997
|
+
elif utils.is_pyarrow_array(spec):
|
|
998
|
+
data_format = 'arrow'
|
|
999
|
+
if not overrides:
|
|
1000
|
+
raise TypeError(
|
|
1001
|
+
'pyarrow Arrays must have an element data type specified '
|
|
1002
|
+
f'in the {udf_parameter} parameter of the @udf decorator',
|
|
1003
|
+
)
|
|
1004
|
+
colspec = overrides
|
|
1005
|
+
check_composite_type(colspec, mode, 'pyarrow Array')
|
|
1006
|
+
|
|
1007
|
+
# Return type is specified by a dataclass definition
|
|
1008
|
+
elif utils.is_dataclass(spec):
|
|
1009
|
+
colspec = overrides or get_dataclass_schema(spec)
|
|
1010
|
+
|
|
1011
|
+
# Return type is specified by a TypedDict definition
|
|
1012
|
+
elif utils.is_typeddict(spec):
|
|
1013
|
+
colspec = overrides or get_typeddict_schema(spec)
|
|
1014
|
+
|
|
1015
|
+
# Return type is specified by a pydantic model
|
|
1016
|
+
elif utils.is_pydantic(spec):
|
|
1017
|
+
colspec = overrides or get_pydantic_schema(spec)
|
|
1018
|
+
|
|
1019
|
+
# Return type is specified by a named tuple
|
|
1020
|
+
elif utils.is_namedtuple(spec):
|
|
1021
|
+
colspec = overrides or get_namedtuple_schema(spec)
|
|
1022
|
+
|
|
1023
|
+
# Unrecognized return type
|
|
1024
|
+
elif spec is not None:
|
|
1025
|
+
|
|
1026
|
+
# Return type is specified by a SQL string
|
|
1027
|
+
if isinstance(spec, str):
|
|
1028
|
+
data_format = 'scalar'
|
|
1029
|
+
colspec = [ParamSpec(dtype=spec, is_optional=is_optional)]
|
|
1030
|
+
|
|
1031
|
+
# Plain list vector
|
|
1032
|
+
elif typing.get_origin(spec) is list:
|
|
1033
|
+
data_format = 'list'
|
|
1034
|
+
colspec = [ParamSpec(dtype=typing.get_args(spec)[0], is_optional=is_optional)]
|
|
1035
|
+
|
|
1036
|
+
# Multiple return values
|
|
1037
|
+
elif inspect.isclass(typing.get_origin(spec)) \
|
|
1038
|
+
and issubclass(typing.get_origin(spec), tuple): # type: ignore[arg-type]
|
|
1039
|
+
|
|
1040
|
+
# Make sure that the number of overrides matches the number of
|
|
1041
|
+
# return types or parameter types
|
|
1042
|
+
if overrides and len(typing.get_args(spec)) != len(overrides):
|
|
1043
|
+
raise ValueError(
|
|
1044
|
+
f'number of {mode} types does not match the number of '
|
|
1045
|
+
'overrides specified',
|
|
1046
|
+
)
|
|
1047
|
+
|
|
1048
|
+
colspec = []
|
|
1049
|
+
out_data_formats = []
|
|
1050
|
+
|
|
1051
|
+
# Get the colspec for each item in the tuple
|
|
1052
|
+
for i, x in enumerate(typing.get_args(spec)):
|
|
1053
|
+
params, out_data_format, _ = get_schema(
|
|
1054
|
+
unpack_masked_type(x),
|
|
1055
|
+
overrides=[overrides[i]] if overrides else [],
|
|
1056
|
+
# Always pass UDF mode for individual items
|
|
1057
|
+
mode=mode,
|
|
1058
|
+
)
|
|
1059
|
+
|
|
1060
|
+
# Use the name from the overrides if specified
|
|
1061
|
+
if overrides:
|
|
1062
|
+
if overrides[i] and not params[0].name:
|
|
1063
|
+
params[0].name = overrides[i].name
|
|
1064
|
+
elif not overrides[i].name:
|
|
1065
|
+
params[0].name = f'{string.ascii_letters[i]}'
|
|
1066
|
+
|
|
1067
|
+
colspec.append(params[0])
|
|
1068
|
+
out_data_formats.append(out_data_format)
|
|
1069
|
+
|
|
1070
|
+
# Make sure that all the data formats are the same
|
|
1071
|
+
if len(set(out_data_formats)) > 1:
|
|
1072
|
+
raise TypeError(
|
|
1073
|
+
'data formats must be all be the same vector / scalar type: '
|
|
1074
|
+
f'{", ".join(out_data_formats)}',
|
|
1075
|
+
)
|
|
1076
|
+
|
|
1077
|
+
if data_format != 'list' and out_data_formats:
|
|
1078
|
+
data_format = out_data_formats[0]
|
|
1079
|
+
|
|
1080
|
+
# Since the colspec was computed by get_schema already, don't go
|
|
1081
|
+
# through the process of normalizing the dtypes again
|
|
1082
|
+
return colspec, data_format, function_type # type: ignore
|
|
1083
|
+
|
|
1084
|
+
# Use overrides if specified
|
|
1085
|
+
elif overrides:
|
|
1086
|
+
if not data_format:
|
|
1087
|
+
data_format = get_data_format(spec)
|
|
1088
|
+
colspec = overrides
|
|
1089
|
+
|
|
1090
|
+
# Single value, no override
|
|
1091
|
+
else:
|
|
1092
|
+
if not data_format:
|
|
1093
|
+
data_format = 'scalar'
|
|
1094
|
+
colspec = [ParamSpec(dtype=spec, is_optional=is_optional)]
|
|
1095
|
+
|
|
1096
|
+
out = []
|
|
1097
|
+
|
|
1098
|
+
# Normalize colspec data types
|
|
1099
|
+
for c in colspec:
|
|
1100
|
+
|
|
1101
|
+
if isinstance(c.dtype, str):
|
|
1102
|
+
dtype = c.dtype
|
|
1103
|
+
else:
|
|
1104
|
+
dtype = collapse_dtypes(
|
|
1105
|
+
[normalize_dtype(x) for x in simplify_dtype(c.dtype)],
|
|
1106
|
+
include_null=c.is_optional,
|
|
1107
|
+
)
|
|
1108
|
+
|
|
1109
|
+
p = ParamSpec(
|
|
1110
|
+
name=c.name,
|
|
1111
|
+
dtype=dtype,
|
|
1112
|
+
sql_type=c.sql_type if isinstance(c.sql_type, str) else None,
|
|
1113
|
+
is_optional=c.is_optional,
|
|
1114
|
+
)
|
|
1115
|
+
|
|
1116
|
+
out.append(p)
|
|
1117
|
+
|
|
1118
|
+
return out, data_format, function_type
|
|
1119
|
+
|
|
1120
|
+
|
|
1121
|
+
def vector_check(obj: Any) -> Tuple[Any, str]:
|
|
1122
|
+
"""
|
|
1123
|
+
Check if the object is a vector type.
|
|
1124
|
+
|
|
1125
|
+
Parameters
|
|
1126
|
+
----------
|
|
1127
|
+
obj : Any
|
|
1128
|
+
The object to check
|
|
1129
|
+
|
|
1130
|
+
Returns
|
|
1131
|
+
-------
|
|
1132
|
+
Tuple[Any, str]
|
|
1133
|
+
The scalar type and the data format:
|
|
1134
|
+
'scalar', 'list', 'numpy', 'pandas', or 'polars'
|
|
1135
|
+
|
|
1136
|
+
"""
|
|
1137
|
+
if utils.is_numpy(obj):
|
|
1138
|
+
if len(typing.get_args(obj)) < 2:
|
|
1139
|
+
return None, 'numpy'
|
|
1140
|
+
return typing.get_args(obj)[1], 'numpy'
|
|
1141
|
+
if utils.is_pandas_series(obj):
|
|
1142
|
+
if len(typing.get_args(obj)) < 2:
|
|
1143
|
+
return None, 'pandas'
|
|
1144
|
+
return typing.get_args(obj)[1], 'pandas'
|
|
1145
|
+
if utils.is_polars_series(obj):
|
|
1146
|
+
return None, 'polars'
|
|
1147
|
+
if utils.is_pyarrow_array(obj):
|
|
1148
|
+
return None, 'arrow'
|
|
1149
|
+
if obj is list or typing.get_origin(obj) is list:
|
|
1150
|
+
if len(typing.get_args(obj)) < 1:
|
|
1151
|
+
return None, 'list'
|
|
1152
|
+
return typing.get_args(obj)[0], 'list'
|
|
1153
|
+
return obj, 'scalar'
|
|
1154
|
+
|
|
1155
|
+
|
|
1156
|
+
def get_masks(func: Callable[..., Any]) -> Tuple[List[bool], List[bool]]:
|
|
1157
|
+
"""
|
|
1158
|
+
Get the list of masked parameters and return values for the function.
|
|
1159
|
+
|
|
1160
|
+
Parameters
|
|
1161
|
+
----------
|
|
1162
|
+
func : Callable
|
|
1163
|
+
The function to call as the endpoint
|
|
1164
|
+
|
|
1165
|
+
Returns
|
|
1166
|
+
-------
|
|
1167
|
+
Tuple[List[bool], List[bool]]
|
|
1168
|
+
A Tuple containing the parameter / return value masks
|
|
1169
|
+
as lists of booleans
|
|
1170
|
+
|
|
1171
|
+
|
|
1172
|
+
"""
|
|
1173
|
+
params = inspect.signature(func).parameters
|
|
1174
|
+
returns = inspect.signature(func).return_annotation
|
|
1175
|
+
|
|
1176
|
+
ret_masks = []
|
|
1177
|
+
if typing.get_origin(returns) is Masked:
|
|
1178
|
+
ret_masks = [True]
|
|
1179
|
+
elif typing.get_origin(returns) is Table:
|
|
1180
|
+
for x in typing.get_args(returns):
|
|
1181
|
+
if typing.get_origin(x) is Masked:
|
|
1182
|
+
ret_masks.append(True)
|
|
1183
|
+
else:
|
|
1184
|
+
ret_masks.append(False)
|
|
1185
|
+
if not any(ret_masks):
|
|
1186
|
+
ret_masks = []
|
|
1187
|
+
|
|
1188
|
+
return (
|
|
1189
|
+
[typing.get_origin(x.annotation) is Masked for x in params.values()],
|
|
1190
|
+
ret_masks,
|
|
1191
|
+
)
|
|
1192
|
+
|
|
1193
|
+
|
|
1194
|
+
def get_signature(
|
|
1195
|
+
func: Callable[..., Any],
|
|
1196
|
+
func_name: Optional[str] = None,
|
|
1197
|
+
) -> Dict[str, Any]:
|
|
1198
|
+
'''
|
|
1199
|
+
Print the UDF signature of the Python callable.
|
|
1200
|
+
|
|
1201
|
+
Parameters
|
|
1202
|
+
----------
|
|
1203
|
+
func : Callable
|
|
1204
|
+
The function to extract the signature of
|
|
1205
|
+
func_name : str, optional
|
|
1206
|
+
Name override for function
|
|
1207
|
+
|
|
1208
|
+
Returns
|
|
1209
|
+
-------
|
|
1210
|
+
Dict[str, Any]
|
|
1211
|
+
|
|
1212
|
+
'''
|
|
1213
|
+
signature = inspect.signature(func)
|
|
1214
|
+
args: List[Dict[str, Any]] = []
|
|
1215
|
+
returns: List[Dict[str, Any]] = []
|
|
1216
|
+
|
|
1217
|
+
attrs = getattr(func, '_singlestoredb_attrs', {})
|
|
1218
|
+
name = attrs.get('name', func_name if func_name else func.__name__)
|
|
1219
|
+
|
|
1220
|
+
out: Dict[str, Any] = dict(name=name, args=args, returns=returns)
|
|
1221
|
+
|
|
1222
|
+
# Do not allow variable positional or keyword arguments
|
|
1223
|
+
for p in signature.parameters.values():
|
|
1224
|
+
if p.kind == inspect.Parameter.VAR_POSITIONAL:
|
|
1225
|
+
raise TypeError('variable positional arguments are not supported')
|
|
1226
|
+
elif p.kind == inspect.Parameter.VAR_KEYWORD:
|
|
1227
|
+
raise TypeError('variable keyword arguments are not supported')
|
|
1228
|
+
|
|
1229
|
+
# TODO: Use typing.get_type_hints() for parameters / return values?
|
|
1230
|
+
|
|
1231
|
+
# Generate the parameter type and the corresponding SQL code for that parameter
|
|
1232
|
+
args_schema: List[ParamSpec] = []
|
|
1233
|
+
args_data_formats = []
|
|
1234
|
+
args_colspec = [x for x in get_colspec(attrs.get('args', []))]
|
|
1235
|
+
args_masks, ret_masks = get_masks(func)
|
|
1236
|
+
|
|
1237
|
+
if args_colspec and len(args_colspec) != len(signature.parameters):
|
|
1238
|
+
raise ValueError(
|
|
1239
|
+
'number of args in the decorator does not match '
|
|
1240
|
+
'the number of parameters in the function signature',
|
|
1241
|
+
)
|
|
1242
|
+
|
|
1243
|
+
params = list(signature.parameters.values())
|
|
1244
|
+
|
|
1245
|
+
# Get the colspec for each parameter
|
|
1246
|
+
for i, param in enumerate(params):
|
|
1247
|
+
arg_schema, args_data_format, _ = get_schema(
|
|
1248
|
+
unpack_masked_type(param.annotation),
|
|
1249
|
+
overrides=[args_colspec[i]] if args_colspec else [],
|
|
1250
|
+
mode='parameter',
|
|
1251
|
+
)
|
|
1252
|
+
args_data_formats.append(args_data_format)
|
|
1253
|
+
|
|
1254
|
+
if len(arg_schema) > 1:
|
|
1255
|
+
raise TypeError(
|
|
1256
|
+
'only one parameter type is supported; '
|
|
1257
|
+
f'got {len(arg_schema)} types for parameter {param.name}',
|
|
1258
|
+
)
|
|
1259
|
+
|
|
1260
|
+
# Insert parameter names as needed
|
|
1261
|
+
if not arg_schema[0].name:
|
|
1262
|
+
arg_schema[0].name = param.name
|
|
1263
|
+
|
|
1264
|
+
args_schema.append(arg_schema[0])
|
|
1265
|
+
|
|
1266
|
+
for i, pspec in enumerate(args_schema):
|
|
1267
|
+
default_option = {}
|
|
1268
|
+
|
|
1269
|
+
# Insert default values as needed
|
|
1270
|
+
if args_colspec and args_colspec[i].default is not NO_DEFAULT:
|
|
1271
|
+
default_option['default'] = args_colspec[i].default
|
|
1272
|
+
elif params and params[i].default is not param.empty:
|
|
1273
|
+
default_option['default'] = params[i].default
|
|
1274
|
+
|
|
1275
|
+
# Generate SQL code for the parameter
|
|
1276
|
+
sql = pspec.sql_type or dtype_to_sql(
|
|
1277
|
+
pspec.dtype,
|
|
1278
|
+
force_nullable=args_masks[i] or pspec.is_optional,
|
|
1279
|
+
**default_option,
|
|
1280
|
+
)
|
|
1281
|
+
|
|
1282
|
+
# Add parameter to args definitions
|
|
1283
|
+
args.append(
|
|
1284
|
+
dict(
|
|
1285
|
+
name=pspec.name,
|
|
1286
|
+
dtype=pspec.dtype,
|
|
1287
|
+
sql=sql,
|
|
1288
|
+
**default_option,
|
|
1289
|
+
transformer=pspec.transformer,
|
|
1290
|
+
),
|
|
1291
|
+
)
|
|
1292
|
+
|
|
1293
|
+
# Check that all the data formats are all the same
|
|
1294
|
+
if len(set(args_data_formats)) > 1:
|
|
1295
|
+
raise TypeError(
|
|
1296
|
+
'input data formats must be all be the same: '
|
|
1297
|
+
f'{", ".join(args_data_formats)}',
|
|
1298
|
+
)
|
|
1299
|
+
|
|
1300
|
+
adf = out['args_data_format'] = args_data_formats[0] \
|
|
1301
|
+
if args_data_formats else 'scalar'
|
|
1302
|
+
|
|
1303
|
+
returns_colspec = get_colspec(attrs.get('returns', []))
|
|
1304
|
+
|
|
1305
|
+
# Generate the return types and the corresponding SQL code for those values
|
|
1306
|
+
ret_schema, out['returns_data_format'], function_type = get_schema(
|
|
1307
|
+
unpack_masked_type(signature.return_annotation),
|
|
1308
|
+
overrides=returns_colspec if returns_colspec else None,
|
|
1309
|
+
mode='return',
|
|
1310
|
+
)
|
|
1311
|
+
|
|
1312
|
+
rdf = out['returns_data_format'] = out['returns_data_format'] or 'scalar'
|
|
1313
|
+
out['function_type'] = function_type
|
|
1314
|
+
|
|
1315
|
+
# Reality check the input and output data formats
|
|
1316
|
+
if function_type == 'udf':
|
|
1317
|
+
if (adf == 'scalar' and rdf != 'scalar') or \
|
|
1318
|
+
(adf != 'scalar' and rdf == 'scalar'):
|
|
1319
|
+
raise TypeError(
|
|
1320
|
+
'Function can not have scalar arguments and a vector return type, '
|
|
1321
|
+
'or vice versa. Parameters and return values must all be either ',
|
|
1322
|
+
'scalar or vector types.',
|
|
1323
|
+
)
|
|
1324
|
+
|
|
1325
|
+
# All functions have to return a value, so if none was specified try to
|
|
1326
|
+
# insert a reasonable default that includes NULLs.
|
|
1327
|
+
if not ret_schema:
|
|
1328
|
+
ret_schema = [
|
|
1329
|
+
ParamSpec(
|
|
1330
|
+
dtype='int8?', sql_type='TINYINT NULL', default=None, is_optional=True,
|
|
1331
|
+
),
|
|
1332
|
+
]
|
|
1333
|
+
|
|
1334
|
+
if function_type == 'udf' and len(ret_schema) > 1:
|
|
1335
|
+
raise ValueError(
|
|
1336
|
+
'UDFs can only return a single value; '
|
|
1337
|
+
f'got {len(ret_schema)} return values',
|
|
1338
|
+
)
|
|
1339
|
+
|
|
1340
|
+
# Generate field names for the return values
|
|
1341
|
+
if function_type == 'tvf' or len(ret_schema) > 1:
|
|
1342
|
+
for i, rspec in enumerate(ret_schema):
|
|
1343
|
+
if not rspec.name:
|
|
1344
|
+
ret_schema[i] = ParamSpec(
|
|
1345
|
+
name=string.ascii_letters[i],
|
|
1346
|
+
dtype=rspec.dtype,
|
|
1347
|
+
sql_type=rspec.sql_type,
|
|
1348
|
+
transformer=rspec.transformer,
|
|
1349
|
+
)
|
|
1350
|
+
|
|
1351
|
+
# Generate SQL code for the return values
|
|
1352
|
+
for i, rspec in enumerate(ret_schema):
|
|
1353
|
+
sql = rspec.sql_type or dtype_to_sql(
|
|
1354
|
+
rspec.dtype,
|
|
1355
|
+
force_nullable=(ret_masks[i] or rspec.is_optional)
|
|
1356
|
+
if ret_masks else rspec.is_optional,
|
|
1357
|
+
function_type=function_type,
|
|
1358
|
+
)
|
|
1359
|
+
returns.append(
|
|
1360
|
+
dict(
|
|
1361
|
+
name=rspec.name,
|
|
1362
|
+
dtype=rspec.dtype,
|
|
1363
|
+
sql=sql,
|
|
1364
|
+
transformer=rspec.transformer,
|
|
1365
|
+
),
|
|
1366
|
+
)
|
|
1367
|
+
|
|
1368
|
+
# Set the function endpoint
|
|
1369
|
+
out['endpoint'] = '/invoke'
|
|
1370
|
+
|
|
1371
|
+
# Set the function doc string
|
|
1372
|
+
out['doc'] = func.__doc__
|
|
1373
|
+
|
|
1374
|
+
return out
|
|
1375
|
+
|
|
1376
|
+
|
|
1377
|
+
def sql_to_dtype(sql: str) -> str:
|
|
1378
|
+
"""
|
|
1379
|
+
Convert a SQL type into a normalized data type identifier.
|
|
1380
|
+
|
|
1381
|
+
Parameters
|
|
1382
|
+
----------
|
|
1383
|
+
sql : str
|
|
1384
|
+
SQL data type specification
|
|
1385
|
+
|
|
1386
|
+
Returns
|
|
1387
|
+
-------
|
|
1388
|
+
str
|
|
1389
|
+
|
|
1390
|
+
"""
|
|
1391
|
+
sql = re.sub(r'\s+', r' ', sql.upper().strip())
|
|
1392
|
+
|
|
1393
|
+
m = re.match(r'(\w+)(\([^\)]+\))?', sql)
|
|
1394
|
+
if not m:
|
|
1395
|
+
raise TypeError(f'unrecognized data type: {sql}')
|
|
1396
|
+
|
|
1397
|
+
sql_type = m.group(1)
|
|
1398
|
+
type_attrs = re.split(r'\s*,\s*', m.group(2) or '')
|
|
1399
|
+
|
|
1400
|
+
if sql_type in ('DATETIME', 'TIME', 'TIMESTAMP') and \
|
|
1401
|
+
type_attrs and type_attrs[0] == '6':
|
|
1402
|
+
sql_type += '(6)'
|
|
1403
|
+
|
|
1404
|
+
elif ' UNSIGNED' in sql:
|
|
1405
|
+
sql_type += ' UNSIGNED'
|
|
1406
|
+
|
|
1407
|
+
try:
|
|
1408
|
+
dtype = sql_to_type_map[sql_type]
|
|
1409
|
+
except KeyError:
|
|
1410
|
+
raise TypeError(f'unrecognized data type: {sql_type}')
|
|
1411
|
+
|
|
1412
|
+
if ' NOT NULL' not in sql:
|
|
1413
|
+
dtype += '?'
|
|
1414
|
+
|
|
1415
|
+
return dtype
|
|
1416
|
+
|
|
1417
|
+
|
|
1418
|
+
def dtype_to_sql(
|
|
1419
|
+
dtype: str,
|
|
1420
|
+
default: Any = NO_DEFAULT,
|
|
1421
|
+
field_names: Optional[List[str]] = None,
|
|
1422
|
+
function_type: str = 'udf',
|
|
1423
|
+
force_nullable: bool = False,
|
|
1424
|
+
) -> str:
|
|
1425
|
+
"""
|
|
1426
|
+
Convert a collapsed dtype string to a SQL type.
|
|
1427
|
+
|
|
1428
|
+
Parameters
|
|
1429
|
+
----------
|
|
1430
|
+
dtype : str
|
|
1431
|
+
Simplified data type string
|
|
1432
|
+
default : Any, optional
|
|
1433
|
+
Default value
|
|
1434
|
+
field_names : List[str], optional
|
|
1435
|
+
Field names for tuple types
|
|
1436
|
+
function_type : str, optional
|
|
1437
|
+
Function type, either 'udf' or 'tvf'
|
|
1438
|
+
force_nullable : bool, optional
|
|
1439
|
+
Whether to force the type to be nullable
|
|
1440
|
+
|
|
1441
|
+
Returns
|
|
1442
|
+
-------
|
|
1443
|
+
str
|
|
1444
|
+
|
|
1445
|
+
"""
|
|
1446
|
+
nullable = ' NOT NULL'
|
|
1447
|
+
if dtype.endswith('?'):
|
|
1448
|
+
nullable = ' NULL'
|
|
1449
|
+
dtype = dtype[:-1]
|
|
1450
|
+
elif '|null' in dtype:
|
|
1451
|
+
nullable = ' NULL'
|
|
1452
|
+
dtype = dtype.replace('|null', '')
|
|
1453
|
+
elif force_nullable:
|
|
1454
|
+
nullable = ' NULL'
|
|
1455
|
+
|
|
1456
|
+
if dtype == 'null':
|
|
1457
|
+
nullable = ''
|
|
1458
|
+
|
|
1459
|
+
default_clause = ''
|
|
1460
|
+
if default is not NO_DEFAULT:
|
|
1461
|
+
if default is dt.NULL:
|
|
1462
|
+
default = None
|
|
1463
|
+
default_clause = f' DEFAULT {escape_item(default, "utf8")}'
|
|
1464
|
+
|
|
1465
|
+
if dtype.startswith('array['):
|
|
1466
|
+
_, dtypes = dtype.split('[', 1)
|
|
1467
|
+
dtypes = dtypes[:-1]
|
|
1468
|
+
item_dtype = dtype_to_sql(dtypes, function_type=function_type)
|
|
1469
|
+
return f'ARRAY({item_dtype}){nullable}{default_clause}'
|
|
1470
|
+
|
|
1471
|
+
if dtype.startswith('tuple['):
|
|
1472
|
+
_, dtypes = dtype.split('[', 1)
|
|
1473
|
+
dtypes = dtypes[:-1]
|
|
1474
|
+
item_dtypes = []
|
|
1475
|
+
for i, item in enumerate(dtypes.split(',')):
|
|
1476
|
+
if field_names:
|
|
1477
|
+
name = field_names[i]
|
|
1478
|
+
else:
|
|
1479
|
+
name = string.ascii_letters[i]
|
|
1480
|
+
if '=' in item:
|
|
1481
|
+
name, item = item.split('=', 1)
|
|
1482
|
+
item_dtypes.append(
|
|
1483
|
+
f'`{name}` ' + dtype_to_sql(item, function_type=function_type),
|
|
1484
|
+
)
|
|
1485
|
+
if function_type == 'udf':
|
|
1486
|
+
return f'RECORD({", ".join(item_dtypes)}){nullable}{default_clause}'
|
|
1487
|
+
else:
|
|
1488
|
+
return re.sub(
|
|
1489
|
+
r' NOT NULL\s*$', r'',
|
|
1490
|
+
f'TABLE({", ".join(item_dtypes)}){nullable}{default_clause}',
|
|
1491
|
+
)
|
|
1492
|
+
|
|
1493
|
+
return f'{sql_type_map[dtype]}{nullable}{default_clause}'
|
|
1494
|
+
|
|
1495
|
+
|
|
1496
|
+
def signature_to_sql(
|
|
1497
|
+
signature: Dict[str, Any],
|
|
1498
|
+
url: Optional[str] = None,
|
|
1499
|
+
data_format: str = 'rowdat_1',
|
|
1500
|
+
app_mode: str = 'remote',
|
|
1501
|
+
link: Optional[str] = None,
|
|
1502
|
+
replace: bool = False,
|
|
1503
|
+
database: Optional[str] = None,
|
|
1504
|
+
) -> str:
|
|
1505
|
+
'''
|
|
1506
|
+
Convert a dictionary function signature into SQL.
|
|
1507
|
+
|
|
1508
|
+
Parameters
|
|
1509
|
+
----------
|
|
1510
|
+
signature : Dict[str, Any]
|
|
1511
|
+
Function signature in the form of a dictionary as returned by
|
|
1512
|
+
the `get_signature` function
|
|
1513
|
+
|
|
1514
|
+
Returns
|
|
1515
|
+
-------
|
|
1516
|
+
str : SQL formatted function signature
|
|
1517
|
+
|
|
1518
|
+
'''
|
|
1519
|
+
function_type = signature.get('function_type') or 'udf'
|
|
1520
|
+
|
|
1521
|
+
args = []
|
|
1522
|
+
for arg in signature['args']:
|
|
1523
|
+
# Use default value from Python function if SQL doesn't set one
|
|
1524
|
+
default = ''
|
|
1525
|
+
if not re.search(r'\s+default\s+\S', arg['sql'], flags=re.I):
|
|
1526
|
+
default = ''
|
|
1527
|
+
if arg.get('default', None) is not None:
|
|
1528
|
+
default = f' DEFAULT {escape_item(arg["default"], "utf8")}'
|
|
1529
|
+
args.append(escape_name(arg['name']) + ' ' + arg['sql'] + default)
|
|
1530
|
+
|
|
1531
|
+
returns = ''
|
|
1532
|
+
if signature.get('returns'):
|
|
1533
|
+
ret = signature['returns']
|
|
1534
|
+
if function_type == 'tvf':
|
|
1535
|
+
res = 'TABLE(' + ', '.join(
|
|
1536
|
+
f'{escape_name(x["name"])} {x["sql"]}' for x in ret
|
|
1537
|
+
) + ')'
|
|
1538
|
+
elif ret[0]['name'] and len(ret) > 1:
|
|
1539
|
+
res = 'RECORD(' + ', '.join(
|
|
1540
|
+
f'{escape_name(x["name"])} {x["sql"]}' for x in ret
|
|
1541
|
+
) + ')'
|
|
1542
|
+
else:
|
|
1543
|
+
res = ret[0]['sql']
|
|
1544
|
+
returns = f' RETURNS {res}'
|
|
1545
|
+
else:
|
|
1546
|
+
raise ValueError(
|
|
1547
|
+
'function signature must have a return type specified',
|
|
1548
|
+
)
|
|
1549
|
+
|
|
1550
|
+
host = os.environ.get('SINGLESTOREDB_EXT_HOST', '127.0.0.1')
|
|
1551
|
+
port = os.environ.get('SINGLESTOREDB_EXT_PORT', '8000')
|
|
1552
|
+
|
|
1553
|
+
if app_mode.lower() == 'remote':
|
|
1554
|
+
url = url or f'https://{host}:{port}/invoke'
|
|
1555
|
+
elif url is None:
|
|
1556
|
+
raise ValueError('url can not be `None`')
|
|
1557
|
+
|
|
1558
|
+
database_prefix = ''
|
|
1559
|
+
if signature.get('database'):
|
|
1560
|
+
database_prefix = escape_name(signature['database']) + '.'
|
|
1561
|
+
elif database is not None:
|
|
1562
|
+
database_prefix = escape_name(database) + '.'
|
|
1563
|
+
|
|
1564
|
+
or_replace = 'OR REPLACE ' if (bool(signature.get('replace')) or replace) else ''
|
|
1565
|
+
|
|
1566
|
+
link_str = ''
|
|
1567
|
+
if link:
|
|
1568
|
+
if not re.match(r'^[\w_]+$', link):
|
|
1569
|
+
raise ValueError(f'invalid LINK name: {link}')
|
|
1570
|
+
link_str = f' LINK {link}'
|
|
1571
|
+
|
|
1572
|
+
return (
|
|
1573
|
+
f'CREATE {or_replace}EXTERNAL FUNCTION ' +
|
|
1574
|
+
f'{database_prefix}{escape_name(signature["name"])}' +
|
|
1575
|
+
'(' + ', '.join(args) + ')' + returns +
|
|
1576
|
+
f' AS {app_mode.upper()} SERVICE "{url}" FORMAT {data_format.upper()}'
|
|
1577
|
+
f'{link_str};'
|
|
1578
|
+
)
|