accelforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. accelforge/__init__.py +21 -0
  2. accelforge/_accelerated_imports.py +16 -0
  3. accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
  4. accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
  5. accelforge/_deprecate/_simanneal/simanneal.py +666 -0
  6. accelforge/_deprecate/_simanneal/tracking.py +105 -0
  7. accelforge/_deprecate/_simanneal/wrappers.py +218 -0
  8. accelforge/_deprecate/_simanneal2/__init__.py +7 -0
  9. accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
  10. accelforge/_deprecate/_simanneal2/tracking.py +116 -0
  11. accelforge/_deprecate/compatibility_util.py +181 -0
  12. accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
  13. accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
  14. accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
  15. accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
  16. accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
  17. accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
  18. accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
  19. accelforge/_deprecate/tags.py +69 -0
  20. accelforge/_deprecate/viz/__init__.py +0 -0
  21. accelforge/_deprecate/viz/interactive.py +159 -0
  22. accelforge/_deprecate/viz/reservationtree.py +307 -0
  23. accelforge/_deprecate/viz/ski_slope.py +88 -0
  24. accelforge/_version.py +15 -0
  25. accelforge/examples.py +39 -0
  26. accelforge/frontend/__init__.py +10 -0
  27. accelforge/frontend/_binding.py +129 -0
  28. accelforge/frontend/_workload_isl/__init__.py +2 -0
  29. accelforge/frontend/_workload_isl/_isl.py +149 -0
  30. accelforge/frontend/_workload_isl/_symbolic.py +141 -0
  31. accelforge/frontend/arch copy.py +1544 -0
  32. accelforge/frontend/arch.py +1642 -0
  33. accelforge/frontend/config.py +63 -0
  34. accelforge/frontend/mapper/__init__.py +5 -0
  35. accelforge/frontend/mapper/ffm.py +126 -0
  36. accelforge/frontend/mapper/mapper.py +7 -0
  37. accelforge/frontend/mapper/metrics.py +30 -0
  38. accelforge/frontend/mapping/__init__.py +1 -0
  39. accelforge/frontend/mapping/mapping.py +1736 -0
  40. accelforge/frontend/model.py +14 -0
  41. accelforge/frontend/renames.py +150 -0
  42. accelforge/frontend/spec copy.py +230 -0
  43. accelforge/frontend/spec.py +301 -0
  44. accelforge/frontend/variables.py +12 -0
  45. accelforge/frontend/workload.py +952 -0
  46. accelforge/mapper/FFM/__init__.py +9 -0
  47. accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
  48. accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
  49. accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
  50. accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
  51. accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
  52. accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
  53. accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
  54. accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
  55. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
  56. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
  57. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
  58. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
  59. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
  60. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
  61. accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
  62. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
  63. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
  64. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
  65. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
  66. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
  67. accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
  68. accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
  69. accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
  70. accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
  71. accelforge/mapper/FFM/data.py +61 -0
  72. accelforge/mapper/FFM/main copy.py +236 -0
  73. accelforge/mapper/FFM/main.py +208 -0
  74. accelforge/mapper/FFM/mappings.py +510 -0
  75. accelforge/mapper/FFM/pmappings.py +310 -0
  76. accelforge/mapper/__init__.py +4 -0
  77. accelforge/mapper.py +0 -0
  78. accelforge/model/__init__.py +1 -0
  79. accelforge/model/_looptree/__init__.py +0 -0
  80. accelforge/model/_looptree/accesses.py +335 -0
  81. accelforge/model/_looptree/capacity/__init__.py +1 -0
  82. accelforge/model/_looptree/capacity/aggregators.py +36 -0
  83. accelforge/model/_looptree/capacity/capacity.py +47 -0
  84. accelforge/model/_looptree/energy.py +150 -0
  85. accelforge/model/_looptree/equivalent_ranks.py +29 -0
  86. accelforge/model/_looptree/latency/__init__.py +1 -0
  87. accelforge/model/_looptree/latency/latency.py +98 -0
  88. accelforge/model/_looptree/latency/memory.py +120 -0
  89. accelforge/model/_looptree/latency/processors.py +92 -0
  90. accelforge/model/_looptree/mapping_utilities.py +71 -0
  91. accelforge/model/_looptree/reuse/__init__.py +4 -0
  92. accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
  93. accelforge/model/_looptree/reuse/isl/des.py +59 -0
  94. accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
  95. accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
  96. accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
  97. accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
  98. accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
  99. accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
  100. accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
  101. accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
  102. accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
  103. accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
  104. accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
  105. accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
  106. accelforge/model/_looptree/run.py +122 -0
  107. accelforge/model/_looptree/types.py +26 -0
  108. accelforge/model/_looptree/visualization/__init__.py +0 -0
  109. accelforge/model/_looptree/visualization/occupancy.py +11 -0
  110. accelforge/model/main.py +222 -0
  111. accelforge/plotting/__init__.py +2 -0
  112. accelforge/plotting/mappings.py +219 -0
  113. accelforge/plotting/specs.py +57 -0
  114. accelforge/util/__init__.py +4 -0
  115. accelforge/util/_base_analysis_types.py +24 -0
  116. accelforge/util/_basetypes.py +1089 -0
  117. accelforge/util/_frozenset.py +36 -0
  118. accelforge/util/_isl.py +29 -0
  119. accelforge/util/_itertools.py +14 -0
  120. accelforge/util/_mathfuncs.py +57 -0
  121. accelforge/util/_parse_expressions.py +339 -0
  122. accelforge/util/_picklecache.py +32 -0
  123. accelforge/util/_setexpressions.py +268 -0
  124. accelforge/util/_sympy/__init__.py +0 -0
  125. accelforge/util/_sympy/broadcast_max.py +18 -0
  126. accelforge/util/_visualization.py +112 -0
  127. accelforge/util/_yaml.py +579 -0
  128. accelforge/util/parallel.py +193 -0
  129. accelforge-0.0.1.dist-info/METADATA +64 -0
  130. accelforge-0.0.1.dist-info/RECORD +258 -0
  131. accelforge-0.0.1.dist-info/WHEEL +5 -0
  132. accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
  133. accelforge-0.0.1.dist-info/top_level.txt +5 -0
  134. docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
  135. docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
  136. docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
  137. docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
  138. docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
  139. docs/_build/html/_sources/fastfusion.rst.txt +20 -0
  140. docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
  141. docs/_build/html/_sources/index.rst.txt +87 -0
  142. docs/_build/html/_sources/modules.rst.txt +7 -0
  143. docs/_build/html/_sources/notes/citation.rst.txt +45 -0
  144. docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
  145. docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
  146. docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
  147. docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
  148. docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
  149. docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
  150. docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
  151. docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
  152. docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
  153. docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
  154. docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
  155. docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
  156. docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
  157. docs/_build/html/_sources/notes/spec.rst.txt +36 -0
  158. docs/source/_ext/include_attrs.py +213 -0
  159. docs/source/_ext/include_docstring.py +364 -0
  160. docs/source/_ext/include_functions.py +154 -0
  161. docs/source/_ext/include_notebook.py +131 -0
  162. docs/source/_ext/include_yaml.py +119 -0
  163. docs/source/_ext/inherited_attributes.py +222 -0
  164. docs/source/_ext/paths.py +4 -0
  165. docs/source/conf.py +79 -0
  166. examples/arches/compute_in_memory/_include.yaml +74 -0
  167. examples/arches/compute_in_memory/_include_functions.py +229 -0
  168. examples/arches/compute_in_memory/_load_spec.py +57 -0
  169. examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
  170. examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
  171. examples/arches/compute_in_memory/components/misc.py +195 -0
  172. examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
  173. examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
  174. examples/arches/compute_in_memory/isaac.yaml +233 -0
  175. examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
  176. examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
  177. examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
  178. examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
  179. examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
  180. examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
  181. examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
  182. examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
  183. examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
  184. examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
  185. examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
  186. examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
  187. examples/arches/eyeriss.yaml +68 -0
  188. examples/arches/fanout_variations/at_glb.yaml +31 -0
  189. examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
  190. examples/arches/fanout_variations/at_mac.yaml +31 -0
  191. examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
  192. examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
  193. examples/arches/nvdla.yaml +47 -0
  194. examples/arches/simple.yaml +28 -0
  195. examples/arches/tpu_v4i.yaml +67 -0
  196. examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
  197. examples/misc/component_annotated.yaml +33 -0
  198. examples/workloads/gpt3_6.7B.yaml +124 -0
  199. examples/workloads/matmuls.yaml +20 -0
  200. examples/workloads/mobilenet_28.yaml +81 -0
  201. examples/workloads/mobilenet_various_separate.yaml +106 -0
  202. examples/workloads/three_matmuls_annotated.yaml +59 -0
  203. notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
  204. notebooks/compute_in_memory/_scripts.py +339 -0
  205. notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
  206. notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
  207. notebooks/paths.py +4 -0
  208. notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
  209. notebooks/tutorials/FFM.ipynb +3498 -0
  210. notebooks/tutorials/_include.py +48 -0
  211. notebooks/tutorials/component_energy_area.ipynb +363 -0
  212. tests/Q_mapping.yaml +38 -0
  213. tests/__init__.py +0 -0
  214. tests/conv.mapping.yaml +27 -0
  215. tests/conv.workload.yaml +13 -0
  216. tests/conv_sym.mapping.yaml +43 -0
  217. tests/copy.mapping.yaml +35 -0
  218. tests/copy.workload.yaml +15 -0
  219. tests/distribuffers/__init__.py +0 -0
  220. tests/distribuffers/multicast/test_cases.yaml +482 -0
  221. tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
  222. tests/distribuffers/spec/distributed.yaml +100 -0
  223. tests/distribuffers/spec/logical_arch.yaml +32 -0
  224. tests/distribuffers/spec/physical_arch.yaml +69 -0
  225. tests/distribuffers/test_binding.py +48 -0
  226. tests/frontend/__init__.py +0 -0
  227. tests/frontend/test_mapping_viz.py +52 -0
  228. tests/mapper/__init__.py +0 -0
  229. tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
  230. tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
  231. tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
  232. tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
  233. tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
  234. tests/mapper/test_mapping_to_isl.py +90 -0
  235. tests/mapper/test_spatial_reuse_analysis.py +67 -0
  236. tests/mapper/test_temporal_reuse_analysis.py +56 -0
  237. tests/mapper/util.py +58 -0
  238. tests/matmul.mapping.yaml +29 -0
  239. tests/matmul.workload.yaml +12 -0
  240. tests/matmul_spatial.mapping.yaml +44 -0
  241. tests/mha.renames.yaml +65 -0
  242. tests/mha.workload.yaml +67 -0
  243. tests/mha.yaml +59 -0
  244. tests/mha_full.workload.yaml +67 -0
  245. tests/mobilenet.workload.yaml +35 -0
  246. tests/mobilenet_long.workload.yaml +64 -0
  247. tests/pmappingcache.py +24 -0
  248. tests/processing_stage.arch.yaml +40 -0
  249. tests/snowcat.arch.yaml +36 -0
  250. tests/test_ffm_join_pmappings.py +106 -0
  251. tests/test_ffm_make_pmappings.py +82 -0
  252. tests/test_ffm_make_tile_shapes.py +49 -0
  253. tests/test_mapper.py +100 -0
  254. tests/test_model.py +37 -0
  255. tests/test_plotting.py +72 -0
  256. tests/test_processing_stage.py +46 -0
  257. tests/test_symbolic_model.py +248 -0
  258. tests/test_workload.py +141 -0
accelforge/__init__.py ADDED
@@ -0,0 +1,21 @@
1
+ from accelforge.frontend import arch
2
+ from accelforge.frontend import config
3
+ from accelforge.frontend import mapping
4
+ from accelforge.frontend import renames
5
+ from accelforge.frontend import spec
6
+ from accelforge.frontend import variables
7
+ from accelforge.frontend import workload
8
+ from accelforge.frontend.spec import Spec, Spec
9
+ from accelforge.mapper.FFM import Metrics
10
+ from accelforge.util import set_n_parallel_jobs
11
+ from accelforge.util import LiteralString
12
+ import accelforge.mapper as mapper
13
+ from accelforge.examples import examples
14
+
15
+ from accelforge.frontend.variables import Variables
16
+ from accelforge.frontend.arch import Arch
17
+ from accelforge.frontend.config import Config
18
+ from accelforge.frontend.mapping import Mapping
19
+ from accelforge.frontend.renames import Renames
20
+ from accelforge.frontend.spec import Spec
21
+ from accelforge.frontend.workload import Workload
@@ -0,0 +1,16 @@
1
+ import os
2
+
3
+ os.environ["ACCELFORGE_ACCELERATED_IMPORTS"] = "0"
4
+
5
+ if os.environ.get("ACCELFORGE_ACCELERATED_IMPORTS", "0") == "1":
6
+ import cudf as pd
7
+ import cupy as np
8
+ import cupy as scipy
9
+
10
+ ACCELERATED = True
11
+ else:
12
+ import pandas as pd
13
+ import numpy as np
14
+ import scipy
15
+
16
+ ACCELERATED = False
@@ -0,0 +1,271 @@
1
+ from collections import defaultdict
2
+ import itertools
3
+ import time
4
+ from fastfusion._accelerated_imports import pd
5
+ from fastfusion.mapper.FFM._join_pmappings.sim import PmappingGroup, Loop, Compatibility
6
+ from fastfusion.mapper.FFM._join_pmappings.pmapping_group import PmappingDataframe
7
+ from fastfusion.mapper.simanneal.mapspaceglobals import MapspaceGlobals
8
+ from fastfusion.util._frozenset import fzs
9
+
10
+
11
+ def mapping2sims(einsum_to_result: Compatibility):
12
+ r = {}
13
+ for einsum_name, compat_dict in einsum_to_result.items():
14
+ r[einsum_name] = [paretofy(k, v) for k, v in compat_dict.items()]
15
+ return list(r.values())
16
+
17
+
18
+ def get_possible_translations(
19
+ t: Compatibility,
20
+ pairwise_equivalent_rank_variables: dict[str, set[str]],
21
+ full_equivalent_rank_variables: dict[str, set[str]],
22
+ right_rank_variables: set[str],
23
+ ):
24
+ # Fused ranks should be transitive, but if a fused loop indexes into two
25
+ # different ranks in the next Einsum, we can't fuse becuase it will tile in
26
+ # multiple directions.
27
+ #
28
+ # The first union checks what loops we CAN fuse with in the next Einsum. The
29
+ # second union checks what loops MUST index into in the next
30
+ #
31
+ # Einsum. If we alias into multiple ranks, we can't fuse. Otherwise, try out
32
+ # each possible rank.
33
+ def translate_loop(l: Loop):
34
+ compatible_rank_variables = (
35
+ set.union(
36
+ *(full_equivalent_rank_variables[n] for n in l.rank_variable_names)
37
+ )
38
+ & right_rank_variables
39
+ )
40
+ pairwise_compatible_rank_variables = (
41
+ set.union(
42
+ *(pairwise_equivalent_rank_variables[n] for n in l.rank_variable_names)
43
+ )
44
+ & right_rank_variables
45
+ )
46
+ if len(pairwise_compatible_rank_variables) > 1:
47
+ return
48
+ for n in compatible_rank_variables:
49
+ yield Loop(fzs((n,)), l.bound, l.is_spatial)
50
+
51
+ for loops in itertools.product(*map(translate_loop, t.loops)):
52
+ yield t.update(loops=loops)
53
+
54
+
55
+ prev_time = 0
56
+ total_time = defaultdict(int)
57
+
58
+
59
+ def init_print_time():
60
+ global prev_time, total_time
61
+ prev_time = time.time()
62
+ total_time = defaultdict(int)
63
+
64
+
65
+ def print_time(what: str):
66
+ global prev_time
67
+ t = time.time() - prev_time
68
+ print(f"{what}: {t:.2f} seconds")
69
+ total_time[what] += t
70
+ prev_time = time.time()
71
+
72
+
73
+ def print_total_time():
74
+ print(f"\n======== Total time ========")
75
+ for k, v in total_time.items():
76
+ print(f"{k}: {v:.2f} seconds")
77
+ total = sum(total_time.values())
78
+ if total > 60:
79
+ print(f"\nTotal: {total:.2f} seconds ({total/60:.2f} minutes)")
80
+ else:
81
+ print(f"\nTotal: {total:.2f} seconds")
82
+ print(f"============================\n")
83
+
84
+
85
+ class PmappingsOneEinsum:
86
+ def __init__(self, einsum_name: str, pm_group_list: list[PmappingGroup]):
87
+ self.einsum_name: str = einsum_name
88
+ self.pmapping_groups: list[PmappingGroup] = pm_group_list
89
+ self.tensor_names: set[str] = set(pm_group_list[0].tensor_names)
90
+
91
+ def __getitem__(self, i):
92
+ return self.pmapping_groups[i]
93
+
94
+
95
+ def make_full_equivalent_rank_variables(pairwise_equivalent_rank_variables):
96
+ full_equivalent_rank_variables = {
97
+ k: set(v) for k, v in pairwise_equivalent_rank_variables.items()
98
+ }
99
+ changed = True
100
+ while changed:
101
+ changed = False
102
+ for r in full_equivalent_rank_variables:
103
+ for r2 in list(full_equivalent_rank_variables[r]):
104
+ for r3 in list(full_equivalent_rank_variables[r2]):
105
+ if r3 in full_equivalent_rank_variables[r]:
106
+ continue
107
+ changed = True
108
+ full_equivalent_rank_variables[r].add(r3)
109
+ return full_equivalent_rank_variables
110
+
111
+
112
+ def quick_join(
113
+ pmapping_groups: dict[str, PmappingGroup],
114
+ mapspace_globals: MapspaceGlobals,
115
+ ):
116
+ resource2capacity = mapspace_globals.resource2capacity
117
+ pairwise_equivalent_rank_variables = mapspace_globals.pairwise_equivalent_ranks
118
+ aliased_tensors = mapspace_globals.aliased_tensors
119
+ full_equivalent_rank_variables = mapspace_globals.full_equivalent_ranks
120
+
121
+ n_mappings = {}
122
+ runtime = {}
123
+ nbuckets = []
124
+
125
+ n_evaluations = 0
126
+
127
+ pmapping_groups = list(pmapping_groups.items())
128
+
129
+ init_print_time()
130
+
131
+ pmapping_groups = [PmappingsOneEinsum(*s) for s in pmapping_groups]
132
+
133
+ if not pmapping_groups:
134
+ raise ValueError("No PmappingGroups to join")
135
+
136
+ # ======================================================================
137
+ # Initial consolidate and group all PmappingGroups
138
+ # ======================================================================
139
+ for i, sim_holder in enumerate(pmapping_groups):
140
+ right_tensors = set.union(
141
+ set(), *[s.tensor_names for s in pmapping_groups[i + 1 :]]
142
+ )
143
+ if i == 0:
144
+ sim_holder.pmapping_groups = PmappingGroup.left_consolidate(
145
+ sim_holder.pmapping_groups,
146
+ right_tensors,
147
+ )
148
+ continue
149
+ t0 = time.time()
150
+ left_tensors = set.union(set(), *[s.tensor_names for s in pmapping_groups[:i]])
151
+ live_tensors = right_tensors
152
+ shared_tensors = left_tensors & sim_holder.tensor_names
153
+ sim_holder.pmapping_groups = sorted(
154
+ sim_holder.pmapping_groups, key=lambda x: len(x.mappings.data), reverse=True
155
+ )
156
+ sim_holder.pmapping_groups = PmappingGroup.right_consolidate(
157
+ sim_holder.pmapping_groups,
158
+ live_tensors,
159
+ shared_tensors,
160
+ )
161
+ sim_holder.pmapping_groups = PmappingGroup.combine_combineable(
162
+ sim_holder.pmapping_groups,
163
+ left_tensors | right_tensors,
164
+ )
165
+ if i > 0:
166
+ sim_holder.pmapping_groups = PmappingGroup.group_right(
167
+ sim_holder.pmapping_groups, left_tensors, drop_tags=True
168
+ )
169
+ einsum, prev_einsum = sim_holder.einsum_name, pmapping_groups[i - 1].einsum_name
170
+ runtime[f"{prev_einsum} → {einsum}"] = time.time() - t0
171
+ t0 = time.time()
172
+
173
+ n_iterations = 0
174
+ total_iterations = len(pmapping_groups)
175
+
176
+ def grab_sim_holder() -> (
177
+ tuple[dict[Compatibility, list[PmappingGroup]], str, set[str]]
178
+ ):
179
+ nonlocal n_iterations
180
+ n_iterations += 1
181
+ holder = pmapping_groups.pop(0)
182
+ return holder.pmapping_groups, holder.einsum_name, holder.tensor_names
183
+
184
+ if pmapping_groups:
185
+ left, left_einsum, left_tensors = grab_sim_holder()
186
+
187
+ partial_mapping_size = 1
188
+ while pmapping_groups:
189
+ t0 = time.time()
190
+ # ======================================================================
191
+ # Grab new Einsum from the right. Record logging data and find still
192
+ # tensors that will be live after this Einsum.
193
+ # ======================================================================
194
+ nbuckets.append(len(left))
195
+ # nmappings.append(sum(len(s.mappings.data) for s in left))
196
+ right, right_einsum, right_tensors = grab_sim_holder()
197
+ right_rank_variables = mapspace_globals.einsum2ranks[right_einsum]
198
+
199
+ partial_mapping_size += 1
200
+
201
+ live_tensors = set.union(set(), *[s.tensor_names for s in pmapping_groups])
202
+ shared_tensors = set(left_tensors) & set(right_tensors)
203
+ live_tensors_with_right = live_tensors | right_tensors
204
+
205
+ # ======================================================================
206
+ # Clean up the previously-combined PmappingGroups. Consolidate, combine, group
207
+ # them into buckets.
208
+ # ======================================================================
209
+
210
+ left = PmappingGroup.combine_combineable(
211
+ left,
212
+ live_tensors | right_tensors,
213
+ )
214
+
215
+ # Group left and right into buckets
216
+ left = PmappingGroup.group_left(left, right_tensors, drop_tags=True)
217
+
218
+ # ======================================================================
219
+ # Remove dead tensors from left and right. This happens after grouping
220
+ # because we only reserve space for shared tensors after it's dead. This
221
+ # is in case the tensor lifetime extends beyond the Einsums for which it
222
+ # is used.
223
+ # ======================================================================
224
+ PmappingGroup.remove_dead_tensors(
225
+ [s for lr in [left, right] for v in lr.values() for s in v], live_tensors
226
+ )
227
+
228
+ # ======================================================================
229
+ # Merge the left and right buckets.
230
+ # ======================================================================
231
+ combined: list[PmappingGroup] = []
232
+ for k in left:
233
+ for k_translated in get_possible_translations(
234
+ k,
235
+ pairwise_equivalent_rank_variables,
236
+ full_equivalent_rank_variables,
237
+ right_rank_variables,
238
+ ):
239
+ for a, b in itertools.product(left[k], right.get(k_translated, [])):
240
+ if a.compatibility.tags.are_compatible_with(b.compatibility.tags):
241
+ combined.append(
242
+ a.merge_next(
243
+ b,
244
+ live_tensors,
245
+ live_tensors_with_right,
246
+ aliased_tensors,
247
+ resource2capacity,
248
+ delay=False,
249
+ )
250
+ )
251
+
252
+ if not combined:
253
+ raise ValueError("No match found for any group")
254
+
255
+ # ======================================================================
256
+ # Update left for the next iteration.
257
+ # =================================================================
258
+ left = combined
259
+ left_einsum = right_einsum
260
+ left_tensors |= right_tensors
261
+
262
+ # ======================================================================
263
+ # Final consolidate and group
264
+ # ======================================================================
265
+ t0 = time.time()
266
+ left = PmappingGroup.left_consolidate(left, None)
267
+ s_final = PmappingGroup.combine_combineable(left, set(), drop_tags=True)
268
+ assert len(s_final) == 1
269
+ mappings = s_final[0].mappings
270
+
271
+ return mappings
@@ -0,0 +1,298 @@
1
+ from collections import defaultdict
2
+ import itertools
3
+
4
+ from fastfusion.frontend import arch
5
+ from fastfusion.frontend.spec import Spec
6
+ from fastfusion.mapper.FFM._join_pmappings.join_pmappings import PmappingGroup
7
+ from fastfusion.mapper.FFM._join_pmappings.compatibility import Loop, Compatibility
8
+ from fastfusion.util._frozenset import fzs
9
+ from fastfusion.mapper.FFM._join_pmappings.join_pmappings import (
10
+ make_full_equivalent_rank_variables,
11
+ )
12
+
13
+
14
+ class MapspaceGlobals:
15
+ def __init__(
16
+ self,
17
+ pmapping_groups: dict[str, list[PmappingGroup]],
18
+ spec: Spec,
19
+ objective_function_cols: list[str] = None,
20
+ flattened_architecture: list[arch.Leaf] = None,
21
+ ):
22
+ self.pmapping_groups = pmapping_groups
23
+ self.einsum_names = spec.workload.einsum_names
24
+ self.einsum2ranks = {
25
+ einsum_name: spec.workload.einsums[einsum_name].rank_variables
26
+ for einsum_name in self.einsum_names
27
+ }
28
+ self.einsum2tensors = {
29
+ einsum_name: spec.workload.einsums[einsum_name].tensor_names
30
+ for einsum_name in self.einsum_names
31
+ }
32
+ self.tensor_names = set().union(
33
+ *(self.einsum2tensors[e] for e in self.einsum_names)
34
+ )
35
+ self.tensor_names_used_in_multiple_einsums = (
36
+ spec.workload.tensor_names_used_in_multiple_einsums
37
+ )
38
+ self.pairwise_equivalent_ranks = (
39
+ spec.workload.get_pairwise_equivalent_rank_variables()
40
+ )
41
+ self.full_equivalent_ranks = make_full_equivalent_rank_variables(
42
+ self.pairwise_equivalent_ranks
43
+ )
44
+
45
+ self.resource2capacity = {}
46
+ flattened_architecture = (
47
+ flattened_architecture or spec.get_flattened_architecture()
48
+ )
49
+ for l in flattened_architecture:
50
+ if isinstance(l, arch.Memory):
51
+ self.resource2capacity[l.name] = l.size
52
+
53
+ self.objective_function_cols = objective_function_cols
54
+ self.rank_translations = self._create_rank_translations(self.einsum2ranks)
55
+
56
+ for i, (left_id, left_sims) in enumerate(pmapping_groups.items()):
57
+ for j, (right_id, right_sims) in enumerate(pmapping_groups.items()):
58
+ if i >= j:
59
+ continue
60
+
61
+ left_live = self.get_live_tensors(*self.einsum_names[: i + 1])
62
+ right_live = self.get_live_tensors(*self.einsum_names[j:])
63
+ left_tensors = self.get_tensors(self.einsum_names[i])
64
+ right_tensors = self.get_tensors(self.einsum_names[j])
65
+
66
+ if not (left_live & right_live):
67
+ continue
68
+ print(f"Checking {left_id} {right_id}")
69
+
70
+ right_tilings = {
71
+ s.compatibility.clear_dead_tensors(
72
+ live_tensors=left_live
73
+ ).clear_dead_tensors(left_tensors, keep_loops=True)
74
+ for s in right_sims
75
+ }
76
+ assert right_tilings, f"R {left_id} {right_id}"
77
+ for s in list(left_sims):
78
+ for t in self.get_possible_translations(s.compatibility, right_id):
79
+ t = t.clear_dead_tensors(live_tensors=right_live)
80
+ t = t.clear_dead_tensors(
81
+ live_tensors=right_tensors, keep_loops=True
82
+ )
83
+ if t in right_tilings:
84
+ break
85
+ else:
86
+ left_sims.remove(s)
87
+ assert (
88
+ left_sims
89
+ ), f"Removed all of left {left_id} while checking right {right_id}"
90
+
91
+ left_tilings = {
92
+ s.compatibility.clear_dead_tensors(
93
+ live_tensors=right_live
94
+ ).clear_dead_tensors(right_tensors, keep_loops=True)
95
+ for s in left_sims
96
+ }
97
+ assert left_tilings, f"L {left_id} {right_id}"
98
+ for s in list(right_sims):
99
+ for t in self.get_possible_translations(s.compatibility, left_id):
100
+ t = t.clear_dead_tensors(live_tensors=left_live)
101
+ t = t.clear_dead_tensors(
102
+ live_tensors=left_tensors, keep_loops=True
103
+ )
104
+ if t in left_tilings:
105
+ break
106
+ else:
107
+ right_sims.remove(s)
108
+ assert (
109
+ right_sims
110
+ ), f"Removed all of right {right_id} while checking left {left_id}"
111
+
112
+ self.tensor2possible_loops_above = self._create_tensor2possible_loops_above()
113
+ self.tensor2possible_loops_above_set = {
114
+ k: {k2: set(v2) for k2, v2 in v.items()}
115
+ for k, v in self.tensor2possible_loops_above.items()
116
+ }
117
+ self.tensor2memories = self._create_tensor2memories()
118
+ self.einsum_tiling_2_sim = self._create_einsum_tiling_2_sim()
119
+ self.einsum_rank_index_to_loops = self._create_einsum_rank_index_to_loops()
120
+ (
121
+ self.compatibility2leftcompatibility,
122
+ self.compatibility2rightcompatibility,
123
+ self.leftcompatibility2tiling,
124
+ self.rightcompatibility2tiling,
125
+ ) = self._create_compatibility()
126
+ self.size_scale = len(self.einsum2ranks)
127
+ n_optimal = sum(
128
+ len(s.mappings.data)
129
+ for simlist in self.pmapping_groups.values()
130
+ for s in simlist
131
+ )
132
+ n_pmappings = sum(
133
+ s.mappings.n_pmappings
134
+ for simlist in self.pmapping_groups.values()
135
+ for s in simlist
136
+ )
137
+ self.find_pmapping_scale = n_pmappings / n_optimal
138
+ self.aliased_tensors = spec.workload.get_tensor_copies()
139
+
140
+ def get_live_tensors(self, *einsums: str):
141
+ return set.union(*(self.einsum2tensors[e] for e in einsums))
142
+
143
+ def _create_compatibility(self):
144
+ tiling2leftcompatibility = {}
145
+ tiling2rightcompatibility = {}
146
+
147
+ def tilings2compatibility(tilings: list[Compatibility], live_tensors: set[str]):
148
+ return {t: t.clear_dead_tensors(live_tensors=live_tensors) for t in tilings}
149
+
150
+ for i, (einsum_name, pm_group_list) in enumerate(self.pmapping_groups.items()):
151
+ if i > 0:
152
+ prev_live = self.get_live_tensors(*self.einsum_names[:i])
153
+ tiling2leftcompatibility[einsum_name] = tilings2compatibility(
154
+ [s.compatibility for s in pm_group_list],
155
+ prev_live,
156
+ )
157
+ if i < len(self.pmapping_groups) - 1:
158
+ next_live = self.get_live_tensors(*self.einsum_names[i + 1 :])
159
+ tiling2rightcompatibility[einsum_name] = tilings2compatibility(
160
+ [s.compatibility for s in pm_group_list],
161
+ next_live,
162
+ )
163
+
164
+ leftcompatibility2tiling = {}
165
+ rightcompatibility2tiling = {}
166
+ for einsum_name in self.einsum_names:
167
+ for src, dst in (
168
+ (tiling2leftcompatibility, leftcompatibility2tiling),
169
+ (tiling2rightcompatibility, rightcompatibility2tiling),
170
+ ):
171
+ if einsum_name not in src:
172
+ continue
173
+ dst = dst.setdefault(einsum_name, {})
174
+ for k, v in src[einsum_name].items():
175
+ dst.setdefault(v, []).append(k)
176
+ return (
177
+ tiling2leftcompatibility,
178
+ tiling2rightcompatibility,
179
+ leftcompatibility2tiling,
180
+ rightcompatibility2tiling,
181
+ )
182
+
183
+ def _create_einsum_tiling_2_sim(self):
184
+ einsum_tiling_2_sim = {}
185
+ for e, pm_group_list in self.pmapping_groups.items():
186
+ cur_sims = defaultdict(list)
187
+ for sim in pm_group_list:
188
+ cur_sims[sim.compatibility].append(sim)
189
+ einsum_tiling_2_sim[e] = {}
190
+ for t, s in cur_sims.items():
191
+ s = PmappingGroup.concat(s)
192
+ einsum_tiling_2_sim[e][t] = s
193
+ return einsum_tiling_2_sim
194
+
195
+ def _create_tensor2possible_loops_above(self):
196
+ tensor2possible_loops_above = {}
197
+ for einsum_name, pm_group_list in self.pmapping_groups.items():
198
+ tensor2possible_loops_above[einsum_name] = defaultdict(set)
199
+ for sim in pm_group_list:
200
+ for tensor in sim.compatibility.tensors:
201
+ tensor2possible_loops_above[einsum_name][tensor] |= set(
202
+ sim.compatibility.loops[: tensor.above_loop_index]
203
+ )
204
+ return {
205
+ e: {s: list(l) for s, l in d.items()}
206
+ for e, d in tensor2possible_loops_above.items()
207
+ }
208
+
209
+ def _create_tensor2memories(self):
210
+ tensor2memories = {}
211
+ for t in self.tensor_names_used_in_multiple_einsums:
212
+ possible_memories = []
213
+ for einsum_name, pm_group_list in self.pmapping_groups.items():
214
+ cur_memories = set()
215
+ if t not in pm_group_list[0].tensor_names:
216
+ continue
217
+ for sim in pm_group_list:
218
+ tensor = sim.compatibility.get_tensor_by_name(t)
219
+ cur_memories.add(tensor)
220
+ possible_memories.append(cur_memories)
221
+ if possible_memories:
222
+ tensor2memories[t] = list(set.intersection(*possible_memories))
223
+ else:
224
+ raise ValueError(f"No memories for {t}")
225
+ return tensor2memories
226
+
227
+ def _create_rank_translations(self, einsum2ranks: dict[str, set[str]]):
228
+ rank_translations = {}
229
+ for einsum_name, ranks in einsum2ranks.items():
230
+ translations = {einsum_name2: {} for einsum_name2 in self.einsum_names}
231
+ for einsum_name2, ranks2 in einsum2ranks.items():
232
+ for rank in ranks:
233
+ equiv = self.full_equivalent_ranks[rank] & ranks2
234
+ translations[einsum_name2][rank] = equiv
235
+ rank_translations[einsum_name] = {
236
+ k: {k2: list(v2) for k2, v2 in v.items()}
237
+ for k, v in translations.items()
238
+ }
239
+ return rank_translations
240
+
241
+ def _create_full_equivalent_ranks(
242
+ self, pairwise_equivalent_ranks: dict[str, set[str]]
243
+ ):
244
+ full_equivalent_ranks = {
245
+ k: set(v) for k, v in pairwise_equivalent_ranks.items()
246
+ }
247
+ changed = True
248
+ while changed:
249
+ changed = False
250
+ for r in full_equivalent_ranks:
251
+ for r2 in list(full_equivalent_ranks[r]):
252
+ for r3 in list(full_equivalent_ranks[r2]):
253
+ if r3 in full_equivalent_ranks[r]:
254
+ continue
255
+ changed = True
256
+ full_equivalent_ranks[r].add(r3)
257
+ return full_equivalent_ranks
258
+
259
+ def _create_einsum_rank_index_to_loops(
260
+ self,
261
+ ) -> dict[str, dict[str, dict[int, list[Loop]]]]:
262
+ einsum_rank_index_to_loops = {}
263
+ for einsum_name, pm_group_list in self.pmapping_groups.items():
264
+ einsum_rank_index_to_loops[einsum_name] = {}
265
+ for sim in pm_group_list:
266
+ for rank_index, loop in enumerate(sim.compatibility.loops):
267
+ x = einsum_rank_index_to_loops[einsum_name].setdefault(
268
+ loop.rank_variable_name, {}
269
+ )
270
+ x.setdefault(rank_index, []).append(loop)
271
+ return einsum_rank_index_to_loops
272
+
273
+ def get_tensors(self, *einsums: str):
274
+ return set.union(*(self.einsum2tensors[e] for e in einsums))
275
+
276
+ def get_possible_translations(self, t: Compatibility, to_einsum: str):
277
+ pairwise_equivalent_ranks = self.pairwise_equivalent_ranks
278
+ full_equivalent_ranks = self.full_equivalent_ranks
279
+ right_ranks = self.einsum2ranks[to_einsum]
280
+
281
+ def translate_loop(l: Loop):
282
+ compatible_ranks = (
283
+ set.union(*(full_equivalent_ranks[n] for n in l.rank_variable_names))
284
+ & right_ranks
285
+ )
286
+ pairwise_compatible_ranks = (
287
+ set.union(
288
+ *(pairwise_equivalent_ranks[n] for n in l.rank_variable_names)
289
+ )
290
+ & right_ranks
291
+ )
292
+ if len(pairwise_compatible_ranks) > 1:
293
+ return
294
+ for n in compatible_ranks:
295
+ yield Loop(fzs((n,)), l.bound, l.is_spatial)
296
+
297
+ for loops in itertools.product(*map(translate_loop, t.loops)):
298
+ yield Compatibility(loops, t.tensors, t.tags)