accelforge 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accelforge/__init__.py +21 -0
- accelforge/_accelerated_imports.py +16 -0
- accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
- accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
- accelforge/_deprecate/_simanneal/simanneal.py +666 -0
- accelforge/_deprecate/_simanneal/tracking.py +105 -0
- accelforge/_deprecate/_simanneal/wrappers.py +218 -0
- accelforge/_deprecate/_simanneal2/__init__.py +7 -0
- accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
- accelforge/_deprecate/_simanneal2/tracking.py +116 -0
- accelforge/_deprecate/compatibility_util.py +181 -0
- accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
- accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
- accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
- accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
- accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
- accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
- accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
- accelforge/_deprecate/tags.py +69 -0
- accelforge/_deprecate/viz/__init__.py +0 -0
- accelforge/_deprecate/viz/interactive.py +159 -0
- accelforge/_deprecate/viz/reservationtree.py +307 -0
- accelforge/_deprecate/viz/ski_slope.py +88 -0
- accelforge/_version.py +15 -0
- accelforge/examples.py +39 -0
- accelforge/frontend/__init__.py +10 -0
- accelforge/frontend/_binding.py +129 -0
- accelforge/frontend/_workload_isl/__init__.py +2 -0
- accelforge/frontend/_workload_isl/_isl.py +149 -0
- accelforge/frontend/_workload_isl/_symbolic.py +141 -0
- accelforge/frontend/arch copy.py +1544 -0
- accelforge/frontend/arch.py +1642 -0
- accelforge/frontend/config.py +63 -0
- accelforge/frontend/mapper/__init__.py +5 -0
- accelforge/frontend/mapper/ffm.py +126 -0
- accelforge/frontend/mapper/mapper.py +7 -0
- accelforge/frontend/mapper/metrics.py +30 -0
- accelforge/frontend/mapping/__init__.py +1 -0
- accelforge/frontend/mapping/mapping.py +1736 -0
- accelforge/frontend/model.py +14 -0
- accelforge/frontend/renames.py +150 -0
- accelforge/frontend/spec copy.py +230 -0
- accelforge/frontend/spec.py +301 -0
- accelforge/frontend/variables.py +12 -0
- accelforge/frontend/workload.py +952 -0
- accelforge/mapper/FFM/__init__.py +9 -0
- accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
- accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
- accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
- accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
- accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
- accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
- accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
- accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
- accelforge/mapper/FFM/data.py +61 -0
- accelforge/mapper/FFM/main copy.py +236 -0
- accelforge/mapper/FFM/main.py +208 -0
- accelforge/mapper/FFM/mappings.py +510 -0
- accelforge/mapper/FFM/pmappings.py +310 -0
- accelforge/mapper/__init__.py +4 -0
- accelforge/mapper.py +0 -0
- accelforge/model/__init__.py +1 -0
- accelforge/model/_looptree/__init__.py +0 -0
- accelforge/model/_looptree/accesses.py +335 -0
- accelforge/model/_looptree/capacity/__init__.py +1 -0
- accelforge/model/_looptree/capacity/aggregators.py +36 -0
- accelforge/model/_looptree/capacity/capacity.py +47 -0
- accelforge/model/_looptree/energy.py +150 -0
- accelforge/model/_looptree/equivalent_ranks.py +29 -0
- accelforge/model/_looptree/latency/__init__.py +1 -0
- accelforge/model/_looptree/latency/latency.py +98 -0
- accelforge/model/_looptree/latency/memory.py +120 -0
- accelforge/model/_looptree/latency/processors.py +92 -0
- accelforge/model/_looptree/mapping_utilities.py +71 -0
- accelforge/model/_looptree/reuse/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
- accelforge/model/_looptree/reuse/isl/des.py +59 -0
- accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
- accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
- accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
- accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
- accelforge/model/_looptree/run.py +122 -0
- accelforge/model/_looptree/types.py +26 -0
- accelforge/model/_looptree/visualization/__init__.py +0 -0
- accelforge/model/_looptree/visualization/occupancy.py +11 -0
- accelforge/model/main.py +222 -0
- accelforge/plotting/__init__.py +2 -0
- accelforge/plotting/mappings.py +219 -0
- accelforge/plotting/specs.py +57 -0
- accelforge/util/__init__.py +4 -0
- accelforge/util/_base_analysis_types.py +24 -0
- accelforge/util/_basetypes.py +1089 -0
- accelforge/util/_frozenset.py +36 -0
- accelforge/util/_isl.py +29 -0
- accelforge/util/_itertools.py +14 -0
- accelforge/util/_mathfuncs.py +57 -0
- accelforge/util/_parse_expressions.py +339 -0
- accelforge/util/_picklecache.py +32 -0
- accelforge/util/_setexpressions.py +268 -0
- accelforge/util/_sympy/__init__.py +0 -0
- accelforge/util/_sympy/broadcast_max.py +18 -0
- accelforge/util/_visualization.py +112 -0
- accelforge/util/_yaml.py +579 -0
- accelforge/util/parallel.py +193 -0
- accelforge-0.0.1.dist-info/METADATA +64 -0
- accelforge-0.0.1.dist-info/RECORD +258 -0
- accelforge-0.0.1.dist-info/WHEEL +5 -0
- accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
- accelforge-0.0.1.dist-info/top_level.txt +5 -0
- docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
- docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
- docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
- docs/_build/html/_sources/fastfusion.rst.txt +20 -0
- docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
- docs/_build/html/_sources/index.rst.txt +87 -0
- docs/_build/html/_sources/modules.rst.txt +7 -0
- docs/_build/html/_sources/notes/citation.rst.txt +45 -0
- docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
- docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
- docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
- docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
- docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
- docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
- docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
- docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
- docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
- docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
- docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
- docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
- docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
- docs/_build/html/_sources/notes/spec.rst.txt +36 -0
- docs/source/_ext/include_attrs.py +213 -0
- docs/source/_ext/include_docstring.py +364 -0
- docs/source/_ext/include_functions.py +154 -0
- docs/source/_ext/include_notebook.py +131 -0
- docs/source/_ext/include_yaml.py +119 -0
- docs/source/_ext/inherited_attributes.py +222 -0
- docs/source/_ext/paths.py +4 -0
- docs/source/conf.py +79 -0
- examples/arches/compute_in_memory/_include.yaml +74 -0
- examples/arches/compute_in_memory/_include_functions.py +229 -0
- examples/arches/compute_in_memory/_load_spec.py +57 -0
- examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
- examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
- examples/arches/compute_in_memory/components/misc.py +195 -0
- examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
- examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
- examples/arches/compute_in_memory/isaac.yaml +233 -0
- examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
- examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
- examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
- examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
- examples/arches/eyeriss.yaml +68 -0
- examples/arches/fanout_variations/at_glb.yaml +31 -0
- examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
- examples/arches/fanout_variations/at_mac.yaml +31 -0
- examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
- examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
- examples/arches/nvdla.yaml +47 -0
- examples/arches/simple.yaml +28 -0
- examples/arches/tpu_v4i.yaml +67 -0
- examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
- examples/misc/component_annotated.yaml +33 -0
- examples/workloads/gpt3_6.7B.yaml +124 -0
- examples/workloads/matmuls.yaml +20 -0
- examples/workloads/mobilenet_28.yaml +81 -0
- examples/workloads/mobilenet_various_separate.yaml +106 -0
- examples/workloads/three_matmuls_annotated.yaml +59 -0
- notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
- notebooks/compute_in_memory/_scripts.py +339 -0
- notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
- notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
- notebooks/paths.py +4 -0
- notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
- notebooks/tutorials/FFM.ipynb +3498 -0
- notebooks/tutorials/_include.py +48 -0
- notebooks/tutorials/component_energy_area.ipynb +363 -0
- tests/Q_mapping.yaml +38 -0
- tests/__init__.py +0 -0
- tests/conv.mapping.yaml +27 -0
- tests/conv.workload.yaml +13 -0
- tests/conv_sym.mapping.yaml +43 -0
- tests/copy.mapping.yaml +35 -0
- tests/copy.workload.yaml +15 -0
- tests/distribuffers/__init__.py +0 -0
- tests/distribuffers/multicast/test_cases.yaml +482 -0
- tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
- tests/distribuffers/spec/distributed.yaml +100 -0
- tests/distribuffers/spec/logical_arch.yaml +32 -0
- tests/distribuffers/spec/physical_arch.yaml +69 -0
- tests/distribuffers/test_binding.py +48 -0
- tests/frontend/__init__.py +0 -0
- tests/frontend/test_mapping_viz.py +52 -0
- tests/mapper/__init__.py +0 -0
- tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
- tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
- tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
- tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
- tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
- tests/mapper/test_mapping_to_isl.py +90 -0
- tests/mapper/test_spatial_reuse_analysis.py +67 -0
- tests/mapper/test_temporal_reuse_analysis.py +56 -0
- tests/mapper/util.py +58 -0
- tests/matmul.mapping.yaml +29 -0
- tests/matmul.workload.yaml +12 -0
- tests/matmul_spatial.mapping.yaml +44 -0
- tests/mha.renames.yaml +65 -0
- tests/mha.workload.yaml +67 -0
- tests/mha.yaml +59 -0
- tests/mha_full.workload.yaml +67 -0
- tests/mobilenet.workload.yaml +35 -0
- tests/mobilenet_long.workload.yaml +64 -0
- tests/pmappingcache.py +24 -0
- tests/processing_stage.arch.yaml +40 -0
- tests/snowcat.arch.yaml +36 -0
- tests/test_ffm_join_pmappings.py +106 -0
- tests/test_ffm_make_pmappings.py +82 -0
- tests/test_ffm_make_tile_shapes.py +49 -0
- tests/test_mapper.py +100 -0
- tests/test_model.py +37 -0
- tests/test_plotting.py +72 -0
- tests/test_processing_stage.py +46 -0
- tests/test_symbolic_model.py +248 -0
- tests/test_workload.py +141 -0
|
@@ -0,0 +1,373 @@
|
|
|
1
|
+
from accelforge.frontend.renames import TensorName
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
import itertools
|
|
5
|
+
from enum import Enum
|
|
6
|
+
|
|
7
|
+
import accelforge.frontend.arch as arch
|
|
8
|
+
from accelforge.frontend.mapping import (
|
|
9
|
+
MappingNode,
|
|
10
|
+
ProcessingStage,
|
|
11
|
+
Temporal,
|
|
12
|
+
Spatial,
|
|
13
|
+
TensorHolder,
|
|
14
|
+
)
|
|
15
|
+
from accelforge.frontend.workload import (
|
|
16
|
+
Einsum,
|
|
17
|
+
RankVariable,
|
|
18
|
+
Workload,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# =================================================================================================
|
|
23
|
+
# Insert loops
|
|
24
|
+
# =================================================================================================
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class LowerChoice(Enum):
|
|
28
|
+
YES = 0
|
|
29
|
+
NO = 1
|
|
30
|
+
OPTIONAL = 2
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def insert_temporal_loops(
|
|
34
|
+
mapping: list[TensorHolder],
|
|
35
|
+
einsum: Einsum,
|
|
36
|
+
first_memory: arch.Memory,
|
|
37
|
+
rank_variable_bounds: dict[RankVariable, int],
|
|
38
|
+
ranks_with_tile_pattern: set,
|
|
39
|
+
workload: Workload,
|
|
40
|
+
_can_lower_outermost_memory: bool,
|
|
41
|
+
flattened_arch: list[arch.Leaf],
|
|
42
|
+
max_fused_loops: int,
|
|
43
|
+
):
|
|
44
|
+
# First establish insertion points. Insertion points are:
|
|
45
|
+
# - Below the last instance of the first memory
|
|
46
|
+
# - Between any two TensorHolder nodes
|
|
47
|
+
# - After the last TensorHolder node
|
|
48
|
+
|
|
49
|
+
# The following logic is really just to make sure that all the storage nodse for the
|
|
50
|
+
# outermost memory are together at the beginning of the split mapping. After that,
|
|
51
|
+
# each entries in the split mapping has a single TensorHolder.
|
|
52
|
+
split_mapping: list[list[TensorHolder]] = [[]]
|
|
53
|
+
for m in mapping:
|
|
54
|
+
split_mapping.append([m])
|
|
55
|
+
if len(split_mapping) > 1 and m.component == first_memory.name:
|
|
56
|
+
split_mapping[-2].extend(split_mapping.pop(-1))
|
|
57
|
+
for i, s in enumerate[list[TensorHolder | Spatial]](split_mapping):
|
|
58
|
+
for m in s:
|
|
59
|
+
if i == 0 and m.component != first_memory.name:
|
|
60
|
+
raise ValueError(
|
|
61
|
+
"The first TensorHolder in the mapping is not for the outermost "
|
|
62
|
+
"memory. This isn't known to be invalid, but the code may not "
|
|
63
|
+
"handle it."
|
|
64
|
+
)
|
|
65
|
+
elif i > 0 and m.component == first_memory.name:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
"First memory isn't at the top of the hierarchy. This isn't known"
|
|
68
|
+
"to be invalid, but the code may not handle it."
|
|
69
|
+
)
|
|
70
|
+
elif i == 0 and isinstance(m, Spatial):
|
|
71
|
+
raise ValueError(
|
|
72
|
+
"Found Spatial node before any TensorHolder. This isn't known to "
|
|
73
|
+
"be invalid, but the code may not handle it."
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
split_mapping = [m for m in split_mapping if m]
|
|
77
|
+
|
|
78
|
+
# These Einsum properties are recalculated since Einsum is mutable
|
|
79
|
+
# We're pre-computing and reusing for efficiency
|
|
80
|
+
tensor2fully_relevant_rank_vars = einsum.tensor2directly_indexing_rank_variables
|
|
81
|
+
tensor2partially_relevant_rank_vars = (
|
|
82
|
+
einsum.tensor2expression_indexing_rank_variables
|
|
83
|
+
)
|
|
84
|
+
tensor2irrelevant_rank_vars = einsum.tensor2irrelevant_rank_variables
|
|
85
|
+
tensor2rank_vars = einsum.tensor2rank_variables
|
|
86
|
+
tensors = einsum.tensor_names
|
|
87
|
+
|
|
88
|
+
fusable_tensors = (
|
|
89
|
+
einsum.tensor_names & workload.tensor_names_used_in_multiple_einsums
|
|
90
|
+
)
|
|
91
|
+
is_fused_loops = True
|
|
92
|
+
seen_tensors = set()
|
|
93
|
+
choices = []
|
|
94
|
+
lowering_choices: list[tuple[bool, ...]] = []
|
|
95
|
+
fanouts = {}
|
|
96
|
+
fanout = 1
|
|
97
|
+
for node in flattened_arch:
|
|
98
|
+
fanouts[node.name] = (fanout := fanout * node.get_fanout())
|
|
99
|
+
|
|
100
|
+
def _get_next_storages(i: int, pstage_allowed: bool = False) -> list[TensorHolder]:
|
|
101
|
+
for j in range(i + 1, len(split_mapping)):
|
|
102
|
+
assert len(split_mapping[j]) <= 1
|
|
103
|
+
# We don't add loops before processing stages
|
|
104
|
+
if isinstance(split_mapping[j][0], ProcessingStage) and not pstage_allowed:
|
|
105
|
+
continue
|
|
106
|
+
return split_mapping[j]
|
|
107
|
+
return []
|
|
108
|
+
|
|
109
|
+
prev_fanout = 1
|
|
110
|
+
someone_elses_spatials_may_be_placed_above = False
|
|
111
|
+
for i, prev_storages in enumerate(split_mapping):
|
|
112
|
+
# =============================================================================
|
|
113
|
+
# Choose what temporal loops to insert between prev_storages and the next
|
|
114
|
+
# TensorHolder node(s).
|
|
115
|
+
# =============================================================================
|
|
116
|
+
|
|
117
|
+
next_storages = _get_next_storages(i)
|
|
118
|
+
next_anything = _get_next_storages(i, pstage_allowed=True)
|
|
119
|
+
|
|
120
|
+
for s in prev_storages:
|
|
121
|
+
# No tensor holders must mix backing/non-backing tensors.
|
|
122
|
+
assert not s._backing or all(t in s._backing for t in s.tensors)
|
|
123
|
+
# One tensor per holder
|
|
124
|
+
assert len(s.tensors) == 1
|
|
125
|
+
|
|
126
|
+
rank_variables = einsum.rank_variables
|
|
127
|
+
# rank_variables = {r for r in rank_variables if rank_variable_bounds[r] > 1}
|
|
128
|
+
seen_tensors |= set.union(*(set(t.tensors) for t in prev_storages), set())
|
|
129
|
+
is_fused_loops = is_fused_loops and len(fusable_tensors - seen_tensors) > 0
|
|
130
|
+
prev_tensors = set.union(set(), *(set(t.tensors) for t in prev_storages))
|
|
131
|
+
next_persistent = set.union(
|
|
132
|
+
set(), *(set(t.tensors) for t in next_storages if t.persistent)
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
max_fanout_before = max(
|
|
136
|
+
[fanouts[s2.component] for s in split_mapping[:i] for s2 in s],
|
|
137
|
+
default=float("inf"),
|
|
138
|
+
)
|
|
139
|
+
min_fanout_after = min(
|
|
140
|
+
[fanouts[s2.component] for s in split_mapping[i + 1 :] for s2 in s],
|
|
141
|
+
default=0,
|
|
142
|
+
)
|
|
143
|
+
cur_fanout = set(fanouts[s2.component] for s2 in prev_storages)
|
|
144
|
+
next_fanout = set(fanouts[s2.component] for s2 in next_anything)
|
|
145
|
+
if len(next_fanout) == 0:
|
|
146
|
+
next_fanout.add(float("inf"))
|
|
147
|
+
# Either it's main memory or we have one entry in the list, so there should only
|
|
148
|
+
# be one
|
|
149
|
+
assert len(cur_fanout) == 1
|
|
150
|
+
assert len(next_fanout) == 1
|
|
151
|
+
cur_fanout = next(iter(cur_fanout))
|
|
152
|
+
next_fanout = next(iter(next_fanout))
|
|
153
|
+
|
|
154
|
+
# Can't have loops above persistent tensor holders
|
|
155
|
+
if next_persistent:
|
|
156
|
+
rank_variables &= set()
|
|
157
|
+
|
|
158
|
+
# No recomputation: If we haven't seen a tensor yet, must only iterate over
|
|
159
|
+
# fully-relevant rank variables.
|
|
160
|
+
for t in tensors - seen_tensors:
|
|
161
|
+
rank_variables &= tensor2fully_relevant_rank_vars[t]
|
|
162
|
+
|
|
163
|
+
if max_fused_loops == 0 and (fusable_tensors - seen_tensors):
|
|
164
|
+
rank_variables &= set()
|
|
165
|
+
|
|
166
|
+
# The fanout for a prior node may be placed here, so spatial nodes may be moved
|
|
167
|
+
# here
|
|
168
|
+
someone_elses_spatials_may_be_placed_below = (
|
|
169
|
+
next_fanout > cur_fanout and max_fanout_before > cur_fanout
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# If the fanout is about to increase, then spatial loops may be placed below the
|
|
173
|
+
# current node. There may have been constrained temporal loops earlier that need
|
|
174
|
+
# to be placed here, so we won't prohibit any loops.
|
|
175
|
+
if someone_elses_spatials_may_be_placed_below:
|
|
176
|
+
pass
|
|
177
|
+
else:
|
|
178
|
+
|
|
179
|
+
# Optimality-preserving optimization: Loops below processing stages aren't
|
|
180
|
+
# helpful because there is no storage. Ctrl-F for
|
|
181
|
+
# CONTIGUOUS_ITERATION_SPACE_DISCUSSION: Can't do this if we may put another
|
|
182
|
+
# node's spatial loops below this one, because lowering would add move the
|
|
183
|
+
# spatials down, which would constrain the temporals due to spatial-temporal
|
|
184
|
+
# crossing.
|
|
185
|
+
if isinstance(prev_storages[0], ProcessingStage):
|
|
186
|
+
rank_variables &= set()
|
|
187
|
+
|
|
188
|
+
# Generally we want to only use rank variables that are irrelevant to the
|
|
189
|
+
# previous tensors, else we'd just lower those tensors. However, we can't
|
|
190
|
+
# lower backing TensorHolder nodes because this will add loops to
|
|
191
|
+
# compatibility.
|
|
192
|
+
|
|
193
|
+
# Optimality-preserving optimization: We can trivially lower non-backing
|
|
194
|
+
# TensorHolder nodes through fully-relevant loops. Can't do this if the
|
|
195
|
+
# loops are fused because that'd add loops to the compatibility. Ctrl-F
|
|
196
|
+
# forCONTIGUOUS_ITERATION_SPACE_DISCUSSION: Can't do this if we may put
|
|
197
|
+
# another node's spatial loops below this one, because lowering would add
|
|
198
|
+
# move the spatials down, which would constrain the temporals due to
|
|
199
|
+
# spatial-temporal crossing.
|
|
200
|
+
for s in prev_storages:
|
|
201
|
+
for t in s.tensors:
|
|
202
|
+
if t not in s._backing and not s._must_be_here:
|
|
203
|
+
rank_variables -= tensor2fully_relevant_rank_vars[t]
|
|
204
|
+
|
|
205
|
+
# Optimality-preserving optimization: We can trivially raise TensorHolder
|
|
206
|
+
# nodes through irrelevant unfused loops. Can't do this if the loops are
|
|
207
|
+
# fused because that'd increase the lifetime of the TensorHolder node. Can't
|
|
208
|
+
# do this if the irrelevant rank variables partially-relevant to the
|
|
209
|
+
# previous tensors, since that affects the permutation. See
|
|
210
|
+
# CONTIGUOUS_ITERATION_SPACE_DISCUSSION: Can't do this if we may put another
|
|
211
|
+
# node's spatial loops above this one, because raising would add move the
|
|
212
|
+
# temporals down, which would constrain them due to spatial-temporal
|
|
213
|
+
# crossing. TODO: CONTIGUOUS_ITERATION_SPACE_DISCUSSION: This causes all
|
|
214
|
+
# loops to be added, but really we only need to re-add the ones that may
|
|
215
|
+
# conflict with a spatial loop.
|
|
216
|
+
if not is_fused_loops:
|
|
217
|
+
for s in next_storages:
|
|
218
|
+
if not s._must_be_here:
|
|
219
|
+
for t in s.tensors:
|
|
220
|
+
rvs = tensor2irrelevant_rank_vars[t]
|
|
221
|
+
for t2 in prev_tensors:
|
|
222
|
+
rvs -= tensor2partially_relevant_rank_vars[t2]
|
|
223
|
+
rank_variables -= rvs
|
|
224
|
+
|
|
225
|
+
# =============================================================================
|
|
226
|
+
# Determine whether to lower TensorHolder nodes through partially-relevant
|
|
227
|
+
# loops.
|
|
228
|
+
# =============================================================================
|
|
229
|
+
partially_relevant_to_previous = rank_variables & set.union(
|
|
230
|
+
set(), *(tensor2partially_relevant_rank_vars[t] for t in prev_tensors)
|
|
231
|
+
)
|
|
232
|
+
permutable_partially_relevant = set()
|
|
233
|
+
|
|
234
|
+
# NOTE: If the lowering logic for backing TensorHolders is updated & we can
|
|
235
|
+
# lower through >1 loops, then also update label_fused_loops
|
|
236
|
+
for s in prev_storages:
|
|
237
|
+
partially_relevant_to_previous = set.union(
|
|
238
|
+
set(), *(tensor2partially_relevant_rank_vars[t] for t in s.tensors)
|
|
239
|
+
)
|
|
240
|
+
partially_relevant_to_previous &= rank_variables
|
|
241
|
+
lowerable_backing = (
|
|
242
|
+
_can_lower_outermost_memory or s.component != first_memory.name
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
# Persistent. Must be at the top of the mapping.
|
|
246
|
+
if s.persistent:
|
|
247
|
+
lowering_choices.append((False,))
|
|
248
|
+
# Don't lower our own reservations through someone else's spatial loops.
|
|
249
|
+
elif someone_elses_spatials_may_be_placed_below:
|
|
250
|
+
lowering_choices.append((False,))
|
|
251
|
+
# Processing stage. Lowering doesn't matter. Don't lower.
|
|
252
|
+
elif isinstance(s, ProcessingStage):
|
|
253
|
+
lowering_choices.append((False,))
|
|
254
|
+
# Previous is backing and there's partially-relevant rank variables. May
|
|
255
|
+
# want to lower to reduce memory footprint, or raise to reduce number of
|
|
256
|
+
# fused loops.
|
|
257
|
+
elif s._backing and lowerable_backing and partially_relevant_to_previous:
|
|
258
|
+
lowering_choices.append((False, True))
|
|
259
|
+
permutable_partially_relevant |= partially_relevant_to_previous
|
|
260
|
+
# No backing in previous. No cost to lowering. Lower all
|
|
261
|
+
elif not s._backing:
|
|
262
|
+
lowering_choices.append((True,))
|
|
263
|
+
permutable_partially_relevant |= partially_relevant_to_previous
|
|
264
|
+
# Previous TensorHolder is backing but not lowerable or there are no
|
|
265
|
+
# partially relevant rank vars.
|
|
266
|
+
else:
|
|
267
|
+
lowering_choices.append((False,))
|
|
268
|
+
|
|
269
|
+
# =============================================================================
|
|
270
|
+
# Create loop order and lowering choices
|
|
271
|
+
# =============================================================================
|
|
272
|
+
|
|
273
|
+
can_lower = any(any(c) for c in lowering_choices)
|
|
274
|
+
|
|
275
|
+
# Create canonical loop orders that avoids repeating reuse patterns.
|
|
276
|
+
choices.append(
|
|
277
|
+
list(
|
|
278
|
+
canonical_loop_orders(
|
|
279
|
+
rank_variables, permutable_partially_relevant, can_lower
|
|
280
|
+
)
|
|
281
|
+
)
|
|
282
|
+
)
|
|
283
|
+
prev_fanout = cur_fanout
|
|
284
|
+
someone_elses_spatials_may_be_placed_above = (
|
|
285
|
+
someone_elses_spatials_may_be_placed_below
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
# ==================================================================================
|
|
289
|
+
# Iterate over all possible mappings
|
|
290
|
+
# ==================================================================================
|
|
291
|
+
|
|
292
|
+
# TODO: Optimization: If we can optionally lower a tensor & the loop below it is
|
|
293
|
+
# not something through which we can lower for a given permutation, skip options
|
|
294
|
+
# that lower that tensor because they get the same result as not lowering the
|
|
295
|
+
# tensor.
|
|
296
|
+
n_loop_orders = len(list(itertools.product(*choices)))
|
|
297
|
+
for loop_orders in itertools.product(*choices):
|
|
298
|
+
full_mapping = []
|
|
299
|
+
for prev_storages, loop_order in zip(split_mapping, loop_orders):
|
|
300
|
+
full_mapping.extend(prev_storages)
|
|
301
|
+
full_mapping.extend(Temporal(rank_variable=r) for r in loop_order)
|
|
302
|
+
|
|
303
|
+
storages = [node for node in full_mapping if isinstance(node, TensorHolder)]
|
|
304
|
+
assert len(lowering_choices) == len(storages)
|
|
305
|
+
for lowering_choice in itertools.product(*lowering_choices):
|
|
306
|
+
for lower, node in zip(lowering_choice, storages):
|
|
307
|
+
node._lower = lower
|
|
308
|
+
|
|
309
|
+
yield list(full_mapping), n_loop_orders
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def insert_spatial_loops(
|
|
313
|
+
mapping: list[MappingNode],
|
|
314
|
+
einsum: Einsum,
|
|
315
|
+
flattened_arch: list[arch.Memory],
|
|
316
|
+
):
|
|
317
|
+
nodes_with_fanout = [n for n in flattened_arch if n.get_fanout() > 1]
|
|
318
|
+
arch_node_names = [n.name for n in flattened_arch]
|
|
319
|
+
|
|
320
|
+
for node in nodes_with_fanout:
|
|
321
|
+
insertion_point = _idx_of_highest_tensor_holder_with_component_below_fanout(
|
|
322
|
+
node, mapping, arch_node_names
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
rv = einsum.rank_variables
|
|
326
|
+
for fanout_dim in node.spatial:
|
|
327
|
+
for r in rv:
|
|
328
|
+
s = Spatial(
|
|
329
|
+
rank_variable=r,
|
|
330
|
+
name=fanout_dim.name,
|
|
331
|
+
component_object=node,
|
|
332
|
+
component=node.name,
|
|
333
|
+
)
|
|
334
|
+
if insertion_point == len(mapping):
|
|
335
|
+
mapping.append(s)
|
|
336
|
+
else:
|
|
337
|
+
mapping.insert(insertion_point, s)
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def _idx_of_highest_tensor_holder_with_component_below_fanout(
|
|
341
|
+
fanout_node, mapping, arch_node_names
|
|
342
|
+
):
|
|
343
|
+
for i in range(len(mapping)):
|
|
344
|
+
if not isinstance(mapping[i], TensorHolder):
|
|
345
|
+
continue
|
|
346
|
+
if arch_node_names.index(mapping[i].component) >= arch_node_names.index(
|
|
347
|
+
fanout_node.name
|
|
348
|
+
):
|
|
349
|
+
return i
|
|
350
|
+
return len(mapping)
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def canonical_loop_orders(
|
|
354
|
+
rank_variables: set[RankVariable],
|
|
355
|
+
partially_relevant_to_previous: set[RankVariable],
|
|
356
|
+
can_lower: bool,
|
|
357
|
+
):
|
|
358
|
+
"""Generate loop orders that result in unique reuse patterns."""
|
|
359
|
+
# Only the first partially-relevant rank variable matters is a meaningful
|
|
360
|
+
# choice because lowering only happens through at most one rank var.
|
|
361
|
+
if not partially_relevant_to_previous or not can_lower:
|
|
362
|
+
yield tuple(sorted(rank_variables))
|
|
363
|
+
return
|
|
364
|
+
|
|
365
|
+
for first_rank_var in partially_relevant_to_previous:
|
|
366
|
+
rest_of_partially_relevant = partially_relevant_to_previous - {first_rank_var}
|
|
367
|
+
rest_rank_vars = rank_variables - partially_relevant_to_previous
|
|
368
|
+
# Since order does not matter, we choose alphabetical order as canonical.
|
|
369
|
+
yield (
|
|
370
|
+
(first_rank_var,)
|
|
371
|
+
+ tuple(sorted(rest_of_partially_relevant))
|
|
372
|
+
+ tuple(sorted(rest_rank_vars))
|
|
373
|
+
)
|