accelforge 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accelforge/__init__.py +21 -0
- accelforge/_accelerated_imports.py +16 -0
- accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
- accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
- accelforge/_deprecate/_simanneal/simanneal.py +666 -0
- accelforge/_deprecate/_simanneal/tracking.py +105 -0
- accelforge/_deprecate/_simanneal/wrappers.py +218 -0
- accelforge/_deprecate/_simanneal2/__init__.py +7 -0
- accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
- accelforge/_deprecate/_simanneal2/tracking.py +116 -0
- accelforge/_deprecate/compatibility_util.py +181 -0
- accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
- accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
- accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
- accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
- accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
- accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
- accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
- accelforge/_deprecate/tags.py +69 -0
- accelforge/_deprecate/viz/__init__.py +0 -0
- accelforge/_deprecate/viz/interactive.py +159 -0
- accelforge/_deprecate/viz/reservationtree.py +307 -0
- accelforge/_deprecate/viz/ski_slope.py +88 -0
- accelforge/_version.py +15 -0
- accelforge/examples.py +39 -0
- accelforge/frontend/__init__.py +10 -0
- accelforge/frontend/_binding.py +129 -0
- accelforge/frontend/_workload_isl/__init__.py +2 -0
- accelforge/frontend/_workload_isl/_isl.py +149 -0
- accelforge/frontend/_workload_isl/_symbolic.py +141 -0
- accelforge/frontend/arch copy.py +1544 -0
- accelforge/frontend/arch.py +1642 -0
- accelforge/frontend/config.py +63 -0
- accelforge/frontend/mapper/__init__.py +5 -0
- accelforge/frontend/mapper/ffm.py +126 -0
- accelforge/frontend/mapper/mapper.py +7 -0
- accelforge/frontend/mapper/metrics.py +30 -0
- accelforge/frontend/mapping/__init__.py +1 -0
- accelforge/frontend/mapping/mapping.py +1736 -0
- accelforge/frontend/model.py +14 -0
- accelforge/frontend/renames.py +150 -0
- accelforge/frontend/spec copy.py +230 -0
- accelforge/frontend/spec.py +301 -0
- accelforge/frontend/variables.py +12 -0
- accelforge/frontend/workload.py +952 -0
- accelforge/mapper/FFM/__init__.py +9 -0
- accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
- accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
- accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
- accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
- accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
- accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
- accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
- accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
- accelforge/mapper/FFM/data.py +61 -0
- accelforge/mapper/FFM/main copy.py +236 -0
- accelforge/mapper/FFM/main.py +208 -0
- accelforge/mapper/FFM/mappings.py +510 -0
- accelforge/mapper/FFM/pmappings.py +310 -0
- accelforge/mapper/__init__.py +4 -0
- accelforge/mapper.py +0 -0
- accelforge/model/__init__.py +1 -0
- accelforge/model/_looptree/__init__.py +0 -0
- accelforge/model/_looptree/accesses.py +335 -0
- accelforge/model/_looptree/capacity/__init__.py +1 -0
- accelforge/model/_looptree/capacity/aggregators.py +36 -0
- accelforge/model/_looptree/capacity/capacity.py +47 -0
- accelforge/model/_looptree/energy.py +150 -0
- accelforge/model/_looptree/equivalent_ranks.py +29 -0
- accelforge/model/_looptree/latency/__init__.py +1 -0
- accelforge/model/_looptree/latency/latency.py +98 -0
- accelforge/model/_looptree/latency/memory.py +120 -0
- accelforge/model/_looptree/latency/processors.py +92 -0
- accelforge/model/_looptree/mapping_utilities.py +71 -0
- accelforge/model/_looptree/reuse/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
- accelforge/model/_looptree/reuse/isl/des.py +59 -0
- accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
- accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
- accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
- accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
- accelforge/model/_looptree/run.py +122 -0
- accelforge/model/_looptree/types.py +26 -0
- accelforge/model/_looptree/visualization/__init__.py +0 -0
- accelforge/model/_looptree/visualization/occupancy.py +11 -0
- accelforge/model/main.py +222 -0
- accelforge/plotting/__init__.py +2 -0
- accelforge/plotting/mappings.py +219 -0
- accelforge/plotting/specs.py +57 -0
- accelforge/util/__init__.py +4 -0
- accelforge/util/_base_analysis_types.py +24 -0
- accelforge/util/_basetypes.py +1089 -0
- accelforge/util/_frozenset.py +36 -0
- accelforge/util/_isl.py +29 -0
- accelforge/util/_itertools.py +14 -0
- accelforge/util/_mathfuncs.py +57 -0
- accelforge/util/_parse_expressions.py +339 -0
- accelforge/util/_picklecache.py +32 -0
- accelforge/util/_setexpressions.py +268 -0
- accelforge/util/_sympy/__init__.py +0 -0
- accelforge/util/_sympy/broadcast_max.py +18 -0
- accelforge/util/_visualization.py +112 -0
- accelforge/util/_yaml.py +579 -0
- accelforge/util/parallel.py +193 -0
- accelforge-0.0.1.dist-info/METADATA +64 -0
- accelforge-0.0.1.dist-info/RECORD +258 -0
- accelforge-0.0.1.dist-info/WHEEL +5 -0
- accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
- accelforge-0.0.1.dist-info/top_level.txt +5 -0
- docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
- docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
- docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
- docs/_build/html/_sources/fastfusion.rst.txt +20 -0
- docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
- docs/_build/html/_sources/index.rst.txt +87 -0
- docs/_build/html/_sources/modules.rst.txt +7 -0
- docs/_build/html/_sources/notes/citation.rst.txt +45 -0
- docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
- docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
- docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
- docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
- docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
- docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
- docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
- docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
- docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
- docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
- docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
- docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
- docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
- docs/_build/html/_sources/notes/spec.rst.txt +36 -0
- docs/source/_ext/include_attrs.py +213 -0
- docs/source/_ext/include_docstring.py +364 -0
- docs/source/_ext/include_functions.py +154 -0
- docs/source/_ext/include_notebook.py +131 -0
- docs/source/_ext/include_yaml.py +119 -0
- docs/source/_ext/inherited_attributes.py +222 -0
- docs/source/_ext/paths.py +4 -0
- docs/source/conf.py +79 -0
- examples/arches/compute_in_memory/_include.yaml +74 -0
- examples/arches/compute_in_memory/_include_functions.py +229 -0
- examples/arches/compute_in_memory/_load_spec.py +57 -0
- examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
- examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
- examples/arches/compute_in_memory/components/misc.py +195 -0
- examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
- examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
- examples/arches/compute_in_memory/isaac.yaml +233 -0
- examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
- examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
- examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
- examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
- examples/arches/eyeriss.yaml +68 -0
- examples/arches/fanout_variations/at_glb.yaml +31 -0
- examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
- examples/arches/fanout_variations/at_mac.yaml +31 -0
- examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
- examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
- examples/arches/nvdla.yaml +47 -0
- examples/arches/simple.yaml +28 -0
- examples/arches/tpu_v4i.yaml +67 -0
- examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
- examples/misc/component_annotated.yaml +33 -0
- examples/workloads/gpt3_6.7B.yaml +124 -0
- examples/workloads/matmuls.yaml +20 -0
- examples/workloads/mobilenet_28.yaml +81 -0
- examples/workloads/mobilenet_various_separate.yaml +106 -0
- examples/workloads/three_matmuls_annotated.yaml +59 -0
- notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
- notebooks/compute_in_memory/_scripts.py +339 -0
- notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
- notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
- notebooks/paths.py +4 -0
- notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
- notebooks/tutorials/FFM.ipynb +3498 -0
- notebooks/tutorials/_include.py +48 -0
- notebooks/tutorials/component_energy_area.ipynb +363 -0
- tests/Q_mapping.yaml +38 -0
- tests/__init__.py +0 -0
- tests/conv.mapping.yaml +27 -0
- tests/conv.workload.yaml +13 -0
- tests/conv_sym.mapping.yaml +43 -0
- tests/copy.mapping.yaml +35 -0
- tests/copy.workload.yaml +15 -0
- tests/distribuffers/__init__.py +0 -0
- tests/distribuffers/multicast/test_cases.yaml +482 -0
- tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
- tests/distribuffers/spec/distributed.yaml +100 -0
- tests/distribuffers/spec/logical_arch.yaml +32 -0
- tests/distribuffers/spec/physical_arch.yaml +69 -0
- tests/distribuffers/test_binding.py +48 -0
- tests/frontend/__init__.py +0 -0
- tests/frontend/test_mapping_viz.py +52 -0
- tests/mapper/__init__.py +0 -0
- tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
- tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
- tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
- tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
- tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
- tests/mapper/test_mapping_to_isl.py +90 -0
- tests/mapper/test_spatial_reuse_analysis.py +67 -0
- tests/mapper/test_temporal_reuse_analysis.py +56 -0
- tests/mapper/util.py +58 -0
- tests/matmul.mapping.yaml +29 -0
- tests/matmul.workload.yaml +12 -0
- tests/matmul_spatial.mapping.yaml +44 -0
- tests/mha.renames.yaml +65 -0
- tests/mha.workload.yaml +67 -0
- tests/mha.yaml +59 -0
- tests/mha_full.workload.yaml +67 -0
- tests/mobilenet.workload.yaml +35 -0
- tests/mobilenet_long.workload.yaml +64 -0
- tests/pmappingcache.py +24 -0
- tests/processing_stage.arch.yaml +40 -0
- tests/snowcat.arch.yaml +36 -0
- tests/test_ffm_join_pmappings.py +106 -0
- tests/test_ffm_make_pmappings.py +82 -0
- tests/test_ffm_make_tile_shapes.py +49 -0
- tests/test_mapper.py +100 -0
- tests/test_model.py +37 -0
- tests/test_plotting.py +72 -0
- tests/test_processing_stage.py +46 -0
- tests/test_symbolic_model.py +248 -0
- tests/test_workload.py +141 -0
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterable, Set
|
|
4
|
+
|
|
5
|
+
from fastfusion.frontend.spec import Spec
|
|
6
|
+
from fastfusion.frontend.workload import EinsumName, TensorName
|
|
7
|
+
from fastfusion.mapper.FFM._join_pmappings.compatibility import Compatibility
|
|
8
|
+
from fastfusion.mapper.FFM._join_pmappings.sim import SIM
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
DO_PRINT = False
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def myprint(*args, **kwargs):
|
|
15
|
+
if DO_PRINT:
|
|
16
|
+
print(*args, **kwargs)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def sims2untiled_compats(
|
|
20
|
+
einsum2sims: dict[EinsumName, Iterable[SIM]],
|
|
21
|
+
) -> dict[EinsumName, set[Compatibility]]:
|
|
22
|
+
return {
|
|
23
|
+
einsum_name: {sim.compatibility.clear_loop_bounds() for sim in sims}
|
|
24
|
+
for einsum_name, sims in einsum2sims.items()
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def join_compatibilities(
|
|
29
|
+
einsum2compatibilities: dict[EinsumName, Iterable[Compatibility]],
|
|
30
|
+
spec: Spec = None,
|
|
31
|
+
) -> dict[EinsumName, set[Compatibility]]:
|
|
32
|
+
"""
|
|
33
|
+
Return dict from Einsum name to compatibilities (without tile shape)
|
|
34
|
+
that will ever contribute to full mappings.
|
|
35
|
+
|
|
36
|
+
CONTRACT FOR MAPPINGS GETTING TO THIS POINT: see `join_pmappings.join_sims`
|
|
37
|
+
"""
|
|
38
|
+
for einsum_name, compats in einsum2compatibilities.items():
|
|
39
|
+
if sum(len(c) for c in compats) == 0:
|
|
40
|
+
raise ValueError(f"No pmappings for {einsum_name}")
|
|
41
|
+
|
|
42
|
+
if len(einsum2compatibilities) == 0:
|
|
43
|
+
raise ValueError("Nothing to join")
|
|
44
|
+
|
|
45
|
+
for einsum_name, per_einsum_compats in einsum2compatibilities.items():
|
|
46
|
+
if not per_einsum_compats:
|
|
47
|
+
raise ValueError(f"No compatibility for {einsum_name}")
|
|
48
|
+
|
|
49
|
+
compat2einsum2original: dict[
|
|
50
|
+
Compatibility, dict[EinsumName, set[Compatibility]]
|
|
51
|
+
] = {}
|
|
52
|
+
for einsum_name, per_einsum_compats in einsum2compatibilities.items():
|
|
53
|
+
for compat in per_einsum_compats:
|
|
54
|
+
einsum2original = compat2einsum2original.setdefault(compat, {})
|
|
55
|
+
original = einsum2original.setdefault(einsum_name, set())
|
|
56
|
+
original.add(compat)
|
|
57
|
+
|
|
58
|
+
compatibilities = list(einsum2compatibilities.items())
|
|
59
|
+
|
|
60
|
+
einsum2tensor_names = {
|
|
61
|
+
einsum_name: spec.workload.einsums[einsum_name].tensor_names
|
|
62
|
+
for einsum_name in einsum2compatibilities
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
einsum2important_compatibilities = {}
|
|
66
|
+
|
|
67
|
+
# while-loop states
|
|
68
|
+
assert len(compatibilities) > 0
|
|
69
|
+
left_einsum, all_left_compats = compatibilities.pop(0)
|
|
70
|
+
left_tensors = einsum2tensor_names[left_einsum]
|
|
71
|
+
|
|
72
|
+
while compatibilities:
|
|
73
|
+
right_einsum, all_right_compats = compatibilities.pop(0)
|
|
74
|
+
|
|
75
|
+
right_tensors = einsum2tensor_names[right_einsum]
|
|
76
|
+
live_tensors = set.union(
|
|
77
|
+
set(), *(einsum2tensor_names[e] for e, _ in compatibilities)
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
grouped_left_compats = group_left(all_left_compats, right_tensors)
|
|
81
|
+
grouped_right_compats = group_right(all_right_compats, left_tensors)
|
|
82
|
+
|
|
83
|
+
combined = combine_left_and_right_compats(
|
|
84
|
+
compat2einsum2original,
|
|
85
|
+
grouped_left_compats,
|
|
86
|
+
grouped_right_compats,
|
|
87
|
+
live_tensors,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
if DO_PRINT:
|
|
91
|
+
print_reverse_unmatched(grouped_left_compats, grouped_right_compats)
|
|
92
|
+
|
|
93
|
+
if not combined:
|
|
94
|
+
raise ValueError("No match found for any group")
|
|
95
|
+
|
|
96
|
+
# update while-loop states
|
|
97
|
+
all_left_compats = combined
|
|
98
|
+
left_einsum = right_einsum
|
|
99
|
+
left_tensors |= right_tensors
|
|
100
|
+
|
|
101
|
+
einsum2important_compatibilities: dict[EinsumName, set[Compatibility]] = {}
|
|
102
|
+
for compat in combined:
|
|
103
|
+
for einsum, original in compat2einsum2original[compat].items():
|
|
104
|
+
important_compats = einsum2important_compatibilities.setdefault(
|
|
105
|
+
einsum, set()
|
|
106
|
+
)
|
|
107
|
+
important_compats.update(original)
|
|
108
|
+
return einsum2important_compatibilities
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def combine_left_and_right_compats(
|
|
112
|
+
compat2einsum2original: dict[Compatibility, dict[EinsumName, set[Compatibility]]],
|
|
113
|
+
grouped_left_compats: dict[Compatibility, Iterable[Compatibility]],
|
|
114
|
+
grouped_right_compats: dict[Compatibility, Iterable[Compatibility]],
|
|
115
|
+
live_tensors: set[TensorName],
|
|
116
|
+
):
|
|
117
|
+
combined: list[Compatibility] = []
|
|
118
|
+
for left_key, left_compats in grouped_left_compats.items():
|
|
119
|
+
myprint(f"Left key {left_key}")
|
|
120
|
+
|
|
121
|
+
compatible_right_compats = grouped_right_compats.get(left_key, [])
|
|
122
|
+
|
|
123
|
+
if len(compatible_right_compats) == 0:
|
|
124
|
+
if DO_PRINT:
|
|
125
|
+
for l in left_compats:
|
|
126
|
+
print(f"\tNo match for {l}")
|
|
127
|
+
continue
|
|
128
|
+
|
|
129
|
+
for l, r in itertools.product(left_compats, compatible_right_compats):
|
|
130
|
+
if l.tags.are_compatible_with(r.tags):
|
|
131
|
+
merged = l.merge_next(r, live_tensors)
|
|
132
|
+
combined.append(merged)
|
|
133
|
+
|
|
134
|
+
einsum2original = compat2einsum2original.setdefault(merged, {})
|
|
135
|
+
|
|
136
|
+
left_einsum2original = compat2einsum2original[l]
|
|
137
|
+
right_einsum2original = compat2einsum2original[r]
|
|
138
|
+
|
|
139
|
+
einsums = set(left_einsum2original) | set(right_einsum2original)
|
|
140
|
+
for einsum in einsums:
|
|
141
|
+
einsum2original.setdefault(einsum, set()).update(
|
|
142
|
+
left_einsum2original.get(einsum, set())
|
|
143
|
+
| right_einsum2original.get(einsum, set())
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
myprint(f"\t{l}\n\t<-->\n\t{r}")
|
|
147
|
+
myprint(f"\t-->\n\t{merged}")
|
|
148
|
+
return combined
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def print_reverse_unmatched(
|
|
152
|
+
grouped_left_compats: dict[Compatibility, Iterable[Compatibility]],
|
|
153
|
+
grouped_right_compats: dict[Compatibility, Iterable[Compatibility]],
|
|
154
|
+
):
|
|
155
|
+
for right_key, right_compats in grouped_right_compats.items():
|
|
156
|
+
if right_key not in grouped_left_compats:
|
|
157
|
+
for r in right_compats:
|
|
158
|
+
print(f"\tREVERSE: No match for {r} using {right_key}")
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def group_left(
|
|
162
|
+
left_compatibilities: Iterable[Compatibility],
|
|
163
|
+
right_tensors: Set[TensorName],
|
|
164
|
+
) -> dict[Compatibility, set[Compatibility]]:
|
|
165
|
+
grouped_compats = {}
|
|
166
|
+
for compat in left_compatibilities:
|
|
167
|
+
key = compat.clear_dead_tensors(right_tensors, keep_loops=True, drop_tags=True)
|
|
168
|
+
grouped_compats.setdefault(key, set()).add(compat)
|
|
169
|
+
return grouped_compats
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def group_right(
|
|
173
|
+
right_compatibilities: Iterable[Compatibility],
|
|
174
|
+
left_tensors: Set[TensorName],
|
|
175
|
+
) -> dict[Compatibility, set[Compatibility]]:
|
|
176
|
+
grouped_compats = {}
|
|
177
|
+
for compat in right_compatibilities:
|
|
178
|
+
key = compat.clear_dead_tensors(left_tensors, keep_loops=True, drop_tags=True)
|
|
179
|
+
for per_loop_key in key.all_n_loops():
|
|
180
|
+
grouped_compats.setdefault(per_loop_key, set()).add(compat)
|
|
181
|
+
return grouped_compats
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from collections.abc import Iterable
|
|
3
|
+
from itertools import permutations, product
|
|
4
|
+
|
|
5
|
+
from pytimeloop.bindings.looptree import LooptreeWorkload, LooptreeDependencyAnalyzer
|
|
6
|
+
|
|
7
|
+
from pytimeloop.looptree.mapping_utilities import get_intermediate_tensors
|
|
8
|
+
from fastfusion.util._frozenset import fzs
|
|
9
|
+
|
|
10
|
+
from .grouped_einsums import GroupOfSimilarEinsums, Id
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def group_similar_einsums(
|
|
14
|
+
einsum_ids: Iterable[int],
|
|
15
|
+
workload: LooptreeWorkload,
|
|
16
|
+
analyzer: LooptreeDependencyAnalyzer,
|
|
17
|
+
) -> list[GroupOfSimilarEinsums[Id]]:
|
|
18
|
+
"""
|
|
19
|
+
Groups similar Einsums in `einsum_ids`.
|
|
20
|
+
"""
|
|
21
|
+
grouped_einsums: list[GroupOfSimilarEinsums[Id]] = []
|
|
22
|
+
for einsum_id in einsum_ids:
|
|
23
|
+
found = False
|
|
24
|
+
for einsum_group in grouped_einsums:
|
|
25
|
+
einsum_ref_id = einsum_group.reference_einsum
|
|
26
|
+
rank_renaming, tensor_renaming = is_equivalent(
|
|
27
|
+
einsum_ref_id, einsum_id, workload, analyzer
|
|
28
|
+
)
|
|
29
|
+
if rank_renaming is not None:
|
|
30
|
+
einsum_group.add_similar_einsum(
|
|
31
|
+
einsum_id, rank_renaming, tensor_renaming
|
|
32
|
+
)
|
|
33
|
+
found = True
|
|
34
|
+
break
|
|
35
|
+
|
|
36
|
+
if not found:
|
|
37
|
+
grouped_einsums.append(GroupOfSimilarEinsums(einsum_id, workload))
|
|
38
|
+
return grouped_einsums
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def is_equivalent(
|
|
42
|
+
einsum_id1: int,
|
|
43
|
+
einsum_id2: int,
|
|
44
|
+
workload: LooptreeWorkload,
|
|
45
|
+
analyzer: LooptreeDependencyAnalyzer,
|
|
46
|
+
) -> tuple[dict[int, int], dict[int, int]]:
|
|
47
|
+
"""
|
|
48
|
+
Determines whether two Einsums are equivalent in tensor shapes and
|
|
49
|
+
tensor indexing expressions.
|
|
50
|
+
|
|
51
|
+
If the two Einsums are equivalent, the rank and tensor renamings are
|
|
52
|
+
returned.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
If the two Einsums are equivalent, the function returns two dicts,
|
|
56
|
+
`rank_renaming` and `tensor_renaming`, representing how to rename
|
|
57
|
+
ranks (tensors) of `einsum_id1` to `einsum_id2`.
|
|
58
|
+
|
|
59
|
+
Otherwise, a tuple `(None, None)` is returned.
|
|
60
|
+
"""
|
|
61
|
+
einsum1_ranks = workload.einsum_ospace_dimensions(einsum_id1)
|
|
62
|
+
einsum2_ranks = workload.einsum_ospace_dimensions(einsum_id2)
|
|
63
|
+
|
|
64
|
+
if len(einsum1_ranks) != len(einsum2_ranks):
|
|
65
|
+
return None, None
|
|
66
|
+
|
|
67
|
+
einsum1_input_tensors = workload.tensors_read_by_einsum(einsum_id1)
|
|
68
|
+
einsum1_output_tensor = workload.tensors_written_by_einsum(einsum_id1)
|
|
69
|
+
einsum2_input_tensors = workload.tensors_read_by_einsum(einsum_id2)
|
|
70
|
+
einsum2_output_tensor = workload.tensors_written_by_einsum(einsum_id2)
|
|
71
|
+
|
|
72
|
+
if einsum1_output_tensor is None:
|
|
73
|
+
einsum1_output_tensor = set()
|
|
74
|
+
if einsum2_output_tensor is None:
|
|
75
|
+
einsum2_output_tensor = set()
|
|
76
|
+
|
|
77
|
+
intermediate_tensors = get_intermediate_tensors(workload)
|
|
78
|
+
|
|
79
|
+
all_tensor_properties = []
|
|
80
|
+
all_tensors = [
|
|
81
|
+
(einsum1_input_tensors, einsum1_output_tensor),
|
|
82
|
+
(einsum2_input_tensors, einsum2_output_tensor),
|
|
83
|
+
]
|
|
84
|
+
for input_tensors, output_tensors in all_tensors:
|
|
85
|
+
tensor_properties = defaultdict(set)
|
|
86
|
+
for tensor in input_tensors:
|
|
87
|
+
tensor_properties[tensor].add("input")
|
|
88
|
+
for tensor in output_tensors:
|
|
89
|
+
tensor_properties[tensor].add("output")
|
|
90
|
+
for tensor in tensor_properties:
|
|
91
|
+
if tensor in intermediate_tensors:
|
|
92
|
+
tensor_properties[tensor].add("intermediate")
|
|
93
|
+
tensor_properties = {
|
|
94
|
+
tensor: fzs(properties) for tensor, properties in tensor_properties.items()
|
|
95
|
+
}
|
|
96
|
+
all_tensor_properties.append(tensor_properties)
|
|
97
|
+
|
|
98
|
+
property_to_tensors = defaultdict(lambda: (set(), set()))
|
|
99
|
+
for i, tensor_properties in enumerate(all_tensor_properties):
|
|
100
|
+
for tensor, property in tensor_properties.items():
|
|
101
|
+
tensor_sets = property_to_tensors[property]
|
|
102
|
+
tensor_sets[i].add(tensor)
|
|
103
|
+
|
|
104
|
+
# Check if we can rename tensors in einsum1 to einsum2
|
|
105
|
+
for tensor_renaming in tensor_renamings(property_to_tensors):
|
|
106
|
+
# Check if we can rename einsum1 ranks to create einsum2
|
|
107
|
+
for renamed_ranks in permutations(einsum2_ranks):
|
|
108
|
+
rank_renaming = {r1: r2 for r1, r2 in zip(einsum1_ranks, renamed_ranks)}
|
|
109
|
+
if not _shape_is_equivalent(rank_renaming, workload):
|
|
110
|
+
continue
|
|
111
|
+
|
|
112
|
+
if not _dependency_is_equivalent(
|
|
113
|
+
einsum_id1, einsum_id2, rank_renaming, tensor_renaming, analyzer
|
|
114
|
+
):
|
|
115
|
+
continue
|
|
116
|
+
|
|
117
|
+
return rank_renaming, tensor_renaming
|
|
118
|
+
return None, None
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def tensor_renamings(property_to_tensors):
|
|
122
|
+
for tensors_of_1, tensors_of_2 in property_to_tensors.values():
|
|
123
|
+
if len(tensors_of_1) != len(tensors_of_2):
|
|
124
|
+
return
|
|
125
|
+
|
|
126
|
+
all_tensors_of_1 = [
|
|
127
|
+
t for tensors_of_1, _ in property_to_tensors.values() for t in tensors_of_1
|
|
128
|
+
]
|
|
129
|
+
permutations_of_tensor_2_by_property = []
|
|
130
|
+
for _, tensors_of_2 in property_to_tensors.values():
|
|
131
|
+
permutations_of_tensor_2_by_property.append(permutations(tensors_of_2))
|
|
132
|
+
for permutation_of_2 in product(*permutations_of_tensor_2_by_property):
|
|
133
|
+
permutation_of_2 = tuple(t for tupl in permutation_of_2 for t in tupl)
|
|
134
|
+
renaming = dict(zip(all_tensors_of_1, permutation_of_2))
|
|
135
|
+
yield renaming
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _shape_is_equivalent(rank_renaming, workload):
|
|
139
|
+
for r1, r2 in rank_renaming.items():
|
|
140
|
+
r1_shape = workload.get_rank_shape(r1)
|
|
141
|
+
r2_shape = workload.get_rank_shape(r2)
|
|
142
|
+
if r1_shape != r2_shape:
|
|
143
|
+
return False
|
|
144
|
+
return True
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _dependency_is_equivalent(
|
|
148
|
+
einsum_id1, einsum_id2, rank_renaming, tensor_renaming, analyzer
|
|
149
|
+
):
|
|
150
|
+
for t1, t2 in tensor_renaming.items():
|
|
151
|
+
for r1, r2 in rank_renaming.items():
|
|
152
|
+
r1_relevant_to_t1 = analyzer.einsum_dim_is_directly_relevant_to_tensor(
|
|
153
|
+
einsum_id1, r1, t1
|
|
154
|
+
)
|
|
155
|
+
r2_relevant_to_t2 = analyzer.einsum_dim_is_directly_relevant_to_tensor(
|
|
156
|
+
einsum_id2, r2, t2
|
|
157
|
+
)
|
|
158
|
+
if r1_relevant_to_t1 != r2_relevant_to_t2:
|
|
159
|
+
return False
|
|
160
|
+
return True
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
|
|
3
|
+
from bindings.looptree import LooptreeWorkload
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
type Id = int
|
|
7
|
+
type Name = str
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class GroupOfSimilarEinsums[IdOrName: Id | Name]:
|
|
11
|
+
def __init__(self, reference_einsum: Id, workload: LooptreeWorkload):
|
|
12
|
+
self.reference_einsum = reference_einsum
|
|
13
|
+
self.workload = workload
|
|
14
|
+
self.similar_einsums_to_renaming = {}
|
|
15
|
+
self.in_id = True
|
|
16
|
+
|
|
17
|
+
def add_similar_einsum(
|
|
18
|
+
self,
|
|
19
|
+
similar_einsum: IdOrName,
|
|
20
|
+
rank_renaming: IdOrName,
|
|
21
|
+
tensor_renaming: IdOrName,
|
|
22
|
+
):
|
|
23
|
+
self.similar_einsums_to_renaming[similar_einsum] = (
|
|
24
|
+
rank_renaming,
|
|
25
|
+
tensor_renaming,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def similar_einsums(self) -> Iterable[IdOrName]:
|
|
30
|
+
return self.similar_einsums_to_renaming.keys()
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def get_renaming(
|
|
34
|
+
self, other_einsum: IdOrName
|
|
35
|
+
) -> tuple[dict[IdOrName, IdOrName], dict[IdOrName, IdOrName]]:
|
|
36
|
+
"""Returns iterable over tuple `(rank_renaming, tensor_renaming)`"""
|
|
37
|
+
try:
|
|
38
|
+
return self.similar_einsums_to_renaming[other_einsum]
|
|
39
|
+
except Exception as e:
|
|
40
|
+
e.add_note(f"{other_einsum} not in group of similar Einsums.")
|
|
41
|
+
raise
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def similar_einsums_and_renamings(
|
|
45
|
+
self,
|
|
46
|
+
) -> Iterable[
|
|
47
|
+
tuple[IdOrName, tuple[dict[IdOrName, IdOrName], dict[IdOrName, IdOrName]]]
|
|
48
|
+
]:
|
|
49
|
+
"""
|
|
50
|
+
Returns iterable over tuple `(similar_einsum, renaming)`
|
|
51
|
+
where `renaming` itself is `(rank_renaming, tensor_renaming).
|
|
52
|
+
"""
|
|
53
|
+
return self.similar_einsums_and_renamings.items()
|
|
54
|
+
|
|
55
|
+
def convert_id_to_name(self) -> "GroupOfSimilarEinsums[Name]":
|
|
56
|
+
einsum_id_to_name = self.workload.EinsumIdToName()
|
|
57
|
+
tensor_id_to_name = self.workload.DataSpaceIdToName()
|
|
58
|
+
rank_id_to_name = self.workload.DimensionIdToName()
|
|
59
|
+
|
|
60
|
+
grouped_einsums_in_name = GroupOfSimilarEinsums(
|
|
61
|
+
einsum_id_to_name[self.reference_einsum], self.workload
|
|
62
|
+
)
|
|
63
|
+
self.in_id = False
|
|
64
|
+
|
|
65
|
+
similar_einsums_to_renamings = self.get_einsums_similar_to_reference(
|
|
66
|
+
self.reference_einsum
|
|
67
|
+
)
|
|
68
|
+
for einsum_id, renaming in similar_einsums_to_renamings.items():
|
|
69
|
+
rank_renaming, tensor_renaming = renaming
|
|
70
|
+
rank_renaming_in_names = {
|
|
71
|
+
rank_id_to_name[k]: rank_id_to_name[v] for k, v in rank_renaming.items()
|
|
72
|
+
}
|
|
73
|
+
tensor_renaming_in_names = {
|
|
74
|
+
tensor_id_to_name[k]: tensor_id_to_name[v]
|
|
75
|
+
for k, v in tensor_renaming.items()
|
|
76
|
+
}
|
|
77
|
+
grouped_einsums_in_name.add_einsum_similar_to_reference(
|
|
78
|
+
einsum_id_to_name[self.reference_einsum],
|
|
79
|
+
einsum_id_to_name[einsum_id],
|
|
80
|
+
rank_renaming_in_names,
|
|
81
|
+
tensor_renaming_in_names,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
return grouped_einsums_in_name
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
from fastfusion.frontend.mapping import Loop, Temporal
|
|
2
|
+
from fastfusion.mapper.FFM.deprecate_maybe.tags import Tags
|
|
3
|
+
|
|
4
|
+
from .util import get_fused_loops_per_tensor
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
FFMT_VALID = "FFMT_VALID"
|
|
8
|
+
FFMT_WEIGHT_UNTILED = "FFMT_WEIGHT_UNTILED"
|
|
9
|
+
FFMT_WEIGHT_TILED = "FFMT_WEIGHT_TILED"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_ffmt_tag(compatibility):
|
|
13
|
+
return get_ffmt_matmul_tag(compatibility)
|
|
14
|
+
if "Matmul" in einsum_name:
|
|
15
|
+
return get_ffmt_matmul_tag(compatibility)
|
|
16
|
+
else:
|
|
17
|
+
return get_ffmt_mha_tag(compatibility)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_ffmt_matmul_tag(compatibility):
|
|
21
|
+
# FFMT is:
|
|
22
|
+
# - [input | output, weight]
|
|
23
|
+
# If there's >1 fused loop, they must be above the same number of loops
|
|
24
|
+
tensors = [s for s in compatibility.tensors if s.resource_name != "MainMemory"]
|
|
25
|
+
if len(tensors) <= 1:
|
|
26
|
+
return Tags((FFMT_VALID,))
|
|
27
|
+
|
|
28
|
+
allowed_n_loops = [
|
|
29
|
+
(0, 0),
|
|
30
|
+
(1, 1),
|
|
31
|
+
(1, 2),
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
# If there's a B or H fused loop, add one to the allowed n_loops
|
|
35
|
+
for rank_var in "b", "h":
|
|
36
|
+
if any(rank_var in l.rank_variable for l in compatibility.loops):
|
|
37
|
+
allowed_n_loops = [(x + 1, y + 1) for x, y in allowed_n_loops]
|
|
38
|
+
|
|
39
|
+
if tuple(sorted(s.above_loop_index for s in tensors)) in [
|
|
40
|
+
(0, 0),
|
|
41
|
+
(1, 1),
|
|
42
|
+
(1, 2),
|
|
43
|
+
]:
|
|
44
|
+
return Tags((FFMT_VALID,))
|
|
45
|
+
raise ValueError()
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_ffmt_mha_tag(compatibility):
|
|
49
|
+
tensors = [s for s in compatibility.tensors if s.resource_name != "MainMemory"]
|
|
50
|
+
if len(compatibility.loops) == 0:
|
|
51
|
+
return Tags((FFMT_VALID,))
|
|
52
|
+
|
|
53
|
+
# Loops have to be in the order (b, h)
|
|
54
|
+
if len(compatibility.loops) == 1:
|
|
55
|
+
return Tags((FFMT_INVALID,))
|
|
56
|
+
|
|
57
|
+
if len(set(s.above_loop_index for s in tensors)) > 1:
|
|
58
|
+
raise ValueError()
|
|
59
|
+
return Tags((FFMT_VALID,))
|
|
60
|
+
|
|
61
|
+
for tensors in compatibility.tensors:
|
|
62
|
+
if tensor.resource_name == "MainMemory":
|
|
63
|
+
continue
|
|
64
|
+
unique_loops.add(tensor.above_loop_index)
|
|
65
|
+
|
|
66
|
+
if len(unique_loops) == 0:
|
|
67
|
+
return Tags() # unfused is compatible with anything
|
|
68
|
+
|
|
69
|
+
untiled_fused = len(unique_loops) == 1 and next(iter(unique_loops)) == 0
|
|
70
|
+
if untiled_fused:
|
|
71
|
+
return Tags((FFMT_VALID,))
|
|
72
|
+
|
|
73
|
+
min_weight_idx, max_weight_idx, max_non_weight_idx = float("inf"), 0, 0
|
|
74
|
+
max_weight_idx = 0
|
|
75
|
+
for tensor, n_loops in tensor_to_n_fused_loops.items():
|
|
76
|
+
is_weight = "Filter" in tensor.name
|
|
77
|
+
if is_weight:
|
|
78
|
+
min_weight_idx = min(min_weight_idx, n_loops)
|
|
79
|
+
max_weight_idx = max(max_weight_idx, n_loops)
|
|
80
|
+
else:
|
|
81
|
+
max_non_weight_idx = max(max_non_weight_idx, n_loops)
|
|
82
|
+
|
|
83
|
+
weight_untiled = min_weight_idx == 0 and max_weight_idx == 0
|
|
84
|
+
if weight_untiled:
|
|
85
|
+
return Tags((FFMT_VALID, FFMT_WEIGHT_UNTILED))
|
|
86
|
+
elif min_weight_idx >= max_non_weight_idx:
|
|
87
|
+
return Tags((FFMT_VALID, FFMT_WEIGHT_TILED))
|
|
88
|
+
raise ValueError()
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def get_ffmt_mha_tag(pmapping):
|
|
92
|
+
einsum_name = pmapping[-1].einsum_name
|
|
93
|
+
B, H, M, F, P, G, E, D, C, J = "bhmfpgedcj"
|
|
94
|
+
EINSUM_NAME_TO_REDUCED_RANK_OUTPUT_RANK = {
|
|
95
|
+
"Q": [D, E],
|
|
96
|
+
"K": [D, E],
|
|
97
|
+
"V": [D, F],
|
|
98
|
+
"QK": [E, P],
|
|
99
|
+
"AV": [P, F],
|
|
100
|
+
"Z": [F, G],
|
|
101
|
+
"FFA": [G, C],
|
|
102
|
+
"FFB": [C, J],
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
rank_var_permutation = []
|
|
106
|
+
for node in pmapping:
|
|
107
|
+
if isinstance(node, Loop):
|
|
108
|
+
if not isinstance(node, Temporal):
|
|
109
|
+
raise RuntimeError(
|
|
110
|
+
"get_ffmt_mha_tag should not be used for "
|
|
111
|
+
"anything other than Snowcat"
|
|
112
|
+
)
|
|
113
|
+
rank_var_permutation.append(node.rank_variable)
|
|
114
|
+
|
|
115
|
+
tensor_to_n_fused_loops = get_fused_loops_per_tensor(
|
|
116
|
+
pmapping, intermediate_tensors, "MainMemory"
|
|
117
|
+
)
|
|
118
|
+
unfused = all(
|
|
119
|
+
n is None
|
|
120
|
+
for t, n in tensor_to_n_fused_loops.items()
|
|
121
|
+
if t in intermediate_tensors
|
|
122
|
+
)
|
|
123
|
+
if einsum_name not in EINSUM_NAME_TO_REDUCED_RANK_OUTPUT_RANK:
|
|
124
|
+
if unfused:
|
|
125
|
+
return Tags((FFMT_VALID,))
|
|
126
|
+
raise ValueError()
|
|
127
|
+
|
|
128
|
+
reduced_rank, output_rank = EINSUM_NAME_TO_REDUCED_RANK_OUTPUT_RANK[einsum_name]
|
|
129
|
+
|
|
130
|
+
EINSUM_NAME_TO_INPUT_OUTPUT_TENSORS = {
|
|
131
|
+
"Q": ["I_I_to_Q_K_V", "Q_Q_to_QK"],
|
|
132
|
+
"K": ["I_I_to_Q_K_V", "K_K_to_QK"],
|
|
133
|
+
"V": ["I_I_to_Q_K_V", "V_V_to_AV"],
|
|
134
|
+
"QK": ["Q_Q_to_QK", "QK_QK_to_AV"],
|
|
135
|
+
"AV": ["QK_QK_to_AV", "AV_AV_to_Z"],
|
|
136
|
+
"Z": ["AV_AV_to_Z", "Z_Z_to_FFA"],
|
|
137
|
+
"FFA": ["Z_Z_to_FFA", "FFA_FFA_to_FFB"],
|
|
138
|
+
"FFB": ["FFA_FFA_to_FFB", "FFB_FFB_to_n"],
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
input_tensor, output_tensor = EINSUM_NAME_TO_INPUT_OUTPUT_TENSORS[einsum_name]
|
|
142
|
+
input_output_tensors = {input_tensor, output_tensor}
|
|
143
|
+
|
|
144
|
+
min_weight_idx = float("inf")
|
|
145
|
+
max_weight_idx = 0
|
|
146
|
+
max_non_weight_idx = 0
|
|
147
|
+
first, last = True, True
|
|
148
|
+
for tensor, n_loops in tensor_to_n_fused_loops.items():
|
|
149
|
+
if tensor.name == input_tensor and n_loops is not None:
|
|
150
|
+
first = False
|
|
151
|
+
if tensor.name == output_tensor and n_loops is not None:
|
|
152
|
+
last = False
|
|
153
|
+
|
|
154
|
+
is_weight = tensor.name not in input_output_tensors
|
|
155
|
+
if is_weight:
|
|
156
|
+
min_weight_idx = min(min_weight_idx, n_loops)
|
|
157
|
+
max_weight_idx = max(max_weight_idx, n_loops)
|
|
158
|
+
else:
|
|
159
|
+
max_non_weight_idx = max(max_non_weight_idx, n_loops)
|
|
160
|
+
|
|
161
|
+
# Rank variable order and the n_loops for (input, output)
|
|
162
|
+
prefix_choices = [([B, H], (2, 2))]
|
|
163
|
+
|
|
164
|
+
# Rank variable order and the n_loops for (input, output)
|
|
165
|
+
extra_rank_choices = [
|
|
166
|
+
([M], (1, 1)),
|
|
167
|
+
]
|
|
168
|
+
if first:
|
|
169
|
+
if output_rank is not None:
|
|
170
|
+
extra_rank_choices.append(([M, output_rank], (1, 2)))
|
|
171
|
+
if reduced_rank is not None and output_rank is not None:
|
|
172
|
+
extra_rank_choices.append(([M, output_rank, reduced_rank], (3, 2)))
|
|
173
|
+
if output_rank is None and reduced_rank is not None:
|
|
174
|
+
extra_rank_choices.append(([M, reduced_rank], (2, 1)))
|
|
175
|
+
elif last:
|
|
176
|
+
if output_rank is not None:
|
|
177
|
+
extra_rank_choices.append(([M, output_rank], (1, 2)))
|
|
178
|
+
else:
|
|
179
|
+
if reduced_rank is not None:
|
|
180
|
+
extra_rank_choices.append(([M, reduced_rank], (2, 1)))
|
|
181
|
+
|
|
182
|
+
for prefix_permutation, prefix_n_loops in prefix_choices:
|
|
183
|
+
for extra_permutation, extra_n_loops in extra_rank_choices:
|
|
184
|
+
permutation = prefix_permutation + extra_permutation
|
|
185
|
+
input_n_loops = prefix_n_loops[0] + extra_n_loops[0]
|
|
186
|
+
output_n_loops = prefix_n_loops[1] + extra_n_loops[1]
|
|
187
|
+
untiled_weight_idx = len(prefix_permutation)
|
|
188
|
+
|
|
189
|
+
permutation_matches = True
|
|
190
|
+
for rank_var, ref_rank_var in zip(rank_var_permutation, permutation):
|
|
191
|
+
if rank_var != ref_rank_var:
|
|
192
|
+
permutation_matches = False
|
|
193
|
+
break
|
|
194
|
+
|
|
195
|
+
if not permutation_matches:
|
|
196
|
+
continue
|
|
197
|
+
|
|
198
|
+
if tensor_to_n_fused_loops[input_tensor] != input_n_loops:
|
|
199
|
+
continue
|
|
200
|
+
if tensor_to_n_fused_loops[output_tensor] != output_n_loops:
|
|
201
|
+
continue
|
|
202
|
+
|
|
203
|
+
weight_untiled = (
|
|
204
|
+
min_weight_idx == untiled_weight_idx
|
|
205
|
+
and max_weight_idx == untiled_weight_idx
|
|
206
|
+
)
|
|
207
|
+
if weight_untiled:
|
|
208
|
+
return Tags((FFMT_VALID, FFMT_WEIGHT_UNTILED))
|
|
209
|
+
elif min_weight_idx >= max_non_weight_idx:
|
|
210
|
+
return Tags((FFMT_VALID, FFMT_WEIGHT_TILED))
|
|
211
|
+
|
|
212
|
+
raise ValueError()
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from fastfusion.mapper.FFM.deprecate_maybe.tags import Tags
|
|
2
|
+
from fastfusion.mapper.FFM._join_pmappings.compatibility import Compatibility
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
ONE_SPLIT = "ONE_SPLIT"
|
|
6
|
+
NOT_ONE_SPLIT = "NOT_ONE_SPLIT"
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def get_one_split_tag(compatibility: Compatibility) -> Tags:
|
|
10
|
+
# TODO
|
|
11
|
+
unique_loops = set()
|
|
12
|
+
for tensor in compatibility.tensors:
|
|
13
|
+
if tensor.resource_name == "MainMemory":
|
|
14
|
+
continue
|
|
15
|
+
unique_loops.add(tensor.above_loop_index)
|
|
16
|
+
|
|
17
|
+
if len(unique_loops) == 0:
|
|
18
|
+
return Tags() # unfused is compatible with anything
|
|
19
|
+
|
|
20
|
+
# Fused with both sides. Make sure that the number of loops is the same.
|
|
21
|
+
if len(unique_loops) > 1:
|
|
22
|
+
return Tags(("INVALID",))
|
|
23
|
+
|
|
24
|
+
return Tags((ONE_SPLIT, f"FUSED_LOOPS={next(iter(unique_loops))}"))
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from fastfusion.frontend.mapping import Reservation, Loop, Mapping
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def get_fused_loops_per_tensor(
|
|
5
|
+
pmapping: Mapping, intermediate_tensors, non_fused_memory
|
|
6
|
+
):
|
|
7
|
+
"""
|
|
8
|
+
Returns a dictionary mapping tensor to number of fused loops or None
|
|
9
|
+
if unfused (backed in non_fused_memory).
|
|
10
|
+
"""
|
|
11
|
+
tensor_to_n_fused_loops = {}
|
|
12
|
+
n_loops = 0
|
|
13
|
+
for node in pmapping.nodes:
|
|
14
|
+
if isinstance(node, Reservation):
|
|
15
|
+
tensor = node.tensor
|
|
16
|
+
if tensor not in intermediate_tensors or tensor in tensor_to_n_fused_loops:
|
|
17
|
+
continue
|
|
18
|
+
if node.component == non_fused_memory:
|
|
19
|
+
tensor_to_n_fused_loops[tensor] = None
|
|
20
|
+
else:
|
|
21
|
+
tensor_to_n_fused_loops[tensor] = n_loops
|
|
22
|
+
elif isinstance(node, Loop):
|
|
23
|
+
n_loops += 1
|
|
24
|
+
return tensor_to_n_fused_loops
|