accelforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. accelforge/__init__.py +21 -0
  2. accelforge/_accelerated_imports.py +16 -0
  3. accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
  4. accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
  5. accelforge/_deprecate/_simanneal/simanneal.py +666 -0
  6. accelforge/_deprecate/_simanneal/tracking.py +105 -0
  7. accelforge/_deprecate/_simanneal/wrappers.py +218 -0
  8. accelforge/_deprecate/_simanneal2/__init__.py +7 -0
  9. accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
  10. accelforge/_deprecate/_simanneal2/tracking.py +116 -0
  11. accelforge/_deprecate/compatibility_util.py +181 -0
  12. accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
  13. accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
  14. accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
  15. accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
  16. accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
  17. accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
  18. accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
  19. accelforge/_deprecate/tags.py +69 -0
  20. accelforge/_deprecate/viz/__init__.py +0 -0
  21. accelforge/_deprecate/viz/interactive.py +159 -0
  22. accelforge/_deprecate/viz/reservationtree.py +307 -0
  23. accelforge/_deprecate/viz/ski_slope.py +88 -0
  24. accelforge/_version.py +15 -0
  25. accelforge/examples.py +39 -0
  26. accelforge/frontend/__init__.py +10 -0
  27. accelforge/frontend/_binding.py +129 -0
  28. accelforge/frontend/_workload_isl/__init__.py +2 -0
  29. accelforge/frontend/_workload_isl/_isl.py +149 -0
  30. accelforge/frontend/_workload_isl/_symbolic.py +141 -0
  31. accelforge/frontend/arch copy.py +1544 -0
  32. accelforge/frontend/arch.py +1642 -0
  33. accelforge/frontend/config.py +63 -0
  34. accelforge/frontend/mapper/__init__.py +5 -0
  35. accelforge/frontend/mapper/ffm.py +126 -0
  36. accelforge/frontend/mapper/mapper.py +7 -0
  37. accelforge/frontend/mapper/metrics.py +30 -0
  38. accelforge/frontend/mapping/__init__.py +1 -0
  39. accelforge/frontend/mapping/mapping.py +1736 -0
  40. accelforge/frontend/model.py +14 -0
  41. accelforge/frontend/renames.py +150 -0
  42. accelforge/frontend/spec copy.py +230 -0
  43. accelforge/frontend/spec.py +301 -0
  44. accelforge/frontend/variables.py +12 -0
  45. accelforge/frontend/workload.py +952 -0
  46. accelforge/mapper/FFM/__init__.py +9 -0
  47. accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
  48. accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
  49. accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
  50. accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
  51. accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
  52. accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
  53. accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
  54. accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
  55. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
  56. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
  57. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
  58. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
  59. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
  60. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
  61. accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
  62. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
  63. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
  64. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
  65. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
  66. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
  67. accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
  68. accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
  69. accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
  70. accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
  71. accelforge/mapper/FFM/data.py +61 -0
  72. accelforge/mapper/FFM/main copy.py +236 -0
  73. accelforge/mapper/FFM/main.py +208 -0
  74. accelforge/mapper/FFM/mappings.py +510 -0
  75. accelforge/mapper/FFM/pmappings.py +310 -0
  76. accelforge/mapper/__init__.py +4 -0
  77. accelforge/mapper.py +0 -0
  78. accelforge/model/__init__.py +1 -0
  79. accelforge/model/_looptree/__init__.py +0 -0
  80. accelforge/model/_looptree/accesses.py +335 -0
  81. accelforge/model/_looptree/capacity/__init__.py +1 -0
  82. accelforge/model/_looptree/capacity/aggregators.py +36 -0
  83. accelforge/model/_looptree/capacity/capacity.py +47 -0
  84. accelforge/model/_looptree/energy.py +150 -0
  85. accelforge/model/_looptree/equivalent_ranks.py +29 -0
  86. accelforge/model/_looptree/latency/__init__.py +1 -0
  87. accelforge/model/_looptree/latency/latency.py +98 -0
  88. accelforge/model/_looptree/latency/memory.py +120 -0
  89. accelforge/model/_looptree/latency/processors.py +92 -0
  90. accelforge/model/_looptree/mapping_utilities.py +71 -0
  91. accelforge/model/_looptree/reuse/__init__.py +4 -0
  92. accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
  93. accelforge/model/_looptree/reuse/isl/des.py +59 -0
  94. accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
  95. accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
  96. accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
  97. accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
  98. accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
  99. accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
  100. accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
  101. accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
  102. accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
  103. accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
  104. accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
  105. accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
  106. accelforge/model/_looptree/run.py +122 -0
  107. accelforge/model/_looptree/types.py +26 -0
  108. accelforge/model/_looptree/visualization/__init__.py +0 -0
  109. accelforge/model/_looptree/visualization/occupancy.py +11 -0
  110. accelforge/model/main.py +222 -0
  111. accelforge/plotting/__init__.py +2 -0
  112. accelforge/plotting/mappings.py +219 -0
  113. accelforge/plotting/specs.py +57 -0
  114. accelforge/util/__init__.py +4 -0
  115. accelforge/util/_base_analysis_types.py +24 -0
  116. accelforge/util/_basetypes.py +1089 -0
  117. accelforge/util/_frozenset.py +36 -0
  118. accelforge/util/_isl.py +29 -0
  119. accelforge/util/_itertools.py +14 -0
  120. accelforge/util/_mathfuncs.py +57 -0
  121. accelforge/util/_parse_expressions.py +339 -0
  122. accelforge/util/_picklecache.py +32 -0
  123. accelforge/util/_setexpressions.py +268 -0
  124. accelforge/util/_sympy/__init__.py +0 -0
  125. accelforge/util/_sympy/broadcast_max.py +18 -0
  126. accelforge/util/_visualization.py +112 -0
  127. accelforge/util/_yaml.py +579 -0
  128. accelforge/util/parallel.py +193 -0
  129. accelforge-0.0.1.dist-info/METADATA +64 -0
  130. accelforge-0.0.1.dist-info/RECORD +258 -0
  131. accelforge-0.0.1.dist-info/WHEEL +5 -0
  132. accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
  133. accelforge-0.0.1.dist-info/top_level.txt +5 -0
  134. docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
  135. docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
  136. docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
  137. docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
  138. docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
  139. docs/_build/html/_sources/fastfusion.rst.txt +20 -0
  140. docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
  141. docs/_build/html/_sources/index.rst.txt +87 -0
  142. docs/_build/html/_sources/modules.rst.txt +7 -0
  143. docs/_build/html/_sources/notes/citation.rst.txt +45 -0
  144. docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
  145. docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
  146. docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
  147. docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
  148. docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
  149. docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
  150. docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
  151. docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
  152. docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
  153. docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
  154. docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
  155. docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
  156. docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
  157. docs/_build/html/_sources/notes/spec.rst.txt +36 -0
  158. docs/source/_ext/include_attrs.py +213 -0
  159. docs/source/_ext/include_docstring.py +364 -0
  160. docs/source/_ext/include_functions.py +154 -0
  161. docs/source/_ext/include_notebook.py +131 -0
  162. docs/source/_ext/include_yaml.py +119 -0
  163. docs/source/_ext/inherited_attributes.py +222 -0
  164. docs/source/_ext/paths.py +4 -0
  165. docs/source/conf.py +79 -0
  166. examples/arches/compute_in_memory/_include.yaml +74 -0
  167. examples/arches/compute_in_memory/_include_functions.py +229 -0
  168. examples/arches/compute_in_memory/_load_spec.py +57 -0
  169. examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
  170. examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
  171. examples/arches/compute_in_memory/components/misc.py +195 -0
  172. examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
  173. examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
  174. examples/arches/compute_in_memory/isaac.yaml +233 -0
  175. examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
  176. examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
  177. examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
  178. examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
  179. examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
  180. examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
  181. examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
  182. examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
  183. examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
  184. examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
  185. examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
  186. examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
  187. examples/arches/eyeriss.yaml +68 -0
  188. examples/arches/fanout_variations/at_glb.yaml +31 -0
  189. examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
  190. examples/arches/fanout_variations/at_mac.yaml +31 -0
  191. examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
  192. examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
  193. examples/arches/nvdla.yaml +47 -0
  194. examples/arches/simple.yaml +28 -0
  195. examples/arches/tpu_v4i.yaml +67 -0
  196. examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
  197. examples/misc/component_annotated.yaml +33 -0
  198. examples/workloads/gpt3_6.7B.yaml +124 -0
  199. examples/workloads/matmuls.yaml +20 -0
  200. examples/workloads/mobilenet_28.yaml +81 -0
  201. examples/workloads/mobilenet_various_separate.yaml +106 -0
  202. examples/workloads/three_matmuls_annotated.yaml +59 -0
  203. notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
  204. notebooks/compute_in_memory/_scripts.py +339 -0
  205. notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
  206. notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
  207. notebooks/paths.py +4 -0
  208. notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
  209. notebooks/tutorials/FFM.ipynb +3498 -0
  210. notebooks/tutorials/_include.py +48 -0
  211. notebooks/tutorials/component_energy_area.ipynb +363 -0
  212. tests/Q_mapping.yaml +38 -0
  213. tests/__init__.py +0 -0
  214. tests/conv.mapping.yaml +27 -0
  215. tests/conv.workload.yaml +13 -0
  216. tests/conv_sym.mapping.yaml +43 -0
  217. tests/copy.mapping.yaml +35 -0
  218. tests/copy.workload.yaml +15 -0
  219. tests/distribuffers/__init__.py +0 -0
  220. tests/distribuffers/multicast/test_cases.yaml +482 -0
  221. tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
  222. tests/distribuffers/spec/distributed.yaml +100 -0
  223. tests/distribuffers/spec/logical_arch.yaml +32 -0
  224. tests/distribuffers/spec/physical_arch.yaml +69 -0
  225. tests/distribuffers/test_binding.py +48 -0
  226. tests/frontend/__init__.py +0 -0
  227. tests/frontend/test_mapping_viz.py +52 -0
  228. tests/mapper/__init__.py +0 -0
  229. tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
  230. tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
  231. tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
  232. tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
  233. tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
  234. tests/mapper/test_mapping_to_isl.py +90 -0
  235. tests/mapper/test_spatial_reuse_analysis.py +67 -0
  236. tests/mapper/test_temporal_reuse_analysis.py +56 -0
  237. tests/mapper/util.py +58 -0
  238. tests/matmul.mapping.yaml +29 -0
  239. tests/matmul.workload.yaml +12 -0
  240. tests/matmul_spatial.mapping.yaml +44 -0
  241. tests/mha.renames.yaml +65 -0
  242. tests/mha.workload.yaml +67 -0
  243. tests/mha.yaml +59 -0
  244. tests/mha_full.workload.yaml +67 -0
  245. tests/mobilenet.workload.yaml +35 -0
  246. tests/mobilenet_long.workload.yaml +64 -0
  247. tests/pmappingcache.py +24 -0
  248. tests/processing_stage.arch.yaml +40 -0
  249. tests/snowcat.arch.yaml +36 -0
  250. tests/test_ffm_join_pmappings.py +106 -0
  251. tests/test_ffm_make_pmappings.py +82 -0
  252. tests/test_ffm_make_tile_shapes.py +49 -0
  253. tests/test_mapper.py +100 -0
  254. tests/test_model.py +37 -0
  255. tests/test_plotting.py +72 -0
  256. tests/test_processing_stage.py +46 -0
  257. tests/test_symbolic_model.py +248 -0
  258. tests/test_workload.py +141 -0
@@ -0,0 +1,105 @@
1
+ import time
2
+
3
+
4
+ class EvaluationsScoreTracker:
5
+ def __init__(
6
+ self, max_evaluations: int, stop_at_score: float, print_period: int = 10
7
+ ):
8
+ self.max_evaluations = max_evaluations
9
+ self.stop_at_score = stop_at_score
10
+ self.evaluations = 0
11
+ self.score = float("inf")
12
+ self.history = [(0, float("inf"))]
13
+ self._scale_by = 1
14
+ self.print_period = print_period
15
+ self.prev_print_time = None
16
+ self.print_stopped_text = False
17
+ self.n_mappings = {}
18
+ self.runtime = {}
19
+
20
+ def add_evaluation(self, n_evaluations: int, best_score: float):
21
+ self.evaluations += n_evaluations * self._scale_by
22
+ self.score = min(self.score, best_score)
23
+ # Same score as before, remove the last entry
24
+ if len(self.history) > 2 and self.history[-2][1] == self.score:
25
+ self.history.pop(-1)
26
+ self.history.append((self.evaluations, self.score))
27
+
28
+ cur_time = time.time()
29
+ if (
30
+ self.prev_print_time is None
31
+ or cur_time - self.prev_print_time > self.print_period
32
+ ):
33
+ self.prev_print_time = cur_time
34
+ print(f"Evaluations: {self.evaluations}, Score: {self.score}")
35
+
36
+ if self.max_evaluations is not None and self.evaluations > self.max_evaluations:
37
+ self.clean_history()
38
+ if not self.print_stopped_text:
39
+ print(
40
+ f"Stopping due to evaluations {self.evaluations} > {self.max_evaluations}"
41
+ )
42
+ self.print_stopped_text = True
43
+ raise StopIteration
44
+ return True
45
+ if self.stop_at_score is not None and self.score < self.stop_at_score:
46
+ self.clean_history()
47
+ if not self.print_stopped_text:
48
+ print(f"Stopping due to score {self.score} < {self.stop_at_score}")
49
+ self.print_stopped_text = True
50
+ raise StopIteration
51
+ return True
52
+ return False
53
+
54
+ def multiply_scale_by(self, scale_by: float):
55
+ self._scale_by *= scale_by
56
+
57
+ def __repr__(self):
58
+ return f"Evaluations: {self.evaluations}, Score: {self.score}"
59
+
60
+ def __str__(self):
61
+ return f"Evaluations: {self.evaluations}, Score: {self.score}"
62
+
63
+ def clean_history(self):
64
+ keep_indices = [0]
65
+ for i in range(1, len(self.history) - 1):
66
+ if (
67
+ self.history[i][1] != self.history[i - 1][1]
68
+ or self.history[i][1] != self.history[i + 1][1]
69
+ ):
70
+ keep_indices.append(i)
71
+ keep_indices.append(len(self.history) - 1)
72
+ self.history = [self.history[i] for i in keep_indices]
73
+
74
+ def merge_with(self, other: "EvaluationsScoreTracker"):
75
+ self.score = min(self.score, other.score)
76
+ self.evaluations += other.evaluations
77
+
78
+ i, j = 1, 1
79
+ history = [(0, float("inf"))]
80
+ cur_score = float("inf")
81
+ cur_evaluations = 0
82
+ while i < len(self.history) or j < len(other.history):
83
+ # Grab whichever has the lowest evaluations
84
+ if i < len(self.history) and (
85
+ j == len(other.history) or self.history[i][0] < other.history[j][0]
86
+ ):
87
+ new_evaluations = self.history[i][0] - self.history[i - 1][0]
88
+ new_score = self.history[i][1]
89
+ cur_evaluations += new_evaluations
90
+ cur_score = min(cur_score, new_score)
91
+ history.append((cur_evaluations, cur_score))
92
+ i += 1
93
+ elif j < len(other.history):
94
+ new_evaluations = other.history[j][0] - other.history[j - 1][0]
95
+ new_score = other.history[j][1]
96
+ cur_evaluations += new_evaluations
97
+ cur_score = min(cur_score, new_score)
98
+ history.append((cur_evaluations, cur_score))
99
+ j += 1
100
+ self.history = history
101
+ self.clean_history()
102
+
103
+ def increase_all_evaluations(self, n_evaluations: int):
104
+ self.evaluations += n_evaluations
105
+ self.history = [(e + n_evaluations, s) for e, s in self.history]
@@ -0,0 +1,218 @@
1
+ from collections import defaultdict
2
+ import copy
3
+ import itertools
4
+ import time
5
+ from joblib import delayed
6
+ from fastfusion._accelerated_imports import pd
7
+ from fastfusion.frontend import arch
8
+ from fastfusion.frontend.spec import Spec
9
+ from fastfusion.mapper.FFM._join_pmappings.sim import PmappingGroup, Loop, Compatibility
10
+ from fastfusion.mapper.FFM._join_pmappings.pmapping_group import (
11
+ PmappingDataframe,
12
+ is_reservation_col,
13
+ )
14
+ from fastfusion.mapper.simanneal.simanneal import MapspaceGlobals, _fuse_sims
15
+ from fastfusion.mapper.simanneal.tracking import EvaluationsScoreTracker
16
+ from fastfusion.util._frozenset import fzs
17
+ from fastfusion.util.parallel import parallel, util
18
+
19
+
20
+ def mapping2sims(einsum_to_result: Compatibility):
21
+ r = {}
22
+ for einsum_name, compat_dict in einsum_to_result.items():
23
+ r[einsum_name] = [paretofy(k, v) for k, v in compat_dict.items()]
24
+ return list(r.values())
25
+
26
+
27
+ def paretofy(k, v):
28
+ return PmappingGroup(k, PmappingDataframe(pd.DataFrame(v).fillna(0)))
29
+
30
+
31
+ def get_possible_translations(
32
+ t: Compatibility,
33
+ pairwise_equivalent_rank_variables: dict[str, set[str]],
34
+ full_equivalent_rank_variables: dict[str, set[str]],
35
+ right_rank_variables: set[str],
36
+ ):
37
+ # Fused ranks should be transitive, but if a fused loop indexes into two
38
+ # different ranks in the next Einsum, we can't fuse becuase it will tile in
39
+ # multiple directions.
40
+ #
41
+ # The first union checks what loops we CAN fuse with in the next Einsum. The
42
+ # second union checks what loops MUST index into in the next
43
+ #
44
+ # Einsum. If we alias into multiple ranks, we can't fuse. Otherwise, try out
45
+ # each possible rank.
46
+ def translate_loop(l: Loop):
47
+ compatible_rank_variables = (
48
+ set.union(
49
+ *(full_equivalent_rank_variables[n] for n in l.rank_variable_names)
50
+ )
51
+ & right_rank_variables
52
+ )
53
+ pairwise_compatible_rank_variables = (
54
+ set.union(
55
+ *(pairwise_equivalent_rank_variables[n] for n in l.rank_variable_names)
56
+ )
57
+ & right_rank_variables
58
+ )
59
+ if len(pairwise_compatible_rank_variables) > 1:
60
+ return
61
+ for n in compatible_rank_variables:
62
+ yield Loop(fzs((n,)), l.bound, l.is_spatial)
63
+
64
+ for loops in itertools.product(*map(translate_loop, t.loops)):
65
+ yield t.update(loops=loops)
66
+
67
+
68
+ prev_time = 0
69
+ total_time = defaultdict(int)
70
+
71
+
72
+ def init_print_time():
73
+ global prev_time, total_time
74
+ prev_time = time.time()
75
+ total_time = defaultdict(int)
76
+
77
+
78
+ def print_time(what: str):
79
+ global prev_time
80
+ t = time.time() - prev_time
81
+ print(f"{what}: {t:.2f} seconds")
82
+ total_time[what] += t
83
+ prev_time = time.time()
84
+
85
+
86
+ def print_total_time():
87
+ print(f"\n======== Total time ========")
88
+ for k, v in total_time.items():
89
+ print(f"{k}: {v:.2f} seconds")
90
+ total = sum(total_time.values())
91
+ if total > 60:
92
+ print(f"\nTotal: {total:.2f} seconds ({total/60:.2f} minutes)")
93
+ else:
94
+ print(f"\nTotal: {total:.2f} seconds")
95
+ print(f"============================\n")
96
+
97
+
98
+ class PmappingsOneEinsum:
99
+ def __init__(self, einsum_name: str, pm_group_list: list[PmappingGroup]):
100
+ self.einsum_name: str = einsum_name
101
+ self.pmapping_groups: list[PmappingGroup] = pm_group_list
102
+ self.tensor_names: set[str] = set(pm_group_list[0].tensor_names)
103
+
104
+ def __getitem__(self, i):
105
+ return self.pmapping_groups[i]
106
+
107
+
108
+ def make_full_equivalent_rank_variables(pairwise_equivalent_rank_variables):
109
+ full_equivalent_rank_variables = {
110
+ k: set(v) for k, v in pairwise_equivalent_rank_variables.items()
111
+ }
112
+ changed = True
113
+ while changed:
114
+ changed = False
115
+ for r in full_equivalent_rank_variables:
116
+ for r2 in list(full_equivalent_rank_variables[r]):
117
+ for r3 in list(full_equivalent_rank_variables[r2]):
118
+ if r3 in full_equivalent_rank_variables[r]:
119
+ continue
120
+ changed = True
121
+ full_equivalent_rank_variables[r].add(r3)
122
+ return full_equivalent_rank_variables
123
+
124
+
125
+ def get_pmappings_data(
126
+ pmapping_groups: dict[str, list[PmappingGroup]],
127
+ evaluations_tracker,
128
+ spec: Spec = None,
129
+ flattened_architecture: list[arch.Leaf] = None,
130
+ ):
131
+ resource2capacity = {}
132
+ flattened_architecture = flattened_architecture or spec.get_flattened_architecture()
133
+ for l in flattened_architecture:
134
+ if isinstance(l, arch.Memory):
135
+ resource2capacity[l.name] = l.size
136
+
137
+ pairwise_equivalent_rank_variables = (
138
+ spec.workload.get_pairwise_equivalent_rank_variables()
139
+ )
140
+
141
+ aliased_tensors = spec.workload.get_tensor_copies()
142
+
143
+ full_equivalent_rank_variables = make_full_equivalent_rank_variables(
144
+ pairwise_equivalent_rank_variables
145
+ )
146
+
147
+ return (
148
+ pmapping_groups,
149
+ evaluations_tracker,
150
+ spec,
151
+ flattened_architecture,
152
+ resource2capacity,
153
+ pairwise_equivalent_rank_variables,
154
+ aliased_tensors,
155
+ full_equivalent_rank_variables,
156
+ )
157
+
158
+
159
+ def join_pmappings(
160
+ pmapping_groups: dict[str, list[PmappingGroup]],
161
+ evaluations_tracker: EvaluationsScoreTracker,
162
+ algorithm: str,
163
+ spec: Spec = None,
164
+ flattened_architecture: list[arch.Leaf] = None,
165
+ ) -> PmappingDataframe:
166
+ objective_function_cols = None
167
+ cols = next(iter(pmapping_groups.values()))[0].mappings.data.columns
168
+ if objective_function_cols is None:
169
+ objective_function_cols = [c for c in cols if "Total" in c]
170
+ keepcols = []
171
+
172
+ for pm_group_list in pmapping_groups.values():
173
+ for sim in pm_group_list:
174
+ for col in objective_function_cols:
175
+ if col not in sim.mappings.data.columns:
176
+ sim.mappings.data[col] = 0
177
+ reservations = [
178
+ c for c in sim.mappings.data.columns if is_reservation_col(c)
179
+ ]
180
+ sim.mappings._data = sim.mappings.data[
181
+ objective_function_cols + keepcols + reservations
182
+ ]
183
+
184
+ mapspace_globals = MapspaceGlobals(
185
+ pmapping_groups,
186
+ spec,
187
+ objective_function_cols,
188
+ flattened_architecture,
189
+ )
190
+
191
+ n_threads = util.N_PARALLEL_PROCESSES
192
+ while n_threads >= 1:
193
+ try:
194
+ results_and_trackers = parallel(
195
+ [
196
+ delayed(_fuse_sims)(
197
+ mapspace_globals,
198
+ n_threads=n_threads,
199
+ evaluations_tracker=copy.deepcopy(evaluations_tracker),
200
+ algorithm=algorithm,
201
+ )
202
+ for _ in range(n_threads)
203
+ ],
204
+ n_jobs=n_threads if util.PARALLELIZE else 1,
205
+ )
206
+ results = pd.concat([r[0] for r in results_and_trackers])
207
+ break
208
+ except OSError as e:
209
+ if n_threads == 1:
210
+ raise OSError("Failed to fuse pmapping_groups with 1 thread") from e
211
+ print(
212
+ f"Failed to fuse pmapping_groups with {n_threads} threads, trying with {n_threads // 2}"
213
+ )
214
+ n_threads //= 2
215
+
216
+ for t in results_and_trackers:
217
+ evaluations_tracker.merge_with(t[1])
218
+ return results
@@ -0,0 +1,7 @@
1
+ from fastfusion.mapper.FFM.main import (
2
+ make_pmappings,
3
+ MultiEinsumPmappings,
4
+ Mappings,
5
+ )
6
+ from fastfusion.frontend.mapper.metrics import Metrics
7
+ from .simanneal import join_pmappings