pineapple-pine 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pine/__init__.py +8 -0
- pine/cancellation.py +20 -0
- pine/cli/__init__.py +0 -0
- pine/cli/codegen.py +67 -0
- pine/cli/dag.py +62 -0
- pine/cli/run.py +97 -0
- pine/cli/server.py +346 -0
- pine/config.py +304 -0
- pine/dag.py +218 -0
- pine/engine.py +681 -0
- pine/errors.py +31 -0
- pine/frame.py +237 -0
- pine/go_format.py +181 -0
- pine/operator.py +346 -0
- pine/operators/__init__.py +412 -0
- pine/operators/filter_condition.py +29 -0
- pine/operators/filter_paginate.py +48 -0
- pine/operators/filter_truncate.py +35 -0
- pine/operators/merge_dedup.py +38 -0
- pine/operators/observe_log.py +53 -0
- pine/operators/recall_resource.py +53 -0
- pine/operators/recall_static.py +38 -0
- pine/operators/reorder_shuffle.py +92 -0
- pine/operators/reorder_sort.py +62 -0
- pine/operators/transform_by_lua.py +308 -0
- pine/operators/transform_copy.py +58 -0
- pine/operators/transform_dispatch.py +29 -0
- pine/operators/transform_normalize.py +59 -0
- pine/operators/transform_redis_get.py +138 -0
- pine/operators/transform_redis_set.py +147 -0
- pine/operators/transform_remote_pineapple.py +224 -0
- pine/operators/transform_resource_lookup.py +87 -0
- pine/operators/transform_size.py +25 -0
- pine/parallel.py +88 -0
- pine/py.typed +0 -0
- pine/registry.py +88 -0
- pine/result.py +32 -0
- pine/stats.py +98 -0
- pine/visualize.py +184 -0
- pineapple_pine-0.7.0.dist-info/METADATA +78 -0
- pineapple_pine-0.7.0.dist-info/RECORD +43 -0
- pineapple_pine-0.7.0.dist-info/WHEEL +5 -0
- pineapple_pine-0.7.0.dist-info/top_level.txt +1 -0
pine/config.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from pine.errors import ConfigError
|
|
8
|
+
|
|
9
|
+
RESERVED_KEYS = frozenset({
|
|
10
|
+
"type_name", "$metadata", "$code_info", "skip", "recall", "sources",
|
|
11
|
+
"debug", "consumes_row_set", "mutates_row_set",
|
|
12
|
+
"additive_writes_row_set", "common_defaults", "item_defaults",
|
|
13
|
+
"for_branch_control", "data_parallel",
|
|
14
|
+
})
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class Metadata:
|
|
19
|
+
common_input: list[str] = field(default_factory=list)
|
|
20
|
+
common_output: list[str] = field(default_factory=list)
|
|
21
|
+
item_input: list[str] = field(default_factory=list)
|
|
22
|
+
item_output: list[str] = field(default_factory=list)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class OperatorConfig:
|
|
27
|
+
"""Per-operator configuration within a pipeline."""
|
|
28
|
+
|
|
29
|
+
type_name: str = ""
|
|
30
|
+
metadata: Metadata = field(default_factory=Metadata)
|
|
31
|
+
skip: list[str] = field(default_factory=list)
|
|
32
|
+
recall: bool = False
|
|
33
|
+
sources: list[str] = field(default_factory=list)
|
|
34
|
+
debug: bool = False
|
|
35
|
+
consumes_row_set: bool = False
|
|
36
|
+
mutates_row_set: bool = False
|
|
37
|
+
additive_writes_row_set: bool = False
|
|
38
|
+
for_branch_control: bool = False
|
|
39
|
+
data_parallel: int = 1
|
|
40
|
+
common_defaults: dict[str, Any] = field(default_factory=dict)
|
|
41
|
+
item_defaults: dict[str, Any] = field(default_factory=dict)
|
|
42
|
+
raw_params: dict[str, Any] = field(default_factory=dict)
|
|
43
|
+
operator_type: str = ""
|
|
44
|
+
input_spec: InputFieldSpec | None = None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class DefaultedField:
|
|
48
|
+
__slots__ = ("name", "default")
|
|
49
|
+
|
|
50
|
+
def __init__(self, name: str, default: Any):
|
|
51
|
+
self.name = name
|
|
52
|
+
self.default = default
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class InputFieldSpec:
|
|
56
|
+
"""Resolved input field specification with defaults and strict fields."""
|
|
57
|
+
|
|
58
|
+
__slots__ = ("strict_common", "defaulted_common", "strict_item", "defaulted_item")
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
strict_common: list[str],
|
|
63
|
+
defaulted_common: list[DefaultedField],
|
|
64
|
+
strict_item: list[str],
|
|
65
|
+
defaulted_item: list[DefaultedField],
|
|
66
|
+
):
|
|
67
|
+
self.strict_common = strict_common
|
|
68
|
+
self.defaulted_common = defaulted_common
|
|
69
|
+
self.strict_item = strict_item
|
|
70
|
+
self.defaulted_item = defaulted_item
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def compute(metadata: "Metadata", common_defaults: dict[str, Any],
|
|
74
|
+
item_defaults: dict[str, Any], skip: list[str]) -> "InputFieldSpec":
|
|
75
|
+
skip_set = set(skip)
|
|
76
|
+
strict_common: list[str] = []
|
|
77
|
+
defaulted_common: list[DefaultedField] = []
|
|
78
|
+
for f in metadata.common_input:
|
|
79
|
+
if f in skip_set:
|
|
80
|
+
continue
|
|
81
|
+
if f in common_defaults:
|
|
82
|
+
defaulted_common.append(DefaultedField(f, common_defaults[f]))
|
|
83
|
+
else:
|
|
84
|
+
strict_common.append(f)
|
|
85
|
+
|
|
86
|
+
strict_item: list[str] = []
|
|
87
|
+
defaulted_item: list[DefaultedField] = []
|
|
88
|
+
for f in metadata.item_input:
|
|
89
|
+
if f in item_defaults:
|
|
90
|
+
defaulted_item.append(DefaultedField(f, item_defaults[f]))
|
|
91
|
+
else:
|
|
92
|
+
strict_item.append(f)
|
|
93
|
+
|
|
94
|
+
return InputFieldSpec(strict_common, defaulted_common, strict_item, defaulted_item)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class SubFlowRef:
|
|
98
|
+
def __init__(self, pipeline: list[str] | None = None):
|
|
99
|
+
self.pipeline: list[str] = pipeline or []
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class FlowContract:
|
|
103
|
+
"""Declares the pipeline's required inputs and guaranteed outputs."""
|
|
104
|
+
|
|
105
|
+
def __init__(self):
|
|
106
|
+
self.common_input: list[str] = []
|
|
107
|
+
self.item_input: list[str] = []
|
|
108
|
+
self.common_output: list[str] = []
|
|
109
|
+
self.item_output: list[str] = []
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class PipelineConfig:
|
|
113
|
+
"""Holds operator configs and sub-flow pipeline map."""
|
|
114
|
+
|
|
115
|
+
def __init__(self):
|
|
116
|
+
self.operators: dict[str, OperatorConfig] = {}
|
|
117
|
+
self.pipeline_map: dict[str, SubFlowRef] = {}
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class ExpandResult:
|
|
121
|
+
def __init__(self, sequence: list[str], op_to_sub_flow: dict[str, str]):
|
|
122
|
+
self.sequence = sequence
|
|
123
|
+
self.op_to_sub_flow = op_to_sub_flow
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class Config:
|
|
127
|
+
"""Pipeline configuration parsed from JSON."""
|
|
128
|
+
|
|
129
|
+
def __init__(self):
|
|
130
|
+
self.pineapple_version: str = ""
|
|
131
|
+
self.log_prefix: str = ""
|
|
132
|
+
self.debug: bool = False
|
|
133
|
+
self.storage_mode: str = "row"
|
|
134
|
+
self.pipeline_config: PipelineConfig = PipelineConfig()
|
|
135
|
+
self.pipeline_group: dict[str, SubFlowRef] = {}
|
|
136
|
+
self.flow_contract: FlowContract = FlowContract()
|
|
137
|
+
|
|
138
|
+
@classmethod
|
|
139
|
+
def load(cls, json_data: bytes) -> "Config":
|
|
140
|
+
try:
|
|
141
|
+
root = json.loads(json_data)
|
|
142
|
+
except Exception as e:
|
|
143
|
+
raise ConfigError(f"failed to parse config JSON: {e}")
|
|
144
|
+
|
|
145
|
+
cfg = Config()
|
|
146
|
+
cfg.pineapple_version = root.get("_PINEAPPLE_VERSION", "")
|
|
147
|
+
cfg.log_prefix = root.get("log_prefix", "")
|
|
148
|
+
cfg.debug = root.get("debug", False)
|
|
149
|
+
cfg.storage_mode = root.get("storage_mode", "row")
|
|
150
|
+
|
|
151
|
+
fc = root.get("flow_contract", {})
|
|
152
|
+
cfg.flow_contract.common_input = fc.get("common_input", [])
|
|
153
|
+
cfg.flow_contract.item_input = fc.get("item_input", [])
|
|
154
|
+
cfg.flow_contract.common_output = fc.get("common_output", [])
|
|
155
|
+
cfg.flow_contract.item_output = fc.get("item_output", [])
|
|
156
|
+
|
|
157
|
+
pg = root.get("pipeline_group", {})
|
|
158
|
+
for name, val in pg.items():
|
|
159
|
+
cfg.pipeline_group[name] = SubFlowRef(val.get("pipeline", []))
|
|
160
|
+
|
|
161
|
+
pc = root.get("pipeline_config", {})
|
|
162
|
+
pm = pc.get("pipeline_map", {})
|
|
163
|
+
for name, val in pm.items():
|
|
164
|
+
cfg.pipeline_config.pipeline_map[name] = SubFlowRef(val.get("pipeline", []))
|
|
165
|
+
|
|
166
|
+
ops = pc.get("operators", {})
|
|
167
|
+
for name, op_node in ops.items():
|
|
168
|
+
op_cfg = _parse_operator_config(op_node)
|
|
169
|
+
cfg.pipeline_config.operators[name] = op_cfg
|
|
170
|
+
|
|
171
|
+
_validate(cfg)
|
|
172
|
+
return cfg
|
|
173
|
+
|
|
174
|
+
def expand_operator_sequence_with_sub_flows(self) -> ExpandResult:
|
|
175
|
+
if "main" in self.pipeline_group:
|
|
176
|
+
group = self.pipeline_group["main"]
|
|
177
|
+
elif len(self.pipeline_group) == 1:
|
|
178
|
+
group = next(iter(self.pipeline_group.values()))
|
|
179
|
+
else:
|
|
180
|
+
raise ConfigError(
|
|
181
|
+
'pipeline_group must contain a "main" entry or exactly one entry'
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
for name in self.pipeline_config.operators:
|
|
185
|
+
if name in self.pipeline_config.pipeline_map:
|
|
186
|
+
raise ConfigError(
|
|
187
|
+
f'name "{name}" exists in both operators and pipeline_map'
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
sequence: list[str] = []
|
|
191
|
+
op_to_sub_flow: dict[str, str] = {}
|
|
192
|
+
visiting: set[str] = set()
|
|
193
|
+
seen: set[str] = set()
|
|
194
|
+
|
|
195
|
+
self._expand_entries(
|
|
196
|
+
group.pipeline, "", sequence, op_to_sub_flow, visiting, seen
|
|
197
|
+
)
|
|
198
|
+
return ExpandResult(sequence, op_to_sub_flow)
|
|
199
|
+
|
|
200
|
+
def _expand_entries(
|
|
201
|
+
self,
|
|
202
|
+
entries: list[str],
|
|
203
|
+
parent_path: str,
|
|
204
|
+
sequence: list[str],
|
|
205
|
+
op_to_sub_flow: dict[str, str],
|
|
206
|
+
visiting: set[str],
|
|
207
|
+
seen: set[str],
|
|
208
|
+
):
|
|
209
|
+
for entry in entries:
|
|
210
|
+
if entry in self.pipeline_config.operators:
|
|
211
|
+
if entry in seen:
|
|
212
|
+
raise ConfigError(
|
|
213
|
+
f'operator "{entry}" referenced more than once in pipeline tree'
|
|
214
|
+
)
|
|
215
|
+
seen.add(entry)
|
|
216
|
+
sequence.append(entry)
|
|
217
|
+
op_to_sub_flow[entry] = parent_path
|
|
218
|
+
elif entry in self.pipeline_config.pipeline_map:
|
|
219
|
+
if entry in visiting:
|
|
220
|
+
raise ConfigError(
|
|
221
|
+
f'cycle detected in sub-flow expansion: "{entry}"'
|
|
222
|
+
)
|
|
223
|
+
visiting.add(entry)
|
|
224
|
+
self._expand_entries(
|
|
225
|
+
self.pipeline_config.pipeline_map[entry].pipeline,
|
|
226
|
+
entry,
|
|
227
|
+
sequence,
|
|
228
|
+
op_to_sub_flow,
|
|
229
|
+
visiting,
|
|
230
|
+
seen,
|
|
231
|
+
)
|
|
232
|
+
visiting.discard(entry)
|
|
233
|
+
else:
|
|
234
|
+
raise ConfigError(
|
|
235
|
+
f'pipeline entry "{entry}" is neither an operator nor a sub-flow'
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def _parse_operator_config(node: dict[str, Any]) -> OperatorConfig:
|
|
240
|
+
op_cfg = OperatorConfig()
|
|
241
|
+
op_cfg.type_name = node.get("type_name", "")
|
|
242
|
+
op_cfg.recall = node.get("recall", False)
|
|
243
|
+
op_cfg.debug = node.get("debug", False)
|
|
244
|
+
op_cfg.consumes_row_set = node.get("consumes_row_set", False)
|
|
245
|
+
op_cfg.mutates_row_set = node.get("mutates_row_set", False)
|
|
246
|
+
op_cfg.additive_writes_row_set = node.get("additive_writes_row_set", False)
|
|
247
|
+
op_cfg.for_branch_control = node.get("for_branch_control", False)
|
|
248
|
+
op_cfg.data_parallel = node.get("data_parallel", 1)
|
|
249
|
+
op_cfg.sources = node.get("sources", [])
|
|
250
|
+
|
|
251
|
+
skip = node.get("skip", [])
|
|
252
|
+
if isinstance(skip, str):
|
|
253
|
+
op_cfg.skip = [skip] if skip else []
|
|
254
|
+
elif isinstance(skip, list):
|
|
255
|
+
op_cfg.skip = [str(s) for s in skip]
|
|
256
|
+
else:
|
|
257
|
+
op_cfg.skip = []
|
|
258
|
+
|
|
259
|
+
meta = node.get("$metadata", {})
|
|
260
|
+
op_cfg.metadata.common_input = meta.get("common_input", [])
|
|
261
|
+
op_cfg.metadata.common_output = meta.get("common_output", [])
|
|
262
|
+
op_cfg.metadata.item_input = meta.get("item_input", [])
|
|
263
|
+
op_cfg.metadata.item_output = meta.get("item_output", [])
|
|
264
|
+
|
|
265
|
+
op_cfg.common_defaults = node.get("common_defaults", {})
|
|
266
|
+
op_cfg.item_defaults = node.get("item_defaults", {})
|
|
267
|
+
|
|
268
|
+
op_cfg.raw_params = {}
|
|
269
|
+
for key, value in node.items():
|
|
270
|
+
if key not in RESERVED_KEYS:
|
|
271
|
+
op_cfg.raw_params[key] = value
|
|
272
|
+
|
|
273
|
+
return op_cfg
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def _validate(cfg: Config):
|
|
277
|
+
if not cfg.pipeline_config.operators:
|
|
278
|
+
raise ConfigError("pipeline_config.operators is empty")
|
|
279
|
+
if not cfg.pipeline_group:
|
|
280
|
+
raise ConfigError("pipeline_group is empty")
|
|
281
|
+
|
|
282
|
+
for name, op_cfg in cfg.pipeline_config.operators.items():
|
|
283
|
+
if not op_cfg.type_name:
|
|
284
|
+
raise ConfigError(f'operator "{name}": missing type_name')
|
|
285
|
+
|
|
286
|
+
for name, op_cfg in cfg.pipeline_config.operators.items():
|
|
287
|
+
for src in op_cfg.sources:
|
|
288
|
+
if src not in cfg.pipeline_config.operators:
|
|
289
|
+
raise ConfigError(
|
|
290
|
+
f'operator "{name}": sources references undefined operator "{src}"'
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
for name, op_cfg in cfg.pipeline_config.operators.items():
|
|
294
|
+
for skip_field in op_cfg.skip:
|
|
295
|
+
if not skip_field.startswith("_"):
|
|
296
|
+
raise ConfigError(
|
|
297
|
+
f'operator "{name}": skip field "{skip_field}" must start with '
|
|
298
|
+
"'_' (control fields are engine-internal)"
|
|
299
|
+
)
|
|
300
|
+
if skip_field not in op_cfg.metadata.common_input:
|
|
301
|
+
raise ConfigError(
|
|
302
|
+
f'operator "{name}": skip field "{skip_field}" must also appear '
|
|
303
|
+
"in $metadata.common_input to ensure correct DAG ordering"
|
|
304
|
+
)
|
pine/dag.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections import deque
|
|
4
|
+
|
|
5
|
+
from pine.config import OperatorConfig
|
|
6
|
+
from pine.errors import ConfigError
|
|
7
|
+
|
|
8
|
+
_ROW_SET_SENTINEL = "_row_set_"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class _FieldTracker:
|
|
12
|
+
def __init__(self):
|
|
13
|
+
self.last_mut_writer: int = -1
|
|
14
|
+
self.additive_writers: list[int] = []
|
|
15
|
+
self.active_readers: list[int] = []
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Node:
|
|
19
|
+
"""Single operator node within the DAG."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, name: str, index: int, sub_flow: str, config: OperatorConfig):
|
|
22
|
+
self.name = name
|
|
23
|
+
self.index = index
|
|
24
|
+
self.sub_flow = sub_flow
|
|
25
|
+
self.config = config
|
|
26
|
+
self.preds: list[int] = []
|
|
27
|
+
self.succs: list[int] = []
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class DAG:
|
|
31
|
+
"""Directed acyclic graph encoding operator execution dependencies."""
|
|
32
|
+
|
|
33
|
+
def __init__(self, nodes: list[Node], name_to_index: dict[str, int]):
|
|
34
|
+
self.nodes = nodes
|
|
35
|
+
self.name_to_index = name_to_index
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def build(
|
|
39
|
+
cls,
|
|
40
|
+
sequence: list[str],
|
|
41
|
+
operators: dict[str, OperatorConfig],
|
|
42
|
+
op_to_sub_flow: dict[str, str],
|
|
43
|
+
) -> "DAG":
|
|
44
|
+
nodes: list[Node] = []
|
|
45
|
+
name_to_index: dict[str, int] = {}
|
|
46
|
+
|
|
47
|
+
for i, name in enumerate(sequence):
|
|
48
|
+
op_cfg = operators.get(name)
|
|
49
|
+
if op_cfg is None:
|
|
50
|
+
raise ConfigError(f'operator "{name}" not found')
|
|
51
|
+
node = Node(name, i, op_to_sub_flow.get(name, ""), op_cfg)
|
|
52
|
+
nodes.append(node)
|
|
53
|
+
name_to_index[name] = i
|
|
54
|
+
|
|
55
|
+
g = cls(nodes, name_to_index)
|
|
56
|
+
|
|
57
|
+
_add_edges(g, sequence, operators, is_common=True)
|
|
58
|
+
_add_edges(g, sequence, operators, is_common=False)
|
|
59
|
+
|
|
60
|
+
for i, name in enumerate(sequence):
|
|
61
|
+
op_cfg = operators[name]
|
|
62
|
+
for src in op_cfg.sources:
|
|
63
|
+
src_idx = name_to_index.get(src)
|
|
64
|
+
if src_idx is None:
|
|
65
|
+
raise ConfigError(
|
|
66
|
+
f'operator "{name}" sources references unknown operator "{src}"'
|
|
67
|
+
)
|
|
68
|
+
_add_edge(g, src_idx, i)
|
|
69
|
+
|
|
70
|
+
_reduce(g)
|
|
71
|
+
_topological_sort(g)
|
|
72
|
+
|
|
73
|
+
return g
|
|
74
|
+
|
|
75
|
+
def topological_order(self) -> list[int]:
|
|
76
|
+
return _topological_sort(self)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _add_edge(g: DAG, from_idx: int, to_idx: int):
|
|
80
|
+
if from_idx == to_idx:
|
|
81
|
+
return
|
|
82
|
+
if to_idx in g.nodes[from_idx].succs:
|
|
83
|
+
return
|
|
84
|
+
g.nodes[from_idx].succs.append(to_idx)
|
|
85
|
+
g.nodes[to_idx].preds.append(from_idx)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _add_edges(
|
|
89
|
+
g: DAG,
|
|
90
|
+
sequence: list[str],
|
|
91
|
+
operators: dict[str, OperatorConfig],
|
|
92
|
+
is_common: bool,
|
|
93
|
+
):
|
|
94
|
+
fields: dict[str, _FieldTracker] = {}
|
|
95
|
+
|
|
96
|
+
for i, name in enumerate(sequence):
|
|
97
|
+
op_cfg = operators[name]
|
|
98
|
+
meta = op_cfg.metadata
|
|
99
|
+
|
|
100
|
+
read_fields = list(meta.common_input if is_common else meta.item_input)
|
|
101
|
+
write_fields = list(meta.common_output if is_common else meta.item_output)
|
|
102
|
+
is_additive_write = not is_common and op_cfg.additive_writes_row_set
|
|
103
|
+
|
|
104
|
+
if not is_common:
|
|
105
|
+
if is_additive_write:
|
|
106
|
+
write_fields.append(_ROW_SET_SENTINEL)
|
|
107
|
+
if op_cfg.consumes_row_set:
|
|
108
|
+
read_fields.append(_ROW_SET_SENTINEL)
|
|
109
|
+
|
|
110
|
+
# RAW edges
|
|
111
|
+
for field in read_fields:
|
|
112
|
+
ft = fields.setdefault(field, _FieldTracker())
|
|
113
|
+
if ft.last_mut_writer >= 0:
|
|
114
|
+
_add_edge(g, ft.last_mut_writer, i)
|
|
115
|
+
for aw in ft.additive_writers:
|
|
116
|
+
_add_edge(g, aw, i)
|
|
117
|
+
ft.active_readers.append(i)
|
|
118
|
+
|
|
119
|
+
# WAR + WAW edges
|
|
120
|
+
for field in write_fields:
|
|
121
|
+
ft = fields.setdefault(field, _FieldTracker())
|
|
122
|
+
if is_additive_write:
|
|
123
|
+
if ft.last_mut_writer >= 0:
|
|
124
|
+
_add_edge(g, ft.last_mut_writer, i)
|
|
125
|
+
for reader in ft.active_readers:
|
|
126
|
+
if reader != i:
|
|
127
|
+
_add_edge(g, reader, i)
|
|
128
|
+
ft.additive_writers.append(i)
|
|
129
|
+
else:
|
|
130
|
+
if ft.last_mut_writer >= 0:
|
|
131
|
+
_add_edge(g, ft.last_mut_writer, i)
|
|
132
|
+
for aw in ft.additive_writers:
|
|
133
|
+
_add_edge(g, aw, i)
|
|
134
|
+
for reader in ft.active_readers:
|
|
135
|
+
if reader != i:
|
|
136
|
+
_add_edge(g, reader, i)
|
|
137
|
+
ft.last_mut_writer = i
|
|
138
|
+
ft.additive_writers.clear()
|
|
139
|
+
ft.active_readers.clear()
|
|
140
|
+
|
|
141
|
+
# MutatesRowSet: mutating write to _ROW_SET_SENTINEL
|
|
142
|
+
if not is_common and op_cfg.mutates_row_set:
|
|
143
|
+
ft = fields.setdefault(_ROW_SET_SENTINEL, _FieldTracker())
|
|
144
|
+
if ft.last_mut_writer >= 0:
|
|
145
|
+
_add_edge(g, ft.last_mut_writer, i)
|
|
146
|
+
for aw in ft.additive_writers:
|
|
147
|
+
_add_edge(g, aw, i)
|
|
148
|
+
for reader in ft.active_readers:
|
|
149
|
+
if reader != i:
|
|
150
|
+
_add_edge(g, reader, i)
|
|
151
|
+
ft.last_mut_writer = i
|
|
152
|
+
ft.additive_writers.clear()
|
|
153
|
+
ft.active_readers.clear()
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _topological_sort(g: DAG) -> list[int]:
|
|
157
|
+
n = len(g.nodes)
|
|
158
|
+
in_degree = [len(node.preds) for node in g.nodes]
|
|
159
|
+
queue: deque[int] = deque()
|
|
160
|
+
for i in range(n):
|
|
161
|
+
if in_degree[i] == 0:
|
|
162
|
+
queue.append(i)
|
|
163
|
+
|
|
164
|
+
order: list[int] = []
|
|
165
|
+
while queue:
|
|
166
|
+
curr = queue.popleft()
|
|
167
|
+
order.append(curr)
|
|
168
|
+
for succ in g.nodes[curr].succs:
|
|
169
|
+
in_degree[succ] -= 1
|
|
170
|
+
if in_degree[succ] == 0:
|
|
171
|
+
queue.append(succ)
|
|
172
|
+
|
|
173
|
+
if len(order) != n:
|
|
174
|
+
cycle_nodes = [g.nodes[i].name for i in range(n) if in_degree[i] > 0]
|
|
175
|
+
raise ConfigError(
|
|
176
|
+
f"DAG contains a cycle involving operators: {cycle_nodes}"
|
|
177
|
+
)
|
|
178
|
+
return order
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _reduce(g: DAG):
|
|
182
|
+
n = len(g.nodes)
|
|
183
|
+
kept: list[tuple[int, int]] = []
|
|
184
|
+
|
|
185
|
+
for u in range(n):
|
|
186
|
+
for v in g.nodes[u].succs:
|
|
187
|
+
if not _reachable_without(g, u, v):
|
|
188
|
+
kept.append((u, v))
|
|
189
|
+
|
|
190
|
+
for node in g.nodes:
|
|
191
|
+
node.preds.clear()
|
|
192
|
+
node.succs.clear()
|
|
193
|
+
for u, v in kept:
|
|
194
|
+
g.nodes[u].succs.append(v)
|
|
195
|
+
g.nodes[v].preds.append(u)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _reachable_without(g: DAG, src: int, dst: int) -> bool:
|
|
199
|
+
visited = set()
|
|
200
|
+
visited.add(src)
|
|
201
|
+
queue: deque[int] = deque()
|
|
202
|
+
|
|
203
|
+
for next_node in g.nodes[src].succs:
|
|
204
|
+
if next_node == dst:
|
|
205
|
+
continue
|
|
206
|
+
if next_node not in visited:
|
|
207
|
+
visited.add(next_node)
|
|
208
|
+
queue.append(next_node)
|
|
209
|
+
|
|
210
|
+
while queue:
|
|
211
|
+
cur = queue.popleft()
|
|
212
|
+
if cur == dst:
|
|
213
|
+
return True
|
|
214
|
+
for next_node in g.nodes[cur].succs:
|
|
215
|
+
if next_node not in visited:
|
|
216
|
+
visited.add(next_node)
|
|
217
|
+
queue.append(next_node)
|
|
218
|
+
return False
|