agenta 0.19.9__py3-none-any.whl → 0.20.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of agenta might be problematic. Click here for more details.

@@ -8,7 +8,6 @@ from typing import Any, Callable, Optional
8
8
  import agenta as ag
9
9
  from agenta.sdk.decorators.base import BaseDecorator
10
10
  from agenta.sdk.tracing.logger import llm_logger as logging
11
- from agenta.sdk.tracing.tracing_context import tracing_context, TracingContext
12
11
  from agenta.sdk.utils.debug import debug, DEBUG, SHIFT
13
12
 
14
13
 
@@ -38,149 +37,55 @@ class instrument(BaseDecorator):
38
37
  """
39
38
 
40
39
  def __init__(
41
- self, config: Optional[dict] = None, spankind: str = "workflow"
40
+ self,
41
+ config: Optional[dict] = None,
42
+ spankind: str = "workflow",
42
43
  ) -> None:
43
44
  self.config = config
44
45
  self.spankind = spankind
45
- self.tracing = ag.tracing
46
46
 
47
47
  def __call__(self, func: Callable[..., Any]):
48
48
  is_coroutine_function = inspect.iscoroutinefunction(func)
49
49
 
50
- @debug()
51
- @wraps(func)
52
- async def async_wrapper(*args, **kwargs):
53
- result = None
50
+ def get_inputs(*args, **kwargs):
54
51
  func_args = inspect.getfullargspec(func).args
55
52
  input_dict = {name: value for name, value in zip(func_args, args)}
56
53
  input_dict.update(kwargs)
57
54
 
58
- async def wrapped_func(*args, **kwargs):
59
- # logging.debug(" ".join([">..", str(tracing_context.get())]))
60
-
61
- token = None
62
- if tracing_context.get() is None:
63
- token = tracing_context.set(TracingContext())
64
-
65
- # logging.debug(" ".join([">>.", str(tracing_context.get())]))
55
+ return input_dict
66
56
 
67
- self.tracing.start_span(
57
+ @wraps(func)
58
+ async def async_wrapper(*args, **kwargs):
59
+ async def wrapped_func(*args, **kwargs):
60
+ with ag.tracing.Context(
68
61
  name=func.__name__,
69
- input=input_dict,
62
+ input=get_inputs(*args, **kwargs),
70
63
  spankind=self.spankind,
71
64
  config=self.config,
72
- )
73
-
74
- try:
65
+ ):
75
66
  result = await func(*args, **kwargs)
76
67
 
77
- self.tracing.set_status(status="OK")
78
- self.tracing.end_span(
79
- outputs=(
80
- {"message": result}
81
- if not isinstance(result, dict)
82
- else result
83
- )
84
- )
85
-
86
- # logging.debug(" ".join(["<<.", str(tracing_context.get())]))
87
-
88
- if token is not None:
89
- tracing_context.reset(token)
90
-
91
- # logging.debug(" ".join(["<..", str(tracing_context.get())]))
68
+ ag.tracing.store_outputs(result)
92
69
 
93
70
  return result
94
71
 
95
- except Exception as e:
96
- result = {
97
- "message": str(e),
98
- "stacktrace": traceback.format_exc(),
99
- }
100
-
101
- self.tracing.set_attributes(
102
- {"traceback_exception": traceback.format_exc()}
103
- )
104
- self.tracing.set_status(status="ERROR")
105
- self.tracing.end_span(outputs=result)
106
-
107
- # logging.debug(" ".join(["<<.", str(tracing_context.get())]))
108
-
109
- if token is not None:
110
- tracing_context.reset(token)
111
-
112
- # logging.debug(" ".join(["<..", str(tracing_context.get())]))
113
-
114
- raise e
115
-
116
72
  return await wrapped_func(*args, **kwargs)
117
73
 
118
74
  @wraps(func)
119
75
  def sync_wrapper(*args, **kwargs):
120
- result = None
121
- func_args = inspect.getfullargspec(func).args
122
- input_dict = {name: value for name, value in zip(func_args, args)}
123
- input_dict.update(kwargs)
124
-
125
76
  def wrapped_func(*args, **kwargs):
126
- # logging.debug(" ".join([">..", str(tracing_context.get())]))
127
-
128
- token = None
129
- if tracing_context.get() is None:
130
- token = tracing_context.set(TracingContext())
131
-
132
- # logging.debug(" ".join([">>.", str(tracing_context.get())]))
133
-
134
- span = self.tracing.start_span(
77
+ with ag.tracing.Context(
135
78
  name=func.__name__,
136
- input=input_dict,
79
+ input=get_inputs(*args, **kwargs),
137
80
  spankind=self.spankind,
138
81
  config=self.config,
139
- )
140
-
141
- try:
82
+ ):
142
83
  result = func(*args, **kwargs)
143
84
 
144
- self.tracing.set_status(status="OK")
145
- self.tracing.end_span(
146
- outputs=(
147
- {"message": result}
148
- if not isinstance(result, dict)
149
- else result
150
- )
151
- )
152
-
153
- # logging.debug(" ".join(["<<.", str(tracing_context.get())]))
154
-
155
- if token is not None:
156
- tracing_context.reset(token)
157
-
158
- # logging.debug(" ".join(["<..", str(tracing_context.get())]))
85
+ ag.tracing.store_outputs(result)
159
86
 
160
87
  return result
161
88
 
162
- except Exception as e:
163
- result = {
164
- "message": str(e),
165
- "stacktrace": traceback.format_exc(),
166
- }
167
-
168
- self.tracing.set_attributes(
169
- {"traceback_exception": traceback.format_exc()}
170
- )
171
-
172
- self.tracing.set_status(status="ERROR")
173
- self.tracing.end_span(outputs=result)
174
-
175
- # logging.debug(" ".join(["<<.", str(tracing_context.get())]))
176
-
177
- if token is not None:
178
- tracing_context.reset(token)
179
-
180
- # logging.debug(" ".join(["<..", str(tracing_context.get())]))
181
-
182
- raise e
183
-
184
89
  return wrapped_func(*args, **kwargs)
185
90
 
186
91
  return async_wrapper if is_coroutine_function else sync_wrapper
@@ -1,5 +1,9 @@
1
1
  import agenta as ag
2
2
 
3
+ from agenta.sdk.tracing.tracing_context import tracing_context, TracingContext
4
+
5
+ from agenta.sdk.utils.debug import debug
6
+
3
7
 
4
8
  def litellm_handler():
5
9
  try:
@@ -23,20 +27,25 @@ def litellm_handler():
23
27
  LitellmCustomLogger (object): custom logger that allows us to override the events to capture.
24
28
  """
25
29
 
30
+ def __init__(self):
31
+ self.span = None
32
+
26
33
  @property
27
34
  def _trace(self):
28
35
  return ag.tracing
29
36
 
37
+ @debug()
30
38
  def log_pre_api_call(self, model, messages, kwargs):
31
39
  call_type = kwargs.get("call_type")
32
40
  span_kind = (
33
41
  "llm" if call_type in ["completion", "acompletion"] else "embedding"
34
42
  )
35
43
 
36
- ag.tracing.start_span(
44
+ self.span = ag.tracing.open_span(
37
45
  name=f"{span_kind}_call",
38
46
  input={"messages": kwargs["messages"]},
39
47
  spankind=span_kind,
48
+ active=False,
40
49
  )
41
50
  ag.tracing.set_attributes(
42
51
  {
@@ -49,9 +58,10 @@ def litellm_handler():
49
58
  }
50
59
  )
51
60
 
61
+ @debug()
52
62
  def log_stream_event(self, kwargs, response_obj, start_time, end_time):
53
- ag.tracing.set_status(status="OK")
54
- ag.tracing.end_span(
63
+ ag.tracing.set_status(status="OK", span_id=self.span.id)
64
+ ag.tracing.store_outputs(
55
65
  outputs={
56
66
  "message": kwargs.get(
57
67
  "complete_streaming_response"
@@ -65,13 +75,16 @@ def litellm_handler():
65
75
  "response_cost"
66
76
  ), # litellm calculates response cost
67
77
  },
78
+ span_id=self.span.id,
68
79
  )
80
+ ag.tracing.close_span(span_id=self.span.id)
69
81
 
82
+ @debug()
70
83
  def log_success_event(
71
84
  self, kwargs, response_obj: ModelResponse, start_time, end_time
72
85
  ):
73
- ag.tracing.set_status(status="OK")
74
- ag.tracing.end_span(
86
+ ag.tracing.set_status(status="OK", span_id=self.span.id)
87
+ ag.tracing.store_outputs(
75
88
  outputs={
76
89
  "message": response_obj.choices[0].message.content,
77
90
  "usage": (
@@ -83,12 +96,15 @@ def litellm_handler():
83
96
  "response_cost"
84
97
  ), # litellm calculates response cost
85
98
  },
99
+ span_id=self.span.id,
86
100
  )
101
+ ag.tracing.close_span(span_id=self.span.id)
87
102
 
103
+ @debug()
88
104
  def log_failure_event(
89
105
  self, kwargs, response_obj: ModelResponse, start_time, end_time
90
106
  ):
91
- ag.tracing.set_status(status="ERROR")
107
+ ag.tracing.set_status(status="ERROR", span_id=self.span.id)
92
108
  ag.tracing.set_attributes(
93
109
  {
94
110
  "traceback_exception": repr(
@@ -98,10 +114,11 @@ def litellm_handler():
98
114
  "end_time"
99
115
  ], # datetime object of when call was completed
100
116
  },
117
+ span_id=self.span.id,
101
118
  )
102
- ag.tracing.end_span(
119
+ ag.tracing.store_outputs(
103
120
  outputs={
104
- "message": kwargs["exception"], # the Exception raised
121
+ "message": repr(kwargs["exception"]), # the Exception raised
105
122
  "usage": (
106
123
  response_obj.usage.dict()
107
124
  if hasattr(response_obj, "usage")
@@ -111,13 +128,16 @@ def litellm_handler():
111
128
  "response_cost"
112
129
  ), # litellm calculates response cost
113
130
  },
131
+ span_id=self.span.id,
114
132
  )
133
+ ag.tracing.close_span(span_id=self.span.id)
115
134
 
135
+ @debug()
116
136
  async def async_log_stream_event(
117
137
  self, kwargs, response_obj, start_time, end_time
118
138
  ):
119
- ag.tracing.set_status(status="OK")
120
- ag.tracing.end_span(
139
+ ag.tracing.set_status(status="OK", span_id=self.span.id)
140
+ ag.tracing.store_outputs(
121
141
  outputs={
122
142
  "message": kwargs.get(
123
143
  "complete_streaming_response"
@@ -131,13 +151,16 @@ def litellm_handler():
131
151
  "response_cost"
132
152
  ), # litellm calculates response cost
133
153
  },
154
+ span_id=self.span.id,
134
155
  )
156
+ ag.tracing.close_span(span_id=self.span.id)
135
157
 
158
+ @debug()
136
159
  async def async_log_success_event(
137
160
  self, kwargs, response_obj, start_time, end_time
138
161
  ):
139
- ag.tracing.set_status(status="OK")
140
- ag.tracing.end_span(
162
+ ag.tracing.set_status(status="OK", span_id=self.span.id)
163
+ ag.tracing.store_outputs(
141
164
  outputs={
142
165
  "message": response_obj.choices[0].message.content,
143
166
  "usage": (
@@ -149,12 +172,15 @@ def litellm_handler():
149
172
  "response_cost"
150
173
  ), # litellm calculates response cost
151
174
  },
175
+ span_id=self.span.id,
152
176
  )
177
+ ag.tracing.close_span(span_id=self.span.id)
153
178
 
179
+ @debug()
154
180
  async def async_log_failure_event(
155
181
  self, kwargs, response_obj, start_time, end_time
156
182
  ):
157
- ag.tracing.set_status(status="ERROR")
183
+ ag.tracing.set_status(status="ERROR", span_id=self.span.id)
158
184
  ag.tracing.set_attributes(
159
185
  {
160
186
  "traceback_exception": kwargs[
@@ -164,8 +190,9 @@ def litellm_handler():
164
190
  "end_time"
165
191
  ], # datetime object of when call was completed
166
192
  },
193
+ span_id=self.span.id,
167
194
  )
168
- ag.tracing.end_span(
195
+ ag.tracing.store_outputs(
169
196
  outputs={
170
197
  "message": repr(kwargs["exception"]), # the Exception raised
171
198
  "usage": (
@@ -177,6 +204,8 @@ def litellm_handler():
177
204
  "response_cost"
178
205
  ), # litellm calculates response cost
179
206
  },
207
+ span_id=self.span.id,
180
208
  )
209
+ ag.tracing.close_span(span_id=self.span.id)
181
210
 
182
211
  return LitellmHandler()