palimpzest 0.6.4__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/__init__.py +5 -0
- palimpzest/constants.py +110 -43
- palimpzest/core/__init__.py +0 -78
- palimpzest/core/data/dataclasses.py +382 -44
- palimpzest/core/elements/filters.py +7 -3
- palimpzest/core/elements/index.py +70 -0
- palimpzest/core/elements/records.py +33 -11
- palimpzest/core/lib/fields.py +1 -0
- palimpzest/core/lib/schemas.py +4 -3
- palimpzest/prompts/moa_proposer_convert_prompts.py +0 -4
- palimpzest/prompts/prompt_factory.py +44 -7
- palimpzest/prompts/split_merge_prompts.py +56 -0
- palimpzest/prompts/split_proposer_prompts.py +55 -0
- palimpzest/query/execution/execution_strategy.py +435 -53
- palimpzest/query/execution/execution_strategy_type.py +20 -0
- palimpzest/query/execution/mab_execution_strategy.py +532 -0
- palimpzest/query/execution/parallel_execution_strategy.py +143 -172
- palimpzest/query/execution/random_sampling_execution_strategy.py +240 -0
- palimpzest/query/execution/single_threaded_execution_strategy.py +173 -203
- palimpzest/query/generators/api_client_factory.py +31 -0
- palimpzest/query/generators/generators.py +256 -76
- palimpzest/query/operators/__init__.py +1 -2
- palimpzest/query/operators/code_synthesis_convert.py +33 -18
- palimpzest/query/operators/convert.py +30 -97
- palimpzest/query/operators/critique_and_refine_convert.py +5 -6
- palimpzest/query/operators/filter.py +7 -10
- palimpzest/query/operators/logical.py +54 -10
- palimpzest/query/operators/map.py +130 -0
- palimpzest/query/operators/mixture_of_agents_convert.py +6 -6
- palimpzest/query/operators/physical.py +3 -12
- palimpzest/query/operators/rag_convert.py +66 -18
- palimpzest/query/operators/retrieve.py +230 -34
- palimpzest/query/operators/scan.py +5 -2
- palimpzest/query/operators/split_convert.py +169 -0
- palimpzest/query/operators/token_reduction_convert.py +8 -14
- palimpzest/query/optimizer/__init__.py +4 -16
- palimpzest/query/optimizer/cost_model.py +73 -266
- palimpzest/query/optimizer/optimizer.py +87 -58
- palimpzest/query/optimizer/optimizer_strategy.py +18 -97
- palimpzest/query/optimizer/optimizer_strategy_type.py +37 -0
- palimpzest/query/optimizer/plan.py +2 -3
- palimpzest/query/optimizer/primitives.py +5 -3
- palimpzest/query/optimizer/rules.py +336 -172
- palimpzest/query/optimizer/tasks.py +30 -100
- palimpzest/query/processor/config.py +38 -22
- palimpzest/query/processor/nosentinel_processor.py +16 -520
- palimpzest/query/processor/processing_strategy_type.py +28 -0
- palimpzest/query/processor/query_processor.py +38 -206
- palimpzest/query/processor/query_processor_factory.py +117 -130
- palimpzest/query/processor/sentinel_processor.py +90 -0
- palimpzest/query/processor/streaming_processor.py +25 -32
- palimpzest/sets.py +88 -41
- palimpzest/utils/model_helpers.py +8 -7
- palimpzest/utils/progress.py +368 -152
- palimpzest/utils/token_reduction_helpers.py +1 -3
- {palimpzest-0.6.4.dist-info → palimpzest-0.7.1.dist-info}/METADATA +19 -9
- palimpzest-0.7.1.dist-info/RECORD +96 -0
- {palimpzest-0.6.4.dist-info → palimpzest-0.7.1.dist-info}/WHEEL +1 -1
- palimpzest/query/processor/mab_sentinel_processor.py +0 -884
- palimpzest/query/processor/random_sampling_sentinel_processor.py +0 -639
- palimpzest/utils/index_helpers.py +0 -6
- palimpzest-0.6.4.dist-info/RECORD +0 -87
- {palimpzest-0.6.4.dist-info → palimpzest-0.7.1.dist-info/licenses}/LICENSE +0 -0
- {palimpzest-0.6.4.dist-info → palimpzest-0.7.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from typing import Callable
|
|
5
|
+
|
|
6
|
+
from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates, RecordOpStats
|
|
7
|
+
from palimpzest.core.elements.records import DataRecord, DataRecordSet
|
|
8
|
+
from palimpzest.core.lib.fields import Field
|
|
9
|
+
from palimpzest.query.operators.physical import PhysicalOperator
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MapOp(PhysicalOperator):
|
|
13
|
+
def __init__(self, udf: Callable | None = None, *args, **kwargs):
|
|
14
|
+
super().__init__(*args, **kwargs)
|
|
15
|
+
self.udf = udf
|
|
16
|
+
|
|
17
|
+
def __str__(self):
|
|
18
|
+
op = super().__str__()
|
|
19
|
+
op += f" UDF: {self.udf.__name__}\n"
|
|
20
|
+
return op
|
|
21
|
+
|
|
22
|
+
def get_id_params(self):
|
|
23
|
+
id_params = super().get_id_params()
|
|
24
|
+
id_params = {"udf": self.udf, **id_params}
|
|
25
|
+
|
|
26
|
+
return id_params
|
|
27
|
+
|
|
28
|
+
def get_op_params(self):
|
|
29
|
+
op_params = super().get_op_params()
|
|
30
|
+
op_params = {"udf": self.udf, **op_params}
|
|
31
|
+
|
|
32
|
+
return op_params
|
|
33
|
+
|
|
34
|
+
def _create_record_set(
|
|
35
|
+
self,
|
|
36
|
+
record: DataRecord,
|
|
37
|
+
generation_stats: GenerationStats,
|
|
38
|
+
total_time: float,
|
|
39
|
+
) -> DataRecordSet:
|
|
40
|
+
"""
|
|
41
|
+
Given an input DataRecord and a determination of whether it passed the filter or not,
|
|
42
|
+
construct the resulting RecordSet.
|
|
43
|
+
"""
|
|
44
|
+
# create RecordOpStats object
|
|
45
|
+
record_op_stats = RecordOpStats(
|
|
46
|
+
record_id=record.id,
|
|
47
|
+
record_parent_id=record.parent_id,
|
|
48
|
+
record_source_idx=record.source_idx,
|
|
49
|
+
record_state=record.to_dict(include_bytes=False),
|
|
50
|
+
op_id=self.get_op_id(),
|
|
51
|
+
logical_op_id=self.logical_op_id,
|
|
52
|
+
op_name=self.op_name(),
|
|
53
|
+
time_per_record=total_time,
|
|
54
|
+
cost_per_record=0.0,
|
|
55
|
+
fn_call_duration_secs=generation_stats.fn_call_duration_secs,
|
|
56
|
+
answer=record.to_dict(include_bytes=False),
|
|
57
|
+
op_details={k: str(v) for k, v in self.get_id_params().items()},
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
return DataRecordSet([record], [record_op_stats])
|
|
61
|
+
|
|
62
|
+
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
63
|
+
"""
|
|
64
|
+
Compute naive cost estimates for the Map operation. These estimates assume that the map UDF
|
|
65
|
+
(1) has no cost and (2) has perfect quality.
|
|
66
|
+
"""
|
|
67
|
+
# estimate 1 ms single-threaded execution for udf function
|
|
68
|
+
time_per_record = 0.001
|
|
69
|
+
|
|
70
|
+
# assume filter fn has perfect quality
|
|
71
|
+
return OperatorCostEstimates(
|
|
72
|
+
cardinality=source_op_cost_estimates.cardinality,
|
|
73
|
+
time_per_record=time_per_record,
|
|
74
|
+
cost_per_record=0.0,
|
|
75
|
+
quality=1.0,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def map(self, candidate: DataRecord, fields: dict[str, Field]) -> tuple[dict[str, list], GenerationStats]:
|
|
79
|
+
# apply UDF to input record
|
|
80
|
+
start_time = time.time()
|
|
81
|
+
field_answers = {}
|
|
82
|
+
try:
|
|
83
|
+
# execute the UDF function
|
|
84
|
+
field_answers = self.udf(candidate.to_dict())
|
|
85
|
+
|
|
86
|
+
# answer should be a dictionary
|
|
87
|
+
assert isinstance(field_answers, dict), (
|
|
88
|
+
"UDF must return a dictionary mapping each input field to its value for map operations"
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
if self.verbose:
|
|
92
|
+
print(f"{self.udf.__name__}")
|
|
93
|
+
|
|
94
|
+
except Exception as e:
|
|
95
|
+
print(f"Error invoking user-defined function for map: {e}")
|
|
96
|
+
raise e
|
|
97
|
+
|
|
98
|
+
# create generation stats object containing the time spent executing the UDF function
|
|
99
|
+
generation_stats = GenerationStats(fn_call_duration_secs=time.time() - start_time)
|
|
100
|
+
|
|
101
|
+
return field_answers, generation_stats
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def __call__(self, candidate: DataRecord) -> DataRecordSet:
|
|
105
|
+
"""
|
|
106
|
+
This method converts an input DataRecord into an output DataRecordSet. The output DataRecordSet contains the
|
|
107
|
+
DataRecord(s) output by the operator's convert() method and their corresponding RecordOpStats objects.
|
|
108
|
+
Some subclasses may override this __call__method to implement their own custom logic.
|
|
109
|
+
"""
|
|
110
|
+
start_time = time.time()
|
|
111
|
+
|
|
112
|
+
# execute the map operation
|
|
113
|
+
field_answers: dict[str, list]
|
|
114
|
+
fields = {field: field_type for field, field_type in self.output_schema.field_map().items()}
|
|
115
|
+
field_answers, generation_stats = self.map(candidate=candidate, fields=fields)
|
|
116
|
+
assert all([field in field_answers for field in fields]), "Not all fields are present in output of map!"
|
|
117
|
+
|
|
118
|
+
# construct DataRecord from field_answers
|
|
119
|
+
dr = DataRecord.from_parent(schema=self.output_schema, parent_record=candidate)
|
|
120
|
+
for field_name, field_value in field_answers.items():
|
|
121
|
+
dr[field_name] = field_value
|
|
122
|
+
|
|
123
|
+
# construct and return DataRecordSet
|
|
124
|
+
record_set = self._create_record_set(
|
|
125
|
+
record=dr,
|
|
126
|
+
generation_stats=generation_stats,
|
|
127
|
+
total_time=time.time() - start_time,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
return record_set
|
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
4
|
-
|
|
5
3
|
from palimpzest.constants import MODEL_CARDS, Model, PromptStrategy
|
|
6
4
|
from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates
|
|
7
5
|
from palimpzest.core.elements.records import DataRecord
|
|
6
|
+
from palimpzest.core.lib.fields import Field
|
|
8
7
|
from palimpzest.query.generators.generators import generator_factory
|
|
9
8
|
from palimpzest.query.operators.convert import LLMConvert
|
|
10
9
|
|
|
@@ -112,7 +111,7 @@ class MixtureOfAgentsConvert(LLMConvert):
|
|
|
112
111
|
|
|
113
112
|
return naive_op_cost_estimates
|
|
114
113
|
|
|
115
|
-
def convert(self, candidate: DataRecord, fields:
|
|
114
|
+
def convert(self, candidate: DataRecord, fields: dict[str, Field]) -> tuple[dict[str, list], GenerationStats]:
|
|
116
115
|
# get input fields
|
|
117
116
|
input_fields = self.get_input_fields()
|
|
118
117
|
|
|
@@ -120,8 +119,9 @@ class MixtureOfAgentsConvert(LLMConvert):
|
|
|
120
119
|
proposer_model_final_answers, proposer_model_generation_stats = [], []
|
|
121
120
|
for proposer_generator, temperature in zip(self.proposer_generators, self.temperatures):
|
|
122
121
|
gen_kwargs = {"project_cols": input_fields, "output_schema": self.output_schema, "temperature": temperature}
|
|
123
|
-
|
|
124
|
-
|
|
122
|
+
_, reasoning, generation_stats, _ = proposer_generator(candidate, fields, json_output=False, **gen_kwargs)
|
|
123
|
+
proposer_text = f"REASONING:{reasoning}\n"
|
|
124
|
+
proposer_model_final_answers.append(proposer_text)
|
|
125
125
|
proposer_model_generation_stats.append(generation_stats)
|
|
126
126
|
|
|
127
127
|
# call the aggregator
|
|
@@ -130,7 +130,7 @@ class MixtureOfAgentsConvert(LLMConvert):
|
|
|
130
130
|
"output_schema": self.output_schema,
|
|
131
131
|
"model_responses": proposer_model_final_answers,
|
|
132
132
|
}
|
|
133
|
-
field_answers, _, aggregator_gen_stats = self.aggregator_generator(candidate, fields, **gen_kwargs)
|
|
133
|
+
field_answers, _, aggregator_gen_stats, _ = self.aggregator_generator(candidate, fields, **gen_kwargs)
|
|
134
134
|
|
|
135
135
|
# compute the total generation stats
|
|
136
136
|
generation_stats = sum(proposer_model_generation_stats) + aggregator_gen_stats
|
|
@@ -125,6 +125,9 @@ class PhysicalOperator:
|
|
|
125
125
|
self.op_id = hash_for_id(hash_str)
|
|
126
126
|
|
|
127
127
|
return self.op_id
|
|
128
|
+
|
|
129
|
+
def get_logical_op_id(self) -> str | None:
|
|
130
|
+
return self.logical_op_id
|
|
128
131
|
|
|
129
132
|
def __hash__(self):
|
|
130
133
|
return int(self.op_id, 16)
|
|
@@ -187,15 +190,3 @@ class PhysicalOperator:
|
|
|
187
190
|
|
|
188
191
|
def __call__(self, candidate: DataRecord) -> DataRecordSet:
|
|
189
192
|
raise NotImplementedError("Calling __call__ from abstract method")
|
|
190
|
-
|
|
191
|
-
@staticmethod
|
|
192
|
-
def execute_op_wrapper(operator: PhysicalOperator, op_input: DataRecord | list[DataRecord] | int) -> tuple[DataRecordSet, PhysicalOperator]:
|
|
193
|
-
"""
|
|
194
|
-
Wrapper function around operator execution which also and returns the operator.
|
|
195
|
-
This is useful in the parallel setting(s) where operators are executed by a worker pool,
|
|
196
|
-
and it is convenient to return the op_id along with the computation result.
|
|
197
|
-
"""
|
|
198
|
-
record_set = operator(op_input)
|
|
199
|
-
|
|
200
|
-
return record_set, operator, op_input
|
|
201
|
-
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
import time
|
|
4
4
|
|
|
5
5
|
from numpy import dot
|
|
6
6
|
from numpy.linalg import norm
|
|
@@ -9,11 +9,12 @@ from openai import OpenAI
|
|
|
9
9
|
from palimpzest.constants import (
|
|
10
10
|
MODEL_CARDS,
|
|
11
11
|
NAIVE_EST_NUM_OUTPUT_TOKENS,
|
|
12
|
+
Model,
|
|
12
13
|
)
|
|
13
14
|
from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates
|
|
14
15
|
from palimpzest.core.elements.records import DataRecord
|
|
15
|
-
from palimpzest.core.lib.fields import StringField
|
|
16
|
-
from palimpzest.query.operators.convert import
|
|
16
|
+
from palimpzest.core.lib.fields import Field, StringField
|
|
17
|
+
from palimpzest.query.operators.convert import LLMConvert
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
class RAGConvert(LLMConvert):
|
|
@@ -21,7 +22,7 @@ class RAGConvert(LLMConvert):
|
|
|
21
22
|
super().__init__(*args, **kwargs)
|
|
22
23
|
# NOTE: in the future, we should abstract the embedding model to allow for different models
|
|
23
24
|
self.client = None
|
|
24
|
-
self.embedding_model =
|
|
25
|
+
self.embedding_model = Model.TEXT_EMBEDDING_3_SMALL
|
|
25
26
|
self.num_chunks_per_field = num_chunks_per_field
|
|
26
27
|
self.chunk_size = chunk_size
|
|
27
28
|
|
|
@@ -93,13 +94,39 @@ class RAGConvert(LLMConvert):
|
|
|
93
94
|
|
|
94
95
|
return chunks
|
|
95
96
|
|
|
96
|
-
def compute_embedding(self, text: str) -> list[float]:
|
|
97
|
+
def compute_embedding(self, text: str) -> tuple[list[float], GenerationStats]:
|
|
97
98
|
"""
|
|
98
|
-
Compute the embedding for a text string.
|
|
99
|
+
Compute the embedding for a text string. Return the embedding and the GenerationStats object
|
|
100
|
+
that captures the cost of the operation.
|
|
99
101
|
"""
|
|
100
|
-
|
|
102
|
+
# get the embedding model name
|
|
103
|
+
model_name = self.embedding_model.value
|
|
104
|
+
|
|
105
|
+
# compute the embedding
|
|
106
|
+
start_time = time.time()
|
|
107
|
+
response = self.client.embeddings.create(input=text, model=model_name)
|
|
108
|
+
total_time = time.time() - start_time
|
|
109
|
+
|
|
110
|
+
# extract the embedding
|
|
111
|
+
embedding = response.data[0].embedding
|
|
112
|
+
|
|
113
|
+
# compute the generation stats object
|
|
114
|
+
model_card = MODEL_CARDS[model_name]
|
|
115
|
+
total_input_tokens = response.usage.total_tokens
|
|
116
|
+
total_input_cost = model_card["usd_per_input_token"] * total_input_tokens
|
|
117
|
+
embed_stats = GenerationStats(
|
|
118
|
+
model_name=model_name, # NOTE: this should be overwritten by generation model in convert()
|
|
119
|
+
total_input_tokens=total_input_tokens,
|
|
120
|
+
total_output_tokens=0.0,
|
|
121
|
+
total_input_cost=total_input_cost,
|
|
122
|
+
total_output_cost=0.0,
|
|
123
|
+
cost_per_record=total_input_cost,
|
|
124
|
+
llm_call_duration_secs=total_time,
|
|
125
|
+
total_llm_calls=1,
|
|
126
|
+
total_embedding_llm_calls=1,
|
|
127
|
+
)
|
|
101
128
|
|
|
102
|
-
return
|
|
129
|
+
return embedding, embed_stats
|
|
103
130
|
|
|
104
131
|
def compute_similarity(self, query_embedding: list[float], chunk_embedding: list[float]) -> float:
|
|
105
132
|
"""
|
|
@@ -107,18 +134,25 @@ class RAGConvert(LLMConvert):
|
|
|
107
134
|
"""
|
|
108
135
|
return dot(query_embedding, chunk_embedding) / (norm(query_embedding) * norm(chunk_embedding))
|
|
109
136
|
|
|
110
|
-
def get_chunked_candidate(self, candidate: DataRecord, input_fields: list[str], output_fields: list[str]) -> DataRecord:
|
|
137
|
+
def get_chunked_candidate(self, candidate: DataRecord, input_fields: list[str], output_fields: list[str]) -> tuple[DataRecord, GenerationStats]:
|
|
111
138
|
"""
|
|
112
139
|
For each text field, chunk the content and compute the chunk embeddings. Then select the top-k chunks
|
|
113
140
|
for each field. If a field is smaller than the chunk size, simply include the full field.
|
|
114
141
|
"""
|
|
142
|
+
# initialize stats for embedding costs
|
|
143
|
+
embed_stats = GenerationStats()
|
|
144
|
+
|
|
115
145
|
# compute embedding for output fields
|
|
116
146
|
output_fields_desc = ""
|
|
117
147
|
field_desc_map = self.output_schema.field_desc_map()
|
|
118
148
|
for field_name in output_fields:
|
|
119
149
|
output_fields_desc += f"- {field_name}: {field_desc_map[field_name]}\n"
|
|
120
|
-
query_embedding = self.compute_embedding(output_fields_desc)
|
|
150
|
+
query_embedding, query_embed_stats = self.compute_embedding(output_fields_desc)
|
|
151
|
+
|
|
152
|
+
# add cost of embedding the query to embed_stats
|
|
153
|
+
embed_stats += query_embed_stats
|
|
121
154
|
|
|
155
|
+
# for each input field, chunk its content and compute the (per-chunk) embeddings
|
|
122
156
|
for field_name in input_fields:
|
|
123
157
|
field = candidate.get_field_type(field_name)
|
|
124
158
|
|
|
@@ -133,14 +167,18 @@ class RAGConvert(LLMConvert):
|
|
|
133
167
|
candidate[field_name] = "[" + ", ".join(candidate[field_name]) + "]"
|
|
134
168
|
|
|
135
169
|
# skip this field if it is a string field and its length is less than the chunk size
|
|
136
|
-
if
|
|
170
|
+
if len(candidate[field_name]) < self.chunk_size:
|
|
137
171
|
continue
|
|
138
172
|
|
|
139
173
|
# chunk the content
|
|
140
174
|
chunks = self.chunk_text(candidate[field_name], self.chunk_size)
|
|
141
175
|
|
|
142
176
|
# compute embeddings for each chunk
|
|
143
|
-
chunk_embeddings = [self.compute_embedding(chunk) for chunk in chunks]
|
|
177
|
+
chunk_embeddings, chunk_embed_stats_lst = zip(*[self.compute_embedding(chunk) for chunk in chunks])
|
|
178
|
+
|
|
179
|
+
# add cost of embedding each chunk to embed_stats
|
|
180
|
+
for chunk_embed_stats in chunk_embed_stats_lst:
|
|
181
|
+
embed_stats += chunk_embed_stats
|
|
144
182
|
|
|
145
183
|
# select the top-k chunks
|
|
146
184
|
sorted_chunks = sorted(
|
|
@@ -154,29 +192,39 @@ class RAGConvert(LLMConvert):
|
|
|
154
192
|
top_k_chunks = [chunk for _, chunk in sorted(top_k_chunks, key=lambda tup: tup[0])]
|
|
155
193
|
candidate[field_name] = "...".join(top_k_chunks)
|
|
156
194
|
|
|
157
|
-
return candidate
|
|
195
|
+
return candidate, embed_stats
|
|
158
196
|
|
|
159
|
-
def convert(self, candidate: DataRecord, fields:
|
|
197
|
+
def convert(self, candidate: DataRecord, fields: dict[str, Field]) -> tuple[dict[str, list], GenerationStats]:
|
|
160
198
|
# set client
|
|
161
199
|
self.client = OpenAI() if self.client is None else self.client
|
|
162
200
|
|
|
163
201
|
# get the set of input fields to use for the convert operation
|
|
164
202
|
input_fields = self.get_input_fields()
|
|
203
|
+
output_fields = list(fields.keys())
|
|
165
204
|
|
|
166
205
|
# lookup most relevant chunks for each field using embedding search
|
|
167
206
|
candidate_copy = candidate.copy()
|
|
168
|
-
candidate_copy = self.get_chunked_candidate(candidate_copy, input_fields)
|
|
207
|
+
candidate_copy, embed_stats = self.get_chunked_candidate(candidate_copy, input_fields, output_fields)
|
|
169
208
|
|
|
170
209
|
# construct kwargs for generation
|
|
171
210
|
gen_kwargs = {"project_cols": input_fields, "output_schema": self.output_schema}
|
|
172
211
|
|
|
173
212
|
# generate outputs for all fields in a single query
|
|
174
|
-
field_answers, _, generation_stats = self.generator(candidate_copy, fields, **gen_kwargs)
|
|
213
|
+
field_answers, _, generation_stats, _ = self.generator(candidate_copy, fields, **gen_kwargs)
|
|
214
|
+
|
|
215
|
+
# NOTE: summing embedding stats with generation stats is messy because it will lead to misleading
|
|
216
|
+
# measurements of total_input_tokens and total_output_tokens. We should fix this in the future.
|
|
217
|
+
# The good news: as long as we compute the cost_per_record of each GenerationStats object correctly,
|
|
218
|
+
# then the total cost of the operation will be correct (which will roll-up to correctly computing
|
|
219
|
+
# the total cost of the operator, plan, and execution).
|
|
220
|
+
#
|
|
221
|
+
# combine stats from embedding with stats for generation
|
|
222
|
+
generation_stats += embed_stats
|
|
175
223
|
|
|
176
224
|
# if there was an error for any field, execute a conventional query on that field
|
|
177
|
-
for
|
|
225
|
+
for field_name, answers in field_answers.items():
|
|
178
226
|
if answers is None:
|
|
179
|
-
single_field_answers, _, single_field_stats = self.generator(candidate_copy, [
|
|
227
|
+
single_field_answers, _, single_field_stats, _ = self.generator(candidate_copy, {field_name: fields[field_name]}, **gen_kwargs)
|
|
180
228
|
field_answers.update(single_field_answers)
|
|
181
229
|
generation_stats += single_field_stats
|
|
182
230
|
|