singlestoredb 1.12.4__cp38-abi3-win32.whl → 1.13.0__cp38-abi3-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of singlestoredb might be problematic. Click here for more details.
- _singlestoredb_accel.pyd +0 -0
- singlestoredb/__init__.py +1 -1
- singlestoredb/apps/__init__.py +1 -0
- singlestoredb/apps/_config.py +6 -0
- singlestoredb/apps/_connection_info.py +8 -0
- singlestoredb/apps/_python_udfs.py +85 -0
- singlestoredb/config.py +14 -2
- singlestoredb/functions/__init__.py +11 -1
- singlestoredb/functions/decorator.py +102 -252
- singlestoredb/functions/dtypes.py +545 -198
- singlestoredb/functions/ext/asgi.py +288 -90
- singlestoredb/functions/ext/json.py +29 -36
- singlestoredb/functions/ext/mmap.py +1 -1
- singlestoredb/functions/ext/rowdat_1.py +50 -70
- singlestoredb/functions/signature.py +816 -144
- singlestoredb/functions/typing.py +41 -0
- singlestoredb/functions/utils.py +342 -0
- singlestoredb/http/connection.py +3 -1
- singlestoredb/management/manager.py +6 -1
- singlestoredb/management/utils.py +2 -2
- singlestoredb/tests/ext_funcs/__init__.py +476 -237
- singlestoredb/tests/test_ext_func.py +192 -3
- singlestoredb/tests/test_udf.py +101 -131
- singlestoredb/tests/test_udf_returns.py +459 -0
- {singlestoredb-1.12.4.dist-info → singlestoredb-1.13.0.dist-info}/METADATA +2 -1
- {singlestoredb-1.12.4.dist-info → singlestoredb-1.13.0.dist-info}/RECORD +30 -26
- {singlestoredb-1.12.4.dist-info → singlestoredb-1.13.0.dist-info}/LICENSE +0 -0
- {singlestoredb-1.12.4.dist-info → singlestoredb-1.13.0.dist-info}/WHEEL +0 -0
- {singlestoredb-1.12.4.dist-info → singlestoredb-1.13.0.dist-info}/entry_points.txt +0 -0
- {singlestoredb-1.12.4.dist-info → singlestoredb-1.13.0.dist-info}/top_level.txt +0 -0
|
@@ -16,6 +16,7 @@ from typing import List
|
|
|
16
16
|
from typing import Optional
|
|
17
17
|
from typing import Sequence
|
|
18
18
|
from typing import Tuple
|
|
19
|
+
from typing import Type
|
|
19
20
|
from typing import TypeVar
|
|
20
21
|
from typing import Union
|
|
21
22
|
|
|
@@ -25,13 +26,11 @@ try:
|
|
|
25
26
|
except ImportError:
|
|
26
27
|
has_numpy = False
|
|
27
28
|
|
|
28
|
-
try:
|
|
29
|
-
import pydantic
|
|
30
|
-
has_pydantic = True
|
|
31
|
-
except ImportError:
|
|
32
|
-
has_pydantic = False
|
|
33
29
|
|
|
34
30
|
from . import dtypes as dt
|
|
31
|
+
from . import utils
|
|
32
|
+
from .typing import Table
|
|
33
|
+
from .typing import Masked
|
|
35
34
|
from ..mysql.converters import escape_item # type: ignore
|
|
36
35
|
|
|
37
36
|
if sys.version_info >= (3, 10):
|
|
@@ -40,6 +39,18 @@ else:
|
|
|
40
39
|
_UNION_TYPES = {typing.Union}
|
|
41
40
|
|
|
42
41
|
|
|
42
|
+
def is_union(x: Any) -> bool:
|
|
43
|
+
"""Check if the object is a Union."""
|
|
44
|
+
return typing.get_origin(x) in _UNION_TYPES
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class NoDefaultType:
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
NO_DEFAULT = NoDefaultType()
|
|
52
|
+
|
|
53
|
+
|
|
43
54
|
array_types: Tuple[Any, ...]
|
|
44
55
|
|
|
45
56
|
if has_numpy:
|
|
@@ -192,6 +203,23 @@ class ArrayCollection(Collection):
|
|
|
192
203
|
pass
|
|
193
204
|
|
|
194
205
|
|
|
206
|
+
def get_data_format(obj: Any) -> str:
|
|
207
|
+
"""Return the data format of the DataFrame / Table / vector."""
|
|
208
|
+
# Cheating here a bit so we don't have to import pandas / polars / pyarrow
|
|
209
|
+
# unless we absolutely need to
|
|
210
|
+
if getattr(obj, '__module__', '').startswith('pandas.'):
|
|
211
|
+
return 'pandas'
|
|
212
|
+
if getattr(obj, '__module__', '').startswith('polars.'):
|
|
213
|
+
return 'polars'
|
|
214
|
+
if getattr(obj, '__module__', '').startswith('pyarrow.'):
|
|
215
|
+
return 'arrow'
|
|
216
|
+
if getattr(obj, '__module__', '').startswith('numpy.'):
|
|
217
|
+
return 'numpy'
|
|
218
|
+
if isinstance(obj, list):
|
|
219
|
+
return 'list'
|
|
220
|
+
return 'scalar'
|
|
221
|
+
|
|
222
|
+
|
|
195
223
|
def escape_name(name: str) -> str:
|
|
196
224
|
"""Escape a function parameter name."""
|
|
197
225
|
if '`' in name:
|
|
@@ -203,6 +231,12 @@ def simplify_dtype(dtype: Any) -> List[Any]:
|
|
|
203
231
|
"""
|
|
204
232
|
Expand a type annotation to a flattened list of atomic types.
|
|
205
233
|
|
|
234
|
+
This function will attempty to find the underlying type of a
|
|
235
|
+
type annotation. For example, a Union of types will be flattened
|
|
236
|
+
to a list of types. A Tuple or Array type will be expanded to
|
|
237
|
+
a list of types. A TypeVar will be expanded to a list of
|
|
238
|
+
constraints and bounds.
|
|
239
|
+
|
|
206
240
|
Parameters
|
|
207
241
|
----------
|
|
208
242
|
dtype : Any
|
|
@@ -210,7 +244,8 @@ def simplify_dtype(dtype: Any) -> List[Any]:
|
|
|
210
244
|
|
|
211
245
|
Returns
|
|
212
246
|
-------
|
|
213
|
-
List[Any]
|
|
247
|
+
List[Any]
|
|
248
|
+
list of dtype strings, TupleCollections, and ArrayCollections
|
|
214
249
|
|
|
215
250
|
"""
|
|
216
251
|
origin = typing.get_origin(dtype)
|
|
@@ -218,7 +253,7 @@ def simplify_dtype(dtype: Any) -> List[Any]:
|
|
|
218
253
|
args = []
|
|
219
254
|
|
|
220
255
|
# Flatten Unions
|
|
221
|
-
if
|
|
256
|
+
if is_union(dtype):
|
|
222
257
|
for x in typing.get_args(dtype):
|
|
223
258
|
args.extend(simplify_dtype(x))
|
|
224
259
|
|
|
@@ -230,7 +265,7 @@ def simplify_dtype(dtype: Any) -> List[Any]:
|
|
|
230
265
|
args.extend(simplify_dtype(dtype.__bound__))
|
|
231
266
|
|
|
232
267
|
# Sequence types
|
|
233
|
-
elif origin is not None and issubclass(origin, Sequence):
|
|
268
|
+
elif origin is not None and inspect.isclass(origin) and issubclass(origin, Sequence):
|
|
234
269
|
item_args: List[Union[List[type], type]] = []
|
|
235
270
|
for x in typing.get_args(dtype):
|
|
236
271
|
item_dtype = simplify_dtype(x)
|
|
@@ -252,14 +287,31 @@ def simplify_dtype(dtype: Any) -> List[Any]:
|
|
|
252
287
|
return args
|
|
253
288
|
|
|
254
289
|
|
|
255
|
-
def
|
|
256
|
-
"""
|
|
290
|
+
def normalize_dtype(dtype: Any) -> str:
|
|
291
|
+
"""
|
|
292
|
+
Normalize the type annotation into a type name.
|
|
293
|
+
|
|
294
|
+
Parameters
|
|
295
|
+
----------
|
|
296
|
+
dtype : Any
|
|
297
|
+
Type annotation, list of type annotations, or a string
|
|
298
|
+
containing a SQL type name
|
|
299
|
+
|
|
300
|
+
Returns
|
|
301
|
+
-------
|
|
302
|
+
str
|
|
303
|
+
Normalized type name
|
|
304
|
+
|
|
305
|
+
"""
|
|
257
306
|
if isinstance(dtype, list):
|
|
258
|
-
return '|'.join(
|
|
307
|
+
return '|'.join(normalize_dtype(x) for x in dtype)
|
|
259
308
|
|
|
260
309
|
if isinstance(dtype, str):
|
|
261
310
|
return sql_to_dtype(dtype)
|
|
262
311
|
|
|
312
|
+
if typing.get_origin(dtype) is np.dtype:
|
|
313
|
+
dtype = typing.get_args(dtype)[0]
|
|
314
|
+
|
|
263
315
|
# Specific types
|
|
264
316
|
if dtype is None or dtype is type(None): # noqa: E721
|
|
265
317
|
return 'null'
|
|
@@ -270,45 +322,61 @@ def classify_dtype(dtype: Any) -> str:
|
|
|
270
322
|
if dtype is bool:
|
|
271
323
|
return 'bool'
|
|
272
324
|
|
|
273
|
-
if
|
|
274
|
-
|
|
325
|
+
if utils.is_dataclass(dtype):
|
|
326
|
+
dc_fields = dataclasses.fields(dtype)
|
|
327
|
+
item_dtypes = ','.join(
|
|
328
|
+
f'{normalize_dtype(simplify_dtype(x.type))}' for x in dc_fields
|
|
329
|
+
)
|
|
330
|
+
return f'tuple[{item_dtypes}]'
|
|
331
|
+
|
|
332
|
+
if utils.is_typeddict(dtype):
|
|
333
|
+
td_fields = utils.get_annotations(dtype).keys()
|
|
334
|
+
item_dtypes = ','.join(
|
|
335
|
+
f'{normalize_dtype(simplify_dtype(dtype[x]))}' for x in td_fields
|
|
336
|
+
)
|
|
337
|
+
return f'tuple[{item_dtypes}]'
|
|
338
|
+
|
|
339
|
+
if utils.is_pydantic(dtype):
|
|
340
|
+
pyd_fields = dtype.model_fields.values()
|
|
275
341
|
item_dtypes = ','.join(
|
|
276
|
-
f'{
|
|
342
|
+
f'{normalize_dtype(simplify_dtype(x.annotation))}' # type: ignore
|
|
343
|
+
for x in pyd_fields
|
|
277
344
|
)
|
|
278
345
|
return f'tuple[{item_dtypes}]'
|
|
279
346
|
|
|
280
|
-
if
|
|
281
|
-
|
|
347
|
+
if utils.is_namedtuple(dtype):
|
|
348
|
+
nt_fields = utils.get_annotations(dtype).values()
|
|
282
349
|
item_dtypes = ','.join(
|
|
283
|
-
f'{
|
|
284
|
-
for x in fields
|
|
350
|
+
f'{normalize_dtype(simplify_dtype(dtype[x]))}' for x in nt_fields
|
|
285
351
|
)
|
|
286
352
|
return f'tuple[{item_dtypes}]'
|
|
287
353
|
|
|
288
354
|
if not inspect.isclass(dtype):
|
|
355
|
+
|
|
289
356
|
# Check for compound types
|
|
290
357
|
origin = typing.get_origin(dtype)
|
|
291
358
|
if origin is not None:
|
|
359
|
+
|
|
292
360
|
# Tuple type
|
|
293
361
|
if origin is Tuple:
|
|
294
362
|
args = typing.get_args(dtype)
|
|
295
|
-
item_dtypes = ','.join(
|
|
363
|
+
item_dtypes = ','.join(normalize_dtype(x) for x in args)
|
|
296
364
|
return f'tuple[{item_dtypes}]'
|
|
297
365
|
|
|
298
366
|
# Array types
|
|
299
|
-
elif issubclass(origin, array_types):
|
|
367
|
+
elif inspect.isclass(origin) and issubclass(origin, array_types):
|
|
300
368
|
args = typing.get_args(dtype)
|
|
301
|
-
item_dtype =
|
|
369
|
+
item_dtype = normalize_dtype(args[0])
|
|
302
370
|
return f'array[{item_dtype}]'
|
|
303
371
|
|
|
304
372
|
raise TypeError(f'unsupported type annotation: {dtype}')
|
|
305
373
|
|
|
306
374
|
if isinstance(dtype, ArrayCollection):
|
|
307
|
-
item_dtypes = ','.join(
|
|
375
|
+
item_dtypes = ','.join(normalize_dtype(x) for x in dtype.item_dtypes)
|
|
308
376
|
return f'array[{item_dtypes}]'
|
|
309
377
|
|
|
310
378
|
if isinstance(dtype, TupleCollection):
|
|
311
|
-
item_dtypes = ','.join(
|
|
379
|
+
item_dtypes = ','.join(normalize_dtype(x) for x in dtype.item_dtypes)
|
|
312
380
|
return f'tuple[{item_dtypes}]'
|
|
313
381
|
|
|
314
382
|
# Check numpy types if it's available
|
|
@@ -346,31 +414,39 @@ def classify_dtype(dtype: Any) -> str:
|
|
|
346
414
|
|
|
347
415
|
raise TypeError(
|
|
348
416
|
f'unsupported type annotation: {dtype}; '
|
|
349
|
-
'use `args`/`returns` on the @udf/@tvf
|
|
417
|
+
'use `args`/`returns` on the @udf/@tvf decorator to specify the data type',
|
|
350
418
|
)
|
|
351
419
|
|
|
352
420
|
|
|
353
|
-
def collapse_dtypes(dtypes: Union[str, List[str]]) -> str:
|
|
421
|
+
def collapse_dtypes(dtypes: Union[str, List[str]], include_null: bool = False) -> str:
|
|
354
422
|
"""
|
|
355
423
|
Collapse a dtype possibly containing multiple data types to one type.
|
|
356
424
|
|
|
425
|
+
This function can fail if there is no single type that naturally
|
|
426
|
+
encompasses all of the types in the list.
|
|
427
|
+
|
|
357
428
|
Parameters
|
|
358
429
|
----------
|
|
359
430
|
dtypes : str or list[str]
|
|
360
431
|
The data types to collapse
|
|
432
|
+
include_null : bool, optional
|
|
433
|
+
Whether to force include null types in the result
|
|
361
434
|
|
|
362
435
|
Returns
|
|
363
436
|
-------
|
|
364
437
|
str
|
|
365
438
|
|
|
366
439
|
"""
|
|
440
|
+
if isinstance(dtypes, str) and '|' in dtypes:
|
|
441
|
+
dtypes = dtypes.split('|')
|
|
442
|
+
|
|
367
443
|
if not isinstance(dtypes, list):
|
|
368
444
|
return dtypes
|
|
369
445
|
|
|
370
446
|
orig_dtypes = dtypes
|
|
371
447
|
dtypes = list(set(dtypes))
|
|
372
448
|
|
|
373
|
-
is_nullable = 'null' in dtypes
|
|
449
|
+
is_nullable = include_null or 'null' in dtypes
|
|
374
450
|
|
|
375
451
|
dtypes = [x for x in dtypes if x != 'null']
|
|
376
452
|
|
|
@@ -443,7 +519,602 @@ def collapse_dtypes(dtypes: Union[str, List[str]]) -> str:
|
|
|
443
519
|
return dtypes[0] + ('?' if is_nullable else '')
|
|
444
520
|
|
|
445
521
|
|
|
446
|
-
def
|
|
522
|
+
def get_dataclass_schema(
|
|
523
|
+
obj: Any,
|
|
524
|
+
include_default: bool = False,
|
|
525
|
+
) -> List[Union[Tuple[str, Any], Tuple[str, Any, Any]]]:
|
|
526
|
+
"""
|
|
527
|
+
Get the schema of a dataclass.
|
|
528
|
+
|
|
529
|
+
Parameters
|
|
530
|
+
----------
|
|
531
|
+
obj : dataclass
|
|
532
|
+
The dataclass to get the schema of
|
|
533
|
+
|
|
534
|
+
Returns
|
|
535
|
+
-------
|
|
536
|
+
List[Tuple[str, Any]] | List[Tuple[str, Any, Any]]
|
|
537
|
+
A list of tuples containing the field names and field types
|
|
538
|
+
|
|
539
|
+
"""
|
|
540
|
+
if include_default:
|
|
541
|
+
return [
|
|
542
|
+
(
|
|
543
|
+
f.name, f.type,
|
|
544
|
+
NO_DEFAULT if f.default is dataclasses.MISSING else f.default,
|
|
545
|
+
)
|
|
546
|
+
for f in dataclasses.fields(obj)
|
|
547
|
+
]
|
|
548
|
+
return [(f.name, f.type) for f in dataclasses.fields(obj)]
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
def get_typeddict_schema(
|
|
552
|
+
obj: Any,
|
|
553
|
+
include_default: bool = False,
|
|
554
|
+
) -> List[Union[Tuple[str, Any], Tuple[str, Any, Any]]]:
|
|
555
|
+
"""
|
|
556
|
+
Get the schema of a TypedDict.
|
|
557
|
+
|
|
558
|
+
Parameters
|
|
559
|
+
----------
|
|
560
|
+
obj : TypedDict
|
|
561
|
+
The TypedDict to get the schema of
|
|
562
|
+
include_default : bool, optional
|
|
563
|
+
Whether to include the default value in the column specification
|
|
564
|
+
|
|
565
|
+
Returns
|
|
566
|
+
-------
|
|
567
|
+
List[Tuple[str, Any]] | List[Tuple[str, Any, Any]]
|
|
568
|
+
A list of tuples containing the field names and field types
|
|
569
|
+
|
|
570
|
+
"""
|
|
571
|
+
if include_default:
|
|
572
|
+
return [
|
|
573
|
+
(k, v, getattr(obj, k, NO_DEFAULT))
|
|
574
|
+
for k, v in utils.get_annotations(obj).items()
|
|
575
|
+
]
|
|
576
|
+
return list(utils.get_annotations(obj).items())
|
|
577
|
+
|
|
578
|
+
|
|
579
|
+
def get_pydantic_schema(
|
|
580
|
+
obj: Any,
|
|
581
|
+
include_default: bool = False,
|
|
582
|
+
) -> List[Union[Tuple[str, Any], Tuple[str, Any, Any]]]:
|
|
583
|
+
"""
|
|
584
|
+
Get the schema of a pydantic model.
|
|
585
|
+
|
|
586
|
+
Parameters
|
|
587
|
+
----------
|
|
588
|
+
obj : pydantic.BaseModel
|
|
589
|
+
The pydantic model to get the schema of
|
|
590
|
+
include_default : bool, optional
|
|
591
|
+
Whether to include the default value in the column specification
|
|
592
|
+
|
|
593
|
+
Returns
|
|
594
|
+
-------
|
|
595
|
+
List[Tuple[str, Any]] | List[Tuple[str, Any, Any]]
|
|
596
|
+
A list of tuples containing the field names and field types
|
|
597
|
+
|
|
598
|
+
"""
|
|
599
|
+
import pydantic_core
|
|
600
|
+
if include_default:
|
|
601
|
+
return [
|
|
602
|
+
(
|
|
603
|
+
k, v.annotation,
|
|
604
|
+
NO_DEFAULT if v.default is pydantic_core.PydanticUndefined else v.default,
|
|
605
|
+
)
|
|
606
|
+
for k, v in obj.model_fields.items()
|
|
607
|
+
]
|
|
608
|
+
return [(k, v.annotation) for k, v in obj.model_fields.items()]
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
def get_namedtuple_schema(
|
|
612
|
+
obj: Any,
|
|
613
|
+
include_default: bool = False,
|
|
614
|
+
) -> List[Union[Tuple[Any, str], Tuple[Any, str, Any]]]:
|
|
615
|
+
"""
|
|
616
|
+
Get the schema of a named tuple.
|
|
617
|
+
|
|
618
|
+
Parameters
|
|
619
|
+
----------
|
|
620
|
+
obj : NamedTuple
|
|
621
|
+
The named tuple to get the schema of
|
|
622
|
+
include_default : bool, optional
|
|
623
|
+
Whether to include the default value in the column specification
|
|
624
|
+
|
|
625
|
+
Returns
|
|
626
|
+
-------
|
|
627
|
+
List[Tuple[Any, str]] | List[Tuple[Any, str, Any]]
|
|
628
|
+
A list of tuples containing the field names and field types
|
|
629
|
+
|
|
630
|
+
"""
|
|
631
|
+
if include_default:
|
|
632
|
+
return [
|
|
633
|
+
(
|
|
634
|
+
k, v,
|
|
635
|
+
obj._field_defaults.get(k, NO_DEFAULT),
|
|
636
|
+
)
|
|
637
|
+
for k, v in utils.get_annotations(obj).items()
|
|
638
|
+
]
|
|
639
|
+
return list(utils.get_annotations(obj).items())
|
|
640
|
+
|
|
641
|
+
|
|
642
|
+
def get_table_schema(
|
|
643
|
+
obj: Any,
|
|
644
|
+
include_default: bool = False,
|
|
645
|
+
) -> List[Union[Tuple[Any, str], Tuple[Any, str, Any]]]:
|
|
646
|
+
"""
|
|
647
|
+
Get the schema of a Table.
|
|
648
|
+
|
|
649
|
+
Parameters
|
|
650
|
+
----------
|
|
651
|
+
obj : Table
|
|
652
|
+
The Table to get the schema of
|
|
653
|
+
include_default : bool, optional
|
|
654
|
+
Whether to include the default value in the column specification
|
|
655
|
+
|
|
656
|
+
Returns
|
|
657
|
+
-------
|
|
658
|
+
List[Tuple[Any, str]] | List[Tuple[Any, str, Any]]
|
|
659
|
+
A list of tuples containing the field names and field types
|
|
660
|
+
|
|
661
|
+
"""
|
|
662
|
+
if include_default:
|
|
663
|
+
return [
|
|
664
|
+
(k, v, getattr(obj, k, NO_DEFAULT))
|
|
665
|
+
for k, v in utils.get_annotations(obj).items()
|
|
666
|
+
]
|
|
667
|
+
return list(utils.get_annotations(obj).items())
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
def get_colspec(
|
|
671
|
+
overrides: Any,
|
|
672
|
+
include_default: bool = False,
|
|
673
|
+
) -> List[Union[Tuple[str, Any], Tuple[str, Any, Any]]]:
|
|
674
|
+
"""
|
|
675
|
+
Get the column specification from the overrides.
|
|
676
|
+
|
|
677
|
+
Parameters
|
|
678
|
+
----------
|
|
679
|
+
overrides : Any
|
|
680
|
+
The overrides to get the column specification from
|
|
681
|
+
include_default : bool, optional
|
|
682
|
+
Whether to include the default value in the column specification
|
|
683
|
+
|
|
684
|
+
Returns
|
|
685
|
+
-------
|
|
686
|
+
List[Tuple[str, Any]] | List[Tuple[str, Any, Any]]
|
|
687
|
+
A list of tuples containing the field names and field types
|
|
688
|
+
|
|
689
|
+
"""
|
|
690
|
+
overrides_colspec = []
|
|
691
|
+
|
|
692
|
+
if overrides:
|
|
693
|
+
|
|
694
|
+
# Dataclass
|
|
695
|
+
if utils.is_dataclass(overrides):
|
|
696
|
+
overrides_colspec = get_dataclass_schema(
|
|
697
|
+
overrides, include_default=include_default,
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
# TypedDict
|
|
701
|
+
elif utils.is_typeddict(overrides):
|
|
702
|
+
overrides_colspec = get_typeddict_schema(
|
|
703
|
+
overrides, include_default=include_default,
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
# Named tuple
|
|
707
|
+
elif utils.is_namedtuple(overrides):
|
|
708
|
+
overrides_colspec = get_namedtuple_schema(
|
|
709
|
+
overrides, include_default=include_default,
|
|
710
|
+
)
|
|
711
|
+
|
|
712
|
+
# Pydantic model
|
|
713
|
+
elif utils.is_pydantic(overrides):
|
|
714
|
+
overrides_colspec = get_pydantic_schema(
|
|
715
|
+
overrides, include_default=include_default,
|
|
716
|
+
)
|
|
717
|
+
|
|
718
|
+
# List of types
|
|
719
|
+
elif isinstance(overrides, list):
|
|
720
|
+
if include_default:
|
|
721
|
+
overrides_colspec = [
|
|
722
|
+
(getattr(x, 'name', ''), x, NO_DEFAULT) for x in overrides
|
|
723
|
+
]
|
|
724
|
+
else:
|
|
725
|
+
overrides_colspec = [(getattr(x, 'name', ''), x) for x in overrides]
|
|
726
|
+
|
|
727
|
+
# Other
|
|
728
|
+
else:
|
|
729
|
+
if include_default:
|
|
730
|
+
overrides_colspec = [
|
|
731
|
+
(getattr(overrides, 'name', ''), overrides, NO_DEFAULT),
|
|
732
|
+
]
|
|
733
|
+
else:
|
|
734
|
+
overrides_colspec = [(getattr(overrides, 'name', ''), overrides)]
|
|
735
|
+
|
|
736
|
+
return overrides_colspec
|
|
737
|
+
|
|
738
|
+
|
|
739
|
+
def unpack_masked_type(obj: Any) -> Any:
|
|
740
|
+
"""
|
|
741
|
+
Unpack a masked type into a single type.
|
|
742
|
+
|
|
743
|
+
Parameters
|
|
744
|
+
----------
|
|
745
|
+
obj : Any
|
|
746
|
+
The masked type to unpack
|
|
747
|
+
|
|
748
|
+
Returns
|
|
749
|
+
-------
|
|
750
|
+
Any
|
|
751
|
+
The unpacked type
|
|
752
|
+
|
|
753
|
+
"""
|
|
754
|
+
if typing.get_origin(obj) is Masked:
|
|
755
|
+
return typing.get_args(obj)[0]
|
|
756
|
+
return obj
|
|
757
|
+
|
|
758
|
+
|
|
759
|
+
def get_schema(
|
|
760
|
+
spec: Any,
|
|
761
|
+
overrides: Optional[Union[List[str], Type[Any]]] = None,
|
|
762
|
+
mode: str = 'parameter',
|
|
763
|
+
) -> Tuple[List[Tuple[str, Any, Optional[str]]], str, str]:
|
|
764
|
+
"""
|
|
765
|
+
Expand a return type annotation into a list of types and field names.
|
|
766
|
+
|
|
767
|
+
Parameters
|
|
768
|
+
----------
|
|
769
|
+
spec : Any
|
|
770
|
+
The return type specification
|
|
771
|
+
overrides : List[str], optional
|
|
772
|
+
List of SQL type specifications for the return type
|
|
773
|
+
mode : str
|
|
774
|
+
The mode of the function, either 'parameter' or 'return'
|
|
775
|
+
|
|
776
|
+
Returns
|
|
777
|
+
-------
|
|
778
|
+
Tuple[List[Tuple[str, Any, Optional[str]]], str, str]
|
|
779
|
+
A list of tuples containing the field names and field types,
|
|
780
|
+
the normalized data format, optionally the SQL
|
|
781
|
+
definition of the type, and the data format of the type
|
|
782
|
+
|
|
783
|
+
"""
|
|
784
|
+
colspec = []
|
|
785
|
+
data_format = ''
|
|
786
|
+
function_type = 'udf'
|
|
787
|
+
|
|
788
|
+
origin = typing.get_origin(spec)
|
|
789
|
+
args = typing.get_args(spec)
|
|
790
|
+
args_origins = [typing.get_origin(x) if x is not None else None for x in args]
|
|
791
|
+
|
|
792
|
+
# Make sure that the result of a TVF is a list or dataframe
|
|
793
|
+
if mode == 'return':
|
|
794
|
+
|
|
795
|
+
# See if it's a Table subclass with annotations
|
|
796
|
+
if inspect.isclass(origin) and origin is Table:
|
|
797
|
+
|
|
798
|
+
function_type = 'tvf'
|
|
799
|
+
|
|
800
|
+
if utils.is_dataframe(args[0]):
|
|
801
|
+
if not overrides:
|
|
802
|
+
raise TypeError(
|
|
803
|
+
'column types must be specified by the '
|
|
804
|
+
'`returns=` parameter of the @udf decorator',
|
|
805
|
+
)
|
|
806
|
+
|
|
807
|
+
if utils.get_module(args[0]) in ['pandas', 'polars', 'pyarrow']:
|
|
808
|
+
data_format = utils.get_module(args[0])
|
|
809
|
+
spec = args[0]
|
|
810
|
+
else:
|
|
811
|
+
raise TypeError(
|
|
812
|
+
'only pandas.DataFrames, polars.DataFrames, '
|
|
813
|
+
'and pyarrow.Tables are supported as tables.',
|
|
814
|
+
)
|
|
815
|
+
|
|
816
|
+
elif typing.get_origin(args[0]) is list:
|
|
817
|
+
if len(args) != 1:
|
|
818
|
+
raise TypeError(
|
|
819
|
+
'only one list is supported within a table; to '
|
|
820
|
+
'return multiple columns, use a tuple, NamedTuple, '
|
|
821
|
+
'dataclass, TypedDict, or pydantic model',
|
|
822
|
+
)
|
|
823
|
+
spec = typing.get_args(args[0])[0]
|
|
824
|
+
data_format = 'list'
|
|
825
|
+
|
|
826
|
+
elif all([utils.is_vector(x, include_masks=True) for x in args]):
|
|
827
|
+
pass
|
|
828
|
+
|
|
829
|
+
else:
|
|
830
|
+
raise TypeError(
|
|
831
|
+
'return type for TVF must be a list, DataFrame / Table, '
|
|
832
|
+
'or tuple of vectors',
|
|
833
|
+
)
|
|
834
|
+
|
|
835
|
+
# Short circuit check for common valid types
|
|
836
|
+
elif utils.is_vector(spec) or spec in [str, float, int, bytes]:
|
|
837
|
+
pass
|
|
838
|
+
|
|
839
|
+
# Try to catch some common mistakes
|
|
840
|
+
elif origin in [tuple, dict] or tuple in args_origins or \
|
|
841
|
+
(
|
|
842
|
+
inspect.isclass(spec) and
|
|
843
|
+
(
|
|
844
|
+
utils.is_dataframe(spec)
|
|
845
|
+
or utils.is_dataclass(spec)
|
|
846
|
+
or utils.is_typeddict(spec)
|
|
847
|
+
or utils.is_pydantic(spec)
|
|
848
|
+
or utils.is_namedtuple(spec)
|
|
849
|
+
)
|
|
850
|
+
):
|
|
851
|
+
raise TypeError(
|
|
852
|
+
'invalid return type for a UDF; '
|
|
853
|
+
f'expecting a scalar or vector, but got {spec}',
|
|
854
|
+
)
|
|
855
|
+
|
|
856
|
+
# Short circuit check for common valid types
|
|
857
|
+
elif utils.is_vector(spec) or spec in [str, float, int, bytes]:
|
|
858
|
+
pass
|
|
859
|
+
|
|
860
|
+
# Error out for incorrect parameter types
|
|
861
|
+
elif origin in [tuple, dict] or tuple in args_origins or \
|
|
862
|
+
(
|
|
863
|
+
inspect.isclass(spec) and
|
|
864
|
+
(
|
|
865
|
+
utils.is_dataframe(spec)
|
|
866
|
+
or utils.is_dataclass(spec)
|
|
867
|
+
or utils.is_typeddict(spec)
|
|
868
|
+
or utils.is_pydantic(spec)
|
|
869
|
+
or utils.is_namedtuple(spec)
|
|
870
|
+
)
|
|
871
|
+
):
|
|
872
|
+
raise TypeError(f'parameter types must be scalar or vector, got {spec}')
|
|
873
|
+
|
|
874
|
+
#
|
|
875
|
+
# Process each parameter / return type into a colspec
|
|
876
|
+
#
|
|
877
|
+
|
|
878
|
+
# Compute overrides colspec from various formats
|
|
879
|
+
overrides_colspec = get_colspec(overrides)
|
|
880
|
+
|
|
881
|
+
# Dataframe type
|
|
882
|
+
if utils.is_dataframe(spec):
|
|
883
|
+
colspec = overrides_colspec
|
|
884
|
+
|
|
885
|
+
# Numpy array types
|
|
886
|
+
elif utils.is_numpy(spec):
|
|
887
|
+
data_format = 'numpy'
|
|
888
|
+
if overrides:
|
|
889
|
+
colspec = overrides_colspec
|
|
890
|
+
elif len(typing.get_args(spec)) < 2:
|
|
891
|
+
raise TypeError(
|
|
892
|
+
'numpy array must have a data type specified '
|
|
893
|
+
'in the @udf decorator or with an NDArray type annotation',
|
|
894
|
+
)
|
|
895
|
+
else:
|
|
896
|
+
colspec = [('', typing.get_args(spec)[1])]
|
|
897
|
+
|
|
898
|
+
# Pandas Series
|
|
899
|
+
elif utils.is_pandas_series(spec):
|
|
900
|
+
data_format = 'pandas'
|
|
901
|
+
if not overrides:
|
|
902
|
+
raise TypeError(
|
|
903
|
+
'pandas Series must have a data type specified '
|
|
904
|
+
'in the @udf decorator',
|
|
905
|
+
)
|
|
906
|
+
colspec = overrides_colspec
|
|
907
|
+
|
|
908
|
+
# Polars Series
|
|
909
|
+
elif utils.is_polars_series(spec):
|
|
910
|
+
data_format = 'polars'
|
|
911
|
+
if not overrides:
|
|
912
|
+
raise TypeError(
|
|
913
|
+
'polars Series must have a data type specified '
|
|
914
|
+
'in the @udf decorator',
|
|
915
|
+
)
|
|
916
|
+
colspec = overrides_colspec
|
|
917
|
+
|
|
918
|
+
# PyArrow Array
|
|
919
|
+
elif utils.is_pyarrow_array(spec):
|
|
920
|
+
data_format = 'arrow'
|
|
921
|
+
if not overrides:
|
|
922
|
+
raise TypeError(
|
|
923
|
+
'pyarrow Arrays must have a data type specified '
|
|
924
|
+
'in the @udf decorator',
|
|
925
|
+
)
|
|
926
|
+
colspec = overrides_colspec
|
|
927
|
+
|
|
928
|
+
# Return type is specified by a dataclass definition
|
|
929
|
+
elif utils.is_dataclass(spec):
|
|
930
|
+
colspec = overrides_colspec or get_dataclass_schema(spec)
|
|
931
|
+
|
|
932
|
+
# Return type is specified by a TypedDict definition
|
|
933
|
+
elif utils.is_typeddict(spec):
|
|
934
|
+
colspec = overrides_colspec or get_typeddict_schema(spec)
|
|
935
|
+
|
|
936
|
+
# Return type is specified by a pydantic model
|
|
937
|
+
elif utils.is_pydantic(spec):
|
|
938
|
+
colspec = overrides_colspec or get_pydantic_schema(spec)
|
|
939
|
+
|
|
940
|
+
# Return type is specified by a named tuple
|
|
941
|
+
elif utils.is_namedtuple(spec):
|
|
942
|
+
colspec = overrides_colspec or get_namedtuple_schema(spec)
|
|
943
|
+
|
|
944
|
+
# Unrecognized return type
|
|
945
|
+
elif spec is not None:
|
|
946
|
+
|
|
947
|
+
# Return type is specified by a SQL string
|
|
948
|
+
if isinstance(spec, str):
|
|
949
|
+
data_format = 'scalar'
|
|
950
|
+
colspec = [(getattr(spec, 'name', ''), spec)]
|
|
951
|
+
|
|
952
|
+
# Plain list vector
|
|
953
|
+
elif typing.get_origin(spec) is list:
|
|
954
|
+
data_format = 'list'
|
|
955
|
+
colspec = [('', typing.get_args(spec)[0])]
|
|
956
|
+
|
|
957
|
+
# Multiple return values
|
|
958
|
+
elif inspect.isclass(typing.get_origin(spec)) \
|
|
959
|
+
and issubclass(typing.get_origin(spec), tuple): # type: ignore[arg-type]
|
|
960
|
+
|
|
961
|
+
out_names, out_overrides = [], []
|
|
962
|
+
|
|
963
|
+
# Get the colspec for the overrides
|
|
964
|
+
if overrides:
|
|
965
|
+
out_colspec = [
|
|
966
|
+
x for x in get_colspec(overrides, include_default=True)
|
|
967
|
+
]
|
|
968
|
+
out_names = [x[0] for x in out_colspec]
|
|
969
|
+
out_overrides = [x[1] for x in out_colspec]
|
|
970
|
+
|
|
971
|
+
# Make sure that the number of overrides matches the number of
|
|
972
|
+
# return types or parameter types
|
|
973
|
+
if out_overrides and len(typing.get_args(spec)) != len(out_overrides):
|
|
974
|
+
raise ValueError(
|
|
975
|
+
f'number of {mode} types does not match the number of '
|
|
976
|
+
'overrides specified',
|
|
977
|
+
)
|
|
978
|
+
|
|
979
|
+
colspec = []
|
|
980
|
+
out_data_formats = []
|
|
981
|
+
|
|
982
|
+
# Get the colspec for each item in the tuple
|
|
983
|
+
for i, x in enumerate(typing.get_args(spec)):
|
|
984
|
+
out_item, out_data_format, _ = get_schema(
|
|
985
|
+
unpack_masked_type(x),
|
|
986
|
+
overrides=out_overrides[i] if out_overrides else [],
|
|
987
|
+
# Always pass UDF mode for individual items
|
|
988
|
+
mode=mode,
|
|
989
|
+
)
|
|
990
|
+
|
|
991
|
+
# Use the name from the overrides if specified
|
|
992
|
+
if out_names and out_names[i] and not out_item[0][0]:
|
|
993
|
+
out_item = [(out_names[i], *out_item[0][1:])]
|
|
994
|
+
elif not out_item[0][0]:
|
|
995
|
+
out_item = [(f'{string.ascii_letters[i]}', *out_item[0][1:])]
|
|
996
|
+
|
|
997
|
+
colspec += out_item
|
|
998
|
+
out_data_formats.append(out_data_format)
|
|
999
|
+
|
|
1000
|
+
# Make sure that all the data formats are the same
|
|
1001
|
+
if len(set(out_data_formats)) > 1:
|
|
1002
|
+
raise TypeError(
|
|
1003
|
+
'data formats must be all be the same vector / scalar type: '
|
|
1004
|
+
f'{", ".join(out_data_formats)}',
|
|
1005
|
+
)
|
|
1006
|
+
|
|
1007
|
+
if data_format != 'list' and out_data_formats:
|
|
1008
|
+
data_format = out_data_formats[0]
|
|
1009
|
+
|
|
1010
|
+
# Since the colspec was computed by get_schema already, don't go
|
|
1011
|
+
# through the process of normalizing the dtypes again
|
|
1012
|
+
return colspec, data_format, function_type # type: ignore
|
|
1013
|
+
|
|
1014
|
+
# Use overrides if specified
|
|
1015
|
+
elif overrides:
|
|
1016
|
+
if not data_format:
|
|
1017
|
+
data_format = get_data_format(spec)
|
|
1018
|
+
colspec = overrides_colspec
|
|
1019
|
+
|
|
1020
|
+
# Single value, no override
|
|
1021
|
+
else:
|
|
1022
|
+
if not data_format:
|
|
1023
|
+
data_format = 'scalar'
|
|
1024
|
+
colspec = [('', spec)]
|
|
1025
|
+
|
|
1026
|
+
out = []
|
|
1027
|
+
|
|
1028
|
+
# Normalize colspec data types
|
|
1029
|
+
for k, v, *_ in colspec:
|
|
1030
|
+
out.append((
|
|
1031
|
+
k,
|
|
1032
|
+
collapse_dtypes(
|
|
1033
|
+
[normalize_dtype(x) for x in simplify_dtype(v)],
|
|
1034
|
+
),
|
|
1035
|
+
v if isinstance(v, str) else None,
|
|
1036
|
+
))
|
|
1037
|
+
|
|
1038
|
+
return out, data_format, function_type
|
|
1039
|
+
|
|
1040
|
+
|
|
1041
|
+
def vector_check(obj: Any) -> Tuple[Any, str]:
|
|
1042
|
+
"""
|
|
1043
|
+
Check if the object is a vector type.
|
|
1044
|
+
|
|
1045
|
+
Parameters
|
|
1046
|
+
----------
|
|
1047
|
+
obj : Any
|
|
1048
|
+
The object to check
|
|
1049
|
+
|
|
1050
|
+
Returns
|
|
1051
|
+
-------
|
|
1052
|
+
Tuple[Any, str]
|
|
1053
|
+
The scalar type and the data format:
|
|
1054
|
+
'scalar', 'list', 'numpy', 'pandas', or 'polars'
|
|
1055
|
+
|
|
1056
|
+
"""
|
|
1057
|
+
if utils.is_numpy(obj):
|
|
1058
|
+
if len(typing.get_args(obj)) < 2:
|
|
1059
|
+
return None, 'numpy'
|
|
1060
|
+
return typing.get_args(obj)[1], 'numpy'
|
|
1061
|
+
if utils.is_pandas_series(obj):
|
|
1062
|
+
if len(typing.get_args(obj)) < 2:
|
|
1063
|
+
return None, 'pandas'
|
|
1064
|
+
return typing.get_args(obj)[1], 'pandas'
|
|
1065
|
+
if utils.is_polars_series(obj):
|
|
1066
|
+
return None, 'polars'
|
|
1067
|
+
if utils.is_pyarrow_array(obj):
|
|
1068
|
+
return None, 'arrow'
|
|
1069
|
+
if obj is list or typing.get_origin(obj) is list:
|
|
1070
|
+
if len(typing.get_args(obj)) < 1:
|
|
1071
|
+
return None, 'list'
|
|
1072
|
+
return typing.get_args(obj)[0], 'list'
|
|
1073
|
+
return obj, 'scalar'
|
|
1074
|
+
|
|
1075
|
+
|
|
1076
|
+
def get_masks(func: Callable[..., Any]) -> Tuple[List[bool], List[bool]]:
|
|
1077
|
+
"""
|
|
1078
|
+
Get the list of masked parameters and return values for the function.
|
|
1079
|
+
|
|
1080
|
+
Parameters
|
|
1081
|
+
----------
|
|
1082
|
+
func : Callable
|
|
1083
|
+
The function to call as the endpoint
|
|
1084
|
+
|
|
1085
|
+
Returns
|
|
1086
|
+
-------
|
|
1087
|
+
Tuple[List[bool], List[bool]]
|
|
1088
|
+
A Tuple containing the parameter / return value masks
|
|
1089
|
+
as lists of booleans
|
|
1090
|
+
|
|
1091
|
+
|
|
1092
|
+
"""
|
|
1093
|
+
params = inspect.signature(func).parameters
|
|
1094
|
+
returns = inspect.signature(func).return_annotation
|
|
1095
|
+
|
|
1096
|
+
ret_masks = []
|
|
1097
|
+
if typing.get_origin(returns) is Masked:
|
|
1098
|
+
ret_masks = [True]
|
|
1099
|
+
elif typing.get_origin(returns) is Table:
|
|
1100
|
+
for x in typing.get_args(returns):
|
|
1101
|
+
if typing.get_origin(x) is Masked:
|
|
1102
|
+
ret_masks.append(True)
|
|
1103
|
+
else:
|
|
1104
|
+
ret_masks.append(False)
|
|
1105
|
+
if not any(ret_masks):
|
|
1106
|
+
ret_masks = []
|
|
1107
|
+
|
|
1108
|
+
return (
|
|
1109
|
+
[typing.get_origin(x.annotation) is Masked for x in params.values()],
|
|
1110
|
+
ret_masks,
|
|
1111
|
+
)
|
|
1112
|
+
|
|
1113
|
+
|
|
1114
|
+
def get_signature(
|
|
1115
|
+
func: Callable[..., Any],
|
|
1116
|
+
func_name: Optional[str] = None,
|
|
1117
|
+
) -> Dict[str, Any]:
|
|
447
1118
|
'''
|
|
448
1119
|
Print the UDF signature of the Python callable.
|
|
449
1120
|
|
|
@@ -451,7 +1122,7 @@ def get_signature(func: Callable[..., Any], name: Optional[str] = None) -> Dict[
|
|
|
451
1122
|
----------
|
|
452
1123
|
func : Callable
|
|
453
1124
|
The function to extract the signature of
|
|
454
|
-
|
|
1125
|
+
func_name : str, optional
|
|
455
1126
|
Name override for function
|
|
456
1127
|
|
|
457
1128
|
Returns
|
|
@@ -461,138 +1132,113 @@ def get_signature(func: Callable[..., Any], name: Optional[str] = None) -> Dict[
|
|
|
461
1132
|
'''
|
|
462
1133
|
signature = inspect.signature(func)
|
|
463
1134
|
args: List[Dict[str, Any]] = []
|
|
1135
|
+
returns: List[Dict[str, Any]] = []
|
|
1136
|
+
|
|
464
1137
|
attrs = getattr(func, '_singlestoredb_attrs', {})
|
|
465
|
-
name = attrs.get('name',
|
|
466
|
-
|
|
467
|
-
out: Dict[str, Any] = dict(name=name, args=args)
|
|
468
|
-
|
|
469
|
-
arg_names = [x for x in signature.parameters]
|
|
470
|
-
defaults = [
|
|
471
|
-
x.default if x.default is not inspect.Parameter.empty else None
|
|
472
|
-
for x in signature.parameters.values()
|
|
473
|
-
]
|
|
474
|
-
annotations = {
|
|
475
|
-
k: x.annotation for k, x in signature.parameters.items()
|
|
476
|
-
if x.annotation is not inspect.Parameter.empty
|
|
477
|
-
}
|
|
1138
|
+
name = attrs.get('name', func_name if func_name else func.__name__)
|
|
1139
|
+
|
|
1140
|
+
out: Dict[str, Any] = dict(name=name, args=args, returns=returns)
|
|
478
1141
|
|
|
1142
|
+
# Do not allow variable positional or keyword arguments
|
|
479
1143
|
for p in signature.parameters.values():
|
|
480
1144
|
if p.kind == inspect.Parameter.VAR_POSITIONAL:
|
|
481
1145
|
raise TypeError('variable positional arguments are not supported')
|
|
482
1146
|
elif p.kind == inspect.Parameter.VAR_KEYWORD:
|
|
483
1147
|
raise TypeError('variable keyword arguments are not supported')
|
|
484
1148
|
|
|
485
|
-
|
|
486
|
-
returns_overrides = attrs.get('returns', None)
|
|
487
|
-
output_fields = attrs.get('output_fields', None)
|
|
1149
|
+
# TODO: Use typing.get_type_hints() for parameters / return values?
|
|
488
1150
|
|
|
489
|
-
|
|
1151
|
+
# Generate the parameter type and the corresponding SQL code for that parameter
|
|
1152
|
+
args_schema = []
|
|
1153
|
+
args_data_formats = []
|
|
1154
|
+
args_colspec = [x for x in get_colspec(attrs.get('args', []), include_default=True)]
|
|
1155
|
+
args_overrides = [x[1] for x in args_colspec]
|
|
1156
|
+
args_defaults = [x[2] for x in args_colspec] # type: ignore
|
|
1157
|
+
args_masks, ret_masks = get_masks(func)
|
|
490
1158
|
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
'
|
|
495
|
-
.format(', '.join(spec_diff), name),
|
|
1159
|
+
if args_overrides and len(args_overrides) != len(signature.parameters):
|
|
1160
|
+
raise ValueError(
|
|
1161
|
+
'number of args in the decorator does not match '
|
|
1162
|
+
'the number of parameters in the function signature',
|
|
496
1163
|
)
|
|
497
|
-
elif isinstance(args_overrides, dict):
|
|
498
|
-
for s in spec_diff:
|
|
499
|
-
if s not in args_overrides:
|
|
500
|
-
raise TypeError(
|
|
501
|
-
'missing annotations for {} in {}'
|
|
502
|
-
.format(', '.join(spec_diff), name),
|
|
503
|
-
)
|
|
504
|
-
elif isinstance(args_overrides, list):
|
|
505
|
-
if len(arg_names) != len(args_overrides):
|
|
506
|
-
raise TypeError(
|
|
507
|
-
'number of annotations does not match in {}: {}'
|
|
508
|
-
.format(name, ', '.join(spec_diff)),
|
|
509
|
-
)
|
|
510
1164
|
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
1165
|
+
params = list(signature.parameters.values())
|
|
1166
|
+
|
|
1167
|
+
# Get the colspec for each parameter
|
|
1168
|
+
for i, param in enumerate(params):
|
|
1169
|
+
arg_schema, args_data_format, _ = get_schema(
|
|
1170
|
+
unpack_masked_type(param.annotation),
|
|
1171
|
+
overrides=args_overrides[i] if args_overrides else [],
|
|
1172
|
+
mode='parameter',
|
|
1173
|
+
)
|
|
1174
|
+
args_data_formats.append(args_data_format)
|
|
1175
|
+
|
|
1176
|
+
# Insert parameter names as needed
|
|
1177
|
+
if not arg_schema[0][0]:
|
|
1178
|
+
args_schema.append((param.name, *arg_schema[0][1:]))
|
|
1179
|
+
|
|
1180
|
+
for i, (name, atype, sql) in enumerate(args_schema):
|
|
1181
|
+
default_option = {}
|
|
1182
|
+
|
|
1183
|
+
# Insert default values as needed
|
|
1184
|
+
if args_defaults:
|
|
1185
|
+
if args_defaults[i] is not NO_DEFAULT:
|
|
1186
|
+
default_option['default'] = args_defaults[i]
|
|
524
1187
|
else:
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
if returns_overrides is None \
|
|
532
|
-
and signature.return_annotation is inspect.Signature.empty:
|
|
533
|
-
raise TypeError(f'no return value annotation in function {name}')
|
|
534
|
-
|
|
535
|
-
if isinstance(returns_overrides, str):
|
|
536
|
-
sql = returns_overrides
|
|
537
|
-
out_type = sql_to_dtype(sql)
|
|
538
|
-
elif isinstance(returns_overrides, list):
|
|
539
|
-
if not output_fields:
|
|
540
|
-
output_fields = [
|
|
541
|
-
string.ascii_letters[i] for i in range(len(returns_overrides))
|
|
542
|
-
]
|
|
543
|
-
out_type = 'tuple[' + collapse_dtypes([
|
|
544
|
-
classify_dtype(x)
|
|
545
|
-
for x in simplify_dtype(returns_overrides)
|
|
546
|
-
]).replace('|', ',') + ']'
|
|
547
|
-
sql = dtype_to_sql(
|
|
548
|
-
out_type, function_type=function_type, field_names=output_fields,
|
|
1188
|
+
if params[i].default is not param.empty:
|
|
1189
|
+
default_option['default'] = params[i].default
|
|
1190
|
+
|
|
1191
|
+
# Generate SQL code for the parameter
|
|
1192
|
+
sql = sql or dtype_to_sql(
|
|
1193
|
+
atype, force_nullable=args_masks[i], **default_option,
|
|
549
1194
|
)
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
1195
|
+
|
|
1196
|
+
# Add parameter to args definitions
|
|
1197
|
+
args.append(dict(name=name, dtype=atype, sql=sql, **default_option))
|
|
1198
|
+
|
|
1199
|
+
# Check that all the data formats are all the same
|
|
1200
|
+
if len(set(args_data_formats)) > 1:
|
|
1201
|
+
raise TypeError(
|
|
1202
|
+
'input data formats must be all be the same: '
|
|
1203
|
+
f'{", ".join(args_data_formats)}',
|
|
559
1204
|
)
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
1205
|
+
|
|
1206
|
+
out['args_data_format'] = args_data_formats[0] if args_data_formats else 'scalar'
|
|
1207
|
+
|
|
1208
|
+
# Generate the return types and the corresponding SQL code for those values
|
|
1209
|
+
ret_schema, out['returns_data_format'], function_type = get_schema(
|
|
1210
|
+
unpack_masked_type(signature.return_annotation),
|
|
1211
|
+
overrides=attrs.get('returns', None),
|
|
1212
|
+
mode='return',
|
|
1213
|
+
)
|
|
1214
|
+
|
|
1215
|
+
out['returns_data_format'] = out['returns_data_format'] or 'scalar'
|
|
1216
|
+
out['function_type'] = function_type
|
|
1217
|
+
|
|
1218
|
+
# All functions have to return a value, so if none was specified try to
|
|
1219
|
+
# insert a reasonable default that includes NULLs.
|
|
1220
|
+
if not ret_schema:
|
|
1221
|
+
ret_schema = [('', 'int8?', 'TINYINT NULL')]
|
|
1222
|
+
|
|
1223
|
+
# Generate field names for the return values
|
|
1224
|
+
if function_type == 'tvf' or len(ret_schema) > 1:
|
|
1225
|
+
for i, (name, rtype, sql) in enumerate(ret_schema):
|
|
1226
|
+
if not name:
|
|
1227
|
+
ret_schema[i] = (string.ascii_letters[i], rtype, sql)
|
|
1228
|
+
|
|
1229
|
+
# Generate SQL code for the return values
|
|
1230
|
+
for i, (name, rtype, sql) in enumerate(ret_schema):
|
|
1231
|
+
sql = sql or dtype_to_sql(
|
|
1232
|
+
rtype,
|
|
1233
|
+
force_nullable=ret_masks[i] if ret_masks else False,
|
|
568
1234
|
function_type=function_type,
|
|
569
|
-
field_names=[x for x in returns_overrides.model_fields.keys()],
|
|
570
|
-
)
|
|
571
|
-
elif returns_overrides is not None and not isinstance(returns_overrides, str):
|
|
572
|
-
raise TypeError(f'unrecognized type for return value: {returns_overrides}')
|
|
573
|
-
else:
|
|
574
|
-
if not output_fields:
|
|
575
|
-
if dataclasses.is_dataclass(signature.return_annotation):
|
|
576
|
-
output_fields = [
|
|
577
|
-
x.name for x in dataclasses.fields(signature.return_annotation)
|
|
578
|
-
]
|
|
579
|
-
elif has_pydantic and inspect.isclass(signature.return_annotation) \
|
|
580
|
-
and issubclass(signature.return_annotation, pydantic.BaseModel):
|
|
581
|
-
output_fields = list(signature.return_annotation.model_fields.keys())
|
|
582
|
-
out_type = collapse_dtypes([
|
|
583
|
-
classify_dtype(x) for x in simplify_dtype(signature.return_annotation)
|
|
584
|
-
])
|
|
585
|
-
sql = dtype_to_sql(
|
|
586
|
-
out_type, function_type=function_type, field_names=output_fields,
|
|
587
1235
|
)
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
copied_keys = ['database', 'environment', 'packages', 'resources', 'replace']
|
|
591
|
-
for key in copied_keys:
|
|
592
|
-
if attrs.get(key):
|
|
593
|
-
out[key] = attrs[key]
|
|
1236
|
+
returns.append(dict(name=name, dtype=rtype, sql=sql))
|
|
594
1237
|
|
|
1238
|
+
# Set the function endpoint
|
|
595
1239
|
out['endpoint'] = '/invoke'
|
|
1240
|
+
|
|
1241
|
+
# Set the function doc string
|
|
596
1242
|
out['doc'] = func.__doc__
|
|
597
1243
|
|
|
598
1244
|
return out
|
|
@@ -641,9 +1287,10 @@ def sql_to_dtype(sql: str) -> str:
|
|
|
641
1287
|
|
|
642
1288
|
def dtype_to_sql(
|
|
643
1289
|
dtype: str,
|
|
644
|
-
default: Any =
|
|
1290
|
+
default: Any = NO_DEFAULT,
|
|
645
1291
|
field_names: Optional[List[str]] = None,
|
|
646
1292
|
function_type: str = 'udf',
|
|
1293
|
+
force_nullable: bool = False,
|
|
647
1294
|
) -> str:
|
|
648
1295
|
"""
|
|
649
1296
|
Convert a collapsed dtype string to a SQL type.
|
|
@@ -656,6 +1303,10 @@ def dtype_to_sql(
|
|
|
656
1303
|
Default value
|
|
657
1304
|
field_names : List[str], optional
|
|
658
1305
|
Field names for tuple types
|
|
1306
|
+
function_type : str, optional
|
|
1307
|
+
Function type, either 'udf' or 'tvf'
|
|
1308
|
+
force_nullable : bool, optional
|
|
1309
|
+
Whether to force the type to be nullable
|
|
659
1310
|
|
|
660
1311
|
Returns
|
|
661
1312
|
-------
|
|
@@ -666,12 +1317,17 @@ def dtype_to_sql(
|
|
|
666
1317
|
if dtype.endswith('?'):
|
|
667
1318
|
nullable = ' NULL'
|
|
668
1319
|
dtype = dtype[:-1]
|
|
1320
|
+
elif '|null' in dtype:
|
|
1321
|
+
nullable = ' NULL'
|
|
1322
|
+
dtype = dtype.replace('|null', '')
|
|
1323
|
+
elif force_nullable:
|
|
1324
|
+
nullable = ' NULL'
|
|
669
1325
|
|
|
670
1326
|
if dtype == 'null':
|
|
671
1327
|
nullable = ''
|
|
672
1328
|
|
|
673
1329
|
default_clause = ''
|
|
674
|
-
if default is not
|
|
1330
|
+
if default is not NO_DEFAULT:
|
|
675
1331
|
if default is dt.NULL:
|
|
676
1332
|
default = None
|
|
677
1333
|
default_clause = f' DEFAULT {escape_item(default, "utf8")}'
|
|
@@ -729,6 +1385,8 @@ def signature_to_sql(
|
|
|
729
1385
|
str : SQL formatted function signature
|
|
730
1386
|
|
|
731
1387
|
'''
|
|
1388
|
+
function_type = signature.get('function_type') or 'udf'
|
|
1389
|
+
|
|
732
1390
|
args = []
|
|
733
1391
|
for arg in signature['args']:
|
|
734
1392
|
# Use default value from Python function if SQL doesn't set one
|
|
@@ -741,8 +1399,22 @@ def signature_to_sql(
|
|
|
741
1399
|
|
|
742
1400
|
returns = ''
|
|
743
1401
|
if signature.get('returns'):
|
|
744
|
-
|
|
1402
|
+
ret = signature['returns']
|
|
1403
|
+
if function_type == 'tvf':
|
|
1404
|
+
res = 'TABLE(' + ', '.join(
|
|
1405
|
+
f'{escape_name(x["name"])} {x["sql"]}' for x in ret
|
|
1406
|
+
) + ')'
|
|
1407
|
+
elif ret[0]['name'] and len(ret) > 1:
|
|
1408
|
+
res = 'RECORD(' + ', '.join(
|
|
1409
|
+
f'{escape_name(x["name"])} {x["sql"]}' for x in ret
|
|
1410
|
+
) + ')'
|
|
1411
|
+
else:
|
|
1412
|
+
res = ret[0]['sql']
|
|
745
1413
|
returns = f' RETURNS {res}'
|
|
1414
|
+
else:
|
|
1415
|
+
raise ValueError(
|
|
1416
|
+
'function signature must have a return type specified',
|
|
1417
|
+
)
|
|
746
1418
|
|
|
747
1419
|
host = os.environ.get('SINGLESTOREDB_EXT_HOST', '127.0.0.1')
|
|
748
1420
|
port = os.environ.get('SINGLESTOREDB_EXT_PORT', '8000')
|