agenta 0.32.0__py3-none-any.whl → 0.32.0a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of agenta might be problematic. Click here for more details.
- agenta/__init__.py +3 -1
- agenta/client/backend/client.py +22 -14
- agenta/client/backend/core/http_client.py +3 -3
- agenta/sdk/__init__.py +1 -1
- agenta/sdk/context/routing.py +1 -0
- agenta/sdk/decorators/routing.py +164 -476
- agenta/sdk/decorators/tracing.py +16 -4
- agenta/sdk/litellm/litellm.py +44 -8
- agenta/sdk/litellm/mockllm.py +27 -0
- agenta/sdk/litellm/mocks/__init__.py +32 -0
- agenta/sdk/managers/vault.py +16 -0
- agenta/sdk/middleware/auth.py +5 -1
- agenta/sdk/middleware/config.py +16 -7
- agenta/sdk/middleware/inline.py +38 -0
- agenta/sdk/middleware/mock.py +33 -0
- agenta/sdk/middleware/vault.py +6 -19
- agenta/sdk/tracing/exporters.py +0 -1
- agenta/sdk/tracing/inline.py +23 -29
- agenta/sdk/types.py +334 -4
- {agenta-0.32.0.dist-info → agenta-0.32.0a2.dist-info}/METADATA +1 -1
- {agenta-0.32.0.dist-info → agenta-0.32.0a2.dist-info}/RECORD +23 -18
- {agenta-0.32.0.dist-info → agenta-0.32.0a2.dist-info}/WHEEL +0 -0
- {agenta-0.32.0.dist-info → agenta-0.32.0a2.dist-info}/entry_points.txt +0 -0
agenta/sdk/decorators/routing.py
CHANGED
|
@@ -1,19 +1,20 @@
|
|
|
1
|
+
from sre_parse import NOT_LITERAL_UNI_IGNORE
|
|
1
2
|
from typing import Type, Any, Callable, Dict, Optional, Tuple, List
|
|
2
|
-
from inspect import signature, iscoroutinefunction, Signature, Parameter
|
|
3
|
+
from inspect import signature, iscoroutinefunction, Signature, Parameter
|
|
3
4
|
from functools import wraps
|
|
4
5
|
from traceback import format_exception
|
|
5
6
|
from asyncio import sleep
|
|
6
|
-
|
|
7
|
-
from tempfile import NamedTemporaryFile
|
|
8
|
-
from annotated_types import Ge, Le, Gt, Lt
|
|
7
|
+
from uuid import UUID
|
|
9
8
|
from pydantic import BaseModel, HttpUrl, ValidationError
|
|
10
9
|
|
|
11
|
-
from fastapi import Body, FastAPI,
|
|
10
|
+
from fastapi import Body, FastAPI, HTTPException, Request
|
|
12
11
|
|
|
13
|
-
from agenta.sdk.middleware.
|
|
14
|
-
from agenta.sdk.middleware.
|
|
15
|
-
from agenta.sdk.middleware.config import ConfigMiddleware
|
|
12
|
+
from agenta.sdk.middleware.mock import MockMiddleware
|
|
13
|
+
from agenta.sdk.middleware.inline import InlineMiddleware
|
|
16
14
|
from agenta.sdk.middleware.vault import VaultMiddleware
|
|
15
|
+
from agenta.sdk.middleware.config import ConfigMiddleware
|
|
16
|
+
from agenta.sdk.middleware.otel import OTelMiddleware
|
|
17
|
+
from agenta.sdk.middleware.auth import AuthMiddleware
|
|
17
18
|
from agenta.sdk.middleware.cors import CORSMiddleware
|
|
18
19
|
|
|
19
20
|
from agenta.sdk.context.routing import (
|
|
@@ -30,18 +31,9 @@ from agenta.sdk.utils.exceptions import suppress, display_exception
|
|
|
30
31
|
from agenta.sdk.utils.logging import log
|
|
31
32
|
from agenta.sdk.utils.helpers import get_current_version
|
|
32
33
|
from agenta.sdk.types import (
|
|
33
|
-
DictInput,
|
|
34
|
-
FloatParam,
|
|
35
|
-
InFile,
|
|
36
|
-
IntParam,
|
|
37
|
-
MultipleChoiceParam,
|
|
38
34
|
MultipleChoice,
|
|
39
|
-
GroupedMultipleChoiceParam,
|
|
40
|
-
TextParam,
|
|
41
|
-
MessagesInput,
|
|
42
|
-
FileInputURL,
|
|
43
35
|
BaseResponse,
|
|
44
|
-
|
|
36
|
+
MCField,
|
|
45
37
|
)
|
|
46
38
|
|
|
47
39
|
import agenta as ag
|
|
@@ -124,6 +116,7 @@ class entrypoint:
|
|
|
124
116
|
_middleware = False
|
|
125
117
|
_run_path = "/run"
|
|
126
118
|
_test_path = "/test"
|
|
119
|
+
_config_key = "ag_config"
|
|
127
120
|
# LEGACY
|
|
128
121
|
_legacy_playground_run_path = "/playground/run"
|
|
129
122
|
_legacy_generate_path = "/generate"
|
|
@@ -140,13 +133,13 @@ class entrypoint:
|
|
|
140
133
|
self.config_schema = config_schema
|
|
141
134
|
|
|
142
135
|
signature_parameters = signature(func).parameters
|
|
143
|
-
ingestible_files = self.extract_ingestible_files()
|
|
144
136
|
config, default_parameters = self.parse_config()
|
|
145
137
|
|
|
146
138
|
### --- Middleware --- #
|
|
147
139
|
if not entrypoint._middleware:
|
|
148
140
|
entrypoint._middleware = True
|
|
149
|
-
|
|
141
|
+
app.add_middleware(MockMiddleware)
|
|
142
|
+
app.add_middleware(InlineMiddleware)
|
|
150
143
|
app.add_middleware(VaultMiddleware)
|
|
151
144
|
app.add_middleware(ConfigMiddleware)
|
|
152
145
|
app.add_middleware(AuthMiddleware)
|
|
@@ -167,65 +160,84 @@ class entrypoint:
|
|
|
167
160
|
}
|
|
168
161
|
# LEGACY
|
|
169
162
|
|
|
170
|
-
kwargs, _ = self.
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
163
|
+
kwargs, _ = self.process_kwargs(kwargs, default_parameters)
|
|
164
|
+
if (
|
|
165
|
+
request.state.config["parameters"] is None
|
|
166
|
+
or request.state.config["references"] is None
|
|
167
|
+
):
|
|
168
|
+
raise HTTPException(
|
|
169
|
+
status_code=400,
|
|
170
|
+
detail="Config not found based on provided references.",
|
|
171
|
+
)
|
|
174
172
|
|
|
175
|
-
return await self.execute_wrapper(request,
|
|
173
|
+
return await self.execute_wrapper(request, *args, **kwargs)
|
|
176
174
|
|
|
177
|
-
self.update_run_wrapper_signature(
|
|
178
|
-
wrapper=run_wrapper,
|
|
179
|
-
ingestible_files=ingestible_files,
|
|
180
|
-
)
|
|
175
|
+
self.update_run_wrapper_signature(wrapper=run_wrapper)
|
|
181
176
|
|
|
182
177
|
run_route = f"{entrypoint._run_path}{route_path}"
|
|
183
|
-
app.post(
|
|
178
|
+
app.post(
|
|
179
|
+
run_route,
|
|
180
|
+
response_model=BaseResponse,
|
|
181
|
+
response_model_exclude_none=True,
|
|
182
|
+
)(run_wrapper)
|
|
184
183
|
|
|
185
184
|
# LEGACY
|
|
186
185
|
# TODO: Removing this implies breaking changes in :
|
|
187
186
|
# - calls to /generate_deployed must be replaced with calls to /run
|
|
188
187
|
if route_path == "":
|
|
189
188
|
run_route = entrypoint._legacy_generate_deployed_path
|
|
190
|
-
app.post(
|
|
189
|
+
app.post(
|
|
190
|
+
run_route,
|
|
191
|
+
response_model=BaseResponse,
|
|
192
|
+
response_model_exclude_none=True,
|
|
193
|
+
)(run_wrapper)
|
|
191
194
|
# LEGACY
|
|
192
195
|
### ----------- #
|
|
193
196
|
|
|
194
197
|
### --- Test --- #
|
|
195
198
|
@wraps(func)
|
|
196
199
|
async def test_wrapper(request: Request, *args, **kwargs) -> Any:
|
|
197
|
-
kwargs,
|
|
198
|
-
|
|
199
|
-
request.state.config["parameters"] =
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
config_class=config,
|
|
210
|
-
config_dict=default_parameters,
|
|
211
|
-
)
|
|
200
|
+
kwargs, config = self.process_kwargs(kwargs, default_parameters)
|
|
201
|
+
request.state.inline = True
|
|
202
|
+
request.state.config["parameters"] = config
|
|
203
|
+
if request.state.config["references"]:
|
|
204
|
+
request.state.config["references"] = {
|
|
205
|
+
k: v
|
|
206
|
+
for k, v in request.state.config["references"].items()
|
|
207
|
+
if k.startswith("application")
|
|
208
|
+
} or None
|
|
209
|
+
return await self.execute_wrapper(request, *args, **kwargs)
|
|
210
|
+
|
|
211
|
+
self.update_test_wrapper_signature(wrapper=test_wrapper, config_instance=config)
|
|
212
212
|
|
|
213
213
|
test_route = f"{entrypoint._test_path}{route_path}"
|
|
214
|
-
app.post(
|
|
214
|
+
app.post(
|
|
215
|
+
test_route,
|
|
216
|
+
response_model=BaseResponse,
|
|
217
|
+
response_model_exclude_none=True,
|
|
218
|
+
)(test_wrapper)
|
|
215
219
|
|
|
216
220
|
# LEGACY
|
|
217
221
|
# TODO: Removing this implies breaking changes in :
|
|
218
222
|
# - calls to /generate must be replaced with calls to /test
|
|
219
223
|
if route_path == "":
|
|
220
224
|
test_route = entrypoint._legacy_generate_path
|
|
221
|
-
app.post(
|
|
225
|
+
app.post(
|
|
226
|
+
test_route,
|
|
227
|
+
response_model=BaseResponse,
|
|
228
|
+
response_model_exclude_none=True,
|
|
229
|
+
)(test_wrapper)
|
|
222
230
|
# LEGACY
|
|
223
231
|
|
|
224
232
|
# LEGACY
|
|
225
233
|
# TODO: Removing this implies no breaking changes
|
|
226
234
|
if route_path == "":
|
|
227
235
|
test_route = entrypoint._legacy_playground_run_path
|
|
228
|
-
app.post(
|
|
236
|
+
app.post(
|
|
237
|
+
test_route,
|
|
238
|
+
response_model=BaseResponse,
|
|
239
|
+
response_model_exclude_none=True,
|
|
240
|
+
)(test_wrapper)
|
|
229
241
|
# LEGACY
|
|
230
242
|
### ------------ #
|
|
231
243
|
|
|
@@ -235,11 +247,7 @@ class entrypoint:
|
|
|
235
247
|
{
|
|
236
248
|
"func": func.__name__,
|
|
237
249
|
"endpoint": test_route,
|
|
238
|
-
"params":
|
|
239
|
-
{**default_parameters, **signature_parameters}
|
|
240
|
-
if not config
|
|
241
|
-
else signature_parameters
|
|
242
|
-
),
|
|
250
|
+
"params": signature_parameters,
|
|
243
251
|
"config": config,
|
|
244
252
|
}
|
|
245
253
|
)
|
|
@@ -263,18 +271,9 @@ class entrypoint:
|
|
|
263
271
|
|
|
264
272
|
app.openapi_schema = None # Forces FastAPI to re-generate the schema
|
|
265
273
|
openapi_schema = app.openapi()
|
|
266
|
-
|
|
267
274
|
openapi_schema["agenta_sdk"] = {"version": get_current_version()}
|
|
268
|
-
|
|
269
275
|
for _route in entrypoint.routes:
|
|
270
|
-
|
|
271
|
-
openapi_schema=openapi_schema,
|
|
272
|
-
func_name=_route["func"],
|
|
273
|
-
endpoint=_route["endpoint"],
|
|
274
|
-
params=_route["params"],
|
|
275
|
-
)
|
|
276
|
-
|
|
277
|
-
if _route["config"] is not None: # new SDK version
|
|
276
|
+
if _route["config"] is not None:
|
|
278
277
|
self.override_config_in_schema(
|
|
279
278
|
openapi_schema=openapi_schema,
|
|
280
279
|
func_name=_route["func"],
|
|
@@ -283,23 +282,15 @@ class entrypoint:
|
|
|
283
282
|
)
|
|
284
283
|
### --------------- #
|
|
285
284
|
|
|
286
|
-
def
|
|
287
|
-
"""
|
|
288
|
-
|
|
289
|
-
return {
|
|
290
|
-
name: param
|
|
291
|
-
for name, param in signature(self.func).parameters.items()
|
|
292
|
-
if param.annotation is InFile
|
|
293
|
-
}
|
|
294
|
-
|
|
295
|
-
def parse_config(self) -> Dict[str, Any]:
|
|
285
|
+
def parse_config(self) -> Tuple[Optional[Type[BaseModel]], Dict[str, Any]]:
|
|
286
|
+
"""Parse the config schema and return the config class and default parameters."""
|
|
296
287
|
config = None
|
|
297
|
-
default_parameters =
|
|
288
|
+
default_parameters = {}
|
|
298
289
|
|
|
299
290
|
if self.config_schema:
|
|
300
291
|
try:
|
|
301
292
|
config = self.config_schema() if self.config_schema else None
|
|
302
|
-
default_parameters = config.dict() if config else
|
|
293
|
+
default_parameters = config.dict() if config else {}
|
|
303
294
|
except ValidationError as e:
|
|
304
295
|
raise ValueError(
|
|
305
296
|
f"Error initializing config_schema. Please ensure all required fields have default values: {str(e)}"
|
|
@@ -311,39 +302,22 @@ class entrypoint:
|
|
|
311
302
|
|
|
312
303
|
return config, default_parameters
|
|
313
304
|
|
|
314
|
-
def
|
|
305
|
+
def process_kwargs(
|
|
315
306
|
self, kwargs: Dict[str, Any], default_parameters: Dict[str, Any]
|
|
316
307
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
upfile: UploadFile,
|
|
325
|
-
):
|
|
326
|
-
temp_file = NamedTemporaryFile(delete=False)
|
|
327
|
-
temp_file.write(upfile.file.read())
|
|
328
|
-
temp_file.close()
|
|
329
|
-
|
|
330
|
-
return InFile(file_name=upfile.filename, file_path=temp_file.name)
|
|
308
|
+
"""Remove the config parameters from the kwargs."""
|
|
309
|
+
# Extract agenta_config if present
|
|
310
|
+
config_params = kwargs.pop(self._config_key, {})
|
|
311
|
+
if isinstance(config_params, BaseModel):
|
|
312
|
+
config_params = config_params.dict()
|
|
313
|
+
# Merge with default parameters
|
|
314
|
+
config = {**default_parameters, **config_params}
|
|
331
315
|
|
|
332
|
-
|
|
333
|
-
self,
|
|
334
|
-
func_params: Dict[str, Any],
|
|
335
|
-
ingestible_files: Dict[str, Parameter],
|
|
336
|
-
) -> None:
|
|
337
|
-
"""Ingest files specified in function parameters."""
|
|
338
|
-
|
|
339
|
-
for name in ingestible_files:
|
|
340
|
-
if name in func_params and func_params[name] is not None:
|
|
341
|
-
func_params[name] = self.ingest_file(func_params[name])
|
|
316
|
+
return kwargs, config
|
|
342
317
|
|
|
343
318
|
async def execute_wrapper(
|
|
344
319
|
self,
|
|
345
320
|
request: Request,
|
|
346
|
-
inline: bool,
|
|
347
321
|
*args,
|
|
348
322
|
**kwargs,
|
|
349
323
|
):
|
|
@@ -355,11 +329,14 @@ class entrypoint:
|
|
|
355
329
|
parameters = state.config.get("parameters")
|
|
356
330
|
references = state.config.get("references")
|
|
357
331
|
secrets = state.vault.get("secrets")
|
|
332
|
+
inline = state.inline
|
|
333
|
+
mock = state.mock
|
|
358
334
|
|
|
359
335
|
with routing_context_manager(
|
|
360
336
|
context=RoutingContext(
|
|
361
337
|
parameters=parameters,
|
|
362
338
|
secrets=secrets,
|
|
339
|
+
mock=mock,
|
|
363
340
|
)
|
|
364
341
|
):
|
|
365
342
|
with tracing_context_manager(
|
|
@@ -369,27 +346,17 @@ class entrypoint:
|
|
|
369
346
|
references=references,
|
|
370
347
|
)
|
|
371
348
|
):
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
inline: bool,
|
|
379
|
-
*args,
|
|
380
|
-
**kwargs,
|
|
381
|
-
):
|
|
382
|
-
try:
|
|
383
|
-
result = (
|
|
384
|
-
await self.func(*args, **kwargs)
|
|
385
|
-
if iscoroutinefunction(self.func)
|
|
386
|
-
else self.func(*args, **kwargs)
|
|
387
|
-
)
|
|
349
|
+
try:
|
|
350
|
+
result = (
|
|
351
|
+
await self.func(*args, **kwargs)
|
|
352
|
+
if iscoroutinefunction(self.func)
|
|
353
|
+
else self.func(*args, **kwargs)
|
|
354
|
+
)
|
|
388
355
|
|
|
389
|
-
|
|
356
|
+
return await self.handle_success(result, inline)
|
|
390
357
|
|
|
391
|
-
|
|
392
|
-
|
|
358
|
+
except Exception as error: # pylint: disable=broad-except
|
|
359
|
+
self.handle_failure(error)
|
|
393
360
|
|
|
394
361
|
async def handle_success(
|
|
395
362
|
self,
|
|
@@ -398,17 +365,23 @@ class entrypoint:
|
|
|
398
365
|
):
|
|
399
366
|
data = None
|
|
400
367
|
tree = None
|
|
368
|
+
content_type = "text/plain"
|
|
369
|
+
tree_id = None
|
|
401
370
|
|
|
402
371
|
with suppress():
|
|
372
|
+
if isinstance(result, (dict, list)):
|
|
373
|
+
content_type = "application/json"
|
|
403
374
|
data = self.patch_result(result)
|
|
404
375
|
|
|
405
376
|
if inline:
|
|
406
|
-
tree = await self.fetch_inline_trace(inline)
|
|
377
|
+
tree, tree_id = await self.fetch_inline_trace(inline)
|
|
407
378
|
|
|
408
379
|
try:
|
|
409
|
-
return BaseResponse(
|
|
380
|
+
return BaseResponse(
|
|
381
|
+
data=data, tree=tree, content_type=content_type, tree_id=tree_id
|
|
382
|
+
)
|
|
410
383
|
except:
|
|
411
|
-
return BaseResponse(data=data)
|
|
384
|
+
return BaseResponse(data=data, content_type=content_type)
|
|
412
385
|
|
|
413
386
|
def handle_failure(
|
|
414
387
|
self,
|
|
@@ -416,12 +389,15 @@ class entrypoint:
|
|
|
416
389
|
):
|
|
417
390
|
display_exception("Application Exception")
|
|
418
391
|
|
|
419
|
-
status_code =
|
|
420
|
-
|
|
392
|
+
status_code = (
|
|
393
|
+
getattr(error, "status_code") if hasattr(error, "status_code") else 500
|
|
394
|
+
)
|
|
421
395
|
stacktrace = format_exception(error, value=error, tb=error.__traceback__) # type: ignore
|
|
422
|
-
detail = {"message": message, "stacktrace": stacktrace}
|
|
423
396
|
|
|
424
|
-
raise HTTPException(
|
|
397
|
+
raise HTTPException(
|
|
398
|
+
status_code=status_code,
|
|
399
|
+
detail={"message": str(error), "stacktrace": stacktrace},
|
|
400
|
+
)
|
|
425
401
|
|
|
426
402
|
def patch_result(
|
|
427
403
|
self,
|
|
@@ -465,42 +441,33 @@ class entrypoint:
|
|
|
465
441
|
|
|
466
442
|
async def fetch_inline_trace(
|
|
467
443
|
self,
|
|
468
|
-
inline,
|
|
444
|
+
inline: bool,
|
|
469
445
|
):
|
|
470
|
-
WAIT_FOR_SPANS = True
|
|
471
446
|
TIMEOUT = 1
|
|
472
447
|
TIMESTEP = 0.1
|
|
473
|
-
FINALSTEP = 0.001
|
|
474
448
|
NOFSTEPS = TIMEOUT / TIMESTEP
|
|
475
449
|
|
|
476
|
-
trace = None
|
|
477
|
-
|
|
478
450
|
context = tracing_context.get()
|
|
479
451
|
|
|
480
452
|
link = context.link
|
|
481
453
|
|
|
482
|
-
|
|
454
|
+
tree = None
|
|
455
|
+
_tree_id = link.get("tree_id") if link else None # in int format
|
|
456
|
+
tree_id = str(UUID(int=_tree_id)) if _tree_id else None # in uuid_as_str format
|
|
483
457
|
|
|
484
|
-
if
|
|
458
|
+
if _tree_id is not None:
|
|
485
459
|
if inline:
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
):
|
|
493
|
-
await sleep(TIMESTEP)
|
|
494
|
-
|
|
495
|
-
remaining_steps -= 1
|
|
496
|
-
|
|
497
|
-
await sleep(FINALSTEP)
|
|
460
|
+
remaining_steps = NOFSTEPS
|
|
461
|
+
while (
|
|
462
|
+
not ag.tracing.is_inline_trace_ready(_tree_id)
|
|
463
|
+
and remaining_steps > 0
|
|
464
|
+
):
|
|
465
|
+
await sleep(TIMESTEP)
|
|
498
466
|
|
|
499
|
-
|
|
500
|
-
else:
|
|
501
|
-
trace = {"trace_id": trace_id}
|
|
467
|
+
remaining_steps -= 1
|
|
502
468
|
|
|
503
|
-
|
|
469
|
+
tree = ag.tracing.get_inline_trace(_tree_id)
|
|
470
|
+
return tree, tree_id
|
|
504
471
|
|
|
505
472
|
# --- OpenAPI --- #
|
|
506
473
|
|
|
@@ -542,74 +509,56 @@ class entrypoint:
|
|
|
542
509
|
def update_test_wrapper_signature(
|
|
543
510
|
self,
|
|
544
511
|
wrapper: Callable[..., Any],
|
|
545
|
-
|
|
546
|
-
config_dict: Dict[str, Any],
|
|
547
|
-
ingestible_files: Dict[str, Parameter],
|
|
512
|
+
config_instance: Type[BaseModel], # TODO: change to our type
|
|
548
513
|
) -> None:
|
|
549
514
|
"""Update the function signature to include new parameters."""
|
|
550
515
|
|
|
551
516
|
updated_params: List[Parameter] = []
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
else:
|
|
555
|
-
self.deprecated_add_config_params_to_parser(updated_params, config_dict)
|
|
556
|
-
self.add_func_params_to_parser(updated_params, ingestible_files)
|
|
517
|
+
self.add_config_params_to_parser(updated_params, config_instance)
|
|
518
|
+
self.add_func_params_to_parser(updated_params)
|
|
557
519
|
self.update_wrapper_signature(wrapper, updated_params)
|
|
558
520
|
self.add_request_to_signature(wrapper)
|
|
559
521
|
|
|
560
522
|
def update_run_wrapper_signature(
|
|
561
523
|
self,
|
|
562
524
|
wrapper: Callable[..., Any],
|
|
563
|
-
ingestible_files: Dict[str, Parameter],
|
|
564
525
|
) -> None:
|
|
565
526
|
"""Update the function signature to include new parameters."""
|
|
566
527
|
|
|
567
528
|
updated_params: List[Parameter] = []
|
|
568
|
-
self.add_func_params_to_parser(updated_params
|
|
569
|
-
for param in [
|
|
570
|
-
"config",
|
|
571
|
-
"environment",
|
|
572
|
-
]: # we add the config and environment parameters
|
|
573
|
-
updated_params.append(
|
|
574
|
-
Parameter(
|
|
575
|
-
name=param,
|
|
576
|
-
kind=Parameter.KEYWORD_ONLY,
|
|
577
|
-
default=Body(None),
|
|
578
|
-
annotation=str,
|
|
579
|
-
)
|
|
580
|
-
)
|
|
529
|
+
self.add_func_params_to_parser(updated_params)
|
|
581
530
|
self.update_wrapper_signature(wrapper, updated_params)
|
|
582
531
|
self.add_request_to_signature(wrapper)
|
|
583
532
|
|
|
584
533
|
def add_config_params_to_parser(
|
|
585
|
-
self, updated_params: list,
|
|
534
|
+
self, updated_params: list, config_instance: Type[BaseModel]
|
|
586
535
|
) -> None:
|
|
587
536
|
"""Add configuration parameters to function signature."""
|
|
588
|
-
|
|
537
|
+
|
|
538
|
+
for name, field in config_instance.model_fields.items():
|
|
589
539
|
assert field.default is not None, f"Field {name} has no default value"
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
)
|
|
540
|
+
|
|
541
|
+
updated_params.append(
|
|
542
|
+
Parameter(
|
|
543
|
+
name=self._config_key,
|
|
544
|
+
kind=Parameter.KEYWORD_ONLY,
|
|
545
|
+
annotation=type(config_instance), # Get the actual class type
|
|
546
|
+
default=Body(config_instance), # Use the instance directly
|
|
597
547
|
)
|
|
548
|
+
)
|
|
598
549
|
|
|
599
|
-
def
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
"""Add configuration parameters to function signature."""
|
|
603
|
-
for name, param in config_dict.items():
|
|
550
|
+
def add_func_params_to_parser(self, updated_params: list) -> None:
|
|
551
|
+
"""Add function parameters to function signature."""
|
|
552
|
+
for name, param in signature(self.func).parameters.items():
|
|
604
553
|
assert (
|
|
605
|
-
len(param.__class__.__bases__) == 1
|
|
606
|
-
), f"Inherited standard type of {param.__class__} needs to be one."
|
|
554
|
+
len(param.default.__class__.__bases__) == 1
|
|
555
|
+
), f"Inherited standard type of {param.default.__class__} needs to be one."
|
|
607
556
|
updated_params.append(
|
|
608
557
|
Parameter(
|
|
609
|
-
name
|
|
610
|
-
|
|
611
|
-
default=Body(
|
|
612
|
-
annotation=param.__class__.__bases__[
|
|
558
|
+
name,
|
|
559
|
+
Parameter.KEYWORD_ONLY,
|
|
560
|
+
default=Body(..., embed=True),
|
|
561
|
+
annotation=param.default.__class__.__bases__[
|
|
613
562
|
0
|
|
614
563
|
], # determines and get the base (parent/inheritance) type of the sdk-type at run-time. \
|
|
615
564
|
# E.g __class__ is ag.MessagesInput() and accessing it parent type will return (<class 'list'>,), \
|
|
@@ -617,34 +566,6 @@ class entrypoint:
|
|
|
617
566
|
)
|
|
618
567
|
)
|
|
619
568
|
|
|
620
|
-
def add_func_params_to_parser(
|
|
621
|
-
self,
|
|
622
|
-
updated_params: list,
|
|
623
|
-
ingestible_files: Dict[str, Parameter],
|
|
624
|
-
) -> None:
|
|
625
|
-
"""Add function parameters to function signature."""
|
|
626
|
-
for name, param in signature(self.func).parameters.items():
|
|
627
|
-
if name in ingestible_files:
|
|
628
|
-
updated_params.append(
|
|
629
|
-
Parameter(name, param.kind, annotation=UploadFile)
|
|
630
|
-
)
|
|
631
|
-
else:
|
|
632
|
-
assert (
|
|
633
|
-
len(param.default.__class__.__bases__) == 1
|
|
634
|
-
), f"Inherited standard type of {param.default.__class__} needs to be one."
|
|
635
|
-
updated_params.append(
|
|
636
|
-
Parameter(
|
|
637
|
-
name,
|
|
638
|
-
Parameter.KEYWORD_ONLY,
|
|
639
|
-
default=Body(..., embed=True),
|
|
640
|
-
annotation=param.default.__class__.__bases__[
|
|
641
|
-
0
|
|
642
|
-
], # determines and get the base (parent/inheritance) type of the sdk-type at run-time. \
|
|
643
|
-
# E.g __class__ is ag.MessagesInput() and accessing it parent type will return (<class 'list'>,), \
|
|
644
|
-
# thus, why we are accessing the first item.
|
|
645
|
-
)
|
|
646
|
-
)
|
|
647
|
-
|
|
648
569
|
def override_config_in_schema(
|
|
649
570
|
self,
|
|
650
571
|
openapi_schema: dict,
|
|
@@ -652,259 +573,26 @@ class entrypoint:
|
|
|
652
573
|
endpoint: str,
|
|
653
574
|
config: Type[BaseModel],
|
|
654
575
|
):
|
|
576
|
+
"""Override config in OpenAPI schema to add agenta-specific metadata."""
|
|
655
577
|
endpoint = endpoint[1:].replace("/", "_")
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
#
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
if isinstance(
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
else:
|
|
678
|
-
schema_to_override[param_name]["x-parameter"] = "text"
|
|
679
|
-
if param_val.annotation is bool:
|
|
680
|
-
schema_to_override[param_name]["x-parameter"] = "bool"
|
|
681
|
-
if param_val.annotation in (int, float):
|
|
682
|
-
schema_to_override[param_name]["x-parameter"] = (
|
|
683
|
-
"int" if param_val.annotation is int else "float"
|
|
684
|
-
)
|
|
685
|
-
# Check for greater than or equal to constraint
|
|
686
|
-
if any(isinstance(constraint, Ge) for constraint in param_val.metadata):
|
|
687
|
-
min_value = next(
|
|
688
|
-
constraint.ge
|
|
689
|
-
for constraint in param_val.metadata
|
|
690
|
-
if isinstance(constraint, Ge)
|
|
691
|
-
)
|
|
692
|
-
schema_to_override[param_name]["minimum"] = min_value
|
|
693
|
-
# Check for greater than constraint
|
|
694
|
-
elif any(
|
|
695
|
-
isinstance(constraint, Gt) for constraint in param_val.metadata
|
|
696
|
-
):
|
|
697
|
-
min_value = next(
|
|
698
|
-
constraint.gt
|
|
699
|
-
for constraint in param_val.metadata
|
|
700
|
-
if isinstance(constraint, Gt)
|
|
701
|
-
)
|
|
702
|
-
schema_to_override[param_name]["exclusiveMinimum"] = min_value
|
|
703
|
-
# Check for less than or equal to constraint
|
|
704
|
-
if any(isinstance(constraint, Le) for constraint in param_val.metadata):
|
|
705
|
-
max_value = next(
|
|
706
|
-
constraint.le
|
|
707
|
-
for constraint in param_val.metadata
|
|
708
|
-
if isinstance(constraint, Le)
|
|
709
|
-
)
|
|
710
|
-
schema_to_override[param_name]["maximum"] = max_value
|
|
711
|
-
# Check for less than constraint
|
|
712
|
-
elif any(
|
|
713
|
-
isinstance(constraint, Lt) for constraint in param_val.metadata
|
|
714
|
-
):
|
|
715
|
-
max_value = next(
|
|
716
|
-
constraint.lt
|
|
717
|
-
for constraint in param_val.metadata
|
|
718
|
-
if isinstance(constraint, Lt)
|
|
719
|
-
)
|
|
720
|
-
schema_to_override[param_name]["exclusiveMaximum"] = max_value
|
|
721
|
-
|
|
722
|
-
def override_schema(
|
|
723
|
-
self, openapi_schema: dict, func_name: str, endpoint: str, params: dict
|
|
724
|
-
):
|
|
725
|
-
"""
|
|
726
|
-
Overrides the default openai schema generated by fastapi with additional information about:
|
|
727
|
-
- The choices available for each MultipleChoiceParam instance
|
|
728
|
-
- The min and max values for each FloatParam instance
|
|
729
|
-
- The min and max values for each IntParam instance
|
|
730
|
-
- The default value for DictInput instance
|
|
731
|
-
- The default value for MessagesParam instance
|
|
732
|
-
- The default value for FileInputURL instance
|
|
733
|
-
- The default value for BinaryParam instance
|
|
734
|
-
- ... [PLEASE ADD AT EACH CHANGE]
|
|
735
|
-
|
|
736
|
-
Args:
|
|
737
|
-
openapi_schema (dict): The openapi schema generated by fastapi
|
|
738
|
-
func (str): The name of the function to override
|
|
739
|
-
endpoint (str): The name of the endpoint to override
|
|
740
|
-
params (dict(param_name, param_val)): The dictionary of the parameters for the function
|
|
741
|
-
"""
|
|
742
|
-
|
|
743
|
-
def find_in_schema(
|
|
744
|
-
schema_type_properties: dict, schema: dict, param_name: str, xparam: str
|
|
745
|
-
):
|
|
746
|
-
"""Finds a parameter in the schema based on its name and x-parameter value"""
|
|
747
|
-
for _, value in schema.items():
|
|
748
|
-
value_title_lower = str(value.get("title")).lower()
|
|
749
|
-
value_title = (
|
|
750
|
-
"_".join(value_title_lower.split())
|
|
751
|
-
if len(value_title_lower.split()) >= 2
|
|
752
|
-
else value_title_lower
|
|
753
|
-
)
|
|
754
|
-
|
|
755
|
-
if (
|
|
756
|
-
isinstance(value, dict)
|
|
757
|
-
and schema_type_properties.get("x-parameter") == xparam
|
|
758
|
-
and value_title == param_name
|
|
759
|
-
):
|
|
760
|
-
# this will update the default type schema with the properties gotten
|
|
761
|
-
# from the schema type (param_val) __schema_properties__ classmethod
|
|
762
|
-
for type_key, type_value in schema_type_properties.items():
|
|
763
|
-
# BEFORE:
|
|
764
|
-
# value = {'temperature': {'title': 'Temperature'}}
|
|
765
|
-
value[type_key] = type_value
|
|
766
|
-
# AFTER:
|
|
767
|
-
# value = {'temperature': { "type": "number", "title": "Temperature", "x-parameter": "float" }}
|
|
768
|
-
return value
|
|
769
|
-
|
|
770
|
-
def get_type_from_param(param_val):
|
|
771
|
-
param_type = "string"
|
|
772
|
-
annotation = param_val.annotation
|
|
773
|
-
|
|
774
|
-
if annotation == int:
|
|
775
|
-
param_type = "integer"
|
|
776
|
-
elif annotation == float:
|
|
777
|
-
param_type = "number"
|
|
778
|
-
elif annotation == dict:
|
|
779
|
-
param_type = "object"
|
|
780
|
-
elif annotation == bool:
|
|
781
|
-
param_type = "boolean"
|
|
782
|
-
elif annotation == list:
|
|
783
|
-
param_type = "list"
|
|
784
|
-
elif annotation == str:
|
|
785
|
-
param_type = "string"
|
|
786
|
-
else:
|
|
787
|
-
print("ERROR, unhandled annotation:", annotation)
|
|
788
|
-
|
|
789
|
-
return param_type
|
|
790
|
-
|
|
791
|
-
# Goes from '/some/path' to 'some_path'
|
|
792
|
-
endpoint = endpoint[1:].replace("/", "_")
|
|
793
|
-
|
|
794
|
-
schema_to_override = openapi_schema["components"]["schemas"][
|
|
795
|
-
f"Body_{func_name}_{endpoint}_post"
|
|
796
|
-
]["properties"]
|
|
797
|
-
|
|
798
|
-
for param_name, param_val in params.items():
|
|
799
|
-
if isinstance(param_val, GroupedMultipleChoiceParam):
|
|
800
|
-
subschema = find_in_schema(
|
|
801
|
-
param_val.__schema_type_properties__(),
|
|
802
|
-
schema_to_override,
|
|
803
|
-
param_name,
|
|
804
|
-
"grouped_choice",
|
|
805
|
-
)
|
|
806
|
-
assert (
|
|
807
|
-
subschema
|
|
808
|
-
), f"GroupedMultipleChoiceParam '{param_name}' is in the parameters but could not be found in the openapi.json"
|
|
809
|
-
subschema["choices"] = param_val.choices # type: ignore
|
|
810
|
-
subschema["default"] = param_val.default # type: ignore
|
|
811
|
-
|
|
812
|
-
elif isinstance(param_val, MultipleChoiceParam):
|
|
813
|
-
subschema = find_in_schema(
|
|
814
|
-
param_val.__schema_type_properties__(),
|
|
815
|
-
schema_to_override,
|
|
816
|
-
param_name,
|
|
817
|
-
"choice",
|
|
818
|
-
)
|
|
819
|
-
default = str(param_val)
|
|
820
|
-
param_choices = param_val.choices # type: ignore
|
|
821
|
-
choices = (
|
|
822
|
-
[default] + param_choices
|
|
823
|
-
if param_val not in param_choices
|
|
824
|
-
else param_choices
|
|
825
|
-
)
|
|
826
|
-
subschema["enum"] = choices
|
|
827
|
-
subschema["default"] = (
|
|
828
|
-
default if default in param_choices else choices[0]
|
|
829
|
-
)
|
|
830
|
-
|
|
831
|
-
elif isinstance(param_val, FloatParam):
|
|
832
|
-
subschema = find_in_schema(
|
|
833
|
-
param_val.__schema_type_properties__(),
|
|
834
|
-
schema_to_override,
|
|
835
|
-
param_name,
|
|
836
|
-
"float",
|
|
837
|
-
)
|
|
838
|
-
subschema["minimum"] = param_val.minval # type: ignore
|
|
839
|
-
subschema["maximum"] = param_val.maxval # type: ignore
|
|
840
|
-
subschema["default"] = param_val
|
|
841
|
-
|
|
842
|
-
elif isinstance(param_val, IntParam):
|
|
843
|
-
subschema = find_in_schema(
|
|
844
|
-
param_val.__schema_type_properties__(),
|
|
845
|
-
schema_to_override,
|
|
846
|
-
param_name,
|
|
847
|
-
"int",
|
|
848
|
-
)
|
|
849
|
-
subschema["minimum"] = param_val.minval # type: ignore
|
|
850
|
-
subschema["maximum"] = param_val.maxval # type: ignore
|
|
851
|
-
subschema["default"] = param_val
|
|
852
|
-
|
|
853
|
-
elif isinstance(param_val, Parameter) and param_val.annotation is DictInput:
|
|
854
|
-
subschema = find_in_schema(
|
|
855
|
-
param_val.annotation.__schema_type_properties__(),
|
|
856
|
-
schema_to_override,
|
|
857
|
-
param_name,
|
|
858
|
-
"dict",
|
|
859
|
-
)
|
|
860
|
-
subschema["default"] = param_val.default["default_keys"]
|
|
861
|
-
|
|
862
|
-
elif isinstance(param_val, TextParam):
|
|
863
|
-
subschema = find_in_schema(
|
|
864
|
-
param_val.__schema_type_properties__(),
|
|
865
|
-
schema_to_override,
|
|
866
|
-
param_name,
|
|
867
|
-
"text",
|
|
868
|
-
)
|
|
869
|
-
subschema["default"] = param_val
|
|
870
|
-
|
|
871
|
-
elif (
|
|
872
|
-
isinstance(param_val, Parameter)
|
|
873
|
-
and param_val.annotation is MessagesInput
|
|
874
|
-
):
|
|
875
|
-
subschema = find_in_schema(
|
|
876
|
-
param_val.annotation.__schema_type_properties__(),
|
|
877
|
-
schema_to_override,
|
|
878
|
-
param_name,
|
|
879
|
-
"messages",
|
|
880
|
-
)
|
|
881
|
-
subschema["default"] = param_val.default
|
|
882
|
-
|
|
883
|
-
elif (
|
|
884
|
-
isinstance(param_val, Parameter)
|
|
885
|
-
and param_val.annotation is FileInputURL
|
|
886
|
-
):
|
|
887
|
-
subschema = find_in_schema(
|
|
888
|
-
param_val.annotation.__schema_type_properties__(),
|
|
889
|
-
schema_to_override,
|
|
890
|
-
param_name,
|
|
891
|
-
"file_url",
|
|
892
|
-
)
|
|
893
|
-
subschema["default"] = "https://example.com"
|
|
894
|
-
|
|
895
|
-
elif isinstance(param_val, BinaryParam):
|
|
896
|
-
subschema = find_in_schema(
|
|
897
|
-
param_val.__schema_type_properties__(),
|
|
898
|
-
schema_to_override,
|
|
899
|
-
param_name,
|
|
900
|
-
"bool",
|
|
901
|
-
)
|
|
902
|
-
subschema["default"] = param_val.default # type: ignore
|
|
903
|
-
else:
|
|
904
|
-
subschema = {
|
|
905
|
-
"title": str(param_name).capitalize(),
|
|
906
|
-
"type": get_type_from_param(param_val),
|
|
907
|
-
}
|
|
908
|
-
if param_val.default != _empty:
|
|
909
|
-
subschema["default"] = param_val.default # type: ignore
|
|
910
|
-
schema_to_override[param_name] = subschema
|
|
578
|
+
schema_key = f"Body_{func_name}_{endpoint}_post"
|
|
579
|
+
schema_to_override = openapi_schema["components"]["schemas"][schema_key]
|
|
580
|
+
|
|
581
|
+
# Get the config class name to find its schema
|
|
582
|
+
config_class_name = type(config).__name__
|
|
583
|
+
config_schema = openapi_schema["components"]["schemas"][config_class_name]
|
|
584
|
+
# Process each field in the config class
|
|
585
|
+
for field_name, field in config.__class__.model_fields.items():
|
|
586
|
+
# Check if field has Annotated metadata for MultipleChoice
|
|
587
|
+
if hasattr(field, "metadata") and field.metadata:
|
|
588
|
+
for meta in field.metadata:
|
|
589
|
+
if isinstance(meta, MultipleChoice):
|
|
590
|
+
choices = meta.choices
|
|
591
|
+
if isinstance(choices, dict):
|
|
592
|
+
config_schema["properties"][field_name].update(
|
|
593
|
+
{"x-parameter": "grouped_choice", "choices": choices}
|
|
594
|
+
)
|
|
595
|
+
elif isinstance(choices, list):
|
|
596
|
+
config_schema["properties"][field_name].update(
|
|
597
|
+
{"x-parameter": "choice", "enum": choices}
|
|
598
|
+
)
|