judgeval 0.0.44__py3-none-any.whl → 0.0.46__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. judgeval/__init__.py +5 -4
  2. judgeval/clients.py +6 -6
  3. judgeval/common/__init__.py +7 -2
  4. judgeval/common/exceptions.py +2 -3
  5. judgeval/common/logger.py +74 -49
  6. judgeval/common/s3_storage.py +30 -23
  7. judgeval/common/tracer.py +1273 -939
  8. judgeval/common/utils.py +416 -244
  9. judgeval/constants.py +73 -61
  10. judgeval/data/__init__.py +1 -1
  11. judgeval/data/custom_example.py +3 -2
  12. judgeval/data/datasets/dataset.py +80 -54
  13. judgeval/data/datasets/eval_dataset_client.py +131 -181
  14. judgeval/data/example.py +67 -43
  15. judgeval/data/result.py +11 -9
  16. judgeval/data/scorer_data.py +4 -2
  17. judgeval/data/tool.py +25 -16
  18. judgeval/data/trace.py +57 -29
  19. judgeval/data/trace_run.py +5 -11
  20. judgeval/evaluation_run.py +22 -82
  21. judgeval/integrations/langgraph.py +546 -184
  22. judgeval/judges/base_judge.py +1 -2
  23. judgeval/judges/litellm_judge.py +33 -11
  24. judgeval/judges/mixture_of_judges.py +128 -78
  25. judgeval/judges/together_judge.py +22 -9
  26. judgeval/judges/utils.py +14 -5
  27. judgeval/judgment_client.py +259 -271
  28. judgeval/rules.py +169 -142
  29. judgeval/run_evaluation.py +462 -305
  30. judgeval/scorers/api_scorer.py +20 -11
  31. judgeval/scorers/exceptions.py +1 -0
  32. judgeval/scorers/judgeval_scorer.py +77 -58
  33. judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +46 -15
  34. judgeval/scorers/judgeval_scorers/api_scorers/answer_correctness.py +3 -2
  35. judgeval/scorers/judgeval_scorers/api_scorers/answer_relevancy.py +3 -2
  36. judgeval/scorers/judgeval_scorers/api_scorers/classifier_scorer.py +12 -11
  37. judgeval/scorers/judgeval_scorers/api_scorers/comparison.py +7 -5
  38. judgeval/scorers/judgeval_scorers/api_scorers/contextual_precision.py +3 -2
  39. judgeval/scorers/judgeval_scorers/api_scorers/contextual_recall.py +3 -2
  40. judgeval/scorers/judgeval_scorers/api_scorers/contextual_relevancy.py +5 -2
  41. judgeval/scorers/judgeval_scorers/api_scorers/derailment_scorer.py +2 -1
  42. judgeval/scorers/judgeval_scorers/api_scorers/execution_order.py +17 -8
  43. judgeval/scorers/judgeval_scorers/api_scorers/faithfulness.py +3 -2
  44. judgeval/scorers/judgeval_scorers/api_scorers/groundedness.py +3 -2
  45. judgeval/scorers/judgeval_scorers/api_scorers/hallucination.py +3 -2
  46. judgeval/scorers/judgeval_scorers/api_scorers/instruction_adherence.py +3 -2
  47. judgeval/scorers/judgeval_scorers/api_scorers/json_correctness.py +8 -9
  48. judgeval/scorers/judgeval_scorers/api_scorers/summarization.py +4 -4
  49. judgeval/scorers/judgeval_scorers/api_scorers/tool_dependency.py +5 -5
  50. judgeval/scorers/judgeval_scorers/api_scorers/tool_order.py +5 -2
  51. judgeval/scorers/judgeval_scorers/classifiers/text2sql/text2sql_scorer.py +9 -10
  52. judgeval/scorers/prompt_scorer.py +48 -37
  53. judgeval/scorers/score.py +86 -53
  54. judgeval/scorers/utils.py +11 -7
  55. judgeval/tracer/__init__.py +1 -1
  56. judgeval/utils/alerts.py +23 -12
  57. judgeval/utils/{data_utils.py → file_utils.py} +5 -9
  58. judgeval/utils/requests.py +29 -0
  59. judgeval/version_check.py +5 -2
  60. {judgeval-0.0.44.dist-info → judgeval-0.0.46.dist-info}/METADATA +79 -135
  61. judgeval-0.0.46.dist-info/RECORD +69 -0
  62. judgeval-0.0.44.dist-info/RECORD +0 -68
  63. {judgeval-0.0.44.dist-info → judgeval-0.0.46.dist-info}/WHEEL +0 -0
  64. {judgeval-0.0.44.dist-info → judgeval-0.0.46.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/utils.py CHANGED
@@ -2,7 +2,7 @@
2
2
  This file contains utility functions used in repo scripts
3
3
 
4
4
  For API calling, we support:
5
- - parallelized model calls on the same prompt
5
+ - parallelized model calls on the same prompt
6
6
  - batched model calls on different prompts
7
7
 
8
8
  NOTE: any function beginning with 'a', e.g. 'afetch_together_api_response', is an asynchronous function
@@ -13,9 +13,9 @@ import asyncio
13
13
  import concurrent.futures
14
14
  import os
15
15
  from types import TracebackType
16
- import requests
16
+ from judgeval.utils.requests import requests
17
17
  import pprint
18
- from typing import Any, Dict, List, Literal, Mapping, Optional, TypeAlias, Union
18
+ from typing import Any, Dict, List, Mapping, Optional, TypeAlias, Union, TypeGuard
19
19
 
20
20
  # Third-party imports
21
21
  import litellm
@@ -24,7 +24,13 @@ from dotenv import load_dotenv
24
24
 
25
25
  # Local application/library-specific imports
26
26
  from judgeval.clients import async_together_client, together_client
27
- from judgeval.constants import *
27
+ from judgeval.constants import (
28
+ ACCEPTABLE_MODELS,
29
+ MAX_WORKER_THREADS,
30
+ ROOT_API,
31
+ TOGETHER_SUPPORTED_MODELS,
32
+ LITELLM_SUPPORTED_MODELS,
33
+ )
28
34
  from judgeval.common.logger import debug, error
29
35
 
30
36
 
@@ -32,72 +38,80 @@ class CustomModelParameters(pydantic.BaseModel):
32
38
  model_name: str
33
39
  secret_key: str
34
40
  litellm_base_url: str
35
-
36
- @pydantic.field_validator('model_name')
41
+
42
+ @pydantic.field_validator("model_name")
37
43
  def validate_model_name(cls, v):
38
44
  if not v:
39
45
  raise ValueError("Model name cannot be empty")
40
46
  return v
41
-
42
- @pydantic.field_validator('secret_key')
47
+
48
+ @pydantic.field_validator("secret_key")
43
49
  def validate_secret_key(cls, v):
44
50
  if not v:
45
51
  raise ValueError("Secret key cannot be empty")
46
52
  return v
47
-
48
- @pydantic.field_validator('litellm_base_url')
53
+
54
+ @pydantic.field_validator("litellm_base_url")
49
55
  def validate_litellm_base_url(cls, v):
50
56
  if not v:
51
57
  raise ValueError("Litellm base URL cannot be empty")
52
58
  return v
53
59
 
60
+
54
61
  class ChatCompletionRequest(pydantic.BaseModel):
55
62
  model: str
56
63
  messages: List[Dict[str, str]]
57
64
  response_format: Optional[Union[pydantic.BaseModel, Dict[str, Any]]] = None
58
-
59
- @pydantic.field_validator('messages')
65
+
66
+ @pydantic.field_validator("messages")
60
67
  def validate_messages(cls, messages):
61
68
  if not messages:
62
69
  raise ValueError("Messages cannot be empty")
63
-
70
+
64
71
  for msg in messages:
65
72
  if not isinstance(msg, dict):
66
73
  raise TypeError("Message must be a dictionary")
67
- if 'role' not in msg:
74
+ if "role" not in msg:
68
75
  raise ValueError("Message missing required 'role' field")
69
- if 'content' not in msg:
76
+ if "content" not in msg:
70
77
  raise ValueError("Message missing required 'content' field")
71
- if msg['role'] not in ['system', 'user', 'assistant']:
72
- raise ValueError(f"Invalid role '{msg['role']}'. Must be 'system', 'user', or 'assistant'")
73
-
78
+ if msg["role"] not in ["system", "user", "assistant"]:
79
+ raise ValueError(
80
+ f"Invalid role '{msg['role']}'. Must be 'system', 'user', or 'assistant'"
81
+ )
82
+
74
83
  return messages
75
84
 
76
- @pydantic.field_validator('model')
85
+ @pydantic.field_validator("model")
77
86
  def validate_model(cls, model):
78
87
  if not model:
79
88
  raise ValueError("Model cannot be empty")
80
89
  if model not in ACCEPTABLE_MODELS:
81
90
  raise ValueError(f"Model {model} is not in the list of supported models.")
82
91
  return model
83
-
84
- @pydantic.field_validator('response_format', mode='before')
92
+
93
+ @pydantic.field_validator("response_format", mode="before")
85
94
  def validate_response_format(cls, response_format):
86
95
  if response_format is not None:
87
96
  if not isinstance(response_format, (dict, pydantic.BaseModel)):
88
- raise TypeError("Response format must be a dictionary or pydantic model")
97
+ raise TypeError(
98
+ "Response format must be a dictionary or pydantic model"
99
+ )
89
100
  # Optional: Add additional validation for required fields if needed
90
101
  # For example, checking for 'type': 'json' in OpenAI's format
91
102
  return response_format
92
103
 
93
- os.environ['LITELLM_LOG'] = 'DEBUG'
104
+
105
+ os.environ["LITELLM_LOG"] = "DEBUG"
94
106
 
95
107
  load_dotenv()
96
108
 
109
+
97
110
  def read_file(file_path: str) -> str:
98
- with open(file_path, "r", encoding='utf-8') as file:
111
+ with open(file_path, "r", encoding="utf-8") as file:
99
112
  return file.read()
100
113
 
114
+
101
115
  def validate_api_key(judgment_api_key: str):
102
116
  """
103
117
  Validates that the user api key is valid
@@ -109,66 +123,67 @@ def validate_api_key(judgment_api_key: str):
109
123
  "Authorization": f"Bearer {judgment_api_key}",
110
124
  },
111
125
  json={}, # Empty body now
112
- verify=True
126
+ verify=True,
113
127
  )
114
128
  if response.status_code == 200:
115
129
  return True, response.json()
116
130
  else:
117
131
  return False, response.json().get("detail", "Error validating API key")
118
132
 
119
- def fetch_together_api_response(model: str, messages: List[Mapping], response_format: pydantic.BaseModel = None) -> str:
133
+
134
+ def fetch_together_api_response(
135
+ model: str, messages: List[Mapping], response_format: pydantic.BaseModel = None
136
+ ) -> str:
120
137
  """
121
138
  Fetches a single response from the Together API for a given model and messages.
122
139
  """
123
140
  # Validate request
124
141
  if messages is None or messages == []:
125
142
  raise ValueError("Messages cannot be empty")
126
-
143
+
127
144
  request = ChatCompletionRequest(
128
- model=model,
129
- messages=messages,
130
- response_format=response_format
145
+ model=model, messages=messages, response_format=response_format
131
146
  )
132
-
147
+
133
148
  debug(f"Calling Together API with model: {request.model}")
134
149
  debug(f"Messages: {request.messages}")
135
-
150
+
136
151
  if request.response_format is not None:
137
152
  debug(f"Using response format: {request.response_format}")
138
153
  response = together_client.chat.completions.create(
139
154
  model=request.model,
140
155
  messages=request.messages,
141
- response_format=request.response_format
156
+ response_format=request.response_format,
142
157
  )
143
158
  else:
144
159
  response = together_client.chat.completions.create(
145
160
  model=request.model,
146
161
  messages=request.messages,
147
162
  )
148
-
163
+
149
164
  debug(f"Received response: {response.choices[0].message.content[:100]}...")
150
165
  return response.choices[0].message.content
151
166
 
152
167
 
153
- async def afetch_together_api_response(model: str, messages: List[Mapping], response_format: pydantic.BaseModel = None) -> str:
168
+ async def afetch_together_api_response(
169
+ model: str, messages: List[Mapping], response_format: pydantic.BaseModel = None
170
+ ) -> str:
154
171
  """
155
172
  ASYNCHRONOUSLY Fetches a single response from the Together API for a given model and messages.
156
173
  """
157
174
  request = ChatCompletionRequest(
158
- model=model,
159
- messages=messages,
160
- response_format=response_format
175
+ model=model, messages=messages, response_format=response_format
161
176
  )
162
-
177
+
163
178
  debug(f"Calling Together API with model: {request.model}")
164
179
  debug(f"Messages: {request.messages}")
165
-
180
+
166
181
  if request.response_format is not None:
167
182
  debug(f"Using response format: {request.response_format}")
168
183
  response = await async_together_client.chat.completions.create(
169
184
  model=request.model,
170
185
  messages=request.messages,
171
- response_format=request.response_format
186
+ response_format=request.response_format,
172
187
  )
173
188
  else:
174
189
  response = await async_together_client.chat.completions.create(
@@ -178,7 +193,11 @@ async def afetch_together_api_response(model: str, messages: List[Mapping], resp
178
193
  return response.choices[0].message.content
179
194
 
180
195
 
181
- def query_together_api_multiple_calls(models: List[str], messages: List[List[Mapping]], response_formats: List[pydantic.BaseModel] = None) -> List[str]:
196
+ def query_together_api_multiple_calls(
197
+ models: List[str],
198
+ messages: List[List[Mapping]],
199
+ response_formats: List[pydantic.BaseModel] | None = None,
200
+ ) -> List[Union[str, None]]:
182
201
  """
183
202
  Queries the Together API for multiple calls in parallel
184
203
 
@@ -197,25 +216,35 @@ def query_together_api_multiple_calls(models: List[str], messages: List[List[Map
197
216
  # Validate all models are supported
198
217
  for model in models:
199
218
  if model not in ACCEPTABLE_MODELS:
200
- raise ValueError(f"Model {model} is not in the list of supported models: {ACCEPTABLE_MODELS}.")
219
+ raise ValueError(
220
+ f"Model {model} is not in the list of supported models: {ACCEPTABLE_MODELS}."
221
+ )
201
222
 
202
223
  # Validate input lengths match
203
224
  if response_formats is None:
204
225
  response_formats = [None] * len(models)
205
226
  if not (len(models) == len(messages) == len(response_formats)):
206
- raise ValueError("Number of models, messages, and response formats must be the same")
227
+ raise ValueError(
228
+ "Number of models, messages, and response formats must be the same"
229
+ )
207
230
 
208
231
  # Validate message format
209
232
  validate_batched_chat_messages(messages)
210
233
 
211
- num_workers = int(os.getenv('NUM_WORKER_THREADS', MAX_WORKER_THREADS))
234
+ num_workers = int(os.getenv("NUM_WORKER_THREADS", MAX_WORKER_THREADS))
212
235
  # Initialize results to maintain ordered outputs
213
- out = [None] * len(messages)
236
+ out: List[str | None] = [None] * len(messages)
214
237
  with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
215
238
  # Submit all queries to together API with index, gets back the response content
216
- futures = {executor.submit(fetch_together_api_response, model, message, response_format): idx \
217
- for idx, (model, message, response_format) in enumerate(zip(models, messages, response_formats))}
218
-
239
+ futures = {
240
+ executor.submit(
241
+ fetch_together_api_response, model, message, response_format
242
+ ): idx
243
+ for idx, (model, message, response_format) in enumerate(
244
+ zip(models, messages, response_formats)
245
+ )
246
+ }
247
+
219
248
  # Collect results as they complete -- result is response content
220
249
  for future in concurrent.futures.as_completed(futures):
221
250
  idx = futures[future]
@@ -223,11 +252,15 @@ def query_together_api_multiple_calls(models: List[str], messages: List[List[Map
223
252
  out[idx] = future.result()
224
253
  except Exception as e:
225
254
  error(f"Error in parallel call {idx}: {str(e)}")
226
- out[idx] = None
255
+ out[idx] = None
227
256
  return out
228
257
 
229
258
 
230
- async def aquery_together_api_multiple_calls(models: List[str], messages: List[List[Mapping]], response_formats: List[pydantic.BaseModel] = None) -> List[str]:
259
+ async def aquery_together_api_multiple_calls(
260
+ models: List[str],
261
+ messages: List[List[Mapping]],
262
+ response_formats: List[pydantic.BaseModel] | None = None,
263
+ ) -> List[Union[str, None]]:
231
264
  """
232
265
  Queries the Together API for multiple calls in parallel
233
266
 
@@ -246,57 +279,65 @@ async def aquery_together_api_multiple_calls(models: List[str], messages: List[L
246
279
  # Validate all models are supported
247
280
  for model in models:
248
281
  if model not in ACCEPTABLE_MODELS:
249
- raise ValueError(f"Model {model} is not in the list of supported models: {ACCEPTABLE_MODELS}.")
282
+ raise ValueError(
283
+ f"Model {model} is not in the list of supported models: {ACCEPTABLE_MODELS}."
284
+ )
250
285
 
251
286
  # Validate input lengths match
252
287
  if response_formats is None:
253
288
  response_formats = [None] * len(models)
254
289
  if not (len(models) == len(messages) == len(response_formats)):
255
- raise ValueError("Number of models, messages, and response formats must be the same")
290
+ raise ValueError(
291
+ "Number of models, messages, and response formats must be the same"
292
+ )
256
293
 
257
294
  # Validate message format
258
295
  validate_batched_chat_messages(messages)
259
296
 
260
297
  debug(f"Starting parallel Together API calls for {len(messages)} messages")
261
- out = [None] * len(messages)
262
-
298
+ out: List[Union[str, None]] = [None] * len(messages)
299
+
263
300
  async def fetch_and_store(idx, model, message, response_format):
264
301
  try:
265
302
  debug(f"Processing call {idx} with model {model}")
266
- out[idx] = await afetch_together_api_response(model, message, response_format)
303
+ out[idx] = await afetch_together_api_response(
304
+ model, message, response_format
305
+ )
267
306
  except Exception as e:
268
307
  error(f"Error in parallel call {idx}: {str(e)}")
269
308
  out[idx] = None
270
309
 
271
310
  tasks = [
272
311
  fetch_and_store(idx, model, message, response_format)
273
- for idx, (model, message, response_format) in enumerate(zip(models, messages, response_formats))
312
+ for idx, (model, message, response_format) in enumerate(
313
+ zip(models, messages, response_formats)
314
+ )
274
315
  ]
275
-
316
+
276
317
  await asyncio.gather(*tasks)
277
318
  debug(f"Completed {len(messages)} parallel calls")
278
319
  return out
279
320
 
280
321
 
281
- def fetch_litellm_api_response(model: str, messages: List[Mapping], response_format: pydantic.BaseModel = None) -> str:
322
+ def fetch_litellm_api_response(
323
+ model: str, messages: List[Mapping], response_format: pydantic.BaseModel = None
324
+ ) -> str:
282
325
  """
283
326
  Fetches a single response from the Litellm API for a given model and messages.
284
327
  """
285
328
  request = ChatCompletionRequest(
286
- model=model,
287
- messages=messages,
288
- response_format=response_format
329
+ model=model, messages=messages, response_format=response_format
289
330
  )
290
-
331
+
291
332
  debug(f"Calling LiteLLM API with model: {request.model}")
292
333
  debug(f"Messages: {request.messages}")
293
-
334
+
294
335
  if request.response_format is not None:
295
336
  debug(f"Using response format: {request.response_format}")
296
337
  response = litellm.completion(
297
338
  model=request.model,
298
339
  messages=request.messages,
299
- response_format=request.response_format
340
+ response_format=request.response_format,
300
341
  )
301
342
  else:
302
343
  response = litellm.completion(
@@ -306,23 +347,29 @@ def fetch_litellm_api_response(model: str, messages: List[Mapping], response_for
306
347
  return response.choices[0].message.content
307
348
 
308
349
 
309
- def fetch_custom_litellm_api_response(custom_model_parameters: CustomModelParameters, messages: List[Mapping], response_format: pydantic.BaseModel = None) -> str:
350
+ def fetch_custom_litellm_api_response(
351
+ custom_model_parameters: CustomModelParameters,
352
+ messages: List[Mapping],
353
+ response_format: pydantic.BaseModel = None,
354
+ ) -> str:
310
355
  if messages is None or messages == []:
311
356
  raise ValueError("Messages cannot be empty")
312
-
357
+
313
358
  if custom_model_parameters is None:
314
359
  raise ValueError("Custom model parameters cannot be empty")
315
-
360
+
316
361
  if not isinstance(custom_model_parameters, CustomModelParameters):
317
- raise ValueError("Custom model parameters must be a CustomModelParameters object")
318
-
362
+ raise ValueError(
363
+ "Custom model parameters must be a CustomModelParameters object"
364
+ )
365
+
319
366
  if response_format is not None:
320
367
  response = litellm.completion(
321
368
  model=custom_model_parameters.model_name,
322
369
  messages=messages,
323
370
  api_key=custom_model_parameters.secret_key,
324
371
  base_url=custom_model_parameters.litellm_base_url,
325
- response_format=response_format
372
+ response_format=response_format,
326
373
  )
327
374
  else:
328
375
  response = litellm.completion(
@@ -334,53 +381,61 @@ def fetch_custom_litellm_api_response(custom_model_parameters: CustomModelParame
334
381
  return response.choices[0].message.content
335
382
 
336
383
 
337
- async def afetch_litellm_api_response(model: str, messages: List[Mapping], response_format: pydantic.BaseModel = None) -> str:
384
+ async def afetch_litellm_api_response(
385
+ model: str, messages: List[Mapping], response_format: pydantic.BaseModel = None
386
+ ) -> str:
338
387
  """
339
388
  ASYNCHRONOUSLY Fetches a single response from the Litellm API for a given model and messages.
340
389
  """
341
390
  if messages is None or messages == []:
342
391
  raise ValueError("Messages cannot be empty")
343
-
392
+
344
393
  # Add validation
345
394
  validate_chat_messages(messages)
346
-
395
+
347
396
  if model not in ACCEPTABLE_MODELS:
348
- raise ValueError(f"Model {model} is not in the list of supported models: {ACCEPTABLE_MODELS}.")
349
-
397
+ raise ValueError(
398
+ f"Model {model} is not in the list of supported models: {ACCEPTABLE_MODELS}."
399
+ )
400
+
350
401
  if response_format is not None:
351
402
  response = await litellm.acompletion(
352
- model=model,
353
- messages=messages,
354
- response_format=response_format
403
+ model=model, messages=messages, response_format=response_format
355
404
  )
356
405
  else:
357
406
  response = await litellm.acompletion(
358
407
  model=model,
359
- messages=messages,
408
+ messages=messages,
360
409
  )
361
410
  return response.choices[0].message.content
362
411
 
363
412
 
364
- async def afetch_custom_litellm_api_response(custom_model_parameters: CustomModelParameters, messages: List[Mapping], response_format: pydantic.BaseModel = None) -> str:
413
+ async def afetch_custom_litellm_api_response(
414
+ custom_model_parameters: CustomModelParameters,
415
+ messages: List[Mapping],
416
+ response_format: pydantic.BaseModel = None,
417
+ ) -> str:
365
418
  """
366
419
  ASYNCHRONOUSLY Fetches a single response from the Litellm API for a given model and messages.
367
420
  """
368
421
  if messages is None or messages == []:
369
422
  raise ValueError("Messages cannot be empty")
370
-
423
+
371
424
  if custom_model_parameters is None:
372
425
  raise ValueError("Custom model parameters cannot be empty")
373
-
426
+
374
427
  if not isinstance(custom_model_parameters, CustomModelParameters):
375
- raise ValueError("Custom model parameters must be a CustomModelParameters object")
376
-
428
+ raise ValueError(
429
+ "Custom model parameters must be a CustomModelParameters object"
430
+ )
431
+
377
432
  if response_format is not None:
378
433
  response = await litellm.acompletion(
379
434
  model=custom_model_parameters.model_name,
380
435
  messages=messages,
381
436
  api_key=custom_model_parameters.secret_key,
382
437
  base_url=custom_model_parameters.litellm_base_url,
383
- response_format=response_format
438
+ response_format=response_format,
384
439
  )
385
440
  else:
386
441
  response = await litellm.acompletion(
@@ -392,26 +447,36 @@ async def afetch_custom_litellm_api_response(custom_model_parameters: CustomMode
392
447
  return response.choices[0].message.content
393
448
 
394
449
 
395
- def query_litellm_api_multiple_calls(models: List[str], messages: List[Mapping], response_formats: List[pydantic.BaseModel] = None) -> List[str]:
450
+ def query_litellm_api_multiple_calls(
451
+ models: List[str],
452
+ messages: List[List[Mapping]],
453
+ response_formats: List[pydantic.BaseModel] | None = None,
454
+ ) -> List[Union[str, None]]:
396
455
  """
397
456
  Queries the Litellm API for multiple calls in parallel
398
457
 
399
458
  Args:
400
459
  models (List[str]): List of models to query
401
- messages (List[Mapping]): List of messages to query
460
+ messages (List[List[Mapping]]): List of messages to query
402
461
  response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
403
462
 
404
463
  Returns:
405
464
  List[str]: Litellm responses for each model and message pair in order. Any exceptions in the thread call result in a None.
406
465
  """
407
- num_workers = int(os.getenv('NUM_WORKER_THREADS', MAX_WORKER_THREADS))
466
+ num_workers = int(os.getenv("NUM_WORKER_THREADS", MAX_WORKER_THREADS))
408
467
  # Initialize results to maintain ordered outputs
409
- out = [None] * len(messages)
468
+ out: List[Union[str, None]] = [None] * len(messages)
410
469
  with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
411
470
  # Submit all queries to Litellm API with index, gets back the response content
412
- futures = {executor.submit(fetch_litellm_api_response, model, message, response_format): idx \
413
- for idx, (model, message, response_format) in enumerate(zip(models, messages, response_formats))}
414
-
471
+ futures = {
472
+ executor.submit(
473
+ fetch_litellm_api_response, model, message, response_format
474
+ ): idx
475
+ for idx, (model, message, response_format) in enumerate(
476
+ zip(models, messages, response_formats or [None] * len(messages))
477
+ )
478
+ }
479
+
415
480
  # Collect results as they complete -- result is response content
416
481
  for future in concurrent.futures.as_completed(futures):
417
482
  idx = futures[future]
@@ -419,37 +484,45 @@ def query_litellm_api_multiple_calls(models: List[str], messages: List[Mapping],
419
484
  out[idx] = future.result()
420
485
  except Exception as e:
421
486
  error(f"Error in parallel call {idx}: {str(e)}")
422
- out[idx] = None
487
+ out[idx] = None
423
488
  return out
424
489
 
425
490
 
426
- async def aquery_litellm_api_multiple_calls(models: List[str], messages: List[Mapping], response_formats: List[pydantic.BaseModel] = None) -> List[str]:
491
+ async def aquery_litellm_api_multiple_calls(
492
+ models: List[str],
493
+ messages: List[List[Mapping]],
494
+ response_formats: List[pydantic.BaseModel] | None = None,
495
+ ) -> List[Union[str, None]]:
427
496
  """
428
497
  Queries the Litellm API for multiple calls in parallel
429
498
 
430
499
  Args:
431
500
  models (List[str]): List of models to query
432
- messages (List[Mapping]): List of messages to query
501
+ messages (List[List[Mapping]]): List of messages to query
433
502
  response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
434
-
503
+
435
504
  Returns:
436
505
  List[str]: Litellm responses for each model and message pair in order. Any exceptions in the thread call result in a None.
437
506
  """
438
507
  # Initialize results to maintain ordered outputs
439
- out = [None] * len(messages)
440
-
508
+ out: List[Union[str, None]] = [None] * len(messages)
509
+
441
510
  async def fetch_and_store(idx, model, message, response_format):
442
511
  try:
443
- out[idx] = await afetch_litellm_api_response(model, message, response_format)
512
+ out[idx] = await afetch_litellm_api_response(
513
+ model, message, response_format
514
+ )
444
515
  except Exception as e:
445
516
  error(f"Error in parallel call {idx}: {str(e)}")
446
517
  out[idx] = None
447
518
 
448
519
  tasks = [
449
520
  fetch_and_store(idx, model, message, response_format)
450
- for idx, (model, message, response_format) in enumerate(zip(models, messages, response_formats))
521
+ for idx, (model, message, response_format) in enumerate(
522
+ zip(models, messages, response_formats or [None] * len(messages))
523
+ )
451
524
  ]
452
-
525
+
453
526
  await asyncio.gather(*tasks)
454
527
  return out
455
528
 
@@ -458,56 +531,75 @@ def validate_chat_messages(messages, batched: bool = False):
458
531
  """Validate chat message format before API call"""
459
532
  if not isinstance(messages, list):
460
533
  raise TypeError("Messages must be a list")
461
-
534
+
462
535
  for msg in messages:
463
536
  if not isinstance(msg, dict):
464
537
  if batched and not isinstance(msg, list):
465
538
  raise TypeError("Each message must be a list")
466
539
  elif not batched:
467
540
  raise TypeError("Message must be a dictionary")
468
- if 'role' not in msg:
541
+ if "role" not in msg:
469
542
  raise ValueError("Message missing required 'role' field")
470
- if 'content' not in msg:
543
+ if "content" not in msg:
471
544
  raise ValueError("Message missing required 'content' field")
472
- if msg['role'] not in ['system', 'user', 'assistant']:
473
- raise ValueError(f"Invalid role '{msg['role']}'. Must be 'system', 'user', or 'assistant'")
545
+ if msg["role"] not in ["system", "user", "assistant"]:
546
+ raise ValueError(
547
+ f"Invalid role '{msg['role']}'. Must be 'system', 'user', or 'assistant'"
548
+ )
549
+
474
550
 
475
- def validate_batched_chat_messages(messages: List[List[Mapping]]):
551
+ def validate_batched_chat_messages(messages):
476
552
  """
477
553
  Validate format of batched chat messages before API call
478
-
554
+
479
555
  Args:
480
556
  messages (List[List[Mapping]]): List of message lists, where each inner list contains
481
557
  message dictionaries with 'role' and 'content' fields
482
-
558
+
483
559
  Raises:
484
560
  TypeError: If messages format is invalid
485
561
  ValueError: If message content is invalid
486
562
  """
487
563
  if not isinstance(messages, list):
488
564
  raise TypeError("Batched messages must be a list")
489
-
565
+
490
566
  if not messages:
491
567
  raise ValueError("Batched messages cannot be empty")
492
-
568
+
493
569
  for message_list in messages:
494
570
  if not isinstance(message_list, list):
495
571
  raise TypeError("Each batch item must be a list of messages")
496
-
572
+
497
573
  # Validate individual messages using existing function
498
574
  validate_chat_messages(message_list)
499
575
 
500
- def get_chat_completion(model_type: str,
501
- messages : Union[List[Mapping], List[List[Mapping]]],
502
- response_format: pydantic.BaseModel = None,
503
- batched: bool = False
504
- ) -> Union[str, List[str]]:
576
+
577
+ def is_batched_messages(
578
+ messages: Union[List[Mapping], List[List[Mapping]]],
579
+ ) -> TypeGuard[List[List[Mapping]]]:
580
+ return isinstance(messages, list) and all(isinstance(msg, list) for msg in messages)
581
+
582
+
583
+ def is_simple_messages(
584
+ messages: Union[List[Mapping], List[List[Mapping]]],
585
+ ) -> TypeGuard[List[Mapping]]:
586
+ return isinstance(messages, list) and all(
587
+ not isinstance(msg, list) for msg in messages
588
+ )
589
+
590
+
591
+ def get_chat_completion(
592
+ model_type: str,
593
+ messages: Union[List[Mapping], List[List[Mapping]]],
594
+ response_format: pydantic.BaseModel = None,
595
+ batched: bool = False,
596
+ ) -> Union[str, List[str | None]]:
505
597
  """
506
598
  Generates chat completions using a single model and potentially several messages. Supports closed-source and OSS models.
507
599
 
508
600
  Parameters:
509
601
  - model_type (str): The type of model to use for generating completions.
510
- - messages (Union[List[Mapping], List[List[Mapping]]]): The messages to be used for generating completions.
602
+ - messages (Union[List[Mapping], List[List[Mapping]]]): The messages to be used for generating completions.
511
603
  If batched is True, this should be a list of lists of mappings.
512
604
  - response_format (pydantic.BaseModel, optional): The format of the response. Defaults to None.
513
605
  - batched (bool, optional): Whether to process messages in batch mode. Defaults to False.
@@ -516,50 +608,71 @@ def get_chat_completion(model_type: str,
516
608
  Raises:
517
609
  - ValueError: If requested model is not supported by Litellm or TogetherAI.
518
610
  """
519
-
611
+
520
612
  # Check for empty messages list
521
613
  if not messages or messages == []:
522
614
  raise ValueError("Messages cannot be empty")
523
-
615
+
524
616
  # Add validation
525
617
  if batched:
526
618
  validate_batched_chat_messages(messages)
527
619
  else:
528
620
  validate_chat_messages(messages)
529
-
530
- if batched and model_type in TOGETHER_SUPPORTED_MODELS:
531
- return query_together_api_multiple_calls(models=[model_type] * len(messages),
532
- messages=messages,
533
- response_formats=[response_format] * len(messages))
534
- elif batched and model_type in LITELLM_SUPPORTED_MODELS:
535
- return query_litellm_api_multiple_calls(models=[model_type] * len(messages),
536
- messages=messages,
537
- response_format=response_format)
538
- elif not batched and model_type in TOGETHER_SUPPORTED_MODELS:
539
- return fetch_together_api_response(model=model_type,
540
- messages=messages,
541
- response_format=response_format)
542
- elif not batched and model_type in LITELLM_SUPPORTED_MODELS:
543
- return fetch_litellm_api_response(model=model_type,
544
- messages=messages,
545
- response_format=response_format)
546
-
547
-
548
-
549
- raise ValueError(f"Model {model_type} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again.")
550
-
551
-
552
- async def aget_chat_completion(model_type: str,
553
- messages : Union[List[Mapping], List[List[Mapping]]],
554
- response_format: pydantic.BaseModel = None,
555
- batched: bool = False
556
- ) -> Union[str, List[str]]:
621
+
622
+ if (
623
+ batched
624
+ and is_batched_messages(messages)
625
+ and model_type in TOGETHER_SUPPORTED_MODELS
626
+ ):
627
+ return query_together_api_multiple_calls(
628
+ models=[model_type] * len(messages),
629
+ messages=messages,
630
+ response_formats=[response_format] * len(messages),
631
+ )
632
+ elif (
633
+ batched
634
+ and is_batched_messages(messages)
635
+ and model_type in LITELLM_SUPPORTED_MODELS
636
+ ):
637
+ return query_litellm_api_multiple_calls(
638
+ models=[model_type] * len(messages),
639
+ messages=messages,
640
+ response_formats=[response_format] * len(messages),
641
+ )
642
+ elif (
643
+ not batched
644
+ and is_simple_messages(messages)
645
+ and model_type in TOGETHER_SUPPORTED_MODELS
646
+ ):
647
+ return fetch_together_api_response(
648
+ model=model_type, messages=messages, response_format=response_format
649
+ )
650
+ elif (
651
+ not batched
652
+ and is_simple_messages(messages)
653
+ and model_type in LITELLM_SUPPORTED_MODELS
654
+ ):
655
+ return fetch_litellm_api_response(
656
+ model=model_type, messages=messages, response_format=response_format
657
+ )
658
+
659
+ raise ValueError(
660
+ f"Model {model_type} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again."
661
+ )
662
+
663
+
664
+ async def aget_chat_completion(
665
+ model_type: str,
666
+ messages: Union[List[Mapping], List[List[Mapping]]],
667
+ response_format: pydantic.BaseModel = None,
668
+ batched: bool = False,
669
+ ) -> Union[str, List[str | None]]:
557
670
  """
558
671
  ASYNCHRONOUSLY generates chat completions using a single model and potentially several messages. Supports closed-source and OSS models.
559
672
 
560
673
  Parameters:
561
674
  - model_type (str): The type of model to use for generating completions.
562
- - messages (Union[List[Mapping], List[List[Mapping]]]): The messages to be used for generating completions.
675
+ - messages (Union[List[Mapping], List[List[Mapping]]]): The messages to be used for generating completions.
563
676
  If batched is True, this should be a list of lists of mappings.
564
677
  - response_format (pydantic.BaseModel, optional): The format of the response. Defaults to None.
565
678
  - batched (bool, optional): Whether to process messages in batch mode. Defaults to False.
@@ -569,38 +682,64 @@ async def aget_chat_completion(model_type: str,
569
682
  - ValueError: If requested model is not supported by Litellm or TogetherAI.
570
683
  """
571
684
  debug(f"Starting chat completion for model {model_type}, batched={batched}")
572
-
685
+
573
686
  if batched:
574
687
  validate_batched_chat_messages(messages)
575
688
  else:
576
689
  validate_chat_messages(messages)
577
-
578
- if batched and model_type in TOGETHER_SUPPORTED_MODELS:
690
+
691
+ if (
692
+ batched
693
+ and is_batched_messages(messages)
694
+ and model_type in TOGETHER_SUPPORTED_MODELS
695
+ ):
579
696
  debug("Using batched Together API call")
580
- return await aquery_together_api_multiple_calls(models=[model_type] * len(messages),
581
- messages=messages,
582
- response_formats=[response_format] * len(messages))
583
- elif batched and model_type in LITELLM_SUPPORTED_MODELS:
697
+ return await aquery_together_api_multiple_calls(
698
+ models=[model_type] * len(messages),
699
+ messages=messages,
700
+ response_formats=[response_format] * len(messages),
701
+ )
702
+ elif (
703
+ batched
704
+ and is_batched_messages(messages)
705
+ and model_type in LITELLM_SUPPORTED_MODELS
706
+ ):
584
707
  debug("Using batched LiteLLM API call")
585
- return await aquery_litellm_api_multiple_calls(models=[model_type] * len(messages),
586
- messages=messages,
587
- response_formats=[response_format] * len(messages))
588
- elif not batched and model_type in TOGETHER_SUPPORTED_MODELS:
708
+ return await aquery_litellm_api_multiple_calls(
709
+ models=[model_type] * len(messages),
710
+ messages=messages,
711
+ response_formats=[response_format] * len(messages),
712
+ )
713
+ elif (
714
+ not batched
715
+ and is_simple_messages(messages)
716
+ and model_type in TOGETHER_SUPPORTED_MODELS
717
+ ):
589
718
  debug("Using single Together API call")
590
- return await afetch_together_api_response(model=model_type,
591
- messages=messages,
592
- response_format=response_format)
593
- elif not batched and model_type in LITELLM_SUPPORTED_MODELS:
719
+ return await afetch_together_api_response(
720
+ model=model_type, messages=messages, response_format=response_format
721
+ )
722
+ elif (
723
+ not batched
724
+ and is_simple_messages(messages)
725
+ and model_type in LITELLM_SUPPORTED_MODELS
726
+ ):
594
727
  debug("Using single LiteLLM API call")
595
- return await afetch_litellm_api_response(model=model_type,
596
- messages=messages,
597
- response_format=response_format)
598
-
728
+ return await afetch_litellm_api_response(
729
+ model=model_type, messages=messages, response_format=response_format
730
+ )
731
+
599
732
  error(f"Model {model_type} not supported by either API")
600
- raise ValueError(f"Model {model_type} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again.")
733
+ raise ValueError(
734
+ f"Model {model_type} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again."
735
+ )
601
736
 
602
737
 
603
- def get_completion_multiple_models(models: List[str], messages: List[List[Mapping]], response_formats: List[pydantic.BaseModel] = None) -> List[str]:
738
+ def get_completion_multiple_models(
739
+ models: List[str],
740
+ messages: List[List[Mapping]],
741
+ response_formats: List[pydantic.BaseModel] | None = None,
742
+ ) -> List[str | None]:
604
743
  """
605
744
  Retrieves completions for a single prompt from multiple models in parallel. Supports closed-source and OSS models.
606
745
 
@@ -608,28 +747,32 @@ def get_completion_multiple_models(models: List[str], messages: List[List[Mappin
608
747
  models (List[str]): List of models to query
609
748
  messages (List[List[Mapping]]): List of messages to query. Each inner object corresponds to a single prompt.
610
749
  response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
611
-
750
+
612
751
  Returns:
613
752
  List[str]: List of completions from the models in the order of the input models
614
753
  Raises:
615
754
  ValueError: If a model is not supported by Litellm or Together
616
755
  """
617
756
  debug(f"Starting multiple model completion for {len(models)} models")
618
-
757
+
619
758
  if models is None or models == []:
620
759
  raise ValueError("Models list cannot be empty")
621
-
760
+
622
761
  validate_batched_chat_messages(messages)
623
-
762
+
624
763
  if len(models) != len(messages):
625
764
  error(f"Model/message count mismatch: {len(models)} vs {len(messages)}")
626
- raise ValueError(f"Number of models and messages must be the same: {len(models)} != {len(messages)}")
765
+ raise ValueError(
766
+ f"Number of models and messages must be the same: {len(models)} != {len(messages)}"
767
+ )
627
768
  if response_formats is None:
628
769
  response_formats = [None] * len(models)
629
770
  # Partition the model requests into TogetherAI and Litellm models, but keep the ordering saved
630
771
  together_calls, litellm_calls = {}, {} # index -> model, message, response_format
631
772
  together_responses, litellm_responses = [], []
632
- for idx, (model, message, r_format) in enumerate(zip(models, messages, response_formats)):
773
+ for idx, (model, message, r_format) in enumerate(
774
+ zip(models, messages, response_formats)
775
+ ):
633
776
  if model in TOGETHER_SUPPORTED_MODELS:
634
777
  debug(f"Model {model} routed to Together API")
635
778
  together_calls[idx] = (model, message, r_format)
@@ -638,39 +781,49 @@ def get_completion_multiple_models(models: List[str], messages: List[List[Mappin
638
781
  litellm_calls[idx] = (model, message, r_format)
639
782
  else:
640
783
  error(f"Model {model} not supported by either API")
641
- raise ValueError(f"Model {model} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again.")
642
-
784
+ raise ValueError(
785
+ f"Model {model} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again."
786
+ )
787
+
643
788
  # Add validation before processing
644
789
  for msg_list in messages:
645
790
  validate_chat_messages(msg_list)
646
-
791
+
647
792
  # Get the responses from the TogetherAI models
648
793
  # List of responses from the TogetherAI models in order of the together_calls dict
649
794
  if together_calls:
650
795
  debug(f"Executing {len(together_calls)} Together API calls")
651
- together_responses = query_together_api_multiple_calls(models=[model for model, _, _ in together_calls.values()],
652
- messages=[message for _, message, _ in together_calls.values()],
653
- response_formats=[format for _, _, format in together_calls.values()])
654
-
796
+ together_responses = query_together_api_multiple_calls(
797
+ models=[model for model, _, _ in together_calls.values()],
798
+ messages=[message for _, message, _ in together_calls.values()],
799
+ response_formats=[format for _, _, format in together_calls.values()],
800
+ )
801
+
655
802
  # Get the responses from the Litellm models
656
803
  if litellm_calls:
657
804
  debug(f"Executing {len(litellm_calls)} LiteLLM API calls")
658
- litellm_responses = query_litellm_api_multiple_calls(models=[model for model, _, _ in litellm_calls.values()],
659
- messages=[message for _, message, _ in litellm_calls.values()],
660
- response_formats=[format for _, _, format in litellm_calls.values()])
805
+ litellm_responses = query_litellm_api_multiple_calls(
806
+ models=[model for model, _, _ in litellm_calls.values()],
807
+ messages=[message for _, message, _ in litellm_calls.values()],
808
+ response_formats=[format for _, _, format in litellm_calls.values()],
809
+ )
661
810
 
662
811
  # Merge the responses in the order of the original models
663
812
  debug("Merging responses")
664
- out = [None] * len(models)
813
+ out: List[Union[str, None]] = [None] * len(models)
665
814
  for idx, (model, message, r_format) in together_calls.items():
666
815
  out[idx] = together_responses.pop(0)
667
816
  for idx, (model, message, r_format) in litellm_calls.items():
668
817
  out[idx] = litellm_responses.pop(0)
669
818
  debug("Multiple model completion finished")
670
- return out
819
+ return out
671
820
 
672
821
 
673
- async def aget_completion_multiple_models(models: List[str], messages: List[List[Mapping]], response_formats: List[pydantic.BaseModel] = None) -> List[str]:
822
+ async def aget_completion_multiple_models(
823
+ models: List[str],
824
+ messages: List[List[Mapping]],
825
+ response_formats: List[pydantic.BaseModel] | None = None,
826
+ ) -> List[str | None]:
674
827
  """
675
828
  ASYNCHRONOUSLY retrieves completions for a single prompt from multiple models in parallel. Supports closed-source and OSS models.
676
829
 
@@ -678,7 +831,7 @@ async def aget_completion_multiple_models(models: List[str], messages: List[List
678
831
  models (List[str]): List of models to query
679
832
  messages (List[List[Mapping]]): List of messages to query. Each inner object corresponds to a single prompt.
680
833
  response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
681
-
834
+
682
835
  Returns:
683
836
  List[str]: List of completions from the models in the order of the input models
684
837
  Raises:
@@ -686,48 +839,54 @@ async def aget_completion_multiple_models(models: List[str], messages: List[List
686
839
  """
687
840
  if models is None or models == []:
688
841
  raise ValueError("Models list cannot be empty")
689
-
842
+
690
843
  if len(models) != len(messages):
691
- raise ValueError(f"Number of models and messages must be the same: {len(models)} != {len(messages)}")
844
+ raise ValueError(
845
+ f"Number of models and messages must be the same: {len(models)} != {len(messages)}"
846
+ )
692
847
  if response_formats is None:
693
848
  response_formats = [None] * len(models)
694
849
 
695
850
  validate_batched_chat_messages(messages)
696
-
851
+
697
852
  # Partition the model requests into TogetherAI and Litellm models, but keep the ordering saved
698
853
  together_calls, litellm_calls = {}, {} # index -> model, message, response_format
699
854
  together_responses, litellm_responses = [], []
700
- for idx, (model, message, r_format) in enumerate(zip(models, messages, response_formats)):
855
+ for idx, (model, message, r_format) in enumerate(
856
+ zip(models, messages, response_formats)
857
+ ):
701
858
  if model in TOGETHER_SUPPORTED_MODELS:
702
859
  together_calls[idx] = (model, message, r_format)
703
860
  elif model in LITELLM_SUPPORTED_MODELS:
704
861
  litellm_calls[idx] = (model, message, r_format)
705
862
  else:
706
- raise ValueError(f"Model {model} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again.")
707
-
863
+ raise ValueError(
864
+ f"Model {model} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again."
865
+ )
866
+
708
867
  # Add validation before processing
709
868
  for msg_list in messages:
710
869
  validate_chat_messages(msg_list)
711
-
870
+
712
871
  # Get the responses from the TogetherAI models
713
872
  # List of responses from the TogetherAI models in order of the together_calls dict
714
873
  if together_calls:
715
874
  together_responses = await aquery_together_api_multiple_calls(
716
- models=[model for model, _, _ in together_calls.values()],
717
- messages=[message for _, message, _ in together_calls.values()],
718
- response_formats=[format for _, _, format in together_calls.values()]
875
+ models=[model for model, _, _ in together_calls.values()],
876
+ messages=[message for _, message, _ in together_calls.values()],
877
+ response_formats=[format for _, _, format in together_calls.values()],
719
878
  )
720
-
879
+
721
880
  # Get the responses from the Litellm models
722
881
  if litellm_calls:
723
882
  litellm_responses = await aquery_litellm_api_multiple_calls(
724
883
  models=[model for model, _, _ in litellm_calls.values()],
725
884
  messages=[message for _, message, _ in litellm_calls.values()],
726
- response_formats=[format for _, _, format in litellm_calls.values()]
885
+ response_formats=[format for _, _, format in litellm_calls.values()],
727
886
  )
728
887
 
729
888
  # Merge the responses in the order of the original models
730
- out = [None] * len(models)
889
+ out: List[Union[str, None]] = [None] * len(models)
731
890
  for idx, (model, message, r_format) in together_calls.items():
732
891
  out[idx] = together_responses.pop(0)
733
892
  for idx, (model, message, r_format) in litellm_calls.items():
@@ -736,53 +895,66 @@ async def aget_completion_multiple_models(models: List[str], messages: List[List
736
895
 
737
896
 
738
897
  if __name__ == "__main__":
739
-
740
- # Batched
741
- pprint.pprint(get_chat_completion(
742
- model_type="LLAMA3_405B_INSTRUCT_TURBO",
743
- messages=[
744
- [
745
- {"role": "system", "content": "You are a helpful assistant."},
746
- {"role": "user", "content": "What is the capital of France?"},
747
- ],
748
- [
749
- {"role": "system", "content": "You are a helpful assistant."},
750
- {"role": "user", "content": "What is the capital of Japan?"},
751
- ]
898
+ batched_messages: List[List[Mapping]] = [
899
+ [
900
+ {"role": "system", "content": "You are a helpful assistant."},
901
+ {"role": "user", "content": "What is the capital of France?"},
752
902
  ],
753
- batched=True
754
- ))
903
+ [
904
+ {"role": "system", "content": "You are a helpful assistant."},
905
+ {"role": "user", "content": "What is the capital of Japan?"},
906
+ ],
907
+ ]
755
908
 
756
- # Non batched
757
- pprint.pprint(get_chat_completion(
758
- model_type="LLAMA3_8B_INSTRUCT_TURBO",
759
- messages=[
909
+ non_batched_messages: List[Mapping] = [
910
+ {"role": "system", "content": "You are a helpful assistant."},
911
+ {"role": "user", "content": "What is the capital of France?"},
912
+ ]
913
+
914
+ batched_messages_2: List[List[Mapping]] = [
915
+ [
916
+ {"role": "system", "content": "You are a helpful assistant."},
917
+ {"role": "user", "content": "What is the capital of China?"},
918
+ ],
919
+ [
760
920
  {"role": "system", "content": "You are a helpful assistant."},
761
921
  {"role": "user", "content": "What is the capital of France?"},
762
922
  ],
763
- batched=False
764
- ))
923
+ [
924
+ {"role": "system", "content": "You are a helpful assistant."},
925
+ {"role": "user", "content": "What is the capital of Japan?"},
926
+ ],
927
+ ]
928
+
929
+ # Batched
930
+ pprint.pprint(
931
+ get_chat_completion(
932
+ model_type="LLAMA3_405B_INSTRUCT_TURBO",
933
+ messages=batched_messages,
934
+ batched=True,
935
+ )
936
+ )
937
+
938
+ # Non batched
939
+ pprint.pprint(
940
+ get_chat_completion(
941
+ model_type="LLAMA3_8B_INSTRUCT_TURBO",
942
+ messages=non_batched_messages,
943
+ batched=False,
944
+ )
945
+ )
765
946
 
766
947
  # Batched single completion to multiple models
767
- pprint.pprint(get_completion_multiple_models(
768
- models=[
769
- "LLAMA3_70B_INSTRUCT_TURBO", "LLAMA3_405B_INSTRUCT_TURBO", "gpt-4.1-mini"
770
- ],
771
- messages=[
772
- [
773
- {"role": "system", "content": "You are a helpful assistant."},
774
- {"role": "user", "content": "What is the capital of China?"},
948
+ pprint.pprint(
949
+ get_completion_multiple_models(
950
+ models=[
951
+ "LLAMA3_70B_INSTRUCT_TURBO",
952
+ "LLAMA3_405B_INSTRUCT_TURBO",
953
+ "gpt-4.1-mini",
775
954
  ],
776
- [
777
- {"role": "system", "content": "You are a helpful assistant."},
778
- {"role": "user", "content": "What is the capital of France?"},
779
- ],
780
- [
781
- {"role": "system", "content": "You are a helpful assistant."},
782
- {"role": "user", "content": "What is the capital of Japan?"},
783
- ]
784
- ]
785
- ))
786
-
955
+ messages=batched_messages_2,
956
+ )
957
+ )
958
+
787
959
  ExcInfo: TypeAlias = tuple[type[BaseException], BaseException, TracebackType]
788
960
  OptExcInfo: TypeAlias = ExcInfo | tuple[None, None, None]