palimpzest 0.6.4__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. palimpzest/__init__.py +5 -0
  2. palimpzest/constants.py +110 -43
  3. palimpzest/core/__init__.py +0 -78
  4. palimpzest/core/data/dataclasses.py +382 -44
  5. palimpzest/core/elements/filters.py +7 -3
  6. palimpzest/core/elements/index.py +70 -0
  7. palimpzest/core/elements/records.py +33 -11
  8. palimpzest/core/lib/fields.py +1 -0
  9. palimpzest/core/lib/schemas.py +4 -3
  10. palimpzest/prompts/moa_proposer_convert_prompts.py +0 -4
  11. palimpzest/prompts/prompt_factory.py +44 -7
  12. palimpzest/prompts/split_merge_prompts.py +56 -0
  13. palimpzest/prompts/split_proposer_prompts.py +55 -0
  14. palimpzest/query/execution/execution_strategy.py +435 -53
  15. palimpzest/query/execution/execution_strategy_type.py +20 -0
  16. palimpzest/query/execution/mab_execution_strategy.py +532 -0
  17. palimpzest/query/execution/parallel_execution_strategy.py +143 -172
  18. palimpzest/query/execution/random_sampling_execution_strategy.py +240 -0
  19. palimpzest/query/execution/single_threaded_execution_strategy.py +173 -203
  20. palimpzest/query/generators/api_client_factory.py +31 -0
  21. palimpzest/query/generators/generators.py +256 -76
  22. palimpzest/query/operators/__init__.py +1 -2
  23. palimpzest/query/operators/code_synthesis_convert.py +33 -18
  24. palimpzest/query/operators/convert.py +30 -97
  25. palimpzest/query/operators/critique_and_refine_convert.py +5 -6
  26. palimpzest/query/operators/filter.py +7 -10
  27. palimpzest/query/operators/logical.py +54 -10
  28. palimpzest/query/operators/map.py +130 -0
  29. palimpzest/query/operators/mixture_of_agents_convert.py +6 -6
  30. palimpzest/query/operators/physical.py +3 -12
  31. palimpzest/query/operators/rag_convert.py +66 -18
  32. palimpzest/query/operators/retrieve.py +230 -34
  33. palimpzest/query/operators/scan.py +5 -2
  34. palimpzest/query/operators/split_convert.py +169 -0
  35. palimpzest/query/operators/token_reduction_convert.py +8 -14
  36. palimpzest/query/optimizer/__init__.py +4 -16
  37. palimpzest/query/optimizer/cost_model.py +73 -266
  38. palimpzest/query/optimizer/optimizer.py +87 -58
  39. palimpzest/query/optimizer/optimizer_strategy.py +18 -97
  40. palimpzest/query/optimizer/optimizer_strategy_type.py +37 -0
  41. palimpzest/query/optimizer/plan.py +2 -3
  42. palimpzest/query/optimizer/primitives.py +5 -3
  43. palimpzest/query/optimizer/rules.py +336 -172
  44. palimpzest/query/optimizer/tasks.py +30 -100
  45. palimpzest/query/processor/config.py +38 -22
  46. palimpzest/query/processor/nosentinel_processor.py +16 -520
  47. palimpzest/query/processor/processing_strategy_type.py +28 -0
  48. palimpzest/query/processor/query_processor.py +38 -206
  49. palimpzest/query/processor/query_processor_factory.py +117 -130
  50. palimpzest/query/processor/sentinel_processor.py +90 -0
  51. palimpzest/query/processor/streaming_processor.py +25 -32
  52. palimpzest/sets.py +88 -41
  53. palimpzest/utils/model_helpers.py +8 -7
  54. palimpzest/utils/progress.py +368 -152
  55. palimpzest/utils/token_reduction_helpers.py +1 -3
  56. {palimpzest-0.6.4.dist-info → palimpzest-0.7.0.dist-info}/METADATA +19 -9
  57. palimpzest-0.7.0.dist-info/RECORD +96 -0
  58. {palimpzest-0.6.4.dist-info → palimpzest-0.7.0.dist-info}/WHEEL +1 -1
  59. palimpzest/query/processor/mab_sentinel_processor.py +0 -884
  60. palimpzest/query/processor/random_sampling_sentinel_processor.py +0 -639
  61. palimpzest/utils/index_helpers.py +0 -6
  62. palimpzest-0.6.4.dist-info/RECORD +0 -87
  63. {palimpzest-0.6.4.dist-info → palimpzest-0.7.0.dist-info/licenses}/LICENSE +0 -0
  64. {palimpzest-0.6.4.dist-info → palimpzest-0.7.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,10 @@
1
1
  """
2
2
  This file contains the Generator classes and generator factory.
3
3
  """
4
+
4
5
  from __future__ import annotations
5
6
 
7
+ import logging
6
8
  import os
7
9
  import re
8
10
  import time
@@ -15,40 +17,42 @@ from typing import Any, Generic, TypeVar
15
17
  from colorama import Fore, Style
16
18
  from openai import OpenAI
17
19
  from openai.types.chat.chat_completion import ChatCompletion
18
-
19
- # from tenacity import retry, stop_after_attempt, wait_exponential
20
20
  from together import Together
21
21
  from together.types.chat_completions import ChatCompletionResponse
22
22
 
23
23
  from palimpzest.constants import (
24
24
  MODEL_CARDS,
25
- # RETRY_MAX_ATTEMPTS,
26
- # RETRY_MAX_SECS,
27
- # RETRY_MULTIPLIER,
25
+ APIClient,
28
26
  Cardinality,
29
27
  Model,
30
28
  PromptStrategy,
31
29
  )
32
30
  from palimpzest.core.data.dataclasses import GenerationStats
33
31
  from palimpzest.core.elements.records import DataRecord
32
+ from palimpzest.core.lib.fields import Field, ListField
34
33
  from palimpzest.prompts import PromptFactory
34
+ from palimpzest.query.generators.api_client_factory import APIClientFactory
35
35
  from palimpzest.utils.generation_helpers import get_json_from_answer
36
36
  from palimpzest.utils.sandbox import API
37
37
 
38
38
  # DEFINITIONS
39
- GenerationOutput = tuple[dict, str | None, GenerationStats]
39
+ GenerationOutput = tuple[dict, str | None, GenerationStats, list[dict]]
40
40
  ContextType = TypeVar("ContextType")
41
41
  InputType = TypeVar("InputType")
42
42
 
43
43
 
44
- def generator_factory(model: Model, prompt_strategy: PromptStrategy, cardinality: Cardinality, verbose: bool = False) -> BaseGenerator:
44
+ logger = logging.getLogger(__name__)
45
+
46
+ def generator_factory(
47
+ model: Model, prompt_strategy: PromptStrategy, cardinality: Cardinality, verbose: bool = False
48
+ ) -> BaseGenerator:
45
49
  """
46
50
  Factory function to return the correct generator based on the model, strategy, and cardinality.
47
51
  """
48
52
  if model in [Model.GPT_4o, Model.GPT_4o_MINI, Model.GPT_4o_V, Model.GPT_4o_MINI_V]:
49
53
  return OpenAIGenerator(model, prompt_strategy, cardinality, verbose)
50
54
 
51
- elif model in [Model.MIXTRAL, Model.LLAMA3, Model.LLAMA3_V]:
55
+ elif model in [Model.MIXTRAL, Model.LLAMA3, Model.LLAMA3_V, Model.DEEPSEEK]:
52
56
  return TogetherGenerator(model, prompt_strategy, cardinality, verbose)
53
57
 
54
58
  raise Exception(f"Unsupported model: {model}")
@@ -69,7 +73,15 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
69
73
  """
70
74
  Abstract base class for Generators.
71
75
  """
72
- def __init__(self, model: Model, prompt_strategy: PromptStrategy, cardinality: Cardinality = Cardinality.ONE_TO_ONE, verbose: bool = False, system_role: str = "system"):
76
+
77
+ def __init__(
78
+ self,
79
+ model: Model,
80
+ prompt_strategy: PromptStrategy,
81
+ cardinality: Cardinality = Cardinality.ONE_TO_ONE,
82
+ verbose: bool = False,
83
+ system_role: str = "system",
84
+ ):
73
85
  self.model = model
74
86
  self.model_name = model.value
75
87
  self.cardinality = cardinality
@@ -77,11 +89,6 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
77
89
  self.verbose = verbose
78
90
  self.system_role = system_role
79
91
  self.prompt_factory = PromptFactory(prompt_strategy, model, cardinality)
80
- self.messages = None
81
-
82
- def get_messages(self) -> list[dict] | None:
83
- """Returns the messages used in the last generation."""
84
- return self.messages
85
92
 
86
93
  @abstractmethod
87
94
  def _get_client_or_model(self, **kwargs) -> Any:
@@ -160,7 +167,7 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
160
167
 
161
168
  return payload
162
169
 
163
- def _parse_reasoning(self, completion_text: str, **kwargs) -> Any:
170
+ def _parse_reasoning(self, completion_text: str, **kwargs) -> str:
164
171
  """Extract the reasoning for the generated output from the completion object."""
165
172
  # use a custom reasoning parser if provided
166
173
  if kwargs.get("parse_reasoning"):
@@ -169,66 +176,174 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
169
176
 
170
177
  # if the model followed the default instructions, the completion text will have reasoning
171
178
  # before the "ANSWER:"; if this is the case, we simply extract and return that full section
172
- regex = re.compile("(.*?)answer:.*", re.IGNORECASE | re.DOTALL)
179
+ if "answer" in completion_text.lower():
180
+ regex = re.compile("(.*?)answer:.*", re.IGNORECASE | re.DOTALL)
181
+ matches = regex.findall(completion_text)
182
+ if len(matches) > 0:
183
+ return matches[0].strip()
184
+
185
+ # otherwise, return the full completion text
186
+ return completion_text
187
+
188
+ def _prepare_field_answers(self, field_answers: dict | list[dict], fields: dict[str, Field]) -> dict[str, list]:
189
+ """
190
+ field_answers is a dictionary mapping fields to their values. For one-to-one converts, wrap each
191
+ answer in a list. For one-to-many converts, invert the list of dictionaries into a dictionary with
192
+ list values.
193
+ """
194
+ # if this is a one-to-one convert, we need to wrap each answer in a list
195
+ if self.cardinality == Cardinality.ONE_TO_ONE:
196
+ field_answers = {field_name: [field_answers[field_name]] for field_name in fields}
197
+
198
+ # otherwise, we need to invert the list of dictionaries into a dictionary with list values
199
+ else:
200
+ field_answers_lst: list[dict] = deepcopy(field_answers)
201
+
202
+ field_answers = {field_name: [] for field_name in fields}
203
+ for answer_dict in field_answers_lst:
204
+ for field_name in fields:
205
+ answer = answer_dict.get(field_name, None)
206
+ field_answers[field_name].append(answer)
207
+
208
+ return field_answers
209
+
210
+ def _check_convert_answer_text(self, answer_text: str, fields: dict[str, Field], throw_exception: bool=False) -> dict | list[dict] | None:
211
+ """
212
+ Try parsing the answer text into a JSON object. If the parsing fails, return None.
213
+ """
214
+ try:
215
+ # extract json from the answer text
216
+ field_answers = get_json_from_answer(answer_text, self.model, self.cardinality)
217
+
218
+ # TODO: wrap non-list outputs in a list if expected output is a list
219
+
220
+ # common error: if the output is a singleton list which contains a list, but the expected field type
221
+ # is a list of strings, or a list of floats, i.e. not a list of lists; then extract the inner list
222
+ for field, field_type in fields.items():
223
+ answer = field_answers[field]
224
+ field_type_is_not_list_of_lists = isinstance(field_type, ListField) and not issubclass(field_type.element_type, ListField)
225
+ answer_is_list_of_lists = isinstance(answer, list) and len(answer) == 1 and isinstance(answer[0], list)
226
+ if field_type_is_not_list_of_lists and answer_is_list_of_lists:
227
+ field_answers[field] = answer[0]
228
+
229
+ # prepare the field answers to match the expected output and return
230
+ return self._prepare_field_answers(field_answers, fields)
231
+
232
+ except Exception as e:
233
+ if throw_exception:
234
+ raise e
235
+
236
+ return None
237
+
238
+ def _check_filter_answer_text(self, answer_text: str) -> dict | None:
239
+ """
240
+ Return {"passed_operator": True} if and only if "true" is in the answer text.
241
+ Return {"passed_operator": False} if and only if "false" is in the answer text.
242
+ Otherwise, return None.
243
+ """
244
+ # NOTE: we may be able to eliminate this condition by specifying this JSON output in the prompt;
245
+ # however, that would also need to coincide with a change to allow the parse_answer_fn to set "passed_operator"
246
+ if "true" in answer_text.lower():
247
+ return {"passed_operator": True}
248
+ elif "false" in answer_text.lower():
249
+ return {"passed_operator": False}
250
+
251
+ return None
252
+
253
+ def _parse_convert_answer(self, completion_text: str, fields: dict[str, Field], json_output: bool) -> dict[str, list]:
254
+ """Extract the answer from the completion object for convert operations."""
255
+ # if the model followed the default instructions, the completion text will place
256
+ # its answer between "ANSWER:" and "---"
257
+ regex = re.compile("answer:(.*?)---", re.IGNORECASE | re.DOTALL)
173
258
  matches = regex.findall(completion_text)
174
259
  if len(matches) > 0:
175
- return matches[0].strip()
260
+ answer_text = matches[0].strip()
176
261
 
177
- # otherwise, return None
178
- return None
262
+ # if we don't expect a JSON output, return the answer text as is
263
+ if not json_output:
264
+ return answer_text
179
265
 
180
- def _parse_answer(self, completion_text: str, fields: list[str] | None, **kwargs) -> Any:
181
- """Extract the answer from the completion object."""
182
- # use a custom answer parser if provided
183
- if kwargs.get("parse_answer"):
184
- parse_answer_fn = kwargs.get("parse_answer")
185
- return parse_answer_fn(completion_text)
266
+ # otherwise, try to parse the answer text into a JSON object
267
+ field_answers = self._check_convert_answer_text(answer_text, fields)
268
+ if field_answers is not None:
269
+ return field_answers
186
270
 
271
+ # if the first regex didn't find an answer, try taking all the text after "ANSWER:"
272
+ regex = re.compile("answer:(.*)", re.IGNORECASE | re.DOTALL)
273
+ matches = regex.findall(completion_text)
274
+ if len(matches) > 0:
275
+ answer_text = matches[0].strip()
276
+
277
+ # if we don't expect a JSON output, return the answer text as is
278
+ if not json_output:
279
+ return answer_text
280
+
281
+ # otherwise, try to parse the answer text into a JSON object
282
+ field_answers = self._check_convert_answer_text(answer_text, fields)
283
+ if field_answers is not None:
284
+ return field_answers
285
+
286
+ # finally, try taking all of the text; for JSON output, throw an exception if parsing fails
287
+ if not json_output:
288
+ return completion_text
289
+
290
+ return self._check_convert_answer_text(completion_text, fields, throw_exception=True)
291
+
292
+ def _parse_filter_answer(self, completion_text: str) -> dict[str, list]:
293
+ """Extract the answer from the completion object for filter operations."""
187
294
  # if the model followed the default instructions, the completion text will place
188
295
  # its answer between "ANSWER:" and "---"
189
- answer_text = None
190
296
  regex = re.compile("answer:(.*?)---", re.IGNORECASE | re.DOTALL)
191
297
  matches = regex.findall(completion_text)
192
298
  if len(matches) > 0:
193
299
  answer_text = matches[0].strip()
300
+ field_answers = self._check_filter_answer_text(answer_text)
301
+ if field_answers is not None:
302
+ return field_answers
194
303
 
195
- # otherwise, take all the text after "ANSWER:" (or just all of the text)
196
- else:
197
- regex = re.compile("answer:(.*?)", re.IGNORECASE | re.DOTALL)
198
- matches = regex.findall(completion_text)
199
- answer_text = matches[0].strip() if len(matches) > 0 else completion_text
304
+ # if the first regex didn't find an answer, try taking all the text after "ANSWER:"
305
+ regex = re.compile("answer:(.*)", re.IGNORECASE | re.DOTALL)
306
+ matches = regex.findall(completion_text)
307
+ if len(matches) > 0:
308
+ answer_text = matches[0].strip()
309
+ field_answers = self._check_filter_answer_text(answer_text)
310
+ if field_answers is not None:
311
+ return field_answers
200
312
 
201
- # if this is a filter operator, return True if and only if "true" is in the answer text
202
- # NOTE: we may be able to elimiate this condition by specifying this JSON output in the prompt;
203
- # however, that would also need to coincide with a change to allow the parse_answer_fn to set "passed_operator"
204
- if self.prompt_strategy in [PromptStrategy.COT_BOOL, PromptStrategy.COT_BOOL_IMAGE]:
205
- return {"passed_operator": "true" in answer_text.lower()}
313
+ # finally, try taking all of the text; throw an exception if this doesn't work
314
+ field_answers = self._check_filter_answer_text(completion_text)
315
+ if field_answers is None:
316
+ raise Exception(f"Could not parse answer from completion text: {completion_text}")
206
317
 
207
- # parse the answer text into a JSON object and return it
208
- field_answers = get_json_from_answer(answer_text, self.model, self.cardinality)
318
+ return field_answers
209
319
 
210
- # if this is a one-to-one convert, we need to wrap each answer in a list
211
- if self.cardinality == Cardinality.ONE_TO_ONE:
212
- field_answers = {field_name: [field_answers[field_name]] for field_name in fields}
320
+ def _parse_answer(self, completion_text: str, fields: dict[str, Field] | None, json_output: bool, **kwargs) -> dict[str, list]:
321
+ """Extract the answer from the completion object."""
322
+ # use a custom answer parser if provided
323
+ if kwargs.get("parse_answer"):
324
+ parse_answer_fn = kwargs.get("parse_answer")
325
+ return parse_answer_fn(completion_text)
213
326
 
214
- # otherwise, we need to invert the list of dictionaries into a dictionary with list values
215
- else:
216
- field_answers_lst: list[dict] = deepcopy(field_answers)
327
+ # fields should be a dict if a custom answer parser is not provided
328
+ assert isinstance(fields, dict), "Fields must be provided if a custom answer parser is not provided."
217
329
 
218
- field_answers = {field_name: [] for field_name in fields}
219
- for answer_dict in field_answers_lst:
220
- for field_name in fields:
221
- answer = answer_dict.get(field_name, None)
222
- field_answers[field_name].append(answer)
330
+ # extract the per-field answers from the completion text
331
+ field_answers = (
332
+ self._parse_filter_answer(completion_text)
333
+ if self.prompt_strategy.is_bool_prompt()
334
+ else self._parse_convert_answer(completion_text, fields, json_output)
335
+ )
223
336
 
224
337
  return field_answers
225
338
 
226
- def __call__(self, candidate: DataRecord, fields: list[str] | None, **kwargs) -> GenerationOutput:
339
+ def __call__(self, candidate: DataRecord, fields: dict[str, Field] | None, json_output: bool=True, **kwargs) -> GenerationOutput:
227
340
  """Take the input record (`candidate`), generate the output `fields`, and return the generated output."""
228
341
  client = self._get_client_or_model()
342
+ logger.debug(f"Generating for candidate {candidate} with fields {fields}")
229
343
 
230
344
  # fields can only be None if the user provides an answer parser
231
- assert fields is not None or "parse_answer" in kwargs, "`fields` must be provided if `parse_answer` function is not provided in kwargs."
345
+ fields_check = fields is not None or "parse_answer" in kwargs
346
+ assert fields_check, "`fields` must be provided if `parse_answer` function is not provided in kwargs."
232
347
 
233
348
  # if the user (or operator) provides a system prompt instead of a prompt, treat this as
234
349
  # the prompt and print a warning
@@ -238,10 +353,10 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
238
353
  warnings.warn("Provided `system_prompt` without providing `prompt`; setting `prompt` = `system_prompt`.") # noqa: B028
239
354
 
240
355
  # generate a list of messages which can be used to construct a payload
241
- self.messages = self.prompt_factory.create_messages(candidate, fields, **kwargs)
356
+ messages = self.prompt_factory.create_messages(candidate, fields, **kwargs)
242
357
 
243
358
  # create the chat payload
244
- chat_payload = self._generate_payload(self.messages, **kwargs)
359
+ chat_payload = self._generate_payload(messages, **kwargs)
245
360
 
246
361
  # generate the text completion
247
362
  start_time = time.time()
@@ -249,16 +364,20 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
249
364
  try:
250
365
  completion = self._generate_completion(client, chat_payload, **kwargs)
251
366
  end_time = time.time()
252
-
367
+ logger.debug(f"Generated completion in {end_time - start_time:.2f} seconds")
253
368
  # if there's an error generating the completion, we have to return an empty answer
254
369
  # and can only account for the time spent performing the failed generation
255
370
  except Exception as e:
256
- print(f"Error generating completion: {e}")
371
+ logger.error(f"Error generating completion: {e}")
257
372
  field_answers = {field_name: None for field_name in fields}
258
373
  reasoning = None
259
- generation_stats = GenerationStats(model_name=self.model_name, llm_call_duration_secs=time.time() - start_time)
374
+ generation_stats = GenerationStats(
375
+ model_name=self.model_name,
376
+ llm_call_duration_secs=time.time() - start_time,
377
+ total_llm_calls=1,
378
+ )
260
379
 
261
- return field_answers, reasoning, generation_stats
380
+ return field_answers, reasoning, generation_stats, messages
262
381
 
263
382
  # parse usage statistics and create the GenerationStats
264
383
  generation_stats = None
@@ -282,6 +401,7 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
282
401
  total_input_cost=input_tokens * usd_per_input_token,
283
402
  total_output_cost=output_tokens * usd_per_output_token,
284
403
  cost_per_record=input_tokens * usd_per_input_token + output_tokens * usd_per_output_token,
404
+ total_llm_calls=1,
285
405
  # "system_prompt": system_prompt,
286
406
  # "prompt": prompt,
287
407
  # "usage": usage,
@@ -292,46 +412,65 @@ class BaseGenerator(Generic[ContextType, InputType], ABC):
292
412
 
293
413
  # pretty print prompt + full completion output for debugging
294
414
  completion_text = self._get_completion_text(completion, **kwargs)
295
- if self.verbose:
296
- prompt = ""
297
- for message in self.messages:
298
- if message["role"] == "user":
299
- prompt += message["content"] + "\n" if message["type"] == "text" else "<image>\n"
300
- print(f"PROMPT:\n{prompt}")
301
- print(Fore.GREEN + f"{completion_text}\n" + Style.RESET_ALL)
415
+ prompt = ""
416
+ for message in messages:
417
+ if message["role"] == "user":
418
+ prompt += message["content"] + "\n" if message["type"] == "text" else "<image>\n"
419
+ logger.debug(f"PROMPT:\n{prompt}")
420
+ logger.debug(Fore.GREEN + f"{completion_text}\n" + Style.RESET_ALL)
302
421
 
303
422
  # parse reasoning
304
423
  reasoning = None
305
424
  try:
306
425
  reasoning = self._parse_reasoning(completion_text, **kwargs)
307
426
  except Exception as e:
308
- print(f"Error parsing reasoning and answers: {e}")
427
+ logger.error(f"Error parsing reasoning and answers: {e}")
309
428
 
310
429
  # parse field answers
311
430
  field_answers = None if fields is None else {field_name: None for field_name in fields}
312
431
  try:
313
- field_answers = self._parse_answer(completion_text, fields, **kwargs)
432
+ field_answers = self._parse_answer(completion_text, fields, json_output, **kwargs)
314
433
  except Exception as e:
315
- print(f"Error parsing answers: {e}")
316
-
317
- return field_answers, reasoning, generation_stats
434
+ logger.error(f"Error parsing answers: {e}")
435
+ os.makedirs("parse-answer-errors", exist_ok=True)
436
+ ts = time.time()
437
+ with open(f"parse-answer-errors/error-{ts}.txt", "w") as f:
438
+ f.write(f"{str(self.model_name)}\n")
439
+ f.write("#####\n")
440
+ f.write(f"{str(self.prompt_strategy)}\n")
441
+ f.write("#####\n")
442
+ f.write(f"{str(completion_text)}\n")
443
+ f.write("#####\n")
444
+ f.write(f"{str(fields)}\n")
445
+ f.write("#####\n")
446
+ f.write(f"{str(e)}\n")
447
+
448
+ logger.debug(f"Generated field answers: {field_answers}")
449
+ return field_answers, reasoning, generation_stats, messages
318
450
 
319
451
 
320
452
  class OpenAIGenerator(BaseGenerator[str | list[str], str]):
321
453
  """
322
454
  Class for generating text using the OpenAI chat API.
323
455
  """
324
- def __init__(self, model: Model, prompt_strategy: PromptStrategy, cardinality: Cardinality = Cardinality.ONE_TO_ONE, verbose: bool = False):
456
+
457
+ def __init__(
458
+ self,
459
+ model: Model,
460
+ prompt_strategy: PromptStrategy,
461
+ cardinality: Cardinality = Cardinality.ONE_TO_ONE,
462
+ verbose: bool = False,
463
+ ):
325
464
  # assert that model is an OpenAI model
326
465
  assert model in [Model.GPT_4o, Model.GPT_4o_MINI, Model.GPT_4o_V, Model.GPT_4o_MINI_V]
327
466
  super().__init__(model, prompt_strategy, cardinality, verbose, "developer")
328
467
 
329
468
  def _get_client_or_model(self, **kwargs) -> OpenAI:
330
469
  """Returns a client (or local model) which can be invoked to perform the generation."""
331
- return OpenAI(api_key=get_api_key("OPENAI_API_KEY"))
470
+ return APIClientFactory.get_client(APIClient.OPENAI, get_api_key("OPENAI_API_KEY"))
332
471
 
333
472
  def _generate_completion(self, client: OpenAI, payload: dict, **kwargs) -> ChatCompletion:
334
- """Generates a completion object using the client (or local model)."""
473
+ """Generates a completion object using the client (or local model)."""
335
474
  return client.chat.completions.create(**payload)
336
475
 
337
476
  def _get_completion_text(self, completion: ChatCompletion, **kwargs) -> str:
@@ -358,14 +497,56 @@ class TogetherGenerator(BaseGenerator[str | list[str], str]):
358
497
  """
359
498
  Class for generating text using the Together chat API.
360
499
  """
361
- def __init__(self, model: Model, prompt_strategy: PromptStrategy, cardinality: Cardinality = Cardinality.ONE_TO_ONE, verbose: bool = False):
500
+
501
+ def __init__(
502
+ self,
503
+ model: Model,
504
+ prompt_strategy: PromptStrategy,
505
+ cardinality: Cardinality = Cardinality.ONE_TO_ONE,
506
+ verbose: bool = False,
507
+ ):
362
508
  # assert that model is a model offered by Together
363
- assert model in [Model.MIXTRAL, Model.LLAMA3, Model.LLAMA3_V]
509
+ assert model in [Model.MIXTRAL, Model.LLAMA3, Model.LLAMA3_V, Model.DEEPSEEK]
364
510
  super().__init__(model, prompt_strategy, cardinality, verbose, "system")
365
511
 
512
+ def _generate_payload(self, messages: list[dict], **kwargs) -> dict:
513
+ """
514
+ Generates the payload which will be fed into the client (or local model).
515
+
516
+ Each message will be a dictionary with the following format:
517
+ {
518
+ "role": "user" | "system",
519
+ "type": "text" | "image",
520
+ "content": str
521
+ }
522
+
523
+ For LLAMA3, the payload needs to be in a {"role": <role>, "content": <content>} format.
524
+ """
525
+ # for other models, use our standard payload generation
526
+ if self.model != Model.LLAMA3:
527
+ return super()._generate_payload(messages, **kwargs)
528
+
529
+ # get basic parameters
530
+ model = self.model_name
531
+ temperature = kwargs.get("temperature", 0.0)
532
+
533
+ # construct messages in simple {"role": <role>, "content": <content>} format
534
+ chat_messages = []
535
+ for message in messages:
536
+ chat_messages.append({"role": message["role"], "content": message["content"]})
537
+
538
+ # construct and return payload
539
+ payload = {
540
+ "model": model,
541
+ "temperature": temperature,
542
+ "messages": chat_messages,
543
+ }
544
+
545
+ return payload
546
+
366
547
  def _get_client_or_model(self, **kwargs) -> Together:
367
548
  """Returns a client (or local model) which can be invoked to perform the generation."""
368
- return Together(api_key=get_api_key("TOGETHER_API_KEY"))
549
+ return APIClientFactory.get_client(APIClient.TOGETHER, get_api_key("TOGETHER_API_KEY"))
369
550
 
370
551
  def _generate_completion(self, client: Together, payload: dict, **kwargs) -> ChatCompletionResponse:
371
552
  """Generates a completion object using the client (or local model)."""
@@ -391,7 +572,6 @@ class TogetherGenerator(BaseGenerator[str | list[str], str]):
391
572
  return completion.choices[0].logprobs
392
573
 
393
574
 
394
-
395
575
  ### CODE SYNTHESIS EXECUTION ###
396
576
  def code_execution(api: API, code: str, candidate_dict: dict[str, Any], verbose: bool = False):
397
577
  inputs = {field_name: candidate_dict[field_name] for field_name in api.inputs}
@@ -5,7 +5,6 @@ from palimpzest.query.operators.aggregate import CountAggregateOp as _CountAggre
5
5
  from palimpzest.query.operators.convert import ConvertOp as _ConvertOp
6
6
  from palimpzest.query.operators.convert import LLMConvert as _LLMConvert
7
7
  from palimpzest.query.operators.convert import LLMConvertBonded as _LLMConvertBonded
8
- from palimpzest.query.operators.convert import LLMConvertConventional as _LLMConvertConventional
9
8
  from palimpzest.query.operators.convert import NonLLMConvert as _NonLLMConvert
10
9
  from palimpzest.query.operators.filter import FilterOp as _FilterOp
11
10
  from palimpzest.query.operators.filter import LLMFilter as _LLMFilter
@@ -66,7 +65,7 @@ PHYSICAL_OPERATORS = (
66
65
  # aggregate
67
66
  [_AggregateOp, _ApplyGroupByOp, _AverageAggregateOp, _CountAggregateOp]
68
67
  # convert
69
- + [_ConvertOp, _NonLLMConvert, _LLMConvert, _LLMConvertConventional, _LLMConvertBonded]
68
+ + [_ConvertOp, _NonLLMConvert, _LLMConvert, _LLMConvertBonded]
70
69
  # scan
71
70
  + [_ScanPhysicalOp, _MarshalAndScanDataOp, _CacheScanDataOp]
72
71
  # filter