libcirkit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cirkit/__init__.py +0 -0
- cirkit/backend/__init__.py +0 -0
- cirkit/backend/base.py +199 -0
- cirkit/backend/compiler.py +213 -0
- cirkit/backend/registry.py +53 -0
- cirkit/backend/torch/__init__.py +0 -0
- cirkit/backend/torch/circuits.py +217 -0
- cirkit/backend/torch/compiler.py +592 -0
- cirkit/backend/torch/graph/__init__.py +0 -0
- cirkit/backend/torch/graph/folding.py +230 -0
- cirkit/backend/torch/graph/modules.py +276 -0
- cirkit/backend/torch/graph/optimize.py +258 -0
- cirkit/backend/torch/initializers.py +49 -0
- cirkit/backend/torch/layers/__init__.py +16 -0
- cirkit/backend/torch/layers/base.py +119 -0
- cirkit/backend/torch/layers/inner.py +335 -0
- cirkit/backend/torch/layers/input.py +746 -0
- cirkit/backend/torch/layers/optimized.py +241 -0
- cirkit/backend/torch/optimization/__init__.py +0 -0
- cirkit/backend/torch/optimization/layers.py +166 -0
- cirkit/backend/torch/optimization/parameters.py +67 -0
- cirkit/backend/torch/optimization/registry.py +81 -0
- cirkit/backend/torch/parameters/__init__.py +0 -0
- cirkit/backend/torch/parameters/nodes.py +828 -0
- cirkit/backend/torch/parameters/parameter.py +117 -0
- cirkit/backend/torch/parameters/pic.py +418 -0
- cirkit/backend/torch/queries.py +178 -0
- cirkit/backend/torch/rules/__init__.py +3 -0
- cirkit/backend/torch/rules/initializers.py +53 -0
- cirkit/backend/torch/rules/layers.py +184 -0
- cirkit/backend/torch/rules/parameters.py +280 -0
- cirkit/backend/torch/semiring.py +492 -0
- cirkit/backend/torch/utils.py +102 -0
- cirkit/pipeline.py +355 -0
- cirkit/symbolic/__init__.py +0 -0
- cirkit/symbolic/circuit.py +938 -0
- cirkit/symbolic/dtypes.py +45 -0
- cirkit/symbolic/functional.py +674 -0
- cirkit/symbolic/initializers.py +121 -0
- cirkit/symbolic/layers.py +788 -0
- cirkit/symbolic/operators.py +384 -0
- cirkit/symbolic/parameters.py +921 -0
- cirkit/symbolic/registry.py +119 -0
- cirkit/templates/__init__.py +0 -0
- cirkit/templates/circuit_templates/__init__.py +2 -0
- cirkit/templates/circuit_templates/data.py +107 -0
- cirkit/templates/circuit_templates/utils.py +287 -0
- cirkit/templates/region_graph/__init__.py +11 -0
- cirkit/templates/region_graph/algorithms/__init__.py +9 -0
- cirkit/templates/region_graph/algorithms/chow_liu.py +141 -0
- cirkit/templates/region_graph/algorithms/factorized.py +43 -0
- cirkit/templates/region_graph/algorithms/linear.py +77 -0
- cirkit/templates/region_graph/algorithms/poon_domingos.py +203 -0
- cirkit/templates/region_graph/algorithms/quad.py +179 -0
- cirkit/templates/region_graph/algorithms/random.py +110 -0
- cirkit/templates/region_graph/algorithms/utils.py +124 -0
- cirkit/templates/region_graph/graph.py +335 -0
- cirkit/utils/__init__.py +0 -0
- cirkit/utils/algorithms.py +218 -0
- cirkit/utils/scope.py +13 -0
- libcirkit-0.1.0.dist-info/LICENSE +674 -0
- libcirkit-0.1.0.dist-info/METADATA +200 -0
- libcirkit-0.1.0.dist-info/RECORD +65 -0
- libcirkit-0.1.0.dist-info/WHEEL +5 -0
- libcirkit-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,592 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from collections.abc import Callable, Sequence
|
|
4
|
+
from itertools import chain
|
|
5
|
+
from typing import cast
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import Tensor
|
|
9
|
+
|
|
10
|
+
from cirkit.backend.compiler import (
|
|
11
|
+
AbstractCompiler,
|
|
12
|
+
CompilerInitializerRegistry,
|
|
13
|
+
CompilerLayerRegistry,
|
|
14
|
+
CompilerParameterRegistry,
|
|
15
|
+
)
|
|
16
|
+
from cirkit.backend.registry import CompilerRegistry
|
|
17
|
+
from cirkit.backend.torch.circuits import AbstractTorchCircuit, TorchCircuit, TorchConstantCircuit
|
|
18
|
+
from cirkit.backend.torch.graph.folding import build_folded_graph
|
|
19
|
+
from cirkit.backend.torch.graph.optimize import (
|
|
20
|
+
GraphOptPattern,
|
|
21
|
+
match_optimization_patterns,
|
|
22
|
+
optimize_graph,
|
|
23
|
+
)
|
|
24
|
+
from cirkit.backend.torch.initializers import stacked_initializer_
|
|
25
|
+
from cirkit.backend.torch.layers import TorchInputLayer, TorchLayer
|
|
26
|
+
from cirkit.backend.torch.layers.input import TorchConstantLayer
|
|
27
|
+
from cirkit.backend.torch.optimization.layers import (
|
|
28
|
+
DEFAULT_LAYER_FUSE_OPT_RULES,
|
|
29
|
+
DEFAULT_LAYER_SHATTER_OPT_RULES,
|
|
30
|
+
)
|
|
31
|
+
from cirkit.backend.torch.optimization.parameters import DEFAULT_PARAMETER_OPT_RULES
|
|
32
|
+
from cirkit.backend.torch.optimization.registry import (
|
|
33
|
+
LayerOptApplyFunc,
|
|
34
|
+
LayerOptMatch,
|
|
35
|
+
LayerOptPattern,
|
|
36
|
+
LayerOptRegistry,
|
|
37
|
+
ParameterOptApplyFunc,
|
|
38
|
+
ParameterOptMatch,
|
|
39
|
+
ParameterOptPattern,
|
|
40
|
+
ParameterOptRegistry,
|
|
41
|
+
)
|
|
42
|
+
from cirkit.backend.torch.parameters.nodes import (
|
|
43
|
+
TorchParameterNode,
|
|
44
|
+
TorchParameterOp,
|
|
45
|
+
TorchPointerParameter,
|
|
46
|
+
TorchTensorParameter,
|
|
47
|
+
)
|
|
48
|
+
from cirkit.backend.torch.parameters.parameter import TorchParameter
|
|
49
|
+
from cirkit.backend.torch.rules import (
|
|
50
|
+
DEFAULT_INITIALIZER_COMPILATION_RULES,
|
|
51
|
+
DEFAULT_LAYER_COMPILATION_RULES,
|
|
52
|
+
DEFAULT_PARAMETER_COMPILATION_RULES,
|
|
53
|
+
)
|
|
54
|
+
from cirkit.backend.torch.semiring import Semiring, SemiringImpl
|
|
55
|
+
from cirkit.symbolic.circuit import Circuit, pipeline_topological_ordering
|
|
56
|
+
from cirkit.symbolic.initializers import Initializer
|
|
57
|
+
from cirkit.symbolic.layers import Layer
|
|
58
|
+
from cirkit.symbolic.parameters import Parameter, ParameterNode, TensorParameter
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class TorchCompilerState:
|
|
62
|
+
def __init__(self):
|
|
63
|
+
# A map from symbolic parameter tensors to a tuple containing the compiled parameter tensor,
|
|
64
|
+
# and the slice index, which is 0 if the compiled parameter tensor is unfolded.
|
|
65
|
+
# If the compiled parameter tensor is folded, then the slice index can be non-zero.
|
|
66
|
+
self._compiled_parameters: dict[TensorParameter, tuple[TorchTensorParameter, int]] = {}
|
|
67
|
+
|
|
68
|
+
# We keep a reverse map from compiled and unfolded parameter tensors
|
|
69
|
+
# to the corresponding symbolic parameter tensors.
|
|
70
|
+
# This is useful to update the map from symbolic to compiled parameter tensors above
|
|
71
|
+
# after we fold the tensor parameters within a circuit.
|
|
72
|
+
# Since this is useful only for folding, it will be cleared after each circuit compilation.
|
|
73
|
+
self._symbolic_parameters: dict[TorchTensorParameter, TensorParameter] = {}
|
|
74
|
+
|
|
75
|
+
def finish_compilation(self) -> None:
|
|
76
|
+
# Clear the map from (unfolded) compiled parameter tensors to symbolic ones
|
|
77
|
+
self._symbolic_parameters.clear()
|
|
78
|
+
|
|
79
|
+
def has_compiled_parameter(self, p: TensorParameter) -> bool:
|
|
80
|
+
# Retrieve whether a tensor parameter has already been compiled
|
|
81
|
+
return p in self._compiled_parameters
|
|
82
|
+
|
|
83
|
+
def retrieve_compiled_parameter(self, p: TensorParameter) -> tuple[TorchTensorParameter, int]:
|
|
84
|
+
# Retrieve the compiled parameter: we return the fold index as well.
|
|
85
|
+
return self._compiled_parameters[p]
|
|
86
|
+
|
|
87
|
+
def retrieve_symbolic_parameter(self, p: TorchTensorParameter) -> TensorParameter:
|
|
88
|
+
# Retrieve the symbolic parameter tensor associated to the compiled one (which is unfolded)
|
|
89
|
+
return self._symbolic_parameters[p]
|
|
90
|
+
|
|
91
|
+
def register_compiled_parameter(
|
|
92
|
+
self, sp: TensorParameter, cp: TorchTensorParameter, *, fold_idx: int | None = None
|
|
93
|
+
) -> None:
|
|
94
|
+
# Register a link from a symbolic parameter tensor to a compiled parameter tensor.
|
|
95
|
+
if fold_idx is None:
|
|
96
|
+
# We are registering an unfolded compiled parameter tensor
|
|
97
|
+
# So, we can also register the reverse map (i.e., compiled to symbolic)
|
|
98
|
+
self._compiled_parameters[sp] = (cp, 0)
|
|
99
|
+
self._symbolic_parameters[cp] = sp
|
|
100
|
+
|
|
101
|
+
# We are registering a folded compiled parameter tensor
|
|
102
|
+
# So, we associate the symbolic parameter tensor to a particular slice of the
|
|
103
|
+
# folded compiled parameter tensor, which is specified by the 'fold_idx'.
|
|
104
|
+
self._compiled_parameters[sp] = (cp, fold_idx)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class TorchCompiler(AbstractCompiler):
|
|
108
|
+
def __init__(self, semiring: str = "sum-product", fold: bool = False, optimize: bool = False):
|
|
109
|
+
super().__init__(
|
|
110
|
+
CompilerLayerRegistry(DEFAULT_LAYER_COMPILATION_RULES),
|
|
111
|
+
CompilerParameterRegistry(DEFAULT_PARAMETER_COMPILATION_RULES),
|
|
112
|
+
CompilerInitializerRegistry(DEFAULT_INITIALIZER_COMPILATION_RULES),
|
|
113
|
+
fold=fold,
|
|
114
|
+
optimize=optimize,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# The semiring being used at compile time
|
|
118
|
+
self._semiring: Semiring = SemiringImpl.from_name(semiring)
|
|
119
|
+
|
|
120
|
+
# The state of the compiler
|
|
121
|
+
self._state = TorchCompilerState()
|
|
122
|
+
|
|
123
|
+
# The registry of optimization rules
|
|
124
|
+
self._optimization_registry = {
|
|
125
|
+
"parameter": ParameterOptRegistry(DEFAULT_PARAMETER_OPT_RULES),
|
|
126
|
+
"layer_fuse": LayerOptRegistry(DEFAULT_LAYER_FUSE_OPT_RULES),
|
|
127
|
+
"layer_shatter": LayerOptRegistry(DEFAULT_LAYER_SHATTER_OPT_RULES),
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
def compile_pipeline(self, sc: Circuit) -> AbstractTorchCircuit:
|
|
131
|
+
# Compile the circuits following the topological ordering of the pipeline.
|
|
132
|
+
for sci in pipeline_topological_ordering([sc]):
|
|
133
|
+
# Check if the circuit in the pipeline has already been compiled
|
|
134
|
+
if self.is_compiled(sci):
|
|
135
|
+
continue
|
|
136
|
+
|
|
137
|
+
# Compile the circuit
|
|
138
|
+
self._compile_circuit(sci)
|
|
139
|
+
|
|
140
|
+
# Return the compiled circuit (i.e., the output of the circuit pipeline)
|
|
141
|
+
return self.get_compiled_circuit(sc)
|
|
142
|
+
|
|
143
|
+
@property
|
|
144
|
+
def semiring(self) -> Semiring:
|
|
145
|
+
return self._semiring
|
|
146
|
+
|
|
147
|
+
@property
|
|
148
|
+
def is_fold_enabled(self) -> bool:
|
|
149
|
+
return self._flags["fold"]
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
def is_optimize_enabled(self) -> bool:
|
|
153
|
+
return self._flags["optimize"]
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def state(self) -> TorchCompilerState:
|
|
157
|
+
return self._state
|
|
158
|
+
|
|
159
|
+
def compile_layer(self, layer: Layer) -> TorchLayer:
|
|
160
|
+
signature = type(layer)
|
|
161
|
+
rule = self.retrieve_layer_rule(signature)
|
|
162
|
+
return cast(TorchLayer, rule(self, layer))
|
|
163
|
+
|
|
164
|
+
def compile_parameter(self, parameter: Parameter) -> TorchParameter:
|
|
165
|
+
# A map from symbolic to compiled parameters
|
|
166
|
+
compiled_nodes_map: dict[ParameterNode, TorchParameterNode] = {}
|
|
167
|
+
|
|
168
|
+
# The parameter nodes, and their inputs
|
|
169
|
+
nodes: list[TorchParameterNode] = []
|
|
170
|
+
in_nodes: dict[TorchParameterNode, list[TorchParameterNode]] = {}
|
|
171
|
+
|
|
172
|
+
# Compile the parameter by following the topological ordering
|
|
173
|
+
for p in parameter.topological_ordering():
|
|
174
|
+
# Compile the parameter node and make the connections
|
|
175
|
+
compiled_p = self._compile_parameter_node(p)
|
|
176
|
+
in_compiled_nodes = [compiled_nodes_map[pi] for pi in parameter.node_inputs(p)]
|
|
177
|
+
in_nodes[compiled_p] = in_compiled_nodes
|
|
178
|
+
compiled_nodes_map[p] = compiled_p
|
|
179
|
+
nodes.append(compiled_p)
|
|
180
|
+
|
|
181
|
+
# Build the parameter's computational graph
|
|
182
|
+
outputs = [compiled_nodes_map[parameter.output]]
|
|
183
|
+
return TorchParameter(nodes, in_nodes, outputs)
|
|
184
|
+
|
|
185
|
+
def compile_initializer(self, initializer: Initializer) -> Callable[[Tensor], Tensor]:
|
|
186
|
+
# Retrieve the rule for the given initializer and compile it
|
|
187
|
+
signature = type(initializer)
|
|
188
|
+
rule = self.retrieve_initializer_rule(signature)
|
|
189
|
+
return cast(Callable[[Tensor], Tensor], rule(self, initializer))
|
|
190
|
+
|
|
191
|
+
def retrieve_optimization_registry(self, kind: str) -> CompilerRegistry:
|
|
192
|
+
return cast(CompilerRegistry, self._optimization_registry[kind])
|
|
193
|
+
|
|
194
|
+
def retrieve_optimization_rule(self, kind: str, pattern: GraphOptPattern) -> Callable:
|
|
195
|
+
registry = self.retrieve_optimization_registry(kind)
|
|
196
|
+
return registry.retrieve_rule(pattern)
|
|
197
|
+
|
|
198
|
+
def _compile_parameter_node(self, node: ParameterNode) -> TorchParameterNode:
|
|
199
|
+
signature = type(node)
|
|
200
|
+
rule = self.retrieve_parameter_rule(signature)
|
|
201
|
+
return cast(TorchParameterNode, rule(self, node))
|
|
202
|
+
|
|
203
|
+
def _compile_circuit(self, sc: Circuit) -> AbstractTorchCircuit:
|
|
204
|
+
# A map from symbolic to compiled layers
|
|
205
|
+
compiled_layers_map: dict[Layer, TorchLayer] = {}
|
|
206
|
+
|
|
207
|
+
# The inputs of each layer
|
|
208
|
+
in_layers: dict[TorchLayer, list[TorchLayer]] = {}
|
|
209
|
+
|
|
210
|
+
# Compile layers by following the topological ordering
|
|
211
|
+
for sl in sc.topological_ordering():
|
|
212
|
+
# Compile the layer, for any layer types
|
|
213
|
+
layer = self.compile_layer(sl)
|
|
214
|
+
|
|
215
|
+
# Build the connectivity between compiled layers
|
|
216
|
+
ins = [compiled_layers_map[sli] for sli in sc.layer_inputs(sl)]
|
|
217
|
+
in_layers[layer] = ins
|
|
218
|
+
compiled_layers_map[sl] = layer
|
|
219
|
+
|
|
220
|
+
# If the symbolic circuit being compiled has empty scope,
|
|
221
|
+
# then return a 'constant circuit' whose interface does not require inputs
|
|
222
|
+
cc_cls = TorchCircuit if sc.scope else TorchConstantCircuit
|
|
223
|
+
|
|
224
|
+
# Construct the sequence of output layers
|
|
225
|
+
outputs = [compiled_layers_map[sl] for sl in sc.outputs]
|
|
226
|
+
|
|
227
|
+
# Construct the tensorized circuit
|
|
228
|
+
layers = [compiled_layers_map[sl] for sl in compiled_layers_map.keys()]
|
|
229
|
+
cc = cc_cls(
|
|
230
|
+
sc.scope,
|
|
231
|
+
sc.num_channels,
|
|
232
|
+
layers=layers,
|
|
233
|
+
in_layers=in_layers,
|
|
234
|
+
outputs=outputs,
|
|
235
|
+
properties=sc.properties,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
# Post-process the compiled circuit, i.e.,
|
|
239
|
+
# optionally apply optimizations to it and then fold it
|
|
240
|
+
cc = self._post_process_circuit(cc)
|
|
241
|
+
|
|
242
|
+
# Allocate & initialize the parameters
|
|
243
|
+
cc.reset_parameters()
|
|
244
|
+
|
|
245
|
+
# Register the compiled circuit
|
|
246
|
+
self.register_compiled_circuit(sc, cc)
|
|
247
|
+
|
|
248
|
+
# Signal the end of the circuit compilation to the state
|
|
249
|
+
self._state.finish_compilation()
|
|
250
|
+
return cc
|
|
251
|
+
|
|
252
|
+
def _post_process_circuit(self, cc: AbstractTorchCircuit) -> AbstractTorchCircuit:
|
|
253
|
+
if self.is_optimize_enabled:
|
|
254
|
+
# Optimize the circuit computational graph
|
|
255
|
+
opt_cc = _optimize_circuit(self, cc, max_opt_steps=5)
|
|
256
|
+
del cc
|
|
257
|
+
cc = opt_cc
|
|
258
|
+
if self.is_fold_enabled:
|
|
259
|
+
# Optimize the circuit by folding it
|
|
260
|
+
opt_cc = _fold_circuit(self, cc)
|
|
261
|
+
del cc
|
|
262
|
+
cc = opt_cc
|
|
263
|
+
return cc
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def _fold_circuit(compiler: TorchCompiler, cc: AbstractTorchCircuit) -> AbstractTorchCircuit:
|
|
267
|
+
# Fold the layers in the given circuit, by following the layer-wise topological ordering
|
|
268
|
+
layers, in_layers, outputs, fold_idx_info = build_folded_graph(
|
|
269
|
+
cc.layerwise_topological_ordering(),
|
|
270
|
+
outputs=cc.outputs,
|
|
271
|
+
incomings_fn=cc.layer_inputs,
|
|
272
|
+
fold_group_fn=functools.partial(_fold_layers_group, compiler=compiler),
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
# Instantiate a folded circuit
|
|
276
|
+
return type(cc)(
|
|
277
|
+
cc.scope,
|
|
278
|
+
cc.num_channels,
|
|
279
|
+
layers,
|
|
280
|
+
in_layers,
|
|
281
|
+
outputs,
|
|
282
|
+
properties=cc.properties,
|
|
283
|
+
fold_idx_info=fold_idx_info,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def _fold_layers_group(layers: list[TorchLayer], *, compiler: TorchCompiler) -> TorchLayer:
|
|
288
|
+
# Retrieve the class of the folded layer, as well as the configuration attributes
|
|
289
|
+
fold_layer_cls = type(layers[0])
|
|
290
|
+
fold_layer_conf = layers[0].config
|
|
291
|
+
|
|
292
|
+
# If we are folding input layers, then concatenate the variables scope index tensors
|
|
293
|
+
kwargs = {}
|
|
294
|
+
if issubclass(fold_layer_cls, TorchInputLayer):
|
|
295
|
+
if not issubclass(fold_layer_cls, TorchConstantLayer):
|
|
296
|
+
kwargs["scope_idx"] = torch.cat([l.scope_idx for l in layers])
|
|
297
|
+
else:
|
|
298
|
+
# We are folding sum or product layers, so simply set the number of folds
|
|
299
|
+
kwargs["num_folds"] = sum(l.num_folds for l in layers)
|
|
300
|
+
|
|
301
|
+
# Retrieve the parameters of each layer, and
|
|
302
|
+
# retrieve the sub-module layers of each layer
|
|
303
|
+
layer_params: dict[str, list[TorchParameter]] = defaultdict(list)
|
|
304
|
+
layer_submodules: dict[str, list[TorchLayer]] = defaultdict(list)
|
|
305
|
+
for l in layers:
|
|
306
|
+
for n, p in l.params.items():
|
|
307
|
+
layer_params[n].append(p)
|
|
308
|
+
for n, sub_l in l.sub_modules.items():
|
|
309
|
+
layer_submodules[n].append(sub_l)
|
|
310
|
+
|
|
311
|
+
# Fold the parameters, if the layers have any
|
|
312
|
+
fold_layer_parameters: dict[str, TorchParameter] = {
|
|
313
|
+
n: _fold_parameters(compiler, ps) for n, ps in layer_params.items()
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
# Fold all sub-module layers, if the layers have any
|
|
317
|
+
fold_layer_submodules: dict[str, TorchLayer] = {
|
|
318
|
+
n: _fold_layers_group(ls, compiler=compiler) for n, ls in layer_submodules.items()
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
# Instantiate a new folded layer, using the folded layer configuration and the folded parameters
|
|
322
|
+
return fold_layer_cls(
|
|
323
|
+
**fold_layer_conf,
|
|
324
|
+
**fold_layer_submodules,
|
|
325
|
+
**fold_layer_parameters,
|
|
326
|
+
semiring=compiler.semiring,
|
|
327
|
+
**kwargs,
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def _fold_parameters(compiler: TorchCompiler, parameters: list[TorchParameter]) -> TorchParameter:
|
|
332
|
+
# Retrieve:
|
|
333
|
+
# (i) the parameter nodes and the input to each node;
|
|
334
|
+
# (ii) the layer-wise (aka bottom-up) topological orderings of parameter nodes
|
|
335
|
+
in_nodes: dict[TorchParameterNode, Sequence[TorchParameterNode]] = {}
|
|
336
|
+
for pi in parameters:
|
|
337
|
+
in_nodes.update(pi.nodes_inputs)
|
|
338
|
+
ordering: list[list[TorchParameterNode]] = []
|
|
339
|
+
for pi in parameters:
|
|
340
|
+
for i, frontier in enumerate(pi.layerwise_topological_ordering()):
|
|
341
|
+
if i < len(ordering):
|
|
342
|
+
ordering[i].extend(frontier)
|
|
343
|
+
continue
|
|
344
|
+
ordering.append(frontier)
|
|
345
|
+
|
|
346
|
+
# Fold the nodes in the merged parameter computational graphs,
|
|
347
|
+
# by following the layer-wise topological ordering
|
|
348
|
+
nodes, in_nodes, outputs, fold_idx_info = build_folded_graph(
|
|
349
|
+
ordering,
|
|
350
|
+
outputs=chain.from_iterable(map(lambda pi: pi.outputs, parameters)),
|
|
351
|
+
incomings_fn=in_nodes.get,
|
|
352
|
+
fold_group_fn=functools.partial(_fold_parameter_nodes_group, compiler=compiler),
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
# Construct the folded parameter's computational graph
|
|
356
|
+
return TorchParameter(nodes, in_nodes, outputs, fold_idx_info=fold_idx_info)
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def _fold_parameter_nodes_group(
|
|
360
|
+
group: list[TorchParameterNode], *, compiler: TorchCompiler
|
|
361
|
+
) -> TorchParameterNode:
|
|
362
|
+
fold_node_cls = type(group[0])
|
|
363
|
+
# Catch the case we are folding tensor parameters
|
|
364
|
+
# That is, we set the number of folds, copy the number of parameters and relevant flags,
|
|
365
|
+
# and stack the initialization functions together.
|
|
366
|
+
if issubclass(fold_node_cls, TorchTensorParameter):
|
|
367
|
+
assert all(isinstance(p, TorchTensorParameter) for p in group)
|
|
368
|
+
folded_node = TorchTensorParameter(
|
|
369
|
+
*group[0].shape,
|
|
370
|
+
num_folds=len(group),
|
|
371
|
+
requires_grad=group[0].requires_grad,
|
|
372
|
+
initializer_=functools.partial(
|
|
373
|
+
stacked_initializer_, initializers=list(map(lambda p: p.initializer, group))
|
|
374
|
+
),
|
|
375
|
+
dtype=group[0].dtype,
|
|
376
|
+
)
|
|
377
|
+
# If we are folding parameter tensors, then update the registry as to maintain the correct
|
|
378
|
+
# mapping between symbolic parameter leaves (which are unfolded) and slices within the folded
|
|
379
|
+
# compiled parameter leaves.
|
|
380
|
+
for i, p in enumerate(group):
|
|
381
|
+
sp = compiler.state.retrieve_symbolic_parameter(p)
|
|
382
|
+
compiler.state.register_compiled_parameter(sp, folded_node, fold_idx=i)
|
|
383
|
+
return folded_node
|
|
384
|
+
# Catch the case we are folding parameters obtained via slicing
|
|
385
|
+
# This case regularly fires when doing operations over circuits
|
|
386
|
+
# that are compiled into folded tensorized circuits
|
|
387
|
+
if issubclass(fold_node_cls, TorchPointerParameter):
|
|
388
|
+
assert all(isinstance(p, TorchPointerParameter) for p in group)
|
|
389
|
+
if len(group) == 1:
|
|
390
|
+
# Catch the case we are not able to fold multiple tensor slicing operations
|
|
391
|
+
# In such a case, just have the slice as folded parameter (i.e., number of folds = 1)
|
|
392
|
+
return group[0]
|
|
393
|
+
# Catch the case we are able to fold multiple tensor slicing operations
|
|
394
|
+
in_folded_node = group[0].deref()
|
|
395
|
+
in_fold_idx: list[int] = list(
|
|
396
|
+
chain.from_iterable(
|
|
397
|
+
list(range(p.num_folds)) if p.fold_idx is None else p.fold_idx for p in group
|
|
398
|
+
)
|
|
399
|
+
)
|
|
400
|
+
return TorchPointerParameter(in_folded_node, fold_idx=in_fold_idx)
|
|
401
|
+
# We are folding an operator: just set the number of folds and copy the configuration parameters
|
|
402
|
+
assert all(isinstance(p, TorchParameterOp) for p in group)
|
|
403
|
+
return fold_node_cls(**group[0].config, num_folds=len(group))
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
def _optimize_circuit(
|
|
407
|
+
compiler: TorchCompiler, cc: AbstractTorchCircuit, *, max_opt_steps: int = 5
|
|
408
|
+
) -> AbstractTorchCircuit:
|
|
409
|
+
assert max_opt_steps > 0
|
|
410
|
+
|
|
411
|
+
# Each optimization step consists of three kinds of optimizations (see below).
|
|
412
|
+
# We continue optimizing until no further optimization can be performed
|
|
413
|
+
# or if we reach a maximum number of optimization steps being performed
|
|
414
|
+
optimizing = True
|
|
415
|
+
opt_step = 0
|
|
416
|
+
while optimizing and opt_step < max_opt_steps:
|
|
417
|
+
# First optimization step: optimize the parameters node of the parameter graphs of each layer
|
|
418
|
+
opt_cc, opt_fuse_parameter_nodes = _optimize_parameter_nodes(compiler, cc)
|
|
419
|
+
del cc
|
|
420
|
+
cc = opt_cc
|
|
421
|
+
|
|
422
|
+
# Second optimization step: shatter layers in multiple more efficient ones
|
|
423
|
+
opt_cc, opt_shatter_layers = _optimize_layers(compiler, cc, shatter=True)
|
|
424
|
+
del cc
|
|
425
|
+
cc = opt_cc
|
|
426
|
+
|
|
427
|
+
# Third optimization step: fuse multiple layers into a single more efficient one
|
|
428
|
+
opt_cc, opt_fuse_layers = _optimize_layers(compiler, cc, shatter=False)
|
|
429
|
+
del cc
|
|
430
|
+
cc = opt_cc
|
|
431
|
+
|
|
432
|
+
# Update the optimization step and whether we should continue optimizing
|
|
433
|
+
optimizing = opt_fuse_parameter_nodes or opt_shatter_layers or opt_fuse_layers
|
|
434
|
+
opt_step += 1
|
|
435
|
+
|
|
436
|
+
return cc
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def _optimize_parameter_nodes(
|
|
440
|
+
compiler: TorchCompiler, cc: AbstractTorchCircuit
|
|
441
|
+
) -> tuple[AbstractTorchCircuit, bool]:
|
|
442
|
+
def match_optimizer(match: ParameterOptMatch) -> tuple[TorchParameterNode, ...]:
|
|
443
|
+
rule = compiler.retrieve_optimization_rule("parameter", match.pattern)
|
|
444
|
+
func = cast(ParameterOptApplyFunc, rule)
|
|
445
|
+
return func(compiler, match)
|
|
446
|
+
|
|
447
|
+
# Loop through all the layers
|
|
448
|
+
has_been_optimized = False
|
|
449
|
+
patterns = compiler.retrieve_optimization_registry("parameter").signatures
|
|
450
|
+
for layer in cc.layers:
|
|
451
|
+
# Retrieve the parameter computational graphs of the layer
|
|
452
|
+
for pname, pgraph in layer.params.items():
|
|
453
|
+
# Optimize the parameter computational graph
|
|
454
|
+
optimize_result = optimize_graph(
|
|
455
|
+
pgraph.topological_ordering(),
|
|
456
|
+
pgraph.outputs,
|
|
457
|
+
patterns,
|
|
458
|
+
incomings_fn=pgraph.node_inputs,
|
|
459
|
+
outcomings_fn=pgraph.node_outputs,
|
|
460
|
+
pattern_matcher_fn=_match_parameter_nodes_pattern,
|
|
461
|
+
match_optimizer_fn=match_optimizer,
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
# Check if no optimization is possible
|
|
465
|
+
if optimize_result is None:
|
|
466
|
+
continue
|
|
467
|
+
nodes, in_nodes, outputs = optimize_result
|
|
468
|
+
|
|
469
|
+
# Build the optimized computational graph
|
|
470
|
+
pgraph = type(pgraph)(nodes, in_nodes, outputs)
|
|
471
|
+
|
|
472
|
+
# Update the parameter computational graph assigned to the layer
|
|
473
|
+
assert hasattr(layer, pname)
|
|
474
|
+
setattr(layer, pname, pgraph)
|
|
475
|
+
has_been_optimized = True
|
|
476
|
+
|
|
477
|
+
# Check whether no parameter optimization has been possible
|
|
478
|
+
if has_been_optimized:
|
|
479
|
+
return cc, True
|
|
480
|
+
return cc, False
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def _optimize_layers(
|
|
484
|
+
compiler: TorchCompiler, cc: AbstractTorchCircuit, *, shatter: bool = False
|
|
485
|
+
) -> tuple[AbstractTorchCircuit, bool]:
|
|
486
|
+
def match_optimizer_shatter(match: LayerOptMatch) -> tuple[TorchLayer, ...]:
|
|
487
|
+
rule = compiler.retrieve_optimization_rule("layer_shatter", match.pattern)
|
|
488
|
+
func = cast(LayerOptApplyFunc, rule)
|
|
489
|
+
return func(compiler, match)
|
|
490
|
+
|
|
491
|
+
def match_optimizer_fuse(match: LayerOptMatch) -> tuple[TorchLayer, ...]:
|
|
492
|
+
rule = compiler.retrieve_optimization_rule("layer_fuse", match.pattern)
|
|
493
|
+
func = cast(LayerOptApplyFunc, rule)
|
|
494
|
+
return func(compiler, match)
|
|
495
|
+
|
|
496
|
+
registry = compiler.retrieve_optimization_registry("layer_shatter" if shatter else "layer_fuse")
|
|
497
|
+
match_optimizer = match_optimizer_shatter if shatter else match_optimizer_fuse
|
|
498
|
+
optimize_result = optimize_graph(
|
|
499
|
+
cc.topological_ordering(),
|
|
500
|
+
cc.outputs,
|
|
501
|
+
registry.signatures,
|
|
502
|
+
incomings_fn=cc.layer_inputs,
|
|
503
|
+
outcomings_fn=cc.layer_outputs,
|
|
504
|
+
pattern_matcher_fn=_match_layer_pattern,
|
|
505
|
+
match_optimizer_fn=match_optimizer,
|
|
506
|
+
)
|
|
507
|
+
if optimize_result is None:
|
|
508
|
+
return cc, False
|
|
509
|
+
layers, in_layers, outputs = optimize_result
|
|
510
|
+
cc = type(cc)(cc.scope, cc.num_channels, layers, in_layers, outputs, properties=cc.properties)
|
|
511
|
+
return cc, True
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
def _match_parameter_nodes_pattern(
|
|
515
|
+
node: TorchParameterNode,
|
|
516
|
+
pattern: ParameterOptPattern,
|
|
517
|
+
*,
|
|
518
|
+
incomings_fn: Callable[[TorchParameterNode], Sequence[TorchParameterNode]],
|
|
519
|
+
outcomings_fn: Callable[[TorchParameterNode], Sequence[TorchParameterNode]],
|
|
520
|
+
) -> ParameterOptMatch | None:
|
|
521
|
+
pattern_entries = pattern.entries()
|
|
522
|
+
num_entries = len(pattern_entries)
|
|
523
|
+
matched_nodes = []
|
|
524
|
+
|
|
525
|
+
# Start matching the pattern from the root
|
|
526
|
+
# TODO: generalize to match DAGs or binary trees
|
|
527
|
+
for nid in range(num_entries):
|
|
528
|
+
if not isinstance(node, pattern_entries[nid]):
|
|
529
|
+
return None
|
|
530
|
+
in_nodes = incomings_fn(node)
|
|
531
|
+
if len(in_nodes) > 1 and nid != num_entries - 1:
|
|
532
|
+
return None
|
|
533
|
+
out_nodes = outcomings_fn(node)
|
|
534
|
+
if len(out_nodes) > 1 and nid != 0:
|
|
535
|
+
return None
|
|
536
|
+
matched_nodes.append(node)
|
|
537
|
+
if nid != num_entries - 1:
|
|
538
|
+
(node,) = in_nodes
|
|
539
|
+
|
|
540
|
+
return ParameterOptMatch(pattern, matched_nodes)
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
def _match_layer_pattern(
|
|
544
|
+
layer: TorchLayer,
|
|
545
|
+
pattern: LayerOptPattern,
|
|
546
|
+
*,
|
|
547
|
+
incomings_fn: Callable[[TorchLayer], Sequence[TorchLayer]],
|
|
548
|
+
outcomings_fn: Callable[[TorchLayer], Sequence[TorchLayer]],
|
|
549
|
+
) -> LayerOptMatch | None:
|
|
550
|
+
ppatterns = pattern.ppatterns()
|
|
551
|
+
pattern_entries = pattern.entries()
|
|
552
|
+
num_entries = len(pattern_entries)
|
|
553
|
+
matched_layers = []
|
|
554
|
+
matched_parameters = []
|
|
555
|
+
|
|
556
|
+
# Start matching the pattern from the root
|
|
557
|
+
# TODO: generalize to match DAGs or trees
|
|
558
|
+
for lid in range(num_entries):
|
|
559
|
+
# First, attempt to match the layer
|
|
560
|
+
if not isinstance(layer, pattern_entries[lid]):
|
|
561
|
+
return None
|
|
562
|
+
in_nodes = incomings_fn(layer)
|
|
563
|
+
if len(in_nodes) > 1 and lid != num_entries - 1:
|
|
564
|
+
return None
|
|
565
|
+
out_nodes = outcomings_fn(layer)
|
|
566
|
+
if len(out_nodes) > 1 and lid != 0:
|
|
567
|
+
return None
|
|
568
|
+
|
|
569
|
+
# Second, attempt to match the patterns specified for its parameters
|
|
570
|
+
lpmatches = {}
|
|
571
|
+
for pname, ppattern in ppatterns[lid].items():
|
|
572
|
+
pgraph = layer.params[pname]
|
|
573
|
+
matches, _ = match_optimization_patterns(
|
|
574
|
+
pgraph.topological_ordering(),
|
|
575
|
+
pgraph.outputs,
|
|
576
|
+
[ppattern],
|
|
577
|
+
incomings_fn=pgraph.node_inputs,
|
|
578
|
+
outcomings_fn=pgraph.node_outputs,
|
|
579
|
+
pattern_matcher_fn=_match_parameter_nodes_pattern,
|
|
580
|
+
)
|
|
581
|
+
if not matches:
|
|
582
|
+
return None
|
|
583
|
+
lpmatches[pname] = matches
|
|
584
|
+
matched_parameters.append(lpmatches)
|
|
585
|
+
|
|
586
|
+
# We got a match with the layer and its parameters.
|
|
587
|
+
# Next, try to match its input sub-graph.
|
|
588
|
+
matched_layers.append(layer)
|
|
589
|
+
if lid != num_entries - 1:
|
|
590
|
+
(layer,) = in_nodes
|
|
591
|
+
|
|
592
|
+
return LayerOptMatch(pattern, matched_layers, matched_parameters)
|
|
File without changes
|