palimpzest 0.7.21__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. palimpzest/__init__.py +37 -6
  2. palimpzest/agents/__init__.py +0 -0
  3. palimpzest/agents/compute_agents.py +0 -0
  4. palimpzest/agents/search_agents.py +637 -0
  5. palimpzest/constants.py +343 -209
  6. palimpzest/core/data/context.py +393 -0
  7. palimpzest/core/data/context_manager.py +163 -0
  8. palimpzest/core/data/dataset.py +639 -0
  9. palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
  10. palimpzest/core/elements/groupbysig.py +16 -13
  11. palimpzest/core/elements/records.py +166 -75
  12. palimpzest/core/lib/schemas.py +152 -390
  13. palimpzest/core/{data/dataclasses.py → models.py} +306 -170
  14. palimpzest/policy.py +2 -27
  15. palimpzest/prompts/__init__.py +35 -5
  16. palimpzest/prompts/agent_prompts.py +357 -0
  17. palimpzest/prompts/context_search.py +9 -0
  18. palimpzest/prompts/convert_prompts.py +62 -6
  19. palimpzest/prompts/filter_prompts.py +51 -6
  20. palimpzest/prompts/join_prompts.py +163 -0
  21. palimpzest/prompts/moa_proposer_convert_prompts.py +6 -6
  22. palimpzest/prompts/prompt_factory.py +375 -47
  23. palimpzest/prompts/split_proposer_prompts.py +1 -1
  24. palimpzest/prompts/util_phrases.py +5 -0
  25. palimpzest/prompts/validator.py +239 -0
  26. palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
  27. palimpzest/query/execution/execution_strategy.py +210 -317
  28. palimpzest/query/execution/execution_strategy_type.py +5 -7
  29. palimpzest/query/execution/mab_execution_strategy.py +249 -136
  30. palimpzest/query/execution/parallel_execution_strategy.py +153 -244
  31. palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
  32. palimpzest/query/generators/generators.py +160 -331
  33. palimpzest/query/operators/__init__.py +15 -5
  34. palimpzest/query/operators/aggregate.py +50 -33
  35. palimpzest/query/operators/compute.py +201 -0
  36. palimpzest/query/operators/convert.py +33 -19
  37. palimpzest/query/operators/critique_and_refine_convert.py +7 -5
  38. palimpzest/query/operators/distinct.py +62 -0
  39. palimpzest/query/operators/filter.py +26 -16
  40. palimpzest/query/operators/join.py +403 -0
  41. palimpzest/query/operators/limit.py +3 -3
  42. palimpzest/query/operators/logical.py +205 -77
  43. palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
  44. palimpzest/query/operators/physical.py +27 -21
  45. palimpzest/query/operators/project.py +3 -3
  46. palimpzest/query/operators/rag_convert.py +7 -7
  47. palimpzest/query/operators/retrieve.py +9 -9
  48. palimpzest/query/operators/scan.py +81 -42
  49. palimpzest/query/operators/search.py +524 -0
  50. palimpzest/query/operators/split_convert.py +10 -8
  51. palimpzest/query/optimizer/__init__.py +7 -9
  52. palimpzest/query/optimizer/cost_model.py +108 -441
  53. palimpzest/query/optimizer/optimizer.py +123 -181
  54. palimpzest/query/optimizer/optimizer_strategy.py +66 -61
  55. palimpzest/query/optimizer/plan.py +352 -67
  56. palimpzest/query/optimizer/primitives.py +43 -19
  57. palimpzest/query/optimizer/rules.py +484 -646
  58. palimpzest/query/optimizer/tasks.py +127 -58
  59. palimpzest/query/processor/config.py +42 -76
  60. palimpzest/query/processor/query_processor.py +73 -18
  61. palimpzest/query/processor/query_processor_factory.py +46 -38
  62. palimpzest/schemabuilder/schema_builder.py +15 -28
  63. palimpzest/utils/model_helpers.py +32 -77
  64. palimpzest/utils/progress.py +114 -102
  65. palimpzest/validator/__init__.py +0 -0
  66. palimpzest/validator/validator.py +306 -0
  67. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/METADATA +6 -1
  68. palimpzest-0.8.1.dist-info/RECORD +95 -0
  69. palimpzest/core/lib/fields.py +0 -141
  70. palimpzest/prompts/code_synthesis_prompts.py +0 -28
  71. palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
  72. palimpzest/query/generators/api_client_factory.py +0 -30
  73. palimpzest/query/operators/code_synthesis_convert.py +0 -488
  74. palimpzest/query/operators/map.py +0 -130
  75. palimpzest/query/processor/nosentinel_processor.py +0 -33
  76. palimpzest/query/processor/processing_strategy_type.py +0 -28
  77. palimpzest/query/processor/sentinel_processor.py +0 -88
  78. palimpzest/query/processor/streaming_processor.py +0 -149
  79. palimpzest/sets.py +0 -405
  80. palimpzest/utils/datareader_helpers.py +0 -61
  81. palimpzest/utils/demo_helpers.py +0 -75
  82. palimpzest/utils/field_helpers.py +0 -69
  83. palimpzest/utils/generation_helpers.py +0 -69
  84. palimpzest/utils/sandbox.py +0 -183
  85. palimpzest-0.7.21.dist-info/RECORD +0 -95
  86. /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
  87. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/WHEEL +0 -0
  88. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/licenses/LICENSE +0 -0
  89. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/top_level.txt +0 -0
@@ -4,36 +4,28 @@ This file contains the Generator classes and generator factory.
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
+ import json
7
8
  import logging
8
9
  import os
9
- import re
10
10
  import time
11
11
  import warnings
12
- from abc import ABC, abstractmethod
13
- from collections import Counter
14
12
  from copy import deepcopy
15
13
  from typing import Any, Generic, TypeVar
16
14
 
15
+ import litellm
16
+ import regex as re # Use regex instead of re to used variable length lookbehind
17
17
  from colorama import Fore, Style
18
- from openai import OpenAI
19
- from openai.types.chat.chat_completion import ChatCompletion
20
- from together import Together
21
- from together.types.chat_completions import ChatCompletionResponse
18
+ from pydantic.fields import FieldInfo
22
19
 
23
20
  from palimpzest.constants import (
24
21
  MODEL_CARDS,
25
- APIClient,
26
22
  Cardinality,
27
23
  Model,
28
24
  PromptStrategy,
29
25
  )
30
- from palimpzest.core.data.dataclasses import GenerationStats
31
26
  from palimpzest.core.elements.records import DataRecord
32
- from palimpzest.core.lib.fields import Field, ListField
27
+ from palimpzest.core.models import GenerationStats
33
28
  from palimpzest.prompts import PromptFactory
34
- from palimpzest.query.generators.api_client_factory import APIClientFactory
35
- from palimpzest.utils.generation_helpers import get_json_from_answer
36
- from palimpzest.utils.sandbox import API
37
29
 
38
30
  # DEFINITIONS
39
31
  GenerationOutput = tuple[dict, str | None, GenerationStats, list[dict]]
@@ -43,31 +35,71 @@ InputType = TypeVar("InputType")
43
35
 
44
36
  logger = logging.getLogger(__name__)
45
37
 
46
- def generator_factory(
47
- model: Model, prompt_strategy: PromptStrategy, cardinality: Cardinality, verbose: bool = False
48
- ) -> BaseGenerator:
38
+ def get_json_from_answer(answer: str, model: Model, cardinality: Cardinality) -> dict[str, Any]:
49
39
  """
50
- Factory function to return the correct generator based on the model, strategy, and cardinality.
40
+ This function parses an LLM response which is supposed to output a JSON object
41
+ and optimistically searches for the substring containing the JSON object.
51
42
  """
52
- if model.is_openai_model():
53
- return OpenAIGenerator(model, prompt_strategy, cardinality, verbose)
54
-
55
- elif model.is_together_model():
56
- return TogetherGenerator(model, prompt_strategy, cardinality, verbose)
57
-
58
- raise Exception(f"Unsupported model: {model}")
59
-
60
-
61
- def get_api_key(key: str) -> str:
62
- # get API key from environment or throw an exception if it's not set
63
- if key not in os.environ:
64
- raise ValueError("key not found in environment variables")
65
-
66
- return os.environ[key]
67
-
68
-
43
+ # model-specific trimming for LLAMA3 responses
44
+ if model.is_llama_model():
45
+ answer = answer.split("---")[0]
46
+ answer = answer.replace("True", "true")
47
+ answer = answer.replace("False", "false")
48
+
49
+ # split off context / excess, which models sometimes output after answer
50
+ answer = answer.split("Context:")[0]
51
+ answer = answer.split("# this is the answer")[0]
52
+
53
+ # trim the answer to only include the JSON dictionary
54
+ if cardinality == Cardinality.ONE_TO_ONE:
55
+ if not answer.strip().startswith("{"):
56
+ # Find the start index of the actual JSON string assuming the prefix is followed by the JSON dictionary
57
+ start_index = answer.find("{")
58
+ if start_index != -1:
59
+ # Remove the prefix and any leading characters before the JSON starts
60
+ answer = answer[start_index:]
61
+
62
+ if not answer.strip().endswith("}"):
63
+ # Find the end index of the actual JSON string assuming the suffix is preceded by the JSON dictionary
64
+ end_index = answer.rfind("}")
65
+ if end_index != -1:
66
+ # Remove the suffix and any trailing characters after the JSON ends
67
+ answer = answer[: end_index + 1]
68
+
69
+ # otherwise, trim the answer to only include the JSON array
70
+ else:
71
+ if not answer.strip().startswith("["):
72
+ # Find the start index of the actual JSON string assuming the prefix is followed by the JSON array
73
+ start_index = answer.find("[")
74
+ if start_index != -1:
75
+ # Remove the prefix and any leading characters before the JSON starts
76
+ answer = answer[start_index:]
77
+
78
+ if not answer.strip().endswith("]"):
79
+ # Find the end index of the actual JSON string
80
+ # assuming the suffix is preceded by the JSON object/array
81
+ end_index = answer.rfind("]")
82
+ if end_index != -1:
83
+ # Remove the suffix and any trailing characters after the JSON ends
84
+ answer = answer[: end_index + 1]
85
+
86
+ # Handle weird escaped values. I am not sure why the model
87
+ # is returning these, but the JSON parser can't take them
88
+ answer = answer.replace(r"\_", "_")
89
+ answer = answer.replace("\\n", "\n")
90
+ # Remove https and http prefixes to not conflict with comment detection
91
+ # Handle comments in the JSON response. Use regex from // until end of line
92
+ answer = re.sub(r"(?<!https?:)\/\/.*?$", "", answer, flags=re.MULTILINE)
93
+ answer = re.sub(r",\n.*\.\.\.$", "", answer, flags=re.MULTILINE)
94
+ # Sanitize newlines in the JSON response
95
+ answer = answer.replace("\n", " ")
96
+
97
+ # finally, parse and return the JSON object; errors are handled by the caller
98
+ return json.loads(answer)
99
+
100
+ # TODO: push parallelism of generations into LiteLLM rather than threadpool in executor
69
101
  # TODO: make sure answer parsing works with custom prompts / parsers (can defer this)
70
- class BaseGenerator(Generic[ContextType, InputType], ABC):
102
+ class Generator(Generic[ContextType, InputType]):
71
103
  """
72
104
  Abstract base class for Generators.
73
105
  """
@@ -76,94 +108,21 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
76
108
  self,
77
109
  model: Model,
78
110
  prompt_strategy: PromptStrategy,
111
+ reasoning_effort: str | None = None,
112
+ api_base: str | None = None,
79
113
  cardinality: Cardinality = Cardinality.ONE_TO_ONE,
114
+ desc: str | None = None,
80
115
  verbose: bool = False,
81
- system_role: str = "system",
82
116
  ):
83
117
  self.model = model
84
118
  self.model_name = model.value
85
119
  self.cardinality = cardinality
86
120
  self.prompt_strategy = prompt_strategy
121
+ self.reasoning_effort = reasoning_effort
122
+ self.api_base = api_base
123
+ self.desc = desc
87
124
  self.verbose = verbose
88
- self.system_role = system_role
89
- self.prompt_factory = PromptFactory(prompt_strategy, model, cardinality)
90
-
91
- @abstractmethod
92
- def _get_client_or_model(self, **kwargs) -> Any:
93
- """Returns a client (or local model) which can be invoked to perform the generation."""
94
- pass
95
-
96
- @abstractmethod
97
- def _generate_completion(self, client_or_model: Any, payload: dict, **kwargs) -> Any:
98
- """Generates a completion object using the client (or local model)."""
99
- pass
100
-
101
- @abstractmethod
102
- def _get_completion_text(self, completion: Any, **kwargs) -> Any:
103
- """Extract the completion text from the completion object."""
104
- pass
105
-
106
- @abstractmethod
107
- def _get_usage(self, completion: Any, **kwargs) -> Any:
108
- """Extract the usage statistics from the completion object."""
109
- pass
110
-
111
- @abstractmethod
112
- def _get_finish_reason(self, completion: Any, **kwargs) -> Any:
113
- """Extract the finish reason from the completion object."""
114
- pass
115
-
116
- @abstractmethod
117
- def _get_answer_log_probs(self, completion: Any, **kwargs) -> Any:
118
- """Extract the log probabilities from the completion object."""
119
- pass
120
-
121
- def _generate_payload(self, messages: list[dict], **kwargs) -> dict:
122
- """
123
- Generates the payload which will be fed into the client (or local model).
124
-
125
- Each message will be a dictionary with the following format:
126
- {
127
- "role": "user" | "system",
128
- "type": "text" | "image",
129
- "content": str
130
- }
131
- """
132
- # get basic parameters
133
- model = self.model_name
134
- temperature = kwargs.get("temperature", 0.0)
135
-
136
- # construct messages and add system prompt if present
137
- chat_messages, user_content = [], []
138
- for message in messages:
139
- # flush user content into a message and add system message
140
- if message["role"] == "system":
141
- if len(user_content) > 0:
142
- chat_messages.append({"role": "user", "content": user_content})
143
- user_content = []
144
-
145
- chat_messages.append({"role": self.system_role, "content": message["content"]})
146
-
147
- # add user content for text messages
148
- elif message["role"] == "user" and message["type"] == "text":
149
- user_content.append({"type": "text", "text": message["content"]})
150
-
151
- # add user content for image messages
152
- elif message["role"] == "user" and message["type"] == "image":
153
- user_content.append({"type": "image_url", "image_url": {"url": message["content"]}})
154
-
155
- # flush any remaining user content into a final message
156
- if len(user_content) > 0:
157
- chat_messages.append({"role": "user", "content": user_content})
158
-
159
- # construct and return payload
160
- payload = {
161
- "model": model,
162
- "temperature": temperature,
163
- "messages": chat_messages,
164
- }
165
-
166
- return payload
125
+ self.prompt_factory = PromptFactory(prompt_strategy, model, cardinality, desc)
167
126
 
168
127
  def _parse_reasoning(self, completion_text: str, **kwargs) -> str:
169
128
  """Extract the reasoning for the generated output from the completion object."""
@@ -183,7 +142,7 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
183
142
  # otherwise, return the full completion text
184
143
  return completion_text
185
144
 
186
- def _prepare_field_answers(self, field_answers: dict | list[dict], fields: dict[str, Field]) -> dict[str, list]:
145
+ def _prepare_field_answers(self, field_answers: dict | list[dict], fields: dict[str, FieldInfo]) -> dict[str, list]:
187
146
  """
188
147
  field_answers is a dictionary mapping fields to their values. For one-to-one converts, wrap each
189
148
  answer in a list. For one-to-many converts, invert the list of dictionaries into a dictionary with
@@ -205,7 +164,7 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
205
164
 
206
165
  return field_answers
207
166
 
208
- def _check_convert_answer_text(self, answer_text: str, fields: dict[str, Field], throw_exception: bool=False) -> dict | list[dict] | None:
167
+ def _check_convert_answer_text(self, answer_text: str, fields: dict[str, FieldInfo], throw_exception: bool=False) -> dict | list[dict] | None:
209
168
  """
210
169
  Try parsing the answer text into a JSON object. If the parsing fails, return None.
211
170
  """
@@ -213,18 +172,6 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
213
172
  # extract json from the answer text
214
173
  field_answers = get_json_from_answer(answer_text, self.model, self.cardinality)
215
174
 
216
- # TODO: wrap non-list outputs in a list if expected output is a list
217
-
218
- # common error for one-to-one: if the output is a singleton list which contains a list, but the expected field type
219
- # is a list of strings, or a list of floats, i.e. not a list of lists; then extract the inner list
220
- if self.cardinality == Cardinality.ONE_TO_ONE:
221
- for field, field_type in fields.items():
222
- answer = field_answers[field]
223
- field_type_is_not_list_of_lists = isinstance(field_type, ListField) and not issubclass(field_type.element_type, ListField)
224
- answer_is_list_of_lists = isinstance(answer, list) and len(answer) == 1 and isinstance(answer[0], list)
225
- if field_type_is_not_list_of_lists and answer_is_list_of_lists:
226
- field_answers[field] = answer[0]
227
-
228
175
  # prepare the field answers to match the expected output and return
229
176
  return self._prepare_field_answers(field_answers, fields)
230
177
 
@@ -234,7 +181,7 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
234
181
 
235
182
  return None
236
183
 
237
- def _check_filter_answer_text(self, answer_text: str) -> dict | None:
184
+ def _check_bool_answer_text(self, answer_text: str) -> dict | None:
238
185
  """
239
186
  Return {"passed_operator": True} if and only if "true" is in the answer text.
240
187
  Return {"passed_operator": False} if and only if "false" is in the answer text.
@@ -249,7 +196,7 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
249
196
 
250
197
  return None
251
198
 
252
- def _parse_convert_answer(self, completion_text: str, fields: dict[str, Field], json_output: bool) -> dict[str, list]:
199
+ def _parse_convert_answer(self, completion_text: str, fields: dict[str, FieldInfo], json_output: bool) -> dict[str, list]:
253
200
  """Extract the answer from the completion object for convert operations."""
254
201
  # if the model followed the default instructions, the completion text will place
255
202
  # its answer between "ANSWER:" and "---"
@@ -288,15 +235,15 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
288
235
 
289
236
  return self._check_convert_answer_text(completion_text, fields, throw_exception=True)
290
237
 
291
- def _parse_filter_answer(self, completion_text: str) -> dict[str, list]:
292
- """Extract the answer from the completion object for filter operations."""
238
+ def _parse_bool_answer(self, completion_text: str) -> dict[str, list]:
239
+ """Extract the answer from the completion object for filter and join operations."""
293
240
  # if the model followed the default instructions, the completion text will place
294
241
  # its answer between "ANSWER:" and "---"
295
242
  regex = re.compile("answer:(.*?)---", re.IGNORECASE | re.DOTALL)
296
243
  matches = regex.findall(completion_text)
297
244
  if len(matches) > 0:
298
245
  answer_text = matches[0].strip()
299
- field_answers = self._check_filter_answer_text(answer_text)
246
+ field_answers = self._check_bool_answer_text(answer_text)
300
247
  if field_answers is not None:
301
248
  return field_answers
302
249
 
@@ -305,18 +252,18 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
305
252
  matches = regex.findall(completion_text)
306
253
  if len(matches) > 0:
307
254
  answer_text = matches[0].strip()
308
- field_answers = self._check_filter_answer_text(answer_text)
255
+ field_answers = self._check_bool_answer_text(answer_text)
309
256
  if field_answers is not None:
310
257
  return field_answers
311
258
 
312
259
  # finally, try taking all of the text; throw an exception if this doesn't work
313
- field_answers = self._check_filter_answer_text(completion_text)
260
+ field_answers = self._check_bool_answer_text(completion_text)
314
261
  if field_answers is None:
315
262
  raise Exception(f"Could not parse answer from completion text: {completion_text}")
316
263
 
317
264
  return field_answers
318
265
 
319
- def _parse_answer(self, completion_text: str, fields: dict[str, Field] | None, json_output: bool, **kwargs) -> dict[str, list]:
266
+ def _parse_answer(self, completion_text: str, fields: dict[str, FieldInfo] | None, json_output: bool, **kwargs) -> dict[str, list]:
320
267
  """Extract the answer from the completion object."""
321
268
  # use a custom answer parser if provided
322
269
  if kwargs.get("parse_answer"):
@@ -328,16 +275,15 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
328
275
 
329
276
  # extract the per-field answers from the completion text
330
277
  field_answers = (
331
- self._parse_filter_answer(completion_text)
332
- if self.prompt_strategy.is_bool_prompt()
278
+ self._parse_bool_answer(completion_text)
279
+ if self.prompt_strategy.is_bool_prompt() or self.prompt_strategy.is_join_prompt()
333
280
  else self._parse_convert_answer(completion_text, fields, json_output)
334
281
  )
335
282
 
336
283
  return field_answers
337
284
 
338
- def __call__(self, candidate: DataRecord, fields: dict[str, Field] | None, json_output: bool=True, **kwargs) -> GenerationOutput:
285
+ def __call__(self, candidate: DataRecord, fields: dict[str, FieldInfo] | None, right_candidate: DataRecord | None = None, json_output: bool=True, **kwargs) -> GenerationOutput:
339
286
  """Take the input record (`candidate`), generate the output `fields`, and return the generated output."""
340
- client = self._get_client_or_model()
341
287
  logger.debug(f"Generating for candidate {candidate} with fields {fields}")
342
288
 
343
289
  # fields can only be None if the user provides an answer parser
@@ -352,23 +298,45 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
352
298
  warnings.warn("Provided `system_prompt` without providing `prompt`; setting `prompt` = `system_prompt`.") # noqa: B028
353
299
 
354
300
  # generate a list of messages which can be used to construct a payload
355
- messages = self.prompt_factory.create_messages(candidate, fields, **kwargs)
356
-
357
- # create the chat payload
358
- chat_payload = self._generate_payload(messages, **kwargs)
301
+ messages = self.prompt_factory.create_messages(candidate, fields, right_candidate, **kwargs)
359
302
 
360
303
  # generate the text completion
361
304
  start_time = time.time()
362
305
  completion = None
363
306
  try:
364
- completion = self._generate_completion(client, chat_payload, **kwargs)
307
+ completion_kwargs = {}
308
+ if not self.model.is_o_model() and not self.model.is_gpt_5_model():
309
+ completion_kwargs = {"temperature": kwargs.get("temperature", 0.0), **completion_kwargs}
310
+ if self.prompt_strategy.is_audio_prompt():
311
+ completion_kwargs = {"modalities": ["text"], **completion_kwargs}
312
+ if self.model.is_reasoning_model():
313
+ if self.model.is_vertex_model():
314
+ reasoning_effort = self.reasoning_effort
315
+ if self.reasoning_effort is None and self.model == Model.GEMINI_2_5_PRO:
316
+ reasoning_effort = "low"
317
+ elif self.reasoning_effort is None:
318
+ reasoning_effort = "disable"
319
+ completion_kwargs = {"reasoning_effort": reasoning_effort, **completion_kwargs}
320
+ elif self.model.is_anthropic_model() and self.reasoning_effort is not None:
321
+ completion_kwargs = {"reasoning_effort": self.reasoning_effort, **completion_kwargs}
322
+ elif self.model.is_openai_model():
323
+ reasoning_effort = "minimal" if self.reasoning_effort is None else self.reasoning_effort
324
+ completion_kwargs = {"reasoning_effort": reasoning_effort, **completion_kwargs}
325
+ if self.model.is_vllm_model():
326
+ completion_kwargs = {"api_base": self.api_base, **completion_kwargs}
327
+ completion = litellm.completion(model=self.model_name, messages=messages, **completion_kwargs)
365
328
  end_time = time.time()
366
329
  logger.debug(f"Generated completion in {end_time - start_time:.2f} seconds")
367
330
  # if there's an error generating the completion, we have to return an empty answer
368
331
  # and can only account for the time spent performing the failed generation
369
- except Exception:
370
- # logger.error(f"Error generating completion: {e}")
371
- field_answers = {field_name: None for field_name in fields}
332
+ except Exception as e:
333
+ print(f"Error generating completion: {e}")
334
+ logger.error(f"Error generating completion: {e}")
335
+ field_answers = (
336
+ {"passed_operator": False}
337
+ if self.prompt_strategy.is_bool_prompt() or self.prompt_strategy.is_join_prompt()
338
+ else {field_name: None for field_name in fields}
339
+ )
372
340
  reasoning = None
373
341
  generation_stats = GenerationStats(
374
342
  model_name=self.model_name,
@@ -381,40 +349,57 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
381
349
  # parse usage statistics and create the GenerationStats
382
350
  generation_stats = None
383
351
  if completion is not None:
384
- usage = self._get_usage(completion, **kwargs)
385
- # finish_reason = self._get_finish_reason(completion, **kwargs)
386
- # answer_log_probs = self._get_answer_log_probs(completion, **kwargs)
352
+ usage = completion.usage.model_dump()
387
353
 
388
- # get cost per input/output token for the model and parse number of input and output tokens
389
- usd_per_input_token = MODEL_CARDS[self.model_name]["usd_per_input_token"]
354
+ # get cost per input/output token for the model
355
+ usd_per_input_token = MODEL_CARDS[self.model_name].get("usd_per_input_token", 0.0)
356
+ usd_per_audio_input_token = MODEL_CARDS[self.model_name].get("usd_per_audio_input_token", 0.0)
390
357
  usd_per_output_token = MODEL_CARDS[self.model_name]["usd_per_output_token"]
391
- input_tokens = usage["input_tokens"]
392
- output_tokens = usage["output_tokens"]
358
+
359
+ # TODO: for some models (e.g. GPT-5) we cannot separate text from image prompt tokens yet;
360
+ # for now, we only use tokens from prompt_token_details if it's an audio prompt
361
+ # get output tokens (all text) and input tokens by modality
362
+ output_tokens = usage["completion_tokens"]
363
+ if self.prompt_strategy.is_audio_prompt():
364
+ input_audio_tokens = usage["prompt_tokens_details"].get("audio_tokens", 0)
365
+ input_text_tokens = usage["prompt_tokens_details"].get("text_tokens", 0)
366
+ input_image_tokens = 0
367
+ else:
368
+ input_audio_tokens = 0
369
+ input_text_tokens = usage["prompt_tokens"]
370
+ input_image_tokens = 0
371
+ input_tokens = input_audio_tokens + input_text_tokens + input_image_tokens
372
+
373
+ # compute the input and output token costs
374
+ total_input_cost = (input_text_tokens + input_image_tokens) * usd_per_input_token + input_audio_tokens * usd_per_audio_input_token
375
+ total_output_cost = output_tokens * usd_per_output_token
393
376
 
394
377
  generation_stats = GenerationStats(
395
378
  model_name=self.model_name,
396
379
  llm_call_duration_secs=end_time - start_time,
397
380
  fn_call_duration_secs=0.0,
381
+ input_audio_tokens=input_audio_tokens,
382
+ input_text_tokens=input_text_tokens,
383
+ input_image_tokens=input_image_tokens,
398
384
  total_input_tokens=input_tokens,
399
385
  total_output_tokens=output_tokens,
400
- total_input_cost=input_tokens * usd_per_input_token,
401
- total_output_cost=output_tokens * usd_per_output_token,
402
- cost_per_record=input_tokens * usd_per_input_token + output_tokens * usd_per_output_token,
386
+ total_input_cost=total_input_cost,
387
+ total_output_cost=total_output_cost,
388
+ cost_per_record=total_input_cost + total_output_cost,
403
389
  total_llm_calls=1,
404
- # "system_prompt": system_prompt,
405
- # "prompt": prompt,
406
- # "usage": usage,
407
- # "finish_reason": finish_reason,
408
- # "answer_log_probs": answer_log_probs,
409
- # "answer": answer,
410
390
  )
411
391
 
412
392
  # pretty print prompt + full completion output for debugging
413
- completion_text = self._get_completion_text(completion, **kwargs)
393
+ completion_text = completion.choices[0].message.content
414
394
  prompt = ""
415
395
  for message in messages:
416
396
  if message["role"] == "user":
417
- prompt += message["content"] + "\n" if message["type"] == "text" else "<image>\n"
397
+ if message["type"] == "text":
398
+ prompt += message["content"] + "\n"
399
+ elif message["type"] == "image":
400
+ prompt += "<image>\n"
401
+ elif message["type"] == "input_audio":
402
+ prompt += "<audio>\n"
418
403
  logger.debug(f"PROMPT:\n{prompt}")
419
404
  logger.debug(Fore.GREEN + f"{completion_text}\n" + Style.RESET_ALL)
420
405
 
@@ -422,17 +407,20 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
422
407
  reasoning = None
423
408
  try:
424
409
  reasoning = self._parse_reasoning(completion_text, **kwargs)
425
- except Exception:
426
- # logger.error(f"Error parsing reasoning and answers: {e}")
427
- logger.debug("TODO: undo this")
410
+ except Exception as e:
411
+ logger.error(f"Error parsing reasoning and answers: {e}")
428
412
  pass
429
413
 
430
414
  # parse field answers
431
- field_answers = None if fields is None else {field_name: None for field_name in fields}
415
+ field_answers = None
416
+ if fields is not None and (self.prompt_strategy.is_bool_prompt() or self.prompt_strategy.is_join_prompt()):
417
+ field_answers = {"passed_operator": False}
418
+ elif fields is not None and not (self.prompt_strategy.is_bool_prompt() or self.prompt_strategy.is_join_prompt()):
419
+ field_answers = {field_name: None for field_name in fields}
432
420
  try:
433
421
  field_answers = self._parse_answer(completion_text, fields, json_output, **kwargs)
434
422
  except Exception as e:
435
- # logger.error(f"Error parsing answers: {e}")
423
+ logger.error(f"Error parsing answers: {e}")
436
424
  os.makedirs("parse-answer-errors", exist_ok=True)
437
425
  ts = time.time()
438
426
  with open(f"parse-answer-errors/error-{ts}.txt", "w") as f:
@@ -448,162 +436,3 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
448
436
 
449
437
  logger.debug(f"Generated field answers: {field_answers}")
450
438
  return field_answers, reasoning, generation_stats, messages
451
-
452
-
453
- class OpenAIGenerator(BaseGenerator[str | list[str], str]):
454
- """
455
- Class for generating text using the OpenAI chat API.
456
- """
457
-
458
- def __init__(
459
- self,
460
- model: Model,
461
- prompt_strategy: PromptStrategy,
462
- cardinality: Cardinality = Cardinality.ONE_TO_ONE,
463
- verbose: bool = False,
464
- ):
465
- # assert that model is an OpenAI model
466
- assert model.is_openai_model()
467
- super().__init__(model, prompt_strategy, cardinality, verbose, "developer")
468
-
469
- def _get_client_or_model(self, **kwargs) -> OpenAI:
470
- """Returns a client (or local model) which can be invoked to perform the generation."""
471
- return APIClientFactory.get_client(APIClient.OPENAI, get_api_key("OPENAI_API_KEY"))
472
-
473
- def _generate_completion(self, client: OpenAI, payload: dict, **kwargs) -> ChatCompletion:
474
- """Generates a completion object using the client (or local model)."""
475
- return client.chat.completions.create(**payload)
476
-
477
- def _get_completion_text(self, completion: ChatCompletion, **kwargs) -> str:
478
- """Extract the completion text from the completion object."""
479
- return completion.choices[0].message.content
480
-
481
- def _get_usage(self, completion: ChatCompletion, **kwargs) -> dict:
482
- """Extract the usage statistics from the completion object."""
483
- return {
484
- "input_tokens": completion.usage.prompt_tokens,
485
- "output_tokens": completion.usage.completion_tokens,
486
- }
487
-
488
- def _get_finish_reason(self, completion: ChatCompletion, **kwargs) -> str:
489
- """Extract the finish reason from the completion object."""
490
- return completion.choices[0].finish_reason
491
-
492
- def _get_answer_log_probs(self, completion: ChatCompletion, **kwargs) -> list[float]:
493
- """Extract the log probabilities from the completion object."""
494
- return completion.choices[0].logprobs
495
-
496
-
497
- class TogetherGenerator(BaseGenerator[str | list[str], str]):
498
- """
499
- Class for generating text using the Together chat API.
500
- """
501
-
502
- def __init__(
503
- self,
504
- model: Model,
505
- prompt_strategy: PromptStrategy,
506
- cardinality: Cardinality = Cardinality.ONE_TO_ONE,
507
- verbose: bool = False,
508
- ):
509
- # assert that model is a model offered by Together
510
- assert model.is_together_model()
511
- super().__init__(model, prompt_strategy, cardinality, verbose, "system")
512
-
513
- def _generate_payload(self, messages: list[dict], **kwargs) -> dict:
514
- """
515
- Generates the payload which will be fed into the client (or local model).
516
-
517
- Each message will be a dictionary with the following format:
518
- {
519
- "role": "user" | "system",
520
- "type": "text" | "image",
521
- "content": str
522
- }
523
-
524
- For LLAMA3, the payload needs to be in a {"role": <role>, "content": <content>} format.
525
- """
526
- # for other models, use our standard payload generation
527
- if not self.model.is_llama_model():
528
- return super()._generate_payload(messages, **kwargs)
529
-
530
- # get basic parameters
531
- model = self.model_name
532
- temperature = kwargs.get("temperature", 0.0)
533
-
534
- # construct messages in simple {"role": <role>, "content": <content>} format
535
- chat_messages = []
536
- for message in messages:
537
- chat_messages.append({"role": message["role"], "content": message["content"]})
538
-
539
- # construct and return payload
540
- payload = {
541
- "model": model,
542
- "temperature": temperature,
543
- "messages": chat_messages,
544
- }
545
-
546
- return payload
547
-
548
- def _get_client_or_model(self, **kwargs) -> Together:
549
- """Returns a client (or local model) which can be invoked to perform the generation."""
550
- return APIClientFactory.get_client(APIClient.TOGETHER, get_api_key("TOGETHER_API_KEY"))
551
-
552
- def _generate_completion(self, client: Together, payload: dict, **kwargs) -> ChatCompletionResponse:
553
- """Generates a completion object using the client (or local model)."""
554
- return client.chat.completions.create(**payload)
555
-
556
- def _get_completion_text(self, completion: ChatCompletionResponse, **kwargs) -> str:
557
- """Extract the completion text from the completion object."""
558
- return completion.choices[0].message.content
559
-
560
- def _get_usage(self, completion: ChatCompletionResponse, **kwargs) -> dict:
561
- """Extract the usage statistics from the completion object."""
562
- return {
563
- "input_tokens": completion.usage.prompt_tokens,
564
- "output_tokens": completion.usage.completion_tokens,
565
- }
566
-
567
- def _get_finish_reason(self, completion: ChatCompletionResponse, **kwargs) -> str:
568
- """Extract the finish reason from the completion object."""
569
- return completion.choices[0].finish_reason.value
570
-
571
- def _get_answer_log_probs(self, completion: ChatCompletionResponse, **kwargs) -> list[float]:
572
- """Extract the log probabilities from the completion object."""
573
- return completion.choices[0].logprobs
574
-
575
-
576
- ### CODE SYNTHESIS EXECUTION ###
577
- def code_execution(api: API, code: str, candidate_dict: dict[str, Any], verbose: bool = False):
578
- inputs = {field_name: candidate_dict[field_name] for field_name in api.inputs}
579
- response = api.api_execute(code, inputs)
580
- pred = response["response"] if response["status"] and response["response"] else None
581
- return pred
582
-
583
-
584
- def code_ensemble_execution(
585
- api: API, code_ensemble: dict[str, str], candidate_dict: dict[str, Any], verbose: bool = True
586
- ) -> GenerationOutput:
587
- start_time = time.time()
588
- try:
589
- preds = list()
590
- for _, code in code_ensemble.items():
591
- pred = code_execution(api, code, candidate_dict)
592
- preds.append(pred)
593
-
594
- preds = [pred for pred in preds if pred is not None]
595
-
596
- if len(preds) == 1:
597
- majority_response = preds[0]
598
- exec_stats = GenerationStats(fn_call_duration_secs=time.time() - start_time)
599
- return majority_response, None, exec_stats
600
-
601
- if len(preds) > 0:
602
- majority_response = Counter(preds).most_common(1)[0][0]
603
- exec_stats = GenerationStats(fn_call_duration_secs=time.time() - start_time)
604
- return majority_response, None, exec_stats
605
-
606
- except Exception:
607
- pass
608
-
609
- return None, None, GenerationStats(fn_call_duration_secs=time.time() - start_time)