agenta 0.14.7a0__py3-none-any.whl → 0.14.7a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of agenta might be problematic. Click here for more details.

@@ -1,493 +0,0 @@
1
- """The code for the Agenta SDK"""
2
-
3
- import os
4
- import sys
5
- import time
6
- import inspect
7
- import argparse
8
- import asyncio
9
- import traceback
10
- import functools
11
- from pathlib import Path
12
- from tempfile import NamedTemporaryFile
13
- from typing import Any, Callable, Dict, Optional, Tuple, List
14
-
15
- from fastapi.middleware.cors import CORSMiddleware
16
- from fastapi import Body, FastAPI, UploadFile, HTTPException
17
-
18
- import agenta
19
- from agenta.sdk.context import save_context
20
- from agenta.sdk.router import router as router
21
- from agenta.sdk.tracing.llm_tracing import Tracing
22
- from agenta.sdk.decorators.base import BaseDecorator
23
- from agenta.sdk.types import (
24
- Context,
25
- DictInput,
26
- FloatParam,
27
- InFile,
28
- IntParam,
29
- MultipleChoiceParam,
30
- GroupedMultipleChoiceParam,
31
- TextParam,
32
- MessagesInput,
33
- FileInputURL,
34
- FuncResponse,
35
- BinaryParam,
36
- )
37
-
38
- app = FastAPI()
39
-
40
- origins = [
41
- "*",
42
- ]
43
-
44
- app.add_middleware(
45
- CORSMiddleware,
46
- allow_origins=origins,
47
- allow_credentials=True,
48
- allow_methods=["*"],
49
- allow_headers=["*"],
50
- )
51
-
52
- app.include_router(router, prefix="")
53
-
54
-
55
- class entrypoint(BaseDecorator):
56
- """Decorator class to wrap a function for HTTP POST, terminal exposure and enable tracing.
57
-
58
- Args:
59
- BaseDecorator (object): base decorator class
60
-
61
- Example:
62
- ```python
63
- import agenta as ag
64
-
65
- @ag.entrypoint(enable_tracing=True) # Defaults to False
66
- async def chain_of_prompts_llm(prompt: str):
67
- return ...
68
- ```
69
- """
70
-
71
- def __init__(self, enable_tracing: bool = False):
72
- self.enable_tracing = enable_tracing
73
-
74
- def __call__(self, func: Callable[..., Any]):
75
- endpoint_name = "generate"
76
- func_signature = inspect.signature(func)
77
- config_params = agenta.config.all()
78
- ingestible_files = self.extract_ingestible_files(func_signature)
79
-
80
- if self.enable_tracing:
81
- from .tracing import trace
82
-
83
- func = trace()(func)
84
-
85
- @functools.wraps(func)
86
- async def wrapper(*args, **kwargs) -> Any:
87
- func_params, api_config_params = self.split_kwargs(kwargs, config_params)
88
- self.ingest_files(func_params, ingestible_files)
89
- agenta.config.set(**api_config_params)
90
- llm_result = await self.execute_function(
91
- func, *args, params=func_params, config_params=config_params
92
- )
93
- return llm_result
94
-
95
- @functools.wraps(func)
96
- async def wrapper_deployed(*args, **kwargs) -> Any:
97
- func_params = {
98
- k: v for k, v in kwargs.items() if k not in ["config", "environment"]
99
- }
100
-
101
- if "environment" in kwargs and kwargs["environment"] is not None:
102
- agenta.config.pull(environment_name=kwargs["environment"])
103
- elif "config" in kwargs and kwargs["config"] is not None:
104
- agenta.config.pull(config_name=kwargs["config"])
105
- else:
106
- agenta.config.pull(config_name="default")
107
-
108
- llm_result = await self.execute_function(
109
- func, *args, params=func_params, config_params=config_params
110
- )
111
- return llm_result
112
-
113
- self.update_function_signature(
114
- wrapper, func_signature, config_params, ingestible_files
115
- )
116
- route = f"/{endpoint_name}"
117
- app.post(route, response_model=FuncResponse)(wrapper)
118
-
119
- self.update_deployed_function_signature(
120
- wrapper_deployed,
121
- func_signature,
122
- ingestible_files,
123
- )
124
- route_deployed = f"/{endpoint_name}_deployed"
125
- app.post(route_deployed, response_model=FuncResponse)(wrapper_deployed)
126
- self.override_schema(
127
- openapi_schema=app.openapi(),
128
- func_name=func.__name__,
129
- endpoint=endpoint_name,
130
- params={**config_params, **func_signature.parameters},
131
- )
132
-
133
- if self.is_main_script(func):
134
- result = self.handle_terminal_run(
135
- func,
136
- func_signature.parameters, # type: ignore
137
- config_params,
138
- ingestible_files,
139
- )
140
- return result
141
-
142
- def extract_ingestible_files(
143
- self,
144
- func_signature: inspect.Signature,
145
- ) -> Dict[str, inspect.Parameter]:
146
- """Extract parameters annotated as InFile from function signature."""
147
-
148
- return {
149
- name: param
150
- for name, param in func_signature.parameters.items()
151
- if param.annotation is InFile
152
- }
153
-
154
- def split_kwargs(
155
- self, kwargs: Dict[str, Any], config_params: Dict[str, Any]
156
- ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
157
- """Split keyword arguments into function parameters and API configuration parameters."""
158
-
159
- func_params = {k: v for k, v in kwargs.items() if k not in config_params}
160
- api_config_params = {k: v for k, v in kwargs.items() if k in config_params}
161
- return func_params, api_config_params
162
-
163
- def ingest_file(self, upfile: UploadFile):
164
- temp_file = NamedTemporaryFile(delete=False)
165
- temp_file.write(upfile.file.read())
166
- temp_file.close()
167
- return InFile(file_name=upfile.filename, file_path=temp_file.name)
168
-
169
- def ingest_files(
170
- self,
171
- func_params: Dict[str, Any],
172
- ingestible_files: Dict[str, inspect.Parameter],
173
- ) -> None:
174
- """Ingest files specified in function parameters."""
175
-
176
- for name in ingestible_files:
177
- if name in func_params and func_params[name] is not None:
178
- func_params[name] = self.ingest_file(func_params[name])
179
-
180
- async def execute_function(self, func: Callable[..., Any], *args, **func_params):
181
- """Execute the function and handle any exceptions."""
182
-
183
- try:
184
- """Note: The following block is for backward compatibility.
185
- It allows functions to work seamlessly whether they are synchronous or asynchronous.
186
- For synchronous functions, it calls them directly, while for asynchronous functions,
187
- it awaits their execution.
188
- """
189
- is_coroutine_function = inspect.iscoroutinefunction(func)
190
- start_time = time.perf_counter()
191
- if is_coroutine_function:
192
- result = await func(*args, **func_params["params"])
193
- else:
194
- result = func(*args, **func_params["params"])
195
-
196
- end_time = time.perf_counter()
197
- latency = end_time - start_time
198
-
199
- if isinstance(result, Context):
200
- save_context(result)
201
- if isinstance(result, Dict):
202
- return FuncResponse(**result, latency=round(latency, 4))
203
- if isinstance(result, str):
204
- return FuncResponse(message=result, latency=round(latency, 4)) # type: ignore
205
- except Exception as e:
206
- self.handle_exception(e)
207
- return FuncResponse(message="Unexpected error occurred", latency=0) # type: ignore
208
-
209
- def handle_exception(self, e: Exception):
210
- """Handle exceptions."""
211
-
212
- status_code: int = e.status_code if hasattr(e, "status_code") else 500
213
- traceback_str = traceback.format_exception(e, value=e, tb=e.__traceback__) # type: ignore
214
- raise HTTPException(
215
- status_code=status_code,
216
- detail={"error": str(e), "traceback": "".join(traceback_str)},
217
- )
218
-
219
- def update_wrapper_signature(
220
- self, wrapper: Callable[..., Any], updated_params: List
221
- ):
222
- """
223
- Updates the signature of a wrapper function with a new list of parameters.
224
-
225
- Args:
226
- wrapper (callable): A callable object, such as a function or a method, that requires a signature update.
227
- updated_params (List[inspect.Parameter]): A list of `inspect.Parameter` objects representing the updated parameters
228
- for the wrapper function.
229
- """
230
-
231
- wrapper_signature = inspect.signature(wrapper)
232
- wrapper_signature = wrapper_signature.replace(parameters=updated_params)
233
- wrapper.__signature__ = wrapper_signature
234
-
235
- def update_function_signature(
236
- self,
237
- wrapper: Callable[..., Any],
238
- func_signature: inspect.Signature,
239
- config_params: Dict[str, Any],
240
- ingestible_files: Dict[str, inspect.Parameter],
241
- ) -> None:
242
- """Update the function signature to include new parameters."""
243
-
244
- updated_params = []
245
- self.add_config_params_to_parser(updated_params, config_params)
246
- self.add_func_params_to_parser(updated_params, func_signature, ingestible_files)
247
- self.update_wrapper_signature(wrapper, updated_params)
248
-
249
- def update_deployed_function_signature(
250
- self,
251
- wrapper: Callable[..., Any],
252
- func_signature: inspect.Signature,
253
- ingestible_files: Dict[str, inspect.Parameter],
254
- ) -> None:
255
- """Update the function signature to include new parameters."""
256
- updated_params = []
257
- self.add_func_params_to_parser(updated_params, func_signature, ingestible_files)
258
- for param in [
259
- "config",
260
- "environment",
261
- ]: # we add the config and environment parameters
262
- updated_params.append(
263
- inspect.Parameter(
264
- param,
265
- inspect.Parameter.KEYWORD_ONLY,
266
- default=Body(None),
267
- annotation=str,
268
- )
269
- )
270
- self.update_wrapper_signature(wrapper, updated_params)
271
-
272
- def add_config_params_to_parser(
273
- self, updated_params: list, config_params: Dict[str, Any]
274
- ) -> None:
275
- """Add configuration parameters to function signature."""
276
- for name, param in config_params.items():
277
- updated_params.append(
278
- inspect.Parameter(
279
- name,
280
- inspect.Parameter.KEYWORD_ONLY,
281
- default=Body(param),
282
- annotation=Optional[type(param)],
283
- )
284
- )
285
-
286
- def add_func_params_to_parser(
287
- self,
288
- updated_params: list,
289
- func_signature: inspect.Signature,
290
- ingestible_files: Dict[str, inspect.Parameter],
291
- ) -> None:
292
- """Add function parameters to function signature."""
293
- for name, param in func_signature.parameters.items():
294
- if name in ingestible_files:
295
- updated_params.append(
296
- inspect.Parameter(name, param.kind, annotation=UploadFile)
297
- )
298
- else:
299
- updated_params.append(
300
- inspect.Parameter(
301
- name,
302
- inspect.Parameter.KEYWORD_ONLY,
303
- default=Body(..., embed=True),
304
- annotation=param.annotation,
305
- )
306
- )
307
-
308
- def is_main_script(self, func: Callable) -> bool:
309
- """
310
- Check if the script containing the function is the main script being run.
311
-
312
- Args:
313
- func (Callable): The function object to check.
314
-
315
- Returns:
316
- bool: True if the script containing the function is the main script, False otherwise.
317
-
318
- Example:
319
- if is_main_script(my_function):
320
- print("This is the main script.")
321
- """
322
- return os.path.splitext(os.path.basename(inspect.getfile(func)))[
323
- 0
324
- ] == "tracing" or (
325
- os.path.splitext(os.path.basename(sys.argv[0]))[0]
326
- == os.path.splitext(os.path.basename(inspect.getfile(func)))[0]
327
- )
328
-
329
- def handle_terminal_run(
330
- self,
331
- func: Callable,
332
- func_params: Dict[str, inspect.Parameter],
333
- config_params: Dict[str, Any],
334
- ingestible_files: Dict,
335
- ):
336
- """
337
- Parses command line arguments and sets configuration when script is run from the terminal.
338
-
339
- Args:
340
- func_params (dict): A dictionary containing the function parameters and their annotations.
341
- config_params (dict): A dictionary containing the configuration parameters.
342
- ingestible_files (dict): A dictionary containing the files that should be ingested.
343
- """
344
-
345
- # For required parameters, we add them as arguments
346
- parser = argparse.ArgumentParser()
347
- for name, param in func_params.items():
348
- if name in ingestible_files:
349
- parser.add_argument(name, type=str)
350
- else:
351
- parser.add_argument(name, type=param.annotation)
352
-
353
- for name, param in config_params.items():
354
- if type(param) is MultipleChoiceParam:
355
- parser.add_argument(
356
- f"--{name}",
357
- type=str,
358
- default=param.default,
359
- choices=param.choices,
360
- )
361
- else:
362
- parser.add_argument(
363
- f"--{name}",
364
- type=type(param),
365
- default=param,
366
- )
367
-
368
- args = parser.parse_args()
369
-
370
- # split the arg list into the arg in the app_param and
371
- # the args from the sig.parameter
372
- args_config_params = {k: v for k, v in vars(args).items() if k in config_params}
373
- args_func_params = {
374
- k: v for k, v in vars(args).items() if k not in config_params
375
- }
376
- for name in ingestible_files:
377
- args_func_params[name] = InFile(
378
- file_name=Path(args_func_params[name]).stem,
379
- file_path=args_func_params[name],
380
- )
381
-
382
- agenta.config.set(**args_config_params)
383
-
384
- loop = asyncio.get_event_loop()
385
- result = loop.run_until_complete(
386
- self.execute_function(
387
- func,
388
- **{"params": args_func_params, "config_params": args_config_params},
389
- )
390
- )
391
- print(
392
- f"\n========== Result ==========\n\nMessage: {result.message}\nCost: {result.cost}\nToken Usage: {result.usage}"
393
- )
394
- return result
395
-
396
- def override_schema(
397
- self, openapi_schema: dict, func_name: str, endpoint: str, params: dict
398
- ):
399
- """
400
- Overrides the default openai schema generated by fastapi with additional information about:
401
- - The choices available for each MultipleChoiceParam instance
402
- - The min and max values for each FloatParam instance
403
- - The min and max values for each IntParam instance
404
- - The default value for DictInput instance
405
- - The default value for MessagesParam instance
406
- - The default value for FileInputURL instance
407
- - The default value for BinaryParam instance
408
- - ... [PLEASE ADD AT EACH CHANGE]
409
-
410
- Args:
411
- openapi_schema (dict): The openapi schema generated by fastapi
412
- func_name (str): The name of the function to override
413
- endpoint (str): The name of the endpoint to override
414
- params (dict(param_name, param_val)): The dictionary of the parameters for the function
415
- """
416
-
417
- def find_in_schema(schema: dict, param_name: str, xparam: str):
418
- """Finds a parameter in the schema based on its name and x-parameter value"""
419
- for _, value in schema.items():
420
- value_title_lower = str(value.get("title")).lower()
421
- value_title = (
422
- "_".join(value_title_lower.split())
423
- if len(value_title_lower.split()) >= 2
424
- else value_title_lower
425
- )
426
-
427
- if (
428
- isinstance(value, dict)
429
- and value.get("x-parameter") == xparam
430
- and value_title == param_name
431
- ):
432
- return value
433
-
434
- schema_to_override = openapi_schema["components"]["schemas"][
435
- f"Body_{func_name}_{endpoint}_post"
436
- ]["properties"]
437
- for param_name, param_val in params.items():
438
- if isinstance(param_val, GroupedMultipleChoiceParam):
439
- subschema = find_in_schema(
440
- schema_to_override, param_name, "grouped_choice"
441
- )
442
- assert (
443
- subschema
444
- ), f"GroupedMultipleChoiceParam '{param_name}' is in the parameters but could not be found in the openapi.json"
445
- subschema["choices"] = param_val.choices
446
- subschema["default"] = param_val.default
447
- if isinstance(param_val, MultipleChoiceParam):
448
- subschema = find_in_schema(schema_to_override, param_name, "choice")
449
- default = str(param_val)
450
- param_choices = param_val.choices
451
- choices = (
452
- [default] + param_choices
453
- if param_val not in param_choices
454
- else param_choices
455
- )
456
- subschema["enum"] = choices
457
- subschema["default"] = (
458
- default if default in param_choices else choices[0]
459
- )
460
- if isinstance(param_val, FloatParam):
461
- subschema = find_in_schema(schema_to_override, param_name, "float")
462
- subschema["minimum"] = param_val.minval
463
- subschema["maximum"] = param_val.maxval
464
- subschema["default"] = param_val
465
- if isinstance(param_val, IntParam):
466
- subschema = find_in_schema(schema_to_override, param_name, "int")
467
- subschema["minimum"] = param_val.minval
468
- subschema["maximum"] = param_val.maxval
469
- subschema["default"] = param_val
470
- if (
471
- isinstance(param_val, inspect.Parameter)
472
- and param_val.annotation is DictInput
473
- ):
474
- subschema = find_in_schema(schema_to_override, param_name, "dict")
475
- subschema["default"] = param_val.default["default_keys"]
476
- if isinstance(param_val, TextParam):
477
- subschema = find_in_schema(schema_to_override, param_name, "text")
478
- subschema["default"] = param_val
479
- if (
480
- isinstance(param_val, inspect.Parameter)
481
- and param_val.annotation is MessagesInput
482
- ):
483
- subschema = find_in_schema(schema_to_override, param_name, "messages")
484
- subschema["default"] = param_val.default
485
- if (
486
- isinstance(param_val, inspect.Parameter)
487
- and param_val.annotation is FileInputURL
488
- ):
489
- subschema = find_in_schema(schema_to_override, param_name, "file_url")
490
- subschema["default"] = "https://example.com"
491
- if isinstance(param_val, BinaryParam):
492
- subschema = find_in_schema(schema_to_override, param_name, "bool")
493
- subschema["default"] = param_val.default
@@ -1,112 +0,0 @@
1
- # Stdlib Imports
2
- import os
3
- import inspect
4
- from functools import wraps
5
- from typing import Any, Callable
6
-
7
- # Own Imports
8
- import agenta as ag
9
- from agenta.sdk.decorators.base import BaseDecorator
10
-
11
-
12
- class span(BaseDecorator):
13
- """Decorator class for starting and ending spans of a trace.
14
-
15
- Args:
16
- BaseDecorator (object): base decorator class
17
-
18
- Example:
19
- ```python
20
- import agenta as ag
21
-
22
- @ag.span(type="llm")
23
- async def openai_llm_call(prompt: str):
24
- return ...
25
- ```
26
- """
27
-
28
- def __init__(self, type: str):
29
- self.type = type
30
- self.trace = ag.llm_tracing()
31
-
32
- def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]:
33
- @wraps(func)
34
- async def wrapper(*args, **kwargs):
35
- result = None
36
- span = self.trace.start_span(
37
- name=func.__name__,
38
- input=kwargs,
39
- spankind=self.type,
40
- )
41
- try:
42
- is_coroutine_function = inspect.iscoroutinefunction(func)
43
- if is_coroutine_function:
44
- result = await func(*args, **kwargs)
45
- else:
46
- result = func(*args, **kwargs)
47
- self.trace.update_span_status(span=span, value="OK")
48
- except Exception as e:
49
- result = str(e)
50
- self.trace.update_span_status(span=span, value="ERROR")
51
- finally:
52
- if not isinstance(result, dict):
53
- result = {"message": result}
54
- self.trace.end_span(outputs=result, span=span)
55
- return result
56
-
57
- return wrapper
58
-
59
-
60
- class trace(BaseDecorator):
61
- """Decorator class for starting and ending the parent span of a trace.
62
-
63
- Args:
64
- BaseDecorator (object): base decorator class
65
-
66
- Example:
67
- ```python
68
- import agenta as ag
69
-
70
- @ag.trace()
71
- async def generate(prompt: str):
72
- return ...
73
- ```
74
- """
75
-
76
- def __init__(self):
77
- self.trace = ag.llm_tracing()
78
-
79
- def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]:
80
- @wraps(func)
81
- async def wrapper(*args, **kwargs):
82
- result = None
83
- self.trace.start_parent_span(
84
- name=func.__name__,
85
- inputs=kwargs,
86
- config=ag.config.all(),
87
- environment=os.environ.get("AGENTA_LLM_RUN_ENVIRONMENT"), # type: ignore
88
- )
89
- try:
90
- is_coroutine_function = inspect.iscoroutinefunction(func)
91
- if is_coroutine_function:
92
- result = await func(*args, **kwargs)
93
- else:
94
- result = func(*args, **kwargs)
95
-
96
- self.trace.update_span_status(span=self.trace.active_trace, value="OK")
97
- except Exception as e:
98
- result = str(e)
99
- self.trace.update_span_status(
100
- span=self.trace.active_trace, value="ERROR"
101
- )
102
- finally:
103
- if not isinstance(result, dict):
104
- result = {"message": result}
105
-
106
- self.trace.end_recording(
107
- outputs=result,
108
- span=self.trace.active_trace,
109
- )
110
- return result
111
-
112
- return wrapper