palimpzest 0.7.20__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. palimpzest/__init__.py +37 -6
  2. palimpzest/agents/__init__.py +0 -0
  3. palimpzest/agents/compute_agents.py +0 -0
  4. palimpzest/agents/search_agents.py +637 -0
  5. palimpzest/constants.py +259 -197
  6. palimpzest/core/data/context.py +393 -0
  7. palimpzest/core/data/context_manager.py +163 -0
  8. palimpzest/core/data/dataset.py +634 -0
  9. palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
  10. palimpzest/core/elements/groupbysig.py +16 -13
  11. palimpzest/core/elements/records.py +166 -75
  12. palimpzest/core/lib/schemas.py +152 -390
  13. palimpzest/core/{data/dataclasses.py → models.py} +306 -170
  14. palimpzest/policy.py +2 -27
  15. palimpzest/prompts/__init__.py +35 -5
  16. palimpzest/prompts/agent_prompts.py +357 -0
  17. palimpzest/prompts/context_search.py +9 -0
  18. palimpzest/prompts/convert_prompts.py +61 -5
  19. palimpzest/prompts/filter_prompts.py +50 -5
  20. palimpzest/prompts/join_prompts.py +163 -0
  21. palimpzest/prompts/moa_proposer_convert_prompts.py +5 -5
  22. palimpzest/prompts/prompt_factory.py +358 -46
  23. palimpzest/prompts/validator.py +239 -0
  24. palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
  25. palimpzest/query/execution/execution_strategy.py +210 -317
  26. palimpzest/query/execution/execution_strategy_type.py +5 -7
  27. palimpzest/query/execution/mab_execution_strategy.py +249 -136
  28. palimpzest/query/execution/parallel_execution_strategy.py +153 -244
  29. palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
  30. palimpzest/query/generators/generators.py +157 -330
  31. palimpzest/query/operators/__init__.py +15 -5
  32. palimpzest/query/operators/aggregate.py +50 -33
  33. palimpzest/query/operators/compute.py +201 -0
  34. palimpzest/query/operators/convert.py +27 -21
  35. palimpzest/query/operators/critique_and_refine_convert.py +7 -5
  36. palimpzest/query/operators/distinct.py +62 -0
  37. palimpzest/query/operators/filter.py +22 -13
  38. palimpzest/query/operators/join.py +402 -0
  39. palimpzest/query/operators/limit.py +3 -3
  40. palimpzest/query/operators/logical.py +198 -80
  41. palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
  42. palimpzest/query/operators/physical.py +27 -21
  43. palimpzest/query/operators/project.py +3 -3
  44. palimpzest/query/operators/rag_convert.py +7 -7
  45. palimpzest/query/operators/retrieve.py +9 -9
  46. palimpzest/query/operators/scan.py +81 -42
  47. palimpzest/query/operators/search.py +524 -0
  48. palimpzest/query/operators/split_convert.py +10 -8
  49. palimpzest/query/optimizer/__init__.py +7 -9
  50. palimpzest/query/optimizer/cost_model.py +108 -441
  51. palimpzest/query/optimizer/optimizer.py +123 -181
  52. palimpzest/query/optimizer/optimizer_strategy.py +66 -61
  53. palimpzest/query/optimizer/plan.py +352 -67
  54. palimpzest/query/optimizer/primitives.py +43 -19
  55. palimpzest/query/optimizer/rules.py +484 -646
  56. palimpzest/query/optimizer/tasks.py +127 -58
  57. palimpzest/query/processor/config.py +41 -76
  58. palimpzest/query/processor/query_processor.py +73 -18
  59. palimpzest/query/processor/query_processor_factory.py +46 -38
  60. palimpzest/schemabuilder/schema_builder.py +15 -28
  61. palimpzest/utils/model_helpers.py +27 -77
  62. palimpzest/utils/progress.py +114 -102
  63. palimpzest/validator/__init__.py +0 -0
  64. palimpzest/validator/validator.py +306 -0
  65. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/METADATA +6 -1
  66. palimpzest-0.8.0.dist-info/RECORD +95 -0
  67. palimpzest/core/lib/fields.py +0 -141
  68. palimpzest/prompts/code_synthesis_prompts.py +0 -28
  69. palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
  70. palimpzest/query/generators/api_client_factory.py +0 -30
  71. palimpzest/query/operators/code_synthesis_convert.py +0 -488
  72. palimpzest/query/operators/map.py +0 -130
  73. palimpzest/query/processor/nosentinel_processor.py +0 -33
  74. palimpzest/query/processor/processing_strategy_type.py +0 -28
  75. palimpzest/query/processor/sentinel_processor.py +0 -88
  76. palimpzest/query/processor/streaming_processor.py +0 -149
  77. palimpzest/sets.py +0 -405
  78. palimpzest/utils/datareader_helpers.py +0 -61
  79. palimpzest/utils/demo_helpers.py +0 -75
  80. palimpzest/utils/field_helpers.py +0 -69
  81. palimpzest/utils/generation_helpers.py +0 -69
  82. palimpzest/utils/sandbox.py +0 -183
  83. palimpzest-0.7.20.dist-info/RECORD +0 -95
  84. /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
  85. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/WHEEL +0 -0
  86. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/licenses/LICENSE +0 -0
  87. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,16 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import json
3
4
  import time
4
5
  from abc import abstractmethod
5
- from dataclasses import dataclass, field, fields
6
6
  from typing import Any
7
7
 
8
- import numpy as np
8
+ from pydantic import BaseModel, Field
9
9
 
10
10
 
11
- @dataclass
12
- class GenerationStats:
11
+ class GenerationStats(BaseModel):
13
12
  """
14
- Dataclass for storing statistics about the execution of an operator on a single record.
13
+ Model for storing statistics about the execution of an operator on a single record.
15
14
  """
16
15
 
17
16
  model_name: str | None = None
@@ -19,6 +18,15 @@ class GenerationStats:
19
18
  # The raw answer as output from the generator (a list of strings, possibly of len 1)
20
19
  # raw_answers: Optional[List[str]] = field(default_factory=list)
21
20
 
21
+ # the number of input audio tokens
22
+ input_audio_tokens: int = 0
23
+
24
+ # the number of input text tokens
25
+ input_text_tokens: int = 0
26
+
27
+ # the number of input image tokens
28
+ input_image_tokens: int = 0
29
+
22
30
  # the total number of input tokens processed by this operator; None if this operation did not use an LLM
23
31
  # typed as a float because GenerationStats may be amortized (i.e. divided) across a number of output records
24
32
  total_input_tokens: float = 0.0
@@ -33,7 +41,7 @@ class GenerationStats:
33
41
  # the total cost of processing the output tokens; None if this operation did not use an LLM
34
42
  total_output_cost: float = 0.0
35
43
 
36
- # the total cost of processing the output tokens; None if this operation did not use an LLM
44
+ # the total cost of processing the input and output tokens; None if this operation did not use an LLM
37
45
  cost_per_record: float = 0.0
38
46
 
39
47
  # (if applicable) the time (in seconds) spent executing a call to an LLM
@@ -50,7 +58,7 @@ class GenerationStats:
50
58
 
51
59
  def __iadd__(self, other: GenerationStats) -> GenerationStats:
52
60
  # self.raw_answers.extend(other.raw_answers)
53
- for dataclass_field in [
61
+ for model_field in [
54
62
  "total_input_tokens",
55
63
  "total_output_tokens",
56
64
  "total_input_cost",
@@ -61,7 +69,7 @@ class GenerationStats:
61
69
  "total_llm_calls",
62
70
  "total_embedding_llm_calls",
63
71
  ]:
64
- setattr(self, dataclass_field, getattr(self, dataclass_field) + getattr(other, dataclass_field))
72
+ setattr(self, model_field, getattr(self, model_field) + getattr(other, model_field))
65
73
  return self
66
74
 
67
75
  def __add__(self, other: GenerationStats) -> GenerationStats:
@@ -89,7 +97,7 @@ class GenerationStats:
89
97
  raise ZeroDivisionError("Cannot divide by zero")
90
98
  if isinstance(quotient, int):
91
99
  quotient = float(quotient)
92
- for dataclass_field in [
100
+ for model_field in [
93
101
  "total_input_tokens",
94
102
  "total_output_tokens",
95
103
  "total_input_cost",
@@ -100,7 +108,7 @@ class GenerationStats:
100
108
  "total_llm_calls",
101
109
  "total_embedding_llm_calls",
102
110
  ]:
103
- setattr(self, dataclass_field, getattr(self, dataclass_field) / quotient)
111
+ setattr(self, model_field, getattr(self, model_field) / quotient)
104
112
  return self
105
113
 
106
114
  def __truediv__(self, quotient: float) -> GenerationStats:
@@ -129,22 +137,30 @@ class GenerationStats:
129
137
  assert not isinstance(other, GenerationStats), "This should not be called with a GenerationStats object"
130
138
  return self
131
139
 
140
+ # NOTE: this is added temporarily to help track cost of compute agent writing PZ code;
141
+ # once we find a long-term solution for tracking that cost, we can remove this
142
+ def to_json(self, filepath: str | None = None) -> dict | None:
143
+ if filepath is None:
144
+ return self.model_dump(mode="json")
145
+
146
+ with open(filepath, "w") as f:
147
+ json.dump(self.model_dump(mode="json"), f)
148
+
132
149
 
133
- @dataclass
134
- class RecordOpStats:
150
+ class RecordOpStats(BaseModel):
135
151
  """
136
- Dataclass for storing statistics about the execution of an operator on a single record.
152
+ Model for storing statistics about the execution of an operator on a single record.
137
153
  """
138
154
 
139
155
  ##### REQUIRED FIELDS #####
140
156
  # record id; an identifier for this record
141
- record_id: str
157
+ record_id: str | int
142
158
 
143
- # identifier for the parent of this record
144
- record_parent_id: str
159
+ # identifier for the parent(s) of this record
160
+ record_parent_ids: list[str | int] | None
145
161
 
146
- # idenifier for the source idx of this record
147
- record_source_idx: str
162
+ # idenifier for the source indices of this record
163
+ record_source_indices: list[str | int]
148
164
 
149
165
  # a dictionary with the record state after being processed by the operator
150
166
  record_state: dict[str, Any]
@@ -165,8 +181,11 @@ class RecordOpStats:
165
181
  cost_per_record: float
166
182
 
167
183
  ##### NOT-OPTIONAL, BUT FILLED BY EXECUTION CLASS AFTER CONSTRUCTOR CALL #####
168
- # the ID of the physical operation which produced the input record for this record at this operation
169
- source_full_op_id: str | None = None
184
+ # the ID(s) of the physical operation(s) which produced the input record(s) for this record at this operation
185
+ source_unique_full_op_ids: list[str] | None = None
186
+
187
+ # the ID(s) of the logical operation(s) which produced the input record(s) for this record at this operation
188
+ source_unique_logical_op_ids: list[str] | None = None
170
189
 
171
190
  # the ID of the physical plan which produced this record at this operation
172
191
  plan_id: str = ""
@@ -207,8 +226,11 @@ class RecordOpStats:
207
226
  # (if applicable) the filter text (or a string representation of the filter function) applied to this record
208
227
  filter_str: str | None = None
209
228
 
229
+ # (if applicable) the join condition applied to this record
230
+ join_condition: str | None = None
231
+
210
232
  # the True/False result of whether this record was output by the operator or not
211
- # (can only be False if the operator is as Filter)
233
+ # (can only be False if the operator is a Filter or Join)
212
234
  passed_operator: bool = True
213
235
 
214
236
  # (if applicable) the time (in seconds) spent executing a call to an LLM
@@ -230,16 +252,12 @@ class RecordOpStats:
230
252
  image_operation: bool | None = None
231
253
 
232
254
  # an OPTIONAL dictionary with more detailed information about this operation;
233
- op_details: dict[str, Any] = field(default_factory=dict)
234
-
235
- def to_json(self):
236
- return {field.name: getattr(self, field.name) for field in fields(self)}
255
+ op_details: dict[str, Any] = Field(default_factory=dict)
237
256
 
238
257
 
239
- @dataclass
240
- class OperatorStats:
258
+ class OperatorStats(BaseModel):
241
259
  """
242
- Dataclass for storing statistics captured within a given operator.
260
+ Model for storing statistics captured within a given operator.
243
261
  """
244
262
 
245
263
  # the full ID of the physical operation in which these stats were collected
@@ -254,17 +272,26 @@ class OperatorStats:
254
272
  # the total cost of this operation
255
273
  total_op_cost: float = 0.0
256
274
 
275
+ # the total input tokens processed by this operation
276
+ total_input_tokens: int = 0
277
+
278
+ # the total output tokens processed by this operation
279
+ total_output_tokens: int = 0
280
+
257
281
  # a list of RecordOpStats processed by the operation
258
- record_op_stats_lst: list[RecordOpStats] = field(default_factory=list)
282
+ record_op_stats_lst: list[RecordOpStats] = Field(default_factory=list)
259
283
 
260
- # the full ID of the physical operator which precedes this one
261
- source_full_op_id: str | None = None
284
+ # the unique full ID(s) of the physical operator(s) which precede this one (used by PlanStats)
285
+ source_unique_full_op_ids: list[str] | None = None
286
+
287
+ # the unique full ID(s) of the logical operator(s) which precede this one (used by SentinelPlanStats)
288
+ source_unique_logical_op_ids: list[str] | None = None
262
289
 
263
290
  # the ID of the physical plan which this operator is part of
264
291
  plan_id: str = ""
265
292
 
266
293
  # an OPTIONAL dictionary with more detailed information about this operation;
267
- op_details: dict[str, Any] = field(default_factory=dict)
294
+ op_details: dict[str, Any] = Field(default_factory=dict)
268
295
 
269
296
  def __iadd__(self, stats: OperatorStats | RecordOpStats) -> OperatorStats:
270
297
  """
@@ -280,34 +307,28 @@ class OperatorStats:
280
307
  if isinstance(stats, OperatorStats):
281
308
  self.total_op_time += stats.total_op_time
282
309
  self.total_op_cost += stats.total_op_cost
310
+ self.total_input_tokens += stats.total_input_tokens
311
+ self.total_output_tokens += stats.total_output_tokens
283
312
  self.record_op_stats_lst.extend(stats.record_op_stats_lst)
284
313
 
285
314
  elif isinstance(stats, RecordOpStats):
286
- stats.source_full_op_id = self.source_full_op_id
315
+ stats.source_unique_full_op_ids = self.source_unique_full_op_ids
287
316
  stats.plan_id = self.plan_id
288
317
  self.record_op_stats_lst.append(stats)
289
318
  self.total_op_time += stats.time_per_record
290
319
  self.total_op_cost += stats.cost_per_record
320
+ self.total_input_tokens += stats.total_input_tokens
321
+ self.total_output_tokens += stats.total_output_tokens
291
322
 
292
323
  else:
293
324
  raise TypeError(f"Cannot add {type(stats)} to OperatorStats")
294
325
 
295
326
  return self
296
327
 
297
- def to_json(self):
298
- return {
299
- "full_op_id": self.full_op_id,
300
- "op_name": self.op_name,
301
- "total_op_time": self.total_op_time,
302
- "total_op_cost": self.total_op_cost,
303
- "record_op_stats_lst": [record_op_stats.to_json() for record_op_stats in self.record_op_stats_lst],
304
- "op_details": self.op_details,
305
- }
306
328
 
307
- @dataclass
308
- class BasePlanStats:
329
+ class BasePlanStats(BaseModel):
309
330
  """
310
- Dataclass for storing statistics captured for an entire plan.
331
+ Model for storing statistics captured for an entire plan.
311
332
 
312
333
  This class is subclassed for tracking:
313
334
  - PlanStats: the statistics for execution of a PhysicalPlan
@@ -331,7 +352,11 @@ class BasePlanStats:
331
352
  # dictionary whose values are OperatorStats objects;
332
353
  # PlanStats maps {full_op_id -> OperatorStats}
333
354
  # SentinelPlanStats maps {logical_op_id -> {full_op_id -> OperatorStats}}
334
- operator_stats: dict = field(default_factory=dict)
355
+ operator_stats: dict = Field(default_factory=dict)
356
+
357
+ # dictionary whose values are GenerationStats objects for validation;
358
+ # only used by SentinelPlanStats
359
+ validation_gen_stats: dict[str, GenerationStats] = Field(default_factory=dict)
335
360
 
336
361
  # total runtime for the plan measured from the start to the end of PhysicalPlan.execute()
337
362
  total_plan_time: float = 0.0
@@ -339,6 +364,12 @@ class BasePlanStats:
339
364
  # total cost for plan
340
365
  total_plan_cost: float = 0.0
341
366
 
367
+ # total input tokens processed by this plan
368
+ total_input_tokens: int = 0
369
+
370
+ # total output tokens processed by this plan
371
+ total_output_tokens: int = 0
372
+
342
373
  # start time for the plan execution; should be set by calling PlanStats.start()
343
374
  start_time: float | None = None
344
375
 
@@ -351,7 +382,9 @@ class BasePlanStats:
351
382
  if self.start_time is None:
352
383
  raise RuntimeError("PlanStats.start() must be called before PlanStats.finish()")
353
384
  self.total_plan_time = time.time() - self.start_time
354
- self.total_plan_cost = self.sum_op_costs()
385
+ self.total_plan_cost = self.sum_op_costs() + self.sum_validation_costs()
386
+ self.total_input_tokens = self.sum_input_tokens() + self.sum_validation_input_tokens()
387
+ self.total_output_tokens = self.sum_output_tokens() + self.sum_validation_output_tokens()
355
388
 
356
389
  @staticmethod
357
390
  @abstractmethod
@@ -369,9 +402,23 @@ class BasePlanStats:
369
402
  pass
370
403
 
371
404
  @abstractmethod
372
- def add_record_op_stats(self, record_op_stats: RecordOpStats | list[RecordOpStats]) -> None:
405
+ def sum_input_tokens(self) -> int:
406
+ """
407
+ Sum the input tokens processed by all operators in this plan.
408
+ """
409
+ pass
410
+
411
+ @abstractmethod
412
+ def sum_output_tokens(self) -> int:
413
+ """
414
+ Sum the output tokens processed by all operators in this plan.
415
+ """
416
+ pass
417
+
418
+ @abstractmethod
419
+ def add_record_op_stats(self, unique_full_op_id: str, record_op_stats: RecordOpStats | list[RecordOpStats]) -> None:
373
420
  """
374
- Add the given RecordOpStats to this plan's operator stats.
421
+ Add the given RecordOpStats to this plan's operator stats for the given operator id.
375
422
  """
376
423
  pass
377
424
 
@@ -389,14 +436,25 @@ class BasePlanStats:
389
436
  """
390
437
  pass
391
438
 
392
- @abstractmethod
393
- def to_json(self) -> dict:
439
+ def sum_validation_costs(self) -> float:
394
440
  """
395
- Return a JSON representation of this plan's statistics.
441
+ Sum the costs of all validation generations in this plan.
396
442
  """
397
- pass
443
+ return sum([gen_stats.cost_per_record for _, gen_stats in self.validation_gen_stats.items()])
444
+
445
+ def sum_validation_input_tokens(self) -> int:
446
+ """
447
+ Sum the input tokens processed by all validation generations in this plan.
448
+ """
449
+ return sum([gen_stats.total_input_tokens for _, gen_stats in self.validation_gen_stats.items()])
450
+
451
+ def sum_validation_output_tokens(self) -> int:
452
+ """
453
+ Sum the output tokens processed by all validation generations in this plan.
454
+ """
455
+ return sum([gen_stats.total_output_tokens for _, gen_stats in self.validation_gen_stats.items()])
456
+
398
457
 
399
- @dataclass
400
458
  class PlanStats(BasePlanStats):
401
459
  """
402
460
  Subclass of BasePlanStats which captures statistics from the execution of a single PhysicalPlan.
@@ -406,17 +464,18 @@ class PlanStats(BasePlanStats):
406
464
  """
407
465
  Initialize this PlanStats object from a PhysicalPlan object.
408
466
  """
467
+ # TODO?: have PhysicalPlan return PlanStats object
409
468
  operator_stats = {}
410
- for op_idx, op in enumerate(plan.operators):
411
- full_op_id = op.get_full_op_id()
412
- operator_stats[full_op_id] = OperatorStats(
413
- full_op_id=full_op_id,
469
+ for topo_idx, op in enumerate(plan):
470
+ unique_full_op_id = f"{topo_idx}-{op.get_full_op_id()}"
471
+ operator_stats[unique_full_op_id] = OperatorStats(
472
+ full_op_id=op.get_full_op_id(),
414
473
  op_name=op.op_name(),
415
- source_full_op_id=None if op_idx == 0 else plan.operators[op_idx - 1].get_full_op_id(),
474
+ source_unique_full_op_ids=plan.get_source_unique_full_op_ids(topo_idx, op),
416
475
  plan_id=plan.plan_id,
417
476
  op_details={k: str(v) for k, v in op.get_id_params().items()},
418
477
  )
419
-
478
+
420
479
  return PlanStats(plan_id=plan.plan_id, plan_str=str(plan), operator_stats=operator_stats)
421
480
 
422
481
  def sum_op_costs(self) -> float:
@@ -425,20 +484,31 @@ class PlanStats(BasePlanStats):
425
484
  """
426
485
  return sum([op_stats.total_op_cost for _, op_stats in self.operator_stats.items()])
427
486
 
428
- def add_record_op_stats(self, record_op_stats: RecordOpStats | list[RecordOpStats]) -> None:
487
+ def sum_input_tokens(self) -> int:
488
+ """
489
+ Sum the input tokens processed by all operators in this plan.
490
+ """
491
+ return sum([op_stats.total_input_tokens for _, op_stats in self.operator_stats.items()])
492
+
493
+ def sum_output_tokens(self) -> int:
429
494
  """
430
- Add the given RecordOpStats to this plan's operator stats.
495
+ Sum the output tokens processed by all operators in this plan.
496
+ """
497
+ return sum([op_stats.total_output_tokens for _, op_stats in self.operator_stats.items()])
498
+
499
+ def add_record_op_stats(self, unique_full_op_id: str, record_op_stats: RecordOpStats | list[RecordOpStats]) -> None:
500
+ """
501
+ Add the given RecordOpStats to this plan's operator stats for the given operator id.
431
502
  """
432
503
  # normalize input type to be list[RecordOpStats]
433
504
  record_op_stats_lst = record_op_stats if isinstance(record_op_stats, list) else [record_op_stats]
434
505
 
435
506
  # update operator stats
436
507
  for record_op_stats in record_op_stats_lst:
437
- full_op_id = record_op_stats.full_op_id
438
- if full_op_id in self.operator_stats:
439
- self.operator_stats[full_op_id] += record_op_stats
508
+ if unique_full_op_id in self.operator_stats:
509
+ self.operator_stats[unique_full_op_id] += record_op_stats
440
510
  else:
441
- raise ValueError(f"RecordOpStats with full_op_id {full_op_id} not found in PlanStats")
511
+ raise ValueError(f"RecordOpStats with unique_full_op_id {unique_full_op_id} not found in PlanStats")
442
512
 
443
513
  def __iadd__(self, plan_stats: PlanStats) -> None:
444
514
  """
@@ -450,30 +520,24 @@ class PlanStats(BasePlanStats):
450
520
  """
451
521
  self.total_plan_time += plan_stats.total_plan_time
452
522
  self.total_plan_cost += plan_stats.total_plan_cost
453
- for full_op_id, op_stats in plan_stats.operator_stats.items():
454
- if full_op_id in self.operator_stats:
455
- self.operator_stats[full_op_id] += op_stats
523
+ self.total_input_tokens += plan_stats.total_input_tokens
524
+ self.total_output_tokens += plan_stats.total_output_tokens
525
+ for unique_full_op_id, op_stats in plan_stats.operator_stats.items():
526
+ if unique_full_op_id in self.operator_stats:
527
+ self.operator_stats[unique_full_op_id] += op_stats
456
528
  else:
457
- self.operator_stats[full_op_id] = op_stats
529
+ self.operator_stats[unique_full_op_id] = op_stats
458
530
 
459
531
  def __str__(self) -> str:
460
532
  stats = f"total_plan_time={self.total_plan_time} \n"
461
533
  stats += f"total_plan_cost={self.total_plan_cost} \n"
534
+ stats += f"total_input_tokens={self.total_input_tokens} \n"
535
+ stats += f"total_output_tokens={self.total_output_tokens} \n"
462
536
  for idx, op_stats in enumerate(self.operator_stats.values()):
463
537
  stats += f"{idx}. {op_stats.op_name} time={op_stats.total_op_time} cost={op_stats.total_op_cost} \n"
464
538
  return stats
465
539
 
466
- def to_json(self) -> dict:
467
- return {
468
- "plan_id": self.plan_id,
469
- "plan_str": self.plan_str,
470
- "operator_stats": {full_op_id: op_stats.to_json() for full_op_id, op_stats in self.operator_stats.items()},
471
- "total_plan_time": self.total_plan_time,
472
- "total_plan_cost": self.total_plan_cost,
473
- }
474
540
 
475
-
476
- @dataclass
477
541
  class SentinelPlanStats(BasePlanStats):
478
542
  """
479
543
  Subclass of BasePlanStats which captures statistics from the execution of a single SentinelPlan.
@@ -484,18 +548,19 @@ class SentinelPlanStats(BasePlanStats):
484
548
  Initialize this PlanStats object from a Sentinel object.
485
549
  """
486
550
  operator_stats = {}
487
- for op_set_idx, (logical_op_id, op_set) in enumerate(plan):
488
- operator_stats[logical_op_id] = {}
551
+ for topo_idx, (logical_op_id, op_set) in enumerate(plan):
552
+ unique_logical_op_id = f"{topo_idx}-{logical_op_id}"
553
+ operator_stats[unique_logical_op_id] = {}
489
554
  for physical_op in op_set:
490
555
  full_op_id = physical_op.get_full_op_id()
491
- operator_stats[logical_op_id][full_op_id] = OperatorStats(
556
+ operator_stats[unique_logical_op_id][full_op_id] = OperatorStats(
492
557
  full_op_id=full_op_id,
493
558
  op_name=physical_op.op_name(),
494
- source_full_op_id=None if op_set_idx == 0 else plan.logical_op_ids[op_set_idx - 1], # NOTE: this may be a reason to keep `source_op_id` instead of `source_full_op_id`
559
+ source_unique_logical_op_ids=plan.get_source_unique_logical_op_ids(unique_logical_op_id),
495
560
  plan_id=plan.plan_id,
496
561
  op_details={k: str(v) for k, v in physical_op.get_id_params().items()},
497
562
  )
498
-
563
+
499
564
  return SentinelPlanStats(plan_id=plan.plan_id, plan_str=str(plan), operator_stats=operator_stats)
500
565
 
501
566
  def sum_op_costs(self) -> float:
@@ -504,24 +569,45 @@ class SentinelPlanStats(BasePlanStats):
504
569
  """
505
570
  return sum(sum([op_stats.total_op_cost for _, op_stats in phys_op_stats.items()]) for _, phys_op_stats in self.operator_stats.items())
506
571
 
507
- def add_record_op_stats(self, record_op_stats: RecordOpStats | list[RecordOpStats]) -> None:
572
+ def sum_input_tokens(self) -> int:
573
+ """
574
+ Sum the input tokens processed by all operators in this plan.
575
+ """
576
+ return sum(sum([op_stats.total_input_tokens for _, op_stats in phys_op_stats.items()]) for _, phys_op_stats in self.operator_stats.items())
577
+
578
+ def sum_output_tokens(self) -> int:
508
579
  """
509
- Add the given RecordOpStats to this plan's operator stats.
580
+ Sum the output tokens processed by all operators in this plan.
581
+ """
582
+ return sum(sum([op_stats.total_output_tokens for _, op_stats in phys_op_stats.items()]) for _, phys_op_stats in self.operator_stats.items())
583
+
584
+ def add_record_op_stats(self, unique_logical_op_id: str, record_op_stats: RecordOpStats | list[RecordOpStats]) -> None:
585
+ """
586
+ Add the given RecordOpStats to this plan's operator stats for the given operator set id.
510
587
  """
511
588
  # normalize input type to be list[RecordOpStats]
512
589
  record_op_stats_lst = record_op_stats if isinstance(record_op_stats, list) else [record_op_stats]
513
590
 
514
591
  # update operator stats
515
592
  for record_op_stats in record_op_stats_lst:
516
- logical_op_id = record_op_stats.logical_op_id
517
593
  full_op_id = record_op_stats.full_op_id
518
- if logical_op_id in self.operator_stats:
519
- if full_op_id in self.operator_stats[logical_op_id]:
520
- self.operator_stats[logical_op_id][full_op_id] += record_op_stats
594
+ if unique_logical_op_id in self.operator_stats:
595
+ if full_op_id in self.operator_stats[unique_logical_op_id]:
596
+ self.operator_stats[unique_logical_op_id][full_op_id] += record_op_stats
521
597
  else:
522
598
  raise ValueError(f"RecordOpStats with full_op_id {full_op_id} not found in SentinelPlanStats")
523
599
  else:
524
- raise ValueError(f"RecordOpStats with logical_op_id {logical_op_id} not found in SentinelPlanStats")
600
+ raise ValueError(f"RecordOpStats with unique_logical_op_id {unique_logical_op_id} not found in SentinelPlanStats")
601
+
602
+ def add_validation_gen_stats(self, unique_logical_op_id: str, gen_stats: GenerationStats) -> None:
603
+ """
604
+ Add the given GenerationStats to this plan's validation generation stats for the given logical operator id.
605
+ """
606
+ if unique_logical_op_id in self.validation_gen_stats:
607
+ self.validation_gen_stats[unique_logical_op_id] += gen_stats
608
+ else:
609
+ self.validation_gen_stats[unique_logical_op_id] = gen_stats
610
+
525
611
 
526
612
  def __iadd__(self, plan_stats: SentinelPlanStats) -> None:
527
613
  """
@@ -533,19 +619,29 @@ class SentinelPlanStats(BasePlanStats):
533
619
  """
534
620
  self.total_plan_time += plan_stats.total_plan_time
535
621
  self.total_plan_cost += plan_stats.total_plan_cost
536
- for logical_op_id, physical_op_stats in plan_stats.operator_stats.items():
622
+ self.total_input_tokens += plan_stats.total_input_tokens
623
+ self.total_output_tokens += plan_stats.total_output_tokens
624
+ for unique_logical_op_id, physical_op_stats in plan_stats.operator_stats.items():
537
625
  for full_op_id, op_stats in physical_op_stats.items():
538
- if logical_op_id in self.operator_stats:
539
- if full_op_id in self.operator_stats[logical_op_id]:
540
- self.operator_stats[logical_op_id][full_op_id] += op_stats
626
+ if unique_logical_op_id in self.operator_stats:
627
+ if full_op_id in self.operator_stats[unique_logical_op_id]:
628
+ self.operator_stats[unique_logical_op_id][full_op_id] += op_stats
541
629
  else:
542
- self.operator_stats[logical_op_id][full_op_id] = op_stats
630
+ self.operator_stats[unique_logical_op_id][full_op_id] = op_stats
543
631
  else:
544
- self.operator_stats[logical_op_id] = physical_op_stats
632
+ self.operator_stats[unique_logical_op_id] = physical_op_stats
633
+
634
+ for unique_logical_op_id, gen_stats in plan_stats.validation_gen_stats.items():
635
+ if unique_logical_op_id in self.validation_gen_stats:
636
+ self.validation_gen_stats[unique_logical_op_id] += gen_stats
637
+ else:
638
+ self.validation_gen_stats[unique_logical_op_id] = gen_stats
545
639
 
546
640
  def __str__(self) -> str:
547
641
  stats = f"total_plan_time={self.total_plan_time} \n"
548
642
  stats += f"total_plan_cost={self.total_plan_cost} \n"
643
+ stats += f"total_input_tokens={self.total_input_tokens} \n"
644
+ stats += f"total_output_tokens={self.total_output_tokens} \n"
549
645
  for outer_idx, physical_op_stats in enumerate(self.operator_stats.values()):
550
646
  total_time = sum([op_stats.total_op_time for op_stats in physical_op_stats.values()])
551
647
  total_cost = sum([op_stats.total_op_cost for op_stats in physical_op_stats.values()])
@@ -554,33 +650,20 @@ class SentinelPlanStats(BasePlanStats):
554
650
  stats += f" {outer_idx}.{inner_idx}. {op_stats.op_name} time={op_stats.total_op_time} cost={op_stats.total_op_cost} \n"
555
651
  return stats
556
652
 
557
- def to_json(self) -> dict:
558
- return {
559
- "plan_id": self.plan_id,
560
- "plan_str": self.plan_str,
561
- "operator_stats": {
562
- logical_op_id: {full_op_id: op_stats.to_json() for full_op_id, op_stats in physical_op_stats.items()}
563
- for logical_op_id, physical_op_stats in self.operator_stats.items()
564
- },
565
- "total_plan_time": self.total_plan_time,
566
- "total_plan_cost": self.total_plan_cost,
567
- }
568
-
569
653
 
570
- @dataclass
571
- class ExecutionStats:
654
+ class ExecutionStats(BaseModel):
572
655
  """
573
- Dataclass for storing statistics captured for the entire execution of a workload.
656
+ Model for storing statistics captured for the entire execution of a workload.
574
657
  """
575
658
 
576
659
  # string for identifying this workload execution
577
660
  execution_id: str | None = None
578
661
 
579
662
  # dictionary of SentinelPlanStats objects (one for each sentinel plan run during execution)
580
- sentinel_plan_stats: dict[str, SentinelPlanStats] = field(default_factory=dict)
663
+ sentinel_plan_stats: dict[str, SentinelPlanStats] = Field(default_factory=dict)
581
664
 
582
665
  # dictionary of PlanStats objects (one for each plan run during execution)
583
- plan_stats: dict[str, PlanStats] = field(default_factory=dict)
666
+ plan_stats: dict[str, PlanStats] = Field(default_factory=dict)
584
667
 
585
668
  # total time spent optimizing
586
669
  optimization_time: float = 0.0
@@ -600,16 +683,25 @@ class ExecutionStats:
600
683
  # total cost for the entire execution
601
684
  total_execution_cost: float = 0.0
602
685
 
686
+ # total number of input tokens processed
687
+ total_input_tokens: int = 0
688
+
689
+ # total number of output tokens processed
690
+ total_output_tokens: int = 0
691
+
692
+ # total number of tokens processed
693
+ total_tokens: int = 0
694
+
603
695
  # dictionary of sentinel plan strings; useful for printing executed sentinel plans in demos
604
- sentinel_plan_strs: dict[str, str] = field(default_factory=dict)
696
+ sentinel_plan_strs: dict[str, str] = Field(default_factory=dict)
605
697
 
606
698
  # dictionary of plan strings; useful for printing executed plans in demos
607
- plan_strs: dict[str, str] = field(default_factory=dict)
699
+ plan_strs: dict[str, str] = Field(default_factory=dict)
608
700
 
609
701
  # start time for the execution; should be set by calling ExecutionStats.start()
610
702
  start_time: float | None = None
611
703
 
612
- # end time for the optimization;
704
+ # end time for the optimization;
613
705
  optimization_end_time: float | None = None
614
706
 
615
707
  def start(self) -> None:
@@ -647,6 +739,11 @@ class ExecutionStats:
647
739
  self.plan_execution_cost = self.sum_plan_costs()
648
740
  self.total_execution_cost = self.optimization_cost + self.plan_execution_cost
649
741
 
742
+ # compute the tokens for total execution
743
+ self.total_input_tokens = self.sum_input_tokens()
744
+ self.total_output_tokens = self.sum_output_tokens()
745
+ self.total_tokens = self.total_input_tokens + self.total_output_tokens
746
+
650
747
  # compute plan_strs
651
748
  self.plan_strs = {plan_id: plan_stats.plan_str for plan_id, plan_stats in self.plan_stats.items()}
652
749
 
@@ -654,7 +751,7 @@ class ExecutionStats:
654
751
  """
655
752
  Sum the costs of all SentinelPlans in this execution.
656
753
  """
657
- return sum([plan_stats.sum_op_costs() for _, plan_stats in self.sentinel_plan_stats.items()])
754
+ return sum([plan_stats.sum_op_costs() + plan_stats.sum_validation_costs() for _, plan_stats in self.sentinel_plan_stats.items()])
658
755
 
659
756
  def sum_plan_costs(self) -> float:
660
757
  """
@@ -662,6 +759,22 @@ class ExecutionStats:
662
759
  """
663
760
  return sum([plan_stats.sum_op_costs() for _, plan_stats in self.plan_stats.items()])
664
761
 
762
+ def sum_input_tokens(self) -> int:
763
+ """
764
+ Sum the input tokens processed in this execution
765
+ """
766
+ sentinel_plan_input_tokens = sum([plan_stats.sum_input_tokens() for _, plan_stats in self.sentinel_plan_stats.items()])
767
+ plan_input_tokens = sum([plan_stats.sum_input_tokens() for _, plan_stats in self.plan_stats.items()])
768
+ return plan_input_tokens + sentinel_plan_input_tokens
769
+
770
+ def sum_output_tokens(self) -> int:
771
+ """
772
+ Sum the output tokens processed in this execution
773
+ """
774
+ sentinel_plan_output_tokens = sum([plan_stats.sum_output_tokens() for _, plan_stats in self.sentinel_plan_stats.items()])
775
+ plan_output_tokens = sum([plan_stats.sum_output_tokens() for _, plan_stats in self.plan_stats.items()])
776
+ return plan_output_tokens + sentinel_plan_output_tokens
777
+
665
778
  def add_plan_stats(self, plan_stats: PlanStats | SentinelPlanStats | list[PlanStats] | list[SentinelPlanStats]) -> None:
666
779
  """
667
780
  Add the given PlanStats (or SentinelPlanStats) to this execution's plan stats.
@@ -686,43 +799,17 @@ class ExecutionStats:
686
799
  else:
687
800
  raise TypeError(f"Cannot add {type(plan_stats)} to ExecutionStats")
688
801
 
689
- def clean_json(self, stats: dict):
690
- """
691
- Convert np.int64 and np.float64 to int and float for all values in stats.
692
- """
693
- for key, value in stats.items():
694
- if isinstance(value, dict):
695
- stats[key] = self.clean_json(value)
696
- elif isinstance(value, np.int64):
697
- stats[key] = int(value)
698
- elif isinstance(value, np.float64):
699
- stats[key] = float(value)
700
- return stats
802
+ def to_json(self, filepath: str | None = None) -> dict | None:
803
+ if filepath is None:
804
+ return self.model_dump(mode="json")
701
805
 
702
- def to_json(self):
703
- stats = {
704
- "execution_id": self.execution_id,
705
- "sentinel_plan_stats": {
706
- plan_id: plan_stats.to_json() for plan_id, plan_stats in self.sentinel_plan_stats.items()
707
- },
708
- "plan_stats": {plan_id: plan_stats.to_json() for plan_id, plan_stats in self.plan_stats.items()},
709
- "optimization_time": self.optimization_time,
710
- "optimization_cost": self.optimization_cost,
711
- "plan_execution_time": self.plan_execution_time,
712
- "plan_execution_cost": self.plan_execution_cost,
713
- "total_execution_time": self.total_execution_time,
714
- "total_execution_cost": self.total_execution_cost,
715
- "sentinel_plan_strs": self.sentinel_plan_strs,
716
- "plan_strs": self.plan_strs,
717
- }
718
- stats = self.clean_json(stats)
719
- return stats
806
+ with open(filepath, "w") as f:
807
+ json.dump(self.model_dump(mode="json"), f)
720
808
 
721
809
 
722
- @dataclass
723
- class OperatorCostEstimates:
810
+ class OperatorCostEstimates(BaseModel):
724
811
  """
725
- Dataclass for storing estimates of key metrics of interest for each operator.
812
+ Model for storing estimates of key metrics of interest for each operator.
726
813
  """
727
814
 
728
815
  # (estimated) number of records output by this operator
@@ -765,10 +852,10 @@ class OperatorCostEstimates:
765
852
  """
766
853
  Multiply all fields by a scalar.
767
854
  """
768
- dct = {field.name: getattr(self, field.name) * multiplier for field in fields(self)}
855
+ dct = {field_name: getattr(self, field_name) * multiplier for field_name in self.model_fields}
769
856
  return OperatorCostEstimates(**dct)
770
857
 
771
- def __post_init__(self):
858
+ def model_post_init(self, __context: Any) -> None:
772
859
  if self.cardinality_lower_bound is None and self.cardinality_upper_bound is None:
773
860
  self.cardinality_lower_bound = self.cardinality
774
861
  self.cardinality_upper_bound = self.cardinality
@@ -786,10 +873,9 @@ class OperatorCostEstimates:
786
873
  self.quality_upper_bound = self.quality
787
874
 
788
875
 
789
- @dataclass
790
- class PlanCost:
876
+ class PlanCost(BaseModel):
791
877
  """
792
- Dataclass for storing the (cost, time, quality) estimates of (sub)-plans and their upper and lower bounds.
878
+ Model for storing the (cost, time, quality) estimates of (sub)-plans and their upper and lower bounds.
793
879
  """
794
880
 
795
881
  # the expression cost
@@ -825,7 +911,16 @@ class PlanCost:
825
911
  def __hash__(self):
826
912
  return hash(f"{self.cost}-{self.time}-{self.quality}")
827
913
 
828
- def __post_init__(self):
914
+ def __eq__(self, other: Any) -> bool:
915
+ if not isinstance(other, PlanCost):
916
+ return False
917
+ return (
918
+ self.cost == other.cost
919
+ and self.time == other.time
920
+ and self.quality == other.quality
921
+ )
922
+
923
+ def model_post_init(self, __context: Any) -> None:
829
924
  if self.time_lower_bound is None and self.time_upper_bound is None:
830
925
  self.time_lower_bound = self.time
831
926
  self.time_upper_bound = self.time
@@ -838,30 +933,71 @@ class PlanCost:
838
933
  self.quality_lower_bound = self.quality
839
934
  self.quality_upper_bound = self.quality
840
935
 
936
+ def join_add(self, left_plan_cost: PlanCost, right_plan_cost: PlanCost, execution_strategy: str = "parallel") -> PlanCost:
937
+ """
938
+ Add the PlanCost objects for two joined plans (left_plan_cost and right_plan_cost)
939
+ to the PlanCost object for the join operator. The execution strategy determines how
940
+ the input times are combined. If the execution strategy is "parallel", the input time
941
+ is the maximum of the two times. If the execution strategy is "sequential" (which is
942
+ currently anything else), the input time is the sum of the two times.
943
+
944
+ For quality, we compute the produce of the operator quality with the average of the
945
+ two input qualities.
946
+
947
+ NOTE: we currently assume the updating of the op_estimates are handled by the caller
948
+ as there is not a universally correct meaning of addition of op_estimates.
949
+ """
950
+ dct = {}
951
+ for model_field in ["cost", "cost_lower_bound", "cost_upper_bound"]:
952
+ op_field_value = getattr(self, model_field)
953
+ left_plan_field_value = getattr(left_plan_cost, model_field)
954
+ right_plan_field_value = getattr(right_plan_cost, model_field)
955
+ if op_field_value is not None and left_plan_field_value is not None and right_plan_field_value is not None:
956
+ dct[model_field] = op_field_value + left_plan_field_value + right_plan_field_value
957
+
958
+ for model_field in ["time", "time_lower_bound", "time_upper_bound"]:
959
+ op_field_value = getattr(self, model_field)
960
+ left_plan_field_value = getattr(left_plan_cost, model_field)
961
+ right_plan_field_value = getattr(right_plan_cost, model_field)
962
+ if op_field_value is not None and left_plan_field_value is not None and right_plan_field_value is not None:
963
+ if execution_strategy == "parallel":
964
+ dct[model_field] = op_field_value + max(left_plan_field_value, right_plan_field_value)
965
+ else:
966
+ dct[model_field] = op_field_value + left_plan_field_value + right_plan_field_value
967
+
968
+ for model_field in ["quality", "quality_lower_bound", "quality_upper_bound"]:
969
+ op_field_value = getattr(self, model_field)
970
+ left_plan_field_value = getattr(left_plan_cost, model_field)
971
+ right_plan_field_value = getattr(right_plan_cost, model_field)
972
+ if op_field_value is not None and left_plan_field_value is not None and right_plan_field_value is not None:
973
+ dct[model_field] = op_field_value * ((left_plan_field_value + right_plan_field_value) / 2.0)
974
+
975
+ return PlanCost(**dct)
976
+
841
977
  def __iadd__(self, other: PlanCost) -> PlanCost:
842
978
  """
843
979
  NOTE: we currently assume the updating of the op_estimates are handled by the caller
844
- as there is not a universally correct meaning of addition of op_estiamtes.
980
+ as there is not a universally correct meaning of addition of op_estimates.
845
981
  """
846
982
  self.cost += other.cost
847
983
  self.time += other.time
848
984
  self.quality *= other.quality
849
- for dataclass_field in ["cost_lower_bound", "cost_upper_bound", "time_lower_bound", "time_upper_bound"]:
850
- if getattr(self, dataclass_field) is not None and getattr(other, dataclass_field) is not None:
851
- summation = getattr(self, dataclass_field) + getattr(other, dataclass_field)
852
- setattr(self, dataclass_field, summation)
985
+ for model_field in ["cost_lower_bound", "cost_upper_bound", "time_lower_bound", "time_upper_bound"]:
986
+ if getattr(self, model_field) is not None and getattr(other, model_field) is not None:
987
+ summation = getattr(self, model_field) + getattr(other, model_field)
988
+ setattr(self, model_field, summation)
853
989
 
854
- for dataclass_field in ["quality_lower_bound", "quality_upper_bound"]:
855
- if getattr(self, dataclass_field) is not None and getattr(other, dataclass_field) is not None:
856
- product = getattr(self, dataclass_field) * getattr(other, dataclass_field)
857
- setattr(self, dataclass_field, product)
990
+ for model_field in ["quality_lower_bound", "quality_upper_bound"]:
991
+ if getattr(self, model_field) is not None and getattr(other, model_field) is not None:
992
+ product = getattr(self, model_field) * getattr(other, model_field)
993
+ setattr(self, model_field, product)
858
994
 
859
995
  return self
860
996
 
861
997
  def __add__(self, other: PlanCost) -> PlanCost:
862
998
  """
863
999
  NOTE: we currently assume the updating of the op_estimates are handled by the caller
864
- as there is not a universally correct meaning of addition of op_estiamtes.
1000
+ as there is not a universally correct meaning of addition of op_estimates.
865
1001
  """
866
1002
  dct = {
867
1003
  field: getattr(self, field) + getattr(other, field)
@@ -874,7 +1010,7 @@ class PlanCost:
874
1010
  "time_upper_bound",
875
1011
  ]
876
1012
  }
877
- for dataclass_field in ["quality", "quality_lower_bound", "quality_upper_bound"]:
878
- dct[dataclass_field] = getattr(self, dataclass_field) * getattr(other, dataclass_field)
1013
+ for model_field in ["quality", "quality_lower_bound", "quality_upper_bound"]:
1014
+ dct[model_field] = getattr(self, model_field) * getattr(other, model_field)
879
1015
 
880
1016
  return PlanCost(**dct)