accelforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. accelforge/__init__.py +21 -0
  2. accelforge/_accelerated_imports.py +16 -0
  3. accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
  4. accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
  5. accelforge/_deprecate/_simanneal/simanneal.py +666 -0
  6. accelforge/_deprecate/_simanneal/tracking.py +105 -0
  7. accelforge/_deprecate/_simanneal/wrappers.py +218 -0
  8. accelforge/_deprecate/_simanneal2/__init__.py +7 -0
  9. accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
  10. accelforge/_deprecate/_simanneal2/tracking.py +116 -0
  11. accelforge/_deprecate/compatibility_util.py +181 -0
  12. accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
  13. accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
  14. accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
  15. accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
  16. accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
  17. accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
  18. accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
  19. accelforge/_deprecate/tags.py +69 -0
  20. accelforge/_deprecate/viz/__init__.py +0 -0
  21. accelforge/_deprecate/viz/interactive.py +159 -0
  22. accelforge/_deprecate/viz/reservationtree.py +307 -0
  23. accelforge/_deprecate/viz/ski_slope.py +88 -0
  24. accelforge/_version.py +15 -0
  25. accelforge/examples.py +39 -0
  26. accelforge/frontend/__init__.py +10 -0
  27. accelforge/frontend/_binding.py +129 -0
  28. accelforge/frontend/_workload_isl/__init__.py +2 -0
  29. accelforge/frontend/_workload_isl/_isl.py +149 -0
  30. accelforge/frontend/_workload_isl/_symbolic.py +141 -0
  31. accelforge/frontend/arch copy.py +1544 -0
  32. accelforge/frontend/arch.py +1642 -0
  33. accelforge/frontend/config.py +63 -0
  34. accelforge/frontend/mapper/__init__.py +5 -0
  35. accelforge/frontend/mapper/ffm.py +126 -0
  36. accelforge/frontend/mapper/mapper.py +7 -0
  37. accelforge/frontend/mapper/metrics.py +30 -0
  38. accelforge/frontend/mapping/__init__.py +1 -0
  39. accelforge/frontend/mapping/mapping.py +1736 -0
  40. accelforge/frontend/model.py +14 -0
  41. accelforge/frontend/renames.py +150 -0
  42. accelforge/frontend/spec copy.py +230 -0
  43. accelforge/frontend/spec.py +301 -0
  44. accelforge/frontend/variables.py +12 -0
  45. accelforge/frontend/workload.py +952 -0
  46. accelforge/mapper/FFM/__init__.py +9 -0
  47. accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
  48. accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
  49. accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
  50. accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
  51. accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
  52. accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
  53. accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
  54. accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
  55. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
  56. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
  57. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
  58. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
  59. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
  60. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
  61. accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
  62. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
  63. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
  64. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
  65. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
  66. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
  67. accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
  68. accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
  69. accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
  70. accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
  71. accelforge/mapper/FFM/data.py +61 -0
  72. accelforge/mapper/FFM/main copy.py +236 -0
  73. accelforge/mapper/FFM/main.py +208 -0
  74. accelforge/mapper/FFM/mappings.py +510 -0
  75. accelforge/mapper/FFM/pmappings.py +310 -0
  76. accelforge/mapper/__init__.py +4 -0
  77. accelforge/mapper.py +0 -0
  78. accelforge/model/__init__.py +1 -0
  79. accelforge/model/_looptree/__init__.py +0 -0
  80. accelforge/model/_looptree/accesses.py +335 -0
  81. accelforge/model/_looptree/capacity/__init__.py +1 -0
  82. accelforge/model/_looptree/capacity/aggregators.py +36 -0
  83. accelforge/model/_looptree/capacity/capacity.py +47 -0
  84. accelforge/model/_looptree/energy.py +150 -0
  85. accelforge/model/_looptree/equivalent_ranks.py +29 -0
  86. accelforge/model/_looptree/latency/__init__.py +1 -0
  87. accelforge/model/_looptree/latency/latency.py +98 -0
  88. accelforge/model/_looptree/latency/memory.py +120 -0
  89. accelforge/model/_looptree/latency/processors.py +92 -0
  90. accelforge/model/_looptree/mapping_utilities.py +71 -0
  91. accelforge/model/_looptree/reuse/__init__.py +4 -0
  92. accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
  93. accelforge/model/_looptree/reuse/isl/des.py +59 -0
  94. accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
  95. accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
  96. accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
  97. accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
  98. accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
  99. accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
  100. accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
  101. accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
  102. accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
  103. accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
  104. accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
  105. accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
  106. accelforge/model/_looptree/run.py +122 -0
  107. accelforge/model/_looptree/types.py +26 -0
  108. accelforge/model/_looptree/visualization/__init__.py +0 -0
  109. accelforge/model/_looptree/visualization/occupancy.py +11 -0
  110. accelforge/model/main.py +222 -0
  111. accelforge/plotting/__init__.py +2 -0
  112. accelforge/plotting/mappings.py +219 -0
  113. accelforge/plotting/specs.py +57 -0
  114. accelforge/util/__init__.py +4 -0
  115. accelforge/util/_base_analysis_types.py +24 -0
  116. accelforge/util/_basetypes.py +1089 -0
  117. accelforge/util/_frozenset.py +36 -0
  118. accelforge/util/_isl.py +29 -0
  119. accelforge/util/_itertools.py +14 -0
  120. accelforge/util/_mathfuncs.py +57 -0
  121. accelforge/util/_parse_expressions.py +339 -0
  122. accelforge/util/_picklecache.py +32 -0
  123. accelforge/util/_setexpressions.py +268 -0
  124. accelforge/util/_sympy/__init__.py +0 -0
  125. accelforge/util/_sympy/broadcast_max.py +18 -0
  126. accelforge/util/_visualization.py +112 -0
  127. accelforge/util/_yaml.py +579 -0
  128. accelforge/util/parallel.py +193 -0
  129. accelforge-0.0.1.dist-info/METADATA +64 -0
  130. accelforge-0.0.1.dist-info/RECORD +258 -0
  131. accelforge-0.0.1.dist-info/WHEEL +5 -0
  132. accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
  133. accelforge-0.0.1.dist-info/top_level.txt +5 -0
  134. docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
  135. docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
  136. docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
  137. docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
  138. docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
  139. docs/_build/html/_sources/fastfusion.rst.txt +20 -0
  140. docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
  141. docs/_build/html/_sources/index.rst.txt +87 -0
  142. docs/_build/html/_sources/modules.rst.txt +7 -0
  143. docs/_build/html/_sources/notes/citation.rst.txt +45 -0
  144. docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
  145. docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
  146. docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
  147. docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
  148. docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
  149. docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
  150. docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
  151. docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
  152. docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
  153. docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
  154. docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
  155. docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
  156. docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
  157. docs/_build/html/_sources/notes/spec.rst.txt +36 -0
  158. docs/source/_ext/include_attrs.py +213 -0
  159. docs/source/_ext/include_docstring.py +364 -0
  160. docs/source/_ext/include_functions.py +154 -0
  161. docs/source/_ext/include_notebook.py +131 -0
  162. docs/source/_ext/include_yaml.py +119 -0
  163. docs/source/_ext/inherited_attributes.py +222 -0
  164. docs/source/_ext/paths.py +4 -0
  165. docs/source/conf.py +79 -0
  166. examples/arches/compute_in_memory/_include.yaml +74 -0
  167. examples/arches/compute_in_memory/_include_functions.py +229 -0
  168. examples/arches/compute_in_memory/_load_spec.py +57 -0
  169. examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
  170. examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
  171. examples/arches/compute_in_memory/components/misc.py +195 -0
  172. examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
  173. examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
  174. examples/arches/compute_in_memory/isaac.yaml +233 -0
  175. examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
  176. examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
  177. examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
  178. examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
  179. examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
  180. examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
  181. examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
  182. examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
  183. examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
  184. examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
  185. examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
  186. examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
  187. examples/arches/eyeriss.yaml +68 -0
  188. examples/arches/fanout_variations/at_glb.yaml +31 -0
  189. examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
  190. examples/arches/fanout_variations/at_mac.yaml +31 -0
  191. examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
  192. examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
  193. examples/arches/nvdla.yaml +47 -0
  194. examples/arches/simple.yaml +28 -0
  195. examples/arches/tpu_v4i.yaml +67 -0
  196. examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
  197. examples/misc/component_annotated.yaml +33 -0
  198. examples/workloads/gpt3_6.7B.yaml +124 -0
  199. examples/workloads/matmuls.yaml +20 -0
  200. examples/workloads/mobilenet_28.yaml +81 -0
  201. examples/workloads/mobilenet_various_separate.yaml +106 -0
  202. examples/workloads/three_matmuls_annotated.yaml +59 -0
  203. notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
  204. notebooks/compute_in_memory/_scripts.py +339 -0
  205. notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
  206. notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
  207. notebooks/paths.py +4 -0
  208. notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
  209. notebooks/tutorials/FFM.ipynb +3498 -0
  210. notebooks/tutorials/_include.py +48 -0
  211. notebooks/tutorials/component_energy_area.ipynb +363 -0
  212. tests/Q_mapping.yaml +38 -0
  213. tests/__init__.py +0 -0
  214. tests/conv.mapping.yaml +27 -0
  215. tests/conv.workload.yaml +13 -0
  216. tests/conv_sym.mapping.yaml +43 -0
  217. tests/copy.mapping.yaml +35 -0
  218. tests/copy.workload.yaml +15 -0
  219. tests/distribuffers/__init__.py +0 -0
  220. tests/distribuffers/multicast/test_cases.yaml +482 -0
  221. tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
  222. tests/distribuffers/spec/distributed.yaml +100 -0
  223. tests/distribuffers/spec/logical_arch.yaml +32 -0
  224. tests/distribuffers/spec/physical_arch.yaml +69 -0
  225. tests/distribuffers/test_binding.py +48 -0
  226. tests/frontend/__init__.py +0 -0
  227. tests/frontend/test_mapping_viz.py +52 -0
  228. tests/mapper/__init__.py +0 -0
  229. tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
  230. tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
  231. tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
  232. tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
  233. tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
  234. tests/mapper/test_mapping_to_isl.py +90 -0
  235. tests/mapper/test_spatial_reuse_analysis.py +67 -0
  236. tests/mapper/test_temporal_reuse_analysis.py +56 -0
  237. tests/mapper/util.py +58 -0
  238. tests/matmul.mapping.yaml +29 -0
  239. tests/matmul.workload.yaml +12 -0
  240. tests/matmul_spatial.mapping.yaml +44 -0
  241. tests/mha.renames.yaml +65 -0
  242. tests/mha.workload.yaml +67 -0
  243. tests/mha.yaml +59 -0
  244. tests/mha_full.workload.yaml +67 -0
  245. tests/mobilenet.workload.yaml +35 -0
  246. tests/mobilenet_long.workload.yaml +64 -0
  247. tests/pmappingcache.py +24 -0
  248. tests/processing_stage.arch.yaml +40 -0
  249. tests/snowcat.arch.yaml +36 -0
  250. tests/test_ffm_join_pmappings.py +106 -0
  251. tests/test_ffm_make_pmappings.py +82 -0
  252. tests/test_ffm_make_tile_shapes.py +49 -0
  253. tests/test_mapper.py +100 -0
  254. tests/test_model.py +37 -0
  255. tests/test_plotting.py +72 -0
  256. tests/test_processing_stage.py +46 -0
  257. tests/test_symbolic_model.py +248 -0
  258. tests/test_workload.py +141 -0
@@ -0,0 +1,685 @@
1
+ """
2
+ File for all the functions that conduct tiling analysis for the overall mapping
3
+ analysis.
4
+ """
5
+
6
+ from collections import defaultdict, deque
7
+ from typing import List, Tuple, Optional
8
+
9
+ from pprint import pformat
10
+
11
+ import islpy as isl
12
+
13
+ from accelforge.frontend.mapping import (
14
+ # Types
15
+ MappingNode,
16
+ # Mapping objects
17
+ Mapping,
18
+ MappingNodeWithChildren,
19
+ Nested,
20
+ # Physical object types in Mappings.
21
+ Compute,
22
+ Storage,
23
+ # Logical object types in Mappings.
24
+ Loop,
25
+ Spatial,
26
+ Temporal,
27
+ Split,
28
+ )
29
+ from accelforge.frontend.workload import (
30
+ # Workload class for all of AccelForge.
31
+ Workload,
32
+ )
33
+ from accelforge.frontend._workload_isl._isl import (
34
+ get_einsum_operation_space,
35
+ get_projection_map,
36
+ )
37
+ from accelforge.frontend.mapping import TensorName
38
+ from accelforge.model._looptree.reuse.isl.isl_functions import (
39
+ add_dims_preserve_name_map,
40
+ insert_dims_preserve_name_map,
41
+ map_to_prior_coordinate,
42
+ )
43
+ from accelforge.model._looptree.reuse.isl.mapping_to_isl import DUMP_ISL_IR
44
+ from accelforge.model._looptree.reuse.isl.mapping_to_isl.types import (
45
+ EinsumName,
46
+ Tiling,
47
+ BranchTiling,
48
+ )
49
+
50
+
51
+ def get_mapping_group_einsums(
52
+ mapping: Mapping,
53
+ ) -> defaultdict[MappingNode, set[EinsumName]]:
54
+ """
55
+ From a mapping, get the group of einsums for a given node.
56
+
57
+ Parameters
58
+ ----------
59
+ mapping:
60
+ The mapping we are getting the grouped einsums for.
61
+
62
+ Returns
63
+ -------
64
+ A dictionary relating a MappingNode to a set of einsums.
65
+ """
66
+ # Each pair is a (current_node, last_non_branch_node)
67
+ dfs_stack: deque[Tuple[MappingNode, MappingNode]] = deque()
68
+ # Each pair is a (last_non_branch_node, set_of_children_nodes)
69
+ child_stack: deque[Tuple[MappingNode, set[MappingNode]]] = deque()
70
+ result: defaultdict[MappingNode, set[EinsumName]] = defaultdict(set)
71
+
72
+ # Start DFS hierarchical search from the root.
73
+ dfs_stack.append((mapping, mapping))
74
+
75
+ # Exhaustive DFS search.
76
+ while dfs_stack:
77
+ # Grabs latest node to search.
78
+ node, last_non_branch = dfs_stack.pop()
79
+
80
+ # Differentiates behavior by number of child nodes.
81
+ match node:
82
+ case MappingNodeWithChildren():
83
+ match len(node.nodes):
84
+ # No children, log as a folded result.
85
+ case 0:
86
+ # Note:: Check necesary in case Distrobuffers elides
87
+ # computes into one large unit.
88
+ if isinstance(node, Compute):
89
+ result[last_non_branch].add(node.einsum)
90
+ else:
91
+ raise TypeError(
92
+ f"The following node should be of class "
93
+ f"Compute as it has no children:\n---\n{node}"
94
+ )
95
+ # Explore the children further.
96
+ case 1:
97
+ dfs_stack.append((node.nodes[0], last_non_branch))
98
+ # Log all branching children and explore all children.
99
+ case _:
100
+ children: set[MappingNode] = set(node.nodes)
101
+ child_stack.append((last_non_branch, children))
102
+ dfs_stack.extend((child, child) for child in children)
103
+ # Assumed no children, log as a folded result.
104
+ case Compute():
105
+ result[last_non_branch].add(node.einsum)
106
+ # These had children in Timeloop we had to add to the DFS, but because
107
+ # of our extension of dfs_stack we can just skip this node.
108
+ case Spatial() | Temporal() | Storage():
109
+ continue
110
+ case _:
111
+ raise AttributeError(
112
+ f"The following node of class {type(node)} has "
113
+ f"indeterminant number of children:\n---\n"
114
+ f"{node}"
115
+ )
116
+
117
+ # Push up einsums to parents.
118
+ for node, children in reversed(child_stack):
119
+ node_einsum_set: set[EinsumName] = result[node]
120
+ for child in children:
121
+ node_einsum_set.update(result[child])
122
+
123
+ return result
124
+
125
+
126
+ def get_head_among_einsums(
127
+ einsum_set: set[EinsumName], workload: Workload
128
+ ) -> set[EinsumName]:
129
+ """
130
+ Gets the provider einsums that only consume data (i.e., sink einsums).
131
+
132
+ Parameters
133
+ ----------
134
+ einsum_set:
135
+ Set of einsums to consider.
136
+ workload:
137
+ The workload context the einsums exist in.
138
+
139
+ Returns
140
+ -------
141
+ The set of all head einsums.
142
+ """
143
+ # Returns set of einsums that are not data producers.
144
+ return {
145
+ einsum
146
+ for einsum in einsum_set
147
+ if all(
148
+ not any(
149
+ consumer.name in einsum_set
150
+ for consumer in workload.einsums_with_tensor_as_input(output_tensor)
151
+ )
152
+ for output_tensor in workload.einsums[einsum].output_tensor_names
153
+ )
154
+ }
155
+
156
+
157
+ def add_new_tile_dim(
158
+ old_tiling: Tiling, dim_idx: int, tile_size: int, rank_var: Optional[str] = None
159
+ ) -> Tiling:
160
+ """
161
+ Given a tiling, add a new dimension to the tiling.
162
+
163
+ Parameters
164
+ ----------
165
+ old_tiling:
166
+ The previous tiling the mapper proposed.
167
+ dim_idx:
168
+ The index of the dimension being tiled.
169
+ tile_size:
170
+ The size of the tiling on dim_idx.
171
+ rank_var:
172
+ Rank variable name to assign to the new input dimension, if provided.
173
+
174
+ Returns
175
+ -------
176
+ The new Tiling with tiled dimension at dim_idx.
177
+ """
178
+
179
+ # new_tiling has one extra dimension at the end compared to old_tiling.
180
+ new_tiling = insert_dims_preserve_name_map(
181
+ old_tiling, isl.dim_type.in_, old_tiling.dim(isl.dim_type.in_), 1
182
+ )
183
+ if rank_var:
184
+ new_tiling = new_tiling.set_dim_name(
185
+ isl.dim_type.in_, old_tiling.dim(isl.dim_type.in_), rank_var
186
+ )
187
+
188
+ # Min and max of dim_idx. dimension being tiled as function of tiled dimensions.
189
+ dim_min: isl.PwAff = new_tiling.dim_min(dim_idx)
190
+ dim_max: isl.PwAff = new_tiling.dim_max(dim_idx)
191
+
192
+ # Aff from tiled dimensions space to value of newest dim.
193
+ new_dim_id: isl.Aff = isl.Aff.var_on_domain(
194
+ dim_min.get_domain_space().to_local_space(),
195
+ isl.dim_type.set,
196
+ dim_min.dim(isl.dim_type.in_) - 1,
197
+ )
198
+
199
+ # Aff from tiled dimensions space to tile tile size constant.
200
+ tile_size_aff: isl.Aff = isl.Aff.val_on_domain_space(
201
+ dim_min.get_domain_space(), isl.Val.int_from_ui(isl.DEFAULT_CONTEXT, tile_size)
202
+ )
203
+
204
+ # PwAff from tiled dimension space to tile_size * newest_dim.
205
+ tile_translate: isl.PwAff = isl.PwAff.from_aff(new_dim_id.mul(tile_size_aff))
206
+
207
+ # What dim_min should be given new tiling.
208
+ new_dim_min: isl.PwAff = dim_min.add(tile_translate)
209
+
210
+ # What dim_max should be given new tiling.
211
+ new_dim_max: isl.PwAff = new_dim_min.add(
212
+ isl.PwAff.from_aff(tile_size_aff.add_constant_val(-1))
213
+ )
214
+
215
+ # TODO: Might be logically equivalent to new_dim_id:
216
+ # https://github.com/NVlabs/timeloop/blob/32370826fdf1aa3c8deb0c93e6b2a2fc7cf053aa/src/loop-analysis/mapping-to-isl/tiling.cpp#L52-L59
217
+ new_iter_id: isl.PwAff = isl.PwAff.from_aff(
218
+ isl.Aff.var_on_domain(
219
+ new_tiling.get_space().domain(),
220
+ isl.dim_type.set,
221
+ old_tiling.dim(isl.dim_type.in_),
222
+ )
223
+ )
224
+
225
+ # The set of valid values of the new tiled dimensions.
226
+ iter_set: isl.Set = new_tiling.domain()
227
+ iter_set = iter_set.intersect(new_iter_id.le_set(dim_max.div(tile_size_aff).ceil()))
228
+ iter_set = iter_set.intersect(new_dim_min.ge_set(dim_min))
229
+
230
+ # The value of iter dims cannot exceed what was available before tiling.
231
+ new_tiling = new_tiling.intersect_domain(iter_set)
232
+
233
+ # The set of operations need to to follow the new tile bounds.
234
+ identity: isl.PwAff = isl.PwAff.from_aff(
235
+ isl.Aff.var_on_domain(new_tiling.get_space().range(), isl.dim_type.set, dim_idx)
236
+ )
237
+ new_tiling = new_tiling.intersect(new_dim_min.le_map(identity))
238
+ new_tiling = new_tiling.intersect(new_dim_max.ge_map(identity))
239
+
240
+ return new_tiling
241
+
242
+
243
+ def shared_input_based_tile_shape_inference(
244
+ workload: Workload,
245
+ tiling_info: defaultdict[EinsumName, Tiling],
246
+ einsums: set[EinsumName],
247
+ shared_input_tensor: TensorName,
248
+ tiled_einsum: EinsumName,
249
+ ) -> None:
250
+ """
251
+ Given a `tiled_einsum` in a `workload`, restrict the other `einsums`' execution
252
+ in this tiling to one in which the data is shared with the `tiled_einsum`. This
253
+ is because, when tiled, data is multicast so the other einsums being tiled together
254
+ must shared data.
255
+
256
+ Parameters
257
+ ----------
258
+ workload:
259
+ The workload context the tiling is occurring in.
260
+ tiling_info:
261
+ Relation of `EinsumName` and its viable tiling on hardware.
262
+ einsums:
263
+ The set of all einsums.
264
+ shared_input_tensor:
265
+ The singular tensor `einsums` all read from.
266
+ tiled_einsum:
267
+ The einsum being tiled.
268
+
269
+ Returns
270
+ -------
271
+ None
272
+
273
+ Postconditions
274
+ --------------
275
+ `tiling_info` is updated such that each Tiling contains only compatible tilings
276
+ with `tiled_einsum`.
277
+ """
278
+ # Gets the data tiled_einsum reads from shared_input_tensor
279
+ tiled_einsum_read_accesses: isl.Map = get_projection_map(
280
+ workload.einsums[tiled_einsum], shared_input_tensor
281
+ )
282
+ read_data: isl.Map = tiling_info[tiled_einsum].apply_range(
283
+ tiled_einsum_read_accesses
284
+ )
285
+
286
+ # Goes through all other einsums and restrict their tilings to only the executable
287
+ # operations after one of the einsums is tiled.
288
+ for einsum in einsums:
289
+ if einsum == tiled_einsum:
290
+ continue
291
+
292
+ read_accesses: isl.Map = get_projection_map(
293
+ workload.einsums[einsum], shared_input_tensor
294
+ )
295
+ executable_operations: isl.Map = read_data.apply_range(read_accesses.reverse())
296
+ executable_operations = executable_operations.intersect_range(
297
+ get_einsum_operation_space(workload, einsum)
298
+ )
299
+
300
+ tiling_info[einsum] = tiling_info[einsum].intersect(executable_operations)
301
+
302
+
303
+ def consumer_based_tile_shape_inference(
304
+ workload: Workload,
305
+ tiling_info: defaultdict[EinsumName, Tiling],
306
+ tensor_to_reuse_level: defaultdict[TensorName, int],
307
+ einsums: set[EinsumName],
308
+ tiled_einsum: EinsumName,
309
+ ):
310
+ """
311
+ Given a `tiled_einsum` in a `workload`, restrict the other `einsums`' execution
312
+ in this tiling to one in which the data is required for the tensors read by
313
+ `tiled_einsum`. This is because, when tiled, data is multicast so the other
314
+ einsums being tiled together must shared data.
315
+
316
+ Parameters
317
+ ----------
318
+ workload:
319
+ The workload context the tiling is occurring in.
320
+ tiling_info:
321
+ Relation of `EinsumName` and its viable tiling on hardware.
322
+ tensor_to_reuse_level:
323
+ A relation between a tensor and the amount of reuse occurring.
324
+ einsums:
325
+ The set of all einsums.
326
+ tiled_einsum:
327
+ The einsum being tiled.
328
+
329
+ Returns
330
+ -------
331
+ None
332
+
333
+ Postconditions
334
+ --------------
335
+ `tiling_info` is updated such that each Tiling contains only compatible tilings
336
+ with `tiled_einsum`.
337
+ """
338
+ # Goes recursively through tensor dependencies (read tensors) and tiles them.
339
+ queue: deque[EinsumName] = deque([tiled_einsum])
340
+ while queue:
341
+ einsum: EinsumName = queue.popleft()
342
+ tiling: Tiling = tiling_info[einsum]
343
+
344
+ # For each tensor read by this einsum, tile that tensor's producers.
345
+ for tensor in workload.einsums[einsum].input_tensor_names:
346
+ producer_einsums: set[EinsumName] = {
347
+ e.name for e in workload.einsums[einsum].output_tensor_names
348
+ }
349
+ if len(producer_einsums) > 1:
350
+ raise NotImplementedError(
351
+ "Tile shape inference cannot handle multiple einsums writing the same tensor."
352
+ )
353
+
354
+ # Not an intermediate tensor.
355
+ if not producer_einsums:
356
+ continue
357
+
358
+ producer_einsums.intersection_update(einsums)
359
+ # No producer einsum in this fusion set.
360
+ if not producer_einsums:
361
+ continue
362
+
363
+ # Collates all the consumer einsum read accesses.
364
+ producer_einsum: EinsumName = next(iter(producer_einsums))
365
+ read_accesses: isl.Map = get_projection_map(
366
+ workload.einsums[einsum], tensor
367
+ )
368
+ # Required data of the tiling as a mapping of read accesses.
369
+ required_data: isl.Map = tiling.apply_range(read_accesses)
370
+
371
+ # Calculates the data computed by the producer einsums.
372
+ computed_data: isl.Map = required_data
373
+ if tensor in tensor_to_reuse_level:
374
+ reuse_level: int = tensor_to_reuse_level[tensor]
375
+ shifter: isl.Map = map_to_prior_coordinate(
376
+ tiling.dim(isl.dim_type.in_),
377
+ reuse_level,
378
+ tiling.get_tuple_name(isl.dim_type.in_),
379
+ )
380
+ buffered_data: isl.Map = shifter.apply_range(required_data)
381
+ computed_data = computed_data.subtract(buffered_data).coalesce()
382
+
383
+ # Grabs the elements this tensor relies on from producer_einsums.
384
+ producer_write_dependency: isl.Map = get_projection_map(
385
+ workload.einsums[producer_einsum], tensor
386
+ )
387
+ # Gets the required operations to produce the current tensor.
388
+ required_operations: isl.Map = computed_data.apply_range(
389
+ producer_write_dependency.reverse()
390
+ )
391
+ required_operations = required_operations.intersect_range(
392
+ get_einsum_operation_space(workload, producer_einsum)
393
+ )
394
+
395
+ # Mutations of the tilings of producer einsums.
396
+ # TODO: Deal with fusing naming better (perhaps mix the names?)
397
+ tiling_info[producer_einsum] = tiling_info[producer_einsum].intersect(
398
+ required_operations.set_tuple_name(
399
+ isl.dim_type.in_,
400
+ tiling_info[producer_einsum].get_tuple_name(isl.dim_type.in_),
401
+ )
402
+ )
403
+
404
+ queue.append(producer_einsum)
405
+
406
+
407
+ def detect_shared_input_tensor(
408
+ fused_set: set[EinsumName], workload: Workload
409
+ ) -> List[TensorName]:
410
+ """
411
+ Given a set of fused einsums on a workload, detect the input tensor that they
412
+ all are dependent on, if it exists.
413
+
414
+ Parameters
415
+ ----------
416
+ fused_set:
417
+ The set of fused einsums being analyzed.
418
+ workload:
419
+ The workload context the einsums exist in.
420
+
421
+ Returns
422
+ -------
423
+ The list of tensors shared by the inputs. Because we default to consumer-based
424
+ analysis if there's more than 1 shared input among the tensors, we only return
425
+ tuple sizes of {0, 1, 2}.
426
+ """
427
+ n_einsums: int = 0
428
+ tensor_read_counts: defaultdict[TensorName, int] = defaultdict(lambda: 0)
429
+
430
+ # Counts the number of times a tensor is read by an einsum.
431
+ for einsum in fused_set:
432
+ for tensor in workload.einsums[einsum].input_tensor_names:
433
+ tensor_read_counts[tensor] += 1
434
+ n_einsums += 1
435
+
436
+ shared_input_tensors: List[TensorName] = []
437
+ for tensor, count in tensor_read_counts.items():
438
+ # Tensor is shared by all einsums.
439
+ if count == n_einsums:
440
+ shared_input_tensors.append(tensor)
441
+ # Caller should resort to consumer-based fusing methods.
442
+ if len(shared_input_tensors) > 1:
443
+ return shared_input_tensors
444
+
445
+ return shared_input_tensors
446
+
447
+
448
+ def tiling_from_mapping(mapping: Mapping, workload: Workload) -> BranchTiling:
449
+ """
450
+ Given a mapping and a workload generates a tiling.
451
+
452
+ Parameters
453
+ ----------
454
+ mapping:
455
+ A mapping of data to hardware.
456
+ workload:
457
+ The problem being solved.
458
+
459
+ Returns
460
+ -------
461
+ BranchTiling associating a node's ID with its tiling.
462
+ """
463
+ result: BranchTiling = BranchTiling()
464
+ # Grabs the head einsums.
465
+ mapping_groups: defaultdict[MappingNode, set[EinsumName]] = (
466
+ get_mapping_group_einsums(mapping)
467
+ )
468
+ mapping_group_heads: defaultdict[MappingNode, set[EinsumName]] = defaultdict(
469
+ set,
470
+ {
471
+ node: get_head_among_einsums(group, workload)
472
+ for node, group in mapping_groups.items()
473
+ },
474
+ )
475
+
476
+ tensor_to_reuse_level: defaultdict[TensorName, int] = defaultdict()
477
+ dfs_stack: deque[MappingNode] = deque([mapping]) # DFS starts at mapping root.
478
+
479
+ # Maps last non-branch to tiling of each in the group.
480
+ tiling_info: defaultdict[MappingNode, defaultdict[EinsumName, Tiling]] = (
481
+ defaultdict(defaultdict)
482
+ )
483
+
484
+ # Appends info for the root.
485
+ for einsum_name in workload.einsum_names:
486
+ tiling_info[mapping][einsum_name] = isl.Map.from_range(
487
+ get_einsum_operation_space(workload, einsum_name)
488
+ ).set_tuple_name(isl.dim_type.in_, f"{einsum_name}_tiled_iteration")
489
+
490
+ # Tracks rank_var specified to partitioned_rank_var index, as traversal
491
+ # in tiling goes down the partition.
492
+ rank_var_partitions: defaultdict[str, int] = defaultdict(lambda: 0)
493
+
494
+ def _get_rank_var_partition(rank_var: str) -> str:
495
+ """
496
+ Given a rank_var, get the partition at the current point in execution
497
+ and increment for the next retrieval.
498
+ """
499
+ nonlocal rank_var_partitions
500
+ rank_var_partition: str = f"{rank_var}{rank_var_partitions[rank_var]}"
501
+ rank_var_partitions[rank_var] += 1
502
+ return rank_var_partition
503
+
504
+ def _tile_branch(heads: set[EinsumName], fusing_node: MappingNode):
505
+ """
506
+ Given a set of `heads` to fuse at `fusing_node`, fuse as much as possible
507
+ in this branch.
508
+
509
+ Parameters
510
+ ----------
511
+ heads:
512
+ The heads being fused.
513
+ fusing_node:
514
+ The node node in the mapping at which the fusing is happening.
515
+
516
+ Preconditions
517
+ -------------
518
+ 1. `dfs_stack`: initialized with tiles to proceed to explore.
519
+ 2. `tiling_info`: prima facie populated.
520
+ 3. `tensor_to_reuse_level`: initialized and unmutated from last time this
521
+ function was run.
522
+
523
+ Postconditions
524
+ --------------
525
+ 1. `dfs_stack`: progressed to the next node to tile at.
526
+ 2. `tiling_info`: updated to include the fusing and tiling.
527
+ 3. `tensor_to_reuse_level`: populated if information has changed from tiling.
528
+ """
529
+ nonlocal dfs_stack
530
+ nonlocal tiling_info
531
+ nonlocal tensor_to_reuse_level
532
+
533
+ current_node: MappingNode = fusing_node
534
+ while True:
535
+ # Fuses current_node to one of the heads.
536
+ match current_node:
537
+ # For or Par-For loop handling.
538
+ case Loop():
539
+ if len(heads) != 1:
540
+ raise ValueError(
541
+ f"Cannot fuse tiled set with {len(heads)} heads.\n"
542
+ )
543
+
544
+ # Tiles `current_node.rank_variable` at `head`
545
+ head = next(iter(heads))
546
+ tiling: Tiling = tiling_info[fusing_node][head]
547
+ # Downstreams of "heads" is also constant as it is a set, not
548
+ # AbstractSet.
549
+ idx: int = tuple(workload.einsums[head].rank_variables).index(
550
+ current_node.rank_variable
551
+ )
552
+
553
+ # Adds a new tile_dim to the old tiling.
554
+ # TODO: Handle stride.
555
+ if (
556
+ isinstance(
557
+ _ := current_node.tile_pattern.initial_tile_shape, int
558
+ )
559
+ and (_ != 0)
560
+ and (_ == current_node.tile_pattern.tile_shape)
561
+ ):
562
+ tiling: Tiling = add_new_tile_dim(
563
+ tiling,
564
+ idx,
565
+ current_node.tile_pattern.initial_tile_shape,
566
+ _get_rank_var_partition(current_node.rank_variable),
567
+ )
568
+ else:
569
+ raise NotImplementedError(
570
+ f"Tile size analysis not implemented for type {type(fusing_node)} "
571
+ f"with tile shape {current_node.tile_pattern.initial_tile_shape}"
572
+ )
573
+
574
+ # Saves the fused tiling.
575
+ tiling_info[fusing_node][head] = tiling
576
+
577
+ # Adds the ranks to the tiling isl.Map.
578
+ iteration_set: isl.Set = tiling.domain()
579
+ for einsum in mapping_groups[fusing_node] - {head}:
580
+ tiling = tiling_info[fusing_node][einsum]
581
+ # Index variables for the branch.
582
+ tiling = insert_dims_preserve_name_map(
583
+ tiling, isl.dim_type.in_, tiling.dim(isl.dim_type.in_), 1
584
+ )
585
+ tiling = tiling.set_dim_name(
586
+ isl.dim_type.in_,
587
+ tiling.dim(isl.dim_type.in_) - 1,
588
+ _get_rank_var_partition(current_node.rank_variable),
589
+ )
590
+ # TODO: Figure out if this intersection is correct.
591
+ tiling = tiling.intersect_domain(
592
+ iteration_set.set_tuple_name(
593
+ tiling.get_tuple_name(isl.dim_type.in_)
594
+ )
595
+ )
596
+ tiling_info[fusing_node][einsum] = tiling
597
+
598
+ current_node = dfs_stack.pop()
599
+ # Notes what reuse level the tensor is on.
600
+ case Storage():
601
+ # See current_node is the highest level of Storage to determine reuse level.
602
+ for tensor in current_node.tensors:
603
+ # Check second term
604
+ if tensor not in tensor_to_reuse_level:
605
+ random_einsum: EinsumName = next(
606
+ iter(mapping_groups[fusing_node])
607
+ )
608
+ tiling: Tiling = tiling_info[fusing_node][random_einsum]
609
+ tensor_to_reuse_level[tensor] = tiling.dim(isl.dim_type.in_)
610
+
611
+ current_node = dfs_stack.pop()
612
+ # If we are at the Mapping root, just go to the actual Nodes.
613
+ case Mapping():
614
+ dfs_stack.extend(reversed(current_node.nodes))
615
+ current_node = dfs_stack.pop()
616
+ # If we hit the compute node, we've finished tiling, end!
617
+ case Compute():
618
+ result[current_node] = tiling_info[fusing_node][current_node.einsum]
619
+ return
620
+ case Split():
621
+ fused_set: set[EinsumName] = mapping_groups[fusing_node]
622
+ if len(heads) != 1:
623
+ # There can't be a tiling, so no inference to be done.
624
+ break
625
+
626
+ random_head = next(iter(heads))
627
+ if len(_ := detect_shared_input_tensor(fused_set, workload)) == 1:
628
+ shared_input_based_tile_shape_inference(
629
+ workload,
630
+ tiling_info[fusing_node],
631
+ fused_set,
632
+ _[0],
633
+ random_head,
634
+ )
635
+ else:
636
+ consumer_based_tile_shape_inference(
637
+ workload,
638
+ tiling_info[fusing_node],
639
+ tensor_to_reuse_level,
640
+ fused_set,
641
+ random_head,
642
+ )
643
+
644
+ # Goes through each child node of the current node and propagate
645
+ # the tiling updates.
646
+ for idx, child in enumerate(current_node.nodes):
647
+ # Each child needs tilings for all Einsums in its group.
648
+ group: set[EinsumName] = mapping_groups[child]
649
+ tilings: defaultdict[EinsumName, Tiling] = defaultdict()
650
+
651
+ # For all einsums the child is involved in, update their tilings.
652
+ for einsum in group:
653
+ tiling: Tiling = tiling_info[fusing_node][einsum]
654
+ # Add dimension that iterates over branches.
655
+ new_tiling: Tiling = add_dims_preserve_name_map(
656
+ tiling, isl.dim_type.in_, 1
657
+ )
658
+
659
+ tilings[einsum] = new_tiling.fix_input_si(
660
+ new_tiling.dim(isl.dim_type.in_) - 1, idx
661
+ )
662
+
663
+ # Update the tiling info for the child.
664
+ tiling_info[child] = tilings
665
+ # DFS tile on the child.
666
+ dfs_stack.append(child)
667
+
668
+ return
669
+ case Nested():
670
+ dfs_stack.extend(reversed(current_node.nodes))
671
+ current_node = dfs_stack.pop()
672
+ case _:
673
+ raise NotImplementedError(
674
+ f"Type {type(fusing_node)} not handled.\n"
675
+ f"---\n"
676
+ f"node={pformat(fusing_node)}"
677
+ )
678
+
679
+ while dfs_stack:
680
+ fusing_node = dfs_stack.pop()
681
+ if DUMP_ISL_IR:
682
+ print(f"New Tiling Root: {pformat(fusing_node)}")
683
+ _tile_branch(mapping_group_heads[fusing_node], fusing_node)
684
+
685
+ return result