palimpzest 0.7.21__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/__init__.py +37 -6
- palimpzest/agents/__init__.py +0 -0
- palimpzest/agents/compute_agents.py +0 -0
- palimpzest/agents/search_agents.py +637 -0
- palimpzest/constants.py +343 -209
- palimpzest/core/data/context.py +393 -0
- palimpzest/core/data/context_manager.py +163 -0
- palimpzest/core/data/dataset.py +639 -0
- palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
- palimpzest/core/elements/groupbysig.py +16 -13
- palimpzest/core/elements/records.py +166 -75
- palimpzest/core/lib/schemas.py +152 -390
- palimpzest/core/{data/dataclasses.py → models.py} +306 -170
- palimpzest/policy.py +2 -27
- palimpzest/prompts/__init__.py +35 -5
- palimpzest/prompts/agent_prompts.py +357 -0
- palimpzest/prompts/context_search.py +9 -0
- palimpzest/prompts/convert_prompts.py +62 -6
- palimpzest/prompts/filter_prompts.py +51 -6
- palimpzest/prompts/join_prompts.py +163 -0
- palimpzest/prompts/moa_proposer_convert_prompts.py +6 -6
- palimpzest/prompts/prompt_factory.py +375 -47
- palimpzest/prompts/split_proposer_prompts.py +1 -1
- palimpzest/prompts/util_phrases.py +5 -0
- palimpzest/prompts/validator.py +239 -0
- palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
- palimpzest/query/execution/execution_strategy.py +210 -317
- palimpzest/query/execution/execution_strategy_type.py +5 -7
- palimpzest/query/execution/mab_execution_strategy.py +249 -136
- palimpzest/query/execution/parallel_execution_strategy.py +153 -244
- palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
- palimpzest/query/generators/generators.py +160 -331
- palimpzest/query/operators/__init__.py +15 -5
- palimpzest/query/operators/aggregate.py +50 -33
- palimpzest/query/operators/compute.py +201 -0
- palimpzest/query/operators/convert.py +33 -19
- palimpzest/query/operators/critique_and_refine_convert.py +7 -5
- palimpzest/query/operators/distinct.py +62 -0
- palimpzest/query/operators/filter.py +26 -16
- palimpzest/query/operators/join.py +403 -0
- palimpzest/query/operators/limit.py +3 -3
- palimpzest/query/operators/logical.py +205 -77
- palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
- palimpzest/query/operators/physical.py +27 -21
- palimpzest/query/operators/project.py +3 -3
- palimpzest/query/operators/rag_convert.py +7 -7
- palimpzest/query/operators/retrieve.py +9 -9
- palimpzest/query/operators/scan.py +81 -42
- palimpzest/query/operators/search.py +524 -0
- palimpzest/query/operators/split_convert.py +10 -8
- palimpzest/query/optimizer/__init__.py +7 -9
- palimpzest/query/optimizer/cost_model.py +108 -441
- palimpzest/query/optimizer/optimizer.py +123 -181
- palimpzest/query/optimizer/optimizer_strategy.py +66 -61
- palimpzest/query/optimizer/plan.py +352 -67
- palimpzest/query/optimizer/primitives.py +43 -19
- palimpzest/query/optimizer/rules.py +484 -646
- palimpzest/query/optimizer/tasks.py +127 -58
- palimpzest/query/processor/config.py +42 -76
- palimpzest/query/processor/query_processor.py +73 -18
- palimpzest/query/processor/query_processor_factory.py +46 -38
- palimpzest/schemabuilder/schema_builder.py +15 -28
- palimpzest/utils/model_helpers.py +32 -77
- palimpzest/utils/progress.py +114 -102
- palimpzest/validator/__init__.py +0 -0
- palimpzest/validator/validator.py +306 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/METADATA +6 -1
- palimpzest-0.8.1.dist-info/RECORD +95 -0
- palimpzest/core/lib/fields.py +0 -141
- palimpzest/prompts/code_synthesis_prompts.py +0 -28
- palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
- palimpzest/query/generators/api_client_factory.py +0 -30
- palimpzest/query/operators/code_synthesis_convert.py +0 -488
- palimpzest/query/operators/map.py +0 -130
- palimpzest/query/processor/nosentinel_processor.py +0 -33
- palimpzest/query/processor/processing_strategy_type.py +0 -28
- palimpzest/query/processor/sentinel_processor.py +0 -88
- palimpzest/query/processor/streaming_processor.py +0 -149
- palimpzest/sets.py +0 -405
- palimpzest/utils/datareader_helpers.py +0 -61
- palimpzest/utils/demo_helpers.py +0 -75
- palimpzest/utils/field_helpers.py +0 -69
- palimpzest/utils/generation_helpers.py +0 -69
- palimpzest/utils/sandbox.py +0 -183
- palimpzest-0.7.21.dist-info/RECORD +0 -95
- /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/WHEEL +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/top_level.txt +0 -0
|
@@ -6,9 +6,12 @@ from palimpzest.query.operators.convert import ConvertOp as _ConvertOp
|
|
|
6
6
|
from palimpzest.query.operators.convert import LLMConvert as _LLMConvert
|
|
7
7
|
from palimpzest.query.operators.convert import LLMConvertBonded as _LLMConvertBonded
|
|
8
8
|
from palimpzest.query.operators.convert import NonLLMConvert as _NonLLMConvert
|
|
9
|
+
from palimpzest.query.operators.distinct import DistinctOp as _DistinctOp
|
|
9
10
|
from palimpzest.query.operators.filter import FilterOp as _FilterOp
|
|
10
11
|
from palimpzest.query.operators.filter import LLMFilter as _LLMFilter
|
|
11
12
|
from palimpzest.query.operators.filter import NonLLMFilter as _NonLLMFilter
|
|
13
|
+
from palimpzest.query.operators.join import JoinOp as _JoinOp
|
|
14
|
+
from palimpzest.query.operators.join import NestedLoopsJoin as _NestedLoopsJoin
|
|
12
15
|
from palimpzest.query.operators.limit import LimitScanOp as _LimitScanOp
|
|
13
16
|
from palimpzest.query.operators.logical import (
|
|
14
17
|
Aggregate as _Aggregate,
|
|
@@ -17,10 +20,10 @@ from palimpzest.query.operators.logical import (
|
|
|
17
20
|
BaseScan as _BaseScan,
|
|
18
21
|
)
|
|
19
22
|
from palimpzest.query.operators.logical import (
|
|
20
|
-
|
|
23
|
+
ConvertScan as _ConvertScan,
|
|
21
24
|
)
|
|
22
25
|
from palimpzest.query.operators.logical import (
|
|
23
|
-
|
|
26
|
+
Distinct as _Distinct,
|
|
24
27
|
)
|
|
25
28
|
from palimpzest.query.operators.logical import (
|
|
26
29
|
FilteredScan as _FilteredScan,
|
|
@@ -28,6 +31,9 @@ from palimpzest.query.operators.logical import (
|
|
|
28
31
|
from palimpzest.query.operators.logical import (
|
|
29
32
|
GroupByAggregate as _GroupByAggregate,
|
|
30
33
|
)
|
|
34
|
+
from palimpzest.query.operators.logical import (
|
|
35
|
+
JoinOp as _LogicalJoinOp,
|
|
36
|
+
)
|
|
31
37
|
from palimpzest.query.operators.logical import (
|
|
32
38
|
LimitScan as _LimitScan,
|
|
33
39
|
)
|
|
@@ -44,7 +50,6 @@ from palimpzest.query.operators.mixture_of_agents_convert import MixtureOfAgents
|
|
|
44
50
|
from palimpzest.query.operators.physical import PhysicalOperator as _PhysicalOperator
|
|
45
51
|
from palimpzest.query.operators.project import ProjectOp as _ProjectOp
|
|
46
52
|
from palimpzest.query.operators.retrieve import RetrieveOp as _RetrieveOp
|
|
47
|
-
from palimpzest.query.operators.scan import CacheScanDataOp as _CacheScanDataOp
|
|
48
53
|
from palimpzest.query.operators.scan import MarshalAndScanDataOp as _MarshalAndScanDataOp
|
|
49
54
|
from palimpzest.query.operators.scan import ScanPhysicalOp as _ScanPhysicalOp
|
|
50
55
|
|
|
@@ -52,10 +57,11 @@ LOGICAL_OPERATORS = [
|
|
|
52
57
|
_LogicalOperator,
|
|
53
58
|
_Aggregate,
|
|
54
59
|
_BaseScan,
|
|
55
|
-
_CacheScan,
|
|
56
60
|
_ConvertScan,
|
|
61
|
+
_Distinct,
|
|
57
62
|
_FilteredScan,
|
|
58
63
|
_GroupByAggregate,
|
|
64
|
+
_LogicalJoinOp,
|
|
59
65
|
_LimitScan,
|
|
60
66
|
_Project,
|
|
61
67
|
_RetrieveScan,
|
|
@@ -66,10 +72,14 @@ PHYSICAL_OPERATORS = (
|
|
|
66
72
|
[_AggregateOp, _ApplyGroupByOp, _AverageAggregateOp, _CountAggregateOp]
|
|
67
73
|
# convert
|
|
68
74
|
+ [_ConvertOp, _NonLLMConvert, _LLMConvert, _LLMConvertBonded]
|
|
75
|
+
# distinct
|
|
76
|
+
+ [_DistinctOp]
|
|
69
77
|
# scan
|
|
70
|
-
+ [_ScanPhysicalOp, _MarshalAndScanDataOp
|
|
78
|
+
+ [_ScanPhysicalOp, _MarshalAndScanDataOp]
|
|
71
79
|
# filter
|
|
72
80
|
+ [_FilterOp, _NonLLMFilter, _LLMFilter]
|
|
81
|
+
# join
|
|
82
|
+
+ [_JoinOp, _NestedLoopsJoin]
|
|
73
83
|
# limit
|
|
74
84
|
+ [_LimitScanOp]
|
|
75
85
|
# mixture-of-agents
|
|
@@ -3,10 +3,10 @@ from __future__ import annotations
|
|
|
3
3
|
import time
|
|
4
4
|
|
|
5
5
|
from palimpzest.constants import NAIVE_EST_NUM_GROUPS, AggFunc
|
|
6
|
-
from palimpzest.core.data.dataclasses import OperatorCostEstimates, RecordOpStats
|
|
7
6
|
from palimpzest.core.elements.groupbysig import GroupBySig
|
|
8
7
|
from palimpzest.core.elements.records import DataRecord, DataRecordSet
|
|
9
|
-
from palimpzest.core.lib.schemas import
|
|
8
|
+
from palimpzest.core.lib.schemas import Average, Count
|
|
9
|
+
from palimpzest.core.models import OperatorCostEstimates, RecordOpStats
|
|
10
10
|
from palimpzest.query.operators.physical import PhysicalOperator
|
|
11
11
|
|
|
12
12
|
|
|
@@ -16,7 +16,7 @@ class AggregateOp(PhysicalOperator):
|
|
|
16
16
|
__call__ methods. Thus, we use a slightly modified abstract base class for
|
|
17
17
|
these operators.
|
|
18
18
|
"""
|
|
19
|
-
def __call__(self, candidates:
|
|
19
|
+
def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
|
|
20
20
|
raise NotImplementedError("Using __call__ from abstract method")
|
|
21
21
|
|
|
22
22
|
|
|
@@ -67,6 +67,8 @@ class ApplyGroupByOp(AggregateOp):
|
|
|
67
67
|
return state + 1
|
|
68
68
|
elif func.lower() == "average":
|
|
69
69
|
sum, cnt = state
|
|
70
|
+
if val is None:
|
|
71
|
+
return (sum, cnt)
|
|
70
72
|
return (sum + val, cnt + 1)
|
|
71
73
|
else:
|
|
72
74
|
raise Exception("Unknown agg function " + func)
|
|
@@ -77,11 +79,11 @@ class ApplyGroupByOp(AggregateOp):
|
|
|
77
79
|
return state
|
|
78
80
|
elif func.lower() == "average":
|
|
79
81
|
sum, cnt = state
|
|
80
|
-
return float(sum) / cnt
|
|
82
|
+
return float(sum) / cnt if cnt > 0 else None
|
|
81
83
|
else:
|
|
82
84
|
raise Exception("Unknown agg function " + func)
|
|
83
85
|
|
|
84
|
-
def __call__(self, candidates:
|
|
86
|
+
def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
|
|
85
87
|
start_time = time.time()
|
|
86
88
|
|
|
87
89
|
# build group array
|
|
@@ -107,17 +109,13 @@ class ApplyGroupByOp(AggregateOp):
|
|
|
107
109
|
agg_state[group] = state
|
|
108
110
|
|
|
109
111
|
# return list of data records (one per group)
|
|
110
|
-
drs = []
|
|
112
|
+
drs: list[DataRecord] = []
|
|
111
113
|
group_by_fields = self.group_by_sig.group_by_fields
|
|
112
114
|
agg_fields = self.group_by_sig.get_agg_field_names()
|
|
113
115
|
for g in agg_state:
|
|
114
|
-
dr = DataRecord
|
|
115
|
-
# NOTE: this will set the parent_id and source_idx to be the id of the final source record;
|
|
116
|
-
# in the near future we may want to have parent_id accept a list of ids
|
|
117
|
-
dr = DataRecord.from_parent(
|
|
116
|
+
dr = DataRecord.from_agg_parents(
|
|
118
117
|
schema=self.group_by_sig.output_schema(),
|
|
119
|
-
|
|
120
|
-
project_cols=[],
|
|
118
|
+
parent_records=candidates,
|
|
121
119
|
)
|
|
122
120
|
for i in range(0, len(g)):
|
|
123
121
|
k = g[i]
|
|
@@ -135,8 +133,8 @@ class ApplyGroupByOp(AggregateOp):
|
|
|
135
133
|
for dr in drs:
|
|
136
134
|
record_op_stats = RecordOpStats(
|
|
137
135
|
record_id=dr.id,
|
|
138
|
-
|
|
139
|
-
|
|
136
|
+
record_parent_ids=dr.parent_ids,
|
|
137
|
+
record_source_indices=dr.source_indices,
|
|
140
138
|
record_state=dr.to_dict(include_bytes=False),
|
|
141
139
|
full_op_id=self.get_full_op_id(),
|
|
142
140
|
logical_op_id=self.logical_op_id,
|
|
@@ -155,13 +153,20 @@ class AverageAggregateOp(AggregateOp):
|
|
|
155
153
|
# NOTE: we don't actually need / use agg_func here (yet)
|
|
156
154
|
|
|
157
155
|
def __init__(self, agg_func: AggFunc, *args, **kwargs):
|
|
158
|
-
|
|
156
|
+
# enforce that output schema is correct
|
|
157
|
+
assert kwargs["output_schema"] == Average, "AverageAggregateOp requires output_schema to be Average"
|
|
158
|
+
|
|
159
|
+
# enforce that input schema is a single numeric field
|
|
160
|
+
input_field_types = list(kwargs["input_schema"].model_fields.values())
|
|
161
|
+
assert len(input_field_types) == 1, "AverageAggregateOp requires input_schema to have exactly one field"
|
|
162
|
+
numeric_field_types = [bool, int, float, bool | None, int | None, float | None, int | float, int | float | None]
|
|
163
|
+
is_numeric = input_field_types[0].annotation in numeric_field_types
|
|
164
|
+
assert is_numeric, f"AverageAggregateOp requires input_schema to have a numeric field type, i.e. one of: {numeric_field_types}\nGot: {input_field_types[0]}"
|
|
165
|
+
|
|
166
|
+
# call parent constructor
|
|
159
167
|
super().__init__(*args, **kwargs)
|
|
160
168
|
self.agg_func = agg_func
|
|
161
169
|
|
|
162
|
-
if not self.input_schema.get_desc() == Number.get_desc():
|
|
163
|
-
raise Exception("Aggregate function AVERAGE is only defined over Numbers")
|
|
164
|
-
|
|
165
170
|
def __str__(self):
|
|
166
171
|
op = super().__str__()
|
|
167
172
|
op += f" Function: {str(self.agg_func)}\n"
|
|
@@ -184,19 +189,29 @@ class AverageAggregateOp(AggregateOp):
|
|
|
184
189
|
quality=1.0,
|
|
185
190
|
)
|
|
186
191
|
|
|
187
|
-
def __call__(self, candidates:
|
|
192
|
+
def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
|
|
188
193
|
start_time = time.time()
|
|
189
194
|
|
|
190
|
-
# NOTE:
|
|
191
|
-
#
|
|
192
|
-
|
|
193
|
-
|
|
195
|
+
# NOTE: we currently do not guarantee that input values conform to their specified type;
|
|
196
|
+
# as a result, we simply omit any values which do not parse to a float from the average
|
|
197
|
+
# NOTE: right now we perform a check in the constructor which enforces that the input_schema
|
|
198
|
+
# has a single field which is numeric in nature; in the future we may want to have a
|
|
199
|
+
# cleaner way of computing the value (rather than `float(list(candidate...))` below)
|
|
200
|
+
dr = DataRecord.from_agg_parents(schema=Average, parent_records=candidates)
|
|
201
|
+
summation, total = 0, 0
|
|
202
|
+
for candidate in candidates:
|
|
203
|
+
try:
|
|
204
|
+
summation += float(list(candidate.to_dict().values())[0])
|
|
205
|
+
total += 1
|
|
206
|
+
except Exception:
|
|
207
|
+
pass
|
|
208
|
+
dr.average = summation / total
|
|
194
209
|
|
|
195
210
|
# create RecordOpStats object
|
|
196
211
|
record_op_stats = RecordOpStats(
|
|
197
212
|
record_id=dr.id,
|
|
198
|
-
|
|
199
|
-
|
|
213
|
+
record_parent_ids=dr.parent_ids,
|
|
214
|
+
record_source_indices=dr.source_indices,
|
|
200
215
|
record_state=dr.to_dict(include_bytes=False),
|
|
201
216
|
full_op_id=self.get_full_op_id(),
|
|
202
217
|
logical_op_id=self.logical_op_id,
|
|
@@ -212,7 +227,10 @@ class CountAggregateOp(AggregateOp):
|
|
|
212
227
|
# NOTE: we don't actually need / use agg_func here (yet)
|
|
213
228
|
|
|
214
229
|
def __init__(self, agg_func: AggFunc, *args, **kwargs):
|
|
215
|
-
|
|
230
|
+
# enforce that output schema is correct
|
|
231
|
+
assert kwargs["output_schema"] == Count, "CountAggregateOp requires output_schema to be Count"
|
|
232
|
+
|
|
233
|
+
# call parent constructor
|
|
216
234
|
super().__init__(*args, **kwargs)
|
|
217
235
|
self.agg_func = agg_func
|
|
218
236
|
|
|
@@ -238,19 +256,18 @@ class CountAggregateOp(AggregateOp):
|
|
|
238
256
|
quality=1.0,
|
|
239
257
|
)
|
|
240
258
|
|
|
241
|
-
def __call__(self, candidates:
|
|
259
|
+
def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
|
|
242
260
|
start_time = time.time()
|
|
243
261
|
|
|
244
|
-
#
|
|
245
|
-
|
|
246
|
-
dr =
|
|
247
|
-
dr.value = len(candidates)
|
|
262
|
+
# create new DataRecord
|
|
263
|
+
dr = DataRecord.from_agg_parents(schema=Count, parent_records=candidates)
|
|
264
|
+
dr.count = len(candidates)
|
|
248
265
|
|
|
249
266
|
# create RecordOpStats object
|
|
250
267
|
record_op_stats = RecordOpStats(
|
|
251
268
|
record_id=dr.id,
|
|
252
|
-
|
|
253
|
-
|
|
269
|
+
record_parent_ids=dr.parent_ids,
|
|
270
|
+
record_source_indices=dr.source_indices,
|
|
254
271
|
record_state=dr.to_dict(include_bytes=False),
|
|
255
272
|
full_op_id=self.get_full_op_id(),
|
|
256
273
|
logical_op_id=self.logical_op_id,
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import inspect
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from smolagents import CodeAgent, LiteLLMModel, tool
|
|
8
|
+
|
|
9
|
+
from palimpzest.core.data.context import Context
|
|
10
|
+
from palimpzest.core.data.context_manager import ContextManager
|
|
11
|
+
from palimpzest.core.elements.records import DataRecord, DataRecordSet
|
|
12
|
+
from palimpzest.core.models import GenerationStats, OperatorCostEstimates, RecordOpStats
|
|
13
|
+
from palimpzest.query.operators.physical import PhysicalOperator
|
|
14
|
+
|
|
15
|
+
# TODO: need to store final executed code in compute() operator so that humans can debug when human-in-the-loop
|
|
16
|
+
|
|
17
|
+
def make_tool(bound_method):
|
|
18
|
+
# Get the original function and bound instance
|
|
19
|
+
func = bound_method.__func__
|
|
20
|
+
instance = bound_method.__self__
|
|
21
|
+
|
|
22
|
+
# Get the signature and remove 'self'
|
|
23
|
+
sig = inspect.signature(func)
|
|
24
|
+
params = list(sig.parameters.values())[1:] # skip 'self'
|
|
25
|
+
new_sig = inspect.Signature(parameters=params, return_annotation=sig.return_annotation)
|
|
26
|
+
|
|
27
|
+
# Create a wrapper function dynamically
|
|
28
|
+
@functools.wraps(func)
|
|
29
|
+
def wrapper(*args, **kwargs):
|
|
30
|
+
return func(instance, *args, **kwargs)
|
|
31
|
+
|
|
32
|
+
# Update the __signature__ to reflect the new one without 'self'
|
|
33
|
+
wrapper.__signature__ = new_sig
|
|
34
|
+
|
|
35
|
+
return wrapper
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class SmolAgentsCompute(PhysicalOperator):
|
|
39
|
+
"""
|
|
40
|
+
"""
|
|
41
|
+
def __init__(self, context_id: str, instruction: str, additional_contexts: list[Context] | None = None, *args, **kwargs):
|
|
42
|
+
super().__init__(*args, **kwargs)
|
|
43
|
+
self.context_id = context_id
|
|
44
|
+
self.instruction = instruction
|
|
45
|
+
self.additional_contexts = [] if additional_contexts is None else additional_contexts
|
|
46
|
+
# self.model_id = "anthropic/claude-3-7-sonnet-latest"
|
|
47
|
+
self.model_id = "openai/gpt-4o-mini-2024-07-18"
|
|
48
|
+
# self.model_id = "openai/gpt-4o-2024-08-06"
|
|
49
|
+
api_key = os.getenv("ANTHROPIC_API_KEY") if "anthropic" in self.model_id else os.getenv("OPENAI_API_KEY")
|
|
50
|
+
self.model = LiteLLMModel(model_id=self.model_id, api_key=api_key)
|
|
51
|
+
|
|
52
|
+
def __str__(self):
|
|
53
|
+
op = super().__str__()
|
|
54
|
+
op += f" Context ID: {self.context_id:20s}\n"
|
|
55
|
+
op += f" Instruction: {self.instruction:20s}\n"
|
|
56
|
+
op += f" Add. Ctxs: {self.additional_contexts}\n"
|
|
57
|
+
return op
|
|
58
|
+
|
|
59
|
+
def get_id_params(self):
|
|
60
|
+
id_params = super().get_id_params()
|
|
61
|
+
return {
|
|
62
|
+
"context_id": self.context_id,
|
|
63
|
+
"instruction": self.instruction,
|
|
64
|
+
"additional_contexts": self.additional_contexts,
|
|
65
|
+
**id_params,
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
def get_op_params(self):
|
|
69
|
+
op_params = super().get_op_params()
|
|
70
|
+
return {
|
|
71
|
+
"context_id": self.context_id,
|
|
72
|
+
"instruction": self.instruction,
|
|
73
|
+
"additional_contexts": self.additional_contexts,
|
|
74
|
+
**op_params,
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
78
|
+
return OperatorCostEstimates(
|
|
79
|
+
cardinality=source_op_cost_estimates.cardinality,
|
|
80
|
+
time_per_record=100,
|
|
81
|
+
cost_per_record=1,
|
|
82
|
+
quality=1.0,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
def _create_record_set(
|
|
86
|
+
self,
|
|
87
|
+
candidate: DataRecord,
|
|
88
|
+
generation_stats: GenerationStats,
|
|
89
|
+
total_time: float,
|
|
90
|
+
answer: dict[str, Any],
|
|
91
|
+
) -> DataRecordSet:
|
|
92
|
+
"""
|
|
93
|
+
Given an input DataRecord and a determination of whether it passed the filter or not,
|
|
94
|
+
construct the resulting RecordSet.
|
|
95
|
+
"""
|
|
96
|
+
# create new DataRecord and set passed_operator attribute
|
|
97
|
+
dr = DataRecord.from_parent(self.output_schema, parent_record=candidate)
|
|
98
|
+
for field in self.output_schema.model_fields:
|
|
99
|
+
if field in answer:
|
|
100
|
+
dr[field] = answer[field]
|
|
101
|
+
|
|
102
|
+
# create RecordOpStats object
|
|
103
|
+
record_op_stats = RecordOpStats(
|
|
104
|
+
record_id=dr.id,
|
|
105
|
+
record_parent_ids=dr.parent_ids,
|
|
106
|
+
record_source_indices=dr.source_indices,
|
|
107
|
+
record_state=dr.to_dict(include_bytes=False),
|
|
108
|
+
full_op_id=self.get_full_op_id(),
|
|
109
|
+
logical_op_id=self.logical_op_id,
|
|
110
|
+
op_name=self.op_name(),
|
|
111
|
+
time_per_record=total_time,
|
|
112
|
+
cost_per_record=generation_stats.cost_per_record,
|
|
113
|
+
model_name=self.get_model_name(),
|
|
114
|
+
total_input_tokens=generation_stats.total_input_tokens,
|
|
115
|
+
total_output_tokens=generation_stats.total_output_tokens,
|
|
116
|
+
total_input_cost=generation_stats.total_input_cost,
|
|
117
|
+
total_output_cost=generation_stats.total_output_cost,
|
|
118
|
+
llm_call_duration_secs=generation_stats.llm_call_duration_secs,
|
|
119
|
+
fn_call_duration_secs=generation_stats.fn_call_duration_secs,
|
|
120
|
+
total_llm_calls=generation_stats.total_llm_calls,
|
|
121
|
+
total_embedding_llm_calls=generation_stats.total_embedding_llm_calls,
|
|
122
|
+
answer={k: v.description if isinstance(v, Context) else v for k, v in answer.items()},
|
|
123
|
+
op_details={k: str(v) for k, v in self.get_id_params().items()},
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
return DataRecordSet([dr], [record_op_stats])
|
|
127
|
+
|
|
128
|
+
def __call__(self, candidate: DataRecord) -> Any:
|
|
129
|
+
start_time = time.time()
|
|
130
|
+
|
|
131
|
+
# get the input context object and its tools
|
|
132
|
+
input_context: Context = candidate.context
|
|
133
|
+
description = input_context.description
|
|
134
|
+
tools = [tool(make_tool(f)) for f in input_context.tools]
|
|
135
|
+
|
|
136
|
+
# update the description to include any additional contexts
|
|
137
|
+
for ctx in self.additional_contexts:
|
|
138
|
+
# TODO: remove additional context if it is an ancestor of the input context
|
|
139
|
+
# (not just if it is equal to the input context)
|
|
140
|
+
if ctx.id == input_context.id:
|
|
141
|
+
continue
|
|
142
|
+
description += f"\n\nHere is some additional Context which may be useful:\n\n{ctx.description}"
|
|
143
|
+
|
|
144
|
+
# perform the computation
|
|
145
|
+
instructions = f"\n\nHere is a description of the Context whose data you will be working with, as well as any previously computed results:\n\n{description}"
|
|
146
|
+
agent = CodeAgent(
|
|
147
|
+
tools=tools,
|
|
148
|
+
model=self.model,
|
|
149
|
+
add_base_tools=False,
|
|
150
|
+
instructions=instructions,
|
|
151
|
+
return_full_result=True,
|
|
152
|
+
additional_authorized_imports=["pandas", "io", "os"],
|
|
153
|
+
planning_interval=4,
|
|
154
|
+
max_steps=30,
|
|
155
|
+
)
|
|
156
|
+
result = agent.run(self.instruction)
|
|
157
|
+
# NOTE: you can see the system prompt with `agent.memory.system_prompt.system_prompt`
|
|
158
|
+
# full_steps = agent.memory.get_full_steps()
|
|
159
|
+
|
|
160
|
+
# compute generation stats
|
|
161
|
+
response = result.output
|
|
162
|
+
input_tokens = result.token_usage.input_tokens
|
|
163
|
+
output_tokens = result.token_usage.output_tokens
|
|
164
|
+
cost_per_input_token = (3.0 / 1e6) if "anthropic" in self.model_id else (0.15 / 1e6) # (2.5 / 1e6) #
|
|
165
|
+
cost_per_output_token = (15.0 / 1e6) if "anthropic" in self.model_id else (0.6 / 1e6) # (10.0 / 1e6) #
|
|
166
|
+
input_cost = input_tokens * cost_per_input_token
|
|
167
|
+
output_cost = output_tokens * cost_per_output_token
|
|
168
|
+
generation_stats = GenerationStats(
|
|
169
|
+
model_name=self.model_id,
|
|
170
|
+
total_input_tokens=input_tokens,
|
|
171
|
+
total_output_tokens=output_tokens,
|
|
172
|
+
total_input_cost=input_cost,
|
|
173
|
+
total_output_cost=output_cost,
|
|
174
|
+
cost_per_record=input_cost + output_cost,
|
|
175
|
+
llm_call_duration_secs=time.time() - start_time,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
# update the description of the computed Context to include the result
|
|
179
|
+
new_description = f"RESULT: {response}\n\n"
|
|
180
|
+
cm = ContextManager()
|
|
181
|
+
cm.update_context(id=self.context_id, description=new_description)
|
|
182
|
+
|
|
183
|
+
# create and return record set
|
|
184
|
+
field_answers = {
|
|
185
|
+
"context": cm.get_context(id=self.context_id),
|
|
186
|
+
f"result-{self.context_id}": response,
|
|
187
|
+
}
|
|
188
|
+
record_set = self._create_record_set(
|
|
189
|
+
candidate,
|
|
190
|
+
generation_stats,
|
|
191
|
+
time.time() - start_time,
|
|
192
|
+
field_answers,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
return record_set
|
|
196
|
+
|
|
197
|
+
# import json; json.dumps(agent.memory.get_full_steps())
|
|
198
|
+
# agent.memory.get_full_steps()[1].keys()
|
|
199
|
+
# dict_keys(['step_number', 'timing', 'model_input_messages', 'tool_calls', 'error', 'model_output_message', 'model_output', 'code_action', 'observations', 'observations_images',
|
|
200
|
+
# 'action_output', 'token_usage', 'is_final_answer'])
|
|
201
|
+
# agent.memory.get_full_steps()[1]['action_output']
|
|
@@ -4,6 +4,8 @@ import time
|
|
|
4
4
|
from abc import ABC, abstractmethod
|
|
5
5
|
from typing import Callable
|
|
6
6
|
|
|
7
|
+
from pydantic.fields import FieldInfo
|
|
8
|
+
|
|
7
9
|
from palimpzest.constants import (
|
|
8
10
|
MODEL_CARDS,
|
|
9
11
|
NAIVE_EST_NUM_INPUT_TOKENS,
|
|
@@ -13,12 +15,10 @@ from palimpzest.constants import (
|
|
|
13
15
|
Model,
|
|
14
16
|
PromptStrategy,
|
|
15
17
|
)
|
|
16
|
-
from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates, RecordOpStats
|
|
17
18
|
from palimpzest.core.elements.records import DataRecord, DataRecordSet
|
|
18
|
-
from palimpzest.core.
|
|
19
|
-
from palimpzest.query.generators.generators import
|
|
19
|
+
from palimpzest.core.models import GenerationStats, OperatorCostEstimates, RecordOpStats
|
|
20
|
+
from palimpzest.query.generators.generators import Generator
|
|
20
21
|
from palimpzest.query.operators.physical import PhysicalOperator
|
|
21
|
-
from palimpzest.utils.model_helpers import get_vision_models
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
class ConvertOp(PhysicalOperator, ABC):
|
|
@@ -40,6 +40,7 @@ class ConvertOp(PhysicalOperator, ABC):
|
|
|
40
40
|
id_params = {
|
|
41
41
|
"cardinality": self.cardinality.value,
|
|
42
42
|
"udf": self.udf,
|
|
43
|
+
"desc": self.desc,
|
|
43
44
|
**id_params,
|
|
44
45
|
}
|
|
45
46
|
|
|
@@ -47,7 +48,12 @@ class ConvertOp(PhysicalOperator, ABC):
|
|
|
47
48
|
|
|
48
49
|
def get_op_params(self):
|
|
49
50
|
op_params = super().get_op_params()
|
|
50
|
-
op_params = {
|
|
51
|
+
op_params = {
|
|
52
|
+
"cardinality": self.cardinality,
|
|
53
|
+
"udf": self.udf,
|
|
54
|
+
"desc": self.desc,
|
|
55
|
+
**op_params,
|
|
56
|
+
}
|
|
51
57
|
|
|
52
58
|
return op_params
|
|
53
59
|
|
|
@@ -78,8 +84,8 @@ class ConvertOp(PhysicalOperator, ABC):
|
|
|
78
84
|
setattr(dr, field, getattr(candidate, field))
|
|
79
85
|
|
|
80
86
|
# get input field names and output field names
|
|
81
|
-
input_fields = self.input_schema.
|
|
82
|
-
output_fields = self.output_schema.
|
|
87
|
+
input_fields = list(self.input_schema.model_fields)
|
|
88
|
+
output_fields = list(self.output_schema.model_fields)
|
|
83
89
|
|
|
84
90
|
# parse newly generated fields from the field_answers dictionary for this field; if the list
|
|
85
91
|
# of generated values is shorter than the number of records, we fill in with None
|
|
@@ -112,8 +118,8 @@ class ConvertOp(PhysicalOperator, ABC):
|
|
|
112
118
|
record_op_stats_lst = [
|
|
113
119
|
RecordOpStats(
|
|
114
120
|
record_id=dr.id,
|
|
115
|
-
|
|
116
|
-
|
|
121
|
+
record_parent_ids=dr.parent_ids,
|
|
122
|
+
record_source_indices=dr.source_indices,
|
|
117
123
|
record_state=dr.to_dict(include_bytes=False),
|
|
118
124
|
full_op_id=self.get_full_op_id(),
|
|
119
125
|
logical_op_id=self.logical_op_id,
|
|
@@ -122,7 +128,7 @@ class ConvertOp(PhysicalOperator, ABC):
|
|
|
122
128
|
cost_per_record=per_record_stats.cost_per_record,
|
|
123
129
|
model_name=self.get_model_name(),
|
|
124
130
|
answer={field_name: getattr(dr, field_name) for field_name in field_names},
|
|
125
|
-
input_fields=self.input_schema.
|
|
131
|
+
input_fields=list(self.input_schema.model_fields),
|
|
126
132
|
generated_fields=field_names,
|
|
127
133
|
total_input_tokens=per_record_stats.total_input_tokens,
|
|
128
134
|
total_output_tokens=per_record_stats.total_output_tokens,
|
|
@@ -148,7 +154,7 @@ class ConvertOp(PhysicalOperator, ABC):
|
|
|
148
154
|
pass
|
|
149
155
|
|
|
150
156
|
@abstractmethod
|
|
151
|
-
def convert(self, candidate: DataRecord, fields: dict[str,
|
|
157
|
+
def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[str, list], GenerationStats]:
|
|
152
158
|
"""
|
|
153
159
|
This abstract method will be implemented by subclasses of ConvertOp to process the input DataRecord
|
|
154
160
|
and generate the value(s) for each of the specified fields. If the convert operator is a one-to-many
|
|
@@ -182,7 +188,7 @@ class ConvertOp(PhysicalOperator, ABC):
|
|
|
182
188
|
|
|
183
189
|
# execute the convert
|
|
184
190
|
field_answers: dict[str, list]
|
|
185
|
-
fields = {field: field_type for field, field_type in self.output_schema.
|
|
191
|
+
fields = {field: field_type for field, field_type in self.output_schema.model_fields.items() if field in fields_to_generate}
|
|
186
192
|
field_answers, generation_stats = self.convert(candidate=candidate, fields=fields)
|
|
187
193
|
assert all([field in field_answers for field in fields_to_generate]), "Not all fields were generated!"
|
|
188
194
|
|
|
@@ -235,7 +241,7 @@ class NonLLMConvert(ConvertOp):
|
|
|
235
241
|
quality=1.0,
|
|
236
242
|
)
|
|
237
243
|
|
|
238
|
-
def convert(self, candidate: DataRecord, fields: dict[str,
|
|
244
|
+
def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[str, list], GenerationStats]:
|
|
239
245
|
# apply UDF to input record
|
|
240
246
|
start_time = time.time()
|
|
241
247
|
field_answers = {}
|
|
@@ -282,18 +288,21 @@ class LLMConvert(ConvertOp):
|
|
|
282
288
|
self,
|
|
283
289
|
model: Model,
|
|
284
290
|
prompt_strategy: PromptStrategy = PromptStrategy.COT_QA,
|
|
291
|
+
reasoning_effort: str | None = None,
|
|
285
292
|
*args,
|
|
286
293
|
**kwargs,
|
|
287
294
|
):
|
|
288
295
|
super().__init__(*args, **kwargs)
|
|
289
296
|
self.model = model
|
|
290
297
|
self.prompt_strategy = prompt_strategy
|
|
298
|
+
self.reasoning_effort = reasoning_effort
|
|
291
299
|
if model is not None:
|
|
292
|
-
self.generator =
|
|
300
|
+
self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base, self.cardinality, self.desc, self.verbose)
|
|
293
301
|
|
|
294
302
|
def __str__(self):
|
|
295
303
|
op = super().__str__()
|
|
296
304
|
op += f" Prompt Strategy: {self.prompt_strategy}\n"
|
|
305
|
+
op += f" Reasoning Effort: {self.reasoning_effort}\n"
|
|
297
306
|
return op
|
|
298
307
|
|
|
299
308
|
def get_id_params(self):
|
|
@@ -301,6 +310,7 @@ class LLMConvert(ConvertOp):
|
|
|
301
310
|
id_params = {
|
|
302
311
|
"model": None if self.model is None else self.model.value,
|
|
303
312
|
"prompt_strategy": None if self.prompt_strategy is None else self.prompt_strategy.value,
|
|
313
|
+
"reasoning_effort": self.reasoning_effort,
|
|
304
314
|
**id_params,
|
|
305
315
|
}
|
|
306
316
|
|
|
@@ -311,6 +321,7 @@ class LLMConvert(ConvertOp):
|
|
|
311
321
|
op_params = {
|
|
312
322
|
"model": self.model,
|
|
313
323
|
"prompt_strategy": self.prompt_strategy,
|
|
324
|
+
"reasoning_effort": self.reasoning_effort,
|
|
314
325
|
**op_params,
|
|
315
326
|
}
|
|
316
327
|
|
|
@@ -320,7 +331,7 @@ class LLMConvert(ConvertOp):
|
|
|
320
331
|
return None if self.model is None else self.model.value
|
|
321
332
|
|
|
322
333
|
def is_image_conversion(self) -> bool:
|
|
323
|
-
return self.
|
|
334
|
+
return self.prompt_strategy.is_image_prompt()
|
|
324
335
|
|
|
325
336
|
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
326
337
|
"""
|
|
@@ -334,13 +345,16 @@ class LLMConvert(ConvertOp):
|
|
|
334
345
|
est_num_output_tokens = NAIVE_EST_NUM_OUTPUT_TOKENS
|
|
335
346
|
|
|
336
347
|
# get est. of conversion time per record from model card;
|
|
337
|
-
# NOTE: model will only be None for code synthesis, which uses GPT-3.5 as fallback
|
|
338
348
|
model_name = self.model.value if getattr(self, "model", None) is not None else Model.GPT_4o_MINI.value
|
|
339
349
|
model_conversion_time_per_record = MODEL_CARDS[model_name]["seconds_per_output_token"] * est_num_output_tokens
|
|
340
350
|
|
|
341
351
|
# get est. of conversion cost (in USD) per record from model card
|
|
352
|
+
usd_per_input_token = MODEL_CARDS[model_name].get("usd_per_input_token")
|
|
353
|
+
if getattr(self, "prompt_strategy", None) is not None and self.prompt_strategy.is_audio_prompt():
|
|
354
|
+
usd_per_input_token = MODEL_CARDS[model_name]["usd_per_audio_input_token"]
|
|
355
|
+
|
|
342
356
|
model_conversion_usd_per_record = (
|
|
343
|
-
|
|
357
|
+
usd_per_input_token * est_num_input_tokens
|
|
344
358
|
+ MODEL_CARDS[model_name]["usd_per_output_token"] * est_num_output_tokens
|
|
345
359
|
)
|
|
346
360
|
|
|
@@ -349,7 +363,7 @@ class LLMConvert(ConvertOp):
|
|
|
349
363
|
cardinality = selectivity * source_op_cost_estimates.cardinality
|
|
350
364
|
|
|
351
365
|
# estimate quality of output based on the strength of the model being used
|
|
352
|
-
quality = (MODEL_CARDS[model_name]["overall"] / 100.0)
|
|
366
|
+
quality = (MODEL_CARDS[model_name]["overall"] / 100.0)
|
|
353
367
|
|
|
354
368
|
return OperatorCostEstimates(
|
|
355
369
|
cardinality=cardinality,
|
|
@@ -361,7 +375,7 @@ class LLMConvert(ConvertOp):
|
|
|
361
375
|
|
|
362
376
|
class LLMConvertBonded(LLMConvert):
|
|
363
377
|
|
|
364
|
-
def convert(self, candidate: DataRecord, fields: dict[str,
|
|
378
|
+
def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[str, list], GenerationStats]:
|
|
365
379
|
# get the set of input fields to use for the convert operation
|
|
366
380
|
input_fields = self.get_input_fields()
|
|
367
381
|
|
|
@@ -2,10 +2,12 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
|
+
from pydantic.fields import FieldInfo
|
|
6
|
+
|
|
5
7
|
from palimpzest.constants import MODEL_CARDS, Model, PromptStrategy
|
|
6
|
-
from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates
|
|
7
8
|
from palimpzest.core.elements.records import DataRecord
|
|
8
|
-
from palimpzest.
|
|
9
|
+
from palimpzest.core.models import GenerationStats, OperatorCostEstimates
|
|
10
|
+
from palimpzest.query.generators.generators import Generator
|
|
9
11
|
from palimpzest.query.operators.convert import LLMConvert
|
|
10
12
|
|
|
11
13
|
# TYPE DEFINITIONS
|
|
@@ -35,8 +37,8 @@ class CriticAndRefineConvert(LLMConvert):
|
|
|
35
37
|
raise ValueError(f"Unsupported prompt strategy: {self.prompt_strategy}")
|
|
36
38
|
|
|
37
39
|
# create generators
|
|
38
|
-
self.critic_generator =
|
|
39
|
-
self.refine_generator =
|
|
40
|
+
self.critic_generator = Generator(self.critic_model, self.critic_prompt_strategy, self.reasoning_effort, self.api_base, self.cardinality, self.desc, self.verbose)
|
|
41
|
+
self.refine_generator = Generator(self.refine_model, self.refinement_prompt_strategy, self.reasoning_effort, self.api_base, self.cardinality, self.desc, self.verbose)
|
|
40
42
|
|
|
41
43
|
def __str__(self):
|
|
42
44
|
op = super().__str__()
|
|
@@ -86,7 +88,7 @@ class CriticAndRefineConvert(LLMConvert):
|
|
|
86
88
|
|
|
87
89
|
return naive_op_cost_estimates
|
|
88
90
|
|
|
89
|
-
def convert(self, candidate: DataRecord, fields:
|
|
91
|
+
def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[FieldName, list[Any]], GenerationStats]:
|
|
90
92
|
# get input fields
|
|
91
93
|
input_fields = self.get_input_fields()
|
|
92
94
|
|