accelforge 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accelforge/__init__.py +21 -0
- accelforge/_accelerated_imports.py +16 -0
- accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
- accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
- accelforge/_deprecate/_simanneal/simanneal.py +666 -0
- accelforge/_deprecate/_simanneal/tracking.py +105 -0
- accelforge/_deprecate/_simanneal/wrappers.py +218 -0
- accelforge/_deprecate/_simanneal2/__init__.py +7 -0
- accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
- accelforge/_deprecate/_simanneal2/tracking.py +116 -0
- accelforge/_deprecate/compatibility_util.py +181 -0
- accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
- accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
- accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
- accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
- accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
- accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
- accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
- accelforge/_deprecate/tags.py +69 -0
- accelforge/_deprecate/viz/__init__.py +0 -0
- accelforge/_deprecate/viz/interactive.py +159 -0
- accelforge/_deprecate/viz/reservationtree.py +307 -0
- accelforge/_deprecate/viz/ski_slope.py +88 -0
- accelforge/_version.py +15 -0
- accelforge/examples.py +39 -0
- accelforge/frontend/__init__.py +10 -0
- accelforge/frontend/_binding.py +129 -0
- accelforge/frontend/_workload_isl/__init__.py +2 -0
- accelforge/frontend/_workload_isl/_isl.py +149 -0
- accelforge/frontend/_workload_isl/_symbolic.py +141 -0
- accelforge/frontend/arch copy.py +1544 -0
- accelforge/frontend/arch.py +1642 -0
- accelforge/frontend/config.py +63 -0
- accelforge/frontend/mapper/__init__.py +5 -0
- accelforge/frontend/mapper/ffm.py +126 -0
- accelforge/frontend/mapper/mapper.py +7 -0
- accelforge/frontend/mapper/metrics.py +30 -0
- accelforge/frontend/mapping/__init__.py +1 -0
- accelforge/frontend/mapping/mapping.py +1736 -0
- accelforge/frontend/model.py +14 -0
- accelforge/frontend/renames.py +150 -0
- accelforge/frontend/spec copy.py +230 -0
- accelforge/frontend/spec.py +301 -0
- accelforge/frontend/variables.py +12 -0
- accelforge/frontend/workload.py +952 -0
- accelforge/mapper/FFM/__init__.py +9 -0
- accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
- accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
- accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
- accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
- accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
- accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
- accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
- accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
- accelforge/mapper/FFM/data.py +61 -0
- accelforge/mapper/FFM/main copy.py +236 -0
- accelforge/mapper/FFM/main.py +208 -0
- accelforge/mapper/FFM/mappings.py +510 -0
- accelforge/mapper/FFM/pmappings.py +310 -0
- accelforge/mapper/__init__.py +4 -0
- accelforge/mapper.py +0 -0
- accelforge/model/__init__.py +1 -0
- accelforge/model/_looptree/__init__.py +0 -0
- accelforge/model/_looptree/accesses.py +335 -0
- accelforge/model/_looptree/capacity/__init__.py +1 -0
- accelforge/model/_looptree/capacity/aggregators.py +36 -0
- accelforge/model/_looptree/capacity/capacity.py +47 -0
- accelforge/model/_looptree/energy.py +150 -0
- accelforge/model/_looptree/equivalent_ranks.py +29 -0
- accelforge/model/_looptree/latency/__init__.py +1 -0
- accelforge/model/_looptree/latency/latency.py +98 -0
- accelforge/model/_looptree/latency/memory.py +120 -0
- accelforge/model/_looptree/latency/processors.py +92 -0
- accelforge/model/_looptree/mapping_utilities.py +71 -0
- accelforge/model/_looptree/reuse/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
- accelforge/model/_looptree/reuse/isl/des.py +59 -0
- accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
- accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
- accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
- accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
- accelforge/model/_looptree/run.py +122 -0
- accelforge/model/_looptree/types.py +26 -0
- accelforge/model/_looptree/visualization/__init__.py +0 -0
- accelforge/model/_looptree/visualization/occupancy.py +11 -0
- accelforge/model/main.py +222 -0
- accelforge/plotting/__init__.py +2 -0
- accelforge/plotting/mappings.py +219 -0
- accelforge/plotting/specs.py +57 -0
- accelforge/util/__init__.py +4 -0
- accelforge/util/_base_analysis_types.py +24 -0
- accelforge/util/_basetypes.py +1089 -0
- accelforge/util/_frozenset.py +36 -0
- accelforge/util/_isl.py +29 -0
- accelforge/util/_itertools.py +14 -0
- accelforge/util/_mathfuncs.py +57 -0
- accelforge/util/_parse_expressions.py +339 -0
- accelforge/util/_picklecache.py +32 -0
- accelforge/util/_setexpressions.py +268 -0
- accelforge/util/_sympy/__init__.py +0 -0
- accelforge/util/_sympy/broadcast_max.py +18 -0
- accelforge/util/_visualization.py +112 -0
- accelforge/util/_yaml.py +579 -0
- accelforge/util/parallel.py +193 -0
- accelforge-0.0.1.dist-info/METADATA +64 -0
- accelforge-0.0.1.dist-info/RECORD +258 -0
- accelforge-0.0.1.dist-info/WHEEL +5 -0
- accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
- accelforge-0.0.1.dist-info/top_level.txt +5 -0
- docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
- docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
- docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
- docs/_build/html/_sources/fastfusion.rst.txt +20 -0
- docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
- docs/_build/html/_sources/index.rst.txt +87 -0
- docs/_build/html/_sources/modules.rst.txt +7 -0
- docs/_build/html/_sources/notes/citation.rst.txt +45 -0
- docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
- docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
- docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
- docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
- docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
- docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
- docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
- docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
- docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
- docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
- docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
- docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
- docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
- docs/_build/html/_sources/notes/spec.rst.txt +36 -0
- docs/source/_ext/include_attrs.py +213 -0
- docs/source/_ext/include_docstring.py +364 -0
- docs/source/_ext/include_functions.py +154 -0
- docs/source/_ext/include_notebook.py +131 -0
- docs/source/_ext/include_yaml.py +119 -0
- docs/source/_ext/inherited_attributes.py +222 -0
- docs/source/_ext/paths.py +4 -0
- docs/source/conf.py +79 -0
- examples/arches/compute_in_memory/_include.yaml +74 -0
- examples/arches/compute_in_memory/_include_functions.py +229 -0
- examples/arches/compute_in_memory/_load_spec.py +57 -0
- examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
- examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
- examples/arches/compute_in_memory/components/misc.py +195 -0
- examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
- examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
- examples/arches/compute_in_memory/isaac.yaml +233 -0
- examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
- examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
- examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
- examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
- examples/arches/eyeriss.yaml +68 -0
- examples/arches/fanout_variations/at_glb.yaml +31 -0
- examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
- examples/arches/fanout_variations/at_mac.yaml +31 -0
- examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
- examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
- examples/arches/nvdla.yaml +47 -0
- examples/arches/simple.yaml +28 -0
- examples/arches/tpu_v4i.yaml +67 -0
- examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
- examples/misc/component_annotated.yaml +33 -0
- examples/workloads/gpt3_6.7B.yaml +124 -0
- examples/workloads/matmuls.yaml +20 -0
- examples/workloads/mobilenet_28.yaml +81 -0
- examples/workloads/mobilenet_various_separate.yaml +106 -0
- examples/workloads/three_matmuls_annotated.yaml +59 -0
- notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
- notebooks/compute_in_memory/_scripts.py +339 -0
- notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
- notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
- notebooks/paths.py +4 -0
- notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
- notebooks/tutorials/FFM.ipynb +3498 -0
- notebooks/tutorials/_include.py +48 -0
- notebooks/tutorials/component_energy_area.ipynb +363 -0
- tests/Q_mapping.yaml +38 -0
- tests/__init__.py +0 -0
- tests/conv.mapping.yaml +27 -0
- tests/conv.workload.yaml +13 -0
- tests/conv_sym.mapping.yaml +43 -0
- tests/copy.mapping.yaml +35 -0
- tests/copy.workload.yaml +15 -0
- tests/distribuffers/__init__.py +0 -0
- tests/distribuffers/multicast/test_cases.yaml +482 -0
- tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
- tests/distribuffers/spec/distributed.yaml +100 -0
- tests/distribuffers/spec/logical_arch.yaml +32 -0
- tests/distribuffers/spec/physical_arch.yaml +69 -0
- tests/distribuffers/test_binding.py +48 -0
- tests/frontend/__init__.py +0 -0
- tests/frontend/test_mapping_viz.py +52 -0
- tests/mapper/__init__.py +0 -0
- tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
- tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
- tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
- tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
- tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
- tests/mapper/test_mapping_to_isl.py +90 -0
- tests/mapper/test_spatial_reuse_analysis.py +67 -0
- tests/mapper/test_temporal_reuse_analysis.py +56 -0
- tests/mapper/util.py +58 -0
- tests/matmul.mapping.yaml +29 -0
- tests/matmul.workload.yaml +12 -0
- tests/matmul_spatial.mapping.yaml +44 -0
- tests/mha.renames.yaml +65 -0
- tests/mha.workload.yaml +67 -0
- tests/mha.yaml +59 -0
- tests/mha_full.workload.yaml +67 -0
- tests/mobilenet.workload.yaml +35 -0
- tests/mobilenet_long.workload.yaml +64 -0
- tests/pmappingcache.py +24 -0
- tests/processing_stage.arch.yaml +40 -0
- tests/snowcat.arch.yaml +36 -0
- tests/test_ffm_join_pmappings.py +106 -0
- tests/test_ffm_make_pmappings.py +82 -0
- tests/test_ffm_make_tile_shapes.py +49 -0
- tests/test_mapper.py +100 -0
- tests/test_model.py +37 -0
- tests/test_plotting.py +72 -0
- tests/test_processing_stage.py +46 -0
- tests/test_symbolic_model.py +248 -0
- tests/test_workload.py +141 -0
accelforge/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from accelforge.frontend import arch
|
|
2
|
+
from accelforge.frontend import config
|
|
3
|
+
from accelforge.frontend import mapping
|
|
4
|
+
from accelforge.frontend import renames
|
|
5
|
+
from accelforge.frontend import spec
|
|
6
|
+
from accelforge.frontend import variables
|
|
7
|
+
from accelforge.frontend import workload
|
|
8
|
+
from accelforge.frontend.spec import Spec, Spec
|
|
9
|
+
from accelforge.mapper.FFM import Metrics
|
|
10
|
+
from accelforge.util import set_n_parallel_jobs
|
|
11
|
+
from accelforge.util import LiteralString
|
|
12
|
+
import accelforge.mapper as mapper
|
|
13
|
+
from accelforge.examples import examples
|
|
14
|
+
|
|
15
|
+
from accelforge.frontend.variables import Variables
|
|
16
|
+
from accelforge.frontend.arch import Arch
|
|
17
|
+
from accelforge.frontend.config import Config
|
|
18
|
+
from accelforge.frontend.mapping import Mapping
|
|
19
|
+
from accelforge.frontend.renames import Renames
|
|
20
|
+
from accelforge.frontend.spec import Spec
|
|
21
|
+
from accelforge.frontend.workload import Workload
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
os.environ["ACCELFORGE_ACCELERATED_IMPORTS"] = "0"
|
|
4
|
+
|
|
5
|
+
if os.environ.get("ACCELFORGE_ACCELERATED_IMPORTS", "0") == "1":
|
|
6
|
+
import cudf as pd
|
|
7
|
+
import cupy as np
|
|
8
|
+
import cupy as scipy
|
|
9
|
+
|
|
10
|
+
ACCELERATED = True
|
|
11
|
+
else:
|
|
12
|
+
import pandas as pd
|
|
13
|
+
import numpy as np
|
|
14
|
+
import scipy
|
|
15
|
+
|
|
16
|
+
ACCELERATED = False
|
|
@@ -0,0 +1,271 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
import itertools
|
|
3
|
+
import time
|
|
4
|
+
from fastfusion._accelerated_imports import pd
|
|
5
|
+
from fastfusion.mapper.FFM._join_pmappings.sim import PmappingGroup, Loop, Compatibility
|
|
6
|
+
from fastfusion.mapper.FFM._join_pmappings.pmapping_group import PmappingDataframe
|
|
7
|
+
from fastfusion.mapper.simanneal.mapspaceglobals import MapspaceGlobals
|
|
8
|
+
from fastfusion.util._frozenset import fzs
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def mapping2sims(einsum_to_result: Compatibility):
|
|
12
|
+
r = {}
|
|
13
|
+
for einsum_name, compat_dict in einsum_to_result.items():
|
|
14
|
+
r[einsum_name] = [paretofy(k, v) for k, v in compat_dict.items()]
|
|
15
|
+
return list(r.values())
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_possible_translations(
|
|
19
|
+
t: Compatibility,
|
|
20
|
+
pairwise_equivalent_rank_variables: dict[str, set[str]],
|
|
21
|
+
full_equivalent_rank_variables: dict[str, set[str]],
|
|
22
|
+
right_rank_variables: set[str],
|
|
23
|
+
):
|
|
24
|
+
# Fused ranks should be transitive, but if a fused loop indexes into two
|
|
25
|
+
# different ranks in the next Einsum, we can't fuse becuase it will tile in
|
|
26
|
+
# multiple directions.
|
|
27
|
+
#
|
|
28
|
+
# The first union checks what loops we CAN fuse with in the next Einsum. The
|
|
29
|
+
# second union checks what loops MUST index into in the next
|
|
30
|
+
#
|
|
31
|
+
# Einsum. If we alias into multiple ranks, we can't fuse. Otherwise, try out
|
|
32
|
+
# each possible rank.
|
|
33
|
+
def translate_loop(l: Loop):
|
|
34
|
+
compatible_rank_variables = (
|
|
35
|
+
set.union(
|
|
36
|
+
*(full_equivalent_rank_variables[n] for n in l.rank_variable_names)
|
|
37
|
+
)
|
|
38
|
+
& right_rank_variables
|
|
39
|
+
)
|
|
40
|
+
pairwise_compatible_rank_variables = (
|
|
41
|
+
set.union(
|
|
42
|
+
*(pairwise_equivalent_rank_variables[n] for n in l.rank_variable_names)
|
|
43
|
+
)
|
|
44
|
+
& right_rank_variables
|
|
45
|
+
)
|
|
46
|
+
if len(pairwise_compatible_rank_variables) > 1:
|
|
47
|
+
return
|
|
48
|
+
for n in compatible_rank_variables:
|
|
49
|
+
yield Loop(fzs((n,)), l.bound, l.is_spatial)
|
|
50
|
+
|
|
51
|
+
for loops in itertools.product(*map(translate_loop, t.loops)):
|
|
52
|
+
yield t.update(loops=loops)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
prev_time = 0
|
|
56
|
+
total_time = defaultdict(int)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def init_print_time():
|
|
60
|
+
global prev_time, total_time
|
|
61
|
+
prev_time = time.time()
|
|
62
|
+
total_time = defaultdict(int)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def print_time(what: str):
|
|
66
|
+
global prev_time
|
|
67
|
+
t = time.time() - prev_time
|
|
68
|
+
print(f"{what}: {t:.2f} seconds")
|
|
69
|
+
total_time[what] += t
|
|
70
|
+
prev_time = time.time()
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def print_total_time():
|
|
74
|
+
print(f"\n======== Total time ========")
|
|
75
|
+
for k, v in total_time.items():
|
|
76
|
+
print(f"{k}: {v:.2f} seconds")
|
|
77
|
+
total = sum(total_time.values())
|
|
78
|
+
if total > 60:
|
|
79
|
+
print(f"\nTotal: {total:.2f} seconds ({total/60:.2f} minutes)")
|
|
80
|
+
else:
|
|
81
|
+
print(f"\nTotal: {total:.2f} seconds")
|
|
82
|
+
print(f"============================\n")
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class PmappingsOneEinsum:
|
|
86
|
+
def __init__(self, einsum_name: str, pm_group_list: list[PmappingGroup]):
|
|
87
|
+
self.einsum_name: str = einsum_name
|
|
88
|
+
self.pmapping_groups: list[PmappingGroup] = pm_group_list
|
|
89
|
+
self.tensor_names: set[str] = set(pm_group_list[0].tensor_names)
|
|
90
|
+
|
|
91
|
+
def __getitem__(self, i):
|
|
92
|
+
return self.pmapping_groups[i]
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def make_full_equivalent_rank_variables(pairwise_equivalent_rank_variables):
|
|
96
|
+
full_equivalent_rank_variables = {
|
|
97
|
+
k: set(v) for k, v in pairwise_equivalent_rank_variables.items()
|
|
98
|
+
}
|
|
99
|
+
changed = True
|
|
100
|
+
while changed:
|
|
101
|
+
changed = False
|
|
102
|
+
for r in full_equivalent_rank_variables:
|
|
103
|
+
for r2 in list(full_equivalent_rank_variables[r]):
|
|
104
|
+
for r3 in list(full_equivalent_rank_variables[r2]):
|
|
105
|
+
if r3 in full_equivalent_rank_variables[r]:
|
|
106
|
+
continue
|
|
107
|
+
changed = True
|
|
108
|
+
full_equivalent_rank_variables[r].add(r3)
|
|
109
|
+
return full_equivalent_rank_variables
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def quick_join(
|
|
113
|
+
pmapping_groups: dict[str, PmappingGroup],
|
|
114
|
+
mapspace_globals: MapspaceGlobals,
|
|
115
|
+
):
|
|
116
|
+
resource2capacity = mapspace_globals.resource2capacity
|
|
117
|
+
pairwise_equivalent_rank_variables = mapspace_globals.pairwise_equivalent_ranks
|
|
118
|
+
aliased_tensors = mapspace_globals.aliased_tensors
|
|
119
|
+
full_equivalent_rank_variables = mapspace_globals.full_equivalent_ranks
|
|
120
|
+
|
|
121
|
+
n_mappings = {}
|
|
122
|
+
runtime = {}
|
|
123
|
+
nbuckets = []
|
|
124
|
+
|
|
125
|
+
n_evaluations = 0
|
|
126
|
+
|
|
127
|
+
pmapping_groups = list(pmapping_groups.items())
|
|
128
|
+
|
|
129
|
+
init_print_time()
|
|
130
|
+
|
|
131
|
+
pmapping_groups = [PmappingsOneEinsum(*s) for s in pmapping_groups]
|
|
132
|
+
|
|
133
|
+
if not pmapping_groups:
|
|
134
|
+
raise ValueError("No PmappingGroups to join")
|
|
135
|
+
|
|
136
|
+
# ======================================================================
|
|
137
|
+
# Initial consolidate and group all PmappingGroups
|
|
138
|
+
# ======================================================================
|
|
139
|
+
for i, sim_holder in enumerate(pmapping_groups):
|
|
140
|
+
right_tensors = set.union(
|
|
141
|
+
set(), *[s.tensor_names for s in pmapping_groups[i + 1 :]]
|
|
142
|
+
)
|
|
143
|
+
if i == 0:
|
|
144
|
+
sim_holder.pmapping_groups = PmappingGroup.left_consolidate(
|
|
145
|
+
sim_holder.pmapping_groups,
|
|
146
|
+
right_tensors,
|
|
147
|
+
)
|
|
148
|
+
continue
|
|
149
|
+
t0 = time.time()
|
|
150
|
+
left_tensors = set.union(set(), *[s.tensor_names for s in pmapping_groups[:i]])
|
|
151
|
+
live_tensors = right_tensors
|
|
152
|
+
shared_tensors = left_tensors & sim_holder.tensor_names
|
|
153
|
+
sim_holder.pmapping_groups = sorted(
|
|
154
|
+
sim_holder.pmapping_groups, key=lambda x: len(x.mappings.data), reverse=True
|
|
155
|
+
)
|
|
156
|
+
sim_holder.pmapping_groups = PmappingGroup.right_consolidate(
|
|
157
|
+
sim_holder.pmapping_groups,
|
|
158
|
+
live_tensors,
|
|
159
|
+
shared_tensors,
|
|
160
|
+
)
|
|
161
|
+
sim_holder.pmapping_groups = PmappingGroup.combine_combineable(
|
|
162
|
+
sim_holder.pmapping_groups,
|
|
163
|
+
left_tensors | right_tensors,
|
|
164
|
+
)
|
|
165
|
+
if i > 0:
|
|
166
|
+
sim_holder.pmapping_groups = PmappingGroup.group_right(
|
|
167
|
+
sim_holder.pmapping_groups, left_tensors, drop_tags=True
|
|
168
|
+
)
|
|
169
|
+
einsum, prev_einsum = sim_holder.einsum_name, pmapping_groups[i - 1].einsum_name
|
|
170
|
+
runtime[f"{prev_einsum} → {einsum}"] = time.time() - t0
|
|
171
|
+
t0 = time.time()
|
|
172
|
+
|
|
173
|
+
n_iterations = 0
|
|
174
|
+
total_iterations = len(pmapping_groups)
|
|
175
|
+
|
|
176
|
+
def grab_sim_holder() -> (
|
|
177
|
+
tuple[dict[Compatibility, list[PmappingGroup]], str, set[str]]
|
|
178
|
+
):
|
|
179
|
+
nonlocal n_iterations
|
|
180
|
+
n_iterations += 1
|
|
181
|
+
holder = pmapping_groups.pop(0)
|
|
182
|
+
return holder.pmapping_groups, holder.einsum_name, holder.tensor_names
|
|
183
|
+
|
|
184
|
+
if pmapping_groups:
|
|
185
|
+
left, left_einsum, left_tensors = grab_sim_holder()
|
|
186
|
+
|
|
187
|
+
partial_mapping_size = 1
|
|
188
|
+
while pmapping_groups:
|
|
189
|
+
t0 = time.time()
|
|
190
|
+
# ======================================================================
|
|
191
|
+
# Grab new Einsum from the right. Record logging data and find still
|
|
192
|
+
# tensors that will be live after this Einsum.
|
|
193
|
+
# ======================================================================
|
|
194
|
+
nbuckets.append(len(left))
|
|
195
|
+
# nmappings.append(sum(len(s.mappings.data) for s in left))
|
|
196
|
+
right, right_einsum, right_tensors = grab_sim_holder()
|
|
197
|
+
right_rank_variables = mapspace_globals.einsum2ranks[right_einsum]
|
|
198
|
+
|
|
199
|
+
partial_mapping_size += 1
|
|
200
|
+
|
|
201
|
+
live_tensors = set.union(set(), *[s.tensor_names for s in pmapping_groups])
|
|
202
|
+
shared_tensors = set(left_tensors) & set(right_tensors)
|
|
203
|
+
live_tensors_with_right = live_tensors | right_tensors
|
|
204
|
+
|
|
205
|
+
# ======================================================================
|
|
206
|
+
# Clean up the previously-combined PmappingGroups. Consolidate, combine, group
|
|
207
|
+
# them into buckets.
|
|
208
|
+
# ======================================================================
|
|
209
|
+
|
|
210
|
+
left = PmappingGroup.combine_combineable(
|
|
211
|
+
left,
|
|
212
|
+
live_tensors | right_tensors,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
# Group left and right into buckets
|
|
216
|
+
left = PmappingGroup.group_left(left, right_tensors, drop_tags=True)
|
|
217
|
+
|
|
218
|
+
# ======================================================================
|
|
219
|
+
# Remove dead tensors from left and right. This happens after grouping
|
|
220
|
+
# because we only reserve space for shared tensors after it's dead. This
|
|
221
|
+
# is in case the tensor lifetime extends beyond the Einsums for which it
|
|
222
|
+
# is used.
|
|
223
|
+
# ======================================================================
|
|
224
|
+
PmappingGroup.remove_dead_tensors(
|
|
225
|
+
[s for lr in [left, right] for v in lr.values() for s in v], live_tensors
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
# ======================================================================
|
|
229
|
+
# Merge the left and right buckets.
|
|
230
|
+
# ======================================================================
|
|
231
|
+
combined: list[PmappingGroup] = []
|
|
232
|
+
for k in left:
|
|
233
|
+
for k_translated in get_possible_translations(
|
|
234
|
+
k,
|
|
235
|
+
pairwise_equivalent_rank_variables,
|
|
236
|
+
full_equivalent_rank_variables,
|
|
237
|
+
right_rank_variables,
|
|
238
|
+
):
|
|
239
|
+
for a, b in itertools.product(left[k], right.get(k_translated, [])):
|
|
240
|
+
if a.compatibility.tags.are_compatible_with(b.compatibility.tags):
|
|
241
|
+
combined.append(
|
|
242
|
+
a.merge_next(
|
|
243
|
+
b,
|
|
244
|
+
live_tensors,
|
|
245
|
+
live_tensors_with_right,
|
|
246
|
+
aliased_tensors,
|
|
247
|
+
resource2capacity,
|
|
248
|
+
delay=False,
|
|
249
|
+
)
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
if not combined:
|
|
253
|
+
raise ValueError("No match found for any group")
|
|
254
|
+
|
|
255
|
+
# ======================================================================
|
|
256
|
+
# Update left for the next iteration.
|
|
257
|
+
# =================================================================
|
|
258
|
+
left = combined
|
|
259
|
+
left_einsum = right_einsum
|
|
260
|
+
left_tensors |= right_tensors
|
|
261
|
+
|
|
262
|
+
# ======================================================================
|
|
263
|
+
# Final consolidate and group
|
|
264
|
+
# ======================================================================
|
|
265
|
+
t0 = time.time()
|
|
266
|
+
left = PmappingGroup.left_consolidate(left, None)
|
|
267
|
+
s_final = PmappingGroup.combine_combineable(left, set(), drop_tags=True)
|
|
268
|
+
assert len(s_final) == 1
|
|
269
|
+
mappings = s_final[0].mappings
|
|
270
|
+
|
|
271
|
+
return mappings
|
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
import itertools
|
|
3
|
+
|
|
4
|
+
from fastfusion.frontend import arch
|
|
5
|
+
from fastfusion.frontend.spec import Spec
|
|
6
|
+
from fastfusion.mapper.FFM._join_pmappings.join_pmappings import PmappingGroup
|
|
7
|
+
from fastfusion.mapper.FFM._join_pmappings.compatibility import Loop, Compatibility
|
|
8
|
+
from fastfusion.util._frozenset import fzs
|
|
9
|
+
from fastfusion.mapper.FFM._join_pmappings.join_pmappings import (
|
|
10
|
+
make_full_equivalent_rank_variables,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MapspaceGlobals:
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
pmapping_groups: dict[str, list[PmappingGroup]],
|
|
18
|
+
spec: Spec,
|
|
19
|
+
objective_function_cols: list[str] = None,
|
|
20
|
+
flattened_architecture: list[arch.Leaf] = None,
|
|
21
|
+
):
|
|
22
|
+
self.pmapping_groups = pmapping_groups
|
|
23
|
+
self.einsum_names = spec.workload.einsum_names
|
|
24
|
+
self.einsum2ranks = {
|
|
25
|
+
einsum_name: spec.workload.einsums[einsum_name].rank_variables
|
|
26
|
+
for einsum_name in self.einsum_names
|
|
27
|
+
}
|
|
28
|
+
self.einsum2tensors = {
|
|
29
|
+
einsum_name: spec.workload.einsums[einsum_name].tensor_names
|
|
30
|
+
for einsum_name in self.einsum_names
|
|
31
|
+
}
|
|
32
|
+
self.tensor_names = set().union(
|
|
33
|
+
*(self.einsum2tensors[e] for e in self.einsum_names)
|
|
34
|
+
)
|
|
35
|
+
self.tensor_names_used_in_multiple_einsums = (
|
|
36
|
+
spec.workload.tensor_names_used_in_multiple_einsums
|
|
37
|
+
)
|
|
38
|
+
self.pairwise_equivalent_ranks = (
|
|
39
|
+
spec.workload.get_pairwise_equivalent_rank_variables()
|
|
40
|
+
)
|
|
41
|
+
self.full_equivalent_ranks = make_full_equivalent_rank_variables(
|
|
42
|
+
self.pairwise_equivalent_ranks
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
self.resource2capacity = {}
|
|
46
|
+
flattened_architecture = (
|
|
47
|
+
flattened_architecture or spec.get_flattened_architecture()
|
|
48
|
+
)
|
|
49
|
+
for l in flattened_architecture:
|
|
50
|
+
if isinstance(l, arch.Memory):
|
|
51
|
+
self.resource2capacity[l.name] = l.size
|
|
52
|
+
|
|
53
|
+
self.objective_function_cols = objective_function_cols
|
|
54
|
+
self.rank_translations = self._create_rank_translations(self.einsum2ranks)
|
|
55
|
+
|
|
56
|
+
for i, (left_id, left_sims) in enumerate(pmapping_groups.items()):
|
|
57
|
+
for j, (right_id, right_sims) in enumerate(pmapping_groups.items()):
|
|
58
|
+
if i >= j:
|
|
59
|
+
continue
|
|
60
|
+
|
|
61
|
+
left_live = self.get_live_tensors(*self.einsum_names[: i + 1])
|
|
62
|
+
right_live = self.get_live_tensors(*self.einsum_names[j:])
|
|
63
|
+
left_tensors = self.get_tensors(self.einsum_names[i])
|
|
64
|
+
right_tensors = self.get_tensors(self.einsum_names[j])
|
|
65
|
+
|
|
66
|
+
if not (left_live & right_live):
|
|
67
|
+
continue
|
|
68
|
+
print(f"Checking {left_id} {right_id}")
|
|
69
|
+
|
|
70
|
+
right_tilings = {
|
|
71
|
+
s.compatibility.clear_dead_tensors(
|
|
72
|
+
live_tensors=left_live
|
|
73
|
+
).clear_dead_tensors(left_tensors, keep_loops=True)
|
|
74
|
+
for s in right_sims
|
|
75
|
+
}
|
|
76
|
+
assert right_tilings, f"R {left_id} {right_id}"
|
|
77
|
+
for s in list(left_sims):
|
|
78
|
+
for t in self.get_possible_translations(s.compatibility, right_id):
|
|
79
|
+
t = t.clear_dead_tensors(live_tensors=right_live)
|
|
80
|
+
t = t.clear_dead_tensors(
|
|
81
|
+
live_tensors=right_tensors, keep_loops=True
|
|
82
|
+
)
|
|
83
|
+
if t in right_tilings:
|
|
84
|
+
break
|
|
85
|
+
else:
|
|
86
|
+
left_sims.remove(s)
|
|
87
|
+
assert (
|
|
88
|
+
left_sims
|
|
89
|
+
), f"Removed all of left {left_id} while checking right {right_id}"
|
|
90
|
+
|
|
91
|
+
left_tilings = {
|
|
92
|
+
s.compatibility.clear_dead_tensors(
|
|
93
|
+
live_tensors=right_live
|
|
94
|
+
).clear_dead_tensors(right_tensors, keep_loops=True)
|
|
95
|
+
for s in left_sims
|
|
96
|
+
}
|
|
97
|
+
assert left_tilings, f"L {left_id} {right_id}"
|
|
98
|
+
for s in list(right_sims):
|
|
99
|
+
for t in self.get_possible_translations(s.compatibility, left_id):
|
|
100
|
+
t = t.clear_dead_tensors(live_tensors=left_live)
|
|
101
|
+
t = t.clear_dead_tensors(
|
|
102
|
+
live_tensors=left_tensors, keep_loops=True
|
|
103
|
+
)
|
|
104
|
+
if t in left_tilings:
|
|
105
|
+
break
|
|
106
|
+
else:
|
|
107
|
+
right_sims.remove(s)
|
|
108
|
+
assert (
|
|
109
|
+
right_sims
|
|
110
|
+
), f"Removed all of right {right_id} while checking left {left_id}"
|
|
111
|
+
|
|
112
|
+
self.tensor2possible_loops_above = self._create_tensor2possible_loops_above()
|
|
113
|
+
self.tensor2possible_loops_above_set = {
|
|
114
|
+
k: {k2: set(v2) for k2, v2 in v.items()}
|
|
115
|
+
for k, v in self.tensor2possible_loops_above.items()
|
|
116
|
+
}
|
|
117
|
+
self.tensor2memories = self._create_tensor2memories()
|
|
118
|
+
self.einsum_tiling_2_sim = self._create_einsum_tiling_2_sim()
|
|
119
|
+
self.einsum_rank_index_to_loops = self._create_einsum_rank_index_to_loops()
|
|
120
|
+
(
|
|
121
|
+
self.compatibility2leftcompatibility,
|
|
122
|
+
self.compatibility2rightcompatibility,
|
|
123
|
+
self.leftcompatibility2tiling,
|
|
124
|
+
self.rightcompatibility2tiling,
|
|
125
|
+
) = self._create_compatibility()
|
|
126
|
+
self.size_scale = len(self.einsum2ranks)
|
|
127
|
+
n_optimal = sum(
|
|
128
|
+
len(s.mappings.data)
|
|
129
|
+
for simlist in self.pmapping_groups.values()
|
|
130
|
+
for s in simlist
|
|
131
|
+
)
|
|
132
|
+
n_pmappings = sum(
|
|
133
|
+
s.mappings.n_pmappings
|
|
134
|
+
for simlist in self.pmapping_groups.values()
|
|
135
|
+
for s in simlist
|
|
136
|
+
)
|
|
137
|
+
self.find_pmapping_scale = n_pmappings / n_optimal
|
|
138
|
+
self.aliased_tensors = spec.workload.get_tensor_copies()
|
|
139
|
+
|
|
140
|
+
def get_live_tensors(self, *einsums: str):
|
|
141
|
+
return set.union(*(self.einsum2tensors[e] for e in einsums))
|
|
142
|
+
|
|
143
|
+
def _create_compatibility(self):
|
|
144
|
+
tiling2leftcompatibility = {}
|
|
145
|
+
tiling2rightcompatibility = {}
|
|
146
|
+
|
|
147
|
+
def tilings2compatibility(tilings: list[Compatibility], live_tensors: set[str]):
|
|
148
|
+
return {t: t.clear_dead_tensors(live_tensors=live_tensors) for t in tilings}
|
|
149
|
+
|
|
150
|
+
for i, (einsum_name, pm_group_list) in enumerate(self.pmapping_groups.items()):
|
|
151
|
+
if i > 0:
|
|
152
|
+
prev_live = self.get_live_tensors(*self.einsum_names[:i])
|
|
153
|
+
tiling2leftcompatibility[einsum_name] = tilings2compatibility(
|
|
154
|
+
[s.compatibility for s in pm_group_list],
|
|
155
|
+
prev_live,
|
|
156
|
+
)
|
|
157
|
+
if i < len(self.pmapping_groups) - 1:
|
|
158
|
+
next_live = self.get_live_tensors(*self.einsum_names[i + 1 :])
|
|
159
|
+
tiling2rightcompatibility[einsum_name] = tilings2compatibility(
|
|
160
|
+
[s.compatibility for s in pm_group_list],
|
|
161
|
+
next_live,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
leftcompatibility2tiling = {}
|
|
165
|
+
rightcompatibility2tiling = {}
|
|
166
|
+
for einsum_name in self.einsum_names:
|
|
167
|
+
for src, dst in (
|
|
168
|
+
(tiling2leftcompatibility, leftcompatibility2tiling),
|
|
169
|
+
(tiling2rightcompatibility, rightcompatibility2tiling),
|
|
170
|
+
):
|
|
171
|
+
if einsum_name not in src:
|
|
172
|
+
continue
|
|
173
|
+
dst = dst.setdefault(einsum_name, {})
|
|
174
|
+
for k, v in src[einsum_name].items():
|
|
175
|
+
dst.setdefault(v, []).append(k)
|
|
176
|
+
return (
|
|
177
|
+
tiling2leftcompatibility,
|
|
178
|
+
tiling2rightcompatibility,
|
|
179
|
+
leftcompatibility2tiling,
|
|
180
|
+
rightcompatibility2tiling,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
def _create_einsum_tiling_2_sim(self):
|
|
184
|
+
einsum_tiling_2_sim = {}
|
|
185
|
+
for e, pm_group_list in self.pmapping_groups.items():
|
|
186
|
+
cur_sims = defaultdict(list)
|
|
187
|
+
for sim in pm_group_list:
|
|
188
|
+
cur_sims[sim.compatibility].append(sim)
|
|
189
|
+
einsum_tiling_2_sim[e] = {}
|
|
190
|
+
for t, s in cur_sims.items():
|
|
191
|
+
s = PmappingGroup.concat(s)
|
|
192
|
+
einsum_tiling_2_sim[e][t] = s
|
|
193
|
+
return einsum_tiling_2_sim
|
|
194
|
+
|
|
195
|
+
def _create_tensor2possible_loops_above(self):
|
|
196
|
+
tensor2possible_loops_above = {}
|
|
197
|
+
for einsum_name, pm_group_list in self.pmapping_groups.items():
|
|
198
|
+
tensor2possible_loops_above[einsum_name] = defaultdict(set)
|
|
199
|
+
for sim in pm_group_list:
|
|
200
|
+
for tensor in sim.compatibility.tensors:
|
|
201
|
+
tensor2possible_loops_above[einsum_name][tensor] |= set(
|
|
202
|
+
sim.compatibility.loops[: tensor.above_loop_index]
|
|
203
|
+
)
|
|
204
|
+
return {
|
|
205
|
+
e: {s: list(l) for s, l in d.items()}
|
|
206
|
+
for e, d in tensor2possible_loops_above.items()
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
def _create_tensor2memories(self):
|
|
210
|
+
tensor2memories = {}
|
|
211
|
+
for t in self.tensor_names_used_in_multiple_einsums:
|
|
212
|
+
possible_memories = []
|
|
213
|
+
for einsum_name, pm_group_list in self.pmapping_groups.items():
|
|
214
|
+
cur_memories = set()
|
|
215
|
+
if t not in pm_group_list[0].tensor_names:
|
|
216
|
+
continue
|
|
217
|
+
for sim in pm_group_list:
|
|
218
|
+
tensor = sim.compatibility.get_tensor_by_name(t)
|
|
219
|
+
cur_memories.add(tensor)
|
|
220
|
+
possible_memories.append(cur_memories)
|
|
221
|
+
if possible_memories:
|
|
222
|
+
tensor2memories[t] = list(set.intersection(*possible_memories))
|
|
223
|
+
else:
|
|
224
|
+
raise ValueError(f"No memories for {t}")
|
|
225
|
+
return tensor2memories
|
|
226
|
+
|
|
227
|
+
def _create_rank_translations(self, einsum2ranks: dict[str, set[str]]):
|
|
228
|
+
rank_translations = {}
|
|
229
|
+
for einsum_name, ranks in einsum2ranks.items():
|
|
230
|
+
translations = {einsum_name2: {} for einsum_name2 in self.einsum_names}
|
|
231
|
+
for einsum_name2, ranks2 in einsum2ranks.items():
|
|
232
|
+
for rank in ranks:
|
|
233
|
+
equiv = self.full_equivalent_ranks[rank] & ranks2
|
|
234
|
+
translations[einsum_name2][rank] = equiv
|
|
235
|
+
rank_translations[einsum_name] = {
|
|
236
|
+
k: {k2: list(v2) for k2, v2 in v.items()}
|
|
237
|
+
for k, v in translations.items()
|
|
238
|
+
}
|
|
239
|
+
return rank_translations
|
|
240
|
+
|
|
241
|
+
def _create_full_equivalent_ranks(
|
|
242
|
+
self, pairwise_equivalent_ranks: dict[str, set[str]]
|
|
243
|
+
):
|
|
244
|
+
full_equivalent_ranks = {
|
|
245
|
+
k: set(v) for k, v in pairwise_equivalent_ranks.items()
|
|
246
|
+
}
|
|
247
|
+
changed = True
|
|
248
|
+
while changed:
|
|
249
|
+
changed = False
|
|
250
|
+
for r in full_equivalent_ranks:
|
|
251
|
+
for r2 in list(full_equivalent_ranks[r]):
|
|
252
|
+
for r3 in list(full_equivalent_ranks[r2]):
|
|
253
|
+
if r3 in full_equivalent_ranks[r]:
|
|
254
|
+
continue
|
|
255
|
+
changed = True
|
|
256
|
+
full_equivalent_ranks[r].add(r3)
|
|
257
|
+
return full_equivalent_ranks
|
|
258
|
+
|
|
259
|
+
def _create_einsum_rank_index_to_loops(
|
|
260
|
+
self,
|
|
261
|
+
) -> dict[str, dict[str, dict[int, list[Loop]]]]:
|
|
262
|
+
einsum_rank_index_to_loops = {}
|
|
263
|
+
for einsum_name, pm_group_list in self.pmapping_groups.items():
|
|
264
|
+
einsum_rank_index_to_loops[einsum_name] = {}
|
|
265
|
+
for sim in pm_group_list:
|
|
266
|
+
for rank_index, loop in enumerate(sim.compatibility.loops):
|
|
267
|
+
x = einsum_rank_index_to_loops[einsum_name].setdefault(
|
|
268
|
+
loop.rank_variable_name, {}
|
|
269
|
+
)
|
|
270
|
+
x.setdefault(rank_index, []).append(loop)
|
|
271
|
+
return einsum_rank_index_to_loops
|
|
272
|
+
|
|
273
|
+
def get_tensors(self, *einsums: str):
|
|
274
|
+
return set.union(*(self.einsum2tensors[e] for e in einsums))
|
|
275
|
+
|
|
276
|
+
def get_possible_translations(self, t: Compatibility, to_einsum: str):
|
|
277
|
+
pairwise_equivalent_ranks = self.pairwise_equivalent_ranks
|
|
278
|
+
full_equivalent_ranks = self.full_equivalent_ranks
|
|
279
|
+
right_ranks = self.einsum2ranks[to_einsum]
|
|
280
|
+
|
|
281
|
+
def translate_loop(l: Loop):
|
|
282
|
+
compatible_ranks = (
|
|
283
|
+
set.union(*(full_equivalent_ranks[n] for n in l.rank_variable_names))
|
|
284
|
+
& right_ranks
|
|
285
|
+
)
|
|
286
|
+
pairwise_compatible_ranks = (
|
|
287
|
+
set.union(
|
|
288
|
+
*(pairwise_equivalent_ranks[n] for n in l.rank_variable_names)
|
|
289
|
+
)
|
|
290
|
+
& right_ranks
|
|
291
|
+
)
|
|
292
|
+
if len(pairwise_compatible_ranks) > 1:
|
|
293
|
+
return
|
|
294
|
+
for n in compatible_ranks:
|
|
295
|
+
yield Loop(fzs((n,)), l.bound, l.is_spatial)
|
|
296
|
+
|
|
297
|
+
for loops in itertools.product(*map(translate_loop, t.loops)):
|
|
298
|
+
yield Compatibility(loops, t.tensors, t.tags)
|