accelforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. accelforge/__init__.py +21 -0
  2. accelforge/_accelerated_imports.py +16 -0
  3. accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
  4. accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
  5. accelforge/_deprecate/_simanneal/simanneal.py +666 -0
  6. accelforge/_deprecate/_simanneal/tracking.py +105 -0
  7. accelforge/_deprecate/_simanneal/wrappers.py +218 -0
  8. accelforge/_deprecate/_simanneal2/__init__.py +7 -0
  9. accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
  10. accelforge/_deprecate/_simanneal2/tracking.py +116 -0
  11. accelforge/_deprecate/compatibility_util.py +181 -0
  12. accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
  13. accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
  14. accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
  15. accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
  16. accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
  17. accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
  18. accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
  19. accelforge/_deprecate/tags.py +69 -0
  20. accelforge/_deprecate/viz/__init__.py +0 -0
  21. accelforge/_deprecate/viz/interactive.py +159 -0
  22. accelforge/_deprecate/viz/reservationtree.py +307 -0
  23. accelforge/_deprecate/viz/ski_slope.py +88 -0
  24. accelforge/_version.py +15 -0
  25. accelforge/examples.py +39 -0
  26. accelforge/frontend/__init__.py +10 -0
  27. accelforge/frontend/_binding.py +129 -0
  28. accelforge/frontend/_workload_isl/__init__.py +2 -0
  29. accelforge/frontend/_workload_isl/_isl.py +149 -0
  30. accelforge/frontend/_workload_isl/_symbolic.py +141 -0
  31. accelforge/frontend/arch copy.py +1544 -0
  32. accelforge/frontend/arch.py +1642 -0
  33. accelforge/frontend/config.py +63 -0
  34. accelforge/frontend/mapper/__init__.py +5 -0
  35. accelforge/frontend/mapper/ffm.py +126 -0
  36. accelforge/frontend/mapper/mapper.py +7 -0
  37. accelforge/frontend/mapper/metrics.py +30 -0
  38. accelforge/frontend/mapping/__init__.py +1 -0
  39. accelforge/frontend/mapping/mapping.py +1736 -0
  40. accelforge/frontend/model.py +14 -0
  41. accelforge/frontend/renames.py +150 -0
  42. accelforge/frontend/spec copy.py +230 -0
  43. accelforge/frontend/spec.py +301 -0
  44. accelforge/frontend/variables.py +12 -0
  45. accelforge/frontend/workload.py +952 -0
  46. accelforge/mapper/FFM/__init__.py +9 -0
  47. accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
  48. accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
  49. accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
  50. accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
  51. accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
  52. accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
  53. accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
  54. accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
  55. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
  56. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
  57. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
  58. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
  59. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
  60. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
  61. accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
  62. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
  63. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
  64. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
  65. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
  66. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
  67. accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
  68. accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
  69. accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
  70. accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
  71. accelforge/mapper/FFM/data.py +61 -0
  72. accelforge/mapper/FFM/main copy.py +236 -0
  73. accelforge/mapper/FFM/main.py +208 -0
  74. accelforge/mapper/FFM/mappings.py +510 -0
  75. accelforge/mapper/FFM/pmappings.py +310 -0
  76. accelforge/mapper/__init__.py +4 -0
  77. accelforge/mapper.py +0 -0
  78. accelforge/model/__init__.py +1 -0
  79. accelforge/model/_looptree/__init__.py +0 -0
  80. accelforge/model/_looptree/accesses.py +335 -0
  81. accelforge/model/_looptree/capacity/__init__.py +1 -0
  82. accelforge/model/_looptree/capacity/aggregators.py +36 -0
  83. accelforge/model/_looptree/capacity/capacity.py +47 -0
  84. accelforge/model/_looptree/energy.py +150 -0
  85. accelforge/model/_looptree/equivalent_ranks.py +29 -0
  86. accelforge/model/_looptree/latency/__init__.py +1 -0
  87. accelforge/model/_looptree/latency/latency.py +98 -0
  88. accelforge/model/_looptree/latency/memory.py +120 -0
  89. accelforge/model/_looptree/latency/processors.py +92 -0
  90. accelforge/model/_looptree/mapping_utilities.py +71 -0
  91. accelforge/model/_looptree/reuse/__init__.py +4 -0
  92. accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
  93. accelforge/model/_looptree/reuse/isl/des.py +59 -0
  94. accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
  95. accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
  96. accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
  97. accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
  98. accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
  99. accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
  100. accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
  101. accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
  102. accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
  103. accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
  104. accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
  105. accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
  106. accelforge/model/_looptree/run.py +122 -0
  107. accelforge/model/_looptree/types.py +26 -0
  108. accelforge/model/_looptree/visualization/__init__.py +0 -0
  109. accelforge/model/_looptree/visualization/occupancy.py +11 -0
  110. accelforge/model/main.py +222 -0
  111. accelforge/plotting/__init__.py +2 -0
  112. accelforge/plotting/mappings.py +219 -0
  113. accelforge/plotting/specs.py +57 -0
  114. accelforge/util/__init__.py +4 -0
  115. accelforge/util/_base_analysis_types.py +24 -0
  116. accelforge/util/_basetypes.py +1089 -0
  117. accelforge/util/_frozenset.py +36 -0
  118. accelforge/util/_isl.py +29 -0
  119. accelforge/util/_itertools.py +14 -0
  120. accelforge/util/_mathfuncs.py +57 -0
  121. accelforge/util/_parse_expressions.py +339 -0
  122. accelforge/util/_picklecache.py +32 -0
  123. accelforge/util/_setexpressions.py +268 -0
  124. accelforge/util/_sympy/__init__.py +0 -0
  125. accelforge/util/_sympy/broadcast_max.py +18 -0
  126. accelforge/util/_visualization.py +112 -0
  127. accelforge/util/_yaml.py +579 -0
  128. accelforge/util/parallel.py +193 -0
  129. accelforge-0.0.1.dist-info/METADATA +64 -0
  130. accelforge-0.0.1.dist-info/RECORD +258 -0
  131. accelforge-0.0.1.dist-info/WHEEL +5 -0
  132. accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
  133. accelforge-0.0.1.dist-info/top_level.txt +5 -0
  134. docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
  135. docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
  136. docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
  137. docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
  138. docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
  139. docs/_build/html/_sources/fastfusion.rst.txt +20 -0
  140. docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
  141. docs/_build/html/_sources/index.rst.txt +87 -0
  142. docs/_build/html/_sources/modules.rst.txt +7 -0
  143. docs/_build/html/_sources/notes/citation.rst.txt +45 -0
  144. docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
  145. docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
  146. docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
  147. docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
  148. docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
  149. docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
  150. docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
  151. docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
  152. docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
  153. docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
  154. docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
  155. docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
  156. docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
  157. docs/_build/html/_sources/notes/spec.rst.txt +36 -0
  158. docs/source/_ext/include_attrs.py +213 -0
  159. docs/source/_ext/include_docstring.py +364 -0
  160. docs/source/_ext/include_functions.py +154 -0
  161. docs/source/_ext/include_notebook.py +131 -0
  162. docs/source/_ext/include_yaml.py +119 -0
  163. docs/source/_ext/inherited_attributes.py +222 -0
  164. docs/source/_ext/paths.py +4 -0
  165. docs/source/conf.py +79 -0
  166. examples/arches/compute_in_memory/_include.yaml +74 -0
  167. examples/arches/compute_in_memory/_include_functions.py +229 -0
  168. examples/arches/compute_in_memory/_load_spec.py +57 -0
  169. examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
  170. examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
  171. examples/arches/compute_in_memory/components/misc.py +195 -0
  172. examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
  173. examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
  174. examples/arches/compute_in_memory/isaac.yaml +233 -0
  175. examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
  176. examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
  177. examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
  178. examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
  179. examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
  180. examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
  181. examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
  182. examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
  183. examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
  184. examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
  185. examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
  186. examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
  187. examples/arches/eyeriss.yaml +68 -0
  188. examples/arches/fanout_variations/at_glb.yaml +31 -0
  189. examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
  190. examples/arches/fanout_variations/at_mac.yaml +31 -0
  191. examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
  192. examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
  193. examples/arches/nvdla.yaml +47 -0
  194. examples/arches/simple.yaml +28 -0
  195. examples/arches/tpu_v4i.yaml +67 -0
  196. examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
  197. examples/misc/component_annotated.yaml +33 -0
  198. examples/workloads/gpt3_6.7B.yaml +124 -0
  199. examples/workloads/matmuls.yaml +20 -0
  200. examples/workloads/mobilenet_28.yaml +81 -0
  201. examples/workloads/mobilenet_various_separate.yaml +106 -0
  202. examples/workloads/three_matmuls_annotated.yaml +59 -0
  203. notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
  204. notebooks/compute_in_memory/_scripts.py +339 -0
  205. notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
  206. notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
  207. notebooks/paths.py +4 -0
  208. notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
  209. notebooks/tutorials/FFM.ipynb +3498 -0
  210. notebooks/tutorials/_include.py +48 -0
  211. notebooks/tutorials/component_energy_area.ipynb +363 -0
  212. tests/Q_mapping.yaml +38 -0
  213. tests/__init__.py +0 -0
  214. tests/conv.mapping.yaml +27 -0
  215. tests/conv.workload.yaml +13 -0
  216. tests/conv_sym.mapping.yaml +43 -0
  217. tests/copy.mapping.yaml +35 -0
  218. tests/copy.workload.yaml +15 -0
  219. tests/distribuffers/__init__.py +0 -0
  220. tests/distribuffers/multicast/test_cases.yaml +482 -0
  221. tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
  222. tests/distribuffers/spec/distributed.yaml +100 -0
  223. tests/distribuffers/spec/logical_arch.yaml +32 -0
  224. tests/distribuffers/spec/physical_arch.yaml +69 -0
  225. tests/distribuffers/test_binding.py +48 -0
  226. tests/frontend/__init__.py +0 -0
  227. tests/frontend/test_mapping_viz.py +52 -0
  228. tests/mapper/__init__.py +0 -0
  229. tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
  230. tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
  231. tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
  232. tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
  233. tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
  234. tests/mapper/test_mapping_to_isl.py +90 -0
  235. tests/mapper/test_spatial_reuse_analysis.py +67 -0
  236. tests/mapper/test_temporal_reuse_analysis.py +56 -0
  237. tests/mapper/util.py +58 -0
  238. tests/matmul.mapping.yaml +29 -0
  239. tests/matmul.workload.yaml +12 -0
  240. tests/matmul_spatial.mapping.yaml +44 -0
  241. tests/mha.renames.yaml +65 -0
  242. tests/mha.workload.yaml +67 -0
  243. tests/mha.yaml +59 -0
  244. tests/mha_full.workload.yaml +67 -0
  245. tests/mobilenet.workload.yaml +35 -0
  246. tests/mobilenet_long.workload.yaml +64 -0
  247. tests/pmappingcache.py +24 -0
  248. tests/processing_stage.arch.yaml +40 -0
  249. tests/snowcat.arch.yaml +36 -0
  250. tests/test_ffm_join_pmappings.py +106 -0
  251. tests/test_ffm_make_pmappings.py +82 -0
  252. tests/test_ffm_make_tile_shapes.py +49 -0
  253. tests/test_mapper.py +100 -0
  254. tests/test_model.py +37 -0
  255. tests/test_plotting.py +72 -0
  256. tests/test_processing_stage.py +46 -0
  257. tests/test_symbolic_model.py +248 -0
  258. tests/test_workload.py +141 -0
@@ -0,0 +1,1544 @@
1
+ import copy
2
+ import math
3
+ from numbers import Number
4
+ import re
5
+ from typing import (
6
+ Any,
7
+ Callable,
8
+ Iterator,
9
+ List,
10
+ Literal,
11
+ Optional,
12
+ TypeVar,
13
+ Annotated,
14
+ Type,
15
+ Union,
16
+ )
17
+ from pydantic import ConfigDict, Tag
18
+ import pydantic
19
+ from hwcomponents import (
20
+ ComponentModel,
21
+ get_models,
22
+ get_model,
23
+ )
24
+ import pydot
25
+
26
+ from accelforge.util._basetypes import (
27
+ ParsableModel,
28
+ ParsableList,
29
+ ParseExtras,
30
+ ParsesTo,
31
+ _PostCall,
32
+ _get_tag,
33
+ )
34
+ import numpy as np
35
+
36
+ from accelforge.util._parse_expressions import ParseError, parse_expression
37
+ from accelforge.util._setexpressions import InvertibleSet, eval_set_expression
38
+ from accelforge.frontend.renames import RankVariable, TensorName
39
+
40
+ from accelforge._version import assert_version, __version__
41
+ from pydantic import Discriminator
42
+ from accelforge.util._basetypes import _uninstantiable
43
+ from accelforge.util.parallel import _SVGJupyterRender, _pydot_graph
44
+
45
+ T = TypeVar("T", bound="ArchNode")
46
+
47
+
48
+ class ArchNode(ParsableModel):
49
+ """A node in the architecture."""
50
+
51
+ def find(self, name: str) -> "Leaf":
52
+ """
53
+ Finds a `Leaf` node with the given name.
54
+ :raises ValueError: If the `Leaf` node with the given name is not found.
55
+ """
56
+ if isinstance(self, Leaf) and getattr(self, "name", None) == name:
57
+ return self
58
+
59
+ if isinstance(self, Branch):
60
+ for element in self.nodes:
61
+ try:
62
+ return element.find(name)
63
+ except (AttributeError, ValueError):
64
+ pass
65
+ raise ValueError(f"Leaf {name} not found in {self}")
66
+
67
+
68
+ class ArchNodes(ParsableList):
69
+ """A list of `ArchNode`s."""
70
+
71
+ def __repr__(self):
72
+ return f"{self.__class__.__name__}({super().__repr__()})"
73
+
74
+ def _parse_expressions(self, *args, **kwargs):
75
+ class PostCallArchNode(_PostCall):
76
+ def __call__(self, field, value, parsed, symbol_table):
77
+ if isinstance(parsed, Container):
78
+ symbol_table.update(parsed.attributes)
79
+ return parsed
80
+
81
+ return super()._parse_expressions(
82
+ *args, **kwargs, post_calls=(PostCallArchNode(),)
83
+ )
84
+
85
+
86
+ class Comparison(ParsableModel):
87
+ """
88
+ A comparison between a rank variable's bound and a value. A comparison is performed
89
+ for each rank variable.
90
+
91
+ The LHS of each comparison is the loop bound of a loop that affects this rank
92
+ variable. The RHS is the given value.
93
+
94
+ For example, if the expression resolves to [a, b], the operator is "<=", and the
95
+ value is 10, and we have loops "for a0 in [0..A0)" and "for b0 in [0..B0)", then a
96
+ mapping is only valid if A0 <= 10 and B0 <= 10.
97
+ """
98
+
99
+ expression: str | InvertibleSet[RankVariable] | set[RankVariable]
100
+ """ The expression to compare. This expression should resolve to a set of rank
101
+ variables. A comparison is performed for each rank variable independently, and the
102
+ result passes if and only if all comparisons pass. The LHS of each comparison is the
103
+ loop bound of a loop that affects this rank variable. The RHS is the given value.
104
+ """
105
+
106
+ operator: str
107
+ """ The operator to use for the comparison. Supported operators are:
108
+ - == (equal to)
109
+ - <= (less than or equal to)
110
+ - >= (greater than or equal to)
111
+ - < (less than)
112
+ - > (greater than)
113
+ - product== (product of all loop bounds is equal to)
114
+ - product<= (product of all loop bounds is less than or equal to)
115
+ - product>= (product of all loop bounds is greater than or equal to)
116
+ - product< (product of all loop bounds is less than)
117
+ - product> (product of all loop bounds is greater than)
118
+ """
119
+
120
+ value: ParsesTo[int]
121
+ """ The value to compare against. """
122
+
123
+ def _parse(self, symbol_table: dict[str, Any], location: str):
124
+ # if len(self) != 3:
125
+ # raise ValueError(f"Comparison can only have 3 elements. got {len(self)}")
126
+ new = type(self)(
127
+ expression=eval_set_expression(
128
+ self.expression, symbol_table, "rank_variables", location
129
+ ),
130
+ operator=self.operator,
131
+ value=self.value,
132
+ )
133
+ if len(new.expression) == 1 and "product" in new.operator:
134
+ new.operator = new.operator.replace("product", "")
135
+ return new
136
+
137
+ def _constrained_to_one(self) -> bool:
138
+ return self.value == 1 and self.operator in [
139
+ "==",
140
+ "<=",
141
+ "product==",
142
+ "product<=",
143
+ ]
144
+
145
+ def _split_expression(self) -> List[set[RankVariable]]:
146
+ if "product" in self.operator:
147
+ return [self.expression]
148
+ return sorted(set((x,)) for x in self.expression)
149
+
150
+ def _to_constraint_lambda(
151
+ self,
152
+ increasing_sizes: bool,
153
+ ) -> Callable[[bool, np.ndarray], bool | np.ndarray]:
154
+ # Equal operators can only evaluate when all sizes are known
155
+ eq_op = lambda final: (
156
+ np.equal
157
+ if final
158
+ else (np.less_equal if increasing_sizes else np.greater_equal)
159
+ )
160
+
161
+ # If we're increasing, we can evaluate leq immediately. If we're
162
+ # decreasing, we can evaluate geq immediately. The other must wait
163
+ # until all sizes are known.
164
+ le_wrapper = lambda op: lambda final, sizes: (
165
+ op(sizes) if final or increasing_sizes else True
166
+ )
167
+ ge_wrapper = lambda op: lambda final, sizes: (
168
+ op(sizes) if final or not increasing_sizes else True
169
+ )
170
+
171
+ _all = lambda sizes: np.all(sizes, axis=1)
172
+ _prod = lambda sizes: np.prod(sizes, axis=1)
173
+
174
+ # fmt: off
175
+ operator_to_wrapper = {
176
+ "==": lambda final, sizes: _all(eq_op(final)(sizes, self.value)),
177
+ "product==": lambda final, sizes: eq_op(final)(_prod(sizes), self.value),
178
+ "<=": le_wrapper(lambda sizes: _all(sizes) <= self.value),
179
+ ">=": ge_wrapper(lambda sizes: _all(sizes) >= self.value),
180
+ "<": le_wrapper(lambda sizes: _all(sizes) < self.value),
181
+ ">": ge_wrapper(lambda sizes: _all(sizes) > self.value),
182
+ "product<=": le_wrapper(lambda sizes: _prod(sizes) <= self.value),
183
+ "product>=": ge_wrapper(lambda sizes: _prod(sizes) >= self.value),
184
+ "product<": le_wrapper(lambda sizes: _prod(sizes) < self.value),
185
+ "product>": ge_wrapper(lambda sizes: _prod(sizes) > self.value),
186
+ }
187
+ # fmt: on
188
+
189
+ if self.operator in operator_to_wrapper:
190
+ return operator_to_wrapper[self.operator]
191
+ raise KeyError(
192
+ f"Unknown operator: {self.operator}. Known operators: {list(operator_to_wrapper.keys())}"
193
+ )
194
+
195
+
196
+ class Spatial(ParsableModel):
197
+ """A one-dimensional spatial fanout in the architecture."""
198
+
199
+ name: str
200
+ """
201
+ The name of the dimension over which this spatial fanout is occurring (e.g., X or Y).
202
+ """
203
+
204
+ fanout: ParsesTo[int]
205
+ """ The size of this fanout. """
206
+
207
+ may_reuse: ParsesTo[InvertibleSet[TensorName] | set[TensorName]] = "All"
208
+ """ The tensors that can be reused spatially across instances of this fanout. This
209
+ expression will be parsed for each mapping template. """
210
+
211
+ loop_bounds: ParsableList[Comparison] = ParsableList()
212
+ """ Bounds for loops over this dimension. This is a list of :class:`~.Comparison`
213
+ objects, all of which must be satisfied by the loops to which this constraint
214
+ applies.
215
+ """
216
+
217
+ min_usage: int | float | str = 0.0
218
+ """ The minimum utilization of spatial instances, as a value from 0 to 1. A mapping
219
+ is invalid if less than this porportion of this dimension's fanout is utilized.
220
+ Mappers that support it (e.g., FFM) may, if no mappings satisfy this constraint,
221
+ return the highest-utilization mappings.
222
+ """
223
+
224
+ reuse: ParsesTo[InvertibleSet[TensorName] | set[TensorName]] = "Nothing"
225
+ """ A set of tensors or a set expression representing tensors that must be reused
226
+ across spatial iterations. Spatial loops may only be placed that reuse ALL tensors
227
+ given here.
228
+ """
229
+
230
+ usage_scale: ParsesTo[int | float | str] = 1
231
+ """
232
+ This factor scales the usage in this dimension. For example, if usage_scale is 2 and
233
+ 10/20 spatial instances are used, then the usage will be scaled to 20/20.
234
+ """
235
+
236
+ def _parse(self, symbol_table: dict[str, Any], location: str):
237
+ return type(self)(
238
+ name=self.name,
239
+ fanout=self.fanout,
240
+ may_reuse=set(
241
+ eval_set_expression(
242
+ self.may_reuse,
243
+ symbol_table,
244
+ expected_space_name="tensors",
245
+ location=location + ".may_reuse",
246
+ )
247
+ ),
248
+ loop_bounds=[
249
+ x._parse(symbol_table, location + ".loop_bounds")
250
+ for x in self.loop_bounds
251
+ ],
252
+ min_usage=parse_expression(
253
+ self.min_usage,
254
+ symbol_table,
255
+ "min_usage",
256
+ location + ".min_usage",
257
+ ),
258
+ reuse=eval_set_expression(
259
+ self.reuse,
260
+ symbol_table,
261
+ "tensors",
262
+ location + ".reuse",
263
+ ),
264
+ )
265
+
266
+
267
+ class LeafAttributes(ParsableModel):
268
+ pass
269
+
270
+
271
+ class AttributesWithExtras(ParseExtras):
272
+ pass
273
+
274
+
275
+ class AttributesWithEnergyLatency(AttributesWithExtras):
276
+ energy: ParsesTo[int | float | None] = None
277
+ energy_scale: ParsesTo[int | float] = 1
278
+ latency_scale: ParsesTo[int | float] = 1
279
+
280
+
281
+ class ComponentAttributes(AttributesWithEnergyLatency):
282
+ area: ParsesTo[int | float | None] = None
283
+ """
284
+ The area of a single instance of this component in m^2. If set, area calculations
285
+ will use this value.
286
+ """
287
+ total_area: ParsesTo[int | float | None] = None
288
+ """
289
+ The total area of all instances of this component in m^2. Do not set this value. It
290
+ is calculated when the architecture's area is calculated.
291
+ """
292
+ area_scale: ParsesTo[int | float] = 1
293
+ """
294
+ The scale factor for the area of this comxponent. This is used to scale the area of
295
+ this component. For example, if the area is 1 m^2 and the scale factor is 2, then
296
+ the area is 2 m^2.
297
+ """
298
+
299
+ leak_power: ParsesTo[int | float | None] = None
300
+ """
301
+ The leak power of a single instance of this component in W. If set, leak power
302
+ calculations will use this value.
303
+ """
304
+ total_leak_power: ParsesTo[int | float | None] = None
305
+ """
306
+ The total leak power of all instances of this component in W. Do not set this value.
307
+ It is calculated when the architecture's leak power is calculated.
308
+ """
309
+ leak_power_scale: ParsesTo[int | float] = 1
310
+ """
311
+ The scale factor for the leak power of this component. This is used to scale the
312
+ leak power of this component. For example, if the leak power is 1 W and the scale
313
+ factor is 2, then the leak power is 2 W.
314
+ """
315
+
316
+ energy_scale: ParsesTo[int | float] = 1
317
+ """
318
+ The scale factor for dynamic energy of this component. For each action, multiplies
319
+ this action's energy. Multiplies the calculated energy of each action.
320
+ """
321
+
322
+ total_latency: str | int | float = (
323
+ "sum(*action2latency.values()) / n_parallel_instances"
324
+ )
325
+ """
326
+ An expression representing the total latency of this component in seconds. This is
327
+ used to calculate the latency of a given Einsum. Special variables available are the
328
+ following:
329
+
330
+ - `min`: The minimum value of all arguments to the expression.
331
+ - `max`: The maximum value of all arguments to the expression.
332
+ - `sum`: The sum of all arguments to the expression.
333
+ - `X_actions`: The number of times action `X` is performed. For example,
334
+ `read_actions` is the number of times the read action is performed.
335
+ - `X_latency`: The total latency of all actions of type `X`. For example,
336
+ `read_latency` is the total latency of all read actions. It is equal to the
337
+ per-read latency multiplied by the number of read actions.
338
+ - `action2latency`: A dictionary of action names to their latency.
339
+
340
+ Additionally, all component attributes are availble as variables, and all other
341
+ functions generally available in parsing. Note this expression is parsed after other
342
+ component attributes are parsed.
343
+
344
+ For example, the following expression calculates latency assuming that each read or
345
+ write action takes 1ns: ``1e-9 * (read_actions + write_actions)``.
346
+ """
347
+
348
+ latency_scale: ParsesTo[int | float] = 1
349
+ """
350
+ The scale factor for the latency of this component. This is used to scale the
351
+ latency of this component. For example, if the latency is 1 ns and the scale factor
352
+ is 2, then the latency is 2 ns. Multiplies the calculated latency of each action.
353
+ """
354
+
355
+ n_parallel_instances: ParsesTo[int | float] = 1
356
+ """
357
+ The number of parallel instances of this component. Increasing parallel instances
358
+ will proportionally increase area and leakage, while reducing latency (unless
359
+ latency calculation is overridden).
360
+ """
361
+
362
+
363
+ class FanoutAttributes(LeafAttributes):
364
+ model_config = ConfigDict(extra="forbid")
365
+
366
+
367
+ class ActionArguments(AttributesWithEnergyLatency):
368
+ """
369
+ Arguments for an action of a component.
370
+ """
371
+
372
+ energy: ParsesTo[int | float | None] = None
373
+ """
374
+ Dynamic energy of this action. Per-action energy is multiplied by the component's
375
+ attributes.energy_scale and the action's arguments.energy_scale.
376
+ """
377
+ energy_scale: ParsesTo[int | float] = 1
378
+ """
379
+ The scale factor for dynamic energy of this action. Multiplies this action's energy
380
+ by this value.
381
+ """
382
+ latency: ParsesTo[int | float | None] = None
383
+ """
384
+ Latency of this action. Per-action latency is multiplied by the component's
385
+ attributes.latency_scale and the action's arguments.latency_scale.
386
+ """
387
+ latency_scale: ParsesTo[int | float] = 1
388
+ """
389
+ The scale factor for dynamic latency of this action. Multiplies this action's
390
+ latency by this value.
391
+ """
392
+
393
+
394
+ class TensorHolderActionArguments(ActionArguments):
395
+ bits_per_action: ParsesTo[int | float] = (
396
+ "1 if attributes.bits_per_action is None else attributes.bits_per_action"
397
+ )
398
+ """ The number of bits accessed in this action. For example, setting bits_per_action
399
+ to 16 means that each call to this action yields 16 bits. """
400
+
401
+
402
+ class Action(ParsableModel):
403
+ name: str
404
+ """ The name of this action. """
405
+
406
+ arguments: ActionArguments = ActionArguments()
407
+ """
408
+ The arguments for this action. Passed to the component's model to calculate the
409
+ energy and latency of the action.
410
+ """
411
+
412
+
413
+ class TensorHolderAction(Action):
414
+ arguments: TensorHolderActionArguments = TensorHolderActionArguments()
415
+ """
416
+ The arguments for this action. Passed to the component's model to calculate the
417
+ energy and latency of the action.
418
+ """
419
+
420
+
421
+ @_uninstantiable
422
+ class Leaf(ArchNode):
423
+ """A leaf node in the architecture. This is an abstract class that represents any
424
+ node that is not a `Branch`."""
425
+
426
+ name: str
427
+ """ The name of this `Leaf`. """
428
+
429
+ attributes: LeafAttributes = LeafAttributes()
430
+ """ The attributes of this `Leaf`. """
431
+
432
+ spatial: ParsableList[Spatial] = ParsableList()
433
+ """
434
+ The spatial fanouts of this `Leaf`.
435
+
436
+ Spatial fanouts describe the spatial organization of components in the architecture.
437
+ A spatial fanout of size N for this node means that there are N instances of this
438
+ node. Multiple spatial fanouts lead to a multi-dimensional fanout. Spatial
439
+ constraints apply to the data exchange across these instances. Spatial fanouts
440
+ specified at this level also apply to lower-level `Leaf` nodes in the architecture.
441
+ """
442
+
443
+ _fields_for_energy_area_latency_leak_calculation: tuple[str, ...] = (
444
+ "name",
445
+ "attributes",
446
+ )
447
+ """
448
+ The fields that are used to calculate the energy, area, latency, and leak power of
449
+ this `Leaf`.
450
+ """
451
+
452
+ def _parse_expressions(self, symbol_table: dict[str, Any], *args, **kwargs):
453
+ class PostCallLeaf(_PostCall):
454
+ def __call__(self, field, value, parsed, symbol_table):
455
+ if field == "attributes":
456
+ symbol_table.update(parsed.model_dump())
457
+ return parsed
458
+
459
+ if "_parsing_attributes_only_" in symbol_table:
460
+ kwargs["fields"] = self._fields_for_energy_area_latency_leak_calculation
461
+
462
+ parsed, symbol_table = super()._parse_expressions(
463
+ symbol_table,
464
+ *args,
465
+ **kwargs,
466
+ post_calls=(PostCallLeaf(),),
467
+ order=("attributes",),
468
+ )
469
+ symbol_table[self.name] = self
470
+ return parsed, symbol_table
471
+
472
+ def get_fanout(self) -> int:
473
+ """The spatial fanout of this node."""
474
+ return int(math.prod(x.fanout for x in self.spatial))
475
+
476
+
477
+ @_uninstantiable
478
+ class Component(Leaf):
479
+ """A component object in the architecture. This is overridden by different
480
+ component types, such as `Memory` and `Compute`."""
481
+
482
+ name: str
483
+ """ The name of this `Component`. """
484
+
485
+ component_class: Optional[str] = None
486
+ """ The class of this `Component`. Used if an energy or area model needs to be
487
+ called for this `Component`. """
488
+
489
+ component_model: ComponentModel | None = None
490
+ """ The model to use for this `Component`. If not set, the model will be found with
491
+ `hwcomponents.get_models()`. If set, the `component_class` will be ignored. """
492
+
493
+ component_modeling_log: list[str] = []
494
+ """ A log of the energy and area calculations for this `Component`. """
495
+
496
+ actions: ParsableList[Action]
497
+ """ The actions that this `Component` can perform. """
498
+
499
+ attributes: ComponentAttributes = ComponentAttributes()
500
+ """ The attributes of this `Component`. """
501
+
502
+ model_config = ConfigDict(arbitrary_types_allowed=True)
503
+
504
+ enabled: str | bool = True
505
+ """ Whether this component is enabled. If the expression resolves to False, then
506
+ the component is disabled. This is parsed per-pmapping-template, so it is a function
507
+ of the tensors in the current Einsum. For example, you may say `len(All) >= 3` and
508
+ the component will only be enabled with Einsums with three or more tensors.
509
+ """
510
+
511
+ _fields_for_energy_area_latency_leak_calculation: tuple[str, ...] = (
512
+ "actions",
513
+ "component_class",
514
+ "component_model",
515
+ "component_modeling_log",
516
+ ) + Leaf._fields_for_energy_area_latency_leak_calculation
517
+
518
+ def _update_actions(self, new_actions: ParsableList[Action]):
519
+ has_actions = set(x.name for x in self.actions)
520
+ for action in new_actions:
521
+ if action.name not in has_actions:
522
+ self.actions.append(action)
523
+
524
+ def get_component_class(self, trying_to_calculate: str = None) -> str:
525
+ """Returns the class of this `Component`.
526
+
527
+ Parameters
528
+ ----------
529
+ trying_to_parse : str, optional
530
+ What was trying to be calculated using this component. If provided, the
531
+ error message will be more specific.
532
+
533
+ :raises ParseError: If the `component_class` is not set.
534
+ """
535
+ extra_info = ""
536
+ if trying_to_calculate is not None:
537
+ extra_info = f" Occurred while trying to calculate {trying_to_calculate}."
538
+
539
+ if self.component_class is None:
540
+ raise ParseError(
541
+ f"component_class must be set to a valid string. "
542
+ f"Got {self.component_class}. This occurred because the model tried to "
543
+ "talk to hwcomponents, but was missing necessary attributes. If you do "
544
+ "not want to use hwcomponents models, ensure that attributes.area and "
545
+ "attributes.leak_power are set, as well as, for each action, "
546
+ f"arguments.energy and arguments.latency are set.{extra_info}",
547
+ source_field=f"{self.name}.component_class",
548
+ )
549
+ return self.component_class
550
+
551
+ def populate_component_model(
552
+ self: T,
553
+ models: list[ComponentModel] | None = None,
554
+ in_place: bool = False,
555
+ trying_to_calculate: str = None,
556
+ ) -> T:
557
+ """
558
+ Populates the ``component_model`` attribute with the model for this component.
559
+ Extends the ``component_modeling_log`` field with log messages. Uses the
560
+ ``component_class`` attribute to find the model and populate the
561
+ ``component_model`` attribute. Uses the ``hwcomponents.get_model()`` function to
562
+ find the model.
563
+
564
+ Parameters
565
+ ----------
566
+ models : list[ComponentModel] | None
567
+ The models to use for energy calculation. If not provided, the models will
568
+ be found with `hwcomponents.get_models()`.
569
+ in_place : bool
570
+ If True, the component will be modified in place. Otherwise, a copy will be
571
+ returned.
572
+ trying_to_calculate : str, optional
573
+ What was trying to be calculated using this component. If provided, the
574
+ error messages for missing component_class will be more specific.
575
+
576
+ Returns
577
+ -------
578
+ T
579
+ A copy of the component with the populated ``component_model`` attribute.
580
+ """
581
+ if not in_place:
582
+ self = self.model_copy()
583
+ self.attributes = self.attributes.model_copy()
584
+ self.actions = type(self.actions)([a.model_copy() for a in self.actions])
585
+ for action in self.actions:
586
+ action.arguments = action.arguments.model_copy()
587
+
588
+ if self.component_model is None:
589
+ if models is None:
590
+ models = get_models()
591
+ estimation = get_model(
592
+ self.get_component_class(trying_to_calculate=trying_to_calculate),
593
+ self.attributes.model_dump(),
594
+ required_actions=list(x.name for x in self.actions),
595
+ models=models,
596
+ _return_estimation_object=True,
597
+ )
598
+ self.component_model = estimation.value
599
+ self.component_modeling_log.extend(estimation.messages)
600
+ return self
601
+
602
+ def calculate_action_energy(
603
+ self: T,
604
+ models: list[ComponentModel] | None = None,
605
+ in_place: bool = False,
606
+ ) -> T:
607
+ """
608
+ Calculates energy for each action of this component. If energy is set in the
609
+ arguments or attributes (with arguments taking precedence), that value will be
610
+ used. Otherwise, the energy will be calculated using hwcomponents. Populates,
611
+ for each action, the ``<action>.arguments.energy`` and field. Extends the
612
+ ``component_modeling_log`` field with log messages.
613
+
614
+ Uses the ``component_model`` attribute, or, if not set, the ``component_class``
615
+ attribute to find the model and populate the ``component_model`` attribute.
616
+
617
+ Note that these methods will be called by the Spec when calculating energy and
618
+ area. If you call them yourself, note that string expressions may not be parsed
619
+ because they need the Spec's global scope. If you are sure that all necessary
620
+ values are present and not a result of an expression, you can call these
621
+ directly. Otherwise, you can call the ``Spec.calculate_component_area_energy_latency_leak``
622
+ and then grab components from the returned ``Spec``.
623
+
624
+ Parameters
625
+ ----------
626
+ models : list[ComponentModel] | None
627
+ The models to use for energy calculation. If not provided, the models will
628
+ be found with `hwcomponents.get_models()`.
629
+ in_place : bool
630
+ If True, the component will be modified in place. Otherwise, a copy will be
631
+ returned.
632
+
633
+ Returns
634
+ -------
635
+ T
636
+ A copy of the component with the calculated energy.
637
+ """
638
+ if not in_place:
639
+ self = self.model_copy()
640
+ self.attributes = self.attributes.model_copy()
641
+ self.actions = type(self.actions)([a.model_copy() for a in self.actions])
642
+ for action in self.actions:
643
+ action.arguments = action.arguments.model_copy()
644
+
645
+ messages = self.component_modeling_log
646
+
647
+ attributes = self.attributes
648
+ for action in self.actions:
649
+ messages.append(f"Calculating energy for {self.name} action {action.name}.")
650
+ args = action.arguments
651
+ if args.energy is not None:
652
+ energy = args.energy
653
+ messages.append(f"Setting {self.name} energy to {args.energy=}")
654
+ else:
655
+ self.populate_component_model(
656
+ models,
657
+ in_place=True,
658
+ trying_to_calculate=f"arguments.energy for action {action.name}",
659
+ )
660
+ energy = self.component_model.try_call_arbitrary_action(
661
+ action_name=action.name,
662
+ _return_estimation_object=True,
663
+ **{**attributes.model_dump(), **args.model_dump()},
664
+ )
665
+ messages.extend(energy.messages)
666
+ energy = energy.value[0]
667
+ if attributes.energy_scale != 1:
668
+ energy *= attributes.energy_scale
669
+ messages.append(
670
+ f"Scaling {self.name} energy by {attributes.energy_scale=}"
671
+ )
672
+ if args.energy_scale != 1:
673
+ energy *= args.energy_scale
674
+ messages.append(f"Scaling {self.name} energy by {args.energy_scale=}")
675
+ action.arguments.energy = energy
676
+ return self
677
+
678
+ def calculate_leak_power(
679
+ self: T,
680
+ models: list[ComponentModel] | None = None,
681
+ in_place: bool = False,
682
+ ) -> T:
683
+ """
684
+ Calculates the leak power for this component. If leak power is set in the
685
+ arguments or attributes (with arguments taking precedence), that value will be
686
+ used. Otherwise, the leak power will be calculated using hwcomponents. Populates
687
+ ``attributes.leak_power`` field. Extends the ``component_modeling_log`` field with log
688
+ messages.
689
+
690
+ Uses the ``component_model`` attribute, or, if not set, the ``component_class``
691
+ attribute to find the model and populate the ``component_model`` attribute.
692
+
693
+ Note that these methods will be called by the Spec when calculating energy and
694
+ area. If you call them yourself, note that string expressions may not be parsed
695
+ because they need the Spec's global scope. If you are sure that all necessary
696
+ values are present and not a result of an expression, you can call these
697
+ directly. Otherwise, you can call the ``Spec.calculate_component_area_energy_latency_leak``
698
+ and then grab components from the returned ``Spec``.
699
+
700
+ Parameters
701
+ ----------
702
+ models : list[ComponentModel] | None
703
+ The models to use for energy calculation. If not provided, the models will
704
+ be found with `hwcomponents.get_models()`.
705
+ in_place : bool
706
+ If True, the component will be modified in place. Otherwise, a copy will be
707
+ returned.
708
+
709
+ Returns
710
+ -------
711
+ T
712
+ A copy of the component with the calculated energy.
713
+ """
714
+ if not in_place:
715
+ self = self.model_copy()
716
+ self.attributes = self.attributes.model_copy()
717
+ self.actions = type(self.actions)([a.model_copy() for a in self.actions])
718
+ for action in self.actions:
719
+ action.arguments = action.arguments.model_copy()
720
+
721
+ attributes = self.attributes
722
+ messages = self.component_modeling_log
723
+ if attributes.leak_power is not None:
724
+ leak_power = attributes.leak_power
725
+ messages.append(
726
+ f"Using predefined leak power value {attributes.leak_power=}"
727
+ )
728
+ else:
729
+ self.populate_component_model(
730
+ models,
731
+ in_place=True,
732
+ trying_to_calculate="attributes.leak_power",
733
+ )
734
+ leak_power = self.component_model.leak_power
735
+ if attributes.leak_power_scale != 1:
736
+ leak_power *= attributes.leak_power_scale
737
+ messages.append(f"Scaling leak power by {attributes.leak_power_scale=}")
738
+ if attributes.n_parallel_instances != 1:
739
+ leak_power *= attributes.n_parallel_instances
740
+ messages.append(f"Scaling leak power by {attributes.n_parallel_instances=}")
741
+ self.attributes.leak_power = leak_power
742
+ return self
743
+
744
+ def calculate_area(
745
+ self: T,
746
+ models: list[ComponentModel] | None = None,
747
+ in_place: bool = False,
748
+ ) -> T:
749
+ """
750
+ Calculates the area for this component. If area is set in the attributes, that
751
+ value will be used. Otherwise, the area will be calculated using the
752
+ hwcomponents library. Populates ``attributes.area`` field. Extends the
753
+ ``component_modeling_log`` field with log messages.
754
+
755
+ Uses the ``component_model`` attribute, or, if not set, the ``component_class``
756
+ attribute to find the model and populate the ``component_model`` attribute.
757
+
758
+ Note that these methods will be called by the Spec when calculating
759
+ energy and area. If you call them yourself, note that string expressions may not
760
+ be parsed because they need the Spec's global scope. If you are sure
761
+ that all necessary values are present and not a result of an expression, you can
762
+ call these directly. Otherwise, you can call the
763
+ ``Spec.calculate_component_area_energy_latency_leak`` and then grab components from
764
+ the returned ``Spec``.
765
+
766
+ Parameters
767
+ ----------
768
+ models : list[ComponentModel] | None
769
+ The models to use for area calculation. If not provided, the models will be
770
+ found with `hwcomponents.get_models()`.
771
+ in_place : bool
772
+ If True, the component will be modified in place. Otherwise, a copy will be
773
+ returned.
774
+
775
+ Returns
776
+ -------
777
+ T
778
+ A copy of the component with the calculated area.
779
+ """
780
+ if not in_place:
781
+ self = self.model_copy()
782
+ self.attributes = self.attributes.model_copy()
783
+ self.actions = type(self.actions)([a.model_copy() for a in self.actions])
784
+ for action in self.actions:
785
+ action.arguments = action.arguments.model_copy()
786
+
787
+ attributes = self.attributes
788
+ messages = self.component_modeling_log
789
+ if attributes.area is not None:
790
+ area = attributes.area
791
+ messages.append(f"Using predefined area value {attributes.area=}")
792
+ else:
793
+ self.populate_component_model(
794
+ models,
795
+ in_place=True,
796
+ trying_to_calculate="attributes.area",
797
+ )
798
+ area = self.component_model.area
799
+ if attributes.area_scale != 1:
800
+ area *= attributes.area_scale
801
+ messages.append(f"Scaling area by {attributes.area_scale=}")
802
+ if attributes.n_parallel_instances != 1:
803
+ area *= attributes.n_parallel_instances
804
+ messages.append(f"Scaling area by {attributes.n_parallel_instances=}")
805
+ self.attributes.area = area
806
+ return self
807
+
808
+ def calculate_action_latency(
809
+ self: T,
810
+ models: list[ComponentModel] | None = None,
811
+ in_place: bool = False,
812
+ ) -> T:
813
+ """
814
+ Calculates the latency for each action by this component. Populates the
815
+ ``<action>.arguments.latency`` field. Extends the ``component_modeling_log`` field with
816
+ log messages.
817
+
818
+ Parameters
819
+ ----------
820
+ models : list[ComponentModel] | None
821
+ The models to use for latency calculation. If not provided, the models will be
822
+ found with `hwcomponents.get_models()`.
823
+ in_place : bool
824
+ If True, the component will be modified in place. Otherwise, a copy will be
825
+ returned.
826
+
827
+ Returns
828
+ -------
829
+ T
830
+ A copy of the component with the calculated latency for each action.
831
+ """
832
+ if not in_place:
833
+ self = self.model_copy()
834
+ self.attributes = self.attributes.model_copy()
835
+ self.actions = type(self.actions)([a.model_copy() for a in self.actions])
836
+ for action in self.actions:
837
+ action.arguments = action.arguments.model_copy()
838
+
839
+ messages = self.component_modeling_log
840
+
841
+ attributes = self.attributes
842
+ for action in self.actions:
843
+ messages.append(
844
+ f"Calculating latency for {self.name} action {action.name}."
845
+ )
846
+ args = action.arguments
847
+ if args.latency is not None:
848
+ latency = args.latency
849
+ messages.append(f"Setting {self.name} latency to {args.latency=}")
850
+ else:
851
+ self.populate_component_model(
852
+ models,
853
+ in_place=True,
854
+ trying_to_calculate=f"arguments.latency for action {action.name}",
855
+ )
856
+ latency = self.component_model.try_call_arbitrary_action(
857
+ action_name=action.name,
858
+ _return_estimation_object=True,
859
+ **{**attributes.model_dump(), **args.model_dump()},
860
+ )
861
+ messages.extend(latency.messages)
862
+ latency = latency.value[1]
863
+ if attributes.latency_scale != 1:
864
+ latency *= attributes.latency_scale
865
+ messages.append(
866
+ f"Scaling {self.name} latency by {attributes.latency_scale=}"
867
+ )
868
+ if args.latency_scale != 1:
869
+ latency *= args.latency_scale
870
+ messages.append(f"Scaling {self.name} latency by {args.latency_scale=}")
871
+ if attributes.n_parallel_instances != 1:
872
+ latency /= attributes.n_parallel_instances
873
+ messages.append(
874
+ f"Dividing {self.name} latency by {attributes.n_parallel_instances=}"
875
+ )
876
+ action.arguments.latency = latency
877
+ return self
878
+
879
+ def calculate_area_energy_latency_leak(
880
+ self: T, models: list[ComponentModel] | None = None, in_place: bool = False
881
+ ) -> T:
882
+ """
883
+ Calculates the area, energy, latency, and leak power for this component.
884
+ Populates the ``attributes.area``, ``attributes.total_area``,
885
+ ``attributes.leak_power``, ``attributes.total_leak_power``,
886
+ ``attributes.total_latency``, and ``component_modeling_log`` fields of this
887
+ component. Additionally, for each action, populates the
888
+ ``<action>.arguments.area``, ``<action>.arguments.energy``,
889
+ ``<action>.arguments.latency``, and ``<action>.arguments.leak_power`` fields.
890
+ Extends the ``component_modeling_log`` field with log messages.
891
+
892
+ Note that these methods will be called by the Spec when calculating energy and
893
+ area. If you call them yourself, note that string expressions may not be parsed
894
+ because they need the Spec's global scope. If you are sure that all necessary
895
+ values are present and not a result of an expression, you can call these
896
+ directly. Otherwise, you can call the ``Spec.calculate_component_area_energy_latency_leak``
897
+ and then grab components from the returned ``Spec``.
898
+
899
+ Parameters
900
+ ----------
901
+ models : list[ComponentModel] | None
902
+ The models to use for energy calculation. If not provided, the models will
903
+ be found with `hwcomponents.get_models()`.
904
+ in_place : bool
905
+ If True, the component will be modified in place. Otherwise, a copy will be
906
+ returned.
907
+
908
+ Returns
909
+ -------
910
+ T
911
+ The component with the calculated energy, area, and leak power.
912
+ """
913
+ if not in_place:
914
+ self = self.model_copy()
915
+ self.attributes = self.attributes.model_copy()
916
+ self.actions = type(self.actions)([a.model_copy() for a in self.actions])
917
+ for action in self.actions:
918
+ action.arguments = action.arguments.model_copy()
919
+ self.calculate_area(models, in_place=True)
920
+ self.calculate_action_energy(models, in_place=True)
921
+ self.calculate_action_latency(models, in_place=True)
922
+ self.calculate_leak_power(models, in_place=True)
923
+ return self
924
+
925
+
926
+ class Container(Leaf):
927
+ """A `Container` is an abstract node in the architecture that contains other nodes.
928
+ For example, a P` may be a `Container` that contains `Memory`s and `Compute` units.
929
+ """
930
+
931
+ pass
932
+
933
+
934
+ MEMORY_ACTIONS = ParsableList[TensorHolderAction](
935
+ [
936
+ TensorHolderAction(name="read"),
937
+ TensorHolderAction(name="write"),
938
+ ]
939
+ )
940
+
941
+
942
+ PROCESSING_STAGE_ACTIONS = ParsableList[TensorHolderAction](
943
+ [
944
+ TensorHolderAction(name="read"),
945
+ ]
946
+ )
947
+
948
+ COMPUTE_ACTIONS = ParsableList(
949
+ [
950
+ Action(name="compute"),
951
+ ]
952
+ )
953
+
954
+
955
+ def _parse_tensor2bits(
956
+ to_parse: dict[str, Any], location: str, symbol_table: dict[str, Any]
957
+ ) -> dict[str, Any]:
958
+ result = {}
959
+ for key, value in to_parse.items():
960
+ if isinstance(value, Number):
961
+ result[key] = value
962
+ continue
963
+ result[key] = parse_expression(
964
+ expression=value,
965
+ symbol_table=symbol_table,
966
+ attr_name=key,
967
+ location=location,
968
+ )
969
+ return result
970
+
971
+
972
+ class TensorHolderAttributes(ComponentAttributes):
973
+ """
974
+ Attributes for a `TensorHolder`. `TensorHolder`s are components that hold tensors
975
+ (usually `Memory`s). When specifying these attributes, it is recommended to
976
+ underscore-prefix attribute names. See `TODO: UNDERSCORE_PREFIX_DISCUSSION`.
977
+ """
978
+
979
+ bits_per_value_scale: ParsesTo[dict | int | float] = {"All": 1}
980
+ """
981
+ A scaling factor for the bits per value of the tensors in this `TensorHolder`. If
982
+ this is a dictionary, keys in the dictionary are parsed as expressions and may
983
+ reference one or more tensors.
984
+ """
985
+
986
+ bits_per_action: ParsesTo[int | float | None] = None
987
+ """
988
+ The number of bits accessed in each of this component's actions. Overridden by
989
+ bits_per_action in the action arguments. If set here, acts as a default value for
990
+ the bits_per_action of all actions of this component.
991
+ """
992
+
993
+ def model_post_init(self, __context__=None) -> None:
994
+ if not isinstance(self.bits_per_value_scale, dict):
995
+ self.bits_per_value_scale = {"All": self.bits_per_value_scale}
996
+
997
+ def _parse_expressions(self, *args, **kwargs):
998
+ class MyPostCall(_PostCall):
999
+ def __call__(self, field, value, parsed, symbol_table):
1000
+ if field == "bits_per_value_scale":
1001
+ parsed = _parse_tensor2bits(
1002
+ parsed,
1003
+ location="bits_per_value_scale",
1004
+ symbol_table=symbol_table,
1005
+ )
1006
+ return parsed
1007
+
1008
+ return super()._parse_expressions(*args, **kwargs, post_calls=(MyPostCall(),))
1009
+
1010
+
1011
+ class MemoryAttributes(TensorHolderAttributes):
1012
+ """Attributes for a `Memory`."""
1013
+
1014
+ size: ParsesTo[int | float]
1015
+ """ The size of this `Memory` in bits. """
1016
+
1017
+
1018
+ class Tensors(ParsableModel):
1019
+ """
1020
+ Fields that control which tensor(s) are kept in a :py:class:`~.TensorHolder` and in
1021
+ what order their nodes may appear in the mapping.
1022
+ """
1023
+
1024
+ keep: ParsesTo[InvertibleSet[TensorName] | set[TensorName]] = (
1025
+ "<Defaults to Nothing>"
1026
+ )
1027
+ """
1028
+ A set expression describing which tensors must be kept in this
1029
+ :class:`accelforge.frontend.arch.TensorHolder`. If this is not defined, then all
1030
+ tensors must be kept.
1031
+ """
1032
+
1033
+ may_keep: ParsesTo[InvertibleSet[TensorName] | set[TensorName]] = (
1034
+ "<Nothing if keep is defined, else All>"
1035
+ )
1036
+ """
1037
+ A set expression describing which tensors may optionally be kept in this
1038
+ :class:`accelforge.frontend.arch.TensorHolder`. The mapper will explore both keeping
1039
+ and not keeping each of these tensors. If this is not defined, then all tensors may
1040
+ be kept.
1041
+ """
1042
+
1043
+ tile_shape: ParsableList[Comparison] = []
1044
+ """
1045
+ The tile shape for each rank variable. This is given as a list of
1046
+ :class:`~.Comparison` objects, where each comparison must evaluate to True for a
1047
+ valid mapping.
1048
+ """
1049
+
1050
+ no_refetch_from_above: ParsesTo[InvertibleSet[TensorName] | set[TensorName]] = (
1051
+ "~All"
1052
+ )
1053
+ """
1054
+ The tensors that are not allowed to be refetched from above. This is given as a set
1055
+ of :class:`~.TensorName` objects or a set expression that resolves to them. These
1056
+ tensors must be fetched at most one time from above memories, and may not be
1057
+ refetched across any temporal or spatial loop iterations. Tensors may be fetched in
1058
+ pieces (if they do not cause re-fetches of any piece).
1059
+ """
1060
+
1061
+ tensor_order_options: ParsableList[
1062
+ ParsableList[ParsesTo[InvertibleSet[TensorName] | set[TensorName]]]
1063
+ ] = ParsableList()
1064
+ """
1065
+ Options for the order of tensor storage nodes in the mapping. This is given as a
1066
+ list-of-lists-of-sets. Each list-of-sets is a valid order of tensor storage nodes.
1067
+ Order is given from highest in the mapping to lowest.
1068
+
1069
+ For example, an option could be [input | output, weight], which means that there is
1070
+ no relative ordering required between input and output, but weight must be below
1071
+ both.
1072
+ """
1073
+
1074
+ force_memory_hierarchy_order: bool = True
1075
+ """
1076
+ If set to true, storage nodes for lower-level memories must be placed below storage
1077
+ nodes for higher-level memories. For example, all MainMemory storage nodes must go
1078
+ above all LocalBuffer storage nodes.
1079
+
1080
+ This constraint always applies to same-tensor storage nodes (e.g., MainMemory
1081
+ reusing Output must go above LocalBuffer reusing Output); turning it off will permit
1082
+ things like MainMemory reusing Output going above LocalBuffer reusing Input.
1083
+
1084
+ This is identical to the `force_memory_hierarchy_order` field in the `FFM` class,
1085
+ but only applies to this tensor holder.
1086
+ """
1087
+
1088
+ def _parse_tensor_order_options(
1089
+ self, symbol_table: dict[str, Any], location: str
1090
+ ) -> "Tensors":
1091
+ result = type(self)(
1092
+ tensor_order_options=[
1093
+ [
1094
+ eval_set_expression(x, symbol_table, "tensors", location)
1095
+ for x in order_choice
1096
+ ]
1097
+ for order_choice in self.tensor_order_options
1098
+ ],
1099
+ )
1100
+ # Assert that there are no intersecting sets
1101
+ for order in result.tensor_order_options:
1102
+ for i, s0 in enumerate(order):
1103
+ for j, s1 in enumerate(order):
1104
+ if i == j:
1105
+ continue
1106
+ if s0 & s1:
1107
+ raise ValueError(
1108
+ f"Intersecting entries in dataflow constraint: {s0} and {s1}"
1109
+ )
1110
+ return result
1111
+
1112
+ def _parse_keep(self, symbol_table: dict[str, Any], location: str) -> "Tensors":
1113
+ keep, may_keep = self.keep, self.may_keep
1114
+ if may_keep == "<Nothing if keep is defined, else All>":
1115
+ may_keep = "All" if keep == "<Defaults to Nothing>" else "~All"
1116
+ if keep == "<Defaults to Nothing>":
1117
+ keep = "Nothing"
1118
+
1119
+ may_keep_first = isinstance(keep, str) and re.findall(r"\bmay_keep\b", keep)
1120
+ keep_first = isinstance(may_keep, str) and re.findall(r"\bkeep\b", may_keep)
1121
+ if keep_first and may_keep_first:
1122
+ raise ValueError(
1123
+ f"Keep and may_keep reference each other: " f"{keep} and {may_keep}"
1124
+ )
1125
+
1126
+ if may_keep_first:
1127
+ may_keep = eval_set_expression(may_keep, symbol_table, "tensors", location)
1128
+ symbol_table = copy.copy(symbol_table)
1129
+ symbol_table["may_keep"] = may_keep
1130
+ keep = eval_set_expression(keep, symbol_table, "tensors", location)
1131
+ return type(self)(keep=keep, may_keep=may_keep)
1132
+ else:
1133
+ keep = eval_set_expression(keep, symbol_table, "tensors", location)
1134
+ symbol_table = copy.copy(symbol_table)
1135
+ symbol_table["keep"] = keep
1136
+ may_keep = eval_set_expression(may_keep, symbol_table, "tensors", location)
1137
+ return type(self)(keep=keep, may_keep=may_keep)
1138
+
1139
+ def _parse_non_keep(self, symbol_table: dict[str, Any], location: str) -> "Tensors":
1140
+ return type(self)(
1141
+ tile_shape=[x._parse(symbol_table, location) for x in self.tile_shape],
1142
+ no_refetch_from_above=eval_set_expression(
1143
+ self.no_refetch_from_above, symbol_table, "tensors", location
1144
+ ),
1145
+ force_memory_hierarchy_order=parse_expression(
1146
+ self.force_memory_hierarchy_order,
1147
+ symbol_table,
1148
+ "force_memory_hierarchy_order",
1149
+ location,
1150
+ ),
1151
+ )
1152
+
1153
+
1154
+ @_uninstantiable
1155
+ class TensorHolder(Component):
1156
+ """
1157
+ A `TensorHolder` is a component that holds tensors. These are usually `Memory`s,
1158
+ but can also be `ProcessingStage`s.
1159
+ """
1160
+
1161
+ actions: ParsableList[TensorHolderAction] = MEMORY_ACTIONS
1162
+ """ The actions that this `TensorHolder` can perform. """
1163
+
1164
+ attributes: TensorHolderAttributes = pydantic.Field(
1165
+ default_factory=TensorHolderAttributes
1166
+ )
1167
+ """ The `TensorHolderAttributes` that describe this `TensorHolder`. """
1168
+
1169
+ tensors: Tensors = Tensors()
1170
+ """
1171
+ Fields that control which tensor(s) are kept in this `TensorHolder` and in what
1172
+ order their nodes may appear in the mapping.
1173
+ """
1174
+
1175
+ def model_post_init(self, __context__=None) -> None:
1176
+ self._update_actions(MEMORY_ACTIONS)
1177
+
1178
+
1179
+ class Fanout(Leaf):
1180
+ """
1181
+ Creates a spatial fanout, and doesn't do anything else.
1182
+ """
1183
+
1184
+ attributes: FanoutAttributes = pydantic.Field(default_factory=FanoutAttributes)
1185
+ """ Fanout attributes. Zero energy, leak power, area, and latency. """
1186
+
1187
+
1188
+ class Memory(TensorHolder):
1189
+ """A `Memory` is a `TensorHolder` that stores data over time, allowing for temporal
1190
+ reuse."""
1191
+
1192
+ attributes: "MemoryAttributes" = pydantic.Field(default_factory=MemoryAttributes)
1193
+ """ The attributes of this `Memory`. """
1194
+
1195
+ actions: ParsableList[TensorHolderAction] = MEMORY_ACTIONS
1196
+ """ The actions that this `Memory` can perform. """
1197
+
1198
+
1199
+ class ProcessingStageAttributes(TensorHolderAttributes):
1200
+ """Attributes for a `ProcessingStage`."""
1201
+
1202
+ direction: Literal["up", "down", "up_and_down"]
1203
+ """
1204
+ The direction in which data flows through this `ProcessingStage`. If "up", then data
1205
+ flows from below `TensorHolder`, through this `ProcessingStage` (plus paying
1206
+ associated costs), and then to the next `TensorHolder` above it. Other data
1207
+ movements are assumed to avoid this ProcessingStage.
1208
+ """
1209
+
1210
+
1211
+ class ProcessingStage(TensorHolder):
1212
+ """A `ProcessingStage` is a `TensorHolder` that does not store data over time, and
1213
+ therefore does not allow for temporal reuse. Use this as a toll that charges reads
1214
+ and writes every time a piece of data moves through it.
1215
+
1216
+ Every write to a `ProcessingStage` is immediately written to the next `Memory`
1217
+ (which may be above or below depending on where the write came from), and same for
1218
+ reads.
1219
+
1220
+ The access counts of a `ProcessingStage` are only included in the "read" action.
1221
+ Each traversal through the `ProcessingStage` is counted as a read. Writes are always
1222
+ zero.
1223
+ """
1224
+
1225
+ attributes: ProcessingStageAttributes = pydantic.Field(
1226
+ default_factory=ProcessingStageAttributes
1227
+ )
1228
+ """ The attributes of this `ProcessingStage`. """
1229
+
1230
+ actions: ParsableList[TensorHolderAction] = PROCESSING_STAGE_ACTIONS
1231
+ """ The actions that this `ProcessingStage` can perform. """
1232
+
1233
+ def model_post_init(self, __context__=None) -> None:
1234
+ self._update_actions(PROCESSING_STAGE_ACTIONS)
1235
+
1236
+
1237
+ class ComputeAttributes(ComponentAttributes):
1238
+ """Attributes for a `Compute`."""
1239
+
1240
+ pass
1241
+
1242
+
1243
+ class Compute(Component):
1244
+ actions: ParsableList[Action] = COMPUTE_ACTIONS
1245
+ """ The actions that this `Compute` can perform. """
1246
+
1247
+ attributes: ComputeAttributes = pydantic.Field(default_factory=ComputeAttributes)
1248
+ """ The attributes of this `Compute`. """
1249
+
1250
+ def model_post_init(self, __context__=None) -> None:
1251
+ self._update_actions(COMPUTE_ACTIONS)
1252
+
1253
+
1254
+ T = TypeVar("T")
1255
+
1256
+
1257
+ @_uninstantiable
1258
+ class Branch(ArchNode):
1259
+ # nodes: ArchNodes[_InferFromTag[Compute, Memory, "Hierarchical"]] = ArchNodes()
1260
+ nodes: ArchNodes[
1261
+ Annotated[
1262
+ Union[
1263
+ Annotated[Compute, Tag("Compute")],
1264
+ Annotated[Memory, Tag("Memory")],
1265
+ Annotated[ProcessingStage, Tag("ProcessingStage")],
1266
+ Annotated[Fanout, Tag("Fanout")],
1267
+ Annotated["Parallel", Tag("Parallel")],
1268
+ Annotated["Hierarchical", Tag("Hierarchical")],
1269
+ ],
1270
+ Discriminator(_get_tag),
1271
+ ]
1272
+ ] = ArchNodes()
1273
+
1274
+ def get_nodes_of_type(self, types: Type[T] | tuple[Type[T], ...]) -> Iterator[T]:
1275
+ for node in self.nodes:
1276
+ if isinstance(node, types):
1277
+ yield node
1278
+ elif isinstance(node, Branch):
1279
+ yield from node.get_nodes_of_type(types)
1280
+
1281
+
1282
+ class Parallel(Branch):
1283
+ def _flatten(
1284
+ self,
1285
+ attributes: dict,
1286
+ compute_node: str,
1287
+ fanout: int = 1,
1288
+ return_fanout: bool = False,
1289
+ ):
1290
+ nodes = []
1291
+
1292
+ def _parse_node(node: Leaf, fanout: int):
1293
+ fanout *= node.get_fanout()
1294
+ node2 = node.model_copy()
1295
+ node2.attributes = type(node.attributes)(
1296
+ **{**attributes.model_dump(), **node.attributes.model_dump()}
1297
+ )
1298
+ nodes.append(node2)
1299
+ return fanout
1300
+
1301
+ for node in self.nodes:
1302
+ if isinstance(node, Compute) and node.name == compute_node:
1303
+ fanout = _parse_node(node, fanout)
1304
+ break
1305
+ if isinstance(node, Branch):
1306
+ computes = node.get_nodes_of_type(Compute)
1307
+ if compute_node in [c.name for c in computes]:
1308
+ new_nodes, new_fanout = node._flatten(
1309
+ attributes, compute_node, fanout, return_fanout=True
1310
+ )
1311
+ nodes.extend(new_nodes)
1312
+ fanout *= new_fanout
1313
+ break
1314
+ else:
1315
+ raise ParseError(f"Compute node {compute_node} not found in parallel node")
1316
+
1317
+ return nodes, fanout if return_fanout else nodes
1318
+
1319
+
1320
+ class Hierarchical(Branch):
1321
+ def _flatten(
1322
+ self,
1323
+ attributes: dict,
1324
+ compute_node: str,
1325
+ fanout: int = 1,
1326
+ return_fanout: bool = False,
1327
+ ):
1328
+ nodes = []
1329
+
1330
+ def _parse_node(node: Leaf, fanout: int):
1331
+ fanout *= node.get_fanout()
1332
+ node2 = node.model_copy()
1333
+ attrs = {**node.attributes.model_dump()}
1334
+ if isinstance(node.attributes, AttributesWithExtras):
1335
+ attrs = {**attributes.model_dump(), **attrs}
1336
+ node2.attributes = type(node.attributes)(**attrs)
1337
+ nodes.append(node2)
1338
+ return fanout
1339
+
1340
+ for i, node in enumerate(self.nodes):
1341
+ try:
1342
+ if isinstance(node, (Hierarchical, Parallel)):
1343
+ if isinstance(node, Parallel) and i < len(self.nodes) - 1:
1344
+ raise ParseError(
1345
+ f"Parallel node {node.name} must be the last node in a "
1346
+ "hierarchical node"
1347
+ )
1348
+ new_nodes, new_fanout = node._flatten(
1349
+ attributes, compute_node, fanout, return_fanout=True
1350
+ )
1351
+ nodes.extend(new_nodes)
1352
+ fanout *= new_fanout
1353
+ if any(
1354
+ isinstance(n, Compute) and n.name == compute_node
1355
+ for n in new_nodes
1356
+ ):
1357
+ break
1358
+ elif isinstance(node, Compute):
1359
+ if node.name == compute_node:
1360
+ fanout = _parse_node(node, fanout)
1361
+ break
1362
+ elif isinstance(node, Leaf) and not isinstance(node, Container):
1363
+ fanout = _parse_node(node, fanout)
1364
+ elif isinstance(node, Container):
1365
+ fanout *= node.get_fanout()
1366
+ else:
1367
+ raise TypeError(f"Can't flatten {node}")
1368
+ except ParseError as e:
1369
+ e.add_field(node)
1370
+ raise e
1371
+
1372
+ if return_fanout:
1373
+ return nodes, fanout
1374
+ return nodes
1375
+
1376
+ def render(self) -> str:
1377
+ graph = _pydot_graph()
1378
+ graph.add_node(pydot.Node("root", shape="box", label="TODO: Arch Render"))
1379
+ return _SVGJupyterRender(graph.create_svg(prog="dot").decode("utf-8"))
1380
+
1381
+ def _repr_svg_(self) -> str:
1382
+ return self.render()
1383
+
1384
+
1385
+ class _ConstraintLambda:
1386
+ def __init__(
1387
+ self,
1388
+ constraint: Comparison,
1389
+ target_mapping_nodes: list[Spatial],
1390
+ rank_variables: set[str],
1391
+ ):
1392
+ self.constraint = constraint
1393
+ self.constraint_lambda = (
1394
+ None if constraint is None else constraint._to_constraint_lambda(True)
1395
+ )
1396
+ self.target_mapping_nodes = target_mapping_nodes
1397
+ self.rank_variables = rank_variables
1398
+ self._target_node_indices = None
1399
+ self._target_loop_indices = None
1400
+
1401
+ def __call__(self, rank_variables: set[RankVariable], sizes: np.ndarray) -> bool:
1402
+ final = self.rank_variables.issubset(rank_variables)
1403
+ return self.constraint_lambda(final, sizes)
1404
+
1405
+ def _constrained_node_str(self) -> str:
1406
+ return f"constrains {self._target_node_indices}"
1407
+
1408
+ def __bool__(self) -> bool:
1409
+ return bool(self.target_mapping_nodes)
1410
+
1411
+
1412
+ class _TileShapeConstraintLambda(_ConstraintLambda):
1413
+ def pretty_str(self) -> str:
1414
+ return f"Tile shape {self.constraint.operator} {self.constraint.value} {self._constrained_node_str()}"
1415
+
1416
+
1417
+ class _LoopBoundsConstraintLambda(_ConstraintLambda):
1418
+ def pretty_str(self) -> str:
1419
+ return f"Loop bounds {self.constraint.operator} {self.constraint.value} {self._constrained_node_str()}"
1420
+
1421
+
1422
+ class _MinUtilizationConstraintLambda(_ConstraintLambda):
1423
+ def __init__(
1424
+ self,
1425
+ target_mapping_nodes: list[Spatial],
1426
+ rank_variables: set[str],
1427
+ min_usage: float,
1428
+ ):
1429
+ super().__init__(None, target_mapping_nodes, rank_variables)
1430
+ self.min_usage = min_usage
1431
+
1432
+ def __call__(self, complete_indices: list[int], utilizations: np.ndarray) -> bool:
1433
+ # final = self.rank_variables.issubset(rank_variables)
1434
+ final = set(self._target_loop_indices).issubset(set(complete_indices))
1435
+ if not final:
1436
+ return np.ones(utilizations.shape[0], dtype=np.bool)
1437
+
1438
+ # Some utilizations are already above the minimum. Return those.
1439
+ result = utilizations >= self.min_usage
1440
+ if np.sum(result) > 0:
1441
+ return result
1442
+
1443
+ # Nobody is amove the minimum. Return the best we can do.
1444
+ max_utilization = np.max(utilizations, axis=0)
1445
+ return utilizations == max_utilization
1446
+
1447
+ def pretty_str(self) -> str:
1448
+ return f"Min utilization {self.min_usage} {self._constrained_node_str()}"
1449
+
1450
+
1451
+ class Arch(Hierarchical):
1452
+ # version: Annotated[str, assert_version] = __version__
1453
+ # """ The version of the architecture spec. """
1454
+
1455
+ @property
1456
+ def total_area(self) -> float:
1457
+ """
1458
+ Returns the total area of the architecture in m^2.
1459
+
1460
+ Returns
1461
+ -------
1462
+ float
1463
+ The total area of the architecture in m^2.
1464
+ """
1465
+ return sum(self.per_component_total_area.values())
1466
+
1467
+ @property
1468
+ def total_leak_power(self) -> float:
1469
+ """
1470
+ Returns the total leak power of the architecture in W.
1471
+
1472
+ Returns
1473
+ -------
1474
+ float
1475
+ The total leak power of the architecture in W.
1476
+ """
1477
+ return sum(self.per_component_total_leak_power.values())
1478
+
1479
+ @property
1480
+ def per_component_total_area(self) -> dict[str, float]:
1481
+ """
1482
+ Returns the total area used by each component in the architecture in m^2.
1483
+
1484
+ Returns
1485
+ -------
1486
+ dict[str, float]
1487
+ A dictionary of component names to their total area in m^2.
1488
+ """
1489
+ area = {
1490
+ node.name: node.attributes.total_area
1491
+ for node in self.get_nodes_of_type(Component)
1492
+ }
1493
+ for k, v in area.items():
1494
+ if v is None:
1495
+ raise ValueError(
1496
+ f"Area of {k} is not set. Please call the Spec's "
1497
+ "`calculate_component_area_energy_latency_leak` method before accessing this "
1498
+ "property."
1499
+ )
1500
+ return area
1501
+
1502
+ @property
1503
+ def per_component_total_leak_power(self) -> dict[str, float]:
1504
+ """
1505
+ Returns the total leak power of each component in the architecture in W.
1506
+
1507
+ Returns
1508
+ -------
1509
+ dict[str, float]
1510
+ A dictionary of component names to their total leak power in W.
1511
+ """
1512
+ leak_power = {
1513
+ node.name: node.attributes.total_leak_power
1514
+ for node in self.get_nodes_of_type(Component)
1515
+ }
1516
+ for k, v in leak_power.items():
1517
+ if v is None:
1518
+ raise ValueError(
1519
+ f"Leak power of {k} is not set. Please call the Spec's "
1520
+ "`calculate_component_area_energy_latency_leak` method before accessing this "
1521
+ "property."
1522
+ )
1523
+ return leak_power
1524
+
1525
+ def _parse_expressions(self, *args, **kwargs):
1526
+ symbol_table = kwargs["symbol_table"]
1527
+ for node in self.get_nodes_of_type(Leaf):
1528
+ symbol_table[node.name] = node
1529
+ return super()._parse_expressions(*args, **kwargs)
1530
+
1531
+ def __getitem__(self, name: str) -> Leaf:
1532
+ return self.name2leaf(name)
1533
+
1534
+ def model_post_init(self, __context__=None) -> None:
1535
+ # Make sure all leaf names are unique
1536
+ leaves = {}
1537
+ for l in self.get_nodes_of_type(Leaf):
1538
+ n = l.name
1539
+ leaves.setdefault(n, l)
1540
+ assert l is leaves[n], f"Duplicate name {n} found in architecture"
1541
+
1542
+
1543
+ # We had to reference Hierarchical before it was defined
1544
+ Branch.model_rebuild()