palimpzest 0.6.4__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/__init__.py +5 -0
- palimpzest/constants.py +110 -43
- palimpzest/core/__init__.py +0 -78
- palimpzest/core/data/dataclasses.py +382 -44
- palimpzest/core/elements/filters.py +7 -3
- palimpzest/core/elements/index.py +70 -0
- palimpzest/core/elements/records.py +33 -11
- palimpzest/core/lib/fields.py +1 -0
- palimpzest/core/lib/schemas.py +4 -3
- palimpzest/prompts/moa_proposer_convert_prompts.py +0 -4
- palimpzest/prompts/prompt_factory.py +44 -7
- palimpzest/prompts/split_merge_prompts.py +56 -0
- palimpzest/prompts/split_proposer_prompts.py +55 -0
- palimpzest/query/execution/execution_strategy.py +435 -53
- palimpzest/query/execution/execution_strategy_type.py +20 -0
- palimpzest/query/execution/mab_execution_strategy.py +532 -0
- palimpzest/query/execution/parallel_execution_strategy.py +143 -172
- palimpzest/query/execution/random_sampling_execution_strategy.py +240 -0
- palimpzest/query/execution/single_threaded_execution_strategy.py +173 -203
- palimpzest/query/generators/api_client_factory.py +31 -0
- palimpzest/query/generators/generators.py +256 -76
- palimpzest/query/operators/__init__.py +1 -2
- palimpzest/query/operators/code_synthesis_convert.py +33 -18
- palimpzest/query/operators/convert.py +30 -97
- palimpzest/query/operators/critique_and_refine_convert.py +5 -6
- palimpzest/query/operators/filter.py +7 -10
- palimpzest/query/operators/logical.py +54 -10
- palimpzest/query/operators/map.py +130 -0
- palimpzest/query/operators/mixture_of_agents_convert.py +6 -6
- palimpzest/query/operators/physical.py +3 -12
- palimpzest/query/operators/rag_convert.py +66 -18
- palimpzest/query/operators/retrieve.py +230 -34
- palimpzest/query/operators/scan.py +5 -2
- palimpzest/query/operators/split_convert.py +169 -0
- palimpzest/query/operators/token_reduction_convert.py +8 -14
- palimpzest/query/optimizer/__init__.py +4 -16
- palimpzest/query/optimizer/cost_model.py +73 -266
- palimpzest/query/optimizer/optimizer.py +87 -58
- palimpzest/query/optimizer/optimizer_strategy.py +18 -97
- palimpzest/query/optimizer/optimizer_strategy_type.py +37 -0
- palimpzest/query/optimizer/plan.py +2 -3
- palimpzest/query/optimizer/primitives.py +5 -3
- palimpzest/query/optimizer/rules.py +336 -172
- palimpzest/query/optimizer/tasks.py +30 -100
- palimpzest/query/processor/config.py +38 -22
- palimpzest/query/processor/nosentinel_processor.py +16 -520
- palimpzest/query/processor/processing_strategy_type.py +28 -0
- palimpzest/query/processor/query_processor.py +38 -206
- palimpzest/query/processor/query_processor_factory.py +117 -130
- palimpzest/query/processor/sentinel_processor.py +90 -0
- palimpzest/query/processor/streaming_processor.py +25 -32
- palimpzest/sets.py +88 -41
- palimpzest/utils/model_helpers.py +8 -7
- palimpzest/utils/progress.py +368 -152
- palimpzest/utils/token_reduction_helpers.py +1 -3
- {palimpzest-0.6.4.dist-info → palimpzest-0.7.1.dist-info}/METADATA +19 -9
- palimpzest-0.7.1.dist-info/RECORD +96 -0
- {palimpzest-0.6.4.dist-info → palimpzest-0.7.1.dist-info}/WHEEL +1 -1
- palimpzest/query/processor/mab_sentinel_processor.py +0 -884
- palimpzest/query/processor/random_sampling_sentinel_processor.py +0 -639
- palimpzest/utils/index_helpers.py +0 -6
- palimpzest-0.6.4.dist-info/RECORD +0 -87
- {palimpzest-0.6.4.dist-info → palimpzest-0.7.1.dist-info/licenses}/LICENSE +0 -0
- {palimpzest-0.6.4.dist-info → palimpzest-0.7.1.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
"""
|
|
2
2
|
This file contains the Generator classes and generator factory.
|
|
3
3
|
"""
|
|
4
|
+
|
|
4
5
|
from __future__ import annotations
|
|
5
6
|
|
|
7
|
+
import logging
|
|
6
8
|
import os
|
|
7
9
|
import re
|
|
8
10
|
import time
|
|
@@ -15,40 +17,42 @@ from typing import Any, Generic, TypeVar
|
|
|
15
17
|
from colorama import Fore, Style
|
|
16
18
|
from openai import OpenAI
|
|
17
19
|
from openai.types.chat.chat_completion import ChatCompletion
|
|
18
|
-
|
|
19
|
-
# from tenacity import retry, stop_after_attempt, wait_exponential
|
|
20
20
|
from together import Together
|
|
21
21
|
from together.types.chat_completions import ChatCompletionResponse
|
|
22
22
|
|
|
23
23
|
from palimpzest.constants import (
|
|
24
24
|
MODEL_CARDS,
|
|
25
|
-
|
|
26
|
-
# RETRY_MAX_SECS,
|
|
27
|
-
# RETRY_MULTIPLIER,
|
|
25
|
+
APIClient,
|
|
28
26
|
Cardinality,
|
|
29
27
|
Model,
|
|
30
28
|
PromptStrategy,
|
|
31
29
|
)
|
|
32
30
|
from palimpzest.core.data.dataclasses import GenerationStats
|
|
33
31
|
from palimpzest.core.elements.records import DataRecord
|
|
32
|
+
from palimpzest.core.lib.fields import Field, ListField
|
|
34
33
|
from palimpzest.prompts import PromptFactory
|
|
34
|
+
from palimpzest.query.generators.api_client_factory import APIClientFactory
|
|
35
35
|
from palimpzest.utils.generation_helpers import get_json_from_answer
|
|
36
36
|
from palimpzest.utils.sandbox import API
|
|
37
37
|
|
|
38
38
|
# DEFINITIONS
|
|
39
|
-
GenerationOutput = tuple[dict, str | None, GenerationStats]
|
|
39
|
+
GenerationOutput = tuple[dict, str | None, GenerationStats, list[dict]]
|
|
40
40
|
ContextType = TypeVar("ContextType")
|
|
41
41
|
InputType = TypeVar("InputType")
|
|
42
42
|
|
|
43
43
|
|
|
44
|
-
|
|
44
|
+
logger = logging.getLogger(__name__)
|
|
45
|
+
|
|
46
|
+
def generator_factory(
|
|
47
|
+
model: Model, prompt_strategy: PromptStrategy, cardinality: Cardinality, verbose: bool = False
|
|
48
|
+
) -> BaseGenerator:
|
|
45
49
|
"""
|
|
46
50
|
Factory function to return the correct generator based on the model, strategy, and cardinality.
|
|
47
51
|
"""
|
|
48
52
|
if model in [Model.GPT_4o, Model.GPT_4o_MINI, Model.GPT_4o_V, Model.GPT_4o_MINI_V]:
|
|
49
53
|
return OpenAIGenerator(model, prompt_strategy, cardinality, verbose)
|
|
50
54
|
|
|
51
|
-
elif model in [Model.MIXTRAL, Model.LLAMA3, Model.LLAMA3_V]:
|
|
55
|
+
elif model in [Model.MIXTRAL, Model.LLAMA3, Model.LLAMA3_V, Model.DEEPSEEK]:
|
|
52
56
|
return TogetherGenerator(model, prompt_strategy, cardinality, verbose)
|
|
53
57
|
|
|
54
58
|
raise Exception(f"Unsupported model: {model}")
|
|
@@ -69,7 +73,15 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
69
73
|
"""
|
|
70
74
|
Abstract base class for Generators.
|
|
71
75
|
"""
|
|
72
|
-
|
|
76
|
+
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
model: Model,
|
|
80
|
+
prompt_strategy: PromptStrategy,
|
|
81
|
+
cardinality: Cardinality = Cardinality.ONE_TO_ONE,
|
|
82
|
+
verbose: bool = False,
|
|
83
|
+
system_role: str = "system",
|
|
84
|
+
):
|
|
73
85
|
self.model = model
|
|
74
86
|
self.model_name = model.value
|
|
75
87
|
self.cardinality = cardinality
|
|
@@ -77,11 +89,6 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
77
89
|
self.verbose = verbose
|
|
78
90
|
self.system_role = system_role
|
|
79
91
|
self.prompt_factory = PromptFactory(prompt_strategy, model, cardinality)
|
|
80
|
-
self.messages = None
|
|
81
|
-
|
|
82
|
-
def get_messages(self) -> list[dict] | None:
|
|
83
|
-
"""Returns the messages used in the last generation."""
|
|
84
|
-
return self.messages
|
|
85
92
|
|
|
86
93
|
@abstractmethod
|
|
87
94
|
def _get_client_or_model(self, **kwargs) -> Any:
|
|
@@ -160,7 +167,7 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
160
167
|
|
|
161
168
|
return payload
|
|
162
169
|
|
|
163
|
-
def _parse_reasoning(self, completion_text: str, **kwargs) ->
|
|
170
|
+
def _parse_reasoning(self, completion_text: str, **kwargs) -> str:
|
|
164
171
|
"""Extract the reasoning for the generated output from the completion object."""
|
|
165
172
|
# use a custom reasoning parser if provided
|
|
166
173
|
if kwargs.get("parse_reasoning"):
|
|
@@ -169,66 +176,174 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
169
176
|
|
|
170
177
|
# if the model followed the default instructions, the completion text will have reasoning
|
|
171
178
|
# before the "ANSWER:"; if this is the case, we simply extract and return that full section
|
|
172
|
-
|
|
179
|
+
if "answer" in completion_text.lower():
|
|
180
|
+
regex = re.compile("(.*?)answer:.*", re.IGNORECASE | re.DOTALL)
|
|
181
|
+
matches = regex.findall(completion_text)
|
|
182
|
+
if len(matches) > 0:
|
|
183
|
+
return matches[0].strip()
|
|
184
|
+
|
|
185
|
+
# otherwise, return the full completion text
|
|
186
|
+
return completion_text
|
|
187
|
+
|
|
188
|
+
def _prepare_field_answers(self, field_answers: dict | list[dict], fields: dict[str, Field]) -> dict[str, list]:
|
|
189
|
+
"""
|
|
190
|
+
field_answers is a dictionary mapping fields to their values. For one-to-one converts, wrap each
|
|
191
|
+
answer in a list. For one-to-many converts, invert the list of dictionaries into a dictionary with
|
|
192
|
+
list values.
|
|
193
|
+
"""
|
|
194
|
+
# if this is a one-to-one convert, we need to wrap each answer in a list
|
|
195
|
+
if self.cardinality == Cardinality.ONE_TO_ONE:
|
|
196
|
+
field_answers = {field_name: [field_answers[field_name]] for field_name in fields}
|
|
197
|
+
|
|
198
|
+
# otherwise, we need to invert the list of dictionaries into a dictionary with list values
|
|
199
|
+
else:
|
|
200
|
+
field_answers_lst: list[dict] = deepcopy(field_answers)
|
|
201
|
+
|
|
202
|
+
field_answers = {field_name: [] for field_name in fields}
|
|
203
|
+
for answer_dict in field_answers_lst:
|
|
204
|
+
for field_name in fields:
|
|
205
|
+
answer = answer_dict.get(field_name, None)
|
|
206
|
+
field_answers[field_name].append(answer)
|
|
207
|
+
|
|
208
|
+
return field_answers
|
|
209
|
+
|
|
210
|
+
def _check_convert_answer_text(self, answer_text: str, fields: dict[str, Field], throw_exception: bool=False) -> dict | list[dict] | None:
|
|
211
|
+
"""
|
|
212
|
+
Try parsing the answer text into a JSON object. If the parsing fails, return None.
|
|
213
|
+
"""
|
|
214
|
+
try:
|
|
215
|
+
# extract json from the answer text
|
|
216
|
+
field_answers = get_json_from_answer(answer_text, self.model, self.cardinality)
|
|
217
|
+
|
|
218
|
+
# TODO: wrap non-list outputs in a list if expected output is a list
|
|
219
|
+
|
|
220
|
+
# common error: if the output is a singleton list which contains a list, but the expected field type
|
|
221
|
+
# is a list of strings, or a list of floats, i.e. not a list of lists; then extract the inner list
|
|
222
|
+
for field, field_type in fields.items():
|
|
223
|
+
answer = field_answers[field]
|
|
224
|
+
field_type_is_not_list_of_lists = isinstance(field_type, ListField) and not issubclass(field_type.element_type, ListField)
|
|
225
|
+
answer_is_list_of_lists = isinstance(answer, list) and len(answer) == 1 and isinstance(answer[0], list)
|
|
226
|
+
if field_type_is_not_list_of_lists and answer_is_list_of_lists:
|
|
227
|
+
field_answers[field] = answer[0]
|
|
228
|
+
|
|
229
|
+
# prepare the field answers to match the expected output and return
|
|
230
|
+
return self._prepare_field_answers(field_answers, fields)
|
|
231
|
+
|
|
232
|
+
except Exception as e:
|
|
233
|
+
if throw_exception:
|
|
234
|
+
raise e
|
|
235
|
+
|
|
236
|
+
return None
|
|
237
|
+
|
|
238
|
+
def _check_filter_answer_text(self, answer_text: str) -> dict | None:
|
|
239
|
+
"""
|
|
240
|
+
Return {"passed_operator": True} if and only if "true" is in the answer text.
|
|
241
|
+
Return {"passed_operator": False} if and only if "false" is in the answer text.
|
|
242
|
+
Otherwise, return None.
|
|
243
|
+
"""
|
|
244
|
+
# NOTE: we may be able to eliminate this condition by specifying this JSON output in the prompt;
|
|
245
|
+
# however, that would also need to coincide with a change to allow the parse_answer_fn to set "passed_operator"
|
|
246
|
+
if "true" in answer_text.lower():
|
|
247
|
+
return {"passed_operator": True}
|
|
248
|
+
elif "false" in answer_text.lower():
|
|
249
|
+
return {"passed_operator": False}
|
|
250
|
+
|
|
251
|
+
return None
|
|
252
|
+
|
|
253
|
+
def _parse_convert_answer(self, completion_text: str, fields: dict[str, Field], json_output: bool) -> dict[str, list]:
|
|
254
|
+
"""Extract the answer from the completion object for convert operations."""
|
|
255
|
+
# if the model followed the default instructions, the completion text will place
|
|
256
|
+
# its answer between "ANSWER:" and "---"
|
|
257
|
+
regex = re.compile("answer:(.*?)---", re.IGNORECASE | re.DOTALL)
|
|
173
258
|
matches = regex.findall(completion_text)
|
|
174
259
|
if len(matches) > 0:
|
|
175
|
-
|
|
260
|
+
answer_text = matches[0].strip()
|
|
176
261
|
|
|
177
|
-
|
|
178
|
-
|
|
262
|
+
# if we don't expect a JSON output, return the answer text as is
|
|
263
|
+
if not json_output:
|
|
264
|
+
return answer_text
|
|
179
265
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
parse_answer_fn = kwargs.get("parse_answer")
|
|
185
|
-
return parse_answer_fn(completion_text)
|
|
266
|
+
# otherwise, try to parse the answer text into a JSON object
|
|
267
|
+
field_answers = self._check_convert_answer_text(answer_text, fields)
|
|
268
|
+
if field_answers is not None:
|
|
269
|
+
return field_answers
|
|
186
270
|
|
|
271
|
+
# if the first regex didn't find an answer, try taking all the text after "ANSWER:"
|
|
272
|
+
regex = re.compile("answer:(.*)", re.IGNORECASE | re.DOTALL)
|
|
273
|
+
matches = regex.findall(completion_text)
|
|
274
|
+
if len(matches) > 0:
|
|
275
|
+
answer_text = matches[0].strip()
|
|
276
|
+
|
|
277
|
+
# if we don't expect a JSON output, return the answer text as is
|
|
278
|
+
if not json_output:
|
|
279
|
+
return answer_text
|
|
280
|
+
|
|
281
|
+
# otherwise, try to parse the answer text into a JSON object
|
|
282
|
+
field_answers = self._check_convert_answer_text(answer_text, fields)
|
|
283
|
+
if field_answers is not None:
|
|
284
|
+
return field_answers
|
|
285
|
+
|
|
286
|
+
# finally, try taking all of the text; for JSON output, throw an exception if parsing fails
|
|
287
|
+
if not json_output:
|
|
288
|
+
return completion_text
|
|
289
|
+
|
|
290
|
+
return self._check_convert_answer_text(completion_text, fields, throw_exception=True)
|
|
291
|
+
|
|
292
|
+
def _parse_filter_answer(self, completion_text: str) -> dict[str, list]:
|
|
293
|
+
"""Extract the answer from the completion object for filter operations."""
|
|
187
294
|
# if the model followed the default instructions, the completion text will place
|
|
188
295
|
# its answer between "ANSWER:" and "---"
|
|
189
|
-
answer_text = None
|
|
190
296
|
regex = re.compile("answer:(.*?)---", re.IGNORECASE | re.DOTALL)
|
|
191
297
|
matches = regex.findall(completion_text)
|
|
192
298
|
if len(matches) > 0:
|
|
193
299
|
answer_text = matches[0].strip()
|
|
300
|
+
field_answers = self._check_filter_answer_text(answer_text)
|
|
301
|
+
if field_answers is not None:
|
|
302
|
+
return field_answers
|
|
194
303
|
|
|
195
|
-
#
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
answer_text = matches[0].strip()
|
|
304
|
+
# if the first regex didn't find an answer, try taking all the text after "ANSWER:"
|
|
305
|
+
regex = re.compile("answer:(.*)", re.IGNORECASE | re.DOTALL)
|
|
306
|
+
matches = regex.findall(completion_text)
|
|
307
|
+
if len(matches) > 0:
|
|
308
|
+
answer_text = matches[0].strip()
|
|
309
|
+
field_answers = self._check_filter_answer_text(answer_text)
|
|
310
|
+
if field_answers is not None:
|
|
311
|
+
return field_answers
|
|
200
312
|
|
|
201
|
-
#
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
return {"passed_operator": "true" in answer_text.lower()}
|
|
313
|
+
# finally, try taking all of the text; throw an exception if this doesn't work
|
|
314
|
+
field_answers = self._check_filter_answer_text(completion_text)
|
|
315
|
+
if field_answers is None:
|
|
316
|
+
raise Exception(f"Could not parse answer from completion text: {completion_text}")
|
|
206
317
|
|
|
207
|
-
|
|
208
|
-
field_answers = get_json_from_answer(answer_text, self.model, self.cardinality)
|
|
318
|
+
return field_answers
|
|
209
319
|
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
320
|
+
def _parse_answer(self, completion_text: str, fields: dict[str, Field] | None, json_output: bool, **kwargs) -> dict[str, list]:
|
|
321
|
+
"""Extract the answer from the completion object."""
|
|
322
|
+
# use a custom answer parser if provided
|
|
323
|
+
if kwargs.get("parse_answer"):
|
|
324
|
+
parse_answer_fn = kwargs.get("parse_answer")
|
|
325
|
+
return parse_answer_fn(completion_text)
|
|
213
326
|
|
|
214
|
-
#
|
|
215
|
-
|
|
216
|
-
field_answers_lst: list[dict] = deepcopy(field_answers)
|
|
327
|
+
# fields should be a dict if a custom answer parser is not provided
|
|
328
|
+
assert isinstance(fields, dict), "Fields must be provided if a custom answer parser is not provided."
|
|
217
329
|
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
330
|
+
# extract the per-field answers from the completion text
|
|
331
|
+
field_answers = (
|
|
332
|
+
self._parse_filter_answer(completion_text)
|
|
333
|
+
if self.prompt_strategy.is_bool_prompt()
|
|
334
|
+
else self._parse_convert_answer(completion_text, fields, json_output)
|
|
335
|
+
)
|
|
223
336
|
|
|
224
337
|
return field_answers
|
|
225
338
|
|
|
226
|
-
def __call__(self, candidate: DataRecord, fields:
|
|
339
|
+
def __call__(self, candidate: DataRecord, fields: dict[str, Field] | None, json_output: bool=True, **kwargs) -> GenerationOutput:
|
|
227
340
|
"""Take the input record (`candidate`), generate the output `fields`, and return the generated output."""
|
|
228
341
|
client = self._get_client_or_model()
|
|
342
|
+
logger.debug(f"Generating for candidate {candidate} with fields {fields}")
|
|
229
343
|
|
|
230
344
|
# fields can only be None if the user provides an answer parser
|
|
231
|
-
|
|
345
|
+
fields_check = fields is not None or "parse_answer" in kwargs
|
|
346
|
+
assert fields_check, "`fields` must be provided if `parse_answer` function is not provided in kwargs."
|
|
232
347
|
|
|
233
348
|
# if the user (or operator) provides a system prompt instead of a prompt, treat this as
|
|
234
349
|
# the prompt and print a warning
|
|
@@ -238,10 +353,10 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
238
353
|
warnings.warn("Provided `system_prompt` without providing `prompt`; setting `prompt` = `system_prompt`.") # noqa: B028
|
|
239
354
|
|
|
240
355
|
# generate a list of messages which can be used to construct a payload
|
|
241
|
-
|
|
356
|
+
messages = self.prompt_factory.create_messages(candidate, fields, **kwargs)
|
|
242
357
|
|
|
243
358
|
# create the chat payload
|
|
244
|
-
chat_payload = self._generate_payload(
|
|
359
|
+
chat_payload = self._generate_payload(messages, **kwargs)
|
|
245
360
|
|
|
246
361
|
# generate the text completion
|
|
247
362
|
start_time = time.time()
|
|
@@ -249,16 +364,20 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
249
364
|
try:
|
|
250
365
|
completion = self._generate_completion(client, chat_payload, **kwargs)
|
|
251
366
|
end_time = time.time()
|
|
252
|
-
|
|
367
|
+
logger.debug(f"Generated completion in {end_time - start_time:.2f} seconds")
|
|
253
368
|
# if there's an error generating the completion, we have to return an empty answer
|
|
254
369
|
# and can only account for the time spent performing the failed generation
|
|
255
370
|
except Exception as e:
|
|
256
|
-
|
|
371
|
+
logger.error(f"Error generating completion: {e}")
|
|
257
372
|
field_answers = {field_name: None for field_name in fields}
|
|
258
373
|
reasoning = None
|
|
259
|
-
generation_stats = GenerationStats(
|
|
374
|
+
generation_stats = GenerationStats(
|
|
375
|
+
model_name=self.model_name,
|
|
376
|
+
llm_call_duration_secs=time.time() - start_time,
|
|
377
|
+
total_llm_calls=1,
|
|
378
|
+
)
|
|
260
379
|
|
|
261
|
-
return field_answers, reasoning, generation_stats
|
|
380
|
+
return field_answers, reasoning, generation_stats, messages
|
|
262
381
|
|
|
263
382
|
# parse usage statistics and create the GenerationStats
|
|
264
383
|
generation_stats = None
|
|
@@ -282,6 +401,7 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
282
401
|
total_input_cost=input_tokens * usd_per_input_token,
|
|
283
402
|
total_output_cost=output_tokens * usd_per_output_token,
|
|
284
403
|
cost_per_record=input_tokens * usd_per_input_token + output_tokens * usd_per_output_token,
|
|
404
|
+
total_llm_calls=1,
|
|
285
405
|
# "system_prompt": system_prompt,
|
|
286
406
|
# "prompt": prompt,
|
|
287
407
|
# "usage": usage,
|
|
@@ -292,46 +412,65 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
|
|
|
292
412
|
|
|
293
413
|
# pretty print prompt + full completion output for debugging
|
|
294
414
|
completion_text = self._get_completion_text(completion, **kwargs)
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
if message["
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
print(Fore.GREEN + f"{completion_text}\n" + Style.RESET_ALL)
|
|
415
|
+
prompt = ""
|
|
416
|
+
for message in messages:
|
|
417
|
+
if message["role"] == "user":
|
|
418
|
+
prompt += message["content"] + "\n" if message["type"] == "text" else "<image>\n"
|
|
419
|
+
logger.debug(f"PROMPT:\n{prompt}")
|
|
420
|
+
logger.debug(Fore.GREEN + f"{completion_text}\n" + Style.RESET_ALL)
|
|
302
421
|
|
|
303
422
|
# parse reasoning
|
|
304
423
|
reasoning = None
|
|
305
424
|
try:
|
|
306
425
|
reasoning = self._parse_reasoning(completion_text, **kwargs)
|
|
307
426
|
except Exception as e:
|
|
308
|
-
|
|
427
|
+
logger.error(f"Error parsing reasoning and answers: {e}")
|
|
309
428
|
|
|
310
429
|
# parse field answers
|
|
311
430
|
field_answers = None if fields is None else {field_name: None for field_name in fields}
|
|
312
431
|
try:
|
|
313
|
-
field_answers = self._parse_answer(completion_text, fields, **kwargs)
|
|
432
|
+
field_answers = self._parse_answer(completion_text, fields, json_output, **kwargs)
|
|
314
433
|
except Exception as e:
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
434
|
+
logger.error(f"Error parsing answers: {e}")
|
|
435
|
+
os.makedirs("parse-answer-errors", exist_ok=True)
|
|
436
|
+
ts = time.time()
|
|
437
|
+
with open(f"parse-answer-errors/error-{ts}.txt", "w") as f:
|
|
438
|
+
f.write(f"{str(self.model_name)}\n")
|
|
439
|
+
f.write("#####\n")
|
|
440
|
+
f.write(f"{str(self.prompt_strategy)}\n")
|
|
441
|
+
f.write("#####\n")
|
|
442
|
+
f.write(f"{str(completion_text)}\n")
|
|
443
|
+
f.write("#####\n")
|
|
444
|
+
f.write(f"{str(fields)}\n")
|
|
445
|
+
f.write("#####\n")
|
|
446
|
+
f.write(f"{str(e)}\n")
|
|
447
|
+
|
|
448
|
+
logger.debug(f"Generated field answers: {field_answers}")
|
|
449
|
+
return field_answers, reasoning, generation_stats, messages
|
|
318
450
|
|
|
319
451
|
|
|
320
452
|
class OpenAIGenerator(BaseGenerator[str | list[str], str]):
|
|
321
453
|
"""
|
|
322
454
|
Class for generating text using the OpenAI chat API.
|
|
323
455
|
"""
|
|
324
|
-
|
|
456
|
+
|
|
457
|
+
def __init__(
|
|
458
|
+
self,
|
|
459
|
+
model: Model,
|
|
460
|
+
prompt_strategy: PromptStrategy,
|
|
461
|
+
cardinality: Cardinality = Cardinality.ONE_TO_ONE,
|
|
462
|
+
verbose: bool = False,
|
|
463
|
+
):
|
|
325
464
|
# assert that model is an OpenAI model
|
|
326
465
|
assert model in [Model.GPT_4o, Model.GPT_4o_MINI, Model.GPT_4o_V, Model.GPT_4o_MINI_V]
|
|
327
466
|
super().__init__(model, prompt_strategy, cardinality, verbose, "developer")
|
|
328
467
|
|
|
329
468
|
def _get_client_or_model(self, **kwargs) -> OpenAI:
|
|
330
469
|
"""Returns a client (or local model) which can be invoked to perform the generation."""
|
|
331
|
-
return
|
|
470
|
+
return APIClientFactory.get_client(APIClient.OPENAI, get_api_key("OPENAI_API_KEY"))
|
|
332
471
|
|
|
333
472
|
def _generate_completion(self, client: OpenAI, payload: dict, **kwargs) -> ChatCompletion:
|
|
334
|
-
"""Generates a completion object using the client (or local model)."""
|
|
473
|
+
"""Generates a completion object using the client (or local model)."""
|
|
335
474
|
return client.chat.completions.create(**payload)
|
|
336
475
|
|
|
337
476
|
def _get_completion_text(self, completion: ChatCompletion, **kwargs) -> str:
|
|
@@ -358,14 +497,56 @@ class TogetherGenerator(BaseGenerator[str | list[str], str]):
|
|
|
358
497
|
"""
|
|
359
498
|
Class for generating text using the Together chat API.
|
|
360
499
|
"""
|
|
361
|
-
|
|
500
|
+
|
|
501
|
+
def __init__(
|
|
502
|
+
self,
|
|
503
|
+
model: Model,
|
|
504
|
+
prompt_strategy: PromptStrategy,
|
|
505
|
+
cardinality: Cardinality = Cardinality.ONE_TO_ONE,
|
|
506
|
+
verbose: bool = False,
|
|
507
|
+
):
|
|
362
508
|
# assert that model is a model offered by Together
|
|
363
|
-
assert model in [Model.MIXTRAL, Model.LLAMA3, Model.LLAMA3_V]
|
|
509
|
+
assert model in [Model.MIXTRAL, Model.LLAMA3, Model.LLAMA3_V, Model.DEEPSEEK]
|
|
364
510
|
super().__init__(model, prompt_strategy, cardinality, verbose, "system")
|
|
365
511
|
|
|
512
|
+
def _generate_payload(self, messages: list[dict], **kwargs) -> dict:
|
|
513
|
+
"""
|
|
514
|
+
Generates the payload which will be fed into the client (or local model).
|
|
515
|
+
|
|
516
|
+
Each message will be a dictionary with the following format:
|
|
517
|
+
{
|
|
518
|
+
"role": "user" | "system",
|
|
519
|
+
"type": "text" | "image",
|
|
520
|
+
"content": str
|
|
521
|
+
}
|
|
522
|
+
|
|
523
|
+
For LLAMA3, the payload needs to be in a {"role": <role>, "content": <content>} format.
|
|
524
|
+
"""
|
|
525
|
+
# for other models, use our standard payload generation
|
|
526
|
+
if self.model != Model.LLAMA3:
|
|
527
|
+
return super()._generate_payload(messages, **kwargs)
|
|
528
|
+
|
|
529
|
+
# get basic parameters
|
|
530
|
+
model = self.model_name
|
|
531
|
+
temperature = kwargs.get("temperature", 0.0)
|
|
532
|
+
|
|
533
|
+
# construct messages in simple {"role": <role>, "content": <content>} format
|
|
534
|
+
chat_messages = []
|
|
535
|
+
for message in messages:
|
|
536
|
+
chat_messages.append({"role": message["role"], "content": message["content"]})
|
|
537
|
+
|
|
538
|
+
# construct and return payload
|
|
539
|
+
payload = {
|
|
540
|
+
"model": model,
|
|
541
|
+
"temperature": temperature,
|
|
542
|
+
"messages": chat_messages,
|
|
543
|
+
}
|
|
544
|
+
|
|
545
|
+
return payload
|
|
546
|
+
|
|
366
547
|
def _get_client_or_model(self, **kwargs) -> Together:
|
|
367
548
|
"""Returns a client (or local model) which can be invoked to perform the generation."""
|
|
368
|
-
return
|
|
549
|
+
return APIClientFactory.get_client(APIClient.TOGETHER, get_api_key("TOGETHER_API_KEY"))
|
|
369
550
|
|
|
370
551
|
def _generate_completion(self, client: Together, payload: dict, **kwargs) -> ChatCompletionResponse:
|
|
371
552
|
"""Generates a completion object using the client (or local model)."""
|
|
@@ -391,7 +572,6 @@ class TogetherGenerator(BaseGenerator[str | list[str], str]):
|
|
|
391
572
|
return completion.choices[0].logprobs
|
|
392
573
|
|
|
393
574
|
|
|
394
|
-
|
|
395
575
|
### CODE SYNTHESIS EXECUTION ###
|
|
396
576
|
def code_execution(api: API, code: str, candidate_dict: dict[str, Any], verbose: bool = False):
|
|
397
577
|
inputs = {field_name: candidate_dict[field_name] for field_name in api.inputs}
|
|
@@ -5,7 +5,6 @@ from palimpzest.query.operators.aggregate import CountAggregateOp as _CountAggre
|
|
|
5
5
|
from palimpzest.query.operators.convert import ConvertOp as _ConvertOp
|
|
6
6
|
from palimpzest.query.operators.convert import LLMConvert as _LLMConvert
|
|
7
7
|
from palimpzest.query.operators.convert import LLMConvertBonded as _LLMConvertBonded
|
|
8
|
-
from palimpzest.query.operators.convert import LLMConvertConventional as _LLMConvertConventional
|
|
9
8
|
from palimpzest.query.operators.convert import NonLLMConvert as _NonLLMConvert
|
|
10
9
|
from palimpzest.query.operators.filter import FilterOp as _FilterOp
|
|
11
10
|
from palimpzest.query.operators.filter import LLMFilter as _LLMFilter
|
|
@@ -66,7 +65,7 @@ PHYSICAL_OPERATORS = (
|
|
|
66
65
|
# aggregate
|
|
67
66
|
[_AggregateOp, _ApplyGroupByOp, _AverageAggregateOp, _CountAggregateOp]
|
|
68
67
|
# convert
|
|
69
|
-
+ [_ConvertOp, _NonLLMConvert, _LLMConvert,
|
|
68
|
+
+ [_ConvertOp, _NonLLMConvert, _LLMConvert, _LLMConvertBonded]
|
|
70
69
|
# scan
|
|
71
70
|
+ [_ScanPhysicalOp, _MarshalAndScanDataOp, _CacheScanDataOp]
|
|
72
71
|
# filter
|