palimpzest 0.7.20__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/__init__.py +37 -6
- palimpzest/agents/__init__.py +0 -0
- palimpzest/agents/compute_agents.py +0 -0
- palimpzest/agents/search_agents.py +637 -0
- palimpzest/constants.py +259 -197
- palimpzest/core/data/context.py +393 -0
- palimpzest/core/data/context_manager.py +163 -0
- palimpzest/core/data/dataset.py +634 -0
- palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
- palimpzest/core/elements/groupbysig.py +16 -13
- palimpzest/core/elements/records.py +166 -75
- palimpzest/core/lib/schemas.py +152 -390
- palimpzest/core/{data/dataclasses.py → models.py} +306 -170
- palimpzest/policy.py +2 -27
- palimpzest/prompts/__init__.py +35 -5
- palimpzest/prompts/agent_prompts.py +357 -0
- palimpzest/prompts/context_search.py +9 -0
- palimpzest/prompts/convert_prompts.py +61 -5
- palimpzest/prompts/filter_prompts.py +50 -5
- palimpzest/prompts/join_prompts.py +163 -0
- palimpzest/prompts/moa_proposer_convert_prompts.py +5 -5
- palimpzest/prompts/prompt_factory.py +358 -46
- palimpzest/prompts/validator.py +239 -0
- palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
- palimpzest/query/execution/execution_strategy.py +210 -317
- palimpzest/query/execution/execution_strategy_type.py +5 -7
- palimpzest/query/execution/mab_execution_strategy.py +249 -136
- palimpzest/query/execution/parallel_execution_strategy.py +153 -244
- palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
- palimpzest/query/generators/generators.py +157 -330
- palimpzest/query/operators/__init__.py +15 -5
- palimpzest/query/operators/aggregate.py +50 -33
- palimpzest/query/operators/compute.py +201 -0
- palimpzest/query/operators/convert.py +27 -21
- palimpzest/query/operators/critique_and_refine_convert.py +7 -5
- palimpzest/query/operators/distinct.py +62 -0
- palimpzest/query/operators/filter.py +22 -13
- palimpzest/query/operators/join.py +402 -0
- palimpzest/query/operators/limit.py +3 -3
- palimpzest/query/operators/logical.py +198 -80
- palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
- palimpzest/query/operators/physical.py +27 -21
- palimpzest/query/operators/project.py +3 -3
- palimpzest/query/operators/rag_convert.py +7 -7
- palimpzest/query/operators/retrieve.py +9 -9
- palimpzest/query/operators/scan.py +81 -42
- palimpzest/query/operators/search.py +524 -0
- palimpzest/query/operators/split_convert.py +10 -8
- palimpzest/query/optimizer/__init__.py +7 -9
- palimpzest/query/optimizer/cost_model.py +108 -441
- palimpzest/query/optimizer/optimizer.py +123 -181
- palimpzest/query/optimizer/optimizer_strategy.py +66 -61
- palimpzest/query/optimizer/plan.py +352 -67
- palimpzest/query/optimizer/primitives.py +43 -19
- palimpzest/query/optimizer/rules.py +484 -646
- palimpzest/query/optimizer/tasks.py +127 -58
- palimpzest/query/processor/config.py +41 -76
- palimpzest/query/processor/query_processor.py +73 -18
- palimpzest/query/processor/query_processor_factory.py +46 -38
- palimpzest/schemabuilder/schema_builder.py +15 -28
- palimpzest/utils/model_helpers.py +27 -77
- palimpzest/utils/progress.py +114 -102
- palimpzest/validator/__init__.py +0 -0
- palimpzest/validator/validator.py +306 -0
- {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/METADATA +6 -1
- palimpzest-0.8.0.dist-info/RECORD +95 -0
- palimpzest/core/lib/fields.py +0 -141
- palimpzest/prompts/code_synthesis_prompts.py +0 -28
- palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
- palimpzest/query/generators/api_client_factory.py +0 -30
- palimpzest/query/operators/code_synthesis_convert.py +0 -488
- palimpzest/query/operators/map.py +0 -130
- palimpzest/query/processor/nosentinel_processor.py +0 -33
- palimpzest/query/processor/processing_strategy_type.py +0 -28
- palimpzest/query/processor/sentinel_processor.py +0 -88
- palimpzest/query/processor/streaming_processor.py +0 -149
- palimpzest/sets.py +0 -405
- palimpzest/utils/datareader_helpers.py +0 -61
- palimpzest/utils/demo_helpers.py +0 -75
- palimpzest/utils/field_helpers.py +0 -69
- palimpzest/utils/generation_helpers.py +0 -69
- palimpzest/utils/sandbox.py +0 -183
- palimpzest-0.7.20.dist-info/RECORD +0 -95
- /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
- {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/WHEEL +0 -0
- {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/top_level.txt +0 -0
|
@@ -6,9 +6,12 @@ from palimpzest.query.operators.convert import ConvertOp as _ConvertOp
|
|
|
6
6
|
from palimpzest.query.operators.convert import LLMConvert as _LLMConvert
|
|
7
7
|
from palimpzest.query.operators.convert import LLMConvertBonded as _LLMConvertBonded
|
|
8
8
|
from palimpzest.query.operators.convert import NonLLMConvert as _NonLLMConvert
|
|
9
|
+
from palimpzest.query.operators.distinct import DistinctOp as _DistinctOp
|
|
9
10
|
from palimpzest.query.operators.filter import FilterOp as _FilterOp
|
|
10
11
|
from palimpzest.query.operators.filter import LLMFilter as _LLMFilter
|
|
11
12
|
from palimpzest.query.operators.filter import NonLLMFilter as _NonLLMFilter
|
|
13
|
+
from palimpzest.query.operators.join import JoinOp as _JoinOp
|
|
14
|
+
from palimpzest.query.operators.join import NestedLoopsJoin as _NestedLoopsJoin
|
|
12
15
|
from palimpzest.query.operators.limit import LimitScanOp as _LimitScanOp
|
|
13
16
|
from palimpzest.query.operators.logical import (
|
|
14
17
|
Aggregate as _Aggregate,
|
|
@@ -17,10 +20,10 @@ from palimpzest.query.operators.logical import (
|
|
|
17
20
|
BaseScan as _BaseScan,
|
|
18
21
|
)
|
|
19
22
|
from palimpzest.query.operators.logical import (
|
|
20
|
-
|
|
23
|
+
ConvertScan as _ConvertScan,
|
|
21
24
|
)
|
|
22
25
|
from palimpzest.query.operators.logical import (
|
|
23
|
-
|
|
26
|
+
Distinct as _Distinct,
|
|
24
27
|
)
|
|
25
28
|
from palimpzest.query.operators.logical import (
|
|
26
29
|
FilteredScan as _FilteredScan,
|
|
@@ -28,6 +31,9 @@ from palimpzest.query.operators.logical import (
|
|
|
28
31
|
from palimpzest.query.operators.logical import (
|
|
29
32
|
GroupByAggregate as _GroupByAggregate,
|
|
30
33
|
)
|
|
34
|
+
from palimpzest.query.operators.logical import (
|
|
35
|
+
JoinOp as _LogicalJoinOp,
|
|
36
|
+
)
|
|
31
37
|
from palimpzest.query.operators.logical import (
|
|
32
38
|
LimitScan as _LimitScan,
|
|
33
39
|
)
|
|
@@ -44,7 +50,6 @@ from palimpzest.query.operators.mixture_of_agents_convert import MixtureOfAgents
|
|
|
44
50
|
from palimpzest.query.operators.physical import PhysicalOperator as _PhysicalOperator
|
|
45
51
|
from palimpzest.query.operators.project import ProjectOp as _ProjectOp
|
|
46
52
|
from palimpzest.query.operators.retrieve import RetrieveOp as _RetrieveOp
|
|
47
|
-
from palimpzest.query.operators.scan import CacheScanDataOp as _CacheScanDataOp
|
|
48
53
|
from palimpzest.query.operators.scan import MarshalAndScanDataOp as _MarshalAndScanDataOp
|
|
49
54
|
from palimpzest.query.operators.scan import ScanPhysicalOp as _ScanPhysicalOp
|
|
50
55
|
|
|
@@ -52,10 +57,11 @@ LOGICAL_OPERATORS = [
|
|
|
52
57
|
_LogicalOperator,
|
|
53
58
|
_Aggregate,
|
|
54
59
|
_BaseScan,
|
|
55
|
-
_CacheScan,
|
|
56
60
|
_ConvertScan,
|
|
61
|
+
_Distinct,
|
|
57
62
|
_FilteredScan,
|
|
58
63
|
_GroupByAggregate,
|
|
64
|
+
_LogicalJoinOp,
|
|
59
65
|
_LimitScan,
|
|
60
66
|
_Project,
|
|
61
67
|
_RetrieveScan,
|
|
@@ -66,10 +72,14 @@ PHYSICAL_OPERATORS = (
|
|
|
66
72
|
[_AggregateOp, _ApplyGroupByOp, _AverageAggregateOp, _CountAggregateOp]
|
|
67
73
|
# convert
|
|
68
74
|
+ [_ConvertOp, _NonLLMConvert, _LLMConvert, _LLMConvertBonded]
|
|
75
|
+
# distinct
|
|
76
|
+
+ [_DistinctOp]
|
|
69
77
|
# scan
|
|
70
|
-
+ [_ScanPhysicalOp, _MarshalAndScanDataOp
|
|
78
|
+
+ [_ScanPhysicalOp, _MarshalAndScanDataOp]
|
|
71
79
|
# filter
|
|
72
80
|
+ [_FilterOp, _NonLLMFilter, _LLMFilter]
|
|
81
|
+
# join
|
|
82
|
+
+ [_JoinOp, _NestedLoopsJoin]
|
|
73
83
|
# limit
|
|
74
84
|
+ [_LimitScanOp]
|
|
75
85
|
# mixture-of-agents
|
|
@@ -3,10 +3,10 @@ from __future__ import annotations
|
|
|
3
3
|
import time
|
|
4
4
|
|
|
5
5
|
from palimpzest.constants import NAIVE_EST_NUM_GROUPS, AggFunc
|
|
6
|
-
from palimpzest.core.data.dataclasses import OperatorCostEstimates, RecordOpStats
|
|
7
6
|
from palimpzest.core.elements.groupbysig import GroupBySig
|
|
8
7
|
from palimpzest.core.elements.records import DataRecord, DataRecordSet
|
|
9
|
-
from palimpzest.core.lib.schemas import
|
|
8
|
+
from palimpzest.core.lib.schemas import Average, Count
|
|
9
|
+
from palimpzest.core.models import OperatorCostEstimates, RecordOpStats
|
|
10
10
|
from palimpzest.query.operators.physical import PhysicalOperator
|
|
11
11
|
|
|
12
12
|
|
|
@@ -16,7 +16,7 @@ class AggregateOp(PhysicalOperator):
|
|
|
16
16
|
__call__ methods. Thus, we use a slightly modified abstract base class for
|
|
17
17
|
these operators.
|
|
18
18
|
"""
|
|
19
|
-
def __call__(self, candidates:
|
|
19
|
+
def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
|
|
20
20
|
raise NotImplementedError("Using __call__ from abstract method")
|
|
21
21
|
|
|
22
22
|
|
|
@@ -67,6 +67,8 @@ class ApplyGroupByOp(AggregateOp):
|
|
|
67
67
|
return state + 1
|
|
68
68
|
elif func.lower() == "average":
|
|
69
69
|
sum, cnt = state
|
|
70
|
+
if val is None:
|
|
71
|
+
return (sum, cnt)
|
|
70
72
|
return (sum + val, cnt + 1)
|
|
71
73
|
else:
|
|
72
74
|
raise Exception("Unknown agg function " + func)
|
|
@@ -77,11 +79,11 @@ class ApplyGroupByOp(AggregateOp):
|
|
|
77
79
|
return state
|
|
78
80
|
elif func.lower() == "average":
|
|
79
81
|
sum, cnt = state
|
|
80
|
-
return float(sum) / cnt
|
|
82
|
+
return float(sum) / cnt if cnt > 0 else None
|
|
81
83
|
else:
|
|
82
84
|
raise Exception("Unknown agg function " + func)
|
|
83
85
|
|
|
84
|
-
def __call__(self, candidates:
|
|
86
|
+
def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
|
|
85
87
|
start_time = time.time()
|
|
86
88
|
|
|
87
89
|
# build group array
|
|
@@ -107,17 +109,13 @@ class ApplyGroupByOp(AggregateOp):
|
|
|
107
109
|
agg_state[group] = state
|
|
108
110
|
|
|
109
111
|
# return list of data records (one per group)
|
|
110
|
-
drs = []
|
|
112
|
+
drs: list[DataRecord] = []
|
|
111
113
|
group_by_fields = self.group_by_sig.group_by_fields
|
|
112
114
|
agg_fields = self.group_by_sig.get_agg_field_names()
|
|
113
115
|
for g in agg_state:
|
|
114
|
-
dr = DataRecord
|
|
115
|
-
# NOTE: this will set the parent_id and source_idx to be the id of the final source record;
|
|
116
|
-
# in the near future we may want to have parent_id accept a list of ids
|
|
117
|
-
dr = DataRecord.from_parent(
|
|
116
|
+
dr = DataRecord.from_agg_parents(
|
|
118
117
|
schema=self.group_by_sig.output_schema(),
|
|
119
|
-
|
|
120
|
-
project_cols=[],
|
|
118
|
+
parent_records=candidates,
|
|
121
119
|
)
|
|
122
120
|
for i in range(0, len(g)):
|
|
123
121
|
k = g[i]
|
|
@@ -135,8 +133,8 @@ class ApplyGroupByOp(AggregateOp):
|
|
|
135
133
|
for dr in drs:
|
|
136
134
|
record_op_stats = RecordOpStats(
|
|
137
135
|
record_id=dr.id,
|
|
138
|
-
|
|
139
|
-
|
|
136
|
+
record_parent_ids=dr.parent_ids,
|
|
137
|
+
record_source_indices=dr.source_indices,
|
|
140
138
|
record_state=dr.to_dict(include_bytes=False),
|
|
141
139
|
full_op_id=self.get_full_op_id(),
|
|
142
140
|
logical_op_id=self.logical_op_id,
|
|
@@ -155,13 +153,20 @@ class AverageAggregateOp(AggregateOp):
|
|
|
155
153
|
# NOTE: we don't actually need / use agg_func here (yet)
|
|
156
154
|
|
|
157
155
|
def __init__(self, agg_func: AggFunc, *args, **kwargs):
|
|
158
|
-
|
|
156
|
+
# enforce that output schema is correct
|
|
157
|
+
assert kwargs["output_schema"] == Average, "AverageAggregateOp requires output_schema to be Average"
|
|
158
|
+
|
|
159
|
+
# enforce that input schema is a single numeric field
|
|
160
|
+
input_field_types = list(kwargs["input_schema"].model_fields.values())
|
|
161
|
+
assert len(input_field_types) == 1, "AverageAggregateOp requires input_schema to have exactly one field"
|
|
162
|
+
numeric_field_types = [bool, int, float, bool | None, int | None, float | None, int | float, int | float | None]
|
|
163
|
+
is_numeric = input_field_types[0].annotation in numeric_field_types
|
|
164
|
+
assert is_numeric, f"AverageAggregateOp requires input_schema to have a numeric field type, i.e. one of: {numeric_field_types}\nGot: {input_field_types[0]}"
|
|
165
|
+
|
|
166
|
+
# call parent constructor
|
|
159
167
|
super().__init__(*args, **kwargs)
|
|
160
168
|
self.agg_func = agg_func
|
|
161
169
|
|
|
162
|
-
if not self.input_schema.get_desc() == Number.get_desc():
|
|
163
|
-
raise Exception("Aggregate function AVERAGE is only defined over Numbers")
|
|
164
|
-
|
|
165
170
|
def __str__(self):
|
|
166
171
|
op = super().__str__()
|
|
167
172
|
op += f" Function: {str(self.agg_func)}\n"
|
|
@@ -184,19 +189,29 @@ class AverageAggregateOp(AggregateOp):
|
|
|
184
189
|
quality=1.0,
|
|
185
190
|
)
|
|
186
191
|
|
|
187
|
-
def __call__(self, candidates:
|
|
192
|
+
def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
|
|
188
193
|
start_time = time.time()
|
|
189
194
|
|
|
190
|
-
# NOTE:
|
|
191
|
-
#
|
|
192
|
-
|
|
193
|
-
|
|
195
|
+
# NOTE: we currently do not guarantee that input values conform to their specified type;
|
|
196
|
+
# as a result, we simply omit any values which do not parse to a float from the average
|
|
197
|
+
# NOTE: right now we perform a check in the constructor which enforces that the input_schema
|
|
198
|
+
# has a single field which is numeric in nature; in the future we may want to have a
|
|
199
|
+
# cleaner way of computing the value (rather than `float(list(candidate...))` below)
|
|
200
|
+
dr = DataRecord.from_agg_parents(schema=Average, parent_records=candidates)
|
|
201
|
+
summation, total = 0, 0
|
|
202
|
+
for candidate in candidates:
|
|
203
|
+
try:
|
|
204
|
+
summation += float(list(candidate.to_dict().values())[0])
|
|
205
|
+
total += 1
|
|
206
|
+
except Exception:
|
|
207
|
+
pass
|
|
208
|
+
dr.average = summation / total
|
|
194
209
|
|
|
195
210
|
# create RecordOpStats object
|
|
196
211
|
record_op_stats = RecordOpStats(
|
|
197
212
|
record_id=dr.id,
|
|
198
|
-
|
|
199
|
-
|
|
213
|
+
record_parent_ids=dr.parent_ids,
|
|
214
|
+
record_source_indices=dr.source_indices,
|
|
200
215
|
record_state=dr.to_dict(include_bytes=False),
|
|
201
216
|
full_op_id=self.get_full_op_id(),
|
|
202
217
|
logical_op_id=self.logical_op_id,
|
|
@@ -212,7 +227,10 @@ class CountAggregateOp(AggregateOp):
|
|
|
212
227
|
# NOTE: we don't actually need / use agg_func here (yet)
|
|
213
228
|
|
|
214
229
|
def __init__(self, agg_func: AggFunc, *args, **kwargs):
|
|
215
|
-
|
|
230
|
+
# enforce that output schema is correct
|
|
231
|
+
assert kwargs["output_schema"] == Count, "CountAggregateOp requires output_schema to be Count"
|
|
232
|
+
|
|
233
|
+
# call parent constructor
|
|
216
234
|
super().__init__(*args, **kwargs)
|
|
217
235
|
self.agg_func = agg_func
|
|
218
236
|
|
|
@@ -238,19 +256,18 @@ class CountAggregateOp(AggregateOp):
|
|
|
238
256
|
quality=1.0,
|
|
239
257
|
)
|
|
240
258
|
|
|
241
|
-
def __call__(self, candidates:
|
|
259
|
+
def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
|
|
242
260
|
start_time = time.time()
|
|
243
261
|
|
|
244
|
-
#
|
|
245
|
-
|
|
246
|
-
dr =
|
|
247
|
-
dr.value = len(candidates)
|
|
262
|
+
# create new DataRecord
|
|
263
|
+
dr = DataRecord.from_agg_parents(schema=Count, parent_records=candidates)
|
|
264
|
+
dr.count = len(candidates)
|
|
248
265
|
|
|
249
266
|
# create RecordOpStats object
|
|
250
267
|
record_op_stats = RecordOpStats(
|
|
251
268
|
record_id=dr.id,
|
|
252
|
-
|
|
253
|
-
|
|
269
|
+
record_parent_ids=dr.parent_ids,
|
|
270
|
+
record_source_indices=dr.source_indices,
|
|
254
271
|
record_state=dr.to_dict(include_bytes=False),
|
|
255
272
|
full_op_id=self.get_full_op_id(),
|
|
256
273
|
logical_op_id=self.logical_op_id,
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import inspect
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from smolagents import CodeAgent, LiteLLMModel, tool
|
|
8
|
+
|
|
9
|
+
from palimpzest.core.data.context import Context
|
|
10
|
+
from palimpzest.core.data.context_manager import ContextManager
|
|
11
|
+
from palimpzest.core.elements.records import DataRecord, DataRecordSet
|
|
12
|
+
from palimpzest.core.models import GenerationStats, OperatorCostEstimates, RecordOpStats
|
|
13
|
+
from palimpzest.query.operators.physical import PhysicalOperator
|
|
14
|
+
|
|
15
|
+
# TODO: need to store final executed code in compute() operator so that humans can debug when human-in-the-loop
|
|
16
|
+
|
|
17
|
+
def make_tool(bound_method):
|
|
18
|
+
# Get the original function and bound instance
|
|
19
|
+
func = bound_method.__func__
|
|
20
|
+
instance = bound_method.__self__
|
|
21
|
+
|
|
22
|
+
# Get the signature and remove 'self'
|
|
23
|
+
sig = inspect.signature(func)
|
|
24
|
+
params = list(sig.parameters.values())[1:] # skip 'self'
|
|
25
|
+
new_sig = inspect.Signature(parameters=params, return_annotation=sig.return_annotation)
|
|
26
|
+
|
|
27
|
+
# Create a wrapper function dynamically
|
|
28
|
+
@functools.wraps(func)
|
|
29
|
+
def wrapper(*args, **kwargs):
|
|
30
|
+
return func(instance, *args, **kwargs)
|
|
31
|
+
|
|
32
|
+
# Update the __signature__ to reflect the new one without 'self'
|
|
33
|
+
wrapper.__signature__ = new_sig
|
|
34
|
+
|
|
35
|
+
return wrapper
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class SmolAgentsCompute(PhysicalOperator):
|
|
39
|
+
"""
|
|
40
|
+
"""
|
|
41
|
+
def __init__(self, context_id: str, instruction: str, additional_contexts: list[Context] | None = None, *args, **kwargs):
|
|
42
|
+
super().__init__(*args, **kwargs)
|
|
43
|
+
self.context_id = context_id
|
|
44
|
+
self.instruction = instruction
|
|
45
|
+
self.additional_contexts = [] if additional_contexts is None else additional_contexts
|
|
46
|
+
# self.model_id = "anthropic/claude-3-7-sonnet-latest"
|
|
47
|
+
self.model_id = "openai/gpt-4o-mini-2024-07-18"
|
|
48
|
+
# self.model_id = "openai/gpt-4o-2024-08-06"
|
|
49
|
+
api_key = os.getenv("ANTHROPIC_API_KEY") if "anthropic" in self.model_id else os.getenv("OPENAI_API_KEY")
|
|
50
|
+
self.model = LiteLLMModel(model_id=self.model_id, api_key=api_key)
|
|
51
|
+
|
|
52
|
+
def __str__(self):
|
|
53
|
+
op = super().__str__()
|
|
54
|
+
op += f" Context ID: {self.context_id:20s}\n"
|
|
55
|
+
op += f" Instruction: {self.instruction:20s}\n"
|
|
56
|
+
op += f" Add. Ctxs: {self.additional_contexts}\n"
|
|
57
|
+
return op
|
|
58
|
+
|
|
59
|
+
def get_id_params(self):
|
|
60
|
+
id_params = super().get_id_params()
|
|
61
|
+
return {
|
|
62
|
+
"context_id": self.context_id,
|
|
63
|
+
"instruction": self.instruction,
|
|
64
|
+
"additional_contexts": self.additional_contexts,
|
|
65
|
+
**id_params,
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
def get_op_params(self):
|
|
69
|
+
op_params = super().get_op_params()
|
|
70
|
+
return {
|
|
71
|
+
"context_id": self.context_id,
|
|
72
|
+
"instruction": self.instruction,
|
|
73
|
+
"additional_contexts": self.additional_contexts,
|
|
74
|
+
**op_params,
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
78
|
+
return OperatorCostEstimates(
|
|
79
|
+
cardinality=source_op_cost_estimates.cardinality,
|
|
80
|
+
time_per_record=100,
|
|
81
|
+
cost_per_record=1,
|
|
82
|
+
quality=1.0,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
def _create_record_set(
|
|
86
|
+
self,
|
|
87
|
+
candidate: DataRecord,
|
|
88
|
+
generation_stats: GenerationStats,
|
|
89
|
+
total_time: float,
|
|
90
|
+
answer: dict[str, Any],
|
|
91
|
+
) -> DataRecordSet:
|
|
92
|
+
"""
|
|
93
|
+
Given an input DataRecord and a determination of whether it passed the filter or not,
|
|
94
|
+
construct the resulting RecordSet.
|
|
95
|
+
"""
|
|
96
|
+
# create new DataRecord and set passed_operator attribute
|
|
97
|
+
dr = DataRecord.from_parent(self.output_schema, parent_record=candidate)
|
|
98
|
+
for field in self.output_schema.model_fields:
|
|
99
|
+
if field in answer:
|
|
100
|
+
dr[field] = answer[field]
|
|
101
|
+
|
|
102
|
+
# create RecordOpStats object
|
|
103
|
+
record_op_stats = RecordOpStats(
|
|
104
|
+
record_id=dr.id,
|
|
105
|
+
record_parent_ids=dr.parent_ids,
|
|
106
|
+
record_source_indices=dr.source_indices,
|
|
107
|
+
record_state=dr.to_dict(include_bytes=False),
|
|
108
|
+
full_op_id=self.get_full_op_id(),
|
|
109
|
+
logical_op_id=self.logical_op_id,
|
|
110
|
+
op_name=self.op_name(),
|
|
111
|
+
time_per_record=total_time,
|
|
112
|
+
cost_per_record=generation_stats.cost_per_record,
|
|
113
|
+
model_name=self.get_model_name(),
|
|
114
|
+
total_input_tokens=generation_stats.total_input_tokens,
|
|
115
|
+
total_output_tokens=generation_stats.total_output_tokens,
|
|
116
|
+
total_input_cost=generation_stats.total_input_cost,
|
|
117
|
+
total_output_cost=generation_stats.total_output_cost,
|
|
118
|
+
llm_call_duration_secs=generation_stats.llm_call_duration_secs,
|
|
119
|
+
fn_call_duration_secs=generation_stats.fn_call_duration_secs,
|
|
120
|
+
total_llm_calls=generation_stats.total_llm_calls,
|
|
121
|
+
total_embedding_llm_calls=generation_stats.total_embedding_llm_calls,
|
|
122
|
+
answer={k: v.description if isinstance(v, Context) else v for k, v in answer.items()},
|
|
123
|
+
op_details={k: str(v) for k, v in self.get_id_params().items()},
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
return DataRecordSet([dr], [record_op_stats])
|
|
127
|
+
|
|
128
|
+
def __call__(self, candidate: DataRecord) -> Any:
|
|
129
|
+
start_time = time.time()
|
|
130
|
+
|
|
131
|
+
# get the input context object and its tools
|
|
132
|
+
input_context: Context = candidate.context
|
|
133
|
+
description = input_context.description
|
|
134
|
+
tools = [tool(make_tool(f)) for f in input_context.tools]
|
|
135
|
+
|
|
136
|
+
# update the description to include any additional contexts
|
|
137
|
+
for ctx in self.additional_contexts:
|
|
138
|
+
# TODO: remove additional context if it is an ancestor of the input context
|
|
139
|
+
# (not just if it is equal to the input context)
|
|
140
|
+
if ctx.id == input_context.id:
|
|
141
|
+
continue
|
|
142
|
+
description += f"\n\nHere is some additional Context which may be useful:\n\n{ctx.description}"
|
|
143
|
+
|
|
144
|
+
# perform the computation
|
|
145
|
+
instructions = f"\n\nHere is a description of the Context whose data you will be working with, as well as any previously computed results:\n\n{description}"
|
|
146
|
+
agent = CodeAgent(
|
|
147
|
+
tools=tools,
|
|
148
|
+
model=self.model,
|
|
149
|
+
add_base_tools=False,
|
|
150
|
+
instructions=instructions,
|
|
151
|
+
return_full_result=True,
|
|
152
|
+
additional_authorized_imports=["pandas", "io", "os"],
|
|
153
|
+
planning_interval=4,
|
|
154
|
+
max_steps=30,
|
|
155
|
+
)
|
|
156
|
+
result = agent.run(self.instruction)
|
|
157
|
+
# NOTE: you can see the system prompt with `agent.memory.system_prompt.system_prompt`
|
|
158
|
+
# full_steps = agent.memory.get_full_steps()
|
|
159
|
+
|
|
160
|
+
# compute generation stats
|
|
161
|
+
response = result.output
|
|
162
|
+
input_tokens = result.token_usage.input_tokens
|
|
163
|
+
output_tokens = result.token_usage.output_tokens
|
|
164
|
+
cost_per_input_token = (3.0 / 1e6) if "anthropic" in self.model_id else (0.15 / 1e6) # (2.5 / 1e6) #
|
|
165
|
+
cost_per_output_token = (15.0 / 1e6) if "anthropic" in self.model_id else (0.6 / 1e6) # (10.0 / 1e6) #
|
|
166
|
+
input_cost = input_tokens * cost_per_input_token
|
|
167
|
+
output_cost = output_tokens * cost_per_output_token
|
|
168
|
+
generation_stats = GenerationStats(
|
|
169
|
+
model_name=self.model_id,
|
|
170
|
+
total_input_tokens=input_tokens,
|
|
171
|
+
total_output_tokens=output_tokens,
|
|
172
|
+
total_input_cost=input_cost,
|
|
173
|
+
total_output_cost=output_cost,
|
|
174
|
+
cost_per_record=input_cost + output_cost,
|
|
175
|
+
llm_call_duration_secs=time.time() - start_time,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
# update the description of the computed Context to include the result
|
|
179
|
+
new_description = f"RESULT: {response}\n\n"
|
|
180
|
+
cm = ContextManager()
|
|
181
|
+
cm.update_context(id=self.context_id, description=new_description)
|
|
182
|
+
|
|
183
|
+
# create and return record set
|
|
184
|
+
field_answers = {
|
|
185
|
+
"context": cm.get_context(id=self.context_id),
|
|
186
|
+
f"result-{self.context_id}": response,
|
|
187
|
+
}
|
|
188
|
+
record_set = self._create_record_set(
|
|
189
|
+
candidate,
|
|
190
|
+
generation_stats,
|
|
191
|
+
time.time() - start_time,
|
|
192
|
+
field_answers,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
return record_set
|
|
196
|
+
|
|
197
|
+
# import json; json.dumps(agent.memory.get_full_steps())
|
|
198
|
+
# agent.memory.get_full_steps()[1].keys()
|
|
199
|
+
# dict_keys(['step_number', 'timing', 'model_input_messages', 'tool_calls', 'error', 'model_output_message', 'model_output', 'code_action', 'observations', 'observations_images',
|
|
200
|
+
# 'action_output', 'token_usage', 'is_final_answer'])
|
|
201
|
+
# agent.memory.get_full_steps()[1]['action_output']
|
|
@@ -4,6 +4,8 @@ import time
|
|
|
4
4
|
from abc import ABC, abstractmethod
|
|
5
5
|
from typing import Callable
|
|
6
6
|
|
|
7
|
+
from pydantic.fields import FieldInfo
|
|
8
|
+
|
|
7
9
|
from palimpzest.constants import (
|
|
8
10
|
MODEL_CARDS,
|
|
9
11
|
NAIVE_EST_NUM_INPUT_TOKENS,
|
|
@@ -13,12 +15,10 @@ from palimpzest.constants import (
|
|
|
13
15
|
Model,
|
|
14
16
|
PromptStrategy,
|
|
15
17
|
)
|
|
16
|
-
from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates, RecordOpStats
|
|
17
18
|
from palimpzest.core.elements.records import DataRecord, DataRecordSet
|
|
18
|
-
from palimpzest.core.
|
|
19
|
-
from palimpzest.query.generators.generators import
|
|
19
|
+
from palimpzest.core.models import GenerationStats, OperatorCostEstimates, RecordOpStats
|
|
20
|
+
from palimpzest.query.generators.generators import Generator
|
|
20
21
|
from palimpzest.query.operators.physical import PhysicalOperator
|
|
21
|
-
from palimpzest.utils.model_helpers import get_vision_models
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
class ConvertOp(PhysicalOperator, ABC):
|
|
@@ -26,14 +26,12 @@ class ConvertOp(PhysicalOperator, ABC):
|
|
|
26
26
|
self,
|
|
27
27
|
cardinality: Cardinality = Cardinality.ONE_TO_ONE,
|
|
28
28
|
udf: Callable | None = None,
|
|
29
|
-
desc: str | None = None,
|
|
30
29
|
*args,
|
|
31
30
|
**kwargs,
|
|
32
31
|
):
|
|
33
32
|
super().__init__(*args, **kwargs)
|
|
34
33
|
self.cardinality = cardinality
|
|
35
34
|
self.udf = udf
|
|
36
|
-
self.desc = desc
|
|
37
35
|
|
|
38
36
|
def get_id_params(self):
|
|
39
37
|
id_params = super().get_id_params()
|
|
@@ -47,7 +45,7 @@ class ConvertOp(PhysicalOperator, ABC):
|
|
|
47
45
|
|
|
48
46
|
def get_op_params(self):
|
|
49
47
|
op_params = super().get_op_params()
|
|
50
|
-
op_params = {"cardinality": self.cardinality, "udf": self.udf,
|
|
48
|
+
op_params = {"cardinality": self.cardinality, "udf": self.udf, **op_params}
|
|
51
49
|
|
|
52
50
|
return op_params
|
|
53
51
|
|
|
@@ -78,8 +76,8 @@ class ConvertOp(PhysicalOperator, ABC):
|
|
|
78
76
|
setattr(dr, field, getattr(candidate, field))
|
|
79
77
|
|
|
80
78
|
# get input field names and output field names
|
|
81
|
-
input_fields = self.input_schema.
|
|
82
|
-
output_fields = self.output_schema.
|
|
79
|
+
input_fields = list(self.input_schema.model_fields)
|
|
80
|
+
output_fields = list(self.output_schema.model_fields)
|
|
83
81
|
|
|
84
82
|
# parse newly generated fields from the field_answers dictionary for this field; if the list
|
|
85
83
|
# of generated values is shorter than the number of records, we fill in with None
|
|
@@ -112,8 +110,8 @@ class ConvertOp(PhysicalOperator, ABC):
|
|
|
112
110
|
record_op_stats_lst = [
|
|
113
111
|
RecordOpStats(
|
|
114
112
|
record_id=dr.id,
|
|
115
|
-
|
|
116
|
-
|
|
113
|
+
record_parent_ids=dr.parent_ids,
|
|
114
|
+
record_source_indices=dr.source_indices,
|
|
117
115
|
record_state=dr.to_dict(include_bytes=False),
|
|
118
116
|
full_op_id=self.get_full_op_id(),
|
|
119
117
|
logical_op_id=self.logical_op_id,
|
|
@@ -122,7 +120,7 @@ class ConvertOp(PhysicalOperator, ABC):
|
|
|
122
120
|
cost_per_record=per_record_stats.cost_per_record,
|
|
123
121
|
model_name=self.get_model_name(),
|
|
124
122
|
answer={field_name: getattr(dr, field_name) for field_name in field_names},
|
|
125
|
-
input_fields=self.input_schema.
|
|
123
|
+
input_fields=list(self.input_schema.model_fields),
|
|
126
124
|
generated_fields=field_names,
|
|
127
125
|
total_input_tokens=per_record_stats.total_input_tokens,
|
|
128
126
|
total_output_tokens=per_record_stats.total_output_tokens,
|
|
@@ -148,7 +146,7 @@ class ConvertOp(PhysicalOperator, ABC):
|
|
|
148
146
|
pass
|
|
149
147
|
|
|
150
148
|
@abstractmethod
|
|
151
|
-
def convert(self, candidate: DataRecord, fields: dict[str,
|
|
149
|
+
def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[str, list], GenerationStats]:
|
|
152
150
|
"""
|
|
153
151
|
This abstract method will be implemented by subclasses of ConvertOp to process the input DataRecord
|
|
154
152
|
and generate the value(s) for each of the specified fields. If the convert operator is a one-to-many
|
|
@@ -182,7 +180,7 @@ class ConvertOp(PhysicalOperator, ABC):
|
|
|
182
180
|
|
|
183
181
|
# execute the convert
|
|
184
182
|
field_answers: dict[str, list]
|
|
185
|
-
fields = {field: field_type for field, field_type in self.output_schema.
|
|
183
|
+
fields = {field: field_type for field, field_type in self.output_schema.model_fields.items() if field in fields_to_generate}
|
|
186
184
|
field_answers, generation_stats = self.convert(candidate=candidate, fields=fields)
|
|
187
185
|
assert all([field in field_answers for field in fields_to_generate]), "Not all fields were generated!"
|
|
188
186
|
|
|
@@ -235,7 +233,7 @@ class NonLLMConvert(ConvertOp):
|
|
|
235
233
|
quality=1.0,
|
|
236
234
|
)
|
|
237
235
|
|
|
238
|
-
def convert(self, candidate: DataRecord, fields: dict[str,
|
|
236
|
+
def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[str, list], GenerationStats]:
|
|
239
237
|
# apply UDF to input record
|
|
240
238
|
start_time = time.time()
|
|
241
239
|
field_answers = {}
|
|
@@ -282,18 +280,21 @@ class LLMConvert(ConvertOp):
|
|
|
282
280
|
self,
|
|
283
281
|
model: Model,
|
|
284
282
|
prompt_strategy: PromptStrategy = PromptStrategy.COT_QA,
|
|
283
|
+
reasoning_effort: str | None = None,
|
|
285
284
|
*args,
|
|
286
285
|
**kwargs,
|
|
287
286
|
):
|
|
288
287
|
super().__init__(*args, **kwargs)
|
|
289
288
|
self.model = model
|
|
290
289
|
self.prompt_strategy = prompt_strategy
|
|
290
|
+
self.reasoning_effort = reasoning_effort
|
|
291
291
|
if model is not None:
|
|
292
|
-
self.generator =
|
|
292
|
+
self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base, self.cardinality, self.verbose)
|
|
293
293
|
|
|
294
294
|
def __str__(self):
|
|
295
295
|
op = super().__str__()
|
|
296
296
|
op += f" Prompt Strategy: {self.prompt_strategy}\n"
|
|
297
|
+
op += f" Reasoning Effort: {self.reasoning_effort}\n"
|
|
297
298
|
return op
|
|
298
299
|
|
|
299
300
|
def get_id_params(self):
|
|
@@ -301,6 +302,7 @@ class LLMConvert(ConvertOp):
|
|
|
301
302
|
id_params = {
|
|
302
303
|
"model": None if self.model is None else self.model.value,
|
|
303
304
|
"prompt_strategy": None if self.prompt_strategy is None else self.prompt_strategy.value,
|
|
305
|
+
"reasoning_effort": self.reasoning_effort,
|
|
304
306
|
**id_params,
|
|
305
307
|
}
|
|
306
308
|
|
|
@@ -311,6 +313,7 @@ class LLMConvert(ConvertOp):
|
|
|
311
313
|
op_params = {
|
|
312
314
|
"model": self.model,
|
|
313
315
|
"prompt_strategy": self.prompt_strategy,
|
|
316
|
+
"reasoning_effort": self.reasoning_effort,
|
|
314
317
|
**op_params,
|
|
315
318
|
}
|
|
316
319
|
|
|
@@ -320,7 +323,7 @@ class LLMConvert(ConvertOp):
|
|
|
320
323
|
return None if self.model is None else self.model.value
|
|
321
324
|
|
|
322
325
|
def is_image_conversion(self) -> bool:
|
|
323
|
-
return self.
|
|
326
|
+
return self.prompt_strategy.is_image_prompt()
|
|
324
327
|
|
|
325
328
|
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
326
329
|
"""
|
|
@@ -334,13 +337,16 @@ class LLMConvert(ConvertOp):
|
|
|
334
337
|
est_num_output_tokens = NAIVE_EST_NUM_OUTPUT_TOKENS
|
|
335
338
|
|
|
336
339
|
# get est. of conversion time per record from model card;
|
|
337
|
-
# NOTE: model will only be None for code synthesis, which uses GPT-3.5 as fallback
|
|
338
340
|
model_name = self.model.value if getattr(self, "model", None) is not None else Model.GPT_4o_MINI.value
|
|
339
341
|
model_conversion_time_per_record = MODEL_CARDS[model_name]["seconds_per_output_token"] * est_num_output_tokens
|
|
340
342
|
|
|
341
343
|
# get est. of conversion cost (in USD) per record from model card
|
|
344
|
+
usd_per_input_token = MODEL_CARDS[model_name].get("usd_per_input_token")
|
|
345
|
+
if getattr(self, "prompt_strategy", None) is not None and self.prompt_strategy.is_audio_prompt():
|
|
346
|
+
usd_per_input_token = MODEL_CARDS[model_name]["usd_per_audio_input_token"]
|
|
347
|
+
|
|
342
348
|
model_conversion_usd_per_record = (
|
|
343
|
-
|
|
349
|
+
usd_per_input_token * est_num_input_tokens
|
|
344
350
|
+ MODEL_CARDS[model_name]["usd_per_output_token"] * est_num_output_tokens
|
|
345
351
|
)
|
|
346
352
|
|
|
@@ -349,7 +355,7 @@ class LLMConvert(ConvertOp):
|
|
|
349
355
|
cardinality = selectivity * source_op_cost_estimates.cardinality
|
|
350
356
|
|
|
351
357
|
# estimate quality of output based on the strength of the model being used
|
|
352
|
-
quality = (MODEL_CARDS[model_name]["overall"] / 100.0)
|
|
358
|
+
quality = (MODEL_CARDS[model_name]["overall"] / 100.0)
|
|
353
359
|
|
|
354
360
|
return OperatorCostEstimates(
|
|
355
361
|
cardinality=cardinality,
|
|
@@ -361,7 +367,7 @@ class LLMConvert(ConvertOp):
|
|
|
361
367
|
|
|
362
368
|
class LLMConvertBonded(LLMConvert):
|
|
363
369
|
|
|
364
|
-
def convert(self, candidate: DataRecord, fields: dict[str,
|
|
370
|
+
def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[str, list], GenerationStats]:
|
|
365
371
|
# get the set of input fields to use for the convert operation
|
|
366
372
|
input_fields = self.get_input_fields()
|
|
367
373
|
|
|
@@ -2,10 +2,12 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
|
+
from pydantic.fields import FieldInfo
|
|
6
|
+
|
|
5
7
|
from palimpzest.constants import MODEL_CARDS, Model, PromptStrategy
|
|
6
|
-
from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates
|
|
7
8
|
from palimpzest.core.elements.records import DataRecord
|
|
8
|
-
from palimpzest.
|
|
9
|
+
from palimpzest.core.models import GenerationStats, OperatorCostEstimates
|
|
10
|
+
from palimpzest.query.generators.generators import Generator
|
|
9
11
|
from palimpzest.query.operators.convert import LLMConvert
|
|
10
12
|
|
|
11
13
|
# TYPE DEFINITIONS
|
|
@@ -35,8 +37,8 @@ class CriticAndRefineConvert(LLMConvert):
|
|
|
35
37
|
raise ValueError(f"Unsupported prompt strategy: {self.prompt_strategy}")
|
|
36
38
|
|
|
37
39
|
# create generators
|
|
38
|
-
self.critic_generator =
|
|
39
|
-
self.refine_generator =
|
|
40
|
+
self.critic_generator = Generator(self.critic_model, self.critic_prompt_strategy, self.reasoning_effort, self.api_base, self.cardinality, self.verbose)
|
|
41
|
+
self.refine_generator = Generator(self.refine_model, self.refinement_prompt_strategy, self.reasoning_effort, self.api_base, self.cardinality, self.verbose)
|
|
40
42
|
|
|
41
43
|
def __str__(self):
|
|
42
44
|
op = super().__str__()
|
|
@@ -86,7 +88,7 @@ class CriticAndRefineConvert(LLMConvert):
|
|
|
86
88
|
|
|
87
89
|
return naive_op_cost_estimates
|
|
88
90
|
|
|
89
|
-
def convert(self, candidate: DataRecord, fields:
|
|
91
|
+
def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[FieldName, list[Any]], GenerationStats]:
|
|
90
92
|
# get input fields
|
|
91
93
|
input_fields = self.get_input_fields()
|
|
92
94
|
|