accelforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. accelforge/__init__.py +21 -0
  2. accelforge/_accelerated_imports.py +16 -0
  3. accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
  4. accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
  5. accelforge/_deprecate/_simanneal/simanneal.py +666 -0
  6. accelforge/_deprecate/_simanneal/tracking.py +105 -0
  7. accelforge/_deprecate/_simanneal/wrappers.py +218 -0
  8. accelforge/_deprecate/_simanneal2/__init__.py +7 -0
  9. accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
  10. accelforge/_deprecate/_simanneal2/tracking.py +116 -0
  11. accelforge/_deprecate/compatibility_util.py +181 -0
  12. accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
  13. accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
  14. accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
  15. accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
  16. accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
  17. accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
  18. accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
  19. accelforge/_deprecate/tags.py +69 -0
  20. accelforge/_deprecate/viz/__init__.py +0 -0
  21. accelforge/_deprecate/viz/interactive.py +159 -0
  22. accelforge/_deprecate/viz/reservationtree.py +307 -0
  23. accelforge/_deprecate/viz/ski_slope.py +88 -0
  24. accelforge/_version.py +15 -0
  25. accelforge/examples.py +39 -0
  26. accelforge/frontend/__init__.py +10 -0
  27. accelforge/frontend/_binding.py +129 -0
  28. accelforge/frontend/_workload_isl/__init__.py +2 -0
  29. accelforge/frontend/_workload_isl/_isl.py +149 -0
  30. accelforge/frontend/_workload_isl/_symbolic.py +141 -0
  31. accelforge/frontend/arch copy.py +1544 -0
  32. accelforge/frontend/arch.py +1642 -0
  33. accelforge/frontend/config.py +63 -0
  34. accelforge/frontend/mapper/__init__.py +5 -0
  35. accelforge/frontend/mapper/ffm.py +126 -0
  36. accelforge/frontend/mapper/mapper.py +7 -0
  37. accelforge/frontend/mapper/metrics.py +30 -0
  38. accelforge/frontend/mapping/__init__.py +1 -0
  39. accelforge/frontend/mapping/mapping.py +1736 -0
  40. accelforge/frontend/model.py +14 -0
  41. accelforge/frontend/renames.py +150 -0
  42. accelforge/frontend/spec copy.py +230 -0
  43. accelforge/frontend/spec.py +301 -0
  44. accelforge/frontend/variables.py +12 -0
  45. accelforge/frontend/workload.py +952 -0
  46. accelforge/mapper/FFM/__init__.py +9 -0
  47. accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
  48. accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
  49. accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
  50. accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
  51. accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
  52. accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
  53. accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
  54. accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
  55. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
  56. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
  57. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
  58. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
  59. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
  60. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
  61. accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
  62. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
  63. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
  64. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
  65. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
  66. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
  67. accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
  68. accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
  69. accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
  70. accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
  71. accelforge/mapper/FFM/data.py +61 -0
  72. accelforge/mapper/FFM/main copy.py +236 -0
  73. accelforge/mapper/FFM/main.py +208 -0
  74. accelforge/mapper/FFM/mappings.py +510 -0
  75. accelforge/mapper/FFM/pmappings.py +310 -0
  76. accelforge/mapper/__init__.py +4 -0
  77. accelforge/mapper.py +0 -0
  78. accelforge/model/__init__.py +1 -0
  79. accelforge/model/_looptree/__init__.py +0 -0
  80. accelforge/model/_looptree/accesses.py +335 -0
  81. accelforge/model/_looptree/capacity/__init__.py +1 -0
  82. accelforge/model/_looptree/capacity/aggregators.py +36 -0
  83. accelforge/model/_looptree/capacity/capacity.py +47 -0
  84. accelforge/model/_looptree/energy.py +150 -0
  85. accelforge/model/_looptree/equivalent_ranks.py +29 -0
  86. accelforge/model/_looptree/latency/__init__.py +1 -0
  87. accelforge/model/_looptree/latency/latency.py +98 -0
  88. accelforge/model/_looptree/latency/memory.py +120 -0
  89. accelforge/model/_looptree/latency/processors.py +92 -0
  90. accelforge/model/_looptree/mapping_utilities.py +71 -0
  91. accelforge/model/_looptree/reuse/__init__.py +4 -0
  92. accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
  93. accelforge/model/_looptree/reuse/isl/des.py +59 -0
  94. accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
  95. accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
  96. accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
  97. accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
  98. accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
  99. accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
  100. accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
  101. accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
  102. accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
  103. accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
  104. accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
  105. accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
  106. accelforge/model/_looptree/run.py +122 -0
  107. accelforge/model/_looptree/types.py +26 -0
  108. accelforge/model/_looptree/visualization/__init__.py +0 -0
  109. accelforge/model/_looptree/visualization/occupancy.py +11 -0
  110. accelforge/model/main.py +222 -0
  111. accelforge/plotting/__init__.py +2 -0
  112. accelforge/plotting/mappings.py +219 -0
  113. accelforge/plotting/specs.py +57 -0
  114. accelforge/util/__init__.py +4 -0
  115. accelforge/util/_base_analysis_types.py +24 -0
  116. accelforge/util/_basetypes.py +1089 -0
  117. accelforge/util/_frozenset.py +36 -0
  118. accelforge/util/_isl.py +29 -0
  119. accelforge/util/_itertools.py +14 -0
  120. accelforge/util/_mathfuncs.py +57 -0
  121. accelforge/util/_parse_expressions.py +339 -0
  122. accelforge/util/_picklecache.py +32 -0
  123. accelforge/util/_setexpressions.py +268 -0
  124. accelforge/util/_sympy/__init__.py +0 -0
  125. accelforge/util/_sympy/broadcast_max.py +18 -0
  126. accelforge/util/_visualization.py +112 -0
  127. accelforge/util/_yaml.py +579 -0
  128. accelforge/util/parallel.py +193 -0
  129. accelforge-0.0.1.dist-info/METADATA +64 -0
  130. accelforge-0.0.1.dist-info/RECORD +258 -0
  131. accelforge-0.0.1.dist-info/WHEEL +5 -0
  132. accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
  133. accelforge-0.0.1.dist-info/top_level.txt +5 -0
  134. docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
  135. docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
  136. docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
  137. docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
  138. docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
  139. docs/_build/html/_sources/fastfusion.rst.txt +20 -0
  140. docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
  141. docs/_build/html/_sources/index.rst.txt +87 -0
  142. docs/_build/html/_sources/modules.rst.txt +7 -0
  143. docs/_build/html/_sources/notes/citation.rst.txt +45 -0
  144. docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
  145. docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
  146. docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
  147. docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
  148. docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
  149. docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
  150. docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
  151. docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
  152. docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
  153. docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
  154. docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
  155. docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
  156. docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
  157. docs/_build/html/_sources/notes/spec.rst.txt +36 -0
  158. docs/source/_ext/include_attrs.py +213 -0
  159. docs/source/_ext/include_docstring.py +364 -0
  160. docs/source/_ext/include_functions.py +154 -0
  161. docs/source/_ext/include_notebook.py +131 -0
  162. docs/source/_ext/include_yaml.py +119 -0
  163. docs/source/_ext/inherited_attributes.py +222 -0
  164. docs/source/_ext/paths.py +4 -0
  165. docs/source/conf.py +79 -0
  166. examples/arches/compute_in_memory/_include.yaml +74 -0
  167. examples/arches/compute_in_memory/_include_functions.py +229 -0
  168. examples/arches/compute_in_memory/_load_spec.py +57 -0
  169. examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
  170. examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
  171. examples/arches/compute_in_memory/components/misc.py +195 -0
  172. examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
  173. examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
  174. examples/arches/compute_in_memory/isaac.yaml +233 -0
  175. examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
  176. examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
  177. examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
  178. examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
  179. examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
  180. examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
  181. examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
  182. examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
  183. examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
  184. examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
  185. examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
  186. examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
  187. examples/arches/eyeriss.yaml +68 -0
  188. examples/arches/fanout_variations/at_glb.yaml +31 -0
  189. examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
  190. examples/arches/fanout_variations/at_mac.yaml +31 -0
  191. examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
  192. examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
  193. examples/arches/nvdla.yaml +47 -0
  194. examples/arches/simple.yaml +28 -0
  195. examples/arches/tpu_v4i.yaml +67 -0
  196. examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
  197. examples/misc/component_annotated.yaml +33 -0
  198. examples/workloads/gpt3_6.7B.yaml +124 -0
  199. examples/workloads/matmuls.yaml +20 -0
  200. examples/workloads/mobilenet_28.yaml +81 -0
  201. examples/workloads/mobilenet_various_separate.yaml +106 -0
  202. examples/workloads/three_matmuls_annotated.yaml +59 -0
  203. notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
  204. notebooks/compute_in_memory/_scripts.py +339 -0
  205. notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
  206. notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
  207. notebooks/paths.py +4 -0
  208. notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
  209. notebooks/tutorials/FFM.ipynb +3498 -0
  210. notebooks/tutorials/_include.py +48 -0
  211. notebooks/tutorials/component_energy_area.ipynb +363 -0
  212. tests/Q_mapping.yaml +38 -0
  213. tests/__init__.py +0 -0
  214. tests/conv.mapping.yaml +27 -0
  215. tests/conv.workload.yaml +13 -0
  216. tests/conv_sym.mapping.yaml +43 -0
  217. tests/copy.mapping.yaml +35 -0
  218. tests/copy.workload.yaml +15 -0
  219. tests/distribuffers/__init__.py +0 -0
  220. tests/distribuffers/multicast/test_cases.yaml +482 -0
  221. tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
  222. tests/distribuffers/spec/distributed.yaml +100 -0
  223. tests/distribuffers/spec/logical_arch.yaml +32 -0
  224. tests/distribuffers/spec/physical_arch.yaml +69 -0
  225. tests/distribuffers/test_binding.py +48 -0
  226. tests/frontend/__init__.py +0 -0
  227. tests/frontend/test_mapping_viz.py +52 -0
  228. tests/mapper/__init__.py +0 -0
  229. tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
  230. tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
  231. tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
  232. tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
  233. tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
  234. tests/mapper/test_mapping_to_isl.py +90 -0
  235. tests/mapper/test_spatial_reuse_analysis.py +67 -0
  236. tests/mapper/test_temporal_reuse_analysis.py +56 -0
  237. tests/mapper/util.py +58 -0
  238. tests/matmul.mapping.yaml +29 -0
  239. tests/matmul.workload.yaml +12 -0
  240. tests/matmul_spatial.mapping.yaml +44 -0
  241. tests/mha.renames.yaml +65 -0
  242. tests/mha.workload.yaml +67 -0
  243. tests/mha.yaml +59 -0
  244. tests/mha_full.workload.yaml +67 -0
  245. tests/mobilenet.workload.yaml +35 -0
  246. tests/mobilenet_long.workload.yaml +64 -0
  247. tests/pmappingcache.py +24 -0
  248. tests/processing_stage.arch.yaml +40 -0
  249. tests/snowcat.arch.yaml +36 -0
  250. tests/test_ffm_join_pmappings.py +106 -0
  251. tests/test_ffm_make_pmappings.py +82 -0
  252. tests/test_ffm_make_tile_shapes.py +49 -0
  253. tests/test_mapper.py +100 -0
  254. tests/test_model.py +37 -0
  255. tests/test_plotting.py +72 -0
  256. tests/test_processing_stage.py +46 -0
  257. tests/test_symbolic_model.py +248 -0
  258. tests/test_workload.py +141 -0
@@ -0,0 +1,1681 @@
1
+ from enum import Enum
2
+ from functools import lru_cache
3
+ from math import ceil, log2, prod
4
+ import copy
5
+ import re
6
+ import resource
7
+ import time
8
+ from typing import Callable, Iterator, Optional
9
+ from sympy import Expr, Symbol, factorint, lambdify
10
+ from accelforge import util
11
+ from accelforge._accelerated_imports import np
12
+ from accelforge._accelerated_imports import pd
13
+ import accelforge.frontend.arch as arch
14
+ from accelforge.frontend._workload_isl._isl import get_rank_variable_bounds
15
+ from accelforge.frontend._workload_isl._symbolic import get_projection_expr
16
+ from accelforge.frontend.workload import Einsum
17
+ from accelforge.frontend.mapping import (
18
+ Loop,
19
+ Mapping,
20
+ Temporal,
21
+ Spatial,
22
+ TensorHolder,
23
+ )
24
+ from accelforge.mapper.FFM._make_pmappings.pmapper_job import Job
25
+ from accelforge.mapper.FFM._pareto_df.df_convention import (
26
+ stride2col,
27
+ initial2col,
28
+ iterations2col,
29
+ )
30
+ from accelforge.mapper.FFM._pareto_df.pareto import makepareto_numpy
31
+ from accelforge.model._looptree.reuse.symbolic import IMPERFECT
32
+ from accelforge.mapper.FFM._join_pmappings.pmapping_dataframe import (
33
+ nameloop2col,
34
+ tensor2col,
35
+ firstlatency2col,
36
+ )
37
+ from accelforge.frontend.mapper.metrics import Metrics
38
+ from accelforge.util._frozenset import fzs
39
+ import math
40
+ import sympy
41
+ import numpy as np
42
+ from numbers import Number
43
+
44
+ from accelforge.mapper.FFM._make_pmappings.make_pmappings_from_templates.symbol_relations import (
45
+ SymbolRelations,
46
+ )
47
+ from accelforge.util._sympy.broadcast_max import Max
48
+ from accelforge.mapper.FFM._make_pmappings.make_pmappings_from_templates.run_model import (
49
+ run_model,
50
+ )
51
+
52
+
53
+ class ComparisonResult(Enum):
54
+ ALWAYS_GEQ_THAN_ZERO = "ALWAYS_GEQ_THAN_ZERO"
55
+ ALWAYS_LEQ_THAN_ZERO = "ALWAYS_LEQ_THAN_ZERO"
56
+ ALWAYS_EQUAL_TO_ZERO = "ALWAYS_EQUAL_TO_ZERO"
57
+ UNKNOWN = "unknown"
58
+
59
+ def __or__(self, other: "ComparisonResult"):
60
+ if self == other:
61
+ return self
62
+ if self == ComparisonResult.ALWAYS_EQUAL_TO_ZERO:
63
+ return other
64
+ if other == ComparisonResult.ALWAYS_EQUAL_TO_ZERO:
65
+ return self
66
+ return ComparisonResult.UNKNOWN
67
+
68
+
69
+ @lru_cache(maxsize=10000)
70
+ def diff(f: Expr, s: Symbol):
71
+ return sympy.diff(f, s)
72
+
73
+
74
+ @lru_cache(maxsize=10000)
75
+ def diff_geq_leq_zero(f: Expr, s: Symbol, bounds: tuple[tuple[Symbol, int, int], ...]):
76
+ # Assume ceiling won't affect the sign of the derivative. Changing from positive to
77
+ # zero or negative to zero is OK and does not count as changing the sign.
78
+ if isinstance(f, sympy.Expr):
79
+ f = f.replace(
80
+ lambda expr: expr.is_Function and expr.func == sympy.ceiling,
81
+ lambda expr: expr.args[0],
82
+ )
83
+ return geq_leq_zero(diff(f, s), bounds)
84
+
85
+
86
+ @lru_cache(maxsize=10000)
87
+ def function_range(f: Expr, s: Symbol, lo: int, hi: int):
88
+ return sympy.calculus.util.function_range(f, s, domain=sympy.Interval(lo, hi))
89
+
90
+
91
+ def expr_replace(f: Expr, old: sympy.Function, new: Expr) -> Expr:
92
+ return f.replace(
93
+ lambda expr: expr.is_Function and expr.func == old,
94
+ lambda expr: new,
95
+ )
96
+
97
+
98
+ def partition_heaviside(f: Expr) -> tuple[Expr, ...]:
99
+ if f.has(sympy.Heaviside):
100
+ return expr_replace(f, sympy.Heaviside, 1), expr_replace(f, sympy.Heaviside, 0)
101
+ return (f,)
102
+
103
+
104
+ # @lru_cache(maxsize=10000)
105
+ # def _get_function_range(
106
+ # f: Expr,
107
+ # check_symbols: tuple[Symbol, ...],
108
+ # bounds: tuple[tuple[Symbol, int, int], ...],
109
+ # return_min: bool,
110
+ # ) -> list:
111
+ # if isinstance(f, sympy.Expr):
112
+ # f = f.replace(
113
+ # lambda expr: expr.is_Function and expr.func == sympy.ceiling,
114
+ # lambda expr: expr.args[0],
115
+ # )
116
+ # fs = list(partition_heaviside(f))
117
+ # else:
118
+ # fs = [f]
119
+
120
+ # if len(fs) > 1:
121
+ # return [f3 for f2 in fs for f3 in _get_function_range(f2, check_symbols, bounds, return_min)]
122
+
123
+ # f = fs[0]
124
+ # check_symbol = check_symbols[0]
125
+ # check_symbols = check_symbols[1:]
126
+ # bounds = None
127
+ # for s, lo, hi in bounds:
128
+ # if s == check_symbol:
129
+ # bounds = (s, lo, hi)
130
+ # break
131
+ # else:
132
+ # raise ValueError(f"Symbol {check_symbol} not found in bounds")
133
+
134
+ # f_range = sympy.calculus.util.function_range(f, check_symbol, domain=sympy.Interval(lo, hi))
135
+
136
+ # if isinstance(f_range, sympy.FiniteSet):
137
+ # return [f3 for f2 in f_range for f3 in _get_function_range(f2, check_symbols, bounds, return_min)]
138
+
139
+ # target = f_range.left if return_min else f_range.right
140
+ # return _get_function_range(target, check_symbols, bounds, return_min)
141
+
142
+
143
+ @lru_cache(maxsize=10000)
144
+ def _compare_to_zero(
145
+ f: Expr, bounds: tuple[tuple[Symbol, int, int], ...], check_lt_zero: bool
146
+ ) -> bool:
147
+ """
148
+ Returns True if the function may possibly be less than zero or greater than zero.
149
+
150
+ If check_lt_zero is True, then we're checking if the function may possibly be less
151
+ than zero. Otherwise, we're checking if the function may possibly be greater than
152
+ zero.
153
+
154
+ If we can't tell, then conservatively return True.
155
+ """
156
+ if isinstance(f, sympy.Expr):
157
+ f = f.replace(
158
+ lambda expr: expr.is_Function and expr.func == sympy.ceiling,
159
+ lambda expr: expr.args[0],
160
+ )
161
+ fs = list(partition_heaviside(f))
162
+ else:
163
+ fs = [f]
164
+
165
+ if len(fs) > 1:
166
+ return any(_compare_to_zero(f2, bounds, check_lt_zero) for f2 in fs)
167
+
168
+ f = fs[0]
169
+ try:
170
+ if check_lt_zero:
171
+ # Less than zero anywhere == NOT geq zero everywhere
172
+ return not f >= 0
173
+ else:
174
+ # Greater than zero anywhere == NOT leq zero everywhere
175
+ return not f <= 0
176
+ except TypeError:
177
+ pass
178
+
179
+ min_check, max_check = (any, all) if check_lt_zero else (all, any)
180
+ if isinstance(f, sympy.Min):
181
+ return min_check(_compare_to_zero(g, bounds, check_lt_zero) for g in f.args)
182
+ if isinstance(f, sympy.Max):
183
+ return max_check(_compare_to_zero(g, bounds, check_lt_zero) for g in f.args)
184
+
185
+ # Tried this on one workload and had marginally faster speeds with choosing the
186
+ # symbol that appears the least times. Also tried the symbol that appears the most
187
+ # times and the symbol that appears first in the bounds list. They had equivalent
188
+ # speeds, approx. 3% slower overall tile shape exploration than min.
189
+ chosen_s = min(f.free_symbols, key=lambda s: f.count(s))
190
+ for s, lo, hi in bounds:
191
+ if s == chosen_s:
192
+ break
193
+ else:
194
+ raise ValueError(f"Symbol {chosen_s} not found in bounds")
195
+
196
+ try:
197
+ f_range = function_range(f, s, lo, hi)
198
+ except (NotImplementedError, TypeError):
199
+ return True
200
+
201
+ if isinstance(f_range, sympy.FiniteSet):
202
+ return any(_compare_to_zero(f2, bounds, check_lt_zero) for f2 in f_range)
203
+ else:
204
+ return _compare_to_zero(
205
+ f_range.left if check_lt_zero else f_range.right,
206
+ bounds,
207
+ check_lt_zero,
208
+ )
209
+
210
+
211
+ @lru_cache(maxsize=10000)
212
+ def geq_leq_zero(
213
+ f: Expr,
214
+ bounds: tuple[tuple[Symbol, int, int], ...],
215
+ ):
216
+ # return geq_leq_than_zero(f, bounds)
217
+ lt_zero = _compare_to_zero(f, bounds, check_lt_zero=True)
218
+ gt_zero = _compare_to_zero(f, bounds, check_lt_zero=False)
219
+
220
+ if lt_zero and gt_zero:
221
+ return ComparisonResult.UNKNOWN
222
+ if lt_zero and not gt_zero:
223
+ return ComparisonResult.ALWAYS_LEQ_THAN_ZERO
224
+ if gt_zero and not lt_zero:
225
+ return ComparisonResult.ALWAYS_GEQ_THAN_ZERO
226
+ return ComparisonResult.ALWAYS_EQUAL_TO_ZERO
227
+
228
+
229
+ def compile_dict(symbols, dictionary):
230
+ def lambdify(key, value):
231
+ x = util._lambdify_type_check(symbols, value)
232
+ return x
233
+
234
+ return {k: lambdify(symbols, v) for k, v in dictionary.items()}
235
+
236
+
237
+ class Goal:
238
+ """
239
+ X subset Y means that Y will block pruning for all cases that X will block pruning.
240
+
241
+ - min is a subset of min_per_prime_factor is a subset of diff
242
+ - max is a subset of max_per_prime_factor is a subset of diff
243
+
244
+ If we're combining goals and they disagree, use the larger space.
245
+ """
246
+
247
+ def __init__(
248
+ self,
249
+ goal: str = None,
250
+ max_value: Optional[float] = None,
251
+ only_care_if_valid: bool = False,
252
+ ):
253
+ self.goal = goal
254
+ self.max_value = max_value
255
+ self.only_care_if_valid = only_care_if_valid
256
+
257
+ def __or__(self, other: "Goal"):
258
+ if self.goal is None:
259
+ return copy.copy(other)
260
+ if other.goal is None:
261
+ return copy.copy(self)
262
+ assert self.max_value == other.max_value
263
+ assert self.only_care_if_valid == other.only_care_if_valid
264
+ mv = self.max_value
265
+ care = self.only_care_if_valid or other.only_care_if_valid
266
+
267
+ # If the goals are the same, space doesn't change
268
+ if self.goal == other.goal:
269
+ return Goal(self.goal, max_value=mv, only_care_if_valid=care)
270
+
271
+ # min_per_prime_factor is a superset of min, so we can just keep the min_per_prime_factor goal
272
+ if {self.goal, other.goal} == {"min", "min_per_prime_factor"}:
273
+ return Goal("min_per_prime_factor", max_value=mv, only_care_if_valid=care)
274
+
275
+ # max_per_prime_factor is a superset of max, so we can just keep the max_per_prime_factor goal
276
+ if {self.goal, other.goal} == {"max", "max_per_prime_factor"}:
277
+ return Goal("max_per_prime_factor", max_value=mv, only_care_if_valid=care)
278
+
279
+ # Otherwise, there's a disagreement and the only space we're both in can be diff
280
+ return Goal("diff", max_value=mv, only_care_if_valid=care)
281
+
282
+ def __str__(self):
283
+ return f"{self.goal} {self.max_value} {self.only_care_if_valid}"
284
+
285
+ def __repr__(self):
286
+ return f"Goal({self.goal}, {self.max_value}, {self.only_care_if_valid})"
287
+
288
+ def __invert__(self):
289
+ if self.goal == "min":
290
+ return Goal("max", self.max_value, self.only_care_if_valid)
291
+ elif self.goal == "max":
292
+ return Goal("min", self.max_value, self.only_care_if_valid)
293
+ elif self.goal == "min_per_prime_factor":
294
+ raise ValueError("Can't invert min_per_prime_factor")
295
+ elif self.goal == "max_per_prime_factor":
296
+ raise ValueError("Can't invert max_per_prime_factor")
297
+ else:
298
+ return copy.copy(self)
299
+
300
+ def __eq__(self, other: "Goal"):
301
+ return (
302
+ isinstance(other, Goal)
303
+ and self.goal == other.goal
304
+ and self.max_value == other.max_value
305
+ and self.only_care_if_valid == other.only_care_if_valid
306
+ )
307
+
308
+
309
+ class Objective:
310
+ def __init__(
311
+ self,
312
+ name: str,
313
+ formula: Expr | Number,
314
+ max_value: float = None,
315
+ symbols: list[str] = None,
316
+ only_care_if_valid: bool = False,
317
+ min_value: float = None,
318
+ inclusive: bool = True,
319
+ try_best_if_none_reaches_min: bool = False,
320
+ ):
321
+ if isinstance(formula, Number):
322
+ formula = sympy.Number(formula)
323
+ self.name: str = name
324
+ self.formula: Expr = simplify(formula)
325
+ self._symbols: list[str] = symbols
326
+ self.max_value: float = max_value
327
+ self.min_value: float = min_value
328
+ self.only_care_if_valid: bool = only_care_if_valid
329
+ if only_care_if_valid:
330
+ assert max_value is not None or min_value is not None
331
+ self.inclusive: bool = inclusive
332
+ self.try_best_if_none_reaches_min: bool = try_best_if_none_reaches_min
333
+
334
+
335
+ def is_constant(f: Expr) -> bool:
336
+ try:
337
+ return f.is_constant()
338
+ except ValueError:
339
+ return all(is_constant(arg) for arg in f.args)
340
+
341
+
342
+ @lru_cache(maxsize=10000)
343
+ def _try_replace_single_term(
344
+ t: Expr,
345
+ symbols_enumerated: fzs[Symbol],
346
+ bounds: tuple[tuple[Symbol, int, int], ...],
347
+ ):
348
+ goal = None
349
+ if len(t.free_symbols & symbols_enumerated) == 1:
350
+ s = next(iter(t.free_symbols & symbols_enumerated))
351
+ try:
352
+ diff_result = diff_geq_leq_zero(t, s, bounds)
353
+ if diff_result == ComparisonResult.ALWAYS_GEQ_THAN_ZERO:
354
+ goal = Goal("min")
355
+ elif diff_result == ComparisonResult.ALWAYS_LEQ_THAN_ZERO:
356
+ goal = Goal("max")
357
+ elif diff_result == ComparisonResult.UNKNOWN:
358
+ goal = Goal("diff")
359
+ elif diff_result == ComparisonResult.ALWAYS_EQUAL_TO_ZERO:
360
+ pass
361
+ else:
362
+ raise ValueError(
363
+ f"Comparison result {diff_result} is not a valid comparison result"
364
+ )
365
+ return s, goal
366
+ except (TypeError, ValueError):
367
+ pass
368
+ return t, None
369
+
370
+
371
+ def try_replace_single_term(
372
+ t: Expr,
373
+ symbols_enumerated: fzs[Symbol],
374
+ bounds: tuple[tuple[Symbol, int, int], ...],
375
+ ):
376
+ return _try_replace_single_term(t, symbols_enumerated & t.free_symbols, bounds)
377
+
378
+
379
+ @lru_cache(maxsize=10000)
380
+ def _partition_formula(
381
+ f: Expr,
382
+ symbols_enumerated: set[Symbol],
383
+ bounds: tuple[tuple[Symbol, int, int], ...],
384
+ ) -> dict[Symbol, Goal]:
385
+ goals: dict[Symbol, Goal] = {}
386
+
387
+ def update_goal(symbol: Symbol, goal: str, **kwargs):
388
+ goals[symbol] = Goal(goal) | goals.get(symbol, Goal())
389
+
390
+ negate = False
391
+
392
+ if not f.free_symbols & symbols_enumerated:
393
+ return goals
394
+
395
+ def _try_replace_unknowns(t: Expr):
396
+ for s in t.free_symbols - symbols_enumerated:
397
+ if not affects_comparison(t, s, symbols_enumerated):
398
+ t = t.subs(s, 1)
399
+ return t
400
+
401
+ def _recombine_terms(terms: list[Expr]):
402
+ can_evaluate = []
403
+ no_relation = []
404
+ others = {}
405
+ for t in terms:
406
+ t = _try_replace_unknowns(t)
407
+ try:
408
+ if not t.free_symbols & symbols_enumerated:
409
+ continue
410
+ except (TypeError, ValueError):
411
+ pass
412
+ if t.free_symbols.issubset(symbols_enumerated):
413
+ can_evaluate.append(t)
414
+ elif t.free_symbols.isdisjoint(symbols_enumerated):
415
+ no_relation.append(t)
416
+ else:
417
+ others.setdefault(fzs(t.free_symbols - symbols_enumerated), []).append(
418
+ t
419
+ )
420
+
421
+ # Grab the terms that we can evaluate directly first
422
+ chosen = []
423
+ if can_evaluate:
424
+ chosen.append(type(f)(*can_evaluate))
425
+ # Ignore no relation
426
+ chosen.extend([x for v in others.values() for x in v])
427
+
428
+ return chosen
429
+
430
+ if isinstance(f, (sympy.Max, sympy.Min, sympy.Add, sympy.ceiling)):
431
+ terms = _recombine_terms(f.args)
432
+ elif isinstance(f, sympy.Mul):
433
+ terms = _recombine_terms(f.args)
434
+ # If the formula is a product:
435
+ # - Divide the max value by the constant factors
436
+ # - For non-constant factors, if they're >1 then we can keep the max.
437
+ # Otherwise we have to drop it.
438
+ for t in f.args:
439
+ geq_result = geq_leq_zero(t, bounds)
440
+ if geq_result == ComparisonResult.ALWAYS_LEQ_THAN_ZERO:
441
+ negate = not negate
442
+ elif geq_result == ComparisonResult.UNKNOWN:
443
+ negate = None
444
+ break
445
+ elif geq_result == ComparisonResult.ALWAYS_GEQ_THAN_ZERO:
446
+ pass
447
+ elif geq_result == ComparisonResult.ALWAYS_EQUAL_TO_ZERO:
448
+ pass
449
+ else:
450
+ raise ValueError(
451
+ f"Comparison result {geq_result} is not a valid comparison result"
452
+ )
453
+ else:
454
+ terms = [_try_replace_unknowns(f)]
455
+
456
+ for term in terms:
457
+ term, goal = try_replace_single_term(term, fzs(symbols_enumerated), bounds)
458
+ if goal is not None:
459
+ update_goal(term, goal.goal)
460
+ continue
461
+
462
+ # Constant! Don't care
463
+ if len(term.free_symbols & symbols_enumerated) == 0:
464
+ continue
465
+
466
+ if term.free_symbols.issubset(symbols_enumerated):
467
+ update_goal(term, "min")
468
+ continue
469
+
470
+ # Don't recurse with the same formula. If we got here without simplifying it,
471
+ # give up and mark everything "diff".
472
+ if term == f:
473
+ for symbol in term.free_symbols:
474
+ update_goal(symbol, "diff")
475
+ else:
476
+ for subterm, subgoal in partition_formula(
477
+ term, symbols_enumerated, bounds
478
+ ).items():
479
+ goals[subterm] = subgoal | goals.get(subterm, Goal())
480
+
481
+ for k, v in goals.items():
482
+ if negate:
483
+ goals[k] = ~v
484
+ if negate is None:
485
+ v.goal = "diff"
486
+
487
+ return goals
488
+
489
+
490
+ @lru_cache(maxsize=10000)
491
+ def _get_n_prime_factors(n: int) -> int:
492
+ return len(factorint(n))
493
+
494
+
495
+ def partition_formula(
496
+ f: Expr,
497
+ symbols_enumerated: set[Symbol],
498
+ bounds: tuple[tuple[Symbol, int, int], ...],
499
+ ) -> dict[Symbol, Goal]:
500
+ return _partition_formula(f, fzs(symbols_enumerated & f.free_symbols), bounds)
501
+
502
+
503
+ def get_possible_factor_sizes(n: int, imperfect: bool = False) -> list[int]:
504
+ factors = []
505
+ for i in range(1, math.ceil(n**0.5) + 1):
506
+ if not imperfect and n % i != 0:
507
+ continue
508
+ factors.append(i)
509
+ factors.append(math.ceil(n / i))
510
+ return sorted(set(factors))
511
+
512
+
513
+ def append_vector(matrix: np.ndarray, vector: np.ndarray):
514
+ if matrix is None:
515
+ return vector.reshape(-1, 1)
516
+ return np.concatenate(
517
+ (
518
+ np.repeat(matrix, vector.shape[0], axis=0),
519
+ np.tile(vector.reshape(-1, 1), (matrix.shape[0], 1)),
520
+ ),
521
+ axis=1,
522
+ )
523
+
524
+
525
+ @lru_cache(maxsize=10000)
526
+ def simplify(f: Expr):
527
+ return f.simplify()
528
+
529
+
530
+ def symbol2int(symbol: Symbol):
531
+ return int(re.findall(r"(\d+)", symbol.name)[0])
532
+
533
+
534
+ @lru_cache(maxsize=10000)
535
+ def f_minus_other_f(f: Expr, symbols_enumerated: set[Symbol]):
536
+ f2 = f
537
+ for s in f.free_symbols & symbols_enumerated:
538
+ f2 = f2.subs(s, sympy.Symbol(f"{s}_2", integer=True, positive=True))
539
+ return f2 - f > 0
540
+
541
+
542
+ @lru_cache(maxsize=10000)
543
+ def affects_comparison(f: Expr, s: Symbol, symbols_enumerated: set[Symbol]):
544
+ if not isinstance(f, sympy.Expr):
545
+ return False
546
+ delta = f_minus_other_f(f, symbols_enumerated)
547
+ if not isinstance(delta, sympy.Expr) or s not in delta.free_symbols:
548
+ return False
549
+
550
+ delta = simplify(delta)
551
+ if s not in delta.free_symbols:
552
+ return False
553
+
554
+ return True
555
+
556
+
557
+ def get_padded_choices(
558
+ symbols_enumerated: list[Symbol],
559
+ symbols_non_enumerated_set: set[Symbol],
560
+ choices_enumerated: np.ndarray,
561
+ what_tiles_symbol: SymbolRelations,
562
+ minimize_formula: Expr = None,
563
+ maximize_formula: Expr = None,
564
+ ):
565
+ choices_padded = {}
566
+ ones = np.ones(choices_enumerated.shape[0], choices_enumerated.dtype)
567
+ for symbol in symbols_enumerated:
568
+ choices_padded[symbol] = choices_enumerated[:, symbols_enumerated.index(symbol)]
569
+ for symbol in symbols_non_enumerated_set:
570
+ choices_padded[symbol] = ones
571
+ if minimize_formula is not None or maximize_formula is not None:
572
+ if minimize_formula is None:
573
+ formula = maximize_formula
574
+ sign = -1
575
+ elif maximize_formula is None:
576
+ formula = minimize_formula
577
+ sign = 1
578
+ else:
579
+ raise ValueError(
580
+ "Both minimize_formula and maximize_formula are not None"
581
+ )
582
+ diff_result = diff_geq_leq_zero(
583
+ sign * formula, symbol, what_tiles_symbol.bounds
584
+ )
585
+ if diff_result == ComparisonResult.ALWAYS_LEQ_THAN_ZERO:
586
+ choices_padded[symbol] = ones * what_tiles_symbol.get_max_size(symbol)
587
+ elif diff_result == ComparisonResult.ALWAYS_GEQ_THAN_ZERO:
588
+ pass
589
+ elif diff_result == ComparisonResult.ALWAYS_EQUAL_TO_ZERO:
590
+ pass
591
+ elif diff_result == ComparisonResult.UNKNOWN:
592
+ raise ValueError(f"Can't tell if {symbol} is increasing or decreasing")
593
+ else:
594
+ raise ValueError(
595
+ f"Comparison result {diff_result} is not a valid comparison result"
596
+ )
597
+
598
+ return choices_padded
599
+
600
+
601
+ def check_loops(
602
+ symbols_enumerated: list[Symbol],
603
+ choices_enumerated: np.ndarray,
604
+ max_loop_check_groups: list[tuple[Number, list[Symbol]]],
605
+ what_tiles_symbol: SymbolRelations,
606
+ ):
607
+ def get_size(x: Symbol | int):
608
+ if isinstance(x, Symbol) and x in symbols_enumerated:
609
+ return choices_enumerated[:, symbols_enumerated.index(x)]
610
+ elif isinstance(x, Symbol):
611
+ return what_tiles_symbol.get_max_size(x)
612
+ else:
613
+ return x
614
+
615
+ def has_fanout(x: Symbol | int):
616
+ outer = get_size(what_tiles_symbol.get_inner_tiles(x))
617
+ inner = get_size(x)
618
+ return outer != inner
619
+
620
+ def can_check(x: Symbol | int):
621
+ if isinstance(x, Symbol) and x not in symbols_enumerated:
622
+ return False
623
+ # tiles = what_tiles_symbol.get_outer_tiles(x, none_if_fail=True)
624
+ # if tiles is not None and isinstance(tiles, Symbol) and tiles not in symbols_enumerated:
625
+ # return False
626
+ return True
627
+
628
+ for limit, group in max_loop_check_groups:
629
+ prev_len = choices_enumerated.shape[0]
630
+ if len(group) <= limit:
631
+ continue
632
+
633
+ n = 0
634
+ for g in group:
635
+ if can_check(g):
636
+ n += has_fanout(g)
637
+
638
+ if isinstance(n, np.ndarray):
639
+ choices_enumerated = choices_enumerated[n <= limit]
640
+ elif n > limit:
641
+ choices_enumerated = choices_enumerated[0:0, :]
642
+
643
+ return choices_enumerated
644
+
645
+
646
+ def coalesce_symbols(
647
+ update_symbol2goal: Callable,
648
+ symbols_enumerated: list[Symbol],
649
+ symbol2goal: dict[Symbol, Goal],
650
+ log_message: Callable,
651
+ bounds: tuple[tuple[Symbol, int, int], ...],
652
+ ):
653
+ sym_enumerated_set = fzs(symbols_enumerated)
654
+ new_symbol2goal = {}
655
+
656
+ log_message("coalesce symbols", f"initial")
657
+ for s, g in symbol2goal.items():
658
+ log_message(f"\t{g.goal}: {s}")
659
+
660
+ changed = True
661
+ while changed:
662
+ new_symbol2goal = {}
663
+
664
+ def latest(s=None):
665
+ if s is None:
666
+ x = dict(symbol2goal)
667
+ x.update(new_symbol2goal)
668
+ return x
669
+ return new_symbol2goal[s] if s in new_symbol2goal else symbol2goal[s]
670
+
671
+ for formula, goal in list(symbol2goal.items()):
672
+ # Not dependent on any enumerated symbols, so drop it
673
+ if not formula.free_symbols & sym_enumerated_set:
674
+ log_message("coalesce symbols", f"dropping constant: {formula}")
675
+ continue
676
+
677
+ # It is an enumerated symbol, so just keep it
678
+ if formula in symbols_enumerated:
679
+ update_symbol2goal(formula, goal, new_symbol2goal)
680
+ continue
681
+
682
+ # If it's a sum, remove any terms that are constant
683
+ if isinstance(formula, sympy.Add):
684
+ for term in formula.args:
685
+ if len(term.free_symbols) == 0:
686
+ formula = formula.subs(term, 0)
687
+ log_message("coalesce symbols", f"dropping constant: {term}")
688
+ continue
689
+ if len(formula.args) == 1:
690
+ formula = formula.args[0]
691
+
692
+ # If it's a product, remove any terms that are constant
693
+ if isinstance(formula, sympy.Mul):
694
+ for term in formula.args:
695
+ if len(term.free_symbols) == 0:
696
+ formula = formula.subs(term, 1)
697
+ if term < 0:
698
+ goal = ~goal
699
+ log_message("coalesce symbols", f"dropping constant: {term}")
700
+ continue
701
+ if len(formula.args) == 1:
702
+ formula = formula.args[0]
703
+
704
+ # If it's a function of a non-enumerated symbol or a symbol that we can't
705
+ # compare and it won't affect comparisons, then we can drop it.
706
+
707
+ # If it's a function of a non-enumerated symbol &
708
+ for s in formula.free_symbols:
709
+ if s in symbols_enumerated and latest().get(s, Goal()).goal != "diff":
710
+ continue
711
+
712
+ if not affects_comparison(formula, s, sym_enumerated_set):
713
+ formula = formula.subs(s, 1)
714
+ log_message(
715
+ "coalesce symbols",
716
+ f"dropping non-comparable symbol that does not affect comparison {s}: {formula}",
717
+ )
718
+ continue
719
+ else:
720
+ log_message(
721
+ "coalesce symbols",
722
+ f"keeping dropping symbol that affects comparison {s}: {formula}",
723
+ )
724
+
725
+ # If there's only one symbol in the formula, we can try to replace it with
726
+ # just the symbol.
727
+ if len(formula.free_symbols & sym_enumerated_set) == 1:
728
+ formula, new_goal = try_replace_single_term(
729
+ formula, sym_enumerated_set, bounds
730
+ )
731
+ if new_goal is not None:
732
+ log_message("coalesce symbols", f"replacing single term: {formula}")
733
+ update_symbol2goal(formula, new_goal, new_symbol2goal)
734
+
735
+ # If we're a fraction and all of our symbols are in the denominator, replace
736
+ # it with the reciprocal and change the goal
737
+ if isinstance(formula, sympy.Mul):
738
+ for term in formula.args:
739
+ if len(term.free_symbols) == 0:
740
+ continue
741
+ if isinstance(term, sympy.Pow) and term.args[1] == -1:
742
+ continue
743
+ break
744
+ else:
745
+ log_message("coalesce symbols", f"replacing reciprocal: {formula}")
746
+ formula = 1 / formula
747
+ goal = ~goal
748
+
749
+ # # If a symbol does not affect the formula, we can remove it
750
+ # for s in formula.free_symbols:
751
+ # diff_result = diff_geq_leq_zero(formula, s, bounds)
752
+ # if diff_result == ComparisonResult.ALWAYS_EQUAL_TO_ZERO:
753
+ # formula = formula.subs(s, 1)
754
+ # log_message("coalesce symbols", f"dropping symbol based on derivative == 0: {s}: {formula}")
755
+ # continue
756
+ # else:
757
+ # log_message("coalesce symbols", f"not dropping symbol based on derivative == 0: {s}: {formula}")
758
+
759
+ # If a formula agrees entirely with other goals, then we can remove it
760
+ disagrees = []
761
+ for s in formula.free_symbols:
762
+ g = latest(s).goal if s in latest() else None
763
+ if g in ["min", "max"]:
764
+ diff_result = diff_geq_leq_zero(formula, s, bounds)
765
+ if diff_result == ComparisonResult.ALWAYS_LEQ_THAN_ZERO:
766
+ this_goal = (~goal).goal
767
+ elif diff_result == ComparisonResult.ALWAYS_GEQ_THAN_ZERO:
768
+ this_goal = (goal).goal
769
+ elif diff_result == ComparisonResult.UNKNOWN:
770
+ break
771
+ elif diff_result == ComparisonResult.ALWAYS_EQUAL_TO_ZERO:
772
+ this_goal = g # Make it agree
773
+ else:
774
+ diff_geq_leq_zero(formula, s, bounds)
775
+ raise ValueError(
776
+ f"Comparison result {diff_result} is not a valid comparison result"
777
+ )
778
+ if g != this_goal:
779
+ disagrees.append(s)
780
+ continue
781
+ break
782
+ else:
783
+ # We didn't break! This formula agrees with all other goals, so we can
784
+ # remove it.
785
+ log_message(
786
+ "coalesce symbols",
787
+ f"removing formula that agrees with all other goals: {formula}",
788
+ )
789
+ for s in disagrees:
790
+ log_message(
791
+ "coalesce symbols",
792
+ f"previous formula disagreed with {s}. Changing goal to diff",
793
+ )
794
+ update_symbol2goal(s, Goal("diff"), new_symbol2goal)
795
+ continue
796
+ update_symbol2goal(formula, goal, new_symbol2goal)
797
+
798
+ changed = symbol2goal != new_symbol2goal
799
+ symbol2goal = new_symbol2goal
800
+
801
+ log_message("coalesce symbols", f"final")
802
+ for s, g in symbol2goal.items():
803
+ log_message(f"\t{g.goal}: {s}")
804
+
805
+ return symbol2goal
806
+
807
+
808
+ def get_tile_shape_choices(
809
+ objectives: list[Objective],
810
+ symbols: list[Symbol],
811
+ what_tiles_symbol: SymbolRelations,
812
+ job: "Job",
813
+ keep_symbols: list[Symbol] = (),
814
+ max_loop_check_groups: list[tuple[Number, list[Symbol]]] = (),
815
+ ):
816
+ objectives = [copy.deepcopy(o) for o in objectives]
817
+
818
+ import time
819
+
820
+ objectives = objectives.copy()
821
+
822
+ symbols_enumerated: list[Symbol] = []
823
+ choices_enumerated: np.ndarray = None
824
+
825
+ symbols_remaining = list(symbols)
826
+
827
+ imperfect = IMPERFECT
828
+
829
+ # Inner to outer faster if there's symbols to keep because those symbols end up in
830
+ # the outer loops, so it does those symbols (which end up multiplying our choices)
831
+ # last. Outer to inner is faster if there's no symbols to keep because that's what
832
+ # happened on exactly one workload that Tanner tested.
833
+ # TILE_SHAPE_ORDER = "inner_to_outer_one_rv_at_a_time" if keep_symbols else "outer_to_inner_one_rv_at_a_time"
834
+ TILE_SHAPE_ORDER = "inner_to_outer_one_rv_at_a_time"
835
+ # TILE_SHAPE_ORDER = "inner_to_outer"
836
+
837
+ # For imperfect, we make inner tile shapes, then create outer tile shapes that are
838
+ # multiples of the non-residual part of the inner tile shape. This way, the very last
839
+ # iteration of the outer tile shape fully contains the reisudal part of the inner tile
840
+ # shape, and we don't have any cases where there are residuals stacking across multiple
841
+ # loop levels.
842
+ if IMPERFECT:
843
+ assert TILE_SHAPE_ORDER == "inner_to_outer_one_rv_at_a_time"
844
+
845
+ paretoed_by = []
846
+
847
+ prev_time, start_time = time.time(), time.time()
848
+ times = {}
849
+
850
+ def time_end(s):
851
+ nonlocal prev_time
852
+ cur_time = time.time()
853
+ times.setdefault(s, 0)
854
+ times[s] += cur_time - prev_time
855
+ prev_time = cur_time
856
+
857
+ def log_message(message: str, *args: str):
858
+ t = time.time() - prev_time
859
+ s = "**" if t > 1 else ""
860
+ job.log_message(f"{s}{t:.2f}s: {message} {' '.join(args)}")
861
+ # print(f"{time.time() - prev_time:.2f}s: {message} {' '.join(args)}")
862
+ time_end(message)
863
+
864
+ log_message("init")
865
+
866
+ def eval_objective(
867
+ formula: Expr | Objective,
868
+ choices: np.ndarray,
869
+ minimize_formula: Expr = None,
870
+ maximize_formula: Expr = None,
871
+ ):
872
+ if isinstance(formula, Objective):
873
+ formula = formula.formula
874
+ if formula in symbols_enumerated:
875
+ return choices[:, symbols_enumerated.index(formula)]
876
+
877
+ padded_choices = get_padded_choices(
878
+ symbols_enumerated=symbols_enumerated,
879
+ symbols_non_enumerated_set=symbols_non_enumerated_set,
880
+ choices_enumerated=choices,
881
+ what_tiles_symbol=what_tiles_symbol,
882
+ minimize_formula=minimize_formula,
883
+ maximize_formula=maximize_formula,
884
+ )
885
+ return util._lambdify_type_check(symbols, formula)(
886
+ **{str(k): v for k, v in padded_choices.items()},
887
+ )
888
+
889
+ def grab_symbol(prev_symbol: Symbol = None):
890
+ # TODO: Maybe start with a symbol that would result in more pruning up front?
891
+ # Maximize the # of choices that can be resolved easily
892
+ if TILE_SHAPE_ORDER == "inner_to_outer":
893
+ return symbols_remaining.pop(-1)
894
+ if TILE_SHAPE_ORDER == "outer_to_inner":
895
+ return symbols_remaining.pop(0)
896
+
897
+ if TILE_SHAPE_ORDER == "inner_to_outer_one_rv_at_a_time":
898
+ # Continue with a symbol representing the parent tile of the last symbol
899
+ # if possible. Otherwise (see return), just grab any symbol.
900
+ choice = what_tiles_symbol.get_outer_tiles(prev_symbol, none_if_fail=True)
901
+ if choice is not None and choice in symbols_remaining:
902
+ symbols_remaining.remove(choice)
903
+ return choice
904
+ # Pick a symbol that has:
905
+ # - Nobody tiling it
906
+ # - The smallest maximum size
907
+ strides = [s for s in symbols_remaining if what_tiles_symbol.is_stride(s)]
908
+ choice = -1
909
+ if strides:
910
+ max_size = what_tiles_symbol.get_max_size(strides[choice])
911
+ for i, s in enumerate(strides):
912
+ if what_tiles_symbol.get_inner_tiles(s, none_if_fail=True) is None:
913
+ if what_tiles_symbol.get_max_size(s) < max_size:
914
+ choice = i
915
+ max_size = what_tiles_symbol.get_max_size(s)
916
+ choice = symbols_remaining.index(strides[choice])
917
+ return symbols_remaining.pop(choice)
918
+ elif TILE_SHAPE_ORDER == "outer_to_inner_one_rv_at_a_time":
919
+ # Continue with a symbol representing the child tile of the last symbol
920
+ # if possible. Otherwise (see return), just grab any symbol.
921
+ choice = what_tiles_symbol.get_inner_tiles(prev_symbol, none_if_fail=True)
922
+ if choice is not None and choice in symbols_remaining:
923
+ symbols_remaining.remove(choice)
924
+ return choice
925
+ # Pick a symbol that has:
926
+ # - Tiles nobody
927
+ # - The smallest maximum size
928
+ strides = [s for s in symbols_remaining if what_tiles_symbol.is_stride(s)]
929
+ choice = 0
930
+ if strides:
931
+ max_size = what_tiles_symbol.get_max_size(strides[choice])
932
+ for i, s in enumerate(strides):
933
+ if what_tiles_symbol.get_outer_tiles(s, none_if_fail=True) is None:
934
+ if what_tiles_symbol.get_max_size(s) < max_size:
935
+ choice = i
936
+ max_size = what_tiles_symbol.get_max_size(s)
937
+ choice = symbols_remaining.index(strides[choice])
938
+ return symbols_remaining.pop(choice)
939
+ else:
940
+ raise RuntimeError(f"BUG: invalid TILE_SHAPE_ORDER: {TILE_SHAPE_ORDER}")
941
+
942
+ last_stride_symbol = None # track the last stride symbol to select next symbol
943
+ symbol = None
944
+ while symbols_remaining:
945
+ # ==============================================================================
946
+ # Enumerate choices for a new symbol
947
+ # ==============================================================================
948
+ symbol = grab_symbol(last_stride_symbol)
949
+
950
+ choices = []
951
+ if what_tiles_symbol.is_stride(symbol):
952
+ last_stride_symbol = symbol
953
+ inner_tiles = what_tiles_symbol.get_inner_tiles(symbol, none_if_fail=True)
954
+ outer_tiles = what_tiles_symbol.get_outer_tiles(symbol, none_if_fail=True)
955
+
956
+ # Figure out inner size and outer size
957
+ if inner_tiles in symbols_enumerated:
958
+ inner_tiles_type = "enumerated"
959
+ inner_size = None
960
+ elif isinstance(inner_tiles, int):
961
+ inner_tiles_type = "set"
962
+ inner_size = inner_tiles
963
+ else:
964
+ inner_tiles_type = "unknown"
965
+ inner_size = 1
966
+
967
+ if outer_tiles in symbols_enumerated:
968
+ outer_tiles_type = "enumerated"
969
+ outer_size = None
970
+ elif isinstance(outer_tiles, int):
971
+ outer_tiles_type = "set"
972
+ outer_size = outer_tiles
973
+ else:
974
+ outer_tiles_type = "unknown"
975
+ outer_size = what_tiles_symbol.get_max_size(outer_tiles)
976
+
977
+ if inner_tiles_type == "enumerated" and outer_tiles_type == "enumerated":
978
+ raise RuntimeError(
979
+ f"BUG: both inner, {inner_tiles}, and outer, {outer_tiles},"
980
+ f"tiles of {symbol} are enumerated (thus far: {symbols_enumerated})"
981
+ )
982
+ if inner_tiles_type == "unknown" and outer_tiles_type == "unknown":
983
+ raise RuntimeError("BUG: both inner and outer tiles are unknown")
984
+
985
+ # Use inner size and outer size to generate choices
986
+ if inner_tiles_type in {"set", "unknown"} and outer_tiles_type in {
987
+ "set",
988
+ "unknown",
989
+ }:
990
+ factorize = math.ceil(outer_size / inner_size)
991
+ factors = list(get_possible_factor_sizes(factorize, imperfect))
992
+ scaled = np.array(factors) * inner_size
993
+ choices.append(append_vector(choices_enumerated, scaled))
994
+ elif inner_tiles_type == "enumerated":
995
+ assert isinstance(outer_size, int)
996
+ i = symbols_enumerated.index(inner_tiles)
997
+ for inner_choice in np.unique(choices_enumerated[:, i]):
998
+ partition = choices_enumerated[
999
+ np.where(choices_enumerated[:, i] == inner_choice)
1000
+ ]
1001
+ factorize = math.ceil(outer_size / inner_choice)
1002
+ factors = list(get_possible_factor_sizes(factorize, imperfect))
1003
+ scaled = np.array(factors) * inner_choice
1004
+ choices.append(append_vector(partition, scaled))
1005
+ else:
1006
+ assert outer_tiles_type == "enumerated"
1007
+ assert isinstance(inner_size, int)
1008
+ i = symbols_enumerated.index(outer_tiles)
1009
+ for outer_choice in np.unique(choices_enumerated[:, i]):
1010
+ partition = choices_enumerated[
1011
+ np.where(choices_enumerated[:, i] == outer_choice)
1012
+ ]
1013
+ factorize = math.ceil(outer_choice / inner_size)
1014
+ factors = list(get_possible_factor_sizes(factorize, imperfect))
1015
+ scaled = np.array(factors) * inner_size
1016
+ choices.append(append_vector(partition, scaled))
1017
+ elif what_tiles_symbol.is_initial_tile_shape(symbol):
1018
+ stride = what_tiles_symbol.get_stride(symbol)
1019
+ delta_choices = np.array(list(what_tiles_symbol.get_delta_choices(symbol)))
1020
+
1021
+ outer_stride = what_tiles_symbol.get_outer_tiles(stride, none_if_fail=True)
1022
+ assert outer_stride is None or isinstance(
1023
+ outer_stride, int
1024
+ ), f"outer stride is symbol {outer_stride}"
1025
+ if outer_stride is None:
1026
+ outer_size = what_tiles_symbol.get_max_size(stride)
1027
+ else:
1028
+ outer_size = outer_stride
1029
+
1030
+ if not stride in symbols_enumerated and not isinstance(stride, int):
1031
+ raise RuntimeError(
1032
+ f"BUG: stride {stride} of initial tile shape "
1033
+ f"{symbol} is neither enumerated nor a specified value"
1034
+ )
1035
+
1036
+ if isinstance(stride, int):
1037
+ initial_choices = delta_choices + stride
1038
+ initial_choices = initial_choices[initial_choices <= outer_size]
1039
+ choices.append(append_vector(choices_enumerated, initial_choices))
1040
+ else:
1041
+ i = symbols_enumerated.index(stride)
1042
+ for stride_choice in np.unique(choices_enumerated[:, i]):
1043
+ partition = choices_enumerated[
1044
+ np.where(choices_enumerated[:, i] == stride_choice)
1045
+ ]
1046
+ initial_choices = delta_choices + stride_choice
1047
+ initial_choices = initial_choices[initial_choices <= outer_size]
1048
+ choices.append(append_vector(partition, initial_choices))
1049
+ else:
1050
+ raise RuntimeError(
1051
+ f"BUG: symbol {symbol} is neither stride nor initial tile shape"
1052
+ )
1053
+
1054
+ # if not partitions:
1055
+ # return np.array([]).reshape(-1, len(symbols))
1056
+
1057
+ prev_size = choices_enumerated.shape[0] if choices_enumerated is not None else 1
1058
+ choices_enumerated = np.concatenate(choices, axis=0)
1059
+ job.n_total_pmappings *= choices_enumerated.shape[0] / max(1, prev_size)
1060
+ symbols_enumerated.append(symbol)
1061
+ log_message("enumerate", f"{symbol}", f"size={choices_enumerated.shape[0]}")
1062
+
1063
+ # ==============================================================================
1064
+ # Max fused loops per rank check
1065
+ # ==============================================================================
1066
+
1067
+ prev_size = choices_enumerated.shape[0]
1068
+ choices_enumerated = check_loops(
1069
+ symbols_enumerated,
1070
+ choices_enumerated,
1071
+ max_loop_check_groups,
1072
+ what_tiles_symbol,
1073
+ )
1074
+ job.log_porp_pmappings_kept(
1075
+ f"max_fused_loops_per_rank_variable",
1076
+ choices_enumerated.shape[0] / max(1, prev_size),
1077
+ )
1078
+ log_message(
1079
+ "max_fused_loops_per_rank_variable", f"size={choices_enumerated.shape[0]}"
1080
+ )
1081
+
1082
+ # ==============================================================================
1083
+ # Create initial Pareto-finding goals
1084
+ # ==============================================================================
1085
+ symbol2goal = {}
1086
+
1087
+ def update_symbol2goal(
1088
+ symbol: Symbol, goal: Goal, s2g: dict[Symbol, Goal] = None
1089
+ ):
1090
+ if s2g is None:
1091
+ s2g = symbol2goal
1092
+ s2g[symbol] = goal | s2g.get(symbol, Goal())
1093
+
1094
+ # If we're a symbol and a non-enumerated outer loop depends on us, then we need
1095
+ # to track this loop. Minimize it if we're imperfect (giving the outer the most
1096
+ # choices possible), or diff if we're perfect (since perfect constrains choices
1097
+ # so we can't just min).
1098
+ for s in symbols_enumerated:
1099
+ per_prime_factor = not (
1100
+ IMPERFECT
1101
+ or _get_n_prime_factors(what_tiles_symbol.get_max_size(s)) == 1
1102
+ )
1103
+ tiles = what_tiles_symbol.get_outer_tiles(s, none_if_fail=True)
1104
+ if isinstance(tiles, Symbol) and tiles not in symbols_enumerated:
1105
+ update_symbol2goal(
1106
+ s, Goal("min_per_prime_factor" if per_prime_factor else "min")
1107
+ )
1108
+
1109
+ # Same for inner loops depending on us, but maximize if we're imperfect
1110
+ tiled_by = what_tiles_symbol.get_inner_tiles(s, none_if_fail=True)
1111
+ if isinstance(tiled_by, Symbol) and tiled_by not in symbols_enumerated:
1112
+ update_symbol2goal(
1113
+ s, Goal("max_per_prime_factor" if per_prime_factor else "max")
1114
+ )
1115
+
1116
+ # If we need to keep this symbol, must preserve all choices for it
1117
+ for s in set(symbols_enumerated) & set(keep_symbols):
1118
+ update_symbol2goal(s, Goal("diff"))
1119
+
1120
+ symbols_non_enumerated_set = set(symbols) - set(symbols_enumerated)
1121
+ sym_enumerated_set = set(symbols_enumerated)
1122
+
1123
+ if job.spec.mapper.ffm._count_option_for_mapsapce_size_evaluation != ():
1124
+ choices_enumerated = choices_enumerated[:1, :]
1125
+ continue
1126
+
1127
+ choices_enumerated_float = choices_enumerated.astype(util.NUMPY_FLOAT_TYPE)
1128
+
1129
+ # ==============================================================================
1130
+ # Create functions to Pareto using objectives
1131
+ # ==============================================================================
1132
+ for objective in list(objectives):
1133
+ goals = partition_formula(
1134
+ objective.formula, sym_enumerated_set, what_tiles_symbol.bounds
1135
+ )
1136
+ if any(g.goal == "diff" for g in goals.values()):
1137
+ goals2 = partition_formula(
1138
+ sympy.expand(objective.formula),
1139
+ sym_enumerated_set,
1140
+ what_tiles_symbol.bounds,
1141
+ )
1142
+ goals = min(
1143
+ (goals, goals2),
1144
+ key=lambda x: sum(g.goal == "diff" for g in x.values()),
1145
+ )
1146
+
1147
+ # ==========================================================================
1148
+ # If there's a max value, then check for validity
1149
+ # ==========================================================================
1150
+ complete = objective.formula.free_symbols.issubset(sym_enumerated_set)
1151
+ prev_size = choices_enumerated.shape[0]
1152
+ if objective.max_value is not None:
1153
+ try:
1154
+ # minimize_for_objective may raise a TypeError if there's unknown
1155
+ # symbols
1156
+ result = eval_objective(
1157
+ objective.formula,
1158
+ choices_enumerated_float,
1159
+ minimize_formula=objective.formula,
1160
+ )
1161
+ if objective.inclusive:
1162
+ valid = result <= objective.max_value
1163
+ else:
1164
+ valid = result < objective.max_value
1165
+ if not isinstance(valid, np.ndarray):
1166
+ valid = (
1167
+ np.zeros(choices_enumerated.shape[0], dtype=bool) + valid
1168
+ )
1169
+ choices_enumerated = choices_enumerated[valid]
1170
+ choices_enumerated_float = choices_enumerated_float[valid]
1171
+ except (TypeError, ValueError):
1172
+ pass
1173
+ if objective.min_value is not None:
1174
+ try:
1175
+ # minimize_for_objective may raise a TypeError if there's unknown
1176
+ # symbols
1177
+ result = eval_objective(
1178
+ objective.formula,
1179
+ choices_enumerated_float,
1180
+ maximize_formula=objective.formula,
1181
+ )
1182
+ if objective.inclusive:
1183
+ valid = result >= objective.min_value
1184
+ else:
1185
+ valid = result > objective.min_value
1186
+ if not isinstance(valid, np.ndarray):
1187
+ valid = (
1188
+ np.zeros(choices_enumerated.shape[0], dtype=bool) + valid
1189
+ )
1190
+
1191
+ if not objective.try_best_if_none_reaches_min:
1192
+ choices_enumerated = choices_enumerated[valid]
1193
+ choices_enumerated_float = choices_enumerated_float[valid]
1194
+ else:
1195
+ if valid.any():
1196
+ choices_enumerated = choices_enumerated[valid]
1197
+ choices_enumerated_float = choices_enumerated_float[valid]
1198
+ elif complete:
1199
+ valid |= result == result.min()
1200
+ choices_enumerated = choices_enumerated[valid]
1201
+ choices_enumerated_float = choices_enumerated_float[valid]
1202
+ except (TypeError, ValueError):
1203
+ pass
1204
+
1205
+ porp = sum(valid) / max(1, choices_enumerated.shape[0])
1206
+ job.log_porp_pmappings_kept(
1207
+ f"{objective.name}",
1208
+ sum(valid) / max(1, prev_size),
1209
+ )
1210
+ log_message(f"Valid check", f"{objective.name}", f"porp={porp:.2%}")
1211
+ if complete:
1212
+ objective.max_value = None # We don't care anymore
1213
+ if objective.only_care_if_valid:
1214
+ objectives.remove(objective)
1215
+ log_message(f"Removed {objective.name} because it is always valid")
1216
+ goals.clear()
1217
+
1218
+ log_message(f"formula", f"{objective.formula}", f"{goals}")
1219
+
1220
+ for symbol, goal in goals.items():
1221
+ update_symbol2goal(symbol, goal)
1222
+
1223
+ job.n_evaluated_pmappings += choices_enumerated.shape[0]
1224
+ if not choices_enumerated.shape[0]:
1225
+ return np.array([]).reshape(-1, len(symbols))
1226
+
1227
+ if choices_enumerated.shape[0] < 100:
1228
+ continue
1229
+
1230
+ # ==============================================================================
1231
+ # Coalesce symbols. This simplifies our tracked goals. It also breaks down
1232
+ # partially-unknown goals into fully-known and/or fully-unknown goals.
1233
+ # ==============================================================================
1234
+ symbol2goal = coalesce_symbols(
1235
+ symbols_enumerated=symbols_enumerated,
1236
+ symbol2goal=symbol2goal,
1237
+ update_symbol2goal=update_symbol2goal,
1238
+ log_message=log_message,
1239
+ bounds=what_tiles_symbol.bounds,
1240
+ )
1241
+
1242
+ log_message("coalesce symbols", f"{symbol2goal}")
1243
+
1244
+ paretoed_by_key = fzs((f, g.goal) for f, g in symbol2goal.items())
1245
+ if any(p.issubset(paretoed_by_key) for p in paretoed_by):
1246
+ job.log_message(
1247
+ "Skipping Pareto because we've already found a Pareto with these objectives."
1248
+ )
1249
+ continue
1250
+ paretoed_by.append(paretoed_by_key)
1251
+
1252
+ objective_values = {}
1253
+ for formula, goal in list(symbol2goal.items()):
1254
+ objective_values[formula] = eval_objective(
1255
+ formula, choices_enumerated_float
1256
+ )
1257
+ symbol2goal[formula] = goal
1258
+ log_message("eval", f"{goal.goal}", f"{formula}")
1259
+
1260
+ if not objective_values:
1261
+ # Objective values don't depend on tile shapes
1262
+ choices_enumerated = choices_enumerated[:1, :]
1263
+ choices_enumerated_float = choices_enumerated_float[:1, :]
1264
+
1265
+ elif not all(
1266
+ symbol2goal.get(s, None) == Goal("diff") for s in symbols_enumerated
1267
+ ):
1268
+ to_pareto = np.concatenate(
1269
+ [v.reshape(-1, 1) for v in objective_values.values()], axis=1
1270
+ )
1271
+ log_message("Pareto", f"size {to_pareto.shape[0]}", "with objectives:")
1272
+ for obj in objectives:
1273
+ log_message(f"\t{obj.name}: {obj.formula}")
1274
+ log_message("Formulas:")
1275
+ for formula, goal in symbol2goal.items():
1276
+ log_message(f"\t{goal.goal}: {formula}")
1277
+
1278
+ drop_cols = []
1279
+ pareto_goals = []
1280
+ for i, (formula, goal) in enumerate(objective_values.items()):
1281
+ goal = symbol2goal[formula]
1282
+ if i not in drop_cols:
1283
+ pareto_goals.append(goal.goal)
1284
+ to_pareto = to_pareto[
1285
+ :, [i for i in range(to_pareto.shape[1]) if i not in drop_cols]
1286
+ ]
1287
+ keep = makepareto_numpy(to_pareto, pareto_goals, dirty=True)
1288
+ prev_size = choices_enumerated.shape[0]
1289
+ choices_enumerated = choices_enumerated[keep]
1290
+ job.log_porp_pmappings_kept(
1291
+ f"Pareto", sum(keep) / choices_enumerated.shape[0]
1292
+ )
1293
+ log_message("pareto", f"size {prev_size} -> {choices_enumerated.shape[0]}")
1294
+
1295
+ # ==================================================================================
1296
+ # Return the choices
1297
+ # ==================================================================================
1298
+ t = time.time() - start_time
1299
+ if t > 60:
1300
+ a = [
1301
+ f"Total time: {t:.2f}s",
1302
+ f"Pmapping: {job.mapping.compact_str()}",
1303
+ ]
1304
+ print("\n\t" + f"\n\t".join(a + job.messages))
1305
+
1306
+ # Rearrange in tile shape order
1307
+ if choices_enumerated is None:
1308
+ return np.array([])
1309
+ return choices_enumerated[:, [symbols_enumerated.index(s) for s in symbols]]
1310
+
1311
+
1312
+ def makesymbol(name: str):
1313
+ # TODO: Do the solve() calls work with integer=True?
1314
+ return Symbol(name, positive=True, integer=True)
1315
+
1316
+
1317
+ def make_keep_symbols(pmapping: Mapping) -> set[Symbol]:
1318
+ keep_symbols = set()
1319
+ for node in pmapping.nodes:
1320
+ if isinstance(node, Loop) and node._fused:
1321
+ if isinstance(node.initial_tile_shape, Symbol):
1322
+ keep_symbols.add(node.initial_tile_shape)
1323
+ if isinstance(node.tile_shape, Symbol):
1324
+ keep_symbols.add(node.tile_shape)
1325
+ return keep_symbols
1326
+
1327
+
1328
+ def get_rank_var_to_fused_loops(
1329
+ pmapping: Mapping, shape: dict[str, int]
1330
+ ) -> dict[str, list[Symbol]]:
1331
+ rank_var_to_fused_loops: dict[str, list[Symbol]] = {}
1332
+ for node in [n for n in pmapping.nodes if isinstance(n, Loop) and n._fused]:
1333
+ rank_var_to_fused_loops.setdefault(node.rank_variable, []).append(
1334
+ node.tile_shape
1335
+ )
1336
+ return rank_var_to_fused_loops
1337
+
1338
+
1339
+ def set_last_tile_shape_to_one(pmapping):
1340
+ pmapping = pmapping.nodes
1341
+
1342
+ rank_var_to_last_node = {}
1343
+ for node in pmapping:
1344
+ if isinstance(node, Temporal) or isinstance(node, Spatial):
1345
+ rank_var_to_last_node[node.rank_variable] = node
1346
+
1347
+ for last_node in rank_var_to_last_node.values():
1348
+ last_node.initial_tile_shape = None
1349
+ last_node.tile_shape = 1
1350
+
1351
+
1352
+ # This was made only so we could do some counting of the time.
1353
+ def call_compiled_objective(f, *args):
1354
+ return f(*args)
1355
+
1356
+
1357
+ def _make_tile_shapes(job: "Job"):
1358
+ # We're going to convert the job into a list of symbols and objectives
1359
+ pmapping = job.mapping
1360
+ constraints = job.constraints
1361
+ constraints.set_loop_indices(pmapping.nodes)
1362
+ set_last_tile_shape_to_one(pmapping)
1363
+ t0 = time.time()
1364
+ (
1365
+ symbols,
1366
+ symbolic_df,
1367
+ per_memory_usage_df,
1368
+ usage_df,
1369
+ tensor2mapping,
1370
+ ) = run_model(job)
1371
+
1372
+ model_time = time.time() - t0
1373
+ shape = job.rank_variable_bounds
1374
+ what_tiles_symbol = SymbolRelations.from_pmapping_and_shape(
1375
+ pmapping, shape, job.spec.workload
1376
+ )
1377
+ keep_symbols = make_keep_symbols(pmapping)
1378
+ rank_var_to_fused_loops = get_rank_var_to_fused_loops(pmapping, shape)
1379
+ all_fused_loops = set(sum(rank_var_to_fused_loops.values(), []))
1380
+
1381
+ objectives = []
1382
+
1383
+ # ==================================================================================
1384
+ # Loop bounds constraints. Put these before the other objectives so that hopefully
1385
+ # if 100% of the pmappings are pruned, then we're given the actual architecture
1386
+ # component that caused it and not the loop bound constraint.
1387
+ # ==================================================================================
1388
+ loops = [n for n in pmapping.nodes if isinstance(n, Loop)]
1389
+ for c in constraints.loop_bounds_constraints:
1390
+ min_value, max_value, inclusive = None, None, True
1391
+ is_product = "product" in c.constraint.operator
1392
+ operator = c.constraint.operator.replace("product", "")
1393
+ if operator in ["==", "<=", "<"]:
1394
+ max_value = c.constraint.value
1395
+ if operator in [">=", ">", "=="]:
1396
+ min_value = c.constraint.value
1397
+ if operator in ["<", ">"]:
1398
+ inclusive = False
1399
+
1400
+ targets = []
1401
+ for i in c._target_loop_indices:
1402
+ n = loops[i]
1403
+ size = what_tiles_symbol.get_outer_tiles(n.tile_shape, none_if_fail=True)
1404
+ if size is None:
1405
+ size = what_tiles_symbol.get_max_size(n.tile_shape)
1406
+ targets.append(size / n.tile_shape)
1407
+
1408
+ # targets = [loops[i]._calculated_n_iterations for i in c._target_loop_indices]
1409
+ if not targets:
1410
+ continue
1411
+
1412
+ if is_product:
1413
+ targets = [sympy.Mul(*targets)]
1414
+
1415
+ if max_value is None and min_value is not None:
1416
+ max_value = -min_value
1417
+ targets = [-target for target in targets]
1418
+ min_value = None
1419
+
1420
+ for target in targets:
1421
+ objectives.append(
1422
+ Objective(
1423
+ name=f"loop_bounds_{c.constraint}",
1424
+ formula=target,
1425
+ symbols=symbols,
1426
+ only_care_if_valid=True,
1427
+ max_value=max_value,
1428
+ min_value=min_value,
1429
+ inclusive=inclusive,
1430
+ )
1431
+ )
1432
+
1433
+ # ==================================================================================
1434
+ # Memory usage and usage constraints.
1435
+ # ==================================================================================
1436
+ for k, v in {**per_memory_usage_df, **usage_df}.items():
1437
+ # If we only track for pmappings, we only care if it's valid. If we track for
1438
+ # all, we care about the value too.
1439
+
1440
+ only_care_if_valid = False
1441
+ if k in job.memories_track_pmappings_only:
1442
+ only_care_if_valid = True
1443
+
1444
+ # TODO: Update check to see if we may be sharing usage with other
1445
+ # pmappings in parallel/pipeline.
1446
+ if k in usage_df:
1447
+ only_care_if_valid = True
1448
+
1449
+ objectives.append(
1450
+ Objective(
1451
+ name=k,
1452
+ formula=v,
1453
+ symbols=symbols,
1454
+ only_care_if_valid=only_care_if_valid,
1455
+ max_value=1,
1456
+ )
1457
+ )
1458
+
1459
+ # ==================================================================================
1460
+ # Min usage constraints. Put this last because it has some try best if none reach
1461
+ # min logic.
1462
+ # ==================================================================================
1463
+ for (
1464
+ component_name,
1465
+ name,
1466
+ ), constraint in job.constraints.min_usage_constraints.items():
1467
+ objectives.append(
1468
+ Objective(
1469
+ name=f"min_usage_{component_name}_{name}",
1470
+ formula=v,
1471
+ symbols=symbols,
1472
+ only_care_if_valid=True,
1473
+ min_value=constraint.min_usage,
1474
+ try_best_if_none_reaches_min=True,
1475
+ )
1476
+ )
1477
+
1478
+ for k, v in symbolic_df.items():
1479
+ if "Total" not in k:
1480
+ continue
1481
+
1482
+ objectives.append(
1483
+ Objective(
1484
+ name=k,
1485
+ formula=v,
1486
+ symbols=symbols,
1487
+ )
1488
+ )
1489
+
1490
+ rank2symbols = {}
1491
+ for node in pmapping.nodes:
1492
+ if isinstance(node, (Temporal, Spatial)):
1493
+ if node.tile_shape in symbols:
1494
+ rank2symbols.setdefault(node.rank_variable, []).append(node.tile_shape)
1495
+
1496
+ max_loop_check_groups = [
1497
+ (job.spec.mapper.ffm.max_fused_loops, all_fused_loops),
1498
+ *[
1499
+ (job.spec.mapper.ffm.max_fused_loops_per_rank_variable, x)
1500
+ for x in rank_var_to_fused_loops.values()
1501
+ ],
1502
+ ]
1503
+
1504
+ max_loop_check_groups = [g for g in max_loop_check_groups if g[1]]
1505
+
1506
+ choices_enumerated = get_tile_shape_choices(
1507
+ objectives=objectives,
1508
+ symbols=symbols,
1509
+ what_tiles_symbol=what_tiles_symbol,
1510
+ job=job,
1511
+ keep_symbols=keep_symbols,
1512
+ max_loop_check_groups=max_loop_check_groups,
1513
+ )
1514
+
1515
+ try:
1516
+ compiled_df = compile_dict(symbols, symbolic_df)
1517
+ compiled_per_memory_usage_df = compile_dict(symbols, per_memory_usage_df)
1518
+ compiled_usage_df = compile_dict(symbols, usage_df)
1519
+ except Exception as e:
1520
+ print("Compilation failed for this mapping:")
1521
+ for node in pmapping.nodes:
1522
+ if hasattr(node, "compact_str"):
1523
+ print(node.compact_str())
1524
+ print(symbolic_df)
1525
+ e.add_note("Compilation failed")
1526
+ raise
1527
+
1528
+ choices_float = choices_enumerated.astype(util.NUMPY_FLOAT_TYPE)
1529
+ # choices_float = np.tile(choices_float, (1000000, 1))
1530
+ # choices_enumerated = np.tile(choices_enumerated, (1000000, 1))
1531
+
1532
+ df = {}
1533
+ for i, symbol in enumerate(symbols):
1534
+ df[symbol.name] = choices_enumerated[:, i]
1535
+
1536
+ t0 = time.time()
1537
+ for key in compiled_df:
1538
+ df[key] = call_compiled_objective(compiled_df[key], *choices_float.T)
1539
+ if "latency" in key and "first_latency" not in key:
1540
+ val = [df[key]] if isinstance(df[key], Number) else df[key]
1541
+ if any(l < 0 for l in val):
1542
+ raise ValueError(f"Negative latency for {key}: {val}")
1543
+ if "energy" in key:
1544
+ val = [df[key]] if isinstance(df[key], Number) else df[key]
1545
+ if any(l < 0 for l in val):
1546
+ raise ValueError(f"Negative energy for {key}: {val}")
1547
+
1548
+ # Some initial tile shapes are invalid
1549
+ for nloops, n in enumerate(
1550
+ node for node in job.mapping.nodes if isinstance(node, Loop) and node._fused
1551
+ ):
1552
+ stride = n.tile_pattern.tile_shape
1553
+ initial = (
1554
+ n.tile_pattern.initial_tile_shape
1555
+ if n.tile_pattern.initial_tile_shape is not None
1556
+ else stride
1557
+ )
1558
+ outer_stride = what_tiles_symbol.get_outer_tiles(stride)
1559
+ outer_initial = what_tiles_symbol.get_initial(outer_stride, none_if_fail=True)
1560
+ outer_stride = (
1561
+ df[outer_stride.name] if isinstance(outer_stride, Symbol) else outer_stride
1562
+ )
1563
+
1564
+ outer_initial = (
1565
+ df[outer_initial.name]
1566
+ if isinstance(outer_initial, Symbol)
1567
+ else outer_stride
1568
+ )
1569
+
1570
+ rank_var_stride = df[stride.name] if isinstance(stride, Symbol) else stride
1571
+ rank_var_initial = df[initial.name] if isinstance(initial, Symbol) else initial
1572
+
1573
+ # NOTE: The concept of having one "n_iterations" is precarious when imperfect factorization in involved
1574
+ df[iterations2col(nloops)] = np.ceil(
1575
+ (outer_initial - rank_var_initial) / rank_var_stride + 1
1576
+ )
1577
+ df[f"lower_iterations<SEP>{nloops}"] = outer_stride - rank_var_initial
1578
+
1579
+ # Generate rank columns
1580
+ einsum: Einsum = job.spec.workload.einsums[job.einsum_name]
1581
+ for tensor_access in einsum.tensor_accesses:
1582
+ tensor = tensor_access.name
1583
+ projections = get_projection_expr(einsum, tensor)
1584
+ for rank, expr in projections.items():
1585
+ free_symbols = tuple(expr.free_symbols)
1586
+ free_symbols_str = tuple(symbol.name for symbol in free_symbols)
1587
+ if n.rank_variable not in free_symbols_str:
1588
+ continue
1589
+
1590
+ rank_stride = expr.coeff(n.rank_variable) * rank_var_stride
1591
+
1592
+ args = []
1593
+ for free_rank_var in free_symbols:
1594
+ if free_rank_var.name == n.rank_variable:
1595
+ args.append(rank_var_initial)
1596
+ else:
1597
+ args.append(shape[free_rank_var.name])
1598
+ rank_initial = lambdify(free_symbols, expr)(*args)
1599
+
1600
+ df[stride2col(rank, nloops)] = rank_stride
1601
+ df[initial2col(rank, nloops)] = rank_initial
1602
+
1603
+ try:
1604
+ df = pd.DataFrame(df, columns=df.keys())
1605
+ except ValueError as e:
1606
+ df = pd.DataFrame(df, columns=df.keys(), index=[0])
1607
+ assert not df.isna().any().any()
1608
+
1609
+ energy_cols = [c for c in df.columns if "Total<SEP>energy" in c]
1610
+ if (df[energy_cols] < 0).any(axis=None):
1611
+ mapping_with_negative_energy = df[(df[energy_cols] < 0).any(axis=1)]
1612
+ print(df.columns)
1613
+ msg = ""
1614
+ for _, row in mapping_with_negative_energy.iterrows():
1615
+ for k, v in row.items():
1616
+ msg += f"{k}: {v}\n"
1617
+ msg += "\n"
1618
+ raise RuntimeError(f"negative energy:\n{msg}")
1619
+
1620
+ job.n_valid_pmappings = job.n_total_pmappings * prod(
1621
+ job.pmapping_keep_rates.values()
1622
+ )
1623
+ return df, tensor2mapping
1624
+
1625
+
1626
+ def make_tile_shapes(job: "Job"):
1627
+ memory_limit = job.memory_limit // 8 # Bytes -> bits
1628
+ if job.memory_limit != float("inf"):
1629
+ try:
1630
+ resource.setrlimit(resource.RLIMIT_AS, (job.memory_limit, job.memory_limit))
1631
+ except (ValueError, OSError):
1632
+ # Ignore permission errors when trying to set memory limits
1633
+ pass
1634
+
1635
+ if job.time_limit != float("inf"):
1636
+ try:
1637
+ resource.setrlimit(
1638
+ resource.RLIMIT_CPU, (ceil(job.time_limit), ceil(job.time_limit))
1639
+ )
1640
+ except (ValueError, OSError):
1641
+ # Ignore permission errors when trying to set CPU limits
1642
+ pass
1643
+
1644
+ def format_memory_limit() -> str:
1645
+ if memory_limit == float("inf"):
1646
+ return "infinite"
1647
+ if memory_limit > 1024 * 1024 * 1024:
1648
+ return f"{memory_limit / (1024 * 1024 * 1024):.2f} GB"
1649
+ elif memory_limit > 1024 * 1024:
1650
+ return f"{memory_limit / (1024 * 1024):.2f} MB"
1651
+ elif memory_limit > 1024:
1652
+ return f"{memory_limit / 1024:.2f} KB"
1653
+ else:
1654
+ return f"{memory_limit:.2f} B"
1655
+
1656
+ try:
1657
+ return _make_tile_shapes(job)
1658
+ except MemoryError as e:
1659
+ s = f"Job ran out of memory with memory limit {format_memory_limit()}"
1660
+ job.log_message(f"Tile shape exploration failed: {s}")
1661
+ raise RuntimeError(job.pretty_str()) from e
1662
+ except TimeoutError as e:
1663
+ s = f"Job timed out with time limit {job.time_limit:.2f} seconds"
1664
+ job.log_message(f"Tile shape exploration failed: {s}")
1665
+ raise RuntimeError(job.pretty_str()) from e
1666
+
1667
+ finally:
1668
+ try:
1669
+ resource.setrlimit(
1670
+ resource.RLIMIT_AS, (resource.RLIM_INFINITY, resource.RLIM_INFINITY)
1671
+ )
1672
+ except (ValueError, OSError):
1673
+ # Ignore permission errors when trying to reset memory limits
1674
+ pass
1675
+ try:
1676
+ resource.setrlimit(
1677
+ resource.RLIMIT_CPU, (resource.RLIM_INFINITY, resource.RLIM_INFINITY)
1678
+ )
1679
+ except (ValueError, OSError):
1680
+ # Ignore permission errors when trying to reset CPU limits
1681
+ pass