accelforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. accelforge/__init__.py +21 -0
  2. accelforge/_accelerated_imports.py +16 -0
  3. accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
  4. accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
  5. accelforge/_deprecate/_simanneal/simanneal.py +666 -0
  6. accelforge/_deprecate/_simanneal/tracking.py +105 -0
  7. accelforge/_deprecate/_simanneal/wrappers.py +218 -0
  8. accelforge/_deprecate/_simanneal2/__init__.py +7 -0
  9. accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
  10. accelforge/_deprecate/_simanneal2/tracking.py +116 -0
  11. accelforge/_deprecate/compatibility_util.py +181 -0
  12. accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
  13. accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
  14. accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
  15. accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
  16. accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
  17. accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
  18. accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
  19. accelforge/_deprecate/tags.py +69 -0
  20. accelforge/_deprecate/viz/__init__.py +0 -0
  21. accelforge/_deprecate/viz/interactive.py +159 -0
  22. accelforge/_deprecate/viz/reservationtree.py +307 -0
  23. accelforge/_deprecate/viz/ski_slope.py +88 -0
  24. accelforge/_version.py +15 -0
  25. accelforge/examples.py +39 -0
  26. accelforge/frontend/__init__.py +10 -0
  27. accelforge/frontend/_binding.py +129 -0
  28. accelforge/frontend/_workload_isl/__init__.py +2 -0
  29. accelforge/frontend/_workload_isl/_isl.py +149 -0
  30. accelforge/frontend/_workload_isl/_symbolic.py +141 -0
  31. accelforge/frontend/arch copy.py +1544 -0
  32. accelforge/frontend/arch.py +1642 -0
  33. accelforge/frontend/config.py +63 -0
  34. accelforge/frontend/mapper/__init__.py +5 -0
  35. accelforge/frontend/mapper/ffm.py +126 -0
  36. accelforge/frontend/mapper/mapper.py +7 -0
  37. accelforge/frontend/mapper/metrics.py +30 -0
  38. accelforge/frontend/mapping/__init__.py +1 -0
  39. accelforge/frontend/mapping/mapping.py +1736 -0
  40. accelforge/frontend/model.py +14 -0
  41. accelforge/frontend/renames.py +150 -0
  42. accelforge/frontend/spec copy.py +230 -0
  43. accelforge/frontend/spec.py +301 -0
  44. accelforge/frontend/variables.py +12 -0
  45. accelforge/frontend/workload.py +952 -0
  46. accelforge/mapper/FFM/__init__.py +9 -0
  47. accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
  48. accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
  49. accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
  50. accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
  51. accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
  52. accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
  53. accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
  54. accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
  55. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
  56. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
  57. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
  58. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
  59. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
  60. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
  61. accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
  62. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
  63. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
  64. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
  65. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
  66. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
  67. accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
  68. accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
  69. accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
  70. accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
  71. accelforge/mapper/FFM/data.py +61 -0
  72. accelforge/mapper/FFM/main copy.py +236 -0
  73. accelforge/mapper/FFM/main.py +208 -0
  74. accelforge/mapper/FFM/mappings.py +510 -0
  75. accelforge/mapper/FFM/pmappings.py +310 -0
  76. accelforge/mapper/__init__.py +4 -0
  77. accelforge/mapper.py +0 -0
  78. accelforge/model/__init__.py +1 -0
  79. accelforge/model/_looptree/__init__.py +0 -0
  80. accelforge/model/_looptree/accesses.py +335 -0
  81. accelforge/model/_looptree/capacity/__init__.py +1 -0
  82. accelforge/model/_looptree/capacity/aggregators.py +36 -0
  83. accelforge/model/_looptree/capacity/capacity.py +47 -0
  84. accelforge/model/_looptree/energy.py +150 -0
  85. accelforge/model/_looptree/equivalent_ranks.py +29 -0
  86. accelforge/model/_looptree/latency/__init__.py +1 -0
  87. accelforge/model/_looptree/latency/latency.py +98 -0
  88. accelforge/model/_looptree/latency/memory.py +120 -0
  89. accelforge/model/_looptree/latency/processors.py +92 -0
  90. accelforge/model/_looptree/mapping_utilities.py +71 -0
  91. accelforge/model/_looptree/reuse/__init__.py +4 -0
  92. accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
  93. accelforge/model/_looptree/reuse/isl/des.py +59 -0
  94. accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
  95. accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
  96. accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
  97. accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
  98. accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
  99. accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
  100. accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
  101. accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
  102. accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
  103. accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
  104. accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
  105. accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
  106. accelforge/model/_looptree/run.py +122 -0
  107. accelforge/model/_looptree/types.py +26 -0
  108. accelforge/model/_looptree/visualization/__init__.py +0 -0
  109. accelforge/model/_looptree/visualization/occupancy.py +11 -0
  110. accelforge/model/main.py +222 -0
  111. accelforge/plotting/__init__.py +2 -0
  112. accelforge/plotting/mappings.py +219 -0
  113. accelforge/plotting/specs.py +57 -0
  114. accelforge/util/__init__.py +4 -0
  115. accelforge/util/_base_analysis_types.py +24 -0
  116. accelforge/util/_basetypes.py +1089 -0
  117. accelforge/util/_frozenset.py +36 -0
  118. accelforge/util/_isl.py +29 -0
  119. accelforge/util/_itertools.py +14 -0
  120. accelforge/util/_mathfuncs.py +57 -0
  121. accelforge/util/_parse_expressions.py +339 -0
  122. accelforge/util/_picklecache.py +32 -0
  123. accelforge/util/_setexpressions.py +268 -0
  124. accelforge/util/_sympy/__init__.py +0 -0
  125. accelforge/util/_sympy/broadcast_max.py +18 -0
  126. accelforge/util/_visualization.py +112 -0
  127. accelforge/util/_yaml.py +579 -0
  128. accelforge/util/parallel.py +193 -0
  129. accelforge-0.0.1.dist-info/METADATA +64 -0
  130. accelforge-0.0.1.dist-info/RECORD +258 -0
  131. accelforge-0.0.1.dist-info/WHEEL +5 -0
  132. accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
  133. accelforge-0.0.1.dist-info/top_level.txt +5 -0
  134. docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
  135. docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
  136. docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
  137. docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
  138. docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
  139. docs/_build/html/_sources/fastfusion.rst.txt +20 -0
  140. docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
  141. docs/_build/html/_sources/index.rst.txt +87 -0
  142. docs/_build/html/_sources/modules.rst.txt +7 -0
  143. docs/_build/html/_sources/notes/citation.rst.txt +45 -0
  144. docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
  145. docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
  146. docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
  147. docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
  148. docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
  149. docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
  150. docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
  151. docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
  152. docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
  153. docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
  154. docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
  155. docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
  156. docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
  157. docs/_build/html/_sources/notes/spec.rst.txt +36 -0
  158. docs/source/_ext/include_attrs.py +213 -0
  159. docs/source/_ext/include_docstring.py +364 -0
  160. docs/source/_ext/include_functions.py +154 -0
  161. docs/source/_ext/include_notebook.py +131 -0
  162. docs/source/_ext/include_yaml.py +119 -0
  163. docs/source/_ext/inherited_attributes.py +222 -0
  164. docs/source/_ext/paths.py +4 -0
  165. docs/source/conf.py +79 -0
  166. examples/arches/compute_in_memory/_include.yaml +74 -0
  167. examples/arches/compute_in_memory/_include_functions.py +229 -0
  168. examples/arches/compute_in_memory/_load_spec.py +57 -0
  169. examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
  170. examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
  171. examples/arches/compute_in_memory/components/misc.py +195 -0
  172. examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
  173. examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
  174. examples/arches/compute_in_memory/isaac.yaml +233 -0
  175. examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
  176. examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
  177. examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
  178. examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
  179. examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
  180. examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
  181. examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
  182. examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
  183. examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
  184. examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
  185. examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
  186. examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
  187. examples/arches/eyeriss.yaml +68 -0
  188. examples/arches/fanout_variations/at_glb.yaml +31 -0
  189. examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
  190. examples/arches/fanout_variations/at_mac.yaml +31 -0
  191. examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
  192. examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
  193. examples/arches/nvdla.yaml +47 -0
  194. examples/arches/simple.yaml +28 -0
  195. examples/arches/tpu_v4i.yaml +67 -0
  196. examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
  197. examples/misc/component_annotated.yaml +33 -0
  198. examples/workloads/gpt3_6.7B.yaml +124 -0
  199. examples/workloads/matmuls.yaml +20 -0
  200. examples/workloads/mobilenet_28.yaml +81 -0
  201. examples/workloads/mobilenet_various_separate.yaml +106 -0
  202. examples/workloads/three_matmuls_annotated.yaml +59 -0
  203. notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
  204. notebooks/compute_in_memory/_scripts.py +339 -0
  205. notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
  206. notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
  207. notebooks/paths.py +4 -0
  208. notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
  209. notebooks/tutorials/FFM.ipynb +3498 -0
  210. notebooks/tutorials/_include.py +48 -0
  211. notebooks/tutorials/component_energy_area.ipynb +363 -0
  212. tests/Q_mapping.yaml +38 -0
  213. tests/__init__.py +0 -0
  214. tests/conv.mapping.yaml +27 -0
  215. tests/conv.workload.yaml +13 -0
  216. tests/conv_sym.mapping.yaml +43 -0
  217. tests/copy.mapping.yaml +35 -0
  218. tests/copy.workload.yaml +15 -0
  219. tests/distribuffers/__init__.py +0 -0
  220. tests/distribuffers/multicast/test_cases.yaml +482 -0
  221. tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
  222. tests/distribuffers/spec/distributed.yaml +100 -0
  223. tests/distribuffers/spec/logical_arch.yaml +32 -0
  224. tests/distribuffers/spec/physical_arch.yaml +69 -0
  225. tests/distribuffers/test_binding.py +48 -0
  226. tests/frontend/__init__.py +0 -0
  227. tests/frontend/test_mapping_viz.py +52 -0
  228. tests/mapper/__init__.py +0 -0
  229. tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
  230. tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
  231. tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
  232. tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
  233. tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
  234. tests/mapper/test_mapping_to_isl.py +90 -0
  235. tests/mapper/test_spatial_reuse_analysis.py +67 -0
  236. tests/mapper/test_temporal_reuse_analysis.py +56 -0
  237. tests/mapper/util.py +58 -0
  238. tests/matmul.mapping.yaml +29 -0
  239. tests/matmul.workload.yaml +12 -0
  240. tests/matmul_spatial.mapping.yaml +44 -0
  241. tests/mha.renames.yaml +65 -0
  242. tests/mha.workload.yaml +67 -0
  243. tests/mha.yaml +59 -0
  244. tests/mha_full.workload.yaml +67 -0
  245. tests/mobilenet.workload.yaml +35 -0
  246. tests/mobilenet_long.workload.yaml +64 -0
  247. tests/pmappingcache.py +24 -0
  248. tests/processing_stage.arch.yaml +40 -0
  249. tests/snowcat.arch.yaml +36 -0
  250. tests/test_ffm_join_pmappings.py +106 -0
  251. tests/test_ffm_make_pmappings.py +82 -0
  252. tests/test_ffm_make_tile_shapes.py +49 -0
  253. tests/test_mapper.py +100 -0
  254. tests/test_model.py +37 -0
  255. tests/test_plotting.py +72 -0
  256. tests/test_processing_stage.py +46 -0
  257. tests/test_symbolic_model.py +248 -0
  258. tests/test_workload.py +141 -0
@@ -0,0 +1,1396 @@
1
+ import copy
2
+ from dataclasses import dataclass, field
3
+ import itertools
4
+ from accelforge.frontend.mapping import (
5
+ Compute,
6
+ Mapping,
7
+ Nested,
8
+ Pipeline,
9
+ ProcessingStage,
10
+ Reservation,
11
+ Sequential,
12
+ Spatial,
13
+ Split,
14
+ Storage,
15
+ Temporal,
16
+ )
17
+ from typing import Any, List
18
+
19
+ from accelforge.frontend import arch
20
+ import accelforge.frontend.mapping as mapping_spec
21
+ from accelforge.frontend.mapping import (
22
+ Mapping,
23
+ MappingNode,
24
+ Nested,
25
+ Spatial,
26
+ Temporal,
27
+ Storage,
28
+ Reservation,
29
+ Loop,
30
+ TensorHolder,
31
+ ProcessingStage,
32
+ )
33
+ from accelforge.frontend.workload import (
34
+ Workload,
35
+ TensorName,
36
+ isl_expression_has_variable,
37
+ )
38
+ from accelforge.frontend._workload_isl._isl import get_rank_variable_bounds
39
+ from accelforge.frontend._workload_isl._symbolic import (
40
+ get_projection_expr,
41
+ get_rank_variable_relevancy,
42
+ compute_dense_tile_occupancy,
43
+ Irrelevant,
44
+ Relevant,
45
+ PartiallyRelevant,
46
+ )
47
+
48
+ from accelforge.model._looptree.types import Buffet
49
+
50
+ from accelforge.mapper.FFM._make_pmappings.pmapper_job import Job
51
+ from accelforge.util._sympy.broadcast_max import Min, Max
52
+
53
+ import sympy
54
+
55
+
56
+ SYMBOL = "symbol"
57
+ IMPERFECT = False
58
+
59
+ PRINT_FORMULAS = False
60
+
61
+
62
+ @dataclass(eq=True, frozen=True)
63
+ class Compute:
64
+ einsum: str
65
+ level: str
66
+
67
+
68
+ class Uninitialized:
69
+ def __init__(self):
70
+ pass
71
+
72
+ def __str__(self):
73
+ return "Uninitialized"
74
+
75
+ def __repr__(self):
76
+ return "Uninitialized()"
77
+
78
+ def __rmul__(self, other):
79
+ return self * other
80
+
81
+ def __mul__(self, other):
82
+ return self
83
+
84
+ def __radd__(self, other):
85
+ return self + other
86
+
87
+ def __add__(self, other):
88
+ return self
89
+
90
+
91
+ # TODO: unsure if this is needed. If the sympy symbol is created with the
92
+ # correct assumption (e.g., positive), this should be automatic.
93
+ def min_nonzero(a: Any, b: Any) -> Any:
94
+ if a == 0:
95
+ return b
96
+ if b == 0:
97
+ return a
98
+ return Min(a, b)
99
+
100
+
101
+ def max_dict(a: dict[Any, Any], b: dict[Any, Any]) -> dict[Any, Any]:
102
+ new = {**a}
103
+ for key, value in b.items():
104
+ new[key] = Max(new[key], value) if key in new else value
105
+ assert isinstance(new, dict)
106
+ return new
107
+
108
+
109
+ @dataclass
110
+ class BuffetStats:
111
+ total_reads_to_parent: Any = field(default=0)
112
+ total_writes_to_parent: Any = field(default=0)
113
+ max_per_parent_reads_to_parent: Any = field(default=0)
114
+ max_per_parent_writes_to_parent: Any = field(default=0)
115
+
116
+ total_reads_to_peer: Any = field(default=0)
117
+ total_writes_to_peer: Any = field(default=0)
118
+ max_per_unit_reads_to_peer: Any = field(default=0)
119
+ max_per_unit_writes_to_peer: Any = field(default=0)
120
+
121
+ # Skip the first iteration of temporal loops for data that is written
122
+ total_skipped_first_reads_to_parent: Any = field(default=0)
123
+ total_skipped_first_reads_to_peer: Any = field(default=0)
124
+ min_per_parent_skipped_first_reads_to_parent: Any = field(default=0)
125
+ min_per_unit_skipped_first_writes_to_peer: Any = field(default=0)
126
+
127
+ max_occupancy: Any = field(default=0)
128
+ _n_loops_above: int = field(default=0)
129
+
130
+ # These are used to calculate energy and latency
131
+ total_write_actions: Any = field(default=0)
132
+ max_per_unit_write_actions: Any = field(default=0)
133
+ total_read_actions: Any = field(default=0)
134
+ max_per_unit_read_actions: Any = field(default=0)
135
+
136
+ total_skipped_first_write_actions: Any = field(default=0)
137
+ min_per_unit_skipped_first_write_actions: Any = field(default=0)
138
+ total_skipped_first_read_actions: Any = field(default=0)
139
+ min_per_unit_skipped_first_read_actions: Any = field(default=0)
140
+
141
+ persistent: bool = field(default=False)
142
+
143
+ @property
144
+ def n_loops_above(self) -> int:
145
+ if self.persistent:
146
+ return -1
147
+ return self._n_loops_above
148
+
149
+ @n_loops_above.setter
150
+ def n_loops_above(self, value: int):
151
+ self._n_loops_above = value
152
+
153
+ def repeat_temporal(self, factor: int, is_fully_relevant: bool) -> "BuffetStats":
154
+ new = copy.copy(self)
155
+ for attr in self.__dict__:
156
+ if not attr.startswith(("total_", "max_", "min_")):
157
+ continue
158
+ if "skipped_first" in attr and not is_fully_relevant:
159
+ continue # First actions occur once per relevant iteration.
160
+ if attr == "max_occupancy":
161
+ continue # Max occupancy is not affected by temporal loops above
162
+ setattr(new, attr, getattr(new, attr) * factor)
163
+ return new
164
+
165
+ def repeat_spatial(self, factor: int, reuse_parent_accesses: bool) -> "BuffetStats":
166
+ """
167
+ Duplicate attributes to account for spatial fanout.
168
+
169
+ Paramters
170
+ ---------
171
+ factor:
172
+ duplication factor
173
+ reuse_parent_access:
174
+ whether to reuse accesses to parent (if True, multicast is assumed
175
+ and accesses to parents are not duplicated).
176
+ """
177
+ new = copy.copy(self)
178
+ for attr in self.__dict__:
179
+ if not attr.startswith(("total_", "max_", "min_")):
180
+ continue
181
+ if "parent" in attr and reuse_parent_accesses:
182
+ continue # If parent accesses are reused, no need to multiply
183
+ if "per_unit" in attr:
184
+ continue # Spatial fanout doesn't affect per-unit stats
185
+ if attr == "max_occupancy":
186
+ continue # Max occupancy is not affected by temporal loops above
187
+ setattr(new, attr, getattr(new, attr) * factor)
188
+ return new
189
+
190
+ def max(self, **kwargs: Any):
191
+ for key, value in kwargs.items():
192
+ setattr(self, key, Max(getattr(self, key), value))
193
+
194
+ def min(self, **kwargs: Any):
195
+ for key, value in kwargs.items():
196
+ setattr(self, key, Min(getattr(self, key), value))
197
+
198
+ def __add__(self, other: "BuffetStats") -> "BuffetStats":
199
+ new = copy.copy(self)
200
+ for attr in self.__dict__:
201
+ if attr.startswith("min_"):
202
+ setattr(
203
+ new, attr, min_nonzero(getattr(self, attr), getattr(other, attr))
204
+ )
205
+ elif attr.startswith("max_"):
206
+ setattr(new, attr, Max(getattr(self, attr), getattr(other, attr)))
207
+ elif attr.startswith("total_"):
208
+ setattr(new, attr, getattr(self, attr) + getattr(other, attr))
209
+ elif getattr(self, attr) is None:
210
+ setattr(new, attr, getattr(other, attr))
211
+ elif getattr(other, attr) is None:
212
+ setattr(new, attr, getattr(self, attr))
213
+ else:
214
+ assert getattr(self, attr) == getattr(
215
+ other, attr
216
+ ), f"BUG: {attr} is different. self: {getattr(self, attr)} other: {getattr(other, attr)}"
217
+ return new
218
+
219
+ def __iadd__(self, other: "BuffetStats") -> "BuffetStats":
220
+ new = self + other
221
+ for key, value in new.__dict__.items():
222
+ setattr(self, key, value)
223
+ return self
224
+
225
+ def net_total_read_actions(self) -> Any:
226
+ return self.total_read_actions - self.total_skipped_first_read_actions
227
+
228
+ def net_total_write_actions(self) -> Any:
229
+ return self.total_write_actions - self.total_skipped_first_write_actions
230
+
231
+ def net_max_per_unit_read_actions(self) -> Any:
232
+ return (
233
+ self.max_per_unit_read_actions
234
+ - self.min_per_unit_skipped_first_read_actions
235
+ )
236
+
237
+ def net_max_per_unit_write_actions(self) -> Any:
238
+ return (
239
+ self.max_per_unit_write_actions
240
+ - self.min_per_unit_skipped_first_write_actions
241
+ )
242
+
243
+
244
+ def blank_buffet_stats() -> BuffetStats:
245
+ stats = BuffetStats()
246
+ stats.n_loops_above = None # Inherit from whoever is added to this
247
+ return stats
248
+
249
+
250
+ @dataclass
251
+ class ComputeStats:
252
+ total_ops: Any = field(default=0)
253
+ max_per_unit_ops: Any = field(default=0)
254
+ # "max" below refers to the longest latency of any iteration
255
+ max_latency: Any = field(default=0)
256
+ # Mapping from the loop-index (0 at top) to the latency of the first
257
+ # iteration of that loop. "Max" because we may have loops above that and we
258
+ # will take the maximum of the firsts.
259
+ max_first_latency: dict[int, Any] = field(default_factory=dict)
260
+
261
+ def repeat_temporal(self, factor: int) -> "ComputeStats":
262
+ new = copy.copy(self)
263
+ new.total_ops *= factor
264
+ new.max_per_unit_ops *= factor
265
+ new.max_latency *= factor
266
+ # NOTE: max_first_latency does not change
267
+ return new
268
+
269
+ def repeat_spatial(self, factor: int) -> "ComputeStats":
270
+ new = copy.copy(self)
271
+ new.total_ops *= factor
272
+ return new
273
+
274
+ def __add__(self, other: "ComputeStats") -> "ComputeStats":
275
+ new = copy.copy(self)
276
+ new.total_ops += other.total_ops
277
+ new.max_per_unit_ops += other.max_per_unit_ops
278
+ new.max_latency += other.max_latency
279
+ # max_first_latency is only ever updated across loops ABOVE the loop
280
+ # for which we calculated that first latency, so we should MAX
281
+ new.max_first_latency = max_dict(
282
+ self.max_first_latency, other.max_first_latency
283
+ ) # FIRST LATENCY
284
+ return new
285
+
286
+ def combine_temporal(self, other: "ComputeStats"):
287
+ self.total_ops += other.total_ops
288
+ self.max_per_unit_ops += other.max_per_unit_ops
289
+ self.max_latency += other.max_latency
290
+ # max_first_latency is only ever updated across loops ABOVE the loop
291
+ # for which we calculated that first latency, so we should MAX
292
+ self.max_first_latency = max_dict(
293
+ self.max_first_latency, other.max_first_latency
294
+ ) # FIRST LATENCY
295
+
296
+ def combine_spatial(self, other: "ComputeStats"):
297
+ self.total_ops += other.total_ops
298
+ self.max_per_unit_ops = Max(self.max_per_unit_ops, other.max_per_unit_ops)
299
+ self.max_latency = Max(self.max_latency, other.max_latency)
300
+ # max_first_latency is only ever updated across loops ABOVE the loop
301
+ # for which we calculated that first latency, so we should MAX
302
+ self.max_first_latency = max_dict(
303
+ self.max_first_latency, other.max_first_latency
304
+ ) # FIRST LATENCY
305
+
306
+
307
+ @dataclass
308
+ class SymbolicAnalysisOutput:
309
+ compute_stats: dict[Compute, ComputeStats] = field(default_factory=dict)
310
+
311
+ buffet_stats: dict[Buffet, BuffetStats] = field(default_factory=dict)
312
+
313
+ # Mapping [level, einsum] to the fanout
314
+ fanout: dict[(Buffet, str), int] = field(default_factory=dict)
315
+
316
+ # Mapping [einsum] to the number of temporal steps
317
+ temporal_steps: dict[str, int] = field(default_factory=dict)
318
+
319
+ symbols: list[sympy.Symbol] = field(default_factory=list)
320
+
321
+ # tensor to the mapping for that particular tensor
322
+ tensor2mapping: dict[TensorName, Mapping] = field(default_factory=dict)
323
+
324
+ def get_buffet_for_tensor(self, tensor: TensorName) -> Buffet:
325
+ for buffet in self.buffet_stats:
326
+ if buffet.tensor == tensor:
327
+ return buffet
328
+ raise ValueError(f"Buffet for tensor {tensor} not found")
329
+
330
+ def max(self, **kwargs: Any):
331
+ for key, value in kwargs.items():
332
+ assert key in [
333
+ "compute_stats",
334
+ "stats",
335
+ "fanout",
336
+ "temporal_steps",
337
+ ]
338
+ previous = getattr(self, key)
339
+ for k, v in value.items():
340
+ previous.setdefault(k, {})
341
+ for k2, v2 in v.items():
342
+ if k2 in previous[k]:
343
+ previous[k][k2] = Max(previous[k][k2], v2)
344
+ else:
345
+ previous[k][k2] = v2
346
+
347
+ def get_child_buffet_stats(self, buffet: Buffet) -> BuffetStats:
348
+ seen = False
349
+ for child_buffet, child_stats in reversed(self.buffet_stats.items()):
350
+ if not seen:
351
+ seen = child_buffet == buffet
352
+ continue
353
+ if child_buffet.tensor == buffet.tensor:
354
+ return child_stats
355
+ return None
356
+
357
+ def sum_buffet_stats_per_level(self) -> dict[str, BuffetStats]:
358
+ result: dict[str, BuffetStats] = {}
359
+ for buffet, stats in self.buffet_stats.items():
360
+ result.setdefault(buffet.level, blank_buffet_stats())
361
+ result[buffet.level] += stats
362
+ return result
363
+
364
+ def add_buffet_stats_and_symbols(self, other: "SymbolicAnalysisOutput"):
365
+ assert not (set(self.buffet_stats) & set(other.buffet_stats)), "BUG"
366
+ self.buffet_stats.update(other.buffet_stats)
367
+ # if self.temporal_steps != other.temporal_steps:
368
+ # print(f'Temporal steps are different.')
369
+ # print(f'\tmine: {self.temporal_steps}')
370
+ # print(f'\tother: {other.temporal_steps}')
371
+ # assert self.temporal_steps == other.temporal_steps, "BUG"
372
+ self.temporal_steps.update(other.temporal_steps)
373
+ self.symbols.extend([s for s in other.symbols if s not in self.symbols])
374
+ # Assert compute stats are the same
375
+ # assert self.compute_stats == other.compute_stats, "BUG"
376
+
377
+
378
+ @dataclass
379
+ class AnalysisInfo:
380
+ """Information needed within the analysis step by multiple functions that
381
+ can be computed once at the beginning.
382
+ """
383
+
384
+ mapping: Mapping
385
+ workload: Workload
386
+ full_rank_variable_shapes: dict
387
+ all_tensors: set
388
+
389
+ einsum_tensor_to_projection: dict
390
+ tensor_to_relevancy: dict
391
+ tensor_to_backer_id: dict[TensorName, int]
392
+
393
+ is_copy_operation: TensorName | None
394
+
395
+ job: Job
396
+
397
+ tensor_to_reservation_backer_id: dict[TensorName, int] = field(default_factory=dict)
398
+
399
+ # We track first latency for these nodes (should be Temporal)
400
+ last_temporal_node_idx: int = None
401
+ """
402
+ node idx of the last (above) temporal node
403
+ """
404
+ idxs_to_track_first_latency: set[int] = field(default_factory=set)
405
+ """
406
+ node idxs for which we track first latency
407
+ """
408
+
409
+
410
+ def quick_insert_reservation_nodes(job: Job) -> list[MappingNode]:
411
+ mapping = list(job.mapping.nodes)
412
+ workload = job.spec.workload
413
+
414
+ # TODO: Subclass reservation with TensorReservation or something so that we can
415
+ # track which are for tensors and which are for non-tensor resources.
416
+
417
+ info = AnalysisInfo(
418
+ mapping=None,
419
+ workload=workload,
420
+ full_rank_variable_shapes=None,
421
+ all_tensors=None,
422
+ einsum_tensor_to_projection=None,
423
+ tensor_to_relevancy=job.tensor_to_relevancy,
424
+ tensor_to_backer_id=None,
425
+ is_copy_operation=None,
426
+ job=None,
427
+ )
428
+ insert_reservation_nodes(mapping, info, job.fusable_tensors)
429
+ m = Mapping(nodes=mapping)
430
+ m._n_loop_orders = job.mapping._n_loop_orders
431
+ return m
432
+
433
+
434
+ def convert_to_copy(
435
+ mapping: list[MappingNode], workload: Workload
436
+ ) -> tuple[list[MappingNode], dict[TensorName, int]]:
437
+ mapping = copy.deepcopy(mapping)
438
+
439
+ # Calculate this BEFORE we modify the mapping. We're going to have the copy source
440
+ # tensor moving upward sometimes, and we don't want the backing tensor holder
441
+ tensor_to_backer_id = get_tensor_to_backer_id(mapping)
442
+
443
+ first_input_tensor = workload.einsums[mapping[-1].einsum].copy_source_tensor()
444
+
445
+ for node in mapping:
446
+ if isinstance(node, TensorHolder):
447
+ if node.tensors:
448
+ node.tensors = [first_input_tensor]
449
+ node._lower = False
450
+
451
+ to_remove = []
452
+ i = 0
453
+ while i < len(mapping):
454
+ node = mapping[i]
455
+ if isinstance(node, TensorHolder):
456
+ j = i + 1
457
+ while j < len(mapping):
458
+ node2 = mapping[j]
459
+ if (
460
+ isinstance(node2, TensorHolder)
461
+ and node.component == node2.component
462
+ ):
463
+ mapping.pop(j)
464
+ else:
465
+ j += 1
466
+ i += 1
467
+ mapping = [node for node in mapping if node not in to_remove]
468
+
469
+ return mapping, tensor_to_backer_id
470
+
471
+
472
+ def analyze_reuse_and_add_reservations_to_mapping(
473
+ job: Job,
474
+ ) -> SymbolicAnalysisOutput:
475
+ mapping = job.mapping.nodes
476
+ workload = job.spec.workload
477
+ einsum_name = mapping[-1].einsum
478
+
479
+ is_copy_operation = workload.einsums[einsum_name].is_copy_operation
480
+ symbols = insert_sympy_symbols(job.mapping.nodes, job)
481
+
482
+ if is_copy_operation:
483
+ mapping, tensor_to_backer_id = convert_to_copy(mapping, workload)
484
+ else:
485
+ tensor_to_backer_id = get_tensor_to_backer_id(mapping)
486
+
487
+ job.mapping = quick_insert_reservation_nodes(job)
488
+ # print(f'Job mapping: {job.mapping.compact_str()}')
489
+ # for n in job.mapping.nodes:
490
+ # print(f'\t{n.compact_str()}')
491
+
492
+ einsum_tensor_to_projection = {}
493
+ einsum = workload.einsums[einsum_name]
494
+ all_tensors = einsum.tensor_names
495
+ for tensor in all_tensors:
496
+ einsum_tensor_to_projection[(einsum_name, tensor)] = get_projection_expr(
497
+ einsum, tensor
498
+ )
499
+ tensor_to_relevancy = {
500
+ tensor: get_rank_variable_relevancy(einsum, tensor) for tensor in all_tensors
501
+ }
502
+ assert all_tensors, f"Einsum {einsum_name} has no tensors"
503
+
504
+ """
505
+ Note for how this works.
506
+
507
+ Spatial loops are weird, because they don't belong at a single point in the loop
508
+ nest. For example:
509
+
510
+ - DRAM keep A, B
511
+ - *
512
+ - Reg keep A
513
+ - for n in [0..N)
514
+ - GLB keep B
515
+ - *
516
+ - Compute
517
+
518
+ A loop spatial-for (Reg) k in [0..K) would affect the register at the point of the
519
+ first asterisk, but the global buffer at the point of the second asterisk.
520
+
521
+ To handle this, we make a separate mapping for each tensor, analyze each, and
522
+ combine the results.
523
+
524
+ To anyone who would like to create behavior that simultaneously looks at multiple
525
+ storage nodes for a given memory, note that there will be two challenges to address:
526
+
527
+ 1. The code currently analyzes one tensor at a time. This could be fixed by
528
+ processing all mapping(s) together, applying loop(s) from each to only the
529
+ appropriate nodes.
530
+ 2. The code must analyze one storage node at a time, and there may be temporal and
531
+ spatial nodes between two storage nodes for a given memory, which would separate
532
+ the analysis steps for the storage nodes. This may be addressed by only
533
+ performing such analysis until the outermost storage node for a particular memory
534
+ has been analyzed.
535
+ """
536
+ result = None
537
+
538
+ tensor2mapping = {}
539
+ index_expressions = set(einsum.indexing_expressions)
540
+ for k, v in job.rank_variable_bounds.items():
541
+ index_expressions.add(f"0 < {k} <= {v}")
542
+ for tensor in all_tensors:
543
+ cur_mapping = job.mapping._get_single_tensor_mapping(
544
+ tensor, job.flattened_arch, index_expressions
545
+ )
546
+ info = AnalysisInfo(
547
+ mapping=cur_mapping.nodes,
548
+ workload=workload,
549
+ full_rank_variable_shapes=job.rank_variable_bounds,
550
+ all_tensors=set([tensor]),
551
+ einsum_tensor_to_projection=einsum_tensor_to_projection,
552
+ tensor_to_relevancy=tensor_to_relevancy,
553
+ tensor_to_backer_id=tensor_to_backer_id,
554
+ is_copy_operation=is_copy_operation,
555
+ job=job,
556
+ )
557
+ cur_result = analyze_node(0, job.rank_variable_bounds, info)
558
+ if result is None:
559
+ result = cur_result
560
+ else:
561
+ result.add_buffet_stats_and_symbols(cur_result)
562
+ tensor2mapping[tensor] = cur_mapping
563
+
564
+ result.symbols = symbols
565
+ result.tensor2mapping = tensor2mapping
566
+
567
+ if PRINT_FORMULAS:
568
+ print(f"Mapping:")
569
+ for node in mapping:
570
+ print(f"\t{node.compact_str()}")
571
+
572
+ for buffet, stats in result.buffet_stats.items():
573
+ print(f"Einsum {buffet.einsum} tensor {buffet.tensor} level {buffet.level}")
574
+ print(f"\tTotal reads to parent: {stats.total_reads_to_parent}")
575
+ print(f"\tTotal writes to parent: {stats.total_writes_to_parent}")
576
+ print(
577
+ f"\tMax per parent reads to parent: {stats.max_per_parent_reads_to_parent}"
578
+ )
579
+ print(
580
+ f"\tMax per parent writes to parent: {stats.max_per_parent_writes_to_parent}"
581
+ )
582
+ print(f"\tTotal reads to peer: {stats.total_reads_to_peer}")
583
+ print(f"\tTotal writes to peer: {stats.total_writes_to_peer}")
584
+ print(f"\tMax per unit reads to peer: {stats.max_per_unit_reads_to_peer}")
585
+ print(f"\tMax per unit writes to peer: {stats.max_per_unit_writes_to_peer}")
586
+
587
+ return result
588
+
589
+
590
+ def get_tensor_to_backer_id(mapping: Mapping):
591
+ tensor_to_ids: dict[TensorName, set[int]] = {}
592
+ for node in mapping:
593
+ if isinstance(node, TensorHolder):
594
+ for tensor in node.tensors:
595
+ if tensor in tensor_to_ids:
596
+ continue
597
+ tensor_to_ids[tensor] = id(node)
598
+ return tensor_to_ids
599
+
600
+
601
+ class ReservationAnalysisTracker:
602
+ def __init__(self, buffet, node):
603
+ self.buffet: Buffet = buffet
604
+ self.node: TensorHolder = node
605
+
606
+ # These are interface (TODO: should be property)
607
+ self.is_fill_level = False
608
+ self.should_stop = False
609
+ self.insert_reservation_under = False
610
+ self.insert_fill_under = False
611
+
612
+ # Temporary values
613
+ self.has_filled = False
614
+
615
+ def track_temporal_loop(self, relevancy, node):
616
+ self.is_fill_level = False
617
+ self.insert_reservation_under = False
618
+ self.insert_fill_under = False
619
+
620
+ if isinstance(relevancy, Irrelevant):
621
+ if not self.has_filled:
622
+ self.is_fill_level = True
623
+ self.has_filled = True
624
+
625
+ self.should_stop = True
626
+ elif isinstance(relevancy, Relevant):
627
+ self.should_stop = False
628
+ elif isinstance(relevancy, PartiallyRelevant):
629
+ self.last = True
630
+
631
+ if not self.has_filled:
632
+ self.is_fill_level = True
633
+ self.has_filled = True
634
+
635
+ self.should_stop = True
636
+ self.insert_reservation_under = True
637
+ else:
638
+ raise ValueError(f"Unknown relevancy {relevancy}")
639
+
640
+ def track_compute(self):
641
+ self.should_stop = True
642
+ if not self.has_filled:
643
+ self.is_fill_level = True
644
+ self.has_filled = True
645
+
646
+ def track_spatial_loop(self, relevancy, node):
647
+ if node.component != self.buffet.level:
648
+ self.should_stop = True
649
+ if not self.has_filled:
650
+ self.is_fill_level = True
651
+ self.has_filled = True
652
+ return
653
+
654
+ self.is_fill_level = False
655
+ self.should_stop = False
656
+
657
+
658
+ def insert_reservation_nodes(
659
+ mapping, info: AnalysisInfo, fusable_tensors: set[TensorName]
660
+ ):
661
+ trackers: list[ReservationAnalysisTracker] = []
662
+ einsum = info.workload.einsums[mapping[-1].einsum]
663
+ non_intermediate_tensors = (
664
+ einsum.tensor_names - info.workload.tensor_names_used_in_multiple_einsums
665
+ )
666
+ seen_tensors = set() # reservation for top-level buffets cannot be lowered
667
+
668
+ n_nodes = len(mapping)
669
+ i = 0
670
+ while i < n_nodes:
671
+ node = mapping[i]
672
+ to_remove = []
673
+ if isinstance(node, Reservation):
674
+ pass
675
+ elif isinstance(node, Temporal):
676
+ rank = node.rank_variable
677
+ for tracker in trackers:
678
+ relevancy = info.tensor_to_relevancy[tracker.buffet.tensor]
679
+ tracker.track_temporal_loop(relevancy[rank], node)
680
+ elif isinstance(node, Spatial):
681
+ rank = node.rank_variable
682
+ for tracker in trackers:
683
+ relevancy = info.tensor_to_relevancy[tracker.buffet.tensor]
684
+ tracker.track_spatial_loop(relevancy[rank], node)
685
+ elif isinstance(node, TensorHolder):
686
+ for tracker in trackers:
687
+ tracker.should_stop = True
688
+ tracker.insert_reservation_under = False
689
+ for tensor in node.tensors:
690
+ tensor = TensorName(tensor)
691
+ buffet = Buffet(tensor, mapping[-1].einsum, node.component)
692
+ trackers.append(ReservationAnalysisTracker(buffet, node))
693
+ if not node._lower or (
694
+ tensor not in seen_tensors and tensor in non_intermediate_tensors
695
+ ):
696
+ seen_tensors.add(tensor)
697
+ trackers[-1].is_fill_level = True
698
+ trackers[-1].insert_reservation_under = True
699
+ trackers[-1].insert_fill_under = True
700
+ trackers[-1].should_stop = True
701
+ elif isinstance(node, mapping_spec.Compute):
702
+ for tracker in trackers:
703
+ tracker.track_compute()
704
+ tracker.insert_reservation_under = False
705
+ else:
706
+ raise NotImplementedError(f"Unknown node type {type(node)}")
707
+
708
+ reservation_insert_below = []
709
+ reservation_insert_above = []
710
+ for j in range(len(trackers) - 1, -1, -1):
711
+ if not trackers[j].should_stop:
712
+ continue
713
+ tracker = trackers.pop(j)
714
+ buffet = tracker.buffet
715
+ node = Reservation(purposes=[buffet.tensor], resource=buffet.level)
716
+ node.persistent = tracker.node.persistent
717
+ node._backing = tracker.node._backing
718
+
719
+ if (
720
+ buffet.tensor not in info.tensor_to_reservation_backer_id
721
+ and buffet.tensor in info.workload.tensor_names_used_in_multiple_einsums
722
+ ):
723
+ info.tensor_to_reservation_backer_id[buffet.tensor] = id(node)
724
+
725
+ if tracker.insert_reservation_under:
726
+ reservation_insert_below.append(node)
727
+ else:
728
+ reservation_insert_above.append(node)
729
+
730
+ # The order of these for loops is important. Reservation must be below fill.
731
+ for node in reservation_insert_below:
732
+ mapping.insert(i + 1, node)
733
+ i += 1
734
+ for node in reservation_insert_above:
735
+ mapping.insert(i, node)
736
+ i += 1
737
+
738
+ i += 1
739
+ n_nodes = len(mapping)
740
+
741
+ label_fused_loops(mapping, fusable_tensors)
742
+
743
+
744
+ def label_fused_loops(mapping: List[MappingNode], fusable_tensors: set[TensorName]):
745
+ last_backer = -1
746
+ assert isinstance(
747
+ fusable_tensors, set
748
+ ), f"Fusable tensors must be a set, got {type(fusable_tensors)}"
749
+
750
+ for i, node in enumerate(mapping):
751
+ if not isinstance(node, TensorHolder) or isinstance(node, Reservation):
752
+ continue
753
+ assert isinstance(node._backing, set)
754
+ if node._backing & fusable_tensors:
755
+ last_backer = i
756
+ # We may lower this node through at most one additional loop
757
+ if node._lower:
758
+ last_backer += 1
759
+ if last_backer == -1 and fusable_tensors:
760
+ raise ValueError(
761
+ f"No backing TensorHolder found in mapping {", ".join(m.compact_str() for m in mapping)}"
762
+ )
763
+
764
+ for i, node in enumerate[MappingNode](mapping):
765
+ if isinstance(node, Loop):
766
+ node._fused = i < last_backer
767
+
768
+ return mapping
769
+
770
+
771
+ def analyze_node(node_idx, current_shape, info: AnalysisInfo) -> SymbolicAnalysisOutput:
772
+ node = info.mapping[node_idx]
773
+ class2analysis_function = {
774
+ Temporal: analyze_temporal,
775
+ Spatial: analyze_spatial,
776
+ Storage: analyze_storage,
777
+ Reservation: analyze_reservation,
778
+ mapping_spec.Compute: analyze_compute,
779
+ ProcessingStage: analyze_processing_stage,
780
+ }
781
+ if type(node) not in class2analysis_function:
782
+ raise TypeError(f"Unknown node type {type(node)}")
783
+ return class2analysis_function[type(node)](node_idx, current_shape, info)
784
+
785
+
786
+ def analyze_temporal(
787
+ node_idx, current_shape, info: AnalysisInfo
788
+ ) -> SymbolicAnalysisOutput:
789
+ mapping = info.mapping
790
+ node = mapping[node_idx]
791
+ stride_and_shape = get_stride_and_tile_shape(node, current_shape, node_idx, info)
792
+
793
+ result_accumulator = SymbolicAnalysisOutput()
794
+
795
+ first_latency = None
796
+
797
+ def handle_repeated_value(repeated_shape):
798
+ nonlocal first_latency
799
+ shape_value = repeated_shape.value
800
+ shape_repeats = repeated_shape.repeats
801
+
802
+ child_shape = current_shape.copy()
803
+ child_shape[node.rank_variable] = shape_value
804
+
805
+ child_result = analyze_node(node_idx + 1, child_shape, info)
806
+
807
+ accumulated_buffet_stats = result_accumulator.buffet_stats
808
+ for buffet, stats in child_result.buffet_stats.items():
809
+ relevancy = info.tensor_to_relevancy[buffet.tensor][node.rank_variable]
810
+ is_fully_relevant = isinstance(relevancy, Relevant)
811
+ accumulated_stats = accumulated_buffet_stats.setdefault(
812
+ buffet, blank_buffet_stats()
813
+ )
814
+ accumulated_stats += stats.repeat_temporal(
815
+ shape_repeats, is_fully_relevant=is_fully_relevant
816
+ )
817
+ accumulated_stats.n_loops_above = stats.n_loops_above + 1
818
+
819
+ for einsum, child_steps in child_result.temporal_steps.items():
820
+ if einsum not in result_accumulator.temporal_steps:
821
+ result_accumulator.temporal_steps[einsum] = 0
822
+ result_accumulator.temporal_steps[einsum] += child_steps * shape_repeats
823
+
824
+ result_accumulator.max(fanout=child_result.fanout)
825
+
826
+ for key in child_result.compute_stats:
827
+ if first_latency is None:
828
+ first_latency = child_result.compute_stats[key].max_latency
829
+
830
+ compute_stats = result_accumulator.compute_stats.setdefault(
831
+ key, ComputeStats()
832
+ )
833
+ compute_stats += child_result.compute_stats[key].repeat_temporal(
834
+ shape_repeats
835
+ )
836
+ result_accumulator.compute_stats[key] = compute_stats
837
+
838
+ info.last_temporal_node_idx = node_idx
839
+
840
+ shape = stride_and_shape.shape
841
+ if isinstance(shape, SequenceOfRepatedvalues):
842
+ for repeated_shape in shape.sequence:
843
+ assert isinstance(repeated_shape, RepeatedValue)
844
+ handle_repeated_value(repeated_shape)
845
+ elif isinstance(shape, RepeatedValue):
846
+ handle_repeated_value(shape)
847
+
848
+ if node_idx in info.idxs_to_track_first_latency:
849
+ for compute_stat in result_accumulator.compute_stats.values():
850
+ # Should be the first time we store this value
851
+ assert node_idx not in compute_stat.max_first_latency
852
+ compute_stat.max_first_latency[node_idx] = first_latency
853
+
854
+ return result_accumulator
855
+
856
+
857
+ def analyze_spatial(node_idx, current_shape, info: AnalysisInfo):
858
+ mapping = info.mapping
859
+ einsum_name = mapping[-1].einsum
860
+ node: Spatial = mapping[node_idx]
861
+ rank_var = node.rank_variable
862
+ node_dim = node.name
863
+ spatial_component = node.component_object
864
+ component_spatial_dim = spatial_component.spatial[node_dim]
865
+ stride_and_shape = get_stride_and_tile_shape(node, current_shape, node_idx, info)
866
+
867
+ result_accumulator = SymbolicAnalysisOutput()
868
+
869
+ def handle_repeated_value(repeated_shape):
870
+ shape_value = repeated_shape.value
871
+ shape_repeats = repeated_shape.repeats
872
+
873
+ child_shape = current_shape.copy()
874
+ child_shape[node.rank_variable] = shape_value
875
+
876
+ child_result = analyze_node(node_idx + 1, child_shape, info)
877
+
878
+ accumulated_buffet_stats = result_accumulator.buffet_stats
879
+ child_stats = list(child_result.buffet_stats.items())
880
+ for i, (buffet, buffet_stats) in enumerate(child_stats):
881
+ stats = buffet_stats
882
+ accumulated_stats = accumulated_buffet_stats.setdefault(
883
+ buffet, blank_buffet_stats()
884
+ )
885
+ relevancy = info.tensor_to_relevancy[buffet.tensor][rank_var]
886
+
887
+ # Reuse parent accesses only:
888
+ # - Irrelevant loops
889
+ # - The outermost level that holds the tensor (the one whose parent accesses
890
+ # will be going through the network)
891
+ last_buffet = True
892
+ for other_buffet, _ in child_stats[i + 1 :]:
893
+ if other_buffet.tensor == buffet.tensor:
894
+ last_buffet = False
895
+ break
896
+
897
+ reuse_parent_accesses = (
898
+ last_buffet
899
+ and isinstance(relevancy, Irrelevant)
900
+ and buffet.tensor in component_spatial_dim.may_reuse
901
+ )
902
+
903
+ accumulated_stats += stats.repeat_spatial(
904
+ shape_repeats, reuse_parent_accesses=reuse_parent_accesses
905
+ )
906
+ accumulated_stats.n_loops_above = stats.n_loops_above + 1
907
+
908
+ for einsum, child_steps in child_result.temporal_steps.items():
909
+ if einsum not in result_accumulator.temporal_steps:
910
+ result_accumulator.temporal_steps[einsum] = child_steps
911
+ else:
912
+ result_accumulator.temporal_steps[einsum] = Max(
913
+ result_accumulator.temporal_steps[einsum], child_steps
914
+ )
915
+
916
+ my_key = (node.component, einsum_name)
917
+ child_result.fanout.setdefault(my_key, {})
918
+
919
+ # Propagate up everything except the current level and dimension
920
+ child_fanout = copy.deepcopy(child_result.fanout)
921
+ target_fanout = child_fanout[my_key].pop(node_dim, 1)
922
+ result_accumulator.max(fanout=child_fanout)
923
+
924
+ # Prpoagate current level and dimension * shape_repeats
925
+ child_fanout = child_result.fanout[my_key]
926
+ fanout = result_accumulator.fanout.setdefault(my_key, {})
927
+ fanout.setdefault(node_dim, 0) # TODO: Assume sympy can just take in 0
928
+ # TODO: If node_dim was missing, the original code would have omitted
929
+ # shape_repeats. Is this correct?
930
+ fanout[node_dim] += target_fanout * shape_repeats
931
+
932
+ for key in child_result.compute_stats:
933
+ # TODO: ensure that `ComputeStats()`, which is initialized ONCE, is okay to use here
934
+ compute_stats = result_accumulator.compute_stats.setdefault(
935
+ key, ComputeStats()
936
+ )
937
+ # TODO: If check omitted. This was in the original code, check history if needed.
938
+ compute_stats.combine_spatial(
939
+ child_result.compute_stats[key].repeat_spatial(shape_repeats)
940
+ )
941
+
942
+ shape = stride_and_shape.shape
943
+ if isinstance(shape, SequenceOfRepatedvalues):
944
+ for repeated_shape in shape.sequence:
945
+ assert isinstance(repeated_shape, RepeatedValue)
946
+ handle_repeated_value(repeated_shape)
947
+ elif isinstance(shape, RepeatedValue):
948
+ handle_repeated_value(shape)
949
+
950
+ return result_accumulator
951
+
952
+
953
+ def reduce_dicts(dict1: dict, dict2: dict, reduce_op):
954
+ for key in dict1:
955
+ if key not in dict2:
956
+ dict2[key] = dict1[key]
957
+ else:
958
+ dict2[key] = reduce_op(dict1[key], dict2[key])
959
+
960
+
961
+ def get_total_to_per_unit(total, max_per_unit):
962
+ if total == 0 and max_per_unit != 0:
963
+ raise ValueError(f"total is 0 but max_per_unit is {max_per_unit}")
964
+ if total == 0:
965
+ return 1
966
+ return max_per_unit / total
967
+
968
+
969
+ def has_parent_tensor_holder(
970
+ tensor: TensorName, node_idx: int, info: AnalysisInfo
971
+ ) -> bool:
972
+ for node in info.mapping[:node_idx]:
973
+ if isinstance(node, TensorHolder) and tensor in node.tensors:
974
+ return True
975
+ return False
976
+
977
+
978
+ def find_component_object(
979
+ component: str, flattened_arch: list[arch.Leaf]
980
+ ) -> arch.TensorHolder:
981
+ for node in flattened_arch:
982
+ if node.name == component:
983
+ return node
984
+ raise ValueError(f"Component {component} not found in flattened arch")
985
+
986
+
987
+ def analyze_storage(
988
+ node_idx: int,
989
+ current_shape: dict[str, int],
990
+ info: AnalysisInfo,
991
+ propagate_child_results: bool = False,
992
+ count_upward_movement: bool = True,
993
+ count_downward_movement: bool = True,
994
+ count_writes: bool = True,
995
+ ):
996
+ mapping = info.mapping
997
+ einsum_name = mapping[-1].einsum
998
+ node: TensorHolder = mapping[node_idx]
999
+
1000
+ child_result = analyze_node(node_idx + 1, current_shape, info)
1001
+
1002
+ for tensor in node.tensors:
1003
+ tensor = TensorName(tensor)
1004
+ buffet = Buffet(tensor, einsum_name, node.component)
1005
+
1006
+ # Reservations make these, and they go below the storage node, so the buffet
1007
+ # stats are already made at this point
1008
+ stats = child_result.buffet_stats[buffet]
1009
+ backer_id = info.tensor_to_backer_id[tensor]
1010
+ is_backing = backer_id == id(node)
1011
+ if node.persistent:
1012
+ stats.persistent = True
1013
+ below_backing = backer_id in [id(m) for m in mapping[:node_idx]]
1014
+
1015
+ projection = info.einsum_tensor_to_projection[(einsum_name, tensor)]
1016
+
1017
+ fills = compute_dense_tile_occupancy(projection, current_shape)
1018
+
1019
+ child = child_result.get_child_buffet_stats(buffet)
1020
+ inherit_from_child = propagate_child_results and child is not None
1021
+
1022
+ # ==============================================================================
1023
+ # Calculate the total fills and reads to parent. These propagate upward.
1024
+ # ==============================================================================
1025
+
1026
+ def inherit_add(attr: str, default_value: Any = fills) -> Any:
1027
+ val = getattr(child, attr) if inherit_from_child else default_value
1028
+ setattr(stats, attr, val + getattr(stats, attr))
1029
+
1030
+ if has_parent_tensor_holder(tensor, node_idx, info):
1031
+ # Initial fetch: If we're below the backing storage, fetch data from above
1032
+ # at the beginning.
1033
+ if not is_backing and below_backing:
1034
+ inherit_add("total_reads_to_parent", fills)
1035
+ inherit_add("max_per_parent_reads_to_parent", fills)
1036
+
1037
+ # Data writeback. Do not writeback if it's a copy operation and we're below
1038
+ # the backing storage; data only flows upward.
1039
+
1040
+ # Writeback occurs in two cases:
1041
+ # - We're at or above the backing storage, so we need to propagate our
1042
+ # results upward to any storage nodes that will need this data.
1043
+ # - This is a written tensor, so we need to write back the written data.
1044
+ if (
1045
+ tensor in info.workload.einsums[einsum_name].output_tensor_names
1046
+ or not below_backing
1047
+ ):
1048
+ inherit_add("total_writes_to_parent")
1049
+ inherit_add("max_per_parent_writes_to_parent")
1050
+
1051
+ # For read+write tensors, we skip the first fill because the data will be
1052
+ # initialized with a zero value.
1053
+ if tensor in info.workload.einsums[einsum_name].output_tensor_names:
1054
+ inherit_add("total_skipped_first_reads_to_parent")
1055
+ inherit_add("min_per_parent_skipped_first_reads_to_parent")
1056
+
1057
+ # ==============================================================================
1058
+ # Convert to actions. These are not used used upward; they are used to get
1059
+ # energy and latency.
1060
+ # ==============================================================================
1061
+ component_object = find_component_object(
1062
+ node.component, info.job.flattened_arch
1063
+ )
1064
+ bits_per_value_scale = component_object.bits_per_value_scale[tensor]
1065
+ bits_per_value = (
1066
+ bits_per_value_scale
1067
+ * info.job.einsum.tensor_accesses[tensor].bits_per_value
1068
+ )
1069
+ read_bits_per_action = component_object.actions["read"].bits_per_action
1070
+ read_scale = bits_per_value / read_bits_per_action
1071
+ if count_writes:
1072
+ write_bits_per_action = component_object.actions["write"].bits_per_action
1073
+ write_scale = bits_per_value / write_bits_per_action
1074
+ else:
1075
+ write_scale = 0
1076
+
1077
+ # ==========================
1078
+ # Data exchanges with parent
1079
+ if count_downward_movement: # Parent -> Me
1080
+ stats.total_write_actions += stats.total_reads_to_parent * write_scale
1081
+ stats.max_per_unit_write_actions += (
1082
+ stats.total_reads_to_parent * write_scale
1083
+ )
1084
+ stats.total_skipped_first_write_actions += (
1085
+ stats.total_skipped_first_reads_to_parent * write_scale
1086
+ )
1087
+ stats.min_per_unit_skipped_first_write_actions += (
1088
+ stats.min_per_parent_skipped_first_reads_to_parent * write_scale
1089
+ )
1090
+
1091
+ if count_upward_movement: # Me -> Parent
1092
+ # Comment this to have the final writeback to a buffer hit both that buffer and
1093
+ # go directly to the parent without incurring another read from the buffer.
1094
+ stats.total_read_actions += stats.total_writes_to_parent * read_scale
1095
+ stats.max_per_unit_read_actions += stats.total_writes_to_parent * read_scale
1096
+
1097
+ # ========================
1098
+ # Data exchanges with peer
1099
+ stats.total_read_actions += stats.total_reads_to_peer * read_scale
1100
+ stats.total_write_actions += stats.total_reads_to_peer * write_scale
1101
+
1102
+ # =========================
1103
+ # Data exchanges with child
1104
+ if child is not None:
1105
+ if count_downward_movement: # Me -> Child
1106
+ stats.total_read_actions += child.total_reads_to_parent * read_scale
1107
+ stats.max_per_unit_read_actions += (
1108
+ child.max_per_parent_reads_to_parent * read_scale
1109
+ )
1110
+ # Skip first read
1111
+ stats.total_skipped_first_read_actions += (
1112
+ child.total_skipped_first_reads_to_parent * read_scale
1113
+ )
1114
+ stats.min_per_unit_skipped_first_read_actions += (
1115
+ child.min_per_parent_skipped_first_reads_to_parent * read_scale
1116
+ )
1117
+
1118
+ if count_upward_movement: # Child -> Me
1119
+ stats.total_write_actions += child.total_writes_to_parent * write_scale
1120
+ stats.max_per_unit_write_actions += (
1121
+ child.max_per_parent_writes_to_parent * write_scale
1122
+ )
1123
+
1124
+ return child_result
1125
+
1126
+
1127
+ def analyze_processing_stage(node_idx, current_shape, info: AnalysisInfo):
1128
+ mapping = info.mapping
1129
+ einsum_name = mapping[-1].einsum
1130
+ node = mapping[node_idx]
1131
+ component_object = find_component_object(node.component, info.job.flattened_arch)
1132
+ storage_result = analyze_storage(
1133
+ node_idx,
1134
+ current_shape,
1135
+ info,
1136
+ propagate_child_results=True,
1137
+ count_upward_movement=component_object.direction != "down",
1138
+ count_downward_movement=component_object.direction != "up",
1139
+ count_writes=False,
1140
+ )
1141
+ for tensor in node.tensors:
1142
+ buffet = Buffet(tensor, einsum_name, node.component)
1143
+ stats = storage_result.buffet_stats[buffet]
1144
+ stats.max_occupancy = 0
1145
+ assert stats.total_write_actions == 0
1146
+ return storage_result
1147
+
1148
+
1149
+ def analyze_reservation(node_idx, current_shape, info: AnalysisInfo):
1150
+ mapping = info.mapping
1151
+ einsum_name = mapping[-1].einsum
1152
+ node = mapping[node_idx]
1153
+ tensor = TensorName(node.purpose)
1154
+
1155
+ if info.last_temporal_node_idx is not None and id(
1156
+ node
1157
+ ) == info.tensor_to_reservation_backer_id.get(node.purpose, None):
1158
+ info.idxs_to_track_first_latency.add(info.last_temporal_node_idx)
1159
+
1160
+ child_result = analyze_node(node_idx + 1, current_shape, info)
1161
+
1162
+ buffet = Buffet(tensor, einsum_name, node.resource)
1163
+
1164
+ # Reservation nodes are the first to produce stats for a buffet
1165
+ assert buffet not in child_result.buffet_stats
1166
+
1167
+ stats = BuffetStats()
1168
+ projection = info.einsum_tensor_to_projection[(einsum_name, tensor)]
1169
+ component_object = find_component_object(node.resource, info.job.flattened_arch)
1170
+ bits_per_value_scale = component_object.bits_per_value_scale[tensor]
1171
+ bits_per_value = (
1172
+ bits_per_value_scale * info.job.einsum.tensor_accesses[tensor].bits_per_value
1173
+ )
1174
+ stats.max_occupancy = (
1175
+ compute_dense_tile_occupancy(projection, current_shape) * bits_per_value
1176
+ )
1177
+ child_result.buffet_stats[buffet] = stats
1178
+
1179
+ fanout_key = (node.resource, einsum_name)
1180
+ if fanout_key not in child_result.fanout:
1181
+ child_result.fanout[fanout_key] = {}
1182
+
1183
+ return child_result
1184
+
1185
+
1186
+ def analyze_compute(
1187
+ node_idx, current_shape, info: AnalysisInfo
1188
+ ) -> SymbolicAnalysisOutput:
1189
+ einsum = info.mapping[-1].einsum
1190
+ node = info.mapping[node_idx]
1191
+
1192
+ computes = 0 if info.is_copy_operation else 1
1193
+
1194
+ result_accumulator = SymbolicAnalysisOutput()
1195
+
1196
+ result_accumulator.temporal_steps[einsum] = computes
1197
+ result_accumulator.compute_stats[Compute(einsum, node.component)] = ComputeStats(
1198
+ computes,
1199
+ computes,
1200
+ 1,
1201
+ )
1202
+
1203
+ if info.is_copy_operation:
1204
+ return result_accumulator
1205
+
1206
+ for tensor in info.all_tensors:
1207
+ buffet = Buffet(tensor, einsum, node.component)
1208
+ stats = BuffetStats()
1209
+ stats.total_reads_to_parent = 1
1210
+ stats.max_per_parent_reads_to_parent = 1
1211
+ if tensor in info.workload.einsums[einsum].output_tensor_names:
1212
+ stats.total_writes_to_parent = 1
1213
+ stats.max_per_parent_writes_to_parent = 1
1214
+ stats.total_skipped_first_reads_to_parent = 1
1215
+ stats.min_per_parent_skipped_first_reads_to_parent = 1
1216
+ stats.max_occupancy = 1
1217
+ result_accumulator.buffet_stats[buffet] = stats
1218
+
1219
+ return result_accumulator
1220
+
1221
+
1222
+ @dataclass
1223
+ class RepeatedValue[T]:
1224
+ value: T
1225
+ repeats: int
1226
+
1227
+
1228
+ @dataclass
1229
+ class SequenceOfRepatedvalues[T]:
1230
+ sequence: list[RepeatedValue[T]]
1231
+
1232
+
1233
+ @dataclass
1234
+ class StrideAndShape:
1235
+ stride: any
1236
+ shape: any
1237
+
1238
+
1239
+ def get_stride_and_tile_shape(node: Loop, full_shape, n: int, info: AnalysisInfo):
1240
+ rank = node.rank_variable
1241
+ rank_shape = full_shape[rank]
1242
+
1243
+ stride = node.tile_shape
1244
+ initial_tile_shape = node.initial_tile_shape
1245
+
1246
+ # PERFECT:
1247
+ # - Node shape = stride
1248
+ # - # Iterations = total shape / stride
1249
+ # IMPERFECT:
1250
+ # - Node shape = stride
1251
+ # - # Iterations = ceil(total shape / stride)
1252
+ if IMPERFECT and initial_tile_shape is None:
1253
+ factor = sympy.ceiling(rank_shape / stride)
1254
+ stride_avg = stride / sympy.ceiling(rank_shape / stride)
1255
+ return StrideAndShape(stride_avg, RepeatedValue(stride, factor))
1256
+
1257
+ if initial_tile_shape is None:
1258
+ if node._assume_perfect_factor or known_perfect_factor(stride, rank_shape):
1259
+ factor = rank_shape / stride
1260
+ return StrideAndShape(stride, RepeatedValue(stride, factor))
1261
+ else:
1262
+ factor = sympy.ceiling(rank_shape / sympy.Min(stride, rank_shape))
1263
+ return make_possibly_different_last(stride, factor, rank_shape)
1264
+
1265
+ middle_shape_factor = sympy.ceiling((rank_shape - initial_tile_shape) / stride)
1266
+ # TODO: sometimes last_shape is 0, causing numerical instability
1267
+ # Currently, we are sometimes rounding up last shape.
1268
+ # last_shape = rank_shape - initial_tile_shape - stride*middle_shape_factor
1269
+ # has_last_shape = sympy.ceiling(last_shape/(last_shape+1))
1270
+ return StrideAndShape(
1271
+ stride,
1272
+ SequenceOfRepatedvalues(
1273
+ [
1274
+ RepeatedValue(initial_tile_shape, 1),
1275
+ RepeatedValue(stride, middle_shape_factor),
1276
+ # RepeatedValue(last_shape+0.01, has_last_shape)
1277
+ ]
1278
+ ),
1279
+ )
1280
+ # if node.tile_shape is not None:
1281
+ # tile_shape = node.tile_shape
1282
+
1283
+ # if node._assume_perfect_factor or known_perfect_factor(tile_shape, rank_shape):
1284
+ # factor = rank_shape / tile_shape
1285
+ # return StrideAndShape(tile_shape, RepeatedValue(tile_shape, factor))
1286
+ # else:
1287
+ # factor = sympy.ceiling(rank_shape / sympy.Min(tile_shape, rank_shape))
1288
+ # return make_possibly_different_last(tile_shape, factor, rank_shape)
1289
+ # elif node.loop_bound is not None:
1290
+ # factor = node.loop_bound
1291
+
1292
+ # if node._assume_perfect_factor or known_perfect_factor(factor, rank_shape):
1293
+ # tile_shape = rank_shape / factor
1294
+ # return StrideAndShape(tile_shape, RepeatedValue(tile_shape, factor))
1295
+ # else:
1296
+ # tile_shape = sympy.ceiling(rank_shape / sympy.Min(rank_shape, factor))
1297
+ # return make_possibly_different_last(tile_shape, factor, rank_shape)
1298
+
1299
+ # elif node.tile_pattern is not None:
1300
+ # stride = node.tile_pattern.tile_shape
1301
+ # initial_tile_shape = node.tile_pattern.initial_tile_shape
1302
+ # tile_shape = node.tile_pattern.tile_shape
1303
+
1304
+ # if initial_tile_shape is not None:
1305
+ # middle_shape_factor = sympy.ceiling((rank_shape - initial_tile_shape)/stride)
1306
+ # # TODO: sometimes last_shape is 0, causing numerical instability
1307
+ # # Currently, we are sometimes rounding up last shape.
1308
+ # # last_shape = rank_shape - initial_tile_shape - stride*middle_shape_factor
1309
+ # # has_last_shape = sympy.ceiling(last_shape/(last_shape+1))
1310
+ # return StrideAndShape(
1311
+ # stride,
1312
+ # SequenceOfRepatedvalues([
1313
+ # RepeatedValue(initial_tile_shape, 1),
1314
+ # RepeatedValue(stride, middle_shape_factor),
1315
+ # # RepeatedValue(last_shape+0.01, has_last_shape)
1316
+ # ])
1317
+ # )
1318
+
1319
+
1320
+ def known_perfect_factor(divisor, full_shape):
1321
+ return (
1322
+ isinstance(divisor, int)
1323
+ and isinstance(full_shape, int)
1324
+ and full_shape % divisor == 1
1325
+ )
1326
+
1327
+
1328
+ def make_possibly_different_last(common_tile_shape, factor, full_shape):
1329
+ last_shape = full_shape - common_tile_shape * (factor - 1)
1330
+ all_shapes = SequenceOfRepatedvalues(
1331
+ [RepeatedValue(common_tile_shape, factor - 1), RepeatedValue(last_shape, 1)]
1332
+ )
1333
+ return StrideAndShape(common_tile_shape, all_shapes)
1334
+
1335
+
1336
+ def insert_sympy_symbols(mapping: list[MappingNode], job: Job):
1337
+ loop_idx = 0
1338
+ symbols = []
1339
+ rank_var_with_initial = set()
1340
+ for i, node in enumerate(mapping):
1341
+ if not isinstance(node, Loop):
1342
+ continue
1343
+
1344
+ stride_halos = set()
1345
+ for t in job.spec.workload.einsums[job.einsum_name].tensor_names:
1346
+ for (rank, rank_variable), (stride, halo) in job.stride_and_halo[t].items():
1347
+ if rank_variable == node.rank_variable:
1348
+ stride_halos.add((stride, halo))
1349
+
1350
+ if len(stride_halos) == 0:
1351
+ raise RuntimeError(
1352
+ f"{repr(node.rank_variable)} not found in {job.stride_and_halo}"
1353
+ )
1354
+
1355
+ # We only explore imperfect for the outermost fused loops
1356
+ simple = (
1357
+ (len(stride_halos) <= 1 and next(iter(stride_halos)) == (1, 0))
1358
+ or node.rank_variable in rank_var_with_initial
1359
+ or not node._fused
1360
+ )
1361
+
1362
+ # NOTE: initial_tile_shape must be inserted into `symbols` before `stride`
1363
+ # because of the order of tile shape exploration.
1364
+ # TODO: there has to be a better way to do this.
1365
+ if simple: # Just use the stride!
1366
+ node.initial_tile_shape = None
1367
+ elif node.initial_tile_shape == SYMBOL:
1368
+ rank_var_with_initial.add(node.rank_variable)
1369
+ initial_tile_shape = sympy.symbols(
1370
+ f"initial{loop_idx}", positive=True, integer=True
1371
+ )
1372
+ symbols.append(initial_tile_shape)
1373
+ node.initial_tile_shape = initial_tile_shape
1374
+
1375
+ # TODO: Check for 0 < shape < 1 for loop bound target
1376
+ if job.rank_variable_bounds[node.rank_variable] == 1:
1377
+ node.tile_shape = 1
1378
+ elif node.tile_shape == SYMBOL:
1379
+ stride = sympy.symbols(f"stride{loop_idx}", positive=True, integer=True)
1380
+ symbols.append(stride)
1381
+ node.tile_shape = stride
1382
+
1383
+ # TODO: sometimes, a mapping is passed into the model twice.
1384
+ # E.g., after calling mapper, the model is called again for more
1385
+ # details.
1386
+ #
1387
+ # assert (
1388
+ # node.calculated_n_iterations is None
1389
+ # ), "Number of iterations is derived from the model. Do not set it!"
1390
+ node.calculated_n_iterations = sympy.symbols(
1391
+ f"n_iterations{loop_idx}", positive=True, integer=True
1392
+ )
1393
+
1394
+ loop_idx += 1
1395
+
1396
+ return symbols