accelforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of accelforge might be problematic. Click here for more details.

Files changed (258) hide show
  1. accelforge/__init__.py +21 -0
  2. accelforge/_accelerated_imports.py +16 -0
  3. accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
  4. accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
  5. accelforge/_deprecate/_simanneal/simanneal.py +666 -0
  6. accelforge/_deprecate/_simanneal/tracking.py +105 -0
  7. accelforge/_deprecate/_simanneal/wrappers.py +218 -0
  8. accelforge/_deprecate/_simanneal2/__init__.py +7 -0
  9. accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
  10. accelforge/_deprecate/_simanneal2/tracking.py +116 -0
  11. accelforge/_deprecate/compatibility_util.py +181 -0
  12. accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
  13. accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
  14. accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
  15. accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
  16. accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
  17. accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
  18. accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
  19. accelforge/_deprecate/tags.py +69 -0
  20. accelforge/_deprecate/viz/__init__.py +0 -0
  21. accelforge/_deprecate/viz/interactive.py +159 -0
  22. accelforge/_deprecate/viz/reservationtree.py +307 -0
  23. accelforge/_deprecate/viz/ski_slope.py +88 -0
  24. accelforge/_version.py +15 -0
  25. accelforge/examples.py +39 -0
  26. accelforge/frontend/__init__.py +10 -0
  27. accelforge/frontend/_binding.py +129 -0
  28. accelforge/frontend/_workload_isl/__init__.py +2 -0
  29. accelforge/frontend/_workload_isl/_isl.py +149 -0
  30. accelforge/frontend/_workload_isl/_symbolic.py +141 -0
  31. accelforge/frontend/arch copy.py +1544 -0
  32. accelforge/frontend/arch.py +1642 -0
  33. accelforge/frontend/config.py +63 -0
  34. accelforge/frontend/mapper/__init__.py +5 -0
  35. accelforge/frontend/mapper/ffm.py +126 -0
  36. accelforge/frontend/mapper/mapper.py +7 -0
  37. accelforge/frontend/mapper/metrics.py +30 -0
  38. accelforge/frontend/mapping/__init__.py +1 -0
  39. accelforge/frontend/mapping/mapping.py +1736 -0
  40. accelforge/frontend/model.py +14 -0
  41. accelforge/frontend/renames.py +150 -0
  42. accelforge/frontend/spec copy.py +230 -0
  43. accelforge/frontend/spec.py +301 -0
  44. accelforge/frontend/variables.py +12 -0
  45. accelforge/frontend/workload.py +952 -0
  46. accelforge/mapper/FFM/__init__.py +9 -0
  47. accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
  48. accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
  49. accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
  50. accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
  51. accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
  52. accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
  53. accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
  54. accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
  55. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
  56. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
  57. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
  58. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
  59. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
  60. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
  61. accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
  62. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
  63. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
  64. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
  65. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
  66. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
  67. accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
  68. accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
  69. accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
  70. accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
  71. accelforge/mapper/FFM/data.py +61 -0
  72. accelforge/mapper/FFM/main copy.py +236 -0
  73. accelforge/mapper/FFM/main.py +208 -0
  74. accelforge/mapper/FFM/mappings.py +510 -0
  75. accelforge/mapper/FFM/pmappings.py +310 -0
  76. accelforge/mapper/__init__.py +4 -0
  77. accelforge/mapper.py +0 -0
  78. accelforge/model/__init__.py +1 -0
  79. accelforge/model/_looptree/__init__.py +0 -0
  80. accelforge/model/_looptree/accesses.py +335 -0
  81. accelforge/model/_looptree/capacity/__init__.py +1 -0
  82. accelforge/model/_looptree/capacity/aggregators.py +36 -0
  83. accelforge/model/_looptree/capacity/capacity.py +47 -0
  84. accelforge/model/_looptree/energy.py +150 -0
  85. accelforge/model/_looptree/equivalent_ranks.py +29 -0
  86. accelforge/model/_looptree/latency/__init__.py +1 -0
  87. accelforge/model/_looptree/latency/latency.py +98 -0
  88. accelforge/model/_looptree/latency/memory.py +120 -0
  89. accelforge/model/_looptree/latency/processors.py +92 -0
  90. accelforge/model/_looptree/mapping_utilities.py +71 -0
  91. accelforge/model/_looptree/reuse/__init__.py +4 -0
  92. accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
  93. accelforge/model/_looptree/reuse/isl/des.py +59 -0
  94. accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
  95. accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
  96. accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
  97. accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
  98. accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
  99. accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
  100. accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
  101. accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
  102. accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
  103. accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
  104. accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
  105. accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
  106. accelforge/model/_looptree/run.py +122 -0
  107. accelforge/model/_looptree/types.py +26 -0
  108. accelforge/model/_looptree/visualization/__init__.py +0 -0
  109. accelforge/model/_looptree/visualization/occupancy.py +11 -0
  110. accelforge/model/main.py +222 -0
  111. accelforge/plotting/__init__.py +2 -0
  112. accelforge/plotting/mappings.py +219 -0
  113. accelforge/plotting/specs.py +57 -0
  114. accelforge/util/__init__.py +4 -0
  115. accelforge/util/_base_analysis_types.py +24 -0
  116. accelforge/util/_basetypes.py +1089 -0
  117. accelforge/util/_frozenset.py +36 -0
  118. accelforge/util/_isl.py +29 -0
  119. accelforge/util/_itertools.py +14 -0
  120. accelforge/util/_mathfuncs.py +57 -0
  121. accelforge/util/_parse_expressions.py +339 -0
  122. accelforge/util/_picklecache.py +32 -0
  123. accelforge/util/_setexpressions.py +268 -0
  124. accelforge/util/_sympy/__init__.py +0 -0
  125. accelforge/util/_sympy/broadcast_max.py +18 -0
  126. accelforge/util/_visualization.py +112 -0
  127. accelforge/util/_yaml.py +579 -0
  128. accelforge/util/parallel.py +193 -0
  129. accelforge-0.0.1.dist-info/METADATA +64 -0
  130. accelforge-0.0.1.dist-info/RECORD +258 -0
  131. accelforge-0.0.1.dist-info/WHEEL +5 -0
  132. accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
  133. accelforge-0.0.1.dist-info/top_level.txt +5 -0
  134. docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
  135. docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
  136. docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
  137. docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
  138. docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
  139. docs/_build/html/_sources/fastfusion.rst.txt +20 -0
  140. docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
  141. docs/_build/html/_sources/index.rst.txt +87 -0
  142. docs/_build/html/_sources/modules.rst.txt +7 -0
  143. docs/_build/html/_sources/notes/citation.rst.txt +45 -0
  144. docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
  145. docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
  146. docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
  147. docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
  148. docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
  149. docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
  150. docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
  151. docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
  152. docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
  153. docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
  154. docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
  155. docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
  156. docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
  157. docs/_build/html/_sources/notes/spec.rst.txt +36 -0
  158. docs/source/_ext/include_attrs.py +213 -0
  159. docs/source/_ext/include_docstring.py +364 -0
  160. docs/source/_ext/include_functions.py +154 -0
  161. docs/source/_ext/include_notebook.py +131 -0
  162. docs/source/_ext/include_yaml.py +119 -0
  163. docs/source/_ext/inherited_attributes.py +222 -0
  164. docs/source/_ext/paths.py +4 -0
  165. docs/source/conf.py +79 -0
  166. examples/arches/compute_in_memory/_include.yaml +74 -0
  167. examples/arches/compute_in_memory/_include_functions.py +229 -0
  168. examples/arches/compute_in_memory/_load_spec.py +57 -0
  169. examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
  170. examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
  171. examples/arches/compute_in_memory/components/misc.py +195 -0
  172. examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
  173. examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
  174. examples/arches/compute_in_memory/isaac.yaml +233 -0
  175. examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
  176. examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
  177. examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
  178. examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
  179. examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
  180. examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
  181. examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
  182. examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
  183. examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
  184. examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
  185. examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
  186. examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
  187. examples/arches/eyeriss.yaml +68 -0
  188. examples/arches/fanout_variations/at_glb.yaml +31 -0
  189. examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
  190. examples/arches/fanout_variations/at_mac.yaml +31 -0
  191. examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
  192. examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
  193. examples/arches/nvdla.yaml +47 -0
  194. examples/arches/simple.yaml +28 -0
  195. examples/arches/tpu_v4i.yaml +67 -0
  196. examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
  197. examples/misc/component_annotated.yaml +33 -0
  198. examples/workloads/gpt3_6.7B.yaml +124 -0
  199. examples/workloads/matmuls.yaml +20 -0
  200. examples/workloads/mobilenet_28.yaml +81 -0
  201. examples/workloads/mobilenet_various_separate.yaml +106 -0
  202. examples/workloads/three_matmuls_annotated.yaml +59 -0
  203. notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
  204. notebooks/compute_in_memory/_scripts.py +339 -0
  205. notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
  206. notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
  207. notebooks/paths.py +4 -0
  208. notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
  209. notebooks/tutorials/FFM.ipynb +3498 -0
  210. notebooks/tutorials/_include.py +48 -0
  211. notebooks/tutorials/component_energy_area.ipynb +363 -0
  212. tests/Q_mapping.yaml +38 -0
  213. tests/__init__.py +0 -0
  214. tests/conv.mapping.yaml +27 -0
  215. tests/conv.workload.yaml +13 -0
  216. tests/conv_sym.mapping.yaml +43 -0
  217. tests/copy.mapping.yaml +35 -0
  218. tests/copy.workload.yaml +15 -0
  219. tests/distribuffers/__init__.py +0 -0
  220. tests/distribuffers/multicast/test_cases.yaml +482 -0
  221. tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
  222. tests/distribuffers/spec/distributed.yaml +100 -0
  223. tests/distribuffers/spec/logical_arch.yaml +32 -0
  224. tests/distribuffers/spec/physical_arch.yaml +69 -0
  225. tests/distribuffers/test_binding.py +48 -0
  226. tests/frontend/__init__.py +0 -0
  227. tests/frontend/test_mapping_viz.py +52 -0
  228. tests/mapper/__init__.py +0 -0
  229. tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
  230. tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
  231. tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
  232. tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
  233. tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
  234. tests/mapper/test_mapping_to_isl.py +90 -0
  235. tests/mapper/test_spatial_reuse_analysis.py +67 -0
  236. tests/mapper/test_temporal_reuse_analysis.py +56 -0
  237. tests/mapper/util.py +58 -0
  238. tests/matmul.mapping.yaml +29 -0
  239. tests/matmul.workload.yaml +12 -0
  240. tests/matmul_spatial.mapping.yaml +44 -0
  241. tests/mha.renames.yaml +65 -0
  242. tests/mha.workload.yaml +67 -0
  243. tests/mha.yaml +59 -0
  244. tests/mha_full.workload.yaml +67 -0
  245. tests/mobilenet.workload.yaml +35 -0
  246. tests/mobilenet_long.workload.yaml +64 -0
  247. tests/pmappingcache.py +24 -0
  248. tests/processing_stage.arch.yaml +40 -0
  249. tests/snowcat.arch.yaml +36 -0
  250. tests/test_ffm_join_pmappings.py +106 -0
  251. tests/test_ffm_make_pmappings.py +82 -0
  252. tests/test_ffm_make_tile_shapes.py +49 -0
  253. tests/test_mapper.py +100 -0
  254. tests/test_model.py +37 -0
  255. tests/test_plotting.py +72 -0
  256. tests/test_processing_stage.py +46 -0
  257. tests/test_symbolic_model.py +248 -0
  258. tests/test_workload.py +141 -0
@@ -0,0 +1,836 @@
1
+ import functools
2
+ from math import prod
3
+ import time
4
+
5
+ import numba
6
+ import pandas as pd
7
+
8
+ from paretoset import paretoset
9
+ from joblib import delayed
10
+ from sympy import factorint
11
+
12
+ from accelforge._accelerated_imports import np
13
+ from accelforge.util.parallel import parallel
14
+
15
+ from accelforge.mapper.FFM._pareto_df.df_convention import (
16
+ col_used_in_pareto,
17
+ is_fused_loop_col,
18
+ is_n_iterations_col,
19
+ is_objective_col,
20
+ )
21
+
22
+ from paretoset.algorithms_numba import any_jitted
23
+
24
+
25
+ def dominates(a: pd.Series, b: pd.Series) -> bool:
26
+ return all(a[i] <= b[i] for i in range(len(a)))
27
+
28
+
29
+ def check_dominance(df: pd.DataFrame, n_optimal: int):
30
+ # mask = np.zeros(len(df), dtype=bool)
31
+ # mask[:new_point] = True
32
+ mask = np.zeros(len(df) - n_optimal, dtype=bool)
33
+ for col in df.columns:
34
+ compare = df.iloc[n_optimal - 1][col]
35
+ mask = mask | (df[col].iloc[n_optimal:] < compare)
36
+ return np.concatenate([np.ones(n_optimal, dtype=bool), mask])
37
+
38
+
39
+ def quickpareto(df: pd.DataFrame) -> pd.DataFrame:
40
+ # Step 1: Sort by the column with the most unique values
41
+ # Step 2: Extract the first row. Add it to the pareto set
42
+ # Step 3: Remove all dominated points
43
+ # Step 4: Repeat until no more points to add
44
+
45
+ # Step 1: Sort by the column with the most unique values
46
+ original_len = len(df)
47
+ col_to_sort = max(df.columns, key=lambda c: df[c].nunique())
48
+ df = df.sort_values(by=col_to_sort).drop(columns=[col_to_sort])
49
+
50
+ new_point = 0
51
+ while new_point < len(df):
52
+ mask = check_dominance(df, new_point + 1)
53
+ df = df[mask]
54
+ new_point += 1
55
+
56
+ # Turn the index into a mask
57
+ mask = np.zeros(original_len, dtype=bool)
58
+ mask[df.index] = True
59
+ return mask
60
+
61
+
62
+ def makepareto_quick2(mappings: pd.DataFrame, columns: list[str]) -> pd.DataFrame:
63
+ from fast_pareto import is_pareto_front
64
+
65
+ m2 = mappings[columns]
66
+ m2 = m2[is_pareto_front(m2.to_numpy())].drop_duplicates()
67
+ return mappings.loc[m2.index]
68
+
69
+
70
+ def makepareto_quick(mappings: pd.DataFrame, columns: list[str]) -> pd.DataFrame:
71
+ return mappings[quickpareto(mappings[columns])]
72
+
73
+
74
+ def paretofy_chunk(chunk, sense: list[str]):
75
+ return paretoset(chunk, sense=sense)
76
+
77
+
78
+ def makepareto_merge(
79
+ mappings: pd.DataFrame,
80
+ columns: list[str],
81
+ parallelize: bool = False,
82
+ split_by_cols: list[str] = (),
83
+ ) -> pd.DataFrame:
84
+ chunk_size = 10000
85
+ if len(mappings) <= 1:
86
+ return mappings
87
+
88
+ sense = ["min"] * len(columns) + ["diff"] * len(split_by_cols)
89
+
90
+ to_chunk = mappings[columns + list(split_by_cols)]
91
+ chunks = parallel(
92
+ [
93
+ delayed(paretofy_chunk)(chunk, sense)
94
+ for chunk in [
95
+ to_chunk[i : i + chunk_size]
96
+ for i in range(0, len(to_chunk), chunk_size)
97
+ ]
98
+ ],
99
+ n_jobs=1 if parallelize else None,
100
+ )
101
+ mappings = mappings[np.concatenate(chunks)]
102
+ return mappings[paretoset(mappings[columns + list(split_by_cols)], sense=sense)]
103
+
104
+
105
+ def makepareto_time_compare(mappings: pd.DataFrame, columns: list[str]) -> pd.DataFrame:
106
+ t0 = time.time()
107
+ pareto = makepareto_merge(mappings, columns)
108
+ t1 = time.time()
109
+ merge_time = t1 - t0
110
+ print(
111
+ f"Time to make pareto with merge: {t1 - t0: .2f}. Number of pareto points: {len(pareto)}"
112
+ )
113
+
114
+ t0 = time.time()
115
+ pareto2 = makepareto_quick2(mappings, columns)
116
+ t1 = time.time()
117
+ print(
118
+ f"Time to make pareto with quick: {t1 - t0: .2f}. Number of pareto points: {len(pareto2)}"
119
+ )
120
+ quick_time = t1 - t0
121
+
122
+ print(f"Quick is {quick_time / merge_time: .2f}x slower")
123
+
124
+ if len(pareto) != len(pareto2):
125
+ print(f"mismatch: {len(pareto)} != {len(pareto2)}")
126
+ makepareto_quick2(mappings)
127
+
128
+ return pareto2
129
+
130
+
131
+ # 2d. Blockwise vectorized CuPy Pareto front with sorting by one objective (full check)
132
+ # 2c. Fully vectorized CuPy brute-force Pareto front
133
+ # (returns numpy mask for compatibility)
134
+ def pareto_front_cupy_vectorized(X):
135
+ # if len(X) > 1000:
136
+ # return X[paretoset(X.get(), sense=["min"] * X.shape[1])]
137
+
138
+ # Broadcast X_gpu to (n, n, m) for all-pairs comparison
139
+ A = X[:, None, :] # shape (n, 1, m)
140
+ B = X[None, :, :] # shape (1, n, m)
141
+ less_equal = (B <= A).all(axis=2) # shape (n, n)
142
+ strictly_less = (B < A).any(axis=2) # shape (n, n)
143
+ dominated = less_equal & strictly_less # shape (n, n)
144
+ is_pareto = ~dominated.any(axis=1)
145
+ return is_pareto
146
+
147
+
148
+ # 2d. Recursive blockwise merge CuPy Pareto front with sorting by one objective
149
+ def pareto_front_cupy_blockwise_sorted_recursive(X, block_size=2000):
150
+ N = X.shape[0]
151
+ if N <= block_size:
152
+ # Base case: just compute Pareto front directly
153
+ mask = pareto_front_cupy_vectorized(X)
154
+ return mask
155
+ # Split into two halves
156
+ mid = N // 2
157
+ a, b = X[:mid], X[mid:]
158
+ mask_a = pareto_front_cupy_blockwise_sorted_recursive(a, block_size)
159
+ mask_b = pareto_front_cupy_blockwise_sorted_recursive(b, block_size)
160
+ # Get Pareto-optimal points from both halves
161
+ pareto_points_a = a[mask_a]
162
+ pareto_points_b = b[mask_b]
163
+ merged_points = np.vstack([pareto_points_a, pareto_points_b])
164
+ # Compute Pareto front of the merged set
165
+ merged_mask = pareto_front_cupy_vectorized(merged_points)
166
+ merged_indices = np.where(merged_mask)[0]
167
+ # Map merged_indices back to the original indices in X
168
+ # First, get the indices in X for the merged points
169
+ indices_a = np.where(mask_a)[0]
170
+ indices_b = np.where(mask_b)[0] + mid
171
+ all_indices = np.concatenate([indices_a, indices_b])
172
+ merged_indices_in_X = all_indices[merged_indices]
173
+ # Build the final mask for X
174
+ mask = np.zeros(N, dtype=bool)
175
+ mask[merged_indices_in_X] = True
176
+ return mask
177
+
178
+
179
+ # def makepareto(
180
+ # mappings: pd.DataFrame,
181
+ # columns: list[str] = None,
182
+ # parallelize: bool = False,
183
+ # split_by_cols: list[str] = (),
184
+ # ) -> pd.DataFrame:
185
+ # # return makepareto_time_compare(mappings)
186
+ # if columns is None:
187
+ # columns = [c for c in mappings.columns if col_used_in_pareto(c)]
188
+ # if _accelerated_imports.ACCELERATED:
189
+ # mask = pareto_front_cupy_blockwise_sorted_recursive(mappings[columns].to_cupy())
190
+ # return mappings[mask]
191
+
192
+
193
+ TOLERANCE = 0.0
194
+
195
+
196
+ def logify(x: pd.Series) -> pd.Series:
197
+ if 0 < TOLERANCE < 1:
198
+ pass
199
+ else:
200
+ assert (
201
+ TOLERANCE == 0
202
+ ), f"Tolerance must be between 0 and 1. Tolerance {TOLERANCE} is invalid."
203
+ return x
204
+
205
+ if x.min() <= 0:
206
+ return x
207
+
208
+ logged = np.log(x)
209
+
210
+ return np.round(logged / TOLERANCE) * TOLERANCE
211
+
212
+
213
+ def makepareto(
214
+ mappings: pd.DataFrame,
215
+ columns: list[str] = None,
216
+ parallelize: bool = False,
217
+ split_by_cols: list[str] = (),
218
+ ) -> pd.DataFrame:
219
+ # return makepareto_time_compare(mappings)
220
+ if columns is None:
221
+ columns = [c for c in mappings.columns if col_used_in_pareto(c)]
222
+
223
+ # Number of iterations is derived from the tile shapes, so we don't need to use it,
224
+ # since any row with the same tile shapes will have the same number of iterations.
225
+ split_by_cols = list(split_by_cols) + [
226
+ c
227
+ for c in mappings.columns
228
+ if is_fused_loop_col(c) and not is_n_iterations_col(c)
229
+ ]
230
+
231
+ goals = []
232
+ to_pareto = []
233
+ pareto_cols = []
234
+ for c in mappings.columns:
235
+ if mappings[c].nunique() <= 1:
236
+ continue
237
+
238
+ if c in columns and is_objective_col(c): # or col_used_in_pareto(c)):
239
+ to_pareto.append(logify(mappings[c]))
240
+ pareto_cols.append(c)
241
+ goals += ["min"]
242
+ elif c in split_by_cols:
243
+ to_pareto.append(mappings[c])
244
+ pareto_cols.append(c)
245
+ goals.append("diff")
246
+ elif c in columns:
247
+ to_pareto.append(mappings[c])
248
+ pareto_cols.append(c)
249
+ goals.append("min")
250
+
251
+ if not to_pareto:
252
+ return mappings.iloc[0:1]
253
+
254
+ return mappings[paretoset(pd.concat(to_pareto, axis=1), sense=goals)]
255
+
256
+ f = pd.concat(to_pareto, axis=1)
257
+ x = list(f.groupby([c for c, d in zip(pareto_cols, goals) if d == "diff"]))
258
+ print(x)
259
+
260
+
261
+ @functools.lru_cache(maxsize=10000)
262
+ def _factorint_cached(x: int):
263
+ return factorint(x)
264
+
265
+
266
+ def prime_factor_counts(arr: np.ndarray) -> np.ndarray:
267
+ if isinstance(arr, tuple):
268
+ return tuple(prime_factor_counts(a) for a in arr)
269
+
270
+ arr = np.asarray(arr, dtype=int)
271
+ unique_vals = np.unique(arr)
272
+ factorizations = {x: _factorint_cached(x) for x in unique_vals}
273
+
274
+ # Gather all unique primes
275
+ all_primes = sorted({p for f in factorizations.values() for p in f})
276
+
277
+ # Build result matrix
278
+ result = np.zeros((len(arr), len(all_primes)), dtype=int)
279
+ prime_index = {p: j for j, p in enumerate(all_primes)}
280
+
281
+ for i, x in enumerate(arr):
282
+ for p, exp in factorizations[x].items():
283
+ result[i, prime_index[p]] = exp
284
+
285
+ return result
286
+
287
+
288
+ def paretoset_grouped_dirty(df: pd.DataFrame, sense: list[str]):
289
+ # return paretoset(df, sense=sense)
290
+
291
+ assert all(i == c for i, c in enumerate(df.columns))
292
+ assert len(sense) == len(df.columns)
293
+
294
+ from paretoset.algorithms_numba import paretoset_jit
295
+ from paretoset.algorithms_numba import BNL
296
+
297
+ for c in df.columns:
298
+ if sense[c] == "max":
299
+ df[c] = -df[c]
300
+ sense[c] = "min"
301
+
302
+ GROUP_SIZE = 128
303
+
304
+ group_by = [c for c in df.columns if sense[c] == "diff"]
305
+ n_groups = prod(len(df[c].unique()) for c in group_by)
306
+
307
+ if len(df) / n_groups < GROUP_SIZE:
308
+ return paretoset(df, sense=sense)
309
+
310
+ c2unique = {c: len(df[c].unique()) for c in df.columns if c not in group_by}
311
+ while c2unique:
312
+ col, n = min(c2unique.items(), key=lambda x: x[1])
313
+ c2unique.pop(col)
314
+ n_groups *= n
315
+ if len(df) / n_groups < GROUP_SIZE:
316
+ break
317
+ group_by.append(col)
318
+
319
+ n_diffs = sum(x == "diff" for x in sense)
320
+ if len(group_by) < 2 or len(group_by) == n_diffs:
321
+ return paretoset(df, sense=sense)
322
+
323
+ def _row_from_group(mins, group):
324
+ per_col_mins = group.min(axis=0)
325
+ per_col_maxs = group.max(axis=0)
326
+ good_row = group.iloc[
327
+ np.argmin((group ** (1 / len(group.columns))).prod(axis=1))
328
+ ]
329
+ return [mins, per_col_mins, per_col_maxs, good_row, group]
330
+
331
+ groups = list(df.groupby(group_by))
332
+ groups_by_diff = {}
333
+ keepcols = [c for c in df.columns if c not in group_by]
334
+ for x, group in groups:
335
+ diffs, mins = x[:n_diffs], x[n_diffs:]
336
+ group = group[keepcols]
337
+ groups_by_diff.setdefault(diffs, []).append(_row_from_group(mins, group))
338
+
339
+ # print(f'Grouped into {len(groups)} groups using {len(group_by)} columns')
340
+ # orig_size = len(df)
341
+ # n_groups = len(groups)
342
+ # n_cols = len(keepcols)
343
+ # new_size = sum(len(g2) for g in groups_by_diff.values() for _, _, _, g2 in g)
344
+ # print(f'Grouped into {n_groups} groups, {orig_size} -> {new_size} rows, {n_cols} columns. Remaining {len(keepcols)} columns')
345
+
346
+ for groups in groups_by_diff.values():
347
+ for i, (
348
+ mins_a,
349
+ per_col_mins_a,
350
+ per_col_maxs_a,
351
+ good_row_a,
352
+ group_a,
353
+ ) in enumerate(groups):
354
+ if group_a is None:
355
+ continue
356
+
357
+ for j, (
358
+ mins_b,
359
+ per_col_mins_b,
360
+ per_col_maxs_b,
361
+ good_row_b,
362
+ group_b,
363
+ ) in enumerate(groups):
364
+ if group_b is None or i == j:
365
+ continue
366
+
367
+ if all(a <= b for a, b in zip(good_row_a, per_col_mins_b)):
368
+ groups[j][-1] = None
369
+ continue
370
+
371
+ if all(a <= b for a, b in zip(good_row_a, good_row_b)):
372
+ # The good row of a dominates the good row of b. It'll likely
373
+ # dominate many b!
374
+ group_b = group_b[(group_b < good_row_a).any(axis=1)]
375
+ if len(group_b) == 0:
376
+ groups[j][-1] = None
377
+ continue
378
+ groups[j].clear()
379
+ groups[j].extend(_row_from_group(mins_b, group_b))
380
+
381
+ # # a can only dominate b if all of the min columns dominate
382
+ # if not all(a <= b for a, b in zip(mins_a, mins_b)):
383
+ # continue
384
+
385
+ # # Check if any b beats all a. If so, continue.
386
+ # if any(a > b for a, b in zip(per_col_mins_a, per_col_maxs_b)):
387
+ # continue
388
+
389
+ # # # Check if any a beats every b. If so, get rid of b.
390
+ # # a_doms = all(a <= b for a, b in zip(per_col_maxs_a, per_col_mins_b))
391
+ # # if a_doms:
392
+ # # groups[j][-1] = None
393
+ # # # print(f'Dropping dominated group {j}')
394
+ # # continue
395
+
396
+ # row_a = group_a.iloc[np.random.randint(len(group_a))]
397
+ # if all(a <= b for a, b in zip(row_w_min_first_obj_b, per_col_mins_b)):
398
+ # groups[j][-1] = None
399
+
400
+ # Everything below just ended up making things slower
401
+
402
+ # if any(a > b for a, b in zip(row_a, per_col_maxs_b)):
403
+ # continue
404
+
405
+ # continue
406
+
407
+ # # Grab a random a. Get rid of all b that are dominated by it.
408
+ # a_lt_b_maxes = group_a.iloc[
409
+ # np.where(np.all(group_a <= per_col_maxs_b, axis=1))[0]
410
+ # ]
411
+ # if len(a_lt_b_maxes) == 0:
412
+ # continue
413
+
414
+ # row_a = a_lt_b_maxes.iloc[np.random.randint(len(a_lt_b_maxes))]
415
+
416
+ # b_idx = np.where(np.any(group_b < row_a, axis=1))[0]
417
+ # if len(b_idx) == 0:
418
+ # groups[j][-1] = None
419
+ # else:
420
+ # groups[j][-1] = group_b.iloc[b_idx]
421
+ # groups[j][1] = group_b.iloc[b_idx].min(axis=0)
422
+ # groups[j][2] = group_b.iloc[b_idx].max(axis=0)
423
+
424
+ # # Now we're in a case where a may dominate b. Update b.
425
+ # catted = pd.concat([group_a, group_b], axis=0)
426
+ # mask = np.concatenate([
427
+ # np.zeros(len(group_a), dtype=bool),
428
+ # np.ones(len(group_b), dtype=bool)
429
+ # ])
430
+ # catted = catted[paretoset_jit(catted.to_numpy()) & mask]
431
+ # groups[j][1] = catted.min(axis=0)
432
+ # groups[j][2] = catted.max(axis=0)
433
+ # groups[j][3] = catted
434
+
435
+ result = np.zeros(len(df), dtype=bool)
436
+ for group in groups_by_diff.values():
437
+ for _, _, _, _, group in group:
438
+ if group is not None:
439
+ result[group[paretoset_jit(group.to_numpy())].index] = True
440
+
441
+ return result
442
+
443
+
444
+ def makepareto_numpy(
445
+ mappings: np.ndarray,
446
+ goals: list[str],
447
+ dirty: bool = False,
448
+ ) -> pd.DataFrame:
449
+
450
+ to_pareto = []
451
+ new_goals = []
452
+ assert len(goals) == mappings.shape[1]
453
+ for c in range(mappings.shape[1]):
454
+ if len(np.unique(mappings[:, c])) <= 1:
455
+ continue
456
+
457
+ goal = goals[c]
458
+ # if goal != "diff" and dirty and len(np.unique(mappings[:, c])) < np.log2(mappings.shape[0]):
459
+ # # print(f"Changed {goal} to diff because there are {len(np.unique(mappings[:, c]))} unique values for {mappings.shape[0]} rows")
460
+ # goal = "diff"
461
+
462
+ if goal in ["min", "max"]:
463
+ l = logify(mappings[:, c].reshape((-1, 1)))
464
+ to_pareto.append(l if goal == "min" else -l)
465
+ new_goals.append("min")
466
+ elif goal == "diff":
467
+ to_pareto.append(mappings[:, c].reshape((-1, 1)))
468
+ new_goals.append("diff")
469
+ elif goal == "min_per_prime_factor":
470
+ if not dirty:
471
+ # Paretoset tends to be faster with these as diffs. Tanner tried for a
472
+ # long time to get min_per_prime_factor to be faster, but it
473
+ # didn't work. What it would do is say that if one choice for an inner
474
+ # loop has used up fewer of every prime factor than another choice, then
475
+ # the latter would give a superset of options for outer loops.
476
+ # Intuitively, we could enable more pruning by doing this instead of
477
+ # "diff", which is overconservative. Likewise, we could do "min" for
478
+ # imperfect instead of "diff". However, this ultimately made things
479
+ # slower because it didn't get much Pareto pruning, but caused many more
480
+ # Pareto comparisons ("diff" partitioning into N partitions --> N^2
481
+ # improvement). I hypothesize that the reason that it doesn't improve
482
+ # pruning much is that when we've enumerated a loop but not the loop
483
+ # above it, the given loop is almost always trading off tile shape for
484
+ # accesses, leading to no point being dominated by another point.
485
+ to_pareto.append(mappings[:, c].reshape((-1, 1)))
486
+ new_goals.append("diff")
487
+ else:
488
+ counts = prime_factor_counts(mappings[:, c])
489
+ for i in range(counts.shape[1]):
490
+ to_pareto.append(counts[:, i].reshape((-1, 1)))
491
+ new_goals.append("min")
492
+ elif goal == "max_per_prime_factor":
493
+ if not dirty:
494
+ # See above big comment.
495
+ to_pareto.append(mappings[:, c].reshape((-1, 1)))
496
+ new_goals.append("diff")
497
+ else:
498
+ counts = prime_factor_counts(mappings[:, c])
499
+ for i in range(counts.shape[1]):
500
+ to_pareto.append(counts[:, i].reshape((-1, 1)))
501
+ new_goals.append("max")
502
+ else:
503
+ raise ValueError(f"Unknown goal: {goal}")
504
+
505
+ if not to_pareto:
506
+ return mappings[:1]
507
+
508
+ df = pd.DataFrame(np.concatenate(to_pareto, axis=1), columns=range(len(to_pareto)))
509
+
510
+ if dirty:
511
+ return paretoset_grouped_dirty(df, sense=new_goals)
512
+ return paretoset(df, sense=new_goals)
513
+
514
+
515
+ @numba.jit(nopython=True)
516
+ def paretoset_attack_defend_jit(costs_attack, costs_defend, costs_shared):
517
+ """
518
+ Find the pareto-efficient points
519
+ :param costs: An (n_points, n_costs) array
520
+ :param return_mask: True to return a mask
521
+ :return: An array of indices of pareto-efficient points.
522
+ If return_mask is True, this will be an (n_points, ) boolean array
523
+ Otherwise it will be a (n_efficient_points, ) integer array of indices.
524
+ """
525
+ # https://stackoverflow.com/questions/32791911/fast-calculation-of-pareto-front-in-python
526
+
527
+ is_efficient = np.arange(costs_attack.shape[0])
528
+ n_points = costs_attack.shape[0]
529
+
530
+ next_point_index = 0 # Next index in the is_efficient array to search for
531
+ while next_point_index < len(costs_attack):
532
+ this_cost_attack = costs_attack[next_point_index]
533
+ this_cost_shared = costs_shared[next_point_index]
534
+
535
+ # Keep any point with a lower cost
536
+ current_efficient_points = any_jitted(costs_defend, this_cost_attack)
537
+ current_efficient_points |= any_jitted(costs_shared, this_cost_shared)
538
+
539
+ # np.any(costs < costs[next_point_index], axis=1)
540
+ current_efficient_points[next_point_index] = True # And keep self
541
+
542
+ # Remove dominated points
543
+ is_efficient = is_efficient[current_efficient_points]
544
+ costs_attack = costs_attack[current_efficient_points]
545
+ costs_defend = costs_defend[current_efficient_points]
546
+ costs_shared = costs_shared[current_efficient_points]
547
+
548
+ # Re-adjust the index
549
+ next_point_index = np.sum(current_efficient_points[:next_point_index]) + 1
550
+
551
+ is_efficient_mask = np.zeros(shape=n_points, dtype=np.bool_)
552
+ is_efficient_mask[is_efficient] = True
553
+ return is_efficient_mask
554
+
555
+
556
+ class Group:
557
+ def __init__(
558
+ self,
559
+ mins: np.ndarray,
560
+ group_shared: pd.DataFrame,
561
+ group_attack: pd.DataFrame,
562
+ group_defend: pd.DataFrame,
563
+ ):
564
+ self.mins = mins
565
+ self.group_shared = group_shared
566
+ self.group_attack = group_attack
567
+ self.group_defend = group_defend
568
+
569
+ scaleby = 1 / (
570
+ len(group_attack.columns) + len(group_shared.columns)
571
+ ) # Prevent overflow
572
+ row_attack_scores = (group_attack**scaleby).prod(axis=1)
573
+ row_shared_scores = (group_shared**scaleby).prod(axis=1)
574
+ good_row = np.argmin(row_attack_scores * row_shared_scores)
575
+
576
+ self.good_row_attack = group_attack.iloc[good_row]
577
+ self.good_row_shared = group_shared.iloc[good_row]
578
+
579
+ assert (
580
+ len(self.group_shared) == len(self.group_attack) == len(self.group_defend)
581
+ )
582
+
583
+ def __bool__(self):
584
+ return len(self.group_attack) > 0
585
+
586
+ def attack_with(self, other: "Group"):
587
+ if all(o <= s for o, s in zip(other.mins, self.mins)):
588
+ mask_defend = np.array(
589
+ (self.group_defend < other.good_row_attack).any(axis=1), dtype=bool
590
+ )
591
+ mask_shared = np.array(
592
+ (self.group_shared < other.good_row_shared).any(axis=1), dtype=bool
593
+ )
594
+ mask = mask_defend | mask_shared
595
+ self.group_attack = self.group_attack[mask]
596
+ self.group_defend = self.group_defend[mask]
597
+ self.group_shared = self.group_shared[mask]
598
+
599
+ def paretofy(self):
600
+ mask = paretoset_attack_defend_jit(
601
+ self.group_attack.to_numpy(),
602
+ self.group_defend.to_numpy(),
603
+ self.group_shared.to_numpy(),
604
+ )
605
+ self.group_attack = self.group_attack[mask]
606
+ self.group_defend = self.group_defend[mask]
607
+ self.group_shared = self.group_shared[mask]
608
+
609
+ def get_pareto_index(self):
610
+ mask = paretoset_attack_defend_jit(
611
+ self.group_attack.to_numpy(),
612
+ self.group_defend.to_numpy(),
613
+ self.group_shared.to_numpy(),
614
+ )
615
+ return self.group_shared.index[mask]
616
+
617
+
618
+ def paretoset_attack_defend_grouped_dirty(
619
+ attack: pd.DataFrame,
620
+ defend: pd.DataFrame,
621
+ shared: pd.DataFrame,
622
+ sense_shared: list[str],
623
+ sense_attack_defend: list[str],
624
+ ):
625
+ GROUP_SIZE = 128
626
+ assert all(i == c for i, c in enumerate(attack.columns))
627
+ assert all(i == c for i, c in enumerate(defend.columns))
628
+ assert all(i == c for i, c in enumerate(shared.columns))
629
+ assert len(sense_attack_defend) == len(attack.columns)
630
+ assert len(sense_attack_defend) == len(defend.columns)
631
+ assert len(sense_shared) == len(shared.columns)
632
+
633
+ assert all(x in ["min"] for x in sense_attack_defend)
634
+ assert all(x in ["min", "diff"] for x in sense_shared)
635
+
636
+ group_by = [c for c in shared.columns if sense_shared[c] == "diff"]
637
+ n_groups = prod(len(shared[c].unique()) for c in group_by)
638
+ c2unique = {c: len(shared[c].unique()) for c in shared.columns if c not in group_by}
639
+ while c2unique:
640
+ col, n = min(c2unique.items(), key=lambda x: x[1])
641
+ c2unique.pop(col)
642
+ n_groups *= n
643
+ if len(shared) / n_groups < GROUP_SIZE:
644
+ break
645
+ group_by.append(col)
646
+ n_diffs = sum(x == "diff" for x in sense_shared)
647
+
648
+ groups_shared = list(shared.groupby(group_by)) if group_by else [([], shared)]
649
+
650
+ groups_by_diff = {}
651
+ keepcols = [c for c in shared.columns if c not in group_by]
652
+ for x_shared, group_shared in groups_shared:
653
+ diffs, mins = x_shared[:n_diffs], x_shared[n_diffs:]
654
+ group_attack = attack.iloc[group_shared.index]
655
+ group_defend = defend.iloc[group_shared.index]
656
+ group_obj = Group(
657
+ mins,
658
+ group_shared,
659
+ group_attack,
660
+ group_defend,
661
+ )
662
+ groups_by_diff.setdefault(tuple(diffs), []).append(group_obj)
663
+
664
+ # print(f'Grouped into {len(groups)} groups using {len(group_by)} columns')
665
+ # orig_size = len(df)
666
+ # n_groups = len(groups)
667
+ # n_cols = len(keepcols)
668
+ # new_size = sum(len(g2) for g in groups_by_diff.values() for _, _, _, g2 in g)
669
+ # print(f'Grouped into {n_groups} groups, {orig_size} -> {new_size} rows, {n_cols} columns. Remaining {len(keepcols)} columns')
670
+
671
+ for groups in groups_by_diff.values():
672
+ for i, group_a in enumerate(groups):
673
+ if not group_a:
674
+ continue
675
+
676
+ for j, group_b in enumerate(groups):
677
+ if not group_b or i == j:
678
+ continue
679
+
680
+ group_a.attack_with(group_b)
681
+
682
+ # # a can only dominate b if all of the min columns dominate
683
+ # if not all(a <= b for a, b in zip(mins_a, mins_b)):
684
+ # continue
685
+
686
+ # # Check if any b beats all a. If so, continue.
687
+ # if any(a > b for a, b in zip(per_col_mins_a, per_col_maxs_b)):
688
+ # continue
689
+
690
+ # # # Check if any a beats every b. If so, get rid of b.
691
+ # # a_doms = all(a <= b for a, b in zip(per_col_maxs_a, per_col_mins_b))
692
+ # # if a_doms:
693
+ # # groups[j][-1] = None
694
+ # # # print(f'Dropping dominated group {j}')
695
+ # # continue
696
+
697
+ # row_a = group_a.iloc[np.random.randint(len(group_a))]
698
+ # if all(a <= b for a, b in zip(row_w_min_first_obj_b, per_col_mins_b)):
699
+ # groups[j][-1] = None
700
+
701
+ # Everything below just ended up making things slower
702
+
703
+ # if any(a > b for a, b in zip(row_a, per_col_maxs_b)):
704
+ # continue
705
+
706
+ # continue
707
+
708
+ # # Grab a random a. Get rid of all b that are dominated by it.
709
+ # a_lt_b_maxes = group_a.iloc[
710
+ # np.where(np.all(group_a <= per_col_maxs_b, axis=1))[0]
711
+ # ]
712
+ # if len(a_lt_b_maxes) == 0:
713
+ # continue
714
+
715
+ # row_a = a_lt_b_maxes.iloc[np.random.randint(len(a_lt_b_maxes))]
716
+
717
+ # b_idx = np.where(np.any(group_b < row_a, axis=1))[0]
718
+ # if len(b_idx) == 0:
719
+ # groups[j][-1] = None
720
+ # else:
721
+ # groups[j][-1] = group_b.iloc[b_idx]
722
+ # groups[j][1] = group_b.iloc[b_idx].min(axis=0)
723
+ # groups[j][2] = group_b.iloc[b_idx].max(axis=0)
724
+
725
+ # # Now we're in a case where a may dominate b. Update b.
726
+ # catted = pd.concat([group_a, group_b], axis=0)
727
+ # mask = np.concatenate([
728
+ # np.zeros(len(group_a), dtype=bool),
729
+ # np.ones(len(group_b), dtype=bool)
730
+ # ])
731
+ # catted = catted[paretoset_jit(catted.to_numpy()) & mask]
732
+ # groups[j][1] = catted.min(axis=0)
733
+ # groups[j][2] = catted.max(axis=0)
734
+ # groups[j][3] = catted
735
+
736
+ result = np.zeros(len(attack), dtype=bool)
737
+ total, kept = 0, 0
738
+ for groups in groups_by_diff.values():
739
+ for group in groups:
740
+ if group:
741
+ idx = group.get_pareto_index()
742
+ total += len(group.group_shared)
743
+ kept += len(idx)
744
+ result[idx] = True
745
+ return result
746
+
747
+
748
+ def makepareto_attack_defend_dirty(
749
+ objectives: list[tuple[np.ndarray, np.ndarray] | np.ndarray],
750
+ goals: list[str],
751
+ ) -> np.ndarray:
752
+ attack = []
753
+ defend = []
754
+ shared = []
755
+ sense_attack_defend = []
756
+ sense_shared = []
757
+ for objective, goal in zip(objectives, goals):
758
+ if isinstance(objective, tuple):
759
+ if goal == "min":
760
+ attack.append(objective[0])
761
+ defend.append(objective[1])
762
+ sense_attack_defend.append("min")
763
+ elif goal == "max":
764
+ attack.append(-objective[0])
765
+ defend.append(-objective[1])
766
+ sense_attack_defend.append("min")
767
+ elif goal in ["diff", "min_per_prime_factor", "max_per_prime_factor"]:
768
+ attack.append(objective[0])
769
+ defend.append(objective[1])
770
+ sense_attack_defend.append("diff")
771
+ elif goal == "min_per_prime_factor":
772
+ counts = prime_factor_counts(objective)
773
+ for i in range(counts[0].shape[1]):
774
+ attack.append(counts[0][:, i].reshape((-1, 1)))
775
+ defend.append(counts[1][:, i].reshape((-1, 1)))
776
+ sense_attack_defend.append("min")
777
+ sense_attack_defend.append("min")
778
+ elif goal == "max_per_prime_factor":
779
+ counts = prime_factor_counts(objective)
780
+ for i in range(counts[0].shape[1]):
781
+ attack.append(-counts[0][:, i].reshape((-1, 1)))
782
+ defend.append(-counts[1][:, i].reshape((-1, 1)))
783
+ sense_attack_defend.append("min")
784
+ else:
785
+ raise ValueError(f"Unknown goal: {goal}")
786
+
787
+ if isinstance(objective, np.ndarray):
788
+ if goal == "min":
789
+ shared.append(objective)
790
+ sense_shared.append("min")
791
+ elif goal == "max":
792
+ shared.append(-objective)
793
+ sense_shared.append("min")
794
+ elif goal == "diff":
795
+ shared.append(objective)
796
+ sense_shared.append("diff")
797
+ elif goal == "min_per_prime_factor":
798
+ counts = prime_factor_counts(objective)
799
+ for i in range(counts.shape[1]):
800
+ shared.append(counts[:, i].reshape((-1, 1)))
801
+ sense_shared.append("min")
802
+ elif goal == "max_per_prime_factor":
803
+ counts = prime_factor_counts(objective)
804
+ for i in range(counts.shape[1]):
805
+ shared.append(-counts[:, i].reshape((-1, 1)))
806
+ sense_shared.append("min")
807
+ else:
808
+ raise ValueError(f"Unknown goal: {goal}")
809
+
810
+ index_size = max(
811
+ x.size if isinstance(x, np.ndarray) else max(len(y) for y in x)
812
+ for x in attack + defend + shared
813
+ )
814
+
815
+ def stack(x: list[np.ndarray]) -> pd.DataFrame:
816
+ if not x:
817
+ return pd.DataFrame(columns=[], index=range(index_size))
818
+ x = [y.reshape(-1, 1) for y in x]
819
+ return pd.DataFrame(np.concatenate(x, axis=1), columns=range(len(x)))
820
+
821
+ if (
822
+ not attack
823
+ and not defend
824
+ and not any(
825
+ x in sense_shared for x in ["min_per_prime_factor", "max_per_prime_factor"]
826
+ )
827
+ ):
828
+ return paretoset(stack(shared), sense=sense_shared)
829
+
830
+ return paretoset_attack_defend_grouped_dirty(
831
+ shared=stack(shared),
832
+ attack=stack(attack),
833
+ defend=stack(defend),
834
+ sense_shared=sense_shared,
835
+ sense_attack_defend=sense_attack_defend,
836
+ )