singlestoredb 1.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (183) hide show
  1. singlestoredb/__init__.py +75 -0
  2. singlestoredb/ai/__init__.py +2 -0
  3. singlestoredb/ai/chat.py +139 -0
  4. singlestoredb/ai/embeddings.py +128 -0
  5. singlestoredb/alchemy/__init__.py +90 -0
  6. singlestoredb/apps/__init__.py +3 -0
  7. singlestoredb/apps/_cloud_functions.py +90 -0
  8. singlestoredb/apps/_config.py +72 -0
  9. singlestoredb/apps/_connection_info.py +18 -0
  10. singlestoredb/apps/_dashboards.py +47 -0
  11. singlestoredb/apps/_process.py +32 -0
  12. singlestoredb/apps/_python_udfs.py +100 -0
  13. singlestoredb/apps/_stdout_supress.py +30 -0
  14. singlestoredb/apps/_uvicorn_util.py +36 -0
  15. singlestoredb/auth.py +245 -0
  16. singlestoredb/config.py +484 -0
  17. singlestoredb/connection.py +1487 -0
  18. singlestoredb/converters.py +950 -0
  19. singlestoredb/docstring/__init__.py +33 -0
  20. singlestoredb/docstring/attrdoc.py +126 -0
  21. singlestoredb/docstring/common.py +230 -0
  22. singlestoredb/docstring/epydoc.py +267 -0
  23. singlestoredb/docstring/google.py +412 -0
  24. singlestoredb/docstring/numpydoc.py +562 -0
  25. singlestoredb/docstring/parser.py +100 -0
  26. singlestoredb/docstring/py.typed +1 -0
  27. singlestoredb/docstring/rest.py +256 -0
  28. singlestoredb/docstring/tests/__init__.py +1 -0
  29. singlestoredb/docstring/tests/_pydoctor.py +21 -0
  30. singlestoredb/docstring/tests/test_epydoc.py +729 -0
  31. singlestoredb/docstring/tests/test_google.py +1007 -0
  32. singlestoredb/docstring/tests/test_numpydoc.py +1100 -0
  33. singlestoredb/docstring/tests/test_parse_from_object.py +109 -0
  34. singlestoredb/docstring/tests/test_parser.py +248 -0
  35. singlestoredb/docstring/tests/test_rest.py +547 -0
  36. singlestoredb/docstring/tests/test_util.py +70 -0
  37. singlestoredb/docstring/util.py +141 -0
  38. singlestoredb/exceptions.py +120 -0
  39. singlestoredb/functions/__init__.py +16 -0
  40. singlestoredb/functions/decorator.py +201 -0
  41. singlestoredb/functions/dtypes.py +1793 -0
  42. singlestoredb/functions/ext/__init__.py +1 -0
  43. singlestoredb/functions/ext/arrow.py +375 -0
  44. singlestoredb/functions/ext/asgi.py +2133 -0
  45. singlestoredb/functions/ext/json.py +420 -0
  46. singlestoredb/functions/ext/mmap.py +413 -0
  47. singlestoredb/functions/ext/rowdat_1.py +724 -0
  48. singlestoredb/functions/ext/timer.py +89 -0
  49. singlestoredb/functions/ext/utils.py +218 -0
  50. singlestoredb/functions/signature.py +1578 -0
  51. singlestoredb/functions/typing/__init__.py +41 -0
  52. singlestoredb/functions/typing/numpy.py +20 -0
  53. singlestoredb/functions/typing/pandas.py +2 -0
  54. singlestoredb/functions/typing/polars.py +2 -0
  55. singlestoredb/functions/typing/pyarrow.py +2 -0
  56. singlestoredb/functions/utils.py +421 -0
  57. singlestoredb/fusion/__init__.py +11 -0
  58. singlestoredb/fusion/graphql.py +213 -0
  59. singlestoredb/fusion/handler.py +916 -0
  60. singlestoredb/fusion/handlers/__init__.py +0 -0
  61. singlestoredb/fusion/handlers/export.py +525 -0
  62. singlestoredb/fusion/handlers/files.py +690 -0
  63. singlestoredb/fusion/handlers/job.py +660 -0
  64. singlestoredb/fusion/handlers/models.py +250 -0
  65. singlestoredb/fusion/handlers/stage.py +502 -0
  66. singlestoredb/fusion/handlers/utils.py +324 -0
  67. singlestoredb/fusion/handlers/workspace.py +956 -0
  68. singlestoredb/fusion/registry.py +249 -0
  69. singlestoredb/fusion/result.py +399 -0
  70. singlestoredb/http/__init__.py +27 -0
  71. singlestoredb/http/connection.py +1267 -0
  72. singlestoredb/magics/__init__.py +34 -0
  73. singlestoredb/magics/run_personal.py +137 -0
  74. singlestoredb/magics/run_shared.py +134 -0
  75. singlestoredb/management/__init__.py +9 -0
  76. singlestoredb/management/billing_usage.py +148 -0
  77. singlestoredb/management/cluster.py +462 -0
  78. singlestoredb/management/export.py +295 -0
  79. singlestoredb/management/files.py +1102 -0
  80. singlestoredb/management/inference_api.py +105 -0
  81. singlestoredb/management/job.py +887 -0
  82. singlestoredb/management/manager.py +373 -0
  83. singlestoredb/management/organization.py +226 -0
  84. singlestoredb/management/region.py +169 -0
  85. singlestoredb/management/utils.py +423 -0
  86. singlestoredb/management/workspace.py +1927 -0
  87. singlestoredb/mysql/__init__.py +177 -0
  88. singlestoredb/mysql/_auth.py +298 -0
  89. singlestoredb/mysql/charset.py +214 -0
  90. singlestoredb/mysql/connection.py +2032 -0
  91. singlestoredb/mysql/constants/CLIENT.py +38 -0
  92. singlestoredb/mysql/constants/COMMAND.py +32 -0
  93. singlestoredb/mysql/constants/CR.py +78 -0
  94. singlestoredb/mysql/constants/ER.py +474 -0
  95. singlestoredb/mysql/constants/EXTENDED_TYPE.py +3 -0
  96. singlestoredb/mysql/constants/FIELD_TYPE.py +48 -0
  97. singlestoredb/mysql/constants/FLAG.py +15 -0
  98. singlestoredb/mysql/constants/SERVER_STATUS.py +10 -0
  99. singlestoredb/mysql/constants/VECTOR_TYPE.py +6 -0
  100. singlestoredb/mysql/constants/__init__.py +0 -0
  101. singlestoredb/mysql/converters.py +271 -0
  102. singlestoredb/mysql/cursors.py +896 -0
  103. singlestoredb/mysql/err.py +92 -0
  104. singlestoredb/mysql/optionfile.py +20 -0
  105. singlestoredb/mysql/protocol.py +450 -0
  106. singlestoredb/mysql/tests/__init__.py +19 -0
  107. singlestoredb/mysql/tests/base.py +126 -0
  108. singlestoredb/mysql/tests/conftest.py +37 -0
  109. singlestoredb/mysql/tests/test_DictCursor.py +132 -0
  110. singlestoredb/mysql/tests/test_SSCursor.py +141 -0
  111. singlestoredb/mysql/tests/test_basic.py +452 -0
  112. singlestoredb/mysql/tests/test_connection.py +851 -0
  113. singlestoredb/mysql/tests/test_converters.py +58 -0
  114. singlestoredb/mysql/tests/test_cursor.py +141 -0
  115. singlestoredb/mysql/tests/test_err.py +16 -0
  116. singlestoredb/mysql/tests/test_issues.py +514 -0
  117. singlestoredb/mysql/tests/test_load_local.py +75 -0
  118. singlestoredb/mysql/tests/test_nextset.py +88 -0
  119. singlestoredb/mysql/tests/test_optionfile.py +27 -0
  120. singlestoredb/mysql/tests/thirdparty/__init__.py +6 -0
  121. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/__init__.py +9 -0
  122. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/capabilities.py +323 -0
  123. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/dbapi20.py +865 -0
  124. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py +110 -0
  125. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py +224 -0
  126. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py +101 -0
  127. singlestoredb/mysql/times.py +23 -0
  128. singlestoredb/notebook/__init__.py +16 -0
  129. singlestoredb/notebook/_objects.py +213 -0
  130. singlestoredb/notebook/_portal.py +352 -0
  131. singlestoredb/py.typed +0 -0
  132. singlestoredb/pytest.py +352 -0
  133. singlestoredb/server/__init__.py +0 -0
  134. singlestoredb/server/docker.py +452 -0
  135. singlestoredb/server/free_tier.py +267 -0
  136. singlestoredb/tests/__init__.py +0 -0
  137. singlestoredb/tests/alltypes.sql +307 -0
  138. singlestoredb/tests/alltypes_no_nulls.sql +208 -0
  139. singlestoredb/tests/empty.sql +0 -0
  140. singlestoredb/tests/ext_funcs/__init__.py +702 -0
  141. singlestoredb/tests/local_infile.csv +3 -0
  142. singlestoredb/tests/test.ipynb +18 -0
  143. singlestoredb/tests/test.sql +680 -0
  144. singlestoredb/tests/test2.ipynb +18 -0
  145. singlestoredb/tests/test2.sql +1 -0
  146. singlestoredb/tests/test_basics.py +1332 -0
  147. singlestoredb/tests/test_config.py +318 -0
  148. singlestoredb/tests/test_connection.py +3103 -0
  149. singlestoredb/tests/test_dbapi.py +27 -0
  150. singlestoredb/tests/test_exceptions.py +45 -0
  151. singlestoredb/tests/test_ext_func.py +1472 -0
  152. singlestoredb/tests/test_ext_func_data.py +1101 -0
  153. singlestoredb/tests/test_fusion.py +1527 -0
  154. singlestoredb/tests/test_http.py +288 -0
  155. singlestoredb/tests/test_management.py +1599 -0
  156. singlestoredb/tests/test_plugin.py +33 -0
  157. singlestoredb/tests/test_results.py +171 -0
  158. singlestoredb/tests/test_types.py +132 -0
  159. singlestoredb/tests/test_udf.py +737 -0
  160. singlestoredb/tests/test_udf_returns.py +459 -0
  161. singlestoredb/tests/test_vectorstore.py +51 -0
  162. singlestoredb/tests/test_xdict.py +333 -0
  163. singlestoredb/tests/utils.py +141 -0
  164. singlestoredb/types.py +373 -0
  165. singlestoredb/utils/__init__.py +0 -0
  166. singlestoredb/utils/config.py +950 -0
  167. singlestoredb/utils/convert_rows.py +69 -0
  168. singlestoredb/utils/debug.py +13 -0
  169. singlestoredb/utils/dtypes.py +205 -0
  170. singlestoredb/utils/events.py +65 -0
  171. singlestoredb/utils/mogrify.py +151 -0
  172. singlestoredb/utils/results.py +585 -0
  173. singlestoredb/utils/xdict.py +425 -0
  174. singlestoredb/vectorstore.py +192 -0
  175. singlestoredb/warnings.py +5 -0
  176. singlestoredb-1.16.1.dist-info/METADATA +165 -0
  177. singlestoredb-1.16.1.dist-info/RECORD +183 -0
  178. singlestoredb-1.16.1.dist-info/WHEEL +5 -0
  179. singlestoredb-1.16.1.dist-info/entry_points.txt +2 -0
  180. singlestoredb-1.16.1.dist-info/licenses/LICENSE +201 -0
  181. singlestoredb-1.16.1.dist-info/top_level.txt +3 -0
  182. sqlx/__init__.py +4 -0
  183. sqlx/magic.py +113 -0
@@ -0,0 +1,2133 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Web application for SingleStoreDB external functions.
4
+
5
+ This module supplies a function that can create web apps intended for use
6
+ with the external function feature of SingleStoreDB. The application
7
+ function is a standard ASGI <https://asgi.readthedocs.io/en/latest/index.html>
8
+ request handler for use with servers such as Uvicorn <https://www.uvicorn.org>.
9
+
10
+ An external function web application can be created using the `create_app`
11
+ function. By default, the exported Python functions are specified by
12
+ environment variables starting with SINGLESTOREDB_EXT_FUNCTIONS. See the
13
+ documentation in `create_app` for the full syntax. If the application is
14
+ created in Python code rather than from the command-line, exported
15
+ functions can be specified in the parameters.
16
+
17
+ An example of starting a server is shown below.
18
+
19
+ Example
20
+ -------
21
+ > SINGLESTOREDB_EXT_FUNCTIONS='myfuncs.[percentile_90,percentile_95]' \
22
+ python3 -m singlestoredb.functions.ext.asgi
23
+
24
+ """
25
+ import argparse
26
+ import asyncio
27
+ import contextvars
28
+ import dataclasses
29
+ import datetime
30
+ import functools
31
+ import importlib.util
32
+ import inspect
33
+ import io
34
+ import itertools
35
+ import json
36
+ import logging
37
+ import os
38
+ import re
39
+ import secrets
40
+ import sys
41
+ import tempfile
42
+ import textwrap
43
+ import threading
44
+ import time
45
+ import traceback
46
+ import typing
47
+ import urllib
48
+ import uuid
49
+ import zipfile
50
+ import zipimport
51
+ from collections.abc import Awaitable
52
+ from collections.abc import Iterable
53
+ from collections.abc import Sequence
54
+ from types import ModuleType
55
+ from typing import Any
56
+ from typing import Callable
57
+ from typing import Dict
58
+ from typing import List
59
+ from typing import Optional
60
+ from typing import Set
61
+ from typing import Tuple
62
+ from typing import Union
63
+
64
+ from . import arrow
65
+ from . import json as jdata
66
+ from . import rowdat_1
67
+ from . import utils
68
+ from ... import connection
69
+ from ... import manage_workspaces
70
+ from ...config import get_option
71
+ from ...mysql.constants import FIELD_TYPE as ft
72
+ from ..signature import get_signature
73
+ from ..signature import signature_to_sql
74
+ from ..typing import Masked
75
+ from ..typing import Table
76
+ from .timer import Timer
77
+ from singlestoredb.docstring.parser import parse
78
+ from singlestoredb.functions.dtypes import escape_name
79
+
80
+ try:
81
+ import cloudpickle
82
+ has_cloudpickle = True
83
+ except ImportError:
84
+ has_cloudpickle = False
85
+
86
+ try:
87
+ from pydantic import BaseModel
88
+ has_pydantic = True
89
+ except ImportError:
90
+ has_pydantic = False
91
+
92
+
93
+ logger = utils.get_logger('singlestoredb.functions.ext.asgi')
94
+
95
+ # If a number of processes is specified, create a pool of workers
96
+ num_processes = max(0, int(os.environ.get('SINGLESTOREDB_EXT_NUM_PROCESSES', 0)))
97
+ if num_processes > 1:
98
+ try:
99
+ from ray.util.multiprocessing import Pool
100
+ except ImportError:
101
+ from multiprocessing import Pool
102
+ func_map = Pool(num_processes).starmap
103
+ else:
104
+ func_map = itertools.starmap
105
+
106
+
107
+ async def to_thread(
108
+ func: Any, /, *args: Any, **kwargs: Dict[str, Any],
109
+ ) -> Any:
110
+ loop = asyncio.get_running_loop()
111
+ ctx = contextvars.copy_context()
112
+ func_call = functools.partial(ctx.run, func, *args, **kwargs)
113
+ return await loop.run_in_executor(None, func_call)
114
+
115
+
116
+ # Use negative values to indicate unsigned ints / binary data / usec time precision
117
+ rowdat_1_type_map = {
118
+ 'bool': ft.LONGLONG,
119
+ 'int8': ft.LONGLONG,
120
+ 'int16': ft.LONGLONG,
121
+ 'int32': ft.LONGLONG,
122
+ 'int64': ft.LONGLONG,
123
+ 'uint8': -ft.LONGLONG,
124
+ 'uint16': -ft.LONGLONG,
125
+ 'uint32': -ft.LONGLONG,
126
+ 'uint64': -ft.LONGLONG,
127
+ 'float32': ft.DOUBLE,
128
+ 'float64': ft.DOUBLE,
129
+ 'str': ft.STRING,
130
+ 'bytes': -ft.STRING,
131
+ }
132
+
133
+
134
+ def get_func_names(funcs: str) -> List[Tuple[str, str]]:
135
+ """
136
+ Parse all function names from string.
137
+
138
+ Parameters
139
+ ----------
140
+ func_names : str
141
+ String containing one or more function names. The syntax is
142
+ as follows: [func-name-1@func-alias-1,func-name-2@func-alias-2,...].
143
+ The optional '@name' portion is an alias if you want the function
144
+ to be renamed.
145
+
146
+ Returns
147
+ -------
148
+ List[Tuple[str]] : a list of tuples containing the names and aliases
149
+ of each function.
150
+
151
+ """
152
+ if funcs.startswith('['):
153
+ func_names = funcs.replace('[', '').replace(']', '').split(',')
154
+ func_names = [x.strip() for x in func_names]
155
+ else:
156
+ func_names = [funcs]
157
+
158
+ out = []
159
+ for name in func_names:
160
+ alias = name
161
+ if '@' in name:
162
+ name, alias = name.split('@', 1)
163
+ out.append((name, alias))
164
+
165
+ return out
166
+
167
+
168
+ def as_tuple(x: Any) -> Any:
169
+ """Convert object to tuple."""
170
+ if has_pydantic and isinstance(x, BaseModel):
171
+ return tuple(x.model_dump().values())
172
+ if dataclasses.is_dataclass(x):
173
+ return dataclasses.astuple(x) # type: ignore
174
+ if isinstance(x, dict):
175
+ return tuple(x.values())
176
+ return tuple(x)
177
+
178
+
179
+ def as_list_of_tuples(x: Any) -> Any:
180
+ """Convert object to a list of tuples."""
181
+ if isinstance(x, Table):
182
+ x = x[0]
183
+ if isinstance(x, (list, tuple)) and len(x) > 0:
184
+ if isinstance(x[0], (list, tuple)):
185
+ return x
186
+ if has_pydantic and isinstance(x[0], BaseModel):
187
+ return [tuple(y.model_dump().values()) for y in x]
188
+ if dataclasses.is_dataclass(x[0]):
189
+ return [dataclasses.astuple(y) for y in x]
190
+ if isinstance(x[0], dict):
191
+ return [tuple(y.values()) for y in x]
192
+ return [(y,) for y in x]
193
+ return x
194
+
195
+
196
+ def get_dataframe_columns(df: Any) -> List[Any]:
197
+ """Return columns of data from a dataframe/table."""
198
+ if isinstance(df, Table):
199
+ if len(df) == 1:
200
+ df = df[0]
201
+ else:
202
+ return list(df)
203
+
204
+ if isinstance(df, Masked):
205
+ return [df]
206
+
207
+ if isinstance(df, tuple):
208
+ return list(df)
209
+
210
+ rtype = str(type(df)).lower()
211
+ if 'dataframe' in rtype:
212
+ return [df[x] for x in df.columns]
213
+ elif 'table' in rtype:
214
+ return df.columns
215
+ elif 'series' in rtype:
216
+ return [df]
217
+ elif 'array' in rtype:
218
+ return [df]
219
+ elif 'tuple' in rtype:
220
+ return list(df)
221
+
222
+ raise TypeError(
223
+ 'Unsupported data type for dataframe columns: '
224
+ f'{rtype}',
225
+ )
226
+
227
+
228
+ def get_array_class(data_format: str) -> Callable[..., Any]:
229
+ """
230
+ Get the array class for the current data format.
231
+
232
+ """
233
+ if data_format == 'polars':
234
+ import polars as pl
235
+ array_cls = pl.Series
236
+ elif data_format == 'arrow':
237
+ import pyarrow as pa
238
+ array_cls = pa.array
239
+ elif data_format == 'pandas':
240
+ import pandas as pd
241
+ array_cls = pd.Series
242
+ else:
243
+ import numpy as np
244
+ array_cls = np.array
245
+ return array_cls
246
+
247
+
248
+ def get_masked_params(func: Callable[..., Any]) -> List[bool]:
249
+ """
250
+ Get the list of masked parameters for the function.
251
+
252
+ Parameters
253
+ ----------
254
+ func : Callable
255
+ The function to call as the endpoint
256
+
257
+ Returns
258
+ -------
259
+ List[bool]
260
+ Boolean list of masked parameters
261
+
262
+ """
263
+ params = inspect.signature(func).parameters
264
+ return [typing.get_origin(x.annotation) is Masked for x in params.values()]
265
+
266
+
267
+ def build_tuple(x: Any) -> Any:
268
+ """Convert object to tuple."""
269
+ return tuple(x) if isinstance(x, Masked) else (x, None)
270
+
271
+
272
+ def cancel_on_event(
273
+ cancel_event: threading.Event,
274
+ ) -> None:
275
+ """
276
+ Cancel the function call if the cancel event is set.
277
+
278
+ Parameters
279
+ ----------
280
+ cancel_event : threading.Event
281
+ The event to check for cancellation
282
+
283
+ Raises
284
+ ------
285
+ asyncio.CancelledError
286
+ If the cancel event is set
287
+
288
+ """
289
+ if cancel_event.is_set():
290
+ task = asyncio.current_task()
291
+ if task is not None:
292
+ task.cancel()
293
+ raise asyncio.CancelledError(
294
+ 'Function call was cancelled by client',
295
+ )
296
+
297
+
298
+ def build_udf_endpoint(
299
+ func: Callable[..., Any],
300
+ returns_data_format: str,
301
+ ) -> Callable[..., Any]:
302
+ """
303
+ Build a UDF endpoint for scalar / list types (row-based).
304
+
305
+ Parameters
306
+ ----------
307
+ func : Callable
308
+ The function to call as the endpoint
309
+ returns_data_format : str
310
+ The format of the return values
311
+
312
+ Returns
313
+ -------
314
+ Callable
315
+ The function endpoint
316
+
317
+ """
318
+ if returns_data_format in ['scalar', 'list']:
319
+
320
+ is_async = asyncio.iscoroutinefunction(func)
321
+
322
+ async def do_func(
323
+ cancel_event: threading.Event,
324
+ timer: Timer,
325
+ row_ids: Sequence[int],
326
+ rows: Sequence[Sequence[Any]],
327
+ ) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
328
+ '''Call function on given rows of data.'''
329
+ out = []
330
+ async with timer('call_function'):
331
+ for row in rows:
332
+ cancel_on_event(cancel_event)
333
+ if is_async:
334
+ out.append(await func(*row))
335
+ else:
336
+ out.append(func(*row))
337
+ return row_ids, list(zip(out))
338
+
339
+ return do_func
340
+
341
+ return build_vector_udf_endpoint(func, returns_data_format)
342
+
343
+
344
+ def build_vector_udf_endpoint(
345
+ func: Callable[..., Any],
346
+ returns_data_format: str,
347
+ ) -> Callable[..., Any]:
348
+ """
349
+ Build a UDF endpoint for vector formats (column-based).
350
+
351
+ Parameters
352
+ ----------
353
+ func : Callable
354
+ The function to call as the endpoint
355
+ returns_data_format : str
356
+ The format of the return values
357
+
358
+ Returns
359
+ -------
360
+ Callable
361
+ The function endpoint
362
+
363
+ """
364
+ masks = get_masked_params(func)
365
+ array_cls = get_array_class(returns_data_format)
366
+ is_async = asyncio.iscoroutinefunction(func)
367
+
368
+ async def do_func(
369
+ cancel_event: threading.Event,
370
+ timer: Timer,
371
+ row_ids: Sequence[int],
372
+ cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
373
+ ) -> Tuple[
374
+ Sequence[int],
375
+ List[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
376
+ ]:
377
+ '''Call function on given columns of data.'''
378
+ row_ids = array_cls(row_ids)
379
+
380
+ # Call the function with `cols` as the function parameters
381
+ async with timer('call_function'):
382
+ if cols and cols[0]:
383
+ if is_async:
384
+ out = await func(*[x if m else x[0] for x, m in zip(cols, masks)])
385
+ else:
386
+ out = func(*[x if m else x[0] for x, m in zip(cols, masks)])
387
+ else:
388
+ if is_async:
389
+ out = await func()
390
+ else:
391
+ out = func()
392
+
393
+ cancel_on_event(cancel_event)
394
+
395
+ # Single masked value
396
+ if isinstance(out, Masked):
397
+ return row_ids, [tuple(out)]
398
+
399
+ # Multiple return values
400
+ if isinstance(out, tuple):
401
+ return row_ids, [build_tuple(x) for x in out]
402
+
403
+ # Single return value
404
+ return row_ids, [(out, None)]
405
+
406
+ return do_func
407
+
408
+
409
+ def build_tvf_endpoint(
410
+ func: Callable[..., Any],
411
+ returns_data_format: str,
412
+ ) -> Callable[..., Any]:
413
+ """
414
+ Build a TVF endpoint for scalar / list types (row-based).
415
+
416
+ Parameters
417
+ ----------
418
+ func : Callable
419
+ The function to call as the endpoint
420
+ returns_data_format : str
421
+ The format of the return values
422
+
423
+ Returns
424
+ -------
425
+ Callable
426
+ The function endpoint
427
+
428
+ """
429
+ if returns_data_format in ['scalar', 'list']:
430
+
431
+ is_async = asyncio.iscoroutinefunction(func)
432
+
433
+ async def do_func(
434
+ cancel_event: threading.Event,
435
+ timer: Timer,
436
+ row_ids: Sequence[int],
437
+ rows: Sequence[Sequence[Any]],
438
+ ) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
439
+ '''Call function on given rows of data.'''
440
+ out_ids: List[int] = []
441
+ out = []
442
+ # Call function on each row of data
443
+ async with timer('call_function'):
444
+ for i, row in zip(row_ids, rows):
445
+ cancel_on_event(cancel_event)
446
+ if is_async:
447
+ res = await func(*row)
448
+ else:
449
+ res = func(*row)
450
+ out.extend(as_list_of_tuples(res))
451
+ out_ids.extend([row_ids[i]] * (len(out)-len(out_ids)))
452
+ return out_ids, out
453
+
454
+ return do_func
455
+
456
+ return build_vector_tvf_endpoint(func, returns_data_format)
457
+
458
+
459
+ def build_vector_tvf_endpoint(
460
+ func: Callable[..., Any],
461
+ returns_data_format: str,
462
+ ) -> Callable[..., Any]:
463
+ """
464
+ Build a TVF endpoint for vector formats (column-based).
465
+
466
+ Parameters
467
+ ----------
468
+ func : Callable
469
+ The function to call as the endpoint
470
+ returns_data_format : str
471
+ The format of the return values
472
+
473
+ Returns
474
+ -------
475
+ Callable
476
+ The function endpoint
477
+
478
+ """
479
+ masks = get_masked_params(func)
480
+ array_cls = get_array_class(returns_data_format)
481
+
482
+ async def do_func(
483
+ cancel_event: threading.Event,
484
+ timer: Timer,
485
+ row_ids: Sequence[int],
486
+ cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
487
+ ) -> Tuple[
488
+ Sequence[int],
489
+ List[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
490
+ ]:
491
+ '''Call function on given columns of data.'''
492
+ # NOTE: There is no way to determine which row ID belongs to
493
+ # each result row, so we just have to use the same
494
+ # row ID for all rows in the result.
495
+
496
+ is_async = asyncio.iscoroutinefunction(func)
497
+
498
+ # Call function on each column of data
499
+ async with timer('call_function'):
500
+ if cols and cols[0]:
501
+ if is_async:
502
+ func_res = await func(
503
+ *[x if m else x[0] for x, m in zip(cols, masks)],
504
+ )
505
+ else:
506
+ func_res = func(
507
+ *[x if m else x[0] for x, m in zip(cols, masks)],
508
+ )
509
+ else:
510
+ if is_async:
511
+ func_res = await func()
512
+ else:
513
+ func_res = func()
514
+
515
+ res = get_dataframe_columns(func_res)
516
+
517
+ cancel_on_event(cancel_event)
518
+
519
+ # Generate row IDs
520
+ if isinstance(res[0], Masked):
521
+ row_ids = array_cls([row_ids[0]] * len(res[0][0]))
522
+ else:
523
+ row_ids = array_cls([row_ids[0]] * len(res[0]))
524
+
525
+ return row_ids, [build_tuple(x) for x in res]
526
+
527
+ return do_func
528
+
529
+
530
+ def make_func(
531
+ name: str,
532
+ func: Callable[..., Any],
533
+ ) -> Tuple[Callable[..., Any], Dict[str, Any]]:
534
+ """
535
+ Make a function endpoint.
536
+
537
+ Parameters
538
+ ----------
539
+ name : str
540
+ Name of the function to create
541
+ func : Callable
542
+ The function to call as the endpoint
543
+ database : str, optional
544
+ The database to use for the function definition
545
+
546
+ Returns
547
+ -------
548
+ (Callable, Dict[str, Any])
549
+
550
+ """
551
+ info: Dict[str, Any] = {}
552
+
553
+ sig = get_signature(func, func_name=name)
554
+
555
+ function_type = sig.get('function_type', 'udf')
556
+ args_data_format = sig.get('args_data_format', 'scalar')
557
+ returns_data_format = sig.get('returns_data_format', 'scalar')
558
+ timeout = (
559
+ func._singlestoredb_attrs.get('timeout') or # type: ignore
560
+ get_option('external_function.timeout')
561
+ )
562
+
563
+ if function_type == 'tvf':
564
+ do_func = build_tvf_endpoint(func, returns_data_format)
565
+ else:
566
+ do_func = build_udf_endpoint(func, returns_data_format)
567
+
568
+ do_func.__name__ = name
569
+ do_func.__doc__ = func.__doc__
570
+
571
+ # Store signature for generating CREATE FUNCTION calls
572
+ info['signature'] = sig
573
+
574
+ # Set data format
575
+ info['args_data_format'] = args_data_format
576
+ info['returns_data_format'] = returns_data_format
577
+
578
+ # Set function type
579
+ info['function_type'] = function_type
580
+
581
+ # Set timeout
582
+ info['timeout'] = max(timeout, 1)
583
+
584
+ # Set async flag
585
+ info['is_async'] = asyncio.iscoroutinefunction(func)
586
+
587
+ # Setup argument types for rowdat_1 parser
588
+ colspec = []
589
+ for x in sig['args']:
590
+ dtype = x['dtype'].replace('?', '')
591
+ if dtype not in rowdat_1_type_map:
592
+ raise TypeError(f'no data type mapping for {dtype}')
593
+ colspec.append((x['name'], rowdat_1_type_map[dtype]))
594
+ info['colspec'] = colspec
595
+
596
+ # Setup return type
597
+ returns = []
598
+ for x in sig['returns']:
599
+ dtype = x['dtype'].replace('?', '')
600
+ if dtype not in rowdat_1_type_map:
601
+ raise TypeError(f'no data type mapping for {dtype}')
602
+ returns.append((x['name'], rowdat_1_type_map[dtype]))
603
+ info['returns'] = returns
604
+
605
+ return do_func, info
606
+
607
+
608
+ async def cancel_on_timeout(timeout: int) -> None:
609
+ """Cancel request if it takes too long."""
610
+ await asyncio.sleep(timeout)
611
+ raise asyncio.CancelledError(
612
+ 'Function call was cancelled due to timeout',
613
+ )
614
+
615
+
616
+ async def cancel_on_disconnect(
617
+ receive: Callable[..., Awaitable[Any]],
618
+ ) -> None:
619
+ """Cancel request if client disconnects."""
620
+ while True:
621
+ message = await receive()
622
+ if message.get('type', '') == 'http.disconnect':
623
+ raise asyncio.CancelledError(
624
+ 'Function call was cancelled by client',
625
+ )
626
+
627
+
628
+ async def cancel_all_tasks(tasks: Iterable[asyncio.Task[Any]]) -> None:
629
+ """Cancel all tasks."""
630
+ for task in tasks:
631
+ task.cancel()
632
+ await asyncio.gather(*tasks, return_exceptions=True)
633
+
634
+
635
+ def start_counter() -> float:
636
+ """Start a timer and return the start time."""
637
+ return time.perf_counter()
638
+
639
+
640
+ def end_counter(start: float) -> float:
641
+ """End a timer and return the elapsed time."""
642
+ return time.perf_counter() - start
643
+
644
+
645
+ class Application(object):
646
+ """
647
+ Create an external function application.
648
+
649
+ If `functions` is None, the environment is searched for function
650
+ specifications in variables starting with `SINGLESTOREDB_EXT_FUNCTIONS`.
651
+ Any number of environment variables can be specified as long as they
652
+ have this prefix. The format of the environment variable value is the
653
+ same as for the `functions` parameter.
654
+
655
+ Parameters
656
+ ----------
657
+ functions : str or Iterable[str], optional
658
+ Python functions are specified using a string format as follows:
659
+ * Single function : <pkg1>.<func1>
660
+ * Multiple functions : <pkg1>.[<func1-name,func2-name,...]
661
+ * Function aliases : <pkg1>.[<func1@alias1,func2@alias2,...]
662
+ * Multiple packages : <pkg1>.<func1>:<pkg2>.<func2>
663
+ app_mode : str, optional
664
+ The mode of operation for the application: remote, managed, or collocated
665
+ url : str, optional
666
+ The URL of the function API
667
+ data_format : str, optional
668
+ The format of the data rows: 'rowdat_1' or 'json'
669
+ data_version : str, optional
670
+ The version of the call format to expect: '1.0'
671
+ link_name : str, optional
672
+ The link name to use for the external function application. This is
673
+ only for pre-existing links, and can only be used without
674
+ ``link_config`` and ``link_credentials``.
675
+ link_config : Dict[str, Any], optional
676
+ The CONFIG section of a LINK definition. This dictionary gets
677
+ converted to JSON for the CREATE LINK call.
678
+ link_credentials : Dict[str, Any], optional
679
+ The CREDENTIALS section of a LINK definition. This dictionary gets
680
+ converted to JSON for the CREATE LINK call.
681
+ name_prefix : str, optional
682
+ Prefix to add to function names when registering with the database
683
+ name_suffix : str, optional
684
+ Suffix to add to function names when registering with the database
685
+ function_database : str, optional
686
+ The database to use for external function definitions.
687
+ log_file : str, optional
688
+ File path to write logs to instead of console. If None, logs are
689
+ written to console. When specified, application logger handlers
690
+ are replaced with a file handler.
691
+ log_level : str, optional
692
+ Logging level for the application logger. Valid values are 'info',
693
+ 'debug', 'warning', 'error'. Defaults to 'info'.
694
+ disable_metrics : bool, optional
695
+ Disable logging of function call metrics. Defaults to False.
696
+ app_name : str, optional
697
+ Name for the application instance. Used to create a logger-specific
698
+ name. If not provided, a random name will be generated.
699
+
700
+ """
701
+
702
+ # Plain text response start
703
+ text_response_dict: Dict[str, Any] = dict(
704
+ type='http.response.start',
705
+ status=200,
706
+ headers=[(b'content-type', b'text/plain')],
707
+ )
708
+
709
+ # Error response start
710
+ error_response_dict: Dict[str, Any] = dict(
711
+ type='http.response.start',
712
+ status=500,
713
+ headers=[(b'content-type', b'text/plain')],
714
+ )
715
+
716
+ # Timeout response start
717
+ timeout_response_dict: Dict[str, Any] = dict(
718
+ type='http.response.start',
719
+ status=504,
720
+ headers=[(b'content-type', b'text/plain')],
721
+ )
722
+
723
+ # Cancel response start
724
+ cancel_response_dict: Dict[str, Any] = dict(
725
+ type='http.response.start',
726
+ status=503,
727
+ headers=[(b'content-type', b'text/plain')],
728
+ )
729
+
730
+ # JSON response start
731
+ json_response_dict: Dict[str, Any] = dict(
732
+ type='http.response.start',
733
+ status=200,
734
+ headers=[(b'content-type', b'application/json')],
735
+ )
736
+
737
+ # ROWDAT_1 response start
738
+ rowdat_1_response_dict: Dict[str, Any] = dict(
739
+ type='http.response.start',
740
+ status=200,
741
+ headers=[(b'content-type', b'x-application/rowdat_1')],
742
+ )
743
+
744
+ # Apache Arrow response start
745
+ arrow_response_dict: Dict[str, Any] = dict(
746
+ type='http.response.start',
747
+ status=200,
748
+ headers=[(b'content-type', b'application/vnd.apache.arrow.file')],
749
+ )
750
+
751
+ # Path not found response start
752
+ path_not_found_response_dict: Dict[str, Any] = dict(
753
+ type='http.response.start',
754
+ status=404,
755
+ )
756
+
757
+ # Response body template
758
+ body_response_dict: Dict[str, Any] = dict(
759
+ type='http.response.body',
760
+ )
761
+
762
+ # Data format + version handlers
763
+ handlers = {
764
+ (b'application/octet-stream', b'1.0', 'scalar'): dict(
765
+ load=rowdat_1.load,
766
+ dump=rowdat_1.dump,
767
+ response=rowdat_1_response_dict,
768
+ ),
769
+ (b'application/octet-stream', b'1.0', 'list'): dict(
770
+ load=rowdat_1.load,
771
+ dump=rowdat_1.dump,
772
+ response=rowdat_1_response_dict,
773
+ ),
774
+ (b'application/octet-stream', b'1.0', 'pandas'): dict(
775
+ load=rowdat_1.load_pandas,
776
+ dump=rowdat_1.dump_pandas,
777
+ response=rowdat_1_response_dict,
778
+ ),
779
+ (b'application/octet-stream', b'1.0', 'numpy'): dict(
780
+ load=rowdat_1.load_numpy,
781
+ dump=rowdat_1.dump_numpy,
782
+ response=rowdat_1_response_dict,
783
+ ),
784
+ (b'application/octet-stream', b'1.0', 'polars'): dict(
785
+ load=rowdat_1.load_polars,
786
+ dump=rowdat_1.dump_polars,
787
+ response=rowdat_1_response_dict,
788
+ ),
789
+ (b'application/octet-stream', b'1.0', 'arrow'): dict(
790
+ load=rowdat_1.load_arrow,
791
+ dump=rowdat_1.dump_arrow,
792
+ response=rowdat_1_response_dict,
793
+ ),
794
+ (b'application/json', b'1.0', 'scalar'): dict(
795
+ load=jdata.load,
796
+ dump=jdata.dump,
797
+ response=json_response_dict,
798
+ ),
799
+ (b'application/json', b'1.0', 'list'): dict(
800
+ load=jdata.load,
801
+ dump=jdata.dump,
802
+ response=json_response_dict,
803
+ ),
804
+ (b'application/json', b'1.0', 'pandas'): dict(
805
+ load=jdata.load_pandas,
806
+ dump=jdata.dump_pandas,
807
+ response=json_response_dict,
808
+ ),
809
+ (b'application/json', b'1.0', 'numpy'): dict(
810
+ load=jdata.load_numpy,
811
+ dump=jdata.dump_numpy,
812
+ response=json_response_dict,
813
+ ),
814
+ (b'application/json', b'1.0', 'polars'): dict(
815
+ load=jdata.load_polars,
816
+ dump=jdata.dump_polars,
817
+ response=json_response_dict,
818
+ ),
819
+ (b'application/json', b'1.0', 'arrow'): dict(
820
+ load=jdata.load_arrow,
821
+ dump=jdata.dump_arrow,
822
+ response=json_response_dict,
823
+ ),
824
+ (b'application/vnd.apache.arrow.file', b'1.0', 'scalar'): dict(
825
+ load=arrow.load,
826
+ dump=arrow.dump,
827
+ response=arrow_response_dict,
828
+ ),
829
+ (b'application/vnd.apache.arrow.file', b'1.0', 'pandas'): dict(
830
+ load=arrow.load_pandas,
831
+ dump=arrow.dump_pandas,
832
+ response=arrow_response_dict,
833
+ ),
834
+ (b'application/vnd.apache.arrow.file', b'1.0', 'numpy'): dict(
835
+ load=arrow.load_numpy,
836
+ dump=arrow.dump_numpy,
837
+ response=arrow_response_dict,
838
+ ),
839
+ (b'application/vnd.apache.arrow.file', b'1.0', 'polars'): dict(
840
+ load=arrow.load_polars,
841
+ dump=arrow.dump_polars,
842
+ response=arrow_response_dict,
843
+ ),
844
+ (b'application/vnd.apache.arrow.file', b'1.0', 'arrow'): dict(
845
+ load=arrow.load_arrow,
846
+ dump=arrow.dump_arrow,
847
+ response=arrow_response_dict,
848
+ ),
849
+ }
850
+
851
+ # Valid URL paths
852
+ invoke_path = ('invoke',)
853
+ show_create_function_path = ('show', 'create_function')
854
+ show_function_info_path = ('show', 'function_info')
855
+ status = ('status',)
856
+
857
+ def __init__(
858
+ self,
859
+ functions: Optional[
860
+ Union[
861
+ str,
862
+ Iterable[str],
863
+ Callable[..., Any],
864
+ Iterable[Callable[..., Any]],
865
+ ModuleType,
866
+ Iterable[ModuleType],
867
+ ]
868
+ ] = None,
869
+ app_mode: str = get_option('external_function.app_mode'),
870
+ url: str = get_option('external_function.url'),
871
+ data_format: str = get_option('external_function.data_format'),
872
+ data_version: str = get_option('external_function.data_version'),
873
+ link_name: Optional[str] = get_option('external_function.link_name'),
874
+ link_config: Optional[Dict[str, Any]] = None,
875
+ link_credentials: Optional[Dict[str, Any]] = None,
876
+ name_prefix: str = get_option('external_function.name_prefix'),
877
+ name_suffix: str = get_option('external_function.name_suffix'),
878
+ function_database: Optional[str] = None,
879
+ log_file: Optional[str] = get_option('external_function.log_file'),
880
+ log_level: str = get_option('external_function.log_level'),
881
+ disable_metrics: bool = get_option('external_function.disable_metrics'),
882
+ app_name: Optional[str] = get_option('external_function.app_name'),
883
+ ) -> None:
884
+ if link_name and (link_config or link_credentials):
885
+ raise ValueError(
886
+ '`link_name` can not be used with `link_config` or `link_credentials`',
887
+ )
888
+
889
+ if link_config is None:
890
+ link_config = json.loads(
891
+ get_option('external_function.link_config') or '{}',
892
+ ) or None
893
+
894
+ if link_credentials is None:
895
+ link_credentials = json.loads(
896
+ get_option('external_function.link_credentials') or '{}',
897
+ ) or None
898
+
899
+ # Generate application name if not provided
900
+ if app_name is None:
901
+ app_name = f'udf_app_{secrets.token_hex(4)}'
902
+
903
+ self.name = app_name
904
+
905
+ # Create logger instance specific to this application
906
+ self.logger = utils.get_logger(f'singlestoredb.functions.ext.asgi.{self.name}')
907
+
908
+ # List of functions specs
909
+ specs: List[Union[str, Callable[..., Any], ModuleType]] = []
910
+
911
+ # Look up Python function specifications
912
+ if functions is None:
913
+ env_vars = [
914
+ x for x in os.environ.keys()
915
+ if x.startswith('SINGLESTOREDB_EXT_FUNCTIONS')
916
+ ]
917
+ if env_vars:
918
+ specs = [os.environ[x] for x in env_vars]
919
+ else:
920
+ import __main__
921
+ specs = [__main__]
922
+
923
+ elif isinstance(functions, ModuleType):
924
+ specs = [functions]
925
+
926
+ elif isinstance(functions, str):
927
+ specs = [functions]
928
+
929
+ elif callable(functions):
930
+ specs = [functions]
931
+
932
+ else:
933
+ specs = list(functions)
934
+
935
+ # Add functions to application
936
+ endpoints = dict()
937
+ external_functions = dict()
938
+ for funcs in itertools.chain(specs):
939
+
940
+ if isinstance(funcs, str):
941
+ # Module name
942
+ if importlib.util.find_spec(funcs) is not None:
943
+ items = importlib.import_module(funcs)
944
+ for x in vars(items).values():
945
+ if not hasattr(x, '_singlestoredb_attrs'):
946
+ continue
947
+ name = x._singlestoredb_attrs.get('name', x.__name__)
948
+ name = f'{name_prefix}{name}{name_suffix}'
949
+ external_functions[x.__name__] = x
950
+ func, info = make_func(name, x)
951
+ endpoints[name.encode('utf-8')] = func, info
952
+
953
+ # Fully qualified function name
954
+ elif '.' in funcs:
955
+ pkg_path, func_names = funcs.rsplit('.', 1)
956
+ pkg = importlib.import_module(pkg_path)
957
+
958
+ if pkg is None:
959
+ raise RuntimeError(f'Could not locate module: {pkg}')
960
+
961
+ # Add endpoint for each exported function
962
+ for name, alias in get_func_names(func_names):
963
+ item = getattr(pkg, name)
964
+ alias = f'{name_prefix}{name}{name_suffix}'
965
+ external_functions[name] = item
966
+ func, info = make_func(alias, item)
967
+ endpoints[alias.encode('utf-8')] = func, info
968
+
969
+ else:
970
+ raise RuntimeError(f'Could not locate module: {funcs}')
971
+
972
+ elif isinstance(funcs, ModuleType):
973
+ for x in vars(funcs).values():
974
+ if not hasattr(x, '_singlestoredb_attrs'):
975
+ continue
976
+ name = x._singlestoredb_attrs.get('name', x.__name__)
977
+ name = f'{name_prefix}{name}{name_suffix}'
978
+ external_functions[x.__name__] = x
979
+ func, info = make_func(name, x)
980
+ endpoints[name.encode('utf-8')] = func, info
981
+
982
+ else:
983
+ alias = funcs.__name__
984
+ external_functions[funcs.__name__] = funcs
985
+ alias = f'{name_prefix}{alias}{name_suffix}'
986
+ func, info = make_func(alias, funcs)
987
+ endpoints[alias.encode('utf-8')] = func, info
988
+
989
+ self.app_mode = app_mode
990
+ self.url = url
991
+ self.data_format = data_format
992
+ self.data_version = data_version
993
+ self.link_name = link_name
994
+ self.link_config = link_config
995
+ self.link_credentials = link_credentials
996
+ self.endpoints = endpoints
997
+ self.external_functions = external_functions
998
+ self.function_database = function_database
999
+ self.log_file = log_file
1000
+ self.log_level = log_level
1001
+ self.disable_metrics = disable_metrics
1002
+
1003
+ # Configure logging
1004
+ self._configure_logging()
1005
+
1006
+ def _configure_logging(self) -> None:
1007
+ """Configure logging based on the log_file settings."""
1008
+ # Set logger level
1009
+ self.logger.setLevel(getattr(logging, self.log_level.upper()))
1010
+
1011
+ # Remove all existing handlers to ensure clean configuration
1012
+ self.logger.handlers.clear()
1013
+
1014
+ # Configure log file if specified
1015
+ if self.log_file:
1016
+ # Create file handler
1017
+ file_handler = logging.FileHandler(self.log_file)
1018
+ file_handler.setLevel(getattr(logging, self.log_level.upper()))
1019
+
1020
+ # Use JSON formatter for file logging
1021
+ formatter = utils.JSONFormatter()
1022
+ file_handler.setFormatter(formatter)
1023
+
1024
+ # Add the handler to the logger
1025
+ self.logger.addHandler(file_handler)
1026
+ else:
1027
+ # For console logging, create a new stream handler with JSON formatter
1028
+ console_handler = logging.StreamHandler()
1029
+ console_handler.setLevel(getattr(logging, self.log_level.upper()))
1030
+ console_handler.setFormatter(utils.JSONFormatter())
1031
+ self.logger.addHandler(console_handler)
1032
+
1033
+ # Prevent propagation to avoid duplicate or differently formatted messages
1034
+ self.logger.propagate = False
1035
+
1036
+ def get_uvicorn_log_config(self) -> Dict[str, Any]:
1037
+ """
1038
+ Create uvicorn log config that matches the Application's logging format.
1039
+
1040
+ This method returns the log configuration used by uvicorn, allowing external
1041
+ users to match the logging format of the Application class.
1042
+
1043
+ Returns
1044
+ -------
1045
+ Dict[str, Any]
1046
+ Log configuration dictionary compatible with uvicorn's log_config parameter
1047
+
1048
+ """
1049
+ log_config = {
1050
+ 'version': 1,
1051
+ 'disable_existing_loggers': False,
1052
+ 'formatters': {
1053
+ 'json': {
1054
+ '()': 'singlestoredb.functions.ext.utils.JSONFormatter',
1055
+ },
1056
+ },
1057
+ 'handlers': {
1058
+ 'default': {
1059
+ 'class': (
1060
+ 'logging.FileHandler' if self.log_file
1061
+ else 'logging.StreamHandler'
1062
+ ),
1063
+ 'formatter': 'json',
1064
+ },
1065
+ },
1066
+ 'loggers': {
1067
+ 'uvicorn': {
1068
+ 'handlers': ['default'],
1069
+ 'level': self.log_level.upper(),
1070
+ 'propagate': False,
1071
+ },
1072
+ 'uvicorn.error': {
1073
+ 'handlers': ['default'],
1074
+ 'level': self.log_level.upper(),
1075
+ 'propagate': False,
1076
+ },
1077
+ 'uvicorn.access': {
1078
+ 'handlers': ['default'],
1079
+ 'level': self.log_level.upper(),
1080
+ 'propagate': False,
1081
+ },
1082
+ },
1083
+ }
1084
+
1085
+ # Add filename to file handler if log file is specified
1086
+ if self.log_file:
1087
+ log_config['handlers']['default']['filename'] = self.log_file # type: ignore
1088
+
1089
+ return log_config
1090
+
1091
+ async def __call__(
1092
+ self,
1093
+ scope: Dict[str, Any],
1094
+ receive: Callable[..., Awaitable[Any]],
1095
+ send: Callable[..., Awaitable[Any]],
1096
+ ) -> None:
1097
+ '''
1098
+ Application request handler.
1099
+
1100
+ Parameters
1101
+ ----------
1102
+ scope : dict
1103
+ ASGI request scope
1104
+ receive : Callable
1105
+ Function to receieve request information
1106
+ send : Callable
1107
+ Function to send response information
1108
+
1109
+ '''
1110
+ request_id = str(uuid.uuid4())
1111
+
1112
+ timer = Timer(
1113
+ app_name=self.name,
1114
+ id=request_id,
1115
+ timestamp=datetime.datetime.now(
1116
+ datetime.timezone.utc,
1117
+ ).strftime('%Y-%m-%dT%H:%M:%S.%fZ'),
1118
+ )
1119
+ call_timer = Timer(
1120
+ app_name=self.name,
1121
+ id=request_id,
1122
+ timestamp=datetime.datetime.now(
1123
+ datetime.timezone.utc,
1124
+ ).strftime('%Y-%m-%dT%H:%M:%S.%fZ'),
1125
+ )
1126
+
1127
+ if scope['type'] != 'http':
1128
+ raise ValueError(f"Expected HTTP scope, got {scope['type']}")
1129
+
1130
+ method = scope['method']
1131
+ path = tuple(x for x in scope['path'].split('/') if x)
1132
+ headers = dict(scope['headers'])
1133
+
1134
+ content_type = headers.get(
1135
+ b'content-type',
1136
+ b'application/octet-stream',
1137
+ )
1138
+ accepts = headers.get(b'accepts', content_type)
1139
+ func_name = headers.get(b's2-ef-name', b'')
1140
+ func_endpoint = self.endpoints.get(func_name)
1141
+ ignore_cancel = headers.get(b's2-ef-ignore-cancel', b'false') == b'true'
1142
+
1143
+ timer.metadata['function'] = func_name.decode('utf-8') if func_name else ''
1144
+ call_timer.metadata['function'] = timer.metadata['function']
1145
+
1146
+ func = None
1147
+ func_info: Dict[str, Any] = {}
1148
+ if func_endpoint is not None:
1149
+ func, func_info = func_endpoint
1150
+
1151
+ # Call the endpoint
1152
+ if method == 'POST' and func is not None and path == self.invoke_path:
1153
+
1154
+ self.logger.info(
1155
+ 'Function call initiated',
1156
+ extra={
1157
+ 'app_name': self.name,
1158
+ 'request_id': request_id,
1159
+ 'function_name': func_name.decode('utf-8'),
1160
+ 'content_type': content_type.decode('utf-8'),
1161
+ 'accepts': accepts.decode('utf-8'),
1162
+ },
1163
+ )
1164
+
1165
+ args_data_format = func_info['args_data_format']
1166
+ returns_data_format = func_info['returns_data_format']
1167
+ data = []
1168
+ more_body = True
1169
+ with timer('receive_data'):
1170
+ while more_body:
1171
+ request = await receive()
1172
+ if request.get('type', '') == 'http.disconnect':
1173
+ raise RuntimeError('client disconnected')
1174
+ data.append(request['body'])
1175
+ more_body = request.get('more_body', False)
1176
+
1177
+ data_version = headers.get(b's2-ef-version', b'')
1178
+ input_handler = self.handlers[(content_type, data_version, args_data_format)]
1179
+ output_handler = self.handlers[(accepts, data_version, returns_data_format)]
1180
+
1181
+ try:
1182
+ all_tasks = []
1183
+ result = []
1184
+
1185
+ cancel_event = threading.Event()
1186
+
1187
+ with timer('parse_input'):
1188
+ inputs = input_handler['load']( # type: ignore
1189
+ func_info['colspec'], b''.join(data),
1190
+ )
1191
+
1192
+ func_task = asyncio.create_task(
1193
+ func(cancel_event, call_timer, *inputs)
1194
+ if func_info['is_async']
1195
+ else to_thread(
1196
+ lambda: asyncio.run(
1197
+ func(cancel_event, call_timer, *inputs),
1198
+ ),
1199
+ ),
1200
+ )
1201
+ disconnect_task = asyncio.create_task(
1202
+ asyncio.sleep(int(1e9))
1203
+ if ignore_cancel else cancel_on_disconnect(receive),
1204
+ )
1205
+ timeout_task = asyncio.create_task(
1206
+ cancel_on_timeout(func_info['timeout']),
1207
+ )
1208
+
1209
+ all_tasks += [func_task, disconnect_task, timeout_task]
1210
+
1211
+ async with timer('function_wrapper'):
1212
+ done, pending = await asyncio.wait(
1213
+ all_tasks, return_when=asyncio.FIRST_COMPLETED,
1214
+ )
1215
+
1216
+ await cancel_all_tasks(pending)
1217
+
1218
+ for task in done:
1219
+ if task is disconnect_task:
1220
+ cancel_event.set()
1221
+ raise asyncio.CancelledError(
1222
+ 'Function call was cancelled by client disconnect',
1223
+ )
1224
+
1225
+ elif task is timeout_task:
1226
+ cancel_event.set()
1227
+ raise asyncio.TimeoutError(
1228
+ 'Function call was cancelled due to timeout',
1229
+ )
1230
+
1231
+ elif task is func_task:
1232
+ result.extend(task.result())
1233
+
1234
+ with timer('format_output'):
1235
+ body = output_handler['dump'](
1236
+ [x[1] for x in func_info['returns']], *result, # type: ignore
1237
+ )
1238
+
1239
+ await send(output_handler['response'])
1240
+
1241
+ except asyncio.TimeoutError:
1242
+ self.logger.exception(
1243
+ 'Function call timeout',
1244
+ extra={
1245
+ 'app_name': self.name,
1246
+ 'request_id': request_id,
1247
+ 'function_name': func_name.decode('utf-8'),
1248
+ 'timeout': func_info['timeout'],
1249
+ },
1250
+ )
1251
+ body = (
1252
+ 'TimeoutError: Function call timed out after ' +
1253
+ str(func_info['timeout']) +
1254
+ ' seconds'
1255
+ ).encode('utf-8')
1256
+ await send(self.timeout_response_dict)
1257
+
1258
+ except asyncio.CancelledError:
1259
+ self.logger.exception(
1260
+ 'Function call cancelled',
1261
+ extra={
1262
+ 'app_name': self.name,
1263
+ 'request_id': request_id,
1264
+ 'function_name': func_name.decode('utf-8'),
1265
+ },
1266
+ )
1267
+ body = b'CancelledError: Function call was cancelled'
1268
+ await send(self.cancel_response_dict)
1269
+
1270
+ except Exception as e:
1271
+ self.logger.exception(
1272
+ 'Function call error',
1273
+ extra={
1274
+ 'app_name': self.name,
1275
+ 'request_id': request_id,
1276
+ 'function_name': func_name.decode('utf-8'),
1277
+ 'exception_type': type(e).__name__,
1278
+ },
1279
+ )
1280
+ msg = traceback.format_exc().strip().split(' File ')[-1]
1281
+ if msg.startswith('"/tmp/ipykernel_'):
1282
+ msg = 'Line ' + msg.split(', line ')[-1]
1283
+ else:
1284
+ msg = 'File ' + msg
1285
+ body = msg.encode('utf-8')
1286
+ await send(self.error_response_dict)
1287
+
1288
+ finally:
1289
+ await cancel_all_tasks(all_tasks)
1290
+
1291
+ # Handle api reflection
1292
+ elif method == 'GET' and path == self.show_create_function_path:
1293
+ host = headers.get(b'host', b'localhost:80')
1294
+ reflected_url = f'{scope["scheme"]}://{host.decode("utf-8")}/invoke'
1295
+
1296
+ syntax = []
1297
+ for key, (endpoint, endpoint_info) in self.endpoints.items():
1298
+ if not func_name or key == func_name:
1299
+ syntax.append(
1300
+ signature_to_sql(
1301
+ endpoint_info['signature'],
1302
+ url=self.url or reflected_url,
1303
+ data_format=self.data_format,
1304
+ database=self.function_database or None,
1305
+ ),
1306
+ )
1307
+ body = '\n'.join(syntax).encode('utf-8')
1308
+
1309
+ await send(self.text_response_dict)
1310
+
1311
+ # Return function info
1312
+ elif method == 'GET' and (path == self.show_function_info_path or not path):
1313
+ functions = self.get_function_info()
1314
+ body = json.dumps(dict(functions=functions)).encode('utf-8')
1315
+ await send(self.text_response_dict)
1316
+
1317
+ # Return status
1318
+ elif method == 'GET' and path == self.status:
1319
+ body = json.dumps(dict(status='ok')).encode('utf-8')
1320
+ await send(self.text_response_dict)
1321
+
1322
+ # Path not found
1323
+ else:
1324
+ body = b''
1325
+ await send(self.path_not_found_response_dict)
1326
+
1327
+ # Send body
1328
+ with timer('send_response'):
1329
+ out = self.body_response_dict.copy()
1330
+ out['body'] = body
1331
+ await send(out)
1332
+
1333
+ for k, v in call_timer.metrics.items():
1334
+ timer.metrics[k] = v
1335
+
1336
+ if not self.disable_metrics:
1337
+ metrics = timer.finish()
1338
+ self.logger.info(
1339
+ 'Function call metrics',
1340
+ extra={
1341
+ 'app_name': self.name,
1342
+ 'request_id': request_id,
1343
+ 'function_name': timer.metadata.get('function', ''),
1344
+ 'metrics': metrics,
1345
+ },
1346
+ )
1347
+
1348
+ def _create_link(
1349
+ self,
1350
+ config: Optional[Dict[str, Any]],
1351
+ credentials: Optional[Dict[str, Any]],
1352
+ ) -> Tuple[str, str]:
1353
+ """Generate CREATE LINK command."""
1354
+ if self.link_name:
1355
+ return self.link_name, ''
1356
+
1357
+ if not config and not credentials:
1358
+ return '', ''
1359
+
1360
+ link_name = f'py_ext_func_link_{secrets.token_hex(14)}'
1361
+ out = [f'CREATE LINK {link_name} AS HTTP']
1362
+
1363
+ if config:
1364
+ out.append(f"CONFIG '{json.dumps(config)}'")
1365
+
1366
+ if credentials:
1367
+ out.append(f"CREDENTIALS '{json.dumps(credentials)}'")
1368
+
1369
+ return link_name, ' '.join(out) + ';'
1370
+
1371
+ def _locate_app_functions(self, cur: Any) -> Tuple[Set[str], Set[str]]:
1372
+ """Locate all current functions and links belonging to this app."""
1373
+ funcs, links = set(), set()
1374
+ if self.function_database:
1375
+ database_prefix = escape_name(self.function_database) + '.'
1376
+ cur.execute(f'SHOW FUNCTIONS IN {escape_name(self.function_database)}')
1377
+ else:
1378
+ database_prefix = ''
1379
+ cur.execute('SHOW FUNCTIONS')
1380
+
1381
+ for row in list(cur):
1382
+ name, ftype, link = row[0], row[1], row[-1]
1383
+ # Only look at external functions
1384
+ if 'external' not in ftype.lower():
1385
+ continue
1386
+ # See if function URL matches url
1387
+ cur.execute(f'SHOW CREATE FUNCTION {database_prefix}{escape_name(name)}')
1388
+ for fname, _, code, *_ in list(cur):
1389
+ m = re.search(r" (?:\w+) (?:SERVICE|MANAGED) '([^']+)'", code)
1390
+ if m and m.group(1) == self.url:
1391
+ funcs.add(f'{database_prefix}{escape_name(fname)}')
1392
+ if link and re.match(r'^py_ext_func_link_\S{14}$', link):
1393
+ links.add(link)
1394
+
1395
+ return funcs, links
1396
+
1397
+ def get_function_info(
1398
+ self,
1399
+ func_name: Optional[str] = None,
1400
+ ) -> Dict[str, Any]:
1401
+ """
1402
+ Return the functions and function signature information.
1403
+
1404
+ Returns
1405
+ -------
1406
+ Dict[str, Any]
1407
+
1408
+ """
1409
+ functions = {}
1410
+ no_default = object()
1411
+
1412
+ # Generate CREATE FUNCTION SQL for each function using get_create_functions
1413
+ create_sqls = self.get_create_functions(replace=True)
1414
+ sql_map = {}
1415
+ for (_, info), sql in zip(self.endpoints.values(), create_sqls):
1416
+ sig = info['signature']
1417
+ sql_map[sig['name']] = sql
1418
+
1419
+ for key, (func, info) in self.endpoints.items():
1420
+ # Get info from docstring
1421
+ doc_summary = ''
1422
+ doc_long_description = ''
1423
+ doc_params = {}
1424
+ doc_returns = None
1425
+ doc_examples = []
1426
+ if func.__doc__:
1427
+ try:
1428
+ docs = parse(func.__doc__)
1429
+ doc_params = {p.arg_name: p for p in docs.params}
1430
+ doc_returns = docs.returns
1431
+ if not docs.short_description and docs.long_description:
1432
+ doc_summary = docs.long_description or ''
1433
+ else:
1434
+ doc_summary = docs.short_description or ''
1435
+ doc_long_description = docs.long_description or ''
1436
+ for ex in docs.examples:
1437
+ ex_dict: Dict[str, Any] = {
1438
+ 'description': None,
1439
+ 'code': None,
1440
+ 'output': None,
1441
+ }
1442
+ if ex.description:
1443
+ ex_dict['description'] = ex.description
1444
+ if ex.snippet:
1445
+ code, output = [], []
1446
+ for line in ex.snippet.split('\n'):
1447
+ line = line.rstrip()
1448
+ if re.match(r'^(\w+>|>>>|\.\.\.)', line):
1449
+ code.append(line)
1450
+ else:
1451
+ output.append(line)
1452
+ ex_dict['code'] = '\n'.join(code) or None
1453
+ ex_dict['output'] = '\n'.join(output) or None
1454
+ if ex.post_snippet:
1455
+ ex_dict['postscript'] = ex.post_snippet
1456
+ doc_examples.append(ex_dict)
1457
+
1458
+ except Exception as e:
1459
+ self.logger.warning(
1460
+ 'Could not parse docstring for function',
1461
+ extra={
1462
+ 'app_name': self.name,
1463
+ 'function_name': key.decode('utf-8'),
1464
+ 'error': str(e),
1465
+ },
1466
+ )
1467
+
1468
+ if not func_name or key == func_name:
1469
+ sig = info['signature']
1470
+ args = []
1471
+
1472
+ # Function arguments
1473
+ for i, a in enumerate(sig.get('args', [])):
1474
+ name = a['name']
1475
+ dtype = a['dtype']
1476
+ nullable = '?' in dtype
1477
+ args.append(
1478
+ dict(
1479
+ name=name,
1480
+ dtype=dtype.replace('?', ''),
1481
+ nullable=nullable,
1482
+ description=(doc_params[name].description or '')
1483
+ if name in doc_params else '',
1484
+ ),
1485
+ )
1486
+ if a.get('default', no_default) is not no_default:
1487
+ args[-1]['default'] = a['default']
1488
+
1489
+ # Return values
1490
+ ret = sig.get('returns', [])
1491
+ returns = []
1492
+
1493
+ for a in ret:
1494
+ dtype = a['dtype']
1495
+ nullable = '?' in dtype
1496
+ returns.append(
1497
+ dict(
1498
+ dtype=dtype.replace('?', ''),
1499
+ nullable=nullable,
1500
+ description=doc_returns.description
1501
+ if doc_returns else '',
1502
+ ),
1503
+ )
1504
+ if a.get('name', None):
1505
+ returns[-1]['name'] = a['name']
1506
+ if a.get('default', no_default) is not no_default:
1507
+ returns[-1]['default'] = a['default']
1508
+
1509
+ sql = sql_map.get(sig['name'], '')
1510
+ functions[sig['name']] = dict(
1511
+ args=args,
1512
+ returns=returns,
1513
+ function_type=info['function_type'],
1514
+ sql_statement=sql,
1515
+ summary=doc_summary,
1516
+ long_description=doc_long_description,
1517
+ examples=doc_examples,
1518
+ )
1519
+
1520
+ return functions
1521
+
1522
+ def get_create_functions(
1523
+ self,
1524
+ replace: bool = False,
1525
+ ) -> List[str]:
1526
+ """
1527
+ Generate CREATE FUNCTION code for all functions.
1528
+
1529
+ Parameters
1530
+ ----------
1531
+ replace : bool, optional
1532
+ Should existing functions be replaced?
1533
+
1534
+ Returns
1535
+ -------
1536
+ List[str]
1537
+
1538
+ """
1539
+ if not self.endpoints:
1540
+ return []
1541
+
1542
+ out = []
1543
+ link = ''
1544
+ if self.app_mode.lower() == 'remote':
1545
+ link, link_str = self._create_link(self.link_config, self.link_credentials)
1546
+ if link and link_str:
1547
+ out.append(link_str)
1548
+
1549
+ for key, (endpoint, endpoint_info) in self.endpoints.items():
1550
+ out.append(
1551
+ signature_to_sql(
1552
+ endpoint_info['signature'],
1553
+ url=self.url,
1554
+ data_format=self.data_format,
1555
+ app_mode=self.app_mode,
1556
+ replace=replace,
1557
+ link=link or None,
1558
+ database=self.function_database or None,
1559
+ ),
1560
+ )
1561
+
1562
+ return out
1563
+
1564
+ def register_functions(
1565
+ self,
1566
+ *connection_args: Any,
1567
+ replace: bool = False,
1568
+ **connection_kwargs: Any,
1569
+ ) -> None:
1570
+ """
1571
+ Register functions with the database.
1572
+
1573
+ Parameters
1574
+ ----------
1575
+ *connection_args : Any
1576
+ Database connection parameters
1577
+ replace : bool, optional
1578
+ Should existing functions be replaced?
1579
+ **connection_kwargs : Any
1580
+ Database connection parameters
1581
+
1582
+ """
1583
+ with connection.connect(*connection_args, **connection_kwargs) as conn:
1584
+ with conn.cursor() as cur:
1585
+ if replace:
1586
+ funcs, links = self._locate_app_functions(cur)
1587
+ for fname in funcs:
1588
+ cur.execute(f'DROP FUNCTION IF EXISTS {fname}')
1589
+ for link in links:
1590
+ cur.execute(f'DROP LINK {link}')
1591
+ for func in self.get_create_functions(replace=replace):
1592
+ cur.execute(func)
1593
+
1594
+ def drop_functions(
1595
+ self,
1596
+ *connection_args: Any,
1597
+ **connection_kwargs: Any,
1598
+ ) -> None:
1599
+ """
1600
+ Drop registered functions from database.
1601
+
1602
+ Parameters
1603
+ ----------
1604
+ *connection_args : Any
1605
+ Database connection parameters
1606
+ **connection_kwargs : Any
1607
+ Database connection parameters
1608
+
1609
+ """
1610
+ with connection.connect(*connection_args, **connection_kwargs) as conn:
1611
+ with conn.cursor() as cur:
1612
+ funcs, links = self._locate_app_functions(cur)
1613
+ for fname in funcs:
1614
+ cur.execute(f'DROP FUNCTION IF EXISTS {fname}')
1615
+ for link in links:
1616
+ cur.execute(f'DROP LINK {link}')
1617
+
1618
+ async def call(
1619
+ self,
1620
+ name: str,
1621
+ data_in: io.BytesIO,
1622
+ data_out: io.BytesIO,
1623
+ data_format: Optional[str] = None,
1624
+ data_version: Optional[str] = None,
1625
+ ) -> None:
1626
+ """
1627
+ Call a function in the application.
1628
+
1629
+ Parameters
1630
+ ----------
1631
+ name : str
1632
+ Name of the function to call
1633
+ data_in : io.BytesIO
1634
+ The input data rows
1635
+ data_out : io.BytesIO
1636
+ The output data rows
1637
+ data_format : str, optional
1638
+ The format of the input and output data
1639
+ data_version : str, optional
1640
+ The version of the data format
1641
+
1642
+ """
1643
+ data_format = data_format or self.data_format
1644
+ data_version = data_version or self.data_version
1645
+
1646
+ async def receive() -> Dict[str, Any]:
1647
+ return dict(body=data_in.read())
1648
+
1649
+ async def send(content: Dict[str, Any]) -> None:
1650
+ status = content.get('status', 200)
1651
+ if status != 200:
1652
+ raise KeyError(f'error occurred when calling `{name}`: {status}')
1653
+ data_out.write(content.get('body', b''))
1654
+
1655
+ accepts = dict(
1656
+ json=b'application/json',
1657
+ rowdat_1=b'application/octet-stream',
1658
+ arrow=b'application/vnd.apache.arrow.file',
1659
+ )
1660
+
1661
+ # Mock an ASGI scope
1662
+ scope = dict(
1663
+ type='http',
1664
+ path='invoke',
1665
+ method='POST',
1666
+ headers={
1667
+ b'content-type': accepts[data_format.lower()],
1668
+ b'accepts': accepts[data_format.lower()],
1669
+ b's2-ef-name': name.encode('utf-8'),
1670
+ b's2-ef-version': data_version.encode('utf-8'),
1671
+ b's2-ef-ignore-cancel': b'true',
1672
+ },
1673
+ )
1674
+
1675
+ await self(scope, receive, send)
1676
+
1677
+ def to_environment(
1678
+ self,
1679
+ name: str,
1680
+ destination: str = '.',
1681
+ version: Optional[str] = None,
1682
+ dependencies: Optional[List[str]] = None,
1683
+ authors: Optional[List[Dict[str, str]]] = None,
1684
+ maintainers: Optional[List[Dict[str, str]]] = None,
1685
+ description: Optional[str] = None,
1686
+ container_service: Optional[Dict[str, Any]] = None,
1687
+ external_function: Optional[Dict[str, Any]] = None,
1688
+ external_function_remote: Optional[Dict[str, Any]] = None,
1689
+ external_function_collocated: Optional[Dict[str, Any]] = None,
1690
+ overwrite: bool = False,
1691
+ ) -> None:
1692
+ """
1693
+ Convert application to an environment file.
1694
+
1695
+ Parameters
1696
+ ----------
1697
+ name : str
1698
+ Name of the output environment
1699
+ destination : str, optional
1700
+ Location of the output file
1701
+ version : str, optional
1702
+ Version of the package
1703
+ dependencies : List[str], optional
1704
+ List of dependency specifications like in a requirements.txt file
1705
+ authors : List[Dict[str, Any]], optional
1706
+ Dictionaries of author information. Keys may include: email, name
1707
+ maintainers : List[Dict[str, Any]], optional
1708
+ Dictionaries of maintainer information. Keys may include: email, name
1709
+ description : str, optional
1710
+ Description of package
1711
+ container_service : Dict[str, Any], optional
1712
+ Container service specifications
1713
+ external_function : Dict[str, Any], optional
1714
+ External function specifications (applies to both remote and collocated)
1715
+ external_function_remote : Dict[str, Any], optional
1716
+ Remote external function specifications
1717
+ external_function_collocated : Dict[str, Any], optional
1718
+ Collocated external function specifications
1719
+ overwrite : bool, optional
1720
+ Should destination file be overwritten if it exists?
1721
+
1722
+ """
1723
+ if not has_cloudpickle:
1724
+ raise RuntimeError('the cloudpicke package is required for this operation')
1725
+
1726
+ # Write to temporary location if a remote destination is specified
1727
+ tmpdir = None
1728
+ if destination.startswith('stage://'):
1729
+ tmpdir = tempfile.TemporaryDirectory()
1730
+ local_path = os.path.join(tmpdir.name, f'{name}.env')
1731
+ else:
1732
+ local_path = os.path.join(destination, f'{name}.env')
1733
+ if not overwrite and os.path.exists(local_path):
1734
+ raise OSError(f'path already exists: {local_path}')
1735
+
1736
+ with zipfile.ZipFile(local_path, mode='w') as z:
1737
+ # Write metadata
1738
+ z.writestr(
1739
+ 'pyproject.toml', utils.to_toml({
1740
+ 'project': dict(
1741
+ name=name,
1742
+ version=version,
1743
+ dependencies=dependencies,
1744
+ requires_python='== ' +
1745
+ '.'.join(str(x) for x in sys.version_info[:3]),
1746
+ authors=authors,
1747
+ maintainers=maintainers,
1748
+ description=description,
1749
+ ),
1750
+ 'tool.container-service': container_service,
1751
+ 'tool.external-function': external_function,
1752
+ 'tool.external-function.remote': external_function_remote,
1753
+ 'tool.external-function.collocated': external_function_collocated,
1754
+ }),
1755
+ )
1756
+
1757
+ # Write Python package
1758
+ z.writestr(
1759
+ f'{name}/__init__.py',
1760
+ textwrap.dedent(f'''
1761
+ import pickle as _pkl
1762
+ globals().update(
1763
+ _pkl.loads({cloudpickle.dumps(self.external_functions)}),
1764
+ )
1765
+ __all__ = {list(self.external_functions.keys())}''').strip(),
1766
+ )
1767
+
1768
+ # Upload to Stage as needed
1769
+ if destination.startswith('stage://'):
1770
+ url = urllib.parse.urlparse(re.sub(r'/+$', r'', destination) + '/')
1771
+ if not url.path or url.path == '/':
1772
+ raise ValueError(f'no stage path was specified: {destination}')
1773
+
1774
+ mgr = manage_workspaces()
1775
+ if url.hostname:
1776
+ wsg = mgr.get_workspace_group(url.hostname)
1777
+ elif os.environ.get('SINGLESTOREDB_WORKSPACE_GROUP'):
1778
+ wsg = mgr.get_workspace_group(
1779
+ os.environ['SINGLESTOREDB_WORKSPACE_GROUP'],
1780
+ )
1781
+ else:
1782
+ raise ValueError(f'no workspace group specified: {destination}')
1783
+
1784
+ # Make intermediate directories
1785
+ if url.path.count('/') > 1:
1786
+ wsg.stage.mkdirs(os.path.dirname(url.path))
1787
+
1788
+ wsg.stage.upload_file(
1789
+ local_path, url.path + f'{name}.env',
1790
+ overwrite=overwrite,
1791
+ )
1792
+ os.remove(local_path)
1793
+
1794
+
1795
+ def main(argv: Optional[List[str]] = None) -> None:
1796
+ """
1797
+ Main program for HTTP-based Python UDFs
1798
+
1799
+ Parameters
1800
+ ----------
1801
+ argv : List[str], optional
1802
+ List of command-line parameters
1803
+
1804
+ """
1805
+ try:
1806
+ import uvicorn
1807
+ except ImportError:
1808
+ raise ImportError('the uvicorn package is required to run this command')
1809
+
1810
+ # Should we run in embedded mode (typically for Jupyter)
1811
+ try:
1812
+ asyncio.get_running_loop()
1813
+ use_async = True
1814
+ except RuntimeError:
1815
+ use_async = False
1816
+
1817
+ # Temporary directory for Stage environment files
1818
+ tmpdir = None
1819
+
1820
+ # Depending on whether we find an environment file specified, we
1821
+ # may have to process the command line twice.
1822
+ functions = []
1823
+ defaults: Dict[str, Any] = {}
1824
+ for i in range(2):
1825
+
1826
+ parser = argparse.ArgumentParser(
1827
+ prog='python -m singlestoredb.functions.ext.asgi',
1828
+ description='Run an HTTP-based Python UDF server',
1829
+ )
1830
+ parser.add_argument(
1831
+ '--url', metavar='url',
1832
+ default=defaults.get(
1833
+ 'url',
1834
+ get_option('external_function.url'),
1835
+ ),
1836
+ help='URL of the UDF server endpoint',
1837
+ )
1838
+ parser.add_argument(
1839
+ '--host', metavar='host',
1840
+ default=defaults.get(
1841
+ 'host',
1842
+ get_option('external_function.host'),
1843
+ ),
1844
+ help='bind socket to this host',
1845
+ )
1846
+ parser.add_argument(
1847
+ '--port', metavar='port', type=int,
1848
+ default=defaults.get(
1849
+ 'port',
1850
+ get_option('external_function.port'),
1851
+ ),
1852
+ help='bind socket to this port',
1853
+ )
1854
+ parser.add_argument(
1855
+ '--db', metavar='conn-str',
1856
+ default=defaults.get(
1857
+ 'connection',
1858
+ get_option('external_function.connection'),
1859
+ ),
1860
+ help='connection string to use for registering functions',
1861
+ )
1862
+ parser.add_argument(
1863
+ '--replace-existing', action='store_true',
1864
+ help='should existing functions of the same name '
1865
+ 'in the database be replaced?',
1866
+ )
1867
+ parser.add_argument(
1868
+ '--data-format', metavar='format',
1869
+ default=defaults.get(
1870
+ 'data_format',
1871
+ get_option('external_function.data_format'),
1872
+ ),
1873
+ choices=['rowdat_1', 'json'],
1874
+ help='format of the data rows',
1875
+ )
1876
+ parser.add_argument(
1877
+ '--data-version', metavar='version',
1878
+ default=defaults.get(
1879
+ 'data_version',
1880
+ get_option('external_function.data_version'),
1881
+ ),
1882
+ help='version of the data row format',
1883
+ )
1884
+ parser.add_argument(
1885
+ '--link-name', metavar='name',
1886
+ default=defaults.get(
1887
+ 'link_name',
1888
+ get_option('external_function.link_name'),
1889
+ ) or '',
1890
+ help='name of the link to use for connections',
1891
+ )
1892
+ parser.add_argument(
1893
+ '--link-config', metavar='json',
1894
+ default=str(
1895
+ defaults.get(
1896
+ 'link_config',
1897
+ get_option('external_function.link_config'),
1898
+ ) or '{}',
1899
+ ),
1900
+ help='link config in JSON format',
1901
+ )
1902
+ parser.add_argument(
1903
+ '--link-credentials', metavar='json',
1904
+ default=str(
1905
+ defaults.get(
1906
+ 'link_credentials',
1907
+ get_option('external_function.link_credentials'),
1908
+ ) or '{}',
1909
+ ),
1910
+ help='link credentials in JSON format',
1911
+ )
1912
+ parser.add_argument(
1913
+ '--log-level', metavar='[info|debug|warning|error]',
1914
+ default=defaults.get(
1915
+ 'log_level',
1916
+ get_option('external_function.log_level'),
1917
+ ),
1918
+ help='logging level',
1919
+ )
1920
+ parser.add_argument(
1921
+ '--log-file', metavar='filepath',
1922
+ default=defaults.get(
1923
+ 'log_file',
1924
+ get_option('external_function.log_file'),
1925
+ ),
1926
+ help='File path to write logs to instead of console',
1927
+ )
1928
+ parser.add_argument(
1929
+ '--disable-metrics', action='store_true',
1930
+ default=defaults.get(
1931
+ 'disable_metrics',
1932
+ get_option('external_function.disable_metrics'),
1933
+ ),
1934
+ help='Disable logging of function call metrics',
1935
+ )
1936
+ parser.add_argument(
1937
+ '--name-prefix', metavar='name_prefix',
1938
+ default=defaults.get(
1939
+ 'name_prefix',
1940
+ get_option('external_function.name_prefix'),
1941
+ ),
1942
+ help='Prefix to add to function names',
1943
+ )
1944
+ parser.add_argument(
1945
+ '--name-suffix', metavar='name_suffix',
1946
+ default=defaults.get(
1947
+ 'name_suffix',
1948
+ get_option('external_function.name_suffix'),
1949
+ ),
1950
+ help='Suffix to add to function names',
1951
+ )
1952
+ parser.add_argument(
1953
+ '--function-database', metavar='function_database',
1954
+ default=defaults.get(
1955
+ 'function_database',
1956
+ get_option('external_function.function_database'),
1957
+ ),
1958
+ help='Database to use for the function definition',
1959
+ )
1960
+ parser.add_argument(
1961
+ '--app-name', metavar='app_name',
1962
+ default=defaults.get(
1963
+ 'app_name',
1964
+ get_option('external_function.app_name'),
1965
+ ),
1966
+ help='Name for the application instance',
1967
+ )
1968
+ parser.add_argument(
1969
+ 'functions', metavar='module.or.func.path', nargs='*',
1970
+ help='functions or modules to export in UDF server',
1971
+ )
1972
+
1973
+ args = parser.parse_args(argv)
1974
+
1975
+ if i > 0:
1976
+ break
1977
+
1978
+ # Download Stage files as needed
1979
+ for i, f in enumerate(args.functions):
1980
+ if f.startswith('stage://'):
1981
+ url = urllib.parse.urlparse(f)
1982
+ if not url.path or url.path == '/':
1983
+ raise ValueError(f'no stage path was specified: {f}')
1984
+ if url.path.endswith('/'):
1985
+ raise ValueError(f'an environment file must be specified: {f}')
1986
+
1987
+ mgr = manage_workspaces()
1988
+ if url.hostname:
1989
+ wsg = mgr.get_workspace_group(url.hostname)
1990
+ elif os.environ.get('SINGLESTOREDB_WORKSPACE_GROUP'):
1991
+ wsg = mgr.get_workspace_group(
1992
+ os.environ['SINGLESTOREDB_WORKSPACE_GROUP'],
1993
+ )
1994
+ else:
1995
+ raise ValueError(f'no workspace group specified: {f}')
1996
+
1997
+ if tmpdir is None:
1998
+ tmpdir = tempfile.TemporaryDirectory()
1999
+
2000
+ local_path = os.path.join(tmpdir.name, url.path.split('/')[-1])
2001
+ wsg.stage.download_file(url.path, local_path)
2002
+ args.functions[i] = local_path
2003
+
2004
+ elif f.startswith('http://') or f.startswith('https://'):
2005
+ if tmpdir is None:
2006
+ tmpdir = tempfile.TemporaryDirectory()
2007
+
2008
+ local_path = os.path.join(tmpdir.name, f.split('/')[-1])
2009
+ urllib.request.urlretrieve(f, local_path)
2010
+ args.functions[i] = local_path
2011
+
2012
+ # See if any of the args are zip files (assume they are environment files)
2013
+ modules = [(x, zipfile.is_zipfile(x)) for x in args.functions]
2014
+ envs = [x[0] for x in modules if x[1]]
2015
+ others = [x[0] for x in modules if not x[1]]
2016
+
2017
+ if envs and len(envs) > 1:
2018
+ raise RuntimeError('only one environment file may be specified')
2019
+
2020
+ if envs and others:
2021
+ raise RuntimeError('environment files and other modules can not be mixed.')
2022
+
2023
+ # See if an environment file was specified. If so, use those settings
2024
+ # as the defaults and reprocess command line.
2025
+ if envs:
2026
+ # Add pyproject.toml variables and redo command-line processing
2027
+ defaults = utils.read_config(
2028
+ envs[0],
2029
+ ['tool.external-function', 'tool.external-function.remote'],
2030
+ )
2031
+
2032
+ # Load zip file as a module
2033
+ modname = os.path.splitext(os.path.basename(envs[0]))[0]
2034
+ zi = zipimport.zipimporter(envs[0])
2035
+ mod = zi.load_module(modname)
2036
+ if mod is None:
2037
+ raise RuntimeError(f'environment file could not be imported: {envs[0]}')
2038
+ functions = [mod]
2039
+
2040
+ if defaults:
2041
+ continue
2042
+
2043
+ args.functions = functions or args.functions or None
2044
+ args.replace_existing = args.replace_existing \
2045
+ or defaults.get('replace_existing') \
2046
+ or get_option('external_function.replace_existing')
2047
+
2048
+ # Substitute in host / port if specified
2049
+ if args.host != defaults.get('host') or args.port != defaults.get('port'):
2050
+ u = urllib.parse.urlparse(args.url)
2051
+ args.url = u._replace(netloc=f'{args.host}:{args.port}').geturl()
2052
+
2053
+ # Create application from functions / module
2054
+ app = Application(
2055
+ functions=args.functions,
2056
+ url=args.url,
2057
+ data_format=args.data_format,
2058
+ data_version=args.data_version,
2059
+ link_name=args.link_name or None,
2060
+ link_config=json.loads(args.link_config) or None,
2061
+ link_credentials=json.loads(args.link_credentials) or None,
2062
+ app_mode='remote',
2063
+ name_prefix=args.name_prefix,
2064
+ name_suffix=args.name_suffix,
2065
+ function_database=args.function_database or None,
2066
+ log_file=args.log_file,
2067
+ log_level=args.log_level,
2068
+ disable_metrics=args.disable_metrics,
2069
+ app_name=args.app_name,
2070
+ )
2071
+
2072
+ funcs = app.get_create_functions(replace=args.replace_existing)
2073
+ if not funcs:
2074
+ raise RuntimeError('no functions specified')
2075
+
2076
+ for f in funcs:
2077
+ app.logger.info(f)
2078
+
2079
+ try:
2080
+ if args.db:
2081
+ app.logger.info('Registering functions with database')
2082
+ app.register_functions(
2083
+ args.db,
2084
+ replace=args.replace_existing,
2085
+ )
2086
+
2087
+ app_args = {
2088
+ k: v for k, v in dict(
2089
+ host=args.host or None,
2090
+ port=args.port or None,
2091
+ log_level=args.log_level,
2092
+ lifespan='off',
2093
+ ).items() if v is not None
2094
+ }
2095
+
2096
+ # Configure uvicorn logging to use JSON format matching Application's format
2097
+ app_args['log_config'] = app.get_uvicorn_log_config()
2098
+
2099
+ if use_async:
2100
+ asyncio.create_task(_run_uvicorn(uvicorn, app, app_args, db=args.db))
2101
+ else:
2102
+ uvicorn.run(app, **app_args)
2103
+
2104
+ finally:
2105
+ if not use_async and args.db:
2106
+ app.logger.info('Dropping functions from database')
2107
+ app.drop_functions(args.db)
2108
+
2109
+
2110
+ async def _run_uvicorn(
2111
+ uvicorn: Any,
2112
+ app: Any,
2113
+ app_args: Any,
2114
+ db: Optional[str] = None,
2115
+ ) -> None:
2116
+ """Run uvicorn server and clean up functions after shutdown."""
2117
+ await uvicorn.Server(uvicorn.Config(app, **app_args)).serve()
2118
+ if db:
2119
+ app.logger.info('Dropping functions from database')
2120
+ app.drop_functions(db)
2121
+
2122
+
2123
+ create_app = Application
2124
+
2125
+
2126
+ if __name__ == '__main__':
2127
+ try:
2128
+ main()
2129
+ except RuntimeError as exc:
2130
+ logger.error(str(exc))
2131
+ sys.exit(1)
2132
+ except KeyboardInterrupt:
2133
+ pass