palimpzest 0.6.3__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/__init__.py +5 -0
- palimpzest/constants.py +110 -43
- palimpzest/core/__init__.py +0 -78
- palimpzest/core/data/dataclasses.py +382 -44
- palimpzest/core/elements/filters.py +7 -3
- palimpzest/core/elements/index.py +70 -0
- palimpzest/core/elements/records.py +33 -11
- palimpzest/core/lib/fields.py +1 -0
- palimpzest/core/lib/schemas.py +4 -3
- palimpzest/prompts/moa_proposer_convert_prompts.py +0 -4
- palimpzest/prompts/prompt_factory.py +44 -7
- palimpzest/prompts/split_merge_prompts.py +56 -0
- palimpzest/prompts/split_proposer_prompts.py +55 -0
- palimpzest/query/execution/execution_strategy.py +435 -53
- palimpzest/query/execution/execution_strategy_type.py +20 -0
- palimpzest/query/execution/mab_execution_strategy.py +532 -0
- palimpzest/query/execution/parallel_execution_strategy.py +143 -172
- palimpzest/query/execution/random_sampling_execution_strategy.py +240 -0
- palimpzest/query/execution/single_threaded_execution_strategy.py +173 -203
- palimpzest/query/generators/api_client_factory.py +31 -0
- palimpzest/query/generators/generators.py +256 -76
- palimpzest/query/operators/__init__.py +1 -2
- palimpzest/query/operators/code_synthesis_convert.py +33 -18
- palimpzest/query/operators/convert.py +30 -97
- palimpzest/query/operators/critique_and_refine_convert.py +5 -6
- palimpzest/query/operators/filter.py +7 -10
- palimpzest/query/operators/logical.py +54 -10
- palimpzest/query/operators/map.py +130 -0
- palimpzest/query/operators/mixture_of_agents_convert.py +6 -6
- palimpzest/query/operators/physical.py +3 -12
- palimpzest/query/operators/rag_convert.py +66 -18
- palimpzest/query/operators/retrieve.py +230 -34
- palimpzest/query/operators/scan.py +5 -2
- palimpzest/query/operators/split_convert.py +169 -0
- palimpzest/query/operators/token_reduction_convert.py +8 -14
- palimpzest/query/optimizer/__init__.py +4 -16
- palimpzest/query/optimizer/cost_model.py +73 -266
- palimpzest/query/optimizer/optimizer.py +87 -58
- palimpzest/query/optimizer/optimizer_strategy.py +18 -97
- palimpzest/query/optimizer/optimizer_strategy_type.py +37 -0
- palimpzest/query/optimizer/plan.py +2 -3
- palimpzest/query/optimizer/primitives.py +5 -3
- palimpzest/query/optimizer/rules.py +336 -172
- palimpzest/query/optimizer/tasks.py +30 -100
- palimpzest/query/processor/config.py +38 -22
- palimpzest/query/processor/nosentinel_processor.py +16 -520
- palimpzest/query/processor/processing_strategy_type.py +28 -0
- palimpzest/query/processor/query_processor.py +38 -206
- palimpzest/query/processor/query_processor_factory.py +117 -130
- palimpzest/query/processor/sentinel_processor.py +90 -0
- palimpzest/query/processor/streaming_processor.py +25 -32
- palimpzest/sets.py +88 -41
- palimpzest/utils/model_helpers.py +8 -7
- palimpzest/utils/progress.py +368 -152
- palimpzest/utils/token_reduction_helpers.py +1 -3
- {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/METADATA +28 -24
- palimpzest-0.7.0.dist-info/RECORD +96 -0
- {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/WHEEL +1 -1
- palimpzest/query/processor/mab_sentinel_processor.py +0 -884
- palimpzest/query/processor/random_sampling_sentinel_processor.py +0 -639
- palimpzest/utils/index_helpers.py +0 -6
- palimpzest-0.6.3.dist-info/RECORD +0 -87
- {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info/licenses}/LICENSE +0 -0
- {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import time
|
|
4
|
+
from abc import abstractmethod
|
|
3
5
|
from dataclasses import dataclass, field, fields
|
|
4
6
|
from typing import Any
|
|
5
7
|
|
|
@@ -38,6 +40,12 @@ class GenerationStats:
|
|
|
38
40
|
# (if applicable) the time (in seconds) spent executing a call to a function
|
|
39
41
|
fn_call_duration_secs: float = 0.0
|
|
40
42
|
|
|
43
|
+
# (if applicable) the total number of LLM calls made by this operator
|
|
44
|
+
total_llm_calls: int = 0
|
|
45
|
+
|
|
46
|
+
# (if applicable) the total number of embedding LLM calls made by this operator
|
|
47
|
+
total_embedding_llm_calls: int = 0
|
|
48
|
+
|
|
41
49
|
def __iadd__(self, other: GenerationStats) -> GenerationStats:
|
|
42
50
|
# self.raw_answers.extend(other.raw_answers)
|
|
43
51
|
for dataclass_field in [
|
|
@@ -48,6 +56,8 @@ class GenerationStats:
|
|
|
48
56
|
"cost_per_record",
|
|
49
57
|
"llm_call_duration_secs",
|
|
50
58
|
"fn_call_duration_secs",
|
|
59
|
+
"total_llm_calls",
|
|
60
|
+
"total_embedding_llm_calls",
|
|
51
61
|
]:
|
|
52
62
|
setattr(self, dataclass_field, getattr(self, dataclass_field) + getattr(other, dataclass_field))
|
|
53
63
|
return self
|
|
@@ -63,6 +73,8 @@ class GenerationStats:
|
|
|
63
73
|
"llm_call_duration_secs",
|
|
64
74
|
"fn_call_duration_secs",
|
|
65
75
|
"cost_per_record",
|
|
76
|
+
"total_llm_calls",
|
|
77
|
+
"total_embedding_llm_calls",
|
|
66
78
|
]
|
|
67
79
|
}
|
|
68
80
|
# dct['raw_answers'] = self.raw_answers + other.raw_answers
|
|
@@ -83,6 +95,8 @@ class GenerationStats:
|
|
|
83
95
|
"cost_per_record",
|
|
84
96
|
"llm_call_duration_secs",
|
|
85
97
|
"fn_call_duration_secs",
|
|
98
|
+
"total_llm_calls",
|
|
99
|
+
"total_embedding_llm_calls",
|
|
86
100
|
]:
|
|
87
101
|
setattr(self, dataclass_field, getattr(self, dataclass_field) / quotient)
|
|
88
102
|
return self
|
|
@@ -101,6 +115,8 @@ class GenerationStats:
|
|
|
101
115
|
"total_output_cost",
|
|
102
116
|
"llm_call_duration_secs",
|
|
103
117
|
"fn_call_duration_secs",
|
|
118
|
+
"total_llm_calls",
|
|
119
|
+
"total_embedding_llm_calls",
|
|
104
120
|
"cost_per_record",
|
|
105
121
|
]
|
|
106
122
|
}
|
|
@@ -108,6 +124,7 @@ class GenerationStats:
|
|
|
108
124
|
return GenerationStats(**dct)
|
|
109
125
|
|
|
110
126
|
def __radd__(self, other: int) -> GenerationStats:
|
|
127
|
+
assert not isinstance(other, GenerationStats), "This should not be called with a GenerationStats object"
|
|
111
128
|
return self
|
|
112
129
|
|
|
113
130
|
|
|
@@ -198,6 +215,12 @@ class RecordOpStats:
|
|
|
198
215
|
# (if applicable) the time (in seconds) spent executing a UDF or calling an external api
|
|
199
216
|
fn_call_duration_secs: float = 0.0
|
|
200
217
|
|
|
218
|
+
# (if applicable) the total number of LLM calls made by this operator
|
|
219
|
+
total_llm_calls: int = 0
|
|
220
|
+
|
|
221
|
+
# (if applicable) the total number of embedding LLM calls made by this operator
|
|
222
|
+
total_embedding_llm_calls: int = 0
|
|
223
|
+
|
|
201
224
|
# (if applicable) a boolean indicating whether this is the statistics captured from a failed convert operation
|
|
202
225
|
failed_convert: bool | None = None
|
|
203
226
|
|
|
@@ -232,32 +255,41 @@ class OperatorStats:
|
|
|
232
255
|
# a list of RecordOpStats processed by the operation
|
|
233
256
|
record_op_stats_lst: list[RecordOpStats] = field(default_factory=list)
|
|
234
257
|
|
|
258
|
+
# the ID of the physical operator which precedes this one
|
|
259
|
+
source_op_id: str | None = None
|
|
260
|
+
|
|
261
|
+
# the ID of the physical plan which this operator is part of
|
|
262
|
+
plan_id: str = ""
|
|
263
|
+
|
|
235
264
|
# an OPTIONAL dictionary with more detailed information about this operation;
|
|
236
265
|
op_details: dict[str, Any] = field(default_factory=dict)
|
|
237
266
|
|
|
238
|
-
def
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
self.record_op_stats_lst.
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
267
|
+
def __iadd__(self, stats: OperatorStats | RecordOpStats) -> OperatorStats:
|
|
268
|
+
"""
|
|
269
|
+
Sum the given stats to this operator's stats. The given stats can be either:
|
|
270
|
+
|
|
271
|
+
1. an OperatorStats object
|
|
272
|
+
2. a RecordOpStats object
|
|
273
|
+
|
|
274
|
+
NOTE: in case (1.) we assume the execution layer guarantees that `stats` is
|
|
275
|
+
generated by the same operator in the same plan. Thus, we assume the
|
|
276
|
+
op_ids, op_name, source_op_id, etc. do not need to be updated.
|
|
277
|
+
"""
|
|
278
|
+
if isinstance(stats, OperatorStats):
|
|
279
|
+
self.total_op_time += stats.total_op_time
|
|
280
|
+
self.total_op_cost += stats.total_op_cost
|
|
281
|
+
self.record_op_stats_lst.extend(stats.record_op_stats_lst)
|
|
282
|
+
|
|
283
|
+
elif isinstance(stats, RecordOpStats):
|
|
284
|
+
stats.source_op_id = self.source_op_id
|
|
285
|
+
stats.plan_id = self.plan_id
|
|
286
|
+
self.record_op_stats_lst.append(stats)
|
|
287
|
+
self.total_op_time += stats.time_per_record
|
|
288
|
+
self.total_op_cost += stats.cost_per_record
|
|
289
|
+
|
|
290
|
+
else:
|
|
291
|
+
raise TypeError(f"Cannot add {type(stats)} to OperatorStats")
|
|
292
|
+
|
|
261
293
|
return self
|
|
262
294
|
|
|
263
295
|
def to_json(self):
|
|
@@ -270,11 +302,22 @@ class OperatorStats:
|
|
|
270
302
|
"op_details": self.op_details,
|
|
271
303
|
}
|
|
272
304
|
|
|
273
|
-
|
|
274
305
|
@dataclass
|
|
275
|
-
class
|
|
306
|
+
class BasePlanStats:
|
|
276
307
|
"""
|
|
277
308
|
Dataclass for storing statistics captured for an entire plan.
|
|
309
|
+
|
|
310
|
+
This class is subclassed for tracking:
|
|
311
|
+
- PlanStats: the statistics for execution of a PhysicalPlan
|
|
312
|
+
- SentinelPlanStats: the statistics for execution of a SentinelPlan
|
|
313
|
+
|
|
314
|
+
The key difference between the two subclasses is that the `operator_stats`
|
|
315
|
+
field in the PlanStats maps from the physical operator ids to their corresponding
|
|
316
|
+
OperatorStats objects.
|
|
317
|
+
|
|
318
|
+
The `operator_stats` field in the SentinelPlanStats maps from a logical operator id
|
|
319
|
+
to another dictionary which maps from the physical operator ids to their corresponding
|
|
320
|
+
OperatorStats objects.
|
|
278
321
|
"""
|
|
279
322
|
|
|
280
323
|
# id for identifying the physical plan
|
|
@@ -283,8 +326,10 @@ class PlanStats:
|
|
|
283
326
|
# string representation of the physical plan
|
|
284
327
|
plan_str: str | None = None
|
|
285
328
|
|
|
286
|
-
# dictionary
|
|
287
|
-
|
|
329
|
+
# dictionary whose values are OperatorStats objects;
|
|
330
|
+
# PlanStats maps {physical_op_id -> OperatorStats}
|
|
331
|
+
# SentinelPlanStats maps {logical_op_id -> {physical_op_id -> OperatorStats}}
|
|
332
|
+
operator_stats: dict = field(default_factory=dict)
|
|
288
333
|
|
|
289
334
|
# total runtime for the plan measured from the start to the end of PhysicalPlan.execute()
|
|
290
335
|
total_plan_time: float = 0.0
|
|
@@ -292,7 +337,108 @@ class PlanStats:
|
|
|
292
337
|
# total cost for plan
|
|
293
338
|
total_plan_cost: float = 0.0
|
|
294
339
|
|
|
295
|
-
|
|
340
|
+
# start time for the plan execution; should be set by calling PlanStats.start()
|
|
341
|
+
start_time: float | None = None
|
|
342
|
+
|
|
343
|
+
def start(self) -> None:
|
|
344
|
+
"""Start the timer for this plan execution."""
|
|
345
|
+
self.start_time = time.time()
|
|
346
|
+
|
|
347
|
+
def finish(self) -> None:
|
|
348
|
+
"""Finish the timer for this plan execution."""
|
|
349
|
+
if self.start_time is None:
|
|
350
|
+
raise RuntimeError("PlanStats.start() must be called before PlanStats.finish()")
|
|
351
|
+
self.total_plan_time = time.time() - self.start_time
|
|
352
|
+
self.total_plan_cost = self.sum_op_costs()
|
|
353
|
+
|
|
354
|
+
@staticmethod
|
|
355
|
+
@abstractmethod
|
|
356
|
+
def from_plan(plan) -> BasePlanStats:
|
|
357
|
+
"""
|
|
358
|
+
Initialize this PlanStats object from a PhysicalPlan or SentinelPlan object.
|
|
359
|
+
"""
|
|
360
|
+
pass
|
|
361
|
+
|
|
362
|
+
@abstractmethod
|
|
363
|
+
def sum_op_costs(self) -> float:
|
|
364
|
+
"""
|
|
365
|
+
Sum the costs of all operators in this plan.
|
|
366
|
+
"""
|
|
367
|
+
pass
|
|
368
|
+
|
|
369
|
+
@abstractmethod
|
|
370
|
+
def add_record_op_stats(self, record_op_stats: RecordOpStats | list[RecordOpStats]) -> None:
|
|
371
|
+
"""
|
|
372
|
+
Add the given RecordOpStats to this plan's operator stats.
|
|
373
|
+
"""
|
|
374
|
+
pass
|
|
375
|
+
|
|
376
|
+
@abstractmethod
|
|
377
|
+
def __iadd__(self, plan_stats: BasePlanStats) -> None:
|
|
378
|
+
"""
|
|
379
|
+
Add the given PlanStats to this plan's operator stats.
|
|
380
|
+
"""
|
|
381
|
+
pass
|
|
382
|
+
|
|
383
|
+
@abstractmethod
|
|
384
|
+
def __str__(self) -> str:
|
|
385
|
+
"""
|
|
386
|
+
Return a string representation of this plan's statistics.
|
|
387
|
+
"""
|
|
388
|
+
pass
|
|
389
|
+
|
|
390
|
+
@abstractmethod
|
|
391
|
+
def to_json(self) -> dict:
|
|
392
|
+
"""
|
|
393
|
+
Return a JSON representation of this plan's statistics.
|
|
394
|
+
"""
|
|
395
|
+
pass
|
|
396
|
+
|
|
397
|
+
@dataclass
|
|
398
|
+
class PlanStats(BasePlanStats):
|
|
399
|
+
"""
|
|
400
|
+
Subclass of BasePlanStats which captures statistics from the execution of a single PhysicalPlan.
|
|
401
|
+
"""
|
|
402
|
+
@staticmethod
|
|
403
|
+
def from_plan(plan) -> PlanStats:
|
|
404
|
+
"""
|
|
405
|
+
Initialize this PlanStats object from a PhysicalPlan object.
|
|
406
|
+
"""
|
|
407
|
+
operator_stats = {}
|
|
408
|
+
for op_idx, op in enumerate(plan.operators):
|
|
409
|
+
op_id = op.get_op_id()
|
|
410
|
+
operator_stats[op_id] = OperatorStats(
|
|
411
|
+
op_id=op_id,
|
|
412
|
+
op_name=op.op_name(),
|
|
413
|
+
source_op_id=None if op_idx == 0 else plan.operators[op_idx - 1].get_op_id(),
|
|
414
|
+
plan_id=plan.plan_id,
|
|
415
|
+
op_details={k: str(v) for k, v in op.get_id_params().items()},
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
return PlanStats(plan_id=plan.plan_id, plan_str=str(plan), operator_stats=operator_stats)
|
|
419
|
+
|
|
420
|
+
def sum_op_costs(self) -> float:
|
|
421
|
+
"""
|
|
422
|
+
Sum the costs of all operators in this plan.
|
|
423
|
+
"""
|
|
424
|
+
return sum([op_stats.total_op_cost for _, op_stats in self.operator_stats.items()])
|
|
425
|
+
|
|
426
|
+
def add_record_op_stats(self, record_op_stats: RecordOpStats | list[RecordOpStats]) -> None:
|
|
427
|
+
"""
|
|
428
|
+
Add the given RecordOpStats to this plan's operator stats.
|
|
429
|
+
"""
|
|
430
|
+
# normalize input type to be list[RecordOpStats]
|
|
431
|
+
record_op_stats_lst = record_op_stats if isinstance(record_op_stats, list) else [record_op_stats]
|
|
432
|
+
|
|
433
|
+
# update operator stats
|
|
434
|
+
for record_op_stats in record_op_stats_lst:
|
|
435
|
+
op_id = record_op_stats.op_id
|
|
436
|
+
if op_id in self.operator_stats:
|
|
437
|
+
self.operator_stats[op_id] += record_op_stats
|
|
438
|
+
else:
|
|
439
|
+
raise ValueError(f"RecordOpStats with physical_op_id {op_id} not found in PlanStats")
|
|
440
|
+
|
|
441
|
+
def __iadd__(self, plan_stats: PlanStats) -> None:
|
|
296
442
|
"""
|
|
297
443
|
NOTE: we assume the execution layer guarantees:
|
|
298
444
|
1. these plan_stats belong to the same plan
|
|
@@ -302,24 +448,20 @@ class PlanStats:
|
|
|
302
448
|
"""
|
|
303
449
|
self.total_plan_time += plan_stats.total_plan_time
|
|
304
450
|
self.total_plan_cost += plan_stats.total_plan_cost
|
|
305
|
-
for
|
|
306
|
-
if
|
|
307
|
-
self.operator_stats[
|
|
451
|
+
for op_id, op_stats in plan_stats.operator_stats.items():
|
|
452
|
+
if op_id in self.operator_stats:
|
|
453
|
+
self.operator_stats[op_id] += op_stats
|
|
308
454
|
else:
|
|
309
|
-
self.operator_stats[
|
|
455
|
+
self.operator_stats[op_id] = op_stats
|
|
310
456
|
|
|
311
|
-
def
|
|
312
|
-
self.total_plan_time
|
|
313
|
-
self.total_plan_cost
|
|
314
|
-
|
|
315
|
-
def __str__(self):
|
|
316
|
-
stats = f"Total_plan_time={self.total_plan_time} \n"
|
|
317
|
-
stats += f"Total_plan_cost={self.total_plan_cost} \n"
|
|
457
|
+
def __str__(self) -> str:
|
|
458
|
+
stats = f"total_plan_time={self.total_plan_time} \n"
|
|
459
|
+
stats += f"total_plan_cost={self.total_plan_cost} \n"
|
|
318
460
|
for idx, op_stats in enumerate(self.operator_stats.values()):
|
|
319
461
|
stats += f"{idx}. {op_stats.op_name} time={op_stats.total_op_time} cost={op_stats.total_op_cost} \n"
|
|
320
462
|
return stats
|
|
321
463
|
|
|
322
|
-
def to_json(self):
|
|
464
|
+
def to_json(self) -> dict:
|
|
323
465
|
return {
|
|
324
466
|
"plan_id": self.plan_id,
|
|
325
467
|
"plan_str": self.plan_str,
|
|
@@ -329,6 +471,100 @@ class PlanStats:
|
|
|
329
471
|
}
|
|
330
472
|
|
|
331
473
|
|
|
474
|
+
@dataclass
|
|
475
|
+
class SentinelPlanStats(BasePlanStats):
|
|
476
|
+
"""
|
|
477
|
+
Subclass of BasePlanStats which captures statistics from the execution of a single SentinelPlan.
|
|
478
|
+
"""
|
|
479
|
+
@staticmethod
|
|
480
|
+
def from_plan(plan) -> SentinelPlanStats:
|
|
481
|
+
"""
|
|
482
|
+
Initialize this PlanStats object from a Sentinel object.
|
|
483
|
+
"""
|
|
484
|
+
operator_stats = {}
|
|
485
|
+
for op_set_idx, (logical_op_id, op_set) in enumerate(plan):
|
|
486
|
+
operator_stats[logical_op_id] = {}
|
|
487
|
+
for physical_op in op_set:
|
|
488
|
+
op_id = physical_op.get_op_id()
|
|
489
|
+
operator_stats[logical_op_id][op_id] = OperatorStats(
|
|
490
|
+
op_id=op_id,
|
|
491
|
+
op_name=physical_op.op_name(),
|
|
492
|
+
source_op_id=None if op_set_idx == 0 else plan.logical_op_ids[op_set_idx - 1],
|
|
493
|
+
plan_id=plan.plan_id,
|
|
494
|
+
op_details={k: str(v) for k, v in physical_op.get_id_params().items()},
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
return SentinelPlanStats(plan_id=plan.plan_id, plan_str=str(plan), operator_stats=operator_stats)
|
|
498
|
+
|
|
499
|
+
def sum_op_costs(self) -> float:
|
|
500
|
+
"""
|
|
501
|
+
Sum the costs of all operators in this plan.
|
|
502
|
+
"""
|
|
503
|
+
return sum(sum([op_stats.total_op_cost for _, op_stats in phys_op_stats.items()]) for _, phys_op_stats in self.operator_stats.items())
|
|
504
|
+
|
|
505
|
+
def add_record_op_stats(self, record_op_stats: RecordOpStats | list[RecordOpStats]) -> None:
|
|
506
|
+
"""
|
|
507
|
+
Add the given RecordOpStats to this plan's operator stats.
|
|
508
|
+
"""
|
|
509
|
+
# normalize input type to be list[RecordOpStats]
|
|
510
|
+
record_op_stats_lst = record_op_stats if isinstance(record_op_stats, list) else [record_op_stats]
|
|
511
|
+
|
|
512
|
+
# update operator stats
|
|
513
|
+
for record_op_stats in record_op_stats_lst:
|
|
514
|
+
logical_op_id = record_op_stats.logical_op_id
|
|
515
|
+
physical_op_id = record_op_stats.op_id
|
|
516
|
+
if logical_op_id in self.operator_stats:
|
|
517
|
+
if physical_op_id in self.operator_stats[logical_op_id]:
|
|
518
|
+
self.operator_stats[logical_op_id][physical_op_id] += record_op_stats
|
|
519
|
+
else:
|
|
520
|
+
raise ValueError(f"RecordOpStats with physical_op_id {physical_op_id} not found in SentinelPlanStats")
|
|
521
|
+
else:
|
|
522
|
+
raise ValueError(f"RecordOpStats with logical_op_id {logical_op_id} not found in SentinelPlanStats")
|
|
523
|
+
|
|
524
|
+
def __iadd__(self, plan_stats: SentinelPlanStats) -> None:
|
|
525
|
+
"""
|
|
526
|
+
NOTE: we assume the execution layer guarantees:
|
|
527
|
+
1. these plan_stats belong to the same plan
|
|
528
|
+
2. these plan_stats come from sequential (non-overlapping) executions of the same plan
|
|
529
|
+
|
|
530
|
+
The latter criteria implies it is okay for this method to sum the plan (and operator) runtimes.
|
|
531
|
+
"""
|
|
532
|
+
self.total_plan_time += plan_stats.total_plan_time
|
|
533
|
+
self.total_plan_cost += plan_stats.total_plan_cost
|
|
534
|
+
for logical_op_id, physical_op_stats in plan_stats.operator_stats.items():
|
|
535
|
+
for physical_op_id, op_stats in physical_op_stats.items():
|
|
536
|
+
if logical_op_id in self.operator_stats:
|
|
537
|
+
if physical_op_id in self.operator_stats[logical_op_id]:
|
|
538
|
+
self.operator_stats[logical_op_id][physical_op_id] += op_stats
|
|
539
|
+
else:
|
|
540
|
+
self.operator_stats[logical_op_id][physical_op_id] = op_stats
|
|
541
|
+
else:
|
|
542
|
+
self.operator_stats[logical_op_id] = physical_op_stats
|
|
543
|
+
|
|
544
|
+
def __str__(self) -> str:
|
|
545
|
+
stats = f"total_plan_time={self.total_plan_time} \n"
|
|
546
|
+
stats += f"total_plan_cost={self.total_plan_cost} \n"
|
|
547
|
+
for outer_idx, physical_op_stats in enumerate(self.operator_stats.values()):
|
|
548
|
+
total_time = sum([op_stats.total_op_time for op_stats in physical_op_stats.values()])
|
|
549
|
+
total_cost = sum([op_stats.total_op_cost for op_stats in physical_op_stats.values()])
|
|
550
|
+
stats += f"{outer_idx}. total_time={total_time} total_cost={total_cost} \n"
|
|
551
|
+
for inner_idx, op_stats in enumerate(physical_op_stats.values()):
|
|
552
|
+
stats += f" {outer_idx}.{inner_idx}. {op_stats.op_name} time={op_stats.total_op_time} cost={op_stats.total_op_cost} \n"
|
|
553
|
+
return stats
|
|
554
|
+
|
|
555
|
+
def to_json(self) -> dict:
|
|
556
|
+
return {
|
|
557
|
+
"plan_id": self.plan_id,
|
|
558
|
+
"plan_str": self.plan_str,
|
|
559
|
+
"operator_stats": {
|
|
560
|
+
logical_op_id: {physical_op_id: op_stats.to_json() for physical_op_id, op_stats in physical_op_stats.items()}
|
|
561
|
+
for logical_op_id, physical_op_stats in self.operator_stats.items()
|
|
562
|
+
},
|
|
563
|
+
"total_plan_time": self.total_plan_time,
|
|
564
|
+
"total_plan_cost": self.total_plan_cost,
|
|
565
|
+
}
|
|
566
|
+
|
|
567
|
+
|
|
332
568
|
@dataclass
|
|
333
569
|
class ExecutionStats:
|
|
334
570
|
"""
|
|
@@ -338,28 +574,130 @@ class ExecutionStats:
|
|
|
338
574
|
# string for identifying this workload execution
|
|
339
575
|
execution_id: str | None = None
|
|
340
576
|
|
|
577
|
+
# dictionary of SentinelPlanStats objects (one for each sentinel plan run during execution)
|
|
578
|
+
sentinel_plan_stats: dict[str, SentinelPlanStats] = field(default_factory=dict)
|
|
579
|
+
|
|
341
580
|
# dictionary of PlanStats objects (one for each plan run during execution)
|
|
342
581
|
plan_stats: dict[str, PlanStats] = field(default_factory=dict)
|
|
343
582
|
|
|
344
583
|
# total time spent optimizing
|
|
345
|
-
|
|
584
|
+
optimization_time: float = 0.0
|
|
346
585
|
|
|
347
|
-
# total
|
|
586
|
+
# total cost of optimizing
|
|
587
|
+
optimization_cost: float = 0.0
|
|
588
|
+
|
|
589
|
+
# total time spent executing the optimized plan
|
|
590
|
+
plan_execution_time: float = 0.0
|
|
591
|
+
|
|
592
|
+
# total cost of executing the optimized plan
|
|
593
|
+
plan_execution_cost: float = 0.0
|
|
594
|
+
|
|
595
|
+
# total runtime for the entire execution
|
|
348
596
|
total_execution_time: float = 0.0
|
|
349
597
|
|
|
350
|
-
# total cost for
|
|
598
|
+
# total cost for the entire execution
|
|
351
599
|
total_execution_cost: float = 0.0
|
|
352
600
|
|
|
601
|
+
# dictionary of sentinel plan strings; useful for printing executed sentinel plans in demos
|
|
602
|
+
sentinel_plan_strs: dict[str, str] = field(default_factory=dict)
|
|
603
|
+
|
|
353
604
|
# dictionary of plan strings; useful for printing executed plans in demos
|
|
354
605
|
plan_strs: dict[str, str] = field(default_factory=dict)
|
|
355
606
|
|
|
607
|
+
# start time for the execution; should be set by calling ExecutionStats.start()
|
|
608
|
+
start_time: float | None = None
|
|
609
|
+
|
|
610
|
+
# end time for the optimization;
|
|
611
|
+
optimization_end_time: float | None = None
|
|
612
|
+
|
|
613
|
+
def start(self) -> None:
|
|
614
|
+
"""Start the timer for this execution."""
|
|
615
|
+
self.start_time = time.time()
|
|
616
|
+
|
|
617
|
+
def finish_optimization(self) -> None:
|
|
618
|
+
"""Finish the timer for the optimization phase of this execution."""
|
|
619
|
+
if self.start_time is None:
|
|
620
|
+
raise RuntimeError("ExecutionStats.start() must be called before ExecutionStats.finish_optimization()")
|
|
621
|
+
|
|
622
|
+
# compute optimization time and cost
|
|
623
|
+
self.optimization_end_time = time.time()
|
|
624
|
+
self.optimization_time = self.optimization_end_time - self.start_time
|
|
625
|
+
self.optimization_cost = self.sum_sentinel_plan_costs()
|
|
626
|
+
|
|
627
|
+
# compute sentinel_plan_strs
|
|
628
|
+
self.sentinel_plan_strs = {plan_id: plan_stats.plan_str for plan_id, plan_stats in self.sentinel_plan_stats.items()}
|
|
629
|
+
|
|
630
|
+
def finish(self) -> None:
|
|
631
|
+
"""Finish the timer for this execution."""
|
|
632
|
+
if self.start_time is None:
|
|
633
|
+
raise RuntimeError("ExecutionStats.start() must be called before ExecutionStats.finish()")
|
|
634
|
+
|
|
635
|
+
# compute time for plan and total execution
|
|
636
|
+
end_time = time.time()
|
|
637
|
+
self.plan_execution_time = (
|
|
638
|
+
end_time - self.optimization_end_time
|
|
639
|
+
if self.optimization_end_time is not None
|
|
640
|
+
else end_time - self.start_time
|
|
641
|
+
)
|
|
642
|
+
self.total_execution_time = end_time - self.start_time
|
|
643
|
+
|
|
644
|
+
# compute the cost for plan and total execution
|
|
645
|
+
self.plan_execution_cost = self.sum_plan_costs()
|
|
646
|
+
self.total_execution_cost = self.optimization_cost + self.plan_execution_cost
|
|
647
|
+
|
|
648
|
+
# compute plan_strs
|
|
649
|
+
self.plan_strs = {plan_id: plan_stats.plan_str for plan_id, plan_stats in self.plan_stats.items()}
|
|
650
|
+
|
|
651
|
+
def sum_sentinel_plan_costs(self) -> float:
|
|
652
|
+
"""
|
|
653
|
+
Sum the costs of all SentinelPlans in this execution.
|
|
654
|
+
"""
|
|
655
|
+
return sum([plan_stats.sum_op_costs() for _, plan_stats in self.sentinel_plan_stats.items()])
|
|
656
|
+
|
|
657
|
+
def sum_plan_costs(self) -> float:
|
|
658
|
+
"""
|
|
659
|
+
Sum the costs of all PhysicalPlans in this execution.
|
|
660
|
+
"""
|
|
661
|
+
return sum([plan_stats.sum_op_costs() for _, plan_stats in self.plan_stats.items()])
|
|
662
|
+
|
|
663
|
+
def add_plan_stats(self, plan_stats: PlanStats | SentinelPlanStats | list[PlanStats] | list[SentinelPlanStats]) -> None:
|
|
664
|
+
"""
|
|
665
|
+
Add the given PlanStats (or SentinelPlanStats) to this execution's plan stats.
|
|
666
|
+
|
|
667
|
+
NOTE: we make the assumption that the same plan cannot be run more than once in parallel,
|
|
668
|
+
i.e. each plan stats object for an individual plan comes from two different (sequential)
|
|
669
|
+
periods in time. Thus, PlanStats objects can be summed.
|
|
670
|
+
"""
|
|
671
|
+
# normalize input type to be list[PlanStats] or list[SentinelPlanStats]
|
|
672
|
+
if isinstance(plan_stats, (PlanStats, SentinelPlanStats)):
|
|
673
|
+
plan_stats = [plan_stats]
|
|
674
|
+
|
|
675
|
+
for plan_stats_obj in plan_stats:
|
|
676
|
+
if isinstance(plan_stats_obj, PlanStats) and plan_stats_obj.plan_id not in self.plan_stats:
|
|
677
|
+
self.plan_stats[plan_stats_obj.plan_id] = plan_stats_obj
|
|
678
|
+
elif isinstance(plan_stats_obj, PlanStats):
|
|
679
|
+
self.plan_stats[plan_stats_obj.plan_id] += plan_stats_obj
|
|
680
|
+
elif isinstance(plan_stats_obj, SentinelPlanStats) and plan_stats_obj.plan_id not in self.sentinel_plan_stats:
|
|
681
|
+
self.sentinel_plan_stats[plan_stats_obj.plan_id] = plan_stats_obj
|
|
682
|
+
elif isinstance(plan_stats_obj, SentinelPlanStats):
|
|
683
|
+
self.sentinel_plan_stats[plan_stats_obj.plan_id] += plan_stats_obj
|
|
684
|
+
else:
|
|
685
|
+
raise TypeError(f"Cannot add {type(plan_stats)} to ExecutionStats")
|
|
686
|
+
|
|
356
687
|
def to_json(self):
|
|
357
688
|
return {
|
|
358
689
|
"execution_id": self.execution_id,
|
|
690
|
+
"sentinel_plan_stats": {
|
|
691
|
+
plan_id: plan_stats.to_json() for plan_id, plan_stats in self.sentinel_plan_stats.items()
|
|
692
|
+
},
|
|
359
693
|
"plan_stats": {plan_id: plan_stats.to_json() for plan_id, plan_stats in self.plan_stats.items()},
|
|
360
|
-
"
|
|
694
|
+
"optimization_time": self.optimization_time,
|
|
695
|
+
"optimization_cost": self.optimization_cost,
|
|
696
|
+
"plan_execution_time": self.plan_execution_time,
|
|
697
|
+
"plan_execution_cost": self.plan_execution_cost,
|
|
361
698
|
"total_execution_time": self.total_execution_time,
|
|
362
699
|
"total_execution_cost": self.total_execution_cost,
|
|
700
|
+
"sentinel_plan_strs": self.sentinel_plan_strs,
|
|
363
701
|
"plan_strs": self.plan_strs,
|
|
364
702
|
}
|
|
365
703
|
|
|
@@ -16,17 +16,20 @@ class Filter:
|
|
|
16
16
|
self.filter_fn = filter_fn
|
|
17
17
|
|
|
18
18
|
def serialize(self) -> dict[str, Any]:
|
|
19
|
-
return {
|
|
19
|
+
return {
|
|
20
|
+
"filter_condition": self.filter_condition,
|
|
21
|
+
"filter_fn": self.filter_fn.__name__ if self.filter_fn is not None else None,
|
|
22
|
+
}
|
|
20
23
|
|
|
21
24
|
def get_filter_str(self) -> str:
|
|
22
|
-
return self.filter_condition if self.filter_condition is not None else
|
|
25
|
+
return self.filter_condition if self.filter_condition is not None else self.filter_fn.__name__
|
|
23
26
|
|
|
24
27
|
def __repr__(self) -> str:
|
|
25
28
|
return "Filter(" + self.get_filter_str() + ")"
|
|
26
29
|
|
|
27
30
|
def __hash__(self) -> int:
|
|
28
31
|
# custom hash function
|
|
29
|
-
return hash(self.filter_condition) if self.filter_condition is not None else hash(
|
|
32
|
+
return hash(self.filter_condition) if self.filter_condition is not None else hash(self.filter_fn.__name__)
|
|
30
33
|
|
|
31
34
|
def __eq__(self, other) -> bool:
|
|
32
35
|
# __eq__ should be defined for consistency with __hash__
|
|
@@ -35,5 +38,6 @@ class Filter:
|
|
|
35
38
|
and self.filter_condition == other.filter_condition
|
|
36
39
|
and self.filter_fn == other.filter_fn
|
|
37
40
|
)
|
|
41
|
+
|
|
38
42
|
def __str__(self) -> str:
|
|
39
43
|
return self.get_filter_str()
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
from chromadb.api.models.Collection import Collection
|
|
6
|
+
from ragatouille.RAGPretrainedModel import RAGPretrainedModel
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def index_factory(index: Collection | RAGPretrainedModel) -> PZIndex:
|
|
10
|
+
"""
|
|
11
|
+
Factory function to create a PZ index based on the type of the provided index.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
index (Collection | RAGPretrainedModel): The index provided by the user.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
PZIndex: The PZ wrapped Index.
|
|
18
|
+
"""
|
|
19
|
+
if isinstance(index, Collection):
|
|
20
|
+
return ChromaIndex(index)
|
|
21
|
+
elif isinstance(index, RAGPretrainedModel):
|
|
22
|
+
return RagatouilleIndex(index)
|
|
23
|
+
else:
|
|
24
|
+
raise TypeError(f"Unsupported index type: {type(index)}\nindex must be a `chromadb.api.models.Collection.Collection` or `ragatouille.RAGPretrainedModel.RAGPretrainedModel`")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class BaseIndex(ABC):
|
|
28
|
+
|
|
29
|
+
def __init__(self, index: Collection | RAGPretrainedModel):
|
|
30
|
+
self.index = index
|
|
31
|
+
|
|
32
|
+
def __str__(self):
|
|
33
|
+
"""
|
|
34
|
+
Return a string representation of the index.
|
|
35
|
+
"""
|
|
36
|
+
return f"{self.__class__.__name__}"
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def search(self, query_embedding: list[float] | list[list[float]], results_per_query: int = 1) -> list | list[list]:
|
|
40
|
+
"""
|
|
41
|
+
Query the index with a string or a list of strings.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
query (str | list[str]): The query string or list of strings to search for.
|
|
45
|
+
results_per_query (int): The number of top results to retrieve for each query.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
list | list[list]: The top results for the query. If query is a list, the top
|
|
49
|
+
results for each query in the list are returned. Each list will contain the
|
|
50
|
+
raw elements yielded by the index. This way, users can program against the
|
|
51
|
+
results they expect to get from e.g. chromadb or ragatouille.
|
|
52
|
+
"""
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class ChromaIndex(BaseIndex):
|
|
57
|
+
def __init__(self, index: Collection):
|
|
58
|
+
assert isinstance(index, Collection), "ChromaIndex input must be a `chromadb.api.models.Collection.Collection`"
|
|
59
|
+
super().__init__(index)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class RagatouilleIndex(BaseIndex):
|
|
64
|
+
def __init__(self, index: RAGPretrainedModel):
|
|
65
|
+
assert isinstance(index, RAGPretrainedModel), "RagatouilleIndex input must be a `ragatouille.RAGPretrainedModel.RAGPretrainedModel`"
|
|
66
|
+
super().__init__(index)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# define type for PZIndex
|
|
70
|
+
PZIndex = ChromaIndex | RagatouilleIndex
|