accelforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. accelforge/__init__.py +21 -0
  2. accelforge/_accelerated_imports.py +16 -0
  3. accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
  4. accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
  5. accelforge/_deprecate/_simanneal/simanneal.py +666 -0
  6. accelforge/_deprecate/_simanneal/tracking.py +105 -0
  7. accelforge/_deprecate/_simanneal/wrappers.py +218 -0
  8. accelforge/_deprecate/_simanneal2/__init__.py +7 -0
  9. accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
  10. accelforge/_deprecate/_simanneal2/tracking.py +116 -0
  11. accelforge/_deprecate/compatibility_util.py +181 -0
  12. accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
  13. accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
  14. accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
  15. accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
  16. accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
  17. accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
  18. accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
  19. accelforge/_deprecate/tags.py +69 -0
  20. accelforge/_deprecate/viz/__init__.py +0 -0
  21. accelforge/_deprecate/viz/interactive.py +159 -0
  22. accelforge/_deprecate/viz/reservationtree.py +307 -0
  23. accelforge/_deprecate/viz/ski_slope.py +88 -0
  24. accelforge/_version.py +15 -0
  25. accelforge/examples.py +39 -0
  26. accelforge/frontend/__init__.py +10 -0
  27. accelforge/frontend/_binding.py +129 -0
  28. accelforge/frontend/_workload_isl/__init__.py +2 -0
  29. accelforge/frontend/_workload_isl/_isl.py +149 -0
  30. accelforge/frontend/_workload_isl/_symbolic.py +141 -0
  31. accelforge/frontend/arch copy.py +1544 -0
  32. accelforge/frontend/arch.py +1642 -0
  33. accelforge/frontend/config.py +63 -0
  34. accelforge/frontend/mapper/__init__.py +5 -0
  35. accelforge/frontend/mapper/ffm.py +126 -0
  36. accelforge/frontend/mapper/mapper.py +7 -0
  37. accelforge/frontend/mapper/metrics.py +30 -0
  38. accelforge/frontend/mapping/__init__.py +1 -0
  39. accelforge/frontend/mapping/mapping.py +1736 -0
  40. accelforge/frontend/model.py +14 -0
  41. accelforge/frontend/renames.py +150 -0
  42. accelforge/frontend/spec copy.py +230 -0
  43. accelforge/frontend/spec.py +301 -0
  44. accelforge/frontend/variables.py +12 -0
  45. accelforge/frontend/workload.py +952 -0
  46. accelforge/mapper/FFM/__init__.py +9 -0
  47. accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
  48. accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
  49. accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
  50. accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
  51. accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
  52. accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
  53. accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
  54. accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
  55. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
  56. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
  57. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
  58. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
  59. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
  60. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
  61. accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
  62. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
  63. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
  64. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
  65. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
  66. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
  67. accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
  68. accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
  69. accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
  70. accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
  71. accelforge/mapper/FFM/data.py +61 -0
  72. accelforge/mapper/FFM/main copy.py +236 -0
  73. accelforge/mapper/FFM/main.py +208 -0
  74. accelforge/mapper/FFM/mappings.py +510 -0
  75. accelforge/mapper/FFM/pmappings.py +310 -0
  76. accelforge/mapper/__init__.py +4 -0
  77. accelforge/mapper.py +0 -0
  78. accelforge/model/__init__.py +1 -0
  79. accelforge/model/_looptree/__init__.py +0 -0
  80. accelforge/model/_looptree/accesses.py +335 -0
  81. accelforge/model/_looptree/capacity/__init__.py +1 -0
  82. accelforge/model/_looptree/capacity/aggregators.py +36 -0
  83. accelforge/model/_looptree/capacity/capacity.py +47 -0
  84. accelforge/model/_looptree/energy.py +150 -0
  85. accelforge/model/_looptree/equivalent_ranks.py +29 -0
  86. accelforge/model/_looptree/latency/__init__.py +1 -0
  87. accelforge/model/_looptree/latency/latency.py +98 -0
  88. accelforge/model/_looptree/latency/memory.py +120 -0
  89. accelforge/model/_looptree/latency/processors.py +92 -0
  90. accelforge/model/_looptree/mapping_utilities.py +71 -0
  91. accelforge/model/_looptree/reuse/__init__.py +4 -0
  92. accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
  93. accelforge/model/_looptree/reuse/isl/des.py +59 -0
  94. accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
  95. accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
  96. accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
  97. accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
  98. accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
  99. accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
  100. accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
  101. accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
  102. accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
  103. accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
  104. accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
  105. accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
  106. accelforge/model/_looptree/run.py +122 -0
  107. accelforge/model/_looptree/types.py +26 -0
  108. accelforge/model/_looptree/visualization/__init__.py +0 -0
  109. accelforge/model/_looptree/visualization/occupancy.py +11 -0
  110. accelforge/model/main.py +222 -0
  111. accelforge/plotting/__init__.py +2 -0
  112. accelforge/plotting/mappings.py +219 -0
  113. accelforge/plotting/specs.py +57 -0
  114. accelforge/util/__init__.py +4 -0
  115. accelforge/util/_base_analysis_types.py +24 -0
  116. accelforge/util/_basetypes.py +1089 -0
  117. accelforge/util/_frozenset.py +36 -0
  118. accelforge/util/_isl.py +29 -0
  119. accelforge/util/_itertools.py +14 -0
  120. accelforge/util/_mathfuncs.py +57 -0
  121. accelforge/util/_parse_expressions.py +339 -0
  122. accelforge/util/_picklecache.py +32 -0
  123. accelforge/util/_setexpressions.py +268 -0
  124. accelforge/util/_sympy/__init__.py +0 -0
  125. accelforge/util/_sympy/broadcast_max.py +18 -0
  126. accelforge/util/_visualization.py +112 -0
  127. accelforge/util/_yaml.py +579 -0
  128. accelforge/util/parallel.py +193 -0
  129. accelforge-0.0.1.dist-info/METADATA +64 -0
  130. accelforge-0.0.1.dist-info/RECORD +258 -0
  131. accelforge-0.0.1.dist-info/WHEEL +5 -0
  132. accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
  133. accelforge-0.0.1.dist-info/top_level.txt +5 -0
  134. docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
  135. docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
  136. docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
  137. docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
  138. docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
  139. docs/_build/html/_sources/fastfusion.rst.txt +20 -0
  140. docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
  141. docs/_build/html/_sources/index.rst.txt +87 -0
  142. docs/_build/html/_sources/modules.rst.txt +7 -0
  143. docs/_build/html/_sources/notes/citation.rst.txt +45 -0
  144. docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
  145. docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
  146. docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
  147. docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
  148. docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
  149. docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
  150. docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
  151. docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
  152. docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
  153. docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
  154. docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
  155. docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
  156. docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
  157. docs/_build/html/_sources/notes/spec.rst.txt +36 -0
  158. docs/source/_ext/include_attrs.py +213 -0
  159. docs/source/_ext/include_docstring.py +364 -0
  160. docs/source/_ext/include_functions.py +154 -0
  161. docs/source/_ext/include_notebook.py +131 -0
  162. docs/source/_ext/include_yaml.py +119 -0
  163. docs/source/_ext/inherited_attributes.py +222 -0
  164. docs/source/_ext/paths.py +4 -0
  165. docs/source/conf.py +79 -0
  166. examples/arches/compute_in_memory/_include.yaml +74 -0
  167. examples/arches/compute_in_memory/_include_functions.py +229 -0
  168. examples/arches/compute_in_memory/_load_spec.py +57 -0
  169. examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
  170. examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
  171. examples/arches/compute_in_memory/components/misc.py +195 -0
  172. examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
  173. examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
  174. examples/arches/compute_in_memory/isaac.yaml +233 -0
  175. examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
  176. examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
  177. examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
  178. examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
  179. examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
  180. examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
  181. examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
  182. examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
  183. examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
  184. examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
  185. examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
  186. examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
  187. examples/arches/eyeriss.yaml +68 -0
  188. examples/arches/fanout_variations/at_glb.yaml +31 -0
  189. examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
  190. examples/arches/fanout_variations/at_mac.yaml +31 -0
  191. examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
  192. examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
  193. examples/arches/nvdla.yaml +47 -0
  194. examples/arches/simple.yaml +28 -0
  195. examples/arches/tpu_v4i.yaml +67 -0
  196. examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
  197. examples/misc/component_annotated.yaml +33 -0
  198. examples/workloads/gpt3_6.7B.yaml +124 -0
  199. examples/workloads/matmuls.yaml +20 -0
  200. examples/workloads/mobilenet_28.yaml +81 -0
  201. examples/workloads/mobilenet_various_separate.yaml +106 -0
  202. examples/workloads/three_matmuls_annotated.yaml +59 -0
  203. notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
  204. notebooks/compute_in_memory/_scripts.py +339 -0
  205. notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
  206. notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
  207. notebooks/paths.py +4 -0
  208. notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
  209. notebooks/tutorials/FFM.ipynb +3498 -0
  210. notebooks/tutorials/_include.py +48 -0
  211. notebooks/tutorials/component_energy_area.ipynb +363 -0
  212. tests/Q_mapping.yaml +38 -0
  213. tests/__init__.py +0 -0
  214. tests/conv.mapping.yaml +27 -0
  215. tests/conv.workload.yaml +13 -0
  216. tests/conv_sym.mapping.yaml +43 -0
  217. tests/copy.mapping.yaml +35 -0
  218. tests/copy.workload.yaml +15 -0
  219. tests/distribuffers/__init__.py +0 -0
  220. tests/distribuffers/multicast/test_cases.yaml +482 -0
  221. tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
  222. tests/distribuffers/spec/distributed.yaml +100 -0
  223. tests/distribuffers/spec/logical_arch.yaml +32 -0
  224. tests/distribuffers/spec/physical_arch.yaml +69 -0
  225. tests/distribuffers/test_binding.py +48 -0
  226. tests/frontend/__init__.py +0 -0
  227. tests/frontend/test_mapping_viz.py +52 -0
  228. tests/mapper/__init__.py +0 -0
  229. tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
  230. tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
  231. tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
  232. tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
  233. tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
  234. tests/mapper/test_mapping_to_isl.py +90 -0
  235. tests/mapper/test_spatial_reuse_analysis.py +67 -0
  236. tests/mapper/test_temporal_reuse_analysis.py +56 -0
  237. tests/mapper/util.py +58 -0
  238. tests/matmul.mapping.yaml +29 -0
  239. tests/matmul.workload.yaml +12 -0
  240. tests/matmul_spatial.mapping.yaml +44 -0
  241. tests/mha.renames.yaml +65 -0
  242. tests/mha.workload.yaml +67 -0
  243. tests/mha.yaml +59 -0
  244. tests/mha_full.workload.yaml +67 -0
  245. tests/mobilenet.workload.yaml +35 -0
  246. tests/mobilenet_long.workload.yaml +64 -0
  247. tests/pmappingcache.py +24 -0
  248. tests/processing_stage.arch.yaml +40 -0
  249. tests/snowcat.arch.yaml +36 -0
  250. tests/test_ffm_join_pmappings.py +106 -0
  251. tests/test_ffm_make_pmappings.py +82 -0
  252. tests/test_ffm_make_tile_shapes.py +49 -0
  253. tests/test_mapper.py +100 -0
  254. tests/test_model.py +37 -0
  255. tests/test_plotting.py +72 -0
  256. tests/test_processing_stage.py +46 -0
  257. tests/test_symbolic_model.py +248 -0
  258. tests/test_workload.py +141 -0
@@ -0,0 +1,901 @@
1
+ from collections import defaultdict
2
+ import copy
3
+ import functools
4
+ import itertools
5
+
6
+ from typing import Callable, Iterable, Optional
7
+
8
+ import sympy
9
+
10
+ from accelforge.frontend.mapping import Nested, TilePattern
11
+ from accelforge.frontend.mapping import Loop as MappingLoop
12
+ from accelforge.mapper.FFM._join_pmappings.compatibility import (
13
+ Compatibility,
14
+ Loop,
15
+ TensorReservation,
16
+ )
17
+ from accelforge.util._frozenset import fzs
18
+
19
+ from accelforge._accelerated_imports import pd
20
+
21
+ from accelforge.mapper.FFM._pareto_df.df_convention import *
22
+ from accelforge.mapper.FFM._pareto_df.pareto import makepareto
23
+
24
+
25
+ CHECK_CORRECTNESS = False
26
+
27
+
28
+ def error_check_wrapper(func):
29
+ if not CHECK_CORRECTNESS:
30
+ return func
31
+
32
+ @functools.wraps(func)
33
+ def wrapper(*args, **kwargs):
34
+ try:
35
+ prev_args, prev_kwargs = copy.deepcopy(args), copy.deepcopy(kwargs)
36
+ return func(*args, **kwargs)
37
+ except Exception as e:
38
+ print(f"EXCEPTION: {e}")
39
+ live_tensors = set()
40
+ if "live_tensors" in kwargs:
41
+ live_tensors = kwargs["live_tensors"]
42
+ else:
43
+ argnames = func.__code__.co_varnames[: func.__code__.co_argcount]
44
+ if "live_tensors" in argnames:
45
+ idx = argnames.index("live_tensors")
46
+ if idx < len(args):
47
+ live_tensors = args[idx]
48
+ for prev_arg in itertools.chain(prev_args, prev_kwargs.values()):
49
+ if isinstance(prev_arg, PmappingDataframe):
50
+ prev_arg.fail(0, live_tensors)
51
+ break
52
+ func(*args, **kwargs) # For debugging
53
+
54
+ return wrapper
55
+
56
+
57
+ class PmappingDataframe:
58
+ def __init__(
59
+ self,
60
+ data: pd.DataFrame,
61
+ n_total_pmappings: float,
62
+ n_valid_pmappings: float,
63
+ skip_pareto: bool = False,
64
+ fill_reservation_cols: set | str = fzs(),
65
+ check_above_subset_below: bool = CHECK_CORRECTNESS,
66
+ max_right_to_left: bool = False,
67
+ next_shared_loop_index: int = None,
68
+ parallelize_pareto: bool = False,
69
+ limit_capacity_drop_valid_reservations: bool = True,
70
+ ignored_resources: set[str] = None,
71
+ ):
72
+ self._data: pd.DataFrame = data
73
+ self.right_reservations: dict[set] = None
74
+ self.left_reservations: dict[set] = None
75
+ self._prev_free_to_loop_index = None
76
+ self._parallelize_pareto = parallelize_pareto
77
+ self._make_reservations()
78
+ self.n_total_pmappings: float = n_total_pmappings
79
+ self.n_valid_pmappings: float = n_valid_pmappings
80
+
81
+ if next_shared_loop_index is not None:
82
+ assert (
83
+ ignored_resources is not None
84
+ ), "ignored_resources must be set if next_shared_loop_index is set"
85
+ self.free_to_loop_index(loop_index=next_shared_loop_index)
86
+ self.limit_capacity(
87
+ next_shared_loop_index=next_shared_loop_index,
88
+ drop_valid_reservations=limit_capacity_drop_valid_reservations,
89
+ ignored_resources=ignored_resources,
90
+ )
91
+ self._check_reservations()
92
+
93
+ if fill_reservation_cols: # Affects PmappingDataframe so must go before
94
+ self.fill_reservation_cols(fill_reservation_cols)
95
+ if check_above_subset_below:
96
+ self.check_above_subset_below()
97
+ if max_right_to_left: # Affects PmappingDataframe so must go before
98
+ self.max_right_to_left()
99
+ if check_above_subset_below:
100
+ self.check_above_subset_below()
101
+
102
+ if not skip_pareto:
103
+ self.make_pareto(parallelize=parallelize_pareto)
104
+
105
+ if check_above_subset_below:
106
+ self.check_above_subset_below()
107
+
108
+ self._check_reservations()
109
+
110
+ assert len(self.data.columns) == len(
111
+ set(self.data.columns)
112
+ ), f"Duplicate columns: {self.data.columns}"
113
+
114
+ def all_reservation_levels(self):
115
+ return set().union(
116
+ set(),
117
+ *self.left_reservations.values(),
118
+ *self.right_reservations.values(),
119
+ )
120
+
121
+ def rename(self, renames: dict[str, str]) -> "PmappingDataframe":
122
+ new = self.copy()
123
+ new.data.rename(columns=renames, inplace=True)
124
+ return new
125
+
126
+ @error_check_wrapper
127
+ def fill_reservation_cols(self, columns: set | str):
128
+ self._check_reservations()
129
+ targets = []
130
+ if columns == "auto":
131
+ for left, reservations_dict in [
132
+ (True, self.left_reservations),
133
+ (False, self.right_reservations),
134
+ ]:
135
+ for resource, reservations in reservations_dict.items():
136
+ for r in sorted(reservations):
137
+ above = self.get_reservation_or_parent(resource, r - 1)
138
+ if above is not None:
139
+ below = nameloop2col(resource, r, left=left)
140
+ targets.append((r, above, below))
141
+ else:
142
+ for below in columns:
143
+ if (name_nloops := col2nameloop(below)) is None:
144
+ raise ValueError(f"{below} is not a valid reservation column")
145
+ name, nloops = name_nloops
146
+ above = self.get_reservation_or_parent(name, nloops - 1)
147
+ if above is not None:
148
+ targets.append((nloops, above, below))
149
+
150
+ # Sort so we go from top to bottom. Needed in case we have to max 0->1
151
+ # then 1->2
152
+ for _, above, below in sorted(targets, key=lambda x: x[0]):
153
+ assert (
154
+ above in self.data.columns
155
+ ), f"Missing column {above}. Have columns:\n\t" + "\n\t".join(
156
+ list(self.data.columns)
157
+ )
158
+ assert (
159
+ below in self.data.columns
160
+ ), f"Missing column {below}. Have columns:\n\t" + "\n\t".join(
161
+ list(self.data.columns)
162
+ )
163
+ max_to_col(self.data, below, above)
164
+
165
+ self._check_reservations()
166
+
167
+ @error_check_wrapper
168
+ def max_right_to_left(self):
169
+ for resource, reservations in self.left_reservations.items():
170
+ for r in reservations:
171
+ if r in self.right_reservations.get(resource, set()):
172
+ source = nameloop2col(resource, r)
173
+ target = nameloop2col(resource, r, left=True)
174
+ max_to_col(self.data, target, source)
175
+ self._make_reservations()
176
+ self._check_reservations()
177
+
178
+ @property
179
+ def data(self) -> pd.DataFrame:
180
+ return self._data
181
+
182
+ @error_check_wrapper
183
+ def _make_reservations(self):
184
+ """
185
+ Create a dictionary of reservations for each resource.
186
+ The dictionary keys are the resource names and the values are lists
187
+ of column names for each loop index.
188
+ """
189
+ self.left_reservations, self.right_reservations = {}, {}
190
+ for c in self.data.columns:
191
+ if (name_nloops := col2nameloop(c)) is not None:
192
+ name, nloops = name_nloops
193
+ target = (
194
+ self.left_reservations
195
+ if is_left_col(c)
196
+ else self.right_reservations
197
+ )
198
+ target.setdefault(name, set()).add(nloops)
199
+ assert nloops >= -1
200
+
201
+ def _check_reservations(self):
202
+ prev_left, prev_right = self.left_reservations, self.right_reservations
203
+ self._make_reservations()
204
+ assert (
205
+ self.left_reservations == prev_left
206
+ ), f"Left reservations changed: {self.left_reservations} != {prev_left}"
207
+ assert (
208
+ self.right_reservations == prev_right
209
+ ), f"Right reservations changed: {self.right_reservations} != {prev_right}"
210
+
211
+ @error_check_wrapper
212
+ def free_to_loop_index(
213
+ self,
214
+ loop_index: int,
215
+ live_tensors: set[int] = None,
216
+ check_correctness: bool = CHECK_CORRECTNESS,
217
+ ) -> bool:
218
+ """
219
+ A B
220
+ / | --- 0
221
+ C D
222
+ / | --- 1 < Shared Loop Index
223
+ E F
224
+ / | --- 2
225
+ G H
226
+ ->
227
+ A B
228
+ / | --- 0
229
+ C D
230
+ | --- 1 < Shared Loop Index
231
+ max(E,G,H)
232
+ We skip incorporating E into the max because its reservations are
233
+ already incorporated into F and G.
234
+ """
235
+ if loop_index == self._prev_free_to_loop_index:
236
+ return False
237
+ self._prev_free_to_loop_index = loop_index
238
+
239
+ drop_columns = []
240
+ for resource in set(self.left_reservations) | set(self.right_reservations):
241
+ max_columns = []
242
+ left_reservations = self.left_reservations.get(resource, set())
243
+ right_reservations = self.right_reservations.get(resource, set())
244
+ left_big_enough = [l for l in left_reservations if l >= loop_index + 1]
245
+ right_big_enough = [
246
+ r for r in right_reservations if r >= loop_index + 2
247
+ ] # + 1 is target
248
+
249
+ if len(right_big_enough) > 1: # All ones above the last are subsets
250
+ right_biggest = max(right_big_enough)
251
+ right_big_enough.remove(right_biggest)
252
+ drop_columns += [nameloop2col(resource, r) for r in right_big_enough]
253
+ right_big_enough = [right_biggest]
254
+
255
+ max_columns = [nameloop2col(resource, r) for r in right_big_enough] + [
256
+ nameloop2col(resource, l, left=True) for l in left_big_enough
257
+ ]
258
+
259
+ if not max_columns:
260
+ continue
261
+
262
+ target = nameloop2col(resource, loop_index + 1)
263
+ if target in self.data:
264
+ max_columns.append(target)
265
+
266
+ if len(max_columns) == 1:
267
+ self.data.rename(columns={max_columns[0]: target}, inplace=True)
268
+ else:
269
+ for c in max_columns:
270
+ max_to_col(self.data, target, c)
271
+ drop_columns += [m for m in max_columns if m != target]
272
+ self.data.drop(columns=drop_columns, inplace=True)
273
+ self._make_reservations()
274
+
275
+ if check_correctness and live_tensors is not None:
276
+ self.copy().check_reservations(live_tensors=live_tensors)
277
+ self._check_reservations()
278
+ return len(drop_columns) != 0
279
+
280
+ @error_check_wrapper
281
+ def get_reservation_or_parent(
282
+ self,
283
+ name: str,
284
+ level: int,
285
+ left: bool = False,
286
+ return_name_level_left: bool = False,
287
+ ) -> str | tuple[str, int, bool] | None:
288
+ reservations = self.left_reservations if left else self.right_reservations
289
+ if (reservations := reservations.get(name, None)) is not None:
290
+ while level >= -1:
291
+ if level in reservations:
292
+ if return_name_level_left:
293
+ return name, level, left
294
+ return nameloop2col(name, level, left)
295
+ # The parent of left nodes are right nodes, so if we don't find a
296
+ # left node immediately then we're back on the right nodes
297
+ reservations = self.right_reservations.get(name, set())
298
+ left = False
299
+ level -= 1
300
+ return None
301
+
302
+ @error_check_wrapper
303
+ def shift_bottom_reservation_left(self, shared_loop_index: int):
304
+ """
305
+ Shifts the bottom reservation from right to left.
306
+ Example:
307
+ Before: After:
308
+ A B A B
309
+ / | --- 0 / | --- 0
310
+ C D C D
311
+ | --- 1 / --- 1
312
+ E E
313
+ """
314
+ for resource in self.right_reservations:
315
+ if shared_loop_index + 1 not in self.right_reservations[resource]:
316
+ continue
317
+ self.left_reservations.setdefault(resource, set())
318
+ self.right_reservations[resource].remove(shared_loop_index + 1)
319
+ self.left_reservations[resource].add(shared_loop_index + 1)
320
+ source = nameloop2col(resource, shared_loop_index + 1)
321
+ target = nameloop2col(resource, shared_loop_index + 1, left=True)
322
+ if target in self.data:
323
+ max_to_col(self.data, target, source)
324
+ self.data.drop(columns=[source], inplace=True)
325
+ else:
326
+ self.data.rename(columns={source: target}, inplace=True)
327
+ self._make_reservations()
328
+ self._check_reservations()
329
+
330
+ @staticmethod
331
+ def _get_target_path(suffix: str = None) -> str:
332
+ import os
333
+
334
+ f = "./images"
335
+ os.makedirs(f, exist_ok=True)
336
+ suffix = "" if suffix is None else f".{suffix}"
337
+ i = 0
338
+ while os.path.exists(os.path.join(f, f"test_{i}{suffix}.png")):
339
+ i += 1
340
+ return os.path.join(f, f"test_{i}{suffix}.png")
341
+
342
+ def get_max_loop_index(self):
343
+ return max(
344
+ max(
345
+ (max(r, default=-1) for r in self.right_reservations.values()),
346
+ default=-1,
347
+ ),
348
+ max(
349
+ (max(r, default=-1) for r in self.left_reservations.values()),
350
+ default=-1,
351
+ ),
352
+ )
353
+
354
+ def get_min_loop_index(self):
355
+ return min(
356
+ min(
357
+ (min(r, default=1000000) for r in self.right_reservations.values()),
358
+ default=1000000,
359
+ ),
360
+ min(
361
+ (min(r, default=1000000) for r in self.left_reservations.values()),
362
+ default=1000000,
363
+ ),
364
+ )
365
+
366
+ @error_check_wrapper
367
+ def merge_next(
368
+ self,
369
+ right: "PmappingDataframe",
370
+ shared_loop_index: int,
371
+ next_shared_loop_index: int,
372
+ live_tensors: set[int],
373
+ still_live_reservations: set[TensorReservation],
374
+ duplicated_aliased_tensors: set[TensorReservation],
375
+ compatibility_left: Compatibility,
376
+ compatibility_right: Compatibility,
377
+ compatibility_joined: Compatibility,
378
+ ignored_resources: set[str],
379
+ drop_valid_reservations: bool = True,
380
+ _pmapping_row_filter_function: Callable[[pd.Series], bool] | None = None,
381
+ ) -> "PmappingDataframe":
382
+ """
383
+ A B A2
384
+ / | --- 0 |
385
+ C D C2
386
+ | --- 1 | < Shared Loop Index
387
+ E E2
388
+ |
389
+ F2
390
+ ->
391
+ A A+A2
392
+ / | --- 0
393
+ C+A2 C+C2
394
+ / | --- 1 < Shared Loop Index
395
+ E+C2 E2+D
396
+ |
397
+ F2+D
398
+ """
399
+ self._check_reservations()
400
+ right._check_reservations()
401
+ self.free_to_loop_index(shared_loop_index, live_tensors=live_tensors)
402
+ self.shift_bottom_reservation_left(shared_loop_index)
403
+
404
+ shared_tensor_names = (
405
+ compatibility_left.tensor_names & compatibility_right.tensor_names
406
+ )
407
+ shared_tensors = [
408
+ compatibility_left.get_tensor_by_name(s) for s in shared_tensor_names
409
+ ]
410
+ left_match, right_match = [], []
411
+ make_empty_result = False
412
+
413
+ def check_match(la: Loop, lb: Loop, param: str):
414
+ a, b = getattr(la.tile_pattern, param), getattr(lb.tile_pattern, param)
415
+ if isinstance(a, str) or isinstance(b, str):
416
+ left_match.append(a)
417
+ right_match.append(b)
418
+ elif a != b:
419
+ raise ValueError(f"Mismatch in {param}: {a} != {b}")
420
+
421
+ try:
422
+ for s in shared_tensor_names:
423
+ ta = compatibility_left.get_tensor_by_name(s)
424
+ tb = compatibility_right.get_tensor_by_name(s)
425
+ for la, lb in zip(ta.loops, tb.loops):
426
+ check_match(la, lb, "initial_tile_shape")
427
+ check_match(la, lb, "tile_shape")
428
+
429
+ for la, lb in zip(compatibility_left.loops, compatibility_right.loops):
430
+ check_match(la, lb, "calculated_n_iterations")
431
+
432
+ except ValueError as e:
433
+ make_empty_result = True
434
+
435
+ assert not right.left_reservations, f"{right.left_reservations} is not None"
436
+
437
+ for resource, reservations in self.right_reservations.items():
438
+ n_reservations = max(reservations, default=-1)
439
+ assert (
440
+ n_reservations <= shared_loop_index
441
+ ), f"{resource}: {reservations} > {shared_loop_index}"
442
+
443
+ for resource, reservations in self.left_reservations.items():
444
+ n_reservations = max(reservations, default=-1)
445
+ assert (
446
+ n_reservations <= shared_loop_index + 1
447
+ ), f"{resource}: {reservations} > {shared_loop_index}"
448
+
449
+ max_nloops = max(
450
+ shared_loop_index, self.get_max_loop_index(), right.get_max_loop_index()
451
+ )
452
+ min_nloops = min(self.get_min_loop_index(), right.get_min_loop_index())
453
+
454
+ sd, rd = self.data, right.data
455
+ if make_empty_result:
456
+ sd = sd.iloc[0:0]
457
+ rd = rd.iloc[0:0]
458
+
459
+ if left_match:
460
+ df = pd.merge(
461
+ sd,
462
+ rd,
463
+ how="inner",
464
+ left_on=left_match,
465
+ right_on=right_match,
466
+ suffixes=["", "_RIGHT_MERGE"],
467
+ )
468
+ else:
469
+ df = pd.merge(sd, rd, how="cross", suffixes=["", "_RIGHT_MERGE"])
470
+
471
+ # Drop all fused loop columns that are not used anymore
472
+ remaining_symbols = compatibility_joined.symbols()
473
+ dropcols = [
474
+ c for c in df.columns if is_fused_loop_col(c) and c not in remaining_symbols
475
+ ]
476
+ df = df.drop(columns=dropcols)
477
+
478
+ # Number of combinations
479
+ n_total_pmappings = self.n_total_pmappings * right.n_total_pmappings
480
+ n_valid_pmappings = self.n_valid_pmappings * right.n_valid_pmappings
481
+ scale_by = len(df) / max(1, len(self.data) * len(right.data))
482
+ n_total_pmappings *= scale_by
483
+ n_valid_pmappings *= scale_by
484
+
485
+ # Make sure everything is done in increasing loop order so we don't have
486
+ # read-after-write hazards
487
+ for nloops in range(max_nloops, min_nloops - 1, -1):
488
+
489
+ def iter_reservations(reservations_dict):
490
+ for resource in reservations_dict:
491
+ if nloops in reservations_dict[resource]:
492
+ yield resource
493
+
494
+ # For the RIGHT tree, RIGHT reservations: If there is no matching node in the left
495
+ # tree, add the above-this-level reservation from the left tree. If there is a matching
496
+ # node in the left tree, then we'll add this node to it in the next step.
497
+ for resource in iter_reservations(right.right_reservations):
498
+ if (
499
+ source := self.get_reservation_or_parent(resource, nloops - 1)
500
+ ) is None:
501
+ continue
502
+ target = nameloop2col(resource, nloops)
503
+ # If there's a merged version column, then it's in both trees
504
+ if target + "_RIGHT_MERGE" in df:
505
+ continue
506
+ add_to_col(df, target, source)
507
+ # For LEFT tree, LEFT reservations: Add the immediately-above
508
+ # reservation from the right tree.
509
+ for resource in iter_reservations(self.left_reservations):
510
+ if (
511
+ source := right.get_reservation_or_parent(resource, nloops - 1)
512
+ ) is None:
513
+ continue
514
+ right_merge_source = source + "_RIGHT_MERGE"
515
+ target = nameloop2col(resource, nloops, left=True)
516
+ if source is not None:
517
+ add_to_col(
518
+ df,
519
+ target,
520
+ right_merge_source if right_merge_source in df else source,
521
+ )
522
+ # For LEFT tree, RIGHT reservations: Add the same-level reservation from
523
+ # the right tree. This will double-count reservations that are in both branches,
524
+ # so we remove them later.
525
+ for resource in iter_reservations(self.right_reservations):
526
+ if (
527
+ source := right.get_reservation_or_parent(resource, nloops)
528
+ ) is None:
529
+ continue
530
+ right_merge_source = source + "_RIGHT_MERGE"
531
+ target = nameloop2col(resource, nloops)
532
+ if source is not None:
533
+ add_to_col(
534
+ df,
535
+ target,
536
+ right_merge_source if right_merge_source in df else source,
537
+ )
538
+
539
+ # For everything else: Simple add
540
+ dropcols = [c for c in df.columns if c.endswith("_RIGHT_MERGE")]
541
+ for source in dropcols:
542
+ target = source[: -len("_RIGHT_MERGE")]
543
+ if is_tensor_col(target):
544
+ continue
545
+ if not col_used_in_pareto(target):
546
+ raise ValueError(f"{target} is not used in pareto")
547
+ if col2nameloop(target) is None:
548
+ add_to_col(df, target, source)
549
+
550
+ df = df.drop(columns=dropcols)
551
+ result = PmappingDataframe(
552
+ df,
553
+ skip_pareto=True,
554
+ check_above_subset_below=False,
555
+ n_total_pmappings=n_total_pmappings,
556
+ n_valid_pmappings=n_valid_pmappings,
557
+ )
558
+ # Remove tensors that were allocated in both branches and got added
559
+ # together.
560
+ shared_to_free = [
561
+ s for s in shared_tensors if s.above_loop_index <= shared_loop_index
562
+ ]
563
+ live_to_alloc = [
564
+ s for s in still_live_reservations if s.above_loop_index > shared_loop_index
565
+ ]
566
+ result.adjust_reservations(
567
+ alloc=live_to_alloc,
568
+ free=list(itertools.chain(shared_to_free, duplicated_aliased_tensors)),
569
+ ignored_resources=ignored_resources,
570
+ )
571
+
572
+ if CHECK_CORRECTNESS:
573
+ result.check_above_subset_below(live_tensors)
574
+ result.check_reservations(live_tensors)
575
+
576
+ result.free_to_loop_index(next_shared_loop_index, live_tensors=live_tensors)
577
+ if not CHECK_CORRECTNESS:
578
+ result.limit_capacity(
579
+ next_shared_loop_index,
580
+ drop_valid_reservations,
581
+ ignored_resources=ignored_resources,
582
+ )
583
+ result.max_right_to_left()
584
+ if _pmapping_row_filter_function is not None:
585
+ result = result.filter_rows(_pmapping_row_filter_function)
586
+ result.make_pareto()
587
+ result._check_reservations()
588
+
589
+ return result
590
+
591
+ @error_check_wrapper
592
+ def _adjust_reservations_one_resource(
593
+ self,
594
+ resource: str,
595
+ alloc: Iterable[TensorReservation],
596
+ free: Iterable[TensorReservation],
597
+ ):
598
+ alloc, free = list(alloc), list(free)
599
+ # Iterate through each reservation and level
600
+ targets = defaultdict(int)
601
+
602
+ # Must allocate at the above_loop_index level
603
+ for t in itertools.chain(alloc, free):
604
+ self.right_reservations.setdefault(resource, set()).add(t.above_loop_index)
605
+
606
+ for t, negate in [(t, False) for t in alloc] + [(t, True) for t in free]:
607
+ size = self.data[tensor2col(t.name)]
608
+ size = -size if negate else size
609
+ targets[t.above_loop_index, False] += size
610
+ # Allocate at any levels below the above_loop_index level
611
+ for level in self.right_reservations[resource]:
612
+ if level > t.above_loop_index:
613
+ targets[level, False] += size
614
+ for level in self.left_reservations.get(resource, set()):
615
+ if level > t.above_loop_index:
616
+ targets[level, True] += size
617
+
618
+ # Now apply the allocations. Sort so we go from top to bottom in case
619
+ # there are maxes that propagate down.
620
+ for (level, left), size in sorted(
621
+ targets.items(), key=lambda x: x[0], reverse=True
622
+ ):
623
+ target = nameloop2col(resource, level, left=left)
624
+ if target in self.data:
625
+ self.data.loc[:, target] += size
626
+ continue
627
+
628
+ # We're creating a new column, so copy allocations from any parents
629
+ source = self.get_reservation_or_parent(resource, level - 1)
630
+ try:
631
+ self.data[target] = size + (self.data[source] if source else 0)
632
+ except:
633
+ source = self.get_reservation_or_parent(resource, level - 1)
634
+ self.data[target] = size + (self.data[source] if source else 0)
635
+
636
+ # Assert all reservations are >= 0
637
+ assert (self.data[target] >= 0).all(), f"Negative reservation: {target}"
638
+
639
+ @error_check_wrapper
640
+ def adjust_reservations(
641
+ self,
642
+ alloc: Iterable[TensorReservation],
643
+ free: Iterable[TensorReservation],
644
+ ignored_resources: set[str] = set(),
645
+ ):
646
+ alloc, free = list(alloc), list(free)
647
+ all_resources = {t.resource_name for t in alloc} | {
648
+ t.resource_name for t in free
649
+ }
650
+ # Handle each resource separately
651
+ for resource in all_resources:
652
+ if resource in ignored_resources:
653
+ continue
654
+ cur_alloc = [t for t in alloc if t.resource_name == resource]
655
+ cur_free = [t for t in free if t.resource_name == resource]
656
+ if cur_alloc or cur_free:
657
+ self._adjust_reservations_one_resource(resource, cur_alloc, cur_free)
658
+
659
+ @staticmethod
660
+ def concat(
661
+ paretos: list["PmappingDataframe"], skip_pareto: bool = False
662
+ ) -> "PmappingDataframe":
663
+ if len(paretos) == 0:
664
+ raise ValueError("No paretos to concatenate")
665
+ if len(paretos) == 1:
666
+ return paretos[0]
667
+
668
+ required_cols = set.union(*[set(p.data.columns) for p in paretos])
669
+ shared_cols = set.intersection(*[set(p.data.columns) for p in paretos])
670
+ fill_cols = required_cols - shared_cols
671
+ fill_cols = [c for c in fill_cols if col_used_in_pareto(c)]
672
+
673
+ concatenated = pd.concat([p.data for p in paretos]).reset_index(drop=True)
674
+
675
+ p = PmappingDataframe(
676
+ concatenated.fillna(0),
677
+ skip_pareto=len(paretos) == 1 or skip_pareto,
678
+ fill_reservation_cols=fill_cols,
679
+ n_total_pmappings=sum(p.n_total_pmappings for p in paretos),
680
+ n_valid_pmappings=sum(p.n_valid_pmappings for p in paretos),
681
+ )
682
+ return p
683
+
684
+ def update(
685
+ self,
686
+ skip_pareto: bool,
687
+ **kwargs,
688
+ ) -> "PmappingDataframe":
689
+ args = dict(
690
+ data=self.data,
691
+ skip_pareto=skip_pareto,
692
+ check_above_subset_below=False,
693
+ n_total_pmappings=self.n_total_pmappings,
694
+ n_valid_pmappings=self.n_valid_pmappings,
695
+ )
696
+ args.update(kwargs)
697
+ return PmappingDataframe(**args)
698
+
699
+ def copy(self) -> "PmappingDataframe":
700
+ return self.update(
701
+ data=self.data.copy(),
702
+ skip_pareto=True,
703
+ check_above_subset_below=False,
704
+ )
705
+ return p
706
+
707
+ def limit_capacity(
708
+ self,
709
+ next_shared_loop_index: int = None,
710
+ drop_valid_reservations: bool = True,
711
+ ignored_resources: set[str] = set(),
712
+ ):
713
+ dropcols = []
714
+ for resource in sorted(
715
+ set(self.right_reservations) | set(self.left_reservations)
716
+ ):
717
+ # Right reservations: Only check the greatest-index level. If a loop
718
+ # is 0 and the next shared loop index is -1, then we can drop the
719
+ # column.
720
+ right_loops = self.right_reservations.get(resource, set())
721
+ if right_loops:
722
+ n = max(right_loops)
723
+ col = nameloop2col(resource, n)
724
+ self._data = self.data[self.data[col] <= 1]
725
+ for l in list(right_loops):
726
+ if (
727
+ l == 0
728
+ and next_shared_loop_index == -1
729
+ and drop_valid_reservations
730
+ and resource not in ignored_resources
731
+ ):
732
+ right_loops.discard(l)
733
+ dropcols.append(col)
734
+
735
+ # Left reservations: Check all levels. If a loop is 0,
736
+ # then we can drop the column.
737
+ left_loops = self.left_reservations.get(resource, set())
738
+ for l in list(left_loops):
739
+ col = nameloop2col(resource, l, left=True)
740
+ self._data = self.data[self.data[col] <= 1]
741
+ if (
742
+ l == 0
743
+ and drop_valid_reservations
744
+ and resource not in ignored_resources
745
+ ):
746
+ left_loops.discard(l)
747
+ dropcols.append(col)
748
+
749
+ self._data = self.data.drop(columns=dropcols)
750
+ self._make_reservations()
751
+
752
+ def make_pareto(self, columns: list[str] = None, parallelize: bool = False):
753
+ self._check_reservations()
754
+ self._data = makepareto(self.data, columns, parallelize=parallelize)
755
+ self._check_reservations()
756
+
757
+ def has_reservations(self):
758
+ return any(col2nameloop(c) is not None for c in self.data.columns)
759
+
760
+ # ============================================================================
761
+ # Checking functions
762
+ # ============================================================================
763
+ def check_above_subset_below(self, live_tensors: set[str] = fzs()):
764
+ assert not self.data.isnull().values.any(), f"NaN in {self.data}"
765
+ targets = []
766
+ for left, reservations_dict in [
767
+ (True, self.left_reservations),
768
+ (False, self.right_reservations),
769
+ ]:
770
+ for resource, reservations in reservations_dict.items():
771
+ for r in reservations:
772
+ above = self.get_reservation_or_parent(resource, r - 1)
773
+ if above is not None:
774
+ below = nameloop2col(resource, r, left=left)
775
+ targets.append((above, below))
776
+
777
+ for above, below in targets:
778
+ if (self.data[below] < self.data[above]).any():
779
+ first_failing_index = (self.data[below] < self.data[above]).idxmax()
780
+ fail_row = self.data.iloc[first_failing_index]
781
+ error = f"""
782
+ {below} column is less than {above} column. A reservation at
783
+ a level should include all reservations above it. There were {len(fail_row)} rows
784
+ with this error. One example: {fail_row}
785
+ """
786
+ self.fail(first_failing_index, live_tensors)
787
+ raise ValueError(error)
788
+
789
+ def filter_rows(
790
+ self, _pmapping_row_filter_function: Callable[[pd.Series], bool] | None = None
791
+ ) -> "PmappingDataframe":
792
+ if _pmapping_row_filter_function is None:
793
+ return self.copy()
794
+
795
+ # s = _pmapping_row_filter_function(self._data)
796
+ # if s.sum() > 0:
797
+ # print(f"Filter rate: {s.sum() / len(s):.2%}")
798
+ return self.update(
799
+ data=self._data[_pmapping_row_filter_function(self._data)].copy(),
800
+ skip_pareto=True,
801
+ )
802
+
803
+ def __len__(self) -> int:
804
+ return len(self._data)
805
+
806
+ # @error_check_wrapper
807
+ # def check_reservations(self, live_tensors: set[int]):
808
+ # from accelforge.visualization.reservationtree import mappings2reservationtree
809
+ # assert not self.data.isnull().values.any(), f"NaN in {self.data}"
810
+
811
+ # self = self.copy()
812
+
813
+ # self.free_to_loop_index(-1, check_correctness=False)
814
+ # self.shift_bottom_reservation_left(-1)
815
+
816
+ # for i, r in self.data.iterrows():
817
+ # looptree = mappings2reservationtree(
818
+ # r[MAPPING_COLUMN],
819
+ # r.get(STATS, None),
820
+ # still_live_tensors=live_tensors
821
+ # )
822
+ # reservations = dict(looptree.get_reservations())
823
+
824
+ # # If r doesn't have any columns, continue. It's a copy Einsum so it has no
825
+ # # stats.
826
+ # if r.empty:
827
+ # continue
828
+
829
+ # for k, v in reservations.items():
830
+ # col = self.get_reservation_or_parent(k, 0, left=True)
831
+ # if str(k) == "0":
832
+ # continue
833
+ # if col not in self.data.columns:
834
+ # got = r[[c for c in self.data.columns if col2nameloop(c) is not None]]
835
+ # self.fail(i, live_tensors)
836
+ # raise ValueError(f"Missing {k}: Expected {reservations}. Got: {got}")
837
+ # if r[col] != v:
838
+ # got = r[[c for c in self.data.columns if col2nameloop(c) is not None]]
839
+ # self.fail(i, live_tensors)
840
+ # looptree = mappings2reservationtree(
841
+ # r[MAPPING_COLUMN],
842
+ # r.get(STATS, None),
843
+ # # skip_backing_tensors_in_right_branch=live_tensors,
844
+ # still_live_tensors=live_tensors,
845
+ # )
846
+ # raise ValueError(
847
+ # f"Mismatched {k}: {v} != {r[col]}. Expected {reservations}. Got: {got}"
848
+ # )
849
+
850
+ # def fail(self, index, live_tensors):
851
+ # from accelforge.mapper.FFM._join_pmappings.pmapping_group import TensorReservation
852
+ # r = self.data.iloc[index]
853
+ # assert not self.data.isnull().values.any(), f"NaN in {self.data}"
854
+ # self = self.copy()
855
+ # self._draw_index(index, live_tensors, self._get_target_path(suffix="fail"))
856
+ # all_tensors = set(t for tn in r[MAPPING_COLUMN].values() for t in tn.tensors)
857
+ # all_tensors = TensorReservation.get_backing_tensors(all_tensors)
858
+ # for t in sorted(all_tensors):
859
+ # print(f"{t.__repr__()},")
860
+
861
+ # def _draw_index(self, index: int, live_tensors, to_file: str = "test.png"):
862
+ # from accelforge.visualization.reservationtree import mappings2reservationtree
863
+ # import pydot
864
+ # looptree = mappings2reservationtree(
865
+ # self.data.iloc[index][MAPPING_COLUMN],
866
+ # self.data.iloc[index].get(STATS, None),
867
+ # still_live_tensors=live_tensors,
868
+ # )
869
+ # graph = pydot.Dot(graph_type="digraph", ranksep="0.2", nodesep="0.2")
870
+ # looptree.to_pydot(graph)
871
+ # row = self.data.iloc[index]
872
+ # all_data = sorted(f"{k}: {v}" for k, v in row.items() if k not in DICT_COLUMNS)
873
+ # data_str = "\n".join(all_data)
874
+ # graph.add_node(pydot.Node("data", label=data_str, shape="plaintext"))
875
+ # with open(to_file, "wb") as f:
876
+ # f.write(graph.create_png())
877
+
878
+
879
+ def row2pmappings(
880
+ row: pd.Series,
881
+ einsum_names: list[str],
882
+ rank_variable_bounds: dict[str, dict[str, int]],
883
+ ) -> list[Nested]:
884
+ pmappings: list[Nested] = []
885
+ for einsum_name in einsum_names:
886
+ pmapping: Nested = copy.deepcopy(row[f"{einsum_name}<SEP>{MAPPING_COLUMN}"])
887
+ for node in pmapping.nodes:
888
+
889
+ def acc(s: str | None | int):
890
+ s = s.name if isinstance(s, sympy.Symbol) else s
891
+ return row[f"{einsum_name}<SEP>{s}"] if isinstance(s, str) else s
892
+
893
+ if isinstance(node, MappingLoop):
894
+ tp: TilePattern = node.tile_pattern
895
+ node.tile_pattern = tp.update(
896
+ initial_tile_shape=acc(tp.initial_tile_shape),
897
+ tile_shape=acc(tp.tile_shape),
898
+ )
899
+ pmappings.append(pmapping)
900
+ pmapping._beautify_loops(rank_variable_bounds)
901
+ return pmappings