libcirkit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. cirkit/__init__.py +0 -0
  2. cirkit/backend/__init__.py +0 -0
  3. cirkit/backend/base.py +199 -0
  4. cirkit/backend/compiler.py +213 -0
  5. cirkit/backend/registry.py +53 -0
  6. cirkit/backend/torch/__init__.py +0 -0
  7. cirkit/backend/torch/circuits.py +217 -0
  8. cirkit/backend/torch/compiler.py +592 -0
  9. cirkit/backend/torch/graph/__init__.py +0 -0
  10. cirkit/backend/torch/graph/folding.py +230 -0
  11. cirkit/backend/torch/graph/modules.py +276 -0
  12. cirkit/backend/torch/graph/optimize.py +258 -0
  13. cirkit/backend/torch/initializers.py +49 -0
  14. cirkit/backend/torch/layers/__init__.py +16 -0
  15. cirkit/backend/torch/layers/base.py +119 -0
  16. cirkit/backend/torch/layers/inner.py +335 -0
  17. cirkit/backend/torch/layers/input.py +746 -0
  18. cirkit/backend/torch/layers/optimized.py +241 -0
  19. cirkit/backend/torch/optimization/__init__.py +0 -0
  20. cirkit/backend/torch/optimization/layers.py +166 -0
  21. cirkit/backend/torch/optimization/parameters.py +67 -0
  22. cirkit/backend/torch/optimization/registry.py +81 -0
  23. cirkit/backend/torch/parameters/__init__.py +0 -0
  24. cirkit/backend/torch/parameters/nodes.py +828 -0
  25. cirkit/backend/torch/parameters/parameter.py +117 -0
  26. cirkit/backend/torch/parameters/pic.py +418 -0
  27. cirkit/backend/torch/queries.py +178 -0
  28. cirkit/backend/torch/rules/__init__.py +3 -0
  29. cirkit/backend/torch/rules/initializers.py +53 -0
  30. cirkit/backend/torch/rules/layers.py +184 -0
  31. cirkit/backend/torch/rules/parameters.py +280 -0
  32. cirkit/backend/torch/semiring.py +492 -0
  33. cirkit/backend/torch/utils.py +102 -0
  34. cirkit/pipeline.py +355 -0
  35. cirkit/symbolic/__init__.py +0 -0
  36. cirkit/symbolic/circuit.py +938 -0
  37. cirkit/symbolic/dtypes.py +45 -0
  38. cirkit/symbolic/functional.py +674 -0
  39. cirkit/symbolic/initializers.py +121 -0
  40. cirkit/symbolic/layers.py +788 -0
  41. cirkit/symbolic/operators.py +384 -0
  42. cirkit/symbolic/parameters.py +921 -0
  43. cirkit/symbolic/registry.py +119 -0
  44. cirkit/templates/__init__.py +0 -0
  45. cirkit/templates/circuit_templates/__init__.py +2 -0
  46. cirkit/templates/circuit_templates/data.py +107 -0
  47. cirkit/templates/circuit_templates/utils.py +287 -0
  48. cirkit/templates/region_graph/__init__.py +11 -0
  49. cirkit/templates/region_graph/algorithms/__init__.py +9 -0
  50. cirkit/templates/region_graph/algorithms/chow_liu.py +141 -0
  51. cirkit/templates/region_graph/algorithms/factorized.py +43 -0
  52. cirkit/templates/region_graph/algorithms/linear.py +77 -0
  53. cirkit/templates/region_graph/algorithms/poon_domingos.py +203 -0
  54. cirkit/templates/region_graph/algorithms/quad.py +179 -0
  55. cirkit/templates/region_graph/algorithms/random.py +110 -0
  56. cirkit/templates/region_graph/algorithms/utils.py +124 -0
  57. cirkit/templates/region_graph/graph.py +335 -0
  58. cirkit/utils/__init__.py +0 -0
  59. cirkit/utils/algorithms.py +218 -0
  60. cirkit/utils/scope.py +13 -0
  61. libcirkit-0.1.0.dist-info/LICENSE +674 -0
  62. libcirkit-0.1.0.dist-info/METADATA +200 -0
  63. libcirkit-0.1.0.dist-info/RECORD +65 -0
  64. libcirkit-0.1.0.dist-info/WHEEL +5 -0
  65. libcirkit-0.1.0.dist-info/top_level.txt +1 -0
cirkit/__init__.py ADDED
File without changes
File without changes
cirkit/backend/base.py ADDED
@@ -0,0 +1,199 @@
1
+ import os
2
+ from abc import ABC, abstractmethod
3
+ from collections.abc import Callable
4
+ from typing import IO, Any, Protocol, TypeVar
5
+
6
+ from cirkit.symbolic.circuit import Circuit
7
+ from cirkit.symbolic.initializers import Initializer
8
+ from cirkit.symbolic.layers import Layer
9
+ from cirkit.symbolic.parameters import ParameterNode
10
+ from cirkit.utils.algorithms import BiMap
11
+
12
+ CompiledCircuit = TypeVar("CompiledCircuit")
13
+ LayerCompilationSign = type[Layer]
14
+ ParameterCompilationSign = type[ParameterNode]
15
+ InitializerCompilationSign = type[Initializer]
16
+
17
+
18
+ class LayerCompilationFunc(Protocol):
19
+ def __call__(self, compiler: "AbstractCompiler", sl: Layer, **kwargs) -> Any:
20
+ ...
21
+
22
+
23
+ class ParameterCompilationFunc(Protocol):
24
+ def __call__(self, compiler: "AbstractCompiler", p: ParameterNode, **kwargs) -> Any:
25
+ ...
26
+
27
+
28
+ class InitializerCompilationFunc(Protocol):
29
+ def __call__(self, compiler: "AbstractCompiler", init: Initializer, **kwargs) -> Any:
30
+ ...
31
+
32
+
33
+ class CompilationRuleNotFound(Exception):
34
+ def __init__(self, msg: str):
35
+ super().__init__(msg)
36
+
37
+
38
+ SUPPORTED_BACKENDS = ["torch"]
39
+
40
+
41
+ class CompiledCircuitsMap:
42
+ def __init__(self):
43
+ self._bimap = BiMap[Circuit, CompiledCircuit]()
44
+
45
+ def is_compiled(self, sc: Circuit) -> bool:
46
+ return self._bimap.has_left(sc)
47
+
48
+ def has_symbolic(self, cc: CompiledCircuit) -> bool:
49
+ return self._bimap.has_right(cc)
50
+
51
+ def get_compiled_circuit(self, sc: Circuit) -> CompiledCircuit:
52
+ return self._bimap.get_left(sc)
53
+
54
+ def get_symbolic_circuit(self, cc: CompiledCircuit) -> Circuit:
55
+ return self._bimap.get_right(cc)
56
+
57
+ def register_compiled_circuit(self, sc: Circuit, cc: CompiledCircuit):
58
+ self._bimap.add(sc, cc)
59
+
60
+
61
+ class CompilerRegistry:
62
+ def __init__(
63
+ self,
64
+ layer_rules: dict[LayerCompilationSign, LayerCompilationFunc] | None = None,
65
+ parameter_rules: dict[ParameterCompilationSign, ParameterCompilationFunc] | None = None,
66
+ initializer_rules: None
67
+ | (dict[InitializerCompilationSign, InitializerCompilationFunc]) = None,
68
+ ):
69
+ self._layer_rules = {} if layer_rules is None else layer_rules
70
+ self._parameter_rules = {} if parameter_rules is None else parameter_rules
71
+ self._initializer_rules = {} if initializer_rules is None else initializer_rules
72
+
73
+ @staticmethod
74
+ def _validate_rule_sign(func: Callable, sym_cls: type) -> type | None:
75
+ args = func.__annotations__
76
+ if "return" not in args or "compiler" not in args or len(args) != 3:
77
+ return None
78
+ if not issubclass(args["compiler"], AbstractCompiler):
79
+ return None
80
+ arg_names = list(filter(lambda a: a not in ("return", "compiler"), args.keys()))
81
+ found_sym_cls = args[arg_names[0]]
82
+ if not issubclass(found_sym_cls, sym_cls):
83
+ return None
84
+ return found_sym_cls
85
+
86
+ def add_layer_rule(self, func: LayerCompilationFunc):
87
+ layer_cls: type[Layer] | None = self._validate_rule_sign(func, Layer)
88
+ if layer_cls is None:
89
+ raise ValueError("The function is not a symbolic layer compilation rule")
90
+ self._layer_rules[layer_cls] = func
91
+
92
+ def add_parameter_rule(self, func: ParameterCompilationFunc):
93
+ param_cls: type[ParameterNode] | None = self._validate_rule_sign(func, ParameterNode)
94
+ if param_cls is None:
95
+ raise ValueError("The function is not a symbolic parameter compilation rule")
96
+ self._parameter_rules[param_cls] = func
97
+
98
+ def add_initializer_rule(self, func: InitializerCompilationFunc):
99
+ init_cls: type[Initializer] | None = self._validate_rule_sign(func, Initializer)
100
+ if init_cls is None:
101
+ raise ValueError("The function is not a symbolic initializer compilation rule")
102
+ self._initializer_rules[init_cls] = func
103
+
104
+ def retrieve_layer_rule(self, signature: LayerCompilationSign) -> LayerCompilationFunc:
105
+ if signature not in self._layer_rules:
106
+ raise CompilationRuleNotFound(
107
+ f"Layer compilation rule for signature '{signature}' not found"
108
+ )
109
+ return self._layer_rules[signature]
110
+
111
+ def retrieve_parameter_rule(
112
+ self, signature: ParameterCompilationSign
113
+ ) -> ParameterCompilationFunc:
114
+ if signature not in self._parameter_rules:
115
+ raise CompilationRuleNotFound(
116
+ f"Parameter compilation rule for signature '{signature}' not found"
117
+ )
118
+ return self._parameter_rules[signature]
119
+
120
+ def retrieve_initializer_rule(
121
+ self, signature: InitializerCompilationSign
122
+ ) -> InitializerCompilationFunc:
123
+ if signature not in self._initializer_rules:
124
+ raise CompilationRuleNotFound(
125
+ f"Initializer compilation rule for signature '{signature}' not found"
126
+ )
127
+ return self._initializer_rules[signature]
128
+
129
+
130
+ class AbstractCompiler(ABC):
131
+ def __init__(self, registry: CompilerRegistry, **flags):
132
+ self._registry = registry
133
+ self._flags = flags
134
+ self._compiled_circuits = CompiledCircuitsMap()
135
+
136
+ def is_compiled(self, sc: Circuit) -> bool:
137
+ return self._compiled_circuits.is_compiled(sc)
138
+
139
+ def has_symbolic(self, cc: CompiledCircuit) -> bool:
140
+ return self._compiled_circuits.has_symbolic(cc)
141
+
142
+ def get_compiled_circuit(self, sc: Circuit) -> CompiledCircuit:
143
+ return self._compiled_circuits.get_compiled_circuit(sc)
144
+
145
+ def get_symbolic_circuit(self, cc: CompiledCircuit) -> Circuit:
146
+ return self._compiled_circuits.get_symbolic_circuit(cc)
147
+
148
+ def register_compiled_circuit(self, sc: Circuit, cc: CompiledCircuit):
149
+ self._compiled_circuits.register_compiled_circuit(sc, cc)
150
+
151
+ def add_layer_rule(self, func: LayerCompilationFunc):
152
+ self._registry.add_layer_rule(func)
153
+
154
+ def add_parameter_rule(self, func: ParameterCompilationFunc):
155
+ self._registry.add_parameter_rule(func)
156
+
157
+ def add_initializer_rule(self, func: InitializerCompilationFunc):
158
+ self._registry.add_initializer_rule(func)
159
+
160
+ def retrieve_layer_rule(self, signature: LayerCompilationSign) -> LayerCompilationFunc:
161
+ return self._registry.retrieve_layer_rule(signature)
162
+
163
+ def retrieve_parameter_rule(
164
+ self, signature: ParameterCompilationSign
165
+ ) -> ParameterCompilationFunc:
166
+ return self._registry.retrieve_parameter_rule(signature)
167
+
168
+ def retrieve_initializer_rule(
169
+ self, signature: InitializerCompilationSign
170
+ ) -> InitializerCompilationFunc:
171
+ return self._registry.retrieve_initializer_rule(signature)
172
+
173
+ def compile(self, sc: Circuit) -> CompiledCircuit:
174
+ if self.is_compiled(sc):
175
+ return self.get_compiled_circuit(sc)
176
+ return self.compile_pipeline(sc)
177
+
178
+ @abstractmethod
179
+ def compile_layer(self, sl: Layer) -> Any:
180
+ ...
181
+
182
+ @abstractmethod
183
+ def compile_pipeline(self, sc: Circuit) -> CompiledCircuit:
184
+ ...
185
+
186
+ @abstractmethod
187
+ def save(
188
+ self,
189
+ sym_filepath: IO | os.PathLike | str,
190
+ compiled_filepath: IO | os.PathLike | str,
191
+ ):
192
+ ...
193
+
194
+ @staticmethod
195
+ @abstractmethod
196
+ def load(
197
+ sym_filepath: IO | os.PathLike | str, tens_filepath: IO | os.PathLike | str
198
+ ) -> "AbstractCompiler":
199
+ ...
@@ -0,0 +1,213 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Protocol, TypeVar, cast
3
+
4
+ from cirkit.backend.registry import CompilerRegistry
5
+ from cirkit.symbolic.circuit import Circuit
6
+ from cirkit.symbolic.initializers import Initializer
7
+ from cirkit.symbolic.layers import Layer
8
+ from cirkit.symbolic.parameters import ParameterNode
9
+ from cirkit.utils.algorithms import BiMap
10
+
11
+ SUPPORTED_BACKENDS = ["torch"]
12
+
13
+ CompiledCircuit = TypeVar("CompiledCircuit")
14
+
15
+
16
+ class CompiledCircuitsMap:
17
+ def __init__(self):
18
+ self._bimap = BiMap[Circuit, CompiledCircuit]()
19
+
20
+ def is_compiled(self, sc: Circuit) -> bool:
21
+ return self._bimap.has_left(sc)
22
+
23
+ def has_symbolic(self, cc: CompiledCircuit) -> bool:
24
+ return self._bimap.has_right(cc)
25
+
26
+ def get_compiled_circuit(self, sc: Circuit) -> CompiledCircuit:
27
+ return self._bimap.get_left(sc)
28
+
29
+ def get_symbolic_circuit(self, cc: CompiledCircuit) -> Circuit:
30
+ return self._bimap.get_right(cc)
31
+
32
+ def register_compiled_circuit(self, sc: Circuit, cc: CompiledCircuit):
33
+ self._bimap.add(sc, cc)
34
+
35
+
36
+ LayerCompilationSign = type[Layer]
37
+ ParameterCompilationSign = type[ParameterNode]
38
+ InitializerCompilationSign = type[Initializer]
39
+
40
+
41
+ class LayerCompilationFunc(Protocol):
42
+ """The layer compilation function protocol."""
43
+
44
+ def __call__(self, compiler: "AbstractCompiler", sl: Layer, **kwargs) -> Any:
45
+ """Compile a symbolic layer, given a compiler.
46
+
47
+ Args:
48
+ compiler: The compiler.
49
+ sl: The symbolic layer.
50
+ **kwargs: The optional arguments for the compilation.
51
+
52
+ Returns:
53
+ A representation of the compiled layer, which depends on the chosen compilation backend.
54
+ """
55
+
56
+
57
+ class ParameterCompilationFunc(Protocol):
58
+ """The parameter node compilation function protocol."""
59
+
60
+ def __call__(self, compiler: "AbstractCompiler", p: ParameterNode, **kwargs) -> Any:
61
+ """Compile a symbolic parameter node, given a compiler.
62
+
63
+ Args:
64
+ compiler: The compiler.
65
+ p: The symbolic parameter node.
66
+ **kwargs: The optional arguments for the compilation.
67
+
68
+ Returns:
69
+ A representation of the compiled parameter node,
70
+ which depends on the chosen compilation backend.
71
+ """
72
+
73
+
74
+ class InitializerCompilationFunc(Protocol):
75
+ """The initialization method compilation function protocol."""
76
+
77
+ def __call__(self, compiler: "AbstractCompiler", init: Initializer, **kwargs) -> Any:
78
+ """Compile a symbolic initializer, given a compiler.
79
+
80
+ Args:
81
+ compiler: The compiler.
82
+ init: The symbolic initializer.
83
+ **kwargs: The optional arguments for the compilation.
84
+
85
+ Returns:
86
+ A representation of the compiled initializer,
87
+ which depends on the chosen compilation backend.
88
+ """
89
+
90
+
91
+ class CompilationRuleNotFound(Exception):
92
+ """An exception that is raised when a compilation rule is not found."""
93
+
94
+ def __init__(self, msg: str):
95
+ """Initializes a compilation rule not found exception.
96
+
97
+ Args:
98
+ msg: The message of the exception.
99
+ """
100
+ super().__init__(msg)
101
+
102
+
103
+ class CompilerLayerRegistry(CompilerRegistry[LayerCompilationSign, LayerCompilationFunc]):
104
+ @classmethod
105
+ def _validate_rule_function(cls, func: LayerCompilationFunc) -> bool:
106
+ ann = func.__annotations__.copy()
107
+ del ann["return"]
108
+ args = tuple(ann.keys())
109
+ return issubclass(ann[args[-1]], Layer)
110
+
111
+ @classmethod
112
+ def _retrieve_signature(cls, func: LayerCompilationFunc) -> LayerCompilationSign:
113
+ ann = func.__annotations__.copy()
114
+ del ann["return"]
115
+ args = tuple(ann.keys())
116
+ return cast(LayerCompilationSign, ann[args[-1]])
117
+
118
+
119
+ class CompilerParameterRegistry(
120
+ CompilerRegistry[ParameterCompilationSign, ParameterCompilationFunc]
121
+ ):
122
+ @classmethod
123
+ def _validate_rule_function(cls, func: ParameterCompilationFunc) -> bool:
124
+ ann = func.__annotations__.copy()
125
+ del ann["return"]
126
+ args = tuple(ann.keys())
127
+ return issubclass(ann[args[-1]], ParameterNode)
128
+
129
+ @classmethod
130
+ def _retrieve_signature(cls, func: ParameterCompilationFunc) -> ParameterCompilationSign:
131
+ ann = func.__annotations__.copy()
132
+ del ann["return"]
133
+ args = tuple(ann.keys())
134
+ return cast(ParameterCompilationSign, ann[args[-1]])
135
+
136
+
137
+ class CompilerInitializerRegistry(
138
+ CompilerRegistry[InitializerCompilationSign, InitializerCompilationFunc]
139
+ ):
140
+ @classmethod
141
+ def _validate_rule_function(cls, func: InitializerCompilationFunc) -> bool:
142
+ ann = func.__annotations__.copy()
143
+ del ann["return"]
144
+ args = tuple(ann.keys())
145
+ return issubclass(ann[args[-1]], Initializer)
146
+
147
+ @classmethod
148
+ def _retrieve_signature(cls, func: ParameterCompilationFunc) -> InitializerCompilationSign:
149
+ ann = func.__annotations__.copy()
150
+ del ann["return"]
151
+ args = tuple(ann.keys())
152
+ return cast(InitializerCompilationSign, ann[args[-1]])
153
+
154
+
155
+ class AbstractCompiler(ABC):
156
+ def __init__(
157
+ self,
158
+ layers_registry: CompilerLayerRegistry,
159
+ parameters_registry: CompilerParameterRegistry,
160
+ initializers_registry: CompilerInitializerRegistry,
161
+ **flags,
162
+ ):
163
+ self._layers_registry = layers_registry
164
+ self._parameters_registry = parameters_registry
165
+ self._initializers_registry = initializers_registry
166
+ self._flags = flags
167
+ self._compiled_circuits = CompiledCircuitsMap()
168
+
169
+ def is_compiled(self, sc: Circuit) -> bool:
170
+ return self._compiled_circuits.is_compiled(sc)
171
+
172
+ def has_symbolic(self, cc: CompiledCircuit) -> bool:
173
+ return self._compiled_circuits.has_symbolic(cc)
174
+
175
+ def get_compiled_circuit(self, sc: Circuit) -> CompiledCircuit:
176
+ return self._compiled_circuits.get_compiled_circuit(sc)
177
+
178
+ def get_symbolic_circuit(self, cc: CompiledCircuit) -> Circuit:
179
+ return self._compiled_circuits.get_symbolic_circuit(cc)
180
+
181
+ def register_compiled_circuit(self, sc: Circuit, cc: CompiledCircuit):
182
+ self._compiled_circuits.register_compiled_circuit(sc, cc)
183
+
184
+ def add_layer_rule(self, func: LayerCompilationFunc):
185
+ self._layers_registry.add_rule(func)
186
+
187
+ def add_parameter_rule(self, func: ParameterCompilationFunc):
188
+ self._parameters_registry.add_rule(func)
189
+
190
+ def add_initializer_rule(self, func: InitializerCompilationFunc):
191
+ self._initializers_registry.add_rule(func)
192
+
193
+ def retrieve_layer_rule(self, signature: LayerCompilationSign) -> LayerCompilationFunc:
194
+ return self._layers_registry.retrieve_rule(signature)
195
+
196
+ def retrieve_parameter_rule(
197
+ self, signature: ParameterCompilationSign
198
+ ) -> ParameterCompilationFunc:
199
+ return self._parameters_registry.retrieve_rule(signature)
200
+
201
+ def retrieve_initializer_rule(
202
+ self, signature: InitializerCompilationSign
203
+ ) -> InitializerCompilationFunc:
204
+ return self._initializers_registry.retrieve_rule(signature)
205
+
206
+ def compile(self, sc: Circuit) -> CompiledCircuit:
207
+ if self.is_compiled(sc):
208
+ return self.get_compiled_circuit(sc)
209
+ return self.compile_pipeline(sc)
210
+
211
+ @abstractmethod
212
+ def compile_pipeline(self, sc: Circuit) -> CompiledCircuit:
213
+ ...
@@ -0,0 +1,53 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Generic, TypeVar
3
+
4
+ RegistrySign = TypeVar("RegistrySign")
5
+ RegistryFunc = TypeVar("RegistryFunc")
6
+
7
+
8
+ class InvalidRuleSign(Exception):
9
+ def __init__(self, annotations: dict[str, type]):
10
+ super().__init__(
11
+ f"Cannot extract rule signature from function with annotations '{annotations}"
12
+ )
13
+
14
+
15
+ class InvalidRuleFunction(Exception):
16
+ def __init__(self, annotations: dict[str, type]):
17
+ super().__init__(f"Invalid Compilation rule function with annotations '{annotations}'")
18
+
19
+
20
+ class CompilationRuleNotFound(Exception):
21
+ def __init__(self, signature: RegistrySign):
22
+ super().__init__(f"Compilation rule for signature '{signature}' not found")
23
+
24
+
25
+ class CompilerRegistry(Generic[RegistrySign, RegistryFunc], ABC):
26
+ def __init__(self, rules: dict[RegistrySign, RegistryFunc] | None = None):
27
+ self._rules = {} if rules is None else rules
28
+
29
+ @classmethod
30
+ @abstractmethod
31
+ def _validate_rule_function(cls, func: RegistryFunc) -> bool:
32
+ ...
33
+
34
+ @classmethod
35
+ def _retrieve_signature(cls, func: RegistryFunc) -> RegistrySign:
36
+ raise InvalidRuleSign(func.__annotations__)
37
+
38
+ @property
39
+ def signatures(self) -> list[RegistrySign]:
40
+ return list(self._rules)
41
+
42
+ def add_rule(self, func: RegistryFunc, *, signature: RegistrySign | None = None) -> None:
43
+ if not self._validate_rule_function(func):
44
+ raise InvalidRuleFunction(func.__annotations__)
45
+ if signature is None:
46
+ signature = self._retrieve_signature(func)
47
+ self._rules[signature] = func
48
+
49
+ def retrieve_rule(self, signature: RegistrySign) -> RegistryFunc:
50
+ func = self._rules.get(signature, None)
51
+ if func is not None:
52
+ return func
53
+ raise CompilationRuleNotFound(signature)
File without changes
@@ -0,0 +1,217 @@
1
+ from collections.abc import Callable, Iterator, Sequence
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from cirkit.backend.torch.graph.folding import (
7
+ build_address_book_stacked_entry,
8
+ build_unfold_index_info,
9
+ )
10
+ from cirkit.backend.torch.graph.modules import (
11
+ AddressBook,
12
+ AddressBookEntry,
13
+ FoldIndexInfo,
14
+ TorchDiAcyclicGraph,
15
+ )
16
+ from cirkit.backend.torch.layers import TorchInputLayer, TorchLayer
17
+ from cirkit.symbolic.circuit import StructuralProperties
18
+ from cirkit.utils.scope import Scope
19
+
20
+
21
+ class LayerAddressBook(AddressBook):
22
+ def __init__(self, entries: list[AddressBookEntry]):
23
+ super().__init__(entries)
24
+
25
+ def lookup(
26
+ self, module_outputs: list[Tensor], *, in_graph: Tensor | None = None
27
+ ) -> Iterator[tuple[TorchLayer | None, tuple[Tensor, ...]]]:
28
+ # Loop through the entries and yield inputs
29
+ for entry in self._entries:
30
+ # Catch the case there are some inputs coming from other modules
31
+ if entry.in_module_ids:
32
+ (in_fold_idx,) = entry.in_fold_idx
33
+ (in_module_ids,) = entry.in_module_ids
34
+ if len(in_module_ids) == 1:
35
+ x = module_outputs[in_module_ids[0]]
36
+ else:
37
+ x = torch.cat([module_outputs[mid] for mid in in_module_ids], dim=0)
38
+ x = x[in_fold_idx]
39
+ yield entry.module, (x,)
40
+ continue
41
+
42
+ # Catch the case there are no inputs coming from other modules
43
+ # That is, we are gathering the inputs of input layers
44
+ assert isinstance(entry.module, TorchInputLayer)
45
+ if in_graph is None:
46
+ yield entry.module, ()
47
+ else:
48
+ # in_graph: An input batch (assignments to variables) of shape (B, C, D)
49
+ # scope_idx: The scope of the layers in each fold, a tensor of shape (F, D'), D' < D
50
+ # x: (B, C, D) -> (B, C, F, D') -> (F, C, B, D')
51
+ x = in_graph[..., entry.module.scope_idx].permute(2, 1, 0, 3)
52
+ yield entry.module, (x,)
53
+
54
+ @classmethod
55
+ def from_index_info(
56
+ cls,
57
+ fold_idx_info: FoldIndexInfo,
58
+ *,
59
+ incomings_fn: Callable[[TorchLayer], Sequence[TorchLayer]],
60
+ ) -> "LayerAddressBook":
61
+ # The address book entries being built
62
+ entries: list[AddressBookEntry] = []
63
+
64
+ # A useful dictionary mapping module ids to their number of folds
65
+ num_folds: dict[int, int] = {}
66
+
67
+ # Build the bookkeeping data structure by following the topological ordering
68
+ for mid, m in enumerate(fold_idx_info.ordering):
69
+ # Retrieve the index information of the input modules
70
+ in_modules_fold_idx = fold_idx_info.in_fold_idx[mid]
71
+
72
+ # Catch the case of a folded module having the output of another module as input
73
+ if incomings_fn(m):
74
+ entry = build_address_book_stacked_entry(
75
+ m, in_modules_fold_idx, num_folds=num_folds
76
+ )
77
+ else:
78
+ # Catch the case of a folded module having the input of the network as input
79
+ # That is, this is the case of an input layer
80
+ entry = AddressBookEntry(m, [], [])
81
+
82
+ num_folds[mid] = m.num_folds
83
+ entries.append(entry)
84
+
85
+ # Append the last bookkeeping entry with the information to compute the output tensor
86
+ entry = build_address_book_stacked_entry(
87
+ None, [fold_idx_info.out_fold_idx], num_folds=num_folds, output=True
88
+ )
89
+ entries.append(entry)
90
+
91
+ return LayerAddressBook(entries)
92
+
93
+
94
+ class AbstractTorchCircuit(TorchDiAcyclicGraph[TorchLayer]):
95
+ def __init__(
96
+ self,
97
+ scope: Scope,
98
+ num_channels: int,
99
+ layers: Sequence[TorchLayer],
100
+ in_layers: dict[TorchLayer, Sequence[TorchLayer]],
101
+ outputs: Sequence[TorchLayer],
102
+ *,
103
+ properties: StructuralProperties,
104
+ fold_idx_info: FoldIndexInfo | None = None,
105
+ ) -> None:
106
+ super().__init__(
107
+ layers,
108
+ in_layers,
109
+ outputs,
110
+ fold_idx_info=fold_idx_info,
111
+ )
112
+ self._scope = scope
113
+ self._num_channels = num_channels
114
+ self._properties = properties
115
+
116
+ @property
117
+ def scope(self) -> Scope:
118
+ return self._scope
119
+
120
+ @property
121
+ def num_variables(self) -> int:
122
+ return len(self.scope)
123
+
124
+ @property
125
+ def num_channels(self) -> int:
126
+ return self._num_channels
127
+
128
+ @property
129
+ def properties(self) -> StructuralProperties:
130
+ return self._properties
131
+
132
+ @property
133
+ def layers(self) -> Sequence[TorchLayer]:
134
+ return self.nodes
135
+
136
+ def layer_inputs(self, l: TorchLayer) -> Sequence[TorchLayer]:
137
+ return self.node_inputs(l)
138
+
139
+ def layer_outputs(self, l: TorchLayer) -> Sequence[TorchLayer]:
140
+ return self.node_outputs(l)
141
+
142
+ @property
143
+ def layers_inputs(self) -> dict[TorchLayer, Sequence[TorchLayer]]:
144
+ return self.nodes_inputs
145
+
146
+ @property
147
+ def layers_outputs(self) -> dict[TorchLayer, Sequence[TorchLayer]]:
148
+ return self.nodes_outputs
149
+
150
+ def reset_parameters(self) -> None:
151
+ # For each layer, initialize its parameters, if any
152
+ for l in self.layers:
153
+ for p in l.params.values():
154
+ p.reset_parameters()
155
+
156
+ def _set_device(self, device: str | torch.device | int) -> None:
157
+ for l in self.layers:
158
+ for p in l.params.values():
159
+ p._set_device(device)
160
+ super()._set_device(device)
161
+
162
+ def _build_unfold_index_info(self) -> FoldIndexInfo:
163
+ return build_unfold_index_info(
164
+ self.topological_ordering(), outputs=self.outputs, incomings_fn=self.node_inputs
165
+ )
166
+
167
+ def _build_address_book(self, fold_idx_info: FoldIndexInfo) -> LayerAddressBook:
168
+ return LayerAddressBook.from_index_info(fold_idx_info, incomings_fn=self.layer_inputs)
169
+
170
+ def _evaluate_layers(self, x: Tensor) -> Tensor:
171
+ # Evaluate layers on the given input
172
+ y = self.evaluate(x) # (O, B, K)
173
+ return y.transpose(0, 1) # (B, O, K)
174
+
175
+
176
+ class TorchCircuit(AbstractTorchCircuit):
177
+ """The tensorized circuit with concrete computational graph in PyTorch.
178
+
179
+ This class is aimed for computation, and therefore does not include structural properties.
180
+ """
181
+
182
+ def __call__(self, x: Tensor) -> Tensor:
183
+ """Invoke the forward function.
184
+
185
+ Args:
186
+ x (Tensor): The input of the circuit, shape (B, C, D).
187
+
188
+ Returns:
189
+ Tensor: The output of the circuit, shape (B, num_out, num_cls).
190
+ """ # TODO: single letter name?
191
+ # IGNORE: Idiom for nn.Module.__call__.
192
+ return super().__call__(x) # type: ignore[no-any-return,misc]
193
+
194
+ def forward(self, x: Tensor) -> Tensor:
195
+ return self._evaluate_layers(x)
196
+
197
+
198
+ class TorchConstantCircuit(AbstractTorchCircuit):
199
+ """The tensorized circuit with concrete computational graph in PyTorch.
200
+
201
+ This class is aimed for computation, and therefore does not include strutural properties.
202
+ """
203
+
204
+ def __call__(self) -> Tensor:
205
+ """Invoke the forward function.
206
+
207
+ Returns:
208
+ Tensor: The output of the circuit, shape (B, num_out, num_cls).
209
+ """ # TODO: single letter name?
210
+ # IGNORE: Idiom for nn.Module.__call__.
211
+ return super().__call__() # type: ignore[no-any-return,misc]
212
+
213
+ def forward(self) -> Tensor:
214
+ # Evaluate the layers using some dummy input
215
+ x = torch.empty(size=(1, self.num_channels, self.num_variables), device=self.device)
216
+ x = self._evaluate_layers(x) # (B, O, K)
217
+ return x.squeeze(dim=0) # (O, K)