judgeval 0.1.0__py3-none-any.whl → 0.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (234) hide show
  1. judgeval/__init__.py +173 -10
  2. judgeval/api/__init__.py +523 -0
  3. judgeval/api/api_types.py +413 -0
  4. judgeval/cli.py +112 -0
  5. judgeval/constants.py +7 -30
  6. judgeval/data/__init__.py +1 -3
  7. judgeval/data/evaluation_run.py +125 -0
  8. judgeval/data/example.py +14 -40
  9. judgeval/data/judgment_types.py +396 -146
  10. judgeval/data/result.py +11 -18
  11. judgeval/data/scorer_data.py +3 -26
  12. judgeval/data/scripts/openapi_transform.py +5 -5
  13. judgeval/data/trace.py +115 -194
  14. judgeval/dataset/__init__.py +335 -0
  15. judgeval/env.py +55 -0
  16. judgeval/evaluation/__init__.py +346 -0
  17. judgeval/exceptions.py +28 -0
  18. judgeval/integrations/langgraph/__init__.py +13 -0
  19. judgeval/integrations/openlit/__init__.py +51 -0
  20. judgeval/judges/__init__.py +2 -2
  21. judgeval/judges/litellm_judge.py +77 -16
  22. judgeval/judges/together_judge.py +88 -17
  23. judgeval/judges/utils.py +7 -20
  24. judgeval/judgment_attribute_keys.py +55 -0
  25. judgeval/{common/logger.py → logger.py} +24 -8
  26. judgeval/prompt/__init__.py +330 -0
  27. judgeval/scorers/__init__.py +11 -11
  28. judgeval/scorers/agent_scorer.py +15 -19
  29. judgeval/scorers/api_scorer.py +21 -23
  30. judgeval/scorers/base_scorer.py +54 -36
  31. judgeval/scorers/example_scorer.py +1 -3
  32. judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +2 -24
  33. judgeval/scorers/judgeval_scorers/api_scorers/answer_correctness.py +2 -10
  34. judgeval/scorers/judgeval_scorers/api_scorers/answer_relevancy.py +2 -2
  35. judgeval/scorers/judgeval_scorers/api_scorers/faithfulness.py +2 -10
  36. judgeval/scorers/judgeval_scorers/api_scorers/instruction_adherence.py +2 -14
  37. judgeval/scorers/judgeval_scorers/api_scorers/prompt_scorer.py +171 -59
  38. judgeval/scorers/score.py +64 -47
  39. judgeval/scorers/utils.py +2 -107
  40. judgeval/tracer/__init__.py +1111 -2
  41. judgeval/tracer/constants.py +1 -0
  42. judgeval/tracer/exporters/__init__.py +40 -0
  43. judgeval/tracer/exporters/s3.py +119 -0
  44. judgeval/tracer/exporters/store.py +59 -0
  45. judgeval/tracer/exporters/utils.py +32 -0
  46. judgeval/tracer/keys.py +63 -0
  47. judgeval/tracer/llm/__init__.py +7 -0
  48. judgeval/tracer/llm/config.py +78 -0
  49. judgeval/tracer/llm/constants.py +9 -0
  50. judgeval/tracer/llm/llm_anthropic/__init__.py +3 -0
  51. judgeval/tracer/llm/llm_anthropic/config.py +6 -0
  52. judgeval/tracer/llm/llm_anthropic/messages.py +452 -0
  53. judgeval/tracer/llm/llm_anthropic/messages_stream.py +322 -0
  54. judgeval/tracer/llm/llm_anthropic/wrapper.py +59 -0
  55. judgeval/tracer/llm/llm_google/__init__.py +3 -0
  56. judgeval/tracer/llm/llm_google/config.py +6 -0
  57. judgeval/tracer/llm/llm_google/generate_content.py +127 -0
  58. judgeval/tracer/llm/llm_google/wrapper.py +30 -0
  59. judgeval/tracer/llm/llm_openai/__init__.py +3 -0
  60. judgeval/tracer/llm/llm_openai/beta_chat_completions.py +216 -0
  61. judgeval/tracer/llm/llm_openai/chat_completions.py +501 -0
  62. judgeval/tracer/llm/llm_openai/config.py +6 -0
  63. judgeval/tracer/llm/llm_openai/responses.py +506 -0
  64. judgeval/tracer/llm/llm_openai/utils.py +42 -0
  65. judgeval/tracer/llm/llm_openai/wrapper.py +63 -0
  66. judgeval/tracer/llm/llm_together/__init__.py +3 -0
  67. judgeval/tracer/llm/llm_together/chat_completions.py +406 -0
  68. judgeval/tracer/llm/llm_together/config.py +6 -0
  69. judgeval/tracer/llm/llm_together/wrapper.py +52 -0
  70. judgeval/tracer/llm/providers.py +19 -0
  71. judgeval/tracer/managers.py +167 -0
  72. judgeval/tracer/processors/__init__.py +220 -0
  73. judgeval/tracer/utils.py +19 -0
  74. judgeval/trainer/__init__.py +14 -0
  75. judgeval/trainer/base_trainer.py +122 -0
  76. judgeval/trainer/config.py +123 -0
  77. judgeval/trainer/console.py +144 -0
  78. judgeval/trainer/fireworks_trainer.py +392 -0
  79. judgeval/trainer/trainable_model.py +252 -0
  80. judgeval/trainer/trainer.py +70 -0
  81. judgeval/utils/async_utils.py +39 -0
  82. judgeval/utils/decorators/__init__.py +0 -0
  83. judgeval/utils/decorators/dont_throw.py +37 -0
  84. judgeval/utils/decorators/use_once.py +13 -0
  85. judgeval/utils/file_utils.py +74 -28
  86. judgeval/utils/guards.py +36 -0
  87. judgeval/utils/meta.py +27 -0
  88. judgeval/utils/project.py +15 -0
  89. judgeval/utils/serialize.py +253 -0
  90. judgeval/utils/testing.py +70 -0
  91. judgeval/utils/url.py +10 -0
  92. judgeval/{version_check.py → utils/version_check.py} +5 -3
  93. judgeval/utils/wrappers/README.md +3 -0
  94. judgeval/utils/wrappers/__init__.py +15 -0
  95. judgeval/utils/wrappers/immutable_wrap_async.py +74 -0
  96. judgeval/utils/wrappers/immutable_wrap_async_iterator.py +84 -0
  97. judgeval/utils/wrappers/immutable_wrap_sync.py +66 -0
  98. judgeval/utils/wrappers/immutable_wrap_sync_iterator.py +84 -0
  99. judgeval/utils/wrappers/mutable_wrap_async.py +67 -0
  100. judgeval/utils/wrappers/mutable_wrap_sync.py +67 -0
  101. judgeval/utils/wrappers/py.typed +0 -0
  102. judgeval/utils/wrappers/utils.py +35 -0
  103. judgeval/v1/__init__.py +88 -0
  104. judgeval/v1/data/__init__.py +7 -0
  105. judgeval/v1/data/example.py +44 -0
  106. judgeval/v1/data/scorer_data.py +42 -0
  107. judgeval/v1/data/scoring_result.py +44 -0
  108. judgeval/v1/datasets/__init__.py +6 -0
  109. judgeval/v1/datasets/dataset.py +214 -0
  110. judgeval/v1/datasets/dataset_factory.py +94 -0
  111. judgeval/v1/evaluation/__init__.py +6 -0
  112. judgeval/v1/evaluation/evaluation.py +182 -0
  113. judgeval/v1/evaluation/evaluation_factory.py +17 -0
  114. judgeval/v1/instrumentation/__init__.py +6 -0
  115. judgeval/v1/instrumentation/llm/__init__.py +7 -0
  116. judgeval/v1/instrumentation/llm/config.py +78 -0
  117. judgeval/v1/instrumentation/llm/constants.py +11 -0
  118. judgeval/v1/instrumentation/llm/llm_anthropic/__init__.py +5 -0
  119. judgeval/v1/instrumentation/llm/llm_anthropic/config.py +6 -0
  120. judgeval/v1/instrumentation/llm/llm_anthropic/messages.py +414 -0
  121. judgeval/v1/instrumentation/llm/llm_anthropic/messages_stream.py +307 -0
  122. judgeval/v1/instrumentation/llm/llm_anthropic/wrapper.py +61 -0
  123. judgeval/v1/instrumentation/llm/llm_google/__init__.py +5 -0
  124. judgeval/v1/instrumentation/llm/llm_google/config.py +6 -0
  125. judgeval/v1/instrumentation/llm/llm_google/generate_content.py +121 -0
  126. judgeval/v1/instrumentation/llm/llm_google/wrapper.py +30 -0
  127. judgeval/v1/instrumentation/llm/llm_openai/__init__.py +5 -0
  128. judgeval/v1/instrumentation/llm/llm_openai/beta_chat_completions.py +212 -0
  129. judgeval/v1/instrumentation/llm/llm_openai/chat_completions.py +477 -0
  130. judgeval/v1/instrumentation/llm/llm_openai/config.py +6 -0
  131. judgeval/v1/instrumentation/llm/llm_openai/responses.py +472 -0
  132. judgeval/v1/instrumentation/llm/llm_openai/utils.py +41 -0
  133. judgeval/v1/instrumentation/llm/llm_openai/wrapper.py +63 -0
  134. judgeval/v1/instrumentation/llm/llm_together/__init__.py +5 -0
  135. judgeval/v1/instrumentation/llm/llm_together/chat_completions.py +382 -0
  136. judgeval/v1/instrumentation/llm/llm_together/config.py +6 -0
  137. judgeval/v1/instrumentation/llm/llm_together/wrapper.py +57 -0
  138. judgeval/v1/instrumentation/llm/providers.py +19 -0
  139. judgeval/v1/integrations/claude_agent_sdk/__init__.py +119 -0
  140. judgeval/v1/integrations/claude_agent_sdk/wrapper.py +564 -0
  141. judgeval/v1/integrations/langgraph/__init__.py +13 -0
  142. judgeval/v1/integrations/openlit/__init__.py +47 -0
  143. judgeval/v1/internal/api/__init__.py +525 -0
  144. judgeval/v1/internal/api/api_types.py +413 -0
  145. judgeval/v1/prompts/__init__.py +6 -0
  146. judgeval/v1/prompts/prompt.py +29 -0
  147. judgeval/v1/prompts/prompt_factory.py +189 -0
  148. judgeval/v1/py.typed +0 -0
  149. judgeval/v1/scorers/__init__.py +6 -0
  150. judgeval/v1/scorers/api_scorer.py +82 -0
  151. judgeval/v1/scorers/base_scorer.py +17 -0
  152. judgeval/v1/scorers/built_in/__init__.py +17 -0
  153. judgeval/v1/scorers/built_in/answer_correctness.py +28 -0
  154. judgeval/v1/scorers/built_in/answer_relevancy.py +28 -0
  155. judgeval/v1/scorers/built_in/built_in_factory.py +26 -0
  156. judgeval/v1/scorers/built_in/faithfulness.py +28 -0
  157. judgeval/v1/scorers/built_in/instruction_adherence.py +28 -0
  158. judgeval/v1/scorers/custom_scorer/__init__.py +6 -0
  159. judgeval/v1/scorers/custom_scorer/custom_scorer.py +50 -0
  160. judgeval/v1/scorers/custom_scorer/custom_scorer_factory.py +16 -0
  161. judgeval/v1/scorers/prompt_scorer/__init__.py +6 -0
  162. judgeval/v1/scorers/prompt_scorer/prompt_scorer.py +86 -0
  163. judgeval/v1/scorers/prompt_scorer/prompt_scorer_factory.py +85 -0
  164. judgeval/v1/scorers/scorers_factory.py +49 -0
  165. judgeval/v1/tracer/__init__.py +7 -0
  166. judgeval/v1/tracer/base_tracer.py +520 -0
  167. judgeval/v1/tracer/exporters/__init__.py +14 -0
  168. judgeval/v1/tracer/exporters/in_memory_span_exporter.py +25 -0
  169. judgeval/v1/tracer/exporters/judgment_span_exporter.py +42 -0
  170. judgeval/v1/tracer/exporters/noop_span_exporter.py +19 -0
  171. judgeval/v1/tracer/exporters/span_store.py +50 -0
  172. judgeval/v1/tracer/judgment_tracer_provider.py +70 -0
  173. judgeval/v1/tracer/processors/__init__.py +6 -0
  174. judgeval/v1/tracer/processors/_lifecycles/__init__.py +28 -0
  175. judgeval/v1/tracer/processors/_lifecycles/agent_id_processor.py +53 -0
  176. judgeval/v1/tracer/processors/_lifecycles/context_keys.py +11 -0
  177. judgeval/v1/tracer/processors/_lifecycles/customer_id_processor.py +29 -0
  178. judgeval/v1/tracer/processors/_lifecycles/registry.py +18 -0
  179. judgeval/v1/tracer/processors/judgment_span_processor.py +165 -0
  180. judgeval/v1/tracer/processors/noop_span_processor.py +42 -0
  181. judgeval/v1/tracer/tracer.py +67 -0
  182. judgeval/v1/tracer/tracer_factory.py +38 -0
  183. judgeval/v1/trainers/__init__.py +5 -0
  184. judgeval/v1/trainers/base_trainer.py +62 -0
  185. judgeval/v1/trainers/config.py +123 -0
  186. judgeval/v1/trainers/console.py +144 -0
  187. judgeval/v1/trainers/fireworks_trainer.py +392 -0
  188. judgeval/v1/trainers/trainable_model.py +252 -0
  189. judgeval/v1/trainers/trainers_factory.py +37 -0
  190. judgeval/v1/utils.py +18 -0
  191. judgeval/version.py +5 -0
  192. judgeval/warnings.py +4 -0
  193. judgeval-0.23.0.dist-info/METADATA +266 -0
  194. judgeval-0.23.0.dist-info/RECORD +201 -0
  195. judgeval-0.23.0.dist-info/entry_points.txt +2 -0
  196. judgeval/clients.py +0 -34
  197. judgeval/common/__init__.py +0 -13
  198. judgeval/common/api/__init__.py +0 -3
  199. judgeval/common/api/api.py +0 -352
  200. judgeval/common/api/constants.py +0 -165
  201. judgeval/common/exceptions.py +0 -27
  202. judgeval/common/storage/__init__.py +0 -6
  203. judgeval/common/storage/s3_storage.py +0 -98
  204. judgeval/common/tracer/__init__.py +0 -31
  205. judgeval/common/tracer/constants.py +0 -22
  206. judgeval/common/tracer/core.py +0 -1916
  207. judgeval/common/tracer/otel_exporter.py +0 -108
  208. judgeval/common/tracer/otel_span_processor.py +0 -234
  209. judgeval/common/tracer/span_processor.py +0 -37
  210. judgeval/common/tracer/span_transformer.py +0 -211
  211. judgeval/common/tracer/trace_manager.py +0 -92
  212. judgeval/common/utils.py +0 -940
  213. judgeval/data/datasets/__init__.py +0 -4
  214. judgeval/data/datasets/dataset.py +0 -341
  215. judgeval/data/datasets/eval_dataset_client.py +0 -214
  216. judgeval/data/tool.py +0 -5
  217. judgeval/data/trace_run.py +0 -37
  218. judgeval/evaluation_run.py +0 -75
  219. judgeval/integrations/langgraph.py +0 -843
  220. judgeval/judges/mixture_of_judges.py +0 -286
  221. judgeval/judgment_client.py +0 -369
  222. judgeval/rules.py +0 -521
  223. judgeval/run_evaluation.py +0 -684
  224. judgeval/scorers/judgeval_scorers/api_scorers/derailment_scorer.py +0 -14
  225. judgeval/scorers/judgeval_scorers/api_scorers/execution_order.py +0 -52
  226. judgeval/scorers/judgeval_scorers/api_scorers/hallucination.py +0 -28
  227. judgeval/scorers/judgeval_scorers/api_scorers/tool_dependency.py +0 -20
  228. judgeval/scorers/judgeval_scorers/api_scorers/tool_order.py +0 -27
  229. judgeval/utils/alerts.py +0 -93
  230. judgeval/utils/requests.py +0 -50
  231. judgeval-0.1.0.dist-info/METADATA +0 -202
  232. judgeval-0.1.0.dist-info/RECORD +0 -73
  233. {judgeval-0.1.0.dist-info → judgeval-0.23.0.dist-info}/WHEEL +0 -0
  234. {judgeval-0.1.0.dist-info → judgeval-0.23.0.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/utils.py DELETED
@@ -1,940 +0,0 @@
1
- """
2
- This file contains utility functions used in repo scripts
3
-
4
- For API calling, we support:
5
- - parallelized model calls on the same prompt
6
- - batched model calls on different prompts
7
-
8
- NOTE: any function beginning with 'a', e.g. 'afetch_together_api_response', is an asynchronous function
9
- """
10
-
11
- # Standard library imports
12
- import asyncio
13
- import concurrent.futures
14
- import os
15
- from types import TracebackType
16
- from judgeval.common.api.constants import ROOT_API
17
- from judgeval.utils.requests import requests
18
- import pprint
19
- from typing import Any, Dict, List, Mapping, Optional, TypeAlias, Union, TypeGuard
20
-
21
- # Third-party imports
22
- import litellm
23
- import pydantic
24
- from dotenv import load_dotenv
25
-
26
- # Local application/library-specific imports
27
- from judgeval.clients import async_together_client, together_client
28
- from judgeval.constants import (
29
- ACCEPTABLE_MODELS,
30
- MAX_WORKER_THREADS,
31
- TOGETHER_SUPPORTED_MODELS,
32
- LITELLM_SUPPORTED_MODELS,
33
- )
34
- from judgeval.common.logger import judgeval_logger
35
-
36
-
37
- class CustomModelParameters(pydantic.BaseModel):
38
- model_name: str
39
- secret_key: str
40
- litellm_base_url: str
41
-
42
- @pydantic.field_validator("model_name")
43
- @classmethod
44
- def validate_model_name(cls, v):
45
- if not v:
46
- raise ValueError("Model name cannot be empty")
47
- return v
48
-
49
- @pydantic.field_validator("secret_key")
50
- @classmethod
51
- def validate_secret_key(cls, v):
52
- if not v:
53
- raise ValueError("Secret key cannot be empty")
54
- return v
55
-
56
- @pydantic.field_validator("litellm_base_url")
57
- @classmethod
58
- def validate_litellm_base_url(cls, v):
59
- if not v:
60
- raise ValueError("Litellm base URL cannot be empty")
61
- return v
62
-
63
-
64
- class ChatCompletionRequest(pydantic.BaseModel):
65
- model: str
66
- messages: List[Dict[str, str]]
67
- response_format: Optional[Union[pydantic.BaseModel, Dict[str, Any]]] = None
68
-
69
- @pydantic.field_validator("messages")
70
- @classmethod
71
- def validate_messages(cls, messages):
72
- if not messages:
73
- raise ValueError("Messages cannot be empty")
74
-
75
- for msg in messages:
76
- if not isinstance(msg, dict):
77
- raise TypeError("Message must be a dictionary")
78
- if "role" not in msg:
79
- raise ValueError("Message missing required 'role' field")
80
- if "content" not in msg:
81
- raise ValueError("Message missing required 'content' field")
82
- if msg["role"] not in ["system", "user", "assistant"]:
83
- raise ValueError(
84
- f"Invalid role '{msg['role']}'. Must be 'system', 'user', or 'assistant'"
85
- )
86
-
87
- return messages
88
-
89
- @pydantic.field_validator("model")
90
- @classmethod
91
- def validate_model(cls, model):
92
- if not model:
93
- raise ValueError("Model cannot be empty")
94
- if model not in ACCEPTABLE_MODELS:
95
- raise ValueError(f"Model {model} is not in the list of supported models.")
96
- return model
97
-
98
- @pydantic.field_validator("response_format", mode="before")
99
- @classmethod
100
- def validate_response_format(cls, response_format):
101
- if response_format is not None:
102
- if not isinstance(response_format, (dict, pydantic.BaseModel)):
103
- raise TypeError(
104
- "Response format must be a dictionary or pydantic model"
105
- )
106
- # Optional: Add additional validation for required fields if needed
107
- # For example, checking for 'type': 'json' in OpenAI's format
108
- return response_format
109
-
110
-
111
- os.environ["LITELLM_LOG"] = "DEBUG"
112
-
113
- load_dotenv()
114
-
115
-
116
- def read_file(file_path: str) -> str:
117
- with open(file_path, "r", encoding="utf-8") as file:
118
- return file.read()
119
-
120
-
121
- def validate_api_key(judgment_api_key: str):
122
- """
123
- Validates that the user api key is valid
124
- """
125
- response = requests.post(
126
- f"{ROOT_API}/auth/validate_api_key/",
127
- headers={
128
- "Content-Type": "application/json",
129
- "Authorization": f"Bearer {judgment_api_key}",
130
- },
131
- json={},
132
- verify=True,
133
- )
134
- if response.status_code == 200:
135
- return True, response.json()
136
- else:
137
- return False, response.json().get("detail", "Error validating API key")
138
-
139
-
140
- def fetch_together_api_response(
141
- model: str, messages: List[Mapping], response_format: pydantic.BaseModel = None
142
- ) -> str:
143
- """
144
- Fetches a single response from the Together API for a given model and messages.
145
- """
146
- # Validate request
147
- if messages is None or messages == []:
148
- raise ValueError("Messages cannot be empty")
149
-
150
- request = ChatCompletionRequest(
151
- model=model, messages=messages, response_format=response_format
152
- )
153
-
154
- if request.response_format is not None:
155
- response = together_client.chat.completions.create(
156
- model=request.model,
157
- messages=request.messages,
158
- response_format=request.response_format,
159
- )
160
- else:
161
- response = together_client.chat.completions.create(
162
- model=request.model,
163
- messages=request.messages,
164
- )
165
-
166
- return response.choices[0].message.content
167
-
168
-
169
- async def afetch_together_api_response(
170
- model: str, messages: List[Mapping], response_format: pydantic.BaseModel = None
171
- ) -> str:
172
- """
173
- ASYNCHRONOUSLY Fetches a single response from the Together API for a given model and messages.
174
- """
175
- request = ChatCompletionRequest(
176
- model=model, messages=messages, response_format=response_format
177
- )
178
-
179
- if request.response_format is not None:
180
- response = await async_together_client.chat.completions.create(
181
- model=request.model,
182
- messages=request.messages,
183
- response_format=request.response_format,
184
- )
185
- else:
186
- response = await async_together_client.chat.completions.create(
187
- model=request.model,
188
- messages=request.messages,
189
- )
190
- return response.choices[0].message.content
191
-
192
-
193
- def query_together_api_multiple_calls(
194
- models: List[str],
195
- messages: List[List[Mapping]],
196
- response_formats: List[pydantic.BaseModel] | None = None,
197
- ) -> List[Union[str, None]]:
198
- """
199
- Queries the Together API for multiple calls in parallel
200
-
201
- Args:
202
- models (List[str]): List of models to query
203
- messages (List[List[Mapping]]): List of messages to query. Each inner object corresponds to a single prompt.
204
- response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
205
-
206
- Returns:
207
- List[str]: TogetherAI responses for each model and message pair in order. Any exceptions in the thread call result in a None.
208
- """
209
- # Check for empty models list
210
- if not models:
211
- raise ValueError("Models list cannot be empty")
212
-
213
- # Validate all models are supported
214
- for model in models:
215
- if model not in ACCEPTABLE_MODELS:
216
- raise ValueError(
217
- f"Model {model} is not in the list of supported models: {ACCEPTABLE_MODELS}."
218
- )
219
-
220
- # Validate input lengths match
221
- if response_formats is None:
222
- response_formats = [None] * len(models)
223
- if not (len(models) == len(messages) == len(response_formats)):
224
- raise ValueError(
225
- "Number of models, messages, and response formats must be the same"
226
- )
227
-
228
- # Validate message format
229
- validate_batched_chat_messages(messages)
230
-
231
- num_workers = int(os.getenv("NUM_WORKER_THREADS", MAX_WORKER_THREADS))
232
- # Initialize results to maintain ordered outputs
233
- out: List[str | None] = [None] * len(messages)
234
- with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
235
- # Submit all queries to together API with index, gets back the response content
236
- futures = {
237
- executor.submit(
238
- fetch_together_api_response, model, message, response_format
239
- ): idx
240
- for idx, (model, message, response_format) in enumerate(
241
- zip(models, messages, response_formats)
242
- )
243
- }
244
-
245
- # Collect results as they complete -- result is response content
246
- for future in concurrent.futures.as_completed(futures):
247
- idx = futures[future]
248
- try:
249
- out[idx] = future.result()
250
- except Exception as e:
251
- judgeval_logger.error(f"Error in parallel call {idx}: {str(e)}")
252
- out[idx] = None
253
- return out
254
-
255
-
256
- async def aquery_together_api_multiple_calls(
257
- models: List[str],
258
- messages: List[List[Mapping]],
259
- response_formats: List[pydantic.BaseModel] | None = None,
260
- ) -> List[Union[str, None]]:
261
- """
262
- Queries the Together API for multiple calls in parallel
263
-
264
- Args:
265
- models (List[str]): List of models to query
266
- messages (List[List[Mapping]]): List of messages to query. Each inner object corresponds to a single prompt.
267
- response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
268
-
269
- Returns:
270
- List[str]: TogetherAI responses for each model and message pair in order. Any exceptions in the thread call result in a None.
271
- """
272
- # Check for empty models list
273
- if not models:
274
- raise ValueError("Models list cannot be empty")
275
-
276
- # Validate all models are supported
277
- for model in models:
278
- if model not in ACCEPTABLE_MODELS:
279
- raise ValueError(
280
- f"Model {model} is not in the list of supported models: {ACCEPTABLE_MODELS}."
281
- )
282
-
283
- # Validate input lengths match
284
- if response_formats is None:
285
- response_formats = [None] * len(models)
286
- if not (len(models) == len(messages) == len(response_formats)):
287
- raise ValueError(
288
- "Number of models, messages, and response formats must be the same"
289
- )
290
-
291
- # Validate message format
292
- validate_batched_chat_messages(messages)
293
-
294
- out: List[Union[str, None]] = [None] * len(messages)
295
-
296
- async def fetch_and_store(idx, model, message, response_format):
297
- try:
298
- out[idx] = await afetch_together_api_response(
299
- model, message, response_format
300
- )
301
- except Exception as e:
302
- judgeval_logger.error(f"Error in parallel call {idx}: {str(e)}")
303
- out[idx] = None
304
-
305
- tasks = [
306
- fetch_and_store(idx, model, message, response_format)
307
- for idx, (model, message, response_format) in enumerate(
308
- zip(models, messages, response_formats)
309
- )
310
- ]
311
-
312
- await asyncio.gather(*tasks)
313
- return out
314
-
315
-
316
- def fetch_litellm_api_response(
317
- model: str, messages: List[Mapping], response_format: pydantic.BaseModel = None
318
- ) -> str:
319
- """
320
- Fetches a single response from the Litellm API for a given model and messages.
321
- """
322
- request = ChatCompletionRequest(
323
- model=model, messages=messages, response_format=response_format
324
- )
325
-
326
- if request.response_format is not None:
327
- response = litellm.completion(
328
- model=request.model,
329
- messages=request.messages,
330
- response_format=request.response_format,
331
- )
332
- else:
333
- response = litellm.completion(
334
- model=request.model,
335
- messages=request.messages,
336
- )
337
- return response.choices[0].message.content
338
-
339
-
340
- def fetch_custom_litellm_api_response(
341
- custom_model_parameters: CustomModelParameters,
342
- messages: List[Mapping],
343
- response_format: pydantic.BaseModel = None,
344
- ) -> str:
345
- if messages is None or messages == []:
346
- raise ValueError("Messages cannot be empty")
347
-
348
- if custom_model_parameters is None:
349
- raise ValueError("Custom model parameters cannot be empty")
350
-
351
- if not isinstance(custom_model_parameters, CustomModelParameters):
352
- raise ValueError(
353
- "Custom model parameters must be a CustomModelParameters object"
354
- )
355
-
356
- if response_format is not None:
357
- response = litellm.completion(
358
- model=custom_model_parameters.model_name,
359
- messages=messages,
360
- api_key=custom_model_parameters.secret_key,
361
- base_url=custom_model_parameters.litellm_base_url,
362
- response_format=response_format,
363
- )
364
- else:
365
- response = litellm.completion(
366
- model=custom_model_parameters.model_name,
367
- messages=messages,
368
- api_key=custom_model_parameters.secret_key,
369
- base_url=custom_model_parameters.litellm_base_url,
370
- )
371
- return response.choices[0].message.content
372
-
373
-
374
- async def afetch_litellm_api_response(
375
- model: str, messages: List[Mapping], response_format: pydantic.BaseModel = None
376
- ) -> str:
377
- """
378
- ASYNCHRONOUSLY Fetches a single response from the Litellm API for a given model and messages.
379
- """
380
- if messages is None or messages == []:
381
- raise ValueError("Messages cannot be empty")
382
-
383
- # Add validation
384
- validate_chat_messages(messages)
385
-
386
- if model not in ACCEPTABLE_MODELS:
387
- raise ValueError(
388
- f"Model {model} is not in the list of supported models: {ACCEPTABLE_MODELS}."
389
- )
390
-
391
- if response_format is not None:
392
- response = await litellm.acompletion(
393
- model=model, messages=messages, response_format=response_format
394
- )
395
- else:
396
- response = await litellm.acompletion(
397
- model=model,
398
- messages=messages,
399
- )
400
- return response.choices[0].message.content
401
-
402
-
403
- async def afetch_custom_litellm_api_response(
404
- custom_model_parameters: CustomModelParameters,
405
- messages: List[Mapping],
406
- response_format: pydantic.BaseModel = None,
407
- ) -> str:
408
- """
409
- ASYNCHRONOUSLY Fetches a single response from the Litellm API for a given model and messages.
410
- """
411
- if messages is None or messages == []:
412
- raise ValueError("Messages cannot be empty")
413
-
414
- if custom_model_parameters is None:
415
- raise ValueError("Custom model parameters cannot be empty")
416
-
417
- if not isinstance(custom_model_parameters, CustomModelParameters):
418
- raise ValueError(
419
- "Custom model parameters must be a CustomModelParameters object"
420
- )
421
-
422
- if response_format is not None:
423
- response = await litellm.acompletion(
424
- model=custom_model_parameters.model_name,
425
- messages=messages,
426
- api_key=custom_model_parameters.secret_key,
427
- base_url=custom_model_parameters.litellm_base_url,
428
- response_format=response_format,
429
- )
430
- else:
431
- response = await litellm.acompletion(
432
- model=custom_model_parameters.model_name,
433
- messages=messages,
434
- api_key=custom_model_parameters.secret_key,
435
- base_url=custom_model_parameters.litellm_base_url,
436
- )
437
- return response.choices[0].message.content
438
-
439
-
440
- def query_litellm_api_multiple_calls(
441
- models: List[str],
442
- messages: List[List[Mapping]],
443
- response_formats: List[pydantic.BaseModel] | None = None,
444
- ) -> List[Union[str, None]]:
445
- """
446
- Queries the Litellm API for multiple calls in parallel
447
-
448
- Args:
449
- models (List[str]): List of models to query
450
- messages (List[List[Mapping]]): List of messages to query
451
- response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
452
-
453
- Returns:
454
- List[str]: Litellm responses for each model and message pair in order. Any exceptions in the thread call result in a None.
455
- """
456
- num_workers = int(os.getenv("NUM_WORKER_THREADS", MAX_WORKER_THREADS))
457
- # Initialize results to maintain ordered outputs
458
- out: List[Union[str, None]] = [None] * len(messages)
459
- with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
460
- # Submit all queries to Litellm API with index, gets back the response content
461
- futures = {
462
- executor.submit(
463
- fetch_litellm_api_response, model, message, response_format
464
- ): idx
465
- for idx, (model, message, response_format) in enumerate(
466
- zip(models, messages, response_formats or [None] * len(messages))
467
- )
468
- }
469
-
470
- # Collect results as they complete -- result is response content
471
- for future in concurrent.futures.as_completed(futures):
472
- idx = futures[future]
473
- try:
474
- out[idx] = future.result()
475
- except Exception as e:
476
- judgeval_logger.error(f"Error in parallel call {idx}: {str(e)}")
477
- out[idx] = None
478
- return out
479
-
480
-
481
- async def aquery_litellm_api_multiple_calls(
482
- models: List[str],
483
- messages: List[List[Mapping]],
484
- response_formats: List[pydantic.BaseModel] | None = None,
485
- ) -> List[Union[str, None]]:
486
- """
487
- Queries the Litellm API for multiple calls in parallel
488
-
489
- Args:
490
- models (List[str]): List of models to query
491
- messages (List[List[Mapping]]): List of messages to query
492
- response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
493
-
494
- Returns:
495
- List[str]: Litellm responses for each model and message pair in order. Any exceptions in the thread call result in a None.
496
- """
497
- # Initialize results to maintain ordered outputs
498
- out: List[Union[str, None]] = [None] * len(messages)
499
-
500
- async def fetch_and_store(idx, model, message, response_format):
501
- try:
502
- out[idx] = await afetch_litellm_api_response(
503
- model, message, response_format
504
- )
505
- except Exception as e:
506
- judgeval_logger.error(f"Error in parallel call {idx}: {str(e)}")
507
- out[idx] = None
508
-
509
- tasks = [
510
- fetch_and_store(idx, model, message, response_format)
511
- for idx, (model, message, response_format) in enumerate(
512
- zip(models, messages, response_formats or [None] * len(messages))
513
- )
514
- ]
515
-
516
- await asyncio.gather(*tasks)
517
- return out
518
-
519
-
520
- def validate_chat_messages(messages, batched: bool = False):
521
- """Validate chat message format before API call"""
522
- if not isinstance(messages, list):
523
- raise TypeError("Messages must be a list")
524
-
525
- for msg in messages:
526
- if not isinstance(msg, dict):
527
- if batched and not isinstance(msg, list):
528
- raise TypeError("Each message must be a list")
529
- elif not batched:
530
- raise TypeError("Message must be a dictionary")
531
- if "role" not in msg:
532
- raise ValueError("Message missing required 'role' field")
533
- if "content" not in msg:
534
- raise ValueError("Message missing required 'content' field")
535
- if msg["role"] not in ["system", "user", "assistant"]:
536
- raise ValueError(
537
- f"Invalid role '{msg['role']}'. Must be 'system', 'user', or 'assistant'"
538
- )
539
-
540
-
541
- def validate_batched_chat_messages(messages):
542
- """
543
- Validate format of batched chat messages before API call
544
-
545
- Args:
546
- messages (List[List[Mapping]]): List of message lists, where each inner list contains
547
- message dictionaries with 'role' and 'content' fields
548
-
549
- Raises:
550
- TypeError: If messages format is invalid
551
- ValueError: If message content is invalid
552
- """
553
- if not isinstance(messages, list):
554
- raise TypeError("Batched messages must be a list")
555
-
556
- if not messages:
557
- raise ValueError("Batched messages cannot be empty")
558
-
559
- for message_list in messages:
560
- if not isinstance(message_list, list):
561
- raise TypeError("Each batch item must be a list of messages")
562
-
563
- # Validate individual messages using existing function
564
- validate_chat_messages(message_list)
565
-
566
-
567
- def is_batched_messages(
568
- messages: Union[List[Mapping], List[List[Mapping]]],
569
- ) -> TypeGuard[List[List[Mapping]]]:
570
- return isinstance(messages, list) and all(isinstance(msg, list) for msg in messages)
571
-
572
-
573
- def is_simple_messages(
574
- messages: Union[List[Mapping], List[List[Mapping]]],
575
- ) -> TypeGuard[List[Mapping]]:
576
- return isinstance(messages, list) and all(
577
- not isinstance(msg, list) for msg in messages
578
- )
579
-
580
-
581
- def get_chat_completion(
582
- model_type: str,
583
- messages: Union[List[Mapping], List[List[Mapping]]],
584
- response_format: pydantic.BaseModel = None,
585
- batched: bool = False,
586
- ) -> Union[str, List[str | None]]:
587
- """
588
- Generates chat completions using a single model and potentially several messages. Supports closed-source and OSS models.
589
-
590
- Parameters:
591
- - model_type (str): The type of model to use for generating completions.
592
- - messages (Union[List[Mapping], List[List[Mapping]]]): The messages to be used for generating completions.
593
- If batched is True, this should be a list of lists of mappings.
594
- - response_format (pydantic.BaseModel, optional): The format of the response. Defaults to None.
595
- - batched (bool, optional): Whether to process messages in batch mode. Defaults to False.
596
- Returns:
597
- - str: The generated chat completion(s). If batched is True, returns a list of strings.
598
- Raises:
599
- - ValueError: If requested model is not supported by Litellm or TogetherAI.
600
- """
601
-
602
- # Check for empty messages list
603
- if not messages or messages == []:
604
- raise ValueError("Messages cannot be empty")
605
-
606
- # Add validation
607
- if batched:
608
- validate_batched_chat_messages(messages)
609
- else:
610
- validate_chat_messages(messages)
611
-
612
- if (
613
- batched
614
- and is_batched_messages(messages)
615
- and model_type in TOGETHER_SUPPORTED_MODELS
616
- ):
617
- return query_together_api_multiple_calls(
618
- models=[model_type] * len(messages),
619
- messages=messages,
620
- response_formats=[response_format] * len(messages),
621
- )
622
- elif (
623
- batched
624
- and is_batched_messages(messages)
625
- and model_type in LITELLM_SUPPORTED_MODELS
626
- ):
627
- return query_litellm_api_multiple_calls(
628
- models=[model_type] * len(messages),
629
- messages=messages,
630
- response_formats=[response_format] * len(messages),
631
- )
632
- elif (
633
- not batched
634
- and is_simple_messages(messages)
635
- and model_type in TOGETHER_SUPPORTED_MODELS
636
- ):
637
- return fetch_together_api_response(
638
- model=model_type, messages=messages, response_format=response_format
639
- )
640
- elif (
641
- not batched
642
- and is_simple_messages(messages)
643
- and model_type in LITELLM_SUPPORTED_MODELS
644
- ):
645
- return fetch_litellm_api_response(
646
- model=model_type, messages=messages, response_format=response_format
647
- )
648
-
649
- raise ValueError(
650
- f"Model {model_type} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again."
651
- )
652
-
653
-
654
- async def aget_chat_completion(
655
- model_type: str,
656
- messages: Union[List[Mapping], List[List[Mapping]]],
657
- response_format: pydantic.BaseModel = None,
658
- batched: bool = False,
659
- ) -> Union[str, List[str | None]]:
660
- """
661
- ASYNCHRONOUSLY generates chat completions using a single model and potentially several messages. Supports closed-source and OSS models.
662
-
663
- Parameters:
664
- - model_type (str): The type of model to use for generating completions.
665
- - messages (Union[List[Mapping], List[List[Mapping]]]): The messages to be used for generating completions.
666
- If batched is True, this should be a list of lists of mappings.
667
- - response_format (pydantic.BaseModel, optional): The format of the response. Defaults to None.
668
- - batched (bool, optional): Whether to process messages in batch mode. Defaults to False.
669
- Returns:
670
- - str: The generated chat completion(s). If batched is True, returns a list of strings.
671
- Raises:
672
- - ValueError: If requested model is not supported by Litellm or TogetherAI.
673
- """
674
-
675
- if batched:
676
- validate_batched_chat_messages(messages)
677
- else:
678
- validate_chat_messages(messages)
679
-
680
- if (
681
- batched
682
- and is_batched_messages(messages)
683
- and model_type in TOGETHER_SUPPORTED_MODELS
684
- ):
685
- return await aquery_together_api_multiple_calls(
686
- models=[model_type] * len(messages),
687
- messages=messages,
688
- response_formats=[response_format] * len(messages),
689
- )
690
- elif (
691
- batched
692
- and is_batched_messages(messages)
693
- and model_type in LITELLM_SUPPORTED_MODELS
694
- ):
695
- return await aquery_litellm_api_multiple_calls(
696
- models=[model_type] * len(messages),
697
- messages=messages,
698
- response_formats=[response_format] * len(messages),
699
- )
700
- elif (
701
- not batched
702
- and is_simple_messages(messages)
703
- and model_type in TOGETHER_SUPPORTED_MODELS
704
- ):
705
- return await afetch_together_api_response(
706
- model=model_type, messages=messages, response_format=response_format
707
- )
708
- elif (
709
- not batched
710
- and is_simple_messages(messages)
711
- and model_type in LITELLM_SUPPORTED_MODELS
712
- ):
713
- return await afetch_litellm_api_response(
714
- model=model_type, messages=messages, response_format=response_format
715
- )
716
-
717
- judgeval_logger.error(f"Model {model_type} not supported by either API")
718
- raise ValueError(
719
- f"Model {model_type} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again."
720
- )
721
-
722
-
723
- def get_completion_multiple_models(
724
- models: List[str],
725
- messages: List[List[Mapping]],
726
- response_formats: List[pydantic.BaseModel] | None = None,
727
- ) -> List[str | None]:
728
- """
729
- Retrieves completions for a single prompt from multiple models in parallel. Supports closed-source and OSS models.
730
-
731
- Args:
732
- models (List[str]): List of models to query
733
- messages (List[List[Mapping]]): List of messages to query. Each inner object corresponds to a single prompt.
734
- response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
735
-
736
- Returns:
737
- List[str]: List of completions from the models in the order of the input models
738
- Raises:
739
- ValueError: If a model is not supported by Litellm or Together
740
- """
741
-
742
- if models is None or models == []:
743
- raise ValueError("Models list cannot be empty")
744
-
745
- validate_batched_chat_messages(messages)
746
-
747
- if len(models) != len(messages):
748
- judgeval_logger.error(
749
- f"Model/message count mismatch: {len(models)} vs {len(messages)}"
750
- )
751
- raise ValueError(
752
- f"Number of models and messages must be the same: {len(models)} != {len(messages)}"
753
- )
754
- if response_formats is None:
755
- response_formats = [None] * len(models)
756
- # Partition the model requests into TogetherAI and Litellm models, but keep the ordering saved
757
- together_calls, litellm_calls = {}, {} # index -> model, message, response_format
758
- together_responses, litellm_responses = [], []
759
- for idx, (model, message, r_format) in enumerate(
760
- zip(models, messages, response_formats)
761
- ):
762
- if model in TOGETHER_SUPPORTED_MODELS:
763
- together_calls[idx] = (model, message, r_format)
764
- elif model in LITELLM_SUPPORTED_MODELS:
765
- litellm_calls[idx] = (model, message, r_format)
766
- else:
767
- judgeval_logger.error(f"Model {model} not supported by either API")
768
- raise ValueError(
769
- f"Model {model} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again."
770
- )
771
-
772
- # Add validation before processing
773
- for msg_list in messages:
774
- validate_chat_messages(msg_list)
775
-
776
- # Get the responses from the TogetherAI models
777
- # List of responses from the TogetherAI models in order of the together_calls dict
778
- if together_calls:
779
- together_responses = query_together_api_multiple_calls(
780
- models=[model for model, _, _ in together_calls.values()],
781
- messages=[message for _, message, _ in together_calls.values()],
782
- response_formats=[format for _, _, format in together_calls.values()],
783
- )
784
-
785
- # Get the responses from the Litellm models
786
- if litellm_calls:
787
- litellm_responses = query_litellm_api_multiple_calls(
788
- models=[model for model, _, _ in litellm_calls.values()],
789
- messages=[message for _, message, _ in litellm_calls.values()],
790
- response_formats=[format for _, _, format in litellm_calls.values()],
791
- )
792
-
793
- # Merge the responses in the order of the original models
794
- out: List[Union[str, None]] = [None] * len(models)
795
- for idx, (model, message, r_format) in together_calls.items():
796
- out[idx] = together_responses.pop(0)
797
- for idx, (model, message, r_format) in litellm_calls.items():
798
- out[idx] = litellm_responses.pop(0)
799
- return out
800
-
801
-
802
- async def aget_completion_multiple_models(
803
- models: List[str],
804
- messages: List[List[Mapping]],
805
- response_formats: List[pydantic.BaseModel] | None = None,
806
- ) -> List[str | None]:
807
- """
808
- ASYNCHRONOUSLY retrieves completions for a single prompt from multiple models in parallel. Supports closed-source and OSS models.
809
-
810
- Args:
811
- models (List[str]): List of models to query
812
- messages (List[List[Mapping]]): List of messages to query. Each inner object corresponds to a single prompt.
813
- response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
814
-
815
- Returns:
816
- List[str]: List of completions from the models in the order of the input models
817
- Raises:
818
- ValueError: If a model is not supported by Litellm or Together
819
- """
820
- if models is None or models == []:
821
- raise ValueError("Models list cannot be empty")
822
-
823
- if len(models) != len(messages):
824
- raise ValueError(
825
- f"Number of models and messages must be the same: {len(models)} != {len(messages)}"
826
- )
827
- if response_formats is None:
828
- response_formats = [None] * len(models)
829
-
830
- validate_batched_chat_messages(messages)
831
-
832
- # Partition the model requests into TogetherAI and Litellm models, but keep the ordering saved
833
- together_calls, litellm_calls = {}, {} # index -> model, message, response_format
834
- together_responses, litellm_responses = [], []
835
- for idx, (model, message, r_format) in enumerate(
836
- zip(models, messages, response_formats)
837
- ):
838
- if model in TOGETHER_SUPPORTED_MODELS:
839
- together_calls[idx] = (model, message, r_format)
840
- elif model in LITELLM_SUPPORTED_MODELS:
841
- litellm_calls[idx] = (model, message, r_format)
842
- else:
843
- raise ValueError(
844
- f"Model {model} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again."
845
- )
846
-
847
- # Add validation before processing
848
- for msg_list in messages:
849
- validate_chat_messages(msg_list)
850
-
851
- # Get the responses from the TogetherAI models
852
- # List of responses from the TogetherAI models in order of the together_calls dict
853
- if together_calls:
854
- together_responses = await aquery_together_api_multiple_calls(
855
- models=[model for model, _, _ in together_calls.values()],
856
- messages=[message for _, message, _ in together_calls.values()],
857
- response_formats=[format for _, _, format in together_calls.values()],
858
- )
859
-
860
- # Get the responses from the Litellm models
861
- if litellm_calls:
862
- litellm_responses = await aquery_litellm_api_multiple_calls(
863
- models=[model for model, _, _ in litellm_calls.values()],
864
- messages=[message for _, message, _ in litellm_calls.values()],
865
- response_formats=[format for _, _, format in litellm_calls.values()],
866
- )
867
-
868
- # Merge the responses in the order of the original models
869
- out: List[Union[str, None]] = [None] * len(models)
870
- for idx, (model, message, r_format) in together_calls.items():
871
- out[idx] = together_responses.pop(0)
872
- for idx, (model, message, r_format) in litellm_calls.items():
873
- out[idx] = litellm_responses.pop(0)
874
- return out
875
-
876
-
877
- if __name__ == "__main__":
878
- batched_messages: List[List[Mapping]] = [
879
- [
880
- {"role": "system", "content": "You are a helpful assistant."},
881
- {"role": "user", "content": "What is the capital of France?"},
882
- ],
883
- [
884
- {"role": "system", "content": "You are a helpful assistant."},
885
- {"role": "user", "content": "What is the capital of Japan?"},
886
- ],
887
- ]
888
-
889
- non_batched_messages: List[Mapping] = [
890
- {"role": "system", "content": "You are a helpful assistant."},
891
- {"role": "user", "content": "What is the capital of France?"},
892
- ]
893
-
894
- batched_messages_2: List[List[Mapping]] = [
895
- [
896
- {"role": "system", "content": "You are a helpful assistant."},
897
- {"role": "user", "content": "What is the capital of China?"},
898
- ],
899
- [
900
- {"role": "system", "content": "You are a helpful assistant."},
901
- {"role": "user", "content": "What is the capital of France?"},
902
- ],
903
- [
904
- {"role": "system", "content": "You are a helpful assistant."},
905
- {"role": "user", "content": "What is the capital of Japan?"},
906
- ],
907
- ]
908
-
909
- # Batched
910
- pprint.pprint(
911
- get_chat_completion(
912
- model_type="LLAMA3_405B_INSTRUCT_TURBO",
913
- messages=batched_messages,
914
- batched=True,
915
- )
916
- )
917
-
918
- # Non batched
919
- pprint.pprint(
920
- get_chat_completion(
921
- model_type="LLAMA3_8B_INSTRUCT_TURBO",
922
- messages=non_batched_messages,
923
- batched=False,
924
- )
925
- )
926
-
927
- # Batched single completion to multiple models
928
- pprint.pprint(
929
- get_completion_multiple_models(
930
- models=[
931
- "LLAMA3_70B_INSTRUCT_TURBO",
932
- "LLAMA3_405B_INSTRUCT_TURBO",
933
- "gpt-4.1-mini",
934
- ],
935
- messages=batched_messages_2,
936
- )
937
- )
938
-
939
- ExcInfo: TypeAlias = tuple[type[BaseException], BaseException, TracebackType]
940
- OptExcInfo: TypeAlias = ExcInfo | tuple[None, None, None]