accelforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. accelforge/__init__.py +21 -0
  2. accelforge/_accelerated_imports.py +16 -0
  3. accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
  4. accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
  5. accelforge/_deprecate/_simanneal/simanneal.py +666 -0
  6. accelforge/_deprecate/_simanneal/tracking.py +105 -0
  7. accelforge/_deprecate/_simanneal/wrappers.py +218 -0
  8. accelforge/_deprecate/_simanneal2/__init__.py +7 -0
  9. accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
  10. accelforge/_deprecate/_simanneal2/tracking.py +116 -0
  11. accelforge/_deprecate/compatibility_util.py +181 -0
  12. accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
  13. accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
  14. accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
  15. accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
  16. accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
  17. accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
  18. accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
  19. accelforge/_deprecate/tags.py +69 -0
  20. accelforge/_deprecate/viz/__init__.py +0 -0
  21. accelforge/_deprecate/viz/interactive.py +159 -0
  22. accelforge/_deprecate/viz/reservationtree.py +307 -0
  23. accelforge/_deprecate/viz/ski_slope.py +88 -0
  24. accelforge/_version.py +15 -0
  25. accelforge/examples.py +39 -0
  26. accelforge/frontend/__init__.py +10 -0
  27. accelforge/frontend/_binding.py +129 -0
  28. accelforge/frontend/_workload_isl/__init__.py +2 -0
  29. accelforge/frontend/_workload_isl/_isl.py +149 -0
  30. accelforge/frontend/_workload_isl/_symbolic.py +141 -0
  31. accelforge/frontend/arch copy.py +1544 -0
  32. accelforge/frontend/arch.py +1642 -0
  33. accelforge/frontend/config.py +63 -0
  34. accelforge/frontend/mapper/__init__.py +5 -0
  35. accelforge/frontend/mapper/ffm.py +126 -0
  36. accelforge/frontend/mapper/mapper.py +7 -0
  37. accelforge/frontend/mapper/metrics.py +30 -0
  38. accelforge/frontend/mapping/__init__.py +1 -0
  39. accelforge/frontend/mapping/mapping.py +1736 -0
  40. accelforge/frontend/model.py +14 -0
  41. accelforge/frontend/renames.py +150 -0
  42. accelforge/frontend/spec copy.py +230 -0
  43. accelforge/frontend/spec.py +301 -0
  44. accelforge/frontend/variables.py +12 -0
  45. accelforge/frontend/workload.py +952 -0
  46. accelforge/mapper/FFM/__init__.py +9 -0
  47. accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
  48. accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
  49. accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
  50. accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
  51. accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
  52. accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
  53. accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
  54. accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
  55. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
  56. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
  57. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
  58. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
  59. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
  60. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
  61. accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
  62. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
  63. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
  64. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
  65. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
  66. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
  67. accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
  68. accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
  69. accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
  70. accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
  71. accelforge/mapper/FFM/data.py +61 -0
  72. accelforge/mapper/FFM/main copy.py +236 -0
  73. accelforge/mapper/FFM/main.py +208 -0
  74. accelforge/mapper/FFM/mappings.py +510 -0
  75. accelforge/mapper/FFM/pmappings.py +310 -0
  76. accelforge/mapper/__init__.py +4 -0
  77. accelforge/mapper.py +0 -0
  78. accelforge/model/__init__.py +1 -0
  79. accelforge/model/_looptree/__init__.py +0 -0
  80. accelforge/model/_looptree/accesses.py +335 -0
  81. accelforge/model/_looptree/capacity/__init__.py +1 -0
  82. accelforge/model/_looptree/capacity/aggregators.py +36 -0
  83. accelforge/model/_looptree/capacity/capacity.py +47 -0
  84. accelforge/model/_looptree/energy.py +150 -0
  85. accelforge/model/_looptree/equivalent_ranks.py +29 -0
  86. accelforge/model/_looptree/latency/__init__.py +1 -0
  87. accelforge/model/_looptree/latency/latency.py +98 -0
  88. accelforge/model/_looptree/latency/memory.py +120 -0
  89. accelforge/model/_looptree/latency/processors.py +92 -0
  90. accelforge/model/_looptree/mapping_utilities.py +71 -0
  91. accelforge/model/_looptree/reuse/__init__.py +4 -0
  92. accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
  93. accelforge/model/_looptree/reuse/isl/des.py +59 -0
  94. accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
  95. accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
  96. accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
  97. accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
  98. accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
  99. accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
  100. accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
  101. accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
  102. accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
  103. accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
  104. accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
  105. accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
  106. accelforge/model/_looptree/run.py +122 -0
  107. accelforge/model/_looptree/types.py +26 -0
  108. accelforge/model/_looptree/visualization/__init__.py +0 -0
  109. accelforge/model/_looptree/visualization/occupancy.py +11 -0
  110. accelforge/model/main.py +222 -0
  111. accelforge/plotting/__init__.py +2 -0
  112. accelforge/plotting/mappings.py +219 -0
  113. accelforge/plotting/specs.py +57 -0
  114. accelforge/util/__init__.py +4 -0
  115. accelforge/util/_base_analysis_types.py +24 -0
  116. accelforge/util/_basetypes.py +1089 -0
  117. accelforge/util/_frozenset.py +36 -0
  118. accelforge/util/_isl.py +29 -0
  119. accelforge/util/_itertools.py +14 -0
  120. accelforge/util/_mathfuncs.py +57 -0
  121. accelforge/util/_parse_expressions.py +339 -0
  122. accelforge/util/_picklecache.py +32 -0
  123. accelforge/util/_setexpressions.py +268 -0
  124. accelforge/util/_sympy/__init__.py +0 -0
  125. accelforge/util/_sympy/broadcast_max.py +18 -0
  126. accelforge/util/_visualization.py +112 -0
  127. accelforge/util/_yaml.py +579 -0
  128. accelforge/util/parallel.py +193 -0
  129. accelforge-0.0.1.dist-info/METADATA +64 -0
  130. accelforge-0.0.1.dist-info/RECORD +258 -0
  131. accelforge-0.0.1.dist-info/WHEEL +5 -0
  132. accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
  133. accelforge-0.0.1.dist-info/top_level.txt +5 -0
  134. docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
  135. docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
  136. docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
  137. docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
  138. docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
  139. docs/_build/html/_sources/fastfusion.rst.txt +20 -0
  140. docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
  141. docs/_build/html/_sources/index.rst.txt +87 -0
  142. docs/_build/html/_sources/modules.rst.txt +7 -0
  143. docs/_build/html/_sources/notes/citation.rst.txt +45 -0
  144. docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
  145. docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
  146. docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
  147. docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
  148. docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
  149. docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
  150. docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
  151. docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
  152. docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
  153. docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
  154. docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
  155. docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
  156. docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
  157. docs/_build/html/_sources/notes/spec.rst.txt +36 -0
  158. docs/source/_ext/include_attrs.py +213 -0
  159. docs/source/_ext/include_docstring.py +364 -0
  160. docs/source/_ext/include_functions.py +154 -0
  161. docs/source/_ext/include_notebook.py +131 -0
  162. docs/source/_ext/include_yaml.py +119 -0
  163. docs/source/_ext/inherited_attributes.py +222 -0
  164. docs/source/_ext/paths.py +4 -0
  165. docs/source/conf.py +79 -0
  166. examples/arches/compute_in_memory/_include.yaml +74 -0
  167. examples/arches/compute_in_memory/_include_functions.py +229 -0
  168. examples/arches/compute_in_memory/_load_spec.py +57 -0
  169. examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
  170. examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
  171. examples/arches/compute_in_memory/components/misc.py +195 -0
  172. examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
  173. examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
  174. examples/arches/compute_in_memory/isaac.yaml +233 -0
  175. examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
  176. examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
  177. examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
  178. examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
  179. examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
  180. examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
  181. examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
  182. examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
  183. examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
  184. examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
  185. examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
  186. examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
  187. examples/arches/eyeriss.yaml +68 -0
  188. examples/arches/fanout_variations/at_glb.yaml +31 -0
  189. examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
  190. examples/arches/fanout_variations/at_mac.yaml +31 -0
  191. examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
  192. examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
  193. examples/arches/nvdla.yaml +47 -0
  194. examples/arches/simple.yaml +28 -0
  195. examples/arches/tpu_v4i.yaml +67 -0
  196. examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
  197. examples/misc/component_annotated.yaml +33 -0
  198. examples/workloads/gpt3_6.7B.yaml +124 -0
  199. examples/workloads/matmuls.yaml +20 -0
  200. examples/workloads/mobilenet_28.yaml +81 -0
  201. examples/workloads/mobilenet_various_separate.yaml +106 -0
  202. examples/workloads/three_matmuls_annotated.yaml +59 -0
  203. notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
  204. notebooks/compute_in_memory/_scripts.py +339 -0
  205. notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
  206. notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
  207. notebooks/paths.py +4 -0
  208. notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
  209. notebooks/tutorials/FFM.ipynb +3498 -0
  210. notebooks/tutorials/_include.py +48 -0
  211. notebooks/tutorials/component_energy_area.ipynb +363 -0
  212. tests/Q_mapping.yaml +38 -0
  213. tests/__init__.py +0 -0
  214. tests/conv.mapping.yaml +27 -0
  215. tests/conv.workload.yaml +13 -0
  216. tests/conv_sym.mapping.yaml +43 -0
  217. tests/copy.mapping.yaml +35 -0
  218. tests/copy.workload.yaml +15 -0
  219. tests/distribuffers/__init__.py +0 -0
  220. tests/distribuffers/multicast/test_cases.yaml +482 -0
  221. tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
  222. tests/distribuffers/spec/distributed.yaml +100 -0
  223. tests/distribuffers/spec/logical_arch.yaml +32 -0
  224. tests/distribuffers/spec/physical_arch.yaml +69 -0
  225. tests/distribuffers/test_binding.py +48 -0
  226. tests/frontend/__init__.py +0 -0
  227. tests/frontend/test_mapping_viz.py +52 -0
  228. tests/mapper/__init__.py +0 -0
  229. tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
  230. tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
  231. tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
  232. tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
  233. tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
  234. tests/mapper/test_mapping_to_isl.py +90 -0
  235. tests/mapper/test_spatial_reuse_analysis.py +67 -0
  236. tests/mapper/test_temporal_reuse_analysis.py +56 -0
  237. tests/mapper/util.py +58 -0
  238. tests/matmul.mapping.yaml +29 -0
  239. tests/matmul.workload.yaml +12 -0
  240. tests/matmul_spatial.mapping.yaml +44 -0
  241. tests/mha.renames.yaml +65 -0
  242. tests/mha.workload.yaml +67 -0
  243. tests/mha.yaml +59 -0
  244. tests/mha_full.workload.yaml +67 -0
  245. tests/mobilenet.workload.yaml +35 -0
  246. tests/mobilenet_long.workload.yaml +64 -0
  247. tests/pmappingcache.py +24 -0
  248. tests/processing_stage.arch.yaml +40 -0
  249. tests/snowcat.arch.yaml +36 -0
  250. tests/test_ffm_join_pmappings.py +106 -0
  251. tests/test_ffm_make_pmappings.py +82 -0
  252. tests/test_ffm_make_tile_shapes.py +49 -0
  253. tests/test_mapper.py +100 -0
  254. tests/test_model.py +37 -0
  255. tests/test_plotting.py +72 -0
  256. tests/test_processing_stage.py +46 -0
  257. tests/test_symbolic_model.py +248 -0
  258. tests/test_workload.py +141 -0
@@ -0,0 +1,236 @@
1
+ """
2
+ Relevant Name Changes:
3
+ - BufferID -> ComponentName
4
+ """
5
+
6
+ from collections import defaultdict
7
+ from typing import Callable, List, Tuple
8
+ import islpy as isl
9
+
10
+ from accelforge.frontend.mapping import (
11
+ Mapping,
12
+ MappingNode,
13
+ # Iterations
14
+ Loop,
15
+ Spatial,
16
+ Temporal,
17
+ # Splits
18
+ Pipeline,
19
+ Sequential,
20
+ # Logical hardware features
21
+ Storage,
22
+ Compute,
23
+ )
24
+ from accelforge.frontend.workload import TensorName, Workload
25
+
26
+ from accelforge.model._looptree.mapping_utilities import get_paths
27
+ from accelforge.model._looptree.types import ComponentName
28
+ from accelforge.model._looptree.reuse.isl.isl_functions import (
29
+ dim_projector_mask,
30
+ insert_equal_dims_map,
31
+ )
32
+ from accelforge.model._looptree.reuse.isl.mapping_to_isl import DUMP_ISL_IR
33
+
34
+ from .types import (
35
+ # Bookkeeping objects
36
+ BufferTensorEinsum,
37
+ ComputeEinsum,
38
+ EinsumName,
39
+ Skew,
40
+ SkewsInfo,
41
+ # Tags
42
+ Tag,
43
+ TemporalTag,
44
+ SpatialTag,
45
+ PipelineTag,
46
+ SequentialTag,
47
+ )
48
+
49
+
50
+ def skews_from_mapping(mapping: Mapping, workload: Workload) -> SkewsInfo:
51
+ """
52
+ Given a mapping and workload, compute the skew relationships for buffers and
53
+ computes.
54
+ TODO: Fill this in with more accurate description.
55
+
56
+ Parameters
57
+ ----------
58
+ mapping:
59
+ The mapping being analyzed.
60
+ workload:
61
+ The workload being executed.
62
+
63
+ Returns
64
+ -------
65
+ Skew information for buffer-tensor-einsum and compute-einsum combinations.
66
+ """
67
+ compute_einsum_to_skew: dict[ComputeEinsum, Skew] = defaultdict()
68
+ buffer_tensor_einsum_to_skew: dict[BufferTensorEinsum, Skew] = defaultdict()
69
+
70
+ for path in get_paths(mapping):
71
+ leaf: Compute = path[-1]
72
+
73
+ # Get the last storage node in path for a particular buffet.
74
+ buffer_to_last_storage_node: dict[ComponentName, MappingNode] = {}
75
+ buffer_node: List[Tuple[ComponentName, MappingNode]] = []
76
+ all_buffer_tensors: List[Tuple[ComponentName, TensorName]] = []
77
+
78
+ node: MappingNode
79
+ for node in path:
80
+ match node:
81
+ case Storage():
82
+ buffer: ComponentName = node.component
83
+ buffer_to_last_storage_node[buffer] = node
84
+ buffer_node.append((buffer, node))
85
+ # TODO: Check this is correct
86
+ all_buffer_tensors.extend(
87
+ (buffer, tensor) for tensor in node.tensors
88
+ )
89
+ case Compute():
90
+ compute: ComponentName = node.component
91
+ buffer_to_last_storage_node[compute] = node
92
+ buffer_node.append((compute, node))
93
+
94
+ node_to_current_buffer: dict[MappingNode, MappingNode] = {}
95
+ buffer_idx: int = 0
96
+ for node in path:
97
+ _, cur_buf_last_node = buffer_node[buffer_idx]
98
+ node_to_current_buffer[node] = cur_buf_last_node
99
+
100
+ if node == cur_buf_last_node:
101
+ buffer_idx += 1
102
+
103
+ # Generate tags, map, and which dims (and tags) should be removed per buffer.
104
+ tags: List[Tag] = []
105
+ base_space: str = f"{leaf.component}_spacetime"
106
+ removal_map: isl.Map = (
107
+ isl.Map.from_multi_aff(
108
+ isl.MultiAff.identity_on_domain_space(
109
+ isl.Space.alloc(isl.DEFAULT_CONTEXT, 0, 0, 0).domain()
110
+ )
111
+ )
112
+ .set_tuple_name(isl.dim_type.in_, base_space)
113
+ .set_tuple_name(isl.dim_type.out, base_space)
114
+ )
115
+
116
+ buffer_storage_past: set[Tuple[ComponentName, TensorName]] = set()
117
+ buffer_fully_complete: set[ComponentName] = set()
118
+ buffer_to_dim_removal_mask: defaultdict[
119
+ Tuple[ComponentName, TensorName], List[bool]
120
+ ] = defaultdict(list)
121
+
122
+ def add_tag(
123
+ tag: Tag,
124
+ mask_condition: Callable[[ComponentName, TensorName], bool] = (
125
+ lambda b, t: b in buffer_fully_complete
126
+ ),
127
+ ) -> None:
128
+ """
129
+ Performs necessary modifications to removal_map and removal_mask to
130
+ accommodate tagging.
131
+
132
+ Parameters
133
+ ----------
134
+ tag:
135
+ The tag to add.
136
+ mask_condition:
137
+ Boolean resolution for the removal mask.
138
+
139
+ Postconditions
140
+ --------------
141
+ - `tags` has another tag appended to it.
142
+ - `removal_map` has an input and output dimension added that are equal
143
+ to each other.
144
+ - `removal_mask` has a new entry.
145
+ """
146
+ nonlocal tags
147
+ tags.append(tag)
148
+ nonlocal removal_map
149
+ removal_map = insert_equal_dims_map(
150
+ removal_map,
151
+ removal_map.dim(isl.dim_type.in_),
152
+ removal_map.dim(isl.dim_type.out),
153
+ 1,
154
+ )
155
+ if DUMP_ISL_IR:
156
+ print(f"skew removal_map: {removal_map}")
157
+ print(f"tag: {tag}")
158
+
159
+ nonlocal all_buffer_tensors
160
+ nonlocal buffer_to_dim_removal_mask
161
+ for buffer_tensor in all_buffer_tensors:
162
+ removal_mask = buffer_to_dim_removal_mask[buffer_tensor]
163
+ removal_mask.append(mask_condition(*buffer_tensor))
164
+
165
+ for node in path:
166
+ match node:
167
+ case Storage():
168
+ buffer_storage_past.update(
169
+ (node.component, tensor) for tensor in node.tensors
170
+ )
171
+ if node == buffer_to_last_storage_node[node.component]:
172
+ buffer_fully_complete.add(node.component)
173
+ case Loop():
174
+ tag: Tag
175
+ if isinstance(node, Temporal):
176
+ tag: Tag = TemporalTag()
177
+ elif isinstance(node, Spatial):
178
+ tag: Tag = SpatialTag(0, node_to_current_buffer[node])
179
+ else:
180
+ raise ValueError(
181
+ f"Type {type(node)} is an iteration not in space or time."
182
+ )
183
+
184
+ # TODO: Verify logical equivalence to:
185
+ # https://github.com/NVlabs/timeloop/blob/32370826fdf1aa3c8deb0c93e6b2a2fc7cf053aa/src/loop-analysis/mapping-to-isl/fused-mapping-to-isl.cpp#L660-L671
186
+ add_tag(
187
+ tag,
188
+ lambda b, t: (
189
+ (b in buffer_fully_complete)
190
+ or (
191
+ (b, t) in buffer_storage_past
192
+ and isinstance(node, Temporal)
193
+ )
194
+ ),
195
+ )
196
+ case Pipeline():
197
+ add_tag(PipelineTag())
198
+ case Sequential():
199
+ add_tag(SequentialTag())
200
+
201
+ for buffer_tensor in all_buffer_tensors:
202
+ mask: List[bool] = buffer_to_dim_removal_mask[buffer_tensor]
203
+ domain: isl.Set = removal_map.domain()
204
+ projector: isl.Map = dim_projector_mask(domain.get_space(), mask)
205
+ removal_projection: isl.Map = projector.apply_range(removal_map)
206
+ # Attach tuple names per-buffer so downstream occupancy maps keep the spacetime label.
207
+ space_name: str = f"{buffer_tensor[0]}_spacetime"
208
+ removal_projection = removal_projection.set_tuple_name(
209
+ isl.dim_type.in_, space_name
210
+ ).set_tuple_name(isl.dim_type.out, space_name)
211
+
212
+ buffer_tags: List[Tag] = [tag for i, tag in enumerate(tags) if not mask[i]]
213
+
214
+ # TODO: This buffet structure makes no sense in this context:
215
+ # https://github.com/NVlabs/timeloop/blob/32370826fdf1aa3c8deb0c93e6b2a2fc7cf053aa/src/loop-analysis/mapping-to-isl/fused-mapping-to-isl.cpp#L740-L743
216
+ buffer_tensor_einsum_to_skew[BufferTensorEinsum(*buffer_tensor, leaf)] = (
217
+ Skew(buffer_tags, removal_projection)
218
+ )
219
+
220
+ # TODO: Figure out what is actually:
221
+ # https://github.com/NVlabs/timeloop/blob/32370826fdf1aa3c8deb0c93e6b2a2fc7cf053aa/src/loop-analysis/mapping-to-isl/fused-mapping-to-isl.cpp#L746
222
+ compute_einsum_to_skew[ComputeEinsum(leaf.component, leaf)] = Skew(
223
+ tags, removal_map
224
+ )
225
+ einsum: EinsumName = leaf.einsum
226
+ for tensor in workload.einsums[einsum].input_tensor_names:
227
+ buffer_tensor_einsum_to_skew[
228
+ BufferTensorEinsum(leaf.component, tensor, leaf)
229
+ ] = Skew(tags, removal_map)
230
+
231
+ for tensor in workload.einsums[einsum].output_tensor_names:
232
+ buffer_tensor_einsum_to_skew[
233
+ BufferTensorEinsum(leaf.component, tensor, leaf)
234
+ ] = Skew(tags, removal_map)
235
+
236
+ return SkewsInfo(buffer_tensor_einsum_to_skew, compute_einsum_to_skew)