deepeval 3.5.1__py3-none-any.whl → 3.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. deepeval/_version.py +1 -1
  2. deepeval/config/settings.py +94 -2
  3. deepeval/config/utils.py +54 -1
  4. deepeval/constants.py +27 -0
  5. deepeval/integrations/langchain/__init__.py +2 -3
  6. deepeval/integrations/langchain/callback.py +126 -301
  7. deepeval/integrations/langchain/patch.py +24 -13
  8. deepeval/integrations/langchain/utils.py +203 -1
  9. deepeval/integrations/pydantic_ai/patcher.py +220 -185
  10. deepeval/integrations/pydantic_ai/utils.py +86 -0
  11. deepeval/metrics/conversational_g_eval/conversational_g_eval.py +1 -0
  12. deepeval/metrics/pii_leakage/pii_leakage.py +1 -1
  13. deepeval/models/embedding_models/azure_embedding_model.py +40 -9
  14. deepeval/models/embedding_models/local_embedding_model.py +54 -11
  15. deepeval/models/embedding_models/ollama_embedding_model.py +25 -7
  16. deepeval/models/embedding_models/openai_embedding_model.py +47 -5
  17. deepeval/models/llms/amazon_bedrock_model.py +31 -4
  18. deepeval/models/llms/anthropic_model.py +39 -13
  19. deepeval/models/llms/azure_model.py +37 -38
  20. deepeval/models/llms/deepseek_model.py +36 -7
  21. deepeval/models/llms/gemini_model.py +10 -0
  22. deepeval/models/llms/grok_model.py +50 -3
  23. deepeval/models/llms/kimi_model.py +37 -7
  24. deepeval/models/llms/local_model.py +38 -12
  25. deepeval/models/llms/ollama_model.py +15 -3
  26. deepeval/models/llms/openai_model.py +37 -44
  27. deepeval/models/mlllms/gemini_model.py +21 -3
  28. deepeval/models/mlllms/ollama_model.py +38 -13
  29. deepeval/models/mlllms/openai_model.py +18 -42
  30. deepeval/models/retry_policy.py +548 -64
  31. deepeval/prompt/api.py +13 -9
  32. deepeval/prompt/prompt.py +19 -9
  33. deepeval/tracing/tracing.py +87 -0
  34. deepeval/utils.py +12 -0
  35. {deepeval-3.5.1.dist-info → deepeval-3.5.3.dist-info}/METADATA +1 -1
  36. {deepeval-3.5.1.dist-info → deepeval-3.5.3.dist-info}/RECORD +39 -38
  37. {deepeval-3.5.1.dist-info → deepeval-3.5.3.dist-info}/LICENSE.md +0 -0
  38. {deepeval-3.5.1.dist-info → deepeval-3.5.3.dist-info}/WHEEL +0 -0
  39. {deepeval-3.5.1.dist-info → deepeval-3.5.3.dist-info}/entry_points.txt +0 -0
@@ -1,15 +1,12 @@
1
1
  from typing import Any, Optional, List, Dict
2
2
  from uuid import UUID
3
3
  from time import perf_counter
4
+ from deepeval.tracing.context import current_trace_context
4
5
  from deepeval.tracing.types import (
5
6
  LlmOutput,
6
7
  LlmToolCall,
7
- TraceAttributes,
8
8
  )
9
- from deepeval.metrics import BaseMetric, TaskCompletionMetric
10
- from deepeval.test_case import LLMTestCase
11
- from deepeval.test_run import global_test_run_manager
12
- import uuid
9
+ from deepeval.metrics import BaseMetric
13
10
 
14
11
  try:
15
12
  from langchain_core.callbacks.base import BaseCallbackHandler
@@ -20,11 +17,13 @@ try:
20
17
  # contains langchain imports
21
18
  from deepeval.integrations.langchain.utils import (
22
19
  parse_prompts_to_messages,
23
- prepare_dict,
24
20
  extract_name,
25
21
  safe_extract_model_name,
26
22
  safe_extract_token_usage,
23
+ enter_current_context,
24
+ exit_current_context,
27
25
  )
26
+ from deepeval.integrations.langchain.patch import tool
28
27
 
29
28
  langchain_installed = True
30
29
  except:
@@ -38,13 +37,8 @@ def is_langchain_installed():
38
37
  )
39
38
 
40
39
 
41
- # ASSUMPTIONS:
42
- # cycle for a single invoke call
43
- # one trace per cycle
44
-
45
40
  from deepeval.tracing import trace_manager
46
41
  from deepeval.tracing.types import (
47
- BaseSpan,
48
42
  LlmSpan,
49
43
  RetrieverSpan,
50
44
  TraceSpanStatus,
@@ -55,135 +49,32 @@ from deepeval.telemetry import capture_tracing_integration
55
49
 
56
50
  class CallbackHandler(BaseCallbackHandler):
57
51
 
58
- active_trace_id: Optional[str] = None
59
- metrics: List[BaseMetric] = []
60
- metric_collection: Optional[str] = None
61
-
62
52
  def __init__(
63
53
  self,
64
- metrics: List[BaseMetric] = [],
65
- metric_collection: Optional[str] = None,
66
54
  name: Optional[str] = None,
67
55
  tags: Optional[List[str]] = None,
68
56
  metadata: Optional[Dict[str, Any]] = None,
69
57
  thread_id: Optional[str] = None,
70
58
  user_id: Optional[str] = None,
59
+ metrics: Optional[List[BaseMetric]] = None,
60
+ metric_collection: Optional[str] = None,
71
61
  ):
72
62
  is_langchain_installed()
73
63
  with capture_tracing_integration("langchain.callback.CallbackHandler"):
64
+ trace = trace_manager.start_new_trace()
65
+
66
+ self.trace_uuid = trace.uuid
67
+
68
+ trace.name = name
69
+ trace.tags = tags
70
+ trace.metadata = metadata
71
+ trace.thread_id = thread_id
72
+ trace.user_id = user_id
74
73
  self.metrics = metrics
75
74
  self.metric_collection = metric_collection
76
- self.trace_attributes = TraceAttributes(
77
- name=name,
78
- tags=tags,
79
- metadata=metadata,
80
- thread_id=thread_id,
81
- user_id=user_id,
82
- )
75
+ current_trace_context.set(trace)
83
76
  super().__init__()
84
77
 
85
- def on_llm_new_token(
86
- self,
87
- token: str,
88
- *,
89
- chunk,
90
- run_id: UUID,
91
- parent_run_id: Optional[UUID] = None,
92
- tags: Optional[list[str]] = None,
93
- **kwargs: Any,
94
- ):
95
- llm_span: Optional[LlmSpan] = trace_manager.get_span_by_uuid(
96
- str(run_id)
97
- )
98
- if llm_span is None:
99
- return
100
- if llm_span.token_intervals is None:
101
- llm_span.token_intervals = {perf_counter(): token}
102
- else:
103
- llm_span.token_intervals[perf_counter()] = token
104
-
105
- def check_active_trace_id(self):
106
- if self.active_trace_id is None:
107
- self.active_trace_id = trace_manager.start_new_trace().uuid
108
-
109
- def add_span_to_trace(self, span: BaseSpan):
110
- trace_manager.add_span(span)
111
- trace_manager.add_span_to_trace(span)
112
-
113
- def end_span(self, span: BaseSpan):
114
- span.end_time = perf_counter()
115
- span.status = TraceSpanStatus.SUCCESS
116
- trace_manager.remove_span(str(span.uuid))
117
-
118
- ######## Conditions to add metric_collection to span ########
119
- if (
120
- self.metric_collection and span.parent_uuid is None
121
- ): # if span is a root span
122
- span.metric_collection = self.metric_collection
123
-
124
- ######## Conditions to add metrics to span ########
125
- if self.metrics and span.parent_uuid is None: # if span is a root span
126
-
127
- # prepare test_case for task_completion metric
128
- for metric in self.metrics:
129
- if isinstance(metric, TaskCompletionMetric):
130
- self.prepare_span_metric_test_case(metric, span)
131
-
132
- def end_trace(self, span: BaseSpan):
133
- current_trace = trace_manager.get_trace_by_uuid(self.active_trace_id)
134
-
135
- ######## Conditions send the trace for evaluation ########
136
- if self.metrics:
137
- trace_manager.evaluating = (
138
- True # to avoid posting the trace to the server
139
- )
140
- trace_manager.evaluation_loop = (
141
- True # to avoid traces being evaluated twice
142
- )
143
- trace_manager.integration_traces_to_evaluate.append(current_trace)
144
-
145
- if current_trace is not None:
146
- current_trace.input = span.input
147
- current_trace.output = span.output
148
-
149
- # set trace attributes
150
- if self.trace_attributes:
151
- if self.trace_attributes.name:
152
- current_trace.name = self.trace_attributes.name
153
- if self.trace_attributes.tags:
154
- current_trace.tags = self.trace_attributes.tags
155
- if self.trace_attributes.metadata:
156
- current_trace.metadata = self.trace_attributes.metadata
157
- if self.trace_attributes.thread_id:
158
- current_trace.thread_id = self.trace_attributes.thread_id
159
- if self.trace_attributes.user_id:
160
- current_trace.user_id = self.trace_attributes.user_id
161
-
162
- trace_manager.end_trace(self.active_trace_id)
163
- self.active_trace_id = None
164
-
165
- def prepare_span_metric_test_case(
166
- self, metric: TaskCompletionMetric, span: BaseSpan
167
- ):
168
- task_completion_metric = TaskCompletionMetric(
169
- threshold=metric.threshold,
170
- model=metric.model,
171
- include_reason=metric.include_reason,
172
- async_mode=metric.async_mode,
173
- strict_mode=metric.strict_mode,
174
- verbose_mode=metric.verbose_mode,
175
- )
176
- task_completion_metric.evaluation_cost = 0
177
- _llm_test_case = LLMTestCase(input="None", actual_output="None")
178
- _llm_test_case._trace_dict = trace_manager.create_nested_spans_dict(
179
- span
180
- )
181
- task, _ = task_completion_metric._extract_task_and_outcome(
182
- _llm_test_case
183
- )
184
- task_completion_metric.task = task
185
- span.metrics = [task_completion_metric]
186
-
187
78
  def on_chain_start(
188
79
  self,
189
80
  serialized: dict[str, Any],
@@ -195,43 +86,32 @@ class CallbackHandler(BaseCallbackHandler):
195
86
  metadata: Optional[dict[str, Any]] = None,
196
87
  **kwargs: Any,
197
88
  ) -> Any:
198
-
199
- self.check_active_trace_id()
200
- base_span = BaseSpan(
201
- uuid=str(run_id),
202
- status=TraceSpanStatus.ERRORED,
203
- children=[],
204
- trace_uuid=self.active_trace_id,
205
- parent_uuid=str(parent_run_id) if parent_run_id else None,
206
- start_time=perf_counter(),
207
- name=extract_name(serialized, **kwargs),
208
- input=inputs,
209
- metadata=prepare_dict(
210
- serialized=serialized, tags=tags, metadata=metadata, **kwargs
211
- ),
212
- # fallback for on_end callback
213
- end_time=perf_counter(),
214
- )
215
- self.add_span_to_trace(base_span)
89
+ if parent_run_id is None:
90
+ uuid_str = str(run_id)
91
+ base_span = enter_current_context(
92
+ uuid_str=uuid_str,
93
+ span_type="custom",
94
+ func_name=extract_name(serialized, **kwargs),
95
+ )
96
+ base_span.input = inputs
97
+ current_trace_context.get().input = inputs
98
+ base_span.metrics = self.metrics
99
+ base_span.metric_collection = self.metric_collection
216
100
 
217
101
  def on_chain_end(
218
102
  self,
219
- outputs: dict[str, Any],
103
+ output: Any,
220
104
  *,
221
105
  run_id: UUID,
222
106
  parent_run_id: Optional[UUID] = None,
223
- **kwargs: Any, # un-logged kwargs
107
+ **kwargs: Any,
224
108
  ) -> Any:
225
-
226
- base_span = trace_manager.get_span_by_uuid(str(run_id))
227
- if base_span is None:
228
- return
229
-
230
- base_span.output = outputs
231
- self.end_span(base_span)
232
-
233
- if parent_run_id is None:
234
- self.end_trace(base_span)
109
+ uuid_str = str(run_id)
110
+ base_span = trace_manager.get_span_by_uuid(uuid_str)
111
+ if base_span:
112
+ base_span.output = output
113
+ current_trace_context.get().output = output
114
+ exit_current_context(uuid_str=uuid_str)
235
115
 
236
116
  def on_llm_start(
237
117
  self,
@@ -244,36 +124,24 @@ class CallbackHandler(BaseCallbackHandler):
244
124
  metadata: Optional[dict[str, Any]] = None,
245
125
  **kwargs: Any,
246
126
  ) -> Any:
247
-
248
- self.check_active_trace_id()
249
-
250
- # extract input
127
+ uuid_str = str(run_id)
251
128
  input_messages = parse_prompts_to_messages(prompts, **kwargs)
252
-
253
- # extract model name
254
129
  model = safe_extract_model_name(metadata, **kwargs)
255
130
 
256
- llm_span = LlmSpan(
257
- uuid=str(run_id),
258
- status=TraceSpanStatus.ERRORED,
259
- children=[],
260
- trace_uuid=self.active_trace_id,
261
- parent_uuid=str(parent_run_id) if parent_run_id else None,
262
- start_time=perf_counter(),
263
- name=extract_name(serialized, **kwargs),
264
- input=input_messages,
265
- output="",
266
- metadata=prepare_dict(
267
- serialized=serialized, tags=tags, metadata=metadata, **kwargs
268
- ),
269
- model=model,
270
- # fallback for on_end callback
271
- end_time=perf_counter(),
272
- metric_collection=metadata.get("metric_collection", None),
273
- metrics=metadata.get("metrics", None),
131
+ llm_span: LlmSpan = enter_current_context(
132
+ uuid_str=uuid_str,
133
+ span_type="llm",
134
+ func_name=extract_name(serialized, **kwargs),
274
135
  )
275
136
 
276
- self.add_span_to_trace(llm_span)
137
+ llm_span.input = input_messages
138
+ llm_span.model = model
139
+ metrics = metadata.pop("metrics", None)
140
+ metric_collection = metadata.pop("metric_collection", None)
141
+ prompt = metadata.pop("prompt", None)
142
+ llm_span.metrics = metrics
143
+ llm_span.metric_collection = metric_collection
144
+ llm_span.prompt = prompt
277
145
 
278
146
  def on_llm_end(
279
147
  self,
@@ -283,12 +151,8 @@ class CallbackHandler(BaseCallbackHandler):
283
151
  parent_run_id: Optional[UUID] = None,
284
152
  **kwargs: Any, # un-logged kwargs
285
153
  ) -> Any:
286
- llm_span: LlmSpan = trace_manager.get_span_by_uuid(str(run_id))
287
- if llm_span is None:
288
- return
289
-
290
- if not isinstance(llm_span, LlmSpan):
291
- return
154
+ uuid_str = str(run_id)
155
+ llm_span: LlmSpan = trace_manager.get_span_by_uuid(uuid_str)
292
156
 
293
157
  output = ""
294
158
  total_input_tokens = 0
@@ -338,9 +202,38 @@ class CallbackHandler(BaseCallbackHandler):
338
202
  total_output_tokens if total_output_tokens > 0 else None
339
203
  )
340
204
 
341
- self.end_span(llm_span)
342
- if parent_run_id is None:
343
- self.end_trace(llm_span)
205
+ exit_current_context(uuid_str=uuid_str)
206
+
207
+ def on_llm_error(
208
+ self,
209
+ error: BaseException,
210
+ *,
211
+ run_id: UUID,
212
+ parent_run_id: Optional[UUID] = None,
213
+ **kwargs: Any,
214
+ ) -> Any:
215
+ uuid_str = str(run_id)
216
+ llm_span: LlmSpan = trace_manager.get_span_by_uuid(uuid_str)
217
+ llm_span.status = TraceSpanStatus.ERRORED
218
+ llm_span.error = str(error)
219
+ exit_current_context(uuid_str=uuid_str)
220
+
221
+ def on_llm_new_token(
222
+ self,
223
+ token: str,
224
+ *,
225
+ chunk,
226
+ run_id: UUID,
227
+ parent_run_id: Optional[UUID] = None,
228
+ tags: Optional[list[str]] = None,
229
+ **kwargs: Any,
230
+ ):
231
+ uuid_str = str(run_id)
232
+ llm_span: LlmSpan = trace_manager.get_span_by_uuid(uuid_str)
233
+ if llm_span.token_intervals is None:
234
+ llm_span.token_intervals = {perf_counter(): token}
235
+ else:
236
+ llm_span.token_intervals[perf_counter()] = token
344
237
 
345
238
  def on_tool_start(
346
239
  self,
@@ -354,27 +247,16 @@ class CallbackHandler(BaseCallbackHandler):
354
247
  inputs: Optional[dict[str, Any]] = None,
355
248
  **kwargs: Any,
356
249
  ) -> Any:
357
-
358
- self.check_active_trace_id()
359
-
360
- tool_span = ToolSpan(
361
- uuid=str(run_id),
362
- status=TraceSpanStatus.ERRORED,
363
- children=[],
364
- trace_uuid=self.active_trace_id,
365
- parent_uuid=str(parent_run_id) if parent_run_id else None,
366
- start_time=perf_counter(),
367
- name=extract_name(serialized, **kwargs),
368
- input=input_str,
369
- metadata=prepare_dict(
370
- serialized=serialized, tags=tags, metadata=metadata, **kwargs
371
- ),
372
- # fallback for on_end callback
373
- end_time=perf_counter(),
374
- metric_collection=metadata.get("metric_collection", None),
375
- metrics=metadata.get("metrics", None),
250
+ uuid_str = str(run_id)
251
+
252
+ tool_span = enter_current_context(
253
+ uuid_str=uuid_str,
254
+ span_type="tool",
255
+ func_name=extract_name(
256
+ serialized, **kwargs
257
+ ), # ignored when setting the input
376
258
  )
377
- self.add_span_to_trace(tool_span)
259
+ tool_span.input = inputs
378
260
 
379
261
  def on_tool_end(
380
262
  self,
@@ -385,16 +267,24 @@ class CallbackHandler(BaseCallbackHandler):
385
267
  **kwargs: Any, # un-logged kwargs
386
268
  ) -> Any:
387
269
 
388
- tool_span = trace_manager.get_span_by_uuid(str(run_id))
389
- if tool_span is None:
390
- return
391
-
270
+ uuid_str = str(run_id)
271
+ tool_span: ToolSpan = trace_manager.get_span_by_uuid(uuid_str)
392
272
  tool_span.output = output
273
+ exit_current_context(uuid_str=uuid_str)
393
274
 
394
- self.end_span(tool_span)
395
-
396
- if parent_run_id is None:
397
- self.end_trace(tool_span)
275
+ def on_tool_error(
276
+ self,
277
+ error: BaseException,
278
+ *,
279
+ run_id: UUID,
280
+ parent_run_id: Optional[UUID] = None,
281
+ **kwargs: Any, # un-logged kwargs
282
+ ) -> Any:
283
+ uuid_str = str(run_id)
284
+ tool_span: ToolSpan = trace_manager.get_span_by_uuid(uuid_str)
285
+ tool_span.status = TraceSpanStatus.ERRORED
286
+ tool_span.error = str(error)
287
+ exit_current_context(uuid_str=uuid_str)
398
288
 
399
289
  def on_retriever_start(
400
290
  self,
@@ -407,28 +297,16 @@ class CallbackHandler(BaseCallbackHandler):
407
297
  metadata: Optional[dict[str, Any]] = None,
408
298
  **kwargs: Any, # un-logged kwargs
409
299
  ) -> Any:
410
-
411
- self.check_active_trace_id()
412
-
413
- retriever_span = RetrieverSpan(
414
- uuid=str(run_id),
415
- status=TraceSpanStatus.ERRORED,
416
- children=[],
417
- trace_uuid=self.active_trace_id,
418
- parent_uuid=str(parent_run_id) if parent_run_id else None,
419
- start_time=perf_counter(),
420
- name=extract_name(serialized, **kwargs),
421
- embedder=metadata.get("ls_embedding_provider", "unknown"),
422
- metadata=prepare_dict(
423
- serialized=serialized, tags=tags, metadata=metadata, **kwargs
424
- ),
425
- # fallback for on_end callback
426
- end_time=perf_counter(),
300
+ uuid_str = str(run_id)
301
+ retriever_span = enter_current_context(
302
+ uuid_str=uuid_str,
303
+ span_type="retriever",
304
+ func_name=extract_name(serialized, **kwargs),
305
+ observe_kwargs={
306
+ "embedder": metadata.get("ls_embedding_provider", "unknown"),
307
+ },
427
308
  )
428
309
  retriever_span.input = query
429
- retriever_span.retrieval_context = []
430
-
431
- self.add_span_to_trace(retriever_span)
432
310
 
433
311
  def on_retriever_end(
434
312
  self,
@@ -438,11 +316,8 @@ class CallbackHandler(BaseCallbackHandler):
438
316
  parent_run_id: Optional[UUID] = None,
439
317
  **kwargs: Any, # un-logged kwargs
440
318
  ) -> Any:
441
-
442
- retriever_span = trace_manager.get_span_by_uuid(str(run_id))
443
-
444
- if retriever_span is None:
445
- return
319
+ uuid_str = str(run_id)
320
+ retriever_span: RetrieverSpan = trace_manager.get_span_by_uuid(uuid_str)
446
321
 
447
322
  # prepare output
448
323
  output_list = []
@@ -452,58 +327,8 @@ class CallbackHandler(BaseCallbackHandler):
452
327
  else:
453
328
  output_list.append(str(output))
454
329
 
455
- retriever_span.input = retriever_span.input
456
- retriever_span.retrieval_context = output_list
457
-
458
- self.end_span(retriever_span)
459
-
460
- if parent_run_id is None:
461
- self.end_trace(retriever_span)
462
-
463
- ################## on_error callbacks ###############
464
-
465
- def on_chain_error(
466
- self,
467
- error: BaseException,
468
- *,
469
- run_id: UUID,
470
- parent_run_id: Optional[UUID] = None,
471
- **kwargs: Any,
472
- ) -> None:
473
- base_span = trace_manager.get_span_by_uuid(str(run_id))
474
- if base_span is None:
475
- return
476
-
477
- base_span.end_time = perf_counter()
478
-
479
- def on_llm_error(
480
- self,
481
- error: BaseException,
482
- *,
483
- run_id: UUID,
484
- parent_run_id: Optional[UUID] = None,
485
- **kwargs: Any,
486
- ) -> Any:
487
-
488
- llm_span = trace_manager.get_span_by_uuid(str(run_id))
489
- if llm_span is None:
490
- return
491
-
492
- llm_span.end_time = perf_counter()
493
-
494
- def on_tool_error(
495
- self,
496
- error: BaseException,
497
- *,
498
- run_id: UUID,
499
- parent_run_id: Optional[UUID] = None,
500
- **kwargs: Any,
501
- ) -> Any:
502
- tool_span = trace_manager.get_span_by_uuid(str(run_id))
503
- if tool_span is None:
504
- return
505
-
506
- tool_span.end_time = perf_counter()
330
+ retriever_span.output = output_list
331
+ exit_current_context(uuid_str=uuid_str)
507
332
 
508
333
  def on_retriever_error(
509
334
  self,
@@ -511,10 +336,10 @@ class CallbackHandler(BaseCallbackHandler):
511
336
  *,
512
337
  run_id: UUID,
513
338
  parent_run_id: Optional[UUID] = None,
514
- **kwargs: Any,
339
+ **kwargs: Any, # un-logged kwargs
515
340
  ) -> Any:
516
- retriever_span = trace_manager.get_span_by_uuid(str(run_id))
517
- if retriever_span is None:
518
- return
519
-
520
- retriever_span.end_time = perf_counter()
341
+ uuid_str = str(run_id)
342
+ retriever_span: RetrieverSpan = trace_manager.get_span_by_uuid(uuid_str)
343
+ retriever_span.status = TraceSpanStatus.ERRORED
344
+ retriever_span.error = str(error)
345
+ exit_current_context(uuid_str=uuid_str)
@@ -1,7 +1,8 @@
1
- from langchain_core.tools import tool as original_tool, BaseTool
1
+ import functools
2
2
  from deepeval.metrics import BaseMetric
3
- from typing import List, Optional, Callable, Any
4
- from functools import wraps
3
+ from deepeval.tracing.context import current_span_context
4
+ from typing import List, Optional, Callable
5
+ from langchain_core.tools import tool as original_tool, BaseTool
5
6
 
6
7
 
7
8
  def tool(
@@ -16,17 +17,27 @@ def tool(
16
17
 
17
18
  # original_tool returns a decorator function, so we need to return a decorator
18
19
  def decorator(func: Callable) -> BaseTool:
19
-
20
- # Apply the original tool decorator to get the BaseTool
20
+ func = _patch_tool_decorator(func, metrics, metric_collection)
21
21
  tool_instance = original_tool(*args, **kwargs)(func)
22
-
23
- if isinstance(tool_instance, BaseTool):
24
- if tool_instance.metadata is None:
25
- tool_instance.metadata = {}
26
-
27
- tool_instance.metadata["metric_collection"] = metric_collection
28
- tool_instance.metadata["metrics"] = metrics
29
-
30
22
  return tool_instance
31
23
 
32
24
  return decorator
25
+
26
+
27
+ def _patch_tool_decorator(
28
+ func: Callable,
29
+ metrics: Optional[List[BaseMetric]] = None,
30
+ metric_collection: Optional[str] = None,
31
+ ):
32
+ original_func = func
33
+
34
+ @functools.wraps(original_func)
35
+ def wrapper(*args, **kwargs):
36
+ current_span = current_span_context.get()
37
+ current_span.metrics = metrics
38
+ current_span.metric_collection = metric_collection
39
+ res = original_func(*args, **kwargs)
40
+ return res
41
+
42
+ tool = wrapper
43
+ return tool