singlestoredb 1.16.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- singlestoredb/__init__.py +75 -0
- singlestoredb/ai/__init__.py +2 -0
- singlestoredb/ai/chat.py +139 -0
- singlestoredb/ai/embeddings.py +128 -0
- singlestoredb/alchemy/__init__.py +90 -0
- singlestoredb/apps/__init__.py +3 -0
- singlestoredb/apps/_cloud_functions.py +90 -0
- singlestoredb/apps/_config.py +72 -0
- singlestoredb/apps/_connection_info.py +18 -0
- singlestoredb/apps/_dashboards.py +47 -0
- singlestoredb/apps/_process.py +32 -0
- singlestoredb/apps/_python_udfs.py +100 -0
- singlestoredb/apps/_stdout_supress.py +30 -0
- singlestoredb/apps/_uvicorn_util.py +36 -0
- singlestoredb/auth.py +245 -0
- singlestoredb/config.py +484 -0
- singlestoredb/connection.py +1487 -0
- singlestoredb/converters.py +950 -0
- singlestoredb/docstring/__init__.py +33 -0
- singlestoredb/docstring/attrdoc.py +126 -0
- singlestoredb/docstring/common.py +230 -0
- singlestoredb/docstring/epydoc.py +267 -0
- singlestoredb/docstring/google.py +412 -0
- singlestoredb/docstring/numpydoc.py +562 -0
- singlestoredb/docstring/parser.py +100 -0
- singlestoredb/docstring/py.typed +1 -0
- singlestoredb/docstring/rest.py +256 -0
- singlestoredb/docstring/tests/__init__.py +1 -0
- singlestoredb/docstring/tests/_pydoctor.py +21 -0
- singlestoredb/docstring/tests/test_epydoc.py +729 -0
- singlestoredb/docstring/tests/test_google.py +1007 -0
- singlestoredb/docstring/tests/test_numpydoc.py +1100 -0
- singlestoredb/docstring/tests/test_parse_from_object.py +109 -0
- singlestoredb/docstring/tests/test_parser.py +248 -0
- singlestoredb/docstring/tests/test_rest.py +547 -0
- singlestoredb/docstring/tests/test_util.py +70 -0
- singlestoredb/docstring/util.py +141 -0
- singlestoredb/exceptions.py +120 -0
- singlestoredb/functions/__init__.py +16 -0
- singlestoredb/functions/decorator.py +201 -0
- singlestoredb/functions/dtypes.py +1793 -0
- singlestoredb/functions/ext/__init__.py +1 -0
- singlestoredb/functions/ext/arrow.py +375 -0
- singlestoredb/functions/ext/asgi.py +2133 -0
- singlestoredb/functions/ext/json.py +420 -0
- singlestoredb/functions/ext/mmap.py +413 -0
- singlestoredb/functions/ext/rowdat_1.py +724 -0
- singlestoredb/functions/ext/timer.py +89 -0
- singlestoredb/functions/ext/utils.py +218 -0
- singlestoredb/functions/signature.py +1578 -0
- singlestoredb/functions/typing/__init__.py +41 -0
- singlestoredb/functions/typing/numpy.py +20 -0
- singlestoredb/functions/typing/pandas.py +2 -0
- singlestoredb/functions/typing/polars.py +2 -0
- singlestoredb/functions/typing/pyarrow.py +2 -0
- singlestoredb/functions/utils.py +421 -0
- singlestoredb/fusion/__init__.py +11 -0
- singlestoredb/fusion/graphql.py +213 -0
- singlestoredb/fusion/handler.py +916 -0
- singlestoredb/fusion/handlers/__init__.py +0 -0
- singlestoredb/fusion/handlers/export.py +525 -0
- singlestoredb/fusion/handlers/files.py +690 -0
- singlestoredb/fusion/handlers/job.py +660 -0
- singlestoredb/fusion/handlers/models.py +250 -0
- singlestoredb/fusion/handlers/stage.py +502 -0
- singlestoredb/fusion/handlers/utils.py +324 -0
- singlestoredb/fusion/handlers/workspace.py +956 -0
- singlestoredb/fusion/registry.py +249 -0
- singlestoredb/fusion/result.py +399 -0
- singlestoredb/http/__init__.py +27 -0
- singlestoredb/http/connection.py +1267 -0
- singlestoredb/magics/__init__.py +34 -0
- singlestoredb/magics/run_personal.py +137 -0
- singlestoredb/magics/run_shared.py +134 -0
- singlestoredb/management/__init__.py +9 -0
- singlestoredb/management/billing_usage.py +148 -0
- singlestoredb/management/cluster.py +462 -0
- singlestoredb/management/export.py +295 -0
- singlestoredb/management/files.py +1102 -0
- singlestoredb/management/inference_api.py +105 -0
- singlestoredb/management/job.py +887 -0
- singlestoredb/management/manager.py +373 -0
- singlestoredb/management/organization.py +226 -0
- singlestoredb/management/region.py +169 -0
- singlestoredb/management/utils.py +423 -0
- singlestoredb/management/workspace.py +1927 -0
- singlestoredb/mysql/__init__.py +177 -0
- singlestoredb/mysql/_auth.py +298 -0
- singlestoredb/mysql/charset.py +214 -0
- singlestoredb/mysql/connection.py +2032 -0
- singlestoredb/mysql/constants/CLIENT.py +38 -0
- singlestoredb/mysql/constants/COMMAND.py +32 -0
- singlestoredb/mysql/constants/CR.py +78 -0
- singlestoredb/mysql/constants/ER.py +474 -0
- singlestoredb/mysql/constants/EXTENDED_TYPE.py +3 -0
- singlestoredb/mysql/constants/FIELD_TYPE.py +48 -0
- singlestoredb/mysql/constants/FLAG.py +15 -0
- singlestoredb/mysql/constants/SERVER_STATUS.py +10 -0
- singlestoredb/mysql/constants/VECTOR_TYPE.py +6 -0
- singlestoredb/mysql/constants/__init__.py +0 -0
- singlestoredb/mysql/converters.py +271 -0
- singlestoredb/mysql/cursors.py +896 -0
- singlestoredb/mysql/err.py +92 -0
- singlestoredb/mysql/optionfile.py +20 -0
- singlestoredb/mysql/protocol.py +450 -0
- singlestoredb/mysql/tests/__init__.py +19 -0
- singlestoredb/mysql/tests/base.py +126 -0
- singlestoredb/mysql/tests/conftest.py +37 -0
- singlestoredb/mysql/tests/test_DictCursor.py +132 -0
- singlestoredb/mysql/tests/test_SSCursor.py +141 -0
- singlestoredb/mysql/tests/test_basic.py +452 -0
- singlestoredb/mysql/tests/test_connection.py +851 -0
- singlestoredb/mysql/tests/test_converters.py +58 -0
- singlestoredb/mysql/tests/test_cursor.py +141 -0
- singlestoredb/mysql/tests/test_err.py +16 -0
- singlestoredb/mysql/tests/test_issues.py +514 -0
- singlestoredb/mysql/tests/test_load_local.py +75 -0
- singlestoredb/mysql/tests/test_nextset.py +88 -0
- singlestoredb/mysql/tests/test_optionfile.py +27 -0
- singlestoredb/mysql/tests/thirdparty/__init__.py +6 -0
- singlestoredb/mysql/tests/thirdparty/test_MySQLdb/__init__.py +9 -0
- singlestoredb/mysql/tests/thirdparty/test_MySQLdb/capabilities.py +323 -0
- singlestoredb/mysql/tests/thirdparty/test_MySQLdb/dbapi20.py +865 -0
- singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py +110 -0
- singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py +224 -0
- singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py +101 -0
- singlestoredb/mysql/times.py +23 -0
- singlestoredb/notebook/__init__.py +16 -0
- singlestoredb/notebook/_objects.py +213 -0
- singlestoredb/notebook/_portal.py +352 -0
- singlestoredb/py.typed +0 -0
- singlestoredb/pytest.py +352 -0
- singlestoredb/server/__init__.py +0 -0
- singlestoredb/server/docker.py +452 -0
- singlestoredb/server/free_tier.py +267 -0
- singlestoredb/tests/__init__.py +0 -0
- singlestoredb/tests/alltypes.sql +307 -0
- singlestoredb/tests/alltypes_no_nulls.sql +208 -0
- singlestoredb/tests/empty.sql +0 -0
- singlestoredb/tests/ext_funcs/__init__.py +702 -0
- singlestoredb/tests/local_infile.csv +3 -0
- singlestoredb/tests/test.ipynb +18 -0
- singlestoredb/tests/test.sql +680 -0
- singlestoredb/tests/test2.ipynb +18 -0
- singlestoredb/tests/test2.sql +1 -0
- singlestoredb/tests/test_basics.py +1332 -0
- singlestoredb/tests/test_config.py +318 -0
- singlestoredb/tests/test_connection.py +3103 -0
- singlestoredb/tests/test_dbapi.py +27 -0
- singlestoredb/tests/test_exceptions.py +45 -0
- singlestoredb/tests/test_ext_func.py +1472 -0
- singlestoredb/tests/test_ext_func_data.py +1101 -0
- singlestoredb/tests/test_fusion.py +1527 -0
- singlestoredb/tests/test_http.py +288 -0
- singlestoredb/tests/test_management.py +1599 -0
- singlestoredb/tests/test_plugin.py +33 -0
- singlestoredb/tests/test_results.py +171 -0
- singlestoredb/tests/test_types.py +132 -0
- singlestoredb/tests/test_udf.py +737 -0
- singlestoredb/tests/test_udf_returns.py +459 -0
- singlestoredb/tests/test_vectorstore.py +51 -0
- singlestoredb/tests/test_xdict.py +333 -0
- singlestoredb/tests/utils.py +141 -0
- singlestoredb/types.py +373 -0
- singlestoredb/utils/__init__.py +0 -0
- singlestoredb/utils/config.py +950 -0
- singlestoredb/utils/convert_rows.py +69 -0
- singlestoredb/utils/debug.py +13 -0
- singlestoredb/utils/dtypes.py +205 -0
- singlestoredb/utils/events.py +65 -0
- singlestoredb/utils/mogrify.py +151 -0
- singlestoredb/utils/results.py +585 -0
- singlestoredb/utils/xdict.py +425 -0
- singlestoredb/vectorstore.py +192 -0
- singlestoredb/warnings.py +5 -0
- singlestoredb-1.16.1.dist-info/METADATA +165 -0
- singlestoredb-1.16.1.dist-info/RECORD +183 -0
- singlestoredb-1.16.1.dist-info/WHEEL +5 -0
- singlestoredb-1.16.1.dist-info/entry_points.txt +2 -0
- singlestoredb-1.16.1.dist-info/licenses/LICENSE +201 -0
- singlestoredb-1.16.1.dist-info/top_level.txt +3 -0
- sqlx/__init__.py +4 -0
- sqlx/magic.py +113 -0
|
@@ -0,0 +1,2133 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Web application for SingleStoreDB external functions.
|
|
4
|
+
|
|
5
|
+
This module supplies a function that can create web apps intended for use
|
|
6
|
+
with the external function feature of SingleStoreDB. The application
|
|
7
|
+
function is a standard ASGI <https://asgi.readthedocs.io/en/latest/index.html>
|
|
8
|
+
request handler for use with servers such as Uvicorn <https://www.uvicorn.org>.
|
|
9
|
+
|
|
10
|
+
An external function web application can be created using the `create_app`
|
|
11
|
+
function. By default, the exported Python functions are specified by
|
|
12
|
+
environment variables starting with SINGLESTOREDB_EXT_FUNCTIONS. See the
|
|
13
|
+
documentation in `create_app` for the full syntax. If the application is
|
|
14
|
+
created in Python code rather than from the command-line, exported
|
|
15
|
+
functions can be specified in the parameters.
|
|
16
|
+
|
|
17
|
+
An example of starting a server is shown below.
|
|
18
|
+
|
|
19
|
+
Example
|
|
20
|
+
-------
|
|
21
|
+
> SINGLESTOREDB_EXT_FUNCTIONS='myfuncs.[percentile_90,percentile_95]' \
|
|
22
|
+
python3 -m singlestoredb.functions.ext.asgi
|
|
23
|
+
|
|
24
|
+
"""
|
|
25
|
+
import argparse
|
|
26
|
+
import asyncio
|
|
27
|
+
import contextvars
|
|
28
|
+
import dataclasses
|
|
29
|
+
import datetime
|
|
30
|
+
import functools
|
|
31
|
+
import importlib.util
|
|
32
|
+
import inspect
|
|
33
|
+
import io
|
|
34
|
+
import itertools
|
|
35
|
+
import json
|
|
36
|
+
import logging
|
|
37
|
+
import os
|
|
38
|
+
import re
|
|
39
|
+
import secrets
|
|
40
|
+
import sys
|
|
41
|
+
import tempfile
|
|
42
|
+
import textwrap
|
|
43
|
+
import threading
|
|
44
|
+
import time
|
|
45
|
+
import traceback
|
|
46
|
+
import typing
|
|
47
|
+
import urllib
|
|
48
|
+
import uuid
|
|
49
|
+
import zipfile
|
|
50
|
+
import zipimport
|
|
51
|
+
from collections.abc import Awaitable
|
|
52
|
+
from collections.abc import Iterable
|
|
53
|
+
from collections.abc import Sequence
|
|
54
|
+
from types import ModuleType
|
|
55
|
+
from typing import Any
|
|
56
|
+
from typing import Callable
|
|
57
|
+
from typing import Dict
|
|
58
|
+
from typing import List
|
|
59
|
+
from typing import Optional
|
|
60
|
+
from typing import Set
|
|
61
|
+
from typing import Tuple
|
|
62
|
+
from typing import Union
|
|
63
|
+
|
|
64
|
+
from . import arrow
|
|
65
|
+
from . import json as jdata
|
|
66
|
+
from . import rowdat_1
|
|
67
|
+
from . import utils
|
|
68
|
+
from ... import connection
|
|
69
|
+
from ... import manage_workspaces
|
|
70
|
+
from ...config import get_option
|
|
71
|
+
from ...mysql.constants import FIELD_TYPE as ft
|
|
72
|
+
from ..signature import get_signature
|
|
73
|
+
from ..signature import signature_to_sql
|
|
74
|
+
from ..typing import Masked
|
|
75
|
+
from ..typing import Table
|
|
76
|
+
from .timer import Timer
|
|
77
|
+
from singlestoredb.docstring.parser import parse
|
|
78
|
+
from singlestoredb.functions.dtypes import escape_name
|
|
79
|
+
|
|
80
|
+
try:
|
|
81
|
+
import cloudpickle
|
|
82
|
+
has_cloudpickle = True
|
|
83
|
+
except ImportError:
|
|
84
|
+
has_cloudpickle = False
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
from pydantic import BaseModel
|
|
88
|
+
has_pydantic = True
|
|
89
|
+
except ImportError:
|
|
90
|
+
has_pydantic = False
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
logger = utils.get_logger('singlestoredb.functions.ext.asgi')
|
|
94
|
+
|
|
95
|
+
# If a number of processes is specified, create a pool of workers
|
|
96
|
+
num_processes = max(0, int(os.environ.get('SINGLESTOREDB_EXT_NUM_PROCESSES', 0)))
|
|
97
|
+
if num_processes > 1:
|
|
98
|
+
try:
|
|
99
|
+
from ray.util.multiprocessing import Pool
|
|
100
|
+
except ImportError:
|
|
101
|
+
from multiprocessing import Pool
|
|
102
|
+
func_map = Pool(num_processes).starmap
|
|
103
|
+
else:
|
|
104
|
+
func_map = itertools.starmap
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
async def to_thread(
|
|
108
|
+
func: Any, /, *args: Any, **kwargs: Dict[str, Any],
|
|
109
|
+
) -> Any:
|
|
110
|
+
loop = asyncio.get_running_loop()
|
|
111
|
+
ctx = contextvars.copy_context()
|
|
112
|
+
func_call = functools.partial(ctx.run, func, *args, **kwargs)
|
|
113
|
+
return await loop.run_in_executor(None, func_call)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
# Use negative values to indicate unsigned ints / binary data / usec time precision
|
|
117
|
+
rowdat_1_type_map = {
|
|
118
|
+
'bool': ft.LONGLONG,
|
|
119
|
+
'int8': ft.LONGLONG,
|
|
120
|
+
'int16': ft.LONGLONG,
|
|
121
|
+
'int32': ft.LONGLONG,
|
|
122
|
+
'int64': ft.LONGLONG,
|
|
123
|
+
'uint8': -ft.LONGLONG,
|
|
124
|
+
'uint16': -ft.LONGLONG,
|
|
125
|
+
'uint32': -ft.LONGLONG,
|
|
126
|
+
'uint64': -ft.LONGLONG,
|
|
127
|
+
'float32': ft.DOUBLE,
|
|
128
|
+
'float64': ft.DOUBLE,
|
|
129
|
+
'str': ft.STRING,
|
|
130
|
+
'bytes': -ft.STRING,
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def get_func_names(funcs: str) -> List[Tuple[str, str]]:
|
|
135
|
+
"""
|
|
136
|
+
Parse all function names from string.
|
|
137
|
+
|
|
138
|
+
Parameters
|
|
139
|
+
----------
|
|
140
|
+
func_names : str
|
|
141
|
+
String containing one or more function names. The syntax is
|
|
142
|
+
as follows: [func-name-1@func-alias-1,func-name-2@func-alias-2,...].
|
|
143
|
+
The optional '@name' portion is an alias if you want the function
|
|
144
|
+
to be renamed.
|
|
145
|
+
|
|
146
|
+
Returns
|
|
147
|
+
-------
|
|
148
|
+
List[Tuple[str]] : a list of tuples containing the names and aliases
|
|
149
|
+
of each function.
|
|
150
|
+
|
|
151
|
+
"""
|
|
152
|
+
if funcs.startswith('['):
|
|
153
|
+
func_names = funcs.replace('[', '').replace(']', '').split(',')
|
|
154
|
+
func_names = [x.strip() for x in func_names]
|
|
155
|
+
else:
|
|
156
|
+
func_names = [funcs]
|
|
157
|
+
|
|
158
|
+
out = []
|
|
159
|
+
for name in func_names:
|
|
160
|
+
alias = name
|
|
161
|
+
if '@' in name:
|
|
162
|
+
name, alias = name.split('@', 1)
|
|
163
|
+
out.append((name, alias))
|
|
164
|
+
|
|
165
|
+
return out
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def as_tuple(x: Any) -> Any:
|
|
169
|
+
"""Convert object to tuple."""
|
|
170
|
+
if has_pydantic and isinstance(x, BaseModel):
|
|
171
|
+
return tuple(x.model_dump().values())
|
|
172
|
+
if dataclasses.is_dataclass(x):
|
|
173
|
+
return dataclasses.astuple(x) # type: ignore
|
|
174
|
+
if isinstance(x, dict):
|
|
175
|
+
return tuple(x.values())
|
|
176
|
+
return tuple(x)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def as_list_of_tuples(x: Any) -> Any:
|
|
180
|
+
"""Convert object to a list of tuples."""
|
|
181
|
+
if isinstance(x, Table):
|
|
182
|
+
x = x[0]
|
|
183
|
+
if isinstance(x, (list, tuple)) and len(x) > 0:
|
|
184
|
+
if isinstance(x[0], (list, tuple)):
|
|
185
|
+
return x
|
|
186
|
+
if has_pydantic and isinstance(x[0], BaseModel):
|
|
187
|
+
return [tuple(y.model_dump().values()) for y in x]
|
|
188
|
+
if dataclasses.is_dataclass(x[0]):
|
|
189
|
+
return [dataclasses.astuple(y) for y in x]
|
|
190
|
+
if isinstance(x[0], dict):
|
|
191
|
+
return [tuple(y.values()) for y in x]
|
|
192
|
+
return [(y,) for y in x]
|
|
193
|
+
return x
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def get_dataframe_columns(df: Any) -> List[Any]:
|
|
197
|
+
"""Return columns of data from a dataframe/table."""
|
|
198
|
+
if isinstance(df, Table):
|
|
199
|
+
if len(df) == 1:
|
|
200
|
+
df = df[0]
|
|
201
|
+
else:
|
|
202
|
+
return list(df)
|
|
203
|
+
|
|
204
|
+
if isinstance(df, Masked):
|
|
205
|
+
return [df]
|
|
206
|
+
|
|
207
|
+
if isinstance(df, tuple):
|
|
208
|
+
return list(df)
|
|
209
|
+
|
|
210
|
+
rtype = str(type(df)).lower()
|
|
211
|
+
if 'dataframe' in rtype:
|
|
212
|
+
return [df[x] for x in df.columns]
|
|
213
|
+
elif 'table' in rtype:
|
|
214
|
+
return df.columns
|
|
215
|
+
elif 'series' in rtype:
|
|
216
|
+
return [df]
|
|
217
|
+
elif 'array' in rtype:
|
|
218
|
+
return [df]
|
|
219
|
+
elif 'tuple' in rtype:
|
|
220
|
+
return list(df)
|
|
221
|
+
|
|
222
|
+
raise TypeError(
|
|
223
|
+
'Unsupported data type for dataframe columns: '
|
|
224
|
+
f'{rtype}',
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def get_array_class(data_format: str) -> Callable[..., Any]:
|
|
229
|
+
"""
|
|
230
|
+
Get the array class for the current data format.
|
|
231
|
+
|
|
232
|
+
"""
|
|
233
|
+
if data_format == 'polars':
|
|
234
|
+
import polars as pl
|
|
235
|
+
array_cls = pl.Series
|
|
236
|
+
elif data_format == 'arrow':
|
|
237
|
+
import pyarrow as pa
|
|
238
|
+
array_cls = pa.array
|
|
239
|
+
elif data_format == 'pandas':
|
|
240
|
+
import pandas as pd
|
|
241
|
+
array_cls = pd.Series
|
|
242
|
+
else:
|
|
243
|
+
import numpy as np
|
|
244
|
+
array_cls = np.array
|
|
245
|
+
return array_cls
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def get_masked_params(func: Callable[..., Any]) -> List[bool]:
|
|
249
|
+
"""
|
|
250
|
+
Get the list of masked parameters for the function.
|
|
251
|
+
|
|
252
|
+
Parameters
|
|
253
|
+
----------
|
|
254
|
+
func : Callable
|
|
255
|
+
The function to call as the endpoint
|
|
256
|
+
|
|
257
|
+
Returns
|
|
258
|
+
-------
|
|
259
|
+
List[bool]
|
|
260
|
+
Boolean list of masked parameters
|
|
261
|
+
|
|
262
|
+
"""
|
|
263
|
+
params = inspect.signature(func).parameters
|
|
264
|
+
return [typing.get_origin(x.annotation) is Masked for x in params.values()]
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def build_tuple(x: Any) -> Any:
|
|
268
|
+
"""Convert object to tuple."""
|
|
269
|
+
return tuple(x) if isinstance(x, Masked) else (x, None)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def cancel_on_event(
|
|
273
|
+
cancel_event: threading.Event,
|
|
274
|
+
) -> None:
|
|
275
|
+
"""
|
|
276
|
+
Cancel the function call if the cancel event is set.
|
|
277
|
+
|
|
278
|
+
Parameters
|
|
279
|
+
----------
|
|
280
|
+
cancel_event : threading.Event
|
|
281
|
+
The event to check for cancellation
|
|
282
|
+
|
|
283
|
+
Raises
|
|
284
|
+
------
|
|
285
|
+
asyncio.CancelledError
|
|
286
|
+
If the cancel event is set
|
|
287
|
+
|
|
288
|
+
"""
|
|
289
|
+
if cancel_event.is_set():
|
|
290
|
+
task = asyncio.current_task()
|
|
291
|
+
if task is not None:
|
|
292
|
+
task.cancel()
|
|
293
|
+
raise asyncio.CancelledError(
|
|
294
|
+
'Function call was cancelled by client',
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def build_udf_endpoint(
|
|
299
|
+
func: Callable[..., Any],
|
|
300
|
+
returns_data_format: str,
|
|
301
|
+
) -> Callable[..., Any]:
|
|
302
|
+
"""
|
|
303
|
+
Build a UDF endpoint for scalar / list types (row-based).
|
|
304
|
+
|
|
305
|
+
Parameters
|
|
306
|
+
----------
|
|
307
|
+
func : Callable
|
|
308
|
+
The function to call as the endpoint
|
|
309
|
+
returns_data_format : str
|
|
310
|
+
The format of the return values
|
|
311
|
+
|
|
312
|
+
Returns
|
|
313
|
+
-------
|
|
314
|
+
Callable
|
|
315
|
+
The function endpoint
|
|
316
|
+
|
|
317
|
+
"""
|
|
318
|
+
if returns_data_format in ['scalar', 'list']:
|
|
319
|
+
|
|
320
|
+
is_async = asyncio.iscoroutinefunction(func)
|
|
321
|
+
|
|
322
|
+
async def do_func(
|
|
323
|
+
cancel_event: threading.Event,
|
|
324
|
+
timer: Timer,
|
|
325
|
+
row_ids: Sequence[int],
|
|
326
|
+
rows: Sequence[Sequence[Any]],
|
|
327
|
+
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
|
|
328
|
+
'''Call function on given rows of data.'''
|
|
329
|
+
out = []
|
|
330
|
+
async with timer('call_function'):
|
|
331
|
+
for row in rows:
|
|
332
|
+
cancel_on_event(cancel_event)
|
|
333
|
+
if is_async:
|
|
334
|
+
out.append(await func(*row))
|
|
335
|
+
else:
|
|
336
|
+
out.append(func(*row))
|
|
337
|
+
return row_ids, list(zip(out))
|
|
338
|
+
|
|
339
|
+
return do_func
|
|
340
|
+
|
|
341
|
+
return build_vector_udf_endpoint(func, returns_data_format)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def build_vector_udf_endpoint(
|
|
345
|
+
func: Callable[..., Any],
|
|
346
|
+
returns_data_format: str,
|
|
347
|
+
) -> Callable[..., Any]:
|
|
348
|
+
"""
|
|
349
|
+
Build a UDF endpoint for vector formats (column-based).
|
|
350
|
+
|
|
351
|
+
Parameters
|
|
352
|
+
----------
|
|
353
|
+
func : Callable
|
|
354
|
+
The function to call as the endpoint
|
|
355
|
+
returns_data_format : str
|
|
356
|
+
The format of the return values
|
|
357
|
+
|
|
358
|
+
Returns
|
|
359
|
+
-------
|
|
360
|
+
Callable
|
|
361
|
+
The function endpoint
|
|
362
|
+
|
|
363
|
+
"""
|
|
364
|
+
masks = get_masked_params(func)
|
|
365
|
+
array_cls = get_array_class(returns_data_format)
|
|
366
|
+
is_async = asyncio.iscoroutinefunction(func)
|
|
367
|
+
|
|
368
|
+
async def do_func(
|
|
369
|
+
cancel_event: threading.Event,
|
|
370
|
+
timer: Timer,
|
|
371
|
+
row_ids: Sequence[int],
|
|
372
|
+
cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
|
|
373
|
+
) -> Tuple[
|
|
374
|
+
Sequence[int],
|
|
375
|
+
List[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
|
|
376
|
+
]:
|
|
377
|
+
'''Call function on given columns of data.'''
|
|
378
|
+
row_ids = array_cls(row_ids)
|
|
379
|
+
|
|
380
|
+
# Call the function with `cols` as the function parameters
|
|
381
|
+
async with timer('call_function'):
|
|
382
|
+
if cols and cols[0]:
|
|
383
|
+
if is_async:
|
|
384
|
+
out = await func(*[x if m else x[0] for x, m in zip(cols, masks)])
|
|
385
|
+
else:
|
|
386
|
+
out = func(*[x if m else x[0] for x, m in zip(cols, masks)])
|
|
387
|
+
else:
|
|
388
|
+
if is_async:
|
|
389
|
+
out = await func()
|
|
390
|
+
else:
|
|
391
|
+
out = func()
|
|
392
|
+
|
|
393
|
+
cancel_on_event(cancel_event)
|
|
394
|
+
|
|
395
|
+
# Single masked value
|
|
396
|
+
if isinstance(out, Masked):
|
|
397
|
+
return row_ids, [tuple(out)]
|
|
398
|
+
|
|
399
|
+
# Multiple return values
|
|
400
|
+
if isinstance(out, tuple):
|
|
401
|
+
return row_ids, [build_tuple(x) for x in out]
|
|
402
|
+
|
|
403
|
+
# Single return value
|
|
404
|
+
return row_ids, [(out, None)]
|
|
405
|
+
|
|
406
|
+
return do_func
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def build_tvf_endpoint(
|
|
410
|
+
func: Callable[..., Any],
|
|
411
|
+
returns_data_format: str,
|
|
412
|
+
) -> Callable[..., Any]:
|
|
413
|
+
"""
|
|
414
|
+
Build a TVF endpoint for scalar / list types (row-based).
|
|
415
|
+
|
|
416
|
+
Parameters
|
|
417
|
+
----------
|
|
418
|
+
func : Callable
|
|
419
|
+
The function to call as the endpoint
|
|
420
|
+
returns_data_format : str
|
|
421
|
+
The format of the return values
|
|
422
|
+
|
|
423
|
+
Returns
|
|
424
|
+
-------
|
|
425
|
+
Callable
|
|
426
|
+
The function endpoint
|
|
427
|
+
|
|
428
|
+
"""
|
|
429
|
+
if returns_data_format in ['scalar', 'list']:
|
|
430
|
+
|
|
431
|
+
is_async = asyncio.iscoroutinefunction(func)
|
|
432
|
+
|
|
433
|
+
async def do_func(
|
|
434
|
+
cancel_event: threading.Event,
|
|
435
|
+
timer: Timer,
|
|
436
|
+
row_ids: Sequence[int],
|
|
437
|
+
rows: Sequence[Sequence[Any]],
|
|
438
|
+
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
|
|
439
|
+
'''Call function on given rows of data.'''
|
|
440
|
+
out_ids: List[int] = []
|
|
441
|
+
out = []
|
|
442
|
+
# Call function on each row of data
|
|
443
|
+
async with timer('call_function'):
|
|
444
|
+
for i, row in zip(row_ids, rows):
|
|
445
|
+
cancel_on_event(cancel_event)
|
|
446
|
+
if is_async:
|
|
447
|
+
res = await func(*row)
|
|
448
|
+
else:
|
|
449
|
+
res = func(*row)
|
|
450
|
+
out.extend(as_list_of_tuples(res))
|
|
451
|
+
out_ids.extend([row_ids[i]] * (len(out)-len(out_ids)))
|
|
452
|
+
return out_ids, out
|
|
453
|
+
|
|
454
|
+
return do_func
|
|
455
|
+
|
|
456
|
+
return build_vector_tvf_endpoint(func, returns_data_format)
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def build_vector_tvf_endpoint(
|
|
460
|
+
func: Callable[..., Any],
|
|
461
|
+
returns_data_format: str,
|
|
462
|
+
) -> Callable[..., Any]:
|
|
463
|
+
"""
|
|
464
|
+
Build a TVF endpoint for vector formats (column-based).
|
|
465
|
+
|
|
466
|
+
Parameters
|
|
467
|
+
----------
|
|
468
|
+
func : Callable
|
|
469
|
+
The function to call as the endpoint
|
|
470
|
+
returns_data_format : str
|
|
471
|
+
The format of the return values
|
|
472
|
+
|
|
473
|
+
Returns
|
|
474
|
+
-------
|
|
475
|
+
Callable
|
|
476
|
+
The function endpoint
|
|
477
|
+
|
|
478
|
+
"""
|
|
479
|
+
masks = get_masked_params(func)
|
|
480
|
+
array_cls = get_array_class(returns_data_format)
|
|
481
|
+
|
|
482
|
+
async def do_func(
|
|
483
|
+
cancel_event: threading.Event,
|
|
484
|
+
timer: Timer,
|
|
485
|
+
row_ids: Sequence[int],
|
|
486
|
+
cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
|
|
487
|
+
) -> Tuple[
|
|
488
|
+
Sequence[int],
|
|
489
|
+
List[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
|
|
490
|
+
]:
|
|
491
|
+
'''Call function on given columns of data.'''
|
|
492
|
+
# NOTE: There is no way to determine which row ID belongs to
|
|
493
|
+
# each result row, so we just have to use the same
|
|
494
|
+
# row ID for all rows in the result.
|
|
495
|
+
|
|
496
|
+
is_async = asyncio.iscoroutinefunction(func)
|
|
497
|
+
|
|
498
|
+
# Call function on each column of data
|
|
499
|
+
async with timer('call_function'):
|
|
500
|
+
if cols and cols[0]:
|
|
501
|
+
if is_async:
|
|
502
|
+
func_res = await func(
|
|
503
|
+
*[x if m else x[0] for x, m in zip(cols, masks)],
|
|
504
|
+
)
|
|
505
|
+
else:
|
|
506
|
+
func_res = func(
|
|
507
|
+
*[x if m else x[0] for x, m in zip(cols, masks)],
|
|
508
|
+
)
|
|
509
|
+
else:
|
|
510
|
+
if is_async:
|
|
511
|
+
func_res = await func()
|
|
512
|
+
else:
|
|
513
|
+
func_res = func()
|
|
514
|
+
|
|
515
|
+
res = get_dataframe_columns(func_res)
|
|
516
|
+
|
|
517
|
+
cancel_on_event(cancel_event)
|
|
518
|
+
|
|
519
|
+
# Generate row IDs
|
|
520
|
+
if isinstance(res[0], Masked):
|
|
521
|
+
row_ids = array_cls([row_ids[0]] * len(res[0][0]))
|
|
522
|
+
else:
|
|
523
|
+
row_ids = array_cls([row_ids[0]] * len(res[0]))
|
|
524
|
+
|
|
525
|
+
return row_ids, [build_tuple(x) for x in res]
|
|
526
|
+
|
|
527
|
+
return do_func
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
def make_func(
|
|
531
|
+
name: str,
|
|
532
|
+
func: Callable[..., Any],
|
|
533
|
+
) -> Tuple[Callable[..., Any], Dict[str, Any]]:
|
|
534
|
+
"""
|
|
535
|
+
Make a function endpoint.
|
|
536
|
+
|
|
537
|
+
Parameters
|
|
538
|
+
----------
|
|
539
|
+
name : str
|
|
540
|
+
Name of the function to create
|
|
541
|
+
func : Callable
|
|
542
|
+
The function to call as the endpoint
|
|
543
|
+
database : str, optional
|
|
544
|
+
The database to use for the function definition
|
|
545
|
+
|
|
546
|
+
Returns
|
|
547
|
+
-------
|
|
548
|
+
(Callable, Dict[str, Any])
|
|
549
|
+
|
|
550
|
+
"""
|
|
551
|
+
info: Dict[str, Any] = {}
|
|
552
|
+
|
|
553
|
+
sig = get_signature(func, func_name=name)
|
|
554
|
+
|
|
555
|
+
function_type = sig.get('function_type', 'udf')
|
|
556
|
+
args_data_format = sig.get('args_data_format', 'scalar')
|
|
557
|
+
returns_data_format = sig.get('returns_data_format', 'scalar')
|
|
558
|
+
timeout = (
|
|
559
|
+
func._singlestoredb_attrs.get('timeout') or # type: ignore
|
|
560
|
+
get_option('external_function.timeout')
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
if function_type == 'tvf':
|
|
564
|
+
do_func = build_tvf_endpoint(func, returns_data_format)
|
|
565
|
+
else:
|
|
566
|
+
do_func = build_udf_endpoint(func, returns_data_format)
|
|
567
|
+
|
|
568
|
+
do_func.__name__ = name
|
|
569
|
+
do_func.__doc__ = func.__doc__
|
|
570
|
+
|
|
571
|
+
# Store signature for generating CREATE FUNCTION calls
|
|
572
|
+
info['signature'] = sig
|
|
573
|
+
|
|
574
|
+
# Set data format
|
|
575
|
+
info['args_data_format'] = args_data_format
|
|
576
|
+
info['returns_data_format'] = returns_data_format
|
|
577
|
+
|
|
578
|
+
# Set function type
|
|
579
|
+
info['function_type'] = function_type
|
|
580
|
+
|
|
581
|
+
# Set timeout
|
|
582
|
+
info['timeout'] = max(timeout, 1)
|
|
583
|
+
|
|
584
|
+
# Set async flag
|
|
585
|
+
info['is_async'] = asyncio.iscoroutinefunction(func)
|
|
586
|
+
|
|
587
|
+
# Setup argument types for rowdat_1 parser
|
|
588
|
+
colspec = []
|
|
589
|
+
for x in sig['args']:
|
|
590
|
+
dtype = x['dtype'].replace('?', '')
|
|
591
|
+
if dtype not in rowdat_1_type_map:
|
|
592
|
+
raise TypeError(f'no data type mapping for {dtype}')
|
|
593
|
+
colspec.append((x['name'], rowdat_1_type_map[dtype]))
|
|
594
|
+
info['colspec'] = colspec
|
|
595
|
+
|
|
596
|
+
# Setup return type
|
|
597
|
+
returns = []
|
|
598
|
+
for x in sig['returns']:
|
|
599
|
+
dtype = x['dtype'].replace('?', '')
|
|
600
|
+
if dtype not in rowdat_1_type_map:
|
|
601
|
+
raise TypeError(f'no data type mapping for {dtype}')
|
|
602
|
+
returns.append((x['name'], rowdat_1_type_map[dtype]))
|
|
603
|
+
info['returns'] = returns
|
|
604
|
+
|
|
605
|
+
return do_func, info
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
async def cancel_on_timeout(timeout: int) -> None:
|
|
609
|
+
"""Cancel request if it takes too long."""
|
|
610
|
+
await asyncio.sleep(timeout)
|
|
611
|
+
raise asyncio.CancelledError(
|
|
612
|
+
'Function call was cancelled due to timeout',
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
async def cancel_on_disconnect(
|
|
617
|
+
receive: Callable[..., Awaitable[Any]],
|
|
618
|
+
) -> None:
|
|
619
|
+
"""Cancel request if client disconnects."""
|
|
620
|
+
while True:
|
|
621
|
+
message = await receive()
|
|
622
|
+
if message.get('type', '') == 'http.disconnect':
|
|
623
|
+
raise asyncio.CancelledError(
|
|
624
|
+
'Function call was cancelled by client',
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
|
|
628
|
+
async def cancel_all_tasks(tasks: Iterable[asyncio.Task[Any]]) -> None:
|
|
629
|
+
"""Cancel all tasks."""
|
|
630
|
+
for task in tasks:
|
|
631
|
+
task.cancel()
|
|
632
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
def start_counter() -> float:
|
|
636
|
+
"""Start a timer and return the start time."""
|
|
637
|
+
return time.perf_counter()
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
def end_counter(start: float) -> float:
|
|
641
|
+
"""End a timer and return the elapsed time."""
|
|
642
|
+
return time.perf_counter() - start
|
|
643
|
+
|
|
644
|
+
|
|
645
|
+
class Application(object):
|
|
646
|
+
"""
|
|
647
|
+
Create an external function application.
|
|
648
|
+
|
|
649
|
+
If `functions` is None, the environment is searched for function
|
|
650
|
+
specifications in variables starting with `SINGLESTOREDB_EXT_FUNCTIONS`.
|
|
651
|
+
Any number of environment variables can be specified as long as they
|
|
652
|
+
have this prefix. The format of the environment variable value is the
|
|
653
|
+
same as for the `functions` parameter.
|
|
654
|
+
|
|
655
|
+
Parameters
|
|
656
|
+
----------
|
|
657
|
+
functions : str or Iterable[str], optional
|
|
658
|
+
Python functions are specified using a string format as follows:
|
|
659
|
+
* Single function : <pkg1>.<func1>
|
|
660
|
+
* Multiple functions : <pkg1>.[<func1-name,func2-name,...]
|
|
661
|
+
* Function aliases : <pkg1>.[<func1@alias1,func2@alias2,...]
|
|
662
|
+
* Multiple packages : <pkg1>.<func1>:<pkg2>.<func2>
|
|
663
|
+
app_mode : str, optional
|
|
664
|
+
The mode of operation for the application: remote, managed, or collocated
|
|
665
|
+
url : str, optional
|
|
666
|
+
The URL of the function API
|
|
667
|
+
data_format : str, optional
|
|
668
|
+
The format of the data rows: 'rowdat_1' or 'json'
|
|
669
|
+
data_version : str, optional
|
|
670
|
+
The version of the call format to expect: '1.0'
|
|
671
|
+
link_name : str, optional
|
|
672
|
+
The link name to use for the external function application. This is
|
|
673
|
+
only for pre-existing links, and can only be used without
|
|
674
|
+
``link_config`` and ``link_credentials``.
|
|
675
|
+
link_config : Dict[str, Any], optional
|
|
676
|
+
The CONFIG section of a LINK definition. This dictionary gets
|
|
677
|
+
converted to JSON for the CREATE LINK call.
|
|
678
|
+
link_credentials : Dict[str, Any], optional
|
|
679
|
+
The CREDENTIALS section of a LINK definition. This dictionary gets
|
|
680
|
+
converted to JSON for the CREATE LINK call.
|
|
681
|
+
name_prefix : str, optional
|
|
682
|
+
Prefix to add to function names when registering with the database
|
|
683
|
+
name_suffix : str, optional
|
|
684
|
+
Suffix to add to function names when registering with the database
|
|
685
|
+
function_database : str, optional
|
|
686
|
+
The database to use for external function definitions.
|
|
687
|
+
log_file : str, optional
|
|
688
|
+
File path to write logs to instead of console. If None, logs are
|
|
689
|
+
written to console. When specified, application logger handlers
|
|
690
|
+
are replaced with a file handler.
|
|
691
|
+
log_level : str, optional
|
|
692
|
+
Logging level for the application logger. Valid values are 'info',
|
|
693
|
+
'debug', 'warning', 'error'. Defaults to 'info'.
|
|
694
|
+
disable_metrics : bool, optional
|
|
695
|
+
Disable logging of function call metrics. Defaults to False.
|
|
696
|
+
app_name : str, optional
|
|
697
|
+
Name for the application instance. Used to create a logger-specific
|
|
698
|
+
name. If not provided, a random name will be generated.
|
|
699
|
+
|
|
700
|
+
"""
|
|
701
|
+
|
|
702
|
+
# Plain text response start
|
|
703
|
+
text_response_dict: Dict[str, Any] = dict(
|
|
704
|
+
type='http.response.start',
|
|
705
|
+
status=200,
|
|
706
|
+
headers=[(b'content-type', b'text/plain')],
|
|
707
|
+
)
|
|
708
|
+
|
|
709
|
+
# Error response start
|
|
710
|
+
error_response_dict: Dict[str, Any] = dict(
|
|
711
|
+
type='http.response.start',
|
|
712
|
+
status=500,
|
|
713
|
+
headers=[(b'content-type', b'text/plain')],
|
|
714
|
+
)
|
|
715
|
+
|
|
716
|
+
# Timeout response start
|
|
717
|
+
timeout_response_dict: Dict[str, Any] = dict(
|
|
718
|
+
type='http.response.start',
|
|
719
|
+
status=504,
|
|
720
|
+
headers=[(b'content-type', b'text/plain')],
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
# Cancel response start
|
|
724
|
+
cancel_response_dict: Dict[str, Any] = dict(
|
|
725
|
+
type='http.response.start',
|
|
726
|
+
status=503,
|
|
727
|
+
headers=[(b'content-type', b'text/plain')],
|
|
728
|
+
)
|
|
729
|
+
|
|
730
|
+
# JSON response start
|
|
731
|
+
json_response_dict: Dict[str, Any] = dict(
|
|
732
|
+
type='http.response.start',
|
|
733
|
+
status=200,
|
|
734
|
+
headers=[(b'content-type', b'application/json')],
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
# ROWDAT_1 response start
|
|
738
|
+
rowdat_1_response_dict: Dict[str, Any] = dict(
|
|
739
|
+
type='http.response.start',
|
|
740
|
+
status=200,
|
|
741
|
+
headers=[(b'content-type', b'x-application/rowdat_1')],
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
# Apache Arrow response start
|
|
745
|
+
arrow_response_dict: Dict[str, Any] = dict(
|
|
746
|
+
type='http.response.start',
|
|
747
|
+
status=200,
|
|
748
|
+
headers=[(b'content-type', b'application/vnd.apache.arrow.file')],
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
# Path not found response start
|
|
752
|
+
path_not_found_response_dict: Dict[str, Any] = dict(
|
|
753
|
+
type='http.response.start',
|
|
754
|
+
status=404,
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
# Response body template
|
|
758
|
+
body_response_dict: Dict[str, Any] = dict(
|
|
759
|
+
type='http.response.body',
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
# Data format + version handlers
|
|
763
|
+
handlers = {
|
|
764
|
+
(b'application/octet-stream', b'1.0', 'scalar'): dict(
|
|
765
|
+
load=rowdat_1.load,
|
|
766
|
+
dump=rowdat_1.dump,
|
|
767
|
+
response=rowdat_1_response_dict,
|
|
768
|
+
),
|
|
769
|
+
(b'application/octet-stream', b'1.0', 'list'): dict(
|
|
770
|
+
load=rowdat_1.load,
|
|
771
|
+
dump=rowdat_1.dump,
|
|
772
|
+
response=rowdat_1_response_dict,
|
|
773
|
+
),
|
|
774
|
+
(b'application/octet-stream', b'1.0', 'pandas'): dict(
|
|
775
|
+
load=rowdat_1.load_pandas,
|
|
776
|
+
dump=rowdat_1.dump_pandas,
|
|
777
|
+
response=rowdat_1_response_dict,
|
|
778
|
+
),
|
|
779
|
+
(b'application/octet-stream', b'1.0', 'numpy'): dict(
|
|
780
|
+
load=rowdat_1.load_numpy,
|
|
781
|
+
dump=rowdat_1.dump_numpy,
|
|
782
|
+
response=rowdat_1_response_dict,
|
|
783
|
+
),
|
|
784
|
+
(b'application/octet-stream', b'1.0', 'polars'): dict(
|
|
785
|
+
load=rowdat_1.load_polars,
|
|
786
|
+
dump=rowdat_1.dump_polars,
|
|
787
|
+
response=rowdat_1_response_dict,
|
|
788
|
+
),
|
|
789
|
+
(b'application/octet-stream', b'1.0', 'arrow'): dict(
|
|
790
|
+
load=rowdat_1.load_arrow,
|
|
791
|
+
dump=rowdat_1.dump_arrow,
|
|
792
|
+
response=rowdat_1_response_dict,
|
|
793
|
+
),
|
|
794
|
+
(b'application/json', b'1.0', 'scalar'): dict(
|
|
795
|
+
load=jdata.load,
|
|
796
|
+
dump=jdata.dump,
|
|
797
|
+
response=json_response_dict,
|
|
798
|
+
),
|
|
799
|
+
(b'application/json', b'1.0', 'list'): dict(
|
|
800
|
+
load=jdata.load,
|
|
801
|
+
dump=jdata.dump,
|
|
802
|
+
response=json_response_dict,
|
|
803
|
+
),
|
|
804
|
+
(b'application/json', b'1.0', 'pandas'): dict(
|
|
805
|
+
load=jdata.load_pandas,
|
|
806
|
+
dump=jdata.dump_pandas,
|
|
807
|
+
response=json_response_dict,
|
|
808
|
+
),
|
|
809
|
+
(b'application/json', b'1.0', 'numpy'): dict(
|
|
810
|
+
load=jdata.load_numpy,
|
|
811
|
+
dump=jdata.dump_numpy,
|
|
812
|
+
response=json_response_dict,
|
|
813
|
+
),
|
|
814
|
+
(b'application/json', b'1.0', 'polars'): dict(
|
|
815
|
+
load=jdata.load_polars,
|
|
816
|
+
dump=jdata.dump_polars,
|
|
817
|
+
response=json_response_dict,
|
|
818
|
+
),
|
|
819
|
+
(b'application/json', b'1.0', 'arrow'): dict(
|
|
820
|
+
load=jdata.load_arrow,
|
|
821
|
+
dump=jdata.dump_arrow,
|
|
822
|
+
response=json_response_dict,
|
|
823
|
+
),
|
|
824
|
+
(b'application/vnd.apache.arrow.file', b'1.0', 'scalar'): dict(
|
|
825
|
+
load=arrow.load,
|
|
826
|
+
dump=arrow.dump,
|
|
827
|
+
response=arrow_response_dict,
|
|
828
|
+
),
|
|
829
|
+
(b'application/vnd.apache.arrow.file', b'1.0', 'pandas'): dict(
|
|
830
|
+
load=arrow.load_pandas,
|
|
831
|
+
dump=arrow.dump_pandas,
|
|
832
|
+
response=arrow_response_dict,
|
|
833
|
+
),
|
|
834
|
+
(b'application/vnd.apache.arrow.file', b'1.0', 'numpy'): dict(
|
|
835
|
+
load=arrow.load_numpy,
|
|
836
|
+
dump=arrow.dump_numpy,
|
|
837
|
+
response=arrow_response_dict,
|
|
838
|
+
),
|
|
839
|
+
(b'application/vnd.apache.arrow.file', b'1.0', 'polars'): dict(
|
|
840
|
+
load=arrow.load_polars,
|
|
841
|
+
dump=arrow.dump_polars,
|
|
842
|
+
response=arrow_response_dict,
|
|
843
|
+
),
|
|
844
|
+
(b'application/vnd.apache.arrow.file', b'1.0', 'arrow'): dict(
|
|
845
|
+
load=arrow.load_arrow,
|
|
846
|
+
dump=arrow.dump_arrow,
|
|
847
|
+
response=arrow_response_dict,
|
|
848
|
+
),
|
|
849
|
+
}
|
|
850
|
+
|
|
851
|
+
# Valid URL paths
|
|
852
|
+
invoke_path = ('invoke',)
|
|
853
|
+
show_create_function_path = ('show', 'create_function')
|
|
854
|
+
show_function_info_path = ('show', 'function_info')
|
|
855
|
+
status = ('status',)
|
|
856
|
+
|
|
857
|
+
def __init__(
|
|
858
|
+
self,
|
|
859
|
+
functions: Optional[
|
|
860
|
+
Union[
|
|
861
|
+
str,
|
|
862
|
+
Iterable[str],
|
|
863
|
+
Callable[..., Any],
|
|
864
|
+
Iterable[Callable[..., Any]],
|
|
865
|
+
ModuleType,
|
|
866
|
+
Iterable[ModuleType],
|
|
867
|
+
]
|
|
868
|
+
] = None,
|
|
869
|
+
app_mode: str = get_option('external_function.app_mode'),
|
|
870
|
+
url: str = get_option('external_function.url'),
|
|
871
|
+
data_format: str = get_option('external_function.data_format'),
|
|
872
|
+
data_version: str = get_option('external_function.data_version'),
|
|
873
|
+
link_name: Optional[str] = get_option('external_function.link_name'),
|
|
874
|
+
link_config: Optional[Dict[str, Any]] = None,
|
|
875
|
+
link_credentials: Optional[Dict[str, Any]] = None,
|
|
876
|
+
name_prefix: str = get_option('external_function.name_prefix'),
|
|
877
|
+
name_suffix: str = get_option('external_function.name_suffix'),
|
|
878
|
+
function_database: Optional[str] = None,
|
|
879
|
+
log_file: Optional[str] = get_option('external_function.log_file'),
|
|
880
|
+
log_level: str = get_option('external_function.log_level'),
|
|
881
|
+
disable_metrics: bool = get_option('external_function.disable_metrics'),
|
|
882
|
+
app_name: Optional[str] = get_option('external_function.app_name'),
|
|
883
|
+
) -> None:
|
|
884
|
+
if link_name and (link_config or link_credentials):
|
|
885
|
+
raise ValueError(
|
|
886
|
+
'`link_name` can not be used with `link_config` or `link_credentials`',
|
|
887
|
+
)
|
|
888
|
+
|
|
889
|
+
if link_config is None:
|
|
890
|
+
link_config = json.loads(
|
|
891
|
+
get_option('external_function.link_config') or '{}',
|
|
892
|
+
) or None
|
|
893
|
+
|
|
894
|
+
if link_credentials is None:
|
|
895
|
+
link_credentials = json.loads(
|
|
896
|
+
get_option('external_function.link_credentials') or '{}',
|
|
897
|
+
) or None
|
|
898
|
+
|
|
899
|
+
# Generate application name if not provided
|
|
900
|
+
if app_name is None:
|
|
901
|
+
app_name = f'udf_app_{secrets.token_hex(4)}'
|
|
902
|
+
|
|
903
|
+
self.name = app_name
|
|
904
|
+
|
|
905
|
+
# Create logger instance specific to this application
|
|
906
|
+
self.logger = utils.get_logger(f'singlestoredb.functions.ext.asgi.{self.name}')
|
|
907
|
+
|
|
908
|
+
# List of functions specs
|
|
909
|
+
specs: List[Union[str, Callable[..., Any], ModuleType]] = []
|
|
910
|
+
|
|
911
|
+
# Look up Python function specifications
|
|
912
|
+
if functions is None:
|
|
913
|
+
env_vars = [
|
|
914
|
+
x for x in os.environ.keys()
|
|
915
|
+
if x.startswith('SINGLESTOREDB_EXT_FUNCTIONS')
|
|
916
|
+
]
|
|
917
|
+
if env_vars:
|
|
918
|
+
specs = [os.environ[x] for x in env_vars]
|
|
919
|
+
else:
|
|
920
|
+
import __main__
|
|
921
|
+
specs = [__main__]
|
|
922
|
+
|
|
923
|
+
elif isinstance(functions, ModuleType):
|
|
924
|
+
specs = [functions]
|
|
925
|
+
|
|
926
|
+
elif isinstance(functions, str):
|
|
927
|
+
specs = [functions]
|
|
928
|
+
|
|
929
|
+
elif callable(functions):
|
|
930
|
+
specs = [functions]
|
|
931
|
+
|
|
932
|
+
else:
|
|
933
|
+
specs = list(functions)
|
|
934
|
+
|
|
935
|
+
# Add functions to application
|
|
936
|
+
endpoints = dict()
|
|
937
|
+
external_functions = dict()
|
|
938
|
+
for funcs in itertools.chain(specs):
|
|
939
|
+
|
|
940
|
+
if isinstance(funcs, str):
|
|
941
|
+
# Module name
|
|
942
|
+
if importlib.util.find_spec(funcs) is not None:
|
|
943
|
+
items = importlib.import_module(funcs)
|
|
944
|
+
for x in vars(items).values():
|
|
945
|
+
if not hasattr(x, '_singlestoredb_attrs'):
|
|
946
|
+
continue
|
|
947
|
+
name = x._singlestoredb_attrs.get('name', x.__name__)
|
|
948
|
+
name = f'{name_prefix}{name}{name_suffix}'
|
|
949
|
+
external_functions[x.__name__] = x
|
|
950
|
+
func, info = make_func(name, x)
|
|
951
|
+
endpoints[name.encode('utf-8')] = func, info
|
|
952
|
+
|
|
953
|
+
# Fully qualified function name
|
|
954
|
+
elif '.' in funcs:
|
|
955
|
+
pkg_path, func_names = funcs.rsplit('.', 1)
|
|
956
|
+
pkg = importlib.import_module(pkg_path)
|
|
957
|
+
|
|
958
|
+
if pkg is None:
|
|
959
|
+
raise RuntimeError(f'Could not locate module: {pkg}')
|
|
960
|
+
|
|
961
|
+
# Add endpoint for each exported function
|
|
962
|
+
for name, alias in get_func_names(func_names):
|
|
963
|
+
item = getattr(pkg, name)
|
|
964
|
+
alias = f'{name_prefix}{name}{name_suffix}'
|
|
965
|
+
external_functions[name] = item
|
|
966
|
+
func, info = make_func(alias, item)
|
|
967
|
+
endpoints[alias.encode('utf-8')] = func, info
|
|
968
|
+
|
|
969
|
+
else:
|
|
970
|
+
raise RuntimeError(f'Could not locate module: {funcs}')
|
|
971
|
+
|
|
972
|
+
elif isinstance(funcs, ModuleType):
|
|
973
|
+
for x in vars(funcs).values():
|
|
974
|
+
if not hasattr(x, '_singlestoredb_attrs'):
|
|
975
|
+
continue
|
|
976
|
+
name = x._singlestoredb_attrs.get('name', x.__name__)
|
|
977
|
+
name = f'{name_prefix}{name}{name_suffix}'
|
|
978
|
+
external_functions[x.__name__] = x
|
|
979
|
+
func, info = make_func(name, x)
|
|
980
|
+
endpoints[name.encode('utf-8')] = func, info
|
|
981
|
+
|
|
982
|
+
else:
|
|
983
|
+
alias = funcs.__name__
|
|
984
|
+
external_functions[funcs.__name__] = funcs
|
|
985
|
+
alias = f'{name_prefix}{alias}{name_suffix}'
|
|
986
|
+
func, info = make_func(alias, funcs)
|
|
987
|
+
endpoints[alias.encode('utf-8')] = func, info
|
|
988
|
+
|
|
989
|
+
self.app_mode = app_mode
|
|
990
|
+
self.url = url
|
|
991
|
+
self.data_format = data_format
|
|
992
|
+
self.data_version = data_version
|
|
993
|
+
self.link_name = link_name
|
|
994
|
+
self.link_config = link_config
|
|
995
|
+
self.link_credentials = link_credentials
|
|
996
|
+
self.endpoints = endpoints
|
|
997
|
+
self.external_functions = external_functions
|
|
998
|
+
self.function_database = function_database
|
|
999
|
+
self.log_file = log_file
|
|
1000
|
+
self.log_level = log_level
|
|
1001
|
+
self.disable_metrics = disable_metrics
|
|
1002
|
+
|
|
1003
|
+
# Configure logging
|
|
1004
|
+
self._configure_logging()
|
|
1005
|
+
|
|
1006
|
+
def _configure_logging(self) -> None:
|
|
1007
|
+
"""Configure logging based on the log_file settings."""
|
|
1008
|
+
# Set logger level
|
|
1009
|
+
self.logger.setLevel(getattr(logging, self.log_level.upper()))
|
|
1010
|
+
|
|
1011
|
+
# Remove all existing handlers to ensure clean configuration
|
|
1012
|
+
self.logger.handlers.clear()
|
|
1013
|
+
|
|
1014
|
+
# Configure log file if specified
|
|
1015
|
+
if self.log_file:
|
|
1016
|
+
# Create file handler
|
|
1017
|
+
file_handler = logging.FileHandler(self.log_file)
|
|
1018
|
+
file_handler.setLevel(getattr(logging, self.log_level.upper()))
|
|
1019
|
+
|
|
1020
|
+
# Use JSON formatter for file logging
|
|
1021
|
+
formatter = utils.JSONFormatter()
|
|
1022
|
+
file_handler.setFormatter(formatter)
|
|
1023
|
+
|
|
1024
|
+
# Add the handler to the logger
|
|
1025
|
+
self.logger.addHandler(file_handler)
|
|
1026
|
+
else:
|
|
1027
|
+
# For console logging, create a new stream handler with JSON formatter
|
|
1028
|
+
console_handler = logging.StreamHandler()
|
|
1029
|
+
console_handler.setLevel(getattr(logging, self.log_level.upper()))
|
|
1030
|
+
console_handler.setFormatter(utils.JSONFormatter())
|
|
1031
|
+
self.logger.addHandler(console_handler)
|
|
1032
|
+
|
|
1033
|
+
# Prevent propagation to avoid duplicate or differently formatted messages
|
|
1034
|
+
self.logger.propagate = False
|
|
1035
|
+
|
|
1036
|
+
def get_uvicorn_log_config(self) -> Dict[str, Any]:
|
|
1037
|
+
"""
|
|
1038
|
+
Create uvicorn log config that matches the Application's logging format.
|
|
1039
|
+
|
|
1040
|
+
This method returns the log configuration used by uvicorn, allowing external
|
|
1041
|
+
users to match the logging format of the Application class.
|
|
1042
|
+
|
|
1043
|
+
Returns
|
|
1044
|
+
-------
|
|
1045
|
+
Dict[str, Any]
|
|
1046
|
+
Log configuration dictionary compatible with uvicorn's log_config parameter
|
|
1047
|
+
|
|
1048
|
+
"""
|
|
1049
|
+
log_config = {
|
|
1050
|
+
'version': 1,
|
|
1051
|
+
'disable_existing_loggers': False,
|
|
1052
|
+
'formatters': {
|
|
1053
|
+
'json': {
|
|
1054
|
+
'()': 'singlestoredb.functions.ext.utils.JSONFormatter',
|
|
1055
|
+
},
|
|
1056
|
+
},
|
|
1057
|
+
'handlers': {
|
|
1058
|
+
'default': {
|
|
1059
|
+
'class': (
|
|
1060
|
+
'logging.FileHandler' if self.log_file
|
|
1061
|
+
else 'logging.StreamHandler'
|
|
1062
|
+
),
|
|
1063
|
+
'formatter': 'json',
|
|
1064
|
+
},
|
|
1065
|
+
},
|
|
1066
|
+
'loggers': {
|
|
1067
|
+
'uvicorn': {
|
|
1068
|
+
'handlers': ['default'],
|
|
1069
|
+
'level': self.log_level.upper(),
|
|
1070
|
+
'propagate': False,
|
|
1071
|
+
},
|
|
1072
|
+
'uvicorn.error': {
|
|
1073
|
+
'handlers': ['default'],
|
|
1074
|
+
'level': self.log_level.upper(),
|
|
1075
|
+
'propagate': False,
|
|
1076
|
+
},
|
|
1077
|
+
'uvicorn.access': {
|
|
1078
|
+
'handlers': ['default'],
|
|
1079
|
+
'level': self.log_level.upper(),
|
|
1080
|
+
'propagate': False,
|
|
1081
|
+
},
|
|
1082
|
+
},
|
|
1083
|
+
}
|
|
1084
|
+
|
|
1085
|
+
# Add filename to file handler if log file is specified
|
|
1086
|
+
if self.log_file:
|
|
1087
|
+
log_config['handlers']['default']['filename'] = self.log_file # type: ignore
|
|
1088
|
+
|
|
1089
|
+
return log_config
|
|
1090
|
+
|
|
1091
|
+
async def __call__(
|
|
1092
|
+
self,
|
|
1093
|
+
scope: Dict[str, Any],
|
|
1094
|
+
receive: Callable[..., Awaitable[Any]],
|
|
1095
|
+
send: Callable[..., Awaitable[Any]],
|
|
1096
|
+
) -> None:
|
|
1097
|
+
'''
|
|
1098
|
+
Application request handler.
|
|
1099
|
+
|
|
1100
|
+
Parameters
|
|
1101
|
+
----------
|
|
1102
|
+
scope : dict
|
|
1103
|
+
ASGI request scope
|
|
1104
|
+
receive : Callable
|
|
1105
|
+
Function to receieve request information
|
|
1106
|
+
send : Callable
|
|
1107
|
+
Function to send response information
|
|
1108
|
+
|
|
1109
|
+
'''
|
|
1110
|
+
request_id = str(uuid.uuid4())
|
|
1111
|
+
|
|
1112
|
+
timer = Timer(
|
|
1113
|
+
app_name=self.name,
|
|
1114
|
+
id=request_id,
|
|
1115
|
+
timestamp=datetime.datetime.now(
|
|
1116
|
+
datetime.timezone.utc,
|
|
1117
|
+
).strftime('%Y-%m-%dT%H:%M:%S.%fZ'),
|
|
1118
|
+
)
|
|
1119
|
+
call_timer = Timer(
|
|
1120
|
+
app_name=self.name,
|
|
1121
|
+
id=request_id,
|
|
1122
|
+
timestamp=datetime.datetime.now(
|
|
1123
|
+
datetime.timezone.utc,
|
|
1124
|
+
).strftime('%Y-%m-%dT%H:%M:%S.%fZ'),
|
|
1125
|
+
)
|
|
1126
|
+
|
|
1127
|
+
if scope['type'] != 'http':
|
|
1128
|
+
raise ValueError(f"Expected HTTP scope, got {scope['type']}")
|
|
1129
|
+
|
|
1130
|
+
method = scope['method']
|
|
1131
|
+
path = tuple(x for x in scope['path'].split('/') if x)
|
|
1132
|
+
headers = dict(scope['headers'])
|
|
1133
|
+
|
|
1134
|
+
content_type = headers.get(
|
|
1135
|
+
b'content-type',
|
|
1136
|
+
b'application/octet-stream',
|
|
1137
|
+
)
|
|
1138
|
+
accepts = headers.get(b'accepts', content_type)
|
|
1139
|
+
func_name = headers.get(b's2-ef-name', b'')
|
|
1140
|
+
func_endpoint = self.endpoints.get(func_name)
|
|
1141
|
+
ignore_cancel = headers.get(b's2-ef-ignore-cancel', b'false') == b'true'
|
|
1142
|
+
|
|
1143
|
+
timer.metadata['function'] = func_name.decode('utf-8') if func_name else ''
|
|
1144
|
+
call_timer.metadata['function'] = timer.metadata['function']
|
|
1145
|
+
|
|
1146
|
+
func = None
|
|
1147
|
+
func_info: Dict[str, Any] = {}
|
|
1148
|
+
if func_endpoint is not None:
|
|
1149
|
+
func, func_info = func_endpoint
|
|
1150
|
+
|
|
1151
|
+
# Call the endpoint
|
|
1152
|
+
if method == 'POST' and func is not None and path == self.invoke_path:
|
|
1153
|
+
|
|
1154
|
+
self.logger.info(
|
|
1155
|
+
'Function call initiated',
|
|
1156
|
+
extra={
|
|
1157
|
+
'app_name': self.name,
|
|
1158
|
+
'request_id': request_id,
|
|
1159
|
+
'function_name': func_name.decode('utf-8'),
|
|
1160
|
+
'content_type': content_type.decode('utf-8'),
|
|
1161
|
+
'accepts': accepts.decode('utf-8'),
|
|
1162
|
+
},
|
|
1163
|
+
)
|
|
1164
|
+
|
|
1165
|
+
args_data_format = func_info['args_data_format']
|
|
1166
|
+
returns_data_format = func_info['returns_data_format']
|
|
1167
|
+
data = []
|
|
1168
|
+
more_body = True
|
|
1169
|
+
with timer('receive_data'):
|
|
1170
|
+
while more_body:
|
|
1171
|
+
request = await receive()
|
|
1172
|
+
if request.get('type', '') == 'http.disconnect':
|
|
1173
|
+
raise RuntimeError('client disconnected')
|
|
1174
|
+
data.append(request['body'])
|
|
1175
|
+
more_body = request.get('more_body', False)
|
|
1176
|
+
|
|
1177
|
+
data_version = headers.get(b's2-ef-version', b'')
|
|
1178
|
+
input_handler = self.handlers[(content_type, data_version, args_data_format)]
|
|
1179
|
+
output_handler = self.handlers[(accepts, data_version, returns_data_format)]
|
|
1180
|
+
|
|
1181
|
+
try:
|
|
1182
|
+
all_tasks = []
|
|
1183
|
+
result = []
|
|
1184
|
+
|
|
1185
|
+
cancel_event = threading.Event()
|
|
1186
|
+
|
|
1187
|
+
with timer('parse_input'):
|
|
1188
|
+
inputs = input_handler['load']( # type: ignore
|
|
1189
|
+
func_info['colspec'], b''.join(data),
|
|
1190
|
+
)
|
|
1191
|
+
|
|
1192
|
+
func_task = asyncio.create_task(
|
|
1193
|
+
func(cancel_event, call_timer, *inputs)
|
|
1194
|
+
if func_info['is_async']
|
|
1195
|
+
else to_thread(
|
|
1196
|
+
lambda: asyncio.run(
|
|
1197
|
+
func(cancel_event, call_timer, *inputs),
|
|
1198
|
+
),
|
|
1199
|
+
),
|
|
1200
|
+
)
|
|
1201
|
+
disconnect_task = asyncio.create_task(
|
|
1202
|
+
asyncio.sleep(int(1e9))
|
|
1203
|
+
if ignore_cancel else cancel_on_disconnect(receive),
|
|
1204
|
+
)
|
|
1205
|
+
timeout_task = asyncio.create_task(
|
|
1206
|
+
cancel_on_timeout(func_info['timeout']),
|
|
1207
|
+
)
|
|
1208
|
+
|
|
1209
|
+
all_tasks += [func_task, disconnect_task, timeout_task]
|
|
1210
|
+
|
|
1211
|
+
async with timer('function_wrapper'):
|
|
1212
|
+
done, pending = await asyncio.wait(
|
|
1213
|
+
all_tasks, return_when=asyncio.FIRST_COMPLETED,
|
|
1214
|
+
)
|
|
1215
|
+
|
|
1216
|
+
await cancel_all_tasks(pending)
|
|
1217
|
+
|
|
1218
|
+
for task in done:
|
|
1219
|
+
if task is disconnect_task:
|
|
1220
|
+
cancel_event.set()
|
|
1221
|
+
raise asyncio.CancelledError(
|
|
1222
|
+
'Function call was cancelled by client disconnect',
|
|
1223
|
+
)
|
|
1224
|
+
|
|
1225
|
+
elif task is timeout_task:
|
|
1226
|
+
cancel_event.set()
|
|
1227
|
+
raise asyncio.TimeoutError(
|
|
1228
|
+
'Function call was cancelled due to timeout',
|
|
1229
|
+
)
|
|
1230
|
+
|
|
1231
|
+
elif task is func_task:
|
|
1232
|
+
result.extend(task.result())
|
|
1233
|
+
|
|
1234
|
+
with timer('format_output'):
|
|
1235
|
+
body = output_handler['dump'](
|
|
1236
|
+
[x[1] for x in func_info['returns']], *result, # type: ignore
|
|
1237
|
+
)
|
|
1238
|
+
|
|
1239
|
+
await send(output_handler['response'])
|
|
1240
|
+
|
|
1241
|
+
except asyncio.TimeoutError:
|
|
1242
|
+
self.logger.exception(
|
|
1243
|
+
'Function call timeout',
|
|
1244
|
+
extra={
|
|
1245
|
+
'app_name': self.name,
|
|
1246
|
+
'request_id': request_id,
|
|
1247
|
+
'function_name': func_name.decode('utf-8'),
|
|
1248
|
+
'timeout': func_info['timeout'],
|
|
1249
|
+
},
|
|
1250
|
+
)
|
|
1251
|
+
body = (
|
|
1252
|
+
'TimeoutError: Function call timed out after ' +
|
|
1253
|
+
str(func_info['timeout']) +
|
|
1254
|
+
' seconds'
|
|
1255
|
+
).encode('utf-8')
|
|
1256
|
+
await send(self.timeout_response_dict)
|
|
1257
|
+
|
|
1258
|
+
except asyncio.CancelledError:
|
|
1259
|
+
self.logger.exception(
|
|
1260
|
+
'Function call cancelled',
|
|
1261
|
+
extra={
|
|
1262
|
+
'app_name': self.name,
|
|
1263
|
+
'request_id': request_id,
|
|
1264
|
+
'function_name': func_name.decode('utf-8'),
|
|
1265
|
+
},
|
|
1266
|
+
)
|
|
1267
|
+
body = b'CancelledError: Function call was cancelled'
|
|
1268
|
+
await send(self.cancel_response_dict)
|
|
1269
|
+
|
|
1270
|
+
except Exception as e:
|
|
1271
|
+
self.logger.exception(
|
|
1272
|
+
'Function call error',
|
|
1273
|
+
extra={
|
|
1274
|
+
'app_name': self.name,
|
|
1275
|
+
'request_id': request_id,
|
|
1276
|
+
'function_name': func_name.decode('utf-8'),
|
|
1277
|
+
'exception_type': type(e).__name__,
|
|
1278
|
+
},
|
|
1279
|
+
)
|
|
1280
|
+
msg = traceback.format_exc().strip().split(' File ')[-1]
|
|
1281
|
+
if msg.startswith('"/tmp/ipykernel_'):
|
|
1282
|
+
msg = 'Line ' + msg.split(', line ')[-1]
|
|
1283
|
+
else:
|
|
1284
|
+
msg = 'File ' + msg
|
|
1285
|
+
body = msg.encode('utf-8')
|
|
1286
|
+
await send(self.error_response_dict)
|
|
1287
|
+
|
|
1288
|
+
finally:
|
|
1289
|
+
await cancel_all_tasks(all_tasks)
|
|
1290
|
+
|
|
1291
|
+
# Handle api reflection
|
|
1292
|
+
elif method == 'GET' and path == self.show_create_function_path:
|
|
1293
|
+
host = headers.get(b'host', b'localhost:80')
|
|
1294
|
+
reflected_url = f'{scope["scheme"]}://{host.decode("utf-8")}/invoke'
|
|
1295
|
+
|
|
1296
|
+
syntax = []
|
|
1297
|
+
for key, (endpoint, endpoint_info) in self.endpoints.items():
|
|
1298
|
+
if not func_name or key == func_name:
|
|
1299
|
+
syntax.append(
|
|
1300
|
+
signature_to_sql(
|
|
1301
|
+
endpoint_info['signature'],
|
|
1302
|
+
url=self.url or reflected_url,
|
|
1303
|
+
data_format=self.data_format,
|
|
1304
|
+
database=self.function_database or None,
|
|
1305
|
+
),
|
|
1306
|
+
)
|
|
1307
|
+
body = '\n'.join(syntax).encode('utf-8')
|
|
1308
|
+
|
|
1309
|
+
await send(self.text_response_dict)
|
|
1310
|
+
|
|
1311
|
+
# Return function info
|
|
1312
|
+
elif method == 'GET' and (path == self.show_function_info_path or not path):
|
|
1313
|
+
functions = self.get_function_info()
|
|
1314
|
+
body = json.dumps(dict(functions=functions)).encode('utf-8')
|
|
1315
|
+
await send(self.text_response_dict)
|
|
1316
|
+
|
|
1317
|
+
# Return status
|
|
1318
|
+
elif method == 'GET' and path == self.status:
|
|
1319
|
+
body = json.dumps(dict(status='ok')).encode('utf-8')
|
|
1320
|
+
await send(self.text_response_dict)
|
|
1321
|
+
|
|
1322
|
+
# Path not found
|
|
1323
|
+
else:
|
|
1324
|
+
body = b''
|
|
1325
|
+
await send(self.path_not_found_response_dict)
|
|
1326
|
+
|
|
1327
|
+
# Send body
|
|
1328
|
+
with timer('send_response'):
|
|
1329
|
+
out = self.body_response_dict.copy()
|
|
1330
|
+
out['body'] = body
|
|
1331
|
+
await send(out)
|
|
1332
|
+
|
|
1333
|
+
for k, v in call_timer.metrics.items():
|
|
1334
|
+
timer.metrics[k] = v
|
|
1335
|
+
|
|
1336
|
+
if not self.disable_metrics:
|
|
1337
|
+
metrics = timer.finish()
|
|
1338
|
+
self.logger.info(
|
|
1339
|
+
'Function call metrics',
|
|
1340
|
+
extra={
|
|
1341
|
+
'app_name': self.name,
|
|
1342
|
+
'request_id': request_id,
|
|
1343
|
+
'function_name': timer.metadata.get('function', ''),
|
|
1344
|
+
'metrics': metrics,
|
|
1345
|
+
},
|
|
1346
|
+
)
|
|
1347
|
+
|
|
1348
|
+
def _create_link(
|
|
1349
|
+
self,
|
|
1350
|
+
config: Optional[Dict[str, Any]],
|
|
1351
|
+
credentials: Optional[Dict[str, Any]],
|
|
1352
|
+
) -> Tuple[str, str]:
|
|
1353
|
+
"""Generate CREATE LINK command."""
|
|
1354
|
+
if self.link_name:
|
|
1355
|
+
return self.link_name, ''
|
|
1356
|
+
|
|
1357
|
+
if not config and not credentials:
|
|
1358
|
+
return '', ''
|
|
1359
|
+
|
|
1360
|
+
link_name = f'py_ext_func_link_{secrets.token_hex(14)}'
|
|
1361
|
+
out = [f'CREATE LINK {link_name} AS HTTP']
|
|
1362
|
+
|
|
1363
|
+
if config:
|
|
1364
|
+
out.append(f"CONFIG '{json.dumps(config)}'")
|
|
1365
|
+
|
|
1366
|
+
if credentials:
|
|
1367
|
+
out.append(f"CREDENTIALS '{json.dumps(credentials)}'")
|
|
1368
|
+
|
|
1369
|
+
return link_name, ' '.join(out) + ';'
|
|
1370
|
+
|
|
1371
|
+
def _locate_app_functions(self, cur: Any) -> Tuple[Set[str], Set[str]]:
|
|
1372
|
+
"""Locate all current functions and links belonging to this app."""
|
|
1373
|
+
funcs, links = set(), set()
|
|
1374
|
+
if self.function_database:
|
|
1375
|
+
database_prefix = escape_name(self.function_database) + '.'
|
|
1376
|
+
cur.execute(f'SHOW FUNCTIONS IN {escape_name(self.function_database)}')
|
|
1377
|
+
else:
|
|
1378
|
+
database_prefix = ''
|
|
1379
|
+
cur.execute('SHOW FUNCTIONS')
|
|
1380
|
+
|
|
1381
|
+
for row in list(cur):
|
|
1382
|
+
name, ftype, link = row[0], row[1], row[-1]
|
|
1383
|
+
# Only look at external functions
|
|
1384
|
+
if 'external' not in ftype.lower():
|
|
1385
|
+
continue
|
|
1386
|
+
# See if function URL matches url
|
|
1387
|
+
cur.execute(f'SHOW CREATE FUNCTION {database_prefix}{escape_name(name)}')
|
|
1388
|
+
for fname, _, code, *_ in list(cur):
|
|
1389
|
+
m = re.search(r" (?:\w+) (?:SERVICE|MANAGED) '([^']+)'", code)
|
|
1390
|
+
if m and m.group(1) == self.url:
|
|
1391
|
+
funcs.add(f'{database_prefix}{escape_name(fname)}')
|
|
1392
|
+
if link and re.match(r'^py_ext_func_link_\S{14}$', link):
|
|
1393
|
+
links.add(link)
|
|
1394
|
+
|
|
1395
|
+
return funcs, links
|
|
1396
|
+
|
|
1397
|
+
def get_function_info(
|
|
1398
|
+
self,
|
|
1399
|
+
func_name: Optional[str] = None,
|
|
1400
|
+
) -> Dict[str, Any]:
|
|
1401
|
+
"""
|
|
1402
|
+
Return the functions and function signature information.
|
|
1403
|
+
|
|
1404
|
+
Returns
|
|
1405
|
+
-------
|
|
1406
|
+
Dict[str, Any]
|
|
1407
|
+
|
|
1408
|
+
"""
|
|
1409
|
+
functions = {}
|
|
1410
|
+
no_default = object()
|
|
1411
|
+
|
|
1412
|
+
# Generate CREATE FUNCTION SQL for each function using get_create_functions
|
|
1413
|
+
create_sqls = self.get_create_functions(replace=True)
|
|
1414
|
+
sql_map = {}
|
|
1415
|
+
for (_, info), sql in zip(self.endpoints.values(), create_sqls):
|
|
1416
|
+
sig = info['signature']
|
|
1417
|
+
sql_map[sig['name']] = sql
|
|
1418
|
+
|
|
1419
|
+
for key, (func, info) in self.endpoints.items():
|
|
1420
|
+
# Get info from docstring
|
|
1421
|
+
doc_summary = ''
|
|
1422
|
+
doc_long_description = ''
|
|
1423
|
+
doc_params = {}
|
|
1424
|
+
doc_returns = None
|
|
1425
|
+
doc_examples = []
|
|
1426
|
+
if func.__doc__:
|
|
1427
|
+
try:
|
|
1428
|
+
docs = parse(func.__doc__)
|
|
1429
|
+
doc_params = {p.arg_name: p for p in docs.params}
|
|
1430
|
+
doc_returns = docs.returns
|
|
1431
|
+
if not docs.short_description and docs.long_description:
|
|
1432
|
+
doc_summary = docs.long_description or ''
|
|
1433
|
+
else:
|
|
1434
|
+
doc_summary = docs.short_description or ''
|
|
1435
|
+
doc_long_description = docs.long_description or ''
|
|
1436
|
+
for ex in docs.examples:
|
|
1437
|
+
ex_dict: Dict[str, Any] = {
|
|
1438
|
+
'description': None,
|
|
1439
|
+
'code': None,
|
|
1440
|
+
'output': None,
|
|
1441
|
+
}
|
|
1442
|
+
if ex.description:
|
|
1443
|
+
ex_dict['description'] = ex.description
|
|
1444
|
+
if ex.snippet:
|
|
1445
|
+
code, output = [], []
|
|
1446
|
+
for line in ex.snippet.split('\n'):
|
|
1447
|
+
line = line.rstrip()
|
|
1448
|
+
if re.match(r'^(\w+>|>>>|\.\.\.)', line):
|
|
1449
|
+
code.append(line)
|
|
1450
|
+
else:
|
|
1451
|
+
output.append(line)
|
|
1452
|
+
ex_dict['code'] = '\n'.join(code) or None
|
|
1453
|
+
ex_dict['output'] = '\n'.join(output) or None
|
|
1454
|
+
if ex.post_snippet:
|
|
1455
|
+
ex_dict['postscript'] = ex.post_snippet
|
|
1456
|
+
doc_examples.append(ex_dict)
|
|
1457
|
+
|
|
1458
|
+
except Exception as e:
|
|
1459
|
+
self.logger.warning(
|
|
1460
|
+
'Could not parse docstring for function',
|
|
1461
|
+
extra={
|
|
1462
|
+
'app_name': self.name,
|
|
1463
|
+
'function_name': key.decode('utf-8'),
|
|
1464
|
+
'error': str(e),
|
|
1465
|
+
},
|
|
1466
|
+
)
|
|
1467
|
+
|
|
1468
|
+
if not func_name or key == func_name:
|
|
1469
|
+
sig = info['signature']
|
|
1470
|
+
args = []
|
|
1471
|
+
|
|
1472
|
+
# Function arguments
|
|
1473
|
+
for i, a in enumerate(sig.get('args', [])):
|
|
1474
|
+
name = a['name']
|
|
1475
|
+
dtype = a['dtype']
|
|
1476
|
+
nullable = '?' in dtype
|
|
1477
|
+
args.append(
|
|
1478
|
+
dict(
|
|
1479
|
+
name=name,
|
|
1480
|
+
dtype=dtype.replace('?', ''),
|
|
1481
|
+
nullable=nullable,
|
|
1482
|
+
description=(doc_params[name].description or '')
|
|
1483
|
+
if name in doc_params else '',
|
|
1484
|
+
),
|
|
1485
|
+
)
|
|
1486
|
+
if a.get('default', no_default) is not no_default:
|
|
1487
|
+
args[-1]['default'] = a['default']
|
|
1488
|
+
|
|
1489
|
+
# Return values
|
|
1490
|
+
ret = sig.get('returns', [])
|
|
1491
|
+
returns = []
|
|
1492
|
+
|
|
1493
|
+
for a in ret:
|
|
1494
|
+
dtype = a['dtype']
|
|
1495
|
+
nullable = '?' in dtype
|
|
1496
|
+
returns.append(
|
|
1497
|
+
dict(
|
|
1498
|
+
dtype=dtype.replace('?', ''),
|
|
1499
|
+
nullable=nullable,
|
|
1500
|
+
description=doc_returns.description
|
|
1501
|
+
if doc_returns else '',
|
|
1502
|
+
),
|
|
1503
|
+
)
|
|
1504
|
+
if a.get('name', None):
|
|
1505
|
+
returns[-1]['name'] = a['name']
|
|
1506
|
+
if a.get('default', no_default) is not no_default:
|
|
1507
|
+
returns[-1]['default'] = a['default']
|
|
1508
|
+
|
|
1509
|
+
sql = sql_map.get(sig['name'], '')
|
|
1510
|
+
functions[sig['name']] = dict(
|
|
1511
|
+
args=args,
|
|
1512
|
+
returns=returns,
|
|
1513
|
+
function_type=info['function_type'],
|
|
1514
|
+
sql_statement=sql,
|
|
1515
|
+
summary=doc_summary,
|
|
1516
|
+
long_description=doc_long_description,
|
|
1517
|
+
examples=doc_examples,
|
|
1518
|
+
)
|
|
1519
|
+
|
|
1520
|
+
return functions
|
|
1521
|
+
|
|
1522
|
+
def get_create_functions(
|
|
1523
|
+
self,
|
|
1524
|
+
replace: bool = False,
|
|
1525
|
+
) -> List[str]:
|
|
1526
|
+
"""
|
|
1527
|
+
Generate CREATE FUNCTION code for all functions.
|
|
1528
|
+
|
|
1529
|
+
Parameters
|
|
1530
|
+
----------
|
|
1531
|
+
replace : bool, optional
|
|
1532
|
+
Should existing functions be replaced?
|
|
1533
|
+
|
|
1534
|
+
Returns
|
|
1535
|
+
-------
|
|
1536
|
+
List[str]
|
|
1537
|
+
|
|
1538
|
+
"""
|
|
1539
|
+
if not self.endpoints:
|
|
1540
|
+
return []
|
|
1541
|
+
|
|
1542
|
+
out = []
|
|
1543
|
+
link = ''
|
|
1544
|
+
if self.app_mode.lower() == 'remote':
|
|
1545
|
+
link, link_str = self._create_link(self.link_config, self.link_credentials)
|
|
1546
|
+
if link and link_str:
|
|
1547
|
+
out.append(link_str)
|
|
1548
|
+
|
|
1549
|
+
for key, (endpoint, endpoint_info) in self.endpoints.items():
|
|
1550
|
+
out.append(
|
|
1551
|
+
signature_to_sql(
|
|
1552
|
+
endpoint_info['signature'],
|
|
1553
|
+
url=self.url,
|
|
1554
|
+
data_format=self.data_format,
|
|
1555
|
+
app_mode=self.app_mode,
|
|
1556
|
+
replace=replace,
|
|
1557
|
+
link=link or None,
|
|
1558
|
+
database=self.function_database or None,
|
|
1559
|
+
),
|
|
1560
|
+
)
|
|
1561
|
+
|
|
1562
|
+
return out
|
|
1563
|
+
|
|
1564
|
+
def register_functions(
|
|
1565
|
+
self,
|
|
1566
|
+
*connection_args: Any,
|
|
1567
|
+
replace: bool = False,
|
|
1568
|
+
**connection_kwargs: Any,
|
|
1569
|
+
) -> None:
|
|
1570
|
+
"""
|
|
1571
|
+
Register functions with the database.
|
|
1572
|
+
|
|
1573
|
+
Parameters
|
|
1574
|
+
----------
|
|
1575
|
+
*connection_args : Any
|
|
1576
|
+
Database connection parameters
|
|
1577
|
+
replace : bool, optional
|
|
1578
|
+
Should existing functions be replaced?
|
|
1579
|
+
**connection_kwargs : Any
|
|
1580
|
+
Database connection parameters
|
|
1581
|
+
|
|
1582
|
+
"""
|
|
1583
|
+
with connection.connect(*connection_args, **connection_kwargs) as conn:
|
|
1584
|
+
with conn.cursor() as cur:
|
|
1585
|
+
if replace:
|
|
1586
|
+
funcs, links = self._locate_app_functions(cur)
|
|
1587
|
+
for fname in funcs:
|
|
1588
|
+
cur.execute(f'DROP FUNCTION IF EXISTS {fname}')
|
|
1589
|
+
for link in links:
|
|
1590
|
+
cur.execute(f'DROP LINK {link}')
|
|
1591
|
+
for func in self.get_create_functions(replace=replace):
|
|
1592
|
+
cur.execute(func)
|
|
1593
|
+
|
|
1594
|
+
def drop_functions(
|
|
1595
|
+
self,
|
|
1596
|
+
*connection_args: Any,
|
|
1597
|
+
**connection_kwargs: Any,
|
|
1598
|
+
) -> None:
|
|
1599
|
+
"""
|
|
1600
|
+
Drop registered functions from database.
|
|
1601
|
+
|
|
1602
|
+
Parameters
|
|
1603
|
+
----------
|
|
1604
|
+
*connection_args : Any
|
|
1605
|
+
Database connection parameters
|
|
1606
|
+
**connection_kwargs : Any
|
|
1607
|
+
Database connection parameters
|
|
1608
|
+
|
|
1609
|
+
"""
|
|
1610
|
+
with connection.connect(*connection_args, **connection_kwargs) as conn:
|
|
1611
|
+
with conn.cursor() as cur:
|
|
1612
|
+
funcs, links = self._locate_app_functions(cur)
|
|
1613
|
+
for fname in funcs:
|
|
1614
|
+
cur.execute(f'DROP FUNCTION IF EXISTS {fname}')
|
|
1615
|
+
for link in links:
|
|
1616
|
+
cur.execute(f'DROP LINK {link}')
|
|
1617
|
+
|
|
1618
|
+
async def call(
|
|
1619
|
+
self,
|
|
1620
|
+
name: str,
|
|
1621
|
+
data_in: io.BytesIO,
|
|
1622
|
+
data_out: io.BytesIO,
|
|
1623
|
+
data_format: Optional[str] = None,
|
|
1624
|
+
data_version: Optional[str] = None,
|
|
1625
|
+
) -> None:
|
|
1626
|
+
"""
|
|
1627
|
+
Call a function in the application.
|
|
1628
|
+
|
|
1629
|
+
Parameters
|
|
1630
|
+
----------
|
|
1631
|
+
name : str
|
|
1632
|
+
Name of the function to call
|
|
1633
|
+
data_in : io.BytesIO
|
|
1634
|
+
The input data rows
|
|
1635
|
+
data_out : io.BytesIO
|
|
1636
|
+
The output data rows
|
|
1637
|
+
data_format : str, optional
|
|
1638
|
+
The format of the input and output data
|
|
1639
|
+
data_version : str, optional
|
|
1640
|
+
The version of the data format
|
|
1641
|
+
|
|
1642
|
+
"""
|
|
1643
|
+
data_format = data_format or self.data_format
|
|
1644
|
+
data_version = data_version or self.data_version
|
|
1645
|
+
|
|
1646
|
+
async def receive() -> Dict[str, Any]:
|
|
1647
|
+
return dict(body=data_in.read())
|
|
1648
|
+
|
|
1649
|
+
async def send(content: Dict[str, Any]) -> None:
|
|
1650
|
+
status = content.get('status', 200)
|
|
1651
|
+
if status != 200:
|
|
1652
|
+
raise KeyError(f'error occurred when calling `{name}`: {status}')
|
|
1653
|
+
data_out.write(content.get('body', b''))
|
|
1654
|
+
|
|
1655
|
+
accepts = dict(
|
|
1656
|
+
json=b'application/json',
|
|
1657
|
+
rowdat_1=b'application/octet-stream',
|
|
1658
|
+
arrow=b'application/vnd.apache.arrow.file',
|
|
1659
|
+
)
|
|
1660
|
+
|
|
1661
|
+
# Mock an ASGI scope
|
|
1662
|
+
scope = dict(
|
|
1663
|
+
type='http',
|
|
1664
|
+
path='invoke',
|
|
1665
|
+
method='POST',
|
|
1666
|
+
headers={
|
|
1667
|
+
b'content-type': accepts[data_format.lower()],
|
|
1668
|
+
b'accepts': accepts[data_format.lower()],
|
|
1669
|
+
b's2-ef-name': name.encode('utf-8'),
|
|
1670
|
+
b's2-ef-version': data_version.encode('utf-8'),
|
|
1671
|
+
b's2-ef-ignore-cancel': b'true',
|
|
1672
|
+
},
|
|
1673
|
+
)
|
|
1674
|
+
|
|
1675
|
+
await self(scope, receive, send)
|
|
1676
|
+
|
|
1677
|
+
def to_environment(
|
|
1678
|
+
self,
|
|
1679
|
+
name: str,
|
|
1680
|
+
destination: str = '.',
|
|
1681
|
+
version: Optional[str] = None,
|
|
1682
|
+
dependencies: Optional[List[str]] = None,
|
|
1683
|
+
authors: Optional[List[Dict[str, str]]] = None,
|
|
1684
|
+
maintainers: Optional[List[Dict[str, str]]] = None,
|
|
1685
|
+
description: Optional[str] = None,
|
|
1686
|
+
container_service: Optional[Dict[str, Any]] = None,
|
|
1687
|
+
external_function: Optional[Dict[str, Any]] = None,
|
|
1688
|
+
external_function_remote: Optional[Dict[str, Any]] = None,
|
|
1689
|
+
external_function_collocated: Optional[Dict[str, Any]] = None,
|
|
1690
|
+
overwrite: bool = False,
|
|
1691
|
+
) -> None:
|
|
1692
|
+
"""
|
|
1693
|
+
Convert application to an environment file.
|
|
1694
|
+
|
|
1695
|
+
Parameters
|
|
1696
|
+
----------
|
|
1697
|
+
name : str
|
|
1698
|
+
Name of the output environment
|
|
1699
|
+
destination : str, optional
|
|
1700
|
+
Location of the output file
|
|
1701
|
+
version : str, optional
|
|
1702
|
+
Version of the package
|
|
1703
|
+
dependencies : List[str], optional
|
|
1704
|
+
List of dependency specifications like in a requirements.txt file
|
|
1705
|
+
authors : List[Dict[str, Any]], optional
|
|
1706
|
+
Dictionaries of author information. Keys may include: email, name
|
|
1707
|
+
maintainers : List[Dict[str, Any]], optional
|
|
1708
|
+
Dictionaries of maintainer information. Keys may include: email, name
|
|
1709
|
+
description : str, optional
|
|
1710
|
+
Description of package
|
|
1711
|
+
container_service : Dict[str, Any], optional
|
|
1712
|
+
Container service specifications
|
|
1713
|
+
external_function : Dict[str, Any], optional
|
|
1714
|
+
External function specifications (applies to both remote and collocated)
|
|
1715
|
+
external_function_remote : Dict[str, Any], optional
|
|
1716
|
+
Remote external function specifications
|
|
1717
|
+
external_function_collocated : Dict[str, Any], optional
|
|
1718
|
+
Collocated external function specifications
|
|
1719
|
+
overwrite : bool, optional
|
|
1720
|
+
Should destination file be overwritten if it exists?
|
|
1721
|
+
|
|
1722
|
+
"""
|
|
1723
|
+
if not has_cloudpickle:
|
|
1724
|
+
raise RuntimeError('the cloudpicke package is required for this operation')
|
|
1725
|
+
|
|
1726
|
+
# Write to temporary location if a remote destination is specified
|
|
1727
|
+
tmpdir = None
|
|
1728
|
+
if destination.startswith('stage://'):
|
|
1729
|
+
tmpdir = tempfile.TemporaryDirectory()
|
|
1730
|
+
local_path = os.path.join(tmpdir.name, f'{name}.env')
|
|
1731
|
+
else:
|
|
1732
|
+
local_path = os.path.join(destination, f'{name}.env')
|
|
1733
|
+
if not overwrite and os.path.exists(local_path):
|
|
1734
|
+
raise OSError(f'path already exists: {local_path}')
|
|
1735
|
+
|
|
1736
|
+
with zipfile.ZipFile(local_path, mode='w') as z:
|
|
1737
|
+
# Write metadata
|
|
1738
|
+
z.writestr(
|
|
1739
|
+
'pyproject.toml', utils.to_toml({
|
|
1740
|
+
'project': dict(
|
|
1741
|
+
name=name,
|
|
1742
|
+
version=version,
|
|
1743
|
+
dependencies=dependencies,
|
|
1744
|
+
requires_python='== ' +
|
|
1745
|
+
'.'.join(str(x) for x in sys.version_info[:3]),
|
|
1746
|
+
authors=authors,
|
|
1747
|
+
maintainers=maintainers,
|
|
1748
|
+
description=description,
|
|
1749
|
+
),
|
|
1750
|
+
'tool.container-service': container_service,
|
|
1751
|
+
'tool.external-function': external_function,
|
|
1752
|
+
'tool.external-function.remote': external_function_remote,
|
|
1753
|
+
'tool.external-function.collocated': external_function_collocated,
|
|
1754
|
+
}),
|
|
1755
|
+
)
|
|
1756
|
+
|
|
1757
|
+
# Write Python package
|
|
1758
|
+
z.writestr(
|
|
1759
|
+
f'{name}/__init__.py',
|
|
1760
|
+
textwrap.dedent(f'''
|
|
1761
|
+
import pickle as _pkl
|
|
1762
|
+
globals().update(
|
|
1763
|
+
_pkl.loads({cloudpickle.dumps(self.external_functions)}),
|
|
1764
|
+
)
|
|
1765
|
+
__all__ = {list(self.external_functions.keys())}''').strip(),
|
|
1766
|
+
)
|
|
1767
|
+
|
|
1768
|
+
# Upload to Stage as needed
|
|
1769
|
+
if destination.startswith('stage://'):
|
|
1770
|
+
url = urllib.parse.urlparse(re.sub(r'/+$', r'', destination) + '/')
|
|
1771
|
+
if not url.path or url.path == '/':
|
|
1772
|
+
raise ValueError(f'no stage path was specified: {destination}')
|
|
1773
|
+
|
|
1774
|
+
mgr = manage_workspaces()
|
|
1775
|
+
if url.hostname:
|
|
1776
|
+
wsg = mgr.get_workspace_group(url.hostname)
|
|
1777
|
+
elif os.environ.get('SINGLESTOREDB_WORKSPACE_GROUP'):
|
|
1778
|
+
wsg = mgr.get_workspace_group(
|
|
1779
|
+
os.environ['SINGLESTOREDB_WORKSPACE_GROUP'],
|
|
1780
|
+
)
|
|
1781
|
+
else:
|
|
1782
|
+
raise ValueError(f'no workspace group specified: {destination}')
|
|
1783
|
+
|
|
1784
|
+
# Make intermediate directories
|
|
1785
|
+
if url.path.count('/') > 1:
|
|
1786
|
+
wsg.stage.mkdirs(os.path.dirname(url.path))
|
|
1787
|
+
|
|
1788
|
+
wsg.stage.upload_file(
|
|
1789
|
+
local_path, url.path + f'{name}.env',
|
|
1790
|
+
overwrite=overwrite,
|
|
1791
|
+
)
|
|
1792
|
+
os.remove(local_path)
|
|
1793
|
+
|
|
1794
|
+
|
|
1795
|
+
def main(argv: Optional[List[str]] = None) -> None:
|
|
1796
|
+
"""
|
|
1797
|
+
Main program for HTTP-based Python UDFs
|
|
1798
|
+
|
|
1799
|
+
Parameters
|
|
1800
|
+
----------
|
|
1801
|
+
argv : List[str], optional
|
|
1802
|
+
List of command-line parameters
|
|
1803
|
+
|
|
1804
|
+
"""
|
|
1805
|
+
try:
|
|
1806
|
+
import uvicorn
|
|
1807
|
+
except ImportError:
|
|
1808
|
+
raise ImportError('the uvicorn package is required to run this command')
|
|
1809
|
+
|
|
1810
|
+
# Should we run in embedded mode (typically for Jupyter)
|
|
1811
|
+
try:
|
|
1812
|
+
asyncio.get_running_loop()
|
|
1813
|
+
use_async = True
|
|
1814
|
+
except RuntimeError:
|
|
1815
|
+
use_async = False
|
|
1816
|
+
|
|
1817
|
+
# Temporary directory for Stage environment files
|
|
1818
|
+
tmpdir = None
|
|
1819
|
+
|
|
1820
|
+
# Depending on whether we find an environment file specified, we
|
|
1821
|
+
# may have to process the command line twice.
|
|
1822
|
+
functions = []
|
|
1823
|
+
defaults: Dict[str, Any] = {}
|
|
1824
|
+
for i in range(2):
|
|
1825
|
+
|
|
1826
|
+
parser = argparse.ArgumentParser(
|
|
1827
|
+
prog='python -m singlestoredb.functions.ext.asgi',
|
|
1828
|
+
description='Run an HTTP-based Python UDF server',
|
|
1829
|
+
)
|
|
1830
|
+
parser.add_argument(
|
|
1831
|
+
'--url', metavar='url',
|
|
1832
|
+
default=defaults.get(
|
|
1833
|
+
'url',
|
|
1834
|
+
get_option('external_function.url'),
|
|
1835
|
+
),
|
|
1836
|
+
help='URL of the UDF server endpoint',
|
|
1837
|
+
)
|
|
1838
|
+
parser.add_argument(
|
|
1839
|
+
'--host', metavar='host',
|
|
1840
|
+
default=defaults.get(
|
|
1841
|
+
'host',
|
|
1842
|
+
get_option('external_function.host'),
|
|
1843
|
+
),
|
|
1844
|
+
help='bind socket to this host',
|
|
1845
|
+
)
|
|
1846
|
+
parser.add_argument(
|
|
1847
|
+
'--port', metavar='port', type=int,
|
|
1848
|
+
default=defaults.get(
|
|
1849
|
+
'port',
|
|
1850
|
+
get_option('external_function.port'),
|
|
1851
|
+
),
|
|
1852
|
+
help='bind socket to this port',
|
|
1853
|
+
)
|
|
1854
|
+
parser.add_argument(
|
|
1855
|
+
'--db', metavar='conn-str',
|
|
1856
|
+
default=defaults.get(
|
|
1857
|
+
'connection',
|
|
1858
|
+
get_option('external_function.connection'),
|
|
1859
|
+
),
|
|
1860
|
+
help='connection string to use for registering functions',
|
|
1861
|
+
)
|
|
1862
|
+
parser.add_argument(
|
|
1863
|
+
'--replace-existing', action='store_true',
|
|
1864
|
+
help='should existing functions of the same name '
|
|
1865
|
+
'in the database be replaced?',
|
|
1866
|
+
)
|
|
1867
|
+
parser.add_argument(
|
|
1868
|
+
'--data-format', metavar='format',
|
|
1869
|
+
default=defaults.get(
|
|
1870
|
+
'data_format',
|
|
1871
|
+
get_option('external_function.data_format'),
|
|
1872
|
+
),
|
|
1873
|
+
choices=['rowdat_1', 'json'],
|
|
1874
|
+
help='format of the data rows',
|
|
1875
|
+
)
|
|
1876
|
+
parser.add_argument(
|
|
1877
|
+
'--data-version', metavar='version',
|
|
1878
|
+
default=defaults.get(
|
|
1879
|
+
'data_version',
|
|
1880
|
+
get_option('external_function.data_version'),
|
|
1881
|
+
),
|
|
1882
|
+
help='version of the data row format',
|
|
1883
|
+
)
|
|
1884
|
+
parser.add_argument(
|
|
1885
|
+
'--link-name', metavar='name',
|
|
1886
|
+
default=defaults.get(
|
|
1887
|
+
'link_name',
|
|
1888
|
+
get_option('external_function.link_name'),
|
|
1889
|
+
) or '',
|
|
1890
|
+
help='name of the link to use for connections',
|
|
1891
|
+
)
|
|
1892
|
+
parser.add_argument(
|
|
1893
|
+
'--link-config', metavar='json',
|
|
1894
|
+
default=str(
|
|
1895
|
+
defaults.get(
|
|
1896
|
+
'link_config',
|
|
1897
|
+
get_option('external_function.link_config'),
|
|
1898
|
+
) or '{}',
|
|
1899
|
+
),
|
|
1900
|
+
help='link config in JSON format',
|
|
1901
|
+
)
|
|
1902
|
+
parser.add_argument(
|
|
1903
|
+
'--link-credentials', metavar='json',
|
|
1904
|
+
default=str(
|
|
1905
|
+
defaults.get(
|
|
1906
|
+
'link_credentials',
|
|
1907
|
+
get_option('external_function.link_credentials'),
|
|
1908
|
+
) or '{}',
|
|
1909
|
+
),
|
|
1910
|
+
help='link credentials in JSON format',
|
|
1911
|
+
)
|
|
1912
|
+
parser.add_argument(
|
|
1913
|
+
'--log-level', metavar='[info|debug|warning|error]',
|
|
1914
|
+
default=defaults.get(
|
|
1915
|
+
'log_level',
|
|
1916
|
+
get_option('external_function.log_level'),
|
|
1917
|
+
),
|
|
1918
|
+
help='logging level',
|
|
1919
|
+
)
|
|
1920
|
+
parser.add_argument(
|
|
1921
|
+
'--log-file', metavar='filepath',
|
|
1922
|
+
default=defaults.get(
|
|
1923
|
+
'log_file',
|
|
1924
|
+
get_option('external_function.log_file'),
|
|
1925
|
+
),
|
|
1926
|
+
help='File path to write logs to instead of console',
|
|
1927
|
+
)
|
|
1928
|
+
parser.add_argument(
|
|
1929
|
+
'--disable-metrics', action='store_true',
|
|
1930
|
+
default=defaults.get(
|
|
1931
|
+
'disable_metrics',
|
|
1932
|
+
get_option('external_function.disable_metrics'),
|
|
1933
|
+
),
|
|
1934
|
+
help='Disable logging of function call metrics',
|
|
1935
|
+
)
|
|
1936
|
+
parser.add_argument(
|
|
1937
|
+
'--name-prefix', metavar='name_prefix',
|
|
1938
|
+
default=defaults.get(
|
|
1939
|
+
'name_prefix',
|
|
1940
|
+
get_option('external_function.name_prefix'),
|
|
1941
|
+
),
|
|
1942
|
+
help='Prefix to add to function names',
|
|
1943
|
+
)
|
|
1944
|
+
parser.add_argument(
|
|
1945
|
+
'--name-suffix', metavar='name_suffix',
|
|
1946
|
+
default=defaults.get(
|
|
1947
|
+
'name_suffix',
|
|
1948
|
+
get_option('external_function.name_suffix'),
|
|
1949
|
+
),
|
|
1950
|
+
help='Suffix to add to function names',
|
|
1951
|
+
)
|
|
1952
|
+
parser.add_argument(
|
|
1953
|
+
'--function-database', metavar='function_database',
|
|
1954
|
+
default=defaults.get(
|
|
1955
|
+
'function_database',
|
|
1956
|
+
get_option('external_function.function_database'),
|
|
1957
|
+
),
|
|
1958
|
+
help='Database to use for the function definition',
|
|
1959
|
+
)
|
|
1960
|
+
parser.add_argument(
|
|
1961
|
+
'--app-name', metavar='app_name',
|
|
1962
|
+
default=defaults.get(
|
|
1963
|
+
'app_name',
|
|
1964
|
+
get_option('external_function.app_name'),
|
|
1965
|
+
),
|
|
1966
|
+
help='Name for the application instance',
|
|
1967
|
+
)
|
|
1968
|
+
parser.add_argument(
|
|
1969
|
+
'functions', metavar='module.or.func.path', nargs='*',
|
|
1970
|
+
help='functions or modules to export in UDF server',
|
|
1971
|
+
)
|
|
1972
|
+
|
|
1973
|
+
args = parser.parse_args(argv)
|
|
1974
|
+
|
|
1975
|
+
if i > 0:
|
|
1976
|
+
break
|
|
1977
|
+
|
|
1978
|
+
# Download Stage files as needed
|
|
1979
|
+
for i, f in enumerate(args.functions):
|
|
1980
|
+
if f.startswith('stage://'):
|
|
1981
|
+
url = urllib.parse.urlparse(f)
|
|
1982
|
+
if not url.path or url.path == '/':
|
|
1983
|
+
raise ValueError(f'no stage path was specified: {f}')
|
|
1984
|
+
if url.path.endswith('/'):
|
|
1985
|
+
raise ValueError(f'an environment file must be specified: {f}')
|
|
1986
|
+
|
|
1987
|
+
mgr = manage_workspaces()
|
|
1988
|
+
if url.hostname:
|
|
1989
|
+
wsg = mgr.get_workspace_group(url.hostname)
|
|
1990
|
+
elif os.environ.get('SINGLESTOREDB_WORKSPACE_GROUP'):
|
|
1991
|
+
wsg = mgr.get_workspace_group(
|
|
1992
|
+
os.environ['SINGLESTOREDB_WORKSPACE_GROUP'],
|
|
1993
|
+
)
|
|
1994
|
+
else:
|
|
1995
|
+
raise ValueError(f'no workspace group specified: {f}')
|
|
1996
|
+
|
|
1997
|
+
if tmpdir is None:
|
|
1998
|
+
tmpdir = tempfile.TemporaryDirectory()
|
|
1999
|
+
|
|
2000
|
+
local_path = os.path.join(tmpdir.name, url.path.split('/')[-1])
|
|
2001
|
+
wsg.stage.download_file(url.path, local_path)
|
|
2002
|
+
args.functions[i] = local_path
|
|
2003
|
+
|
|
2004
|
+
elif f.startswith('http://') or f.startswith('https://'):
|
|
2005
|
+
if tmpdir is None:
|
|
2006
|
+
tmpdir = tempfile.TemporaryDirectory()
|
|
2007
|
+
|
|
2008
|
+
local_path = os.path.join(tmpdir.name, f.split('/')[-1])
|
|
2009
|
+
urllib.request.urlretrieve(f, local_path)
|
|
2010
|
+
args.functions[i] = local_path
|
|
2011
|
+
|
|
2012
|
+
# See if any of the args are zip files (assume they are environment files)
|
|
2013
|
+
modules = [(x, zipfile.is_zipfile(x)) for x in args.functions]
|
|
2014
|
+
envs = [x[0] for x in modules if x[1]]
|
|
2015
|
+
others = [x[0] for x in modules if not x[1]]
|
|
2016
|
+
|
|
2017
|
+
if envs and len(envs) > 1:
|
|
2018
|
+
raise RuntimeError('only one environment file may be specified')
|
|
2019
|
+
|
|
2020
|
+
if envs and others:
|
|
2021
|
+
raise RuntimeError('environment files and other modules can not be mixed.')
|
|
2022
|
+
|
|
2023
|
+
# See if an environment file was specified. If so, use those settings
|
|
2024
|
+
# as the defaults and reprocess command line.
|
|
2025
|
+
if envs:
|
|
2026
|
+
# Add pyproject.toml variables and redo command-line processing
|
|
2027
|
+
defaults = utils.read_config(
|
|
2028
|
+
envs[0],
|
|
2029
|
+
['tool.external-function', 'tool.external-function.remote'],
|
|
2030
|
+
)
|
|
2031
|
+
|
|
2032
|
+
# Load zip file as a module
|
|
2033
|
+
modname = os.path.splitext(os.path.basename(envs[0]))[0]
|
|
2034
|
+
zi = zipimport.zipimporter(envs[0])
|
|
2035
|
+
mod = zi.load_module(modname)
|
|
2036
|
+
if mod is None:
|
|
2037
|
+
raise RuntimeError(f'environment file could not be imported: {envs[0]}')
|
|
2038
|
+
functions = [mod]
|
|
2039
|
+
|
|
2040
|
+
if defaults:
|
|
2041
|
+
continue
|
|
2042
|
+
|
|
2043
|
+
args.functions = functions or args.functions or None
|
|
2044
|
+
args.replace_existing = args.replace_existing \
|
|
2045
|
+
or defaults.get('replace_existing') \
|
|
2046
|
+
or get_option('external_function.replace_existing')
|
|
2047
|
+
|
|
2048
|
+
# Substitute in host / port if specified
|
|
2049
|
+
if args.host != defaults.get('host') or args.port != defaults.get('port'):
|
|
2050
|
+
u = urllib.parse.urlparse(args.url)
|
|
2051
|
+
args.url = u._replace(netloc=f'{args.host}:{args.port}').geturl()
|
|
2052
|
+
|
|
2053
|
+
# Create application from functions / module
|
|
2054
|
+
app = Application(
|
|
2055
|
+
functions=args.functions,
|
|
2056
|
+
url=args.url,
|
|
2057
|
+
data_format=args.data_format,
|
|
2058
|
+
data_version=args.data_version,
|
|
2059
|
+
link_name=args.link_name or None,
|
|
2060
|
+
link_config=json.loads(args.link_config) or None,
|
|
2061
|
+
link_credentials=json.loads(args.link_credentials) or None,
|
|
2062
|
+
app_mode='remote',
|
|
2063
|
+
name_prefix=args.name_prefix,
|
|
2064
|
+
name_suffix=args.name_suffix,
|
|
2065
|
+
function_database=args.function_database or None,
|
|
2066
|
+
log_file=args.log_file,
|
|
2067
|
+
log_level=args.log_level,
|
|
2068
|
+
disable_metrics=args.disable_metrics,
|
|
2069
|
+
app_name=args.app_name,
|
|
2070
|
+
)
|
|
2071
|
+
|
|
2072
|
+
funcs = app.get_create_functions(replace=args.replace_existing)
|
|
2073
|
+
if not funcs:
|
|
2074
|
+
raise RuntimeError('no functions specified')
|
|
2075
|
+
|
|
2076
|
+
for f in funcs:
|
|
2077
|
+
app.logger.info(f)
|
|
2078
|
+
|
|
2079
|
+
try:
|
|
2080
|
+
if args.db:
|
|
2081
|
+
app.logger.info('Registering functions with database')
|
|
2082
|
+
app.register_functions(
|
|
2083
|
+
args.db,
|
|
2084
|
+
replace=args.replace_existing,
|
|
2085
|
+
)
|
|
2086
|
+
|
|
2087
|
+
app_args = {
|
|
2088
|
+
k: v for k, v in dict(
|
|
2089
|
+
host=args.host or None,
|
|
2090
|
+
port=args.port or None,
|
|
2091
|
+
log_level=args.log_level,
|
|
2092
|
+
lifespan='off',
|
|
2093
|
+
).items() if v is not None
|
|
2094
|
+
}
|
|
2095
|
+
|
|
2096
|
+
# Configure uvicorn logging to use JSON format matching Application's format
|
|
2097
|
+
app_args['log_config'] = app.get_uvicorn_log_config()
|
|
2098
|
+
|
|
2099
|
+
if use_async:
|
|
2100
|
+
asyncio.create_task(_run_uvicorn(uvicorn, app, app_args, db=args.db))
|
|
2101
|
+
else:
|
|
2102
|
+
uvicorn.run(app, **app_args)
|
|
2103
|
+
|
|
2104
|
+
finally:
|
|
2105
|
+
if not use_async and args.db:
|
|
2106
|
+
app.logger.info('Dropping functions from database')
|
|
2107
|
+
app.drop_functions(args.db)
|
|
2108
|
+
|
|
2109
|
+
|
|
2110
|
+
async def _run_uvicorn(
|
|
2111
|
+
uvicorn: Any,
|
|
2112
|
+
app: Any,
|
|
2113
|
+
app_args: Any,
|
|
2114
|
+
db: Optional[str] = None,
|
|
2115
|
+
) -> None:
|
|
2116
|
+
"""Run uvicorn server and clean up functions after shutdown."""
|
|
2117
|
+
await uvicorn.Server(uvicorn.Config(app, **app_args)).serve()
|
|
2118
|
+
if db:
|
|
2119
|
+
app.logger.info('Dropping functions from database')
|
|
2120
|
+
app.drop_functions(db)
|
|
2121
|
+
|
|
2122
|
+
|
|
2123
|
+
create_app = Application
|
|
2124
|
+
|
|
2125
|
+
|
|
2126
|
+
if __name__ == '__main__':
|
|
2127
|
+
try:
|
|
2128
|
+
main()
|
|
2129
|
+
except RuntimeError as exc:
|
|
2130
|
+
logger.error(str(exc))
|
|
2131
|
+
sys.exit(1)
|
|
2132
|
+
except KeyboardInterrupt:
|
|
2133
|
+
pass
|