palimpzest 0.7.21__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. palimpzest/__init__.py +37 -6
  2. palimpzest/agents/__init__.py +0 -0
  3. palimpzest/agents/compute_agents.py +0 -0
  4. palimpzest/agents/search_agents.py +637 -0
  5. palimpzest/constants.py +343 -209
  6. palimpzest/core/data/context.py +393 -0
  7. palimpzest/core/data/context_manager.py +163 -0
  8. palimpzest/core/data/dataset.py +639 -0
  9. palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
  10. palimpzest/core/elements/groupbysig.py +16 -13
  11. palimpzest/core/elements/records.py +166 -75
  12. palimpzest/core/lib/schemas.py +152 -390
  13. palimpzest/core/{data/dataclasses.py → models.py} +306 -170
  14. palimpzest/policy.py +2 -27
  15. palimpzest/prompts/__init__.py +35 -5
  16. palimpzest/prompts/agent_prompts.py +357 -0
  17. palimpzest/prompts/context_search.py +9 -0
  18. palimpzest/prompts/convert_prompts.py +62 -6
  19. palimpzest/prompts/filter_prompts.py +51 -6
  20. palimpzest/prompts/join_prompts.py +163 -0
  21. palimpzest/prompts/moa_proposer_convert_prompts.py +6 -6
  22. palimpzest/prompts/prompt_factory.py +375 -47
  23. palimpzest/prompts/split_proposer_prompts.py +1 -1
  24. palimpzest/prompts/util_phrases.py +5 -0
  25. palimpzest/prompts/validator.py +239 -0
  26. palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
  27. palimpzest/query/execution/execution_strategy.py +210 -317
  28. palimpzest/query/execution/execution_strategy_type.py +5 -7
  29. palimpzest/query/execution/mab_execution_strategy.py +249 -136
  30. palimpzest/query/execution/parallel_execution_strategy.py +153 -244
  31. palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
  32. palimpzest/query/generators/generators.py +160 -331
  33. palimpzest/query/operators/__init__.py +15 -5
  34. palimpzest/query/operators/aggregate.py +50 -33
  35. palimpzest/query/operators/compute.py +201 -0
  36. palimpzest/query/operators/convert.py +33 -19
  37. palimpzest/query/operators/critique_and_refine_convert.py +7 -5
  38. palimpzest/query/operators/distinct.py +62 -0
  39. palimpzest/query/operators/filter.py +26 -16
  40. palimpzest/query/operators/join.py +403 -0
  41. palimpzest/query/operators/limit.py +3 -3
  42. palimpzest/query/operators/logical.py +205 -77
  43. palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
  44. palimpzest/query/operators/physical.py +27 -21
  45. palimpzest/query/operators/project.py +3 -3
  46. palimpzest/query/operators/rag_convert.py +7 -7
  47. palimpzest/query/operators/retrieve.py +9 -9
  48. palimpzest/query/operators/scan.py +81 -42
  49. palimpzest/query/operators/search.py +524 -0
  50. palimpzest/query/operators/split_convert.py +10 -8
  51. palimpzest/query/optimizer/__init__.py +7 -9
  52. palimpzest/query/optimizer/cost_model.py +108 -441
  53. palimpzest/query/optimizer/optimizer.py +123 -181
  54. palimpzest/query/optimizer/optimizer_strategy.py +66 -61
  55. palimpzest/query/optimizer/plan.py +352 -67
  56. palimpzest/query/optimizer/primitives.py +43 -19
  57. palimpzest/query/optimizer/rules.py +484 -646
  58. palimpzest/query/optimizer/tasks.py +127 -58
  59. palimpzest/query/processor/config.py +42 -76
  60. palimpzest/query/processor/query_processor.py +73 -18
  61. palimpzest/query/processor/query_processor_factory.py +46 -38
  62. palimpzest/schemabuilder/schema_builder.py +15 -28
  63. palimpzest/utils/model_helpers.py +32 -77
  64. palimpzest/utils/progress.py +114 -102
  65. palimpzest/validator/__init__.py +0 -0
  66. palimpzest/validator/validator.py +306 -0
  67. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/METADATA +6 -1
  68. palimpzest-0.8.1.dist-info/RECORD +95 -0
  69. palimpzest/core/lib/fields.py +0 -141
  70. palimpzest/prompts/code_synthesis_prompts.py +0 -28
  71. palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
  72. palimpzest/query/generators/api_client_factory.py +0 -30
  73. palimpzest/query/operators/code_synthesis_convert.py +0 -488
  74. palimpzest/query/operators/map.py +0 -130
  75. palimpzest/query/processor/nosentinel_processor.py +0 -33
  76. palimpzest/query/processor/processing_strategy_type.py +0 -28
  77. palimpzest/query/processor/sentinel_processor.py +0 -88
  78. palimpzest/query/processor/streaming_processor.py +0 -149
  79. palimpzest/sets.py +0 -405
  80. palimpzest/utils/datareader_helpers.py +0 -61
  81. palimpzest/utils/demo_helpers.py +0 -75
  82. palimpzest/utils/field_helpers.py +0 -69
  83. palimpzest/utils/generation_helpers.py +0 -69
  84. palimpzest/utils/sandbox.py +0 -183
  85. palimpzest-0.7.21.dist-info/RECORD +0 -95
  86. /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
  87. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/WHEEL +0 -0
  88. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/licenses/LICENSE +0 -0
  89. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/top_level.txt +0 -0
@@ -6,9 +6,12 @@ from palimpzest.query.operators.convert import ConvertOp as _ConvertOp
6
6
  from palimpzest.query.operators.convert import LLMConvert as _LLMConvert
7
7
  from palimpzest.query.operators.convert import LLMConvertBonded as _LLMConvertBonded
8
8
  from palimpzest.query.operators.convert import NonLLMConvert as _NonLLMConvert
9
+ from palimpzest.query.operators.distinct import DistinctOp as _DistinctOp
9
10
  from palimpzest.query.operators.filter import FilterOp as _FilterOp
10
11
  from palimpzest.query.operators.filter import LLMFilter as _LLMFilter
11
12
  from palimpzest.query.operators.filter import NonLLMFilter as _NonLLMFilter
13
+ from palimpzest.query.operators.join import JoinOp as _JoinOp
14
+ from palimpzest.query.operators.join import NestedLoopsJoin as _NestedLoopsJoin
12
15
  from palimpzest.query.operators.limit import LimitScanOp as _LimitScanOp
13
16
  from palimpzest.query.operators.logical import (
14
17
  Aggregate as _Aggregate,
@@ -17,10 +20,10 @@ from palimpzest.query.operators.logical import (
17
20
  BaseScan as _BaseScan,
18
21
  )
19
22
  from palimpzest.query.operators.logical import (
20
- CacheScan as _CacheScan,
23
+ ConvertScan as _ConvertScan,
21
24
  )
22
25
  from palimpzest.query.operators.logical import (
23
- ConvertScan as _ConvertScan,
26
+ Distinct as _Distinct,
24
27
  )
25
28
  from palimpzest.query.operators.logical import (
26
29
  FilteredScan as _FilteredScan,
@@ -28,6 +31,9 @@ from palimpzest.query.operators.logical import (
28
31
  from palimpzest.query.operators.logical import (
29
32
  GroupByAggregate as _GroupByAggregate,
30
33
  )
34
+ from palimpzest.query.operators.logical import (
35
+ JoinOp as _LogicalJoinOp,
36
+ )
31
37
  from palimpzest.query.operators.logical import (
32
38
  LimitScan as _LimitScan,
33
39
  )
@@ -44,7 +50,6 @@ from palimpzest.query.operators.mixture_of_agents_convert import MixtureOfAgents
44
50
  from palimpzest.query.operators.physical import PhysicalOperator as _PhysicalOperator
45
51
  from palimpzest.query.operators.project import ProjectOp as _ProjectOp
46
52
  from palimpzest.query.operators.retrieve import RetrieveOp as _RetrieveOp
47
- from palimpzest.query.operators.scan import CacheScanDataOp as _CacheScanDataOp
48
53
  from palimpzest.query.operators.scan import MarshalAndScanDataOp as _MarshalAndScanDataOp
49
54
  from palimpzest.query.operators.scan import ScanPhysicalOp as _ScanPhysicalOp
50
55
 
@@ -52,10 +57,11 @@ LOGICAL_OPERATORS = [
52
57
  _LogicalOperator,
53
58
  _Aggregate,
54
59
  _BaseScan,
55
- _CacheScan,
56
60
  _ConvertScan,
61
+ _Distinct,
57
62
  _FilteredScan,
58
63
  _GroupByAggregate,
64
+ _LogicalJoinOp,
59
65
  _LimitScan,
60
66
  _Project,
61
67
  _RetrieveScan,
@@ -66,10 +72,14 @@ PHYSICAL_OPERATORS = (
66
72
  [_AggregateOp, _ApplyGroupByOp, _AverageAggregateOp, _CountAggregateOp]
67
73
  # convert
68
74
  + [_ConvertOp, _NonLLMConvert, _LLMConvert, _LLMConvertBonded]
75
+ # distinct
76
+ + [_DistinctOp]
69
77
  # scan
70
- + [_ScanPhysicalOp, _MarshalAndScanDataOp, _CacheScanDataOp]
78
+ + [_ScanPhysicalOp, _MarshalAndScanDataOp]
71
79
  # filter
72
80
  + [_FilterOp, _NonLLMFilter, _LLMFilter]
81
+ # join
82
+ + [_JoinOp, _NestedLoopsJoin]
73
83
  # limit
74
84
  + [_LimitScanOp]
75
85
  # mixture-of-agents
@@ -3,10 +3,10 @@ from __future__ import annotations
3
3
  import time
4
4
 
5
5
  from palimpzest.constants import NAIVE_EST_NUM_GROUPS, AggFunc
6
- from palimpzest.core.data.dataclasses import OperatorCostEstimates, RecordOpStats
7
6
  from palimpzest.core.elements.groupbysig import GroupBySig
8
7
  from palimpzest.core.elements.records import DataRecord, DataRecordSet
9
- from palimpzest.core.lib.schemas import Number
8
+ from palimpzest.core.lib.schemas import Average, Count
9
+ from palimpzest.core.models import OperatorCostEstimates, RecordOpStats
10
10
  from palimpzest.query.operators.physical import PhysicalOperator
11
11
 
12
12
 
@@ -16,7 +16,7 @@ class AggregateOp(PhysicalOperator):
16
16
  __call__ methods. Thus, we use a slightly modified abstract base class for
17
17
  these operators.
18
18
  """
19
- def __call__(self, candidates: DataRecordSet) -> DataRecordSet:
19
+ def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
20
20
  raise NotImplementedError("Using __call__ from abstract method")
21
21
 
22
22
 
@@ -67,6 +67,8 @@ class ApplyGroupByOp(AggregateOp):
67
67
  return state + 1
68
68
  elif func.lower() == "average":
69
69
  sum, cnt = state
70
+ if val is None:
71
+ return (sum, cnt)
70
72
  return (sum + val, cnt + 1)
71
73
  else:
72
74
  raise Exception("Unknown agg function " + func)
@@ -77,11 +79,11 @@ class ApplyGroupByOp(AggregateOp):
77
79
  return state
78
80
  elif func.lower() == "average":
79
81
  sum, cnt = state
80
- return float(sum) / cnt
82
+ return float(sum) / cnt if cnt > 0 else None
81
83
  else:
82
84
  raise Exception("Unknown agg function " + func)
83
85
 
84
- def __call__(self, candidates: DataRecordSet) -> DataRecordSet:
86
+ def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
85
87
  start_time = time.time()
86
88
 
87
89
  # build group array
@@ -107,17 +109,13 @@ class ApplyGroupByOp(AggregateOp):
107
109
  agg_state[group] = state
108
110
 
109
111
  # return list of data records (one per group)
110
- drs = []
112
+ drs: list[DataRecord] = []
111
113
  group_by_fields = self.group_by_sig.group_by_fields
112
114
  agg_fields = self.group_by_sig.get_agg_field_names()
113
115
  for g in agg_state:
114
- dr = DataRecord(self.group_by_sig.output_schema())
115
- # NOTE: this will set the parent_id and source_idx to be the id of the final source record;
116
- # in the near future we may want to have parent_id accept a list of ids
117
- dr = DataRecord.from_parent(
116
+ dr = DataRecord.from_agg_parents(
118
117
  schema=self.group_by_sig.output_schema(),
119
- parent_record=candidates[-1],
120
- project_cols=[],
118
+ parent_records=candidates,
121
119
  )
122
120
  for i in range(0, len(g)):
123
121
  k = g[i]
@@ -135,8 +133,8 @@ class ApplyGroupByOp(AggregateOp):
135
133
  for dr in drs:
136
134
  record_op_stats = RecordOpStats(
137
135
  record_id=dr.id,
138
- record_parent_id=dr.parent_id,
139
- record_source_idx=dr.source_idx,
136
+ record_parent_ids=dr.parent_ids,
137
+ record_source_indices=dr.source_indices,
140
138
  record_state=dr.to_dict(include_bytes=False),
141
139
  full_op_id=self.get_full_op_id(),
142
140
  logical_op_id=self.logical_op_id,
@@ -155,13 +153,20 @@ class AverageAggregateOp(AggregateOp):
155
153
  # NOTE: we don't actually need / use agg_func here (yet)
156
154
 
157
155
  def __init__(self, agg_func: AggFunc, *args, **kwargs):
158
- kwargs["output_schema"] = Number
156
+ # enforce that output schema is correct
157
+ assert kwargs["output_schema"] == Average, "AverageAggregateOp requires output_schema to be Average"
158
+
159
+ # enforce that input schema is a single numeric field
160
+ input_field_types = list(kwargs["input_schema"].model_fields.values())
161
+ assert len(input_field_types) == 1, "AverageAggregateOp requires input_schema to have exactly one field"
162
+ numeric_field_types = [bool, int, float, bool | None, int | None, float | None, int | float, int | float | None]
163
+ is_numeric = input_field_types[0].annotation in numeric_field_types
164
+ assert is_numeric, f"AverageAggregateOp requires input_schema to have a numeric field type, i.e. one of: {numeric_field_types}\nGot: {input_field_types[0]}"
165
+
166
+ # call parent constructor
159
167
  super().__init__(*args, **kwargs)
160
168
  self.agg_func = agg_func
161
169
 
162
- if not self.input_schema.get_desc() == Number.get_desc():
163
- raise Exception("Aggregate function AVERAGE is only defined over Numbers")
164
-
165
170
  def __str__(self):
166
171
  op = super().__str__()
167
172
  op += f" Function: {str(self.agg_func)}\n"
@@ -184,19 +189,29 @@ class AverageAggregateOp(AggregateOp):
184
189
  quality=1.0,
185
190
  )
186
191
 
187
- def __call__(self, candidates: DataRecordSet) -> DataRecordSet:
192
+ def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
188
193
  start_time = time.time()
189
194
 
190
- # NOTE: this will set the parent_id and source_idx to be the id of the final source record;
191
- # in the near future we may want to have parent_id accept a list of ids
192
- dr = DataRecord.from_parent(schema=Number, parent_record=candidates[-1], project_cols=[])
193
- dr.value = sum(list(map(lambda c: float(c.value), candidates))) / len(candidates)
195
+ # NOTE: we currently do not guarantee that input values conform to their specified type;
196
+ # as a result, we simply omit any values which do not parse to a float from the average
197
+ # NOTE: right now we perform a check in the constructor which enforces that the input_schema
198
+ # has a single field which is numeric in nature; in the future we may want to have a
199
+ # cleaner way of computing the value (rather than `float(list(candidate...))` below)
200
+ dr = DataRecord.from_agg_parents(schema=Average, parent_records=candidates)
201
+ summation, total = 0, 0
202
+ for candidate in candidates:
203
+ try:
204
+ summation += float(list(candidate.to_dict().values())[0])
205
+ total += 1
206
+ except Exception:
207
+ pass
208
+ dr.average = summation / total
194
209
 
195
210
  # create RecordOpStats object
196
211
  record_op_stats = RecordOpStats(
197
212
  record_id=dr.id,
198
- record_parent_id=dr.parent_id,
199
- record_source_idx=dr.source_idx,
213
+ record_parent_ids=dr.parent_ids,
214
+ record_source_indices=dr.source_indices,
200
215
  record_state=dr.to_dict(include_bytes=False),
201
216
  full_op_id=self.get_full_op_id(),
202
217
  logical_op_id=self.logical_op_id,
@@ -212,7 +227,10 @@ class CountAggregateOp(AggregateOp):
212
227
  # NOTE: we don't actually need / use agg_func here (yet)
213
228
 
214
229
  def __init__(self, agg_func: AggFunc, *args, **kwargs):
215
- kwargs["output_schema"] = Number
230
+ # enforce that output schema is correct
231
+ assert kwargs["output_schema"] == Count, "CountAggregateOp requires output_schema to be Count"
232
+
233
+ # call parent constructor
216
234
  super().__init__(*args, **kwargs)
217
235
  self.agg_func = agg_func
218
236
 
@@ -238,19 +256,18 @@ class CountAggregateOp(AggregateOp):
238
256
  quality=1.0,
239
257
  )
240
258
 
241
- def __call__(self, candidates: DataRecordSet) -> DataRecordSet:
259
+ def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
242
260
  start_time = time.time()
243
261
 
244
- # NOTE: this will set the parent_id to be the id of the final source record;
245
- # in the near future we may want to have parent_id accept a list of ids
246
- dr = DataRecord.from_parent(schema=Number, parent_record=candidates[-1], project_cols=[])
247
- dr.value = len(candidates)
262
+ # create new DataRecord
263
+ dr = DataRecord.from_agg_parents(schema=Count, parent_records=candidates)
264
+ dr.count = len(candidates)
248
265
 
249
266
  # create RecordOpStats object
250
267
  record_op_stats = RecordOpStats(
251
268
  record_id=dr.id,
252
- record_parent_id=dr.parent_id,
253
- record_source_idx=dr.source_idx,
269
+ record_parent_ids=dr.parent_ids,
270
+ record_source_indices=dr.source_indices,
254
271
  record_state=dr.to_dict(include_bytes=False),
255
272
  full_op_id=self.get_full_op_id(),
256
273
  logical_op_id=self.logical_op_id,
@@ -0,0 +1,201 @@
1
+ import functools
2
+ import inspect
3
+ import os
4
+ import time
5
+ from typing import Any
6
+
7
+ from smolagents import CodeAgent, LiteLLMModel, tool
8
+
9
+ from palimpzest.core.data.context import Context
10
+ from palimpzest.core.data.context_manager import ContextManager
11
+ from palimpzest.core.elements.records import DataRecord, DataRecordSet
12
+ from palimpzest.core.models import GenerationStats, OperatorCostEstimates, RecordOpStats
13
+ from palimpzest.query.operators.physical import PhysicalOperator
14
+
15
+ # TODO: need to store final executed code in compute() operator so that humans can debug when human-in-the-loop
16
+
17
+ def make_tool(bound_method):
18
+ # Get the original function and bound instance
19
+ func = bound_method.__func__
20
+ instance = bound_method.__self__
21
+
22
+ # Get the signature and remove 'self'
23
+ sig = inspect.signature(func)
24
+ params = list(sig.parameters.values())[1:] # skip 'self'
25
+ new_sig = inspect.Signature(parameters=params, return_annotation=sig.return_annotation)
26
+
27
+ # Create a wrapper function dynamically
28
+ @functools.wraps(func)
29
+ def wrapper(*args, **kwargs):
30
+ return func(instance, *args, **kwargs)
31
+
32
+ # Update the __signature__ to reflect the new one without 'self'
33
+ wrapper.__signature__ = new_sig
34
+
35
+ return wrapper
36
+
37
+
38
+ class SmolAgentsCompute(PhysicalOperator):
39
+ """
40
+ """
41
+ def __init__(self, context_id: str, instruction: str, additional_contexts: list[Context] | None = None, *args, **kwargs):
42
+ super().__init__(*args, **kwargs)
43
+ self.context_id = context_id
44
+ self.instruction = instruction
45
+ self.additional_contexts = [] if additional_contexts is None else additional_contexts
46
+ # self.model_id = "anthropic/claude-3-7-sonnet-latest"
47
+ self.model_id = "openai/gpt-4o-mini-2024-07-18"
48
+ # self.model_id = "openai/gpt-4o-2024-08-06"
49
+ api_key = os.getenv("ANTHROPIC_API_KEY") if "anthropic" in self.model_id else os.getenv("OPENAI_API_KEY")
50
+ self.model = LiteLLMModel(model_id=self.model_id, api_key=api_key)
51
+
52
+ def __str__(self):
53
+ op = super().__str__()
54
+ op += f" Context ID: {self.context_id:20s}\n"
55
+ op += f" Instruction: {self.instruction:20s}\n"
56
+ op += f" Add. Ctxs: {self.additional_contexts}\n"
57
+ return op
58
+
59
+ def get_id_params(self):
60
+ id_params = super().get_id_params()
61
+ return {
62
+ "context_id": self.context_id,
63
+ "instruction": self.instruction,
64
+ "additional_contexts": self.additional_contexts,
65
+ **id_params,
66
+ }
67
+
68
+ def get_op_params(self):
69
+ op_params = super().get_op_params()
70
+ return {
71
+ "context_id": self.context_id,
72
+ "instruction": self.instruction,
73
+ "additional_contexts": self.additional_contexts,
74
+ **op_params,
75
+ }
76
+
77
+ def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
78
+ return OperatorCostEstimates(
79
+ cardinality=source_op_cost_estimates.cardinality,
80
+ time_per_record=100,
81
+ cost_per_record=1,
82
+ quality=1.0,
83
+ )
84
+
85
+ def _create_record_set(
86
+ self,
87
+ candidate: DataRecord,
88
+ generation_stats: GenerationStats,
89
+ total_time: float,
90
+ answer: dict[str, Any],
91
+ ) -> DataRecordSet:
92
+ """
93
+ Given an input DataRecord and a determination of whether it passed the filter or not,
94
+ construct the resulting RecordSet.
95
+ """
96
+ # create new DataRecord and set passed_operator attribute
97
+ dr = DataRecord.from_parent(self.output_schema, parent_record=candidate)
98
+ for field in self.output_schema.model_fields:
99
+ if field in answer:
100
+ dr[field] = answer[field]
101
+
102
+ # create RecordOpStats object
103
+ record_op_stats = RecordOpStats(
104
+ record_id=dr.id,
105
+ record_parent_ids=dr.parent_ids,
106
+ record_source_indices=dr.source_indices,
107
+ record_state=dr.to_dict(include_bytes=False),
108
+ full_op_id=self.get_full_op_id(),
109
+ logical_op_id=self.logical_op_id,
110
+ op_name=self.op_name(),
111
+ time_per_record=total_time,
112
+ cost_per_record=generation_stats.cost_per_record,
113
+ model_name=self.get_model_name(),
114
+ total_input_tokens=generation_stats.total_input_tokens,
115
+ total_output_tokens=generation_stats.total_output_tokens,
116
+ total_input_cost=generation_stats.total_input_cost,
117
+ total_output_cost=generation_stats.total_output_cost,
118
+ llm_call_duration_secs=generation_stats.llm_call_duration_secs,
119
+ fn_call_duration_secs=generation_stats.fn_call_duration_secs,
120
+ total_llm_calls=generation_stats.total_llm_calls,
121
+ total_embedding_llm_calls=generation_stats.total_embedding_llm_calls,
122
+ answer={k: v.description if isinstance(v, Context) else v for k, v in answer.items()},
123
+ op_details={k: str(v) for k, v in self.get_id_params().items()},
124
+ )
125
+
126
+ return DataRecordSet([dr], [record_op_stats])
127
+
128
+ def __call__(self, candidate: DataRecord) -> Any:
129
+ start_time = time.time()
130
+
131
+ # get the input context object and its tools
132
+ input_context: Context = candidate.context
133
+ description = input_context.description
134
+ tools = [tool(make_tool(f)) for f in input_context.tools]
135
+
136
+ # update the description to include any additional contexts
137
+ for ctx in self.additional_contexts:
138
+ # TODO: remove additional context if it is an ancestor of the input context
139
+ # (not just if it is equal to the input context)
140
+ if ctx.id == input_context.id:
141
+ continue
142
+ description += f"\n\nHere is some additional Context which may be useful:\n\n{ctx.description}"
143
+
144
+ # perform the computation
145
+ instructions = f"\n\nHere is a description of the Context whose data you will be working with, as well as any previously computed results:\n\n{description}"
146
+ agent = CodeAgent(
147
+ tools=tools,
148
+ model=self.model,
149
+ add_base_tools=False,
150
+ instructions=instructions,
151
+ return_full_result=True,
152
+ additional_authorized_imports=["pandas", "io", "os"],
153
+ planning_interval=4,
154
+ max_steps=30,
155
+ )
156
+ result = agent.run(self.instruction)
157
+ # NOTE: you can see the system prompt with `agent.memory.system_prompt.system_prompt`
158
+ # full_steps = agent.memory.get_full_steps()
159
+
160
+ # compute generation stats
161
+ response = result.output
162
+ input_tokens = result.token_usage.input_tokens
163
+ output_tokens = result.token_usage.output_tokens
164
+ cost_per_input_token = (3.0 / 1e6) if "anthropic" in self.model_id else (0.15 / 1e6) # (2.5 / 1e6) #
165
+ cost_per_output_token = (15.0 / 1e6) if "anthropic" in self.model_id else (0.6 / 1e6) # (10.0 / 1e6) #
166
+ input_cost = input_tokens * cost_per_input_token
167
+ output_cost = output_tokens * cost_per_output_token
168
+ generation_stats = GenerationStats(
169
+ model_name=self.model_id,
170
+ total_input_tokens=input_tokens,
171
+ total_output_tokens=output_tokens,
172
+ total_input_cost=input_cost,
173
+ total_output_cost=output_cost,
174
+ cost_per_record=input_cost + output_cost,
175
+ llm_call_duration_secs=time.time() - start_time,
176
+ )
177
+
178
+ # update the description of the computed Context to include the result
179
+ new_description = f"RESULT: {response}\n\n"
180
+ cm = ContextManager()
181
+ cm.update_context(id=self.context_id, description=new_description)
182
+
183
+ # create and return record set
184
+ field_answers = {
185
+ "context": cm.get_context(id=self.context_id),
186
+ f"result-{self.context_id}": response,
187
+ }
188
+ record_set = self._create_record_set(
189
+ candidate,
190
+ generation_stats,
191
+ time.time() - start_time,
192
+ field_answers,
193
+ )
194
+
195
+ return record_set
196
+
197
+ # import json; json.dumps(agent.memory.get_full_steps())
198
+ # agent.memory.get_full_steps()[1].keys()
199
+ # dict_keys(['step_number', 'timing', 'model_input_messages', 'tool_calls', 'error', 'model_output_message', 'model_output', 'code_action', 'observations', 'observations_images',
200
+ # 'action_output', 'token_usage', 'is_final_answer'])
201
+ # agent.memory.get_full_steps()[1]['action_output']
@@ -4,6 +4,8 @@ import time
4
4
  from abc import ABC, abstractmethod
5
5
  from typing import Callable
6
6
 
7
+ from pydantic.fields import FieldInfo
8
+
7
9
  from palimpzest.constants import (
8
10
  MODEL_CARDS,
9
11
  NAIVE_EST_NUM_INPUT_TOKENS,
@@ -13,12 +15,10 @@ from palimpzest.constants import (
13
15
  Model,
14
16
  PromptStrategy,
15
17
  )
16
- from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates, RecordOpStats
17
18
  from palimpzest.core.elements.records import DataRecord, DataRecordSet
18
- from palimpzest.core.lib.fields import Field
19
- from palimpzest.query.generators.generators import generator_factory
19
+ from palimpzest.core.models import GenerationStats, OperatorCostEstimates, RecordOpStats
20
+ from palimpzest.query.generators.generators import Generator
20
21
  from palimpzest.query.operators.physical import PhysicalOperator
21
- from palimpzest.utils.model_helpers import get_vision_models
22
22
 
23
23
 
24
24
  class ConvertOp(PhysicalOperator, ABC):
@@ -40,6 +40,7 @@ class ConvertOp(PhysicalOperator, ABC):
40
40
  id_params = {
41
41
  "cardinality": self.cardinality.value,
42
42
  "udf": self.udf,
43
+ "desc": self.desc,
43
44
  **id_params,
44
45
  }
45
46
 
@@ -47,7 +48,12 @@ class ConvertOp(PhysicalOperator, ABC):
47
48
 
48
49
  def get_op_params(self):
49
50
  op_params = super().get_op_params()
50
- op_params = {"cardinality": self.cardinality, "udf": self.udf, "desc": self.desc, **op_params}
51
+ op_params = {
52
+ "cardinality": self.cardinality,
53
+ "udf": self.udf,
54
+ "desc": self.desc,
55
+ **op_params,
56
+ }
51
57
 
52
58
  return op_params
53
59
 
@@ -78,8 +84,8 @@ class ConvertOp(PhysicalOperator, ABC):
78
84
  setattr(dr, field, getattr(candidate, field))
79
85
 
80
86
  # get input field names and output field names
81
- input_fields = self.input_schema.field_names()
82
- output_fields = self.output_schema.field_names()
87
+ input_fields = list(self.input_schema.model_fields)
88
+ output_fields = list(self.output_schema.model_fields)
83
89
 
84
90
  # parse newly generated fields from the field_answers dictionary for this field; if the list
85
91
  # of generated values is shorter than the number of records, we fill in with None
@@ -112,8 +118,8 @@ class ConvertOp(PhysicalOperator, ABC):
112
118
  record_op_stats_lst = [
113
119
  RecordOpStats(
114
120
  record_id=dr.id,
115
- record_parent_id=dr.parent_id,
116
- record_source_idx=dr.source_idx,
121
+ record_parent_ids=dr.parent_ids,
122
+ record_source_indices=dr.source_indices,
117
123
  record_state=dr.to_dict(include_bytes=False),
118
124
  full_op_id=self.get_full_op_id(),
119
125
  logical_op_id=self.logical_op_id,
@@ -122,7 +128,7 @@ class ConvertOp(PhysicalOperator, ABC):
122
128
  cost_per_record=per_record_stats.cost_per_record,
123
129
  model_name=self.get_model_name(),
124
130
  answer={field_name: getattr(dr, field_name) for field_name in field_names},
125
- input_fields=self.input_schema.field_names(),
131
+ input_fields=list(self.input_schema.model_fields),
126
132
  generated_fields=field_names,
127
133
  total_input_tokens=per_record_stats.total_input_tokens,
128
134
  total_output_tokens=per_record_stats.total_output_tokens,
@@ -148,7 +154,7 @@ class ConvertOp(PhysicalOperator, ABC):
148
154
  pass
149
155
 
150
156
  @abstractmethod
151
- def convert(self, candidate: DataRecord, fields: dict[str, Field]) -> tuple[dict[str, list], GenerationStats]:
157
+ def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[str, list], GenerationStats]:
152
158
  """
153
159
  This abstract method will be implemented by subclasses of ConvertOp to process the input DataRecord
154
160
  and generate the value(s) for each of the specified fields. If the convert operator is a one-to-many
@@ -182,7 +188,7 @@ class ConvertOp(PhysicalOperator, ABC):
182
188
 
183
189
  # execute the convert
184
190
  field_answers: dict[str, list]
185
- fields = {field: field_type for field, field_type in self.output_schema.field_map().items() if field in fields_to_generate}
191
+ fields = {field: field_type for field, field_type in self.output_schema.model_fields.items() if field in fields_to_generate}
186
192
  field_answers, generation_stats = self.convert(candidate=candidate, fields=fields)
187
193
  assert all([field in field_answers for field in fields_to_generate]), "Not all fields were generated!"
188
194
 
@@ -235,7 +241,7 @@ class NonLLMConvert(ConvertOp):
235
241
  quality=1.0,
236
242
  )
237
243
 
238
- def convert(self, candidate: DataRecord, fields: dict[str, Field]) -> tuple[dict[str, list], GenerationStats]:
244
+ def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[str, list], GenerationStats]:
239
245
  # apply UDF to input record
240
246
  start_time = time.time()
241
247
  field_answers = {}
@@ -282,18 +288,21 @@ class LLMConvert(ConvertOp):
282
288
  self,
283
289
  model: Model,
284
290
  prompt_strategy: PromptStrategy = PromptStrategy.COT_QA,
291
+ reasoning_effort: str | None = None,
285
292
  *args,
286
293
  **kwargs,
287
294
  ):
288
295
  super().__init__(*args, **kwargs)
289
296
  self.model = model
290
297
  self.prompt_strategy = prompt_strategy
298
+ self.reasoning_effort = reasoning_effort
291
299
  if model is not None:
292
- self.generator = generator_factory(model, prompt_strategy, self.cardinality, self.verbose)
300
+ self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base, self.cardinality, self.desc, self.verbose)
293
301
 
294
302
  def __str__(self):
295
303
  op = super().__str__()
296
304
  op += f" Prompt Strategy: {self.prompt_strategy}\n"
305
+ op += f" Reasoning Effort: {self.reasoning_effort}\n"
297
306
  return op
298
307
 
299
308
  def get_id_params(self):
@@ -301,6 +310,7 @@ class LLMConvert(ConvertOp):
301
310
  id_params = {
302
311
  "model": None if self.model is None else self.model.value,
303
312
  "prompt_strategy": None if self.prompt_strategy is None else self.prompt_strategy.value,
313
+ "reasoning_effort": self.reasoning_effort,
304
314
  **id_params,
305
315
  }
306
316
 
@@ -311,6 +321,7 @@ class LLMConvert(ConvertOp):
311
321
  op_params = {
312
322
  "model": self.model,
313
323
  "prompt_strategy": self.prompt_strategy,
324
+ "reasoning_effort": self.reasoning_effort,
314
325
  **op_params,
315
326
  }
316
327
 
@@ -320,7 +331,7 @@ class LLMConvert(ConvertOp):
320
331
  return None if self.model is None else self.model.value
321
332
 
322
333
  def is_image_conversion(self) -> bool:
323
- return self.model in get_vision_models()
334
+ return self.prompt_strategy.is_image_prompt()
324
335
 
325
336
  def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
326
337
  """
@@ -334,13 +345,16 @@ class LLMConvert(ConvertOp):
334
345
  est_num_output_tokens = NAIVE_EST_NUM_OUTPUT_TOKENS
335
346
 
336
347
  # get est. of conversion time per record from model card;
337
- # NOTE: model will only be None for code synthesis, which uses GPT-3.5 as fallback
338
348
  model_name = self.model.value if getattr(self, "model", None) is not None else Model.GPT_4o_MINI.value
339
349
  model_conversion_time_per_record = MODEL_CARDS[model_name]["seconds_per_output_token"] * est_num_output_tokens
340
350
 
341
351
  # get est. of conversion cost (in USD) per record from model card
352
+ usd_per_input_token = MODEL_CARDS[model_name].get("usd_per_input_token")
353
+ if getattr(self, "prompt_strategy", None) is not None and self.prompt_strategy.is_audio_prompt():
354
+ usd_per_input_token = MODEL_CARDS[model_name]["usd_per_audio_input_token"]
355
+
342
356
  model_conversion_usd_per_record = (
343
- MODEL_CARDS[model_name]["usd_per_input_token"] * est_num_input_tokens
357
+ usd_per_input_token * est_num_input_tokens
344
358
  + MODEL_CARDS[model_name]["usd_per_output_token"] * est_num_output_tokens
345
359
  )
346
360
 
@@ -349,7 +363,7 @@ class LLMConvert(ConvertOp):
349
363
  cardinality = selectivity * source_op_cost_estimates.cardinality
350
364
 
351
365
  # estimate quality of output based on the strength of the model being used
352
- quality = (MODEL_CARDS[model_name]["overall"] / 100.0) * source_op_cost_estimates.quality
366
+ quality = (MODEL_CARDS[model_name]["overall"] / 100.0)
353
367
 
354
368
  return OperatorCostEstimates(
355
369
  cardinality=cardinality,
@@ -361,7 +375,7 @@ class LLMConvert(ConvertOp):
361
375
 
362
376
  class LLMConvertBonded(LLMConvert):
363
377
 
364
- def convert(self, candidate: DataRecord, fields: dict[str, Field]) -> tuple[dict[str, list], GenerationStats]:
378
+ def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[str, list], GenerationStats]:
365
379
  # get the set of input fields to use for the convert operation
366
380
  input_fields = self.get_input_fields()
367
381
 
@@ -2,10 +2,12 @@ from __future__ import annotations
2
2
 
3
3
  from typing import Any
4
4
 
5
+ from pydantic.fields import FieldInfo
6
+
5
7
  from palimpzest.constants import MODEL_CARDS, Model, PromptStrategy
6
- from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates
7
8
  from palimpzest.core.elements.records import DataRecord
8
- from palimpzest.query.generators.generators import generator_factory
9
+ from palimpzest.core.models import GenerationStats, OperatorCostEstimates
10
+ from palimpzest.query.generators.generators import Generator
9
11
  from palimpzest.query.operators.convert import LLMConvert
10
12
 
11
13
  # TYPE DEFINITIONS
@@ -35,8 +37,8 @@ class CriticAndRefineConvert(LLMConvert):
35
37
  raise ValueError(f"Unsupported prompt strategy: {self.prompt_strategy}")
36
38
 
37
39
  # create generators
38
- self.critic_generator = generator_factory(self.critic_model, self.critic_prompt_strategy, self.cardinality, self.verbose)
39
- self.refine_generator = generator_factory(self.refine_model, self.refinement_prompt_strategy, self.cardinality, self.verbose)
40
+ self.critic_generator = Generator(self.critic_model, self.critic_prompt_strategy, self.reasoning_effort, self.api_base, self.cardinality, self.desc, self.verbose)
41
+ self.refine_generator = Generator(self.refine_model, self.refinement_prompt_strategy, self.reasoning_effort, self.api_base, self.cardinality, self.desc, self.verbose)
40
42
 
41
43
  def __str__(self):
42
44
  op = super().__str__()
@@ -86,7 +88,7 @@ class CriticAndRefineConvert(LLMConvert):
86
88
 
87
89
  return naive_op_cost_estimates
88
90
 
89
- def convert(self, candidate: DataRecord, fields: list[str]) -> tuple[dict[FieldName, list[Any]], GenerationStats]:
91
+ def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[FieldName, list[Any]], GenerationStats]:
90
92
  # get input fields
91
93
  input_fields = self.get_input_fields()
92
94