accelforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of accelforge might be problematic. Click here for more details.

Files changed (258) hide show
  1. accelforge/__init__.py +21 -0
  2. accelforge/_accelerated_imports.py +16 -0
  3. accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
  4. accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
  5. accelforge/_deprecate/_simanneal/simanneal.py +666 -0
  6. accelforge/_deprecate/_simanneal/tracking.py +105 -0
  7. accelforge/_deprecate/_simanneal/wrappers.py +218 -0
  8. accelforge/_deprecate/_simanneal2/__init__.py +7 -0
  9. accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
  10. accelforge/_deprecate/_simanneal2/tracking.py +116 -0
  11. accelforge/_deprecate/compatibility_util.py +181 -0
  12. accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
  13. accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
  14. accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
  15. accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
  16. accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
  17. accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
  18. accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
  19. accelforge/_deprecate/tags.py +69 -0
  20. accelforge/_deprecate/viz/__init__.py +0 -0
  21. accelforge/_deprecate/viz/interactive.py +159 -0
  22. accelforge/_deprecate/viz/reservationtree.py +307 -0
  23. accelforge/_deprecate/viz/ski_slope.py +88 -0
  24. accelforge/_version.py +15 -0
  25. accelforge/examples.py +39 -0
  26. accelforge/frontend/__init__.py +10 -0
  27. accelforge/frontend/_binding.py +129 -0
  28. accelforge/frontend/_workload_isl/__init__.py +2 -0
  29. accelforge/frontend/_workload_isl/_isl.py +149 -0
  30. accelforge/frontend/_workload_isl/_symbolic.py +141 -0
  31. accelforge/frontend/arch copy.py +1544 -0
  32. accelforge/frontend/arch.py +1642 -0
  33. accelforge/frontend/config.py +63 -0
  34. accelforge/frontend/mapper/__init__.py +5 -0
  35. accelforge/frontend/mapper/ffm.py +126 -0
  36. accelforge/frontend/mapper/mapper.py +7 -0
  37. accelforge/frontend/mapper/metrics.py +30 -0
  38. accelforge/frontend/mapping/__init__.py +1 -0
  39. accelforge/frontend/mapping/mapping.py +1736 -0
  40. accelforge/frontend/model.py +14 -0
  41. accelforge/frontend/renames.py +150 -0
  42. accelforge/frontend/spec copy.py +230 -0
  43. accelforge/frontend/spec.py +301 -0
  44. accelforge/frontend/variables.py +12 -0
  45. accelforge/frontend/workload.py +952 -0
  46. accelforge/mapper/FFM/__init__.py +9 -0
  47. accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
  48. accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
  49. accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
  50. accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
  51. accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
  52. accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
  53. accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
  54. accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
  55. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
  56. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
  57. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
  58. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
  59. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
  60. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
  61. accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
  62. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
  63. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
  64. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
  65. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
  66. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
  67. accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
  68. accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
  69. accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
  70. accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
  71. accelforge/mapper/FFM/data.py +61 -0
  72. accelforge/mapper/FFM/main copy.py +236 -0
  73. accelforge/mapper/FFM/main.py +208 -0
  74. accelforge/mapper/FFM/mappings.py +510 -0
  75. accelforge/mapper/FFM/pmappings.py +310 -0
  76. accelforge/mapper/__init__.py +4 -0
  77. accelforge/mapper.py +0 -0
  78. accelforge/model/__init__.py +1 -0
  79. accelforge/model/_looptree/__init__.py +0 -0
  80. accelforge/model/_looptree/accesses.py +335 -0
  81. accelforge/model/_looptree/capacity/__init__.py +1 -0
  82. accelforge/model/_looptree/capacity/aggregators.py +36 -0
  83. accelforge/model/_looptree/capacity/capacity.py +47 -0
  84. accelforge/model/_looptree/energy.py +150 -0
  85. accelforge/model/_looptree/equivalent_ranks.py +29 -0
  86. accelforge/model/_looptree/latency/__init__.py +1 -0
  87. accelforge/model/_looptree/latency/latency.py +98 -0
  88. accelforge/model/_looptree/latency/memory.py +120 -0
  89. accelforge/model/_looptree/latency/processors.py +92 -0
  90. accelforge/model/_looptree/mapping_utilities.py +71 -0
  91. accelforge/model/_looptree/reuse/__init__.py +4 -0
  92. accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
  93. accelforge/model/_looptree/reuse/isl/des.py +59 -0
  94. accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
  95. accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
  96. accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
  97. accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
  98. accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
  99. accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
  100. accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
  101. accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
  102. accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
  103. accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
  104. accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
  105. accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
  106. accelforge/model/_looptree/run.py +122 -0
  107. accelforge/model/_looptree/types.py +26 -0
  108. accelforge/model/_looptree/visualization/__init__.py +0 -0
  109. accelforge/model/_looptree/visualization/occupancy.py +11 -0
  110. accelforge/model/main.py +222 -0
  111. accelforge/plotting/__init__.py +2 -0
  112. accelforge/plotting/mappings.py +219 -0
  113. accelforge/plotting/specs.py +57 -0
  114. accelforge/util/__init__.py +4 -0
  115. accelforge/util/_base_analysis_types.py +24 -0
  116. accelforge/util/_basetypes.py +1089 -0
  117. accelforge/util/_frozenset.py +36 -0
  118. accelforge/util/_isl.py +29 -0
  119. accelforge/util/_itertools.py +14 -0
  120. accelforge/util/_mathfuncs.py +57 -0
  121. accelforge/util/_parse_expressions.py +339 -0
  122. accelforge/util/_picklecache.py +32 -0
  123. accelforge/util/_setexpressions.py +268 -0
  124. accelforge/util/_sympy/__init__.py +0 -0
  125. accelforge/util/_sympy/broadcast_max.py +18 -0
  126. accelforge/util/_visualization.py +112 -0
  127. accelforge/util/_yaml.py +579 -0
  128. accelforge/util/parallel.py +193 -0
  129. accelforge-0.0.1.dist-info/METADATA +64 -0
  130. accelforge-0.0.1.dist-info/RECORD +258 -0
  131. accelforge-0.0.1.dist-info/WHEEL +5 -0
  132. accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
  133. accelforge-0.0.1.dist-info/top_level.txt +5 -0
  134. docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
  135. docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
  136. docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
  137. docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
  138. docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
  139. docs/_build/html/_sources/fastfusion.rst.txt +20 -0
  140. docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
  141. docs/_build/html/_sources/index.rst.txt +87 -0
  142. docs/_build/html/_sources/modules.rst.txt +7 -0
  143. docs/_build/html/_sources/notes/citation.rst.txt +45 -0
  144. docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
  145. docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
  146. docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
  147. docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
  148. docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
  149. docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
  150. docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
  151. docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
  152. docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
  153. docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
  154. docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
  155. docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
  156. docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
  157. docs/_build/html/_sources/notes/spec.rst.txt +36 -0
  158. docs/source/_ext/include_attrs.py +213 -0
  159. docs/source/_ext/include_docstring.py +364 -0
  160. docs/source/_ext/include_functions.py +154 -0
  161. docs/source/_ext/include_notebook.py +131 -0
  162. docs/source/_ext/include_yaml.py +119 -0
  163. docs/source/_ext/inherited_attributes.py +222 -0
  164. docs/source/_ext/paths.py +4 -0
  165. docs/source/conf.py +79 -0
  166. examples/arches/compute_in_memory/_include.yaml +74 -0
  167. examples/arches/compute_in_memory/_include_functions.py +229 -0
  168. examples/arches/compute_in_memory/_load_spec.py +57 -0
  169. examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
  170. examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
  171. examples/arches/compute_in_memory/components/misc.py +195 -0
  172. examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
  173. examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
  174. examples/arches/compute_in_memory/isaac.yaml +233 -0
  175. examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
  176. examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
  177. examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
  178. examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
  179. examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
  180. examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
  181. examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
  182. examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
  183. examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
  184. examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
  185. examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
  186. examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
  187. examples/arches/eyeriss.yaml +68 -0
  188. examples/arches/fanout_variations/at_glb.yaml +31 -0
  189. examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
  190. examples/arches/fanout_variations/at_mac.yaml +31 -0
  191. examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
  192. examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
  193. examples/arches/nvdla.yaml +47 -0
  194. examples/arches/simple.yaml +28 -0
  195. examples/arches/tpu_v4i.yaml +67 -0
  196. examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
  197. examples/misc/component_annotated.yaml +33 -0
  198. examples/workloads/gpt3_6.7B.yaml +124 -0
  199. examples/workloads/matmuls.yaml +20 -0
  200. examples/workloads/mobilenet_28.yaml +81 -0
  201. examples/workloads/mobilenet_various_separate.yaml +106 -0
  202. examples/workloads/three_matmuls_annotated.yaml +59 -0
  203. notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
  204. notebooks/compute_in_memory/_scripts.py +339 -0
  205. notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
  206. notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
  207. notebooks/paths.py +4 -0
  208. notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
  209. notebooks/tutorials/FFM.ipynb +3498 -0
  210. notebooks/tutorials/_include.py +48 -0
  211. notebooks/tutorials/component_energy_area.ipynb +363 -0
  212. tests/Q_mapping.yaml +38 -0
  213. tests/__init__.py +0 -0
  214. tests/conv.mapping.yaml +27 -0
  215. tests/conv.workload.yaml +13 -0
  216. tests/conv_sym.mapping.yaml +43 -0
  217. tests/copy.mapping.yaml +35 -0
  218. tests/copy.workload.yaml +15 -0
  219. tests/distribuffers/__init__.py +0 -0
  220. tests/distribuffers/multicast/test_cases.yaml +482 -0
  221. tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
  222. tests/distribuffers/spec/distributed.yaml +100 -0
  223. tests/distribuffers/spec/logical_arch.yaml +32 -0
  224. tests/distribuffers/spec/physical_arch.yaml +69 -0
  225. tests/distribuffers/test_binding.py +48 -0
  226. tests/frontend/__init__.py +0 -0
  227. tests/frontend/test_mapping_viz.py +52 -0
  228. tests/mapper/__init__.py +0 -0
  229. tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
  230. tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
  231. tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
  232. tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
  233. tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
  234. tests/mapper/test_mapping_to_isl.py +90 -0
  235. tests/mapper/test_spatial_reuse_analysis.py +67 -0
  236. tests/mapper/test_temporal_reuse_analysis.py +56 -0
  237. tests/mapper/util.py +58 -0
  238. tests/matmul.mapping.yaml +29 -0
  239. tests/matmul.workload.yaml +12 -0
  240. tests/matmul_spatial.mapping.yaml +44 -0
  241. tests/mha.renames.yaml +65 -0
  242. tests/mha.workload.yaml +67 -0
  243. tests/mha.yaml +59 -0
  244. tests/mha_full.workload.yaml +67 -0
  245. tests/mobilenet.workload.yaml +35 -0
  246. tests/mobilenet_long.workload.yaml +64 -0
  247. tests/pmappingcache.py +24 -0
  248. tests/processing_stage.arch.yaml +40 -0
  249. tests/snowcat.arch.yaml +36 -0
  250. tests/test_ffm_join_pmappings.py +106 -0
  251. tests/test_ffm_make_pmappings.py +82 -0
  252. tests/test_ffm_make_tile_shapes.py +49 -0
  253. tests/test_mapper.py +100 -0
  254. tests/test_model.py +37 -0
  255. tests/test_plotting.py +72 -0
  256. tests/test_processing_stage.py +46 -0
  257. tests/test_symbolic_model.py +248 -0
  258. tests/test_workload.py +141 -0
@@ -0,0 +1,1642 @@
1
+ import copy
2
+ import itertools
3
+ import logging
4
+ import math
5
+ from numbers import Number
6
+ import re
7
+ from typing import (
8
+ Any,
9
+ Callable,
10
+ Iterator,
11
+ List,
12
+ Literal,
13
+ Optional,
14
+ Self,
15
+ TypeVar,
16
+ Annotated,
17
+ Type,
18
+ Union,
19
+ )
20
+ from pydantic import ConfigDict, Tag
21
+ import pydantic
22
+ from hwcomponents import (
23
+ ComponentModel,
24
+ get_models,
25
+ get_model,
26
+ )
27
+ import pydot
28
+
29
+ from accelforge.util._basetypes import (
30
+ ParsableModel,
31
+ ParsableList,
32
+ ParseExtras,
33
+ ParsesTo,
34
+ TryParseTo,
35
+ _PostCall,
36
+ _get_tag,
37
+ )
38
+ import numpy as np
39
+
40
+ from accelforge.util._parse_expressions import ParseError, parse_expression
41
+ from accelforge.util._setexpressions import InvertibleSet, eval_set_expression
42
+ from accelforge.frontend.renames import RankVariable, TensorName
43
+
44
+ from accelforge._version import assert_version, __version__
45
+ from pydantic import Discriminator
46
+ from accelforge.util._basetypes import _uninstantiable
47
+ from accelforge.util.parallel import _SVGJupyterRender
48
+ from accelforge.util._visualization import _pydot_graph
49
+
50
+ T = TypeVar("T", bound="ArchNode")
51
+
52
+
53
+ class ArchNode(ParsableModel):
54
+ """A node in the architecture."""
55
+
56
+ def find(self, name: str) -> "Leaf":
57
+ """
58
+ Finds a `Leaf` node with the given name.
59
+ :raises ValueError: If the `Leaf` node with the given name is not found.
60
+ """
61
+ if isinstance(self, Leaf) and getattr(self, "name", None) == name:
62
+ return self
63
+
64
+ if isinstance(self, Branch):
65
+ for element in self.nodes:
66
+ try:
67
+ return element.find(name)
68
+ except (AttributeError, ValueError):
69
+ pass
70
+ raise ValueError(f"Leaf {name} not found in {self}")
71
+
72
+
73
+ class ArchNodes(ParsableList):
74
+ """A list of `ArchNode`s."""
75
+
76
+ def __repr__(self):
77
+ return f"{self.__class__.__name__}({super().__repr__()})"
78
+
79
+ def _parse_expressions(self, symbol_table: dict[str, Any], *args, **kwargs):
80
+ class PostCallArchNode(_PostCall):
81
+ def __call__(self, field, value, parsed, symbol_table):
82
+ symbol_table[parsed.name] = parsed
83
+ return parsed
84
+
85
+ for i, node in enumerate(self):
86
+ symbol_table[i] = node
87
+
88
+ return super()._parse_expressions(
89
+ symbol_table, *args, **kwargs, post_calls=(PostCallArchNode(),)
90
+ )
91
+
92
+
93
+ class Comparison(ParsableModel):
94
+ """
95
+ A comparison between a rank variable's bound and a value. A comparison is performed
96
+ for each rank variable.
97
+
98
+ The LHS of each comparison is the loop bound of a loop that affects this rank
99
+ variable. The RHS is the given value.
100
+
101
+ For example, if the expression resolves to [a, b], the operator is "<=", and the
102
+ value is 10, and we have loops "for a0 in [0..A0)" and "for b0 in [0..B0)", then a
103
+ mapping is only valid if A0 <= 10 and B0 <= 10.
104
+ """
105
+
106
+ expression: TryParseTo[InvertibleSet[RankVariable]]
107
+ """ The expression to compare. This expression should resolve to a set of rank
108
+ variables. A comparison is performed for each rank variable independently, and the
109
+ result passes if and only if all comparisons pass. The LHS of each comparison is the
110
+ loop bound of a loop that affects this rank variable. The RHS is the given value.
111
+ """
112
+
113
+ operator: str
114
+ """ The operator to use for the comparison. Supported operators are:
115
+ - == (equal to)
116
+ - <= (less than or equal to)
117
+ - >= (greater than or equal to)
118
+ - < (less than)
119
+ - > (greater than)
120
+ - product== (product of all loop bounds is equal to)
121
+ - product<= (product of all loop bounds is less than or equal to)
122
+ - product>= (product of all loop bounds is greater than or equal to)
123
+ - product< (product of all loop bounds is less than)
124
+ - product> (product of all loop bounds is greater than)
125
+ """
126
+
127
+ value: ParsesTo[int]
128
+ """ The value to compare against. """
129
+
130
+ _str_repr: str = None
131
+ """ A string to print for this comparison when __str__ is called. If None, a default
132
+ string will be used. """
133
+
134
+ def _parse_expressions(self, *args, **kwargs):
135
+ result, symbol_table = super()._parse_expressions(*args, **kwargs)
136
+ if len(result.expression) == 1 and "product" in result.operator:
137
+ result.operator = result.operator.replace("product", "")
138
+ return result, symbol_table
139
+
140
+ # def _parse(self, symbol_table: dict[str, Any], location: str):
141
+ # # if len(self) != 3:
142
+ # # raise ValueError(f"Comparison can only have 3 elements. got {len(self)}")
143
+ # new = type(self)(
144
+ # expression=eval_set_expression(
145
+ # self.expression, symbol_table, "rank_variables", location
146
+ # ),
147
+ # operator=self.operator,
148
+ # value=self.value,
149
+ # )
150
+ # if len(new.expression) == 1 and "product" in new.operator:
151
+ # new.operator = new.operator.replace("product", "")
152
+ # return new
153
+
154
+ def _constrained_to_one(self) -> bool:
155
+ return self.value == 1 and self.operator in [
156
+ "==",
157
+ "<=",
158
+ "product==",
159
+ "product<=",
160
+ ]
161
+
162
+ def _split_expression(self) -> List[set[RankVariable]]:
163
+ if "product" in self.operator:
164
+ return [self.expression]
165
+ return sorted(set((x,)) for x in self.expression)
166
+
167
+ def _to_constraint_lambda(
168
+ self,
169
+ increasing_sizes: bool,
170
+ ) -> Callable[[bool, np.ndarray], bool | np.ndarray]:
171
+ # Equal operators can only evaluate when all sizes are known
172
+ eq_op = lambda final: (
173
+ np.equal
174
+ if final
175
+ else (np.less_equal if increasing_sizes else np.greater_equal)
176
+ )
177
+
178
+ # If we're increasing, we can evaluate leq immediately. If we're
179
+ # decreasing, we can evaluate geq immediately. The other must wait
180
+ # until all sizes are known.
181
+ le_wrapper = lambda op: lambda final, sizes: (
182
+ op(sizes) if final or increasing_sizes else True
183
+ )
184
+ ge_wrapper = lambda op: lambda final, sizes: (
185
+ op(sizes) if final or not increasing_sizes else True
186
+ )
187
+
188
+ _all = lambda sizes: np.all(sizes, axis=1)
189
+ _prod = lambda sizes: np.prod(sizes, axis=1)
190
+
191
+ # fmt: off
192
+ operator_to_wrapper = {
193
+ "==": lambda final, sizes: _all(eq_op(final)(sizes, self.value)),
194
+ "product==": lambda final, sizes: eq_op(final)(_prod(sizes), self.value),
195
+ "<=": le_wrapper(lambda sizes: _all(sizes) <= self.value),
196
+ ">=": ge_wrapper(lambda sizes: _all(sizes) >= self.value),
197
+ "<": le_wrapper(lambda sizes: _all(sizes) < self.value),
198
+ ">": ge_wrapper(lambda sizes: _all(sizes) > self.value),
199
+ "product<=": le_wrapper(lambda sizes: _prod(sizes) <= self.value),
200
+ "product>=": ge_wrapper(lambda sizes: _prod(sizes) >= self.value),
201
+ "product<": le_wrapper(lambda sizes: _prod(sizes) < self.value),
202
+ "product>": ge_wrapper(lambda sizes: _prod(sizes) > self.value),
203
+ }
204
+ # fmt: on
205
+
206
+ if self.operator in operator_to_wrapper:
207
+ return operator_to_wrapper[self.operator]
208
+ raise KeyError(
209
+ f"Unknown operator: {self.operator}. Known operators: {list(operator_to_wrapper.keys())}"
210
+ )
211
+
212
+ def __str__(self) -> str:
213
+ if self._str_repr is not None:
214
+ return self._str_repr
215
+ return f"({sorted(self.expression)}) {self.operator} ({self.value})"
216
+
217
+
218
+ class Spatial(ParsableModel):
219
+ """A one-dimensional spatial fanout in the architecture."""
220
+
221
+ name: str
222
+ """
223
+ The name of the dimension over which this spatial fanout is occurring (e.g., X or Y).
224
+ """
225
+
226
+ fanout: ParsesTo[int]
227
+ """ The size of this fanout. """
228
+
229
+ may_reuse: TryParseTo[InvertibleSet[TensorName]] = "All"
230
+ """ The tensors that can be reused spatially across instances of this fanout. This
231
+ expression will be parsed for each mapping template. """
232
+
233
+ loop_bounds: ParsableList[Comparison] = ParsableList()
234
+ """ Bounds for loops over this dimension. This is a list of :class:`~.Comparison`
235
+ objects, all of which must be satisfied by the loops to which this constraint
236
+ applies.
237
+
238
+ Note: Loops may be removed if they are constrained to only one iteration.
239
+ """
240
+
241
+ min_usage: int | float | str = 0.0
242
+ """ The minimum usage of spatial instances, as a value from 0 to 1. A mapping
243
+ is invalid if less than this porportion of this dimension's fanout is utilized.
244
+ Mappers that support it (e.g., FFM) may, if no mappings satisfy this constraint,
245
+ return the highest-usage mappings.
246
+ """
247
+
248
+ reuse: TryParseTo[InvertibleSet[TensorName]] = "Nothing"
249
+ """ A set of tensors or a set expression representing tensors that must be reused
250
+ across spatial iterations. Spatial loops may only be placed that reuse ALL tensors
251
+ given here.
252
+
253
+ Note: Loops may be removed if they do not reuse a tensor given here and they do not
254
+ appear in another loop bound constraint.
255
+ """
256
+
257
+ usage_scale: ParsesTo[int | float | str] = 1
258
+ """
259
+ This factor scales the usage in this dimension. For example, if usage_scale is 2 and
260
+ 10/20 spatial instances are used, then the usage will be scaled to 20/20.
261
+ """
262
+
263
+ power_gateable: ParsesTo[bool] = False
264
+ """
265
+ Whether this spatial fanout has power gating. If True, then unused spatial instances
266
+ will be power gated if not used by a particular Einsum.
267
+ """
268
+
269
+ # def _parse(self, symbol_table: dict[str, Any], location: str):
270
+ # return type(self)(
271
+ # name=self.name,
272
+ # fanout=self.fanout,
273
+ # may_reuse=set(
274
+ # eval_set_expression(
275
+ # self.may_reuse,
276
+ # symbol_table,
277
+ # expected_space=TensorName,
278
+ # location=location + ".may_reuse",
279
+ # )
280
+ # ),
281
+ # loop_bounds=[
282
+ # x._parse(symbol_table, location + ".loop_bounds")
283
+ # for x in self.loop_bounds
284
+ # ],
285
+ # min_usage=parse_expression(
286
+ # self.min_usage,
287
+ # symbol_table,
288
+ # "min_usage",
289
+ # location + ".min_usage",
290
+ # ),
291
+ # reuse=eval_set_expression(
292
+ # self.reuse,
293
+ # symbol_table,
294
+ # "tensors",
295
+ # location + ".reuse",
296
+ # ),
297
+ # )
298
+
299
+
300
+ class Action(ParsableModel):
301
+ """
302
+ An action that may be performed by a component.
303
+ """
304
+
305
+ name: str
306
+ """ The name of this action. """
307
+
308
+ energy: ParsesTo[int | float | None] = None
309
+ """
310
+ Dynamic energy of this action. Per-action energy is multiplied by the component's
311
+ energy_scale and the action's energy_scale.
312
+ """
313
+
314
+ energy_scale: ParsesTo[int | float] = 1
315
+ """
316
+ The scale factor for dynamic energy of this action. Multiplies this action's energy
317
+ by this value.
318
+ """
319
+
320
+ latency: ParsesTo[int | float | None] = None
321
+ """
322
+ Latency of this action. Per-action latency is multiplied by the component's
323
+ latency_scale and the action's latency_scale.
324
+ """
325
+
326
+ latency_scale: ParsesTo[int | float] = 1
327
+ """
328
+ The scale factor for dynamic latency of this action. Multiplies this action's
329
+ latency by this value.
330
+ """
331
+
332
+ extra_attributes_for_component_model: ParseExtras = ParseExtras()
333
+ """ Extra attributes to pass to the component model. In addition to all attributes
334
+ of this action, any extra attributes will be passed to the component model as
335
+ arguments to the component model's action. This can be used to define attributes
336
+ that are known to the component model, but not accelforge, such as clock
337
+ frequency."""
338
+
339
+ def _attributes_for_component_model(self) -> dict[str, Any]:
340
+ return {
341
+ **self.shallow_model_dump(),
342
+ **self.extra_attributes_for_component_model.shallow_model_dump(),
343
+ }
344
+
345
+
346
+ class TensorHolderAction(Action):
347
+ bits_per_action: ParsesTo[int | float] = (
348
+ "1 if bits_per_action is None else bits_per_action"
349
+ )
350
+ """ The number of bits accessed in this action. For example, setting bits_per_action
351
+ to 16 means that each call to this action yields 16 bits. """
352
+
353
+
354
+ @_uninstantiable
355
+ class Leaf(ArchNode):
356
+ """A leaf node in the architecture. This is an abstract class that represents any
357
+ node that is not a `Branch`."""
358
+
359
+ name: str
360
+ """ The name of this `Leaf`. """
361
+
362
+ spatial: ParsableList[Spatial] = ParsableList()
363
+ """
364
+ The spatial fanouts of this `Leaf`.
365
+
366
+ Spatial fanouts describe the spatial organization of components in the architecture.
367
+ A spatial fanout of size N for this node means that there are N instances of this
368
+ node. Multiple spatial fanouts lead to a multi-dimensional fanout. Spatial
369
+ constraints apply to the data exchange across these instances. Spatial fanouts
370
+ specified at this level also apply to lower-level `Leaf` nodes in the architecture.
371
+ """
372
+
373
+ def get_fanout(self) -> int:
374
+ """The spatial fanout of this node."""
375
+ return int(math.prod(x.fanout for x in self.spatial))
376
+
377
+
378
+ _COMPONENT_MODEL_CACHE: dict[tuple, "Component"] = {}
379
+
380
+
381
+ def _set_component_model_cache(key: tuple, value: "Component"):
382
+ while len(_COMPONENT_MODEL_CACHE) > 1000:
383
+ _COMPONENT_MODEL_CACHE.popitem(last=False)
384
+ _COMPONENT_MODEL_CACHE[key] = value
385
+
386
+
387
+ class _ExtraAttrs(ParseExtras):
388
+ def _parse_expressions(self, symbol_table: dict[str, Any], *args, **kwargs):
389
+ if getattr(self, "_parsed", False):
390
+ return super()._parse_expressions(symbol_table, *args, **kwargs)
391
+
392
+ orig_symbol_table = dict(symbol_table)
393
+ if "arch_extra_attributes_for_all_component_models" not in orig_symbol_table:
394
+ raise ParseError(
395
+ "arch_extra_attributes_for_all_component_models is not set in the symbol "
396
+ "table. Was parsing called from the architecture on top?"
397
+ )
398
+
399
+ for k, v in orig_symbol_table[
400
+ "arch_extra_attributes_for_all_component_models"
401
+ ].items():
402
+ if getattr(self, k, None) is None:
403
+ setattr(self, k, v)
404
+
405
+ return super()._parse_expressions(symbol_table, *args, **kwargs)
406
+
407
+
408
+ @_uninstantiable
409
+ class Component(Leaf):
410
+ """A component object in the architecture. This is overridden by different
411
+ component types, such as `Memory` and `Compute`."""
412
+
413
+ name: str
414
+ """ The name of this `Component`. """
415
+
416
+ component_class: Optional[str] = None
417
+ """ The class of this `Component`. Used if an energy or area model needs to be
418
+ called for this `Component`. """
419
+
420
+ component_model: ComponentModel | None = None
421
+ """ The model to use for this `Component`. If not set, the model will be found with
422
+ `hwcomponents.get_models()`. If set, the `component_class` will be ignored. """
423
+
424
+ component_modeling_log: list[str] = []
425
+ """ A log of the energy and area calculations for this `Component`. """
426
+
427
+ actions: ParsableList[Action]
428
+ """ The actions that this `Component` can perform. """
429
+
430
+ enabled: TryParseTo[bool] = True
431
+ """ Whether this component is enabled. If the expression resolves to False, then
432
+ the component is disabled. This is parsed per-pmapping-template, so it is a function
433
+ of the tensors in the current Einsum. For example, you may say `len(All) >= 3` and
434
+ the component will only be enabled with Einsums with three or more tensors.
435
+ """
436
+
437
+ area: ParsesTo[int | float | None] = None
438
+ """
439
+ The area of a single instance of this component in m^2. If set, area calculations
440
+ will use this value.
441
+ """
442
+
443
+ total_area: ParsesTo[int | float | None] = None
444
+ """
445
+ The total area of all instances of this component in m^2. Do not set this value. It
446
+ is calculated when the architecture's area is calculated.
447
+ """
448
+
449
+ area_scale: ParsesTo[int | float] = 1
450
+ """
451
+ The scale factor for the area of this comxponent. This is used to scale the area of
452
+ this component. For example, if the area is 1 m^2 and the scale factor is 2, then
453
+ the area is 2 m^2.
454
+ """
455
+
456
+ leak_power: ParsesTo[int | float | None] = None
457
+ """
458
+ The leak power of a single instance of this component in W. If set, leak power
459
+ calculations will use this value.
460
+ """
461
+
462
+ total_leak_power: ParsesTo[int | float | None] = None
463
+ """
464
+ The total leak power of all instances of this component in W. Do not set this value.
465
+ It is calculated when the architecture's leak power is calculated. If instances are
466
+ power gated, actual leak power may be less than this value.
467
+ """
468
+
469
+ leak_power_scale: ParsesTo[int | float] = 1
470
+ """
471
+ The scale factor for the leak power of this component. This is used to scale the
472
+ leak power of this component. For example, if the leak power is 1 W and the scale
473
+ factor is 2, then the leak power is 2 W.
474
+ """
475
+
476
+ energy_scale: ParsesTo[int | float] = 1
477
+ """
478
+ The scale factor for dynamic energy of this component. For each action, multiplies
479
+ this action's energy. Multiplies the calculated energy of each action.
480
+ """
481
+
482
+ total_latency: str | int | float = "sum(*action2latency.values())"
483
+ """
484
+ An expression representing the total latency of this component in seconds. This is
485
+ used to calculate the latency of a given Einsum. Special variables available are the
486
+ following:
487
+
488
+ - `min`: The minimum value of all arguments to the expression.
489
+ - `max`: The maximum value of all arguments to the expression.
490
+ - `sum`: The sum of all arguments to the expression.
491
+ - `X_actions`: The number of times action `X` is performed. For example,
492
+ `read_actions` is the number of times the read action is performed.
493
+ - `X_latency`: The total latency of all actions of type `X`. For example,
494
+ `read_latency` is the total latency of all read actions. It is equal to the
495
+ per-read latency multiplied by the number of read actions.
496
+ - `action2latency`: A dictionary of action names to their latency.
497
+
498
+ Additionally, all component attributes are availble as variables, and all other
499
+ functions generally available in parsing. Note this expression is parsed after other
500
+ component attributes are parsed.
501
+
502
+ For example, the following expression calculates latency assuming that each read or
503
+ write action takes 1ns: ``1e-9 * (read_actions + write_actions)``.
504
+ """
505
+
506
+ latency_scale: ParsesTo[int | float] = 1
507
+ """
508
+ The scale factor for the latency of this component. This is used to scale the
509
+ latency of this component. For example, if the latency is 1 ns and the scale factor
510
+ is 2, then the latency is 2 ns. Multiplies the calculated latency of each action.
511
+ """
512
+
513
+ n_parallel_instances: ParsesTo[int | float] = 1
514
+ """
515
+ The number of parallel instances of this component. Increasing parallel instances
516
+ will proportionally increase area and leakage, while reducing latency (unless
517
+ latency calculation is overridden).
518
+ """
519
+
520
+ extra_attributes_for_component_model: _ExtraAttrs = _ExtraAttrs()
521
+ """ Extra attributes to pass to the component model. In addition to all attributes
522
+ of this component, any extra attributes will be passed to the component model. This
523
+ can be used to define attributes that are known to the component model, but not
524
+ accelforge, such as the technology node."""
525
+
526
+ model_config = ConfigDict(arbitrary_types_allowed=True)
527
+
528
+ def _update_actions(self, new_actions: ParsableList[Action]):
529
+ has_actions = set(x.name for x in self.actions)
530
+ for action in new_actions:
531
+ if action.name not in has_actions:
532
+ self.actions.append(action)
533
+
534
+ def get_component_class(self, trying_to_calculate: str = None) -> str:
535
+ """Returns the class of this `Component`.
536
+
537
+ Parameters
538
+ ----------
539
+ trying_to_parse : str, optional
540
+ What was trying to be calculated using this component. If provided, the
541
+ error message will be more specific.
542
+
543
+ :raises ParseError: If the `component_class` is not set.
544
+ """
545
+ extra_info = ""
546
+ if trying_to_calculate is not None:
547
+ extra_info = f" Occurred while trying to calculate {trying_to_calculate}."
548
+
549
+ if self.component_class is None:
550
+ raise ParseError(
551
+ f"component_class must be set to a valid string. "
552
+ f"Got {self.component_class}. This occurred because the model tried to "
553
+ "talk to hwcomponents, but was missing necessary attributes. If you do "
554
+ "not want to use hwcomponents component_models, ensure that area and "
555
+ "leak_power are set, as well as, for each action, "
556
+ f"energy and latency are set.{extra_info}",
557
+ source_field=f"{self.name}.component_class",
558
+ )
559
+ return self.component_class
560
+
561
+ def _is_dummy(self) -> bool:
562
+ return (
563
+ self.component_model is None
564
+ and self.component_class is not None
565
+ and self.component_class.lower() == "dummy"
566
+ )
567
+
568
+ def populate_component_model(
569
+ self: T,
570
+ component_models: list[ComponentModel] | None = None,
571
+ in_place: bool = False,
572
+ trying_to_calculate: str = None,
573
+ ) -> T:
574
+ """
575
+ Populates the ``component_model`` attribute with the model for this component.
576
+ Extends the ``component_modeling_log`` field with log messages. Uses the
577
+ ``component_class`` attribute to find the model and populate the
578
+ ``component_model`` attribute. Uses the ``hwcomponents.get_model()`` function to
579
+ find the model.
580
+
581
+ Parameters
582
+ ----------
583
+ component_models : list[ComponentModel] | None
584
+ The models to use for energy calculation. If not provided, the models will
585
+ be found with `hwcomponents.get_models()`.
586
+ in_place : bool
587
+ If True, the component will be modified in place. Otherwise, a copy will be
588
+ returned.
589
+ trying_to_calculate : str, optional
590
+ What was trying to be calculated using this component. If provided, the
591
+ error messages for missing component_class will be more specific.
592
+
593
+ Returns
594
+ -------
595
+ T
596
+ A copy of the component with the populated ``component_model`` attribute.
597
+ """
598
+ if not in_place:
599
+ self: Self = self._copy_for_component_modeling()
600
+
601
+ if self._is_dummy():
602
+ return self
603
+
604
+ if self.component_model is None:
605
+ if component_models is None:
606
+ component_models = get_models()
607
+ estimation = get_model(
608
+ self.get_component_class(trying_to_calculate=trying_to_calculate),
609
+ self._attributes_for_component_model(),
610
+ required_actions=list(x.name for x in self.actions),
611
+ models=component_models,
612
+ _return_estimation_object=True,
613
+ )
614
+ self.component_model = estimation.value
615
+ self.component_modeling_log.extend(estimation.messages)
616
+ return self
617
+
618
+ def calculate_action_energy(
619
+ self,
620
+ component_models: list[ComponentModel] | None = None,
621
+ in_place: bool = False,
622
+ ) -> Self:
623
+ """
624
+ Calculates energy for each action of this component. If energy is set in the
625
+ action or component (with action taking precedence), that value will be used.
626
+ Otherwise, the energy will be calculated using hwcomponents. Populates, for each
627
+ action, the ``<action>.energy`` and field. Extends the
628
+ ``component_modeling_log`` field with log messages.
629
+
630
+ Uses the ``component_model`` attribute, or, if not set, the ``component_class``
631
+ attribute to find the model and populate the ``component_model`` attribute.
632
+
633
+ Note that these methods will be called by the Spec when calculating energy and
634
+ area. If you call them yourself, note that string expressions may not be parsed
635
+ because they need the Spec's global scope. If you are sure that all necessary
636
+ values are present and not a result of an expression, you can call these
637
+ directly. Otherwise, you can call the
638
+ ``Spec.calculate_component_area_energy_latency_leak`` and then grab components
639
+ from the returned ``Spec``.
640
+
641
+ Parameters
642
+ ----------
643
+ component_models : list[ComponentModel] | None
644
+ The models to use for energy calculation. If not provided, the models will
645
+ be found with `hwcomponents.get_models()`.
646
+ in_place : bool
647
+ If True, the component will be modified in place. Otherwise, a copy will be
648
+ returned.
649
+
650
+ Returns
651
+ -------
652
+ T
653
+ A copy of the component with the calculated energy.
654
+ """
655
+ if not in_place:
656
+ self: Component = self._copy_for_component_modeling()
657
+
658
+ messages = self.component_modeling_log
659
+
660
+ for action in self.actions:
661
+ messages.append(f"Calculating energy for {self.name} action {action.name}.")
662
+ if action.energy is not None:
663
+ energy = action.energy
664
+ messages.append(f"Setting {self.name} energy to {action.energy=}")
665
+ elif self._is_dummy():
666
+ energy = 0
667
+ messages.append('Component is "Dummy". Setting energy to 0.')
668
+ else:
669
+ self.populate_component_model(
670
+ component_models,
671
+ in_place=True,
672
+ trying_to_calculate=f"energy for action {action.name}",
673
+ )
674
+ energy = self.component_model.try_call_arbitrary_action(
675
+ action_name=action.name,
676
+ _return_estimation_object=True,
677
+ **{
678
+ **self._attributes_for_component_model(),
679
+ **action._attributes_for_component_model(),
680
+ },
681
+ )
682
+ messages.extend(energy.messages)
683
+ energy = energy.value[0]
684
+ if self.energy_scale != 1:
685
+ energy *= self.energy_scale
686
+ messages.append(f"Scaling {self.name} energy by {self.energy_scale=}")
687
+ if action.energy_scale != 1:
688
+ energy *= action.energy_scale
689
+ messages.append(f"Scaling {self.name} energy by {action.energy_scale=}")
690
+ action.energy = energy
691
+ if action.energy < 0:
692
+ logging.warning(
693
+ f"Component {self.name} action {action.name} has negative energy: "
694
+ f"{action.energy=}"
695
+ )
696
+ return self
697
+
698
+ def calculate_leak_power(
699
+ self,
700
+ component_models: list[ComponentModel] | None = None,
701
+ in_place: bool = False,
702
+ ) -> Self:
703
+ """
704
+ Calculates the leak power for this component. If leak power is set in the
705
+ component, that value will be used. Otherwise, the leak power will be calculated
706
+ using hwcomponents. Populates ``leak_power`` field. Extends the
707
+ ``component_modeling_log`` field with log messages.
708
+
709
+ Uses the ``component_model`` attribute, or, if not set, the ``component_class``
710
+ attribute to find the model and populate the ``component_model`` attribute.
711
+
712
+ Note that these methods will be called by the Spec when calculating energy and
713
+ area. If you call them yourself, note that string expressions may not be parsed
714
+ because they need the Spec's global scope. If you are sure that all necessary
715
+ values are present and not a result of an expression, you can call these
716
+ directly. Otherwise, you can call the
717
+ ``Spec.calculate_component_area_energy_latency_leak`` and then grab components
718
+ from the returned ``Spec``.
719
+
720
+ Parameters
721
+ ----------
722
+ component_models : list[ComponentModel] | None
723
+ The models to use for energy calculation. If not provided, the models will
724
+ be found with `hwcomponents.get_models()`.
725
+ in_place : bool
726
+ If True, the component will be modified in place. Otherwise, a copy will be
727
+ returned.
728
+
729
+ Returns
730
+ -------
731
+ Self
732
+ A copy of the component with the calculated energy.
733
+ """
734
+ if not in_place:
735
+ self: Self = self._copy_for_component_modeling()
736
+
737
+ messages = self.component_modeling_log
738
+ if self.leak_power is not None:
739
+ leak_power = self.leak_power
740
+ messages.append(f"Using predefined leak power value {self.leak_power=}")
741
+ elif self._is_dummy():
742
+ leak_power = 0
743
+ messages.append("Component is dummy. Setting leak power to 0.")
744
+ else:
745
+ self.populate_component_model(
746
+ component_models,
747
+ in_place=True,
748
+ trying_to_calculate="leak_power",
749
+ )
750
+ leak_power = self.component_model.leak_power
751
+ if self.leak_power_scale != 1:
752
+ leak_power *= self.leak_power_scale
753
+ messages.append(f"Scaling leak power by {self.leak_power_scale=}")
754
+ if self.n_parallel_instances != 1:
755
+ leak_power *= self.n_parallel_instances
756
+ messages.append(f"Scaling leak power by {self.n_parallel_instances=}")
757
+ self.leak_power = leak_power
758
+ if self.leak_power < 0:
759
+ logging.warning(
760
+ f"Component {self.name} has negative leak power: {self.leak_power}"
761
+ )
762
+ return self
763
+
764
+ def calculate_area(
765
+ self,
766
+ component_models: list[ComponentModel] | None = None,
767
+ in_place: bool = False,
768
+ ) -> Self:
769
+ """
770
+ Calculates the area for this component. If area is set in the component, that
771
+ value will be used. Otherwise, the area will be calculated using the
772
+ hwcomponents library. Populates ``area`` field. Extends the
773
+ ``component_modeling_log`` field with log messages.
774
+
775
+ Uses the ``component_model`` attribute, or, if not set, the ``component_class``
776
+ attribute to find the model and populate the ``component_model`` attribute.
777
+
778
+ Note that these methods will be called by the Spec when calculating energy and
779
+ area. If you call them yourself, note that string expressions may not be parsed
780
+ because they need the Spec's global scope. If you are sure that all necessary
781
+ values are present and not a result of an expression, you can call these
782
+ directly. Otherwise, you can call the
783
+ ``Spec.calculate_component_area_energy_latency_leak`` and then grab components
784
+ from the returned ``Spec``.
785
+
786
+ Parameters
787
+ ----------
788
+ component_models : list[ComponentModel] | None
789
+ The models to use for area calculation. If not provided, the models will be
790
+ found with `hwcomponents.get_models()`.
791
+ in_place : bool
792
+ If True, the component will be modified in place. Otherwise, a copy will be
793
+ returned.
794
+
795
+ Returns
796
+ -------
797
+ Self
798
+ A copy of the component with the calculated area.
799
+ """
800
+ if not in_place:
801
+ self: Self = self._copy_for_component_modeling()
802
+
803
+ messages = self.component_modeling_log
804
+ if self.area is not None:
805
+ area = self.area
806
+ messages.append(f"Using predefined area value {self.area=}")
807
+ elif self._is_dummy():
808
+ area = 0
809
+ messages.append("Component is dummy. Setting area to 0.")
810
+ else:
811
+ self.populate_component_model(
812
+ component_models,
813
+ in_place=True,
814
+ trying_to_calculate="area",
815
+ )
816
+ area = self.component_model.area
817
+ if self.area_scale != 1:
818
+ area *= self.area_scale
819
+ messages.append(f"Scaling area by {self.area_scale=}")
820
+ if self.n_parallel_instances != 1:
821
+ area *= self.n_parallel_instances
822
+ messages.append(f"Scaling area by {self.n_parallel_instances=}")
823
+ self.area = area
824
+ if self.area < 0:
825
+ logging.warning(f"Component {self.name} has negative area: {self.area}")
826
+ return self
827
+
828
+ def calculate_action_latency(
829
+ self,
830
+ component_models: list[ComponentModel] | None = None,
831
+ in_place: bool = False,
832
+ ) -> Self:
833
+ """
834
+ Calculates the latency for each action by this component. Populates the
835
+ ``<action>.latency`` field. Extends the ``component_modeling_log`` field with
836
+ log messages.
837
+
838
+ Parameters
839
+ ----------
840
+ component_models : list[ComponentModel] | None
841
+ The models to use for latency calculation. If not provided, the models will be
842
+ found with `hwcomponents.get_models()`.
843
+ in_place : bool
844
+ If True, the component will be modified in place. Otherwise, a copy will be
845
+ returned.
846
+
847
+ Returns
848
+ -------
849
+ Self
850
+ A copy of the component with the calculated latency for each action.
851
+ """
852
+ if not in_place:
853
+ self: Self = self._copy_for_component_modeling()
854
+
855
+ messages = self.component_modeling_log
856
+
857
+ for action in self.actions:
858
+ messages.append(
859
+ f"Calculating latency for {self.name} action {action.name}."
860
+ )
861
+ if action.latency is not None:
862
+ latency = action.latency
863
+ messages.append(f"Setting {self.name} latency to {action.latency=}")
864
+ elif self._is_dummy():
865
+ latency = 0
866
+ messages.append("Component is dummy. Setting latency to 0.")
867
+ else:
868
+ self.populate_component_model(
869
+ component_models,
870
+ in_place=True,
871
+ trying_to_calculate=f"latency for action {action.name}",
872
+ )
873
+ latency = self.component_model.try_call_arbitrary_action(
874
+ action_name=action.name,
875
+ _return_estimation_object=True,
876
+ **{
877
+ **self._attributes_for_component_model(),
878
+ **action._attributes_for_component_model(),
879
+ },
880
+ )
881
+ messages.extend(latency.messages)
882
+ latency = latency.value[1]
883
+ if self.latency_scale != 1:
884
+ latency *= self.latency_scale
885
+ messages.append(f"Scaling {self.name} latency by {self.latency_scale=}")
886
+ if action.latency_scale != 1:
887
+ latency *= action.latency_scale
888
+ messages.append(
889
+ f"Scaling {self.name} latency by {action.latency_scale=}"
890
+ )
891
+ if self.n_parallel_instances != 1:
892
+ latency /= self.n_parallel_instances
893
+ messages.append(
894
+ f"Dividing {self.name} latency by {self.n_parallel_instances=}"
895
+ )
896
+ action.latency = latency
897
+ if action.latency < 0:
898
+ logging.warning(
899
+ f"Component {self.name} action {action.name} has negative latency: "
900
+ f"{action.latency}"
901
+ )
902
+ return self
903
+
904
+ def calculate_area_energy_latency_leak(
905
+ self,
906
+ component_models: list[ComponentModel] | None = None,
907
+ in_place: bool = False,
908
+ _use_cache: bool = False,
909
+ ) -> Self:
910
+ """
911
+ Calculates the area, energy, latency, and leak power for this component.
912
+ Populates the ``area``, ``total_area``, ``leak_power``, ``total_leak_power``,
913
+ ``total_latency``, and ``component_modeling_log`` fields of this component.
914
+ Additionally, for each action, populates the ``<action>.area``,
915
+ ``<action>.energy``, ``<action>.latency``, and ``<action>.leak_power`` fields.
916
+ Extends the ``component_modeling_log`` field with log messages.
917
+
918
+ Note that these methods will be called by the Spec when calculating energy and
919
+ area. If you call them yourself, note that string expressions may not be parsed
920
+ because they need the Spec's global scope. If you are sure that all necessary
921
+ values are present and not a result of an expression, you can call these
922
+ directly. Otherwise, you can call the
923
+ ``Spec.calculate_component_area_energy_latency_leak`` and then grab components
924
+ from the returned ``Spec``.
925
+
926
+ Parameters
927
+ ----------
928
+ component_models : list[ComponentModel] | None
929
+ The models to use for energy calculation. If not provided, the models will
930
+ be found with `hwcomponents.get_models()`.
931
+ in_place : bool
932
+ If True, the component will be modified in place. Otherwise, a copy will be
933
+ returned.
934
+ _use_cache : bool
935
+ If True, the component model will be cached and reused if the same component
936
+ class, attributes, and actions are provided. Note that this may return
937
+ copies of the same object across multiple calls.
938
+
939
+ Returns
940
+ -------
941
+ Self
942
+ The component with the calculated energy, area, and leak power.
943
+ """
944
+ if not in_place:
945
+ self: Self = self._copy_for_component_modeling()
946
+
947
+ if _use_cache and self.component_model is None:
948
+ cachekey = (
949
+ self.component_class,
950
+ self._attributes_for_component_model(),
951
+ self.actions,
952
+ )
953
+ if cachekey in _COMPONENT_MODEL_CACHE:
954
+ component = _COMPONENT_MODEL_CACHE[cachekey]
955
+ self.component_model = component.component_model
956
+ self.component_modeling_log = component.component_modeling_log
957
+ self.actions = component.actions
958
+ return self
959
+
960
+ self.calculate_area(component_models, in_place=True)
961
+ self.calculate_action_energy(component_models, in_place=True)
962
+ self.calculate_action_latency(component_models, in_place=True)
963
+ self.calculate_leak_power(component_models, in_place=True)
964
+ if _use_cache:
965
+ _set_component_model_cache(cachekey, self)
966
+ return self
967
+
968
+ def _attributes_for_component_model(self) -> dict[str, Any]:
969
+ return {
970
+ **self.shallow_model_dump(),
971
+ **self.extra_attributes_for_component_model.shallow_model_dump(),
972
+ }
973
+
974
+ def _copy_for_component_modeling(self) -> Self:
975
+ self: Component = self.model_copy()
976
+ self.extra_attributes_for_component_model = (
977
+ self.extra_attributes_for_component_model.model_copy()
978
+ )
979
+ self.actions = type(self.actions)([a.model_copy() for a in self.actions])
980
+ for action in self.actions:
981
+ action.extra_attributes_for_component_model = (
982
+ action.extra_attributes_for_component_model.model_copy()
983
+ )
984
+ return self
985
+
986
+
987
+ MEMORY_ACTIONS = ParsableList[TensorHolderAction](
988
+ [
989
+ TensorHolderAction(name="read"),
990
+ TensorHolderAction(name="write"),
991
+ ]
992
+ )
993
+
994
+
995
+ PROCESSING_STAGE_ACTIONS = ParsableList[TensorHolderAction](
996
+ [
997
+ TensorHolderAction(name="read"),
998
+ ]
999
+ )
1000
+
1001
+ COMPUTE_ACTIONS = ParsableList(
1002
+ [
1003
+ Action(name="compute"),
1004
+ ]
1005
+ )
1006
+
1007
+
1008
+ def _parse_tensor2bits(
1009
+ to_parse: dict[str, Any], location: str, symbol_table: dict[str, Any]
1010
+ ) -> dict[str, Any]:
1011
+ result = {}
1012
+ for key, value in to_parse.items():
1013
+ key_parsed = eval_set_expression(
1014
+ expression=key,
1015
+ symbol_table=symbol_table,
1016
+ expected_space=TensorName,
1017
+ location=f"{location} key {key}",
1018
+ ).instance
1019
+ result[key_parsed] = parse_expression(
1020
+ expression=value,
1021
+ symbol_table=symbol_table,
1022
+ attr_name=key,
1023
+ location=location,
1024
+ )
1025
+
1026
+ all = symbol_table["All"].instance
1027
+ for k in result:
1028
+ all -= k
1029
+
1030
+ if all:
1031
+ raise ParseError(f"Missing bits_per_value_scale for {all}")
1032
+
1033
+ for a, b in itertools.combinations(result.keys(), 2):
1034
+ if a & b:
1035
+ raise ParseError(f"bits_per_value_scale for {a} and {b} overlap")
1036
+
1037
+ return {k2: v for k, v in result.items() for k2 in k}
1038
+
1039
+
1040
+ class Tensors(ParsableModel):
1041
+ """
1042
+ Fields that control which tensor(s) are kept in a :py:class:`~.TensorHolder` and in
1043
+ what order their nodes may appear in the mapping.
1044
+ """
1045
+
1046
+ keep: TryParseTo[InvertibleSet[TensorName]] = "<Defaults to Nothing>"
1047
+ """
1048
+ A set expression describing which tensors must be kept in this
1049
+ :class:`accelforge.frontend.arch.TensorHolder`. If this is not defined, then all
1050
+ tensors must be kept. Any tensors that are in ``back`` will also be added to
1051
+ ``keep``.
1052
+ """
1053
+
1054
+ may_keep: TryParseTo[InvertibleSet[TensorName]] = (
1055
+ "<Nothing if keep is defined, else All>"
1056
+ )
1057
+ """
1058
+ A set expression describing which tensors may optionally be kept in this
1059
+ :class:`accelforge.frontend.arch.TensorHolder`. The mapper will explore both keeping
1060
+ and not keeping each of these tensors. If this is not defined, then all tensors may
1061
+ be kept.
1062
+ """
1063
+
1064
+ back: TryParseTo[InvertibleSet[TensorName]] = "Nothing"
1065
+ """
1066
+ A set expression describing which tensors must be backed by this
1067
+ :class:`accelforge.frontend.arch.TensorHolder`. If this is not defined, then no
1068
+ tensors must be backed.
1069
+ """
1070
+
1071
+ tile_shape: ParsableList[Comparison] = []
1072
+ """
1073
+ The tile shape for each rank variable. This is given as a list of
1074
+ :class:`~.Comparison` objects, where each comparison must evaluate to True for a
1075
+ valid mapping.
1076
+ """
1077
+
1078
+ no_refetch_from_above: TryParseTo[InvertibleSet[TensorName]] = "~All"
1079
+ """
1080
+ The tensors that are not allowed to be refetched from above. This is given as a set
1081
+ of :class:`~.TensorName` objects or a set expression that resolves to them. These
1082
+ tensors must be fetched at most one time from above memories, and may not be
1083
+ refetched across any temporal or spatial loop iterations. Tensors may be fetched in
1084
+ pieces (if they do not cause re-fetches of any piece).
1085
+ """
1086
+
1087
+ tensor_order_options: ParsableList[
1088
+ ParsableList[TryParseTo[InvertibleSet[TensorName]]]
1089
+ ] = ParsableList()
1090
+ """
1091
+ Options for the order of tensor storage nodes in the mapping. This is given as a
1092
+ list-of-lists-of-sets. Each list-of-sets is a valid order of tensor storage nodes.
1093
+ Order is given from highest in the mapping to lowest.
1094
+
1095
+ For example, an option could be [input | output, weight], which means that there is
1096
+ no relative ordering required between input and output, but weight must be below
1097
+ both.
1098
+ """
1099
+
1100
+ force_memory_hierarchy_order: bool = True
1101
+ """
1102
+ If set to true, storage nodes for lower-level memories must be placed below storage
1103
+ nodes for higher-level memories. For example, all MainMemory storage nodes must go
1104
+ above all LocalBuffer storage nodes.
1105
+
1106
+ This constraint always applies to same-tensor storage nodes (e.g., MainMemory
1107
+ reusing Output must go above LocalBuffer reusing Output); turning it off will permit
1108
+ things like MainMemory reusing Output going above LocalBuffer reusing Input.
1109
+
1110
+ This is identical to the `force_memory_hierarchy_order` field in the `FFM` class,
1111
+ but only applies to this tensor holder.
1112
+ """
1113
+
1114
+ def _parse_expressions(self, *args, **kwargs):
1115
+ self = copy.copy(self)
1116
+ keep, may_keep = self.keep, self.may_keep
1117
+ if may_keep == "<Nothing if keep is defined, else All>":
1118
+ self.may_keep = "All" if keep == "<Defaults to Nothing>" else "~All"
1119
+ if keep == "<Defaults to Nothing>":
1120
+ self.keep = "Nothing"
1121
+ parsed, symbol_table = super(self.__class__, self)._parse_expressions(
1122
+ *args, **kwargs
1123
+ )
1124
+ if isinstance(parsed.back, InvertibleSet):
1125
+ if isinstance(parsed.keep, InvertibleSet):
1126
+ parsed.keep |= parsed.back
1127
+ if isinstance(parsed.may_keep, InvertibleSet):
1128
+ parsed.may_keep -= parsed.back
1129
+
1130
+ # Assert that there are no intersecting sets
1131
+ for order in parsed.tensor_order_options:
1132
+ for i, s0 in enumerate(order):
1133
+ for j, s1 in enumerate(order):
1134
+ if i == j:
1135
+ continue
1136
+ if not isinstance(s0, InvertibleSet) or not isinstance(
1137
+ s1, InvertibleSet
1138
+ ):
1139
+ continue
1140
+ if s0 & s1:
1141
+ raise ValueError(
1142
+ f"Intersecting entries in tensor_order_options: {s0} and {s1}"
1143
+ )
1144
+
1145
+ return parsed, symbol_table
1146
+
1147
+ # def _parse_tensor_order_options(
1148
+ # self, symbol_table: dict[str, Any], location: str
1149
+ # ) -> "Tensors":
1150
+ # result = type(self)(
1151
+ # tensor_order_options=[
1152
+ # [
1153
+ # eval_set_expression(x, symbol_table, "tensors", location)
1154
+ # for x in order_choice
1155
+ # ]
1156
+ # for order_choice in self.tensor_order_options
1157
+ # ],
1158
+ # )
1159
+ # # Assert that there are no intersecting sets
1160
+ # for order in result.tensor_order_options:
1161
+ # for i, s0 in enumerate(order):
1162
+ # for j, s1 in enumerate(order):
1163
+ # if i == j:
1164
+ # continue
1165
+ # if s0 & s1:
1166
+ # raise ValueError(
1167
+ # f"Intersecting entries in dataflow constraint: {s0} and {s1}"
1168
+ # )
1169
+ # return result
1170
+
1171
+ # def _parse_keep(self, symbol_table: dict[str, Any], location: str) -> "Tensors":
1172
+ # keep, may_keep = self.keep, self.may_keep
1173
+ # if may_keep == "<Nothing if keep is defined, else All>":
1174
+ # may_keep = "All" if keep == "<Defaults to Nothing>" else "~All"
1175
+ # if keep == "<Defaults to Nothing>":
1176
+ # keep = "Nothing"
1177
+
1178
+ # may_keep_first = isinstance(keep, str) and re.findall(r"\bmay_keep\b", keep)
1179
+ # keep_first = isinstance(may_keep, str) and re.findall(r"\bkeep\b", may_keep)
1180
+ # if keep_first and may_keep_first:
1181
+ # raise ValueError(
1182
+ # f"Keep and may_keep reference each other: " f"{keep} and {may_keep}"
1183
+ # )
1184
+
1185
+ # if may_keep_first:
1186
+ # may_keep = eval_set_expression(may_keep, symbol_table, "tensors", location)
1187
+ # symbol_table = copy.copy(symbol_table)
1188
+ # symbol_table["may_keep"] = may_keep
1189
+ # keep = eval_set_expression(keep, symbol_table, "tensors", location)
1190
+ # return type(self)(keep=keep, may_keep=may_keep)
1191
+ # else:
1192
+ # keep = eval_set_expression(keep, symbol_table, "tensors", location)
1193
+ # symbol_table = copy.copy(symbol_table)
1194
+ # symbol_table["keep"] = keep
1195
+ # may_keep = eval_set_expression(may_keep, symbol_table, "tensors", location)
1196
+ # return type(self)(keep=keep, may_keep=may_keep)
1197
+
1198
+ # def _parse_non_keep(self, symbol_table: dict[str, Any], location: str) -> "Tensors":
1199
+ # return type(self)(
1200
+ # tile_shape=[x._parse(symbol_table, location) for x in self.tile_shape],
1201
+ # no_refetch_from_above=eval_set_expression(
1202
+ # self.no_refetch_from_above, symbol_table, "tensors", location
1203
+ # ),
1204
+ # force_memory_hierarchy_order=parse_expression(
1205
+ # self.force_memory_hierarchy_order,
1206
+ # symbol_table,
1207
+ # "force_memory_hierarchy_order",
1208
+ # location,
1209
+ # ),
1210
+ # )
1211
+
1212
+
1213
+ @_uninstantiable
1214
+ class TensorHolder(Component):
1215
+ """
1216
+ A `TensorHolder` is a component that holds tensors. These are usually `Memory`s,
1217
+ but can also be `ProcessingStage`s.
1218
+ """
1219
+
1220
+ actions: ParsableList[TensorHolderAction] = MEMORY_ACTIONS
1221
+ """ The actions that this `TensorHolder` can perform. """
1222
+
1223
+ tensors: Tensors = Tensors()
1224
+ """
1225
+ Fields that control which tensor(s) are kept in this `TensorHolder` and in what
1226
+ order their nodes may appear in the mapping.
1227
+ """
1228
+
1229
+ bits_per_value_scale: ParsesTo[dict] = {"All": 1}
1230
+ """
1231
+ A scaling factor for the bits per value of the tensors in this `TensorHolder`. If
1232
+ this is a dictionary, keys in the dictionary are parsed as expressions and may
1233
+ reference one or more tensors.
1234
+ """
1235
+
1236
+ bits_per_action: ParsesTo[int | float | None] = None
1237
+ """
1238
+ The number of bits accessed in each of this component's actions. Overridden by
1239
+ bits_per_action in any action of this component. If set here, acts as a default
1240
+ value for the bits_per_action of all actions of this component.
1241
+ """
1242
+
1243
+ def model_post_init(self, __context__=None) -> None:
1244
+ self._update_actions(MEMORY_ACTIONS)
1245
+
1246
+ def _parse_expressions(self, *args, **kwargs):
1247
+ # Sometimes the same component object may appear in the mapping and the
1248
+ # architecture, in which case we don't want parsing to happen twice.
1249
+ if getattr(self, "_parsed", False):
1250
+ return super()._parse_expressions(*args, **kwargs)
1251
+
1252
+ class MyPostCall(_PostCall):
1253
+ def __call__(self, field, value, parsed, symbol_table):
1254
+ if field == "bits_per_value_scale":
1255
+ parsed = _parse_tensor2bits(
1256
+ parsed,
1257
+ location="bits_per_value_scale",
1258
+ symbol_table=symbol_table,
1259
+ )
1260
+ return parsed
1261
+
1262
+ return super()._parse_expressions(*args, **kwargs, post_calls=(MyPostCall(),))
1263
+
1264
+
1265
+ class Fanout(Leaf):
1266
+ """
1267
+ Creates a spatial fanout, and doesn't do anything else.
1268
+ """
1269
+
1270
+
1271
+ class Memory(TensorHolder):
1272
+ """A `Memory` is a `TensorHolder` that stores data over time, allowing for temporal
1273
+ reuse."""
1274
+
1275
+ size: ParsesTo[int | float]
1276
+ """ The size of this `Memory` in bits. """
1277
+
1278
+ actions: ParsableList[TensorHolderAction] = MEMORY_ACTIONS
1279
+ """ The actions that this `Memory` can perform. """
1280
+
1281
+
1282
+ class ProcessingStage(TensorHolder):
1283
+ """A `ProcessingStage` is a `TensorHolder` that does not store data over time, and
1284
+ therefore does not allow for temporal reuse. Use this as a toll that charges reads
1285
+ and writes every time a piece of data moves through it.
1286
+
1287
+ Every write to a `ProcessingStage` is immediately written to the next `Memory`
1288
+ (which may be above or below depending on where the write came from), and same for
1289
+ reads.
1290
+
1291
+ The access counts of a `ProcessingStage` are only included in the "read" action.
1292
+ Each traversal through the `ProcessingStage` is counted as a read. Writes are always
1293
+ zero.
1294
+ """
1295
+
1296
+ direction: Literal["up", "down", "up_and_down"]
1297
+ """
1298
+ The direction in which data flows through this `ProcessingStage`. If "up", then data
1299
+ flows from below `TensorHolder`, through this `ProcessingStage` (plus paying
1300
+ associated costs), and then to the next `TensorHolder` above it. Other data
1301
+ movements are assumed to avoid this ProcessingStage.
1302
+ """
1303
+
1304
+ actions: ParsableList[TensorHolderAction] = PROCESSING_STAGE_ACTIONS
1305
+ """ The actions that this `ProcessingStage` can perform. """
1306
+
1307
+ def model_post_init(self, __context__=None) -> None:
1308
+ self._update_actions(PROCESSING_STAGE_ACTIONS)
1309
+
1310
+
1311
+ class Compute(Component):
1312
+ actions: ParsableList[Action] = COMPUTE_ACTIONS
1313
+ """ The actions that this `Compute` can perform. """
1314
+
1315
+ def model_post_init(self, __context__=None) -> None:
1316
+ self._update_actions(COMPUTE_ACTIONS)
1317
+
1318
+
1319
+ T = TypeVar("T")
1320
+
1321
+
1322
+ @_uninstantiable
1323
+ class Branch(ArchNode):
1324
+ # nodes: ArchNodes[_InferFromTag[Compute, Memory, "Hierarchical"]] = ArchNodes()
1325
+ nodes: ArchNodes[
1326
+ Annotated[
1327
+ Union[
1328
+ Annotated[Compute, Tag("Compute")],
1329
+ Annotated[Memory, Tag("Memory")],
1330
+ Annotated[ProcessingStage, Tag("ProcessingStage")],
1331
+ Annotated[Fanout, Tag("Fanout")],
1332
+ Annotated["Parallel", Tag("Parallel")],
1333
+ Annotated["Hierarchical", Tag("Hierarchical")],
1334
+ ],
1335
+ Discriminator(_get_tag),
1336
+ ]
1337
+ ] = ArchNodes()
1338
+
1339
+ def get_nodes_of_type(self, types: Type[T] | tuple[Type[T], ...]) -> Iterator[T]:
1340
+ for node in self.nodes:
1341
+ if isinstance(node, types):
1342
+ yield node
1343
+ elif isinstance(node, Branch):
1344
+ yield from node.get_nodes_of_type(types)
1345
+
1346
+
1347
+ class Parallel(Branch):
1348
+ def _flatten(
1349
+ self,
1350
+ compute_node: str,
1351
+ fanout: int = 1,
1352
+ return_fanout: bool = False,
1353
+ ):
1354
+ nodes = []
1355
+
1356
+ for node in self.nodes:
1357
+ if isinstance(node, Compute) and node.name == compute_node:
1358
+ fanout *= node.get_fanout()
1359
+ nodes.append(node)
1360
+ break
1361
+ if isinstance(node, Branch):
1362
+ computes = node.get_nodes_of_type(Compute)
1363
+ if compute_node in [c.name for c in computes]:
1364
+ new_nodes, new_fanout = node._flatten(
1365
+ compute_node, fanout, return_fanout=True
1366
+ )
1367
+ nodes.extend(new_nodes)
1368
+ fanout *= new_fanout
1369
+ break
1370
+ else:
1371
+ raise ParseError(f"Compute node {compute_node} not found in parallel node")
1372
+
1373
+ return nodes, fanout if return_fanout else nodes
1374
+
1375
+
1376
+ class Hierarchical(Branch):
1377
+ def _flatten(
1378
+ self,
1379
+ compute_node: str,
1380
+ fanout: int = 1,
1381
+ return_fanout: bool = False,
1382
+ ):
1383
+ nodes = []
1384
+
1385
+ for i, node in enumerate(self.nodes):
1386
+ try:
1387
+ if isinstance(node, (Hierarchical, Parallel)):
1388
+ if isinstance(node, Parallel) and i < len(self.nodes) - 1:
1389
+ raise ParseError(
1390
+ f"Parallel node {node.name} must be the last node in a "
1391
+ "hierarchical node"
1392
+ )
1393
+ new_nodes, new_fanout = node._flatten(
1394
+ compute_node, fanout, return_fanout=True
1395
+ )
1396
+ nodes.extend(new_nodes)
1397
+ fanout *= new_fanout
1398
+ if any(
1399
+ isinstance(n, Compute) and n.name == compute_node
1400
+ for n in new_nodes
1401
+ ):
1402
+ break
1403
+ elif isinstance(node, Compute):
1404
+ if node.name == compute_node:
1405
+ fanout *= node.get_fanout()
1406
+ nodes.append(node)
1407
+ break
1408
+ elif isinstance(node, Leaf):
1409
+ fanout *= node.get_fanout()
1410
+ nodes.append(node)
1411
+ else:
1412
+ raise TypeError(f"Can't flatten {node}")
1413
+ except ParseError as e:
1414
+ e.add_field(node)
1415
+ raise e
1416
+
1417
+ if return_fanout:
1418
+ return nodes, fanout
1419
+ return nodes
1420
+
1421
+ def render(self) -> str:
1422
+ graph = _pydot_graph()
1423
+ graph.add_node(pydot.Node("root", shape="box", label="TODO: Arch Render"))
1424
+ return _SVGJupyterRender(graph.create_svg(prog="dot").decode("utf-8"))
1425
+
1426
+ def _repr_svg_(self) -> str:
1427
+ return self.render()
1428
+
1429
+
1430
+ class _ConstraintLambda:
1431
+ def __init__(
1432
+ self,
1433
+ constraint: Comparison,
1434
+ target_mapping_nodes: list[Spatial],
1435
+ rank_variables: set[str],
1436
+ ):
1437
+ self.constraint = constraint
1438
+ self.constraint_lambda = (
1439
+ None if constraint is None else constraint._to_constraint_lambda(True)
1440
+ )
1441
+ self.target_mapping_nodes = target_mapping_nodes
1442
+ self.rank_variables = rank_variables
1443
+ self._target_node_indices = None
1444
+ self._target_loop_indices = None
1445
+
1446
+ def __repr__(self):
1447
+ return f"_ConstraintLambda({self.constraint}, {self.target_mapping_nodes}, {self.rank_variables})"
1448
+
1449
+ def __call__(self, rank_variables: set[RankVariable], sizes: np.ndarray) -> bool:
1450
+ final = self.rank_variables.issubset(rank_variables)
1451
+ return self.constraint_lambda(final, sizes)
1452
+
1453
+ def _constrained_node_str(self) -> str:
1454
+ return f"constrains {self._target_node_indices}"
1455
+
1456
+ def __bool__(self) -> bool:
1457
+ return bool(self.target_mapping_nodes)
1458
+
1459
+ def __str__(self) -> str:
1460
+ return self.constraint.__str__()
1461
+
1462
+
1463
+ class _TileShapeConstraintLambda(_ConstraintLambda):
1464
+ def __str__(self) -> str:
1465
+ if self.constraint._str_repr is not None:
1466
+ return self.constraint._str_repr
1467
+ return "tile_shape " + super().__str__()
1468
+
1469
+ def pretty_str(self) -> str:
1470
+ return f"Tile shape {self.constraint.operator} {self.constraint.value} {self._constrained_node_str()}"
1471
+
1472
+
1473
+ class _LoopBoundsConstraintLambda(_ConstraintLambda):
1474
+ def __str__(self) -> str:
1475
+ if self.constraint._str_repr is not None:
1476
+ return self.constraint._str_repr
1477
+ return "loop_bounds " + super().__str__()
1478
+
1479
+ def pretty_str(self) -> str:
1480
+ return f"Loop bounds {self.constraint.operator} {self.constraint.value} {self._constrained_node_str()}"
1481
+
1482
+
1483
+ class _MinUsageConstraintLambda(_ConstraintLambda):
1484
+ def __init__(
1485
+ self,
1486
+ target_mapping_nodes: list[Spatial],
1487
+ rank_variables: set[str],
1488
+ min_usage: float,
1489
+ ):
1490
+ super().__init__(None, target_mapping_nodes, rank_variables)
1491
+ self.min_usage = min_usage
1492
+
1493
+ def __call__(self, complete_indices: list[int], usages: np.ndarray) -> bool:
1494
+ # final = self.rank_variables.issubset(rank_variables)
1495
+ final = set(self._target_loop_indices).issubset(set(complete_indices))
1496
+ if not final:
1497
+ return np.ones(usages.shape[0], dtype=np.bool)
1498
+
1499
+ # Some usages are already above the minimum. Return those.
1500
+ result = usages >= self.min_usage
1501
+ if np.sum(result) > 0:
1502
+ return result
1503
+
1504
+ # Nobody is amove the minimum. Return the best we can do.
1505
+ max_usage = np.max(usages, axis=0)
1506
+ return usages == max_usage
1507
+
1508
+ def pretty_str(self) -> str:
1509
+ return f"Min usage {self.min_usage} {self._constrained_node_str()}"
1510
+
1511
+
1512
+ class Arch(Hierarchical):
1513
+ """Top-level architecture specification."""
1514
+
1515
+ arch_globals_dependent_on_workload: ParseExtras = ParseExtras()
1516
+ """
1517
+ Attributes that are dependent on the workload. This is parsed first in the
1518
+ architecture, and symbols here are available to the rest of the architecture.
1519
+ """
1520
+
1521
+ extra_attributes_for_all_component_models: ParseExtras = ParseExtras()
1522
+ """
1523
+ Extra attributes to pass to all component models. This can be used to pass global
1524
+ attributes, such as technology node or clock period, to every component model.
1525
+ """
1526
+
1527
+ @property
1528
+ def total_area(self) -> float:
1529
+ """
1530
+ Returns the total area of the architecture in m^2.
1531
+
1532
+ Returns
1533
+ -------
1534
+ float
1535
+ The total area of the architecture in m^2.
1536
+ """
1537
+ return sum(self.per_component_total_area.values())
1538
+
1539
+ @property
1540
+ def total_leak_power(self) -> float:
1541
+ """
1542
+ Returns the total leak power of the architecture in W.
1543
+
1544
+ Returns
1545
+ -------
1546
+ float
1547
+ The total leak power of the architecture in W.
1548
+ """
1549
+ return sum(self.per_component_total_leak_power.values())
1550
+
1551
+ @property
1552
+ def per_component_total_area(self) -> dict[str, float]:
1553
+ """
1554
+ Returns the total area used by each component in the architecture in m^2.
1555
+
1556
+ Returns
1557
+ -------
1558
+ dict[str, float]
1559
+ A dictionary of component names to their total area in m^2.
1560
+ """
1561
+ area = {
1562
+ node.name: node.total_area for node in self.get_nodes_of_type(Component)
1563
+ }
1564
+ for k, v in area.items():
1565
+ if v is None:
1566
+ raise ValueError(
1567
+ f"Area of {k} is not set. Please call the Spec's "
1568
+ "`calculate_component_area_energy_latency_leak` method before accessing this "
1569
+ "property."
1570
+ )
1571
+ return area
1572
+
1573
+ @property
1574
+ def per_component_total_leak_power(self) -> dict[str, float]:
1575
+ """
1576
+ Returns the total leak power of each component in the architecture in W.
1577
+
1578
+ Returns
1579
+ -------
1580
+ dict[str, float]
1581
+ A dictionary of component names to their total leak power in W.
1582
+ """
1583
+ leak_power = {
1584
+ node.name: node.total_leak_power
1585
+ for node in self.get_nodes_of_type(Component)
1586
+ }
1587
+ for k, v in leak_power.items():
1588
+ if v is None:
1589
+ raise ValueError(
1590
+ f"Leak power of {k} is not set. Please call the Spec's "
1591
+ "`calculate_component_area_energy_latency_leak` method before accessing this "
1592
+ "property."
1593
+ )
1594
+ return leak_power
1595
+
1596
+ def _parse_expressions(self, symbol_table: dict[str, Any], *args, **kwargs):
1597
+ outer_st = symbol_table
1598
+
1599
+ class PostCallArch(_PostCall):
1600
+ def __call__(self, field, value, parsed, symbol_table):
1601
+ if field == "arch_globals_dependent_on_workload":
1602
+ parsed_dump = parsed.shallow_model_dump()
1603
+ symbol_table.update(parsed_dump)
1604
+ symbol_table["arch_globals_dependent_on_workload"] = parsed_dump
1605
+ if field == "extra_attributes_for_all_component_models":
1606
+ parsed_dump = parsed.shallow_model_dump()
1607
+ symbol_table["arch_extra_attributes_for_all_component_models"] = (
1608
+ parsed_dump
1609
+ )
1610
+ return parsed
1611
+
1612
+ cur_st = dict(symbol_table)
1613
+
1614
+ for node in self.get_nodes_of_type(Leaf):
1615
+ cur_st[node.name] = node
1616
+
1617
+ parsed, _ = super()._parse_expressions(
1618
+ cur_st,
1619
+ *args,
1620
+ **kwargs,
1621
+ post_calls=(PostCallArch(),),
1622
+ order=(
1623
+ "arch_globals_dependent_on_workload",
1624
+ "extra_attributes_for_all_component_models",
1625
+ ),
1626
+ )
1627
+ return parsed, symbol_table
1628
+
1629
+ def __getitem__(self, name: str) -> Leaf:
1630
+ return self.name2leaf(name)
1631
+
1632
+ def model_post_init(self, __context__=None) -> None:
1633
+ # Make sure all leaf names are unique
1634
+ leaves = {}
1635
+ for l in self.get_nodes_of_type(Leaf):
1636
+ n = l.name
1637
+ leaves.setdefault(n, l)
1638
+ assert l is leaves[n], f"Duplicate name {n} found in architecture"
1639
+
1640
+
1641
+ # We had to reference Hierarchical before it was defined
1642
+ Branch.model_rebuild()