judgeval 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. judgeval/__init__.py +83 -0
  2. judgeval/clients.py +19 -0
  3. judgeval/common/__init__.py +8 -0
  4. judgeval/common/exceptions.py +28 -0
  5. judgeval/common/logger.py +189 -0
  6. judgeval/common/tracer.py +587 -0
  7. judgeval/common/utils.py +763 -0
  8. judgeval/constants.py +55 -0
  9. judgeval/data/__init__.py +14 -0
  10. judgeval/data/api_example.py +111 -0
  11. judgeval/data/datasets/__init__.py +4 -0
  12. judgeval/data/datasets/dataset.py +407 -0
  13. judgeval/data/datasets/ground_truth.py +54 -0
  14. judgeval/data/datasets/utils.py +74 -0
  15. judgeval/data/example.py +76 -0
  16. judgeval/data/result.py +83 -0
  17. judgeval/data/scorer_data.py +86 -0
  18. judgeval/evaluation_run.py +130 -0
  19. judgeval/judges/__init__.py +7 -0
  20. judgeval/judges/base_judge.py +44 -0
  21. judgeval/judges/litellm_judge.py +49 -0
  22. judgeval/judges/mixture_of_judges.py +248 -0
  23. judgeval/judges/together_judge.py +55 -0
  24. judgeval/judges/utils.py +45 -0
  25. judgeval/judgment_client.py +244 -0
  26. judgeval/run_evaluation.py +355 -0
  27. judgeval/scorers/__init__.py +30 -0
  28. judgeval/scorers/base_scorer.py +51 -0
  29. judgeval/scorers/custom_scorer.py +134 -0
  30. judgeval/scorers/judgeval_scorers/__init__.py +21 -0
  31. judgeval/scorers/judgeval_scorers/answer_relevancy.py +19 -0
  32. judgeval/scorers/judgeval_scorers/contextual_precision.py +19 -0
  33. judgeval/scorers/judgeval_scorers/contextual_recall.py +19 -0
  34. judgeval/scorers/judgeval_scorers/contextual_relevancy.py +22 -0
  35. judgeval/scorers/judgeval_scorers/faithfulness.py +19 -0
  36. judgeval/scorers/judgeval_scorers/hallucination.py +19 -0
  37. judgeval/scorers/judgeval_scorers/json_correctness.py +32 -0
  38. judgeval/scorers/judgeval_scorers/summarization.py +20 -0
  39. judgeval/scorers/judgeval_scorers/tool_correctness.py +19 -0
  40. judgeval/scorers/prompt_scorer.py +439 -0
  41. judgeval/scorers/score.py +427 -0
  42. judgeval/scorers/utils.py +175 -0
  43. judgeval-0.0.1.dist-info/METADATA +40 -0
  44. judgeval-0.0.1.dist-info/RECORD +46 -0
  45. judgeval-0.0.1.dist-info/WHEEL +4 -0
  46. judgeval-0.0.1.dist-info/licenses/LICENSE.md +202 -0
@@ -0,0 +1,763 @@
1
+ """
2
+ This file contains utility functions used in repo scripts
3
+
4
+ For API calling, we support:
5
+ - parallelized model calls on the same prompt
6
+ - batched model calls on different prompts
7
+
8
+ NOTE: any function beginning with 'a', e.g. 'afetch_together_api_response', is an asynchronous function
9
+ """
10
+
11
+ import concurrent.futures
12
+ from typing import List, Mapping, Dict, Union, Optional, Literal, Any
13
+ import asyncio
14
+ import litellm
15
+ import pydantic
16
+ import pprint
17
+ import os
18
+ from dotenv import load_dotenv
19
+
20
+ from judgeval.clients import async_together_client, together_client
21
+ from judgeval.constants import *
22
+ from judgeval.common.logger import debug, error
23
+
24
+ LITELLM_SUPPORTED_MODELS = set(litellm.model_list)
25
+
26
+ class CustomModelParameters(pydantic.BaseModel):
27
+ model_name: str
28
+ secret_key: str
29
+ litellm_base_url: str
30
+
31
+ @pydantic.field_validator('model_name')
32
+ def validate_model_name(cls, v):
33
+ if not v:
34
+ raise ValueError("Model name cannot be empty")
35
+ return v
36
+
37
+ @pydantic.field_validator('secret_key')
38
+ def validate_secret_key(cls, v):
39
+ if not v:
40
+ raise ValueError("Secret key cannot be empty")
41
+ return v
42
+
43
+ @pydantic.field_validator('litellm_base_url')
44
+ def validate_litellm_base_url(cls, v):
45
+ if not v:
46
+ raise ValueError("Litellm base URL cannot be empty")
47
+ return v
48
+
49
+ class ChatCompletionRequest(pydantic.BaseModel):
50
+ model: str
51
+ messages: List[Dict[str, str]]
52
+ response_format: Optional[Union[pydantic.BaseModel, Dict[str, Any]]] = None
53
+
54
+ @pydantic.field_validator('messages')
55
+ def validate_messages(cls, messages):
56
+ if not messages:
57
+ raise ValueError("Messages cannot be empty")
58
+
59
+ for msg in messages:
60
+ if not isinstance(msg, dict):
61
+ raise TypeError("Message must be a dictionary")
62
+ if 'role' not in msg:
63
+ raise ValueError("Message missing required 'role' field")
64
+ if 'content' not in msg:
65
+ raise ValueError("Message missing required 'content' field")
66
+ if msg['role'] not in ['system', 'user', 'assistant']:
67
+ raise ValueError(f"Invalid role '{msg['role']}'. Must be 'system', 'user', or 'assistant'")
68
+
69
+ return messages
70
+
71
+ @pydantic.field_validator('model')
72
+ def validate_model(cls, model):
73
+ if not model:
74
+ raise ValueError("Model cannot be empty")
75
+ if model not in TOGETHER_SUPPORTED_MODELS and model not in LITELLM_SUPPORTED_MODELS:
76
+ raise ValueError(f"Model {model} is not in the list of supported models.")
77
+ return model
78
+
79
+ @pydantic.field_validator('response_format', mode='before')
80
+ def validate_response_format(cls, response_format):
81
+ if response_format is not None:
82
+ if not isinstance(response_format, (dict, pydantic.BaseModel)):
83
+ raise TypeError("Response format must be a dictionary or pydantic model")
84
+ # Optional: Add additional validation for required fields if needed
85
+ # For example, checking for 'type': 'json' in OpenAI's format
86
+ return response_format
87
+
88
+ os.environ['LITELLM_LOG'] = 'DEBUG'
89
+
90
+ load_dotenv()
91
+
92
+ def read_file(file_path: str) -> str:
93
+ with open(file_path, "r", encoding='utf-8') as file:
94
+ return file.read()
95
+
96
+
97
+ def fetch_together_api_response(model: str, messages: List[Mapping], response_format: pydantic.BaseModel = None) -> str:
98
+ """
99
+ Fetches a single response from the Together API for a given model and messages.
100
+ """
101
+ # Validate request
102
+ if messages is None or messages == []:
103
+ raise ValueError("Messages cannot be empty")
104
+
105
+ request = ChatCompletionRequest(
106
+ model=model,
107
+ messages=messages,
108
+ response_format=response_format
109
+ )
110
+
111
+ debug(f"Calling Together API with model: {request.model}")
112
+ debug(f"Messages: {request.messages}")
113
+
114
+ if request.response_format is not None:
115
+ debug(f"Using response format: {request.response_format}")
116
+ response = together_client.chat.completions.create(
117
+ model=TOGETHER_SUPPORTED_MODELS.get(request.model),
118
+ messages=request.messages,
119
+ response_format=request.response_format
120
+ )
121
+ else:
122
+ response = together_client.chat.completions.create(
123
+ model=TOGETHER_SUPPORTED_MODELS.get(request.model),
124
+ messages=request.messages,
125
+ )
126
+
127
+ debug(f"Received response: {response.choices[0].message.content[:100]}...")
128
+ return response.choices[0].message.content
129
+
130
+
131
+ async def afetch_together_api_response(model: str, messages: List[Mapping], response_format: pydantic.BaseModel = None) -> str:
132
+ """
133
+ ASYNCHRONOUSLY Fetches a single response from the Together API for a given model and messages.
134
+ """
135
+ request = ChatCompletionRequest(
136
+ model=model,
137
+ messages=messages,
138
+ response_format=response_format
139
+ )
140
+
141
+ debug(f"Calling Together API with model: {request.model}")
142
+ debug(f"Messages: {request.messages}")
143
+
144
+ if request.response_format is not None:
145
+ debug(f"Using response format: {request.response_format}")
146
+ response = await async_together_client.chat.completions.create(
147
+ model=TOGETHER_SUPPORTED_MODELS.get(request.model),
148
+ messages=request.messages,
149
+ response_format=request.response_format
150
+ )
151
+ else:
152
+ response = await async_together_client.chat.completions.create(
153
+ model=TOGETHER_SUPPORTED_MODELS.get(request.model),
154
+ messages=request.messages,
155
+ )
156
+ return response.choices[0].message.content
157
+
158
+
159
+ def query_together_api_multiple_calls(models: List[str], messages: List[List[Mapping]], response_formats: List[pydantic.BaseModel] = None) -> List[str]:
160
+ """
161
+ Queries the Together API for multiple calls in parallel
162
+
163
+ Args:
164
+ models (List[str]): List of models to query
165
+ messages (List[List[Mapping]]): List of messages to query. Each inner object corresponds to a single prompt.
166
+ response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
167
+
168
+ Returns:
169
+ List[str]: TogetherAI responses for each model and message pair in order. Any exceptions in the thread call result in a None.
170
+ """
171
+ # Check for empty models list
172
+ if not models:
173
+ raise ValueError("Models list cannot be empty")
174
+
175
+ # Validate all models are supported
176
+ for model in models:
177
+ if model not in TOGETHER_SUPPORTED_MODELS:
178
+ raise ValueError(f"Model {model} is not in the list of supported TogetherAI models: {TOGETHER_SUPPORTED_MODELS}.")
179
+
180
+ # Validate input lengths match
181
+ if response_formats is None:
182
+ response_formats = [None] * len(models)
183
+ if not (len(models) == len(messages) == len(response_formats)):
184
+ raise ValueError("Number of models, messages, and response formats must be the same")
185
+
186
+ # Validate message format
187
+ validate_batched_chat_messages(messages)
188
+
189
+ num_workers = int(os.getenv('NUM_WORKER_THREADS', MAX_WORKER_THREADS))
190
+ # Initialize results to maintain ordered outputs
191
+ out = [None] * len(messages)
192
+ with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
193
+ # Submit all queries to together API with index, gets back the response content
194
+ futures = {executor.submit(fetch_together_api_response, model, message, response_format): idx \
195
+ for idx, (model, message, response_format) in enumerate(zip(models, messages, response_formats))}
196
+
197
+ # Collect results as they complete -- result is response content
198
+ for future in concurrent.futures.as_completed(futures):
199
+ idx = futures[future]
200
+ try:
201
+ out[idx] = future.result()
202
+ except Exception as e:
203
+ error(f"Error in parallel call {idx}: {str(e)}")
204
+ out[idx] = None
205
+ return out
206
+
207
+
208
+ async def aquery_together_api_multiple_calls(models: List[str], messages: List[List[Mapping]], response_formats: List[pydantic.BaseModel] = None) -> List[str]:
209
+ """
210
+ Queries the Together API for multiple calls in parallel
211
+
212
+ Args:
213
+ models (List[str]): List of models to query
214
+ messages (List[List[Mapping]]): List of messages to query. Each inner object corresponds to a single prompt.
215
+ response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
216
+
217
+ Returns:
218
+ List[str]: TogetherAI responses for each model and message pair in order. Any exceptions in the thread call result in a None.
219
+ """
220
+ # Check for empty models list
221
+ if not models:
222
+ raise ValueError("Models list cannot be empty")
223
+
224
+ # Validate all models are supported
225
+ for model in models:
226
+ if model not in TOGETHER_SUPPORTED_MODELS:
227
+ raise ValueError(f"Model {model} is not in the list of supported TogetherAI models: {TOGETHER_SUPPORTED_MODELS}.")
228
+
229
+ # Validate input lengths match
230
+ if response_formats is None:
231
+ response_formats = [None] * len(models)
232
+ if not (len(models) == len(messages) == len(response_formats)):
233
+ raise ValueError("Number of models, messages, and response formats must be the same")
234
+
235
+ # Validate message format
236
+ validate_batched_chat_messages(messages)
237
+
238
+ debug(f"Starting parallel Together API calls for {len(messages)} messages")
239
+ out = [None] * len(messages)
240
+
241
+ async def fetch_and_store(idx, model, message, response_format):
242
+ try:
243
+ debug(f"Processing call {idx} with model {model}")
244
+ out[idx] = await afetch_together_api_response(model, message, response_format)
245
+ except Exception as e:
246
+ error(f"Error in parallel call {idx}: {str(e)}")
247
+ out[idx] = None
248
+
249
+ tasks = [
250
+ fetch_and_store(idx, model, message, response_format)
251
+ for idx, (model, message, response_format) in enumerate(zip(models, messages, response_formats))
252
+ ]
253
+
254
+ await asyncio.gather(*tasks)
255
+ debug(f"Completed {len(messages)} parallel calls")
256
+ return out
257
+
258
+
259
+ def fetch_litellm_api_response(model: str, messages: List[Mapping], response_format: pydantic.BaseModel = None) -> str:
260
+ """
261
+ Fetches a single response from the Litellm API for a given model and messages.
262
+ """
263
+ request = ChatCompletionRequest(
264
+ model=model,
265
+ messages=messages,
266
+ response_format=response_format
267
+ )
268
+
269
+ debug(f"Calling LiteLLM API with model: {request.model}")
270
+ debug(f"Messages: {request.messages}")
271
+
272
+ if request.response_format is not None:
273
+ debug(f"Using response format: {request.response_format}")
274
+ response = litellm.completion(
275
+ model=request.model,
276
+ messages=request.messages,
277
+ response_format=request.response_format
278
+ )
279
+ else:
280
+ response = litellm.completion(
281
+ model=request.model,
282
+ messages=request.messages,
283
+ )
284
+ return response.choices[0].message.content
285
+
286
+
287
+ def fetch_custom_litellm_api_response(custom_model_parameters: CustomModelParameters, messages: List[Mapping], response_format: pydantic.BaseModel = None) -> str:
288
+ if messages is None or messages == []:
289
+ raise ValueError("Messages cannot be empty")
290
+
291
+ if custom_model_parameters is None:
292
+ raise ValueError("Custom model parameters cannot be empty")
293
+
294
+ if not isinstance(custom_model_parameters, CustomModelParameters):
295
+ raise ValueError("Custom model parameters must be a CustomModelParameters object")
296
+
297
+ if response_format is not None:
298
+ response = litellm.completion(
299
+ model=custom_model_parameters.model_name,
300
+ messages=messages,
301
+ api_key=custom_model_parameters.secret_key,
302
+ base_url=custom_model_parameters.litellm_base_url,
303
+ response_format=response_format
304
+ )
305
+ else:
306
+ response = litellm.completion(
307
+ model=custom_model_parameters.model_name,
308
+ messages=messages,
309
+ api_key=custom_model_parameters.secret_key,
310
+ base_url=custom_model_parameters.litellm_base_url,
311
+ )
312
+ return response.choices[0].message.content
313
+
314
+
315
+ async def afetch_litellm_api_response(model: str, messages: List[Mapping], response_format: pydantic.BaseModel = None) -> str:
316
+ """
317
+ ASYNCHRONOUSLY Fetches a single response from the Litellm API for a given model and messages.
318
+ """
319
+ if messages is None or messages == []:
320
+ raise ValueError("Messages cannot be empty")
321
+
322
+ # Add validation
323
+ validate_chat_messages(messages)
324
+
325
+ if model not in LITELLM_SUPPORTED_MODELS:
326
+ raise ValueError(f"Model {model} is not in the list of supported Litellm models: {LITELLM_SUPPORTED_MODELS}.")
327
+
328
+ if response_format is not None:
329
+ response = await litellm.acompletion(
330
+ model=model,
331
+ messages=messages,
332
+ response_format=response_format
333
+ )
334
+ else:
335
+ response = await litellm.acompletion(
336
+ model=model,
337
+ messages=messages,
338
+ )
339
+ return response.choices[0].message.content
340
+
341
+
342
+ async def afetch_custom_litellm_api_response(custom_model_parameters: CustomModelParameters, messages: List[Mapping], response_format: pydantic.BaseModel = None) -> str:
343
+ """
344
+ ASYNCHRONOUSLY Fetches a single response from the Litellm API for a given model and messages.
345
+ """
346
+ if messages is None or messages == []:
347
+ raise ValueError("Messages cannot be empty")
348
+
349
+ if custom_model_parameters is None:
350
+ raise ValueError("Custom model parameters cannot be empty")
351
+
352
+ if not isinstance(custom_model_parameters, CustomModelParameters):
353
+ raise ValueError("Custom model parameters must be a CustomModelParameters object")
354
+
355
+ if response_format is not None:
356
+ response = await litellm.acompletion(
357
+ model=custom_model_parameters.model_name,
358
+ messages=messages,
359
+ api_key=custom_model_parameters.secret_key,
360
+ base_url=custom_model_parameters.litellm_base_url,
361
+ response_format=response_format
362
+ )
363
+ else:
364
+ response = await litellm.acompletion(
365
+ model=custom_model_parameters.model_name,
366
+ messages=messages,
367
+ api_key=custom_model_parameters.secret_key,
368
+ base_url=custom_model_parameters.litellm_base_url,
369
+ )
370
+ return response.choices[0].message.content
371
+
372
+
373
+ def query_litellm_api_multiple_calls(models: List[str], messages: List[Mapping], response_formats: List[pydantic.BaseModel] = None) -> List[str]:
374
+ """
375
+ Queries the Litellm API for multiple calls in parallel
376
+
377
+ Args:
378
+ models (List[str]): List of models to query
379
+ messages (List[Mapping]): List of messages to query
380
+ response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
381
+
382
+ Returns:
383
+ List[str]: Litellm responses for each model and message pair in order. Any exceptions in the thread call result in a None.
384
+ """
385
+ num_workers = int(os.getenv('NUM_WORKER_THREADS', MAX_WORKER_THREADS))
386
+ # Initialize results to maintain ordered outputs
387
+ out = [None] * len(messages)
388
+ with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
389
+ # Submit all queries to Litellm API with index, gets back the response content
390
+ futures = {executor.submit(fetch_litellm_api_response, model, message, response_format): idx \
391
+ for idx, (model, message, response_format) in enumerate(zip(models, messages, response_formats))}
392
+
393
+ # Collect results as they complete -- result is response content
394
+ for future in concurrent.futures.as_completed(futures):
395
+ idx = futures[future]
396
+ try:
397
+ out[idx] = future.result()
398
+ except Exception as e:
399
+ error(f"Error in parallel call {idx}: {str(e)}")
400
+ out[idx] = None
401
+ return out
402
+
403
+
404
+ async def aquery_litellm_api_multiple_calls(models: List[str], messages: List[Mapping], response_formats: List[pydantic.BaseModel] = None) -> List[str]:
405
+ """
406
+ Queries the Litellm API for multiple calls in parallel
407
+
408
+ Args:
409
+ models (List[str]): List of models to query
410
+ messages (List[Mapping]): List of messages to query
411
+ response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
412
+
413
+ Returns:
414
+ List[str]: Litellm responses for each model and message pair in order. Any exceptions in the thread call result in a None.
415
+ """
416
+ # Initialize results to maintain ordered outputs
417
+ out = [None] * len(messages)
418
+
419
+ async def fetch_and_store(idx, model, message, response_format):
420
+ try:
421
+ out[idx] = await afetch_litellm_api_response(model, message, response_format)
422
+ except Exception as e:
423
+ error(f"Error in parallel call {idx}: {str(e)}")
424
+ out[idx] = None
425
+
426
+ tasks = [
427
+ fetch_and_store(idx, model, message, response_format)
428
+ for idx, (model, message, response_format) in enumerate(zip(models, messages, response_formats))
429
+ ]
430
+
431
+ await asyncio.gather(*tasks)
432
+ return out
433
+
434
+
435
+ def validate_chat_messages(messages, batched: bool = False):
436
+ """Validate chat message format before API call"""
437
+ if not isinstance(messages, list):
438
+ raise TypeError("Messages must be a list")
439
+
440
+ for msg in messages:
441
+ if not isinstance(msg, dict):
442
+ if batched and not isinstance(msg, list):
443
+ raise TypeError("Each message must be a list")
444
+ elif not batched:
445
+ raise TypeError("Message must be a dictionary")
446
+ if 'role' not in msg:
447
+ raise ValueError("Message missing required 'role' field")
448
+ if 'content' not in msg:
449
+ raise ValueError("Message missing required 'content' field")
450
+ if msg['role'] not in ['system', 'user', 'assistant']:
451
+ raise ValueError(f"Invalid role '{msg['role']}'. Must be 'system', 'user', or 'assistant'")
452
+
453
+ def validate_batched_chat_messages(messages: List[List[Mapping]]):
454
+ """
455
+ Validate format of batched chat messages before API call
456
+
457
+ Args:
458
+ messages (List[List[Mapping]]): List of message lists, where each inner list contains
459
+ message dictionaries with 'role' and 'content' fields
460
+
461
+ Raises:
462
+ TypeError: If messages format is invalid
463
+ ValueError: If message content is invalid
464
+ """
465
+ if not isinstance(messages, list):
466
+ raise TypeError("Batched messages must be a list")
467
+
468
+ if not messages:
469
+ raise ValueError("Batched messages cannot be empty")
470
+
471
+ for message_list in messages:
472
+ if not isinstance(message_list, list):
473
+ raise TypeError("Each batch item must be a list of messages")
474
+
475
+ # Validate individual messages using existing function
476
+ validate_chat_messages(message_list)
477
+
478
+ def get_chat_completion(model_type: str,
479
+ messages : Union[List[Mapping], List[List[Mapping]]],
480
+ response_format: pydantic.BaseModel = None,
481
+ batched: bool = False
482
+ ) -> Union[str, List[str]]:
483
+ """
484
+ Generates chat completions using a single model and potentially several messages. Supports closed-source and OSS models.
485
+
486
+ Parameters:
487
+ - model_type (str): The type of model to use for generating completions.
488
+ - messages (Union[List[Mapping], List[List[Mapping]]]): The messages to be used for generating completions.
489
+ If batched is True, this should be a list of lists of mappings.
490
+ - response_format (pydantic.BaseModel, optional): The format of the response. Defaults to None.
491
+ - batched (bool, optional): Whether to process messages in batch mode. Defaults to False.
492
+ Returns:
493
+ - str: The generated chat completion(s). If batched is True, returns a list of strings.
494
+ Raises:
495
+ - ValueError: If requested model is not supported by Litellm or TogetherAI.
496
+ """
497
+
498
+ # Check for empty messages list
499
+ if not messages or messages == []:
500
+ raise ValueError("Messages cannot be empty")
501
+
502
+ # Add validation
503
+ if batched:
504
+ validate_batched_chat_messages(messages)
505
+ else:
506
+ validate_chat_messages(messages)
507
+
508
+ if batched and model_type in TOGETHER_SUPPORTED_MODELS:
509
+ return query_together_api_multiple_calls(models=[model_type] * len(messages),
510
+ messages=messages,
511
+ response_formats=[response_format] * len(messages))
512
+ elif batched and model_type in LITELLM_SUPPORTED_MODELS:
513
+ return query_litellm_api_multiple_calls(models=[model_type] * len(messages),
514
+ messages=messages,
515
+ response_format=response_format)
516
+ elif not batched and model_type in TOGETHER_SUPPORTED_MODELS:
517
+ return fetch_together_api_response(model=model_type,
518
+ messages=messages,
519
+ response_format=response_format)
520
+ elif not batched and model_type in LITELLM_SUPPORTED_MODELS:
521
+ return fetch_litellm_api_response(model=model_type,
522
+ messages=messages,
523
+ response_format=response_format)
524
+
525
+
526
+
527
+ raise ValueError(f"Model {model_type} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again.")
528
+
529
+
530
+ async def aget_chat_completion(model_type: str,
531
+ messages : Union[List[Mapping], List[List[Mapping]]],
532
+ response_format: pydantic.BaseModel = None,
533
+ batched: bool = False
534
+ ) -> Union[str, List[str]]:
535
+ """
536
+ ASYNCHRONOUSLY generates chat completions using a single model and potentially several messages. Supports closed-source and OSS models.
537
+
538
+ Parameters:
539
+ - model_type (str): The type of model to use for generating completions.
540
+ - messages (Union[List[Mapping], List[List[Mapping]]]): The messages to be used for generating completions.
541
+ If batched is True, this should be a list of lists of mappings.
542
+ - response_format (pydantic.BaseModel, optional): The format of the response. Defaults to None.
543
+ - batched (bool, optional): Whether to process messages in batch mode. Defaults to False.
544
+ Returns:
545
+ - str: The generated chat completion(s). If batched is True, returns a list of strings.
546
+ Raises:
547
+ - ValueError: If requested model is not supported by Litellm or TogetherAI.
548
+ """
549
+ debug(f"Starting chat completion for model {model_type}, batched={batched}")
550
+
551
+ if batched:
552
+ validate_batched_chat_messages(messages)
553
+ else:
554
+ validate_chat_messages(messages)
555
+
556
+ if batched and model_type in TOGETHER_SUPPORTED_MODELS:
557
+ debug("Using batched Together API call")
558
+ return await aquery_together_api_multiple_calls(models=[model_type] * len(messages),
559
+ messages=messages,
560
+ response_formats=[response_format] * len(messages))
561
+ elif batched and model_type in LITELLM_SUPPORTED_MODELS:
562
+ debug("Using batched LiteLLM API call")
563
+ return await aquery_litellm_api_multiple_calls(models=[model_type] * len(messages),
564
+ messages=messages,
565
+ response_formats=[response_format] * len(messages))
566
+ elif not batched and model_type in TOGETHER_SUPPORTED_MODELS:
567
+ debug("Using single Together API call")
568
+ return await afetch_together_api_response(model=model_type,
569
+ messages=messages,
570
+ response_format=response_format)
571
+ elif not batched and model_type in LITELLM_SUPPORTED_MODELS:
572
+ debug("Using single LiteLLM API call")
573
+ return await afetch_litellm_api_response(model=model_type,
574
+ messages=messages,
575
+ response_format=response_format)
576
+
577
+ error(f"Model {model_type} not supported by either API")
578
+ raise ValueError(f"Model {model_type} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again.")
579
+
580
+
581
+ def get_completion_multiple_models(models: List[str], messages: List[List[Mapping]], response_formats: List[pydantic.BaseModel] = None) -> List[str]:
582
+ """
583
+ Retrieves completions for a single prompt from multiple models in parallel. Supports closed-source and OSS models.
584
+
585
+ Args:
586
+ models (List[str]): List of models to query
587
+ messages (List[List[Mapping]]): List of messages to query. Each inner object corresponds to a single prompt.
588
+ response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
589
+
590
+ Returns:
591
+ List[str]: List of completions from the models in the order of the input models
592
+ Raises:
593
+ ValueError: If a model is not supported by Litellm or Together
594
+ """
595
+ debug(f"Starting multiple model completion for {len(models)} models")
596
+
597
+ if models is None or models == []:
598
+ raise ValueError("Models list cannot be empty")
599
+
600
+ validate_batched_chat_messages(messages)
601
+
602
+ if len(models) != len(messages):
603
+ error(f"Model/message count mismatch: {len(models)} vs {len(messages)}")
604
+ raise ValueError(f"Number of models and messages must be the same: {len(models)} != {len(messages)}")
605
+ if response_formats is None:
606
+ response_formats = [None] * len(models)
607
+ # Partition the model requests into TogetherAI and Litellm models, but keep the ordering saved
608
+ together_calls, litellm_calls = {}, {} # index -> model, message, response_format
609
+ together_responses, litellm_responses = [], []
610
+ for idx, (model, message, r_format) in enumerate(zip(models, messages, response_formats)):
611
+ if model in TOGETHER_SUPPORTED_MODELS:
612
+ debug(f"Model {model} routed to Together API")
613
+ together_calls[idx] = (model, message, r_format)
614
+ elif model in LITELLM_SUPPORTED_MODELS:
615
+ debug(f"Model {model} routed to LiteLLM API")
616
+ litellm_calls[idx] = (model, message, r_format)
617
+ else:
618
+ error(f"Model {model} not supported by either API")
619
+ raise ValueError(f"Model {model} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again.")
620
+
621
+ # Add validation before processing
622
+ for msg_list in messages:
623
+ validate_chat_messages(msg_list)
624
+
625
+ # Get the responses from the TogetherAI models
626
+ # List of responses from the TogetherAI models in order of the together_calls dict
627
+ if together_calls:
628
+ debug(f"Executing {len(together_calls)} Together API calls")
629
+ together_responses = query_together_api_multiple_calls(models=[model for model, _, _ in together_calls.values()],
630
+ messages=[message for _, message, _ in together_calls.values()],
631
+ response_formats=[format for _, _, format in together_calls.values()])
632
+
633
+ # Get the responses from the Litellm models
634
+ if litellm_calls:
635
+ debug(f"Executing {len(litellm_calls)} LiteLLM API calls")
636
+ litellm_responses = query_litellm_api_multiple_calls(models=[model for model, _, _ in litellm_calls.values()],
637
+ messages=[message for _, message, _ in litellm_calls.values()],
638
+ response_formats=[format for _, _, format in litellm_calls.values()])
639
+
640
+ # Merge the responses in the order of the original models
641
+ debug("Merging responses")
642
+ out = [None] * len(models)
643
+ for idx, (model, message, r_format) in together_calls.items():
644
+ out[idx] = together_responses.pop(0)
645
+ for idx, (model, message, r_format) in litellm_calls.items():
646
+ out[idx] = litellm_responses.pop(0)
647
+ debug("Multiple model completion finished")
648
+ return out
649
+
650
+
651
+ async def aget_completion_multiple_models(models: List[str], messages: List[List[Mapping]], response_formats: List[pydantic.BaseModel] = None) -> List[str]:
652
+ """
653
+ ASYNCHRONOUSLY retrieves completions for a single prompt from multiple models in parallel. Supports closed-source and OSS models.
654
+
655
+ Args:
656
+ models (List[str]): List of models to query
657
+ messages (List[List[Mapping]]): List of messages to query. Each inner object corresponds to a single prompt.
658
+ response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
659
+
660
+ Returns:
661
+ List[str]: List of completions from the models in the order of the input models
662
+ Raises:
663
+ ValueError: If a model is not supported by Litellm or Together
664
+ """
665
+ if models is None or models == []:
666
+ raise ValueError("Models list cannot be empty")
667
+
668
+ if len(models) != len(messages):
669
+ raise ValueError(f"Number of models and messages must be the same: {len(models)} != {len(messages)}")
670
+ if response_formats is None:
671
+ response_formats = [None] * len(models)
672
+
673
+ validate_batched_chat_messages(messages)
674
+
675
+ # Partition the model requests into TogetherAI and Litellm models, but keep the ordering saved
676
+ together_calls, litellm_calls = {}, {} # index -> model, message, response_format
677
+ together_responses, litellm_responses = [], []
678
+ for idx, (model, message, r_format) in enumerate(zip(models, messages, response_formats)):
679
+ if model in TOGETHER_SUPPORTED_MODELS:
680
+ together_calls[idx] = (model, message, r_format)
681
+ elif model in LITELLM_SUPPORTED_MODELS:
682
+ litellm_calls[idx] = (model, message, r_format)
683
+ else:
684
+ raise ValueError(f"Model {model} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again.")
685
+
686
+ # Add validation before processing
687
+ for msg_list in messages:
688
+ validate_chat_messages(msg_list)
689
+
690
+ # Get the responses from the TogetherAI models
691
+ # List of responses from the TogetherAI models in order of the together_calls dict
692
+ if together_calls:
693
+ together_responses = await aquery_together_api_multiple_calls(
694
+ models=[model for model, _, _ in together_calls.values()],
695
+ messages=[message for _, message, _ in together_calls.values()],
696
+ response_formats=[format for _, _, format in together_calls.values()]
697
+ )
698
+
699
+ # Get the responses from the Litellm models
700
+ if litellm_calls:
701
+ litellm_responses = await aquery_litellm_api_multiple_calls(
702
+ models=[model for model, _, _ in litellm_calls.values()],
703
+ messages=[message for _, message, _ in litellm_calls.values()],
704
+ response_formats=[format for _, _, format in litellm_calls.values()]
705
+ )
706
+
707
+ # Merge the responses in the order of the original models
708
+ out = [None] * len(models)
709
+ for idx, (model, message, r_format) in together_calls.items():
710
+ out[idx] = together_responses.pop(0)
711
+ for idx, (model, message, r_format) in litellm_calls.items():
712
+ out[idx] = litellm_responses.pop(0)
713
+ return out
714
+
715
+
716
+ if __name__ == "__main__":
717
+
718
+ # Batched
719
+ pprint.pprint(get_chat_completion(
720
+ model_type="LLAMA3_405B_INSTRUCT_TURBO",
721
+ messages=[
722
+ [
723
+ {"role": "system", "content": "You are a helpful assistant."},
724
+ {"role": "user", "content": "What is the capital of France?"},
725
+ ],
726
+ [
727
+ {"role": "system", "content": "You are a helpful assistant."},
728
+ {"role": "user", "content": "What is the capital of Japan?"},
729
+ ]
730
+ ],
731
+ batched=True
732
+ ))
733
+
734
+ # Non batched
735
+ pprint.pprint(get_chat_completion(
736
+ model_type="LLAMA3_8B_INSTRUCT_TURBO",
737
+ messages=[
738
+ {"role": "system", "content": "You are a helpful assistant."},
739
+ {"role": "user", "content": "What is the capital of France?"},
740
+ ],
741
+ batched=False
742
+ ))
743
+
744
+ # Batched single completion to multiple models
745
+ pprint.pprint(get_completion_multiple_models(
746
+ models=[
747
+ "LLAMA3_70B_INSTRUCT_TURBO", "LLAMA3_405B_INSTRUCT_TURBO", "gpt-4o-mini"
748
+ ],
749
+ messages=[
750
+ [
751
+ {"role": "system", "content": "You are a helpful assistant."},
752
+ {"role": "user", "content": "What is the capital of China?"},
753
+ ],
754
+ [
755
+ {"role": "system", "content": "You are a helpful assistant."},
756
+ {"role": "user", "content": "What is the capital of France?"},
757
+ ],
758
+ [
759
+ {"role": "system", "content": "You are a helpful assistant."},
760
+ {"role": "user", "content": "What is the capital of Japan?"},
761
+ ]
762
+ ]
763
+ ))