agenta 0.14.1a0__py3-none-any.whl → 0.14.1a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of agenta might be problematic. Click here for more details.

@@ -1,515 +0,0 @@
1
- """The code for the Agenta SDK"""
2
-
3
- import os
4
- import sys
5
- import time
6
- import inspect
7
- import argparse
8
- import asyncio
9
- import traceback
10
- import functools
11
- from pathlib import Path
12
- from tempfile import NamedTemporaryFile
13
- from typing import Any, Callable, Dict, Optional, Tuple, List
14
-
15
- from fastapi.middleware.cors import CORSMiddleware
16
- from fastapi import Body, FastAPI, UploadFile, HTTPException
17
-
18
- import agenta
19
- from .context import save_context
20
- from .tracing.llm_tracing import Tracing
21
- from .router import router as router
22
- from .types import (
23
- Context,
24
- DictInput,
25
- FloatParam,
26
- InFile,
27
- IntParam,
28
- MultipleChoiceParam,
29
- GroupedMultipleChoiceParam,
30
- TextParam,
31
- MessagesInput,
32
- FileInputURL,
33
- FuncResponse,
34
- BinaryParam,
35
- )
36
-
37
- app = FastAPI()
38
-
39
- origins = [
40
- "*",
41
- ]
42
-
43
- app.add_middleware(
44
- CORSMiddleware,
45
- allow_origins=origins,
46
- allow_credentials=True,
47
- allow_methods=["*"],
48
- allow_headers=["*"],
49
- )
50
-
51
- app.include_router(router, prefix="")
52
-
53
-
54
- def ingest_file(upfile: UploadFile):
55
- temp_file = NamedTemporaryFile(delete=False)
56
- temp_file.write(upfile.file.read())
57
- temp_file.close()
58
- return InFile(file_name=upfile.filename, file_path=temp_file.name)
59
-
60
-
61
- def entrypoint(func: Callable[..., Any]) -> Callable[..., Any]:
62
- """
63
- Decorator to wrap a function for HTTP POST and terminal exposure.
64
-
65
- Args:
66
- func: Function to wrap.
67
-
68
- Returns:
69
- Wrapped function for HTTP POST and terminal.
70
- """
71
-
72
- endpoint_name = "generate"
73
- func_signature = inspect.signature(func)
74
- config_params = agenta.config.all()
75
- ingestible_files = extract_ingestible_files(func_signature)
76
-
77
- # Initialize tracing
78
- tracing = agenta.llm_tracing()
79
-
80
- @functools.wraps(func)
81
- async def wrapper(*args, **kwargs) -> Any:
82
- func_params, api_config_params = split_kwargs(kwargs, config_params)
83
-
84
- # Start tracing
85
- tracing.start_parent_span(
86
- name=func.__name__,
87
- inputs=func_params,
88
- config=config_params,
89
- environment="playground", # type: ignore #NOTE: wrapper is only called in playground
90
- )
91
-
92
- # Ingest files, prepare configurations and run llm app
93
- ingest_files(func_params, ingestible_files)
94
- agenta.config.set(**api_config_params)
95
- llm_result = await execute_function(
96
- func, *args, params=func_params, config_params=config_params
97
- )
98
-
99
- # End trace recording
100
- tracing.end_recording(
101
- outputs=llm_result.dict(),
102
- span=tracing.active_trace,
103
- )
104
- return llm_result
105
-
106
- @functools.wraps(func)
107
- async def wrapper_deployed(*args, **kwargs) -> Any:
108
- func_params = {
109
- k: v for k, v in kwargs.items() if k not in ["config", "environment"]
110
- }
111
- if "environment" in kwargs and kwargs["environment"] is not None:
112
- agenta.config.pull(environment_name=kwargs["environment"])
113
- elif "config" in kwargs and kwargs["config"] is not None:
114
- agenta.config.pull(config_name=kwargs["config"])
115
- else:
116
- agenta.config.pull(config_name="default")
117
-
118
- config = agenta.config.all()
119
-
120
- # Start tracing
121
- tracing.start_parent_span(
122
- name=func.__name__,
123
- inputs=func_params,
124
- config=config,
125
- environment=kwargs["environment"], # type: ignore #NOTE: wrapper is only called in playground
126
- )
127
-
128
- llm_result = await execute_function(
129
- func, *args, params=func_params, config_params=config_params
130
- )
131
-
132
- # End trace recording
133
- tracing.end_recording(
134
- outputs=llm_result.dict(),
135
- span=tracing.active_trace,
136
- )
137
- return llm_result
138
-
139
- update_function_signature(wrapper, func_signature, config_params, ingestible_files)
140
- route = f"/{endpoint_name}"
141
- app.post(route, response_model=FuncResponse)(wrapper)
142
-
143
- update_deployed_function_signature(
144
- wrapper_deployed,
145
- func_signature,
146
- ingestible_files,
147
- )
148
- route_deployed = f"/{endpoint_name}_deployed"
149
- app.post(route_deployed, response_model=FuncResponse)(wrapper_deployed)
150
- override_schema(
151
- openapi_schema=app.openapi(),
152
- func_name=func.__name__,
153
- endpoint=endpoint_name,
154
- params={**config_params, **func_signature.parameters},
155
- )
156
-
157
- if is_main_script(func):
158
- handle_terminal_run(
159
- func, func_signature.parameters, config_params, ingestible_files, tracing
160
- )
161
- return None
162
-
163
-
164
- def extract_ingestible_files(
165
- func_signature: inspect.Signature,
166
- ) -> Dict[str, inspect.Parameter]:
167
- """Extract parameters annotated as InFile from function signature."""
168
-
169
- return {
170
- name: param
171
- for name, param in func_signature.parameters.items()
172
- if param.annotation is InFile
173
- }
174
-
175
-
176
- def split_kwargs(
177
- kwargs: Dict[str, Any], config_params: Dict[str, Any]
178
- ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
179
- """Split keyword arguments into function parameters and API configuration parameters."""
180
-
181
- func_params = {k: v for k, v in kwargs.items() if k not in config_params}
182
- api_config_params = {k: v for k, v in kwargs.items() if k in config_params}
183
- return func_params, api_config_params
184
-
185
-
186
- def ingest_files(
187
- func_params: Dict[str, Any], ingestible_files: Dict[str, inspect.Parameter]
188
- ) -> None:
189
- """Ingest files specified in function parameters."""
190
-
191
- for name in ingestible_files:
192
- if name in func_params and func_params[name] is not None:
193
- func_params[name] = ingest_file(func_params[name])
194
-
195
-
196
- async def execute_function(func: Callable[..., Any], *args, **func_params):
197
- """Execute the function and handle any exceptions."""
198
-
199
- try:
200
- """Note: The following block is for backward compatibility.
201
- It allows functions to work seamlessly whether they are synchronous or asynchronous.
202
- For synchronous functions, it calls them directly, while for asynchronous functions,
203
- it awaits their execution.
204
- """
205
- is_coroutine_function = inspect.iscoroutinefunction(func)
206
- start_time = time.perf_counter()
207
- if is_coroutine_function:
208
- result = await func(*args, **func_params["params"])
209
- else:
210
- result = func(*args, **func_params["params"])
211
-
212
- end_time = time.perf_counter()
213
- latency = end_time - start_time
214
-
215
- if isinstance(result, Context):
216
- save_context(result)
217
- if isinstance(result, Dict):
218
- return FuncResponse(**result, latency=round(latency, 4))
219
- if isinstance(result, str):
220
- return FuncResponse(message=result, latency=round(latency, 4)) # type: ignore
221
- except Exception as e:
222
- handle_exception(e)
223
- return FuncResponse(message="Unexpected error occurred", latency=0) # type: ignore
224
-
225
-
226
- def handle_exception(e: Exception):
227
- """Handle exceptions."""
228
-
229
- status_code: int = e.status_code if hasattr(e, "status_code") else 500
230
- traceback_str = traceback.format_exception(e, value=e, tb=e.__traceback__) # type: ignore
231
- raise HTTPException(
232
- status_code=status_code,
233
- detail={"error": str(e), "traceback": "".join(traceback_str)},
234
- )
235
-
236
-
237
- def update_wrapper_signature(wrapper: Callable[..., Any], updated_params: List):
238
- """
239
- Updates the signature of a wrapper function with a new list of parameters.
240
-
241
- Args:
242
- wrapper (callable): A callable object, such as a function or a method, that requires a signature update.
243
- updated_params (List[inspect.Parameter]): A list of `inspect.Parameter` objects representing the updated parameters
244
- for the wrapper function.
245
- """
246
-
247
- wrapper_signature = inspect.signature(wrapper)
248
- wrapper_signature = wrapper_signature.replace(parameters=updated_params)
249
- wrapper.__signature__ = wrapper_signature
250
-
251
-
252
- def update_function_signature(
253
- wrapper: Callable[..., Any],
254
- func_signature: inspect.Signature,
255
- config_params: Dict[str, Any],
256
- ingestible_files: Dict[str, inspect.Parameter],
257
- ) -> None:
258
- """Update the function signature to include new parameters."""
259
-
260
- updated_params = []
261
- add_config_params_to_parser(updated_params, config_params)
262
- add_func_params_to_parser(updated_params, func_signature, ingestible_files)
263
- update_wrapper_signature(wrapper, updated_params)
264
-
265
-
266
- def update_deployed_function_signature(
267
- wrapper: Callable[..., Any],
268
- func_signature: inspect.Signature,
269
- ingestible_files: Dict[str, inspect.Parameter],
270
- ) -> None:
271
- """Update the function signature to include new parameters."""
272
- updated_params = []
273
- add_func_params_to_parser(updated_params, func_signature, ingestible_files)
274
- for param in [
275
- "config",
276
- "environment",
277
- ]: # we add the config and environment parameters
278
- updated_params.append(
279
- inspect.Parameter(
280
- param,
281
- inspect.Parameter.KEYWORD_ONLY,
282
- default=Body(None),
283
- annotation=str,
284
- )
285
- )
286
- update_wrapper_signature(wrapper, updated_params)
287
-
288
-
289
- def add_config_params_to_parser(
290
- updated_params: list, config_params: Dict[str, Any]
291
- ) -> None:
292
- """Add configuration parameters to function signature."""
293
- for name, param in config_params.items():
294
- updated_params.append(
295
- inspect.Parameter(
296
- name,
297
- inspect.Parameter.KEYWORD_ONLY,
298
- default=Body(param),
299
- annotation=Optional[type(param)],
300
- )
301
- )
302
-
303
-
304
- def add_func_params_to_parser(
305
- updated_params: list,
306
- func_signature: inspect.Signature,
307
- ingestible_files: Dict[str, inspect.Parameter],
308
- ) -> None:
309
- """Add function parameters to function signature."""
310
- for name, param in func_signature.parameters.items():
311
- if name in ingestible_files:
312
- updated_params.append(
313
- inspect.Parameter(name, param.kind, annotation=UploadFile)
314
- )
315
- else:
316
- updated_params.append(
317
- inspect.Parameter(
318
- name,
319
- inspect.Parameter.KEYWORD_ONLY,
320
- default=Body(..., embed=True),
321
- annotation=param.annotation,
322
- )
323
- )
324
-
325
-
326
- def is_main_script(func: Callable) -> bool:
327
- """
328
- Check if the script containing the function is the main script being run.
329
-
330
- Args:
331
- func (Callable): The function object to check.
332
-
333
- Returns:
334
- bool: True if the script containing the function is the main script, False otherwise.
335
-
336
- Example:
337
- if is_main_script(my_function):
338
- print("This is the main script.")
339
- """
340
- return (
341
- os.path.splitext(os.path.basename(sys.argv[0]))[0]
342
- == os.path.splitext(os.path.basename(inspect.getfile(func)))[0]
343
- )
344
-
345
-
346
- def handle_terminal_run(
347
- func: Callable,
348
- func_params: Dict[str, Any],
349
- config_params: Dict[str, Any],
350
- ingestible_files: Dict,
351
- tracing: Tracing,
352
- ) -> None:
353
- """
354
- Parses command line arguments and sets configuration when script is run from the terminal.
355
-
356
- Args:
357
- func_params (dict): A dictionary containing the function parameters and their annotations.
358
- config_params (dict): A dictionary containing the configuration parameters.
359
- ingestible_files (dict): A dictionary containing the files that should be ingested.
360
- tracing (Tracing): The tracing object
361
- """
362
-
363
- # For required parameters, we add them as arguments
364
- parser = argparse.ArgumentParser()
365
- for name, param in func_params.items():
366
- if name in ingestible_files:
367
- parser.add_argument(name, type=str)
368
- else:
369
- parser.add_argument(name, type=param.annotation)
370
-
371
- for name, param in config_params.items():
372
- if type(param) is MultipleChoiceParam:
373
- parser.add_argument(
374
- f"--{name}",
375
- type=str,
376
- default=param.default,
377
- choices=param.choices,
378
- )
379
- else:
380
- parser.add_argument(
381
- f"--{name}",
382
- type=type(param),
383
- default=param,
384
- )
385
-
386
- args = parser.parse_args()
387
-
388
- # split the arg list into the arg in the app_param and
389
- # the args from the sig.parameter
390
- args_config_params = {k: v for k, v in vars(args).items() if k in config_params}
391
- args_func_params = {k: v for k, v in vars(args).items() if k not in config_params}
392
- for name in ingestible_files:
393
- args_func_params[name] = InFile(
394
- file_name=Path(args_func_params[name]).stem,
395
- file_path=args_func_params[name],
396
- )
397
- agenta.config.set(**args_config_params)
398
-
399
- # Start tracing
400
- tracing.start_parent_span(
401
- name=func.__name__,
402
- inputs=args_func_params,
403
- config=args_config_params,
404
- environment="shell", # type: ignore
405
- )
406
-
407
- loop = asyncio.get_event_loop()
408
- result = loop.run_until_complete(
409
- execute_function(
410
- func, **{"params": args_func_params, "config_params": args_config_params}
411
- )
412
- )
413
-
414
- # End trace recording
415
- tracing.end_recording(
416
- outputs=result.dict(),
417
- span=tracing.active_trace, # type: ignore
418
- )
419
- print(
420
- f"\n========== Result ==========\n\nMessage: {result.message}\nCost: {result.cost}\nToken Usage: {result.usage}"
421
- )
422
-
423
-
424
- def override_schema(openapi_schema: dict, func_name: str, endpoint: str, params: dict):
425
- """
426
- Overrides the default openai schema generated by fastapi with additional information about:
427
- - The choices available for each MultipleChoiceParam instance
428
- - The min and max values for each FloatParam instance
429
- - The min and max values for each IntParam instance
430
- - The default value for DictInput instance
431
- - The default value for MessagesParam instance
432
- - The default value for FileInputURL instance
433
- - The default value for BinaryParam instance
434
- - ... [PLEASE ADD AT EACH CHANGE]
435
-
436
- Args:
437
- openapi_schema (dict): The openapi schema generated by fastapi
438
- func_name (str): The name of the function to override
439
- endpoint (str): The name of the endpoint to override
440
- params (dict(param_name, param_val)): The dictionary of the parameters for the function
441
- """
442
-
443
- def find_in_schema(schema: dict, param_name: str, xparam: str):
444
- """Finds a parameter in the schema based on its name and x-parameter value"""
445
- for _, value in schema.items():
446
- value_title_lower = str(value.get("title")).lower()
447
- value_title = (
448
- "_".join(value_title_lower.split())
449
- if len(value_title_lower.split()) >= 2
450
- else value_title_lower
451
- )
452
-
453
- if (
454
- isinstance(value, dict)
455
- and value.get("x-parameter") == xparam
456
- and value_title == param_name
457
- ):
458
- return value
459
-
460
- schema_to_override = openapi_schema["components"]["schemas"][
461
- f"Body_{func_name}_{endpoint}_post"
462
- ]["properties"]
463
- for param_name, param_val in params.items():
464
- if isinstance(param_val, GroupedMultipleChoiceParam):
465
- subschema = find_in_schema(schema_to_override, param_name, "grouped_choice")
466
- assert (
467
- subschema
468
- ), f"GroupedMultipleChoiceParam '{param_name}' is in the parameters but could not be found in the openapi.json"
469
- subschema["choices"] = param_val.choices
470
- subschema["default"] = param_val.default
471
- if isinstance(param_val, MultipleChoiceParam):
472
- subschema = find_in_schema(schema_to_override, param_name, "choice")
473
- default = str(param_val)
474
- param_choices = param_val.choices
475
- choices = (
476
- [default] + param_choices
477
- if param_val not in param_choices
478
- else param_choices
479
- )
480
- subschema["enum"] = choices
481
- subschema["default"] = default if default in param_choices else choices[0]
482
- if isinstance(param_val, FloatParam):
483
- subschema = find_in_schema(schema_to_override, param_name, "float")
484
- subschema["minimum"] = param_val.minval
485
- subschema["maximum"] = param_val.maxval
486
- subschema["default"] = param_val
487
- if isinstance(param_val, IntParam):
488
- subschema = find_in_schema(schema_to_override, param_name, "int")
489
- subschema["minimum"] = param_val.minval
490
- subschema["maximum"] = param_val.maxval
491
- subschema["default"] = param_val
492
- if (
493
- isinstance(param_val, inspect.Parameter)
494
- and param_val.annotation is DictInput
495
- ):
496
- subschema = find_in_schema(schema_to_override, param_name, "dict")
497
- subschema["default"] = param_val.default["default_keys"]
498
- if isinstance(param_val, TextParam):
499
- subschema = find_in_schema(schema_to_override, param_name, "text")
500
- subschema["default"] = param_val
501
- if (
502
- isinstance(param_val, inspect.Parameter)
503
- and param_val.annotation is MessagesInput
504
- ):
505
- subschema = find_in_schema(schema_to_override, param_name, "messages")
506
- subschema["default"] = param_val.default
507
- if (
508
- isinstance(param_val, inspect.Parameter)
509
- and param_val.annotation is FileInputURL
510
- ):
511
- subschema = find_in_schema(schema_to_override, param_name, "file_url")
512
- subschema["default"] = "https://example.com"
513
- if isinstance(param_val, BinaryParam):
514
- subschema = find_in_schema(schema_to_override, param_name, "bool")
515
- subschema["default"] = param_val.default
@@ -1,125 +0,0 @@
1
- # Own Imports
2
- from agenta.sdk import llm_tracing
3
-
4
- # Third Party Imports
5
- from litellm.utils import ModelResponse
6
- from litellm.integrations.custom_logger import CustomLogger as LitellmCustomLogger
7
-
8
-
9
- class AgentaLiteLLMHandler(LitellmCustomLogger):
10
- """This handler is responsible for logging certain events when using litellm to call LLMs.
11
-
12
- Args:
13
- LitellmCustomLogger (object): custom logger that allows us to override the events to capture.
14
- """
15
-
16
- @property
17
- def _trace(self):
18
- return llm_tracing()
19
-
20
- def log_pre_api_call(self, model, messages, kwargs):
21
- self._trace.start_span(
22
- name="pre_api_call",
23
- input=(
24
- {"messages": messages}
25
- if isinstance(messages, list)
26
- else {"inputs": messages}
27
- ),
28
- spankind=(
29
- "llm"
30
- if kwargs["call_type"] in ["completion", "acompletion"]
31
- else "unset"
32
- ),
33
- )
34
- self._trace.set_span_attribute(
35
- "model_config",
36
- {
37
- "model": kwargs.get("model"),
38
- "temperature": kwargs["optional_params"]["temperature"],
39
- },
40
- )
41
-
42
- def log_stream_event(self, kwargs, response_obj, start_time, end_time):
43
- self._trace.update_span_status(span=self._trace.active_span, value="OK")
44
- self._trace.end_span(
45
- outputs={
46
- "message": kwargs["complete_streaming_response"],
47
- "usage": kwargs["usage"],
48
- "cost": kwargs.get("response_cost"),
49
- },
50
- span=self._trace.active_span,
51
- )
52
-
53
- def log_success_event(
54
- self, kwargs, response_obj: ModelResponse, start_time, end_time
55
- ):
56
- self._trace.update_span_status(span=self._trace.active_span, value="OK")
57
- self._trace.end_span(
58
- outputs={
59
- "message": kwargs["message"],
60
- "usage": kwargs["usage"],
61
- "cost": kwargs.get("response_cost"),
62
- },
63
- span=self._trace.active_span,
64
- )
65
-
66
- def log_failure_event(
67
- self, kwargs, response_obj: ModelResponse, start_time, end_time
68
- ):
69
- self._trace.update_span_status(span=self._trace.active_span, value="ERROR")
70
- self._trace.set_span_attribute(
71
- attributes={
72
- "traceback_exception": kwargs["traceback_exception"],
73
- "call_end_time": kwargs["end_time"],
74
- },
75
- )
76
- self._trace.end_span(
77
- outputs={
78
- "message": kwargs["exception"],
79
- "usage": kwargs["usage"],
80
- "cost": kwargs.get("response_cost"),
81
- },
82
- span=self._trace.active_span,
83
- )
84
-
85
- async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
86
- self._trace.update_span_status(span=self._trace.active_span, value="OK")
87
- self._trace.end_span(
88
- outputs={
89
- "message": kwargs["complete_streaming_response"],
90
- "usage": kwargs["usage"],
91
- "cost": kwargs.get("response_cost"),
92
- },
93
- span=self._trace.active_span,
94
- )
95
-
96
- async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
97
- self._trace.update_span_status(span=self._trace.active_span, value="OK")
98
- self._trace.end_span(
99
- outputs={
100
- "message": kwargs["message"],
101
- "usage": kwargs["usage"],
102
- "cost": kwargs.get("response_cost"),
103
- },
104
- span=self._trace.active_span,
105
- )
106
-
107
- async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
108
- self._trace.update_span_status(span=self._trace.active_span, value="ERROR")
109
- self._trace.set_span_attribute(
110
- attributes={
111
- "traceback_exception": kwargs["traceback_exception"],
112
- "call_end_time": kwargs["end_time"],
113
- },
114
- )
115
- self._trace.end_span(
116
- outputs={
117
- "message": kwargs["exception"],
118
- "usage": kwargs["usage"],
119
- "cost": kwargs.get("response_cost"),
120
- },
121
- span=self._trace.active_span,
122
- )
123
-
124
-
125
- agenta_litellm_handler = AgentaLiteLLMHandler()
@@ -1,41 +0,0 @@
1
- # Stdlib Imports
2
- import inspect
3
- from functools import wraps
4
-
5
- # Own Imports
6
- import agenta as ag
7
-
8
-
9
- def span(type: str):
10
- """Decorator to automatically start and end spans."""
11
-
12
- tracing = ag.llm_tracing()
13
-
14
- def decorator(func):
15
- @wraps(func)
16
- async def wrapper(*args, **kwargs):
17
- result = None
18
- span = tracing.start_span(
19
- name=func.__name__,
20
- input=kwargs,
21
- spankind=type,
22
- )
23
- try:
24
- is_coroutine_function = inspect.iscoroutinefunction(func)
25
- if is_coroutine_function:
26
- result = await func(*args, **kwargs)
27
- else:
28
- result = func(*args, **kwargs)
29
- tracing.update_span_status(span=span, value="OK")
30
- except Exception as e:
31
- result = str(e)
32
- tracing.update_span_status(span=span, value="ERROR")
33
- finally:
34
- if not isinstance(result, dict):
35
- result = {"message": result}
36
- tracing.end_span(outputs=result, span=span)
37
- return result
38
-
39
- return wrapper
40
-
41
- return decorator