palimpzest 0.6.4__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. palimpzest/__init__.py +5 -0
  2. palimpzest/constants.py +110 -43
  3. palimpzest/core/__init__.py +0 -78
  4. palimpzest/core/data/dataclasses.py +382 -44
  5. palimpzest/core/elements/filters.py +7 -3
  6. palimpzest/core/elements/index.py +70 -0
  7. palimpzest/core/elements/records.py +33 -11
  8. palimpzest/core/lib/fields.py +1 -0
  9. palimpzest/core/lib/schemas.py +4 -3
  10. palimpzest/prompts/moa_proposer_convert_prompts.py +0 -4
  11. palimpzest/prompts/prompt_factory.py +44 -7
  12. palimpzest/prompts/split_merge_prompts.py +56 -0
  13. palimpzest/prompts/split_proposer_prompts.py +55 -0
  14. palimpzest/query/execution/execution_strategy.py +435 -53
  15. palimpzest/query/execution/execution_strategy_type.py +20 -0
  16. palimpzest/query/execution/mab_execution_strategy.py +532 -0
  17. palimpzest/query/execution/parallel_execution_strategy.py +143 -172
  18. palimpzest/query/execution/random_sampling_execution_strategy.py +240 -0
  19. palimpzest/query/execution/single_threaded_execution_strategy.py +173 -203
  20. palimpzest/query/generators/api_client_factory.py +31 -0
  21. palimpzest/query/generators/generators.py +256 -76
  22. palimpzest/query/operators/__init__.py +1 -2
  23. palimpzest/query/operators/code_synthesis_convert.py +33 -18
  24. palimpzest/query/operators/convert.py +30 -97
  25. palimpzest/query/operators/critique_and_refine_convert.py +5 -6
  26. palimpzest/query/operators/filter.py +7 -10
  27. palimpzest/query/operators/logical.py +54 -10
  28. palimpzest/query/operators/map.py +130 -0
  29. palimpzest/query/operators/mixture_of_agents_convert.py +6 -6
  30. palimpzest/query/operators/physical.py +3 -12
  31. palimpzest/query/operators/rag_convert.py +66 -18
  32. palimpzest/query/operators/retrieve.py +230 -34
  33. palimpzest/query/operators/scan.py +5 -2
  34. palimpzest/query/operators/split_convert.py +169 -0
  35. palimpzest/query/operators/token_reduction_convert.py +8 -14
  36. palimpzest/query/optimizer/__init__.py +4 -16
  37. palimpzest/query/optimizer/cost_model.py +73 -266
  38. palimpzest/query/optimizer/optimizer.py +87 -58
  39. palimpzest/query/optimizer/optimizer_strategy.py +18 -97
  40. palimpzest/query/optimizer/optimizer_strategy_type.py +37 -0
  41. palimpzest/query/optimizer/plan.py +2 -3
  42. palimpzest/query/optimizer/primitives.py +5 -3
  43. palimpzest/query/optimizer/rules.py +336 -172
  44. palimpzest/query/optimizer/tasks.py +30 -100
  45. palimpzest/query/processor/config.py +38 -22
  46. palimpzest/query/processor/nosentinel_processor.py +16 -520
  47. palimpzest/query/processor/processing_strategy_type.py +28 -0
  48. palimpzest/query/processor/query_processor.py +38 -206
  49. palimpzest/query/processor/query_processor_factory.py +117 -130
  50. palimpzest/query/processor/sentinel_processor.py +90 -0
  51. palimpzest/query/processor/streaming_processor.py +25 -32
  52. palimpzest/sets.py +88 -41
  53. palimpzest/utils/model_helpers.py +8 -7
  54. palimpzest/utils/progress.py +368 -152
  55. palimpzest/utils/token_reduction_helpers.py +1 -3
  56. {palimpzest-0.6.4.dist-info → palimpzest-0.7.0.dist-info}/METADATA +19 -9
  57. palimpzest-0.7.0.dist-info/RECORD +96 -0
  58. {palimpzest-0.6.4.dist-info → palimpzest-0.7.0.dist-info}/WHEEL +1 -1
  59. palimpzest/query/processor/mab_sentinel_processor.py +0 -884
  60. palimpzest/query/processor/random_sampling_sentinel_processor.py +0 -639
  61. palimpzest/utils/index_helpers.py +0 -6
  62. palimpzest-0.6.4.dist-info/RECORD +0 -87
  63. {palimpzest-0.6.4.dist-info → palimpzest-0.7.0.dist-info/licenses}/LICENSE +0 -0
  64. {palimpzest-0.6.4.dist-info → palimpzest-0.7.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import time
4
+ from abc import abstractmethod
3
5
  from dataclasses import dataclass, field, fields
4
6
  from typing import Any
5
7
 
@@ -38,6 +40,12 @@ class GenerationStats:
38
40
  # (if applicable) the time (in seconds) spent executing a call to a function
39
41
  fn_call_duration_secs: float = 0.0
40
42
 
43
+ # (if applicable) the total number of LLM calls made by this operator
44
+ total_llm_calls: int = 0
45
+
46
+ # (if applicable) the total number of embedding LLM calls made by this operator
47
+ total_embedding_llm_calls: int = 0
48
+
41
49
  def __iadd__(self, other: GenerationStats) -> GenerationStats:
42
50
  # self.raw_answers.extend(other.raw_answers)
43
51
  for dataclass_field in [
@@ -48,6 +56,8 @@ class GenerationStats:
48
56
  "cost_per_record",
49
57
  "llm_call_duration_secs",
50
58
  "fn_call_duration_secs",
59
+ "total_llm_calls",
60
+ "total_embedding_llm_calls",
51
61
  ]:
52
62
  setattr(self, dataclass_field, getattr(self, dataclass_field) + getattr(other, dataclass_field))
53
63
  return self
@@ -63,6 +73,8 @@ class GenerationStats:
63
73
  "llm_call_duration_secs",
64
74
  "fn_call_duration_secs",
65
75
  "cost_per_record",
76
+ "total_llm_calls",
77
+ "total_embedding_llm_calls",
66
78
  ]
67
79
  }
68
80
  # dct['raw_answers'] = self.raw_answers + other.raw_answers
@@ -83,6 +95,8 @@ class GenerationStats:
83
95
  "cost_per_record",
84
96
  "llm_call_duration_secs",
85
97
  "fn_call_duration_secs",
98
+ "total_llm_calls",
99
+ "total_embedding_llm_calls",
86
100
  ]:
87
101
  setattr(self, dataclass_field, getattr(self, dataclass_field) / quotient)
88
102
  return self
@@ -101,6 +115,8 @@ class GenerationStats:
101
115
  "total_output_cost",
102
116
  "llm_call_duration_secs",
103
117
  "fn_call_duration_secs",
118
+ "total_llm_calls",
119
+ "total_embedding_llm_calls",
104
120
  "cost_per_record",
105
121
  ]
106
122
  }
@@ -108,6 +124,7 @@ class GenerationStats:
108
124
  return GenerationStats(**dct)
109
125
 
110
126
  def __radd__(self, other: int) -> GenerationStats:
127
+ assert not isinstance(other, GenerationStats), "This should not be called with a GenerationStats object"
111
128
  return self
112
129
 
113
130
 
@@ -198,6 +215,12 @@ class RecordOpStats:
198
215
  # (if applicable) the time (in seconds) spent executing a UDF or calling an external api
199
216
  fn_call_duration_secs: float = 0.0
200
217
 
218
+ # (if applicable) the total number of LLM calls made by this operator
219
+ total_llm_calls: int = 0
220
+
221
+ # (if applicable) the total number of embedding LLM calls made by this operator
222
+ total_embedding_llm_calls: int = 0
223
+
201
224
  # (if applicable) a boolean indicating whether this is the statistics captured from a failed convert operation
202
225
  failed_convert: bool | None = None
203
226
 
@@ -232,32 +255,41 @@ class OperatorStats:
232
255
  # a list of RecordOpStats processed by the operation
233
256
  record_op_stats_lst: list[RecordOpStats] = field(default_factory=list)
234
257
 
258
+ # the ID of the physical operator which precedes this one
259
+ source_op_id: str | None = None
260
+
261
+ # the ID of the physical plan which this operator is part of
262
+ plan_id: str = ""
263
+
235
264
  # an OPTIONAL dictionary with more detailed information about this operation;
236
265
  op_details: dict[str, Any] = field(default_factory=dict)
237
266
 
238
- def add_record_op_stats(
239
- self,
240
- record_op_stats_lst: RecordOpStats | list[RecordOpStats],
241
- source_op_id: str | None,
242
- plan_id: str,
243
- ):
244
- # convert individual record into list
245
- if not isinstance(record_op_stats_lst, list):
246
- record_op_stats_lst = [record_op_stats_lst]
247
-
248
- # update op stats
249
- for record_op_stats in record_op_stats_lst:
250
- record_op_stats.source_op_id = source_op_id
251
- record_op_stats.plan_id = plan_id
252
- self.record_op_stats_lst.append(record_op_stats)
253
- self.total_op_time += record_op_stats.time_per_record
254
- self.total_op_cost += record_op_stats.cost_per_record
255
-
256
- def __iadd__(self, op_stats: OperatorStats):
257
- """NOTE: we assume the execution layer guarantees these op_stats belong to the same operator."""
258
- self.total_op_time += op_stats.total_op_time
259
- self.total_op_cost += op_stats.total_op_cost
260
- self.record_op_stats_lst.extend(op_stats.record_op_stats_lst)
267
+ def __iadd__(self, stats: OperatorStats | RecordOpStats) -> OperatorStats:
268
+ """
269
+ Sum the given stats to this operator's stats. The given stats can be either:
270
+
271
+ 1. an OperatorStats object
272
+ 2. a RecordOpStats object
273
+
274
+ NOTE: in case (1.) we assume the execution layer guarantees that `stats` is
275
+ generated by the same operator in the same plan. Thus, we assume the
276
+ op_ids, op_name, source_op_id, etc. do not need to be updated.
277
+ """
278
+ if isinstance(stats, OperatorStats):
279
+ self.total_op_time += stats.total_op_time
280
+ self.total_op_cost += stats.total_op_cost
281
+ self.record_op_stats_lst.extend(stats.record_op_stats_lst)
282
+
283
+ elif isinstance(stats, RecordOpStats):
284
+ stats.source_op_id = self.source_op_id
285
+ stats.plan_id = self.plan_id
286
+ self.record_op_stats_lst.append(stats)
287
+ self.total_op_time += stats.time_per_record
288
+ self.total_op_cost += stats.cost_per_record
289
+
290
+ else:
291
+ raise TypeError(f"Cannot add {type(stats)} to OperatorStats")
292
+
261
293
  return self
262
294
 
263
295
  def to_json(self):
@@ -270,11 +302,22 @@ class OperatorStats:
270
302
  "op_details": self.op_details,
271
303
  }
272
304
 
273
-
274
305
  @dataclass
275
- class PlanStats:
306
+ class BasePlanStats:
276
307
  """
277
308
  Dataclass for storing statistics captured for an entire plan.
309
+
310
+ This class is subclassed for tracking:
311
+ - PlanStats: the statistics for execution of a PhysicalPlan
312
+ - SentinelPlanStats: the statistics for execution of a SentinelPlan
313
+
314
+ The key difference between the two subclasses is that the `operator_stats`
315
+ field in the PlanStats maps from the physical operator ids to their corresponding
316
+ OperatorStats objects.
317
+
318
+ The `operator_stats` field in the SentinelPlanStats maps from a logical operator id
319
+ to another dictionary which maps from the physical operator ids to their corresponding
320
+ OperatorStats objects.
278
321
  """
279
322
 
280
323
  # id for identifying the physical plan
@@ -283,8 +326,10 @@ class PlanStats:
283
326
  # string representation of the physical plan
284
327
  plan_str: str | None = None
285
328
 
286
- # dictionary of OperatorStats objects (one for each operator)
287
- operator_stats: dict[str, OperatorStats] = field(default_factory=dict)
329
+ # dictionary whose values are OperatorStats objects;
330
+ # PlanStats maps {physical_op_id -> OperatorStats}
331
+ # SentinelPlanStats maps {logical_op_id -> {physical_op_id -> OperatorStats}}
332
+ operator_stats: dict = field(default_factory=dict)
288
333
 
289
334
  # total runtime for the plan measured from the start to the end of PhysicalPlan.execute()
290
335
  total_plan_time: float = 0.0
@@ -292,7 +337,108 @@ class PlanStats:
292
337
  # total cost for plan
293
338
  total_plan_cost: float = 0.0
294
339
 
295
- def __iadd__(self, plan_stats: PlanStats):
340
+ # start time for the plan execution; should be set by calling PlanStats.start()
341
+ start_time: float | None = None
342
+
343
+ def start(self) -> None:
344
+ """Start the timer for this plan execution."""
345
+ self.start_time = time.time()
346
+
347
+ def finish(self) -> None:
348
+ """Finish the timer for this plan execution."""
349
+ if self.start_time is None:
350
+ raise RuntimeError("PlanStats.start() must be called before PlanStats.finish()")
351
+ self.total_plan_time = time.time() - self.start_time
352
+ self.total_plan_cost = self.sum_op_costs()
353
+
354
+ @staticmethod
355
+ @abstractmethod
356
+ def from_plan(plan) -> BasePlanStats:
357
+ """
358
+ Initialize this PlanStats object from a PhysicalPlan or SentinelPlan object.
359
+ """
360
+ pass
361
+
362
+ @abstractmethod
363
+ def sum_op_costs(self) -> float:
364
+ """
365
+ Sum the costs of all operators in this plan.
366
+ """
367
+ pass
368
+
369
+ @abstractmethod
370
+ def add_record_op_stats(self, record_op_stats: RecordOpStats | list[RecordOpStats]) -> None:
371
+ """
372
+ Add the given RecordOpStats to this plan's operator stats.
373
+ """
374
+ pass
375
+
376
+ @abstractmethod
377
+ def __iadd__(self, plan_stats: BasePlanStats) -> None:
378
+ """
379
+ Add the given PlanStats to this plan's operator stats.
380
+ """
381
+ pass
382
+
383
+ @abstractmethod
384
+ def __str__(self) -> str:
385
+ """
386
+ Return a string representation of this plan's statistics.
387
+ """
388
+ pass
389
+
390
+ @abstractmethod
391
+ def to_json(self) -> dict:
392
+ """
393
+ Return a JSON representation of this plan's statistics.
394
+ """
395
+ pass
396
+
397
+ @dataclass
398
+ class PlanStats(BasePlanStats):
399
+ """
400
+ Subclass of BasePlanStats which captures statistics from the execution of a single PhysicalPlan.
401
+ """
402
+ @staticmethod
403
+ def from_plan(plan) -> PlanStats:
404
+ """
405
+ Initialize this PlanStats object from a PhysicalPlan object.
406
+ """
407
+ operator_stats = {}
408
+ for op_idx, op in enumerate(plan.operators):
409
+ op_id = op.get_op_id()
410
+ operator_stats[op_id] = OperatorStats(
411
+ op_id=op_id,
412
+ op_name=op.op_name(),
413
+ source_op_id=None if op_idx == 0 else plan.operators[op_idx - 1].get_op_id(),
414
+ plan_id=plan.plan_id,
415
+ op_details={k: str(v) for k, v in op.get_id_params().items()},
416
+ )
417
+
418
+ return PlanStats(plan_id=plan.plan_id, plan_str=str(plan), operator_stats=operator_stats)
419
+
420
+ def sum_op_costs(self) -> float:
421
+ """
422
+ Sum the costs of all operators in this plan.
423
+ """
424
+ return sum([op_stats.total_op_cost for _, op_stats in self.operator_stats.items()])
425
+
426
+ def add_record_op_stats(self, record_op_stats: RecordOpStats | list[RecordOpStats]) -> None:
427
+ """
428
+ Add the given RecordOpStats to this plan's operator stats.
429
+ """
430
+ # normalize input type to be list[RecordOpStats]
431
+ record_op_stats_lst = record_op_stats if isinstance(record_op_stats, list) else [record_op_stats]
432
+
433
+ # update operator stats
434
+ for record_op_stats in record_op_stats_lst:
435
+ op_id = record_op_stats.op_id
436
+ if op_id in self.operator_stats:
437
+ self.operator_stats[op_id] += record_op_stats
438
+ else:
439
+ raise ValueError(f"RecordOpStats with physical_op_id {op_id} not found in PlanStats")
440
+
441
+ def __iadd__(self, plan_stats: PlanStats) -> None:
296
442
  """
297
443
  NOTE: we assume the execution layer guarantees:
298
444
  1. these plan_stats belong to the same plan
@@ -302,24 +448,20 @@ class PlanStats:
302
448
  """
303
449
  self.total_plan_time += plan_stats.total_plan_time
304
450
  self.total_plan_cost += plan_stats.total_plan_cost
305
- for op, op_stats in plan_stats.operator_stats.items():
306
- if op in self.operator_stats:
307
- self.operator_stats[op] += op_stats
451
+ for op_id, op_stats in plan_stats.operator_stats.items():
452
+ if op_id in self.operator_stats:
453
+ self.operator_stats[op_id] += op_stats
308
454
  else:
309
- self.operator_stats[op] = op_stats
455
+ self.operator_stats[op_id] = op_stats
310
456
 
311
- def finalize(self, total_plan_time: float):
312
- self.total_plan_time = total_plan_time
313
- self.total_plan_cost = sum([op_stats.total_op_cost for _, op_stats in self.operator_stats.items()])
314
-
315
- def __str__(self):
316
- stats = f"Total_plan_time={self.total_plan_time} \n"
317
- stats += f"Total_plan_cost={self.total_plan_cost} \n"
457
+ def __str__(self) -> str:
458
+ stats = f"total_plan_time={self.total_plan_time} \n"
459
+ stats += f"total_plan_cost={self.total_plan_cost} \n"
318
460
  for idx, op_stats in enumerate(self.operator_stats.values()):
319
461
  stats += f"{idx}. {op_stats.op_name} time={op_stats.total_op_time} cost={op_stats.total_op_cost} \n"
320
462
  return stats
321
463
 
322
- def to_json(self):
464
+ def to_json(self) -> dict:
323
465
  return {
324
466
  "plan_id": self.plan_id,
325
467
  "plan_str": self.plan_str,
@@ -329,6 +471,100 @@ class PlanStats:
329
471
  }
330
472
 
331
473
 
474
+ @dataclass
475
+ class SentinelPlanStats(BasePlanStats):
476
+ """
477
+ Subclass of BasePlanStats which captures statistics from the execution of a single SentinelPlan.
478
+ """
479
+ @staticmethod
480
+ def from_plan(plan) -> SentinelPlanStats:
481
+ """
482
+ Initialize this PlanStats object from a Sentinel object.
483
+ """
484
+ operator_stats = {}
485
+ for op_set_idx, (logical_op_id, op_set) in enumerate(plan):
486
+ operator_stats[logical_op_id] = {}
487
+ for physical_op in op_set:
488
+ op_id = physical_op.get_op_id()
489
+ operator_stats[logical_op_id][op_id] = OperatorStats(
490
+ op_id=op_id,
491
+ op_name=physical_op.op_name(),
492
+ source_op_id=None if op_set_idx == 0 else plan.logical_op_ids[op_set_idx - 1],
493
+ plan_id=plan.plan_id,
494
+ op_details={k: str(v) for k, v in physical_op.get_id_params().items()},
495
+ )
496
+
497
+ return SentinelPlanStats(plan_id=plan.plan_id, plan_str=str(plan), operator_stats=operator_stats)
498
+
499
+ def sum_op_costs(self) -> float:
500
+ """
501
+ Sum the costs of all operators in this plan.
502
+ """
503
+ return sum(sum([op_stats.total_op_cost for _, op_stats in phys_op_stats.items()]) for _, phys_op_stats in self.operator_stats.items())
504
+
505
+ def add_record_op_stats(self, record_op_stats: RecordOpStats | list[RecordOpStats]) -> None:
506
+ """
507
+ Add the given RecordOpStats to this plan's operator stats.
508
+ """
509
+ # normalize input type to be list[RecordOpStats]
510
+ record_op_stats_lst = record_op_stats if isinstance(record_op_stats, list) else [record_op_stats]
511
+
512
+ # update operator stats
513
+ for record_op_stats in record_op_stats_lst:
514
+ logical_op_id = record_op_stats.logical_op_id
515
+ physical_op_id = record_op_stats.op_id
516
+ if logical_op_id in self.operator_stats:
517
+ if physical_op_id in self.operator_stats[logical_op_id]:
518
+ self.operator_stats[logical_op_id][physical_op_id] += record_op_stats
519
+ else:
520
+ raise ValueError(f"RecordOpStats with physical_op_id {physical_op_id} not found in SentinelPlanStats")
521
+ else:
522
+ raise ValueError(f"RecordOpStats with logical_op_id {logical_op_id} not found in SentinelPlanStats")
523
+
524
+ def __iadd__(self, plan_stats: SentinelPlanStats) -> None:
525
+ """
526
+ NOTE: we assume the execution layer guarantees:
527
+ 1. these plan_stats belong to the same plan
528
+ 2. these plan_stats come from sequential (non-overlapping) executions of the same plan
529
+
530
+ The latter criteria implies it is okay for this method to sum the plan (and operator) runtimes.
531
+ """
532
+ self.total_plan_time += plan_stats.total_plan_time
533
+ self.total_plan_cost += plan_stats.total_plan_cost
534
+ for logical_op_id, physical_op_stats in plan_stats.operator_stats.items():
535
+ for physical_op_id, op_stats in physical_op_stats.items():
536
+ if logical_op_id in self.operator_stats:
537
+ if physical_op_id in self.operator_stats[logical_op_id]:
538
+ self.operator_stats[logical_op_id][physical_op_id] += op_stats
539
+ else:
540
+ self.operator_stats[logical_op_id][physical_op_id] = op_stats
541
+ else:
542
+ self.operator_stats[logical_op_id] = physical_op_stats
543
+
544
+ def __str__(self) -> str:
545
+ stats = f"total_plan_time={self.total_plan_time} \n"
546
+ stats += f"total_plan_cost={self.total_plan_cost} \n"
547
+ for outer_idx, physical_op_stats in enumerate(self.operator_stats.values()):
548
+ total_time = sum([op_stats.total_op_time for op_stats in physical_op_stats.values()])
549
+ total_cost = sum([op_stats.total_op_cost for op_stats in physical_op_stats.values()])
550
+ stats += f"{outer_idx}. total_time={total_time} total_cost={total_cost} \n"
551
+ for inner_idx, op_stats in enumerate(physical_op_stats.values()):
552
+ stats += f" {outer_idx}.{inner_idx}. {op_stats.op_name} time={op_stats.total_op_time} cost={op_stats.total_op_cost} \n"
553
+ return stats
554
+
555
+ def to_json(self) -> dict:
556
+ return {
557
+ "plan_id": self.plan_id,
558
+ "plan_str": self.plan_str,
559
+ "operator_stats": {
560
+ logical_op_id: {physical_op_id: op_stats.to_json() for physical_op_id, op_stats in physical_op_stats.items()}
561
+ for logical_op_id, physical_op_stats in self.operator_stats.items()
562
+ },
563
+ "total_plan_time": self.total_plan_time,
564
+ "total_plan_cost": self.total_plan_cost,
565
+ }
566
+
567
+
332
568
  @dataclass
333
569
  class ExecutionStats:
334
570
  """
@@ -338,28 +574,130 @@ class ExecutionStats:
338
574
  # string for identifying this workload execution
339
575
  execution_id: str | None = None
340
576
 
577
+ # dictionary of SentinelPlanStats objects (one for each sentinel plan run during execution)
578
+ sentinel_plan_stats: dict[str, SentinelPlanStats] = field(default_factory=dict)
579
+
341
580
  # dictionary of PlanStats objects (one for each plan run during execution)
342
581
  plan_stats: dict[str, PlanStats] = field(default_factory=dict)
343
582
 
344
583
  # total time spent optimizing
345
- total_optimization_time: float = 0.0
584
+ optimization_time: float = 0.0
346
585
 
347
- # total runtime for a plan's execution
586
+ # total cost of optimizing
587
+ optimization_cost: float = 0.0
588
+
589
+ # total time spent executing the optimized plan
590
+ plan_execution_time: float = 0.0
591
+
592
+ # total cost of executing the optimized plan
593
+ plan_execution_cost: float = 0.0
594
+
595
+ # total runtime for the entire execution
348
596
  total_execution_time: float = 0.0
349
597
 
350
- # total cost for a plan's execution
598
+ # total cost for the entire execution
351
599
  total_execution_cost: float = 0.0
352
600
 
601
+ # dictionary of sentinel plan strings; useful for printing executed sentinel plans in demos
602
+ sentinel_plan_strs: dict[str, str] = field(default_factory=dict)
603
+
353
604
  # dictionary of plan strings; useful for printing executed plans in demos
354
605
  plan_strs: dict[str, str] = field(default_factory=dict)
355
606
 
607
+ # start time for the execution; should be set by calling ExecutionStats.start()
608
+ start_time: float | None = None
609
+
610
+ # end time for the optimization;
611
+ optimization_end_time: float | None = None
612
+
613
+ def start(self) -> None:
614
+ """Start the timer for this execution."""
615
+ self.start_time = time.time()
616
+
617
+ def finish_optimization(self) -> None:
618
+ """Finish the timer for the optimization phase of this execution."""
619
+ if self.start_time is None:
620
+ raise RuntimeError("ExecutionStats.start() must be called before ExecutionStats.finish_optimization()")
621
+
622
+ # compute optimization time and cost
623
+ self.optimization_end_time = time.time()
624
+ self.optimization_time = self.optimization_end_time - self.start_time
625
+ self.optimization_cost = self.sum_sentinel_plan_costs()
626
+
627
+ # compute sentinel_plan_strs
628
+ self.sentinel_plan_strs = {plan_id: plan_stats.plan_str for plan_id, plan_stats in self.sentinel_plan_stats.items()}
629
+
630
+ def finish(self) -> None:
631
+ """Finish the timer for this execution."""
632
+ if self.start_time is None:
633
+ raise RuntimeError("ExecutionStats.start() must be called before ExecutionStats.finish()")
634
+
635
+ # compute time for plan and total execution
636
+ end_time = time.time()
637
+ self.plan_execution_time = (
638
+ end_time - self.optimization_end_time
639
+ if self.optimization_end_time is not None
640
+ else end_time - self.start_time
641
+ )
642
+ self.total_execution_time = end_time - self.start_time
643
+
644
+ # compute the cost for plan and total execution
645
+ self.plan_execution_cost = self.sum_plan_costs()
646
+ self.total_execution_cost = self.optimization_cost + self.plan_execution_cost
647
+
648
+ # compute plan_strs
649
+ self.plan_strs = {plan_id: plan_stats.plan_str for plan_id, plan_stats in self.plan_stats.items()}
650
+
651
+ def sum_sentinel_plan_costs(self) -> float:
652
+ """
653
+ Sum the costs of all SentinelPlans in this execution.
654
+ """
655
+ return sum([plan_stats.sum_op_costs() for _, plan_stats in self.sentinel_plan_stats.items()])
656
+
657
+ def sum_plan_costs(self) -> float:
658
+ """
659
+ Sum the costs of all PhysicalPlans in this execution.
660
+ """
661
+ return sum([plan_stats.sum_op_costs() for _, plan_stats in self.plan_stats.items()])
662
+
663
+ def add_plan_stats(self, plan_stats: PlanStats | SentinelPlanStats | list[PlanStats] | list[SentinelPlanStats]) -> None:
664
+ """
665
+ Add the given PlanStats (or SentinelPlanStats) to this execution's plan stats.
666
+
667
+ NOTE: we make the assumption that the same plan cannot be run more than once in parallel,
668
+ i.e. each plan stats object for an individual plan comes from two different (sequential)
669
+ periods in time. Thus, PlanStats objects can be summed.
670
+ """
671
+ # normalize input type to be list[PlanStats] or list[SentinelPlanStats]
672
+ if isinstance(plan_stats, (PlanStats, SentinelPlanStats)):
673
+ plan_stats = [plan_stats]
674
+
675
+ for plan_stats_obj in plan_stats:
676
+ if isinstance(plan_stats_obj, PlanStats) and plan_stats_obj.plan_id not in self.plan_stats:
677
+ self.plan_stats[plan_stats_obj.plan_id] = plan_stats_obj
678
+ elif isinstance(plan_stats_obj, PlanStats):
679
+ self.plan_stats[plan_stats_obj.plan_id] += plan_stats_obj
680
+ elif isinstance(plan_stats_obj, SentinelPlanStats) and plan_stats_obj.plan_id not in self.sentinel_plan_stats:
681
+ self.sentinel_plan_stats[plan_stats_obj.plan_id] = plan_stats_obj
682
+ elif isinstance(plan_stats_obj, SentinelPlanStats):
683
+ self.sentinel_plan_stats[plan_stats_obj.plan_id] += plan_stats_obj
684
+ else:
685
+ raise TypeError(f"Cannot add {type(plan_stats)} to ExecutionStats")
686
+
356
687
  def to_json(self):
357
688
  return {
358
689
  "execution_id": self.execution_id,
690
+ "sentinel_plan_stats": {
691
+ plan_id: plan_stats.to_json() for plan_id, plan_stats in self.sentinel_plan_stats.items()
692
+ },
359
693
  "plan_stats": {plan_id: plan_stats.to_json() for plan_id, plan_stats in self.plan_stats.items()},
360
- "total_optimization_time": self.total_optimization_time,
694
+ "optimization_time": self.optimization_time,
695
+ "optimization_cost": self.optimization_cost,
696
+ "plan_execution_time": self.plan_execution_time,
697
+ "plan_execution_cost": self.plan_execution_cost,
361
698
  "total_execution_time": self.total_execution_time,
362
699
  "total_execution_cost": self.total_execution_cost,
700
+ "sentinel_plan_strs": self.sentinel_plan_strs,
363
701
  "plan_strs": self.plan_strs,
364
702
  }
365
703
 
@@ -16,17 +16,20 @@ class Filter:
16
16
  self.filter_fn = filter_fn
17
17
 
18
18
  def serialize(self) -> dict[str, Any]:
19
- return {"filter_condition": self.filter_condition, "filter_fn": str(self.filter_fn)}
19
+ return {
20
+ "filter_condition": self.filter_condition,
21
+ "filter_fn": self.filter_fn.__name__ if self.filter_fn is not None else None,
22
+ }
20
23
 
21
24
  def get_filter_str(self) -> str:
22
- return self.filter_condition if self.filter_condition is not None else str(self.filter_fn)
25
+ return self.filter_condition if self.filter_condition is not None else self.filter_fn.__name__
23
26
 
24
27
  def __repr__(self) -> str:
25
28
  return "Filter(" + self.get_filter_str() + ")"
26
29
 
27
30
  def __hash__(self) -> int:
28
31
  # custom hash function
29
- return hash(self.filter_condition) if self.filter_condition is not None else hash(str(self.filter_fn))
32
+ return hash(self.filter_condition) if self.filter_condition is not None else hash(self.filter_fn.__name__)
30
33
 
31
34
  def __eq__(self, other) -> bool:
32
35
  # __eq__ should be defined for consistency with __hash__
@@ -35,5 +38,6 @@ class Filter:
35
38
  and self.filter_condition == other.filter_condition
36
39
  and self.filter_fn == other.filter_fn
37
40
  )
41
+
38
42
  def __str__(self) -> str:
39
43
  return self.get_filter_str()
@@ -0,0 +1,70 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+ from chromadb.api.models.Collection import Collection
6
+ from ragatouille.RAGPretrainedModel import RAGPretrainedModel
7
+
8
+
9
+ def index_factory(index: Collection | RAGPretrainedModel) -> PZIndex:
10
+ """
11
+ Factory function to create a PZ index based on the type of the provided index.
12
+
13
+ Args:
14
+ index (Collection | RAGPretrainedModel): The index provided by the user.
15
+
16
+ Returns:
17
+ PZIndex: The PZ wrapped Index.
18
+ """
19
+ if isinstance(index, Collection):
20
+ return ChromaIndex(index)
21
+ elif isinstance(index, RAGPretrainedModel):
22
+ return RagatouilleIndex(index)
23
+ else:
24
+ raise TypeError(f"Unsupported index type: {type(index)}\nindex must be a `chromadb.api.models.Collection.Collection` or `ragatouille.RAGPretrainedModel.RAGPretrainedModel`")
25
+
26
+
27
+ class BaseIndex(ABC):
28
+
29
+ def __init__(self, index: Collection | RAGPretrainedModel):
30
+ self.index = index
31
+
32
+ def __str__(self):
33
+ """
34
+ Return a string representation of the index.
35
+ """
36
+ return f"{self.__class__.__name__}"
37
+
38
+ @abstractmethod
39
+ def search(self, query_embedding: list[float] | list[list[float]], results_per_query: int = 1) -> list | list[list]:
40
+ """
41
+ Query the index with a string or a list of strings.
42
+
43
+ Args:
44
+ query (str | list[str]): The query string or list of strings to search for.
45
+ results_per_query (int): The number of top results to retrieve for each query.
46
+
47
+ Returns:
48
+ list | list[list]: The top results for the query. If query is a list, the top
49
+ results for each query in the list are returned. Each list will contain the
50
+ raw elements yielded by the index. This way, users can program against the
51
+ results they expect to get from e.g. chromadb or ragatouille.
52
+ """
53
+ pass
54
+
55
+
56
+ class ChromaIndex(BaseIndex):
57
+ def __init__(self, index: Collection):
58
+ assert isinstance(index, Collection), "ChromaIndex input must be a `chromadb.api.models.Collection.Collection`"
59
+ super().__init__(index)
60
+
61
+
62
+
63
+ class RagatouilleIndex(BaseIndex):
64
+ def __init__(self, index: RAGPretrainedModel):
65
+ assert isinstance(index, RAGPretrainedModel), "RagatouilleIndex input must be a `ragatouille.RAGPretrainedModel.RAGPretrainedModel`"
66
+ super().__init__(index)
67
+
68
+
69
+ # define type for PZIndex
70
+ PZIndex = ChromaIndex | RagatouilleIndex