palimpzest 0.7.20__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. palimpzest/__init__.py +37 -6
  2. palimpzest/agents/__init__.py +0 -0
  3. palimpzest/agents/compute_agents.py +0 -0
  4. palimpzest/agents/search_agents.py +637 -0
  5. palimpzest/constants.py +259 -197
  6. palimpzest/core/data/context.py +393 -0
  7. palimpzest/core/data/context_manager.py +163 -0
  8. palimpzest/core/data/dataset.py +634 -0
  9. palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
  10. palimpzest/core/elements/groupbysig.py +16 -13
  11. palimpzest/core/elements/records.py +166 -75
  12. palimpzest/core/lib/schemas.py +152 -390
  13. palimpzest/core/{data/dataclasses.py → models.py} +306 -170
  14. palimpzest/policy.py +2 -27
  15. palimpzest/prompts/__init__.py +35 -5
  16. palimpzest/prompts/agent_prompts.py +357 -0
  17. palimpzest/prompts/context_search.py +9 -0
  18. palimpzest/prompts/convert_prompts.py +61 -5
  19. palimpzest/prompts/filter_prompts.py +50 -5
  20. palimpzest/prompts/join_prompts.py +163 -0
  21. palimpzest/prompts/moa_proposer_convert_prompts.py +5 -5
  22. palimpzest/prompts/prompt_factory.py +358 -46
  23. palimpzest/prompts/validator.py +239 -0
  24. palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
  25. palimpzest/query/execution/execution_strategy.py +210 -317
  26. palimpzest/query/execution/execution_strategy_type.py +5 -7
  27. palimpzest/query/execution/mab_execution_strategy.py +249 -136
  28. palimpzest/query/execution/parallel_execution_strategy.py +153 -244
  29. palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
  30. palimpzest/query/generators/generators.py +157 -330
  31. palimpzest/query/operators/__init__.py +15 -5
  32. palimpzest/query/operators/aggregate.py +50 -33
  33. palimpzest/query/operators/compute.py +201 -0
  34. palimpzest/query/operators/convert.py +27 -21
  35. palimpzest/query/operators/critique_and_refine_convert.py +7 -5
  36. palimpzest/query/operators/distinct.py +62 -0
  37. palimpzest/query/operators/filter.py +22 -13
  38. palimpzest/query/operators/join.py +402 -0
  39. palimpzest/query/operators/limit.py +3 -3
  40. palimpzest/query/operators/logical.py +198 -80
  41. palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
  42. palimpzest/query/operators/physical.py +27 -21
  43. palimpzest/query/operators/project.py +3 -3
  44. palimpzest/query/operators/rag_convert.py +7 -7
  45. palimpzest/query/operators/retrieve.py +9 -9
  46. palimpzest/query/operators/scan.py +81 -42
  47. palimpzest/query/operators/search.py +524 -0
  48. palimpzest/query/operators/split_convert.py +10 -8
  49. palimpzest/query/optimizer/__init__.py +7 -9
  50. palimpzest/query/optimizer/cost_model.py +108 -441
  51. palimpzest/query/optimizer/optimizer.py +123 -181
  52. palimpzest/query/optimizer/optimizer_strategy.py +66 -61
  53. palimpzest/query/optimizer/plan.py +352 -67
  54. palimpzest/query/optimizer/primitives.py +43 -19
  55. palimpzest/query/optimizer/rules.py +484 -646
  56. palimpzest/query/optimizer/tasks.py +127 -58
  57. palimpzest/query/processor/config.py +41 -76
  58. palimpzest/query/processor/query_processor.py +73 -18
  59. palimpzest/query/processor/query_processor_factory.py +46 -38
  60. palimpzest/schemabuilder/schema_builder.py +15 -28
  61. palimpzest/utils/model_helpers.py +27 -77
  62. palimpzest/utils/progress.py +114 -102
  63. palimpzest/validator/__init__.py +0 -0
  64. palimpzest/validator/validator.py +306 -0
  65. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/METADATA +6 -1
  66. palimpzest-0.8.0.dist-info/RECORD +95 -0
  67. palimpzest/core/lib/fields.py +0 -141
  68. palimpzest/prompts/code_synthesis_prompts.py +0 -28
  69. palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
  70. palimpzest/query/generators/api_client_factory.py +0 -30
  71. palimpzest/query/operators/code_synthesis_convert.py +0 -488
  72. palimpzest/query/operators/map.py +0 -130
  73. palimpzest/query/processor/nosentinel_processor.py +0 -33
  74. palimpzest/query/processor/processing_strategy_type.py +0 -28
  75. palimpzest/query/processor/sentinel_processor.py +0 -88
  76. palimpzest/query/processor/streaming_processor.py +0 -149
  77. palimpzest/sets.py +0 -405
  78. palimpzest/utils/datareader_helpers.py +0 -61
  79. palimpzest/utils/demo_helpers.py +0 -75
  80. palimpzest/utils/field_helpers.py +0 -69
  81. palimpzest/utils/generation_helpers.py +0 -69
  82. palimpzest/utils/sandbox.py +0 -183
  83. palimpzest-0.7.20.dist-info/RECORD +0 -95
  84. /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
  85. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/WHEEL +0 -0
  86. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/licenses/LICENSE +0 -0
  87. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/top_level.txt +0 -0
@@ -4,36 +4,28 @@ This file contains the Generator classes and generator factory.
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
+ import json
7
8
  import logging
8
9
  import os
9
- import re
10
10
  import time
11
11
  import warnings
12
- from abc import ABC, abstractmethod
13
- from collections import Counter
14
12
  from copy import deepcopy
15
13
  from typing import Any, Generic, TypeVar
16
14
 
15
+ import litellm
16
+ import regex as re # Use regex instead of re to used variable length lookbehind
17
17
  from colorama import Fore, Style
18
- from openai import OpenAI
19
- from openai.types.chat.chat_completion import ChatCompletion
20
- from together import Together
21
- from together.types.chat_completions import ChatCompletionResponse
18
+ from pydantic.fields import FieldInfo
22
19
 
23
20
  from palimpzest.constants import (
24
21
  MODEL_CARDS,
25
- APIClient,
26
22
  Cardinality,
27
23
  Model,
28
24
  PromptStrategy,
29
25
  )
30
- from palimpzest.core.data.dataclasses import GenerationStats
31
26
  from palimpzest.core.elements.records import DataRecord
32
- from palimpzest.core.lib.fields import Field, ListField
27
+ from palimpzest.core.models import GenerationStats
33
28
  from palimpzest.prompts import PromptFactory
34
- from palimpzest.query.generators.api_client_factory import APIClientFactory
35
- from palimpzest.utils.generation_helpers import get_json_from_answer
36
- from palimpzest.utils.sandbox import API
37
29
 
38
30
  # DEFINITIONS
39
31
  GenerationOutput = tuple[dict, str | None, GenerationStats, list[dict]]
@@ -43,31 +35,71 @@ InputType = TypeVar("InputType")
43
35
 
44
36
  logger = logging.getLogger(__name__)
45
37
 
46
- def generator_factory(
47
- model: Model, prompt_strategy: PromptStrategy, cardinality: Cardinality, verbose: bool = False
48
- ) -> BaseGenerator:
38
+ def get_json_from_answer(answer: str, model: Model, cardinality: Cardinality) -> dict[str, Any]:
49
39
  """
50
- Factory function to return the correct generator based on the model, strategy, and cardinality.
40
+ This function parses an LLM response which is supposed to output a JSON object
41
+ and optimistically searches for the substring containing the JSON object.
51
42
  """
52
- if model.is_openai_model():
53
- return OpenAIGenerator(model, prompt_strategy, cardinality, verbose)
54
-
55
- elif model.is_together_model():
56
- return TogetherGenerator(model, prompt_strategy, cardinality, verbose)
57
-
58
- raise Exception(f"Unsupported model: {model}")
59
-
60
-
61
- def get_api_key(key: str) -> str:
62
- # get API key from environment or throw an exception if it's not set
63
- if key not in os.environ:
64
- raise ValueError("key not found in environment variables")
65
-
66
- return os.environ[key]
67
-
68
-
43
+ # model-specific trimming for LLAMA3 responses
44
+ if model.is_llama_model():
45
+ answer = answer.split("---")[0]
46
+ answer = answer.replace("True", "true")
47
+ answer = answer.replace("False", "false")
48
+
49
+ # split off context / excess, which models sometimes output after answer
50
+ answer = answer.split("Context:")[0]
51
+ answer = answer.split("# this is the answer")[0]
52
+
53
+ # trim the answer to only include the JSON dictionary
54
+ if cardinality == Cardinality.ONE_TO_ONE:
55
+ if not answer.strip().startswith("{"):
56
+ # Find the start index of the actual JSON string assuming the prefix is followed by the JSON dictionary
57
+ start_index = answer.find("{")
58
+ if start_index != -1:
59
+ # Remove the prefix and any leading characters before the JSON starts
60
+ answer = answer[start_index:]
61
+
62
+ if not answer.strip().endswith("}"):
63
+ # Find the end index of the actual JSON string assuming the suffix is preceded by the JSON dictionary
64
+ end_index = answer.rfind("}")
65
+ if end_index != -1:
66
+ # Remove the suffix and any trailing characters after the JSON ends
67
+ answer = answer[: end_index + 1]
68
+
69
+ # otherwise, trim the answer to only include the JSON array
70
+ else:
71
+ if not answer.strip().startswith("["):
72
+ # Find the start index of the actual JSON string assuming the prefix is followed by the JSON array
73
+ start_index = answer.find("[")
74
+ if start_index != -1:
75
+ # Remove the prefix and any leading characters before the JSON starts
76
+ answer = answer[start_index:]
77
+
78
+ if not answer.strip().endswith("]"):
79
+ # Find the end index of the actual JSON string
80
+ # assuming the suffix is preceded by the JSON object/array
81
+ end_index = answer.rfind("]")
82
+ if end_index != -1:
83
+ # Remove the suffix and any trailing characters after the JSON ends
84
+ answer = answer[: end_index + 1]
85
+
86
+ # Handle weird escaped values. I am not sure why the model
87
+ # is returning these, but the JSON parser can't take them
88
+ answer = answer.replace(r"\_", "_")
89
+ answer = answer.replace("\\n", "\n")
90
+ # Remove https and http prefixes to not conflict with comment detection
91
+ # Handle comments in the JSON response. Use regex from // until end of line
92
+ answer = re.sub(r"(?<!https?:)\/\/.*?$", "", answer, flags=re.MULTILINE)
93
+ answer = re.sub(r",\n.*\.\.\.$", "", answer, flags=re.MULTILINE)
94
+ # Sanitize newlines in the JSON response
95
+ answer = answer.replace("\n", " ")
96
+
97
+ # finally, parse and return the JSON object; errors are handled by the caller
98
+ return json.loads(answer)
99
+
100
+ # TODO: push parallelism of generations into LiteLLM rather than threadpool in executor
69
101
  # TODO: make sure answer parsing works with custom prompts / parsers (can defer this)
70
- class BaseGenerator(Generic[ContextType, InputType], ABC):
102
+ class Generator(Generic[ContextType, InputType]):
71
103
  """
72
104
  Abstract base class for Generators.
73
105
  """
@@ -76,95 +108,20 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
76
108
  self,
77
109
  model: Model,
78
110
  prompt_strategy: PromptStrategy,
111
+ reasoning_effort: str | None = None,
112
+ api_base: str | None = None,
79
113
  cardinality: Cardinality = Cardinality.ONE_TO_ONE,
80
114
  verbose: bool = False,
81
- system_role: str = "system",
82
115
  ):
83
116
  self.model = model
84
117
  self.model_name = model.value
85
118
  self.cardinality = cardinality
86
119
  self.prompt_strategy = prompt_strategy
120
+ self.reasoning_effort = reasoning_effort
121
+ self.api_base = api_base
87
122
  self.verbose = verbose
88
- self.system_role = system_role
89
123
  self.prompt_factory = PromptFactory(prompt_strategy, model, cardinality)
90
124
 
91
- @abstractmethod
92
- def _get_client_or_model(self, **kwargs) -> Any:
93
- """Returns a client (or local model) which can be invoked to perform the generation."""
94
- pass
95
-
96
- @abstractmethod
97
- def _generate_completion(self, client_or_model: Any, payload: dict, **kwargs) -> Any:
98
- """Generates a completion object using the client (or local model)."""
99
- pass
100
-
101
- @abstractmethod
102
- def _get_completion_text(self, completion: Any, **kwargs) -> Any:
103
- """Extract the completion text from the completion object."""
104
- pass
105
-
106
- @abstractmethod
107
- def _get_usage(self, completion: Any, **kwargs) -> Any:
108
- """Extract the usage statistics from the completion object."""
109
- pass
110
-
111
- @abstractmethod
112
- def _get_finish_reason(self, completion: Any, **kwargs) -> Any:
113
- """Extract the finish reason from the completion object."""
114
- pass
115
-
116
- @abstractmethod
117
- def _get_answer_log_probs(self, completion: Any, **kwargs) -> Any:
118
- """Extract the log probabilities from the completion object."""
119
- pass
120
-
121
- def _generate_payload(self, messages: list[dict], **kwargs) -> dict:
122
- """
123
- Generates the payload which will be fed into the client (or local model).
124
-
125
- Each message will be a dictionary with the following format:
126
- {
127
- "role": "user" | "system",
128
- "type": "text" | "image",
129
- "content": str
130
- }
131
- """
132
- # get basic parameters
133
- model = self.model_name
134
- temperature = kwargs.get("temperature", 0.0)
135
-
136
- # construct messages and add system prompt if present
137
- chat_messages, user_content = [], []
138
- for message in messages:
139
- # flush user content into a message and add system message
140
- if message["role"] == "system":
141
- if len(user_content) > 0:
142
- chat_messages.append({"role": "user", "content": user_content})
143
- user_content = []
144
-
145
- chat_messages.append({"role": self.system_role, "content": message["content"]})
146
-
147
- # add user content for text messages
148
- elif message["role"] == "user" and message["type"] == "text":
149
- user_content.append({"type": "text", "text": message["content"]})
150
-
151
- # add user content for image messages
152
- elif message["role"] == "user" and message["type"] == "image":
153
- user_content.append({"type": "image_url", "image_url": {"url": message["content"]}})
154
-
155
- # flush any remaining user content into a final message
156
- if len(user_content) > 0:
157
- chat_messages.append({"role": "user", "content": user_content})
158
-
159
- # construct and return payload
160
- payload = {
161
- "model": model,
162
- "temperature": temperature,
163
- "messages": chat_messages,
164
- }
165
-
166
- return payload
167
-
168
125
  def _parse_reasoning(self, completion_text: str, **kwargs) -> str:
169
126
  """Extract the reasoning for the generated output from the completion object."""
170
127
  # use a custom reasoning parser if provided
@@ -183,7 +140,7 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
183
140
  # otherwise, return the full completion text
184
141
  return completion_text
185
142
 
186
- def _prepare_field_answers(self, field_answers: dict | list[dict], fields: dict[str, Field]) -> dict[str, list]:
143
+ def _prepare_field_answers(self, field_answers: dict | list[dict], fields: dict[str, FieldInfo]) -> dict[str, list]:
187
144
  """
188
145
  field_answers is a dictionary mapping fields to their values. For one-to-one converts, wrap each
189
146
  answer in a list. For one-to-many converts, invert the list of dictionaries into a dictionary with
@@ -205,7 +162,7 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
205
162
 
206
163
  return field_answers
207
164
 
208
- def _check_convert_answer_text(self, answer_text: str, fields: dict[str, Field], throw_exception: bool=False) -> dict | list[dict] | None:
165
+ def _check_convert_answer_text(self, answer_text: str, fields: dict[str, FieldInfo], throw_exception: bool=False) -> dict | list[dict] | None:
209
166
  """
210
167
  Try parsing the answer text into a JSON object. If the parsing fails, return None.
211
168
  """
@@ -213,18 +170,6 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
213
170
  # extract json from the answer text
214
171
  field_answers = get_json_from_answer(answer_text, self.model, self.cardinality)
215
172
 
216
- # TODO: wrap non-list outputs in a list if expected output is a list
217
-
218
- # common error for one-to-one: if the output is a singleton list which contains a list, but the expected field type
219
- # is a list of strings, or a list of floats, i.e. not a list of lists; then extract the inner list
220
- if self.cardinality == Cardinality.ONE_TO_ONE:
221
- for field, field_type in fields.items():
222
- answer = field_answers[field]
223
- field_type_is_not_list_of_lists = isinstance(field_type, ListField) and not issubclass(field_type.element_type, ListField)
224
- answer_is_list_of_lists = isinstance(answer, list) and len(answer) == 1 and isinstance(answer[0], list)
225
- if field_type_is_not_list_of_lists and answer_is_list_of_lists:
226
- field_answers[field] = answer[0]
227
-
228
173
  # prepare the field answers to match the expected output and return
229
174
  return self._prepare_field_answers(field_answers, fields)
230
175
 
@@ -234,7 +179,7 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
234
179
 
235
180
  return None
236
181
 
237
- def _check_filter_answer_text(self, answer_text: str) -> dict | None:
182
+ def _check_bool_answer_text(self, answer_text: str) -> dict | None:
238
183
  """
239
184
  Return {"passed_operator": True} if and only if "true" is in the answer text.
240
185
  Return {"passed_operator": False} if and only if "false" is in the answer text.
@@ -249,7 +194,7 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
249
194
 
250
195
  return None
251
196
 
252
- def _parse_convert_answer(self, completion_text: str, fields: dict[str, Field], json_output: bool) -> dict[str, list]:
197
+ def _parse_convert_answer(self, completion_text: str, fields: dict[str, FieldInfo], json_output: bool) -> dict[str, list]:
253
198
  """Extract the answer from the completion object for convert operations."""
254
199
  # if the model followed the default instructions, the completion text will place
255
200
  # its answer between "ANSWER:" and "---"
@@ -288,15 +233,15 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
288
233
 
289
234
  return self._check_convert_answer_text(completion_text, fields, throw_exception=True)
290
235
 
291
- def _parse_filter_answer(self, completion_text: str) -> dict[str, list]:
292
- """Extract the answer from the completion object for filter operations."""
236
+ def _parse_bool_answer(self, completion_text: str) -> dict[str, list]:
237
+ """Extract the answer from the completion object for filter and join operations."""
293
238
  # if the model followed the default instructions, the completion text will place
294
239
  # its answer between "ANSWER:" and "---"
295
240
  regex = re.compile("answer:(.*?)---", re.IGNORECASE | re.DOTALL)
296
241
  matches = regex.findall(completion_text)
297
242
  if len(matches) > 0:
298
243
  answer_text = matches[0].strip()
299
- field_answers = self._check_filter_answer_text(answer_text)
244
+ field_answers = self._check_bool_answer_text(answer_text)
300
245
  if field_answers is not None:
301
246
  return field_answers
302
247
 
@@ -305,18 +250,18 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
305
250
  matches = regex.findall(completion_text)
306
251
  if len(matches) > 0:
307
252
  answer_text = matches[0].strip()
308
- field_answers = self._check_filter_answer_text(answer_text)
253
+ field_answers = self._check_bool_answer_text(answer_text)
309
254
  if field_answers is not None:
310
255
  return field_answers
311
256
 
312
257
  # finally, try taking all of the text; throw an exception if this doesn't work
313
- field_answers = self._check_filter_answer_text(completion_text)
258
+ field_answers = self._check_bool_answer_text(completion_text)
314
259
  if field_answers is None:
315
260
  raise Exception(f"Could not parse answer from completion text: {completion_text}")
316
261
 
317
262
  return field_answers
318
263
 
319
- def _parse_answer(self, completion_text: str, fields: dict[str, Field] | None, json_output: bool, **kwargs) -> dict[str, list]:
264
+ def _parse_answer(self, completion_text: str, fields: dict[str, FieldInfo] | None, json_output: bool, **kwargs) -> dict[str, list]:
320
265
  """Extract the answer from the completion object."""
321
266
  # use a custom answer parser if provided
322
267
  if kwargs.get("parse_answer"):
@@ -328,16 +273,15 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
328
273
 
329
274
  # extract the per-field answers from the completion text
330
275
  field_answers = (
331
- self._parse_filter_answer(completion_text)
332
- if self.prompt_strategy.is_bool_prompt()
276
+ self._parse_bool_answer(completion_text)
277
+ if self.prompt_strategy.is_bool_prompt() or self.prompt_strategy.is_join_prompt()
333
278
  else self._parse_convert_answer(completion_text, fields, json_output)
334
279
  )
335
280
 
336
281
  return field_answers
337
282
 
338
- def __call__(self, candidate: DataRecord, fields: dict[str, Field] | None, json_output: bool=True, **kwargs) -> GenerationOutput:
283
+ def __call__(self, candidate: DataRecord, fields: dict[str, FieldInfo] | None, right_candidate: DataRecord | None = None, json_output: bool=True, **kwargs) -> GenerationOutput:
339
284
  """Take the input record (`candidate`), generate the output `fields`, and return the generated output."""
340
- client = self._get_client_or_model()
341
285
  logger.debug(f"Generating for candidate {candidate} with fields {fields}")
342
286
 
343
287
  # fields can only be None if the user provides an answer parser
@@ -352,23 +296,45 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
352
296
  warnings.warn("Provided `system_prompt` without providing `prompt`; setting `prompt` = `system_prompt`.") # noqa: B028
353
297
 
354
298
  # generate a list of messages which can be used to construct a payload
355
- messages = self.prompt_factory.create_messages(candidate, fields, **kwargs)
356
-
357
- # create the chat payload
358
- chat_payload = self._generate_payload(messages, **kwargs)
299
+ messages = self.prompt_factory.create_messages(candidate, fields, right_candidate, **kwargs)
359
300
 
360
301
  # generate the text completion
361
302
  start_time = time.time()
362
303
  completion = None
363
304
  try:
364
- completion = self._generate_completion(client, chat_payload, **kwargs)
305
+ completion_kwargs = {}
306
+ if not self.model.is_o_model() and not self.model.is_gpt_5_model():
307
+ completion_kwargs = {"temperature": kwargs.get("temperature", 0.0), **completion_kwargs}
308
+ if self.prompt_strategy.is_audio_prompt():
309
+ completion_kwargs = {"modalities": ["text"], **completion_kwargs}
310
+ if self.model.is_reasoning_model():
311
+ if self.model.is_vertex_model():
312
+ reasoning_effort = self.reasoning_effort
313
+ if self.reasoning_effort is None and self.model == Model.GEMINI_2_5_PRO:
314
+ reasoning_effort = "low"
315
+ elif self.reasoning_effort is None:
316
+ reasoning_effort = "disable"
317
+ completion_kwargs = {"reasoning_effort": reasoning_effort, **completion_kwargs}
318
+ elif self.model.is_anthropic_model() and self.reasoning_effort is not None:
319
+ completion_kwargs = {"reasoning_effort": self.reasoning_effort, **completion_kwargs}
320
+ elif self.model.is_openai_model():
321
+ reasoning_effort = "minimal" if self.reasoning_effort is None else self.reasoning_effort
322
+ completion_kwargs = {"reasoning_effort": reasoning_effort, **completion_kwargs}
323
+ if self.model.is_vllm_model():
324
+ completion_kwargs = {"api_base": self.api_base, **completion_kwargs}
325
+ completion = litellm.completion(model=self.model_name, messages=messages, **completion_kwargs)
365
326
  end_time = time.time()
366
327
  logger.debug(f"Generated completion in {end_time - start_time:.2f} seconds")
367
328
  # if there's an error generating the completion, we have to return an empty answer
368
329
  # and can only account for the time spent performing the failed generation
369
- except Exception:
370
- # logger.error(f"Error generating completion: {e}")
371
- field_answers = {field_name: None for field_name in fields}
330
+ except Exception as e:
331
+ print(f"Error generating completion: {e}")
332
+ logger.error(f"Error generating completion: {e}")
333
+ field_answers = (
334
+ {"passed_operator": False}
335
+ if self.prompt_strategy.is_bool_prompt() or self.prompt_strategy.is_join_prompt()
336
+ else {field_name: None for field_name in fields}
337
+ )
372
338
  reasoning = None
373
339
  generation_stats = GenerationStats(
374
340
  model_name=self.model_name,
@@ -381,40 +347,57 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
381
347
  # parse usage statistics and create the GenerationStats
382
348
  generation_stats = None
383
349
  if completion is not None:
384
- usage = self._get_usage(completion, **kwargs)
385
- # finish_reason = self._get_finish_reason(completion, **kwargs)
386
- # answer_log_probs = self._get_answer_log_probs(completion, **kwargs)
350
+ usage = completion.usage.model_dump()
387
351
 
388
- # get cost per input/output token for the model and parse number of input and output tokens
389
- usd_per_input_token = MODEL_CARDS[self.model_name]["usd_per_input_token"]
352
+ # get cost per input/output token for the model
353
+ usd_per_input_token = MODEL_CARDS[self.model_name].get("usd_per_input_token", 0.0)
354
+ usd_per_audio_input_token = MODEL_CARDS[self.model_name].get("usd_per_audio_input_token", 0.0)
390
355
  usd_per_output_token = MODEL_CARDS[self.model_name]["usd_per_output_token"]
391
- input_tokens = usage["input_tokens"]
392
- output_tokens = usage["output_tokens"]
356
+
357
+ # TODO: for some models (e.g. GPT-5) we cannot separate text from image prompt tokens yet;
358
+ # for now, we only use tokens from prompt_token_details if it's an audio prompt
359
+ # get output tokens (all text) and input tokens by modality
360
+ output_tokens = usage["completion_tokens"]
361
+ if self.prompt_strategy.is_audio_prompt():
362
+ input_audio_tokens = usage["prompt_tokens_details"].get("audio_tokens", 0)
363
+ input_text_tokens = usage["prompt_tokens_details"].get("text_tokens", 0)
364
+ input_image_tokens = 0
365
+ else:
366
+ input_audio_tokens = 0
367
+ input_text_tokens = usage["prompt_tokens"]
368
+ input_image_tokens = 0
369
+ input_tokens = input_audio_tokens + input_text_tokens + input_image_tokens
370
+
371
+ # compute the input and output token costs
372
+ total_input_cost = (input_text_tokens + input_image_tokens) * usd_per_input_token + input_audio_tokens * usd_per_audio_input_token
373
+ total_output_cost = output_tokens * usd_per_output_token
393
374
 
394
375
  generation_stats = GenerationStats(
395
376
  model_name=self.model_name,
396
377
  llm_call_duration_secs=end_time - start_time,
397
378
  fn_call_duration_secs=0.0,
379
+ input_audio_tokens=input_audio_tokens,
380
+ input_text_tokens=input_text_tokens,
381
+ input_image_tokens=input_image_tokens,
398
382
  total_input_tokens=input_tokens,
399
383
  total_output_tokens=output_tokens,
400
- total_input_cost=input_tokens * usd_per_input_token,
401
- total_output_cost=output_tokens * usd_per_output_token,
402
- cost_per_record=input_tokens * usd_per_input_token + output_tokens * usd_per_output_token,
384
+ total_input_cost=total_input_cost,
385
+ total_output_cost=total_output_cost,
386
+ cost_per_record=total_input_cost + total_output_cost,
403
387
  total_llm_calls=1,
404
- # "system_prompt": system_prompt,
405
- # "prompt": prompt,
406
- # "usage": usage,
407
- # "finish_reason": finish_reason,
408
- # "answer_log_probs": answer_log_probs,
409
- # "answer": answer,
410
388
  )
411
389
 
412
390
  # pretty print prompt + full completion output for debugging
413
- completion_text = self._get_completion_text(completion, **kwargs)
391
+ completion_text = completion.choices[0].message.content
414
392
  prompt = ""
415
393
  for message in messages:
416
394
  if message["role"] == "user":
417
- prompt += message["content"] + "\n" if message["type"] == "text" else "<image>\n"
395
+ if message["type"] == "text":
396
+ prompt += message["content"] + "\n"
397
+ elif message["type"] == "image":
398
+ prompt += "<image>\n"
399
+ elif message["type"] == "input_audio":
400
+ prompt += "<audio>\n"
418
401
  logger.debug(f"PROMPT:\n{prompt}")
419
402
  logger.debug(Fore.GREEN + f"{completion_text}\n" + Style.RESET_ALL)
420
403
 
@@ -422,17 +405,20 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
422
405
  reasoning = None
423
406
  try:
424
407
  reasoning = self._parse_reasoning(completion_text, **kwargs)
425
- except Exception:
426
- # logger.error(f"Error parsing reasoning and answers: {e}")
427
- logger.debug("TODO: undo this")
408
+ except Exception as e:
409
+ logger.error(f"Error parsing reasoning and answers: {e}")
428
410
  pass
429
411
 
430
412
  # parse field answers
431
- field_answers = None if fields is None else {field_name: None for field_name in fields}
413
+ field_answers = None
414
+ if fields is not None and (self.prompt_strategy.is_bool_prompt() or self.prompt_strategy.is_join_prompt()):
415
+ field_answers = {"passed_operator": False}
416
+ elif fields is not None and not (self.prompt_strategy.is_bool_prompt() or self.prompt_strategy.is_join_prompt()):
417
+ field_answers = {field_name: None for field_name in fields}
432
418
  try:
433
419
  field_answers = self._parse_answer(completion_text, fields, json_output, **kwargs)
434
420
  except Exception as e:
435
- # logger.error(f"Error parsing answers: {e}")
421
+ logger.error(f"Error parsing answers: {e}")
436
422
  os.makedirs("parse-answer-errors", exist_ok=True)
437
423
  ts = time.time()
438
424
  with open(f"parse-answer-errors/error-{ts}.txt", "w") as f:
@@ -448,162 +434,3 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
448
434
 
449
435
  logger.debug(f"Generated field answers: {field_answers}")
450
436
  return field_answers, reasoning, generation_stats, messages
451
-
452
-
453
- class OpenAIGenerator(BaseGenerator[str | list[str], str]):
454
- """
455
- Class for generating text using the OpenAI chat API.
456
- """
457
-
458
- def __init__(
459
- self,
460
- model: Model,
461
- prompt_strategy: PromptStrategy,
462
- cardinality: Cardinality = Cardinality.ONE_TO_ONE,
463
- verbose: bool = False,
464
- ):
465
- # assert that model is an OpenAI model
466
- assert model.is_openai_model()
467
- super().__init__(model, prompt_strategy, cardinality, verbose, "developer")
468
-
469
- def _get_client_or_model(self, **kwargs) -> OpenAI:
470
- """Returns a client (or local model) which can be invoked to perform the generation."""
471
- return APIClientFactory.get_client(APIClient.OPENAI, get_api_key("OPENAI_API_KEY"))
472
-
473
- def _generate_completion(self, client: OpenAI, payload: dict, **kwargs) -> ChatCompletion:
474
- """Generates a completion object using the client (or local model)."""
475
- return client.chat.completions.create(**payload)
476
-
477
- def _get_completion_text(self, completion: ChatCompletion, **kwargs) -> str:
478
- """Extract the completion text from the completion object."""
479
- return completion.choices[0].message.content
480
-
481
- def _get_usage(self, completion: ChatCompletion, **kwargs) -> dict:
482
- """Extract the usage statistics from the completion object."""
483
- return {
484
- "input_tokens": completion.usage.prompt_tokens,
485
- "output_tokens": completion.usage.completion_tokens,
486
- }
487
-
488
- def _get_finish_reason(self, completion: ChatCompletion, **kwargs) -> str:
489
- """Extract the finish reason from the completion object."""
490
- return completion.choices[0].finish_reason
491
-
492
- def _get_answer_log_probs(self, completion: ChatCompletion, **kwargs) -> list[float]:
493
- """Extract the log probabilities from the completion object."""
494
- return completion.choices[0].logprobs
495
-
496
-
497
- class TogetherGenerator(BaseGenerator[str | list[str], str]):
498
- """
499
- Class for generating text using the Together chat API.
500
- """
501
-
502
- def __init__(
503
- self,
504
- model: Model,
505
- prompt_strategy: PromptStrategy,
506
- cardinality: Cardinality = Cardinality.ONE_TO_ONE,
507
- verbose: bool = False,
508
- ):
509
- # assert that model is a model offered by Together
510
- assert model.is_together_model()
511
- super().__init__(model, prompt_strategy, cardinality, verbose, "system")
512
-
513
- def _generate_payload(self, messages: list[dict], **kwargs) -> dict:
514
- """
515
- Generates the payload which will be fed into the client (or local model).
516
-
517
- Each message will be a dictionary with the following format:
518
- {
519
- "role": "user" | "system",
520
- "type": "text" | "image",
521
- "content": str
522
- }
523
-
524
- For LLAMA3, the payload needs to be in a {"role": <role>, "content": <content>} format.
525
- """
526
- # for other models, use our standard payload generation
527
- if not self.model.is_llama_model():
528
- return super()._generate_payload(messages, **kwargs)
529
-
530
- # get basic parameters
531
- model = self.model_name
532
- temperature = kwargs.get("temperature", 0.0)
533
-
534
- # construct messages in simple {"role": <role>, "content": <content>} format
535
- chat_messages = []
536
- for message in messages:
537
- chat_messages.append({"role": message["role"], "content": message["content"]})
538
-
539
- # construct and return payload
540
- payload = {
541
- "model": model,
542
- "temperature": temperature,
543
- "messages": chat_messages,
544
- }
545
-
546
- return payload
547
-
548
- def _get_client_or_model(self, **kwargs) -> Together:
549
- """Returns a client (or local model) which can be invoked to perform the generation."""
550
- return APIClientFactory.get_client(APIClient.TOGETHER, get_api_key("TOGETHER_API_KEY"))
551
-
552
- def _generate_completion(self, client: Together, payload: dict, **kwargs) -> ChatCompletionResponse:
553
- """Generates a completion object using the client (or local model)."""
554
- return client.chat.completions.create(**payload)
555
-
556
- def _get_completion_text(self, completion: ChatCompletionResponse, **kwargs) -> str:
557
- """Extract the completion text from the completion object."""
558
- return completion.choices[0].message.content
559
-
560
- def _get_usage(self, completion: ChatCompletionResponse, **kwargs) -> dict:
561
- """Extract the usage statistics from the completion object."""
562
- return {
563
- "input_tokens": completion.usage.prompt_tokens,
564
- "output_tokens": completion.usage.completion_tokens,
565
- }
566
-
567
- def _get_finish_reason(self, completion: ChatCompletionResponse, **kwargs) -> str:
568
- """Extract the finish reason from the completion object."""
569
- return completion.choices[0].finish_reason.value
570
-
571
- def _get_answer_log_probs(self, completion: ChatCompletionResponse, **kwargs) -> list[float]:
572
- """Extract the log probabilities from the completion object."""
573
- return completion.choices[0].logprobs
574
-
575
-
576
- ### CODE SYNTHESIS EXECUTION ###
577
- def code_execution(api: API, code: str, candidate_dict: dict[str, Any], verbose: bool = False):
578
- inputs = {field_name: candidate_dict[field_name] for field_name in api.inputs}
579
- response = api.api_execute(code, inputs)
580
- pred = response["response"] if response["status"] and response["response"] else None
581
- return pred
582
-
583
-
584
- def code_ensemble_execution(
585
- api: API, code_ensemble: dict[str, str], candidate_dict: dict[str, Any], verbose: bool = True
586
- ) -> GenerationOutput:
587
- start_time = time.time()
588
- try:
589
- preds = list()
590
- for _, code in code_ensemble.items():
591
- pred = code_execution(api, code, candidate_dict)
592
- preds.append(pred)
593
-
594
- preds = [pred for pred in preds if pred is not None]
595
-
596
- if len(preds) == 1:
597
- majority_response = preds[0]
598
- exec_stats = GenerationStats(fn_call_duration_secs=time.time() - start_time)
599
- return majority_response, None, exec_stats
600
-
601
- if len(preds) > 0:
602
- majority_response = Counter(preds).most_common(1)[0][0]
603
- exec_stats = GenerationStats(fn_call_duration_secs=time.time() - start_time)
604
- return majority_response, None, exec_stats
605
-
606
- except Exception:
607
- pass
608
-
609
- return None, None, GenerationStats(fn_call_duration_secs=time.time() - start_time)