accelforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. accelforge/__init__.py +21 -0
  2. accelforge/_accelerated_imports.py +16 -0
  3. accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
  4. accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
  5. accelforge/_deprecate/_simanneal/simanneal.py +666 -0
  6. accelforge/_deprecate/_simanneal/tracking.py +105 -0
  7. accelforge/_deprecate/_simanneal/wrappers.py +218 -0
  8. accelforge/_deprecate/_simanneal2/__init__.py +7 -0
  9. accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
  10. accelforge/_deprecate/_simanneal2/tracking.py +116 -0
  11. accelforge/_deprecate/compatibility_util.py +181 -0
  12. accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
  13. accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
  14. accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
  15. accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
  16. accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
  17. accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
  18. accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
  19. accelforge/_deprecate/tags.py +69 -0
  20. accelforge/_deprecate/viz/__init__.py +0 -0
  21. accelforge/_deprecate/viz/interactive.py +159 -0
  22. accelforge/_deprecate/viz/reservationtree.py +307 -0
  23. accelforge/_deprecate/viz/ski_slope.py +88 -0
  24. accelforge/_version.py +15 -0
  25. accelforge/examples.py +39 -0
  26. accelforge/frontend/__init__.py +10 -0
  27. accelforge/frontend/_binding.py +129 -0
  28. accelforge/frontend/_workload_isl/__init__.py +2 -0
  29. accelforge/frontend/_workload_isl/_isl.py +149 -0
  30. accelforge/frontend/_workload_isl/_symbolic.py +141 -0
  31. accelforge/frontend/arch copy.py +1544 -0
  32. accelforge/frontend/arch.py +1642 -0
  33. accelforge/frontend/config.py +63 -0
  34. accelforge/frontend/mapper/__init__.py +5 -0
  35. accelforge/frontend/mapper/ffm.py +126 -0
  36. accelforge/frontend/mapper/mapper.py +7 -0
  37. accelforge/frontend/mapper/metrics.py +30 -0
  38. accelforge/frontend/mapping/__init__.py +1 -0
  39. accelforge/frontend/mapping/mapping.py +1736 -0
  40. accelforge/frontend/model.py +14 -0
  41. accelforge/frontend/renames.py +150 -0
  42. accelforge/frontend/spec copy.py +230 -0
  43. accelforge/frontend/spec.py +301 -0
  44. accelforge/frontend/variables.py +12 -0
  45. accelforge/frontend/workload.py +952 -0
  46. accelforge/mapper/FFM/__init__.py +9 -0
  47. accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
  48. accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
  49. accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
  50. accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
  51. accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
  52. accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
  53. accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
  54. accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
  55. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
  56. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
  57. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
  58. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
  59. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
  60. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
  61. accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
  62. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
  63. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
  64. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
  65. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
  66. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
  67. accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
  68. accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
  69. accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
  70. accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
  71. accelforge/mapper/FFM/data.py +61 -0
  72. accelforge/mapper/FFM/main copy.py +236 -0
  73. accelforge/mapper/FFM/main.py +208 -0
  74. accelforge/mapper/FFM/mappings.py +510 -0
  75. accelforge/mapper/FFM/pmappings.py +310 -0
  76. accelforge/mapper/__init__.py +4 -0
  77. accelforge/mapper.py +0 -0
  78. accelforge/model/__init__.py +1 -0
  79. accelforge/model/_looptree/__init__.py +0 -0
  80. accelforge/model/_looptree/accesses.py +335 -0
  81. accelforge/model/_looptree/capacity/__init__.py +1 -0
  82. accelforge/model/_looptree/capacity/aggregators.py +36 -0
  83. accelforge/model/_looptree/capacity/capacity.py +47 -0
  84. accelforge/model/_looptree/energy.py +150 -0
  85. accelforge/model/_looptree/equivalent_ranks.py +29 -0
  86. accelforge/model/_looptree/latency/__init__.py +1 -0
  87. accelforge/model/_looptree/latency/latency.py +98 -0
  88. accelforge/model/_looptree/latency/memory.py +120 -0
  89. accelforge/model/_looptree/latency/processors.py +92 -0
  90. accelforge/model/_looptree/mapping_utilities.py +71 -0
  91. accelforge/model/_looptree/reuse/__init__.py +4 -0
  92. accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
  93. accelforge/model/_looptree/reuse/isl/des.py +59 -0
  94. accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
  95. accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
  96. accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
  97. accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
  98. accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
  99. accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
  100. accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
  101. accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
  102. accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
  103. accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
  104. accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
  105. accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
  106. accelforge/model/_looptree/run.py +122 -0
  107. accelforge/model/_looptree/types.py +26 -0
  108. accelforge/model/_looptree/visualization/__init__.py +0 -0
  109. accelforge/model/_looptree/visualization/occupancy.py +11 -0
  110. accelforge/model/main.py +222 -0
  111. accelforge/plotting/__init__.py +2 -0
  112. accelforge/plotting/mappings.py +219 -0
  113. accelforge/plotting/specs.py +57 -0
  114. accelforge/util/__init__.py +4 -0
  115. accelforge/util/_base_analysis_types.py +24 -0
  116. accelforge/util/_basetypes.py +1089 -0
  117. accelforge/util/_frozenset.py +36 -0
  118. accelforge/util/_isl.py +29 -0
  119. accelforge/util/_itertools.py +14 -0
  120. accelforge/util/_mathfuncs.py +57 -0
  121. accelforge/util/_parse_expressions.py +339 -0
  122. accelforge/util/_picklecache.py +32 -0
  123. accelforge/util/_setexpressions.py +268 -0
  124. accelforge/util/_sympy/__init__.py +0 -0
  125. accelforge/util/_sympy/broadcast_max.py +18 -0
  126. accelforge/util/_visualization.py +112 -0
  127. accelforge/util/_yaml.py +579 -0
  128. accelforge/util/parallel.py +193 -0
  129. accelforge-0.0.1.dist-info/METADATA +64 -0
  130. accelforge-0.0.1.dist-info/RECORD +258 -0
  131. accelforge-0.0.1.dist-info/WHEEL +5 -0
  132. accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
  133. accelforge-0.0.1.dist-info/top_level.txt +5 -0
  134. docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
  135. docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
  136. docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
  137. docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
  138. docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
  139. docs/_build/html/_sources/fastfusion.rst.txt +20 -0
  140. docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
  141. docs/_build/html/_sources/index.rst.txt +87 -0
  142. docs/_build/html/_sources/modules.rst.txt +7 -0
  143. docs/_build/html/_sources/notes/citation.rst.txt +45 -0
  144. docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
  145. docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
  146. docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
  147. docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
  148. docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
  149. docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
  150. docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
  151. docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
  152. docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
  153. docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
  154. docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
  155. docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
  156. docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
  157. docs/_build/html/_sources/notes/spec.rst.txt +36 -0
  158. docs/source/_ext/include_attrs.py +213 -0
  159. docs/source/_ext/include_docstring.py +364 -0
  160. docs/source/_ext/include_functions.py +154 -0
  161. docs/source/_ext/include_notebook.py +131 -0
  162. docs/source/_ext/include_yaml.py +119 -0
  163. docs/source/_ext/inherited_attributes.py +222 -0
  164. docs/source/_ext/paths.py +4 -0
  165. docs/source/conf.py +79 -0
  166. examples/arches/compute_in_memory/_include.yaml +74 -0
  167. examples/arches/compute_in_memory/_include_functions.py +229 -0
  168. examples/arches/compute_in_memory/_load_spec.py +57 -0
  169. examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
  170. examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
  171. examples/arches/compute_in_memory/components/misc.py +195 -0
  172. examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
  173. examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
  174. examples/arches/compute_in_memory/isaac.yaml +233 -0
  175. examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
  176. examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
  177. examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
  178. examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
  179. examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
  180. examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
  181. examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
  182. examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
  183. examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
  184. examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
  185. examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
  186. examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
  187. examples/arches/eyeriss.yaml +68 -0
  188. examples/arches/fanout_variations/at_glb.yaml +31 -0
  189. examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
  190. examples/arches/fanout_variations/at_mac.yaml +31 -0
  191. examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
  192. examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
  193. examples/arches/nvdla.yaml +47 -0
  194. examples/arches/simple.yaml +28 -0
  195. examples/arches/tpu_v4i.yaml +67 -0
  196. examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
  197. examples/misc/component_annotated.yaml +33 -0
  198. examples/workloads/gpt3_6.7B.yaml +124 -0
  199. examples/workloads/matmuls.yaml +20 -0
  200. examples/workloads/mobilenet_28.yaml +81 -0
  201. examples/workloads/mobilenet_various_separate.yaml +106 -0
  202. examples/workloads/three_matmuls_annotated.yaml +59 -0
  203. notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
  204. notebooks/compute_in_memory/_scripts.py +339 -0
  205. notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
  206. notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
  207. notebooks/paths.py +4 -0
  208. notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
  209. notebooks/tutorials/FFM.ipynb +3498 -0
  210. notebooks/tutorials/_include.py +48 -0
  211. notebooks/tutorials/component_energy_area.ipynb +363 -0
  212. tests/Q_mapping.yaml +38 -0
  213. tests/__init__.py +0 -0
  214. tests/conv.mapping.yaml +27 -0
  215. tests/conv.workload.yaml +13 -0
  216. tests/conv_sym.mapping.yaml +43 -0
  217. tests/copy.mapping.yaml +35 -0
  218. tests/copy.workload.yaml +15 -0
  219. tests/distribuffers/__init__.py +0 -0
  220. tests/distribuffers/multicast/test_cases.yaml +482 -0
  221. tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
  222. tests/distribuffers/spec/distributed.yaml +100 -0
  223. tests/distribuffers/spec/logical_arch.yaml +32 -0
  224. tests/distribuffers/spec/physical_arch.yaml +69 -0
  225. tests/distribuffers/test_binding.py +48 -0
  226. tests/frontend/__init__.py +0 -0
  227. tests/frontend/test_mapping_viz.py +52 -0
  228. tests/mapper/__init__.py +0 -0
  229. tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
  230. tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
  231. tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
  232. tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
  233. tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
  234. tests/mapper/test_mapping_to_isl.py +90 -0
  235. tests/mapper/test_spatial_reuse_analysis.py +67 -0
  236. tests/mapper/test_temporal_reuse_analysis.py +56 -0
  237. tests/mapper/util.py +58 -0
  238. tests/matmul.mapping.yaml +29 -0
  239. tests/matmul.workload.yaml +12 -0
  240. tests/matmul_spatial.mapping.yaml +44 -0
  241. tests/mha.renames.yaml +65 -0
  242. tests/mha.workload.yaml +67 -0
  243. tests/mha.yaml +59 -0
  244. tests/mha_full.workload.yaml +67 -0
  245. tests/mobilenet.workload.yaml +35 -0
  246. tests/mobilenet_long.workload.yaml +64 -0
  247. tests/pmappingcache.py +24 -0
  248. tests/processing_stage.arch.yaml +40 -0
  249. tests/snowcat.arch.yaml +36 -0
  250. tests/test_ffm_join_pmappings.py +106 -0
  251. tests/test_ffm_make_pmappings.py +82 -0
  252. tests/test_ffm_make_tile_shapes.py +49 -0
  253. tests/test_mapper.py +100 -0
  254. tests/test_model.py +37 -0
  255. tests/test_plotting.py +72 -0
  256. tests/test_processing_stage.py +46 -0
  257. tests/test_symbolic_model.py +248 -0
  258. tests/test_workload.py +141 -0
@@ -0,0 +1,1736 @@
1
+ """
2
+ A module containing the visualization and types needed to run mapspace exploratioon
3
+ in AccelForge.
4
+ """
5
+
6
+ import copy
7
+ from dataclasses import dataclass
8
+ import inspect
9
+ import itertools
10
+ import pydot
11
+
12
+ from typing import (
13
+ # Collections
14
+ Any,
15
+ List,
16
+ # Object definitions
17
+ Annotated,
18
+ Callable,
19
+ Literal,
20
+ Self,
21
+ # Type constructions
22
+ Type,
23
+ TypeVar,
24
+ TypeAlias,
25
+ # Variable meta-mandates
26
+ Optional,
27
+ Union,
28
+ override,
29
+ )
30
+ from collections.abc import Set
31
+ from pydantic import ConfigDict, Discriminator, Tag, computed_field
32
+ import sympy
33
+
34
+ from accelforge.util._basetypes import (
35
+ # Parsing helpers for the input files.
36
+ ParsableModel,
37
+ ParsableList,
38
+ ParsesTo,
39
+ # Retrieves information from YAML tags.
40
+ _get_tag,
41
+ _uninstantiable,
42
+ NoParse,
43
+ )
44
+ from accelforge.frontend.workload import RankVariable, TensorName
45
+ from accelforge.util._visualization import ColorMap, _pydot_graph
46
+ from accelforge.util.parallel import _SVGJupyterRender
47
+ from accelforge._version import __version__
48
+ from accelforge.frontend import arch
49
+
50
+ T = TypeVar("T", bound="MappingNode")
51
+ """TypeVar T: Restricts the allowable types to types of MappingNodes."""
52
+
53
+ NodeList: TypeAlias = ParsableList[
54
+ Annotated[
55
+ Union[
56
+ Annotated["Split", Tag("Split")],
57
+ Annotated["Compute", Tag("Compute")],
58
+ Annotated["Storage", Tag("Storage")],
59
+ Annotated["Temporal", Tag("Temporal")],
60
+ Annotated["Spatial", Tag("Spatial")],
61
+ Annotated["Sequential", Tag("Sequential")],
62
+ Annotated["Pipeline", Tag("Pipeline")],
63
+ Annotated["Nested", Tag("Nested")],
64
+ Annotated["Reservation", Tag("Reservation")],
65
+ Annotated["Mapping", Tag("Mapping")],
66
+ Annotated["ProcessingStage", Tag("ProcessingStage")],
67
+ ],
68
+ Discriminator(_get_tag),
69
+ ]
70
+ ]
71
+ """
72
+ TypeAlias NodeList: ParsableList that can contain and discriminate between
73
+ MappingNodes of different types.
74
+ """
75
+
76
+ _NO_JOIN_MAPPING_VISUALIZATION = False
77
+
78
+ # =============================================================================
79
+ # LoopTree Mapping Nodes
80
+ # =============================================================================
81
+
82
+
83
+ @_uninstantiable
84
+ class MappingNode(ParsableModel):
85
+ """
86
+ Represents a Node in the Mapping, which can be a loop, a storage node, a compute
87
+ node, etc.
88
+ """
89
+
90
+ _constraint_lambdas: List[Callable[[], bool]] = []
91
+ """ Constraints that apply to this node. """
92
+
93
+ _must_be_here: bool = False
94
+ """ Can the mapper move this node? """
95
+
96
+ _required: bool = False
97
+ """ Must the mapper keep this node? """
98
+
99
+ def _render_node_name(self) -> str:
100
+ """The name for a Pydot node."""
101
+ return f"{self.__class__.__name__}_{id(self)}"
102
+
103
+ def _render_node_label(self, **kwargs) -> str:
104
+ """The label for a Pydot node."""
105
+ return self.__str__()
106
+
107
+ def _render_node_shape(self) -> str:
108
+ """The shape for a Pydot node."""
109
+ return "box"
110
+
111
+ def _render_node(self, **kwargs) -> str:
112
+ """Render this node using Pydot."""
113
+ return pydot.Node(
114
+ self._render_node_name(),
115
+ label=self._render_node_label(**kwargs),
116
+ shape=self._render_node_shape(),
117
+ style="filled",
118
+ fillcolor=self._render_node_color(),
119
+ margin=0,
120
+ )
121
+
122
+ def _parent2next(self) -> "MappingNode":
123
+ """
124
+ Return the parent to the next node in the tree. This is used for nodes that
125
+ don't appear in the tree, like Nested nodes.
126
+ """
127
+ return self
128
+
129
+ def _parent2child(
130
+ self, parent: "MappingNode"
131
+ ) -> list[tuple["MappingNode", "MappingNode"]]:
132
+ """
133
+ Returns a list of tuples, each one being a parent and child node.
134
+ """
135
+ return []
136
+
137
+ def _render_make_children(self, **kwargs) -> list[str]:
138
+ """
139
+ Renders the children of this node and returns them as a list of strings.
140
+ """
141
+ return []
142
+
143
+ def get_nodes_of_type(self, types: Type[T] | tuple[Type[T], ...]) -> List[T]:
144
+ """
145
+ Returns all sub-nodes, including this one, that match the given types.
146
+ """
147
+ nodes: List[T] = []
148
+ if isinstance(self, types):
149
+ nodes.append(self)
150
+ if isinstance(self, MappingNodeWithChildren):
151
+ for node in self.nodes:
152
+ if isinstance(node, types):
153
+ nodes.append(node)
154
+ if isinstance(node, MappingNodeWithChildren):
155
+ nodes.extend(node.get_nodes_of_type(types))
156
+ return nodes
157
+
158
+ def _flatten(self) -> list["MappingNode"]:
159
+ if isinstance(self, MappingNodeWithChildren):
160
+ result = [self]
161
+ for node in self.nodes:
162
+ result.extend(node._flatten())
163
+ return result
164
+ return [self]
165
+
166
+ def _render_node_color(self) -> str:
167
+ """The color for a Pydot node."""
168
+ return "white"
169
+
170
+ def __hash__(self):
171
+ """
172
+ Hashing functor to create mappings of nodes to other objects.
173
+ """
174
+ return id(self)
175
+
176
+ def __eq__(self, other: Any):
177
+ return self is other
178
+
179
+ def __init_subclass__(cls, **kwargs):
180
+ # Let Pydantic build the subclass first.
181
+ super().__init_subclass__(**kwargs)
182
+ # Read the *raw* attribute without descriptor binding,
183
+ h = inspect.getattr_static(cls, "__hash__", None)
184
+ # Replace if unhashable (None) or if it's just BaseModel’s default.
185
+ if h is None or h is ParsableModel.__hash__:
186
+ cls.__hash__ = MappingNode.__hash__
187
+
188
+ def compact_str(self) -> str:
189
+ """Returns a compact string representation of this node."""
190
+ return self.__str__()
191
+
192
+
193
+ @dataclass(frozen=True)
194
+ class TilePattern:
195
+ tile_shape: ParsesTo[
196
+ Literal["symbol"] | sympy.Symbol | int | str | None | sympy.Expr
197
+ ] = "symbol"
198
+ """
199
+ The common tile shape of the pattern. This is the number of indices by which
200
+ the tile moves each iteration.
201
+ """
202
+
203
+ initial_tile_shape: ParsesTo[
204
+ Literal["symbol"] | sympy.Symbol | int | None | str | sympy.Expr
205
+ ] = "symbol"
206
+ """
207
+ The initial tile shape. This is the shape of the tile at the first iteration.
208
+ Subsequent iterations may be smaller if they overlap previous iterations.
209
+ """
210
+
211
+ calculated_n_iterations: (
212
+ Literal["symbol"] | sympy.Symbol | int | None | str | sympy.Expr
213
+ ) = None
214
+ """ The number of iterations in the pattern. Do not set this! Used internally by the
215
+ mapper. """
216
+
217
+ def _symbol_attrs(self) -> tuple[str, ...]:
218
+ """The attributes that may be symbols."""
219
+ return ("tile_shape", "initial_tile_shape", "calculated_n_iterations")
220
+
221
+ def __str__(self) -> str:
222
+ return self.as_str()
223
+
224
+ def as_str(self, with_initial_tile_shape=True, with_tile_shape=True):
225
+ s = []
226
+ if self.calculated_n_iterations not in (None, "symbol"):
227
+ s.append(f"in [0..{self.calculated_n_iterations})")
228
+ if with_initial_tile_shape and (
229
+ self.initial_tile_shape not in (None, "symbol")
230
+ ):
231
+ s.append(f"initial={self.initial_tile_shape}")
232
+ if with_tile_shape and (self.tile_shape not in (None, "symbol")):
233
+ s.append(f"tile_shape={self.tile_shape}")
234
+ return " ".join(s)
235
+
236
+ def update(self, **kwargs) -> "TilePattern":
237
+ """Update the TilePattern with the given keyword arguments."""
238
+ return type(self)(**{**self.__dict__, **kwargs})
239
+
240
+ def _symbol2str(self) -> "TilePattern":
241
+ """
242
+ Convert the symbols in the TilePattern to strings, and return a new TilePattern
243
+ with the symbols replaced by their names.
244
+ """
245
+
246
+ def _symbol2str(x: sympy.Symbol | int | None) -> str | int | None:
247
+ return x.name if isinstance(x, sympy.Symbol) else x
248
+
249
+ return type(self)(
250
+ **{x: _symbol2str(getattr(self, x)) for x in self._symbol_attrs()}
251
+ )
252
+
253
+ def _prepend_symbols(self, prepend: str) -> "TilePattern":
254
+ def _prepend(x: sympy.Symbol | int | None) -> str | int | None:
255
+ if isinstance(x, sympy.Symbol):
256
+ x = x.name
257
+ return prepend + x if isinstance(x, str) else x
258
+
259
+ return self.update(
260
+ {x: _prepend(getattr(self, x)) for x in self._symbol_attrs()}
261
+ )
262
+
263
+ def __eq__(self, other: Any) -> bool:
264
+ if not isinstance(other, TilePattern):
265
+ return False
266
+ return all(getattr(self, x) == getattr(other, x) for x in self._symbol_attrs())
267
+
268
+ def __hash__(self) -> int:
269
+ return hash((self.initial_tile_shape, self.tile_shape))
270
+
271
+ def _rename_to_match(
272
+ self, other: "TilePattern"
273
+ ) -> tuple["TilePattern", dict[str, str]]:
274
+ """
275
+ Changes the symbols in this TilePattern to match the other TilePattern.
276
+
277
+ Parameters
278
+ ----------
279
+ other:
280
+ The TilePattern to match.
281
+
282
+ Returns
283
+ -------
284
+ A tuple containing the updated TilePattern and a dictionary of source->target
285
+ symbol renames.
286
+ """
287
+ renames = {}
288
+ setattrs = {}
289
+ for x in self._symbol_attrs():
290
+ if getattr(self, x) != getattr(other, x):
291
+ renames[getattr(self, x)] = getattr(other, x)
292
+ setattrs[x] = getattr(other, x)
293
+ return self.update(**setattrs), renames
294
+
295
+ def _clear_symbols(self) -> "TilePattern":
296
+ """
297
+ Clears the symbols in this TilePattern, replacing them with None.
298
+ """
299
+
300
+ def desymbol(x: str | sympy.Symbol | int | None) -> str | int | None:
301
+ if isinstance(x, (str, sympy.Symbol)):
302
+ return None
303
+ return x
304
+
305
+ return self.update(
306
+ **{x: desymbol(getattr(self, x)) for x in self._symbol_attrs()}
307
+ )
308
+
309
+
310
+ @_uninstantiable
311
+ class Loop(MappingNode):
312
+ """
313
+ A bounded loop over a rank variable with a given shape and/or pattern.
314
+
315
+ Do not instantiate directly; inherited by :class:`~.Temporal` and
316
+ :class:`~.Spatial`.
317
+ """
318
+
319
+ rank_variable: set[RankVariable] | RankVariable
320
+ """ The rank variable(s) iterated over in this loop. This may be a
321
+ single rank variable, or a set of rank variables if the loop is shared between
322
+ multiple Einsums. """
323
+
324
+ tile_shape: ParsesTo[sympy.Symbol | sympy.Expr | int | str] = "symbol"
325
+ """
326
+ The (common) tile shape of the iteration. For example, if the iteration
327
+ space is range(6) and the tile shape is 3, then we create and iterate over
328
+ two tiles [0, 1, 2] and [3, 4, 5].
329
+
330
+ This attribute specifies the *common* tile shape because
331
+ `initial_tile_shape` may be specified.
332
+
333
+ For users writing YAML, the value should be an integer.
334
+
335
+ For those developing the mapper, the literal string "symbol" is often used
336
+ to tell the model to create a sympy symbol to use as the tile shape. Any
337
+ other string may be specified to explicitly request a variable name (later
338
+ converted to a sympy variable).
339
+ """
340
+
341
+ initial_tile_shape: ParsesTo[sympy.Symbol | sympy.Expr | int | str | None] = None
342
+ """
343
+ The shape of the first tile shape. This attribute is optional. If not
344
+ specified, all tiles have the same shape.
345
+
346
+ If specified, the initial tile shape may differ. For example, an initial
347
+ tile shape of 3 and tile shape of 2 creates the following tiles in the
348
+ iteration space: [0, 1, 2], [3, 4], [5, 6], ...
349
+
350
+ Similarly to tile shape, this value should be an integer when writing a
351
+ YAML input.
352
+
353
+ For those developing the mapper, this attribute can be a string. See
354
+ tile_shape for details.
355
+ """
356
+
357
+ _calculated_n_iterations: (
358
+ Literal["symbol"] | sympy.Symbol | sympy.Expr | int | str | None
359
+ ) = None
360
+
361
+ _assume_perfect_factor: bool = True
362
+ """ Whether the Mapper assumes that tile shapes perfectly divide tensor shapes and
363
+ parent tile shapes. """
364
+
365
+ _fused: bool = None
366
+ """ Whether this Loop is shared with another Einsum. """
367
+
368
+ model_config = ConfigDict(arbitrary_types_allowed=True)
369
+
370
+ def __str__(self) -> str:
371
+ return f"for {self.rank_variable} {self.tile_pattern}"
372
+
373
+ def __eq__(self, other: Any) -> bool:
374
+ return (
375
+ isinstance(other, Loop)
376
+ and self.rank_variable == other.rank_variable
377
+ and self.tile_pattern == other.tile_pattern
378
+ )
379
+
380
+ def _render_node_shape(self) -> str:
381
+ return "box"
382
+
383
+ def _render_node_color(self) -> str:
384
+ return "#FCC2FC"
385
+
386
+ @override
387
+ def compact_str(self) -> str:
388
+ """Returns a compact string representation of this Loop."""
389
+ rv = self.rank_variable
390
+ if isinstance(rv, (set, frozenset)):
391
+ rv = ",".join(sorted(rv))
392
+ return f"{rv} {self.tile_pattern}"
393
+
394
+ def _merge(self, other: "Loop", **kwargs) -> "Loop":
395
+ """Merge this Loop with another Loop, returning the result."""
396
+ if not isinstance(other, Loop):
397
+ raise ValueError(f"Expected Loop, got {type(other)}")
398
+ if self.tile_pattern != other.tile_pattern:
399
+ raise ValueError(
400
+ f"Tile patterns do not match: {self.tile_pattern} != {other.tile_pattern}"
401
+ )
402
+
403
+ my_rv, other_rv = self.rank_variable, other.rank_variable
404
+ my_rv = my_rv if isinstance(my_rv, (set, frozenset)) else set((my_rv,))
405
+ other_rv = (
406
+ other_rv if isinstance(other_rv, (set, frozenset)) else set((other_rv,))
407
+ )
408
+ return type(self)(
409
+ rank_variable=my_rv | other_rv,
410
+ tile_pattern=self.tile_pattern,
411
+ _assume_perfect_factor=self._assume_perfect_factor,
412
+ **kwargs,
413
+ )
414
+
415
+ @property
416
+ def tile_pattern(self) -> TilePattern:
417
+ return TilePattern(
418
+ tile_shape=self.tile_shape,
419
+ initial_tile_shape=self.initial_tile_shape,
420
+ calculated_n_iterations=self.calculated_n_iterations,
421
+ )
422
+
423
+ @tile_pattern.setter
424
+ def tile_pattern(self, value: TilePattern):
425
+ self.tile_shape = value.tile_shape
426
+ self.initial_tile_shape = value.initial_tile_shape
427
+ self.calculated_n_iterations = value.calculated_n_iterations
428
+
429
+ @property
430
+ def calculated_n_iterations(self) -> int:
431
+ """The number of iterations performed by this loop."""
432
+ return self._calculated_n_iterations
433
+
434
+ @calculated_n_iterations.setter
435
+ def calculated_n_iterations(self, value: int) -> None:
436
+ """Set the number of iterations performed by this loop. Do not set this!
437
+ This is calculated by the Mapper."""
438
+ self._calculated_n_iterations = value
439
+
440
+
441
+ class Temporal(Loop):
442
+ """A Temporal :class:`~.Loop`."""
443
+
444
+ @override
445
+ def compact_str(self) -> str:
446
+ return f"T-{super().compact_str()}"
447
+
448
+ def __eq__(self, other: "Temporal") -> bool:
449
+ return isinstance(other, Temporal) and super().__eq__(other)
450
+
451
+ def _merge(self, other: "Temporal") -> "Temporal":
452
+ if not isinstance(other, Temporal):
453
+ raise ValueError(f"Expected Temporal, got {type(other)}")
454
+ return super()._merge(other)
455
+
456
+ def _render_node_label(self, **kwargs) -> str:
457
+ with_initial_tile_shape = True
458
+ with_tile_shape = kwargs.get("with_tile_shape", True)
459
+ return (
460
+ f"for {self.rank_variable} "
461
+ f"{self.tile_pattern.as_str(with_initial_tile_shape, with_tile_shape)}"
462
+ )
463
+
464
+
465
+ class Spatial(Loop):
466
+ """A spatial :class:`~.Loop`."""
467
+
468
+ name: int | str
469
+ """ The dimension over which the spatial is occuring. """
470
+
471
+ component: str
472
+ """ The component name across which different spatial iterations occur. """
473
+
474
+ component_object: NoParse[arch.Leaf] = None
475
+ """ The component object across which different spatial iterations occur. """
476
+
477
+ _constrained_to_one: bool = False
478
+ """ Whether this Spatial loop is constrained to one iteration. Do not set this; used
479
+ internally by the Mapper."""
480
+
481
+ @override
482
+ def compact_str(self) -> str:
483
+ return f"S-{self.name}-{super().compact_str()}"
484
+
485
+ def __str__(self) -> str:
486
+ return f"S-{self.name} " + super().__str__()
487
+
488
+ def __eq__(self, other: "Spatial") -> bool:
489
+ return (
490
+ isinstance(other, Spatial)
491
+ and super().__eq__(other)
492
+ and self.name == other.name
493
+ and self.component == other.component
494
+ and self.component_object == other.component_object
495
+ )
496
+
497
+ def _merge(self, other: "Spatial") -> "Spatial":
498
+ if not isinstance(other, Spatial):
499
+ raise ValueError(f"Expected Spatial, got {type(other)}")
500
+ if self.name != other.name:
501
+ raise ValueError(f"Names do not match: {self.name} != {other.name}")
502
+ if self.component != other.component:
503
+ raise ValueError(
504
+ f"Components do not match: {self.component} != {other.component}"
505
+ )
506
+ return super()._merge(
507
+ other,
508
+ name=self.name,
509
+ component=self.component,
510
+ component_object=self.component_object,
511
+ )
512
+
513
+ def _render_node_label(self, **kwargs) -> str:
514
+ with_initial_tile_shape = kwargs.get("with_initial_tile_shape", True)
515
+ with_tile_shape = kwargs.get("with_tile_shape", True)
516
+ return (
517
+ f"S-{self.name}-for {self.rank_variable} "
518
+ f"{self.tile_pattern.as_str(with_initial_tile_shape, with_tile_shape)}"
519
+ )
520
+
521
+
522
+ class TensorHolder(MappingNode):
523
+ """A node that represents a hardware Component holding a set of tensors."""
524
+
525
+ tensors: ParsableList[TensorName]
526
+ """ The names of the tensors being held in this node. """
527
+
528
+ component: str
529
+ """ The name of the component holding the tensors. """
530
+
531
+ component_object: NoParse[arch.Component] = None
532
+ """ The component object holding the tensors. """
533
+
534
+ _must_keep_tensors: ParsableList[TensorName] = ParsableList()
535
+ """ Which tensor(s) the Mapper must keep here. Do not set this! Used internally by
536
+ the Mapper."""
537
+
538
+ _backing: Set[TensorName] = set()
539
+ """ Which tensor(s) are backed by this node. Do not set this! Used internally by
540
+ the Mapper."""
541
+
542
+ _lower: bool = True
543
+ """ Whether this tensor holder can be lowered. Do not set this! Used internally by
544
+ the Mapper."""
545
+
546
+ persistent: bool = False
547
+ """
548
+ Whether this tensor holder is persistent. Persistent tensors can't be tiled and must
549
+ be kept in backing storage for the full duration of the workload's execution.
550
+ """
551
+
552
+ def __eq__(self, other: Any) -> bool:
553
+ return (
554
+ isinstance(other, TensorHolder)
555
+ and set(self.tensors) == set(other.tensors)
556
+ and self.component == other.component
557
+ )
558
+
559
+ @override
560
+ def compact_str(self) -> str:
561
+ tname = ",".join(self.tensors)
562
+ return f"[{tname} in {self.component}]"
563
+
564
+ def __str__(self, color_map: ColorMap = None) -> str:
565
+ tensors = self.tensors
566
+ if color_map is not None:
567
+ format_list = [f"{self.component} reuses"] + list(tensors)
568
+ return color_map.format_list(format_list)
569
+ return f"{self.component} reuses {', '.join(tensors)}"
570
+
571
+ @property
572
+ def tensor(self) -> TensorName:
573
+ """If there is one tensor held in this tensor holder, returns its name.
574
+ Otherwise, raises an error."""
575
+ if len(self.tensors) != 1:
576
+ raise ValueError(
577
+ f"TensorHolder node {repr(self)} has {len(self.tensors)} tensors. "
578
+ f"Access the tensors property instead."
579
+ )
580
+ return self.tensors[0]
581
+
582
+ def _render_node_shape(self) -> str:
583
+ return "cylinder"
584
+
585
+ def _render_node_color(self) -> str:
586
+ return "#D7FCD7"
587
+
588
+ def _merge(self, other: "TensorHolder") -> "TensorHolder":
589
+ if not isinstance(other, TensorHolder):
590
+ raise ValueError(f"Expected TensorHolder, got {type(other)}")
591
+
592
+ if self.component != other.component:
593
+ raise ValueError(
594
+ f"Components do not match: {self.component} != {other.component}"
595
+ )
596
+
597
+ new = type(self)(
598
+ tensors=self.tensors + other.tensors,
599
+ component=self.component,
600
+ component_object=self.component_object,
601
+ )
602
+ return new
603
+
604
+
605
+ class Storage(TensorHolder):
606
+ """
607
+ A Storage :class:`~.TensorHolder` that can hold tensors for reuse.
608
+ """
609
+
610
+ def _merge(self, other: "Storage") -> "Storage":
611
+ if not isinstance(other, Storage):
612
+ raise ValueError(f"Expected Storage, got {type(other)}")
613
+ return super()._merge(other)
614
+
615
+
616
+ class ProcessingStage(TensorHolder):
617
+ """
618
+ A ProcessingStage :class:`~.TensorHolder` that acts as a pass-through, where data is
619
+ not reused but incurs accesses into this ProcessingStage.
620
+ """
621
+
622
+ def _render_node_shape(self) -> str:
623
+ return "rarrow"
624
+
625
+ def _render_node_color(self) -> str:
626
+ return "#FFCC99"
627
+
628
+ def __str__(self, color_map: ColorMap = None) -> str:
629
+ tensors = self.tensors
630
+ if color_map is not None:
631
+ format_list = [f"{self.component} processes"] + list(tensors)
632
+ return color_map.format_list(format_list)
633
+ return f"{self.component} processes {', '.join(tensors)}"
634
+
635
+
636
+ class Compute(MappingNode):
637
+ """A node that represents a compute operation. These nodes are the leaves of the
638
+ LoopTree."""
639
+
640
+ einsum: str
641
+ """ The Einsum being computed. """
642
+
643
+ component: str
644
+ """ The name of the compute component performing the computation. """
645
+
646
+ component_object: NoParse[arch.Compute | None] = None
647
+ """ The :class:`~accelforge.frontend.arch.Compute` object performing the
648
+ computation. """
649
+
650
+ @override
651
+ def compact_str(self) -> str:
652
+ return f"{self.component} computes {self.einsum}"
653
+
654
+ def __str__(self) -> str:
655
+ return f"{self.component} computes {self.einsum}"
656
+
657
+ def _render_node_shape(self) -> str:
658
+ return "ellipse"
659
+
660
+ def _render_node_color(self) -> str:
661
+ return "#E0EEFF"
662
+
663
+
664
+ class MappingNodeWithChildren(MappingNode):
665
+ """
666
+ A :class:`~.MappingNode` that also has child nodes.
667
+ """
668
+
669
+ nodes: NodeList = ParsableList()
670
+ """ The child nodes. """
671
+
672
+ @override
673
+ def _parent2child(
674
+ self, parent: MappingNode
675
+ ) -> list[tuple[MappingNode, MappingNode]]:
676
+ mine = [(self, node) for node in self.nodes]
677
+ for child in self.nodes:
678
+ mine.extend(child._parent2child(self))
679
+ return mine
680
+
681
+ @override
682
+ def _parent2next(self) -> MappingNode:
683
+ return None
684
+
685
+ @override
686
+ def _render_make_children(self, **kwargs) -> list[str]:
687
+ exclude_types = kwargs.get("exclude_types", tuple())
688
+ lines = []
689
+ for child in self.nodes:
690
+ if not isinstance(child, exclude_types):
691
+ lines.append(child._render_node(**kwargs))
692
+ lines.extend(child._render_make_children(**kwargs))
693
+ return lines
694
+
695
+ @override
696
+ def _get_backers(self) -> list[TensorHolder]:
697
+ backing = []
698
+ for child in self.nodes:
699
+ if isinstance(child, TensorHolder) and child._backing:
700
+ backing.append(child)
701
+ elif isinstance(child, MappingNodeWithChildren):
702
+ backing.extend(child._get_backers())
703
+ return backing
704
+
705
+ def clear_nodes_of_type(self, types: type | tuple[type]) -> None:
706
+ """Clears all child nodes that match the given type(s)."""
707
+ new_nodes = []
708
+ for node in self.nodes:
709
+ if isinstance(node, types):
710
+ continue
711
+ if isinstance(node, MappingNodeWithChildren):
712
+ node.clear_nodes_of_type(types)
713
+ new_nodes.append(node)
714
+ self.nodes = ParsableList(new_nodes)
715
+
716
+ def clear_nodes(self, *nodes: MappingNode) -> None:
717
+ """Removes nodes that equal any of the given nodes."""
718
+ new_nodes: list[MappingNode] = []
719
+ for node in self.nodes:
720
+ if any(n == node for n in nodes):
721
+ continue
722
+ if node in nodes:
723
+ continue
724
+ if isinstance(node, MappingNodeWithChildren):
725
+ node.clear_nodes(*nodes)
726
+ new_nodes.append(node)
727
+ self.nodes = ParsableList(new_nodes)
728
+
729
+ def _consolidate_tensor_holders(self) -> None:
730
+ new_nodes = []
731
+ for node in self.nodes:
732
+ if isinstance(node, TensorHolder):
733
+ found = False
734
+ for n in new_nodes[::-1]:
735
+ if isinstance(n, TensorHolder) and n.component == node.component:
736
+ n.tensors.extend(
737
+ n2 for n2 in node.tensors if n2 not in n.tensors
738
+ )
739
+ found = True
740
+ break
741
+ if isinstance(n, Loop):
742
+ break
743
+ if not found:
744
+ new_nodes.append(node)
745
+ else:
746
+ new_nodes.append(node)
747
+ if isinstance(node, MappingNodeWithChildren):
748
+ node._consolidate_tensor_holders()
749
+ assert new_nodes, "BUG"
750
+ self.nodes = ParsableList(new_nodes)
751
+
752
+ def _consolidate_reservations(self) -> None:
753
+ new_nodes = []
754
+ for node in self.nodes:
755
+ if isinstance(node, Reservation):
756
+ found = False
757
+ for n in new_nodes[::-1]:
758
+ if isinstance(n, Reservation) and n.resource == node.resource:
759
+ n.purposes.extend(node.purposes)
760
+ found = True
761
+ break
762
+ if isinstance(n, Loop):
763
+ break
764
+ if not found:
765
+ new_nodes.append(node)
766
+ else:
767
+ new_nodes.append(node)
768
+ if isinstance(node, MappingNodeWithChildren):
769
+ node._consolidate_reservations()
770
+ assert new_nodes, "BUG"
771
+ self.nodes = ParsableList(new_nodes)
772
+
773
+ def _elevate_persistent_nodes_above_splits(self) -> None:
774
+ new_nodes: list[MappingNode] = []
775
+ for node in self.nodes:
776
+ if isinstance(node, Split):
777
+ persistent_nodes = node._get_persistent_nodes()
778
+ new_nodes.extend(persistent_nodes)
779
+ node.clear_nodes(*persistent_nodes)
780
+ if isinstance(node, MappingNodeWithChildren):
781
+ node._elevate_persistent_nodes_above_splits()
782
+ new_nodes.append(node)
783
+ self.nodes = ParsableList(new_nodes)
784
+
785
+ def _elevate_tensor_holders_above_splits(self) -> None:
786
+ new_nodes: list[MappingNode] = []
787
+ for node in self.nodes:
788
+ if isinstance(node, Split):
789
+ shared_tensor_holders = node._get_shared_tensor_holders()
790
+ new_nodes.extend(shared_tensor_holders)
791
+ node.clear_nodes(*shared_tensor_holders)
792
+ if isinstance(node, MappingNodeWithChildren):
793
+ node._elevate_tensor_holders_above_splits()
794
+ new_nodes.append(node)
795
+ self.nodes = ParsableList(new_nodes)
796
+
797
+ def _propagate_reservations_between_splits(self) -> None:
798
+ for node in self.nodes:
799
+ if isinstance(node, MappingNodeWithChildren):
800
+ node._propagate_reservations_between_splits()
801
+
802
+ if not isinstance(self, Split):
803
+ return
804
+
805
+ for i, node1 in enumerate(self.nodes):
806
+ for j in range(i + 2, len(self.nodes)):
807
+ node2 = self.nodes[j]
808
+ reservations1 = node1.get_nodes_of_type(Reservation)
809
+ reservations2 = node2.get_nodes_of_type(Reservation)
810
+
811
+ shared_reservations = []
812
+ for reservation1 in reservations1:
813
+ for reservation2 in reservations2:
814
+ if reservation1 == reservation2:
815
+ shared_reservations.append(reservation1)
816
+ break
817
+
818
+ for s in shared_reservations:
819
+ for k in range(i + 1, j):
820
+ node3 = self.nodes[k]
821
+ if not isinstance(node3, Nested):
822
+ raise ValueError(f"Expected Nested node, got {type(node3)}")
823
+ reservations3 = node3.get_nodes_of_type(Reservation)
824
+ if s not in reservations3:
825
+ node3.nodes.insert(0, copy.deepcopy(s))
826
+
827
+ def _move_tensor_holders_above_reservations(self) -> None:
828
+ groups = []
829
+ cur_group = []
830
+ for node in self.nodes:
831
+ if isinstance(node, MappingNodeWithChildren):
832
+ node._move_tensor_holders_above_reservations()
833
+ if not isinstance(node, (TensorHolder, Reservation)):
834
+ groups.append(cur_group)
835
+ cur_group = []
836
+ cur_group.append(node)
837
+ groups.append(cur_group)
838
+ groups = [g for g in groups if g]
839
+
840
+ groups = [
841
+ [x for x in g if not isinstance(x, (TensorHolder, Reservation))]
842
+ + [x for x in g if isinstance(x, (TensorHolder))]
843
+ + [x for x in g if isinstance(x, (Reservation))]
844
+ for g in groups
845
+ ]
846
+ self.nodes = ParsableList([x for g in groups for x in g])
847
+
848
+ def _remove_reservations_for_processing_stages(self) -> None:
849
+ processing_stages = self.get_nodes_of_type(ProcessingStage)
850
+ processing_stage_names = set(ps.component for ps in processing_stages)
851
+ reservations = self.get_nodes_of_type(Reservation)
852
+ remove = [r for r in reservations if r.resource in processing_stage_names]
853
+ self.clear_nodes(*remove)
854
+
855
+
856
+ class Split(MappingNodeWithChildren):
857
+ """
858
+ A :class:`~.MappingNodeWithChildren` that splits the tree into multiple branches,
859
+ each applying to different Einsums.
860
+ """
861
+
862
+ def __str__(self) -> str:
863
+ return "Split"
864
+
865
+ def _render_node_shape(self) -> str:
866
+ return "hexagon"
867
+
868
+ def _get_persistent_nodes(self) -> list[MappingNode]:
869
+ nodes = []
870
+ for n in self.nodes:
871
+ nodes.extend(n.get_nodes_of_type(TensorHolder))
872
+ nodes.extend(n.get_nodes_of_type(Reservation))
873
+ return [n for n in nodes if n.persistent]
874
+
875
+ def _get_shared_tensor_holders(self) -> list[TensorHolder]:
876
+ tensor_holders = [n.get_nodes_of_type(TensorHolder) for n in self.nodes]
877
+ shared_tensor_holders = []
878
+ for i in range(len(tensor_holders)):
879
+ for j in range(i + 1, len(tensor_holders)):
880
+ for a in tensor_holders[i]:
881
+ for b in tensor_holders[j]:
882
+ if a._backing & b._backing and a not in shared_tensor_holders:
883
+ assert len(a.tensors) == 1 and len(b.tensors) == 1, "BUG"
884
+ shared_tensor_holders.append(a)
885
+ break
886
+ return shared_tensor_holders
887
+
888
+ def _render_node_color(self) -> str:
889
+ return "#FFFFE0"
890
+
891
+
892
+ class Nested(MappingNodeWithChildren):
893
+ """
894
+ A :class:`~.MappingNodeWithChildren` that represents a nested set of nodes. Each
895
+ node is the parent of the next node.
896
+ """
897
+
898
+ def model_post_init(self, __context__=None) -> None:
899
+ for node in list(self.nodes)[:-1]:
900
+ assert not isinstance(
901
+ node, MappingNodeWithChildren
902
+ ), f"Nested node has a child with children. Only the last child can have children."
903
+
904
+ def _parent2child(
905
+ self, parent: MappingNode
906
+ ) -> list[tuple[MappingNode, MappingNode]]:
907
+ parent2child = []
908
+ for node in self.nodes:
909
+ parent2child.append((parent, node))
910
+ parent2child.extend(node._parent2child(parent))
911
+ parent = node._parent2next()
912
+ return parent2child
913
+
914
+ def _parent2next(self) -> MappingNode:
915
+ if not self.nodes:
916
+ raise ValueError("Nested node has no children")
917
+ return self.nodes[-1]._parent2next()
918
+
919
+ # def _render_connect_children(self, names_lines: list[tuple[str, str]], parent_name: str=None) -> list[str]:
920
+ # return super()._render_connect_children(names_lines)
921
+
922
+ def _render_node_label(self, **kwargs) -> str:
923
+ if not self.nodes:
924
+ raise ValueError("Nested node has no children")
925
+ return self.nodes[0]._render_node_label(**kwargs)
926
+
927
+ def _render_node_name(self) -> str:
928
+ if not self.nodes:
929
+ raise ValueError("Nested node has no children")
930
+ return self.nodes[0]._render_node_name()
931
+
932
+ def _get_n_shared_loops(self, other: "Nested") -> int:
933
+ my_backing = set(
934
+ (t, s.component) for s in self._get_backers() for t in s._backing
935
+ )
936
+ other_backing = set(
937
+ (t, s.component) for s in other._get_backers() for t in s._backing
938
+ )
939
+ shared_backing = my_backing & other_backing
940
+
941
+ if not shared_backing:
942
+ return 0
943
+
944
+ n_shared_loops = 0
945
+ for i, node in enumerate(self.nodes):
946
+ if isinstance(node, Loop):
947
+ n_shared_loops += 1
948
+ if (
949
+ isinstance(node, Reservation)
950
+ and (node.purpose, node.resource) in shared_backing
951
+ ):
952
+ return n_shared_loops
953
+ if isinstance(node, Split):
954
+ for child in node.nodes:
955
+ max_child_n_shared_loops = 0
956
+ try:
957
+ max_child_n_shared_loops = max(
958
+ max_child_n_shared_loops, child._get_n_shared_loops(other)
959
+ )
960
+ except ValueError:
961
+ pass
962
+ return max_child_n_shared_loops + n_shared_loops
963
+
964
+ raise ValueError("BUG")
965
+
966
+ def _break_into_reorderable_groups(
967
+ self, stop_at_n_loops: int
968
+ ) -> list[list[MappingNode]]:
969
+ # We can reorder loops relative to each other
970
+ groups = []
971
+ cur_group = None
972
+
973
+ seen_loops = 0
974
+
975
+ if stop_at_n_loops == 0 and not any(
976
+ isinstance(node, Loop) for node in self.nodes
977
+ ):
978
+ return [list(self.nodes)]
979
+
980
+ i = 0
981
+ for i, node in enumerate(self.nodes):
982
+ if seen_loops >= stop_at_n_loops:
983
+ break
984
+ is_iteration = isinstance(node, Loop)
985
+ if cur_group is None:
986
+ cur_group = []
987
+ elif (is_iteration and not all(isinstance(x, Loop) for x in cur_group)) or (
988
+ not is_iteration and any(isinstance(x, Loop) for x in cur_group)
989
+ ):
990
+ groups.append(cur_group)
991
+ cur_group = []
992
+ cur_group.append(node)
993
+ assert not isinstance(node, Sequential) or i == len(self.nodes) - 1, "BUG"
994
+ if isinstance(node, Loop):
995
+ seen_loops += 1
996
+
997
+ if cur_group:
998
+ groups.append(cur_group)
999
+
1000
+ final_group = self.nodes[i:]
1001
+ groups.append(final_group)
1002
+
1003
+ if seen_loops < stop_at_n_loops:
1004
+ raise ValueError(
1005
+ f"Expected {stop_at_n_loops} loops, but only found {seen_loops}"
1006
+ )
1007
+
1008
+ # Lower reservations. If reservations are in the second-to-last group
1009
+ # # non-iteration group, lower them to the last group.
1010
+ # if len(groups) > 3:
1011
+ # assert not any(isinstance(x, Loop) for x in groups[-1]), "BUG"
1012
+ # assert not any(isinstance(x, Loop) for x in groups[-3]), "BUG"
1013
+ # reservations = [x for x in groups[-2] if isinstance(x, Reservation)]
1014
+ # groups[-1].extend(reservations)
1015
+ # groups[-3] = [x for x in groups[-3] if x not in reservations]
1016
+
1017
+ return groups
1018
+
1019
+ def _merge(self, other: "Nested", n_shared_loops: int) -> "Nested":
1020
+
1021
+ # Break up the nodes above the indices. We need to have them in the format of
1022
+ # [(loop, other stuff...), (loop, other stuff...), ...]
1023
+ my_groups = self._break_into_reorderable_groups(stop_at_n_loops=n_shared_loops)
1024
+ my_remaining = my_groups.pop(-1)
1025
+ other_groups = other._break_into_reorderable_groups(
1026
+ stop_at_n_loops=n_shared_loops
1027
+ )
1028
+ other_remaining = other_groups.pop(-1)
1029
+
1030
+ # Reorder so that the loops are in the same order. We can't reorder groups that
1031
+ # have other stuff in them because that'll change the behavior of the mapping.
1032
+ zipped_groups = []
1033
+
1034
+ def _pop_loop_group(groups: list[list[MappingNode]]) -> list[MappingNode]:
1035
+ while groups and not any(isinstance(x, Loop) for x in groups[0]):
1036
+ zipped_groups.append(groups.pop(0))
1037
+ return groups.pop(0) if groups else []
1038
+
1039
+ my_loop_group = _pop_loop_group(my_groups)
1040
+ other_loop_group = _pop_loop_group(other_groups)
1041
+ while (my_groups or my_loop_group) and (other_groups or other_loop_group):
1042
+ if not my_loop_group:
1043
+ my_loop_group = _pop_loop_group(my_groups)
1044
+ continue
1045
+ if not other_loop_group:
1046
+ other_loop_group = _pop_loop_group(other_groups)
1047
+ continue
1048
+
1049
+ # Add matching loops from the two groups. If we can't find a match, raise an
1050
+ # error.
1051
+ to_add = None
1052
+ for i, a in enumerate(my_loop_group):
1053
+ for j, b in enumerate(other_loop_group):
1054
+ if a == b:
1055
+ to_add = [a]
1056
+ my_loop_group.pop(i)
1057
+ other_loop_group.pop(j)
1058
+ break
1059
+
1060
+ if to_add is None:
1061
+ # TODO: This check for one is only to early catch bugs coming here. The
1062
+ # code below says that if we couldn't find a match, then ignore rank
1063
+ # variables and assume that rank variable translation would fix it.
1064
+ assert len(my_loop_group) == 1 or len(other_loop_group) == 1
1065
+ has_one, may_not_have_one = my_loop_group, other_loop_group
1066
+ if len(has_one) != 1:
1067
+ has_one, may_not_have_one = other_loop_group, my_loop_group
1068
+
1069
+ l = copy.deepcopy(has_one.pop(0))
1070
+ l.rank_variable = (
1071
+ l.rank_variable
1072
+ if isinstance(l.rank_variable, set)
1073
+ else set([l.rank_variable])
1074
+ )
1075
+ for l2 in may_not_have_one:
1076
+ if l2.calculated_n_iterations == l.calculated_n_iterations:
1077
+ break
1078
+ else:
1079
+ raise ValueError(
1080
+ f"No matching loop found for {my_loop_group} and {other_loop_group}"
1081
+ )
1082
+ print(
1083
+ f"Warning. Matching loops {l} and {l2}. Need rank variable translation here."
1084
+ )
1085
+
1086
+ may_not_have_one.remove(l2)
1087
+ rv = l2.rank_variable
1088
+ rv = rv if isinstance(rv, set) else set([rv])
1089
+ l.rank_variable = l.rank_variable | rv
1090
+ to_add = [l]
1091
+
1092
+ zipped_groups.append(to_add)
1093
+
1094
+ assert not my_loop_group and not other_loop_group, "BUG"
1095
+
1096
+ zipped_groups.extend(my_groups)
1097
+ zipped_groups.extend(other_groups)
1098
+
1099
+ flattened = list(x for group in zipped_groups for x in group)
1100
+ new_nodes = [x for x in flattened if not isinstance(x, Sequential)]
1101
+ new_nodes.extend([x for x in flattened if isinstance(x, Sequential)])
1102
+
1103
+ if isinstance(my_remaining[0], Sequential) and isinstance(
1104
+ other_remaining[0], Sequential
1105
+ ):
1106
+ my_remaining[0].nodes.extend(other_remaining[0].nodes)
1107
+ assert len(my_remaining) == 1 and len(other_remaining) == 1, "BUG"
1108
+ new_nodes.append(my_remaining[0])
1109
+ elif isinstance(my_remaining[0], Sequential):
1110
+ my_remaining[0].nodes.append(Nested(nodes=other_remaining))
1111
+ assert len(my_remaining) == 1, "BUG"
1112
+ new_nodes.append(my_remaining[0])
1113
+ elif isinstance(other_remaining[0], Sequential):
1114
+ other_remaining[0].nodes.append(Nested(nodes=my_remaining))
1115
+ assert len(other_remaining) == 1, "BUG"
1116
+ new_nodes.append(other_remaining[0])
1117
+ else:
1118
+ new_nodes.append(
1119
+ Sequential(
1120
+ nodes=[Nested(nodes=my_remaining), Nested(nodes=other_remaining)]
1121
+ )
1122
+ )
1123
+
1124
+ return Nested(nodes=new_nodes)
1125
+
1126
+ def _beautify_loops(
1127
+ self, rank_variable_bounds: Optional[dict[str, dict[str, int]]] = None
1128
+ ):
1129
+ to_remove = []
1130
+ rank_variable_bounds = rank_variable_bounds or {}
1131
+
1132
+ for i, node in enumerate(self.nodes):
1133
+ if not isinstance(node, Loop):
1134
+ continue
1135
+ prev_tile_shape = None
1136
+ for j in range(i - 1, -1, -1):
1137
+ node2 = self.nodes[j]
1138
+ if not isinstance(node2, Loop):
1139
+ continue
1140
+ if node2.tile_shape is None:
1141
+ continue
1142
+ if node2.rank_variable != node.rank_variable:
1143
+ continue
1144
+ prev_tile_shape = node2.tile_shape
1145
+ break
1146
+ if prev_tile_shape is None:
1147
+ prev_tile_shape = rank_variable_bounds.get(node.rank_variable, None)
1148
+ if prev_tile_shape is not None:
1149
+ if node.tile_shape == prev_tile_shape:
1150
+ to_remove.append(i)
1151
+ continue
1152
+ elif node.tile_shape is not None and prev_tile_shape is not None:
1153
+ node.tile_pattern = node.tile_pattern.update(
1154
+ calculated_n_iterations=prev_tile_shape / node.tile_shape,
1155
+ )
1156
+
1157
+ def safe_int_cast(x: int | float | None) -> int | float | None:
1158
+ try:
1159
+ int_x = int(x)
1160
+ return int_x if int_x == x else x
1161
+ except:
1162
+ pass
1163
+ return x
1164
+
1165
+ for i, node in enumerate(self.nodes):
1166
+ if not isinstance(node, Loop):
1167
+ continue
1168
+ node.tile_pattern = node.tile_pattern.update(
1169
+ initial_tile_shape=safe_int_cast(node.tile_pattern.initial_tile_shape),
1170
+ tile_shape=safe_int_cast(node.tile_pattern.tile_shape),
1171
+ )
1172
+
1173
+ self.nodes = [node for i, node in enumerate(self.nodes) if i not in to_remove]
1174
+
1175
+ @override
1176
+ def compact_str(self) -> str:
1177
+ result = []
1178
+ prev = None
1179
+ for node in self.nodes:
1180
+ try:
1181
+ prev = prev._merge(node)
1182
+ except:
1183
+ if prev is not None:
1184
+ result.append(prev)
1185
+ prev = node
1186
+ if prev is not None:
1187
+ result.append(prev)
1188
+
1189
+ return " ".join(node.compact_str() for node in result)
1190
+
1191
+ def _get_single_tensor_mapping(
1192
+ self,
1193
+ tensor_name: TensorName,
1194
+ flattened_arch: list[arch.Leaf],
1195
+ indexing_expressions: set[str],
1196
+ ) -> Self:
1197
+ """
1198
+ Ctrl-F for CONTIGUOUS_ITERATION_SPACE_DISCUSSION
1199
+
1200
+ Returns this Nested node with only the nodes associated with the given tensor.
1201
+
1202
+ Includes loops and compute nodes, plus any tensor holders and reservations that
1203
+ are associated with the given tensor.
1204
+
1205
+ Puts spatials as high as they can go while being below any node that is above
1206
+ them in the memory hierarchy. Between two tensor holders, generally puts spatial
1207
+ loops at the bottom, but may put them above temporal loops if that better lines
1208
+ up with the original order. When memory hierarchy order is followed globally
1209
+ (e.g., output in buffer must be above input in reg), the loop order going into
1210
+ this function will always match that going out.
1211
+
1212
+ This function expects, as input, all spatials to be placed as low as they can
1213
+ go, but above their respective fanouts.
1214
+
1215
+ When memory hierarchy order is only followed per-tensor (e.g., output in buffer
1216
+ must be above output in reg, but can be below input in reg), things may be more
1217
+ complicated. We discuss this in more detail using the following example:
1218
+
1219
+ Hierarchy:
1220
+
1221
+ - Buffer
1222
+ - 2x fanout
1223
+ - Reg
1224
+
1225
+ Mapping:
1226
+ S-reg for m1 in [0, 2):
1227
+ [Reg reuses input]
1228
+ for m0 in [0, 2):
1229
+ [Buffer reuses output]
1230
+
1231
+ When given a mapping and architecture like the above, this function may reorder
1232
+ the spatial and temporal loops, yielding the following:
1233
+
1234
+ Mapping for input:
1235
+ S-reg for m1 in [0, 2):
1236
+ [Reg reuses input]
1237
+ for m0 in [0, 2):
1238
+
1239
+ Mapping for output:
1240
+ for m0 in [0, 2):
1241
+ [Buffer reuses output]
1242
+ S-reg for m1 in [0, 2):
1243
+
1244
+ Unfortunately, such reordering is inevitable given our assumptions of an
1245
+ inclusive memory hierarchy (because any tile stored in the reg must be stored in
1246
+ the buffer), and our desire to place the reg storage node higher. It's also a
1247
+ symptom of the following other issues:
1248
+
1249
+ - In cases like these, storage nodes may need to keep non-contiguous chunks of
1250
+ the iteration space. For example, if the spatial loop is on top, then one reg
1251
+ holds [0, 1] while the other holds [2, 3]. Meanwhile, in the first temporal
1252
+ iteration, the buffer holds [0, 2] and in the second temporal iteration, the
1253
+ buffer holds [2, 4].
1254
+ - We get weird dependencies between loop order and compatibility for fusion
1255
+ because loop order affects the iteration space tiles that are stored.
1256
+
1257
+ To prevent these problems from occuring, we raise an error if there any temporal
1258
+ loops in between that affect the same indexing expressions as the spatial loops.
1259
+ I tried to have it work with our model and then constraining the temporal loops
1260
+ to be null (have the same tile shape as their outer loop), but when we run it
1261
+ per-tensor and reorder, the loop above the temporal changes, so the model
1262
+ returns inconsistent results for each tensor as the tile shape is different.
1263
+ With this constraint, we'll never reorder spatial and temporal loops that affect
1264
+ one another.
1265
+
1266
+ The result of the above is that we'll never reorder spatial and temporal loops
1267
+ that affect one another.
1268
+
1269
+ I haven't thought through how this will work with more complex rank variable
1270
+ expressions, so to be safe, will say that there can not be a temporal and
1271
+ spatial loop that affect the same indexing expression or each others' loop
1272
+ bounds.
1273
+
1274
+ These problems also aren't necessarily impossible to solve; I just haven't
1275
+ thought it through. If we do think it through, a good place to start would be to
1276
+ update the model to support non-contiguous chunks of the iteration space, then
1277
+ come up with some way to explore mappings and fusion while using non-contiguous
1278
+ chunks of the iteration space.
1279
+
1280
+ TODO: Mapper then also needs explore swapping temporal/spatial loops
1281
+ """
1282
+ spatials = [n for n in self.nodes if isinstance(n, Spatial)]
1283
+ tensor_holders = [
1284
+ n for n in self.nodes if isinstance(n, (TensorHolder, Compute))
1285
+ ]
1286
+ others = [
1287
+ n
1288
+ for n in self.nodes
1289
+ if not isinstance(n, (TensorHolder, Reservation, Spatial))
1290
+ or (isinstance(n, TensorHolder) and n.tensor == tensor_name)
1291
+ or (isinstance(n, Reservation) and n.purpose == tensor_name)
1292
+ ]
1293
+ assert not any(isinstance(n, MappingNodeWithChildren) for n in others), "BUG"
1294
+
1295
+ def arch_idx(node: MappingNode) -> int:
1296
+ for i, n in enumerate(flattened_arch):
1297
+ if n.name == node.component:
1298
+ return i
1299
+ raise ValueError(f"Component {node.component} not found in flattened arch")
1300
+
1301
+ spatials_above = {
1302
+ id(node): [s for s in spatials if arch_idx(s) <= arch_idx(node)]
1303
+ for node in tensor_holders
1304
+ }
1305
+ spatials_below = {
1306
+ id(node): [s for s in spatials if arch_idx(s) >= arch_idx(node)]
1307
+ for node in tensor_holders
1308
+ }
1309
+
1310
+ mapping = []
1311
+ for to_add in others:
1312
+ if isinstance(to_add, (TensorHolder, Compute)):
1313
+ cur_spatials_above = [
1314
+ s for s in spatials if s in spatials_above[id(to_add)]
1315
+ ]
1316
+ spatials = [s for s in spatials if s not in cur_spatials_above]
1317
+ mapping.extend(cur_spatials_above)
1318
+ mapping.append(to_add)
1319
+
1320
+ mapping.extend(spatials)
1321
+
1322
+ # Check that spatials are always above their respective fanouts
1323
+ for i, node in enumerate(mapping):
1324
+ if not isinstance(node, (TensorHolder, Compute)):
1325
+ continue
1326
+ for node2 in mapping[i + 1 :]:
1327
+ if not isinstance(node2, Spatial):
1328
+ continue
1329
+ assert node2 in spatials_below[id(node)], "BUG"
1330
+ assert node2 not in spatials_above[id(node)], "BUG"
1331
+
1332
+ # Split the mapping into groups of tensor holders and sequential loops
1333
+ id2idx = {id(node): i for i, node in enumerate(self.nodes)}
1334
+ groups = []
1335
+ for node in mapping:
1336
+ if (
1337
+ isinstance(node, Loop)
1338
+ and len(groups) > 0
1339
+ and isinstance(groups[-1][0], Loop)
1340
+ ):
1341
+ groups[-1].append(node)
1342
+ else:
1343
+ groups.append([node])
1344
+
1345
+ groups = [sorted(g, key=lambda x: id2idx[id(x)]) for g in groups]
1346
+ mapping = [x for g in groups for x in g]
1347
+
1348
+ # Check that all storage-temporal relations are held from before
1349
+ node2idx = {id(node): i for i, node in enumerate(mapping)}
1350
+ prev_node2idx = {id(node): i for i, node in enumerate(self.nodes)}
1351
+ for node, node2 in itertools.combinations(mapping, 2):
1352
+ idx1 = node2idx[id(node)]
1353
+ idx2 = node2idx[id(node2)]
1354
+ prev_idx1 = prev_node2idx[id(node)]
1355
+ prev_idx2 = prev_node2idx[id(node2)]
1356
+ if isinstance(node, TensorHolder) and isinstance(node2, TensorHolder):
1357
+ assert (idx1 > idx2) == (prev_idx1 > prev_idx2), "BUG"
1358
+ # Because of the reordering above, may lower loops beneath tensor holders
1359
+ # and temporal loops in order to place them as low as possble above the
1360
+ # fanout.
1361
+ # elif isinstance(node, TensorHolder) and isinstance(node2, Spatial):
1362
+ # assert (idx1 > idx2) == (prev_idx1 > prev_idx2), "BUG"
1363
+ # elif isinstance(node, Spatial) and isinstance(node2, TensorHolder):
1364
+ # assert (idx1 > idx2) == (prev_idx1 > prev_idx2), "BUG"
1365
+ elif isinstance(node, Spatial) and isinstance(node2, Spatial):
1366
+ assert (idx1 > idx2) == (prev_idx1 > prev_idx2), "BUG"
1367
+
1368
+ # for m in mapping:
1369
+ # print(m.compact_str())
1370
+ # for n in self.nodes:
1371
+ # print(n.compact_str())
1372
+
1373
+ # Check for spatial/temporal loops that have been reordered. These ones can not
1374
+ # co-exist because the tiling is inconsistent.
1375
+ # Ctrl-F for CONTIGUOUS_ITERATION_SPACE_DISCUSSION
1376
+ from accelforge.frontend.workload import isl_expression_has_variable
1377
+
1378
+ node2idx = {id(node): i for i, node in enumerate(self.nodes)}
1379
+ for node1, node2 in itertools.combinations(mapping, 2):
1380
+ # Both must be loops
1381
+ if not isinstance(node1, Loop) or not isinstance(node2, Loop):
1382
+ continue
1383
+ # Must have been reordered
1384
+ if node2idx[id(node1)] <= node2idx[id(node2)]:
1385
+ continue
1386
+ # Must affect the same rank variable expression
1387
+ for expr in indexing_expressions:
1388
+ if not isl_expression_has_variable(expr, node1.rank_variable):
1389
+ continue
1390
+ if not isl_expression_has_variable(expr, node2.rank_variable):
1391
+ continue
1392
+
1393
+ s = """
1394
+ In the given mapping, there exists (potentially with other nodes in
1395
+ between) a spatial loop above a temporal loop above a storage node,
1396
+ where the loops index into the same indexing expression, and the storage
1397
+ node is not fanned out by the spatial loop. This is not allowed.
1398
+
1399
+ Mapping:
1400
+ """
1401
+ s = s.replace(" ", "")
1402
+
1403
+ to_add = []
1404
+ for n in self.nodes:
1405
+ if id(n) == id(node1) or id(n) == id(node2):
1406
+ to_add.append(f"\t{n.compact_str()} <-- Offending Loop")
1407
+ else:
1408
+ to_add.append(f"\t{n.compact_str()}")
1409
+
1410
+ raise ValueError(s + "\n".join(to_add))
1411
+
1412
+ return type(self)(nodes=mapping)
1413
+
1414
+
1415
+ class Parallel(Split):
1416
+ """
1417
+ A :class:`~.Split` where each branch operates at the same time in different
1418
+ spatially-organized hardware.
1419
+ """
1420
+
1421
+ pass
1422
+
1423
+
1424
+ class Pipeline(Split):
1425
+ """
1426
+ A :class:`~.Split` where each branch operates at the same time in different
1427
+ spatially-organized hardware.
1428
+ """
1429
+
1430
+ pass
1431
+
1432
+
1433
+ class Sequential(Split):
1434
+ """
1435
+ A :class:`~.Split` where branches are processed one-after-another.
1436
+ """
1437
+
1438
+ pass
1439
+
1440
+
1441
+ # =============================================================================
1442
+ # Nodes That May Only be Inserted by the Model
1443
+ # =============================================================================
1444
+
1445
+
1446
+ class Reservation(MappingNode):
1447
+ """A node that reserves a hardware resource for a specific task."""
1448
+
1449
+ purposes: ParsableList[str]
1450
+ """ The reasons for reserving the resource. """
1451
+
1452
+ resource: str
1453
+ """ The resource being reserved. """
1454
+
1455
+ _backing: Set[str] = set()
1456
+ """ Tensors for which this reservation is reserving the tensor's backing storage.
1457
+ """
1458
+
1459
+ persistent: bool = False
1460
+ """
1461
+ Whether this reservation is persistent. Persistent reservations can't be tiled and
1462
+ must be kept in backing storage for the full duration of the workload's execution.
1463
+ """
1464
+
1465
+ @override
1466
+ def compact_str(self) -> str:
1467
+ return f'{",".join(self.purposes)} reserves {self.resource}'
1468
+
1469
+ def __str__(self, color_map: ColorMap = None) -> str:
1470
+ purposes = self.purposes
1471
+ if color_map is not None:
1472
+ format_list = [f"{self.resource} reserved for"] + list(purposes)
1473
+ return color_map.format_list(format_list)
1474
+ return f"{self.resource} reserved for {",".join(purposes)}"
1475
+
1476
+ def _render_node_shape(self) -> str:
1477
+ return "component"
1478
+
1479
+ @property
1480
+ def purpose(self) -> str:
1481
+ if len(self.purposes) == 1:
1482
+ return self.purposes[0]
1483
+ raise ValueError(f"Reservation has multiple purposes: {self.purposes}")
1484
+
1485
+ def __eq__(self, other: "Reservation") -> bool:
1486
+ return (
1487
+ isinstance(other, Reservation)
1488
+ and self.purposes == other.purposes
1489
+ and self.resource == other.resource
1490
+ )
1491
+
1492
+ def _render_node_color(self) -> str:
1493
+ return "#E8E8E8" # Light gray
1494
+
1495
+
1496
+ # =============================================================================
1497
+ # Top-level Mapping
1498
+ # =============================================================================
1499
+
1500
+ MappingNodeTypes: TypeAlias = Union[
1501
+ Temporal,
1502
+ Spatial,
1503
+ Storage,
1504
+ Pipeline,
1505
+ Sequential,
1506
+ Compute,
1507
+ Reservation,
1508
+ # Fill,
1509
+ TensorHolder,
1510
+ ]
1511
+ """TypeAlias MappingNodeTypes: The types of MappingNodes possible."""
1512
+
1513
+
1514
+ class Mapping(Nested):
1515
+ """A Mapping of a workload onto a hardware architecture."""
1516
+
1517
+ # version: Annotated[str, assert_version] = __version__
1518
+
1519
+ _n_loop_orders: int | None = None
1520
+ """ Used for counting number of unique mappings. Do not touch. """
1521
+
1522
+ def remove_reservations(self):
1523
+ self.nodes = [n for n in self.nodes if not isinstance(n, Reservation)]
1524
+
1525
+ def split_loop_with_multiple_rank_variables(self):
1526
+ new_nodes = []
1527
+ for node in self.nodes:
1528
+ if isinstance(node, Loop) and isinstance(node.rank_variable, set):
1529
+ for rank_variable in node.rank_variable:
1530
+ new_node = copy.copy(node)
1531
+ new_node.rank_variable = rank_variable
1532
+ new_nodes.append(new_node)
1533
+ else:
1534
+ new_nodes.append(node)
1535
+ self.nodes = new_nodes
1536
+
1537
+ def split_tensor_holders_with_multiple_tensors(self):
1538
+ new_nodes = []
1539
+ for node in self.nodes:
1540
+ if isinstance(node, TensorHolder) and len(node.tensors) > 1:
1541
+ for tensor in node.tensors:
1542
+ new_node = copy.copy(node)
1543
+ new_node.tensors = [tensor]
1544
+ new_nodes.append(new_node)
1545
+ else:
1546
+ new_nodes.append(node)
1547
+ self.nodes = new_nodes
1548
+
1549
+ def _get_fused_slice(self, fusable_tensors: set[TensorName]) -> "Mapping":
1550
+ """
1551
+ Return a mapping with:
1552
+ - All backing reservation nodes for intermediate tensors
1553
+ - Loop nodes above any backing reservation nodes
1554
+ """
1555
+ # All intermediate tensors that can be found in this mapping
1556
+ # Note: `fusable_tensors` may be for **whole workload**.
1557
+ relevant_intermediate_tensors = set()
1558
+ for node in self.nodes:
1559
+ if isinstance(node, Reservation):
1560
+ if node.purpose in fusable_tensors:
1561
+ relevant_intermediate_tensors.add(node.purpose)
1562
+
1563
+ fused_slice = Mapping(nodes=[])
1564
+ to_add = []
1565
+ for node in self.nodes:
1566
+ node = copy.copy(node)
1567
+ if isinstance(node, Reservation):
1568
+ if node.purpose not in relevant_intermediate_tensors:
1569
+ continue
1570
+ fused_slice.nodes.extend(to_add + [node])
1571
+ to_add = []
1572
+ relevant_intermediate_tensors.remove(node.purpose)
1573
+ if len(relevant_intermediate_tensors) == 0:
1574
+ break
1575
+ elif isinstance(node, Loop):
1576
+ to_add.append(node)
1577
+ return fused_slice
1578
+
1579
+ @property
1580
+ def loops(self) -> list[Loop]:
1581
+ """Returns all :class:`~.Loop` nodes in the Mapping."""
1582
+ return self.get_nodes_of_type(Loop)
1583
+
1584
+ def _render_node_label(self, **kwargs) -> str:
1585
+ return f"Root"
1586
+
1587
+ def _repr_svg_(self) -> str:
1588
+ return self.render()
1589
+
1590
+ def render_pydot(self, with_reservations=True, with_tile_shape=True) -> pydot.Dot:
1591
+ """Renders the mapping as a Pydot graph. Returns an SVG string."""
1592
+ graph = _pydot_graph()
1593
+ # Enable HTML-like labels for color support
1594
+ graph.set_node_defaults(label="")
1595
+ if not with_reservations:
1596
+ exclude_types = (Reservation,)
1597
+ else:
1598
+ exclude_types = tuple()
1599
+ for node in self._render_make_children(
1600
+ exclude_types=exclude_types, with_tile_shape=with_tile_shape
1601
+ ):
1602
+ graph.add_node(node)
1603
+
1604
+ color_keys = set()
1605
+ all_nodes = self._flatten()
1606
+ for node in all_nodes:
1607
+ if isinstance(node, TensorHolder):
1608
+ color_keys.update(node.tensors)
1609
+ if isinstance(node, Reservation):
1610
+ color_keys.update(node.purposes)
1611
+
1612
+ color_map = ColorMap(sorted(color_keys))
1613
+
1614
+ for node in all_nodes:
1615
+ if isinstance(node, (TensorHolder, Reservation)):
1616
+ graph_nodes = graph.get_node(node._render_node_name())
1617
+ for graph_node in graph_nodes:
1618
+ # Set HTML-like label for color support
1619
+ new_label = node.__str__(color_map)
1620
+ graph_node.set_label(new_label)
1621
+ # graph_node.set_fillcolor(color_map[node._color_key()])
1622
+ # graph_node.set_style('filled')
1623
+
1624
+ added_edges = set()
1625
+ child2included_parent = {}
1626
+ for parent, child in self._parent2child(None):
1627
+ parent_name = parent._render_node_name() if parent is not None else None
1628
+ child_name = child._render_node_name()
1629
+ if isinstance(parent, exclude_types):
1630
+ parent_name = child2included_parent.get(parent_name, None)
1631
+ child2included_parent[child_name] = parent_name
1632
+ if not isinstance(child, exclude_types):
1633
+ added_edges.add((parent_name, child_name))
1634
+ for parent_name, child_name in added_edges:
1635
+ if parent_name is not None:
1636
+ graph.add_edge(pydot.Edge(parent_name, child_name))
1637
+ return graph
1638
+
1639
+ def render(self) -> _SVGJupyterRender:
1640
+ graph = self.render_pydot()
1641
+ return _SVGJupyterRender(graph.create_svg(prog="dot").decode("utf-8"))
1642
+
1643
+ @classmethod
1644
+ def _from_pmappings(
1645
+ cls,
1646
+ pmappings: list[Nested],
1647
+ rank_variable_bounds: Optional[dict[str, dict[str, int]]] = None,
1648
+ ) -> "Mapping":
1649
+ pmappings = list(copy.deepcopy(pmappings))
1650
+ for pmapping in pmappings:
1651
+ pmapping._beautify_loops(rank_variable_bounds)
1652
+
1653
+ while len(pmappings) > 1:
1654
+ highest_n_shared_loops = 0
1655
+ highest_shared_pmapping_index = 0
1656
+ for i, pmapping in enumerate(pmappings):
1657
+ shared_index = 0
1658
+ for j in range(i + 1, len(pmappings)):
1659
+ shared_index = max(
1660
+ pmapping._get_n_shared_loops(pmappings[j]), shared_index
1661
+ )
1662
+ if shared_index > highest_n_shared_loops:
1663
+ highest_n_shared_loops = shared_index
1664
+ highest_shared_pmapping_index = i
1665
+
1666
+ def einsum_names(pmapping: Nested) -> str:
1667
+ return ",".join(n.einsum for n in pmapping.get_nodes_of_type(Compute))
1668
+
1669
+ names_a = einsum_names(pmappings[highest_shared_pmapping_index])
1670
+ names_b = einsum_names(pmappings[highest_shared_pmapping_index + 1])
1671
+ # print(
1672
+ # f"Merging with shared loops {highest_n_shared_loops}: {names_a} <--> {names_b}."
1673
+ # )
1674
+ # print(pmappings[highest_shared_pmapping_index]._get_n_shared_loops(pmappings[highest_shared_pmapping_index + 1]))
1675
+ pmappings[highest_shared_pmapping_index] = pmappings[
1676
+ highest_shared_pmapping_index
1677
+ ]._merge(
1678
+ pmappings.pop(highest_shared_pmapping_index + 1),
1679
+ 0 if _NO_JOIN_MAPPING_VISUALIZATION else highest_n_shared_loops,
1680
+ )
1681
+
1682
+ mapping: Mapping = cls(nodes=pmappings[0].nodes)
1683
+ mapping._elevate_persistent_nodes_above_splits()
1684
+ mapping._elevate_tensor_holders_above_splits()
1685
+ mapping._propagate_reservations_between_splits()
1686
+ mapping._consolidate_tensor_holders()
1687
+ mapping._consolidate_reservations()
1688
+ mapping._move_tensor_holders_above_reservations()
1689
+ mapping._remove_reservations_for_processing_stages()
1690
+ return mapping
1691
+
1692
+ # import mermaid as md
1693
+ # from mermaid.graph import Graph
1694
+ # lines = []
1695
+ # lines = [
1696
+ # "graph TD",
1697
+ # "%%{init: {'flowchart': {'nodeSpacing': 30, 'rankSpacing': 30, 'padding': 2}, 'themeVariables': {'fontFamily': 'Arial, sans-serif'}}}%%"
1698
+ # ]
1699
+ # lines.extend(self._render_make_children())
1700
+ # for parent, child in self._parent2child(None):
1701
+ # if parent is not None:
1702
+ # lines.append(f"{parent._render_node_name()} --> {child._render_node_name()}")
1703
+ # # if _is_root:
1704
+ # # lines.extend([
1705
+ # # "",
1706
+ # # "classDef default fill:#fff,stroke:#000,stroke-width:1px,color:#000,font-family:Arial,font-size:12px,padding:2px;",
1707
+ # # "classDef compact fill:#fff,stroke:#000,stroke-width:1px,color:#000,font-family:Arial,font-size:12px,padding:2px;"
1708
+ # # ])
1709
+
1710
+ # # Create the graph with the flowchart script
1711
+ # flowchart_script = "\n".join(lines)
1712
+ # graph = Graph('Flowchart', flowchart_script)
1713
+
1714
+ # # Set the configuration for compact layout
1715
+ # config = md.Config()
1716
+ # config.theme = 'base'
1717
+ # # config.theme_variables = {
1718
+ # # 'primaryColor': '#ffffff',
1719
+ # # 'primaryTextColor': '#000000',
1720
+ # # 'primaryBorderColor': '#000000',
1721
+ # # 'lineColor': '#000000',
1722
+ # # 'fontSize': '12px'
1723
+ # # }
1724
+ # # config.flowchart = {
1725
+ # # 'nodeSpacing': 20,
1726
+ # # 'rankSpacing': 10,
1727
+ # # 'curve': 'linear'
1728
+ # # }
1729
+ # graph.config = config
1730
+
1731
+ # return md.Mermaid(graph)
1732
+
1733
+
1734
+ Split.model_rebuild()
1735
+ Nested.model_rebuild()
1736
+ Mapping.model_rebuild()