deepeval 3.8.2__py3-none-any.whl → 3.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,8 +17,9 @@ def wrap_crew_kickoff():
17
17
  func_name="kickoff",
18
18
  metric_collection=metric_collection,
19
19
  metrics=metrics,
20
- ):
20
+ ) as observer:
21
21
  result = original_kickoff(self, *args, **kwargs)
22
+ observer.result = str(result) if result else None
22
23
 
23
24
  return result
24
25
 
@@ -36,8 +37,9 @@ def wrap_crew_kickoff_for_each():
36
37
  func_name="kickoff_for_each",
37
38
  metric_collection=metric_collection,
38
39
  metrics=metrics,
39
- ):
40
+ ) as observer:
40
41
  result = original_kickoff_for_each(self, *args, **kwargs)
42
+ observer.result = str(result) if result else None
41
43
 
42
44
  return result
43
45
 
@@ -55,8 +57,9 @@ def wrap_crew_kickoff_async():
55
57
  func_name="kickoff_async",
56
58
  metric_collection=metric_collection,
57
59
  metrics=metrics,
58
- ):
60
+ ) as observer:
59
61
  result = await original_kickoff_async(self, *args, **kwargs)
62
+ observer.result = str(result) if result else None
60
63
 
61
64
  return result
62
65
 
@@ -74,33 +77,61 @@ def wrap_crew_kickoff_for_each_async():
74
77
  func_name="kickoff_for_each_async",
75
78
  metric_collection=metric_collection,
76
79
  metrics=metrics,
77
- ):
80
+ ) as observer:
78
81
  result = await original_kickoff_for_each_async(
79
82
  self, *args, **kwargs
80
83
  )
84
+ observer.result = str(result) if result else None
81
85
 
82
86
  return result
83
87
 
84
88
  Crew.kickoff_for_each_async = wrapper
85
89
 
86
90
 
87
- def wrap_llm_call():
88
- original_llm_call = LLM.call
91
+ def wrap_crew_akickoff():
92
+ if not hasattr(Crew, "akickoff"):
93
+ return
89
94
 
90
- @wraps(original_llm_call)
91
- def wrapper(self, *args, **kwargs):
95
+ original_akickoff = Crew.akickoff
96
+
97
+ @wraps(original_akickoff)
98
+ async def wrapper(self, *args, **kwargs):
99
+ metric_collection, metrics = _check_metrics_and_metric_collection(self)
100
+ with Observer(
101
+ span_type="crew",
102
+ func_name="akickoff",
103
+ metric_collection=metric_collection,
104
+ metrics=metrics,
105
+ ) as observer:
106
+ result = await original_akickoff(self, *args, **kwargs)
107
+ observer.result = str(result) if result else None
108
+
109
+ return result
110
+
111
+ Crew.akickoff = wrapper
112
+
113
+
114
+ def wrap_crew_akickoff_for_each():
115
+ if not hasattr(Crew, "akickoff_for_each"):
116
+ return
117
+
118
+ original_akickoff_for_each = Crew.akickoff_for_each
119
+
120
+ @wraps(original_akickoff_for_each)
121
+ async def wrapper(self, *args, **kwargs):
92
122
  metric_collection, metrics = _check_metrics_and_metric_collection(self)
93
123
  with Observer(
94
- span_type="llm",
95
- func_name="call",
96
- observe_kwargs={"model": "temp_model"},
124
+ span_type="crew",
125
+ func_name="akickoff_for_each",
97
126
  metric_collection=metric_collection,
98
127
  metrics=metrics,
99
- ):
100
- result = original_llm_call(self, *args, **kwargs)
128
+ ) as observer:
129
+ result = await original_akickoff_for_each(self, *args, **kwargs)
130
+ observer.result = str(result) if result else None
131
+
101
132
  return result
102
133
 
103
- LLM.call = wrapper
134
+ Crew.akickoff_for_each = wrapper
104
135
 
105
136
 
106
137
  def wrap_agent_execute_task():
@@ -114,13 +145,36 @@ def wrap_agent_execute_task():
114
145
  func_name="execute_task",
115
146
  metric_collection=metric_collection,
116
147
  metrics=metrics,
117
- ):
148
+ ) as observer:
118
149
  result = original_execute_task(self, *args, **kwargs)
150
+ observer.result = str(result) if result else None
119
151
  return result
120
152
 
121
153
  Agent.execute_task = wrapper
122
154
 
123
155
 
156
+ def wrap_agent_aexecute_task():
157
+ if not hasattr(Agent, "aexecute_task"):
158
+ return
159
+
160
+ original_aexecute_task = Agent.aexecute_task
161
+
162
+ @wraps(original_aexecute_task)
163
+ async def wrapper(self, *args, **kwargs):
164
+ metric_collection, metrics = _check_metrics_and_metric_collection(self)
165
+ with Observer(
166
+ span_type="agent",
167
+ func_name="aexecute_task",
168
+ metric_collection=metric_collection,
169
+ metrics=metrics,
170
+ ) as observer:
171
+ result = await original_aexecute_task(self, *args, **kwargs)
172
+ observer.result = str(result) if result else None
173
+ return result
174
+
175
+ Agent.aexecute_task = wrapper
176
+
177
+
124
178
  def _check_metrics_and_metric_collection(obj: Any):
125
179
  metric_collection = getattr(obj, "_metric_collection", None)
126
180
  metrics = getattr(obj, "_metrics", None)
@@ -84,6 +84,7 @@ class CallbackHandler(BaseCallbackHandler):
84
84
  user_id: Optional[str] = None,
85
85
  metrics: Optional[List[BaseMetric]] = None,
86
86
  metric_collection: Optional[str] = None,
87
+ test_case_id: Optional[str] = None,
87
88
  ):
88
89
  is_langchain_installed()
89
90
  with capture_tracing_integration("langchain.callback.CallbackHandler"):
@@ -108,6 +109,7 @@ class CallbackHandler(BaseCallbackHandler):
108
109
  "metadata": metadata,
109
110
  "thread_id": thread_id,
110
111
  "user_id": user_id,
112
+ "test_case_id": test_case_id,
111
113
  }
112
114
  self._trace_init_fields: Dict[str, Any] = dict(
113
115
  self._original_init_fields
@@ -200,6 +202,8 @@ class CallbackHandler(BaseCallbackHandler):
200
202
  trace.thread_id = fields["thread_id"]
201
203
  if fields.get("user_id") is not None:
202
204
  trace.user_id = fields["user_id"]
205
+ if fields.get("test_case_id") is not None:
206
+ trace.test_case_id = fields["test_case_id"]
203
207
  # prevent re-applying on every callback
204
208
  self._trace_init_fields = {}
205
209
 
@@ -21,6 +21,7 @@ from deepeval.tracing.types import (
21
21
  from deepeval.tracing.trace_context import (
22
22
  current_llm_context,
23
23
  current_agent_context,
24
+ current_trace_context,
24
25
  )
25
26
  from deepeval.test_case import ToolCall
26
27
  from deepeval.tracing.utils import make_json_serializable
@@ -40,7 +41,10 @@ try:
40
41
  LLMChatStartEvent,
41
42
  LLMChatEndEvent,
42
43
  )
43
- from llama_index_instrumentation.dispatcher import Dispatcher
44
+ from llama_index.core.instrumentation import Dispatcher
45
+ from llama_index.core.instrumentation.events.retrieval import (
46
+ RetrievalEndEvent,
47
+ )
44
48
  from deepeval.integrations.llama_index.utils import (
45
49
  parse_id,
46
50
  prepare_input_llm_test_case_params,
@@ -82,15 +86,23 @@ class LLamaIndexHandler(BaseEventHandler, BaseSpanHandler):
82
86
  input_messages.append({"role": role, "content": content})
83
87
 
84
88
  llm_span_context = current_llm_context.get()
85
- # create the span
89
+
90
+ parent_span = trace_manager.get_span_by_uuid(event.span_id)
91
+ if parent_span:
92
+ trace_uuid = parent_span.trace_uuid
93
+ else:
94
+ current_trace = current_trace_context.get()
95
+ if current_trace:
96
+ trace_uuid = current_trace.uuid
97
+ else:
98
+ trace_uuid = trace_manager.start_new_trace().uuid
99
+
86
100
  llm_span = LlmSpan(
87
101
  name="ConfidentLLMSpan",
88
102
  uuid=str(uuid.uuid4()),
89
103
  status=TraceSpanStatus.IN_PROGRESS,
90
104
  children=[],
91
- trace_uuid=trace_manager.get_span_by_uuid(
92
- event.span_id
93
- ).trace_uuid,
105
+ trace_uuid=trace_uuid,
94
106
  parent_uuid=event.span_id,
95
107
  start_time=perf_counter(),
96
108
  model=getattr(event, "model_dict", {}).get(
@@ -128,6 +140,13 @@ class LLamaIndexHandler(BaseEventHandler, BaseSpanHandler):
128
140
  trace_manager.remove_span(llm_span.uuid)
129
141
  del self.open_ai_astream_to_llm_span_map[event.span_id]
130
142
 
143
+ if isinstance(event, RetrievalEndEvent):
144
+ span = trace_manager.get_span_by_uuid(event.span_id)
145
+ if span:
146
+ span.retrieval_context = [
147
+ node.node.get_content() for node in event.nodes
148
+ ]
149
+
131
150
  def new_span(
132
151
  self,
133
152
  id_: str,
@@ -139,18 +158,30 @@ class LLamaIndexHandler(BaseEventHandler, BaseSpanHandler):
139
158
  ) -> Optional[LlamaIndexBaseSpan]:
140
159
  class_name, method_name = parse_id(id_)
141
160
 
142
- # check if it is a root span
143
- if parent_span_id is None:
144
- trace_uuid = trace_manager.start_new_trace().uuid
145
- elif class_name == "Workflow" and method_name == "run":
146
- trace_uuid = trace_manager.start_new_trace().uuid
147
- parent_span_id = None # since workflow is the root span, we need to set the parent span id to None
161
+ current_trace = current_trace_context.get()
162
+ trace_uuid = None
163
+
164
+ if parent_span_id is None or (
165
+ class_name == "Workflow" and method_name == "run"
166
+ ):
167
+ if current_trace:
168
+ trace_uuid = current_trace.uuid
169
+ else:
170
+ trace_uuid = trace_manager.start_new_trace().uuid
171
+
172
+ if class_name == "Workflow" and method_name == "run":
173
+ parent_span_id = None
174
+
148
175
  elif trace_manager.get_span_by_uuid(parent_span_id):
149
176
  trace_uuid = trace_manager.get_span_by_uuid(
150
177
  parent_span_id
151
178
  ).trace_uuid
179
+
152
180
  else:
153
- trace_uuid = trace_manager.start_new_trace().uuid
181
+ if current_trace:
182
+ trace_uuid = current_trace.uuid
183
+ else:
184
+ trace_uuid = trace_manager.start_new_trace().uuid
154
185
 
155
186
  self.root_span_trace_id_map[id_] = trace_uuid
156
187
 
@@ -195,7 +226,7 @@ class LLamaIndexHandler(BaseEventHandler, BaseSpanHandler):
195
226
  else None
196
227
  ),
197
228
  )
198
- elif method_name == "acall":
229
+ elif method_name in ["acall", "call_tool", "acall_tool"]:
199
230
  span = ToolSpan(
200
231
  uuid=id_,
201
232
  status=TraceSpanStatus.IN_PROGRESS,
@@ -206,7 +237,7 @@ class LLamaIndexHandler(BaseEventHandler, BaseSpanHandler):
206
237
  input=bound_args.arguments,
207
238
  name="Tool",
208
239
  )
209
- # prepare input test case params for the span
240
+
210
241
  prepare_input_llm_test_case_params(
211
242
  class_name, method_name, span, bound_args.arguments
212
243
  )
@@ -215,6 +246,22 @@ class LLamaIndexHandler(BaseEventHandler, BaseSpanHandler):
215
246
 
216
247
  return span
217
248
 
249
+ def _get_output_value(self, result: Any) -> Any:
250
+ """Helper to ensure AgentChatResponse and similar objects are serialized as dicts."""
251
+ if hasattr(result, "response") and hasattr(result, "sources"):
252
+ if hasattr(result, "model_dump"):
253
+ return result.model_dump()
254
+ if hasattr(result, "to_dict"):
255
+ return result.to_dict()
256
+ return {"response": result.response, "sources": result.sources}
257
+
258
+ if hasattr(result, "response"):
259
+ if hasattr(result, "model_dump"):
260
+ return result.model_dump()
261
+ return {"response": result.response}
262
+
263
+ return result
264
+
218
265
  def prepare_to_exit_span(
219
266
  self,
220
267
  id_: str,
@@ -229,7 +276,8 @@ class LLamaIndexHandler(BaseEventHandler, BaseSpanHandler):
229
276
  return None
230
277
 
231
278
  class_name, method_name = parse_id(id_)
232
- if method_name == "call_tool":
279
+
280
+ if method_name in ["call_tool", "acall_tool"]:
233
281
  output_json = make_json_serializable(result)
234
282
  if output_json and isinstance(output_json, dict):
235
283
  if base_span.tools_called is None:
@@ -243,7 +291,7 @@ class LLamaIndexHandler(BaseEventHandler, BaseSpanHandler):
243
291
  )
244
292
  base_span.end_time = perf_counter()
245
293
  base_span.status = TraceSpanStatus.SUCCESS
246
- base_span.output = result
294
+ base_span.output = self._get_output_value(result)
247
295
 
248
296
  if isinstance(base_span, ToolSpan):
249
297
  result_json = make_json_serializable(result)
@@ -265,7 +313,8 @@ class LLamaIndexHandler(BaseEventHandler, BaseSpanHandler):
265
313
 
266
314
  if base_span.parent_uuid is None:
267
315
  trace_manager.end_trace(base_span.trace_uuid)
268
- self.root_span_trace_id_map.pop(base_span.uuid)
316
+ if base_span.uuid in self.root_span_trace_id_map:
317
+ self.root_span_trace_id_map.pop(base_span.uuid)
269
318
 
270
319
  return base_span
271
320
 
@@ -282,13 +331,12 @@ class LLamaIndexHandler(BaseEventHandler, BaseSpanHandler):
282
331
  return None
283
332
 
284
333
  base_span.end_time = perf_counter()
285
- base_span.status = (
286
- TraceSpanStatus.SUCCESS
287
- ) # find a way to add error and handle the span without the parent id
334
+ base_span.status = TraceSpanStatus.SUCCESS
288
335
 
289
336
  if base_span.parent_uuid is None:
290
337
  trace_manager.end_trace(base_span.trace_uuid)
291
- self.root_span_trace_id_map.pop(base_span.uuid)
338
+ if base_span.uuid in self.root_span_trace_id_map:
339
+ self.root_span_trace_id_map.pop(base_span.uuid)
292
340
 
293
341
  return base_span
294
342
 
@@ -36,7 +36,10 @@ try:
36
36
  SpanProcessor as _SpanProcessor,
37
37
  TracerProvider,
38
38
  )
39
- from opentelemetry.sdk.trace.export import BatchSpanProcessor
39
+ from opentelemetry.sdk.trace.export import (
40
+ BatchSpanProcessor,
41
+ SimpleSpanProcessor,
42
+ )
40
43
  from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
41
44
  OTLPSpanExporter,
42
45
  )
@@ -172,7 +175,9 @@ class ConfidentInstrumentationSettings(InstrumentationSettings):
172
175
  trace_provider.add_span_processor(span_interceptor)
173
176
 
174
177
  if is_test_mode:
175
- trace_provider.add_span_processor(BatchSpanProcessor(test_exporter))
178
+ trace_provider.add_span_processor(
179
+ SimpleSpanProcessor(ConfidentSpanExporter())
180
+ )
176
181
  else:
177
182
  trace_provider.add_span_processor(
178
183
  BatchSpanProcessor(
@@ -345,7 +350,6 @@ class SpanInterceptor(SpanProcessor):
345
350
  trace.status = TraceSpanStatus.SUCCESS
346
351
  trace.end_time = perf_counter()
347
352
  trace_manager.traces_to_evaluate.append(trace)
348
- test_exporter.clear_span_json_list()
349
353
 
350
354
  def _add_agent_span(self, span, name):
351
355
  span.set_attribute("confident.span.type", "agent")
@@ -7,6 +7,7 @@ from .api import (
7
7
  ReasoningEffort,
8
8
  OutputType,
9
9
  PromptInterpolationType,
10
+ Tool,
10
11
  )
11
12
 
12
13
  __all__ = [
@@ -18,4 +19,5 @@ __all__ = [
18
19
  "ReasoningEffort",
19
20
  "OutputType",
20
21
  "PromptInterpolationType",
22
+ "Tool",
21
23
  ]
deepeval/prompt/api.py CHANGED
@@ -1,6 +1,14 @@
1
- from pydantic import BaseModel, Field, AliasChoices, ConfigDict
1
+ from pydantic import (
2
+ BaseModel,
3
+ Field,
4
+ AliasChoices,
5
+ ConfigDict,
6
+ model_validator,
7
+ model_serializer,
8
+ )
2
9
  from enum import Enum
3
- from typing import List, Optional
10
+ import uuid
11
+ from typing import List, Optional, Dict, Any, Union, Type
4
12
  from pydantic import TypeAdapter
5
13
 
6
14
  from deepeval.utils import make_model_config
@@ -33,6 +41,12 @@ class ModelProvider(Enum):
33
41
  OPENROUTER = "OPENROUTER"
34
42
 
35
43
 
44
+ class ToolMode(str, Enum):
45
+ ALLOW_ADDITIONAL = "ALLOW_ADDITIONAL"
46
+ NO_ADDITIONAL = "NO_ADDITIONAL"
47
+ STRICT = "STRICT"
48
+
49
+
36
50
  class ModelSettings(BaseModel):
37
51
  provider: Optional[ModelProvider] = None
38
52
  name: Optional[str] = None
@@ -100,6 +114,7 @@ class OutputSchemaField(BaseModel):
100
114
  id: str
101
115
  type: SchemaDataType
102
116
  name: str
117
+ description: Optional[str] = None
103
118
  required: Optional[bool] = False
104
119
  parent_id: Optional[str] = Field(
105
120
  default=None,
@@ -109,8 +124,36 @@ class OutputSchemaField(BaseModel):
109
124
 
110
125
 
111
126
  class OutputSchema(BaseModel):
127
+ id: Optional[str] = None
112
128
  fields: Optional[List[OutputSchemaField]] = None
129
+ name: Optional[str] = None
130
+
131
+
132
+ class Tool(BaseModel):
133
+ id: str = Field(default_factory=lambda: str(uuid.uuid4()))
113
134
  name: str
135
+ description: Optional[str] = None
136
+ mode: ToolMode
137
+ structured_schema: Optional[Union[Type[BaseModel], OutputSchema]] = Field(
138
+ serialization_alias="structuredSchema",
139
+ validation_alias=AliasChoices("structured_schema", "structuredSchema"),
140
+ )
141
+
142
+ @model_validator(mode="after")
143
+ def update_schema(self):
144
+ if not isinstance(self.structured_schema, OutputSchema):
145
+ from deepeval.prompt.utils import construct_output_schema
146
+
147
+ self.structured_schema = construct_output_schema(
148
+ self.structured_schema
149
+ )
150
+ return self
151
+
152
+ @property
153
+ def input_schema(self) -> Dict[str, Any]:
154
+ from deepeval.prompt.utils import output_schema_to_json_schema
155
+
156
+ return output_schema_to_json_schema(self.structured_schema)
114
157
 
115
158
 
116
159
  ###################################
@@ -186,6 +229,7 @@ class PromptHttpResponse(BaseModel):
186
229
  serialization_alias="outputSchema",
187
230
  validation_alias=AliasChoices("output_schema", "outputSchema"),
188
231
  )
232
+ tools: Optional[List[Tool]] = None
189
233
 
190
234
 
191
235
  class PromptPushRequest(BaseModel):
@@ -196,6 +240,7 @@ class PromptPushRequest(BaseModel):
196
240
  alias: str
197
241
  text: Optional[str] = None
198
242
  messages: Optional[List[PromptMessage]] = None
243
+ tools: Optional[List[Tool]] = None
199
244
  interpolation_type: PromptInterpolationType = Field(
200
245
  serialization_alias="interpolationType"
201
246
  )
@@ -215,6 +260,7 @@ class PromptUpdateRequest(BaseModel):
215
260
 
216
261
  text: Optional[str] = None
217
262
  messages: Optional[List[PromptMessage]] = None
263
+ tools: Optional[List[Tool]] = None
218
264
  interpolation_type: PromptInterpolationType = Field(
219
265
  serialization_alias="interpolationType"
220
266
  )
deepeval/prompt/prompt.py CHANGED
@@ -25,6 +25,7 @@ from deepeval.prompt.api import (
25
25
  ModelSettings,
26
26
  OutputSchema,
27
27
  OutputType,
28
+ Tool,
28
29
  )
29
30
  from deepeval.prompt.utils import (
30
31
  interpolate_text,
@@ -101,6 +102,7 @@ class CachedPrompt(BaseModel):
101
102
  model_settings: Optional[ModelSettings]
102
103
  output_type: Optional[OutputType]
103
104
  output_schema: Optional[OutputSchema]
105
+ tools: Optional[List[Tool]] = None
104
106
 
105
107
 
106
108
  class Prompt:
@@ -131,6 +133,7 @@ class Prompt:
131
133
  interpolation_type or PromptInterpolationType.FSTRING
132
134
  )
133
135
  self.confident_api_key = confident_api_key
136
+ self.tools: Optional[List[Tool]] = None
134
137
 
135
138
  self._version = None
136
139
  self._prompt_version_id: Optional[str] = None
@@ -308,6 +311,7 @@ class Prompt:
308
311
  model_settings: Optional[ModelSettings] = None,
309
312
  output_type: Optional[OutputType] = None,
310
313
  output_schema: Optional[OutputSchema] = None,
314
+ tools: Optional[List[Tool]] = None,
311
315
  ):
312
316
  if portalocker is None or not self.alias:
313
317
  return
@@ -354,6 +358,7 @@ class Prompt:
354
358
  "model_settings": model_settings,
355
359
  "output_type": output_type,
356
360
  "output_schema": output_schema,
361
+ "tools": tools,
357
362
  }
358
363
 
359
364
  if cache_key == VERSION_CACHE_KEY:
@@ -415,6 +420,7 @@ class Prompt:
415
420
  self.output_schema = construct_base_model(
416
421
  cached_prompt.output_schema
417
422
  )
423
+ self.tools = cached_prompt.tools
418
424
 
419
425
  end_time = time.perf_counter()
420
426
  time_taken = format(end_time - start_time, ".2f")
@@ -494,6 +500,7 @@ class Prompt:
494
500
  self.output_schema = construct_base_model(
495
501
  cached_prompt.output_schema
496
502
  )
503
+ self.tools = cached_prompt.tools
497
504
  return
498
505
  except Exception:
499
506
  pass
@@ -547,6 +554,7 @@ class Prompt:
547
554
  model_settings=data.get("modelSettings", None),
548
555
  output_type=data.get("outputType", None),
549
556
  output_schema=data.get("outputSchema", None),
557
+ tools=data.get("tools", None),
550
558
  )
551
559
  except Exception:
552
560
  if fallback_to_cache:
@@ -573,6 +581,7 @@ class Prompt:
573
581
  self.output_schema = construct_base_model(
574
582
  response.output_schema
575
583
  )
584
+ self.tools = response.tools
576
585
 
577
586
  end_time = time.perf_counter()
578
587
  time_taken = format(end_time - start_time, ".2f")
@@ -594,6 +603,7 @@ class Prompt:
594
603
  model_settings=response.model_settings,
595
604
  output_type=response.output_type,
596
605
  output_schema=response.output_schema,
606
+ tools=response.tools,
597
607
  )
598
608
 
599
609
  def push(
@@ -606,6 +616,7 @@ class Prompt:
606
616
  model_settings: Optional[ModelSettings] = None,
607
617
  output_type: Optional[OutputType] = None,
608
618
  output_schema: Optional[Type[BaseModel]] = None,
619
+ tools: Optional[List[Tool]] = None,
609
620
  _verbose: Optional[bool] = True,
610
621
  ):
611
622
  if self.alias is None:
@@ -628,6 +639,7 @@ class Prompt:
628
639
  output_type=output_type or self.output_type,
629
640
  output_schema=construct_output_schema(output_schema)
630
641
  or construct_output_schema(self.output_schema),
642
+ tools=tools or self.tools,
631
643
  )
632
644
  try:
633
645
  body = body.model_dump(
@@ -655,6 +667,7 @@ class Prompt:
655
667
  self.model_settings = model_settings or self.model_settings
656
668
  self.output_type = output_type or self.output_type
657
669
  self.output_schema = output_schema or self.output_schema
670
+ self.tools = tools or self.tools
658
671
  self.type = PromptType.TEXT if text_template else PromptType.LIST
659
672
  if _verbose:
660
673
  console = Console()
@@ -674,6 +687,7 @@ class Prompt:
674
687
  model_settings: Optional[ModelSettings] = None,
675
688
  output_type: Optional[OutputType] = None,
676
689
  output_schema: Optional[Type[BaseModel]] = None,
690
+ tools: Optional[List[Tool]] = None,
677
691
  ):
678
692
  if self.alias is None:
679
693
  raise ValueError(
@@ -687,6 +701,7 @@ class Prompt:
687
701
  model_settings=model_settings,
688
702
  output_type=output_type,
689
703
  output_schema=construct_output_schema(output_schema),
704
+ tools=tools,
690
705
  )
691
706
  try:
692
707
  body = body.model_dump(
@@ -712,6 +727,7 @@ class Prompt:
712
727
  self.model_settings = model_settings
713
728
  self.output_type = output_type
714
729
  self.output_schema = output_schema
730
+ self.tools = tools
715
731
  self.type = PromptType.TEXT if text else PromptType.LIST
716
732
  console = Console()
717
733
  console.print("✅ Prompt successfully updated on Confident AI!")
@@ -796,6 +812,10 @@ class Prompt:
796
812
  messages=data.get("messages", None),
797
813
  type=data["type"],
798
814
  interpolation_type=data["interpolationType"],
815
+ model_settings=data.get("modelSettings", None),
816
+ output_type=data.get("outputType", None),
817
+ output_schema=data.get("outputSchema", None),
818
+ tools=data.get("tools", None),
799
819
  )
800
820
 
801
821
  # Update the cache with fresh data from server
@@ -808,6 +828,10 @@ class Prompt:
808
828
  prompt_version_id=response.id,
809
829
  type=response.type,
810
830
  interpolation_type=response.interpolation_type,
831
+ model_settings=response.model_settings,
832
+ output_type=response.output_type,
833
+ output_schema=response.output_schema,
834
+ tools=response.tools,
811
835
  )
812
836
 
813
837
  # Update in-memory properties with fresh data (thread-safe)
@@ -819,6 +843,12 @@ class Prompt:
819
843
  self._prompt_version_id = response.id
820
844
  self.type = response.type
821
845
  self.interpolation_type = response.interpolation_type
846
+ self.model_settings = response.model_settings
847
+ self.output_type = response.output_type
848
+ self.output_schema = construct_base_model(
849
+ response.output_schema
850
+ )
851
+ self.tools = response.tools
822
852
 
823
853
  except Exception:
824
854
  pass