singlestoredb 0.3.3__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of singlestoredb might be problematic. Click here for more details.

Files changed (121) hide show
  1. singlestoredb/__init__.py +33 -2
  2. singlestoredb/alchemy/__init__.py +90 -0
  3. singlestoredb/auth.py +6 -4
  4. singlestoredb/config.py +116 -16
  5. singlestoredb/connection.py +489 -523
  6. singlestoredb/converters.py +275 -26
  7. singlestoredb/exceptions.py +30 -4
  8. singlestoredb/functions/__init__.py +1 -0
  9. singlestoredb/functions/decorator.py +142 -0
  10. singlestoredb/functions/dtypes.py +1639 -0
  11. singlestoredb/functions/ext/__init__.py +2 -0
  12. singlestoredb/functions/ext/arrow.py +375 -0
  13. singlestoredb/functions/ext/asgi.py +661 -0
  14. singlestoredb/functions/ext/json.py +427 -0
  15. singlestoredb/functions/ext/mmap.py +306 -0
  16. singlestoredb/functions/ext/rowdat_1.py +744 -0
  17. singlestoredb/functions/signature.py +673 -0
  18. singlestoredb/fusion/__init__.py +11 -0
  19. singlestoredb/fusion/graphql.py +213 -0
  20. singlestoredb/fusion/handler.py +621 -0
  21. singlestoredb/fusion/handlers/__init__.py +0 -0
  22. singlestoredb/fusion/handlers/stage.py +257 -0
  23. singlestoredb/fusion/handlers/utils.py +162 -0
  24. singlestoredb/fusion/handlers/workspace.py +412 -0
  25. singlestoredb/fusion/registry.py +164 -0
  26. singlestoredb/fusion/result.py +399 -0
  27. singlestoredb/http/__init__.py +27 -0
  28. singlestoredb/http/connection.py +1192 -0
  29. singlestoredb/management/__init__.py +3 -2
  30. singlestoredb/management/billing_usage.py +148 -0
  31. singlestoredb/management/cluster.py +19 -14
  32. singlestoredb/management/manager.py +100 -40
  33. singlestoredb/management/organization.py +188 -0
  34. singlestoredb/management/region.py +6 -8
  35. singlestoredb/management/utils.py +253 -4
  36. singlestoredb/management/workspace.py +1153 -35
  37. singlestoredb/mysql/__init__.py +177 -0
  38. singlestoredb/mysql/_auth.py +298 -0
  39. singlestoredb/mysql/charset.py +214 -0
  40. singlestoredb/mysql/connection.py +1814 -0
  41. singlestoredb/mysql/constants/CLIENT.py +38 -0
  42. singlestoredb/mysql/constants/COMMAND.py +32 -0
  43. singlestoredb/mysql/constants/CR.py +78 -0
  44. singlestoredb/mysql/constants/ER.py +474 -0
  45. singlestoredb/mysql/constants/FIELD_TYPE.py +32 -0
  46. singlestoredb/mysql/constants/FLAG.py +15 -0
  47. singlestoredb/mysql/constants/SERVER_STATUS.py +10 -0
  48. singlestoredb/mysql/constants/__init__.py +0 -0
  49. singlestoredb/mysql/converters.py +271 -0
  50. singlestoredb/mysql/cursors.py +713 -0
  51. singlestoredb/mysql/err.py +92 -0
  52. singlestoredb/mysql/optionfile.py +20 -0
  53. singlestoredb/mysql/protocol.py +388 -0
  54. singlestoredb/mysql/tests/__init__.py +19 -0
  55. singlestoredb/mysql/tests/base.py +126 -0
  56. singlestoredb/mysql/tests/conftest.py +37 -0
  57. singlestoredb/mysql/tests/test_DictCursor.py +132 -0
  58. singlestoredb/mysql/tests/test_SSCursor.py +141 -0
  59. singlestoredb/mysql/tests/test_basic.py +452 -0
  60. singlestoredb/mysql/tests/test_connection.py +851 -0
  61. singlestoredb/mysql/tests/test_converters.py +58 -0
  62. singlestoredb/mysql/tests/test_cursor.py +141 -0
  63. singlestoredb/mysql/tests/test_err.py +16 -0
  64. singlestoredb/mysql/tests/test_issues.py +514 -0
  65. singlestoredb/mysql/tests/test_load_local.py +75 -0
  66. singlestoredb/mysql/tests/test_nextset.py +88 -0
  67. singlestoredb/mysql/tests/test_optionfile.py +27 -0
  68. singlestoredb/mysql/tests/thirdparty/__init__.py +6 -0
  69. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/__init__.py +9 -0
  70. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/capabilities.py +323 -0
  71. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/dbapi20.py +865 -0
  72. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py +110 -0
  73. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py +224 -0
  74. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py +101 -0
  75. singlestoredb/mysql/times.py +23 -0
  76. singlestoredb/pytest.py +283 -0
  77. singlestoredb/tests/empty.sql +0 -0
  78. singlestoredb/tests/ext_funcs/__init__.py +385 -0
  79. singlestoredb/tests/test.sql +210 -0
  80. singlestoredb/tests/test2.sql +1 -0
  81. singlestoredb/tests/test_basics.py +482 -117
  82. singlestoredb/tests/test_config.py +13 -15
  83. singlestoredb/tests/test_connection.py +241 -289
  84. singlestoredb/tests/test_dbapi.py +27 -0
  85. singlestoredb/tests/test_exceptions.py +0 -2
  86. singlestoredb/tests/test_ext_func.py +1193 -0
  87. singlestoredb/tests/test_ext_func_data.py +1101 -0
  88. singlestoredb/tests/test_fusion.py +465 -0
  89. singlestoredb/tests/test_http.py +32 -28
  90. singlestoredb/tests/test_management.py +588 -10
  91. singlestoredb/tests/test_plugin.py +33 -0
  92. singlestoredb/tests/test_results.py +11 -14
  93. singlestoredb/tests/test_types.py +0 -2
  94. singlestoredb/tests/test_udf.py +687 -0
  95. singlestoredb/tests/test_xdict.py +0 -2
  96. singlestoredb/tests/utils.py +3 -4
  97. singlestoredb/types.py +4 -5
  98. singlestoredb/utils/config.py +71 -12
  99. singlestoredb/utils/convert_rows.py +0 -2
  100. singlestoredb/utils/debug.py +13 -0
  101. singlestoredb/utils/mogrify.py +151 -0
  102. singlestoredb/utils/results.py +4 -3
  103. singlestoredb/utils/xdict.py +12 -12
  104. singlestoredb-1.0.3.dist-info/METADATA +139 -0
  105. singlestoredb-1.0.3.dist-info/RECORD +112 -0
  106. {singlestoredb-0.3.3.dist-info → singlestoredb-1.0.3.dist-info}/WHEEL +1 -1
  107. singlestoredb-1.0.3.dist-info/entry_points.txt +2 -0
  108. singlestoredb/drivers/__init__.py +0 -46
  109. singlestoredb/drivers/base.py +0 -200
  110. singlestoredb/drivers/cymysql.py +0 -40
  111. singlestoredb/drivers/http.py +0 -49
  112. singlestoredb/drivers/mariadb.py +0 -42
  113. singlestoredb/drivers/mysqlconnector.py +0 -51
  114. singlestoredb/drivers/mysqldb.py +0 -62
  115. singlestoredb/drivers/pymysql.py +0 -39
  116. singlestoredb/drivers/pyodbc.py +0 -67
  117. singlestoredb/http.py +0 -794
  118. singlestoredb-0.3.3.dist-info/METADATA +0 -105
  119. singlestoredb-0.3.3.dist-info/RECORD +0 -46
  120. {singlestoredb-0.3.3.dist-info → singlestoredb-1.0.3.dist-info}/LICENSE +0 -0
  121. {singlestoredb-0.3.3.dist-info → singlestoredb-1.0.3.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,7 @@
1
1
  #!/usr/bin/env python
2
2
  """Data value conversion utilities."""
3
- from __future__ import annotations
4
-
5
3
  import datetime
4
+ import re
6
5
  from base64 import b64decode
7
6
  from decimal import Decimal
8
7
  from json import loads as json_loads
@@ -14,13 +13,246 @@ from typing import Optional
14
13
  from typing import Set
15
14
  from typing import Union
16
15
 
16
+ try:
17
+ import shapely.wkt
18
+ has_shapely = True
19
+ except ImportError:
20
+ has_shapely = False
21
+
22
+ try:
23
+ import pygeos
24
+ has_pygeos = True
25
+ except ImportError:
26
+ has_pygeos = False
27
+
28
+
29
+ # Cache fromisoformat methods if they exist
30
+ _dt_datetime_fromisoformat = None
31
+ if hasattr(datetime.datetime, 'fromisoformat'):
32
+ _dt_datetime_fromisoformat = datetime.datetime.fromisoformat # type: ignore
33
+ _dt_time_fromisoformat = None
34
+ if hasattr(datetime.time, 'fromisoformat'):
35
+ _dt_time_fromisoformat = datetime.time.fromisoformat # type: ignore
36
+ _dt_date_fromisoformat = None
37
+ if hasattr(datetime.date, 'fromisoformat'):
38
+ _dt_date_fromisoformat = datetime.date.fromisoformat # type: ignore
39
+
40
+
41
+ def _convert_second_fraction(s: str) -> int:
42
+ if not s:
43
+ return 0
44
+ # Pad zeros to ensure the fraction length in microseconds
45
+ s = s.ljust(6, '0')
46
+ return int(s[:6])
47
+
48
+
49
+ DATETIME_RE = re.compile(
50
+ r'(\d{1,4})-(\d{1,2})-(\d{1,2})[T ](\d{1,2}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?',
51
+ )
52
+
53
+ ZERO_DATETIMES = set([
54
+ '0000-00-00 00:00:00',
55
+ '0000-00-00 00:00:00.000',
56
+ '0000-00-00 00:00:00.000000',
57
+ '0000-00-00T00:00:00',
58
+ '0000-00-00T00:00:00.000',
59
+ '0000-00-00T00:00:00.000000',
60
+ ])
61
+ ZERO_DATES = set([
62
+ '0000-00-00',
63
+ ])
64
+
65
+
66
+ def datetime_fromisoformat(
67
+ obj: Union[str, bytes, bytearray],
68
+ ) -> Union[datetime.datetime, str, None]:
69
+ """Returns a DATETIME or TIMESTAMP column value as a datetime object:
70
+
71
+ >>> datetime_fromisoformat('2007-02-25 23:06:20')
72
+ datetime.datetime(2007, 2, 25, 23, 6, 20)
73
+ >>> datetime_fromisoformat('2007-02-25T23:06:20')
74
+ datetime.datetime(2007, 2, 25, 23, 6, 20)
75
+
76
+ Illegal values are returned as str or None:
77
+
78
+ >>> datetime_fromisoformat('2007-02-31T23:06:20')
79
+ '2007-02-31T23:06:20'
80
+ >>> datetime_fromisoformat('0000-00-00 00:00:00')
81
+ None
82
+
83
+ """
84
+ if isinstance(obj, (bytes, bytearray)):
85
+ obj = obj.decode('ascii')
86
+
87
+ if obj in ZERO_DATETIMES:
88
+ return None
89
+
90
+ # Use datetime methods if possible
91
+ if _dt_datetime_fromisoformat is not None:
92
+ try:
93
+ if ' ' in obj or 'T' in obj:
94
+ return _dt_datetime_fromisoformat(obj)
95
+ if _dt_date_fromisoformat is not None:
96
+ date = _dt_date_fromisoformat(obj)
97
+ return datetime.datetime(date.year, date.month, date.day)
98
+ except ValueError:
99
+ return obj
100
+
101
+ m = DATETIME_RE.match(obj)
102
+ if not m:
103
+ mdate = date_fromisoformat(obj)
104
+ if type(mdate) is str:
105
+ return mdate
106
+ return datetime.datetime(mdate.year, mdate.month, mdate.day) # type: ignore
107
+
108
+ try:
109
+ groups = list(m.groups())
110
+ groups[-1] = _convert_second_fraction(groups[-1])
111
+ return datetime.datetime(*[int(x) for x in groups]) # type: ignore
112
+ except ValueError:
113
+ mdate = date_fromisoformat(obj)
114
+ if type(mdate) is str:
115
+ return mdate
116
+ return datetime.datetime(mdate.year, mdate.month, mdate.day) # type: ignore
117
+
118
+
119
+ TIMEDELTA_RE = re.compile(r'(-)?(\d{1,3}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?')
120
+
121
+
122
+ def timedelta_fromisoformat(
123
+ obj: Union[str, bytes, bytearray],
124
+ ) -> Union[datetime.timedelta, str, None]:
125
+ """Returns a TIME column as a timedelta object:
126
+
127
+ >>> timedelta_fromisoformat('25:06:17')
128
+ datetime.timedelta(days=1, seconds=3977)
129
+ >>> timedelta_fromisoformat('-25:06:17')
130
+ datetime.timedelta(days=-2, seconds=82423)
131
+
132
+ Illegal values are returned as string:
133
+
134
+ >>> timedelta_fromisoformat('random crap')
135
+ 'random crap'
136
+
137
+ Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but
138
+ can accept values as (+|-)DD HH:MM:SS. The latter format will not
139
+ be parsed correctly by this function.
140
+ """
141
+ if isinstance(obj, (bytes, bytearray)):
142
+ obj = obj.decode('ascii')
143
+
144
+ m = TIMEDELTA_RE.match(obj)
145
+ if not m:
146
+ return obj
147
+
148
+ try:
149
+ groups = list(m.groups())
150
+ groups[-1] = _convert_second_fraction(groups[-1])
151
+ negate = -1 if groups[0] else 1
152
+ hours, minutes, seconds, microseconds = groups[1:]
153
+
154
+ tdelta = (
155
+ datetime.timedelta(
156
+ hours=int(hours),
157
+ minutes=int(minutes),
158
+ seconds=int(seconds),
159
+ microseconds=int(microseconds),
160
+ )
161
+ * negate
162
+ )
163
+ return tdelta
164
+ except ValueError:
165
+ return obj
166
+
167
+
168
+ TIME_RE = re.compile(r'(\d{1,2}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?')
169
+
170
+
171
+ def time_fromisoformat(
172
+ obj: Union[str, bytes, bytearray],
173
+ ) -> Union[datetime.time, str, None]:
174
+ """Returns a TIME column as a time object:
175
+
176
+ >>> time_fromisoformat('15:06:17')
177
+ datetime.time(15, 6, 17)
178
+
179
+ Illegal values are returned as str:
180
+
181
+ >>> time_fromisoformat('-25:06:17')
182
+ '-25:06:17'
183
+ >>> time_fromisoformat('random crap')
184
+ 'random crap'
185
+
186
+ Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but
187
+ can accept values as (+|-)DD HH:MM:SS. The latter format will not
188
+ be parsed correctly by this function.
189
+
190
+ Also note that MySQL's TIME column corresponds more closely to
191
+ Python's timedelta and not time. However if you want TIME columns
192
+ to be treated as time-of-day and not a time offset, then you can
193
+ use set this function as the converter for FIELD_TYPE.TIME.
194
+ """
195
+ if isinstance(obj, (bytes, bytearray)):
196
+ obj = obj.decode('ascii')
197
+
198
+ # Use datetime methods if possible
199
+ if _dt_time_fromisoformat is not None:
200
+ try:
201
+ return _dt_time_fromisoformat(obj)
202
+ except ValueError:
203
+ return obj
204
+
205
+ m = TIME_RE.match(obj)
206
+ if not m:
207
+ return obj
208
+
209
+ try:
210
+ groups = list(m.groups())
211
+ groups[-1] = _convert_second_fraction(groups[-1])
212
+ hours, minutes, seconds, microseconds = groups
213
+ return datetime.time(
214
+ hour=int(hours),
215
+ minute=int(minutes),
216
+ second=int(seconds),
217
+ microsecond=int(microseconds),
218
+ )
219
+ except ValueError:
220
+ return obj
221
+
222
+
223
+ def date_fromisoformat(
224
+ obj: Union[str, bytes, bytearray],
225
+ ) -> Union[datetime.date, str, None]:
226
+ """Returns a DATE column as a date object:
17
227
 
18
- datetime_fromisoformat = datetime.datetime.fromisoformat
19
- time_fromisoformat = datetime.time.fromisoformat
20
- date_fromisoformat = datetime.date.fromisoformat
21
- datetime_min = datetime.datetime.min
22
- date_min = datetime.date.min
23
- datetime_combine = datetime.datetime.combine
228
+ >>> date_fromisoformat('2007-02-26')
229
+ datetime.date(2007, 2, 26)
230
+
231
+ Illegal values are returned as str or None:
232
+
233
+ >>> date_fromisoformat('2007-02-31')
234
+ '2007-02-31'
235
+ >>> date_fromisoformat('0000-00-00')
236
+ None
237
+
238
+ """
239
+ if isinstance(obj, (bytes, bytearray)):
240
+ obj = obj.decode('ascii')
241
+
242
+ if obj in ZERO_DATES:
243
+ return None
244
+
245
+ # Use datetime methods if possible
246
+ if _dt_date_fromisoformat is not None:
247
+ try:
248
+ return _dt_date_fromisoformat(obj)
249
+ except ValueError:
250
+ return obj
251
+
252
+ try:
253
+ return datetime.date(*[int(x) for x in obj.split('-', 2)])
254
+ except ValueError:
255
+ return obj
24
256
 
25
257
 
26
258
  def identity(x: Any) -> Optional[Any]:
@@ -118,7 +350,7 @@ def decimal_or_none(x: Any) -> Optional[Decimal]:
118
350
  return Decimal(x)
119
351
 
120
352
 
121
- def date_or_none(x: Optional[str]) -> Optional[datetime.date]:
353
+ def date_or_none(x: Optional[str]) -> Optional[Union[datetime.date, str]]:
122
354
  """
123
355
  Convert value to a date.
124
356
 
@@ -137,13 +369,10 @@ def date_or_none(x: Optional[str]) -> Optional[datetime.date]:
137
369
  """
138
370
  if x is None:
139
371
  return None
140
- try:
141
- return date_fromisoformat(x)
142
- except ValueError:
143
- return None
372
+ return date_fromisoformat(x)
144
373
 
145
374
 
146
- def time_or_none(x: Optional[str]) -> Optional[datetime.timedelta]:
375
+ def timedelta_or_none(x: Optional[str]) -> Optional[Union[datetime.timedelta, str]]:
147
376
  """
148
377
  Convert value to a timedelta.
149
378
 
@@ -162,13 +391,32 @@ def time_or_none(x: Optional[str]) -> Optional[datetime.timedelta]:
162
391
  """
163
392
  if x is None:
164
393
  return None
165
- try:
166
- return datetime_combine(date_min, time_fromisoformat(x)) - datetime_min
167
- except ValueError:
394
+ return timedelta_fromisoformat(x)
395
+
396
+
397
+ def time_or_none(x: Optional[str]) -> Optional[Union[datetime.time, str]]:
398
+ """
399
+ Convert value to a time.
400
+
401
+ Parameters
402
+ ----------
403
+ x : Any
404
+ Arbitrary value
405
+
406
+ Returns
407
+ -------
408
+ datetime.time
409
+ If value can be cast to a time
410
+ None
411
+ If input value is None
412
+
413
+ """
414
+ if x is None:
168
415
  return None
416
+ return time_fromisoformat(x)
169
417
 
170
418
 
171
- def datetime_or_none(x: Optional[str]) -> Optional[datetime.datetime]:
419
+ def datetime_or_none(x: Optional[str]) -> Optional[Union[datetime.datetime, str]]:
172
420
  """
173
421
  Convert value to a datetime.
174
422
 
@@ -187,10 +435,7 @@ def datetime_or_none(x: Optional[str]) -> Optional[datetime.datetime]:
187
435
  """
188
436
  if x is None:
189
437
  return None
190
- try:
191
- return datetime_fromisoformat(x)
192
- except ValueError:
193
- return None
438
+ return datetime_fromisoformat(x)
194
439
 
195
440
 
196
441
  def none(x: Any) -> None:
@@ -267,14 +512,18 @@ def geometry_or_none(x: Optional[str]) -> Optional[Any]:
267
512
 
268
513
  Returns
269
514
  -------
270
- ???
515
+ shapely object or pygeos object or str
271
516
  If value is valid geometry value
272
517
  None
273
- If input value is None
518
+ If input value is None or empty
274
519
 
275
520
  """
276
- if x is None:
521
+ if x is None or not x:
277
522
  return None
523
+ if has_shapely:
524
+ return shapely.wkt.loads(x)
525
+ if has_pygeos:
526
+ return pygeos.io.from_wkt(x)
278
527
  return x
279
528
 
280
529
 
@@ -291,7 +540,7 @@ converters: Dict[int, Callable[..., Any]] = {
291
540
  8: int_or_none,
292
541
  9: int_or_none,
293
542
  10: date_or_none,
294
- 11: time_or_none,
543
+ 11: timedelta_or_none,
295
544
  12: datetime_or_none,
296
545
  13: int_or_none,
297
546
  14: date_or_none,
@@ -1,11 +1,13 @@
1
1
  #!/usr/bin/env python
2
2
  """Database exeception classes."""
3
- from __future__ import annotations
4
-
5
3
  from typing import Optional
6
4
 
7
5
 
8
- class Error(Exception):
6
+ class MySQLError(Exception):
7
+ """All MySQL-related exceptions."""
8
+
9
+
10
+ class Error(MySQLError):
9
11
  """
10
12
  Generic database exception.
11
13
 
@@ -54,7 +56,7 @@ class Error(Exception):
54
56
  return self.errmsg
55
57
 
56
58
 
57
- class Warning(Exception):
59
+ class Warning(Warning, MySQLError): # type: ignore
58
60
  """Exception for important warnings like data truncations, etc."""
59
61
 
60
62
 
@@ -92,3 +94,27 @@ class NotSupportedError(DatabaseError):
92
94
 
93
95
  class ManagementError(Error):
94
96
  """Exception for errors in the management API."""
97
+
98
+ def __init__(
99
+ self, errno: Optional[int] = None, msg: Optional[str] = None,
100
+ response: Optional[str] = None,
101
+ ):
102
+ self.errno = errno
103
+ self.errmsg = msg
104
+ self.response = response
105
+ super(Exception, self).__init__(errno, msg)
106
+
107
+ def __str__(self) -> str:
108
+ """Return string representation."""
109
+ prefix = []
110
+ if self.errno is not None:
111
+ prefix.append(f'{self.errno}')
112
+ if self.response is not None:
113
+ prefix.append(f'({self.response})')
114
+ if prefix and self.errmsg:
115
+ return ' '.join(prefix) + ': ' + self.errmsg
116
+ elif prefix:
117
+ return ' '.join(prefix)
118
+ elif self.errmsg:
119
+ return f'{self.errmsg}'
120
+ return 'Unknown error'
@@ -0,0 +1 @@
1
+ from .decorator import udf # noqa: F401
@@ -0,0 +1,142 @@
1
+ import functools
2
+ from typing import Any
3
+ from typing import Callable
4
+ from typing import Dict
5
+ from typing import List
6
+ from typing import Optional
7
+ from typing import Union
8
+
9
+ from .dtypes import DataType
10
+
11
+
12
+ def listify(x: Any) -> List[Any]:
13
+ """Make sure sure value is a list."""
14
+ if x is None:
15
+ return []
16
+ if isinstance(x, (list, tuple, set)):
17
+ return list(x)
18
+ return [x]
19
+
20
+
21
+ def udf(
22
+ func: Optional[Callable[..., Any]] = None,
23
+ *,
24
+ name: Optional[str] = None,
25
+ args: Optional[Union[DataType, List[DataType], Dict[str, DataType]]] = None,
26
+ returns: Optional[str] = None,
27
+ data_format: Optional[str] = None,
28
+ include_masks: bool = False,
29
+ ) -> Callable[..., Any]:
30
+ """
31
+ Apply attributes to a UDF.
32
+
33
+ Parameters
34
+ ----------
35
+ func : callable, optional
36
+ The UDF to apply parameters to
37
+ name : str, optional
38
+ The name to use for the UDF in the database
39
+ args : str | Callable | List[str | Callable] | Dict[str, str | Callable], optional
40
+ Specifies the data types of the function arguments. Typically,
41
+ the function data types are derived from the function parameter
42
+ annotations. These annotations can be overridden. If the function
43
+ takes a single type for all parameters, `args` can be set to a
44
+ SQL string describing all parameters. If the function takes more
45
+ than one parameter and all of the parameters are being manually
46
+ defined, a list of SQL strings may be used (one for each parameter).
47
+ A dictionary of SQL strings may be used to specify a parameter type
48
+ for a subset of parameters; the keys are the names of the
49
+ function parameters. Callables may also be used for datatypes. This
50
+ is primarily for using the functions in the ``dtypes`` module that
51
+ are associated with SQL types with all default options (e.g., ``dt.FLOAT``).
52
+ returns : str, optional
53
+ Specifies the return data type of the function. If not specified,
54
+ the type annotation from the function is used.
55
+ data_format : str, optional
56
+ The data format of each parameter: python, pandas, arrow, polars
57
+ include_masks : bool, optional
58
+ Should boolean masks be included with each input parameter to indicate
59
+ which elements are NULL? This is only used when a input parameters are
60
+ configured to a vector type (numpy, pandas, polars, arrow).
61
+
62
+ Returns
63
+ -------
64
+ Callable
65
+
66
+ """
67
+ if args is None:
68
+ pass
69
+ elif isinstance(args, (list, tuple)):
70
+ args = list(args)
71
+ for i, item in enumerate(args):
72
+ if callable(item):
73
+ args[i] = item()
74
+ for item in args:
75
+ if not isinstance(item, str):
76
+ raise TypeError(f'unrecognized type for parameter: {item}')
77
+ elif isinstance(args, dict):
78
+ args = dict(args)
79
+ for k, v in list(args.items()):
80
+ if callable(v):
81
+ args[k] = v()
82
+ for item in args.values():
83
+ if not isinstance(item, str):
84
+ raise TypeError(f'unrecognized type for parameter: {item}')
85
+ elif callable(args):
86
+ args = args()
87
+ elif isinstance(args, str):
88
+ args = args
89
+ else:
90
+ raise TypeError(f'unrecognized data type for args: {args}')
91
+
92
+ if returns is None:
93
+ pass
94
+ elif callable(returns):
95
+ returns = returns()
96
+ elif isinstance(returns, str):
97
+ returns = returns
98
+ else:
99
+ raise TypeError(f'unrecognized return type: {returns}')
100
+
101
+ if returns is not None and not isinstance(returns, str):
102
+ raise TypeError(f'unrecognized return type: {returns}')
103
+
104
+ if include_masks and data_format == 'python':
105
+ raise RuntimeError(
106
+ 'include_masks is only valid when using '
107
+ 'vectors for input parameters',
108
+ )
109
+
110
+ _singlestoredb_attrs = { # type: ignore
111
+ k: v for k, v in dict(
112
+ name=name,
113
+ args=args,
114
+ returns=returns,
115
+ data_format=data_format,
116
+ include_masks=include_masks,
117
+ ).items() if v is not None
118
+ }
119
+
120
+ # No func was specified, this is an uncalled decorator that will get
121
+ # called later, so the wrapper much be created with the func passed
122
+ # in at that time.
123
+ if func is None:
124
+ def decorate(func: Callable[..., Any]) -> Callable[..., Any]:
125
+ def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]:
126
+ return func(*args, **kwargs) # type: ignore
127
+ wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore
128
+ return functools.wraps(func)(wrapper)
129
+ return decorate
130
+
131
+ def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]:
132
+ return func(*args, **kwargs) # type: ignore
133
+
134
+ wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore
135
+
136
+ return functools.wraps(func)(wrapper)
137
+
138
+
139
+ udf.pandas = functools.partial(udf, data_format='pandas') # type: ignore
140
+ udf.polars = functools.partial(udf, data_format='polars') # type: ignore
141
+ udf.arrow = functools.partial(udf, data_format='arrow') # type: ignore
142
+ udf.numpy = functools.partial(udf, data_format='numpy') # type: ignore