singlestoredb 1.15.0__cp38-abi3-win_amd64.whl → 1.15.2__cp38-abi3-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of singlestoredb might be problematic. Click here for more details.

Files changed (49) hide show
  1. _singlestoredb_accel.pyd +0 -0
  2. singlestoredb/__init__.py +1 -1
  3. singlestoredb/ai/chat.py +14 -0
  4. singlestoredb/apps/_python_udfs.py +18 -3
  5. singlestoredb/apps/_stdout_supress.py +1 -1
  6. singlestoredb/apps/_uvicorn_util.py +4 -0
  7. singlestoredb/config.py +24 -0
  8. singlestoredb/converters.py +1 -1
  9. singlestoredb/docstring/__init__.py +33 -0
  10. singlestoredb/docstring/attrdoc.py +126 -0
  11. singlestoredb/docstring/common.py +230 -0
  12. singlestoredb/docstring/epydoc.py +267 -0
  13. singlestoredb/docstring/google.py +412 -0
  14. singlestoredb/docstring/numpydoc.py +562 -0
  15. singlestoredb/docstring/parser.py +100 -0
  16. singlestoredb/docstring/py.typed +1 -0
  17. singlestoredb/docstring/rest.py +256 -0
  18. singlestoredb/docstring/tests/__init__.py +1 -0
  19. singlestoredb/docstring/tests/_pydoctor.py +21 -0
  20. singlestoredb/docstring/tests/test_epydoc.py +729 -0
  21. singlestoredb/docstring/tests/test_google.py +1007 -0
  22. singlestoredb/docstring/tests/test_numpydoc.py +1100 -0
  23. singlestoredb/docstring/tests/test_parse_from_object.py +109 -0
  24. singlestoredb/docstring/tests/test_parser.py +248 -0
  25. singlestoredb/docstring/tests/test_rest.py +547 -0
  26. singlestoredb/docstring/tests/test_util.py +70 -0
  27. singlestoredb/docstring/util.py +141 -0
  28. singlestoredb/functions/decorator.py +19 -18
  29. singlestoredb/functions/ext/asgi.py +304 -32
  30. singlestoredb/functions/ext/timer.py +2 -11
  31. singlestoredb/functions/ext/utils.py +55 -6
  32. singlestoredb/functions/signature.py +374 -241
  33. singlestoredb/fusion/handlers/files.py +4 -4
  34. singlestoredb/fusion/handlers/models.py +1 -1
  35. singlestoredb/fusion/handlers/stage.py +4 -4
  36. singlestoredb/management/cluster.py +1 -1
  37. singlestoredb/management/manager.py +15 -5
  38. singlestoredb/management/region.py +12 -2
  39. singlestoredb/management/workspace.py +17 -25
  40. singlestoredb/tests/ext_funcs/__init__.py +39 -0
  41. singlestoredb/tests/test_connection.py +18 -8
  42. singlestoredb/tests/test_management.py +24 -57
  43. singlestoredb/tests/test_udf.py +43 -15
  44. {singlestoredb-1.15.0.dist-info → singlestoredb-1.15.2.dist-info}/METADATA +1 -1
  45. {singlestoredb-1.15.0.dist-info → singlestoredb-1.15.2.dist-info}/RECORD +49 -30
  46. {singlestoredb-1.15.0.dist-info → singlestoredb-1.15.2.dist-info}/LICENSE +0 -0
  47. {singlestoredb-1.15.0.dist-info → singlestoredb-1.15.2.dist-info}/WHEEL +0 -0
  48. {singlestoredb-1.15.0.dist-info → singlestoredb-1.15.2.dist-info}/entry_points.txt +0 -0
  49. {singlestoredb-1.15.0.dist-info → singlestoredb-1.15.2.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,6 @@ from typing import List
16
16
  from typing import Optional
17
17
  from typing import Sequence
18
18
  from typing import Tuple
19
- from typing import Type
20
19
  from typing import TypeVar
21
20
  from typing import Union
22
21
 
@@ -188,6 +187,27 @@ sql_to_type_map = {
188
187
  }
189
188
 
190
189
 
190
+ @dataclasses.dataclass
191
+ class ParamSpec:
192
+ # Normalized data type of the parameter
193
+ dtype: Any
194
+
195
+ # Name of the parameter, if applicable
196
+ name: str = ''
197
+
198
+ # SQL type of the parameter
199
+ sql_type: str = ''
200
+
201
+ # Default value of the parameter, if applicable
202
+ default: Any = NO_DEFAULT
203
+
204
+ # Transformer function to apply to the parameter
205
+ transformer: Optional[Callable[..., Any]] = None
206
+
207
+ # Whether the parameter is optional (e.g., Union[T, None] or Optional[T])
208
+ is_optional: bool = False
209
+
210
+
191
211
  class Collection:
192
212
  """Base class for collection data types."""
193
213
 
@@ -519,10 +539,7 @@ def collapse_dtypes(dtypes: Union[str, List[str]], include_null: bool = False) -
519
539
  return dtypes[0] + ('?' if is_nullable else '')
520
540
 
521
541
 
522
- def get_dataclass_schema(
523
- obj: Any,
524
- include_default: bool = False,
525
- ) -> List[Union[Tuple[str, Any], Tuple[str, Any, Any]]]:
542
+ def get_dataclass_schema(obj: Any) -> List[ParamSpec]:
526
543
  """
527
544
  Get the schema of a dataclass.
528
545
 
@@ -533,25 +550,21 @@ def get_dataclass_schema(
533
550
 
534
551
  Returns
535
552
  -------
536
- List[Tuple[str, Any]] | List[Tuple[str, Any, Any]]
537
- A list of tuples containing the field names and field types
553
+ List[ParamSpec]
554
+ A list of parameter specifications for the dataclass fields
538
555
 
539
556
  """
540
- if include_default:
541
- return [
542
- (
543
- f.name, f.type,
544
- NO_DEFAULT if f.default is dataclasses.MISSING else f.default,
545
- )
546
- for f in dataclasses.fields(obj)
547
- ]
548
- return [(f.name, f.type) for f in dataclasses.fields(obj)]
557
+ return [
558
+ ParamSpec(
559
+ name=f.name,
560
+ dtype=f.type,
561
+ default=NO_DEFAULT if f.default is dataclasses.MISSING else f.default,
562
+ )
563
+ for f in dataclasses.fields(obj)
564
+ ]
549
565
 
550
566
 
551
- def get_typeddict_schema(
552
- obj: Any,
553
- include_default: bool = False,
554
- ) -> List[Union[Tuple[str, Any], Tuple[str, Any, Any]]]:
567
+ def get_typeddict_schema(obj: Any) -> List[ParamSpec]:
555
568
  """
556
569
  Get the schema of a TypedDict.
557
570
 
@@ -559,27 +572,24 @@ def get_typeddict_schema(
559
572
  ----------
560
573
  obj : TypedDict
561
574
  The TypedDict to get the schema of
562
- include_default : bool, optional
563
- Whether to include the default value in the column specification
564
575
 
565
576
  Returns
566
577
  -------
567
- List[Tuple[str, Any]] | List[Tuple[str, Any, Any]]
568
- A list of tuples containing the field names and field types
578
+ List[ParamSpec]
579
+ A list of parameter specifications for the TypedDict fields
569
580
 
570
581
  """
571
- if include_default:
572
- return [
573
- (k, v, getattr(obj, k, NO_DEFAULT))
574
- for k, v in utils.get_annotations(obj).items()
575
- ]
576
- return list(utils.get_annotations(obj).items())
582
+ return [
583
+ ParamSpec(
584
+ name=k,
585
+ dtype=v,
586
+ default=getattr(obj, k, NO_DEFAULT),
587
+ )
588
+ for k, v in utils.get_annotations(obj).items()
589
+ ]
577
590
 
578
591
 
579
- def get_pydantic_schema(
580
- obj: Any,
581
- include_default: bool = False,
582
- ) -> List[Union[Tuple[str, Any], Tuple[str, Any, Any]]]:
592
+ def get_pydantic_schema(obj: Any) -> List[ParamSpec]:
583
593
  """
584
594
  Get the schema of a pydantic model.
585
595
 
@@ -587,31 +597,26 @@ def get_pydantic_schema(
587
597
  ----------
588
598
  obj : pydantic.BaseModel
589
599
  The pydantic model to get the schema of
590
- include_default : bool, optional
591
- Whether to include the default value in the column specification
592
600
 
593
601
  Returns
594
602
  -------
595
- List[Tuple[str, Any]] | List[Tuple[str, Any, Any]]
596
- A list of tuples containing the field names and field types
603
+ List[ParamSpec]
604
+ A list of parameter specifications for the pydantic model fields
597
605
 
598
606
  """
599
607
  import pydantic_core
600
- if include_default:
601
- return [
602
- (
603
- k, v.annotation,
604
- NO_DEFAULT if v.default is pydantic_core.PydanticUndefined else v.default,
605
- )
606
- for k, v in obj.model_fields.items()
607
- ]
608
- return [(k, v.annotation) for k, v in obj.model_fields.items()]
608
+ return [
609
+ ParamSpec(
610
+ name=k,
611
+ dtype=v.annotation,
612
+ default=NO_DEFAULT
613
+ if v.default is pydantic_core.PydanticUndefined else v.default,
614
+ )
615
+ for k, v in obj.model_fields.items()
616
+ ]
609
617
 
610
618
 
611
- def get_namedtuple_schema(
612
- obj: Any,
613
- include_default: bool = False,
614
- ) -> List[Union[Tuple[Any, str], Tuple[Any, str, Any]]]:
619
+ def get_namedtuple_schema(obj: Any) -> List[ParamSpec]:
615
620
  """
616
621
  Get the schema of a named tuple.
617
622
 
@@ -619,30 +624,26 @@ def get_namedtuple_schema(
619
624
  ----------
620
625
  obj : NamedTuple
621
626
  The named tuple to get the schema of
622
- include_default : bool, optional
623
- Whether to include the default value in the column specification
624
627
 
625
628
  Returns
626
629
  -------
627
- List[Tuple[Any, str]] | List[Tuple[Any, str, Any]]
628
- A list of tuples containing the field names and field types
630
+ List[ParamSpec]
631
+ A list of parameter specifications for the named tuple fields
629
632
 
630
633
  """
631
- if include_default:
632
- return [
633
- (
634
- k, v,
635
- obj._field_defaults.get(k, NO_DEFAULT),
634
+ return [
635
+ (
636
+ ParamSpec(
637
+ name=k,
638
+ dtype=v,
639
+ default=obj._field_defaults.get(k, NO_DEFAULT),
636
640
  )
637
- for k, v in utils.get_annotations(obj).items()
638
- ]
639
- return list(utils.get_annotations(obj).items())
641
+ )
642
+ for k, v in utils.get_annotations(obj).items()
643
+ ]
640
644
 
641
645
 
642
- def get_table_schema(
643
- obj: Any,
644
- include_default: bool = False,
645
- ) -> List[Union[Tuple[Any, str], Tuple[Any, str, Any]]]:
646
+ def get_table_schema(obj: Any) -> List[ParamSpec]:
646
647
  """
647
648
  Get the schema of a Table.
648
649
 
@@ -650,90 +651,66 @@ def get_table_schema(
650
651
  ----------
651
652
  obj : Table
652
653
  The Table to get the schema of
653
- include_default : bool, optional
654
- Whether to include the default value in the column specification
655
654
 
656
655
  Returns
657
656
  -------
658
- List[Tuple[Any, str]] | List[Tuple[Any, str, Any]]
659
- A list of tuples containing the field names and field types
657
+ List[ParamSpec]
658
+ A list of parameter specifications for the Table fields
660
659
 
661
660
  """
662
- if include_default:
663
- return [
664
- (k, v, getattr(obj, k, NO_DEFAULT))
665
- for k, v in utils.get_annotations(obj).items()
666
- ]
667
- return list(utils.get_annotations(obj).items())
661
+ return [
662
+ ParamSpec(
663
+ name=k,
664
+ dtype=v,
665
+ default=getattr(obj, k, NO_DEFAULT),
666
+ )
667
+ for k, v in utils.get_annotations(obj).items()
668
+ ]
668
669
 
669
670
 
670
- def get_colspec(
671
- overrides: Any,
672
- include_default: bool = False,
673
- ) -> List[Union[Tuple[str, Any], Tuple[str, Any, Any]]]:
671
+ def get_colspec(overrides: List[Any]) -> List[ParamSpec]:
674
672
  """
675
673
  Get the column specification from the overrides.
676
674
 
677
675
  Parameters
678
676
  ----------
679
- overrides : Any
677
+ overrides : List[Any]
680
678
  The overrides to get the column specification from
681
- include_default : bool, optional
682
- Whether to include the default value in the column specification
683
679
 
684
680
  Returns
685
681
  -------
686
- List[Tuple[str, Any]] | List[Tuple[str, Any, Any]]
687
- A list of tuples containing the field names and field types
682
+ List[ParamSpec]
683
+ A list of parameter specifications for the column fields
688
684
 
689
685
  """
690
- overrides_colspec = []
686
+ if len(overrides) == 1:
691
687
 
692
- if overrides:
688
+ override = overrides[0]
693
689
 
694
690
  # Dataclass
695
- if utils.is_dataclass(overrides):
696
- overrides_colspec = get_dataclass_schema(
697
- overrides, include_default=include_default,
698
- )
691
+ if utils.is_dataclass(override):
692
+ return get_dataclass_schema(override)
699
693
 
700
694
  # TypedDict
701
- elif utils.is_typeddict(overrides):
702
- overrides_colspec = get_typeddict_schema(
703
- overrides, include_default=include_default,
704
- )
695
+ elif utils.is_typeddict(override):
696
+ return get_typeddict_schema(override)
705
697
 
706
698
  # Named tuple
707
- elif utils.is_namedtuple(overrides):
708
- overrides_colspec = get_namedtuple_schema(
709
- overrides, include_default=include_default,
710
- )
699
+ elif utils.is_namedtuple(override):
700
+ return get_namedtuple_schema(override)
711
701
 
712
702
  # Pydantic model
713
- elif utils.is_pydantic(overrides):
714
- overrides_colspec = get_pydantic_schema(
715
- overrides, include_default=include_default,
716
- )
717
-
718
- # List of types
719
- elif isinstance(overrides, list):
720
- if include_default:
721
- overrides_colspec = [
722
- (getattr(x, 'name', ''), x, NO_DEFAULT) for x in overrides
723
- ]
724
- else:
725
- overrides_colspec = [(getattr(x, 'name', ''), x) for x in overrides]
726
-
727
- # Other
728
- else:
729
- if include_default:
730
- overrides_colspec = [
731
- (getattr(overrides, 'name', ''), overrides, NO_DEFAULT),
732
- ]
733
- else:
734
- overrides_colspec = [(getattr(overrides, 'name', ''), overrides)]
703
+ elif utils.is_pydantic(override):
704
+ return get_pydantic_schema(override)
735
705
 
736
- return overrides_colspec
706
+ # List of types
707
+ return [
708
+ ParamSpec(
709
+ name=getattr(x, 'name', ''),
710
+ dtype=sql_to_dtype(x) if isinstance(x, str) else x,
711
+ sql_type=x if isinstance(x, str) else '',
712
+ ) for x in overrides
713
+ ]
737
714
 
738
715
 
739
716
  def unpack_masked_type(obj: Any) -> Any:
@@ -756,11 +733,121 @@ def unpack_masked_type(obj: Any) -> Any:
756
733
  return obj
757
734
 
758
735
 
736
+ def unwrap_optional(annotation: Any) -> Tuple[Any, bool]:
737
+ """
738
+ Unwrap Optional[T] and Union[T, None] annotations to get the underlying type.
739
+ Also indicates whether the type was optional.
740
+
741
+ Examples:
742
+ Optional[int] -> (int, True)
743
+ Union[str, None] -> (str, True)
744
+ Union[int, str, None] -> (Union[int, str], True)
745
+ Union[int, str] -> (Union[int, str], False)
746
+ int -> (int, False)
747
+
748
+ Parameters
749
+ ----------
750
+ annotation : Any
751
+ The type annotation to unwrap
752
+
753
+ Returns
754
+ -------
755
+ Tuple[Any, bool]
756
+ A tuple containing:
757
+ - The unwrapped type annotation
758
+ - A boolean indicating if the original type was optional (contained None)
759
+
760
+ """
761
+ origin = typing.get_origin(annotation)
762
+ is_optional = False
763
+
764
+ # Handle Union types (which includes Optional)
765
+ if origin is Union:
766
+ args = typing.get_args(annotation)
767
+ # Check if None is in the union
768
+ is_optional = type(None) in args
769
+
770
+ # Filter out None/NoneType
771
+ non_none_args = [arg for arg in args if arg is not type(None)]
772
+
773
+ if not non_none_args:
774
+ # If only None was in the Union
775
+ from typing import Any
776
+ return Any, is_optional
777
+ elif len(non_none_args) == 1:
778
+ # If there's only one type left, return it directly
779
+ return non_none_args[0], is_optional
780
+ else:
781
+ # Recreate the Union with the remaining types
782
+ return Union[tuple(non_none_args)], is_optional
783
+
784
+ return annotation, is_optional
785
+
786
+
787
+ def is_composite_type(spec: Any) -> bool:
788
+ """
789
+ Check if the object is a composite type (e.g., dataclass, TypedDict, etc.).
790
+
791
+ Parameters
792
+ ----------
793
+ spec : Any
794
+ The object to check
795
+
796
+ Returns
797
+ -------
798
+ bool
799
+ True if the object is a composite type, False otherwise
800
+
801
+ """
802
+ return inspect.isclass(spec) and \
803
+ (
804
+ utils.is_dataframe(spec)
805
+ or utils.is_dataclass(spec)
806
+ or utils.is_typeddict(spec)
807
+ or utils.is_pydantic(spec)
808
+ or utils.is_namedtuple(spec)
809
+ )
810
+
811
+
812
+ def check_composite_type(colspec: List[ParamSpec], mode: str, type_name: str) -> bool:
813
+ """
814
+ Check if the column specification is a composite type.
815
+
816
+ Parameters
817
+ ----------
818
+ colspec : List[ParamSpec]
819
+ The column specification to check
820
+ mode : str
821
+ The mode of the function, either 'parameter' or 'return'
822
+ type_name : str
823
+ The name of the parent type
824
+
825
+ Returns
826
+ -------
827
+ bool
828
+ Verify the composite type is valid for the given mode
829
+
830
+ """
831
+ if mode == 'parameter':
832
+ if is_composite_type(colspec[0].dtype):
833
+ raise TypeError(
834
+ 'composite types are not allowed in a '
835
+ f'{type_name}: {colspec[0].dtype.__name__}',
836
+ )
837
+ elif mode == 'return':
838
+ if is_composite_type(colspec[0].dtype):
839
+ raise TypeError(
840
+ 'composite types are not allowed in a '
841
+ f'{type_name}: {colspec[0].dtype.__name__}',
842
+ )
843
+ return False
844
+
845
+
759
846
  def get_schema(
760
847
  spec: Any,
761
- overrides: Optional[Union[List[str], Type[Any]]] = None,
848
+ overrides: Optional[List[ParamSpec]] = None,
762
849
  mode: str = 'parameter',
763
- ) -> Tuple[List[Tuple[str, Any, Optional[str]]], str, str]:
850
+ ) -> Tuple[List[ParamSpec], str, str]:
764
851
  """
765
852
  Expand a return type annotation into a list of types and field names.
766
853
 
@@ -768,23 +855,24 @@ def get_schema(
768
855
  ----------
769
856
  spec : Any
770
857
  The return type specification
771
- overrides : List[str], optional
858
+ overrides : List[ParamSpec], optional
772
859
  List of SQL type specifications for the return type
773
860
  mode : str
774
861
  The mode of the function, either 'parameter' or 'return'
775
862
 
776
863
  Returns
777
864
  -------
778
- Tuple[List[Tuple[str, Any, Optional[str]]], str, str]
779
- A list of tuples containing the field names and field types,
780
- the normalized data format, optionally the SQL
781
- definition of the type, and the data format of the type
865
+ Tuple[List[ParamSpec], str, str]
866
+ A list of parameter specifications for the function,
867
+ the normalized data format, and the SQL definition of the type
782
868
 
783
869
  """
784
870
  colspec = []
785
871
  data_format = ''
786
872
  function_type = 'udf'
873
+ udf_parameter = '`returns=`' if mode == 'return' else '`args=`'
787
874
 
875
+ spec, is_optional = unwrap_optional(spec)
788
876
  origin = typing.get_origin(spec)
789
877
  args = typing.get_args(spec)
790
878
  args_origins = [typing.get_origin(x) if x is not None else None for x in args]
@@ -833,113 +921,104 @@ def get_schema(
833
921
  )
834
922
 
835
923
  # Short circuit check for common valid types
836
- elif utils.is_vector(spec) or spec in [str, float, int, bytes]:
924
+ elif utils.is_vector(spec) or spec in {str, float, int, bytes}:
837
925
  pass
838
926
 
839
927
  # Try to catch some common mistakes
840
- elif origin in [tuple, dict] or tuple in args_origins or \
841
- (
842
- inspect.isclass(spec) and
843
- (
844
- utils.is_dataframe(spec)
845
- or utils.is_dataclass(spec)
846
- or utils.is_typeddict(spec)
847
- or utils.is_pydantic(spec)
848
- or utils.is_namedtuple(spec)
849
- )
850
- ):
928
+ elif origin in [tuple, dict] or tuple in args_origins or is_composite_type(spec):
851
929
  raise TypeError(
852
- 'invalid return type for a UDF; '
853
- f'expecting a scalar or vector, but got {spec}',
930
+ 'invalid return type for a UDF; expecting a scalar or vector, '
931
+ f'but got {getattr(spec, "__name__", spec)}',
854
932
  )
855
933
 
856
934
  # Short circuit check for common valid types
857
- elif utils.is_vector(spec) or spec in [str, float, int, bytes]:
935
+ elif utils.is_vector(spec) or spec in {str, float, int, bytes}:
858
936
  pass
859
937
 
860
938
  # Error out for incorrect parameter types
861
- elif origin in [tuple, dict] or tuple in args_origins or \
862
- (
863
- inspect.isclass(spec) and
864
- (
865
- utils.is_dataframe(spec)
866
- or utils.is_dataclass(spec)
867
- or utils.is_typeddict(spec)
868
- or utils.is_pydantic(spec)
869
- or utils.is_namedtuple(spec)
870
- )
871
- ):
872
- raise TypeError(f'parameter types must be scalar or vector, got {spec}')
939
+ elif origin in [tuple, dict] or tuple in args_origins or is_composite_type(spec):
940
+ raise TypeError(
941
+ 'parameter types must be scalar or vector, '
942
+ f'got {getattr(spec, "__name__", spec)}',
943
+ )
873
944
 
874
945
  #
875
946
  # Process each parameter / return type into a colspec
876
947
  #
877
948
 
878
- # Compute overrides colspec from various formats
879
- overrides_colspec = get_colspec(overrides)
880
-
881
949
  # Dataframe type
882
950
  if utils.is_dataframe(spec):
883
- colspec = overrides_colspec
951
+ if not overrides:
952
+ raise TypeError(
953
+ 'column types must be specified in the '
954
+ f'{udf_parameter} parameter of the @udf decorator for a DataFrame',
955
+ )
956
+ # colspec = get_colspec(overrides[0].dtype)
957
+ colspec = overrides
884
958
 
885
959
  # Numpy array types
886
960
  elif utils.is_numpy(spec):
887
961
  data_format = 'numpy'
888
962
  if overrides:
889
- colspec = overrides_colspec
963
+ colspec = overrides
890
964
  elif len(typing.get_args(spec)) < 2:
891
965
  raise TypeError(
892
- 'numpy array must have a data type specified '
893
- 'in the @udf decorator or with an NDArray type annotation',
966
+ 'numpy array must have an element data type specified '
967
+ f'in the {udf_parameter} parameter of the @udf decorator '
968
+ 'or with an NDArray type annotation',
894
969
  )
895
970
  else:
896
- colspec = [('', typing.get_args(spec)[1])]
971
+ colspec = [ParamSpec(dtype=typing.get_args(spec)[1])]
972
+ check_composite_type(colspec, mode, 'numpy array')
897
973
 
898
974
  # Pandas Series
899
975
  elif utils.is_pandas_series(spec):
900
976
  data_format = 'pandas'
901
977
  if not overrides:
902
978
  raise TypeError(
903
- 'pandas Series must have a data type specified '
904
- 'in the @udf decorator',
979
+ 'pandas Series must have an element data type specified '
980
+ f'in the {udf_parameter} parameter of the @udf decorator',
905
981
  )
906
- colspec = overrides_colspec
982
+ colspec = overrides
983
+ check_composite_type(colspec, mode, 'pandas Series')
907
984
 
908
985
  # Polars Series
909
986
  elif utils.is_polars_series(spec):
910
987
  data_format = 'polars'
911
988
  if not overrides:
912
989
  raise TypeError(
913
- 'polars Series must have a data type specified '
914
- 'in the @udf decorator',
990
+ 'polars Series must have an element data type specified '
991
+ f'in the {udf_parameter} parameter of the @udf decorator',
915
992
  )
916
- colspec = overrides_colspec
993
+ colspec = overrides
994
+ check_composite_type(colspec, mode, 'polars Series')
917
995
 
918
996
  # PyArrow Array
919
997
  elif utils.is_pyarrow_array(spec):
920
998
  data_format = 'arrow'
921
999
  if not overrides:
922
1000
  raise TypeError(
923
- 'pyarrow Arrays must have a data type specified '
924
- 'in the @udf decorator',
1001
+ 'pyarrow Arrays must have an element data type specified '
1002
+ f'in the {udf_parameter} parameter of the @udf decorator',
925
1003
  )
926
- colspec = overrides_colspec
1004
+ colspec = overrides
1005
+ check_composite_type(colspec, mode, 'pyarrow Array')
927
1006
 
928
1007
  # Return type is specified by a dataclass definition
929
1008
  elif utils.is_dataclass(spec):
930
- colspec = overrides_colspec or get_dataclass_schema(spec)
1009
+ colspec = overrides or get_dataclass_schema(spec)
931
1010
 
932
1011
  # Return type is specified by a TypedDict definition
933
1012
  elif utils.is_typeddict(spec):
934
- colspec = overrides_colspec or get_typeddict_schema(spec)
1013
+ colspec = overrides or get_typeddict_schema(spec)
935
1014
 
936
1015
  # Return type is specified by a pydantic model
937
1016
  elif utils.is_pydantic(spec):
938
- colspec = overrides_colspec or get_pydantic_schema(spec)
1017
+ colspec = overrides or get_pydantic_schema(spec)
939
1018
 
940
1019
  # Return type is specified by a named tuple
941
1020
  elif utils.is_namedtuple(spec):
942
- colspec = overrides_colspec or get_namedtuple_schema(spec)
1021
+ colspec = overrides or get_namedtuple_schema(spec)
943
1022
 
944
1023
  # Unrecognized return type
945
1024
  elif spec is not None:
@@ -947,30 +1026,20 @@ def get_schema(
947
1026
  # Return type is specified by a SQL string
948
1027
  if isinstance(spec, str):
949
1028
  data_format = 'scalar'
950
- colspec = [(getattr(spec, 'name', ''), spec)]
1029
+ colspec = [ParamSpec(dtype=spec, is_optional=is_optional)]
951
1030
 
952
1031
  # Plain list vector
953
1032
  elif typing.get_origin(spec) is list:
954
1033
  data_format = 'list'
955
- colspec = [('', typing.get_args(spec)[0])]
1034
+ colspec = [ParamSpec(dtype=typing.get_args(spec)[0], is_optional=is_optional)]
956
1035
 
957
1036
  # Multiple return values
958
1037
  elif inspect.isclass(typing.get_origin(spec)) \
959
1038
  and issubclass(typing.get_origin(spec), tuple): # type: ignore[arg-type]
960
1039
 
961
- out_names, out_overrides = [], []
962
-
963
- # Get the colspec for the overrides
964
- if overrides:
965
- out_colspec = [
966
- x for x in get_colspec(overrides, include_default=True)
967
- ]
968
- out_names = [x[0] for x in out_colspec]
969
- out_overrides = [x[1] for x in out_colspec]
970
-
971
1040
  # Make sure that the number of overrides matches the number of
972
1041
  # return types or parameter types
973
- if out_overrides and len(typing.get_args(spec)) != len(out_overrides):
1042
+ if overrides and len(typing.get_args(spec)) != len(overrides):
974
1043
  raise ValueError(
975
1044
  f'number of {mode} types does not match the number of '
976
1045
  'overrides specified',
@@ -981,20 +1050,21 @@ def get_schema(
981
1050
 
982
1051
  # Get the colspec for each item in the tuple
983
1052
  for i, x in enumerate(typing.get_args(spec)):
984
- out_item, out_data_format, _ = get_schema(
1053
+ params, out_data_format, _ = get_schema(
985
1054
  unpack_masked_type(x),
986
- overrides=out_overrides[i] if out_overrides else [],
1055
+ overrides=[overrides[i]] if overrides else [],
987
1056
  # Always pass UDF mode for individual items
988
1057
  mode=mode,
989
1058
  )
990
1059
 
991
1060
  # Use the name from the overrides if specified
992
- if out_names and out_names[i] and not out_item[0][0]:
993
- out_item = [(out_names[i], *out_item[0][1:])]
994
- elif not out_item[0][0]:
995
- out_item = [(f'{string.ascii_letters[i]}', *out_item[0][1:])]
1061
+ if overrides:
1062
+ if overrides[i] and not params[0].name:
1063
+ params[0].name = overrides[i].name
1064
+ elif not overrides[i].name:
1065
+ params[0].name = f'{string.ascii_letters[i]}'
996
1066
 
997
- colspec += out_item
1067
+ colspec.append(params[0])
998
1068
  out_data_formats.append(out_data_format)
999
1069
 
1000
1070
  # Make sure that all the data formats are the same
@@ -1015,25 +1085,35 @@ def get_schema(
1015
1085
  elif overrides:
1016
1086
  if not data_format:
1017
1087
  data_format = get_data_format(spec)
1018
- colspec = overrides_colspec
1088
+ colspec = overrides
1019
1089
 
1020
1090
  # Single value, no override
1021
1091
  else:
1022
1092
  if not data_format:
1023
1093
  data_format = 'scalar'
1024
- colspec = [('', spec)]
1094
+ colspec = [ParamSpec(dtype=spec, is_optional=is_optional)]
1025
1095
 
1026
1096
  out = []
1027
1097
 
1028
1098
  # Normalize colspec data types
1029
- for k, v, *_ in colspec:
1030
- out.append((
1031
- k,
1032
- collapse_dtypes(
1033
- [normalize_dtype(x) for x in simplify_dtype(v)],
1034
- ),
1035
- v if isinstance(v, str) else None,
1036
- ))
1099
+ for c in colspec:
1100
+
1101
+ if isinstance(c.dtype, str):
1102
+ dtype = c.dtype
1103
+ else:
1104
+ dtype = collapse_dtypes(
1105
+ [normalize_dtype(x) for x in simplify_dtype(c.dtype)],
1106
+ include_null=c.is_optional,
1107
+ )
1108
+
1109
+ p = ParamSpec(
1110
+ name=c.name,
1111
+ dtype=dtype,
1112
+ sql_type=c.sql_type if isinstance(c.sql_type, str) else None,
1113
+ is_optional=c.is_optional,
1114
+ )
1115
+
1116
+ out.append(p)
1037
1117
 
1038
1118
  return out, data_format, function_type
1039
1119
 
@@ -1149,14 +1229,12 @@ def get_signature(
1149
1229
  # TODO: Use typing.get_type_hints() for parameters / return values?
1150
1230
 
1151
1231
  # Generate the parameter type and the corresponding SQL code for that parameter
1152
- args_schema = []
1232
+ args_schema: List[ParamSpec] = []
1153
1233
  args_data_formats = []
1154
- args_colspec = [x for x in get_colspec(attrs.get('args', []), include_default=True)]
1155
- args_overrides = [x[1] for x in args_colspec]
1156
- args_defaults = [x[2] for x in args_colspec] # type: ignore
1234
+ args_colspec = [x for x in get_colspec(attrs.get('args', []))]
1157
1235
  args_masks, ret_masks = get_masks(func)
1158
1236
 
1159
- if args_overrides and len(args_overrides) != len(signature.parameters):
1237
+ if args_colspec and len(args_colspec) != len(signature.parameters):
1160
1238
  raise ValueError(
1161
1239
  'number of args in the decorator does not match '
1162
1240
  'the number of parameters in the function signature',
@@ -1168,33 +1246,49 @@ def get_signature(
1168
1246
  for i, param in enumerate(params):
1169
1247
  arg_schema, args_data_format, _ = get_schema(
1170
1248
  unpack_masked_type(param.annotation),
1171
- overrides=args_overrides[i] if args_overrides else [],
1249
+ overrides=[args_colspec[i]] if args_colspec else [],
1172
1250
  mode='parameter',
1173
1251
  )
1174
1252
  args_data_formats.append(args_data_format)
1175
1253
 
1254
+ if len(arg_schema) > 1:
1255
+ raise TypeError(
1256
+ 'only one parameter type is supported; '
1257
+ f'got {len(arg_schema)} types for parameter {param.name}',
1258
+ )
1259
+
1176
1260
  # Insert parameter names as needed
1177
- if not arg_schema[0][0]:
1178
- args_schema.append((param.name, *arg_schema[0][1:]))
1261
+ if not arg_schema[0].name:
1262
+ arg_schema[0].name = param.name
1179
1263
 
1180
- for i, (name, atype, sql) in enumerate(args_schema):
1264
+ args_schema.append(arg_schema[0])
1265
+
1266
+ for i, pspec in enumerate(args_schema):
1181
1267
  default_option = {}
1182
1268
 
1183
1269
  # Insert default values as needed
1184
- if args_defaults:
1185
- if args_defaults[i] is not NO_DEFAULT:
1186
- default_option['default'] = args_defaults[i]
1187
- else:
1188
- if params[i].default is not param.empty:
1189
- default_option['default'] = params[i].default
1270
+ if args_colspec and args_colspec[i].default is not NO_DEFAULT:
1271
+ default_option['default'] = args_colspec[i].default
1272
+ elif params and params[i].default is not param.empty:
1273
+ default_option['default'] = params[i].default
1190
1274
 
1191
1275
  # Generate SQL code for the parameter
1192
- sql = sql or dtype_to_sql(
1193
- atype, force_nullable=args_masks[i], **default_option,
1276
+ sql = pspec.sql_type or dtype_to_sql(
1277
+ pspec.dtype,
1278
+ force_nullable=args_masks[i] or pspec.is_optional,
1279
+ **default_option,
1194
1280
  )
1195
1281
 
1196
1282
  # Add parameter to args definitions
1197
- args.append(dict(name=name, dtype=atype, sql=sql, **default_option))
1283
+ args.append(
1284
+ dict(
1285
+ name=pspec.name,
1286
+ dtype=pspec.dtype,
1287
+ sql=sql,
1288
+ **default_option,
1289
+ transformer=pspec.transformer,
1290
+ ),
1291
+ )
1198
1292
 
1199
1293
  # Check that all the data formats are all the same
1200
1294
  if len(set(args_data_formats)) > 1:
@@ -1203,37 +1297,73 @@ def get_signature(
1203
1297
  f'{", ".join(args_data_formats)}',
1204
1298
  )
1205
1299
 
1206
- out['args_data_format'] = args_data_formats[0] if args_data_formats else 'scalar'
1300
+ adf = out['args_data_format'] = args_data_formats[0] \
1301
+ if args_data_formats else 'scalar'
1302
+
1303
+ returns_colspec = get_colspec(attrs.get('returns', []))
1207
1304
 
1208
1305
  # Generate the return types and the corresponding SQL code for those values
1209
1306
  ret_schema, out['returns_data_format'], function_type = get_schema(
1210
1307
  unpack_masked_type(signature.return_annotation),
1211
- overrides=attrs.get('returns', None),
1308
+ overrides=returns_colspec if returns_colspec else None,
1212
1309
  mode='return',
1213
1310
  )
1214
1311
 
1215
- out['returns_data_format'] = out['returns_data_format'] or 'scalar'
1312
+ rdf = out['returns_data_format'] = out['returns_data_format'] or 'scalar'
1216
1313
  out['function_type'] = function_type
1217
1314
 
1315
+ # Reality check the input and output data formats
1316
+ if function_type == 'udf':
1317
+ if (adf == 'scalar' and rdf != 'scalar') or \
1318
+ (adf != 'scalar' and rdf == 'scalar'):
1319
+ raise TypeError(
1320
+ 'Function can not have scalar arguments and a vector return type, '
1321
+ 'or vice versa. Parameters and return values must all be either ',
1322
+ 'scalar or vector types.',
1323
+ )
1324
+
1218
1325
  # All functions have to return a value, so if none was specified try to
1219
1326
  # insert a reasonable default that includes NULLs.
1220
1327
  if not ret_schema:
1221
- ret_schema = [('', 'int8?', 'TINYINT NULL')]
1328
+ ret_schema = [
1329
+ ParamSpec(
1330
+ dtype='int8?', sql_type='TINYINT NULL', default=None, is_optional=True,
1331
+ ),
1332
+ ]
1333
+
1334
+ if function_type == 'udf' and len(ret_schema) > 1:
1335
+ raise ValueError(
1336
+ 'UDFs can only return a single value; '
1337
+ f'got {len(ret_schema)} return values',
1338
+ )
1222
1339
 
1223
1340
  # Generate field names for the return values
1224
1341
  if function_type == 'tvf' or len(ret_schema) > 1:
1225
- for i, (name, rtype, sql) in enumerate(ret_schema):
1226
- if not name:
1227
- ret_schema[i] = (string.ascii_letters[i], rtype, sql)
1342
+ for i, rspec in enumerate(ret_schema):
1343
+ if not rspec.name:
1344
+ ret_schema[i] = ParamSpec(
1345
+ name=string.ascii_letters[i],
1346
+ dtype=rspec.dtype,
1347
+ sql_type=rspec.sql_type,
1348
+ transformer=rspec.transformer,
1349
+ )
1228
1350
 
1229
1351
  # Generate SQL code for the return values
1230
- for i, (name, rtype, sql) in enumerate(ret_schema):
1231
- sql = sql or dtype_to_sql(
1232
- rtype,
1233
- force_nullable=ret_masks[i] if ret_masks else False,
1352
+ for i, rspec in enumerate(ret_schema):
1353
+ sql = rspec.sql_type or dtype_to_sql(
1354
+ rspec.dtype,
1355
+ force_nullable=(ret_masks[i] or rspec.is_optional)
1356
+ if ret_masks else rspec.is_optional,
1234
1357
  function_type=function_type,
1235
1358
  )
1236
- returns.append(dict(name=name, dtype=rtype, sql=sql))
1359
+ returns.append(
1360
+ dict(
1361
+ name=rspec.name,
1362
+ dtype=rspec.dtype,
1363
+ sql=sql,
1364
+ transformer=rspec.transformer,
1365
+ ),
1366
+ )
1237
1367
 
1238
1368
  # Set the function endpoint
1239
1369
  out['endpoint'] = '/invoke'
@@ -1370,6 +1500,7 @@ def signature_to_sql(
1370
1500
  app_mode: str = 'remote',
1371
1501
  link: Optional[str] = None,
1372
1502
  replace: bool = False,
1503
+ database: Optional[str] = None,
1373
1504
  ) -> str:
1374
1505
  '''
1375
1506
  Convert a dictionary function signature into SQL.
@@ -1424,9 +1555,11 @@ def signature_to_sql(
1424
1555
  elif url is None:
1425
1556
  raise ValueError('url can not be `None`')
1426
1557
 
1427
- database = ''
1558
+ database_prefix = ''
1428
1559
  if signature.get('database'):
1429
- database = escape_name(signature['database']) + '.'
1560
+ database_prefix = escape_name(signature['database']) + '.'
1561
+ elif database is not None:
1562
+ database_prefix = escape_name(database) + '.'
1430
1563
 
1431
1564
  or_replace = 'OR REPLACE ' if (bool(signature.get('replace')) or replace) else ''
1432
1565
 
@@ -1438,7 +1571,7 @@ def signature_to_sql(
1438
1571
 
1439
1572
  return (
1440
1573
  f'CREATE {or_replace}EXTERNAL FUNCTION ' +
1441
- f'{database}{escape_name(signature["name"])}' +
1574
+ f'{database_prefix}{escape_name(signature["name"])}' +
1442
1575
  '(' + ', '.join(args) + ')' + returns +
1443
1576
  f' AS {app_mode.upper()} SERVICE "{url}" FORMAT {data_format.upper()}'
1444
1577
  f'{link_str};'