judgeval 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/__init__.py +83 -0
- judgeval/clients.py +19 -0
- judgeval/common/__init__.py +8 -0
- judgeval/common/exceptions.py +28 -0
- judgeval/common/logger.py +189 -0
- judgeval/common/tracer.py +587 -0
- judgeval/common/utils.py +763 -0
- judgeval/constants.py +55 -0
- judgeval/data/__init__.py +14 -0
- judgeval/data/api_example.py +111 -0
- judgeval/data/datasets/__init__.py +4 -0
- judgeval/data/datasets/dataset.py +407 -0
- judgeval/data/datasets/ground_truth.py +54 -0
- judgeval/data/datasets/utils.py +74 -0
- judgeval/data/example.py +76 -0
- judgeval/data/result.py +83 -0
- judgeval/data/scorer_data.py +86 -0
- judgeval/evaluation_run.py +130 -0
- judgeval/judges/__init__.py +7 -0
- judgeval/judges/base_judge.py +44 -0
- judgeval/judges/litellm_judge.py +49 -0
- judgeval/judges/mixture_of_judges.py +248 -0
- judgeval/judges/together_judge.py +55 -0
- judgeval/judges/utils.py +45 -0
- judgeval/judgment_client.py +244 -0
- judgeval/run_evaluation.py +355 -0
- judgeval/scorers/__init__.py +30 -0
- judgeval/scorers/base_scorer.py +51 -0
- judgeval/scorers/custom_scorer.py +134 -0
- judgeval/scorers/judgeval_scorers/__init__.py +21 -0
- judgeval/scorers/judgeval_scorers/answer_relevancy.py +19 -0
- judgeval/scorers/judgeval_scorers/contextual_precision.py +19 -0
- judgeval/scorers/judgeval_scorers/contextual_recall.py +19 -0
- judgeval/scorers/judgeval_scorers/contextual_relevancy.py +22 -0
- judgeval/scorers/judgeval_scorers/faithfulness.py +19 -0
- judgeval/scorers/judgeval_scorers/hallucination.py +19 -0
- judgeval/scorers/judgeval_scorers/json_correctness.py +32 -0
- judgeval/scorers/judgeval_scorers/summarization.py +20 -0
- judgeval/scorers/judgeval_scorers/tool_correctness.py +19 -0
- judgeval/scorers/prompt_scorer.py +439 -0
- judgeval/scorers/score.py +427 -0
- judgeval/scorers/utils.py +175 -0
- judgeval-0.0.1.dist-info/METADATA +40 -0
- judgeval-0.0.1.dist-info/RECORD +46 -0
- judgeval-0.0.1.dist-info/WHEEL +4 -0
- judgeval-0.0.1.dist-info/licenses/LICENSE.md +202 -0
judgeval/common/utils.py
ADDED
@@ -0,0 +1,763 @@
|
|
1
|
+
"""
|
2
|
+
This file contains utility functions used in repo scripts
|
3
|
+
|
4
|
+
For API calling, we support:
|
5
|
+
- parallelized model calls on the same prompt
|
6
|
+
- batched model calls on different prompts
|
7
|
+
|
8
|
+
NOTE: any function beginning with 'a', e.g. 'afetch_together_api_response', is an asynchronous function
|
9
|
+
"""
|
10
|
+
|
11
|
+
import concurrent.futures
|
12
|
+
from typing import List, Mapping, Dict, Union, Optional, Literal, Any
|
13
|
+
import asyncio
|
14
|
+
import litellm
|
15
|
+
import pydantic
|
16
|
+
import pprint
|
17
|
+
import os
|
18
|
+
from dotenv import load_dotenv
|
19
|
+
|
20
|
+
from judgeval.clients import async_together_client, together_client
|
21
|
+
from judgeval.constants import *
|
22
|
+
from judgeval.common.logger import debug, error
|
23
|
+
|
24
|
+
LITELLM_SUPPORTED_MODELS = set(litellm.model_list)
|
25
|
+
|
26
|
+
class CustomModelParameters(pydantic.BaseModel):
|
27
|
+
model_name: str
|
28
|
+
secret_key: str
|
29
|
+
litellm_base_url: str
|
30
|
+
|
31
|
+
@pydantic.field_validator('model_name')
|
32
|
+
def validate_model_name(cls, v):
|
33
|
+
if not v:
|
34
|
+
raise ValueError("Model name cannot be empty")
|
35
|
+
return v
|
36
|
+
|
37
|
+
@pydantic.field_validator('secret_key')
|
38
|
+
def validate_secret_key(cls, v):
|
39
|
+
if not v:
|
40
|
+
raise ValueError("Secret key cannot be empty")
|
41
|
+
return v
|
42
|
+
|
43
|
+
@pydantic.field_validator('litellm_base_url')
|
44
|
+
def validate_litellm_base_url(cls, v):
|
45
|
+
if not v:
|
46
|
+
raise ValueError("Litellm base URL cannot be empty")
|
47
|
+
return v
|
48
|
+
|
49
|
+
class ChatCompletionRequest(pydantic.BaseModel):
|
50
|
+
model: str
|
51
|
+
messages: List[Dict[str, str]]
|
52
|
+
response_format: Optional[Union[pydantic.BaseModel, Dict[str, Any]]] = None
|
53
|
+
|
54
|
+
@pydantic.field_validator('messages')
|
55
|
+
def validate_messages(cls, messages):
|
56
|
+
if not messages:
|
57
|
+
raise ValueError("Messages cannot be empty")
|
58
|
+
|
59
|
+
for msg in messages:
|
60
|
+
if not isinstance(msg, dict):
|
61
|
+
raise TypeError("Message must be a dictionary")
|
62
|
+
if 'role' not in msg:
|
63
|
+
raise ValueError("Message missing required 'role' field")
|
64
|
+
if 'content' not in msg:
|
65
|
+
raise ValueError("Message missing required 'content' field")
|
66
|
+
if msg['role'] not in ['system', 'user', 'assistant']:
|
67
|
+
raise ValueError(f"Invalid role '{msg['role']}'. Must be 'system', 'user', or 'assistant'")
|
68
|
+
|
69
|
+
return messages
|
70
|
+
|
71
|
+
@pydantic.field_validator('model')
|
72
|
+
def validate_model(cls, model):
|
73
|
+
if not model:
|
74
|
+
raise ValueError("Model cannot be empty")
|
75
|
+
if model not in TOGETHER_SUPPORTED_MODELS and model not in LITELLM_SUPPORTED_MODELS:
|
76
|
+
raise ValueError(f"Model {model} is not in the list of supported models.")
|
77
|
+
return model
|
78
|
+
|
79
|
+
@pydantic.field_validator('response_format', mode='before')
|
80
|
+
def validate_response_format(cls, response_format):
|
81
|
+
if response_format is not None:
|
82
|
+
if not isinstance(response_format, (dict, pydantic.BaseModel)):
|
83
|
+
raise TypeError("Response format must be a dictionary or pydantic model")
|
84
|
+
# Optional: Add additional validation for required fields if needed
|
85
|
+
# For example, checking for 'type': 'json' in OpenAI's format
|
86
|
+
return response_format
|
87
|
+
|
88
|
+
os.environ['LITELLM_LOG'] = 'DEBUG'
|
89
|
+
|
90
|
+
load_dotenv()
|
91
|
+
|
92
|
+
def read_file(file_path: str) -> str:
|
93
|
+
with open(file_path, "r", encoding='utf-8') as file:
|
94
|
+
return file.read()
|
95
|
+
|
96
|
+
|
97
|
+
def fetch_together_api_response(model: str, messages: List[Mapping], response_format: pydantic.BaseModel = None) -> str:
|
98
|
+
"""
|
99
|
+
Fetches a single response from the Together API for a given model and messages.
|
100
|
+
"""
|
101
|
+
# Validate request
|
102
|
+
if messages is None or messages == []:
|
103
|
+
raise ValueError("Messages cannot be empty")
|
104
|
+
|
105
|
+
request = ChatCompletionRequest(
|
106
|
+
model=model,
|
107
|
+
messages=messages,
|
108
|
+
response_format=response_format
|
109
|
+
)
|
110
|
+
|
111
|
+
debug(f"Calling Together API with model: {request.model}")
|
112
|
+
debug(f"Messages: {request.messages}")
|
113
|
+
|
114
|
+
if request.response_format is not None:
|
115
|
+
debug(f"Using response format: {request.response_format}")
|
116
|
+
response = together_client.chat.completions.create(
|
117
|
+
model=TOGETHER_SUPPORTED_MODELS.get(request.model),
|
118
|
+
messages=request.messages,
|
119
|
+
response_format=request.response_format
|
120
|
+
)
|
121
|
+
else:
|
122
|
+
response = together_client.chat.completions.create(
|
123
|
+
model=TOGETHER_SUPPORTED_MODELS.get(request.model),
|
124
|
+
messages=request.messages,
|
125
|
+
)
|
126
|
+
|
127
|
+
debug(f"Received response: {response.choices[0].message.content[:100]}...")
|
128
|
+
return response.choices[0].message.content
|
129
|
+
|
130
|
+
|
131
|
+
async def afetch_together_api_response(model: str, messages: List[Mapping], response_format: pydantic.BaseModel = None) -> str:
|
132
|
+
"""
|
133
|
+
ASYNCHRONOUSLY Fetches a single response from the Together API for a given model and messages.
|
134
|
+
"""
|
135
|
+
request = ChatCompletionRequest(
|
136
|
+
model=model,
|
137
|
+
messages=messages,
|
138
|
+
response_format=response_format
|
139
|
+
)
|
140
|
+
|
141
|
+
debug(f"Calling Together API with model: {request.model}")
|
142
|
+
debug(f"Messages: {request.messages}")
|
143
|
+
|
144
|
+
if request.response_format is not None:
|
145
|
+
debug(f"Using response format: {request.response_format}")
|
146
|
+
response = await async_together_client.chat.completions.create(
|
147
|
+
model=TOGETHER_SUPPORTED_MODELS.get(request.model),
|
148
|
+
messages=request.messages,
|
149
|
+
response_format=request.response_format
|
150
|
+
)
|
151
|
+
else:
|
152
|
+
response = await async_together_client.chat.completions.create(
|
153
|
+
model=TOGETHER_SUPPORTED_MODELS.get(request.model),
|
154
|
+
messages=request.messages,
|
155
|
+
)
|
156
|
+
return response.choices[0].message.content
|
157
|
+
|
158
|
+
|
159
|
+
def query_together_api_multiple_calls(models: List[str], messages: List[List[Mapping]], response_formats: List[pydantic.BaseModel] = None) -> List[str]:
|
160
|
+
"""
|
161
|
+
Queries the Together API for multiple calls in parallel
|
162
|
+
|
163
|
+
Args:
|
164
|
+
models (List[str]): List of models to query
|
165
|
+
messages (List[List[Mapping]]): List of messages to query. Each inner object corresponds to a single prompt.
|
166
|
+
response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
|
167
|
+
|
168
|
+
Returns:
|
169
|
+
List[str]: TogetherAI responses for each model and message pair in order. Any exceptions in the thread call result in a None.
|
170
|
+
"""
|
171
|
+
# Check for empty models list
|
172
|
+
if not models:
|
173
|
+
raise ValueError("Models list cannot be empty")
|
174
|
+
|
175
|
+
# Validate all models are supported
|
176
|
+
for model in models:
|
177
|
+
if model not in TOGETHER_SUPPORTED_MODELS:
|
178
|
+
raise ValueError(f"Model {model} is not in the list of supported TogetherAI models: {TOGETHER_SUPPORTED_MODELS}.")
|
179
|
+
|
180
|
+
# Validate input lengths match
|
181
|
+
if response_formats is None:
|
182
|
+
response_formats = [None] * len(models)
|
183
|
+
if not (len(models) == len(messages) == len(response_formats)):
|
184
|
+
raise ValueError("Number of models, messages, and response formats must be the same")
|
185
|
+
|
186
|
+
# Validate message format
|
187
|
+
validate_batched_chat_messages(messages)
|
188
|
+
|
189
|
+
num_workers = int(os.getenv('NUM_WORKER_THREADS', MAX_WORKER_THREADS))
|
190
|
+
# Initialize results to maintain ordered outputs
|
191
|
+
out = [None] * len(messages)
|
192
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
|
193
|
+
# Submit all queries to together API with index, gets back the response content
|
194
|
+
futures = {executor.submit(fetch_together_api_response, model, message, response_format): idx \
|
195
|
+
for idx, (model, message, response_format) in enumerate(zip(models, messages, response_formats))}
|
196
|
+
|
197
|
+
# Collect results as they complete -- result is response content
|
198
|
+
for future in concurrent.futures.as_completed(futures):
|
199
|
+
idx = futures[future]
|
200
|
+
try:
|
201
|
+
out[idx] = future.result()
|
202
|
+
except Exception as e:
|
203
|
+
error(f"Error in parallel call {idx}: {str(e)}")
|
204
|
+
out[idx] = None
|
205
|
+
return out
|
206
|
+
|
207
|
+
|
208
|
+
async def aquery_together_api_multiple_calls(models: List[str], messages: List[List[Mapping]], response_formats: List[pydantic.BaseModel] = None) -> List[str]:
|
209
|
+
"""
|
210
|
+
Queries the Together API for multiple calls in parallel
|
211
|
+
|
212
|
+
Args:
|
213
|
+
models (List[str]): List of models to query
|
214
|
+
messages (List[List[Mapping]]): List of messages to query. Each inner object corresponds to a single prompt.
|
215
|
+
response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
|
216
|
+
|
217
|
+
Returns:
|
218
|
+
List[str]: TogetherAI responses for each model and message pair in order. Any exceptions in the thread call result in a None.
|
219
|
+
"""
|
220
|
+
# Check for empty models list
|
221
|
+
if not models:
|
222
|
+
raise ValueError("Models list cannot be empty")
|
223
|
+
|
224
|
+
# Validate all models are supported
|
225
|
+
for model in models:
|
226
|
+
if model not in TOGETHER_SUPPORTED_MODELS:
|
227
|
+
raise ValueError(f"Model {model} is not in the list of supported TogetherAI models: {TOGETHER_SUPPORTED_MODELS}.")
|
228
|
+
|
229
|
+
# Validate input lengths match
|
230
|
+
if response_formats is None:
|
231
|
+
response_formats = [None] * len(models)
|
232
|
+
if not (len(models) == len(messages) == len(response_formats)):
|
233
|
+
raise ValueError("Number of models, messages, and response formats must be the same")
|
234
|
+
|
235
|
+
# Validate message format
|
236
|
+
validate_batched_chat_messages(messages)
|
237
|
+
|
238
|
+
debug(f"Starting parallel Together API calls for {len(messages)} messages")
|
239
|
+
out = [None] * len(messages)
|
240
|
+
|
241
|
+
async def fetch_and_store(idx, model, message, response_format):
|
242
|
+
try:
|
243
|
+
debug(f"Processing call {idx} with model {model}")
|
244
|
+
out[idx] = await afetch_together_api_response(model, message, response_format)
|
245
|
+
except Exception as e:
|
246
|
+
error(f"Error in parallel call {idx}: {str(e)}")
|
247
|
+
out[idx] = None
|
248
|
+
|
249
|
+
tasks = [
|
250
|
+
fetch_and_store(idx, model, message, response_format)
|
251
|
+
for idx, (model, message, response_format) in enumerate(zip(models, messages, response_formats))
|
252
|
+
]
|
253
|
+
|
254
|
+
await asyncio.gather(*tasks)
|
255
|
+
debug(f"Completed {len(messages)} parallel calls")
|
256
|
+
return out
|
257
|
+
|
258
|
+
|
259
|
+
def fetch_litellm_api_response(model: str, messages: List[Mapping], response_format: pydantic.BaseModel = None) -> str:
|
260
|
+
"""
|
261
|
+
Fetches a single response from the Litellm API for a given model and messages.
|
262
|
+
"""
|
263
|
+
request = ChatCompletionRequest(
|
264
|
+
model=model,
|
265
|
+
messages=messages,
|
266
|
+
response_format=response_format
|
267
|
+
)
|
268
|
+
|
269
|
+
debug(f"Calling LiteLLM API with model: {request.model}")
|
270
|
+
debug(f"Messages: {request.messages}")
|
271
|
+
|
272
|
+
if request.response_format is not None:
|
273
|
+
debug(f"Using response format: {request.response_format}")
|
274
|
+
response = litellm.completion(
|
275
|
+
model=request.model,
|
276
|
+
messages=request.messages,
|
277
|
+
response_format=request.response_format
|
278
|
+
)
|
279
|
+
else:
|
280
|
+
response = litellm.completion(
|
281
|
+
model=request.model,
|
282
|
+
messages=request.messages,
|
283
|
+
)
|
284
|
+
return response.choices[0].message.content
|
285
|
+
|
286
|
+
|
287
|
+
def fetch_custom_litellm_api_response(custom_model_parameters: CustomModelParameters, messages: List[Mapping], response_format: pydantic.BaseModel = None) -> str:
|
288
|
+
if messages is None or messages == []:
|
289
|
+
raise ValueError("Messages cannot be empty")
|
290
|
+
|
291
|
+
if custom_model_parameters is None:
|
292
|
+
raise ValueError("Custom model parameters cannot be empty")
|
293
|
+
|
294
|
+
if not isinstance(custom_model_parameters, CustomModelParameters):
|
295
|
+
raise ValueError("Custom model parameters must be a CustomModelParameters object")
|
296
|
+
|
297
|
+
if response_format is not None:
|
298
|
+
response = litellm.completion(
|
299
|
+
model=custom_model_parameters.model_name,
|
300
|
+
messages=messages,
|
301
|
+
api_key=custom_model_parameters.secret_key,
|
302
|
+
base_url=custom_model_parameters.litellm_base_url,
|
303
|
+
response_format=response_format
|
304
|
+
)
|
305
|
+
else:
|
306
|
+
response = litellm.completion(
|
307
|
+
model=custom_model_parameters.model_name,
|
308
|
+
messages=messages,
|
309
|
+
api_key=custom_model_parameters.secret_key,
|
310
|
+
base_url=custom_model_parameters.litellm_base_url,
|
311
|
+
)
|
312
|
+
return response.choices[0].message.content
|
313
|
+
|
314
|
+
|
315
|
+
async def afetch_litellm_api_response(model: str, messages: List[Mapping], response_format: pydantic.BaseModel = None) -> str:
|
316
|
+
"""
|
317
|
+
ASYNCHRONOUSLY Fetches a single response from the Litellm API for a given model and messages.
|
318
|
+
"""
|
319
|
+
if messages is None or messages == []:
|
320
|
+
raise ValueError("Messages cannot be empty")
|
321
|
+
|
322
|
+
# Add validation
|
323
|
+
validate_chat_messages(messages)
|
324
|
+
|
325
|
+
if model not in LITELLM_SUPPORTED_MODELS:
|
326
|
+
raise ValueError(f"Model {model} is not in the list of supported Litellm models: {LITELLM_SUPPORTED_MODELS}.")
|
327
|
+
|
328
|
+
if response_format is not None:
|
329
|
+
response = await litellm.acompletion(
|
330
|
+
model=model,
|
331
|
+
messages=messages,
|
332
|
+
response_format=response_format
|
333
|
+
)
|
334
|
+
else:
|
335
|
+
response = await litellm.acompletion(
|
336
|
+
model=model,
|
337
|
+
messages=messages,
|
338
|
+
)
|
339
|
+
return response.choices[0].message.content
|
340
|
+
|
341
|
+
|
342
|
+
async def afetch_custom_litellm_api_response(custom_model_parameters: CustomModelParameters, messages: List[Mapping], response_format: pydantic.BaseModel = None) -> str:
|
343
|
+
"""
|
344
|
+
ASYNCHRONOUSLY Fetches a single response from the Litellm API for a given model and messages.
|
345
|
+
"""
|
346
|
+
if messages is None or messages == []:
|
347
|
+
raise ValueError("Messages cannot be empty")
|
348
|
+
|
349
|
+
if custom_model_parameters is None:
|
350
|
+
raise ValueError("Custom model parameters cannot be empty")
|
351
|
+
|
352
|
+
if not isinstance(custom_model_parameters, CustomModelParameters):
|
353
|
+
raise ValueError("Custom model parameters must be a CustomModelParameters object")
|
354
|
+
|
355
|
+
if response_format is not None:
|
356
|
+
response = await litellm.acompletion(
|
357
|
+
model=custom_model_parameters.model_name,
|
358
|
+
messages=messages,
|
359
|
+
api_key=custom_model_parameters.secret_key,
|
360
|
+
base_url=custom_model_parameters.litellm_base_url,
|
361
|
+
response_format=response_format
|
362
|
+
)
|
363
|
+
else:
|
364
|
+
response = await litellm.acompletion(
|
365
|
+
model=custom_model_parameters.model_name,
|
366
|
+
messages=messages,
|
367
|
+
api_key=custom_model_parameters.secret_key,
|
368
|
+
base_url=custom_model_parameters.litellm_base_url,
|
369
|
+
)
|
370
|
+
return response.choices[0].message.content
|
371
|
+
|
372
|
+
|
373
|
+
def query_litellm_api_multiple_calls(models: List[str], messages: List[Mapping], response_formats: List[pydantic.BaseModel] = None) -> List[str]:
|
374
|
+
"""
|
375
|
+
Queries the Litellm API for multiple calls in parallel
|
376
|
+
|
377
|
+
Args:
|
378
|
+
models (List[str]): List of models to query
|
379
|
+
messages (List[Mapping]): List of messages to query
|
380
|
+
response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
|
381
|
+
|
382
|
+
Returns:
|
383
|
+
List[str]: Litellm responses for each model and message pair in order. Any exceptions in the thread call result in a None.
|
384
|
+
"""
|
385
|
+
num_workers = int(os.getenv('NUM_WORKER_THREADS', MAX_WORKER_THREADS))
|
386
|
+
# Initialize results to maintain ordered outputs
|
387
|
+
out = [None] * len(messages)
|
388
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
|
389
|
+
# Submit all queries to Litellm API with index, gets back the response content
|
390
|
+
futures = {executor.submit(fetch_litellm_api_response, model, message, response_format): idx \
|
391
|
+
for idx, (model, message, response_format) in enumerate(zip(models, messages, response_formats))}
|
392
|
+
|
393
|
+
# Collect results as they complete -- result is response content
|
394
|
+
for future in concurrent.futures.as_completed(futures):
|
395
|
+
idx = futures[future]
|
396
|
+
try:
|
397
|
+
out[idx] = future.result()
|
398
|
+
except Exception as e:
|
399
|
+
error(f"Error in parallel call {idx}: {str(e)}")
|
400
|
+
out[idx] = None
|
401
|
+
return out
|
402
|
+
|
403
|
+
|
404
|
+
async def aquery_litellm_api_multiple_calls(models: List[str], messages: List[Mapping], response_formats: List[pydantic.BaseModel] = None) -> List[str]:
|
405
|
+
"""
|
406
|
+
Queries the Litellm API for multiple calls in parallel
|
407
|
+
|
408
|
+
Args:
|
409
|
+
models (List[str]): List of models to query
|
410
|
+
messages (List[Mapping]): List of messages to query
|
411
|
+
response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
|
412
|
+
|
413
|
+
Returns:
|
414
|
+
List[str]: Litellm responses for each model and message pair in order. Any exceptions in the thread call result in a None.
|
415
|
+
"""
|
416
|
+
# Initialize results to maintain ordered outputs
|
417
|
+
out = [None] * len(messages)
|
418
|
+
|
419
|
+
async def fetch_and_store(idx, model, message, response_format):
|
420
|
+
try:
|
421
|
+
out[idx] = await afetch_litellm_api_response(model, message, response_format)
|
422
|
+
except Exception as e:
|
423
|
+
error(f"Error in parallel call {idx}: {str(e)}")
|
424
|
+
out[idx] = None
|
425
|
+
|
426
|
+
tasks = [
|
427
|
+
fetch_and_store(idx, model, message, response_format)
|
428
|
+
for idx, (model, message, response_format) in enumerate(zip(models, messages, response_formats))
|
429
|
+
]
|
430
|
+
|
431
|
+
await asyncio.gather(*tasks)
|
432
|
+
return out
|
433
|
+
|
434
|
+
|
435
|
+
def validate_chat_messages(messages, batched: bool = False):
|
436
|
+
"""Validate chat message format before API call"""
|
437
|
+
if not isinstance(messages, list):
|
438
|
+
raise TypeError("Messages must be a list")
|
439
|
+
|
440
|
+
for msg in messages:
|
441
|
+
if not isinstance(msg, dict):
|
442
|
+
if batched and not isinstance(msg, list):
|
443
|
+
raise TypeError("Each message must be a list")
|
444
|
+
elif not batched:
|
445
|
+
raise TypeError("Message must be a dictionary")
|
446
|
+
if 'role' not in msg:
|
447
|
+
raise ValueError("Message missing required 'role' field")
|
448
|
+
if 'content' not in msg:
|
449
|
+
raise ValueError("Message missing required 'content' field")
|
450
|
+
if msg['role'] not in ['system', 'user', 'assistant']:
|
451
|
+
raise ValueError(f"Invalid role '{msg['role']}'. Must be 'system', 'user', or 'assistant'")
|
452
|
+
|
453
|
+
def validate_batched_chat_messages(messages: List[List[Mapping]]):
|
454
|
+
"""
|
455
|
+
Validate format of batched chat messages before API call
|
456
|
+
|
457
|
+
Args:
|
458
|
+
messages (List[List[Mapping]]): List of message lists, where each inner list contains
|
459
|
+
message dictionaries with 'role' and 'content' fields
|
460
|
+
|
461
|
+
Raises:
|
462
|
+
TypeError: If messages format is invalid
|
463
|
+
ValueError: If message content is invalid
|
464
|
+
"""
|
465
|
+
if not isinstance(messages, list):
|
466
|
+
raise TypeError("Batched messages must be a list")
|
467
|
+
|
468
|
+
if not messages:
|
469
|
+
raise ValueError("Batched messages cannot be empty")
|
470
|
+
|
471
|
+
for message_list in messages:
|
472
|
+
if not isinstance(message_list, list):
|
473
|
+
raise TypeError("Each batch item must be a list of messages")
|
474
|
+
|
475
|
+
# Validate individual messages using existing function
|
476
|
+
validate_chat_messages(message_list)
|
477
|
+
|
478
|
+
def get_chat_completion(model_type: str,
|
479
|
+
messages : Union[List[Mapping], List[List[Mapping]]],
|
480
|
+
response_format: pydantic.BaseModel = None,
|
481
|
+
batched: bool = False
|
482
|
+
) -> Union[str, List[str]]:
|
483
|
+
"""
|
484
|
+
Generates chat completions using a single model and potentially several messages. Supports closed-source and OSS models.
|
485
|
+
|
486
|
+
Parameters:
|
487
|
+
- model_type (str): The type of model to use for generating completions.
|
488
|
+
- messages (Union[List[Mapping], List[List[Mapping]]]): The messages to be used for generating completions.
|
489
|
+
If batched is True, this should be a list of lists of mappings.
|
490
|
+
- response_format (pydantic.BaseModel, optional): The format of the response. Defaults to None.
|
491
|
+
- batched (bool, optional): Whether to process messages in batch mode. Defaults to False.
|
492
|
+
Returns:
|
493
|
+
- str: The generated chat completion(s). If batched is True, returns a list of strings.
|
494
|
+
Raises:
|
495
|
+
- ValueError: If requested model is not supported by Litellm or TogetherAI.
|
496
|
+
"""
|
497
|
+
|
498
|
+
# Check for empty messages list
|
499
|
+
if not messages or messages == []:
|
500
|
+
raise ValueError("Messages cannot be empty")
|
501
|
+
|
502
|
+
# Add validation
|
503
|
+
if batched:
|
504
|
+
validate_batched_chat_messages(messages)
|
505
|
+
else:
|
506
|
+
validate_chat_messages(messages)
|
507
|
+
|
508
|
+
if batched and model_type in TOGETHER_SUPPORTED_MODELS:
|
509
|
+
return query_together_api_multiple_calls(models=[model_type] * len(messages),
|
510
|
+
messages=messages,
|
511
|
+
response_formats=[response_format] * len(messages))
|
512
|
+
elif batched and model_type in LITELLM_SUPPORTED_MODELS:
|
513
|
+
return query_litellm_api_multiple_calls(models=[model_type] * len(messages),
|
514
|
+
messages=messages,
|
515
|
+
response_format=response_format)
|
516
|
+
elif not batched and model_type in TOGETHER_SUPPORTED_MODELS:
|
517
|
+
return fetch_together_api_response(model=model_type,
|
518
|
+
messages=messages,
|
519
|
+
response_format=response_format)
|
520
|
+
elif not batched and model_type in LITELLM_SUPPORTED_MODELS:
|
521
|
+
return fetch_litellm_api_response(model=model_type,
|
522
|
+
messages=messages,
|
523
|
+
response_format=response_format)
|
524
|
+
|
525
|
+
|
526
|
+
|
527
|
+
raise ValueError(f"Model {model_type} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again.")
|
528
|
+
|
529
|
+
|
530
|
+
async def aget_chat_completion(model_type: str,
|
531
|
+
messages : Union[List[Mapping], List[List[Mapping]]],
|
532
|
+
response_format: pydantic.BaseModel = None,
|
533
|
+
batched: bool = False
|
534
|
+
) -> Union[str, List[str]]:
|
535
|
+
"""
|
536
|
+
ASYNCHRONOUSLY generates chat completions using a single model and potentially several messages. Supports closed-source and OSS models.
|
537
|
+
|
538
|
+
Parameters:
|
539
|
+
- model_type (str): The type of model to use for generating completions.
|
540
|
+
- messages (Union[List[Mapping], List[List[Mapping]]]): The messages to be used for generating completions.
|
541
|
+
If batched is True, this should be a list of lists of mappings.
|
542
|
+
- response_format (pydantic.BaseModel, optional): The format of the response. Defaults to None.
|
543
|
+
- batched (bool, optional): Whether to process messages in batch mode. Defaults to False.
|
544
|
+
Returns:
|
545
|
+
- str: The generated chat completion(s). If batched is True, returns a list of strings.
|
546
|
+
Raises:
|
547
|
+
- ValueError: If requested model is not supported by Litellm or TogetherAI.
|
548
|
+
"""
|
549
|
+
debug(f"Starting chat completion for model {model_type}, batched={batched}")
|
550
|
+
|
551
|
+
if batched:
|
552
|
+
validate_batched_chat_messages(messages)
|
553
|
+
else:
|
554
|
+
validate_chat_messages(messages)
|
555
|
+
|
556
|
+
if batched and model_type in TOGETHER_SUPPORTED_MODELS:
|
557
|
+
debug("Using batched Together API call")
|
558
|
+
return await aquery_together_api_multiple_calls(models=[model_type] * len(messages),
|
559
|
+
messages=messages,
|
560
|
+
response_formats=[response_format] * len(messages))
|
561
|
+
elif batched and model_type in LITELLM_SUPPORTED_MODELS:
|
562
|
+
debug("Using batched LiteLLM API call")
|
563
|
+
return await aquery_litellm_api_multiple_calls(models=[model_type] * len(messages),
|
564
|
+
messages=messages,
|
565
|
+
response_formats=[response_format] * len(messages))
|
566
|
+
elif not batched and model_type in TOGETHER_SUPPORTED_MODELS:
|
567
|
+
debug("Using single Together API call")
|
568
|
+
return await afetch_together_api_response(model=model_type,
|
569
|
+
messages=messages,
|
570
|
+
response_format=response_format)
|
571
|
+
elif not batched and model_type in LITELLM_SUPPORTED_MODELS:
|
572
|
+
debug("Using single LiteLLM API call")
|
573
|
+
return await afetch_litellm_api_response(model=model_type,
|
574
|
+
messages=messages,
|
575
|
+
response_format=response_format)
|
576
|
+
|
577
|
+
error(f"Model {model_type} not supported by either API")
|
578
|
+
raise ValueError(f"Model {model_type} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again.")
|
579
|
+
|
580
|
+
|
581
|
+
def get_completion_multiple_models(models: List[str], messages: List[List[Mapping]], response_formats: List[pydantic.BaseModel] = None) -> List[str]:
|
582
|
+
"""
|
583
|
+
Retrieves completions for a single prompt from multiple models in parallel. Supports closed-source and OSS models.
|
584
|
+
|
585
|
+
Args:
|
586
|
+
models (List[str]): List of models to query
|
587
|
+
messages (List[List[Mapping]]): List of messages to query. Each inner object corresponds to a single prompt.
|
588
|
+
response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
|
589
|
+
|
590
|
+
Returns:
|
591
|
+
List[str]: List of completions from the models in the order of the input models
|
592
|
+
Raises:
|
593
|
+
ValueError: If a model is not supported by Litellm or Together
|
594
|
+
"""
|
595
|
+
debug(f"Starting multiple model completion for {len(models)} models")
|
596
|
+
|
597
|
+
if models is None or models == []:
|
598
|
+
raise ValueError("Models list cannot be empty")
|
599
|
+
|
600
|
+
validate_batched_chat_messages(messages)
|
601
|
+
|
602
|
+
if len(models) != len(messages):
|
603
|
+
error(f"Model/message count mismatch: {len(models)} vs {len(messages)}")
|
604
|
+
raise ValueError(f"Number of models and messages must be the same: {len(models)} != {len(messages)}")
|
605
|
+
if response_formats is None:
|
606
|
+
response_formats = [None] * len(models)
|
607
|
+
# Partition the model requests into TogetherAI and Litellm models, but keep the ordering saved
|
608
|
+
together_calls, litellm_calls = {}, {} # index -> model, message, response_format
|
609
|
+
together_responses, litellm_responses = [], []
|
610
|
+
for idx, (model, message, r_format) in enumerate(zip(models, messages, response_formats)):
|
611
|
+
if model in TOGETHER_SUPPORTED_MODELS:
|
612
|
+
debug(f"Model {model} routed to Together API")
|
613
|
+
together_calls[idx] = (model, message, r_format)
|
614
|
+
elif model in LITELLM_SUPPORTED_MODELS:
|
615
|
+
debug(f"Model {model} routed to LiteLLM API")
|
616
|
+
litellm_calls[idx] = (model, message, r_format)
|
617
|
+
else:
|
618
|
+
error(f"Model {model} not supported by either API")
|
619
|
+
raise ValueError(f"Model {model} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again.")
|
620
|
+
|
621
|
+
# Add validation before processing
|
622
|
+
for msg_list in messages:
|
623
|
+
validate_chat_messages(msg_list)
|
624
|
+
|
625
|
+
# Get the responses from the TogetherAI models
|
626
|
+
# List of responses from the TogetherAI models in order of the together_calls dict
|
627
|
+
if together_calls:
|
628
|
+
debug(f"Executing {len(together_calls)} Together API calls")
|
629
|
+
together_responses = query_together_api_multiple_calls(models=[model for model, _, _ in together_calls.values()],
|
630
|
+
messages=[message for _, message, _ in together_calls.values()],
|
631
|
+
response_formats=[format for _, _, format in together_calls.values()])
|
632
|
+
|
633
|
+
# Get the responses from the Litellm models
|
634
|
+
if litellm_calls:
|
635
|
+
debug(f"Executing {len(litellm_calls)} LiteLLM API calls")
|
636
|
+
litellm_responses = query_litellm_api_multiple_calls(models=[model for model, _, _ in litellm_calls.values()],
|
637
|
+
messages=[message for _, message, _ in litellm_calls.values()],
|
638
|
+
response_formats=[format for _, _, format in litellm_calls.values()])
|
639
|
+
|
640
|
+
# Merge the responses in the order of the original models
|
641
|
+
debug("Merging responses")
|
642
|
+
out = [None] * len(models)
|
643
|
+
for idx, (model, message, r_format) in together_calls.items():
|
644
|
+
out[idx] = together_responses.pop(0)
|
645
|
+
for idx, (model, message, r_format) in litellm_calls.items():
|
646
|
+
out[idx] = litellm_responses.pop(0)
|
647
|
+
debug("Multiple model completion finished")
|
648
|
+
return out
|
649
|
+
|
650
|
+
|
651
|
+
async def aget_completion_multiple_models(models: List[str], messages: List[List[Mapping]], response_formats: List[pydantic.BaseModel] = None) -> List[str]:
|
652
|
+
"""
|
653
|
+
ASYNCHRONOUSLY retrieves completions for a single prompt from multiple models in parallel. Supports closed-source and OSS models.
|
654
|
+
|
655
|
+
Args:
|
656
|
+
models (List[str]): List of models to query
|
657
|
+
messages (List[List[Mapping]]): List of messages to query. Each inner object corresponds to a single prompt.
|
658
|
+
response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
|
659
|
+
|
660
|
+
Returns:
|
661
|
+
List[str]: List of completions from the models in the order of the input models
|
662
|
+
Raises:
|
663
|
+
ValueError: If a model is not supported by Litellm or Together
|
664
|
+
"""
|
665
|
+
if models is None or models == []:
|
666
|
+
raise ValueError("Models list cannot be empty")
|
667
|
+
|
668
|
+
if len(models) != len(messages):
|
669
|
+
raise ValueError(f"Number of models and messages must be the same: {len(models)} != {len(messages)}")
|
670
|
+
if response_formats is None:
|
671
|
+
response_formats = [None] * len(models)
|
672
|
+
|
673
|
+
validate_batched_chat_messages(messages)
|
674
|
+
|
675
|
+
# Partition the model requests into TogetherAI and Litellm models, but keep the ordering saved
|
676
|
+
together_calls, litellm_calls = {}, {} # index -> model, message, response_format
|
677
|
+
together_responses, litellm_responses = [], []
|
678
|
+
for idx, (model, message, r_format) in enumerate(zip(models, messages, response_formats)):
|
679
|
+
if model in TOGETHER_SUPPORTED_MODELS:
|
680
|
+
together_calls[idx] = (model, message, r_format)
|
681
|
+
elif model in LITELLM_SUPPORTED_MODELS:
|
682
|
+
litellm_calls[idx] = (model, message, r_format)
|
683
|
+
else:
|
684
|
+
raise ValueError(f"Model {model} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again.")
|
685
|
+
|
686
|
+
# Add validation before processing
|
687
|
+
for msg_list in messages:
|
688
|
+
validate_chat_messages(msg_list)
|
689
|
+
|
690
|
+
# Get the responses from the TogetherAI models
|
691
|
+
# List of responses from the TogetherAI models in order of the together_calls dict
|
692
|
+
if together_calls:
|
693
|
+
together_responses = await aquery_together_api_multiple_calls(
|
694
|
+
models=[model for model, _, _ in together_calls.values()],
|
695
|
+
messages=[message for _, message, _ in together_calls.values()],
|
696
|
+
response_formats=[format for _, _, format in together_calls.values()]
|
697
|
+
)
|
698
|
+
|
699
|
+
# Get the responses from the Litellm models
|
700
|
+
if litellm_calls:
|
701
|
+
litellm_responses = await aquery_litellm_api_multiple_calls(
|
702
|
+
models=[model for model, _, _ in litellm_calls.values()],
|
703
|
+
messages=[message for _, message, _ in litellm_calls.values()],
|
704
|
+
response_formats=[format for _, _, format in litellm_calls.values()]
|
705
|
+
)
|
706
|
+
|
707
|
+
# Merge the responses in the order of the original models
|
708
|
+
out = [None] * len(models)
|
709
|
+
for idx, (model, message, r_format) in together_calls.items():
|
710
|
+
out[idx] = together_responses.pop(0)
|
711
|
+
for idx, (model, message, r_format) in litellm_calls.items():
|
712
|
+
out[idx] = litellm_responses.pop(0)
|
713
|
+
return out
|
714
|
+
|
715
|
+
|
716
|
+
if __name__ == "__main__":
|
717
|
+
|
718
|
+
# Batched
|
719
|
+
pprint.pprint(get_chat_completion(
|
720
|
+
model_type="LLAMA3_405B_INSTRUCT_TURBO",
|
721
|
+
messages=[
|
722
|
+
[
|
723
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
724
|
+
{"role": "user", "content": "What is the capital of France?"},
|
725
|
+
],
|
726
|
+
[
|
727
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
728
|
+
{"role": "user", "content": "What is the capital of Japan?"},
|
729
|
+
]
|
730
|
+
],
|
731
|
+
batched=True
|
732
|
+
))
|
733
|
+
|
734
|
+
# Non batched
|
735
|
+
pprint.pprint(get_chat_completion(
|
736
|
+
model_type="LLAMA3_8B_INSTRUCT_TURBO",
|
737
|
+
messages=[
|
738
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
739
|
+
{"role": "user", "content": "What is the capital of France?"},
|
740
|
+
],
|
741
|
+
batched=False
|
742
|
+
))
|
743
|
+
|
744
|
+
# Batched single completion to multiple models
|
745
|
+
pprint.pprint(get_completion_multiple_models(
|
746
|
+
models=[
|
747
|
+
"LLAMA3_70B_INSTRUCT_TURBO", "LLAMA3_405B_INSTRUCT_TURBO", "gpt-4o-mini"
|
748
|
+
],
|
749
|
+
messages=[
|
750
|
+
[
|
751
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
752
|
+
{"role": "user", "content": "What is the capital of China?"},
|
753
|
+
],
|
754
|
+
[
|
755
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
756
|
+
{"role": "user", "content": "What is the capital of France?"},
|
757
|
+
],
|
758
|
+
[
|
759
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
760
|
+
{"role": "user", "content": "What is the capital of Japan?"},
|
761
|
+
]
|
762
|
+
]
|
763
|
+
))
|