palimpzest 0.7.21__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. palimpzest/__init__.py +37 -6
  2. palimpzest/agents/__init__.py +0 -0
  3. palimpzest/agents/compute_agents.py +0 -0
  4. palimpzest/agents/search_agents.py +637 -0
  5. palimpzest/constants.py +343 -209
  6. palimpzest/core/data/context.py +393 -0
  7. palimpzest/core/data/context_manager.py +163 -0
  8. palimpzest/core/data/dataset.py +639 -0
  9. palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
  10. palimpzest/core/elements/groupbysig.py +16 -13
  11. palimpzest/core/elements/records.py +166 -75
  12. palimpzest/core/lib/schemas.py +152 -390
  13. palimpzest/core/{data/dataclasses.py → models.py} +306 -170
  14. palimpzest/policy.py +2 -27
  15. palimpzest/prompts/__init__.py +35 -5
  16. palimpzest/prompts/agent_prompts.py +357 -0
  17. palimpzest/prompts/context_search.py +9 -0
  18. palimpzest/prompts/convert_prompts.py +62 -6
  19. palimpzest/prompts/filter_prompts.py +51 -6
  20. palimpzest/prompts/join_prompts.py +163 -0
  21. palimpzest/prompts/moa_proposer_convert_prompts.py +6 -6
  22. palimpzest/prompts/prompt_factory.py +375 -47
  23. palimpzest/prompts/split_proposer_prompts.py +1 -1
  24. palimpzest/prompts/util_phrases.py +5 -0
  25. palimpzest/prompts/validator.py +239 -0
  26. palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
  27. palimpzest/query/execution/execution_strategy.py +210 -317
  28. palimpzest/query/execution/execution_strategy_type.py +5 -7
  29. palimpzest/query/execution/mab_execution_strategy.py +249 -136
  30. palimpzest/query/execution/parallel_execution_strategy.py +153 -244
  31. palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
  32. palimpzest/query/generators/generators.py +160 -331
  33. palimpzest/query/operators/__init__.py +15 -5
  34. palimpzest/query/operators/aggregate.py +50 -33
  35. palimpzest/query/operators/compute.py +201 -0
  36. palimpzest/query/operators/convert.py +33 -19
  37. palimpzest/query/operators/critique_and_refine_convert.py +7 -5
  38. palimpzest/query/operators/distinct.py +62 -0
  39. palimpzest/query/operators/filter.py +26 -16
  40. palimpzest/query/operators/join.py +403 -0
  41. palimpzest/query/operators/limit.py +3 -3
  42. palimpzest/query/operators/logical.py +205 -77
  43. palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
  44. palimpzest/query/operators/physical.py +27 -21
  45. palimpzest/query/operators/project.py +3 -3
  46. palimpzest/query/operators/rag_convert.py +7 -7
  47. palimpzest/query/operators/retrieve.py +9 -9
  48. palimpzest/query/operators/scan.py +81 -42
  49. palimpzest/query/operators/search.py +524 -0
  50. palimpzest/query/operators/split_convert.py +10 -8
  51. palimpzest/query/optimizer/__init__.py +7 -9
  52. palimpzest/query/optimizer/cost_model.py +108 -441
  53. palimpzest/query/optimizer/optimizer.py +123 -181
  54. palimpzest/query/optimizer/optimizer_strategy.py +66 -61
  55. palimpzest/query/optimizer/plan.py +352 -67
  56. palimpzest/query/optimizer/primitives.py +43 -19
  57. palimpzest/query/optimizer/rules.py +484 -646
  58. palimpzest/query/optimizer/tasks.py +127 -58
  59. palimpzest/query/processor/config.py +42 -76
  60. palimpzest/query/processor/query_processor.py +73 -18
  61. palimpzest/query/processor/query_processor_factory.py +46 -38
  62. palimpzest/schemabuilder/schema_builder.py +15 -28
  63. palimpzest/utils/model_helpers.py +32 -77
  64. palimpzest/utils/progress.py +114 -102
  65. palimpzest/validator/__init__.py +0 -0
  66. palimpzest/validator/validator.py +306 -0
  67. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/METADATA +6 -1
  68. palimpzest-0.8.1.dist-info/RECORD +95 -0
  69. palimpzest/core/lib/fields.py +0 -141
  70. palimpzest/prompts/code_synthesis_prompts.py +0 -28
  71. palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
  72. palimpzest/query/generators/api_client_factory.py +0 -30
  73. palimpzest/query/operators/code_synthesis_convert.py +0 -488
  74. palimpzest/query/operators/map.py +0 -130
  75. palimpzest/query/processor/nosentinel_processor.py +0 -33
  76. palimpzest/query/processor/processing_strategy_type.py +0 -28
  77. palimpzest/query/processor/sentinel_processor.py +0 -88
  78. palimpzest/query/processor/streaming_processor.py +0 -149
  79. palimpzest/sets.py +0 -405
  80. palimpzest/utils/datareader_helpers.py +0 -61
  81. palimpzest/utils/demo_helpers.py +0 -75
  82. palimpzest/utils/field_helpers.py +0 -69
  83. palimpzest/utils/generation_helpers.py +0 -69
  84. palimpzest/utils/sandbox.py +0 -183
  85. palimpzest-0.7.21.dist-info/RECORD +0 -95
  86. /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
  87. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/WHEEL +0 -0
  88. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/licenses/LICENSE +0 -0
  89. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/top_level.txt +0 -0
@@ -3,32 +3,32 @@ from __future__ import annotations
3
3
  import logging
4
4
  from copy import deepcopy
5
5
 
6
+ from pydantic.fields import FieldInfo
7
+
6
8
  from palimpzest.constants import Model
7
- from palimpzest.core.data.datareaders import DataReader
8
- from palimpzest.core.lib.fields import Field
9
+ from palimpzest.core.data.dataset import Dataset
10
+ from palimpzest.core.lib.schemas import get_schema_field_names
9
11
  from palimpzest.policy import Policy
12
+ from palimpzest.query.execution.execution_strategy_type import ExecutionStrategyType
10
13
  from palimpzest.query.operators.logical import (
11
- Aggregate,
12
- BaseScan,
14
+ ComputeOperator,
13
15
  ConvertScan,
16
+ Distinct,
14
17
  FilteredScan,
15
- GroupByAggregate,
18
+ JoinOp,
16
19
  LimitScan,
17
- LogicalOperator,
18
- MapScan,
19
20
  Project,
20
- RetrieveScan,
21
+ SearchOperator,
21
22
  )
22
23
  from palimpzest.query.optimizer import (
23
24
  IMPLEMENTATION_RULES,
24
25
  TRANSFORMATION_RULES,
25
26
  )
26
- from palimpzest.query.optimizer.cost_model import CostModel
27
+ from palimpzest.query.optimizer.cost_model import BaseCostModel, SampleBasedCostModel
27
28
  from palimpzest.query.optimizer.optimizer_strategy_type import OptimizationStrategyType
28
29
  from palimpzest.query.optimizer.plan import PhysicalPlan
29
30
  from palimpzest.query.optimizer.primitives import Group, LogicalExpression
30
31
  from palimpzest.query.optimizer.rules import (
31
- CodeSynthesisConvertRule,
32
32
  CriticAndRefineConvertRule,
33
33
  LLMConvertBondedRule,
34
34
  MixtureOfAgentsConvertRule,
@@ -42,21 +42,10 @@ from palimpzest.query.optimizer.tasks import (
42
42
  OptimizeLogicalExpression,
43
43
  OptimizePhysicalExpression,
44
44
  )
45
- from palimpzest.sets import Dataset, Set
46
- from palimpzest.utils.hash_helpers import hash_for_serialized_dict
47
- from palimpzest.utils.model_helpers import get_champion_model, get_code_champion_model, get_fallback_model
48
45
 
49
46
  logger = logging.getLogger(__name__)
50
47
 
51
48
 
52
- def get_node_uid(node: Dataset | DataReader) -> str:
53
- """Helper function to compute the universal identifier for a node in the query plan."""
54
- # NOTE: technically, hash_for_serialized_dict(node.serialize()) would be valid for both DataReader and Dataset;
55
- # for the moment, I want to be explicit in Dataset about what constitutes a unique Dataset object, but
56
- # in ther future we may be able to remove universal_identifier() from Dataset and just use this function
57
- return node.universal_identifier() if isinstance(node, Dataset) else hash_for_serialized_dict(node.serialize())
58
-
59
-
60
49
  class Optimizer:
61
50
  """
62
51
  The optimizer is responsible for searching the space of possible physical plans
@@ -83,17 +72,19 @@ class Optimizer:
83
72
  def __init__(
84
73
  self,
85
74
  policy: Policy,
86
- cost_model: CostModel,
75
+ cost_model: BaseCostModel,
87
76
  available_models: list[Model],
88
- cache: bool = False,
77
+ join_parallelism: int = 64,
78
+ reasoning_effort: str | None = None,
79
+ api_base: str | None = None,
89
80
  verbose: bool = False,
90
81
  allow_bonded_query: bool = True,
91
- allow_code_synth: bool = False,
92
82
  allow_rag_reduction: bool = False,
93
83
  allow_mixtures: bool = True,
94
84
  allow_critic: bool = False,
95
85
  allow_split_merge: bool = False,
96
86
  optimizer_strategy: OptimizationStrategyType = OptimizationStrategyType.PARETO,
87
+ execution_strategy: ExecutionStrategyType = ExecutionStrategyType.PARALLEL,
97
88
  use_final_op_quality: bool = False, # TODO: make this func(plan) -> final_quality
98
89
  **kwargs,
99
90
  ):
@@ -128,7 +119,6 @@ class Optimizer:
128
119
  # and remove all optimizations (except for bonded queries)
129
120
  if optimizer_strategy == OptimizationStrategyType.NONE:
130
121
  self.allow_bonded_query = True
131
- self.allow_code_synth = False
132
122
  self.allow_rag_reduction = False
133
123
  self.allow_mixtures = False
134
124
  self.allow_critic = False
@@ -136,16 +126,18 @@ class Optimizer:
136
126
  self.available_models = [available_models[0]]
137
127
 
138
128
  # store optimization hyperparameters
139
- self.cache = cache
140
129
  self.verbose = verbose
141
130
  self.available_models = available_models
131
+ self.join_parallelism = join_parallelism
132
+ self.reasoning_effort = reasoning_effort
133
+ self.api_base = api_base
142
134
  self.allow_bonded_query = allow_bonded_query
143
- self.allow_code_synth = allow_code_synth
144
135
  self.allow_rag_reduction = allow_rag_reduction
145
136
  self.allow_mixtures = allow_mixtures
146
137
  self.allow_critic = allow_critic
147
138
  self.allow_split_merge = allow_split_merge
148
139
  self.optimizer_strategy = optimizer_strategy
140
+ self.execution_strategy = execution_strategy
149
141
  self.use_final_op_quality = use_final_op_quality
150
142
 
151
143
  # prune implementation rules based on boolean flags
@@ -156,11 +148,6 @@ class Optimizer:
156
148
  if rule not in [LLMConvertBondedRule]
157
149
  ]
158
150
 
159
- if not self.allow_code_synth:
160
- self.implementation_rules = [
161
- rule for rule in self.implementation_rules if not issubclass(rule, CodeSynthesisConvertRule)
162
- ]
163
-
164
151
  if not self.allow_rag_reduction:
165
152
  self.implementation_rules = [
166
153
  rule for rule in self.implementation_rules if not issubclass(rule, RAGConvertRule)
@@ -184,32 +171,34 @@ class Optimizer:
184
171
  logger.info(f"Initialized Optimizer with verbose={self.verbose}")
185
172
  logger.debug(f"Initialized Optimizer with params: {self.__dict__}")
186
173
 
187
- def update_cost_model(self, cost_model: CostModel):
174
+ def update_cost_model(self, cost_model: BaseCostModel):
188
175
  self.cost_model = cost_model
189
176
 
190
177
  def get_physical_op_params(self):
191
178
  return {
192
179
  "verbose": self.verbose,
193
180
  "available_models": self.available_models,
194
- "champion_model": get_champion_model(self.available_models),
195
- "code_champion_model": get_code_champion_model(self.available_models),
196
- "fallback_model": get_fallback_model(self.available_models),
181
+ "join_parallelism": self.join_parallelism,
182
+ "reasoning_effort": self.reasoning_effort,
183
+ "api_base": self.api_base,
197
184
  }
198
185
 
199
186
  def deepcopy_clean(self):
200
187
  optimizer = Optimizer(
201
188
  policy=self.policy,
202
- cost_model=CostModel(),
203
- cache=self.cache,
189
+ cost_model=SampleBasedCostModel(),
204
190
  verbose=self.verbose,
205
191
  available_models=self.available_models,
192
+ join_parallelism=self.join_parallelism,
193
+ reasoning_effort=self.reasoning_effort,
194
+ api_base=self.api_base,
206
195
  allow_bonded_query=self.allow_bonded_query,
207
- allow_code_synth=self.allow_code_synth,
208
196
  allow_rag_reduction=self.allow_rag_reduction,
209
197
  allow_mixtures=self.allow_mixtures,
210
198
  allow_critic=self.allow_critic,
211
199
  allow_split_merge=self.allow_split_merge,
212
200
  optimizer_strategy=self.optimizer_strategy,
201
+ execution_strategy=self.execution_strategy,
213
202
  use_final_op_quality=self.use_final_op_quality,
214
203
  )
215
204
  return optimizer
@@ -219,121 +208,65 @@ class Optimizer:
219
208
  optimizer_strategy_cls = optimizer_strategy.value
220
209
  self.strategy = optimizer_strategy_cls()
221
210
 
222
- def construct_group_tree(self, dataset_nodes: list[Set]) -> tuple[list[int], dict[str, Field], dict[str, set[str]]]:
223
- # get node, output_schema, and input_schema (if applicable)
224
- logger.debug(f"Constructing group tree for dataset_nodes: {dataset_nodes}")
225
-
226
- node = dataset_nodes[-1]
227
- output_schema = node.schema
228
- input_schema = dataset_nodes[-2].schema if len(dataset_nodes) > 1 else None
229
-
211
+ def construct_group_tree(self, dataset: Dataset) -> tuple[int, dict[str, FieldInfo], dict[str, set[str]]]:
212
+ logger.debug(f"Constructing group tree for dataset: {dataset}")
230
213
  ### convert node --> Group ###
231
- uid = get_node_uid(node)
232
-
233
214
  # create the op for the given node
234
- op: LogicalOperator | None = None
235
-
236
- # TODO: add cache scan when we add caching back to PZ
237
- # if self.cache:
238
- # op = CacheScan(datareader=node, output_schema=output_schema)
239
- if isinstance(node, DataReader):
240
- op = BaseScan(datareader=node, output_schema=output_schema)
241
- elif node._filter is not None:
242
- op = FilteredScan(
243
- input_schema=input_schema,
244
- output_schema=output_schema,
245
- filter=node._filter,
246
- depends_on=node._depends_on,
247
- target_cache_id=uid,
248
- )
249
- elif node._group_by is not None:
250
- op = GroupByAggregate(
251
- input_schema=input_schema,
252
- output_schema=output_schema,
253
- group_by_sig=node._group_by,
254
- target_cache_id=uid,
255
- )
256
- elif node._agg_func is not None:
257
- op = Aggregate(
258
- input_schema=input_schema,
259
- output_schema=output_schema,
260
- agg_func=node._agg_func,
261
- target_cache_id=uid,
262
- )
263
- elif node._limit is not None:
264
- op = LimitScan(
265
- input_schema=input_schema,
266
- output_schema=output_schema,
267
- limit=node._limit,
268
- target_cache_id=uid,
269
- )
270
- elif node._project_cols is not None:
271
- op = Project(
272
- input_schema=input_schema,
273
- output_schema=output_schema,
274
- project_cols=node._project_cols,
275
- target_cache_id=uid,
276
- )
277
- elif node._index is not None:
278
- op = RetrieveScan(
279
- input_schema=input_schema,
280
- output_schema=output_schema,
281
- index=node._index,
282
- search_func=node._search_func,
283
- search_attr=node._search_attr,
284
- output_attrs=node._output_attrs,
285
- k=node._k,
286
- target_cache_id=uid,
287
- )
288
- elif output_schema != input_schema:
289
- op = ConvertScan(
290
- input_schema=input_schema,
291
- output_schema=output_schema,
292
- cardinality=node._cardinality,
293
- udf=node._udf,
294
- depends_on=node._depends_on,
295
- target_cache_id=uid,
296
- )
297
- elif output_schema == input_schema and node._udf is not None:
298
- op = MapScan(
299
- input_schema=input_schema,
300
- output_schema=output_schema,
301
- udf=node._udf,
302
- target_cache_id=uid,
303
- )
304
- # some legacy plans may have a useless convert; for now we simply skip it
305
- elif output_schema == input_schema:
306
- return self.construct_group_tree(dataset_nodes[:-1]) if len(dataset_nodes) > 1 else ([], {}, {})
215
+ op = dataset._operator
216
+
217
+ # compute the input group id(s) and field(s) for this node
218
+ if len(dataset._sources) == 0:
219
+ input_group_ids, input_group_fields, input_group_properties = ([], {}, {})
220
+ elif len(dataset._sources) == 1:
221
+ input_group_id, input_group_fields, input_group_properties = self.construct_group_tree(dataset._sources[0])
222
+ input_group_ids = [input_group_id]
223
+ elif len(dataset._sources) == 2:
224
+ left_input_group_id, left_input_group_fields, left_input_group_properties = self.construct_group_tree(dataset._sources[0])
225
+ right_input_group_id, right_input_group_fields, right_input_group_properties = self.construct_group_tree(dataset._sources[1])
226
+ input_group_ids = [left_input_group_id, right_input_group_id]
227
+ input_group_fields = {**left_input_group_fields, **right_input_group_fields}
228
+ input_group_properties = deepcopy(left_input_group_properties)
229
+ for k, v in right_input_group_properties.items():
230
+ if k in input_group_properties:
231
+ input_group_properties[k].update(v)
232
+ else:
233
+ input_group_properties[k] = deepcopy(v)
307
234
  else:
308
- raise NotImplementedError(
309
- f"""No logical operator exists for the specified dataset construction.
310
- {input_schema}->{output_schema} {"with filter:'" + node._filter + "'" if node._filter is not None else ""}"""
311
- )
312
-
313
- # compute the input group ids and fields for this node
314
- input_group_ids, input_group_fields, input_group_properties = (
315
- self.construct_group_tree(dataset_nodes[:-1]) if len(dataset_nodes) > 1 else ([], {}, {})
316
- )
235
+ raise NotImplementedError("Constructing group trees for datasets with more than 2 sources is not supported.")
317
236
 
318
237
  # compute the fields added by this operation and all fields
319
238
  input_group_short_field_names = list(
320
239
  map(lambda full_field: full_field.split(".")[-1], input_group_fields.keys())
321
240
  )
322
241
  new_fields = {
323
- field_name: field
324
- for field_name, field in op.output_schema.field_map(unique=True, id=uid).items()
325
- if (field_name.split(".")[-1] not in input_group_short_field_names) or (node._udf is not None)
242
+ field_name: op.output_schema.model_fields[field_name.split(".")[-1]]
243
+ for field_name in get_schema_field_names(op.output_schema, id=dataset.id)
244
+ if (field_name not in input_group_short_field_names) or (hasattr(op, "udf") and op.udf is not None)
326
245
  }
327
246
  all_fields = {**input_group_fields, **new_fields}
328
247
 
329
248
  # compute the set of (short) field names this operation depends on
330
249
  depends_on_field_names = (
331
- {} if isinstance(node, DataReader) else {field_name.split(".")[-1] for field_name in node._depends_on}
250
+ {} if dataset.is_root else {field_name.split(".")[-1] for field_name in op.depends_on}
332
251
  )
333
252
 
253
+ # NOTE: group_id is computed as the unique (sorted) set of fields and properties;
254
+ # If an operation does not modify the fields (or modifies them in a way that
255
+ # can create an idential field set to an earlier group) then we must add an
256
+ # id from the operator to disambiguate the two groups.
334
257
  # compute all properties including this operations'
335
258
  all_properties = deepcopy(input_group_properties)
336
- if isinstance(op, FilteredScan):
259
+ if isinstance(op, ConvertScan) and sorted(op.input_schema.model_fields.keys()) == sorted(op.output_schema.model_fields.keys()):
260
+ model_fields_dict = {
261
+ k: {"annotation": v.annotation, "default": v.default, "description": v.description}
262
+ for k, v in op.output_schema.model_fields.items()
263
+ }
264
+ if "maps" in all_properties:
265
+ all_properties["maps"].add(model_fields_dict)
266
+ else:
267
+ all_properties["maps"] = set([model_fields_dict])
268
+
269
+ elif isinstance(op, FilteredScan):
337
270
  # NOTE: we could use op.get_full_op_id() here, but storing filter strings makes
338
271
  # debugging a bit easier as you can read which filters are in the Group
339
272
  op_filter_str = op.filter.get_filter_str()
@@ -342,6 +275,12 @@ class Optimizer:
342
275
  else:
343
276
  all_properties["filters"] = set([op_filter_str])
344
277
 
278
+ elif isinstance(op, JoinOp):
279
+ if "joins" in all_properties:
280
+ all_properties["joins"].add(op.condition)
281
+ else:
282
+ all_properties["joins"] = set([op.condition])
283
+
345
284
  elif isinstance(op, LimitScan):
346
285
  op_limit_str = op.get_logical_op_id()
347
286
  if "limits" in all_properties:
@@ -356,12 +295,27 @@ class Optimizer:
356
295
  else:
357
296
  all_properties["projects"] = set([op_project_str])
358
297
 
359
- elif isinstance(op, MapScan):
360
- op_udf_str = op.udf.__name__
361
- if "udfs" in all_properties:
362
- all_properties["udfs"].add(op_udf_str)
298
+ elif isinstance(op, Distinct):
299
+ op_distinct_str = op.get_logical_op_id()
300
+ if "distincts" in all_properties:
301
+ all_properties["distincts"].add(op_distinct_str)
302
+ else:
303
+ all_properties["distincts"] = set([op_distinct_str])
304
+
305
+ # TODO: temporary fix; perhaps use op_ids to identify group?
306
+ elif isinstance(op, ComputeOperator):
307
+ op_instruction = op.instruction
308
+ if "instructions" in all_properties:
309
+ all_properties["instructions"].add(op_instruction)
363
310
  else:
364
- all_properties["udfs"] = set([op_udf_str])
311
+ all_properties["instructions"] = set([op_instruction])
312
+
313
+ elif isinstance(op, SearchOperator):
314
+ op_search_query = op.search_query
315
+ if "search_queries" in all_properties:
316
+ all_properties["search_queries"].add(op_search_query)
317
+ else:
318
+ all_properties["search_queries"] = set([op_search_query])
365
319
 
366
320
  # construct the logical expression and group
367
321
  logical_expression = LogicalExpression(
@@ -380,62 +334,50 @@ class Optimizer:
380
334
  logical_expression.set_group_id(group.group_id)
381
335
 
382
336
  # add the expression and group to the optimizer's expressions and groups and return
383
- self.expressions[logical_expression.get_expr_id()] = logical_expression
337
+ self.expressions[logical_expression.expr_id] = logical_expression
384
338
  self.groups[group.group_id] = group
385
- logger.debug(f"Constructed group tree for dataset_nodes: {dataset_nodes}")
339
+ logger.debug(f"Constructed group tree for dataset: {dataset}")
386
340
  logger.debug(f"Group: {group.group_id}, {all_fields}, {all_properties}")
387
341
 
388
- return [group.group_id], all_fields, all_properties
389
-
390
- def convert_query_plan_to_group_tree(self, query_plan: Dataset) -> str:
391
- logger.debug(f"Converting query plan to group tree for query_plan: {query_plan}")
392
- # Obtain ordered list of datasets
393
- dataset_nodes: list[Dataset | DataReader] = []
394
- node = query_plan.copy()
342
+ return group.group_id, all_fields, all_properties
395
343
 
396
- # NOTE: the very first node will be a DataReader; the rest will be Dataset
397
- while isinstance(node, Dataset):
398
- dataset_nodes.append(node)
399
- node = node._source
400
- dataset_nodes.append(node)
401
- dataset_nodes = list(reversed(dataset_nodes))
344
+ def convert_query_plan_to_group_tree(self, dataset: Dataset) -> str:
345
+ logger.debug(f"Converting query plan to group tree for dataset: {dataset}")
402
346
 
403
347
  # compute depends_on field for every node
404
348
  short_to_full_field_name = {}
405
- for node_idx, node in enumerate(dataset_nodes):
349
+ for node in dataset:
406
350
  # update mapping from short to full field names
407
- short_field_names = node.schema.field_names()
408
- full_field_names = node.schema.field_names(unique=True, id=get_node_uid(node))
351
+ short_field_names = get_schema_field_names(node.schema)
352
+ full_field_names = get_schema_field_names(node.schema, id=node.id)
409
353
  for short_field_name, full_field_name in zip(short_field_names, full_field_names):
410
354
  # set mapping automatically if this is a new field
411
- if short_field_name not in short_to_full_field_name or (
412
- node_idx > 0 and dataset_nodes[node_idx - 1].schema != node.schema and node._udf is not None
413
- ):
355
+ if short_field_name not in short_to_full_field_name or (hasattr(node._operator, "udf") and node._operator.udf is not None):
414
356
  short_to_full_field_name[short_field_name] = full_field_name
415
357
 
416
- # if the node is a data source, then skip
417
- if isinstance(node, DataReader):
358
+ # if the node is a root Dataset, then skip
359
+ if node.is_root:
418
360
  continue
419
361
 
420
362
  # If the node already has depends_on specified, then resolve each field name to a full (unique) field name
421
- if len(node._depends_on) > 0:
422
- node._depends_on = list(map(lambda field: short_to_full_field_name[field], node._depends_on))
363
+ if len(node._operator.depends_on) > 0:
364
+ node._operator.depends_on = list(map(lambda field: short_to_full_field_name[field], node._operator.depends_on))
423
365
  continue
424
366
 
425
367
  # otherwise, make the node depend on all upstream nodes
426
- node._depends_on = set()
427
- for upstream_node in dataset_nodes[:node_idx]:
428
- node._depends_on.update(upstream_node.schema.field_names(unique=True, id=get_node_uid(upstream_node)))
429
- node._depends_on = list(node._depends_on)
368
+ node._operator.depends_on = set()
369
+ upstream_nodes = node.get_upstream_datasets()
370
+ for upstream_node in upstream_nodes:
371
+ upstream_field_names = get_schema_field_names(upstream_node.schema, id=upstream_node.id)
372
+ node._operator.depends_on.update(upstream_field_names)
373
+ node._operator.depends_on = list(node._operator.depends_on)
430
374
 
431
375
  # construct tree of groups
432
- final_group_id, _, _ = self.construct_group_tree(dataset_nodes)
376
+ final_group_id, _, _ = self.construct_group_tree(dataset)
433
377
 
434
- # check that final_group_id is a singleton
435
- assert len(final_group_id) == 1
436
- final_group_id = final_group_id[0]
437
- logger.debug(f"Converted query plan to group tree for query_plan: {query_plan}")
378
+ logger.debug(f"Converted query plan to group tree for dataset: {dataset}")
438
379
  logger.debug(f"Final group id: {final_group_id}")
380
+
439
381
  return final_group_id
440
382
 
441
383
  def heuristic_optimization(self, group_id: int) -> None:
@@ -462,24 +404,24 @@ class Optimizer:
462
404
  elif isinstance(task, ApplyRule):
463
405
  context = {"costed_full_op_ids": self.cost_model.get_costed_full_op_ids()}
464
406
  new_tasks = task.perform(
465
- self.groups, self.expressions, context=context, **self.get_physical_op_params()
407
+ self.groups, self.expressions, context=context, **self.get_physical_op_params(),
466
408
  )
467
409
  elif isinstance(task, OptimizePhysicalExpression):
468
- context = {"optimizer_strategy": self.optimizer_strategy}
410
+ context = {"optimizer_strategy": self.optimizer_strategy, "execution_strategy": self.execution_strategy}
469
411
  new_tasks = task.perform(self.cost_model, self.groups, self.policy, context=context)
470
-
471
412
  self.tasks_stack.extend(new_tasks)
472
413
 
473
414
  logger.debug(f"Done searching optimization space for group_id: {group_id}")
474
415
 
475
- def optimize(self, query_plan: Dataset) -> list[PhysicalPlan]:
416
+ def optimize(self, dataset: Dataset) -> list[PhysicalPlan]:
476
417
  """
477
418
  The optimize function takes in an initial query plan and searches the space of
478
419
  logical and physical plans in order to cost and produce a (near) optimal physical plan.
479
420
  """
480
- logger.info(f"Optimizing query plan: {query_plan}")
421
+ logger.info(f"Optimizing query plan: {dataset}")
481
422
  # compute the initial group tree for the user plan
482
- final_group_id = self.convert_query_plan_to_group_tree(query_plan)
423
+ dataset_copy = dataset.copy()
424
+ final_group_id = self.convert_query_plan_to_group_tree(dataset_copy)
483
425
 
484
426
  # TODO
485
427
  # # do heuristic based pre-optimization