pineapple-pine 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pine/config.py ADDED
@@ -0,0 +1,304 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from dataclasses import dataclass, field
5
+ from typing import Any
6
+
7
+ from pine.errors import ConfigError
8
+
9
+ RESERVED_KEYS = frozenset({
10
+ "type_name", "$metadata", "$code_info", "skip", "recall", "sources",
11
+ "debug", "consumes_row_set", "mutates_row_set",
12
+ "additive_writes_row_set", "common_defaults", "item_defaults",
13
+ "for_branch_control", "data_parallel",
14
+ })
15
+
16
+
17
+ @dataclass
18
+ class Metadata:
19
+ common_input: list[str] = field(default_factory=list)
20
+ common_output: list[str] = field(default_factory=list)
21
+ item_input: list[str] = field(default_factory=list)
22
+ item_output: list[str] = field(default_factory=list)
23
+
24
+
25
+ @dataclass
26
+ class OperatorConfig:
27
+ """Per-operator configuration within a pipeline."""
28
+
29
+ type_name: str = ""
30
+ metadata: Metadata = field(default_factory=Metadata)
31
+ skip: list[str] = field(default_factory=list)
32
+ recall: bool = False
33
+ sources: list[str] = field(default_factory=list)
34
+ debug: bool = False
35
+ consumes_row_set: bool = False
36
+ mutates_row_set: bool = False
37
+ additive_writes_row_set: bool = False
38
+ for_branch_control: bool = False
39
+ data_parallel: int = 1
40
+ common_defaults: dict[str, Any] = field(default_factory=dict)
41
+ item_defaults: dict[str, Any] = field(default_factory=dict)
42
+ raw_params: dict[str, Any] = field(default_factory=dict)
43
+ operator_type: str = ""
44
+ input_spec: InputFieldSpec | None = None
45
+
46
+
47
+ class DefaultedField:
48
+ __slots__ = ("name", "default")
49
+
50
+ def __init__(self, name: str, default: Any):
51
+ self.name = name
52
+ self.default = default
53
+
54
+
55
+ class InputFieldSpec:
56
+ """Resolved input field specification with defaults and strict fields."""
57
+
58
+ __slots__ = ("strict_common", "defaulted_common", "strict_item", "defaulted_item")
59
+
60
+ def __init__(
61
+ self,
62
+ strict_common: list[str],
63
+ defaulted_common: list[DefaultedField],
64
+ strict_item: list[str],
65
+ defaulted_item: list[DefaultedField],
66
+ ):
67
+ self.strict_common = strict_common
68
+ self.defaulted_common = defaulted_common
69
+ self.strict_item = strict_item
70
+ self.defaulted_item = defaulted_item
71
+
72
+ @staticmethod
73
+ def compute(metadata: "Metadata", common_defaults: dict[str, Any],
74
+ item_defaults: dict[str, Any], skip: list[str]) -> "InputFieldSpec":
75
+ skip_set = set(skip)
76
+ strict_common: list[str] = []
77
+ defaulted_common: list[DefaultedField] = []
78
+ for f in metadata.common_input:
79
+ if f in skip_set:
80
+ continue
81
+ if f in common_defaults:
82
+ defaulted_common.append(DefaultedField(f, common_defaults[f]))
83
+ else:
84
+ strict_common.append(f)
85
+
86
+ strict_item: list[str] = []
87
+ defaulted_item: list[DefaultedField] = []
88
+ for f in metadata.item_input:
89
+ if f in item_defaults:
90
+ defaulted_item.append(DefaultedField(f, item_defaults[f]))
91
+ else:
92
+ strict_item.append(f)
93
+
94
+ return InputFieldSpec(strict_common, defaulted_common, strict_item, defaulted_item)
95
+
96
+
97
+ class SubFlowRef:
98
+ def __init__(self, pipeline: list[str] | None = None):
99
+ self.pipeline: list[str] = pipeline or []
100
+
101
+
102
+ class FlowContract:
103
+ """Declares the pipeline's required inputs and guaranteed outputs."""
104
+
105
+ def __init__(self):
106
+ self.common_input: list[str] = []
107
+ self.item_input: list[str] = []
108
+ self.common_output: list[str] = []
109
+ self.item_output: list[str] = []
110
+
111
+
112
+ class PipelineConfig:
113
+ """Holds operator configs and sub-flow pipeline map."""
114
+
115
+ def __init__(self):
116
+ self.operators: dict[str, OperatorConfig] = {}
117
+ self.pipeline_map: dict[str, SubFlowRef] = {}
118
+
119
+
120
+ class ExpandResult:
121
+ def __init__(self, sequence: list[str], op_to_sub_flow: dict[str, str]):
122
+ self.sequence = sequence
123
+ self.op_to_sub_flow = op_to_sub_flow
124
+
125
+
126
+ class Config:
127
+ """Pipeline configuration parsed from JSON."""
128
+
129
+ def __init__(self):
130
+ self.pineapple_version: str = ""
131
+ self.log_prefix: str = ""
132
+ self.debug: bool = False
133
+ self.storage_mode: str = "row"
134
+ self.pipeline_config: PipelineConfig = PipelineConfig()
135
+ self.pipeline_group: dict[str, SubFlowRef] = {}
136
+ self.flow_contract: FlowContract = FlowContract()
137
+
138
+ @classmethod
139
+ def load(cls, json_data: bytes) -> "Config":
140
+ try:
141
+ root = json.loads(json_data)
142
+ except Exception as e:
143
+ raise ConfigError(f"failed to parse config JSON: {e}")
144
+
145
+ cfg = Config()
146
+ cfg.pineapple_version = root.get("_PINEAPPLE_VERSION", "")
147
+ cfg.log_prefix = root.get("log_prefix", "")
148
+ cfg.debug = root.get("debug", False)
149
+ cfg.storage_mode = root.get("storage_mode", "row")
150
+
151
+ fc = root.get("flow_contract", {})
152
+ cfg.flow_contract.common_input = fc.get("common_input", [])
153
+ cfg.flow_contract.item_input = fc.get("item_input", [])
154
+ cfg.flow_contract.common_output = fc.get("common_output", [])
155
+ cfg.flow_contract.item_output = fc.get("item_output", [])
156
+
157
+ pg = root.get("pipeline_group", {})
158
+ for name, val in pg.items():
159
+ cfg.pipeline_group[name] = SubFlowRef(val.get("pipeline", []))
160
+
161
+ pc = root.get("pipeline_config", {})
162
+ pm = pc.get("pipeline_map", {})
163
+ for name, val in pm.items():
164
+ cfg.pipeline_config.pipeline_map[name] = SubFlowRef(val.get("pipeline", []))
165
+
166
+ ops = pc.get("operators", {})
167
+ for name, op_node in ops.items():
168
+ op_cfg = _parse_operator_config(op_node)
169
+ cfg.pipeline_config.operators[name] = op_cfg
170
+
171
+ _validate(cfg)
172
+ return cfg
173
+
174
+ def expand_operator_sequence_with_sub_flows(self) -> ExpandResult:
175
+ if "main" in self.pipeline_group:
176
+ group = self.pipeline_group["main"]
177
+ elif len(self.pipeline_group) == 1:
178
+ group = next(iter(self.pipeline_group.values()))
179
+ else:
180
+ raise ConfigError(
181
+ 'pipeline_group must contain a "main" entry or exactly one entry'
182
+ )
183
+
184
+ for name in self.pipeline_config.operators:
185
+ if name in self.pipeline_config.pipeline_map:
186
+ raise ConfigError(
187
+ f'name "{name}" exists in both operators and pipeline_map'
188
+ )
189
+
190
+ sequence: list[str] = []
191
+ op_to_sub_flow: dict[str, str] = {}
192
+ visiting: set[str] = set()
193
+ seen: set[str] = set()
194
+
195
+ self._expand_entries(
196
+ group.pipeline, "", sequence, op_to_sub_flow, visiting, seen
197
+ )
198
+ return ExpandResult(sequence, op_to_sub_flow)
199
+
200
+ def _expand_entries(
201
+ self,
202
+ entries: list[str],
203
+ parent_path: str,
204
+ sequence: list[str],
205
+ op_to_sub_flow: dict[str, str],
206
+ visiting: set[str],
207
+ seen: set[str],
208
+ ):
209
+ for entry in entries:
210
+ if entry in self.pipeline_config.operators:
211
+ if entry in seen:
212
+ raise ConfigError(
213
+ f'operator "{entry}" referenced more than once in pipeline tree'
214
+ )
215
+ seen.add(entry)
216
+ sequence.append(entry)
217
+ op_to_sub_flow[entry] = parent_path
218
+ elif entry in self.pipeline_config.pipeline_map:
219
+ if entry in visiting:
220
+ raise ConfigError(
221
+ f'cycle detected in sub-flow expansion: "{entry}"'
222
+ )
223
+ visiting.add(entry)
224
+ self._expand_entries(
225
+ self.pipeline_config.pipeline_map[entry].pipeline,
226
+ entry,
227
+ sequence,
228
+ op_to_sub_flow,
229
+ visiting,
230
+ seen,
231
+ )
232
+ visiting.discard(entry)
233
+ else:
234
+ raise ConfigError(
235
+ f'pipeline entry "{entry}" is neither an operator nor a sub-flow'
236
+ )
237
+
238
+
239
+ def _parse_operator_config(node: dict[str, Any]) -> OperatorConfig:
240
+ op_cfg = OperatorConfig()
241
+ op_cfg.type_name = node.get("type_name", "")
242
+ op_cfg.recall = node.get("recall", False)
243
+ op_cfg.debug = node.get("debug", False)
244
+ op_cfg.consumes_row_set = node.get("consumes_row_set", False)
245
+ op_cfg.mutates_row_set = node.get("mutates_row_set", False)
246
+ op_cfg.additive_writes_row_set = node.get("additive_writes_row_set", False)
247
+ op_cfg.for_branch_control = node.get("for_branch_control", False)
248
+ op_cfg.data_parallel = node.get("data_parallel", 1)
249
+ op_cfg.sources = node.get("sources", [])
250
+
251
+ skip = node.get("skip", [])
252
+ if isinstance(skip, str):
253
+ op_cfg.skip = [skip] if skip else []
254
+ elif isinstance(skip, list):
255
+ op_cfg.skip = [str(s) for s in skip]
256
+ else:
257
+ op_cfg.skip = []
258
+
259
+ meta = node.get("$metadata", {})
260
+ op_cfg.metadata.common_input = meta.get("common_input", [])
261
+ op_cfg.metadata.common_output = meta.get("common_output", [])
262
+ op_cfg.metadata.item_input = meta.get("item_input", [])
263
+ op_cfg.metadata.item_output = meta.get("item_output", [])
264
+
265
+ op_cfg.common_defaults = node.get("common_defaults", {})
266
+ op_cfg.item_defaults = node.get("item_defaults", {})
267
+
268
+ op_cfg.raw_params = {}
269
+ for key, value in node.items():
270
+ if key not in RESERVED_KEYS:
271
+ op_cfg.raw_params[key] = value
272
+
273
+ return op_cfg
274
+
275
+
276
+ def _validate(cfg: Config):
277
+ if not cfg.pipeline_config.operators:
278
+ raise ConfigError("pipeline_config.operators is empty")
279
+ if not cfg.pipeline_group:
280
+ raise ConfigError("pipeline_group is empty")
281
+
282
+ for name, op_cfg in cfg.pipeline_config.operators.items():
283
+ if not op_cfg.type_name:
284
+ raise ConfigError(f'operator "{name}": missing type_name')
285
+
286
+ for name, op_cfg in cfg.pipeline_config.operators.items():
287
+ for src in op_cfg.sources:
288
+ if src not in cfg.pipeline_config.operators:
289
+ raise ConfigError(
290
+ f'operator "{name}": sources references undefined operator "{src}"'
291
+ )
292
+
293
+ for name, op_cfg in cfg.pipeline_config.operators.items():
294
+ for skip_field in op_cfg.skip:
295
+ if not skip_field.startswith("_"):
296
+ raise ConfigError(
297
+ f'operator "{name}": skip field "{skip_field}" must start with '
298
+ "'_' (control fields are engine-internal)"
299
+ )
300
+ if skip_field not in op_cfg.metadata.common_input:
301
+ raise ConfigError(
302
+ f'operator "{name}": skip field "{skip_field}" must also appear '
303
+ "in $metadata.common_input to ensure correct DAG ordering"
304
+ )
pine/dag.py ADDED
@@ -0,0 +1,218 @@
1
+ from __future__ import annotations
2
+
3
+ from collections import deque
4
+
5
+ from pine.config import OperatorConfig
6
+ from pine.errors import ConfigError
7
+
8
+ _ROW_SET_SENTINEL = "_row_set_"
9
+
10
+
11
+ class _FieldTracker:
12
+ def __init__(self):
13
+ self.last_mut_writer: int = -1
14
+ self.additive_writers: list[int] = []
15
+ self.active_readers: list[int] = []
16
+
17
+
18
+ class Node:
19
+ """Single operator node within the DAG."""
20
+
21
+ def __init__(self, name: str, index: int, sub_flow: str, config: OperatorConfig):
22
+ self.name = name
23
+ self.index = index
24
+ self.sub_flow = sub_flow
25
+ self.config = config
26
+ self.preds: list[int] = []
27
+ self.succs: list[int] = []
28
+
29
+
30
+ class DAG:
31
+ """Directed acyclic graph encoding operator execution dependencies."""
32
+
33
+ def __init__(self, nodes: list[Node], name_to_index: dict[str, int]):
34
+ self.nodes = nodes
35
+ self.name_to_index = name_to_index
36
+
37
+ @classmethod
38
+ def build(
39
+ cls,
40
+ sequence: list[str],
41
+ operators: dict[str, OperatorConfig],
42
+ op_to_sub_flow: dict[str, str],
43
+ ) -> "DAG":
44
+ nodes: list[Node] = []
45
+ name_to_index: dict[str, int] = {}
46
+
47
+ for i, name in enumerate(sequence):
48
+ op_cfg = operators.get(name)
49
+ if op_cfg is None:
50
+ raise ConfigError(f'operator "{name}" not found')
51
+ node = Node(name, i, op_to_sub_flow.get(name, ""), op_cfg)
52
+ nodes.append(node)
53
+ name_to_index[name] = i
54
+
55
+ g = cls(nodes, name_to_index)
56
+
57
+ _add_edges(g, sequence, operators, is_common=True)
58
+ _add_edges(g, sequence, operators, is_common=False)
59
+
60
+ for i, name in enumerate(sequence):
61
+ op_cfg = operators[name]
62
+ for src in op_cfg.sources:
63
+ src_idx = name_to_index.get(src)
64
+ if src_idx is None:
65
+ raise ConfigError(
66
+ f'operator "{name}" sources references unknown operator "{src}"'
67
+ )
68
+ _add_edge(g, src_idx, i)
69
+
70
+ _reduce(g)
71
+ _topological_sort(g)
72
+
73
+ return g
74
+
75
+ def topological_order(self) -> list[int]:
76
+ return _topological_sort(self)
77
+
78
+
79
+ def _add_edge(g: DAG, from_idx: int, to_idx: int):
80
+ if from_idx == to_idx:
81
+ return
82
+ if to_idx in g.nodes[from_idx].succs:
83
+ return
84
+ g.nodes[from_idx].succs.append(to_idx)
85
+ g.nodes[to_idx].preds.append(from_idx)
86
+
87
+
88
+ def _add_edges(
89
+ g: DAG,
90
+ sequence: list[str],
91
+ operators: dict[str, OperatorConfig],
92
+ is_common: bool,
93
+ ):
94
+ fields: dict[str, _FieldTracker] = {}
95
+
96
+ for i, name in enumerate(sequence):
97
+ op_cfg = operators[name]
98
+ meta = op_cfg.metadata
99
+
100
+ read_fields = list(meta.common_input if is_common else meta.item_input)
101
+ write_fields = list(meta.common_output if is_common else meta.item_output)
102
+ is_additive_write = not is_common and op_cfg.additive_writes_row_set
103
+
104
+ if not is_common:
105
+ if is_additive_write:
106
+ write_fields.append(_ROW_SET_SENTINEL)
107
+ if op_cfg.consumes_row_set:
108
+ read_fields.append(_ROW_SET_SENTINEL)
109
+
110
+ # RAW edges
111
+ for field in read_fields:
112
+ ft = fields.setdefault(field, _FieldTracker())
113
+ if ft.last_mut_writer >= 0:
114
+ _add_edge(g, ft.last_mut_writer, i)
115
+ for aw in ft.additive_writers:
116
+ _add_edge(g, aw, i)
117
+ ft.active_readers.append(i)
118
+
119
+ # WAR + WAW edges
120
+ for field in write_fields:
121
+ ft = fields.setdefault(field, _FieldTracker())
122
+ if is_additive_write:
123
+ if ft.last_mut_writer >= 0:
124
+ _add_edge(g, ft.last_mut_writer, i)
125
+ for reader in ft.active_readers:
126
+ if reader != i:
127
+ _add_edge(g, reader, i)
128
+ ft.additive_writers.append(i)
129
+ else:
130
+ if ft.last_mut_writer >= 0:
131
+ _add_edge(g, ft.last_mut_writer, i)
132
+ for aw in ft.additive_writers:
133
+ _add_edge(g, aw, i)
134
+ for reader in ft.active_readers:
135
+ if reader != i:
136
+ _add_edge(g, reader, i)
137
+ ft.last_mut_writer = i
138
+ ft.additive_writers.clear()
139
+ ft.active_readers.clear()
140
+
141
+ # MutatesRowSet: mutating write to _ROW_SET_SENTINEL
142
+ if not is_common and op_cfg.mutates_row_set:
143
+ ft = fields.setdefault(_ROW_SET_SENTINEL, _FieldTracker())
144
+ if ft.last_mut_writer >= 0:
145
+ _add_edge(g, ft.last_mut_writer, i)
146
+ for aw in ft.additive_writers:
147
+ _add_edge(g, aw, i)
148
+ for reader in ft.active_readers:
149
+ if reader != i:
150
+ _add_edge(g, reader, i)
151
+ ft.last_mut_writer = i
152
+ ft.additive_writers.clear()
153
+ ft.active_readers.clear()
154
+
155
+
156
+ def _topological_sort(g: DAG) -> list[int]:
157
+ n = len(g.nodes)
158
+ in_degree = [len(node.preds) for node in g.nodes]
159
+ queue: deque[int] = deque()
160
+ for i in range(n):
161
+ if in_degree[i] == 0:
162
+ queue.append(i)
163
+
164
+ order: list[int] = []
165
+ while queue:
166
+ curr = queue.popleft()
167
+ order.append(curr)
168
+ for succ in g.nodes[curr].succs:
169
+ in_degree[succ] -= 1
170
+ if in_degree[succ] == 0:
171
+ queue.append(succ)
172
+
173
+ if len(order) != n:
174
+ cycle_nodes = [g.nodes[i].name for i in range(n) if in_degree[i] > 0]
175
+ raise ConfigError(
176
+ f"DAG contains a cycle involving operators: {cycle_nodes}"
177
+ )
178
+ return order
179
+
180
+
181
+ def _reduce(g: DAG):
182
+ n = len(g.nodes)
183
+ kept: list[tuple[int, int]] = []
184
+
185
+ for u in range(n):
186
+ for v in g.nodes[u].succs:
187
+ if not _reachable_without(g, u, v):
188
+ kept.append((u, v))
189
+
190
+ for node in g.nodes:
191
+ node.preds.clear()
192
+ node.succs.clear()
193
+ for u, v in kept:
194
+ g.nodes[u].succs.append(v)
195
+ g.nodes[v].preds.append(u)
196
+
197
+
198
+ def _reachable_without(g: DAG, src: int, dst: int) -> bool:
199
+ visited = set()
200
+ visited.add(src)
201
+ queue: deque[int] = deque()
202
+
203
+ for next_node in g.nodes[src].succs:
204
+ if next_node == dst:
205
+ continue
206
+ if next_node not in visited:
207
+ visited.add(next_node)
208
+ queue.append(next_node)
209
+
210
+ while queue:
211
+ cur = queue.popleft()
212
+ if cur == dst:
213
+ return True
214
+ for next_node in g.nodes[cur].succs:
215
+ if next_node not in visited:
216
+ visited.add(next_node)
217
+ queue.append(next_node)
218
+ return False