accelforge 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accelforge/__init__.py +21 -0
- accelforge/_accelerated_imports.py +16 -0
- accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
- accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
- accelforge/_deprecate/_simanneal/simanneal.py +666 -0
- accelforge/_deprecate/_simanneal/tracking.py +105 -0
- accelforge/_deprecate/_simanneal/wrappers.py +218 -0
- accelforge/_deprecate/_simanneal2/__init__.py +7 -0
- accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
- accelforge/_deprecate/_simanneal2/tracking.py +116 -0
- accelforge/_deprecate/compatibility_util.py +181 -0
- accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
- accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
- accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
- accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
- accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
- accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
- accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
- accelforge/_deprecate/tags.py +69 -0
- accelforge/_deprecate/viz/__init__.py +0 -0
- accelforge/_deprecate/viz/interactive.py +159 -0
- accelforge/_deprecate/viz/reservationtree.py +307 -0
- accelforge/_deprecate/viz/ski_slope.py +88 -0
- accelforge/_version.py +15 -0
- accelforge/examples.py +39 -0
- accelforge/frontend/__init__.py +10 -0
- accelforge/frontend/_binding.py +129 -0
- accelforge/frontend/_workload_isl/__init__.py +2 -0
- accelforge/frontend/_workload_isl/_isl.py +149 -0
- accelforge/frontend/_workload_isl/_symbolic.py +141 -0
- accelforge/frontend/arch copy.py +1544 -0
- accelforge/frontend/arch.py +1642 -0
- accelforge/frontend/config.py +63 -0
- accelforge/frontend/mapper/__init__.py +5 -0
- accelforge/frontend/mapper/ffm.py +126 -0
- accelforge/frontend/mapper/mapper.py +7 -0
- accelforge/frontend/mapper/metrics.py +30 -0
- accelforge/frontend/mapping/__init__.py +1 -0
- accelforge/frontend/mapping/mapping.py +1736 -0
- accelforge/frontend/model.py +14 -0
- accelforge/frontend/renames.py +150 -0
- accelforge/frontend/spec copy.py +230 -0
- accelforge/frontend/spec.py +301 -0
- accelforge/frontend/variables.py +12 -0
- accelforge/frontend/workload.py +952 -0
- accelforge/mapper/FFM/__init__.py +9 -0
- accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
- accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
- accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
- accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
- accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
- accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
- accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
- accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
- accelforge/mapper/FFM/data.py +61 -0
- accelforge/mapper/FFM/main copy.py +236 -0
- accelforge/mapper/FFM/main.py +208 -0
- accelforge/mapper/FFM/mappings.py +510 -0
- accelforge/mapper/FFM/pmappings.py +310 -0
- accelforge/mapper/__init__.py +4 -0
- accelforge/mapper.py +0 -0
- accelforge/model/__init__.py +1 -0
- accelforge/model/_looptree/__init__.py +0 -0
- accelforge/model/_looptree/accesses.py +335 -0
- accelforge/model/_looptree/capacity/__init__.py +1 -0
- accelforge/model/_looptree/capacity/aggregators.py +36 -0
- accelforge/model/_looptree/capacity/capacity.py +47 -0
- accelforge/model/_looptree/energy.py +150 -0
- accelforge/model/_looptree/equivalent_ranks.py +29 -0
- accelforge/model/_looptree/latency/__init__.py +1 -0
- accelforge/model/_looptree/latency/latency.py +98 -0
- accelforge/model/_looptree/latency/memory.py +120 -0
- accelforge/model/_looptree/latency/processors.py +92 -0
- accelforge/model/_looptree/mapping_utilities.py +71 -0
- accelforge/model/_looptree/reuse/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
- accelforge/model/_looptree/reuse/isl/des.py +59 -0
- accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
- accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
- accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
- accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
- accelforge/model/_looptree/run.py +122 -0
- accelforge/model/_looptree/types.py +26 -0
- accelforge/model/_looptree/visualization/__init__.py +0 -0
- accelforge/model/_looptree/visualization/occupancy.py +11 -0
- accelforge/model/main.py +222 -0
- accelforge/plotting/__init__.py +2 -0
- accelforge/plotting/mappings.py +219 -0
- accelforge/plotting/specs.py +57 -0
- accelforge/util/__init__.py +4 -0
- accelforge/util/_base_analysis_types.py +24 -0
- accelforge/util/_basetypes.py +1089 -0
- accelforge/util/_frozenset.py +36 -0
- accelforge/util/_isl.py +29 -0
- accelforge/util/_itertools.py +14 -0
- accelforge/util/_mathfuncs.py +57 -0
- accelforge/util/_parse_expressions.py +339 -0
- accelforge/util/_picklecache.py +32 -0
- accelforge/util/_setexpressions.py +268 -0
- accelforge/util/_sympy/__init__.py +0 -0
- accelforge/util/_sympy/broadcast_max.py +18 -0
- accelforge/util/_visualization.py +112 -0
- accelforge/util/_yaml.py +579 -0
- accelforge/util/parallel.py +193 -0
- accelforge-0.0.1.dist-info/METADATA +64 -0
- accelforge-0.0.1.dist-info/RECORD +258 -0
- accelforge-0.0.1.dist-info/WHEEL +5 -0
- accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
- accelforge-0.0.1.dist-info/top_level.txt +5 -0
- docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
- docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
- docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
- docs/_build/html/_sources/fastfusion.rst.txt +20 -0
- docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
- docs/_build/html/_sources/index.rst.txt +87 -0
- docs/_build/html/_sources/modules.rst.txt +7 -0
- docs/_build/html/_sources/notes/citation.rst.txt +45 -0
- docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
- docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
- docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
- docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
- docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
- docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
- docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
- docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
- docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
- docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
- docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
- docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
- docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
- docs/_build/html/_sources/notes/spec.rst.txt +36 -0
- docs/source/_ext/include_attrs.py +213 -0
- docs/source/_ext/include_docstring.py +364 -0
- docs/source/_ext/include_functions.py +154 -0
- docs/source/_ext/include_notebook.py +131 -0
- docs/source/_ext/include_yaml.py +119 -0
- docs/source/_ext/inherited_attributes.py +222 -0
- docs/source/_ext/paths.py +4 -0
- docs/source/conf.py +79 -0
- examples/arches/compute_in_memory/_include.yaml +74 -0
- examples/arches/compute_in_memory/_include_functions.py +229 -0
- examples/arches/compute_in_memory/_load_spec.py +57 -0
- examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
- examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
- examples/arches/compute_in_memory/components/misc.py +195 -0
- examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
- examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
- examples/arches/compute_in_memory/isaac.yaml +233 -0
- examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
- examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
- examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
- examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
- examples/arches/eyeriss.yaml +68 -0
- examples/arches/fanout_variations/at_glb.yaml +31 -0
- examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
- examples/arches/fanout_variations/at_mac.yaml +31 -0
- examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
- examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
- examples/arches/nvdla.yaml +47 -0
- examples/arches/simple.yaml +28 -0
- examples/arches/tpu_v4i.yaml +67 -0
- examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
- examples/misc/component_annotated.yaml +33 -0
- examples/workloads/gpt3_6.7B.yaml +124 -0
- examples/workloads/matmuls.yaml +20 -0
- examples/workloads/mobilenet_28.yaml +81 -0
- examples/workloads/mobilenet_various_separate.yaml +106 -0
- examples/workloads/three_matmuls_annotated.yaml +59 -0
- notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
- notebooks/compute_in_memory/_scripts.py +339 -0
- notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
- notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
- notebooks/paths.py +4 -0
- notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
- notebooks/tutorials/FFM.ipynb +3498 -0
- notebooks/tutorials/_include.py +48 -0
- notebooks/tutorials/component_energy_area.ipynb +363 -0
- tests/Q_mapping.yaml +38 -0
- tests/__init__.py +0 -0
- tests/conv.mapping.yaml +27 -0
- tests/conv.workload.yaml +13 -0
- tests/conv_sym.mapping.yaml +43 -0
- tests/copy.mapping.yaml +35 -0
- tests/copy.workload.yaml +15 -0
- tests/distribuffers/__init__.py +0 -0
- tests/distribuffers/multicast/test_cases.yaml +482 -0
- tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
- tests/distribuffers/spec/distributed.yaml +100 -0
- tests/distribuffers/spec/logical_arch.yaml +32 -0
- tests/distribuffers/spec/physical_arch.yaml +69 -0
- tests/distribuffers/test_binding.py +48 -0
- tests/frontend/__init__.py +0 -0
- tests/frontend/test_mapping_viz.py +52 -0
- tests/mapper/__init__.py +0 -0
- tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
- tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
- tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
- tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
- tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
- tests/mapper/test_mapping_to_isl.py +90 -0
- tests/mapper/test_spatial_reuse_analysis.py +67 -0
- tests/mapper/test_temporal_reuse_analysis.py +56 -0
- tests/mapper/util.py +58 -0
- tests/matmul.mapping.yaml +29 -0
- tests/matmul.workload.yaml +12 -0
- tests/matmul_spatial.mapping.yaml +44 -0
- tests/mha.renames.yaml +65 -0
- tests/mha.workload.yaml +67 -0
- tests/mha.yaml +59 -0
- tests/mha_full.workload.yaml +67 -0
- tests/mobilenet.workload.yaml +35 -0
- tests/mobilenet_long.workload.yaml +64 -0
- tests/pmappingcache.py +24 -0
- tests/processing_stage.arch.yaml +40 -0
- tests/snowcat.arch.yaml +36 -0
- tests/test_ffm_join_pmappings.py +106 -0
- tests/test_ffm_make_pmappings.py +82 -0
- tests/test_ffm_make_tile_shapes.py +49 -0
- tests/test_mapper.py +100 -0
- tests/test_model.py +37 -0
- tests/test_plotting.py +72 -0
- tests/test_processing_stage.py +46 -0
- tests/test_symbolic_model.py +248 -0
- tests/test_workload.py +141 -0
|
@@ -0,0 +1,1408 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
import itertools
|
|
4
|
+
from accelforge.frontend.mapping import (
|
|
5
|
+
Compute,
|
|
6
|
+
Mapping,
|
|
7
|
+
Nested,
|
|
8
|
+
Pipeline,
|
|
9
|
+
ProcessingStage,
|
|
10
|
+
Reservation,
|
|
11
|
+
Sequential,
|
|
12
|
+
Spatial,
|
|
13
|
+
Split,
|
|
14
|
+
Storage,
|
|
15
|
+
Temporal,
|
|
16
|
+
)
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
from accelforge.frontend import arch
|
|
20
|
+
import accelforge.frontend.mapping as mapping_spec
|
|
21
|
+
from accelforge.frontend.mapping import (
|
|
22
|
+
Mapping,
|
|
23
|
+
MappingNode,
|
|
24
|
+
Nested,
|
|
25
|
+
Spatial,
|
|
26
|
+
Temporal,
|
|
27
|
+
Storage,
|
|
28
|
+
Reservation,
|
|
29
|
+
Loop,
|
|
30
|
+
TensorHolder,
|
|
31
|
+
ProcessingStage,
|
|
32
|
+
)
|
|
33
|
+
from accelforge.frontend.workload import (
|
|
34
|
+
Workload,
|
|
35
|
+
TensorName,
|
|
36
|
+
isl_expression_has_variable,
|
|
37
|
+
)
|
|
38
|
+
from accelforge.frontend._workload_isl._isl import get_rank_variable_bounds
|
|
39
|
+
from accelforge.frontend._workload_isl._symbolic import (
|
|
40
|
+
get_projection_expr,
|
|
41
|
+
get_rank_variable_relevancy,
|
|
42
|
+
compute_dense_tile_occupancy,
|
|
43
|
+
Irrelevant,
|
|
44
|
+
Relevant,
|
|
45
|
+
PartiallyRelevant,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
from accelforge.model._looptree.types import Buffet
|
|
49
|
+
|
|
50
|
+
from accelforge.mapper.FFM._make_pmappings.pmapper_job import Job
|
|
51
|
+
from accelforge.util._sympy.broadcast_max import Min, Max
|
|
52
|
+
|
|
53
|
+
import sympy
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
SYMBOL = "symbol"
|
|
57
|
+
IMPERFECT = False
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass(eq=True, frozen=True)
|
|
61
|
+
class Compute:
|
|
62
|
+
einsum: str
|
|
63
|
+
level: str
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class Uninitialized:
|
|
67
|
+
def __init__(self):
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
def __str__(self):
|
|
71
|
+
return "Uninitialized"
|
|
72
|
+
|
|
73
|
+
def __repr__(self):
|
|
74
|
+
return "Uninitialized()"
|
|
75
|
+
|
|
76
|
+
def __rmul__(self, other):
|
|
77
|
+
return self * other
|
|
78
|
+
|
|
79
|
+
def __mul__(self, other):
|
|
80
|
+
return self
|
|
81
|
+
|
|
82
|
+
def __radd__(self, other):
|
|
83
|
+
return self + other
|
|
84
|
+
|
|
85
|
+
def __add__(self, other):
|
|
86
|
+
return self
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# TODO: unsure if this is needed. If the sympy symbol is created with the
|
|
90
|
+
# correct assumption (e.g., positive), this should be automatic.
|
|
91
|
+
def min_nonzero(a: Any, b: Any) -> Any:
|
|
92
|
+
if a == 0:
|
|
93
|
+
return b
|
|
94
|
+
if b == 0:
|
|
95
|
+
return a
|
|
96
|
+
return Min(a, b)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def max_dict(a: dict[Any, Any], b: dict[Any, Any]) -> dict[Any, Any]:
|
|
100
|
+
new = {**a}
|
|
101
|
+
for key, value in b.items():
|
|
102
|
+
new[key] = Max(new[key], value) if key in new else value
|
|
103
|
+
assert isinstance(new, dict)
|
|
104
|
+
return new
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@dataclass
|
|
108
|
+
class BuffetStats:
|
|
109
|
+
total_reads_to_parent: Any = field(default=0)
|
|
110
|
+
total_writes_to_parent: Any = field(default=0)
|
|
111
|
+
max_per_parent_reads_to_parent: Any = field(default=0)
|
|
112
|
+
max_per_parent_writes_to_parent: Any = field(default=0)
|
|
113
|
+
|
|
114
|
+
total_reads_to_peer: Any = field(default=0)
|
|
115
|
+
total_writes_to_peer: Any = field(default=0)
|
|
116
|
+
max_per_unit_reads_to_peer: Any = field(default=0)
|
|
117
|
+
max_per_unit_writes_to_peer: Any = field(default=0)
|
|
118
|
+
|
|
119
|
+
total_writes_to_child: Any = field(default=0)
|
|
120
|
+
total_reads_to_child: Any = field(default=0)
|
|
121
|
+
max_per_unit_writes_to_child: Any = field(default=0)
|
|
122
|
+
max_per_unit_reads_to_child: Any = field(default=0)
|
|
123
|
+
|
|
124
|
+
# Skip the first iteration of temporal loops for data that is written
|
|
125
|
+
total_skipped_first_reads_to_parent: Any = field(default=0)
|
|
126
|
+
total_skipped_first_reads_to_peer: Any = field(default=0)
|
|
127
|
+
total_skipped_first_writes_to_child: Any = field(default=0)
|
|
128
|
+
min_per_parent_skipped_first_reads_to_parent: Any = field(default=0)
|
|
129
|
+
min_per_unit_skipped_first_writes_to_peer: Any = field(default=0)
|
|
130
|
+
min_per_unit_skipped_first_writes_to_child: Any = field(default=0)
|
|
131
|
+
|
|
132
|
+
max_occupancy: Any = field(default=0)
|
|
133
|
+
_n_loops_above: int = field(default=1)
|
|
134
|
+
|
|
135
|
+
persistent: bool = field(default=False)
|
|
136
|
+
|
|
137
|
+
_write_scale: float = field(default=None)
|
|
138
|
+
_read_scale: float = field(default=None)
|
|
139
|
+
_count_upward_movement: bool = field(default=None)
|
|
140
|
+
_count_downward_movement: bool = field(default=None)
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def write_scale(self) -> Any:
|
|
144
|
+
return self._write_scale
|
|
145
|
+
|
|
146
|
+
@write_scale.setter
|
|
147
|
+
def write_scale(self, value: Any):
|
|
148
|
+
assert self._write_scale is None or self._write_scale == value, "BUG"
|
|
149
|
+
self._write_scale = value
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
def read_scale(self) -> Any:
|
|
153
|
+
return self._read_scale
|
|
154
|
+
|
|
155
|
+
@read_scale.setter
|
|
156
|
+
def read_scale(self, value: Any):
|
|
157
|
+
assert self._read_scale is None or self._read_scale == value, "BUG"
|
|
158
|
+
self._read_scale = value
|
|
159
|
+
|
|
160
|
+
@property
|
|
161
|
+
def count_upward_movement(self) -> bool:
|
|
162
|
+
return self._count_upward_movement
|
|
163
|
+
|
|
164
|
+
@count_upward_movement.setter
|
|
165
|
+
def count_upward_movement(self, value: bool):
|
|
166
|
+
assert (
|
|
167
|
+
self._count_upward_movement is None or self._count_upward_movement == value
|
|
168
|
+
), "BUG"
|
|
169
|
+
self._count_upward_movement = value
|
|
170
|
+
|
|
171
|
+
@property
|
|
172
|
+
def count_downward_movement(self) -> bool:
|
|
173
|
+
return self._count_downward_movement
|
|
174
|
+
|
|
175
|
+
@count_downward_movement.setter
|
|
176
|
+
def count_downward_movement(self, value: bool):
|
|
177
|
+
assert (
|
|
178
|
+
self._count_downward_movement is None
|
|
179
|
+
or self._count_downward_movement == value
|
|
180
|
+
), "BUG"
|
|
181
|
+
self._count_downward_movement = value
|
|
182
|
+
|
|
183
|
+
@property
|
|
184
|
+
def n_loops_above(self) -> int:
|
|
185
|
+
if self.persistent:
|
|
186
|
+
return -1
|
|
187
|
+
return self._n_loops_above
|
|
188
|
+
|
|
189
|
+
@n_loops_above.setter
|
|
190
|
+
def n_loops_above(self, value: int):
|
|
191
|
+
self._n_loops_above = value
|
|
192
|
+
|
|
193
|
+
def repeat_temporal(self, factor: int, is_fully_relevant: bool) -> "BuffetStats":
|
|
194
|
+
new = copy.copy(self)
|
|
195
|
+
for attr in self.__dict__:
|
|
196
|
+
if not attr.startswith(("total_", "max_", "min_")):
|
|
197
|
+
continue
|
|
198
|
+
if "skipped_first" in attr and not is_fully_relevant:
|
|
199
|
+
continue # First actions occur once per relevant iteration.
|
|
200
|
+
if attr == "max_occupancy":
|
|
201
|
+
continue # Max occupancy is not affected by temporal loops above
|
|
202
|
+
setattr(new, attr, getattr(new, attr) * factor)
|
|
203
|
+
return new
|
|
204
|
+
|
|
205
|
+
def repeat_spatial(self, factor: int, reuse_parent_accesses: bool) -> "BuffetStats":
|
|
206
|
+
new = copy.copy(self)
|
|
207
|
+
for attr in self.__dict__:
|
|
208
|
+
if not attr.startswith(("total_", "max_", "min_")):
|
|
209
|
+
continue
|
|
210
|
+
if "parent" in attr and reuse_parent_accesses:
|
|
211
|
+
continue # If parent accesses are reused, no need to multiply
|
|
212
|
+
if "per_unit" in attr:
|
|
213
|
+
continue # Spatial fanout doesn't affect per-unit stats
|
|
214
|
+
if attr == "max_occupancy":
|
|
215
|
+
continue # Max occupancy is not affected by temporal loops above
|
|
216
|
+
setattr(new, attr, getattr(new, attr) * factor)
|
|
217
|
+
return new
|
|
218
|
+
|
|
219
|
+
def max(self, **kwargs: Any):
|
|
220
|
+
for key, value in kwargs.items():
|
|
221
|
+
setattr(self, key, Max(getattr(self, key), value))
|
|
222
|
+
|
|
223
|
+
def min(self, **kwargs: Any):
|
|
224
|
+
for key, value in kwargs.items():
|
|
225
|
+
setattr(self, key, Min(getattr(self, key), value))
|
|
226
|
+
|
|
227
|
+
def __add__(self, other: "BuffetStats") -> "BuffetStats":
|
|
228
|
+
new = copy.copy(self)
|
|
229
|
+
for attr in self.__dict__:
|
|
230
|
+
if attr.startswith("min_"):
|
|
231
|
+
setattr(
|
|
232
|
+
new, attr, min_nonzero(getattr(self, attr), getattr(other, attr))
|
|
233
|
+
)
|
|
234
|
+
elif attr.startswith("max_"):
|
|
235
|
+
setattr(new, attr, Max(getattr(self, attr), getattr(other, attr)))
|
|
236
|
+
elif attr.startswith("total_"):
|
|
237
|
+
setattr(new, attr, getattr(self, attr) + getattr(other, attr))
|
|
238
|
+
elif getattr(self, attr) is None:
|
|
239
|
+
setattr(new, attr, getattr(other, attr))
|
|
240
|
+
elif getattr(other, attr) is None:
|
|
241
|
+
setattr(new, attr, getattr(self, attr))
|
|
242
|
+
else:
|
|
243
|
+
assert getattr(self, attr) == getattr(
|
|
244
|
+
other, attr
|
|
245
|
+
), f"BUG: {attr} is different. self: {getattr(self, attr)} other: {getattr(other, attr)}"
|
|
246
|
+
return new
|
|
247
|
+
|
|
248
|
+
def __iadd__(self, other: "BuffetStats") -> "BuffetStats":
|
|
249
|
+
new = self + other
|
|
250
|
+
for key, value in new.__dict__.items():
|
|
251
|
+
setattr(self, key, value)
|
|
252
|
+
return self
|
|
253
|
+
|
|
254
|
+
def net_total_read_actions(self) -> Any:
|
|
255
|
+
return self.total_read_actions - self.total_skipped_first_read_actions
|
|
256
|
+
|
|
257
|
+
def net_total_write_actions(self) -> Any:
|
|
258
|
+
return self.total_write_actions - self.total_skipped_first_write_actions
|
|
259
|
+
|
|
260
|
+
def net_max_per_unit_read_actions(self) -> Any:
|
|
261
|
+
return (
|
|
262
|
+
self.max_per_unit_read_actions
|
|
263
|
+
- self.min_per_unit_skipped_first_read_actions
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
def net_max_per_unit_write_actions(self) -> Any:
|
|
267
|
+
return (
|
|
268
|
+
self.max_per_unit_write_actions
|
|
269
|
+
- self.min_per_unit_skipped_first_write_actions
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
def _get_actions(
|
|
273
|
+
self,
|
|
274
|
+
prefix: str,
|
|
275
|
+
) -> float | sympy.Expr:
|
|
276
|
+
# My reads to parent go down (parent->me), my reads to child go up (child->me)
|
|
277
|
+
if "read" in prefix:
|
|
278
|
+
parent, child = self.count_downward_movement, self.count_upward_movement
|
|
279
|
+
scale = self.write_scale # Write to other = read to self
|
|
280
|
+
# My writes to parent go up (me->parent), my writes to child go down (me->child)
|
|
281
|
+
elif "write" in prefix:
|
|
282
|
+
parent, child = self.count_upward_movement, self.count_downward_movement
|
|
283
|
+
scale = self.read_scale # Write to other = read to self
|
|
284
|
+
else:
|
|
285
|
+
raise ValueError(f"Invalid prefix: {prefix}")
|
|
286
|
+
|
|
287
|
+
total = getattr(self, f"{prefix.replace('write', 'read')}s_to_peer", 0)
|
|
288
|
+
total += getattr(self, f"{prefix}s_to_parent", 0) if parent else 0
|
|
289
|
+
total += getattr(self, f"{prefix}s_to_child", 0) if child else 0
|
|
290
|
+
|
|
291
|
+
return total * scale
|
|
292
|
+
|
|
293
|
+
@property
|
|
294
|
+
def total_write_actions(self):
|
|
295
|
+
return self._get_actions("total_read") # Read to other = write to self
|
|
296
|
+
|
|
297
|
+
@property
|
|
298
|
+
def total_read_actions(self):
|
|
299
|
+
return self._get_actions("total_write") # Write to other = read to self
|
|
300
|
+
|
|
301
|
+
@property
|
|
302
|
+
def max_per_unit_write_actions(self):
|
|
303
|
+
return self._get_actions("max_per_unit_read") # Read to other = write to self
|
|
304
|
+
|
|
305
|
+
@property
|
|
306
|
+
def max_per_unit_read_actions(self):
|
|
307
|
+
return self._get_actions("max_per_unit_write") # Write to other = read to self
|
|
308
|
+
|
|
309
|
+
@property
|
|
310
|
+
def total_skipped_first_write_actions(self):
|
|
311
|
+
# Read to other = write to self
|
|
312
|
+
return self._get_actions("total_skipped_first_read")
|
|
313
|
+
|
|
314
|
+
@property
|
|
315
|
+
def min_per_unit_skipped_first_write_actions(self):
|
|
316
|
+
# Read to other = write to self
|
|
317
|
+
return self._get_actions("min_per_unit_skipped_first_read")
|
|
318
|
+
|
|
319
|
+
@property
|
|
320
|
+
def total_skipped_first_read_actions(self):
|
|
321
|
+
# Write to other = read to self
|
|
322
|
+
return self._get_actions("total_skipped_first_write")
|
|
323
|
+
|
|
324
|
+
@property
|
|
325
|
+
def min_per_unit_skipped_first_read_actions(self):
|
|
326
|
+
# Write to other = read to self
|
|
327
|
+
return self._get_actions("min_per_unit_skipped_first_write")
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def blank_buffet_stats() -> BuffetStats:
|
|
331
|
+
stats = BuffetStats()
|
|
332
|
+
stats.n_loops_above = None # Inherit from whoever is added to this
|
|
333
|
+
return stats
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
@dataclass
|
|
337
|
+
class ComputeStats:
|
|
338
|
+
total_ops: Any = field(default=0)
|
|
339
|
+
max_per_unit_ops: Any = field(default=0)
|
|
340
|
+
# "max" below refers to the longest latency of any iteration
|
|
341
|
+
max_latency: Any = field(default=0)
|
|
342
|
+
# Mapping from the loop-index (0 at top) to the latency of the first
|
|
343
|
+
# iteration of that loop. "Max" because we may have loops above that and we
|
|
344
|
+
# will take the maximum of the firsts.
|
|
345
|
+
max_first_latency: dict[int, Any] = field(default_factory=dict)
|
|
346
|
+
|
|
347
|
+
def repeat_temporal(self, factor: int) -> "ComputeStats":
|
|
348
|
+
new = copy.copy(self)
|
|
349
|
+
new.total_ops *= factor
|
|
350
|
+
new.max_per_unit_ops *= factor
|
|
351
|
+
new.max_latency *= factor
|
|
352
|
+
# NOTE: max_first_latency does not change
|
|
353
|
+
return new
|
|
354
|
+
|
|
355
|
+
def repeat_spatial(self, factor: int) -> "ComputeStats":
|
|
356
|
+
new = copy.copy(self)
|
|
357
|
+
new.total_ops *= factor
|
|
358
|
+
return new
|
|
359
|
+
|
|
360
|
+
def __add__(self, other: "ComputeStats") -> "ComputeStats":
|
|
361
|
+
new = copy.copy(self)
|
|
362
|
+
new.total_ops += other.total_ops
|
|
363
|
+
new.max_per_unit_ops += other.max_per_unit_ops
|
|
364
|
+
new.max_latency += other.max_latency
|
|
365
|
+
# max_first_latency is only ever updated across loops ABOVE the loop
|
|
366
|
+
# for which we calculated that first latency, so we should MAX
|
|
367
|
+
new.max_first_latency = max_dict(
|
|
368
|
+
self.max_first_latency, other.max_first_latency
|
|
369
|
+
) # FIRST LATENCY
|
|
370
|
+
return new
|
|
371
|
+
|
|
372
|
+
def combine_temporal(self, other: "ComputeStats"):
|
|
373
|
+
self.total_ops += other.total_ops
|
|
374
|
+
self.max_per_unit_ops += other.max_per_unit_ops
|
|
375
|
+
self.max_latency += other.max_latency
|
|
376
|
+
# max_first_latency is only ever updated across loops ABOVE the loop
|
|
377
|
+
# for which we calculated that first latency, so we should MAX
|
|
378
|
+
self.max_first_latency = max_dict(
|
|
379
|
+
self.max_first_latency, other.max_first_latency
|
|
380
|
+
) # FIRST LATENCY
|
|
381
|
+
|
|
382
|
+
def combine_spatial(self, other: "ComputeStats"):
|
|
383
|
+
self.total_ops += other.total_ops
|
|
384
|
+
self.max_per_unit_ops = Max(self.max_per_unit_ops, other.max_per_unit_ops)
|
|
385
|
+
self.max_latency = Max(self.max_latency, other.max_latency)
|
|
386
|
+
# max_first_latency is only ever updated across loops ABOVE the loop
|
|
387
|
+
# for which we calculated that first latency, so we should MAX
|
|
388
|
+
self.max_first_latency = max_dict(
|
|
389
|
+
self.max_first_latency, other.max_first_latency
|
|
390
|
+
) # FIRST LATENCY
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
@dataclass
|
|
394
|
+
class SymbolicAnalysisOutput:
|
|
395
|
+
compute_stats: dict[Compute, ComputeStats] = field(default_factory=dict)
|
|
396
|
+
|
|
397
|
+
buffet_stats: dict[Buffet, BuffetStats] = field(default_factory=dict)
|
|
398
|
+
|
|
399
|
+
# Mapping [level, einsum] to the fanout
|
|
400
|
+
fanout: dict[(Buffet, str), int] = field(default_factory=dict)
|
|
401
|
+
|
|
402
|
+
# Mapping [einsum] to the number of temporal steps
|
|
403
|
+
temporal_steps: dict[str, int] = field(default_factory=dict)
|
|
404
|
+
|
|
405
|
+
symbols: list[sympy.Symbol] = field(default_factory=list)
|
|
406
|
+
|
|
407
|
+
# tensor to the mapping for that particular tensor
|
|
408
|
+
tensor2mapping: dict[TensorName, Mapping] = field(default_factory=dict)
|
|
409
|
+
|
|
410
|
+
def get_buffet_for_tensor(self, tensor: TensorName) -> Buffet:
|
|
411
|
+
for buffet in self.buffet_stats:
|
|
412
|
+
if buffet.tensor == tensor:
|
|
413
|
+
return buffet
|
|
414
|
+
raise ValueError(f"Buffet for tensor {tensor} not found")
|
|
415
|
+
|
|
416
|
+
def max(self, **kwargs: Any):
|
|
417
|
+
for key, value in kwargs.items():
|
|
418
|
+
assert key in [
|
|
419
|
+
"compute_stats",
|
|
420
|
+
"stats",
|
|
421
|
+
"fanout",
|
|
422
|
+
"temporal_steps",
|
|
423
|
+
]
|
|
424
|
+
previous = getattr(self, key)
|
|
425
|
+
for k, v in value.items():
|
|
426
|
+
previous.setdefault(k, {})
|
|
427
|
+
for k2, v2 in v.items():
|
|
428
|
+
if k2 in previous[k]:
|
|
429
|
+
previous[k][k2] = Max(previous[k][k2], v2)
|
|
430
|
+
else:
|
|
431
|
+
previous[k][k2] = v2
|
|
432
|
+
|
|
433
|
+
def get_child_buffet_stats(self, buffet: Buffet) -> BuffetStats:
|
|
434
|
+
seen = False
|
|
435
|
+
for child_buffet, child_stats in reversed(self.buffet_stats.items()):
|
|
436
|
+
if not seen:
|
|
437
|
+
seen = child_buffet == buffet
|
|
438
|
+
continue
|
|
439
|
+
if child_buffet.tensor == buffet.tensor:
|
|
440
|
+
return child_stats
|
|
441
|
+
return None
|
|
442
|
+
|
|
443
|
+
def sum_buffet_stats_per_level(self) -> dict[str, BuffetStats]:
|
|
444
|
+
result: dict[str, BuffetStats] = {}
|
|
445
|
+
for buffet, stats in self.buffet_stats.items():
|
|
446
|
+
result.setdefault(buffet.level, blank_buffet_stats())
|
|
447
|
+
result[buffet.level] += stats
|
|
448
|
+
return result
|
|
449
|
+
|
|
450
|
+
def add_buffet_stats_and_symbols(self, other: "SymbolicAnalysisOutput"):
|
|
451
|
+
assert not (set(self.buffet_stats) & set(other.buffet_stats)), "BUG"
|
|
452
|
+
self.buffet_stats.update(other.buffet_stats)
|
|
453
|
+
# if self.temporal_steps != other.temporal_steps:
|
|
454
|
+
# print(f'Temporal steps are different.')
|
|
455
|
+
# print(f'\tmine: {self.temporal_steps}')
|
|
456
|
+
# print(f'\tother: {other.temporal_steps}')
|
|
457
|
+
# assert self.temporal_steps == other.temporal_steps, "BUG"
|
|
458
|
+
self.temporal_steps.update(other.temporal_steps)
|
|
459
|
+
self.symbols.extend([s for s in other.symbols if s not in self.symbols])
|
|
460
|
+
# Assert compute stats are the same
|
|
461
|
+
# assert self.compute_stats == other.compute_stats, "BUG"
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
@dataclass
|
|
465
|
+
class AnalysisInfo:
|
|
466
|
+
"""Information needed within the analysis step by multiple functions that
|
|
467
|
+
can be computed once at the beginning.
|
|
468
|
+
"""
|
|
469
|
+
|
|
470
|
+
mapping: Mapping
|
|
471
|
+
workload: Workload
|
|
472
|
+
full_rank_variable_shapes: dict
|
|
473
|
+
all_tensors: set
|
|
474
|
+
|
|
475
|
+
einsum_tensor_to_projection: dict
|
|
476
|
+
tensor_to_relevancy: dict
|
|
477
|
+
tensor_to_backer_id: dict[TensorName, int]
|
|
478
|
+
|
|
479
|
+
is_copy_operation: TensorName | None
|
|
480
|
+
|
|
481
|
+
job: Job
|
|
482
|
+
|
|
483
|
+
tensor_to_reservation_backer_id: dict[TensorName, int] = field(default_factory=dict)
|
|
484
|
+
|
|
485
|
+
# We track first latency for these nodes (should be Temporal)
|
|
486
|
+
last_temporal_node_idx: int = None
|
|
487
|
+
"""
|
|
488
|
+
node idx of the last (above) temporal node
|
|
489
|
+
"""
|
|
490
|
+
idxs_to_track_first_latency: set[int] = field(default_factory=set)
|
|
491
|
+
"""
|
|
492
|
+
node idxs for which we track first latency
|
|
493
|
+
"""
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def quick_insert_reservation_nodes(job: Job) -> list[MappingNode]:
|
|
497
|
+
mapping = list(job.mapping.nodes)
|
|
498
|
+
workload = job.spec.workload
|
|
499
|
+
|
|
500
|
+
# TODO: Subclass reservation with TensorReservation or something so that we can
|
|
501
|
+
# track which are for tensors and which are for non-tensor resources.
|
|
502
|
+
|
|
503
|
+
info = AnalysisInfo(
|
|
504
|
+
mapping=None,
|
|
505
|
+
workload=workload,
|
|
506
|
+
full_rank_variable_shapes=None,
|
|
507
|
+
all_tensors=None,
|
|
508
|
+
einsum_tensor_to_projection=None,
|
|
509
|
+
tensor_to_relevancy=job.tensor_to_relevancy,
|
|
510
|
+
tensor_to_backer_id=None,
|
|
511
|
+
is_copy_operation=None,
|
|
512
|
+
job=None,
|
|
513
|
+
)
|
|
514
|
+
insert_reservation_nodes(mapping, info)
|
|
515
|
+
m = Mapping(nodes=mapping)
|
|
516
|
+
m._n_loop_orders = job.mapping._n_loop_orders
|
|
517
|
+
return m
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
def convert_to_copy(
|
|
521
|
+
mapping: list[MappingNode], workload: Workload
|
|
522
|
+
) -> tuple[list[MappingNode], dict[TensorName, int]]:
|
|
523
|
+
mapping = copy.deepcopy(mapping)
|
|
524
|
+
|
|
525
|
+
# Calculate this BEFORE we modify the mapping. We're going to have the copy source
|
|
526
|
+
# tensor moving upward sometimes, and we don't want the backing tensor holder
|
|
527
|
+
tensor_to_backer_id = get_tensor_to_backer_id(mapping)
|
|
528
|
+
|
|
529
|
+
first_input_tensor = workload.einsums[mapping[-1].einsum].copy_source_tensor()
|
|
530
|
+
|
|
531
|
+
for node in mapping:
|
|
532
|
+
if isinstance(node, TensorHolder):
|
|
533
|
+
if node.tensors:
|
|
534
|
+
node.tensors = [first_input_tensor]
|
|
535
|
+
node._lower = False
|
|
536
|
+
|
|
537
|
+
to_remove = []
|
|
538
|
+
i = 0
|
|
539
|
+
while i < len(mapping):
|
|
540
|
+
node = mapping[i]
|
|
541
|
+
if isinstance(node, TensorHolder):
|
|
542
|
+
j = i + 1
|
|
543
|
+
while j < len(mapping):
|
|
544
|
+
node2 = mapping[j]
|
|
545
|
+
if (
|
|
546
|
+
isinstance(node2, TensorHolder)
|
|
547
|
+
and node.component == node2.component
|
|
548
|
+
):
|
|
549
|
+
mapping.pop(j)
|
|
550
|
+
else:
|
|
551
|
+
j += 1
|
|
552
|
+
i += 1
|
|
553
|
+
mapping = [node for node in mapping if node not in to_remove]
|
|
554
|
+
|
|
555
|
+
return mapping, tensor_to_backer_id
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
def analyze_reuse_and_add_reservations_to_mapping(
|
|
559
|
+
job: Job,
|
|
560
|
+
) -> SymbolicAnalysisOutput:
|
|
561
|
+
mapping = job.mapping.nodes
|
|
562
|
+
workload = job.spec.workload
|
|
563
|
+
einsum_name = mapping[-1].einsum
|
|
564
|
+
|
|
565
|
+
is_copy_operation = workload.einsums[einsum_name].is_copy_operation
|
|
566
|
+
symbols = insert_sympy_symbols(job.mapping.nodes, job)
|
|
567
|
+
|
|
568
|
+
if is_copy_operation:
|
|
569
|
+
mapping, tensor_to_backer_id = convert_to_copy(mapping, workload)
|
|
570
|
+
else:
|
|
571
|
+
tensor_to_backer_id = get_tensor_to_backer_id(mapping)
|
|
572
|
+
|
|
573
|
+
job.mapping = quick_insert_reservation_nodes(job)
|
|
574
|
+
# print(f'Job mapping: {job.mapping.compact_str()}')
|
|
575
|
+
# for n in job.mapping.nodes:
|
|
576
|
+
# print(f'\t{n.compact_str()}')
|
|
577
|
+
|
|
578
|
+
einsum_tensor_to_projection = {}
|
|
579
|
+
einsum = workload.einsums[einsum_name]
|
|
580
|
+
all_tensors = einsum.tensor_names
|
|
581
|
+
for tensor in all_tensors:
|
|
582
|
+
einsum_tensor_to_projection[(einsum_name, tensor)] = get_projection_expr(
|
|
583
|
+
einsum, tensor
|
|
584
|
+
)
|
|
585
|
+
tensor_to_relevancy = {
|
|
586
|
+
tensor: get_rank_variable_relevancy(einsum, tensor) for tensor in all_tensors
|
|
587
|
+
}
|
|
588
|
+
assert all_tensors, f"Einsum {einsum_name} has no tensors"
|
|
589
|
+
|
|
590
|
+
"""
|
|
591
|
+
Note for how this works.
|
|
592
|
+
|
|
593
|
+
Spatial loops are weird, because they don't belong at a single point in the loop
|
|
594
|
+
nest. For example:
|
|
595
|
+
|
|
596
|
+
- DRAM keep A, B
|
|
597
|
+
- *
|
|
598
|
+
- Reg keep A
|
|
599
|
+
- for n in [0..N)
|
|
600
|
+
- GLB keep B
|
|
601
|
+
- *
|
|
602
|
+
- Compute
|
|
603
|
+
|
|
604
|
+
A loop spatial-for (Reg) k in [0..K) would affect the register at the point of the
|
|
605
|
+
first asterisk, but the global buffer at the point of the second asterisk.
|
|
606
|
+
|
|
607
|
+
To handle this, we make a separate mapping for each tensor, analyze each, and
|
|
608
|
+
combine the results.
|
|
609
|
+
|
|
610
|
+
To anyone who would like to create behavior that simultaneously looks at multiple
|
|
611
|
+
storage nodes for a given memory, note that there will be two challenges to address:
|
|
612
|
+
|
|
613
|
+
1. The code currently analyzes one tensor at a time. This could be fixed by
|
|
614
|
+
processing all mapping(s) together, applying loop(s) from each to only the
|
|
615
|
+
appropriate nodes.
|
|
616
|
+
2. The code must analyze one storage node at a time, and there may be temporal and
|
|
617
|
+
spatial nodes between two storage nodes for a given memory, which would separate
|
|
618
|
+
the analysis steps for the storage nodes. This may be addressed by only
|
|
619
|
+
performing such analysis until the outermost storage node for a particular memory
|
|
620
|
+
has been analyzed.
|
|
621
|
+
"""
|
|
622
|
+
result = None
|
|
623
|
+
|
|
624
|
+
tensor2mapping = {}
|
|
625
|
+
index_expressions = set(einsum.indexing_expressions)
|
|
626
|
+
for k, v in job.rank_variable_bounds.items():
|
|
627
|
+
index_expressions.add(f"0 < {k} <= {v}")
|
|
628
|
+
for tensor in all_tensors:
|
|
629
|
+
cur_mapping = job.mapping._get_single_tensor_mapping(
|
|
630
|
+
tensor, job.flattened_arch, index_expressions
|
|
631
|
+
)
|
|
632
|
+
info = AnalysisInfo(
|
|
633
|
+
mapping=cur_mapping.nodes,
|
|
634
|
+
workload=workload,
|
|
635
|
+
full_rank_variable_shapes=job.rank_variable_bounds,
|
|
636
|
+
all_tensors=set([tensor]),
|
|
637
|
+
einsum_tensor_to_projection=einsum_tensor_to_projection,
|
|
638
|
+
tensor_to_relevancy=tensor_to_relevancy,
|
|
639
|
+
tensor_to_backer_id=tensor_to_backer_id,
|
|
640
|
+
is_copy_operation=is_copy_operation,
|
|
641
|
+
job=job,
|
|
642
|
+
)
|
|
643
|
+
cur_result = analyze_node(0, job.rank_variable_bounds, info)
|
|
644
|
+
if result is None:
|
|
645
|
+
result = cur_result
|
|
646
|
+
else:
|
|
647
|
+
result.add_buffet_stats_and_symbols(cur_result)
|
|
648
|
+
tensor2mapping[tensor] = cur_mapping
|
|
649
|
+
|
|
650
|
+
result.symbols = symbols
|
|
651
|
+
result.tensor2mapping = tensor2mapping
|
|
652
|
+
return result
|
|
653
|
+
|
|
654
|
+
|
|
655
|
+
def get_tensor_to_backer_id(mapping: Mapping):
|
|
656
|
+
tensor_to_ids: dict[TensorName, set[int]] = {}
|
|
657
|
+
for node in mapping:
|
|
658
|
+
if isinstance(node, TensorHolder):
|
|
659
|
+
for tensor in node.tensors:
|
|
660
|
+
if tensor in tensor_to_ids:
|
|
661
|
+
continue
|
|
662
|
+
tensor_to_ids[tensor] = id(node)
|
|
663
|
+
return tensor_to_ids
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
class ReservationAnalysisTracker:
|
|
667
|
+
def __init__(self, buffet, node):
|
|
668
|
+
self.buffet: Buffet = buffet
|
|
669
|
+
self.node: TensorHolder = node
|
|
670
|
+
|
|
671
|
+
# These are interface (TODO: should be property)
|
|
672
|
+
self.is_fill_level = False
|
|
673
|
+
self.should_stop = False
|
|
674
|
+
self.insert_reservation_under = False
|
|
675
|
+
self.insert_fill_under = False
|
|
676
|
+
|
|
677
|
+
# Temporary values
|
|
678
|
+
self.has_filled = False
|
|
679
|
+
|
|
680
|
+
def track_temporal_loop(self, relevancy, node):
|
|
681
|
+
self.is_fill_level = False
|
|
682
|
+
self.insert_reservation_under = False
|
|
683
|
+
self.insert_fill_under = False
|
|
684
|
+
|
|
685
|
+
if isinstance(relevancy, Irrelevant):
|
|
686
|
+
if not self.has_filled:
|
|
687
|
+
self.is_fill_level = True
|
|
688
|
+
self.has_filled = True
|
|
689
|
+
|
|
690
|
+
self.should_stop = True
|
|
691
|
+
elif isinstance(relevancy, Relevant):
|
|
692
|
+
self.should_stop = False
|
|
693
|
+
elif isinstance(relevancy, PartiallyRelevant):
|
|
694
|
+
self.last = True
|
|
695
|
+
|
|
696
|
+
if not self.has_filled:
|
|
697
|
+
self.is_fill_level = True
|
|
698
|
+
self.has_filled = True
|
|
699
|
+
|
|
700
|
+
self.should_stop = True
|
|
701
|
+
self.insert_reservation_under = True
|
|
702
|
+
else:
|
|
703
|
+
raise ValueError(f"Unknown relevancy {relevancy}")
|
|
704
|
+
|
|
705
|
+
def track_compute(self):
|
|
706
|
+
self.should_stop = True
|
|
707
|
+
if not self.has_filled:
|
|
708
|
+
self.is_fill_level = True
|
|
709
|
+
self.has_filled = True
|
|
710
|
+
|
|
711
|
+
def track_spatial_loop(self, relevancy, node):
|
|
712
|
+
if node.component != self.buffet.level:
|
|
713
|
+
self.should_stop = True
|
|
714
|
+
if not self.has_filled:
|
|
715
|
+
self.is_fill_level = True
|
|
716
|
+
self.has_filled = True
|
|
717
|
+
return
|
|
718
|
+
|
|
719
|
+
self.is_fill_level = False
|
|
720
|
+
self.should_stop = False
|
|
721
|
+
|
|
722
|
+
|
|
723
|
+
def insert_reservation_nodes(mapping, info: AnalysisInfo):
|
|
724
|
+
trackers: list[ReservationAnalysisTracker] = []
|
|
725
|
+
einsum = info.workload.einsums[mapping[-1].einsum]
|
|
726
|
+
non_intermediate_tensors = (
|
|
727
|
+
einsum.tensor_names - info.workload.tensor_names_used_in_multiple_einsums
|
|
728
|
+
)
|
|
729
|
+
seen_tensors = set() # reservation for top-level buffets cannot be lowered
|
|
730
|
+
|
|
731
|
+
n_nodes = len(mapping)
|
|
732
|
+
i = 0
|
|
733
|
+
while i < n_nodes:
|
|
734
|
+
node = mapping[i]
|
|
735
|
+
to_remove = []
|
|
736
|
+
if isinstance(node, Reservation):
|
|
737
|
+
pass
|
|
738
|
+
elif isinstance(node, Temporal):
|
|
739
|
+
rank = node.rank_variable
|
|
740
|
+
for tracker in trackers:
|
|
741
|
+
relevancy = info.tensor_to_relevancy[tracker.buffet.tensor]
|
|
742
|
+
tracker.track_temporal_loop(relevancy[rank], node)
|
|
743
|
+
elif isinstance(node, Spatial):
|
|
744
|
+
rank = node.rank_variable
|
|
745
|
+
for tracker in trackers:
|
|
746
|
+
relevancy = info.tensor_to_relevancy[tracker.buffet.tensor]
|
|
747
|
+
tracker.track_spatial_loop(relevancy[rank], node)
|
|
748
|
+
elif isinstance(node, TensorHolder):
|
|
749
|
+
for tracker in trackers:
|
|
750
|
+
tracker.should_stop = True
|
|
751
|
+
tracker.insert_reservation_under = False
|
|
752
|
+
for tensor in node.tensors:
|
|
753
|
+
tensor = TensorName(tensor)
|
|
754
|
+
buffet = Buffet(tensor, mapping[-1].einsum, node.component)
|
|
755
|
+
trackers.append(ReservationAnalysisTracker(buffet, node))
|
|
756
|
+
if not node._lower or (
|
|
757
|
+
tensor not in seen_tensors and tensor in non_intermediate_tensors
|
|
758
|
+
):
|
|
759
|
+
seen_tensors.add(tensor)
|
|
760
|
+
trackers[-1].is_fill_level = True
|
|
761
|
+
trackers[-1].insert_reservation_under = True
|
|
762
|
+
trackers[-1].insert_fill_under = True
|
|
763
|
+
trackers[-1].should_stop = True
|
|
764
|
+
elif isinstance(node, mapping_spec.Compute):
|
|
765
|
+
for tracker in trackers:
|
|
766
|
+
tracker.track_compute()
|
|
767
|
+
tracker.insert_reservation_under = False
|
|
768
|
+
else:
|
|
769
|
+
raise NotImplementedError(f"Unknown node type {type(node)}")
|
|
770
|
+
|
|
771
|
+
reservation_insert_below = []
|
|
772
|
+
reservation_insert_above = []
|
|
773
|
+
for j in range(len(trackers) - 1, -1, -1):
|
|
774
|
+
if not trackers[j].should_stop:
|
|
775
|
+
continue
|
|
776
|
+
tracker = trackers.pop(j)
|
|
777
|
+
buffet = tracker.buffet
|
|
778
|
+
node = Reservation(purposes=[buffet.tensor], resource=buffet.level)
|
|
779
|
+
node.persistent = tracker.node.persistent
|
|
780
|
+
node._backing = tracker.node._backing
|
|
781
|
+
|
|
782
|
+
if (
|
|
783
|
+
buffet.tensor not in info.tensor_to_reservation_backer_id
|
|
784
|
+
and buffet.tensor in info.workload.tensor_names_used_in_multiple_einsums
|
|
785
|
+
):
|
|
786
|
+
info.tensor_to_reservation_backer_id[buffet.tensor] = id(node)
|
|
787
|
+
|
|
788
|
+
if tracker.insert_reservation_under:
|
|
789
|
+
reservation_insert_below.append(node)
|
|
790
|
+
else:
|
|
791
|
+
reservation_insert_above.append(node)
|
|
792
|
+
|
|
793
|
+
# The order of these for loops is important. Reservation must be below fill.
|
|
794
|
+
for node in reservation_insert_below:
|
|
795
|
+
mapping.insert(i + 1, node)
|
|
796
|
+
i += 1
|
|
797
|
+
for node in reservation_insert_above:
|
|
798
|
+
mapping.insert(i, node)
|
|
799
|
+
i += 1
|
|
800
|
+
|
|
801
|
+
i += 1
|
|
802
|
+
n_nodes = len(mapping)
|
|
803
|
+
|
|
804
|
+
label_fused_loops(mapping)
|
|
805
|
+
|
|
806
|
+
|
|
807
|
+
def label_fused_loops(mapping: list[MappingNode]):
|
|
808
|
+
last_backer = None
|
|
809
|
+
for i, node in enumerate(mapping):
|
|
810
|
+
if isinstance(node, Reservation) and node._backing:
|
|
811
|
+
last_backer = i
|
|
812
|
+
if last_backer is None:
|
|
813
|
+
raise ValueError(
|
|
814
|
+
f"No backing TensorHolder found in mapping {", ".join(m.compact_str() for m in mapping)}"
|
|
815
|
+
)
|
|
816
|
+
|
|
817
|
+
for i, node in enumerate(mapping):
|
|
818
|
+
if isinstance(node, Loop):
|
|
819
|
+
node._fused = i < last_backer
|
|
820
|
+
return mapping
|
|
821
|
+
|
|
822
|
+
|
|
823
|
+
def analyze_node(node_idx, current_shape, info: AnalysisInfo) -> SymbolicAnalysisOutput:
|
|
824
|
+
node = info.mapping[node_idx]
|
|
825
|
+
class2analysis_function = {
|
|
826
|
+
Temporal: analyze_temporal,
|
|
827
|
+
Spatial: analyze_spatial,
|
|
828
|
+
Storage: analyze_storage,
|
|
829
|
+
Reservation: analyze_reservation,
|
|
830
|
+
mapping_spec.Compute: analyze_compute,
|
|
831
|
+
ProcessingStage: analyze_processing_stage,
|
|
832
|
+
}
|
|
833
|
+
if type(node) not in class2analysis_function:
|
|
834
|
+
raise TypeError(f"Unknown node type {type(node)}")
|
|
835
|
+
return class2analysis_function[type(node)](node_idx, current_shape, info)
|
|
836
|
+
|
|
837
|
+
|
|
838
|
+
def analyze_temporal(
|
|
839
|
+
node_idx, current_shape, info: AnalysisInfo
|
|
840
|
+
) -> SymbolicAnalysisOutput:
|
|
841
|
+
mapping = info.mapping
|
|
842
|
+
node = mapping[node_idx]
|
|
843
|
+
stride_and_shape = get_stride_and_tile_shape(node, current_shape, node_idx, info)
|
|
844
|
+
|
|
845
|
+
result_accumulator = SymbolicAnalysisOutput()
|
|
846
|
+
|
|
847
|
+
first_latency = None
|
|
848
|
+
|
|
849
|
+
def handle_repeated_value(repeated_shape):
|
|
850
|
+
nonlocal first_latency
|
|
851
|
+
shape_value = repeated_shape.value
|
|
852
|
+
shape_repeats = repeated_shape.repeats
|
|
853
|
+
|
|
854
|
+
child_shape = current_shape.copy()
|
|
855
|
+
child_shape[node.rank_variable] = shape_value
|
|
856
|
+
|
|
857
|
+
child_result = analyze_node(node_idx + 1, child_shape, info)
|
|
858
|
+
|
|
859
|
+
accumulated_buffet_stats = result_accumulator.buffet_stats
|
|
860
|
+
for buffet, stats in child_result.buffet_stats.items():
|
|
861
|
+
relevancy = info.tensor_to_relevancy[buffet.tensor][node.rank_variable]
|
|
862
|
+
is_fully_relevant = isinstance(relevancy, Relevant)
|
|
863
|
+
accumulated_stats = accumulated_buffet_stats.setdefault(
|
|
864
|
+
buffet, blank_buffet_stats()
|
|
865
|
+
)
|
|
866
|
+
accumulated_stats += stats.repeat_temporal(
|
|
867
|
+
shape_repeats, is_fully_relevant=is_fully_relevant
|
|
868
|
+
)
|
|
869
|
+
accumulated_stats.n_loops_above = stats.n_loops_above + 1
|
|
870
|
+
|
|
871
|
+
for einsum, child_steps in child_result.temporal_steps.items():
|
|
872
|
+
if einsum not in result_accumulator.temporal_steps:
|
|
873
|
+
result_accumulator.temporal_steps[einsum] = 0
|
|
874
|
+
result_accumulator.temporal_steps[einsum] += child_steps * shape_repeats
|
|
875
|
+
|
|
876
|
+
result_accumulator.max(fanout=child_result.fanout)
|
|
877
|
+
|
|
878
|
+
for key in child_result.compute_stats:
|
|
879
|
+
if first_latency is None:
|
|
880
|
+
first_latency = child_result.compute_stats[key].max_latency
|
|
881
|
+
|
|
882
|
+
compute_stats = result_accumulator.compute_stats.setdefault(
|
|
883
|
+
key, ComputeStats()
|
|
884
|
+
)
|
|
885
|
+
compute_stats += child_result.compute_stats[key].repeat_temporal(
|
|
886
|
+
shape_repeats
|
|
887
|
+
)
|
|
888
|
+
result_accumulator.compute_stats[key] = compute_stats
|
|
889
|
+
|
|
890
|
+
info.last_temporal_node_idx = node_idx
|
|
891
|
+
|
|
892
|
+
shape = stride_and_shape.shape
|
|
893
|
+
if isinstance(shape, SequenceOfRepatedvalues):
|
|
894
|
+
for repeated_shape in shape.sequence:
|
|
895
|
+
assert isinstance(repeated_shape, RepeatedValue)
|
|
896
|
+
handle_repeated_value(repeated_shape)
|
|
897
|
+
elif isinstance(shape, RepeatedValue):
|
|
898
|
+
handle_repeated_value(shape)
|
|
899
|
+
|
|
900
|
+
if node_idx in info.idxs_to_track_first_latency:
|
|
901
|
+
for compute_stat in result_accumulator.compute_stats.values():
|
|
902
|
+
# Should be the first time we store this value
|
|
903
|
+
assert node_idx not in compute_stat.max_first_latency
|
|
904
|
+
compute_stat.max_first_latency[node_idx] = first_latency
|
|
905
|
+
|
|
906
|
+
return result_accumulator
|
|
907
|
+
|
|
908
|
+
|
|
909
|
+
def analyze_spatial(node_idx, current_shape, info: AnalysisInfo):
|
|
910
|
+
mapping = info.mapping
|
|
911
|
+
einsum_name = mapping[-1].einsum
|
|
912
|
+
node: Spatial = mapping[node_idx]
|
|
913
|
+
rank_var = node.rank_variable
|
|
914
|
+
node_dim = node.name
|
|
915
|
+
stride_and_shape = get_stride_and_tile_shape(node, current_shape, node_idx, info)
|
|
916
|
+
|
|
917
|
+
result_accumulator = SymbolicAnalysisOutput()
|
|
918
|
+
|
|
919
|
+
def handle_repeated_value(repeated_shape):
|
|
920
|
+
shape_value = repeated_shape.value
|
|
921
|
+
shape_repeats = repeated_shape.repeats
|
|
922
|
+
|
|
923
|
+
child_shape = current_shape.copy()
|
|
924
|
+
child_shape[node.rank_variable] = shape_value
|
|
925
|
+
|
|
926
|
+
child_result = analyze_node(node_idx + 1, child_shape, info)
|
|
927
|
+
|
|
928
|
+
accumulated_buffet_stats = result_accumulator.buffet_stats
|
|
929
|
+
child_stats = list(child_result.buffet_stats.items())
|
|
930
|
+
for i, (buffet, buffet_stats) in enumerate(child_stats):
|
|
931
|
+
stats = buffet_stats
|
|
932
|
+
accumulated_stats = accumulated_buffet_stats.setdefault(
|
|
933
|
+
buffet, blank_buffet_stats()
|
|
934
|
+
)
|
|
935
|
+
relevancy = info.tensor_to_relevancy[buffet.tensor][rank_var]
|
|
936
|
+
|
|
937
|
+
# Reuse parent accesses only:
|
|
938
|
+
# - Irrelevant loops
|
|
939
|
+
# - The outermost level that holds the tensor (the one whose parent accesses
|
|
940
|
+
# will be going through the network)
|
|
941
|
+
last_buffet = True
|
|
942
|
+
for other_buffet, _ in child_stats[i + 1 :]:
|
|
943
|
+
if other_buffet.tensor == buffet.tensor:
|
|
944
|
+
last_buffet = False
|
|
945
|
+
break
|
|
946
|
+
|
|
947
|
+
reuse_parent_accesses = (
|
|
948
|
+
last_buffet
|
|
949
|
+
and isinstance(relevancy, Irrelevant)
|
|
950
|
+
and buffet.tensor in node._may_reuse
|
|
951
|
+
)
|
|
952
|
+
|
|
953
|
+
accumulated_stats += stats.repeat_spatial(
|
|
954
|
+
shape_repeats, reuse_parent_accesses=reuse_parent_accesses
|
|
955
|
+
)
|
|
956
|
+
accumulated_stats.n_loops_above = stats.n_loops_above + 1
|
|
957
|
+
|
|
958
|
+
for einsum, child_steps in child_result.temporal_steps.items():
|
|
959
|
+
if einsum not in result_accumulator.temporal_steps:
|
|
960
|
+
result_accumulator.temporal_steps[einsum] = child_steps
|
|
961
|
+
else:
|
|
962
|
+
result_accumulator.temporal_steps[einsum] = Max(
|
|
963
|
+
result_accumulator.temporal_steps[einsum], child_steps
|
|
964
|
+
)
|
|
965
|
+
|
|
966
|
+
my_key = (node.component, einsum_name)
|
|
967
|
+
child_result.fanout.setdefault(my_key, {})
|
|
968
|
+
|
|
969
|
+
# Propagate up everything except the current level and dimension
|
|
970
|
+
child_fanout = copy.deepcopy(child_result.fanout)
|
|
971
|
+
target_fanout = child_fanout[my_key].pop(node_dim, 1)
|
|
972
|
+
result_accumulator.max(fanout=child_fanout)
|
|
973
|
+
|
|
974
|
+
# Prpoagate current level and dimension * shape_repeats
|
|
975
|
+
child_fanout = child_result.fanout[my_key]
|
|
976
|
+
fanout = result_accumulator.fanout.setdefault(my_key, {})
|
|
977
|
+
fanout.setdefault(node_dim, 0) # TODO: Assume sympy can just take in 0
|
|
978
|
+
# TODO: If node_dim was missing, the original code would have omitted
|
|
979
|
+
# shape_repeats. Is this correct?
|
|
980
|
+
fanout[node_dim] += target_fanout * shape_repeats
|
|
981
|
+
|
|
982
|
+
for key in child_result.compute_stats:
|
|
983
|
+
# TODO: ensure that `ComputeStats()`, which is initialized ONCE, is okay to use here
|
|
984
|
+
compute_stats = result_accumulator.compute_stats.setdefault(
|
|
985
|
+
key, ComputeStats()
|
|
986
|
+
)
|
|
987
|
+
# TODO: If check omitted. This was in the original code, check history if needed.
|
|
988
|
+
compute_stats.combine_spatial(
|
|
989
|
+
child_result.compute_stats[key].repeat_spatial(shape_repeats)
|
|
990
|
+
)
|
|
991
|
+
|
|
992
|
+
shape = stride_and_shape.shape
|
|
993
|
+
if isinstance(shape, SequenceOfRepatedvalues):
|
|
994
|
+
for repeated_shape in shape.sequence:
|
|
995
|
+
assert isinstance(repeated_shape, RepeatedValue)
|
|
996
|
+
handle_repeated_value(repeated_shape)
|
|
997
|
+
elif isinstance(shape, RepeatedValue):
|
|
998
|
+
handle_repeated_value(shape)
|
|
999
|
+
|
|
1000
|
+
return result_accumulator
|
|
1001
|
+
|
|
1002
|
+
|
|
1003
|
+
def reduce_dicts(dict1: dict, dict2: dict, reduce_op):
|
|
1004
|
+
for key in dict1:
|
|
1005
|
+
if key not in dict2:
|
|
1006
|
+
dict2[key] = dict1[key]
|
|
1007
|
+
else:
|
|
1008
|
+
dict2[key] = reduce_op(dict1[key], dict2[key])
|
|
1009
|
+
|
|
1010
|
+
|
|
1011
|
+
def get_total_to_per_unit(total, max_per_unit):
|
|
1012
|
+
if total == 0 and max_per_unit != 0:
|
|
1013
|
+
raise ValueError(f"total is 0 but max_per_unit is {max_per_unit}")
|
|
1014
|
+
if total == 0:
|
|
1015
|
+
return 1
|
|
1016
|
+
return max_per_unit / total
|
|
1017
|
+
|
|
1018
|
+
|
|
1019
|
+
def has_parent_tensor_holder(
|
|
1020
|
+
tensor: TensorName, node_idx: int, info: AnalysisInfo
|
|
1021
|
+
) -> bool:
|
|
1022
|
+
for node in info.mapping[:node_idx]:
|
|
1023
|
+
if isinstance(node, TensorHolder) and tensor in node.tensors:
|
|
1024
|
+
return True
|
|
1025
|
+
return False
|
|
1026
|
+
|
|
1027
|
+
|
|
1028
|
+
def find_component_object(
|
|
1029
|
+
component: str, flattened_arch: list[arch.Leaf]
|
|
1030
|
+
) -> arch.TensorHolder:
|
|
1031
|
+
for node in flattened_arch:
|
|
1032
|
+
if node.name == component:
|
|
1033
|
+
return node
|
|
1034
|
+
raise ValueError(f"Component {component} not found in flattened arch")
|
|
1035
|
+
|
|
1036
|
+
|
|
1037
|
+
def analyze_storage(
|
|
1038
|
+
node_idx: int,
|
|
1039
|
+
current_shape: dict[str, int],
|
|
1040
|
+
info: AnalysisInfo,
|
|
1041
|
+
propagate_child_results: bool = False,
|
|
1042
|
+
count_writes: bool = True,
|
|
1043
|
+
):
|
|
1044
|
+
mapping = info.mapping
|
|
1045
|
+
einsum_name = mapping[-1].einsum
|
|
1046
|
+
node: TensorHolder = mapping[node_idx]
|
|
1047
|
+
|
|
1048
|
+
child_result = analyze_node(node_idx + 1, current_shape, info)
|
|
1049
|
+
|
|
1050
|
+
for tensor in node.tensors:
|
|
1051
|
+
tensor = TensorName(tensor)
|
|
1052
|
+
buffet = Buffet(tensor, einsum_name, node.component)
|
|
1053
|
+
|
|
1054
|
+
# Reservations make these, and they go below the storage node, so the buffet
|
|
1055
|
+
# stats are already made at this point
|
|
1056
|
+
stats = child_result.buffet_stats[buffet]
|
|
1057
|
+
backer_id = info.tensor_to_backer_id[tensor]
|
|
1058
|
+
is_backing = backer_id == id(node)
|
|
1059
|
+
if node.persistent:
|
|
1060
|
+
stats.persistent = True
|
|
1061
|
+
below_backing = backer_id in [id(m) for m in mapping[:node_idx]]
|
|
1062
|
+
|
|
1063
|
+
projection = info.einsum_tensor_to_projection[(einsum_name, tensor)]
|
|
1064
|
+
|
|
1065
|
+
fills = compute_dense_tile_occupancy(projection, current_shape)
|
|
1066
|
+
|
|
1067
|
+
child = child_result.get_child_buffet_stats(buffet)
|
|
1068
|
+
inherit_from_child = propagate_child_results and child is not None
|
|
1069
|
+
|
|
1070
|
+
# ==============================================================================
|
|
1071
|
+
# Calculate the total fills and reads to parent. These propagate upward.
|
|
1072
|
+
# ==============================================================================
|
|
1073
|
+
|
|
1074
|
+
def inherit_add(attr: str, default_value: Any = fills) -> Any:
|
|
1075
|
+
val = getattr(child, attr) if inherit_from_child else default_value
|
|
1076
|
+
setattr(stats, attr, val + getattr(stats, attr))
|
|
1077
|
+
|
|
1078
|
+
if has_parent_tensor_holder(tensor, node_idx, info):
|
|
1079
|
+
# Initial fetch: If we're below the backing storage, fetch data from above
|
|
1080
|
+
# at the beginning.
|
|
1081
|
+
if not is_backing and below_backing:
|
|
1082
|
+
inherit_add("total_reads_to_parent", fills)
|
|
1083
|
+
inherit_add("max_per_parent_reads_to_parent", fills)
|
|
1084
|
+
|
|
1085
|
+
# Data writeback. Do not writeback if it's a copy operation and we're below
|
|
1086
|
+
# the backing storage; data only flows upward.
|
|
1087
|
+
|
|
1088
|
+
# Writeback occurs in two cases:
|
|
1089
|
+
# - We're at or above the backing storage, so we need to propagate our
|
|
1090
|
+
# results upward to any storage nodes that will need this data.
|
|
1091
|
+
# - This is a written tensor, so we need to write back the written data.
|
|
1092
|
+
if (
|
|
1093
|
+
tensor in info.workload.einsums[einsum_name].output_tensor_names
|
|
1094
|
+
or not below_backing
|
|
1095
|
+
):
|
|
1096
|
+
inherit_add("total_writes_to_parent")
|
|
1097
|
+
inherit_add("max_per_parent_writes_to_parent")
|
|
1098
|
+
|
|
1099
|
+
# For read+write tensors, we skip the first fill because the data will be
|
|
1100
|
+
# initialized with a zero value.
|
|
1101
|
+
if tensor in info.workload.einsums[einsum_name].output_tensor_names:
|
|
1102
|
+
inherit_add("total_skipped_first_reads_to_parent")
|
|
1103
|
+
inherit_add("min_per_parent_skipped_first_reads_to_parent")
|
|
1104
|
+
|
|
1105
|
+
# =========================
|
|
1106
|
+
# Data exchanges with child
|
|
1107
|
+
if child is not None:
|
|
1108
|
+
stats.total_writes_to_child += child.total_reads_to_parent
|
|
1109
|
+
stats.max_per_unit_writes_to_child += child.max_per_parent_reads_to_parent
|
|
1110
|
+
# Skip first read
|
|
1111
|
+
stats.total_skipped_first_writes_to_child += (
|
|
1112
|
+
child.total_skipped_first_reads_to_parent
|
|
1113
|
+
)
|
|
1114
|
+
stats.min_per_unit_skipped_first_writes_to_child += (
|
|
1115
|
+
child.min_per_parent_skipped_first_reads_to_parent
|
|
1116
|
+
)
|
|
1117
|
+
|
|
1118
|
+
stats.total_reads_to_child += child.total_writes_to_parent
|
|
1119
|
+
stats.max_per_unit_reads_to_child += child.max_per_parent_writes_to_parent
|
|
1120
|
+
|
|
1121
|
+
component_object = find_component_object(
|
|
1122
|
+
node.component, info.job.flattened_arch
|
|
1123
|
+
)
|
|
1124
|
+
bits_per_value_scale = component_object.attributes.bits_per_value_scale[tensor]
|
|
1125
|
+
bits_per_value = bits_per_value_scale * info.job.bits_per_value[tensor]
|
|
1126
|
+
read_bits_per_action = component_object.actions[
|
|
1127
|
+
"read"
|
|
1128
|
+
].arguments.bits_per_action
|
|
1129
|
+
stats.read_scale = bits_per_value / read_bits_per_action
|
|
1130
|
+
if count_writes:
|
|
1131
|
+
write_bits_per_action = component_object.actions[
|
|
1132
|
+
"write"
|
|
1133
|
+
].arguments.bits_per_action
|
|
1134
|
+
stats.write_scale = bits_per_value / write_bits_per_action
|
|
1135
|
+
else:
|
|
1136
|
+
stats.write_scale = 0
|
|
1137
|
+
|
|
1138
|
+
return child_result
|
|
1139
|
+
|
|
1140
|
+
|
|
1141
|
+
def analyze_processing_stage(node_idx, current_shape, info: AnalysisInfo):
|
|
1142
|
+
mapping = info.mapping
|
|
1143
|
+
einsum_name = mapping[-1].einsum
|
|
1144
|
+
node = mapping[node_idx]
|
|
1145
|
+
component_object = find_component_object(node.component, info.job.flattened_arch)
|
|
1146
|
+
storage_result = analyze_storage(
|
|
1147
|
+
node_idx,
|
|
1148
|
+
current_shape,
|
|
1149
|
+
info,
|
|
1150
|
+
propagate_child_results=True,
|
|
1151
|
+
count_writes=False,
|
|
1152
|
+
)
|
|
1153
|
+
for tensor in node.tensors:
|
|
1154
|
+
buffet = Buffet(tensor, einsum_name, node.component)
|
|
1155
|
+
stats = storage_result.buffet_stats[buffet]
|
|
1156
|
+
stats.max_occupancy = 0
|
|
1157
|
+
stats.count_downward_movement = component_object.attributes.direction != "up"
|
|
1158
|
+
stats.count_upward_movement = component_object.attributes.direction != "down"
|
|
1159
|
+
assert stats.total_write_actions == 0
|
|
1160
|
+
return storage_result
|
|
1161
|
+
|
|
1162
|
+
|
|
1163
|
+
def analyze_reservation(node_idx, current_shape, info: AnalysisInfo):
|
|
1164
|
+
mapping = info.mapping
|
|
1165
|
+
einsum_name = mapping[-1].einsum
|
|
1166
|
+
node = mapping[node_idx]
|
|
1167
|
+
tensor = TensorName(node.purpose)
|
|
1168
|
+
|
|
1169
|
+
if info.last_temporal_node_idx is not None and id(
|
|
1170
|
+
node
|
|
1171
|
+
) == info.tensor_to_reservation_backer_id.get(node.purpose, None):
|
|
1172
|
+
info.idxs_to_track_first_latency.add(info.last_temporal_node_idx)
|
|
1173
|
+
|
|
1174
|
+
child_result = analyze_node(node_idx + 1, current_shape, info)
|
|
1175
|
+
|
|
1176
|
+
buffet = Buffet(tensor, einsum_name, node.resource)
|
|
1177
|
+
|
|
1178
|
+
# Reservation nodes are the first to produce stats for a buffet
|
|
1179
|
+
assert buffet not in child_result.buffet_stats
|
|
1180
|
+
|
|
1181
|
+
stats = BuffetStats()
|
|
1182
|
+
projection = info.einsum_tensor_to_projection[(einsum_name, tensor)]
|
|
1183
|
+
component_object = find_component_object(node.resource, info.job.flattened_arch)
|
|
1184
|
+
bits_per_value_scale = component_object.attributes.bits_per_value_scale[tensor]
|
|
1185
|
+
bits_per_value = bits_per_value_scale * info.job.bits_per_value[tensor]
|
|
1186
|
+
stats.max_occupancy = (
|
|
1187
|
+
compute_dense_tile_occupancy(projection, current_shape) * bits_per_value
|
|
1188
|
+
)
|
|
1189
|
+
child_result.buffet_stats[buffet] = stats
|
|
1190
|
+
|
|
1191
|
+
fanout_key = (node.resource, einsum_name)
|
|
1192
|
+
if fanout_key not in child_result.fanout:
|
|
1193
|
+
child_result.fanout[fanout_key] = {}
|
|
1194
|
+
|
|
1195
|
+
return child_result
|
|
1196
|
+
|
|
1197
|
+
|
|
1198
|
+
def analyze_compute(
|
|
1199
|
+
node_idx, current_shape, info: AnalysisInfo
|
|
1200
|
+
) -> SymbolicAnalysisOutput:
|
|
1201
|
+
einsum = info.mapping[-1].einsum
|
|
1202
|
+
node = info.mapping[node_idx]
|
|
1203
|
+
|
|
1204
|
+
computes = 0 if info.is_copy_operation else 1
|
|
1205
|
+
|
|
1206
|
+
result_accumulator = SymbolicAnalysisOutput()
|
|
1207
|
+
|
|
1208
|
+
result_accumulator.temporal_steps[einsum] = computes
|
|
1209
|
+
result_accumulator.compute_stats[Compute(einsum, node.component)] = ComputeStats(
|
|
1210
|
+
computes,
|
|
1211
|
+
computes,
|
|
1212
|
+
1,
|
|
1213
|
+
)
|
|
1214
|
+
|
|
1215
|
+
if info.is_copy_operation:
|
|
1216
|
+
return result_accumulator
|
|
1217
|
+
|
|
1218
|
+
for tensor in info.all_tensors:
|
|
1219
|
+
buffet = Buffet(tensor, einsum, node.component)
|
|
1220
|
+
stats = BuffetStats()
|
|
1221
|
+
stats.total_reads_to_parent = 1
|
|
1222
|
+
stats.max_per_parent_reads_to_parent = 1
|
|
1223
|
+
if tensor in info.workload.einsums[einsum].output_tensor_names:
|
|
1224
|
+
stats.total_writes_to_parent = 1
|
|
1225
|
+
stats.max_per_parent_writes_to_parent = 1
|
|
1226
|
+
stats.total_skipped_first_reads_to_parent = 1
|
|
1227
|
+
stats.min_per_parent_skipped_first_reads_to_parent = 1
|
|
1228
|
+
stats.max_occupancy = 1
|
|
1229
|
+
result_accumulator.buffet_stats[buffet] = stats
|
|
1230
|
+
|
|
1231
|
+
return result_accumulator
|
|
1232
|
+
|
|
1233
|
+
|
|
1234
|
+
@dataclass
|
|
1235
|
+
class RepeatedValue[T]:
|
|
1236
|
+
value: T
|
|
1237
|
+
repeats: int
|
|
1238
|
+
|
|
1239
|
+
|
|
1240
|
+
@dataclass
|
|
1241
|
+
class SequenceOfRepatedvalues[T]:
|
|
1242
|
+
sequence: list[RepeatedValue[T]]
|
|
1243
|
+
|
|
1244
|
+
|
|
1245
|
+
@dataclass
|
|
1246
|
+
class StrideAndShape:
|
|
1247
|
+
stride: any
|
|
1248
|
+
shape: any
|
|
1249
|
+
|
|
1250
|
+
|
|
1251
|
+
def get_stride_and_tile_shape(node: Loop, full_shape, n: int, info: AnalysisInfo):
|
|
1252
|
+
rank = node.rank_variable
|
|
1253
|
+
rank_shape = full_shape[rank]
|
|
1254
|
+
|
|
1255
|
+
stride = node.stride
|
|
1256
|
+
initial_tile_shape = node.initial_tile_shape
|
|
1257
|
+
|
|
1258
|
+
# PERFECT:
|
|
1259
|
+
# - Node shape = stride
|
|
1260
|
+
# - # Iterations = total shape / stride
|
|
1261
|
+
# IMPERFECT:
|
|
1262
|
+
# - Node shape = stride
|
|
1263
|
+
# - # Iterations = ceil(total shape / stride)
|
|
1264
|
+
if IMPERFECT and initial_tile_shape is None:
|
|
1265
|
+
factor = sympy.ceiling(rank_shape / stride)
|
|
1266
|
+
stride_avg = stride / sympy.ceiling(rank_shape / stride)
|
|
1267
|
+
return StrideAndShape(stride_avg, RepeatedValue(stride, factor))
|
|
1268
|
+
|
|
1269
|
+
if initial_tile_shape is None:
|
|
1270
|
+
if node._assume_perfect_factor or known_perfect_factor(stride, rank_shape):
|
|
1271
|
+
factor = rank_shape / stride
|
|
1272
|
+
return StrideAndShape(stride, RepeatedValue(stride, factor))
|
|
1273
|
+
else:
|
|
1274
|
+
factor = sympy.ceiling(rank_shape / sympy.Min(stride, rank_shape))
|
|
1275
|
+
return make_possibly_different_last(stride, factor, rank_shape)
|
|
1276
|
+
|
|
1277
|
+
middle_shape_factor = sympy.ceiling((rank_shape - initial_tile_shape) / stride)
|
|
1278
|
+
# TODO: sometimes last_shape is 0, causing numerical instability
|
|
1279
|
+
# Currently, we are sometimes rounding up last shape.
|
|
1280
|
+
# last_shape = rank_shape - initial_tile_shape - stride*middle_shape_factor
|
|
1281
|
+
# has_last_shape = sympy.ceiling(last_shape/(last_shape+1))
|
|
1282
|
+
return StrideAndShape(
|
|
1283
|
+
stride,
|
|
1284
|
+
SequenceOfRepatedvalues(
|
|
1285
|
+
[
|
|
1286
|
+
RepeatedValue(initial_tile_shape, 1),
|
|
1287
|
+
RepeatedValue(stride, middle_shape_factor),
|
|
1288
|
+
# RepeatedValue(last_shape+0.01, has_last_shape)
|
|
1289
|
+
]
|
|
1290
|
+
),
|
|
1291
|
+
)
|
|
1292
|
+
# if node.tile_shape is not None:
|
|
1293
|
+
# tile_shape = node.tile_shape
|
|
1294
|
+
|
|
1295
|
+
# if node._assume_perfect_factor or known_perfect_factor(tile_shape, rank_shape):
|
|
1296
|
+
# factor = rank_shape / tile_shape
|
|
1297
|
+
# return StrideAndShape(tile_shape, RepeatedValue(tile_shape, factor))
|
|
1298
|
+
# else:
|
|
1299
|
+
# factor = sympy.ceiling(rank_shape / sympy.Min(tile_shape, rank_shape))
|
|
1300
|
+
# return make_possibly_different_last(tile_shape, factor, rank_shape)
|
|
1301
|
+
# elif node.loop_bound is not None:
|
|
1302
|
+
# factor = node.loop_bound
|
|
1303
|
+
|
|
1304
|
+
# if node._assume_perfect_factor or known_perfect_factor(factor, rank_shape):
|
|
1305
|
+
# tile_shape = rank_shape / factor
|
|
1306
|
+
# return StrideAndShape(tile_shape, RepeatedValue(tile_shape, factor))
|
|
1307
|
+
# else:
|
|
1308
|
+
# tile_shape = sympy.ceiling(rank_shape / sympy.Min(rank_shape, factor))
|
|
1309
|
+
# return make_possibly_different_last(tile_shape, factor, rank_shape)
|
|
1310
|
+
|
|
1311
|
+
# elif node.tile_pattern is not None:
|
|
1312
|
+
# stride = node.tile_pattern.stride
|
|
1313
|
+
# initial_tile_shape = node.tile_pattern.initial_tile_shape
|
|
1314
|
+
# tile_shape = node.tile_pattern.tile_shape
|
|
1315
|
+
|
|
1316
|
+
# if initial_tile_shape is not None:
|
|
1317
|
+
# middle_shape_factor = sympy.ceiling((rank_shape - initial_tile_shape)/stride)
|
|
1318
|
+
# # TODO: sometimes last_shape is 0, causing numerical instability
|
|
1319
|
+
# # Currently, we are sometimes rounding up last shape.
|
|
1320
|
+
# # last_shape = rank_shape - initial_tile_shape - stride*middle_shape_factor
|
|
1321
|
+
# # has_last_shape = sympy.ceiling(last_shape/(last_shape+1))
|
|
1322
|
+
# return StrideAndShape(
|
|
1323
|
+
# stride,
|
|
1324
|
+
# SequenceOfRepatedvalues([
|
|
1325
|
+
# RepeatedValue(initial_tile_shape, 1),
|
|
1326
|
+
# RepeatedValue(stride, middle_shape_factor),
|
|
1327
|
+
# # RepeatedValue(last_shape+0.01, has_last_shape)
|
|
1328
|
+
# ])
|
|
1329
|
+
# )
|
|
1330
|
+
|
|
1331
|
+
|
|
1332
|
+
def known_perfect_factor(divisor, full_shape):
|
|
1333
|
+
return (
|
|
1334
|
+
isinstance(divisor, int)
|
|
1335
|
+
and isinstance(full_shape, int)
|
|
1336
|
+
and full_shape % divisor == 1
|
|
1337
|
+
)
|
|
1338
|
+
|
|
1339
|
+
|
|
1340
|
+
def make_possibly_different_last(common_tile_shape, factor, full_shape):
|
|
1341
|
+
last_shape = full_shape - common_tile_shape * (factor - 1)
|
|
1342
|
+
all_shapes = SequenceOfRepatedvalues(
|
|
1343
|
+
[RepeatedValue(common_tile_shape, factor - 1), RepeatedValue(last_shape, 1)]
|
|
1344
|
+
)
|
|
1345
|
+
return StrideAndShape(common_tile_shape, all_shapes)
|
|
1346
|
+
|
|
1347
|
+
|
|
1348
|
+
def insert_sympy_symbols(mapping: list[MappingNode], job: Job):
|
|
1349
|
+
loop_idx = 0
|
|
1350
|
+
symbols = []
|
|
1351
|
+
rank_var_with_initial = set()
|
|
1352
|
+
for i, node in enumerate(mapping):
|
|
1353
|
+
if not isinstance(node, Loop):
|
|
1354
|
+
continue
|
|
1355
|
+
|
|
1356
|
+
stride_halos = set()
|
|
1357
|
+
for t in job.spec.workload.einsums[job.einsum_name].tensor_names:
|
|
1358
|
+
for (rank, rank_variable), (stride, halo) in job.stride_and_halo[t].items():
|
|
1359
|
+
if rank_variable == node.rank_variable:
|
|
1360
|
+
stride_halos.add((stride, halo))
|
|
1361
|
+
|
|
1362
|
+
if len(stride_halos) == 0:
|
|
1363
|
+
raise RuntimeError(
|
|
1364
|
+
f"{repr(node.rank_variable)} not found in {job.stride_and_halo}"
|
|
1365
|
+
)
|
|
1366
|
+
|
|
1367
|
+
# We only explore imperfect for the outermost fused loops
|
|
1368
|
+
simple = (
|
|
1369
|
+
(len(stride_halos) <= 1 and next(iter(stride_halos)) == (1, 0))
|
|
1370
|
+
or node.rank_variable in rank_var_with_initial
|
|
1371
|
+
or not node._fused
|
|
1372
|
+
)
|
|
1373
|
+
|
|
1374
|
+
# NOTE: initial_tile_shape must be inserted into `symbols` before `stride`
|
|
1375
|
+
# because of the order of tile shape exploration.
|
|
1376
|
+
# TODO: there has to be a better way to do this.
|
|
1377
|
+
if simple: # Just use the stride!
|
|
1378
|
+
node.initial_tile_shape = None
|
|
1379
|
+
elif node.initial_tile_shape == SYMBOL:
|
|
1380
|
+
rank_var_with_initial.add(node.rank_variable)
|
|
1381
|
+
initial_tile_shape = sympy.symbols(
|
|
1382
|
+
f"initial{loop_idx}", positive=True, integer=True
|
|
1383
|
+
)
|
|
1384
|
+
symbols.append(initial_tile_shape)
|
|
1385
|
+
node.initial_tile_shape = initial_tile_shape
|
|
1386
|
+
|
|
1387
|
+
# TODO: Check for 0 < shape < 1 for loop bound target
|
|
1388
|
+
if job.rank_variable_bounds[node.rank_variable] == 1:
|
|
1389
|
+
node.stride = 1
|
|
1390
|
+
elif node.stride == SYMBOL:
|
|
1391
|
+
stride = sympy.symbols(f"stride{loop_idx}", positive=True, integer=True)
|
|
1392
|
+
symbols.append(stride)
|
|
1393
|
+
node.stride = stride
|
|
1394
|
+
|
|
1395
|
+
# TODO: sometimes, a mapping is passed into the model twice.
|
|
1396
|
+
# E.g., after calling mapper, the model is called again for more
|
|
1397
|
+
# details.
|
|
1398
|
+
#
|
|
1399
|
+
# assert (
|
|
1400
|
+
# node.calculated_n_iterations is None
|
|
1401
|
+
# ), "Number of iterations is derived from the model. Do not set it!"
|
|
1402
|
+
node.calculated_n_iterations = sympy.symbols(
|
|
1403
|
+
f"n_iterations{loop_idx}", positive=True, integer=True
|
|
1404
|
+
)
|
|
1405
|
+
|
|
1406
|
+
loop_idx += 1
|
|
1407
|
+
|
|
1408
|
+
return symbols
|