palimpzest 1.1.0__tar.gz → 1.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. {palimpzest-1.1.0/src/palimpzest.egg-info → palimpzest-1.2.0}/PKG-INFO +2 -2
  2. {palimpzest-1.1.0 → palimpzest-1.2.0}/pyproject.toml +2 -2
  3. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/core/models.py +71 -1
  4. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/prompts/prompt_factory.py +15 -5
  5. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/execution/execution_strategy.py +2 -0
  6. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/execution/mab_execution_strategy.py +12 -5
  7. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/operators/convert.py +2 -0
  8. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/operators/filter.py +2 -0
  9. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/operators/join.py +104 -69
  10. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/operators/rag.py +15 -11
  11. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/operators/topk.py +24 -5
  12. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/optimizer/cost_model.py +9 -4
  13. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/processor/config.py +1 -0
  14. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/processor/query_processor_factory.py +22 -0
  15. {palimpzest-1.1.0 → palimpzest-1.2.0/src/palimpzest.egg-info}/PKG-INFO +2 -2
  16. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest.egg-info/requires.txt +1 -1
  17. {palimpzest-1.1.0 → palimpzest-1.2.0}/LICENSE +0 -0
  18. {palimpzest-1.1.0 → palimpzest-1.2.0}/README.md +0 -0
  19. {palimpzest-1.1.0 → palimpzest-1.2.0}/setup.cfg +0 -0
  20. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/__init__.py +0 -0
  21. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/agents/__init__.py +0 -0
  22. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/agents/compute_agents.py +0 -0
  23. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/agents/search_agents.py +0 -0
  24. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/constants.py +0 -0
  25. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/core/__init__.py +0 -0
  26. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/core/data/__init__.py +0 -0
  27. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/core/data/context.py +0 -0
  28. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/core/data/context_manager.py +0 -0
  29. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/core/data/dataset.py +0 -0
  30. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/core/data/index_dataset.py +0 -0
  31. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/core/data/iter_dataset.py +0 -0
  32. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/core/elements/__init__.py +0 -0
  33. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/core/elements/filters.py +0 -0
  34. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/core/elements/groupbysig.py +0 -0
  35. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/core/elements/records.py +0 -0
  36. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/core/lib/__init__.py +0 -0
  37. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/core/lib/schemas.py +0 -0
  38. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/policy.py +0 -0
  39. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/prompts/__init__.py +0 -0
  40. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/prompts/agent_prompts.py +0 -0
  41. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/prompts/aggregate_prompts.py +0 -0
  42. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/prompts/context_search.py +0 -0
  43. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/prompts/convert_prompts.py +0 -0
  44. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/prompts/critique_and_refine_prompts.py +0 -0
  45. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/prompts/filter_prompts.py +0 -0
  46. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/prompts/join_prompts.py +0 -0
  47. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/prompts/moa_aggregator_prompts.py +0 -0
  48. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/prompts/moa_proposer_prompts.py +0 -0
  49. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/prompts/split_merge_prompts.py +0 -0
  50. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/prompts/split_proposer_prompts.py +0 -0
  51. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/prompts/utils.py +0 -0
  52. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/prompts/validator.py +0 -0
  53. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/__init__.py +0 -0
  54. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/execution/__init__.py +0 -0
  55. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/execution/all_sample_execution_strategy.py +0 -0
  56. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/execution/execution_strategy_type.py +0 -0
  57. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/execution/parallel_execution_strategy.py +0 -0
  58. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/execution/single_threaded_execution_strategy.py +0 -0
  59. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/generators/__init__.py +0 -0
  60. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/generators/generators.py +0 -0
  61. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/operators/__init__.py +0 -0
  62. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/operators/aggregate.py +0 -0
  63. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/operators/compute.py +0 -0
  64. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/operators/critique_and_refine.py +0 -0
  65. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/operators/distinct.py +0 -0
  66. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/operators/limit.py +0 -0
  67. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/operators/logical.py +0 -0
  68. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/operators/mixture_of_agents.py +0 -0
  69. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/operators/physical.py +0 -0
  70. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/operators/project.py +0 -0
  71. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/operators/scan.py +0 -0
  72. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/operators/search.py +0 -0
  73. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/operators/split.py +0 -0
  74. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/optimizer/__init__.py +0 -0
  75. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/optimizer/optimizer.py +0 -0
  76. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/optimizer/optimizer_strategy.py +0 -0
  77. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/optimizer/optimizer_strategy_type.py +0 -0
  78. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/optimizer/plan.py +0 -0
  79. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/optimizer/primitives.py +0 -0
  80. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/optimizer/rules.py +0 -0
  81. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/optimizer/tasks.py +0 -0
  82. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/processor/__init__.py +0 -0
  83. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/query/processor/query_processor.py +0 -0
  84. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/schemabuilder/__init__.py +0 -0
  85. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/schemabuilder/schema_builder.py +0 -0
  86. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/tools/README.md +0 -0
  87. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/tools/__init__.py +0 -0
  88. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/tools/allenpdf.py +0 -0
  89. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/tools/pdfparser.py +0 -0
  90. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/tools/skema_tools.py +0 -0
  91. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/utils/__init__.py +0 -0
  92. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/utils/env_helpers.py +0 -0
  93. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/utils/hash_helpers.py +0 -0
  94. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/utils/model_helpers.py +0 -0
  95. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/utils/progress.py +0 -0
  96. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/utils/udfs.py +0 -0
  97. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/validator/__init__.py +0 -0
  98. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest/validator/validator.py +0 -0
  99. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest.egg-info/SOURCES.txt +0 -0
  100. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest.egg-info/dependency_links.txt +0 -0
  101. {palimpzest-1.1.0 → palimpzest-1.2.0}/src/palimpzest.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: palimpzest
3
- Version: 1.1.0
3
+ Version: 1.2.0
4
4
  Summary: Palimpzest is a system which enables anyone to process AI-powered analytical queries simply by defining them in a declarative language
5
5
  Author-email: MIT DSG Semantic Management Lab <michjc@csail.mit.edu>
6
6
  Project-URL: homepage, https://palimpzest.org
@@ -31,7 +31,7 @@ Requires-Dist: pillow>=11.3.0
31
31
  Requires-Dist: prettytable>=3.9.0
32
32
  Requires-Dist: psutil==5.9.5
33
33
  Requires-Dist: PyLD>=2.0.4
34
- Requires-Dist: pyarrow==20.0.0
34
+ Requires-Dist: pyarrow>=20.0.0
35
35
  Requires-Dist: pypdf>=5.1.0
36
36
  Requires-Dist: pytest-mock>=3.14.0
37
37
  Requires-Dist: pyyaml>=6.0.1
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "palimpzest"
3
- version = "1.1.0"
3
+ version = "1.2.0"
4
4
  description = "Palimpzest is a system which enables anyone to process AI-powered analytical queries simply by defining them in a declarative language"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -25,7 +25,7 @@ dependencies = [
25
25
  "prettytable>=3.9.0",
26
26
  "psutil==5.9.5",
27
27
  "PyLD>=2.0.4",
28
- "pyarrow==20.0.0",
28
+ "pyarrow>=20.0.0",
29
29
  "pypdf>=5.1.0",
30
30
  "pytest-mock>=3.14.0",
31
31
  "pyyaml>=6.0.1",
@@ -35,12 +35,18 @@ class GenerationStats(BaseModel):
35
35
  # typed as a float because GenerationStats may be amortized (i.e. divided) across a number of output records
36
36
  total_output_tokens: float = 0.0
37
37
 
38
+ # the total number of input tokens processed by embedding models
39
+ total_embedding_input_tokens: float = 0.0
40
+
38
41
  # the total cost of processing the input tokens; None if this operation did not use an LLM
39
42
  total_input_cost: float = 0.0
40
43
 
41
44
  # the total cost of processing the output tokens; None if this operation did not use an LLM
42
45
  total_output_cost: float = 0.0
43
46
 
47
+ # the total cost of processing input tokens for embedding models
48
+ total_embedding_cost: float = 0.0
49
+
44
50
  # the total cost of processing the input and output tokens; None if this operation did not use an LLM
45
51
  cost_per_record: float = 0.0
46
52
 
@@ -68,6 +74,9 @@ class GenerationStats(BaseModel):
68
74
  "fn_call_duration_secs",
69
75
  "total_llm_calls",
70
76
  "total_embedding_llm_calls",
77
+ "total_embedding_input_tokens",
78
+ "total_embedding_cost"
79
+
71
80
  ]:
72
81
  setattr(self, model_field, getattr(self, model_field) + getattr(other, model_field))
73
82
  return self
@@ -85,6 +94,8 @@ class GenerationStats(BaseModel):
85
94
  "cost_per_record",
86
95
  "total_llm_calls",
87
96
  "total_embedding_llm_calls",
97
+ "total_embedding_input_tokens",
98
+ "total_embedding_cost"
88
99
  ]
89
100
  }
90
101
  # dct['raw_answers'] = self.raw_answers + other.raw_answers
@@ -107,6 +118,8 @@ class GenerationStats(BaseModel):
107
118
  "fn_call_duration_secs",
108
119
  "total_llm_calls",
109
120
  "total_embedding_llm_calls",
121
+ "total_embedding_input_tokens",
122
+ "total_embedding_cost"
110
123
  ]:
111
124
  setattr(self, model_field, getattr(self, model_field) / quotient)
112
125
  return self
@@ -128,6 +141,8 @@ class GenerationStats(BaseModel):
128
141
  "total_llm_calls",
129
142
  "total_embedding_llm_calls",
130
143
  "cost_per_record",
144
+ "total_embedding_input_tokens",
145
+ "total_embedding_cost"
131
146
  ]
132
147
  }
133
148
  dct["model_name"] = self.model_name
@@ -217,6 +232,10 @@ class RecordOpStats(BaseModel):
217
232
  # typed as a float because GenerationStats may be amortized (i.e. divided) across a number of output records
218
233
  total_output_tokens: float = 0.0
219
234
 
235
+ # the total number of input tokens processed by embedding models
236
+ # typed as a float because GenerationStats may be amortized (i.e. divided) across a number of output records
237
+ total_embedding_input_tokens: float = 0.0
238
+
220
239
  # the total cost of processing the input tokens; None if this operation did not use an LLM
221
240
  total_input_cost: float = 0.0
222
241
 
@@ -278,6 +297,9 @@ class OperatorStats(BaseModel):
278
297
  # the total output tokens processed by this operation
279
298
  total_output_tokens: int = 0
280
299
 
300
+ #the total embedding input tokens processed by this operation
301
+ total_embedding_input_tokens: int = 0
302
+
281
303
  # a list of RecordOpStats processed by the operation
282
304
  record_op_stats_lst: list[RecordOpStats] = Field(default_factory=list)
283
305
 
@@ -309,6 +331,7 @@ class OperatorStats(BaseModel):
309
331
  self.total_op_cost += stats.total_op_cost
310
332
  self.total_input_tokens += stats.total_input_tokens
311
333
  self.total_output_tokens += stats.total_output_tokens
334
+ self.total_embedding_input_tokens += stats.total_embedding_input_tokens
312
335
  self.record_op_stats_lst.extend(stats.record_op_stats_lst)
313
336
 
314
337
  elif isinstance(stats, RecordOpStats):
@@ -319,6 +342,7 @@ class OperatorStats(BaseModel):
319
342
  self.total_op_cost += stats.cost_per_record
320
343
  self.total_input_tokens += stats.total_input_tokens
321
344
  self.total_output_tokens += stats.total_output_tokens
345
+ self.total_embedding_input_tokens += stats.total_embedding_input_tokens
322
346
 
323
347
  else:
324
348
  raise TypeError(f"Cannot add {type(stats)} to OperatorStats")
@@ -370,6 +394,9 @@ class BasePlanStats(BaseModel):
370
394
  # total output tokens processed by this plan
371
395
  total_output_tokens: int = 0
372
396
 
397
+ # total embedding input tokens processed by this plan
398
+ total_embedding_input_tokens: int = 0
399
+
373
400
  # start time for the plan execution; should be set by calling PlanStats.start()
374
401
  start_time: float | None = None
375
402
 
@@ -385,6 +412,7 @@ class BasePlanStats(BaseModel):
385
412
  self.total_plan_cost = self.sum_op_costs() + self.sum_validation_costs()
386
413
  self.total_input_tokens = self.sum_input_tokens() + self.sum_validation_input_tokens()
387
414
  self.total_output_tokens = self.sum_output_tokens() + self.sum_validation_output_tokens()
415
+ self.total_embedding_input_tokens = self.sum_embedding_input_tokens() + self.sum_validation_embedding_input_tokens()
388
416
 
389
417
  @staticmethod
390
418
  @abstractmethod
@@ -415,6 +443,13 @@ class BasePlanStats(BaseModel):
415
443
  """
416
444
  pass
417
445
 
446
+ @abstractmethod
447
+ def sum_embedding_input_tokens(self) -> int:
448
+ """
449
+ Sum the input embedding tokens processed by all operators in this plan.
450
+ """
451
+ pass
452
+
418
453
  @abstractmethod
419
454
  def add_record_op_stats(self, unique_full_op_id: str, record_op_stats: RecordOpStats | list[RecordOpStats]) -> None:
420
455
  """
@@ -453,6 +488,12 @@ class BasePlanStats(BaseModel):
453
488
  Sum the output tokens processed by all validation generations in this plan.
454
489
  """
455
490
  return sum([gen_stats.total_output_tokens for _, gen_stats in self.validation_gen_stats.items()])
491
+
492
+ def sum_validation_embedding_input_tokens(self) -> int:
493
+ """
494
+ Sum the input embedding tokens processed by all validation generations in this plan.
495
+ """
496
+ return sum([gen_stats.total_embedding_input_tokens for _, gen_stats in self.validation_gen_stats.items()])
456
497
 
457
498
  def get_total_cost_so_far(self) -> float:
458
499
  """
@@ -501,6 +542,12 @@ class PlanStats(BasePlanStats):
501
542
  Sum the output tokens processed by all operators in this plan.
502
543
  """
503
544
  return sum([op_stats.total_output_tokens for _, op_stats in self.operator_stats.items()])
545
+
546
+ def sum_embedding_input_tokens(self) -> int:
547
+ """
548
+ Sum the input embedding tokens processed by all operators in this plan.
549
+ """
550
+ return sum([op_stats.total_embedding_input_tokens for _, op_stats in self.operator_stats.items()])
504
551
 
505
552
  def add_record_op_stats(self, unique_full_op_id: str, record_op_stats: RecordOpStats | list[RecordOpStats]) -> None:
506
553
  """
@@ -528,6 +575,7 @@ class PlanStats(BasePlanStats):
528
575
  self.total_plan_cost += plan_stats.total_plan_cost
529
576
  self.total_input_tokens += plan_stats.total_input_tokens
530
577
  self.total_output_tokens += plan_stats.total_output_tokens
578
+ self.total_embedding_input_tokens += plan_stats.total_embedding_input_tokens
531
579
  for unique_full_op_id, op_stats in plan_stats.operator_stats.items():
532
580
  if unique_full_op_id in self.operator_stats:
533
581
  self.operator_stats[unique_full_op_id] += op_stats
@@ -539,6 +587,7 @@ class PlanStats(BasePlanStats):
539
587
  stats += f"total_plan_cost={self.total_plan_cost} \n"
540
588
  stats += f"total_input_tokens={self.total_input_tokens} \n"
541
589
  stats += f"total_output_tokens={self.total_output_tokens} \n"
590
+ stats += f"total_embedding_input_tokens={self.total_embedding_input_tokens} \n"
542
591
  for idx, op_stats in enumerate(self.operator_stats.values()):
543
592
  stats += f"{idx}. {op_stats.op_name} time={op_stats.total_op_time} cost={op_stats.total_op_cost} \n"
544
593
  return stats
@@ -586,6 +635,12 @@ class SentinelPlanStats(BasePlanStats):
586
635
  Sum the output tokens processed by all operators in this plan.
587
636
  """
588
637
  return sum(sum([op_stats.total_output_tokens for _, op_stats in phys_op_stats.items()]) for _, phys_op_stats in self.operator_stats.items())
638
+
639
+ def sum_embedding_input_tokens(self) -> int:
640
+ """
641
+ Sum the output tokens processed by all operators in this plan.
642
+ """
643
+ return sum(sum([op_stats.total_embedding_input_tokens for _, op_stats in phys_op_stats.items()]) for _, phys_op_stats in self.operator_stats.items())
589
644
 
590
645
  def add_record_op_stats(self, unique_logical_op_id: str, record_op_stats: RecordOpStats | list[RecordOpStats]) -> None:
591
646
  """
@@ -627,6 +682,7 @@ class SentinelPlanStats(BasePlanStats):
627
682
  self.total_plan_cost += plan_stats.total_plan_cost
628
683
  self.total_input_tokens += plan_stats.total_input_tokens
629
684
  self.total_output_tokens += plan_stats.total_output_tokens
685
+ self.total_embedding_input_tokens += plan_stats.total_embedding_input_tokens
630
686
  for unique_logical_op_id, physical_op_stats in plan_stats.operator_stats.items():
631
687
  for full_op_id, op_stats in physical_op_stats.items():
632
688
  if unique_logical_op_id in self.operator_stats:
@@ -648,6 +704,7 @@ class SentinelPlanStats(BasePlanStats):
648
704
  stats += f"total_plan_cost={self.total_plan_cost} \n"
649
705
  stats += f"total_input_tokens={self.total_input_tokens} \n"
650
706
  stats += f"total_output_tokens={self.total_output_tokens} \n"
707
+ stats += f"total_embedding_input_tokens={self.total_embedding_input_tokens} \n"
651
708
  for outer_idx, physical_op_stats in enumerate(self.operator_stats.values()):
652
709
  total_time = sum([op_stats.total_op_time for op_stats in physical_op_stats.values()])
653
710
  total_cost = sum([op_stats.total_op_cost for op_stats in physical_op_stats.values()])
@@ -695,6 +752,9 @@ class ExecutionStats(BaseModel):
695
752
  # total number of output tokens processed
696
753
  total_output_tokens: int = 0
697
754
 
755
+ # total number of embedding input tokens processed
756
+ total_embedding_input_tokens: int = 0
757
+
698
758
  # total number of tokens processed
699
759
  total_tokens: int = 0
700
760
 
@@ -748,7 +808,8 @@ class ExecutionStats(BaseModel):
748
808
  # compute the tokens for total execution
749
809
  self.total_input_tokens = self.sum_input_tokens()
750
810
  self.total_output_tokens = self.sum_output_tokens()
751
- self.total_tokens = self.total_input_tokens + self.total_output_tokens
811
+ self.total_embedding_input_tokens = self.sum_embedding_input_tokens()
812
+ self.total_tokens = self.total_input_tokens + self.total_output_tokens + self.total_embedding_input_tokens
752
813
 
753
814
  # compute plan_strs
754
815
  self.plan_strs = {plan_id: plan_stats.plan_str for plan_id, plan_stats in self.plan_stats.items()}
@@ -780,6 +841,15 @@ class ExecutionStats(BaseModel):
780
841
  sentinel_plan_output_tokens = sum([plan_stats.sum_output_tokens() for _, plan_stats in self.sentinel_plan_stats.items()])
781
842
  plan_output_tokens = sum([plan_stats.sum_output_tokens() for _, plan_stats in self.plan_stats.items()])
782
843
  return plan_output_tokens + sentinel_plan_output_tokens
844
+
845
+
846
+ def sum_embedding_input_tokens(self) -> int:
847
+ """
848
+ Sum the embedding input tokens processed in this execution
849
+ """
850
+ sentinel_plan_embedding_input_tokens = sum([plan_stats.sum_embedding_input_tokens() for _, plan_stats in self.sentinel_plan_stats.items()])
851
+ plan_embedding_input_tokens = sum([plan_stats.sum_embedding_input_tokens() for _, plan_stats in self.plan_stats.items()])
852
+ return plan_embedding_input_tokens + sentinel_plan_embedding_input_tokens
783
853
 
784
854
  def add_plan_stats(self, plan_stats: PlanStats | SentinelPlanStats | list[PlanStats] | list[SentinelPlanStats]) -> None:
785
855
  """
@@ -830,7 +830,7 @@ class PromptFactory:
830
830
  field_type = dr.get_field_type(field_name)
831
831
 
832
832
  # audio filepath (or list of audio filepaths)
833
- if field_type.annotation in [AudioFilepath, AudioFilepath | None, AudioFilepath | Any]:
833
+ if field_type.annotation in [AudioFilepath, AudioFilepath | None, AudioFilepath | Any] and field_value is not None:
834
834
  with open(field_value, "rb") as f:
835
835
  base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
836
836
  audio_content.append(
@@ -839,6 +839,8 @@ class PromptFactory:
839
839
 
840
840
  elif field_type.annotation in [list[AudioFilepath], list[AudioFilepath] | None, list[AudioFilepath] | Any]:
841
841
  for audio_filepath in field_value:
842
+ if audio_filepath is None:
843
+ continue
842
844
  with open(audio_filepath, "rb") as f:
843
845
  base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
844
846
  audio_content.append(
@@ -846,13 +848,15 @@ class PromptFactory:
846
848
  )
847
849
 
848
850
  # pre-encoded images (or list of pre-encoded images)
849
- elif field_type.annotation in [AudioBase64, AudioBase64 | None, AudioBase64 | Any]:
851
+ elif field_type.annotation in [AudioBase64, AudioBase64 | None, AudioBase64 | Any] and field_value is not None:
850
852
  audio_content.append(
851
853
  {"type": "input_audio", "input_audio": {"data": field_value, "format": "wav"}}
852
854
  )
853
855
 
854
856
  elif field_type.annotation in [list[AudioBase64], list[AudioBase64] | None, list[AudioBase64] | Any]:
855
857
  for base64_audio in field_value:
858
+ if base64_audio is None:
859
+ continue
856
860
  audio_content.append(
857
861
  {"type": "input_audio", "input_audio": {"data": base64_audio, "format": "wav"}}
858
862
  )
@@ -882,7 +886,7 @@ class PromptFactory:
882
886
  field_type = dr.get_field_type(field_name)
883
887
 
884
888
  # image filepath (or list of image filepaths)
885
- if field_type.annotation in [ImageFilepath, ImageFilepath | None, ImageFilepath | Any]:
889
+ if field_type.annotation in [ImageFilepath, ImageFilepath | None, ImageFilepath | Any] and field_value is not None:
886
890
  with open(field_value, "rb") as f:
887
891
  base64_image_str = base64.b64encode(f.read()).decode("utf-8")
888
892
  image_content.append(
@@ -891,6 +895,8 @@ class PromptFactory:
891
895
 
892
896
  elif field_type.annotation in [list[ImageFilepath], list[ImageFilepath] | None, list[ImageFilepath] | Any]:
893
897
  for image_filepath in field_value:
898
+ if image_filepath is None:
899
+ continue
894
900
  with open(image_filepath, "rb") as f:
895
901
  base64_image_str = base64.b64encode(f.read()).decode("utf-8")
896
902
  image_content.append(
@@ -898,21 +904,25 @@ class PromptFactory:
898
904
  )
899
905
 
900
906
  # image url (or list of image urls)
901
- elif field_type.annotation in [ImageURL, ImageURL | None, ImageURL | Any]:
907
+ elif field_type.annotation in [ImageURL, ImageURL | None, ImageURL | Any] and field_value is not None:
902
908
  image_content.append({"type": "image_url", "image_url": {"url": field_value}})
903
909
 
904
910
  elif field_type.annotation in [list[ImageURL], list[ImageURL] | None, list[ImageURL] | Any]:
905
911
  for image_url in field_value:
912
+ if image_url is None:
913
+ continue
906
914
  image_content.append({"type": "image_url", "image_url": {"url": image_url}})
907
915
 
908
916
  # pre-encoded images (or list of pre-encoded images)
909
- elif field_type.annotation in [ImageBase64, ImageBase64 | None, ImageBase64 | Any]:
917
+ elif field_type.annotation in [ImageBase64, ImageBase64 | None, ImageBase64 | Any] and field_value is not None:
910
918
  image_content.append(
911
919
  {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{field_value}"}}
912
920
  )
913
921
 
914
922
  elif field_type.annotation in [list[ImageBase64], list[ImageBase64] | None, list[ImageBase64] | Any]:
915
923
  for base64_image in field_value:
924
+ if base64_image is None:
925
+ continue
916
926
  image_content.append(
917
927
  {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
918
928
  )
@@ -91,6 +91,7 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
91
91
  use_final_op_quality: bool = False,
92
92
  seed: int = 42,
93
93
  exp_name: str | None = None,
94
+ dont_use_priors: bool = False,
94
95
  *args,
95
96
  **kwargs,
96
97
  ):
@@ -105,6 +106,7 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
105
106
  self.seed = seed
106
107
  self.rng = np.random.default_rng(seed=seed)
107
108
  self.exp_name = exp_name
109
+ self.dont_use_priors = dont_use_priors
108
110
 
109
111
  # general cache which maps hash(logical_op_id, phys_op_id, hash(input)) --> record_set
110
112
  self.cache: dict[int, DataRecordSet] = {}
@@ -44,6 +44,7 @@ class OpFrontier:
44
44
  seed: int,
45
45
  policy: Policy,
46
46
  priors: dict | None = None,
47
+ dont_use_priors: bool = False,
47
48
  ):
48
49
  # set k and j, which are the initial number of operators in the frontier and the
49
50
  # initial number of records to sample for each frontier operator
@@ -51,6 +52,7 @@ class OpFrontier:
51
52
  self.j = j
52
53
  self.source_indices = source_indices
53
54
  self.root_dataset_ids = root_dataset_ids
55
+ self.dont_use_priors = dont_use_priors
54
56
 
55
57
  # store the policy that we are optimizing under
56
58
  self.policy = policy
@@ -68,6 +70,7 @@ class OpFrontier:
68
70
  is_llm_filter = isinstance(sample_op, LLMFilter)
69
71
  is_llm_topk = isinstance(sample_op, TopKOp) and isinstance(sample_op.index, Collection)
70
72
  self.is_llm_op = is_llm_convert or is_llm_filter or is_llm_topk or self.is_llm_join
73
+ self.is_llm_convert = is_llm_convert
71
74
 
72
75
  # get order in which we will sample physical operators for this logical operator
73
76
  sample_op_indices = self._get_op_index_order(op_set, seed)
@@ -190,7 +193,9 @@ class OpFrontier:
190
193
  Returns a list of indices for the operators in the op_set.
191
194
  """
192
195
  # if this is not an llm-operator, we simply return the indices in random order
193
- if not self.is_llm_op:
196
+ if not self.is_llm_op or self.dont_use_priors:
197
+ if self.is_llm_convert:
198
+ print("Using NO PRIORS for operator sampling order")
194
199
  rng = np.random.default_rng(seed=seed)
195
200
  op_indices = np.arange(len(op_set))
196
201
  rng.shuffle(op_indices)
@@ -198,6 +203,8 @@ class OpFrontier:
198
203
 
199
204
  # if this is an llm-operator, but we do not have priors, we first compute naive priors
200
205
  if self.priors is None or any([op_id not in self.priors for op_id in map(lambda op: op.get_op_id(), op_set)]):
206
+ if self.is_llm_convert:
207
+ print("Using NAIVE PRIORS for operator sampling order")
201
208
  self.priors = self._compute_naive_priors(op_set)
202
209
 
203
210
  # NOTE: self.priors is a dictionary with format:
@@ -770,7 +777,7 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
770
777
 
771
778
  # if the operator is a non-llm filter which has filtered out records, remove those records from
772
779
  # all downstream operators' full_op_id_to_sources_not_processed
773
- if isinstance(op_set[0], NonLLMFilter):
780
+ if isinstance(op_set[0], NonLLMFilter) and next_unique_logical_op_id is not None:
774
781
  self._remove_filtered_records_from_downstream_ops(topo_idx, plan, op_frontiers, source_indices_to_all_record_sets)
775
782
 
776
783
  # finalize plan stats
@@ -805,7 +812,7 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
805
812
  assert len(root_dataset_ids) == 1, f"Scan for {sample_op} has {len(root_dataset_ids)} > 1 root dataset ids"
806
813
  root_dataset_id = root_dataset_ids[0]
807
814
  source_indices = dataset_id_to_shuffled_source_indices[root_dataset_id]
808
- op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors)
815
+ op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors, self.dont_use_priors)
809
816
  elif isinstance(sample_op, JoinOp):
810
817
  assert len(source_unique_logical_op_ids) == 2, f"Join for {sample_op} has {len(source_unique_logical_op_ids)} != 2 source logical operators"
811
818
  left_source_indices = op_frontiers[source_unique_logical_op_ids[0]].source_indices
@@ -814,10 +821,10 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
814
821
  for left_source_idx in left_source_indices:
815
822
  for right_source_idx in right_source_indices:
816
823
  source_indices.append((left_source_idx, right_source_idx))
817
- op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors)
824
+ op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors, self.dont_use_priors)
818
825
  else:
819
826
  source_indices = op_frontiers[source_unique_logical_op_ids[0]].source_indices
820
- op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors)
827
+ op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors, self.dont_use_priors)
821
828
 
822
829
  # initialize and start the progress manager
823
830
  self.progress_manager = create_progress_manager(plan, sample_budget=self.sample_budget, sample_cost_budget=self.sample_cost_budget, progress=self.progress)
@@ -121,8 +121,10 @@ class ConvertOp(PhysicalOperator, ABC):
121
121
  generated_fields=field_names,
122
122
  total_input_tokens=per_record_stats.total_input_tokens,
123
123
  total_output_tokens=per_record_stats.total_output_tokens,
124
+ total_embedding_input_tokens=per_record_stats.total_embedding_input_tokens,
124
125
  total_input_cost=per_record_stats.total_input_cost,
125
126
  total_output_cost=per_record_stats.total_output_cost,
127
+ total_embedding_cost=per_record_stats.total_embedding_cost,
126
128
  llm_call_duration_secs=per_record_stats.llm_call_duration_secs,
127
129
  fn_call_duration_secs=per_record_stats.fn_call_duration_secs,
128
130
  total_llm_calls=per_record_stats.total_llm_calls,
@@ -89,8 +89,10 @@ class FilterOp(PhysicalOperator, ABC):
89
89
  filter_str=self.filter_obj.get_filter_str(),
90
90
  total_input_tokens=generation_stats.total_input_tokens,
91
91
  total_output_tokens=generation_stats.total_output_tokens,
92
+ total_embedding_input_tokens=generation_stats.total_embedding_input_tokens,
92
93
  total_input_cost=generation_stats.total_input_cost,
93
94
  total_output_cost=generation_stats.total_output_cost,
95
+ total_embedding_cost=generation_stats.total_embedding_cost,
94
96
  llm_call_duration_secs=generation_stats.llm_call_duration_secs,
95
97
  fn_call_duration_secs=generation_stats.fn_call_duration_secs,
96
98
  total_llm_calls=generation_stats.total_llm_calls,