accelforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. accelforge/__init__.py +21 -0
  2. accelforge/_accelerated_imports.py +16 -0
  3. accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
  4. accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
  5. accelforge/_deprecate/_simanneal/simanneal.py +666 -0
  6. accelforge/_deprecate/_simanneal/tracking.py +105 -0
  7. accelforge/_deprecate/_simanneal/wrappers.py +218 -0
  8. accelforge/_deprecate/_simanneal2/__init__.py +7 -0
  9. accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
  10. accelforge/_deprecate/_simanneal2/tracking.py +116 -0
  11. accelforge/_deprecate/compatibility_util.py +181 -0
  12. accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
  13. accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
  14. accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
  15. accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
  16. accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
  17. accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
  18. accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
  19. accelforge/_deprecate/tags.py +69 -0
  20. accelforge/_deprecate/viz/__init__.py +0 -0
  21. accelforge/_deprecate/viz/interactive.py +159 -0
  22. accelforge/_deprecate/viz/reservationtree.py +307 -0
  23. accelforge/_deprecate/viz/ski_slope.py +88 -0
  24. accelforge/_version.py +15 -0
  25. accelforge/examples.py +39 -0
  26. accelforge/frontend/__init__.py +10 -0
  27. accelforge/frontend/_binding.py +129 -0
  28. accelforge/frontend/_workload_isl/__init__.py +2 -0
  29. accelforge/frontend/_workload_isl/_isl.py +149 -0
  30. accelforge/frontend/_workload_isl/_symbolic.py +141 -0
  31. accelforge/frontend/arch copy.py +1544 -0
  32. accelforge/frontend/arch.py +1642 -0
  33. accelforge/frontend/config.py +63 -0
  34. accelforge/frontend/mapper/__init__.py +5 -0
  35. accelforge/frontend/mapper/ffm.py +126 -0
  36. accelforge/frontend/mapper/mapper.py +7 -0
  37. accelforge/frontend/mapper/metrics.py +30 -0
  38. accelforge/frontend/mapping/__init__.py +1 -0
  39. accelforge/frontend/mapping/mapping.py +1736 -0
  40. accelforge/frontend/model.py +14 -0
  41. accelforge/frontend/renames.py +150 -0
  42. accelforge/frontend/spec copy.py +230 -0
  43. accelforge/frontend/spec.py +301 -0
  44. accelforge/frontend/variables.py +12 -0
  45. accelforge/frontend/workload.py +952 -0
  46. accelforge/mapper/FFM/__init__.py +9 -0
  47. accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
  48. accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
  49. accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
  50. accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
  51. accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
  52. accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
  53. accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
  54. accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
  55. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
  56. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
  57. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
  58. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
  59. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
  60. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
  61. accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
  62. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
  63. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
  64. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
  65. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
  66. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
  67. accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
  68. accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
  69. accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
  70. accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
  71. accelforge/mapper/FFM/data.py +61 -0
  72. accelforge/mapper/FFM/main copy.py +236 -0
  73. accelforge/mapper/FFM/main.py +208 -0
  74. accelforge/mapper/FFM/mappings.py +510 -0
  75. accelforge/mapper/FFM/pmappings.py +310 -0
  76. accelforge/mapper/__init__.py +4 -0
  77. accelforge/mapper.py +0 -0
  78. accelforge/model/__init__.py +1 -0
  79. accelforge/model/_looptree/__init__.py +0 -0
  80. accelforge/model/_looptree/accesses.py +335 -0
  81. accelforge/model/_looptree/capacity/__init__.py +1 -0
  82. accelforge/model/_looptree/capacity/aggregators.py +36 -0
  83. accelforge/model/_looptree/capacity/capacity.py +47 -0
  84. accelforge/model/_looptree/energy.py +150 -0
  85. accelforge/model/_looptree/equivalent_ranks.py +29 -0
  86. accelforge/model/_looptree/latency/__init__.py +1 -0
  87. accelforge/model/_looptree/latency/latency.py +98 -0
  88. accelforge/model/_looptree/latency/memory.py +120 -0
  89. accelforge/model/_looptree/latency/processors.py +92 -0
  90. accelforge/model/_looptree/mapping_utilities.py +71 -0
  91. accelforge/model/_looptree/reuse/__init__.py +4 -0
  92. accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
  93. accelforge/model/_looptree/reuse/isl/des.py +59 -0
  94. accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
  95. accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
  96. accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
  97. accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
  98. accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
  99. accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
  100. accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
  101. accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
  102. accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
  103. accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
  104. accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
  105. accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
  106. accelforge/model/_looptree/run.py +122 -0
  107. accelforge/model/_looptree/types.py +26 -0
  108. accelforge/model/_looptree/visualization/__init__.py +0 -0
  109. accelforge/model/_looptree/visualization/occupancy.py +11 -0
  110. accelforge/model/main.py +222 -0
  111. accelforge/plotting/__init__.py +2 -0
  112. accelforge/plotting/mappings.py +219 -0
  113. accelforge/plotting/specs.py +57 -0
  114. accelforge/util/__init__.py +4 -0
  115. accelforge/util/_base_analysis_types.py +24 -0
  116. accelforge/util/_basetypes.py +1089 -0
  117. accelforge/util/_frozenset.py +36 -0
  118. accelforge/util/_isl.py +29 -0
  119. accelforge/util/_itertools.py +14 -0
  120. accelforge/util/_mathfuncs.py +57 -0
  121. accelforge/util/_parse_expressions.py +339 -0
  122. accelforge/util/_picklecache.py +32 -0
  123. accelforge/util/_setexpressions.py +268 -0
  124. accelforge/util/_sympy/__init__.py +0 -0
  125. accelforge/util/_sympy/broadcast_max.py +18 -0
  126. accelforge/util/_visualization.py +112 -0
  127. accelforge/util/_yaml.py +579 -0
  128. accelforge/util/parallel.py +193 -0
  129. accelforge-0.0.1.dist-info/METADATA +64 -0
  130. accelforge-0.0.1.dist-info/RECORD +258 -0
  131. accelforge-0.0.1.dist-info/WHEEL +5 -0
  132. accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
  133. accelforge-0.0.1.dist-info/top_level.txt +5 -0
  134. docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
  135. docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
  136. docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
  137. docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
  138. docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
  139. docs/_build/html/_sources/fastfusion.rst.txt +20 -0
  140. docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
  141. docs/_build/html/_sources/index.rst.txt +87 -0
  142. docs/_build/html/_sources/modules.rst.txt +7 -0
  143. docs/_build/html/_sources/notes/citation.rst.txt +45 -0
  144. docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
  145. docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
  146. docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
  147. docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
  148. docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
  149. docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
  150. docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
  151. docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
  152. docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
  153. docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
  154. docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
  155. docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
  156. docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
  157. docs/_build/html/_sources/notes/spec.rst.txt +36 -0
  158. docs/source/_ext/include_attrs.py +213 -0
  159. docs/source/_ext/include_docstring.py +364 -0
  160. docs/source/_ext/include_functions.py +154 -0
  161. docs/source/_ext/include_notebook.py +131 -0
  162. docs/source/_ext/include_yaml.py +119 -0
  163. docs/source/_ext/inherited_attributes.py +222 -0
  164. docs/source/_ext/paths.py +4 -0
  165. docs/source/conf.py +79 -0
  166. examples/arches/compute_in_memory/_include.yaml +74 -0
  167. examples/arches/compute_in_memory/_include_functions.py +229 -0
  168. examples/arches/compute_in_memory/_load_spec.py +57 -0
  169. examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
  170. examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
  171. examples/arches/compute_in_memory/components/misc.py +195 -0
  172. examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
  173. examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
  174. examples/arches/compute_in_memory/isaac.yaml +233 -0
  175. examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
  176. examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
  177. examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
  178. examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
  179. examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
  180. examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
  181. examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
  182. examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
  183. examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
  184. examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
  185. examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
  186. examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
  187. examples/arches/eyeriss.yaml +68 -0
  188. examples/arches/fanout_variations/at_glb.yaml +31 -0
  189. examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
  190. examples/arches/fanout_variations/at_mac.yaml +31 -0
  191. examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
  192. examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
  193. examples/arches/nvdla.yaml +47 -0
  194. examples/arches/simple.yaml +28 -0
  195. examples/arches/tpu_v4i.yaml +67 -0
  196. examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
  197. examples/misc/component_annotated.yaml +33 -0
  198. examples/workloads/gpt3_6.7B.yaml +124 -0
  199. examples/workloads/matmuls.yaml +20 -0
  200. examples/workloads/mobilenet_28.yaml +81 -0
  201. examples/workloads/mobilenet_various_separate.yaml +106 -0
  202. examples/workloads/three_matmuls_annotated.yaml +59 -0
  203. notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
  204. notebooks/compute_in_memory/_scripts.py +339 -0
  205. notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
  206. notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
  207. notebooks/paths.py +4 -0
  208. notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
  209. notebooks/tutorials/FFM.ipynb +3498 -0
  210. notebooks/tutorials/_include.py +48 -0
  211. notebooks/tutorials/component_energy_area.ipynb +363 -0
  212. tests/Q_mapping.yaml +38 -0
  213. tests/__init__.py +0 -0
  214. tests/conv.mapping.yaml +27 -0
  215. tests/conv.workload.yaml +13 -0
  216. tests/conv_sym.mapping.yaml +43 -0
  217. tests/copy.mapping.yaml +35 -0
  218. tests/copy.workload.yaml +15 -0
  219. tests/distribuffers/__init__.py +0 -0
  220. tests/distribuffers/multicast/test_cases.yaml +482 -0
  221. tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
  222. tests/distribuffers/spec/distributed.yaml +100 -0
  223. tests/distribuffers/spec/logical_arch.yaml +32 -0
  224. tests/distribuffers/spec/physical_arch.yaml +69 -0
  225. tests/distribuffers/test_binding.py +48 -0
  226. tests/frontend/__init__.py +0 -0
  227. tests/frontend/test_mapping_viz.py +52 -0
  228. tests/mapper/__init__.py +0 -0
  229. tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
  230. tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
  231. tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
  232. tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
  233. tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
  234. tests/mapper/test_mapping_to_isl.py +90 -0
  235. tests/mapper/test_spatial_reuse_analysis.py +67 -0
  236. tests/mapper/test_temporal_reuse_analysis.py +56 -0
  237. tests/mapper/util.py +58 -0
  238. tests/matmul.mapping.yaml +29 -0
  239. tests/matmul.workload.yaml +12 -0
  240. tests/matmul_spatial.mapping.yaml +44 -0
  241. tests/mha.renames.yaml +65 -0
  242. tests/mha.workload.yaml +67 -0
  243. tests/mha.yaml +59 -0
  244. tests/mha_full.workload.yaml +67 -0
  245. tests/mobilenet.workload.yaml +35 -0
  246. tests/mobilenet_long.workload.yaml +64 -0
  247. tests/pmappingcache.py +24 -0
  248. tests/processing_stage.arch.yaml +40 -0
  249. tests/snowcat.arch.yaml +36 -0
  250. tests/test_ffm_join_pmappings.py +106 -0
  251. tests/test_ffm_make_pmappings.py +82 -0
  252. tests/test_ffm_make_tile_shapes.py +49 -0
  253. tests/test_mapper.py +100 -0
  254. tests/test_model.py +37 -0
  255. tests/test_plotting.py +72 -0
  256. tests/test_processing_stage.py +46 -0
  257. tests/test_symbolic_model.py +248 -0
  258. tests/test_workload.py +141 -0
@@ -0,0 +1,1346 @@
1
+ import copy
2
+ from dataclasses import dataclass, field
3
+ import itertools
4
+ from accelforge.frontend.mapping import (
5
+ Compute,
6
+ Mapping,
7
+ Nested,
8
+ Pipeline,
9
+ ProcessingStage,
10
+ Reservation,
11
+ Sequential,
12
+ Spatial,
13
+ Split,
14
+ Storage,
15
+ Temporal,
16
+ )
17
+ from typing import Any
18
+
19
+ from accelforge.frontend import arch
20
+ import accelforge.frontend.mapping as mapping_spec
21
+ from accelforge.frontend.mapping import (
22
+ Mapping,
23
+ MappingNode,
24
+ Nested,
25
+ Spatial,
26
+ Temporal,
27
+ Storage,
28
+ Reservation,
29
+ Loop,
30
+ TensorHolder,
31
+ ProcessingStage,
32
+ )
33
+ from accelforge.frontend.workload import (
34
+ Workload,
35
+ TensorName,
36
+ isl_expression_has_variable,
37
+ )
38
+ from accelforge.frontend._workload_isl._isl import get_rank_variable_bounds
39
+ from accelforge.frontend._workload_isl._symbolic import (
40
+ get_projection_expr,
41
+ get_rank_variable_relevancy,
42
+ compute_dense_tile_occupancy,
43
+ Irrelevant,
44
+ Relevant,
45
+ PartiallyRelevant,
46
+ )
47
+
48
+ from accelforge.model._looptree.types import Buffet
49
+
50
+ from accelforge.mapper.FFM._make_pmappings.pmapper_job import Job
51
+ from accelforge.util._sympy.broadcast_max import Min, Max
52
+
53
+ import sympy
54
+
55
+
56
+ SYMBOL = "symbol"
57
+ IMPERFECT = False
58
+
59
+
60
+ @dataclass(eq=True, frozen=True)
61
+ class Compute:
62
+ einsum: str
63
+ level: str
64
+
65
+
66
+ class Uninitialized:
67
+ def __init__(self):
68
+ pass
69
+
70
+ def __str__(self):
71
+ return "Uninitialized"
72
+
73
+ def __repr__(self):
74
+ return "Uninitialized()"
75
+
76
+ def __rmul__(self, other):
77
+ return self * other
78
+
79
+ def __mul__(self, other):
80
+ return self
81
+
82
+ def __radd__(self, other):
83
+ return self + other
84
+
85
+ def __add__(self, other):
86
+ return self
87
+
88
+
89
+ # TODO: unsure if this is needed. If the sympy symbol is created with the
90
+ # correct assumption (e.g., positive), this should be automatic.
91
+ def min_nonzero(a: Any, b: Any) -> Any:
92
+ if a == 0:
93
+ return b
94
+ if b == 0:
95
+ return a
96
+ return Min(a, b)
97
+
98
+
99
+ def max_dict(a: dict[Any, Any], b: dict[Any, Any]) -> dict[Any, Any]:
100
+ new = {**a}
101
+ for key, value in b.items():
102
+ new[key] = Max(new[key], value) if key in new else value
103
+ assert isinstance(new, dict)
104
+ return new
105
+
106
+
107
+ @dataclass
108
+ class BuffetStats:
109
+ total_reads_to_parent: Any = field(default=0)
110
+ total_writes_to_parent: Any = field(default=0)
111
+ max_per_parent_reads_to_parent: Any = field(default=0)
112
+ max_per_parent_writes_to_parent: Any = field(default=0)
113
+
114
+ total_reads_to_peer: Any = field(default=0)
115
+ total_writes_to_peer: Any = field(default=0)
116
+ max_per_unit_reads_to_peer: Any = field(default=0)
117
+ max_per_unit_writes_to_peer: Any = field(default=0)
118
+
119
+ # Skip the first iteration of temporal loops for data that is written
120
+ total_skipped_first_reads_to_parent: Any = field(default=0)
121
+ total_skipped_first_reads_to_peer: Any = field(default=0)
122
+ min_per_parent_skipped_first_reads_to_parent: Any = field(default=0)
123
+ min_per_unit_skipped_first_writes_to_peer: Any = field(default=0)
124
+
125
+ max_occupancy: Any = field(default=0)
126
+ _n_loops_above: int = field(default=0)
127
+
128
+ # These are used to calculate energy and latency
129
+ total_write_actions: Any = field(default=0)
130
+ max_per_unit_write_actions: Any = field(default=0)
131
+ total_read_actions: Any = field(default=0)
132
+ max_per_unit_read_actions: Any = field(default=0)
133
+
134
+ total_skipped_first_write_actions: Any = field(default=0)
135
+ min_per_unit_skipped_first_write_actions: Any = field(default=0)
136
+ total_skipped_first_read_actions: Any = field(default=0)
137
+ min_per_unit_skipped_first_read_actions: Any = field(default=0)
138
+
139
+ persistent: bool = field(default=False)
140
+
141
+ @property
142
+ def n_loops_above(self) -> int:
143
+ if self.persistent:
144
+ return -1
145
+ return self._n_loops_above
146
+
147
+ @n_loops_above.setter
148
+ def n_loops_above(self, value: int):
149
+ self._n_loops_above = value
150
+
151
+ def repeat_temporal(self, factor: int, is_fully_relevant: bool) -> "BuffetStats":
152
+ new = copy.copy(self)
153
+ for attr in self.__dict__:
154
+ if not attr.startswith(("total_", "max_", "min_")):
155
+ continue
156
+ if "skipped_first" in attr and not is_fully_relevant:
157
+ continue # First actions occur once per relevant iteration.
158
+ if attr == "max_occupancy":
159
+ continue # Max occupancy is not affected by temporal loops above
160
+ setattr(new, attr, getattr(new, attr) * factor)
161
+ return new
162
+
163
+ def repeat_spatial(self, factor: int, reuse_parent_accesses: bool) -> "BuffetStats":
164
+ new = copy.copy(self)
165
+ for attr in self.__dict__:
166
+ if not attr.startswith(("total_", "max_", "min_")):
167
+ continue
168
+ if "parent" in attr and reuse_parent_accesses:
169
+ continue # If parent accesses are reused, no need to multiply
170
+ if "per_unit" in attr:
171
+ continue # Spatial fanout doesn't affect per-unit stats
172
+ if attr == "max_occupancy":
173
+ continue # Max occupancy is not affected by temporal loops above
174
+ setattr(new, attr, getattr(new, attr) * factor)
175
+ return new
176
+
177
+ def max(self, **kwargs: Any):
178
+ for key, value in kwargs.items():
179
+ setattr(self, key, Max(getattr(self, key), value))
180
+
181
+ def min(self, **kwargs: Any):
182
+ for key, value in kwargs.items():
183
+ setattr(self, key, Min(getattr(self, key), value))
184
+
185
+ def __add__(self, other: "BuffetStats") -> "BuffetStats":
186
+ new = copy.copy(self)
187
+ for attr in self.__dict__:
188
+ if attr.startswith("min_"):
189
+ setattr(
190
+ new, attr, min_nonzero(getattr(self, attr), getattr(other, attr))
191
+ )
192
+ elif attr.startswith("max_"):
193
+ setattr(new, attr, Max(getattr(self, attr), getattr(other, attr)))
194
+ elif attr.startswith("total_"):
195
+ setattr(new, attr, getattr(self, attr) + getattr(other, attr))
196
+ elif getattr(self, attr) is None:
197
+ setattr(new, attr, getattr(other, attr))
198
+ elif getattr(other, attr) is None:
199
+ setattr(new, attr, getattr(self, attr))
200
+ else:
201
+ assert getattr(self, attr) == getattr(
202
+ other, attr
203
+ ), f"BUG: {attr} is different. self: {getattr(self, attr)} other: {getattr(other, attr)}"
204
+ return new
205
+
206
+ def __iadd__(self, other: "BuffetStats") -> "BuffetStats":
207
+ new = self + other
208
+ for key, value in new.__dict__.items():
209
+ setattr(self, key, value)
210
+ return self
211
+
212
+ def net_total_read_actions(self) -> Any:
213
+ return self.total_read_actions - self.total_skipped_first_read_actions
214
+
215
+ def net_total_write_actions(self) -> Any:
216
+ return self.total_write_actions - self.total_skipped_first_write_actions
217
+
218
+ def net_max_per_unit_read_actions(self) -> Any:
219
+ return (
220
+ self.max_per_unit_read_actions
221
+ - self.min_per_unit_skipped_first_read_actions
222
+ )
223
+
224
+ def net_max_per_unit_write_actions(self) -> Any:
225
+ return (
226
+ self.max_per_unit_write_actions
227
+ - self.min_per_unit_skipped_first_write_actions
228
+ )
229
+
230
+
231
+ def blank_buffet_stats() -> BuffetStats:
232
+ stats = BuffetStats()
233
+ stats.n_loops_above = None # Inherit from whoever is added to this
234
+ return stats
235
+
236
+
237
+ @dataclass
238
+ class ComputeStats:
239
+ total_ops: Any = field(default=0)
240
+ max_per_unit_ops: Any = field(default=0)
241
+ # "max" below refers to the longest latency of any iteration
242
+ max_latency: Any = field(default=0)
243
+ # Mapping from the loop-index (0 at top) to the latency of the first
244
+ # iteration of that loop. "Max" because we may have loops above that and we
245
+ # will take the maximum of the firsts.
246
+ max_first_latency: dict[int, Any] = field(default_factory=dict)
247
+
248
+ def repeat_temporal(self, factor: int) -> "ComputeStats":
249
+ new = copy.copy(self)
250
+ new.total_ops *= factor
251
+ new.max_per_unit_ops *= factor
252
+ new.max_latency *= factor
253
+ # NOTE: max_first_latency does not change
254
+ return new
255
+
256
+ def repeat_spatial(self, factor: int) -> "ComputeStats":
257
+ new = copy.copy(self)
258
+ new.total_ops *= factor
259
+ return new
260
+
261
+ def __add__(self, other: "ComputeStats") -> "ComputeStats":
262
+ new = copy.copy(self)
263
+ new.total_ops += other.total_ops
264
+ new.max_per_unit_ops += other.max_per_unit_ops
265
+ new.max_latency += other.max_latency
266
+ # max_first_latency is only ever updated across loops ABOVE the loop
267
+ # for which we calculated that first latency, so we should MAX
268
+ new.max_first_latency = max_dict(
269
+ self.max_first_latency, other.max_first_latency
270
+ ) # FIRST LATENCY
271
+ return new
272
+
273
+ def combine_temporal(self, other: "ComputeStats"):
274
+ self.total_ops += other.total_ops
275
+ self.max_per_unit_ops += other.max_per_unit_ops
276
+ self.max_latency += other.max_latency
277
+ # max_first_latency is only ever updated across loops ABOVE the loop
278
+ # for which we calculated that first latency, so we should MAX
279
+ self.max_first_latency = max_dict(
280
+ self.max_first_latency, other.max_first_latency
281
+ ) # FIRST LATENCY
282
+
283
+ def combine_spatial(self, other: "ComputeStats"):
284
+ self.total_ops += other.total_ops
285
+ self.max_per_unit_ops = Max(self.max_per_unit_ops, other.max_per_unit_ops)
286
+ self.max_latency = Max(self.max_latency, other.max_latency)
287
+ # max_first_latency is only ever updated across loops ABOVE the loop
288
+ # for which we calculated that first latency, so we should MAX
289
+ self.max_first_latency = max_dict(
290
+ self.max_first_latency, other.max_first_latency
291
+ ) # FIRST LATENCY
292
+
293
+
294
+ @dataclass
295
+ class SymbolicAnalysisOutput:
296
+ compute_stats: dict[Compute, ComputeStats] = field(default_factory=dict)
297
+
298
+ buffet_stats: dict[Buffet, BuffetStats] = field(default_factory=dict)
299
+
300
+ # Mapping [level, einsum] to the fanout
301
+ fanout: dict[(Buffet, str), int] = field(default_factory=dict)
302
+
303
+ # Mapping [einsum] to the number of temporal steps
304
+ temporal_steps: dict[str, int] = field(default_factory=dict)
305
+
306
+ symbols: list[sympy.Symbol] = field(default_factory=list)
307
+
308
+ # tensor to the mapping for that particular tensor
309
+ tensor2mapping: dict[TensorName, Mapping] = field(default_factory=dict)
310
+
311
+ def get_buffet_for_tensor(self, tensor: TensorName) -> Buffet:
312
+ for buffet in self.buffet_stats:
313
+ if buffet.tensor == tensor:
314
+ return buffet
315
+ raise ValueError(f"Buffet for tensor {tensor} not found")
316
+
317
+ def max(self, **kwargs: Any):
318
+ for key, value in kwargs.items():
319
+ assert key in [
320
+ "compute_stats",
321
+ "stats",
322
+ "fanout",
323
+ "temporal_steps",
324
+ ]
325
+ previous = getattr(self, key)
326
+ for k, v in value.items():
327
+ previous.setdefault(k, {})
328
+ for k2, v2 in v.items():
329
+ if k2 in previous[k]:
330
+ previous[k][k2] = Max(previous[k][k2], v2)
331
+ else:
332
+ previous[k][k2] = v2
333
+
334
+ def get_child_buffet_stats(self, buffet: Buffet) -> BuffetStats:
335
+ seen = False
336
+ for child_buffet, child_stats in reversed(self.buffet_stats.items()):
337
+ if not seen:
338
+ seen = child_buffet == buffet
339
+ continue
340
+ if child_buffet.tensor == buffet.tensor:
341
+ return child_stats
342
+ return None
343
+
344
+ def sum_buffet_stats_per_level(self) -> dict[str, BuffetStats]:
345
+ result: dict[str, BuffetStats] = {}
346
+ for buffet, stats in self.buffet_stats.items():
347
+ result.setdefault(buffet.level, blank_buffet_stats())
348
+ result[buffet.level] += stats
349
+ return result
350
+
351
+ def add_buffet_stats_and_symbols(self, other: "SymbolicAnalysisOutput"):
352
+ assert not (set(self.buffet_stats) & set(other.buffet_stats)), "BUG"
353
+ self.buffet_stats.update(other.buffet_stats)
354
+ # if self.temporal_steps != other.temporal_steps:
355
+ # print(f'Temporal steps are different.')
356
+ # print(f'\tmine: {self.temporal_steps}')
357
+ # print(f'\tother: {other.temporal_steps}')
358
+ # assert self.temporal_steps == other.temporal_steps, "BUG"
359
+ self.temporal_steps.update(other.temporal_steps)
360
+ self.symbols.extend([s for s in other.symbols if s not in self.symbols])
361
+ # Assert compute stats are the same
362
+ # assert self.compute_stats == other.compute_stats, "BUG"
363
+
364
+
365
+ @dataclass
366
+ class AnalysisInfo:
367
+ """Information needed within the analysis step by multiple functions that
368
+ can be computed once at the beginning.
369
+ """
370
+
371
+ mapping: Mapping
372
+ workload: Workload
373
+ full_rank_variable_shapes: dict
374
+ all_tensors: set
375
+
376
+ einsum_tensor_to_projection: dict
377
+ tensor_to_relevancy: dict
378
+ tensor_to_backer_id: dict[TensorName, int]
379
+
380
+ is_copy_operation: TensorName | None
381
+
382
+ job: Job
383
+
384
+ tensor_to_reservation_backer_id: dict[TensorName, int] = field(default_factory=dict)
385
+
386
+ # We track first latency for these nodes (should be Temporal)
387
+ last_temporal_node_idx: int = None
388
+ """
389
+ node idx of the last (above) temporal node
390
+ """
391
+ idxs_to_track_first_latency: set[int] = field(default_factory=set)
392
+ """
393
+ node idxs for which we track first latency
394
+ """
395
+
396
+
397
+ def quick_insert_reservation_nodes(job: Job) -> list[MappingNode]:
398
+ mapping = list(job.mapping.nodes)
399
+ workload = job.spec.workload
400
+
401
+ # TODO: Subclass reservation with TensorReservation or something so that we can
402
+ # track which are for tensors and which are for non-tensor resources.
403
+
404
+ info = AnalysisInfo(
405
+ mapping=None,
406
+ workload=workload,
407
+ full_rank_variable_shapes=None,
408
+ all_tensors=None,
409
+ einsum_tensor_to_projection=None,
410
+ tensor_to_relevancy=job.tensor_to_relevancy,
411
+ tensor_to_backer_id=None,
412
+ is_copy_operation=None,
413
+ job=None,
414
+ )
415
+ insert_reservation_nodes(mapping, info)
416
+ m = Mapping(nodes=mapping)
417
+ m._n_loop_orders = job.mapping._n_loop_orders
418
+ return m
419
+
420
+
421
+ def convert_to_copy(
422
+ mapping: list[MappingNode], workload: Workload
423
+ ) -> tuple[list[MappingNode], dict[TensorName, int]]:
424
+ mapping = copy.deepcopy(mapping)
425
+
426
+ # Calculate this BEFORE we modify the mapping. We're going to have the copy source
427
+ # tensor moving upward sometimes, and we don't want the backing tensor holder
428
+ tensor_to_backer_id = get_tensor_to_backer_id(mapping)
429
+
430
+ first_input_tensor = workload.einsums[mapping[-1].einsum].copy_source_tensor()
431
+
432
+ for node in mapping:
433
+ if isinstance(node, TensorHolder):
434
+ if node.tensors:
435
+ node.tensors = [first_input_tensor]
436
+ node._lower = False
437
+
438
+ to_remove = []
439
+ i = 0
440
+ while i < len(mapping):
441
+ node = mapping[i]
442
+ if isinstance(node, TensorHolder):
443
+ j = i + 1
444
+ while j < len(mapping):
445
+ node2 = mapping[j]
446
+ if (
447
+ isinstance(node2, TensorHolder)
448
+ and node.component == node2.component
449
+ ):
450
+ mapping.pop(j)
451
+ else:
452
+ j += 1
453
+ i += 1
454
+ mapping = [node for node in mapping if node not in to_remove]
455
+
456
+ return mapping, tensor_to_backer_id
457
+
458
+
459
+ def analyze_reuse_and_add_reservations_to_mapping(
460
+ job: Job,
461
+ ) -> SymbolicAnalysisOutput:
462
+ mapping = job.mapping.nodes
463
+ workload = job.spec.workload
464
+ einsum_name = mapping[-1].einsum
465
+
466
+ is_copy_operation = workload.einsums[einsum_name].is_copy_operation
467
+ symbols = insert_sympy_symbols(job.mapping.nodes, job)
468
+
469
+ if is_copy_operation:
470
+ mapping, tensor_to_backer_id = convert_to_copy(mapping, workload)
471
+ else:
472
+ tensor_to_backer_id = get_tensor_to_backer_id(mapping)
473
+
474
+ job.mapping = quick_insert_reservation_nodes(job)
475
+ # print(f'Job mapping: {job.mapping.compact_str()}')
476
+ # for n in job.mapping.nodes:
477
+ # print(f'\t{n.compact_str()}')
478
+
479
+ einsum_tensor_to_projection = {}
480
+ einsum = workload.einsums[einsum_name]
481
+ all_tensors = einsum.tensor_names
482
+ for tensor in all_tensors:
483
+ einsum_tensor_to_projection[(einsum_name, tensor)] = get_projection_expr(
484
+ einsum, tensor
485
+ )
486
+ tensor_to_relevancy = {
487
+ tensor: get_rank_variable_relevancy(einsum, tensor) for tensor in all_tensors
488
+ }
489
+ assert all_tensors, f"Einsum {einsum_name} has no tensors"
490
+
491
+ """
492
+ Note for how this works.
493
+
494
+ Spatial loops are weird, because they don't belong at a single point in the loop
495
+ nest. For example:
496
+
497
+ - DRAM keep A, B
498
+ - *
499
+ - Reg keep A
500
+ - for n in [0..N)
501
+ - GLB keep B
502
+ - *
503
+ - Compute
504
+
505
+ A loop spatial-for (Reg) k in [0..K) would affect the register at the point of the
506
+ first asterisk, but the global buffer at the point of the second asterisk.
507
+
508
+ To handle this, we make a separate mapping for each tensor, analyze each, and
509
+ combine the results.
510
+
511
+ To anyone who would like to create behavior that simultaneously looks at multiple
512
+ storage nodes for a given memory, note that there will be two challenges to address:
513
+
514
+ 1. The code currently analyzes one tensor at a time. This could be fixed by
515
+ processing all mapping(s) together, applying loop(s) from each to only the
516
+ appropriate nodes.
517
+ 2. The code must analyze one storage node at a time, and there may be temporal and
518
+ spatial nodes between two storage nodes for a given memory, which would separate
519
+ the analysis steps for the storage nodes. This may be addressed by only
520
+ performing such analysis until the outermost storage node for a particular memory
521
+ has been analyzed.
522
+ """
523
+ result = None
524
+
525
+ tensor2mapping = {}
526
+ index_expressions = set(einsum.indexing_expressions)
527
+ for k, v in job.rank_variable_bounds.items():
528
+ index_expressions.add(f"0 < {k} <= {v}")
529
+ for tensor in all_tensors:
530
+ cur_mapping = job.mapping._get_single_tensor_mapping(
531
+ tensor, job.flattened_arch, index_expressions
532
+ )
533
+ info = AnalysisInfo(
534
+ mapping=cur_mapping.nodes,
535
+ workload=workload,
536
+ full_rank_variable_shapes=job.rank_variable_bounds,
537
+ all_tensors=set([tensor]),
538
+ einsum_tensor_to_projection=einsum_tensor_to_projection,
539
+ tensor_to_relevancy=tensor_to_relevancy,
540
+ tensor_to_backer_id=tensor_to_backer_id,
541
+ is_copy_operation=is_copy_operation,
542
+ job=job,
543
+ )
544
+ cur_result = analyze_node(0, job.rank_variable_bounds, info)
545
+ if result is None:
546
+ result = cur_result
547
+ else:
548
+ result.add_buffet_stats_and_symbols(cur_result)
549
+ tensor2mapping[tensor] = cur_mapping
550
+
551
+ result.symbols = symbols
552
+ result.tensor2mapping = tensor2mapping
553
+ return result
554
+
555
+
556
+ def get_tensor_to_backer_id(mapping: Mapping):
557
+ tensor_to_ids: dict[TensorName, set[int]] = {}
558
+ for node in mapping:
559
+ if isinstance(node, TensorHolder):
560
+ for tensor in node.tensors:
561
+ if tensor in tensor_to_ids:
562
+ continue
563
+ tensor_to_ids[tensor] = id(node)
564
+ return tensor_to_ids
565
+
566
+
567
+ class ReservationAnalysisTracker:
568
+ def __init__(self, buffet, node):
569
+ self.buffet: Buffet = buffet
570
+ self.node: TensorHolder = node
571
+
572
+ # These are interface (TODO: should be property)
573
+ self.is_fill_level = False
574
+ self.should_stop = False
575
+ self.insert_reservation_under = False
576
+ self.insert_fill_under = False
577
+
578
+ # Temporary values
579
+ self.has_filled = False
580
+
581
+ def track_temporal_loop(self, relevancy, node):
582
+ self.is_fill_level = False
583
+ self.insert_reservation_under = False
584
+ self.insert_fill_under = False
585
+
586
+ if isinstance(relevancy, Irrelevant):
587
+ if not self.has_filled:
588
+ self.is_fill_level = True
589
+ self.has_filled = True
590
+
591
+ self.should_stop = True
592
+ elif isinstance(relevancy, Relevant):
593
+ self.should_stop = False
594
+ elif isinstance(relevancy, PartiallyRelevant):
595
+ self.last = True
596
+
597
+ if not self.has_filled:
598
+ self.is_fill_level = True
599
+ self.has_filled = True
600
+
601
+ self.should_stop = True
602
+ self.insert_reservation_under = True
603
+ else:
604
+ raise ValueError(f"Unknown relevancy {relevancy}")
605
+
606
+ def track_compute(self):
607
+ self.should_stop = True
608
+ if not self.has_filled:
609
+ self.is_fill_level = True
610
+ self.has_filled = True
611
+
612
+ def track_spatial_loop(self, relevancy, node):
613
+ if node.component != self.buffet.level:
614
+ self.should_stop = True
615
+ if not self.has_filled:
616
+ self.is_fill_level = True
617
+ self.has_filled = True
618
+ return
619
+
620
+ self.is_fill_level = False
621
+ self.should_stop = False
622
+
623
+
624
+ def insert_reservation_nodes(mapping, info: AnalysisInfo):
625
+ trackers: list[ReservationAnalysisTracker] = []
626
+ einsum = info.workload.einsums[mapping[-1].einsum]
627
+ non_intermediate_tensors = (
628
+ einsum.tensor_names - info.workload.tensor_names_used_in_multiple_einsums
629
+ )
630
+ seen_tensors = set() # reservation for top-level buffets cannot be lowered
631
+
632
+ n_nodes = len(mapping)
633
+ i = 0
634
+ while i < n_nodes:
635
+ node = mapping[i]
636
+ to_remove = []
637
+ if isinstance(node, Reservation):
638
+ pass
639
+ elif isinstance(node, Temporal):
640
+ rank = node.rank_variable
641
+ for tracker in trackers:
642
+ relevancy = info.tensor_to_relevancy[tracker.buffet.tensor]
643
+ tracker.track_temporal_loop(relevancy[rank], node)
644
+ elif isinstance(node, Spatial):
645
+ rank = node.rank_variable
646
+ for tracker in trackers:
647
+ relevancy = info.tensor_to_relevancy[tracker.buffet.tensor]
648
+ tracker.track_spatial_loop(relevancy[rank], node)
649
+ elif isinstance(node, TensorHolder):
650
+ for tracker in trackers:
651
+ tracker.should_stop = True
652
+ tracker.insert_reservation_under = False
653
+ for tensor in node.tensors:
654
+ tensor = TensorName(tensor)
655
+ buffet = Buffet(tensor, mapping[-1].einsum, node.component)
656
+ trackers.append(ReservationAnalysisTracker(buffet, node))
657
+ if not node._lower or (
658
+ tensor not in seen_tensors and tensor in non_intermediate_tensors
659
+ ):
660
+ seen_tensors.add(tensor)
661
+ trackers[-1].is_fill_level = True
662
+ trackers[-1].insert_reservation_under = True
663
+ trackers[-1].insert_fill_under = True
664
+ trackers[-1].should_stop = True
665
+ elif isinstance(node, mapping_spec.Compute):
666
+ for tracker in trackers:
667
+ tracker.track_compute()
668
+ tracker.insert_reservation_under = False
669
+ else:
670
+ raise NotImplementedError(f"Unknown node type {type(node)}")
671
+
672
+ reservation_insert_below = []
673
+ reservation_insert_above = []
674
+ for j in range(len(trackers) - 1, -1, -1):
675
+ if not trackers[j].should_stop:
676
+ continue
677
+ tracker = trackers.pop(j)
678
+ buffet = tracker.buffet
679
+ node = Reservation(purposes=[buffet.tensor], resource=buffet.level)
680
+ node.persistent = tracker.node.persistent
681
+ node._backing = tracker.node._backing
682
+
683
+ if (
684
+ buffet.tensor not in info.tensor_to_reservation_backer_id
685
+ and buffet.tensor in info.workload.tensor_names_used_in_multiple_einsums
686
+ ):
687
+ info.tensor_to_reservation_backer_id[buffet.tensor] = id(node)
688
+
689
+ if tracker.insert_reservation_under:
690
+ reservation_insert_below.append(node)
691
+ else:
692
+ reservation_insert_above.append(node)
693
+
694
+ # The order of these for loops is important. Reservation must be below fill.
695
+ for node in reservation_insert_below:
696
+ mapping.insert(i + 1, node)
697
+ i += 1
698
+ for node in reservation_insert_above:
699
+ mapping.insert(i, node)
700
+ i += 1
701
+
702
+ i += 1
703
+ n_nodes = len(mapping)
704
+
705
+ label_fused_loops(mapping)
706
+
707
+
708
+ def label_fused_loops(mapping: list[MappingNode]):
709
+ last_backer = None
710
+ for i, node in enumerate(mapping):
711
+ if isinstance(node, Reservation) and node._backing:
712
+ last_backer = i
713
+ if last_backer is None:
714
+ raise ValueError(
715
+ f"No backing TensorHolder found in mapping {", ".join(m.compact_str() for m in mapping)}"
716
+ )
717
+
718
+ for i, node in enumerate(mapping):
719
+ if isinstance(node, Loop):
720
+ node._fused = i < last_backer
721
+ return mapping
722
+
723
+
724
+ def analyze_node(node_idx, current_shape, info: AnalysisInfo) -> SymbolicAnalysisOutput:
725
+ node = info.mapping[node_idx]
726
+ class2analysis_function = {
727
+ Temporal: analyze_temporal,
728
+ Spatial: analyze_spatial,
729
+ Storage: analyze_storage,
730
+ Reservation: analyze_reservation,
731
+ mapping_spec.Compute: analyze_compute,
732
+ ProcessingStage: analyze_processing_stage,
733
+ }
734
+ if type(node) not in class2analysis_function:
735
+ raise TypeError(f"Unknown node type {type(node)}")
736
+ return class2analysis_function[type(node)](node_idx, current_shape, info)
737
+
738
+
739
+ def analyze_temporal(
740
+ node_idx, current_shape, info: AnalysisInfo
741
+ ) -> SymbolicAnalysisOutput:
742
+ mapping = info.mapping
743
+ node = mapping[node_idx]
744
+ stride_and_shape = get_stride_and_tile_shape(node, current_shape, node_idx, info)
745
+
746
+ result_accumulator = SymbolicAnalysisOutput()
747
+
748
+ first_latency = None
749
+
750
+ def handle_repeated_value(repeated_shape):
751
+ nonlocal first_latency
752
+ shape_value = repeated_shape.value
753
+ shape_repeats = repeated_shape.repeats
754
+
755
+ child_shape = current_shape.copy()
756
+ child_shape[node.rank_variable] = shape_value
757
+
758
+ child_result = analyze_node(node_idx + 1, child_shape, info)
759
+
760
+ accumulated_buffet_stats = result_accumulator.buffet_stats
761
+ for buffet, stats in child_result.buffet_stats.items():
762
+ relevancy = info.tensor_to_relevancy[buffet.tensor][node.rank_variable]
763
+ is_fully_relevant = isinstance(relevancy, Relevant)
764
+ accumulated_stats = accumulated_buffet_stats.setdefault(
765
+ buffet, blank_buffet_stats()
766
+ )
767
+ accumulated_stats += stats.repeat_temporal(
768
+ shape_repeats, is_fully_relevant=is_fully_relevant
769
+ )
770
+ accumulated_stats.n_loops_above = stats.n_loops_above + 1
771
+
772
+ for einsum, child_steps in child_result.temporal_steps.items():
773
+ if einsum not in result_accumulator.temporal_steps:
774
+ result_accumulator.temporal_steps[einsum] = 0
775
+ result_accumulator.temporal_steps[einsum] += child_steps * shape_repeats
776
+
777
+ result_accumulator.max(fanout=child_result.fanout)
778
+
779
+ for key in child_result.compute_stats:
780
+ if first_latency is None:
781
+ first_latency = child_result.compute_stats[key].max_latency
782
+
783
+ compute_stats = result_accumulator.compute_stats.setdefault(
784
+ key, ComputeStats()
785
+ )
786
+ compute_stats += child_result.compute_stats[key].repeat_temporal(
787
+ shape_repeats
788
+ )
789
+ result_accumulator.compute_stats[key] = compute_stats
790
+
791
+ info.last_temporal_node_idx = node_idx
792
+
793
+ shape = stride_and_shape.shape
794
+ if isinstance(shape, SequenceOfRepatedvalues):
795
+ for repeated_shape in shape.sequence:
796
+ assert isinstance(repeated_shape, RepeatedValue)
797
+ handle_repeated_value(repeated_shape)
798
+ elif isinstance(shape, RepeatedValue):
799
+ handle_repeated_value(shape)
800
+
801
+ if node_idx in info.idxs_to_track_first_latency:
802
+ for compute_stat in result_accumulator.compute_stats.values():
803
+ # Should be the first time we store this value
804
+ assert node_idx not in compute_stat.max_first_latency
805
+ compute_stat.max_first_latency[node_idx] = first_latency
806
+
807
+ return result_accumulator
808
+
809
+
810
+ def analyze_spatial(node_idx, current_shape, info: AnalysisInfo):
811
+ mapping = info.mapping
812
+ einsum_name = mapping[-1].einsum
813
+ node: Spatial = mapping[node_idx]
814
+ rank_var = node.rank_variable
815
+ node_dim = node.name
816
+ stride_and_shape = get_stride_and_tile_shape(node, current_shape, node_idx, info)
817
+
818
+ result_accumulator = SymbolicAnalysisOutput()
819
+
820
+ def handle_repeated_value(repeated_shape):
821
+ shape_value = repeated_shape.value
822
+ shape_repeats = repeated_shape.repeats
823
+
824
+ child_shape = current_shape.copy()
825
+ child_shape[node.rank_variable] = shape_value
826
+
827
+ child_result = analyze_node(node_idx + 1, child_shape, info)
828
+
829
+ accumulated_buffet_stats = result_accumulator.buffet_stats
830
+ child_stats = list(child_result.buffet_stats.items())
831
+ for i, (buffet, buffet_stats) in enumerate(child_stats):
832
+ stats = buffet_stats
833
+ accumulated_stats = accumulated_buffet_stats.setdefault(
834
+ buffet, blank_buffet_stats()
835
+ )
836
+ relevancy = info.tensor_to_relevancy[buffet.tensor][rank_var]
837
+
838
+ # Reuse parent accesses only:
839
+ # - Irrelevant loops
840
+ # - The outermost level that holds the tensor (the one whose parent accesses
841
+ # will be going through the network)
842
+ last_buffet = True
843
+ for other_buffet, _ in child_stats[i + 1 :]:
844
+ if other_buffet.tensor == buffet.tensor:
845
+ last_buffet = False
846
+ break
847
+
848
+ reuse_parent_accesses = (
849
+ last_buffet
850
+ and isinstance(relevancy, Irrelevant)
851
+ and buffet.tensor in node._may_reuse
852
+ )
853
+
854
+ accumulated_stats += stats.repeat_spatial(
855
+ shape_repeats, reuse_parent_accesses=reuse_parent_accesses
856
+ )
857
+ accumulated_stats.n_loops_above = stats.n_loops_above + 1
858
+
859
+ for einsum, child_steps in child_result.temporal_steps.items():
860
+ if einsum not in result_accumulator.temporal_steps:
861
+ result_accumulator.temporal_steps[einsum] = child_steps
862
+ else:
863
+ result_accumulator.temporal_steps[einsum] = Max(
864
+ result_accumulator.temporal_steps[einsum], child_steps
865
+ )
866
+
867
+ my_key = (node.component, einsum_name)
868
+ child_result.fanout.setdefault(my_key, {})
869
+
870
+ # Propagate up everything except the current level and dimension
871
+ child_fanout = copy.deepcopy(child_result.fanout)
872
+ target_fanout = child_fanout[my_key].pop(node_dim, 1)
873
+ result_accumulator.max(fanout=child_fanout)
874
+
875
+ # Prpoagate current level and dimension * shape_repeats
876
+ child_fanout = child_result.fanout[my_key]
877
+ fanout = result_accumulator.fanout.setdefault(my_key, {})
878
+ fanout.setdefault(node_dim, 0) # TODO: Assume sympy can just take in 0
879
+ # TODO: If node_dim was missing, the original code would have omitted
880
+ # shape_repeats. Is this correct?
881
+ fanout[node_dim] += target_fanout * shape_repeats
882
+
883
+ for key in child_result.compute_stats:
884
+ # TODO: ensure that `ComputeStats()`, which is initialized ONCE, is okay to use here
885
+ compute_stats = result_accumulator.compute_stats.setdefault(
886
+ key, ComputeStats()
887
+ )
888
+ # TODO: If check omitted. This was in the original code, check history if needed.
889
+ compute_stats.combine_spatial(
890
+ child_result.compute_stats[key].repeat_spatial(shape_repeats)
891
+ )
892
+
893
+ shape = stride_and_shape.shape
894
+ if isinstance(shape, SequenceOfRepatedvalues):
895
+ for repeated_shape in shape.sequence:
896
+ assert isinstance(repeated_shape, RepeatedValue)
897
+ handle_repeated_value(repeated_shape)
898
+ elif isinstance(shape, RepeatedValue):
899
+ handle_repeated_value(shape)
900
+
901
+ return result_accumulator
902
+
903
+
904
+ def reduce_dicts(dict1: dict, dict2: dict, reduce_op):
905
+ for key in dict1:
906
+ if key not in dict2:
907
+ dict2[key] = dict1[key]
908
+ else:
909
+ dict2[key] = reduce_op(dict1[key], dict2[key])
910
+
911
+
912
+ def get_total_to_per_unit(total, max_per_unit):
913
+ if total == 0 and max_per_unit != 0:
914
+ raise ValueError(f"total is 0 but max_per_unit is {max_per_unit}")
915
+ if total == 0:
916
+ return 1
917
+ return max_per_unit / total
918
+
919
+
920
+ def has_parent_tensor_holder(
921
+ tensor: TensorName, node_idx: int, info: AnalysisInfo
922
+ ) -> bool:
923
+ for node in info.mapping[:node_idx]:
924
+ if isinstance(node, TensorHolder) and tensor in node.tensors:
925
+ return True
926
+ return False
927
+
928
+
929
+ def find_component_object(
930
+ component: str, flattened_arch: list[arch.Leaf]
931
+ ) -> arch.TensorHolder:
932
+ for node in flattened_arch:
933
+ if node.name == component:
934
+ return node
935
+ raise ValueError(f"Component {component} not found in flattened arch")
936
+
937
+
938
+ def analyze_storage(
939
+ node_idx: int,
940
+ current_shape: dict[str, int],
941
+ info: AnalysisInfo,
942
+ propagate_child_results: bool = False,
943
+ count_upward_movement: bool = True,
944
+ count_downward_movement: bool = True,
945
+ count_writes: bool = True,
946
+ ):
947
+ mapping = info.mapping
948
+ einsum_name = mapping[-1].einsum
949
+ node: TensorHolder = mapping[node_idx]
950
+
951
+ child_result = analyze_node(node_idx + 1, current_shape, info)
952
+
953
+ for tensor in node.tensors:
954
+ tensor = TensorName(tensor)
955
+ buffet = Buffet(tensor, einsum_name, node.component)
956
+
957
+ # Reservations make these, and they go below the storage node, so the buffet
958
+ # stats are already made at this point
959
+ stats = child_result.buffet_stats[buffet]
960
+ backer_id = info.tensor_to_backer_id[tensor]
961
+ is_backing = backer_id == id(node)
962
+ if node.persistent:
963
+ stats.persistent = True
964
+ below_backing = backer_id in [id(m) for m in mapping[:node_idx]]
965
+
966
+ projection = info.einsum_tensor_to_projection[(einsum_name, tensor)]
967
+
968
+ fills = compute_dense_tile_occupancy(projection, current_shape)
969
+
970
+ child = child_result.get_child_buffet_stats(buffet)
971
+ inherit_from_child = propagate_child_results and child is not None
972
+
973
+ # ==============================================================================
974
+ # Calculate the total fills and reads to parent. These propagate upward.
975
+ # ==============================================================================
976
+
977
+ def inherit_add(attr: str, default_value: Any = fills) -> Any:
978
+ val = getattr(child, attr) if inherit_from_child else default_value
979
+ setattr(stats, attr, val + getattr(stats, attr))
980
+
981
+ if has_parent_tensor_holder(tensor, node_idx, info):
982
+ # Initial fetch: If we're below the backing storage, fetch data from above
983
+ # at the beginning.
984
+ if not is_backing and below_backing:
985
+ inherit_add("total_reads_to_parent", fills)
986
+ inherit_add("max_per_parent_reads_to_parent", fills)
987
+
988
+ # Data writeback. Do not writeback if it's a copy operation and we're below
989
+ # the backing storage; data only flows upward.
990
+
991
+ # Writeback occurs in two cases:
992
+ # - We're at or above the backing storage, so we need to propagate our
993
+ # results upward to any storage nodes that will need this data.
994
+ # - This is a written tensor, so we need to write back the written data.
995
+ if (
996
+ tensor in info.workload.einsums[einsum_name].output_tensor_names
997
+ or not below_backing
998
+ ):
999
+ inherit_add("total_writes_to_parent")
1000
+ inherit_add("max_per_parent_writes_to_parent")
1001
+
1002
+ # For read+write tensors, we skip the first fill because the data will be
1003
+ # initialized with a zero value.
1004
+ if tensor in info.workload.einsums[einsum_name].output_tensor_names:
1005
+ inherit_add("total_skipped_first_reads_to_parent")
1006
+ inherit_add("min_per_parent_skipped_first_reads_to_parent")
1007
+
1008
+ # ==============================================================================
1009
+ # Convert to actions. These are not used used upward; they are used to get
1010
+ # energy and latency.
1011
+ # ==============================================================================
1012
+ component_object = find_component_object(
1013
+ node.component, info.job.flattened_arch
1014
+ )
1015
+ bits_per_value_scale = component_object.attributes.bits_per_value_scale[tensor]
1016
+ bits_per_value = bits_per_value_scale * info.job.bits_per_value[tensor]
1017
+ read_bits_per_action = component_object.actions[
1018
+ "read"
1019
+ ].arguments.bits_per_action
1020
+ read_scale = bits_per_value / read_bits_per_action
1021
+ if count_writes:
1022
+ write_bits_per_action = component_object.actions[
1023
+ "write"
1024
+ ].arguments.bits_per_action
1025
+ write_scale = bits_per_value / write_bits_per_action
1026
+ else:
1027
+ write_scale = 0
1028
+
1029
+ # ==========================
1030
+ # Data exchanges with parent
1031
+ if count_downward_movement: # Parent -> Me
1032
+ stats.total_write_actions += stats.total_reads_to_parent * write_scale
1033
+ stats.max_per_unit_write_actions += (
1034
+ stats.total_reads_to_parent * write_scale
1035
+ )
1036
+ stats.total_skipped_first_write_actions += (
1037
+ stats.total_skipped_first_reads_to_parent * write_scale
1038
+ )
1039
+ stats.min_per_unit_skipped_first_write_actions += (
1040
+ stats.min_per_parent_skipped_first_reads_to_parent * write_scale
1041
+ )
1042
+
1043
+ if count_upward_movement: # Me -> Parent
1044
+ # Comment this to have the final writeback to a buffer hit both that buffer and
1045
+ # go directly to the parent without incurring another read from the buffer.
1046
+ stats.total_read_actions += stats.total_writes_to_parent * read_scale
1047
+ stats.max_per_unit_read_actions += stats.total_writes_to_parent * read_scale
1048
+
1049
+ # ========================
1050
+ # Data exchanges with peer
1051
+ stats.total_read_actions += stats.total_reads_to_peer * read_scale
1052
+ stats.total_write_actions += stats.total_reads_to_peer * write_scale
1053
+
1054
+ # =========================
1055
+ # Data exchanges with child
1056
+ if child is not None:
1057
+ if count_downward_movement: # Me -> Child
1058
+ stats.total_read_actions += child.total_reads_to_parent * read_scale
1059
+ stats.max_per_unit_read_actions += (
1060
+ child.max_per_parent_reads_to_parent * read_scale
1061
+ )
1062
+ # Skip first read
1063
+ stats.total_skipped_first_read_actions += (
1064
+ child.total_skipped_first_reads_to_parent * read_scale
1065
+ )
1066
+ stats.min_per_unit_skipped_first_read_actions += (
1067
+ child.min_per_parent_skipped_first_reads_to_parent * read_scale
1068
+ )
1069
+
1070
+ if count_upward_movement: # Child -> Me
1071
+ stats.total_write_actions += child.total_writes_to_parent * write_scale
1072
+ stats.max_per_unit_write_actions += (
1073
+ child.max_per_parent_writes_to_parent * write_scale
1074
+ )
1075
+
1076
+ return child_result
1077
+
1078
+
1079
+ def analyze_processing_stage(node_idx, current_shape, info: AnalysisInfo):
1080
+ mapping = info.mapping
1081
+ einsum_name = mapping[-1].einsum
1082
+ node = mapping[node_idx]
1083
+ component_object = find_component_object(node.component, info.job.flattened_arch)
1084
+ storage_result = analyze_storage(
1085
+ node_idx,
1086
+ current_shape,
1087
+ info,
1088
+ propagate_child_results=True,
1089
+ count_upward_movement=component_object.attributes.direction != "down",
1090
+ count_downward_movement=component_object.attributes.direction != "up",
1091
+ count_writes=False,
1092
+ )
1093
+ for tensor in node.tensors:
1094
+ buffet = Buffet(tensor, einsum_name, node.component)
1095
+ stats = storage_result.buffet_stats[buffet]
1096
+ stats.max_occupancy = 0
1097
+ assert stats.total_write_actions == 0
1098
+ return storage_result
1099
+
1100
+
1101
+ def analyze_reservation(node_idx, current_shape, info: AnalysisInfo):
1102
+ mapping = info.mapping
1103
+ einsum_name = mapping[-1].einsum
1104
+ node = mapping[node_idx]
1105
+ tensor = TensorName(node.purpose)
1106
+
1107
+ if info.last_temporal_node_idx is not None and id(
1108
+ node
1109
+ ) == info.tensor_to_reservation_backer_id.get(node.purpose, None):
1110
+ info.idxs_to_track_first_latency.add(info.last_temporal_node_idx)
1111
+
1112
+ child_result = analyze_node(node_idx + 1, current_shape, info)
1113
+
1114
+ buffet = Buffet(tensor, einsum_name, node.resource)
1115
+
1116
+ # Reservation nodes are the first to produce stats for a buffet
1117
+ assert buffet not in child_result.buffet_stats
1118
+
1119
+ stats = BuffetStats()
1120
+ projection = info.einsum_tensor_to_projection[(einsum_name, tensor)]
1121
+ component_object = find_component_object(node.resource, info.job.flattened_arch)
1122
+ bits_per_value_scale = component_object.attributes.bits_per_value_scale[tensor]
1123
+ bits_per_value = bits_per_value_scale * info.job.bits_per_value[tensor]
1124
+ stats.max_occupancy = (
1125
+ compute_dense_tile_occupancy(projection, current_shape) * bits_per_value
1126
+ )
1127
+ child_result.buffet_stats[buffet] = stats
1128
+
1129
+ fanout_key = (node.resource, einsum_name)
1130
+ if fanout_key not in child_result.fanout:
1131
+ child_result.fanout[fanout_key] = {}
1132
+
1133
+ return child_result
1134
+
1135
+
1136
+ def analyze_compute(
1137
+ node_idx, current_shape, info: AnalysisInfo
1138
+ ) -> SymbolicAnalysisOutput:
1139
+ einsum = info.mapping[-1].einsum
1140
+ node = info.mapping[node_idx]
1141
+
1142
+ computes = 0 if info.is_copy_operation else 1
1143
+
1144
+ result_accumulator = SymbolicAnalysisOutput()
1145
+
1146
+ result_accumulator.temporal_steps[einsum] = computes
1147
+ result_accumulator.compute_stats[Compute(einsum, node.component)] = ComputeStats(
1148
+ computes,
1149
+ computes,
1150
+ 1,
1151
+ )
1152
+
1153
+ if info.is_copy_operation:
1154
+ return result_accumulator
1155
+
1156
+ for tensor in info.all_tensors:
1157
+ buffet = Buffet(tensor, einsum, node.component)
1158
+ stats = BuffetStats()
1159
+ stats.total_reads_to_parent = 1
1160
+ stats.max_per_parent_reads_to_parent = 1
1161
+ if tensor in info.workload.einsums[einsum].output_tensor_names:
1162
+ stats.total_writes_to_parent = 1
1163
+ stats.max_per_parent_writes_to_parent = 1
1164
+ stats.total_skipped_first_reads_to_parent = 1
1165
+ stats.min_per_parent_skipped_first_reads_to_parent = 1
1166
+ stats.max_occupancy = 1
1167
+ result_accumulator.buffet_stats[buffet] = stats
1168
+
1169
+ return result_accumulator
1170
+
1171
+
1172
+ @dataclass
1173
+ class RepeatedValue[T]:
1174
+ value: T
1175
+ repeats: int
1176
+
1177
+
1178
+ @dataclass
1179
+ class SequenceOfRepatedvalues[T]:
1180
+ sequence: list[RepeatedValue[T]]
1181
+
1182
+
1183
+ @dataclass
1184
+ class StrideAndShape:
1185
+ stride: any
1186
+ shape: any
1187
+
1188
+
1189
+ def get_stride_and_tile_shape(node: Loop, full_shape, n: int, info: AnalysisInfo):
1190
+ rank = node.rank_variable
1191
+ rank_shape = full_shape[rank]
1192
+
1193
+ stride = node.stride
1194
+ initial_tile_shape = node.initial_tile_shape
1195
+
1196
+ # PERFECT:
1197
+ # - Node shape = stride
1198
+ # - # Iterations = total shape / stride
1199
+ # IMPERFECT:
1200
+ # - Node shape = stride
1201
+ # - # Iterations = ceil(total shape / stride)
1202
+ if IMPERFECT and initial_tile_shape is None:
1203
+ factor = sympy.ceiling(rank_shape / stride)
1204
+ stride_avg = stride / sympy.ceiling(rank_shape / stride)
1205
+ return StrideAndShape(stride_avg, RepeatedValue(stride, factor))
1206
+
1207
+ if initial_tile_shape is None:
1208
+ if node._assume_perfect_factor or known_perfect_factor(stride, rank_shape):
1209
+ factor = rank_shape / stride
1210
+ return StrideAndShape(stride, RepeatedValue(stride, factor))
1211
+ else:
1212
+ factor = sympy.ceiling(rank_shape / sympy.Min(stride, rank_shape))
1213
+ return make_possibly_different_last(stride, factor, rank_shape)
1214
+
1215
+ middle_shape_factor = sympy.ceiling((rank_shape - initial_tile_shape) / stride)
1216
+ # TODO: sometimes last_shape is 0, causing numerical instability
1217
+ # Currently, we are sometimes rounding up last shape.
1218
+ # last_shape = rank_shape - initial_tile_shape - stride*middle_shape_factor
1219
+ # has_last_shape = sympy.ceiling(last_shape/(last_shape+1))
1220
+ return StrideAndShape(
1221
+ stride,
1222
+ SequenceOfRepatedvalues(
1223
+ [
1224
+ RepeatedValue(initial_tile_shape, 1),
1225
+ RepeatedValue(stride, middle_shape_factor),
1226
+ # RepeatedValue(last_shape+0.01, has_last_shape)
1227
+ ]
1228
+ ),
1229
+ )
1230
+ # if node.tile_shape is not None:
1231
+ # tile_shape = node.tile_shape
1232
+
1233
+ # if node._assume_perfect_factor or known_perfect_factor(tile_shape, rank_shape):
1234
+ # factor = rank_shape / tile_shape
1235
+ # return StrideAndShape(tile_shape, RepeatedValue(tile_shape, factor))
1236
+ # else:
1237
+ # factor = sympy.ceiling(rank_shape / sympy.Min(tile_shape, rank_shape))
1238
+ # return make_possibly_different_last(tile_shape, factor, rank_shape)
1239
+ # elif node.loop_bound is not None:
1240
+ # factor = node.loop_bound
1241
+
1242
+ # if node._assume_perfect_factor or known_perfect_factor(factor, rank_shape):
1243
+ # tile_shape = rank_shape / factor
1244
+ # return StrideAndShape(tile_shape, RepeatedValue(tile_shape, factor))
1245
+ # else:
1246
+ # tile_shape = sympy.ceiling(rank_shape / sympy.Min(rank_shape, factor))
1247
+ # return make_possibly_different_last(tile_shape, factor, rank_shape)
1248
+
1249
+ # elif node.tile_pattern is not None:
1250
+ # stride = node.tile_pattern.stride
1251
+ # initial_tile_shape = node.tile_pattern.initial_tile_shape
1252
+ # tile_shape = node.tile_pattern.tile_shape
1253
+
1254
+ # if initial_tile_shape is not None:
1255
+ # middle_shape_factor = sympy.ceiling((rank_shape - initial_tile_shape)/stride)
1256
+ # # TODO: sometimes last_shape is 0, causing numerical instability
1257
+ # # Currently, we are sometimes rounding up last shape.
1258
+ # # last_shape = rank_shape - initial_tile_shape - stride*middle_shape_factor
1259
+ # # has_last_shape = sympy.ceiling(last_shape/(last_shape+1))
1260
+ # return StrideAndShape(
1261
+ # stride,
1262
+ # SequenceOfRepatedvalues([
1263
+ # RepeatedValue(initial_tile_shape, 1),
1264
+ # RepeatedValue(stride, middle_shape_factor),
1265
+ # # RepeatedValue(last_shape+0.01, has_last_shape)
1266
+ # ])
1267
+ # )
1268
+
1269
+
1270
+ def known_perfect_factor(divisor, full_shape):
1271
+ return (
1272
+ isinstance(divisor, int)
1273
+ and isinstance(full_shape, int)
1274
+ and full_shape % divisor == 1
1275
+ )
1276
+
1277
+
1278
+ def make_possibly_different_last(common_tile_shape, factor, full_shape):
1279
+ last_shape = full_shape - common_tile_shape * (factor - 1)
1280
+ all_shapes = SequenceOfRepatedvalues(
1281
+ [RepeatedValue(common_tile_shape, factor - 1), RepeatedValue(last_shape, 1)]
1282
+ )
1283
+ return StrideAndShape(common_tile_shape, all_shapes)
1284
+
1285
+
1286
+ def insert_sympy_symbols(mapping: list[MappingNode], job: Job):
1287
+ loop_idx = 0
1288
+ symbols = []
1289
+ rank_var_with_initial = set()
1290
+ for i, node in enumerate(mapping):
1291
+ if not isinstance(node, Loop):
1292
+ continue
1293
+
1294
+ stride_halos = set()
1295
+ for t in job.spec.workload.einsums[job.einsum_name].tensor_names:
1296
+ for (rank, rank_variable), (stride, halo) in job.stride_and_halo[t].items():
1297
+ if rank_variable == node.rank_variable:
1298
+ stride_halos.add((stride, halo))
1299
+
1300
+ if len(stride_halos) == 0:
1301
+ raise RuntimeError(
1302
+ f"{repr(node.rank_variable)} not found in {job.stride_and_halo}"
1303
+ )
1304
+
1305
+ # We only explore imperfect for the outermost fused loops
1306
+ simple = (
1307
+ (len(stride_halos) <= 1 and next(iter(stride_halos)) == (1, 0))
1308
+ or node.rank_variable in rank_var_with_initial
1309
+ or not node._fused
1310
+ )
1311
+
1312
+ # NOTE: initial_tile_shape must be inserted into `symbols` before `stride`
1313
+ # because of the order of tile shape exploration.
1314
+ # TODO: there has to be a better way to do this.
1315
+ if simple: # Just use the stride!
1316
+ node.initial_tile_shape = None
1317
+ elif node.initial_tile_shape == SYMBOL:
1318
+ rank_var_with_initial.add(node.rank_variable)
1319
+ initial_tile_shape = sympy.symbols(
1320
+ f"initial{loop_idx}", positive=True, integer=True
1321
+ )
1322
+ symbols.append(initial_tile_shape)
1323
+ node.initial_tile_shape = initial_tile_shape
1324
+
1325
+ # TODO: Check for 0 < shape < 1 for loop bound target
1326
+ if job.rank_variable_bounds[node.rank_variable] == 1:
1327
+ node.stride = 1
1328
+ elif node.stride == SYMBOL:
1329
+ stride = sympy.symbols(f"stride{loop_idx}", positive=True, integer=True)
1330
+ symbols.append(stride)
1331
+ node.stride = stride
1332
+
1333
+ # TODO: sometimes, a mapping is passed into the model twice.
1334
+ # E.g., after calling mapper, the model is called again for more
1335
+ # details.
1336
+ #
1337
+ # assert (
1338
+ # node.calculated_n_iterations is None
1339
+ # ), "Number of iterations is derived from the model. Do not set it!"
1340
+ node.calculated_n_iterations = sympy.symbols(
1341
+ f"n_iterations{loop_idx}", positive=True, integer=True
1342
+ )
1343
+
1344
+ loop_idx += 1
1345
+
1346
+ return symbols