singlestoredb 1.12.4__py3-none-any.whl → 1.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of singlestoredb might be problematic. Click here for more details.

Files changed (29) hide show
  1. singlestoredb/__init__.py +1 -1
  2. singlestoredb/apps/__init__.py +1 -0
  3. singlestoredb/apps/_config.py +6 -0
  4. singlestoredb/apps/_connection_info.py +8 -0
  5. singlestoredb/apps/_python_udfs.py +85 -0
  6. singlestoredb/config.py +14 -2
  7. singlestoredb/functions/__init__.py +11 -1
  8. singlestoredb/functions/decorator.py +102 -252
  9. singlestoredb/functions/dtypes.py +545 -198
  10. singlestoredb/functions/ext/asgi.py +288 -90
  11. singlestoredb/functions/ext/json.py +29 -36
  12. singlestoredb/functions/ext/mmap.py +1 -1
  13. singlestoredb/functions/ext/rowdat_1.py +50 -70
  14. singlestoredb/functions/signature.py +816 -144
  15. singlestoredb/functions/typing.py +41 -0
  16. singlestoredb/functions/utils.py +342 -0
  17. singlestoredb/http/connection.py +3 -1
  18. singlestoredb/management/manager.py +6 -1
  19. singlestoredb/management/utils.py +2 -2
  20. singlestoredb/tests/ext_funcs/__init__.py +476 -237
  21. singlestoredb/tests/test_ext_func.py +192 -3
  22. singlestoredb/tests/test_udf.py +101 -131
  23. singlestoredb/tests/test_udf_returns.py +459 -0
  24. {singlestoredb-1.12.4.dist-info → singlestoredb-1.13.0.dist-info}/METADATA +2 -1
  25. {singlestoredb-1.12.4.dist-info → singlestoredb-1.13.0.dist-info}/RECORD +29 -25
  26. {singlestoredb-1.12.4.dist-info → singlestoredb-1.13.0.dist-info}/LICENSE +0 -0
  27. {singlestoredb-1.12.4.dist-info → singlestoredb-1.13.0.dist-info}/WHEEL +0 -0
  28. {singlestoredb-1.12.4.dist-info → singlestoredb-1.13.0.dist-info}/entry_points.txt +0 -0
  29. {singlestoredb-1.12.4.dist-info → singlestoredb-1.13.0.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ from typing import List
16
16
  from typing import Optional
17
17
  from typing import Sequence
18
18
  from typing import Tuple
19
+ from typing import Type
19
20
  from typing import TypeVar
20
21
  from typing import Union
21
22
 
@@ -25,13 +26,11 @@ try:
25
26
  except ImportError:
26
27
  has_numpy = False
27
28
 
28
- try:
29
- import pydantic
30
- has_pydantic = True
31
- except ImportError:
32
- has_pydantic = False
33
29
 
34
30
  from . import dtypes as dt
31
+ from . import utils
32
+ from .typing import Table
33
+ from .typing import Masked
35
34
  from ..mysql.converters import escape_item # type: ignore
36
35
 
37
36
  if sys.version_info >= (3, 10):
@@ -40,6 +39,18 @@ else:
40
39
  _UNION_TYPES = {typing.Union}
41
40
 
42
41
 
42
+ def is_union(x: Any) -> bool:
43
+ """Check if the object is a Union."""
44
+ return typing.get_origin(x) in _UNION_TYPES
45
+
46
+
47
+ class NoDefaultType:
48
+ pass
49
+
50
+
51
+ NO_DEFAULT = NoDefaultType()
52
+
53
+
43
54
  array_types: Tuple[Any, ...]
44
55
 
45
56
  if has_numpy:
@@ -192,6 +203,23 @@ class ArrayCollection(Collection):
192
203
  pass
193
204
 
194
205
 
206
+ def get_data_format(obj: Any) -> str:
207
+ """Return the data format of the DataFrame / Table / vector."""
208
+ # Cheating here a bit so we don't have to import pandas / polars / pyarrow
209
+ # unless we absolutely need to
210
+ if getattr(obj, '__module__', '').startswith('pandas.'):
211
+ return 'pandas'
212
+ if getattr(obj, '__module__', '').startswith('polars.'):
213
+ return 'polars'
214
+ if getattr(obj, '__module__', '').startswith('pyarrow.'):
215
+ return 'arrow'
216
+ if getattr(obj, '__module__', '').startswith('numpy.'):
217
+ return 'numpy'
218
+ if isinstance(obj, list):
219
+ return 'list'
220
+ return 'scalar'
221
+
222
+
195
223
  def escape_name(name: str) -> str:
196
224
  """Escape a function parameter name."""
197
225
  if '`' in name:
@@ -203,6 +231,12 @@ def simplify_dtype(dtype: Any) -> List[Any]:
203
231
  """
204
232
  Expand a type annotation to a flattened list of atomic types.
205
233
 
234
+ This function will attempty to find the underlying type of a
235
+ type annotation. For example, a Union of types will be flattened
236
+ to a list of types. A Tuple or Array type will be expanded to
237
+ a list of types. A TypeVar will be expanded to a list of
238
+ constraints and bounds.
239
+
206
240
  Parameters
207
241
  ----------
208
242
  dtype : Any
@@ -210,7 +244,8 @@ def simplify_dtype(dtype: Any) -> List[Any]:
210
244
 
211
245
  Returns
212
246
  -------
213
- List[Any] -- list of dtype strings, TupleCollections, and ArrayCollections
247
+ List[Any]
248
+ list of dtype strings, TupleCollections, and ArrayCollections
214
249
 
215
250
  """
216
251
  origin = typing.get_origin(dtype)
@@ -218,7 +253,7 @@ def simplify_dtype(dtype: Any) -> List[Any]:
218
253
  args = []
219
254
 
220
255
  # Flatten Unions
221
- if origin in _UNION_TYPES:
256
+ if is_union(dtype):
222
257
  for x in typing.get_args(dtype):
223
258
  args.extend(simplify_dtype(x))
224
259
 
@@ -230,7 +265,7 @@ def simplify_dtype(dtype: Any) -> List[Any]:
230
265
  args.extend(simplify_dtype(dtype.__bound__))
231
266
 
232
267
  # Sequence types
233
- elif origin is not None and issubclass(origin, Sequence):
268
+ elif origin is not None and inspect.isclass(origin) and issubclass(origin, Sequence):
234
269
  item_args: List[Union[List[type], type]] = []
235
270
  for x in typing.get_args(dtype):
236
271
  item_dtype = simplify_dtype(x)
@@ -252,14 +287,31 @@ def simplify_dtype(dtype: Any) -> List[Any]:
252
287
  return args
253
288
 
254
289
 
255
- def classify_dtype(dtype: Any) -> str:
256
- """Classify the type annotation into a type name."""
290
+ def normalize_dtype(dtype: Any) -> str:
291
+ """
292
+ Normalize the type annotation into a type name.
293
+
294
+ Parameters
295
+ ----------
296
+ dtype : Any
297
+ Type annotation, list of type annotations, or a string
298
+ containing a SQL type name
299
+
300
+ Returns
301
+ -------
302
+ str
303
+ Normalized type name
304
+
305
+ """
257
306
  if isinstance(dtype, list):
258
- return '|'.join(classify_dtype(x) for x in dtype)
307
+ return '|'.join(normalize_dtype(x) for x in dtype)
259
308
 
260
309
  if isinstance(dtype, str):
261
310
  return sql_to_dtype(dtype)
262
311
 
312
+ if typing.get_origin(dtype) is np.dtype:
313
+ dtype = typing.get_args(dtype)[0]
314
+
263
315
  # Specific types
264
316
  if dtype is None or dtype is type(None): # noqa: E721
265
317
  return 'null'
@@ -270,45 +322,61 @@ def classify_dtype(dtype: Any) -> str:
270
322
  if dtype is bool:
271
323
  return 'bool'
272
324
 
273
- if dataclasses.is_dataclass(dtype):
274
- fields = dataclasses.fields(dtype)
325
+ if utils.is_dataclass(dtype):
326
+ dc_fields = dataclasses.fields(dtype)
327
+ item_dtypes = ','.join(
328
+ f'{normalize_dtype(simplify_dtype(x.type))}' for x in dc_fields
329
+ )
330
+ return f'tuple[{item_dtypes}]'
331
+
332
+ if utils.is_typeddict(dtype):
333
+ td_fields = utils.get_annotations(dtype).keys()
334
+ item_dtypes = ','.join(
335
+ f'{normalize_dtype(simplify_dtype(dtype[x]))}' for x in td_fields
336
+ )
337
+ return f'tuple[{item_dtypes}]'
338
+
339
+ if utils.is_pydantic(dtype):
340
+ pyd_fields = dtype.model_fields.values()
275
341
  item_dtypes = ','.join(
276
- f'{classify_dtype(simplify_dtype(x.type))}' for x in fields
342
+ f'{normalize_dtype(simplify_dtype(x.annotation))}' # type: ignore
343
+ for x in pyd_fields
277
344
  )
278
345
  return f'tuple[{item_dtypes}]'
279
346
 
280
- if has_pydantic and inspect.isclass(dtype) and issubclass(dtype, pydantic.BaseModel):
281
- fields = dtype.model_fields.values()
347
+ if utils.is_namedtuple(dtype):
348
+ nt_fields = utils.get_annotations(dtype).values()
282
349
  item_dtypes = ','.join(
283
- f'{classify_dtype(simplify_dtype(x.annotation))}' # type: ignore
284
- for x in fields
350
+ f'{normalize_dtype(simplify_dtype(dtype[x]))}' for x in nt_fields
285
351
  )
286
352
  return f'tuple[{item_dtypes}]'
287
353
 
288
354
  if not inspect.isclass(dtype):
355
+
289
356
  # Check for compound types
290
357
  origin = typing.get_origin(dtype)
291
358
  if origin is not None:
359
+
292
360
  # Tuple type
293
361
  if origin is Tuple:
294
362
  args = typing.get_args(dtype)
295
- item_dtypes = ','.join(classify_dtype(x) for x in args)
363
+ item_dtypes = ','.join(normalize_dtype(x) for x in args)
296
364
  return f'tuple[{item_dtypes}]'
297
365
 
298
366
  # Array types
299
- elif issubclass(origin, array_types):
367
+ elif inspect.isclass(origin) and issubclass(origin, array_types):
300
368
  args = typing.get_args(dtype)
301
- item_dtype = classify_dtype(args[0])
369
+ item_dtype = normalize_dtype(args[0])
302
370
  return f'array[{item_dtype}]'
303
371
 
304
372
  raise TypeError(f'unsupported type annotation: {dtype}')
305
373
 
306
374
  if isinstance(dtype, ArrayCollection):
307
- item_dtypes = ','.join(classify_dtype(x) for x in dtype.item_dtypes)
375
+ item_dtypes = ','.join(normalize_dtype(x) for x in dtype.item_dtypes)
308
376
  return f'array[{item_dtypes}]'
309
377
 
310
378
  if isinstance(dtype, TupleCollection):
311
- item_dtypes = ','.join(classify_dtype(x) for x in dtype.item_dtypes)
379
+ item_dtypes = ','.join(normalize_dtype(x) for x in dtype.item_dtypes)
312
380
  return f'tuple[{item_dtypes}]'
313
381
 
314
382
  # Check numpy types if it's available
@@ -346,31 +414,39 @@ def classify_dtype(dtype: Any) -> str:
346
414
 
347
415
  raise TypeError(
348
416
  f'unsupported type annotation: {dtype}; '
349
- 'use `args`/`returns` on the @udf/@tvf decotator to specify the data type',
417
+ 'use `args`/`returns` on the @udf/@tvf decorator to specify the data type',
350
418
  )
351
419
 
352
420
 
353
- def collapse_dtypes(dtypes: Union[str, List[str]]) -> str:
421
+ def collapse_dtypes(dtypes: Union[str, List[str]], include_null: bool = False) -> str:
354
422
  """
355
423
  Collapse a dtype possibly containing multiple data types to one type.
356
424
 
425
+ This function can fail if there is no single type that naturally
426
+ encompasses all of the types in the list.
427
+
357
428
  Parameters
358
429
  ----------
359
430
  dtypes : str or list[str]
360
431
  The data types to collapse
432
+ include_null : bool, optional
433
+ Whether to force include null types in the result
361
434
 
362
435
  Returns
363
436
  -------
364
437
  str
365
438
 
366
439
  """
440
+ if isinstance(dtypes, str) and '|' in dtypes:
441
+ dtypes = dtypes.split('|')
442
+
367
443
  if not isinstance(dtypes, list):
368
444
  return dtypes
369
445
 
370
446
  orig_dtypes = dtypes
371
447
  dtypes = list(set(dtypes))
372
448
 
373
- is_nullable = 'null' in dtypes
449
+ is_nullable = include_null or 'null' in dtypes
374
450
 
375
451
  dtypes = [x for x in dtypes if x != 'null']
376
452
 
@@ -443,7 +519,602 @@ def collapse_dtypes(dtypes: Union[str, List[str]]) -> str:
443
519
  return dtypes[0] + ('?' if is_nullable else '')
444
520
 
445
521
 
446
- def get_signature(func: Callable[..., Any], name: Optional[str] = None) -> Dict[str, Any]:
522
+ def get_dataclass_schema(
523
+ obj: Any,
524
+ include_default: bool = False,
525
+ ) -> List[Union[Tuple[str, Any], Tuple[str, Any, Any]]]:
526
+ """
527
+ Get the schema of a dataclass.
528
+
529
+ Parameters
530
+ ----------
531
+ obj : dataclass
532
+ The dataclass to get the schema of
533
+
534
+ Returns
535
+ -------
536
+ List[Tuple[str, Any]] | List[Tuple[str, Any, Any]]
537
+ A list of tuples containing the field names and field types
538
+
539
+ """
540
+ if include_default:
541
+ return [
542
+ (
543
+ f.name, f.type,
544
+ NO_DEFAULT if f.default is dataclasses.MISSING else f.default,
545
+ )
546
+ for f in dataclasses.fields(obj)
547
+ ]
548
+ return [(f.name, f.type) for f in dataclasses.fields(obj)]
549
+
550
+
551
+ def get_typeddict_schema(
552
+ obj: Any,
553
+ include_default: bool = False,
554
+ ) -> List[Union[Tuple[str, Any], Tuple[str, Any, Any]]]:
555
+ """
556
+ Get the schema of a TypedDict.
557
+
558
+ Parameters
559
+ ----------
560
+ obj : TypedDict
561
+ The TypedDict to get the schema of
562
+ include_default : bool, optional
563
+ Whether to include the default value in the column specification
564
+
565
+ Returns
566
+ -------
567
+ List[Tuple[str, Any]] | List[Tuple[str, Any, Any]]
568
+ A list of tuples containing the field names and field types
569
+
570
+ """
571
+ if include_default:
572
+ return [
573
+ (k, v, getattr(obj, k, NO_DEFAULT))
574
+ for k, v in utils.get_annotations(obj).items()
575
+ ]
576
+ return list(utils.get_annotations(obj).items())
577
+
578
+
579
+ def get_pydantic_schema(
580
+ obj: Any,
581
+ include_default: bool = False,
582
+ ) -> List[Union[Tuple[str, Any], Tuple[str, Any, Any]]]:
583
+ """
584
+ Get the schema of a pydantic model.
585
+
586
+ Parameters
587
+ ----------
588
+ obj : pydantic.BaseModel
589
+ The pydantic model to get the schema of
590
+ include_default : bool, optional
591
+ Whether to include the default value in the column specification
592
+
593
+ Returns
594
+ -------
595
+ List[Tuple[str, Any]] | List[Tuple[str, Any, Any]]
596
+ A list of tuples containing the field names and field types
597
+
598
+ """
599
+ import pydantic_core
600
+ if include_default:
601
+ return [
602
+ (
603
+ k, v.annotation,
604
+ NO_DEFAULT if v.default is pydantic_core.PydanticUndefined else v.default,
605
+ )
606
+ for k, v in obj.model_fields.items()
607
+ ]
608
+ return [(k, v.annotation) for k, v in obj.model_fields.items()]
609
+
610
+
611
+ def get_namedtuple_schema(
612
+ obj: Any,
613
+ include_default: bool = False,
614
+ ) -> List[Union[Tuple[Any, str], Tuple[Any, str, Any]]]:
615
+ """
616
+ Get the schema of a named tuple.
617
+
618
+ Parameters
619
+ ----------
620
+ obj : NamedTuple
621
+ The named tuple to get the schema of
622
+ include_default : bool, optional
623
+ Whether to include the default value in the column specification
624
+
625
+ Returns
626
+ -------
627
+ List[Tuple[Any, str]] | List[Tuple[Any, str, Any]]
628
+ A list of tuples containing the field names and field types
629
+
630
+ """
631
+ if include_default:
632
+ return [
633
+ (
634
+ k, v,
635
+ obj._field_defaults.get(k, NO_DEFAULT),
636
+ )
637
+ for k, v in utils.get_annotations(obj).items()
638
+ ]
639
+ return list(utils.get_annotations(obj).items())
640
+
641
+
642
+ def get_table_schema(
643
+ obj: Any,
644
+ include_default: bool = False,
645
+ ) -> List[Union[Tuple[Any, str], Tuple[Any, str, Any]]]:
646
+ """
647
+ Get the schema of a Table.
648
+
649
+ Parameters
650
+ ----------
651
+ obj : Table
652
+ The Table to get the schema of
653
+ include_default : bool, optional
654
+ Whether to include the default value in the column specification
655
+
656
+ Returns
657
+ -------
658
+ List[Tuple[Any, str]] | List[Tuple[Any, str, Any]]
659
+ A list of tuples containing the field names and field types
660
+
661
+ """
662
+ if include_default:
663
+ return [
664
+ (k, v, getattr(obj, k, NO_DEFAULT))
665
+ for k, v in utils.get_annotations(obj).items()
666
+ ]
667
+ return list(utils.get_annotations(obj).items())
668
+
669
+
670
+ def get_colspec(
671
+ overrides: Any,
672
+ include_default: bool = False,
673
+ ) -> List[Union[Tuple[str, Any], Tuple[str, Any, Any]]]:
674
+ """
675
+ Get the column specification from the overrides.
676
+
677
+ Parameters
678
+ ----------
679
+ overrides : Any
680
+ The overrides to get the column specification from
681
+ include_default : bool, optional
682
+ Whether to include the default value in the column specification
683
+
684
+ Returns
685
+ -------
686
+ List[Tuple[str, Any]] | List[Tuple[str, Any, Any]]
687
+ A list of tuples containing the field names and field types
688
+
689
+ """
690
+ overrides_colspec = []
691
+
692
+ if overrides:
693
+
694
+ # Dataclass
695
+ if utils.is_dataclass(overrides):
696
+ overrides_colspec = get_dataclass_schema(
697
+ overrides, include_default=include_default,
698
+ )
699
+
700
+ # TypedDict
701
+ elif utils.is_typeddict(overrides):
702
+ overrides_colspec = get_typeddict_schema(
703
+ overrides, include_default=include_default,
704
+ )
705
+
706
+ # Named tuple
707
+ elif utils.is_namedtuple(overrides):
708
+ overrides_colspec = get_namedtuple_schema(
709
+ overrides, include_default=include_default,
710
+ )
711
+
712
+ # Pydantic model
713
+ elif utils.is_pydantic(overrides):
714
+ overrides_colspec = get_pydantic_schema(
715
+ overrides, include_default=include_default,
716
+ )
717
+
718
+ # List of types
719
+ elif isinstance(overrides, list):
720
+ if include_default:
721
+ overrides_colspec = [
722
+ (getattr(x, 'name', ''), x, NO_DEFAULT) for x in overrides
723
+ ]
724
+ else:
725
+ overrides_colspec = [(getattr(x, 'name', ''), x) for x in overrides]
726
+
727
+ # Other
728
+ else:
729
+ if include_default:
730
+ overrides_colspec = [
731
+ (getattr(overrides, 'name', ''), overrides, NO_DEFAULT),
732
+ ]
733
+ else:
734
+ overrides_colspec = [(getattr(overrides, 'name', ''), overrides)]
735
+
736
+ return overrides_colspec
737
+
738
+
739
+ def unpack_masked_type(obj: Any) -> Any:
740
+ """
741
+ Unpack a masked type into a single type.
742
+
743
+ Parameters
744
+ ----------
745
+ obj : Any
746
+ The masked type to unpack
747
+
748
+ Returns
749
+ -------
750
+ Any
751
+ The unpacked type
752
+
753
+ """
754
+ if typing.get_origin(obj) is Masked:
755
+ return typing.get_args(obj)[0]
756
+ return obj
757
+
758
+
759
+ def get_schema(
760
+ spec: Any,
761
+ overrides: Optional[Union[List[str], Type[Any]]] = None,
762
+ mode: str = 'parameter',
763
+ ) -> Tuple[List[Tuple[str, Any, Optional[str]]], str, str]:
764
+ """
765
+ Expand a return type annotation into a list of types and field names.
766
+
767
+ Parameters
768
+ ----------
769
+ spec : Any
770
+ The return type specification
771
+ overrides : List[str], optional
772
+ List of SQL type specifications for the return type
773
+ mode : str
774
+ The mode of the function, either 'parameter' or 'return'
775
+
776
+ Returns
777
+ -------
778
+ Tuple[List[Tuple[str, Any, Optional[str]]], str, str]
779
+ A list of tuples containing the field names and field types,
780
+ the normalized data format, optionally the SQL
781
+ definition of the type, and the data format of the type
782
+
783
+ """
784
+ colspec = []
785
+ data_format = ''
786
+ function_type = 'udf'
787
+
788
+ origin = typing.get_origin(spec)
789
+ args = typing.get_args(spec)
790
+ args_origins = [typing.get_origin(x) if x is not None else None for x in args]
791
+
792
+ # Make sure that the result of a TVF is a list or dataframe
793
+ if mode == 'return':
794
+
795
+ # See if it's a Table subclass with annotations
796
+ if inspect.isclass(origin) and origin is Table:
797
+
798
+ function_type = 'tvf'
799
+
800
+ if utils.is_dataframe(args[0]):
801
+ if not overrides:
802
+ raise TypeError(
803
+ 'column types must be specified by the '
804
+ '`returns=` parameter of the @udf decorator',
805
+ )
806
+
807
+ if utils.get_module(args[0]) in ['pandas', 'polars', 'pyarrow']:
808
+ data_format = utils.get_module(args[0])
809
+ spec = args[0]
810
+ else:
811
+ raise TypeError(
812
+ 'only pandas.DataFrames, polars.DataFrames, '
813
+ 'and pyarrow.Tables are supported as tables.',
814
+ )
815
+
816
+ elif typing.get_origin(args[0]) is list:
817
+ if len(args) != 1:
818
+ raise TypeError(
819
+ 'only one list is supported within a table; to '
820
+ 'return multiple columns, use a tuple, NamedTuple, '
821
+ 'dataclass, TypedDict, or pydantic model',
822
+ )
823
+ spec = typing.get_args(args[0])[0]
824
+ data_format = 'list'
825
+
826
+ elif all([utils.is_vector(x, include_masks=True) for x in args]):
827
+ pass
828
+
829
+ else:
830
+ raise TypeError(
831
+ 'return type for TVF must be a list, DataFrame / Table, '
832
+ 'or tuple of vectors',
833
+ )
834
+
835
+ # Short circuit check for common valid types
836
+ elif utils.is_vector(spec) or spec in [str, float, int, bytes]:
837
+ pass
838
+
839
+ # Try to catch some common mistakes
840
+ elif origin in [tuple, dict] or tuple in args_origins or \
841
+ (
842
+ inspect.isclass(spec) and
843
+ (
844
+ utils.is_dataframe(spec)
845
+ or utils.is_dataclass(spec)
846
+ or utils.is_typeddict(spec)
847
+ or utils.is_pydantic(spec)
848
+ or utils.is_namedtuple(spec)
849
+ )
850
+ ):
851
+ raise TypeError(
852
+ 'invalid return type for a UDF; '
853
+ f'expecting a scalar or vector, but got {spec}',
854
+ )
855
+
856
+ # Short circuit check for common valid types
857
+ elif utils.is_vector(spec) or spec in [str, float, int, bytes]:
858
+ pass
859
+
860
+ # Error out for incorrect parameter types
861
+ elif origin in [tuple, dict] or tuple in args_origins or \
862
+ (
863
+ inspect.isclass(spec) and
864
+ (
865
+ utils.is_dataframe(spec)
866
+ or utils.is_dataclass(spec)
867
+ or utils.is_typeddict(spec)
868
+ or utils.is_pydantic(spec)
869
+ or utils.is_namedtuple(spec)
870
+ )
871
+ ):
872
+ raise TypeError(f'parameter types must be scalar or vector, got {spec}')
873
+
874
+ #
875
+ # Process each parameter / return type into a colspec
876
+ #
877
+
878
+ # Compute overrides colspec from various formats
879
+ overrides_colspec = get_colspec(overrides)
880
+
881
+ # Dataframe type
882
+ if utils.is_dataframe(spec):
883
+ colspec = overrides_colspec
884
+
885
+ # Numpy array types
886
+ elif utils.is_numpy(spec):
887
+ data_format = 'numpy'
888
+ if overrides:
889
+ colspec = overrides_colspec
890
+ elif len(typing.get_args(spec)) < 2:
891
+ raise TypeError(
892
+ 'numpy array must have a data type specified '
893
+ 'in the @udf decorator or with an NDArray type annotation',
894
+ )
895
+ else:
896
+ colspec = [('', typing.get_args(spec)[1])]
897
+
898
+ # Pandas Series
899
+ elif utils.is_pandas_series(spec):
900
+ data_format = 'pandas'
901
+ if not overrides:
902
+ raise TypeError(
903
+ 'pandas Series must have a data type specified '
904
+ 'in the @udf decorator',
905
+ )
906
+ colspec = overrides_colspec
907
+
908
+ # Polars Series
909
+ elif utils.is_polars_series(spec):
910
+ data_format = 'polars'
911
+ if not overrides:
912
+ raise TypeError(
913
+ 'polars Series must have a data type specified '
914
+ 'in the @udf decorator',
915
+ )
916
+ colspec = overrides_colspec
917
+
918
+ # PyArrow Array
919
+ elif utils.is_pyarrow_array(spec):
920
+ data_format = 'arrow'
921
+ if not overrides:
922
+ raise TypeError(
923
+ 'pyarrow Arrays must have a data type specified '
924
+ 'in the @udf decorator',
925
+ )
926
+ colspec = overrides_colspec
927
+
928
+ # Return type is specified by a dataclass definition
929
+ elif utils.is_dataclass(spec):
930
+ colspec = overrides_colspec or get_dataclass_schema(spec)
931
+
932
+ # Return type is specified by a TypedDict definition
933
+ elif utils.is_typeddict(spec):
934
+ colspec = overrides_colspec or get_typeddict_schema(spec)
935
+
936
+ # Return type is specified by a pydantic model
937
+ elif utils.is_pydantic(spec):
938
+ colspec = overrides_colspec or get_pydantic_schema(spec)
939
+
940
+ # Return type is specified by a named tuple
941
+ elif utils.is_namedtuple(spec):
942
+ colspec = overrides_colspec or get_namedtuple_schema(spec)
943
+
944
+ # Unrecognized return type
945
+ elif spec is not None:
946
+
947
+ # Return type is specified by a SQL string
948
+ if isinstance(spec, str):
949
+ data_format = 'scalar'
950
+ colspec = [(getattr(spec, 'name', ''), spec)]
951
+
952
+ # Plain list vector
953
+ elif typing.get_origin(spec) is list:
954
+ data_format = 'list'
955
+ colspec = [('', typing.get_args(spec)[0])]
956
+
957
+ # Multiple return values
958
+ elif inspect.isclass(typing.get_origin(spec)) \
959
+ and issubclass(typing.get_origin(spec), tuple): # type: ignore[arg-type]
960
+
961
+ out_names, out_overrides = [], []
962
+
963
+ # Get the colspec for the overrides
964
+ if overrides:
965
+ out_colspec = [
966
+ x for x in get_colspec(overrides, include_default=True)
967
+ ]
968
+ out_names = [x[0] for x in out_colspec]
969
+ out_overrides = [x[1] for x in out_colspec]
970
+
971
+ # Make sure that the number of overrides matches the number of
972
+ # return types or parameter types
973
+ if out_overrides and len(typing.get_args(spec)) != len(out_overrides):
974
+ raise ValueError(
975
+ f'number of {mode} types does not match the number of '
976
+ 'overrides specified',
977
+ )
978
+
979
+ colspec = []
980
+ out_data_formats = []
981
+
982
+ # Get the colspec for each item in the tuple
983
+ for i, x in enumerate(typing.get_args(spec)):
984
+ out_item, out_data_format, _ = get_schema(
985
+ unpack_masked_type(x),
986
+ overrides=out_overrides[i] if out_overrides else [],
987
+ # Always pass UDF mode for individual items
988
+ mode=mode,
989
+ )
990
+
991
+ # Use the name from the overrides if specified
992
+ if out_names and out_names[i] and not out_item[0][0]:
993
+ out_item = [(out_names[i], *out_item[0][1:])]
994
+ elif not out_item[0][0]:
995
+ out_item = [(f'{string.ascii_letters[i]}', *out_item[0][1:])]
996
+
997
+ colspec += out_item
998
+ out_data_formats.append(out_data_format)
999
+
1000
+ # Make sure that all the data formats are the same
1001
+ if len(set(out_data_formats)) > 1:
1002
+ raise TypeError(
1003
+ 'data formats must be all be the same vector / scalar type: '
1004
+ f'{", ".join(out_data_formats)}',
1005
+ )
1006
+
1007
+ if data_format != 'list' and out_data_formats:
1008
+ data_format = out_data_formats[0]
1009
+
1010
+ # Since the colspec was computed by get_schema already, don't go
1011
+ # through the process of normalizing the dtypes again
1012
+ return colspec, data_format, function_type # type: ignore
1013
+
1014
+ # Use overrides if specified
1015
+ elif overrides:
1016
+ if not data_format:
1017
+ data_format = get_data_format(spec)
1018
+ colspec = overrides_colspec
1019
+
1020
+ # Single value, no override
1021
+ else:
1022
+ if not data_format:
1023
+ data_format = 'scalar'
1024
+ colspec = [('', spec)]
1025
+
1026
+ out = []
1027
+
1028
+ # Normalize colspec data types
1029
+ for k, v, *_ in colspec:
1030
+ out.append((
1031
+ k,
1032
+ collapse_dtypes(
1033
+ [normalize_dtype(x) for x in simplify_dtype(v)],
1034
+ ),
1035
+ v if isinstance(v, str) else None,
1036
+ ))
1037
+
1038
+ return out, data_format, function_type
1039
+
1040
+
1041
+ def vector_check(obj: Any) -> Tuple[Any, str]:
1042
+ """
1043
+ Check if the object is a vector type.
1044
+
1045
+ Parameters
1046
+ ----------
1047
+ obj : Any
1048
+ The object to check
1049
+
1050
+ Returns
1051
+ -------
1052
+ Tuple[Any, str]
1053
+ The scalar type and the data format:
1054
+ 'scalar', 'list', 'numpy', 'pandas', or 'polars'
1055
+
1056
+ """
1057
+ if utils.is_numpy(obj):
1058
+ if len(typing.get_args(obj)) < 2:
1059
+ return None, 'numpy'
1060
+ return typing.get_args(obj)[1], 'numpy'
1061
+ if utils.is_pandas_series(obj):
1062
+ if len(typing.get_args(obj)) < 2:
1063
+ return None, 'pandas'
1064
+ return typing.get_args(obj)[1], 'pandas'
1065
+ if utils.is_polars_series(obj):
1066
+ return None, 'polars'
1067
+ if utils.is_pyarrow_array(obj):
1068
+ return None, 'arrow'
1069
+ if obj is list or typing.get_origin(obj) is list:
1070
+ if len(typing.get_args(obj)) < 1:
1071
+ return None, 'list'
1072
+ return typing.get_args(obj)[0], 'list'
1073
+ return obj, 'scalar'
1074
+
1075
+
1076
+ def get_masks(func: Callable[..., Any]) -> Tuple[List[bool], List[bool]]:
1077
+ """
1078
+ Get the list of masked parameters and return values for the function.
1079
+
1080
+ Parameters
1081
+ ----------
1082
+ func : Callable
1083
+ The function to call as the endpoint
1084
+
1085
+ Returns
1086
+ -------
1087
+ Tuple[List[bool], List[bool]]
1088
+ A Tuple containing the parameter / return value masks
1089
+ as lists of booleans
1090
+
1091
+
1092
+ """
1093
+ params = inspect.signature(func).parameters
1094
+ returns = inspect.signature(func).return_annotation
1095
+
1096
+ ret_masks = []
1097
+ if typing.get_origin(returns) is Masked:
1098
+ ret_masks = [True]
1099
+ elif typing.get_origin(returns) is Table:
1100
+ for x in typing.get_args(returns):
1101
+ if typing.get_origin(x) is Masked:
1102
+ ret_masks.append(True)
1103
+ else:
1104
+ ret_masks.append(False)
1105
+ if not any(ret_masks):
1106
+ ret_masks = []
1107
+
1108
+ return (
1109
+ [typing.get_origin(x.annotation) is Masked for x in params.values()],
1110
+ ret_masks,
1111
+ )
1112
+
1113
+
1114
+ def get_signature(
1115
+ func: Callable[..., Any],
1116
+ func_name: Optional[str] = None,
1117
+ ) -> Dict[str, Any]:
447
1118
  '''
448
1119
  Print the UDF signature of the Python callable.
449
1120
 
@@ -451,7 +1122,7 @@ def get_signature(func: Callable[..., Any], name: Optional[str] = None) -> Dict[
451
1122
  ----------
452
1123
  func : Callable
453
1124
  The function to extract the signature of
454
- name : str, optional
1125
+ func_name : str, optional
455
1126
  Name override for function
456
1127
 
457
1128
  Returns
@@ -461,138 +1132,113 @@ def get_signature(func: Callable[..., Any], name: Optional[str] = None) -> Dict[
461
1132
  '''
462
1133
  signature = inspect.signature(func)
463
1134
  args: List[Dict[str, Any]] = []
1135
+ returns: List[Dict[str, Any]] = []
1136
+
464
1137
  attrs = getattr(func, '_singlestoredb_attrs', {})
465
- name = attrs.get('name', name if name else func.__name__)
466
- function_type = attrs.get('function_type', 'udf')
467
- out: Dict[str, Any] = dict(name=name, args=args)
468
-
469
- arg_names = [x for x in signature.parameters]
470
- defaults = [
471
- x.default if x.default is not inspect.Parameter.empty else None
472
- for x in signature.parameters.values()
473
- ]
474
- annotations = {
475
- k: x.annotation for k, x in signature.parameters.items()
476
- if x.annotation is not inspect.Parameter.empty
477
- }
1138
+ name = attrs.get('name', func_name if func_name else func.__name__)
1139
+
1140
+ out: Dict[str, Any] = dict(name=name, args=args, returns=returns)
478
1141
 
1142
+ # Do not allow variable positional or keyword arguments
479
1143
  for p in signature.parameters.values():
480
1144
  if p.kind == inspect.Parameter.VAR_POSITIONAL:
481
1145
  raise TypeError('variable positional arguments are not supported')
482
1146
  elif p.kind == inspect.Parameter.VAR_KEYWORD:
483
1147
  raise TypeError('variable keyword arguments are not supported')
484
1148
 
485
- args_overrides = attrs.get('args', None)
486
- returns_overrides = attrs.get('returns', None)
487
- output_fields = attrs.get('output_fields', None)
1149
+ # TODO: Use typing.get_type_hints() for parameters / return values?
488
1150
 
489
- spec_diff = set(arg_names).difference(set(annotations.keys()))
1151
+ # Generate the parameter type and the corresponding SQL code for that parameter
1152
+ args_schema = []
1153
+ args_data_formats = []
1154
+ args_colspec = [x for x in get_colspec(attrs.get('args', []), include_default=True)]
1155
+ args_overrides = [x[1] for x in args_colspec]
1156
+ args_defaults = [x[2] for x in args_colspec] # type: ignore
1157
+ args_masks, ret_masks = get_masks(func)
490
1158
 
491
- # Make sure all arguments are annotated
492
- if spec_diff and args_overrides is None:
493
- raise TypeError(
494
- 'missing annotations for {} in {}'
495
- .format(', '.join(spec_diff), name),
1159
+ if args_overrides and len(args_overrides) != len(signature.parameters):
1160
+ raise ValueError(
1161
+ 'number of args in the decorator does not match '
1162
+ 'the number of parameters in the function signature',
496
1163
  )
497
- elif isinstance(args_overrides, dict):
498
- for s in spec_diff:
499
- if s not in args_overrides:
500
- raise TypeError(
501
- 'missing annotations for {} in {}'
502
- .format(', '.join(spec_diff), name),
503
- )
504
- elif isinstance(args_overrides, list):
505
- if len(arg_names) != len(args_overrides):
506
- raise TypeError(
507
- 'number of annotations does not match in {}: {}'
508
- .format(name, ', '.join(spec_diff)),
509
- )
510
1164
 
511
- for i, arg in enumerate(arg_names):
512
- if isinstance(args_overrides, list):
513
- sql = args_overrides[i]
514
- arg_type = sql_to_dtype(sql)
515
- elif isinstance(args_overrides, dict) and arg in args_overrides:
516
- sql = args_overrides[arg]
517
- arg_type = sql_to_dtype(sql)
518
- elif isinstance(args_overrides, str):
519
- sql = args_overrides
520
- arg_type = sql_to_dtype(sql)
521
- elif args_overrides is not None \
522
- and not isinstance(args_overrides, (list, dict, str)):
523
- raise TypeError(f'unrecognized type for arguments: {args_overrides}')
1165
+ params = list(signature.parameters.values())
1166
+
1167
+ # Get the colspec for each parameter
1168
+ for i, param in enumerate(params):
1169
+ arg_schema, args_data_format, _ = get_schema(
1170
+ unpack_masked_type(param.annotation),
1171
+ overrides=args_overrides[i] if args_overrides else [],
1172
+ mode='parameter',
1173
+ )
1174
+ args_data_formats.append(args_data_format)
1175
+
1176
+ # Insert parameter names as needed
1177
+ if not arg_schema[0][0]:
1178
+ args_schema.append((param.name, *arg_schema[0][1:]))
1179
+
1180
+ for i, (name, atype, sql) in enumerate(args_schema):
1181
+ default_option = {}
1182
+
1183
+ # Insert default values as needed
1184
+ if args_defaults:
1185
+ if args_defaults[i] is not NO_DEFAULT:
1186
+ default_option['default'] = args_defaults[i]
524
1187
  else:
525
- arg_type = collapse_dtypes([
526
- classify_dtype(x) for x in simplify_dtype(annotations[arg])
527
- ])
528
- sql = dtype_to_sql(arg_type, function_type=function_type)
529
- args.append(dict(name=arg, dtype=arg_type, sql=sql, default=defaults[i]))
530
-
531
- if returns_overrides is None \
532
- and signature.return_annotation is inspect.Signature.empty:
533
- raise TypeError(f'no return value annotation in function {name}')
534
-
535
- if isinstance(returns_overrides, str):
536
- sql = returns_overrides
537
- out_type = sql_to_dtype(sql)
538
- elif isinstance(returns_overrides, list):
539
- if not output_fields:
540
- output_fields = [
541
- string.ascii_letters[i] for i in range(len(returns_overrides))
542
- ]
543
- out_type = 'tuple[' + collapse_dtypes([
544
- classify_dtype(x)
545
- for x in simplify_dtype(returns_overrides)
546
- ]).replace('|', ',') + ']'
547
- sql = dtype_to_sql(
548
- out_type, function_type=function_type, field_names=output_fields,
1188
+ if params[i].default is not param.empty:
1189
+ default_option['default'] = params[i].default
1190
+
1191
+ # Generate SQL code for the parameter
1192
+ sql = sql or dtype_to_sql(
1193
+ atype, force_nullable=args_masks[i], **default_option,
549
1194
  )
550
- elif dataclasses.is_dataclass(returns_overrides):
551
- out_type = collapse_dtypes([
552
- classify_dtype(x)
553
- for x in simplify_dtype([x.type for x in returns_overrides.fields])
554
- ])
555
- sql = dtype_to_sql(
556
- out_type,
557
- function_type=function_type,
558
- field_names=[x.name for x in returns_overrides.fields],
1195
+
1196
+ # Add parameter to args definitions
1197
+ args.append(dict(name=name, dtype=atype, sql=sql, **default_option))
1198
+
1199
+ # Check that all the data formats are all the same
1200
+ if len(set(args_data_formats)) > 1:
1201
+ raise TypeError(
1202
+ 'input data formats must be all be the same: '
1203
+ f'{", ".join(args_data_formats)}',
559
1204
  )
560
- elif has_pydantic and inspect.isclass(returns_overrides) \
561
- and issubclass(returns_overrides, pydantic.BaseModel):
562
- out_type = collapse_dtypes([
563
- classify_dtype(x)
564
- for x in simplify_dtype([x for x in returns_overrides.model_fields.values()])
565
- ])
566
- sql = dtype_to_sql(
567
- out_type,
1205
+
1206
+ out['args_data_format'] = args_data_formats[0] if args_data_formats else 'scalar'
1207
+
1208
+ # Generate the return types and the corresponding SQL code for those values
1209
+ ret_schema, out['returns_data_format'], function_type = get_schema(
1210
+ unpack_masked_type(signature.return_annotation),
1211
+ overrides=attrs.get('returns', None),
1212
+ mode='return',
1213
+ )
1214
+
1215
+ out['returns_data_format'] = out['returns_data_format'] or 'scalar'
1216
+ out['function_type'] = function_type
1217
+
1218
+ # All functions have to return a value, so if none was specified try to
1219
+ # insert a reasonable default that includes NULLs.
1220
+ if not ret_schema:
1221
+ ret_schema = [('', 'int8?', 'TINYINT NULL')]
1222
+
1223
+ # Generate field names for the return values
1224
+ if function_type == 'tvf' or len(ret_schema) > 1:
1225
+ for i, (name, rtype, sql) in enumerate(ret_schema):
1226
+ if not name:
1227
+ ret_schema[i] = (string.ascii_letters[i], rtype, sql)
1228
+
1229
+ # Generate SQL code for the return values
1230
+ for i, (name, rtype, sql) in enumerate(ret_schema):
1231
+ sql = sql or dtype_to_sql(
1232
+ rtype,
1233
+ force_nullable=ret_masks[i] if ret_masks else False,
568
1234
  function_type=function_type,
569
- field_names=[x for x in returns_overrides.model_fields.keys()],
570
- )
571
- elif returns_overrides is not None and not isinstance(returns_overrides, str):
572
- raise TypeError(f'unrecognized type for return value: {returns_overrides}')
573
- else:
574
- if not output_fields:
575
- if dataclasses.is_dataclass(signature.return_annotation):
576
- output_fields = [
577
- x.name for x in dataclasses.fields(signature.return_annotation)
578
- ]
579
- elif has_pydantic and inspect.isclass(signature.return_annotation) \
580
- and issubclass(signature.return_annotation, pydantic.BaseModel):
581
- output_fields = list(signature.return_annotation.model_fields.keys())
582
- out_type = collapse_dtypes([
583
- classify_dtype(x) for x in simplify_dtype(signature.return_annotation)
584
- ])
585
- sql = dtype_to_sql(
586
- out_type, function_type=function_type, field_names=output_fields,
587
1235
  )
588
- out['returns'] = dict(dtype=out_type, sql=sql, default=None)
589
-
590
- copied_keys = ['database', 'environment', 'packages', 'resources', 'replace']
591
- for key in copied_keys:
592
- if attrs.get(key):
593
- out[key] = attrs[key]
1236
+ returns.append(dict(name=name, dtype=rtype, sql=sql))
594
1237
 
1238
+ # Set the function endpoint
595
1239
  out['endpoint'] = '/invoke'
1240
+
1241
+ # Set the function doc string
596
1242
  out['doc'] = func.__doc__
597
1243
 
598
1244
  return out
@@ -641,9 +1287,10 @@ def sql_to_dtype(sql: str) -> str:
641
1287
 
642
1288
  def dtype_to_sql(
643
1289
  dtype: str,
644
- default: Any = None,
1290
+ default: Any = NO_DEFAULT,
645
1291
  field_names: Optional[List[str]] = None,
646
1292
  function_type: str = 'udf',
1293
+ force_nullable: bool = False,
647
1294
  ) -> str:
648
1295
  """
649
1296
  Convert a collapsed dtype string to a SQL type.
@@ -656,6 +1303,10 @@ def dtype_to_sql(
656
1303
  Default value
657
1304
  field_names : List[str], optional
658
1305
  Field names for tuple types
1306
+ function_type : str, optional
1307
+ Function type, either 'udf' or 'tvf'
1308
+ force_nullable : bool, optional
1309
+ Whether to force the type to be nullable
659
1310
 
660
1311
  Returns
661
1312
  -------
@@ -666,12 +1317,17 @@ def dtype_to_sql(
666
1317
  if dtype.endswith('?'):
667
1318
  nullable = ' NULL'
668
1319
  dtype = dtype[:-1]
1320
+ elif '|null' in dtype:
1321
+ nullable = ' NULL'
1322
+ dtype = dtype.replace('|null', '')
1323
+ elif force_nullable:
1324
+ nullable = ' NULL'
669
1325
 
670
1326
  if dtype == 'null':
671
1327
  nullable = ''
672
1328
 
673
1329
  default_clause = ''
674
- if default is not None:
1330
+ if default is not NO_DEFAULT:
675
1331
  if default is dt.NULL:
676
1332
  default = None
677
1333
  default_clause = f' DEFAULT {escape_item(default, "utf8")}'
@@ -729,6 +1385,8 @@ def signature_to_sql(
729
1385
  str : SQL formatted function signature
730
1386
 
731
1387
  '''
1388
+ function_type = signature.get('function_type') or 'udf'
1389
+
732
1390
  args = []
733
1391
  for arg in signature['args']:
734
1392
  # Use default value from Python function if SQL doesn't set one
@@ -741,8 +1399,22 @@ def signature_to_sql(
741
1399
 
742
1400
  returns = ''
743
1401
  if signature.get('returns'):
744
- res = signature['returns']['sql']
1402
+ ret = signature['returns']
1403
+ if function_type == 'tvf':
1404
+ res = 'TABLE(' + ', '.join(
1405
+ f'{escape_name(x["name"])} {x["sql"]}' for x in ret
1406
+ ) + ')'
1407
+ elif ret[0]['name'] and len(ret) > 1:
1408
+ res = 'RECORD(' + ', '.join(
1409
+ f'{escape_name(x["name"])} {x["sql"]}' for x in ret
1410
+ ) + ')'
1411
+ else:
1412
+ res = ret[0]['sql']
745
1413
  returns = f' RETURNS {res}'
1414
+ else:
1415
+ raise ValueError(
1416
+ 'function signature must have a return type specified',
1417
+ )
746
1418
 
747
1419
  host = os.environ.get('SINGLESTOREDB_EXT_HOST', '127.0.0.1')
748
1420
  port = os.environ.get('SINGLESTOREDB_EXT_PORT', '8000')