agenta 0.32.0__py3-none-any.whl → 0.32.0a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of agenta might be problematic. Click here for more details.

@@ -1,19 +1,20 @@
1
+ from sre_parse import NOT_LITERAL_UNI_IGNORE
1
2
  from typing import Type, Any, Callable, Dict, Optional, Tuple, List
2
- from inspect import signature, iscoroutinefunction, Signature, Parameter, _empty
3
+ from inspect import signature, iscoroutinefunction, Signature, Parameter
3
4
  from functools import wraps
4
5
  from traceback import format_exception
5
6
  from asyncio import sleep
6
-
7
- from tempfile import NamedTemporaryFile
8
- from annotated_types import Ge, Le, Gt, Lt
7
+ from uuid import UUID
9
8
  from pydantic import BaseModel, HttpUrl, ValidationError
10
9
 
11
- from fastapi import Body, FastAPI, UploadFile, HTTPException, Request
10
+ from fastapi import Body, FastAPI, HTTPException, Request
12
11
 
13
- from agenta.sdk.middleware.auth import AuthMiddleware
14
- from agenta.sdk.middleware.otel import OTelMiddleware
15
- from agenta.sdk.middleware.config import ConfigMiddleware
12
+ from agenta.sdk.middleware.mock import MockMiddleware
13
+ from agenta.sdk.middleware.inline import InlineMiddleware
16
14
  from agenta.sdk.middleware.vault import VaultMiddleware
15
+ from agenta.sdk.middleware.config import ConfigMiddleware
16
+ from agenta.sdk.middleware.otel import OTelMiddleware
17
+ from agenta.sdk.middleware.auth import AuthMiddleware
17
18
  from agenta.sdk.middleware.cors import CORSMiddleware
18
19
 
19
20
  from agenta.sdk.context.routing import (
@@ -30,18 +31,9 @@ from agenta.sdk.utils.exceptions import suppress, display_exception
30
31
  from agenta.sdk.utils.logging import log
31
32
  from agenta.sdk.utils.helpers import get_current_version
32
33
  from agenta.sdk.types import (
33
- DictInput,
34
- FloatParam,
35
- InFile,
36
- IntParam,
37
- MultipleChoiceParam,
38
34
  MultipleChoice,
39
- GroupedMultipleChoiceParam,
40
- TextParam,
41
- MessagesInput,
42
- FileInputURL,
43
35
  BaseResponse,
44
- BinaryParam,
36
+ MCField,
45
37
  )
46
38
 
47
39
  import agenta as ag
@@ -124,6 +116,7 @@ class entrypoint:
124
116
  _middleware = False
125
117
  _run_path = "/run"
126
118
  _test_path = "/test"
119
+ _config_key = "ag_config"
127
120
  # LEGACY
128
121
  _legacy_playground_run_path = "/playground/run"
129
122
  _legacy_generate_path = "/generate"
@@ -140,13 +133,13 @@ class entrypoint:
140
133
  self.config_schema = config_schema
141
134
 
142
135
  signature_parameters = signature(func).parameters
143
- ingestible_files = self.extract_ingestible_files()
144
136
  config, default_parameters = self.parse_config()
145
137
 
146
138
  ### --- Middleware --- #
147
139
  if not entrypoint._middleware:
148
140
  entrypoint._middleware = True
149
-
141
+ app.add_middleware(MockMiddleware)
142
+ app.add_middleware(InlineMiddleware)
150
143
  app.add_middleware(VaultMiddleware)
151
144
  app.add_middleware(ConfigMiddleware)
152
145
  app.add_middleware(AuthMiddleware)
@@ -167,65 +160,84 @@ class entrypoint:
167
160
  }
168
161
  # LEGACY
169
162
 
170
- kwargs, _ = self.split_kwargs(kwargs, default_parameters)
171
-
172
- # TODO: Why is this not used in the run_wrapper?
173
- # self.ingest_files(kwargs, ingestible_files)
163
+ kwargs, _ = self.process_kwargs(kwargs, default_parameters)
164
+ if (
165
+ request.state.config["parameters"] is None
166
+ or request.state.config["references"] is None
167
+ ):
168
+ raise HTTPException(
169
+ status_code=400,
170
+ detail="Config not found based on provided references.",
171
+ )
174
172
 
175
- return await self.execute_wrapper(request, False, *args, **kwargs)
173
+ return await self.execute_wrapper(request, *args, **kwargs)
176
174
 
177
- self.update_run_wrapper_signature(
178
- wrapper=run_wrapper,
179
- ingestible_files=ingestible_files,
180
- )
175
+ self.update_run_wrapper_signature(wrapper=run_wrapper)
181
176
 
182
177
  run_route = f"{entrypoint._run_path}{route_path}"
183
- app.post(run_route, response_model=BaseResponse)(run_wrapper)
178
+ app.post(
179
+ run_route,
180
+ response_model=BaseResponse,
181
+ response_model_exclude_none=True,
182
+ )(run_wrapper)
184
183
 
185
184
  # LEGACY
186
185
  # TODO: Removing this implies breaking changes in :
187
186
  # - calls to /generate_deployed must be replaced with calls to /run
188
187
  if route_path == "":
189
188
  run_route = entrypoint._legacy_generate_deployed_path
190
- app.post(run_route, response_model=BaseResponse)(run_wrapper)
189
+ app.post(
190
+ run_route,
191
+ response_model=BaseResponse,
192
+ response_model_exclude_none=True,
193
+ )(run_wrapper)
191
194
  # LEGACY
192
195
  ### ----------- #
193
196
 
194
197
  ### --- Test --- #
195
198
  @wraps(func)
196
199
  async def test_wrapper(request: Request, *args, **kwargs) -> Any:
197
- kwargs, parameters = self.split_kwargs(kwargs, default_parameters)
198
-
199
- request.state.config["parameters"] = parameters
200
-
201
- # TODO: Why is this only used in the test_wrapper?
202
- self.ingest_files(kwargs, ingestible_files)
203
-
204
- return await self.execute_wrapper(request, True, *args, **kwargs)
205
-
206
- self.update_test_wrapper_signature(
207
- wrapper=test_wrapper,
208
- ingestible_files=ingestible_files,
209
- config_class=config,
210
- config_dict=default_parameters,
211
- )
200
+ kwargs, config = self.process_kwargs(kwargs, default_parameters)
201
+ request.state.inline = True
202
+ request.state.config["parameters"] = config
203
+ if request.state.config["references"]:
204
+ request.state.config["references"] = {
205
+ k: v
206
+ for k, v in request.state.config["references"].items()
207
+ if k.startswith("application")
208
+ } or None
209
+ return await self.execute_wrapper(request, *args, **kwargs)
210
+
211
+ self.update_test_wrapper_signature(wrapper=test_wrapper, config_instance=config)
212
212
 
213
213
  test_route = f"{entrypoint._test_path}{route_path}"
214
- app.post(test_route, response_model=BaseResponse)(test_wrapper)
214
+ app.post(
215
+ test_route,
216
+ response_model=BaseResponse,
217
+ response_model_exclude_none=True,
218
+ )(test_wrapper)
215
219
 
216
220
  # LEGACY
217
221
  # TODO: Removing this implies breaking changes in :
218
222
  # - calls to /generate must be replaced with calls to /test
219
223
  if route_path == "":
220
224
  test_route = entrypoint._legacy_generate_path
221
- app.post(test_route, response_model=BaseResponse)(test_wrapper)
225
+ app.post(
226
+ test_route,
227
+ response_model=BaseResponse,
228
+ response_model_exclude_none=True,
229
+ )(test_wrapper)
222
230
  # LEGACY
223
231
 
224
232
  # LEGACY
225
233
  # TODO: Removing this implies no breaking changes
226
234
  if route_path == "":
227
235
  test_route = entrypoint._legacy_playground_run_path
228
- app.post(test_route, response_model=BaseResponse)(test_wrapper)
236
+ app.post(
237
+ test_route,
238
+ response_model=BaseResponse,
239
+ response_model_exclude_none=True,
240
+ )(test_wrapper)
229
241
  # LEGACY
230
242
  ### ------------ #
231
243
 
@@ -235,11 +247,7 @@ class entrypoint:
235
247
  {
236
248
  "func": func.__name__,
237
249
  "endpoint": test_route,
238
- "params": (
239
- {**default_parameters, **signature_parameters}
240
- if not config
241
- else signature_parameters
242
- ),
250
+ "params": signature_parameters,
243
251
  "config": config,
244
252
  }
245
253
  )
@@ -263,18 +271,9 @@ class entrypoint:
263
271
 
264
272
  app.openapi_schema = None # Forces FastAPI to re-generate the schema
265
273
  openapi_schema = app.openapi()
266
-
267
274
  openapi_schema["agenta_sdk"] = {"version": get_current_version()}
268
-
269
275
  for _route in entrypoint.routes:
270
- self.override_schema(
271
- openapi_schema=openapi_schema,
272
- func_name=_route["func"],
273
- endpoint=_route["endpoint"],
274
- params=_route["params"],
275
- )
276
-
277
- if _route["config"] is not None: # new SDK version
276
+ if _route["config"] is not None:
278
277
  self.override_config_in_schema(
279
278
  openapi_schema=openapi_schema,
280
279
  func_name=_route["func"],
@@ -283,23 +282,15 @@ class entrypoint:
283
282
  )
284
283
  ### --------------- #
285
284
 
286
- def extract_ingestible_files(self) -> Dict[str, Parameter]:
287
- """Extract parameters annotated as InFile from function signature."""
288
-
289
- return {
290
- name: param
291
- for name, param in signature(self.func).parameters.items()
292
- if param.annotation is InFile
293
- }
294
-
295
- def parse_config(self) -> Dict[str, Any]:
285
+ def parse_config(self) -> Tuple[Optional[Type[BaseModel]], Dict[str, Any]]:
286
+ """Parse the config schema and return the config class and default parameters."""
296
287
  config = None
297
- default_parameters = ag.config.all()
288
+ default_parameters = {}
298
289
 
299
290
  if self.config_schema:
300
291
  try:
301
292
  config = self.config_schema() if self.config_schema else None
302
- default_parameters = config.dict() if config else default_parameters
293
+ default_parameters = config.dict() if config else {}
303
294
  except ValidationError as e:
304
295
  raise ValueError(
305
296
  f"Error initializing config_schema. Please ensure all required fields have default values: {str(e)}"
@@ -311,39 +302,22 @@ class entrypoint:
311
302
 
312
303
  return config, default_parameters
313
304
 
314
- def split_kwargs(
305
+ def process_kwargs(
315
306
  self, kwargs: Dict[str, Any], default_parameters: Dict[str, Any]
316
307
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
317
- arguments = {k: v for k, v in kwargs.items() if k not in default_parameters}
318
- parameters = {k: v for k, v in kwargs.items() if k in default_parameters}
319
-
320
- return arguments, parameters
321
-
322
- def ingest_file(
323
- self,
324
- upfile: UploadFile,
325
- ):
326
- temp_file = NamedTemporaryFile(delete=False)
327
- temp_file.write(upfile.file.read())
328
- temp_file.close()
329
-
330
- return InFile(file_name=upfile.filename, file_path=temp_file.name)
308
+ """Remove the config parameters from the kwargs."""
309
+ # Extract agenta_config if present
310
+ config_params = kwargs.pop(self._config_key, {})
311
+ if isinstance(config_params, BaseModel):
312
+ config_params = config_params.dict()
313
+ # Merge with default parameters
314
+ config = {**default_parameters, **config_params}
331
315
 
332
- def ingest_files(
333
- self,
334
- func_params: Dict[str, Any],
335
- ingestible_files: Dict[str, Parameter],
336
- ) -> None:
337
- """Ingest files specified in function parameters."""
338
-
339
- for name in ingestible_files:
340
- if name in func_params and func_params[name] is not None:
341
- func_params[name] = self.ingest_file(func_params[name])
316
+ return kwargs, config
342
317
 
343
318
  async def execute_wrapper(
344
319
  self,
345
320
  request: Request,
346
- inline: bool,
347
321
  *args,
348
322
  **kwargs,
349
323
  ):
@@ -355,11 +329,14 @@ class entrypoint:
355
329
  parameters = state.config.get("parameters")
356
330
  references = state.config.get("references")
357
331
  secrets = state.vault.get("secrets")
332
+ inline = state.inline
333
+ mock = state.mock
358
334
 
359
335
  with routing_context_manager(
360
336
  context=RoutingContext(
361
337
  parameters=parameters,
362
338
  secrets=secrets,
339
+ mock=mock,
363
340
  )
364
341
  ):
365
342
  with tracing_context_manager(
@@ -369,27 +346,17 @@ class entrypoint:
369
346
  references=references,
370
347
  )
371
348
  ):
372
- result = await self.execute_function(inline, *args, **kwargs)
373
-
374
- return result
375
-
376
- async def execute_function(
377
- self,
378
- inline: bool,
379
- *args,
380
- **kwargs,
381
- ):
382
- try:
383
- result = (
384
- await self.func(*args, **kwargs)
385
- if iscoroutinefunction(self.func)
386
- else self.func(*args, **kwargs)
387
- )
349
+ try:
350
+ result = (
351
+ await self.func(*args, **kwargs)
352
+ if iscoroutinefunction(self.func)
353
+ else self.func(*args, **kwargs)
354
+ )
388
355
 
389
- return await self.handle_success(result, inline)
356
+ return await self.handle_success(result, inline)
390
357
 
391
- except Exception as error: # pylint: disable=broad-except
392
- self.handle_failure(error)
358
+ except Exception as error: # pylint: disable=broad-except
359
+ self.handle_failure(error)
393
360
 
394
361
  async def handle_success(
395
362
  self,
@@ -398,17 +365,23 @@ class entrypoint:
398
365
  ):
399
366
  data = None
400
367
  tree = None
368
+ content_type = "text/plain"
369
+ tree_id = None
401
370
 
402
371
  with suppress():
372
+ if isinstance(result, (dict, list)):
373
+ content_type = "application/json"
403
374
  data = self.patch_result(result)
404
375
 
405
376
  if inline:
406
- tree = await self.fetch_inline_trace(inline)
377
+ tree, tree_id = await self.fetch_inline_trace(inline)
407
378
 
408
379
  try:
409
- return BaseResponse(data=data, tree=tree)
380
+ return BaseResponse(
381
+ data=data, tree=tree, content_type=content_type, tree_id=tree_id
382
+ )
410
383
  except:
411
- return BaseResponse(data=data)
384
+ return BaseResponse(data=data, content_type=content_type)
412
385
 
413
386
  def handle_failure(
414
387
  self,
@@ -416,12 +389,15 @@ class entrypoint:
416
389
  ):
417
390
  display_exception("Application Exception")
418
391
 
419
- status_code = 500
420
- message = str(error)
392
+ status_code = (
393
+ getattr(error, "status_code") if hasattr(error, "status_code") else 500
394
+ )
421
395
  stacktrace = format_exception(error, value=error, tb=error.__traceback__) # type: ignore
422
- detail = {"message": message, "stacktrace": stacktrace}
423
396
 
424
- raise HTTPException(status_code=status_code, detail=detail)
397
+ raise HTTPException(
398
+ status_code=status_code,
399
+ detail={"message": str(error), "stacktrace": stacktrace},
400
+ )
425
401
 
426
402
  def patch_result(
427
403
  self,
@@ -465,42 +441,33 @@ class entrypoint:
465
441
 
466
442
  async def fetch_inline_trace(
467
443
  self,
468
- inline,
444
+ inline: bool,
469
445
  ):
470
- WAIT_FOR_SPANS = True
471
446
  TIMEOUT = 1
472
447
  TIMESTEP = 0.1
473
- FINALSTEP = 0.001
474
448
  NOFSTEPS = TIMEOUT / TIMESTEP
475
449
 
476
- trace = None
477
-
478
450
  context = tracing_context.get()
479
451
 
480
452
  link = context.link
481
453
 
482
- trace_id = link.get("tree_id") if link else None
454
+ tree = None
455
+ _tree_id = link.get("tree_id") if link else None # in int format
456
+ tree_id = str(UUID(int=_tree_id)) if _tree_id else None # in uuid_as_str format
483
457
 
484
- if trace_id is not None:
458
+ if _tree_id is not None:
485
459
  if inline:
486
- if WAIT_FOR_SPANS:
487
- remaining_steps = NOFSTEPS
488
-
489
- while (
490
- not ag.tracing.is_inline_trace_ready(trace_id)
491
- and remaining_steps > 0
492
- ):
493
- await sleep(TIMESTEP)
494
-
495
- remaining_steps -= 1
496
-
497
- await sleep(FINALSTEP)
460
+ remaining_steps = NOFSTEPS
461
+ while (
462
+ not ag.tracing.is_inline_trace_ready(_tree_id)
463
+ and remaining_steps > 0
464
+ ):
465
+ await sleep(TIMESTEP)
498
466
 
499
- trace = ag.tracing.get_inline_trace(trace_id)
500
- else:
501
- trace = {"trace_id": trace_id}
467
+ remaining_steps -= 1
502
468
 
503
- return trace
469
+ tree = ag.tracing.get_inline_trace(_tree_id)
470
+ return tree, tree_id
504
471
 
505
472
  # --- OpenAPI --- #
506
473
 
@@ -542,74 +509,56 @@ class entrypoint:
542
509
  def update_test_wrapper_signature(
543
510
  self,
544
511
  wrapper: Callable[..., Any],
545
- config_class: Type[BaseModel], # TODO: change to our type
546
- config_dict: Dict[str, Any],
547
- ingestible_files: Dict[str, Parameter],
512
+ config_instance: Type[BaseModel], # TODO: change to our type
548
513
  ) -> None:
549
514
  """Update the function signature to include new parameters."""
550
515
 
551
516
  updated_params: List[Parameter] = []
552
- if config_class:
553
- self.add_config_params_to_parser(updated_params, config_class)
554
- else:
555
- self.deprecated_add_config_params_to_parser(updated_params, config_dict)
556
- self.add_func_params_to_parser(updated_params, ingestible_files)
517
+ self.add_config_params_to_parser(updated_params, config_instance)
518
+ self.add_func_params_to_parser(updated_params)
557
519
  self.update_wrapper_signature(wrapper, updated_params)
558
520
  self.add_request_to_signature(wrapper)
559
521
 
560
522
  def update_run_wrapper_signature(
561
523
  self,
562
524
  wrapper: Callable[..., Any],
563
- ingestible_files: Dict[str, Parameter],
564
525
  ) -> None:
565
526
  """Update the function signature to include new parameters."""
566
527
 
567
528
  updated_params: List[Parameter] = []
568
- self.add_func_params_to_parser(updated_params, ingestible_files)
569
- for param in [
570
- "config",
571
- "environment",
572
- ]: # we add the config and environment parameters
573
- updated_params.append(
574
- Parameter(
575
- name=param,
576
- kind=Parameter.KEYWORD_ONLY,
577
- default=Body(None),
578
- annotation=str,
579
- )
580
- )
529
+ self.add_func_params_to_parser(updated_params)
581
530
  self.update_wrapper_signature(wrapper, updated_params)
582
531
  self.add_request_to_signature(wrapper)
583
532
 
584
533
  def add_config_params_to_parser(
585
- self, updated_params: list, config_class: Type[BaseModel]
534
+ self, updated_params: list, config_instance: Type[BaseModel]
586
535
  ) -> None:
587
536
  """Add configuration parameters to function signature."""
588
- for name, field in config_class.__fields__.items():
537
+
538
+ for name, field in config_instance.model_fields.items():
589
539
  assert field.default is not None, f"Field {name} has no default value"
590
- updated_params.append(
591
- Parameter(
592
- name=name,
593
- kind=Parameter.KEYWORD_ONLY,
594
- annotation=field.annotation.__name__,
595
- default=Body(field.default),
596
- )
540
+
541
+ updated_params.append(
542
+ Parameter(
543
+ name=self._config_key,
544
+ kind=Parameter.KEYWORD_ONLY,
545
+ annotation=type(config_instance), # Get the actual class type
546
+ default=Body(config_instance), # Use the instance directly
597
547
  )
548
+ )
598
549
 
599
- def deprecated_add_config_params_to_parser(
600
- self, updated_params: list, config_dict: Dict[str, Any]
601
- ) -> None:
602
- """Add configuration parameters to function signature."""
603
- for name, param in config_dict.items():
550
+ def add_func_params_to_parser(self, updated_params: list) -> None:
551
+ """Add function parameters to function signature."""
552
+ for name, param in signature(self.func).parameters.items():
604
553
  assert (
605
- len(param.__class__.__bases__) == 1
606
- ), f"Inherited standard type of {param.__class__} needs to be one."
554
+ len(param.default.__class__.__bases__) == 1
555
+ ), f"Inherited standard type of {param.default.__class__} needs to be one."
607
556
  updated_params.append(
608
557
  Parameter(
609
- name=name,
610
- kind=Parameter.KEYWORD_ONLY,
611
- default=Body(param),
612
- annotation=param.__class__.__bases__[
558
+ name,
559
+ Parameter.KEYWORD_ONLY,
560
+ default=Body(..., embed=True),
561
+ annotation=param.default.__class__.__bases__[
613
562
  0
614
563
  ], # determines and get the base (parent/inheritance) type of the sdk-type at run-time. \
615
564
  # E.g __class__ is ag.MessagesInput() and accessing it parent type will return (<class 'list'>,), \
@@ -617,34 +566,6 @@ class entrypoint:
617
566
  )
618
567
  )
619
568
 
620
- def add_func_params_to_parser(
621
- self,
622
- updated_params: list,
623
- ingestible_files: Dict[str, Parameter],
624
- ) -> None:
625
- """Add function parameters to function signature."""
626
- for name, param in signature(self.func).parameters.items():
627
- if name in ingestible_files:
628
- updated_params.append(
629
- Parameter(name, param.kind, annotation=UploadFile)
630
- )
631
- else:
632
- assert (
633
- len(param.default.__class__.__bases__) == 1
634
- ), f"Inherited standard type of {param.default.__class__} needs to be one."
635
- updated_params.append(
636
- Parameter(
637
- name,
638
- Parameter.KEYWORD_ONLY,
639
- default=Body(..., embed=True),
640
- annotation=param.default.__class__.__bases__[
641
- 0
642
- ], # determines and get the base (parent/inheritance) type of the sdk-type at run-time. \
643
- # E.g __class__ is ag.MessagesInput() and accessing it parent type will return (<class 'list'>,), \
644
- # thus, why we are accessing the first item.
645
- )
646
- )
647
-
648
569
  def override_config_in_schema(
649
570
  self,
650
571
  openapi_schema: dict,
@@ -652,259 +573,26 @@ class entrypoint:
652
573
  endpoint: str,
653
574
  config: Type[BaseModel],
654
575
  ):
576
+ """Override config in OpenAPI schema to add agenta-specific metadata."""
655
577
  endpoint = endpoint[1:].replace("/", "_")
656
- schema_to_override = openapi_schema["components"]["schemas"][
657
- f"Body_{func_name}_{endpoint}_post"
658
- ]["properties"]
659
- # New logic
660
- for param_name, param_val in config.__fields__.items():
661
- if param_val.annotation is str:
662
- if any(
663
- isinstance(constraint, MultipleChoice)
664
- for constraint in param_val.metadata
665
- ):
666
- choices = next(
667
- constraint.choices
668
- for constraint in param_val.metadata
669
- if isinstance(constraint, MultipleChoice)
670
- )
671
- if isinstance(choices, dict):
672
- schema_to_override[param_name]["x-parameter"] = "grouped_choice"
673
- schema_to_override[param_name]["choices"] = choices
674
- elif isinstance(choices, list):
675
- schema_to_override[param_name]["x-parameter"] = "choice"
676
- schema_to_override[param_name]["enum"] = choices
677
- else:
678
- schema_to_override[param_name]["x-parameter"] = "text"
679
- if param_val.annotation is bool:
680
- schema_to_override[param_name]["x-parameter"] = "bool"
681
- if param_val.annotation in (int, float):
682
- schema_to_override[param_name]["x-parameter"] = (
683
- "int" if param_val.annotation is int else "float"
684
- )
685
- # Check for greater than or equal to constraint
686
- if any(isinstance(constraint, Ge) for constraint in param_val.metadata):
687
- min_value = next(
688
- constraint.ge
689
- for constraint in param_val.metadata
690
- if isinstance(constraint, Ge)
691
- )
692
- schema_to_override[param_name]["minimum"] = min_value
693
- # Check for greater than constraint
694
- elif any(
695
- isinstance(constraint, Gt) for constraint in param_val.metadata
696
- ):
697
- min_value = next(
698
- constraint.gt
699
- for constraint in param_val.metadata
700
- if isinstance(constraint, Gt)
701
- )
702
- schema_to_override[param_name]["exclusiveMinimum"] = min_value
703
- # Check for less than or equal to constraint
704
- if any(isinstance(constraint, Le) for constraint in param_val.metadata):
705
- max_value = next(
706
- constraint.le
707
- for constraint in param_val.metadata
708
- if isinstance(constraint, Le)
709
- )
710
- schema_to_override[param_name]["maximum"] = max_value
711
- # Check for less than constraint
712
- elif any(
713
- isinstance(constraint, Lt) for constraint in param_val.metadata
714
- ):
715
- max_value = next(
716
- constraint.lt
717
- for constraint in param_val.metadata
718
- if isinstance(constraint, Lt)
719
- )
720
- schema_to_override[param_name]["exclusiveMaximum"] = max_value
721
-
722
- def override_schema(
723
- self, openapi_schema: dict, func_name: str, endpoint: str, params: dict
724
- ):
725
- """
726
- Overrides the default openai schema generated by fastapi with additional information about:
727
- - The choices available for each MultipleChoiceParam instance
728
- - The min and max values for each FloatParam instance
729
- - The min and max values for each IntParam instance
730
- - The default value for DictInput instance
731
- - The default value for MessagesParam instance
732
- - The default value for FileInputURL instance
733
- - The default value for BinaryParam instance
734
- - ... [PLEASE ADD AT EACH CHANGE]
735
-
736
- Args:
737
- openapi_schema (dict): The openapi schema generated by fastapi
738
- func (str): The name of the function to override
739
- endpoint (str): The name of the endpoint to override
740
- params (dict(param_name, param_val)): The dictionary of the parameters for the function
741
- """
742
-
743
- def find_in_schema(
744
- schema_type_properties: dict, schema: dict, param_name: str, xparam: str
745
- ):
746
- """Finds a parameter in the schema based on its name and x-parameter value"""
747
- for _, value in schema.items():
748
- value_title_lower = str(value.get("title")).lower()
749
- value_title = (
750
- "_".join(value_title_lower.split())
751
- if len(value_title_lower.split()) >= 2
752
- else value_title_lower
753
- )
754
-
755
- if (
756
- isinstance(value, dict)
757
- and schema_type_properties.get("x-parameter") == xparam
758
- and value_title == param_name
759
- ):
760
- # this will update the default type schema with the properties gotten
761
- # from the schema type (param_val) __schema_properties__ classmethod
762
- for type_key, type_value in schema_type_properties.items():
763
- # BEFORE:
764
- # value = {'temperature': {'title': 'Temperature'}}
765
- value[type_key] = type_value
766
- # AFTER:
767
- # value = {'temperature': { "type": "number", "title": "Temperature", "x-parameter": "float" }}
768
- return value
769
-
770
- def get_type_from_param(param_val):
771
- param_type = "string"
772
- annotation = param_val.annotation
773
-
774
- if annotation == int:
775
- param_type = "integer"
776
- elif annotation == float:
777
- param_type = "number"
778
- elif annotation == dict:
779
- param_type = "object"
780
- elif annotation == bool:
781
- param_type = "boolean"
782
- elif annotation == list:
783
- param_type = "list"
784
- elif annotation == str:
785
- param_type = "string"
786
- else:
787
- print("ERROR, unhandled annotation:", annotation)
788
-
789
- return param_type
790
-
791
- # Goes from '/some/path' to 'some_path'
792
- endpoint = endpoint[1:].replace("/", "_")
793
-
794
- schema_to_override = openapi_schema["components"]["schemas"][
795
- f"Body_{func_name}_{endpoint}_post"
796
- ]["properties"]
797
-
798
- for param_name, param_val in params.items():
799
- if isinstance(param_val, GroupedMultipleChoiceParam):
800
- subschema = find_in_schema(
801
- param_val.__schema_type_properties__(),
802
- schema_to_override,
803
- param_name,
804
- "grouped_choice",
805
- )
806
- assert (
807
- subschema
808
- ), f"GroupedMultipleChoiceParam '{param_name}' is in the parameters but could not be found in the openapi.json"
809
- subschema["choices"] = param_val.choices # type: ignore
810
- subschema["default"] = param_val.default # type: ignore
811
-
812
- elif isinstance(param_val, MultipleChoiceParam):
813
- subschema = find_in_schema(
814
- param_val.__schema_type_properties__(),
815
- schema_to_override,
816
- param_name,
817
- "choice",
818
- )
819
- default = str(param_val)
820
- param_choices = param_val.choices # type: ignore
821
- choices = (
822
- [default] + param_choices
823
- if param_val not in param_choices
824
- else param_choices
825
- )
826
- subschema["enum"] = choices
827
- subschema["default"] = (
828
- default if default in param_choices else choices[0]
829
- )
830
-
831
- elif isinstance(param_val, FloatParam):
832
- subschema = find_in_schema(
833
- param_val.__schema_type_properties__(),
834
- schema_to_override,
835
- param_name,
836
- "float",
837
- )
838
- subschema["minimum"] = param_val.minval # type: ignore
839
- subschema["maximum"] = param_val.maxval # type: ignore
840
- subschema["default"] = param_val
841
-
842
- elif isinstance(param_val, IntParam):
843
- subschema = find_in_schema(
844
- param_val.__schema_type_properties__(),
845
- schema_to_override,
846
- param_name,
847
- "int",
848
- )
849
- subschema["minimum"] = param_val.minval # type: ignore
850
- subschema["maximum"] = param_val.maxval # type: ignore
851
- subschema["default"] = param_val
852
-
853
- elif isinstance(param_val, Parameter) and param_val.annotation is DictInput:
854
- subschema = find_in_schema(
855
- param_val.annotation.__schema_type_properties__(),
856
- schema_to_override,
857
- param_name,
858
- "dict",
859
- )
860
- subschema["default"] = param_val.default["default_keys"]
861
-
862
- elif isinstance(param_val, TextParam):
863
- subschema = find_in_schema(
864
- param_val.__schema_type_properties__(),
865
- schema_to_override,
866
- param_name,
867
- "text",
868
- )
869
- subschema["default"] = param_val
870
-
871
- elif (
872
- isinstance(param_val, Parameter)
873
- and param_val.annotation is MessagesInput
874
- ):
875
- subschema = find_in_schema(
876
- param_val.annotation.__schema_type_properties__(),
877
- schema_to_override,
878
- param_name,
879
- "messages",
880
- )
881
- subschema["default"] = param_val.default
882
-
883
- elif (
884
- isinstance(param_val, Parameter)
885
- and param_val.annotation is FileInputURL
886
- ):
887
- subschema = find_in_schema(
888
- param_val.annotation.__schema_type_properties__(),
889
- schema_to_override,
890
- param_name,
891
- "file_url",
892
- )
893
- subschema["default"] = "https://example.com"
894
-
895
- elif isinstance(param_val, BinaryParam):
896
- subschema = find_in_schema(
897
- param_val.__schema_type_properties__(),
898
- schema_to_override,
899
- param_name,
900
- "bool",
901
- )
902
- subschema["default"] = param_val.default # type: ignore
903
- else:
904
- subschema = {
905
- "title": str(param_name).capitalize(),
906
- "type": get_type_from_param(param_val),
907
- }
908
- if param_val.default != _empty:
909
- subschema["default"] = param_val.default # type: ignore
910
- schema_to_override[param_name] = subschema
578
+ schema_key = f"Body_{func_name}_{endpoint}_post"
579
+ schema_to_override = openapi_schema["components"]["schemas"][schema_key]
580
+
581
+ # Get the config class name to find its schema
582
+ config_class_name = type(config).__name__
583
+ config_schema = openapi_schema["components"]["schemas"][config_class_name]
584
+ # Process each field in the config class
585
+ for field_name, field in config.__class__.model_fields.items():
586
+ # Check if field has Annotated metadata for MultipleChoice
587
+ if hasattr(field, "metadata") and field.metadata:
588
+ for meta in field.metadata:
589
+ if isinstance(meta, MultipleChoice):
590
+ choices = meta.choices
591
+ if isinstance(choices, dict):
592
+ config_schema["properties"][field_name].update(
593
+ {"x-parameter": "grouped_choice", "choices": choices}
594
+ )
595
+ elif isinstance(choices, list):
596
+ config_schema["properties"][field_name].update(
597
+ {"x-parameter": "choice", "enum": choices}
598
+ )