accelforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. accelforge/__init__.py +21 -0
  2. accelforge/_accelerated_imports.py +16 -0
  3. accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
  4. accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
  5. accelforge/_deprecate/_simanneal/simanneal.py +666 -0
  6. accelforge/_deprecate/_simanneal/tracking.py +105 -0
  7. accelforge/_deprecate/_simanneal/wrappers.py +218 -0
  8. accelforge/_deprecate/_simanneal2/__init__.py +7 -0
  9. accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
  10. accelforge/_deprecate/_simanneal2/tracking.py +116 -0
  11. accelforge/_deprecate/compatibility_util.py +181 -0
  12. accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
  13. accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
  14. accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
  15. accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
  16. accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
  17. accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
  18. accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
  19. accelforge/_deprecate/tags.py +69 -0
  20. accelforge/_deprecate/viz/__init__.py +0 -0
  21. accelforge/_deprecate/viz/interactive.py +159 -0
  22. accelforge/_deprecate/viz/reservationtree.py +307 -0
  23. accelforge/_deprecate/viz/ski_slope.py +88 -0
  24. accelforge/_version.py +15 -0
  25. accelforge/examples.py +39 -0
  26. accelforge/frontend/__init__.py +10 -0
  27. accelforge/frontend/_binding.py +129 -0
  28. accelforge/frontend/_workload_isl/__init__.py +2 -0
  29. accelforge/frontend/_workload_isl/_isl.py +149 -0
  30. accelforge/frontend/_workload_isl/_symbolic.py +141 -0
  31. accelforge/frontend/arch copy.py +1544 -0
  32. accelforge/frontend/arch.py +1642 -0
  33. accelforge/frontend/config.py +63 -0
  34. accelforge/frontend/mapper/__init__.py +5 -0
  35. accelforge/frontend/mapper/ffm.py +126 -0
  36. accelforge/frontend/mapper/mapper.py +7 -0
  37. accelforge/frontend/mapper/metrics.py +30 -0
  38. accelforge/frontend/mapping/__init__.py +1 -0
  39. accelforge/frontend/mapping/mapping.py +1736 -0
  40. accelforge/frontend/model.py +14 -0
  41. accelforge/frontend/renames.py +150 -0
  42. accelforge/frontend/spec copy.py +230 -0
  43. accelforge/frontend/spec.py +301 -0
  44. accelforge/frontend/variables.py +12 -0
  45. accelforge/frontend/workload.py +952 -0
  46. accelforge/mapper/FFM/__init__.py +9 -0
  47. accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
  48. accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
  49. accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
  50. accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
  51. accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
  52. accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
  53. accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
  54. accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
  55. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
  56. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
  57. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
  58. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
  59. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
  60. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
  61. accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
  62. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
  63. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
  64. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
  65. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
  66. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
  67. accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
  68. accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
  69. accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
  70. accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
  71. accelforge/mapper/FFM/data.py +61 -0
  72. accelforge/mapper/FFM/main copy.py +236 -0
  73. accelforge/mapper/FFM/main.py +208 -0
  74. accelforge/mapper/FFM/mappings.py +510 -0
  75. accelforge/mapper/FFM/pmappings.py +310 -0
  76. accelforge/mapper/__init__.py +4 -0
  77. accelforge/mapper.py +0 -0
  78. accelforge/model/__init__.py +1 -0
  79. accelforge/model/_looptree/__init__.py +0 -0
  80. accelforge/model/_looptree/accesses.py +335 -0
  81. accelforge/model/_looptree/capacity/__init__.py +1 -0
  82. accelforge/model/_looptree/capacity/aggregators.py +36 -0
  83. accelforge/model/_looptree/capacity/capacity.py +47 -0
  84. accelforge/model/_looptree/energy.py +150 -0
  85. accelforge/model/_looptree/equivalent_ranks.py +29 -0
  86. accelforge/model/_looptree/latency/__init__.py +1 -0
  87. accelforge/model/_looptree/latency/latency.py +98 -0
  88. accelforge/model/_looptree/latency/memory.py +120 -0
  89. accelforge/model/_looptree/latency/processors.py +92 -0
  90. accelforge/model/_looptree/mapping_utilities.py +71 -0
  91. accelforge/model/_looptree/reuse/__init__.py +4 -0
  92. accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
  93. accelforge/model/_looptree/reuse/isl/des.py +59 -0
  94. accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
  95. accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
  96. accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
  97. accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
  98. accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
  99. accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
  100. accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
  101. accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
  102. accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
  103. accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
  104. accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
  105. accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
  106. accelforge/model/_looptree/run.py +122 -0
  107. accelforge/model/_looptree/types.py +26 -0
  108. accelforge/model/_looptree/visualization/__init__.py +0 -0
  109. accelforge/model/_looptree/visualization/occupancy.py +11 -0
  110. accelforge/model/main.py +222 -0
  111. accelforge/plotting/__init__.py +2 -0
  112. accelforge/plotting/mappings.py +219 -0
  113. accelforge/plotting/specs.py +57 -0
  114. accelforge/util/__init__.py +4 -0
  115. accelforge/util/_base_analysis_types.py +24 -0
  116. accelforge/util/_basetypes.py +1089 -0
  117. accelforge/util/_frozenset.py +36 -0
  118. accelforge/util/_isl.py +29 -0
  119. accelforge/util/_itertools.py +14 -0
  120. accelforge/util/_mathfuncs.py +57 -0
  121. accelforge/util/_parse_expressions.py +339 -0
  122. accelforge/util/_picklecache.py +32 -0
  123. accelforge/util/_setexpressions.py +268 -0
  124. accelforge/util/_sympy/__init__.py +0 -0
  125. accelforge/util/_sympy/broadcast_max.py +18 -0
  126. accelforge/util/_visualization.py +112 -0
  127. accelforge/util/_yaml.py +579 -0
  128. accelforge/util/parallel.py +193 -0
  129. accelforge-0.0.1.dist-info/METADATA +64 -0
  130. accelforge-0.0.1.dist-info/RECORD +258 -0
  131. accelforge-0.0.1.dist-info/WHEEL +5 -0
  132. accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
  133. accelforge-0.0.1.dist-info/top_level.txt +5 -0
  134. docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
  135. docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
  136. docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
  137. docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
  138. docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
  139. docs/_build/html/_sources/fastfusion.rst.txt +20 -0
  140. docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
  141. docs/_build/html/_sources/index.rst.txt +87 -0
  142. docs/_build/html/_sources/modules.rst.txt +7 -0
  143. docs/_build/html/_sources/notes/citation.rst.txt +45 -0
  144. docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
  145. docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
  146. docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
  147. docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
  148. docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
  149. docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
  150. docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
  151. docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
  152. docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
  153. docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
  154. docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
  155. docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
  156. docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
  157. docs/_build/html/_sources/notes/spec.rst.txt +36 -0
  158. docs/source/_ext/include_attrs.py +213 -0
  159. docs/source/_ext/include_docstring.py +364 -0
  160. docs/source/_ext/include_functions.py +154 -0
  161. docs/source/_ext/include_notebook.py +131 -0
  162. docs/source/_ext/include_yaml.py +119 -0
  163. docs/source/_ext/inherited_attributes.py +222 -0
  164. docs/source/_ext/paths.py +4 -0
  165. docs/source/conf.py +79 -0
  166. examples/arches/compute_in_memory/_include.yaml +74 -0
  167. examples/arches/compute_in_memory/_include_functions.py +229 -0
  168. examples/arches/compute_in_memory/_load_spec.py +57 -0
  169. examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
  170. examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
  171. examples/arches/compute_in_memory/components/misc.py +195 -0
  172. examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
  173. examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
  174. examples/arches/compute_in_memory/isaac.yaml +233 -0
  175. examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
  176. examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
  177. examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
  178. examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
  179. examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
  180. examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
  181. examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
  182. examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
  183. examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
  184. examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
  185. examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
  186. examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
  187. examples/arches/eyeriss.yaml +68 -0
  188. examples/arches/fanout_variations/at_glb.yaml +31 -0
  189. examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
  190. examples/arches/fanout_variations/at_mac.yaml +31 -0
  191. examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
  192. examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
  193. examples/arches/nvdla.yaml +47 -0
  194. examples/arches/simple.yaml +28 -0
  195. examples/arches/tpu_v4i.yaml +67 -0
  196. examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
  197. examples/misc/component_annotated.yaml +33 -0
  198. examples/workloads/gpt3_6.7B.yaml +124 -0
  199. examples/workloads/matmuls.yaml +20 -0
  200. examples/workloads/mobilenet_28.yaml +81 -0
  201. examples/workloads/mobilenet_various_separate.yaml +106 -0
  202. examples/workloads/three_matmuls_annotated.yaml +59 -0
  203. notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
  204. notebooks/compute_in_memory/_scripts.py +339 -0
  205. notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
  206. notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
  207. notebooks/paths.py +4 -0
  208. notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
  209. notebooks/tutorials/FFM.ipynb +3498 -0
  210. notebooks/tutorials/_include.py +48 -0
  211. notebooks/tutorials/component_energy_area.ipynb +363 -0
  212. tests/Q_mapping.yaml +38 -0
  213. tests/__init__.py +0 -0
  214. tests/conv.mapping.yaml +27 -0
  215. tests/conv.workload.yaml +13 -0
  216. tests/conv_sym.mapping.yaml +43 -0
  217. tests/copy.mapping.yaml +35 -0
  218. tests/copy.workload.yaml +15 -0
  219. tests/distribuffers/__init__.py +0 -0
  220. tests/distribuffers/multicast/test_cases.yaml +482 -0
  221. tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
  222. tests/distribuffers/spec/distributed.yaml +100 -0
  223. tests/distribuffers/spec/logical_arch.yaml +32 -0
  224. tests/distribuffers/spec/physical_arch.yaml +69 -0
  225. tests/distribuffers/test_binding.py +48 -0
  226. tests/frontend/__init__.py +0 -0
  227. tests/frontend/test_mapping_viz.py +52 -0
  228. tests/mapper/__init__.py +0 -0
  229. tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
  230. tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
  231. tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
  232. tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
  233. tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
  234. tests/mapper/test_mapping_to_isl.py +90 -0
  235. tests/mapper/test_spatial_reuse_analysis.py +67 -0
  236. tests/mapper/test_temporal_reuse_analysis.py +56 -0
  237. tests/mapper/util.py +58 -0
  238. tests/matmul.mapping.yaml +29 -0
  239. tests/matmul.workload.yaml +12 -0
  240. tests/matmul_spatial.mapping.yaml +44 -0
  241. tests/mha.renames.yaml +65 -0
  242. tests/mha.workload.yaml +67 -0
  243. tests/mha.yaml +59 -0
  244. tests/mha_full.workload.yaml +67 -0
  245. tests/mobilenet.workload.yaml +35 -0
  246. tests/mobilenet_long.workload.yaml +64 -0
  247. tests/pmappingcache.py +24 -0
  248. tests/processing_stage.arch.yaml +40 -0
  249. tests/snowcat.arch.yaml +36 -0
  250. tests/test_ffm_join_pmappings.py +106 -0
  251. tests/test_ffm_make_pmappings.py +82 -0
  252. tests/test_ffm_make_tile_shapes.py +49 -0
  253. tests/test_mapper.py +100 -0
  254. tests/test_model.py +37 -0
  255. tests/test_plotting.py +72 -0
  256. tests/test_processing_stage.py +46 -0
  257. tests/test_symbolic_model.py +248 -0
  258. tests/test_workload.py +141 -0
@@ -0,0 +1,666 @@
1
+ from collections.abc import Mapping
2
+ import copy
3
+ from math import ceil, exp, prod
4
+ import random
5
+ import threading
6
+ import time
7
+ from fastfusion._accelerated_imports import pd
8
+ from fastfusion.mapper.simanneal.evalmapping import quick_join
9
+ from fastfusion.mapper.simanneal.tracking import EvaluationsScoreTracker
10
+ from fastfusion.mapper.FFM._join_pmappings.join_pmappings import PmappingGroup
11
+ from fastfusion.mapper.FFM._join_pmappings.compatibility import (
12
+ TensorReservation,
13
+ Compatibility,
14
+ )
15
+ from fastfusion.mapper.FFM._join_pmappings.pmapping_group import (
16
+ MAPPING_COLUMN,
17
+ PmappingDataframe,
18
+ )
19
+ from fastfusion.util._frozenset import fzs
20
+ from fastfusion.mapper.simanneal.mapspaceglobals import MapspaceGlobals
21
+
22
+ OBJECTIVE_COLUMN = None # None -> Product
23
+
24
+
25
+ class FailedMutation(Exception):
26
+ pass
27
+
28
+
29
+ class Mapping:
30
+ def __init__(self, pmapping_groups: dict[str, list[PmappingGroup]]):
31
+ self.einsum_names = list(pmapping_groups.keys())
32
+ self.einsum2intra_choice = {
33
+ einsum_name: None for einsum_name in self.einsum_names
34
+ }
35
+ self.einsum2tiling = {}
36
+ for einsum_name, pm_group_list in pmapping_groups.items():
37
+ tensor_names = pm_group_list[0].tensor_names
38
+ tensors = fzs(TensorReservation(t, 0, 0, 0) for t in tensor_names)
39
+ self.set_einsum2tiling(einsum_name, Compatibility(tuple(), tensors))
40
+
41
+ # self.history = []
42
+ class dummy_appender:
43
+ def append(*args, **kwargs):
44
+ pass
45
+
46
+ self.history = dummy_appender()
47
+ self.n_crossovers = 0
48
+ self.n_mutations = 0
49
+
50
+ self.n_changes = 0
51
+ self.prev_eval_result = float("inf")
52
+ self.prev_eval_at_n_changes = -1
53
+
54
+ def set_einsum2tiling(self, einsum_name: str, tiling: Compatibility):
55
+ prev = self.einsum2tiling.get(einsum_name, None)
56
+ if prev is not None and prev == tiling:
57
+ return
58
+ self.einsum2tiling[einsum_name] = tiling
59
+ self.einsum2intra_choice[einsum_name] = None
60
+
61
+ def fix_loops(self, mapspace_globals: MapspaceGlobals):
62
+ """Ensure that all tilings have the correct number of loops"""
63
+ self.n_changes += 1
64
+ self.history.append("Fixing loops")
65
+
66
+ try:
67
+ for einsum in self.einsum_names:
68
+ tiling = self.einsum2tiling[einsum]
69
+ n_loops = max(t.above_loop_index for t in tiling.tensors)
70
+
71
+ # If there's too many loops then drop the extra ones
72
+ if n_loops < len(tiling.loops):
73
+ self.set_einsum2tiling(
74
+ einsum, tiling.update(loops=tiling.loops[:n_loops])
75
+ )
76
+
77
+ # If there's not enough loops then add some
78
+ if n_loops > len(tiling.loops):
79
+ for tensor in tiling.tensors:
80
+ for loop in range(len(tiling.loops), tensor.above_loop_index):
81
+ self.mutate_loop(mapspace_globals, tensor, loop, einsum)
82
+ self.force_loop_match(mapspace_globals, loop, einsum)
83
+ assert n_loops == len(self.einsum2tiling[einsum].loops)
84
+
85
+ tiling = self.einsum2tiling[einsum]
86
+ tensors = tiling.tensors
87
+ for i in range(len(tiling.loops)):
88
+ tensors = list(t for t in tensors if t.above_loop_index > i)
89
+ if not tensors:
90
+ continue
91
+ possible_loops = set.intersection(
92
+ *(
93
+ mapspace_globals.tensor2possible_loops_above_set[einsum][t]
94
+ for t in tensors
95
+ )
96
+ )
97
+ if not possible_loops:
98
+ raise FailedMutation(
99
+ f"No possible loops above {i} for {einsum}"
100
+ )
101
+ if tiling.loops[i] not in possible_loops:
102
+ new_loop = random.choice(list(possible_loops))
103
+ self.history.append(
104
+ f"Fixing loop {i} for {einsum} to {new_loop}"
105
+ )
106
+ new_loops = (
107
+ tiling.loops[:i] + (new_loop,) + tiling.loops[i + 1 :]
108
+ )
109
+ self.set_einsum2tiling(einsum, tiling.update(loops=new_loops))
110
+
111
+ except FailedMutation:
112
+ self.history.append(f"Failed to fix loops")
113
+ raise FailedMutation("Failed to fix loops")
114
+
115
+ def match_loops(
116
+ self, index: int, einsum_name: str, mapspace_globals: MapspaceGlobals
117
+ ):
118
+ """Ensure that loops match across Einsums"""
119
+ self.n_changes += 1
120
+ tiling = self.einsum2tiling[einsum_name]
121
+ for einsum_name2, tiling2 in self.einsum2tiling.items():
122
+ if einsum_name2 == einsum_name:
123
+ continue
124
+ shared_loop_index = max(
125
+ tiling.shared_loop_index(tiling2.tensor_names),
126
+ tiling2.shared_loop_index(tiling.tensor_names),
127
+ )
128
+ for i in range(min(shared_loop_index, index) + 1):
129
+ # Translate loop from einsum_name to einsum_name2
130
+ loop = tiling.loops[i]
131
+ translations = mapspace_globals.rank_translations[einsum_name][
132
+ einsum_name2
133
+ ][loop.rank_variable_name]
134
+ if not translations:
135
+ raise FailedMutation(
136
+ f"Failed to translate loop {loop} from {einsum_name} to {einsum_name2}"
137
+ )
138
+ rank_variable_name = random.choice(translations)
139
+ new_loops = (
140
+ tiling2.loops[:i]
141
+ + (loop.update(rank_variable_names=fzs((rank_variable_name,))),)
142
+ + tiling2.loops[i + 1 :]
143
+ )
144
+ tiling2 = tiling2.update(loops=new_loops)
145
+ self.set_einsum2tiling(einsum_name2, tiling2)
146
+
147
+ def mutate_loop(
148
+ self,
149
+ mapspace_globals: MapspaceGlobals,
150
+ tensor: TensorReservation = None,
151
+ index: int = None,
152
+ einsum_name: str = None,
153
+ ):
154
+ self.n_changes += 1
155
+ if tensor is None:
156
+ memories = set().union(*(t.tensors for t in self.einsum2tiling.values()))
157
+ memories = [m for m in memories if m.above_loop_index > 0]
158
+ if not memories:
159
+ raise FailedMutation("No memories to mutate")
160
+ tensor = random.choice(list(memories))
161
+ if index is None:
162
+ index = random.randint(0, tensor.above_loop_index - 1)
163
+ if einsum_name is None:
164
+ possible_einsums = [
165
+ e for e, t in self.einsum2tiling.items() if tensor in t.tensors
166
+ ]
167
+ assert possible_einsums
168
+ einsum_name = random.choice(possible_einsums)
169
+
170
+ tiling = self.einsum2tiling[einsum_name]
171
+ prev_loop = None
172
+
173
+ choice = random.choice(["Increasing", "Decreasing", "Randomizing"])
174
+ if len(tiling.loops) <= index:
175
+ choice = "Randomizing"
176
+
177
+ candidates = mapspace_globals.tensor2possible_loops_above[einsum_name][tensor]
178
+ if choice == "Randomizing":
179
+ new_loop = random.choice(candidates)
180
+ else:
181
+ prev_loop = tiling.loops[index]
182
+ rank, bound = prev_loop.rank_variable_name, prev_loop.bound
183
+ comparison = lambda x, y: x > y if choice == "Increasing" else x < y
184
+
185
+ candidates = [
186
+ c
187
+ for c in candidates
188
+ if comparison(c.bound, bound) and c.rank_variable_name == rank
189
+ ]
190
+ if not candidates:
191
+ raise FailedMutation(
192
+ f"{choice} {prev_loop} for {einsum_name} at {index} failed"
193
+ )
194
+ new_loop = random.choice(candidates)
195
+
196
+ self.history.append(f"{choice} loop {index} for {einsum_name} to {new_loop}")
197
+ new_loops = tiling.loops[:index] + (new_loop,) + tiling.loops[index + 1 :]
198
+ self.set_einsum2tiling(einsum_name, tiling.update(loops=new_loops))
199
+
200
+ def get_shared_loop_index(
201
+ self, mapspace_globals: MapspaceGlobals, einsum_name0: int, einsum_name1: int
202
+ ):
203
+ einsum_names = list(self.einsum2tiling.keys())
204
+ if einsum_name0 == einsum_name1:
205
+ einsum_name = einsum_names[einsum_index0]
206
+ return len(self.einsum2tiling[einsum_name].loops) - 1
207
+
208
+ einsum_index0 = einsum_names.index(einsum_name0)
209
+ einsum_index1 = einsum_names.index(einsum_name1)
210
+
211
+ if einsum_index0 > einsum_index1:
212
+ einsum_index0, einsum_index1 = einsum_index1, einsum_index0
213
+
214
+ tiling0 = self.einsum2tiling[einsum_names[einsum_index0]]
215
+ tiling1 = self.einsum2tiling[einsum_names[einsum_index1]]
216
+ left_tensors = mapspace_globals.get_tensors(*einsum_names[: einsum_index0 + 1])
217
+ right_tensors = mapspace_globals.get_tensors(*einsum_names[einsum_index1:])
218
+ return max(
219
+ tiling0.shared_loop_index(right_tensors),
220
+ tiling1.shared_loop_index(left_tensors),
221
+ )
222
+
223
+ def force_loop_match(
224
+ self,
225
+ mapspace_globals: MapspaceGlobals,
226
+ index: int,
227
+ einsum_name: str,
228
+ ):
229
+ self.n_changes += 1
230
+ tiling = self.einsum2tiling[einsum_name]
231
+ for einsum_name2, tiling2 in self.einsum2tiling.items():
232
+ if einsum_name2 == einsum_name:
233
+ continue
234
+ shared_loop_index = self.get_shared_loop_index(
235
+ mapspace_globals, einsum_name, einsum_name2
236
+ )
237
+ rank_translations = mapspace_globals.rank_translations[einsum_name][
238
+ einsum_name2
239
+ ]
240
+ for i in range(min(shared_loop_index, index) + 1):
241
+ loop = tiling.loops[i]
242
+ translations = rank_translations[loop.rank_variable_name]
243
+ if not translations:
244
+ raise FailedMutation(
245
+ f"Failed to translate loop {loop} from {einsum_name} to {einsum_name2}"
246
+ )
247
+ rank_variable_name = random.choice(translations)
248
+ new_loops = (
249
+ tiling2.loops[:i]
250
+ + (loop.update(rank_variable_names=fzs((rank_variable_name,))),)
251
+ + tiling2.loops[i + 1 :]
252
+ )
253
+ tiling2 = tiling2.update(loops=new_loops)
254
+ self.set_einsum2tiling(einsum_name2, tiling2)
255
+
256
+ def mutate_backing_tensor(self, mapspace_globals: MapspaceGlobals):
257
+ self.n_changes += 1
258
+ tensor = random.choice(
259
+ list(mapspace_globals.tensor_names_used_in_multiple_einsums)
260
+ )
261
+ memories = random.choice(mapspace_globals.tensor2memories[tensor])
262
+ for t in self.einsum2tiling.values():
263
+ if memories in t.tensors:
264
+ raise FailedMutation(
265
+ f"Moving tensor {tensor} to tensor {memories} failed"
266
+ )
267
+ self.history.append(f"Moving tensor {tensor} to tensor {memories}")
268
+ for einsum, tiling in self.einsum2tiling.items():
269
+ if not any(r.name == tensor for r in tiling.tensors):
270
+ continue
271
+ new_tensors = [memories] + [r for r in tiling.tensors if r.name != tensor]
272
+ self.set_einsum2tiling(einsum, tiling.update(tensors=fzs(new_tensors)))
273
+ self.fix_loops(mapspace_globals)
274
+
275
+ def mutate_order(self, mapspace_globals: MapspaceGlobals):
276
+ return
277
+ self.n_changes += 1
278
+ e0, e1 = random.sample(self.einsum_names, 2)
279
+ print(f"Switching {e0} and {e1}")
280
+ self.einsum2tiling[e0], self.einsum2tiling[e1] = (
281
+ self.einsum2tiling[e1],
282
+ self.einsum2tiling[e0],
283
+ )
284
+ self.fix_loops(mapspace_globals)
285
+
286
+ def evaluate(self, mapspace_globals: MapspaceGlobals, return_df=False) -> float:
287
+ if self.n_changes == self.prev_eval_at_n_changes and not return_df:
288
+ return self.prev_eval_result, 1
289
+
290
+ chosen_sims = []
291
+ chosen_mappings = {}
292
+ n_evaluations = (
293
+ mapspace_globals.size_scale * mapspace_globals.find_pmapping_scale
294
+ )
295
+
296
+ if self.n_changes == self.prev_eval_at_n_changes and not return_df:
297
+ return self.prev_eval_result, 1
298
+ self.prev_eval_at_n_changes = self.n_changes
299
+ self.prev_eval_result = float("inf")
300
+
301
+ for einsum_name, t in self.einsum2tiling.items():
302
+ if t not in mapspace_globals.einsum_tiling_2_sim[einsum_name]:
303
+ assert not return_df
304
+ return float("inf"), n_evaluations
305
+
306
+ sim = mapspace_globals.einsum_tiling_2_sim[einsum_name][t]
307
+ chosen_sims.append(sim)
308
+ intra_mappings = sim.mappings.data
309
+
310
+ if self.einsum2intra_choice[einsum_name] is not None:
311
+ mapping = intra_mappings.iloc[
312
+ self.einsum2intra_choice[einsum_name] % len(intra_mappings)
313
+ ]
314
+ chosen_mappings[einsum_name] = mapping
315
+ continue
316
+
317
+ self.einsum2intra_choice[einsum_name] = random.randint(0, 1000000000000)
318
+ choice = self.einsum2intra_choice[einsum_name] % len(sim.mappings.data)
319
+ self.einsum2intra_choice[einsum_name] = choice
320
+ n_evaluations += (
321
+ mapspace_globals.size_scale * mapspace_globals.find_pmapping_scale
322
+ )
323
+ mapping = intra_mappings.iloc[choice]
324
+ chosen_mappings[einsum_name] = mapping
325
+
326
+ try:
327
+ new_sims = {}
328
+ for einsum_name, tiling in self.einsum2tiling.items():
329
+ sim = mapspace_globals.einsum_tiling_2_sim[einsum_name][tiling]
330
+ mapping_index = self.einsum2intra_choice[einsum_name] % len(
331
+ sim.mappings.data
332
+ )
333
+ new_sims[einsum_name] = [
334
+ PmappingGroup(
335
+ compatibility=sim.compatibility,
336
+ mappings=PmappingDataframe(
337
+ sim.mappings.data.iloc[
338
+ mapping_index : mapping_index + 1
339
+ ].copy()
340
+ ),
341
+ )
342
+ ]
343
+ chosen_mappings = quick_join(new_sims, mapspace_globals)
344
+ assert len(chosen_mappings.data) == 1
345
+ chosen_mappings = chosen_mappings.data.iloc[0]
346
+ except Exception as e:
347
+ assert not return_df
348
+ return float("inf"), n_evaluations
349
+
350
+ obj_cols = mapspace_globals.objective_function_cols
351
+ score = prod(chosen_mappings[col] for col in obj_cols)
352
+ # if score < 4.7770043942936216e+20:
353
+ # print("AHH")
354
+ # import pydot
355
+ # graph = pydot.Dot(graph_type="digraph", ranksep="0.2", nodesep="0.2")
356
+ # tree.to_pydot(graph)
357
+ # with open(f"test.png", "wb") as f:
358
+ # f.write(graph.create_png())
359
+
360
+ if return_df:
361
+ d = {col: sum(c[col] for c in chosen_mappings.values()) for col in obj_cols}
362
+ d[MAPPING_COLUMN] = mapping
363
+ self.prev_eval_result = score
364
+ return pd.DataFrame([d]), n_evaluations
365
+ self.prev_eval_result = score
366
+ return score, n_evaluations
367
+
368
+ def mutate_intra_mapping(self, mapspace_globals: MapspaceGlobals):
369
+ self.n_changes += 1
370
+ einsum_name = random.choice(self.einsum_names)
371
+ self.history.append(f"Choosing intra-layer mapping for {einsum_name}")
372
+ self.einsum2intra_choice[einsum_name] = None
373
+
374
+ def get_mutation_functions(self):
375
+ return [
376
+ self.mutate_loop,
377
+ self.mutate_backing_tensor,
378
+ self.mutate_order,
379
+ self.mutate_intra_mapping,
380
+ ]
381
+
382
+ def crossover(self, other: Mapping, mapspace_globals: MapspaceGlobals):
383
+ child = copy.deepcopy(other)
384
+ einsum_name = random.choice(child.einsum_names)
385
+ try:
386
+ child.set_einsum2tiling(einsum_name, self.einsum2tiling[einsum_name])
387
+ child.einsum2intra_choice[einsum_name] = self.einsum2intra_choice[
388
+ einsum_name
389
+ ]
390
+ child.n_changes += 1
391
+ for i in range(len(child.einsum2tiling[einsum_name].loops)):
392
+ child.match_loops(i, einsum_name, mapspace_globals)
393
+ child.fix_loops(mapspace_globals)
394
+ child.n_crossovers += 1
395
+ except FailedMutation:
396
+ return copy.deepcopy(other)
397
+ return child
398
+
399
+ @staticmethod
400
+ def create_random_mapping(mapspace_globals: MapspaceGlobals):
401
+ mapping = Mapping(mapspace_globals.pmapping_groups)
402
+ prev_compatibility: Compatibility = None
403
+ einsum_names = list(mapping.einsum2tiling.keys())
404
+ for i, einsum_name in enumerate(einsum_names):
405
+ pm_group_list = mapspace_globals.pmapping_groups[einsum_name]
406
+ if prev_compatibility is None:
407
+ sim = random.choice(pm_group_list)
408
+ mapping.set_einsum2tiling(einsum_name, sim.compatibility)
409
+ if len(einsum_names) == 1:
410
+ break
411
+ prev_compatibility = mapspace_globals.compatibility2rightcompatibility[
412
+ einsum_name
413
+ ][sim.compatibility]
414
+ live_tensors = mapspace_globals.get_live_tensors(*einsum_names[i + 1 :])
415
+ prev_compatibility = prev_compatibility.clear_dead_tensors(
416
+ live_tensors=live_tensors
417
+ )
418
+ continue
419
+
420
+ tilings = []
421
+ compatiblity_options = mapspace_globals.leftcompatibility2tiling[
422
+ einsum_name
423
+ ]
424
+ cur_tensors = mapspace_globals.get_tensors(einsum_name)
425
+ for translation in mapspace_globals.get_possible_translations(
426
+ prev_compatibility, einsum_name
427
+ ):
428
+ translation = translation.clear_dead_tensors(
429
+ live_tensors=cur_tensors, keep_loops=True
430
+ )
431
+ if translation in compatiblity_options:
432
+ tilings.extend(compatiblity_options[translation])
433
+
434
+ if not tilings:
435
+ raise FailedMutation(
436
+ f"No tilings for {einsum_name} with {prev_compatibility}"
437
+ )
438
+ sim_choices = [
439
+ mapspace_globals.einsum_tiling_2_sim[einsum_name][t] for t in tilings
440
+ ]
441
+ sim = random.choice(sim_choices)
442
+ tiling = sim.compatibility
443
+ mapping.set_einsum2tiling(einsum_name, tiling)
444
+ if i == len(einsum_names) - 1:
445
+ break
446
+ new_compatibility: Compatibility = (
447
+ mapspace_globals.compatibility2rightcompatibility[einsum_name][tiling]
448
+ )
449
+ # Combine prev_compatibility and new_compatibility
450
+ live_tensors = mapspace_globals.get_live_tensors(*einsum_names[i + 1 :])
451
+ prev_compatibility = prev_compatibility.merge_next(
452
+ new_compatibility, live_tensors
453
+ )
454
+ return mapping
455
+
456
+
457
+ def get_accept_function(temperature, cooling_rate, evaluations_tracker):
458
+ proportion = evaluations_tracker.evaluations / evaluations_tracker.max_evaluations
459
+ new_temp = temperature * (1 - proportion) / (1 + cooling_rate * proportion)
460
+
461
+ # Assume prescient knowledge of the best score with which to scale by
462
+ def accept(prev_eval_result, new_score):
463
+ if new_score == float("inf"):
464
+ return False
465
+ if new_score <= prev_eval_result:
466
+ return True
467
+ scaleby = new_temp * evaluations_tracker.stop_at_score
468
+ if scaleby > 0 and random.random() < exp(
469
+ (prev_eval_result - new_score) / scaleby
470
+ ):
471
+ return True
472
+ return False
473
+
474
+ return accept
475
+
476
+
477
+ def mutate(
478
+ mapping: Mapping, mapspace_globals: MapspaceGlobals, accept_function: callable
479
+ ):
480
+ prev_mapping = copy.deepcopy(mapping)
481
+ prev_eval_result = mapping.prev_eval_result
482
+ n_evaluations = 1
483
+ try:
484
+ choice = random.choice(mapping.get_mutation_functions())
485
+ choice(mapspace_globals)
486
+ except FailedMutation:
487
+ return prev_mapping, n_evaluations
488
+ new_score, n_evaluations = mapping.evaluate(mapspace_globals)
489
+ if new_score == float("inf"):
490
+ return prev_mapping, n_evaluations
491
+ if accept_function(prev_eval_result, new_score):
492
+ return mapping, n_evaluations
493
+ return prev_mapping, n_evaluations
494
+
495
+
496
+ def _fuse_sims(
497
+ mapspace_globals: MapspaceGlobals,
498
+ n_threads: int,
499
+ evaluations_tracker: EvaluationsScoreTracker,
500
+ algorithm: str,
501
+ ):
502
+ random.seed(time.time() + hash(threading.get_ident())) # Seed with thread ID
503
+ evaluations_tracker.multiply_scale_by(len(mapspace_globals.einsum_names))
504
+ evaluations_tracker.print_period *= n_threads
505
+ evaluations_tracker.max_evaluations //= n_threads
506
+
507
+ def anneal_population(population, mapspace_globals: MapspaceGlobals, n_rounds):
508
+ temperature = 0.07
509
+ cooling_rate = 8
510
+ while True:
511
+ accept_function = get_accept_function(
512
+ temperature, cooling_rate, evaluations_tracker
513
+ )
514
+ # population = parallel([delayed(mutate)(m, mapspace_globals, accept_function) for m in population])
515
+ for j, mapping in enumerate(population):
516
+ population[j], evaluations = mutate(
517
+ mapping, mapspace_globals, accept_function
518
+ )
519
+ if evaluations_tracker.add_evaluation(
520
+ evaluations, population[j].prev_eval_result
521
+ ):
522
+ return population
523
+
524
+ def genetic_algorithm_population(
525
+ population, mapspace_globals: MapspaceGlobals, n_rounds
526
+ ):
527
+ population_size = len(population)
528
+ crossover_rate = 0.7
529
+ mutation_rate = 0.2
530
+
531
+ def crossover(parent1: Mapping, parent2: Mapping):
532
+ if random.random() > crossover_rate:
533
+ return copy.deepcopy(parent1)
534
+ return parent1.crossover(parent2, mapspace_globals)
535
+
536
+ def mutate_individual(individual):
537
+ individual = copy.deepcopy(individual)
538
+ prev_mapping = copy.deepcopy(individual)
539
+ if random.random() > mutation_rate:
540
+ return individual
541
+ try:
542
+ mutation_function = random.choice(individual.get_mutation_functions())
543
+ mutation_function(mapspace_globals)
544
+ individual.n_mutations += 1
545
+ return individual
546
+ except FailedMutation:
547
+ return prev_mapping
548
+
549
+ best_fitness = float("inf")
550
+ while True:
551
+ # Evaluate fitness
552
+ fitness = [0] * len(population)
553
+ for i, individual in enumerate(population):
554
+ f, evaluations = individual.evaluate(mapspace_globals)
555
+ fitness[i] = f
556
+ best_fitness = min(best_fitness, f)
557
+ if evaluations_tracker.add_evaluation(evaluations, best_fitness):
558
+ return population
559
+
560
+ best_score = min(fitness)
561
+ best_mapping = population[fitness.index(best_score)]
562
+
563
+ # Selection (roulette wheel selection)
564
+ total_fitness = sum(1.0 / (f + 1e-9) for f in fitness)
565
+ probabilities = [(1.0 / (f + 1e-9)) / total_fitness for f in fitness]
566
+ selected_indices = random.choices(
567
+ range(len(population)), probabilities, k=population_size
568
+ )
569
+
570
+ # Crossover
571
+ new_population = list(population[i] for i in selected_indices)
572
+ for i in range(0, population_size, 2):
573
+ parent1 = population[selected_indices[i]]
574
+ parent2 = population[selected_indices[(i + 1) % population_size]]
575
+ child1 = crossover(parent1, parent2)
576
+ child2 = crossover(parent2, parent1)
577
+ new_population.extend([child1, child2])
578
+
579
+ # Mutation
580
+ for i, individual in enumerate(new_population):
581
+ new_population[i] = mutate_individual(individual)
582
+
583
+ new_population.append(best_mapping) # Keep the best mapping around
584
+ population = new_population
585
+
586
+ return population
587
+
588
+ def random_sample_population(
589
+ population, mapspace_globals: MapspaceGlobals, n_rounds, prune=False
590
+ ):
591
+ best_mapping = population[0]
592
+ best_score = float("inf")
593
+ while True:
594
+ try:
595
+ mapping = Mapping.create_random_mapping(mapspace_globals)
596
+ except FailedMutation:
597
+ if not prune:
598
+ if evaluations_tracker.add_evaluation(1, float("inf")):
599
+ return [best_mapping]
600
+ continue
601
+ score, evaluations = mapping.evaluate(mapspace_globals)
602
+ if score < best_score:
603
+ best_mapping = mapping
604
+ best_score = score
605
+ if evaluations_tracker.add_evaluation(evaluations, score):
606
+ return [best_mapping]
607
+ return [best_mapping]
608
+
609
+ extra_args = {}
610
+ if algorithm == "genetic":
611
+ population_size = 1000
612
+ callfunc = genetic_algorithm_population
613
+ elif algorithm == "simulated_anneal":
614
+ population_size = 100 // n_threads
615
+ callfunc = anneal_population
616
+ elif "random" in algorithm:
617
+ population_size = 1
618
+ callfunc = random_sample_population
619
+ extra_args["prune"] = "pruned" in algorithm
620
+
621
+ # Randomly intialize the population
622
+ def get_random_mapping():
623
+ while True:
624
+ try:
625
+ mapping = Mapping.create_random_mapping(mapspace_globals)
626
+ score, evaluations = mapping.evaluate(mapspace_globals)
627
+ evaluations_tracker.add_evaluation(evaluations, score)
628
+ if score == float("inf"):
629
+ raise FailedMutation("Random mapping failed")
630
+ return mapping
631
+ except FailedMutation:
632
+ pass
633
+
634
+ population = []
635
+ while len(population) < population_size:
636
+ try:
637
+ mapping = Mapping.create_random_mapping(mapspace_globals)
638
+ score, evaluations = mapping.evaluate(mapspace_globals)
639
+ if evaluations_tracker.add_evaluation(evaluations, score):
640
+ break
641
+ if score == float("inf"):
642
+ raise FailedMutation("Random mapping failed")
643
+ population.append(mapping)
644
+ except FailedMutation:
645
+ if evaluations_tracker.add_evaluation(1, float("inf")):
646
+ break
647
+
648
+ n_rounds = 9999999999999999999999999
649
+ results = callfunc(population, mapspace_globals, n_rounds)
650
+ eval_results = []
651
+ for m in results:
652
+ try:
653
+ eval_results.append(m.evaluate(mapspace_globals, return_df=True)[0])
654
+ except:
655
+ pass
656
+ try:
657
+ return pd.DataFrame(), evaluations_tracker
658
+ assert False, "Not saving chosen mappings to avoid big files"
659
+ return (
660
+ pd.concat(eval_results),
661
+ evaluations_tracker,
662
+ ) # <- Resulted in large files bc it's not pareto pruned
663
+ except Exception as e:
664
+ for i in range(30):
665
+ print(f"Failed to concatenate results. Exception: {e}")
666
+ return pd.DataFrame(), evaluations_tracker