libcirkit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. cirkit/__init__.py +0 -0
  2. cirkit/backend/__init__.py +0 -0
  3. cirkit/backend/base.py +199 -0
  4. cirkit/backend/compiler.py +213 -0
  5. cirkit/backend/registry.py +53 -0
  6. cirkit/backend/torch/__init__.py +0 -0
  7. cirkit/backend/torch/circuits.py +217 -0
  8. cirkit/backend/torch/compiler.py +592 -0
  9. cirkit/backend/torch/graph/__init__.py +0 -0
  10. cirkit/backend/torch/graph/folding.py +230 -0
  11. cirkit/backend/torch/graph/modules.py +276 -0
  12. cirkit/backend/torch/graph/optimize.py +258 -0
  13. cirkit/backend/torch/initializers.py +49 -0
  14. cirkit/backend/torch/layers/__init__.py +16 -0
  15. cirkit/backend/torch/layers/base.py +119 -0
  16. cirkit/backend/torch/layers/inner.py +335 -0
  17. cirkit/backend/torch/layers/input.py +746 -0
  18. cirkit/backend/torch/layers/optimized.py +241 -0
  19. cirkit/backend/torch/optimization/__init__.py +0 -0
  20. cirkit/backend/torch/optimization/layers.py +166 -0
  21. cirkit/backend/torch/optimization/parameters.py +67 -0
  22. cirkit/backend/torch/optimization/registry.py +81 -0
  23. cirkit/backend/torch/parameters/__init__.py +0 -0
  24. cirkit/backend/torch/parameters/nodes.py +828 -0
  25. cirkit/backend/torch/parameters/parameter.py +117 -0
  26. cirkit/backend/torch/parameters/pic.py +418 -0
  27. cirkit/backend/torch/queries.py +178 -0
  28. cirkit/backend/torch/rules/__init__.py +3 -0
  29. cirkit/backend/torch/rules/initializers.py +53 -0
  30. cirkit/backend/torch/rules/layers.py +184 -0
  31. cirkit/backend/torch/rules/parameters.py +280 -0
  32. cirkit/backend/torch/semiring.py +492 -0
  33. cirkit/backend/torch/utils.py +102 -0
  34. cirkit/pipeline.py +355 -0
  35. cirkit/symbolic/__init__.py +0 -0
  36. cirkit/symbolic/circuit.py +938 -0
  37. cirkit/symbolic/dtypes.py +45 -0
  38. cirkit/symbolic/functional.py +674 -0
  39. cirkit/symbolic/initializers.py +121 -0
  40. cirkit/symbolic/layers.py +788 -0
  41. cirkit/symbolic/operators.py +384 -0
  42. cirkit/symbolic/parameters.py +921 -0
  43. cirkit/symbolic/registry.py +119 -0
  44. cirkit/templates/__init__.py +0 -0
  45. cirkit/templates/circuit_templates/__init__.py +2 -0
  46. cirkit/templates/circuit_templates/data.py +107 -0
  47. cirkit/templates/circuit_templates/utils.py +287 -0
  48. cirkit/templates/region_graph/__init__.py +11 -0
  49. cirkit/templates/region_graph/algorithms/__init__.py +9 -0
  50. cirkit/templates/region_graph/algorithms/chow_liu.py +141 -0
  51. cirkit/templates/region_graph/algorithms/factorized.py +43 -0
  52. cirkit/templates/region_graph/algorithms/linear.py +77 -0
  53. cirkit/templates/region_graph/algorithms/poon_domingos.py +203 -0
  54. cirkit/templates/region_graph/algorithms/quad.py +179 -0
  55. cirkit/templates/region_graph/algorithms/random.py +110 -0
  56. cirkit/templates/region_graph/algorithms/utils.py +124 -0
  57. cirkit/templates/region_graph/graph.py +335 -0
  58. cirkit/utils/__init__.py +0 -0
  59. cirkit/utils/algorithms.py +218 -0
  60. cirkit/utils/scope.py +13 -0
  61. libcirkit-0.1.0.dist-info/LICENSE +674 -0
  62. libcirkit-0.1.0.dist-info/METADATA +200 -0
  63. libcirkit-0.1.0.dist-info/RECORD +65 -0
  64. libcirkit-0.1.0.dist-info/WHEEL +5 -0
  65. libcirkit-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,592 @@
1
+ import functools
2
+ from collections import defaultdict
3
+ from collections.abc import Callable, Sequence
4
+ from itertools import chain
5
+ from typing import cast
6
+
7
+ import torch
8
+ from torch import Tensor
9
+
10
+ from cirkit.backend.compiler import (
11
+ AbstractCompiler,
12
+ CompilerInitializerRegistry,
13
+ CompilerLayerRegistry,
14
+ CompilerParameterRegistry,
15
+ )
16
+ from cirkit.backend.registry import CompilerRegistry
17
+ from cirkit.backend.torch.circuits import AbstractTorchCircuit, TorchCircuit, TorchConstantCircuit
18
+ from cirkit.backend.torch.graph.folding import build_folded_graph
19
+ from cirkit.backend.torch.graph.optimize import (
20
+ GraphOptPattern,
21
+ match_optimization_patterns,
22
+ optimize_graph,
23
+ )
24
+ from cirkit.backend.torch.initializers import stacked_initializer_
25
+ from cirkit.backend.torch.layers import TorchInputLayer, TorchLayer
26
+ from cirkit.backend.torch.layers.input import TorchConstantLayer
27
+ from cirkit.backend.torch.optimization.layers import (
28
+ DEFAULT_LAYER_FUSE_OPT_RULES,
29
+ DEFAULT_LAYER_SHATTER_OPT_RULES,
30
+ )
31
+ from cirkit.backend.torch.optimization.parameters import DEFAULT_PARAMETER_OPT_RULES
32
+ from cirkit.backend.torch.optimization.registry import (
33
+ LayerOptApplyFunc,
34
+ LayerOptMatch,
35
+ LayerOptPattern,
36
+ LayerOptRegistry,
37
+ ParameterOptApplyFunc,
38
+ ParameterOptMatch,
39
+ ParameterOptPattern,
40
+ ParameterOptRegistry,
41
+ )
42
+ from cirkit.backend.torch.parameters.nodes import (
43
+ TorchParameterNode,
44
+ TorchParameterOp,
45
+ TorchPointerParameter,
46
+ TorchTensorParameter,
47
+ )
48
+ from cirkit.backend.torch.parameters.parameter import TorchParameter
49
+ from cirkit.backend.torch.rules import (
50
+ DEFAULT_INITIALIZER_COMPILATION_RULES,
51
+ DEFAULT_LAYER_COMPILATION_RULES,
52
+ DEFAULT_PARAMETER_COMPILATION_RULES,
53
+ )
54
+ from cirkit.backend.torch.semiring import Semiring, SemiringImpl
55
+ from cirkit.symbolic.circuit import Circuit, pipeline_topological_ordering
56
+ from cirkit.symbolic.initializers import Initializer
57
+ from cirkit.symbolic.layers import Layer
58
+ from cirkit.symbolic.parameters import Parameter, ParameterNode, TensorParameter
59
+
60
+
61
+ class TorchCompilerState:
62
+ def __init__(self):
63
+ # A map from symbolic parameter tensors to a tuple containing the compiled parameter tensor,
64
+ # and the slice index, which is 0 if the compiled parameter tensor is unfolded.
65
+ # If the compiled parameter tensor is folded, then the slice index can be non-zero.
66
+ self._compiled_parameters: dict[TensorParameter, tuple[TorchTensorParameter, int]] = {}
67
+
68
+ # We keep a reverse map from compiled and unfolded parameter tensors
69
+ # to the corresponding symbolic parameter tensors.
70
+ # This is useful to update the map from symbolic to compiled parameter tensors above
71
+ # after we fold the tensor parameters within a circuit.
72
+ # Since this is useful only for folding, it will be cleared after each circuit compilation.
73
+ self._symbolic_parameters: dict[TorchTensorParameter, TensorParameter] = {}
74
+
75
+ def finish_compilation(self) -> None:
76
+ # Clear the map from (unfolded) compiled parameter tensors to symbolic ones
77
+ self._symbolic_parameters.clear()
78
+
79
+ def has_compiled_parameter(self, p: TensorParameter) -> bool:
80
+ # Retrieve whether a tensor parameter has already been compiled
81
+ return p in self._compiled_parameters
82
+
83
+ def retrieve_compiled_parameter(self, p: TensorParameter) -> tuple[TorchTensorParameter, int]:
84
+ # Retrieve the compiled parameter: we return the fold index as well.
85
+ return self._compiled_parameters[p]
86
+
87
+ def retrieve_symbolic_parameter(self, p: TorchTensorParameter) -> TensorParameter:
88
+ # Retrieve the symbolic parameter tensor associated to the compiled one (which is unfolded)
89
+ return self._symbolic_parameters[p]
90
+
91
+ def register_compiled_parameter(
92
+ self, sp: TensorParameter, cp: TorchTensorParameter, *, fold_idx: int | None = None
93
+ ) -> None:
94
+ # Register a link from a symbolic parameter tensor to a compiled parameter tensor.
95
+ if fold_idx is None:
96
+ # We are registering an unfolded compiled parameter tensor
97
+ # So, we can also register the reverse map (i.e., compiled to symbolic)
98
+ self._compiled_parameters[sp] = (cp, 0)
99
+ self._symbolic_parameters[cp] = sp
100
+
101
+ # We are registering a folded compiled parameter tensor
102
+ # So, we associate the symbolic parameter tensor to a particular slice of the
103
+ # folded compiled parameter tensor, which is specified by the 'fold_idx'.
104
+ self._compiled_parameters[sp] = (cp, fold_idx)
105
+
106
+
107
+ class TorchCompiler(AbstractCompiler):
108
+ def __init__(self, semiring: str = "sum-product", fold: bool = False, optimize: bool = False):
109
+ super().__init__(
110
+ CompilerLayerRegistry(DEFAULT_LAYER_COMPILATION_RULES),
111
+ CompilerParameterRegistry(DEFAULT_PARAMETER_COMPILATION_RULES),
112
+ CompilerInitializerRegistry(DEFAULT_INITIALIZER_COMPILATION_RULES),
113
+ fold=fold,
114
+ optimize=optimize,
115
+ )
116
+
117
+ # The semiring being used at compile time
118
+ self._semiring: Semiring = SemiringImpl.from_name(semiring)
119
+
120
+ # The state of the compiler
121
+ self._state = TorchCompilerState()
122
+
123
+ # The registry of optimization rules
124
+ self._optimization_registry = {
125
+ "parameter": ParameterOptRegistry(DEFAULT_PARAMETER_OPT_RULES),
126
+ "layer_fuse": LayerOptRegistry(DEFAULT_LAYER_FUSE_OPT_RULES),
127
+ "layer_shatter": LayerOptRegistry(DEFAULT_LAYER_SHATTER_OPT_RULES),
128
+ }
129
+
130
+ def compile_pipeline(self, sc: Circuit) -> AbstractTorchCircuit:
131
+ # Compile the circuits following the topological ordering of the pipeline.
132
+ for sci in pipeline_topological_ordering([sc]):
133
+ # Check if the circuit in the pipeline has already been compiled
134
+ if self.is_compiled(sci):
135
+ continue
136
+
137
+ # Compile the circuit
138
+ self._compile_circuit(sci)
139
+
140
+ # Return the compiled circuit (i.e., the output of the circuit pipeline)
141
+ return self.get_compiled_circuit(sc)
142
+
143
+ @property
144
+ def semiring(self) -> Semiring:
145
+ return self._semiring
146
+
147
+ @property
148
+ def is_fold_enabled(self) -> bool:
149
+ return self._flags["fold"]
150
+
151
+ @property
152
+ def is_optimize_enabled(self) -> bool:
153
+ return self._flags["optimize"]
154
+
155
+ @property
156
+ def state(self) -> TorchCompilerState:
157
+ return self._state
158
+
159
+ def compile_layer(self, layer: Layer) -> TorchLayer:
160
+ signature = type(layer)
161
+ rule = self.retrieve_layer_rule(signature)
162
+ return cast(TorchLayer, rule(self, layer))
163
+
164
+ def compile_parameter(self, parameter: Parameter) -> TorchParameter:
165
+ # A map from symbolic to compiled parameters
166
+ compiled_nodes_map: dict[ParameterNode, TorchParameterNode] = {}
167
+
168
+ # The parameter nodes, and their inputs
169
+ nodes: list[TorchParameterNode] = []
170
+ in_nodes: dict[TorchParameterNode, list[TorchParameterNode]] = {}
171
+
172
+ # Compile the parameter by following the topological ordering
173
+ for p in parameter.topological_ordering():
174
+ # Compile the parameter node and make the connections
175
+ compiled_p = self._compile_parameter_node(p)
176
+ in_compiled_nodes = [compiled_nodes_map[pi] for pi in parameter.node_inputs(p)]
177
+ in_nodes[compiled_p] = in_compiled_nodes
178
+ compiled_nodes_map[p] = compiled_p
179
+ nodes.append(compiled_p)
180
+
181
+ # Build the parameter's computational graph
182
+ outputs = [compiled_nodes_map[parameter.output]]
183
+ return TorchParameter(nodes, in_nodes, outputs)
184
+
185
+ def compile_initializer(self, initializer: Initializer) -> Callable[[Tensor], Tensor]:
186
+ # Retrieve the rule for the given initializer and compile it
187
+ signature = type(initializer)
188
+ rule = self.retrieve_initializer_rule(signature)
189
+ return cast(Callable[[Tensor], Tensor], rule(self, initializer))
190
+
191
+ def retrieve_optimization_registry(self, kind: str) -> CompilerRegistry:
192
+ return cast(CompilerRegistry, self._optimization_registry[kind])
193
+
194
+ def retrieve_optimization_rule(self, kind: str, pattern: GraphOptPattern) -> Callable:
195
+ registry = self.retrieve_optimization_registry(kind)
196
+ return registry.retrieve_rule(pattern)
197
+
198
+ def _compile_parameter_node(self, node: ParameterNode) -> TorchParameterNode:
199
+ signature = type(node)
200
+ rule = self.retrieve_parameter_rule(signature)
201
+ return cast(TorchParameterNode, rule(self, node))
202
+
203
+ def _compile_circuit(self, sc: Circuit) -> AbstractTorchCircuit:
204
+ # A map from symbolic to compiled layers
205
+ compiled_layers_map: dict[Layer, TorchLayer] = {}
206
+
207
+ # The inputs of each layer
208
+ in_layers: dict[TorchLayer, list[TorchLayer]] = {}
209
+
210
+ # Compile layers by following the topological ordering
211
+ for sl in sc.topological_ordering():
212
+ # Compile the layer, for any layer types
213
+ layer = self.compile_layer(sl)
214
+
215
+ # Build the connectivity between compiled layers
216
+ ins = [compiled_layers_map[sli] for sli in sc.layer_inputs(sl)]
217
+ in_layers[layer] = ins
218
+ compiled_layers_map[sl] = layer
219
+
220
+ # If the symbolic circuit being compiled has empty scope,
221
+ # then return a 'constant circuit' whose interface does not require inputs
222
+ cc_cls = TorchCircuit if sc.scope else TorchConstantCircuit
223
+
224
+ # Construct the sequence of output layers
225
+ outputs = [compiled_layers_map[sl] for sl in sc.outputs]
226
+
227
+ # Construct the tensorized circuit
228
+ layers = [compiled_layers_map[sl] for sl in compiled_layers_map.keys()]
229
+ cc = cc_cls(
230
+ sc.scope,
231
+ sc.num_channels,
232
+ layers=layers,
233
+ in_layers=in_layers,
234
+ outputs=outputs,
235
+ properties=sc.properties,
236
+ )
237
+
238
+ # Post-process the compiled circuit, i.e.,
239
+ # optionally apply optimizations to it and then fold it
240
+ cc = self._post_process_circuit(cc)
241
+
242
+ # Allocate & initialize the parameters
243
+ cc.reset_parameters()
244
+
245
+ # Register the compiled circuit
246
+ self.register_compiled_circuit(sc, cc)
247
+
248
+ # Signal the end of the circuit compilation to the state
249
+ self._state.finish_compilation()
250
+ return cc
251
+
252
+ def _post_process_circuit(self, cc: AbstractTorchCircuit) -> AbstractTorchCircuit:
253
+ if self.is_optimize_enabled:
254
+ # Optimize the circuit computational graph
255
+ opt_cc = _optimize_circuit(self, cc, max_opt_steps=5)
256
+ del cc
257
+ cc = opt_cc
258
+ if self.is_fold_enabled:
259
+ # Optimize the circuit by folding it
260
+ opt_cc = _fold_circuit(self, cc)
261
+ del cc
262
+ cc = opt_cc
263
+ return cc
264
+
265
+
266
+ def _fold_circuit(compiler: TorchCompiler, cc: AbstractTorchCircuit) -> AbstractTorchCircuit:
267
+ # Fold the layers in the given circuit, by following the layer-wise topological ordering
268
+ layers, in_layers, outputs, fold_idx_info = build_folded_graph(
269
+ cc.layerwise_topological_ordering(),
270
+ outputs=cc.outputs,
271
+ incomings_fn=cc.layer_inputs,
272
+ fold_group_fn=functools.partial(_fold_layers_group, compiler=compiler),
273
+ )
274
+
275
+ # Instantiate a folded circuit
276
+ return type(cc)(
277
+ cc.scope,
278
+ cc.num_channels,
279
+ layers,
280
+ in_layers,
281
+ outputs,
282
+ properties=cc.properties,
283
+ fold_idx_info=fold_idx_info,
284
+ )
285
+
286
+
287
+ def _fold_layers_group(layers: list[TorchLayer], *, compiler: TorchCompiler) -> TorchLayer:
288
+ # Retrieve the class of the folded layer, as well as the configuration attributes
289
+ fold_layer_cls = type(layers[0])
290
+ fold_layer_conf = layers[0].config
291
+
292
+ # If we are folding input layers, then concatenate the variables scope index tensors
293
+ kwargs = {}
294
+ if issubclass(fold_layer_cls, TorchInputLayer):
295
+ if not issubclass(fold_layer_cls, TorchConstantLayer):
296
+ kwargs["scope_idx"] = torch.cat([l.scope_idx for l in layers])
297
+ else:
298
+ # We are folding sum or product layers, so simply set the number of folds
299
+ kwargs["num_folds"] = sum(l.num_folds for l in layers)
300
+
301
+ # Retrieve the parameters of each layer, and
302
+ # retrieve the sub-module layers of each layer
303
+ layer_params: dict[str, list[TorchParameter]] = defaultdict(list)
304
+ layer_submodules: dict[str, list[TorchLayer]] = defaultdict(list)
305
+ for l in layers:
306
+ for n, p in l.params.items():
307
+ layer_params[n].append(p)
308
+ for n, sub_l in l.sub_modules.items():
309
+ layer_submodules[n].append(sub_l)
310
+
311
+ # Fold the parameters, if the layers have any
312
+ fold_layer_parameters: dict[str, TorchParameter] = {
313
+ n: _fold_parameters(compiler, ps) for n, ps in layer_params.items()
314
+ }
315
+
316
+ # Fold all sub-module layers, if the layers have any
317
+ fold_layer_submodules: dict[str, TorchLayer] = {
318
+ n: _fold_layers_group(ls, compiler=compiler) for n, ls in layer_submodules.items()
319
+ }
320
+
321
+ # Instantiate a new folded layer, using the folded layer configuration and the folded parameters
322
+ return fold_layer_cls(
323
+ **fold_layer_conf,
324
+ **fold_layer_submodules,
325
+ **fold_layer_parameters,
326
+ semiring=compiler.semiring,
327
+ **kwargs,
328
+ )
329
+
330
+
331
+ def _fold_parameters(compiler: TorchCompiler, parameters: list[TorchParameter]) -> TorchParameter:
332
+ # Retrieve:
333
+ # (i) the parameter nodes and the input to each node;
334
+ # (ii) the layer-wise (aka bottom-up) topological orderings of parameter nodes
335
+ in_nodes: dict[TorchParameterNode, Sequence[TorchParameterNode]] = {}
336
+ for pi in parameters:
337
+ in_nodes.update(pi.nodes_inputs)
338
+ ordering: list[list[TorchParameterNode]] = []
339
+ for pi in parameters:
340
+ for i, frontier in enumerate(pi.layerwise_topological_ordering()):
341
+ if i < len(ordering):
342
+ ordering[i].extend(frontier)
343
+ continue
344
+ ordering.append(frontier)
345
+
346
+ # Fold the nodes in the merged parameter computational graphs,
347
+ # by following the layer-wise topological ordering
348
+ nodes, in_nodes, outputs, fold_idx_info = build_folded_graph(
349
+ ordering,
350
+ outputs=chain.from_iterable(map(lambda pi: pi.outputs, parameters)),
351
+ incomings_fn=in_nodes.get,
352
+ fold_group_fn=functools.partial(_fold_parameter_nodes_group, compiler=compiler),
353
+ )
354
+
355
+ # Construct the folded parameter's computational graph
356
+ return TorchParameter(nodes, in_nodes, outputs, fold_idx_info=fold_idx_info)
357
+
358
+
359
+ def _fold_parameter_nodes_group(
360
+ group: list[TorchParameterNode], *, compiler: TorchCompiler
361
+ ) -> TorchParameterNode:
362
+ fold_node_cls = type(group[0])
363
+ # Catch the case we are folding tensor parameters
364
+ # That is, we set the number of folds, copy the number of parameters and relevant flags,
365
+ # and stack the initialization functions together.
366
+ if issubclass(fold_node_cls, TorchTensorParameter):
367
+ assert all(isinstance(p, TorchTensorParameter) for p in group)
368
+ folded_node = TorchTensorParameter(
369
+ *group[0].shape,
370
+ num_folds=len(group),
371
+ requires_grad=group[0].requires_grad,
372
+ initializer_=functools.partial(
373
+ stacked_initializer_, initializers=list(map(lambda p: p.initializer, group))
374
+ ),
375
+ dtype=group[0].dtype,
376
+ )
377
+ # If we are folding parameter tensors, then update the registry as to maintain the correct
378
+ # mapping between symbolic parameter leaves (which are unfolded) and slices within the folded
379
+ # compiled parameter leaves.
380
+ for i, p in enumerate(group):
381
+ sp = compiler.state.retrieve_symbolic_parameter(p)
382
+ compiler.state.register_compiled_parameter(sp, folded_node, fold_idx=i)
383
+ return folded_node
384
+ # Catch the case we are folding parameters obtained via slicing
385
+ # This case regularly fires when doing operations over circuits
386
+ # that are compiled into folded tensorized circuits
387
+ if issubclass(fold_node_cls, TorchPointerParameter):
388
+ assert all(isinstance(p, TorchPointerParameter) for p in group)
389
+ if len(group) == 1:
390
+ # Catch the case we are not able to fold multiple tensor slicing operations
391
+ # In such a case, just have the slice as folded parameter (i.e., number of folds = 1)
392
+ return group[0]
393
+ # Catch the case we are able to fold multiple tensor slicing operations
394
+ in_folded_node = group[0].deref()
395
+ in_fold_idx: list[int] = list(
396
+ chain.from_iterable(
397
+ list(range(p.num_folds)) if p.fold_idx is None else p.fold_idx for p in group
398
+ )
399
+ )
400
+ return TorchPointerParameter(in_folded_node, fold_idx=in_fold_idx)
401
+ # We are folding an operator: just set the number of folds and copy the configuration parameters
402
+ assert all(isinstance(p, TorchParameterOp) for p in group)
403
+ return fold_node_cls(**group[0].config, num_folds=len(group))
404
+
405
+
406
+ def _optimize_circuit(
407
+ compiler: TorchCompiler, cc: AbstractTorchCircuit, *, max_opt_steps: int = 5
408
+ ) -> AbstractTorchCircuit:
409
+ assert max_opt_steps > 0
410
+
411
+ # Each optimization step consists of three kinds of optimizations (see below).
412
+ # We continue optimizing until no further optimization can be performed
413
+ # or if we reach a maximum number of optimization steps being performed
414
+ optimizing = True
415
+ opt_step = 0
416
+ while optimizing and opt_step < max_opt_steps:
417
+ # First optimization step: optimize the parameters node of the parameter graphs of each layer
418
+ opt_cc, opt_fuse_parameter_nodes = _optimize_parameter_nodes(compiler, cc)
419
+ del cc
420
+ cc = opt_cc
421
+
422
+ # Second optimization step: shatter layers in multiple more efficient ones
423
+ opt_cc, opt_shatter_layers = _optimize_layers(compiler, cc, shatter=True)
424
+ del cc
425
+ cc = opt_cc
426
+
427
+ # Third optimization step: fuse multiple layers into a single more efficient one
428
+ opt_cc, opt_fuse_layers = _optimize_layers(compiler, cc, shatter=False)
429
+ del cc
430
+ cc = opt_cc
431
+
432
+ # Update the optimization step and whether we should continue optimizing
433
+ optimizing = opt_fuse_parameter_nodes or opt_shatter_layers or opt_fuse_layers
434
+ opt_step += 1
435
+
436
+ return cc
437
+
438
+
439
+ def _optimize_parameter_nodes(
440
+ compiler: TorchCompiler, cc: AbstractTorchCircuit
441
+ ) -> tuple[AbstractTorchCircuit, bool]:
442
+ def match_optimizer(match: ParameterOptMatch) -> tuple[TorchParameterNode, ...]:
443
+ rule = compiler.retrieve_optimization_rule("parameter", match.pattern)
444
+ func = cast(ParameterOptApplyFunc, rule)
445
+ return func(compiler, match)
446
+
447
+ # Loop through all the layers
448
+ has_been_optimized = False
449
+ patterns = compiler.retrieve_optimization_registry("parameter").signatures
450
+ for layer in cc.layers:
451
+ # Retrieve the parameter computational graphs of the layer
452
+ for pname, pgraph in layer.params.items():
453
+ # Optimize the parameter computational graph
454
+ optimize_result = optimize_graph(
455
+ pgraph.topological_ordering(),
456
+ pgraph.outputs,
457
+ patterns,
458
+ incomings_fn=pgraph.node_inputs,
459
+ outcomings_fn=pgraph.node_outputs,
460
+ pattern_matcher_fn=_match_parameter_nodes_pattern,
461
+ match_optimizer_fn=match_optimizer,
462
+ )
463
+
464
+ # Check if no optimization is possible
465
+ if optimize_result is None:
466
+ continue
467
+ nodes, in_nodes, outputs = optimize_result
468
+
469
+ # Build the optimized computational graph
470
+ pgraph = type(pgraph)(nodes, in_nodes, outputs)
471
+
472
+ # Update the parameter computational graph assigned to the layer
473
+ assert hasattr(layer, pname)
474
+ setattr(layer, pname, pgraph)
475
+ has_been_optimized = True
476
+
477
+ # Check whether no parameter optimization has been possible
478
+ if has_been_optimized:
479
+ return cc, True
480
+ return cc, False
481
+
482
+
483
+ def _optimize_layers(
484
+ compiler: TorchCompiler, cc: AbstractTorchCircuit, *, shatter: bool = False
485
+ ) -> tuple[AbstractTorchCircuit, bool]:
486
+ def match_optimizer_shatter(match: LayerOptMatch) -> tuple[TorchLayer, ...]:
487
+ rule = compiler.retrieve_optimization_rule("layer_shatter", match.pattern)
488
+ func = cast(LayerOptApplyFunc, rule)
489
+ return func(compiler, match)
490
+
491
+ def match_optimizer_fuse(match: LayerOptMatch) -> tuple[TorchLayer, ...]:
492
+ rule = compiler.retrieve_optimization_rule("layer_fuse", match.pattern)
493
+ func = cast(LayerOptApplyFunc, rule)
494
+ return func(compiler, match)
495
+
496
+ registry = compiler.retrieve_optimization_registry("layer_shatter" if shatter else "layer_fuse")
497
+ match_optimizer = match_optimizer_shatter if shatter else match_optimizer_fuse
498
+ optimize_result = optimize_graph(
499
+ cc.topological_ordering(),
500
+ cc.outputs,
501
+ registry.signatures,
502
+ incomings_fn=cc.layer_inputs,
503
+ outcomings_fn=cc.layer_outputs,
504
+ pattern_matcher_fn=_match_layer_pattern,
505
+ match_optimizer_fn=match_optimizer,
506
+ )
507
+ if optimize_result is None:
508
+ return cc, False
509
+ layers, in_layers, outputs = optimize_result
510
+ cc = type(cc)(cc.scope, cc.num_channels, layers, in_layers, outputs, properties=cc.properties)
511
+ return cc, True
512
+
513
+
514
+ def _match_parameter_nodes_pattern(
515
+ node: TorchParameterNode,
516
+ pattern: ParameterOptPattern,
517
+ *,
518
+ incomings_fn: Callable[[TorchParameterNode], Sequence[TorchParameterNode]],
519
+ outcomings_fn: Callable[[TorchParameterNode], Sequence[TorchParameterNode]],
520
+ ) -> ParameterOptMatch | None:
521
+ pattern_entries = pattern.entries()
522
+ num_entries = len(pattern_entries)
523
+ matched_nodes = []
524
+
525
+ # Start matching the pattern from the root
526
+ # TODO: generalize to match DAGs or binary trees
527
+ for nid in range(num_entries):
528
+ if not isinstance(node, pattern_entries[nid]):
529
+ return None
530
+ in_nodes = incomings_fn(node)
531
+ if len(in_nodes) > 1 and nid != num_entries - 1:
532
+ return None
533
+ out_nodes = outcomings_fn(node)
534
+ if len(out_nodes) > 1 and nid != 0:
535
+ return None
536
+ matched_nodes.append(node)
537
+ if nid != num_entries - 1:
538
+ (node,) = in_nodes
539
+
540
+ return ParameterOptMatch(pattern, matched_nodes)
541
+
542
+
543
+ def _match_layer_pattern(
544
+ layer: TorchLayer,
545
+ pattern: LayerOptPattern,
546
+ *,
547
+ incomings_fn: Callable[[TorchLayer], Sequence[TorchLayer]],
548
+ outcomings_fn: Callable[[TorchLayer], Sequence[TorchLayer]],
549
+ ) -> LayerOptMatch | None:
550
+ ppatterns = pattern.ppatterns()
551
+ pattern_entries = pattern.entries()
552
+ num_entries = len(pattern_entries)
553
+ matched_layers = []
554
+ matched_parameters = []
555
+
556
+ # Start matching the pattern from the root
557
+ # TODO: generalize to match DAGs or trees
558
+ for lid in range(num_entries):
559
+ # First, attempt to match the layer
560
+ if not isinstance(layer, pattern_entries[lid]):
561
+ return None
562
+ in_nodes = incomings_fn(layer)
563
+ if len(in_nodes) > 1 and lid != num_entries - 1:
564
+ return None
565
+ out_nodes = outcomings_fn(layer)
566
+ if len(out_nodes) > 1 and lid != 0:
567
+ return None
568
+
569
+ # Second, attempt to match the patterns specified for its parameters
570
+ lpmatches = {}
571
+ for pname, ppattern in ppatterns[lid].items():
572
+ pgraph = layer.params[pname]
573
+ matches, _ = match_optimization_patterns(
574
+ pgraph.topological_ordering(),
575
+ pgraph.outputs,
576
+ [ppattern],
577
+ incomings_fn=pgraph.node_inputs,
578
+ outcomings_fn=pgraph.node_outputs,
579
+ pattern_matcher_fn=_match_parameter_nodes_pattern,
580
+ )
581
+ if not matches:
582
+ return None
583
+ lpmatches[pname] = matches
584
+ matched_parameters.append(lpmatches)
585
+
586
+ # We got a match with the layer and its parameters.
587
+ # Next, try to match its input sub-graph.
588
+ matched_layers.append(layer)
589
+ if lid != num_entries - 1:
590
+ (layer,) = in_nodes
591
+
592
+ return LayerOptMatch(pattern, matched_layers, matched_parameters)
File without changes