accelforge 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accelforge/__init__.py +21 -0
- accelforge/_accelerated_imports.py +16 -0
- accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
- accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
- accelforge/_deprecate/_simanneal/simanneal.py +666 -0
- accelforge/_deprecate/_simanneal/tracking.py +105 -0
- accelforge/_deprecate/_simanneal/wrappers.py +218 -0
- accelforge/_deprecate/_simanneal2/__init__.py +7 -0
- accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
- accelforge/_deprecate/_simanneal2/tracking.py +116 -0
- accelforge/_deprecate/compatibility_util.py +181 -0
- accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
- accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
- accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
- accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
- accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
- accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
- accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
- accelforge/_deprecate/tags.py +69 -0
- accelforge/_deprecate/viz/__init__.py +0 -0
- accelforge/_deprecate/viz/interactive.py +159 -0
- accelforge/_deprecate/viz/reservationtree.py +307 -0
- accelforge/_deprecate/viz/ski_slope.py +88 -0
- accelforge/_version.py +15 -0
- accelforge/examples.py +39 -0
- accelforge/frontend/__init__.py +10 -0
- accelforge/frontend/_binding.py +129 -0
- accelforge/frontend/_workload_isl/__init__.py +2 -0
- accelforge/frontend/_workload_isl/_isl.py +149 -0
- accelforge/frontend/_workload_isl/_symbolic.py +141 -0
- accelforge/frontend/arch copy.py +1544 -0
- accelforge/frontend/arch.py +1642 -0
- accelforge/frontend/config.py +63 -0
- accelforge/frontend/mapper/__init__.py +5 -0
- accelforge/frontend/mapper/ffm.py +126 -0
- accelforge/frontend/mapper/mapper.py +7 -0
- accelforge/frontend/mapper/metrics.py +30 -0
- accelforge/frontend/mapping/__init__.py +1 -0
- accelforge/frontend/mapping/mapping.py +1736 -0
- accelforge/frontend/model.py +14 -0
- accelforge/frontend/renames.py +150 -0
- accelforge/frontend/spec copy.py +230 -0
- accelforge/frontend/spec.py +301 -0
- accelforge/frontend/variables.py +12 -0
- accelforge/frontend/workload.py +952 -0
- accelforge/mapper/FFM/__init__.py +9 -0
- accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
- accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
- accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
- accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
- accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
- accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
- accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
- accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
- accelforge/mapper/FFM/data.py +61 -0
- accelforge/mapper/FFM/main copy.py +236 -0
- accelforge/mapper/FFM/main.py +208 -0
- accelforge/mapper/FFM/mappings.py +510 -0
- accelforge/mapper/FFM/pmappings.py +310 -0
- accelforge/mapper/__init__.py +4 -0
- accelforge/mapper.py +0 -0
- accelforge/model/__init__.py +1 -0
- accelforge/model/_looptree/__init__.py +0 -0
- accelforge/model/_looptree/accesses.py +335 -0
- accelforge/model/_looptree/capacity/__init__.py +1 -0
- accelforge/model/_looptree/capacity/aggregators.py +36 -0
- accelforge/model/_looptree/capacity/capacity.py +47 -0
- accelforge/model/_looptree/energy.py +150 -0
- accelforge/model/_looptree/equivalent_ranks.py +29 -0
- accelforge/model/_looptree/latency/__init__.py +1 -0
- accelforge/model/_looptree/latency/latency.py +98 -0
- accelforge/model/_looptree/latency/memory.py +120 -0
- accelforge/model/_looptree/latency/processors.py +92 -0
- accelforge/model/_looptree/mapping_utilities.py +71 -0
- accelforge/model/_looptree/reuse/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
- accelforge/model/_looptree/reuse/isl/des.py +59 -0
- accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
- accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
- accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
- accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
- accelforge/model/_looptree/run.py +122 -0
- accelforge/model/_looptree/types.py +26 -0
- accelforge/model/_looptree/visualization/__init__.py +0 -0
- accelforge/model/_looptree/visualization/occupancy.py +11 -0
- accelforge/model/main.py +222 -0
- accelforge/plotting/__init__.py +2 -0
- accelforge/plotting/mappings.py +219 -0
- accelforge/plotting/specs.py +57 -0
- accelforge/util/__init__.py +4 -0
- accelforge/util/_base_analysis_types.py +24 -0
- accelforge/util/_basetypes.py +1089 -0
- accelforge/util/_frozenset.py +36 -0
- accelforge/util/_isl.py +29 -0
- accelforge/util/_itertools.py +14 -0
- accelforge/util/_mathfuncs.py +57 -0
- accelforge/util/_parse_expressions.py +339 -0
- accelforge/util/_picklecache.py +32 -0
- accelforge/util/_setexpressions.py +268 -0
- accelforge/util/_sympy/__init__.py +0 -0
- accelforge/util/_sympy/broadcast_max.py +18 -0
- accelforge/util/_visualization.py +112 -0
- accelforge/util/_yaml.py +579 -0
- accelforge/util/parallel.py +193 -0
- accelforge-0.0.1.dist-info/METADATA +64 -0
- accelforge-0.0.1.dist-info/RECORD +258 -0
- accelforge-0.0.1.dist-info/WHEEL +5 -0
- accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
- accelforge-0.0.1.dist-info/top_level.txt +5 -0
- docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
- docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
- docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
- docs/_build/html/_sources/fastfusion.rst.txt +20 -0
- docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
- docs/_build/html/_sources/index.rst.txt +87 -0
- docs/_build/html/_sources/modules.rst.txt +7 -0
- docs/_build/html/_sources/notes/citation.rst.txt +45 -0
- docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
- docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
- docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
- docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
- docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
- docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
- docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
- docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
- docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
- docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
- docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
- docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
- docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
- docs/_build/html/_sources/notes/spec.rst.txt +36 -0
- docs/source/_ext/include_attrs.py +213 -0
- docs/source/_ext/include_docstring.py +364 -0
- docs/source/_ext/include_functions.py +154 -0
- docs/source/_ext/include_notebook.py +131 -0
- docs/source/_ext/include_yaml.py +119 -0
- docs/source/_ext/inherited_attributes.py +222 -0
- docs/source/_ext/paths.py +4 -0
- docs/source/conf.py +79 -0
- examples/arches/compute_in_memory/_include.yaml +74 -0
- examples/arches/compute_in_memory/_include_functions.py +229 -0
- examples/arches/compute_in_memory/_load_spec.py +57 -0
- examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
- examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
- examples/arches/compute_in_memory/components/misc.py +195 -0
- examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
- examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
- examples/arches/compute_in_memory/isaac.yaml +233 -0
- examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
- examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
- examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
- examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
- examples/arches/eyeriss.yaml +68 -0
- examples/arches/fanout_variations/at_glb.yaml +31 -0
- examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
- examples/arches/fanout_variations/at_mac.yaml +31 -0
- examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
- examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
- examples/arches/nvdla.yaml +47 -0
- examples/arches/simple.yaml +28 -0
- examples/arches/tpu_v4i.yaml +67 -0
- examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
- examples/misc/component_annotated.yaml +33 -0
- examples/workloads/gpt3_6.7B.yaml +124 -0
- examples/workloads/matmuls.yaml +20 -0
- examples/workloads/mobilenet_28.yaml +81 -0
- examples/workloads/mobilenet_various_separate.yaml +106 -0
- examples/workloads/three_matmuls_annotated.yaml +59 -0
- notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
- notebooks/compute_in_memory/_scripts.py +339 -0
- notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
- notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
- notebooks/paths.py +4 -0
- notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
- notebooks/tutorials/FFM.ipynb +3498 -0
- notebooks/tutorials/_include.py +48 -0
- notebooks/tutorials/component_energy_area.ipynb +363 -0
- tests/Q_mapping.yaml +38 -0
- tests/__init__.py +0 -0
- tests/conv.mapping.yaml +27 -0
- tests/conv.workload.yaml +13 -0
- tests/conv_sym.mapping.yaml +43 -0
- tests/copy.mapping.yaml +35 -0
- tests/copy.workload.yaml +15 -0
- tests/distribuffers/__init__.py +0 -0
- tests/distribuffers/multicast/test_cases.yaml +482 -0
- tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
- tests/distribuffers/spec/distributed.yaml +100 -0
- tests/distribuffers/spec/logical_arch.yaml +32 -0
- tests/distribuffers/spec/physical_arch.yaml +69 -0
- tests/distribuffers/test_binding.py +48 -0
- tests/frontend/__init__.py +0 -0
- tests/frontend/test_mapping_viz.py +52 -0
- tests/mapper/__init__.py +0 -0
- tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
- tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
- tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
- tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
- tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
- tests/mapper/test_mapping_to_isl.py +90 -0
- tests/mapper/test_spatial_reuse_analysis.py +67 -0
- tests/mapper/test_temporal_reuse_analysis.py +56 -0
- tests/mapper/util.py +58 -0
- tests/matmul.mapping.yaml +29 -0
- tests/matmul.workload.yaml +12 -0
- tests/matmul_spatial.mapping.yaml +44 -0
- tests/mha.renames.yaml +65 -0
- tests/mha.workload.yaml +67 -0
- tests/mha.yaml +59 -0
- tests/mha_full.workload.yaml +67 -0
- tests/mobilenet.workload.yaml +35 -0
- tests/mobilenet_long.workload.yaml +64 -0
- tests/pmappingcache.py +24 -0
- tests/processing_stage.arch.yaml +40 -0
- tests/snowcat.arch.yaml +36 -0
- tests/test_ffm_join_pmappings.py +106 -0
- tests/test_ffm_make_pmappings.py +82 -0
- tests/test_ffm_make_tile_shapes.py +49 -0
- tests/test_mapper.py +100 -0
- tests/test_model.py +37 -0
- tests/test_plotting.py +72 -0
- tests/test_processing_stage.py +46 -0
- tests/test_symbolic_model.py +248 -0
- tests/test_workload.py +141 -0
|
@@ -0,0 +1,703 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
import itertools
|
|
3
|
+
import logging
|
|
4
|
+
import time
|
|
5
|
+
from typing import Callable
|
|
6
|
+
|
|
7
|
+
from accelforge._accelerated_imports import pd
|
|
8
|
+
from accelforge.frontend.spec import Spec
|
|
9
|
+
from accelforge.frontend.mapping import Mapping
|
|
10
|
+
from accelforge.frontend.mapper.metrics import Metrics
|
|
11
|
+
from accelforge.frontend.workload import EinsumName
|
|
12
|
+
from accelforge.mapper.FFM.mappings import Mappings
|
|
13
|
+
from accelforge.mapper.FFM.pmappings import MultiEinsumPmappings
|
|
14
|
+
from accelforge.mapper.FFM._join_pmappings.compress_pmappings import (
|
|
15
|
+
compress_einsum2pmappings,
|
|
16
|
+
decompress_pmappings,
|
|
17
|
+
)
|
|
18
|
+
from accelforge.mapper.FFM._make_pmappings.make_pmappings import (
|
|
19
|
+
get_rank_variable_bounds_for_all_einsums,
|
|
20
|
+
)
|
|
21
|
+
from accelforge.mapper.FFM._join_pmappings.pmapping_dataframe import (
|
|
22
|
+
row2pmappings,
|
|
23
|
+
)
|
|
24
|
+
from accelforge.mapper.FFM._pareto_df.df_convention import MAPPING_COLUMN
|
|
25
|
+
from accelforge.mapper.FFM._join_pmappings.pmapping_group import (
|
|
26
|
+
PmappingGroup,
|
|
27
|
+
Compatibility,
|
|
28
|
+
)
|
|
29
|
+
from accelforge.mapper.FFM._pareto_df.df_convention import col2nameloop
|
|
30
|
+
from accelforge.util import parallel, delayed
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class JoiningTimer:
|
|
37
|
+
def __init__(self):
|
|
38
|
+
self.prev_time = time.time()
|
|
39
|
+
self.total_time = defaultdict(int)
|
|
40
|
+
|
|
41
|
+
def print_time(self, what: str):
|
|
42
|
+
t = time.time() - self.prev_time
|
|
43
|
+
logger.info(f"{what}: {t:.2f} seconds")
|
|
44
|
+
self.total_time[what] += t
|
|
45
|
+
self.prev_time = time.time()
|
|
46
|
+
|
|
47
|
+
def log_total_time(self):
|
|
48
|
+
logger.info(f"\n======== Total time ========")
|
|
49
|
+
for k, v in self.total_time.items():
|
|
50
|
+
logger.info(f"{k}: {v:.2f} seconds")
|
|
51
|
+
total = sum(self.total_time.values())
|
|
52
|
+
if total > 60:
|
|
53
|
+
logger.info(f"\nTotal: {total:.2f} seconds ({total/60:.2f} minutes)")
|
|
54
|
+
else:
|
|
55
|
+
logger.info(f"\nTotal: {total:.2f} seconds")
|
|
56
|
+
logger.info(f"============================\n")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def clean_compress_and_join_pmappings(
|
|
60
|
+
spec: Spec,
|
|
61
|
+
pmappings: MultiEinsumPmappings,
|
|
62
|
+
require_all_einsums: bool = True,
|
|
63
|
+
_pmapping_row_filter_function: Callable[[pd.Series], bool] | None = None,
|
|
64
|
+
) -> Mappings:
|
|
65
|
+
einsum2pmappings = pmappings.einsum2pmappings
|
|
66
|
+
if not require_all_einsums:
|
|
67
|
+
einsum2pmappings = {
|
|
68
|
+
k: v
|
|
69
|
+
for k, v in pmappings.einsum2pmappings.items()
|
|
70
|
+
if k in pmappings.einsums_with_pmappings_generated
|
|
71
|
+
}
|
|
72
|
+
_check_einsum2pmappings_not_empty(einsum2pmappings, pmappings)
|
|
73
|
+
|
|
74
|
+
compressed, decompress_data = compress_einsum2pmappings(einsum2pmappings)
|
|
75
|
+
joined = join_pmappings(
|
|
76
|
+
compressed,
|
|
77
|
+
spec,
|
|
78
|
+
_pmapping_row_filter_function=_pmapping_row_filter_function,
|
|
79
|
+
)
|
|
80
|
+
joined = decompress_pmappings(joined, decompress_data)
|
|
81
|
+
|
|
82
|
+
for einsum_name in einsum2pmappings:
|
|
83
|
+
col = f"{einsum_name}<SEP>{MAPPING_COLUMN}"
|
|
84
|
+
joined.data[col] = joined.data[col].apply(
|
|
85
|
+
lambda x: pmappings.pmapping_objects[einsum_name][x]
|
|
86
|
+
)
|
|
87
|
+
joined._data = joined.data.fillna(0).reset_index(drop=True)
|
|
88
|
+
|
|
89
|
+
rank_variable_bounds = get_rank_variable_bounds_for_all_einsums(spec)
|
|
90
|
+
einsum_names = list(einsum2pmappings.keys())
|
|
91
|
+
joined.data[f"Total<SEP>{MAPPING_COLUMN}"] = [
|
|
92
|
+
MappingFromRow(r, rank_variable_bounds, einsum_names)
|
|
93
|
+
for _, r in joined.data.iterrows()
|
|
94
|
+
]
|
|
95
|
+
# Fill nans with 0. We might get missing columns for some mapping entries if there
|
|
96
|
+
# are energy entries for some pmappings but not others (e.g., one pmapping accesses
|
|
97
|
+
# DRAM while another doesn't.)
|
|
98
|
+
return Mappings(
|
|
99
|
+
spec,
|
|
100
|
+
list(
|
|
101
|
+
x
|
|
102
|
+
for x in list(einsum2pmappings.keys())
|
|
103
|
+
if x in pmappings.einsums_with_pmappings_generated
|
|
104
|
+
),
|
|
105
|
+
joined.data,
|
|
106
|
+
total_mappings=joined.n_total_pmappings,
|
|
107
|
+
valid_mappings=joined.n_valid_pmappings,
|
|
108
|
+
flattened_arches=pmappings.flattened_arches,
|
|
109
|
+
parsed_specs=pmappings.parsed_specs,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class PmappingsOneEinsum:
|
|
114
|
+
def __init__(self, einsum_name: str, pm_group_list: list[PmappingGroup]):
|
|
115
|
+
self.einsum_name: str = einsum_name
|
|
116
|
+
self.pmapping_groups: list[PmappingGroup] = pm_group_list
|
|
117
|
+
self.tensor_names: set[str] = set(pm_group_list[0].tensor_names)
|
|
118
|
+
|
|
119
|
+
def __getitem__(self, i):
|
|
120
|
+
return self.pmapping_groups[i]
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def make_full_equivalent_rank_variables(pairwise_equivalent_rank_variables):
|
|
124
|
+
full_equivalent_rank_variables = {
|
|
125
|
+
k: set(v) for k, v in pairwise_equivalent_rank_variables.items()
|
|
126
|
+
}
|
|
127
|
+
changed = True
|
|
128
|
+
while changed:
|
|
129
|
+
changed = False
|
|
130
|
+
for r in full_equivalent_rank_variables:
|
|
131
|
+
for r2 in list(full_equivalent_rank_variables[r]):
|
|
132
|
+
for r3 in list(full_equivalent_rank_variables[r2]):
|
|
133
|
+
if r3 in full_equivalent_rank_variables[r]:
|
|
134
|
+
continue
|
|
135
|
+
changed = True
|
|
136
|
+
full_equivalent_rank_variables[r].add(r3)
|
|
137
|
+
return full_equivalent_rank_variables
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def get_memories_to_track(
|
|
141
|
+
pmapping_groups: dict[str, list[PmappingGroup]],
|
|
142
|
+
) -> tuple[dict[str, list[PmappingGroup]], set[str], set[str]]:
|
|
143
|
+
|
|
144
|
+
always_below = set()
|
|
145
|
+
for _, einsum_pmapping_groups in pmapping_groups.items():
|
|
146
|
+
for s in einsum_pmapping_groups:
|
|
147
|
+
for col in s.mappings.data.columns:
|
|
148
|
+
name_nloops = col2nameloop(col)
|
|
149
|
+
if name_nloops is not None:
|
|
150
|
+
always_below.add(col2nameloop(col)[0])
|
|
151
|
+
|
|
152
|
+
total_sizes = {}
|
|
153
|
+
ignored_resources = set()
|
|
154
|
+
|
|
155
|
+
for _, einsum_pmapping_groups in pmapping_groups.items():
|
|
156
|
+
max_sizes = {}
|
|
157
|
+
for s in einsum_pmapping_groups:
|
|
158
|
+
n_fused_loops = s.compatibility.n_loops
|
|
159
|
+
for col in s.mappings.data.columns:
|
|
160
|
+
name_nloops = col2nameloop(col)
|
|
161
|
+
if name_nloops is None:
|
|
162
|
+
continue
|
|
163
|
+
|
|
164
|
+
name, nloops = name_nloops
|
|
165
|
+
if name in always_below and nloops < n_fused_loops:
|
|
166
|
+
always_below.remove(name)
|
|
167
|
+
# Check each of the compatibility's tensors
|
|
168
|
+
for tensor in s.compatibility.tensors:
|
|
169
|
+
if tensor.resource_name in always_below:
|
|
170
|
+
always_below.remove(tensor.resource_name)
|
|
171
|
+
size = s.mappings.data[col].max()
|
|
172
|
+
max_sizes[name] = max(max_sizes.get(name, 0), size)
|
|
173
|
+
|
|
174
|
+
# nloops < 0 means that the reservation will live through all Einsums
|
|
175
|
+
if nloops < 0:
|
|
176
|
+
ignored_resources.add(name)
|
|
177
|
+
|
|
178
|
+
for name, size in max_sizes.items():
|
|
179
|
+
total_sizes[name] = total_sizes.get(name, 0) + size
|
|
180
|
+
|
|
181
|
+
ignore = set(t for t, s in total_sizes.items() if s <= 1) | always_below
|
|
182
|
+
|
|
183
|
+
if not ignore:
|
|
184
|
+
return pmapping_groups, ignore
|
|
185
|
+
|
|
186
|
+
def remove_unneeded_columns(s: PmappingGroup):
|
|
187
|
+
data = s.mappings.data
|
|
188
|
+
keep_cols = []
|
|
189
|
+
for col in data.columns:
|
|
190
|
+
name_nloops = col2nameloop(col)
|
|
191
|
+
if name_nloops is None or name_nloops[0] not in ignore:
|
|
192
|
+
keep_cols.append(col)
|
|
193
|
+
run_pareto = len(keep_cols) < len(data.columns)
|
|
194
|
+
return PmappingGroup(
|
|
195
|
+
s.compatibility,
|
|
196
|
+
s.mappings.update(data=data[keep_cols], skip_pareto=not run_pareto),
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
for a in sorted(always_below):
|
|
200
|
+
print(f"Not tracking {a} because it is never reserved for multiple pmappings.")
|
|
201
|
+
for t, s in sorted(total_sizes.items(), key=lambda x: x[1], reverse=True):
|
|
202
|
+
if s <= 1:
|
|
203
|
+
print(
|
|
204
|
+
f"Not tracking {t} because its size is enough for the sum of all "
|
|
205
|
+
f"reservations ({s * 100:.2f}% of the total)"
|
|
206
|
+
)
|
|
207
|
+
break
|
|
208
|
+
|
|
209
|
+
new_pmapping_groups = {}
|
|
210
|
+
for einsum_name, einsum_pmapping_groups in pmapping_groups.items():
|
|
211
|
+
new_pmapping_groups[einsum_name] = list(
|
|
212
|
+
parallel(
|
|
213
|
+
[delayed(remove_unneeded_columns)(s) for s in einsum_pmapping_groups],
|
|
214
|
+
pbar=f"Removing unneeded reservations for {einsum_name}",
|
|
215
|
+
return_as="generator",
|
|
216
|
+
)
|
|
217
|
+
)
|
|
218
|
+
return new_pmapping_groups, ignore
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def join_pmappings(
|
|
222
|
+
pmapping_groups: dict[str, list[PmappingGroup]],
|
|
223
|
+
spec: Spec,
|
|
224
|
+
# Optimality-maintaining optimizations.
|
|
225
|
+
skip_invalid: bool = True,
|
|
226
|
+
combine_reservations: bool = True,
|
|
227
|
+
lookahead_filter: bool = True,
|
|
228
|
+
metrics: Metrics = None,
|
|
229
|
+
_pmapping_row_filter_function: Callable[[pd.Series], bool] | None = None,
|
|
230
|
+
):
|
|
231
|
+
"""
|
|
232
|
+
CONTRACT FOR MAPPINGS GETTING TO THIS POINT:
|
|
233
|
+
|
|
234
|
+
- Reservations at a level include reservations at all levels above it.
|
|
235
|
+
- If one Einsum uses an aliased tensor more than once, then only one
|
|
236
|
+
reservation is made for it. If overlapping lifetimes cause the aliases to
|
|
237
|
+
be alive at the same time, then it is handled here.
|
|
238
|
+
- Memory names should be sorted with higher memory names representing
|
|
239
|
+
memories lower in the hierarchy. e.g., memory 0 is the largest,
|
|
240
|
+
memory 1 the next largest, and memory N is the smallest.
|
|
241
|
+
"""
|
|
242
|
+
|
|
243
|
+
metrics = spec.mapper.ffm.metrics
|
|
244
|
+
|
|
245
|
+
drop_valid_reservations = not (Metrics.RESOURCE_USAGE & metrics)
|
|
246
|
+
ignored_resources = set()
|
|
247
|
+
|
|
248
|
+
if _pmapping_row_filter_function is not None:
|
|
249
|
+
n = sum(len(s.mappings.data) for sg in pmapping_groups.values() for s in sg)
|
|
250
|
+
pmapping_groups = {
|
|
251
|
+
e: [
|
|
252
|
+
PmappingGroup(
|
|
253
|
+
s.compatibility,
|
|
254
|
+
s.mappings.filter_rows(_pmapping_row_filter_function),
|
|
255
|
+
)
|
|
256
|
+
for s in pmapping_groups[e]
|
|
257
|
+
]
|
|
258
|
+
for e in pmapping_groups
|
|
259
|
+
}
|
|
260
|
+
new_n = sum(len(s.mappings.data) for sg in pmapping_groups.values() for s in sg)
|
|
261
|
+
print(f"Filtered {n} -> {new_n} ({new_n / n:.2%} kept) pmappings")
|
|
262
|
+
|
|
263
|
+
if drop_valid_reservations:
|
|
264
|
+
pmapping_groups, ignored_resources = get_memories_to_track(pmapping_groups)
|
|
265
|
+
|
|
266
|
+
mixable_ranks = spec.workload._get_ranks_that_share_indexing_rank_variables()
|
|
267
|
+
|
|
268
|
+
aliased_tensors = spec.workload.get_tensor_copies()
|
|
269
|
+
|
|
270
|
+
n_mappings = {}
|
|
271
|
+
runtime = {}
|
|
272
|
+
nbuckets = []
|
|
273
|
+
|
|
274
|
+
n_evaluations = 0
|
|
275
|
+
|
|
276
|
+
pmapping_groups = list(pmapping_groups.items())
|
|
277
|
+
|
|
278
|
+
if not skip_invalid:
|
|
279
|
+
lookahead_filter = False
|
|
280
|
+
|
|
281
|
+
for einsum_name, s in pmapping_groups:
|
|
282
|
+
if not s:
|
|
283
|
+
raise ValueError(f"No pmappings for {einsum_name}")
|
|
284
|
+
|
|
285
|
+
timer = JoiningTimer()
|
|
286
|
+
|
|
287
|
+
pmgroups = [PmappingsOneEinsum(*s) for s in pmapping_groups]
|
|
288
|
+
|
|
289
|
+
if not pmgroups:
|
|
290
|
+
raise ValueError("No pmappings to join")
|
|
291
|
+
|
|
292
|
+
# ======================================================================
|
|
293
|
+
# Initial consolidate and group all PmappingGroups
|
|
294
|
+
# ======================================================================
|
|
295
|
+
n_mappings["Post Intra-Layer"] = 0
|
|
296
|
+
for i, einsum_pmappings in enumerate(pmgroups):
|
|
297
|
+
cur_tensors = einsum_pmappings.tensor_names
|
|
298
|
+
right_tensors = set.union(set(), *[s.tensor_names for s in pmgroups[i + 1 :]])
|
|
299
|
+
# First Einsum: Remove dead tensors and left consolidate. This is because the
|
|
300
|
+
# first Einsum will have the first pmappigns that are joined from the left
|
|
301
|
+
if i == 0:
|
|
302
|
+
if cur_tensors - right_tensors:
|
|
303
|
+
PmappingGroup.remove_dead_tensors(
|
|
304
|
+
einsum_pmappings.pmapping_groups, right_tensors
|
|
305
|
+
)
|
|
306
|
+
for s in einsum_pmappings.pmapping_groups:
|
|
307
|
+
s.compatibility = s.compatibility.clear_dead_tensors(right_tensors)
|
|
308
|
+
einsum_pmappings.pmapping_groups = PmappingGroup.left_consolidate(
|
|
309
|
+
einsum_pmappings.pmapping_groups,
|
|
310
|
+
right_tensors,
|
|
311
|
+
parallelize=False, # We're not pareto pruning, so parallelization doesn't help.
|
|
312
|
+
pbar=f"Inital consolidate {einsum_pmappings.einsum_name} ({i+1}/{len(pmgroups)})",
|
|
313
|
+
)
|
|
314
|
+
continue
|
|
315
|
+
|
|
316
|
+
# All other Einsums: Will be joined from the right. Remove dead tensors, right
|
|
317
|
+
# consolidate, combine, group.
|
|
318
|
+
t0 = time.time()
|
|
319
|
+
left_tensors = set.union(set(), *[s.tensor_names for s in pmgroups[:i]])
|
|
320
|
+
live_tensors = right_tensors
|
|
321
|
+
shared_tensors = left_tensors & einsum_pmappings.tensor_names
|
|
322
|
+
|
|
323
|
+
if cur_tensors - (right_tensors | left_tensors):
|
|
324
|
+
PmappingGroup.remove_dead_tensors(
|
|
325
|
+
einsum_pmappings.pmapping_groups, right_tensors | left_tensors
|
|
326
|
+
)
|
|
327
|
+
for s in einsum_pmappings.pmapping_groups:
|
|
328
|
+
s.compatibility = s.compatibility.clear_dead_tensors(
|
|
329
|
+
right_tensors | left_tensors
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
einsum_pmappings.pmapping_groups = sorted(
|
|
333
|
+
einsum_pmappings.pmapping_groups,
|
|
334
|
+
key=lambda x: len(x.mappings.data),
|
|
335
|
+
reverse=True,
|
|
336
|
+
)
|
|
337
|
+
einsum_pmappings.pmapping_groups = PmappingGroup.right_consolidate(
|
|
338
|
+
einsum_pmappings.pmapping_groups,
|
|
339
|
+
live_tensors,
|
|
340
|
+
shared_tensors,
|
|
341
|
+
parallelize=False, # We're not pareto pruning, so parallelization doesn't help.
|
|
342
|
+
pbar=f"Inital consolidate {einsum_pmappings.einsum_name} ({i+1}/{len(pmgroups)})",
|
|
343
|
+
)
|
|
344
|
+
einsum_pmappings.pmapping_groups = PmappingGroup.combine_combineable(
|
|
345
|
+
einsum_pmappings.pmapping_groups,
|
|
346
|
+
left_tensors | right_tensors,
|
|
347
|
+
combine_reservations=combine_reservations,
|
|
348
|
+
pbar_postfix=f" for {einsum_pmappings.einsum_name} ({i+1}/{len(pmgroups)})",
|
|
349
|
+
)
|
|
350
|
+
n_mappings["Post Intra-Layer"] += sum(
|
|
351
|
+
len(s.mappings.data) for s in einsum_pmappings.pmapping_groups
|
|
352
|
+
)
|
|
353
|
+
einsum_pmappings.pmapping_groups = PmappingGroup.group(
|
|
354
|
+
einsum_pmappings.pmapping_groups, left_tensors
|
|
355
|
+
)
|
|
356
|
+
einsum, prev_einsum = einsum_pmappings.einsum_name, pmgroups[i - 1].einsum_name
|
|
357
|
+
runtime[f"{prev_einsum} → {einsum}"] = time.time() - t0
|
|
358
|
+
t0 = time.time()
|
|
359
|
+
timer.print_time(f"Initial consolidate and group")
|
|
360
|
+
|
|
361
|
+
n_iterations = 0
|
|
362
|
+
total_iterations = len(pmgroups)
|
|
363
|
+
|
|
364
|
+
def grab_einsum_pmappings() -> (
|
|
365
|
+
tuple[dict[Compatibility, list[PmappingGroup]], str, set[str]]
|
|
366
|
+
):
|
|
367
|
+
nonlocal n_iterations
|
|
368
|
+
n_iterations += 1
|
|
369
|
+
holder = pmgroups.pop(0)
|
|
370
|
+
return holder.pmapping_groups, holder.einsum_name, holder.tensor_names
|
|
371
|
+
|
|
372
|
+
if pmgroups:
|
|
373
|
+
left, left_einsum, left_tensors = grab_einsum_pmappings()
|
|
374
|
+
|
|
375
|
+
partial_mapping_size = 1
|
|
376
|
+
while pmgroups:
|
|
377
|
+
t0 = time.time()
|
|
378
|
+
# ======================================================================
|
|
379
|
+
# Grab new Einsum from the right. Record logging data and find still
|
|
380
|
+
# tensors that will be live after this Einsum.
|
|
381
|
+
# ======================================================================
|
|
382
|
+
nbuckets.append(len(left))
|
|
383
|
+
# nmappings.append(sum(len(s.mappings.data) for s in left))
|
|
384
|
+
right, right_einsum, right_tensors = grab_einsum_pmappings()
|
|
385
|
+
logger.info(f"Einsum {right_einsum} ({n_iterations}/{total_iterations})")
|
|
386
|
+
|
|
387
|
+
partial_mapping_size += 1
|
|
388
|
+
|
|
389
|
+
live_tensors = set.union(set(), *[s.tensor_names for s in pmgroups])
|
|
390
|
+
shared_tensors = set(left_tensors) & set(right_tensors)
|
|
391
|
+
live_tensors_with_right = live_tensors | right_tensors
|
|
392
|
+
|
|
393
|
+
# ======================================================================
|
|
394
|
+
# Clean up the previously-combined PmappingGroups. Consolidate, combine, group
|
|
395
|
+
# them into buckets.
|
|
396
|
+
# ======================================================================
|
|
397
|
+
# print_time(f"Consolidating")
|
|
398
|
+
|
|
399
|
+
left = PmappingGroup.combine_combineable(
|
|
400
|
+
left,
|
|
401
|
+
live_tensors | right_tensors,
|
|
402
|
+
combine_reservations=combine_reservations,
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
# print_time(f"Combining")
|
|
406
|
+
# Group left and right into buckets
|
|
407
|
+
left = PmappingGroup.group(left, right_tensors)
|
|
408
|
+
# print_time("Grouping")
|
|
409
|
+
|
|
410
|
+
# ======================================================================
|
|
411
|
+
# Remove dead tensors from left and right. This happens after grouping because
|
|
412
|
+
# we only reserve space for shared tensors after they're dead (alive is handled
|
|
413
|
+
# by the normal reservation system). This is in case the tensor lifetime extends
|
|
414
|
+
# beyond the Einsums for which it is used.
|
|
415
|
+
# ======================================================================
|
|
416
|
+
PmappingGroup.remove_dead_tensors(
|
|
417
|
+
[s for lr in [left, right] for v in lr.values() for s, _ in v], live_tensors
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
DO_PRINT = False
|
|
421
|
+
DELAY = True
|
|
422
|
+
# ======================================================================
|
|
423
|
+
# Merge the left and right buckets.
|
|
424
|
+
# ======================================================================
|
|
425
|
+
combined: list[PmappingGroup] = []
|
|
426
|
+
cur_nmappings = 0
|
|
427
|
+
combined_ids: set[tuple[int, int, tuple[tuple[int, int], ...]]] = set()
|
|
428
|
+
|
|
429
|
+
for k in left:
|
|
430
|
+
found = False
|
|
431
|
+
if DO_PRINT:
|
|
432
|
+
print(f"Left key {k}")
|
|
433
|
+
for (a, perm_a), (b, perm_b) in itertools.product(
|
|
434
|
+
left[k], right.get(k, [])
|
|
435
|
+
):
|
|
436
|
+
a: PmappingGroup
|
|
437
|
+
b: PmappingGroup
|
|
438
|
+
perm_a: list[int]
|
|
439
|
+
perm_b: list[int]
|
|
440
|
+
key_check = (
|
|
441
|
+
id(a),
|
|
442
|
+
id(b),
|
|
443
|
+
tuple((pa, pb) for pa, pb in zip(perm_a, perm_b)),
|
|
444
|
+
)
|
|
445
|
+
if key_check in combined_ids:
|
|
446
|
+
continue
|
|
447
|
+
combined_ids.add(key_check)
|
|
448
|
+
found = True
|
|
449
|
+
|
|
450
|
+
compatibility_a = a.compatibility.permute(perm_a)
|
|
451
|
+
compatibility_b = b.compatibility.permute(perm_b)
|
|
452
|
+
try:
|
|
453
|
+
compatibility_joined = compatibility_a.merge_next(
|
|
454
|
+
compatibility_b,
|
|
455
|
+
live_tensors,
|
|
456
|
+
mixable_ranks,
|
|
457
|
+
)
|
|
458
|
+
if DO_PRINT:
|
|
459
|
+
print(
|
|
460
|
+
f"\t{a.compatibility} <--> {b.compatibility}"
|
|
461
|
+
)
|
|
462
|
+
except ValueError as e: # Incompatible!
|
|
463
|
+
# if DO_PRINT:
|
|
464
|
+
# print(f"\tIncompatible: {e}")
|
|
465
|
+
continue
|
|
466
|
+
|
|
467
|
+
t0 = time.time()
|
|
468
|
+
|
|
469
|
+
combined.append(
|
|
470
|
+
a.merge_next(
|
|
471
|
+
b,
|
|
472
|
+
live_tensors,
|
|
473
|
+
live_tensors_with_right,
|
|
474
|
+
aliased_tensors,
|
|
475
|
+
compatibility_joined=compatibility_joined,
|
|
476
|
+
drop_valid_reservations=drop_valid_reservations,
|
|
477
|
+
delay=DELAY,
|
|
478
|
+
_pmapping_row_filter_function=_pmapping_row_filter_function,
|
|
479
|
+
ignored_resources=ignored_resources,
|
|
480
|
+
)
|
|
481
|
+
)
|
|
482
|
+
t1 = time.time()
|
|
483
|
+
# print(f'Took {t1 - t0:.2f} seconds to generate {len(combined[-1].mappings.data)} mappings')
|
|
484
|
+
|
|
485
|
+
if not DELAY:
|
|
486
|
+
cur_nmappings += len(a.mappings.data) * len(b.mappings.data)
|
|
487
|
+
if DO_PRINT:
|
|
488
|
+
# s = f"\t-->\n\t{combined[-1].compatibility}"
|
|
489
|
+
# s += f"({len(a.mappings.data)})x({len(b.mappings.data)})"
|
|
490
|
+
# print(s)
|
|
491
|
+
pass
|
|
492
|
+
if DO_PRINT and not found:
|
|
493
|
+
for a, _ in left[k]:
|
|
494
|
+
print(f"\tNo match for {a.compatibility}")
|
|
495
|
+
|
|
496
|
+
if DO_PRINT:
|
|
497
|
+
for k in right:
|
|
498
|
+
if k not in left:
|
|
499
|
+
for b, _ in right[k]:
|
|
500
|
+
print(f"\tREVERSE: No match for {b.compatibility} using {k}")
|
|
501
|
+
|
|
502
|
+
# print_time("Bucket merging")
|
|
503
|
+
def raise_no_match_error():
|
|
504
|
+
estr = f"No match found for any group.\n"
|
|
505
|
+
estr += f"Left compatibility:\n\t" + "\n\t".join(
|
|
506
|
+
str(c) for c in left.keys()
|
|
507
|
+
)
|
|
508
|
+
estr += f"\nRight compatibility:\n\t" + "\n\t".join(
|
|
509
|
+
str(c) for c in right.keys()
|
|
510
|
+
)
|
|
511
|
+
raise ValueError(estr)
|
|
512
|
+
|
|
513
|
+
def no_match_lookahead_error(
|
|
514
|
+
combined: list[PmappingGroup],
|
|
515
|
+
next_keys: set[tuple[int, int, tuple[tuple[int, int], ...]]],
|
|
516
|
+
):
|
|
517
|
+
estr = f"No match found for any group. Left and right joined successfully, "
|
|
518
|
+
estr += f"but will not be compatible with following Einsums.\n"
|
|
519
|
+
estr += f"Left compatibility:\n\t" + "\n\t".join(
|
|
520
|
+
str(s.compatibility) for g in left.values() for s, _ in g
|
|
521
|
+
)
|
|
522
|
+
estr += f"\nRight compatibility:\n\t" + "\n\t".join(
|
|
523
|
+
str(s.compatibility) for g in right.values() for s, _ in g
|
|
524
|
+
)
|
|
525
|
+
estr += f"\nCombined compatibility:\n\t" + "\n\t".join(
|
|
526
|
+
str(s.compatibility) for s in combined
|
|
527
|
+
)
|
|
528
|
+
estr += f"\nFollowing Einsum compatibility:\n\t" + "\n\t".join(
|
|
529
|
+
str(c) for c in next_keys
|
|
530
|
+
)
|
|
531
|
+
raise ValueError(estr)
|
|
532
|
+
|
|
533
|
+
# ======================================================================
|
|
534
|
+
# Look ahead to the next Einsum and see if any of our groups will not
|
|
535
|
+
# be able to merge with it. If so, we can drop them immediately.
|
|
536
|
+
# ======================================================================
|
|
537
|
+
if lookahead_filter:
|
|
538
|
+
cur_tensors = left_tensors | right_tensors
|
|
539
|
+
for next_pmapping_groups in pmgroups:
|
|
540
|
+
next_right_tensors = next_pmapping_groups.tensor_names
|
|
541
|
+
if not next_right_tensors & cur_tensors:
|
|
542
|
+
continue
|
|
543
|
+
prev_combined = combined
|
|
544
|
+
combined = PmappingGroup.group(combined, next_right_tensors)
|
|
545
|
+
next_keys = {
|
|
546
|
+
c.clear_dead_tensors(
|
|
547
|
+
cur_tensors
|
|
548
|
+
).clear_tile_patterns_and_reservation_indices()
|
|
549
|
+
for c in next_pmapping_groups.pmapping_groups
|
|
550
|
+
}
|
|
551
|
+
for k in list(combined):
|
|
552
|
+
k_cleared = k.clear_dead_tensors(
|
|
553
|
+
next_right_tensors
|
|
554
|
+
).clear_tile_patterns_and_reservation_indices()
|
|
555
|
+
if k_cleared not in next_keys:
|
|
556
|
+
if DO_PRINT:
|
|
557
|
+
for b, _ in combined[k]:
|
|
558
|
+
print(
|
|
559
|
+
f"\tLOOKAHEAD to {next_pmapping_groups.einsum_name}: No match for {b.compatibility}"
|
|
560
|
+
)
|
|
561
|
+
del combined[k]
|
|
562
|
+
if not combined:
|
|
563
|
+
PmappingGroup.group(prev_combined, next_right_tensors)
|
|
564
|
+
no_match_lookahead_error(prev_combined, next_keys)
|
|
565
|
+
|
|
566
|
+
combined = list(itertools.chain.from_iterable(combined.values()))
|
|
567
|
+
combined = [c[0] for c in combined]
|
|
568
|
+
# Remove duplicates
|
|
569
|
+
id2combined = {id(c): c for c in combined}
|
|
570
|
+
combined = list(id2combined.values())
|
|
571
|
+
# print(
|
|
572
|
+
# f"Removed {prev_len - len(combined)}/{prev_len} ({len(combined)/prev_len*100:.2f}% remaining)"
|
|
573
|
+
# )
|
|
574
|
+
# print_time("Removing mappings that can't be combined later")
|
|
575
|
+
|
|
576
|
+
if not combined:
|
|
577
|
+
raise_no_match_error()
|
|
578
|
+
|
|
579
|
+
# ======================================================================
|
|
580
|
+
# If we delayed the mapping merging, do it now.
|
|
581
|
+
# ======================================================================
|
|
582
|
+
if DELAY:
|
|
583
|
+
mappings = parallel(
|
|
584
|
+
[c.mappings for c in combined],
|
|
585
|
+
pbar=f"Joining pmappings for {left_einsum} <--> {right_einsum} ({n_iterations}/{total_iterations})",
|
|
586
|
+
return_as="generator",
|
|
587
|
+
)
|
|
588
|
+
for c, mapping in zip(combined, mappings):
|
|
589
|
+
c.mappings = mapping
|
|
590
|
+
cur_nmappings += c.n_pre_prune_mappings
|
|
591
|
+
timer.print_time("Pmapping merging")
|
|
592
|
+
|
|
593
|
+
prev_nmappings = cur_nmappings
|
|
594
|
+
if not skip_invalid:
|
|
595
|
+
left_nmappings = sum(len(s.mappings.data) for k in left.values() for s in k)
|
|
596
|
+
right_nmappings = sum(
|
|
597
|
+
len(s.mappings.data) for k in right.values() for s in k
|
|
598
|
+
)
|
|
599
|
+
cur_nmappings = left_nmappings * right_nmappings
|
|
600
|
+
n_mappings[f"{left_einsum} → {right_einsum}"] = cur_nmappings
|
|
601
|
+
n_evaluations += cur_nmappings
|
|
602
|
+
runtime[f"{left_einsum} → {right_einsum}"] += (time.time() - t0) * (
|
|
603
|
+
cur_nmappings / prev_nmappings
|
|
604
|
+
)
|
|
605
|
+
# print(
|
|
606
|
+
# f'Scaled runtime by {cur_nmappings / prev_nmappings}. Runtime: {runtime[f"{prev_einsum} → {einsum}"]:.2f}'
|
|
607
|
+
# )
|
|
608
|
+
|
|
609
|
+
# ======================================================================
|
|
610
|
+
# Print statements
|
|
611
|
+
# ======================================================================
|
|
612
|
+
logger.info(
|
|
613
|
+
f"\tCombining {sum(len(s) for s in left)}({len(left)}) x {sum(len(s) for s in right)}({len(right)}) -> {len(combined)}"
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
nmappings = sum(len(s.mappings.data) for s in combined)
|
|
617
|
+
for_einsum_text = f"for Einsum {right_einsum}"
|
|
618
|
+
logger.info(f"\tNumber of groups {for_einsum_text}: {len(combined)}")
|
|
619
|
+
# for c in combined:
|
|
620
|
+
# print(f"\t\t{c.compatibility}")
|
|
621
|
+
logger.info(f"\tNumber of mappings {for_einsum_text}: {nmappings}")
|
|
622
|
+
logger.info(
|
|
623
|
+
f"\tMappings per group {for_einsum_text}: {nmappings / len(combined)}"
|
|
624
|
+
)
|
|
625
|
+
logger.info(
|
|
626
|
+
f"\tLargest left: {max(len(s2.mappings.data) for s in left.values() for s2, _ in s)}"
|
|
627
|
+
)
|
|
628
|
+
logger.info(
|
|
629
|
+
f"\tLargest right: {max(len(s2.mappings.data) for s in right.values() for s2, _ in s)}"
|
|
630
|
+
)
|
|
631
|
+
|
|
632
|
+
# ======================================================================
|
|
633
|
+
# Update left for the next iteration.
|
|
634
|
+
# =================================================================
|
|
635
|
+
left = combined
|
|
636
|
+
left_einsum = right_einsum
|
|
637
|
+
left_tensors |= right_tensors
|
|
638
|
+
|
|
639
|
+
# ======================================================================
|
|
640
|
+
# Final consolidate and group
|
|
641
|
+
# ======================================================================
|
|
642
|
+
t0 = time.time()
|
|
643
|
+
left = PmappingGroup.left_consolidate(left, None, pbar="Final consolidate")
|
|
644
|
+
s_final = PmappingGroup.combine_combineable(left, set())
|
|
645
|
+
assert len(s_final) == 1
|
|
646
|
+
mappings = s_final[0].mappings
|
|
647
|
+
|
|
648
|
+
timer.log_total_time()
|
|
649
|
+
# if evaluations_tracker is not None and "Total_latency" in data.columns and "Total_energy" in data.columns:
|
|
650
|
+
# edp = data["Total_latency"] * data["Total_energy"]
|
|
651
|
+
# edp_min = edp.min()
|
|
652
|
+
# evaluations_tracker.add_evaluation(n_evaluations, edp_min)
|
|
653
|
+
# evaluations_tracker.n_mappings.update(n_mappings)
|
|
654
|
+
# evaluations_tracker.runtime.update(runtime)
|
|
655
|
+
|
|
656
|
+
return mappings
|
|
657
|
+
|
|
658
|
+
|
|
659
|
+
def _check_einsum2pmappings_not_empty(einsum2pmappings, pmappings):
|
|
660
|
+
for einsum_name, einsum_pmappings in einsum2pmappings.items():
|
|
661
|
+
total = sum(len(p.mappings.data) for p in einsum_pmappings)
|
|
662
|
+
n_compatibilities = len(einsum_pmappings)
|
|
663
|
+
logger.info(
|
|
664
|
+
f"Einsum {einsum_name} has {total} pmappings with {n_compatibilities} compatibilities"
|
|
665
|
+
)
|
|
666
|
+
if total == 0:
|
|
667
|
+
if einsum_name in pmappings.einsums_with_pmappings_generated:
|
|
668
|
+
raise ValueError(
|
|
669
|
+
f"Einsum {einsum_name} has no pmappings. This likely means that "
|
|
670
|
+
f"no pmappings satisfied constraints for the Einsum. Please check "
|
|
671
|
+
f"the stats outputs from the MultiEinsumPmappings object."
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
raise ValueError(
|
|
675
|
+
f"Einsum {einsum_name} has no pmappings generated. It looks like you "
|
|
676
|
+
"may have used `make_pmappings` with `einsum_names` set. You may set "
|
|
677
|
+
"`require_all_einsums=False` to ignore this error and map only the "
|
|
678
|
+
"Einsums that have pmappings."
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
|
|
682
|
+
class MappingFromRow:
|
|
683
|
+
def __init__(
|
|
684
|
+
self,
|
|
685
|
+
row: pd.Series,
|
|
686
|
+
rank_variable_bounds: dict[str, dict[str, int]],
|
|
687
|
+
einsum_names: list[EinsumName] | None = None,
|
|
688
|
+
):
|
|
689
|
+
self.row = row
|
|
690
|
+
self.rank_variable_bounds = rank_variable_bounds
|
|
691
|
+
self.einsum_names = einsum_names
|
|
692
|
+
|
|
693
|
+
def __call__(self) -> Mapping:
|
|
694
|
+
return Mapping._from_pmappings(
|
|
695
|
+
row2pmappings(self.row, self.einsum_names, self.rank_variable_bounds),
|
|
696
|
+
rank_variable_bounds=self.rank_variable_bounds,
|
|
697
|
+
)
|
|
698
|
+
|
|
699
|
+
def _repr_svg_(self) -> str:
|
|
700
|
+
return self.render()
|
|
701
|
+
|
|
702
|
+
def render(self) -> str:
|
|
703
|
+
return self().render()
|