palimpzest 0.7.21__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. palimpzest/__init__.py +37 -6
  2. palimpzest/agents/__init__.py +0 -0
  3. palimpzest/agents/compute_agents.py +0 -0
  4. palimpzest/agents/search_agents.py +637 -0
  5. palimpzest/constants.py +259 -197
  6. palimpzest/core/data/context.py +393 -0
  7. palimpzest/core/data/context_manager.py +163 -0
  8. palimpzest/core/data/dataset.py +634 -0
  9. palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
  10. palimpzest/core/elements/groupbysig.py +16 -13
  11. palimpzest/core/elements/records.py +166 -75
  12. palimpzest/core/lib/schemas.py +152 -390
  13. palimpzest/core/{data/dataclasses.py → models.py} +306 -170
  14. palimpzest/policy.py +2 -27
  15. palimpzest/prompts/__init__.py +35 -5
  16. palimpzest/prompts/agent_prompts.py +357 -0
  17. palimpzest/prompts/context_search.py +9 -0
  18. palimpzest/prompts/convert_prompts.py +61 -5
  19. palimpzest/prompts/filter_prompts.py +50 -5
  20. palimpzest/prompts/join_prompts.py +163 -0
  21. palimpzest/prompts/moa_proposer_convert_prompts.py +5 -5
  22. palimpzest/prompts/prompt_factory.py +358 -46
  23. palimpzest/prompts/validator.py +239 -0
  24. palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
  25. palimpzest/query/execution/execution_strategy.py +210 -317
  26. palimpzest/query/execution/execution_strategy_type.py +5 -7
  27. palimpzest/query/execution/mab_execution_strategy.py +249 -136
  28. palimpzest/query/execution/parallel_execution_strategy.py +153 -244
  29. palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
  30. palimpzest/query/generators/generators.py +157 -330
  31. palimpzest/query/operators/__init__.py +15 -5
  32. palimpzest/query/operators/aggregate.py +50 -33
  33. palimpzest/query/operators/compute.py +201 -0
  34. palimpzest/query/operators/convert.py +27 -21
  35. palimpzest/query/operators/critique_and_refine_convert.py +7 -5
  36. palimpzest/query/operators/distinct.py +62 -0
  37. palimpzest/query/operators/filter.py +22 -13
  38. palimpzest/query/operators/join.py +402 -0
  39. palimpzest/query/operators/limit.py +3 -3
  40. palimpzest/query/operators/logical.py +198 -80
  41. palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
  42. palimpzest/query/operators/physical.py +27 -21
  43. palimpzest/query/operators/project.py +3 -3
  44. palimpzest/query/operators/rag_convert.py +7 -7
  45. palimpzest/query/operators/retrieve.py +9 -9
  46. palimpzest/query/operators/scan.py +81 -42
  47. palimpzest/query/operators/search.py +524 -0
  48. palimpzest/query/operators/split_convert.py +10 -8
  49. palimpzest/query/optimizer/__init__.py +7 -9
  50. palimpzest/query/optimizer/cost_model.py +108 -441
  51. palimpzest/query/optimizer/optimizer.py +123 -181
  52. palimpzest/query/optimizer/optimizer_strategy.py +66 -61
  53. palimpzest/query/optimizer/plan.py +352 -67
  54. palimpzest/query/optimizer/primitives.py +43 -19
  55. palimpzest/query/optimizer/rules.py +484 -646
  56. palimpzest/query/optimizer/tasks.py +127 -58
  57. palimpzest/query/processor/config.py +41 -76
  58. palimpzest/query/processor/query_processor.py +73 -18
  59. palimpzest/query/processor/query_processor_factory.py +46 -38
  60. palimpzest/schemabuilder/schema_builder.py +15 -28
  61. palimpzest/utils/model_helpers.py +27 -77
  62. palimpzest/utils/progress.py +114 -102
  63. palimpzest/validator/__init__.py +0 -0
  64. palimpzest/validator/validator.py +306 -0
  65. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/METADATA +6 -1
  66. palimpzest-0.8.0.dist-info/RECORD +95 -0
  67. palimpzest/core/lib/fields.py +0 -141
  68. palimpzest/prompts/code_synthesis_prompts.py +0 -28
  69. palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
  70. palimpzest/query/generators/api_client_factory.py +0 -30
  71. palimpzest/query/operators/code_synthesis_convert.py +0 -488
  72. palimpzest/query/operators/map.py +0 -130
  73. palimpzest/query/processor/nosentinel_processor.py +0 -33
  74. palimpzest/query/processor/processing_strategy_type.py +0 -28
  75. palimpzest/query/processor/sentinel_processor.py +0 -88
  76. palimpzest/query/processor/streaming_processor.py +0 -149
  77. palimpzest/sets.py +0 -405
  78. palimpzest/utils/datareader_helpers.py +0 -61
  79. palimpzest/utils/demo_helpers.py +0 -75
  80. palimpzest/utils/field_helpers.py +0 -69
  81. palimpzest/utils/generation_helpers.py +0 -69
  82. palimpzest/utils/sandbox.py +0 -183
  83. palimpzest-0.7.21.dist-info/RECORD +0 -95
  84. /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
  85. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/WHEEL +0 -0
  86. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/licenses/LICENSE +0 -0
  87. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/top_level.txt +0 -0
@@ -6,9 +6,12 @@ from palimpzest.query.operators.convert import ConvertOp as _ConvertOp
6
6
  from palimpzest.query.operators.convert import LLMConvert as _LLMConvert
7
7
  from palimpzest.query.operators.convert import LLMConvertBonded as _LLMConvertBonded
8
8
  from palimpzest.query.operators.convert import NonLLMConvert as _NonLLMConvert
9
+ from palimpzest.query.operators.distinct import DistinctOp as _DistinctOp
9
10
  from palimpzest.query.operators.filter import FilterOp as _FilterOp
10
11
  from palimpzest.query.operators.filter import LLMFilter as _LLMFilter
11
12
  from palimpzest.query.operators.filter import NonLLMFilter as _NonLLMFilter
13
+ from palimpzest.query.operators.join import JoinOp as _JoinOp
14
+ from palimpzest.query.operators.join import NestedLoopsJoin as _NestedLoopsJoin
12
15
  from palimpzest.query.operators.limit import LimitScanOp as _LimitScanOp
13
16
  from palimpzest.query.operators.logical import (
14
17
  Aggregate as _Aggregate,
@@ -17,10 +20,10 @@ from palimpzest.query.operators.logical import (
17
20
  BaseScan as _BaseScan,
18
21
  )
19
22
  from palimpzest.query.operators.logical import (
20
- CacheScan as _CacheScan,
23
+ ConvertScan as _ConvertScan,
21
24
  )
22
25
  from palimpzest.query.operators.logical import (
23
- ConvertScan as _ConvertScan,
26
+ Distinct as _Distinct,
24
27
  )
25
28
  from palimpzest.query.operators.logical import (
26
29
  FilteredScan as _FilteredScan,
@@ -28,6 +31,9 @@ from palimpzest.query.operators.logical import (
28
31
  from palimpzest.query.operators.logical import (
29
32
  GroupByAggregate as _GroupByAggregate,
30
33
  )
34
+ from palimpzest.query.operators.logical import (
35
+ JoinOp as _LogicalJoinOp,
36
+ )
31
37
  from palimpzest.query.operators.logical import (
32
38
  LimitScan as _LimitScan,
33
39
  )
@@ -44,7 +50,6 @@ from palimpzest.query.operators.mixture_of_agents_convert import MixtureOfAgents
44
50
  from palimpzest.query.operators.physical import PhysicalOperator as _PhysicalOperator
45
51
  from palimpzest.query.operators.project import ProjectOp as _ProjectOp
46
52
  from palimpzest.query.operators.retrieve import RetrieveOp as _RetrieveOp
47
- from palimpzest.query.operators.scan import CacheScanDataOp as _CacheScanDataOp
48
53
  from palimpzest.query.operators.scan import MarshalAndScanDataOp as _MarshalAndScanDataOp
49
54
  from palimpzest.query.operators.scan import ScanPhysicalOp as _ScanPhysicalOp
50
55
 
@@ -52,10 +57,11 @@ LOGICAL_OPERATORS = [
52
57
  _LogicalOperator,
53
58
  _Aggregate,
54
59
  _BaseScan,
55
- _CacheScan,
56
60
  _ConvertScan,
61
+ _Distinct,
57
62
  _FilteredScan,
58
63
  _GroupByAggregate,
64
+ _LogicalJoinOp,
59
65
  _LimitScan,
60
66
  _Project,
61
67
  _RetrieveScan,
@@ -66,10 +72,14 @@ PHYSICAL_OPERATORS = (
66
72
  [_AggregateOp, _ApplyGroupByOp, _AverageAggregateOp, _CountAggregateOp]
67
73
  # convert
68
74
  + [_ConvertOp, _NonLLMConvert, _LLMConvert, _LLMConvertBonded]
75
+ # distinct
76
+ + [_DistinctOp]
69
77
  # scan
70
- + [_ScanPhysicalOp, _MarshalAndScanDataOp, _CacheScanDataOp]
78
+ + [_ScanPhysicalOp, _MarshalAndScanDataOp]
71
79
  # filter
72
80
  + [_FilterOp, _NonLLMFilter, _LLMFilter]
81
+ # join
82
+ + [_JoinOp, _NestedLoopsJoin]
73
83
  # limit
74
84
  + [_LimitScanOp]
75
85
  # mixture-of-agents
@@ -3,10 +3,10 @@ from __future__ import annotations
3
3
  import time
4
4
 
5
5
  from palimpzest.constants import NAIVE_EST_NUM_GROUPS, AggFunc
6
- from palimpzest.core.data.dataclasses import OperatorCostEstimates, RecordOpStats
7
6
  from palimpzest.core.elements.groupbysig import GroupBySig
8
7
  from palimpzest.core.elements.records import DataRecord, DataRecordSet
9
- from palimpzest.core.lib.schemas import Number
8
+ from palimpzest.core.lib.schemas import Average, Count
9
+ from palimpzest.core.models import OperatorCostEstimates, RecordOpStats
10
10
  from palimpzest.query.operators.physical import PhysicalOperator
11
11
 
12
12
 
@@ -16,7 +16,7 @@ class AggregateOp(PhysicalOperator):
16
16
  __call__ methods. Thus, we use a slightly modified abstract base class for
17
17
  these operators.
18
18
  """
19
- def __call__(self, candidates: DataRecordSet) -> DataRecordSet:
19
+ def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
20
20
  raise NotImplementedError("Using __call__ from abstract method")
21
21
 
22
22
 
@@ -67,6 +67,8 @@ class ApplyGroupByOp(AggregateOp):
67
67
  return state + 1
68
68
  elif func.lower() == "average":
69
69
  sum, cnt = state
70
+ if val is None:
71
+ return (sum, cnt)
70
72
  return (sum + val, cnt + 1)
71
73
  else:
72
74
  raise Exception("Unknown agg function " + func)
@@ -77,11 +79,11 @@ class ApplyGroupByOp(AggregateOp):
77
79
  return state
78
80
  elif func.lower() == "average":
79
81
  sum, cnt = state
80
- return float(sum) / cnt
82
+ return float(sum) / cnt if cnt > 0 else None
81
83
  else:
82
84
  raise Exception("Unknown agg function " + func)
83
85
 
84
- def __call__(self, candidates: DataRecordSet) -> DataRecordSet:
86
+ def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
85
87
  start_time = time.time()
86
88
 
87
89
  # build group array
@@ -107,17 +109,13 @@ class ApplyGroupByOp(AggregateOp):
107
109
  agg_state[group] = state
108
110
 
109
111
  # return list of data records (one per group)
110
- drs = []
112
+ drs: list[DataRecord] = []
111
113
  group_by_fields = self.group_by_sig.group_by_fields
112
114
  agg_fields = self.group_by_sig.get_agg_field_names()
113
115
  for g in agg_state:
114
- dr = DataRecord(self.group_by_sig.output_schema())
115
- # NOTE: this will set the parent_id and source_idx to be the id of the final source record;
116
- # in the near future we may want to have parent_id accept a list of ids
117
- dr = DataRecord.from_parent(
116
+ dr = DataRecord.from_agg_parents(
118
117
  schema=self.group_by_sig.output_schema(),
119
- parent_record=candidates[-1],
120
- project_cols=[],
118
+ parent_records=candidates,
121
119
  )
122
120
  for i in range(0, len(g)):
123
121
  k = g[i]
@@ -135,8 +133,8 @@ class ApplyGroupByOp(AggregateOp):
135
133
  for dr in drs:
136
134
  record_op_stats = RecordOpStats(
137
135
  record_id=dr.id,
138
- record_parent_id=dr.parent_id,
139
- record_source_idx=dr.source_idx,
136
+ record_parent_ids=dr.parent_ids,
137
+ record_source_indices=dr.source_indices,
140
138
  record_state=dr.to_dict(include_bytes=False),
141
139
  full_op_id=self.get_full_op_id(),
142
140
  logical_op_id=self.logical_op_id,
@@ -155,13 +153,20 @@ class AverageAggregateOp(AggregateOp):
155
153
  # NOTE: we don't actually need / use agg_func here (yet)
156
154
 
157
155
  def __init__(self, agg_func: AggFunc, *args, **kwargs):
158
- kwargs["output_schema"] = Number
156
+ # enforce that output schema is correct
157
+ assert kwargs["output_schema"] == Average, "AverageAggregateOp requires output_schema to be Average"
158
+
159
+ # enforce that input schema is a single numeric field
160
+ input_field_types = list(kwargs["input_schema"].model_fields.values())
161
+ assert len(input_field_types) == 1, "AverageAggregateOp requires input_schema to have exactly one field"
162
+ numeric_field_types = [bool, int, float, bool | None, int | None, float | None, int | float, int | float | None]
163
+ is_numeric = input_field_types[0].annotation in numeric_field_types
164
+ assert is_numeric, f"AverageAggregateOp requires input_schema to have a numeric field type, i.e. one of: {numeric_field_types}\nGot: {input_field_types[0]}"
165
+
166
+ # call parent constructor
159
167
  super().__init__(*args, **kwargs)
160
168
  self.agg_func = agg_func
161
169
 
162
- if not self.input_schema.get_desc() == Number.get_desc():
163
- raise Exception("Aggregate function AVERAGE is only defined over Numbers")
164
-
165
170
  def __str__(self):
166
171
  op = super().__str__()
167
172
  op += f" Function: {str(self.agg_func)}\n"
@@ -184,19 +189,29 @@ class AverageAggregateOp(AggregateOp):
184
189
  quality=1.0,
185
190
  )
186
191
 
187
- def __call__(self, candidates: DataRecordSet) -> DataRecordSet:
192
+ def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
188
193
  start_time = time.time()
189
194
 
190
- # NOTE: this will set the parent_id and source_idx to be the id of the final source record;
191
- # in the near future we may want to have parent_id accept a list of ids
192
- dr = DataRecord.from_parent(schema=Number, parent_record=candidates[-1], project_cols=[])
193
- dr.value = sum(list(map(lambda c: float(c.value), candidates))) / len(candidates)
195
+ # NOTE: we currently do not guarantee that input values conform to their specified type;
196
+ # as a result, we simply omit any values which do not parse to a float from the average
197
+ # NOTE: right now we perform a check in the constructor which enforces that the input_schema
198
+ # has a single field which is numeric in nature; in the future we may want to have a
199
+ # cleaner way of computing the value (rather than `float(list(candidate...))` below)
200
+ dr = DataRecord.from_agg_parents(schema=Average, parent_records=candidates)
201
+ summation, total = 0, 0
202
+ for candidate in candidates:
203
+ try:
204
+ summation += float(list(candidate.to_dict().values())[0])
205
+ total += 1
206
+ except Exception:
207
+ pass
208
+ dr.average = summation / total
194
209
 
195
210
  # create RecordOpStats object
196
211
  record_op_stats = RecordOpStats(
197
212
  record_id=dr.id,
198
- record_parent_id=dr.parent_id,
199
- record_source_idx=dr.source_idx,
213
+ record_parent_ids=dr.parent_ids,
214
+ record_source_indices=dr.source_indices,
200
215
  record_state=dr.to_dict(include_bytes=False),
201
216
  full_op_id=self.get_full_op_id(),
202
217
  logical_op_id=self.logical_op_id,
@@ -212,7 +227,10 @@ class CountAggregateOp(AggregateOp):
212
227
  # NOTE: we don't actually need / use agg_func here (yet)
213
228
 
214
229
  def __init__(self, agg_func: AggFunc, *args, **kwargs):
215
- kwargs["output_schema"] = Number
230
+ # enforce that output schema is correct
231
+ assert kwargs["output_schema"] == Count, "CountAggregateOp requires output_schema to be Count"
232
+
233
+ # call parent constructor
216
234
  super().__init__(*args, **kwargs)
217
235
  self.agg_func = agg_func
218
236
 
@@ -238,19 +256,18 @@ class CountAggregateOp(AggregateOp):
238
256
  quality=1.0,
239
257
  )
240
258
 
241
- def __call__(self, candidates: DataRecordSet) -> DataRecordSet:
259
+ def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
242
260
  start_time = time.time()
243
261
 
244
- # NOTE: this will set the parent_id to be the id of the final source record;
245
- # in the near future we may want to have parent_id accept a list of ids
246
- dr = DataRecord.from_parent(schema=Number, parent_record=candidates[-1], project_cols=[])
247
- dr.value = len(candidates)
262
+ # create new DataRecord
263
+ dr = DataRecord.from_agg_parents(schema=Count, parent_records=candidates)
264
+ dr.count = len(candidates)
248
265
 
249
266
  # create RecordOpStats object
250
267
  record_op_stats = RecordOpStats(
251
268
  record_id=dr.id,
252
- record_parent_id=dr.parent_id,
253
- record_source_idx=dr.source_idx,
269
+ record_parent_ids=dr.parent_ids,
270
+ record_source_indices=dr.source_indices,
254
271
  record_state=dr.to_dict(include_bytes=False),
255
272
  full_op_id=self.get_full_op_id(),
256
273
  logical_op_id=self.logical_op_id,
@@ -0,0 +1,201 @@
1
+ import functools
2
+ import inspect
3
+ import os
4
+ import time
5
+ from typing import Any
6
+
7
+ from smolagents import CodeAgent, LiteLLMModel, tool
8
+
9
+ from palimpzest.core.data.context import Context
10
+ from palimpzest.core.data.context_manager import ContextManager
11
+ from palimpzest.core.elements.records import DataRecord, DataRecordSet
12
+ from palimpzest.core.models import GenerationStats, OperatorCostEstimates, RecordOpStats
13
+ from palimpzest.query.operators.physical import PhysicalOperator
14
+
15
+ # TODO: need to store final executed code in compute() operator so that humans can debug when human-in-the-loop
16
+
17
+ def make_tool(bound_method):
18
+ # Get the original function and bound instance
19
+ func = bound_method.__func__
20
+ instance = bound_method.__self__
21
+
22
+ # Get the signature and remove 'self'
23
+ sig = inspect.signature(func)
24
+ params = list(sig.parameters.values())[1:] # skip 'self'
25
+ new_sig = inspect.Signature(parameters=params, return_annotation=sig.return_annotation)
26
+
27
+ # Create a wrapper function dynamically
28
+ @functools.wraps(func)
29
+ def wrapper(*args, **kwargs):
30
+ return func(instance, *args, **kwargs)
31
+
32
+ # Update the __signature__ to reflect the new one without 'self'
33
+ wrapper.__signature__ = new_sig
34
+
35
+ return wrapper
36
+
37
+
38
+ class SmolAgentsCompute(PhysicalOperator):
39
+ """
40
+ """
41
+ def __init__(self, context_id: str, instruction: str, additional_contexts: list[Context] | None = None, *args, **kwargs):
42
+ super().__init__(*args, **kwargs)
43
+ self.context_id = context_id
44
+ self.instruction = instruction
45
+ self.additional_contexts = [] if additional_contexts is None else additional_contexts
46
+ # self.model_id = "anthropic/claude-3-7-sonnet-latest"
47
+ self.model_id = "openai/gpt-4o-mini-2024-07-18"
48
+ # self.model_id = "openai/gpt-4o-2024-08-06"
49
+ api_key = os.getenv("ANTHROPIC_API_KEY") if "anthropic" in self.model_id else os.getenv("OPENAI_API_KEY")
50
+ self.model = LiteLLMModel(model_id=self.model_id, api_key=api_key)
51
+
52
+ def __str__(self):
53
+ op = super().__str__()
54
+ op += f" Context ID: {self.context_id:20s}\n"
55
+ op += f" Instruction: {self.instruction:20s}\n"
56
+ op += f" Add. Ctxs: {self.additional_contexts}\n"
57
+ return op
58
+
59
+ def get_id_params(self):
60
+ id_params = super().get_id_params()
61
+ return {
62
+ "context_id": self.context_id,
63
+ "instruction": self.instruction,
64
+ "additional_contexts": self.additional_contexts,
65
+ **id_params,
66
+ }
67
+
68
+ def get_op_params(self):
69
+ op_params = super().get_op_params()
70
+ return {
71
+ "context_id": self.context_id,
72
+ "instruction": self.instruction,
73
+ "additional_contexts": self.additional_contexts,
74
+ **op_params,
75
+ }
76
+
77
+ def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
78
+ return OperatorCostEstimates(
79
+ cardinality=source_op_cost_estimates.cardinality,
80
+ time_per_record=100,
81
+ cost_per_record=1,
82
+ quality=1.0,
83
+ )
84
+
85
+ def _create_record_set(
86
+ self,
87
+ candidate: DataRecord,
88
+ generation_stats: GenerationStats,
89
+ total_time: float,
90
+ answer: dict[str, Any],
91
+ ) -> DataRecordSet:
92
+ """
93
+ Given an input DataRecord and a determination of whether it passed the filter or not,
94
+ construct the resulting RecordSet.
95
+ """
96
+ # create new DataRecord and set passed_operator attribute
97
+ dr = DataRecord.from_parent(self.output_schema, parent_record=candidate)
98
+ for field in self.output_schema.model_fields:
99
+ if field in answer:
100
+ dr[field] = answer[field]
101
+
102
+ # create RecordOpStats object
103
+ record_op_stats = RecordOpStats(
104
+ record_id=dr.id,
105
+ record_parent_ids=dr.parent_ids,
106
+ record_source_indices=dr.source_indices,
107
+ record_state=dr.to_dict(include_bytes=False),
108
+ full_op_id=self.get_full_op_id(),
109
+ logical_op_id=self.logical_op_id,
110
+ op_name=self.op_name(),
111
+ time_per_record=total_time,
112
+ cost_per_record=generation_stats.cost_per_record,
113
+ model_name=self.get_model_name(),
114
+ total_input_tokens=generation_stats.total_input_tokens,
115
+ total_output_tokens=generation_stats.total_output_tokens,
116
+ total_input_cost=generation_stats.total_input_cost,
117
+ total_output_cost=generation_stats.total_output_cost,
118
+ llm_call_duration_secs=generation_stats.llm_call_duration_secs,
119
+ fn_call_duration_secs=generation_stats.fn_call_duration_secs,
120
+ total_llm_calls=generation_stats.total_llm_calls,
121
+ total_embedding_llm_calls=generation_stats.total_embedding_llm_calls,
122
+ answer={k: v.description if isinstance(v, Context) else v for k, v in answer.items()},
123
+ op_details={k: str(v) for k, v in self.get_id_params().items()},
124
+ )
125
+
126
+ return DataRecordSet([dr], [record_op_stats])
127
+
128
+ def __call__(self, candidate: DataRecord) -> Any:
129
+ start_time = time.time()
130
+
131
+ # get the input context object and its tools
132
+ input_context: Context = candidate.context
133
+ description = input_context.description
134
+ tools = [tool(make_tool(f)) for f in input_context.tools]
135
+
136
+ # update the description to include any additional contexts
137
+ for ctx in self.additional_contexts:
138
+ # TODO: remove additional context if it is an ancestor of the input context
139
+ # (not just if it is equal to the input context)
140
+ if ctx.id == input_context.id:
141
+ continue
142
+ description += f"\n\nHere is some additional Context which may be useful:\n\n{ctx.description}"
143
+
144
+ # perform the computation
145
+ instructions = f"\n\nHere is a description of the Context whose data you will be working with, as well as any previously computed results:\n\n{description}"
146
+ agent = CodeAgent(
147
+ tools=tools,
148
+ model=self.model,
149
+ add_base_tools=False,
150
+ instructions=instructions,
151
+ return_full_result=True,
152
+ additional_authorized_imports=["pandas", "io", "os"],
153
+ planning_interval=4,
154
+ max_steps=30,
155
+ )
156
+ result = agent.run(self.instruction)
157
+ # NOTE: you can see the system prompt with `agent.memory.system_prompt.system_prompt`
158
+ # full_steps = agent.memory.get_full_steps()
159
+
160
+ # compute generation stats
161
+ response = result.output
162
+ input_tokens = result.token_usage.input_tokens
163
+ output_tokens = result.token_usage.output_tokens
164
+ cost_per_input_token = (3.0 / 1e6) if "anthropic" in self.model_id else (0.15 / 1e6) # (2.5 / 1e6) #
165
+ cost_per_output_token = (15.0 / 1e6) if "anthropic" in self.model_id else (0.6 / 1e6) # (10.0 / 1e6) #
166
+ input_cost = input_tokens * cost_per_input_token
167
+ output_cost = output_tokens * cost_per_output_token
168
+ generation_stats = GenerationStats(
169
+ model_name=self.model_id,
170
+ total_input_tokens=input_tokens,
171
+ total_output_tokens=output_tokens,
172
+ total_input_cost=input_cost,
173
+ total_output_cost=output_cost,
174
+ cost_per_record=input_cost + output_cost,
175
+ llm_call_duration_secs=time.time() - start_time,
176
+ )
177
+
178
+ # update the description of the computed Context to include the result
179
+ new_description = f"RESULT: {response}\n\n"
180
+ cm = ContextManager()
181
+ cm.update_context(id=self.context_id, description=new_description)
182
+
183
+ # create and return record set
184
+ field_answers = {
185
+ "context": cm.get_context(id=self.context_id),
186
+ f"result-{self.context_id}": response,
187
+ }
188
+ record_set = self._create_record_set(
189
+ candidate,
190
+ generation_stats,
191
+ time.time() - start_time,
192
+ field_answers,
193
+ )
194
+
195
+ return record_set
196
+
197
+ # import json; json.dumps(agent.memory.get_full_steps())
198
+ # agent.memory.get_full_steps()[1].keys()
199
+ # dict_keys(['step_number', 'timing', 'model_input_messages', 'tool_calls', 'error', 'model_output_message', 'model_output', 'code_action', 'observations', 'observations_images',
200
+ # 'action_output', 'token_usage', 'is_final_answer'])
201
+ # agent.memory.get_full_steps()[1]['action_output']
@@ -4,6 +4,8 @@ import time
4
4
  from abc import ABC, abstractmethod
5
5
  from typing import Callable
6
6
 
7
+ from pydantic.fields import FieldInfo
8
+
7
9
  from palimpzest.constants import (
8
10
  MODEL_CARDS,
9
11
  NAIVE_EST_NUM_INPUT_TOKENS,
@@ -13,12 +15,10 @@ from palimpzest.constants import (
13
15
  Model,
14
16
  PromptStrategy,
15
17
  )
16
- from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates, RecordOpStats
17
18
  from palimpzest.core.elements.records import DataRecord, DataRecordSet
18
- from palimpzest.core.lib.fields import Field
19
- from palimpzest.query.generators.generators import generator_factory
19
+ from palimpzest.core.models import GenerationStats, OperatorCostEstimates, RecordOpStats
20
+ from palimpzest.query.generators.generators import Generator
20
21
  from palimpzest.query.operators.physical import PhysicalOperator
21
- from palimpzest.utils.model_helpers import get_vision_models
22
22
 
23
23
 
24
24
  class ConvertOp(PhysicalOperator, ABC):
@@ -26,14 +26,12 @@ class ConvertOp(PhysicalOperator, ABC):
26
26
  self,
27
27
  cardinality: Cardinality = Cardinality.ONE_TO_ONE,
28
28
  udf: Callable | None = None,
29
- desc: str | None = None,
30
29
  *args,
31
30
  **kwargs,
32
31
  ):
33
32
  super().__init__(*args, **kwargs)
34
33
  self.cardinality = cardinality
35
34
  self.udf = udf
36
- self.desc = desc
37
35
 
38
36
  def get_id_params(self):
39
37
  id_params = super().get_id_params()
@@ -47,7 +45,7 @@ class ConvertOp(PhysicalOperator, ABC):
47
45
 
48
46
  def get_op_params(self):
49
47
  op_params = super().get_op_params()
50
- op_params = {"cardinality": self.cardinality, "udf": self.udf, "desc": self.desc, **op_params}
48
+ op_params = {"cardinality": self.cardinality, "udf": self.udf, **op_params}
51
49
 
52
50
  return op_params
53
51
 
@@ -78,8 +76,8 @@ class ConvertOp(PhysicalOperator, ABC):
78
76
  setattr(dr, field, getattr(candidate, field))
79
77
 
80
78
  # get input field names and output field names
81
- input_fields = self.input_schema.field_names()
82
- output_fields = self.output_schema.field_names()
79
+ input_fields = list(self.input_schema.model_fields)
80
+ output_fields = list(self.output_schema.model_fields)
83
81
 
84
82
  # parse newly generated fields from the field_answers dictionary for this field; if the list
85
83
  # of generated values is shorter than the number of records, we fill in with None
@@ -112,8 +110,8 @@ class ConvertOp(PhysicalOperator, ABC):
112
110
  record_op_stats_lst = [
113
111
  RecordOpStats(
114
112
  record_id=dr.id,
115
- record_parent_id=dr.parent_id,
116
- record_source_idx=dr.source_idx,
113
+ record_parent_ids=dr.parent_ids,
114
+ record_source_indices=dr.source_indices,
117
115
  record_state=dr.to_dict(include_bytes=False),
118
116
  full_op_id=self.get_full_op_id(),
119
117
  logical_op_id=self.logical_op_id,
@@ -122,7 +120,7 @@ class ConvertOp(PhysicalOperator, ABC):
122
120
  cost_per_record=per_record_stats.cost_per_record,
123
121
  model_name=self.get_model_name(),
124
122
  answer={field_name: getattr(dr, field_name) for field_name in field_names},
125
- input_fields=self.input_schema.field_names(),
123
+ input_fields=list(self.input_schema.model_fields),
126
124
  generated_fields=field_names,
127
125
  total_input_tokens=per_record_stats.total_input_tokens,
128
126
  total_output_tokens=per_record_stats.total_output_tokens,
@@ -148,7 +146,7 @@ class ConvertOp(PhysicalOperator, ABC):
148
146
  pass
149
147
 
150
148
  @abstractmethod
151
- def convert(self, candidate: DataRecord, fields: dict[str, Field]) -> tuple[dict[str, list], GenerationStats]:
149
+ def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[str, list], GenerationStats]:
152
150
  """
153
151
  This abstract method will be implemented by subclasses of ConvertOp to process the input DataRecord
154
152
  and generate the value(s) for each of the specified fields. If the convert operator is a one-to-many
@@ -182,7 +180,7 @@ class ConvertOp(PhysicalOperator, ABC):
182
180
 
183
181
  # execute the convert
184
182
  field_answers: dict[str, list]
185
- fields = {field: field_type for field, field_type in self.output_schema.field_map().items() if field in fields_to_generate}
183
+ fields = {field: field_type for field, field_type in self.output_schema.model_fields.items() if field in fields_to_generate}
186
184
  field_answers, generation_stats = self.convert(candidate=candidate, fields=fields)
187
185
  assert all([field in field_answers for field in fields_to_generate]), "Not all fields were generated!"
188
186
 
@@ -235,7 +233,7 @@ class NonLLMConvert(ConvertOp):
235
233
  quality=1.0,
236
234
  )
237
235
 
238
- def convert(self, candidate: DataRecord, fields: dict[str, Field]) -> tuple[dict[str, list], GenerationStats]:
236
+ def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[str, list], GenerationStats]:
239
237
  # apply UDF to input record
240
238
  start_time = time.time()
241
239
  field_answers = {}
@@ -282,18 +280,21 @@ class LLMConvert(ConvertOp):
282
280
  self,
283
281
  model: Model,
284
282
  prompt_strategy: PromptStrategy = PromptStrategy.COT_QA,
283
+ reasoning_effort: str | None = None,
285
284
  *args,
286
285
  **kwargs,
287
286
  ):
288
287
  super().__init__(*args, **kwargs)
289
288
  self.model = model
290
289
  self.prompt_strategy = prompt_strategy
290
+ self.reasoning_effort = reasoning_effort
291
291
  if model is not None:
292
- self.generator = generator_factory(model, prompt_strategy, self.cardinality, self.verbose)
292
+ self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base, self.cardinality, self.verbose)
293
293
 
294
294
  def __str__(self):
295
295
  op = super().__str__()
296
296
  op += f" Prompt Strategy: {self.prompt_strategy}\n"
297
+ op += f" Reasoning Effort: {self.reasoning_effort}\n"
297
298
  return op
298
299
 
299
300
  def get_id_params(self):
@@ -301,6 +302,7 @@ class LLMConvert(ConvertOp):
301
302
  id_params = {
302
303
  "model": None if self.model is None else self.model.value,
303
304
  "prompt_strategy": None if self.prompt_strategy is None else self.prompt_strategy.value,
305
+ "reasoning_effort": self.reasoning_effort,
304
306
  **id_params,
305
307
  }
306
308
 
@@ -311,6 +313,7 @@ class LLMConvert(ConvertOp):
311
313
  op_params = {
312
314
  "model": self.model,
313
315
  "prompt_strategy": self.prompt_strategy,
316
+ "reasoning_effort": self.reasoning_effort,
314
317
  **op_params,
315
318
  }
316
319
 
@@ -320,7 +323,7 @@ class LLMConvert(ConvertOp):
320
323
  return None if self.model is None else self.model.value
321
324
 
322
325
  def is_image_conversion(self) -> bool:
323
- return self.model in get_vision_models()
326
+ return self.prompt_strategy.is_image_prompt()
324
327
 
325
328
  def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
326
329
  """
@@ -334,13 +337,16 @@ class LLMConvert(ConvertOp):
334
337
  est_num_output_tokens = NAIVE_EST_NUM_OUTPUT_TOKENS
335
338
 
336
339
  # get est. of conversion time per record from model card;
337
- # NOTE: model will only be None for code synthesis, which uses GPT-3.5 as fallback
338
340
  model_name = self.model.value if getattr(self, "model", None) is not None else Model.GPT_4o_MINI.value
339
341
  model_conversion_time_per_record = MODEL_CARDS[model_name]["seconds_per_output_token"] * est_num_output_tokens
340
342
 
341
343
  # get est. of conversion cost (in USD) per record from model card
344
+ usd_per_input_token = MODEL_CARDS[model_name].get("usd_per_input_token")
345
+ if getattr(self, "prompt_strategy", None) is not None and self.prompt_strategy.is_audio_prompt():
346
+ usd_per_input_token = MODEL_CARDS[model_name]["usd_per_audio_input_token"]
347
+
342
348
  model_conversion_usd_per_record = (
343
- MODEL_CARDS[model_name]["usd_per_input_token"] * est_num_input_tokens
349
+ usd_per_input_token * est_num_input_tokens
344
350
  + MODEL_CARDS[model_name]["usd_per_output_token"] * est_num_output_tokens
345
351
  )
346
352
 
@@ -349,7 +355,7 @@ class LLMConvert(ConvertOp):
349
355
  cardinality = selectivity * source_op_cost_estimates.cardinality
350
356
 
351
357
  # estimate quality of output based on the strength of the model being used
352
- quality = (MODEL_CARDS[model_name]["overall"] / 100.0) * source_op_cost_estimates.quality
358
+ quality = (MODEL_CARDS[model_name]["overall"] / 100.0)
353
359
 
354
360
  return OperatorCostEstimates(
355
361
  cardinality=cardinality,
@@ -361,7 +367,7 @@ class LLMConvert(ConvertOp):
361
367
 
362
368
  class LLMConvertBonded(LLMConvert):
363
369
 
364
- def convert(self, candidate: DataRecord, fields: dict[str, Field]) -> tuple[dict[str, list], GenerationStats]:
370
+ def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[str, list], GenerationStats]:
365
371
  # get the set of input fields to use for the convert operation
366
372
  input_fields = self.get_input_fields()
367
373
 
@@ -2,10 +2,12 @@ from __future__ import annotations
2
2
 
3
3
  from typing import Any
4
4
 
5
+ from pydantic.fields import FieldInfo
6
+
5
7
  from palimpzest.constants import MODEL_CARDS, Model, PromptStrategy
6
- from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates
7
8
  from palimpzest.core.elements.records import DataRecord
8
- from palimpzest.query.generators.generators import generator_factory
9
+ from palimpzest.core.models import GenerationStats, OperatorCostEstimates
10
+ from palimpzest.query.generators.generators import Generator
9
11
  from palimpzest.query.operators.convert import LLMConvert
10
12
 
11
13
  # TYPE DEFINITIONS
@@ -35,8 +37,8 @@ class CriticAndRefineConvert(LLMConvert):
35
37
  raise ValueError(f"Unsupported prompt strategy: {self.prompt_strategy}")
36
38
 
37
39
  # create generators
38
- self.critic_generator = generator_factory(self.critic_model, self.critic_prompt_strategy, self.cardinality, self.verbose)
39
- self.refine_generator = generator_factory(self.refine_model, self.refinement_prompt_strategy, self.cardinality, self.verbose)
40
+ self.critic_generator = Generator(self.critic_model, self.critic_prompt_strategy, self.reasoning_effort, self.api_base, self.cardinality, self.verbose)
41
+ self.refine_generator = Generator(self.refine_model, self.refinement_prompt_strategy, self.reasoning_effort, self.api_base, self.cardinality, self.verbose)
40
42
 
41
43
  def __str__(self):
42
44
  op = super().__str__()
@@ -86,7 +88,7 @@ class CriticAndRefineConvert(LLMConvert):
86
88
 
87
89
  return naive_op_cost_estimates
88
90
 
89
- def convert(self, candidate: DataRecord, fields: list[str]) -> tuple[dict[FieldName, list[Any]], GenerationStats]:
91
+ def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[FieldName, list[Any]], GenerationStats]:
90
92
  # get input fields
91
93
  input_fields = self.get_input_fields()
92
94