accelforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. accelforge/__init__.py +21 -0
  2. accelforge/_accelerated_imports.py +16 -0
  3. accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
  4. accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
  5. accelforge/_deprecate/_simanneal/simanneal.py +666 -0
  6. accelforge/_deprecate/_simanneal/tracking.py +105 -0
  7. accelforge/_deprecate/_simanneal/wrappers.py +218 -0
  8. accelforge/_deprecate/_simanneal2/__init__.py +7 -0
  9. accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
  10. accelforge/_deprecate/_simanneal2/tracking.py +116 -0
  11. accelforge/_deprecate/compatibility_util.py +181 -0
  12. accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
  13. accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
  14. accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
  15. accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
  16. accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
  17. accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
  18. accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
  19. accelforge/_deprecate/tags.py +69 -0
  20. accelforge/_deprecate/viz/__init__.py +0 -0
  21. accelforge/_deprecate/viz/interactive.py +159 -0
  22. accelforge/_deprecate/viz/reservationtree.py +307 -0
  23. accelforge/_deprecate/viz/ski_slope.py +88 -0
  24. accelforge/_version.py +15 -0
  25. accelforge/examples.py +39 -0
  26. accelforge/frontend/__init__.py +10 -0
  27. accelforge/frontend/_binding.py +129 -0
  28. accelforge/frontend/_workload_isl/__init__.py +2 -0
  29. accelforge/frontend/_workload_isl/_isl.py +149 -0
  30. accelforge/frontend/_workload_isl/_symbolic.py +141 -0
  31. accelforge/frontend/arch copy.py +1544 -0
  32. accelforge/frontend/arch.py +1642 -0
  33. accelforge/frontend/config.py +63 -0
  34. accelforge/frontend/mapper/__init__.py +5 -0
  35. accelforge/frontend/mapper/ffm.py +126 -0
  36. accelforge/frontend/mapper/mapper.py +7 -0
  37. accelforge/frontend/mapper/metrics.py +30 -0
  38. accelforge/frontend/mapping/__init__.py +1 -0
  39. accelforge/frontend/mapping/mapping.py +1736 -0
  40. accelforge/frontend/model.py +14 -0
  41. accelforge/frontend/renames.py +150 -0
  42. accelforge/frontend/spec copy.py +230 -0
  43. accelforge/frontend/spec.py +301 -0
  44. accelforge/frontend/variables.py +12 -0
  45. accelforge/frontend/workload.py +952 -0
  46. accelforge/mapper/FFM/__init__.py +9 -0
  47. accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
  48. accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
  49. accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
  50. accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
  51. accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
  52. accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
  53. accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
  54. accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
  55. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
  56. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
  57. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
  58. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
  59. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
  60. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
  61. accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
  62. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
  63. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
  64. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
  65. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
  66. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
  67. accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
  68. accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
  69. accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
  70. accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
  71. accelforge/mapper/FFM/data.py +61 -0
  72. accelforge/mapper/FFM/main copy.py +236 -0
  73. accelforge/mapper/FFM/main.py +208 -0
  74. accelforge/mapper/FFM/mappings.py +510 -0
  75. accelforge/mapper/FFM/pmappings.py +310 -0
  76. accelforge/mapper/__init__.py +4 -0
  77. accelforge/mapper.py +0 -0
  78. accelforge/model/__init__.py +1 -0
  79. accelforge/model/_looptree/__init__.py +0 -0
  80. accelforge/model/_looptree/accesses.py +335 -0
  81. accelforge/model/_looptree/capacity/__init__.py +1 -0
  82. accelforge/model/_looptree/capacity/aggregators.py +36 -0
  83. accelforge/model/_looptree/capacity/capacity.py +47 -0
  84. accelforge/model/_looptree/energy.py +150 -0
  85. accelforge/model/_looptree/equivalent_ranks.py +29 -0
  86. accelforge/model/_looptree/latency/__init__.py +1 -0
  87. accelforge/model/_looptree/latency/latency.py +98 -0
  88. accelforge/model/_looptree/latency/memory.py +120 -0
  89. accelforge/model/_looptree/latency/processors.py +92 -0
  90. accelforge/model/_looptree/mapping_utilities.py +71 -0
  91. accelforge/model/_looptree/reuse/__init__.py +4 -0
  92. accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
  93. accelforge/model/_looptree/reuse/isl/des.py +59 -0
  94. accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
  95. accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
  96. accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
  97. accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
  98. accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
  99. accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
  100. accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
  101. accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
  102. accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
  103. accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
  104. accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
  105. accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
  106. accelforge/model/_looptree/run.py +122 -0
  107. accelforge/model/_looptree/types.py +26 -0
  108. accelforge/model/_looptree/visualization/__init__.py +0 -0
  109. accelforge/model/_looptree/visualization/occupancy.py +11 -0
  110. accelforge/model/main.py +222 -0
  111. accelforge/plotting/__init__.py +2 -0
  112. accelforge/plotting/mappings.py +219 -0
  113. accelforge/plotting/specs.py +57 -0
  114. accelforge/util/__init__.py +4 -0
  115. accelforge/util/_base_analysis_types.py +24 -0
  116. accelforge/util/_basetypes.py +1089 -0
  117. accelforge/util/_frozenset.py +36 -0
  118. accelforge/util/_isl.py +29 -0
  119. accelforge/util/_itertools.py +14 -0
  120. accelforge/util/_mathfuncs.py +57 -0
  121. accelforge/util/_parse_expressions.py +339 -0
  122. accelforge/util/_picklecache.py +32 -0
  123. accelforge/util/_setexpressions.py +268 -0
  124. accelforge/util/_sympy/__init__.py +0 -0
  125. accelforge/util/_sympy/broadcast_max.py +18 -0
  126. accelforge/util/_visualization.py +112 -0
  127. accelforge/util/_yaml.py +579 -0
  128. accelforge/util/parallel.py +193 -0
  129. accelforge-0.0.1.dist-info/METADATA +64 -0
  130. accelforge-0.0.1.dist-info/RECORD +258 -0
  131. accelforge-0.0.1.dist-info/WHEEL +5 -0
  132. accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
  133. accelforge-0.0.1.dist-info/top_level.txt +5 -0
  134. docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
  135. docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
  136. docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
  137. docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
  138. docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
  139. docs/_build/html/_sources/fastfusion.rst.txt +20 -0
  140. docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
  141. docs/_build/html/_sources/index.rst.txt +87 -0
  142. docs/_build/html/_sources/modules.rst.txt +7 -0
  143. docs/_build/html/_sources/notes/citation.rst.txt +45 -0
  144. docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
  145. docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
  146. docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
  147. docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
  148. docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
  149. docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
  150. docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
  151. docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
  152. docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
  153. docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
  154. docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
  155. docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
  156. docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
  157. docs/_build/html/_sources/notes/spec.rst.txt +36 -0
  158. docs/source/_ext/include_attrs.py +213 -0
  159. docs/source/_ext/include_docstring.py +364 -0
  160. docs/source/_ext/include_functions.py +154 -0
  161. docs/source/_ext/include_notebook.py +131 -0
  162. docs/source/_ext/include_yaml.py +119 -0
  163. docs/source/_ext/inherited_attributes.py +222 -0
  164. docs/source/_ext/paths.py +4 -0
  165. docs/source/conf.py +79 -0
  166. examples/arches/compute_in_memory/_include.yaml +74 -0
  167. examples/arches/compute_in_memory/_include_functions.py +229 -0
  168. examples/arches/compute_in_memory/_load_spec.py +57 -0
  169. examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
  170. examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
  171. examples/arches/compute_in_memory/components/misc.py +195 -0
  172. examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
  173. examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
  174. examples/arches/compute_in_memory/isaac.yaml +233 -0
  175. examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
  176. examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
  177. examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
  178. examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
  179. examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
  180. examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
  181. examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
  182. examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
  183. examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
  184. examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
  185. examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
  186. examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
  187. examples/arches/eyeriss.yaml +68 -0
  188. examples/arches/fanout_variations/at_glb.yaml +31 -0
  189. examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
  190. examples/arches/fanout_variations/at_mac.yaml +31 -0
  191. examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
  192. examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
  193. examples/arches/nvdla.yaml +47 -0
  194. examples/arches/simple.yaml +28 -0
  195. examples/arches/tpu_v4i.yaml +67 -0
  196. examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
  197. examples/misc/component_annotated.yaml +33 -0
  198. examples/workloads/gpt3_6.7B.yaml +124 -0
  199. examples/workloads/matmuls.yaml +20 -0
  200. examples/workloads/mobilenet_28.yaml +81 -0
  201. examples/workloads/mobilenet_various_separate.yaml +106 -0
  202. examples/workloads/three_matmuls_annotated.yaml +59 -0
  203. notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
  204. notebooks/compute_in_memory/_scripts.py +339 -0
  205. notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
  206. notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
  207. notebooks/paths.py +4 -0
  208. notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
  209. notebooks/tutorials/FFM.ipynb +3498 -0
  210. notebooks/tutorials/_include.py +48 -0
  211. notebooks/tutorials/component_energy_area.ipynb +363 -0
  212. tests/Q_mapping.yaml +38 -0
  213. tests/__init__.py +0 -0
  214. tests/conv.mapping.yaml +27 -0
  215. tests/conv.workload.yaml +13 -0
  216. tests/conv_sym.mapping.yaml +43 -0
  217. tests/copy.mapping.yaml +35 -0
  218. tests/copy.workload.yaml +15 -0
  219. tests/distribuffers/__init__.py +0 -0
  220. tests/distribuffers/multicast/test_cases.yaml +482 -0
  221. tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
  222. tests/distribuffers/spec/distributed.yaml +100 -0
  223. tests/distribuffers/spec/logical_arch.yaml +32 -0
  224. tests/distribuffers/spec/physical_arch.yaml +69 -0
  225. tests/distribuffers/test_binding.py +48 -0
  226. tests/frontend/__init__.py +0 -0
  227. tests/frontend/test_mapping_viz.py +52 -0
  228. tests/mapper/__init__.py +0 -0
  229. tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
  230. tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
  231. tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
  232. tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
  233. tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
  234. tests/mapper/test_mapping_to_isl.py +90 -0
  235. tests/mapper/test_spatial_reuse_analysis.py +67 -0
  236. tests/mapper/test_temporal_reuse_analysis.py +56 -0
  237. tests/mapper/util.py +58 -0
  238. tests/matmul.mapping.yaml +29 -0
  239. tests/matmul.workload.yaml +12 -0
  240. tests/matmul_spatial.mapping.yaml +44 -0
  241. tests/mha.renames.yaml +65 -0
  242. tests/mha.workload.yaml +67 -0
  243. tests/mha.yaml +59 -0
  244. tests/mha_full.workload.yaml +67 -0
  245. tests/mobilenet.workload.yaml +35 -0
  246. tests/mobilenet_long.workload.yaml +64 -0
  247. tests/pmappingcache.py +24 -0
  248. tests/processing_stage.arch.yaml +40 -0
  249. tests/snowcat.arch.yaml +36 -0
  250. tests/test_ffm_join_pmappings.py +106 -0
  251. tests/test_ffm_make_pmappings.py +82 -0
  252. tests/test_ffm_make_tile_shapes.py +49 -0
  253. tests/test_mapper.py +100 -0
  254. tests/test_model.py +37 -0
  255. tests/test_plotting.py +72 -0
  256. tests/test_processing_stage.py +46 -0
  257. tests/test_symbolic_model.py +248 -0
  258. tests/test_workload.py +141 -0
@@ -0,0 +1,337 @@
1
+ from collections import defaultdict
2
+ from functools import cached_property
3
+ from typing import Any, Callable, Iterable
4
+ import pandas as pd
5
+ from joblib import delayed
6
+
7
+ from accelforge.mapper.FFM._join_pmappings.pmapping_dataframe import PmappingDataframe
8
+
9
+ from accelforge.mapper.FFM._join_pmappings.compatibility import *
10
+ from accelforge.util import parallel
11
+
12
+
13
+ class PmappingGroup:
14
+ def __init__(self, compatibility: Compatibility, mappings: PmappingDataframe):
15
+ self.compatibility: Compatibility = compatibility
16
+ self.mappings: PmappingDataframe = mappings
17
+ self.tensors: dict[str, TensorReservation] = {
18
+ t.name: t for t in self.compatibility.tensors
19
+ }
20
+ self.n_pre_prune_mappings = 0
21
+
22
+ def compatibility_str(self):
23
+ compatibility = ",".join(str(l) for l in self.compatibility.tensors)
24
+ compatibility += " || " + ", ".join(str(t) for t in self.tensors.values())
25
+ return compatibility
26
+
27
+ @cached_property
28
+ def tensor_names(self) -> set[str]:
29
+ return set(self.tensors)
30
+
31
+ def copy(self) -> "PmappingGroup":
32
+ return PmappingGroup(self.compatibility, self.mappings.copy())
33
+
34
+ def __len__(self) -> int:
35
+ return len(self.mappings)
36
+
37
+ def merge_next(
38
+ self,
39
+ right: "PmappingGroup",
40
+ live_tensors: set[str],
41
+ live_tensors_with_right: set[str],
42
+ aliased_tensors: dict[str, set[str]],
43
+ compatibility_joined: Compatibility,
44
+ ignored_resources: set[str],
45
+ drop_valid_reservations: bool = True,
46
+ delay: bool = False,
47
+ _pmapping_row_filter_function: Callable[[pd.Series], bool] | None = None,
48
+ ) -> "PmappingGroup":
49
+ shared_loop_index = self.compatibility.shared_loop_index(
50
+ right.compatibility.tensor_names | live_tensors
51
+ )
52
+ next_shared_loop_index = compatibility_joined.shared_loop_index(live_tensors)
53
+
54
+ still_live_reservations = [
55
+ r
56
+ for r in self.compatibility.tensors
57
+ if r.name in live_tensors and r.name not in right.compatibility.tensor_names
58
+ ]
59
+
60
+ duplicated_aliased_tensors = set()
61
+ for name, my_tensor in self.tensors.items():
62
+ for aliased_tensor in aliased_tensors.get(name, set()):
63
+ if (aliased_tensor := right.tensors.get(aliased_tensor, None)) is None:
64
+ continue
65
+ if my_tensor.resource_name == aliased_tensor.resource_name:
66
+ duplicated_aliased_tensors.add(aliased_tensor.name)
67
+
68
+ mapping = delayed(self.mappings.merge_next)(
69
+ right.mappings,
70
+ shared_loop_index,
71
+ next_shared_loop_index,
72
+ live_tensors_with_right,
73
+ still_live_reservations,
74
+ duplicated_aliased_tensors,
75
+ compatibility_left=self.compatibility,
76
+ compatibility_right=right.compatibility,
77
+ compatibility_joined=compatibility_joined,
78
+ drop_valid_reservations=drop_valid_reservations,
79
+ _pmapping_row_filter_function=_pmapping_row_filter_function,
80
+ ignored_resources=ignored_resources,
81
+ )
82
+
83
+ if not delay:
84
+ mapping = mapping[0](*mapping[1], **mapping[2])
85
+
86
+ s = PmappingGroup(compatibility_joined, mapping)
87
+ assert (
88
+ compatibility_joined.max_above_loop_index == next_shared_loop_index + 1
89
+ ), f"{self.compatibility} {right.compatibility} {next_shared_loop_index + 1} -> {compatibility_joined} {len(compatibility_joined.loops)}"
90
+ s.tensors.update(right.tensors)
91
+ s.tensors.update(self.tensors)
92
+ s.n_pre_prune_mappings = len(self.mappings.data) * len(right.mappings.data)
93
+ return s
94
+
95
+ def get_shared_loop_index(self, live_tensors: set[str]) -> int:
96
+ live_tensors = list(self.compatibility.tensor_names) + [live_tensors]
97
+ return self.compatibility.shared_loop_index(live_tensors)
98
+
99
+ def _right_consolidate(
100
+ self,
101
+ live_tensors: set[str] = None,
102
+ shared_tensors: set[str] = None,
103
+ ):
104
+ dead_tensors = set(self.tensors) - (live_tensors or set())
105
+ check_tensors = (shared_tensors or set()) | (live_tensors or set())
106
+ shared_loop_index = self.compatibility.shared_loop_index(check_tensors)
107
+ for t in dead_tensors:
108
+ t = self.tensors.pop(t)
109
+ if self.mappings.free_to_loop_index(
110
+ shared_loop_index, live_tensors=live_tensors
111
+ ):
112
+ self.mappings.make_pareto()
113
+ return self
114
+
115
+ def _left_consolidate(self, live_tensors: set[str] = None):
116
+ check_tensors = live_tensors or set()
117
+ shared_loop_index = self.compatibility.shared_loop_index(check_tensors)
118
+ self.mappings.free_to_loop_index(shared_loop_index, live_tensors=live_tensors)
119
+ return self
120
+
121
+ @staticmethod
122
+ def right_consolidate(
123
+ pmapping_groups: list["PmappingGroup"],
124
+ live_tensors: set[str],
125
+ shared_tensors: set[str] = None,
126
+ pbar: str = None,
127
+ parallelize: bool = True,
128
+ ) -> list["PmappingGroup"]:
129
+ def job(s):
130
+ return s._right_consolidate(live_tensors, shared_tensors)
131
+
132
+ if not parallelize:
133
+ return [
134
+ s._right_consolidate(live_tensors, shared_tensors)
135
+ for s in pmapping_groups
136
+ ]
137
+
138
+ return parallel([delayed(job)(s) for s in pmapping_groups], pbar=pbar)
139
+
140
+ @staticmethod
141
+ def left_consolidate(
142
+ pmapping_groups: list["PmappingGroup"],
143
+ live_tensors: set[str],
144
+ pbar: str = None,
145
+ parallelize: bool = True,
146
+ ) -> list["PmappingGroup"]:
147
+ def job(s):
148
+ return s._left_consolidate(live_tensors)
149
+
150
+ if not parallelize:
151
+ return [s._left_consolidate(live_tensors) for s in pmapping_groups]
152
+
153
+ return parallel([delayed(job)(s) for s in pmapping_groups], pbar=pbar)
154
+
155
+ def _hashable_attrs(self):
156
+ return self.mappings, fzs(self.tensors.items())
157
+
158
+ @staticmethod
159
+ def concat(
160
+ pmapping_groups: Iterable["PmappingGroup"],
161
+ allow_different_compatibilies: bool = False,
162
+ ) -> "PmappingGroup":
163
+ pmapping_groups = list(pmapping_groups)
164
+ assert len(pmapping_groups) > 0, "Cannot concat empty list of PmappingGroups"
165
+ if not allow_different_compatibilies:
166
+ s = set(
167
+ s.compatibility.clear_symbolic_tile_patterns() for s in pmapping_groups
168
+ )
169
+ if len(s) > 1:
170
+ a = pmapping_groups[0]
171
+ for b in pmapping_groups[1:]:
172
+ if a.compatibility != b.compatibility:
173
+ break
174
+ PmappingGroup.combine_combineable((a, b), "All")
175
+ assert (
176
+ a == b
177
+ ), f"Cannot concat PmappingGroups with different compatibilies:\n\t{a}\n\t{b}"
178
+ assert len(s) == 1, (
179
+ f"Cannot concat PmappingGroups with different compatibilies:\n\t"
180
+ + "\n\t".join(str(s2) for s2 in s)
181
+ )
182
+
183
+ c0 = pmapping_groups[0].compatibility
184
+ to_concat = [pmapping_groups[0]] + [
185
+ s.rename_compatibility(c0) for s in pmapping_groups[1:]
186
+ ]
187
+ return PmappingGroup(
188
+ c0, PmappingDataframe.concat([s.mappings for s in to_concat])
189
+ )
190
+
191
+ def rename_compatibility(self, new_c: Compatibility) -> Compatibility:
192
+ c, renamed = self.compatibility._rename_to_match(new_c)
193
+ return PmappingGroup(c, self.mappings.rename(renamed))
194
+
195
+ @staticmethod
196
+ def _group(
197
+ pmapping_groups: list["PmappingGroup"],
198
+ live_tensors: set[str] | Literal["All"],
199
+ clear_tile_patterns_and_reservation_indices: bool = False,
200
+ include_permutations: bool = False,
201
+ clear_symbolic_tile_patterns: bool = False,
202
+ try_permute_into_equivalent: bool = False,
203
+ ) -> (
204
+ dict[Compatibility, list["PmappingGroup"]]
205
+ | dict[Compatibility, list[tuple["PmappingGroup", list[int]]]]
206
+ ):
207
+ """
208
+ Clears dead tensors (may keep loops), then group PmappingGroups based on
209
+ compatibility.
210
+ """
211
+ grouped = defaultdict(list)
212
+
213
+ def clear(c: Compatibility):
214
+ if clear_symbolic_tile_patterns:
215
+ c = c.clear_symbolic_tile_patterns()
216
+ if clear_tile_patterns_and_reservation_indices:
217
+ return c.clear_tile_patterns_and_reservation_indices()
218
+ return c
219
+
220
+ for s in pmapping_groups:
221
+ compatibility = s.compatibility.clear_dead_tensors(live_tensors)
222
+
223
+ if include_permutations or try_permute_into_equivalent:
224
+ keys = compatibility.make_equivalent_permutations()
225
+ for t, loop_changes in keys:
226
+ # Line below DOES NOT MUTATE. It's check that the permutation works.
227
+ s.compatibility.permute(loop_changes)
228
+ grouped[clear(t)].append((s, loop_changes))
229
+ else:
230
+ grouped[clear(compatibility)].append(s)
231
+
232
+ if clear_tile_patterns_and_reservation_indices:
233
+ for k in grouped:
234
+ assert (
235
+ len(k.reservation_indices) == 0
236
+ ), f"Extra reservation indices are not empty: {k.reservation_indices}"
237
+
238
+ if try_permute_into_equivalent:
239
+ assert not include_permutations
240
+ new_grouped = {}
241
+ pmgroups_remaining = {id(s) for s in pmapping_groups}
242
+ for c, g in sorted(grouped.items(), key=lambda x: len(x[1]), reverse=True):
243
+ if not pmgroups_remaining:
244
+ break
245
+ g = [
246
+ (s, loop_changes)
247
+ for s, loop_changes in g
248
+ if id(s) in pmgroups_remaining
249
+ ]
250
+ if g:
251
+ pmgroups_remaining -= {id(s) for s, _ in g}
252
+ permuted = [
253
+ PmappingGroup(s.compatibility.permute(lc), s.mappings)
254
+ for s, lc in g
255
+ ]
256
+ new_grouped[c] = permuted
257
+ grouped = new_grouped
258
+
259
+ return grouped
260
+
261
+ @staticmethod
262
+ def combine_combineable(
263
+ pmapping_groups: list["PmappingGroup"],
264
+ live_tensors: set[str] | Literal["All"],
265
+ allow_different_compatibilies: bool = False,
266
+ combine_reservations: bool = True,
267
+ pbar_postfix: str = "",
268
+ ) -> list["PmappingGroup"]:
269
+ pmapping_groups = [s for s in pmapping_groups if len(s.mappings.data) > 0]
270
+ no_combine = []
271
+ if not combine_reservations:
272
+ has_reservations = [s.mappings.has_reservations() for s in pmapping_groups]
273
+ no_combine = [s for s, h in zip(pmapping_groups, has_reservations) if h]
274
+ pmapping_groups = [
275
+ s for s, h in zip(pmapping_groups, has_reservations) if not h
276
+ ]
277
+ groups = list(
278
+ PmappingGroup._group(
279
+ pmapping_groups,
280
+ live_tensors,
281
+ clear_symbolic_tile_patterns=True,
282
+ try_permute_into_equivalent=True,
283
+ ).values()
284
+ )
285
+ groups_with_one = [g[0] for g in groups if len(g) == 1]
286
+ if len(groups_with_one) == len(groups):
287
+ return groups_with_one + no_combine
288
+
289
+ others = parallel(
290
+ [
291
+ delayed(PmappingGroup.concat)(g, allow_different_compatibilies)
292
+ for g in groups
293
+ if len(g) > 1
294
+ ],
295
+ pbar=f"Grouping pmappings{pbar_postfix}",
296
+ )
297
+ return groups_with_one + others + no_combine
298
+
299
+ @staticmethod
300
+ def filter_by_tensors(
301
+ pmapping_groups: list["PmappingGroup"] | dict[Compatibility, Any],
302
+ tensors: set[str],
303
+ ) -> list["PmappingGroup"]:
304
+ def check(tensors_to_check):
305
+ for t in tensors_to_check:
306
+ for t2 in tensors:
307
+ if (t2.name == "*" or t.name == t2.name) and t != t2:
308
+ return False
309
+ return True
310
+
311
+ tensors = set(tensors)
312
+ if isinstance(pmapping_groups, list):
313
+ return [s for s in pmapping_groups if check(s.compatibility.tensors)]
314
+ if isinstance(pmapping_groups, dict):
315
+ return {k: v for k, v in pmapping_groups.items() if check(k.tensors)}
316
+ raise ValueError(f"Invalid type {type(pmapping_groups)}")
317
+
318
+ @staticmethod
319
+ def group(
320
+ pmapping_groups: list["PmappingGroup"], live_tensors: set[str]
321
+ ) -> dict[tuple[Compatibility, ...], list[tuple["PmappingGroup", list[int]]]]:
322
+ x = PmappingGroup._group(
323
+ pmapping_groups,
324
+ live_tensors,
325
+ clear_tile_patterns_and_reservation_indices=True,
326
+ include_permutations=True,
327
+ )
328
+ return x
329
+
330
+ @staticmethod
331
+ def remove_dead_tensors(
332
+ pmapping_groups: list["PmappingGroup"], live_tensors: set[str]
333
+ ):
334
+ for s in pmapping_groups:
335
+ for t in list(s.tensors):
336
+ if t not in live_tensors:
337
+ del s.tensors[t]
@@ -0,0 +1,360 @@
1
+ from collections import defaultdict
2
+ import logging
3
+ from typing import List
4
+ from accelforge._accelerated_imports import np
5
+ from accelforge.frontend._workload_isl._symbolic import PartiallyRelevant, Relevant
6
+ import accelforge.frontend.arch as arch
7
+ from accelforge.frontend.arch import (
8
+ Comparison,
9
+ _MinUsageConstraintLambda,
10
+ _TileShapeConstraintLambda,
11
+ _LoopBoundsConstraintLambda,
12
+ _ConstraintLambda,
13
+ )
14
+ from accelforge.frontend.mapping import (
15
+ Loop,
16
+ MappingNode,
17
+ TensorHolder,
18
+ Temporal,
19
+ Spatial,
20
+ )
21
+ from accelforge.frontend.renames import TensorName
22
+ from accelforge.frontend.workload import EinsumName, RankVariable
23
+ from accelforge.util._setexpressions import InvertibleSet
24
+ from accelforge.util._frozenset import fzs
25
+
26
+
27
+ # =================================================================================================
28
+ # Attach constraints to mapping
29
+ # =================================================================================================
30
+ class MappingConstraints:
31
+ def __init__(self):
32
+ self.tile_shape_constraints: list[_TileShapeConstraintLambda] = []
33
+ self.loop_bounds_constraints: list[_LoopBoundsConstraintLambda] = []
34
+ self.min_usage_constraints: dict[tuple[str, str], _MinUsageConstraintLambda] = (
35
+ {}
36
+ )
37
+
38
+ def get_all_constraints(self) -> list[_ConstraintLambda]:
39
+ return (
40
+ self.tile_shape_constraints
41
+ + self.loop_bounds_constraints
42
+ + list(self.min_usage_constraints.values())
43
+ )
44
+
45
+ def check_tile_shape_constraints(
46
+ self, tile_shapes: np.ndarray, complete_indices: list[int]
47
+ ):
48
+ mask = np.ones(tile_shapes.shape[0], dtype=np.bool)
49
+ for c in self.tile_shape_constraints:
50
+ mask = mask & c(complete_indices, tile_shapes[:, c._target_loop_indices])
51
+ return mask
52
+
53
+ def check_min_usage_constraints(
54
+ self,
55
+ component_name: str,
56
+ name: str,
57
+ usage: np.ndarray,
58
+ complete_indices: list[int],
59
+ ):
60
+ if (component_name, name) not in self.min_usage_constraints:
61
+ return np.ones(usage.shape[0], dtype=np.bool)
62
+
63
+ return self.min_usage_constraints[(component_name, name)](
64
+ complete_indices, usage
65
+ )
66
+
67
+ def set_loop_indices(self, nodes: list[MappingNode]):
68
+ loops = [n for n in nodes if isinstance(n, Loop)]
69
+ for c in self.get_all_constraints():
70
+ c._target_node_indices = [nodes.index(t) for t in c.target_mapping_nodes]
71
+ c._target_loop_indices = [loops.index(t) for t in c.target_mapping_nodes]
72
+
73
+ # Min usage constraints also depend on the loop ABOVE the target loop
74
+ # because the loop above determines the number of tiles
75
+ for c in self.min_usage_constraints.values():
76
+ # Rank variables must be unique between mapping nodes
77
+ rank_variables = set(t.rank_variable for t in c.target_mapping_nodes)
78
+ assert len(rank_variables) == len(
79
+ c.target_mapping_nodes
80
+ ), "Rank variables must be unique between mapping nodes"
81
+
82
+ for target_mapping_node in c.target_mapping_nodes:
83
+ assert isinstance(target_mapping_node, Spatial)
84
+ loop_index = loops.index(target_mapping_node) - 1
85
+ while loop_index >= 0:
86
+ loop = loops[loop_index]
87
+ if loop.rank_variable in rank_variables:
88
+ c._target_loop_indices.append(loop_index)
89
+ c._target_node_indices.append(nodes.index(loop))
90
+ break
91
+ loop_index -= 1
92
+
93
+ def clear_constrained_to_one(
94
+ self, mapping: list["MappingNode"], einsum_name: EinsumName
95
+ ) -> list["MappingNode"]:
96
+ # Not constrained to one --> Can't remove
97
+ node2constraints = defaultdict(list)
98
+ do_not_remove = set()
99
+ for c in self.tile_shape_constraints:
100
+ for t in c.target_mapping_nodes:
101
+ node2constraints[id(t)].append(c)
102
+ do_not_remove.add(id(t))
103
+ for c in self.loop_bounds_constraints:
104
+ if not c.constraint._constrained_to_one():
105
+ for t in c.target_mapping_nodes:
106
+ node2constraints[id(t)].append(c)
107
+ do_not_remove.add(id(t))
108
+
109
+ # Constrained to one --> remove iff not in do_not_remove
110
+ to_remove = set()
111
+ for c in self.loop_bounds_constraints:
112
+ if c.constraint._constrained_to_one():
113
+ my_remove = set(id(t) for t in c.target_mapping_nodes)
114
+ if my_remove & do_not_remove:
115
+ loops = [n for n in mapping if id(n) in my_remove]
116
+ p = len(loops) == 1
117
+ loops = (", ".join(n.compact_str() for n in loops)).strip()
118
+ isare = "is" if p else "are"
119
+ all_others = ", ".join(
120
+ str(c2) for c2 in node2constraints[id(t)] if c2 != c
121
+ )
122
+ logging.warning(
123
+ f"For Einsum {einsum_name}, loop{'s' * (not p)} {loops} "
124
+ f"{isare} set to be removed by {c} and also appear{'s' * p} in "
125
+ f"{all_others}. The loop{'s' * (not p)} will not be removed "
126
+ f"from the mapping, but it may be subject to conflicting "
127
+ f"constraints."
128
+ )
129
+
130
+ c.target_mapping_nodes = [
131
+ t for t in c.target_mapping_nodes if id(t) not in my_remove
132
+ ]
133
+ to_remove.update(my_remove)
134
+ self.loop_bounds_constraints = [
135
+ c
136
+ for c in self.loop_bounds_constraints
137
+ if not c.constraint._constrained_to_one()
138
+ ]
139
+
140
+ for c in self.get_all_constraints():
141
+ c.target_mapping_nodes = [
142
+ n for n in c.target_mapping_nodes if id(n) not in to_remove
143
+ ]
144
+
145
+ return [m for m in mapping if id(m) not in to_remove]
146
+
147
+ def pretty_str(self) -> str:
148
+ s = ""
149
+ all_constraints = self.get_all_constraints()
150
+ s += "Tile shape constraints:\n"
151
+ for c in self.tile_shape_constraints:
152
+ s += f"\t{all_constraints.index(c)} {c.pretty_str()}\n"
153
+ s += "Loop bounds constraints:\n"
154
+ for c in self.loop_bounds_constraints:
155
+ s += f"\t{all_constraints.index(c)} {c.pretty_str()}\n"
156
+ s += "Min usage constraints:\n"
157
+ for c in self.min_usage_constraints.values():
158
+ s += f"\t{all_constraints.index(c)} {c.pretty_str()}\n"
159
+ return s
160
+
161
+ def remove_missing_targets(self, mapping: list[MappingNode]):
162
+ for c in self.get_all_constraints():
163
+ c.target_mapping_nodes = [n for n in c.target_mapping_nodes if n in mapping]
164
+
165
+ self.tile_shape_constraints = [c for c in self.tile_shape_constraints if c]
166
+ self.loop_bounds_constraints = [c for c in self.loop_bounds_constraints if c]
167
+ self.min_usage_constraints = {
168
+ k: c for k, c in self.min_usage_constraints.items() if c
169
+ }
170
+
171
+
172
+ def first_tensor_holder_index(mapping: list["MappingNode"], memory_name: str) -> int:
173
+ for i, m in enumerate(mapping):
174
+ if isinstance(m, TensorHolder) and m.component == memory_name:
175
+ return i
176
+ return None
177
+
178
+
179
+ def constrained_loops(
180
+ mapping: list["MappingNode"],
181
+ rank_variables: set[RankVariable],
182
+ start_index: int = None,
183
+ look_behind: bool = False,
184
+ component: str = None,
185
+ one_loop_per_rank_variable: bool = True,
186
+ ) -> list[Loop]:
187
+ nodes = []
188
+ remaining_rank_variables = set(rank_variables)
189
+
190
+ if look_behind:
191
+ to_check = list(enumerate(mapping))
192
+ to_check.reverse()
193
+ if start_index is not None:
194
+ to_check = [
195
+ m for i, m in to_check if start_index is None or i <= start_index
196
+ ]
197
+ else:
198
+ to_check = list(enumerate(mapping))
199
+ to_check = [m for i, m in to_check if start_index is None or i >= start_index]
200
+
201
+ for m in to_check:
202
+ if not isinstance(m, Loop):
203
+ continue
204
+ if component is not None and (
205
+ not isinstance(m, Spatial) or m.component != component
206
+ ):
207
+ continue
208
+ assert isinstance(m.rank_variable, RankVariable)
209
+ if m.rank_variable in remaining_rank_variables:
210
+ nodes.append(m)
211
+ if one_loop_per_rank_variable:
212
+ remaining_rank_variables.discard(m.rank_variable)
213
+ # TODO: what is this supposed to do?
214
+ # for r in remaining_rank_variables:
215
+ # assert (
216
+ # component is None
217
+ # ), "There should be a spatial loop for every rank variable"
218
+ return nodes
219
+
220
+
221
+ def get_constraints(
222
+ flattened_arch: list[arch.Leaf],
223
+ mapping: List[MappingNode],
224
+ symbol_table: dict[str, InvertibleSet],
225
+ einsum_name: EinsumName,
226
+ tensor_to_relevancy: dict[
227
+ TensorName, dict[RankVariable, Relevant | PartiallyRelevant]
228
+ ],
229
+ ) -> tuple[List[MappingNode], MappingConstraints]:
230
+
231
+ constraints = MappingConstraints()
232
+
233
+ # Tensor constraints
234
+ for m in flattened_arch:
235
+ # Ignore if not a memory
236
+ if not isinstance(m, arch.Memory):
237
+ continue
238
+
239
+ # Ignore if it doesn't hold any tensors
240
+ if (index := first_tensor_holder_index(mapping, m.name)) is None:
241
+ continue
242
+
243
+ # Tile shape constraints
244
+ for c in m.tensors.tile_shape:
245
+ nodes = constrained_loops(
246
+ mapping, c.expression, index - 1, look_behind=True
247
+ )
248
+ for exp in c._split_expression():
249
+ new_nodes = [n for n in nodes if n.rank_variable in exp]
250
+ constraint = _TileShapeConstraintLambda(c, new_nodes, exp)
251
+ constraints.tile_shape_constraints.append(constraint)
252
+
253
+ exp = symbol_table[m.name] & m.tensors.no_refetch_from_above
254
+
255
+ nodes = []
256
+ for no_refetch in exp.iter_one_element_sets():
257
+ # Start from the first index of the tensor holder, stop at index - 1
258
+ start_index = 0
259
+ n = next(iter(no_refetch))
260
+ while start_index < len(mapping):
261
+ if (
262
+ isinstance(mapping[start_index], TensorHolder)
263
+ and n in mapping[start_index].tensors
264
+ ):
265
+ break
266
+ start_index += 1
267
+
268
+ end_index = start_index
269
+ while end_index < len(mapping):
270
+ if (
271
+ isinstance(mapping[end_index], TensorHolder)
272
+ and n in mapping[end_index].tensors
273
+ and mapping[end_index].component == m.name
274
+ ):
275
+ break
276
+ end_index += 1
277
+
278
+ for i in range(start_index, end_index):
279
+ if isinstance(mapping[i], Loop) and not isinstance(
280
+ tensor_to_relevancy[n][mapping[i].rank_variable], Relevant
281
+ ):
282
+ if mapping[i] not in nodes:
283
+ nodes.append(mapping[i])
284
+
285
+ if nodes:
286
+ constraints.loop_bounds_constraints.append(
287
+ _LoopBoundsConstraintLambda(
288
+ Comparison(expression=exp, operator="==", value=1), nodes, exp
289
+ )
290
+ )
291
+
292
+ # Spatial constraints
293
+ for m in flattened_arch:
294
+ if not isinstance(m, (arch.Memory, arch.Fanout)):
295
+ continue
296
+
297
+ for dim in m.spatial:
298
+ loops = [
299
+ n
300
+ for n in mapping
301
+ if isinstance(n, Spatial)
302
+ and (n.component, n.name) == (m.name, dim.name)
303
+ ]
304
+ loop_bounds = list(dim.loop_bounds)
305
+ if dim.reuse:
306
+ loop_bounds.append(
307
+ Comparison(
308
+ expression=dim.reuse.rank_variables,
309
+ operator="==",
310
+ value=1,
311
+ )
312
+ )
313
+ loop_bounds[-1]._str_repr = f"reuse {set(dim.reuse)}"
314
+
315
+ # Loop bounds constraints
316
+ if loop_bounds:
317
+ for c in loop_bounds:
318
+ nodes = constrained_loops(loops, c.expression, component=m.name)
319
+ for exp in c._split_expression():
320
+ new_nodes = [l for l in loops if l.rank_variable in exp]
321
+ constraint = _LoopBoundsConstraintLambda(c, new_nodes, exp)
322
+ constraints.loop_bounds_constraints.append(constraint)
323
+
324
+ # Min usage constraints
325
+ target_mapping_nodes = [
326
+ n
327
+ for n in mapping
328
+ if isinstance(n, Spatial)
329
+ and n.component == m.name
330
+ and n.name == dim.name
331
+ ]
332
+ if dim.min_usage > 0:
333
+ if not target_mapping_nodes:
334
+ continue
335
+ rank_variables = {t.rank_variable for t in target_mapping_nodes}
336
+ constraint = _MinUsageConstraintLambda(
337
+ target_mapping_nodes,
338
+ rank_variables,
339
+ dim.min_usage,
340
+ )
341
+ key = (m.name, dim.name)
342
+ constraints.min_usage_constraints[key] = constraint
343
+
344
+ for t in target_mapping_nodes:
345
+ t._may_reuse = dim.may_reuse
346
+
347
+ # Additional spatial constraints
348
+ for m in mapping:
349
+ if isinstance(m, Spatial) and m._constrained_to_one:
350
+ constraints.loop_bounds_constraints.append(
351
+ _LoopBoundsConstraintLambda(
352
+ Comparison(expression=m.rank_variable, operator="==", value=1),
353
+ [m],
354
+ m.rank_variable,
355
+ )
356
+ )
357
+
358
+ mapping = constraints.clear_constrained_to_one(mapping, einsum_name)
359
+
360
+ return mapping, constraints
@@ -0,0 +1 @@
1
+ from .make_pmapping_templates import make_pmapping_templates