accelforge 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accelforge/__init__.py +21 -0
- accelforge/_accelerated_imports.py +16 -0
- accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
- accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
- accelforge/_deprecate/_simanneal/simanneal.py +666 -0
- accelforge/_deprecate/_simanneal/tracking.py +105 -0
- accelforge/_deprecate/_simanneal/wrappers.py +218 -0
- accelforge/_deprecate/_simanneal2/__init__.py +7 -0
- accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
- accelforge/_deprecate/_simanneal2/tracking.py +116 -0
- accelforge/_deprecate/compatibility_util.py +181 -0
- accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
- accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
- accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
- accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
- accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
- accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
- accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
- accelforge/_deprecate/tags.py +69 -0
- accelforge/_deprecate/viz/__init__.py +0 -0
- accelforge/_deprecate/viz/interactive.py +159 -0
- accelforge/_deprecate/viz/reservationtree.py +307 -0
- accelforge/_deprecate/viz/ski_slope.py +88 -0
- accelforge/_version.py +15 -0
- accelforge/examples.py +39 -0
- accelforge/frontend/__init__.py +10 -0
- accelforge/frontend/_binding.py +129 -0
- accelforge/frontend/_workload_isl/__init__.py +2 -0
- accelforge/frontend/_workload_isl/_isl.py +149 -0
- accelforge/frontend/_workload_isl/_symbolic.py +141 -0
- accelforge/frontend/arch copy.py +1544 -0
- accelforge/frontend/arch.py +1642 -0
- accelforge/frontend/config.py +63 -0
- accelforge/frontend/mapper/__init__.py +5 -0
- accelforge/frontend/mapper/ffm.py +126 -0
- accelforge/frontend/mapper/mapper.py +7 -0
- accelforge/frontend/mapper/metrics.py +30 -0
- accelforge/frontend/mapping/__init__.py +1 -0
- accelforge/frontend/mapping/mapping.py +1736 -0
- accelforge/frontend/model.py +14 -0
- accelforge/frontend/renames.py +150 -0
- accelforge/frontend/spec copy.py +230 -0
- accelforge/frontend/spec.py +301 -0
- accelforge/frontend/variables.py +12 -0
- accelforge/frontend/workload.py +952 -0
- accelforge/mapper/FFM/__init__.py +9 -0
- accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
- accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
- accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
- accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
- accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
- accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
- accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
- accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
- accelforge/mapper/FFM/data.py +61 -0
- accelforge/mapper/FFM/main copy.py +236 -0
- accelforge/mapper/FFM/main.py +208 -0
- accelforge/mapper/FFM/mappings.py +510 -0
- accelforge/mapper/FFM/pmappings.py +310 -0
- accelforge/mapper/__init__.py +4 -0
- accelforge/mapper.py +0 -0
- accelforge/model/__init__.py +1 -0
- accelforge/model/_looptree/__init__.py +0 -0
- accelforge/model/_looptree/accesses.py +335 -0
- accelforge/model/_looptree/capacity/__init__.py +1 -0
- accelforge/model/_looptree/capacity/aggregators.py +36 -0
- accelforge/model/_looptree/capacity/capacity.py +47 -0
- accelforge/model/_looptree/energy.py +150 -0
- accelforge/model/_looptree/equivalent_ranks.py +29 -0
- accelforge/model/_looptree/latency/__init__.py +1 -0
- accelforge/model/_looptree/latency/latency.py +98 -0
- accelforge/model/_looptree/latency/memory.py +120 -0
- accelforge/model/_looptree/latency/processors.py +92 -0
- accelforge/model/_looptree/mapping_utilities.py +71 -0
- accelforge/model/_looptree/reuse/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
- accelforge/model/_looptree/reuse/isl/des.py +59 -0
- accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
- accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
- accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
- accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
- accelforge/model/_looptree/run.py +122 -0
- accelforge/model/_looptree/types.py +26 -0
- accelforge/model/_looptree/visualization/__init__.py +0 -0
- accelforge/model/_looptree/visualization/occupancy.py +11 -0
- accelforge/model/main.py +222 -0
- accelforge/plotting/__init__.py +2 -0
- accelforge/plotting/mappings.py +219 -0
- accelforge/plotting/specs.py +57 -0
- accelforge/util/__init__.py +4 -0
- accelforge/util/_base_analysis_types.py +24 -0
- accelforge/util/_basetypes.py +1089 -0
- accelforge/util/_frozenset.py +36 -0
- accelforge/util/_isl.py +29 -0
- accelforge/util/_itertools.py +14 -0
- accelforge/util/_mathfuncs.py +57 -0
- accelforge/util/_parse_expressions.py +339 -0
- accelforge/util/_picklecache.py +32 -0
- accelforge/util/_setexpressions.py +268 -0
- accelforge/util/_sympy/__init__.py +0 -0
- accelforge/util/_sympy/broadcast_max.py +18 -0
- accelforge/util/_visualization.py +112 -0
- accelforge/util/_yaml.py +579 -0
- accelforge/util/parallel.py +193 -0
- accelforge-0.0.1.dist-info/METADATA +64 -0
- accelforge-0.0.1.dist-info/RECORD +258 -0
- accelforge-0.0.1.dist-info/WHEEL +5 -0
- accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
- accelforge-0.0.1.dist-info/top_level.txt +5 -0
- docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
- docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
- docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
- docs/_build/html/_sources/fastfusion.rst.txt +20 -0
- docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
- docs/_build/html/_sources/index.rst.txt +87 -0
- docs/_build/html/_sources/modules.rst.txt +7 -0
- docs/_build/html/_sources/notes/citation.rst.txt +45 -0
- docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
- docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
- docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
- docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
- docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
- docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
- docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
- docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
- docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
- docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
- docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
- docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
- docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
- docs/_build/html/_sources/notes/spec.rst.txt +36 -0
- docs/source/_ext/include_attrs.py +213 -0
- docs/source/_ext/include_docstring.py +364 -0
- docs/source/_ext/include_functions.py +154 -0
- docs/source/_ext/include_notebook.py +131 -0
- docs/source/_ext/include_yaml.py +119 -0
- docs/source/_ext/inherited_attributes.py +222 -0
- docs/source/_ext/paths.py +4 -0
- docs/source/conf.py +79 -0
- examples/arches/compute_in_memory/_include.yaml +74 -0
- examples/arches/compute_in_memory/_include_functions.py +229 -0
- examples/arches/compute_in_memory/_load_spec.py +57 -0
- examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
- examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
- examples/arches/compute_in_memory/components/misc.py +195 -0
- examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
- examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
- examples/arches/compute_in_memory/isaac.yaml +233 -0
- examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
- examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
- examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
- examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
- examples/arches/eyeriss.yaml +68 -0
- examples/arches/fanout_variations/at_glb.yaml +31 -0
- examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
- examples/arches/fanout_variations/at_mac.yaml +31 -0
- examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
- examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
- examples/arches/nvdla.yaml +47 -0
- examples/arches/simple.yaml +28 -0
- examples/arches/tpu_v4i.yaml +67 -0
- examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
- examples/misc/component_annotated.yaml +33 -0
- examples/workloads/gpt3_6.7B.yaml +124 -0
- examples/workloads/matmuls.yaml +20 -0
- examples/workloads/mobilenet_28.yaml +81 -0
- examples/workloads/mobilenet_various_separate.yaml +106 -0
- examples/workloads/three_matmuls_annotated.yaml +59 -0
- notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
- notebooks/compute_in_memory/_scripts.py +339 -0
- notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
- notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
- notebooks/paths.py +4 -0
- notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
- notebooks/tutorials/FFM.ipynb +3498 -0
- notebooks/tutorials/_include.py +48 -0
- notebooks/tutorials/component_energy_area.ipynb +363 -0
- tests/Q_mapping.yaml +38 -0
- tests/__init__.py +0 -0
- tests/conv.mapping.yaml +27 -0
- tests/conv.workload.yaml +13 -0
- tests/conv_sym.mapping.yaml +43 -0
- tests/copy.mapping.yaml +35 -0
- tests/copy.workload.yaml +15 -0
- tests/distribuffers/__init__.py +0 -0
- tests/distribuffers/multicast/test_cases.yaml +482 -0
- tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
- tests/distribuffers/spec/distributed.yaml +100 -0
- tests/distribuffers/spec/logical_arch.yaml +32 -0
- tests/distribuffers/spec/physical_arch.yaml +69 -0
- tests/distribuffers/test_binding.py +48 -0
- tests/frontend/__init__.py +0 -0
- tests/frontend/test_mapping_viz.py +52 -0
- tests/mapper/__init__.py +0 -0
- tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
- tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
- tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
- tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
- tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
- tests/mapper/test_mapping_to_isl.py +90 -0
- tests/mapper/test_spatial_reuse_analysis.py +67 -0
- tests/mapper/test_temporal_reuse_analysis.py +56 -0
- tests/mapper/util.py +58 -0
- tests/matmul.mapping.yaml +29 -0
- tests/matmul.workload.yaml +12 -0
- tests/matmul_spatial.mapping.yaml +44 -0
- tests/mha.renames.yaml +65 -0
- tests/mha.workload.yaml +67 -0
- tests/mha.yaml +59 -0
- tests/mha_full.workload.yaml +67 -0
- tests/mobilenet.workload.yaml +35 -0
- tests/mobilenet_long.workload.yaml +64 -0
- tests/pmappingcache.py +24 -0
- tests/processing_stage.arch.yaml +40 -0
- tests/snowcat.arch.yaml +36 -0
- tests/test_ffm_join_pmappings.py +106 -0
- tests/test_ffm_make_pmappings.py +82 -0
- tests/test_ffm_make_tile_shapes.py +49 -0
- tests/test_mapper.py +100 -0
- tests/test_model.py +37 -0
- tests/test_plotting.py +72 -0
- tests/test_processing_stage.py +46 -0
- tests/test_symbolic_model.py +248 -0
- tests/test_workload.py +141 -0
|
@@ -0,0 +1,1681 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from functools import lru_cache
|
|
3
|
+
from math import ceil, log2, prod
|
|
4
|
+
import copy
|
|
5
|
+
import re
|
|
6
|
+
import resource
|
|
7
|
+
import time
|
|
8
|
+
from typing import Callable, Iterator, Optional
|
|
9
|
+
from sympy import Expr, Symbol, factorint, lambdify
|
|
10
|
+
from accelforge import util
|
|
11
|
+
from accelforge._accelerated_imports import np
|
|
12
|
+
from accelforge._accelerated_imports import pd
|
|
13
|
+
import accelforge.frontend.arch as arch
|
|
14
|
+
from accelforge.frontend._workload_isl._isl import get_rank_variable_bounds
|
|
15
|
+
from accelforge.frontend._workload_isl._symbolic import get_projection_expr
|
|
16
|
+
from accelforge.frontend.workload import Einsum
|
|
17
|
+
from accelforge.frontend.mapping import (
|
|
18
|
+
Loop,
|
|
19
|
+
Mapping,
|
|
20
|
+
Temporal,
|
|
21
|
+
Spatial,
|
|
22
|
+
TensorHolder,
|
|
23
|
+
)
|
|
24
|
+
from accelforge.mapper.FFM._make_pmappings.pmapper_job import Job
|
|
25
|
+
from accelforge.mapper.FFM._pareto_df.df_convention import (
|
|
26
|
+
stride2col,
|
|
27
|
+
initial2col,
|
|
28
|
+
iterations2col,
|
|
29
|
+
)
|
|
30
|
+
from accelforge.mapper.FFM._pareto_df.pareto import makepareto_numpy
|
|
31
|
+
from accelforge.model._looptree.reuse.symbolic import IMPERFECT
|
|
32
|
+
from accelforge.mapper.FFM._join_pmappings.pmapping_dataframe import (
|
|
33
|
+
nameloop2col,
|
|
34
|
+
tensor2col,
|
|
35
|
+
firstlatency2col,
|
|
36
|
+
)
|
|
37
|
+
from accelforge.frontend.mapper.metrics import Metrics
|
|
38
|
+
from accelforge.util._frozenset import fzs
|
|
39
|
+
import math
|
|
40
|
+
import sympy
|
|
41
|
+
import numpy as np
|
|
42
|
+
from numbers import Number
|
|
43
|
+
|
|
44
|
+
from accelforge.mapper.FFM._make_pmappings.make_pmappings_from_templates.symbol_relations import (
|
|
45
|
+
SymbolRelations,
|
|
46
|
+
)
|
|
47
|
+
from accelforge.util._sympy.broadcast_max import Max
|
|
48
|
+
from accelforge.mapper.FFM._make_pmappings.make_pmappings_from_templates.run_model import (
|
|
49
|
+
run_model,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class ComparisonResult(Enum):
|
|
54
|
+
ALWAYS_GEQ_THAN_ZERO = "ALWAYS_GEQ_THAN_ZERO"
|
|
55
|
+
ALWAYS_LEQ_THAN_ZERO = "ALWAYS_LEQ_THAN_ZERO"
|
|
56
|
+
ALWAYS_EQUAL_TO_ZERO = "ALWAYS_EQUAL_TO_ZERO"
|
|
57
|
+
UNKNOWN = "unknown"
|
|
58
|
+
|
|
59
|
+
def __or__(self, other: "ComparisonResult"):
|
|
60
|
+
if self == other:
|
|
61
|
+
return self
|
|
62
|
+
if self == ComparisonResult.ALWAYS_EQUAL_TO_ZERO:
|
|
63
|
+
return other
|
|
64
|
+
if other == ComparisonResult.ALWAYS_EQUAL_TO_ZERO:
|
|
65
|
+
return self
|
|
66
|
+
return ComparisonResult.UNKNOWN
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@lru_cache(maxsize=10000)
|
|
70
|
+
def diff(f: Expr, s: Symbol):
|
|
71
|
+
return sympy.diff(f, s)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@lru_cache(maxsize=10000)
|
|
75
|
+
def diff_geq_leq_zero(f: Expr, s: Symbol, bounds: tuple[tuple[Symbol, int, int], ...]):
|
|
76
|
+
# Assume ceiling won't affect the sign of the derivative. Changing from positive to
|
|
77
|
+
# zero or negative to zero is OK and does not count as changing the sign.
|
|
78
|
+
if isinstance(f, sympy.Expr):
|
|
79
|
+
f = f.replace(
|
|
80
|
+
lambda expr: expr.is_Function and expr.func == sympy.ceiling,
|
|
81
|
+
lambda expr: expr.args[0],
|
|
82
|
+
)
|
|
83
|
+
return geq_leq_zero(diff(f, s), bounds)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@lru_cache(maxsize=10000)
|
|
87
|
+
def function_range(f: Expr, s: Symbol, lo: int, hi: int):
|
|
88
|
+
return sympy.calculus.util.function_range(f, s, domain=sympy.Interval(lo, hi))
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def expr_replace(f: Expr, old: sympy.Function, new: Expr) -> Expr:
|
|
92
|
+
return f.replace(
|
|
93
|
+
lambda expr: expr.is_Function and expr.func == old,
|
|
94
|
+
lambda expr: new,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def partition_heaviside(f: Expr) -> tuple[Expr, ...]:
|
|
99
|
+
if f.has(sympy.Heaviside):
|
|
100
|
+
return expr_replace(f, sympy.Heaviside, 1), expr_replace(f, sympy.Heaviside, 0)
|
|
101
|
+
return (f,)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
# @lru_cache(maxsize=10000)
|
|
105
|
+
# def _get_function_range(
|
|
106
|
+
# f: Expr,
|
|
107
|
+
# check_symbols: tuple[Symbol, ...],
|
|
108
|
+
# bounds: tuple[tuple[Symbol, int, int], ...],
|
|
109
|
+
# return_min: bool,
|
|
110
|
+
# ) -> list:
|
|
111
|
+
# if isinstance(f, sympy.Expr):
|
|
112
|
+
# f = f.replace(
|
|
113
|
+
# lambda expr: expr.is_Function and expr.func == sympy.ceiling,
|
|
114
|
+
# lambda expr: expr.args[0],
|
|
115
|
+
# )
|
|
116
|
+
# fs = list(partition_heaviside(f))
|
|
117
|
+
# else:
|
|
118
|
+
# fs = [f]
|
|
119
|
+
|
|
120
|
+
# if len(fs) > 1:
|
|
121
|
+
# return [f3 for f2 in fs for f3 in _get_function_range(f2, check_symbols, bounds, return_min)]
|
|
122
|
+
|
|
123
|
+
# f = fs[0]
|
|
124
|
+
# check_symbol = check_symbols[0]
|
|
125
|
+
# check_symbols = check_symbols[1:]
|
|
126
|
+
# bounds = None
|
|
127
|
+
# for s, lo, hi in bounds:
|
|
128
|
+
# if s == check_symbol:
|
|
129
|
+
# bounds = (s, lo, hi)
|
|
130
|
+
# break
|
|
131
|
+
# else:
|
|
132
|
+
# raise ValueError(f"Symbol {check_symbol} not found in bounds")
|
|
133
|
+
|
|
134
|
+
# f_range = sympy.calculus.util.function_range(f, check_symbol, domain=sympy.Interval(lo, hi))
|
|
135
|
+
|
|
136
|
+
# if isinstance(f_range, sympy.FiniteSet):
|
|
137
|
+
# return [f3 for f2 in f_range for f3 in _get_function_range(f2, check_symbols, bounds, return_min)]
|
|
138
|
+
|
|
139
|
+
# target = f_range.left if return_min else f_range.right
|
|
140
|
+
# return _get_function_range(target, check_symbols, bounds, return_min)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@lru_cache(maxsize=10000)
|
|
144
|
+
def _compare_to_zero(
|
|
145
|
+
f: Expr, bounds: tuple[tuple[Symbol, int, int], ...], check_lt_zero: bool
|
|
146
|
+
) -> bool:
|
|
147
|
+
"""
|
|
148
|
+
Returns True if the function may possibly be less than zero or greater than zero.
|
|
149
|
+
|
|
150
|
+
If check_lt_zero is True, then we're checking if the function may possibly be less
|
|
151
|
+
than zero. Otherwise, we're checking if the function may possibly be greater than
|
|
152
|
+
zero.
|
|
153
|
+
|
|
154
|
+
If we can't tell, then conservatively return True.
|
|
155
|
+
"""
|
|
156
|
+
if isinstance(f, sympy.Expr):
|
|
157
|
+
f = f.replace(
|
|
158
|
+
lambda expr: expr.is_Function and expr.func == sympy.ceiling,
|
|
159
|
+
lambda expr: expr.args[0],
|
|
160
|
+
)
|
|
161
|
+
fs = list(partition_heaviside(f))
|
|
162
|
+
else:
|
|
163
|
+
fs = [f]
|
|
164
|
+
|
|
165
|
+
if len(fs) > 1:
|
|
166
|
+
return any(_compare_to_zero(f2, bounds, check_lt_zero) for f2 in fs)
|
|
167
|
+
|
|
168
|
+
f = fs[0]
|
|
169
|
+
try:
|
|
170
|
+
if check_lt_zero:
|
|
171
|
+
# Less than zero anywhere == NOT geq zero everywhere
|
|
172
|
+
return not f >= 0
|
|
173
|
+
else:
|
|
174
|
+
# Greater than zero anywhere == NOT leq zero everywhere
|
|
175
|
+
return not f <= 0
|
|
176
|
+
except TypeError:
|
|
177
|
+
pass
|
|
178
|
+
|
|
179
|
+
min_check, max_check = (any, all) if check_lt_zero else (all, any)
|
|
180
|
+
if isinstance(f, sympy.Min):
|
|
181
|
+
return min_check(_compare_to_zero(g, bounds, check_lt_zero) for g in f.args)
|
|
182
|
+
if isinstance(f, sympy.Max):
|
|
183
|
+
return max_check(_compare_to_zero(g, bounds, check_lt_zero) for g in f.args)
|
|
184
|
+
|
|
185
|
+
# Tried this on one workload and had marginally faster speeds with choosing the
|
|
186
|
+
# symbol that appears the least times. Also tried the symbol that appears the most
|
|
187
|
+
# times and the symbol that appears first in the bounds list. They had equivalent
|
|
188
|
+
# speeds, approx. 3% slower overall tile shape exploration than min.
|
|
189
|
+
chosen_s = min(f.free_symbols, key=lambda s: f.count(s))
|
|
190
|
+
for s, lo, hi in bounds:
|
|
191
|
+
if s == chosen_s:
|
|
192
|
+
break
|
|
193
|
+
else:
|
|
194
|
+
raise ValueError(f"Symbol {chosen_s} not found in bounds")
|
|
195
|
+
|
|
196
|
+
try:
|
|
197
|
+
f_range = function_range(f, s, lo, hi)
|
|
198
|
+
except (NotImplementedError, TypeError):
|
|
199
|
+
return True
|
|
200
|
+
|
|
201
|
+
if isinstance(f_range, sympy.FiniteSet):
|
|
202
|
+
return any(_compare_to_zero(f2, bounds, check_lt_zero) for f2 in f_range)
|
|
203
|
+
else:
|
|
204
|
+
return _compare_to_zero(
|
|
205
|
+
f_range.left if check_lt_zero else f_range.right,
|
|
206
|
+
bounds,
|
|
207
|
+
check_lt_zero,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
@lru_cache(maxsize=10000)
|
|
212
|
+
def geq_leq_zero(
|
|
213
|
+
f: Expr,
|
|
214
|
+
bounds: tuple[tuple[Symbol, int, int], ...],
|
|
215
|
+
):
|
|
216
|
+
# return geq_leq_than_zero(f, bounds)
|
|
217
|
+
lt_zero = _compare_to_zero(f, bounds, check_lt_zero=True)
|
|
218
|
+
gt_zero = _compare_to_zero(f, bounds, check_lt_zero=False)
|
|
219
|
+
|
|
220
|
+
if lt_zero and gt_zero:
|
|
221
|
+
return ComparisonResult.UNKNOWN
|
|
222
|
+
if lt_zero and not gt_zero:
|
|
223
|
+
return ComparisonResult.ALWAYS_LEQ_THAN_ZERO
|
|
224
|
+
if gt_zero and not lt_zero:
|
|
225
|
+
return ComparisonResult.ALWAYS_GEQ_THAN_ZERO
|
|
226
|
+
return ComparisonResult.ALWAYS_EQUAL_TO_ZERO
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def compile_dict(symbols, dictionary):
|
|
230
|
+
def lambdify(key, value):
|
|
231
|
+
x = util._lambdify_type_check(symbols, value)
|
|
232
|
+
return x
|
|
233
|
+
|
|
234
|
+
return {k: lambdify(symbols, v) for k, v in dictionary.items()}
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
class Goal:
|
|
238
|
+
"""
|
|
239
|
+
X subset Y means that Y will block pruning for all cases that X will block pruning.
|
|
240
|
+
|
|
241
|
+
- min is a subset of min_per_prime_factor is a subset of diff
|
|
242
|
+
- max is a subset of max_per_prime_factor is a subset of diff
|
|
243
|
+
|
|
244
|
+
If we're combining goals and they disagree, use the larger space.
|
|
245
|
+
"""
|
|
246
|
+
|
|
247
|
+
def __init__(
|
|
248
|
+
self,
|
|
249
|
+
goal: str = None,
|
|
250
|
+
max_value: Optional[float] = None,
|
|
251
|
+
only_care_if_valid: bool = False,
|
|
252
|
+
):
|
|
253
|
+
self.goal = goal
|
|
254
|
+
self.max_value = max_value
|
|
255
|
+
self.only_care_if_valid = only_care_if_valid
|
|
256
|
+
|
|
257
|
+
def __or__(self, other: "Goal"):
|
|
258
|
+
if self.goal is None:
|
|
259
|
+
return copy.copy(other)
|
|
260
|
+
if other.goal is None:
|
|
261
|
+
return copy.copy(self)
|
|
262
|
+
assert self.max_value == other.max_value
|
|
263
|
+
assert self.only_care_if_valid == other.only_care_if_valid
|
|
264
|
+
mv = self.max_value
|
|
265
|
+
care = self.only_care_if_valid or other.only_care_if_valid
|
|
266
|
+
|
|
267
|
+
# If the goals are the same, space doesn't change
|
|
268
|
+
if self.goal == other.goal:
|
|
269
|
+
return Goal(self.goal, max_value=mv, only_care_if_valid=care)
|
|
270
|
+
|
|
271
|
+
# min_per_prime_factor is a superset of min, so we can just keep the min_per_prime_factor goal
|
|
272
|
+
if {self.goal, other.goal} == {"min", "min_per_prime_factor"}:
|
|
273
|
+
return Goal("min_per_prime_factor", max_value=mv, only_care_if_valid=care)
|
|
274
|
+
|
|
275
|
+
# max_per_prime_factor is a superset of max, so we can just keep the max_per_prime_factor goal
|
|
276
|
+
if {self.goal, other.goal} == {"max", "max_per_prime_factor"}:
|
|
277
|
+
return Goal("max_per_prime_factor", max_value=mv, only_care_if_valid=care)
|
|
278
|
+
|
|
279
|
+
# Otherwise, there's a disagreement and the only space we're both in can be diff
|
|
280
|
+
return Goal("diff", max_value=mv, only_care_if_valid=care)
|
|
281
|
+
|
|
282
|
+
def __str__(self):
|
|
283
|
+
return f"{self.goal} {self.max_value} {self.only_care_if_valid}"
|
|
284
|
+
|
|
285
|
+
def __repr__(self):
|
|
286
|
+
return f"Goal({self.goal}, {self.max_value}, {self.only_care_if_valid})"
|
|
287
|
+
|
|
288
|
+
def __invert__(self):
|
|
289
|
+
if self.goal == "min":
|
|
290
|
+
return Goal("max", self.max_value, self.only_care_if_valid)
|
|
291
|
+
elif self.goal == "max":
|
|
292
|
+
return Goal("min", self.max_value, self.only_care_if_valid)
|
|
293
|
+
elif self.goal == "min_per_prime_factor":
|
|
294
|
+
raise ValueError("Can't invert min_per_prime_factor")
|
|
295
|
+
elif self.goal == "max_per_prime_factor":
|
|
296
|
+
raise ValueError("Can't invert max_per_prime_factor")
|
|
297
|
+
else:
|
|
298
|
+
return copy.copy(self)
|
|
299
|
+
|
|
300
|
+
def __eq__(self, other: "Goal"):
|
|
301
|
+
return (
|
|
302
|
+
isinstance(other, Goal)
|
|
303
|
+
and self.goal == other.goal
|
|
304
|
+
and self.max_value == other.max_value
|
|
305
|
+
and self.only_care_if_valid == other.only_care_if_valid
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
class Objective:
|
|
310
|
+
def __init__(
|
|
311
|
+
self,
|
|
312
|
+
name: str,
|
|
313
|
+
formula: Expr | Number,
|
|
314
|
+
max_value: float = None,
|
|
315
|
+
symbols: list[str] = None,
|
|
316
|
+
only_care_if_valid: bool = False,
|
|
317
|
+
min_value: float = None,
|
|
318
|
+
inclusive: bool = True,
|
|
319
|
+
try_best_if_none_reaches_min: bool = False,
|
|
320
|
+
):
|
|
321
|
+
if isinstance(formula, Number):
|
|
322
|
+
formula = sympy.Number(formula)
|
|
323
|
+
self.name: str = name
|
|
324
|
+
self.formula: Expr = simplify(formula)
|
|
325
|
+
self._symbols: list[str] = symbols
|
|
326
|
+
self.max_value: float = max_value
|
|
327
|
+
self.min_value: float = min_value
|
|
328
|
+
self.only_care_if_valid: bool = only_care_if_valid
|
|
329
|
+
if only_care_if_valid:
|
|
330
|
+
assert max_value is not None or min_value is not None
|
|
331
|
+
self.inclusive: bool = inclusive
|
|
332
|
+
self.try_best_if_none_reaches_min: bool = try_best_if_none_reaches_min
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def is_constant(f: Expr) -> bool:
|
|
336
|
+
try:
|
|
337
|
+
return f.is_constant()
|
|
338
|
+
except ValueError:
|
|
339
|
+
return all(is_constant(arg) for arg in f.args)
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
@lru_cache(maxsize=10000)
|
|
343
|
+
def _try_replace_single_term(
|
|
344
|
+
t: Expr,
|
|
345
|
+
symbols_enumerated: fzs[Symbol],
|
|
346
|
+
bounds: tuple[tuple[Symbol, int, int], ...],
|
|
347
|
+
):
|
|
348
|
+
goal = None
|
|
349
|
+
if len(t.free_symbols & symbols_enumerated) == 1:
|
|
350
|
+
s = next(iter(t.free_symbols & symbols_enumerated))
|
|
351
|
+
try:
|
|
352
|
+
diff_result = diff_geq_leq_zero(t, s, bounds)
|
|
353
|
+
if diff_result == ComparisonResult.ALWAYS_GEQ_THAN_ZERO:
|
|
354
|
+
goal = Goal("min")
|
|
355
|
+
elif diff_result == ComparisonResult.ALWAYS_LEQ_THAN_ZERO:
|
|
356
|
+
goal = Goal("max")
|
|
357
|
+
elif diff_result == ComparisonResult.UNKNOWN:
|
|
358
|
+
goal = Goal("diff")
|
|
359
|
+
elif diff_result == ComparisonResult.ALWAYS_EQUAL_TO_ZERO:
|
|
360
|
+
pass
|
|
361
|
+
else:
|
|
362
|
+
raise ValueError(
|
|
363
|
+
f"Comparison result {diff_result} is not a valid comparison result"
|
|
364
|
+
)
|
|
365
|
+
return s, goal
|
|
366
|
+
except (TypeError, ValueError):
|
|
367
|
+
pass
|
|
368
|
+
return t, None
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def try_replace_single_term(
|
|
372
|
+
t: Expr,
|
|
373
|
+
symbols_enumerated: fzs[Symbol],
|
|
374
|
+
bounds: tuple[tuple[Symbol, int, int], ...],
|
|
375
|
+
):
|
|
376
|
+
return _try_replace_single_term(t, symbols_enumerated & t.free_symbols, bounds)
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
@lru_cache(maxsize=10000)
|
|
380
|
+
def _partition_formula(
|
|
381
|
+
f: Expr,
|
|
382
|
+
symbols_enumerated: set[Symbol],
|
|
383
|
+
bounds: tuple[tuple[Symbol, int, int], ...],
|
|
384
|
+
) -> dict[Symbol, Goal]:
|
|
385
|
+
goals: dict[Symbol, Goal] = {}
|
|
386
|
+
|
|
387
|
+
def update_goal(symbol: Symbol, goal: str, **kwargs):
|
|
388
|
+
goals[symbol] = Goal(goal) | goals.get(symbol, Goal())
|
|
389
|
+
|
|
390
|
+
negate = False
|
|
391
|
+
|
|
392
|
+
if not f.free_symbols & symbols_enumerated:
|
|
393
|
+
return goals
|
|
394
|
+
|
|
395
|
+
def _try_replace_unknowns(t: Expr):
|
|
396
|
+
for s in t.free_symbols - symbols_enumerated:
|
|
397
|
+
if not affects_comparison(t, s, symbols_enumerated):
|
|
398
|
+
t = t.subs(s, 1)
|
|
399
|
+
return t
|
|
400
|
+
|
|
401
|
+
def _recombine_terms(terms: list[Expr]):
|
|
402
|
+
can_evaluate = []
|
|
403
|
+
no_relation = []
|
|
404
|
+
others = {}
|
|
405
|
+
for t in terms:
|
|
406
|
+
t = _try_replace_unknowns(t)
|
|
407
|
+
try:
|
|
408
|
+
if not t.free_symbols & symbols_enumerated:
|
|
409
|
+
continue
|
|
410
|
+
except (TypeError, ValueError):
|
|
411
|
+
pass
|
|
412
|
+
if t.free_symbols.issubset(symbols_enumerated):
|
|
413
|
+
can_evaluate.append(t)
|
|
414
|
+
elif t.free_symbols.isdisjoint(symbols_enumerated):
|
|
415
|
+
no_relation.append(t)
|
|
416
|
+
else:
|
|
417
|
+
others.setdefault(fzs(t.free_symbols - symbols_enumerated), []).append(
|
|
418
|
+
t
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
# Grab the terms that we can evaluate directly first
|
|
422
|
+
chosen = []
|
|
423
|
+
if can_evaluate:
|
|
424
|
+
chosen.append(type(f)(*can_evaluate))
|
|
425
|
+
# Ignore no relation
|
|
426
|
+
chosen.extend([x for v in others.values() for x in v])
|
|
427
|
+
|
|
428
|
+
return chosen
|
|
429
|
+
|
|
430
|
+
if isinstance(f, (sympy.Max, sympy.Min, sympy.Add, sympy.ceiling)):
|
|
431
|
+
terms = _recombine_terms(f.args)
|
|
432
|
+
elif isinstance(f, sympy.Mul):
|
|
433
|
+
terms = _recombine_terms(f.args)
|
|
434
|
+
# If the formula is a product:
|
|
435
|
+
# - Divide the max value by the constant factors
|
|
436
|
+
# - For non-constant factors, if they're >1 then we can keep the max.
|
|
437
|
+
# Otherwise we have to drop it.
|
|
438
|
+
for t in f.args:
|
|
439
|
+
geq_result = geq_leq_zero(t, bounds)
|
|
440
|
+
if geq_result == ComparisonResult.ALWAYS_LEQ_THAN_ZERO:
|
|
441
|
+
negate = not negate
|
|
442
|
+
elif geq_result == ComparisonResult.UNKNOWN:
|
|
443
|
+
negate = None
|
|
444
|
+
break
|
|
445
|
+
elif geq_result == ComparisonResult.ALWAYS_GEQ_THAN_ZERO:
|
|
446
|
+
pass
|
|
447
|
+
elif geq_result == ComparisonResult.ALWAYS_EQUAL_TO_ZERO:
|
|
448
|
+
pass
|
|
449
|
+
else:
|
|
450
|
+
raise ValueError(
|
|
451
|
+
f"Comparison result {geq_result} is not a valid comparison result"
|
|
452
|
+
)
|
|
453
|
+
else:
|
|
454
|
+
terms = [_try_replace_unknowns(f)]
|
|
455
|
+
|
|
456
|
+
for term in terms:
|
|
457
|
+
term, goal = try_replace_single_term(term, fzs(symbols_enumerated), bounds)
|
|
458
|
+
if goal is not None:
|
|
459
|
+
update_goal(term, goal.goal)
|
|
460
|
+
continue
|
|
461
|
+
|
|
462
|
+
# Constant! Don't care
|
|
463
|
+
if len(term.free_symbols & symbols_enumerated) == 0:
|
|
464
|
+
continue
|
|
465
|
+
|
|
466
|
+
if term.free_symbols.issubset(symbols_enumerated):
|
|
467
|
+
update_goal(term, "min")
|
|
468
|
+
continue
|
|
469
|
+
|
|
470
|
+
# Don't recurse with the same formula. If we got here without simplifying it,
|
|
471
|
+
# give up and mark everything "diff".
|
|
472
|
+
if term == f:
|
|
473
|
+
for symbol in term.free_symbols:
|
|
474
|
+
update_goal(symbol, "diff")
|
|
475
|
+
else:
|
|
476
|
+
for subterm, subgoal in partition_formula(
|
|
477
|
+
term, symbols_enumerated, bounds
|
|
478
|
+
).items():
|
|
479
|
+
goals[subterm] = subgoal | goals.get(subterm, Goal())
|
|
480
|
+
|
|
481
|
+
for k, v in goals.items():
|
|
482
|
+
if negate:
|
|
483
|
+
goals[k] = ~v
|
|
484
|
+
if negate is None:
|
|
485
|
+
v.goal = "diff"
|
|
486
|
+
|
|
487
|
+
return goals
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
@lru_cache(maxsize=10000)
|
|
491
|
+
def _get_n_prime_factors(n: int) -> int:
|
|
492
|
+
return len(factorint(n))
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
def partition_formula(
|
|
496
|
+
f: Expr,
|
|
497
|
+
symbols_enumerated: set[Symbol],
|
|
498
|
+
bounds: tuple[tuple[Symbol, int, int], ...],
|
|
499
|
+
) -> dict[Symbol, Goal]:
|
|
500
|
+
return _partition_formula(f, fzs(symbols_enumerated & f.free_symbols), bounds)
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
def get_possible_factor_sizes(n: int, imperfect: bool = False) -> list[int]:
|
|
504
|
+
factors = []
|
|
505
|
+
for i in range(1, math.ceil(n**0.5) + 1):
|
|
506
|
+
if not imperfect and n % i != 0:
|
|
507
|
+
continue
|
|
508
|
+
factors.append(i)
|
|
509
|
+
factors.append(math.ceil(n / i))
|
|
510
|
+
return sorted(set(factors))
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
def append_vector(matrix: np.ndarray, vector: np.ndarray):
|
|
514
|
+
if matrix is None:
|
|
515
|
+
return vector.reshape(-1, 1)
|
|
516
|
+
return np.concatenate(
|
|
517
|
+
(
|
|
518
|
+
np.repeat(matrix, vector.shape[0], axis=0),
|
|
519
|
+
np.tile(vector.reshape(-1, 1), (matrix.shape[0], 1)),
|
|
520
|
+
),
|
|
521
|
+
axis=1,
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
@lru_cache(maxsize=10000)
|
|
526
|
+
def simplify(f: Expr):
|
|
527
|
+
return f.simplify()
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
def symbol2int(symbol: Symbol):
|
|
531
|
+
return int(re.findall(r"(\d+)", symbol.name)[0])
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
@lru_cache(maxsize=10000)
|
|
535
|
+
def f_minus_other_f(f: Expr, symbols_enumerated: set[Symbol]):
|
|
536
|
+
f2 = f
|
|
537
|
+
for s in f.free_symbols & symbols_enumerated:
|
|
538
|
+
f2 = f2.subs(s, sympy.Symbol(f"{s}_2", integer=True, positive=True))
|
|
539
|
+
return f2 - f > 0
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
@lru_cache(maxsize=10000)
|
|
543
|
+
def affects_comparison(f: Expr, s: Symbol, symbols_enumerated: set[Symbol]):
|
|
544
|
+
if not isinstance(f, sympy.Expr):
|
|
545
|
+
return False
|
|
546
|
+
delta = f_minus_other_f(f, symbols_enumerated)
|
|
547
|
+
if not isinstance(delta, sympy.Expr) or s not in delta.free_symbols:
|
|
548
|
+
return False
|
|
549
|
+
|
|
550
|
+
delta = simplify(delta)
|
|
551
|
+
if s not in delta.free_symbols:
|
|
552
|
+
return False
|
|
553
|
+
|
|
554
|
+
return True
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
def get_padded_choices(
|
|
558
|
+
symbols_enumerated: list[Symbol],
|
|
559
|
+
symbols_non_enumerated_set: set[Symbol],
|
|
560
|
+
choices_enumerated: np.ndarray,
|
|
561
|
+
what_tiles_symbol: SymbolRelations,
|
|
562
|
+
minimize_formula: Expr = None,
|
|
563
|
+
maximize_formula: Expr = None,
|
|
564
|
+
):
|
|
565
|
+
choices_padded = {}
|
|
566
|
+
ones = np.ones(choices_enumerated.shape[0], choices_enumerated.dtype)
|
|
567
|
+
for symbol in symbols_enumerated:
|
|
568
|
+
choices_padded[symbol] = choices_enumerated[:, symbols_enumerated.index(symbol)]
|
|
569
|
+
for symbol in symbols_non_enumerated_set:
|
|
570
|
+
choices_padded[symbol] = ones
|
|
571
|
+
if minimize_formula is not None or maximize_formula is not None:
|
|
572
|
+
if minimize_formula is None:
|
|
573
|
+
formula = maximize_formula
|
|
574
|
+
sign = -1
|
|
575
|
+
elif maximize_formula is None:
|
|
576
|
+
formula = minimize_formula
|
|
577
|
+
sign = 1
|
|
578
|
+
else:
|
|
579
|
+
raise ValueError(
|
|
580
|
+
"Both minimize_formula and maximize_formula are not None"
|
|
581
|
+
)
|
|
582
|
+
diff_result = diff_geq_leq_zero(
|
|
583
|
+
sign * formula, symbol, what_tiles_symbol.bounds
|
|
584
|
+
)
|
|
585
|
+
if diff_result == ComparisonResult.ALWAYS_LEQ_THAN_ZERO:
|
|
586
|
+
choices_padded[symbol] = ones * what_tiles_symbol.get_max_size(symbol)
|
|
587
|
+
elif diff_result == ComparisonResult.ALWAYS_GEQ_THAN_ZERO:
|
|
588
|
+
pass
|
|
589
|
+
elif diff_result == ComparisonResult.ALWAYS_EQUAL_TO_ZERO:
|
|
590
|
+
pass
|
|
591
|
+
elif diff_result == ComparisonResult.UNKNOWN:
|
|
592
|
+
raise ValueError(f"Can't tell if {symbol} is increasing or decreasing")
|
|
593
|
+
else:
|
|
594
|
+
raise ValueError(
|
|
595
|
+
f"Comparison result {diff_result} is not a valid comparison result"
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
return choices_padded
|
|
599
|
+
|
|
600
|
+
|
|
601
|
+
def check_loops(
|
|
602
|
+
symbols_enumerated: list[Symbol],
|
|
603
|
+
choices_enumerated: np.ndarray,
|
|
604
|
+
max_loop_check_groups: list[tuple[Number, list[Symbol]]],
|
|
605
|
+
what_tiles_symbol: SymbolRelations,
|
|
606
|
+
):
|
|
607
|
+
def get_size(x: Symbol | int):
|
|
608
|
+
if isinstance(x, Symbol) and x in symbols_enumerated:
|
|
609
|
+
return choices_enumerated[:, symbols_enumerated.index(x)]
|
|
610
|
+
elif isinstance(x, Symbol):
|
|
611
|
+
return what_tiles_symbol.get_max_size(x)
|
|
612
|
+
else:
|
|
613
|
+
return x
|
|
614
|
+
|
|
615
|
+
def has_fanout(x: Symbol | int):
|
|
616
|
+
outer = get_size(what_tiles_symbol.get_inner_tiles(x))
|
|
617
|
+
inner = get_size(x)
|
|
618
|
+
return outer != inner
|
|
619
|
+
|
|
620
|
+
def can_check(x: Symbol | int):
|
|
621
|
+
if isinstance(x, Symbol) and x not in symbols_enumerated:
|
|
622
|
+
return False
|
|
623
|
+
# tiles = what_tiles_symbol.get_outer_tiles(x, none_if_fail=True)
|
|
624
|
+
# if tiles is not None and isinstance(tiles, Symbol) and tiles not in symbols_enumerated:
|
|
625
|
+
# return False
|
|
626
|
+
return True
|
|
627
|
+
|
|
628
|
+
for limit, group in max_loop_check_groups:
|
|
629
|
+
prev_len = choices_enumerated.shape[0]
|
|
630
|
+
if len(group) <= limit:
|
|
631
|
+
continue
|
|
632
|
+
|
|
633
|
+
n = 0
|
|
634
|
+
for g in group:
|
|
635
|
+
if can_check(g):
|
|
636
|
+
n += has_fanout(g)
|
|
637
|
+
|
|
638
|
+
if isinstance(n, np.ndarray):
|
|
639
|
+
choices_enumerated = choices_enumerated[n <= limit]
|
|
640
|
+
elif n > limit:
|
|
641
|
+
choices_enumerated = choices_enumerated[0:0, :]
|
|
642
|
+
|
|
643
|
+
return choices_enumerated
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
def coalesce_symbols(
|
|
647
|
+
update_symbol2goal: Callable,
|
|
648
|
+
symbols_enumerated: list[Symbol],
|
|
649
|
+
symbol2goal: dict[Symbol, Goal],
|
|
650
|
+
log_message: Callable,
|
|
651
|
+
bounds: tuple[tuple[Symbol, int, int], ...],
|
|
652
|
+
):
|
|
653
|
+
sym_enumerated_set = fzs(symbols_enumerated)
|
|
654
|
+
new_symbol2goal = {}
|
|
655
|
+
|
|
656
|
+
log_message("coalesce symbols", f"initial")
|
|
657
|
+
for s, g in symbol2goal.items():
|
|
658
|
+
log_message(f"\t{g.goal}: {s}")
|
|
659
|
+
|
|
660
|
+
changed = True
|
|
661
|
+
while changed:
|
|
662
|
+
new_symbol2goal = {}
|
|
663
|
+
|
|
664
|
+
def latest(s=None):
|
|
665
|
+
if s is None:
|
|
666
|
+
x = dict(symbol2goal)
|
|
667
|
+
x.update(new_symbol2goal)
|
|
668
|
+
return x
|
|
669
|
+
return new_symbol2goal[s] if s in new_symbol2goal else symbol2goal[s]
|
|
670
|
+
|
|
671
|
+
for formula, goal in list(symbol2goal.items()):
|
|
672
|
+
# Not dependent on any enumerated symbols, so drop it
|
|
673
|
+
if not formula.free_symbols & sym_enumerated_set:
|
|
674
|
+
log_message("coalesce symbols", f"dropping constant: {formula}")
|
|
675
|
+
continue
|
|
676
|
+
|
|
677
|
+
# It is an enumerated symbol, so just keep it
|
|
678
|
+
if formula in symbols_enumerated:
|
|
679
|
+
update_symbol2goal(formula, goal, new_symbol2goal)
|
|
680
|
+
continue
|
|
681
|
+
|
|
682
|
+
# If it's a sum, remove any terms that are constant
|
|
683
|
+
if isinstance(formula, sympy.Add):
|
|
684
|
+
for term in formula.args:
|
|
685
|
+
if len(term.free_symbols) == 0:
|
|
686
|
+
formula = formula.subs(term, 0)
|
|
687
|
+
log_message("coalesce symbols", f"dropping constant: {term}")
|
|
688
|
+
continue
|
|
689
|
+
if len(formula.args) == 1:
|
|
690
|
+
formula = formula.args[0]
|
|
691
|
+
|
|
692
|
+
# If it's a product, remove any terms that are constant
|
|
693
|
+
if isinstance(formula, sympy.Mul):
|
|
694
|
+
for term in formula.args:
|
|
695
|
+
if len(term.free_symbols) == 0:
|
|
696
|
+
formula = formula.subs(term, 1)
|
|
697
|
+
if term < 0:
|
|
698
|
+
goal = ~goal
|
|
699
|
+
log_message("coalesce symbols", f"dropping constant: {term}")
|
|
700
|
+
continue
|
|
701
|
+
if len(formula.args) == 1:
|
|
702
|
+
formula = formula.args[0]
|
|
703
|
+
|
|
704
|
+
# If it's a function of a non-enumerated symbol or a symbol that we can't
|
|
705
|
+
# compare and it won't affect comparisons, then we can drop it.
|
|
706
|
+
|
|
707
|
+
# If it's a function of a non-enumerated symbol &
|
|
708
|
+
for s in formula.free_symbols:
|
|
709
|
+
if s in symbols_enumerated and latest().get(s, Goal()).goal != "diff":
|
|
710
|
+
continue
|
|
711
|
+
|
|
712
|
+
if not affects_comparison(formula, s, sym_enumerated_set):
|
|
713
|
+
formula = formula.subs(s, 1)
|
|
714
|
+
log_message(
|
|
715
|
+
"coalesce symbols",
|
|
716
|
+
f"dropping non-comparable symbol that does not affect comparison {s}: {formula}",
|
|
717
|
+
)
|
|
718
|
+
continue
|
|
719
|
+
else:
|
|
720
|
+
log_message(
|
|
721
|
+
"coalesce symbols",
|
|
722
|
+
f"keeping dropping symbol that affects comparison {s}: {formula}",
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
# If there's only one symbol in the formula, we can try to replace it with
|
|
726
|
+
# just the symbol.
|
|
727
|
+
if len(formula.free_symbols & sym_enumerated_set) == 1:
|
|
728
|
+
formula, new_goal = try_replace_single_term(
|
|
729
|
+
formula, sym_enumerated_set, bounds
|
|
730
|
+
)
|
|
731
|
+
if new_goal is not None:
|
|
732
|
+
log_message("coalesce symbols", f"replacing single term: {formula}")
|
|
733
|
+
update_symbol2goal(formula, new_goal, new_symbol2goal)
|
|
734
|
+
|
|
735
|
+
# If we're a fraction and all of our symbols are in the denominator, replace
|
|
736
|
+
# it with the reciprocal and change the goal
|
|
737
|
+
if isinstance(formula, sympy.Mul):
|
|
738
|
+
for term in formula.args:
|
|
739
|
+
if len(term.free_symbols) == 0:
|
|
740
|
+
continue
|
|
741
|
+
if isinstance(term, sympy.Pow) and term.args[1] == -1:
|
|
742
|
+
continue
|
|
743
|
+
break
|
|
744
|
+
else:
|
|
745
|
+
log_message("coalesce symbols", f"replacing reciprocal: {formula}")
|
|
746
|
+
formula = 1 / formula
|
|
747
|
+
goal = ~goal
|
|
748
|
+
|
|
749
|
+
# # If a symbol does not affect the formula, we can remove it
|
|
750
|
+
# for s in formula.free_symbols:
|
|
751
|
+
# diff_result = diff_geq_leq_zero(formula, s, bounds)
|
|
752
|
+
# if diff_result == ComparisonResult.ALWAYS_EQUAL_TO_ZERO:
|
|
753
|
+
# formula = formula.subs(s, 1)
|
|
754
|
+
# log_message("coalesce symbols", f"dropping symbol based on derivative == 0: {s}: {formula}")
|
|
755
|
+
# continue
|
|
756
|
+
# else:
|
|
757
|
+
# log_message("coalesce symbols", f"not dropping symbol based on derivative == 0: {s}: {formula}")
|
|
758
|
+
|
|
759
|
+
# If a formula agrees entirely with other goals, then we can remove it
|
|
760
|
+
disagrees = []
|
|
761
|
+
for s in formula.free_symbols:
|
|
762
|
+
g = latest(s).goal if s in latest() else None
|
|
763
|
+
if g in ["min", "max"]:
|
|
764
|
+
diff_result = diff_geq_leq_zero(formula, s, bounds)
|
|
765
|
+
if diff_result == ComparisonResult.ALWAYS_LEQ_THAN_ZERO:
|
|
766
|
+
this_goal = (~goal).goal
|
|
767
|
+
elif diff_result == ComparisonResult.ALWAYS_GEQ_THAN_ZERO:
|
|
768
|
+
this_goal = (goal).goal
|
|
769
|
+
elif diff_result == ComparisonResult.UNKNOWN:
|
|
770
|
+
break
|
|
771
|
+
elif diff_result == ComparisonResult.ALWAYS_EQUAL_TO_ZERO:
|
|
772
|
+
this_goal = g # Make it agree
|
|
773
|
+
else:
|
|
774
|
+
diff_geq_leq_zero(formula, s, bounds)
|
|
775
|
+
raise ValueError(
|
|
776
|
+
f"Comparison result {diff_result} is not a valid comparison result"
|
|
777
|
+
)
|
|
778
|
+
if g != this_goal:
|
|
779
|
+
disagrees.append(s)
|
|
780
|
+
continue
|
|
781
|
+
break
|
|
782
|
+
else:
|
|
783
|
+
# We didn't break! This formula agrees with all other goals, so we can
|
|
784
|
+
# remove it.
|
|
785
|
+
log_message(
|
|
786
|
+
"coalesce symbols",
|
|
787
|
+
f"removing formula that agrees with all other goals: {formula}",
|
|
788
|
+
)
|
|
789
|
+
for s in disagrees:
|
|
790
|
+
log_message(
|
|
791
|
+
"coalesce symbols",
|
|
792
|
+
f"previous formula disagreed with {s}. Changing goal to diff",
|
|
793
|
+
)
|
|
794
|
+
update_symbol2goal(s, Goal("diff"), new_symbol2goal)
|
|
795
|
+
continue
|
|
796
|
+
update_symbol2goal(formula, goal, new_symbol2goal)
|
|
797
|
+
|
|
798
|
+
changed = symbol2goal != new_symbol2goal
|
|
799
|
+
symbol2goal = new_symbol2goal
|
|
800
|
+
|
|
801
|
+
log_message("coalesce symbols", f"final")
|
|
802
|
+
for s, g in symbol2goal.items():
|
|
803
|
+
log_message(f"\t{g.goal}: {s}")
|
|
804
|
+
|
|
805
|
+
return symbol2goal
|
|
806
|
+
|
|
807
|
+
|
|
808
|
+
def get_tile_shape_choices(
|
|
809
|
+
objectives: list[Objective],
|
|
810
|
+
symbols: list[Symbol],
|
|
811
|
+
what_tiles_symbol: SymbolRelations,
|
|
812
|
+
job: "Job",
|
|
813
|
+
keep_symbols: list[Symbol] = (),
|
|
814
|
+
max_loop_check_groups: list[tuple[Number, list[Symbol]]] = (),
|
|
815
|
+
):
|
|
816
|
+
objectives = [copy.deepcopy(o) for o in objectives]
|
|
817
|
+
|
|
818
|
+
import time
|
|
819
|
+
|
|
820
|
+
objectives = objectives.copy()
|
|
821
|
+
|
|
822
|
+
symbols_enumerated: list[Symbol] = []
|
|
823
|
+
choices_enumerated: np.ndarray = None
|
|
824
|
+
|
|
825
|
+
symbols_remaining = list(symbols)
|
|
826
|
+
|
|
827
|
+
imperfect = IMPERFECT
|
|
828
|
+
|
|
829
|
+
# Inner to outer faster if there's symbols to keep because those symbols end up in
|
|
830
|
+
# the outer loops, so it does those symbols (which end up multiplying our choices)
|
|
831
|
+
# last. Outer to inner is faster if there's no symbols to keep because that's what
|
|
832
|
+
# happened on exactly one workload that Tanner tested.
|
|
833
|
+
# TILE_SHAPE_ORDER = "inner_to_outer_one_rv_at_a_time" if keep_symbols else "outer_to_inner_one_rv_at_a_time"
|
|
834
|
+
TILE_SHAPE_ORDER = "inner_to_outer_one_rv_at_a_time"
|
|
835
|
+
# TILE_SHAPE_ORDER = "inner_to_outer"
|
|
836
|
+
|
|
837
|
+
# For imperfect, we make inner tile shapes, then create outer tile shapes that are
|
|
838
|
+
# multiples of the non-residual part of the inner tile shape. This way, the very last
|
|
839
|
+
# iteration of the outer tile shape fully contains the reisudal part of the inner tile
|
|
840
|
+
# shape, and we don't have any cases where there are residuals stacking across multiple
|
|
841
|
+
# loop levels.
|
|
842
|
+
if IMPERFECT:
|
|
843
|
+
assert TILE_SHAPE_ORDER == "inner_to_outer_one_rv_at_a_time"
|
|
844
|
+
|
|
845
|
+
paretoed_by = []
|
|
846
|
+
|
|
847
|
+
prev_time, start_time = time.time(), time.time()
|
|
848
|
+
times = {}
|
|
849
|
+
|
|
850
|
+
def time_end(s):
|
|
851
|
+
nonlocal prev_time
|
|
852
|
+
cur_time = time.time()
|
|
853
|
+
times.setdefault(s, 0)
|
|
854
|
+
times[s] += cur_time - prev_time
|
|
855
|
+
prev_time = cur_time
|
|
856
|
+
|
|
857
|
+
def log_message(message: str, *args: str):
|
|
858
|
+
t = time.time() - prev_time
|
|
859
|
+
s = "**" if t > 1 else ""
|
|
860
|
+
job.log_message(f"{s}{t:.2f}s: {message} {' '.join(args)}")
|
|
861
|
+
# print(f"{time.time() - prev_time:.2f}s: {message} {' '.join(args)}")
|
|
862
|
+
time_end(message)
|
|
863
|
+
|
|
864
|
+
log_message("init")
|
|
865
|
+
|
|
866
|
+
def eval_objective(
|
|
867
|
+
formula: Expr | Objective,
|
|
868
|
+
choices: np.ndarray,
|
|
869
|
+
minimize_formula: Expr = None,
|
|
870
|
+
maximize_formula: Expr = None,
|
|
871
|
+
):
|
|
872
|
+
if isinstance(formula, Objective):
|
|
873
|
+
formula = formula.formula
|
|
874
|
+
if formula in symbols_enumerated:
|
|
875
|
+
return choices[:, symbols_enumerated.index(formula)]
|
|
876
|
+
|
|
877
|
+
padded_choices = get_padded_choices(
|
|
878
|
+
symbols_enumerated=symbols_enumerated,
|
|
879
|
+
symbols_non_enumerated_set=symbols_non_enumerated_set,
|
|
880
|
+
choices_enumerated=choices,
|
|
881
|
+
what_tiles_symbol=what_tiles_symbol,
|
|
882
|
+
minimize_formula=minimize_formula,
|
|
883
|
+
maximize_formula=maximize_formula,
|
|
884
|
+
)
|
|
885
|
+
return util._lambdify_type_check(symbols, formula)(
|
|
886
|
+
**{str(k): v for k, v in padded_choices.items()},
|
|
887
|
+
)
|
|
888
|
+
|
|
889
|
+
def grab_symbol(prev_symbol: Symbol = None):
|
|
890
|
+
# TODO: Maybe start with a symbol that would result in more pruning up front?
|
|
891
|
+
# Maximize the # of choices that can be resolved easily
|
|
892
|
+
if TILE_SHAPE_ORDER == "inner_to_outer":
|
|
893
|
+
return symbols_remaining.pop(-1)
|
|
894
|
+
if TILE_SHAPE_ORDER == "outer_to_inner":
|
|
895
|
+
return symbols_remaining.pop(0)
|
|
896
|
+
|
|
897
|
+
if TILE_SHAPE_ORDER == "inner_to_outer_one_rv_at_a_time":
|
|
898
|
+
# Continue with a symbol representing the parent tile of the last symbol
|
|
899
|
+
# if possible. Otherwise (see return), just grab any symbol.
|
|
900
|
+
choice = what_tiles_symbol.get_outer_tiles(prev_symbol, none_if_fail=True)
|
|
901
|
+
if choice is not None and choice in symbols_remaining:
|
|
902
|
+
symbols_remaining.remove(choice)
|
|
903
|
+
return choice
|
|
904
|
+
# Pick a symbol that has:
|
|
905
|
+
# - Nobody tiling it
|
|
906
|
+
# - The smallest maximum size
|
|
907
|
+
strides = [s for s in symbols_remaining if what_tiles_symbol.is_stride(s)]
|
|
908
|
+
choice = -1
|
|
909
|
+
if strides:
|
|
910
|
+
max_size = what_tiles_symbol.get_max_size(strides[choice])
|
|
911
|
+
for i, s in enumerate(strides):
|
|
912
|
+
if what_tiles_symbol.get_inner_tiles(s, none_if_fail=True) is None:
|
|
913
|
+
if what_tiles_symbol.get_max_size(s) < max_size:
|
|
914
|
+
choice = i
|
|
915
|
+
max_size = what_tiles_symbol.get_max_size(s)
|
|
916
|
+
choice = symbols_remaining.index(strides[choice])
|
|
917
|
+
return symbols_remaining.pop(choice)
|
|
918
|
+
elif TILE_SHAPE_ORDER == "outer_to_inner_one_rv_at_a_time":
|
|
919
|
+
# Continue with a symbol representing the child tile of the last symbol
|
|
920
|
+
# if possible. Otherwise (see return), just grab any symbol.
|
|
921
|
+
choice = what_tiles_symbol.get_inner_tiles(prev_symbol, none_if_fail=True)
|
|
922
|
+
if choice is not None and choice in symbols_remaining:
|
|
923
|
+
symbols_remaining.remove(choice)
|
|
924
|
+
return choice
|
|
925
|
+
# Pick a symbol that has:
|
|
926
|
+
# - Tiles nobody
|
|
927
|
+
# - The smallest maximum size
|
|
928
|
+
strides = [s for s in symbols_remaining if what_tiles_symbol.is_stride(s)]
|
|
929
|
+
choice = 0
|
|
930
|
+
if strides:
|
|
931
|
+
max_size = what_tiles_symbol.get_max_size(strides[choice])
|
|
932
|
+
for i, s in enumerate(strides):
|
|
933
|
+
if what_tiles_symbol.get_outer_tiles(s, none_if_fail=True) is None:
|
|
934
|
+
if what_tiles_symbol.get_max_size(s) < max_size:
|
|
935
|
+
choice = i
|
|
936
|
+
max_size = what_tiles_symbol.get_max_size(s)
|
|
937
|
+
choice = symbols_remaining.index(strides[choice])
|
|
938
|
+
return symbols_remaining.pop(choice)
|
|
939
|
+
else:
|
|
940
|
+
raise RuntimeError(f"BUG: invalid TILE_SHAPE_ORDER: {TILE_SHAPE_ORDER}")
|
|
941
|
+
|
|
942
|
+
last_stride_symbol = None # track the last stride symbol to select next symbol
|
|
943
|
+
symbol = None
|
|
944
|
+
while symbols_remaining:
|
|
945
|
+
# ==============================================================================
|
|
946
|
+
# Enumerate choices for a new symbol
|
|
947
|
+
# ==============================================================================
|
|
948
|
+
symbol = grab_symbol(last_stride_symbol)
|
|
949
|
+
|
|
950
|
+
choices = []
|
|
951
|
+
if what_tiles_symbol.is_stride(symbol):
|
|
952
|
+
last_stride_symbol = symbol
|
|
953
|
+
inner_tiles = what_tiles_symbol.get_inner_tiles(symbol, none_if_fail=True)
|
|
954
|
+
outer_tiles = what_tiles_symbol.get_outer_tiles(symbol, none_if_fail=True)
|
|
955
|
+
|
|
956
|
+
# Figure out inner size and outer size
|
|
957
|
+
if inner_tiles in symbols_enumerated:
|
|
958
|
+
inner_tiles_type = "enumerated"
|
|
959
|
+
inner_size = None
|
|
960
|
+
elif isinstance(inner_tiles, int):
|
|
961
|
+
inner_tiles_type = "set"
|
|
962
|
+
inner_size = inner_tiles
|
|
963
|
+
else:
|
|
964
|
+
inner_tiles_type = "unknown"
|
|
965
|
+
inner_size = 1
|
|
966
|
+
|
|
967
|
+
if outer_tiles in symbols_enumerated:
|
|
968
|
+
outer_tiles_type = "enumerated"
|
|
969
|
+
outer_size = None
|
|
970
|
+
elif isinstance(outer_tiles, int):
|
|
971
|
+
outer_tiles_type = "set"
|
|
972
|
+
outer_size = outer_tiles
|
|
973
|
+
else:
|
|
974
|
+
outer_tiles_type = "unknown"
|
|
975
|
+
outer_size = what_tiles_symbol.get_max_size(outer_tiles)
|
|
976
|
+
|
|
977
|
+
if inner_tiles_type == "enumerated" and outer_tiles_type == "enumerated":
|
|
978
|
+
raise RuntimeError(
|
|
979
|
+
f"BUG: both inner, {inner_tiles}, and outer, {outer_tiles},"
|
|
980
|
+
f"tiles of {symbol} are enumerated (thus far: {symbols_enumerated})"
|
|
981
|
+
)
|
|
982
|
+
if inner_tiles_type == "unknown" and outer_tiles_type == "unknown":
|
|
983
|
+
raise RuntimeError("BUG: both inner and outer tiles are unknown")
|
|
984
|
+
|
|
985
|
+
# Use inner size and outer size to generate choices
|
|
986
|
+
if inner_tiles_type in {"set", "unknown"} and outer_tiles_type in {
|
|
987
|
+
"set",
|
|
988
|
+
"unknown",
|
|
989
|
+
}:
|
|
990
|
+
factorize = math.ceil(outer_size / inner_size)
|
|
991
|
+
factors = list(get_possible_factor_sizes(factorize, imperfect))
|
|
992
|
+
scaled = np.array(factors) * inner_size
|
|
993
|
+
choices.append(append_vector(choices_enumerated, scaled))
|
|
994
|
+
elif inner_tiles_type == "enumerated":
|
|
995
|
+
assert isinstance(outer_size, int)
|
|
996
|
+
i = symbols_enumerated.index(inner_tiles)
|
|
997
|
+
for inner_choice in np.unique(choices_enumerated[:, i]):
|
|
998
|
+
partition = choices_enumerated[
|
|
999
|
+
np.where(choices_enumerated[:, i] == inner_choice)
|
|
1000
|
+
]
|
|
1001
|
+
factorize = math.ceil(outer_size / inner_choice)
|
|
1002
|
+
factors = list(get_possible_factor_sizes(factorize, imperfect))
|
|
1003
|
+
scaled = np.array(factors) * inner_choice
|
|
1004
|
+
choices.append(append_vector(partition, scaled))
|
|
1005
|
+
else:
|
|
1006
|
+
assert outer_tiles_type == "enumerated"
|
|
1007
|
+
assert isinstance(inner_size, int)
|
|
1008
|
+
i = symbols_enumerated.index(outer_tiles)
|
|
1009
|
+
for outer_choice in np.unique(choices_enumerated[:, i]):
|
|
1010
|
+
partition = choices_enumerated[
|
|
1011
|
+
np.where(choices_enumerated[:, i] == outer_choice)
|
|
1012
|
+
]
|
|
1013
|
+
factorize = math.ceil(outer_choice / inner_size)
|
|
1014
|
+
factors = list(get_possible_factor_sizes(factorize, imperfect))
|
|
1015
|
+
scaled = np.array(factors) * inner_size
|
|
1016
|
+
choices.append(append_vector(partition, scaled))
|
|
1017
|
+
elif what_tiles_symbol.is_initial_tile_shape(symbol):
|
|
1018
|
+
stride = what_tiles_symbol.get_stride(symbol)
|
|
1019
|
+
delta_choices = np.array(list(what_tiles_symbol.get_delta_choices(symbol)))
|
|
1020
|
+
|
|
1021
|
+
outer_stride = what_tiles_symbol.get_outer_tiles(stride, none_if_fail=True)
|
|
1022
|
+
assert outer_stride is None or isinstance(
|
|
1023
|
+
outer_stride, int
|
|
1024
|
+
), f"outer stride is symbol {outer_stride}"
|
|
1025
|
+
if outer_stride is None:
|
|
1026
|
+
outer_size = what_tiles_symbol.get_max_size(stride)
|
|
1027
|
+
else:
|
|
1028
|
+
outer_size = outer_stride
|
|
1029
|
+
|
|
1030
|
+
if not stride in symbols_enumerated and not isinstance(stride, int):
|
|
1031
|
+
raise RuntimeError(
|
|
1032
|
+
f"BUG: stride {stride} of initial tile shape "
|
|
1033
|
+
f"{symbol} is neither enumerated nor a specified value"
|
|
1034
|
+
)
|
|
1035
|
+
|
|
1036
|
+
if isinstance(stride, int):
|
|
1037
|
+
initial_choices = delta_choices + stride
|
|
1038
|
+
initial_choices = initial_choices[initial_choices <= outer_size]
|
|
1039
|
+
choices.append(append_vector(choices_enumerated, initial_choices))
|
|
1040
|
+
else:
|
|
1041
|
+
i = symbols_enumerated.index(stride)
|
|
1042
|
+
for stride_choice in np.unique(choices_enumerated[:, i]):
|
|
1043
|
+
partition = choices_enumerated[
|
|
1044
|
+
np.where(choices_enumerated[:, i] == stride_choice)
|
|
1045
|
+
]
|
|
1046
|
+
initial_choices = delta_choices + stride_choice
|
|
1047
|
+
initial_choices = initial_choices[initial_choices <= outer_size]
|
|
1048
|
+
choices.append(append_vector(partition, initial_choices))
|
|
1049
|
+
else:
|
|
1050
|
+
raise RuntimeError(
|
|
1051
|
+
f"BUG: symbol {symbol} is neither stride nor initial tile shape"
|
|
1052
|
+
)
|
|
1053
|
+
|
|
1054
|
+
# if not partitions:
|
|
1055
|
+
# return np.array([]).reshape(-1, len(symbols))
|
|
1056
|
+
|
|
1057
|
+
prev_size = choices_enumerated.shape[0] if choices_enumerated is not None else 1
|
|
1058
|
+
choices_enumerated = np.concatenate(choices, axis=0)
|
|
1059
|
+
job.n_total_pmappings *= choices_enumerated.shape[0] / max(1, prev_size)
|
|
1060
|
+
symbols_enumerated.append(symbol)
|
|
1061
|
+
log_message("enumerate", f"{symbol}", f"size={choices_enumerated.shape[0]}")
|
|
1062
|
+
|
|
1063
|
+
# ==============================================================================
|
|
1064
|
+
# Max fused loops per rank check
|
|
1065
|
+
# ==============================================================================
|
|
1066
|
+
|
|
1067
|
+
prev_size = choices_enumerated.shape[0]
|
|
1068
|
+
choices_enumerated = check_loops(
|
|
1069
|
+
symbols_enumerated,
|
|
1070
|
+
choices_enumerated,
|
|
1071
|
+
max_loop_check_groups,
|
|
1072
|
+
what_tiles_symbol,
|
|
1073
|
+
)
|
|
1074
|
+
job.log_porp_pmappings_kept(
|
|
1075
|
+
f"max_fused_loops_per_rank_variable",
|
|
1076
|
+
choices_enumerated.shape[0] / max(1, prev_size),
|
|
1077
|
+
)
|
|
1078
|
+
log_message(
|
|
1079
|
+
"max_fused_loops_per_rank_variable", f"size={choices_enumerated.shape[0]}"
|
|
1080
|
+
)
|
|
1081
|
+
|
|
1082
|
+
# ==============================================================================
|
|
1083
|
+
# Create initial Pareto-finding goals
|
|
1084
|
+
# ==============================================================================
|
|
1085
|
+
symbol2goal = {}
|
|
1086
|
+
|
|
1087
|
+
def update_symbol2goal(
|
|
1088
|
+
symbol: Symbol, goal: Goal, s2g: dict[Symbol, Goal] = None
|
|
1089
|
+
):
|
|
1090
|
+
if s2g is None:
|
|
1091
|
+
s2g = symbol2goal
|
|
1092
|
+
s2g[symbol] = goal | s2g.get(symbol, Goal())
|
|
1093
|
+
|
|
1094
|
+
# If we're a symbol and a non-enumerated outer loop depends on us, then we need
|
|
1095
|
+
# to track this loop. Minimize it if we're imperfect (giving the outer the most
|
|
1096
|
+
# choices possible), or diff if we're perfect (since perfect constrains choices
|
|
1097
|
+
# so we can't just min).
|
|
1098
|
+
for s in symbols_enumerated:
|
|
1099
|
+
per_prime_factor = not (
|
|
1100
|
+
IMPERFECT
|
|
1101
|
+
or _get_n_prime_factors(what_tiles_symbol.get_max_size(s)) == 1
|
|
1102
|
+
)
|
|
1103
|
+
tiles = what_tiles_symbol.get_outer_tiles(s, none_if_fail=True)
|
|
1104
|
+
if isinstance(tiles, Symbol) and tiles not in symbols_enumerated:
|
|
1105
|
+
update_symbol2goal(
|
|
1106
|
+
s, Goal("min_per_prime_factor" if per_prime_factor else "min")
|
|
1107
|
+
)
|
|
1108
|
+
|
|
1109
|
+
# Same for inner loops depending on us, but maximize if we're imperfect
|
|
1110
|
+
tiled_by = what_tiles_symbol.get_inner_tiles(s, none_if_fail=True)
|
|
1111
|
+
if isinstance(tiled_by, Symbol) and tiled_by not in symbols_enumerated:
|
|
1112
|
+
update_symbol2goal(
|
|
1113
|
+
s, Goal("max_per_prime_factor" if per_prime_factor else "max")
|
|
1114
|
+
)
|
|
1115
|
+
|
|
1116
|
+
# If we need to keep this symbol, must preserve all choices for it
|
|
1117
|
+
for s in set(symbols_enumerated) & set(keep_symbols):
|
|
1118
|
+
update_symbol2goal(s, Goal("diff"))
|
|
1119
|
+
|
|
1120
|
+
symbols_non_enumerated_set = set(symbols) - set(symbols_enumerated)
|
|
1121
|
+
sym_enumerated_set = set(symbols_enumerated)
|
|
1122
|
+
|
|
1123
|
+
if job.spec.mapper.ffm._count_option_for_mapsapce_size_evaluation != ():
|
|
1124
|
+
choices_enumerated = choices_enumerated[:1, :]
|
|
1125
|
+
continue
|
|
1126
|
+
|
|
1127
|
+
choices_enumerated_float = choices_enumerated.astype(util.NUMPY_FLOAT_TYPE)
|
|
1128
|
+
|
|
1129
|
+
# ==============================================================================
|
|
1130
|
+
# Create functions to Pareto using objectives
|
|
1131
|
+
# ==============================================================================
|
|
1132
|
+
for objective in list(objectives):
|
|
1133
|
+
goals = partition_formula(
|
|
1134
|
+
objective.formula, sym_enumerated_set, what_tiles_symbol.bounds
|
|
1135
|
+
)
|
|
1136
|
+
if any(g.goal == "diff" for g in goals.values()):
|
|
1137
|
+
goals2 = partition_formula(
|
|
1138
|
+
sympy.expand(objective.formula),
|
|
1139
|
+
sym_enumerated_set,
|
|
1140
|
+
what_tiles_symbol.bounds,
|
|
1141
|
+
)
|
|
1142
|
+
goals = min(
|
|
1143
|
+
(goals, goals2),
|
|
1144
|
+
key=lambda x: sum(g.goal == "diff" for g in x.values()),
|
|
1145
|
+
)
|
|
1146
|
+
|
|
1147
|
+
# ==========================================================================
|
|
1148
|
+
# If there's a max value, then check for validity
|
|
1149
|
+
# ==========================================================================
|
|
1150
|
+
complete = objective.formula.free_symbols.issubset(sym_enumerated_set)
|
|
1151
|
+
prev_size = choices_enumerated.shape[0]
|
|
1152
|
+
if objective.max_value is not None:
|
|
1153
|
+
try:
|
|
1154
|
+
# minimize_for_objective may raise a TypeError if there's unknown
|
|
1155
|
+
# symbols
|
|
1156
|
+
result = eval_objective(
|
|
1157
|
+
objective.formula,
|
|
1158
|
+
choices_enumerated_float,
|
|
1159
|
+
minimize_formula=objective.formula,
|
|
1160
|
+
)
|
|
1161
|
+
if objective.inclusive:
|
|
1162
|
+
valid = result <= objective.max_value
|
|
1163
|
+
else:
|
|
1164
|
+
valid = result < objective.max_value
|
|
1165
|
+
if not isinstance(valid, np.ndarray):
|
|
1166
|
+
valid = (
|
|
1167
|
+
np.zeros(choices_enumerated.shape[0], dtype=bool) + valid
|
|
1168
|
+
)
|
|
1169
|
+
choices_enumerated = choices_enumerated[valid]
|
|
1170
|
+
choices_enumerated_float = choices_enumerated_float[valid]
|
|
1171
|
+
except (TypeError, ValueError):
|
|
1172
|
+
pass
|
|
1173
|
+
if objective.min_value is not None:
|
|
1174
|
+
try:
|
|
1175
|
+
# minimize_for_objective may raise a TypeError if there's unknown
|
|
1176
|
+
# symbols
|
|
1177
|
+
result = eval_objective(
|
|
1178
|
+
objective.formula,
|
|
1179
|
+
choices_enumerated_float,
|
|
1180
|
+
maximize_formula=objective.formula,
|
|
1181
|
+
)
|
|
1182
|
+
if objective.inclusive:
|
|
1183
|
+
valid = result >= objective.min_value
|
|
1184
|
+
else:
|
|
1185
|
+
valid = result > objective.min_value
|
|
1186
|
+
if not isinstance(valid, np.ndarray):
|
|
1187
|
+
valid = (
|
|
1188
|
+
np.zeros(choices_enumerated.shape[0], dtype=bool) + valid
|
|
1189
|
+
)
|
|
1190
|
+
|
|
1191
|
+
if not objective.try_best_if_none_reaches_min:
|
|
1192
|
+
choices_enumerated = choices_enumerated[valid]
|
|
1193
|
+
choices_enumerated_float = choices_enumerated_float[valid]
|
|
1194
|
+
else:
|
|
1195
|
+
if valid.any():
|
|
1196
|
+
choices_enumerated = choices_enumerated[valid]
|
|
1197
|
+
choices_enumerated_float = choices_enumerated_float[valid]
|
|
1198
|
+
elif complete:
|
|
1199
|
+
valid |= result == result.min()
|
|
1200
|
+
choices_enumerated = choices_enumerated[valid]
|
|
1201
|
+
choices_enumerated_float = choices_enumerated_float[valid]
|
|
1202
|
+
except (TypeError, ValueError):
|
|
1203
|
+
pass
|
|
1204
|
+
|
|
1205
|
+
porp = sum(valid) / max(1, choices_enumerated.shape[0])
|
|
1206
|
+
job.log_porp_pmappings_kept(
|
|
1207
|
+
f"{objective.name}",
|
|
1208
|
+
sum(valid) / max(1, prev_size),
|
|
1209
|
+
)
|
|
1210
|
+
log_message(f"Valid check", f"{objective.name}", f"porp={porp:.2%}")
|
|
1211
|
+
if complete:
|
|
1212
|
+
objective.max_value = None # We don't care anymore
|
|
1213
|
+
if objective.only_care_if_valid:
|
|
1214
|
+
objectives.remove(objective)
|
|
1215
|
+
log_message(f"Removed {objective.name} because it is always valid")
|
|
1216
|
+
goals.clear()
|
|
1217
|
+
|
|
1218
|
+
log_message(f"formula", f"{objective.formula}", f"{goals}")
|
|
1219
|
+
|
|
1220
|
+
for symbol, goal in goals.items():
|
|
1221
|
+
update_symbol2goal(symbol, goal)
|
|
1222
|
+
|
|
1223
|
+
job.n_evaluated_pmappings += choices_enumerated.shape[0]
|
|
1224
|
+
if not choices_enumerated.shape[0]:
|
|
1225
|
+
return np.array([]).reshape(-1, len(symbols))
|
|
1226
|
+
|
|
1227
|
+
if choices_enumerated.shape[0] < 100:
|
|
1228
|
+
continue
|
|
1229
|
+
|
|
1230
|
+
# ==============================================================================
|
|
1231
|
+
# Coalesce symbols. This simplifies our tracked goals. It also breaks down
|
|
1232
|
+
# partially-unknown goals into fully-known and/or fully-unknown goals.
|
|
1233
|
+
# ==============================================================================
|
|
1234
|
+
symbol2goal = coalesce_symbols(
|
|
1235
|
+
symbols_enumerated=symbols_enumerated,
|
|
1236
|
+
symbol2goal=symbol2goal,
|
|
1237
|
+
update_symbol2goal=update_symbol2goal,
|
|
1238
|
+
log_message=log_message,
|
|
1239
|
+
bounds=what_tiles_symbol.bounds,
|
|
1240
|
+
)
|
|
1241
|
+
|
|
1242
|
+
log_message("coalesce symbols", f"{symbol2goal}")
|
|
1243
|
+
|
|
1244
|
+
paretoed_by_key = fzs((f, g.goal) for f, g in symbol2goal.items())
|
|
1245
|
+
if any(p.issubset(paretoed_by_key) for p in paretoed_by):
|
|
1246
|
+
job.log_message(
|
|
1247
|
+
"Skipping Pareto because we've already found a Pareto with these objectives."
|
|
1248
|
+
)
|
|
1249
|
+
continue
|
|
1250
|
+
paretoed_by.append(paretoed_by_key)
|
|
1251
|
+
|
|
1252
|
+
objective_values = {}
|
|
1253
|
+
for formula, goal in list(symbol2goal.items()):
|
|
1254
|
+
objective_values[formula] = eval_objective(
|
|
1255
|
+
formula, choices_enumerated_float
|
|
1256
|
+
)
|
|
1257
|
+
symbol2goal[formula] = goal
|
|
1258
|
+
log_message("eval", f"{goal.goal}", f"{formula}")
|
|
1259
|
+
|
|
1260
|
+
if not objective_values:
|
|
1261
|
+
# Objective values don't depend on tile shapes
|
|
1262
|
+
choices_enumerated = choices_enumerated[:1, :]
|
|
1263
|
+
choices_enumerated_float = choices_enumerated_float[:1, :]
|
|
1264
|
+
|
|
1265
|
+
elif not all(
|
|
1266
|
+
symbol2goal.get(s, None) == Goal("diff") for s in symbols_enumerated
|
|
1267
|
+
):
|
|
1268
|
+
to_pareto = np.concatenate(
|
|
1269
|
+
[v.reshape(-1, 1) for v in objective_values.values()], axis=1
|
|
1270
|
+
)
|
|
1271
|
+
log_message("Pareto", f"size {to_pareto.shape[0]}", "with objectives:")
|
|
1272
|
+
for obj in objectives:
|
|
1273
|
+
log_message(f"\t{obj.name}: {obj.formula}")
|
|
1274
|
+
log_message("Formulas:")
|
|
1275
|
+
for formula, goal in symbol2goal.items():
|
|
1276
|
+
log_message(f"\t{goal.goal}: {formula}")
|
|
1277
|
+
|
|
1278
|
+
drop_cols = []
|
|
1279
|
+
pareto_goals = []
|
|
1280
|
+
for i, (formula, goal) in enumerate(objective_values.items()):
|
|
1281
|
+
goal = symbol2goal[formula]
|
|
1282
|
+
if i not in drop_cols:
|
|
1283
|
+
pareto_goals.append(goal.goal)
|
|
1284
|
+
to_pareto = to_pareto[
|
|
1285
|
+
:, [i for i in range(to_pareto.shape[1]) if i not in drop_cols]
|
|
1286
|
+
]
|
|
1287
|
+
keep = makepareto_numpy(to_pareto, pareto_goals, dirty=True)
|
|
1288
|
+
prev_size = choices_enumerated.shape[0]
|
|
1289
|
+
choices_enumerated = choices_enumerated[keep]
|
|
1290
|
+
job.log_porp_pmappings_kept(
|
|
1291
|
+
f"Pareto", sum(keep) / choices_enumerated.shape[0]
|
|
1292
|
+
)
|
|
1293
|
+
log_message("pareto", f"size {prev_size} -> {choices_enumerated.shape[0]}")
|
|
1294
|
+
|
|
1295
|
+
# ==================================================================================
|
|
1296
|
+
# Return the choices
|
|
1297
|
+
# ==================================================================================
|
|
1298
|
+
t = time.time() - start_time
|
|
1299
|
+
if t > 60:
|
|
1300
|
+
a = [
|
|
1301
|
+
f"Total time: {t:.2f}s",
|
|
1302
|
+
f"Pmapping: {job.mapping.compact_str()}",
|
|
1303
|
+
]
|
|
1304
|
+
print("\n\t" + f"\n\t".join(a + job.messages))
|
|
1305
|
+
|
|
1306
|
+
# Rearrange in tile shape order
|
|
1307
|
+
if choices_enumerated is None:
|
|
1308
|
+
return np.array([])
|
|
1309
|
+
return choices_enumerated[:, [symbols_enumerated.index(s) for s in symbols]]
|
|
1310
|
+
|
|
1311
|
+
|
|
1312
|
+
def makesymbol(name: str):
|
|
1313
|
+
# TODO: Do the solve() calls work with integer=True?
|
|
1314
|
+
return Symbol(name, positive=True, integer=True)
|
|
1315
|
+
|
|
1316
|
+
|
|
1317
|
+
def make_keep_symbols(pmapping: Mapping) -> set[Symbol]:
|
|
1318
|
+
keep_symbols = set()
|
|
1319
|
+
for node in pmapping.nodes:
|
|
1320
|
+
if isinstance(node, Loop) and node._fused:
|
|
1321
|
+
if isinstance(node.initial_tile_shape, Symbol):
|
|
1322
|
+
keep_symbols.add(node.initial_tile_shape)
|
|
1323
|
+
if isinstance(node.tile_shape, Symbol):
|
|
1324
|
+
keep_symbols.add(node.tile_shape)
|
|
1325
|
+
return keep_symbols
|
|
1326
|
+
|
|
1327
|
+
|
|
1328
|
+
def get_rank_var_to_fused_loops(
|
|
1329
|
+
pmapping: Mapping, shape: dict[str, int]
|
|
1330
|
+
) -> dict[str, list[Symbol]]:
|
|
1331
|
+
rank_var_to_fused_loops: dict[str, list[Symbol]] = {}
|
|
1332
|
+
for node in [n for n in pmapping.nodes if isinstance(n, Loop) and n._fused]:
|
|
1333
|
+
rank_var_to_fused_loops.setdefault(node.rank_variable, []).append(
|
|
1334
|
+
node.tile_shape
|
|
1335
|
+
)
|
|
1336
|
+
return rank_var_to_fused_loops
|
|
1337
|
+
|
|
1338
|
+
|
|
1339
|
+
def set_last_tile_shape_to_one(pmapping):
|
|
1340
|
+
pmapping = pmapping.nodes
|
|
1341
|
+
|
|
1342
|
+
rank_var_to_last_node = {}
|
|
1343
|
+
for node in pmapping:
|
|
1344
|
+
if isinstance(node, Temporal) or isinstance(node, Spatial):
|
|
1345
|
+
rank_var_to_last_node[node.rank_variable] = node
|
|
1346
|
+
|
|
1347
|
+
for last_node in rank_var_to_last_node.values():
|
|
1348
|
+
last_node.initial_tile_shape = None
|
|
1349
|
+
last_node.tile_shape = 1
|
|
1350
|
+
|
|
1351
|
+
|
|
1352
|
+
# This was made only so we could do some counting of the time.
|
|
1353
|
+
def call_compiled_objective(f, *args):
|
|
1354
|
+
return f(*args)
|
|
1355
|
+
|
|
1356
|
+
|
|
1357
|
+
def _make_tile_shapes(job: "Job"):
|
|
1358
|
+
# We're going to convert the job into a list of symbols and objectives
|
|
1359
|
+
pmapping = job.mapping
|
|
1360
|
+
constraints = job.constraints
|
|
1361
|
+
constraints.set_loop_indices(pmapping.nodes)
|
|
1362
|
+
set_last_tile_shape_to_one(pmapping)
|
|
1363
|
+
t0 = time.time()
|
|
1364
|
+
(
|
|
1365
|
+
symbols,
|
|
1366
|
+
symbolic_df,
|
|
1367
|
+
per_memory_usage_df,
|
|
1368
|
+
usage_df,
|
|
1369
|
+
tensor2mapping,
|
|
1370
|
+
) = run_model(job)
|
|
1371
|
+
|
|
1372
|
+
model_time = time.time() - t0
|
|
1373
|
+
shape = job.rank_variable_bounds
|
|
1374
|
+
what_tiles_symbol = SymbolRelations.from_pmapping_and_shape(
|
|
1375
|
+
pmapping, shape, job.spec.workload
|
|
1376
|
+
)
|
|
1377
|
+
keep_symbols = make_keep_symbols(pmapping)
|
|
1378
|
+
rank_var_to_fused_loops = get_rank_var_to_fused_loops(pmapping, shape)
|
|
1379
|
+
all_fused_loops = set(sum(rank_var_to_fused_loops.values(), []))
|
|
1380
|
+
|
|
1381
|
+
objectives = []
|
|
1382
|
+
|
|
1383
|
+
# ==================================================================================
|
|
1384
|
+
# Loop bounds constraints. Put these before the other objectives so that hopefully
|
|
1385
|
+
# if 100% of the pmappings are pruned, then we're given the actual architecture
|
|
1386
|
+
# component that caused it and not the loop bound constraint.
|
|
1387
|
+
# ==================================================================================
|
|
1388
|
+
loops = [n for n in pmapping.nodes if isinstance(n, Loop)]
|
|
1389
|
+
for c in constraints.loop_bounds_constraints:
|
|
1390
|
+
min_value, max_value, inclusive = None, None, True
|
|
1391
|
+
is_product = "product" in c.constraint.operator
|
|
1392
|
+
operator = c.constraint.operator.replace("product", "")
|
|
1393
|
+
if operator in ["==", "<=", "<"]:
|
|
1394
|
+
max_value = c.constraint.value
|
|
1395
|
+
if operator in [">=", ">", "=="]:
|
|
1396
|
+
min_value = c.constraint.value
|
|
1397
|
+
if operator in ["<", ">"]:
|
|
1398
|
+
inclusive = False
|
|
1399
|
+
|
|
1400
|
+
targets = []
|
|
1401
|
+
for i in c._target_loop_indices:
|
|
1402
|
+
n = loops[i]
|
|
1403
|
+
size = what_tiles_symbol.get_outer_tiles(n.tile_shape, none_if_fail=True)
|
|
1404
|
+
if size is None:
|
|
1405
|
+
size = what_tiles_symbol.get_max_size(n.tile_shape)
|
|
1406
|
+
targets.append(size / n.tile_shape)
|
|
1407
|
+
|
|
1408
|
+
# targets = [loops[i]._calculated_n_iterations for i in c._target_loop_indices]
|
|
1409
|
+
if not targets:
|
|
1410
|
+
continue
|
|
1411
|
+
|
|
1412
|
+
if is_product:
|
|
1413
|
+
targets = [sympy.Mul(*targets)]
|
|
1414
|
+
|
|
1415
|
+
if max_value is None and min_value is not None:
|
|
1416
|
+
max_value = -min_value
|
|
1417
|
+
targets = [-target for target in targets]
|
|
1418
|
+
min_value = None
|
|
1419
|
+
|
|
1420
|
+
for target in targets:
|
|
1421
|
+
objectives.append(
|
|
1422
|
+
Objective(
|
|
1423
|
+
name=f"loop_bounds_{c.constraint}",
|
|
1424
|
+
formula=target,
|
|
1425
|
+
symbols=symbols,
|
|
1426
|
+
only_care_if_valid=True,
|
|
1427
|
+
max_value=max_value,
|
|
1428
|
+
min_value=min_value,
|
|
1429
|
+
inclusive=inclusive,
|
|
1430
|
+
)
|
|
1431
|
+
)
|
|
1432
|
+
|
|
1433
|
+
# ==================================================================================
|
|
1434
|
+
# Memory usage and usage constraints.
|
|
1435
|
+
# ==================================================================================
|
|
1436
|
+
for k, v in {**per_memory_usage_df, **usage_df}.items():
|
|
1437
|
+
# If we only track for pmappings, we only care if it's valid. If we track for
|
|
1438
|
+
# all, we care about the value too.
|
|
1439
|
+
|
|
1440
|
+
only_care_if_valid = False
|
|
1441
|
+
if k in job.memories_track_pmappings_only:
|
|
1442
|
+
only_care_if_valid = True
|
|
1443
|
+
|
|
1444
|
+
# TODO: Update check to see if we may be sharing usage with other
|
|
1445
|
+
# pmappings in parallel/pipeline.
|
|
1446
|
+
if k in usage_df:
|
|
1447
|
+
only_care_if_valid = True
|
|
1448
|
+
|
|
1449
|
+
objectives.append(
|
|
1450
|
+
Objective(
|
|
1451
|
+
name=k,
|
|
1452
|
+
formula=v,
|
|
1453
|
+
symbols=symbols,
|
|
1454
|
+
only_care_if_valid=only_care_if_valid,
|
|
1455
|
+
max_value=1,
|
|
1456
|
+
)
|
|
1457
|
+
)
|
|
1458
|
+
|
|
1459
|
+
# ==================================================================================
|
|
1460
|
+
# Min usage constraints. Put this last because it has some try best if none reach
|
|
1461
|
+
# min logic.
|
|
1462
|
+
# ==================================================================================
|
|
1463
|
+
for (
|
|
1464
|
+
component_name,
|
|
1465
|
+
name,
|
|
1466
|
+
), constraint in job.constraints.min_usage_constraints.items():
|
|
1467
|
+
objectives.append(
|
|
1468
|
+
Objective(
|
|
1469
|
+
name=f"min_usage_{component_name}_{name}",
|
|
1470
|
+
formula=v,
|
|
1471
|
+
symbols=symbols,
|
|
1472
|
+
only_care_if_valid=True,
|
|
1473
|
+
min_value=constraint.min_usage,
|
|
1474
|
+
try_best_if_none_reaches_min=True,
|
|
1475
|
+
)
|
|
1476
|
+
)
|
|
1477
|
+
|
|
1478
|
+
for k, v in symbolic_df.items():
|
|
1479
|
+
if "Total" not in k:
|
|
1480
|
+
continue
|
|
1481
|
+
|
|
1482
|
+
objectives.append(
|
|
1483
|
+
Objective(
|
|
1484
|
+
name=k,
|
|
1485
|
+
formula=v,
|
|
1486
|
+
symbols=symbols,
|
|
1487
|
+
)
|
|
1488
|
+
)
|
|
1489
|
+
|
|
1490
|
+
rank2symbols = {}
|
|
1491
|
+
for node in pmapping.nodes:
|
|
1492
|
+
if isinstance(node, (Temporal, Spatial)):
|
|
1493
|
+
if node.tile_shape in symbols:
|
|
1494
|
+
rank2symbols.setdefault(node.rank_variable, []).append(node.tile_shape)
|
|
1495
|
+
|
|
1496
|
+
max_loop_check_groups = [
|
|
1497
|
+
(job.spec.mapper.ffm.max_fused_loops, all_fused_loops),
|
|
1498
|
+
*[
|
|
1499
|
+
(job.spec.mapper.ffm.max_fused_loops_per_rank_variable, x)
|
|
1500
|
+
for x in rank_var_to_fused_loops.values()
|
|
1501
|
+
],
|
|
1502
|
+
]
|
|
1503
|
+
|
|
1504
|
+
max_loop_check_groups = [g for g in max_loop_check_groups if g[1]]
|
|
1505
|
+
|
|
1506
|
+
choices_enumerated = get_tile_shape_choices(
|
|
1507
|
+
objectives=objectives,
|
|
1508
|
+
symbols=symbols,
|
|
1509
|
+
what_tiles_symbol=what_tiles_symbol,
|
|
1510
|
+
job=job,
|
|
1511
|
+
keep_symbols=keep_symbols,
|
|
1512
|
+
max_loop_check_groups=max_loop_check_groups,
|
|
1513
|
+
)
|
|
1514
|
+
|
|
1515
|
+
try:
|
|
1516
|
+
compiled_df = compile_dict(symbols, symbolic_df)
|
|
1517
|
+
compiled_per_memory_usage_df = compile_dict(symbols, per_memory_usage_df)
|
|
1518
|
+
compiled_usage_df = compile_dict(symbols, usage_df)
|
|
1519
|
+
except Exception as e:
|
|
1520
|
+
print("Compilation failed for this mapping:")
|
|
1521
|
+
for node in pmapping.nodes:
|
|
1522
|
+
if hasattr(node, "compact_str"):
|
|
1523
|
+
print(node.compact_str())
|
|
1524
|
+
print(symbolic_df)
|
|
1525
|
+
e.add_note("Compilation failed")
|
|
1526
|
+
raise
|
|
1527
|
+
|
|
1528
|
+
choices_float = choices_enumerated.astype(util.NUMPY_FLOAT_TYPE)
|
|
1529
|
+
# choices_float = np.tile(choices_float, (1000000, 1))
|
|
1530
|
+
# choices_enumerated = np.tile(choices_enumerated, (1000000, 1))
|
|
1531
|
+
|
|
1532
|
+
df = {}
|
|
1533
|
+
for i, symbol in enumerate(symbols):
|
|
1534
|
+
df[symbol.name] = choices_enumerated[:, i]
|
|
1535
|
+
|
|
1536
|
+
t0 = time.time()
|
|
1537
|
+
for key in compiled_df:
|
|
1538
|
+
df[key] = call_compiled_objective(compiled_df[key], *choices_float.T)
|
|
1539
|
+
if "latency" in key and "first_latency" not in key:
|
|
1540
|
+
val = [df[key]] if isinstance(df[key], Number) else df[key]
|
|
1541
|
+
if any(l < 0 for l in val):
|
|
1542
|
+
raise ValueError(f"Negative latency for {key}: {val}")
|
|
1543
|
+
if "energy" in key:
|
|
1544
|
+
val = [df[key]] if isinstance(df[key], Number) else df[key]
|
|
1545
|
+
if any(l < 0 for l in val):
|
|
1546
|
+
raise ValueError(f"Negative energy for {key}: {val}")
|
|
1547
|
+
|
|
1548
|
+
# Some initial tile shapes are invalid
|
|
1549
|
+
for nloops, n in enumerate(
|
|
1550
|
+
node for node in job.mapping.nodes if isinstance(node, Loop) and node._fused
|
|
1551
|
+
):
|
|
1552
|
+
stride = n.tile_pattern.tile_shape
|
|
1553
|
+
initial = (
|
|
1554
|
+
n.tile_pattern.initial_tile_shape
|
|
1555
|
+
if n.tile_pattern.initial_tile_shape is not None
|
|
1556
|
+
else stride
|
|
1557
|
+
)
|
|
1558
|
+
outer_stride = what_tiles_symbol.get_outer_tiles(stride)
|
|
1559
|
+
outer_initial = what_tiles_symbol.get_initial(outer_stride, none_if_fail=True)
|
|
1560
|
+
outer_stride = (
|
|
1561
|
+
df[outer_stride.name] if isinstance(outer_stride, Symbol) else outer_stride
|
|
1562
|
+
)
|
|
1563
|
+
|
|
1564
|
+
outer_initial = (
|
|
1565
|
+
df[outer_initial.name]
|
|
1566
|
+
if isinstance(outer_initial, Symbol)
|
|
1567
|
+
else outer_stride
|
|
1568
|
+
)
|
|
1569
|
+
|
|
1570
|
+
rank_var_stride = df[stride.name] if isinstance(stride, Symbol) else stride
|
|
1571
|
+
rank_var_initial = df[initial.name] if isinstance(initial, Symbol) else initial
|
|
1572
|
+
|
|
1573
|
+
# NOTE: The concept of having one "n_iterations" is precarious when imperfect factorization in involved
|
|
1574
|
+
df[iterations2col(nloops)] = np.ceil(
|
|
1575
|
+
(outer_initial - rank_var_initial) / rank_var_stride + 1
|
|
1576
|
+
)
|
|
1577
|
+
df[f"lower_iterations<SEP>{nloops}"] = outer_stride - rank_var_initial
|
|
1578
|
+
|
|
1579
|
+
# Generate rank columns
|
|
1580
|
+
einsum: Einsum = job.spec.workload.einsums[job.einsum_name]
|
|
1581
|
+
for tensor_access in einsum.tensor_accesses:
|
|
1582
|
+
tensor = tensor_access.name
|
|
1583
|
+
projections = get_projection_expr(einsum, tensor)
|
|
1584
|
+
for rank, expr in projections.items():
|
|
1585
|
+
free_symbols = tuple(expr.free_symbols)
|
|
1586
|
+
free_symbols_str = tuple(symbol.name for symbol in free_symbols)
|
|
1587
|
+
if n.rank_variable not in free_symbols_str:
|
|
1588
|
+
continue
|
|
1589
|
+
|
|
1590
|
+
rank_stride = expr.coeff(n.rank_variable) * rank_var_stride
|
|
1591
|
+
|
|
1592
|
+
args = []
|
|
1593
|
+
for free_rank_var in free_symbols:
|
|
1594
|
+
if free_rank_var.name == n.rank_variable:
|
|
1595
|
+
args.append(rank_var_initial)
|
|
1596
|
+
else:
|
|
1597
|
+
args.append(shape[free_rank_var.name])
|
|
1598
|
+
rank_initial = lambdify(free_symbols, expr)(*args)
|
|
1599
|
+
|
|
1600
|
+
df[stride2col(rank, nloops)] = rank_stride
|
|
1601
|
+
df[initial2col(rank, nloops)] = rank_initial
|
|
1602
|
+
|
|
1603
|
+
try:
|
|
1604
|
+
df = pd.DataFrame(df, columns=df.keys())
|
|
1605
|
+
except ValueError as e:
|
|
1606
|
+
df = pd.DataFrame(df, columns=df.keys(), index=[0])
|
|
1607
|
+
assert not df.isna().any().any()
|
|
1608
|
+
|
|
1609
|
+
energy_cols = [c for c in df.columns if "Total<SEP>energy" in c]
|
|
1610
|
+
if (df[energy_cols] < 0).any(axis=None):
|
|
1611
|
+
mapping_with_negative_energy = df[(df[energy_cols] < 0).any(axis=1)]
|
|
1612
|
+
print(df.columns)
|
|
1613
|
+
msg = ""
|
|
1614
|
+
for _, row in mapping_with_negative_energy.iterrows():
|
|
1615
|
+
for k, v in row.items():
|
|
1616
|
+
msg += f"{k}: {v}\n"
|
|
1617
|
+
msg += "\n"
|
|
1618
|
+
raise RuntimeError(f"negative energy:\n{msg}")
|
|
1619
|
+
|
|
1620
|
+
job.n_valid_pmappings = job.n_total_pmappings * prod(
|
|
1621
|
+
job.pmapping_keep_rates.values()
|
|
1622
|
+
)
|
|
1623
|
+
return df, tensor2mapping
|
|
1624
|
+
|
|
1625
|
+
|
|
1626
|
+
def make_tile_shapes(job: "Job"):
|
|
1627
|
+
memory_limit = job.memory_limit // 8 # Bytes -> bits
|
|
1628
|
+
if job.memory_limit != float("inf"):
|
|
1629
|
+
try:
|
|
1630
|
+
resource.setrlimit(resource.RLIMIT_AS, (job.memory_limit, job.memory_limit))
|
|
1631
|
+
except (ValueError, OSError):
|
|
1632
|
+
# Ignore permission errors when trying to set memory limits
|
|
1633
|
+
pass
|
|
1634
|
+
|
|
1635
|
+
if job.time_limit != float("inf"):
|
|
1636
|
+
try:
|
|
1637
|
+
resource.setrlimit(
|
|
1638
|
+
resource.RLIMIT_CPU, (ceil(job.time_limit), ceil(job.time_limit))
|
|
1639
|
+
)
|
|
1640
|
+
except (ValueError, OSError):
|
|
1641
|
+
# Ignore permission errors when trying to set CPU limits
|
|
1642
|
+
pass
|
|
1643
|
+
|
|
1644
|
+
def format_memory_limit() -> str:
|
|
1645
|
+
if memory_limit == float("inf"):
|
|
1646
|
+
return "infinite"
|
|
1647
|
+
if memory_limit > 1024 * 1024 * 1024:
|
|
1648
|
+
return f"{memory_limit / (1024 * 1024 * 1024):.2f} GB"
|
|
1649
|
+
elif memory_limit > 1024 * 1024:
|
|
1650
|
+
return f"{memory_limit / (1024 * 1024):.2f} MB"
|
|
1651
|
+
elif memory_limit > 1024:
|
|
1652
|
+
return f"{memory_limit / 1024:.2f} KB"
|
|
1653
|
+
else:
|
|
1654
|
+
return f"{memory_limit:.2f} B"
|
|
1655
|
+
|
|
1656
|
+
try:
|
|
1657
|
+
return _make_tile_shapes(job)
|
|
1658
|
+
except MemoryError as e:
|
|
1659
|
+
s = f"Job ran out of memory with memory limit {format_memory_limit()}"
|
|
1660
|
+
job.log_message(f"Tile shape exploration failed: {s}")
|
|
1661
|
+
raise RuntimeError(job.pretty_str()) from e
|
|
1662
|
+
except TimeoutError as e:
|
|
1663
|
+
s = f"Job timed out with time limit {job.time_limit:.2f} seconds"
|
|
1664
|
+
job.log_message(f"Tile shape exploration failed: {s}")
|
|
1665
|
+
raise RuntimeError(job.pretty_str()) from e
|
|
1666
|
+
|
|
1667
|
+
finally:
|
|
1668
|
+
try:
|
|
1669
|
+
resource.setrlimit(
|
|
1670
|
+
resource.RLIMIT_AS, (resource.RLIM_INFINITY, resource.RLIM_INFINITY)
|
|
1671
|
+
)
|
|
1672
|
+
except (ValueError, OSError):
|
|
1673
|
+
# Ignore permission errors when trying to reset memory limits
|
|
1674
|
+
pass
|
|
1675
|
+
try:
|
|
1676
|
+
resource.setrlimit(
|
|
1677
|
+
resource.RLIMIT_CPU, (resource.RLIM_INFINITY, resource.RLIM_INFINITY)
|
|
1678
|
+
)
|
|
1679
|
+
except (ValueError, OSError):
|
|
1680
|
+
# Ignore permission errors when trying to reset CPU limits
|
|
1681
|
+
pass
|