libentry 1.22.4__py3-none-any.whl → 1.23.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
libentry/schema.py CHANGED
@@ -2,9 +2,10 @@
2
2
 
3
3
  __author__ = "xi"
4
4
  __all__ = [
5
+ "APISignature",
6
+ "get_api_signature",
5
7
  "SchemaField",
6
8
  "Schema",
7
- "ParseContext",
8
9
  "parse_type",
9
10
  "QueryAPIOutput",
10
11
  "query_api",
@@ -13,14 +14,73 @@ __all__ = [
13
14
  import enum
14
15
  from dataclasses import asdict, dataclass, is_dataclass
15
16
  from inspect import signature
16
- from typing import Any, Dict, Iterable, List, Literal, Mapping, MutableMapping, NoReturn, Optional, Sequence, Union, \
17
- get_args, \
18
- get_origin
17
+ from typing import Any, Dict, Iterable, List, Literal, Mapping, MutableMapping, NoReturn, Optional, Sequence, Type, \
18
+ Union, get_args, get_origin
19
19
 
20
- from pydantic import BaseModel, Field, create_model
20
+ from pydantic import BaseModel, ConfigDict, Field, RootModel, create_model
21
21
  from pydantic_core import PydanticUndefined
22
22
 
23
23
 
24
+ class APISignature(BaseModel):
25
+ input_types: List[Any]
26
+ input_model: Optional[Type[BaseModel]] = None
27
+ bundled_model: Optional[Type[BaseModel]] = None
28
+ output_type: Optional[Any] = None
29
+ output_model: Optional[Type[BaseModel]] = None
30
+
31
+
32
+ def get_api_signature(fn, ignores: List[str] = ("self", "cls")) -> APISignature:
33
+ sig = signature(fn)
34
+
35
+ input_types = []
36
+ fields = {}
37
+ for name, param in sig.parameters.items():
38
+ if name in ignores:
39
+ continue
40
+
41
+ annotation = param.annotation
42
+ if annotation is sig.empty:
43
+ annotation = Any
44
+
45
+ input_types.append(annotation)
46
+
47
+ default = param.default
48
+ field = Field() if default is sig.empty else Field(default=default)
49
+ fields[name] = (annotation, field)
50
+
51
+ input_model = None
52
+ bundled_model = None
53
+ if len(input_types) == 1:
54
+ for annotation in input_types:
55
+ origin = get_origin(annotation) or annotation
56
+ if isinstance(origin, type) and issubclass(origin, BaseModel):
57
+ input_model = origin
58
+ if input_model is None:
59
+ name = "".join(word.capitalize() for word in fn.__name__.split("_"))
60
+ bundled_model = create_model(
61
+ f"{name}Request*",
62
+ __config__=ConfigDict(extra="forbid"),
63
+ **fields
64
+ )
65
+
66
+ output_type = None
67
+ output_model = None
68
+ output_annotation = sig.return_annotation
69
+ if output_annotation is not None and output_annotation is not NoReturn:
70
+ if output_annotation is sig.empty:
71
+ output_annotation = Any
72
+ output_type = output_annotation
73
+ output_model = RootModel[output_annotation]
74
+
75
+ return APISignature(
76
+ input_types=input_types,
77
+ input_model=input_model,
78
+ bundled_model=bundled_model,
79
+ output_type=output_type,
80
+ output_model=output_model
81
+ )
82
+
83
+
24
84
  class SchemaField(BaseModel):
25
85
  name: str = Field()
26
86
  type: Union[str, List[str]] = Field("Any")
@@ -75,13 +135,14 @@ def type_parser(fn):
75
135
 
76
136
 
77
137
  @type_parser
78
- def _parse_basic_types(context: ParseContext):
138
+ def _parse_basic_types(context: ParseContext) -> Optional[str]:
79
139
  if context.origin in {int, float, str, bool}:
80
140
  return context.origin.__name__
141
+ return None
81
142
 
82
143
 
83
144
  @type_parser
84
- def _parse_dict(context: ParseContext):
145
+ def _parse_dict(context: ParseContext) -> Optional[str]:
85
146
  if issubclass(context.origin, Mapping):
86
147
  dict_args = get_args(context.annotation)
87
148
  if dict_args:
@@ -95,10 +156,11 @@ def _parse_dict(context: ParseContext):
95
156
  return f"Dict[{key_type},{value_type}]"
96
157
  else:
97
158
  return "Dict"
159
+ return None
98
160
 
99
161
 
100
162
  @type_parser
101
- def _parse_list(context: ParseContext):
163
+ def _parse_list(context: ParseContext) -> Optional[str]:
102
164
  if issubclass(context.origin, Sequence):
103
165
  list_args = get_args(context.annotation)
104
166
  if list_args:
@@ -110,16 +172,18 @@ def _parse_list(context: ParseContext):
110
172
  return f"List[{elem_type}]"
111
173
  else:
112
174
  return "List"
175
+ return None
113
176
 
114
177
 
115
178
  @type_parser
116
- def _parse_enum(context: ParseContext):
179
+ def _parse_enum(context: ParseContext) -> Optional[str]:
117
180
  if issubclass(context.origin, enum.Enum):
118
181
  return f"Enum[{','.join(e.name for e in context.origin)}]"
182
+ return None
119
183
 
120
184
 
121
185
  @type_parser
122
- def _parse_base_model(context: ParseContext):
186
+ def _parse_base_model(context: ParseContext) -> Optional[str]:
123
187
  origin = context.origin
124
188
  if issubclass(origin, BaseModel):
125
189
  _module = origin.__module__
@@ -130,6 +194,11 @@ def _parse_base_model(context: ParseContext):
130
194
  is_not_base_class = origin is not BaseModel
131
195
  if is_new_model and is_not_base_class:
132
196
  schema = Schema(name=model_name)
197
+ context.schemas[model_name] = schema
198
+ # Once the Schema object is created, it should be put into the context immediately.
199
+ # This is to prevent this object from being repeatedly parsed.
200
+ # Repeated parsing will cause dead recursion!
201
+
133
202
  fields = origin.model_fields
134
203
  assert isinstance(fields, Mapping)
135
204
  for name, field in fields.items():
@@ -152,13 +221,13 @@ def _parse_base_model(context: ParseContext):
152
221
  if is_dataclass(md):
153
222
  schema_field.metadata.update(asdict(md))
154
223
  schema.fields.append(schema_field)
155
- context.schemas[model_name] = schema
156
224
 
157
225
  return model_name
226
+ return None
158
227
 
159
228
 
160
229
  @type_parser
161
- def _parse_iterable(context: ParseContext):
230
+ def _parse_iterable(context: ParseContext) -> Optional[str]:
162
231
  if context.origin.__name__ in {"Iterable", "Generator", "range"} and issubclass(context.origin, Iterable):
163
232
  iter_args = get_args(context.annotation)
164
233
  if len(iter_args) != 1:
@@ -167,20 +236,23 @@ def _parse_iterable(context: ParseContext):
167
236
  if isinstance(iter_type, list):
168
237
  raise TypeError("\"Union\" cannot be used as the type of iterable elements.")
169
238
  return f"Iter[{iter_type}]"
239
+ return None
170
240
 
171
241
 
172
242
  @type_parser
173
- def _parse_none_type(context: ParseContext):
243
+ def _parse_none_type(context: ParseContext) -> Optional[str]:
174
244
  origin = context.origin
175
245
  if origin.__module__ == "builtins" and origin.__name__ == "NoneType":
176
246
  return "NoneType"
247
+ return None
177
248
 
178
249
 
179
250
  @type_parser
180
- def _parse_ndarray(context: ParseContext):
251
+ def _parse_ndarray(context: ParseContext) -> Optional[str]:
181
252
  origin = context.origin
182
253
  if origin.__module__ == "numpy" and origin.__name__ == "ndarray":
183
254
  return "numpy.ndarray"
255
+ return None
184
256
 
185
257
 
186
258
  def generic_parser(fn):
@@ -189,25 +261,28 @@ def generic_parser(fn):
189
261
 
190
262
 
191
263
  @generic_parser
192
- def _parse_any(context: ParseContext):
264
+ def _parse_any(context: ParseContext) -> Optional[str]:
193
265
  if context.origin is Any or str(context.origin) == str(Any):
194
266
  return "Any"
267
+ return None
195
268
 
196
269
 
197
270
  @generic_parser
198
- def _parse_union(context: ParseContext):
271
+ def _parse_union(context: ParseContext) -> Optional[List[str]]:
199
272
  if context.origin is Union or str(context.origin) == str(Union):
200
273
  return [
201
274
  parse_type(arg, context.schemas)
202
275
  for arg in get_args(context.annotation)
203
276
  ]
277
+ return None
204
278
 
205
279
 
206
280
  @generic_parser
207
- def _parse_literal(context: ParseContext):
281
+ def _parse_literal(context: ParseContext) -> Optional[str]:
208
282
  if context.origin is Literal or str(context.origin) == str(Literal):
209
283
  enum_args = get_args(context.annotation)
210
284
  return f"Enum[{','.join(map(str, enum_args))}]"
285
+ return None
211
286
 
212
287
 
213
288
  class QueryAPIOutput(BaseModel):
@@ -217,49 +292,33 @@ class QueryAPIOutput(BaseModel):
217
292
  bundled_input: bool
218
293
 
219
294
 
220
- def query_api(fn) -> QueryAPIOutput:
221
- sig = signature(fn)
222
-
223
- fields = {}
224
- for name, param in sig.parameters.items():
225
- if name in ["self", "cls"]:
226
- continue
227
-
228
- annotation = param.annotation
229
- if annotation is sig.empty:
230
- annotation = Any
231
-
232
- default = param.default
233
- field = Field() if default is sig.empty else Field(default)
234
- fields[name] = (annotation, field)
235
-
236
- args_model = None
237
- if len(fields) == 1:
238
- for annotation, _ in fields.values():
239
- origin = get_origin(annotation)
240
- if origin is None:
241
- origin = annotation
242
- if isinstance(origin, type) and issubclass(origin, BaseModel):
243
- args_model = origin
244
- bundle = args_model is None
245
- if bundle:
246
- name = "".join(word.capitalize() for word in fn.__name__.split("_"))
247
- args_model = create_model(f"{name}Request*", **fields)
295
+ def query_api(obj) -> QueryAPIOutput:
296
+ api_models = obj if isinstance(obj, APISignature) else get_api_signature(obj)
248
297
 
249
298
  context = {}
299
+
300
+ args_model = api_models.input_model or api_models.bundled_model
250
301
  input_schema = parse_type(args_model, context)
302
+
251
303
  output_schema = None
252
- return_annotation = sig.return_annotation
253
- if return_annotation is not None and return_annotation is not NoReturn:
254
- if return_annotation is sig.empty:
255
- return_annotation = Any
256
- output_schema = parse_type(return_annotation, context)
304
+ if api_models.output_type is not None:
305
+ output_schema = parse_type(api_models.output_type, context)
257
306
  if isinstance(output_schema, list):
258
307
  output_schema = output_schema[0]
259
308
 
309
+ # output_schema = None
310
+ # sig = signature(fn)
311
+ # return_annotation = sig.return_annotation
312
+ # if return_annotation is not None and return_annotation is not NoReturn:
313
+ # if return_annotation is sig.empty:
314
+ # return_annotation = Any
315
+ # output_schema = parse_type(return_annotation, context)
316
+ # if isinstance(output_schema, list):
317
+ # output_schema = output_schema[0]
318
+
260
319
  return QueryAPIOutput(
261
320
  input_schema=input_schema,
262
321
  output_schema=output_schema,
263
322
  context=context,
264
- bundled_input=bundle,
323
+ bundled_input=api_models.bundled_model is not None,
265
324
  )
libentry/service/flask.py CHANGED
@@ -380,7 +380,7 @@ def run_service(
380
380
  if backlog is None or backlog < num_threads * 2:
381
381
  backlog = num_threads * 2
382
382
 
383
- def ssl_context(config, default_ssl_context_factory):
383
+ def ssl_context(config, _default_ssl_context_factory):
384
384
  import ssl
385
385
  context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
386
386
  context.load_cert_chain(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: libentry
3
- Version: 1.22.4
3
+ Version: 1.23.1
4
4
  Summary: Entries for experimental utilities.
5
5
  Home-page: https://github.com/XoriieInpottn/libentry
6
6
  Author: xi
@@ -29,6 +29,7 @@ Dynamic: requires-dist
29
29
  Dynamic: summary
30
30
 
31
31
  # libentry
32
+ [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/XoriieInpottn/libentry)
32
33
 
33
34
  ## Define a Service Class
34
35
  1. Define a normal python class.
@@ -6,21 +6,25 @@ libentry/executor.py,sha256=cTV0WxJi0nU1TP-cOwmeodN8DD6L1691M2HIQsJtGrU,6582
6
6
  libentry/experiment.py,sha256=ejgAHDXWIe9x4haUzIFuz1WasLY0_aD1z_vyEVGjTu8,4922
7
7
  libentry/json.py,sha256=CubUUu29h7idLaC4d66vKhjBgVHKN1rZOv-Tw2qM17k,1916
8
8
  libentry/logging.py,sha256=IiYoCUzm8XTK1fduA-NA0FI2Qz_m81NEPV3d3tEfgdI,1349
9
- libentry/schema.py,sha256=o6JcdR00Yj4_Qjmlo100OlQpMVnl0PgvvwVVrL9limw,8268
9
+ libentry/schema.py,sha256=jf_rNe-1VRd4TxcZH-M9TmMJYZTHgfqWggpUCq8bNUU,10337
10
10
  libentry/test_api.py,sha256=Xw7B7sH6g1iCTV5sFzyBF3JAJzeOr9xg0AyezTNsnIk,4452
11
11
  libentry/utils.py,sha256=O7P6GadtUIjq0N2IZH7PhHZDUM3NebzcqyDqytet7CM,683
12
+ libentry/mcp/__init__.py,sha256=1oLL20yLB1GL9IbFiZD8OReDqiCpFr-yetIR6x1cNkI,23
13
+ libentry/mcp/api.py,sha256=uoGBYCesMj6umlJpRulKZNS3trm9oG3LUSg1otPDS_8,2362
14
+ libentry/mcp/client.py,sha256=lM_bTF40pbdYdBrMmoOqUDRzlNgjqEKh5d4IVkpI6D8,21512
15
+ libentry/mcp/service.py,sha256=KDpEUhHuyVXjc_J5Z9_aciJbTcEy9dYA44rpdgAAwGE,32322
16
+ libentry/mcp/types.py,sha256=xTQCnKAgeJNss4klJ33MrWHGCzG_LeR3urizO_Z9q9U,12239
12
17
  libentry/service/__init__.py,sha256=1oLL20yLB1GL9IbFiZD8OReDqiCpFr-yetIR6x1cNkI,23
13
18
  libentry/service/common.py,sha256=OVaW2afgKA6YqstJmtnprBCqQEUZEWotZ6tHavmJJeU,42
14
- libentry/service/flask.py,sha256=SDaZnhkS3Zk6y8CytVO_awwQ3RUiY7qSuMkYAgTu_SU,13816
15
- libentry/service/flask_mcp.py,sha256=guzDVVT4gfjhFhnLbMSTWYARyxqbEv1gDaI6SLKurdU,11540
19
+ libentry/service/flask.py,sha256=2egCFFhRAfLpmSyibgaJ-3oexI-j27P1bmaPEn-hSlc,13817
16
20
  libentry/service/list.py,sha256=ElHWhTgShGOhaxMUEwVbMXos0NQKjHsODboiQ-3AMwE,1397
17
21
  libentry/service/running.py,sha256=FrPJoJX6wYxcHIysoatAxhW3LajCCm0Gx6l7__6sULQ,5105
18
22
  libentry/service/start.py,sha256=mZT7b9rVULvzy9GTZwxWnciCHgv9dbGN2JbxM60OMn4,1270
19
23
  libentry/service/stop.py,sha256=wOpwZgrEJ7QirntfvibGq-XsTC6b3ELhzRW2zezh-0s,1187
20
- libentry-1.22.4.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
21
- libentry-1.22.4.dist-info/METADATA,sha256=B9WWlSfqYclWBLafid59zTg2GMh78j5cp7uOGTKznUo,1040
22
- libentry-1.22.4.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
23
- libentry-1.22.4.dist-info/entry_points.txt,sha256=1v_nLVDsjvVJp9SWhl4ef2zZrsLTBtFWgrYFgqvQBgc,61
24
- libentry-1.22.4.dist-info/top_level.txt,sha256=u2uF6-X5fn2Erf9PYXOg_6tntPqTpyT-yzUZrltEd6I,9
25
- libentry-1.22.4.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
26
- libentry-1.22.4.dist-info/RECORD,,
24
+ libentry-1.23.1.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
25
+ libentry-1.23.1.dist-info/METADATA,sha256=HUq8bgDgb7i2Dm9zNcXAzYE-0cpzwYbDCNUnkHEHsks,1135
26
+ libentry-1.23.1.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
27
+ libentry-1.23.1.dist-info/entry_points.txt,sha256=1v_nLVDsjvVJp9SWhl4ef2zZrsLTBtFWgrYFgqvQBgc,61
28
+ libentry-1.23.1.dist-info/top_level.txt,sha256=u2uF6-X5fn2Erf9PYXOg_6tntPqTpyT-yzUZrltEd6I,9
29
+ libentry-1.23.1.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
30
+ libentry-1.23.1.dist-info/RECORD,,
@@ -1,337 +0,0 @@
1
- #!/usr/bin/env python3
2
-
3
- __author__ = "xi"
4
- __all__ = [
5
- "MCPMethod",
6
- ]
7
-
8
- import asyncio
9
- import traceback
10
- from inspect import signature
11
- from types import GeneratorType
12
- from typing import Callable, Dict, Iterable, Optional, Type, Union
13
-
14
- from flask import Flask, request as flask_request
15
- from pydantic import BaseModel
16
-
17
- from libentry import api, json, logger
18
- from libentry.api import list_api_info
19
- from libentry.schema import query_api
20
-
21
- try:
22
- from gunicorn.app.base import BaseApplication
23
- except ImportError:
24
- class BaseApplication:
25
-
26
- def load(self) -> Flask:
27
- pass
28
-
29
- def run(self):
30
- flask_server = self.load()
31
- assert hasattr(self, "options")
32
- bind = getattr(self, "options")["bind"]
33
- pos = bind.rfind(":")
34
- host = bind[:pos]
35
- port = int(bind[pos + 1:])
36
- logger.warn("Your system doesn't support gunicorn.")
37
- logger.warn("Use Flask directly.")
38
- logger.warn("Options like \"num_threads\", \"num_workers\" are ignored.")
39
- return flask_server.run(host=host, port=port)
40
-
41
-
42
- class MCPMethod:
43
-
44
- def __init__(self, fn: Callable, method: str = None):
45
- self.fn = fn
46
- assert hasattr(fn, "__name__")
47
- self.__name__ = fn.__name__
48
- self.method = self.__name__ if method is None else method
49
-
50
- self.input_schema = None
51
- params = signature(fn).parameters
52
- if len(params) == 1:
53
- for name, value in params.items():
54
- annotation = value.annotation
55
- if isinstance(annotation, type) and issubclass(annotation, BaseModel):
56
- self.input_schema = annotation
57
-
58
- def __call__(self, request: dict) -> Union[dict, Iterable[dict]]:
59
- try:
60
- jsonrpc_version = request["jsonrpc"]
61
- request_id = request["id"]
62
- method = request["method"]
63
- except KeyError:
64
- raise RuntimeError("Invalid JSON-RPC specification.")
65
-
66
- if not isinstance(request_id, (str, int)):
67
- raise RuntimeError(
68
- f"Request ID should be an integer or string. "
69
- f"Got {type(request_id)}."
70
- )
71
-
72
- if method != self.method:
73
- raise RuntimeError(
74
- f"Method missmatch."
75
- f"Expect {self.method}, got {method}."
76
- )
77
-
78
- params = request.get("params", {})
79
-
80
- try:
81
- if self.input_schema is not None:
82
- # Note that "input_schema is not None" means:
83
- # (1) The function has only one argument;
84
- # (2) The arguments is a BaseModel.
85
- # In this case, the request data can be directly validated as a "BaseModel" and
86
- # subsequently passed to the function as a single object.
87
- pydantic_params = self.input_schema.model_validate(params)
88
- result = self.fn(pydantic_params)
89
- else:
90
- # The function has multiple arguments, and the request data bundle them as a single object.
91
- # So, they should be unpacked before pass to the function.
92
- result = self.fn(**params)
93
- except Exception as e:
94
- if isinstance(e, (SystemExit, KeyboardInterrupt)):
95
- raise e
96
- return {
97
- "jsonrpc": jsonrpc_version,
98
- "id": request_id,
99
- "error": self._make_error(e)
100
- }
101
-
102
- if not isinstance(result, (GeneratorType, range)):
103
- return {
104
- "jsonrpc": jsonrpc_version,
105
- "id": request_id,
106
- "result": result
107
- }
108
-
109
- return ({
110
- "jsonrpc": jsonrpc_version,
111
- "id": request_id,
112
- "result": item
113
- } for item in result)
114
-
115
- @staticmethod
116
- def _make_error(e):
117
- err_cls = e.__class__
118
- err_name = err_cls.__name__
119
- module = err_cls.__module__
120
- if module != "builtins":
121
- err_name = f"{module}.{err_name}"
122
- return {
123
- "code": 1,
124
- "message": f"{err_name}: {str(e)}",
125
- "data": traceback.format_exc()
126
- }
127
-
128
-
129
- class FlaskMethod:
130
-
131
- def __init__(self, method, api_info, app):
132
- self.method = MCPMethod(method)
133
- self.api_info = api_info
134
- self.app = app
135
- assert hasattr(method, "__name__")
136
- self.__name__ = method.__name__
137
-
138
- CONTENT_TYPE_JSON = "application/json"
139
- CONTENT_TYPE_SSE = "text/event-stream"
140
-
141
- def __call__(self):
142
- args = flask_request.args
143
- data = flask_request.data
144
- content_type = flask_request.content_type
145
- accepts = flask_request.accept_mimetypes
146
-
147
- json_from_url = {**args}
148
- if data:
149
- if (not content_type) or content_type == self.CONTENT_TYPE_JSON:
150
- json_from_data = json.loads(data)
151
- else:
152
- return self.app.error(f"Unsupported Content-Type: \"{content_type}\".")
153
- else:
154
- json_from_data = {}
155
-
156
- conflicts = json_from_url.keys() & json_from_data.keys()
157
- if len(conflicts) > 0:
158
- return self.app.error(f"Duplicated fields: \"{conflicts}\".")
159
-
160
- input_json = {**json_from_url, **json_from_data}
161
- print(input_json)
162
-
163
- try:
164
- output_json = self.method(input_json)
165
- except Exception as e:
166
- return self.app.error(str(e))
167
-
168
- if isinstance(output_json, Dict):
169
- if self.CONTENT_TYPE_JSON in accepts:
170
- return self.app.ok(json.dumps(output_json), mimetype=self.CONTENT_TYPE_JSON)
171
- else:
172
- return self.app.error(f"Unsupported Accept: \"{[*accepts]}\".")
173
- elif isinstance(output_json, (GeneratorType, range)):
174
- if self.CONTENT_TYPE_SSE in accepts:
175
- # todo
176
- return self.app.ok(json.dumps(output_json), mimetype=self.CONTENT_TYPE_SSE)
177
- else:
178
- return self.app.error(f"Unsupported Accept: \"{[*accepts]}\".")
179
-
180
-
181
- class FlaskServer(Flask):
182
-
183
- def __init__(self, service):
184
- super().__init__(__name__)
185
- self.service = service
186
-
187
- logger.info("Initializing Flask application.")
188
- self.api_info_list = list_api_info(service)
189
- if len(self.api_info_list) == 0:
190
- logger.error("No API found, nothing to serve.")
191
- return
192
-
193
- for fn, api_info in self.api_info_list:
194
- method = api_info.method
195
- path = api_info.path
196
- if asyncio.iscoroutinefunction(fn):
197
- logger.error(f"Async function \"{fn.__name__}\" is not supported.")
198
- continue
199
- logger.info(f"Serving {method}-API for {path}")
200
-
201
- wrapped_fn = FlaskMethod(fn, api_info, self)
202
- if method == "GET":
203
- self.get(path)(wrapped_fn)
204
- elif method == "POST":
205
- self.post(path)(wrapped_fn)
206
- else:
207
- raise RuntimeError(f"Unsupported method \"{method}\" for ")
208
-
209
- for fn, api_info in list_api_info(self):
210
- method = api_info.method
211
- path = api_info.path
212
-
213
- if any(api_info.path == a.path for _, a in self.api_info_list):
214
- logger.info(f"Use custom implementation of {path}.")
215
- continue
216
-
217
- if asyncio.iscoroutinefunction(fn):
218
- logger.error(f"Async function \"{fn.__name__}\" is not supported.")
219
- continue
220
- logger.info(f"Serving {method}-API for {path}")
221
-
222
- wrapped_fn = FlaskMethod(fn, api_info, self)
223
- if method == "GET":
224
- self.get(path)(wrapped_fn)
225
- elif method == "POST":
226
- self.post(path)(wrapped_fn)
227
- else:
228
- raise RuntimeError(f"Unsupported method \"{method}\" for ")
229
-
230
- logger.info("Flask application initialized.")
231
-
232
- @api.get("/")
233
- def index(self, name: str = None):
234
- if name is None:
235
- all_api = []
236
- for _, api_info in self.api_info_list:
237
- all_api.append({"path": api_info.path})
238
- return all_api
239
-
240
- for fn, api_info in self.api_info_list:
241
- if api_info.path == "/" + name:
242
- return query_api(fn).model_dump()
243
-
244
- return f"No API named \"{name}\""
245
-
246
- @api.get()
247
- def live(self):
248
- return "OK"
249
-
250
- def ok(self, body: Union[str, Iterable[str]], mimetype: str):
251
- return self.response_class(body, status=200, mimetype=mimetype)
252
-
253
- def error(self, body: str, mimetype="text"):
254
- return self.response_class(body, status=500, mimetype=mimetype)
255
-
256
-
257
- class GunicornApplication(BaseApplication):
258
-
259
- def __init__(self, service_type, service_config=None, options=None):
260
- self.service_type = service_type
261
- self.service_config = service_config
262
- self.options = options or {}
263
- super().__init__()
264
-
265
- def load_config(self):
266
- config = {
267
- key: value
268
- for key, value in self.options.items()
269
- if key in self.cfg.settings and value is not None
270
- }
271
- for key, value in config.items():
272
- self.cfg.set(key.lower(), value)
273
-
274
- def load(self):
275
- logger.info("Initializing the service.")
276
- if isinstance(self.service_type, type) or callable(self.service_type):
277
- service = self.service_type(self.service_config) if self.service_config else self.service_type()
278
- elif self.service_config is None:
279
- logger.warning(
280
- "Be careful! It is not recommended to start the server from a service instance. "
281
- "Use service_type and service_config instead."
282
- )
283
- service = self.service_type
284
- else:
285
- raise TypeError(f"Invalid service type \"{type(self.service_type)}\".")
286
- logger.info("Service initialized.")
287
-
288
- return FlaskServer(service)
289
-
290
-
291
- def run_service(
292
- service_type: Union[Type, Callable],
293
- service_config=None,
294
- host: str = "0.0.0.0",
295
- port: int = 8888,
296
- num_workers: int = 1,
297
- num_threads: int = 20,
298
- num_connections: Optional[int] = 1000,
299
- backlog: Optional[int] = 1000,
300
- worker_class: str = "gthread",
301
- timeout: int = 60,
302
- keyfile: Optional[str] = None,
303
- keyfile_password: Optional[str] = None,
304
- certfile: Optional[str] = None
305
- ):
306
- logger.info("Starting gunicorn server.")
307
- if num_connections is None or num_connections < num_threads * 2:
308
- num_connections = num_threads * 2
309
- if backlog is None or backlog < num_threads * 2:
310
- backlog = num_threads * 2
311
-
312
- def ssl_context(config, default_ssl_context_factory):
313
- import ssl
314
- context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
315
- context.load_cert_chain(
316
- certfile=config.certfile,
317
- keyfile=config.keyfile,
318
- password=keyfile_password
319
- )
320
- context.minimum_version = ssl.TLSVersion.TLSv1_3
321
- return context
322
-
323
- options = {
324
- "bind": f"{host}:{port}",
325
- "workers": num_workers,
326
- "threads": num_threads,
327
- "timeout": timeout,
328
- "worker_connections": num_connections,
329
- "backlog": backlog,
330
- "keyfile": keyfile,
331
- "certfile": certfile,
332
- "worker_class": worker_class,
333
- "ssl_context": ssl_context
334
- }
335
- for name, value in options.items():
336
- logger.info(f"Option {name}: {value}")
337
- GunicornApplication(service_type, service_config, options).run()