accelforge 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accelforge/__init__.py +21 -0
- accelforge/_accelerated_imports.py +16 -0
- accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
- accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
- accelforge/_deprecate/_simanneal/simanneal.py +666 -0
- accelforge/_deprecate/_simanneal/tracking.py +105 -0
- accelforge/_deprecate/_simanneal/wrappers.py +218 -0
- accelforge/_deprecate/_simanneal2/__init__.py +7 -0
- accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
- accelforge/_deprecate/_simanneal2/tracking.py +116 -0
- accelforge/_deprecate/compatibility_util.py +181 -0
- accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
- accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
- accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
- accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
- accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
- accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
- accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
- accelforge/_deprecate/tags.py +69 -0
- accelforge/_deprecate/viz/__init__.py +0 -0
- accelforge/_deprecate/viz/interactive.py +159 -0
- accelforge/_deprecate/viz/reservationtree.py +307 -0
- accelforge/_deprecate/viz/ski_slope.py +88 -0
- accelforge/_version.py +15 -0
- accelforge/examples.py +39 -0
- accelforge/frontend/__init__.py +10 -0
- accelforge/frontend/_binding.py +129 -0
- accelforge/frontend/_workload_isl/__init__.py +2 -0
- accelforge/frontend/_workload_isl/_isl.py +149 -0
- accelforge/frontend/_workload_isl/_symbolic.py +141 -0
- accelforge/frontend/arch copy.py +1544 -0
- accelforge/frontend/arch.py +1642 -0
- accelforge/frontend/config.py +63 -0
- accelforge/frontend/mapper/__init__.py +5 -0
- accelforge/frontend/mapper/ffm.py +126 -0
- accelforge/frontend/mapper/mapper.py +7 -0
- accelforge/frontend/mapper/metrics.py +30 -0
- accelforge/frontend/mapping/__init__.py +1 -0
- accelforge/frontend/mapping/mapping.py +1736 -0
- accelforge/frontend/model.py +14 -0
- accelforge/frontend/renames.py +150 -0
- accelforge/frontend/spec copy.py +230 -0
- accelforge/frontend/spec.py +301 -0
- accelforge/frontend/variables.py +12 -0
- accelforge/frontend/workload.py +952 -0
- accelforge/mapper/FFM/__init__.py +9 -0
- accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
- accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
- accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
- accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
- accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
- accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
- accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
- accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
- accelforge/mapper/FFM/data.py +61 -0
- accelforge/mapper/FFM/main copy.py +236 -0
- accelforge/mapper/FFM/main.py +208 -0
- accelforge/mapper/FFM/mappings.py +510 -0
- accelforge/mapper/FFM/pmappings.py +310 -0
- accelforge/mapper/__init__.py +4 -0
- accelforge/mapper.py +0 -0
- accelforge/model/__init__.py +1 -0
- accelforge/model/_looptree/__init__.py +0 -0
- accelforge/model/_looptree/accesses.py +335 -0
- accelforge/model/_looptree/capacity/__init__.py +1 -0
- accelforge/model/_looptree/capacity/aggregators.py +36 -0
- accelforge/model/_looptree/capacity/capacity.py +47 -0
- accelforge/model/_looptree/energy.py +150 -0
- accelforge/model/_looptree/equivalent_ranks.py +29 -0
- accelforge/model/_looptree/latency/__init__.py +1 -0
- accelforge/model/_looptree/latency/latency.py +98 -0
- accelforge/model/_looptree/latency/memory.py +120 -0
- accelforge/model/_looptree/latency/processors.py +92 -0
- accelforge/model/_looptree/mapping_utilities.py +71 -0
- accelforge/model/_looptree/reuse/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
- accelforge/model/_looptree/reuse/isl/des.py +59 -0
- accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
- accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
- accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
- accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
- accelforge/model/_looptree/run.py +122 -0
- accelforge/model/_looptree/types.py +26 -0
- accelforge/model/_looptree/visualization/__init__.py +0 -0
- accelforge/model/_looptree/visualization/occupancy.py +11 -0
- accelforge/model/main.py +222 -0
- accelforge/plotting/__init__.py +2 -0
- accelforge/plotting/mappings.py +219 -0
- accelforge/plotting/specs.py +57 -0
- accelforge/util/__init__.py +4 -0
- accelforge/util/_base_analysis_types.py +24 -0
- accelforge/util/_basetypes.py +1089 -0
- accelforge/util/_frozenset.py +36 -0
- accelforge/util/_isl.py +29 -0
- accelforge/util/_itertools.py +14 -0
- accelforge/util/_mathfuncs.py +57 -0
- accelforge/util/_parse_expressions.py +339 -0
- accelforge/util/_picklecache.py +32 -0
- accelforge/util/_setexpressions.py +268 -0
- accelforge/util/_sympy/__init__.py +0 -0
- accelforge/util/_sympy/broadcast_max.py +18 -0
- accelforge/util/_visualization.py +112 -0
- accelforge/util/_yaml.py +579 -0
- accelforge/util/parallel.py +193 -0
- accelforge-0.0.1.dist-info/METADATA +64 -0
- accelforge-0.0.1.dist-info/RECORD +258 -0
- accelforge-0.0.1.dist-info/WHEEL +5 -0
- accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
- accelforge-0.0.1.dist-info/top_level.txt +5 -0
- docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
- docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
- docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
- docs/_build/html/_sources/fastfusion.rst.txt +20 -0
- docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
- docs/_build/html/_sources/index.rst.txt +87 -0
- docs/_build/html/_sources/modules.rst.txt +7 -0
- docs/_build/html/_sources/notes/citation.rst.txt +45 -0
- docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
- docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
- docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
- docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
- docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
- docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
- docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
- docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
- docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
- docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
- docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
- docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
- docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
- docs/_build/html/_sources/notes/spec.rst.txt +36 -0
- docs/source/_ext/include_attrs.py +213 -0
- docs/source/_ext/include_docstring.py +364 -0
- docs/source/_ext/include_functions.py +154 -0
- docs/source/_ext/include_notebook.py +131 -0
- docs/source/_ext/include_yaml.py +119 -0
- docs/source/_ext/inherited_attributes.py +222 -0
- docs/source/_ext/paths.py +4 -0
- docs/source/conf.py +79 -0
- examples/arches/compute_in_memory/_include.yaml +74 -0
- examples/arches/compute_in_memory/_include_functions.py +229 -0
- examples/arches/compute_in_memory/_load_spec.py +57 -0
- examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
- examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
- examples/arches/compute_in_memory/components/misc.py +195 -0
- examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
- examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
- examples/arches/compute_in_memory/isaac.yaml +233 -0
- examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
- examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
- examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
- examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
- examples/arches/eyeriss.yaml +68 -0
- examples/arches/fanout_variations/at_glb.yaml +31 -0
- examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
- examples/arches/fanout_variations/at_mac.yaml +31 -0
- examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
- examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
- examples/arches/nvdla.yaml +47 -0
- examples/arches/simple.yaml +28 -0
- examples/arches/tpu_v4i.yaml +67 -0
- examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
- examples/misc/component_annotated.yaml +33 -0
- examples/workloads/gpt3_6.7B.yaml +124 -0
- examples/workloads/matmuls.yaml +20 -0
- examples/workloads/mobilenet_28.yaml +81 -0
- examples/workloads/mobilenet_various_separate.yaml +106 -0
- examples/workloads/three_matmuls_annotated.yaml +59 -0
- notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
- notebooks/compute_in_memory/_scripts.py +339 -0
- notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
- notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
- notebooks/paths.py +4 -0
- notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
- notebooks/tutorials/FFM.ipynb +3498 -0
- notebooks/tutorials/_include.py +48 -0
- notebooks/tutorials/component_energy_area.ipynb +363 -0
- tests/Q_mapping.yaml +38 -0
- tests/__init__.py +0 -0
- tests/conv.mapping.yaml +27 -0
- tests/conv.workload.yaml +13 -0
- tests/conv_sym.mapping.yaml +43 -0
- tests/copy.mapping.yaml +35 -0
- tests/copy.workload.yaml +15 -0
- tests/distribuffers/__init__.py +0 -0
- tests/distribuffers/multicast/test_cases.yaml +482 -0
- tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
- tests/distribuffers/spec/distributed.yaml +100 -0
- tests/distribuffers/spec/logical_arch.yaml +32 -0
- tests/distribuffers/spec/physical_arch.yaml +69 -0
- tests/distribuffers/test_binding.py +48 -0
- tests/frontend/__init__.py +0 -0
- tests/frontend/test_mapping_viz.py +52 -0
- tests/mapper/__init__.py +0 -0
- tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
- tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
- tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
- tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
- tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
- tests/mapper/test_mapping_to_isl.py +90 -0
- tests/mapper/test_spatial_reuse_analysis.py +67 -0
- tests/mapper/test_temporal_reuse_analysis.py +56 -0
- tests/mapper/util.py +58 -0
- tests/matmul.mapping.yaml +29 -0
- tests/matmul.workload.yaml +12 -0
- tests/matmul_spatial.mapping.yaml +44 -0
- tests/mha.renames.yaml +65 -0
- tests/mha.workload.yaml +67 -0
- tests/mha.yaml +59 -0
- tests/mha_full.workload.yaml +67 -0
- tests/mobilenet.workload.yaml +35 -0
- tests/mobilenet_long.workload.yaml +64 -0
- tests/pmappingcache.py +24 -0
- tests/processing_stage.arch.yaml +40 -0
- tests/snowcat.arch.yaml +36 -0
- tests/test_ffm_join_pmappings.py +106 -0
- tests/test_ffm_make_pmappings.py +82 -0
- tests/test_ffm_make_tile_shapes.py +49 -0
- tests/test_mapper.py +100 -0
- tests/test_model.py +37 -0
- tests/test_plotting.py +72 -0
- tests/test_processing_stage.py +46 -0
- tests/test_symbolic_model.py +248 -0
- tests/test_workload.py +141 -0
|
@@ -0,0 +1,382 @@
|
|
|
1
|
+
from collections.abc import Collection, Generator, Sequence
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from itertools import product
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import accelforge.frontend.arch as arch
|
|
8
|
+
from accelforge.frontend.mapping import MappingNode, ProcessingStage, TensorHolder
|
|
9
|
+
from accelforge.frontend.spec import Spec
|
|
10
|
+
from accelforge.frontend.workload import TensorName, SymbolTable
|
|
11
|
+
from accelforge.util._parse_expressions import MATH_FUNCS
|
|
12
|
+
from accelforge.util._setexpressions import eval_set_expression
|
|
13
|
+
|
|
14
|
+
from accelforge.mapper.FFM._make_pmappings.make_pmapping_templates.make_storages import (
|
|
15
|
+
make_storage_choices_all_levels,
|
|
16
|
+
)
|
|
17
|
+
from accelforge.frontend.workload import EinsumName
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_tensor_choices(
|
|
21
|
+
einsum_name: EinsumName,
|
|
22
|
+
nodes: list[arch.Memory],
|
|
23
|
+
symbol_table: SymbolTable,
|
|
24
|
+
spec: Spec,
|
|
25
|
+
first_memory: arch.Memory,
|
|
26
|
+
fusable_tensors: set[TensorName],
|
|
27
|
+
) -> Generator[tuple[list[TensorHolder], SymbolTable, arch.Compute], None, None]:
|
|
28
|
+
nodes, compute = nodes[:-1], nodes[-1]
|
|
29
|
+
while True:
|
|
30
|
+
if not nodes:
|
|
31
|
+
return
|
|
32
|
+
if not isinstance(nodes[0], arch.Memory):
|
|
33
|
+
nodes = nodes[1:]
|
|
34
|
+
continue
|
|
35
|
+
assert isinstance(nodes[0].enabled, bool)
|
|
36
|
+
if not nodes[0].enabled:
|
|
37
|
+
nodes = nodes[1:]
|
|
38
|
+
continue
|
|
39
|
+
break
|
|
40
|
+
|
|
41
|
+
tensors = spec.workload.einsums[einsum_name].tensor_names
|
|
42
|
+
is_copy_op = spec.workload.einsums[einsum_name].is_copy_operation
|
|
43
|
+
persistent_tensors = {
|
|
44
|
+
t.name
|
|
45
|
+
for t in spec.workload.einsums[einsum_name].tensor_accesses
|
|
46
|
+
if t.persistent
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
for choice, symbol_table in make_storage_choices_all_levels(
|
|
50
|
+
nodes=nodes,
|
|
51
|
+
symbol_table=symbol_table,
|
|
52
|
+
is_copy_op=is_copy_op,
|
|
53
|
+
persistent_tensors=persistent_tensors,
|
|
54
|
+
seen_tensors=set(),
|
|
55
|
+
einsum_name=einsum_name,
|
|
56
|
+
):
|
|
57
|
+
x = [y for z in choice.values() for y in z]
|
|
58
|
+
logging.info(
|
|
59
|
+
f"\t\tUnordered storage choice: {", ".join(n.compact_str() for n in x)}"
|
|
60
|
+
)
|
|
61
|
+
all_tensor_holders = [v2 for v in choice.values() for v2 in v]
|
|
62
|
+
|
|
63
|
+
# Start out the mapping with the outermost memory name
|
|
64
|
+
base_mapping = []
|
|
65
|
+
# for node in list(all_tensor_holders[::-1]):
|
|
66
|
+
# if node.component == first_tensor_holder.name:
|
|
67
|
+
# all_tensor_holders.remove(node)
|
|
68
|
+
# base_mapping.append(node)
|
|
69
|
+
|
|
70
|
+
# Get the dataflow constraints for the mapping
|
|
71
|
+
required_order = get_tensor_order_constraint(nodes, symbol_table, tensors)
|
|
72
|
+
|
|
73
|
+
symbol_table["arch_attributes"] = {}
|
|
74
|
+
cur_compute = compute._parse_expressions(
|
|
75
|
+
symbol_table,
|
|
76
|
+
location=f"arch.{compute.name}",
|
|
77
|
+
must_parse_try_parse_to=True,
|
|
78
|
+
must_copy=False,
|
|
79
|
+
)[0]
|
|
80
|
+
assert isinstance(cur_compute.enabled, bool)
|
|
81
|
+
if not cur_compute.enabled:
|
|
82
|
+
continue
|
|
83
|
+
|
|
84
|
+
for mapping in recursive_order_tensor_choices(
|
|
85
|
+
einsum_name,
|
|
86
|
+
tensors,
|
|
87
|
+
base_mapping,
|
|
88
|
+
nodes,
|
|
89
|
+
all_tensor_holders,
|
|
90
|
+
required_order,
|
|
91
|
+
spec,
|
|
92
|
+
is_copy_op,
|
|
93
|
+
first_memory,
|
|
94
|
+
fusable_tensors,
|
|
95
|
+
):
|
|
96
|
+
yield mapping, symbol_table, cur_compute
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def get_tensor_order_constraint(nodes, symbol_table, tensors):
|
|
100
|
+
required_order: dict[str, list[Order]] = {}
|
|
101
|
+
for node in nodes:
|
|
102
|
+
if isinstance(node, arch.Fanout):
|
|
103
|
+
continue
|
|
104
|
+
for order_constraint in node.tensors.tensor_order_options:
|
|
105
|
+
order = Order()
|
|
106
|
+
for together_tensors in order_constraint:
|
|
107
|
+
in_mapping_together_tensors = [
|
|
108
|
+
tensor for tensor in together_tensors if tensor in tensors
|
|
109
|
+
]
|
|
110
|
+
if len(in_mapping_together_tensors) == 1:
|
|
111
|
+
only_tensor = in_mapping_together_tensors[0]
|
|
112
|
+
order.add_tensor(only_tensor)
|
|
113
|
+
elif len(in_mapping_together_tensors) > 1:
|
|
114
|
+
order.add_together_tensors(in_mapping_together_tensors)
|
|
115
|
+
if order.order:
|
|
116
|
+
required_order.setdefault(node.name, []).append(order)
|
|
117
|
+
return required_order
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def recursive_order_tensor_choices(
|
|
121
|
+
einsum_name: EinsumName,
|
|
122
|
+
tensors: set[TensorName],
|
|
123
|
+
mapping: Sequence[MappingNode],
|
|
124
|
+
nodes: list[arch.Memory],
|
|
125
|
+
remaining_choices: list,
|
|
126
|
+
required_order: list[list[TensorHolder]],
|
|
127
|
+
spec: Spec,
|
|
128
|
+
is_copy_op: bool,
|
|
129
|
+
first_memory: arch.Memory,
|
|
130
|
+
fusable_tensors: set[TensorName],
|
|
131
|
+
) -> Generator[list[MappingNode], None, None]:
|
|
132
|
+
def check_has_tensors(mapping: list[MappingNode]):
|
|
133
|
+
tensor_holders = [node for node in mapping if isinstance(node, TensorHolder)]
|
|
134
|
+
tensors_in_mapping = {
|
|
135
|
+
tensor
|
|
136
|
+
for tensor_holder in tensor_holders
|
|
137
|
+
for tensor in tensor_holder.tensors
|
|
138
|
+
}
|
|
139
|
+
if tensors_in_mapping != tensors:
|
|
140
|
+
raise ValueError(
|
|
141
|
+
f"Einsum {einsum_name} has a pmapping template that is missing tensors. Ensure "
|
|
142
|
+
f"that there is a storage node storing each tensor in the Einsum. Missing "
|
|
143
|
+
f"tensors: {tensors - tensors_in_mapping}. Pmapping template:\n\t"
|
|
144
|
+
+ "\n\t".join(m.compact_str() for m in mapping)
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
mapping = list(mapping)
|
|
148
|
+
if not remaining_choices:
|
|
149
|
+
check_has_tensors(mapping)
|
|
150
|
+
yield mapping
|
|
151
|
+
return
|
|
152
|
+
|
|
153
|
+
# If it's a copy op and we have the backing storage for every tensor, return
|
|
154
|
+
# immediately
|
|
155
|
+
if is_copy_op:
|
|
156
|
+
tensor_holders = [node for node in mapping if isinstance(node, TensorHolder)]
|
|
157
|
+
if set().union(*[t._backing for t in tensor_holders]) == tensors:
|
|
158
|
+
check_has_tensors(mapping)
|
|
159
|
+
yield mapping
|
|
160
|
+
return
|
|
161
|
+
|
|
162
|
+
for choice in sorted(remaining_choices, key=lambda x: x.compact_str()):
|
|
163
|
+
mapping.append(choice)
|
|
164
|
+
new_remaining = [c for c in remaining_choices if c != choice]
|
|
165
|
+
valid, reason = valid_tensor_holder_order(
|
|
166
|
+
mapping,
|
|
167
|
+
[n.name for n in nodes],
|
|
168
|
+
required_order,
|
|
169
|
+
spec,
|
|
170
|
+
first_memory,
|
|
171
|
+
fusable_tensors,
|
|
172
|
+
)
|
|
173
|
+
if valid:
|
|
174
|
+
yield from recursive_order_tensor_choices(
|
|
175
|
+
einsum_name,
|
|
176
|
+
tensors,
|
|
177
|
+
mapping,
|
|
178
|
+
nodes,
|
|
179
|
+
new_remaining,
|
|
180
|
+
required_order,
|
|
181
|
+
spec,
|
|
182
|
+
is_copy_op,
|
|
183
|
+
first_memory,
|
|
184
|
+
fusable_tensors,
|
|
185
|
+
)
|
|
186
|
+
else:
|
|
187
|
+
logging.info(
|
|
188
|
+
"\t\t"
|
|
189
|
+
+ " " * len(mapping)
|
|
190
|
+
+ f"Invalid tensor holder order: {", ".join(n.compact_str() for n in mapping)}: {reason}"
|
|
191
|
+
)
|
|
192
|
+
mapping.pop()
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def valid_tensor_holder_order(
|
|
196
|
+
mapping: Sequence[TensorHolder],
|
|
197
|
+
node_names: list[str],
|
|
198
|
+
required_orders: dict[str, list["Order"]],
|
|
199
|
+
spec: Spec,
|
|
200
|
+
first_memory: arch.Memory,
|
|
201
|
+
fusable_tensors: set[TensorName],
|
|
202
|
+
):
|
|
203
|
+
memory_to_satisfied_constraints: dict[str, set] = {}
|
|
204
|
+
for i, m0 in enumerate(mapping):
|
|
205
|
+
for j, m1 in enumerate(mapping[i:]):
|
|
206
|
+
j += i
|
|
207
|
+
|
|
208
|
+
s1, s2 = m0.component, m1.component
|
|
209
|
+
s1_idx, s2_idx = node_names.index(s1), node_names.index(s2)
|
|
210
|
+
s1_persistent, s2_persistent = m0.persistent, m1.persistent
|
|
211
|
+
either_persistent = s1_persistent or s2_persistent
|
|
212
|
+
|
|
213
|
+
assert len(m0.tensors) == 1
|
|
214
|
+
assert len(m1.tensors) == 1
|
|
215
|
+
|
|
216
|
+
# If they're persistent they're forced to be at the top.
|
|
217
|
+
force_order = (
|
|
218
|
+
spec.mapper.ffm.force_memory_hierarchy_order and not either_persistent
|
|
219
|
+
)
|
|
220
|
+
force_order &= m0.component_object.tensors.force_memory_hierarchy_order
|
|
221
|
+
force_order &= m1.component_object.tensors.force_memory_hierarchy_order
|
|
222
|
+
|
|
223
|
+
# Ctrl-F for CONTIGUOUS_ITERATION_SPACE_DISCUSSION: The following line does
|
|
224
|
+
# not let backing storage be above in the mapping anything that is below it
|
|
225
|
+
# in the memory hierarchy. THIS IS NOT FUNDAMENTAL. If we remove this
|
|
226
|
+
# constraint, then the fused loops may be different across different backing
|
|
227
|
+
# storages, so we would need to update make_pmappings_from_templates.py to
|
|
228
|
+
# make compatibility from the mapping for each tensor.
|
|
229
|
+
force_order |= bool(m0._backing & fusable_tensors)
|
|
230
|
+
|
|
231
|
+
if force_order and i < j and s2_idx < s1_idx:
|
|
232
|
+
return (
|
|
233
|
+
False,
|
|
234
|
+
f"Memory {s1} is below memory {s2}, violating memory hierarchy order.",
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
s1_outermost = s1_persistent
|
|
238
|
+
s2_outermost = s2_persistent
|
|
239
|
+
if not spec.mapper.ffm._can_lower_outermost_memory:
|
|
240
|
+
s1_outermost |= s1 == first_memory.name
|
|
241
|
+
s2_outermost |= s2 == first_memory.name
|
|
242
|
+
|
|
243
|
+
# Persistent tensors must be at the top of the hierarchy
|
|
244
|
+
if s2_outermost and not s1_outermost and i < j:
|
|
245
|
+
return (
|
|
246
|
+
False,
|
|
247
|
+
f"Outermost {m0.compact_str()}, persistent {s1_persistent} is below non-outermost {m1.compact_str()}, persistent {s2_persistent}.",
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
# We don't really care about processing stage order, so just make it follow
|
|
251
|
+
# the regular memory hierarchy order. For processing stages at a given
|
|
252
|
+
# level, make them alphabetical.
|
|
253
|
+
if (
|
|
254
|
+
isinstance(m0, ProcessingStage)
|
|
255
|
+
and m0.component == m1.component
|
|
256
|
+
and m0.tensor < m1.tensor
|
|
257
|
+
):
|
|
258
|
+
return (
|
|
259
|
+
False,
|
|
260
|
+
f"Processing stage {m0} is not ordered alphabetically by tensor; has tensor {m0.tensor} before {m1.tensor}",
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
# If there is a processing stage, don't explore order. If there's two
|
|
264
|
+
# back-to-back nodes and one is a processing stage, make them follow the
|
|
265
|
+
# memory hierarchy order.
|
|
266
|
+
if isinstance(m0, ProcessingStage) and s2_idx < s1_idx and i == j - 1:
|
|
267
|
+
return False, f"Processing stage {m0} is directly above {m1}"
|
|
268
|
+
if isinstance(m1, ProcessingStage) and s2_idx < s1_idx and i == j - 1:
|
|
269
|
+
return False, f"Processing stage {m1} is directly above {m0}"
|
|
270
|
+
|
|
271
|
+
if s1 == s2 and s1 in required_orders and i != j:
|
|
272
|
+
if s1 not in memory_to_satisfied_constraints:
|
|
273
|
+
memory_to_satisfied_constraints[s1] = {
|
|
274
|
+
i for i in range(len(required_orders[s1]))
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
good = True
|
|
278
|
+
for order_idx, order_choice in enumerate(required_orders[s1]):
|
|
279
|
+
if order_idx not in memory_to_satisfied_constraints[s1]:
|
|
280
|
+
continue
|
|
281
|
+
|
|
282
|
+
good = True
|
|
283
|
+
for t1, t2 in product(mapping[i].tensors, mapping[j].tensors):
|
|
284
|
+
idx_of_i_in_order = order_choice.index(t1)
|
|
285
|
+
idx_of_j_in_order = order_choice.index(t2)
|
|
286
|
+
|
|
287
|
+
if idx_of_i_in_order is None or idx_of_j_in_order is None:
|
|
288
|
+
continue
|
|
289
|
+
|
|
290
|
+
if idx_of_i_in_order > idx_of_j_in_order:
|
|
291
|
+
good = False
|
|
292
|
+
reason = f"Tensor {t1} is before tensor {t2} in the order {order_choice}"
|
|
293
|
+
break
|
|
294
|
+
if not good:
|
|
295
|
+
memory_to_satisfied_constraints[s1].remove(order_idx)
|
|
296
|
+
|
|
297
|
+
if len(memory_to_satisfied_constraints[s1]) == 0:
|
|
298
|
+
return False, reason
|
|
299
|
+
|
|
300
|
+
if not (set(m0.tensors) & set(m1.tensors)):
|
|
301
|
+
continue
|
|
302
|
+
|
|
303
|
+
if i < j and s2_idx < s1_idx:
|
|
304
|
+
return False, f"{m0.compact_str()} is below {m1.compact_str()}"
|
|
305
|
+
|
|
306
|
+
# If a tensor is stored in two levels back-to-back, then we should have
|
|
307
|
+
# bypassed the outer TensorHolder if possible.
|
|
308
|
+
either_backing = m0._backing & m1._backing
|
|
309
|
+
if (
|
|
310
|
+
"redundant_dataplacements"
|
|
311
|
+
not in spec.mapper.ffm._count_option_for_mapsapce_size_evaluation
|
|
312
|
+
):
|
|
313
|
+
if i == j or i == j - 1:
|
|
314
|
+
if s1_idx < s2_idx and not (
|
|
315
|
+
(set(m0._must_keep_tensors) & set(m1.tensors)) or either_backing
|
|
316
|
+
):
|
|
317
|
+
shared = set(m0._must_keep_tensors) & set(m1.tensors)
|
|
318
|
+
return (
|
|
319
|
+
False,
|
|
320
|
+
f"{shared} stored in back-to-back storage nodes, and could have bypassed the outer one.",
|
|
321
|
+
)
|
|
322
|
+
if s2_idx < s1_idx and not (
|
|
323
|
+
(set(m1._must_keep_tensors) & set(m0.tensors)) or either_backing
|
|
324
|
+
):
|
|
325
|
+
shared = set(m1._must_keep_tensors) & set(m0.tensors)
|
|
326
|
+
return (
|
|
327
|
+
False,
|
|
328
|
+
f"{shared} is stored in back-to-back storage nodes, and could have bypassed the outer one.",
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
for i, m0 in enumerate(mapping):
|
|
332
|
+
for j, m1 in enumerate(mapping[i:]):
|
|
333
|
+
s1, s2 = m0.component, m1.component
|
|
334
|
+
if s1 != s2 or s1 not in memory_to_satisfied_constraints or i == j:
|
|
335
|
+
continue
|
|
336
|
+
|
|
337
|
+
satisfied_orders = memory_to_satisfied_constraints[s1]
|
|
338
|
+
assert len(satisfied_orders) > 0
|
|
339
|
+
|
|
340
|
+
for order_idx in satisfied_orders:
|
|
341
|
+
order = required_orders[s1][order_idx]
|
|
342
|
+
for tensor_i in m0.tensors:
|
|
343
|
+
for tensor_j in m1.tensors:
|
|
344
|
+
if order.index(tensor_i) != order.index(tensor_j):
|
|
345
|
+
continue
|
|
346
|
+
break
|
|
347
|
+
|
|
348
|
+
return True, ""
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
@dataclass(frozen=True)
|
|
352
|
+
class Alone:
|
|
353
|
+
tensor: Any
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
@dataclass(frozen=True)
|
|
357
|
+
class Together:
|
|
358
|
+
tensors: Collection[Any]
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
class Order:
|
|
362
|
+
"""An ordering of tensors."""
|
|
363
|
+
|
|
364
|
+
def __init__(self):
|
|
365
|
+
self.order = []
|
|
366
|
+
|
|
367
|
+
def __repr__(self):
|
|
368
|
+
return f"Order({self.order})"
|
|
369
|
+
|
|
370
|
+
def add_tensor(self, tensor):
|
|
371
|
+
self.order.append(Alone(tensor))
|
|
372
|
+
|
|
373
|
+
def add_together_tensors(self, together_tensors):
|
|
374
|
+
self.order.append(Together(together_tensors))
|
|
375
|
+
|
|
376
|
+
def index(self, tensor):
|
|
377
|
+
for i, order_term in enumerate(self.order):
|
|
378
|
+
if (isinstance(order_term, Alone) and order_term.tensor == tensor) or (
|
|
379
|
+
isinstance(order_term, Together) and tensor in order_term.tensors
|
|
380
|
+
):
|
|
381
|
+
return i
|
|
382
|
+
return None
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
from collections.abc import Generator
|
|
3
|
+
from itertools import chain, combinations
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
import accelforge.frontend.arch as arch
|
|
7
|
+
from accelforge.frontend.mapping import Storage, TensorHolder, ProcessingStage
|
|
8
|
+
from accelforge.frontend.workload import TensorName, SymbolTable
|
|
9
|
+
|
|
10
|
+
from accelforge.util._parse_expressions import ParseError
|
|
11
|
+
from accelforge.util._setexpressions import InvertibleSet
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def make_tensor_choices_one_level(
|
|
15
|
+
node: arch.Leaf,
|
|
16
|
+
symbol_table: dict[str, InvertibleSet],
|
|
17
|
+
persistent_tensors: set[TensorName],
|
|
18
|
+
seen_tensors: set[TensorName] = (),
|
|
19
|
+
is_copy_op: bool = False,
|
|
20
|
+
einsum_name: str = None,
|
|
21
|
+
) -> Generator[tuple[list[TensorHolder], SymbolTable, set[TensorName]], None, None]:
|
|
22
|
+
"""
|
|
23
|
+
Generate combinations of TensorHolder nodes based on keep and bypass
|
|
24
|
+
constraints.
|
|
25
|
+
|
|
26
|
+
Each generated list contains TensorHolder nodes for single tensors.
|
|
27
|
+
"""
|
|
28
|
+
assert "All" in symbol_table
|
|
29
|
+
tensors = symbol_table["All"]
|
|
30
|
+
|
|
31
|
+
if not isinstance(node, arch.TensorHolder):
|
|
32
|
+
yield [], symbol_table, set(seen_tensors)
|
|
33
|
+
return
|
|
34
|
+
|
|
35
|
+
if isinstance(node, arch.Memory):
|
|
36
|
+
target_type = Storage
|
|
37
|
+
elif isinstance(node, arch.ProcessingStage):
|
|
38
|
+
target_type = ProcessingStage
|
|
39
|
+
elif isinstance(node, arch.Dummy):
|
|
40
|
+
yield [], symbol_table, set(seen_tensors)
|
|
41
|
+
return
|
|
42
|
+
else:
|
|
43
|
+
raise ValueError(f"Unexpected tensor holder type: {type(node)}")
|
|
44
|
+
|
|
45
|
+
new_symbol_table = copy.copy(symbol_table)
|
|
46
|
+
|
|
47
|
+
node = copy.copy(node)
|
|
48
|
+
try:
|
|
49
|
+
node.tensors: arch.Tensors = node.tensors._parse_expressions(
|
|
50
|
+
symbol_table=symbol_table,
|
|
51
|
+
must_parse_try_parse_to=True,
|
|
52
|
+
must_copy=False,
|
|
53
|
+
location=f"arch.{node.name}.tensors",
|
|
54
|
+
)[0]
|
|
55
|
+
except ParseError as e:
|
|
56
|
+
e.add_field(f"Einsum {einsum_name} arch.{node.name}.tensors")
|
|
57
|
+
raise e
|
|
58
|
+
|
|
59
|
+
must_keep = tensors.to_my_space(node.tensors.keep | node.tensors.back)
|
|
60
|
+
may_keep = tensors.to_my_space(node.tensors.may_keep)
|
|
61
|
+
may_keep -= must_keep
|
|
62
|
+
|
|
63
|
+
if seen_tensors & set(node.tensors.back):
|
|
64
|
+
return
|
|
65
|
+
|
|
66
|
+
if must_keep - tensors:
|
|
67
|
+
raise KeyError(
|
|
68
|
+
f"Keep constraint for {node.name} includes tensors that are "
|
|
69
|
+
f"not in the workload: {must_keep - new_symbol_table['All']}"
|
|
70
|
+
)
|
|
71
|
+
if may_keep - tensors:
|
|
72
|
+
raise KeyError(
|
|
73
|
+
f"Bypass constraint for {node.name} includes tensors that are "
|
|
74
|
+
f"not in the workload: {may_keep - tensors.full_space}"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
logging.info(
|
|
78
|
+
f"\t\t{node.name} must keep {sorted(must_keep)}, may keep {sorted(may_keep)}"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# No reuse in copy operations, so no need to keep tensors in more places
|
|
82
|
+
if is_copy_op:
|
|
83
|
+
may_keep -= tensors.to_my_space(seen_tensors)
|
|
84
|
+
|
|
85
|
+
for subset in powerset(sorted(may_keep, key=str)):
|
|
86
|
+
# Make keep choice & update symbol table
|
|
87
|
+
subset = tensors.to_my_space(set(subset))
|
|
88
|
+
keep_choice = tensors.to_my_space(subset | must_keep)
|
|
89
|
+
# Below line is so users can do MainMemory().tensors() or MainMemory.tensors
|
|
90
|
+
new_symbol_table[node.name] = keep_choice
|
|
91
|
+
new_symbol_table["Above"] |= keep_choice
|
|
92
|
+
new_seen_tensors = seen_tensors | set(keep_choice)
|
|
93
|
+
|
|
94
|
+
# Make sure they're all tensors
|
|
95
|
+
assert all(isinstance(k, TensorName) for k in keep_choice)
|
|
96
|
+
keep_choice = keep_choice.to_my_space({copy.copy(t) for t in keep_choice})
|
|
97
|
+
nodes = []
|
|
98
|
+
|
|
99
|
+
# Create storage nodes. Sort them to keep this deterministic. Ordering is done
|
|
100
|
+
# later.
|
|
101
|
+
for t in sorted(keep_choice, key=str):
|
|
102
|
+
nodes.append(
|
|
103
|
+
target_type(tensors=[t], component=node.name, component_object=node)
|
|
104
|
+
)
|
|
105
|
+
if t not in seen_tensors:
|
|
106
|
+
nodes[-1]._backing.add(t)
|
|
107
|
+
nodes[-1]._must_keep_tensors = [t]
|
|
108
|
+
nodes[-1].persistent = t in persistent_tensors
|
|
109
|
+
elif t in must_keep:
|
|
110
|
+
nodes[-1]._must_keep_tensors = [t]
|
|
111
|
+
|
|
112
|
+
yield nodes, new_symbol_table, new_seen_tensors
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def make_storage_choices_all_levels(
|
|
116
|
+
nodes: list[TensorHolder],
|
|
117
|
+
symbol_table: dict[str, InvertibleSet],
|
|
118
|
+
persistent_tensors: set[TensorName],
|
|
119
|
+
seen_tensors: set[TensorName] = None,
|
|
120
|
+
is_copy_op: bool = False,
|
|
121
|
+
einsum_name: str = None,
|
|
122
|
+
) -> Generator[tuple[dict[str, list[TensorHolder]], SymbolTable], None, None]:
|
|
123
|
+
"""
|
|
124
|
+
Generate combinations of TensorHolder nodes based on keep and bypass
|
|
125
|
+
constraints.
|
|
126
|
+
|
|
127
|
+
Each generated dict maps memory name to a list of TensorHolder nodes for
|
|
128
|
+
single tensors.
|
|
129
|
+
"""
|
|
130
|
+
seen_tensors = set() if seen_tensors is None else seen_tensors
|
|
131
|
+
if len(nodes) == 0:
|
|
132
|
+
yield dict(), symbol_table
|
|
133
|
+
return
|
|
134
|
+
for choice, symbol_table, new_seen_tensors in make_tensor_choices_one_level(
|
|
135
|
+
node=nodes[0],
|
|
136
|
+
symbol_table=symbol_table,
|
|
137
|
+
persistent_tensors=persistent_tensors,
|
|
138
|
+
seen_tensors=seen_tensors,
|
|
139
|
+
is_copy_op=is_copy_op,
|
|
140
|
+
einsum_name=einsum_name,
|
|
141
|
+
):
|
|
142
|
+
for subchoices, symbol_table in make_storage_choices_all_levels(
|
|
143
|
+
nodes=nodes[1:],
|
|
144
|
+
symbol_table=symbol_table,
|
|
145
|
+
persistent_tensors=persistent_tensors,
|
|
146
|
+
seen_tensors=new_seen_tensors,
|
|
147
|
+
is_copy_op=is_copy_op,
|
|
148
|
+
einsum_name=einsum_name,
|
|
149
|
+
):
|
|
150
|
+
yield {**subchoices, nodes[0].name: choice}, symbol_table
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def powerset(iterable):
|
|
154
|
+
s = list(iterable)
|
|
155
|
+
return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
|