deepeval 3.7.5__py3-none-any.whl → 3.7.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. deepeval/_version.py +1 -1
  2. deepeval/cli/main.py +2022 -759
  3. deepeval/cli/utils.py +208 -36
  4. deepeval/config/dotenv_handler.py +19 -0
  5. deepeval/config/settings.py +675 -245
  6. deepeval/config/utils.py +9 -1
  7. deepeval/dataset/api.py +23 -1
  8. deepeval/dataset/golden.py +106 -21
  9. deepeval/evaluate/evaluate.py +0 -3
  10. deepeval/evaluate/execute.py +162 -315
  11. deepeval/evaluate/utils.py +6 -30
  12. deepeval/key_handler.py +124 -51
  13. deepeval/metrics/__init__.py +0 -4
  14. deepeval/metrics/answer_relevancy/answer_relevancy.py +89 -132
  15. deepeval/metrics/answer_relevancy/template.py +102 -179
  16. deepeval/metrics/arena_g_eval/arena_g_eval.py +98 -96
  17. deepeval/metrics/arena_g_eval/template.py +17 -1
  18. deepeval/metrics/argument_correctness/argument_correctness.py +81 -87
  19. deepeval/metrics/argument_correctness/template.py +19 -2
  20. deepeval/metrics/base_metric.py +19 -41
  21. deepeval/metrics/bias/bias.py +102 -108
  22. deepeval/metrics/bias/template.py +14 -2
  23. deepeval/metrics/contextual_precision/contextual_precision.py +56 -92
  24. deepeval/metrics/contextual_recall/contextual_recall.py +58 -85
  25. deepeval/metrics/contextual_relevancy/contextual_relevancy.py +53 -83
  26. deepeval/metrics/conversation_completeness/conversation_completeness.py +101 -119
  27. deepeval/metrics/conversation_completeness/template.py +23 -3
  28. deepeval/metrics/conversational_dag/conversational_dag.py +12 -8
  29. deepeval/metrics/conversational_dag/nodes.py +66 -123
  30. deepeval/metrics/conversational_dag/templates.py +16 -0
  31. deepeval/metrics/conversational_g_eval/conversational_g_eval.py +47 -66
  32. deepeval/metrics/dag/dag.py +10 -0
  33. deepeval/metrics/dag/nodes.py +63 -126
  34. deepeval/metrics/dag/templates.py +14 -0
  35. deepeval/metrics/exact_match/exact_match.py +9 -1
  36. deepeval/metrics/faithfulness/faithfulness.py +82 -136
  37. deepeval/metrics/g_eval/g_eval.py +93 -79
  38. deepeval/metrics/g_eval/template.py +18 -1
  39. deepeval/metrics/g_eval/utils.py +7 -6
  40. deepeval/metrics/goal_accuracy/goal_accuracy.py +91 -76
  41. deepeval/metrics/goal_accuracy/template.py +21 -3
  42. deepeval/metrics/hallucination/hallucination.py +60 -75
  43. deepeval/metrics/hallucination/template.py +13 -0
  44. deepeval/metrics/indicator.py +11 -10
  45. deepeval/metrics/json_correctness/json_correctness.py +40 -38
  46. deepeval/metrics/json_correctness/template.py +10 -0
  47. deepeval/metrics/knowledge_retention/knowledge_retention.py +60 -97
  48. deepeval/metrics/knowledge_retention/schema.py +9 -3
  49. deepeval/metrics/knowledge_retention/template.py +12 -0
  50. deepeval/metrics/mcp/mcp_task_completion.py +72 -43
  51. deepeval/metrics/mcp/multi_turn_mcp_use_metric.py +93 -75
  52. deepeval/metrics/mcp/schema.py +4 -0
  53. deepeval/metrics/mcp/template.py +59 -0
  54. deepeval/metrics/mcp_use_metric/mcp_use_metric.py +58 -64
  55. deepeval/metrics/mcp_use_metric/template.py +12 -0
  56. deepeval/metrics/misuse/misuse.py +77 -97
  57. deepeval/metrics/misuse/template.py +15 -0
  58. deepeval/metrics/multimodal_metrics/__init__.py +0 -1
  59. deepeval/metrics/multimodal_metrics/image_coherence/image_coherence.py +37 -38
  60. deepeval/metrics/multimodal_metrics/image_editing/image_editing.py +55 -76
  61. deepeval/metrics/multimodal_metrics/image_helpfulness/image_helpfulness.py +37 -38
  62. deepeval/metrics/multimodal_metrics/image_reference/image_reference.py +37 -38
  63. deepeval/metrics/multimodal_metrics/text_to_image/text_to_image.py +57 -76
  64. deepeval/metrics/non_advice/non_advice.py +79 -105
  65. deepeval/metrics/non_advice/template.py +12 -0
  66. deepeval/metrics/pattern_match/pattern_match.py +12 -4
  67. deepeval/metrics/pii_leakage/pii_leakage.py +75 -106
  68. deepeval/metrics/pii_leakage/template.py +14 -0
  69. deepeval/metrics/plan_adherence/plan_adherence.py +63 -89
  70. deepeval/metrics/plan_adherence/template.py +11 -0
  71. deepeval/metrics/plan_quality/plan_quality.py +63 -87
  72. deepeval/metrics/plan_quality/template.py +9 -0
  73. deepeval/metrics/prompt_alignment/prompt_alignment.py +78 -86
  74. deepeval/metrics/prompt_alignment/template.py +12 -0
  75. deepeval/metrics/role_adherence/role_adherence.py +48 -71
  76. deepeval/metrics/role_adherence/template.py +14 -0
  77. deepeval/metrics/role_violation/role_violation.py +75 -108
  78. deepeval/metrics/role_violation/template.py +12 -0
  79. deepeval/metrics/step_efficiency/step_efficiency.py +55 -65
  80. deepeval/metrics/step_efficiency/template.py +11 -0
  81. deepeval/metrics/summarization/summarization.py +115 -183
  82. deepeval/metrics/summarization/template.py +19 -0
  83. deepeval/metrics/task_completion/task_completion.py +67 -73
  84. deepeval/metrics/tool_correctness/tool_correctness.py +43 -42
  85. deepeval/metrics/tool_use/schema.py +4 -0
  86. deepeval/metrics/tool_use/template.py +16 -2
  87. deepeval/metrics/tool_use/tool_use.py +72 -94
  88. deepeval/metrics/topic_adherence/schema.py +4 -0
  89. deepeval/metrics/topic_adherence/template.py +21 -1
  90. deepeval/metrics/topic_adherence/topic_adherence.py +68 -81
  91. deepeval/metrics/toxicity/template.py +13 -0
  92. deepeval/metrics/toxicity/toxicity.py +80 -99
  93. deepeval/metrics/turn_contextual_precision/schema.py +3 -3
  94. deepeval/metrics/turn_contextual_precision/template.py +9 -2
  95. deepeval/metrics/turn_contextual_precision/turn_contextual_precision.py +154 -154
  96. deepeval/metrics/turn_contextual_recall/schema.py +3 -3
  97. deepeval/metrics/turn_contextual_recall/template.py +8 -1
  98. deepeval/metrics/turn_contextual_recall/turn_contextual_recall.py +148 -143
  99. deepeval/metrics/turn_contextual_relevancy/schema.py +2 -2
  100. deepeval/metrics/turn_contextual_relevancy/template.py +8 -1
  101. deepeval/metrics/turn_contextual_relevancy/turn_contextual_relevancy.py +154 -157
  102. deepeval/metrics/turn_faithfulness/schema.py +1 -1
  103. deepeval/metrics/turn_faithfulness/template.py +8 -1
  104. deepeval/metrics/turn_faithfulness/turn_faithfulness.py +180 -203
  105. deepeval/metrics/turn_relevancy/template.py +14 -0
  106. deepeval/metrics/turn_relevancy/turn_relevancy.py +56 -69
  107. deepeval/metrics/utils.py +161 -91
  108. deepeval/models/__init__.py +2 -0
  109. deepeval/models/base_model.py +44 -6
  110. deepeval/models/embedding_models/azure_embedding_model.py +34 -12
  111. deepeval/models/embedding_models/local_embedding_model.py +22 -7
  112. deepeval/models/embedding_models/ollama_embedding_model.py +17 -6
  113. deepeval/models/embedding_models/openai_embedding_model.py +3 -2
  114. deepeval/models/llms/__init__.py +2 -0
  115. deepeval/models/llms/amazon_bedrock_model.py +229 -73
  116. deepeval/models/llms/anthropic_model.py +143 -48
  117. deepeval/models/llms/azure_model.py +169 -95
  118. deepeval/models/llms/constants.py +2032 -0
  119. deepeval/models/llms/deepseek_model.py +82 -35
  120. deepeval/models/llms/gemini_model.py +126 -67
  121. deepeval/models/llms/grok_model.py +128 -65
  122. deepeval/models/llms/kimi_model.py +129 -87
  123. deepeval/models/llms/litellm_model.py +94 -18
  124. deepeval/models/llms/local_model.py +115 -16
  125. deepeval/models/llms/ollama_model.py +97 -76
  126. deepeval/models/llms/openai_model.py +169 -311
  127. deepeval/models/llms/portkey_model.py +58 -16
  128. deepeval/models/llms/utils.py +5 -2
  129. deepeval/models/retry_policy.py +10 -5
  130. deepeval/models/utils.py +56 -4
  131. deepeval/simulator/conversation_simulator.py +49 -2
  132. deepeval/simulator/template.py +16 -1
  133. deepeval/synthesizer/synthesizer.py +19 -17
  134. deepeval/test_case/api.py +24 -45
  135. deepeval/test_case/arena_test_case.py +7 -2
  136. deepeval/test_case/conversational_test_case.py +55 -6
  137. deepeval/test_case/llm_test_case.py +60 -6
  138. deepeval/test_run/api.py +3 -0
  139. deepeval/test_run/test_run.py +6 -1
  140. deepeval/utils.py +26 -0
  141. {deepeval-3.7.5.dist-info → deepeval-3.7.7.dist-info}/METADATA +3 -3
  142. {deepeval-3.7.5.dist-info → deepeval-3.7.7.dist-info}/RECORD +145 -148
  143. deepeval/metrics/multimodal_metrics/multimodal_g_eval/__init__.py +0 -0
  144. deepeval/metrics/multimodal_metrics/multimodal_g_eval/multimodal_g_eval.py +0 -386
  145. deepeval/metrics/multimodal_metrics/multimodal_g_eval/schema.py +0 -11
  146. deepeval/metrics/multimodal_metrics/multimodal_g_eval/template.py +0 -133
  147. deepeval/metrics/multimodal_metrics/multimodal_g_eval/utils.py +0 -68
  148. {deepeval-3.7.5.dist-info → deepeval-3.7.7.dist-info}/LICENSE.md +0 -0
  149. {deepeval-3.7.5.dist-info → deepeval-3.7.7.dist-info}/WHEEL +0 -0
  150. {deepeval-3.7.5.dist-info → deepeval-3.7.7.dist-info}/entry_points.txt +0 -0
@@ -13,6 +13,7 @@ from deepeval.models.utils import (
13
13
  require_secret_api_key,
14
14
  normalize_kwargs_and_extract_aliases,
15
15
  )
16
+ from deepeval.utils import require_param
16
17
 
17
18
 
18
19
  retry_azure = create_retry_decorator(PS.AZURE)
@@ -31,7 +32,7 @@ class AzureOpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
31
32
  api_key: Optional[str] = None,
32
33
  base_url: Optional[str] = None,
33
34
  deployment_name: Optional[str] = None,
34
- openai_api_version: Optional[str] = None,
35
+ api_version: Optional[str] = None,
35
36
  generation_kwargs: Optional[Dict] = None,
36
37
  **kwargs,
37
38
  ):
@@ -53,25 +54,46 @@ class AzureOpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
53
54
 
54
55
  if api_key is not None:
55
56
  # keep it secret, keep it safe from serializings, logging and alike
56
- self.api_key: SecretStr | None = SecretStr(api_key)
57
+ self.api_key: Optional[SecretStr] = SecretStr(api_key)
57
58
  else:
58
59
  self.api_key = settings.AZURE_OPENAI_API_KEY
59
60
 
60
- self.openai_api_version = (
61
- openai_api_version or settings.OPENAI_API_VERSION
61
+ api_version = api_version or settings.OPENAI_API_VERSION
62
+ if base_url is not None:
63
+ base_url = str(base_url).rstrip("/")
64
+ elif settings.AZURE_OPENAI_ENDPOINT is not None:
65
+ base_url = str(settings.AZURE_OPENAI_ENDPOINT).rstrip("/")
66
+
67
+ deployment_name = (
68
+ deployment_name or settings.AZURE_EMBEDDING_DEPLOYMENT_NAME
69
+ )
70
+
71
+ model = model or settings.AZURE_EMBEDDING_MODEL_NAME or deployment_name
72
+
73
+ # validation
74
+ self.deployment_name = require_param(
75
+ deployment_name,
76
+ provider_label="AzureOpenAIEmbeddingModel",
77
+ env_var_name="AZURE_EMBEDDING_DEPLOYMENT_NAME",
78
+ param_hint="deployment_name",
62
79
  )
63
- self.base_url = (
64
- base_url
65
- or settings.AZURE_OPENAI_ENDPOINT
66
- and str(settings.AZURE_OPENAI_ENDPOINT)
80
+
81
+ self.base_url = require_param(
82
+ base_url,
83
+ provider_label="AzureOpenAIEmbeddingModel",
84
+ env_var_name="AZURE_OPENAI_ENDPOINT",
85
+ param_hint="base_url",
67
86
  )
68
87
 
69
- self.deployment_name = (
70
- deployment_name or settings.AZURE_EMBEDDING_DEPLOYMENT_NAME
88
+ self.api_version = require_param(
89
+ api_version,
90
+ provider_label="AzureOpenAIEmbeddingModel",
91
+ env_var_name="OPENAI_API_VERSION",
92
+ param_hint="api_version",
71
93
  )
94
+
72
95
  # Keep sanitized kwargs for client call to strip legacy keys
73
96
  self.kwargs = normalized_kwargs
74
- model = model or self.deployment_name
75
97
  self.generation_kwargs = generation_kwargs or {}
76
98
  super().__init__(model)
77
99
 
@@ -126,7 +148,7 @@ class AzureOpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
126
148
 
127
149
  client_init_kwargs = dict(
128
150
  api_key=api_key,
129
- api_version=self.openai_api_version,
151
+ api_version=self.api_version,
130
152
  azure_endpoint=self.base_url,
131
153
  azure_deployment=self.deployment_name,
132
154
  **client_kwargs,
@@ -12,7 +12,7 @@ from deepeval.models.retry_policy import (
12
12
  sdk_retries_for,
13
13
  )
14
14
  from deepeval.constants import ProviderSlug as PS
15
-
15
+ from deepeval.utils import require_param
16
16
 
17
17
  # consistent retry rules
18
18
  retry_local = create_retry_decorator(PS.LOCAL)
@@ -31,16 +31,31 @@ class LocalEmbeddingModel(DeepEvalBaseEmbeddingModel):
31
31
  settings = get_settings()
32
32
  if api_key is not None:
33
33
  # keep it secret, keep it safe from serializings, logging and alike
34
- self.api_key: SecretStr | None = SecretStr(api_key)
34
+ self.api_key: Optional[SecretStr] = SecretStr(api_key)
35
35
  else:
36
36
  self.api_key = get_settings().LOCAL_EMBEDDING_API_KEY
37
37
 
38
- self.base_url = (
39
- base_url
40
- or settings.LOCAL_EMBEDDING_BASE_URL
41
- and str(settings.LOCAL_EMBEDDING_BASE_URL)
42
- )
38
+ if base_url is not None:
39
+ base_url = str(base_url).rstrip("/")
40
+ elif settings.LOCAL_EMBEDDING_BASE_URL is not None:
41
+ base_url = str(settings.LOCAL_EMBEDDING_BASE_URL).rstrip("/")
42
+
43
43
  model = model or settings.LOCAL_EMBEDDING_MODEL_NAME
44
+ # validation
45
+ model = require_param(
46
+ model,
47
+ provider_label="LocalEmbeddingModel",
48
+ env_var_name="LOCAL_EMBEDDING_MODEL_NAME",
49
+ param_hint="model",
50
+ )
51
+
52
+ self.base_url = require_param(
53
+ base_url,
54
+ provider_label="LocalEmbeddingModel",
55
+ env_var_name="LOCAL_EMBEDDING_BASE_URL",
56
+ param_hint="base_url",
57
+ )
58
+
44
59
  # Keep sanitized kwargs for client call to strip legacy keys
45
60
  self.kwargs = kwargs
46
61
  self.generation_kwargs = generation_kwargs or {}
@@ -10,7 +10,7 @@ from deepeval.models.retry_policy import (
10
10
  create_retry_decorator,
11
11
  )
12
12
  from deepeval.constants import ProviderSlug as PS
13
-
13
+ from deepeval.utils import require_param
14
14
 
15
15
  retry_ollama = create_retry_decorator(PS.OLLAMA)
16
16
 
@@ -37,12 +37,23 @@ class OllamaEmbeddingModel(DeepEvalBaseEmbeddingModel):
37
37
 
38
38
  settings = get_settings()
39
39
 
40
- self.base_url = (
41
- base_url
42
- or settings.LOCAL_EMBEDDING_BASE_URL
43
- and str(settings.LOCAL_EMBEDDING_BASE_URL)
44
- )
40
+ if base_url is not None:
41
+ self.base_url = str(base_url).rstrip("/")
42
+ elif settings.LOCAL_EMBEDDING_BASE_URL is not None:
43
+ self.base_url = str(settings.LOCAL_EMBEDDING_BASE_URL).rstrip("/")
44
+ else:
45
+ self.base_url = "http://localhost:11434"
46
+
45
47
  model = model or settings.LOCAL_EMBEDDING_MODEL_NAME
48
+
49
+ # validation
50
+ model = require_param(
51
+ model,
52
+ provider_label="OllamaEmbeddingModel",
53
+ env_var_name="LOCAL_EMBEDDING_MODEL_NAME",
54
+ param_hint="model",
55
+ )
56
+
46
57
  # Keep sanitized kwargs for client call to strip legacy keys
47
58
  self.kwargs = normalized_kwargs
48
59
  self.generation_kwargs = generation_kwargs or {}
@@ -2,6 +2,7 @@ from typing import Dict, Optional, List
2
2
  from openai import OpenAI, AsyncOpenAI
3
3
  from pydantic import SecretStr
4
4
 
5
+ from deepeval.errors import DeepEvalError
5
6
  from deepeval.config.settings import get_settings
6
7
  from deepeval.models.utils import (
7
8
  require_secret_api_key,
@@ -51,13 +52,13 @@ class OpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
51
52
 
52
53
  if api_key is not None:
53
54
  # keep it secret, keep it safe from serializings, logging and alike
54
- self.api_key: SecretStr | None = SecretStr(api_key)
55
+ self.api_key: Optional[SecretStr] = SecretStr(api_key)
55
56
  else:
56
57
  self.api_key = get_settings().OPENAI_API_KEY
57
58
 
58
59
  model = model if model else default_openai_embedding_model
59
60
  if model not in valid_openai_embedding_models:
60
- raise ValueError(
61
+ raise DeepEvalError(
61
62
  f"Invalid model. Available OpenAI Embedding models: {', '.join(valid_openai_embedding_models)}"
62
63
  )
63
64
  self.kwargs = normalized_kwargs
@@ -9,6 +9,7 @@ from .litellm_model import LiteLLMModel
9
9
  from .kimi_model import KimiModel
10
10
  from .grok_model import GrokModel
11
11
  from .deepseek_model import DeepSeekModel
12
+ from .portkey_model import PortkeyModel
12
13
 
13
14
  __all__ = [
14
15
  "AzureOpenAIModel",
@@ -22,4 +23,5 @@ __all__ = [
22
23
  "KimiModel",
23
24
  "GrokModel",
24
25
  "DeepSeekModel",
26
+ "PortkeyModel",
25
27
  ]
@@ -1,130 +1,286 @@
1
- from typing import Optional, Tuple, Union, Dict
1
+ import base64
2
+ from typing import Optional, Tuple, Union, Dict, List
2
3
  from contextlib import AsyncExitStack
3
- from pydantic import BaseModel
4
4
 
5
+ from pydantic import BaseModel, SecretStr
6
+
7
+ from deepeval.config.settings import get_settings
8
+ from deepeval.utils import (
9
+ require_dependency,
10
+ require_param,
11
+ )
5
12
  from deepeval.models.retry_policy import (
6
13
  create_retry_decorator,
7
14
  sdk_retries_for,
8
15
  )
16
+ from deepeval.test_case import MLLMImage
17
+ from deepeval.utils import check_if_multimodal, convert_to_multi_modal_array
9
18
  from deepeval.models import DeepEvalBaseLLM
19
+ from deepeval.models.llms.constants import BEDROCK_MODELS_DATA
10
20
  from deepeval.models.llms.utils import trim_and_load_json, safe_asyncio_run
11
21
  from deepeval.constants import ProviderSlug as PS
22
+ from deepeval.models.utils import (
23
+ require_costs,
24
+ normalize_kwargs_and_extract_aliases,
25
+ )
12
26
 
13
- # check aiobotocore availability
14
- try:
15
- from aiobotocore.session import get_session
16
- from botocore.config import Config
17
-
18
- aiobotocore_available = True
19
- except ImportError:
20
- aiobotocore_available = False
21
27
 
22
- # define retry policy
23
28
  retry_bedrock = create_retry_decorator(PS.BEDROCK)
24
29
 
25
-
26
- def _check_aiobotocore_available():
27
- if not aiobotocore_available:
28
- raise ImportError(
29
- "aiobotocore and botocore are required for this functionality. "
30
- "Install them via your package manager (e.g. pip install aiobotocore botocore)"
31
- )
30
+ _ALIAS_MAP = {
31
+ "model": ["model_id"],
32
+ "region": ["region_name"],
33
+ "cost_per_input_token": ["input_token_cost"],
34
+ "cost_per_output_token": ["output_token_cost"],
35
+ }
32
36
 
33
37
 
34
38
  class AmazonBedrockModel(DeepEvalBaseLLM):
35
39
  def __init__(
36
40
  self,
37
- model_id: str,
38
- region_name: str,
41
+ model: Optional[str] = None,
39
42
  aws_access_key_id: Optional[str] = None,
40
43
  aws_secret_access_key: Optional[str] = None,
41
- input_token_cost: float = 0,
42
- output_token_cost: float = 0,
44
+ cost_per_input_token: Optional[float] = None,
45
+ cost_per_output_token: Optional[float] = None,
46
+ region: Optional[str] = None,
43
47
  generation_kwargs: Optional[Dict] = None,
44
48
  **kwargs,
45
49
  ):
46
- _check_aiobotocore_available()
47
- super().__init__(model_id)
48
-
49
- self.model_id = model_id
50
- self.region_name = region_name
51
- self.aws_access_key_id = aws_access_key_id
52
- self.aws_secret_access_key = aws_secret_access_key
53
- self.input_token_cost = input_token_cost
54
- self.output_token_cost = output_token_cost
55
-
56
- # prepare aiobotocore session, config, and async exit stack
57
- self._session = get_session()
50
+ settings = get_settings()
51
+
52
+ normalized_kwargs, alias_values = normalize_kwargs_and_extract_aliases(
53
+ "AmazonBedrockModel",
54
+ kwargs,
55
+ _ALIAS_MAP,
56
+ )
57
+
58
+ # Backwards compatibility for renamed params
59
+ if model is None and "model" in alias_values:
60
+ model = alias_values["model"]
61
+ if (
62
+ cost_per_input_token is None
63
+ and "cost_per_input_token" in alias_values
64
+ ):
65
+ cost_per_input_token = alias_values["cost_per_input_token"]
66
+ if (
67
+ cost_per_output_token is None
68
+ and "cost_per_output_token" in alias_values
69
+ ):
70
+ cost_per_output_token = alias_values["cost_per_output_token"]
71
+
72
+ # Secrets: prefer explicit args -> settings -> then AWS default chain
73
+ if aws_access_key_id is not None:
74
+ self.aws_access_key_id: Optional[SecretStr] = SecretStr(
75
+ aws_access_key_id
76
+ )
77
+ else:
78
+ self.aws_access_key_id = settings.AWS_ACCESS_KEY_ID
79
+
80
+ if aws_secret_access_key is not None:
81
+ self.aws_secret_access_key: Optional[SecretStr] = SecretStr(
82
+ aws_secret_access_key
83
+ )
84
+ else:
85
+ self.aws_secret_access_key = settings.AWS_SECRET_ACCESS_KEY
86
+
87
+ # Dependencies: aiobotocore & botocore
88
+ aiobotocore_session = require_dependency(
89
+ "aiobotocore.session",
90
+ provider_label="AmazonBedrockModel",
91
+ install_hint="Install it with `pip install aiobotocore`.",
92
+ )
93
+ self.botocore_module = require_dependency(
94
+ "botocore",
95
+ provider_label="AmazonBedrockModel",
96
+ install_hint="Install it with `pip install botocore`.",
97
+ )
98
+ self._session = aiobotocore_session.get_session()
58
99
  self._exit_stack = AsyncExitStack()
59
- self.kwargs = kwargs
100
+
101
+ # Defaults from settings
102
+ model = model or settings.AWS_BEDROCK_MODEL_NAME
103
+ region = region or settings.AWS_BEDROCK_REGION
104
+
105
+ cost_per_input_token = (
106
+ cost_per_input_token
107
+ if cost_per_input_token is not None
108
+ else settings.AWS_BEDROCK_COST_PER_INPUT_TOKEN
109
+ )
110
+ cost_per_output_token = (
111
+ cost_per_output_token
112
+ if cost_per_output_token is not None
113
+ else settings.AWS_BEDROCK_COST_PER_OUTPUT_TOKEN
114
+ )
115
+
116
+ # Required params
117
+ model = require_param(
118
+ model,
119
+ provider_label="AmazonBedrockModel",
120
+ env_var_name="AWS_BEDROCK_MODEL_NAME",
121
+ param_hint="model",
122
+ )
123
+ region = require_param(
124
+ region,
125
+ provider_label="AmazonBedrockModel",
126
+ env_var_name="AWS_BEDROCK_REGION",
127
+ param_hint="region",
128
+ )
129
+
130
+ self.model_data = BEDROCK_MODELS_DATA.get(model)
131
+ cost_per_input_token, cost_per_output_token = require_costs(
132
+ self.model_data,
133
+ model,
134
+ "AWS_BEDROCK_COST_PER_INPUT_TOKEN",
135
+ "AWS_BEDROCK_COST_PER_OUTPUT_TOKEN",
136
+ cost_per_input_token,
137
+ cost_per_output_token,
138
+ )
139
+
140
+ # Final attributes
141
+ self.region = region
142
+ self.cost_per_input_token = float(cost_per_input_token or 0.0)
143
+ self.cost_per_output_token = float(cost_per_output_token or 0.0)
144
+
145
+ self.kwargs = normalized_kwargs
60
146
  self.generation_kwargs = generation_kwargs or {}
61
147
  self._client = None
62
148
  self._sdk_retry_mode: Optional[bool] = None
63
149
 
150
+ super().__init__(model)
151
+
64
152
  ###############################################
65
153
  # Generate functions
66
154
  ###############################################
67
155
 
68
156
  def generate(
69
157
  self, prompt: str, schema: Optional[BaseModel] = None
70
- ) -> Tuple[Union[str, Dict], float]:
158
+ ) -> Tuple[Union[str, BaseModel], float]:
71
159
  return safe_asyncio_run(self.a_generate(prompt, schema))
72
160
 
73
161
  @retry_bedrock
74
162
  async def a_generate(
75
163
  self, prompt: str, schema: Optional[BaseModel] = None
76
- ) -> Tuple[Union[str, Dict], float]:
77
-
78
- try:
164
+ ) -> Tuple[Union[str, BaseModel], float]:
165
+ if check_if_multimodal(prompt):
166
+ prompt = convert_to_multi_modal_array(input=prompt)
167
+ payload = self.generate_payload(prompt)
168
+ else:
79
169
  payload = self.get_converse_request_body(prompt)
80
- client = await self._ensure_client()
81
- response = await client.converse(
82
- modelId=self.model_id,
83
- messages=payload["messages"],
84
- inferenceConfig=payload["inferenceConfig"],
85
- )
86
- message = response["output"]["message"]["content"][0]["text"]
87
- cost = self.calculate_cost(
88
- response["usage"]["inputTokens"],
89
- response["usage"]["outputTokens"],
90
- )
91
- if schema is None:
92
- return message, cost
93
- else:
94
- json_output = trim_and_load_json(message)
95
- return schema.model_validate(json_output), cost
96
- finally:
97
- await self.close()
170
+
171
+ payload = self.get_converse_request_body(prompt)
172
+ client = await self._ensure_client()
173
+ response = await client.converse(
174
+ modelId=self.get_model_name(),
175
+ messages=payload["messages"],
176
+ inferenceConfig=payload["inferenceConfig"],
177
+ )
178
+ message = response["output"]["message"]["content"][0]["text"]
179
+ cost = self.calculate_cost(
180
+ response["usage"]["inputTokens"],
181
+ response["usage"]["outputTokens"],
182
+ )
183
+ if schema is None:
184
+ return message, cost
185
+ else:
186
+ json_output = trim_and_load_json(message)
187
+ return schema.model_validate(json_output), cost
188
+
189
+ def generate_payload(
190
+ self, multimodal_input: Optional[List[Union[str, MLLMImage]]] = None
191
+ ):
192
+ multimodal_input = [] if multimodal_input is None else multimodal_input
193
+ content = []
194
+ for element in multimodal_input:
195
+ if isinstance(element, str):
196
+ content.append({"text": element})
197
+ elif isinstance(element, MLLMImage):
198
+ # Bedrock doesn't support external URLs - must convert everything to bytes
199
+ element.ensure_images_loaded()
200
+
201
+ image_format = (
202
+ (element.mimeType or "image/jpeg").split("/")[-1].upper()
203
+ )
204
+ image_format = "JPEG" if image_format == "JPG" else image_format
205
+
206
+ try:
207
+ image_raw_bytes = base64.b64decode(element.dataBase64)
208
+ except Exception:
209
+ raise ValueError(
210
+ f"Invalid base64 data in MLLMImage: {element._id}"
211
+ )
212
+
213
+ content.append(
214
+ {
215
+ "image": {
216
+ "format": image_format,
217
+ "source": {"bytes": image_raw_bytes},
218
+ }
219
+ }
220
+ )
221
+
222
+ return {
223
+ "messages": [{"role": "user", "content": content}],
224
+ "inferenceConfig": {
225
+ **self.generation_kwargs,
226
+ },
227
+ }
228
+
229
+ #########################
230
+ # Capabilities #
231
+ #########################
232
+
233
+ def supports_log_probs(self) -> Union[bool, None]:
234
+ return self.model_data.supports_log_probs
235
+
236
+ def supports_temperature(self) -> Union[bool, None]:
237
+ return self.model_data.supports_temperature
238
+
239
+ def supports_multimodal(self) -> Union[bool, None]:
240
+ return self.model_data.supports_multimodal
241
+
242
+ def supports_structured_outputs(self) -> Union[bool, None]:
243
+ return self.model_data.supports_structured_outputs
244
+
245
+ def supports_json_mode(self) -> Union[bool, None]:
246
+ return self.model_data.supports_json
98
247
 
99
248
  ###############################################
100
249
  # Client management
101
250
  ###############################################
102
251
 
103
252
  async def _ensure_client(self):
253
+
104
254
  use_sdk = sdk_retries_for(PS.BEDROCK)
105
255
 
106
256
  # only rebuild if client is missing or the sdk retry mode changes
107
257
  if self._client is None or self._sdk_retry_mode != use_sdk:
108
- # Close any previous
109
- if self._client is not None:
110
- await self._exit_stack.aclose()
111
- self._client = None
112
258
 
113
259
  # create retry config for botocore
114
260
  retries_config = {"max_attempts": (5 if use_sdk else 1)}
115
261
  if use_sdk:
116
262
  retries_config["mode"] = "adaptive"
117
263
 
264
+ Config = self.botocore_module.config.Config
118
265
  config = Config(retries=retries_config)
119
266
 
120
- cm = self._session.create_client(
121
- "bedrock-runtime",
122
- region_name=self.region_name,
123
- aws_access_key_id=self.aws_access_key_id,
124
- aws_secret_access_key=self.aws_secret_access_key,
125
- config=config,
267
+ client_kwargs = {
268
+ "region_name": self.region,
269
+ "config": config,
126
270
  **self.kwargs,
127
- )
271
+ }
272
+
273
+ if self.aws_access_key_id is not None:
274
+ client_kwargs["aws_access_key_id"] = (
275
+ self.aws_access_key_id.get_secret_value()
276
+ )
277
+ if self.aws_secret_access_key is not None:
278
+ client_kwargs["aws_secret_access_key"] = (
279
+ self.aws_secret_access_key.get_secret_value()
280
+ )
281
+
282
+ cm = self._session.create_client("bedrock-runtime", **client_kwargs)
283
+
128
284
  self._client = await self._exit_stack.enter_async_context(cm)
129
285
  self._sdk_retry_mode = use_sdk
130
286
 
@@ -148,13 +304,13 @@ class AmazonBedrockModel(DeepEvalBaseLLM):
148
304
  }
149
305
 
150
306
  def calculate_cost(self, input_tokens: int, output_tokens: int) -> float:
151
- return (
152
- input_tokens * self.input_token_cost
153
- + output_tokens * self.output_token_cost
154
- )
307
+ if self.model_data.input_price and self.model_data.output_price:
308
+ input_cost = input_tokens * self.model_data.input_price
309
+ output_cost = output_tokens * self.model_data.output_price
310
+ return input_cost + output_cost
155
311
 
156
312
  def load_model(self):
157
313
  pass
158
314
 
159
315
  def get_model_name(self) -> str:
160
- return self.model_id
316
+ return self.name