singlestoredb 1.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (183) hide show
  1. singlestoredb/__init__.py +75 -0
  2. singlestoredb/ai/__init__.py +2 -0
  3. singlestoredb/ai/chat.py +139 -0
  4. singlestoredb/ai/embeddings.py +128 -0
  5. singlestoredb/alchemy/__init__.py +90 -0
  6. singlestoredb/apps/__init__.py +3 -0
  7. singlestoredb/apps/_cloud_functions.py +90 -0
  8. singlestoredb/apps/_config.py +72 -0
  9. singlestoredb/apps/_connection_info.py +18 -0
  10. singlestoredb/apps/_dashboards.py +47 -0
  11. singlestoredb/apps/_process.py +32 -0
  12. singlestoredb/apps/_python_udfs.py +100 -0
  13. singlestoredb/apps/_stdout_supress.py +30 -0
  14. singlestoredb/apps/_uvicorn_util.py +36 -0
  15. singlestoredb/auth.py +245 -0
  16. singlestoredb/config.py +484 -0
  17. singlestoredb/connection.py +1487 -0
  18. singlestoredb/converters.py +950 -0
  19. singlestoredb/docstring/__init__.py +33 -0
  20. singlestoredb/docstring/attrdoc.py +126 -0
  21. singlestoredb/docstring/common.py +230 -0
  22. singlestoredb/docstring/epydoc.py +267 -0
  23. singlestoredb/docstring/google.py +412 -0
  24. singlestoredb/docstring/numpydoc.py +562 -0
  25. singlestoredb/docstring/parser.py +100 -0
  26. singlestoredb/docstring/py.typed +1 -0
  27. singlestoredb/docstring/rest.py +256 -0
  28. singlestoredb/docstring/tests/__init__.py +1 -0
  29. singlestoredb/docstring/tests/_pydoctor.py +21 -0
  30. singlestoredb/docstring/tests/test_epydoc.py +729 -0
  31. singlestoredb/docstring/tests/test_google.py +1007 -0
  32. singlestoredb/docstring/tests/test_numpydoc.py +1100 -0
  33. singlestoredb/docstring/tests/test_parse_from_object.py +109 -0
  34. singlestoredb/docstring/tests/test_parser.py +248 -0
  35. singlestoredb/docstring/tests/test_rest.py +547 -0
  36. singlestoredb/docstring/tests/test_util.py +70 -0
  37. singlestoredb/docstring/util.py +141 -0
  38. singlestoredb/exceptions.py +120 -0
  39. singlestoredb/functions/__init__.py +16 -0
  40. singlestoredb/functions/decorator.py +201 -0
  41. singlestoredb/functions/dtypes.py +1793 -0
  42. singlestoredb/functions/ext/__init__.py +1 -0
  43. singlestoredb/functions/ext/arrow.py +375 -0
  44. singlestoredb/functions/ext/asgi.py +2133 -0
  45. singlestoredb/functions/ext/json.py +420 -0
  46. singlestoredb/functions/ext/mmap.py +413 -0
  47. singlestoredb/functions/ext/rowdat_1.py +724 -0
  48. singlestoredb/functions/ext/timer.py +89 -0
  49. singlestoredb/functions/ext/utils.py +218 -0
  50. singlestoredb/functions/signature.py +1578 -0
  51. singlestoredb/functions/typing/__init__.py +41 -0
  52. singlestoredb/functions/typing/numpy.py +20 -0
  53. singlestoredb/functions/typing/pandas.py +2 -0
  54. singlestoredb/functions/typing/polars.py +2 -0
  55. singlestoredb/functions/typing/pyarrow.py +2 -0
  56. singlestoredb/functions/utils.py +421 -0
  57. singlestoredb/fusion/__init__.py +11 -0
  58. singlestoredb/fusion/graphql.py +213 -0
  59. singlestoredb/fusion/handler.py +916 -0
  60. singlestoredb/fusion/handlers/__init__.py +0 -0
  61. singlestoredb/fusion/handlers/export.py +525 -0
  62. singlestoredb/fusion/handlers/files.py +690 -0
  63. singlestoredb/fusion/handlers/job.py +660 -0
  64. singlestoredb/fusion/handlers/models.py +250 -0
  65. singlestoredb/fusion/handlers/stage.py +502 -0
  66. singlestoredb/fusion/handlers/utils.py +324 -0
  67. singlestoredb/fusion/handlers/workspace.py +956 -0
  68. singlestoredb/fusion/registry.py +249 -0
  69. singlestoredb/fusion/result.py +399 -0
  70. singlestoredb/http/__init__.py +27 -0
  71. singlestoredb/http/connection.py +1267 -0
  72. singlestoredb/magics/__init__.py +34 -0
  73. singlestoredb/magics/run_personal.py +137 -0
  74. singlestoredb/magics/run_shared.py +134 -0
  75. singlestoredb/management/__init__.py +9 -0
  76. singlestoredb/management/billing_usage.py +148 -0
  77. singlestoredb/management/cluster.py +462 -0
  78. singlestoredb/management/export.py +295 -0
  79. singlestoredb/management/files.py +1102 -0
  80. singlestoredb/management/inference_api.py +105 -0
  81. singlestoredb/management/job.py +887 -0
  82. singlestoredb/management/manager.py +373 -0
  83. singlestoredb/management/organization.py +226 -0
  84. singlestoredb/management/region.py +169 -0
  85. singlestoredb/management/utils.py +423 -0
  86. singlestoredb/management/workspace.py +1927 -0
  87. singlestoredb/mysql/__init__.py +177 -0
  88. singlestoredb/mysql/_auth.py +298 -0
  89. singlestoredb/mysql/charset.py +214 -0
  90. singlestoredb/mysql/connection.py +2032 -0
  91. singlestoredb/mysql/constants/CLIENT.py +38 -0
  92. singlestoredb/mysql/constants/COMMAND.py +32 -0
  93. singlestoredb/mysql/constants/CR.py +78 -0
  94. singlestoredb/mysql/constants/ER.py +474 -0
  95. singlestoredb/mysql/constants/EXTENDED_TYPE.py +3 -0
  96. singlestoredb/mysql/constants/FIELD_TYPE.py +48 -0
  97. singlestoredb/mysql/constants/FLAG.py +15 -0
  98. singlestoredb/mysql/constants/SERVER_STATUS.py +10 -0
  99. singlestoredb/mysql/constants/VECTOR_TYPE.py +6 -0
  100. singlestoredb/mysql/constants/__init__.py +0 -0
  101. singlestoredb/mysql/converters.py +271 -0
  102. singlestoredb/mysql/cursors.py +896 -0
  103. singlestoredb/mysql/err.py +92 -0
  104. singlestoredb/mysql/optionfile.py +20 -0
  105. singlestoredb/mysql/protocol.py +450 -0
  106. singlestoredb/mysql/tests/__init__.py +19 -0
  107. singlestoredb/mysql/tests/base.py +126 -0
  108. singlestoredb/mysql/tests/conftest.py +37 -0
  109. singlestoredb/mysql/tests/test_DictCursor.py +132 -0
  110. singlestoredb/mysql/tests/test_SSCursor.py +141 -0
  111. singlestoredb/mysql/tests/test_basic.py +452 -0
  112. singlestoredb/mysql/tests/test_connection.py +851 -0
  113. singlestoredb/mysql/tests/test_converters.py +58 -0
  114. singlestoredb/mysql/tests/test_cursor.py +141 -0
  115. singlestoredb/mysql/tests/test_err.py +16 -0
  116. singlestoredb/mysql/tests/test_issues.py +514 -0
  117. singlestoredb/mysql/tests/test_load_local.py +75 -0
  118. singlestoredb/mysql/tests/test_nextset.py +88 -0
  119. singlestoredb/mysql/tests/test_optionfile.py +27 -0
  120. singlestoredb/mysql/tests/thirdparty/__init__.py +6 -0
  121. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/__init__.py +9 -0
  122. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/capabilities.py +323 -0
  123. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/dbapi20.py +865 -0
  124. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py +110 -0
  125. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py +224 -0
  126. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py +101 -0
  127. singlestoredb/mysql/times.py +23 -0
  128. singlestoredb/notebook/__init__.py +16 -0
  129. singlestoredb/notebook/_objects.py +213 -0
  130. singlestoredb/notebook/_portal.py +352 -0
  131. singlestoredb/py.typed +0 -0
  132. singlestoredb/pytest.py +352 -0
  133. singlestoredb/server/__init__.py +0 -0
  134. singlestoredb/server/docker.py +452 -0
  135. singlestoredb/server/free_tier.py +267 -0
  136. singlestoredb/tests/__init__.py +0 -0
  137. singlestoredb/tests/alltypes.sql +307 -0
  138. singlestoredb/tests/alltypes_no_nulls.sql +208 -0
  139. singlestoredb/tests/empty.sql +0 -0
  140. singlestoredb/tests/ext_funcs/__init__.py +702 -0
  141. singlestoredb/tests/local_infile.csv +3 -0
  142. singlestoredb/tests/test.ipynb +18 -0
  143. singlestoredb/tests/test.sql +680 -0
  144. singlestoredb/tests/test2.ipynb +18 -0
  145. singlestoredb/tests/test2.sql +1 -0
  146. singlestoredb/tests/test_basics.py +1332 -0
  147. singlestoredb/tests/test_config.py +318 -0
  148. singlestoredb/tests/test_connection.py +3103 -0
  149. singlestoredb/tests/test_dbapi.py +27 -0
  150. singlestoredb/tests/test_exceptions.py +45 -0
  151. singlestoredb/tests/test_ext_func.py +1472 -0
  152. singlestoredb/tests/test_ext_func_data.py +1101 -0
  153. singlestoredb/tests/test_fusion.py +1527 -0
  154. singlestoredb/tests/test_http.py +288 -0
  155. singlestoredb/tests/test_management.py +1599 -0
  156. singlestoredb/tests/test_plugin.py +33 -0
  157. singlestoredb/tests/test_results.py +171 -0
  158. singlestoredb/tests/test_types.py +132 -0
  159. singlestoredb/tests/test_udf.py +737 -0
  160. singlestoredb/tests/test_udf_returns.py +459 -0
  161. singlestoredb/tests/test_vectorstore.py +51 -0
  162. singlestoredb/tests/test_xdict.py +333 -0
  163. singlestoredb/tests/utils.py +141 -0
  164. singlestoredb/types.py +373 -0
  165. singlestoredb/utils/__init__.py +0 -0
  166. singlestoredb/utils/config.py +950 -0
  167. singlestoredb/utils/convert_rows.py +69 -0
  168. singlestoredb/utils/debug.py +13 -0
  169. singlestoredb/utils/dtypes.py +205 -0
  170. singlestoredb/utils/events.py +65 -0
  171. singlestoredb/utils/mogrify.py +151 -0
  172. singlestoredb/utils/results.py +585 -0
  173. singlestoredb/utils/xdict.py +425 -0
  174. singlestoredb/vectorstore.py +192 -0
  175. singlestoredb/warnings.py +5 -0
  176. singlestoredb-1.16.1.dist-info/METADATA +165 -0
  177. singlestoredb-1.16.1.dist-info/RECORD +183 -0
  178. singlestoredb-1.16.1.dist-info/WHEEL +5 -0
  179. singlestoredb-1.16.1.dist-info/entry_points.txt +2 -0
  180. singlestoredb-1.16.1.dist-info/licenses/LICENSE +201 -0
  181. singlestoredb-1.16.1.dist-info/top_level.txt +3 -0
  182. sqlx/__init__.py +4 -0
  183. sqlx/magic.py +113 -0
@@ -0,0 +1,1578 @@
1
+ #!/usr/bin/env python3
2
+ import dataclasses
3
+ import datetime
4
+ import inspect
5
+ import numbers
6
+ import os
7
+ import re
8
+ import string
9
+ import sys
10
+ import types
11
+ import typing
12
+ from collections.abc import Sequence
13
+ from typing import Any
14
+ from typing import Callable
15
+ from typing import Dict
16
+ from typing import List
17
+ from typing import Optional
18
+ from typing import Tuple
19
+ from typing import TypeVar
20
+ from typing import Union
21
+
22
+ try:
23
+ import numpy as np
24
+ has_numpy = True
25
+ except ImportError:
26
+ has_numpy = False
27
+
28
+
29
+ from . import dtypes as dt
30
+ from . import utils
31
+ from .typing import Table
32
+ from .typing import Masked
33
+ from ..mysql.converters import escape_item # type: ignore
34
+
35
+ if sys.version_info >= (3, 10):
36
+ _UNION_TYPES = {typing.Union, types.UnionType}
37
+ else:
38
+ _UNION_TYPES = {typing.Union}
39
+
40
+
41
+ def is_union(x: Any) -> bool:
42
+ """Check if the object is a Union."""
43
+ return typing.get_origin(x) in _UNION_TYPES
44
+
45
+
46
+ class NoDefaultType:
47
+ pass
48
+
49
+
50
+ NO_DEFAULT = NoDefaultType()
51
+
52
+
53
+ array_types: Tuple[Any, ...]
54
+
55
+ if has_numpy:
56
+ array_types = (Sequence, np.ndarray)
57
+ numpy_type_map = {
58
+ np.integer: 'int64',
59
+ np.int_: 'int64',
60
+ np.int64: 'int64',
61
+ np.int32: 'int32',
62
+ np.int16: 'int16',
63
+ np.int8: 'int8',
64
+ np.uint: 'uint64',
65
+ np.unsignedinteger: 'uint64',
66
+ np.uint64: 'uint64',
67
+ np.uint32: 'uint32',
68
+ np.uint16: 'uint16',
69
+ np.uint8: 'uint8',
70
+ np.longlong: 'uint64',
71
+ np.ulonglong: 'uint64',
72
+ np.str_: 'str',
73
+ np.bytes_: 'bytes',
74
+ np.float64: 'float64',
75
+ np.float32: 'float32',
76
+ np.float16: 'float16',
77
+ np.double: 'float64',
78
+ }
79
+ if hasattr(np, 'unicode_'):
80
+ numpy_type_map[np.unicode_] = 'str'
81
+ if hasattr(np, 'float_'):
82
+ numpy_type_map[np.float_] = 'float64'
83
+ else:
84
+ array_types = (Sequence,)
85
+ numpy_type_map = {}
86
+
87
+ float_type_map = {
88
+ 'float': 'float64',
89
+ 'float_': 'float64',
90
+ 'float64': 'float64',
91
+ 'f8': 'float64',
92
+ 'double': 'float64',
93
+ 'float32': 'float32',
94
+ 'f4': 'float32',
95
+ 'float16': 'float16',
96
+ 'f2': 'float16',
97
+ 'float8': 'float8',
98
+ 'f1': 'float8',
99
+ }
100
+
101
+ int_type_map = {
102
+ 'int': 'int64',
103
+ 'integer': 'int64',
104
+ 'int_': 'int64',
105
+ 'int64': 'int64',
106
+ 'i8': 'int64',
107
+ 'int32': 'int32',
108
+ 'i4': 'int32',
109
+ 'int16': 'int16',
110
+ 'i2': 'int16',
111
+ 'int8': 'int8',
112
+ 'i1': 'int8',
113
+ 'uint': 'uint64',
114
+ 'uinteger': 'uint64',
115
+ 'uint_': 'uint64',
116
+ 'uint64': 'uint64',
117
+ 'u8': 'uint64',
118
+ 'uint32': 'uint32',
119
+ 'u4': 'uint32',
120
+ 'uint16': 'uint16',
121
+ 'u2': 'uint16',
122
+ 'uint8': 'uint8',
123
+ 'u1': 'uint8',
124
+ }
125
+
126
+ sql_type_map = {
127
+ 'bool': 'BOOL',
128
+ 'int8': 'TINYINT',
129
+ 'int16': 'SMALLINT',
130
+ 'int32': 'INT',
131
+ 'int64': 'BIGINT',
132
+ 'uint8': 'TINYINT UNSIGNED',
133
+ 'uint16': 'SMALLINT UNSIGNED',
134
+ 'uint32': 'INT UNSIGNED',
135
+ 'uint64': 'BIGINT UNSIGNED',
136
+ 'float32': 'FLOAT',
137
+ 'float64': 'DOUBLE',
138
+ 'str': 'TEXT',
139
+ 'bytes': 'BLOB',
140
+ 'null': 'NULL',
141
+ 'datetime': 'DATETIME',
142
+ 'datetime6': 'DATETIME(6)',
143
+ 'date': 'DATE',
144
+ 'time': 'TIME',
145
+ 'time6': 'TIME(6)',
146
+ }
147
+
148
+ sql_to_type_map = {
149
+ 'BOOL': 'bool',
150
+ 'TINYINT': 'int8',
151
+ 'TINYINT UNSIGNED': 'uint8',
152
+ 'SMALLINT': 'int16',
153
+ 'SMALLINT UNSIGNED': 'int16',
154
+ 'MEDIUMINT': 'int32',
155
+ 'MEDIUMINT UNSIGNED': 'int32',
156
+ 'INT24': 'int32',
157
+ 'INT24 UNSIGNED': 'int32',
158
+ 'INT': 'int32',
159
+ 'INT UNSIGNED': 'int32',
160
+ 'INTEGER': 'int32',
161
+ 'INTEGER UNSIGNED': 'int32',
162
+ 'BIGINT': 'int64',
163
+ 'BIGINT UNSIGNED': 'int64',
164
+ 'FLOAT': 'float32',
165
+ 'DOUBLE': 'float64',
166
+ 'REAL': 'float64',
167
+ 'DATE': 'date',
168
+ 'TIME': 'time',
169
+ 'TIME(6)': 'time6',
170
+ 'DATETIME': 'datetime',
171
+ 'DATETIME(6)': 'datetime',
172
+ 'TIMESTAMP': 'datetime',
173
+ 'TIMESTAMP(6)': 'datetime',
174
+ 'YEAR': 'uint64',
175
+ 'CHAR': 'str',
176
+ 'VARCHAR': 'str',
177
+ 'TEXT': 'str',
178
+ 'TINYTEXT': 'str',
179
+ 'MEDIUMTEXT': 'str',
180
+ 'LONGTEXT': 'str',
181
+ 'BINARY': 'bytes',
182
+ 'VARBINARY': 'bytes',
183
+ 'BLOB': 'bytes',
184
+ 'TINYBLOB': 'bytes',
185
+ 'MEDIUMBLOB': 'bytes',
186
+ 'LONGBLOB': 'bytes',
187
+ }
188
+
189
+
190
+ @dataclasses.dataclass
191
+ class ParamSpec:
192
+ # Normalized data type of the parameter
193
+ dtype: Any
194
+
195
+ # Name of the parameter, if applicable
196
+ name: str = ''
197
+
198
+ # SQL type of the parameter
199
+ sql_type: str = ''
200
+
201
+ # Default value of the parameter, if applicable
202
+ default: Any = NO_DEFAULT
203
+
204
+ # Transformer function to apply to the parameter
205
+ transformer: Optional[Callable[..., Any]] = None
206
+
207
+ # Whether the parameter is optional (e.g., Union[T, None] or Optional[T])
208
+ is_optional: bool = False
209
+
210
+
211
+ class Collection:
212
+ """Base class for collection data types."""
213
+
214
+ def __init__(self, *item_dtypes: Union[List[type], type]):
215
+ self.item_dtypes = item_dtypes
216
+
217
+
218
+ class TupleCollection(Collection):
219
+ pass
220
+
221
+
222
+ class ArrayCollection(Collection):
223
+ pass
224
+
225
+
226
+ def get_data_format(obj: Any) -> str:
227
+ """Return the data format of the DataFrame / Table / vector."""
228
+ # Cheating here a bit so we don't have to import pandas / polars / pyarrow
229
+ # unless we absolutely need to
230
+ if getattr(obj, '__module__', '').startswith('pandas.'):
231
+ return 'pandas'
232
+ if getattr(obj, '__module__', '').startswith('polars.'):
233
+ return 'polars'
234
+ if getattr(obj, '__module__', '').startswith('pyarrow.'):
235
+ return 'arrow'
236
+ if getattr(obj, '__module__', '').startswith('numpy.'):
237
+ return 'numpy'
238
+ if isinstance(obj, list):
239
+ return 'list'
240
+ return 'scalar'
241
+
242
+
243
+ def escape_name(name: str) -> str:
244
+ """Escape a function parameter name."""
245
+ if '`' in name:
246
+ name = name.replace('`', '``')
247
+ return f'`{name}`'
248
+
249
+
250
+ def simplify_dtype(dtype: Any) -> List[Any]:
251
+ """
252
+ Expand a type annotation to a flattened list of atomic types.
253
+
254
+ This function will attempty to find the underlying type of a
255
+ type annotation. For example, a Union of types will be flattened
256
+ to a list of types. A Tuple or Array type will be expanded to
257
+ a list of types. A TypeVar will be expanded to a list of
258
+ constraints and bounds.
259
+
260
+ Parameters
261
+ ----------
262
+ dtype : Any
263
+ Python type annotation
264
+
265
+ Returns
266
+ -------
267
+ List[Any]
268
+ list of dtype strings, TupleCollections, and ArrayCollections
269
+
270
+ """
271
+ origin = typing.get_origin(dtype)
272
+ atype = type(dtype)
273
+ args = []
274
+
275
+ # Flatten Unions
276
+ if is_union(dtype):
277
+ for x in typing.get_args(dtype):
278
+ args.extend(simplify_dtype(x))
279
+
280
+ # Expand custom types to individual types (does not support bounds)
281
+ elif atype is TypeVar:
282
+ for x in dtype.__constraints__:
283
+ args.extend(simplify_dtype(x))
284
+ if dtype.__bound__:
285
+ args.extend(simplify_dtype(dtype.__bound__))
286
+
287
+ # Sequence types
288
+ elif origin is not None and inspect.isclass(origin) and issubclass(origin, Sequence):
289
+ item_args: List[Union[List[type], type]] = []
290
+ for x in typing.get_args(dtype):
291
+ item_dtype = simplify_dtype(x)
292
+ if len(item_dtype) == 1:
293
+ item_args.append(item_dtype[0])
294
+ else:
295
+ item_args.append(item_dtype)
296
+ if origin is tuple or origin is Tuple:
297
+ args.append(TupleCollection(*item_args))
298
+ elif len(item_args) > 1:
299
+ raise TypeError('sequence types may only contain one item data type')
300
+ else:
301
+ args.append(ArrayCollection(*item_args))
302
+
303
+ # Not a Union or TypeVar
304
+ else:
305
+ args.append(dtype)
306
+
307
+ return args
308
+
309
+
310
+ def normalize_dtype(dtype: Any) -> str:
311
+ """
312
+ Normalize the type annotation into a type name.
313
+
314
+ Parameters
315
+ ----------
316
+ dtype : Any
317
+ Type annotation, list of type annotations, or a string
318
+ containing a SQL type name
319
+
320
+ Returns
321
+ -------
322
+ str
323
+ Normalized type name
324
+
325
+ """
326
+ if isinstance(dtype, list):
327
+ return '|'.join(normalize_dtype(x) for x in dtype)
328
+
329
+ if isinstance(dtype, str):
330
+ return sql_to_dtype(dtype)
331
+
332
+ if typing.get_origin(dtype) is np.dtype:
333
+ dtype = typing.get_args(dtype)[0]
334
+
335
+ # Specific types
336
+ if dtype is None or dtype is type(None): # noqa: E721
337
+ return 'null'
338
+ if dtype is int:
339
+ return 'int64'
340
+ if dtype is float:
341
+ return 'float64'
342
+ if dtype is bool:
343
+ return 'bool'
344
+
345
+ if utils.is_dataclass(dtype):
346
+ dc_fields = dataclasses.fields(dtype)
347
+ item_dtypes = ','.join(
348
+ f'{normalize_dtype(simplify_dtype(x.type))}' for x in dc_fields
349
+ )
350
+ return f'tuple[{item_dtypes}]'
351
+
352
+ if utils.is_typeddict(dtype):
353
+ td_fields = utils.get_annotations(dtype).keys()
354
+ item_dtypes = ','.join(
355
+ f'{normalize_dtype(simplify_dtype(dtype[x]))}' for x in td_fields
356
+ )
357
+ return f'tuple[{item_dtypes}]'
358
+
359
+ if utils.is_pydantic(dtype):
360
+ pyd_fields = dtype.model_fields.values()
361
+ item_dtypes = ','.join(
362
+ f'{normalize_dtype(simplify_dtype(x.annotation))}' # type: ignore
363
+ for x in pyd_fields
364
+ )
365
+ return f'tuple[{item_dtypes}]'
366
+
367
+ if utils.is_namedtuple(dtype):
368
+ nt_fields = utils.get_annotations(dtype).values()
369
+ item_dtypes = ','.join(
370
+ f'{normalize_dtype(simplify_dtype(dtype[x]))}' for x in nt_fields
371
+ )
372
+ return f'tuple[{item_dtypes}]'
373
+
374
+ if not inspect.isclass(dtype):
375
+
376
+ # Check for compound types
377
+ origin = typing.get_origin(dtype)
378
+ if origin is not None:
379
+
380
+ # Tuple type
381
+ if origin is Tuple:
382
+ args = typing.get_args(dtype)
383
+ item_dtypes = ','.join(normalize_dtype(x) for x in args)
384
+ return f'tuple[{item_dtypes}]'
385
+
386
+ # Array types
387
+ elif inspect.isclass(origin) and issubclass(origin, array_types):
388
+ args = typing.get_args(dtype)
389
+ item_dtype = normalize_dtype(args[0])
390
+ return f'array[{item_dtype}]'
391
+
392
+ raise TypeError(f'unsupported type annotation: {dtype}')
393
+
394
+ if isinstance(dtype, ArrayCollection):
395
+ item_dtypes = ','.join(normalize_dtype(x) for x in dtype.item_dtypes)
396
+ return f'array[{item_dtypes}]'
397
+
398
+ if isinstance(dtype, TupleCollection):
399
+ item_dtypes = ','.join(normalize_dtype(x) for x in dtype.item_dtypes)
400
+ return f'tuple[{item_dtypes}]'
401
+
402
+ # Check numpy types if it's available
403
+ if dtype in numpy_type_map:
404
+ return numpy_type_map[dtype]
405
+
406
+ # Broad numeric types
407
+ if issubclass(dtype, int):
408
+ return 'int64'
409
+ if issubclass(dtype, float):
410
+ return 'float64'
411
+
412
+ # Strings / Bytes
413
+ if issubclass(dtype, str):
414
+ return 'str'
415
+ if issubclass(dtype, (bytes, bytearray)):
416
+ return 'bytes'
417
+
418
+ # Date / Times
419
+ if issubclass(dtype, datetime.datetime):
420
+ return 'datetime'
421
+ if issubclass(dtype, datetime.date):
422
+ return 'date'
423
+ if issubclass(dtype, datetime.timedelta):
424
+ return 'time'
425
+
426
+ # Last resort, guess it by the name...
427
+ name = dtype.__name__.lower()
428
+ is_float = issubclass(dtype, numbers.Real)
429
+ is_int = issubclass(dtype, numbers.Integral)
430
+ if is_float:
431
+ return float_type_map.get(name, 'float64')
432
+ if is_int:
433
+ return int_type_map.get(name, 'int64')
434
+
435
+ raise TypeError(
436
+ f'unsupported type annotation: {dtype}; '
437
+ 'use `args`/`returns` on the @udf/@tvf decorator to specify the data type',
438
+ )
439
+
440
+
441
+ def collapse_dtypes(dtypes: Union[str, List[str]], include_null: bool = False) -> str:
442
+ """
443
+ Collapse a dtype possibly containing multiple data types to one type.
444
+
445
+ This function can fail if there is no single type that naturally
446
+ encompasses all of the types in the list.
447
+
448
+ Parameters
449
+ ----------
450
+ dtypes : str or list[str]
451
+ The data types to collapse
452
+ include_null : bool, optional
453
+ Whether to force include null types in the result
454
+
455
+ Returns
456
+ -------
457
+ str
458
+
459
+ """
460
+ if isinstance(dtypes, str) and '|' in dtypes:
461
+ dtypes = dtypes.split('|')
462
+
463
+ if not isinstance(dtypes, list):
464
+ return dtypes
465
+
466
+ orig_dtypes = dtypes
467
+ dtypes = list(set(dtypes))
468
+
469
+ is_nullable = include_null or 'null' in dtypes
470
+
471
+ dtypes = [x for x in dtypes if x != 'null']
472
+
473
+ if 'uint64' in dtypes:
474
+ dtypes = [x for x in dtypes if x not in ('uint8', 'uint16', 'uint32')]
475
+ if 'uint32' in dtypes:
476
+ dtypes = [x for x in dtypes if x not in ('uint8', 'uint16')]
477
+ if 'uint16' in dtypes:
478
+ dtypes = [x for x in dtypes if x not in ('uint8',)]
479
+
480
+ if 'int64' in dtypes:
481
+ dtypes = [
482
+ x for x in dtypes if x not in (
483
+ 'int8', 'int16', 'int32',
484
+ 'uint8', 'uint16', 'uint32',
485
+ )
486
+ ]
487
+ if 'int32' in dtypes:
488
+ dtypes = [
489
+ x for x in dtypes if x not in (
490
+ 'int8', 'int16',
491
+ 'uint8', 'uint16',
492
+ )
493
+ ]
494
+ if 'int16' in dtypes:
495
+ dtypes = [x for x in dtypes if x not in ('int8', 'uint8')]
496
+
497
+ if 'float64' in dtypes:
498
+ dtypes = [
499
+ x for x in dtypes if x not in (
500
+ 'float32',
501
+ 'int8', 'int16', 'int32', 'int64',
502
+ 'uint8', 'uint16', 'uint32', 'uint64',
503
+ )
504
+ ]
505
+ if 'float32' in dtypes:
506
+ dtypes = [
507
+ x for x in dtypes if x not in (
508
+ 'int8', 'int16', 'int32',
509
+ 'uint8', 'uint16', 'uint32',
510
+ )
511
+ ]
512
+
513
+ for i, item in enumerate(dtypes):
514
+
515
+ if item.startswith('array[') and '|' in item:
516
+ _, item_spec = item.split('[', 1)
517
+ item_spec = item_spec[:-1]
518
+ item = collapse_dtypes(item_spec.split('|'))
519
+ dtypes[i] = f'array[{item}]'
520
+
521
+ elif item.startswith('tuple[') and '|' in item:
522
+ _, item_spec = item.split('[', 1)
523
+ item_spec = item_spec[:-1]
524
+ sub_dtypes = []
525
+ for subitem_spec in item_spec.split(','):
526
+ item = collapse_dtypes(subitem_spec.split('|'))
527
+ sub_dtypes.append(item)
528
+ dtypes[i] = f'tuple[{",".join(sub_dtypes)}]'
529
+
530
+ if len(dtypes) > 1:
531
+ raise TypeError(
532
+ 'types can not be collapsed to a single type: '
533
+ f'{", ".join(orig_dtypes)}',
534
+ )
535
+
536
+ if not dtypes:
537
+ return 'null'
538
+
539
+ return dtypes[0] + ('?' if is_nullable else '')
540
+
541
+
542
+ def get_dataclass_schema(obj: Any) -> List[ParamSpec]:
543
+ """
544
+ Get the schema of a dataclass.
545
+
546
+ Parameters
547
+ ----------
548
+ obj : dataclass
549
+ The dataclass to get the schema of
550
+
551
+ Returns
552
+ -------
553
+ List[ParamSpec]
554
+ A list of parameter specifications for the dataclass fields
555
+
556
+ """
557
+ return [
558
+ ParamSpec(
559
+ name=f.name,
560
+ dtype=f.type,
561
+ default=NO_DEFAULT if f.default is dataclasses.MISSING else f.default,
562
+ )
563
+ for f in dataclasses.fields(obj)
564
+ ]
565
+
566
+
567
+ def get_typeddict_schema(obj: Any) -> List[ParamSpec]:
568
+ """
569
+ Get the schema of a TypedDict.
570
+
571
+ Parameters
572
+ ----------
573
+ obj : TypedDict
574
+ The TypedDict to get the schema of
575
+
576
+ Returns
577
+ -------
578
+ List[ParamSpec]
579
+ A list of parameter specifications for the TypedDict fields
580
+
581
+ """
582
+ return [
583
+ ParamSpec(
584
+ name=k,
585
+ dtype=v,
586
+ default=getattr(obj, k, NO_DEFAULT),
587
+ )
588
+ for k, v in utils.get_annotations(obj).items()
589
+ ]
590
+
591
+
592
+ def get_pydantic_schema(obj: Any) -> List[ParamSpec]:
593
+ """
594
+ Get the schema of a pydantic model.
595
+
596
+ Parameters
597
+ ----------
598
+ obj : pydantic.BaseModel
599
+ The pydantic model to get the schema of
600
+
601
+ Returns
602
+ -------
603
+ List[ParamSpec]
604
+ A list of parameter specifications for the pydantic model fields
605
+
606
+ """
607
+ import pydantic_core
608
+ return [
609
+ ParamSpec(
610
+ name=k,
611
+ dtype=v.annotation,
612
+ default=NO_DEFAULT
613
+ if v.default is pydantic_core.PydanticUndefined else v.default,
614
+ )
615
+ for k, v in obj.model_fields.items()
616
+ ]
617
+
618
+
619
+ def get_namedtuple_schema(obj: Any) -> List[ParamSpec]:
620
+ """
621
+ Get the schema of a named tuple.
622
+
623
+ Parameters
624
+ ----------
625
+ obj : NamedTuple
626
+ The named tuple to get the schema of
627
+
628
+ Returns
629
+ -------
630
+ List[ParamSpec]
631
+ A list of parameter specifications for the named tuple fields
632
+
633
+ """
634
+ return [
635
+ (
636
+ ParamSpec(
637
+ name=k,
638
+ dtype=v,
639
+ default=obj._field_defaults.get(k, NO_DEFAULT),
640
+ )
641
+ )
642
+ for k, v in utils.get_annotations(obj).items()
643
+ ]
644
+
645
+
646
+ def get_table_schema(obj: Any) -> List[ParamSpec]:
647
+ """
648
+ Get the schema of a Table.
649
+
650
+ Parameters
651
+ ----------
652
+ obj : Table
653
+ The Table to get the schema of
654
+
655
+ Returns
656
+ -------
657
+ List[ParamSpec]
658
+ A list of parameter specifications for the Table fields
659
+
660
+ """
661
+ return [
662
+ ParamSpec(
663
+ name=k,
664
+ dtype=v,
665
+ default=getattr(obj, k, NO_DEFAULT),
666
+ )
667
+ for k, v in utils.get_annotations(obj).items()
668
+ ]
669
+
670
+
671
+ def get_colspec(overrides: List[Any]) -> List[ParamSpec]:
672
+ """
673
+ Get the column specification from the overrides.
674
+
675
+ Parameters
676
+ ----------
677
+ overrides : List[Any]
678
+ The overrides to get the column specification from
679
+
680
+ Returns
681
+ -------
682
+ List[ParamSpec]
683
+ A list of parameter specifications for the column fields
684
+
685
+ """
686
+ if len(overrides) == 1:
687
+
688
+ override = overrides[0]
689
+
690
+ # Dataclass
691
+ if utils.is_dataclass(override):
692
+ return get_dataclass_schema(override)
693
+
694
+ # TypedDict
695
+ elif utils.is_typeddict(override):
696
+ return get_typeddict_schema(override)
697
+
698
+ # Named tuple
699
+ elif utils.is_namedtuple(override):
700
+ return get_namedtuple_schema(override)
701
+
702
+ # Pydantic model
703
+ elif utils.is_pydantic(override):
704
+ return get_pydantic_schema(override)
705
+
706
+ # List of types
707
+ return [
708
+ ParamSpec(
709
+ name=getattr(x, 'name', ''),
710
+ dtype=sql_to_dtype(x) if isinstance(x, str) else x,
711
+ sql_type=x if isinstance(x, str) else '',
712
+ ) for x in overrides
713
+ ]
714
+
715
+
716
+ def unpack_masked_type(obj: Any) -> Any:
717
+ """
718
+ Unpack a masked type into a single type.
719
+
720
+ Parameters
721
+ ----------
722
+ obj : Any
723
+ The masked type to unpack
724
+
725
+ Returns
726
+ -------
727
+ Any
728
+ The unpacked type
729
+
730
+ """
731
+ if typing.get_origin(obj) is Masked:
732
+ return typing.get_args(obj)[0]
733
+ return obj
734
+
735
+
736
+ def unwrap_optional(annotation: Any) -> Tuple[Any, bool]:
737
+ """
738
+ Unwrap Optional[T] and Union[T, None] annotations to get the underlying type.
739
+ Also indicates whether the type was optional.
740
+
741
+ Examples:
742
+ Optional[int] -> (int, True)
743
+ Union[str, None] -> (str, True)
744
+ Union[int, str, None] -> (Union[int, str], True)
745
+ Union[int, str] -> (Union[int, str], False)
746
+ int -> (int, False)
747
+
748
+ Parameters
749
+ ----------
750
+ annotation : Any
751
+ The type annotation to unwrap
752
+
753
+ Returns
754
+ -------
755
+ Tuple[Any, bool]
756
+ A tuple containing:
757
+ - The unwrapped type annotation
758
+ - A boolean indicating if the original type was optional (contained None)
759
+
760
+ """
761
+ origin = typing.get_origin(annotation)
762
+ is_optional = False
763
+
764
+ # Handle Union types (which includes Optional)
765
+ if origin is Union:
766
+ args = typing.get_args(annotation)
767
+ # Check if None is in the union
768
+ is_optional = type(None) in args
769
+
770
+ # Filter out None/NoneType
771
+ non_none_args = [arg for arg in args if arg is not type(None)]
772
+
773
+ if not non_none_args:
774
+ # If only None was in the Union
775
+ from typing import Any
776
+ return Any, is_optional
777
+ elif len(non_none_args) == 1:
778
+ # If there's only one type left, return it directly
779
+ return non_none_args[0], is_optional
780
+ else:
781
+ # Recreate the Union with the remaining types
782
+ return Union[tuple(non_none_args)], is_optional
783
+
784
+ return annotation, is_optional
785
+
786
+
787
+ def is_composite_type(spec: Any) -> bool:
788
+ """
789
+ Check if the object is a composite type (e.g., dataclass, TypedDict, etc.).
790
+
791
+ Parameters
792
+ ----------
793
+ spec : Any
794
+ The object to check
795
+
796
+ Returns
797
+ -------
798
+ bool
799
+ True if the object is a composite type, False otherwise
800
+
801
+ """
802
+ return inspect.isclass(spec) and \
803
+ (
804
+ utils.is_dataframe(spec)
805
+ or utils.is_dataclass(spec)
806
+ or utils.is_typeddict(spec)
807
+ or utils.is_pydantic(spec)
808
+ or utils.is_namedtuple(spec)
809
+ )
810
+
811
+
812
+ def check_composite_type(colspec: List[ParamSpec], mode: str, type_name: str) -> bool:
813
+ """
814
+ Check if the column specification is a composite type.
815
+
816
+ Parameters
817
+ ----------
818
+ colspec : List[ParamSpec]
819
+ The column specification to check
820
+ mode : str
821
+ The mode of the function, either 'parameter' or 'return'
822
+ type_name : str
823
+ The name of the parent type
824
+
825
+ Returns
826
+ -------
827
+ bool
828
+ Verify the composite type is valid for the given mode
829
+
830
+ """
831
+ if mode == 'parameter':
832
+ if is_composite_type(colspec[0].dtype):
833
+ raise TypeError(
834
+ 'composite types are not allowed in a '
835
+ f'{type_name}: {colspec[0].dtype.__name__}',
836
+ )
837
+ elif mode == 'return':
838
+ if is_composite_type(colspec[0].dtype):
839
+ raise TypeError(
840
+ 'composite types are not allowed in a '
841
+ f'{type_name}: {colspec[0].dtype.__name__}',
842
+ )
843
+ return False
844
+
845
+
846
+ def get_schema(
847
+ spec: Any,
848
+ overrides: Optional[List[ParamSpec]] = None,
849
+ mode: str = 'parameter',
850
+ ) -> Tuple[List[ParamSpec], str, str]:
851
+ """
852
+ Expand a return type annotation into a list of types and field names.
853
+
854
+ Parameters
855
+ ----------
856
+ spec : Any
857
+ The return type specification
858
+ overrides : List[ParamSpec], optional
859
+ List of SQL type specifications for the return type
860
+ mode : str
861
+ The mode of the function, either 'parameter' or 'return'
862
+
863
+ Returns
864
+ -------
865
+ Tuple[List[ParamSpec], str, str]
866
+ A list of parameter specifications for the function,
867
+ the normalized data format, and the SQL definition of the type
868
+
869
+ """
870
+ colspec = []
871
+ data_format = ''
872
+ function_type = 'udf'
873
+ udf_parameter = '`returns=`' if mode == 'return' else '`args=`'
874
+
875
+ spec, is_optional = unwrap_optional(spec)
876
+ origin = typing.get_origin(spec)
877
+ args = typing.get_args(spec)
878
+ args_origins = [typing.get_origin(x) if x is not None else None for x in args]
879
+
880
+ # Make sure that the result of a TVF is a list or dataframe
881
+ if mode == 'return':
882
+
883
+ # See if it's a Table subclass with annotations
884
+ if inspect.isclass(origin) and origin is Table:
885
+
886
+ function_type = 'tvf'
887
+
888
+ if utils.is_dataframe(args[0]):
889
+ if not overrides:
890
+ raise TypeError(
891
+ 'column types must be specified by the '
892
+ '`returns=` parameter of the @udf decorator',
893
+ )
894
+
895
+ if utils.get_module(args[0]) in ['pandas', 'polars', 'pyarrow']:
896
+ data_format = utils.get_module(args[0])
897
+ spec = args[0]
898
+ else:
899
+ raise TypeError(
900
+ 'only pandas.DataFrames, polars.DataFrames, '
901
+ 'and pyarrow.Tables are supported as tables.',
902
+ )
903
+
904
+ elif typing.get_origin(args[0]) is list:
905
+ if len(args) != 1:
906
+ raise TypeError(
907
+ 'only one list is supported within a table; to '
908
+ 'return multiple columns, use a tuple, NamedTuple, '
909
+ 'dataclass, TypedDict, or pydantic model',
910
+ )
911
+ spec = typing.get_args(args[0])[0]
912
+ data_format = 'list'
913
+
914
+ elif all([utils.is_vector(x, include_masks=True) for x in args]):
915
+ pass
916
+
917
+ else:
918
+ raise TypeError(
919
+ 'return type for TVF must be a list, DataFrame / Table, '
920
+ 'or tuple of vectors',
921
+ )
922
+
923
+ # Short circuit check for common valid types
924
+ elif utils.is_vector(spec) or spec in {str, float, int, bytes}:
925
+ pass
926
+
927
+ # Try to catch some common mistakes
928
+ elif origin in [tuple, dict] or tuple in args_origins or is_composite_type(spec):
929
+ raise TypeError(
930
+ 'invalid return type for a UDF; expecting a scalar or vector, '
931
+ f'but got {getattr(spec, "__name__", spec)}',
932
+ )
933
+
934
+ # Short circuit check for common valid types
935
+ elif utils.is_vector(spec) or spec in {str, float, int, bytes}:
936
+ pass
937
+
938
+ # Error out for incorrect parameter types
939
+ elif origin in [tuple, dict] or tuple in args_origins or is_composite_type(spec):
940
+ raise TypeError(
941
+ 'parameter types must be scalar or vector, '
942
+ f'got {getattr(spec, "__name__", spec)}',
943
+ )
944
+
945
+ #
946
+ # Process each parameter / return type into a colspec
947
+ #
948
+
949
+ # Dataframe type
950
+ if utils.is_dataframe(spec):
951
+ if not overrides:
952
+ raise TypeError(
953
+ 'column types must be specified in the '
954
+ f'{udf_parameter} parameter of the @udf decorator for a DataFrame',
955
+ )
956
+ # colspec = get_colspec(overrides[0].dtype)
957
+ colspec = overrides
958
+
959
+ # Numpy array types
960
+ elif utils.is_numpy(spec):
961
+ data_format = 'numpy'
962
+ if overrides:
963
+ colspec = overrides
964
+ elif len(typing.get_args(spec)) < 2:
965
+ raise TypeError(
966
+ 'numpy array must have an element data type specified '
967
+ f'in the {udf_parameter} parameter of the @udf decorator '
968
+ 'or with an NDArray type annotation',
969
+ )
970
+ else:
971
+ colspec = [ParamSpec(dtype=typing.get_args(spec)[1])]
972
+ check_composite_type(colspec, mode, 'numpy array')
973
+
974
+ # Pandas Series
975
+ elif utils.is_pandas_series(spec):
976
+ data_format = 'pandas'
977
+ if not overrides:
978
+ raise TypeError(
979
+ 'pandas Series must have an element data type specified '
980
+ f'in the {udf_parameter} parameter of the @udf decorator',
981
+ )
982
+ colspec = overrides
983
+ check_composite_type(colspec, mode, 'pandas Series')
984
+
985
+ # Polars Series
986
+ elif utils.is_polars_series(spec):
987
+ data_format = 'polars'
988
+ if not overrides:
989
+ raise TypeError(
990
+ 'polars Series must have an element data type specified '
991
+ f'in the {udf_parameter} parameter of the @udf decorator',
992
+ )
993
+ colspec = overrides
994
+ check_composite_type(colspec, mode, 'polars Series')
995
+
996
+ # PyArrow Array
997
+ elif utils.is_pyarrow_array(spec):
998
+ data_format = 'arrow'
999
+ if not overrides:
1000
+ raise TypeError(
1001
+ 'pyarrow Arrays must have an element data type specified '
1002
+ f'in the {udf_parameter} parameter of the @udf decorator',
1003
+ )
1004
+ colspec = overrides
1005
+ check_composite_type(colspec, mode, 'pyarrow Array')
1006
+
1007
+ # Return type is specified by a dataclass definition
1008
+ elif utils.is_dataclass(spec):
1009
+ colspec = overrides or get_dataclass_schema(spec)
1010
+
1011
+ # Return type is specified by a TypedDict definition
1012
+ elif utils.is_typeddict(spec):
1013
+ colspec = overrides or get_typeddict_schema(spec)
1014
+
1015
+ # Return type is specified by a pydantic model
1016
+ elif utils.is_pydantic(spec):
1017
+ colspec = overrides or get_pydantic_schema(spec)
1018
+
1019
+ # Return type is specified by a named tuple
1020
+ elif utils.is_namedtuple(spec):
1021
+ colspec = overrides or get_namedtuple_schema(spec)
1022
+
1023
+ # Unrecognized return type
1024
+ elif spec is not None:
1025
+
1026
+ # Return type is specified by a SQL string
1027
+ if isinstance(spec, str):
1028
+ data_format = 'scalar'
1029
+ colspec = [ParamSpec(dtype=spec, is_optional=is_optional)]
1030
+
1031
+ # Plain list vector
1032
+ elif typing.get_origin(spec) is list:
1033
+ data_format = 'list'
1034
+ colspec = [ParamSpec(dtype=typing.get_args(spec)[0], is_optional=is_optional)]
1035
+
1036
+ # Multiple return values
1037
+ elif inspect.isclass(typing.get_origin(spec)) \
1038
+ and issubclass(typing.get_origin(spec), tuple): # type: ignore[arg-type]
1039
+
1040
+ # Make sure that the number of overrides matches the number of
1041
+ # return types or parameter types
1042
+ if overrides and len(typing.get_args(spec)) != len(overrides):
1043
+ raise ValueError(
1044
+ f'number of {mode} types does not match the number of '
1045
+ 'overrides specified',
1046
+ )
1047
+
1048
+ colspec = []
1049
+ out_data_formats = []
1050
+
1051
+ # Get the colspec for each item in the tuple
1052
+ for i, x in enumerate(typing.get_args(spec)):
1053
+ params, out_data_format, _ = get_schema(
1054
+ unpack_masked_type(x),
1055
+ overrides=[overrides[i]] if overrides else [],
1056
+ # Always pass UDF mode for individual items
1057
+ mode=mode,
1058
+ )
1059
+
1060
+ # Use the name from the overrides if specified
1061
+ if overrides:
1062
+ if overrides[i] and not params[0].name:
1063
+ params[0].name = overrides[i].name
1064
+ elif not overrides[i].name:
1065
+ params[0].name = f'{string.ascii_letters[i]}'
1066
+
1067
+ colspec.append(params[0])
1068
+ out_data_formats.append(out_data_format)
1069
+
1070
+ # Make sure that all the data formats are the same
1071
+ if len(set(out_data_formats)) > 1:
1072
+ raise TypeError(
1073
+ 'data formats must be all be the same vector / scalar type: '
1074
+ f'{", ".join(out_data_formats)}',
1075
+ )
1076
+
1077
+ if data_format != 'list' and out_data_formats:
1078
+ data_format = out_data_formats[0]
1079
+
1080
+ # Since the colspec was computed by get_schema already, don't go
1081
+ # through the process of normalizing the dtypes again
1082
+ return colspec, data_format, function_type # type: ignore
1083
+
1084
+ # Use overrides if specified
1085
+ elif overrides:
1086
+ if not data_format:
1087
+ data_format = get_data_format(spec)
1088
+ colspec = overrides
1089
+
1090
+ # Single value, no override
1091
+ else:
1092
+ if not data_format:
1093
+ data_format = 'scalar'
1094
+ colspec = [ParamSpec(dtype=spec, is_optional=is_optional)]
1095
+
1096
+ out = []
1097
+
1098
+ # Normalize colspec data types
1099
+ for c in colspec:
1100
+
1101
+ if isinstance(c.dtype, str):
1102
+ dtype = c.dtype
1103
+ else:
1104
+ dtype = collapse_dtypes(
1105
+ [normalize_dtype(x) for x in simplify_dtype(c.dtype)],
1106
+ include_null=c.is_optional,
1107
+ )
1108
+
1109
+ p = ParamSpec(
1110
+ name=c.name,
1111
+ dtype=dtype,
1112
+ sql_type=c.sql_type if isinstance(c.sql_type, str) else None,
1113
+ is_optional=c.is_optional,
1114
+ )
1115
+
1116
+ out.append(p)
1117
+
1118
+ return out, data_format, function_type
1119
+
1120
+
1121
+ def vector_check(obj: Any) -> Tuple[Any, str]:
1122
+ """
1123
+ Check if the object is a vector type.
1124
+
1125
+ Parameters
1126
+ ----------
1127
+ obj : Any
1128
+ The object to check
1129
+
1130
+ Returns
1131
+ -------
1132
+ Tuple[Any, str]
1133
+ The scalar type and the data format:
1134
+ 'scalar', 'list', 'numpy', 'pandas', or 'polars'
1135
+
1136
+ """
1137
+ if utils.is_numpy(obj):
1138
+ if len(typing.get_args(obj)) < 2:
1139
+ return None, 'numpy'
1140
+ return typing.get_args(obj)[1], 'numpy'
1141
+ if utils.is_pandas_series(obj):
1142
+ if len(typing.get_args(obj)) < 2:
1143
+ return None, 'pandas'
1144
+ return typing.get_args(obj)[1], 'pandas'
1145
+ if utils.is_polars_series(obj):
1146
+ return None, 'polars'
1147
+ if utils.is_pyarrow_array(obj):
1148
+ return None, 'arrow'
1149
+ if obj is list or typing.get_origin(obj) is list:
1150
+ if len(typing.get_args(obj)) < 1:
1151
+ return None, 'list'
1152
+ return typing.get_args(obj)[0], 'list'
1153
+ return obj, 'scalar'
1154
+
1155
+
1156
+ def get_masks(func: Callable[..., Any]) -> Tuple[List[bool], List[bool]]:
1157
+ """
1158
+ Get the list of masked parameters and return values for the function.
1159
+
1160
+ Parameters
1161
+ ----------
1162
+ func : Callable
1163
+ The function to call as the endpoint
1164
+
1165
+ Returns
1166
+ -------
1167
+ Tuple[List[bool], List[bool]]
1168
+ A Tuple containing the parameter / return value masks
1169
+ as lists of booleans
1170
+
1171
+
1172
+ """
1173
+ params = inspect.signature(func).parameters
1174
+ returns = inspect.signature(func).return_annotation
1175
+
1176
+ ret_masks = []
1177
+ if typing.get_origin(returns) is Masked:
1178
+ ret_masks = [True]
1179
+ elif typing.get_origin(returns) is Table:
1180
+ for x in typing.get_args(returns):
1181
+ if typing.get_origin(x) is Masked:
1182
+ ret_masks.append(True)
1183
+ else:
1184
+ ret_masks.append(False)
1185
+ if not any(ret_masks):
1186
+ ret_masks = []
1187
+
1188
+ return (
1189
+ [typing.get_origin(x.annotation) is Masked for x in params.values()],
1190
+ ret_masks,
1191
+ )
1192
+
1193
+
1194
+ def get_signature(
1195
+ func: Callable[..., Any],
1196
+ func_name: Optional[str] = None,
1197
+ ) -> Dict[str, Any]:
1198
+ '''
1199
+ Print the UDF signature of the Python callable.
1200
+
1201
+ Parameters
1202
+ ----------
1203
+ func : Callable
1204
+ The function to extract the signature of
1205
+ func_name : str, optional
1206
+ Name override for function
1207
+
1208
+ Returns
1209
+ -------
1210
+ Dict[str, Any]
1211
+
1212
+ '''
1213
+ signature = inspect.signature(func)
1214
+ args: List[Dict[str, Any]] = []
1215
+ returns: List[Dict[str, Any]] = []
1216
+
1217
+ attrs = getattr(func, '_singlestoredb_attrs', {})
1218
+ name = attrs.get('name', func_name if func_name else func.__name__)
1219
+
1220
+ out: Dict[str, Any] = dict(name=name, args=args, returns=returns)
1221
+
1222
+ # Do not allow variable positional or keyword arguments
1223
+ for p in signature.parameters.values():
1224
+ if p.kind == inspect.Parameter.VAR_POSITIONAL:
1225
+ raise TypeError('variable positional arguments are not supported')
1226
+ elif p.kind == inspect.Parameter.VAR_KEYWORD:
1227
+ raise TypeError('variable keyword arguments are not supported')
1228
+
1229
+ # TODO: Use typing.get_type_hints() for parameters / return values?
1230
+
1231
+ # Generate the parameter type and the corresponding SQL code for that parameter
1232
+ args_schema: List[ParamSpec] = []
1233
+ args_data_formats = []
1234
+ args_colspec = [x for x in get_colspec(attrs.get('args', []))]
1235
+ args_masks, ret_masks = get_masks(func)
1236
+
1237
+ if args_colspec and len(args_colspec) != len(signature.parameters):
1238
+ raise ValueError(
1239
+ 'number of args in the decorator does not match '
1240
+ 'the number of parameters in the function signature',
1241
+ )
1242
+
1243
+ params = list(signature.parameters.values())
1244
+
1245
+ # Get the colspec for each parameter
1246
+ for i, param in enumerate(params):
1247
+ arg_schema, args_data_format, _ = get_schema(
1248
+ unpack_masked_type(param.annotation),
1249
+ overrides=[args_colspec[i]] if args_colspec else [],
1250
+ mode='parameter',
1251
+ )
1252
+ args_data_formats.append(args_data_format)
1253
+
1254
+ if len(arg_schema) > 1:
1255
+ raise TypeError(
1256
+ 'only one parameter type is supported; '
1257
+ f'got {len(arg_schema)} types for parameter {param.name}',
1258
+ )
1259
+
1260
+ # Insert parameter names as needed
1261
+ if not arg_schema[0].name:
1262
+ arg_schema[0].name = param.name
1263
+
1264
+ args_schema.append(arg_schema[0])
1265
+
1266
+ for i, pspec in enumerate(args_schema):
1267
+ default_option = {}
1268
+
1269
+ # Insert default values as needed
1270
+ if args_colspec and args_colspec[i].default is not NO_DEFAULT:
1271
+ default_option['default'] = args_colspec[i].default
1272
+ elif params and params[i].default is not param.empty:
1273
+ default_option['default'] = params[i].default
1274
+
1275
+ # Generate SQL code for the parameter
1276
+ sql = pspec.sql_type or dtype_to_sql(
1277
+ pspec.dtype,
1278
+ force_nullable=args_masks[i] or pspec.is_optional,
1279
+ **default_option,
1280
+ )
1281
+
1282
+ # Add parameter to args definitions
1283
+ args.append(
1284
+ dict(
1285
+ name=pspec.name,
1286
+ dtype=pspec.dtype,
1287
+ sql=sql,
1288
+ **default_option,
1289
+ transformer=pspec.transformer,
1290
+ ),
1291
+ )
1292
+
1293
+ # Check that all the data formats are all the same
1294
+ if len(set(args_data_formats)) > 1:
1295
+ raise TypeError(
1296
+ 'input data formats must be all be the same: '
1297
+ f'{", ".join(args_data_formats)}',
1298
+ )
1299
+
1300
+ adf = out['args_data_format'] = args_data_formats[0] \
1301
+ if args_data_formats else 'scalar'
1302
+
1303
+ returns_colspec = get_colspec(attrs.get('returns', []))
1304
+
1305
+ # Generate the return types and the corresponding SQL code for those values
1306
+ ret_schema, out['returns_data_format'], function_type = get_schema(
1307
+ unpack_masked_type(signature.return_annotation),
1308
+ overrides=returns_colspec if returns_colspec else None,
1309
+ mode='return',
1310
+ )
1311
+
1312
+ rdf = out['returns_data_format'] = out['returns_data_format'] or 'scalar'
1313
+ out['function_type'] = function_type
1314
+
1315
+ # Reality check the input and output data formats
1316
+ if function_type == 'udf':
1317
+ if (adf == 'scalar' and rdf != 'scalar') or \
1318
+ (adf != 'scalar' and rdf == 'scalar'):
1319
+ raise TypeError(
1320
+ 'Function can not have scalar arguments and a vector return type, '
1321
+ 'or vice versa. Parameters and return values must all be either ',
1322
+ 'scalar or vector types.',
1323
+ )
1324
+
1325
+ # All functions have to return a value, so if none was specified try to
1326
+ # insert a reasonable default that includes NULLs.
1327
+ if not ret_schema:
1328
+ ret_schema = [
1329
+ ParamSpec(
1330
+ dtype='int8?', sql_type='TINYINT NULL', default=None, is_optional=True,
1331
+ ),
1332
+ ]
1333
+
1334
+ if function_type == 'udf' and len(ret_schema) > 1:
1335
+ raise ValueError(
1336
+ 'UDFs can only return a single value; '
1337
+ f'got {len(ret_schema)} return values',
1338
+ )
1339
+
1340
+ # Generate field names for the return values
1341
+ if function_type == 'tvf' or len(ret_schema) > 1:
1342
+ for i, rspec in enumerate(ret_schema):
1343
+ if not rspec.name:
1344
+ ret_schema[i] = ParamSpec(
1345
+ name=string.ascii_letters[i],
1346
+ dtype=rspec.dtype,
1347
+ sql_type=rspec.sql_type,
1348
+ transformer=rspec.transformer,
1349
+ )
1350
+
1351
+ # Generate SQL code for the return values
1352
+ for i, rspec in enumerate(ret_schema):
1353
+ sql = rspec.sql_type or dtype_to_sql(
1354
+ rspec.dtype,
1355
+ force_nullable=(ret_masks[i] or rspec.is_optional)
1356
+ if ret_masks else rspec.is_optional,
1357
+ function_type=function_type,
1358
+ )
1359
+ returns.append(
1360
+ dict(
1361
+ name=rspec.name,
1362
+ dtype=rspec.dtype,
1363
+ sql=sql,
1364
+ transformer=rspec.transformer,
1365
+ ),
1366
+ )
1367
+
1368
+ # Set the function endpoint
1369
+ out['endpoint'] = '/invoke'
1370
+
1371
+ # Set the function doc string
1372
+ out['doc'] = func.__doc__
1373
+
1374
+ return out
1375
+
1376
+
1377
+ def sql_to_dtype(sql: str) -> str:
1378
+ """
1379
+ Convert a SQL type into a normalized data type identifier.
1380
+
1381
+ Parameters
1382
+ ----------
1383
+ sql : str
1384
+ SQL data type specification
1385
+
1386
+ Returns
1387
+ -------
1388
+ str
1389
+
1390
+ """
1391
+ sql = re.sub(r'\s+', r' ', sql.upper().strip())
1392
+
1393
+ m = re.match(r'(\w+)(\([^\)]+\))?', sql)
1394
+ if not m:
1395
+ raise TypeError(f'unrecognized data type: {sql}')
1396
+
1397
+ sql_type = m.group(1)
1398
+ type_attrs = re.split(r'\s*,\s*', m.group(2) or '')
1399
+
1400
+ if sql_type in ('DATETIME', 'TIME', 'TIMESTAMP') and \
1401
+ type_attrs and type_attrs[0] == '6':
1402
+ sql_type += '(6)'
1403
+
1404
+ elif ' UNSIGNED' in sql:
1405
+ sql_type += ' UNSIGNED'
1406
+
1407
+ try:
1408
+ dtype = sql_to_type_map[sql_type]
1409
+ except KeyError:
1410
+ raise TypeError(f'unrecognized data type: {sql_type}')
1411
+
1412
+ if ' NOT NULL' not in sql:
1413
+ dtype += '?'
1414
+
1415
+ return dtype
1416
+
1417
+
1418
+ def dtype_to_sql(
1419
+ dtype: str,
1420
+ default: Any = NO_DEFAULT,
1421
+ field_names: Optional[List[str]] = None,
1422
+ function_type: str = 'udf',
1423
+ force_nullable: bool = False,
1424
+ ) -> str:
1425
+ """
1426
+ Convert a collapsed dtype string to a SQL type.
1427
+
1428
+ Parameters
1429
+ ----------
1430
+ dtype : str
1431
+ Simplified data type string
1432
+ default : Any, optional
1433
+ Default value
1434
+ field_names : List[str], optional
1435
+ Field names for tuple types
1436
+ function_type : str, optional
1437
+ Function type, either 'udf' or 'tvf'
1438
+ force_nullable : bool, optional
1439
+ Whether to force the type to be nullable
1440
+
1441
+ Returns
1442
+ -------
1443
+ str
1444
+
1445
+ """
1446
+ nullable = ' NOT NULL'
1447
+ if dtype.endswith('?'):
1448
+ nullable = ' NULL'
1449
+ dtype = dtype[:-1]
1450
+ elif '|null' in dtype:
1451
+ nullable = ' NULL'
1452
+ dtype = dtype.replace('|null', '')
1453
+ elif force_nullable:
1454
+ nullable = ' NULL'
1455
+
1456
+ if dtype == 'null':
1457
+ nullable = ''
1458
+
1459
+ default_clause = ''
1460
+ if default is not NO_DEFAULT:
1461
+ if default is dt.NULL:
1462
+ default = None
1463
+ default_clause = f' DEFAULT {escape_item(default, "utf8")}'
1464
+
1465
+ if dtype.startswith('array['):
1466
+ _, dtypes = dtype.split('[', 1)
1467
+ dtypes = dtypes[:-1]
1468
+ item_dtype = dtype_to_sql(dtypes, function_type=function_type)
1469
+ return f'ARRAY({item_dtype}){nullable}{default_clause}'
1470
+
1471
+ if dtype.startswith('tuple['):
1472
+ _, dtypes = dtype.split('[', 1)
1473
+ dtypes = dtypes[:-1]
1474
+ item_dtypes = []
1475
+ for i, item in enumerate(dtypes.split(',')):
1476
+ if field_names:
1477
+ name = field_names[i]
1478
+ else:
1479
+ name = string.ascii_letters[i]
1480
+ if '=' in item:
1481
+ name, item = item.split('=', 1)
1482
+ item_dtypes.append(
1483
+ f'`{name}` ' + dtype_to_sql(item, function_type=function_type),
1484
+ )
1485
+ if function_type == 'udf':
1486
+ return f'RECORD({", ".join(item_dtypes)}){nullable}{default_clause}'
1487
+ else:
1488
+ return re.sub(
1489
+ r' NOT NULL\s*$', r'',
1490
+ f'TABLE({", ".join(item_dtypes)}){nullable}{default_clause}',
1491
+ )
1492
+
1493
+ return f'{sql_type_map[dtype]}{nullable}{default_clause}'
1494
+
1495
+
1496
+ def signature_to_sql(
1497
+ signature: Dict[str, Any],
1498
+ url: Optional[str] = None,
1499
+ data_format: str = 'rowdat_1',
1500
+ app_mode: str = 'remote',
1501
+ link: Optional[str] = None,
1502
+ replace: bool = False,
1503
+ database: Optional[str] = None,
1504
+ ) -> str:
1505
+ '''
1506
+ Convert a dictionary function signature into SQL.
1507
+
1508
+ Parameters
1509
+ ----------
1510
+ signature : Dict[str, Any]
1511
+ Function signature in the form of a dictionary as returned by
1512
+ the `get_signature` function
1513
+
1514
+ Returns
1515
+ -------
1516
+ str : SQL formatted function signature
1517
+
1518
+ '''
1519
+ function_type = signature.get('function_type') or 'udf'
1520
+
1521
+ args = []
1522
+ for arg in signature['args']:
1523
+ # Use default value from Python function if SQL doesn't set one
1524
+ default = ''
1525
+ if not re.search(r'\s+default\s+\S', arg['sql'], flags=re.I):
1526
+ default = ''
1527
+ if arg.get('default', None) is not None:
1528
+ default = f' DEFAULT {escape_item(arg["default"], "utf8")}'
1529
+ args.append(escape_name(arg['name']) + ' ' + arg['sql'] + default)
1530
+
1531
+ returns = ''
1532
+ if signature.get('returns'):
1533
+ ret = signature['returns']
1534
+ if function_type == 'tvf':
1535
+ res = 'TABLE(' + ', '.join(
1536
+ f'{escape_name(x["name"])} {x["sql"]}' for x in ret
1537
+ ) + ')'
1538
+ elif ret[0]['name'] and len(ret) > 1:
1539
+ res = 'RECORD(' + ', '.join(
1540
+ f'{escape_name(x["name"])} {x["sql"]}' for x in ret
1541
+ ) + ')'
1542
+ else:
1543
+ res = ret[0]['sql']
1544
+ returns = f' RETURNS {res}'
1545
+ else:
1546
+ raise ValueError(
1547
+ 'function signature must have a return type specified',
1548
+ )
1549
+
1550
+ host = os.environ.get('SINGLESTOREDB_EXT_HOST', '127.0.0.1')
1551
+ port = os.environ.get('SINGLESTOREDB_EXT_PORT', '8000')
1552
+
1553
+ if app_mode.lower() == 'remote':
1554
+ url = url or f'https://{host}:{port}/invoke'
1555
+ elif url is None:
1556
+ raise ValueError('url can not be `None`')
1557
+
1558
+ database_prefix = ''
1559
+ if signature.get('database'):
1560
+ database_prefix = escape_name(signature['database']) + '.'
1561
+ elif database is not None:
1562
+ database_prefix = escape_name(database) + '.'
1563
+
1564
+ or_replace = 'OR REPLACE ' if (bool(signature.get('replace')) or replace) else ''
1565
+
1566
+ link_str = ''
1567
+ if link:
1568
+ if not re.match(r'^[\w_]+$', link):
1569
+ raise ValueError(f'invalid LINK name: {link}')
1570
+ link_str = f' LINK {link}'
1571
+
1572
+ return (
1573
+ f'CREATE {or_replace}EXTERNAL FUNCTION ' +
1574
+ f'{database_prefix}{escape_name(signature["name"])}' +
1575
+ '(' + ', '.join(args) + ')' + returns +
1576
+ f' AS {app_mode.upper()} SERVICE "{url}" FORMAT {data_format.upper()}'
1577
+ f'{link_str};'
1578
+ )