accelforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. accelforge/__init__.py +21 -0
  2. accelforge/_accelerated_imports.py +16 -0
  3. accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
  4. accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
  5. accelforge/_deprecate/_simanneal/simanneal.py +666 -0
  6. accelforge/_deprecate/_simanneal/tracking.py +105 -0
  7. accelforge/_deprecate/_simanneal/wrappers.py +218 -0
  8. accelforge/_deprecate/_simanneal2/__init__.py +7 -0
  9. accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
  10. accelforge/_deprecate/_simanneal2/tracking.py +116 -0
  11. accelforge/_deprecate/compatibility_util.py +181 -0
  12. accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
  13. accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
  14. accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
  15. accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
  16. accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
  17. accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
  18. accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
  19. accelforge/_deprecate/tags.py +69 -0
  20. accelforge/_deprecate/viz/__init__.py +0 -0
  21. accelforge/_deprecate/viz/interactive.py +159 -0
  22. accelforge/_deprecate/viz/reservationtree.py +307 -0
  23. accelforge/_deprecate/viz/ski_slope.py +88 -0
  24. accelforge/_version.py +15 -0
  25. accelforge/examples.py +39 -0
  26. accelforge/frontend/__init__.py +10 -0
  27. accelforge/frontend/_binding.py +129 -0
  28. accelforge/frontend/_workload_isl/__init__.py +2 -0
  29. accelforge/frontend/_workload_isl/_isl.py +149 -0
  30. accelforge/frontend/_workload_isl/_symbolic.py +141 -0
  31. accelforge/frontend/arch copy.py +1544 -0
  32. accelforge/frontend/arch.py +1642 -0
  33. accelforge/frontend/config.py +63 -0
  34. accelforge/frontend/mapper/__init__.py +5 -0
  35. accelforge/frontend/mapper/ffm.py +126 -0
  36. accelforge/frontend/mapper/mapper.py +7 -0
  37. accelforge/frontend/mapper/metrics.py +30 -0
  38. accelforge/frontend/mapping/__init__.py +1 -0
  39. accelforge/frontend/mapping/mapping.py +1736 -0
  40. accelforge/frontend/model.py +14 -0
  41. accelforge/frontend/renames.py +150 -0
  42. accelforge/frontend/spec copy.py +230 -0
  43. accelforge/frontend/spec.py +301 -0
  44. accelforge/frontend/variables.py +12 -0
  45. accelforge/frontend/workload.py +952 -0
  46. accelforge/mapper/FFM/__init__.py +9 -0
  47. accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
  48. accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
  49. accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
  50. accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
  51. accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
  52. accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
  53. accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
  54. accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
  55. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
  56. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
  57. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
  58. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
  59. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
  60. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
  61. accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
  62. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
  63. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
  64. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
  65. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
  66. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
  67. accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
  68. accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
  69. accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
  70. accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
  71. accelforge/mapper/FFM/data.py +61 -0
  72. accelforge/mapper/FFM/main copy.py +236 -0
  73. accelforge/mapper/FFM/main.py +208 -0
  74. accelforge/mapper/FFM/mappings.py +510 -0
  75. accelforge/mapper/FFM/pmappings.py +310 -0
  76. accelforge/mapper/__init__.py +4 -0
  77. accelforge/mapper.py +0 -0
  78. accelforge/model/__init__.py +1 -0
  79. accelforge/model/_looptree/__init__.py +0 -0
  80. accelforge/model/_looptree/accesses.py +335 -0
  81. accelforge/model/_looptree/capacity/__init__.py +1 -0
  82. accelforge/model/_looptree/capacity/aggregators.py +36 -0
  83. accelforge/model/_looptree/capacity/capacity.py +47 -0
  84. accelforge/model/_looptree/energy.py +150 -0
  85. accelforge/model/_looptree/equivalent_ranks.py +29 -0
  86. accelforge/model/_looptree/latency/__init__.py +1 -0
  87. accelforge/model/_looptree/latency/latency.py +98 -0
  88. accelforge/model/_looptree/latency/memory.py +120 -0
  89. accelforge/model/_looptree/latency/processors.py +92 -0
  90. accelforge/model/_looptree/mapping_utilities.py +71 -0
  91. accelforge/model/_looptree/reuse/__init__.py +4 -0
  92. accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
  93. accelforge/model/_looptree/reuse/isl/des.py +59 -0
  94. accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
  95. accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
  96. accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
  97. accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
  98. accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
  99. accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
  100. accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
  101. accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
  102. accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
  103. accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
  104. accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
  105. accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
  106. accelforge/model/_looptree/run.py +122 -0
  107. accelforge/model/_looptree/types.py +26 -0
  108. accelforge/model/_looptree/visualization/__init__.py +0 -0
  109. accelforge/model/_looptree/visualization/occupancy.py +11 -0
  110. accelforge/model/main.py +222 -0
  111. accelforge/plotting/__init__.py +2 -0
  112. accelforge/plotting/mappings.py +219 -0
  113. accelforge/plotting/specs.py +57 -0
  114. accelforge/util/__init__.py +4 -0
  115. accelforge/util/_base_analysis_types.py +24 -0
  116. accelforge/util/_basetypes.py +1089 -0
  117. accelforge/util/_frozenset.py +36 -0
  118. accelforge/util/_isl.py +29 -0
  119. accelforge/util/_itertools.py +14 -0
  120. accelforge/util/_mathfuncs.py +57 -0
  121. accelforge/util/_parse_expressions.py +339 -0
  122. accelforge/util/_picklecache.py +32 -0
  123. accelforge/util/_setexpressions.py +268 -0
  124. accelforge/util/_sympy/__init__.py +0 -0
  125. accelforge/util/_sympy/broadcast_max.py +18 -0
  126. accelforge/util/_visualization.py +112 -0
  127. accelforge/util/_yaml.py +579 -0
  128. accelforge/util/parallel.py +193 -0
  129. accelforge-0.0.1.dist-info/METADATA +64 -0
  130. accelforge-0.0.1.dist-info/RECORD +258 -0
  131. accelforge-0.0.1.dist-info/WHEEL +5 -0
  132. accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
  133. accelforge-0.0.1.dist-info/top_level.txt +5 -0
  134. docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
  135. docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
  136. docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
  137. docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
  138. docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
  139. docs/_build/html/_sources/fastfusion.rst.txt +20 -0
  140. docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
  141. docs/_build/html/_sources/index.rst.txt +87 -0
  142. docs/_build/html/_sources/modules.rst.txt +7 -0
  143. docs/_build/html/_sources/notes/citation.rst.txt +45 -0
  144. docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
  145. docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
  146. docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
  147. docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
  148. docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
  149. docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
  150. docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
  151. docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
  152. docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
  153. docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
  154. docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
  155. docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
  156. docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
  157. docs/_build/html/_sources/notes/spec.rst.txt +36 -0
  158. docs/source/_ext/include_attrs.py +213 -0
  159. docs/source/_ext/include_docstring.py +364 -0
  160. docs/source/_ext/include_functions.py +154 -0
  161. docs/source/_ext/include_notebook.py +131 -0
  162. docs/source/_ext/include_yaml.py +119 -0
  163. docs/source/_ext/inherited_attributes.py +222 -0
  164. docs/source/_ext/paths.py +4 -0
  165. docs/source/conf.py +79 -0
  166. examples/arches/compute_in_memory/_include.yaml +74 -0
  167. examples/arches/compute_in_memory/_include_functions.py +229 -0
  168. examples/arches/compute_in_memory/_load_spec.py +57 -0
  169. examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
  170. examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
  171. examples/arches/compute_in_memory/components/misc.py +195 -0
  172. examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
  173. examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
  174. examples/arches/compute_in_memory/isaac.yaml +233 -0
  175. examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
  176. examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
  177. examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
  178. examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
  179. examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
  180. examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
  181. examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
  182. examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
  183. examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
  184. examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
  185. examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
  186. examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
  187. examples/arches/eyeriss.yaml +68 -0
  188. examples/arches/fanout_variations/at_glb.yaml +31 -0
  189. examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
  190. examples/arches/fanout_variations/at_mac.yaml +31 -0
  191. examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
  192. examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
  193. examples/arches/nvdla.yaml +47 -0
  194. examples/arches/simple.yaml +28 -0
  195. examples/arches/tpu_v4i.yaml +67 -0
  196. examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
  197. examples/misc/component_annotated.yaml +33 -0
  198. examples/workloads/gpt3_6.7B.yaml +124 -0
  199. examples/workloads/matmuls.yaml +20 -0
  200. examples/workloads/mobilenet_28.yaml +81 -0
  201. examples/workloads/mobilenet_various_separate.yaml +106 -0
  202. examples/workloads/three_matmuls_annotated.yaml +59 -0
  203. notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
  204. notebooks/compute_in_memory/_scripts.py +339 -0
  205. notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
  206. notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
  207. notebooks/paths.py +4 -0
  208. notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
  209. notebooks/tutorials/FFM.ipynb +3498 -0
  210. notebooks/tutorials/_include.py +48 -0
  211. notebooks/tutorials/component_energy_area.ipynb +363 -0
  212. tests/Q_mapping.yaml +38 -0
  213. tests/__init__.py +0 -0
  214. tests/conv.mapping.yaml +27 -0
  215. tests/conv.workload.yaml +13 -0
  216. tests/conv_sym.mapping.yaml +43 -0
  217. tests/copy.mapping.yaml +35 -0
  218. tests/copy.workload.yaml +15 -0
  219. tests/distribuffers/__init__.py +0 -0
  220. tests/distribuffers/multicast/test_cases.yaml +482 -0
  221. tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
  222. tests/distribuffers/spec/distributed.yaml +100 -0
  223. tests/distribuffers/spec/logical_arch.yaml +32 -0
  224. tests/distribuffers/spec/physical_arch.yaml +69 -0
  225. tests/distribuffers/test_binding.py +48 -0
  226. tests/frontend/__init__.py +0 -0
  227. tests/frontend/test_mapping_viz.py +52 -0
  228. tests/mapper/__init__.py +0 -0
  229. tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
  230. tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
  231. tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
  232. tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
  233. tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
  234. tests/mapper/test_mapping_to_isl.py +90 -0
  235. tests/mapper/test_spatial_reuse_analysis.py +67 -0
  236. tests/mapper/test_temporal_reuse_analysis.py +56 -0
  237. tests/mapper/util.py +58 -0
  238. tests/matmul.mapping.yaml +29 -0
  239. tests/matmul.workload.yaml +12 -0
  240. tests/matmul_spatial.mapping.yaml +44 -0
  241. tests/mha.renames.yaml +65 -0
  242. tests/mha.workload.yaml +67 -0
  243. tests/mha.yaml +59 -0
  244. tests/mha_full.workload.yaml +67 -0
  245. tests/mobilenet.workload.yaml +35 -0
  246. tests/mobilenet_long.workload.yaml +64 -0
  247. tests/pmappingcache.py +24 -0
  248. tests/processing_stage.arch.yaml +40 -0
  249. tests/snowcat.arch.yaml +36 -0
  250. tests/test_ffm_join_pmappings.py +106 -0
  251. tests/test_ffm_make_pmappings.py +82 -0
  252. tests/test_ffm_make_tile_shapes.py +49 -0
  253. tests/test_mapper.py +100 -0
  254. tests/test_model.py +37 -0
  255. tests/test_plotting.py +72 -0
  256. tests/test_processing_stage.py +46 -0
  257. tests/test_symbolic_model.py +248 -0
  258. tests/test_workload.py +141 -0
@@ -0,0 +1,1408 @@
1
+ import copy
2
+ from dataclasses import dataclass, field
3
+ import itertools
4
+ from accelforge.frontend.mapping import (
5
+ Compute,
6
+ Mapping,
7
+ Nested,
8
+ Pipeline,
9
+ ProcessingStage,
10
+ Reservation,
11
+ Sequential,
12
+ Spatial,
13
+ Split,
14
+ Storage,
15
+ Temporal,
16
+ )
17
+ from typing import Any
18
+
19
+ from accelforge.frontend import arch
20
+ import accelforge.frontend.mapping as mapping_spec
21
+ from accelforge.frontend.mapping import (
22
+ Mapping,
23
+ MappingNode,
24
+ Nested,
25
+ Spatial,
26
+ Temporal,
27
+ Storage,
28
+ Reservation,
29
+ Loop,
30
+ TensorHolder,
31
+ ProcessingStage,
32
+ )
33
+ from accelforge.frontend.workload import (
34
+ Workload,
35
+ TensorName,
36
+ isl_expression_has_variable,
37
+ )
38
+ from accelforge.frontend._workload_isl._isl import get_rank_variable_bounds
39
+ from accelforge.frontend._workload_isl._symbolic import (
40
+ get_projection_expr,
41
+ get_rank_variable_relevancy,
42
+ compute_dense_tile_occupancy,
43
+ Irrelevant,
44
+ Relevant,
45
+ PartiallyRelevant,
46
+ )
47
+
48
+ from accelforge.model._looptree.types import Buffet
49
+
50
+ from accelforge.mapper.FFM._make_pmappings.pmapper_job import Job
51
+ from accelforge.util._sympy.broadcast_max import Min, Max
52
+
53
+ import sympy
54
+
55
+
56
+ SYMBOL = "symbol"
57
+ IMPERFECT = False
58
+
59
+
60
+ @dataclass(eq=True, frozen=True)
61
+ class Compute:
62
+ einsum: str
63
+ level: str
64
+
65
+
66
+ class Uninitialized:
67
+ def __init__(self):
68
+ pass
69
+
70
+ def __str__(self):
71
+ return "Uninitialized"
72
+
73
+ def __repr__(self):
74
+ return "Uninitialized()"
75
+
76
+ def __rmul__(self, other):
77
+ return self * other
78
+
79
+ def __mul__(self, other):
80
+ return self
81
+
82
+ def __radd__(self, other):
83
+ return self + other
84
+
85
+ def __add__(self, other):
86
+ return self
87
+
88
+
89
+ # TODO: unsure if this is needed. If the sympy symbol is created with the
90
+ # correct assumption (e.g., positive), this should be automatic.
91
+ def min_nonzero(a: Any, b: Any) -> Any:
92
+ if a == 0:
93
+ return b
94
+ if b == 0:
95
+ return a
96
+ return Min(a, b)
97
+
98
+
99
+ def max_dict(a: dict[Any, Any], b: dict[Any, Any]) -> dict[Any, Any]:
100
+ new = {**a}
101
+ for key, value in b.items():
102
+ new[key] = Max(new[key], value) if key in new else value
103
+ assert isinstance(new, dict)
104
+ return new
105
+
106
+
107
+ @dataclass
108
+ class BuffetStats:
109
+ total_reads_to_parent: Any = field(default=0)
110
+ total_writes_to_parent: Any = field(default=0)
111
+ max_per_parent_reads_to_parent: Any = field(default=0)
112
+ max_per_parent_writes_to_parent: Any = field(default=0)
113
+
114
+ total_reads_to_peer: Any = field(default=0)
115
+ total_writes_to_peer: Any = field(default=0)
116
+ max_per_unit_reads_to_peer: Any = field(default=0)
117
+ max_per_unit_writes_to_peer: Any = field(default=0)
118
+
119
+ total_writes_to_child: Any = field(default=0)
120
+ total_reads_to_child: Any = field(default=0)
121
+ max_per_unit_writes_to_child: Any = field(default=0)
122
+ max_per_unit_reads_to_child: Any = field(default=0)
123
+
124
+ # Skip the first iteration of temporal loops for data that is written
125
+ total_skipped_first_reads_to_parent: Any = field(default=0)
126
+ total_skipped_first_reads_to_peer: Any = field(default=0)
127
+ total_skipped_first_writes_to_child: Any = field(default=0)
128
+ min_per_parent_skipped_first_reads_to_parent: Any = field(default=0)
129
+ min_per_unit_skipped_first_writes_to_peer: Any = field(default=0)
130
+ min_per_unit_skipped_first_writes_to_child: Any = field(default=0)
131
+
132
+ max_occupancy: Any = field(default=0)
133
+ _n_loops_above: int = field(default=1)
134
+
135
+ persistent: bool = field(default=False)
136
+
137
+ _write_scale: float = field(default=None)
138
+ _read_scale: float = field(default=None)
139
+ _count_upward_movement: bool = field(default=None)
140
+ _count_downward_movement: bool = field(default=None)
141
+
142
+ @property
143
+ def write_scale(self) -> Any:
144
+ return self._write_scale
145
+
146
+ @write_scale.setter
147
+ def write_scale(self, value: Any):
148
+ assert self._write_scale is None or self._write_scale == value, "BUG"
149
+ self._write_scale = value
150
+
151
+ @property
152
+ def read_scale(self) -> Any:
153
+ return self._read_scale
154
+
155
+ @read_scale.setter
156
+ def read_scale(self, value: Any):
157
+ assert self._read_scale is None or self._read_scale == value, "BUG"
158
+ self._read_scale = value
159
+
160
+ @property
161
+ def count_upward_movement(self) -> bool:
162
+ return self._count_upward_movement
163
+
164
+ @count_upward_movement.setter
165
+ def count_upward_movement(self, value: bool):
166
+ assert (
167
+ self._count_upward_movement is None or self._count_upward_movement == value
168
+ ), "BUG"
169
+ self._count_upward_movement = value
170
+
171
+ @property
172
+ def count_downward_movement(self) -> bool:
173
+ return self._count_downward_movement
174
+
175
+ @count_downward_movement.setter
176
+ def count_downward_movement(self, value: bool):
177
+ assert (
178
+ self._count_downward_movement is None
179
+ or self._count_downward_movement == value
180
+ ), "BUG"
181
+ self._count_downward_movement = value
182
+
183
+ @property
184
+ def n_loops_above(self) -> int:
185
+ if self.persistent:
186
+ return -1
187
+ return self._n_loops_above
188
+
189
+ @n_loops_above.setter
190
+ def n_loops_above(self, value: int):
191
+ self._n_loops_above = value
192
+
193
+ def repeat_temporal(self, factor: int, is_fully_relevant: bool) -> "BuffetStats":
194
+ new = copy.copy(self)
195
+ for attr in self.__dict__:
196
+ if not attr.startswith(("total_", "max_", "min_")):
197
+ continue
198
+ if "skipped_first" in attr and not is_fully_relevant:
199
+ continue # First actions occur once per relevant iteration.
200
+ if attr == "max_occupancy":
201
+ continue # Max occupancy is not affected by temporal loops above
202
+ setattr(new, attr, getattr(new, attr) * factor)
203
+ return new
204
+
205
+ def repeat_spatial(self, factor: int, reuse_parent_accesses: bool) -> "BuffetStats":
206
+ new = copy.copy(self)
207
+ for attr in self.__dict__:
208
+ if not attr.startswith(("total_", "max_", "min_")):
209
+ continue
210
+ if "parent" in attr and reuse_parent_accesses:
211
+ continue # If parent accesses are reused, no need to multiply
212
+ if "per_unit" in attr:
213
+ continue # Spatial fanout doesn't affect per-unit stats
214
+ if attr == "max_occupancy":
215
+ continue # Max occupancy is not affected by temporal loops above
216
+ setattr(new, attr, getattr(new, attr) * factor)
217
+ return new
218
+
219
+ def max(self, **kwargs: Any):
220
+ for key, value in kwargs.items():
221
+ setattr(self, key, Max(getattr(self, key), value))
222
+
223
+ def min(self, **kwargs: Any):
224
+ for key, value in kwargs.items():
225
+ setattr(self, key, Min(getattr(self, key), value))
226
+
227
+ def __add__(self, other: "BuffetStats") -> "BuffetStats":
228
+ new = copy.copy(self)
229
+ for attr in self.__dict__:
230
+ if attr.startswith("min_"):
231
+ setattr(
232
+ new, attr, min_nonzero(getattr(self, attr), getattr(other, attr))
233
+ )
234
+ elif attr.startswith("max_"):
235
+ setattr(new, attr, Max(getattr(self, attr), getattr(other, attr)))
236
+ elif attr.startswith("total_"):
237
+ setattr(new, attr, getattr(self, attr) + getattr(other, attr))
238
+ elif getattr(self, attr) is None:
239
+ setattr(new, attr, getattr(other, attr))
240
+ elif getattr(other, attr) is None:
241
+ setattr(new, attr, getattr(self, attr))
242
+ else:
243
+ assert getattr(self, attr) == getattr(
244
+ other, attr
245
+ ), f"BUG: {attr} is different. self: {getattr(self, attr)} other: {getattr(other, attr)}"
246
+ return new
247
+
248
+ def __iadd__(self, other: "BuffetStats") -> "BuffetStats":
249
+ new = self + other
250
+ for key, value in new.__dict__.items():
251
+ setattr(self, key, value)
252
+ return self
253
+
254
+ def net_total_read_actions(self) -> Any:
255
+ return self.total_read_actions - self.total_skipped_first_read_actions
256
+
257
+ def net_total_write_actions(self) -> Any:
258
+ return self.total_write_actions - self.total_skipped_first_write_actions
259
+
260
+ def net_max_per_unit_read_actions(self) -> Any:
261
+ return (
262
+ self.max_per_unit_read_actions
263
+ - self.min_per_unit_skipped_first_read_actions
264
+ )
265
+
266
+ def net_max_per_unit_write_actions(self) -> Any:
267
+ return (
268
+ self.max_per_unit_write_actions
269
+ - self.min_per_unit_skipped_first_write_actions
270
+ )
271
+
272
+ def _get_actions(
273
+ self,
274
+ prefix: str,
275
+ ) -> float | sympy.Expr:
276
+ # My reads to parent go down (parent->me), my reads to child go up (child->me)
277
+ if "read" in prefix:
278
+ parent, child = self.count_downward_movement, self.count_upward_movement
279
+ scale = self.write_scale # Write to other = read to self
280
+ # My writes to parent go up (me->parent), my writes to child go down (me->child)
281
+ elif "write" in prefix:
282
+ parent, child = self.count_upward_movement, self.count_downward_movement
283
+ scale = self.read_scale # Write to other = read to self
284
+ else:
285
+ raise ValueError(f"Invalid prefix: {prefix}")
286
+
287
+ total = getattr(self, f"{prefix.replace('write', 'read')}s_to_peer", 0)
288
+ total += getattr(self, f"{prefix}s_to_parent", 0) if parent else 0
289
+ total += getattr(self, f"{prefix}s_to_child", 0) if child else 0
290
+
291
+ return total * scale
292
+
293
+ @property
294
+ def total_write_actions(self):
295
+ return self._get_actions("total_read") # Read to other = write to self
296
+
297
+ @property
298
+ def total_read_actions(self):
299
+ return self._get_actions("total_write") # Write to other = read to self
300
+
301
+ @property
302
+ def max_per_unit_write_actions(self):
303
+ return self._get_actions("max_per_unit_read") # Read to other = write to self
304
+
305
+ @property
306
+ def max_per_unit_read_actions(self):
307
+ return self._get_actions("max_per_unit_write") # Write to other = read to self
308
+
309
+ @property
310
+ def total_skipped_first_write_actions(self):
311
+ # Read to other = write to self
312
+ return self._get_actions("total_skipped_first_read")
313
+
314
+ @property
315
+ def min_per_unit_skipped_first_write_actions(self):
316
+ # Read to other = write to self
317
+ return self._get_actions("min_per_unit_skipped_first_read")
318
+
319
+ @property
320
+ def total_skipped_first_read_actions(self):
321
+ # Write to other = read to self
322
+ return self._get_actions("total_skipped_first_write")
323
+
324
+ @property
325
+ def min_per_unit_skipped_first_read_actions(self):
326
+ # Write to other = read to self
327
+ return self._get_actions("min_per_unit_skipped_first_write")
328
+
329
+
330
+ def blank_buffet_stats() -> BuffetStats:
331
+ stats = BuffetStats()
332
+ stats.n_loops_above = None # Inherit from whoever is added to this
333
+ return stats
334
+
335
+
336
+ @dataclass
337
+ class ComputeStats:
338
+ total_ops: Any = field(default=0)
339
+ max_per_unit_ops: Any = field(default=0)
340
+ # "max" below refers to the longest latency of any iteration
341
+ max_latency: Any = field(default=0)
342
+ # Mapping from the loop-index (0 at top) to the latency of the first
343
+ # iteration of that loop. "Max" because we may have loops above that and we
344
+ # will take the maximum of the firsts.
345
+ max_first_latency: dict[int, Any] = field(default_factory=dict)
346
+
347
+ def repeat_temporal(self, factor: int) -> "ComputeStats":
348
+ new = copy.copy(self)
349
+ new.total_ops *= factor
350
+ new.max_per_unit_ops *= factor
351
+ new.max_latency *= factor
352
+ # NOTE: max_first_latency does not change
353
+ return new
354
+
355
+ def repeat_spatial(self, factor: int) -> "ComputeStats":
356
+ new = copy.copy(self)
357
+ new.total_ops *= factor
358
+ return new
359
+
360
+ def __add__(self, other: "ComputeStats") -> "ComputeStats":
361
+ new = copy.copy(self)
362
+ new.total_ops += other.total_ops
363
+ new.max_per_unit_ops += other.max_per_unit_ops
364
+ new.max_latency += other.max_latency
365
+ # max_first_latency is only ever updated across loops ABOVE the loop
366
+ # for which we calculated that first latency, so we should MAX
367
+ new.max_first_latency = max_dict(
368
+ self.max_first_latency, other.max_first_latency
369
+ ) # FIRST LATENCY
370
+ return new
371
+
372
+ def combine_temporal(self, other: "ComputeStats"):
373
+ self.total_ops += other.total_ops
374
+ self.max_per_unit_ops += other.max_per_unit_ops
375
+ self.max_latency += other.max_latency
376
+ # max_first_latency is only ever updated across loops ABOVE the loop
377
+ # for which we calculated that first latency, so we should MAX
378
+ self.max_first_latency = max_dict(
379
+ self.max_first_latency, other.max_first_latency
380
+ ) # FIRST LATENCY
381
+
382
+ def combine_spatial(self, other: "ComputeStats"):
383
+ self.total_ops += other.total_ops
384
+ self.max_per_unit_ops = Max(self.max_per_unit_ops, other.max_per_unit_ops)
385
+ self.max_latency = Max(self.max_latency, other.max_latency)
386
+ # max_first_latency is only ever updated across loops ABOVE the loop
387
+ # for which we calculated that first latency, so we should MAX
388
+ self.max_first_latency = max_dict(
389
+ self.max_first_latency, other.max_first_latency
390
+ ) # FIRST LATENCY
391
+
392
+
393
+ @dataclass
394
+ class SymbolicAnalysisOutput:
395
+ compute_stats: dict[Compute, ComputeStats] = field(default_factory=dict)
396
+
397
+ buffet_stats: dict[Buffet, BuffetStats] = field(default_factory=dict)
398
+
399
+ # Mapping [level, einsum] to the fanout
400
+ fanout: dict[(Buffet, str), int] = field(default_factory=dict)
401
+
402
+ # Mapping [einsum] to the number of temporal steps
403
+ temporal_steps: dict[str, int] = field(default_factory=dict)
404
+
405
+ symbols: list[sympy.Symbol] = field(default_factory=list)
406
+
407
+ # tensor to the mapping for that particular tensor
408
+ tensor2mapping: dict[TensorName, Mapping] = field(default_factory=dict)
409
+
410
+ def get_buffet_for_tensor(self, tensor: TensorName) -> Buffet:
411
+ for buffet in self.buffet_stats:
412
+ if buffet.tensor == tensor:
413
+ return buffet
414
+ raise ValueError(f"Buffet for tensor {tensor} not found")
415
+
416
+ def max(self, **kwargs: Any):
417
+ for key, value in kwargs.items():
418
+ assert key in [
419
+ "compute_stats",
420
+ "stats",
421
+ "fanout",
422
+ "temporal_steps",
423
+ ]
424
+ previous = getattr(self, key)
425
+ for k, v in value.items():
426
+ previous.setdefault(k, {})
427
+ for k2, v2 in v.items():
428
+ if k2 in previous[k]:
429
+ previous[k][k2] = Max(previous[k][k2], v2)
430
+ else:
431
+ previous[k][k2] = v2
432
+
433
+ def get_child_buffet_stats(self, buffet: Buffet) -> BuffetStats:
434
+ seen = False
435
+ for child_buffet, child_stats in reversed(self.buffet_stats.items()):
436
+ if not seen:
437
+ seen = child_buffet == buffet
438
+ continue
439
+ if child_buffet.tensor == buffet.tensor:
440
+ return child_stats
441
+ return None
442
+
443
+ def sum_buffet_stats_per_level(self) -> dict[str, BuffetStats]:
444
+ result: dict[str, BuffetStats] = {}
445
+ for buffet, stats in self.buffet_stats.items():
446
+ result.setdefault(buffet.level, blank_buffet_stats())
447
+ result[buffet.level] += stats
448
+ return result
449
+
450
+ def add_buffet_stats_and_symbols(self, other: "SymbolicAnalysisOutput"):
451
+ assert not (set(self.buffet_stats) & set(other.buffet_stats)), "BUG"
452
+ self.buffet_stats.update(other.buffet_stats)
453
+ # if self.temporal_steps != other.temporal_steps:
454
+ # print(f'Temporal steps are different.')
455
+ # print(f'\tmine: {self.temporal_steps}')
456
+ # print(f'\tother: {other.temporal_steps}')
457
+ # assert self.temporal_steps == other.temporal_steps, "BUG"
458
+ self.temporal_steps.update(other.temporal_steps)
459
+ self.symbols.extend([s for s in other.symbols if s not in self.symbols])
460
+ # Assert compute stats are the same
461
+ # assert self.compute_stats == other.compute_stats, "BUG"
462
+
463
+
464
+ @dataclass
465
+ class AnalysisInfo:
466
+ """Information needed within the analysis step by multiple functions that
467
+ can be computed once at the beginning.
468
+ """
469
+
470
+ mapping: Mapping
471
+ workload: Workload
472
+ full_rank_variable_shapes: dict
473
+ all_tensors: set
474
+
475
+ einsum_tensor_to_projection: dict
476
+ tensor_to_relevancy: dict
477
+ tensor_to_backer_id: dict[TensorName, int]
478
+
479
+ is_copy_operation: TensorName | None
480
+
481
+ job: Job
482
+
483
+ tensor_to_reservation_backer_id: dict[TensorName, int] = field(default_factory=dict)
484
+
485
+ # We track first latency for these nodes (should be Temporal)
486
+ last_temporal_node_idx: int = None
487
+ """
488
+ node idx of the last (above) temporal node
489
+ """
490
+ idxs_to_track_first_latency: set[int] = field(default_factory=set)
491
+ """
492
+ node idxs for which we track first latency
493
+ """
494
+
495
+
496
+ def quick_insert_reservation_nodes(job: Job) -> list[MappingNode]:
497
+ mapping = list(job.mapping.nodes)
498
+ workload = job.spec.workload
499
+
500
+ # TODO: Subclass reservation with TensorReservation or something so that we can
501
+ # track which are for tensors and which are for non-tensor resources.
502
+
503
+ info = AnalysisInfo(
504
+ mapping=None,
505
+ workload=workload,
506
+ full_rank_variable_shapes=None,
507
+ all_tensors=None,
508
+ einsum_tensor_to_projection=None,
509
+ tensor_to_relevancy=job.tensor_to_relevancy,
510
+ tensor_to_backer_id=None,
511
+ is_copy_operation=None,
512
+ job=None,
513
+ )
514
+ insert_reservation_nodes(mapping, info)
515
+ m = Mapping(nodes=mapping)
516
+ m._n_loop_orders = job.mapping._n_loop_orders
517
+ return m
518
+
519
+
520
+ def convert_to_copy(
521
+ mapping: list[MappingNode], workload: Workload
522
+ ) -> tuple[list[MappingNode], dict[TensorName, int]]:
523
+ mapping = copy.deepcopy(mapping)
524
+
525
+ # Calculate this BEFORE we modify the mapping. We're going to have the copy source
526
+ # tensor moving upward sometimes, and we don't want the backing tensor holder
527
+ tensor_to_backer_id = get_tensor_to_backer_id(mapping)
528
+
529
+ first_input_tensor = workload.einsums[mapping[-1].einsum].copy_source_tensor()
530
+
531
+ for node in mapping:
532
+ if isinstance(node, TensorHolder):
533
+ if node.tensors:
534
+ node.tensors = [first_input_tensor]
535
+ node._lower = False
536
+
537
+ to_remove = []
538
+ i = 0
539
+ while i < len(mapping):
540
+ node = mapping[i]
541
+ if isinstance(node, TensorHolder):
542
+ j = i + 1
543
+ while j < len(mapping):
544
+ node2 = mapping[j]
545
+ if (
546
+ isinstance(node2, TensorHolder)
547
+ and node.component == node2.component
548
+ ):
549
+ mapping.pop(j)
550
+ else:
551
+ j += 1
552
+ i += 1
553
+ mapping = [node for node in mapping if node not in to_remove]
554
+
555
+ return mapping, tensor_to_backer_id
556
+
557
+
558
+ def analyze_reuse_and_add_reservations_to_mapping(
559
+ job: Job,
560
+ ) -> SymbolicAnalysisOutput:
561
+ mapping = job.mapping.nodes
562
+ workload = job.spec.workload
563
+ einsum_name = mapping[-1].einsum
564
+
565
+ is_copy_operation = workload.einsums[einsum_name].is_copy_operation
566
+ symbols = insert_sympy_symbols(job.mapping.nodes, job)
567
+
568
+ if is_copy_operation:
569
+ mapping, tensor_to_backer_id = convert_to_copy(mapping, workload)
570
+ else:
571
+ tensor_to_backer_id = get_tensor_to_backer_id(mapping)
572
+
573
+ job.mapping = quick_insert_reservation_nodes(job)
574
+ # print(f'Job mapping: {job.mapping.compact_str()}')
575
+ # for n in job.mapping.nodes:
576
+ # print(f'\t{n.compact_str()}')
577
+
578
+ einsum_tensor_to_projection = {}
579
+ einsum = workload.einsums[einsum_name]
580
+ all_tensors = einsum.tensor_names
581
+ for tensor in all_tensors:
582
+ einsum_tensor_to_projection[(einsum_name, tensor)] = get_projection_expr(
583
+ einsum, tensor
584
+ )
585
+ tensor_to_relevancy = {
586
+ tensor: get_rank_variable_relevancy(einsum, tensor) for tensor in all_tensors
587
+ }
588
+ assert all_tensors, f"Einsum {einsum_name} has no tensors"
589
+
590
+ """
591
+ Note for how this works.
592
+
593
+ Spatial loops are weird, because they don't belong at a single point in the loop
594
+ nest. For example:
595
+
596
+ - DRAM keep A, B
597
+ - *
598
+ - Reg keep A
599
+ - for n in [0..N)
600
+ - GLB keep B
601
+ - *
602
+ - Compute
603
+
604
+ A loop spatial-for (Reg) k in [0..K) would affect the register at the point of the
605
+ first asterisk, but the global buffer at the point of the second asterisk.
606
+
607
+ To handle this, we make a separate mapping for each tensor, analyze each, and
608
+ combine the results.
609
+
610
+ To anyone who would like to create behavior that simultaneously looks at multiple
611
+ storage nodes for a given memory, note that there will be two challenges to address:
612
+
613
+ 1. The code currently analyzes one tensor at a time. This could be fixed by
614
+ processing all mapping(s) together, applying loop(s) from each to only the
615
+ appropriate nodes.
616
+ 2. The code must analyze one storage node at a time, and there may be temporal and
617
+ spatial nodes between two storage nodes for a given memory, which would separate
618
+ the analysis steps for the storage nodes. This may be addressed by only
619
+ performing such analysis until the outermost storage node for a particular memory
620
+ has been analyzed.
621
+ """
622
+ result = None
623
+
624
+ tensor2mapping = {}
625
+ index_expressions = set(einsum.indexing_expressions)
626
+ for k, v in job.rank_variable_bounds.items():
627
+ index_expressions.add(f"0 < {k} <= {v}")
628
+ for tensor in all_tensors:
629
+ cur_mapping = job.mapping._get_single_tensor_mapping(
630
+ tensor, job.flattened_arch, index_expressions
631
+ )
632
+ info = AnalysisInfo(
633
+ mapping=cur_mapping.nodes,
634
+ workload=workload,
635
+ full_rank_variable_shapes=job.rank_variable_bounds,
636
+ all_tensors=set([tensor]),
637
+ einsum_tensor_to_projection=einsum_tensor_to_projection,
638
+ tensor_to_relevancy=tensor_to_relevancy,
639
+ tensor_to_backer_id=tensor_to_backer_id,
640
+ is_copy_operation=is_copy_operation,
641
+ job=job,
642
+ )
643
+ cur_result = analyze_node(0, job.rank_variable_bounds, info)
644
+ if result is None:
645
+ result = cur_result
646
+ else:
647
+ result.add_buffet_stats_and_symbols(cur_result)
648
+ tensor2mapping[tensor] = cur_mapping
649
+
650
+ result.symbols = symbols
651
+ result.tensor2mapping = tensor2mapping
652
+ return result
653
+
654
+
655
+ def get_tensor_to_backer_id(mapping: Mapping):
656
+ tensor_to_ids: dict[TensorName, set[int]] = {}
657
+ for node in mapping:
658
+ if isinstance(node, TensorHolder):
659
+ for tensor in node.tensors:
660
+ if tensor in tensor_to_ids:
661
+ continue
662
+ tensor_to_ids[tensor] = id(node)
663
+ return tensor_to_ids
664
+
665
+
666
+ class ReservationAnalysisTracker:
667
+ def __init__(self, buffet, node):
668
+ self.buffet: Buffet = buffet
669
+ self.node: TensorHolder = node
670
+
671
+ # These are interface (TODO: should be property)
672
+ self.is_fill_level = False
673
+ self.should_stop = False
674
+ self.insert_reservation_under = False
675
+ self.insert_fill_under = False
676
+
677
+ # Temporary values
678
+ self.has_filled = False
679
+
680
+ def track_temporal_loop(self, relevancy, node):
681
+ self.is_fill_level = False
682
+ self.insert_reservation_under = False
683
+ self.insert_fill_under = False
684
+
685
+ if isinstance(relevancy, Irrelevant):
686
+ if not self.has_filled:
687
+ self.is_fill_level = True
688
+ self.has_filled = True
689
+
690
+ self.should_stop = True
691
+ elif isinstance(relevancy, Relevant):
692
+ self.should_stop = False
693
+ elif isinstance(relevancy, PartiallyRelevant):
694
+ self.last = True
695
+
696
+ if not self.has_filled:
697
+ self.is_fill_level = True
698
+ self.has_filled = True
699
+
700
+ self.should_stop = True
701
+ self.insert_reservation_under = True
702
+ else:
703
+ raise ValueError(f"Unknown relevancy {relevancy}")
704
+
705
+ def track_compute(self):
706
+ self.should_stop = True
707
+ if not self.has_filled:
708
+ self.is_fill_level = True
709
+ self.has_filled = True
710
+
711
+ def track_spatial_loop(self, relevancy, node):
712
+ if node.component != self.buffet.level:
713
+ self.should_stop = True
714
+ if not self.has_filled:
715
+ self.is_fill_level = True
716
+ self.has_filled = True
717
+ return
718
+
719
+ self.is_fill_level = False
720
+ self.should_stop = False
721
+
722
+
723
+ def insert_reservation_nodes(mapping, info: AnalysisInfo):
724
+ trackers: list[ReservationAnalysisTracker] = []
725
+ einsum = info.workload.einsums[mapping[-1].einsum]
726
+ non_intermediate_tensors = (
727
+ einsum.tensor_names - info.workload.tensor_names_used_in_multiple_einsums
728
+ )
729
+ seen_tensors = set() # reservation for top-level buffets cannot be lowered
730
+
731
+ n_nodes = len(mapping)
732
+ i = 0
733
+ while i < n_nodes:
734
+ node = mapping[i]
735
+ to_remove = []
736
+ if isinstance(node, Reservation):
737
+ pass
738
+ elif isinstance(node, Temporal):
739
+ rank = node.rank_variable
740
+ for tracker in trackers:
741
+ relevancy = info.tensor_to_relevancy[tracker.buffet.tensor]
742
+ tracker.track_temporal_loop(relevancy[rank], node)
743
+ elif isinstance(node, Spatial):
744
+ rank = node.rank_variable
745
+ for tracker in trackers:
746
+ relevancy = info.tensor_to_relevancy[tracker.buffet.tensor]
747
+ tracker.track_spatial_loop(relevancy[rank], node)
748
+ elif isinstance(node, TensorHolder):
749
+ for tracker in trackers:
750
+ tracker.should_stop = True
751
+ tracker.insert_reservation_under = False
752
+ for tensor in node.tensors:
753
+ tensor = TensorName(tensor)
754
+ buffet = Buffet(tensor, mapping[-1].einsum, node.component)
755
+ trackers.append(ReservationAnalysisTracker(buffet, node))
756
+ if not node._lower or (
757
+ tensor not in seen_tensors and tensor in non_intermediate_tensors
758
+ ):
759
+ seen_tensors.add(tensor)
760
+ trackers[-1].is_fill_level = True
761
+ trackers[-1].insert_reservation_under = True
762
+ trackers[-1].insert_fill_under = True
763
+ trackers[-1].should_stop = True
764
+ elif isinstance(node, mapping_spec.Compute):
765
+ for tracker in trackers:
766
+ tracker.track_compute()
767
+ tracker.insert_reservation_under = False
768
+ else:
769
+ raise NotImplementedError(f"Unknown node type {type(node)}")
770
+
771
+ reservation_insert_below = []
772
+ reservation_insert_above = []
773
+ for j in range(len(trackers) - 1, -1, -1):
774
+ if not trackers[j].should_stop:
775
+ continue
776
+ tracker = trackers.pop(j)
777
+ buffet = tracker.buffet
778
+ node = Reservation(purposes=[buffet.tensor], resource=buffet.level)
779
+ node.persistent = tracker.node.persistent
780
+ node._backing = tracker.node._backing
781
+
782
+ if (
783
+ buffet.tensor not in info.tensor_to_reservation_backer_id
784
+ and buffet.tensor in info.workload.tensor_names_used_in_multiple_einsums
785
+ ):
786
+ info.tensor_to_reservation_backer_id[buffet.tensor] = id(node)
787
+
788
+ if tracker.insert_reservation_under:
789
+ reservation_insert_below.append(node)
790
+ else:
791
+ reservation_insert_above.append(node)
792
+
793
+ # The order of these for loops is important. Reservation must be below fill.
794
+ for node in reservation_insert_below:
795
+ mapping.insert(i + 1, node)
796
+ i += 1
797
+ for node in reservation_insert_above:
798
+ mapping.insert(i, node)
799
+ i += 1
800
+
801
+ i += 1
802
+ n_nodes = len(mapping)
803
+
804
+ label_fused_loops(mapping)
805
+
806
+
807
+ def label_fused_loops(mapping: list[MappingNode]):
808
+ last_backer = None
809
+ for i, node in enumerate(mapping):
810
+ if isinstance(node, Reservation) and node._backing:
811
+ last_backer = i
812
+ if last_backer is None:
813
+ raise ValueError(
814
+ f"No backing TensorHolder found in mapping {", ".join(m.compact_str() for m in mapping)}"
815
+ )
816
+
817
+ for i, node in enumerate(mapping):
818
+ if isinstance(node, Loop):
819
+ node._fused = i < last_backer
820
+ return mapping
821
+
822
+
823
+ def analyze_node(node_idx, current_shape, info: AnalysisInfo) -> SymbolicAnalysisOutput:
824
+ node = info.mapping[node_idx]
825
+ class2analysis_function = {
826
+ Temporal: analyze_temporal,
827
+ Spatial: analyze_spatial,
828
+ Storage: analyze_storage,
829
+ Reservation: analyze_reservation,
830
+ mapping_spec.Compute: analyze_compute,
831
+ ProcessingStage: analyze_processing_stage,
832
+ }
833
+ if type(node) not in class2analysis_function:
834
+ raise TypeError(f"Unknown node type {type(node)}")
835
+ return class2analysis_function[type(node)](node_idx, current_shape, info)
836
+
837
+
838
+ def analyze_temporal(
839
+ node_idx, current_shape, info: AnalysisInfo
840
+ ) -> SymbolicAnalysisOutput:
841
+ mapping = info.mapping
842
+ node = mapping[node_idx]
843
+ stride_and_shape = get_stride_and_tile_shape(node, current_shape, node_idx, info)
844
+
845
+ result_accumulator = SymbolicAnalysisOutput()
846
+
847
+ first_latency = None
848
+
849
+ def handle_repeated_value(repeated_shape):
850
+ nonlocal first_latency
851
+ shape_value = repeated_shape.value
852
+ shape_repeats = repeated_shape.repeats
853
+
854
+ child_shape = current_shape.copy()
855
+ child_shape[node.rank_variable] = shape_value
856
+
857
+ child_result = analyze_node(node_idx + 1, child_shape, info)
858
+
859
+ accumulated_buffet_stats = result_accumulator.buffet_stats
860
+ for buffet, stats in child_result.buffet_stats.items():
861
+ relevancy = info.tensor_to_relevancy[buffet.tensor][node.rank_variable]
862
+ is_fully_relevant = isinstance(relevancy, Relevant)
863
+ accumulated_stats = accumulated_buffet_stats.setdefault(
864
+ buffet, blank_buffet_stats()
865
+ )
866
+ accumulated_stats += stats.repeat_temporal(
867
+ shape_repeats, is_fully_relevant=is_fully_relevant
868
+ )
869
+ accumulated_stats.n_loops_above = stats.n_loops_above + 1
870
+
871
+ for einsum, child_steps in child_result.temporal_steps.items():
872
+ if einsum not in result_accumulator.temporal_steps:
873
+ result_accumulator.temporal_steps[einsum] = 0
874
+ result_accumulator.temporal_steps[einsum] += child_steps * shape_repeats
875
+
876
+ result_accumulator.max(fanout=child_result.fanout)
877
+
878
+ for key in child_result.compute_stats:
879
+ if first_latency is None:
880
+ first_latency = child_result.compute_stats[key].max_latency
881
+
882
+ compute_stats = result_accumulator.compute_stats.setdefault(
883
+ key, ComputeStats()
884
+ )
885
+ compute_stats += child_result.compute_stats[key].repeat_temporal(
886
+ shape_repeats
887
+ )
888
+ result_accumulator.compute_stats[key] = compute_stats
889
+
890
+ info.last_temporal_node_idx = node_idx
891
+
892
+ shape = stride_and_shape.shape
893
+ if isinstance(shape, SequenceOfRepatedvalues):
894
+ for repeated_shape in shape.sequence:
895
+ assert isinstance(repeated_shape, RepeatedValue)
896
+ handle_repeated_value(repeated_shape)
897
+ elif isinstance(shape, RepeatedValue):
898
+ handle_repeated_value(shape)
899
+
900
+ if node_idx in info.idxs_to_track_first_latency:
901
+ for compute_stat in result_accumulator.compute_stats.values():
902
+ # Should be the first time we store this value
903
+ assert node_idx not in compute_stat.max_first_latency
904
+ compute_stat.max_first_latency[node_idx] = first_latency
905
+
906
+ return result_accumulator
907
+
908
+
909
+ def analyze_spatial(node_idx, current_shape, info: AnalysisInfo):
910
+ mapping = info.mapping
911
+ einsum_name = mapping[-1].einsum
912
+ node: Spatial = mapping[node_idx]
913
+ rank_var = node.rank_variable
914
+ node_dim = node.name
915
+ stride_and_shape = get_stride_and_tile_shape(node, current_shape, node_idx, info)
916
+
917
+ result_accumulator = SymbolicAnalysisOutput()
918
+
919
+ def handle_repeated_value(repeated_shape):
920
+ shape_value = repeated_shape.value
921
+ shape_repeats = repeated_shape.repeats
922
+
923
+ child_shape = current_shape.copy()
924
+ child_shape[node.rank_variable] = shape_value
925
+
926
+ child_result = analyze_node(node_idx + 1, child_shape, info)
927
+
928
+ accumulated_buffet_stats = result_accumulator.buffet_stats
929
+ child_stats = list(child_result.buffet_stats.items())
930
+ for i, (buffet, buffet_stats) in enumerate(child_stats):
931
+ stats = buffet_stats
932
+ accumulated_stats = accumulated_buffet_stats.setdefault(
933
+ buffet, blank_buffet_stats()
934
+ )
935
+ relevancy = info.tensor_to_relevancy[buffet.tensor][rank_var]
936
+
937
+ # Reuse parent accesses only:
938
+ # - Irrelevant loops
939
+ # - The outermost level that holds the tensor (the one whose parent accesses
940
+ # will be going through the network)
941
+ last_buffet = True
942
+ for other_buffet, _ in child_stats[i + 1 :]:
943
+ if other_buffet.tensor == buffet.tensor:
944
+ last_buffet = False
945
+ break
946
+
947
+ reuse_parent_accesses = (
948
+ last_buffet
949
+ and isinstance(relevancy, Irrelevant)
950
+ and buffet.tensor in node._may_reuse
951
+ )
952
+
953
+ accumulated_stats += stats.repeat_spatial(
954
+ shape_repeats, reuse_parent_accesses=reuse_parent_accesses
955
+ )
956
+ accumulated_stats.n_loops_above = stats.n_loops_above + 1
957
+
958
+ for einsum, child_steps in child_result.temporal_steps.items():
959
+ if einsum not in result_accumulator.temporal_steps:
960
+ result_accumulator.temporal_steps[einsum] = child_steps
961
+ else:
962
+ result_accumulator.temporal_steps[einsum] = Max(
963
+ result_accumulator.temporal_steps[einsum], child_steps
964
+ )
965
+
966
+ my_key = (node.component, einsum_name)
967
+ child_result.fanout.setdefault(my_key, {})
968
+
969
+ # Propagate up everything except the current level and dimension
970
+ child_fanout = copy.deepcopy(child_result.fanout)
971
+ target_fanout = child_fanout[my_key].pop(node_dim, 1)
972
+ result_accumulator.max(fanout=child_fanout)
973
+
974
+ # Prpoagate current level and dimension * shape_repeats
975
+ child_fanout = child_result.fanout[my_key]
976
+ fanout = result_accumulator.fanout.setdefault(my_key, {})
977
+ fanout.setdefault(node_dim, 0) # TODO: Assume sympy can just take in 0
978
+ # TODO: If node_dim was missing, the original code would have omitted
979
+ # shape_repeats. Is this correct?
980
+ fanout[node_dim] += target_fanout * shape_repeats
981
+
982
+ for key in child_result.compute_stats:
983
+ # TODO: ensure that `ComputeStats()`, which is initialized ONCE, is okay to use here
984
+ compute_stats = result_accumulator.compute_stats.setdefault(
985
+ key, ComputeStats()
986
+ )
987
+ # TODO: If check omitted. This was in the original code, check history if needed.
988
+ compute_stats.combine_spatial(
989
+ child_result.compute_stats[key].repeat_spatial(shape_repeats)
990
+ )
991
+
992
+ shape = stride_and_shape.shape
993
+ if isinstance(shape, SequenceOfRepatedvalues):
994
+ for repeated_shape in shape.sequence:
995
+ assert isinstance(repeated_shape, RepeatedValue)
996
+ handle_repeated_value(repeated_shape)
997
+ elif isinstance(shape, RepeatedValue):
998
+ handle_repeated_value(shape)
999
+
1000
+ return result_accumulator
1001
+
1002
+
1003
+ def reduce_dicts(dict1: dict, dict2: dict, reduce_op):
1004
+ for key in dict1:
1005
+ if key not in dict2:
1006
+ dict2[key] = dict1[key]
1007
+ else:
1008
+ dict2[key] = reduce_op(dict1[key], dict2[key])
1009
+
1010
+
1011
+ def get_total_to_per_unit(total, max_per_unit):
1012
+ if total == 0 and max_per_unit != 0:
1013
+ raise ValueError(f"total is 0 but max_per_unit is {max_per_unit}")
1014
+ if total == 0:
1015
+ return 1
1016
+ return max_per_unit / total
1017
+
1018
+
1019
+ def has_parent_tensor_holder(
1020
+ tensor: TensorName, node_idx: int, info: AnalysisInfo
1021
+ ) -> bool:
1022
+ for node in info.mapping[:node_idx]:
1023
+ if isinstance(node, TensorHolder) and tensor in node.tensors:
1024
+ return True
1025
+ return False
1026
+
1027
+
1028
+ def find_component_object(
1029
+ component: str, flattened_arch: list[arch.Leaf]
1030
+ ) -> arch.TensorHolder:
1031
+ for node in flattened_arch:
1032
+ if node.name == component:
1033
+ return node
1034
+ raise ValueError(f"Component {component} not found in flattened arch")
1035
+
1036
+
1037
+ def analyze_storage(
1038
+ node_idx: int,
1039
+ current_shape: dict[str, int],
1040
+ info: AnalysisInfo,
1041
+ propagate_child_results: bool = False,
1042
+ count_writes: bool = True,
1043
+ ):
1044
+ mapping = info.mapping
1045
+ einsum_name = mapping[-1].einsum
1046
+ node: TensorHolder = mapping[node_idx]
1047
+
1048
+ child_result = analyze_node(node_idx + 1, current_shape, info)
1049
+
1050
+ for tensor in node.tensors:
1051
+ tensor = TensorName(tensor)
1052
+ buffet = Buffet(tensor, einsum_name, node.component)
1053
+
1054
+ # Reservations make these, and they go below the storage node, so the buffet
1055
+ # stats are already made at this point
1056
+ stats = child_result.buffet_stats[buffet]
1057
+ backer_id = info.tensor_to_backer_id[tensor]
1058
+ is_backing = backer_id == id(node)
1059
+ if node.persistent:
1060
+ stats.persistent = True
1061
+ below_backing = backer_id in [id(m) for m in mapping[:node_idx]]
1062
+
1063
+ projection = info.einsum_tensor_to_projection[(einsum_name, tensor)]
1064
+
1065
+ fills = compute_dense_tile_occupancy(projection, current_shape)
1066
+
1067
+ child = child_result.get_child_buffet_stats(buffet)
1068
+ inherit_from_child = propagate_child_results and child is not None
1069
+
1070
+ # ==============================================================================
1071
+ # Calculate the total fills and reads to parent. These propagate upward.
1072
+ # ==============================================================================
1073
+
1074
+ def inherit_add(attr: str, default_value: Any = fills) -> Any:
1075
+ val = getattr(child, attr) if inherit_from_child else default_value
1076
+ setattr(stats, attr, val + getattr(stats, attr))
1077
+
1078
+ if has_parent_tensor_holder(tensor, node_idx, info):
1079
+ # Initial fetch: If we're below the backing storage, fetch data from above
1080
+ # at the beginning.
1081
+ if not is_backing and below_backing:
1082
+ inherit_add("total_reads_to_parent", fills)
1083
+ inherit_add("max_per_parent_reads_to_parent", fills)
1084
+
1085
+ # Data writeback. Do not writeback if it's a copy operation and we're below
1086
+ # the backing storage; data only flows upward.
1087
+
1088
+ # Writeback occurs in two cases:
1089
+ # - We're at or above the backing storage, so we need to propagate our
1090
+ # results upward to any storage nodes that will need this data.
1091
+ # - This is a written tensor, so we need to write back the written data.
1092
+ if (
1093
+ tensor in info.workload.einsums[einsum_name].output_tensor_names
1094
+ or not below_backing
1095
+ ):
1096
+ inherit_add("total_writes_to_parent")
1097
+ inherit_add("max_per_parent_writes_to_parent")
1098
+
1099
+ # For read+write tensors, we skip the first fill because the data will be
1100
+ # initialized with a zero value.
1101
+ if tensor in info.workload.einsums[einsum_name].output_tensor_names:
1102
+ inherit_add("total_skipped_first_reads_to_parent")
1103
+ inherit_add("min_per_parent_skipped_first_reads_to_parent")
1104
+
1105
+ # =========================
1106
+ # Data exchanges with child
1107
+ if child is not None:
1108
+ stats.total_writes_to_child += child.total_reads_to_parent
1109
+ stats.max_per_unit_writes_to_child += child.max_per_parent_reads_to_parent
1110
+ # Skip first read
1111
+ stats.total_skipped_first_writes_to_child += (
1112
+ child.total_skipped_first_reads_to_parent
1113
+ )
1114
+ stats.min_per_unit_skipped_first_writes_to_child += (
1115
+ child.min_per_parent_skipped_first_reads_to_parent
1116
+ )
1117
+
1118
+ stats.total_reads_to_child += child.total_writes_to_parent
1119
+ stats.max_per_unit_reads_to_child += child.max_per_parent_writes_to_parent
1120
+
1121
+ component_object = find_component_object(
1122
+ node.component, info.job.flattened_arch
1123
+ )
1124
+ bits_per_value_scale = component_object.attributes.bits_per_value_scale[tensor]
1125
+ bits_per_value = bits_per_value_scale * info.job.bits_per_value[tensor]
1126
+ read_bits_per_action = component_object.actions[
1127
+ "read"
1128
+ ].arguments.bits_per_action
1129
+ stats.read_scale = bits_per_value / read_bits_per_action
1130
+ if count_writes:
1131
+ write_bits_per_action = component_object.actions[
1132
+ "write"
1133
+ ].arguments.bits_per_action
1134
+ stats.write_scale = bits_per_value / write_bits_per_action
1135
+ else:
1136
+ stats.write_scale = 0
1137
+
1138
+ return child_result
1139
+
1140
+
1141
+ def analyze_processing_stage(node_idx, current_shape, info: AnalysisInfo):
1142
+ mapping = info.mapping
1143
+ einsum_name = mapping[-1].einsum
1144
+ node = mapping[node_idx]
1145
+ component_object = find_component_object(node.component, info.job.flattened_arch)
1146
+ storage_result = analyze_storage(
1147
+ node_idx,
1148
+ current_shape,
1149
+ info,
1150
+ propagate_child_results=True,
1151
+ count_writes=False,
1152
+ )
1153
+ for tensor in node.tensors:
1154
+ buffet = Buffet(tensor, einsum_name, node.component)
1155
+ stats = storage_result.buffet_stats[buffet]
1156
+ stats.max_occupancy = 0
1157
+ stats.count_downward_movement = component_object.attributes.direction != "up"
1158
+ stats.count_upward_movement = component_object.attributes.direction != "down"
1159
+ assert stats.total_write_actions == 0
1160
+ return storage_result
1161
+
1162
+
1163
+ def analyze_reservation(node_idx, current_shape, info: AnalysisInfo):
1164
+ mapping = info.mapping
1165
+ einsum_name = mapping[-1].einsum
1166
+ node = mapping[node_idx]
1167
+ tensor = TensorName(node.purpose)
1168
+
1169
+ if info.last_temporal_node_idx is not None and id(
1170
+ node
1171
+ ) == info.tensor_to_reservation_backer_id.get(node.purpose, None):
1172
+ info.idxs_to_track_first_latency.add(info.last_temporal_node_idx)
1173
+
1174
+ child_result = analyze_node(node_idx + 1, current_shape, info)
1175
+
1176
+ buffet = Buffet(tensor, einsum_name, node.resource)
1177
+
1178
+ # Reservation nodes are the first to produce stats for a buffet
1179
+ assert buffet not in child_result.buffet_stats
1180
+
1181
+ stats = BuffetStats()
1182
+ projection = info.einsum_tensor_to_projection[(einsum_name, tensor)]
1183
+ component_object = find_component_object(node.resource, info.job.flattened_arch)
1184
+ bits_per_value_scale = component_object.attributes.bits_per_value_scale[tensor]
1185
+ bits_per_value = bits_per_value_scale * info.job.bits_per_value[tensor]
1186
+ stats.max_occupancy = (
1187
+ compute_dense_tile_occupancy(projection, current_shape) * bits_per_value
1188
+ )
1189
+ child_result.buffet_stats[buffet] = stats
1190
+
1191
+ fanout_key = (node.resource, einsum_name)
1192
+ if fanout_key not in child_result.fanout:
1193
+ child_result.fanout[fanout_key] = {}
1194
+
1195
+ return child_result
1196
+
1197
+
1198
+ def analyze_compute(
1199
+ node_idx, current_shape, info: AnalysisInfo
1200
+ ) -> SymbolicAnalysisOutput:
1201
+ einsum = info.mapping[-1].einsum
1202
+ node = info.mapping[node_idx]
1203
+
1204
+ computes = 0 if info.is_copy_operation else 1
1205
+
1206
+ result_accumulator = SymbolicAnalysisOutput()
1207
+
1208
+ result_accumulator.temporal_steps[einsum] = computes
1209
+ result_accumulator.compute_stats[Compute(einsum, node.component)] = ComputeStats(
1210
+ computes,
1211
+ computes,
1212
+ 1,
1213
+ )
1214
+
1215
+ if info.is_copy_operation:
1216
+ return result_accumulator
1217
+
1218
+ for tensor in info.all_tensors:
1219
+ buffet = Buffet(tensor, einsum, node.component)
1220
+ stats = BuffetStats()
1221
+ stats.total_reads_to_parent = 1
1222
+ stats.max_per_parent_reads_to_parent = 1
1223
+ if tensor in info.workload.einsums[einsum].output_tensor_names:
1224
+ stats.total_writes_to_parent = 1
1225
+ stats.max_per_parent_writes_to_parent = 1
1226
+ stats.total_skipped_first_reads_to_parent = 1
1227
+ stats.min_per_parent_skipped_first_reads_to_parent = 1
1228
+ stats.max_occupancy = 1
1229
+ result_accumulator.buffet_stats[buffet] = stats
1230
+
1231
+ return result_accumulator
1232
+
1233
+
1234
+ @dataclass
1235
+ class RepeatedValue[T]:
1236
+ value: T
1237
+ repeats: int
1238
+
1239
+
1240
+ @dataclass
1241
+ class SequenceOfRepatedvalues[T]:
1242
+ sequence: list[RepeatedValue[T]]
1243
+
1244
+
1245
+ @dataclass
1246
+ class StrideAndShape:
1247
+ stride: any
1248
+ shape: any
1249
+
1250
+
1251
+ def get_stride_and_tile_shape(node: Loop, full_shape, n: int, info: AnalysisInfo):
1252
+ rank = node.rank_variable
1253
+ rank_shape = full_shape[rank]
1254
+
1255
+ stride = node.stride
1256
+ initial_tile_shape = node.initial_tile_shape
1257
+
1258
+ # PERFECT:
1259
+ # - Node shape = stride
1260
+ # - # Iterations = total shape / stride
1261
+ # IMPERFECT:
1262
+ # - Node shape = stride
1263
+ # - # Iterations = ceil(total shape / stride)
1264
+ if IMPERFECT and initial_tile_shape is None:
1265
+ factor = sympy.ceiling(rank_shape / stride)
1266
+ stride_avg = stride / sympy.ceiling(rank_shape / stride)
1267
+ return StrideAndShape(stride_avg, RepeatedValue(stride, factor))
1268
+
1269
+ if initial_tile_shape is None:
1270
+ if node._assume_perfect_factor or known_perfect_factor(stride, rank_shape):
1271
+ factor = rank_shape / stride
1272
+ return StrideAndShape(stride, RepeatedValue(stride, factor))
1273
+ else:
1274
+ factor = sympy.ceiling(rank_shape / sympy.Min(stride, rank_shape))
1275
+ return make_possibly_different_last(stride, factor, rank_shape)
1276
+
1277
+ middle_shape_factor = sympy.ceiling((rank_shape - initial_tile_shape) / stride)
1278
+ # TODO: sometimes last_shape is 0, causing numerical instability
1279
+ # Currently, we are sometimes rounding up last shape.
1280
+ # last_shape = rank_shape - initial_tile_shape - stride*middle_shape_factor
1281
+ # has_last_shape = sympy.ceiling(last_shape/(last_shape+1))
1282
+ return StrideAndShape(
1283
+ stride,
1284
+ SequenceOfRepatedvalues(
1285
+ [
1286
+ RepeatedValue(initial_tile_shape, 1),
1287
+ RepeatedValue(stride, middle_shape_factor),
1288
+ # RepeatedValue(last_shape+0.01, has_last_shape)
1289
+ ]
1290
+ ),
1291
+ )
1292
+ # if node.tile_shape is not None:
1293
+ # tile_shape = node.tile_shape
1294
+
1295
+ # if node._assume_perfect_factor or known_perfect_factor(tile_shape, rank_shape):
1296
+ # factor = rank_shape / tile_shape
1297
+ # return StrideAndShape(tile_shape, RepeatedValue(tile_shape, factor))
1298
+ # else:
1299
+ # factor = sympy.ceiling(rank_shape / sympy.Min(tile_shape, rank_shape))
1300
+ # return make_possibly_different_last(tile_shape, factor, rank_shape)
1301
+ # elif node.loop_bound is not None:
1302
+ # factor = node.loop_bound
1303
+
1304
+ # if node._assume_perfect_factor or known_perfect_factor(factor, rank_shape):
1305
+ # tile_shape = rank_shape / factor
1306
+ # return StrideAndShape(tile_shape, RepeatedValue(tile_shape, factor))
1307
+ # else:
1308
+ # tile_shape = sympy.ceiling(rank_shape / sympy.Min(rank_shape, factor))
1309
+ # return make_possibly_different_last(tile_shape, factor, rank_shape)
1310
+
1311
+ # elif node.tile_pattern is not None:
1312
+ # stride = node.tile_pattern.stride
1313
+ # initial_tile_shape = node.tile_pattern.initial_tile_shape
1314
+ # tile_shape = node.tile_pattern.tile_shape
1315
+
1316
+ # if initial_tile_shape is not None:
1317
+ # middle_shape_factor = sympy.ceiling((rank_shape - initial_tile_shape)/stride)
1318
+ # # TODO: sometimes last_shape is 0, causing numerical instability
1319
+ # # Currently, we are sometimes rounding up last shape.
1320
+ # # last_shape = rank_shape - initial_tile_shape - stride*middle_shape_factor
1321
+ # # has_last_shape = sympy.ceiling(last_shape/(last_shape+1))
1322
+ # return StrideAndShape(
1323
+ # stride,
1324
+ # SequenceOfRepatedvalues([
1325
+ # RepeatedValue(initial_tile_shape, 1),
1326
+ # RepeatedValue(stride, middle_shape_factor),
1327
+ # # RepeatedValue(last_shape+0.01, has_last_shape)
1328
+ # ])
1329
+ # )
1330
+
1331
+
1332
+ def known_perfect_factor(divisor, full_shape):
1333
+ return (
1334
+ isinstance(divisor, int)
1335
+ and isinstance(full_shape, int)
1336
+ and full_shape % divisor == 1
1337
+ )
1338
+
1339
+
1340
+ def make_possibly_different_last(common_tile_shape, factor, full_shape):
1341
+ last_shape = full_shape - common_tile_shape * (factor - 1)
1342
+ all_shapes = SequenceOfRepatedvalues(
1343
+ [RepeatedValue(common_tile_shape, factor - 1), RepeatedValue(last_shape, 1)]
1344
+ )
1345
+ return StrideAndShape(common_tile_shape, all_shapes)
1346
+
1347
+
1348
+ def insert_sympy_symbols(mapping: list[MappingNode], job: Job):
1349
+ loop_idx = 0
1350
+ symbols = []
1351
+ rank_var_with_initial = set()
1352
+ for i, node in enumerate(mapping):
1353
+ if not isinstance(node, Loop):
1354
+ continue
1355
+
1356
+ stride_halos = set()
1357
+ for t in job.spec.workload.einsums[job.einsum_name].tensor_names:
1358
+ for (rank, rank_variable), (stride, halo) in job.stride_and_halo[t].items():
1359
+ if rank_variable == node.rank_variable:
1360
+ stride_halos.add((stride, halo))
1361
+
1362
+ if len(stride_halos) == 0:
1363
+ raise RuntimeError(
1364
+ f"{repr(node.rank_variable)} not found in {job.stride_and_halo}"
1365
+ )
1366
+
1367
+ # We only explore imperfect for the outermost fused loops
1368
+ simple = (
1369
+ (len(stride_halos) <= 1 and next(iter(stride_halos)) == (1, 0))
1370
+ or node.rank_variable in rank_var_with_initial
1371
+ or not node._fused
1372
+ )
1373
+
1374
+ # NOTE: initial_tile_shape must be inserted into `symbols` before `stride`
1375
+ # because of the order of tile shape exploration.
1376
+ # TODO: there has to be a better way to do this.
1377
+ if simple: # Just use the stride!
1378
+ node.initial_tile_shape = None
1379
+ elif node.initial_tile_shape == SYMBOL:
1380
+ rank_var_with_initial.add(node.rank_variable)
1381
+ initial_tile_shape = sympy.symbols(
1382
+ f"initial{loop_idx}", positive=True, integer=True
1383
+ )
1384
+ symbols.append(initial_tile_shape)
1385
+ node.initial_tile_shape = initial_tile_shape
1386
+
1387
+ # TODO: Check for 0 < shape < 1 for loop bound target
1388
+ if job.rank_variable_bounds[node.rank_variable] == 1:
1389
+ node.stride = 1
1390
+ elif node.stride == SYMBOL:
1391
+ stride = sympy.symbols(f"stride{loop_idx}", positive=True, integer=True)
1392
+ symbols.append(stride)
1393
+ node.stride = stride
1394
+
1395
+ # TODO: sometimes, a mapping is passed into the model twice.
1396
+ # E.g., after calling mapper, the model is called again for more
1397
+ # details.
1398
+ #
1399
+ # assert (
1400
+ # node.calculated_n_iterations is None
1401
+ # ), "Number of iterations is derived from the model. Do not set it!"
1402
+ node.calculated_n_iterations = sympy.symbols(
1403
+ f"n_iterations{loop_idx}", positive=True, integer=True
1404
+ )
1405
+
1406
+ loop_idx += 1
1407
+
1408
+ return symbols