judgeval 0.0.44__py3-none-any.whl → 0.0.46__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/__init__.py +5 -4
- judgeval/clients.py +6 -6
- judgeval/common/__init__.py +7 -2
- judgeval/common/exceptions.py +2 -3
- judgeval/common/logger.py +74 -49
- judgeval/common/s3_storage.py +30 -23
- judgeval/common/tracer.py +1273 -939
- judgeval/common/utils.py +416 -244
- judgeval/constants.py +73 -61
- judgeval/data/__init__.py +1 -1
- judgeval/data/custom_example.py +3 -2
- judgeval/data/datasets/dataset.py +80 -54
- judgeval/data/datasets/eval_dataset_client.py +131 -181
- judgeval/data/example.py +67 -43
- judgeval/data/result.py +11 -9
- judgeval/data/scorer_data.py +4 -2
- judgeval/data/tool.py +25 -16
- judgeval/data/trace.py +57 -29
- judgeval/data/trace_run.py +5 -11
- judgeval/evaluation_run.py +22 -82
- judgeval/integrations/langgraph.py +546 -184
- judgeval/judges/base_judge.py +1 -2
- judgeval/judges/litellm_judge.py +33 -11
- judgeval/judges/mixture_of_judges.py +128 -78
- judgeval/judges/together_judge.py +22 -9
- judgeval/judges/utils.py +14 -5
- judgeval/judgment_client.py +259 -271
- judgeval/rules.py +169 -142
- judgeval/run_evaluation.py +462 -305
- judgeval/scorers/api_scorer.py +20 -11
- judgeval/scorers/exceptions.py +1 -0
- judgeval/scorers/judgeval_scorer.py +77 -58
- judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +46 -15
- judgeval/scorers/judgeval_scorers/api_scorers/answer_correctness.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/answer_relevancy.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/classifier_scorer.py +12 -11
- judgeval/scorers/judgeval_scorers/api_scorers/comparison.py +7 -5
- judgeval/scorers/judgeval_scorers/api_scorers/contextual_precision.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/contextual_recall.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/contextual_relevancy.py +5 -2
- judgeval/scorers/judgeval_scorers/api_scorers/derailment_scorer.py +2 -1
- judgeval/scorers/judgeval_scorers/api_scorers/execution_order.py +17 -8
- judgeval/scorers/judgeval_scorers/api_scorers/faithfulness.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/groundedness.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/hallucination.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/instruction_adherence.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/json_correctness.py +8 -9
- judgeval/scorers/judgeval_scorers/api_scorers/summarization.py +4 -4
- judgeval/scorers/judgeval_scorers/api_scorers/tool_dependency.py +5 -5
- judgeval/scorers/judgeval_scorers/api_scorers/tool_order.py +5 -2
- judgeval/scorers/judgeval_scorers/classifiers/text2sql/text2sql_scorer.py +9 -10
- judgeval/scorers/prompt_scorer.py +48 -37
- judgeval/scorers/score.py +86 -53
- judgeval/scorers/utils.py +11 -7
- judgeval/tracer/__init__.py +1 -1
- judgeval/utils/alerts.py +23 -12
- judgeval/utils/{data_utils.py → file_utils.py} +5 -9
- judgeval/utils/requests.py +29 -0
- judgeval/version_check.py +5 -2
- {judgeval-0.0.44.dist-info → judgeval-0.0.46.dist-info}/METADATA +79 -135
- judgeval-0.0.46.dist-info/RECORD +69 -0
- judgeval-0.0.44.dist-info/RECORD +0 -68
- {judgeval-0.0.44.dist-info → judgeval-0.0.46.dist-info}/WHEEL +0 -0
- {judgeval-0.0.44.dist-info → judgeval-0.0.46.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/utils.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
This file contains utility functions used in repo scripts
|
3
3
|
|
4
4
|
For API calling, we support:
|
5
|
-
- parallelized model calls on the same prompt
|
5
|
+
- parallelized model calls on the same prompt
|
6
6
|
- batched model calls on different prompts
|
7
7
|
|
8
8
|
NOTE: any function beginning with 'a', e.g. 'afetch_together_api_response', is an asynchronous function
|
@@ -13,9 +13,9 @@ import asyncio
|
|
13
13
|
import concurrent.futures
|
14
14
|
import os
|
15
15
|
from types import TracebackType
|
16
|
-
import requests
|
16
|
+
from judgeval.utils.requests import requests
|
17
17
|
import pprint
|
18
|
-
from typing import Any, Dict, List,
|
18
|
+
from typing import Any, Dict, List, Mapping, Optional, TypeAlias, Union, TypeGuard
|
19
19
|
|
20
20
|
# Third-party imports
|
21
21
|
import litellm
|
@@ -24,7 +24,13 @@ from dotenv import load_dotenv
|
|
24
24
|
|
25
25
|
# Local application/library-specific imports
|
26
26
|
from judgeval.clients import async_together_client, together_client
|
27
|
-
from judgeval.constants import
|
27
|
+
from judgeval.constants import (
|
28
|
+
ACCEPTABLE_MODELS,
|
29
|
+
MAX_WORKER_THREADS,
|
30
|
+
ROOT_API,
|
31
|
+
TOGETHER_SUPPORTED_MODELS,
|
32
|
+
LITELLM_SUPPORTED_MODELS,
|
33
|
+
)
|
28
34
|
from judgeval.common.logger import debug, error
|
29
35
|
|
30
36
|
|
@@ -32,72 +38,80 @@ class CustomModelParameters(pydantic.BaseModel):
|
|
32
38
|
model_name: str
|
33
39
|
secret_key: str
|
34
40
|
litellm_base_url: str
|
35
|
-
|
36
|
-
@pydantic.field_validator(
|
41
|
+
|
42
|
+
@pydantic.field_validator("model_name")
|
37
43
|
def validate_model_name(cls, v):
|
38
44
|
if not v:
|
39
45
|
raise ValueError("Model name cannot be empty")
|
40
46
|
return v
|
41
|
-
|
42
|
-
@pydantic.field_validator(
|
47
|
+
|
48
|
+
@pydantic.field_validator("secret_key")
|
43
49
|
def validate_secret_key(cls, v):
|
44
50
|
if not v:
|
45
51
|
raise ValueError("Secret key cannot be empty")
|
46
52
|
return v
|
47
|
-
|
48
|
-
@pydantic.field_validator(
|
53
|
+
|
54
|
+
@pydantic.field_validator("litellm_base_url")
|
49
55
|
def validate_litellm_base_url(cls, v):
|
50
56
|
if not v:
|
51
57
|
raise ValueError("Litellm base URL cannot be empty")
|
52
58
|
return v
|
53
59
|
|
60
|
+
|
54
61
|
class ChatCompletionRequest(pydantic.BaseModel):
|
55
62
|
model: str
|
56
63
|
messages: List[Dict[str, str]]
|
57
64
|
response_format: Optional[Union[pydantic.BaseModel, Dict[str, Any]]] = None
|
58
|
-
|
59
|
-
@pydantic.field_validator(
|
65
|
+
|
66
|
+
@pydantic.field_validator("messages")
|
60
67
|
def validate_messages(cls, messages):
|
61
68
|
if not messages:
|
62
69
|
raise ValueError("Messages cannot be empty")
|
63
|
-
|
70
|
+
|
64
71
|
for msg in messages:
|
65
72
|
if not isinstance(msg, dict):
|
66
73
|
raise TypeError("Message must be a dictionary")
|
67
|
-
if
|
74
|
+
if "role" not in msg:
|
68
75
|
raise ValueError("Message missing required 'role' field")
|
69
|
-
if
|
76
|
+
if "content" not in msg:
|
70
77
|
raise ValueError("Message missing required 'content' field")
|
71
|
-
if msg[
|
72
|
-
raise ValueError(
|
73
|
-
|
78
|
+
if msg["role"] not in ["system", "user", "assistant"]:
|
79
|
+
raise ValueError(
|
80
|
+
f"Invalid role '{msg['role']}'. Must be 'system', 'user', or 'assistant'"
|
81
|
+
)
|
82
|
+
|
74
83
|
return messages
|
75
84
|
|
76
|
-
@pydantic.field_validator(
|
85
|
+
@pydantic.field_validator("model")
|
77
86
|
def validate_model(cls, model):
|
78
87
|
if not model:
|
79
88
|
raise ValueError("Model cannot be empty")
|
80
89
|
if model not in ACCEPTABLE_MODELS:
|
81
90
|
raise ValueError(f"Model {model} is not in the list of supported models.")
|
82
91
|
return model
|
83
|
-
|
84
|
-
@pydantic.field_validator(
|
92
|
+
|
93
|
+
@pydantic.field_validator("response_format", mode="before")
|
85
94
|
def validate_response_format(cls, response_format):
|
86
95
|
if response_format is not None:
|
87
96
|
if not isinstance(response_format, (dict, pydantic.BaseModel)):
|
88
|
-
raise TypeError(
|
97
|
+
raise TypeError(
|
98
|
+
"Response format must be a dictionary or pydantic model"
|
99
|
+
)
|
89
100
|
# Optional: Add additional validation for required fields if needed
|
90
101
|
# For example, checking for 'type': 'json' in OpenAI's format
|
91
102
|
return response_format
|
92
103
|
|
93
|
-
|
104
|
+
|
105
|
+
os.environ["LITELLM_LOG"] = "DEBUG"
|
94
106
|
|
95
107
|
load_dotenv()
|
96
108
|
|
109
|
+
|
97
110
|
def read_file(file_path: str) -> str:
|
98
|
-
with open(file_path, "r", encoding=
|
111
|
+
with open(file_path, "r", encoding="utf-8") as file:
|
99
112
|
return file.read()
|
100
113
|
|
114
|
+
|
101
115
|
def validate_api_key(judgment_api_key: str):
|
102
116
|
"""
|
103
117
|
Validates that the user api key is valid
|
@@ -109,66 +123,67 @@ def validate_api_key(judgment_api_key: str):
|
|
109
123
|
"Authorization": f"Bearer {judgment_api_key}",
|
110
124
|
},
|
111
125
|
json={}, # Empty body now
|
112
|
-
verify=True
|
126
|
+
verify=True,
|
113
127
|
)
|
114
128
|
if response.status_code == 200:
|
115
129
|
return True, response.json()
|
116
130
|
else:
|
117
131
|
return False, response.json().get("detail", "Error validating API key")
|
118
132
|
|
119
|
-
|
133
|
+
|
134
|
+
def fetch_together_api_response(
|
135
|
+
model: str, messages: List[Mapping], response_format: pydantic.BaseModel = None
|
136
|
+
) -> str:
|
120
137
|
"""
|
121
138
|
Fetches a single response from the Together API for a given model and messages.
|
122
139
|
"""
|
123
140
|
# Validate request
|
124
141
|
if messages is None or messages == []:
|
125
142
|
raise ValueError("Messages cannot be empty")
|
126
|
-
|
143
|
+
|
127
144
|
request = ChatCompletionRequest(
|
128
|
-
|
129
|
-
messages=messages,
|
130
|
-
response_format=response_format
|
145
|
+
model=model, messages=messages, response_format=response_format
|
131
146
|
)
|
132
|
-
|
147
|
+
|
133
148
|
debug(f"Calling Together API with model: {request.model}")
|
134
149
|
debug(f"Messages: {request.messages}")
|
135
|
-
|
150
|
+
|
136
151
|
if request.response_format is not None:
|
137
152
|
debug(f"Using response format: {request.response_format}")
|
138
153
|
response = together_client.chat.completions.create(
|
139
154
|
model=request.model,
|
140
155
|
messages=request.messages,
|
141
|
-
response_format=request.response_format
|
156
|
+
response_format=request.response_format,
|
142
157
|
)
|
143
158
|
else:
|
144
159
|
response = together_client.chat.completions.create(
|
145
160
|
model=request.model,
|
146
161
|
messages=request.messages,
|
147
162
|
)
|
148
|
-
|
163
|
+
|
149
164
|
debug(f"Received response: {response.choices[0].message.content[:100]}...")
|
150
165
|
return response.choices[0].message.content
|
151
166
|
|
152
167
|
|
153
|
-
async def afetch_together_api_response(
|
168
|
+
async def afetch_together_api_response(
|
169
|
+
model: str, messages: List[Mapping], response_format: pydantic.BaseModel = None
|
170
|
+
) -> str:
|
154
171
|
"""
|
155
172
|
ASYNCHRONOUSLY Fetches a single response from the Together API for a given model and messages.
|
156
173
|
"""
|
157
174
|
request = ChatCompletionRequest(
|
158
|
-
model=model,
|
159
|
-
messages=messages,
|
160
|
-
response_format=response_format
|
175
|
+
model=model, messages=messages, response_format=response_format
|
161
176
|
)
|
162
|
-
|
177
|
+
|
163
178
|
debug(f"Calling Together API with model: {request.model}")
|
164
179
|
debug(f"Messages: {request.messages}")
|
165
|
-
|
180
|
+
|
166
181
|
if request.response_format is not None:
|
167
182
|
debug(f"Using response format: {request.response_format}")
|
168
183
|
response = await async_together_client.chat.completions.create(
|
169
184
|
model=request.model,
|
170
185
|
messages=request.messages,
|
171
|
-
response_format=request.response_format
|
186
|
+
response_format=request.response_format,
|
172
187
|
)
|
173
188
|
else:
|
174
189
|
response = await async_together_client.chat.completions.create(
|
@@ -178,7 +193,11 @@ async def afetch_together_api_response(model: str, messages: List[Mapping], resp
|
|
178
193
|
return response.choices[0].message.content
|
179
194
|
|
180
195
|
|
181
|
-
def query_together_api_multiple_calls(
|
196
|
+
def query_together_api_multiple_calls(
|
197
|
+
models: List[str],
|
198
|
+
messages: List[List[Mapping]],
|
199
|
+
response_formats: List[pydantic.BaseModel] | None = None,
|
200
|
+
) -> List[Union[str, None]]:
|
182
201
|
"""
|
183
202
|
Queries the Together API for multiple calls in parallel
|
184
203
|
|
@@ -197,25 +216,35 @@ def query_together_api_multiple_calls(models: List[str], messages: List[List[Map
|
|
197
216
|
# Validate all models are supported
|
198
217
|
for model in models:
|
199
218
|
if model not in ACCEPTABLE_MODELS:
|
200
|
-
raise ValueError(
|
219
|
+
raise ValueError(
|
220
|
+
f"Model {model} is not in the list of supported models: {ACCEPTABLE_MODELS}."
|
221
|
+
)
|
201
222
|
|
202
223
|
# Validate input lengths match
|
203
224
|
if response_formats is None:
|
204
225
|
response_formats = [None] * len(models)
|
205
226
|
if not (len(models) == len(messages) == len(response_formats)):
|
206
|
-
raise ValueError(
|
227
|
+
raise ValueError(
|
228
|
+
"Number of models, messages, and response formats must be the same"
|
229
|
+
)
|
207
230
|
|
208
231
|
# Validate message format
|
209
232
|
validate_batched_chat_messages(messages)
|
210
233
|
|
211
|
-
num_workers = int(os.getenv(
|
234
|
+
num_workers = int(os.getenv("NUM_WORKER_THREADS", MAX_WORKER_THREADS))
|
212
235
|
# Initialize results to maintain ordered outputs
|
213
|
-
out = [None] * len(messages)
|
236
|
+
out: List[str | None] = [None] * len(messages)
|
214
237
|
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
|
215
238
|
# Submit all queries to together API with index, gets back the response content
|
216
|
-
futures = {
|
217
|
-
|
218
|
-
|
239
|
+
futures = {
|
240
|
+
executor.submit(
|
241
|
+
fetch_together_api_response, model, message, response_format
|
242
|
+
): idx
|
243
|
+
for idx, (model, message, response_format) in enumerate(
|
244
|
+
zip(models, messages, response_formats)
|
245
|
+
)
|
246
|
+
}
|
247
|
+
|
219
248
|
# Collect results as they complete -- result is response content
|
220
249
|
for future in concurrent.futures.as_completed(futures):
|
221
250
|
idx = futures[future]
|
@@ -223,11 +252,15 @@ def query_together_api_multiple_calls(models: List[str], messages: List[List[Map
|
|
223
252
|
out[idx] = future.result()
|
224
253
|
except Exception as e:
|
225
254
|
error(f"Error in parallel call {idx}: {str(e)}")
|
226
|
-
out[idx] = None
|
255
|
+
out[idx] = None
|
227
256
|
return out
|
228
257
|
|
229
258
|
|
230
|
-
async def aquery_together_api_multiple_calls(
|
259
|
+
async def aquery_together_api_multiple_calls(
|
260
|
+
models: List[str],
|
261
|
+
messages: List[List[Mapping]],
|
262
|
+
response_formats: List[pydantic.BaseModel] | None = None,
|
263
|
+
) -> List[Union[str, None]]:
|
231
264
|
"""
|
232
265
|
Queries the Together API for multiple calls in parallel
|
233
266
|
|
@@ -246,57 +279,65 @@ async def aquery_together_api_multiple_calls(models: List[str], messages: List[L
|
|
246
279
|
# Validate all models are supported
|
247
280
|
for model in models:
|
248
281
|
if model not in ACCEPTABLE_MODELS:
|
249
|
-
raise ValueError(
|
282
|
+
raise ValueError(
|
283
|
+
f"Model {model} is not in the list of supported models: {ACCEPTABLE_MODELS}."
|
284
|
+
)
|
250
285
|
|
251
286
|
# Validate input lengths match
|
252
287
|
if response_formats is None:
|
253
288
|
response_formats = [None] * len(models)
|
254
289
|
if not (len(models) == len(messages) == len(response_formats)):
|
255
|
-
raise ValueError(
|
290
|
+
raise ValueError(
|
291
|
+
"Number of models, messages, and response formats must be the same"
|
292
|
+
)
|
256
293
|
|
257
294
|
# Validate message format
|
258
295
|
validate_batched_chat_messages(messages)
|
259
296
|
|
260
297
|
debug(f"Starting parallel Together API calls for {len(messages)} messages")
|
261
|
-
out = [None] * len(messages)
|
262
|
-
|
298
|
+
out: List[Union[str, None]] = [None] * len(messages)
|
299
|
+
|
263
300
|
async def fetch_and_store(idx, model, message, response_format):
|
264
301
|
try:
|
265
302
|
debug(f"Processing call {idx} with model {model}")
|
266
|
-
out[idx] = await afetch_together_api_response(
|
303
|
+
out[idx] = await afetch_together_api_response(
|
304
|
+
model, message, response_format
|
305
|
+
)
|
267
306
|
except Exception as e:
|
268
307
|
error(f"Error in parallel call {idx}: {str(e)}")
|
269
308
|
out[idx] = None
|
270
309
|
|
271
310
|
tasks = [
|
272
311
|
fetch_and_store(idx, model, message, response_format)
|
273
|
-
for idx, (model, message, response_format) in enumerate(
|
312
|
+
for idx, (model, message, response_format) in enumerate(
|
313
|
+
zip(models, messages, response_formats)
|
314
|
+
)
|
274
315
|
]
|
275
|
-
|
316
|
+
|
276
317
|
await asyncio.gather(*tasks)
|
277
318
|
debug(f"Completed {len(messages)} parallel calls")
|
278
319
|
return out
|
279
320
|
|
280
321
|
|
281
|
-
def fetch_litellm_api_response(
|
322
|
+
def fetch_litellm_api_response(
|
323
|
+
model: str, messages: List[Mapping], response_format: pydantic.BaseModel = None
|
324
|
+
) -> str:
|
282
325
|
"""
|
283
326
|
Fetches a single response from the Litellm API for a given model and messages.
|
284
327
|
"""
|
285
328
|
request = ChatCompletionRequest(
|
286
|
-
model=model,
|
287
|
-
messages=messages,
|
288
|
-
response_format=response_format
|
329
|
+
model=model, messages=messages, response_format=response_format
|
289
330
|
)
|
290
|
-
|
331
|
+
|
291
332
|
debug(f"Calling LiteLLM API with model: {request.model}")
|
292
333
|
debug(f"Messages: {request.messages}")
|
293
|
-
|
334
|
+
|
294
335
|
if request.response_format is not None:
|
295
336
|
debug(f"Using response format: {request.response_format}")
|
296
337
|
response = litellm.completion(
|
297
338
|
model=request.model,
|
298
339
|
messages=request.messages,
|
299
|
-
response_format=request.response_format
|
340
|
+
response_format=request.response_format,
|
300
341
|
)
|
301
342
|
else:
|
302
343
|
response = litellm.completion(
|
@@ -306,23 +347,29 @@ def fetch_litellm_api_response(model: str, messages: List[Mapping], response_for
|
|
306
347
|
return response.choices[0].message.content
|
307
348
|
|
308
349
|
|
309
|
-
def fetch_custom_litellm_api_response(
|
350
|
+
def fetch_custom_litellm_api_response(
|
351
|
+
custom_model_parameters: CustomModelParameters,
|
352
|
+
messages: List[Mapping],
|
353
|
+
response_format: pydantic.BaseModel = None,
|
354
|
+
) -> str:
|
310
355
|
if messages is None or messages == []:
|
311
356
|
raise ValueError("Messages cannot be empty")
|
312
|
-
|
357
|
+
|
313
358
|
if custom_model_parameters is None:
|
314
359
|
raise ValueError("Custom model parameters cannot be empty")
|
315
|
-
|
360
|
+
|
316
361
|
if not isinstance(custom_model_parameters, CustomModelParameters):
|
317
|
-
raise ValueError(
|
318
|
-
|
362
|
+
raise ValueError(
|
363
|
+
"Custom model parameters must be a CustomModelParameters object"
|
364
|
+
)
|
365
|
+
|
319
366
|
if response_format is not None:
|
320
367
|
response = litellm.completion(
|
321
368
|
model=custom_model_parameters.model_name,
|
322
369
|
messages=messages,
|
323
370
|
api_key=custom_model_parameters.secret_key,
|
324
371
|
base_url=custom_model_parameters.litellm_base_url,
|
325
|
-
response_format=response_format
|
372
|
+
response_format=response_format,
|
326
373
|
)
|
327
374
|
else:
|
328
375
|
response = litellm.completion(
|
@@ -334,53 +381,61 @@ def fetch_custom_litellm_api_response(custom_model_parameters: CustomModelParame
|
|
334
381
|
return response.choices[0].message.content
|
335
382
|
|
336
383
|
|
337
|
-
async def afetch_litellm_api_response(
|
384
|
+
async def afetch_litellm_api_response(
|
385
|
+
model: str, messages: List[Mapping], response_format: pydantic.BaseModel = None
|
386
|
+
) -> str:
|
338
387
|
"""
|
339
388
|
ASYNCHRONOUSLY Fetches a single response from the Litellm API for a given model and messages.
|
340
389
|
"""
|
341
390
|
if messages is None or messages == []:
|
342
391
|
raise ValueError("Messages cannot be empty")
|
343
|
-
|
392
|
+
|
344
393
|
# Add validation
|
345
394
|
validate_chat_messages(messages)
|
346
|
-
|
395
|
+
|
347
396
|
if model not in ACCEPTABLE_MODELS:
|
348
|
-
raise ValueError(
|
349
|
-
|
397
|
+
raise ValueError(
|
398
|
+
f"Model {model} is not in the list of supported models: {ACCEPTABLE_MODELS}."
|
399
|
+
)
|
400
|
+
|
350
401
|
if response_format is not None:
|
351
402
|
response = await litellm.acompletion(
|
352
|
-
model=model,
|
353
|
-
messages=messages,
|
354
|
-
response_format=response_format
|
403
|
+
model=model, messages=messages, response_format=response_format
|
355
404
|
)
|
356
405
|
else:
|
357
406
|
response = await litellm.acompletion(
|
358
407
|
model=model,
|
359
|
-
messages=messages,
|
408
|
+
messages=messages,
|
360
409
|
)
|
361
410
|
return response.choices[0].message.content
|
362
411
|
|
363
412
|
|
364
|
-
async def afetch_custom_litellm_api_response(
|
413
|
+
async def afetch_custom_litellm_api_response(
|
414
|
+
custom_model_parameters: CustomModelParameters,
|
415
|
+
messages: List[Mapping],
|
416
|
+
response_format: pydantic.BaseModel = None,
|
417
|
+
) -> str:
|
365
418
|
"""
|
366
419
|
ASYNCHRONOUSLY Fetches a single response from the Litellm API for a given model and messages.
|
367
420
|
"""
|
368
421
|
if messages is None or messages == []:
|
369
422
|
raise ValueError("Messages cannot be empty")
|
370
|
-
|
423
|
+
|
371
424
|
if custom_model_parameters is None:
|
372
425
|
raise ValueError("Custom model parameters cannot be empty")
|
373
|
-
|
426
|
+
|
374
427
|
if not isinstance(custom_model_parameters, CustomModelParameters):
|
375
|
-
raise ValueError(
|
376
|
-
|
428
|
+
raise ValueError(
|
429
|
+
"Custom model parameters must be a CustomModelParameters object"
|
430
|
+
)
|
431
|
+
|
377
432
|
if response_format is not None:
|
378
433
|
response = await litellm.acompletion(
|
379
434
|
model=custom_model_parameters.model_name,
|
380
435
|
messages=messages,
|
381
436
|
api_key=custom_model_parameters.secret_key,
|
382
437
|
base_url=custom_model_parameters.litellm_base_url,
|
383
|
-
response_format=response_format
|
438
|
+
response_format=response_format,
|
384
439
|
)
|
385
440
|
else:
|
386
441
|
response = await litellm.acompletion(
|
@@ -392,26 +447,36 @@ async def afetch_custom_litellm_api_response(custom_model_parameters: CustomMode
|
|
392
447
|
return response.choices[0].message.content
|
393
448
|
|
394
449
|
|
395
|
-
def query_litellm_api_multiple_calls(
|
450
|
+
def query_litellm_api_multiple_calls(
|
451
|
+
models: List[str],
|
452
|
+
messages: List[List[Mapping]],
|
453
|
+
response_formats: List[pydantic.BaseModel] | None = None,
|
454
|
+
) -> List[Union[str, None]]:
|
396
455
|
"""
|
397
456
|
Queries the Litellm API for multiple calls in parallel
|
398
457
|
|
399
458
|
Args:
|
400
459
|
models (List[str]): List of models to query
|
401
|
-
messages (List[Mapping]): List of messages to query
|
460
|
+
messages (List[List[Mapping]]): List of messages to query
|
402
461
|
response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
|
403
462
|
|
404
463
|
Returns:
|
405
464
|
List[str]: Litellm responses for each model and message pair in order. Any exceptions in the thread call result in a None.
|
406
465
|
"""
|
407
|
-
num_workers = int(os.getenv(
|
466
|
+
num_workers = int(os.getenv("NUM_WORKER_THREADS", MAX_WORKER_THREADS))
|
408
467
|
# Initialize results to maintain ordered outputs
|
409
|
-
out = [None] * len(messages)
|
468
|
+
out: List[Union[str, None]] = [None] * len(messages)
|
410
469
|
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
|
411
470
|
# Submit all queries to Litellm API with index, gets back the response content
|
412
|
-
futures = {
|
413
|
-
|
414
|
-
|
471
|
+
futures = {
|
472
|
+
executor.submit(
|
473
|
+
fetch_litellm_api_response, model, message, response_format
|
474
|
+
): idx
|
475
|
+
for idx, (model, message, response_format) in enumerate(
|
476
|
+
zip(models, messages, response_formats or [None] * len(messages))
|
477
|
+
)
|
478
|
+
}
|
479
|
+
|
415
480
|
# Collect results as they complete -- result is response content
|
416
481
|
for future in concurrent.futures.as_completed(futures):
|
417
482
|
idx = futures[future]
|
@@ -419,37 +484,45 @@ def query_litellm_api_multiple_calls(models: List[str], messages: List[Mapping],
|
|
419
484
|
out[idx] = future.result()
|
420
485
|
except Exception as e:
|
421
486
|
error(f"Error in parallel call {idx}: {str(e)}")
|
422
|
-
out[idx] = None
|
487
|
+
out[idx] = None
|
423
488
|
return out
|
424
489
|
|
425
490
|
|
426
|
-
async def aquery_litellm_api_multiple_calls(
|
491
|
+
async def aquery_litellm_api_multiple_calls(
|
492
|
+
models: List[str],
|
493
|
+
messages: List[List[Mapping]],
|
494
|
+
response_formats: List[pydantic.BaseModel] | None = None,
|
495
|
+
) -> List[Union[str, None]]:
|
427
496
|
"""
|
428
497
|
Queries the Litellm API for multiple calls in parallel
|
429
498
|
|
430
499
|
Args:
|
431
500
|
models (List[str]): List of models to query
|
432
|
-
messages (List[Mapping]): List of messages to query
|
501
|
+
messages (List[List[Mapping]]): List of messages to query
|
433
502
|
response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
|
434
|
-
|
503
|
+
|
435
504
|
Returns:
|
436
505
|
List[str]: Litellm responses for each model and message pair in order. Any exceptions in the thread call result in a None.
|
437
506
|
"""
|
438
507
|
# Initialize results to maintain ordered outputs
|
439
|
-
out = [None] * len(messages)
|
440
|
-
|
508
|
+
out: List[Union[str, None]] = [None] * len(messages)
|
509
|
+
|
441
510
|
async def fetch_and_store(idx, model, message, response_format):
|
442
511
|
try:
|
443
|
-
out[idx] = await afetch_litellm_api_response(
|
512
|
+
out[idx] = await afetch_litellm_api_response(
|
513
|
+
model, message, response_format
|
514
|
+
)
|
444
515
|
except Exception as e:
|
445
516
|
error(f"Error in parallel call {idx}: {str(e)}")
|
446
517
|
out[idx] = None
|
447
518
|
|
448
519
|
tasks = [
|
449
520
|
fetch_and_store(idx, model, message, response_format)
|
450
|
-
for idx, (model, message, response_format) in enumerate(
|
521
|
+
for idx, (model, message, response_format) in enumerate(
|
522
|
+
zip(models, messages, response_formats or [None] * len(messages))
|
523
|
+
)
|
451
524
|
]
|
452
|
-
|
525
|
+
|
453
526
|
await asyncio.gather(*tasks)
|
454
527
|
return out
|
455
528
|
|
@@ -458,56 +531,75 @@ def validate_chat_messages(messages, batched: bool = False):
|
|
458
531
|
"""Validate chat message format before API call"""
|
459
532
|
if not isinstance(messages, list):
|
460
533
|
raise TypeError("Messages must be a list")
|
461
|
-
|
534
|
+
|
462
535
|
for msg in messages:
|
463
536
|
if not isinstance(msg, dict):
|
464
537
|
if batched and not isinstance(msg, list):
|
465
538
|
raise TypeError("Each message must be a list")
|
466
539
|
elif not batched:
|
467
540
|
raise TypeError("Message must be a dictionary")
|
468
|
-
if
|
541
|
+
if "role" not in msg:
|
469
542
|
raise ValueError("Message missing required 'role' field")
|
470
|
-
if
|
543
|
+
if "content" not in msg:
|
471
544
|
raise ValueError("Message missing required 'content' field")
|
472
|
-
if msg[
|
473
|
-
raise ValueError(
|
545
|
+
if msg["role"] not in ["system", "user", "assistant"]:
|
546
|
+
raise ValueError(
|
547
|
+
f"Invalid role '{msg['role']}'. Must be 'system', 'user', or 'assistant'"
|
548
|
+
)
|
549
|
+
|
474
550
|
|
475
|
-
def validate_batched_chat_messages(messages
|
551
|
+
def validate_batched_chat_messages(messages):
|
476
552
|
"""
|
477
553
|
Validate format of batched chat messages before API call
|
478
|
-
|
554
|
+
|
479
555
|
Args:
|
480
556
|
messages (List[List[Mapping]]): List of message lists, where each inner list contains
|
481
557
|
message dictionaries with 'role' and 'content' fields
|
482
|
-
|
558
|
+
|
483
559
|
Raises:
|
484
560
|
TypeError: If messages format is invalid
|
485
561
|
ValueError: If message content is invalid
|
486
562
|
"""
|
487
563
|
if not isinstance(messages, list):
|
488
564
|
raise TypeError("Batched messages must be a list")
|
489
|
-
|
565
|
+
|
490
566
|
if not messages:
|
491
567
|
raise ValueError("Batched messages cannot be empty")
|
492
|
-
|
568
|
+
|
493
569
|
for message_list in messages:
|
494
570
|
if not isinstance(message_list, list):
|
495
571
|
raise TypeError("Each batch item must be a list of messages")
|
496
|
-
|
572
|
+
|
497
573
|
# Validate individual messages using existing function
|
498
574
|
validate_chat_messages(message_list)
|
499
575
|
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
576
|
+
|
577
|
+
def is_batched_messages(
|
578
|
+
messages: Union[List[Mapping], List[List[Mapping]]],
|
579
|
+
) -> TypeGuard[List[List[Mapping]]]:
|
580
|
+
return isinstance(messages, list) and all(isinstance(msg, list) for msg in messages)
|
581
|
+
|
582
|
+
|
583
|
+
def is_simple_messages(
|
584
|
+
messages: Union[List[Mapping], List[List[Mapping]]],
|
585
|
+
) -> TypeGuard[List[Mapping]]:
|
586
|
+
return isinstance(messages, list) and all(
|
587
|
+
not isinstance(msg, list) for msg in messages
|
588
|
+
)
|
589
|
+
|
590
|
+
|
591
|
+
def get_chat_completion(
|
592
|
+
model_type: str,
|
593
|
+
messages: Union[List[Mapping], List[List[Mapping]]],
|
594
|
+
response_format: pydantic.BaseModel = None,
|
595
|
+
batched: bool = False,
|
596
|
+
) -> Union[str, List[str | None]]:
|
505
597
|
"""
|
506
598
|
Generates chat completions using a single model and potentially several messages. Supports closed-source and OSS models.
|
507
599
|
|
508
600
|
Parameters:
|
509
601
|
- model_type (str): The type of model to use for generating completions.
|
510
|
-
- messages (Union[List[Mapping], List[List[Mapping]]]): The messages to be used for generating completions.
|
602
|
+
- messages (Union[List[Mapping], List[List[Mapping]]]): The messages to be used for generating completions.
|
511
603
|
If batched is True, this should be a list of lists of mappings.
|
512
604
|
- response_format (pydantic.BaseModel, optional): The format of the response. Defaults to None.
|
513
605
|
- batched (bool, optional): Whether to process messages in batch mode. Defaults to False.
|
@@ -516,50 +608,71 @@ def get_chat_completion(model_type: str,
|
|
516
608
|
Raises:
|
517
609
|
- ValueError: If requested model is not supported by Litellm or TogetherAI.
|
518
610
|
"""
|
519
|
-
|
611
|
+
|
520
612
|
# Check for empty messages list
|
521
613
|
if not messages or messages == []:
|
522
614
|
raise ValueError("Messages cannot be empty")
|
523
|
-
|
615
|
+
|
524
616
|
# Add validation
|
525
617
|
if batched:
|
526
618
|
validate_batched_chat_messages(messages)
|
527
619
|
else:
|
528
620
|
validate_chat_messages(messages)
|
529
|
-
|
530
|
-
if
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
return
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
621
|
+
|
622
|
+
if (
|
623
|
+
batched
|
624
|
+
and is_batched_messages(messages)
|
625
|
+
and model_type in TOGETHER_SUPPORTED_MODELS
|
626
|
+
):
|
627
|
+
return query_together_api_multiple_calls(
|
628
|
+
models=[model_type] * len(messages),
|
629
|
+
messages=messages,
|
630
|
+
response_formats=[response_format] * len(messages),
|
631
|
+
)
|
632
|
+
elif (
|
633
|
+
batched
|
634
|
+
and is_batched_messages(messages)
|
635
|
+
and model_type in LITELLM_SUPPORTED_MODELS
|
636
|
+
):
|
637
|
+
return query_litellm_api_multiple_calls(
|
638
|
+
models=[model_type] * len(messages),
|
639
|
+
messages=messages,
|
640
|
+
response_formats=[response_format] * len(messages),
|
641
|
+
)
|
642
|
+
elif (
|
643
|
+
not batched
|
644
|
+
and is_simple_messages(messages)
|
645
|
+
and model_type in TOGETHER_SUPPORTED_MODELS
|
646
|
+
):
|
647
|
+
return fetch_together_api_response(
|
648
|
+
model=model_type, messages=messages, response_format=response_format
|
649
|
+
)
|
650
|
+
elif (
|
651
|
+
not batched
|
652
|
+
and is_simple_messages(messages)
|
653
|
+
and model_type in LITELLM_SUPPORTED_MODELS
|
654
|
+
):
|
655
|
+
return fetch_litellm_api_response(
|
656
|
+
model=model_type, messages=messages, response_format=response_format
|
657
|
+
)
|
658
|
+
|
659
|
+
raise ValueError(
|
660
|
+
f"Model {model_type} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again."
|
661
|
+
)
|
662
|
+
|
663
|
+
|
664
|
+
async def aget_chat_completion(
|
665
|
+
model_type: str,
|
666
|
+
messages: Union[List[Mapping], List[List[Mapping]]],
|
667
|
+
response_format: pydantic.BaseModel = None,
|
668
|
+
batched: bool = False,
|
669
|
+
) -> Union[str, List[str | None]]:
|
557
670
|
"""
|
558
671
|
ASYNCHRONOUSLY generates chat completions using a single model and potentially several messages. Supports closed-source and OSS models.
|
559
672
|
|
560
673
|
Parameters:
|
561
674
|
- model_type (str): The type of model to use for generating completions.
|
562
|
-
- messages (Union[List[Mapping], List[List[Mapping]]]): The messages to be used for generating completions.
|
675
|
+
- messages (Union[List[Mapping], List[List[Mapping]]]): The messages to be used for generating completions.
|
563
676
|
If batched is True, this should be a list of lists of mappings.
|
564
677
|
- response_format (pydantic.BaseModel, optional): The format of the response. Defaults to None.
|
565
678
|
- batched (bool, optional): Whether to process messages in batch mode. Defaults to False.
|
@@ -569,38 +682,64 @@ async def aget_chat_completion(model_type: str,
|
|
569
682
|
- ValueError: If requested model is not supported by Litellm or TogetherAI.
|
570
683
|
"""
|
571
684
|
debug(f"Starting chat completion for model {model_type}, batched={batched}")
|
572
|
-
|
685
|
+
|
573
686
|
if batched:
|
574
687
|
validate_batched_chat_messages(messages)
|
575
688
|
else:
|
576
689
|
validate_chat_messages(messages)
|
577
|
-
|
578
|
-
if
|
690
|
+
|
691
|
+
if (
|
692
|
+
batched
|
693
|
+
and is_batched_messages(messages)
|
694
|
+
and model_type in TOGETHER_SUPPORTED_MODELS
|
695
|
+
):
|
579
696
|
debug("Using batched Together API call")
|
580
|
-
return await aquery_together_api_multiple_calls(
|
581
|
-
|
582
|
-
|
583
|
-
|
697
|
+
return await aquery_together_api_multiple_calls(
|
698
|
+
models=[model_type] * len(messages),
|
699
|
+
messages=messages,
|
700
|
+
response_formats=[response_format] * len(messages),
|
701
|
+
)
|
702
|
+
elif (
|
703
|
+
batched
|
704
|
+
and is_batched_messages(messages)
|
705
|
+
and model_type in LITELLM_SUPPORTED_MODELS
|
706
|
+
):
|
584
707
|
debug("Using batched LiteLLM API call")
|
585
|
-
return await aquery_litellm_api_multiple_calls(
|
586
|
-
|
587
|
-
|
588
|
-
|
708
|
+
return await aquery_litellm_api_multiple_calls(
|
709
|
+
models=[model_type] * len(messages),
|
710
|
+
messages=messages,
|
711
|
+
response_formats=[response_format] * len(messages),
|
712
|
+
)
|
713
|
+
elif (
|
714
|
+
not batched
|
715
|
+
and is_simple_messages(messages)
|
716
|
+
and model_type in TOGETHER_SUPPORTED_MODELS
|
717
|
+
):
|
589
718
|
debug("Using single Together API call")
|
590
|
-
return await afetch_together_api_response(
|
591
|
-
|
592
|
-
|
593
|
-
elif
|
719
|
+
return await afetch_together_api_response(
|
720
|
+
model=model_type, messages=messages, response_format=response_format
|
721
|
+
)
|
722
|
+
elif (
|
723
|
+
not batched
|
724
|
+
and is_simple_messages(messages)
|
725
|
+
and model_type in LITELLM_SUPPORTED_MODELS
|
726
|
+
):
|
594
727
|
debug("Using single LiteLLM API call")
|
595
|
-
return await afetch_litellm_api_response(
|
596
|
-
|
597
|
-
|
598
|
-
|
728
|
+
return await afetch_litellm_api_response(
|
729
|
+
model=model_type, messages=messages, response_format=response_format
|
730
|
+
)
|
731
|
+
|
599
732
|
error(f"Model {model_type} not supported by either API")
|
600
|
-
raise ValueError(
|
733
|
+
raise ValueError(
|
734
|
+
f"Model {model_type} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again."
|
735
|
+
)
|
601
736
|
|
602
737
|
|
603
|
-
def get_completion_multiple_models(
|
738
|
+
def get_completion_multiple_models(
|
739
|
+
models: List[str],
|
740
|
+
messages: List[List[Mapping]],
|
741
|
+
response_formats: List[pydantic.BaseModel] | None = None,
|
742
|
+
) -> List[str | None]:
|
604
743
|
"""
|
605
744
|
Retrieves completions for a single prompt from multiple models in parallel. Supports closed-source and OSS models.
|
606
745
|
|
@@ -608,28 +747,32 @@ def get_completion_multiple_models(models: List[str], messages: List[List[Mappin
|
|
608
747
|
models (List[str]): List of models to query
|
609
748
|
messages (List[List[Mapping]]): List of messages to query. Each inner object corresponds to a single prompt.
|
610
749
|
response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
|
611
|
-
|
750
|
+
|
612
751
|
Returns:
|
613
752
|
List[str]: List of completions from the models in the order of the input models
|
614
753
|
Raises:
|
615
754
|
ValueError: If a model is not supported by Litellm or Together
|
616
755
|
"""
|
617
756
|
debug(f"Starting multiple model completion for {len(models)} models")
|
618
|
-
|
757
|
+
|
619
758
|
if models is None or models == []:
|
620
759
|
raise ValueError("Models list cannot be empty")
|
621
|
-
|
760
|
+
|
622
761
|
validate_batched_chat_messages(messages)
|
623
|
-
|
762
|
+
|
624
763
|
if len(models) != len(messages):
|
625
764
|
error(f"Model/message count mismatch: {len(models)} vs {len(messages)}")
|
626
|
-
raise ValueError(
|
765
|
+
raise ValueError(
|
766
|
+
f"Number of models and messages must be the same: {len(models)} != {len(messages)}"
|
767
|
+
)
|
627
768
|
if response_formats is None:
|
628
769
|
response_formats = [None] * len(models)
|
629
770
|
# Partition the model requests into TogetherAI and Litellm models, but keep the ordering saved
|
630
771
|
together_calls, litellm_calls = {}, {} # index -> model, message, response_format
|
631
772
|
together_responses, litellm_responses = [], []
|
632
|
-
for idx, (model, message, r_format) in enumerate(
|
773
|
+
for idx, (model, message, r_format) in enumerate(
|
774
|
+
zip(models, messages, response_formats)
|
775
|
+
):
|
633
776
|
if model in TOGETHER_SUPPORTED_MODELS:
|
634
777
|
debug(f"Model {model} routed to Together API")
|
635
778
|
together_calls[idx] = (model, message, r_format)
|
@@ -638,39 +781,49 @@ def get_completion_multiple_models(models: List[str], messages: List[List[Mappin
|
|
638
781
|
litellm_calls[idx] = (model, message, r_format)
|
639
782
|
else:
|
640
783
|
error(f"Model {model} not supported by either API")
|
641
|
-
raise ValueError(
|
642
|
-
|
784
|
+
raise ValueError(
|
785
|
+
f"Model {model} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again."
|
786
|
+
)
|
787
|
+
|
643
788
|
# Add validation before processing
|
644
789
|
for msg_list in messages:
|
645
790
|
validate_chat_messages(msg_list)
|
646
|
-
|
791
|
+
|
647
792
|
# Get the responses from the TogetherAI models
|
648
793
|
# List of responses from the TogetherAI models in order of the together_calls dict
|
649
794
|
if together_calls:
|
650
795
|
debug(f"Executing {len(together_calls)} Together API calls")
|
651
|
-
together_responses = query_together_api_multiple_calls(
|
652
|
-
|
653
|
-
|
654
|
-
|
796
|
+
together_responses = query_together_api_multiple_calls(
|
797
|
+
models=[model for model, _, _ in together_calls.values()],
|
798
|
+
messages=[message for _, message, _ in together_calls.values()],
|
799
|
+
response_formats=[format for _, _, format in together_calls.values()],
|
800
|
+
)
|
801
|
+
|
655
802
|
# Get the responses from the Litellm models
|
656
803
|
if litellm_calls:
|
657
804
|
debug(f"Executing {len(litellm_calls)} LiteLLM API calls")
|
658
|
-
litellm_responses = query_litellm_api_multiple_calls(
|
659
|
-
|
660
|
-
|
805
|
+
litellm_responses = query_litellm_api_multiple_calls(
|
806
|
+
models=[model for model, _, _ in litellm_calls.values()],
|
807
|
+
messages=[message for _, message, _ in litellm_calls.values()],
|
808
|
+
response_formats=[format for _, _, format in litellm_calls.values()],
|
809
|
+
)
|
661
810
|
|
662
811
|
# Merge the responses in the order of the original models
|
663
812
|
debug("Merging responses")
|
664
|
-
out = [None] * len(models)
|
813
|
+
out: List[Union[str, None]] = [None] * len(models)
|
665
814
|
for idx, (model, message, r_format) in together_calls.items():
|
666
815
|
out[idx] = together_responses.pop(0)
|
667
816
|
for idx, (model, message, r_format) in litellm_calls.items():
|
668
817
|
out[idx] = litellm_responses.pop(0)
|
669
818
|
debug("Multiple model completion finished")
|
670
|
-
return out
|
819
|
+
return out
|
671
820
|
|
672
821
|
|
673
|
-
async def aget_completion_multiple_models(
|
822
|
+
async def aget_completion_multiple_models(
|
823
|
+
models: List[str],
|
824
|
+
messages: List[List[Mapping]],
|
825
|
+
response_formats: List[pydantic.BaseModel] | None = None,
|
826
|
+
) -> List[str | None]:
|
674
827
|
"""
|
675
828
|
ASYNCHRONOUSLY retrieves completions for a single prompt from multiple models in parallel. Supports closed-source and OSS models.
|
676
829
|
|
@@ -678,7 +831,7 @@ async def aget_completion_multiple_models(models: List[str], messages: List[List
|
|
678
831
|
models (List[str]): List of models to query
|
679
832
|
messages (List[List[Mapping]]): List of messages to query. Each inner object corresponds to a single prompt.
|
680
833
|
response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
|
681
|
-
|
834
|
+
|
682
835
|
Returns:
|
683
836
|
List[str]: List of completions from the models in the order of the input models
|
684
837
|
Raises:
|
@@ -686,48 +839,54 @@ async def aget_completion_multiple_models(models: List[str], messages: List[List
|
|
686
839
|
"""
|
687
840
|
if models is None or models == []:
|
688
841
|
raise ValueError("Models list cannot be empty")
|
689
|
-
|
842
|
+
|
690
843
|
if len(models) != len(messages):
|
691
|
-
raise ValueError(
|
844
|
+
raise ValueError(
|
845
|
+
f"Number of models and messages must be the same: {len(models)} != {len(messages)}"
|
846
|
+
)
|
692
847
|
if response_formats is None:
|
693
848
|
response_formats = [None] * len(models)
|
694
849
|
|
695
850
|
validate_batched_chat_messages(messages)
|
696
|
-
|
851
|
+
|
697
852
|
# Partition the model requests into TogetherAI and Litellm models, but keep the ordering saved
|
698
853
|
together_calls, litellm_calls = {}, {} # index -> model, message, response_format
|
699
854
|
together_responses, litellm_responses = [], []
|
700
|
-
for idx, (model, message, r_format) in enumerate(
|
855
|
+
for idx, (model, message, r_format) in enumerate(
|
856
|
+
zip(models, messages, response_formats)
|
857
|
+
):
|
701
858
|
if model in TOGETHER_SUPPORTED_MODELS:
|
702
859
|
together_calls[idx] = (model, message, r_format)
|
703
860
|
elif model in LITELLM_SUPPORTED_MODELS:
|
704
861
|
litellm_calls[idx] = (model, message, r_format)
|
705
862
|
else:
|
706
|
-
raise ValueError(
|
707
|
-
|
863
|
+
raise ValueError(
|
864
|
+
f"Model {model} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again."
|
865
|
+
)
|
866
|
+
|
708
867
|
# Add validation before processing
|
709
868
|
for msg_list in messages:
|
710
869
|
validate_chat_messages(msg_list)
|
711
|
-
|
870
|
+
|
712
871
|
# Get the responses from the TogetherAI models
|
713
872
|
# List of responses from the TogetherAI models in order of the together_calls dict
|
714
873
|
if together_calls:
|
715
874
|
together_responses = await aquery_together_api_multiple_calls(
|
716
|
-
models=[model for model, _, _ in together_calls.values()],
|
717
|
-
messages=[message for _, message, _ in together_calls.values()],
|
718
|
-
response_formats=[format for _, _, format in together_calls.values()]
|
875
|
+
models=[model for model, _, _ in together_calls.values()],
|
876
|
+
messages=[message for _, message, _ in together_calls.values()],
|
877
|
+
response_formats=[format for _, _, format in together_calls.values()],
|
719
878
|
)
|
720
|
-
|
879
|
+
|
721
880
|
# Get the responses from the Litellm models
|
722
881
|
if litellm_calls:
|
723
882
|
litellm_responses = await aquery_litellm_api_multiple_calls(
|
724
883
|
models=[model for model, _, _ in litellm_calls.values()],
|
725
884
|
messages=[message for _, message, _ in litellm_calls.values()],
|
726
|
-
response_formats=[format for _, _, format in litellm_calls.values()]
|
885
|
+
response_formats=[format for _, _, format in litellm_calls.values()],
|
727
886
|
)
|
728
887
|
|
729
888
|
# Merge the responses in the order of the original models
|
730
|
-
out = [None] * len(models)
|
889
|
+
out: List[Union[str, None]] = [None] * len(models)
|
731
890
|
for idx, (model, message, r_format) in together_calls.items():
|
732
891
|
out[idx] = together_responses.pop(0)
|
733
892
|
for idx, (model, message, r_format) in litellm_calls.items():
|
@@ -736,53 +895,66 @@ async def aget_completion_multiple_models(models: List[str], messages: List[List
|
|
736
895
|
|
737
896
|
|
738
897
|
if __name__ == "__main__":
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
messages=[
|
744
|
-
[
|
745
|
-
{"role": "system", "content": "You are a helpful assistant."},
|
746
|
-
{"role": "user", "content": "What is the capital of France?"},
|
747
|
-
],
|
748
|
-
[
|
749
|
-
{"role": "system", "content": "You are a helpful assistant."},
|
750
|
-
{"role": "user", "content": "What is the capital of Japan?"},
|
751
|
-
]
|
898
|
+
batched_messages: List[List[Mapping]] = [
|
899
|
+
[
|
900
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
901
|
+
{"role": "user", "content": "What is the capital of France?"},
|
752
902
|
],
|
753
|
-
|
754
|
-
|
903
|
+
[
|
904
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
905
|
+
{"role": "user", "content": "What is the capital of Japan?"},
|
906
|
+
],
|
907
|
+
]
|
755
908
|
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
909
|
+
non_batched_messages: List[Mapping] = [
|
910
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
911
|
+
{"role": "user", "content": "What is the capital of France?"},
|
912
|
+
]
|
913
|
+
|
914
|
+
batched_messages_2: List[List[Mapping]] = [
|
915
|
+
[
|
916
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
917
|
+
{"role": "user", "content": "What is the capital of China?"},
|
918
|
+
],
|
919
|
+
[
|
760
920
|
{"role": "system", "content": "You are a helpful assistant."},
|
761
921
|
{"role": "user", "content": "What is the capital of France?"},
|
762
922
|
],
|
763
|
-
|
764
|
-
|
923
|
+
[
|
924
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
925
|
+
{"role": "user", "content": "What is the capital of Japan?"},
|
926
|
+
],
|
927
|
+
]
|
928
|
+
|
929
|
+
# Batched
|
930
|
+
pprint.pprint(
|
931
|
+
get_chat_completion(
|
932
|
+
model_type="LLAMA3_405B_INSTRUCT_TURBO",
|
933
|
+
messages=batched_messages,
|
934
|
+
batched=True,
|
935
|
+
)
|
936
|
+
)
|
937
|
+
|
938
|
+
# Non batched
|
939
|
+
pprint.pprint(
|
940
|
+
get_chat_completion(
|
941
|
+
model_type="LLAMA3_8B_INSTRUCT_TURBO",
|
942
|
+
messages=non_batched_messages,
|
943
|
+
batched=False,
|
944
|
+
)
|
945
|
+
)
|
765
946
|
|
766
947
|
# Batched single completion to multiple models
|
767
|
-
pprint.pprint(
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
{"role": "system", "content": "You are a helpful assistant."},
|
774
|
-
{"role": "user", "content": "What is the capital of China?"},
|
948
|
+
pprint.pprint(
|
949
|
+
get_completion_multiple_models(
|
950
|
+
models=[
|
951
|
+
"LLAMA3_70B_INSTRUCT_TURBO",
|
952
|
+
"LLAMA3_405B_INSTRUCT_TURBO",
|
953
|
+
"gpt-4.1-mini",
|
775
954
|
],
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
[
|
781
|
-
{"role": "system", "content": "You are a helpful assistant."},
|
782
|
-
{"role": "user", "content": "What is the capital of Japan?"},
|
783
|
-
]
|
784
|
-
]
|
785
|
-
))
|
786
|
-
|
955
|
+
messages=batched_messages_2,
|
956
|
+
)
|
957
|
+
)
|
958
|
+
|
787
959
|
ExcInfo: TypeAlias = tuple[type[BaseException], BaseException, TracebackType]
|
788
960
|
OptExcInfo: TypeAlias = ExcInfo | tuple[None, None, None]
|