accelforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. accelforge/__init__.py +21 -0
  2. accelforge/_accelerated_imports.py +16 -0
  3. accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
  4. accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
  5. accelforge/_deprecate/_simanneal/simanneal.py +666 -0
  6. accelforge/_deprecate/_simanneal/tracking.py +105 -0
  7. accelforge/_deprecate/_simanneal/wrappers.py +218 -0
  8. accelforge/_deprecate/_simanneal2/__init__.py +7 -0
  9. accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
  10. accelforge/_deprecate/_simanneal2/tracking.py +116 -0
  11. accelforge/_deprecate/compatibility_util.py +181 -0
  12. accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
  13. accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
  14. accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
  15. accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
  16. accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
  17. accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
  18. accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
  19. accelforge/_deprecate/tags.py +69 -0
  20. accelforge/_deprecate/viz/__init__.py +0 -0
  21. accelforge/_deprecate/viz/interactive.py +159 -0
  22. accelforge/_deprecate/viz/reservationtree.py +307 -0
  23. accelforge/_deprecate/viz/ski_slope.py +88 -0
  24. accelforge/_version.py +15 -0
  25. accelforge/examples.py +39 -0
  26. accelforge/frontend/__init__.py +10 -0
  27. accelforge/frontend/_binding.py +129 -0
  28. accelforge/frontend/_workload_isl/__init__.py +2 -0
  29. accelforge/frontend/_workload_isl/_isl.py +149 -0
  30. accelforge/frontend/_workload_isl/_symbolic.py +141 -0
  31. accelforge/frontend/arch copy.py +1544 -0
  32. accelforge/frontend/arch.py +1642 -0
  33. accelforge/frontend/config.py +63 -0
  34. accelforge/frontend/mapper/__init__.py +5 -0
  35. accelforge/frontend/mapper/ffm.py +126 -0
  36. accelforge/frontend/mapper/mapper.py +7 -0
  37. accelforge/frontend/mapper/metrics.py +30 -0
  38. accelforge/frontend/mapping/__init__.py +1 -0
  39. accelforge/frontend/mapping/mapping.py +1736 -0
  40. accelforge/frontend/model.py +14 -0
  41. accelforge/frontend/renames.py +150 -0
  42. accelforge/frontend/spec copy.py +230 -0
  43. accelforge/frontend/spec.py +301 -0
  44. accelforge/frontend/variables.py +12 -0
  45. accelforge/frontend/workload.py +952 -0
  46. accelforge/mapper/FFM/__init__.py +9 -0
  47. accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
  48. accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
  49. accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
  50. accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
  51. accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
  52. accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
  53. accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
  54. accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
  55. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
  56. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
  57. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
  58. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
  59. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
  60. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
  61. accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
  62. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
  63. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
  64. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
  65. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
  66. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
  67. accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
  68. accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
  69. accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
  70. accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
  71. accelforge/mapper/FFM/data.py +61 -0
  72. accelforge/mapper/FFM/main copy.py +236 -0
  73. accelforge/mapper/FFM/main.py +208 -0
  74. accelforge/mapper/FFM/mappings.py +510 -0
  75. accelforge/mapper/FFM/pmappings.py +310 -0
  76. accelforge/mapper/__init__.py +4 -0
  77. accelforge/mapper.py +0 -0
  78. accelforge/model/__init__.py +1 -0
  79. accelforge/model/_looptree/__init__.py +0 -0
  80. accelforge/model/_looptree/accesses.py +335 -0
  81. accelforge/model/_looptree/capacity/__init__.py +1 -0
  82. accelforge/model/_looptree/capacity/aggregators.py +36 -0
  83. accelforge/model/_looptree/capacity/capacity.py +47 -0
  84. accelforge/model/_looptree/energy.py +150 -0
  85. accelforge/model/_looptree/equivalent_ranks.py +29 -0
  86. accelforge/model/_looptree/latency/__init__.py +1 -0
  87. accelforge/model/_looptree/latency/latency.py +98 -0
  88. accelforge/model/_looptree/latency/memory.py +120 -0
  89. accelforge/model/_looptree/latency/processors.py +92 -0
  90. accelforge/model/_looptree/mapping_utilities.py +71 -0
  91. accelforge/model/_looptree/reuse/__init__.py +4 -0
  92. accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
  93. accelforge/model/_looptree/reuse/isl/des.py +59 -0
  94. accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
  95. accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
  96. accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
  97. accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
  98. accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
  99. accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
  100. accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
  101. accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
  102. accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
  103. accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
  104. accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
  105. accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
  106. accelforge/model/_looptree/run.py +122 -0
  107. accelforge/model/_looptree/types.py +26 -0
  108. accelforge/model/_looptree/visualization/__init__.py +0 -0
  109. accelforge/model/_looptree/visualization/occupancy.py +11 -0
  110. accelforge/model/main.py +222 -0
  111. accelforge/plotting/__init__.py +2 -0
  112. accelforge/plotting/mappings.py +219 -0
  113. accelforge/plotting/specs.py +57 -0
  114. accelforge/util/__init__.py +4 -0
  115. accelforge/util/_base_analysis_types.py +24 -0
  116. accelforge/util/_basetypes.py +1089 -0
  117. accelforge/util/_frozenset.py +36 -0
  118. accelforge/util/_isl.py +29 -0
  119. accelforge/util/_itertools.py +14 -0
  120. accelforge/util/_mathfuncs.py +57 -0
  121. accelforge/util/_parse_expressions.py +339 -0
  122. accelforge/util/_picklecache.py +32 -0
  123. accelforge/util/_setexpressions.py +268 -0
  124. accelforge/util/_sympy/__init__.py +0 -0
  125. accelforge/util/_sympy/broadcast_max.py +18 -0
  126. accelforge/util/_visualization.py +112 -0
  127. accelforge/util/_yaml.py +579 -0
  128. accelforge/util/parallel.py +193 -0
  129. accelforge-0.0.1.dist-info/METADATA +64 -0
  130. accelforge-0.0.1.dist-info/RECORD +258 -0
  131. accelforge-0.0.1.dist-info/WHEEL +5 -0
  132. accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
  133. accelforge-0.0.1.dist-info/top_level.txt +5 -0
  134. docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
  135. docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
  136. docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
  137. docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
  138. docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
  139. docs/_build/html/_sources/fastfusion.rst.txt +20 -0
  140. docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
  141. docs/_build/html/_sources/index.rst.txt +87 -0
  142. docs/_build/html/_sources/modules.rst.txt +7 -0
  143. docs/_build/html/_sources/notes/citation.rst.txt +45 -0
  144. docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
  145. docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
  146. docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
  147. docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
  148. docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
  149. docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
  150. docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
  151. docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
  152. docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
  153. docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
  154. docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
  155. docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
  156. docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
  157. docs/_build/html/_sources/notes/spec.rst.txt +36 -0
  158. docs/source/_ext/include_attrs.py +213 -0
  159. docs/source/_ext/include_docstring.py +364 -0
  160. docs/source/_ext/include_functions.py +154 -0
  161. docs/source/_ext/include_notebook.py +131 -0
  162. docs/source/_ext/include_yaml.py +119 -0
  163. docs/source/_ext/inherited_attributes.py +222 -0
  164. docs/source/_ext/paths.py +4 -0
  165. docs/source/conf.py +79 -0
  166. examples/arches/compute_in_memory/_include.yaml +74 -0
  167. examples/arches/compute_in_memory/_include_functions.py +229 -0
  168. examples/arches/compute_in_memory/_load_spec.py +57 -0
  169. examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
  170. examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
  171. examples/arches/compute_in_memory/components/misc.py +195 -0
  172. examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
  173. examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
  174. examples/arches/compute_in_memory/isaac.yaml +233 -0
  175. examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
  176. examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
  177. examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
  178. examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
  179. examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
  180. examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
  181. examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
  182. examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
  183. examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
  184. examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
  185. examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
  186. examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
  187. examples/arches/eyeriss.yaml +68 -0
  188. examples/arches/fanout_variations/at_glb.yaml +31 -0
  189. examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
  190. examples/arches/fanout_variations/at_mac.yaml +31 -0
  191. examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
  192. examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
  193. examples/arches/nvdla.yaml +47 -0
  194. examples/arches/simple.yaml +28 -0
  195. examples/arches/tpu_v4i.yaml +67 -0
  196. examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
  197. examples/misc/component_annotated.yaml +33 -0
  198. examples/workloads/gpt3_6.7B.yaml +124 -0
  199. examples/workloads/matmuls.yaml +20 -0
  200. examples/workloads/mobilenet_28.yaml +81 -0
  201. examples/workloads/mobilenet_various_separate.yaml +106 -0
  202. examples/workloads/three_matmuls_annotated.yaml +59 -0
  203. notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
  204. notebooks/compute_in_memory/_scripts.py +339 -0
  205. notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
  206. notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
  207. notebooks/paths.py +4 -0
  208. notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
  209. notebooks/tutorials/FFM.ipynb +3498 -0
  210. notebooks/tutorials/_include.py +48 -0
  211. notebooks/tutorials/component_energy_area.ipynb +363 -0
  212. tests/Q_mapping.yaml +38 -0
  213. tests/__init__.py +0 -0
  214. tests/conv.mapping.yaml +27 -0
  215. tests/conv.workload.yaml +13 -0
  216. tests/conv_sym.mapping.yaml +43 -0
  217. tests/copy.mapping.yaml +35 -0
  218. tests/copy.workload.yaml +15 -0
  219. tests/distribuffers/__init__.py +0 -0
  220. tests/distribuffers/multicast/test_cases.yaml +482 -0
  221. tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
  222. tests/distribuffers/spec/distributed.yaml +100 -0
  223. tests/distribuffers/spec/logical_arch.yaml +32 -0
  224. tests/distribuffers/spec/physical_arch.yaml +69 -0
  225. tests/distribuffers/test_binding.py +48 -0
  226. tests/frontend/__init__.py +0 -0
  227. tests/frontend/test_mapping_viz.py +52 -0
  228. tests/mapper/__init__.py +0 -0
  229. tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
  230. tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
  231. tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
  232. tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
  233. tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
  234. tests/mapper/test_mapping_to_isl.py +90 -0
  235. tests/mapper/test_spatial_reuse_analysis.py +67 -0
  236. tests/mapper/test_temporal_reuse_analysis.py +56 -0
  237. tests/mapper/util.py +58 -0
  238. tests/matmul.mapping.yaml +29 -0
  239. tests/matmul.workload.yaml +12 -0
  240. tests/matmul_spatial.mapping.yaml +44 -0
  241. tests/mha.renames.yaml +65 -0
  242. tests/mha.workload.yaml +67 -0
  243. tests/mha.yaml +59 -0
  244. tests/mha_full.workload.yaml +67 -0
  245. tests/mobilenet.workload.yaml +35 -0
  246. tests/mobilenet_long.workload.yaml +64 -0
  247. tests/pmappingcache.py +24 -0
  248. tests/processing_stage.arch.yaml +40 -0
  249. tests/snowcat.arch.yaml +36 -0
  250. tests/test_ffm_join_pmappings.py +106 -0
  251. tests/test_ffm_make_pmappings.py +82 -0
  252. tests/test_ffm_make_tile_shapes.py +49 -0
  253. tests/test_mapper.py +100 -0
  254. tests/test_model.py +37 -0
  255. tests/test_plotting.py +72 -0
  256. tests/test_processing_stage.py +46 -0
  257. tests/test_symbolic_model.py +248 -0
  258. tests/test_workload.py +141 -0
@@ -0,0 +1,382 @@
1
+ from collections.abc import Collection, Generator, Sequence
2
+ from dataclasses import dataclass
3
+ from itertools import product
4
+ import logging
5
+ from typing import Any
6
+
7
+ import accelforge.frontend.arch as arch
8
+ from accelforge.frontend.mapping import MappingNode, ProcessingStage, TensorHolder
9
+ from accelforge.frontend.spec import Spec
10
+ from accelforge.frontend.workload import TensorName, SymbolTable
11
+ from accelforge.util._parse_expressions import MATH_FUNCS
12
+ from accelforge.util._setexpressions import eval_set_expression
13
+
14
+ from accelforge.mapper.FFM._make_pmappings.make_pmapping_templates.make_storages import (
15
+ make_storage_choices_all_levels,
16
+ )
17
+ from accelforge.frontend.workload import EinsumName
18
+
19
+
20
+ def get_tensor_choices(
21
+ einsum_name: EinsumName,
22
+ nodes: list[arch.Memory],
23
+ symbol_table: SymbolTable,
24
+ spec: Spec,
25
+ first_memory: arch.Memory,
26
+ fusable_tensors: set[TensorName],
27
+ ) -> Generator[tuple[list[TensorHolder], SymbolTable, arch.Compute], None, None]:
28
+ nodes, compute = nodes[:-1], nodes[-1]
29
+ while True:
30
+ if not nodes:
31
+ return
32
+ if not isinstance(nodes[0], arch.Memory):
33
+ nodes = nodes[1:]
34
+ continue
35
+ assert isinstance(nodes[0].enabled, bool)
36
+ if not nodes[0].enabled:
37
+ nodes = nodes[1:]
38
+ continue
39
+ break
40
+
41
+ tensors = spec.workload.einsums[einsum_name].tensor_names
42
+ is_copy_op = spec.workload.einsums[einsum_name].is_copy_operation
43
+ persistent_tensors = {
44
+ t.name
45
+ for t in spec.workload.einsums[einsum_name].tensor_accesses
46
+ if t.persistent
47
+ }
48
+
49
+ for choice, symbol_table in make_storage_choices_all_levels(
50
+ nodes=nodes,
51
+ symbol_table=symbol_table,
52
+ is_copy_op=is_copy_op,
53
+ persistent_tensors=persistent_tensors,
54
+ seen_tensors=set(),
55
+ einsum_name=einsum_name,
56
+ ):
57
+ x = [y for z in choice.values() for y in z]
58
+ logging.info(
59
+ f"\t\tUnordered storage choice: {", ".join(n.compact_str() for n in x)}"
60
+ )
61
+ all_tensor_holders = [v2 for v in choice.values() for v2 in v]
62
+
63
+ # Start out the mapping with the outermost memory name
64
+ base_mapping = []
65
+ # for node in list(all_tensor_holders[::-1]):
66
+ # if node.component == first_tensor_holder.name:
67
+ # all_tensor_holders.remove(node)
68
+ # base_mapping.append(node)
69
+
70
+ # Get the dataflow constraints for the mapping
71
+ required_order = get_tensor_order_constraint(nodes, symbol_table, tensors)
72
+
73
+ symbol_table["arch_attributes"] = {}
74
+ cur_compute = compute._parse_expressions(
75
+ symbol_table,
76
+ location=f"arch.{compute.name}",
77
+ must_parse_try_parse_to=True,
78
+ must_copy=False,
79
+ )[0]
80
+ assert isinstance(cur_compute.enabled, bool)
81
+ if not cur_compute.enabled:
82
+ continue
83
+
84
+ for mapping in recursive_order_tensor_choices(
85
+ einsum_name,
86
+ tensors,
87
+ base_mapping,
88
+ nodes,
89
+ all_tensor_holders,
90
+ required_order,
91
+ spec,
92
+ is_copy_op,
93
+ first_memory,
94
+ fusable_tensors,
95
+ ):
96
+ yield mapping, symbol_table, cur_compute
97
+
98
+
99
+ def get_tensor_order_constraint(nodes, symbol_table, tensors):
100
+ required_order: dict[str, list[Order]] = {}
101
+ for node in nodes:
102
+ if isinstance(node, arch.Fanout):
103
+ continue
104
+ for order_constraint in node.tensors.tensor_order_options:
105
+ order = Order()
106
+ for together_tensors in order_constraint:
107
+ in_mapping_together_tensors = [
108
+ tensor for tensor in together_tensors if tensor in tensors
109
+ ]
110
+ if len(in_mapping_together_tensors) == 1:
111
+ only_tensor = in_mapping_together_tensors[0]
112
+ order.add_tensor(only_tensor)
113
+ elif len(in_mapping_together_tensors) > 1:
114
+ order.add_together_tensors(in_mapping_together_tensors)
115
+ if order.order:
116
+ required_order.setdefault(node.name, []).append(order)
117
+ return required_order
118
+
119
+
120
+ def recursive_order_tensor_choices(
121
+ einsum_name: EinsumName,
122
+ tensors: set[TensorName],
123
+ mapping: Sequence[MappingNode],
124
+ nodes: list[arch.Memory],
125
+ remaining_choices: list,
126
+ required_order: list[list[TensorHolder]],
127
+ spec: Spec,
128
+ is_copy_op: bool,
129
+ first_memory: arch.Memory,
130
+ fusable_tensors: set[TensorName],
131
+ ) -> Generator[list[MappingNode], None, None]:
132
+ def check_has_tensors(mapping: list[MappingNode]):
133
+ tensor_holders = [node for node in mapping if isinstance(node, TensorHolder)]
134
+ tensors_in_mapping = {
135
+ tensor
136
+ for tensor_holder in tensor_holders
137
+ for tensor in tensor_holder.tensors
138
+ }
139
+ if tensors_in_mapping != tensors:
140
+ raise ValueError(
141
+ f"Einsum {einsum_name} has a pmapping template that is missing tensors. Ensure "
142
+ f"that there is a storage node storing each tensor in the Einsum. Missing "
143
+ f"tensors: {tensors - tensors_in_mapping}. Pmapping template:\n\t"
144
+ + "\n\t".join(m.compact_str() for m in mapping)
145
+ )
146
+
147
+ mapping = list(mapping)
148
+ if not remaining_choices:
149
+ check_has_tensors(mapping)
150
+ yield mapping
151
+ return
152
+
153
+ # If it's a copy op and we have the backing storage for every tensor, return
154
+ # immediately
155
+ if is_copy_op:
156
+ tensor_holders = [node for node in mapping if isinstance(node, TensorHolder)]
157
+ if set().union(*[t._backing for t in tensor_holders]) == tensors:
158
+ check_has_tensors(mapping)
159
+ yield mapping
160
+ return
161
+
162
+ for choice in sorted(remaining_choices, key=lambda x: x.compact_str()):
163
+ mapping.append(choice)
164
+ new_remaining = [c for c in remaining_choices if c != choice]
165
+ valid, reason = valid_tensor_holder_order(
166
+ mapping,
167
+ [n.name for n in nodes],
168
+ required_order,
169
+ spec,
170
+ first_memory,
171
+ fusable_tensors,
172
+ )
173
+ if valid:
174
+ yield from recursive_order_tensor_choices(
175
+ einsum_name,
176
+ tensors,
177
+ mapping,
178
+ nodes,
179
+ new_remaining,
180
+ required_order,
181
+ spec,
182
+ is_copy_op,
183
+ first_memory,
184
+ fusable_tensors,
185
+ )
186
+ else:
187
+ logging.info(
188
+ "\t\t"
189
+ + " " * len(mapping)
190
+ + f"Invalid tensor holder order: {", ".join(n.compact_str() for n in mapping)}: {reason}"
191
+ )
192
+ mapping.pop()
193
+
194
+
195
+ def valid_tensor_holder_order(
196
+ mapping: Sequence[TensorHolder],
197
+ node_names: list[str],
198
+ required_orders: dict[str, list["Order"]],
199
+ spec: Spec,
200
+ first_memory: arch.Memory,
201
+ fusable_tensors: set[TensorName],
202
+ ):
203
+ memory_to_satisfied_constraints: dict[str, set] = {}
204
+ for i, m0 in enumerate(mapping):
205
+ for j, m1 in enumerate(mapping[i:]):
206
+ j += i
207
+
208
+ s1, s2 = m0.component, m1.component
209
+ s1_idx, s2_idx = node_names.index(s1), node_names.index(s2)
210
+ s1_persistent, s2_persistent = m0.persistent, m1.persistent
211
+ either_persistent = s1_persistent or s2_persistent
212
+
213
+ assert len(m0.tensors) == 1
214
+ assert len(m1.tensors) == 1
215
+
216
+ # If they're persistent they're forced to be at the top.
217
+ force_order = (
218
+ spec.mapper.ffm.force_memory_hierarchy_order and not either_persistent
219
+ )
220
+ force_order &= m0.component_object.tensors.force_memory_hierarchy_order
221
+ force_order &= m1.component_object.tensors.force_memory_hierarchy_order
222
+
223
+ # Ctrl-F for CONTIGUOUS_ITERATION_SPACE_DISCUSSION: The following line does
224
+ # not let backing storage be above in the mapping anything that is below it
225
+ # in the memory hierarchy. THIS IS NOT FUNDAMENTAL. If we remove this
226
+ # constraint, then the fused loops may be different across different backing
227
+ # storages, so we would need to update make_pmappings_from_templates.py to
228
+ # make compatibility from the mapping for each tensor.
229
+ force_order |= bool(m0._backing & fusable_tensors)
230
+
231
+ if force_order and i < j and s2_idx < s1_idx:
232
+ return (
233
+ False,
234
+ f"Memory {s1} is below memory {s2}, violating memory hierarchy order.",
235
+ )
236
+
237
+ s1_outermost = s1_persistent
238
+ s2_outermost = s2_persistent
239
+ if not spec.mapper.ffm._can_lower_outermost_memory:
240
+ s1_outermost |= s1 == first_memory.name
241
+ s2_outermost |= s2 == first_memory.name
242
+
243
+ # Persistent tensors must be at the top of the hierarchy
244
+ if s2_outermost and not s1_outermost and i < j:
245
+ return (
246
+ False,
247
+ f"Outermost {m0.compact_str()}, persistent {s1_persistent} is below non-outermost {m1.compact_str()}, persistent {s2_persistent}.",
248
+ )
249
+
250
+ # We don't really care about processing stage order, so just make it follow
251
+ # the regular memory hierarchy order. For processing stages at a given
252
+ # level, make them alphabetical.
253
+ if (
254
+ isinstance(m0, ProcessingStage)
255
+ and m0.component == m1.component
256
+ and m0.tensor < m1.tensor
257
+ ):
258
+ return (
259
+ False,
260
+ f"Processing stage {m0} is not ordered alphabetically by tensor; has tensor {m0.tensor} before {m1.tensor}",
261
+ )
262
+
263
+ # If there is a processing stage, don't explore order. If there's two
264
+ # back-to-back nodes and one is a processing stage, make them follow the
265
+ # memory hierarchy order.
266
+ if isinstance(m0, ProcessingStage) and s2_idx < s1_idx and i == j - 1:
267
+ return False, f"Processing stage {m0} is directly above {m1}"
268
+ if isinstance(m1, ProcessingStage) and s2_idx < s1_idx and i == j - 1:
269
+ return False, f"Processing stage {m1} is directly above {m0}"
270
+
271
+ if s1 == s2 and s1 in required_orders and i != j:
272
+ if s1 not in memory_to_satisfied_constraints:
273
+ memory_to_satisfied_constraints[s1] = {
274
+ i for i in range(len(required_orders[s1]))
275
+ }
276
+
277
+ good = True
278
+ for order_idx, order_choice in enumerate(required_orders[s1]):
279
+ if order_idx not in memory_to_satisfied_constraints[s1]:
280
+ continue
281
+
282
+ good = True
283
+ for t1, t2 in product(mapping[i].tensors, mapping[j].tensors):
284
+ idx_of_i_in_order = order_choice.index(t1)
285
+ idx_of_j_in_order = order_choice.index(t2)
286
+
287
+ if idx_of_i_in_order is None or idx_of_j_in_order is None:
288
+ continue
289
+
290
+ if idx_of_i_in_order > idx_of_j_in_order:
291
+ good = False
292
+ reason = f"Tensor {t1} is before tensor {t2} in the order {order_choice}"
293
+ break
294
+ if not good:
295
+ memory_to_satisfied_constraints[s1].remove(order_idx)
296
+
297
+ if len(memory_to_satisfied_constraints[s1]) == 0:
298
+ return False, reason
299
+
300
+ if not (set(m0.tensors) & set(m1.tensors)):
301
+ continue
302
+
303
+ if i < j and s2_idx < s1_idx:
304
+ return False, f"{m0.compact_str()} is below {m1.compact_str()}"
305
+
306
+ # If a tensor is stored in two levels back-to-back, then we should have
307
+ # bypassed the outer TensorHolder if possible.
308
+ either_backing = m0._backing & m1._backing
309
+ if (
310
+ "redundant_dataplacements"
311
+ not in spec.mapper.ffm._count_option_for_mapsapce_size_evaluation
312
+ ):
313
+ if i == j or i == j - 1:
314
+ if s1_idx < s2_idx and not (
315
+ (set(m0._must_keep_tensors) & set(m1.tensors)) or either_backing
316
+ ):
317
+ shared = set(m0._must_keep_tensors) & set(m1.tensors)
318
+ return (
319
+ False,
320
+ f"{shared} stored in back-to-back storage nodes, and could have bypassed the outer one.",
321
+ )
322
+ if s2_idx < s1_idx and not (
323
+ (set(m1._must_keep_tensors) & set(m0.tensors)) or either_backing
324
+ ):
325
+ shared = set(m1._must_keep_tensors) & set(m0.tensors)
326
+ return (
327
+ False,
328
+ f"{shared} is stored in back-to-back storage nodes, and could have bypassed the outer one.",
329
+ )
330
+
331
+ for i, m0 in enumerate(mapping):
332
+ for j, m1 in enumerate(mapping[i:]):
333
+ s1, s2 = m0.component, m1.component
334
+ if s1 != s2 or s1 not in memory_to_satisfied_constraints or i == j:
335
+ continue
336
+
337
+ satisfied_orders = memory_to_satisfied_constraints[s1]
338
+ assert len(satisfied_orders) > 0
339
+
340
+ for order_idx in satisfied_orders:
341
+ order = required_orders[s1][order_idx]
342
+ for tensor_i in m0.tensors:
343
+ for tensor_j in m1.tensors:
344
+ if order.index(tensor_i) != order.index(tensor_j):
345
+ continue
346
+ break
347
+
348
+ return True, ""
349
+
350
+
351
+ @dataclass(frozen=True)
352
+ class Alone:
353
+ tensor: Any
354
+
355
+
356
+ @dataclass(frozen=True)
357
+ class Together:
358
+ tensors: Collection[Any]
359
+
360
+
361
+ class Order:
362
+ """An ordering of tensors."""
363
+
364
+ def __init__(self):
365
+ self.order = []
366
+
367
+ def __repr__(self):
368
+ return f"Order({self.order})"
369
+
370
+ def add_tensor(self, tensor):
371
+ self.order.append(Alone(tensor))
372
+
373
+ def add_together_tensors(self, together_tensors):
374
+ self.order.append(Together(together_tensors))
375
+
376
+ def index(self, tensor):
377
+ for i, order_term in enumerate(self.order):
378
+ if (isinstance(order_term, Alone) and order_term.tensor == tensor) or (
379
+ isinstance(order_term, Together) and tensor in order_term.tensors
380
+ ):
381
+ return i
382
+ return None
@@ -0,0 +1,155 @@
1
+ import copy
2
+ from collections.abc import Generator
3
+ from itertools import chain, combinations
4
+ import logging
5
+
6
+ import accelforge.frontend.arch as arch
7
+ from accelforge.frontend.mapping import Storage, TensorHolder, ProcessingStage
8
+ from accelforge.frontend.workload import TensorName, SymbolTable
9
+
10
+ from accelforge.util._parse_expressions import ParseError
11
+ from accelforge.util._setexpressions import InvertibleSet
12
+
13
+
14
+ def make_tensor_choices_one_level(
15
+ node: arch.Leaf,
16
+ symbol_table: dict[str, InvertibleSet],
17
+ persistent_tensors: set[TensorName],
18
+ seen_tensors: set[TensorName] = (),
19
+ is_copy_op: bool = False,
20
+ einsum_name: str = None,
21
+ ) -> Generator[tuple[list[TensorHolder], SymbolTable, set[TensorName]], None, None]:
22
+ """
23
+ Generate combinations of TensorHolder nodes based on keep and bypass
24
+ constraints.
25
+
26
+ Each generated list contains TensorHolder nodes for single tensors.
27
+ """
28
+ assert "All" in symbol_table
29
+ tensors = symbol_table["All"]
30
+
31
+ if not isinstance(node, arch.TensorHolder):
32
+ yield [], symbol_table, set(seen_tensors)
33
+ return
34
+
35
+ if isinstance(node, arch.Memory):
36
+ target_type = Storage
37
+ elif isinstance(node, arch.ProcessingStage):
38
+ target_type = ProcessingStage
39
+ elif isinstance(node, arch.Dummy):
40
+ yield [], symbol_table, set(seen_tensors)
41
+ return
42
+ else:
43
+ raise ValueError(f"Unexpected tensor holder type: {type(node)}")
44
+
45
+ new_symbol_table = copy.copy(symbol_table)
46
+
47
+ node = copy.copy(node)
48
+ try:
49
+ node.tensors: arch.Tensors = node.tensors._parse_expressions(
50
+ symbol_table=symbol_table,
51
+ must_parse_try_parse_to=True,
52
+ must_copy=False,
53
+ location=f"arch.{node.name}.tensors",
54
+ )[0]
55
+ except ParseError as e:
56
+ e.add_field(f"Einsum {einsum_name} arch.{node.name}.tensors")
57
+ raise e
58
+
59
+ must_keep = tensors.to_my_space(node.tensors.keep | node.tensors.back)
60
+ may_keep = tensors.to_my_space(node.tensors.may_keep)
61
+ may_keep -= must_keep
62
+
63
+ if seen_tensors & set(node.tensors.back):
64
+ return
65
+
66
+ if must_keep - tensors:
67
+ raise KeyError(
68
+ f"Keep constraint for {node.name} includes tensors that are "
69
+ f"not in the workload: {must_keep - new_symbol_table['All']}"
70
+ )
71
+ if may_keep - tensors:
72
+ raise KeyError(
73
+ f"Bypass constraint for {node.name} includes tensors that are "
74
+ f"not in the workload: {may_keep - tensors.full_space}"
75
+ )
76
+
77
+ logging.info(
78
+ f"\t\t{node.name} must keep {sorted(must_keep)}, may keep {sorted(may_keep)}"
79
+ )
80
+
81
+ # No reuse in copy operations, so no need to keep tensors in more places
82
+ if is_copy_op:
83
+ may_keep -= tensors.to_my_space(seen_tensors)
84
+
85
+ for subset in powerset(sorted(may_keep, key=str)):
86
+ # Make keep choice & update symbol table
87
+ subset = tensors.to_my_space(set(subset))
88
+ keep_choice = tensors.to_my_space(subset | must_keep)
89
+ # Below line is so users can do MainMemory().tensors() or MainMemory.tensors
90
+ new_symbol_table[node.name] = keep_choice
91
+ new_symbol_table["Above"] |= keep_choice
92
+ new_seen_tensors = seen_tensors | set(keep_choice)
93
+
94
+ # Make sure they're all tensors
95
+ assert all(isinstance(k, TensorName) for k in keep_choice)
96
+ keep_choice = keep_choice.to_my_space({copy.copy(t) for t in keep_choice})
97
+ nodes = []
98
+
99
+ # Create storage nodes. Sort them to keep this deterministic. Ordering is done
100
+ # later.
101
+ for t in sorted(keep_choice, key=str):
102
+ nodes.append(
103
+ target_type(tensors=[t], component=node.name, component_object=node)
104
+ )
105
+ if t not in seen_tensors:
106
+ nodes[-1]._backing.add(t)
107
+ nodes[-1]._must_keep_tensors = [t]
108
+ nodes[-1].persistent = t in persistent_tensors
109
+ elif t in must_keep:
110
+ nodes[-1]._must_keep_tensors = [t]
111
+
112
+ yield nodes, new_symbol_table, new_seen_tensors
113
+
114
+
115
+ def make_storage_choices_all_levels(
116
+ nodes: list[TensorHolder],
117
+ symbol_table: dict[str, InvertibleSet],
118
+ persistent_tensors: set[TensorName],
119
+ seen_tensors: set[TensorName] = None,
120
+ is_copy_op: bool = False,
121
+ einsum_name: str = None,
122
+ ) -> Generator[tuple[dict[str, list[TensorHolder]], SymbolTable], None, None]:
123
+ """
124
+ Generate combinations of TensorHolder nodes based on keep and bypass
125
+ constraints.
126
+
127
+ Each generated dict maps memory name to a list of TensorHolder nodes for
128
+ single tensors.
129
+ """
130
+ seen_tensors = set() if seen_tensors is None else seen_tensors
131
+ if len(nodes) == 0:
132
+ yield dict(), symbol_table
133
+ return
134
+ for choice, symbol_table, new_seen_tensors in make_tensor_choices_one_level(
135
+ node=nodes[0],
136
+ symbol_table=symbol_table,
137
+ persistent_tensors=persistent_tensors,
138
+ seen_tensors=seen_tensors,
139
+ is_copy_op=is_copy_op,
140
+ einsum_name=einsum_name,
141
+ ):
142
+ for subchoices, symbol_table in make_storage_choices_all_levels(
143
+ nodes=nodes[1:],
144
+ symbol_table=symbol_table,
145
+ persistent_tensors=persistent_tensors,
146
+ seen_tensors=new_seen_tensors,
147
+ is_copy_op=is_copy_op,
148
+ einsum_name=einsum_name,
149
+ ):
150
+ yield {**subchoices, nodes[0].name: choice}, symbol_table
151
+
152
+
153
+ def powerset(iterable):
154
+ s = list(iterable)
155
+ return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))