libcirkit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cirkit/__init__.py +0 -0
- cirkit/backend/__init__.py +0 -0
- cirkit/backend/base.py +199 -0
- cirkit/backend/compiler.py +213 -0
- cirkit/backend/registry.py +53 -0
- cirkit/backend/torch/__init__.py +0 -0
- cirkit/backend/torch/circuits.py +217 -0
- cirkit/backend/torch/compiler.py +592 -0
- cirkit/backend/torch/graph/__init__.py +0 -0
- cirkit/backend/torch/graph/folding.py +230 -0
- cirkit/backend/torch/graph/modules.py +276 -0
- cirkit/backend/torch/graph/optimize.py +258 -0
- cirkit/backend/torch/initializers.py +49 -0
- cirkit/backend/torch/layers/__init__.py +16 -0
- cirkit/backend/torch/layers/base.py +119 -0
- cirkit/backend/torch/layers/inner.py +335 -0
- cirkit/backend/torch/layers/input.py +746 -0
- cirkit/backend/torch/layers/optimized.py +241 -0
- cirkit/backend/torch/optimization/__init__.py +0 -0
- cirkit/backend/torch/optimization/layers.py +166 -0
- cirkit/backend/torch/optimization/parameters.py +67 -0
- cirkit/backend/torch/optimization/registry.py +81 -0
- cirkit/backend/torch/parameters/__init__.py +0 -0
- cirkit/backend/torch/parameters/nodes.py +828 -0
- cirkit/backend/torch/parameters/parameter.py +117 -0
- cirkit/backend/torch/parameters/pic.py +418 -0
- cirkit/backend/torch/queries.py +178 -0
- cirkit/backend/torch/rules/__init__.py +3 -0
- cirkit/backend/torch/rules/initializers.py +53 -0
- cirkit/backend/torch/rules/layers.py +184 -0
- cirkit/backend/torch/rules/parameters.py +280 -0
- cirkit/backend/torch/semiring.py +492 -0
- cirkit/backend/torch/utils.py +102 -0
- cirkit/pipeline.py +355 -0
- cirkit/symbolic/__init__.py +0 -0
- cirkit/symbolic/circuit.py +938 -0
- cirkit/symbolic/dtypes.py +45 -0
- cirkit/symbolic/functional.py +674 -0
- cirkit/symbolic/initializers.py +121 -0
- cirkit/symbolic/layers.py +788 -0
- cirkit/symbolic/operators.py +384 -0
- cirkit/symbolic/parameters.py +921 -0
- cirkit/symbolic/registry.py +119 -0
- cirkit/templates/__init__.py +0 -0
- cirkit/templates/circuit_templates/__init__.py +2 -0
- cirkit/templates/circuit_templates/data.py +107 -0
- cirkit/templates/circuit_templates/utils.py +287 -0
- cirkit/templates/region_graph/__init__.py +11 -0
- cirkit/templates/region_graph/algorithms/__init__.py +9 -0
- cirkit/templates/region_graph/algorithms/chow_liu.py +141 -0
- cirkit/templates/region_graph/algorithms/factorized.py +43 -0
- cirkit/templates/region_graph/algorithms/linear.py +77 -0
- cirkit/templates/region_graph/algorithms/poon_domingos.py +203 -0
- cirkit/templates/region_graph/algorithms/quad.py +179 -0
- cirkit/templates/region_graph/algorithms/random.py +110 -0
- cirkit/templates/region_graph/algorithms/utils.py +124 -0
- cirkit/templates/region_graph/graph.py +335 -0
- cirkit/utils/__init__.py +0 -0
- cirkit/utils/algorithms.py +218 -0
- cirkit/utils/scope.py +13 -0
- libcirkit-0.1.0.dist-info/LICENSE +674 -0
- libcirkit-0.1.0.dist-info/METADATA +200 -0
- libcirkit-0.1.0.dist-info/RECORD +65 -0
- libcirkit-0.1.0.dist-info/WHEEL +5 -0
- libcirkit-0.1.0.dist-info/top_level.txt +1 -0
cirkit/__init__.py
ADDED
|
File without changes
|
|
File without changes
|
cirkit/backend/base.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import IO, Any, Protocol, TypeVar
|
|
5
|
+
|
|
6
|
+
from cirkit.symbolic.circuit import Circuit
|
|
7
|
+
from cirkit.symbolic.initializers import Initializer
|
|
8
|
+
from cirkit.symbolic.layers import Layer
|
|
9
|
+
from cirkit.symbolic.parameters import ParameterNode
|
|
10
|
+
from cirkit.utils.algorithms import BiMap
|
|
11
|
+
|
|
12
|
+
CompiledCircuit = TypeVar("CompiledCircuit")
|
|
13
|
+
LayerCompilationSign = type[Layer]
|
|
14
|
+
ParameterCompilationSign = type[ParameterNode]
|
|
15
|
+
InitializerCompilationSign = type[Initializer]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LayerCompilationFunc(Protocol):
|
|
19
|
+
def __call__(self, compiler: "AbstractCompiler", sl: Layer, **kwargs) -> Any:
|
|
20
|
+
...
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ParameterCompilationFunc(Protocol):
|
|
24
|
+
def __call__(self, compiler: "AbstractCompiler", p: ParameterNode, **kwargs) -> Any:
|
|
25
|
+
...
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class InitializerCompilationFunc(Protocol):
|
|
29
|
+
def __call__(self, compiler: "AbstractCompiler", init: Initializer, **kwargs) -> Any:
|
|
30
|
+
...
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class CompilationRuleNotFound(Exception):
|
|
34
|
+
def __init__(self, msg: str):
|
|
35
|
+
super().__init__(msg)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
SUPPORTED_BACKENDS = ["torch"]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class CompiledCircuitsMap:
|
|
42
|
+
def __init__(self):
|
|
43
|
+
self._bimap = BiMap[Circuit, CompiledCircuit]()
|
|
44
|
+
|
|
45
|
+
def is_compiled(self, sc: Circuit) -> bool:
|
|
46
|
+
return self._bimap.has_left(sc)
|
|
47
|
+
|
|
48
|
+
def has_symbolic(self, cc: CompiledCircuit) -> bool:
|
|
49
|
+
return self._bimap.has_right(cc)
|
|
50
|
+
|
|
51
|
+
def get_compiled_circuit(self, sc: Circuit) -> CompiledCircuit:
|
|
52
|
+
return self._bimap.get_left(sc)
|
|
53
|
+
|
|
54
|
+
def get_symbolic_circuit(self, cc: CompiledCircuit) -> Circuit:
|
|
55
|
+
return self._bimap.get_right(cc)
|
|
56
|
+
|
|
57
|
+
def register_compiled_circuit(self, sc: Circuit, cc: CompiledCircuit):
|
|
58
|
+
self._bimap.add(sc, cc)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class CompilerRegistry:
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
layer_rules: dict[LayerCompilationSign, LayerCompilationFunc] | None = None,
|
|
65
|
+
parameter_rules: dict[ParameterCompilationSign, ParameterCompilationFunc] | None = None,
|
|
66
|
+
initializer_rules: None
|
|
67
|
+
| (dict[InitializerCompilationSign, InitializerCompilationFunc]) = None,
|
|
68
|
+
):
|
|
69
|
+
self._layer_rules = {} if layer_rules is None else layer_rules
|
|
70
|
+
self._parameter_rules = {} if parameter_rules is None else parameter_rules
|
|
71
|
+
self._initializer_rules = {} if initializer_rules is None else initializer_rules
|
|
72
|
+
|
|
73
|
+
@staticmethod
|
|
74
|
+
def _validate_rule_sign(func: Callable, sym_cls: type) -> type | None:
|
|
75
|
+
args = func.__annotations__
|
|
76
|
+
if "return" not in args or "compiler" not in args or len(args) != 3:
|
|
77
|
+
return None
|
|
78
|
+
if not issubclass(args["compiler"], AbstractCompiler):
|
|
79
|
+
return None
|
|
80
|
+
arg_names = list(filter(lambda a: a not in ("return", "compiler"), args.keys()))
|
|
81
|
+
found_sym_cls = args[arg_names[0]]
|
|
82
|
+
if not issubclass(found_sym_cls, sym_cls):
|
|
83
|
+
return None
|
|
84
|
+
return found_sym_cls
|
|
85
|
+
|
|
86
|
+
def add_layer_rule(self, func: LayerCompilationFunc):
|
|
87
|
+
layer_cls: type[Layer] | None = self._validate_rule_sign(func, Layer)
|
|
88
|
+
if layer_cls is None:
|
|
89
|
+
raise ValueError("The function is not a symbolic layer compilation rule")
|
|
90
|
+
self._layer_rules[layer_cls] = func
|
|
91
|
+
|
|
92
|
+
def add_parameter_rule(self, func: ParameterCompilationFunc):
|
|
93
|
+
param_cls: type[ParameterNode] | None = self._validate_rule_sign(func, ParameterNode)
|
|
94
|
+
if param_cls is None:
|
|
95
|
+
raise ValueError("The function is not a symbolic parameter compilation rule")
|
|
96
|
+
self._parameter_rules[param_cls] = func
|
|
97
|
+
|
|
98
|
+
def add_initializer_rule(self, func: InitializerCompilationFunc):
|
|
99
|
+
init_cls: type[Initializer] | None = self._validate_rule_sign(func, Initializer)
|
|
100
|
+
if init_cls is None:
|
|
101
|
+
raise ValueError("The function is not a symbolic initializer compilation rule")
|
|
102
|
+
self._initializer_rules[init_cls] = func
|
|
103
|
+
|
|
104
|
+
def retrieve_layer_rule(self, signature: LayerCompilationSign) -> LayerCompilationFunc:
|
|
105
|
+
if signature not in self._layer_rules:
|
|
106
|
+
raise CompilationRuleNotFound(
|
|
107
|
+
f"Layer compilation rule for signature '{signature}' not found"
|
|
108
|
+
)
|
|
109
|
+
return self._layer_rules[signature]
|
|
110
|
+
|
|
111
|
+
def retrieve_parameter_rule(
|
|
112
|
+
self, signature: ParameterCompilationSign
|
|
113
|
+
) -> ParameterCompilationFunc:
|
|
114
|
+
if signature not in self._parameter_rules:
|
|
115
|
+
raise CompilationRuleNotFound(
|
|
116
|
+
f"Parameter compilation rule for signature '{signature}' not found"
|
|
117
|
+
)
|
|
118
|
+
return self._parameter_rules[signature]
|
|
119
|
+
|
|
120
|
+
def retrieve_initializer_rule(
|
|
121
|
+
self, signature: InitializerCompilationSign
|
|
122
|
+
) -> InitializerCompilationFunc:
|
|
123
|
+
if signature not in self._initializer_rules:
|
|
124
|
+
raise CompilationRuleNotFound(
|
|
125
|
+
f"Initializer compilation rule for signature '{signature}' not found"
|
|
126
|
+
)
|
|
127
|
+
return self._initializer_rules[signature]
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class AbstractCompiler(ABC):
|
|
131
|
+
def __init__(self, registry: CompilerRegistry, **flags):
|
|
132
|
+
self._registry = registry
|
|
133
|
+
self._flags = flags
|
|
134
|
+
self._compiled_circuits = CompiledCircuitsMap()
|
|
135
|
+
|
|
136
|
+
def is_compiled(self, sc: Circuit) -> bool:
|
|
137
|
+
return self._compiled_circuits.is_compiled(sc)
|
|
138
|
+
|
|
139
|
+
def has_symbolic(self, cc: CompiledCircuit) -> bool:
|
|
140
|
+
return self._compiled_circuits.has_symbolic(cc)
|
|
141
|
+
|
|
142
|
+
def get_compiled_circuit(self, sc: Circuit) -> CompiledCircuit:
|
|
143
|
+
return self._compiled_circuits.get_compiled_circuit(sc)
|
|
144
|
+
|
|
145
|
+
def get_symbolic_circuit(self, cc: CompiledCircuit) -> Circuit:
|
|
146
|
+
return self._compiled_circuits.get_symbolic_circuit(cc)
|
|
147
|
+
|
|
148
|
+
def register_compiled_circuit(self, sc: Circuit, cc: CompiledCircuit):
|
|
149
|
+
self._compiled_circuits.register_compiled_circuit(sc, cc)
|
|
150
|
+
|
|
151
|
+
def add_layer_rule(self, func: LayerCompilationFunc):
|
|
152
|
+
self._registry.add_layer_rule(func)
|
|
153
|
+
|
|
154
|
+
def add_parameter_rule(self, func: ParameterCompilationFunc):
|
|
155
|
+
self._registry.add_parameter_rule(func)
|
|
156
|
+
|
|
157
|
+
def add_initializer_rule(self, func: InitializerCompilationFunc):
|
|
158
|
+
self._registry.add_initializer_rule(func)
|
|
159
|
+
|
|
160
|
+
def retrieve_layer_rule(self, signature: LayerCompilationSign) -> LayerCompilationFunc:
|
|
161
|
+
return self._registry.retrieve_layer_rule(signature)
|
|
162
|
+
|
|
163
|
+
def retrieve_parameter_rule(
|
|
164
|
+
self, signature: ParameterCompilationSign
|
|
165
|
+
) -> ParameterCompilationFunc:
|
|
166
|
+
return self._registry.retrieve_parameter_rule(signature)
|
|
167
|
+
|
|
168
|
+
def retrieve_initializer_rule(
|
|
169
|
+
self, signature: InitializerCompilationSign
|
|
170
|
+
) -> InitializerCompilationFunc:
|
|
171
|
+
return self._registry.retrieve_initializer_rule(signature)
|
|
172
|
+
|
|
173
|
+
def compile(self, sc: Circuit) -> CompiledCircuit:
|
|
174
|
+
if self.is_compiled(sc):
|
|
175
|
+
return self.get_compiled_circuit(sc)
|
|
176
|
+
return self.compile_pipeline(sc)
|
|
177
|
+
|
|
178
|
+
@abstractmethod
|
|
179
|
+
def compile_layer(self, sl: Layer) -> Any:
|
|
180
|
+
...
|
|
181
|
+
|
|
182
|
+
@abstractmethod
|
|
183
|
+
def compile_pipeline(self, sc: Circuit) -> CompiledCircuit:
|
|
184
|
+
...
|
|
185
|
+
|
|
186
|
+
@abstractmethod
|
|
187
|
+
def save(
|
|
188
|
+
self,
|
|
189
|
+
sym_filepath: IO | os.PathLike | str,
|
|
190
|
+
compiled_filepath: IO | os.PathLike | str,
|
|
191
|
+
):
|
|
192
|
+
...
|
|
193
|
+
|
|
194
|
+
@staticmethod
|
|
195
|
+
@abstractmethod
|
|
196
|
+
def load(
|
|
197
|
+
sym_filepath: IO | os.PathLike | str, tens_filepath: IO | os.PathLike | str
|
|
198
|
+
) -> "AbstractCompiler":
|
|
199
|
+
...
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any, Protocol, TypeVar, cast
|
|
3
|
+
|
|
4
|
+
from cirkit.backend.registry import CompilerRegistry
|
|
5
|
+
from cirkit.symbolic.circuit import Circuit
|
|
6
|
+
from cirkit.symbolic.initializers import Initializer
|
|
7
|
+
from cirkit.symbolic.layers import Layer
|
|
8
|
+
from cirkit.symbolic.parameters import ParameterNode
|
|
9
|
+
from cirkit.utils.algorithms import BiMap
|
|
10
|
+
|
|
11
|
+
SUPPORTED_BACKENDS = ["torch"]
|
|
12
|
+
|
|
13
|
+
CompiledCircuit = TypeVar("CompiledCircuit")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class CompiledCircuitsMap:
|
|
17
|
+
def __init__(self):
|
|
18
|
+
self._bimap = BiMap[Circuit, CompiledCircuit]()
|
|
19
|
+
|
|
20
|
+
def is_compiled(self, sc: Circuit) -> bool:
|
|
21
|
+
return self._bimap.has_left(sc)
|
|
22
|
+
|
|
23
|
+
def has_symbolic(self, cc: CompiledCircuit) -> bool:
|
|
24
|
+
return self._bimap.has_right(cc)
|
|
25
|
+
|
|
26
|
+
def get_compiled_circuit(self, sc: Circuit) -> CompiledCircuit:
|
|
27
|
+
return self._bimap.get_left(sc)
|
|
28
|
+
|
|
29
|
+
def get_symbolic_circuit(self, cc: CompiledCircuit) -> Circuit:
|
|
30
|
+
return self._bimap.get_right(cc)
|
|
31
|
+
|
|
32
|
+
def register_compiled_circuit(self, sc: Circuit, cc: CompiledCircuit):
|
|
33
|
+
self._bimap.add(sc, cc)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
LayerCompilationSign = type[Layer]
|
|
37
|
+
ParameterCompilationSign = type[ParameterNode]
|
|
38
|
+
InitializerCompilationSign = type[Initializer]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class LayerCompilationFunc(Protocol):
|
|
42
|
+
"""The layer compilation function protocol."""
|
|
43
|
+
|
|
44
|
+
def __call__(self, compiler: "AbstractCompiler", sl: Layer, **kwargs) -> Any:
|
|
45
|
+
"""Compile a symbolic layer, given a compiler.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
compiler: The compiler.
|
|
49
|
+
sl: The symbolic layer.
|
|
50
|
+
**kwargs: The optional arguments for the compilation.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
A representation of the compiled layer, which depends on the chosen compilation backend.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class ParameterCompilationFunc(Protocol):
|
|
58
|
+
"""The parameter node compilation function protocol."""
|
|
59
|
+
|
|
60
|
+
def __call__(self, compiler: "AbstractCompiler", p: ParameterNode, **kwargs) -> Any:
|
|
61
|
+
"""Compile a symbolic parameter node, given a compiler.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
compiler: The compiler.
|
|
65
|
+
p: The symbolic parameter node.
|
|
66
|
+
**kwargs: The optional arguments for the compilation.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
A representation of the compiled parameter node,
|
|
70
|
+
which depends on the chosen compilation backend.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class InitializerCompilationFunc(Protocol):
|
|
75
|
+
"""The initialization method compilation function protocol."""
|
|
76
|
+
|
|
77
|
+
def __call__(self, compiler: "AbstractCompiler", init: Initializer, **kwargs) -> Any:
|
|
78
|
+
"""Compile a symbolic initializer, given a compiler.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
compiler: The compiler.
|
|
82
|
+
init: The symbolic initializer.
|
|
83
|
+
**kwargs: The optional arguments for the compilation.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
A representation of the compiled initializer,
|
|
87
|
+
which depends on the chosen compilation backend.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class CompilationRuleNotFound(Exception):
|
|
92
|
+
"""An exception that is raised when a compilation rule is not found."""
|
|
93
|
+
|
|
94
|
+
def __init__(self, msg: str):
|
|
95
|
+
"""Initializes a compilation rule not found exception.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
msg: The message of the exception.
|
|
99
|
+
"""
|
|
100
|
+
super().__init__(msg)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class CompilerLayerRegistry(CompilerRegistry[LayerCompilationSign, LayerCompilationFunc]):
|
|
104
|
+
@classmethod
|
|
105
|
+
def _validate_rule_function(cls, func: LayerCompilationFunc) -> bool:
|
|
106
|
+
ann = func.__annotations__.copy()
|
|
107
|
+
del ann["return"]
|
|
108
|
+
args = tuple(ann.keys())
|
|
109
|
+
return issubclass(ann[args[-1]], Layer)
|
|
110
|
+
|
|
111
|
+
@classmethod
|
|
112
|
+
def _retrieve_signature(cls, func: LayerCompilationFunc) -> LayerCompilationSign:
|
|
113
|
+
ann = func.__annotations__.copy()
|
|
114
|
+
del ann["return"]
|
|
115
|
+
args = tuple(ann.keys())
|
|
116
|
+
return cast(LayerCompilationSign, ann[args[-1]])
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class CompilerParameterRegistry(
|
|
120
|
+
CompilerRegistry[ParameterCompilationSign, ParameterCompilationFunc]
|
|
121
|
+
):
|
|
122
|
+
@classmethod
|
|
123
|
+
def _validate_rule_function(cls, func: ParameterCompilationFunc) -> bool:
|
|
124
|
+
ann = func.__annotations__.copy()
|
|
125
|
+
del ann["return"]
|
|
126
|
+
args = tuple(ann.keys())
|
|
127
|
+
return issubclass(ann[args[-1]], ParameterNode)
|
|
128
|
+
|
|
129
|
+
@classmethod
|
|
130
|
+
def _retrieve_signature(cls, func: ParameterCompilationFunc) -> ParameterCompilationSign:
|
|
131
|
+
ann = func.__annotations__.copy()
|
|
132
|
+
del ann["return"]
|
|
133
|
+
args = tuple(ann.keys())
|
|
134
|
+
return cast(ParameterCompilationSign, ann[args[-1]])
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class CompilerInitializerRegistry(
|
|
138
|
+
CompilerRegistry[InitializerCompilationSign, InitializerCompilationFunc]
|
|
139
|
+
):
|
|
140
|
+
@classmethod
|
|
141
|
+
def _validate_rule_function(cls, func: InitializerCompilationFunc) -> bool:
|
|
142
|
+
ann = func.__annotations__.copy()
|
|
143
|
+
del ann["return"]
|
|
144
|
+
args = tuple(ann.keys())
|
|
145
|
+
return issubclass(ann[args[-1]], Initializer)
|
|
146
|
+
|
|
147
|
+
@classmethod
|
|
148
|
+
def _retrieve_signature(cls, func: ParameterCompilationFunc) -> InitializerCompilationSign:
|
|
149
|
+
ann = func.__annotations__.copy()
|
|
150
|
+
del ann["return"]
|
|
151
|
+
args = tuple(ann.keys())
|
|
152
|
+
return cast(InitializerCompilationSign, ann[args[-1]])
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class AbstractCompiler(ABC):
|
|
156
|
+
def __init__(
|
|
157
|
+
self,
|
|
158
|
+
layers_registry: CompilerLayerRegistry,
|
|
159
|
+
parameters_registry: CompilerParameterRegistry,
|
|
160
|
+
initializers_registry: CompilerInitializerRegistry,
|
|
161
|
+
**flags,
|
|
162
|
+
):
|
|
163
|
+
self._layers_registry = layers_registry
|
|
164
|
+
self._parameters_registry = parameters_registry
|
|
165
|
+
self._initializers_registry = initializers_registry
|
|
166
|
+
self._flags = flags
|
|
167
|
+
self._compiled_circuits = CompiledCircuitsMap()
|
|
168
|
+
|
|
169
|
+
def is_compiled(self, sc: Circuit) -> bool:
|
|
170
|
+
return self._compiled_circuits.is_compiled(sc)
|
|
171
|
+
|
|
172
|
+
def has_symbolic(self, cc: CompiledCircuit) -> bool:
|
|
173
|
+
return self._compiled_circuits.has_symbolic(cc)
|
|
174
|
+
|
|
175
|
+
def get_compiled_circuit(self, sc: Circuit) -> CompiledCircuit:
|
|
176
|
+
return self._compiled_circuits.get_compiled_circuit(sc)
|
|
177
|
+
|
|
178
|
+
def get_symbolic_circuit(self, cc: CompiledCircuit) -> Circuit:
|
|
179
|
+
return self._compiled_circuits.get_symbolic_circuit(cc)
|
|
180
|
+
|
|
181
|
+
def register_compiled_circuit(self, sc: Circuit, cc: CompiledCircuit):
|
|
182
|
+
self._compiled_circuits.register_compiled_circuit(sc, cc)
|
|
183
|
+
|
|
184
|
+
def add_layer_rule(self, func: LayerCompilationFunc):
|
|
185
|
+
self._layers_registry.add_rule(func)
|
|
186
|
+
|
|
187
|
+
def add_parameter_rule(self, func: ParameterCompilationFunc):
|
|
188
|
+
self._parameters_registry.add_rule(func)
|
|
189
|
+
|
|
190
|
+
def add_initializer_rule(self, func: InitializerCompilationFunc):
|
|
191
|
+
self._initializers_registry.add_rule(func)
|
|
192
|
+
|
|
193
|
+
def retrieve_layer_rule(self, signature: LayerCompilationSign) -> LayerCompilationFunc:
|
|
194
|
+
return self._layers_registry.retrieve_rule(signature)
|
|
195
|
+
|
|
196
|
+
def retrieve_parameter_rule(
|
|
197
|
+
self, signature: ParameterCompilationSign
|
|
198
|
+
) -> ParameterCompilationFunc:
|
|
199
|
+
return self._parameters_registry.retrieve_rule(signature)
|
|
200
|
+
|
|
201
|
+
def retrieve_initializer_rule(
|
|
202
|
+
self, signature: InitializerCompilationSign
|
|
203
|
+
) -> InitializerCompilationFunc:
|
|
204
|
+
return self._initializers_registry.retrieve_rule(signature)
|
|
205
|
+
|
|
206
|
+
def compile(self, sc: Circuit) -> CompiledCircuit:
|
|
207
|
+
if self.is_compiled(sc):
|
|
208
|
+
return self.get_compiled_circuit(sc)
|
|
209
|
+
return self.compile_pipeline(sc)
|
|
210
|
+
|
|
211
|
+
@abstractmethod
|
|
212
|
+
def compile_pipeline(self, sc: Circuit) -> CompiledCircuit:
|
|
213
|
+
...
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Generic, TypeVar
|
|
3
|
+
|
|
4
|
+
RegistrySign = TypeVar("RegistrySign")
|
|
5
|
+
RegistryFunc = TypeVar("RegistryFunc")
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class InvalidRuleSign(Exception):
|
|
9
|
+
def __init__(self, annotations: dict[str, type]):
|
|
10
|
+
super().__init__(
|
|
11
|
+
f"Cannot extract rule signature from function with annotations '{annotations}"
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class InvalidRuleFunction(Exception):
|
|
16
|
+
def __init__(self, annotations: dict[str, type]):
|
|
17
|
+
super().__init__(f"Invalid Compilation rule function with annotations '{annotations}'")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class CompilationRuleNotFound(Exception):
|
|
21
|
+
def __init__(self, signature: RegistrySign):
|
|
22
|
+
super().__init__(f"Compilation rule for signature '{signature}' not found")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class CompilerRegistry(Generic[RegistrySign, RegistryFunc], ABC):
|
|
26
|
+
def __init__(self, rules: dict[RegistrySign, RegistryFunc] | None = None):
|
|
27
|
+
self._rules = {} if rules is None else rules
|
|
28
|
+
|
|
29
|
+
@classmethod
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def _validate_rule_function(cls, func: RegistryFunc) -> bool:
|
|
32
|
+
...
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
def _retrieve_signature(cls, func: RegistryFunc) -> RegistrySign:
|
|
36
|
+
raise InvalidRuleSign(func.__annotations__)
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def signatures(self) -> list[RegistrySign]:
|
|
40
|
+
return list(self._rules)
|
|
41
|
+
|
|
42
|
+
def add_rule(self, func: RegistryFunc, *, signature: RegistrySign | None = None) -> None:
|
|
43
|
+
if not self._validate_rule_function(func):
|
|
44
|
+
raise InvalidRuleFunction(func.__annotations__)
|
|
45
|
+
if signature is None:
|
|
46
|
+
signature = self._retrieve_signature(func)
|
|
47
|
+
self._rules[signature] = func
|
|
48
|
+
|
|
49
|
+
def retrieve_rule(self, signature: RegistrySign) -> RegistryFunc:
|
|
50
|
+
func = self._rules.get(signature, None)
|
|
51
|
+
if func is not None:
|
|
52
|
+
return func
|
|
53
|
+
raise CompilationRuleNotFound(signature)
|
|
File without changes
|
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
from collections.abc import Callable, Iterator, Sequence
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from cirkit.backend.torch.graph.folding import (
|
|
7
|
+
build_address_book_stacked_entry,
|
|
8
|
+
build_unfold_index_info,
|
|
9
|
+
)
|
|
10
|
+
from cirkit.backend.torch.graph.modules import (
|
|
11
|
+
AddressBook,
|
|
12
|
+
AddressBookEntry,
|
|
13
|
+
FoldIndexInfo,
|
|
14
|
+
TorchDiAcyclicGraph,
|
|
15
|
+
)
|
|
16
|
+
from cirkit.backend.torch.layers import TorchInputLayer, TorchLayer
|
|
17
|
+
from cirkit.symbolic.circuit import StructuralProperties
|
|
18
|
+
from cirkit.utils.scope import Scope
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class LayerAddressBook(AddressBook):
|
|
22
|
+
def __init__(self, entries: list[AddressBookEntry]):
|
|
23
|
+
super().__init__(entries)
|
|
24
|
+
|
|
25
|
+
def lookup(
|
|
26
|
+
self, module_outputs: list[Tensor], *, in_graph: Tensor | None = None
|
|
27
|
+
) -> Iterator[tuple[TorchLayer | None, tuple[Tensor, ...]]]:
|
|
28
|
+
# Loop through the entries and yield inputs
|
|
29
|
+
for entry in self._entries:
|
|
30
|
+
# Catch the case there are some inputs coming from other modules
|
|
31
|
+
if entry.in_module_ids:
|
|
32
|
+
(in_fold_idx,) = entry.in_fold_idx
|
|
33
|
+
(in_module_ids,) = entry.in_module_ids
|
|
34
|
+
if len(in_module_ids) == 1:
|
|
35
|
+
x = module_outputs[in_module_ids[0]]
|
|
36
|
+
else:
|
|
37
|
+
x = torch.cat([module_outputs[mid] for mid in in_module_ids], dim=0)
|
|
38
|
+
x = x[in_fold_idx]
|
|
39
|
+
yield entry.module, (x,)
|
|
40
|
+
continue
|
|
41
|
+
|
|
42
|
+
# Catch the case there are no inputs coming from other modules
|
|
43
|
+
# That is, we are gathering the inputs of input layers
|
|
44
|
+
assert isinstance(entry.module, TorchInputLayer)
|
|
45
|
+
if in_graph is None:
|
|
46
|
+
yield entry.module, ()
|
|
47
|
+
else:
|
|
48
|
+
# in_graph: An input batch (assignments to variables) of shape (B, C, D)
|
|
49
|
+
# scope_idx: The scope of the layers in each fold, a tensor of shape (F, D'), D' < D
|
|
50
|
+
# x: (B, C, D) -> (B, C, F, D') -> (F, C, B, D')
|
|
51
|
+
x = in_graph[..., entry.module.scope_idx].permute(2, 1, 0, 3)
|
|
52
|
+
yield entry.module, (x,)
|
|
53
|
+
|
|
54
|
+
@classmethod
|
|
55
|
+
def from_index_info(
|
|
56
|
+
cls,
|
|
57
|
+
fold_idx_info: FoldIndexInfo,
|
|
58
|
+
*,
|
|
59
|
+
incomings_fn: Callable[[TorchLayer], Sequence[TorchLayer]],
|
|
60
|
+
) -> "LayerAddressBook":
|
|
61
|
+
# The address book entries being built
|
|
62
|
+
entries: list[AddressBookEntry] = []
|
|
63
|
+
|
|
64
|
+
# A useful dictionary mapping module ids to their number of folds
|
|
65
|
+
num_folds: dict[int, int] = {}
|
|
66
|
+
|
|
67
|
+
# Build the bookkeeping data structure by following the topological ordering
|
|
68
|
+
for mid, m in enumerate(fold_idx_info.ordering):
|
|
69
|
+
# Retrieve the index information of the input modules
|
|
70
|
+
in_modules_fold_idx = fold_idx_info.in_fold_idx[mid]
|
|
71
|
+
|
|
72
|
+
# Catch the case of a folded module having the output of another module as input
|
|
73
|
+
if incomings_fn(m):
|
|
74
|
+
entry = build_address_book_stacked_entry(
|
|
75
|
+
m, in_modules_fold_idx, num_folds=num_folds
|
|
76
|
+
)
|
|
77
|
+
else:
|
|
78
|
+
# Catch the case of a folded module having the input of the network as input
|
|
79
|
+
# That is, this is the case of an input layer
|
|
80
|
+
entry = AddressBookEntry(m, [], [])
|
|
81
|
+
|
|
82
|
+
num_folds[mid] = m.num_folds
|
|
83
|
+
entries.append(entry)
|
|
84
|
+
|
|
85
|
+
# Append the last bookkeeping entry with the information to compute the output tensor
|
|
86
|
+
entry = build_address_book_stacked_entry(
|
|
87
|
+
None, [fold_idx_info.out_fold_idx], num_folds=num_folds, output=True
|
|
88
|
+
)
|
|
89
|
+
entries.append(entry)
|
|
90
|
+
|
|
91
|
+
return LayerAddressBook(entries)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class AbstractTorchCircuit(TorchDiAcyclicGraph[TorchLayer]):
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
scope: Scope,
|
|
98
|
+
num_channels: int,
|
|
99
|
+
layers: Sequence[TorchLayer],
|
|
100
|
+
in_layers: dict[TorchLayer, Sequence[TorchLayer]],
|
|
101
|
+
outputs: Sequence[TorchLayer],
|
|
102
|
+
*,
|
|
103
|
+
properties: StructuralProperties,
|
|
104
|
+
fold_idx_info: FoldIndexInfo | None = None,
|
|
105
|
+
) -> None:
|
|
106
|
+
super().__init__(
|
|
107
|
+
layers,
|
|
108
|
+
in_layers,
|
|
109
|
+
outputs,
|
|
110
|
+
fold_idx_info=fold_idx_info,
|
|
111
|
+
)
|
|
112
|
+
self._scope = scope
|
|
113
|
+
self._num_channels = num_channels
|
|
114
|
+
self._properties = properties
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def scope(self) -> Scope:
|
|
118
|
+
return self._scope
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def num_variables(self) -> int:
|
|
122
|
+
return len(self.scope)
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def num_channels(self) -> int:
|
|
126
|
+
return self._num_channels
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def properties(self) -> StructuralProperties:
|
|
130
|
+
return self._properties
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def layers(self) -> Sequence[TorchLayer]:
|
|
134
|
+
return self.nodes
|
|
135
|
+
|
|
136
|
+
def layer_inputs(self, l: TorchLayer) -> Sequence[TorchLayer]:
|
|
137
|
+
return self.node_inputs(l)
|
|
138
|
+
|
|
139
|
+
def layer_outputs(self, l: TorchLayer) -> Sequence[TorchLayer]:
|
|
140
|
+
return self.node_outputs(l)
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def layers_inputs(self) -> dict[TorchLayer, Sequence[TorchLayer]]:
|
|
144
|
+
return self.nodes_inputs
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def layers_outputs(self) -> dict[TorchLayer, Sequence[TorchLayer]]:
|
|
148
|
+
return self.nodes_outputs
|
|
149
|
+
|
|
150
|
+
def reset_parameters(self) -> None:
|
|
151
|
+
# For each layer, initialize its parameters, if any
|
|
152
|
+
for l in self.layers:
|
|
153
|
+
for p in l.params.values():
|
|
154
|
+
p.reset_parameters()
|
|
155
|
+
|
|
156
|
+
def _set_device(self, device: str | torch.device | int) -> None:
|
|
157
|
+
for l in self.layers:
|
|
158
|
+
for p in l.params.values():
|
|
159
|
+
p._set_device(device)
|
|
160
|
+
super()._set_device(device)
|
|
161
|
+
|
|
162
|
+
def _build_unfold_index_info(self) -> FoldIndexInfo:
|
|
163
|
+
return build_unfold_index_info(
|
|
164
|
+
self.topological_ordering(), outputs=self.outputs, incomings_fn=self.node_inputs
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
def _build_address_book(self, fold_idx_info: FoldIndexInfo) -> LayerAddressBook:
|
|
168
|
+
return LayerAddressBook.from_index_info(fold_idx_info, incomings_fn=self.layer_inputs)
|
|
169
|
+
|
|
170
|
+
def _evaluate_layers(self, x: Tensor) -> Tensor:
|
|
171
|
+
# Evaluate layers on the given input
|
|
172
|
+
y = self.evaluate(x) # (O, B, K)
|
|
173
|
+
return y.transpose(0, 1) # (B, O, K)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class TorchCircuit(AbstractTorchCircuit):
|
|
177
|
+
"""The tensorized circuit with concrete computational graph in PyTorch.
|
|
178
|
+
|
|
179
|
+
This class is aimed for computation, and therefore does not include structural properties.
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
def __call__(self, x: Tensor) -> Tensor:
|
|
183
|
+
"""Invoke the forward function.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
x (Tensor): The input of the circuit, shape (B, C, D).
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Tensor: The output of the circuit, shape (B, num_out, num_cls).
|
|
190
|
+
""" # TODO: single letter name?
|
|
191
|
+
# IGNORE: Idiom for nn.Module.__call__.
|
|
192
|
+
return super().__call__(x) # type: ignore[no-any-return,misc]
|
|
193
|
+
|
|
194
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
195
|
+
return self._evaluate_layers(x)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class TorchConstantCircuit(AbstractTorchCircuit):
|
|
199
|
+
"""The tensorized circuit with concrete computational graph in PyTorch.
|
|
200
|
+
|
|
201
|
+
This class is aimed for computation, and therefore does not include strutural properties.
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
def __call__(self) -> Tensor:
|
|
205
|
+
"""Invoke the forward function.
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
Tensor: The output of the circuit, shape (B, num_out, num_cls).
|
|
209
|
+
""" # TODO: single letter name?
|
|
210
|
+
# IGNORE: Idiom for nn.Module.__call__.
|
|
211
|
+
return super().__call__() # type: ignore[no-any-return,misc]
|
|
212
|
+
|
|
213
|
+
def forward(self) -> Tensor:
|
|
214
|
+
# Evaluate the layers using some dummy input
|
|
215
|
+
x = torch.empty(size=(1, self.num_channels, self.num_variables), device=self.device)
|
|
216
|
+
x = self._evaluate_layers(x) # (B, O, K)
|
|
217
|
+
return x.squeeze(dim=0) # (O, K)
|