accelforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. accelforge/__init__.py +21 -0
  2. accelforge/_accelerated_imports.py +16 -0
  3. accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
  4. accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
  5. accelforge/_deprecate/_simanneal/simanneal.py +666 -0
  6. accelforge/_deprecate/_simanneal/tracking.py +105 -0
  7. accelforge/_deprecate/_simanneal/wrappers.py +218 -0
  8. accelforge/_deprecate/_simanneal2/__init__.py +7 -0
  9. accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
  10. accelforge/_deprecate/_simanneal2/tracking.py +116 -0
  11. accelforge/_deprecate/compatibility_util.py +181 -0
  12. accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
  13. accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
  14. accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
  15. accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
  16. accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
  17. accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
  18. accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
  19. accelforge/_deprecate/tags.py +69 -0
  20. accelforge/_deprecate/viz/__init__.py +0 -0
  21. accelforge/_deprecate/viz/interactive.py +159 -0
  22. accelforge/_deprecate/viz/reservationtree.py +307 -0
  23. accelforge/_deprecate/viz/ski_slope.py +88 -0
  24. accelforge/_version.py +15 -0
  25. accelforge/examples.py +39 -0
  26. accelforge/frontend/__init__.py +10 -0
  27. accelforge/frontend/_binding.py +129 -0
  28. accelforge/frontend/_workload_isl/__init__.py +2 -0
  29. accelforge/frontend/_workload_isl/_isl.py +149 -0
  30. accelforge/frontend/_workload_isl/_symbolic.py +141 -0
  31. accelforge/frontend/arch copy.py +1544 -0
  32. accelforge/frontend/arch.py +1642 -0
  33. accelforge/frontend/config.py +63 -0
  34. accelforge/frontend/mapper/__init__.py +5 -0
  35. accelforge/frontend/mapper/ffm.py +126 -0
  36. accelforge/frontend/mapper/mapper.py +7 -0
  37. accelforge/frontend/mapper/metrics.py +30 -0
  38. accelforge/frontend/mapping/__init__.py +1 -0
  39. accelforge/frontend/mapping/mapping.py +1736 -0
  40. accelforge/frontend/model.py +14 -0
  41. accelforge/frontend/renames.py +150 -0
  42. accelforge/frontend/spec copy.py +230 -0
  43. accelforge/frontend/spec.py +301 -0
  44. accelforge/frontend/variables.py +12 -0
  45. accelforge/frontend/workload.py +952 -0
  46. accelforge/mapper/FFM/__init__.py +9 -0
  47. accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
  48. accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
  49. accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
  50. accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
  51. accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
  52. accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
  53. accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
  54. accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
  55. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
  56. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
  57. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
  58. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
  59. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
  60. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
  61. accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
  62. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
  63. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
  64. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
  65. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
  66. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
  67. accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
  68. accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
  69. accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
  70. accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
  71. accelforge/mapper/FFM/data.py +61 -0
  72. accelforge/mapper/FFM/main copy.py +236 -0
  73. accelforge/mapper/FFM/main.py +208 -0
  74. accelforge/mapper/FFM/mappings.py +510 -0
  75. accelforge/mapper/FFM/pmappings.py +310 -0
  76. accelforge/mapper/__init__.py +4 -0
  77. accelforge/mapper.py +0 -0
  78. accelforge/model/__init__.py +1 -0
  79. accelforge/model/_looptree/__init__.py +0 -0
  80. accelforge/model/_looptree/accesses.py +335 -0
  81. accelforge/model/_looptree/capacity/__init__.py +1 -0
  82. accelforge/model/_looptree/capacity/aggregators.py +36 -0
  83. accelforge/model/_looptree/capacity/capacity.py +47 -0
  84. accelforge/model/_looptree/energy.py +150 -0
  85. accelforge/model/_looptree/equivalent_ranks.py +29 -0
  86. accelforge/model/_looptree/latency/__init__.py +1 -0
  87. accelforge/model/_looptree/latency/latency.py +98 -0
  88. accelforge/model/_looptree/latency/memory.py +120 -0
  89. accelforge/model/_looptree/latency/processors.py +92 -0
  90. accelforge/model/_looptree/mapping_utilities.py +71 -0
  91. accelforge/model/_looptree/reuse/__init__.py +4 -0
  92. accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
  93. accelforge/model/_looptree/reuse/isl/des.py +59 -0
  94. accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
  95. accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
  96. accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
  97. accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
  98. accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
  99. accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
  100. accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
  101. accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
  102. accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
  103. accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
  104. accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
  105. accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
  106. accelforge/model/_looptree/run.py +122 -0
  107. accelforge/model/_looptree/types.py +26 -0
  108. accelforge/model/_looptree/visualization/__init__.py +0 -0
  109. accelforge/model/_looptree/visualization/occupancy.py +11 -0
  110. accelforge/model/main.py +222 -0
  111. accelforge/plotting/__init__.py +2 -0
  112. accelforge/plotting/mappings.py +219 -0
  113. accelforge/plotting/specs.py +57 -0
  114. accelforge/util/__init__.py +4 -0
  115. accelforge/util/_base_analysis_types.py +24 -0
  116. accelforge/util/_basetypes.py +1089 -0
  117. accelforge/util/_frozenset.py +36 -0
  118. accelforge/util/_isl.py +29 -0
  119. accelforge/util/_itertools.py +14 -0
  120. accelforge/util/_mathfuncs.py +57 -0
  121. accelforge/util/_parse_expressions.py +339 -0
  122. accelforge/util/_picklecache.py +32 -0
  123. accelforge/util/_setexpressions.py +268 -0
  124. accelforge/util/_sympy/__init__.py +0 -0
  125. accelforge/util/_sympy/broadcast_max.py +18 -0
  126. accelforge/util/_visualization.py +112 -0
  127. accelforge/util/_yaml.py +579 -0
  128. accelforge/util/parallel.py +193 -0
  129. accelforge-0.0.1.dist-info/METADATA +64 -0
  130. accelforge-0.0.1.dist-info/RECORD +258 -0
  131. accelforge-0.0.1.dist-info/WHEEL +5 -0
  132. accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
  133. accelforge-0.0.1.dist-info/top_level.txt +5 -0
  134. docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
  135. docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
  136. docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
  137. docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
  138. docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
  139. docs/_build/html/_sources/fastfusion.rst.txt +20 -0
  140. docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
  141. docs/_build/html/_sources/index.rst.txt +87 -0
  142. docs/_build/html/_sources/modules.rst.txt +7 -0
  143. docs/_build/html/_sources/notes/citation.rst.txt +45 -0
  144. docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
  145. docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
  146. docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
  147. docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
  148. docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
  149. docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
  150. docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
  151. docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
  152. docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
  153. docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
  154. docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
  155. docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
  156. docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
  157. docs/_build/html/_sources/notes/spec.rst.txt +36 -0
  158. docs/source/_ext/include_attrs.py +213 -0
  159. docs/source/_ext/include_docstring.py +364 -0
  160. docs/source/_ext/include_functions.py +154 -0
  161. docs/source/_ext/include_notebook.py +131 -0
  162. docs/source/_ext/include_yaml.py +119 -0
  163. docs/source/_ext/inherited_attributes.py +222 -0
  164. docs/source/_ext/paths.py +4 -0
  165. docs/source/conf.py +79 -0
  166. examples/arches/compute_in_memory/_include.yaml +74 -0
  167. examples/arches/compute_in_memory/_include_functions.py +229 -0
  168. examples/arches/compute_in_memory/_load_spec.py +57 -0
  169. examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
  170. examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
  171. examples/arches/compute_in_memory/components/misc.py +195 -0
  172. examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
  173. examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
  174. examples/arches/compute_in_memory/isaac.yaml +233 -0
  175. examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
  176. examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
  177. examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
  178. examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
  179. examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
  180. examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
  181. examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
  182. examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
  183. examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
  184. examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
  185. examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
  186. examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
  187. examples/arches/eyeriss.yaml +68 -0
  188. examples/arches/fanout_variations/at_glb.yaml +31 -0
  189. examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
  190. examples/arches/fanout_variations/at_mac.yaml +31 -0
  191. examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
  192. examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
  193. examples/arches/nvdla.yaml +47 -0
  194. examples/arches/simple.yaml +28 -0
  195. examples/arches/tpu_v4i.yaml +67 -0
  196. examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
  197. examples/misc/component_annotated.yaml +33 -0
  198. examples/workloads/gpt3_6.7B.yaml +124 -0
  199. examples/workloads/matmuls.yaml +20 -0
  200. examples/workloads/mobilenet_28.yaml +81 -0
  201. examples/workloads/mobilenet_various_separate.yaml +106 -0
  202. examples/workloads/three_matmuls_annotated.yaml +59 -0
  203. notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
  204. notebooks/compute_in_memory/_scripts.py +339 -0
  205. notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
  206. notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
  207. notebooks/paths.py +4 -0
  208. notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
  209. notebooks/tutorials/FFM.ipynb +3498 -0
  210. notebooks/tutorials/_include.py +48 -0
  211. notebooks/tutorials/component_energy_area.ipynb +363 -0
  212. tests/Q_mapping.yaml +38 -0
  213. tests/__init__.py +0 -0
  214. tests/conv.mapping.yaml +27 -0
  215. tests/conv.workload.yaml +13 -0
  216. tests/conv_sym.mapping.yaml +43 -0
  217. tests/copy.mapping.yaml +35 -0
  218. tests/copy.workload.yaml +15 -0
  219. tests/distribuffers/__init__.py +0 -0
  220. tests/distribuffers/multicast/test_cases.yaml +482 -0
  221. tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
  222. tests/distribuffers/spec/distributed.yaml +100 -0
  223. tests/distribuffers/spec/logical_arch.yaml +32 -0
  224. tests/distribuffers/spec/physical_arch.yaml +69 -0
  225. tests/distribuffers/test_binding.py +48 -0
  226. tests/frontend/__init__.py +0 -0
  227. tests/frontend/test_mapping_viz.py +52 -0
  228. tests/mapper/__init__.py +0 -0
  229. tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
  230. tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
  231. tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
  232. tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
  233. tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
  234. tests/mapper/test_mapping_to_isl.py +90 -0
  235. tests/mapper/test_spatial_reuse_analysis.py +67 -0
  236. tests/mapper/test_temporal_reuse_analysis.py +56 -0
  237. tests/mapper/util.py +58 -0
  238. tests/matmul.mapping.yaml +29 -0
  239. tests/matmul.workload.yaml +12 -0
  240. tests/matmul_spatial.mapping.yaml +44 -0
  241. tests/mha.renames.yaml +65 -0
  242. tests/mha.workload.yaml +67 -0
  243. tests/mha.yaml +59 -0
  244. tests/mha_full.workload.yaml +67 -0
  245. tests/mobilenet.workload.yaml +35 -0
  246. tests/mobilenet_long.workload.yaml +64 -0
  247. tests/pmappingcache.py +24 -0
  248. tests/processing_stage.arch.yaml +40 -0
  249. tests/snowcat.arch.yaml +36 -0
  250. tests/test_ffm_join_pmappings.py +106 -0
  251. tests/test_ffm_make_pmappings.py +82 -0
  252. tests/test_ffm_make_tile_shapes.py +49 -0
  253. tests/test_mapper.py +100 -0
  254. tests/test_model.py +37 -0
  255. tests/test_plotting.py +72 -0
  256. tests/test_processing_stage.py +46 -0
  257. tests/test_symbolic_model.py +248 -0
  258. tests/test_workload.py +141 -0
@@ -0,0 +1,463 @@
1
+ from accelforge.frontend.mapping.mapping import MappingNode
2
+
3
+
4
+ import copy
5
+ from collections import defaultdict
6
+ import itertools
7
+ import logging
8
+ from typing import Any, Iterator, List
9
+ import uuid
10
+
11
+ from tqdm import tqdm
12
+
13
+ import accelforge.frontend.arch as arch
14
+ from accelforge.frontend.mapping import (
15
+ Compute,
16
+ Loop,
17
+ Mapping,
18
+ MappingNode,
19
+ Spatial,
20
+ TensorHolder,
21
+ Temporal,
22
+ )
23
+ from accelforge.frontend.spec import Spec
24
+ from accelforge.frontend._workload_isl._isl import get_rank_variable_bounds
25
+ from accelforge.frontend._workload_isl._symbolic import (
26
+ Relevant,
27
+ get_rank_variable_relevancy,
28
+ get_stride_and_halo,
29
+ get_stride_and_halo_of_einsum,
30
+ PartiallyRelevant,
31
+ )
32
+ from accelforge.frontend.workload import (
33
+ TensorName,
34
+ Einsum,
35
+ EinsumName,
36
+ RankVariable,
37
+ Workload,
38
+ isl_expression_has_variable,
39
+ SymbolTable,
40
+ )
41
+ from accelforge.mapper.FFM._make_pmappings.make_pmapping_templates.make_storage_order import (
42
+ get_tensor_choices,
43
+ )
44
+ from accelforge.mapper.FFM._make_pmappings.make_pmapping_templates.make_reservations import (
45
+ get_reservation_choices,
46
+ )
47
+ from accelforge.mapper.FFM._make_pmappings.contraints.constraints import (
48
+ MappingConstraints,
49
+ get_constraints,
50
+ )
51
+ from accelforge.mapper.FFM._make_pmappings.make_pmapping_templates.make_loops import (
52
+ insert_temporal_loops,
53
+ insert_spatial_loops,
54
+ )
55
+ from accelforge.mapper.FFM._make_pmappings.pmapper_job import (
56
+ Job,
57
+ SameEinsumJobs,
58
+ )
59
+ from accelforge.model._looptree.reuse.symbolic import label_fused_loops
60
+
61
+
62
+ def unpack_loops_to_rank_variables(mapping: List[MappingNode]):
63
+ mapping_new = []
64
+ for node in mapping:
65
+ if not isinstance(node, Loop) or not isinstance(node.rank_variable, set):
66
+ mapping_new.append(node)
67
+ continue
68
+
69
+ for r in sorted(node.rank_variable):
70
+ mapping_new.append(
71
+ type(node)(
72
+ rank_variable=r,
73
+ **node.model_dump(exclude={"rank_variable"}, recursive=False),
74
+ )
75
+ )
76
+ return mapping_new
77
+
78
+
79
+ # =================================================================================================
80
+ # Iterate over mappings
81
+ # =================================================================================================
82
+ def place_missing_temporal_loops(
83
+ mapping: List[MappingNode], einsum: Einsum, flattened_arch: list[arch.Leaf]
84
+ ):
85
+ """
86
+ Adds temporal loops to the mapping to fill in any rank variables that are missing.
87
+ This may occur if there are no points where it'd be helpful to add a non-fused loop,
88
+ so we just need to add one somewhere.
89
+ """
90
+ # If any rank variables are missing, add them as high as possible.
91
+
92
+ rank_variables = einsum.rank_variables
93
+ for m in mapping:
94
+ if isinstance(m, Temporal) and not m._fused:
95
+ rank_variables.discard(m.rank_variable)
96
+
97
+ # Insert point: Right under the last backing & below any out-of-order fanouts
98
+ fanouts = {}
99
+ fanout = 1
100
+ for node in flattened_arch:
101
+ fanouts[node.name] = (fanout := fanout * node.get_fanout())
102
+
103
+ insert_point = 0
104
+ greatest_previous_fanout = 1
105
+ for i in range(len(mapping)):
106
+ if isinstance(mapping[i], TensorHolder):
107
+ if mapping[i]._backing:
108
+ insert_point = i + 1
109
+ cur_fanout = fanouts[mapping[i].component]
110
+ if cur_fanout < greatest_previous_fanout:
111
+ insert_point = i + 1
112
+ greatest_previous_fanout = max(greatest_previous_fanout, cur_fanout)
113
+
114
+ # Put it below all the other temporals here in case we're lowering through them
115
+ if isinstance(mapping[i], Temporal) and insert_point == i:
116
+ insert_point = i + 1
117
+
118
+ temporals = [Temporal(rank_variable=r) for r in sorted(rank_variables)]
119
+
120
+ if insert_point == len(mapping):
121
+ mapping.extend(temporals)
122
+ else:
123
+ for t in temporals:
124
+ mapping.insert(insert_point, t)
125
+
126
+
127
+ def remove_unordered_spatial_temporal_loops(
128
+ mapping: list[MappingNode],
129
+ flattened_arch: list[arch.Leaf],
130
+ einsum: Einsum,
131
+ explore_unordered_spatial_loops: bool = True,
132
+ ):
133
+ fanout = 1
134
+ fanouts = {}
135
+ for node in flattened_arch:
136
+ fanouts[node.name] = (fanout := fanout * node.get_fanout())
137
+
138
+ index_exprs = einsum.indexing_expressions
139
+
140
+ # Remove a temporal loop if:
141
+ # - It's between a spatial loop and a storage node above that fanout in the arch
142
+ # - It indexes into one of the same indexing expressions as the spatial loop
143
+
144
+ disallowed_combinations: list[tuple[set[int], set[int]]] = []
145
+ for i, node in enumerate(mapping):
146
+ if not isinstance(node, Spatial):
147
+ continue
148
+
149
+ last_idx_to_check = _idx_of_lowest_tensor_holder_with_component_above_fanout(
150
+ mapping, i, fanouts, node
151
+ )
152
+ to_check = mapping[i + 1 : last_idx_to_check]
153
+ to_remove = set()
154
+ for n in to_check:
155
+ if isinstance(n, Temporal):
156
+ for expr in index_exprs:
157
+ if not isl_expression_has_variable(expr, node.rank_variable):
158
+ continue
159
+ if not isl_expression_has_variable(expr, n.rank_variable):
160
+ continue
161
+ to_remove.add(id(n))
162
+ break
163
+
164
+ if to_remove:
165
+ disallowed_combinations.append((set([id(node)]), to_remove))
166
+
167
+ if not explore_unordered_spatial_loops:
168
+ disallowed_combinations = [x[1:] for x in disallowed_combinations]
169
+
170
+ for combo in itertools.product(*disallowed_combinations):
171
+ combo = set.union(set(), *combo)
172
+ yield [n for n in mapping if id(n) not in combo]
173
+
174
+
175
+ def _idx_of_lowest_tensor_holder_with_component_above_fanout(
176
+ mapping, start_idx, fanouts, node
177
+ ):
178
+ """
179
+ Return idx of lowest tensor holder with component above fanout. If none
180
+ found, returns index right under start idx (start_idx + 1).
181
+ """
182
+ for j in range(len(mapping) - 1, start_idx, -1):
183
+ n = mapping[j]
184
+ if (
185
+ isinstance(n, TensorHolder)
186
+ and fanouts[n.component] < fanouts[node.component]
187
+ ):
188
+ return j
189
+ return start_idx + 1
190
+
191
+
192
+ def pad_with_bottom_loops(mapping: list[MappingNode], einsum: Einsum):
193
+ rank_variables = einsum.rank_variables
194
+ rank_var_to_count = defaultdict(lambda: 0)
195
+ for node in mapping:
196
+ if isinstance(node, Temporal):
197
+ rank_var_to_count[node.rank_variable] += 1
198
+
199
+ for rank_var in rank_variables:
200
+ if rank_var_to_count[rank_var] < 2:
201
+ mapping.append(Temporal(rank_variable=rank_var))
202
+
203
+
204
+ def _timeloop_style_even(mapping: list[MappingNode]):
205
+ # Iterate through the mapping. If there are >2 TensorHolder nodes for the same
206
+ # memory, move all below the 2nd to the same level as the 2nd.
207
+ mapping = copy.deepcopy(mapping)
208
+ memory2indices = defaultdict(list)
209
+ i = 0
210
+ while i < len(mapping):
211
+ node = mapping[i]
212
+ if not isinstance(mapping[i], TensorHolder):
213
+ i += 1
214
+ continue
215
+ node: TensorHolder
216
+ seen = memory2indices[node.component]
217
+ mapping[i]._lower = False # Lowering might re-uneven the reservationsxs
218
+
219
+ if len(seen) <= 1:
220
+ seen.append(i)
221
+ else:
222
+ mapping.insert(seen[-1] + 1, mapping.pop(i))
223
+ i += 1
224
+ return mapping
225
+
226
+
227
+ def assert_proper_fusion_labeling(
228
+ mapping: list[MappingNode],
229
+ fusable_tensors: set[TensorName],
230
+ check_loops: bool = True,
231
+ ):
232
+ tensors = set()
233
+ for i, t in enumerate(mapping):
234
+ if not isinstance(t, TensorHolder):
235
+ continue
236
+
237
+ new = (set(t.tensors) - tensors) & fusable_tensors
238
+
239
+ if new and check_loops:
240
+ for j in range(i):
241
+ if isinstance(mapping[j], Loop):
242
+ assert mapping[
243
+ j
244
+ ]._fused, f"Node {j} is not fused in {' '.join(m.compact_str() for m in mapping)}"
245
+ assert (
246
+ t._backing & fusable_tensors
247
+ ) == new, f"Node {i} backing missing {new - t._backing} in {' '.join(m.compact_str() for m in mapping)}"
248
+ tensors.update(new)
249
+ tensors.update(t.tensors)
250
+
251
+
252
+ def get_initial_delta_choices(einsum_name: str, workload: Workload):
253
+ stride_and_halo = get_stride_and_halo(workload)
254
+ einsum = workload.einsums[einsum_name]
255
+
256
+ choices = defaultdict(lambda: set([0]))
257
+ consumer_chains = []
258
+ stack = [[(None, einsum)]]
259
+ while stack:
260
+ cur_chain = stack.pop()
261
+ last_tensor, last_einsum = cur_chain[-1]
262
+ for tensor in last_einsum.output_tensor_names:
263
+ einsums_with_tensor_as_input = workload.einsums_with_tensor_as_input(tensor)
264
+
265
+ if len(einsums_with_tensor_as_input) == 0:
266
+ consumer_chains.append(cur_chain)
267
+
268
+ for next_einsum in einsums_with_tensor_as_input:
269
+ stack.append(cur_chain + [(tensor, next_einsum)])
270
+
271
+ for chain in consumer_chains:
272
+ for (_, producer), (tensor, consumer) in zip(
273
+ list(reversed(chain))[1:], reversed(chain)
274
+ ):
275
+ rank_stride_and_halo = stride_and_halo[(consumer.name, tensor)]
276
+ if tensor is None:
277
+ break # done
278
+
279
+ for cons_rank_var in consumer.rank_variables:
280
+ for prod_rank_var in producer.rank_variables:
281
+ for cons_choice in choices[cons_rank_var]:
282
+ if (prod_rank_var, cons_rank_var) not in rank_stride_and_halo:
283
+ continue
284
+ stride, halo = rank_stride_and_halo[
285
+ (prod_rank_var, cons_rank_var)
286
+ ]
287
+ choices[prod_rank_var].add(cons_choice * stride + halo)
288
+
289
+ return choices
290
+
291
+
292
+ def get_ranks_with_tile_pattern(producer_name: EinsumName, workload: Workload):
293
+ initial_choices = get_initial_delta_choices(producer_name, workload)
294
+ return {
295
+ rank_var
296
+ for rank_var in workload.einsums[producer_name].rank_variables
297
+ if len(initial_choices[rank_var]) > 1
298
+ }
299
+
300
+
301
+ def iterate_mappings_no_constraints(
302
+ spec: Spec,
303
+ einsum_name: str,
304
+ flattened_arch: list[arch.Leaf],
305
+ rank_variable_bounds: dict[RankVariable, int],
306
+ job: Job,
307
+ ) -> Iterator[tuple[Mapping, SymbolTable, arch.Compute, int]]:
308
+ first_memory = None
309
+ for node in flattened_arch:
310
+ if isinstance(node, arch.Memory):
311
+ first_memory = node
312
+ break
313
+ if first_memory is None:
314
+ raise ValueError("No memory found in architecture")
315
+
316
+ ranks_with_tile_pattern = get_ranks_with_tile_pattern(einsum_name, spec.workload)
317
+
318
+ einsum = spec.workload.einsums[einsum_name]
319
+ symbol_table = {r.name: r.source for r in einsum.renames}
320
+ fusable_tensors = job.fusable_tensors
321
+
322
+ for mapping, symbol_table, compute in get_tensor_choices(
323
+ einsum_name,
324
+ flattened_arch,
325
+ symbol_table,
326
+ spec,
327
+ first_memory,
328
+ fusable_tensors,
329
+ ):
330
+ logging.info(
331
+ "\tGenerated tensor choices: " + ", ".join(m.compact_str() for m in mapping)
332
+ )
333
+ mapping = copy.deepcopy(mapping)
334
+ for mapping, n_orders in insert_temporal_loops(
335
+ mapping,
336
+ einsum,
337
+ first_memory,
338
+ rank_variable_bounds,
339
+ ranks_with_tile_pattern,
340
+ spec.workload,
341
+ spec.mapper.ffm._can_lower_outermost_memory,
342
+ flattened_arch,
343
+ spec.mapper.ffm.max_fused_loops,
344
+ ):
345
+ mapping = copy.deepcopy(mapping)
346
+ insert_spatial_loops(mapping, einsum, flattened_arch)
347
+ mapping = unpack_loops_to_rank_variables(mapping)
348
+ if spec.mapper.ffm._timeloop_style_even:
349
+ mapping = _timeloop_style_even(mapping)
350
+
351
+ place_missing_temporal_loops(mapping, einsum, flattened_arch)
352
+ label_fused_loops(mapping, fusable_tensors)
353
+ assert_proper_fusion_labeling(mapping, fusable_tensors)
354
+ yield mapping, symbol_table, compute, n_orders
355
+
356
+
357
+ def iterate_mappings_constraints(
358
+ spec: Spec,
359
+ einsum_names: list[str] | str,
360
+ flattened_arch: list[arch.Leaf],
361
+ rank_variable_bounds: dict[RankVariable, int],
362
+ tensor_to_relevancy: dict[
363
+ TensorName, dict[RankVariable, Relevant | PartiallyRelevant]
364
+ ],
365
+ job: Job,
366
+ ) -> Iterator[tuple[Mapping, MappingConstraints, dict[str, str]]]:
367
+ compute_name = flattened_arch[-1].name
368
+
369
+ n_yielded = 0
370
+
371
+ if isinstance(einsum_names, str):
372
+ einsum_names = [einsum_names]
373
+
374
+ for einsum_name in einsum_names:
375
+ logging.info(
376
+ f"Generating pmapping templates for compute {compute_name} Einsums "
377
+ f"{einsum_name}"
378
+ )
379
+
380
+ for mapping, symbol_table, compute, n_orders in iterate_mappings_no_constraints(
381
+ spec,
382
+ einsum_name,
383
+ flattened_arch,
384
+ rank_variable_bounds,
385
+ job,
386
+ ):
387
+ mapping, constraints = get_constraints(
388
+ flattened_arch, mapping, symbol_table, einsum_name, tensor_to_relevancy
389
+ )
390
+
391
+ # This goes after the constraints because constraints may remove some loops,
392
+ # giving us fewer that may be reordered.
393
+ for mapping in remove_unordered_spatial_temporal_loops(
394
+ mapping,
395
+ flattened_arch,
396
+ spec.workload.einsums[einsum_name],
397
+ spec.mapper.ffm.out_of_order_hierarchy_explore_removing_spatials_for_more_temporals,
398
+ ):
399
+ constraints.remove_missing_targets(mapping)
400
+
401
+ mapping.append(
402
+ Compute(
403
+ einsum=einsum_name,
404
+ component=compute_name,
405
+ component_object=compute,
406
+ )
407
+ )
408
+
409
+ # MAPPING MUST NOT BE MODIFIED AFTER constraints.set_loop_indices
410
+ constraints.set_loop_indices(mapping)
411
+
412
+ mapping = Mapping(nodes=[copy.copy(n) for n in mapping])
413
+ mapping._n_loop_orders = n_orders
414
+ yield mapping, constraints, symbol_table
415
+ n_yielded += 1
416
+ if n_yielded >= spec.mapper.ffm.max_pmapping_templates_per_einsum:
417
+ return
418
+
419
+
420
+ # =================================================================================================
421
+ # Top level
422
+ # =================================================================================================
423
+ def make_pmapping_templates(job: Job) -> SameEinsumJobs:
424
+ compute_name = job.flattened_arch[-1].name
425
+
426
+ job.tensor_to_relevancy = {
427
+ tensor: get_rank_variable_relevancy(
428
+ job.spec.workload.einsums[job.einsum_name], tensor
429
+ )
430
+ for tensor in job.spec.workload.einsums[job.einsum_name].tensor_names
431
+ }
432
+
433
+ mappings_constraints = tqdm(
434
+ iterate_mappings_constraints(
435
+ job.spec,
436
+ job.einsum_name,
437
+ job.flattened_arch,
438
+ job.rank_variable_bounds,
439
+ job.tensor_to_relevancy,
440
+ job,
441
+ ),
442
+ desc=f"Generating pmapping templates for compute {compute_name} Einsum {job.einsum_name}",
443
+ )
444
+
445
+ stride_and_halo = get_stride_and_halo_of_einsum(
446
+ job.einsum_name, job.spec.workload, job.rank_variable_bounds
447
+ )
448
+
449
+ jobs = SameEinsumJobs()
450
+ only_output_pmapping_index = job.spec.mapper.ffm._only_output_pmapping_index
451
+ for i, (mapping, constraints, symbol_table) in enumerate(mappings_constraints):
452
+ if only_output_pmapping_index is not None and i != only_output_pmapping_index:
453
+ continue
454
+ new_job = copy.copy(job)
455
+ new_job.mapping = mapping
456
+ new_job.constraints = constraints
457
+ new_job.job_id = uuid.uuid4()
458
+ new_job.rank_variable_bounds = job.rank_variable_bounds
459
+ new_job.stride_and_halo = stride_and_halo
460
+ new_job.compatibility
461
+ jobs.append(new_job)
462
+
463
+ return jobs
@@ -0,0 +1,95 @@
1
+ from collections.abc import Generator
2
+ from typing import Any
3
+
4
+ import accelforge.frontend.arch as arch
5
+ from accelforge.frontend.mapping import MappingNode, Reservation, Storage, TensorHolder
6
+ from accelforge.frontend.spec import Spec
7
+
8
+
9
+ def _recursive_iter_fence_positions(
10
+ fence_positions: dict[str, int],
11
+ max_size: int,
12
+ ) -> Generator[tuple[list[TensorHolder], Any], None, None]:
13
+ if not fence_positions:
14
+ yield {}
15
+ mine = next(iter(fence_positions))
16
+ myval = fence_positions[mine]
17
+ following = {k: v for k, v in fence_positions.items() if k != mine}
18
+ for i in range(myval, max_size):
19
+ following = {k: max(v, i) for k, v in fence_positions.items() if k != mine}
20
+ for following in _recursive_iter_fence_positions(following, max_size):
21
+ yield {mine: i, **following}
22
+
23
+
24
+ def get_reservation_choices(
25
+ mapping: list[TensorHolder],
26
+ flattened_arch: list[arch.Leaf],
27
+ ) -> Generator[tuple[list[TensorHolder], Any], None, None]:
28
+ # Rules:
29
+ # - In general, reservations go right under their storage node
30
+ # - If a storage node is associated with a fanout, explore putting the reservation
31
+ # below it, below the next storage node, and so on. Stop once we don't have any
32
+ # more spatial loops to place. Push down all reservations below this fanout
33
+ # together.
34
+
35
+ # Spatial loops:
36
+ # - Must go below all storage nodes associated with something above the fanout.
37
+ # -> Memories above fanout must serve all fetches across fanout instances.
38
+ # - Must go above all reservations associated with something below the fanout.
39
+ # -> Memories below fanout must be reserved for each fanout instance.
40
+ # - If below any storage node associated with the fanout, then must be relevant.
41
+ # -> No peer-to-peer communication
42
+
43
+ # Temporal loops:
44
+ # - If between a storage node and a reservation node, the outermost temporal loop
45
+ # may be partially relevant. All others must be relevant.
46
+
47
+ # Design choices here:
48
+ # - Where to put the 'fence' for each fanout
49
+
50
+ fanout_nodes = [n for n in flattened_arch if n.get_fanout() > 1]
51
+ fanout_node_names = set[str](n.name for n in fanout_nodes)
52
+ last_seen_fanout = None
53
+ node2lastfanout = {}
54
+
55
+ fence_positions: dict[str, int] = {}
56
+ for i, node in enumerate(mapping):
57
+ if node.component in fanout_node_names:
58
+ fence_positions.setdefault(node.component, i)
59
+ last_seen_fanout = node.component
60
+ node2lastfanout[id(node)] = last_seen_fanout
61
+
62
+ def try_add_reservations(
63
+ new_mapping: list[MappingNode],
64
+ reservations_to_add: list[TensorHolder],
65
+ fence_positions: dict[str, int],
66
+ ):
67
+ for res in list(reservations_to_add):
68
+ add = False
69
+ if node2lastfanout[id(res)] is None:
70
+ add = True
71
+ elif i >= fence_positions[node2lastfanout[id(res)]]:
72
+ add = True
73
+ if add:
74
+ new_mapping.append(
75
+ Reservation(
76
+ purposes=[res.component],
77
+ resource=res.component,
78
+ persistent=res.persistent,
79
+ )
80
+ )
81
+ reservations_to_add.remove(res)
82
+
83
+ # Fence positions are indices of storage nodes below which we'll push all the
84
+ # reservations below that fanout
85
+ for fence_positions in _recursive_iter_fence_positions(
86
+ fence_positions, len(mapping)
87
+ ):
88
+ new_mapping = []
89
+ reservations_to_add = []
90
+ for i, node in enumerate(mapping):
91
+ new_mapping.append(node)
92
+ reservations_to_add.append(node)
93
+ try_add_reservations(new_mapping, reservations_to_add, fence_positions)
94
+ try_add_reservations(new_mapping, reservations_to_add, fence_positions)
95
+ yield new_mapping