accelforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. accelforge/__init__.py +21 -0
  2. accelforge/_accelerated_imports.py +16 -0
  3. accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
  4. accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
  5. accelforge/_deprecate/_simanneal/simanneal.py +666 -0
  6. accelforge/_deprecate/_simanneal/tracking.py +105 -0
  7. accelforge/_deprecate/_simanneal/wrappers.py +218 -0
  8. accelforge/_deprecate/_simanneal2/__init__.py +7 -0
  9. accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
  10. accelforge/_deprecate/_simanneal2/tracking.py +116 -0
  11. accelforge/_deprecate/compatibility_util.py +181 -0
  12. accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
  13. accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
  14. accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
  15. accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
  16. accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
  17. accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
  18. accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
  19. accelforge/_deprecate/tags.py +69 -0
  20. accelforge/_deprecate/viz/__init__.py +0 -0
  21. accelforge/_deprecate/viz/interactive.py +159 -0
  22. accelforge/_deprecate/viz/reservationtree.py +307 -0
  23. accelforge/_deprecate/viz/ski_slope.py +88 -0
  24. accelforge/_version.py +15 -0
  25. accelforge/examples.py +39 -0
  26. accelforge/frontend/__init__.py +10 -0
  27. accelforge/frontend/_binding.py +129 -0
  28. accelforge/frontend/_workload_isl/__init__.py +2 -0
  29. accelforge/frontend/_workload_isl/_isl.py +149 -0
  30. accelforge/frontend/_workload_isl/_symbolic.py +141 -0
  31. accelforge/frontend/arch copy.py +1544 -0
  32. accelforge/frontend/arch.py +1642 -0
  33. accelforge/frontend/config.py +63 -0
  34. accelforge/frontend/mapper/__init__.py +5 -0
  35. accelforge/frontend/mapper/ffm.py +126 -0
  36. accelforge/frontend/mapper/mapper.py +7 -0
  37. accelforge/frontend/mapper/metrics.py +30 -0
  38. accelforge/frontend/mapping/__init__.py +1 -0
  39. accelforge/frontend/mapping/mapping.py +1736 -0
  40. accelforge/frontend/model.py +14 -0
  41. accelforge/frontend/renames.py +150 -0
  42. accelforge/frontend/spec copy.py +230 -0
  43. accelforge/frontend/spec.py +301 -0
  44. accelforge/frontend/variables.py +12 -0
  45. accelforge/frontend/workload.py +952 -0
  46. accelforge/mapper/FFM/__init__.py +9 -0
  47. accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
  48. accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
  49. accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
  50. accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
  51. accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
  52. accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
  53. accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
  54. accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
  55. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
  56. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
  57. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
  58. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
  59. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
  60. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
  61. accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
  62. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
  63. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
  64. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
  65. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
  66. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
  67. accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
  68. accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
  69. accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
  70. accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
  71. accelforge/mapper/FFM/data.py +61 -0
  72. accelforge/mapper/FFM/main copy.py +236 -0
  73. accelforge/mapper/FFM/main.py +208 -0
  74. accelforge/mapper/FFM/mappings.py +510 -0
  75. accelforge/mapper/FFM/pmappings.py +310 -0
  76. accelforge/mapper/__init__.py +4 -0
  77. accelforge/mapper.py +0 -0
  78. accelforge/model/__init__.py +1 -0
  79. accelforge/model/_looptree/__init__.py +0 -0
  80. accelforge/model/_looptree/accesses.py +335 -0
  81. accelforge/model/_looptree/capacity/__init__.py +1 -0
  82. accelforge/model/_looptree/capacity/aggregators.py +36 -0
  83. accelforge/model/_looptree/capacity/capacity.py +47 -0
  84. accelforge/model/_looptree/energy.py +150 -0
  85. accelforge/model/_looptree/equivalent_ranks.py +29 -0
  86. accelforge/model/_looptree/latency/__init__.py +1 -0
  87. accelforge/model/_looptree/latency/latency.py +98 -0
  88. accelforge/model/_looptree/latency/memory.py +120 -0
  89. accelforge/model/_looptree/latency/processors.py +92 -0
  90. accelforge/model/_looptree/mapping_utilities.py +71 -0
  91. accelforge/model/_looptree/reuse/__init__.py +4 -0
  92. accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
  93. accelforge/model/_looptree/reuse/isl/des.py +59 -0
  94. accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
  95. accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
  96. accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
  97. accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
  98. accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
  99. accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
  100. accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
  101. accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
  102. accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
  103. accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
  104. accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
  105. accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
  106. accelforge/model/_looptree/run.py +122 -0
  107. accelforge/model/_looptree/types.py +26 -0
  108. accelforge/model/_looptree/visualization/__init__.py +0 -0
  109. accelforge/model/_looptree/visualization/occupancy.py +11 -0
  110. accelforge/model/main.py +222 -0
  111. accelforge/plotting/__init__.py +2 -0
  112. accelforge/plotting/mappings.py +219 -0
  113. accelforge/plotting/specs.py +57 -0
  114. accelforge/util/__init__.py +4 -0
  115. accelforge/util/_base_analysis_types.py +24 -0
  116. accelforge/util/_basetypes.py +1089 -0
  117. accelforge/util/_frozenset.py +36 -0
  118. accelforge/util/_isl.py +29 -0
  119. accelforge/util/_itertools.py +14 -0
  120. accelforge/util/_mathfuncs.py +57 -0
  121. accelforge/util/_parse_expressions.py +339 -0
  122. accelforge/util/_picklecache.py +32 -0
  123. accelforge/util/_setexpressions.py +268 -0
  124. accelforge/util/_sympy/__init__.py +0 -0
  125. accelforge/util/_sympy/broadcast_max.py +18 -0
  126. accelforge/util/_visualization.py +112 -0
  127. accelforge/util/_yaml.py +579 -0
  128. accelforge/util/parallel.py +193 -0
  129. accelforge-0.0.1.dist-info/METADATA +64 -0
  130. accelforge-0.0.1.dist-info/RECORD +258 -0
  131. accelforge-0.0.1.dist-info/WHEEL +5 -0
  132. accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
  133. accelforge-0.0.1.dist-info/top_level.txt +5 -0
  134. docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
  135. docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
  136. docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
  137. docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
  138. docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
  139. docs/_build/html/_sources/fastfusion.rst.txt +20 -0
  140. docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
  141. docs/_build/html/_sources/index.rst.txt +87 -0
  142. docs/_build/html/_sources/modules.rst.txt +7 -0
  143. docs/_build/html/_sources/notes/citation.rst.txt +45 -0
  144. docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
  145. docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
  146. docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
  147. docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
  148. docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
  149. docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
  150. docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
  151. docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
  152. docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
  153. docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
  154. docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
  155. docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
  156. docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
  157. docs/_build/html/_sources/notes/spec.rst.txt +36 -0
  158. docs/source/_ext/include_attrs.py +213 -0
  159. docs/source/_ext/include_docstring.py +364 -0
  160. docs/source/_ext/include_functions.py +154 -0
  161. docs/source/_ext/include_notebook.py +131 -0
  162. docs/source/_ext/include_yaml.py +119 -0
  163. docs/source/_ext/inherited_attributes.py +222 -0
  164. docs/source/_ext/paths.py +4 -0
  165. docs/source/conf.py +79 -0
  166. examples/arches/compute_in_memory/_include.yaml +74 -0
  167. examples/arches/compute_in_memory/_include_functions.py +229 -0
  168. examples/arches/compute_in_memory/_load_spec.py +57 -0
  169. examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
  170. examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
  171. examples/arches/compute_in_memory/components/misc.py +195 -0
  172. examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
  173. examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
  174. examples/arches/compute_in_memory/isaac.yaml +233 -0
  175. examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
  176. examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
  177. examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
  178. examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
  179. examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
  180. examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
  181. examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
  182. examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
  183. examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
  184. examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
  185. examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
  186. examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
  187. examples/arches/eyeriss.yaml +68 -0
  188. examples/arches/fanout_variations/at_glb.yaml +31 -0
  189. examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
  190. examples/arches/fanout_variations/at_mac.yaml +31 -0
  191. examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
  192. examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
  193. examples/arches/nvdla.yaml +47 -0
  194. examples/arches/simple.yaml +28 -0
  195. examples/arches/tpu_v4i.yaml +67 -0
  196. examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
  197. examples/misc/component_annotated.yaml +33 -0
  198. examples/workloads/gpt3_6.7B.yaml +124 -0
  199. examples/workloads/matmuls.yaml +20 -0
  200. examples/workloads/mobilenet_28.yaml +81 -0
  201. examples/workloads/mobilenet_various_separate.yaml +106 -0
  202. examples/workloads/three_matmuls_annotated.yaml +59 -0
  203. notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
  204. notebooks/compute_in_memory/_scripts.py +339 -0
  205. notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
  206. notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
  207. notebooks/paths.py +4 -0
  208. notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
  209. notebooks/tutorials/FFM.ipynb +3498 -0
  210. notebooks/tutorials/_include.py +48 -0
  211. notebooks/tutorials/component_energy_area.ipynb +363 -0
  212. tests/Q_mapping.yaml +38 -0
  213. tests/__init__.py +0 -0
  214. tests/conv.mapping.yaml +27 -0
  215. tests/conv.workload.yaml +13 -0
  216. tests/conv_sym.mapping.yaml +43 -0
  217. tests/copy.mapping.yaml +35 -0
  218. tests/copy.workload.yaml +15 -0
  219. tests/distribuffers/__init__.py +0 -0
  220. tests/distribuffers/multicast/test_cases.yaml +482 -0
  221. tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
  222. tests/distribuffers/spec/distributed.yaml +100 -0
  223. tests/distribuffers/spec/logical_arch.yaml +32 -0
  224. tests/distribuffers/spec/physical_arch.yaml +69 -0
  225. tests/distribuffers/test_binding.py +48 -0
  226. tests/frontend/__init__.py +0 -0
  227. tests/frontend/test_mapping_viz.py +52 -0
  228. tests/mapper/__init__.py +0 -0
  229. tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
  230. tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
  231. tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
  232. tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
  233. tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
  234. tests/mapper/test_mapping_to_isl.py +90 -0
  235. tests/mapper/test_spatial_reuse_analysis.py +67 -0
  236. tests/mapper/test_temporal_reuse_analysis.py +56 -0
  237. tests/mapper/util.py +58 -0
  238. tests/matmul.mapping.yaml +29 -0
  239. tests/matmul.workload.yaml +12 -0
  240. tests/matmul_spatial.mapping.yaml +44 -0
  241. tests/mha.renames.yaml +65 -0
  242. tests/mha.workload.yaml +67 -0
  243. tests/mha.yaml +59 -0
  244. tests/mha_full.workload.yaml +67 -0
  245. tests/mobilenet.workload.yaml +35 -0
  246. tests/mobilenet_long.workload.yaml +64 -0
  247. tests/pmappingcache.py +24 -0
  248. tests/processing_stage.arch.yaml +40 -0
  249. tests/snowcat.arch.yaml +36 -0
  250. tests/test_ffm_join_pmappings.py +106 -0
  251. tests/test_ffm_make_pmappings.py +82 -0
  252. tests/test_ffm_make_tile_shapes.py +49 -0
  253. tests/test_mapper.py +100 -0
  254. tests/test_model.py +37 -0
  255. tests/test_plotting.py +72 -0
  256. tests/test_processing_stage.py +46 -0
  257. tests/test_symbolic_model.py +248 -0
  258. tests/test_workload.py +141 -0
@@ -0,0 +1,9 @@
1
+ from accelforge.mapper.FFM.main import (
2
+ map_workload_to_arch,
3
+ make_pmappings,
4
+ join_pmappings,
5
+ MultiEinsumPmappings,
6
+ Mappings,
7
+ )
8
+ from accelforge.frontend.mapper.metrics import Metrics
9
+ from accelforge.mapper.FFM._join_pmappings.pmapping_group import PmappingGroup
File without changes
@@ -0,0 +1,653 @@
1
+ from collections import defaultdict
2
+ from enum import Enum
3
+ from dataclasses import dataclass, replace
4
+ import itertools
5
+ from numbers import Number
6
+ from typing import Literal, TypeVar
7
+
8
+ from accelforge.frontend.workload import TensorAccess, Workload
9
+ from accelforge.frontend.mapping import (
10
+ Compute,
11
+ Loop,
12
+ Mapping,
13
+ Spatial,
14
+ TensorHolder,
15
+ Reservation as MappingReservation,
16
+ Split as MappingSplit,
17
+ TilePattern,
18
+ Loop as MappingLoop,
19
+ )
20
+ from accelforge.frontend.renames import Rank, RankVariable, TensorName
21
+ from accelforge.mapper.FFM._pareto_df.df_convention import (
22
+ make_fused_loop_col,
23
+ stride2col,
24
+ initial2col,
25
+ iterations2col,
26
+ )
27
+
28
+ from accelforge.util import _expfmt, fzs
29
+
30
+ # Abstractions:
31
+ # 1. Each tensor is stored above some loop index. 0 is the outermost loop, 1 the
32
+ # next-innermost...
33
+ # 2. All loops above any shared tensor are co-tiled and must match between PmappingGroups.
34
+
35
+ T = TypeVar("T", bound="Updatable")
36
+
37
+
38
+ class Updatable:
39
+ def update(self: T, **kwargs) -> T:
40
+ return replace(self, **kwargs)
41
+
42
+
43
+ def _update_rename_dict(
44
+ renames: dict[str, str],
45
+ new_renames: dict[str, str],
46
+ ):
47
+ for mine, other in new_renames.items():
48
+ if mine not in renames:
49
+ renames[mine] = other
50
+ elif renames[mine] != other:
51
+ raise ValueError(
52
+ f"Renaming {mine} to {other} conflicts with {renames[mine]}"
53
+ )
54
+
55
+
56
+ @dataclass(frozen=True, order=True, eq=True)
57
+ class Loop(Updatable):
58
+ rank_name: Rank
59
+ tile_pattern: TilePattern | None
60
+ is_spatial: bool
61
+
62
+ def __post_init__(self):
63
+ assert isinstance(self.rank_name, Rank)
64
+ assert isinstance(self.tile_pattern, Number | TilePattern | str | None)
65
+ assert isinstance(self.is_spatial, bool)
66
+ assert isinstance(
67
+ self.tile_pattern.initial_tile_shape,
68
+ Number | str | None,
69
+ )
70
+ assert isinstance(
71
+ self.tile_pattern.tile_shape,
72
+ Number | str | None,
73
+ )
74
+
75
+ def __repr__(self):
76
+ return (
77
+ f"Loop({self.rank_name.__repr__()}, {self.tile_pattern}, {self.is_spatial})"
78
+ )
79
+
80
+ def __str__(self):
81
+ return (
82
+ "S-" if self.is_spatial else ""
83
+ ) + f"{self.rank_name}-{self.tile_pattern}"
84
+
85
+ def pydot_str(self):
86
+ if self.is_spatial:
87
+ return f"S-for R{self.rank_name} size {_expfmt(self.tile_pattern)}"
88
+ return f"for {self.rank_name} size {_expfmt(self.tile_pattern)}"
89
+
90
+ def to_yaml(self):
91
+ return {"type": "loop", **self.__dict__}
92
+
93
+ def merge_next(self, right: "Loop") -> "Loop":
94
+ assert self.tile_pattern == right.tile_pattern
95
+ return Loop(
96
+ self.rank_name | right.rank_name,
97
+ right.tile_pattern,
98
+ self.is_spatial,
99
+ )
100
+
101
+ def clear_loop_bound(self, value=0):
102
+ return self.update(tile_pattern=value)
103
+
104
+ def populate(self, nloop: int) -> "Loop":
105
+ tile_pattern = TilePattern(
106
+ tile_shape=stride2col(self.rank_name, nloop),
107
+ initial_tile_shape=initial2col(self.rank_name, nloop),
108
+ calculated_n_iterations=iterations2col(nloop),
109
+ )
110
+ return self.update(tile_pattern=tile_pattern)
111
+
112
+ def _prepend_symbols(self, prepend: str) -> "Loop":
113
+ return self.update(tile_pattern=self.tile_pattern._prepend_symbols(prepend))
114
+
115
+ def clear_symbolic_tile_patterns(self) -> "Loop":
116
+ return self.update(tile_pattern=self.tile_pattern._clear_symbols())
117
+
118
+ def make_fused_loop_symbols(self, prefix: str) -> tuple[dict[str, str], "Loop"]:
119
+ r = {}
120
+ new = self
121
+
122
+ def replace(attr, new):
123
+ g = getattr(self.tile_pattern, attr)
124
+ if not isinstance(g, str):
125
+ return new
126
+ g2 = make_fused_loop_col(f"{prefix}<SEP>{g}")
127
+ r[g] = g2
128
+ return new.update(tile_pattern=new.tile_pattern.update(**{attr: g2}))
129
+
130
+ for s in new.tile_pattern._symbol_attrs():
131
+ new = replace(s, new)
132
+
133
+ return r, new
134
+
135
+ def add_n_iteration_symbols(self) -> "Loop":
136
+ return self.update(tile_pattern=self.tile_pattern.add_n_iteration_symbols())
137
+
138
+ def _rename_to_match(self, other: "Loop") -> tuple["Loop", dict[str, str]]:
139
+ new_tp, renames = self.tile_pattern._rename_to_match(other.tile_pattern)
140
+ return self.update(rank_name=other.rank_name, tile_pattern=new_tp), renames
141
+
142
+
143
+ @dataclass(frozen=True, eq=True, order=True)
144
+ class TensorReservation(Updatable):
145
+ # This order is important. Above loop index should be before resource name
146
+ # so when we sort reservations for tensors then the backing tensor holder comes
147
+ # first.
148
+ # Size is not included in hash or equality functions. This is because there
149
+ # may be floating point rounding errors in reservation sizes. The other
150
+ # attributes are sufficient to determine equality.
151
+ loops: tuple[Loop]
152
+ name: TensorName
153
+ resource_name: str
154
+ persistent: bool = False
155
+
156
+ def __post_init__(self):
157
+ if self.persistent:
158
+ assert len(self.loops) == 0, "Persistent tensors be above all loops"
159
+
160
+ @property
161
+ def above_loop_index(self) -> int:
162
+ return -1 if self.persistent else len(self.loops)
163
+
164
+ def __str__(self):
165
+ return f"[{self.resource_name}] {self.name} below {self.loops}"
166
+
167
+ def __repr__(self):
168
+ return f"Reservation({repr(self.name)}, {repr(self.loops)}, {repr(self.resource_name)})"
169
+
170
+ def pydot_str(self):
171
+ return f"[{self.resource_name}] {self.name}"
172
+
173
+ def permute(self, permutation) -> "Reservation":
174
+ new_loops = [self.loops[permutation[i]] for i in range(len(self.loops))]
175
+ return self.update(loops=tuple(new_loops))
176
+
177
+ def clear_loop_bounds(self) -> "Reservation":
178
+ return self.update(loops=tuple(loop.clear_loop_bound() for loop in self.loops))
179
+
180
+ def populate_loops(self) -> "TensorReservation":
181
+ return self.update(
182
+ loops=tuple(loop.populate(nloop) for nloop, loop in enumerate(self.loops))
183
+ )
184
+
185
+ @staticmethod
186
+ def get_backing_tensors(
187
+ all_tensors: set["TensorReservation"],
188
+ ) -> list["TensorReservation"]:
189
+ id2tensor = defaultdict(lambda: [])
190
+ for t in all_tensors:
191
+ id2tensor[t.name].append(t)
192
+ return sorted(sorted(v)[0] for v in id2tensor.values())
193
+
194
+ def drop_loop_indices(self, loop_indices: set[int]) -> "TensorReservation":
195
+ loops = tuple(l for i, l in enumerate(self.loops) if i not in loop_indices)
196
+ return self.update(loops=loops)
197
+
198
+ def _prepend_symbols(self, prepend: str) -> "TensorReservation":
199
+ return self.update(
200
+ loops=tuple(l._prepend_symbols(prepend) for l in self.loops),
201
+ )
202
+
203
+ def clear_symbolic_tile_patterns(self) -> "TensorReservation":
204
+ return self.update(
205
+ loops=tuple(l.clear_symbolic_tile_patterns() for l in self.loops),
206
+ )
207
+
208
+ def make_fused_loop_symbols(
209
+ self, prefix: str
210
+ ) -> tuple[dict[str, str], "TensorReservation"]:
211
+ result = {}
212
+ loops = []
213
+ for l in self.loops:
214
+ r, l = l.make_fused_loop_symbols(prefix)
215
+ result.update(r)
216
+ loops.append(l)
217
+ return result, self.update(loops=tuple(loops))
218
+
219
+ def add_n_iteration_symbols(self) -> "TensorReservation":
220
+ return self.update(
221
+ loops=tuple(l.add_n_iteration_symbols() for l in self.loops),
222
+ )
223
+
224
+ def _rename_to_match(
225
+ self, other: "TensorReservation"
226
+ ) -> tuple["TensorReservation", dict[str, str]]:
227
+ renames = {}
228
+ new_loops = []
229
+ for l_mine, l_other in zip(self.loops, other.loops):
230
+ l_mine, new_renames = l_mine._rename_to_match(l_other)
231
+ _update_rename_dict(renames, new_renames)
232
+ new_loops.append(l_mine)
233
+ return self.update(loops=tuple(new_loops)), renames
234
+
235
+
236
+ class SplitKind(Enum):
237
+ SEQUENTIAL = 0
238
+ PIPELINE = 1
239
+
240
+
241
+ @dataclass(frozen=True, order=True, eq=True)
242
+ class Split:
243
+ split: MappingSplit
244
+ above_loop_index: int
245
+
246
+
247
+ @dataclass(frozen=True)
248
+ class Compatibility(Updatable):
249
+ tensors: fzs[TensorReservation]
250
+ splits: fzs[Split] = fzs()
251
+ reservation_indices: fzs[int] = fzs()
252
+ check_reservation_indices: bool = True
253
+
254
+ @property
255
+ def n_loops(self) -> int:
256
+ return max([len(s.loops) for s in self.tensors], default=0)
257
+
258
+ @property
259
+ def loops(self) -> tuple[Loop, ...]:
260
+ return max([t.loops for t in self.tensors], key=len) if self.tensors else ()
261
+
262
+ def _get_hash_tuple(self):
263
+ return self.n_loops, self.tensors, self.reservation_indices
264
+
265
+ def __hash__(self):
266
+ return hash(self._get_hash_tuple())
267
+
268
+ def __eq__(self, other):
269
+ return self._get_hash_tuple() == other._get_hash_tuple()
270
+
271
+ def __post_init__(self):
272
+ assert isinstance(self.n_loops, int)
273
+ assert isinstance(self.tensors, fzs)
274
+ assert isinstance(self.splits, fzs)
275
+ assert isinstance(self.reservation_indices, fzs)
276
+ assert (
277
+ max(self.reservation_indices, default=-1) <= self.n_loops
278
+ ), f"Extra reservation indices {self.reservation_indices} are greater than n_loops {self.n_loops}"
279
+ if self.check_reservation_indices:
280
+ p = f"are not in reservation indices {self.reservation_indices}"
281
+ assert all(
282
+ i >= 0 for i in self.reservation_indices
283
+ ), f"Reservation indices {self.reservation_indices} are not all >= 0"
284
+ assert all(
285
+ s.above_loop_index in self.reservation_indices for s in self.splits
286
+ ), f"Split above loop indices {self.splits} {p}"
287
+ assert all(
288
+ len(s.loops) in self.reservation_indices for s in self.tensors
289
+ ), f"Tensor loops {self.tensors} {p}"
290
+
291
+ def get_backing_levels(self) -> dict[str, int]:
292
+ backings = {}
293
+ for t in self.tensors:
294
+ prev = backings.get(t.name, t.above_loop_index)
295
+ backings[t.name] = min(prev, t.above_loop_index)
296
+ return backings
297
+
298
+ @property
299
+ def tensor_names(self) -> set[str]:
300
+ return {t.name for t in self.tensors}
301
+
302
+ @property
303
+ def max_above_loop_index(self) -> int:
304
+ if len(self.tensors) == 0:
305
+ return 0
306
+ return max(s.above_loop_index for s in self.tensors)
307
+
308
+ def shared_loop_index(self, live_tensors: set[str]) -> int:
309
+ n = [l for t, l in self.get_backing_levels().items() if t in live_tensors]
310
+ return max(n) - 1 if n else -1
311
+
312
+ def __len__(self) -> int:
313
+ return self.max_above_loop_index
314
+
315
+ def _rename_to_match(
316
+ self, other: "Compatibility"
317
+ ) -> tuple["Compatibility", dict[str, str]]:
318
+ renames = {}
319
+ assert (
320
+ self.clear_symbolic_tile_patterns() == other.clear_symbolic_tile_patterns()
321
+ )
322
+ tensors = []
323
+ for t in self.tensors:
324
+ other_t = other.get_tensor_by_name(t.name)
325
+ t, new_renames = t._rename_to_match(other_t)
326
+ tensors.append(t)
327
+ _update_rename_dict(renames, new_renames)
328
+
329
+ return (
330
+ Compatibility(
331
+ tensors=fzs(tensors),
332
+ splits=self.splits,
333
+ reservation_indices=self.reservation_indices,
334
+ ),
335
+ renames,
336
+ )
337
+
338
+ def clear_dead_tensors(
339
+ self,
340
+ live_tensors: set[str] | Literal["All"],
341
+ ) -> "Compatibility":
342
+ """
343
+ Return a new compatibility with "dead" tensors removed by:
344
+ 1. keeping only loops relevant to `live_tensors` and
345
+ 2. keeping only `live_tensors`.
346
+
347
+ If `keep_loops` is `True`, then all loops are kept.
348
+ If `keep_tensors` is a set, tensors in the set are kept.
349
+ """
350
+ if live_tensors == "All":
351
+ live_tensors = self.tensor_names
352
+
353
+ remaining_tensors = fzs(s for s in self.tensors if s.name in live_tensors)
354
+ new_n_loops = max((len(s.loops) for s in remaining_tensors), default=0)
355
+ new_splits = fzs(
356
+ split for split in self.splits if split.above_loop_index < new_n_loops
357
+ )
358
+ reservation_indices = fzs(
359
+ {min(i, new_n_loops) for i in self.reservation_indices}
360
+ )
361
+ reservation_indices = fzs(x for x in reservation_indices if x >= 0)
362
+
363
+ return self.update(
364
+ tensors=remaining_tensors,
365
+ splits=new_splits,
366
+ reservation_indices=reservation_indices,
367
+ )
368
+
369
+ def __lt__(self, other):
370
+ return self._get_hash_tuple() < other._get_hash_tuple()
371
+
372
+ def __str__(self):
373
+ return self.__repr__()
374
+
375
+ def __repr__(self):
376
+ return f"Compatibility(n_loops={self.n_loops}, tensors={repr(self.tensors)}), splits={repr(self.splits)}"
377
+
378
+ def _and_tensors_with_names(self, names: set[str]) -> "Compatibility":
379
+ return fzs(s for s in self.tensors if s.name in names)
380
+
381
+ def merge_next(
382
+ self,
383
+ right: "Compatibility",
384
+ live_tensors: set[str],
385
+ mixable_ranks: dict[Rank, set[Rank]],
386
+ ) -> "Compatibility":
387
+ self_freed = self.clear_dead_tensors(live_tensors)
388
+ right_freed = right.clear_dead_tensors(live_tensors)
389
+ if self_freed.n_loops > right_freed.n_loops:
390
+ # This can be relaxed if we have a way to do order-independent joining
391
+ # and/or non-looptree mappings.
392
+ raise ValueError(
393
+ f"Can't merge. I have more loops than the next, so my dataflow can't "
394
+ f"be carried through a LoopTree to where it's needed."
395
+ )
396
+
397
+ live_minus_mine = live_tensors - {s.name for s in self.tensors}
398
+ tensors_a = self._and_tensors_with_names(live_tensors)
399
+ tensors_b = right._and_tensors_with_names(live_minus_mine)
400
+
401
+ # TODO: split handling?
402
+ joined = Compatibility(
403
+ tensors=tensors_a | tensors_b,
404
+ reservation_indices=self_freed.reservation_indices
405
+ | right_freed.reservation_indices,
406
+ )
407
+
408
+ if mixable_ranks is not None and not joined._is_valid(mixable_ranks):
409
+ raise ValueError(f"Invalid rank mixing.")
410
+
411
+ return joined
412
+
413
+ def has_tensor(self, *tensors: TensorReservation) -> bool:
414
+ return all(any(s == t for s in self.tensors) for t in tensors)
415
+
416
+ def _permute_stops(self):
417
+ stops = set(len(s.loops) for s in self.tensors)
418
+ stops |= self.reservation_indices
419
+ stops |= set(s.above_loop_index for s in self.splits)
420
+ return stops
421
+
422
+ def permute(
423
+ self,
424
+ loop_changes: list[int],
425
+ ) -> "Compatibility":
426
+ assert len(loop_changes) <= self.n_loops
427
+ assert set(loop_changes) == set(
428
+ range(len(loop_changes))
429
+ ), f"Loop changes {loop_changes} are not a permutation of {range(len(loop_changes))}"
430
+ if len(loop_changes) < len(self.loops):
431
+ loop_changes = loop_changes + list(
432
+ range(len(loop_changes), len(self.loops))
433
+ )
434
+
435
+ permute_stops = self._permute_stops()
436
+ for i, c in enumerate(loop_changes):
437
+ for r in permute_stops:
438
+ assert (i < r) == (
439
+ c < r
440
+ ), f"Loop changes {loop_changes} cross reservation {r}"
441
+ new_tensors = fzs(s.permute(loop_changes) for s in self.tensors)
442
+ return self.update(tensors=new_tensors)
443
+
444
+ def make_equivalent_permutations(self) -> list[tuple["Compatibility", list[int]]]:
445
+ # Get contiguous blocks of loops with no tensor reservation between them
446
+ blocks = []
447
+ current_block = []
448
+ permute_stops = self._permute_stops()
449
+ for i in range(self.n_loops):
450
+ # Can't permute loops if there's a reservation between them
451
+ if i in permute_stops:
452
+ blocks.append(current_block)
453
+ current_block = []
454
+ current_block.append(i)
455
+ if current_block:
456
+ blocks.append(current_block)
457
+
458
+ per_block_permutations = [
459
+ list(itertools.permutations(block)) for block in blocks
460
+ ]
461
+ all_permutations = list(itertools.product(*per_block_permutations))
462
+ all_permutations = [
463
+ list(itertools.chain(*loop_changes)) for loop_changes in all_permutations
464
+ ]
465
+ return [(self.permute(p), p) for p in all_permutations]
466
+
467
+ def get_tensor_by_name(self, tensor: str) -> TensorReservation:
468
+ for s in self.tensors:
469
+ if s.name == tensor:
470
+ return s
471
+ raise ValueError(f"No reservation found for {tensor}")
472
+
473
+ def per_tensor_compatibility(self) -> dict[str, "Compatibility"]:
474
+ result = {}
475
+ for s in self.tensors:
476
+ result[s.name] = self.clear_dead_tensors(set([s.name]))
477
+ return result
478
+
479
+ def clear_loop_bounds(self) -> "Compatibility":
480
+ return self.update(tensors=fzs(t.clear_loop_bounds() for t in self.tensors))
481
+
482
+ def compatible_with(self, other: "Compatibility") -> bool:
483
+ return True
484
+ # for a in self.tensors:
485
+ # a = a.loops
486
+ # for b in other.tensors:
487
+ # b = b.loops
488
+ # if a[:len(b)] != b[:len(a)]:
489
+ # return False
490
+ # return True
491
+
492
+ def populate_loops(self):
493
+ return self.update(
494
+ tensors=fzs(t.populate_loops() for t in self.tensors),
495
+ )
496
+
497
+ @classmethod
498
+ def from_mapping(
499
+ cls,
500
+ mapping: Mapping,
501
+ tensors: set[TensorName],
502
+ rank_variable_to_ranks: dict[TensorName, dict[RankVariable, Rank]],
503
+ ) -> "Compatibility":
504
+ # TODO: update compatibility to handle spatial-for loop per-tensor update
505
+ tensor_indices = []
506
+ split_above_loop_indices = []
507
+ reservation_indices = []
508
+ backing_remaining = set(tensors)
509
+ n_seen_loops = 0
510
+ n_fused_loops = 0
511
+ for i, n in enumerate(mapping.nodes):
512
+ if isinstance(n, MappingReservation):
513
+ reservation_indices.append(n_seen_loops)
514
+ if not (backing := set(n.purposes) & backing_remaining):
515
+ continue
516
+ backing_remaining -= backing
517
+ assert (
518
+ len(n.purposes) == 1
519
+ ), "Backing reservations should have only one purpose"
520
+ tensor_indices.append(i)
521
+ elif isinstance(n, MappingSplit):
522
+ split_above_loop_indices.append(n_seen_loops)
523
+ elif isinstance(n, MappingLoop):
524
+ n_seen_loops += 1
525
+ n_fused_loops += bool(backing_remaining)
526
+ elif isinstance(n, TensorHolder):
527
+ reservation_indices.append(n_seen_loops)
528
+
529
+ reservation_indices = fzs(min(n, n_fused_loops) for n in reservation_indices)
530
+ reservation_indices = fzs(x for x in reservation_indices if x >= 0)
531
+
532
+ assert (
533
+ not backing_remaining
534
+ ), f"Tensors {backing_remaining} not found in mapping"
535
+
536
+ def get_rank(rank_variable, tensor):
537
+ rv = rank_variable_to_ranks[tensor].get(rank_variable, set())
538
+ assert (
539
+ len(rv) <= 1
540
+ ), f"Rank variable {rank_variable} indexes into multiple ranks {rv} for tensor {tensor} "
541
+ return next(iter(rv), Rank("NO RANK. RECOMPUTED."))
542
+
543
+ def make_loops(above_index: int, tensor_name: TensorName) -> list[MappingLoop]:
544
+ loops = [
545
+ n for n in mapping.nodes[:above_index] if isinstance(n, MappingLoop)
546
+ ]
547
+ loops = [
548
+ Loop(
549
+ rank_name=get_rank(n.rank_variable, tensor_name),
550
+ tile_pattern=n.tile_pattern._symbol2str(),
551
+ is_spatial=isinstance(n, Spatial),
552
+ )
553
+ for n in loops
554
+ ]
555
+ return tuple(loops)
556
+
557
+ return cls(
558
+ tensors=fzs(
559
+ TensorReservation(
560
+ name=mapping.nodes[i].purpose,
561
+ loops=make_loops(i, mapping.nodes[i].purpose),
562
+ resource_name=mapping.nodes[i].resource,
563
+ persistent=mapping.nodes[i].persistent,
564
+ )
565
+ for i in tensor_indices
566
+ ),
567
+ splits=fzs(
568
+ Split(split=n, above_loop_index=i) for i in split_above_loop_indices
569
+ ),
570
+ reservation_indices=fzs(reservation_indices),
571
+ )
572
+
573
+ def symbols(self) -> list[str]:
574
+ symbols = []
575
+
576
+ def add(x: str):
577
+ if isinstance(x, str) and x not in symbols:
578
+ symbols.append(x)
579
+
580
+ for t in self.tensors:
581
+ for l in t.loops:
582
+ add(l.tile_pattern.initial_tile_shape)
583
+ add(l.tile_pattern.tile_shape)
584
+ add(l.tile_pattern.calculated_n_iterations)
585
+ return symbols
586
+
587
+ def drop_loop_indices(self, loop_indices: set[int]) -> "Compatibility":
588
+ loop_indices = set(loop_indices)
589
+ tensors = fzs(t.drop_loop_indices(loop_indices) for t in self.tensors)
590
+ splits = fzs(s for s in self.splits if s.above_loop_index not in loop_indices)
591
+
592
+ def adjust(i: int) -> int:
593
+ return i - sum(x < i for x in loop_indices)
594
+
595
+ reservation_indices = fzs(adjust(i) for i in self.reservation_indices)
596
+ reservation_indices = fzs(x for x in reservation_indices if x >= 0)
597
+
598
+ splits = fzs(
599
+ s.update(above_loop_index=adjust(s.above_loop_index)) for s in self.splits
600
+ )
601
+ return Compatibility(
602
+ tensors=tensors,
603
+ splits=splits,
604
+ reservation_indices=reservation_indices,
605
+ )
606
+
607
+ def _prepend_symbols(self, prepend: str) -> "Compatibility":
608
+ return self.update(
609
+ tensors=fzs(t._prepend_symbols(prepend) for t in self.tensors)
610
+ )
611
+
612
+ def clear_tile_patterns_and_reservation_indices(self) -> "Compatibility":
613
+ return self.update(
614
+ tensors=fzs(t.clear_symbolic_tile_patterns() for t in self.tensors),
615
+ reservation_indices=fzs(),
616
+ check_reservation_indices=False,
617
+ )
618
+
619
+ def clear_symbolic_tile_patterns(self) -> "Compatibility":
620
+ return self.update(
621
+ tensors=fzs(t.clear_symbolic_tile_patterns() for t in self.tensors)
622
+ )
623
+
624
+ def make_fused_loop_symbols(
625
+ self, prefix: str
626
+ ) -> tuple[dict[str, str], "Compatibility"]:
627
+ result = {}
628
+ tensors = []
629
+ for t in self.tensors:
630
+ r, t = t.make_fused_loop_symbols(prefix)
631
+ tensors.append(t)
632
+ result.update(r)
633
+
634
+ return result, self.update(tensors=fzs(tensors))
635
+
636
+ def add_n_iteration_symbols(self) -> "Compatibility":
637
+ return self.update(
638
+ tensors=fzs(t.add_n_iteration_symbols() for t in self.tensors)
639
+ )
640
+
641
+ def _is_valid(self, mixable_ranks: dict[Rank, set[Rank]]) -> bool:
642
+ # Mixable ranks: Ranks that may be co-iterated by a single loop.
643
+ ranks_at_each_loop_index = []
644
+ for i in range(self.n_loops):
645
+ ranks_at_each_loop_index.append(
646
+ set(t.loops[i].rank_name for t in self.tensors if i < len(t.loops))
647
+ )
648
+
649
+ for ranks in ranks_at_each_loop_index:
650
+ for r1, r2 in itertools.combinations(ranks, 2):
651
+ if r1 not in mixable_ranks[r2]:
652
+ return False
653
+ return True