palimpzest 0.7.21__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. palimpzest/__init__.py +37 -6
  2. palimpzest/agents/__init__.py +0 -0
  3. palimpzest/agents/compute_agents.py +0 -0
  4. palimpzest/agents/search_agents.py +637 -0
  5. palimpzest/constants.py +259 -197
  6. palimpzest/core/data/context.py +393 -0
  7. palimpzest/core/data/context_manager.py +163 -0
  8. palimpzest/core/data/dataset.py +634 -0
  9. palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
  10. palimpzest/core/elements/groupbysig.py +16 -13
  11. palimpzest/core/elements/records.py +166 -75
  12. palimpzest/core/lib/schemas.py +152 -390
  13. palimpzest/core/{data/dataclasses.py → models.py} +306 -170
  14. palimpzest/policy.py +2 -27
  15. palimpzest/prompts/__init__.py +35 -5
  16. palimpzest/prompts/agent_prompts.py +357 -0
  17. palimpzest/prompts/context_search.py +9 -0
  18. palimpzest/prompts/convert_prompts.py +61 -5
  19. palimpzest/prompts/filter_prompts.py +50 -5
  20. palimpzest/prompts/join_prompts.py +163 -0
  21. palimpzest/prompts/moa_proposer_convert_prompts.py +5 -5
  22. palimpzest/prompts/prompt_factory.py +358 -46
  23. palimpzest/prompts/validator.py +239 -0
  24. palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
  25. palimpzest/query/execution/execution_strategy.py +210 -317
  26. palimpzest/query/execution/execution_strategy_type.py +5 -7
  27. palimpzest/query/execution/mab_execution_strategy.py +249 -136
  28. palimpzest/query/execution/parallel_execution_strategy.py +153 -244
  29. palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
  30. palimpzest/query/generators/generators.py +157 -330
  31. palimpzest/query/operators/__init__.py +15 -5
  32. palimpzest/query/operators/aggregate.py +50 -33
  33. palimpzest/query/operators/compute.py +201 -0
  34. palimpzest/query/operators/convert.py +27 -21
  35. palimpzest/query/operators/critique_and_refine_convert.py +7 -5
  36. palimpzest/query/operators/distinct.py +62 -0
  37. palimpzest/query/operators/filter.py +22 -13
  38. palimpzest/query/operators/join.py +402 -0
  39. palimpzest/query/operators/limit.py +3 -3
  40. palimpzest/query/operators/logical.py +198 -80
  41. palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
  42. palimpzest/query/operators/physical.py +27 -21
  43. palimpzest/query/operators/project.py +3 -3
  44. palimpzest/query/operators/rag_convert.py +7 -7
  45. palimpzest/query/operators/retrieve.py +9 -9
  46. palimpzest/query/operators/scan.py +81 -42
  47. palimpzest/query/operators/search.py +524 -0
  48. palimpzest/query/operators/split_convert.py +10 -8
  49. palimpzest/query/optimizer/__init__.py +7 -9
  50. palimpzest/query/optimizer/cost_model.py +108 -441
  51. palimpzest/query/optimizer/optimizer.py +123 -181
  52. palimpzest/query/optimizer/optimizer_strategy.py +66 -61
  53. palimpzest/query/optimizer/plan.py +352 -67
  54. palimpzest/query/optimizer/primitives.py +43 -19
  55. palimpzest/query/optimizer/rules.py +484 -646
  56. palimpzest/query/optimizer/tasks.py +127 -58
  57. palimpzest/query/processor/config.py +41 -76
  58. palimpzest/query/processor/query_processor.py +73 -18
  59. palimpzest/query/processor/query_processor_factory.py +46 -38
  60. palimpzest/schemabuilder/schema_builder.py +15 -28
  61. palimpzest/utils/model_helpers.py +27 -77
  62. palimpzest/utils/progress.py +114 -102
  63. palimpzest/validator/__init__.py +0 -0
  64. palimpzest/validator/validator.py +306 -0
  65. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/METADATA +6 -1
  66. palimpzest-0.8.0.dist-info/RECORD +95 -0
  67. palimpzest/core/lib/fields.py +0 -141
  68. palimpzest/prompts/code_synthesis_prompts.py +0 -28
  69. palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
  70. palimpzest/query/generators/api_client_factory.py +0 -30
  71. palimpzest/query/operators/code_synthesis_convert.py +0 -488
  72. palimpzest/query/operators/map.py +0 -130
  73. palimpzest/query/processor/nosentinel_processor.py +0 -33
  74. palimpzest/query/processor/processing_strategy_type.py +0 -28
  75. palimpzest/query/processor/sentinel_processor.py +0 -88
  76. palimpzest/query/processor/streaming_processor.py +0 -149
  77. palimpzest/sets.py +0 -405
  78. palimpzest/utils/datareader_helpers.py +0 -61
  79. palimpzest/utils/demo_helpers.py +0 -75
  80. palimpzest/utils/field_helpers.py +0 -69
  81. palimpzest/utils/generation_helpers.py +0 -69
  82. palimpzest/utils/sandbox.py +0 -183
  83. palimpzest-0.7.21.dist-info/RECORD +0 -95
  84. /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
  85. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/WHEEL +0 -0
  86. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/licenses/LICENSE +0 -0
  87. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/top_level.txt +0 -0
@@ -2,9 +2,10 @@ from __future__ import annotations
2
2
 
3
3
  import json
4
4
 
5
- from palimpzest.core.data.dataclasses import OperatorCostEstimates
5
+ from pydantic import BaseModel
6
+
6
7
  from palimpzest.core.elements.records import DataRecord, DataRecordSet
7
- from palimpzest.core.lib.schemas import Schema
8
+ from palimpzest.core.models import OperatorCostEstimates
8
9
  from palimpzest.utils.hash_helpers import hash_for_id
9
10
 
10
11
 
@@ -17,12 +18,13 @@ class PhysicalOperator:
17
18
 
18
19
  def __init__(
19
20
  self,
20
- output_schema: Schema,
21
- input_schema: Schema | None = None,
21
+ output_schema: BaseModel,
22
+ input_schema: BaseModel | None = None,
22
23
  depends_on: list[str] | None = None,
23
24
  logical_op_id: str | None = None,
25
+ unique_logical_op_id: str | None = None,
24
26
  logical_op_name: str | None = None,
25
- target_cache_id: str | None = None,
27
+ api_base: str | None = None,
26
28
  verbose: bool = False,
27
29
  *args,
28
30
  **kwargs,
@@ -31,16 +33,17 @@ class PhysicalOperator:
31
33
  self.input_schema = input_schema
32
34
  self.depends_on = depends_on if depends_on is None else sorted(depends_on)
33
35
  self.logical_op_id = logical_op_id
36
+ self.unique_logical_op_id = unique_logical_op_id
34
37
  self.logical_op_name = logical_op_name
35
- self.target_cache_id = target_cache_id
38
+ self.api_base = api_base
36
39
  self.verbose = verbose
37
40
  self.op_id = None
38
41
 
39
42
  # compute the fields generated by this physical operator
40
- input_field_names = self.input_schema.field_names() if self.input_schema is not None else []
43
+ input_field_names = list(self.input_schema.model_fields) if self.input_schema is not None else []
41
44
  self.generated_fields = sorted([
42
45
  field_name
43
- for field_name in self.output_schema.field_names()
46
+ for field_name in self.output_schema.model_fields
44
47
  if field_name not in input_field_names
45
48
  ])
46
49
 
@@ -50,16 +53,18 @@ class PhysicalOperator:
50
53
  self.__class__.__hash__ = PhysicalOperator.__hash__
51
54
 
52
55
  def __str__(self):
53
- op = f"{self.input_schema.class_name()} -> {self.op_name()} -> {self.output_schema.class_name()}\n"
54
- op += f" ({', '.join(self.input_schema.field_names())[:30]}) "
55
- op += f"-> ({', '.join(self.output_schema.field_names())[:30]})\n"
56
+ op = f"{self.input_schema.__name__} -> {self.op_name()} -> {self.output_schema.__name__}\n"
57
+ op += f" ({', '.join(sorted(self.input_schema.model_fields))[:30]}) "
58
+ op += f"-> ({', '.join(sorted(self.output_schema.model_fields))[:30]})\n"
56
59
  if getattr(self, "model", None):
57
60
  op += f" Model: {self.model}\n"
58
61
  return op
59
62
 
63
+ # def __eq__(self, other) -> bool:
64
+ # all_op_params_match = all(value == getattr(other, key) for key, value in self.get_op_params().items())
65
+ # return isinstance(other, self.__class__) and all_op_params_match
60
66
  def __eq__(self, other) -> bool:
61
- all_op_params_match = all(value == getattr(other, key) for key, value in self.get_op_params().items())
62
- return isinstance(other, self.__class__) and all_op_params_match
67
+ return isinstance(other, self.__class__) and self.get_full_op_id() == other.get_full_op_id()
63
68
 
64
69
  def copy(self) -> PhysicalOperator:
65
70
  return self.__class__(**self.get_op_params())
@@ -94,8 +99,9 @@ class PhysicalOperator:
94
99
  "input_schema": self.input_schema,
95
100
  "depends_on": self.depends_on,
96
101
  "logical_op_id": self.logical_op_id,
102
+ "unique_logical_op_id": self.unique_logical_op_id,
97
103
  "logical_op_name": self.logical_op_name,
98
- "target_cache_id": self.target_cache_id,
104
+ "api_base": self.api_base,
99
105
  "verbose": self.verbose,
100
106
  }
101
107
 
@@ -116,10 +122,7 @@ class PhysicalOperator:
116
122
  # get op name and op parameters which are relevant for computing the id
117
123
  op_name = self.op_name()
118
124
  id_params = self.get_id_params()
119
- id_params = {
120
- k: str(v) if k != "output_schema" else sorted(v.field_names())
121
- for k, v in id_params.items()
122
- }
125
+ id_params = {k: str(v) for k, v in id_params.items()}
123
126
 
124
127
  # compute, set, and return the op_id
125
128
  hash_str = json.dumps({"op_name": op_name, **id_params}, sort_keys=True)
@@ -127,9 +130,12 @@ class PhysicalOperator:
127
130
 
128
131
  return self.op_id
129
132
 
130
- def get_logical_op_id(self) -> str | None:
133
+ def get_logical_op_id(self) -> str:
131
134
  return self.logical_op_id
132
135
 
136
+ def get_unique_logical_op_id(self) -> str:
137
+ return self.unique_logical_op_id
138
+
133
139
  def get_full_op_id(self):
134
140
  return f"{self.get_logical_op_id()}-{self.get_op_id()}"
135
141
 
@@ -148,9 +154,9 @@ class PhysicalOperator:
148
154
  else None
149
155
  )
150
156
  input_fields = (
151
- self.input_schema.field_names()
157
+ list(self.input_schema.model_fields)
152
158
  if depends_on_fields is None
153
- else [field for field in self.input_schema.field_names() if field in depends_on_fields]
159
+ else [field for field in self.input_schema.model_fields if field in depends_on_fields]
154
160
  )
155
161
 
156
162
  return input_fields
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
- from palimpzest.core.data.dataclasses import OperatorCostEstimates, RecordOpStats
4
3
  from palimpzest.core.elements.records import DataRecord, DataRecordSet
4
+ from palimpzest.core.models import OperatorCostEstimates, RecordOpStats
5
5
  from palimpzest.query.operators.physical import PhysicalOperator
6
6
 
7
7
 
@@ -39,8 +39,8 @@ class ProjectOp(PhysicalOperator):
39
39
  # create RecordOpStats object
40
40
  record_op_stats = RecordOpStats(
41
41
  record_id=dr.id,
42
- record_parent_id=dr.parent_id,
43
- record_source_idx=dr.source_idx,
42
+ record_parent_ids=dr.parent_ids,
43
+ record_source_indices=dr.source_indices,
44
44
  record_state=dr.to_dict(include_bytes=False),
45
45
  full_op_id=self.get_full_op_id(),
46
46
  logical_op_id=self.logical_op_id,
@@ -5,15 +5,15 @@ import time
5
5
  from numpy import dot
6
6
  from numpy.linalg import norm
7
7
  from openai import OpenAI
8
+ from pydantic.fields import FieldInfo
8
9
 
9
10
  from palimpzest.constants import (
10
11
  MODEL_CARDS,
11
12
  NAIVE_EST_NUM_OUTPUT_TOKENS,
12
13
  Model,
13
14
  )
14
- from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates
15
15
  from palimpzest.core.elements.records import DataRecord
16
- from palimpzest.core.lib.fields import Field, StringField
16
+ from palimpzest.core.models import GenerationStats, OperatorCostEstimates
17
17
  from palimpzest.query.operators.convert import LLMConvert
18
18
 
19
19
 
@@ -143,9 +143,9 @@ class RAGConvert(LLMConvert):
143
143
 
144
144
  # compute embedding for output fields
145
145
  output_fields_desc = ""
146
- field_desc_map = self.output_schema.field_desc_map()
147
146
  for field_name in output_fields:
148
- output_fields_desc += f"- {field_name}: {field_desc_map[field_name]}\n"
147
+ desc = self.output_schema.model_fields[field_name].description
148
+ output_fields_desc += f"- {field_name}: {'no description available' if desc is None else desc}\n"
149
149
  query_embedding, query_embed_stats = self.compute_embedding(output_fields_desc)
150
150
 
151
151
  # add cost of embedding the query to embed_stats
@@ -156,8 +156,8 @@ class RAGConvert(LLMConvert):
156
156
  field = candidate.get_field_type(field_name)
157
157
 
158
158
  # skip this field if it is not a string or a list of strings
159
- is_string_field = isinstance(field, StringField)
160
- is_list_string_field = hasattr(field, "element_type") and isinstance(field.element_type, StringField)
159
+ is_string_field = field.annotation in [str, str | None]
160
+ is_list_string_field = field.annotation in [list[str], list[str] | None]
161
161
  if not (is_string_field or is_list_string_field):
162
162
  continue
163
163
 
@@ -193,7 +193,7 @@ class RAGConvert(LLMConvert):
193
193
 
194
194
  return candidate, embed_stats
195
195
 
196
- def convert(self, candidate: DataRecord, fields: dict[str, Field]) -> tuple[dict[str, list], GenerationStats]:
196
+ def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[str, list], GenerationStats]:
197
197
  # set client
198
198
  self.client = OpenAI() if self.client is None else self.client
199
199
 
@@ -8,12 +8,12 @@ from chromadb.api.models.Collection import Collection
8
8
  from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
9
9
  from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction
10
10
  from openai import OpenAI
11
+ from pydantic import BaseModel
11
12
  from sentence_transformers import SentenceTransformer
12
13
 
13
14
  from palimpzest.constants import MODEL_CARDS, Model
14
- from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates, RecordOpStats
15
15
  from palimpzest.core.elements.records import DataRecord, DataRecordSet
16
- from palimpzest.core.lib.schemas import Schema
16
+ from palimpzest.core.models import GenerationStats, OperatorCostEstimates, RecordOpStats
17
17
  from palimpzest.query.operators.physical import PhysicalOperator
18
18
 
19
19
 
@@ -22,7 +22,7 @@ class RetrieveOp(PhysicalOperator):
22
22
  self,
23
23
  index: Collection,
24
24
  search_attr: str,
25
- output_attrs: list[dict] | type[Schema],
25
+ output_attrs: list[dict] | type[BaseModel],
26
26
  search_func: Callable | None,
27
27
  k: int,
28
28
  *args,
@@ -41,12 +41,12 @@ class RetrieveOp(PhysicalOperator):
41
41
  super().__init__(*args, **kwargs)
42
42
 
43
43
  # extract the field names from the output_attrs
44
- if isinstance(output_attrs, Schema):
45
- self.output_field_names = output_attrs.field_names()
44
+ if issubclass(output_attrs, BaseModel):
45
+ self.output_field_names = list(output_attrs.model_fields)
46
46
  elif isinstance(output_attrs, list):
47
47
  self.output_field_names = [attr["name"] for attr in output_attrs]
48
48
  else:
49
- raise ValueError("`output_attrs` must be a list of dicts or a Schema object.")
49
+ raise ValueError("`output_attrs` must be a list of dicts or a `pydantic.BaseModel` object.")
50
50
 
51
51
  if len(self.output_field_names) != 1 and search_func is None:
52
52
  raise ValueError("If `search_func` is None, `output_attrs` must have a single field.")
@@ -160,8 +160,8 @@ class RetrieveOp(PhysicalOperator):
160
160
  # construct the RecordOpStats object
161
161
  record_op_stats = RecordOpStats(
162
162
  record_id=output_dr.id,
163
- record_parent_id=output_dr.parent_id,
164
- record_source_idx=output_dr.source_idx,
163
+ record_parent_ids=output_dr.parent_ids,
164
+ record_source_indices=output_dr.source_indices,
165
165
  record_state=record_state,
166
166
  full_op_id=self.get_full_op_id(),
167
167
  logical_op_id=self.logical_op_id,
@@ -169,7 +169,7 @@ class RetrieveOp(PhysicalOperator):
169
169
  time_per_record=total_time,
170
170
  cost_per_record=generation_stats.cost_per_record,
171
171
  answer=answer,
172
- input_fields=self.input_schema.field_names(),
172
+ input_fields=list(self.input_schema.model_fields),
173
173
  generated_fields=generated_fields,
174
174
  fn_call_duration_secs=total_time - generation_stats.llm_call_duration_secs,
175
175
  llm_call_duration_secs=generation_stats.llm_call_duration_secs,
@@ -2,46 +2,43 @@ from __future__ import annotations
2
2
 
3
3
  import time
4
4
  from abc import ABC, abstractmethod
5
+ from typing import Any
5
6
 
6
- from palimpzest.constants import (
7
- LOCAL_SCAN_TIME_PER_KB,
8
- MEMORY_SCAN_TIME_PER_KB,
9
- Cardinality,
10
- )
11
- from palimpzest.core.data.dataclasses import OperatorCostEstimates, RecordOpStats
12
- from palimpzest.core.data.datareaders import DataReader, DirectoryReader, FileReader
7
+ from palimpzest.constants import LOCAL_SCAN_TIME_PER_KB
8
+ from palimpzest.core.data import context
13
9
  from palimpzest.core.elements.records import DataRecord, DataRecordSet
10
+ from palimpzest.core.models import OperatorCostEstimates, RecordOpStats
14
11
  from palimpzest.query.operators.physical import PhysicalOperator
15
12
 
16
13
 
17
14
  class ScanPhysicalOp(PhysicalOperator, ABC):
18
15
  """
19
- Physical operators which implement DataReaders require slightly more information
16
+ Physical operators which implement root Datasets require slightly more information
20
17
  in order to accurately compute naive cost estimates. Thus, we use a slightly
21
18
  modified abstract base class for these operators.
22
19
  """
23
-
24
- def __init__(self, datareader: DataReader, *args, **kwargs):
20
+ # datasource: IterDataset
21
+ def __init__(self, datasource: Any, *args, **kwargs):
25
22
  super().__init__(*args, **kwargs)
26
- self.datareader = datareader
23
+ self.datasource = datasource
27
24
 
28
25
  def __str__(self):
29
- op = f"{self.op_name()}({self.datareader}) -> {self.output_schema}\n"
30
- op += f" ({', '.join(self.output_schema.field_names())[:30]})\n"
26
+ op = f"{self.op_name()}({self.datasource}) -> {self.output_schema}\n"
27
+ op += f" ({', '.join(list(self.output_schema.model_fields))[:30]})\n"
31
28
  return op
32
29
 
33
30
  def get_id_params(self):
34
- return super().get_id_params()
31
+ id_params = super().get_id_params()
32
+ return {"datasource_id": self.datasource.id, **id_params}
35
33
 
36
34
  def get_op_params(self):
37
35
  op_params = super().get_op_params()
38
- return {"datareader": self.datareader, **op_params}
36
+ return {"datasource": self.datasource, **op_params}
39
37
 
40
38
  @abstractmethod
41
39
  def naive_cost_estimates(
42
40
  self,
43
41
  source_op_cost_estimates: OperatorCostEstimates,
44
- input_cardinality: Cardinality,
45
42
  input_record_size_in_bytes: int | float,
46
43
  ) -> OperatorCostEstimates:
47
44
  """
@@ -62,30 +59,27 @@ class ScanPhysicalOp(PhysicalOperator, ABC):
62
59
 
63
60
  def __call__(self, idx: int) -> DataRecordSet:
64
61
  """
65
- This function invokes `self.datareader.__getitem__` on the given `idx` to retrieve the next data item.
62
+ This function invokes `self.datasource.__getitem__` on the given `idx` to retrieve the next data item.
66
63
  It then returns this item as a DataRecord wrapped in a DataRecordSet.
67
64
  """
68
65
  start_time = time.time()
69
- item = self.datareader[idx]
66
+ item = self.datasource[idx]
70
67
  end_time = time.time()
71
68
 
72
- # TODO: remove once validation data is refactored
73
- item_field_dict = item.get("fields", item)
74
-
75
69
  # check that item covers fields in output schema
76
- output_field_names = self.output_schema.field_names()
77
- assert all([field in item_field_dict for field in output_field_names]), f"Some fields in DataReader schema not present in item!\n - DataReader fields: {output_field_names}\n - Item fields: {list(item.keys())}"
70
+ output_field_names = list(self.output_schema.model_fields)
71
+ assert all([field in item for field in output_field_names]), f"Some fields in Dataset schema not present in item!\n - Dataset fields: {output_field_names}\n - Item fields: {list(item.keys())}"
78
72
 
79
73
  # construct a DataRecord from the item
80
- dr = DataRecord(self.output_schema, source_idx=idx)
74
+ dr = DataRecord(self.output_schema, source_indices=[f"{self.datasource.id}-{idx}"])
81
75
  for field in output_field_names:
82
- setattr(dr, field, item_field_dict[field])
76
+ setattr(dr, field, item[field])
83
77
 
84
78
  # create RecordOpStats objects
85
79
  record_op_stats = RecordOpStats(
86
80
  record_id=dr.id,
87
- record_parent_id=dr.parent_id,
88
- record_source_idx=dr.source_idx,
81
+ record_parent_ids=dr.parent_ids,
82
+ record_source_indices=dr.source_indices,
89
83
  record_state=dr.to_dict(include_bytes=False),
90
84
  full_op_id=self.get_full_op_id(),
91
85
  logical_op_id=self.logical_op_id,
@@ -97,7 +91,7 @@ class ScanPhysicalOp(PhysicalOperator, ABC):
97
91
 
98
92
  # construct and return DataRecordSet object
99
93
  return DataRecordSet([dr], [record_op_stats])
100
-
94
+
101
95
 
102
96
  class MarshalAndScanDataOp(ScanPhysicalOp):
103
97
  def naive_cost_estimates(
@@ -110,11 +104,14 @@ class MarshalAndScanDataOp(ScanPhysicalOp):
110
104
 
111
105
  # estimate time spent reading each record
112
106
  per_record_size_kb = input_record_size_in_bytes / 1024.0
113
- time_per_record = (
114
- LOCAL_SCAN_TIME_PER_KB * per_record_size_kb
115
- if isinstance(self.datareader, (DirectoryReader, FileReader))
116
- else MEMORY_SCAN_TIME_PER_KB * per_record_size_kb
117
- )
107
+
108
+ # TODO: cannot do the first computation b/c we cannot import iter_dataset; possibly revisit
109
+ # time_per_record = (
110
+ # MEMORY_SCAN_TIME_PER_KB * per_record_size_kb
111
+ # if isinstance(self.datasource, (iter_dataset.MemoryDataset))
112
+ # else LOCAL_SCAN_TIME_PER_KB * per_record_size_kb
113
+ # )
114
+ time_per_record = LOCAL_SCAN_TIME_PER_KB * per_record_size_kb
118
115
 
119
116
  # estimate output cardinality
120
117
  cardinality = source_op_cost_estimates.cardinality
@@ -128,26 +125,68 @@ class MarshalAndScanDataOp(ScanPhysicalOp):
128
125
  )
129
126
 
130
127
 
131
- class CacheScanDataOp(ScanPhysicalOp):
128
+ class ContextScanOp(PhysicalOperator):
129
+ """
130
+ Physical operator which facillitates the loading of a Context for processing.
131
+ """
132
+
133
+ def __init__(self, context: context.Context, *args, **kwargs):
134
+ super().__init__(*args, **kwargs)
135
+ self.context = context
136
+
137
+ def __str__(self):
138
+ op = f"{self.op_name()}({self.context}) -> {self.output_schema}\n"
139
+ op += f" ({', '.join(list(self.output_schema.model_fields))[:30]})\n"
140
+ return op
141
+
142
+ def get_id_params(self):
143
+ return super().get_id_params()
144
+
145
+ def get_op_params(self):
146
+ op_params = super().get_op_params()
147
+ return {"context": self.context, **op_params}
148
+
132
149
  def naive_cost_estimates(
133
150
  self,
134
151
  source_op_cost_estimates: OperatorCostEstimates,
135
- input_record_size_in_bytes: int | float,
136
152
  ):
137
153
  # get inputs needed for naive cost estimation
138
154
  # TODO: we should rename cardinality --> "multiplier" or "selectivity" one-to-one / one-to-many
139
155
 
140
156
  # estimate time spent reading each record
141
- per_record_size_kb = input_record_size_in_bytes / 1024.0
142
- time_per_record = LOCAL_SCAN_TIME_PER_KB * per_record_size_kb
157
+ time_per_record = LOCAL_SCAN_TIME_PER_KB * 1.0
143
158
 
144
- # estimate output cardinality
145
- cardinality = source_op_cost_estimates.cardinality
146
-
147
- # for now, assume no cost per record for reading from cache
159
+ # for now, assume no cost per record for reading data
148
160
  return OperatorCostEstimates(
149
- cardinality=cardinality,
161
+ cardinality=1.0,
150
162
  time_per_record=time_per_record,
151
163
  cost_per_record=0,
152
164
  quality=1.0,
153
165
  )
166
+
167
+ def __call__(self, *args, **kwargs) -> DataRecordSet:
168
+ """
169
+ This function returns the context as a DataRecord wrapped in a DataRecordSet.
170
+ """
171
+ # construct a DataRecord from the context
172
+ start_time = time.time()
173
+ dr = DataRecord(self.output_schema, source_indices=[f"{self.context.id}-{0}"])
174
+ dr.context = self.context
175
+ end_time = time.time()
176
+
177
+ # create RecordOpStats objects
178
+ record_op_stats = RecordOpStats(
179
+ record_id=dr.id,
180
+ record_parent_ids=dr.parent_ids,
181
+ record_source_indices=dr.source_indices,
182
+ record_state=dr.to_dict(include_bytes=False),
183
+ full_op_id=self.get_full_op_id(),
184
+ logical_op_id=self.logical_op_id,
185
+ op_name=self.op_name(),
186
+ time_per_record=(end_time - start_time),
187
+ cost_per_record=0.0,
188
+ op_details={k: str(v) for k, v in self.get_id_params().items()},
189
+ )
190
+
191
+ # construct and return DataRecordSet object
192
+ return DataRecordSet([dr], [record_op_stats])