palimpzest 0.7.20__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/__init__.py +37 -6
- palimpzest/agents/__init__.py +0 -0
- palimpzest/agents/compute_agents.py +0 -0
- palimpzest/agents/search_agents.py +637 -0
- palimpzest/constants.py +259 -197
- palimpzest/core/data/context.py +393 -0
- palimpzest/core/data/context_manager.py +163 -0
- palimpzest/core/data/dataset.py +634 -0
- palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
- palimpzest/core/elements/groupbysig.py +16 -13
- palimpzest/core/elements/records.py +166 -75
- palimpzest/core/lib/schemas.py +152 -390
- palimpzest/core/{data/dataclasses.py → models.py} +306 -170
- palimpzest/policy.py +2 -27
- palimpzest/prompts/__init__.py +35 -5
- palimpzest/prompts/agent_prompts.py +357 -0
- palimpzest/prompts/context_search.py +9 -0
- palimpzest/prompts/convert_prompts.py +61 -5
- palimpzest/prompts/filter_prompts.py +50 -5
- palimpzest/prompts/join_prompts.py +163 -0
- palimpzest/prompts/moa_proposer_convert_prompts.py +5 -5
- palimpzest/prompts/prompt_factory.py +358 -46
- palimpzest/prompts/validator.py +239 -0
- palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
- palimpzest/query/execution/execution_strategy.py +210 -317
- palimpzest/query/execution/execution_strategy_type.py +5 -7
- palimpzest/query/execution/mab_execution_strategy.py +249 -136
- palimpzest/query/execution/parallel_execution_strategy.py +153 -244
- palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
- palimpzest/query/generators/generators.py +157 -330
- palimpzest/query/operators/__init__.py +15 -5
- palimpzest/query/operators/aggregate.py +50 -33
- palimpzest/query/operators/compute.py +201 -0
- palimpzest/query/operators/convert.py +27 -21
- palimpzest/query/operators/critique_and_refine_convert.py +7 -5
- palimpzest/query/operators/distinct.py +62 -0
- palimpzest/query/operators/filter.py +22 -13
- palimpzest/query/operators/join.py +402 -0
- palimpzest/query/operators/limit.py +3 -3
- palimpzest/query/operators/logical.py +198 -80
- palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
- palimpzest/query/operators/physical.py +27 -21
- palimpzest/query/operators/project.py +3 -3
- palimpzest/query/operators/rag_convert.py +7 -7
- palimpzest/query/operators/retrieve.py +9 -9
- palimpzest/query/operators/scan.py +81 -42
- palimpzest/query/operators/search.py +524 -0
- palimpzest/query/operators/split_convert.py +10 -8
- palimpzest/query/optimizer/__init__.py +7 -9
- palimpzest/query/optimizer/cost_model.py +108 -441
- palimpzest/query/optimizer/optimizer.py +123 -181
- palimpzest/query/optimizer/optimizer_strategy.py +66 -61
- palimpzest/query/optimizer/plan.py +352 -67
- palimpzest/query/optimizer/primitives.py +43 -19
- palimpzest/query/optimizer/rules.py +484 -646
- palimpzest/query/optimizer/tasks.py +127 -58
- palimpzest/query/processor/config.py +41 -76
- palimpzest/query/processor/query_processor.py +73 -18
- palimpzest/query/processor/query_processor_factory.py +46 -38
- palimpzest/schemabuilder/schema_builder.py +15 -28
- palimpzest/utils/model_helpers.py +27 -77
- palimpzest/utils/progress.py +114 -102
- palimpzest/validator/__init__.py +0 -0
- palimpzest/validator/validator.py +306 -0
- {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/METADATA +6 -1
- palimpzest-0.8.0.dist-info/RECORD +95 -0
- palimpzest/core/lib/fields.py +0 -141
- palimpzest/prompts/code_synthesis_prompts.py +0 -28
- palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
- palimpzest/query/generators/api_client_factory.py +0 -30
- palimpzest/query/operators/code_synthesis_convert.py +0 -488
- palimpzest/query/operators/map.py +0 -130
- palimpzest/query/processor/nosentinel_processor.py +0 -33
- palimpzest/query/processor/processing_strategy_type.py +0 -28
- palimpzest/query/processor/sentinel_processor.py +0 -88
- palimpzest/query/processor/streaming_processor.py +0 -149
- palimpzest/sets.py +0 -405
- palimpzest/utils/datareader_helpers.py +0 -61
- palimpzest/utils/demo_helpers.py +0 -75
- palimpzest/utils/field_helpers.py +0 -69
- palimpzest/utils/generation_helpers.py +0 -69
- palimpzest/utils/sandbox.py +0 -183
- palimpzest-0.7.20.dist-info/RECORD +0 -95
- /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
- {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/WHEEL +0 -0
- {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/top_level.txt +0 -0
|
@@ -1,17 +1,16 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import json
|
|
3
4
|
import time
|
|
4
5
|
from abc import abstractmethod
|
|
5
|
-
from dataclasses import dataclass, field, fields
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
|
-
import
|
|
8
|
+
from pydantic import BaseModel, Field
|
|
9
9
|
|
|
10
10
|
|
|
11
|
-
|
|
12
|
-
class GenerationStats:
|
|
11
|
+
class GenerationStats(BaseModel):
|
|
13
12
|
"""
|
|
14
|
-
|
|
13
|
+
Model for storing statistics about the execution of an operator on a single record.
|
|
15
14
|
"""
|
|
16
15
|
|
|
17
16
|
model_name: str | None = None
|
|
@@ -19,6 +18,15 @@ class GenerationStats:
|
|
|
19
18
|
# The raw answer as output from the generator (a list of strings, possibly of len 1)
|
|
20
19
|
# raw_answers: Optional[List[str]] = field(default_factory=list)
|
|
21
20
|
|
|
21
|
+
# the number of input audio tokens
|
|
22
|
+
input_audio_tokens: int = 0
|
|
23
|
+
|
|
24
|
+
# the number of input text tokens
|
|
25
|
+
input_text_tokens: int = 0
|
|
26
|
+
|
|
27
|
+
# the number of input image tokens
|
|
28
|
+
input_image_tokens: int = 0
|
|
29
|
+
|
|
22
30
|
# the total number of input tokens processed by this operator; None if this operation did not use an LLM
|
|
23
31
|
# typed as a float because GenerationStats may be amortized (i.e. divided) across a number of output records
|
|
24
32
|
total_input_tokens: float = 0.0
|
|
@@ -33,7 +41,7 @@ class GenerationStats:
|
|
|
33
41
|
# the total cost of processing the output tokens; None if this operation did not use an LLM
|
|
34
42
|
total_output_cost: float = 0.0
|
|
35
43
|
|
|
36
|
-
# the total cost of processing the output tokens; None if this operation did not use an LLM
|
|
44
|
+
# the total cost of processing the input and output tokens; None if this operation did not use an LLM
|
|
37
45
|
cost_per_record: float = 0.0
|
|
38
46
|
|
|
39
47
|
# (if applicable) the time (in seconds) spent executing a call to an LLM
|
|
@@ -50,7 +58,7 @@ class GenerationStats:
|
|
|
50
58
|
|
|
51
59
|
def __iadd__(self, other: GenerationStats) -> GenerationStats:
|
|
52
60
|
# self.raw_answers.extend(other.raw_answers)
|
|
53
|
-
for
|
|
61
|
+
for model_field in [
|
|
54
62
|
"total_input_tokens",
|
|
55
63
|
"total_output_tokens",
|
|
56
64
|
"total_input_cost",
|
|
@@ -61,7 +69,7 @@ class GenerationStats:
|
|
|
61
69
|
"total_llm_calls",
|
|
62
70
|
"total_embedding_llm_calls",
|
|
63
71
|
]:
|
|
64
|
-
setattr(self,
|
|
72
|
+
setattr(self, model_field, getattr(self, model_field) + getattr(other, model_field))
|
|
65
73
|
return self
|
|
66
74
|
|
|
67
75
|
def __add__(self, other: GenerationStats) -> GenerationStats:
|
|
@@ -89,7 +97,7 @@ class GenerationStats:
|
|
|
89
97
|
raise ZeroDivisionError("Cannot divide by zero")
|
|
90
98
|
if isinstance(quotient, int):
|
|
91
99
|
quotient = float(quotient)
|
|
92
|
-
for
|
|
100
|
+
for model_field in [
|
|
93
101
|
"total_input_tokens",
|
|
94
102
|
"total_output_tokens",
|
|
95
103
|
"total_input_cost",
|
|
@@ -100,7 +108,7 @@ class GenerationStats:
|
|
|
100
108
|
"total_llm_calls",
|
|
101
109
|
"total_embedding_llm_calls",
|
|
102
110
|
]:
|
|
103
|
-
setattr(self,
|
|
111
|
+
setattr(self, model_field, getattr(self, model_field) / quotient)
|
|
104
112
|
return self
|
|
105
113
|
|
|
106
114
|
def __truediv__(self, quotient: float) -> GenerationStats:
|
|
@@ -129,22 +137,30 @@ class GenerationStats:
|
|
|
129
137
|
assert not isinstance(other, GenerationStats), "This should not be called with a GenerationStats object"
|
|
130
138
|
return self
|
|
131
139
|
|
|
140
|
+
# NOTE: this is added temporarily to help track cost of compute agent writing PZ code;
|
|
141
|
+
# once we find a long-term solution for tracking that cost, we can remove this
|
|
142
|
+
def to_json(self, filepath: str | None = None) -> dict | None:
|
|
143
|
+
if filepath is None:
|
|
144
|
+
return self.model_dump(mode="json")
|
|
145
|
+
|
|
146
|
+
with open(filepath, "w") as f:
|
|
147
|
+
json.dump(self.model_dump(mode="json"), f)
|
|
148
|
+
|
|
132
149
|
|
|
133
|
-
|
|
134
|
-
class RecordOpStats:
|
|
150
|
+
class RecordOpStats(BaseModel):
|
|
135
151
|
"""
|
|
136
|
-
|
|
152
|
+
Model for storing statistics about the execution of an operator on a single record.
|
|
137
153
|
"""
|
|
138
154
|
|
|
139
155
|
##### REQUIRED FIELDS #####
|
|
140
156
|
# record id; an identifier for this record
|
|
141
|
-
record_id: str
|
|
157
|
+
record_id: str | int
|
|
142
158
|
|
|
143
|
-
# identifier for the parent of this record
|
|
144
|
-
|
|
159
|
+
# identifier for the parent(s) of this record
|
|
160
|
+
record_parent_ids: list[str | int] | None
|
|
145
161
|
|
|
146
|
-
# idenifier for the source
|
|
147
|
-
|
|
162
|
+
# idenifier for the source indices of this record
|
|
163
|
+
record_source_indices: list[str | int]
|
|
148
164
|
|
|
149
165
|
# a dictionary with the record state after being processed by the operator
|
|
150
166
|
record_state: dict[str, Any]
|
|
@@ -165,8 +181,11 @@ class RecordOpStats:
|
|
|
165
181
|
cost_per_record: float
|
|
166
182
|
|
|
167
183
|
##### NOT-OPTIONAL, BUT FILLED BY EXECUTION CLASS AFTER CONSTRUCTOR CALL #####
|
|
168
|
-
# the ID of the physical operation which produced the input record for this record at this operation
|
|
169
|
-
|
|
184
|
+
# the ID(s) of the physical operation(s) which produced the input record(s) for this record at this operation
|
|
185
|
+
source_unique_full_op_ids: list[str] | None = None
|
|
186
|
+
|
|
187
|
+
# the ID(s) of the logical operation(s) which produced the input record(s) for this record at this operation
|
|
188
|
+
source_unique_logical_op_ids: list[str] | None = None
|
|
170
189
|
|
|
171
190
|
# the ID of the physical plan which produced this record at this operation
|
|
172
191
|
plan_id: str = ""
|
|
@@ -207,8 +226,11 @@ class RecordOpStats:
|
|
|
207
226
|
# (if applicable) the filter text (or a string representation of the filter function) applied to this record
|
|
208
227
|
filter_str: str | None = None
|
|
209
228
|
|
|
229
|
+
# (if applicable) the join condition applied to this record
|
|
230
|
+
join_condition: str | None = None
|
|
231
|
+
|
|
210
232
|
# the True/False result of whether this record was output by the operator or not
|
|
211
|
-
# (can only be False if the operator is
|
|
233
|
+
# (can only be False if the operator is a Filter or Join)
|
|
212
234
|
passed_operator: bool = True
|
|
213
235
|
|
|
214
236
|
# (if applicable) the time (in seconds) spent executing a call to an LLM
|
|
@@ -230,16 +252,12 @@ class RecordOpStats:
|
|
|
230
252
|
image_operation: bool | None = None
|
|
231
253
|
|
|
232
254
|
# an OPTIONAL dictionary with more detailed information about this operation;
|
|
233
|
-
op_details: dict[str, Any] =
|
|
234
|
-
|
|
235
|
-
def to_json(self):
|
|
236
|
-
return {field.name: getattr(self, field.name) for field in fields(self)}
|
|
255
|
+
op_details: dict[str, Any] = Field(default_factory=dict)
|
|
237
256
|
|
|
238
257
|
|
|
239
|
-
|
|
240
|
-
class OperatorStats:
|
|
258
|
+
class OperatorStats(BaseModel):
|
|
241
259
|
"""
|
|
242
|
-
|
|
260
|
+
Model for storing statistics captured within a given operator.
|
|
243
261
|
"""
|
|
244
262
|
|
|
245
263
|
# the full ID of the physical operation in which these stats were collected
|
|
@@ -254,17 +272,26 @@ class OperatorStats:
|
|
|
254
272
|
# the total cost of this operation
|
|
255
273
|
total_op_cost: float = 0.0
|
|
256
274
|
|
|
275
|
+
# the total input tokens processed by this operation
|
|
276
|
+
total_input_tokens: int = 0
|
|
277
|
+
|
|
278
|
+
# the total output tokens processed by this operation
|
|
279
|
+
total_output_tokens: int = 0
|
|
280
|
+
|
|
257
281
|
# a list of RecordOpStats processed by the operation
|
|
258
|
-
record_op_stats_lst: list[RecordOpStats] =
|
|
282
|
+
record_op_stats_lst: list[RecordOpStats] = Field(default_factory=list)
|
|
259
283
|
|
|
260
|
-
# the full ID of the physical operator which
|
|
261
|
-
|
|
284
|
+
# the unique full ID(s) of the physical operator(s) which precede this one (used by PlanStats)
|
|
285
|
+
source_unique_full_op_ids: list[str] | None = None
|
|
286
|
+
|
|
287
|
+
# the unique full ID(s) of the logical operator(s) which precede this one (used by SentinelPlanStats)
|
|
288
|
+
source_unique_logical_op_ids: list[str] | None = None
|
|
262
289
|
|
|
263
290
|
# the ID of the physical plan which this operator is part of
|
|
264
291
|
plan_id: str = ""
|
|
265
292
|
|
|
266
293
|
# an OPTIONAL dictionary with more detailed information about this operation;
|
|
267
|
-
op_details: dict[str, Any] =
|
|
294
|
+
op_details: dict[str, Any] = Field(default_factory=dict)
|
|
268
295
|
|
|
269
296
|
def __iadd__(self, stats: OperatorStats | RecordOpStats) -> OperatorStats:
|
|
270
297
|
"""
|
|
@@ -280,34 +307,28 @@ class OperatorStats:
|
|
|
280
307
|
if isinstance(stats, OperatorStats):
|
|
281
308
|
self.total_op_time += stats.total_op_time
|
|
282
309
|
self.total_op_cost += stats.total_op_cost
|
|
310
|
+
self.total_input_tokens += stats.total_input_tokens
|
|
311
|
+
self.total_output_tokens += stats.total_output_tokens
|
|
283
312
|
self.record_op_stats_lst.extend(stats.record_op_stats_lst)
|
|
284
313
|
|
|
285
314
|
elif isinstance(stats, RecordOpStats):
|
|
286
|
-
stats.
|
|
315
|
+
stats.source_unique_full_op_ids = self.source_unique_full_op_ids
|
|
287
316
|
stats.plan_id = self.plan_id
|
|
288
317
|
self.record_op_stats_lst.append(stats)
|
|
289
318
|
self.total_op_time += stats.time_per_record
|
|
290
319
|
self.total_op_cost += stats.cost_per_record
|
|
320
|
+
self.total_input_tokens += stats.total_input_tokens
|
|
321
|
+
self.total_output_tokens += stats.total_output_tokens
|
|
291
322
|
|
|
292
323
|
else:
|
|
293
324
|
raise TypeError(f"Cannot add {type(stats)} to OperatorStats")
|
|
294
325
|
|
|
295
326
|
return self
|
|
296
327
|
|
|
297
|
-
def to_json(self):
|
|
298
|
-
return {
|
|
299
|
-
"full_op_id": self.full_op_id,
|
|
300
|
-
"op_name": self.op_name,
|
|
301
|
-
"total_op_time": self.total_op_time,
|
|
302
|
-
"total_op_cost": self.total_op_cost,
|
|
303
|
-
"record_op_stats_lst": [record_op_stats.to_json() for record_op_stats in self.record_op_stats_lst],
|
|
304
|
-
"op_details": self.op_details,
|
|
305
|
-
}
|
|
306
328
|
|
|
307
|
-
|
|
308
|
-
class BasePlanStats:
|
|
329
|
+
class BasePlanStats(BaseModel):
|
|
309
330
|
"""
|
|
310
|
-
|
|
331
|
+
Model for storing statistics captured for an entire plan.
|
|
311
332
|
|
|
312
333
|
This class is subclassed for tracking:
|
|
313
334
|
- PlanStats: the statistics for execution of a PhysicalPlan
|
|
@@ -331,7 +352,11 @@ class BasePlanStats:
|
|
|
331
352
|
# dictionary whose values are OperatorStats objects;
|
|
332
353
|
# PlanStats maps {full_op_id -> OperatorStats}
|
|
333
354
|
# SentinelPlanStats maps {logical_op_id -> {full_op_id -> OperatorStats}}
|
|
334
|
-
operator_stats: dict =
|
|
355
|
+
operator_stats: dict = Field(default_factory=dict)
|
|
356
|
+
|
|
357
|
+
# dictionary whose values are GenerationStats objects for validation;
|
|
358
|
+
# only used by SentinelPlanStats
|
|
359
|
+
validation_gen_stats: dict[str, GenerationStats] = Field(default_factory=dict)
|
|
335
360
|
|
|
336
361
|
# total runtime for the plan measured from the start to the end of PhysicalPlan.execute()
|
|
337
362
|
total_plan_time: float = 0.0
|
|
@@ -339,6 +364,12 @@ class BasePlanStats:
|
|
|
339
364
|
# total cost for plan
|
|
340
365
|
total_plan_cost: float = 0.0
|
|
341
366
|
|
|
367
|
+
# total input tokens processed by this plan
|
|
368
|
+
total_input_tokens: int = 0
|
|
369
|
+
|
|
370
|
+
# total output tokens processed by this plan
|
|
371
|
+
total_output_tokens: int = 0
|
|
372
|
+
|
|
342
373
|
# start time for the plan execution; should be set by calling PlanStats.start()
|
|
343
374
|
start_time: float | None = None
|
|
344
375
|
|
|
@@ -351,7 +382,9 @@ class BasePlanStats:
|
|
|
351
382
|
if self.start_time is None:
|
|
352
383
|
raise RuntimeError("PlanStats.start() must be called before PlanStats.finish()")
|
|
353
384
|
self.total_plan_time = time.time() - self.start_time
|
|
354
|
-
self.total_plan_cost = self.sum_op_costs()
|
|
385
|
+
self.total_plan_cost = self.sum_op_costs() + self.sum_validation_costs()
|
|
386
|
+
self.total_input_tokens = self.sum_input_tokens() + self.sum_validation_input_tokens()
|
|
387
|
+
self.total_output_tokens = self.sum_output_tokens() + self.sum_validation_output_tokens()
|
|
355
388
|
|
|
356
389
|
@staticmethod
|
|
357
390
|
@abstractmethod
|
|
@@ -369,9 +402,23 @@ class BasePlanStats:
|
|
|
369
402
|
pass
|
|
370
403
|
|
|
371
404
|
@abstractmethod
|
|
372
|
-
def
|
|
405
|
+
def sum_input_tokens(self) -> int:
|
|
406
|
+
"""
|
|
407
|
+
Sum the input tokens processed by all operators in this plan.
|
|
408
|
+
"""
|
|
409
|
+
pass
|
|
410
|
+
|
|
411
|
+
@abstractmethod
|
|
412
|
+
def sum_output_tokens(self) -> int:
|
|
413
|
+
"""
|
|
414
|
+
Sum the output tokens processed by all operators in this plan.
|
|
415
|
+
"""
|
|
416
|
+
pass
|
|
417
|
+
|
|
418
|
+
@abstractmethod
|
|
419
|
+
def add_record_op_stats(self, unique_full_op_id: str, record_op_stats: RecordOpStats | list[RecordOpStats]) -> None:
|
|
373
420
|
"""
|
|
374
|
-
Add the given RecordOpStats to this plan's operator stats.
|
|
421
|
+
Add the given RecordOpStats to this plan's operator stats for the given operator id.
|
|
375
422
|
"""
|
|
376
423
|
pass
|
|
377
424
|
|
|
@@ -389,14 +436,25 @@ class BasePlanStats:
|
|
|
389
436
|
"""
|
|
390
437
|
pass
|
|
391
438
|
|
|
392
|
-
|
|
393
|
-
def to_json(self) -> dict:
|
|
439
|
+
def sum_validation_costs(self) -> float:
|
|
394
440
|
"""
|
|
395
|
-
|
|
441
|
+
Sum the costs of all validation generations in this plan.
|
|
396
442
|
"""
|
|
397
|
-
|
|
443
|
+
return sum([gen_stats.cost_per_record for _, gen_stats in self.validation_gen_stats.items()])
|
|
444
|
+
|
|
445
|
+
def sum_validation_input_tokens(self) -> int:
|
|
446
|
+
"""
|
|
447
|
+
Sum the input tokens processed by all validation generations in this plan.
|
|
448
|
+
"""
|
|
449
|
+
return sum([gen_stats.total_input_tokens for _, gen_stats in self.validation_gen_stats.items()])
|
|
450
|
+
|
|
451
|
+
def sum_validation_output_tokens(self) -> int:
|
|
452
|
+
"""
|
|
453
|
+
Sum the output tokens processed by all validation generations in this plan.
|
|
454
|
+
"""
|
|
455
|
+
return sum([gen_stats.total_output_tokens for _, gen_stats in self.validation_gen_stats.items()])
|
|
456
|
+
|
|
398
457
|
|
|
399
|
-
@dataclass
|
|
400
458
|
class PlanStats(BasePlanStats):
|
|
401
459
|
"""
|
|
402
460
|
Subclass of BasePlanStats which captures statistics from the execution of a single PhysicalPlan.
|
|
@@ -406,17 +464,18 @@ class PlanStats(BasePlanStats):
|
|
|
406
464
|
"""
|
|
407
465
|
Initialize this PlanStats object from a PhysicalPlan object.
|
|
408
466
|
"""
|
|
467
|
+
# TODO?: have PhysicalPlan return PlanStats object
|
|
409
468
|
operator_stats = {}
|
|
410
|
-
for
|
|
411
|
-
|
|
412
|
-
operator_stats[
|
|
413
|
-
full_op_id=
|
|
469
|
+
for topo_idx, op in enumerate(plan):
|
|
470
|
+
unique_full_op_id = f"{topo_idx}-{op.get_full_op_id()}"
|
|
471
|
+
operator_stats[unique_full_op_id] = OperatorStats(
|
|
472
|
+
full_op_id=op.get_full_op_id(),
|
|
414
473
|
op_name=op.op_name(),
|
|
415
|
-
|
|
474
|
+
source_unique_full_op_ids=plan.get_source_unique_full_op_ids(topo_idx, op),
|
|
416
475
|
plan_id=plan.plan_id,
|
|
417
476
|
op_details={k: str(v) for k, v in op.get_id_params().items()},
|
|
418
477
|
)
|
|
419
|
-
|
|
478
|
+
|
|
420
479
|
return PlanStats(plan_id=plan.plan_id, plan_str=str(plan), operator_stats=operator_stats)
|
|
421
480
|
|
|
422
481
|
def sum_op_costs(self) -> float:
|
|
@@ -425,20 +484,31 @@ class PlanStats(BasePlanStats):
|
|
|
425
484
|
"""
|
|
426
485
|
return sum([op_stats.total_op_cost for _, op_stats in self.operator_stats.items()])
|
|
427
486
|
|
|
428
|
-
def
|
|
487
|
+
def sum_input_tokens(self) -> int:
|
|
488
|
+
"""
|
|
489
|
+
Sum the input tokens processed by all operators in this plan.
|
|
490
|
+
"""
|
|
491
|
+
return sum([op_stats.total_input_tokens for _, op_stats in self.operator_stats.items()])
|
|
492
|
+
|
|
493
|
+
def sum_output_tokens(self) -> int:
|
|
429
494
|
"""
|
|
430
|
-
|
|
495
|
+
Sum the output tokens processed by all operators in this plan.
|
|
496
|
+
"""
|
|
497
|
+
return sum([op_stats.total_output_tokens for _, op_stats in self.operator_stats.items()])
|
|
498
|
+
|
|
499
|
+
def add_record_op_stats(self, unique_full_op_id: str, record_op_stats: RecordOpStats | list[RecordOpStats]) -> None:
|
|
500
|
+
"""
|
|
501
|
+
Add the given RecordOpStats to this plan's operator stats for the given operator id.
|
|
431
502
|
"""
|
|
432
503
|
# normalize input type to be list[RecordOpStats]
|
|
433
504
|
record_op_stats_lst = record_op_stats if isinstance(record_op_stats, list) else [record_op_stats]
|
|
434
505
|
|
|
435
506
|
# update operator stats
|
|
436
507
|
for record_op_stats in record_op_stats_lst:
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
self.operator_stats[full_op_id] += record_op_stats
|
|
508
|
+
if unique_full_op_id in self.operator_stats:
|
|
509
|
+
self.operator_stats[unique_full_op_id] += record_op_stats
|
|
440
510
|
else:
|
|
441
|
-
raise ValueError(f"RecordOpStats with
|
|
511
|
+
raise ValueError(f"RecordOpStats with unique_full_op_id {unique_full_op_id} not found in PlanStats")
|
|
442
512
|
|
|
443
513
|
def __iadd__(self, plan_stats: PlanStats) -> None:
|
|
444
514
|
"""
|
|
@@ -450,30 +520,24 @@ class PlanStats(BasePlanStats):
|
|
|
450
520
|
"""
|
|
451
521
|
self.total_plan_time += plan_stats.total_plan_time
|
|
452
522
|
self.total_plan_cost += plan_stats.total_plan_cost
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
523
|
+
self.total_input_tokens += plan_stats.total_input_tokens
|
|
524
|
+
self.total_output_tokens += plan_stats.total_output_tokens
|
|
525
|
+
for unique_full_op_id, op_stats in plan_stats.operator_stats.items():
|
|
526
|
+
if unique_full_op_id in self.operator_stats:
|
|
527
|
+
self.operator_stats[unique_full_op_id] += op_stats
|
|
456
528
|
else:
|
|
457
|
-
self.operator_stats[
|
|
529
|
+
self.operator_stats[unique_full_op_id] = op_stats
|
|
458
530
|
|
|
459
531
|
def __str__(self) -> str:
|
|
460
532
|
stats = f"total_plan_time={self.total_plan_time} \n"
|
|
461
533
|
stats += f"total_plan_cost={self.total_plan_cost} \n"
|
|
534
|
+
stats += f"total_input_tokens={self.total_input_tokens} \n"
|
|
535
|
+
stats += f"total_output_tokens={self.total_output_tokens} \n"
|
|
462
536
|
for idx, op_stats in enumerate(self.operator_stats.values()):
|
|
463
537
|
stats += f"{idx}. {op_stats.op_name} time={op_stats.total_op_time} cost={op_stats.total_op_cost} \n"
|
|
464
538
|
return stats
|
|
465
539
|
|
|
466
|
-
def to_json(self) -> dict:
|
|
467
|
-
return {
|
|
468
|
-
"plan_id": self.plan_id,
|
|
469
|
-
"plan_str": self.plan_str,
|
|
470
|
-
"operator_stats": {full_op_id: op_stats.to_json() for full_op_id, op_stats in self.operator_stats.items()},
|
|
471
|
-
"total_plan_time": self.total_plan_time,
|
|
472
|
-
"total_plan_cost": self.total_plan_cost,
|
|
473
|
-
}
|
|
474
540
|
|
|
475
|
-
|
|
476
|
-
@dataclass
|
|
477
541
|
class SentinelPlanStats(BasePlanStats):
|
|
478
542
|
"""
|
|
479
543
|
Subclass of BasePlanStats which captures statistics from the execution of a single SentinelPlan.
|
|
@@ -484,18 +548,19 @@ class SentinelPlanStats(BasePlanStats):
|
|
|
484
548
|
Initialize this PlanStats object from a Sentinel object.
|
|
485
549
|
"""
|
|
486
550
|
operator_stats = {}
|
|
487
|
-
for
|
|
488
|
-
|
|
551
|
+
for topo_idx, (logical_op_id, op_set) in enumerate(plan):
|
|
552
|
+
unique_logical_op_id = f"{topo_idx}-{logical_op_id}"
|
|
553
|
+
operator_stats[unique_logical_op_id] = {}
|
|
489
554
|
for physical_op in op_set:
|
|
490
555
|
full_op_id = physical_op.get_full_op_id()
|
|
491
|
-
operator_stats[
|
|
556
|
+
operator_stats[unique_logical_op_id][full_op_id] = OperatorStats(
|
|
492
557
|
full_op_id=full_op_id,
|
|
493
558
|
op_name=physical_op.op_name(),
|
|
494
|
-
|
|
559
|
+
source_unique_logical_op_ids=plan.get_source_unique_logical_op_ids(unique_logical_op_id),
|
|
495
560
|
plan_id=plan.plan_id,
|
|
496
561
|
op_details={k: str(v) for k, v in physical_op.get_id_params().items()},
|
|
497
562
|
)
|
|
498
|
-
|
|
563
|
+
|
|
499
564
|
return SentinelPlanStats(plan_id=plan.plan_id, plan_str=str(plan), operator_stats=operator_stats)
|
|
500
565
|
|
|
501
566
|
def sum_op_costs(self) -> float:
|
|
@@ -504,24 +569,45 @@ class SentinelPlanStats(BasePlanStats):
|
|
|
504
569
|
"""
|
|
505
570
|
return sum(sum([op_stats.total_op_cost for _, op_stats in phys_op_stats.items()]) for _, phys_op_stats in self.operator_stats.items())
|
|
506
571
|
|
|
507
|
-
def
|
|
572
|
+
def sum_input_tokens(self) -> int:
|
|
573
|
+
"""
|
|
574
|
+
Sum the input tokens processed by all operators in this plan.
|
|
575
|
+
"""
|
|
576
|
+
return sum(sum([op_stats.total_input_tokens for _, op_stats in phys_op_stats.items()]) for _, phys_op_stats in self.operator_stats.items())
|
|
577
|
+
|
|
578
|
+
def sum_output_tokens(self) -> int:
|
|
508
579
|
"""
|
|
509
|
-
|
|
580
|
+
Sum the output tokens processed by all operators in this plan.
|
|
581
|
+
"""
|
|
582
|
+
return sum(sum([op_stats.total_output_tokens for _, op_stats in phys_op_stats.items()]) for _, phys_op_stats in self.operator_stats.items())
|
|
583
|
+
|
|
584
|
+
def add_record_op_stats(self, unique_logical_op_id: str, record_op_stats: RecordOpStats | list[RecordOpStats]) -> None:
|
|
585
|
+
"""
|
|
586
|
+
Add the given RecordOpStats to this plan's operator stats for the given operator set id.
|
|
510
587
|
"""
|
|
511
588
|
# normalize input type to be list[RecordOpStats]
|
|
512
589
|
record_op_stats_lst = record_op_stats if isinstance(record_op_stats, list) else [record_op_stats]
|
|
513
590
|
|
|
514
591
|
# update operator stats
|
|
515
592
|
for record_op_stats in record_op_stats_lst:
|
|
516
|
-
logical_op_id = record_op_stats.logical_op_id
|
|
517
593
|
full_op_id = record_op_stats.full_op_id
|
|
518
|
-
if
|
|
519
|
-
if full_op_id in self.operator_stats[
|
|
520
|
-
self.operator_stats[
|
|
594
|
+
if unique_logical_op_id in self.operator_stats:
|
|
595
|
+
if full_op_id in self.operator_stats[unique_logical_op_id]:
|
|
596
|
+
self.operator_stats[unique_logical_op_id][full_op_id] += record_op_stats
|
|
521
597
|
else:
|
|
522
598
|
raise ValueError(f"RecordOpStats with full_op_id {full_op_id} not found in SentinelPlanStats")
|
|
523
599
|
else:
|
|
524
|
-
raise ValueError(f"RecordOpStats with
|
|
600
|
+
raise ValueError(f"RecordOpStats with unique_logical_op_id {unique_logical_op_id} not found in SentinelPlanStats")
|
|
601
|
+
|
|
602
|
+
def add_validation_gen_stats(self, unique_logical_op_id: str, gen_stats: GenerationStats) -> None:
|
|
603
|
+
"""
|
|
604
|
+
Add the given GenerationStats to this plan's validation generation stats for the given logical operator id.
|
|
605
|
+
"""
|
|
606
|
+
if unique_logical_op_id in self.validation_gen_stats:
|
|
607
|
+
self.validation_gen_stats[unique_logical_op_id] += gen_stats
|
|
608
|
+
else:
|
|
609
|
+
self.validation_gen_stats[unique_logical_op_id] = gen_stats
|
|
610
|
+
|
|
525
611
|
|
|
526
612
|
def __iadd__(self, plan_stats: SentinelPlanStats) -> None:
|
|
527
613
|
"""
|
|
@@ -533,19 +619,29 @@ class SentinelPlanStats(BasePlanStats):
|
|
|
533
619
|
"""
|
|
534
620
|
self.total_plan_time += plan_stats.total_plan_time
|
|
535
621
|
self.total_plan_cost += plan_stats.total_plan_cost
|
|
536
|
-
|
|
622
|
+
self.total_input_tokens += plan_stats.total_input_tokens
|
|
623
|
+
self.total_output_tokens += plan_stats.total_output_tokens
|
|
624
|
+
for unique_logical_op_id, physical_op_stats in plan_stats.operator_stats.items():
|
|
537
625
|
for full_op_id, op_stats in physical_op_stats.items():
|
|
538
|
-
if
|
|
539
|
-
if full_op_id in self.operator_stats[
|
|
540
|
-
self.operator_stats[
|
|
626
|
+
if unique_logical_op_id in self.operator_stats:
|
|
627
|
+
if full_op_id in self.operator_stats[unique_logical_op_id]:
|
|
628
|
+
self.operator_stats[unique_logical_op_id][full_op_id] += op_stats
|
|
541
629
|
else:
|
|
542
|
-
self.operator_stats[
|
|
630
|
+
self.operator_stats[unique_logical_op_id][full_op_id] = op_stats
|
|
543
631
|
else:
|
|
544
|
-
self.operator_stats[
|
|
632
|
+
self.operator_stats[unique_logical_op_id] = physical_op_stats
|
|
633
|
+
|
|
634
|
+
for unique_logical_op_id, gen_stats in plan_stats.validation_gen_stats.items():
|
|
635
|
+
if unique_logical_op_id in self.validation_gen_stats:
|
|
636
|
+
self.validation_gen_stats[unique_logical_op_id] += gen_stats
|
|
637
|
+
else:
|
|
638
|
+
self.validation_gen_stats[unique_logical_op_id] = gen_stats
|
|
545
639
|
|
|
546
640
|
def __str__(self) -> str:
|
|
547
641
|
stats = f"total_plan_time={self.total_plan_time} \n"
|
|
548
642
|
stats += f"total_plan_cost={self.total_plan_cost} \n"
|
|
643
|
+
stats += f"total_input_tokens={self.total_input_tokens} \n"
|
|
644
|
+
stats += f"total_output_tokens={self.total_output_tokens} \n"
|
|
549
645
|
for outer_idx, physical_op_stats in enumerate(self.operator_stats.values()):
|
|
550
646
|
total_time = sum([op_stats.total_op_time for op_stats in physical_op_stats.values()])
|
|
551
647
|
total_cost = sum([op_stats.total_op_cost for op_stats in physical_op_stats.values()])
|
|
@@ -554,33 +650,20 @@ class SentinelPlanStats(BasePlanStats):
|
|
|
554
650
|
stats += f" {outer_idx}.{inner_idx}. {op_stats.op_name} time={op_stats.total_op_time} cost={op_stats.total_op_cost} \n"
|
|
555
651
|
return stats
|
|
556
652
|
|
|
557
|
-
def to_json(self) -> dict:
|
|
558
|
-
return {
|
|
559
|
-
"plan_id": self.plan_id,
|
|
560
|
-
"plan_str": self.plan_str,
|
|
561
|
-
"operator_stats": {
|
|
562
|
-
logical_op_id: {full_op_id: op_stats.to_json() for full_op_id, op_stats in physical_op_stats.items()}
|
|
563
|
-
for logical_op_id, physical_op_stats in self.operator_stats.items()
|
|
564
|
-
},
|
|
565
|
-
"total_plan_time": self.total_plan_time,
|
|
566
|
-
"total_plan_cost": self.total_plan_cost,
|
|
567
|
-
}
|
|
568
|
-
|
|
569
653
|
|
|
570
|
-
|
|
571
|
-
class ExecutionStats:
|
|
654
|
+
class ExecutionStats(BaseModel):
|
|
572
655
|
"""
|
|
573
|
-
|
|
656
|
+
Model for storing statistics captured for the entire execution of a workload.
|
|
574
657
|
"""
|
|
575
658
|
|
|
576
659
|
# string for identifying this workload execution
|
|
577
660
|
execution_id: str | None = None
|
|
578
661
|
|
|
579
662
|
# dictionary of SentinelPlanStats objects (one for each sentinel plan run during execution)
|
|
580
|
-
sentinel_plan_stats: dict[str, SentinelPlanStats] =
|
|
663
|
+
sentinel_plan_stats: dict[str, SentinelPlanStats] = Field(default_factory=dict)
|
|
581
664
|
|
|
582
665
|
# dictionary of PlanStats objects (one for each plan run during execution)
|
|
583
|
-
plan_stats: dict[str, PlanStats] =
|
|
666
|
+
plan_stats: dict[str, PlanStats] = Field(default_factory=dict)
|
|
584
667
|
|
|
585
668
|
# total time spent optimizing
|
|
586
669
|
optimization_time: float = 0.0
|
|
@@ -600,16 +683,25 @@ class ExecutionStats:
|
|
|
600
683
|
# total cost for the entire execution
|
|
601
684
|
total_execution_cost: float = 0.0
|
|
602
685
|
|
|
686
|
+
# total number of input tokens processed
|
|
687
|
+
total_input_tokens: int = 0
|
|
688
|
+
|
|
689
|
+
# total number of output tokens processed
|
|
690
|
+
total_output_tokens: int = 0
|
|
691
|
+
|
|
692
|
+
# total number of tokens processed
|
|
693
|
+
total_tokens: int = 0
|
|
694
|
+
|
|
603
695
|
# dictionary of sentinel plan strings; useful for printing executed sentinel plans in demos
|
|
604
|
-
sentinel_plan_strs: dict[str, str] =
|
|
696
|
+
sentinel_plan_strs: dict[str, str] = Field(default_factory=dict)
|
|
605
697
|
|
|
606
698
|
# dictionary of plan strings; useful for printing executed plans in demos
|
|
607
|
-
plan_strs: dict[str, str] =
|
|
699
|
+
plan_strs: dict[str, str] = Field(default_factory=dict)
|
|
608
700
|
|
|
609
701
|
# start time for the execution; should be set by calling ExecutionStats.start()
|
|
610
702
|
start_time: float | None = None
|
|
611
703
|
|
|
612
|
-
# end time for the optimization;
|
|
704
|
+
# end time for the optimization;
|
|
613
705
|
optimization_end_time: float | None = None
|
|
614
706
|
|
|
615
707
|
def start(self) -> None:
|
|
@@ -647,6 +739,11 @@ class ExecutionStats:
|
|
|
647
739
|
self.plan_execution_cost = self.sum_plan_costs()
|
|
648
740
|
self.total_execution_cost = self.optimization_cost + self.plan_execution_cost
|
|
649
741
|
|
|
742
|
+
# compute the tokens for total execution
|
|
743
|
+
self.total_input_tokens = self.sum_input_tokens()
|
|
744
|
+
self.total_output_tokens = self.sum_output_tokens()
|
|
745
|
+
self.total_tokens = self.total_input_tokens + self.total_output_tokens
|
|
746
|
+
|
|
650
747
|
# compute plan_strs
|
|
651
748
|
self.plan_strs = {plan_id: plan_stats.plan_str for plan_id, plan_stats in self.plan_stats.items()}
|
|
652
749
|
|
|
@@ -654,7 +751,7 @@ class ExecutionStats:
|
|
|
654
751
|
"""
|
|
655
752
|
Sum the costs of all SentinelPlans in this execution.
|
|
656
753
|
"""
|
|
657
|
-
return sum([plan_stats.sum_op_costs() for _, plan_stats in self.sentinel_plan_stats.items()])
|
|
754
|
+
return sum([plan_stats.sum_op_costs() + plan_stats.sum_validation_costs() for _, plan_stats in self.sentinel_plan_stats.items()])
|
|
658
755
|
|
|
659
756
|
def sum_plan_costs(self) -> float:
|
|
660
757
|
"""
|
|
@@ -662,6 +759,22 @@ class ExecutionStats:
|
|
|
662
759
|
"""
|
|
663
760
|
return sum([plan_stats.sum_op_costs() for _, plan_stats in self.plan_stats.items()])
|
|
664
761
|
|
|
762
|
+
def sum_input_tokens(self) -> int:
|
|
763
|
+
"""
|
|
764
|
+
Sum the input tokens processed in this execution
|
|
765
|
+
"""
|
|
766
|
+
sentinel_plan_input_tokens = sum([plan_stats.sum_input_tokens() for _, plan_stats in self.sentinel_plan_stats.items()])
|
|
767
|
+
plan_input_tokens = sum([plan_stats.sum_input_tokens() for _, plan_stats in self.plan_stats.items()])
|
|
768
|
+
return plan_input_tokens + sentinel_plan_input_tokens
|
|
769
|
+
|
|
770
|
+
def sum_output_tokens(self) -> int:
|
|
771
|
+
"""
|
|
772
|
+
Sum the output tokens processed in this execution
|
|
773
|
+
"""
|
|
774
|
+
sentinel_plan_output_tokens = sum([plan_stats.sum_output_tokens() for _, plan_stats in self.sentinel_plan_stats.items()])
|
|
775
|
+
plan_output_tokens = sum([plan_stats.sum_output_tokens() for _, plan_stats in self.plan_stats.items()])
|
|
776
|
+
return plan_output_tokens + sentinel_plan_output_tokens
|
|
777
|
+
|
|
665
778
|
def add_plan_stats(self, plan_stats: PlanStats | SentinelPlanStats | list[PlanStats] | list[SentinelPlanStats]) -> None:
|
|
666
779
|
"""
|
|
667
780
|
Add the given PlanStats (or SentinelPlanStats) to this execution's plan stats.
|
|
@@ -686,43 +799,17 @@ class ExecutionStats:
|
|
|
686
799
|
else:
|
|
687
800
|
raise TypeError(f"Cannot add {type(plan_stats)} to ExecutionStats")
|
|
688
801
|
|
|
689
|
-
def
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
"""
|
|
693
|
-
for key, value in stats.items():
|
|
694
|
-
if isinstance(value, dict):
|
|
695
|
-
stats[key] = self.clean_json(value)
|
|
696
|
-
elif isinstance(value, np.int64):
|
|
697
|
-
stats[key] = int(value)
|
|
698
|
-
elif isinstance(value, np.float64):
|
|
699
|
-
stats[key] = float(value)
|
|
700
|
-
return stats
|
|
802
|
+
def to_json(self, filepath: str | None = None) -> dict | None:
|
|
803
|
+
if filepath is None:
|
|
804
|
+
return self.model_dump(mode="json")
|
|
701
805
|
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
"execution_id": self.execution_id,
|
|
705
|
-
"sentinel_plan_stats": {
|
|
706
|
-
plan_id: plan_stats.to_json() for plan_id, plan_stats in self.sentinel_plan_stats.items()
|
|
707
|
-
},
|
|
708
|
-
"plan_stats": {plan_id: plan_stats.to_json() for plan_id, plan_stats in self.plan_stats.items()},
|
|
709
|
-
"optimization_time": self.optimization_time,
|
|
710
|
-
"optimization_cost": self.optimization_cost,
|
|
711
|
-
"plan_execution_time": self.plan_execution_time,
|
|
712
|
-
"plan_execution_cost": self.plan_execution_cost,
|
|
713
|
-
"total_execution_time": self.total_execution_time,
|
|
714
|
-
"total_execution_cost": self.total_execution_cost,
|
|
715
|
-
"sentinel_plan_strs": self.sentinel_plan_strs,
|
|
716
|
-
"plan_strs": self.plan_strs,
|
|
717
|
-
}
|
|
718
|
-
stats = self.clean_json(stats)
|
|
719
|
-
return stats
|
|
806
|
+
with open(filepath, "w") as f:
|
|
807
|
+
json.dump(self.model_dump(mode="json"), f)
|
|
720
808
|
|
|
721
809
|
|
|
722
|
-
|
|
723
|
-
class OperatorCostEstimates:
|
|
810
|
+
class OperatorCostEstimates(BaseModel):
|
|
724
811
|
"""
|
|
725
|
-
|
|
812
|
+
Model for storing estimates of key metrics of interest for each operator.
|
|
726
813
|
"""
|
|
727
814
|
|
|
728
815
|
# (estimated) number of records output by this operator
|
|
@@ -765,10 +852,10 @@ class OperatorCostEstimates:
|
|
|
765
852
|
"""
|
|
766
853
|
Multiply all fields by a scalar.
|
|
767
854
|
"""
|
|
768
|
-
dct = {
|
|
855
|
+
dct = {field_name: getattr(self, field_name) * multiplier for field_name in self.model_fields}
|
|
769
856
|
return OperatorCostEstimates(**dct)
|
|
770
857
|
|
|
771
|
-
def
|
|
858
|
+
def model_post_init(self, __context: Any) -> None:
|
|
772
859
|
if self.cardinality_lower_bound is None and self.cardinality_upper_bound is None:
|
|
773
860
|
self.cardinality_lower_bound = self.cardinality
|
|
774
861
|
self.cardinality_upper_bound = self.cardinality
|
|
@@ -786,10 +873,9 @@ class OperatorCostEstimates:
|
|
|
786
873
|
self.quality_upper_bound = self.quality
|
|
787
874
|
|
|
788
875
|
|
|
789
|
-
|
|
790
|
-
class PlanCost:
|
|
876
|
+
class PlanCost(BaseModel):
|
|
791
877
|
"""
|
|
792
|
-
|
|
878
|
+
Model for storing the (cost, time, quality) estimates of (sub)-plans and their upper and lower bounds.
|
|
793
879
|
"""
|
|
794
880
|
|
|
795
881
|
# the expression cost
|
|
@@ -825,7 +911,16 @@ class PlanCost:
|
|
|
825
911
|
def __hash__(self):
|
|
826
912
|
return hash(f"{self.cost}-{self.time}-{self.quality}")
|
|
827
913
|
|
|
828
|
-
def
|
|
914
|
+
def __eq__(self, other: Any) -> bool:
|
|
915
|
+
if not isinstance(other, PlanCost):
|
|
916
|
+
return False
|
|
917
|
+
return (
|
|
918
|
+
self.cost == other.cost
|
|
919
|
+
and self.time == other.time
|
|
920
|
+
and self.quality == other.quality
|
|
921
|
+
)
|
|
922
|
+
|
|
923
|
+
def model_post_init(self, __context: Any) -> None:
|
|
829
924
|
if self.time_lower_bound is None and self.time_upper_bound is None:
|
|
830
925
|
self.time_lower_bound = self.time
|
|
831
926
|
self.time_upper_bound = self.time
|
|
@@ -838,30 +933,71 @@ class PlanCost:
|
|
|
838
933
|
self.quality_lower_bound = self.quality
|
|
839
934
|
self.quality_upper_bound = self.quality
|
|
840
935
|
|
|
936
|
+
def join_add(self, left_plan_cost: PlanCost, right_plan_cost: PlanCost, execution_strategy: str = "parallel") -> PlanCost:
|
|
937
|
+
"""
|
|
938
|
+
Add the PlanCost objects for two joined plans (left_plan_cost and right_plan_cost)
|
|
939
|
+
to the PlanCost object for the join operator. The execution strategy determines how
|
|
940
|
+
the input times are combined. If the execution strategy is "parallel", the input time
|
|
941
|
+
is the maximum of the two times. If the execution strategy is "sequential" (which is
|
|
942
|
+
currently anything else), the input time is the sum of the two times.
|
|
943
|
+
|
|
944
|
+
For quality, we compute the produce of the operator quality with the average of the
|
|
945
|
+
two input qualities.
|
|
946
|
+
|
|
947
|
+
NOTE: we currently assume the updating of the op_estimates are handled by the caller
|
|
948
|
+
as there is not a universally correct meaning of addition of op_estimates.
|
|
949
|
+
"""
|
|
950
|
+
dct = {}
|
|
951
|
+
for model_field in ["cost", "cost_lower_bound", "cost_upper_bound"]:
|
|
952
|
+
op_field_value = getattr(self, model_field)
|
|
953
|
+
left_plan_field_value = getattr(left_plan_cost, model_field)
|
|
954
|
+
right_plan_field_value = getattr(right_plan_cost, model_field)
|
|
955
|
+
if op_field_value is not None and left_plan_field_value is not None and right_plan_field_value is not None:
|
|
956
|
+
dct[model_field] = op_field_value + left_plan_field_value + right_plan_field_value
|
|
957
|
+
|
|
958
|
+
for model_field in ["time", "time_lower_bound", "time_upper_bound"]:
|
|
959
|
+
op_field_value = getattr(self, model_field)
|
|
960
|
+
left_plan_field_value = getattr(left_plan_cost, model_field)
|
|
961
|
+
right_plan_field_value = getattr(right_plan_cost, model_field)
|
|
962
|
+
if op_field_value is not None and left_plan_field_value is not None and right_plan_field_value is not None:
|
|
963
|
+
if execution_strategy == "parallel":
|
|
964
|
+
dct[model_field] = op_field_value + max(left_plan_field_value, right_plan_field_value)
|
|
965
|
+
else:
|
|
966
|
+
dct[model_field] = op_field_value + left_plan_field_value + right_plan_field_value
|
|
967
|
+
|
|
968
|
+
for model_field in ["quality", "quality_lower_bound", "quality_upper_bound"]:
|
|
969
|
+
op_field_value = getattr(self, model_field)
|
|
970
|
+
left_plan_field_value = getattr(left_plan_cost, model_field)
|
|
971
|
+
right_plan_field_value = getattr(right_plan_cost, model_field)
|
|
972
|
+
if op_field_value is not None and left_plan_field_value is not None and right_plan_field_value is not None:
|
|
973
|
+
dct[model_field] = op_field_value * ((left_plan_field_value + right_plan_field_value) / 2.0)
|
|
974
|
+
|
|
975
|
+
return PlanCost(**dct)
|
|
976
|
+
|
|
841
977
|
def __iadd__(self, other: PlanCost) -> PlanCost:
|
|
842
978
|
"""
|
|
843
979
|
NOTE: we currently assume the updating of the op_estimates are handled by the caller
|
|
844
|
-
as there is not a universally correct meaning of addition of
|
|
980
|
+
as there is not a universally correct meaning of addition of op_estimates.
|
|
845
981
|
"""
|
|
846
982
|
self.cost += other.cost
|
|
847
983
|
self.time += other.time
|
|
848
984
|
self.quality *= other.quality
|
|
849
|
-
for
|
|
850
|
-
if getattr(self,
|
|
851
|
-
summation = getattr(self,
|
|
852
|
-
setattr(self,
|
|
985
|
+
for model_field in ["cost_lower_bound", "cost_upper_bound", "time_lower_bound", "time_upper_bound"]:
|
|
986
|
+
if getattr(self, model_field) is not None and getattr(other, model_field) is not None:
|
|
987
|
+
summation = getattr(self, model_field) + getattr(other, model_field)
|
|
988
|
+
setattr(self, model_field, summation)
|
|
853
989
|
|
|
854
|
-
for
|
|
855
|
-
if getattr(self,
|
|
856
|
-
product = getattr(self,
|
|
857
|
-
setattr(self,
|
|
990
|
+
for model_field in ["quality_lower_bound", "quality_upper_bound"]:
|
|
991
|
+
if getattr(self, model_field) is not None and getattr(other, model_field) is not None:
|
|
992
|
+
product = getattr(self, model_field) * getattr(other, model_field)
|
|
993
|
+
setattr(self, model_field, product)
|
|
858
994
|
|
|
859
995
|
return self
|
|
860
996
|
|
|
861
997
|
def __add__(self, other: PlanCost) -> PlanCost:
|
|
862
998
|
"""
|
|
863
999
|
NOTE: we currently assume the updating of the op_estimates are handled by the caller
|
|
864
|
-
as there is not a universally correct meaning of addition of
|
|
1000
|
+
as there is not a universally correct meaning of addition of op_estimates.
|
|
865
1001
|
"""
|
|
866
1002
|
dct = {
|
|
867
1003
|
field: getattr(self, field) + getattr(other, field)
|
|
@@ -874,7 +1010,7 @@ class PlanCost:
|
|
|
874
1010
|
"time_upper_bound",
|
|
875
1011
|
]
|
|
876
1012
|
}
|
|
877
|
-
for
|
|
878
|
-
dct[
|
|
1013
|
+
for model_field in ["quality", "quality_lower_bound", "quality_upper_bound"]:
|
|
1014
|
+
dct[model_field] = getattr(self, model_field) * getattr(other, model_field)
|
|
879
1015
|
|
|
880
1016
|
return PlanCost(**dct)
|