deepeval 3.6.6__py3-none-any.whl → 3.6.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. deepeval/_version.py +1 -1
  2. deepeval/benchmarks/equity_med_qa/equity_med_qa.py +1 -0
  3. deepeval/cli/main.py +42 -0
  4. deepeval/confident/api.py +1 -0
  5. deepeval/config/settings.py +22 -4
  6. deepeval/constants.py +8 -1
  7. deepeval/dataset/dataset.py +2 -11
  8. deepeval/dataset/utils.py +1 -1
  9. deepeval/evaluate/evaluate.py +5 -1
  10. deepeval/evaluate/execute.py +97 -42
  11. deepeval/evaluate/utils.py +20 -116
  12. deepeval/integrations/crewai/__init__.py +6 -1
  13. deepeval/integrations/crewai/handler.py +1 -1
  14. deepeval/integrations/crewai/subs.py +51 -0
  15. deepeval/integrations/crewai/wrapper.py +45 -5
  16. deepeval/metrics/answer_relevancy/answer_relevancy.py +12 -3
  17. deepeval/metrics/api.py +281 -0
  18. deepeval/metrics/argument_correctness/argument_correctness.py +12 -2
  19. deepeval/metrics/bias/bias.py +12 -3
  20. deepeval/metrics/contextual_precision/contextual_precision.py +12 -3
  21. deepeval/metrics/contextual_recall/contextual_recall.py +12 -3
  22. deepeval/metrics/contextual_relevancy/contextual_relevancy.py +12 -1
  23. deepeval/metrics/conversation_completeness/conversation_completeness.py +12 -0
  24. deepeval/metrics/conversational_dag/conversational_dag.py +12 -0
  25. deepeval/metrics/conversational_dag/nodes.py +12 -4
  26. deepeval/metrics/conversational_g_eval/conversational_g_eval.py +73 -59
  27. deepeval/metrics/dag/dag.py +12 -0
  28. deepeval/metrics/dag/nodes.py +12 -4
  29. deepeval/metrics/faithfulness/faithfulness.py +12 -1
  30. deepeval/metrics/g_eval/g_eval.py +11 -0
  31. deepeval/metrics/hallucination/hallucination.py +12 -1
  32. deepeval/metrics/indicator.py +8 -2
  33. deepeval/metrics/json_correctness/json_correctness.py +12 -1
  34. deepeval/metrics/knowledge_retention/knowledge_retention.py +12 -0
  35. deepeval/metrics/mcp/mcp_task_completion.py +13 -0
  36. deepeval/metrics/mcp/multi_turn_mcp_use_metric.py +13 -0
  37. deepeval/metrics/mcp_use_metric/mcp_use_metric.py +12 -1
  38. deepeval/metrics/misuse/misuse.py +12 -1
  39. deepeval/metrics/multimodal_metrics/image_coherence/image_coherence.py +3 -0
  40. deepeval/metrics/multimodal_metrics/image_editing/image_editing.py +3 -0
  41. deepeval/metrics/multimodal_metrics/image_helpfulness/image_helpfulness.py +3 -0
  42. deepeval/metrics/multimodal_metrics/image_reference/image_reference.py +3 -0
  43. deepeval/metrics/multimodal_metrics/multimodal_answer_relevancy/multimodal_answer_relevancy.py +6 -1
  44. deepeval/metrics/multimodal_metrics/multimodal_contextual_precision/multimodal_contextual_precision.py +6 -1
  45. deepeval/metrics/multimodal_metrics/multimodal_contextual_recall/multimodal_contextual_recall.py +3 -0
  46. deepeval/metrics/multimodal_metrics/multimodal_contextual_relevancy/multimodal_contextual_relevancy.py +3 -0
  47. deepeval/metrics/multimodal_metrics/multimodal_faithfulness/multimodal_faithfulness.py +3 -0
  48. deepeval/metrics/multimodal_metrics/multimodal_g_eval/multimodal_g_eval.py +3 -0
  49. deepeval/metrics/multimodal_metrics/multimodal_tool_correctness/multimodal_tool_correctness.py +10 -5
  50. deepeval/metrics/non_advice/non_advice.py +12 -0
  51. deepeval/metrics/pii_leakage/pii_leakage.py +12 -1
  52. deepeval/metrics/prompt_alignment/prompt_alignment.py +12 -1
  53. deepeval/metrics/role_adherence/role_adherence.py +12 -0
  54. deepeval/metrics/role_violation/role_violation.py +12 -0
  55. deepeval/metrics/summarization/summarization.py +12 -1
  56. deepeval/metrics/task_completion/task_completion.py +3 -0
  57. deepeval/metrics/tool_correctness/tool_correctness.py +8 -0
  58. deepeval/metrics/toxicity/toxicity.py +12 -0
  59. deepeval/metrics/turn_relevancy/turn_relevancy.py +12 -0
  60. deepeval/models/llms/grok_model.py +1 -1
  61. deepeval/models/llms/openai_model.py +2 -0
  62. deepeval/openai/__init__.py +14 -32
  63. deepeval/openai/extractors.py +24 -34
  64. deepeval/openai/patch.py +256 -161
  65. deepeval/openai/types.py +20 -0
  66. deepeval/openai/utils.py +98 -56
  67. deepeval/prompt/__init__.py +19 -1
  68. deepeval/prompt/api.py +160 -0
  69. deepeval/prompt/prompt.py +244 -62
  70. deepeval/prompt/utils.py +144 -2
  71. deepeval/synthesizer/chunking/context_generator.py +209 -152
  72. deepeval/synthesizer/chunking/doc_chunker.py +46 -12
  73. deepeval/synthesizer/synthesizer.py +8 -5
  74. deepeval/test_case/api.py +131 -0
  75. deepeval/test_run/__init__.py +1 -0
  76. deepeval/test_run/hyperparameters.py +47 -8
  77. deepeval/test_run/test_run.py +104 -1
  78. deepeval/tracing/api.py +3 -1
  79. deepeval/tracing/message_types/__init__.py +10 -0
  80. deepeval/tracing/message_types/base.py +6 -0
  81. deepeval/tracing/message_types/messages.py +14 -0
  82. deepeval/tracing/message_types/tools.py +18 -0
  83. deepeval/tracing/otel/utils.py +1 -1
  84. deepeval/tracing/trace_context.py +73 -4
  85. deepeval/tracing/tracing.py +51 -3
  86. deepeval/tracing/types.py +16 -0
  87. deepeval/tracing/utils.py +8 -0
  88. {deepeval-3.6.6.dist-info → deepeval-3.6.7.dist-info}/METADATA +1 -1
  89. {deepeval-3.6.6.dist-info → deepeval-3.6.7.dist-info}/RECORD +92 -84
  90. {deepeval-3.6.6.dist-info → deepeval-3.6.7.dist-info}/LICENSE.md +0 -0
  91. {deepeval-3.6.6.dist-info → deepeval-3.6.7.dist-info}/WHEEL +0 -0
  92. {deepeval-3.6.6.dist-info → deepeval-3.6.7.dist-info}/entry_points.txt +0 -0
deepeval/prompt/prompt.py CHANGED
@@ -1,11 +1,11 @@
1
1
  from enum import Enum
2
- from typing import Literal, Optional, List, Dict
2
+ from typing import Optional, List, Dict, Type, Literal
3
3
  from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn
4
4
  from rich.console import Console
5
5
  import time
6
6
  import json
7
7
  import os
8
- from pydantic import BaseModel
8
+ from pydantic import BaseModel, ValidationError
9
9
  import asyncio
10
10
  import portalocker
11
11
  import threading
@@ -17,8 +17,20 @@ from deepeval.prompt.api import (
17
17
  PromptInterpolationType,
18
18
  PromptPushRequest,
19
19
  PromptVersionsHttpResponse,
20
+ PromptMessageList,
21
+ PromptUpdateRequest,
22
+ ModelSettings,
23
+ OutputSchema,
24
+ OutputType,
25
+ ReasoningEffort,
26
+ Verbosity,
27
+ ModelProvider,
28
+ )
29
+ from deepeval.prompt.utils import (
30
+ interpolate_text,
31
+ construct_base_model,
32
+ construct_output_schema,
20
33
  )
21
- from deepeval.prompt.utils import interpolate_text
22
34
  from deepeval.confident.api import Api, Endpoints, HttpMethods
23
35
  from deepeval.constants import HIDDEN_DIR
24
36
 
@@ -73,45 +85,51 @@ class CachedPrompt(BaseModel):
73
85
  prompt_version_id: str
74
86
  type: PromptType
75
87
  interpolation_type: PromptInterpolationType
88
+ model_settings: Optional[ModelSettings]
89
+ output_type: Optional[OutputType]
90
+ output_schema: Optional[OutputSchema]
76
91
 
77
92
  class Config:
78
93
  use_enum_values = True
79
94
 
80
95
 
81
96
  class Prompt:
82
- label: Optional[str] = None
83
- _prompt_version_id: Optional[str] = None
84
- _type: Optional[PromptType] = None
85
- _interpolation_type: Optional[PromptInterpolationType] = None
86
97
 
87
98
  def __init__(
88
99
  self,
89
100
  alias: Optional[str] = None,
90
- template: Optional[str] = None,
101
+ text_template: Optional[str] = None,
91
102
  messages_template: Optional[List[PromptMessage]] = None,
103
+ model_settings: Optional[ModelSettings] = None,
104
+ output_type: Optional[OutputType] = None,
105
+ output_schema: Optional[Type[BaseModel]] = None,
92
106
  ):
93
- if alias is None and template is None:
94
- raise TypeError(
95
- "Unable to create Prompt where 'alias' and 'template' are both None. Please provide at least one to continue."
96
- )
97
- if template and messages_template:
107
+ if text_template and messages_template:
98
108
  raise TypeError(
99
- "Unable to create Prompt where 'template' and 'messages_template' are both provided. Please provide only one to continue."
109
+ "Unable to create Prompt where 'text_template' and 'messages_template' are both provided. Please provide only one to continue."
100
110
  )
101
-
102
111
  self.alias = alias
103
- self._text_template = template
104
- self._messages_template = messages_template
112
+ self.text_template = text_template
113
+ self.messages_template = messages_template
114
+ self.model_settings: Optional[ModelSettings] = model_settings
115
+ self.output_type: Optional[OutputType] = output_type
116
+ self.output_schema: Optional[Type[BaseModel]] = output_schema
117
+ self.label: Optional[str] = None
118
+ self.interpolation_type: Optional[PromptInterpolationType] = None
119
+
105
120
  self._version = None
121
+ self._prompt_version_id: Optional[str] = None
106
122
  self._polling_tasks: Dict[str, Dict[str, asyncio.Task]] = {}
107
123
  self._refresh_map: Dict[str, Dict[str, int]] = {}
108
124
  self._lock = (
109
125
  threading.Lock()
110
126
  ) # Protect instance attributes from race conditions
111
- if template:
112
- self._type = PromptType.TEXT
127
+
128
+ self.type: Optional[PromptType] = None
129
+ if text_template:
130
+ self.type = PromptType.TEXT
113
131
  elif messages_template:
114
- self._type = PromptType.LIST
132
+ self.type = PromptType.LIST
115
133
 
116
134
  def __del__(self):
117
135
  """Cleanup polling tasks when instance is destroyed"""
@@ -135,12 +153,48 @@ class Prompt:
135
153
  def version(self, value):
136
154
  self._version = value
137
155
 
156
+ def load(self, file_path: str, messages_key: Optional[str] = None):
157
+ _, ext = os.path.splitext(file_path)
158
+ if ext != ".json" and ext != ".txt":
159
+ raise ValueError("Only .json and .txt files are supported")
160
+
161
+ file_name = os.path.basename(file_path).split(".")[0]
162
+ self.alias = file_name
163
+ with open(file_path, "r") as f:
164
+ content = f.read()
165
+ try:
166
+ data = json.loads(content)
167
+ except:
168
+ self.text_template = content
169
+ return content
170
+
171
+ text_template = None
172
+ messages_template = None
173
+ try:
174
+ if isinstance(data, list):
175
+ messages_template = PromptMessageList.validate_python(data)
176
+ elif isinstance(data, dict):
177
+ if messages_key is None:
178
+ raise ValueError(
179
+ "messages `key` must be provided if file is a dictionary"
180
+ )
181
+ messages = data[messages_key]
182
+ messages_template = PromptMessageList.validate_python(messages)
183
+ else:
184
+ text_template = content
185
+ except ValidationError:
186
+ text_template = content
187
+
188
+ self.text_template = text_template
189
+ self.messages_template = messages_template
190
+ return text_template or messages_template
191
+
138
192
  def interpolate(self, **kwargs):
139
193
  with self._lock:
140
- prompt_type = self._type
141
- text_template = self._text_template
142
- messages_template = self._messages_template
143
- interpolation_type = self._interpolation_type
194
+ prompt_type = self.type
195
+ text_template = self.text_template
196
+ messages_template = self.messages_template
197
+ interpolation_type = self.interpolation_type
144
198
 
145
199
  if prompt_type == PromptType.TEXT:
146
200
  if text_template is None:
@@ -166,7 +220,11 @@ class Prompt:
166
220
  )
167
221
  return interpolated_messages
168
222
  else:
169
- raise ValueError(f"Unsupported prompt type: {prompt_type}")
223
+ raise ValueError(f"Unsupported prompt type: {self.type}")
224
+
225
+ ############################################
226
+ ### Utils
227
+ ############################################
170
228
 
171
229
  def _get_versions(self) -> List:
172
230
  if self.alias is None:
@@ -232,6 +290,9 @@ class Prompt:
232
290
  prompt_version_id: Optional[str] = None,
233
291
  type: Optional[PromptType] = None,
234
292
  interpolation_type: Optional[PromptInterpolationType] = None,
293
+ model_settings: Optional[ModelSettings] = None,
294
+ output_type: Optional[OutputType] = None,
295
+ output_schema: Optional[OutputSchema] = None,
235
296
  ):
236
297
  if not self.alias:
237
298
  return
@@ -276,6 +337,9 @@ class Prompt:
276
337
  "prompt_version_id": prompt_version_id,
277
338
  "type": type,
278
339
  "interpolation_type": interpolation_type,
340
+ "model_settings": model_settings,
341
+ "output_type": output_type,
342
+ "output_schema": output_schema,
279
343
  }
280
344
 
281
345
  if cache_key == VERSION_CACHE_KEY:
@@ -313,14 +377,27 @@ class Prompt:
313
377
  raise ValueError("Unable to fetch prompt and load from cache")
314
378
 
315
379
  with self._lock:
316
- self.version = cached_prompt.version
380
+ self._version = cached_prompt.version
317
381
  self.label = cached_prompt.label
318
- self._text_template = cached_prompt.template
319
- self._messages_template = cached_prompt.messages_template
382
+ self.text_template = cached_prompt.template
383
+ self.messages_template = cached_prompt.messages_template
320
384
  self._prompt_version_id = cached_prompt.prompt_version_id
321
- self._type = PromptType(cached_prompt.type)
322
- self._interpolation_type = PromptInterpolationType(
323
- cached_prompt.interpolation_type
385
+ self.type = (
386
+ PromptType(cached_prompt.type) if cached_prompt.type else None
387
+ )
388
+ self.interpolation_type = (
389
+ PromptInterpolationType(cached_prompt.interpolation_type)
390
+ if cached_prompt.interpolation_type
391
+ else None
392
+ )
393
+ self.model_settings = cached_prompt.model_settings
394
+ self.output_type = (
395
+ OutputType(cached_prompt.output_type)
396
+ if cached_prompt.output_type
397
+ else None
398
+ )
399
+ self.output_schema = construct_base_model(
400
+ cached_prompt.output_schema
324
401
  )
325
402
 
326
403
  end_time = time.perf_counter()
@@ -330,6 +407,10 @@ class Prompt:
330
407
  description=f"{progress.tasks[task_id].description}[rgb(25,227,160)]Loaded from cache! ({time_taken}s)",
331
408
  )
332
409
 
410
+ ############################################
411
+ ### Pull, Push, Update
412
+ ############################################
413
+
333
414
  def pull(
334
415
  self,
335
416
  version: Optional[str] = None,
@@ -369,18 +450,33 @@ class Prompt:
369
450
  )
370
451
  if cached_prompt:
371
452
  with self._lock:
372
- self.version = cached_prompt.version
453
+ self._version = cached_prompt.version
373
454
  self.label = cached_prompt.label
374
- self._text_template = cached_prompt.template
375
- self._messages_template = (
376
- cached_prompt.messages_template
377
- )
455
+ self.text_template = cached_prompt.template
456
+ self.messages_template = cached_prompt.messages_template
378
457
  self._prompt_version_id = (
379
458
  cached_prompt.prompt_version_id
380
459
  )
381
- self._type = PromptType(cached_prompt.type)
382
- self._interpolation_type = PromptInterpolationType(
383
- cached_prompt.interpolation_type
460
+ self.type = (
461
+ PromptType(cached_prompt.type)
462
+ if cached_prompt.type
463
+ else None
464
+ )
465
+ self.interpolation_type = (
466
+ PromptInterpolationType(
467
+ cached_prompt.interpolation_type
468
+ )
469
+ if cached_prompt.interpolation_type
470
+ else None
471
+ )
472
+ self.model_settings = cached_prompt.model_settings
473
+ self.output_type = (
474
+ OutputType(cached_prompt.output_type)
475
+ if cached_prompt.output_type
476
+ else None
477
+ )
478
+ self.output_schema = construct_base_model(
479
+ cached_prompt.output_schema
384
480
  )
385
481
  return
386
482
  except:
@@ -432,6 +528,9 @@ class Prompt:
432
528
  messages=data.get("messages", None),
433
529
  type=data["type"],
434
530
  interpolation_type=data["interpolationType"],
531
+ model_settings=data.get("modelSettings", None),
532
+ output_type=data.get("outputType", None),
533
+ output_schema=data.get("outputSchema", None),
435
534
  )
436
535
  except Exception:
437
536
  if fallback_to_cache:
@@ -446,13 +545,18 @@ class Prompt:
446
545
  raise
447
546
 
448
547
  with self._lock:
449
- self.version = response.version
548
+ self._version = response.version
450
549
  self.label = response.label
451
- self._text_template = response.text
452
- self._messages_template = response.messages
550
+ self.text_template = response.text
551
+ self.messages_template = response.messages
453
552
  self._prompt_version_id = response.id
454
- self._type = response.type
455
- self._interpolation_type = response.interpolation_type
553
+ self.type = response.type
554
+ self.interpolation_type = response.interpolation_type
555
+ self.model_settings = response.model_settings
556
+ self.output_type = response.output_type
557
+ self.output_schema = construct_base_model(
558
+ response.output_schema
559
+ )
456
560
 
457
561
  end_time = time.perf_counter()
458
562
  time_taken = format(end_time - start_time, ".2f")
@@ -471,6 +575,9 @@ class Prompt:
471
575
  prompt_version_id=response.id,
472
576
  type=response.type,
473
577
  interpolation_type=response.interpolation_type,
578
+ model_settings=response.model_settings,
579
+ output_type=response.output_type,
580
+ output_schema=response.output_schema,
474
581
  )
475
582
 
476
583
  def push(
@@ -480,26 +587,36 @@ class Prompt:
480
587
  interpolation_type: Optional[
481
588
  PromptInterpolationType
482
589
  ] = PromptInterpolationType.FSTRING,
590
+ model_settings: Optional[ModelSettings] = None,
591
+ output_type: Optional[OutputType] = None,
592
+ output_schema: Optional[Type[BaseModel]] = None,
593
+ _verbose: Optional[bool] = True,
483
594
  ):
484
595
  if self.alias is None:
485
596
  raise ValueError(
486
597
  "Prompt alias is not set. Please set an alias to continue."
487
598
  )
488
-
489
- if text is None and messages is None:
599
+ text_template = text or self.text_template
600
+ messages_template = messages or self.messages_template
601
+ if text_template is None and messages_template is None:
490
602
  raise ValueError("Either text or messages must be provided")
491
-
492
- if text is not None and messages is not None:
603
+ if text_template is not None and messages_template is not None:
493
604
  raise ValueError("Only one of text or messages can be provided")
494
605
 
495
606
  body = PromptPushRequest(
496
607
  alias=self.alias,
497
- text=text,
498
- messages=messages,
499
- interpolation_type=interpolation_type,
608
+ text=text_template,
609
+ messages=messages_template,
610
+ interpolation_type=interpolation_type or self.interpolation_type,
611
+ model_settings=model_settings or self.model_settings,
612
+ output_type=output_type or self.output_type,
613
+ output_schema=construct_output_schema(output_schema)
614
+ or construct_output_schema(self.output_schema),
500
615
  )
501
616
  try:
502
- body = body.model_dump(by_alias=True, exclude_none=True)
617
+ body = body.model_dump(
618
+ by_alias=True, exclude_none=True, mode="json"
619
+ )
503
620
  except AttributeError:
504
621
  # Pydantic version below 2.0
505
622
  body = body.dict(by_alias=True, exclude_none=True)
@@ -510,13 +627,78 @@ class Prompt:
510
627
  endpoint=Endpoints.PROMPTS_ENDPOINT,
511
628
  body=body,
512
629
  )
630
+ versions = self._get_versions()
513
631
 
514
- if link:
515
- console = Console()
516
- console.print(
517
- "✅ Prompt successfully pushed to Confident AI! View at "
518
- f"[link={link}]{link}[/link]"
632
+ if link and versions:
633
+ self._prompt_version_id = versions[-1].id
634
+ self.text_template = text_template
635
+ self.messages_template = messages_template
636
+ self.interpolation_type = (
637
+ interpolation_type or self.interpolation_type
638
+ )
639
+ self.model_settings = model_settings or self.model_settings
640
+ self.output_type = output_type or self.output_type
641
+ self.output_schema = output_schema or self.output_schema
642
+ self.type = PromptType.TEXT if text_template else PromptType.LIST
643
+ if _verbose:
644
+ console = Console()
645
+ console.print(
646
+ "✅ Prompt successfully pushed to Confident AI! View at "
647
+ f"[link={link}]{link}[/link]"
648
+ )
649
+
650
+ def update(
651
+ self,
652
+ version: str,
653
+ text: Optional[str] = None,
654
+ messages: Optional[List[PromptMessage]] = None,
655
+ interpolation_type: Optional[
656
+ PromptInterpolationType
657
+ ] = PromptInterpolationType.FSTRING,
658
+ model_settings: Optional[ModelSettings] = None,
659
+ output_type: Optional[OutputType] = None,
660
+ output_schema: Optional[Type[BaseModel]] = None,
661
+ ):
662
+ if self.alias is None:
663
+ raise ValueError(
664
+ "Prompt alias is not set. Please set an alias to continue."
665
+ )
666
+
667
+ body = PromptUpdateRequest(
668
+ text=text,
669
+ messages=messages,
670
+ interpolation_type=interpolation_type,
671
+ model_settings=model_settings,
672
+ output_type=output_type,
673
+ output_schema=construct_output_schema(output_schema),
674
+ )
675
+ try:
676
+ body = body.model_dump(
677
+ by_alias=True, exclude_none=True, mode="json"
519
678
  )
679
+ except AttributeError:
680
+ body = body.dict(by_alias=True, exclude_none=True)
681
+ api = Api()
682
+ data, _ = api.send_request(
683
+ method=HttpMethods.PUT,
684
+ endpoint=Endpoints.PROMPTS_VERSION_ID_ENDPOINT,
685
+ url_params={
686
+ "alias": self.alias,
687
+ "versionId": version,
688
+ },
689
+ body=body,
690
+ )
691
+ if data:
692
+ self._version = version
693
+ self.text_template = text
694
+ self.messages_template = messages
695
+ self.interpolation_type = interpolation_type
696
+ self.model_settings = model_settings
697
+ self.output_type = output_type
698
+ self.output_schema = output_schema
699
+ self.type = PromptType.TEXT if text else PromptType.LIST
700
+ console = Console()
701
+ console.print("✅ Prompt successfully updated on Confident AI!")
520
702
 
521
703
  ############################################
522
704
  ### Polling
@@ -614,13 +796,13 @@ class Prompt:
614
796
 
615
797
  # Update in-memory properties with fresh data (thread-safe)
616
798
  with self._lock:
617
- self.version = response.version
799
+ self._version = response.version
618
800
  self.label = response.label
619
- self._text_template = response.text
620
- self._messages_template = response.messages
801
+ self.text_template = response.text
802
+ self.messages_template = response.messages
621
803
  self._prompt_version_id = response.id
622
- self._type = response.type
623
- self._interpolation_type = response.interpolation_type
804
+ self.type = response.type
805
+ self.interpolation_type = response.interpolation_type
624
806
 
625
807
  except Exception:
626
808
  pass
deepeval/prompt/utils.py CHANGED
@@ -1,7 +1,19 @@
1
1
  import re
2
+ import uuid
2
3
  from jinja2 import Template
4
+ from typing import Any, Dict, Type, Optional, List
5
+ from pydantic import BaseModel, create_model
3
6
 
4
- from deepeval.prompt.api import PromptInterpolationType
7
+ from deepeval.prompt.api import (
8
+ PromptInterpolationType,
9
+ OutputSchema,
10
+ SchemaDataType,
11
+ OutputSchemaField,
12
+ )
13
+
14
+ ###################################
15
+ # Interpolation
16
+ ###################################
5
17
 
6
18
 
7
19
  def interpolate_mustache(text: str, **kwargs) -> str:
@@ -47,4 +59,134 @@ def interpolate_text(
47
59
  elif interpolation_type == PromptInterpolationType.JINJA:
48
60
  return interpolate_jinja(text, **kwargs)
49
61
 
50
- raise ValueError(f"Unsupported interpolation type: {interpolation_type}")
62
+
63
+ ###################################
64
+ # Output Schema Deconstruction
65
+ ###################################
66
+
67
+ schema_type_map: Dict[str, Any] = {
68
+ SchemaDataType.STRING.value: str,
69
+ SchemaDataType.INTEGER.value: int,
70
+ SchemaDataType.FLOAT.value: float,
71
+ SchemaDataType.BOOLEAN.value: bool,
72
+ SchemaDataType.NULL.value: type(None),
73
+ SchemaDataType.OBJECT.value: dict,
74
+ }
75
+
76
+
77
+ def construct_nested_base_model(
78
+ parent: OutputSchemaField,
79
+ parent_id_map: Dict[Optional[str], List[OutputSchemaField]],
80
+ model_name: str,
81
+ ) -> Type[BaseModel]:
82
+ child_fields: Dict[str, tuple] = {}
83
+ for child in parent_id_map.get(parent.id, []):
84
+ child_type = (
85
+ child.type.value if hasattr(child.type, "value") else child.type
86
+ )
87
+ if child_type == SchemaDataType.OBJECT.value:
88
+ python_type = construct_nested_base_model(
89
+ child, parent_id_map, child.name
90
+ )
91
+ else:
92
+ python_type = schema_type_map.get(child_type, Any)
93
+ default = ... if child.required else None
94
+ child_fields[child.name or child.id] = (python_type, default)
95
+ return create_model(model_name, **child_fields)
96
+
97
+
98
+ def construct_base_model(
99
+ schema: Optional[OutputSchema] = None,
100
+ ) -> Type[BaseModel]:
101
+ if not schema:
102
+ return None
103
+ if not schema.fields:
104
+ return create_model(schema.name)
105
+
106
+ parent_id_map: Dict[Optional[str], List[OutputSchemaField]] = {}
107
+ for field in schema.fields:
108
+ parent_id = field.parent_id or None
109
+ if parent_id_map.get(parent_id) is None:
110
+ parent_id_map[parent_id] = []
111
+ parent_id_map[parent_id].append(field)
112
+
113
+ root_fields: Dict[str, tuple] = {}
114
+ for field in parent_id_map.get(None, []):
115
+ field_type = (
116
+ field.type.value if hasattr(field.type, "value") else field.type
117
+ )
118
+ if field_type == SchemaDataType.OBJECT.value:
119
+ python_type = construct_nested_base_model(
120
+ field, parent_id_map, field.name
121
+ )
122
+ else:
123
+ python_type = schema_type_map.get(field_type, Any)
124
+ default = ... if field.required else None
125
+ root_fields[field.name] = (python_type, default)
126
+
127
+ return create_model(schema.name, **root_fields)
128
+
129
+
130
+ ###################################
131
+ # Output Schema Construction
132
+ ###################################
133
+
134
+
135
+ def _process_model(
136
+ model_class: Type[BaseModel],
137
+ parent_id: Optional[str] = None,
138
+ ) -> List[OutputSchemaField]:
139
+ fields = []
140
+ model_fields = model_class.model_fields
141
+ for field_name, field_info in model_fields.items():
142
+ field_id = str(uuid.uuid4())
143
+ annotation = field_info.annotation
144
+ field_type = "STRING"
145
+ if annotation == str:
146
+ field_type = "STRING"
147
+ elif annotation == int:
148
+ field_type = "INTEGER"
149
+ elif annotation == float:
150
+ field_type = "FLOAT"
151
+ elif annotation == bool:
152
+ field_type = "BOOLEAN"
153
+ elif annotation == list:
154
+ raise ValueError("Unsupported structured output: list")
155
+ elif annotation == dict:
156
+ raise ValueError("Unsupported structured output: dict")
157
+ elif (
158
+ hasattr(annotation, "__bases__")
159
+ and BaseModel in annotation.__bases__
160
+ ):
161
+ field_type = "OBJECT"
162
+ parent_field = OutputSchemaField(
163
+ id=field_id,
164
+ name=field_name,
165
+ type=field_type,
166
+ required=field_info.default is ...,
167
+ parent_id=parent_id,
168
+ )
169
+ fields.append(parent_field)
170
+ nested_fields = _process_model(annotation, field_id)
171
+ fields.extend(nested_fields)
172
+ continue
173
+ required = field_info.default is ...
174
+ fields.append(
175
+ OutputSchemaField(
176
+ id=field_id,
177
+ name=field_name,
178
+ type=field_type,
179
+ required=required,
180
+ parent_id=parent_id,
181
+ )
182
+ )
183
+ return fields
184
+
185
+
186
+ def construct_output_schema(
187
+ base_model_class: Optional[Type[BaseModel]] = None,
188
+ ) -> Optional[OutputSchema]:
189
+ if base_model_class is None:
190
+ return None
191
+ all_fields = _process_model(base_model_class)
192
+ return OutputSchema(fields=all_fields, name=base_model_class.__name__)