accelforge 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accelforge/__init__.py +21 -0
- accelforge/_accelerated_imports.py +16 -0
- accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
- accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
- accelforge/_deprecate/_simanneal/simanneal.py +666 -0
- accelforge/_deprecate/_simanneal/tracking.py +105 -0
- accelforge/_deprecate/_simanneal/wrappers.py +218 -0
- accelforge/_deprecate/_simanneal2/__init__.py +7 -0
- accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
- accelforge/_deprecate/_simanneal2/tracking.py +116 -0
- accelforge/_deprecate/compatibility_util.py +181 -0
- accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
- accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
- accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
- accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
- accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
- accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
- accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
- accelforge/_deprecate/tags.py +69 -0
- accelforge/_deprecate/viz/__init__.py +0 -0
- accelforge/_deprecate/viz/interactive.py +159 -0
- accelforge/_deprecate/viz/reservationtree.py +307 -0
- accelforge/_deprecate/viz/ski_slope.py +88 -0
- accelforge/_version.py +15 -0
- accelforge/examples.py +39 -0
- accelforge/frontend/__init__.py +10 -0
- accelforge/frontend/_binding.py +129 -0
- accelforge/frontend/_workload_isl/__init__.py +2 -0
- accelforge/frontend/_workload_isl/_isl.py +149 -0
- accelforge/frontend/_workload_isl/_symbolic.py +141 -0
- accelforge/frontend/arch copy.py +1544 -0
- accelforge/frontend/arch.py +1642 -0
- accelforge/frontend/config.py +63 -0
- accelforge/frontend/mapper/__init__.py +5 -0
- accelforge/frontend/mapper/ffm.py +126 -0
- accelforge/frontend/mapper/mapper.py +7 -0
- accelforge/frontend/mapper/metrics.py +30 -0
- accelforge/frontend/mapping/__init__.py +1 -0
- accelforge/frontend/mapping/mapping.py +1736 -0
- accelforge/frontend/model.py +14 -0
- accelforge/frontend/renames.py +150 -0
- accelforge/frontend/spec copy.py +230 -0
- accelforge/frontend/spec.py +301 -0
- accelforge/frontend/variables.py +12 -0
- accelforge/frontend/workload.py +952 -0
- accelforge/mapper/FFM/__init__.py +9 -0
- accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
- accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
- accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
- accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
- accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
- accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
- accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
- accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
- accelforge/mapper/FFM/data.py +61 -0
- accelforge/mapper/FFM/main copy.py +236 -0
- accelforge/mapper/FFM/main.py +208 -0
- accelforge/mapper/FFM/mappings.py +510 -0
- accelforge/mapper/FFM/pmappings.py +310 -0
- accelforge/mapper/__init__.py +4 -0
- accelforge/mapper.py +0 -0
- accelforge/model/__init__.py +1 -0
- accelforge/model/_looptree/__init__.py +0 -0
- accelforge/model/_looptree/accesses.py +335 -0
- accelforge/model/_looptree/capacity/__init__.py +1 -0
- accelforge/model/_looptree/capacity/aggregators.py +36 -0
- accelforge/model/_looptree/capacity/capacity.py +47 -0
- accelforge/model/_looptree/energy.py +150 -0
- accelforge/model/_looptree/equivalent_ranks.py +29 -0
- accelforge/model/_looptree/latency/__init__.py +1 -0
- accelforge/model/_looptree/latency/latency.py +98 -0
- accelforge/model/_looptree/latency/memory.py +120 -0
- accelforge/model/_looptree/latency/processors.py +92 -0
- accelforge/model/_looptree/mapping_utilities.py +71 -0
- accelforge/model/_looptree/reuse/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
- accelforge/model/_looptree/reuse/isl/des.py +59 -0
- accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
- accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
- accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
- accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
- accelforge/model/_looptree/run.py +122 -0
- accelforge/model/_looptree/types.py +26 -0
- accelforge/model/_looptree/visualization/__init__.py +0 -0
- accelforge/model/_looptree/visualization/occupancy.py +11 -0
- accelforge/model/main.py +222 -0
- accelforge/plotting/__init__.py +2 -0
- accelforge/plotting/mappings.py +219 -0
- accelforge/plotting/specs.py +57 -0
- accelforge/util/__init__.py +4 -0
- accelforge/util/_base_analysis_types.py +24 -0
- accelforge/util/_basetypes.py +1089 -0
- accelforge/util/_frozenset.py +36 -0
- accelforge/util/_isl.py +29 -0
- accelforge/util/_itertools.py +14 -0
- accelforge/util/_mathfuncs.py +57 -0
- accelforge/util/_parse_expressions.py +339 -0
- accelforge/util/_picklecache.py +32 -0
- accelforge/util/_setexpressions.py +268 -0
- accelforge/util/_sympy/__init__.py +0 -0
- accelforge/util/_sympy/broadcast_max.py +18 -0
- accelforge/util/_visualization.py +112 -0
- accelforge/util/_yaml.py +579 -0
- accelforge/util/parallel.py +193 -0
- accelforge-0.0.1.dist-info/METADATA +64 -0
- accelforge-0.0.1.dist-info/RECORD +258 -0
- accelforge-0.0.1.dist-info/WHEEL +5 -0
- accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
- accelforge-0.0.1.dist-info/top_level.txt +5 -0
- docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
- docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
- docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
- docs/_build/html/_sources/fastfusion.rst.txt +20 -0
- docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
- docs/_build/html/_sources/index.rst.txt +87 -0
- docs/_build/html/_sources/modules.rst.txt +7 -0
- docs/_build/html/_sources/notes/citation.rst.txt +45 -0
- docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
- docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
- docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
- docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
- docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
- docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
- docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
- docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
- docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
- docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
- docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
- docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
- docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
- docs/_build/html/_sources/notes/spec.rst.txt +36 -0
- docs/source/_ext/include_attrs.py +213 -0
- docs/source/_ext/include_docstring.py +364 -0
- docs/source/_ext/include_functions.py +154 -0
- docs/source/_ext/include_notebook.py +131 -0
- docs/source/_ext/include_yaml.py +119 -0
- docs/source/_ext/inherited_attributes.py +222 -0
- docs/source/_ext/paths.py +4 -0
- docs/source/conf.py +79 -0
- examples/arches/compute_in_memory/_include.yaml +74 -0
- examples/arches/compute_in_memory/_include_functions.py +229 -0
- examples/arches/compute_in_memory/_load_spec.py +57 -0
- examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
- examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
- examples/arches/compute_in_memory/components/misc.py +195 -0
- examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
- examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
- examples/arches/compute_in_memory/isaac.yaml +233 -0
- examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
- examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
- examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
- examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
- examples/arches/eyeriss.yaml +68 -0
- examples/arches/fanout_variations/at_glb.yaml +31 -0
- examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
- examples/arches/fanout_variations/at_mac.yaml +31 -0
- examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
- examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
- examples/arches/nvdla.yaml +47 -0
- examples/arches/simple.yaml +28 -0
- examples/arches/tpu_v4i.yaml +67 -0
- examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
- examples/misc/component_annotated.yaml +33 -0
- examples/workloads/gpt3_6.7B.yaml +124 -0
- examples/workloads/matmuls.yaml +20 -0
- examples/workloads/mobilenet_28.yaml +81 -0
- examples/workloads/mobilenet_various_separate.yaml +106 -0
- examples/workloads/three_matmuls_annotated.yaml +59 -0
- notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
- notebooks/compute_in_memory/_scripts.py +339 -0
- notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
- notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
- notebooks/paths.py +4 -0
- notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
- notebooks/tutorials/FFM.ipynb +3498 -0
- notebooks/tutorials/_include.py +48 -0
- notebooks/tutorials/component_energy_area.ipynb +363 -0
- tests/Q_mapping.yaml +38 -0
- tests/__init__.py +0 -0
- tests/conv.mapping.yaml +27 -0
- tests/conv.workload.yaml +13 -0
- tests/conv_sym.mapping.yaml +43 -0
- tests/copy.mapping.yaml +35 -0
- tests/copy.workload.yaml +15 -0
- tests/distribuffers/__init__.py +0 -0
- tests/distribuffers/multicast/test_cases.yaml +482 -0
- tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
- tests/distribuffers/spec/distributed.yaml +100 -0
- tests/distribuffers/spec/logical_arch.yaml +32 -0
- tests/distribuffers/spec/physical_arch.yaml +69 -0
- tests/distribuffers/test_binding.py +48 -0
- tests/frontend/__init__.py +0 -0
- tests/frontend/test_mapping_viz.py +52 -0
- tests/mapper/__init__.py +0 -0
- tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
- tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
- tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
- tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
- tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
- tests/mapper/test_mapping_to_isl.py +90 -0
- tests/mapper/test_spatial_reuse_analysis.py +67 -0
- tests/mapper/test_temporal_reuse_analysis.py +56 -0
- tests/mapper/util.py +58 -0
- tests/matmul.mapping.yaml +29 -0
- tests/matmul.workload.yaml +12 -0
- tests/matmul_spatial.mapping.yaml +44 -0
- tests/mha.renames.yaml +65 -0
- tests/mha.workload.yaml +67 -0
- tests/mha.yaml +59 -0
- tests/mha_full.workload.yaml +67 -0
- tests/mobilenet.workload.yaml +35 -0
- tests/mobilenet_long.workload.yaml +64 -0
- tests/pmappingcache.py +24 -0
- tests/processing_stage.arch.yaml +40 -0
- tests/snowcat.arch.yaml +36 -0
- tests/test_ffm_join_pmappings.py +106 -0
- tests/test_ffm_make_pmappings.py +82 -0
- tests/test_ffm_make_tile_shapes.py +49 -0
- tests/test_mapper.py +100 -0
- tests/test_model.py +37 -0
- tests/test_plotting.py +72 -0
- tests/test_processing_stage.py +46 -0
- tests/test_symbolic_model.py +248 -0
- tests/test_workload.py +141 -0
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from functools import cached_property
|
|
3
|
+
from typing import Any, Callable, Iterable
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from joblib import delayed
|
|
6
|
+
|
|
7
|
+
from accelforge.mapper.FFM._join_pmappings.pmapping_dataframe import PmappingDataframe
|
|
8
|
+
|
|
9
|
+
from accelforge.mapper.FFM._join_pmappings.compatibility import *
|
|
10
|
+
from accelforge.util import parallel
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class PmappingGroup:
|
|
14
|
+
def __init__(self, compatibility: Compatibility, mappings: PmappingDataframe):
|
|
15
|
+
self.compatibility: Compatibility = compatibility
|
|
16
|
+
self.mappings: PmappingDataframe = mappings
|
|
17
|
+
self.tensors: dict[str, TensorReservation] = {
|
|
18
|
+
t.name: t for t in self.compatibility.tensors
|
|
19
|
+
}
|
|
20
|
+
self.n_pre_prune_mappings = 0
|
|
21
|
+
|
|
22
|
+
def compatibility_str(self):
|
|
23
|
+
compatibility = ",".join(str(l) for l in self.compatibility.tensors)
|
|
24
|
+
compatibility += " || " + ", ".join(str(t) for t in self.tensors.values())
|
|
25
|
+
return compatibility
|
|
26
|
+
|
|
27
|
+
@cached_property
|
|
28
|
+
def tensor_names(self) -> set[str]:
|
|
29
|
+
return set(self.tensors)
|
|
30
|
+
|
|
31
|
+
def copy(self) -> "PmappingGroup":
|
|
32
|
+
return PmappingGroup(self.compatibility, self.mappings.copy())
|
|
33
|
+
|
|
34
|
+
def __len__(self) -> int:
|
|
35
|
+
return len(self.mappings)
|
|
36
|
+
|
|
37
|
+
def merge_next(
|
|
38
|
+
self,
|
|
39
|
+
right: "PmappingGroup",
|
|
40
|
+
live_tensors: set[str],
|
|
41
|
+
live_tensors_with_right: set[str],
|
|
42
|
+
aliased_tensors: dict[str, set[str]],
|
|
43
|
+
compatibility_joined: Compatibility,
|
|
44
|
+
ignored_resources: set[str],
|
|
45
|
+
drop_valid_reservations: bool = True,
|
|
46
|
+
delay: bool = False,
|
|
47
|
+
_pmapping_row_filter_function: Callable[[pd.Series], bool] | None = None,
|
|
48
|
+
) -> "PmappingGroup":
|
|
49
|
+
shared_loop_index = self.compatibility.shared_loop_index(
|
|
50
|
+
right.compatibility.tensor_names | live_tensors
|
|
51
|
+
)
|
|
52
|
+
next_shared_loop_index = compatibility_joined.shared_loop_index(live_tensors)
|
|
53
|
+
|
|
54
|
+
still_live_reservations = [
|
|
55
|
+
r
|
|
56
|
+
for r in self.compatibility.tensors
|
|
57
|
+
if r.name in live_tensors and r.name not in right.compatibility.tensor_names
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
duplicated_aliased_tensors = set()
|
|
61
|
+
for name, my_tensor in self.tensors.items():
|
|
62
|
+
for aliased_tensor in aliased_tensors.get(name, set()):
|
|
63
|
+
if (aliased_tensor := right.tensors.get(aliased_tensor, None)) is None:
|
|
64
|
+
continue
|
|
65
|
+
if my_tensor.resource_name == aliased_tensor.resource_name:
|
|
66
|
+
duplicated_aliased_tensors.add(aliased_tensor.name)
|
|
67
|
+
|
|
68
|
+
mapping = delayed(self.mappings.merge_next)(
|
|
69
|
+
right.mappings,
|
|
70
|
+
shared_loop_index,
|
|
71
|
+
next_shared_loop_index,
|
|
72
|
+
live_tensors_with_right,
|
|
73
|
+
still_live_reservations,
|
|
74
|
+
duplicated_aliased_tensors,
|
|
75
|
+
compatibility_left=self.compatibility,
|
|
76
|
+
compatibility_right=right.compatibility,
|
|
77
|
+
compatibility_joined=compatibility_joined,
|
|
78
|
+
drop_valid_reservations=drop_valid_reservations,
|
|
79
|
+
_pmapping_row_filter_function=_pmapping_row_filter_function,
|
|
80
|
+
ignored_resources=ignored_resources,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
if not delay:
|
|
84
|
+
mapping = mapping[0](*mapping[1], **mapping[2])
|
|
85
|
+
|
|
86
|
+
s = PmappingGroup(compatibility_joined, mapping)
|
|
87
|
+
assert (
|
|
88
|
+
compatibility_joined.max_above_loop_index == next_shared_loop_index + 1
|
|
89
|
+
), f"{self.compatibility} {right.compatibility} {next_shared_loop_index + 1} -> {compatibility_joined} {len(compatibility_joined.loops)}"
|
|
90
|
+
s.tensors.update(right.tensors)
|
|
91
|
+
s.tensors.update(self.tensors)
|
|
92
|
+
s.n_pre_prune_mappings = len(self.mappings.data) * len(right.mappings.data)
|
|
93
|
+
return s
|
|
94
|
+
|
|
95
|
+
def get_shared_loop_index(self, live_tensors: set[str]) -> int:
|
|
96
|
+
live_tensors = list(self.compatibility.tensor_names) + [live_tensors]
|
|
97
|
+
return self.compatibility.shared_loop_index(live_tensors)
|
|
98
|
+
|
|
99
|
+
def _right_consolidate(
|
|
100
|
+
self,
|
|
101
|
+
live_tensors: set[str] = None,
|
|
102
|
+
shared_tensors: set[str] = None,
|
|
103
|
+
):
|
|
104
|
+
dead_tensors = set(self.tensors) - (live_tensors or set())
|
|
105
|
+
check_tensors = (shared_tensors or set()) | (live_tensors or set())
|
|
106
|
+
shared_loop_index = self.compatibility.shared_loop_index(check_tensors)
|
|
107
|
+
for t in dead_tensors:
|
|
108
|
+
t = self.tensors.pop(t)
|
|
109
|
+
if self.mappings.free_to_loop_index(
|
|
110
|
+
shared_loop_index, live_tensors=live_tensors
|
|
111
|
+
):
|
|
112
|
+
self.mappings.make_pareto()
|
|
113
|
+
return self
|
|
114
|
+
|
|
115
|
+
def _left_consolidate(self, live_tensors: set[str] = None):
|
|
116
|
+
check_tensors = live_tensors or set()
|
|
117
|
+
shared_loop_index = self.compatibility.shared_loop_index(check_tensors)
|
|
118
|
+
self.mappings.free_to_loop_index(shared_loop_index, live_tensors=live_tensors)
|
|
119
|
+
return self
|
|
120
|
+
|
|
121
|
+
@staticmethod
|
|
122
|
+
def right_consolidate(
|
|
123
|
+
pmapping_groups: list["PmappingGroup"],
|
|
124
|
+
live_tensors: set[str],
|
|
125
|
+
shared_tensors: set[str] = None,
|
|
126
|
+
pbar: str = None,
|
|
127
|
+
parallelize: bool = True,
|
|
128
|
+
) -> list["PmappingGroup"]:
|
|
129
|
+
def job(s):
|
|
130
|
+
return s._right_consolidate(live_tensors, shared_tensors)
|
|
131
|
+
|
|
132
|
+
if not parallelize:
|
|
133
|
+
return [
|
|
134
|
+
s._right_consolidate(live_tensors, shared_tensors)
|
|
135
|
+
for s in pmapping_groups
|
|
136
|
+
]
|
|
137
|
+
|
|
138
|
+
return parallel([delayed(job)(s) for s in pmapping_groups], pbar=pbar)
|
|
139
|
+
|
|
140
|
+
@staticmethod
|
|
141
|
+
def left_consolidate(
|
|
142
|
+
pmapping_groups: list["PmappingGroup"],
|
|
143
|
+
live_tensors: set[str],
|
|
144
|
+
pbar: str = None,
|
|
145
|
+
parallelize: bool = True,
|
|
146
|
+
) -> list["PmappingGroup"]:
|
|
147
|
+
def job(s):
|
|
148
|
+
return s._left_consolidate(live_tensors)
|
|
149
|
+
|
|
150
|
+
if not parallelize:
|
|
151
|
+
return [s._left_consolidate(live_tensors) for s in pmapping_groups]
|
|
152
|
+
|
|
153
|
+
return parallel([delayed(job)(s) for s in pmapping_groups], pbar=pbar)
|
|
154
|
+
|
|
155
|
+
def _hashable_attrs(self):
|
|
156
|
+
return self.mappings, fzs(self.tensors.items())
|
|
157
|
+
|
|
158
|
+
@staticmethod
|
|
159
|
+
def concat(
|
|
160
|
+
pmapping_groups: Iterable["PmappingGroup"],
|
|
161
|
+
allow_different_compatibilies: bool = False,
|
|
162
|
+
) -> "PmappingGroup":
|
|
163
|
+
pmapping_groups = list(pmapping_groups)
|
|
164
|
+
assert len(pmapping_groups) > 0, "Cannot concat empty list of PmappingGroups"
|
|
165
|
+
if not allow_different_compatibilies:
|
|
166
|
+
s = set(
|
|
167
|
+
s.compatibility.clear_symbolic_tile_patterns() for s in pmapping_groups
|
|
168
|
+
)
|
|
169
|
+
if len(s) > 1:
|
|
170
|
+
a = pmapping_groups[0]
|
|
171
|
+
for b in pmapping_groups[1:]:
|
|
172
|
+
if a.compatibility != b.compatibility:
|
|
173
|
+
break
|
|
174
|
+
PmappingGroup.combine_combineable((a, b), "All")
|
|
175
|
+
assert (
|
|
176
|
+
a == b
|
|
177
|
+
), f"Cannot concat PmappingGroups with different compatibilies:\n\t{a}\n\t{b}"
|
|
178
|
+
assert len(s) == 1, (
|
|
179
|
+
f"Cannot concat PmappingGroups with different compatibilies:\n\t"
|
|
180
|
+
+ "\n\t".join(str(s2) for s2 in s)
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
c0 = pmapping_groups[0].compatibility
|
|
184
|
+
to_concat = [pmapping_groups[0]] + [
|
|
185
|
+
s.rename_compatibility(c0) for s in pmapping_groups[1:]
|
|
186
|
+
]
|
|
187
|
+
return PmappingGroup(
|
|
188
|
+
c0, PmappingDataframe.concat([s.mappings for s in to_concat])
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
def rename_compatibility(self, new_c: Compatibility) -> Compatibility:
|
|
192
|
+
c, renamed = self.compatibility._rename_to_match(new_c)
|
|
193
|
+
return PmappingGroup(c, self.mappings.rename(renamed))
|
|
194
|
+
|
|
195
|
+
@staticmethod
|
|
196
|
+
def _group(
|
|
197
|
+
pmapping_groups: list["PmappingGroup"],
|
|
198
|
+
live_tensors: set[str] | Literal["All"],
|
|
199
|
+
clear_tile_patterns_and_reservation_indices: bool = False,
|
|
200
|
+
include_permutations: bool = False,
|
|
201
|
+
clear_symbolic_tile_patterns: bool = False,
|
|
202
|
+
try_permute_into_equivalent: bool = False,
|
|
203
|
+
) -> (
|
|
204
|
+
dict[Compatibility, list["PmappingGroup"]]
|
|
205
|
+
| dict[Compatibility, list[tuple["PmappingGroup", list[int]]]]
|
|
206
|
+
):
|
|
207
|
+
"""
|
|
208
|
+
Clears dead tensors (may keep loops), then group PmappingGroups based on
|
|
209
|
+
compatibility.
|
|
210
|
+
"""
|
|
211
|
+
grouped = defaultdict(list)
|
|
212
|
+
|
|
213
|
+
def clear(c: Compatibility):
|
|
214
|
+
if clear_symbolic_tile_patterns:
|
|
215
|
+
c = c.clear_symbolic_tile_patterns()
|
|
216
|
+
if clear_tile_patterns_and_reservation_indices:
|
|
217
|
+
return c.clear_tile_patterns_and_reservation_indices()
|
|
218
|
+
return c
|
|
219
|
+
|
|
220
|
+
for s in pmapping_groups:
|
|
221
|
+
compatibility = s.compatibility.clear_dead_tensors(live_tensors)
|
|
222
|
+
|
|
223
|
+
if include_permutations or try_permute_into_equivalent:
|
|
224
|
+
keys = compatibility.make_equivalent_permutations()
|
|
225
|
+
for t, loop_changes in keys:
|
|
226
|
+
# Line below DOES NOT MUTATE. It's check that the permutation works.
|
|
227
|
+
s.compatibility.permute(loop_changes)
|
|
228
|
+
grouped[clear(t)].append((s, loop_changes))
|
|
229
|
+
else:
|
|
230
|
+
grouped[clear(compatibility)].append(s)
|
|
231
|
+
|
|
232
|
+
if clear_tile_patterns_and_reservation_indices:
|
|
233
|
+
for k in grouped:
|
|
234
|
+
assert (
|
|
235
|
+
len(k.reservation_indices) == 0
|
|
236
|
+
), f"Extra reservation indices are not empty: {k.reservation_indices}"
|
|
237
|
+
|
|
238
|
+
if try_permute_into_equivalent:
|
|
239
|
+
assert not include_permutations
|
|
240
|
+
new_grouped = {}
|
|
241
|
+
pmgroups_remaining = {id(s) for s in pmapping_groups}
|
|
242
|
+
for c, g in sorted(grouped.items(), key=lambda x: len(x[1]), reverse=True):
|
|
243
|
+
if not pmgroups_remaining:
|
|
244
|
+
break
|
|
245
|
+
g = [
|
|
246
|
+
(s, loop_changes)
|
|
247
|
+
for s, loop_changes in g
|
|
248
|
+
if id(s) in pmgroups_remaining
|
|
249
|
+
]
|
|
250
|
+
if g:
|
|
251
|
+
pmgroups_remaining -= {id(s) for s, _ in g}
|
|
252
|
+
permuted = [
|
|
253
|
+
PmappingGroup(s.compatibility.permute(lc), s.mappings)
|
|
254
|
+
for s, lc in g
|
|
255
|
+
]
|
|
256
|
+
new_grouped[c] = permuted
|
|
257
|
+
grouped = new_grouped
|
|
258
|
+
|
|
259
|
+
return grouped
|
|
260
|
+
|
|
261
|
+
@staticmethod
|
|
262
|
+
def combine_combineable(
|
|
263
|
+
pmapping_groups: list["PmappingGroup"],
|
|
264
|
+
live_tensors: set[str] | Literal["All"],
|
|
265
|
+
allow_different_compatibilies: bool = False,
|
|
266
|
+
combine_reservations: bool = True,
|
|
267
|
+
pbar_postfix: str = "",
|
|
268
|
+
) -> list["PmappingGroup"]:
|
|
269
|
+
pmapping_groups = [s for s in pmapping_groups if len(s.mappings.data) > 0]
|
|
270
|
+
no_combine = []
|
|
271
|
+
if not combine_reservations:
|
|
272
|
+
has_reservations = [s.mappings.has_reservations() for s in pmapping_groups]
|
|
273
|
+
no_combine = [s for s, h in zip(pmapping_groups, has_reservations) if h]
|
|
274
|
+
pmapping_groups = [
|
|
275
|
+
s for s, h in zip(pmapping_groups, has_reservations) if not h
|
|
276
|
+
]
|
|
277
|
+
groups = list(
|
|
278
|
+
PmappingGroup._group(
|
|
279
|
+
pmapping_groups,
|
|
280
|
+
live_tensors,
|
|
281
|
+
clear_symbolic_tile_patterns=True,
|
|
282
|
+
try_permute_into_equivalent=True,
|
|
283
|
+
).values()
|
|
284
|
+
)
|
|
285
|
+
groups_with_one = [g[0] for g in groups if len(g) == 1]
|
|
286
|
+
if len(groups_with_one) == len(groups):
|
|
287
|
+
return groups_with_one + no_combine
|
|
288
|
+
|
|
289
|
+
others = parallel(
|
|
290
|
+
[
|
|
291
|
+
delayed(PmappingGroup.concat)(g, allow_different_compatibilies)
|
|
292
|
+
for g in groups
|
|
293
|
+
if len(g) > 1
|
|
294
|
+
],
|
|
295
|
+
pbar=f"Grouping pmappings{pbar_postfix}",
|
|
296
|
+
)
|
|
297
|
+
return groups_with_one + others + no_combine
|
|
298
|
+
|
|
299
|
+
@staticmethod
|
|
300
|
+
def filter_by_tensors(
|
|
301
|
+
pmapping_groups: list["PmappingGroup"] | dict[Compatibility, Any],
|
|
302
|
+
tensors: set[str],
|
|
303
|
+
) -> list["PmappingGroup"]:
|
|
304
|
+
def check(tensors_to_check):
|
|
305
|
+
for t in tensors_to_check:
|
|
306
|
+
for t2 in tensors:
|
|
307
|
+
if (t2.name == "*" or t.name == t2.name) and t != t2:
|
|
308
|
+
return False
|
|
309
|
+
return True
|
|
310
|
+
|
|
311
|
+
tensors = set(tensors)
|
|
312
|
+
if isinstance(pmapping_groups, list):
|
|
313
|
+
return [s for s in pmapping_groups if check(s.compatibility.tensors)]
|
|
314
|
+
if isinstance(pmapping_groups, dict):
|
|
315
|
+
return {k: v for k, v in pmapping_groups.items() if check(k.tensors)}
|
|
316
|
+
raise ValueError(f"Invalid type {type(pmapping_groups)}")
|
|
317
|
+
|
|
318
|
+
@staticmethod
|
|
319
|
+
def group(
|
|
320
|
+
pmapping_groups: list["PmappingGroup"], live_tensors: set[str]
|
|
321
|
+
) -> dict[tuple[Compatibility, ...], list[tuple["PmappingGroup", list[int]]]]:
|
|
322
|
+
x = PmappingGroup._group(
|
|
323
|
+
pmapping_groups,
|
|
324
|
+
live_tensors,
|
|
325
|
+
clear_tile_patterns_and_reservation_indices=True,
|
|
326
|
+
include_permutations=True,
|
|
327
|
+
)
|
|
328
|
+
return x
|
|
329
|
+
|
|
330
|
+
@staticmethod
|
|
331
|
+
def remove_dead_tensors(
|
|
332
|
+
pmapping_groups: list["PmappingGroup"], live_tensors: set[str]
|
|
333
|
+
):
|
|
334
|
+
for s in pmapping_groups:
|
|
335
|
+
for t in list(s.tensors):
|
|
336
|
+
if t not in live_tensors:
|
|
337
|
+
del s.tensors[t]
|
|
File without changes
|
|
@@ -0,0 +1,360 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
import logging
|
|
3
|
+
from typing import List
|
|
4
|
+
from accelforge._accelerated_imports import np
|
|
5
|
+
from accelforge.frontend._workload_isl._symbolic import PartiallyRelevant, Relevant
|
|
6
|
+
import accelforge.frontend.arch as arch
|
|
7
|
+
from accelforge.frontend.arch import (
|
|
8
|
+
Comparison,
|
|
9
|
+
_MinUsageConstraintLambda,
|
|
10
|
+
_TileShapeConstraintLambda,
|
|
11
|
+
_LoopBoundsConstraintLambda,
|
|
12
|
+
_ConstraintLambda,
|
|
13
|
+
)
|
|
14
|
+
from accelforge.frontend.mapping import (
|
|
15
|
+
Loop,
|
|
16
|
+
MappingNode,
|
|
17
|
+
TensorHolder,
|
|
18
|
+
Temporal,
|
|
19
|
+
Spatial,
|
|
20
|
+
)
|
|
21
|
+
from accelforge.frontend.renames import TensorName
|
|
22
|
+
from accelforge.frontend.workload import EinsumName, RankVariable
|
|
23
|
+
from accelforge.util._setexpressions import InvertibleSet
|
|
24
|
+
from accelforge.util._frozenset import fzs
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# =================================================================================================
|
|
28
|
+
# Attach constraints to mapping
|
|
29
|
+
# =================================================================================================
|
|
30
|
+
class MappingConstraints:
|
|
31
|
+
def __init__(self):
|
|
32
|
+
self.tile_shape_constraints: list[_TileShapeConstraintLambda] = []
|
|
33
|
+
self.loop_bounds_constraints: list[_LoopBoundsConstraintLambda] = []
|
|
34
|
+
self.min_usage_constraints: dict[tuple[str, str], _MinUsageConstraintLambda] = (
|
|
35
|
+
{}
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
def get_all_constraints(self) -> list[_ConstraintLambda]:
|
|
39
|
+
return (
|
|
40
|
+
self.tile_shape_constraints
|
|
41
|
+
+ self.loop_bounds_constraints
|
|
42
|
+
+ list(self.min_usage_constraints.values())
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
def check_tile_shape_constraints(
|
|
46
|
+
self, tile_shapes: np.ndarray, complete_indices: list[int]
|
|
47
|
+
):
|
|
48
|
+
mask = np.ones(tile_shapes.shape[0], dtype=np.bool)
|
|
49
|
+
for c in self.tile_shape_constraints:
|
|
50
|
+
mask = mask & c(complete_indices, tile_shapes[:, c._target_loop_indices])
|
|
51
|
+
return mask
|
|
52
|
+
|
|
53
|
+
def check_min_usage_constraints(
|
|
54
|
+
self,
|
|
55
|
+
component_name: str,
|
|
56
|
+
name: str,
|
|
57
|
+
usage: np.ndarray,
|
|
58
|
+
complete_indices: list[int],
|
|
59
|
+
):
|
|
60
|
+
if (component_name, name) not in self.min_usage_constraints:
|
|
61
|
+
return np.ones(usage.shape[0], dtype=np.bool)
|
|
62
|
+
|
|
63
|
+
return self.min_usage_constraints[(component_name, name)](
|
|
64
|
+
complete_indices, usage
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def set_loop_indices(self, nodes: list[MappingNode]):
|
|
68
|
+
loops = [n for n in nodes if isinstance(n, Loop)]
|
|
69
|
+
for c in self.get_all_constraints():
|
|
70
|
+
c._target_node_indices = [nodes.index(t) for t in c.target_mapping_nodes]
|
|
71
|
+
c._target_loop_indices = [loops.index(t) for t in c.target_mapping_nodes]
|
|
72
|
+
|
|
73
|
+
# Min usage constraints also depend on the loop ABOVE the target loop
|
|
74
|
+
# because the loop above determines the number of tiles
|
|
75
|
+
for c in self.min_usage_constraints.values():
|
|
76
|
+
# Rank variables must be unique between mapping nodes
|
|
77
|
+
rank_variables = set(t.rank_variable for t in c.target_mapping_nodes)
|
|
78
|
+
assert len(rank_variables) == len(
|
|
79
|
+
c.target_mapping_nodes
|
|
80
|
+
), "Rank variables must be unique between mapping nodes"
|
|
81
|
+
|
|
82
|
+
for target_mapping_node in c.target_mapping_nodes:
|
|
83
|
+
assert isinstance(target_mapping_node, Spatial)
|
|
84
|
+
loop_index = loops.index(target_mapping_node) - 1
|
|
85
|
+
while loop_index >= 0:
|
|
86
|
+
loop = loops[loop_index]
|
|
87
|
+
if loop.rank_variable in rank_variables:
|
|
88
|
+
c._target_loop_indices.append(loop_index)
|
|
89
|
+
c._target_node_indices.append(nodes.index(loop))
|
|
90
|
+
break
|
|
91
|
+
loop_index -= 1
|
|
92
|
+
|
|
93
|
+
def clear_constrained_to_one(
|
|
94
|
+
self, mapping: list["MappingNode"], einsum_name: EinsumName
|
|
95
|
+
) -> list["MappingNode"]:
|
|
96
|
+
# Not constrained to one --> Can't remove
|
|
97
|
+
node2constraints = defaultdict(list)
|
|
98
|
+
do_not_remove = set()
|
|
99
|
+
for c in self.tile_shape_constraints:
|
|
100
|
+
for t in c.target_mapping_nodes:
|
|
101
|
+
node2constraints[id(t)].append(c)
|
|
102
|
+
do_not_remove.add(id(t))
|
|
103
|
+
for c in self.loop_bounds_constraints:
|
|
104
|
+
if not c.constraint._constrained_to_one():
|
|
105
|
+
for t in c.target_mapping_nodes:
|
|
106
|
+
node2constraints[id(t)].append(c)
|
|
107
|
+
do_not_remove.add(id(t))
|
|
108
|
+
|
|
109
|
+
# Constrained to one --> remove iff not in do_not_remove
|
|
110
|
+
to_remove = set()
|
|
111
|
+
for c in self.loop_bounds_constraints:
|
|
112
|
+
if c.constraint._constrained_to_one():
|
|
113
|
+
my_remove = set(id(t) for t in c.target_mapping_nodes)
|
|
114
|
+
if my_remove & do_not_remove:
|
|
115
|
+
loops = [n for n in mapping if id(n) in my_remove]
|
|
116
|
+
p = len(loops) == 1
|
|
117
|
+
loops = (", ".join(n.compact_str() for n in loops)).strip()
|
|
118
|
+
isare = "is" if p else "are"
|
|
119
|
+
all_others = ", ".join(
|
|
120
|
+
str(c2) for c2 in node2constraints[id(t)] if c2 != c
|
|
121
|
+
)
|
|
122
|
+
logging.warning(
|
|
123
|
+
f"For Einsum {einsum_name}, loop{'s' * (not p)} {loops} "
|
|
124
|
+
f"{isare} set to be removed by {c} and also appear{'s' * p} in "
|
|
125
|
+
f"{all_others}. The loop{'s' * (not p)} will not be removed "
|
|
126
|
+
f"from the mapping, but it may be subject to conflicting "
|
|
127
|
+
f"constraints."
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
c.target_mapping_nodes = [
|
|
131
|
+
t for t in c.target_mapping_nodes if id(t) not in my_remove
|
|
132
|
+
]
|
|
133
|
+
to_remove.update(my_remove)
|
|
134
|
+
self.loop_bounds_constraints = [
|
|
135
|
+
c
|
|
136
|
+
for c in self.loop_bounds_constraints
|
|
137
|
+
if not c.constraint._constrained_to_one()
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
for c in self.get_all_constraints():
|
|
141
|
+
c.target_mapping_nodes = [
|
|
142
|
+
n for n in c.target_mapping_nodes if id(n) not in to_remove
|
|
143
|
+
]
|
|
144
|
+
|
|
145
|
+
return [m for m in mapping if id(m) not in to_remove]
|
|
146
|
+
|
|
147
|
+
def pretty_str(self) -> str:
|
|
148
|
+
s = ""
|
|
149
|
+
all_constraints = self.get_all_constraints()
|
|
150
|
+
s += "Tile shape constraints:\n"
|
|
151
|
+
for c in self.tile_shape_constraints:
|
|
152
|
+
s += f"\t{all_constraints.index(c)} {c.pretty_str()}\n"
|
|
153
|
+
s += "Loop bounds constraints:\n"
|
|
154
|
+
for c in self.loop_bounds_constraints:
|
|
155
|
+
s += f"\t{all_constraints.index(c)} {c.pretty_str()}\n"
|
|
156
|
+
s += "Min usage constraints:\n"
|
|
157
|
+
for c in self.min_usage_constraints.values():
|
|
158
|
+
s += f"\t{all_constraints.index(c)} {c.pretty_str()}\n"
|
|
159
|
+
return s
|
|
160
|
+
|
|
161
|
+
def remove_missing_targets(self, mapping: list[MappingNode]):
|
|
162
|
+
for c in self.get_all_constraints():
|
|
163
|
+
c.target_mapping_nodes = [n for n in c.target_mapping_nodes if n in mapping]
|
|
164
|
+
|
|
165
|
+
self.tile_shape_constraints = [c for c in self.tile_shape_constraints if c]
|
|
166
|
+
self.loop_bounds_constraints = [c for c in self.loop_bounds_constraints if c]
|
|
167
|
+
self.min_usage_constraints = {
|
|
168
|
+
k: c for k, c in self.min_usage_constraints.items() if c
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def first_tensor_holder_index(mapping: list["MappingNode"], memory_name: str) -> int:
|
|
173
|
+
for i, m in enumerate(mapping):
|
|
174
|
+
if isinstance(m, TensorHolder) and m.component == memory_name:
|
|
175
|
+
return i
|
|
176
|
+
return None
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def constrained_loops(
|
|
180
|
+
mapping: list["MappingNode"],
|
|
181
|
+
rank_variables: set[RankVariable],
|
|
182
|
+
start_index: int = None,
|
|
183
|
+
look_behind: bool = False,
|
|
184
|
+
component: str = None,
|
|
185
|
+
one_loop_per_rank_variable: bool = True,
|
|
186
|
+
) -> list[Loop]:
|
|
187
|
+
nodes = []
|
|
188
|
+
remaining_rank_variables = set(rank_variables)
|
|
189
|
+
|
|
190
|
+
if look_behind:
|
|
191
|
+
to_check = list(enumerate(mapping))
|
|
192
|
+
to_check.reverse()
|
|
193
|
+
if start_index is not None:
|
|
194
|
+
to_check = [
|
|
195
|
+
m for i, m in to_check if start_index is None or i <= start_index
|
|
196
|
+
]
|
|
197
|
+
else:
|
|
198
|
+
to_check = list(enumerate(mapping))
|
|
199
|
+
to_check = [m for i, m in to_check if start_index is None or i >= start_index]
|
|
200
|
+
|
|
201
|
+
for m in to_check:
|
|
202
|
+
if not isinstance(m, Loop):
|
|
203
|
+
continue
|
|
204
|
+
if component is not None and (
|
|
205
|
+
not isinstance(m, Spatial) or m.component != component
|
|
206
|
+
):
|
|
207
|
+
continue
|
|
208
|
+
assert isinstance(m.rank_variable, RankVariable)
|
|
209
|
+
if m.rank_variable in remaining_rank_variables:
|
|
210
|
+
nodes.append(m)
|
|
211
|
+
if one_loop_per_rank_variable:
|
|
212
|
+
remaining_rank_variables.discard(m.rank_variable)
|
|
213
|
+
# TODO: what is this supposed to do?
|
|
214
|
+
# for r in remaining_rank_variables:
|
|
215
|
+
# assert (
|
|
216
|
+
# component is None
|
|
217
|
+
# ), "There should be a spatial loop for every rank variable"
|
|
218
|
+
return nodes
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def get_constraints(
|
|
222
|
+
flattened_arch: list[arch.Leaf],
|
|
223
|
+
mapping: List[MappingNode],
|
|
224
|
+
symbol_table: dict[str, InvertibleSet],
|
|
225
|
+
einsum_name: EinsumName,
|
|
226
|
+
tensor_to_relevancy: dict[
|
|
227
|
+
TensorName, dict[RankVariable, Relevant | PartiallyRelevant]
|
|
228
|
+
],
|
|
229
|
+
) -> tuple[List[MappingNode], MappingConstraints]:
|
|
230
|
+
|
|
231
|
+
constraints = MappingConstraints()
|
|
232
|
+
|
|
233
|
+
# Tensor constraints
|
|
234
|
+
for m in flattened_arch:
|
|
235
|
+
# Ignore if not a memory
|
|
236
|
+
if not isinstance(m, arch.Memory):
|
|
237
|
+
continue
|
|
238
|
+
|
|
239
|
+
# Ignore if it doesn't hold any tensors
|
|
240
|
+
if (index := first_tensor_holder_index(mapping, m.name)) is None:
|
|
241
|
+
continue
|
|
242
|
+
|
|
243
|
+
# Tile shape constraints
|
|
244
|
+
for c in m.tensors.tile_shape:
|
|
245
|
+
nodes = constrained_loops(
|
|
246
|
+
mapping, c.expression, index - 1, look_behind=True
|
|
247
|
+
)
|
|
248
|
+
for exp in c._split_expression():
|
|
249
|
+
new_nodes = [n for n in nodes if n.rank_variable in exp]
|
|
250
|
+
constraint = _TileShapeConstraintLambda(c, new_nodes, exp)
|
|
251
|
+
constraints.tile_shape_constraints.append(constraint)
|
|
252
|
+
|
|
253
|
+
exp = symbol_table[m.name] & m.tensors.no_refetch_from_above
|
|
254
|
+
|
|
255
|
+
nodes = []
|
|
256
|
+
for no_refetch in exp.iter_one_element_sets():
|
|
257
|
+
# Start from the first index of the tensor holder, stop at index - 1
|
|
258
|
+
start_index = 0
|
|
259
|
+
n = next(iter(no_refetch))
|
|
260
|
+
while start_index < len(mapping):
|
|
261
|
+
if (
|
|
262
|
+
isinstance(mapping[start_index], TensorHolder)
|
|
263
|
+
and n in mapping[start_index].tensors
|
|
264
|
+
):
|
|
265
|
+
break
|
|
266
|
+
start_index += 1
|
|
267
|
+
|
|
268
|
+
end_index = start_index
|
|
269
|
+
while end_index < len(mapping):
|
|
270
|
+
if (
|
|
271
|
+
isinstance(mapping[end_index], TensorHolder)
|
|
272
|
+
and n in mapping[end_index].tensors
|
|
273
|
+
and mapping[end_index].component == m.name
|
|
274
|
+
):
|
|
275
|
+
break
|
|
276
|
+
end_index += 1
|
|
277
|
+
|
|
278
|
+
for i in range(start_index, end_index):
|
|
279
|
+
if isinstance(mapping[i], Loop) and not isinstance(
|
|
280
|
+
tensor_to_relevancy[n][mapping[i].rank_variable], Relevant
|
|
281
|
+
):
|
|
282
|
+
if mapping[i] not in nodes:
|
|
283
|
+
nodes.append(mapping[i])
|
|
284
|
+
|
|
285
|
+
if nodes:
|
|
286
|
+
constraints.loop_bounds_constraints.append(
|
|
287
|
+
_LoopBoundsConstraintLambda(
|
|
288
|
+
Comparison(expression=exp, operator="==", value=1), nodes, exp
|
|
289
|
+
)
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
# Spatial constraints
|
|
293
|
+
for m in flattened_arch:
|
|
294
|
+
if not isinstance(m, (arch.Memory, arch.Fanout)):
|
|
295
|
+
continue
|
|
296
|
+
|
|
297
|
+
for dim in m.spatial:
|
|
298
|
+
loops = [
|
|
299
|
+
n
|
|
300
|
+
for n in mapping
|
|
301
|
+
if isinstance(n, Spatial)
|
|
302
|
+
and (n.component, n.name) == (m.name, dim.name)
|
|
303
|
+
]
|
|
304
|
+
loop_bounds = list(dim.loop_bounds)
|
|
305
|
+
if dim.reuse:
|
|
306
|
+
loop_bounds.append(
|
|
307
|
+
Comparison(
|
|
308
|
+
expression=dim.reuse.rank_variables,
|
|
309
|
+
operator="==",
|
|
310
|
+
value=1,
|
|
311
|
+
)
|
|
312
|
+
)
|
|
313
|
+
loop_bounds[-1]._str_repr = f"reuse {set(dim.reuse)}"
|
|
314
|
+
|
|
315
|
+
# Loop bounds constraints
|
|
316
|
+
if loop_bounds:
|
|
317
|
+
for c in loop_bounds:
|
|
318
|
+
nodes = constrained_loops(loops, c.expression, component=m.name)
|
|
319
|
+
for exp in c._split_expression():
|
|
320
|
+
new_nodes = [l for l in loops if l.rank_variable in exp]
|
|
321
|
+
constraint = _LoopBoundsConstraintLambda(c, new_nodes, exp)
|
|
322
|
+
constraints.loop_bounds_constraints.append(constraint)
|
|
323
|
+
|
|
324
|
+
# Min usage constraints
|
|
325
|
+
target_mapping_nodes = [
|
|
326
|
+
n
|
|
327
|
+
for n in mapping
|
|
328
|
+
if isinstance(n, Spatial)
|
|
329
|
+
and n.component == m.name
|
|
330
|
+
and n.name == dim.name
|
|
331
|
+
]
|
|
332
|
+
if dim.min_usage > 0:
|
|
333
|
+
if not target_mapping_nodes:
|
|
334
|
+
continue
|
|
335
|
+
rank_variables = {t.rank_variable for t in target_mapping_nodes}
|
|
336
|
+
constraint = _MinUsageConstraintLambda(
|
|
337
|
+
target_mapping_nodes,
|
|
338
|
+
rank_variables,
|
|
339
|
+
dim.min_usage,
|
|
340
|
+
)
|
|
341
|
+
key = (m.name, dim.name)
|
|
342
|
+
constraints.min_usage_constraints[key] = constraint
|
|
343
|
+
|
|
344
|
+
for t in target_mapping_nodes:
|
|
345
|
+
t._may_reuse = dim.may_reuse
|
|
346
|
+
|
|
347
|
+
# Additional spatial constraints
|
|
348
|
+
for m in mapping:
|
|
349
|
+
if isinstance(m, Spatial) and m._constrained_to_one:
|
|
350
|
+
constraints.loop_bounds_constraints.append(
|
|
351
|
+
_LoopBoundsConstraintLambda(
|
|
352
|
+
Comparison(expression=m.rank_variable, operator="==", value=1),
|
|
353
|
+
[m],
|
|
354
|
+
m.rank_variable,
|
|
355
|
+
)
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
mapping = constraints.clear_constrained_to_one(mapping, einsum_name)
|
|
359
|
+
|
|
360
|
+
return mapping, constraints
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .make_pmapping_templates import make_pmapping_templates
|