accelforge 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accelforge/__init__.py +21 -0
- accelforge/_accelerated_imports.py +16 -0
- accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
- accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
- accelforge/_deprecate/_simanneal/simanneal.py +666 -0
- accelforge/_deprecate/_simanneal/tracking.py +105 -0
- accelforge/_deprecate/_simanneal/wrappers.py +218 -0
- accelforge/_deprecate/_simanneal2/__init__.py +7 -0
- accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
- accelforge/_deprecate/_simanneal2/tracking.py +116 -0
- accelforge/_deprecate/compatibility_util.py +181 -0
- accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
- accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
- accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
- accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
- accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
- accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
- accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
- accelforge/_deprecate/tags.py +69 -0
- accelforge/_deprecate/viz/__init__.py +0 -0
- accelforge/_deprecate/viz/interactive.py +159 -0
- accelforge/_deprecate/viz/reservationtree.py +307 -0
- accelforge/_deprecate/viz/ski_slope.py +88 -0
- accelforge/_version.py +15 -0
- accelforge/examples.py +39 -0
- accelforge/frontend/__init__.py +10 -0
- accelforge/frontend/_binding.py +129 -0
- accelforge/frontend/_workload_isl/__init__.py +2 -0
- accelforge/frontend/_workload_isl/_isl.py +149 -0
- accelforge/frontend/_workload_isl/_symbolic.py +141 -0
- accelforge/frontend/arch copy.py +1544 -0
- accelforge/frontend/arch.py +1642 -0
- accelforge/frontend/config.py +63 -0
- accelforge/frontend/mapper/__init__.py +5 -0
- accelforge/frontend/mapper/ffm.py +126 -0
- accelforge/frontend/mapper/mapper.py +7 -0
- accelforge/frontend/mapper/metrics.py +30 -0
- accelforge/frontend/mapping/__init__.py +1 -0
- accelforge/frontend/mapping/mapping.py +1736 -0
- accelforge/frontend/model.py +14 -0
- accelforge/frontend/renames.py +150 -0
- accelforge/frontend/spec copy.py +230 -0
- accelforge/frontend/spec.py +301 -0
- accelforge/frontend/variables.py +12 -0
- accelforge/frontend/workload.py +952 -0
- accelforge/mapper/FFM/__init__.py +9 -0
- accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
- accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
- accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
- accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
- accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
- accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
- accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
- accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
- accelforge/mapper/FFM/data.py +61 -0
- accelforge/mapper/FFM/main copy.py +236 -0
- accelforge/mapper/FFM/main.py +208 -0
- accelforge/mapper/FFM/mappings.py +510 -0
- accelforge/mapper/FFM/pmappings.py +310 -0
- accelforge/mapper/__init__.py +4 -0
- accelforge/mapper.py +0 -0
- accelforge/model/__init__.py +1 -0
- accelforge/model/_looptree/__init__.py +0 -0
- accelforge/model/_looptree/accesses.py +335 -0
- accelforge/model/_looptree/capacity/__init__.py +1 -0
- accelforge/model/_looptree/capacity/aggregators.py +36 -0
- accelforge/model/_looptree/capacity/capacity.py +47 -0
- accelforge/model/_looptree/energy.py +150 -0
- accelforge/model/_looptree/equivalent_ranks.py +29 -0
- accelforge/model/_looptree/latency/__init__.py +1 -0
- accelforge/model/_looptree/latency/latency.py +98 -0
- accelforge/model/_looptree/latency/memory.py +120 -0
- accelforge/model/_looptree/latency/processors.py +92 -0
- accelforge/model/_looptree/mapping_utilities.py +71 -0
- accelforge/model/_looptree/reuse/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
- accelforge/model/_looptree/reuse/isl/des.py +59 -0
- accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
- accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
- accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
- accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
- accelforge/model/_looptree/run.py +122 -0
- accelforge/model/_looptree/types.py +26 -0
- accelforge/model/_looptree/visualization/__init__.py +0 -0
- accelforge/model/_looptree/visualization/occupancy.py +11 -0
- accelforge/model/main.py +222 -0
- accelforge/plotting/__init__.py +2 -0
- accelforge/plotting/mappings.py +219 -0
- accelforge/plotting/specs.py +57 -0
- accelforge/util/__init__.py +4 -0
- accelforge/util/_base_analysis_types.py +24 -0
- accelforge/util/_basetypes.py +1089 -0
- accelforge/util/_frozenset.py +36 -0
- accelforge/util/_isl.py +29 -0
- accelforge/util/_itertools.py +14 -0
- accelforge/util/_mathfuncs.py +57 -0
- accelforge/util/_parse_expressions.py +339 -0
- accelforge/util/_picklecache.py +32 -0
- accelforge/util/_setexpressions.py +268 -0
- accelforge/util/_sympy/__init__.py +0 -0
- accelforge/util/_sympy/broadcast_max.py +18 -0
- accelforge/util/_visualization.py +112 -0
- accelforge/util/_yaml.py +579 -0
- accelforge/util/parallel.py +193 -0
- accelforge-0.0.1.dist-info/METADATA +64 -0
- accelforge-0.0.1.dist-info/RECORD +258 -0
- accelforge-0.0.1.dist-info/WHEEL +5 -0
- accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
- accelforge-0.0.1.dist-info/top_level.txt +5 -0
- docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
- docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
- docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
- docs/_build/html/_sources/fastfusion.rst.txt +20 -0
- docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
- docs/_build/html/_sources/index.rst.txt +87 -0
- docs/_build/html/_sources/modules.rst.txt +7 -0
- docs/_build/html/_sources/notes/citation.rst.txt +45 -0
- docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
- docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
- docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
- docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
- docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
- docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
- docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
- docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
- docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
- docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
- docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
- docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
- docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
- docs/_build/html/_sources/notes/spec.rst.txt +36 -0
- docs/source/_ext/include_attrs.py +213 -0
- docs/source/_ext/include_docstring.py +364 -0
- docs/source/_ext/include_functions.py +154 -0
- docs/source/_ext/include_notebook.py +131 -0
- docs/source/_ext/include_yaml.py +119 -0
- docs/source/_ext/inherited_attributes.py +222 -0
- docs/source/_ext/paths.py +4 -0
- docs/source/conf.py +79 -0
- examples/arches/compute_in_memory/_include.yaml +74 -0
- examples/arches/compute_in_memory/_include_functions.py +229 -0
- examples/arches/compute_in_memory/_load_spec.py +57 -0
- examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
- examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
- examples/arches/compute_in_memory/components/misc.py +195 -0
- examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
- examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
- examples/arches/compute_in_memory/isaac.yaml +233 -0
- examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
- examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
- examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
- examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
- examples/arches/eyeriss.yaml +68 -0
- examples/arches/fanout_variations/at_glb.yaml +31 -0
- examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
- examples/arches/fanout_variations/at_mac.yaml +31 -0
- examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
- examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
- examples/arches/nvdla.yaml +47 -0
- examples/arches/simple.yaml +28 -0
- examples/arches/tpu_v4i.yaml +67 -0
- examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
- examples/misc/component_annotated.yaml +33 -0
- examples/workloads/gpt3_6.7B.yaml +124 -0
- examples/workloads/matmuls.yaml +20 -0
- examples/workloads/mobilenet_28.yaml +81 -0
- examples/workloads/mobilenet_various_separate.yaml +106 -0
- examples/workloads/three_matmuls_annotated.yaml +59 -0
- notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
- notebooks/compute_in_memory/_scripts.py +339 -0
- notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
- notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
- notebooks/paths.py +4 -0
- notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
- notebooks/tutorials/FFM.ipynb +3498 -0
- notebooks/tutorials/_include.py +48 -0
- notebooks/tutorials/component_energy_area.ipynb +363 -0
- tests/Q_mapping.yaml +38 -0
- tests/__init__.py +0 -0
- tests/conv.mapping.yaml +27 -0
- tests/conv.workload.yaml +13 -0
- tests/conv_sym.mapping.yaml +43 -0
- tests/copy.mapping.yaml +35 -0
- tests/copy.workload.yaml +15 -0
- tests/distribuffers/__init__.py +0 -0
- tests/distribuffers/multicast/test_cases.yaml +482 -0
- tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
- tests/distribuffers/spec/distributed.yaml +100 -0
- tests/distribuffers/spec/logical_arch.yaml +32 -0
- tests/distribuffers/spec/physical_arch.yaml +69 -0
- tests/distribuffers/test_binding.py +48 -0
- tests/frontend/__init__.py +0 -0
- tests/frontend/test_mapping_viz.py +52 -0
- tests/mapper/__init__.py +0 -0
- tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
- tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
- tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
- tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
- tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
- tests/mapper/test_mapping_to_isl.py +90 -0
- tests/mapper/test_spatial_reuse_analysis.py +67 -0
- tests/mapper/test_temporal_reuse_analysis.py +56 -0
- tests/mapper/util.py +58 -0
- tests/matmul.mapping.yaml +29 -0
- tests/matmul.workload.yaml +12 -0
- tests/matmul_spatial.mapping.yaml +44 -0
- tests/mha.renames.yaml +65 -0
- tests/mha.workload.yaml +67 -0
- tests/mha.yaml +59 -0
- tests/mha_full.workload.yaml +67 -0
- tests/mobilenet.workload.yaml +35 -0
- tests/mobilenet_long.workload.yaml +64 -0
- tests/pmappingcache.py +24 -0
- tests/processing_stage.arch.yaml +40 -0
- tests/snowcat.arch.yaml +36 -0
- tests/test_ffm_join_pmappings.py +106 -0
- tests/test_ffm_make_pmappings.py +82 -0
- tests/test_ffm_make_tile_shapes.py +49 -0
- tests/test_mapper.py +100 -0
- tests/test_model.py +37 -0
- tests/test_plotting.py +72 -0
- tests/test_processing_stage.py +46 -0
- tests/test_symbolic_model.py +248 -0
- tests/test_workload.py +141 -0
|
@@ -0,0 +1,463 @@
|
|
|
1
|
+
from accelforge.frontend.mapping.mapping import MappingNode
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
import copy
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
import itertools
|
|
7
|
+
import logging
|
|
8
|
+
from typing import Any, Iterator, List
|
|
9
|
+
import uuid
|
|
10
|
+
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
|
|
13
|
+
import accelforge.frontend.arch as arch
|
|
14
|
+
from accelforge.frontend.mapping import (
|
|
15
|
+
Compute,
|
|
16
|
+
Loop,
|
|
17
|
+
Mapping,
|
|
18
|
+
MappingNode,
|
|
19
|
+
Spatial,
|
|
20
|
+
TensorHolder,
|
|
21
|
+
Temporal,
|
|
22
|
+
)
|
|
23
|
+
from accelforge.frontend.spec import Spec
|
|
24
|
+
from accelforge.frontend._workload_isl._isl import get_rank_variable_bounds
|
|
25
|
+
from accelforge.frontend._workload_isl._symbolic import (
|
|
26
|
+
Relevant,
|
|
27
|
+
get_rank_variable_relevancy,
|
|
28
|
+
get_stride_and_halo,
|
|
29
|
+
get_stride_and_halo_of_einsum,
|
|
30
|
+
PartiallyRelevant,
|
|
31
|
+
)
|
|
32
|
+
from accelforge.frontend.workload import (
|
|
33
|
+
TensorName,
|
|
34
|
+
Einsum,
|
|
35
|
+
EinsumName,
|
|
36
|
+
RankVariable,
|
|
37
|
+
Workload,
|
|
38
|
+
isl_expression_has_variable,
|
|
39
|
+
SymbolTable,
|
|
40
|
+
)
|
|
41
|
+
from accelforge.mapper.FFM._make_pmappings.make_pmapping_templates.make_storage_order import (
|
|
42
|
+
get_tensor_choices,
|
|
43
|
+
)
|
|
44
|
+
from accelforge.mapper.FFM._make_pmappings.make_pmapping_templates.make_reservations import (
|
|
45
|
+
get_reservation_choices,
|
|
46
|
+
)
|
|
47
|
+
from accelforge.mapper.FFM._make_pmappings.contraints.constraints import (
|
|
48
|
+
MappingConstraints,
|
|
49
|
+
get_constraints,
|
|
50
|
+
)
|
|
51
|
+
from accelforge.mapper.FFM._make_pmappings.make_pmapping_templates.make_loops import (
|
|
52
|
+
insert_temporal_loops,
|
|
53
|
+
insert_spatial_loops,
|
|
54
|
+
)
|
|
55
|
+
from accelforge.mapper.FFM._make_pmappings.pmapper_job import (
|
|
56
|
+
Job,
|
|
57
|
+
SameEinsumJobs,
|
|
58
|
+
)
|
|
59
|
+
from accelforge.model._looptree.reuse.symbolic import label_fused_loops
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def unpack_loops_to_rank_variables(mapping: List[MappingNode]):
|
|
63
|
+
mapping_new = []
|
|
64
|
+
for node in mapping:
|
|
65
|
+
if not isinstance(node, Loop) or not isinstance(node.rank_variable, set):
|
|
66
|
+
mapping_new.append(node)
|
|
67
|
+
continue
|
|
68
|
+
|
|
69
|
+
for r in sorted(node.rank_variable):
|
|
70
|
+
mapping_new.append(
|
|
71
|
+
type(node)(
|
|
72
|
+
rank_variable=r,
|
|
73
|
+
**node.model_dump(exclude={"rank_variable"}, recursive=False),
|
|
74
|
+
)
|
|
75
|
+
)
|
|
76
|
+
return mapping_new
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
# =================================================================================================
|
|
80
|
+
# Iterate over mappings
|
|
81
|
+
# =================================================================================================
|
|
82
|
+
def place_missing_temporal_loops(
|
|
83
|
+
mapping: List[MappingNode], einsum: Einsum, flattened_arch: list[arch.Leaf]
|
|
84
|
+
):
|
|
85
|
+
"""
|
|
86
|
+
Adds temporal loops to the mapping to fill in any rank variables that are missing.
|
|
87
|
+
This may occur if there are no points where it'd be helpful to add a non-fused loop,
|
|
88
|
+
so we just need to add one somewhere.
|
|
89
|
+
"""
|
|
90
|
+
# If any rank variables are missing, add them as high as possible.
|
|
91
|
+
|
|
92
|
+
rank_variables = einsum.rank_variables
|
|
93
|
+
for m in mapping:
|
|
94
|
+
if isinstance(m, Temporal) and not m._fused:
|
|
95
|
+
rank_variables.discard(m.rank_variable)
|
|
96
|
+
|
|
97
|
+
# Insert point: Right under the last backing & below any out-of-order fanouts
|
|
98
|
+
fanouts = {}
|
|
99
|
+
fanout = 1
|
|
100
|
+
for node in flattened_arch:
|
|
101
|
+
fanouts[node.name] = (fanout := fanout * node.get_fanout())
|
|
102
|
+
|
|
103
|
+
insert_point = 0
|
|
104
|
+
greatest_previous_fanout = 1
|
|
105
|
+
for i in range(len(mapping)):
|
|
106
|
+
if isinstance(mapping[i], TensorHolder):
|
|
107
|
+
if mapping[i]._backing:
|
|
108
|
+
insert_point = i + 1
|
|
109
|
+
cur_fanout = fanouts[mapping[i].component]
|
|
110
|
+
if cur_fanout < greatest_previous_fanout:
|
|
111
|
+
insert_point = i + 1
|
|
112
|
+
greatest_previous_fanout = max(greatest_previous_fanout, cur_fanout)
|
|
113
|
+
|
|
114
|
+
# Put it below all the other temporals here in case we're lowering through them
|
|
115
|
+
if isinstance(mapping[i], Temporal) and insert_point == i:
|
|
116
|
+
insert_point = i + 1
|
|
117
|
+
|
|
118
|
+
temporals = [Temporal(rank_variable=r) for r in sorted(rank_variables)]
|
|
119
|
+
|
|
120
|
+
if insert_point == len(mapping):
|
|
121
|
+
mapping.extend(temporals)
|
|
122
|
+
else:
|
|
123
|
+
for t in temporals:
|
|
124
|
+
mapping.insert(insert_point, t)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def remove_unordered_spatial_temporal_loops(
|
|
128
|
+
mapping: list[MappingNode],
|
|
129
|
+
flattened_arch: list[arch.Leaf],
|
|
130
|
+
einsum: Einsum,
|
|
131
|
+
explore_unordered_spatial_loops: bool = True,
|
|
132
|
+
):
|
|
133
|
+
fanout = 1
|
|
134
|
+
fanouts = {}
|
|
135
|
+
for node in flattened_arch:
|
|
136
|
+
fanouts[node.name] = (fanout := fanout * node.get_fanout())
|
|
137
|
+
|
|
138
|
+
index_exprs = einsum.indexing_expressions
|
|
139
|
+
|
|
140
|
+
# Remove a temporal loop if:
|
|
141
|
+
# - It's between a spatial loop and a storage node above that fanout in the arch
|
|
142
|
+
# - It indexes into one of the same indexing expressions as the spatial loop
|
|
143
|
+
|
|
144
|
+
disallowed_combinations: list[tuple[set[int], set[int]]] = []
|
|
145
|
+
for i, node in enumerate(mapping):
|
|
146
|
+
if not isinstance(node, Spatial):
|
|
147
|
+
continue
|
|
148
|
+
|
|
149
|
+
last_idx_to_check = _idx_of_lowest_tensor_holder_with_component_above_fanout(
|
|
150
|
+
mapping, i, fanouts, node
|
|
151
|
+
)
|
|
152
|
+
to_check = mapping[i + 1 : last_idx_to_check]
|
|
153
|
+
to_remove = set()
|
|
154
|
+
for n in to_check:
|
|
155
|
+
if isinstance(n, Temporal):
|
|
156
|
+
for expr in index_exprs:
|
|
157
|
+
if not isl_expression_has_variable(expr, node.rank_variable):
|
|
158
|
+
continue
|
|
159
|
+
if not isl_expression_has_variable(expr, n.rank_variable):
|
|
160
|
+
continue
|
|
161
|
+
to_remove.add(id(n))
|
|
162
|
+
break
|
|
163
|
+
|
|
164
|
+
if to_remove:
|
|
165
|
+
disallowed_combinations.append((set([id(node)]), to_remove))
|
|
166
|
+
|
|
167
|
+
if not explore_unordered_spatial_loops:
|
|
168
|
+
disallowed_combinations = [x[1:] for x in disallowed_combinations]
|
|
169
|
+
|
|
170
|
+
for combo in itertools.product(*disallowed_combinations):
|
|
171
|
+
combo = set.union(set(), *combo)
|
|
172
|
+
yield [n for n in mapping if id(n) not in combo]
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _idx_of_lowest_tensor_holder_with_component_above_fanout(
|
|
176
|
+
mapping, start_idx, fanouts, node
|
|
177
|
+
):
|
|
178
|
+
"""
|
|
179
|
+
Return idx of lowest tensor holder with component above fanout. If none
|
|
180
|
+
found, returns index right under start idx (start_idx + 1).
|
|
181
|
+
"""
|
|
182
|
+
for j in range(len(mapping) - 1, start_idx, -1):
|
|
183
|
+
n = mapping[j]
|
|
184
|
+
if (
|
|
185
|
+
isinstance(n, TensorHolder)
|
|
186
|
+
and fanouts[n.component] < fanouts[node.component]
|
|
187
|
+
):
|
|
188
|
+
return j
|
|
189
|
+
return start_idx + 1
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def pad_with_bottom_loops(mapping: list[MappingNode], einsum: Einsum):
|
|
193
|
+
rank_variables = einsum.rank_variables
|
|
194
|
+
rank_var_to_count = defaultdict(lambda: 0)
|
|
195
|
+
for node in mapping:
|
|
196
|
+
if isinstance(node, Temporal):
|
|
197
|
+
rank_var_to_count[node.rank_variable] += 1
|
|
198
|
+
|
|
199
|
+
for rank_var in rank_variables:
|
|
200
|
+
if rank_var_to_count[rank_var] < 2:
|
|
201
|
+
mapping.append(Temporal(rank_variable=rank_var))
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _timeloop_style_even(mapping: list[MappingNode]):
|
|
205
|
+
# Iterate through the mapping. If there are >2 TensorHolder nodes for the same
|
|
206
|
+
# memory, move all below the 2nd to the same level as the 2nd.
|
|
207
|
+
mapping = copy.deepcopy(mapping)
|
|
208
|
+
memory2indices = defaultdict(list)
|
|
209
|
+
i = 0
|
|
210
|
+
while i < len(mapping):
|
|
211
|
+
node = mapping[i]
|
|
212
|
+
if not isinstance(mapping[i], TensorHolder):
|
|
213
|
+
i += 1
|
|
214
|
+
continue
|
|
215
|
+
node: TensorHolder
|
|
216
|
+
seen = memory2indices[node.component]
|
|
217
|
+
mapping[i]._lower = False # Lowering might re-uneven the reservationsxs
|
|
218
|
+
|
|
219
|
+
if len(seen) <= 1:
|
|
220
|
+
seen.append(i)
|
|
221
|
+
else:
|
|
222
|
+
mapping.insert(seen[-1] + 1, mapping.pop(i))
|
|
223
|
+
i += 1
|
|
224
|
+
return mapping
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def assert_proper_fusion_labeling(
|
|
228
|
+
mapping: list[MappingNode],
|
|
229
|
+
fusable_tensors: set[TensorName],
|
|
230
|
+
check_loops: bool = True,
|
|
231
|
+
):
|
|
232
|
+
tensors = set()
|
|
233
|
+
for i, t in enumerate(mapping):
|
|
234
|
+
if not isinstance(t, TensorHolder):
|
|
235
|
+
continue
|
|
236
|
+
|
|
237
|
+
new = (set(t.tensors) - tensors) & fusable_tensors
|
|
238
|
+
|
|
239
|
+
if new and check_loops:
|
|
240
|
+
for j in range(i):
|
|
241
|
+
if isinstance(mapping[j], Loop):
|
|
242
|
+
assert mapping[
|
|
243
|
+
j
|
|
244
|
+
]._fused, f"Node {j} is not fused in {' '.join(m.compact_str() for m in mapping)}"
|
|
245
|
+
assert (
|
|
246
|
+
t._backing & fusable_tensors
|
|
247
|
+
) == new, f"Node {i} backing missing {new - t._backing} in {' '.join(m.compact_str() for m in mapping)}"
|
|
248
|
+
tensors.update(new)
|
|
249
|
+
tensors.update(t.tensors)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def get_initial_delta_choices(einsum_name: str, workload: Workload):
|
|
253
|
+
stride_and_halo = get_stride_and_halo(workload)
|
|
254
|
+
einsum = workload.einsums[einsum_name]
|
|
255
|
+
|
|
256
|
+
choices = defaultdict(lambda: set([0]))
|
|
257
|
+
consumer_chains = []
|
|
258
|
+
stack = [[(None, einsum)]]
|
|
259
|
+
while stack:
|
|
260
|
+
cur_chain = stack.pop()
|
|
261
|
+
last_tensor, last_einsum = cur_chain[-1]
|
|
262
|
+
for tensor in last_einsum.output_tensor_names:
|
|
263
|
+
einsums_with_tensor_as_input = workload.einsums_with_tensor_as_input(tensor)
|
|
264
|
+
|
|
265
|
+
if len(einsums_with_tensor_as_input) == 0:
|
|
266
|
+
consumer_chains.append(cur_chain)
|
|
267
|
+
|
|
268
|
+
for next_einsum in einsums_with_tensor_as_input:
|
|
269
|
+
stack.append(cur_chain + [(tensor, next_einsum)])
|
|
270
|
+
|
|
271
|
+
for chain in consumer_chains:
|
|
272
|
+
for (_, producer), (tensor, consumer) in zip(
|
|
273
|
+
list(reversed(chain))[1:], reversed(chain)
|
|
274
|
+
):
|
|
275
|
+
rank_stride_and_halo = stride_and_halo[(consumer.name, tensor)]
|
|
276
|
+
if tensor is None:
|
|
277
|
+
break # done
|
|
278
|
+
|
|
279
|
+
for cons_rank_var in consumer.rank_variables:
|
|
280
|
+
for prod_rank_var in producer.rank_variables:
|
|
281
|
+
for cons_choice in choices[cons_rank_var]:
|
|
282
|
+
if (prod_rank_var, cons_rank_var) not in rank_stride_and_halo:
|
|
283
|
+
continue
|
|
284
|
+
stride, halo = rank_stride_and_halo[
|
|
285
|
+
(prod_rank_var, cons_rank_var)
|
|
286
|
+
]
|
|
287
|
+
choices[prod_rank_var].add(cons_choice * stride + halo)
|
|
288
|
+
|
|
289
|
+
return choices
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def get_ranks_with_tile_pattern(producer_name: EinsumName, workload: Workload):
|
|
293
|
+
initial_choices = get_initial_delta_choices(producer_name, workload)
|
|
294
|
+
return {
|
|
295
|
+
rank_var
|
|
296
|
+
for rank_var in workload.einsums[producer_name].rank_variables
|
|
297
|
+
if len(initial_choices[rank_var]) > 1
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def iterate_mappings_no_constraints(
|
|
302
|
+
spec: Spec,
|
|
303
|
+
einsum_name: str,
|
|
304
|
+
flattened_arch: list[arch.Leaf],
|
|
305
|
+
rank_variable_bounds: dict[RankVariable, int],
|
|
306
|
+
job: Job,
|
|
307
|
+
) -> Iterator[tuple[Mapping, SymbolTable, arch.Compute, int]]:
|
|
308
|
+
first_memory = None
|
|
309
|
+
for node in flattened_arch:
|
|
310
|
+
if isinstance(node, arch.Memory):
|
|
311
|
+
first_memory = node
|
|
312
|
+
break
|
|
313
|
+
if first_memory is None:
|
|
314
|
+
raise ValueError("No memory found in architecture")
|
|
315
|
+
|
|
316
|
+
ranks_with_tile_pattern = get_ranks_with_tile_pattern(einsum_name, spec.workload)
|
|
317
|
+
|
|
318
|
+
einsum = spec.workload.einsums[einsum_name]
|
|
319
|
+
symbol_table = {r.name: r.source for r in einsum.renames}
|
|
320
|
+
fusable_tensors = job.fusable_tensors
|
|
321
|
+
|
|
322
|
+
for mapping, symbol_table, compute in get_tensor_choices(
|
|
323
|
+
einsum_name,
|
|
324
|
+
flattened_arch,
|
|
325
|
+
symbol_table,
|
|
326
|
+
spec,
|
|
327
|
+
first_memory,
|
|
328
|
+
fusable_tensors,
|
|
329
|
+
):
|
|
330
|
+
logging.info(
|
|
331
|
+
"\tGenerated tensor choices: " + ", ".join(m.compact_str() for m in mapping)
|
|
332
|
+
)
|
|
333
|
+
mapping = copy.deepcopy(mapping)
|
|
334
|
+
for mapping, n_orders in insert_temporal_loops(
|
|
335
|
+
mapping,
|
|
336
|
+
einsum,
|
|
337
|
+
first_memory,
|
|
338
|
+
rank_variable_bounds,
|
|
339
|
+
ranks_with_tile_pattern,
|
|
340
|
+
spec.workload,
|
|
341
|
+
spec.mapper.ffm._can_lower_outermost_memory,
|
|
342
|
+
flattened_arch,
|
|
343
|
+
spec.mapper.ffm.max_fused_loops,
|
|
344
|
+
):
|
|
345
|
+
mapping = copy.deepcopy(mapping)
|
|
346
|
+
insert_spatial_loops(mapping, einsum, flattened_arch)
|
|
347
|
+
mapping = unpack_loops_to_rank_variables(mapping)
|
|
348
|
+
if spec.mapper.ffm._timeloop_style_even:
|
|
349
|
+
mapping = _timeloop_style_even(mapping)
|
|
350
|
+
|
|
351
|
+
place_missing_temporal_loops(mapping, einsum, flattened_arch)
|
|
352
|
+
label_fused_loops(mapping, fusable_tensors)
|
|
353
|
+
assert_proper_fusion_labeling(mapping, fusable_tensors)
|
|
354
|
+
yield mapping, symbol_table, compute, n_orders
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def iterate_mappings_constraints(
|
|
358
|
+
spec: Spec,
|
|
359
|
+
einsum_names: list[str] | str,
|
|
360
|
+
flattened_arch: list[arch.Leaf],
|
|
361
|
+
rank_variable_bounds: dict[RankVariable, int],
|
|
362
|
+
tensor_to_relevancy: dict[
|
|
363
|
+
TensorName, dict[RankVariable, Relevant | PartiallyRelevant]
|
|
364
|
+
],
|
|
365
|
+
job: Job,
|
|
366
|
+
) -> Iterator[tuple[Mapping, MappingConstraints, dict[str, str]]]:
|
|
367
|
+
compute_name = flattened_arch[-1].name
|
|
368
|
+
|
|
369
|
+
n_yielded = 0
|
|
370
|
+
|
|
371
|
+
if isinstance(einsum_names, str):
|
|
372
|
+
einsum_names = [einsum_names]
|
|
373
|
+
|
|
374
|
+
for einsum_name in einsum_names:
|
|
375
|
+
logging.info(
|
|
376
|
+
f"Generating pmapping templates for compute {compute_name} Einsums "
|
|
377
|
+
f"{einsum_name}"
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
for mapping, symbol_table, compute, n_orders in iterate_mappings_no_constraints(
|
|
381
|
+
spec,
|
|
382
|
+
einsum_name,
|
|
383
|
+
flattened_arch,
|
|
384
|
+
rank_variable_bounds,
|
|
385
|
+
job,
|
|
386
|
+
):
|
|
387
|
+
mapping, constraints = get_constraints(
|
|
388
|
+
flattened_arch, mapping, symbol_table, einsum_name, tensor_to_relevancy
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
# This goes after the constraints because constraints may remove some loops,
|
|
392
|
+
# giving us fewer that may be reordered.
|
|
393
|
+
for mapping in remove_unordered_spatial_temporal_loops(
|
|
394
|
+
mapping,
|
|
395
|
+
flattened_arch,
|
|
396
|
+
spec.workload.einsums[einsum_name],
|
|
397
|
+
spec.mapper.ffm.out_of_order_hierarchy_explore_removing_spatials_for_more_temporals,
|
|
398
|
+
):
|
|
399
|
+
constraints.remove_missing_targets(mapping)
|
|
400
|
+
|
|
401
|
+
mapping.append(
|
|
402
|
+
Compute(
|
|
403
|
+
einsum=einsum_name,
|
|
404
|
+
component=compute_name,
|
|
405
|
+
component_object=compute,
|
|
406
|
+
)
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
# MAPPING MUST NOT BE MODIFIED AFTER constraints.set_loop_indices
|
|
410
|
+
constraints.set_loop_indices(mapping)
|
|
411
|
+
|
|
412
|
+
mapping = Mapping(nodes=[copy.copy(n) for n in mapping])
|
|
413
|
+
mapping._n_loop_orders = n_orders
|
|
414
|
+
yield mapping, constraints, symbol_table
|
|
415
|
+
n_yielded += 1
|
|
416
|
+
if n_yielded >= spec.mapper.ffm.max_pmapping_templates_per_einsum:
|
|
417
|
+
return
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
# =================================================================================================
|
|
421
|
+
# Top level
|
|
422
|
+
# =================================================================================================
|
|
423
|
+
def make_pmapping_templates(job: Job) -> SameEinsumJobs:
|
|
424
|
+
compute_name = job.flattened_arch[-1].name
|
|
425
|
+
|
|
426
|
+
job.tensor_to_relevancy = {
|
|
427
|
+
tensor: get_rank_variable_relevancy(
|
|
428
|
+
job.spec.workload.einsums[job.einsum_name], tensor
|
|
429
|
+
)
|
|
430
|
+
for tensor in job.spec.workload.einsums[job.einsum_name].tensor_names
|
|
431
|
+
}
|
|
432
|
+
|
|
433
|
+
mappings_constraints = tqdm(
|
|
434
|
+
iterate_mappings_constraints(
|
|
435
|
+
job.spec,
|
|
436
|
+
job.einsum_name,
|
|
437
|
+
job.flattened_arch,
|
|
438
|
+
job.rank_variable_bounds,
|
|
439
|
+
job.tensor_to_relevancy,
|
|
440
|
+
job,
|
|
441
|
+
),
|
|
442
|
+
desc=f"Generating pmapping templates for compute {compute_name} Einsum {job.einsum_name}",
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
stride_and_halo = get_stride_and_halo_of_einsum(
|
|
446
|
+
job.einsum_name, job.spec.workload, job.rank_variable_bounds
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
jobs = SameEinsumJobs()
|
|
450
|
+
only_output_pmapping_index = job.spec.mapper.ffm._only_output_pmapping_index
|
|
451
|
+
for i, (mapping, constraints, symbol_table) in enumerate(mappings_constraints):
|
|
452
|
+
if only_output_pmapping_index is not None and i != only_output_pmapping_index:
|
|
453
|
+
continue
|
|
454
|
+
new_job = copy.copy(job)
|
|
455
|
+
new_job.mapping = mapping
|
|
456
|
+
new_job.constraints = constraints
|
|
457
|
+
new_job.job_id = uuid.uuid4()
|
|
458
|
+
new_job.rank_variable_bounds = job.rank_variable_bounds
|
|
459
|
+
new_job.stride_and_halo = stride_and_halo
|
|
460
|
+
new_job.compatibility
|
|
461
|
+
jobs.append(new_job)
|
|
462
|
+
|
|
463
|
+
return jobs
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from collections.abc import Generator
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import accelforge.frontend.arch as arch
|
|
5
|
+
from accelforge.frontend.mapping import MappingNode, Reservation, Storage, TensorHolder
|
|
6
|
+
from accelforge.frontend.spec import Spec
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _recursive_iter_fence_positions(
|
|
10
|
+
fence_positions: dict[str, int],
|
|
11
|
+
max_size: int,
|
|
12
|
+
) -> Generator[tuple[list[TensorHolder], Any], None, None]:
|
|
13
|
+
if not fence_positions:
|
|
14
|
+
yield {}
|
|
15
|
+
mine = next(iter(fence_positions))
|
|
16
|
+
myval = fence_positions[mine]
|
|
17
|
+
following = {k: v for k, v in fence_positions.items() if k != mine}
|
|
18
|
+
for i in range(myval, max_size):
|
|
19
|
+
following = {k: max(v, i) for k, v in fence_positions.items() if k != mine}
|
|
20
|
+
for following in _recursive_iter_fence_positions(following, max_size):
|
|
21
|
+
yield {mine: i, **following}
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_reservation_choices(
|
|
25
|
+
mapping: list[TensorHolder],
|
|
26
|
+
flattened_arch: list[arch.Leaf],
|
|
27
|
+
) -> Generator[tuple[list[TensorHolder], Any], None, None]:
|
|
28
|
+
# Rules:
|
|
29
|
+
# - In general, reservations go right under their storage node
|
|
30
|
+
# - If a storage node is associated with a fanout, explore putting the reservation
|
|
31
|
+
# below it, below the next storage node, and so on. Stop once we don't have any
|
|
32
|
+
# more spatial loops to place. Push down all reservations below this fanout
|
|
33
|
+
# together.
|
|
34
|
+
|
|
35
|
+
# Spatial loops:
|
|
36
|
+
# - Must go below all storage nodes associated with something above the fanout.
|
|
37
|
+
# -> Memories above fanout must serve all fetches across fanout instances.
|
|
38
|
+
# - Must go above all reservations associated with something below the fanout.
|
|
39
|
+
# -> Memories below fanout must be reserved for each fanout instance.
|
|
40
|
+
# - If below any storage node associated with the fanout, then must be relevant.
|
|
41
|
+
# -> No peer-to-peer communication
|
|
42
|
+
|
|
43
|
+
# Temporal loops:
|
|
44
|
+
# - If between a storage node and a reservation node, the outermost temporal loop
|
|
45
|
+
# may be partially relevant. All others must be relevant.
|
|
46
|
+
|
|
47
|
+
# Design choices here:
|
|
48
|
+
# - Where to put the 'fence' for each fanout
|
|
49
|
+
|
|
50
|
+
fanout_nodes = [n for n in flattened_arch if n.get_fanout() > 1]
|
|
51
|
+
fanout_node_names = set[str](n.name for n in fanout_nodes)
|
|
52
|
+
last_seen_fanout = None
|
|
53
|
+
node2lastfanout = {}
|
|
54
|
+
|
|
55
|
+
fence_positions: dict[str, int] = {}
|
|
56
|
+
for i, node in enumerate(mapping):
|
|
57
|
+
if node.component in fanout_node_names:
|
|
58
|
+
fence_positions.setdefault(node.component, i)
|
|
59
|
+
last_seen_fanout = node.component
|
|
60
|
+
node2lastfanout[id(node)] = last_seen_fanout
|
|
61
|
+
|
|
62
|
+
def try_add_reservations(
|
|
63
|
+
new_mapping: list[MappingNode],
|
|
64
|
+
reservations_to_add: list[TensorHolder],
|
|
65
|
+
fence_positions: dict[str, int],
|
|
66
|
+
):
|
|
67
|
+
for res in list(reservations_to_add):
|
|
68
|
+
add = False
|
|
69
|
+
if node2lastfanout[id(res)] is None:
|
|
70
|
+
add = True
|
|
71
|
+
elif i >= fence_positions[node2lastfanout[id(res)]]:
|
|
72
|
+
add = True
|
|
73
|
+
if add:
|
|
74
|
+
new_mapping.append(
|
|
75
|
+
Reservation(
|
|
76
|
+
purposes=[res.component],
|
|
77
|
+
resource=res.component,
|
|
78
|
+
persistent=res.persistent,
|
|
79
|
+
)
|
|
80
|
+
)
|
|
81
|
+
reservations_to_add.remove(res)
|
|
82
|
+
|
|
83
|
+
# Fence positions are indices of storage nodes below which we'll push all the
|
|
84
|
+
# reservations below that fanout
|
|
85
|
+
for fence_positions in _recursive_iter_fence_positions(
|
|
86
|
+
fence_positions, len(mapping)
|
|
87
|
+
):
|
|
88
|
+
new_mapping = []
|
|
89
|
+
reservations_to_add = []
|
|
90
|
+
for i, node in enumerate(mapping):
|
|
91
|
+
new_mapping.append(node)
|
|
92
|
+
reservations_to_add.append(node)
|
|
93
|
+
try_add_reservations(new_mapping, reservations_to_add, fence_positions)
|
|
94
|
+
try_add_reservations(new_mapping, reservations_to_add, fence_positions)
|
|
95
|
+
yield new_mapping
|