accelforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. accelforge/__init__.py +21 -0
  2. accelforge/_accelerated_imports.py +16 -0
  3. accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
  4. accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
  5. accelforge/_deprecate/_simanneal/simanneal.py +666 -0
  6. accelforge/_deprecate/_simanneal/tracking.py +105 -0
  7. accelforge/_deprecate/_simanneal/wrappers.py +218 -0
  8. accelforge/_deprecate/_simanneal2/__init__.py +7 -0
  9. accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
  10. accelforge/_deprecate/_simanneal2/tracking.py +116 -0
  11. accelforge/_deprecate/compatibility_util.py +181 -0
  12. accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
  13. accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
  14. accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
  15. accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
  16. accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
  17. accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
  18. accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
  19. accelforge/_deprecate/tags.py +69 -0
  20. accelforge/_deprecate/viz/__init__.py +0 -0
  21. accelforge/_deprecate/viz/interactive.py +159 -0
  22. accelforge/_deprecate/viz/reservationtree.py +307 -0
  23. accelforge/_deprecate/viz/ski_slope.py +88 -0
  24. accelforge/_version.py +15 -0
  25. accelforge/examples.py +39 -0
  26. accelforge/frontend/__init__.py +10 -0
  27. accelforge/frontend/_binding.py +129 -0
  28. accelforge/frontend/_workload_isl/__init__.py +2 -0
  29. accelforge/frontend/_workload_isl/_isl.py +149 -0
  30. accelforge/frontend/_workload_isl/_symbolic.py +141 -0
  31. accelforge/frontend/arch copy.py +1544 -0
  32. accelforge/frontend/arch.py +1642 -0
  33. accelforge/frontend/config.py +63 -0
  34. accelforge/frontend/mapper/__init__.py +5 -0
  35. accelforge/frontend/mapper/ffm.py +126 -0
  36. accelforge/frontend/mapper/mapper.py +7 -0
  37. accelforge/frontend/mapper/metrics.py +30 -0
  38. accelforge/frontend/mapping/__init__.py +1 -0
  39. accelforge/frontend/mapping/mapping.py +1736 -0
  40. accelforge/frontend/model.py +14 -0
  41. accelforge/frontend/renames.py +150 -0
  42. accelforge/frontend/spec copy.py +230 -0
  43. accelforge/frontend/spec.py +301 -0
  44. accelforge/frontend/variables.py +12 -0
  45. accelforge/frontend/workload.py +952 -0
  46. accelforge/mapper/FFM/__init__.py +9 -0
  47. accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
  48. accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
  49. accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
  50. accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
  51. accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
  52. accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
  53. accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
  54. accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
  55. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
  56. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
  57. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
  58. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
  59. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
  60. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
  61. accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
  62. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
  63. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
  64. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
  65. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
  66. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
  67. accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
  68. accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
  69. accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
  70. accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
  71. accelforge/mapper/FFM/data.py +61 -0
  72. accelforge/mapper/FFM/main copy.py +236 -0
  73. accelforge/mapper/FFM/main.py +208 -0
  74. accelforge/mapper/FFM/mappings.py +510 -0
  75. accelforge/mapper/FFM/pmappings.py +310 -0
  76. accelforge/mapper/__init__.py +4 -0
  77. accelforge/mapper.py +0 -0
  78. accelforge/model/__init__.py +1 -0
  79. accelforge/model/_looptree/__init__.py +0 -0
  80. accelforge/model/_looptree/accesses.py +335 -0
  81. accelforge/model/_looptree/capacity/__init__.py +1 -0
  82. accelforge/model/_looptree/capacity/aggregators.py +36 -0
  83. accelforge/model/_looptree/capacity/capacity.py +47 -0
  84. accelforge/model/_looptree/energy.py +150 -0
  85. accelforge/model/_looptree/equivalent_ranks.py +29 -0
  86. accelforge/model/_looptree/latency/__init__.py +1 -0
  87. accelforge/model/_looptree/latency/latency.py +98 -0
  88. accelforge/model/_looptree/latency/memory.py +120 -0
  89. accelforge/model/_looptree/latency/processors.py +92 -0
  90. accelforge/model/_looptree/mapping_utilities.py +71 -0
  91. accelforge/model/_looptree/reuse/__init__.py +4 -0
  92. accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
  93. accelforge/model/_looptree/reuse/isl/des.py +59 -0
  94. accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
  95. accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
  96. accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
  97. accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
  98. accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
  99. accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
  100. accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
  101. accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
  102. accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
  103. accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
  104. accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
  105. accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
  106. accelforge/model/_looptree/run.py +122 -0
  107. accelforge/model/_looptree/types.py +26 -0
  108. accelforge/model/_looptree/visualization/__init__.py +0 -0
  109. accelforge/model/_looptree/visualization/occupancy.py +11 -0
  110. accelforge/model/main.py +222 -0
  111. accelforge/plotting/__init__.py +2 -0
  112. accelforge/plotting/mappings.py +219 -0
  113. accelforge/plotting/specs.py +57 -0
  114. accelforge/util/__init__.py +4 -0
  115. accelforge/util/_base_analysis_types.py +24 -0
  116. accelforge/util/_basetypes.py +1089 -0
  117. accelforge/util/_frozenset.py +36 -0
  118. accelforge/util/_isl.py +29 -0
  119. accelforge/util/_itertools.py +14 -0
  120. accelforge/util/_mathfuncs.py +57 -0
  121. accelforge/util/_parse_expressions.py +339 -0
  122. accelforge/util/_picklecache.py +32 -0
  123. accelforge/util/_setexpressions.py +268 -0
  124. accelforge/util/_sympy/__init__.py +0 -0
  125. accelforge/util/_sympy/broadcast_max.py +18 -0
  126. accelforge/util/_visualization.py +112 -0
  127. accelforge/util/_yaml.py +579 -0
  128. accelforge/util/parallel.py +193 -0
  129. accelforge-0.0.1.dist-info/METADATA +64 -0
  130. accelforge-0.0.1.dist-info/RECORD +258 -0
  131. accelforge-0.0.1.dist-info/WHEEL +5 -0
  132. accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
  133. accelforge-0.0.1.dist-info/top_level.txt +5 -0
  134. docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
  135. docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
  136. docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
  137. docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
  138. docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
  139. docs/_build/html/_sources/fastfusion.rst.txt +20 -0
  140. docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
  141. docs/_build/html/_sources/index.rst.txt +87 -0
  142. docs/_build/html/_sources/modules.rst.txt +7 -0
  143. docs/_build/html/_sources/notes/citation.rst.txt +45 -0
  144. docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
  145. docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
  146. docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
  147. docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
  148. docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
  149. docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
  150. docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
  151. docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
  152. docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
  153. docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
  154. docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
  155. docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
  156. docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
  157. docs/_build/html/_sources/notes/spec.rst.txt +36 -0
  158. docs/source/_ext/include_attrs.py +213 -0
  159. docs/source/_ext/include_docstring.py +364 -0
  160. docs/source/_ext/include_functions.py +154 -0
  161. docs/source/_ext/include_notebook.py +131 -0
  162. docs/source/_ext/include_yaml.py +119 -0
  163. docs/source/_ext/inherited_attributes.py +222 -0
  164. docs/source/_ext/paths.py +4 -0
  165. docs/source/conf.py +79 -0
  166. examples/arches/compute_in_memory/_include.yaml +74 -0
  167. examples/arches/compute_in_memory/_include_functions.py +229 -0
  168. examples/arches/compute_in_memory/_load_spec.py +57 -0
  169. examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
  170. examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
  171. examples/arches/compute_in_memory/components/misc.py +195 -0
  172. examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
  173. examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
  174. examples/arches/compute_in_memory/isaac.yaml +233 -0
  175. examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
  176. examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
  177. examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
  178. examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
  179. examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
  180. examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
  181. examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
  182. examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
  183. examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
  184. examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
  185. examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
  186. examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
  187. examples/arches/eyeriss.yaml +68 -0
  188. examples/arches/fanout_variations/at_glb.yaml +31 -0
  189. examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
  190. examples/arches/fanout_variations/at_mac.yaml +31 -0
  191. examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
  192. examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
  193. examples/arches/nvdla.yaml +47 -0
  194. examples/arches/simple.yaml +28 -0
  195. examples/arches/tpu_v4i.yaml +67 -0
  196. examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
  197. examples/misc/component_annotated.yaml +33 -0
  198. examples/workloads/gpt3_6.7B.yaml +124 -0
  199. examples/workloads/matmuls.yaml +20 -0
  200. examples/workloads/mobilenet_28.yaml +81 -0
  201. examples/workloads/mobilenet_various_separate.yaml +106 -0
  202. examples/workloads/three_matmuls_annotated.yaml +59 -0
  203. notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
  204. notebooks/compute_in_memory/_scripts.py +339 -0
  205. notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
  206. notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
  207. notebooks/paths.py +4 -0
  208. notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
  209. notebooks/tutorials/FFM.ipynb +3498 -0
  210. notebooks/tutorials/_include.py +48 -0
  211. notebooks/tutorials/component_energy_area.ipynb +363 -0
  212. tests/Q_mapping.yaml +38 -0
  213. tests/__init__.py +0 -0
  214. tests/conv.mapping.yaml +27 -0
  215. tests/conv.workload.yaml +13 -0
  216. tests/conv_sym.mapping.yaml +43 -0
  217. tests/copy.mapping.yaml +35 -0
  218. tests/copy.workload.yaml +15 -0
  219. tests/distribuffers/__init__.py +0 -0
  220. tests/distribuffers/multicast/test_cases.yaml +482 -0
  221. tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
  222. tests/distribuffers/spec/distributed.yaml +100 -0
  223. tests/distribuffers/spec/logical_arch.yaml +32 -0
  224. tests/distribuffers/spec/physical_arch.yaml +69 -0
  225. tests/distribuffers/test_binding.py +48 -0
  226. tests/frontend/__init__.py +0 -0
  227. tests/frontend/test_mapping_viz.py +52 -0
  228. tests/mapper/__init__.py +0 -0
  229. tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
  230. tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
  231. tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
  232. tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
  233. tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
  234. tests/mapper/test_mapping_to_isl.py +90 -0
  235. tests/mapper/test_spatial_reuse_analysis.py +67 -0
  236. tests/mapper/test_temporal_reuse_analysis.py +56 -0
  237. tests/mapper/util.py +58 -0
  238. tests/matmul.mapping.yaml +29 -0
  239. tests/matmul.workload.yaml +12 -0
  240. tests/matmul_spatial.mapping.yaml +44 -0
  241. tests/mha.renames.yaml +65 -0
  242. tests/mha.workload.yaml +67 -0
  243. tests/mha.yaml +59 -0
  244. tests/mha_full.workload.yaml +67 -0
  245. tests/mobilenet.workload.yaml +35 -0
  246. tests/mobilenet_long.workload.yaml +64 -0
  247. tests/pmappingcache.py +24 -0
  248. tests/processing_stage.arch.yaml +40 -0
  249. tests/snowcat.arch.yaml +36 -0
  250. tests/test_ffm_join_pmappings.py +106 -0
  251. tests/test_ffm_make_pmappings.py +82 -0
  252. tests/test_ffm_make_tile_shapes.py +49 -0
  253. tests/test_mapper.py +100 -0
  254. tests/test_model.py +37 -0
  255. tests/test_plotting.py +72 -0
  256. tests/test_processing_stage.py +46 -0
  257. tests/test_symbolic_model.py +248 -0
  258. tests/test_workload.py +141 -0
@@ -0,0 +1,181 @@
1
+ import itertools
2
+
3
+ from collections.abc import Iterable, Set
4
+
5
+ from fastfusion.frontend.spec import Spec
6
+ from fastfusion.frontend.workload import EinsumName, TensorName
7
+ from fastfusion.mapper.FFM._join_pmappings.compatibility import Compatibility
8
+ from fastfusion.mapper.FFM._join_pmappings.sim import SIM
9
+
10
+
11
+ DO_PRINT = False
12
+
13
+
14
+ def myprint(*args, **kwargs):
15
+ if DO_PRINT:
16
+ print(*args, **kwargs)
17
+
18
+
19
+ def sims2untiled_compats(
20
+ einsum2sims: dict[EinsumName, Iterable[SIM]],
21
+ ) -> dict[EinsumName, set[Compatibility]]:
22
+ return {
23
+ einsum_name: {sim.compatibility.clear_loop_bounds() for sim in sims}
24
+ for einsum_name, sims in einsum2sims.items()
25
+ }
26
+
27
+
28
+ def join_compatibilities(
29
+ einsum2compatibilities: dict[EinsumName, Iterable[Compatibility]],
30
+ spec: Spec = None,
31
+ ) -> dict[EinsumName, set[Compatibility]]:
32
+ """
33
+ Return dict from Einsum name to compatibilities (without tile shape)
34
+ that will ever contribute to full mappings.
35
+
36
+ CONTRACT FOR MAPPINGS GETTING TO THIS POINT: see `join_pmappings.join_sims`
37
+ """
38
+ for einsum_name, compats in einsum2compatibilities.items():
39
+ if sum(len(c) for c in compats) == 0:
40
+ raise ValueError(f"No pmappings for {einsum_name}")
41
+
42
+ if len(einsum2compatibilities) == 0:
43
+ raise ValueError("Nothing to join")
44
+
45
+ for einsum_name, per_einsum_compats in einsum2compatibilities.items():
46
+ if not per_einsum_compats:
47
+ raise ValueError(f"No compatibility for {einsum_name}")
48
+
49
+ compat2einsum2original: dict[
50
+ Compatibility, dict[EinsumName, set[Compatibility]]
51
+ ] = {}
52
+ for einsum_name, per_einsum_compats in einsum2compatibilities.items():
53
+ for compat in per_einsum_compats:
54
+ einsum2original = compat2einsum2original.setdefault(compat, {})
55
+ original = einsum2original.setdefault(einsum_name, set())
56
+ original.add(compat)
57
+
58
+ compatibilities = list(einsum2compatibilities.items())
59
+
60
+ einsum2tensor_names = {
61
+ einsum_name: spec.workload.einsums[einsum_name].tensor_names
62
+ for einsum_name in einsum2compatibilities
63
+ }
64
+
65
+ einsum2important_compatibilities = {}
66
+
67
+ # while-loop states
68
+ assert len(compatibilities) > 0
69
+ left_einsum, all_left_compats = compatibilities.pop(0)
70
+ left_tensors = einsum2tensor_names[left_einsum]
71
+
72
+ while compatibilities:
73
+ right_einsum, all_right_compats = compatibilities.pop(0)
74
+
75
+ right_tensors = einsum2tensor_names[right_einsum]
76
+ live_tensors = set.union(
77
+ set(), *(einsum2tensor_names[e] for e, _ in compatibilities)
78
+ )
79
+
80
+ grouped_left_compats = group_left(all_left_compats, right_tensors)
81
+ grouped_right_compats = group_right(all_right_compats, left_tensors)
82
+
83
+ combined = combine_left_and_right_compats(
84
+ compat2einsum2original,
85
+ grouped_left_compats,
86
+ grouped_right_compats,
87
+ live_tensors,
88
+ )
89
+
90
+ if DO_PRINT:
91
+ print_reverse_unmatched(grouped_left_compats, grouped_right_compats)
92
+
93
+ if not combined:
94
+ raise ValueError("No match found for any group")
95
+
96
+ # update while-loop states
97
+ all_left_compats = combined
98
+ left_einsum = right_einsum
99
+ left_tensors |= right_tensors
100
+
101
+ einsum2important_compatibilities: dict[EinsumName, set[Compatibility]] = {}
102
+ for compat in combined:
103
+ for einsum, original in compat2einsum2original[compat].items():
104
+ important_compats = einsum2important_compatibilities.setdefault(
105
+ einsum, set()
106
+ )
107
+ important_compats.update(original)
108
+ return einsum2important_compatibilities
109
+
110
+
111
+ def combine_left_and_right_compats(
112
+ compat2einsum2original: dict[Compatibility, dict[EinsumName, set[Compatibility]]],
113
+ grouped_left_compats: dict[Compatibility, Iterable[Compatibility]],
114
+ grouped_right_compats: dict[Compatibility, Iterable[Compatibility]],
115
+ live_tensors: set[TensorName],
116
+ ):
117
+ combined: list[Compatibility] = []
118
+ for left_key, left_compats in grouped_left_compats.items():
119
+ myprint(f"Left key {left_key}")
120
+
121
+ compatible_right_compats = grouped_right_compats.get(left_key, [])
122
+
123
+ if len(compatible_right_compats) == 0:
124
+ if DO_PRINT:
125
+ for l in left_compats:
126
+ print(f"\tNo match for {l}")
127
+ continue
128
+
129
+ for l, r in itertools.product(left_compats, compatible_right_compats):
130
+ if l.tags.are_compatible_with(r.tags):
131
+ merged = l.merge_next(r, live_tensors)
132
+ combined.append(merged)
133
+
134
+ einsum2original = compat2einsum2original.setdefault(merged, {})
135
+
136
+ left_einsum2original = compat2einsum2original[l]
137
+ right_einsum2original = compat2einsum2original[r]
138
+
139
+ einsums = set(left_einsum2original) | set(right_einsum2original)
140
+ for einsum in einsums:
141
+ einsum2original.setdefault(einsum, set()).update(
142
+ left_einsum2original.get(einsum, set())
143
+ | right_einsum2original.get(einsum, set())
144
+ )
145
+
146
+ myprint(f"\t{l}\n\t<-->\n\t{r}")
147
+ myprint(f"\t-->\n\t{merged}")
148
+ return combined
149
+
150
+
151
+ def print_reverse_unmatched(
152
+ grouped_left_compats: dict[Compatibility, Iterable[Compatibility]],
153
+ grouped_right_compats: dict[Compatibility, Iterable[Compatibility]],
154
+ ):
155
+ for right_key, right_compats in grouped_right_compats.items():
156
+ if right_key not in grouped_left_compats:
157
+ for r in right_compats:
158
+ print(f"\tREVERSE: No match for {r} using {right_key}")
159
+
160
+
161
+ def group_left(
162
+ left_compatibilities: Iterable[Compatibility],
163
+ right_tensors: Set[TensorName],
164
+ ) -> dict[Compatibility, set[Compatibility]]:
165
+ grouped_compats = {}
166
+ for compat in left_compatibilities:
167
+ key = compat.clear_dead_tensors(right_tensors, keep_loops=True, drop_tags=True)
168
+ grouped_compats.setdefault(key, set()).add(compat)
169
+ return grouped_compats
170
+
171
+
172
+ def group_right(
173
+ right_compatibilities: Iterable[Compatibility],
174
+ left_tensors: Set[TensorName],
175
+ ) -> dict[Compatibility, set[Compatibility]]:
176
+ grouped_compats = {}
177
+ for compat in right_compatibilities:
178
+ key = compat.clear_dead_tensors(left_tensors, keep_loops=True, drop_tags=True)
179
+ for per_loop_key in key.all_n_loops():
180
+ grouped_compats.setdefault(per_loop_key, set()).add(compat)
181
+ return grouped_compats
@@ -0,0 +1,2 @@
1
+ from .group_similar_einsums import group_similar_einsums
2
+ from .grouped_einsums import GroupOfSimilarEinsums, Id, Name
@@ -0,0 +1,160 @@
1
+ from collections import defaultdict
2
+ from collections.abc import Iterable
3
+ from itertools import permutations, product
4
+
5
+ from pytimeloop.bindings.looptree import LooptreeWorkload, LooptreeDependencyAnalyzer
6
+
7
+ from pytimeloop.looptree.mapping_utilities import get_intermediate_tensors
8
+ from fastfusion.util._frozenset import fzs
9
+
10
+ from .grouped_einsums import GroupOfSimilarEinsums, Id
11
+
12
+
13
+ def group_similar_einsums(
14
+ einsum_ids: Iterable[int],
15
+ workload: LooptreeWorkload,
16
+ analyzer: LooptreeDependencyAnalyzer,
17
+ ) -> list[GroupOfSimilarEinsums[Id]]:
18
+ """
19
+ Groups similar Einsums in `einsum_ids`.
20
+ """
21
+ grouped_einsums: list[GroupOfSimilarEinsums[Id]] = []
22
+ for einsum_id in einsum_ids:
23
+ found = False
24
+ for einsum_group in grouped_einsums:
25
+ einsum_ref_id = einsum_group.reference_einsum
26
+ rank_renaming, tensor_renaming = is_equivalent(
27
+ einsum_ref_id, einsum_id, workload, analyzer
28
+ )
29
+ if rank_renaming is not None:
30
+ einsum_group.add_similar_einsum(
31
+ einsum_id, rank_renaming, tensor_renaming
32
+ )
33
+ found = True
34
+ break
35
+
36
+ if not found:
37
+ grouped_einsums.append(GroupOfSimilarEinsums(einsum_id, workload))
38
+ return grouped_einsums
39
+
40
+
41
+ def is_equivalent(
42
+ einsum_id1: int,
43
+ einsum_id2: int,
44
+ workload: LooptreeWorkload,
45
+ analyzer: LooptreeDependencyAnalyzer,
46
+ ) -> tuple[dict[int, int], dict[int, int]]:
47
+ """
48
+ Determines whether two Einsums are equivalent in tensor shapes and
49
+ tensor indexing expressions.
50
+
51
+ If the two Einsums are equivalent, the rank and tensor renamings are
52
+ returned.
53
+
54
+ Returns:
55
+ If the two Einsums are equivalent, the function returns two dicts,
56
+ `rank_renaming` and `tensor_renaming`, representing how to rename
57
+ ranks (tensors) of `einsum_id1` to `einsum_id2`.
58
+
59
+ Otherwise, a tuple `(None, None)` is returned.
60
+ """
61
+ einsum1_ranks = workload.einsum_ospace_dimensions(einsum_id1)
62
+ einsum2_ranks = workload.einsum_ospace_dimensions(einsum_id2)
63
+
64
+ if len(einsum1_ranks) != len(einsum2_ranks):
65
+ return None, None
66
+
67
+ einsum1_input_tensors = workload.tensors_read_by_einsum(einsum_id1)
68
+ einsum1_output_tensor = workload.tensors_written_by_einsum(einsum_id1)
69
+ einsum2_input_tensors = workload.tensors_read_by_einsum(einsum_id2)
70
+ einsum2_output_tensor = workload.tensors_written_by_einsum(einsum_id2)
71
+
72
+ if einsum1_output_tensor is None:
73
+ einsum1_output_tensor = set()
74
+ if einsum2_output_tensor is None:
75
+ einsum2_output_tensor = set()
76
+
77
+ intermediate_tensors = get_intermediate_tensors(workload)
78
+
79
+ all_tensor_properties = []
80
+ all_tensors = [
81
+ (einsum1_input_tensors, einsum1_output_tensor),
82
+ (einsum2_input_tensors, einsum2_output_tensor),
83
+ ]
84
+ for input_tensors, output_tensors in all_tensors:
85
+ tensor_properties = defaultdict(set)
86
+ for tensor in input_tensors:
87
+ tensor_properties[tensor].add("input")
88
+ for tensor in output_tensors:
89
+ tensor_properties[tensor].add("output")
90
+ for tensor in tensor_properties:
91
+ if tensor in intermediate_tensors:
92
+ tensor_properties[tensor].add("intermediate")
93
+ tensor_properties = {
94
+ tensor: fzs(properties) for tensor, properties in tensor_properties.items()
95
+ }
96
+ all_tensor_properties.append(tensor_properties)
97
+
98
+ property_to_tensors = defaultdict(lambda: (set(), set()))
99
+ for i, tensor_properties in enumerate(all_tensor_properties):
100
+ for tensor, property in tensor_properties.items():
101
+ tensor_sets = property_to_tensors[property]
102
+ tensor_sets[i].add(tensor)
103
+
104
+ # Check if we can rename tensors in einsum1 to einsum2
105
+ for tensor_renaming in tensor_renamings(property_to_tensors):
106
+ # Check if we can rename einsum1 ranks to create einsum2
107
+ for renamed_ranks in permutations(einsum2_ranks):
108
+ rank_renaming = {r1: r2 for r1, r2 in zip(einsum1_ranks, renamed_ranks)}
109
+ if not _shape_is_equivalent(rank_renaming, workload):
110
+ continue
111
+
112
+ if not _dependency_is_equivalent(
113
+ einsum_id1, einsum_id2, rank_renaming, tensor_renaming, analyzer
114
+ ):
115
+ continue
116
+
117
+ return rank_renaming, tensor_renaming
118
+ return None, None
119
+
120
+
121
+ def tensor_renamings(property_to_tensors):
122
+ for tensors_of_1, tensors_of_2 in property_to_tensors.values():
123
+ if len(tensors_of_1) != len(tensors_of_2):
124
+ return
125
+
126
+ all_tensors_of_1 = [
127
+ t for tensors_of_1, _ in property_to_tensors.values() for t in tensors_of_1
128
+ ]
129
+ permutations_of_tensor_2_by_property = []
130
+ for _, tensors_of_2 in property_to_tensors.values():
131
+ permutations_of_tensor_2_by_property.append(permutations(tensors_of_2))
132
+ for permutation_of_2 in product(*permutations_of_tensor_2_by_property):
133
+ permutation_of_2 = tuple(t for tupl in permutation_of_2 for t in tupl)
134
+ renaming = dict(zip(all_tensors_of_1, permutation_of_2))
135
+ yield renaming
136
+
137
+
138
+ def _shape_is_equivalent(rank_renaming, workload):
139
+ for r1, r2 in rank_renaming.items():
140
+ r1_shape = workload.get_rank_shape(r1)
141
+ r2_shape = workload.get_rank_shape(r2)
142
+ if r1_shape != r2_shape:
143
+ return False
144
+ return True
145
+
146
+
147
+ def _dependency_is_equivalent(
148
+ einsum_id1, einsum_id2, rank_renaming, tensor_renaming, analyzer
149
+ ):
150
+ for t1, t2 in tensor_renaming.items():
151
+ for r1, r2 in rank_renaming.items():
152
+ r1_relevant_to_t1 = analyzer.einsum_dim_is_directly_relevant_to_tensor(
153
+ einsum_id1, r1, t1
154
+ )
155
+ r2_relevant_to_t2 = analyzer.einsum_dim_is_directly_relevant_to_tensor(
156
+ einsum_id2, r2, t2
157
+ )
158
+ if r1_relevant_to_t1 != r2_relevant_to_t2:
159
+ return False
160
+ return True
@@ -0,0 +1,84 @@
1
+ from collections.abc import Iterable
2
+
3
+ from bindings.looptree import LooptreeWorkload
4
+
5
+
6
+ type Id = int
7
+ type Name = str
8
+
9
+
10
+ class GroupOfSimilarEinsums[IdOrName: Id | Name]:
11
+ def __init__(self, reference_einsum: Id, workload: LooptreeWorkload):
12
+ self.reference_einsum = reference_einsum
13
+ self.workload = workload
14
+ self.similar_einsums_to_renaming = {}
15
+ self.in_id = True
16
+
17
+ def add_similar_einsum(
18
+ self,
19
+ similar_einsum: IdOrName,
20
+ rank_renaming: IdOrName,
21
+ tensor_renaming: IdOrName,
22
+ ):
23
+ self.similar_einsums_to_renaming[similar_einsum] = (
24
+ rank_renaming,
25
+ tensor_renaming,
26
+ )
27
+
28
+ @property
29
+ def similar_einsums(self) -> Iterable[IdOrName]:
30
+ return self.similar_einsums_to_renaming.keys()
31
+
32
+ @property
33
+ def get_renaming(
34
+ self, other_einsum: IdOrName
35
+ ) -> tuple[dict[IdOrName, IdOrName], dict[IdOrName, IdOrName]]:
36
+ """Returns iterable over tuple `(rank_renaming, tensor_renaming)`"""
37
+ try:
38
+ return self.similar_einsums_to_renaming[other_einsum]
39
+ except Exception as e:
40
+ e.add_note(f"{other_einsum} not in group of similar Einsums.")
41
+ raise
42
+
43
+ @property
44
+ def similar_einsums_and_renamings(
45
+ self,
46
+ ) -> Iterable[
47
+ tuple[IdOrName, tuple[dict[IdOrName, IdOrName], dict[IdOrName, IdOrName]]]
48
+ ]:
49
+ """
50
+ Returns iterable over tuple `(similar_einsum, renaming)`
51
+ where `renaming` itself is `(rank_renaming, tensor_renaming).
52
+ """
53
+ return self.similar_einsums_and_renamings.items()
54
+
55
+ def convert_id_to_name(self) -> "GroupOfSimilarEinsums[Name]":
56
+ einsum_id_to_name = self.workload.EinsumIdToName()
57
+ tensor_id_to_name = self.workload.DataSpaceIdToName()
58
+ rank_id_to_name = self.workload.DimensionIdToName()
59
+
60
+ grouped_einsums_in_name = GroupOfSimilarEinsums(
61
+ einsum_id_to_name[self.reference_einsum], self.workload
62
+ )
63
+ self.in_id = False
64
+
65
+ similar_einsums_to_renamings = self.get_einsums_similar_to_reference(
66
+ self.reference_einsum
67
+ )
68
+ for einsum_id, renaming in similar_einsums_to_renamings.items():
69
+ rank_renaming, tensor_renaming = renaming
70
+ rank_renaming_in_names = {
71
+ rank_id_to_name[k]: rank_id_to_name[v] for k, v in rank_renaming.items()
72
+ }
73
+ tensor_renaming_in_names = {
74
+ tensor_id_to_name[k]: tensor_id_to_name[v]
75
+ for k, v in tensor_renaming.items()
76
+ }
77
+ grouped_einsums_in_name.add_einsum_similar_to_reference(
78
+ einsum_id_to_name[self.reference_einsum],
79
+ einsum_id_to_name[einsum_id],
80
+ rank_renaming_in_names,
81
+ tensor_renaming_in_names,
82
+ )
83
+
84
+ return grouped_einsums_in_name
@@ -0,0 +1,2 @@
1
+ from .ffmt import get_ffmt_tag
2
+ from .onesplit import get_one_split_tag
@@ -0,0 +1,212 @@
1
+ from fastfusion.frontend.mapping import Loop, Temporal
2
+ from fastfusion.mapper.FFM.deprecate_maybe.tags import Tags
3
+
4
+ from .util import get_fused_loops_per_tensor
5
+
6
+
7
+ FFMT_VALID = "FFMT_VALID"
8
+ FFMT_WEIGHT_UNTILED = "FFMT_WEIGHT_UNTILED"
9
+ FFMT_WEIGHT_TILED = "FFMT_WEIGHT_TILED"
10
+
11
+
12
+ def get_ffmt_tag(compatibility):
13
+ return get_ffmt_matmul_tag(compatibility)
14
+ if "Matmul" in einsum_name:
15
+ return get_ffmt_matmul_tag(compatibility)
16
+ else:
17
+ return get_ffmt_mha_tag(compatibility)
18
+
19
+
20
+ def get_ffmt_matmul_tag(compatibility):
21
+ # FFMT is:
22
+ # - [input | output, weight]
23
+ # If there's >1 fused loop, they must be above the same number of loops
24
+ tensors = [s for s in compatibility.tensors if s.resource_name != "MainMemory"]
25
+ if len(tensors) <= 1:
26
+ return Tags((FFMT_VALID,))
27
+
28
+ allowed_n_loops = [
29
+ (0, 0),
30
+ (1, 1),
31
+ (1, 2),
32
+ ]
33
+
34
+ # If there's a B or H fused loop, add one to the allowed n_loops
35
+ for rank_var in "b", "h":
36
+ if any(rank_var in l.rank_variable for l in compatibility.loops):
37
+ allowed_n_loops = [(x + 1, y + 1) for x, y in allowed_n_loops]
38
+
39
+ if tuple(sorted(s.above_loop_index for s in tensors)) in [
40
+ (0, 0),
41
+ (1, 1),
42
+ (1, 2),
43
+ ]:
44
+ return Tags((FFMT_VALID,))
45
+ raise ValueError()
46
+
47
+
48
+ def get_ffmt_mha_tag(compatibility):
49
+ tensors = [s for s in compatibility.tensors if s.resource_name != "MainMemory"]
50
+ if len(compatibility.loops) == 0:
51
+ return Tags((FFMT_VALID,))
52
+
53
+ # Loops have to be in the order (b, h)
54
+ if len(compatibility.loops) == 1:
55
+ return Tags((FFMT_INVALID,))
56
+
57
+ if len(set(s.above_loop_index for s in tensors)) > 1:
58
+ raise ValueError()
59
+ return Tags((FFMT_VALID,))
60
+
61
+ for tensors in compatibility.tensors:
62
+ if tensor.resource_name == "MainMemory":
63
+ continue
64
+ unique_loops.add(tensor.above_loop_index)
65
+
66
+ if len(unique_loops) == 0:
67
+ return Tags() # unfused is compatible with anything
68
+
69
+ untiled_fused = len(unique_loops) == 1 and next(iter(unique_loops)) == 0
70
+ if untiled_fused:
71
+ return Tags((FFMT_VALID,))
72
+
73
+ min_weight_idx, max_weight_idx, max_non_weight_idx = float("inf"), 0, 0
74
+ max_weight_idx = 0
75
+ for tensor, n_loops in tensor_to_n_fused_loops.items():
76
+ is_weight = "Filter" in tensor.name
77
+ if is_weight:
78
+ min_weight_idx = min(min_weight_idx, n_loops)
79
+ max_weight_idx = max(max_weight_idx, n_loops)
80
+ else:
81
+ max_non_weight_idx = max(max_non_weight_idx, n_loops)
82
+
83
+ weight_untiled = min_weight_idx == 0 and max_weight_idx == 0
84
+ if weight_untiled:
85
+ return Tags((FFMT_VALID, FFMT_WEIGHT_UNTILED))
86
+ elif min_weight_idx >= max_non_weight_idx:
87
+ return Tags((FFMT_VALID, FFMT_WEIGHT_TILED))
88
+ raise ValueError()
89
+
90
+
91
+ def get_ffmt_mha_tag(pmapping):
92
+ einsum_name = pmapping[-1].einsum_name
93
+ B, H, M, F, P, G, E, D, C, J = "bhmfpgedcj"
94
+ EINSUM_NAME_TO_REDUCED_RANK_OUTPUT_RANK = {
95
+ "Q": [D, E],
96
+ "K": [D, E],
97
+ "V": [D, F],
98
+ "QK": [E, P],
99
+ "AV": [P, F],
100
+ "Z": [F, G],
101
+ "FFA": [G, C],
102
+ "FFB": [C, J],
103
+ }
104
+
105
+ rank_var_permutation = []
106
+ for node in pmapping:
107
+ if isinstance(node, Loop):
108
+ if not isinstance(node, Temporal):
109
+ raise RuntimeError(
110
+ "get_ffmt_mha_tag should not be used for "
111
+ "anything other than Snowcat"
112
+ )
113
+ rank_var_permutation.append(node.rank_variable)
114
+
115
+ tensor_to_n_fused_loops = get_fused_loops_per_tensor(
116
+ pmapping, intermediate_tensors, "MainMemory"
117
+ )
118
+ unfused = all(
119
+ n is None
120
+ for t, n in tensor_to_n_fused_loops.items()
121
+ if t in intermediate_tensors
122
+ )
123
+ if einsum_name not in EINSUM_NAME_TO_REDUCED_RANK_OUTPUT_RANK:
124
+ if unfused:
125
+ return Tags((FFMT_VALID,))
126
+ raise ValueError()
127
+
128
+ reduced_rank, output_rank = EINSUM_NAME_TO_REDUCED_RANK_OUTPUT_RANK[einsum_name]
129
+
130
+ EINSUM_NAME_TO_INPUT_OUTPUT_TENSORS = {
131
+ "Q": ["I_I_to_Q_K_V", "Q_Q_to_QK"],
132
+ "K": ["I_I_to_Q_K_V", "K_K_to_QK"],
133
+ "V": ["I_I_to_Q_K_V", "V_V_to_AV"],
134
+ "QK": ["Q_Q_to_QK", "QK_QK_to_AV"],
135
+ "AV": ["QK_QK_to_AV", "AV_AV_to_Z"],
136
+ "Z": ["AV_AV_to_Z", "Z_Z_to_FFA"],
137
+ "FFA": ["Z_Z_to_FFA", "FFA_FFA_to_FFB"],
138
+ "FFB": ["FFA_FFA_to_FFB", "FFB_FFB_to_n"],
139
+ }
140
+
141
+ input_tensor, output_tensor = EINSUM_NAME_TO_INPUT_OUTPUT_TENSORS[einsum_name]
142
+ input_output_tensors = {input_tensor, output_tensor}
143
+
144
+ min_weight_idx = float("inf")
145
+ max_weight_idx = 0
146
+ max_non_weight_idx = 0
147
+ first, last = True, True
148
+ for tensor, n_loops in tensor_to_n_fused_loops.items():
149
+ if tensor.name == input_tensor and n_loops is not None:
150
+ first = False
151
+ if tensor.name == output_tensor and n_loops is not None:
152
+ last = False
153
+
154
+ is_weight = tensor.name not in input_output_tensors
155
+ if is_weight:
156
+ min_weight_idx = min(min_weight_idx, n_loops)
157
+ max_weight_idx = max(max_weight_idx, n_loops)
158
+ else:
159
+ max_non_weight_idx = max(max_non_weight_idx, n_loops)
160
+
161
+ # Rank variable order and the n_loops for (input, output)
162
+ prefix_choices = [([B, H], (2, 2))]
163
+
164
+ # Rank variable order and the n_loops for (input, output)
165
+ extra_rank_choices = [
166
+ ([M], (1, 1)),
167
+ ]
168
+ if first:
169
+ if output_rank is not None:
170
+ extra_rank_choices.append(([M, output_rank], (1, 2)))
171
+ if reduced_rank is not None and output_rank is not None:
172
+ extra_rank_choices.append(([M, output_rank, reduced_rank], (3, 2)))
173
+ if output_rank is None and reduced_rank is not None:
174
+ extra_rank_choices.append(([M, reduced_rank], (2, 1)))
175
+ elif last:
176
+ if output_rank is not None:
177
+ extra_rank_choices.append(([M, output_rank], (1, 2)))
178
+ else:
179
+ if reduced_rank is not None:
180
+ extra_rank_choices.append(([M, reduced_rank], (2, 1)))
181
+
182
+ for prefix_permutation, prefix_n_loops in prefix_choices:
183
+ for extra_permutation, extra_n_loops in extra_rank_choices:
184
+ permutation = prefix_permutation + extra_permutation
185
+ input_n_loops = prefix_n_loops[0] + extra_n_loops[0]
186
+ output_n_loops = prefix_n_loops[1] + extra_n_loops[1]
187
+ untiled_weight_idx = len(prefix_permutation)
188
+
189
+ permutation_matches = True
190
+ for rank_var, ref_rank_var in zip(rank_var_permutation, permutation):
191
+ if rank_var != ref_rank_var:
192
+ permutation_matches = False
193
+ break
194
+
195
+ if not permutation_matches:
196
+ continue
197
+
198
+ if tensor_to_n_fused_loops[input_tensor] != input_n_loops:
199
+ continue
200
+ if tensor_to_n_fused_loops[output_tensor] != output_n_loops:
201
+ continue
202
+
203
+ weight_untiled = (
204
+ min_weight_idx == untiled_weight_idx
205
+ and max_weight_idx == untiled_weight_idx
206
+ )
207
+ if weight_untiled:
208
+ return Tags((FFMT_VALID, FFMT_WEIGHT_UNTILED))
209
+ elif min_weight_idx >= max_non_weight_idx:
210
+ return Tags((FFMT_VALID, FFMT_WEIGHT_TILED))
211
+
212
+ raise ValueError()
@@ -0,0 +1,24 @@
1
+ from fastfusion.mapper.FFM.deprecate_maybe.tags import Tags
2
+ from fastfusion.mapper.FFM._join_pmappings.compatibility import Compatibility
3
+
4
+
5
+ ONE_SPLIT = "ONE_SPLIT"
6
+ NOT_ONE_SPLIT = "NOT_ONE_SPLIT"
7
+
8
+
9
+ def get_one_split_tag(compatibility: Compatibility) -> Tags:
10
+ # TODO
11
+ unique_loops = set()
12
+ for tensor in compatibility.tensors:
13
+ if tensor.resource_name == "MainMemory":
14
+ continue
15
+ unique_loops.add(tensor.above_loop_index)
16
+
17
+ if len(unique_loops) == 0:
18
+ return Tags() # unfused is compatible with anything
19
+
20
+ # Fused with both sides. Make sure that the number of loops is the same.
21
+ if len(unique_loops) > 1:
22
+ return Tags(("INVALID",))
23
+
24
+ return Tags((ONE_SPLIT, f"FUSED_LOOPS={next(iter(unique_loops))}"))
@@ -0,0 +1,24 @@
1
+ from fastfusion.frontend.mapping import Reservation, Loop, Mapping
2
+
3
+
4
+ def get_fused_loops_per_tensor(
5
+ pmapping: Mapping, intermediate_tensors, non_fused_memory
6
+ ):
7
+ """
8
+ Returns a dictionary mapping tensor to number of fused loops or None
9
+ if unfused (backed in non_fused_memory).
10
+ """
11
+ tensor_to_n_fused_loops = {}
12
+ n_loops = 0
13
+ for node in pmapping.nodes:
14
+ if isinstance(node, Reservation):
15
+ tensor = node.tensor
16
+ if tensor not in intermediate_tensors or tensor in tensor_to_n_fused_loops:
17
+ continue
18
+ if node.component == non_fused_memory:
19
+ tensor_to_n_fused_loops[tensor] = None
20
+ else:
21
+ tensor_to_n_fused_loops[tensor] = n_loops
22
+ elif isinstance(node, Loop):
23
+ n_loops += 1
24
+ return tensor_to_n_fused_loops