palimpzest 0.7.21__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/__init__.py +37 -6
- palimpzest/agents/__init__.py +0 -0
- palimpzest/agents/compute_agents.py +0 -0
- palimpzest/agents/search_agents.py +637 -0
- palimpzest/constants.py +343 -209
- palimpzest/core/data/context.py +393 -0
- palimpzest/core/data/context_manager.py +163 -0
- palimpzest/core/data/dataset.py +639 -0
- palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
- palimpzest/core/elements/groupbysig.py +16 -13
- palimpzest/core/elements/records.py +166 -75
- palimpzest/core/lib/schemas.py +152 -390
- palimpzest/core/{data/dataclasses.py → models.py} +306 -170
- palimpzest/policy.py +2 -27
- palimpzest/prompts/__init__.py +35 -5
- palimpzest/prompts/agent_prompts.py +357 -0
- palimpzest/prompts/context_search.py +9 -0
- palimpzest/prompts/convert_prompts.py +62 -6
- palimpzest/prompts/filter_prompts.py +51 -6
- palimpzest/prompts/join_prompts.py +163 -0
- palimpzest/prompts/moa_proposer_convert_prompts.py +6 -6
- palimpzest/prompts/prompt_factory.py +375 -47
- palimpzest/prompts/split_proposer_prompts.py +1 -1
- palimpzest/prompts/util_phrases.py +5 -0
- palimpzest/prompts/validator.py +239 -0
- palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
- palimpzest/query/execution/execution_strategy.py +210 -317
- palimpzest/query/execution/execution_strategy_type.py +5 -7
- palimpzest/query/execution/mab_execution_strategy.py +249 -136
- palimpzest/query/execution/parallel_execution_strategy.py +153 -244
- palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
- palimpzest/query/generators/generators.py +160 -331
- palimpzest/query/operators/__init__.py +15 -5
- palimpzest/query/operators/aggregate.py +50 -33
- palimpzest/query/operators/compute.py +201 -0
- palimpzest/query/operators/convert.py +33 -19
- palimpzest/query/operators/critique_and_refine_convert.py +7 -5
- palimpzest/query/operators/distinct.py +62 -0
- palimpzest/query/operators/filter.py +26 -16
- palimpzest/query/operators/join.py +403 -0
- palimpzest/query/operators/limit.py +3 -3
- palimpzest/query/operators/logical.py +205 -77
- palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
- palimpzest/query/operators/physical.py +27 -21
- palimpzest/query/operators/project.py +3 -3
- palimpzest/query/operators/rag_convert.py +7 -7
- palimpzest/query/operators/retrieve.py +9 -9
- palimpzest/query/operators/scan.py +81 -42
- palimpzest/query/operators/search.py +524 -0
- palimpzest/query/operators/split_convert.py +10 -8
- palimpzest/query/optimizer/__init__.py +7 -9
- palimpzest/query/optimizer/cost_model.py +108 -441
- palimpzest/query/optimizer/optimizer.py +123 -181
- palimpzest/query/optimizer/optimizer_strategy.py +66 -61
- palimpzest/query/optimizer/plan.py +352 -67
- palimpzest/query/optimizer/primitives.py +43 -19
- palimpzest/query/optimizer/rules.py +484 -646
- palimpzest/query/optimizer/tasks.py +127 -58
- palimpzest/query/processor/config.py +42 -76
- palimpzest/query/processor/query_processor.py +73 -18
- palimpzest/query/processor/query_processor_factory.py +46 -38
- palimpzest/schemabuilder/schema_builder.py +15 -28
- palimpzest/utils/model_helpers.py +32 -77
- palimpzest/utils/progress.py +114 -102
- palimpzest/validator/__init__.py +0 -0
- palimpzest/validator/validator.py +306 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/METADATA +6 -1
- palimpzest-0.8.1.dist-info/RECORD +95 -0
- palimpzest/core/lib/fields.py +0 -141
- palimpzest/prompts/code_synthesis_prompts.py +0 -28
- palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
- palimpzest/query/generators/api_client_factory.py +0 -30
- palimpzest/query/operators/code_synthesis_convert.py +0 -488
- palimpzest/query/operators/map.py +0 -130
- palimpzest/query/processor/nosentinel_processor.py +0 -33
- palimpzest/query/processor/processing_strategy_type.py +0 -28
- palimpzest/query/processor/sentinel_processor.py +0 -88
- palimpzest/query/processor/streaming_processor.py +0 -149
- palimpzest/sets.py +0 -405
- palimpzest/utils/datareader_helpers.py +0 -61
- palimpzest/utils/demo_helpers.py +0 -75
- palimpzest/utils/field_helpers.py +0 -69
- palimpzest/utils/generation_helpers.py +0 -69
- palimpzest/utils/sandbox.py +0 -183
- palimpzest-0.7.21.dist-info/RECORD +0 -95
- /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/WHEEL +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/top_level.txt +0 -0
|
@@ -4,36 +4,28 @@ This file contains the Generator classes and generator factory.
|
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
|
+
import json
|
|
7
8
|
import logging
|
|
8
9
|
import os
|
|
9
|
-
import re
|
|
10
10
|
import time
|
|
11
11
|
import warnings
|
|
12
|
-
from abc import ABC, abstractmethod
|
|
13
|
-
from collections import Counter
|
|
14
12
|
from copy import deepcopy
|
|
15
13
|
from typing import Any, Generic, TypeVar
|
|
16
14
|
|
|
15
|
+
import litellm
|
|
16
|
+
import regex as re # Use regex instead of re to used variable length lookbehind
|
|
17
17
|
from colorama import Fore, Style
|
|
18
|
-
from
|
|
19
|
-
from openai.types.chat.chat_completion import ChatCompletion
|
|
20
|
-
from together import Together
|
|
21
|
-
from together.types.chat_completions import ChatCompletionResponse
|
|
18
|
+
from pydantic.fields import FieldInfo
|
|
22
19
|
|
|
23
20
|
from palimpzest.constants import (
|
|
24
21
|
MODEL_CARDS,
|
|
25
|
-
APIClient,
|
|
26
22
|
Cardinality,
|
|
27
23
|
Model,
|
|
28
24
|
PromptStrategy,
|
|
29
25
|
)
|
|
30
|
-
from palimpzest.core.data.dataclasses import GenerationStats
|
|
31
26
|
from palimpzest.core.elements.records import DataRecord
|
|
32
|
-
from palimpzest.core.
|
|
27
|
+
from palimpzest.core.models import GenerationStats
|
|
33
28
|
from palimpzest.prompts import PromptFactory
|
|
34
|
-
from palimpzest.query.generators.api_client_factory import APIClientFactory
|
|
35
|
-
from palimpzest.utils.generation_helpers import get_json_from_answer
|
|
36
|
-
from palimpzest.utils.sandbox import API
|
|
37
29
|
|
|
38
30
|
# DEFINITIONS
|
|
39
31
|
GenerationOutput = tuple[dict, str | None, GenerationStats, list[dict]]
|
|
@@ -43,31 +35,71 @@ InputType = TypeVar("InputType")
|
|
|
43
35
|
|
|
44
36
|
logger = logging.getLogger(__name__)
|
|
45
37
|
|
|
46
|
-
def
|
|
47
|
-
model: Model, prompt_strategy: PromptStrategy, cardinality: Cardinality, verbose: bool = False
|
|
48
|
-
) -> BaseGenerator:
|
|
38
|
+
def get_json_from_answer(answer: str, model: Model, cardinality: Cardinality) -> dict[str, Any]:
|
|
49
39
|
"""
|
|
50
|
-
|
|
40
|
+
This function parses an LLM response which is supposed to output a JSON object
|
|
41
|
+
and optimistically searches for the substring containing the JSON object.
|
|
51
42
|
"""
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
#
|
|
63
|
-
if
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
43
|
+
# model-specific trimming for LLAMA3 responses
|
|
44
|
+
if model.is_llama_model():
|
|
45
|
+
answer = answer.split("---")[0]
|
|
46
|
+
answer = answer.replace("True", "true")
|
|
47
|
+
answer = answer.replace("False", "false")
|
|
48
|
+
|
|
49
|
+
# split off context / excess, which models sometimes output after answer
|
|
50
|
+
answer = answer.split("Context:")[0]
|
|
51
|
+
answer = answer.split("# this is the answer")[0]
|
|
52
|
+
|
|
53
|
+
# trim the answer to only include the JSON dictionary
|
|
54
|
+
if cardinality == Cardinality.ONE_TO_ONE:
|
|
55
|
+
if not answer.strip().startswith("{"):
|
|
56
|
+
# Find the start index of the actual JSON string assuming the prefix is followed by the JSON dictionary
|
|
57
|
+
start_index = answer.find("{")
|
|
58
|
+
if start_index != -1:
|
|
59
|
+
# Remove the prefix and any leading characters before the JSON starts
|
|
60
|
+
answer = answer[start_index:]
|
|
61
|
+
|
|
62
|
+
if not answer.strip().endswith("}"):
|
|
63
|
+
# Find the end index of the actual JSON string assuming the suffix is preceded by the JSON dictionary
|
|
64
|
+
end_index = answer.rfind("}")
|
|
65
|
+
if end_index != -1:
|
|
66
|
+
# Remove the suffix and any trailing characters after the JSON ends
|
|
67
|
+
answer = answer[: end_index + 1]
|
|
68
|
+
|
|
69
|
+
# otherwise, trim the answer to only include the JSON array
|
|
70
|
+
else:
|
|
71
|
+
if not answer.strip().startswith("["):
|
|
72
|
+
# Find the start index of the actual JSON string assuming the prefix is followed by the JSON array
|
|
73
|
+
start_index = answer.find("[")
|
|
74
|
+
if start_index != -1:
|
|
75
|
+
# Remove the prefix and any leading characters before the JSON starts
|
|
76
|
+
answer = answer[start_index:]
|
|
77
|
+
|
|
78
|
+
if not answer.strip().endswith("]"):
|
|
79
|
+
# Find the end index of the actual JSON string
|
|
80
|
+
# assuming the suffix is preceded by the JSON object/array
|
|
81
|
+
end_index = answer.rfind("]")
|
|
82
|
+
if end_index != -1:
|
|
83
|
+
# Remove the suffix and any trailing characters after the JSON ends
|
|
84
|
+
answer = answer[: end_index + 1]
|
|
85
|
+
|
|
86
|
+
# Handle weird escaped values. I am not sure why the model
|
|
87
|
+
# is returning these, but the JSON parser can't take them
|
|
88
|
+
answer = answer.replace(r"\_", "_")
|
|
89
|
+
answer = answer.replace("\\n", "\n")
|
|
90
|
+
# Remove https and http prefixes to not conflict with comment detection
|
|
91
|
+
# Handle comments in the JSON response. Use regex from // until end of line
|
|
92
|
+
answer = re.sub(r"(?<!https?:)\/\/.*?$", "", answer, flags=re.MULTILINE)
|
|
93
|
+
answer = re.sub(r",\n.*\.\.\.$", "", answer, flags=re.MULTILINE)
|
|
94
|
+
# Sanitize newlines in the JSON response
|
|
95
|
+
answer = answer.replace("\n", " ")
|
|
96
|
+
|
|
97
|
+
# finally, parse and return the JSON object; errors are handled by the caller
|
|
98
|
+
return json.loads(answer)
|
|
99
|
+
|
|
100
|
+
# TODO: push parallelism of generations into LiteLLM rather than threadpool in executor
|
|
69
101
|
# TODO: make sure answer parsing works with custom prompts / parsers (can defer this)
|
|
70
|
-
class
|
|
102
|
+
class Generator(Generic[ContextType, InputType]):
|
|
71
103
|
"""
|
|
72
104
|
Abstract base class for Generators.
|
|
73
105
|
"""
|
|
@@ -76,94 +108,21 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
76
108
|
self,
|
|
77
109
|
model: Model,
|
|
78
110
|
prompt_strategy: PromptStrategy,
|
|
111
|
+
reasoning_effort: str | None = None,
|
|
112
|
+
api_base: str | None = None,
|
|
79
113
|
cardinality: Cardinality = Cardinality.ONE_TO_ONE,
|
|
114
|
+
desc: str | None = None,
|
|
80
115
|
verbose: bool = False,
|
|
81
|
-
system_role: str = "system",
|
|
82
116
|
):
|
|
83
117
|
self.model = model
|
|
84
118
|
self.model_name = model.value
|
|
85
119
|
self.cardinality = cardinality
|
|
86
120
|
self.prompt_strategy = prompt_strategy
|
|
121
|
+
self.reasoning_effort = reasoning_effort
|
|
122
|
+
self.api_base = api_base
|
|
123
|
+
self.desc = desc
|
|
87
124
|
self.verbose = verbose
|
|
88
|
-
self.
|
|
89
|
-
self.prompt_factory = PromptFactory(prompt_strategy, model, cardinality)
|
|
90
|
-
|
|
91
|
-
@abstractmethod
|
|
92
|
-
def _get_client_or_model(self, **kwargs) -> Any:
|
|
93
|
-
"""Returns a client (or local model) which can be invoked to perform the generation."""
|
|
94
|
-
pass
|
|
95
|
-
|
|
96
|
-
@abstractmethod
|
|
97
|
-
def _generate_completion(self, client_or_model: Any, payload: dict, **kwargs) -> Any:
|
|
98
|
-
"""Generates a completion object using the client (or local model)."""
|
|
99
|
-
pass
|
|
100
|
-
|
|
101
|
-
@abstractmethod
|
|
102
|
-
def _get_completion_text(self, completion: Any, **kwargs) -> Any:
|
|
103
|
-
"""Extract the completion text from the completion object."""
|
|
104
|
-
pass
|
|
105
|
-
|
|
106
|
-
@abstractmethod
|
|
107
|
-
def _get_usage(self, completion: Any, **kwargs) -> Any:
|
|
108
|
-
"""Extract the usage statistics from the completion object."""
|
|
109
|
-
pass
|
|
110
|
-
|
|
111
|
-
@abstractmethod
|
|
112
|
-
def _get_finish_reason(self, completion: Any, **kwargs) -> Any:
|
|
113
|
-
"""Extract the finish reason from the completion object."""
|
|
114
|
-
pass
|
|
115
|
-
|
|
116
|
-
@abstractmethod
|
|
117
|
-
def _get_answer_log_probs(self, completion: Any, **kwargs) -> Any:
|
|
118
|
-
"""Extract the log probabilities from the completion object."""
|
|
119
|
-
pass
|
|
120
|
-
|
|
121
|
-
def _generate_payload(self, messages: list[dict], **kwargs) -> dict:
|
|
122
|
-
"""
|
|
123
|
-
Generates the payload which will be fed into the client (or local model).
|
|
124
|
-
|
|
125
|
-
Each message will be a dictionary with the following format:
|
|
126
|
-
{
|
|
127
|
-
"role": "user" | "system",
|
|
128
|
-
"type": "text" | "image",
|
|
129
|
-
"content": str
|
|
130
|
-
}
|
|
131
|
-
"""
|
|
132
|
-
# get basic parameters
|
|
133
|
-
model = self.model_name
|
|
134
|
-
temperature = kwargs.get("temperature", 0.0)
|
|
135
|
-
|
|
136
|
-
# construct messages and add system prompt if present
|
|
137
|
-
chat_messages, user_content = [], []
|
|
138
|
-
for message in messages:
|
|
139
|
-
# flush user content into a message and add system message
|
|
140
|
-
if message["role"] == "system":
|
|
141
|
-
if len(user_content) > 0:
|
|
142
|
-
chat_messages.append({"role": "user", "content": user_content})
|
|
143
|
-
user_content = []
|
|
144
|
-
|
|
145
|
-
chat_messages.append({"role": self.system_role, "content": message["content"]})
|
|
146
|
-
|
|
147
|
-
# add user content for text messages
|
|
148
|
-
elif message["role"] == "user" and message["type"] == "text":
|
|
149
|
-
user_content.append({"type": "text", "text": message["content"]})
|
|
150
|
-
|
|
151
|
-
# add user content for image messages
|
|
152
|
-
elif message["role"] == "user" and message["type"] == "image":
|
|
153
|
-
user_content.append({"type": "image_url", "image_url": {"url": message["content"]}})
|
|
154
|
-
|
|
155
|
-
# flush any remaining user content into a final message
|
|
156
|
-
if len(user_content) > 0:
|
|
157
|
-
chat_messages.append({"role": "user", "content": user_content})
|
|
158
|
-
|
|
159
|
-
# construct and return payload
|
|
160
|
-
payload = {
|
|
161
|
-
"model": model,
|
|
162
|
-
"temperature": temperature,
|
|
163
|
-
"messages": chat_messages,
|
|
164
|
-
}
|
|
165
|
-
|
|
166
|
-
return payload
|
|
125
|
+
self.prompt_factory = PromptFactory(prompt_strategy, model, cardinality, desc)
|
|
167
126
|
|
|
168
127
|
def _parse_reasoning(self, completion_text: str, **kwargs) -> str:
|
|
169
128
|
"""Extract the reasoning for the generated output from the completion object."""
|
|
@@ -183,7 +142,7 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
183
142
|
# otherwise, return the full completion text
|
|
184
143
|
return completion_text
|
|
185
144
|
|
|
186
|
-
def _prepare_field_answers(self, field_answers: dict | list[dict], fields: dict[str,
|
|
145
|
+
def _prepare_field_answers(self, field_answers: dict | list[dict], fields: dict[str, FieldInfo]) -> dict[str, list]:
|
|
187
146
|
"""
|
|
188
147
|
field_answers is a dictionary mapping fields to their values. For one-to-one converts, wrap each
|
|
189
148
|
answer in a list. For one-to-many converts, invert the list of dictionaries into a dictionary with
|
|
@@ -205,7 +164,7 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
205
164
|
|
|
206
165
|
return field_answers
|
|
207
166
|
|
|
208
|
-
def _check_convert_answer_text(self, answer_text: str, fields: dict[str,
|
|
167
|
+
def _check_convert_answer_text(self, answer_text: str, fields: dict[str, FieldInfo], throw_exception: bool=False) -> dict | list[dict] | None:
|
|
209
168
|
"""
|
|
210
169
|
Try parsing the answer text into a JSON object. If the parsing fails, return None.
|
|
211
170
|
"""
|
|
@@ -213,18 +172,6 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
213
172
|
# extract json from the answer text
|
|
214
173
|
field_answers = get_json_from_answer(answer_text, self.model, self.cardinality)
|
|
215
174
|
|
|
216
|
-
# TODO: wrap non-list outputs in a list if expected output is a list
|
|
217
|
-
|
|
218
|
-
# common error for one-to-one: if the output is a singleton list which contains a list, but the expected field type
|
|
219
|
-
# is a list of strings, or a list of floats, i.e. not a list of lists; then extract the inner list
|
|
220
|
-
if self.cardinality == Cardinality.ONE_TO_ONE:
|
|
221
|
-
for field, field_type in fields.items():
|
|
222
|
-
answer = field_answers[field]
|
|
223
|
-
field_type_is_not_list_of_lists = isinstance(field_type, ListField) and not issubclass(field_type.element_type, ListField)
|
|
224
|
-
answer_is_list_of_lists = isinstance(answer, list) and len(answer) == 1 and isinstance(answer[0], list)
|
|
225
|
-
if field_type_is_not_list_of_lists and answer_is_list_of_lists:
|
|
226
|
-
field_answers[field] = answer[0]
|
|
227
|
-
|
|
228
175
|
# prepare the field answers to match the expected output and return
|
|
229
176
|
return self._prepare_field_answers(field_answers, fields)
|
|
230
177
|
|
|
@@ -234,7 +181,7 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
234
181
|
|
|
235
182
|
return None
|
|
236
183
|
|
|
237
|
-
def
|
|
184
|
+
def _check_bool_answer_text(self, answer_text: str) -> dict | None:
|
|
238
185
|
"""
|
|
239
186
|
Return {"passed_operator": True} if and only if "true" is in the answer text.
|
|
240
187
|
Return {"passed_operator": False} if and only if "false" is in the answer text.
|
|
@@ -249,7 +196,7 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
249
196
|
|
|
250
197
|
return None
|
|
251
198
|
|
|
252
|
-
def _parse_convert_answer(self, completion_text: str, fields: dict[str,
|
|
199
|
+
def _parse_convert_answer(self, completion_text: str, fields: dict[str, FieldInfo], json_output: bool) -> dict[str, list]:
|
|
253
200
|
"""Extract the answer from the completion object for convert operations."""
|
|
254
201
|
# if the model followed the default instructions, the completion text will place
|
|
255
202
|
# its answer between "ANSWER:" and "---"
|
|
@@ -288,15 +235,15 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
288
235
|
|
|
289
236
|
return self._check_convert_answer_text(completion_text, fields, throw_exception=True)
|
|
290
237
|
|
|
291
|
-
def
|
|
292
|
-
"""Extract the answer from the completion object for filter operations."""
|
|
238
|
+
def _parse_bool_answer(self, completion_text: str) -> dict[str, list]:
|
|
239
|
+
"""Extract the answer from the completion object for filter and join operations."""
|
|
293
240
|
# if the model followed the default instructions, the completion text will place
|
|
294
241
|
# its answer between "ANSWER:" and "---"
|
|
295
242
|
regex = re.compile("answer:(.*?)---", re.IGNORECASE | re.DOTALL)
|
|
296
243
|
matches = regex.findall(completion_text)
|
|
297
244
|
if len(matches) > 0:
|
|
298
245
|
answer_text = matches[0].strip()
|
|
299
|
-
field_answers = self.
|
|
246
|
+
field_answers = self._check_bool_answer_text(answer_text)
|
|
300
247
|
if field_answers is not None:
|
|
301
248
|
return field_answers
|
|
302
249
|
|
|
@@ -305,18 +252,18 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
305
252
|
matches = regex.findall(completion_text)
|
|
306
253
|
if len(matches) > 0:
|
|
307
254
|
answer_text = matches[0].strip()
|
|
308
|
-
field_answers = self.
|
|
255
|
+
field_answers = self._check_bool_answer_text(answer_text)
|
|
309
256
|
if field_answers is not None:
|
|
310
257
|
return field_answers
|
|
311
258
|
|
|
312
259
|
# finally, try taking all of the text; throw an exception if this doesn't work
|
|
313
|
-
field_answers = self.
|
|
260
|
+
field_answers = self._check_bool_answer_text(completion_text)
|
|
314
261
|
if field_answers is None:
|
|
315
262
|
raise Exception(f"Could not parse answer from completion text: {completion_text}")
|
|
316
263
|
|
|
317
264
|
return field_answers
|
|
318
265
|
|
|
319
|
-
def _parse_answer(self, completion_text: str, fields: dict[str,
|
|
266
|
+
def _parse_answer(self, completion_text: str, fields: dict[str, FieldInfo] | None, json_output: bool, **kwargs) -> dict[str, list]:
|
|
320
267
|
"""Extract the answer from the completion object."""
|
|
321
268
|
# use a custom answer parser if provided
|
|
322
269
|
if kwargs.get("parse_answer"):
|
|
@@ -328,16 +275,15 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
328
275
|
|
|
329
276
|
# extract the per-field answers from the completion text
|
|
330
277
|
field_answers = (
|
|
331
|
-
self.
|
|
332
|
-
if self.prompt_strategy.is_bool_prompt()
|
|
278
|
+
self._parse_bool_answer(completion_text)
|
|
279
|
+
if self.prompt_strategy.is_bool_prompt() or self.prompt_strategy.is_join_prompt()
|
|
333
280
|
else self._parse_convert_answer(completion_text, fields, json_output)
|
|
334
281
|
)
|
|
335
282
|
|
|
336
283
|
return field_answers
|
|
337
284
|
|
|
338
|
-
def __call__(self, candidate: DataRecord, fields: dict[str,
|
|
285
|
+
def __call__(self, candidate: DataRecord, fields: dict[str, FieldInfo] | None, right_candidate: DataRecord | None = None, json_output: bool=True, **kwargs) -> GenerationOutput:
|
|
339
286
|
"""Take the input record (`candidate`), generate the output `fields`, and return the generated output."""
|
|
340
|
-
client = self._get_client_or_model()
|
|
341
287
|
logger.debug(f"Generating for candidate {candidate} with fields {fields}")
|
|
342
288
|
|
|
343
289
|
# fields can only be None if the user provides an answer parser
|
|
@@ -352,23 +298,45 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
352
298
|
warnings.warn("Provided `system_prompt` without providing `prompt`; setting `prompt` = `system_prompt`.") # noqa: B028
|
|
353
299
|
|
|
354
300
|
# generate a list of messages which can be used to construct a payload
|
|
355
|
-
messages = self.prompt_factory.create_messages(candidate, fields, **kwargs)
|
|
356
|
-
|
|
357
|
-
# create the chat payload
|
|
358
|
-
chat_payload = self._generate_payload(messages, **kwargs)
|
|
301
|
+
messages = self.prompt_factory.create_messages(candidate, fields, right_candidate, **kwargs)
|
|
359
302
|
|
|
360
303
|
# generate the text completion
|
|
361
304
|
start_time = time.time()
|
|
362
305
|
completion = None
|
|
363
306
|
try:
|
|
364
|
-
|
|
307
|
+
completion_kwargs = {}
|
|
308
|
+
if not self.model.is_o_model() and not self.model.is_gpt_5_model():
|
|
309
|
+
completion_kwargs = {"temperature": kwargs.get("temperature", 0.0), **completion_kwargs}
|
|
310
|
+
if self.prompt_strategy.is_audio_prompt():
|
|
311
|
+
completion_kwargs = {"modalities": ["text"], **completion_kwargs}
|
|
312
|
+
if self.model.is_reasoning_model():
|
|
313
|
+
if self.model.is_vertex_model():
|
|
314
|
+
reasoning_effort = self.reasoning_effort
|
|
315
|
+
if self.reasoning_effort is None and self.model == Model.GEMINI_2_5_PRO:
|
|
316
|
+
reasoning_effort = "low"
|
|
317
|
+
elif self.reasoning_effort is None:
|
|
318
|
+
reasoning_effort = "disable"
|
|
319
|
+
completion_kwargs = {"reasoning_effort": reasoning_effort, **completion_kwargs}
|
|
320
|
+
elif self.model.is_anthropic_model() and self.reasoning_effort is not None:
|
|
321
|
+
completion_kwargs = {"reasoning_effort": self.reasoning_effort, **completion_kwargs}
|
|
322
|
+
elif self.model.is_openai_model():
|
|
323
|
+
reasoning_effort = "minimal" if self.reasoning_effort is None else self.reasoning_effort
|
|
324
|
+
completion_kwargs = {"reasoning_effort": reasoning_effort, **completion_kwargs}
|
|
325
|
+
if self.model.is_vllm_model():
|
|
326
|
+
completion_kwargs = {"api_base": self.api_base, **completion_kwargs}
|
|
327
|
+
completion = litellm.completion(model=self.model_name, messages=messages, **completion_kwargs)
|
|
365
328
|
end_time = time.time()
|
|
366
329
|
logger.debug(f"Generated completion in {end_time - start_time:.2f} seconds")
|
|
367
330
|
# if there's an error generating the completion, we have to return an empty answer
|
|
368
331
|
# and can only account for the time spent performing the failed generation
|
|
369
|
-
except Exception:
|
|
370
|
-
|
|
371
|
-
|
|
332
|
+
except Exception as e:
|
|
333
|
+
print(f"Error generating completion: {e}")
|
|
334
|
+
logger.error(f"Error generating completion: {e}")
|
|
335
|
+
field_answers = (
|
|
336
|
+
{"passed_operator": False}
|
|
337
|
+
if self.prompt_strategy.is_bool_prompt() or self.prompt_strategy.is_join_prompt()
|
|
338
|
+
else {field_name: None for field_name in fields}
|
|
339
|
+
)
|
|
372
340
|
reasoning = None
|
|
373
341
|
generation_stats = GenerationStats(
|
|
374
342
|
model_name=self.model_name,
|
|
@@ -381,40 +349,57 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
381
349
|
# parse usage statistics and create the GenerationStats
|
|
382
350
|
generation_stats = None
|
|
383
351
|
if completion is not None:
|
|
384
|
-
usage =
|
|
385
|
-
# finish_reason = self._get_finish_reason(completion, **kwargs)
|
|
386
|
-
# answer_log_probs = self._get_answer_log_probs(completion, **kwargs)
|
|
352
|
+
usage = completion.usage.model_dump()
|
|
387
353
|
|
|
388
|
-
# get cost per input/output token for the model
|
|
389
|
-
usd_per_input_token = MODEL_CARDS[self.model_name]
|
|
354
|
+
# get cost per input/output token for the model
|
|
355
|
+
usd_per_input_token = MODEL_CARDS[self.model_name].get("usd_per_input_token", 0.0)
|
|
356
|
+
usd_per_audio_input_token = MODEL_CARDS[self.model_name].get("usd_per_audio_input_token", 0.0)
|
|
390
357
|
usd_per_output_token = MODEL_CARDS[self.model_name]["usd_per_output_token"]
|
|
391
|
-
|
|
392
|
-
|
|
358
|
+
|
|
359
|
+
# TODO: for some models (e.g. GPT-5) we cannot separate text from image prompt tokens yet;
|
|
360
|
+
# for now, we only use tokens from prompt_token_details if it's an audio prompt
|
|
361
|
+
# get output tokens (all text) and input tokens by modality
|
|
362
|
+
output_tokens = usage["completion_tokens"]
|
|
363
|
+
if self.prompt_strategy.is_audio_prompt():
|
|
364
|
+
input_audio_tokens = usage["prompt_tokens_details"].get("audio_tokens", 0)
|
|
365
|
+
input_text_tokens = usage["prompt_tokens_details"].get("text_tokens", 0)
|
|
366
|
+
input_image_tokens = 0
|
|
367
|
+
else:
|
|
368
|
+
input_audio_tokens = 0
|
|
369
|
+
input_text_tokens = usage["prompt_tokens"]
|
|
370
|
+
input_image_tokens = 0
|
|
371
|
+
input_tokens = input_audio_tokens + input_text_tokens + input_image_tokens
|
|
372
|
+
|
|
373
|
+
# compute the input and output token costs
|
|
374
|
+
total_input_cost = (input_text_tokens + input_image_tokens) * usd_per_input_token + input_audio_tokens * usd_per_audio_input_token
|
|
375
|
+
total_output_cost = output_tokens * usd_per_output_token
|
|
393
376
|
|
|
394
377
|
generation_stats = GenerationStats(
|
|
395
378
|
model_name=self.model_name,
|
|
396
379
|
llm_call_duration_secs=end_time - start_time,
|
|
397
380
|
fn_call_duration_secs=0.0,
|
|
381
|
+
input_audio_tokens=input_audio_tokens,
|
|
382
|
+
input_text_tokens=input_text_tokens,
|
|
383
|
+
input_image_tokens=input_image_tokens,
|
|
398
384
|
total_input_tokens=input_tokens,
|
|
399
385
|
total_output_tokens=output_tokens,
|
|
400
|
-
total_input_cost=
|
|
401
|
-
total_output_cost=
|
|
402
|
-
cost_per_record=
|
|
386
|
+
total_input_cost=total_input_cost,
|
|
387
|
+
total_output_cost=total_output_cost,
|
|
388
|
+
cost_per_record=total_input_cost + total_output_cost,
|
|
403
389
|
total_llm_calls=1,
|
|
404
|
-
# "system_prompt": system_prompt,
|
|
405
|
-
# "prompt": prompt,
|
|
406
|
-
# "usage": usage,
|
|
407
|
-
# "finish_reason": finish_reason,
|
|
408
|
-
# "answer_log_probs": answer_log_probs,
|
|
409
|
-
# "answer": answer,
|
|
410
390
|
)
|
|
411
391
|
|
|
412
392
|
# pretty print prompt + full completion output for debugging
|
|
413
|
-
completion_text =
|
|
393
|
+
completion_text = completion.choices[0].message.content
|
|
414
394
|
prompt = ""
|
|
415
395
|
for message in messages:
|
|
416
396
|
if message["role"] == "user":
|
|
417
|
-
|
|
397
|
+
if message["type"] == "text":
|
|
398
|
+
prompt += message["content"] + "\n"
|
|
399
|
+
elif message["type"] == "image":
|
|
400
|
+
prompt += "<image>\n"
|
|
401
|
+
elif message["type"] == "input_audio":
|
|
402
|
+
prompt += "<audio>\n"
|
|
418
403
|
logger.debug(f"PROMPT:\n{prompt}")
|
|
419
404
|
logger.debug(Fore.GREEN + f"{completion_text}\n" + Style.RESET_ALL)
|
|
420
405
|
|
|
@@ -422,17 +407,20 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
422
407
|
reasoning = None
|
|
423
408
|
try:
|
|
424
409
|
reasoning = self._parse_reasoning(completion_text, **kwargs)
|
|
425
|
-
except Exception:
|
|
426
|
-
|
|
427
|
-
logger.debug("TODO: undo this")
|
|
410
|
+
except Exception as e:
|
|
411
|
+
logger.error(f"Error parsing reasoning and answers: {e}")
|
|
428
412
|
pass
|
|
429
413
|
|
|
430
414
|
# parse field answers
|
|
431
|
-
field_answers = None
|
|
415
|
+
field_answers = None
|
|
416
|
+
if fields is not None and (self.prompt_strategy.is_bool_prompt() or self.prompt_strategy.is_join_prompt()):
|
|
417
|
+
field_answers = {"passed_operator": False}
|
|
418
|
+
elif fields is not None and not (self.prompt_strategy.is_bool_prompt() or self.prompt_strategy.is_join_prompt()):
|
|
419
|
+
field_answers = {field_name: None for field_name in fields}
|
|
432
420
|
try:
|
|
433
421
|
field_answers = self._parse_answer(completion_text, fields, json_output, **kwargs)
|
|
434
422
|
except Exception as e:
|
|
435
|
-
|
|
423
|
+
logger.error(f"Error parsing answers: {e}")
|
|
436
424
|
os.makedirs("parse-answer-errors", exist_ok=True)
|
|
437
425
|
ts = time.time()
|
|
438
426
|
with open(f"parse-answer-errors/error-{ts}.txt", "w") as f:
|
|
@@ -448,162 +436,3 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
448
436
|
|
|
449
437
|
logger.debug(f"Generated field answers: {field_answers}")
|
|
450
438
|
return field_answers, reasoning, generation_stats, messages
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
class OpenAIGenerator(BaseGenerator[str | list[str], str]):
|
|
454
|
-
"""
|
|
455
|
-
Class for generating text using the OpenAI chat API.
|
|
456
|
-
"""
|
|
457
|
-
|
|
458
|
-
def __init__(
|
|
459
|
-
self,
|
|
460
|
-
model: Model,
|
|
461
|
-
prompt_strategy: PromptStrategy,
|
|
462
|
-
cardinality: Cardinality = Cardinality.ONE_TO_ONE,
|
|
463
|
-
verbose: bool = False,
|
|
464
|
-
):
|
|
465
|
-
# assert that model is an OpenAI model
|
|
466
|
-
assert model.is_openai_model()
|
|
467
|
-
super().__init__(model, prompt_strategy, cardinality, verbose, "developer")
|
|
468
|
-
|
|
469
|
-
def _get_client_or_model(self, **kwargs) -> OpenAI:
|
|
470
|
-
"""Returns a client (or local model) which can be invoked to perform the generation."""
|
|
471
|
-
return APIClientFactory.get_client(APIClient.OPENAI, get_api_key("OPENAI_API_KEY"))
|
|
472
|
-
|
|
473
|
-
def _generate_completion(self, client: OpenAI, payload: dict, **kwargs) -> ChatCompletion:
|
|
474
|
-
"""Generates a completion object using the client (or local model)."""
|
|
475
|
-
return client.chat.completions.create(**payload)
|
|
476
|
-
|
|
477
|
-
def _get_completion_text(self, completion: ChatCompletion, **kwargs) -> str:
|
|
478
|
-
"""Extract the completion text from the completion object."""
|
|
479
|
-
return completion.choices[0].message.content
|
|
480
|
-
|
|
481
|
-
def _get_usage(self, completion: ChatCompletion, **kwargs) -> dict:
|
|
482
|
-
"""Extract the usage statistics from the completion object."""
|
|
483
|
-
return {
|
|
484
|
-
"input_tokens": completion.usage.prompt_tokens,
|
|
485
|
-
"output_tokens": completion.usage.completion_tokens,
|
|
486
|
-
}
|
|
487
|
-
|
|
488
|
-
def _get_finish_reason(self, completion: ChatCompletion, **kwargs) -> str:
|
|
489
|
-
"""Extract the finish reason from the completion object."""
|
|
490
|
-
return completion.choices[0].finish_reason
|
|
491
|
-
|
|
492
|
-
def _get_answer_log_probs(self, completion: ChatCompletion, **kwargs) -> list[float]:
|
|
493
|
-
"""Extract the log probabilities from the completion object."""
|
|
494
|
-
return completion.choices[0].logprobs
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
class TogetherGenerator(BaseGenerator[str | list[str], str]):
|
|
498
|
-
"""
|
|
499
|
-
Class for generating text using the Together chat API.
|
|
500
|
-
"""
|
|
501
|
-
|
|
502
|
-
def __init__(
|
|
503
|
-
self,
|
|
504
|
-
model: Model,
|
|
505
|
-
prompt_strategy: PromptStrategy,
|
|
506
|
-
cardinality: Cardinality = Cardinality.ONE_TO_ONE,
|
|
507
|
-
verbose: bool = False,
|
|
508
|
-
):
|
|
509
|
-
# assert that model is a model offered by Together
|
|
510
|
-
assert model.is_together_model()
|
|
511
|
-
super().__init__(model, prompt_strategy, cardinality, verbose, "system")
|
|
512
|
-
|
|
513
|
-
def _generate_payload(self, messages: list[dict], **kwargs) -> dict:
|
|
514
|
-
"""
|
|
515
|
-
Generates the payload which will be fed into the client (or local model).
|
|
516
|
-
|
|
517
|
-
Each message will be a dictionary with the following format:
|
|
518
|
-
{
|
|
519
|
-
"role": "user" | "system",
|
|
520
|
-
"type": "text" | "image",
|
|
521
|
-
"content": str
|
|
522
|
-
}
|
|
523
|
-
|
|
524
|
-
For LLAMA3, the payload needs to be in a {"role": <role>, "content": <content>} format.
|
|
525
|
-
"""
|
|
526
|
-
# for other models, use our standard payload generation
|
|
527
|
-
if not self.model.is_llama_model():
|
|
528
|
-
return super()._generate_payload(messages, **kwargs)
|
|
529
|
-
|
|
530
|
-
# get basic parameters
|
|
531
|
-
model = self.model_name
|
|
532
|
-
temperature = kwargs.get("temperature", 0.0)
|
|
533
|
-
|
|
534
|
-
# construct messages in simple {"role": <role>, "content": <content>} format
|
|
535
|
-
chat_messages = []
|
|
536
|
-
for message in messages:
|
|
537
|
-
chat_messages.append({"role": message["role"], "content": message["content"]})
|
|
538
|
-
|
|
539
|
-
# construct and return payload
|
|
540
|
-
payload = {
|
|
541
|
-
"model": model,
|
|
542
|
-
"temperature": temperature,
|
|
543
|
-
"messages": chat_messages,
|
|
544
|
-
}
|
|
545
|
-
|
|
546
|
-
return payload
|
|
547
|
-
|
|
548
|
-
def _get_client_or_model(self, **kwargs) -> Together:
|
|
549
|
-
"""Returns a client (or local model) which can be invoked to perform the generation."""
|
|
550
|
-
return APIClientFactory.get_client(APIClient.TOGETHER, get_api_key("TOGETHER_API_KEY"))
|
|
551
|
-
|
|
552
|
-
def _generate_completion(self, client: Together, payload: dict, **kwargs) -> ChatCompletionResponse:
|
|
553
|
-
"""Generates a completion object using the client (or local model)."""
|
|
554
|
-
return client.chat.completions.create(**payload)
|
|
555
|
-
|
|
556
|
-
def _get_completion_text(self, completion: ChatCompletionResponse, **kwargs) -> str:
|
|
557
|
-
"""Extract the completion text from the completion object."""
|
|
558
|
-
return completion.choices[0].message.content
|
|
559
|
-
|
|
560
|
-
def _get_usage(self, completion: ChatCompletionResponse, **kwargs) -> dict:
|
|
561
|
-
"""Extract the usage statistics from the completion object."""
|
|
562
|
-
return {
|
|
563
|
-
"input_tokens": completion.usage.prompt_tokens,
|
|
564
|
-
"output_tokens": completion.usage.completion_tokens,
|
|
565
|
-
}
|
|
566
|
-
|
|
567
|
-
def _get_finish_reason(self, completion: ChatCompletionResponse, **kwargs) -> str:
|
|
568
|
-
"""Extract the finish reason from the completion object."""
|
|
569
|
-
return completion.choices[0].finish_reason.value
|
|
570
|
-
|
|
571
|
-
def _get_answer_log_probs(self, completion: ChatCompletionResponse, **kwargs) -> list[float]:
|
|
572
|
-
"""Extract the log probabilities from the completion object."""
|
|
573
|
-
return completion.choices[0].logprobs
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
### CODE SYNTHESIS EXECUTION ###
|
|
577
|
-
def code_execution(api: API, code: str, candidate_dict: dict[str, Any], verbose: bool = False):
|
|
578
|
-
inputs = {field_name: candidate_dict[field_name] for field_name in api.inputs}
|
|
579
|
-
response = api.api_execute(code, inputs)
|
|
580
|
-
pred = response["response"] if response["status"] and response["response"] else None
|
|
581
|
-
return pred
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
def code_ensemble_execution(
|
|
585
|
-
api: API, code_ensemble: dict[str, str], candidate_dict: dict[str, Any], verbose: bool = True
|
|
586
|
-
) -> GenerationOutput:
|
|
587
|
-
start_time = time.time()
|
|
588
|
-
try:
|
|
589
|
-
preds = list()
|
|
590
|
-
for _, code in code_ensemble.items():
|
|
591
|
-
pred = code_execution(api, code, candidate_dict)
|
|
592
|
-
preds.append(pred)
|
|
593
|
-
|
|
594
|
-
preds = [pred for pred in preds if pred is not None]
|
|
595
|
-
|
|
596
|
-
if len(preds) == 1:
|
|
597
|
-
majority_response = preds[0]
|
|
598
|
-
exec_stats = GenerationStats(fn_call_duration_secs=time.time() - start_time)
|
|
599
|
-
return majority_response, None, exec_stats
|
|
600
|
-
|
|
601
|
-
if len(preds) > 0:
|
|
602
|
-
majority_response = Counter(preds).most_common(1)[0][0]
|
|
603
|
-
exec_stats = GenerationStats(fn_call_duration_secs=time.time() - start_time)
|
|
604
|
-
return majority_response, None, exec_stats
|
|
605
|
-
|
|
606
|
-
except Exception:
|
|
607
|
-
pass
|
|
608
|
-
|
|
609
|
-
return None, None, GenerationStats(fn_call_duration_secs=time.time() - start_time)
|