accelforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. accelforge/__init__.py +21 -0
  2. accelforge/_accelerated_imports.py +16 -0
  3. accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
  4. accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
  5. accelforge/_deprecate/_simanneal/simanneal.py +666 -0
  6. accelforge/_deprecate/_simanneal/tracking.py +105 -0
  7. accelforge/_deprecate/_simanneal/wrappers.py +218 -0
  8. accelforge/_deprecate/_simanneal2/__init__.py +7 -0
  9. accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
  10. accelforge/_deprecate/_simanneal2/tracking.py +116 -0
  11. accelforge/_deprecate/compatibility_util.py +181 -0
  12. accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
  13. accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
  14. accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
  15. accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
  16. accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
  17. accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
  18. accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
  19. accelforge/_deprecate/tags.py +69 -0
  20. accelforge/_deprecate/viz/__init__.py +0 -0
  21. accelforge/_deprecate/viz/interactive.py +159 -0
  22. accelforge/_deprecate/viz/reservationtree.py +307 -0
  23. accelforge/_deprecate/viz/ski_slope.py +88 -0
  24. accelforge/_version.py +15 -0
  25. accelforge/examples.py +39 -0
  26. accelforge/frontend/__init__.py +10 -0
  27. accelforge/frontend/_binding.py +129 -0
  28. accelforge/frontend/_workload_isl/__init__.py +2 -0
  29. accelforge/frontend/_workload_isl/_isl.py +149 -0
  30. accelforge/frontend/_workload_isl/_symbolic.py +141 -0
  31. accelforge/frontend/arch copy.py +1544 -0
  32. accelforge/frontend/arch.py +1642 -0
  33. accelforge/frontend/config.py +63 -0
  34. accelforge/frontend/mapper/__init__.py +5 -0
  35. accelforge/frontend/mapper/ffm.py +126 -0
  36. accelforge/frontend/mapper/mapper.py +7 -0
  37. accelforge/frontend/mapper/metrics.py +30 -0
  38. accelforge/frontend/mapping/__init__.py +1 -0
  39. accelforge/frontend/mapping/mapping.py +1736 -0
  40. accelforge/frontend/model.py +14 -0
  41. accelforge/frontend/renames.py +150 -0
  42. accelforge/frontend/spec copy.py +230 -0
  43. accelforge/frontend/spec.py +301 -0
  44. accelforge/frontend/variables.py +12 -0
  45. accelforge/frontend/workload.py +952 -0
  46. accelforge/mapper/FFM/__init__.py +9 -0
  47. accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
  48. accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
  49. accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
  50. accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
  51. accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
  52. accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
  53. accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
  54. accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
  55. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
  56. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
  57. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
  58. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
  59. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
  60. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
  61. accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
  62. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
  63. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
  64. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
  65. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
  66. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
  67. accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
  68. accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
  69. accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
  70. accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
  71. accelforge/mapper/FFM/data.py +61 -0
  72. accelforge/mapper/FFM/main copy.py +236 -0
  73. accelforge/mapper/FFM/main.py +208 -0
  74. accelforge/mapper/FFM/mappings.py +510 -0
  75. accelforge/mapper/FFM/pmappings.py +310 -0
  76. accelforge/mapper/__init__.py +4 -0
  77. accelforge/mapper.py +0 -0
  78. accelforge/model/__init__.py +1 -0
  79. accelforge/model/_looptree/__init__.py +0 -0
  80. accelforge/model/_looptree/accesses.py +335 -0
  81. accelforge/model/_looptree/capacity/__init__.py +1 -0
  82. accelforge/model/_looptree/capacity/aggregators.py +36 -0
  83. accelforge/model/_looptree/capacity/capacity.py +47 -0
  84. accelforge/model/_looptree/energy.py +150 -0
  85. accelforge/model/_looptree/equivalent_ranks.py +29 -0
  86. accelforge/model/_looptree/latency/__init__.py +1 -0
  87. accelforge/model/_looptree/latency/latency.py +98 -0
  88. accelforge/model/_looptree/latency/memory.py +120 -0
  89. accelforge/model/_looptree/latency/processors.py +92 -0
  90. accelforge/model/_looptree/mapping_utilities.py +71 -0
  91. accelforge/model/_looptree/reuse/__init__.py +4 -0
  92. accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
  93. accelforge/model/_looptree/reuse/isl/des.py +59 -0
  94. accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
  95. accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
  96. accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
  97. accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
  98. accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
  99. accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
  100. accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
  101. accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
  102. accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
  103. accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
  104. accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
  105. accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
  106. accelforge/model/_looptree/run.py +122 -0
  107. accelforge/model/_looptree/types.py +26 -0
  108. accelforge/model/_looptree/visualization/__init__.py +0 -0
  109. accelforge/model/_looptree/visualization/occupancy.py +11 -0
  110. accelforge/model/main.py +222 -0
  111. accelforge/plotting/__init__.py +2 -0
  112. accelforge/plotting/mappings.py +219 -0
  113. accelforge/plotting/specs.py +57 -0
  114. accelforge/util/__init__.py +4 -0
  115. accelforge/util/_base_analysis_types.py +24 -0
  116. accelforge/util/_basetypes.py +1089 -0
  117. accelforge/util/_frozenset.py +36 -0
  118. accelforge/util/_isl.py +29 -0
  119. accelforge/util/_itertools.py +14 -0
  120. accelforge/util/_mathfuncs.py +57 -0
  121. accelforge/util/_parse_expressions.py +339 -0
  122. accelforge/util/_picklecache.py +32 -0
  123. accelforge/util/_setexpressions.py +268 -0
  124. accelforge/util/_sympy/__init__.py +0 -0
  125. accelforge/util/_sympy/broadcast_max.py +18 -0
  126. accelforge/util/_visualization.py +112 -0
  127. accelforge/util/_yaml.py +579 -0
  128. accelforge/util/parallel.py +193 -0
  129. accelforge-0.0.1.dist-info/METADATA +64 -0
  130. accelforge-0.0.1.dist-info/RECORD +258 -0
  131. accelforge-0.0.1.dist-info/WHEEL +5 -0
  132. accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
  133. accelforge-0.0.1.dist-info/top_level.txt +5 -0
  134. docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
  135. docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
  136. docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
  137. docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
  138. docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
  139. docs/_build/html/_sources/fastfusion.rst.txt +20 -0
  140. docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
  141. docs/_build/html/_sources/index.rst.txt +87 -0
  142. docs/_build/html/_sources/modules.rst.txt +7 -0
  143. docs/_build/html/_sources/notes/citation.rst.txt +45 -0
  144. docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
  145. docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
  146. docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
  147. docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
  148. docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
  149. docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
  150. docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
  151. docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
  152. docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
  153. docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
  154. docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
  155. docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
  156. docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
  157. docs/_build/html/_sources/notes/spec.rst.txt +36 -0
  158. docs/source/_ext/include_attrs.py +213 -0
  159. docs/source/_ext/include_docstring.py +364 -0
  160. docs/source/_ext/include_functions.py +154 -0
  161. docs/source/_ext/include_notebook.py +131 -0
  162. docs/source/_ext/include_yaml.py +119 -0
  163. docs/source/_ext/inherited_attributes.py +222 -0
  164. docs/source/_ext/paths.py +4 -0
  165. docs/source/conf.py +79 -0
  166. examples/arches/compute_in_memory/_include.yaml +74 -0
  167. examples/arches/compute_in_memory/_include_functions.py +229 -0
  168. examples/arches/compute_in_memory/_load_spec.py +57 -0
  169. examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
  170. examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
  171. examples/arches/compute_in_memory/components/misc.py +195 -0
  172. examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
  173. examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
  174. examples/arches/compute_in_memory/isaac.yaml +233 -0
  175. examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
  176. examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
  177. examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
  178. examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
  179. examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
  180. examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
  181. examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
  182. examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
  183. examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
  184. examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
  185. examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
  186. examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
  187. examples/arches/eyeriss.yaml +68 -0
  188. examples/arches/fanout_variations/at_glb.yaml +31 -0
  189. examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
  190. examples/arches/fanout_variations/at_mac.yaml +31 -0
  191. examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
  192. examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
  193. examples/arches/nvdla.yaml +47 -0
  194. examples/arches/simple.yaml +28 -0
  195. examples/arches/tpu_v4i.yaml +67 -0
  196. examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
  197. examples/misc/component_annotated.yaml +33 -0
  198. examples/workloads/gpt3_6.7B.yaml +124 -0
  199. examples/workloads/matmuls.yaml +20 -0
  200. examples/workloads/mobilenet_28.yaml +81 -0
  201. examples/workloads/mobilenet_various_separate.yaml +106 -0
  202. examples/workloads/three_matmuls_annotated.yaml +59 -0
  203. notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
  204. notebooks/compute_in_memory/_scripts.py +339 -0
  205. notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
  206. notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
  207. notebooks/paths.py +4 -0
  208. notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
  209. notebooks/tutorials/FFM.ipynb +3498 -0
  210. notebooks/tutorials/_include.py +48 -0
  211. notebooks/tutorials/component_energy_area.ipynb +363 -0
  212. tests/Q_mapping.yaml +38 -0
  213. tests/__init__.py +0 -0
  214. tests/conv.mapping.yaml +27 -0
  215. tests/conv.workload.yaml +13 -0
  216. tests/conv_sym.mapping.yaml +43 -0
  217. tests/copy.mapping.yaml +35 -0
  218. tests/copy.workload.yaml +15 -0
  219. tests/distribuffers/__init__.py +0 -0
  220. tests/distribuffers/multicast/test_cases.yaml +482 -0
  221. tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
  222. tests/distribuffers/spec/distributed.yaml +100 -0
  223. tests/distribuffers/spec/logical_arch.yaml +32 -0
  224. tests/distribuffers/spec/physical_arch.yaml +69 -0
  225. tests/distribuffers/test_binding.py +48 -0
  226. tests/frontend/__init__.py +0 -0
  227. tests/frontend/test_mapping_viz.py +52 -0
  228. tests/mapper/__init__.py +0 -0
  229. tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
  230. tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
  231. tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
  232. tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
  233. tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
  234. tests/mapper/test_mapping_to_isl.py +90 -0
  235. tests/mapper/test_spatial_reuse_analysis.py +67 -0
  236. tests/mapper/test_temporal_reuse_analysis.py +56 -0
  237. tests/mapper/util.py +58 -0
  238. tests/matmul.mapping.yaml +29 -0
  239. tests/matmul.workload.yaml +12 -0
  240. tests/matmul_spatial.mapping.yaml +44 -0
  241. tests/mha.renames.yaml +65 -0
  242. tests/mha.workload.yaml +67 -0
  243. tests/mha.yaml +59 -0
  244. tests/mha_full.workload.yaml +67 -0
  245. tests/mobilenet.workload.yaml +35 -0
  246. tests/mobilenet_long.workload.yaml +64 -0
  247. tests/pmappingcache.py +24 -0
  248. tests/processing_stage.arch.yaml +40 -0
  249. tests/snowcat.arch.yaml +36 -0
  250. tests/test_ffm_join_pmappings.py +106 -0
  251. tests/test_ffm_make_pmappings.py +82 -0
  252. tests/test_ffm_make_tile_shapes.py +49 -0
  253. tests/test_mapper.py +100 -0
  254. tests/test_model.py +37 -0
  255. tests/test_plotting.py +72 -0
  256. tests/test_processing_stage.py +46 -0
  257. tests/test_symbolic_model.py +248 -0
  258. tests/test_workload.py +141 -0
@@ -0,0 +1,373 @@
1
+ from accelforge.frontend.renames import TensorName
2
+
3
+
4
+ import itertools
5
+ from enum import Enum
6
+
7
+ import accelforge.frontend.arch as arch
8
+ from accelforge.frontend.mapping import (
9
+ MappingNode,
10
+ ProcessingStage,
11
+ Temporal,
12
+ Spatial,
13
+ TensorHolder,
14
+ )
15
+ from accelforge.frontend.workload import (
16
+ Einsum,
17
+ RankVariable,
18
+ Workload,
19
+ )
20
+
21
+
22
+ # =================================================================================================
23
+ # Insert loops
24
+ # =================================================================================================
25
+
26
+
27
+ class LowerChoice(Enum):
28
+ YES = 0
29
+ NO = 1
30
+ OPTIONAL = 2
31
+
32
+
33
+ def insert_temporal_loops(
34
+ mapping: list[TensorHolder],
35
+ einsum: Einsum,
36
+ first_memory: arch.Memory,
37
+ rank_variable_bounds: dict[RankVariable, int],
38
+ ranks_with_tile_pattern: set,
39
+ workload: Workload,
40
+ _can_lower_outermost_memory: bool,
41
+ flattened_arch: list[arch.Leaf],
42
+ max_fused_loops: int,
43
+ ):
44
+ # First establish insertion points. Insertion points are:
45
+ # - Below the last instance of the first memory
46
+ # - Between any two TensorHolder nodes
47
+ # - After the last TensorHolder node
48
+
49
+ # The following logic is really just to make sure that all the storage nodse for the
50
+ # outermost memory are together at the beginning of the split mapping. After that,
51
+ # each entries in the split mapping has a single TensorHolder.
52
+ split_mapping: list[list[TensorHolder]] = [[]]
53
+ for m in mapping:
54
+ split_mapping.append([m])
55
+ if len(split_mapping) > 1 and m.component == first_memory.name:
56
+ split_mapping[-2].extend(split_mapping.pop(-1))
57
+ for i, s in enumerate[list[TensorHolder | Spatial]](split_mapping):
58
+ for m in s:
59
+ if i == 0 and m.component != first_memory.name:
60
+ raise ValueError(
61
+ "The first TensorHolder in the mapping is not for the outermost "
62
+ "memory. This isn't known to be invalid, but the code may not "
63
+ "handle it."
64
+ )
65
+ elif i > 0 and m.component == first_memory.name:
66
+ raise ValueError(
67
+ "First memory isn't at the top of the hierarchy. This isn't known"
68
+ "to be invalid, but the code may not handle it."
69
+ )
70
+ elif i == 0 and isinstance(m, Spatial):
71
+ raise ValueError(
72
+ "Found Spatial node before any TensorHolder. This isn't known to "
73
+ "be invalid, but the code may not handle it."
74
+ )
75
+
76
+ split_mapping = [m for m in split_mapping if m]
77
+
78
+ # These Einsum properties are recalculated since Einsum is mutable
79
+ # We're pre-computing and reusing for efficiency
80
+ tensor2fully_relevant_rank_vars = einsum.tensor2directly_indexing_rank_variables
81
+ tensor2partially_relevant_rank_vars = (
82
+ einsum.tensor2expression_indexing_rank_variables
83
+ )
84
+ tensor2irrelevant_rank_vars = einsum.tensor2irrelevant_rank_variables
85
+ tensor2rank_vars = einsum.tensor2rank_variables
86
+ tensors = einsum.tensor_names
87
+
88
+ fusable_tensors = (
89
+ einsum.tensor_names & workload.tensor_names_used_in_multiple_einsums
90
+ )
91
+ is_fused_loops = True
92
+ seen_tensors = set()
93
+ choices = []
94
+ lowering_choices: list[tuple[bool, ...]] = []
95
+ fanouts = {}
96
+ fanout = 1
97
+ for node in flattened_arch:
98
+ fanouts[node.name] = (fanout := fanout * node.get_fanout())
99
+
100
+ def _get_next_storages(i: int, pstage_allowed: bool = False) -> list[TensorHolder]:
101
+ for j in range(i + 1, len(split_mapping)):
102
+ assert len(split_mapping[j]) <= 1
103
+ # We don't add loops before processing stages
104
+ if isinstance(split_mapping[j][0], ProcessingStage) and not pstage_allowed:
105
+ continue
106
+ return split_mapping[j]
107
+ return []
108
+
109
+ prev_fanout = 1
110
+ someone_elses_spatials_may_be_placed_above = False
111
+ for i, prev_storages in enumerate(split_mapping):
112
+ # =============================================================================
113
+ # Choose what temporal loops to insert between prev_storages and the next
114
+ # TensorHolder node(s).
115
+ # =============================================================================
116
+
117
+ next_storages = _get_next_storages(i)
118
+ next_anything = _get_next_storages(i, pstage_allowed=True)
119
+
120
+ for s in prev_storages:
121
+ # No tensor holders must mix backing/non-backing tensors.
122
+ assert not s._backing or all(t in s._backing for t in s.tensors)
123
+ # One tensor per holder
124
+ assert len(s.tensors) == 1
125
+
126
+ rank_variables = einsum.rank_variables
127
+ # rank_variables = {r for r in rank_variables if rank_variable_bounds[r] > 1}
128
+ seen_tensors |= set.union(*(set(t.tensors) for t in prev_storages), set())
129
+ is_fused_loops = is_fused_loops and len(fusable_tensors - seen_tensors) > 0
130
+ prev_tensors = set.union(set(), *(set(t.tensors) for t in prev_storages))
131
+ next_persistent = set.union(
132
+ set(), *(set(t.tensors) for t in next_storages if t.persistent)
133
+ )
134
+
135
+ max_fanout_before = max(
136
+ [fanouts[s2.component] for s in split_mapping[:i] for s2 in s],
137
+ default=float("inf"),
138
+ )
139
+ min_fanout_after = min(
140
+ [fanouts[s2.component] for s in split_mapping[i + 1 :] for s2 in s],
141
+ default=0,
142
+ )
143
+ cur_fanout = set(fanouts[s2.component] for s2 in prev_storages)
144
+ next_fanout = set(fanouts[s2.component] for s2 in next_anything)
145
+ if len(next_fanout) == 0:
146
+ next_fanout.add(float("inf"))
147
+ # Either it's main memory or we have one entry in the list, so there should only
148
+ # be one
149
+ assert len(cur_fanout) == 1
150
+ assert len(next_fanout) == 1
151
+ cur_fanout = next(iter(cur_fanout))
152
+ next_fanout = next(iter(next_fanout))
153
+
154
+ # Can't have loops above persistent tensor holders
155
+ if next_persistent:
156
+ rank_variables &= set()
157
+
158
+ # No recomputation: If we haven't seen a tensor yet, must only iterate over
159
+ # fully-relevant rank variables.
160
+ for t in tensors - seen_tensors:
161
+ rank_variables &= tensor2fully_relevant_rank_vars[t]
162
+
163
+ if max_fused_loops == 0 and (fusable_tensors - seen_tensors):
164
+ rank_variables &= set()
165
+
166
+ # The fanout for a prior node may be placed here, so spatial nodes may be moved
167
+ # here
168
+ someone_elses_spatials_may_be_placed_below = (
169
+ next_fanout > cur_fanout and max_fanout_before > cur_fanout
170
+ )
171
+
172
+ # If the fanout is about to increase, then spatial loops may be placed below the
173
+ # current node. There may have been constrained temporal loops earlier that need
174
+ # to be placed here, so we won't prohibit any loops.
175
+ if someone_elses_spatials_may_be_placed_below:
176
+ pass
177
+ else:
178
+
179
+ # Optimality-preserving optimization: Loops below processing stages aren't
180
+ # helpful because there is no storage. Ctrl-F for
181
+ # CONTIGUOUS_ITERATION_SPACE_DISCUSSION: Can't do this if we may put another
182
+ # node's spatial loops below this one, because lowering would add move the
183
+ # spatials down, which would constrain the temporals due to spatial-temporal
184
+ # crossing.
185
+ if isinstance(prev_storages[0], ProcessingStage):
186
+ rank_variables &= set()
187
+
188
+ # Generally we want to only use rank variables that are irrelevant to the
189
+ # previous tensors, else we'd just lower those tensors. However, we can't
190
+ # lower backing TensorHolder nodes because this will add loops to
191
+ # compatibility.
192
+
193
+ # Optimality-preserving optimization: We can trivially lower non-backing
194
+ # TensorHolder nodes through fully-relevant loops. Can't do this if the
195
+ # loops are fused because that'd add loops to the compatibility. Ctrl-F
196
+ # forCONTIGUOUS_ITERATION_SPACE_DISCUSSION: Can't do this if we may put
197
+ # another node's spatial loops below this one, because lowering would add
198
+ # move the spatials down, which would constrain the temporals due to
199
+ # spatial-temporal crossing.
200
+ for s in prev_storages:
201
+ for t in s.tensors:
202
+ if t not in s._backing and not s._must_be_here:
203
+ rank_variables -= tensor2fully_relevant_rank_vars[t]
204
+
205
+ # Optimality-preserving optimization: We can trivially raise TensorHolder
206
+ # nodes through irrelevant unfused loops. Can't do this if the loops are
207
+ # fused because that'd increase the lifetime of the TensorHolder node. Can't
208
+ # do this if the irrelevant rank variables partially-relevant to the
209
+ # previous tensors, since that affects the permutation. See
210
+ # CONTIGUOUS_ITERATION_SPACE_DISCUSSION: Can't do this if we may put another
211
+ # node's spatial loops above this one, because raising would add move the
212
+ # temporals down, which would constrain them due to spatial-temporal
213
+ # crossing. TODO: CONTIGUOUS_ITERATION_SPACE_DISCUSSION: This causes all
214
+ # loops to be added, but really we only need to re-add the ones that may
215
+ # conflict with a spatial loop.
216
+ if not is_fused_loops:
217
+ for s in next_storages:
218
+ if not s._must_be_here:
219
+ for t in s.tensors:
220
+ rvs = tensor2irrelevant_rank_vars[t]
221
+ for t2 in prev_tensors:
222
+ rvs -= tensor2partially_relevant_rank_vars[t2]
223
+ rank_variables -= rvs
224
+
225
+ # =============================================================================
226
+ # Determine whether to lower TensorHolder nodes through partially-relevant
227
+ # loops.
228
+ # =============================================================================
229
+ partially_relevant_to_previous = rank_variables & set.union(
230
+ set(), *(tensor2partially_relevant_rank_vars[t] for t in prev_tensors)
231
+ )
232
+ permutable_partially_relevant = set()
233
+
234
+ # NOTE: If the lowering logic for backing TensorHolders is updated & we can
235
+ # lower through >1 loops, then also update label_fused_loops
236
+ for s in prev_storages:
237
+ partially_relevant_to_previous = set.union(
238
+ set(), *(tensor2partially_relevant_rank_vars[t] for t in s.tensors)
239
+ )
240
+ partially_relevant_to_previous &= rank_variables
241
+ lowerable_backing = (
242
+ _can_lower_outermost_memory or s.component != first_memory.name
243
+ )
244
+
245
+ # Persistent. Must be at the top of the mapping.
246
+ if s.persistent:
247
+ lowering_choices.append((False,))
248
+ # Don't lower our own reservations through someone else's spatial loops.
249
+ elif someone_elses_spatials_may_be_placed_below:
250
+ lowering_choices.append((False,))
251
+ # Processing stage. Lowering doesn't matter. Don't lower.
252
+ elif isinstance(s, ProcessingStage):
253
+ lowering_choices.append((False,))
254
+ # Previous is backing and there's partially-relevant rank variables. May
255
+ # want to lower to reduce memory footprint, or raise to reduce number of
256
+ # fused loops.
257
+ elif s._backing and lowerable_backing and partially_relevant_to_previous:
258
+ lowering_choices.append((False, True))
259
+ permutable_partially_relevant |= partially_relevant_to_previous
260
+ # No backing in previous. No cost to lowering. Lower all
261
+ elif not s._backing:
262
+ lowering_choices.append((True,))
263
+ permutable_partially_relevant |= partially_relevant_to_previous
264
+ # Previous TensorHolder is backing but not lowerable or there are no
265
+ # partially relevant rank vars.
266
+ else:
267
+ lowering_choices.append((False,))
268
+
269
+ # =============================================================================
270
+ # Create loop order and lowering choices
271
+ # =============================================================================
272
+
273
+ can_lower = any(any(c) for c in lowering_choices)
274
+
275
+ # Create canonical loop orders that avoids repeating reuse patterns.
276
+ choices.append(
277
+ list(
278
+ canonical_loop_orders(
279
+ rank_variables, permutable_partially_relevant, can_lower
280
+ )
281
+ )
282
+ )
283
+ prev_fanout = cur_fanout
284
+ someone_elses_spatials_may_be_placed_above = (
285
+ someone_elses_spatials_may_be_placed_below
286
+ )
287
+
288
+ # ==================================================================================
289
+ # Iterate over all possible mappings
290
+ # ==================================================================================
291
+
292
+ # TODO: Optimization: If we can optionally lower a tensor & the loop below it is
293
+ # not something through which we can lower for a given permutation, skip options
294
+ # that lower that tensor because they get the same result as not lowering the
295
+ # tensor.
296
+ n_loop_orders = len(list(itertools.product(*choices)))
297
+ for loop_orders in itertools.product(*choices):
298
+ full_mapping = []
299
+ for prev_storages, loop_order in zip(split_mapping, loop_orders):
300
+ full_mapping.extend(prev_storages)
301
+ full_mapping.extend(Temporal(rank_variable=r) for r in loop_order)
302
+
303
+ storages = [node for node in full_mapping if isinstance(node, TensorHolder)]
304
+ assert len(lowering_choices) == len(storages)
305
+ for lowering_choice in itertools.product(*lowering_choices):
306
+ for lower, node in zip(lowering_choice, storages):
307
+ node._lower = lower
308
+
309
+ yield list(full_mapping), n_loop_orders
310
+
311
+
312
+ def insert_spatial_loops(
313
+ mapping: list[MappingNode],
314
+ einsum: Einsum,
315
+ flattened_arch: list[arch.Memory],
316
+ ):
317
+ nodes_with_fanout = [n for n in flattened_arch if n.get_fanout() > 1]
318
+ arch_node_names = [n.name for n in flattened_arch]
319
+
320
+ for node in nodes_with_fanout:
321
+ insertion_point = _idx_of_highest_tensor_holder_with_component_below_fanout(
322
+ node, mapping, arch_node_names
323
+ )
324
+
325
+ rv = einsum.rank_variables
326
+ for fanout_dim in node.spatial:
327
+ for r in rv:
328
+ s = Spatial(
329
+ rank_variable=r,
330
+ name=fanout_dim.name,
331
+ component_object=node,
332
+ component=node.name,
333
+ )
334
+ if insertion_point == len(mapping):
335
+ mapping.append(s)
336
+ else:
337
+ mapping.insert(insertion_point, s)
338
+
339
+
340
+ def _idx_of_highest_tensor_holder_with_component_below_fanout(
341
+ fanout_node, mapping, arch_node_names
342
+ ):
343
+ for i in range(len(mapping)):
344
+ if not isinstance(mapping[i], TensorHolder):
345
+ continue
346
+ if arch_node_names.index(mapping[i].component) >= arch_node_names.index(
347
+ fanout_node.name
348
+ ):
349
+ return i
350
+ return len(mapping)
351
+
352
+
353
+ def canonical_loop_orders(
354
+ rank_variables: set[RankVariable],
355
+ partially_relevant_to_previous: set[RankVariable],
356
+ can_lower: bool,
357
+ ):
358
+ """Generate loop orders that result in unique reuse patterns."""
359
+ # Only the first partially-relevant rank variable matters is a meaningful
360
+ # choice because lowering only happens through at most one rank var.
361
+ if not partially_relevant_to_previous or not can_lower:
362
+ yield tuple(sorted(rank_variables))
363
+ return
364
+
365
+ for first_rank_var in partially_relevant_to_previous:
366
+ rest_of_partially_relevant = partially_relevant_to_previous - {first_rank_var}
367
+ rest_rank_vars = rank_variables - partially_relevant_to_previous
368
+ # Since order does not matter, we choose alphabetical order as canonical.
369
+ yield (
370
+ (first_rank_var,)
371
+ + tuple(sorted(rest_of_partially_relevant))
372
+ + tuple(sorted(rest_rank_vars))
373
+ )