palimpzest 0.7.7__py3-none-any.whl → 0.7.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. palimpzest/constants.py +113 -75
  2. palimpzest/core/data/dataclasses.py +55 -38
  3. palimpzest/core/elements/index.py +5 -15
  4. palimpzest/core/elements/records.py +1 -1
  5. palimpzest/prompts/prompt_factory.py +1 -1
  6. palimpzest/query/execution/all_sample_execution_strategy.py +216 -0
  7. palimpzest/query/execution/execution_strategy.py +4 -4
  8. palimpzest/query/execution/execution_strategy_type.py +7 -1
  9. palimpzest/query/execution/mab_execution_strategy.py +184 -72
  10. palimpzest/query/execution/parallel_execution_strategy.py +182 -15
  11. palimpzest/query/execution/single_threaded_execution_strategy.py +21 -21
  12. palimpzest/query/generators/api_client_factory.py +6 -7
  13. palimpzest/query/generators/generators.py +5 -8
  14. palimpzest/query/operators/aggregate.py +4 -3
  15. palimpzest/query/operators/convert.py +1 -1
  16. palimpzest/query/operators/filter.py +1 -1
  17. palimpzest/query/operators/limit.py +1 -1
  18. palimpzest/query/operators/map.py +1 -1
  19. palimpzest/query/operators/physical.py +8 -4
  20. palimpzest/query/operators/project.py +1 -1
  21. palimpzest/query/operators/retrieve.py +7 -23
  22. palimpzest/query/operators/scan.py +1 -1
  23. palimpzest/query/optimizer/cost_model.py +54 -62
  24. palimpzest/query/optimizer/optimizer.py +2 -6
  25. palimpzest/query/optimizer/plan.py +4 -4
  26. palimpzest/query/optimizer/primitives.py +1 -1
  27. palimpzest/query/optimizer/rules.py +8 -26
  28. palimpzest/query/optimizer/tasks.py +3 -3
  29. palimpzest/query/processor/processing_strategy_type.py +2 -2
  30. palimpzest/query/processor/sentinel_processor.py +0 -2
  31. palimpzest/sets.py +2 -3
  32. palimpzest/utils/generation_helpers.py +1 -1
  33. palimpzest/utils/model_helpers.py +27 -9
  34. palimpzest/utils/progress.py +81 -72
  35. {palimpzest-0.7.7.dist-info → palimpzest-0.7.9.dist-info}/METADATA +4 -2
  36. {palimpzest-0.7.7.dist-info → palimpzest-0.7.9.dist-info}/RECORD +39 -38
  37. {palimpzest-0.7.7.dist-info → palimpzest-0.7.9.dist-info}/WHEEL +1 -1
  38. {palimpzest-0.7.7.dist-info → palimpzest-0.7.9.dist-info}/licenses/LICENSE +0 -0
  39. {palimpzest-0.7.7.dist-info → palimpzest-0.7.9.dist-info}/top_level.txt +0 -0
palimpzest/constants.py CHANGED
@@ -10,21 +10,69 @@ class Model(str, Enum):
10
10
  which requires invoking an LLM. It does NOT specify whether the model need be executed
11
11
  remotely or locally (if applicable).
12
12
  """
13
- # LLAMA3 = "meta-llama/Llama-3-8b-chat-hf"
14
- LLAMA3 = "meta-llama/Llama-3.3-70B-Instruct-Turbo"
15
- LLAMA3_V = "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo"
13
+ LLAMA3_2_3B = "meta-llama/Llama-3.2-3B-Instruct-Turbo"
14
+ LLAMA3_1_8B = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"
15
+ LLAMA3_3_70B = "meta-llama/Llama-3.3-70B-Instruct-Turbo"
16
+ LLAMA3_2_90B_V = "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo"
16
17
  MIXTRAL = "mistralai/Mixtral-8x7B-Instruct-v0.1"
17
- DEEPSEEK = "deepseek-ai/DeepSeek-V3"
18
+ DEEPSEEK_V3 = "deepseek-ai/DeepSeek-V3"
19
+ DEEPSEEK_R1_DISTILL_QWEN_1_5B = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
18
20
  GPT_4o = "gpt-4o-2024-08-06"
19
- GPT_4o_V = "gpt-4o-2024-08-06"
20
21
  GPT_4o_MINI = "gpt-4o-mini-2024-07-18"
21
- GPT_4o_MINI_V = "gpt-4o-mini-2024-07-18"
22
22
  TEXT_EMBEDDING_3_SMALL = "text-embedding-3-small"
23
23
  CLIP_VIT_B_32 = "clip-ViT-B-32"
24
+ # o1 = "o1-2024-12-17"
24
25
 
25
26
  def __repr__(self):
26
27
  return f"{self.name}"
27
28
 
29
+ def is_deepseek_model(self):
30
+ return "deepseek" in self.value.lower()
31
+
32
+ def is_llama_model(self):
33
+ return "llama" in self.value.lower()
34
+
35
+ def is_mixtral_model(self):
36
+ return "mixtral" in self.value.lower()
37
+
38
+ def is_clip_model(self):
39
+ return "clip" in self.value.lower()
40
+
41
+ def is_together_model(self):
42
+ is_llama_model = self.is_llama_model()
43
+ is_mixtral_model = self.is_mixtral_model()
44
+ is_deepseek_model = self.is_deepseek_model()
45
+ is_clip_model = self.is_clip_model()
46
+ return is_llama_model or is_mixtral_model or is_deepseek_model or is_clip_model
47
+
48
+ def is_gpt_4o_model(self):
49
+ return "gpt-4o" in self.value.lower()
50
+
51
+ def is_o1_model(self):
52
+ return "o1" in self.value.lower()
53
+
54
+ def is_text_embedding_model(self):
55
+ return "text-embedding" in self.value.lower()
56
+
57
+ def is_openai_model(self):
58
+ is_gpt4_model = self.is_gpt_4o_model()
59
+ is_o1_model = self.is_o1_model()
60
+ is_text_embedding_model = self.is_text_embedding_model()
61
+ return is_gpt4_model or is_o1_model or is_text_embedding_model
62
+
63
+ def is_vision_model(self):
64
+ vision_models = [
65
+ "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
66
+ "gpt-4o-2024-08-06",
67
+ "gpt-4o-mini-2024-07-18",
68
+ "o1-2024-12-17",
69
+ ]
70
+ return self.value in vision_models
71
+
72
+ def is_embedding_model(self):
73
+ is_clip_model = self.is_clip_model()
74
+ is_text_embedding_model = self.is_text_embedding_model()
75
+ return is_clip_model or is_text_embedding_model
28
76
 
29
77
  class APIClient(str, Enum):
30
78
  """
@@ -194,23 +242,10 @@ LOG_LLM_OUTPUT = False
194
242
 
195
243
 
196
244
  #### MODEL PERFORMANCE & COST METRICS ####
197
- # I've looked across models and grouped knowledge into commonly used categories:
198
- # - Agg. Benchmark (we only use MMLU for this)
199
- # - Commonsense Reasoning
200
- # - World Knowledge
201
- # - Reading Comprehension
202
- # - Code
203
- # - Math
204
- #
205
- # We don't have global overlap on the World Knowledge and/or Reading Comprehension
206
- # datasets. Thus, we include these categories results where we have them, but they
207
- # are completely omitted for now.
208
- #
209
- # Within each category only certain models have overlapping results on the same
210
- # individual datasets; in order to have consistent evaluations I have computed
211
- # the average result for each category using only the shared sets of datasets within
212
- # that category. All datasets for which we have results will be shown but commented
213
- # with ###; datasets which are used in our category averages will have a ^.
245
+ # Overall model quality is computed using MMLU-Pro; multi-modal models currently use the same score for vision
246
+ # - in the future we should split quality for vision vs. multi-modal vs. text
247
+ # - code quality was computed using HumanEval, but that benchmark is too easy and should be replaced.
248
+ # - https://huggingface.co/spaces/TIGER-Lab/MMLU-Pro
214
249
  #
215
250
  # Cost is presented in terms of USD / token for input tokens and USD / token for
216
251
  # generated tokens.
@@ -220,17 +255,28 @@ LOG_LLM_OUTPUT = False
220
255
  # values more precisely:
221
256
  # - https://artificialanalysis.ai/models/llama-3-1-instruct-8b
222
257
  #
223
- # LLAMA3_8B_MODEL_CARD = {
224
- # ##### Cost in USD #####
225
- # "usd_per_input_token": 0.18 / 1E6,
226
- # "usd_per_output_token": 0.18 / 1E6,
227
- # ##### Time #####
228
- # "seconds_per_output_token": 0.0061,
229
- # ##### Agg. Benchmark #####
230
- # "overall": 71.0,
231
- # ##### Code #####
232
- # "code": 64.0,
233
- # }
258
+ LLAMA3_2_3B_INSTRUCT_MODEL_CARD = {
259
+ ##### Cost in USD #####
260
+ "usd_per_input_token": 0.06 / 1e6,
261
+ "usd_per_output_token": 0.06 / 1e6,
262
+ ##### Time #####
263
+ "seconds_per_output_token": 0.0064,
264
+ ##### Agg. Benchmark #####
265
+ "overall": 36.50, # https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct/discussions/13
266
+ ##### Code #####
267
+ "code": 0.0,
268
+ }
269
+ LLAMA3_1_8B_INSTRUCT_MODEL_CARD = {
270
+ ##### Cost in USD #####
271
+ "usd_per_input_token": 0.18 / 1e6,
272
+ "usd_per_output_token": 0.18 / 1e6,
273
+ ##### Time #####
274
+ "seconds_per_output_token": 0.0059,
275
+ ##### Agg. Benchmark #####
276
+ "overall": 44.25,
277
+ ##### Code #####
278
+ "code": 72.6,
279
+ }
234
280
  LLAMA3_3_70B_INSTRUCT_MODEL_CARD = {
235
281
  ##### Cost in USD #####
236
282
  "usd_per_input_token": 0.88 / 1e6,
@@ -238,19 +284,10 @@ LLAMA3_3_70B_INSTRUCT_MODEL_CARD = {
238
284
  ##### Time #####
239
285
  "seconds_per_output_token": 0.0139,
240
286
  ##### Agg. Benchmark #####
241
- "overall": 86.0,
287
+ "overall": 65.92,
242
288
  ##### Code #####
243
289
  "code": 88.4,
244
290
  }
245
- # LLAMA3_2_11B_V_MODEL_CARD = {
246
- # ##### Cost in USD #####
247
- # "usd_per_input_token": 0.18 / 1E6,
248
- # "usd_per_output_token": 0.18 / 1E6,
249
- # ##### Time #####
250
- # "seconds_per_output_token": 0.0061,
251
- # ##### Agg. Benchmark #####
252
- # "overall": 71.0,
253
- # }
254
291
  LLAMA3_2_90B_V_MODEL_CARD = {
255
292
  ##### Cost in USD #####
256
293
  "usd_per_input_token": 1.2 / 1e6,
@@ -258,7 +295,7 @@ LLAMA3_2_90B_V_MODEL_CARD = {
258
295
  ##### Time #####
259
296
  "seconds_per_output_token": 0.0222,
260
297
  ##### Agg. Benchmark #####
261
- "overall": 84.0,
298
+ "overall": 65.00, # set to be slightly higher than gpt-4o-mini
262
299
  }
263
300
  MIXTRAL_8X_7B_MODEL_CARD = {
264
301
  ##### Cost in USD #####
@@ -267,7 +304,7 @@ MIXTRAL_8X_7B_MODEL_CARD = {
267
304
  ##### Time #####
268
305
  "seconds_per_output_token": 0.0112,
269
306
  ##### Agg. Benchmark #####
270
- "overall": 63.0,
307
+ "overall": 43.27,
271
308
  ##### Code #####
272
309
  "code": 40.0,
273
310
  }
@@ -278,51 +315,56 @@ DEEPSEEK_V3_MODEL_CARD = {
278
315
  ##### Time #####
279
316
  "seconds_per_output_token": 0.0769,
280
317
  ##### Agg. Benchmark #####
281
- "overall": 87.0,
318
+ "overall": 75.87,
282
319
  ##### Code #####
283
320
  "code": 92.0,
284
321
  }
285
- GPT_4o_MODEL_CARD = {
322
+ DEEPSEEK_R1_DISTILL_QWEN_1_5B_MODEL_CARD = {
286
323
  ##### Cost in USD #####
287
- "usd_per_input_token": 2.5 / 1e6,
288
- "usd_per_output_token": 10.0 / 1e6,
324
+ "usd_per_input_token": 0.18 / 1E6,
325
+ "usd_per_output_token": 0.18 / 1E6,
289
326
  ##### Time #####
290
- "seconds_per_output_token": 0.0079,
327
+ "seconds_per_output_token": 0.0026,
291
328
  ##### Agg. Benchmark #####
292
- "overall": 89.0,
329
+ "overall": 39.90, # https://www.reddit.com/r/LocalLLaMA/comments/1iserf9/deepseek_r1_distilled_models_mmlu_pro_benchmarks/
293
330
  ##### Code #####
294
- "code": 90.0,
331
+ "code": 0.0,
295
332
  }
296
- GPT_4o_V_MODEL_CARD = {
297
- # NOTE: it is unclear if the same ($ / token) costs can be applied, or if we have to calculate this ourselves
333
+ GPT_4o_MODEL_CARD = {
334
+ # NOTE: it is unclear if the same ($ / token) costs can be applied for vision, or if we have to calculate this ourselves
298
335
  ##### Cost in USD #####
299
336
  "usd_per_input_token": 2.5 / 1e6,
300
337
  "usd_per_output_token": 10.0 / 1e6,
301
338
  ##### Time #####
302
339
  "seconds_per_output_token": 0.0079,
303
340
  ##### Agg. Benchmark #####
304
- "overall": 89.0,
341
+ "overall": 74.68,
342
+ ##### Code #####
343
+ "code": 90.0,
305
344
  }
306
345
  GPT_4o_MINI_MODEL_CARD = {
346
+ # NOTE: it is unclear if the same ($ / token) costs can be applied for vision, or if we have to calculate this ourselves
307
347
  ##### Cost in USD #####
308
348
  "usd_per_input_token": 0.15 / 1e6,
309
349
  "usd_per_output_token": 0.6 / 1e6,
310
350
  ##### Time #####
311
351
  "seconds_per_output_token": 0.0098,
312
352
  ##### Agg. Benchmark #####
313
- "overall": 82.0,
353
+ "overall": 63.09,
314
354
  ##### Code #####
315
355
  "code": 86.0,
316
356
  }
317
- GPT_4o_MINI_V_MODEL_CARD = {
318
- # NOTE: it is unclear if the same ($ / token) costs can be applied, or if we have to calculate this ourselves
357
+ o1_MODEL_CARD = { # noqa: N816
358
+ # NOTE: it is unclear if the same ($ / token) costs can be applied for vision, or if we have to calculate this ourselves
319
359
  ##### Cost in USD #####
320
- "usd_per_input_token": 0.15 / 1e6,
321
- "usd_per_output_token": 0.6 / 1e6,
360
+ "usd_per_input_token": 15 / 1e6,
361
+ "usd_per_output_token": 60 / 1e6,
322
362
  ##### Time #####
323
- "seconds_per_output_token": 0.0098,
363
+ "seconds_per_output_token": 0.0110,
324
364
  ##### Agg. Benchmark #####
325
- "overall": 82.0,
365
+ "overall": 89.30,
366
+ ##### Code #####
367
+ "code": 92.3, # NOTE: just copying MMLU score for now
326
368
  }
327
369
  TEXT_EMBEDDING_3_SMALL_MODEL_CARD = {
328
370
  ##### Cost in USD #####
@@ -331,7 +373,7 @@ TEXT_EMBEDDING_3_SMALL_MODEL_CARD = {
331
373
  ##### Time #####
332
374
  "seconds_per_output_token": 0.0098, # NOTE: just copying GPT_4o_MINI_MODEL_CARD for now
333
375
  ##### Agg. Benchmark #####
334
- "overall": 82.0, # NOTE: just copying GPT_4o_MINI_MODEL_CARD for now
376
+ "overall": 63.09, # NOTE: just copying GPT_4o_MINI_MODEL_CARD for now
335
377
  }
336
378
  CLIP_VIT_B_32_MODEL_CARD = {
337
379
  ##### Cost in USD #####
@@ -345,22 +387,18 @@ CLIP_VIT_B_32_MODEL_CARD = {
345
387
 
346
388
 
347
389
  MODEL_CARDS = {
348
- Model.LLAMA3.value: LLAMA3_3_70B_INSTRUCT_MODEL_CARD,
349
- Model.LLAMA3_V.value: LLAMA3_2_90B_V_MODEL_CARD,
350
- Model.DEEPSEEK.value: DEEPSEEK_V3_MODEL_CARD,
390
+ Model.LLAMA3_2_3B.value: LLAMA3_2_3B_INSTRUCT_MODEL_CARD,
391
+ Model.LLAMA3_1_8B.value: LLAMA3_1_8B_INSTRUCT_MODEL_CARD,
392
+ Model.LLAMA3_3_70B.value: LLAMA3_3_70B_INSTRUCT_MODEL_CARD,
393
+ Model.LLAMA3_2_90B_V.value: LLAMA3_2_90B_V_MODEL_CARD,
394
+ Model.DEEPSEEK_V3.value: DEEPSEEK_V3_MODEL_CARD,
395
+ Model.DEEPSEEK_R1_DISTILL_QWEN_1_5B.value: DEEPSEEK_R1_DISTILL_QWEN_1_5B_MODEL_CARD,
351
396
  Model.MIXTRAL.value: MIXTRAL_8X_7B_MODEL_CARD,
352
397
  Model.GPT_4o.value: GPT_4o_MODEL_CARD,
353
- Model.GPT_4o_V.value: GPT_4o_V_MODEL_CARD,
354
398
  Model.GPT_4o_MINI.value: GPT_4o_MINI_MODEL_CARD,
355
- Model.GPT_4o_MINI_V.value: GPT_4o_MINI_V_MODEL_CARD,
399
+ # Model.o1.value: o1_MODEL_CARD,
356
400
  Model.TEXT_EMBEDDING_3_SMALL.value: TEXT_EMBEDDING_3_SMALL_MODEL_CARD,
357
401
  Model.CLIP_VIT_B_32.value: CLIP_VIT_B_32_MODEL_CARD,
358
- ###
359
- # Model.GPT_3_5.value: GPT_3_5_MODEL_CARD,
360
- # Model.GPT_4.value: GPT_4_MODEL_CARD,
361
- # Model.GPT_4V.value: GPT_4V_MODEL_CARD,
362
- # Model.GEMINI_1.value: GEMINI_1_MODEL_CARD,
363
- # Model.GEMINI_1V.value: GEMINI_1V_MODEL_CARD,
364
402
  }
365
403
 
366
404
 
@@ -5,6 +5,8 @@ from abc import abstractmethod
5
5
  from dataclasses import dataclass, field, fields
6
6
  from typing import Any
7
7
 
8
+ import numpy as np
9
+
8
10
 
9
11
  @dataclass
10
12
  class GenerationStats:
@@ -148,7 +150,7 @@ class RecordOpStats:
148
150
  record_state: dict[str, Any]
149
151
 
150
152
  # operation id; an identifier for this operation's physical op id
151
- op_id: str
153
+ full_op_id: str
152
154
 
153
155
  # logical operation id; the logical op id for this physical op
154
156
  logical_op_id: str
@@ -164,7 +166,7 @@ class RecordOpStats:
164
166
 
165
167
  ##### NOT-OPTIONAL, BUT FILLED BY EXECUTION CLASS AFTER CONSTRUCTOR CALL #####
166
168
  # the ID of the physical operation which produced the input record for this record at this operation
167
- source_op_id: str | None = None
169
+ source_full_op_id: str | None = None
168
170
 
169
171
  # the ID of the physical plan which produced this record at this operation
170
172
  plan_id: str = ""
@@ -240,8 +242,8 @@ class OperatorStats:
240
242
  Dataclass for storing statistics captured within a given operator.
241
243
  """
242
244
 
243
- # the ID of the physical operation in which these stats were collected
244
- op_id: str
245
+ # the full ID of the physical operation in which these stats were collected
246
+ full_op_id: str
245
247
 
246
248
  # the name of the physical operation in which these stats were collected
247
249
  op_name: str
@@ -255,8 +257,8 @@ class OperatorStats:
255
257
  # a list of RecordOpStats processed by the operation
256
258
  record_op_stats_lst: list[RecordOpStats] = field(default_factory=list)
257
259
 
258
- # the ID of the physical operator which precedes this one
259
- source_op_id: str | None = None
260
+ # the full ID of the physical operator which precedes this one
261
+ source_full_op_id: str | None = None
260
262
 
261
263
  # the ID of the physical plan which this operator is part of
262
264
  plan_id: str = ""
@@ -273,7 +275,7 @@ class OperatorStats:
273
275
 
274
276
  NOTE: in case (1.) we assume the execution layer guarantees that `stats` is
275
277
  generated by the same operator in the same plan. Thus, we assume the
276
- op_ids, op_name, source_op_id, etc. do not need to be updated.
278
+ full_op_ids, op_name, source_op_id, etc. do not need to be updated.
277
279
  """
278
280
  if isinstance(stats, OperatorStats):
279
281
  self.total_op_time += stats.total_op_time
@@ -281,7 +283,7 @@ class OperatorStats:
281
283
  self.record_op_stats_lst.extend(stats.record_op_stats_lst)
282
284
 
283
285
  elif isinstance(stats, RecordOpStats):
284
- stats.source_op_id = self.source_op_id
286
+ stats.source_full_op_id = self.source_full_op_id
285
287
  stats.plan_id = self.plan_id
286
288
  self.record_op_stats_lst.append(stats)
287
289
  self.total_op_time += stats.time_per_record
@@ -294,7 +296,7 @@ class OperatorStats:
294
296
 
295
297
  def to_json(self):
296
298
  return {
297
- "op_id": self.op_id,
299
+ "full_op_id": self.full_op_id,
298
300
  "op_name": self.op_name,
299
301
  "total_op_time": self.total_op_time,
300
302
  "total_op_cost": self.total_op_cost,
@@ -327,8 +329,8 @@ class BasePlanStats:
327
329
  plan_str: str | None = None
328
330
 
329
331
  # dictionary whose values are OperatorStats objects;
330
- # PlanStats maps {physical_op_id -> OperatorStats}
331
- # SentinelPlanStats maps {logical_op_id -> {physical_op_id -> OperatorStats}}
332
+ # PlanStats maps {full_op_id -> OperatorStats}
333
+ # SentinelPlanStats maps {logical_op_id -> {full_op_id -> OperatorStats}}
332
334
  operator_stats: dict = field(default_factory=dict)
333
335
 
334
336
  # total runtime for the plan measured from the start to the end of PhysicalPlan.execute()
@@ -406,11 +408,11 @@ class PlanStats(BasePlanStats):
406
408
  """
407
409
  operator_stats = {}
408
410
  for op_idx, op in enumerate(plan.operators):
409
- op_id = op.get_op_id()
410
- operator_stats[op_id] = OperatorStats(
411
- op_id=op_id,
411
+ full_op_id = op.get_full_op_id()
412
+ operator_stats[full_op_id] = OperatorStats(
413
+ full_op_id=full_op_id,
412
414
  op_name=op.op_name(),
413
- source_op_id=None if op_idx == 0 else plan.operators[op_idx - 1].get_op_id(),
415
+ source_full_op_id=None if op_idx == 0 else plan.operators[op_idx - 1].get_full_op_id(),
414
416
  plan_id=plan.plan_id,
415
417
  op_details={k: str(v) for k, v in op.get_id_params().items()},
416
418
  )
@@ -432,11 +434,11 @@ class PlanStats(BasePlanStats):
432
434
 
433
435
  # update operator stats
434
436
  for record_op_stats in record_op_stats_lst:
435
- op_id = record_op_stats.op_id
436
- if op_id in self.operator_stats:
437
- self.operator_stats[op_id] += record_op_stats
437
+ full_op_id = record_op_stats.full_op_id
438
+ if full_op_id in self.operator_stats:
439
+ self.operator_stats[full_op_id] += record_op_stats
438
440
  else:
439
- raise ValueError(f"RecordOpStats with physical_op_id {op_id} not found in PlanStats")
441
+ raise ValueError(f"RecordOpStats with full_op_id {full_op_id} not found in PlanStats")
440
442
 
441
443
  def __iadd__(self, plan_stats: PlanStats) -> None:
442
444
  """
@@ -448,11 +450,11 @@ class PlanStats(BasePlanStats):
448
450
  """
449
451
  self.total_plan_time += plan_stats.total_plan_time
450
452
  self.total_plan_cost += plan_stats.total_plan_cost
451
- for op_id, op_stats in plan_stats.operator_stats.items():
452
- if op_id in self.operator_stats:
453
- self.operator_stats[op_id] += op_stats
453
+ for full_op_id, op_stats in plan_stats.operator_stats.items():
454
+ if full_op_id in self.operator_stats:
455
+ self.operator_stats[full_op_id] += op_stats
454
456
  else:
455
- self.operator_stats[op_id] = op_stats
457
+ self.operator_stats[full_op_id] = op_stats
456
458
 
457
459
  def __str__(self) -> str:
458
460
  stats = f"total_plan_time={self.total_plan_time} \n"
@@ -465,7 +467,7 @@ class PlanStats(BasePlanStats):
465
467
  return {
466
468
  "plan_id": self.plan_id,
467
469
  "plan_str": self.plan_str,
468
- "operator_stats": {op_id: op_stats.to_json() for op_id, op_stats in self.operator_stats.items()},
470
+ "operator_stats": {full_op_id: op_stats.to_json() for full_op_id, op_stats in self.operator_stats.items()},
469
471
  "total_plan_time": self.total_plan_time,
470
472
  "total_plan_cost": self.total_plan_cost,
471
473
  }
@@ -485,11 +487,11 @@ class SentinelPlanStats(BasePlanStats):
485
487
  for op_set_idx, (logical_op_id, op_set) in enumerate(plan):
486
488
  operator_stats[logical_op_id] = {}
487
489
  for physical_op in op_set:
488
- op_id = physical_op.get_op_id()
489
- operator_stats[logical_op_id][op_id] = OperatorStats(
490
- op_id=op_id,
490
+ full_op_id = physical_op.get_full_op_id()
491
+ operator_stats[logical_op_id][full_op_id] = OperatorStats(
492
+ full_op_id=full_op_id,
491
493
  op_name=physical_op.op_name(),
492
- source_op_id=None if op_set_idx == 0 else plan.logical_op_ids[op_set_idx - 1],
494
+ source_full_op_id=None if op_set_idx == 0 else plan.logical_op_ids[op_set_idx - 1], # NOTE: this may be a reason to keep `source_op_id` instead of `source_full_op_id`
493
495
  plan_id=plan.plan_id,
494
496
  op_details={k: str(v) for k, v in physical_op.get_id_params().items()},
495
497
  )
@@ -512,12 +514,12 @@ class SentinelPlanStats(BasePlanStats):
512
514
  # update operator stats
513
515
  for record_op_stats in record_op_stats_lst:
514
516
  logical_op_id = record_op_stats.logical_op_id
515
- physical_op_id = record_op_stats.op_id
517
+ full_op_id = record_op_stats.full_op_id
516
518
  if logical_op_id in self.operator_stats:
517
- if physical_op_id in self.operator_stats[logical_op_id]:
518
- self.operator_stats[logical_op_id][physical_op_id] += record_op_stats
519
+ if full_op_id in self.operator_stats[logical_op_id]:
520
+ self.operator_stats[logical_op_id][full_op_id] += record_op_stats
519
521
  else:
520
- raise ValueError(f"RecordOpStats with physical_op_id {physical_op_id} not found in SentinelPlanStats")
522
+ raise ValueError(f"RecordOpStats with full_op_id {full_op_id} not found in SentinelPlanStats")
521
523
  else:
522
524
  raise ValueError(f"RecordOpStats with logical_op_id {logical_op_id} not found in SentinelPlanStats")
523
525
 
@@ -532,12 +534,12 @@ class SentinelPlanStats(BasePlanStats):
532
534
  self.total_plan_time += plan_stats.total_plan_time
533
535
  self.total_plan_cost += plan_stats.total_plan_cost
534
536
  for logical_op_id, physical_op_stats in plan_stats.operator_stats.items():
535
- for physical_op_id, op_stats in physical_op_stats.items():
537
+ for full_op_id, op_stats in physical_op_stats.items():
536
538
  if logical_op_id in self.operator_stats:
537
- if physical_op_id in self.operator_stats[logical_op_id]:
538
- self.operator_stats[logical_op_id][physical_op_id] += op_stats
539
+ if full_op_id in self.operator_stats[logical_op_id]:
540
+ self.operator_stats[logical_op_id][full_op_id] += op_stats
539
541
  else:
540
- self.operator_stats[logical_op_id][physical_op_id] = op_stats
542
+ self.operator_stats[logical_op_id][full_op_id] = op_stats
541
543
  else:
542
544
  self.operator_stats[logical_op_id] = physical_op_stats
543
545
 
@@ -557,7 +559,7 @@ class SentinelPlanStats(BasePlanStats):
557
559
  "plan_id": self.plan_id,
558
560
  "plan_str": self.plan_str,
559
561
  "operator_stats": {
560
- logical_op_id: {physical_op_id: op_stats.to_json() for physical_op_id, op_stats in physical_op_stats.items()}
562
+ logical_op_id: {full_op_id: op_stats.to_json() for full_op_id, op_stats in physical_op_stats.items()}
561
563
  for logical_op_id, physical_op_stats in self.operator_stats.items()
562
564
  },
563
565
  "total_plan_time": self.total_plan_time,
@@ -684,8 +686,21 @@ class ExecutionStats:
684
686
  else:
685
687
  raise TypeError(f"Cannot add {type(plan_stats)} to ExecutionStats")
686
688
 
689
+ def clean_json(self, stats: dict):
690
+ """
691
+ Convert np.int64 and np.float64 to int and float for all values in stats.
692
+ """
693
+ for key, value in stats.items():
694
+ if isinstance(value, dict):
695
+ stats[key] = self.clean_json(value)
696
+ elif isinstance(value, np.int64):
697
+ stats[key] = int(value)
698
+ elif isinstance(value, np.float64):
699
+ stats[key] = float(value)
700
+ return stats
701
+
687
702
  def to_json(self):
688
- return {
703
+ stats = {
689
704
  "execution_id": self.execution_id,
690
705
  "sentinel_plan_stats": {
691
706
  plan_id: plan_stats.to_json() for plan_id, plan_stats in self.sentinel_plan_stats.items()
@@ -700,6 +715,8 @@ class ExecutionStats:
700
715
  "sentinel_plan_strs": self.sentinel_plan_strs,
701
716
  "plan_strs": self.plan_strs,
702
717
  }
718
+ stats = self.clean_json(stats)
719
+ return stats
703
720
 
704
721
 
705
722
  @dataclass
@@ -3,30 +3,27 @@ from __future__ import annotations
3
3
  from abc import ABC, abstractmethod
4
4
 
5
5
  from chromadb.api.models.Collection import Collection
6
- from ragatouille.RAGPretrainedModel import RAGPretrainedModel
7
6
 
8
7
 
9
- def index_factory(index: Collection | RAGPretrainedModel) -> PZIndex:
8
+ def index_factory(index: Collection) -> PZIndex:
10
9
  """
11
10
  Factory function to create a PZ index based on the type of the provided index.
12
11
 
13
12
  Args:
14
- index (Collection | RAGPretrainedModel): The index provided by the user.
13
+ index (Collection): The index provided by the user.
15
14
 
16
15
  Returns:
17
16
  PZIndex: The PZ wrapped Index.
18
17
  """
19
18
  if isinstance(index, Collection):
20
19
  return ChromaIndex(index)
21
- elif isinstance(index, RAGPretrainedModel):
22
- return RagatouilleIndex(index)
23
20
  else:
24
- raise TypeError(f"Unsupported index type: {type(index)}\nindex must be a `chromadb.api.models.Collection.Collection` or `ragatouille.RAGPretrainedModel.RAGPretrainedModel`")
21
+ raise TypeError(f"Unsupported index type: {type(index)}\nindex must be a `chromadb.api.models.Collection.Collection`")
25
22
 
26
23
 
27
24
  class BaseIndex(ABC):
28
25
 
29
- def __init__(self, index: Collection | RAGPretrainedModel):
26
+ def __init__(self, index: Collection):
30
27
  self.index = index
31
28
 
32
29
  def __str__(self):
@@ -59,12 +56,5 @@ class ChromaIndex(BaseIndex):
59
56
  super().__init__(index)
60
57
 
61
58
 
62
-
63
- class RagatouilleIndex(BaseIndex):
64
- def __init__(self, index: RAGPretrainedModel):
65
- assert isinstance(index, RAGPretrainedModel), "RagatouilleIndex input must be a `ragatouille.RAGPretrainedModel.RAGPretrainedModel`"
66
- super().__init__(index)
67
-
68
-
69
59
  # define type for PZIndex
70
- PZIndex = ChromaIndex | RagatouilleIndex
60
+ PZIndex = ChromaIndex
@@ -35,7 +35,7 @@ class DataRecord:
35
35
  self.field_values: dict[str, Any] = {}
36
36
 
37
37
  # the index in the DataReader from which this DataRecord is derived
38
- self.source_idx = source_idx
38
+ self.source_idx = int(source_idx)
39
39
 
40
40
  # the id of the parent record(s) from which this DataRecord is derived
41
41
  self.parent_id = parent_id
@@ -155,7 +155,7 @@ class PromptFactory:
155
155
  # TODO: this does not work for image prompts
156
156
  # TODO: this ignores the size of the `orignal_messages` in critique and refine prompts
157
157
  # cut down on context based on window length
158
- if self.model in [Model.LLAMA3, Model.MIXTRAL]:
158
+ if self.model.is_llama_model() or self.model.is_mixtral_model():
159
159
  total_context_len = len(json.dumps(context, indent=2))
160
160
 
161
161
  # sort fields by length and progressively strip from the longest field until it is short enough;