judgeval 0.7.1__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. judgeval/__init__.py +139 -12
  2. judgeval/api/__init__.py +501 -0
  3. judgeval/api/api_types.py +344 -0
  4. judgeval/cli.py +2 -4
  5. judgeval/constants.py +10 -26
  6. judgeval/data/evaluation_run.py +49 -26
  7. judgeval/data/example.py +2 -2
  8. judgeval/data/judgment_types.py +266 -82
  9. judgeval/data/result.py +4 -5
  10. judgeval/data/scorer_data.py +4 -2
  11. judgeval/data/tool.py +2 -2
  12. judgeval/data/trace.py +7 -50
  13. judgeval/data/trace_run.py +7 -4
  14. judgeval/{dataset.py → dataset/__init__.py} +43 -28
  15. judgeval/env.py +67 -0
  16. judgeval/{run_evaluation.py → evaluation/__init__.py} +29 -95
  17. judgeval/exceptions.py +27 -0
  18. judgeval/integrations/langgraph/__init__.py +788 -0
  19. judgeval/judges/__init__.py +2 -2
  20. judgeval/judges/litellm_judge.py +75 -15
  21. judgeval/judges/together_judge.py +86 -18
  22. judgeval/judges/utils.py +7 -21
  23. judgeval/{common/logger.py → logger.py} +8 -6
  24. judgeval/scorers/__init__.py +0 -4
  25. judgeval/scorers/agent_scorer.py +3 -7
  26. judgeval/scorers/api_scorer.py +8 -13
  27. judgeval/scorers/base_scorer.py +52 -32
  28. judgeval/scorers/example_scorer.py +1 -3
  29. judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +0 -14
  30. judgeval/scorers/judgeval_scorers/api_scorers/prompt_scorer.py +45 -20
  31. judgeval/scorers/judgeval_scorers/api_scorers/tool_dependency.py +2 -2
  32. judgeval/scorers/judgeval_scorers/api_scorers/tool_order.py +3 -3
  33. judgeval/scorers/score.py +21 -31
  34. judgeval/scorers/trace_api_scorer.py +5 -0
  35. judgeval/scorers/utils.py +1 -103
  36. judgeval/tracer/__init__.py +1075 -2
  37. judgeval/tracer/constants.py +1 -0
  38. judgeval/tracer/exporters/__init__.py +37 -0
  39. judgeval/tracer/exporters/s3.py +119 -0
  40. judgeval/tracer/exporters/store.py +43 -0
  41. judgeval/tracer/exporters/utils.py +32 -0
  42. judgeval/tracer/keys.py +67 -0
  43. judgeval/tracer/llm/__init__.py +1233 -0
  44. judgeval/{common/tracer → tracer/llm}/providers.py +5 -10
  45. judgeval/{local_eval_queue.py → tracer/local_eval_queue.py} +15 -10
  46. judgeval/tracer/managers.py +188 -0
  47. judgeval/tracer/processors/__init__.py +181 -0
  48. judgeval/tracer/utils.py +20 -0
  49. judgeval/trainer/__init__.py +5 -0
  50. judgeval/{common/trainer → trainer}/config.py +12 -9
  51. judgeval/{common/trainer → trainer}/console.py +2 -9
  52. judgeval/{common/trainer → trainer}/trainable_model.py +12 -7
  53. judgeval/{common/trainer → trainer}/trainer.py +119 -17
  54. judgeval/utils/async_utils.py +2 -3
  55. judgeval/utils/decorators.py +24 -0
  56. judgeval/utils/file_utils.py +37 -4
  57. judgeval/utils/guards.py +32 -0
  58. judgeval/utils/meta.py +14 -0
  59. judgeval/{common/api/json_encoder.py → utils/serialize.py} +7 -1
  60. judgeval/utils/testing.py +88 -0
  61. judgeval/utils/url.py +10 -0
  62. judgeval/{version_check.py → utils/version_check.py} +3 -3
  63. judgeval/version.py +5 -0
  64. judgeval/warnings.py +4 -0
  65. {judgeval-0.7.1.dist-info → judgeval-0.9.0.dist-info}/METADATA +12 -14
  66. judgeval-0.9.0.dist-info/RECORD +80 -0
  67. judgeval/clients.py +0 -35
  68. judgeval/common/__init__.py +0 -13
  69. judgeval/common/api/__init__.py +0 -3
  70. judgeval/common/api/api.py +0 -375
  71. judgeval/common/api/constants.py +0 -186
  72. judgeval/common/exceptions.py +0 -27
  73. judgeval/common/storage/__init__.py +0 -6
  74. judgeval/common/storage/s3_storage.py +0 -97
  75. judgeval/common/tracer/__init__.py +0 -31
  76. judgeval/common/tracer/constants.py +0 -22
  77. judgeval/common/tracer/core.py +0 -2427
  78. judgeval/common/tracer/otel_exporter.py +0 -108
  79. judgeval/common/tracer/otel_span_processor.py +0 -188
  80. judgeval/common/tracer/span_processor.py +0 -37
  81. judgeval/common/tracer/span_transformer.py +0 -207
  82. judgeval/common/tracer/trace_manager.py +0 -101
  83. judgeval/common/trainer/__init__.py +0 -5
  84. judgeval/common/utils.py +0 -948
  85. judgeval/integrations/langgraph.py +0 -844
  86. judgeval/judges/mixture_of_judges.py +0 -287
  87. judgeval/judgment_client.py +0 -267
  88. judgeval/rules.py +0 -521
  89. judgeval/scorers/judgeval_scorers/api_scorers/execution_order.py +0 -52
  90. judgeval/scorers/judgeval_scorers/api_scorers/hallucination.py +0 -28
  91. judgeval/utils/alerts.py +0 -93
  92. judgeval/utils/requests.py +0 -50
  93. judgeval-0.7.1.dist-info/RECORD +0 -82
  94. {judgeval-0.7.1.dist-info → judgeval-0.9.0.dist-info}/WHEEL +0 -0
  95. {judgeval-0.7.1.dist-info → judgeval-0.9.0.dist-info}/entry_points.txt +0 -0
  96. {judgeval-0.7.1.dist-info → judgeval-0.9.0.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/utils.py DELETED
@@ -1,948 +0,0 @@
1
- """
2
- This file contains utility functions used in repo scripts
3
-
4
- For API calling, we support:
5
- - parallelized model calls on the same prompt
6
- - batched model calls on different prompts
7
-
8
- NOTE: any function beginning with 'a', e.g. 'afetch_together_api_response', is an asynchronous function
9
- """
10
-
11
- # Standard library imports
12
- import asyncio
13
- import concurrent.futures
14
- import os
15
- from types import TracebackType
16
- from judgeval.common.api.constants import ROOT_API
17
- from judgeval.utils.requests import requests
18
- import pprint
19
- from typing import Any, Dict, List, Optional, TypeAlias, Union, TypeGuard
20
-
21
- # Third-party imports
22
- import litellm
23
- import pydantic
24
- from dotenv import load_dotenv
25
-
26
- # Local application/library-specific imports
27
- from judgeval.clients import async_together_client, together_client
28
- from judgeval.constants import (
29
- ACCEPTABLE_MODELS,
30
- MAX_WORKER_THREADS,
31
- TOGETHER_SUPPORTED_MODELS,
32
- LITELLM_SUPPORTED_MODELS,
33
- )
34
- from judgeval.common.logger import judgeval_logger
35
-
36
-
37
- class CustomModelParameters(pydantic.BaseModel):
38
- model_name: str
39
- secret_key: str
40
- litellm_base_url: str
41
-
42
- @pydantic.field_validator("model_name")
43
- @classmethod
44
- def validate_model_name(cls, v):
45
- if not v:
46
- raise ValueError("Model name cannot be empty")
47
- return v
48
-
49
- @pydantic.field_validator("secret_key")
50
- @classmethod
51
- def validate_secret_key(cls, v):
52
- if not v:
53
- raise ValueError("Secret key cannot be empty")
54
- return v
55
-
56
- @pydantic.field_validator("litellm_base_url")
57
- @classmethod
58
- def validate_litellm_base_url(cls, v):
59
- if not v:
60
- raise ValueError("Litellm base URL cannot be empty")
61
- return v
62
-
63
-
64
- class ChatCompletionRequest(pydantic.BaseModel):
65
- model: str
66
- messages: List[Dict[str, str]]
67
- response_format: Optional[Union[pydantic.BaseModel, Dict[str, Any]]] = None
68
-
69
- @pydantic.field_validator("messages")
70
- @classmethod
71
- def validate_messages(cls, messages):
72
- if not messages:
73
- raise ValueError("Messages cannot be empty")
74
-
75
- for msg in messages:
76
- if not isinstance(msg, dict):
77
- raise TypeError("Message must be a dictionary")
78
- if "role" not in msg:
79
- raise ValueError("Message missing required 'role' field")
80
- if "content" not in msg:
81
- raise ValueError("Message missing required 'content' field")
82
- if msg["role"] not in ["system", "user", "assistant"]:
83
- raise ValueError(
84
- f"Invalid role '{msg['role']}'. Must be 'system', 'user', or 'assistant'"
85
- )
86
-
87
- return messages
88
-
89
- @pydantic.field_validator("model")
90
- @classmethod
91
- def validate_model(cls, model):
92
- if not model:
93
- raise ValueError("Model cannot be empty")
94
- if model not in ACCEPTABLE_MODELS:
95
- raise ValueError(f"Model {model} is not in the list of supported models.")
96
- return model
97
-
98
- @pydantic.field_validator("response_format", mode="before")
99
- @classmethod
100
- def validate_response_format(cls, response_format):
101
- if response_format is not None:
102
- if not isinstance(response_format, (dict, pydantic.BaseModel)):
103
- raise TypeError(
104
- "Response format must be a dictionary or pydantic model"
105
- )
106
- # Optional: Add additional validation for required fields if needed
107
- # For example, checking for 'type': 'json' in OpenAI's format
108
- return response_format
109
-
110
-
111
- os.environ["LITELLM_LOG"] = "DEBUG"
112
-
113
- load_dotenv()
114
-
115
-
116
- def read_file(file_path: str) -> str:
117
- with open(file_path, "r", encoding="utf-8") as file:
118
- return file.read()
119
-
120
-
121
- def validate_api_key(judgment_api_key: str):
122
- """
123
- Validates that the user api key is valid
124
- """
125
- response = requests.post(
126
- f"{ROOT_API}/auth/validate_api_key/",
127
- headers={
128
- "Content-Type": "application/json",
129
- "Authorization": f"Bearer {judgment_api_key}",
130
- },
131
- json={},
132
- verify=True,
133
- )
134
- if response.status_code == 200:
135
- return True, response.json()
136
- else:
137
- return False, response.json().get("detail", "Error validating API key")
138
-
139
-
140
- def fetch_together_api_response(
141
- model: str,
142
- messages: List[Dict[str, str]],
143
- response_format: Union[pydantic.BaseModel, None] = None,
144
- ) -> str:
145
- """
146
- Fetches a single response from the Together API for a given model and messages.
147
- """
148
- # Validate request
149
- if messages is None or messages == []:
150
- raise ValueError("Messages cannot be empty")
151
-
152
- request = ChatCompletionRequest(
153
- model=model, messages=messages, response_format=response_format
154
- )
155
-
156
- if request.response_format is not None:
157
- response = together_client.chat.completions.create(
158
- model=request.model,
159
- messages=request.messages,
160
- response_format=request.response_format,
161
- )
162
- else:
163
- response = together_client.chat.completions.create(
164
- model=request.model,
165
- messages=request.messages,
166
- )
167
-
168
- return response.choices[0].message.content
169
-
170
-
171
- async def afetch_together_api_response(
172
- model: str,
173
- messages: List[Dict],
174
- response_format: Union[pydantic.BaseModel, None] = None,
175
- ) -> str:
176
- """
177
- ASYNCHRONOUSLY Fetches a single response from the Together API for a given model and messages.
178
- """
179
- request = ChatCompletionRequest(
180
- model=model, messages=messages, response_format=response_format
181
- )
182
-
183
- if request.response_format is not None:
184
- response = await async_together_client.chat.completions.create(
185
- model=request.model,
186
- messages=request.messages,
187
- response_format=request.response_format,
188
- )
189
- else:
190
- response = await async_together_client.chat.completions.create(
191
- model=request.model,
192
- messages=request.messages,
193
- )
194
- return response.choices[0].message.content
195
-
196
-
197
- def query_together_api_multiple_calls(
198
- models: List[str],
199
- messages: List[List[Dict]],
200
- response_formats: Union[List[Union[pydantic.BaseModel, None]], None] = None,
201
- ) -> List[Union[str, None]]:
202
- """
203
- Queries the Together API for multiple calls in parallel
204
-
205
- Args:
206
- models (List[str]): List of models to query
207
- messages (List[List[Mapping]]): List of messages to query. Each inner object corresponds to a single prompt.
208
- response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
209
-
210
- Returns:
211
- List[str]: TogetherAI responses for each model and message pair in order. Any exceptions in the thread call result in a None.
212
- """
213
- # Check for empty models list
214
- if not models:
215
- raise ValueError("Models list cannot be empty")
216
-
217
- # Validate all models are supported
218
- for model in models:
219
- if model not in ACCEPTABLE_MODELS:
220
- raise ValueError(
221
- f"Model {model} is not in the list of supported models: {ACCEPTABLE_MODELS}."
222
- )
223
-
224
- # Validate input lengths match
225
- if response_formats is None:
226
- response_formats = [None] * len(models)
227
- if not (len(models) == len(messages) == len(response_formats)):
228
- raise ValueError(
229
- "Number of models, messages, and response formats must be the same"
230
- )
231
-
232
- # Validate message format
233
- validate_batched_chat_messages(messages)
234
-
235
- num_workers = int(os.getenv("NUM_WORKER_THREADS", MAX_WORKER_THREADS))
236
- # Initialize results to maintain ordered outputs
237
- out: List[Union[str, None]] = [None] * len(messages)
238
- with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
239
- # Submit all queries to together API with index, gets back the response content
240
- futures = {
241
- executor.submit(
242
- fetch_together_api_response, model, message, response_format
243
- ): idx
244
- for idx, (model, message, response_format) in enumerate(
245
- zip(models, messages, response_formats)
246
- )
247
- }
248
-
249
- # Collect results as they complete -- result is response content
250
- for future in concurrent.futures.as_completed(futures):
251
- idx = futures[future]
252
- try:
253
- out[idx] = future.result()
254
- except Exception as e:
255
- judgeval_logger.error(f"Error in parallel call {idx}: {str(e)}")
256
- out[idx] = None
257
- return out
258
-
259
-
260
- async def aquery_together_api_multiple_calls(
261
- models: List[str],
262
- messages: List[List[Dict]],
263
- response_formats: Union[List[Union[pydantic.BaseModel, None]], None] = None,
264
- ) -> List[Union[str, None]]:
265
- """
266
- Queries the Together API for multiple calls in parallel
267
-
268
- Args:
269
- models (List[str]): List of models to query
270
- messages (List[List[Mapping]]): List of messages to query. Each inner object corresponds to a single prompt.
271
- response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
272
-
273
- Returns:
274
- List[str]: TogetherAI responses for each model and message pair in order. Any exceptions in the thread call result in a None.
275
- """
276
- # Check for empty models list
277
- if not models:
278
- raise ValueError("Models list cannot be empty")
279
-
280
- # Validate all models are supported
281
- for model in models:
282
- if model not in ACCEPTABLE_MODELS:
283
- raise ValueError(
284
- f"Model {model} is not in the list of supported models: {ACCEPTABLE_MODELS}."
285
- )
286
-
287
- # Validate input lengths match
288
- if response_formats is None:
289
- response_formats = [None] * len(models)
290
- if not (len(models) == len(messages) == len(response_formats)):
291
- raise ValueError(
292
- "Number of models, messages, and response formats must be the same"
293
- )
294
-
295
- # Validate message format
296
- validate_batched_chat_messages(messages)
297
-
298
- out: List[Union[str, None]] = [None] * len(messages)
299
-
300
- async def fetch_and_store(idx, model, message, response_format):
301
- try:
302
- out[idx] = await afetch_together_api_response(
303
- model, message, response_format
304
- )
305
- except Exception as e:
306
- judgeval_logger.error(f"Error in parallel call {idx}: {str(e)}")
307
- out[idx] = None
308
-
309
- tasks = [
310
- fetch_and_store(idx, model, message, response_format)
311
- for idx, (model, message, response_format) in enumerate(
312
- zip(models, messages, response_formats)
313
- )
314
- ]
315
-
316
- await asyncio.gather(*tasks)
317
- return out
318
-
319
-
320
- def fetch_litellm_api_response(
321
- model: str,
322
- messages: List[Dict[str, str]],
323
- response_format: Union[pydantic.BaseModel, None] = None,
324
- ) -> str:
325
- """
326
- Fetches a single response from the Litellm API for a given model and messages.
327
- """
328
- request = ChatCompletionRequest(
329
- model=model, messages=messages, response_format=response_format
330
- )
331
-
332
- if request.response_format is not None:
333
- response = litellm.completion(
334
- model=request.model,
335
- messages=request.messages,
336
- response_format=request.response_format,
337
- )
338
- else:
339
- response = litellm.completion(
340
- model=request.model,
341
- messages=request.messages,
342
- )
343
- return response.choices[0].message.content
344
-
345
-
346
- def fetch_custom_litellm_api_response(
347
- custom_model_parameters: CustomModelParameters,
348
- messages: List[Dict[str, str]],
349
- response_format: Union[pydantic.BaseModel, None] = None,
350
- ) -> str:
351
- if messages is None or messages == []:
352
- raise ValueError("Messages cannot be empty")
353
-
354
- if custom_model_parameters is None:
355
- raise ValueError("Custom model parameters cannot be empty")
356
-
357
- if not isinstance(custom_model_parameters, CustomModelParameters):
358
- raise ValueError(
359
- "Custom model parameters must be a CustomModelParameters object"
360
- )
361
-
362
- if response_format is not None:
363
- response = litellm.completion(
364
- model=custom_model_parameters.model_name,
365
- messages=messages,
366
- api_key=custom_model_parameters.secret_key,
367
- base_url=custom_model_parameters.litellm_base_url,
368
- response_format=response_format,
369
- )
370
- else:
371
- response = litellm.completion(
372
- model=custom_model_parameters.model_name,
373
- messages=messages,
374
- api_key=custom_model_parameters.secret_key,
375
- base_url=custom_model_parameters.litellm_base_url,
376
- )
377
- return response.choices[0].message.content
378
-
379
-
380
- async def afetch_litellm_api_response(
381
- model: str,
382
- messages: List[Dict[str, str]],
383
- response_format: Union[pydantic.BaseModel, None] = None,
384
- ) -> str:
385
- """
386
- ASYNCHRONOUSLY Fetches a single response from the Litellm API for a given model and messages.
387
- """
388
- if messages is None or messages == []:
389
- raise ValueError("Messages cannot be empty")
390
-
391
- # Add validation
392
- validate_chat_messages(messages)
393
-
394
- if model not in ACCEPTABLE_MODELS:
395
- raise ValueError(
396
- f"Model {model} is not in the list of supported models: {ACCEPTABLE_MODELS}."
397
- )
398
-
399
- if response_format is not None:
400
- response = await litellm.acompletion(
401
- model=model, messages=messages, response_format=response_format
402
- )
403
- else:
404
- response = await litellm.acompletion(
405
- model=model,
406
- messages=messages,
407
- )
408
- return response.choices[0].message.content
409
-
410
-
411
- async def afetch_custom_litellm_api_response(
412
- custom_model_parameters: CustomModelParameters,
413
- messages: List[Dict[str, str]],
414
- response_format: Union[pydantic.BaseModel, None] = None,
415
- ) -> str:
416
- """
417
- ASYNCHRONOUSLY Fetches a single response from the Litellm API for a given model and messages.
418
- """
419
- if messages is None or messages == []:
420
- raise ValueError("Messages cannot be empty")
421
-
422
- if custom_model_parameters is None:
423
- raise ValueError("Custom model parameters cannot be empty")
424
-
425
- if not isinstance(custom_model_parameters, CustomModelParameters):
426
- raise ValueError(
427
- "Custom model parameters must be a CustomModelParameters object"
428
- )
429
-
430
- if response_format is not None:
431
- response = await litellm.acompletion(
432
- model=custom_model_parameters.model_name,
433
- messages=messages,
434
- api_key=custom_model_parameters.secret_key,
435
- base_url=custom_model_parameters.litellm_base_url,
436
- response_format=response_format,
437
- )
438
- else:
439
- response = await litellm.acompletion(
440
- model=custom_model_parameters.model_name,
441
- messages=messages,
442
- api_key=custom_model_parameters.secret_key,
443
- base_url=custom_model_parameters.litellm_base_url,
444
- )
445
- return response.choices[0].message.content
446
-
447
-
448
- def query_litellm_api_multiple_calls(
449
- models: List[str],
450
- messages: List[List[Dict]],
451
- response_formats: Union[List[Union[pydantic.BaseModel, None]], None] = None,
452
- ) -> List[Union[str, None]]:
453
- """
454
- Queries the Litellm API for multiple calls in parallel
455
-
456
- Args:
457
- models (List[str]): List of models to query
458
- messages (List[List[Mapping]]): List of messages to query
459
- response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
460
-
461
- Returns:
462
- List[str]: Litellm responses for each model and message pair in order. Any exceptions in the thread call result in a None.
463
- """
464
- num_workers = int(os.getenv("NUM_WORKER_THREADS", MAX_WORKER_THREADS))
465
- # Initialize results to maintain ordered outputs
466
- out: List[Union[str, None]] = [None] * len(messages)
467
- with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
468
- # Submit all queries to Litellm API with index, gets back the response content
469
- futures = {
470
- executor.submit(
471
- fetch_litellm_api_response, model, message, response_format
472
- ): idx
473
- for idx, (model, message, response_format) in enumerate(
474
- zip(models, messages, response_formats or [None] * len(messages))
475
- )
476
- }
477
-
478
- # Collect results as they complete -- result is response content
479
- for future in concurrent.futures.as_completed(futures):
480
- idx = futures[future]
481
- try:
482
- out[idx] = future.result()
483
- except Exception as e:
484
- judgeval_logger.error(f"Error in parallel call {idx}: {str(e)}")
485
- out[idx] = None
486
- return out
487
-
488
-
489
- async def aquery_litellm_api_multiple_calls(
490
- models: List[str],
491
- messages: List[List[Dict[str, str]]],
492
- response_formats: Union[List[Union[pydantic.BaseModel, None]], None] = None,
493
- ) -> List[Union[str, None]]:
494
- """
495
- Queries the Litellm API for multiple calls in parallel
496
-
497
- Args:
498
- models (List[str]): List of models to query
499
- messages (List[List[Mapping]]): List of messages to query
500
- response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
501
-
502
- Returns:
503
- List[str]: Litellm responses for each model and message pair in order. Any exceptions in the thread call result in a None.
504
- """
505
- # Initialize results to maintain ordered outputs
506
- out: List[Union[str, None]] = [None] * len(messages)
507
-
508
- async def fetch_and_store(idx, model, message, response_format):
509
- try:
510
- out[idx] = await afetch_litellm_api_response(
511
- model, message, response_format
512
- )
513
- except Exception as e:
514
- judgeval_logger.error(f"Error in parallel call {idx}: {str(e)}")
515
- out[idx] = None
516
-
517
- tasks = [
518
- fetch_and_store(idx, model, message, response_format)
519
- for idx, (model, message, response_format) in enumerate(
520
- zip(models, messages, response_formats or [None] * len(messages))
521
- )
522
- ]
523
-
524
- await asyncio.gather(*tasks)
525
- return out
526
-
527
-
528
- def validate_chat_messages(messages, batched: bool = False):
529
- """Validate chat message format before API call"""
530
- if not isinstance(messages, list):
531
- raise TypeError("Messages must be a list")
532
-
533
- for msg in messages:
534
- if not isinstance(msg, dict):
535
- if batched and not isinstance(msg, list):
536
- raise TypeError("Each message must be a list")
537
- elif not batched:
538
- raise TypeError("Message must be a dictionary")
539
- if "role" not in msg:
540
- raise ValueError("Message missing required 'role' field")
541
- if "content" not in msg:
542
- raise ValueError("Message missing required 'content' field")
543
- if msg["role"] not in ["system", "user", "assistant"]:
544
- raise ValueError(
545
- f"Invalid role '{msg['role']}'. Must be 'system', 'user', or 'assistant'"
546
- )
547
-
548
-
549
- def validate_batched_chat_messages(messages):
550
- """
551
- Validate format of batched chat messages before API call
552
-
553
- Args:
554
- messages (List[List[Mapping]]): List of message lists, where each inner list contains
555
- message dictionaries with 'role' and 'content' fields
556
-
557
- Raises:
558
- TypeError: If messages format is invalid
559
- ValueError: If message content is invalid
560
- """
561
- if not isinstance(messages, list):
562
- raise TypeError("Batched messages must be a list")
563
-
564
- if not messages:
565
- raise ValueError("Batched messages cannot be empty")
566
-
567
- for message_list in messages:
568
- if not isinstance(message_list, list):
569
- raise TypeError("Each batch item must be a list of messages")
570
-
571
- # Validate individual messages using existing function
572
- validate_chat_messages(message_list)
573
-
574
-
575
- def is_batched_messages(
576
- messages: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
577
- ) -> TypeGuard[List[List[Dict[str, str]]]]:
578
- return isinstance(messages, list) and all(isinstance(msg, list) for msg in messages)
579
-
580
-
581
- def is_simple_messages(
582
- messages: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
583
- ) -> TypeGuard[List[Dict[str, str]]]:
584
- return isinstance(messages, list) and all(
585
- not isinstance(msg, list) for msg in messages
586
- )
587
-
588
-
589
- def get_chat_completion(
590
- model_type: str,
591
- messages: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
592
- response_format: Union[pydantic.BaseModel, None] = None,
593
- batched: bool = False,
594
- ) -> Union[str, List[Union[str, None]]]:
595
- """
596
- Generates chat completions using a single model and potentially several messages. Supports closed-source and OSS models.
597
-
598
- Parameters:
599
- - model_type (str): The type of model to use for generating completions.
600
- - messages (Union[List[Mapping], List[List[Mapping]]]): The messages to be used for generating completions.
601
- If batched is True, this should be a list of lists of mappings.
602
- - response_format (pydantic.BaseModel, optional): The format of the response. Defaults to None.
603
- - batched (bool, optional): Whether to process messages in batch mode. Defaults to False.
604
- Returns:
605
- - str: The generated chat completion(s). If batched is True, returns a list of strings.
606
- Raises:
607
- - ValueError: If requested model is not supported by Litellm or TogetherAI.
608
- """
609
-
610
- # Check for empty messages list
611
- if not messages or messages == []:
612
- raise ValueError("Messages cannot be empty")
613
-
614
- # Add validation
615
- if batched:
616
- validate_batched_chat_messages(messages)
617
- else:
618
- validate_chat_messages(messages)
619
-
620
- if (
621
- batched
622
- and is_batched_messages(messages)
623
- and model_type in TOGETHER_SUPPORTED_MODELS
624
- ):
625
- return query_together_api_multiple_calls(
626
- models=[model_type] * len(messages),
627
- messages=messages,
628
- response_formats=[response_format] * len(messages),
629
- )
630
- elif (
631
- batched
632
- and is_batched_messages(messages)
633
- and model_type in LITELLM_SUPPORTED_MODELS
634
- ):
635
- return query_litellm_api_multiple_calls(
636
- models=[model_type] * len(messages),
637
- messages=messages,
638
- response_formats=[response_format] * len(messages),
639
- )
640
- elif (
641
- not batched
642
- and is_simple_messages(messages)
643
- and model_type in TOGETHER_SUPPORTED_MODELS
644
- ):
645
- return fetch_together_api_response(
646
- model=model_type, messages=messages, response_format=response_format
647
- )
648
- elif (
649
- not batched
650
- and is_simple_messages(messages)
651
- and model_type in LITELLM_SUPPORTED_MODELS
652
- ):
653
- return fetch_litellm_api_response(
654
- model=model_type, messages=messages, response_format=response_format
655
- )
656
-
657
- raise ValueError(
658
- f"Model {model_type} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again."
659
- )
660
-
661
-
662
- async def aget_chat_completion(
663
- model_type: str,
664
- messages: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
665
- response_format: Union[pydantic.BaseModel, None] = None,
666
- batched: bool = False,
667
- ) -> Union[str, List[Union[str, None]]]:
668
- """
669
- ASYNCHRONOUSLY generates chat completions using a single model and potentially several messages. Supports closed-source and OSS models.
670
-
671
- Parameters:
672
- - model_type (str): The type of model to use for generating completions.
673
- - messages (Union[List[Mapping], List[List[Mapping]]]): The messages to be used for generating completions.
674
- If batched is True, this should be a list of lists of mappings.
675
- - response_format (pydantic.BaseModel, optional): The format of the response. Defaults to None.
676
- - batched (bool, optional): Whether to process messages in batch mode. Defaults to False.
677
- Returns:
678
- - str: The generated chat completion(s). If batched is True, returns a list of strings.
679
- Raises:
680
- - ValueError: If requested model is not supported by Litellm or TogetherAI.
681
- """
682
-
683
- if batched:
684
- validate_batched_chat_messages(messages)
685
- else:
686
- validate_chat_messages(messages)
687
-
688
- if (
689
- batched
690
- and is_batched_messages(messages)
691
- and model_type in TOGETHER_SUPPORTED_MODELS
692
- ):
693
- return await aquery_together_api_multiple_calls(
694
- models=[model_type] * len(messages),
695
- messages=messages,
696
- response_formats=[response_format] * len(messages),
697
- )
698
- elif (
699
- batched
700
- and is_batched_messages(messages)
701
- and model_type in LITELLM_SUPPORTED_MODELS
702
- ):
703
- return await aquery_litellm_api_multiple_calls(
704
- models=[model_type] * len(messages),
705
- messages=messages,
706
- response_formats=[response_format] * len(messages),
707
- )
708
- elif (
709
- not batched
710
- and is_simple_messages(messages)
711
- and model_type in TOGETHER_SUPPORTED_MODELS
712
- ):
713
- return await afetch_together_api_response(
714
- model=model_type, messages=messages, response_format=response_format
715
- )
716
- elif (
717
- not batched
718
- and is_simple_messages(messages)
719
- and model_type in LITELLM_SUPPORTED_MODELS
720
- ):
721
- return await afetch_litellm_api_response(
722
- model=model_type, messages=messages, response_format=response_format
723
- )
724
-
725
- judgeval_logger.error(f"Model {model_type} not supported by either API")
726
- raise ValueError(
727
- f"Model {model_type} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again."
728
- )
729
-
730
-
731
- def get_completion_multiple_models(
732
- models: List[str],
733
- messages: List[List[Dict[str, str]]],
734
- response_formats: Union[List[Union[pydantic.BaseModel, None]], None] = None,
735
- ) -> List[Union[str, None]]:
736
- """
737
- Retrieves completions for a single prompt from multiple models in parallel. Supports closed-source and OSS models.
738
-
739
- Args:
740
- models (List[str]): List of models to query
741
- messages (List[List[Mapping]]): List of messages to query. Each inner object corresponds to a single prompt.
742
- response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
743
-
744
- Returns:
745
- List[str]: List of completions from the models in the order of the input models
746
- Raises:
747
- ValueError: If a model is not supported by Litellm or Together
748
- """
749
-
750
- if models is None or models == []:
751
- raise ValueError("Models list cannot be empty")
752
-
753
- validate_batched_chat_messages(messages)
754
-
755
- if len(models) != len(messages):
756
- judgeval_logger.error(
757
- f"Model/message count mismatch: {len(models)} vs {len(messages)}"
758
- )
759
- raise ValueError(
760
- f"Number of models and messages must be the same: {len(models)} != {len(messages)}"
761
- )
762
- if response_formats is None:
763
- response_formats = [None] * len(models)
764
- # Partition the model requests into TogetherAI and Litellm models, but keep the ordering saved
765
- together_calls, litellm_calls = {}, {} # index -> model, message, response_format
766
- together_responses, litellm_responses = [], []
767
- for idx, (model, message, r_format) in enumerate(
768
- zip(models, messages, response_formats)
769
- ):
770
- if model in TOGETHER_SUPPORTED_MODELS:
771
- together_calls[idx] = (model, message, r_format)
772
- elif model in LITELLM_SUPPORTED_MODELS:
773
- litellm_calls[idx] = (model, message, r_format)
774
- else:
775
- judgeval_logger.error(f"Model {model} not supported by either API")
776
- raise ValueError(
777
- f"Model {model} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again."
778
- )
779
-
780
- # Add validation before processing
781
- for msg_list in messages:
782
- validate_chat_messages(msg_list)
783
-
784
- # Get the responses from the TogetherAI models
785
- # List of responses from the TogetherAI models in order of the together_calls dict
786
- if together_calls:
787
- together_responses = query_together_api_multiple_calls(
788
- models=[model for model, _, _ in together_calls.values()],
789
- messages=[message for _, message, _ in together_calls.values()],
790
- response_formats=[format for _, _, format in together_calls.values()],
791
- )
792
-
793
- # Get the responses from the Litellm models
794
- if litellm_calls:
795
- litellm_responses = query_litellm_api_multiple_calls(
796
- models=[model for model, _, _ in litellm_calls.values()],
797
- messages=[message for _, message, _ in litellm_calls.values()],
798
- response_formats=[format for _, _, format in litellm_calls.values()],
799
- )
800
-
801
- # Merge the responses in the order of the original models
802
- out: List[Union[str, None]] = [None] * len(models)
803
- for idx, (model, message, r_format) in together_calls.items():
804
- out[idx] = together_responses.pop(0)
805
- for idx, (model, message, r_format) in litellm_calls.items():
806
- out[idx] = litellm_responses.pop(0)
807
- return out
808
-
809
-
810
- async def aget_completion_multiple_models(
811
- models: List[str],
812
- messages: List[List[Dict[str, str]]],
813
- response_formats: Union[List[Union[pydantic.BaseModel, None]], None] = None,
814
- ) -> List[Union[str, None]]:
815
- """
816
- ASYNCHRONOUSLY retrieves completions for a single prompt from multiple models in parallel. Supports closed-source and OSS models.
817
-
818
- Args:
819
- models (List[str]): List of models to query
820
- messages (List[List[Mapping]]): List of messages to query. Each inner object corresponds to a single prompt.
821
- response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
822
-
823
- Returns:
824
- List[str]: List of completions from the models in the order of the input models
825
- Raises:
826
- ValueError: If a model is not supported by Litellm or Together
827
- """
828
- if models is None or models == []:
829
- raise ValueError("Models list cannot be empty")
830
-
831
- if len(models) != len(messages):
832
- raise ValueError(
833
- f"Number of models and messages must be the same: {len(models)} != {len(messages)}"
834
- )
835
- if response_formats is None:
836
- response_formats = [None] * len(models)
837
-
838
- validate_batched_chat_messages(messages)
839
-
840
- # Partition the model requests into TogetherAI and Litellm models, but keep the ordering saved
841
- together_calls, litellm_calls = {}, {} # index -> model, message, response_format
842
- together_responses, litellm_responses = [], []
843
- for idx, (model, message, r_format) in enumerate(
844
- zip(models, messages, response_formats)
845
- ):
846
- if model in TOGETHER_SUPPORTED_MODELS:
847
- together_calls[idx] = (model, message, r_format)
848
- elif model in LITELLM_SUPPORTED_MODELS:
849
- litellm_calls[idx] = (model, message, r_format)
850
- else:
851
- raise ValueError(
852
- f"Model {model} is not supported by Litellm or TogetherAI for chat completions. Please check the model name and try again."
853
- )
854
-
855
- # Add validation before processing
856
- for msg_list in messages:
857
- validate_chat_messages(msg_list)
858
-
859
- # Get the responses from the TogetherAI models
860
- # List of responses from the TogetherAI models in order of the together_calls dict
861
- if together_calls:
862
- together_responses = await aquery_together_api_multiple_calls(
863
- models=[model for model, _, _ in together_calls.values()],
864
- messages=[message for _, message, _ in together_calls.values()],
865
- response_formats=[format for _, _, format in together_calls.values()],
866
- )
867
-
868
- # Get the responses from the Litellm models
869
- if litellm_calls:
870
- litellm_responses = await aquery_litellm_api_multiple_calls(
871
- models=[model for model, _, _ in litellm_calls.values()],
872
- messages=[message for _, message, _ in litellm_calls.values()],
873
- response_formats=[format for _, _, format in litellm_calls.values()],
874
- )
875
-
876
- # Merge the responses in the order of the original models
877
- out: List[Union[str, None]] = [None] * len(models)
878
- for idx, (model, message, r_format) in together_calls.items():
879
- out[idx] = together_responses.pop(0)
880
- for idx, (model, message, r_format) in litellm_calls.items():
881
- out[idx] = litellm_responses.pop(0)
882
- return out
883
-
884
-
885
- if __name__ == "__main__":
886
- batched_messages: List[List[Dict[str, str]]] = [
887
- [
888
- {"role": "system", "content": "You are a helpful assistant."},
889
- {"role": "user", "content": "What is the capital of France?"},
890
- ],
891
- [
892
- {"role": "system", "content": "You are a helpful assistant."},
893
- {"role": "user", "content": "What is the capital of Japan?"},
894
- ],
895
- ]
896
-
897
- non_batched_messages: List[Dict[str, str]] = [
898
- {"role": "system", "content": "You are a helpful assistant."},
899
- {"role": "user", "content": "What is the capital of France?"},
900
- ]
901
-
902
- batched_messages_2: List[List[Dict[str, str]]] = [
903
- [
904
- {"role": "system", "content": "You are a helpful assistant."},
905
- {"role": "user", "content": "What is the capital of China?"},
906
- ],
907
- [
908
- {"role": "system", "content": "You are a helpful assistant."},
909
- {"role": "user", "content": "What is the capital of France?"},
910
- ],
911
- [
912
- {"role": "system", "content": "You are a helpful assistant."},
913
- {"role": "user", "content": "What is the capital of Japan?"},
914
- ],
915
- ]
916
-
917
- # Batched
918
- pprint.pprint(
919
- get_chat_completion(
920
- model_type="LLAMA3_405B_INSTRUCT_TURBO",
921
- messages=batched_messages,
922
- batched=True,
923
- )
924
- )
925
-
926
- # Non batched
927
- pprint.pprint(
928
- get_chat_completion(
929
- model_type="LLAMA3_8B_INSTRUCT_TURBO",
930
- messages=non_batched_messages,
931
- batched=False,
932
- )
933
- )
934
-
935
- # Batched single completion to multiple models
936
- pprint.pprint(
937
- get_completion_multiple_models(
938
- models=[
939
- "LLAMA3_70B_INSTRUCT_TURBO",
940
- "LLAMA3_405B_INSTRUCT_TURBO",
941
- "gpt-4.1-mini",
942
- ],
943
- messages=batched_messages_2,
944
- )
945
- )
946
-
947
- ExcInfo: TypeAlias = tuple[type[BaseException], BaseException, TracebackType]
948
- OptExcInfo: TypeAlias = Union[ExcInfo, tuple[None, None, None]]