palimpzest 0.7.21__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/__init__.py +37 -6
- palimpzest/agents/__init__.py +0 -0
- palimpzest/agents/compute_agents.py +0 -0
- palimpzest/agents/search_agents.py +637 -0
- palimpzest/constants.py +259 -197
- palimpzest/core/data/context.py +393 -0
- palimpzest/core/data/context_manager.py +163 -0
- palimpzest/core/data/dataset.py +634 -0
- palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
- palimpzest/core/elements/groupbysig.py +16 -13
- palimpzest/core/elements/records.py +166 -75
- palimpzest/core/lib/schemas.py +152 -390
- palimpzest/core/{data/dataclasses.py → models.py} +306 -170
- palimpzest/policy.py +2 -27
- palimpzest/prompts/__init__.py +35 -5
- palimpzest/prompts/agent_prompts.py +357 -0
- palimpzest/prompts/context_search.py +9 -0
- palimpzest/prompts/convert_prompts.py +61 -5
- palimpzest/prompts/filter_prompts.py +50 -5
- palimpzest/prompts/join_prompts.py +163 -0
- palimpzest/prompts/moa_proposer_convert_prompts.py +5 -5
- palimpzest/prompts/prompt_factory.py +358 -46
- palimpzest/prompts/validator.py +239 -0
- palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
- palimpzest/query/execution/execution_strategy.py +210 -317
- palimpzest/query/execution/execution_strategy_type.py +5 -7
- palimpzest/query/execution/mab_execution_strategy.py +249 -136
- palimpzest/query/execution/parallel_execution_strategy.py +153 -244
- palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
- palimpzest/query/generators/generators.py +157 -330
- palimpzest/query/operators/__init__.py +15 -5
- palimpzest/query/operators/aggregate.py +50 -33
- palimpzest/query/operators/compute.py +201 -0
- palimpzest/query/operators/convert.py +27 -21
- palimpzest/query/operators/critique_and_refine_convert.py +7 -5
- palimpzest/query/operators/distinct.py +62 -0
- palimpzest/query/operators/filter.py +22 -13
- palimpzest/query/operators/join.py +402 -0
- palimpzest/query/operators/limit.py +3 -3
- palimpzest/query/operators/logical.py +198 -80
- palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
- palimpzest/query/operators/physical.py +27 -21
- palimpzest/query/operators/project.py +3 -3
- palimpzest/query/operators/rag_convert.py +7 -7
- palimpzest/query/operators/retrieve.py +9 -9
- palimpzest/query/operators/scan.py +81 -42
- palimpzest/query/operators/search.py +524 -0
- palimpzest/query/operators/split_convert.py +10 -8
- palimpzest/query/optimizer/__init__.py +7 -9
- palimpzest/query/optimizer/cost_model.py +108 -441
- palimpzest/query/optimizer/optimizer.py +123 -181
- palimpzest/query/optimizer/optimizer_strategy.py +66 -61
- palimpzest/query/optimizer/plan.py +352 -67
- palimpzest/query/optimizer/primitives.py +43 -19
- palimpzest/query/optimizer/rules.py +484 -646
- palimpzest/query/optimizer/tasks.py +127 -58
- palimpzest/query/processor/config.py +41 -76
- palimpzest/query/processor/query_processor.py +73 -18
- palimpzest/query/processor/query_processor_factory.py +46 -38
- palimpzest/schemabuilder/schema_builder.py +15 -28
- palimpzest/utils/model_helpers.py +27 -77
- palimpzest/utils/progress.py +114 -102
- palimpzest/validator/__init__.py +0 -0
- palimpzest/validator/validator.py +306 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/METADATA +6 -1
- palimpzest-0.8.0.dist-info/RECORD +95 -0
- palimpzest/core/lib/fields.py +0 -141
- palimpzest/prompts/code_synthesis_prompts.py +0 -28
- palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
- palimpzest/query/generators/api_client_factory.py +0 -30
- palimpzest/query/operators/code_synthesis_convert.py +0 -488
- palimpzest/query/operators/map.py +0 -130
- palimpzest/query/processor/nosentinel_processor.py +0 -33
- palimpzest/query/processor/processing_strategy_type.py +0 -28
- palimpzest/query/processor/sentinel_processor.py +0 -88
- palimpzest/query/processor/streaming_processor.py +0 -149
- palimpzest/sets.py +0 -405
- palimpzest/utils/datareader_helpers.py +0 -61
- palimpzest/utils/demo_helpers.py +0 -75
- palimpzest/utils/field_helpers.py +0 -69
- palimpzest/utils/generation_helpers.py +0 -69
- palimpzest/utils/sandbox.py +0 -183
- palimpzest-0.7.21.dist-info/RECORD +0 -95
- /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/WHEEL +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/top_level.txt +0 -0
|
@@ -4,36 +4,28 @@ This file contains the Generator classes and generator factory.
|
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
|
+
import json
|
|
7
8
|
import logging
|
|
8
9
|
import os
|
|
9
|
-
import re
|
|
10
10
|
import time
|
|
11
11
|
import warnings
|
|
12
|
-
from abc import ABC, abstractmethod
|
|
13
|
-
from collections import Counter
|
|
14
12
|
from copy import deepcopy
|
|
15
13
|
from typing import Any, Generic, TypeVar
|
|
16
14
|
|
|
15
|
+
import litellm
|
|
16
|
+
import regex as re # Use regex instead of re to used variable length lookbehind
|
|
17
17
|
from colorama import Fore, Style
|
|
18
|
-
from
|
|
19
|
-
from openai.types.chat.chat_completion import ChatCompletion
|
|
20
|
-
from together import Together
|
|
21
|
-
from together.types.chat_completions import ChatCompletionResponse
|
|
18
|
+
from pydantic.fields import FieldInfo
|
|
22
19
|
|
|
23
20
|
from palimpzest.constants import (
|
|
24
21
|
MODEL_CARDS,
|
|
25
|
-
APIClient,
|
|
26
22
|
Cardinality,
|
|
27
23
|
Model,
|
|
28
24
|
PromptStrategy,
|
|
29
25
|
)
|
|
30
|
-
from palimpzest.core.data.dataclasses import GenerationStats
|
|
31
26
|
from palimpzest.core.elements.records import DataRecord
|
|
32
|
-
from palimpzest.core.
|
|
27
|
+
from palimpzest.core.models import GenerationStats
|
|
33
28
|
from palimpzest.prompts import PromptFactory
|
|
34
|
-
from palimpzest.query.generators.api_client_factory import APIClientFactory
|
|
35
|
-
from palimpzest.utils.generation_helpers import get_json_from_answer
|
|
36
|
-
from palimpzest.utils.sandbox import API
|
|
37
29
|
|
|
38
30
|
# DEFINITIONS
|
|
39
31
|
GenerationOutput = tuple[dict, str | None, GenerationStats, list[dict]]
|
|
@@ -43,31 +35,71 @@ InputType = TypeVar("InputType")
|
|
|
43
35
|
|
|
44
36
|
logger = logging.getLogger(__name__)
|
|
45
37
|
|
|
46
|
-
def
|
|
47
|
-
model: Model, prompt_strategy: PromptStrategy, cardinality: Cardinality, verbose: bool = False
|
|
48
|
-
) -> BaseGenerator:
|
|
38
|
+
def get_json_from_answer(answer: str, model: Model, cardinality: Cardinality) -> dict[str, Any]:
|
|
49
39
|
"""
|
|
50
|
-
|
|
40
|
+
This function parses an LLM response which is supposed to output a JSON object
|
|
41
|
+
and optimistically searches for the substring containing the JSON object.
|
|
51
42
|
"""
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
#
|
|
63
|
-
if
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
43
|
+
# model-specific trimming for LLAMA3 responses
|
|
44
|
+
if model.is_llama_model():
|
|
45
|
+
answer = answer.split("---")[0]
|
|
46
|
+
answer = answer.replace("True", "true")
|
|
47
|
+
answer = answer.replace("False", "false")
|
|
48
|
+
|
|
49
|
+
# split off context / excess, which models sometimes output after answer
|
|
50
|
+
answer = answer.split("Context:")[0]
|
|
51
|
+
answer = answer.split("# this is the answer")[0]
|
|
52
|
+
|
|
53
|
+
# trim the answer to only include the JSON dictionary
|
|
54
|
+
if cardinality == Cardinality.ONE_TO_ONE:
|
|
55
|
+
if not answer.strip().startswith("{"):
|
|
56
|
+
# Find the start index of the actual JSON string assuming the prefix is followed by the JSON dictionary
|
|
57
|
+
start_index = answer.find("{")
|
|
58
|
+
if start_index != -1:
|
|
59
|
+
# Remove the prefix and any leading characters before the JSON starts
|
|
60
|
+
answer = answer[start_index:]
|
|
61
|
+
|
|
62
|
+
if not answer.strip().endswith("}"):
|
|
63
|
+
# Find the end index of the actual JSON string assuming the suffix is preceded by the JSON dictionary
|
|
64
|
+
end_index = answer.rfind("}")
|
|
65
|
+
if end_index != -1:
|
|
66
|
+
# Remove the suffix and any trailing characters after the JSON ends
|
|
67
|
+
answer = answer[: end_index + 1]
|
|
68
|
+
|
|
69
|
+
# otherwise, trim the answer to only include the JSON array
|
|
70
|
+
else:
|
|
71
|
+
if not answer.strip().startswith("["):
|
|
72
|
+
# Find the start index of the actual JSON string assuming the prefix is followed by the JSON array
|
|
73
|
+
start_index = answer.find("[")
|
|
74
|
+
if start_index != -1:
|
|
75
|
+
# Remove the prefix and any leading characters before the JSON starts
|
|
76
|
+
answer = answer[start_index:]
|
|
77
|
+
|
|
78
|
+
if not answer.strip().endswith("]"):
|
|
79
|
+
# Find the end index of the actual JSON string
|
|
80
|
+
# assuming the suffix is preceded by the JSON object/array
|
|
81
|
+
end_index = answer.rfind("]")
|
|
82
|
+
if end_index != -1:
|
|
83
|
+
# Remove the suffix and any trailing characters after the JSON ends
|
|
84
|
+
answer = answer[: end_index + 1]
|
|
85
|
+
|
|
86
|
+
# Handle weird escaped values. I am not sure why the model
|
|
87
|
+
# is returning these, but the JSON parser can't take them
|
|
88
|
+
answer = answer.replace(r"\_", "_")
|
|
89
|
+
answer = answer.replace("\\n", "\n")
|
|
90
|
+
# Remove https and http prefixes to not conflict with comment detection
|
|
91
|
+
# Handle comments in the JSON response. Use regex from // until end of line
|
|
92
|
+
answer = re.sub(r"(?<!https?:)\/\/.*?$", "", answer, flags=re.MULTILINE)
|
|
93
|
+
answer = re.sub(r",\n.*\.\.\.$", "", answer, flags=re.MULTILINE)
|
|
94
|
+
# Sanitize newlines in the JSON response
|
|
95
|
+
answer = answer.replace("\n", " ")
|
|
96
|
+
|
|
97
|
+
# finally, parse and return the JSON object; errors are handled by the caller
|
|
98
|
+
return json.loads(answer)
|
|
99
|
+
|
|
100
|
+
# TODO: push parallelism of generations into LiteLLM rather than threadpool in executor
|
|
69
101
|
# TODO: make sure answer parsing works with custom prompts / parsers (can defer this)
|
|
70
|
-
class
|
|
102
|
+
class Generator(Generic[ContextType, InputType]):
|
|
71
103
|
"""
|
|
72
104
|
Abstract base class for Generators.
|
|
73
105
|
"""
|
|
@@ -76,95 +108,20 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
76
108
|
self,
|
|
77
109
|
model: Model,
|
|
78
110
|
prompt_strategy: PromptStrategy,
|
|
111
|
+
reasoning_effort: str | None = None,
|
|
112
|
+
api_base: str | None = None,
|
|
79
113
|
cardinality: Cardinality = Cardinality.ONE_TO_ONE,
|
|
80
114
|
verbose: bool = False,
|
|
81
|
-
system_role: str = "system",
|
|
82
115
|
):
|
|
83
116
|
self.model = model
|
|
84
117
|
self.model_name = model.value
|
|
85
118
|
self.cardinality = cardinality
|
|
86
119
|
self.prompt_strategy = prompt_strategy
|
|
120
|
+
self.reasoning_effort = reasoning_effort
|
|
121
|
+
self.api_base = api_base
|
|
87
122
|
self.verbose = verbose
|
|
88
|
-
self.system_role = system_role
|
|
89
123
|
self.prompt_factory = PromptFactory(prompt_strategy, model, cardinality)
|
|
90
124
|
|
|
91
|
-
@abstractmethod
|
|
92
|
-
def _get_client_or_model(self, **kwargs) -> Any:
|
|
93
|
-
"""Returns a client (or local model) which can be invoked to perform the generation."""
|
|
94
|
-
pass
|
|
95
|
-
|
|
96
|
-
@abstractmethod
|
|
97
|
-
def _generate_completion(self, client_or_model: Any, payload: dict, **kwargs) -> Any:
|
|
98
|
-
"""Generates a completion object using the client (or local model)."""
|
|
99
|
-
pass
|
|
100
|
-
|
|
101
|
-
@abstractmethod
|
|
102
|
-
def _get_completion_text(self, completion: Any, **kwargs) -> Any:
|
|
103
|
-
"""Extract the completion text from the completion object."""
|
|
104
|
-
pass
|
|
105
|
-
|
|
106
|
-
@abstractmethod
|
|
107
|
-
def _get_usage(self, completion: Any, **kwargs) -> Any:
|
|
108
|
-
"""Extract the usage statistics from the completion object."""
|
|
109
|
-
pass
|
|
110
|
-
|
|
111
|
-
@abstractmethod
|
|
112
|
-
def _get_finish_reason(self, completion: Any, **kwargs) -> Any:
|
|
113
|
-
"""Extract the finish reason from the completion object."""
|
|
114
|
-
pass
|
|
115
|
-
|
|
116
|
-
@abstractmethod
|
|
117
|
-
def _get_answer_log_probs(self, completion: Any, **kwargs) -> Any:
|
|
118
|
-
"""Extract the log probabilities from the completion object."""
|
|
119
|
-
pass
|
|
120
|
-
|
|
121
|
-
def _generate_payload(self, messages: list[dict], **kwargs) -> dict:
|
|
122
|
-
"""
|
|
123
|
-
Generates the payload which will be fed into the client (or local model).
|
|
124
|
-
|
|
125
|
-
Each message will be a dictionary with the following format:
|
|
126
|
-
{
|
|
127
|
-
"role": "user" | "system",
|
|
128
|
-
"type": "text" | "image",
|
|
129
|
-
"content": str
|
|
130
|
-
}
|
|
131
|
-
"""
|
|
132
|
-
# get basic parameters
|
|
133
|
-
model = self.model_name
|
|
134
|
-
temperature = kwargs.get("temperature", 0.0)
|
|
135
|
-
|
|
136
|
-
# construct messages and add system prompt if present
|
|
137
|
-
chat_messages, user_content = [], []
|
|
138
|
-
for message in messages:
|
|
139
|
-
# flush user content into a message and add system message
|
|
140
|
-
if message["role"] == "system":
|
|
141
|
-
if len(user_content) > 0:
|
|
142
|
-
chat_messages.append({"role": "user", "content": user_content})
|
|
143
|
-
user_content = []
|
|
144
|
-
|
|
145
|
-
chat_messages.append({"role": self.system_role, "content": message["content"]})
|
|
146
|
-
|
|
147
|
-
# add user content for text messages
|
|
148
|
-
elif message["role"] == "user" and message["type"] == "text":
|
|
149
|
-
user_content.append({"type": "text", "text": message["content"]})
|
|
150
|
-
|
|
151
|
-
# add user content for image messages
|
|
152
|
-
elif message["role"] == "user" and message["type"] == "image":
|
|
153
|
-
user_content.append({"type": "image_url", "image_url": {"url": message["content"]}})
|
|
154
|
-
|
|
155
|
-
# flush any remaining user content into a final message
|
|
156
|
-
if len(user_content) > 0:
|
|
157
|
-
chat_messages.append({"role": "user", "content": user_content})
|
|
158
|
-
|
|
159
|
-
# construct and return payload
|
|
160
|
-
payload = {
|
|
161
|
-
"model": model,
|
|
162
|
-
"temperature": temperature,
|
|
163
|
-
"messages": chat_messages,
|
|
164
|
-
}
|
|
165
|
-
|
|
166
|
-
return payload
|
|
167
|
-
|
|
168
125
|
def _parse_reasoning(self, completion_text: str, **kwargs) -> str:
|
|
169
126
|
"""Extract the reasoning for the generated output from the completion object."""
|
|
170
127
|
# use a custom reasoning parser if provided
|
|
@@ -183,7 +140,7 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
183
140
|
# otherwise, return the full completion text
|
|
184
141
|
return completion_text
|
|
185
142
|
|
|
186
|
-
def _prepare_field_answers(self, field_answers: dict | list[dict], fields: dict[str,
|
|
143
|
+
def _prepare_field_answers(self, field_answers: dict | list[dict], fields: dict[str, FieldInfo]) -> dict[str, list]:
|
|
187
144
|
"""
|
|
188
145
|
field_answers is a dictionary mapping fields to their values. For one-to-one converts, wrap each
|
|
189
146
|
answer in a list. For one-to-many converts, invert the list of dictionaries into a dictionary with
|
|
@@ -205,7 +162,7 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
205
162
|
|
|
206
163
|
return field_answers
|
|
207
164
|
|
|
208
|
-
def _check_convert_answer_text(self, answer_text: str, fields: dict[str,
|
|
165
|
+
def _check_convert_answer_text(self, answer_text: str, fields: dict[str, FieldInfo], throw_exception: bool=False) -> dict | list[dict] | None:
|
|
209
166
|
"""
|
|
210
167
|
Try parsing the answer text into a JSON object. If the parsing fails, return None.
|
|
211
168
|
"""
|
|
@@ -213,18 +170,6 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
213
170
|
# extract json from the answer text
|
|
214
171
|
field_answers = get_json_from_answer(answer_text, self.model, self.cardinality)
|
|
215
172
|
|
|
216
|
-
# TODO: wrap non-list outputs in a list if expected output is a list
|
|
217
|
-
|
|
218
|
-
# common error for one-to-one: if the output is a singleton list which contains a list, but the expected field type
|
|
219
|
-
# is a list of strings, or a list of floats, i.e. not a list of lists; then extract the inner list
|
|
220
|
-
if self.cardinality == Cardinality.ONE_TO_ONE:
|
|
221
|
-
for field, field_type in fields.items():
|
|
222
|
-
answer = field_answers[field]
|
|
223
|
-
field_type_is_not_list_of_lists = isinstance(field_type, ListField) and not issubclass(field_type.element_type, ListField)
|
|
224
|
-
answer_is_list_of_lists = isinstance(answer, list) and len(answer) == 1 and isinstance(answer[0], list)
|
|
225
|
-
if field_type_is_not_list_of_lists and answer_is_list_of_lists:
|
|
226
|
-
field_answers[field] = answer[0]
|
|
227
|
-
|
|
228
173
|
# prepare the field answers to match the expected output and return
|
|
229
174
|
return self._prepare_field_answers(field_answers, fields)
|
|
230
175
|
|
|
@@ -234,7 +179,7 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
234
179
|
|
|
235
180
|
return None
|
|
236
181
|
|
|
237
|
-
def
|
|
182
|
+
def _check_bool_answer_text(self, answer_text: str) -> dict | None:
|
|
238
183
|
"""
|
|
239
184
|
Return {"passed_operator": True} if and only if "true" is in the answer text.
|
|
240
185
|
Return {"passed_operator": False} if and only if "false" is in the answer text.
|
|
@@ -249,7 +194,7 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
249
194
|
|
|
250
195
|
return None
|
|
251
196
|
|
|
252
|
-
def _parse_convert_answer(self, completion_text: str, fields: dict[str,
|
|
197
|
+
def _parse_convert_answer(self, completion_text: str, fields: dict[str, FieldInfo], json_output: bool) -> dict[str, list]:
|
|
253
198
|
"""Extract the answer from the completion object for convert operations."""
|
|
254
199
|
# if the model followed the default instructions, the completion text will place
|
|
255
200
|
# its answer between "ANSWER:" and "---"
|
|
@@ -288,15 +233,15 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
288
233
|
|
|
289
234
|
return self._check_convert_answer_text(completion_text, fields, throw_exception=True)
|
|
290
235
|
|
|
291
|
-
def
|
|
292
|
-
"""Extract the answer from the completion object for filter operations."""
|
|
236
|
+
def _parse_bool_answer(self, completion_text: str) -> dict[str, list]:
|
|
237
|
+
"""Extract the answer from the completion object for filter and join operations."""
|
|
293
238
|
# if the model followed the default instructions, the completion text will place
|
|
294
239
|
# its answer between "ANSWER:" and "---"
|
|
295
240
|
regex = re.compile("answer:(.*?)---", re.IGNORECASE | re.DOTALL)
|
|
296
241
|
matches = regex.findall(completion_text)
|
|
297
242
|
if len(matches) > 0:
|
|
298
243
|
answer_text = matches[0].strip()
|
|
299
|
-
field_answers = self.
|
|
244
|
+
field_answers = self._check_bool_answer_text(answer_text)
|
|
300
245
|
if field_answers is not None:
|
|
301
246
|
return field_answers
|
|
302
247
|
|
|
@@ -305,18 +250,18 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
305
250
|
matches = regex.findall(completion_text)
|
|
306
251
|
if len(matches) > 0:
|
|
307
252
|
answer_text = matches[0].strip()
|
|
308
|
-
field_answers = self.
|
|
253
|
+
field_answers = self._check_bool_answer_text(answer_text)
|
|
309
254
|
if field_answers is not None:
|
|
310
255
|
return field_answers
|
|
311
256
|
|
|
312
257
|
# finally, try taking all of the text; throw an exception if this doesn't work
|
|
313
|
-
field_answers = self.
|
|
258
|
+
field_answers = self._check_bool_answer_text(completion_text)
|
|
314
259
|
if field_answers is None:
|
|
315
260
|
raise Exception(f"Could not parse answer from completion text: {completion_text}")
|
|
316
261
|
|
|
317
262
|
return field_answers
|
|
318
263
|
|
|
319
|
-
def _parse_answer(self, completion_text: str, fields: dict[str,
|
|
264
|
+
def _parse_answer(self, completion_text: str, fields: dict[str, FieldInfo] | None, json_output: bool, **kwargs) -> dict[str, list]:
|
|
320
265
|
"""Extract the answer from the completion object."""
|
|
321
266
|
# use a custom answer parser if provided
|
|
322
267
|
if kwargs.get("parse_answer"):
|
|
@@ -328,16 +273,15 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
328
273
|
|
|
329
274
|
# extract the per-field answers from the completion text
|
|
330
275
|
field_answers = (
|
|
331
|
-
self.
|
|
332
|
-
if self.prompt_strategy.is_bool_prompt()
|
|
276
|
+
self._parse_bool_answer(completion_text)
|
|
277
|
+
if self.prompt_strategy.is_bool_prompt() or self.prompt_strategy.is_join_prompt()
|
|
333
278
|
else self._parse_convert_answer(completion_text, fields, json_output)
|
|
334
279
|
)
|
|
335
280
|
|
|
336
281
|
return field_answers
|
|
337
282
|
|
|
338
|
-
def __call__(self, candidate: DataRecord, fields: dict[str,
|
|
283
|
+
def __call__(self, candidate: DataRecord, fields: dict[str, FieldInfo] | None, right_candidate: DataRecord | None = None, json_output: bool=True, **kwargs) -> GenerationOutput:
|
|
339
284
|
"""Take the input record (`candidate`), generate the output `fields`, and return the generated output."""
|
|
340
|
-
client = self._get_client_or_model()
|
|
341
285
|
logger.debug(f"Generating for candidate {candidate} with fields {fields}")
|
|
342
286
|
|
|
343
287
|
# fields can only be None if the user provides an answer parser
|
|
@@ -352,23 +296,45 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
352
296
|
warnings.warn("Provided `system_prompt` without providing `prompt`; setting `prompt` = `system_prompt`.") # noqa: B028
|
|
353
297
|
|
|
354
298
|
# generate a list of messages which can be used to construct a payload
|
|
355
|
-
messages = self.prompt_factory.create_messages(candidate, fields, **kwargs)
|
|
356
|
-
|
|
357
|
-
# create the chat payload
|
|
358
|
-
chat_payload = self._generate_payload(messages, **kwargs)
|
|
299
|
+
messages = self.prompt_factory.create_messages(candidate, fields, right_candidate, **kwargs)
|
|
359
300
|
|
|
360
301
|
# generate the text completion
|
|
361
302
|
start_time = time.time()
|
|
362
303
|
completion = None
|
|
363
304
|
try:
|
|
364
|
-
|
|
305
|
+
completion_kwargs = {}
|
|
306
|
+
if not self.model.is_o_model() and not self.model.is_gpt_5_model():
|
|
307
|
+
completion_kwargs = {"temperature": kwargs.get("temperature", 0.0), **completion_kwargs}
|
|
308
|
+
if self.prompt_strategy.is_audio_prompt():
|
|
309
|
+
completion_kwargs = {"modalities": ["text"], **completion_kwargs}
|
|
310
|
+
if self.model.is_reasoning_model():
|
|
311
|
+
if self.model.is_vertex_model():
|
|
312
|
+
reasoning_effort = self.reasoning_effort
|
|
313
|
+
if self.reasoning_effort is None and self.model == Model.GEMINI_2_5_PRO:
|
|
314
|
+
reasoning_effort = "low"
|
|
315
|
+
elif self.reasoning_effort is None:
|
|
316
|
+
reasoning_effort = "disable"
|
|
317
|
+
completion_kwargs = {"reasoning_effort": reasoning_effort, **completion_kwargs}
|
|
318
|
+
elif self.model.is_anthropic_model() and self.reasoning_effort is not None:
|
|
319
|
+
completion_kwargs = {"reasoning_effort": self.reasoning_effort, **completion_kwargs}
|
|
320
|
+
elif self.model.is_openai_model():
|
|
321
|
+
reasoning_effort = "minimal" if self.reasoning_effort is None else self.reasoning_effort
|
|
322
|
+
completion_kwargs = {"reasoning_effort": reasoning_effort, **completion_kwargs}
|
|
323
|
+
if self.model.is_vllm_model():
|
|
324
|
+
completion_kwargs = {"api_base": self.api_base, **completion_kwargs}
|
|
325
|
+
completion = litellm.completion(model=self.model_name, messages=messages, **completion_kwargs)
|
|
365
326
|
end_time = time.time()
|
|
366
327
|
logger.debug(f"Generated completion in {end_time - start_time:.2f} seconds")
|
|
367
328
|
# if there's an error generating the completion, we have to return an empty answer
|
|
368
329
|
# and can only account for the time spent performing the failed generation
|
|
369
|
-
except Exception:
|
|
370
|
-
|
|
371
|
-
|
|
330
|
+
except Exception as e:
|
|
331
|
+
print(f"Error generating completion: {e}")
|
|
332
|
+
logger.error(f"Error generating completion: {e}")
|
|
333
|
+
field_answers = (
|
|
334
|
+
{"passed_operator": False}
|
|
335
|
+
if self.prompt_strategy.is_bool_prompt() or self.prompt_strategy.is_join_prompt()
|
|
336
|
+
else {field_name: None for field_name in fields}
|
|
337
|
+
)
|
|
372
338
|
reasoning = None
|
|
373
339
|
generation_stats = GenerationStats(
|
|
374
340
|
model_name=self.model_name,
|
|
@@ -381,40 +347,57 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
381
347
|
# parse usage statistics and create the GenerationStats
|
|
382
348
|
generation_stats = None
|
|
383
349
|
if completion is not None:
|
|
384
|
-
usage =
|
|
385
|
-
# finish_reason = self._get_finish_reason(completion, **kwargs)
|
|
386
|
-
# answer_log_probs = self._get_answer_log_probs(completion, **kwargs)
|
|
350
|
+
usage = completion.usage.model_dump()
|
|
387
351
|
|
|
388
|
-
# get cost per input/output token for the model
|
|
389
|
-
usd_per_input_token = MODEL_CARDS[self.model_name]
|
|
352
|
+
# get cost per input/output token for the model
|
|
353
|
+
usd_per_input_token = MODEL_CARDS[self.model_name].get("usd_per_input_token", 0.0)
|
|
354
|
+
usd_per_audio_input_token = MODEL_CARDS[self.model_name].get("usd_per_audio_input_token", 0.0)
|
|
390
355
|
usd_per_output_token = MODEL_CARDS[self.model_name]["usd_per_output_token"]
|
|
391
|
-
|
|
392
|
-
|
|
356
|
+
|
|
357
|
+
# TODO: for some models (e.g. GPT-5) we cannot separate text from image prompt tokens yet;
|
|
358
|
+
# for now, we only use tokens from prompt_token_details if it's an audio prompt
|
|
359
|
+
# get output tokens (all text) and input tokens by modality
|
|
360
|
+
output_tokens = usage["completion_tokens"]
|
|
361
|
+
if self.prompt_strategy.is_audio_prompt():
|
|
362
|
+
input_audio_tokens = usage["prompt_tokens_details"].get("audio_tokens", 0)
|
|
363
|
+
input_text_tokens = usage["prompt_tokens_details"].get("text_tokens", 0)
|
|
364
|
+
input_image_tokens = 0
|
|
365
|
+
else:
|
|
366
|
+
input_audio_tokens = 0
|
|
367
|
+
input_text_tokens = usage["prompt_tokens"]
|
|
368
|
+
input_image_tokens = 0
|
|
369
|
+
input_tokens = input_audio_tokens + input_text_tokens + input_image_tokens
|
|
370
|
+
|
|
371
|
+
# compute the input and output token costs
|
|
372
|
+
total_input_cost = (input_text_tokens + input_image_tokens) * usd_per_input_token + input_audio_tokens * usd_per_audio_input_token
|
|
373
|
+
total_output_cost = output_tokens * usd_per_output_token
|
|
393
374
|
|
|
394
375
|
generation_stats = GenerationStats(
|
|
395
376
|
model_name=self.model_name,
|
|
396
377
|
llm_call_duration_secs=end_time - start_time,
|
|
397
378
|
fn_call_duration_secs=0.0,
|
|
379
|
+
input_audio_tokens=input_audio_tokens,
|
|
380
|
+
input_text_tokens=input_text_tokens,
|
|
381
|
+
input_image_tokens=input_image_tokens,
|
|
398
382
|
total_input_tokens=input_tokens,
|
|
399
383
|
total_output_tokens=output_tokens,
|
|
400
|
-
total_input_cost=
|
|
401
|
-
total_output_cost=
|
|
402
|
-
cost_per_record=
|
|
384
|
+
total_input_cost=total_input_cost,
|
|
385
|
+
total_output_cost=total_output_cost,
|
|
386
|
+
cost_per_record=total_input_cost + total_output_cost,
|
|
403
387
|
total_llm_calls=1,
|
|
404
|
-
# "system_prompt": system_prompt,
|
|
405
|
-
# "prompt": prompt,
|
|
406
|
-
# "usage": usage,
|
|
407
|
-
# "finish_reason": finish_reason,
|
|
408
|
-
# "answer_log_probs": answer_log_probs,
|
|
409
|
-
# "answer": answer,
|
|
410
388
|
)
|
|
411
389
|
|
|
412
390
|
# pretty print prompt + full completion output for debugging
|
|
413
|
-
completion_text =
|
|
391
|
+
completion_text = completion.choices[0].message.content
|
|
414
392
|
prompt = ""
|
|
415
393
|
for message in messages:
|
|
416
394
|
if message["role"] == "user":
|
|
417
|
-
|
|
395
|
+
if message["type"] == "text":
|
|
396
|
+
prompt += message["content"] + "\n"
|
|
397
|
+
elif message["type"] == "image":
|
|
398
|
+
prompt += "<image>\n"
|
|
399
|
+
elif message["type"] == "input_audio":
|
|
400
|
+
prompt += "<audio>\n"
|
|
418
401
|
logger.debug(f"PROMPT:\n{prompt}")
|
|
419
402
|
logger.debug(Fore.GREEN + f"{completion_text}\n" + Style.RESET_ALL)
|
|
420
403
|
|
|
@@ -422,17 +405,20 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
422
405
|
reasoning = None
|
|
423
406
|
try:
|
|
424
407
|
reasoning = self._parse_reasoning(completion_text, **kwargs)
|
|
425
|
-
except Exception:
|
|
426
|
-
|
|
427
|
-
logger.debug("TODO: undo this")
|
|
408
|
+
except Exception as e:
|
|
409
|
+
logger.error(f"Error parsing reasoning and answers: {e}")
|
|
428
410
|
pass
|
|
429
411
|
|
|
430
412
|
# parse field answers
|
|
431
|
-
field_answers = None
|
|
413
|
+
field_answers = None
|
|
414
|
+
if fields is not None and (self.prompt_strategy.is_bool_prompt() or self.prompt_strategy.is_join_prompt()):
|
|
415
|
+
field_answers = {"passed_operator": False}
|
|
416
|
+
elif fields is not None and not (self.prompt_strategy.is_bool_prompt() or self.prompt_strategy.is_join_prompt()):
|
|
417
|
+
field_answers = {field_name: None for field_name in fields}
|
|
432
418
|
try:
|
|
433
419
|
field_answers = self._parse_answer(completion_text, fields, json_output, **kwargs)
|
|
434
420
|
except Exception as e:
|
|
435
|
-
|
|
421
|
+
logger.error(f"Error parsing answers: {e}")
|
|
436
422
|
os.makedirs("parse-answer-errors", exist_ok=True)
|
|
437
423
|
ts = time.time()
|
|
438
424
|
with open(f"parse-answer-errors/error-{ts}.txt", "w") as f:
|
|
@@ -448,162 +434,3 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
448
434
|
|
|
449
435
|
logger.debug(f"Generated field answers: {field_answers}")
|
|
450
436
|
return field_answers, reasoning, generation_stats, messages
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
class OpenAIGenerator(BaseGenerator[str | list[str], str]):
|
|
454
|
-
"""
|
|
455
|
-
Class for generating text using the OpenAI chat API.
|
|
456
|
-
"""
|
|
457
|
-
|
|
458
|
-
def __init__(
|
|
459
|
-
self,
|
|
460
|
-
model: Model,
|
|
461
|
-
prompt_strategy: PromptStrategy,
|
|
462
|
-
cardinality: Cardinality = Cardinality.ONE_TO_ONE,
|
|
463
|
-
verbose: bool = False,
|
|
464
|
-
):
|
|
465
|
-
# assert that model is an OpenAI model
|
|
466
|
-
assert model.is_openai_model()
|
|
467
|
-
super().__init__(model, prompt_strategy, cardinality, verbose, "developer")
|
|
468
|
-
|
|
469
|
-
def _get_client_or_model(self, **kwargs) -> OpenAI:
|
|
470
|
-
"""Returns a client (or local model) which can be invoked to perform the generation."""
|
|
471
|
-
return APIClientFactory.get_client(APIClient.OPENAI, get_api_key("OPENAI_API_KEY"))
|
|
472
|
-
|
|
473
|
-
def _generate_completion(self, client: OpenAI, payload: dict, **kwargs) -> ChatCompletion:
|
|
474
|
-
"""Generates a completion object using the client (or local model)."""
|
|
475
|
-
return client.chat.completions.create(**payload)
|
|
476
|
-
|
|
477
|
-
def _get_completion_text(self, completion: ChatCompletion, **kwargs) -> str:
|
|
478
|
-
"""Extract the completion text from the completion object."""
|
|
479
|
-
return completion.choices[0].message.content
|
|
480
|
-
|
|
481
|
-
def _get_usage(self, completion: ChatCompletion, **kwargs) -> dict:
|
|
482
|
-
"""Extract the usage statistics from the completion object."""
|
|
483
|
-
return {
|
|
484
|
-
"input_tokens": completion.usage.prompt_tokens,
|
|
485
|
-
"output_tokens": completion.usage.completion_tokens,
|
|
486
|
-
}
|
|
487
|
-
|
|
488
|
-
def _get_finish_reason(self, completion: ChatCompletion, **kwargs) -> str:
|
|
489
|
-
"""Extract the finish reason from the completion object."""
|
|
490
|
-
return completion.choices[0].finish_reason
|
|
491
|
-
|
|
492
|
-
def _get_answer_log_probs(self, completion: ChatCompletion, **kwargs) -> list[float]:
|
|
493
|
-
"""Extract the log probabilities from the completion object."""
|
|
494
|
-
return completion.choices[0].logprobs
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
class TogetherGenerator(BaseGenerator[str | list[str], str]):
|
|
498
|
-
"""
|
|
499
|
-
Class for generating text using the Together chat API.
|
|
500
|
-
"""
|
|
501
|
-
|
|
502
|
-
def __init__(
|
|
503
|
-
self,
|
|
504
|
-
model: Model,
|
|
505
|
-
prompt_strategy: PromptStrategy,
|
|
506
|
-
cardinality: Cardinality = Cardinality.ONE_TO_ONE,
|
|
507
|
-
verbose: bool = False,
|
|
508
|
-
):
|
|
509
|
-
# assert that model is a model offered by Together
|
|
510
|
-
assert model.is_together_model()
|
|
511
|
-
super().__init__(model, prompt_strategy, cardinality, verbose, "system")
|
|
512
|
-
|
|
513
|
-
def _generate_payload(self, messages: list[dict], **kwargs) -> dict:
|
|
514
|
-
"""
|
|
515
|
-
Generates the payload which will be fed into the client (or local model).
|
|
516
|
-
|
|
517
|
-
Each message will be a dictionary with the following format:
|
|
518
|
-
{
|
|
519
|
-
"role": "user" | "system",
|
|
520
|
-
"type": "text" | "image",
|
|
521
|
-
"content": str
|
|
522
|
-
}
|
|
523
|
-
|
|
524
|
-
For LLAMA3, the payload needs to be in a {"role": <role>, "content": <content>} format.
|
|
525
|
-
"""
|
|
526
|
-
# for other models, use our standard payload generation
|
|
527
|
-
if not self.model.is_llama_model():
|
|
528
|
-
return super()._generate_payload(messages, **kwargs)
|
|
529
|
-
|
|
530
|
-
# get basic parameters
|
|
531
|
-
model = self.model_name
|
|
532
|
-
temperature = kwargs.get("temperature", 0.0)
|
|
533
|
-
|
|
534
|
-
# construct messages in simple {"role": <role>, "content": <content>} format
|
|
535
|
-
chat_messages = []
|
|
536
|
-
for message in messages:
|
|
537
|
-
chat_messages.append({"role": message["role"], "content": message["content"]})
|
|
538
|
-
|
|
539
|
-
# construct and return payload
|
|
540
|
-
payload = {
|
|
541
|
-
"model": model,
|
|
542
|
-
"temperature": temperature,
|
|
543
|
-
"messages": chat_messages,
|
|
544
|
-
}
|
|
545
|
-
|
|
546
|
-
return payload
|
|
547
|
-
|
|
548
|
-
def _get_client_or_model(self, **kwargs) -> Together:
|
|
549
|
-
"""Returns a client (or local model) which can be invoked to perform the generation."""
|
|
550
|
-
return APIClientFactory.get_client(APIClient.TOGETHER, get_api_key("TOGETHER_API_KEY"))
|
|
551
|
-
|
|
552
|
-
def _generate_completion(self, client: Together, payload: dict, **kwargs) -> ChatCompletionResponse:
|
|
553
|
-
"""Generates a completion object using the client (or local model)."""
|
|
554
|
-
return client.chat.completions.create(**payload)
|
|
555
|
-
|
|
556
|
-
def _get_completion_text(self, completion: ChatCompletionResponse, **kwargs) -> str:
|
|
557
|
-
"""Extract the completion text from the completion object."""
|
|
558
|
-
return completion.choices[0].message.content
|
|
559
|
-
|
|
560
|
-
def _get_usage(self, completion: ChatCompletionResponse, **kwargs) -> dict:
|
|
561
|
-
"""Extract the usage statistics from the completion object."""
|
|
562
|
-
return {
|
|
563
|
-
"input_tokens": completion.usage.prompt_tokens,
|
|
564
|
-
"output_tokens": completion.usage.completion_tokens,
|
|
565
|
-
}
|
|
566
|
-
|
|
567
|
-
def _get_finish_reason(self, completion: ChatCompletionResponse, **kwargs) -> str:
|
|
568
|
-
"""Extract the finish reason from the completion object."""
|
|
569
|
-
return completion.choices[0].finish_reason.value
|
|
570
|
-
|
|
571
|
-
def _get_answer_log_probs(self, completion: ChatCompletionResponse, **kwargs) -> list[float]:
|
|
572
|
-
"""Extract the log probabilities from the completion object."""
|
|
573
|
-
return completion.choices[0].logprobs
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
### CODE SYNTHESIS EXECUTION ###
|
|
577
|
-
def code_execution(api: API, code: str, candidate_dict: dict[str, Any], verbose: bool = False):
|
|
578
|
-
inputs = {field_name: candidate_dict[field_name] for field_name in api.inputs}
|
|
579
|
-
response = api.api_execute(code, inputs)
|
|
580
|
-
pred = response["response"] if response["status"] and response["response"] else None
|
|
581
|
-
return pred
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
def code_ensemble_execution(
|
|
585
|
-
api: API, code_ensemble: dict[str, str], candidate_dict: dict[str, Any], verbose: bool = True
|
|
586
|
-
) -> GenerationOutput:
|
|
587
|
-
start_time = time.time()
|
|
588
|
-
try:
|
|
589
|
-
preds = list()
|
|
590
|
-
for _, code in code_ensemble.items():
|
|
591
|
-
pred = code_execution(api, code, candidate_dict)
|
|
592
|
-
preds.append(pred)
|
|
593
|
-
|
|
594
|
-
preds = [pred for pred in preds if pred is not None]
|
|
595
|
-
|
|
596
|
-
if len(preds) == 1:
|
|
597
|
-
majority_response = preds[0]
|
|
598
|
-
exec_stats = GenerationStats(fn_call_duration_secs=time.time() - start_time)
|
|
599
|
-
return majority_response, None, exec_stats
|
|
600
|
-
|
|
601
|
-
if len(preds) > 0:
|
|
602
|
-
majority_response = Counter(preds).most_common(1)[0][0]
|
|
603
|
-
exec_stats = GenerationStats(fn_call_duration_secs=time.time() - start_time)
|
|
604
|
-
return majority_response, None, exec_stats
|
|
605
|
-
|
|
606
|
-
except Exception:
|
|
607
|
-
pass
|
|
608
|
-
|
|
609
|
-
return None, None, GenerationStats(fn_call_duration_secs=time.time() - start_time)
|