accelforge 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accelforge/__init__.py +21 -0
- accelforge/_accelerated_imports.py +16 -0
- accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
- accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
- accelforge/_deprecate/_simanneal/simanneal.py +666 -0
- accelforge/_deprecate/_simanneal/tracking.py +105 -0
- accelforge/_deprecate/_simanneal/wrappers.py +218 -0
- accelforge/_deprecate/_simanneal2/__init__.py +7 -0
- accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
- accelforge/_deprecate/_simanneal2/tracking.py +116 -0
- accelforge/_deprecate/compatibility_util.py +181 -0
- accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
- accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
- accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
- accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
- accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
- accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
- accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
- accelforge/_deprecate/tags.py +69 -0
- accelforge/_deprecate/viz/__init__.py +0 -0
- accelforge/_deprecate/viz/interactive.py +159 -0
- accelforge/_deprecate/viz/reservationtree.py +307 -0
- accelforge/_deprecate/viz/ski_slope.py +88 -0
- accelforge/_version.py +15 -0
- accelforge/examples.py +39 -0
- accelforge/frontend/__init__.py +10 -0
- accelforge/frontend/_binding.py +129 -0
- accelforge/frontend/_workload_isl/__init__.py +2 -0
- accelforge/frontend/_workload_isl/_isl.py +149 -0
- accelforge/frontend/_workload_isl/_symbolic.py +141 -0
- accelforge/frontend/arch copy.py +1544 -0
- accelforge/frontend/arch.py +1642 -0
- accelforge/frontend/config.py +63 -0
- accelforge/frontend/mapper/__init__.py +5 -0
- accelforge/frontend/mapper/ffm.py +126 -0
- accelforge/frontend/mapper/mapper.py +7 -0
- accelforge/frontend/mapper/metrics.py +30 -0
- accelforge/frontend/mapping/__init__.py +1 -0
- accelforge/frontend/mapping/mapping.py +1736 -0
- accelforge/frontend/model.py +14 -0
- accelforge/frontend/renames.py +150 -0
- accelforge/frontend/spec copy.py +230 -0
- accelforge/frontend/spec.py +301 -0
- accelforge/frontend/variables.py +12 -0
- accelforge/frontend/workload.py +952 -0
- accelforge/mapper/FFM/__init__.py +9 -0
- accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
- accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
- accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
- accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
- accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
- accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
- accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
- accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
- accelforge/mapper/FFM/data.py +61 -0
- accelforge/mapper/FFM/main copy.py +236 -0
- accelforge/mapper/FFM/main.py +208 -0
- accelforge/mapper/FFM/mappings.py +510 -0
- accelforge/mapper/FFM/pmappings.py +310 -0
- accelforge/mapper/__init__.py +4 -0
- accelforge/mapper.py +0 -0
- accelforge/model/__init__.py +1 -0
- accelforge/model/_looptree/__init__.py +0 -0
- accelforge/model/_looptree/accesses.py +335 -0
- accelforge/model/_looptree/capacity/__init__.py +1 -0
- accelforge/model/_looptree/capacity/aggregators.py +36 -0
- accelforge/model/_looptree/capacity/capacity.py +47 -0
- accelforge/model/_looptree/energy.py +150 -0
- accelforge/model/_looptree/equivalent_ranks.py +29 -0
- accelforge/model/_looptree/latency/__init__.py +1 -0
- accelforge/model/_looptree/latency/latency.py +98 -0
- accelforge/model/_looptree/latency/memory.py +120 -0
- accelforge/model/_looptree/latency/processors.py +92 -0
- accelforge/model/_looptree/mapping_utilities.py +71 -0
- accelforge/model/_looptree/reuse/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
- accelforge/model/_looptree/reuse/isl/des.py +59 -0
- accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
- accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
- accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
- accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
- accelforge/model/_looptree/run.py +122 -0
- accelforge/model/_looptree/types.py +26 -0
- accelforge/model/_looptree/visualization/__init__.py +0 -0
- accelforge/model/_looptree/visualization/occupancy.py +11 -0
- accelforge/model/main.py +222 -0
- accelforge/plotting/__init__.py +2 -0
- accelforge/plotting/mappings.py +219 -0
- accelforge/plotting/specs.py +57 -0
- accelforge/util/__init__.py +4 -0
- accelforge/util/_base_analysis_types.py +24 -0
- accelforge/util/_basetypes.py +1089 -0
- accelforge/util/_frozenset.py +36 -0
- accelforge/util/_isl.py +29 -0
- accelforge/util/_itertools.py +14 -0
- accelforge/util/_mathfuncs.py +57 -0
- accelforge/util/_parse_expressions.py +339 -0
- accelforge/util/_picklecache.py +32 -0
- accelforge/util/_setexpressions.py +268 -0
- accelforge/util/_sympy/__init__.py +0 -0
- accelforge/util/_sympy/broadcast_max.py +18 -0
- accelforge/util/_visualization.py +112 -0
- accelforge/util/_yaml.py +579 -0
- accelforge/util/parallel.py +193 -0
- accelforge-0.0.1.dist-info/METADATA +64 -0
- accelforge-0.0.1.dist-info/RECORD +258 -0
- accelforge-0.0.1.dist-info/WHEEL +5 -0
- accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
- accelforge-0.0.1.dist-info/top_level.txt +5 -0
- docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
- docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
- docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
- docs/_build/html/_sources/fastfusion.rst.txt +20 -0
- docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
- docs/_build/html/_sources/index.rst.txt +87 -0
- docs/_build/html/_sources/modules.rst.txt +7 -0
- docs/_build/html/_sources/notes/citation.rst.txt +45 -0
- docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
- docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
- docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
- docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
- docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
- docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
- docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
- docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
- docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
- docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
- docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
- docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
- docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
- docs/_build/html/_sources/notes/spec.rst.txt +36 -0
- docs/source/_ext/include_attrs.py +213 -0
- docs/source/_ext/include_docstring.py +364 -0
- docs/source/_ext/include_functions.py +154 -0
- docs/source/_ext/include_notebook.py +131 -0
- docs/source/_ext/include_yaml.py +119 -0
- docs/source/_ext/inherited_attributes.py +222 -0
- docs/source/_ext/paths.py +4 -0
- docs/source/conf.py +79 -0
- examples/arches/compute_in_memory/_include.yaml +74 -0
- examples/arches/compute_in_memory/_include_functions.py +229 -0
- examples/arches/compute_in_memory/_load_spec.py +57 -0
- examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
- examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
- examples/arches/compute_in_memory/components/misc.py +195 -0
- examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
- examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
- examples/arches/compute_in_memory/isaac.yaml +233 -0
- examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
- examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
- examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
- examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
- examples/arches/eyeriss.yaml +68 -0
- examples/arches/fanout_variations/at_glb.yaml +31 -0
- examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
- examples/arches/fanout_variations/at_mac.yaml +31 -0
- examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
- examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
- examples/arches/nvdla.yaml +47 -0
- examples/arches/simple.yaml +28 -0
- examples/arches/tpu_v4i.yaml +67 -0
- examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
- examples/misc/component_annotated.yaml +33 -0
- examples/workloads/gpt3_6.7B.yaml +124 -0
- examples/workloads/matmuls.yaml +20 -0
- examples/workloads/mobilenet_28.yaml +81 -0
- examples/workloads/mobilenet_various_separate.yaml +106 -0
- examples/workloads/three_matmuls_annotated.yaml +59 -0
- notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
- notebooks/compute_in_memory/_scripts.py +339 -0
- notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
- notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
- notebooks/paths.py +4 -0
- notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
- notebooks/tutorials/FFM.ipynb +3498 -0
- notebooks/tutorials/_include.py +48 -0
- notebooks/tutorials/component_energy_area.ipynb +363 -0
- tests/Q_mapping.yaml +38 -0
- tests/__init__.py +0 -0
- tests/conv.mapping.yaml +27 -0
- tests/conv.workload.yaml +13 -0
- tests/conv_sym.mapping.yaml +43 -0
- tests/copy.mapping.yaml +35 -0
- tests/copy.workload.yaml +15 -0
- tests/distribuffers/__init__.py +0 -0
- tests/distribuffers/multicast/test_cases.yaml +482 -0
- tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
- tests/distribuffers/spec/distributed.yaml +100 -0
- tests/distribuffers/spec/logical_arch.yaml +32 -0
- tests/distribuffers/spec/physical_arch.yaml +69 -0
- tests/distribuffers/test_binding.py +48 -0
- tests/frontend/__init__.py +0 -0
- tests/frontend/test_mapping_viz.py +52 -0
- tests/mapper/__init__.py +0 -0
- tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
- tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
- tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
- tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
- tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
- tests/mapper/test_mapping_to_isl.py +90 -0
- tests/mapper/test_spatial_reuse_analysis.py +67 -0
- tests/mapper/test_temporal_reuse_analysis.py +56 -0
- tests/mapper/util.py +58 -0
- tests/matmul.mapping.yaml +29 -0
- tests/matmul.workload.yaml +12 -0
- tests/matmul_spatial.mapping.yaml +44 -0
- tests/mha.renames.yaml +65 -0
- tests/mha.workload.yaml +67 -0
- tests/mha.yaml +59 -0
- tests/mha_full.workload.yaml +67 -0
- tests/mobilenet.workload.yaml +35 -0
- tests/mobilenet_long.workload.yaml +64 -0
- tests/pmappingcache.py +24 -0
- tests/processing_stage.arch.yaml +40 -0
- tests/snowcat.arch.yaml +36 -0
- tests/test_ffm_join_pmappings.py +106 -0
- tests/test_ffm_make_pmappings.py +82 -0
- tests/test_ffm_make_tile_shapes.py +49 -0
- tests/test_mapper.py +100 -0
- tests/test_model.py +37 -0
- tests/test_plotting.py +72 -0
- tests/test_processing_stage.py +46 -0
- tests/test_symbolic_model.py +248 -0
- tests/test_workload.py +141 -0
|
@@ -0,0 +1,1736 @@
|
|
|
1
|
+
"""
|
|
2
|
+
A module containing the visualization and types needed to run mapspace exploratioon
|
|
3
|
+
in AccelForge.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import copy
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
import inspect
|
|
9
|
+
import itertools
|
|
10
|
+
import pydot
|
|
11
|
+
|
|
12
|
+
from typing import (
|
|
13
|
+
# Collections
|
|
14
|
+
Any,
|
|
15
|
+
List,
|
|
16
|
+
# Object definitions
|
|
17
|
+
Annotated,
|
|
18
|
+
Callable,
|
|
19
|
+
Literal,
|
|
20
|
+
Self,
|
|
21
|
+
# Type constructions
|
|
22
|
+
Type,
|
|
23
|
+
TypeVar,
|
|
24
|
+
TypeAlias,
|
|
25
|
+
# Variable meta-mandates
|
|
26
|
+
Optional,
|
|
27
|
+
Union,
|
|
28
|
+
override,
|
|
29
|
+
)
|
|
30
|
+
from collections.abc import Set
|
|
31
|
+
from pydantic import ConfigDict, Discriminator, Tag, computed_field
|
|
32
|
+
import sympy
|
|
33
|
+
|
|
34
|
+
from accelforge.util._basetypes import (
|
|
35
|
+
# Parsing helpers for the input files.
|
|
36
|
+
ParsableModel,
|
|
37
|
+
ParsableList,
|
|
38
|
+
ParsesTo,
|
|
39
|
+
# Retrieves information from YAML tags.
|
|
40
|
+
_get_tag,
|
|
41
|
+
_uninstantiable,
|
|
42
|
+
NoParse,
|
|
43
|
+
)
|
|
44
|
+
from accelforge.frontend.workload import RankVariable, TensorName
|
|
45
|
+
from accelforge.util._visualization import ColorMap, _pydot_graph
|
|
46
|
+
from accelforge.util.parallel import _SVGJupyterRender
|
|
47
|
+
from accelforge._version import __version__
|
|
48
|
+
from accelforge.frontend import arch
|
|
49
|
+
|
|
50
|
+
T = TypeVar("T", bound="MappingNode")
|
|
51
|
+
"""TypeVar T: Restricts the allowable types to types of MappingNodes."""
|
|
52
|
+
|
|
53
|
+
NodeList: TypeAlias = ParsableList[
|
|
54
|
+
Annotated[
|
|
55
|
+
Union[
|
|
56
|
+
Annotated["Split", Tag("Split")],
|
|
57
|
+
Annotated["Compute", Tag("Compute")],
|
|
58
|
+
Annotated["Storage", Tag("Storage")],
|
|
59
|
+
Annotated["Temporal", Tag("Temporal")],
|
|
60
|
+
Annotated["Spatial", Tag("Spatial")],
|
|
61
|
+
Annotated["Sequential", Tag("Sequential")],
|
|
62
|
+
Annotated["Pipeline", Tag("Pipeline")],
|
|
63
|
+
Annotated["Nested", Tag("Nested")],
|
|
64
|
+
Annotated["Reservation", Tag("Reservation")],
|
|
65
|
+
Annotated["Mapping", Tag("Mapping")],
|
|
66
|
+
Annotated["ProcessingStage", Tag("ProcessingStage")],
|
|
67
|
+
],
|
|
68
|
+
Discriminator(_get_tag),
|
|
69
|
+
]
|
|
70
|
+
]
|
|
71
|
+
"""
|
|
72
|
+
TypeAlias NodeList: ParsableList that can contain and discriminate between
|
|
73
|
+
MappingNodes of different types.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
_NO_JOIN_MAPPING_VISUALIZATION = False
|
|
77
|
+
|
|
78
|
+
# =============================================================================
|
|
79
|
+
# LoopTree Mapping Nodes
|
|
80
|
+
# =============================================================================
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@_uninstantiable
|
|
84
|
+
class MappingNode(ParsableModel):
|
|
85
|
+
"""
|
|
86
|
+
Represents a Node in the Mapping, which can be a loop, a storage node, a compute
|
|
87
|
+
node, etc.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
_constraint_lambdas: List[Callable[[], bool]] = []
|
|
91
|
+
""" Constraints that apply to this node. """
|
|
92
|
+
|
|
93
|
+
_must_be_here: bool = False
|
|
94
|
+
""" Can the mapper move this node? """
|
|
95
|
+
|
|
96
|
+
_required: bool = False
|
|
97
|
+
""" Must the mapper keep this node? """
|
|
98
|
+
|
|
99
|
+
def _render_node_name(self) -> str:
|
|
100
|
+
"""The name for a Pydot node."""
|
|
101
|
+
return f"{self.__class__.__name__}_{id(self)}"
|
|
102
|
+
|
|
103
|
+
def _render_node_label(self, **kwargs) -> str:
|
|
104
|
+
"""The label for a Pydot node."""
|
|
105
|
+
return self.__str__()
|
|
106
|
+
|
|
107
|
+
def _render_node_shape(self) -> str:
|
|
108
|
+
"""The shape for a Pydot node."""
|
|
109
|
+
return "box"
|
|
110
|
+
|
|
111
|
+
def _render_node(self, **kwargs) -> str:
|
|
112
|
+
"""Render this node using Pydot."""
|
|
113
|
+
return pydot.Node(
|
|
114
|
+
self._render_node_name(),
|
|
115
|
+
label=self._render_node_label(**kwargs),
|
|
116
|
+
shape=self._render_node_shape(),
|
|
117
|
+
style="filled",
|
|
118
|
+
fillcolor=self._render_node_color(),
|
|
119
|
+
margin=0,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
def _parent2next(self) -> "MappingNode":
|
|
123
|
+
"""
|
|
124
|
+
Return the parent to the next node in the tree. This is used for nodes that
|
|
125
|
+
don't appear in the tree, like Nested nodes.
|
|
126
|
+
"""
|
|
127
|
+
return self
|
|
128
|
+
|
|
129
|
+
def _parent2child(
|
|
130
|
+
self, parent: "MappingNode"
|
|
131
|
+
) -> list[tuple["MappingNode", "MappingNode"]]:
|
|
132
|
+
"""
|
|
133
|
+
Returns a list of tuples, each one being a parent and child node.
|
|
134
|
+
"""
|
|
135
|
+
return []
|
|
136
|
+
|
|
137
|
+
def _render_make_children(self, **kwargs) -> list[str]:
|
|
138
|
+
"""
|
|
139
|
+
Renders the children of this node and returns them as a list of strings.
|
|
140
|
+
"""
|
|
141
|
+
return []
|
|
142
|
+
|
|
143
|
+
def get_nodes_of_type(self, types: Type[T] | tuple[Type[T], ...]) -> List[T]:
|
|
144
|
+
"""
|
|
145
|
+
Returns all sub-nodes, including this one, that match the given types.
|
|
146
|
+
"""
|
|
147
|
+
nodes: List[T] = []
|
|
148
|
+
if isinstance(self, types):
|
|
149
|
+
nodes.append(self)
|
|
150
|
+
if isinstance(self, MappingNodeWithChildren):
|
|
151
|
+
for node in self.nodes:
|
|
152
|
+
if isinstance(node, types):
|
|
153
|
+
nodes.append(node)
|
|
154
|
+
if isinstance(node, MappingNodeWithChildren):
|
|
155
|
+
nodes.extend(node.get_nodes_of_type(types))
|
|
156
|
+
return nodes
|
|
157
|
+
|
|
158
|
+
def _flatten(self) -> list["MappingNode"]:
|
|
159
|
+
if isinstance(self, MappingNodeWithChildren):
|
|
160
|
+
result = [self]
|
|
161
|
+
for node in self.nodes:
|
|
162
|
+
result.extend(node._flatten())
|
|
163
|
+
return result
|
|
164
|
+
return [self]
|
|
165
|
+
|
|
166
|
+
def _render_node_color(self) -> str:
|
|
167
|
+
"""The color for a Pydot node."""
|
|
168
|
+
return "white"
|
|
169
|
+
|
|
170
|
+
def __hash__(self):
|
|
171
|
+
"""
|
|
172
|
+
Hashing functor to create mappings of nodes to other objects.
|
|
173
|
+
"""
|
|
174
|
+
return id(self)
|
|
175
|
+
|
|
176
|
+
def __eq__(self, other: Any):
|
|
177
|
+
return self is other
|
|
178
|
+
|
|
179
|
+
def __init_subclass__(cls, **kwargs):
|
|
180
|
+
# Let Pydantic build the subclass first.
|
|
181
|
+
super().__init_subclass__(**kwargs)
|
|
182
|
+
# Read the *raw* attribute without descriptor binding,
|
|
183
|
+
h = inspect.getattr_static(cls, "__hash__", None)
|
|
184
|
+
# Replace if unhashable (None) or if it's just BaseModel’s default.
|
|
185
|
+
if h is None or h is ParsableModel.__hash__:
|
|
186
|
+
cls.__hash__ = MappingNode.__hash__
|
|
187
|
+
|
|
188
|
+
def compact_str(self) -> str:
|
|
189
|
+
"""Returns a compact string representation of this node."""
|
|
190
|
+
return self.__str__()
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
@dataclass(frozen=True)
|
|
194
|
+
class TilePattern:
|
|
195
|
+
tile_shape: ParsesTo[
|
|
196
|
+
Literal["symbol"] | sympy.Symbol | int | str | None | sympy.Expr
|
|
197
|
+
] = "symbol"
|
|
198
|
+
"""
|
|
199
|
+
The common tile shape of the pattern. This is the number of indices by which
|
|
200
|
+
the tile moves each iteration.
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
initial_tile_shape: ParsesTo[
|
|
204
|
+
Literal["symbol"] | sympy.Symbol | int | None | str | sympy.Expr
|
|
205
|
+
] = "symbol"
|
|
206
|
+
"""
|
|
207
|
+
The initial tile shape. This is the shape of the tile at the first iteration.
|
|
208
|
+
Subsequent iterations may be smaller if they overlap previous iterations.
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
calculated_n_iterations: (
|
|
212
|
+
Literal["symbol"] | sympy.Symbol | int | None | str | sympy.Expr
|
|
213
|
+
) = None
|
|
214
|
+
""" The number of iterations in the pattern. Do not set this! Used internally by the
|
|
215
|
+
mapper. """
|
|
216
|
+
|
|
217
|
+
def _symbol_attrs(self) -> tuple[str, ...]:
|
|
218
|
+
"""The attributes that may be symbols."""
|
|
219
|
+
return ("tile_shape", "initial_tile_shape", "calculated_n_iterations")
|
|
220
|
+
|
|
221
|
+
def __str__(self) -> str:
|
|
222
|
+
return self.as_str()
|
|
223
|
+
|
|
224
|
+
def as_str(self, with_initial_tile_shape=True, with_tile_shape=True):
|
|
225
|
+
s = []
|
|
226
|
+
if self.calculated_n_iterations not in (None, "symbol"):
|
|
227
|
+
s.append(f"in [0..{self.calculated_n_iterations})")
|
|
228
|
+
if with_initial_tile_shape and (
|
|
229
|
+
self.initial_tile_shape not in (None, "symbol")
|
|
230
|
+
):
|
|
231
|
+
s.append(f"initial={self.initial_tile_shape}")
|
|
232
|
+
if with_tile_shape and (self.tile_shape not in (None, "symbol")):
|
|
233
|
+
s.append(f"tile_shape={self.tile_shape}")
|
|
234
|
+
return " ".join(s)
|
|
235
|
+
|
|
236
|
+
def update(self, **kwargs) -> "TilePattern":
|
|
237
|
+
"""Update the TilePattern with the given keyword arguments."""
|
|
238
|
+
return type(self)(**{**self.__dict__, **kwargs})
|
|
239
|
+
|
|
240
|
+
def _symbol2str(self) -> "TilePattern":
|
|
241
|
+
"""
|
|
242
|
+
Convert the symbols in the TilePattern to strings, and return a new TilePattern
|
|
243
|
+
with the symbols replaced by their names.
|
|
244
|
+
"""
|
|
245
|
+
|
|
246
|
+
def _symbol2str(x: sympy.Symbol | int | None) -> str | int | None:
|
|
247
|
+
return x.name if isinstance(x, sympy.Symbol) else x
|
|
248
|
+
|
|
249
|
+
return type(self)(
|
|
250
|
+
**{x: _symbol2str(getattr(self, x)) for x in self._symbol_attrs()}
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
def _prepend_symbols(self, prepend: str) -> "TilePattern":
|
|
254
|
+
def _prepend(x: sympy.Symbol | int | None) -> str | int | None:
|
|
255
|
+
if isinstance(x, sympy.Symbol):
|
|
256
|
+
x = x.name
|
|
257
|
+
return prepend + x if isinstance(x, str) else x
|
|
258
|
+
|
|
259
|
+
return self.update(
|
|
260
|
+
{x: _prepend(getattr(self, x)) for x in self._symbol_attrs()}
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
def __eq__(self, other: Any) -> bool:
|
|
264
|
+
if not isinstance(other, TilePattern):
|
|
265
|
+
return False
|
|
266
|
+
return all(getattr(self, x) == getattr(other, x) for x in self._symbol_attrs())
|
|
267
|
+
|
|
268
|
+
def __hash__(self) -> int:
|
|
269
|
+
return hash((self.initial_tile_shape, self.tile_shape))
|
|
270
|
+
|
|
271
|
+
def _rename_to_match(
|
|
272
|
+
self, other: "TilePattern"
|
|
273
|
+
) -> tuple["TilePattern", dict[str, str]]:
|
|
274
|
+
"""
|
|
275
|
+
Changes the symbols in this TilePattern to match the other TilePattern.
|
|
276
|
+
|
|
277
|
+
Parameters
|
|
278
|
+
----------
|
|
279
|
+
other:
|
|
280
|
+
The TilePattern to match.
|
|
281
|
+
|
|
282
|
+
Returns
|
|
283
|
+
-------
|
|
284
|
+
A tuple containing the updated TilePattern and a dictionary of source->target
|
|
285
|
+
symbol renames.
|
|
286
|
+
"""
|
|
287
|
+
renames = {}
|
|
288
|
+
setattrs = {}
|
|
289
|
+
for x in self._symbol_attrs():
|
|
290
|
+
if getattr(self, x) != getattr(other, x):
|
|
291
|
+
renames[getattr(self, x)] = getattr(other, x)
|
|
292
|
+
setattrs[x] = getattr(other, x)
|
|
293
|
+
return self.update(**setattrs), renames
|
|
294
|
+
|
|
295
|
+
def _clear_symbols(self) -> "TilePattern":
|
|
296
|
+
"""
|
|
297
|
+
Clears the symbols in this TilePattern, replacing them with None.
|
|
298
|
+
"""
|
|
299
|
+
|
|
300
|
+
def desymbol(x: str | sympy.Symbol | int | None) -> str | int | None:
|
|
301
|
+
if isinstance(x, (str, sympy.Symbol)):
|
|
302
|
+
return None
|
|
303
|
+
return x
|
|
304
|
+
|
|
305
|
+
return self.update(
|
|
306
|
+
**{x: desymbol(getattr(self, x)) for x in self._symbol_attrs()}
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
@_uninstantiable
|
|
311
|
+
class Loop(MappingNode):
|
|
312
|
+
"""
|
|
313
|
+
A bounded loop over a rank variable with a given shape and/or pattern.
|
|
314
|
+
|
|
315
|
+
Do not instantiate directly; inherited by :class:`~.Temporal` and
|
|
316
|
+
:class:`~.Spatial`.
|
|
317
|
+
"""
|
|
318
|
+
|
|
319
|
+
rank_variable: set[RankVariable] | RankVariable
|
|
320
|
+
""" The rank variable(s) iterated over in this loop. This may be a
|
|
321
|
+
single rank variable, or a set of rank variables if the loop is shared between
|
|
322
|
+
multiple Einsums. """
|
|
323
|
+
|
|
324
|
+
tile_shape: ParsesTo[sympy.Symbol | sympy.Expr | int | str] = "symbol"
|
|
325
|
+
"""
|
|
326
|
+
The (common) tile shape of the iteration. For example, if the iteration
|
|
327
|
+
space is range(6) and the tile shape is 3, then we create and iterate over
|
|
328
|
+
two tiles [0, 1, 2] and [3, 4, 5].
|
|
329
|
+
|
|
330
|
+
This attribute specifies the *common* tile shape because
|
|
331
|
+
`initial_tile_shape` may be specified.
|
|
332
|
+
|
|
333
|
+
For users writing YAML, the value should be an integer.
|
|
334
|
+
|
|
335
|
+
For those developing the mapper, the literal string "symbol" is often used
|
|
336
|
+
to tell the model to create a sympy symbol to use as the tile shape. Any
|
|
337
|
+
other string may be specified to explicitly request a variable name (later
|
|
338
|
+
converted to a sympy variable).
|
|
339
|
+
"""
|
|
340
|
+
|
|
341
|
+
initial_tile_shape: ParsesTo[sympy.Symbol | sympy.Expr | int | str | None] = None
|
|
342
|
+
"""
|
|
343
|
+
The shape of the first tile shape. This attribute is optional. If not
|
|
344
|
+
specified, all tiles have the same shape.
|
|
345
|
+
|
|
346
|
+
If specified, the initial tile shape may differ. For example, an initial
|
|
347
|
+
tile shape of 3 and tile shape of 2 creates the following tiles in the
|
|
348
|
+
iteration space: [0, 1, 2], [3, 4], [5, 6], ...
|
|
349
|
+
|
|
350
|
+
Similarly to tile shape, this value should be an integer when writing a
|
|
351
|
+
YAML input.
|
|
352
|
+
|
|
353
|
+
For those developing the mapper, this attribute can be a string. See
|
|
354
|
+
tile_shape for details.
|
|
355
|
+
"""
|
|
356
|
+
|
|
357
|
+
_calculated_n_iterations: (
|
|
358
|
+
Literal["symbol"] | sympy.Symbol | sympy.Expr | int | str | None
|
|
359
|
+
) = None
|
|
360
|
+
|
|
361
|
+
_assume_perfect_factor: bool = True
|
|
362
|
+
""" Whether the Mapper assumes that tile shapes perfectly divide tensor shapes and
|
|
363
|
+
parent tile shapes. """
|
|
364
|
+
|
|
365
|
+
_fused: bool = None
|
|
366
|
+
""" Whether this Loop is shared with another Einsum. """
|
|
367
|
+
|
|
368
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
369
|
+
|
|
370
|
+
def __str__(self) -> str:
|
|
371
|
+
return f"for {self.rank_variable} {self.tile_pattern}"
|
|
372
|
+
|
|
373
|
+
def __eq__(self, other: Any) -> bool:
|
|
374
|
+
return (
|
|
375
|
+
isinstance(other, Loop)
|
|
376
|
+
and self.rank_variable == other.rank_variable
|
|
377
|
+
and self.tile_pattern == other.tile_pattern
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
def _render_node_shape(self) -> str:
|
|
381
|
+
return "box"
|
|
382
|
+
|
|
383
|
+
def _render_node_color(self) -> str:
|
|
384
|
+
return "#FCC2FC"
|
|
385
|
+
|
|
386
|
+
@override
|
|
387
|
+
def compact_str(self) -> str:
|
|
388
|
+
"""Returns a compact string representation of this Loop."""
|
|
389
|
+
rv = self.rank_variable
|
|
390
|
+
if isinstance(rv, (set, frozenset)):
|
|
391
|
+
rv = ",".join(sorted(rv))
|
|
392
|
+
return f"{rv} {self.tile_pattern}"
|
|
393
|
+
|
|
394
|
+
def _merge(self, other: "Loop", **kwargs) -> "Loop":
|
|
395
|
+
"""Merge this Loop with another Loop, returning the result."""
|
|
396
|
+
if not isinstance(other, Loop):
|
|
397
|
+
raise ValueError(f"Expected Loop, got {type(other)}")
|
|
398
|
+
if self.tile_pattern != other.tile_pattern:
|
|
399
|
+
raise ValueError(
|
|
400
|
+
f"Tile patterns do not match: {self.tile_pattern} != {other.tile_pattern}"
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
my_rv, other_rv = self.rank_variable, other.rank_variable
|
|
404
|
+
my_rv = my_rv if isinstance(my_rv, (set, frozenset)) else set((my_rv,))
|
|
405
|
+
other_rv = (
|
|
406
|
+
other_rv if isinstance(other_rv, (set, frozenset)) else set((other_rv,))
|
|
407
|
+
)
|
|
408
|
+
return type(self)(
|
|
409
|
+
rank_variable=my_rv | other_rv,
|
|
410
|
+
tile_pattern=self.tile_pattern,
|
|
411
|
+
_assume_perfect_factor=self._assume_perfect_factor,
|
|
412
|
+
**kwargs,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
@property
|
|
416
|
+
def tile_pattern(self) -> TilePattern:
|
|
417
|
+
return TilePattern(
|
|
418
|
+
tile_shape=self.tile_shape,
|
|
419
|
+
initial_tile_shape=self.initial_tile_shape,
|
|
420
|
+
calculated_n_iterations=self.calculated_n_iterations,
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
@tile_pattern.setter
|
|
424
|
+
def tile_pattern(self, value: TilePattern):
|
|
425
|
+
self.tile_shape = value.tile_shape
|
|
426
|
+
self.initial_tile_shape = value.initial_tile_shape
|
|
427
|
+
self.calculated_n_iterations = value.calculated_n_iterations
|
|
428
|
+
|
|
429
|
+
@property
|
|
430
|
+
def calculated_n_iterations(self) -> int:
|
|
431
|
+
"""The number of iterations performed by this loop."""
|
|
432
|
+
return self._calculated_n_iterations
|
|
433
|
+
|
|
434
|
+
@calculated_n_iterations.setter
|
|
435
|
+
def calculated_n_iterations(self, value: int) -> None:
|
|
436
|
+
"""Set the number of iterations performed by this loop. Do not set this!
|
|
437
|
+
This is calculated by the Mapper."""
|
|
438
|
+
self._calculated_n_iterations = value
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
class Temporal(Loop):
|
|
442
|
+
"""A Temporal :class:`~.Loop`."""
|
|
443
|
+
|
|
444
|
+
@override
|
|
445
|
+
def compact_str(self) -> str:
|
|
446
|
+
return f"T-{super().compact_str()}"
|
|
447
|
+
|
|
448
|
+
def __eq__(self, other: "Temporal") -> bool:
|
|
449
|
+
return isinstance(other, Temporal) and super().__eq__(other)
|
|
450
|
+
|
|
451
|
+
def _merge(self, other: "Temporal") -> "Temporal":
|
|
452
|
+
if not isinstance(other, Temporal):
|
|
453
|
+
raise ValueError(f"Expected Temporal, got {type(other)}")
|
|
454
|
+
return super()._merge(other)
|
|
455
|
+
|
|
456
|
+
def _render_node_label(self, **kwargs) -> str:
|
|
457
|
+
with_initial_tile_shape = True
|
|
458
|
+
with_tile_shape = kwargs.get("with_tile_shape", True)
|
|
459
|
+
return (
|
|
460
|
+
f"for {self.rank_variable} "
|
|
461
|
+
f"{self.tile_pattern.as_str(with_initial_tile_shape, with_tile_shape)}"
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
class Spatial(Loop):
|
|
466
|
+
"""A spatial :class:`~.Loop`."""
|
|
467
|
+
|
|
468
|
+
name: int | str
|
|
469
|
+
""" The dimension over which the spatial is occuring. """
|
|
470
|
+
|
|
471
|
+
component: str
|
|
472
|
+
""" The component name across which different spatial iterations occur. """
|
|
473
|
+
|
|
474
|
+
component_object: NoParse[arch.Leaf] = None
|
|
475
|
+
""" The component object across which different spatial iterations occur. """
|
|
476
|
+
|
|
477
|
+
_constrained_to_one: bool = False
|
|
478
|
+
""" Whether this Spatial loop is constrained to one iteration. Do not set this; used
|
|
479
|
+
internally by the Mapper."""
|
|
480
|
+
|
|
481
|
+
@override
|
|
482
|
+
def compact_str(self) -> str:
|
|
483
|
+
return f"S-{self.name}-{super().compact_str()}"
|
|
484
|
+
|
|
485
|
+
def __str__(self) -> str:
|
|
486
|
+
return f"S-{self.name} " + super().__str__()
|
|
487
|
+
|
|
488
|
+
def __eq__(self, other: "Spatial") -> bool:
|
|
489
|
+
return (
|
|
490
|
+
isinstance(other, Spatial)
|
|
491
|
+
and super().__eq__(other)
|
|
492
|
+
and self.name == other.name
|
|
493
|
+
and self.component == other.component
|
|
494
|
+
and self.component_object == other.component_object
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
def _merge(self, other: "Spatial") -> "Spatial":
|
|
498
|
+
if not isinstance(other, Spatial):
|
|
499
|
+
raise ValueError(f"Expected Spatial, got {type(other)}")
|
|
500
|
+
if self.name != other.name:
|
|
501
|
+
raise ValueError(f"Names do not match: {self.name} != {other.name}")
|
|
502
|
+
if self.component != other.component:
|
|
503
|
+
raise ValueError(
|
|
504
|
+
f"Components do not match: {self.component} != {other.component}"
|
|
505
|
+
)
|
|
506
|
+
return super()._merge(
|
|
507
|
+
other,
|
|
508
|
+
name=self.name,
|
|
509
|
+
component=self.component,
|
|
510
|
+
component_object=self.component_object,
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
def _render_node_label(self, **kwargs) -> str:
|
|
514
|
+
with_initial_tile_shape = kwargs.get("with_initial_tile_shape", True)
|
|
515
|
+
with_tile_shape = kwargs.get("with_tile_shape", True)
|
|
516
|
+
return (
|
|
517
|
+
f"S-{self.name}-for {self.rank_variable} "
|
|
518
|
+
f"{self.tile_pattern.as_str(with_initial_tile_shape, with_tile_shape)}"
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
class TensorHolder(MappingNode):
|
|
523
|
+
"""A node that represents a hardware Component holding a set of tensors."""
|
|
524
|
+
|
|
525
|
+
tensors: ParsableList[TensorName]
|
|
526
|
+
""" The names of the tensors being held in this node. """
|
|
527
|
+
|
|
528
|
+
component: str
|
|
529
|
+
""" The name of the component holding the tensors. """
|
|
530
|
+
|
|
531
|
+
component_object: NoParse[arch.Component] = None
|
|
532
|
+
""" The component object holding the tensors. """
|
|
533
|
+
|
|
534
|
+
_must_keep_tensors: ParsableList[TensorName] = ParsableList()
|
|
535
|
+
""" Which tensor(s) the Mapper must keep here. Do not set this! Used internally by
|
|
536
|
+
the Mapper."""
|
|
537
|
+
|
|
538
|
+
_backing: Set[TensorName] = set()
|
|
539
|
+
""" Which tensor(s) are backed by this node. Do not set this! Used internally by
|
|
540
|
+
the Mapper."""
|
|
541
|
+
|
|
542
|
+
_lower: bool = True
|
|
543
|
+
""" Whether this tensor holder can be lowered. Do not set this! Used internally by
|
|
544
|
+
the Mapper."""
|
|
545
|
+
|
|
546
|
+
persistent: bool = False
|
|
547
|
+
"""
|
|
548
|
+
Whether this tensor holder is persistent. Persistent tensors can't be tiled and must
|
|
549
|
+
be kept in backing storage for the full duration of the workload's execution.
|
|
550
|
+
"""
|
|
551
|
+
|
|
552
|
+
def __eq__(self, other: Any) -> bool:
|
|
553
|
+
return (
|
|
554
|
+
isinstance(other, TensorHolder)
|
|
555
|
+
and set(self.tensors) == set(other.tensors)
|
|
556
|
+
and self.component == other.component
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
@override
|
|
560
|
+
def compact_str(self) -> str:
|
|
561
|
+
tname = ",".join(self.tensors)
|
|
562
|
+
return f"[{tname} in {self.component}]"
|
|
563
|
+
|
|
564
|
+
def __str__(self, color_map: ColorMap = None) -> str:
|
|
565
|
+
tensors = self.tensors
|
|
566
|
+
if color_map is not None:
|
|
567
|
+
format_list = [f"{self.component} reuses"] + list(tensors)
|
|
568
|
+
return color_map.format_list(format_list)
|
|
569
|
+
return f"{self.component} reuses {', '.join(tensors)}"
|
|
570
|
+
|
|
571
|
+
@property
|
|
572
|
+
def tensor(self) -> TensorName:
|
|
573
|
+
"""If there is one tensor held in this tensor holder, returns its name.
|
|
574
|
+
Otherwise, raises an error."""
|
|
575
|
+
if len(self.tensors) != 1:
|
|
576
|
+
raise ValueError(
|
|
577
|
+
f"TensorHolder node {repr(self)} has {len(self.tensors)} tensors. "
|
|
578
|
+
f"Access the tensors property instead."
|
|
579
|
+
)
|
|
580
|
+
return self.tensors[0]
|
|
581
|
+
|
|
582
|
+
def _render_node_shape(self) -> str:
|
|
583
|
+
return "cylinder"
|
|
584
|
+
|
|
585
|
+
def _render_node_color(self) -> str:
|
|
586
|
+
return "#D7FCD7"
|
|
587
|
+
|
|
588
|
+
def _merge(self, other: "TensorHolder") -> "TensorHolder":
|
|
589
|
+
if not isinstance(other, TensorHolder):
|
|
590
|
+
raise ValueError(f"Expected TensorHolder, got {type(other)}")
|
|
591
|
+
|
|
592
|
+
if self.component != other.component:
|
|
593
|
+
raise ValueError(
|
|
594
|
+
f"Components do not match: {self.component} != {other.component}"
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
new = type(self)(
|
|
598
|
+
tensors=self.tensors + other.tensors,
|
|
599
|
+
component=self.component,
|
|
600
|
+
component_object=self.component_object,
|
|
601
|
+
)
|
|
602
|
+
return new
|
|
603
|
+
|
|
604
|
+
|
|
605
|
+
class Storage(TensorHolder):
|
|
606
|
+
"""
|
|
607
|
+
A Storage :class:`~.TensorHolder` that can hold tensors for reuse.
|
|
608
|
+
"""
|
|
609
|
+
|
|
610
|
+
def _merge(self, other: "Storage") -> "Storage":
|
|
611
|
+
if not isinstance(other, Storage):
|
|
612
|
+
raise ValueError(f"Expected Storage, got {type(other)}")
|
|
613
|
+
return super()._merge(other)
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
class ProcessingStage(TensorHolder):
|
|
617
|
+
"""
|
|
618
|
+
A ProcessingStage :class:`~.TensorHolder` that acts as a pass-through, where data is
|
|
619
|
+
not reused but incurs accesses into this ProcessingStage.
|
|
620
|
+
"""
|
|
621
|
+
|
|
622
|
+
def _render_node_shape(self) -> str:
|
|
623
|
+
return "rarrow"
|
|
624
|
+
|
|
625
|
+
def _render_node_color(self) -> str:
|
|
626
|
+
return "#FFCC99"
|
|
627
|
+
|
|
628
|
+
def __str__(self, color_map: ColorMap = None) -> str:
|
|
629
|
+
tensors = self.tensors
|
|
630
|
+
if color_map is not None:
|
|
631
|
+
format_list = [f"{self.component} processes"] + list(tensors)
|
|
632
|
+
return color_map.format_list(format_list)
|
|
633
|
+
return f"{self.component} processes {', '.join(tensors)}"
|
|
634
|
+
|
|
635
|
+
|
|
636
|
+
class Compute(MappingNode):
|
|
637
|
+
"""A node that represents a compute operation. These nodes are the leaves of the
|
|
638
|
+
LoopTree."""
|
|
639
|
+
|
|
640
|
+
einsum: str
|
|
641
|
+
""" The Einsum being computed. """
|
|
642
|
+
|
|
643
|
+
component: str
|
|
644
|
+
""" The name of the compute component performing the computation. """
|
|
645
|
+
|
|
646
|
+
component_object: NoParse[arch.Compute | None] = None
|
|
647
|
+
""" The :class:`~accelforge.frontend.arch.Compute` object performing the
|
|
648
|
+
computation. """
|
|
649
|
+
|
|
650
|
+
@override
|
|
651
|
+
def compact_str(self) -> str:
|
|
652
|
+
return f"{self.component} computes {self.einsum}"
|
|
653
|
+
|
|
654
|
+
def __str__(self) -> str:
|
|
655
|
+
return f"{self.component} computes {self.einsum}"
|
|
656
|
+
|
|
657
|
+
def _render_node_shape(self) -> str:
|
|
658
|
+
return "ellipse"
|
|
659
|
+
|
|
660
|
+
def _render_node_color(self) -> str:
|
|
661
|
+
return "#E0EEFF"
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
class MappingNodeWithChildren(MappingNode):
|
|
665
|
+
"""
|
|
666
|
+
A :class:`~.MappingNode` that also has child nodes.
|
|
667
|
+
"""
|
|
668
|
+
|
|
669
|
+
nodes: NodeList = ParsableList()
|
|
670
|
+
""" The child nodes. """
|
|
671
|
+
|
|
672
|
+
@override
|
|
673
|
+
def _parent2child(
|
|
674
|
+
self, parent: MappingNode
|
|
675
|
+
) -> list[tuple[MappingNode, MappingNode]]:
|
|
676
|
+
mine = [(self, node) for node in self.nodes]
|
|
677
|
+
for child in self.nodes:
|
|
678
|
+
mine.extend(child._parent2child(self))
|
|
679
|
+
return mine
|
|
680
|
+
|
|
681
|
+
@override
|
|
682
|
+
def _parent2next(self) -> MappingNode:
|
|
683
|
+
return None
|
|
684
|
+
|
|
685
|
+
@override
|
|
686
|
+
def _render_make_children(self, **kwargs) -> list[str]:
|
|
687
|
+
exclude_types = kwargs.get("exclude_types", tuple())
|
|
688
|
+
lines = []
|
|
689
|
+
for child in self.nodes:
|
|
690
|
+
if not isinstance(child, exclude_types):
|
|
691
|
+
lines.append(child._render_node(**kwargs))
|
|
692
|
+
lines.extend(child._render_make_children(**kwargs))
|
|
693
|
+
return lines
|
|
694
|
+
|
|
695
|
+
@override
|
|
696
|
+
def _get_backers(self) -> list[TensorHolder]:
|
|
697
|
+
backing = []
|
|
698
|
+
for child in self.nodes:
|
|
699
|
+
if isinstance(child, TensorHolder) and child._backing:
|
|
700
|
+
backing.append(child)
|
|
701
|
+
elif isinstance(child, MappingNodeWithChildren):
|
|
702
|
+
backing.extend(child._get_backers())
|
|
703
|
+
return backing
|
|
704
|
+
|
|
705
|
+
def clear_nodes_of_type(self, types: type | tuple[type]) -> None:
|
|
706
|
+
"""Clears all child nodes that match the given type(s)."""
|
|
707
|
+
new_nodes = []
|
|
708
|
+
for node in self.nodes:
|
|
709
|
+
if isinstance(node, types):
|
|
710
|
+
continue
|
|
711
|
+
if isinstance(node, MappingNodeWithChildren):
|
|
712
|
+
node.clear_nodes_of_type(types)
|
|
713
|
+
new_nodes.append(node)
|
|
714
|
+
self.nodes = ParsableList(new_nodes)
|
|
715
|
+
|
|
716
|
+
def clear_nodes(self, *nodes: MappingNode) -> None:
|
|
717
|
+
"""Removes nodes that equal any of the given nodes."""
|
|
718
|
+
new_nodes: list[MappingNode] = []
|
|
719
|
+
for node in self.nodes:
|
|
720
|
+
if any(n == node for n in nodes):
|
|
721
|
+
continue
|
|
722
|
+
if node in nodes:
|
|
723
|
+
continue
|
|
724
|
+
if isinstance(node, MappingNodeWithChildren):
|
|
725
|
+
node.clear_nodes(*nodes)
|
|
726
|
+
new_nodes.append(node)
|
|
727
|
+
self.nodes = ParsableList(new_nodes)
|
|
728
|
+
|
|
729
|
+
def _consolidate_tensor_holders(self) -> None:
|
|
730
|
+
new_nodes = []
|
|
731
|
+
for node in self.nodes:
|
|
732
|
+
if isinstance(node, TensorHolder):
|
|
733
|
+
found = False
|
|
734
|
+
for n in new_nodes[::-1]:
|
|
735
|
+
if isinstance(n, TensorHolder) and n.component == node.component:
|
|
736
|
+
n.tensors.extend(
|
|
737
|
+
n2 for n2 in node.tensors if n2 not in n.tensors
|
|
738
|
+
)
|
|
739
|
+
found = True
|
|
740
|
+
break
|
|
741
|
+
if isinstance(n, Loop):
|
|
742
|
+
break
|
|
743
|
+
if not found:
|
|
744
|
+
new_nodes.append(node)
|
|
745
|
+
else:
|
|
746
|
+
new_nodes.append(node)
|
|
747
|
+
if isinstance(node, MappingNodeWithChildren):
|
|
748
|
+
node._consolidate_tensor_holders()
|
|
749
|
+
assert new_nodes, "BUG"
|
|
750
|
+
self.nodes = ParsableList(new_nodes)
|
|
751
|
+
|
|
752
|
+
def _consolidate_reservations(self) -> None:
|
|
753
|
+
new_nodes = []
|
|
754
|
+
for node in self.nodes:
|
|
755
|
+
if isinstance(node, Reservation):
|
|
756
|
+
found = False
|
|
757
|
+
for n in new_nodes[::-1]:
|
|
758
|
+
if isinstance(n, Reservation) and n.resource == node.resource:
|
|
759
|
+
n.purposes.extend(node.purposes)
|
|
760
|
+
found = True
|
|
761
|
+
break
|
|
762
|
+
if isinstance(n, Loop):
|
|
763
|
+
break
|
|
764
|
+
if not found:
|
|
765
|
+
new_nodes.append(node)
|
|
766
|
+
else:
|
|
767
|
+
new_nodes.append(node)
|
|
768
|
+
if isinstance(node, MappingNodeWithChildren):
|
|
769
|
+
node._consolidate_reservations()
|
|
770
|
+
assert new_nodes, "BUG"
|
|
771
|
+
self.nodes = ParsableList(new_nodes)
|
|
772
|
+
|
|
773
|
+
def _elevate_persistent_nodes_above_splits(self) -> None:
|
|
774
|
+
new_nodes: list[MappingNode] = []
|
|
775
|
+
for node in self.nodes:
|
|
776
|
+
if isinstance(node, Split):
|
|
777
|
+
persistent_nodes = node._get_persistent_nodes()
|
|
778
|
+
new_nodes.extend(persistent_nodes)
|
|
779
|
+
node.clear_nodes(*persistent_nodes)
|
|
780
|
+
if isinstance(node, MappingNodeWithChildren):
|
|
781
|
+
node._elevate_persistent_nodes_above_splits()
|
|
782
|
+
new_nodes.append(node)
|
|
783
|
+
self.nodes = ParsableList(new_nodes)
|
|
784
|
+
|
|
785
|
+
def _elevate_tensor_holders_above_splits(self) -> None:
|
|
786
|
+
new_nodes: list[MappingNode] = []
|
|
787
|
+
for node in self.nodes:
|
|
788
|
+
if isinstance(node, Split):
|
|
789
|
+
shared_tensor_holders = node._get_shared_tensor_holders()
|
|
790
|
+
new_nodes.extend(shared_tensor_holders)
|
|
791
|
+
node.clear_nodes(*shared_tensor_holders)
|
|
792
|
+
if isinstance(node, MappingNodeWithChildren):
|
|
793
|
+
node._elevate_tensor_holders_above_splits()
|
|
794
|
+
new_nodes.append(node)
|
|
795
|
+
self.nodes = ParsableList(new_nodes)
|
|
796
|
+
|
|
797
|
+
def _propagate_reservations_between_splits(self) -> None:
|
|
798
|
+
for node in self.nodes:
|
|
799
|
+
if isinstance(node, MappingNodeWithChildren):
|
|
800
|
+
node._propagate_reservations_between_splits()
|
|
801
|
+
|
|
802
|
+
if not isinstance(self, Split):
|
|
803
|
+
return
|
|
804
|
+
|
|
805
|
+
for i, node1 in enumerate(self.nodes):
|
|
806
|
+
for j in range(i + 2, len(self.nodes)):
|
|
807
|
+
node2 = self.nodes[j]
|
|
808
|
+
reservations1 = node1.get_nodes_of_type(Reservation)
|
|
809
|
+
reservations2 = node2.get_nodes_of_type(Reservation)
|
|
810
|
+
|
|
811
|
+
shared_reservations = []
|
|
812
|
+
for reservation1 in reservations1:
|
|
813
|
+
for reservation2 in reservations2:
|
|
814
|
+
if reservation1 == reservation2:
|
|
815
|
+
shared_reservations.append(reservation1)
|
|
816
|
+
break
|
|
817
|
+
|
|
818
|
+
for s in shared_reservations:
|
|
819
|
+
for k in range(i + 1, j):
|
|
820
|
+
node3 = self.nodes[k]
|
|
821
|
+
if not isinstance(node3, Nested):
|
|
822
|
+
raise ValueError(f"Expected Nested node, got {type(node3)}")
|
|
823
|
+
reservations3 = node3.get_nodes_of_type(Reservation)
|
|
824
|
+
if s not in reservations3:
|
|
825
|
+
node3.nodes.insert(0, copy.deepcopy(s))
|
|
826
|
+
|
|
827
|
+
def _move_tensor_holders_above_reservations(self) -> None:
|
|
828
|
+
groups = []
|
|
829
|
+
cur_group = []
|
|
830
|
+
for node in self.nodes:
|
|
831
|
+
if isinstance(node, MappingNodeWithChildren):
|
|
832
|
+
node._move_tensor_holders_above_reservations()
|
|
833
|
+
if not isinstance(node, (TensorHolder, Reservation)):
|
|
834
|
+
groups.append(cur_group)
|
|
835
|
+
cur_group = []
|
|
836
|
+
cur_group.append(node)
|
|
837
|
+
groups.append(cur_group)
|
|
838
|
+
groups = [g for g in groups if g]
|
|
839
|
+
|
|
840
|
+
groups = [
|
|
841
|
+
[x for x in g if not isinstance(x, (TensorHolder, Reservation))]
|
|
842
|
+
+ [x for x in g if isinstance(x, (TensorHolder))]
|
|
843
|
+
+ [x for x in g if isinstance(x, (Reservation))]
|
|
844
|
+
for g in groups
|
|
845
|
+
]
|
|
846
|
+
self.nodes = ParsableList([x for g in groups for x in g])
|
|
847
|
+
|
|
848
|
+
def _remove_reservations_for_processing_stages(self) -> None:
|
|
849
|
+
processing_stages = self.get_nodes_of_type(ProcessingStage)
|
|
850
|
+
processing_stage_names = set(ps.component for ps in processing_stages)
|
|
851
|
+
reservations = self.get_nodes_of_type(Reservation)
|
|
852
|
+
remove = [r for r in reservations if r.resource in processing_stage_names]
|
|
853
|
+
self.clear_nodes(*remove)
|
|
854
|
+
|
|
855
|
+
|
|
856
|
+
class Split(MappingNodeWithChildren):
|
|
857
|
+
"""
|
|
858
|
+
A :class:`~.MappingNodeWithChildren` that splits the tree into multiple branches,
|
|
859
|
+
each applying to different Einsums.
|
|
860
|
+
"""
|
|
861
|
+
|
|
862
|
+
def __str__(self) -> str:
|
|
863
|
+
return "Split"
|
|
864
|
+
|
|
865
|
+
def _render_node_shape(self) -> str:
|
|
866
|
+
return "hexagon"
|
|
867
|
+
|
|
868
|
+
def _get_persistent_nodes(self) -> list[MappingNode]:
|
|
869
|
+
nodes = []
|
|
870
|
+
for n in self.nodes:
|
|
871
|
+
nodes.extend(n.get_nodes_of_type(TensorHolder))
|
|
872
|
+
nodes.extend(n.get_nodes_of_type(Reservation))
|
|
873
|
+
return [n for n in nodes if n.persistent]
|
|
874
|
+
|
|
875
|
+
def _get_shared_tensor_holders(self) -> list[TensorHolder]:
|
|
876
|
+
tensor_holders = [n.get_nodes_of_type(TensorHolder) for n in self.nodes]
|
|
877
|
+
shared_tensor_holders = []
|
|
878
|
+
for i in range(len(tensor_holders)):
|
|
879
|
+
for j in range(i + 1, len(tensor_holders)):
|
|
880
|
+
for a in tensor_holders[i]:
|
|
881
|
+
for b in tensor_holders[j]:
|
|
882
|
+
if a._backing & b._backing and a not in shared_tensor_holders:
|
|
883
|
+
assert len(a.tensors) == 1 and len(b.tensors) == 1, "BUG"
|
|
884
|
+
shared_tensor_holders.append(a)
|
|
885
|
+
break
|
|
886
|
+
return shared_tensor_holders
|
|
887
|
+
|
|
888
|
+
def _render_node_color(self) -> str:
|
|
889
|
+
return "#FFFFE0"
|
|
890
|
+
|
|
891
|
+
|
|
892
|
+
class Nested(MappingNodeWithChildren):
|
|
893
|
+
"""
|
|
894
|
+
A :class:`~.MappingNodeWithChildren` that represents a nested set of nodes. Each
|
|
895
|
+
node is the parent of the next node.
|
|
896
|
+
"""
|
|
897
|
+
|
|
898
|
+
def model_post_init(self, __context__=None) -> None:
|
|
899
|
+
for node in list(self.nodes)[:-1]:
|
|
900
|
+
assert not isinstance(
|
|
901
|
+
node, MappingNodeWithChildren
|
|
902
|
+
), f"Nested node has a child with children. Only the last child can have children."
|
|
903
|
+
|
|
904
|
+
def _parent2child(
|
|
905
|
+
self, parent: MappingNode
|
|
906
|
+
) -> list[tuple[MappingNode, MappingNode]]:
|
|
907
|
+
parent2child = []
|
|
908
|
+
for node in self.nodes:
|
|
909
|
+
parent2child.append((parent, node))
|
|
910
|
+
parent2child.extend(node._parent2child(parent))
|
|
911
|
+
parent = node._parent2next()
|
|
912
|
+
return parent2child
|
|
913
|
+
|
|
914
|
+
def _parent2next(self) -> MappingNode:
|
|
915
|
+
if not self.nodes:
|
|
916
|
+
raise ValueError("Nested node has no children")
|
|
917
|
+
return self.nodes[-1]._parent2next()
|
|
918
|
+
|
|
919
|
+
# def _render_connect_children(self, names_lines: list[tuple[str, str]], parent_name: str=None) -> list[str]:
|
|
920
|
+
# return super()._render_connect_children(names_lines)
|
|
921
|
+
|
|
922
|
+
def _render_node_label(self, **kwargs) -> str:
|
|
923
|
+
if not self.nodes:
|
|
924
|
+
raise ValueError("Nested node has no children")
|
|
925
|
+
return self.nodes[0]._render_node_label(**kwargs)
|
|
926
|
+
|
|
927
|
+
def _render_node_name(self) -> str:
|
|
928
|
+
if not self.nodes:
|
|
929
|
+
raise ValueError("Nested node has no children")
|
|
930
|
+
return self.nodes[0]._render_node_name()
|
|
931
|
+
|
|
932
|
+
def _get_n_shared_loops(self, other: "Nested") -> int:
|
|
933
|
+
my_backing = set(
|
|
934
|
+
(t, s.component) for s in self._get_backers() for t in s._backing
|
|
935
|
+
)
|
|
936
|
+
other_backing = set(
|
|
937
|
+
(t, s.component) for s in other._get_backers() for t in s._backing
|
|
938
|
+
)
|
|
939
|
+
shared_backing = my_backing & other_backing
|
|
940
|
+
|
|
941
|
+
if not shared_backing:
|
|
942
|
+
return 0
|
|
943
|
+
|
|
944
|
+
n_shared_loops = 0
|
|
945
|
+
for i, node in enumerate(self.nodes):
|
|
946
|
+
if isinstance(node, Loop):
|
|
947
|
+
n_shared_loops += 1
|
|
948
|
+
if (
|
|
949
|
+
isinstance(node, Reservation)
|
|
950
|
+
and (node.purpose, node.resource) in shared_backing
|
|
951
|
+
):
|
|
952
|
+
return n_shared_loops
|
|
953
|
+
if isinstance(node, Split):
|
|
954
|
+
for child in node.nodes:
|
|
955
|
+
max_child_n_shared_loops = 0
|
|
956
|
+
try:
|
|
957
|
+
max_child_n_shared_loops = max(
|
|
958
|
+
max_child_n_shared_loops, child._get_n_shared_loops(other)
|
|
959
|
+
)
|
|
960
|
+
except ValueError:
|
|
961
|
+
pass
|
|
962
|
+
return max_child_n_shared_loops + n_shared_loops
|
|
963
|
+
|
|
964
|
+
raise ValueError("BUG")
|
|
965
|
+
|
|
966
|
+
def _break_into_reorderable_groups(
|
|
967
|
+
self, stop_at_n_loops: int
|
|
968
|
+
) -> list[list[MappingNode]]:
|
|
969
|
+
# We can reorder loops relative to each other
|
|
970
|
+
groups = []
|
|
971
|
+
cur_group = None
|
|
972
|
+
|
|
973
|
+
seen_loops = 0
|
|
974
|
+
|
|
975
|
+
if stop_at_n_loops == 0 and not any(
|
|
976
|
+
isinstance(node, Loop) for node in self.nodes
|
|
977
|
+
):
|
|
978
|
+
return [list(self.nodes)]
|
|
979
|
+
|
|
980
|
+
i = 0
|
|
981
|
+
for i, node in enumerate(self.nodes):
|
|
982
|
+
if seen_loops >= stop_at_n_loops:
|
|
983
|
+
break
|
|
984
|
+
is_iteration = isinstance(node, Loop)
|
|
985
|
+
if cur_group is None:
|
|
986
|
+
cur_group = []
|
|
987
|
+
elif (is_iteration and not all(isinstance(x, Loop) for x in cur_group)) or (
|
|
988
|
+
not is_iteration and any(isinstance(x, Loop) for x in cur_group)
|
|
989
|
+
):
|
|
990
|
+
groups.append(cur_group)
|
|
991
|
+
cur_group = []
|
|
992
|
+
cur_group.append(node)
|
|
993
|
+
assert not isinstance(node, Sequential) or i == len(self.nodes) - 1, "BUG"
|
|
994
|
+
if isinstance(node, Loop):
|
|
995
|
+
seen_loops += 1
|
|
996
|
+
|
|
997
|
+
if cur_group:
|
|
998
|
+
groups.append(cur_group)
|
|
999
|
+
|
|
1000
|
+
final_group = self.nodes[i:]
|
|
1001
|
+
groups.append(final_group)
|
|
1002
|
+
|
|
1003
|
+
if seen_loops < stop_at_n_loops:
|
|
1004
|
+
raise ValueError(
|
|
1005
|
+
f"Expected {stop_at_n_loops} loops, but only found {seen_loops}"
|
|
1006
|
+
)
|
|
1007
|
+
|
|
1008
|
+
# Lower reservations. If reservations are in the second-to-last group
|
|
1009
|
+
# # non-iteration group, lower them to the last group.
|
|
1010
|
+
# if len(groups) > 3:
|
|
1011
|
+
# assert not any(isinstance(x, Loop) for x in groups[-1]), "BUG"
|
|
1012
|
+
# assert not any(isinstance(x, Loop) for x in groups[-3]), "BUG"
|
|
1013
|
+
# reservations = [x for x in groups[-2] if isinstance(x, Reservation)]
|
|
1014
|
+
# groups[-1].extend(reservations)
|
|
1015
|
+
# groups[-3] = [x for x in groups[-3] if x not in reservations]
|
|
1016
|
+
|
|
1017
|
+
return groups
|
|
1018
|
+
|
|
1019
|
+
def _merge(self, other: "Nested", n_shared_loops: int) -> "Nested":
|
|
1020
|
+
|
|
1021
|
+
# Break up the nodes above the indices. We need to have them in the format of
|
|
1022
|
+
# [(loop, other stuff...), (loop, other stuff...), ...]
|
|
1023
|
+
my_groups = self._break_into_reorderable_groups(stop_at_n_loops=n_shared_loops)
|
|
1024
|
+
my_remaining = my_groups.pop(-1)
|
|
1025
|
+
other_groups = other._break_into_reorderable_groups(
|
|
1026
|
+
stop_at_n_loops=n_shared_loops
|
|
1027
|
+
)
|
|
1028
|
+
other_remaining = other_groups.pop(-1)
|
|
1029
|
+
|
|
1030
|
+
# Reorder so that the loops are in the same order. We can't reorder groups that
|
|
1031
|
+
# have other stuff in them because that'll change the behavior of the mapping.
|
|
1032
|
+
zipped_groups = []
|
|
1033
|
+
|
|
1034
|
+
def _pop_loop_group(groups: list[list[MappingNode]]) -> list[MappingNode]:
|
|
1035
|
+
while groups and not any(isinstance(x, Loop) for x in groups[0]):
|
|
1036
|
+
zipped_groups.append(groups.pop(0))
|
|
1037
|
+
return groups.pop(0) if groups else []
|
|
1038
|
+
|
|
1039
|
+
my_loop_group = _pop_loop_group(my_groups)
|
|
1040
|
+
other_loop_group = _pop_loop_group(other_groups)
|
|
1041
|
+
while (my_groups or my_loop_group) and (other_groups or other_loop_group):
|
|
1042
|
+
if not my_loop_group:
|
|
1043
|
+
my_loop_group = _pop_loop_group(my_groups)
|
|
1044
|
+
continue
|
|
1045
|
+
if not other_loop_group:
|
|
1046
|
+
other_loop_group = _pop_loop_group(other_groups)
|
|
1047
|
+
continue
|
|
1048
|
+
|
|
1049
|
+
# Add matching loops from the two groups. If we can't find a match, raise an
|
|
1050
|
+
# error.
|
|
1051
|
+
to_add = None
|
|
1052
|
+
for i, a in enumerate(my_loop_group):
|
|
1053
|
+
for j, b in enumerate(other_loop_group):
|
|
1054
|
+
if a == b:
|
|
1055
|
+
to_add = [a]
|
|
1056
|
+
my_loop_group.pop(i)
|
|
1057
|
+
other_loop_group.pop(j)
|
|
1058
|
+
break
|
|
1059
|
+
|
|
1060
|
+
if to_add is None:
|
|
1061
|
+
# TODO: This check for one is only to early catch bugs coming here. The
|
|
1062
|
+
# code below says that if we couldn't find a match, then ignore rank
|
|
1063
|
+
# variables and assume that rank variable translation would fix it.
|
|
1064
|
+
assert len(my_loop_group) == 1 or len(other_loop_group) == 1
|
|
1065
|
+
has_one, may_not_have_one = my_loop_group, other_loop_group
|
|
1066
|
+
if len(has_one) != 1:
|
|
1067
|
+
has_one, may_not_have_one = other_loop_group, my_loop_group
|
|
1068
|
+
|
|
1069
|
+
l = copy.deepcopy(has_one.pop(0))
|
|
1070
|
+
l.rank_variable = (
|
|
1071
|
+
l.rank_variable
|
|
1072
|
+
if isinstance(l.rank_variable, set)
|
|
1073
|
+
else set([l.rank_variable])
|
|
1074
|
+
)
|
|
1075
|
+
for l2 in may_not_have_one:
|
|
1076
|
+
if l2.calculated_n_iterations == l.calculated_n_iterations:
|
|
1077
|
+
break
|
|
1078
|
+
else:
|
|
1079
|
+
raise ValueError(
|
|
1080
|
+
f"No matching loop found for {my_loop_group} and {other_loop_group}"
|
|
1081
|
+
)
|
|
1082
|
+
print(
|
|
1083
|
+
f"Warning. Matching loops {l} and {l2}. Need rank variable translation here."
|
|
1084
|
+
)
|
|
1085
|
+
|
|
1086
|
+
may_not_have_one.remove(l2)
|
|
1087
|
+
rv = l2.rank_variable
|
|
1088
|
+
rv = rv if isinstance(rv, set) else set([rv])
|
|
1089
|
+
l.rank_variable = l.rank_variable | rv
|
|
1090
|
+
to_add = [l]
|
|
1091
|
+
|
|
1092
|
+
zipped_groups.append(to_add)
|
|
1093
|
+
|
|
1094
|
+
assert not my_loop_group and not other_loop_group, "BUG"
|
|
1095
|
+
|
|
1096
|
+
zipped_groups.extend(my_groups)
|
|
1097
|
+
zipped_groups.extend(other_groups)
|
|
1098
|
+
|
|
1099
|
+
flattened = list(x for group in zipped_groups for x in group)
|
|
1100
|
+
new_nodes = [x for x in flattened if not isinstance(x, Sequential)]
|
|
1101
|
+
new_nodes.extend([x for x in flattened if isinstance(x, Sequential)])
|
|
1102
|
+
|
|
1103
|
+
if isinstance(my_remaining[0], Sequential) and isinstance(
|
|
1104
|
+
other_remaining[0], Sequential
|
|
1105
|
+
):
|
|
1106
|
+
my_remaining[0].nodes.extend(other_remaining[0].nodes)
|
|
1107
|
+
assert len(my_remaining) == 1 and len(other_remaining) == 1, "BUG"
|
|
1108
|
+
new_nodes.append(my_remaining[0])
|
|
1109
|
+
elif isinstance(my_remaining[0], Sequential):
|
|
1110
|
+
my_remaining[0].nodes.append(Nested(nodes=other_remaining))
|
|
1111
|
+
assert len(my_remaining) == 1, "BUG"
|
|
1112
|
+
new_nodes.append(my_remaining[0])
|
|
1113
|
+
elif isinstance(other_remaining[0], Sequential):
|
|
1114
|
+
other_remaining[0].nodes.append(Nested(nodes=my_remaining))
|
|
1115
|
+
assert len(other_remaining) == 1, "BUG"
|
|
1116
|
+
new_nodes.append(other_remaining[0])
|
|
1117
|
+
else:
|
|
1118
|
+
new_nodes.append(
|
|
1119
|
+
Sequential(
|
|
1120
|
+
nodes=[Nested(nodes=my_remaining), Nested(nodes=other_remaining)]
|
|
1121
|
+
)
|
|
1122
|
+
)
|
|
1123
|
+
|
|
1124
|
+
return Nested(nodes=new_nodes)
|
|
1125
|
+
|
|
1126
|
+
def _beautify_loops(
|
|
1127
|
+
self, rank_variable_bounds: Optional[dict[str, dict[str, int]]] = None
|
|
1128
|
+
):
|
|
1129
|
+
to_remove = []
|
|
1130
|
+
rank_variable_bounds = rank_variable_bounds or {}
|
|
1131
|
+
|
|
1132
|
+
for i, node in enumerate(self.nodes):
|
|
1133
|
+
if not isinstance(node, Loop):
|
|
1134
|
+
continue
|
|
1135
|
+
prev_tile_shape = None
|
|
1136
|
+
for j in range(i - 1, -1, -1):
|
|
1137
|
+
node2 = self.nodes[j]
|
|
1138
|
+
if not isinstance(node2, Loop):
|
|
1139
|
+
continue
|
|
1140
|
+
if node2.tile_shape is None:
|
|
1141
|
+
continue
|
|
1142
|
+
if node2.rank_variable != node.rank_variable:
|
|
1143
|
+
continue
|
|
1144
|
+
prev_tile_shape = node2.tile_shape
|
|
1145
|
+
break
|
|
1146
|
+
if prev_tile_shape is None:
|
|
1147
|
+
prev_tile_shape = rank_variable_bounds.get(node.rank_variable, None)
|
|
1148
|
+
if prev_tile_shape is not None:
|
|
1149
|
+
if node.tile_shape == prev_tile_shape:
|
|
1150
|
+
to_remove.append(i)
|
|
1151
|
+
continue
|
|
1152
|
+
elif node.tile_shape is not None and prev_tile_shape is not None:
|
|
1153
|
+
node.tile_pattern = node.tile_pattern.update(
|
|
1154
|
+
calculated_n_iterations=prev_tile_shape / node.tile_shape,
|
|
1155
|
+
)
|
|
1156
|
+
|
|
1157
|
+
def safe_int_cast(x: int | float | None) -> int | float | None:
|
|
1158
|
+
try:
|
|
1159
|
+
int_x = int(x)
|
|
1160
|
+
return int_x if int_x == x else x
|
|
1161
|
+
except:
|
|
1162
|
+
pass
|
|
1163
|
+
return x
|
|
1164
|
+
|
|
1165
|
+
for i, node in enumerate(self.nodes):
|
|
1166
|
+
if not isinstance(node, Loop):
|
|
1167
|
+
continue
|
|
1168
|
+
node.tile_pattern = node.tile_pattern.update(
|
|
1169
|
+
initial_tile_shape=safe_int_cast(node.tile_pattern.initial_tile_shape),
|
|
1170
|
+
tile_shape=safe_int_cast(node.tile_pattern.tile_shape),
|
|
1171
|
+
)
|
|
1172
|
+
|
|
1173
|
+
self.nodes = [node for i, node in enumerate(self.nodes) if i not in to_remove]
|
|
1174
|
+
|
|
1175
|
+
@override
|
|
1176
|
+
def compact_str(self) -> str:
|
|
1177
|
+
result = []
|
|
1178
|
+
prev = None
|
|
1179
|
+
for node in self.nodes:
|
|
1180
|
+
try:
|
|
1181
|
+
prev = prev._merge(node)
|
|
1182
|
+
except:
|
|
1183
|
+
if prev is not None:
|
|
1184
|
+
result.append(prev)
|
|
1185
|
+
prev = node
|
|
1186
|
+
if prev is not None:
|
|
1187
|
+
result.append(prev)
|
|
1188
|
+
|
|
1189
|
+
return " ".join(node.compact_str() for node in result)
|
|
1190
|
+
|
|
1191
|
+
def _get_single_tensor_mapping(
|
|
1192
|
+
self,
|
|
1193
|
+
tensor_name: TensorName,
|
|
1194
|
+
flattened_arch: list[arch.Leaf],
|
|
1195
|
+
indexing_expressions: set[str],
|
|
1196
|
+
) -> Self:
|
|
1197
|
+
"""
|
|
1198
|
+
Ctrl-F for CONTIGUOUS_ITERATION_SPACE_DISCUSSION
|
|
1199
|
+
|
|
1200
|
+
Returns this Nested node with only the nodes associated with the given tensor.
|
|
1201
|
+
|
|
1202
|
+
Includes loops and compute nodes, plus any tensor holders and reservations that
|
|
1203
|
+
are associated with the given tensor.
|
|
1204
|
+
|
|
1205
|
+
Puts spatials as high as they can go while being below any node that is above
|
|
1206
|
+
them in the memory hierarchy. Between two tensor holders, generally puts spatial
|
|
1207
|
+
loops at the bottom, but may put them above temporal loops if that better lines
|
|
1208
|
+
up with the original order. When memory hierarchy order is followed globally
|
|
1209
|
+
(e.g., output in buffer must be above input in reg), the loop order going into
|
|
1210
|
+
this function will always match that going out.
|
|
1211
|
+
|
|
1212
|
+
This function expects, as input, all spatials to be placed as low as they can
|
|
1213
|
+
go, but above their respective fanouts.
|
|
1214
|
+
|
|
1215
|
+
When memory hierarchy order is only followed per-tensor (e.g., output in buffer
|
|
1216
|
+
must be above output in reg, but can be below input in reg), things may be more
|
|
1217
|
+
complicated. We discuss this in more detail using the following example:
|
|
1218
|
+
|
|
1219
|
+
Hierarchy:
|
|
1220
|
+
|
|
1221
|
+
- Buffer
|
|
1222
|
+
- 2x fanout
|
|
1223
|
+
- Reg
|
|
1224
|
+
|
|
1225
|
+
Mapping:
|
|
1226
|
+
S-reg for m1 in [0, 2):
|
|
1227
|
+
[Reg reuses input]
|
|
1228
|
+
for m0 in [0, 2):
|
|
1229
|
+
[Buffer reuses output]
|
|
1230
|
+
|
|
1231
|
+
When given a mapping and architecture like the above, this function may reorder
|
|
1232
|
+
the spatial and temporal loops, yielding the following:
|
|
1233
|
+
|
|
1234
|
+
Mapping for input:
|
|
1235
|
+
S-reg for m1 in [0, 2):
|
|
1236
|
+
[Reg reuses input]
|
|
1237
|
+
for m0 in [0, 2):
|
|
1238
|
+
|
|
1239
|
+
Mapping for output:
|
|
1240
|
+
for m0 in [0, 2):
|
|
1241
|
+
[Buffer reuses output]
|
|
1242
|
+
S-reg for m1 in [0, 2):
|
|
1243
|
+
|
|
1244
|
+
Unfortunately, such reordering is inevitable given our assumptions of an
|
|
1245
|
+
inclusive memory hierarchy (because any tile stored in the reg must be stored in
|
|
1246
|
+
the buffer), and our desire to place the reg storage node higher. It's also a
|
|
1247
|
+
symptom of the following other issues:
|
|
1248
|
+
|
|
1249
|
+
- In cases like these, storage nodes may need to keep non-contiguous chunks of
|
|
1250
|
+
the iteration space. For example, if the spatial loop is on top, then one reg
|
|
1251
|
+
holds [0, 1] while the other holds [2, 3]. Meanwhile, in the first temporal
|
|
1252
|
+
iteration, the buffer holds [0, 2] and in the second temporal iteration, the
|
|
1253
|
+
buffer holds [2, 4].
|
|
1254
|
+
- We get weird dependencies between loop order and compatibility for fusion
|
|
1255
|
+
because loop order affects the iteration space tiles that are stored.
|
|
1256
|
+
|
|
1257
|
+
To prevent these problems from occuring, we raise an error if there any temporal
|
|
1258
|
+
loops in between that affect the same indexing expressions as the spatial loops.
|
|
1259
|
+
I tried to have it work with our model and then constraining the temporal loops
|
|
1260
|
+
to be null (have the same tile shape as their outer loop), but when we run it
|
|
1261
|
+
per-tensor and reorder, the loop above the temporal changes, so the model
|
|
1262
|
+
returns inconsistent results for each tensor as the tile shape is different.
|
|
1263
|
+
With this constraint, we'll never reorder spatial and temporal loops that affect
|
|
1264
|
+
one another.
|
|
1265
|
+
|
|
1266
|
+
The result of the above is that we'll never reorder spatial and temporal loops
|
|
1267
|
+
that affect one another.
|
|
1268
|
+
|
|
1269
|
+
I haven't thought through how this will work with more complex rank variable
|
|
1270
|
+
expressions, so to be safe, will say that there can not be a temporal and
|
|
1271
|
+
spatial loop that affect the same indexing expression or each others' loop
|
|
1272
|
+
bounds.
|
|
1273
|
+
|
|
1274
|
+
These problems also aren't necessarily impossible to solve; I just haven't
|
|
1275
|
+
thought it through. If we do think it through, a good place to start would be to
|
|
1276
|
+
update the model to support non-contiguous chunks of the iteration space, then
|
|
1277
|
+
come up with some way to explore mappings and fusion while using non-contiguous
|
|
1278
|
+
chunks of the iteration space.
|
|
1279
|
+
|
|
1280
|
+
TODO: Mapper then also needs explore swapping temporal/spatial loops
|
|
1281
|
+
"""
|
|
1282
|
+
spatials = [n for n in self.nodes if isinstance(n, Spatial)]
|
|
1283
|
+
tensor_holders = [
|
|
1284
|
+
n for n in self.nodes if isinstance(n, (TensorHolder, Compute))
|
|
1285
|
+
]
|
|
1286
|
+
others = [
|
|
1287
|
+
n
|
|
1288
|
+
for n in self.nodes
|
|
1289
|
+
if not isinstance(n, (TensorHolder, Reservation, Spatial))
|
|
1290
|
+
or (isinstance(n, TensorHolder) and n.tensor == tensor_name)
|
|
1291
|
+
or (isinstance(n, Reservation) and n.purpose == tensor_name)
|
|
1292
|
+
]
|
|
1293
|
+
assert not any(isinstance(n, MappingNodeWithChildren) for n in others), "BUG"
|
|
1294
|
+
|
|
1295
|
+
def arch_idx(node: MappingNode) -> int:
|
|
1296
|
+
for i, n in enumerate(flattened_arch):
|
|
1297
|
+
if n.name == node.component:
|
|
1298
|
+
return i
|
|
1299
|
+
raise ValueError(f"Component {node.component} not found in flattened arch")
|
|
1300
|
+
|
|
1301
|
+
spatials_above = {
|
|
1302
|
+
id(node): [s for s in spatials if arch_idx(s) <= arch_idx(node)]
|
|
1303
|
+
for node in tensor_holders
|
|
1304
|
+
}
|
|
1305
|
+
spatials_below = {
|
|
1306
|
+
id(node): [s for s in spatials if arch_idx(s) >= arch_idx(node)]
|
|
1307
|
+
for node in tensor_holders
|
|
1308
|
+
}
|
|
1309
|
+
|
|
1310
|
+
mapping = []
|
|
1311
|
+
for to_add in others:
|
|
1312
|
+
if isinstance(to_add, (TensorHolder, Compute)):
|
|
1313
|
+
cur_spatials_above = [
|
|
1314
|
+
s for s in spatials if s in spatials_above[id(to_add)]
|
|
1315
|
+
]
|
|
1316
|
+
spatials = [s for s in spatials if s not in cur_spatials_above]
|
|
1317
|
+
mapping.extend(cur_spatials_above)
|
|
1318
|
+
mapping.append(to_add)
|
|
1319
|
+
|
|
1320
|
+
mapping.extend(spatials)
|
|
1321
|
+
|
|
1322
|
+
# Check that spatials are always above their respective fanouts
|
|
1323
|
+
for i, node in enumerate(mapping):
|
|
1324
|
+
if not isinstance(node, (TensorHolder, Compute)):
|
|
1325
|
+
continue
|
|
1326
|
+
for node2 in mapping[i + 1 :]:
|
|
1327
|
+
if not isinstance(node2, Spatial):
|
|
1328
|
+
continue
|
|
1329
|
+
assert node2 in spatials_below[id(node)], "BUG"
|
|
1330
|
+
assert node2 not in spatials_above[id(node)], "BUG"
|
|
1331
|
+
|
|
1332
|
+
# Split the mapping into groups of tensor holders and sequential loops
|
|
1333
|
+
id2idx = {id(node): i for i, node in enumerate(self.nodes)}
|
|
1334
|
+
groups = []
|
|
1335
|
+
for node in mapping:
|
|
1336
|
+
if (
|
|
1337
|
+
isinstance(node, Loop)
|
|
1338
|
+
and len(groups) > 0
|
|
1339
|
+
and isinstance(groups[-1][0], Loop)
|
|
1340
|
+
):
|
|
1341
|
+
groups[-1].append(node)
|
|
1342
|
+
else:
|
|
1343
|
+
groups.append([node])
|
|
1344
|
+
|
|
1345
|
+
groups = [sorted(g, key=lambda x: id2idx[id(x)]) for g in groups]
|
|
1346
|
+
mapping = [x for g in groups for x in g]
|
|
1347
|
+
|
|
1348
|
+
# Check that all storage-temporal relations are held from before
|
|
1349
|
+
node2idx = {id(node): i for i, node in enumerate(mapping)}
|
|
1350
|
+
prev_node2idx = {id(node): i for i, node in enumerate(self.nodes)}
|
|
1351
|
+
for node, node2 in itertools.combinations(mapping, 2):
|
|
1352
|
+
idx1 = node2idx[id(node)]
|
|
1353
|
+
idx2 = node2idx[id(node2)]
|
|
1354
|
+
prev_idx1 = prev_node2idx[id(node)]
|
|
1355
|
+
prev_idx2 = prev_node2idx[id(node2)]
|
|
1356
|
+
if isinstance(node, TensorHolder) and isinstance(node2, TensorHolder):
|
|
1357
|
+
assert (idx1 > idx2) == (prev_idx1 > prev_idx2), "BUG"
|
|
1358
|
+
# Because of the reordering above, may lower loops beneath tensor holders
|
|
1359
|
+
# and temporal loops in order to place them as low as possble above the
|
|
1360
|
+
# fanout.
|
|
1361
|
+
# elif isinstance(node, TensorHolder) and isinstance(node2, Spatial):
|
|
1362
|
+
# assert (idx1 > idx2) == (prev_idx1 > prev_idx2), "BUG"
|
|
1363
|
+
# elif isinstance(node, Spatial) and isinstance(node2, TensorHolder):
|
|
1364
|
+
# assert (idx1 > idx2) == (prev_idx1 > prev_idx2), "BUG"
|
|
1365
|
+
elif isinstance(node, Spatial) and isinstance(node2, Spatial):
|
|
1366
|
+
assert (idx1 > idx2) == (prev_idx1 > prev_idx2), "BUG"
|
|
1367
|
+
|
|
1368
|
+
# for m in mapping:
|
|
1369
|
+
# print(m.compact_str())
|
|
1370
|
+
# for n in self.nodes:
|
|
1371
|
+
# print(n.compact_str())
|
|
1372
|
+
|
|
1373
|
+
# Check for spatial/temporal loops that have been reordered. These ones can not
|
|
1374
|
+
# co-exist because the tiling is inconsistent.
|
|
1375
|
+
# Ctrl-F for CONTIGUOUS_ITERATION_SPACE_DISCUSSION
|
|
1376
|
+
from accelforge.frontend.workload import isl_expression_has_variable
|
|
1377
|
+
|
|
1378
|
+
node2idx = {id(node): i for i, node in enumerate(self.nodes)}
|
|
1379
|
+
for node1, node2 in itertools.combinations(mapping, 2):
|
|
1380
|
+
# Both must be loops
|
|
1381
|
+
if not isinstance(node1, Loop) or not isinstance(node2, Loop):
|
|
1382
|
+
continue
|
|
1383
|
+
# Must have been reordered
|
|
1384
|
+
if node2idx[id(node1)] <= node2idx[id(node2)]:
|
|
1385
|
+
continue
|
|
1386
|
+
# Must affect the same rank variable expression
|
|
1387
|
+
for expr in indexing_expressions:
|
|
1388
|
+
if not isl_expression_has_variable(expr, node1.rank_variable):
|
|
1389
|
+
continue
|
|
1390
|
+
if not isl_expression_has_variable(expr, node2.rank_variable):
|
|
1391
|
+
continue
|
|
1392
|
+
|
|
1393
|
+
s = """
|
|
1394
|
+
In the given mapping, there exists (potentially with other nodes in
|
|
1395
|
+
between) a spatial loop above a temporal loop above a storage node,
|
|
1396
|
+
where the loops index into the same indexing expression, and the storage
|
|
1397
|
+
node is not fanned out by the spatial loop. This is not allowed.
|
|
1398
|
+
|
|
1399
|
+
Mapping:
|
|
1400
|
+
"""
|
|
1401
|
+
s = s.replace(" ", "")
|
|
1402
|
+
|
|
1403
|
+
to_add = []
|
|
1404
|
+
for n in self.nodes:
|
|
1405
|
+
if id(n) == id(node1) or id(n) == id(node2):
|
|
1406
|
+
to_add.append(f"\t{n.compact_str()} <-- Offending Loop")
|
|
1407
|
+
else:
|
|
1408
|
+
to_add.append(f"\t{n.compact_str()}")
|
|
1409
|
+
|
|
1410
|
+
raise ValueError(s + "\n".join(to_add))
|
|
1411
|
+
|
|
1412
|
+
return type(self)(nodes=mapping)
|
|
1413
|
+
|
|
1414
|
+
|
|
1415
|
+
class Parallel(Split):
|
|
1416
|
+
"""
|
|
1417
|
+
A :class:`~.Split` where each branch operates at the same time in different
|
|
1418
|
+
spatially-organized hardware.
|
|
1419
|
+
"""
|
|
1420
|
+
|
|
1421
|
+
pass
|
|
1422
|
+
|
|
1423
|
+
|
|
1424
|
+
class Pipeline(Split):
|
|
1425
|
+
"""
|
|
1426
|
+
A :class:`~.Split` where each branch operates at the same time in different
|
|
1427
|
+
spatially-organized hardware.
|
|
1428
|
+
"""
|
|
1429
|
+
|
|
1430
|
+
pass
|
|
1431
|
+
|
|
1432
|
+
|
|
1433
|
+
class Sequential(Split):
|
|
1434
|
+
"""
|
|
1435
|
+
A :class:`~.Split` where branches are processed one-after-another.
|
|
1436
|
+
"""
|
|
1437
|
+
|
|
1438
|
+
pass
|
|
1439
|
+
|
|
1440
|
+
|
|
1441
|
+
# =============================================================================
|
|
1442
|
+
# Nodes That May Only be Inserted by the Model
|
|
1443
|
+
# =============================================================================
|
|
1444
|
+
|
|
1445
|
+
|
|
1446
|
+
class Reservation(MappingNode):
|
|
1447
|
+
"""A node that reserves a hardware resource for a specific task."""
|
|
1448
|
+
|
|
1449
|
+
purposes: ParsableList[str]
|
|
1450
|
+
""" The reasons for reserving the resource. """
|
|
1451
|
+
|
|
1452
|
+
resource: str
|
|
1453
|
+
""" The resource being reserved. """
|
|
1454
|
+
|
|
1455
|
+
_backing: Set[str] = set()
|
|
1456
|
+
""" Tensors for which this reservation is reserving the tensor's backing storage.
|
|
1457
|
+
"""
|
|
1458
|
+
|
|
1459
|
+
persistent: bool = False
|
|
1460
|
+
"""
|
|
1461
|
+
Whether this reservation is persistent. Persistent reservations can't be tiled and
|
|
1462
|
+
must be kept in backing storage for the full duration of the workload's execution.
|
|
1463
|
+
"""
|
|
1464
|
+
|
|
1465
|
+
@override
|
|
1466
|
+
def compact_str(self) -> str:
|
|
1467
|
+
return f'{",".join(self.purposes)} reserves {self.resource}'
|
|
1468
|
+
|
|
1469
|
+
def __str__(self, color_map: ColorMap = None) -> str:
|
|
1470
|
+
purposes = self.purposes
|
|
1471
|
+
if color_map is not None:
|
|
1472
|
+
format_list = [f"{self.resource} reserved for"] + list(purposes)
|
|
1473
|
+
return color_map.format_list(format_list)
|
|
1474
|
+
return f"{self.resource} reserved for {",".join(purposes)}"
|
|
1475
|
+
|
|
1476
|
+
def _render_node_shape(self) -> str:
|
|
1477
|
+
return "component"
|
|
1478
|
+
|
|
1479
|
+
@property
|
|
1480
|
+
def purpose(self) -> str:
|
|
1481
|
+
if len(self.purposes) == 1:
|
|
1482
|
+
return self.purposes[0]
|
|
1483
|
+
raise ValueError(f"Reservation has multiple purposes: {self.purposes}")
|
|
1484
|
+
|
|
1485
|
+
def __eq__(self, other: "Reservation") -> bool:
|
|
1486
|
+
return (
|
|
1487
|
+
isinstance(other, Reservation)
|
|
1488
|
+
and self.purposes == other.purposes
|
|
1489
|
+
and self.resource == other.resource
|
|
1490
|
+
)
|
|
1491
|
+
|
|
1492
|
+
def _render_node_color(self) -> str:
|
|
1493
|
+
return "#E8E8E8" # Light gray
|
|
1494
|
+
|
|
1495
|
+
|
|
1496
|
+
# =============================================================================
|
|
1497
|
+
# Top-level Mapping
|
|
1498
|
+
# =============================================================================
|
|
1499
|
+
|
|
1500
|
+
MappingNodeTypes: TypeAlias = Union[
|
|
1501
|
+
Temporal,
|
|
1502
|
+
Spatial,
|
|
1503
|
+
Storage,
|
|
1504
|
+
Pipeline,
|
|
1505
|
+
Sequential,
|
|
1506
|
+
Compute,
|
|
1507
|
+
Reservation,
|
|
1508
|
+
# Fill,
|
|
1509
|
+
TensorHolder,
|
|
1510
|
+
]
|
|
1511
|
+
"""TypeAlias MappingNodeTypes: The types of MappingNodes possible."""
|
|
1512
|
+
|
|
1513
|
+
|
|
1514
|
+
class Mapping(Nested):
|
|
1515
|
+
"""A Mapping of a workload onto a hardware architecture."""
|
|
1516
|
+
|
|
1517
|
+
# version: Annotated[str, assert_version] = __version__
|
|
1518
|
+
|
|
1519
|
+
_n_loop_orders: int | None = None
|
|
1520
|
+
""" Used for counting number of unique mappings. Do not touch. """
|
|
1521
|
+
|
|
1522
|
+
def remove_reservations(self):
|
|
1523
|
+
self.nodes = [n for n in self.nodes if not isinstance(n, Reservation)]
|
|
1524
|
+
|
|
1525
|
+
def split_loop_with_multiple_rank_variables(self):
|
|
1526
|
+
new_nodes = []
|
|
1527
|
+
for node in self.nodes:
|
|
1528
|
+
if isinstance(node, Loop) and isinstance(node.rank_variable, set):
|
|
1529
|
+
for rank_variable in node.rank_variable:
|
|
1530
|
+
new_node = copy.copy(node)
|
|
1531
|
+
new_node.rank_variable = rank_variable
|
|
1532
|
+
new_nodes.append(new_node)
|
|
1533
|
+
else:
|
|
1534
|
+
new_nodes.append(node)
|
|
1535
|
+
self.nodes = new_nodes
|
|
1536
|
+
|
|
1537
|
+
def split_tensor_holders_with_multiple_tensors(self):
|
|
1538
|
+
new_nodes = []
|
|
1539
|
+
for node in self.nodes:
|
|
1540
|
+
if isinstance(node, TensorHolder) and len(node.tensors) > 1:
|
|
1541
|
+
for tensor in node.tensors:
|
|
1542
|
+
new_node = copy.copy(node)
|
|
1543
|
+
new_node.tensors = [tensor]
|
|
1544
|
+
new_nodes.append(new_node)
|
|
1545
|
+
else:
|
|
1546
|
+
new_nodes.append(node)
|
|
1547
|
+
self.nodes = new_nodes
|
|
1548
|
+
|
|
1549
|
+
def _get_fused_slice(self, fusable_tensors: set[TensorName]) -> "Mapping":
|
|
1550
|
+
"""
|
|
1551
|
+
Return a mapping with:
|
|
1552
|
+
- All backing reservation nodes for intermediate tensors
|
|
1553
|
+
- Loop nodes above any backing reservation nodes
|
|
1554
|
+
"""
|
|
1555
|
+
# All intermediate tensors that can be found in this mapping
|
|
1556
|
+
# Note: `fusable_tensors` may be for **whole workload**.
|
|
1557
|
+
relevant_intermediate_tensors = set()
|
|
1558
|
+
for node in self.nodes:
|
|
1559
|
+
if isinstance(node, Reservation):
|
|
1560
|
+
if node.purpose in fusable_tensors:
|
|
1561
|
+
relevant_intermediate_tensors.add(node.purpose)
|
|
1562
|
+
|
|
1563
|
+
fused_slice = Mapping(nodes=[])
|
|
1564
|
+
to_add = []
|
|
1565
|
+
for node in self.nodes:
|
|
1566
|
+
node = copy.copy(node)
|
|
1567
|
+
if isinstance(node, Reservation):
|
|
1568
|
+
if node.purpose not in relevant_intermediate_tensors:
|
|
1569
|
+
continue
|
|
1570
|
+
fused_slice.nodes.extend(to_add + [node])
|
|
1571
|
+
to_add = []
|
|
1572
|
+
relevant_intermediate_tensors.remove(node.purpose)
|
|
1573
|
+
if len(relevant_intermediate_tensors) == 0:
|
|
1574
|
+
break
|
|
1575
|
+
elif isinstance(node, Loop):
|
|
1576
|
+
to_add.append(node)
|
|
1577
|
+
return fused_slice
|
|
1578
|
+
|
|
1579
|
+
@property
|
|
1580
|
+
def loops(self) -> list[Loop]:
|
|
1581
|
+
"""Returns all :class:`~.Loop` nodes in the Mapping."""
|
|
1582
|
+
return self.get_nodes_of_type(Loop)
|
|
1583
|
+
|
|
1584
|
+
def _render_node_label(self, **kwargs) -> str:
|
|
1585
|
+
return f"Root"
|
|
1586
|
+
|
|
1587
|
+
def _repr_svg_(self) -> str:
|
|
1588
|
+
return self.render()
|
|
1589
|
+
|
|
1590
|
+
def render_pydot(self, with_reservations=True, with_tile_shape=True) -> pydot.Dot:
|
|
1591
|
+
"""Renders the mapping as a Pydot graph. Returns an SVG string."""
|
|
1592
|
+
graph = _pydot_graph()
|
|
1593
|
+
# Enable HTML-like labels for color support
|
|
1594
|
+
graph.set_node_defaults(label="")
|
|
1595
|
+
if not with_reservations:
|
|
1596
|
+
exclude_types = (Reservation,)
|
|
1597
|
+
else:
|
|
1598
|
+
exclude_types = tuple()
|
|
1599
|
+
for node in self._render_make_children(
|
|
1600
|
+
exclude_types=exclude_types, with_tile_shape=with_tile_shape
|
|
1601
|
+
):
|
|
1602
|
+
graph.add_node(node)
|
|
1603
|
+
|
|
1604
|
+
color_keys = set()
|
|
1605
|
+
all_nodes = self._flatten()
|
|
1606
|
+
for node in all_nodes:
|
|
1607
|
+
if isinstance(node, TensorHolder):
|
|
1608
|
+
color_keys.update(node.tensors)
|
|
1609
|
+
if isinstance(node, Reservation):
|
|
1610
|
+
color_keys.update(node.purposes)
|
|
1611
|
+
|
|
1612
|
+
color_map = ColorMap(sorted(color_keys))
|
|
1613
|
+
|
|
1614
|
+
for node in all_nodes:
|
|
1615
|
+
if isinstance(node, (TensorHolder, Reservation)):
|
|
1616
|
+
graph_nodes = graph.get_node(node._render_node_name())
|
|
1617
|
+
for graph_node in graph_nodes:
|
|
1618
|
+
# Set HTML-like label for color support
|
|
1619
|
+
new_label = node.__str__(color_map)
|
|
1620
|
+
graph_node.set_label(new_label)
|
|
1621
|
+
# graph_node.set_fillcolor(color_map[node._color_key()])
|
|
1622
|
+
# graph_node.set_style('filled')
|
|
1623
|
+
|
|
1624
|
+
added_edges = set()
|
|
1625
|
+
child2included_parent = {}
|
|
1626
|
+
for parent, child in self._parent2child(None):
|
|
1627
|
+
parent_name = parent._render_node_name() if parent is not None else None
|
|
1628
|
+
child_name = child._render_node_name()
|
|
1629
|
+
if isinstance(parent, exclude_types):
|
|
1630
|
+
parent_name = child2included_parent.get(parent_name, None)
|
|
1631
|
+
child2included_parent[child_name] = parent_name
|
|
1632
|
+
if not isinstance(child, exclude_types):
|
|
1633
|
+
added_edges.add((parent_name, child_name))
|
|
1634
|
+
for parent_name, child_name in added_edges:
|
|
1635
|
+
if parent_name is not None:
|
|
1636
|
+
graph.add_edge(pydot.Edge(parent_name, child_name))
|
|
1637
|
+
return graph
|
|
1638
|
+
|
|
1639
|
+
def render(self) -> _SVGJupyterRender:
|
|
1640
|
+
graph = self.render_pydot()
|
|
1641
|
+
return _SVGJupyterRender(graph.create_svg(prog="dot").decode("utf-8"))
|
|
1642
|
+
|
|
1643
|
+
@classmethod
|
|
1644
|
+
def _from_pmappings(
|
|
1645
|
+
cls,
|
|
1646
|
+
pmappings: list[Nested],
|
|
1647
|
+
rank_variable_bounds: Optional[dict[str, dict[str, int]]] = None,
|
|
1648
|
+
) -> "Mapping":
|
|
1649
|
+
pmappings = list(copy.deepcopy(pmappings))
|
|
1650
|
+
for pmapping in pmappings:
|
|
1651
|
+
pmapping._beautify_loops(rank_variable_bounds)
|
|
1652
|
+
|
|
1653
|
+
while len(pmappings) > 1:
|
|
1654
|
+
highest_n_shared_loops = 0
|
|
1655
|
+
highest_shared_pmapping_index = 0
|
|
1656
|
+
for i, pmapping in enumerate(pmappings):
|
|
1657
|
+
shared_index = 0
|
|
1658
|
+
for j in range(i + 1, len(pmappings)):
|
|
1659
|
+
shared_index = max(
|
|
1660
|
+
pmapping._get_n_shared_loops(pmappings[j]), shared_index
|
|
1661
|
+
)
|
|
1662
|
+
if shared_index > highest_n_shared_loops:
|
|
1663
|
+
highest_n_shared_loops = shared_index
|
|
1664
|
+
highest_shared_pmapping_index = i
|
|
1665
|
+
|
|
1666
|
+
def einsum_names(pmapping: Nested) -> str:
|
|
1667
|
+
return ",".join(n.einsum for n in pmapping.get_nodes_of_type(Compute))
|
|
1668
|
+
|
|
1669
|
+
names_a = einsum_names(pmappings[highest_shared_pmapping_index])
|
|
1670
|
+
names_b = einsum_names(pmappings[highest_shared_pmapping_index + 1])
|
|
1671
|
+
# print(
|
|
1672
|
+
# f"Merging with shared loops {highest_n_shared_loops}: {names_a} <--> {names_b}."
|
|
1673
|
+
# )
|
|
1674
|
+
# print(pmappings[highest_shared_pmapping_index]._get_n_shared_loops(pmappings[highest_shared_pmapping_index + 1]))
|
|
1675
|
+
pmappings[highest_shared_pmapping_index] = pmappings[
|
|
1676
|
+
highest_shared_pmapping_index
|
|
1677
|
+
]._merge(
|
|
1678
|
+
pmappings.pop(highest_shared_pmapping_index + 1),
|
|
1679
|
+
0 if _NO_JOIN_MAPPING_VISUALIZATION else highest_n_shared_loops,
|
|
1680
|
+
)
|
|
1681
|
+
|
|
1682
|
+
mapping: Mapping = cls(nodes=pmappings[0].nodes)
|
|
1683
|
+
mapping._elevate_persistent_nodes_above_splits()
|
|
1684
|
+
mapping._elevate_tensor_holders_above_splits()
|
|
1685
|
+
mapping._propagate_reservations_between_splits()
|
|
1686
|
+
mapping._consolidate_tensor_holders()
|
|
1687
|
+
mapping._consolidate_reservations()
|
|
1688
|
+
mapping._move_tensor_holders_above_reservations()
|
|
1689
|
+
mapping._remove_reservations_for_processing_stages()
|
|
1690
|
+
return mapping
|
|
1691
|
+
|
|
1692
|
+
# import mermaid as md
|
|
1693
|
+
# from mermaid.graph import Graph
|
|
1694
|
+
# lines = []
|
|
1695
|
+
# lines = [
|
|
1696
|
+
# "graph TD",
|
|
1697
|
+
# "%%{init: {'flowchart': {'nodeSpacing': 30, 'rankSpacing': 30, 'padding': 2}, 'themeVariables': {'fontFamily': 'Arial, sans-serif'}}}%%"
|
|
1698
|
+
# ]
|
|
1699
|
+
# lines.extend(self._render_make_children())
|
|
1700
|
+
# for parent, child in self._parent2child(None):
|
|
1701
|
+
# if parent is not None:
|
|
1702
|
+
# lines.append(f"{parent._render_node_name()} --> {child._render_node_name()}")
|
|
1703
|
+
# # if _is_root:
|
|
1704
|
+
# # lines.extend([
|
|
1705
|
+
# # "",
|
|
1706
|
+
# # "classDef default fill:#fff,stroke:#000,stroke-width:1px,color:#000,font-family:Arial,font-size:12px,padding:2px;",
|
|
1707
|
+
# # "classDef compact fill:#fff,stroke:#000,stroke-width:1px,color:#000,font-family:Arial,font-size:12px,padding:2px;"
|
|
1708
|
+
# # ])
|
|
1709
|
+
|
|
1710
|
+
# # Create the graph with the flowchart script
|
|
1711
|
+
# flowchart_script = "\n".join(lines)
|
|
1712
|
+
# graph = Graph('Flowchart', flowchart_script)
|
|
1713
|
+
|
|
1714
|
+
# # Set the configuration for compact layout
|
|
1715
|
+
# config = md.Config()
|
|
1716
|
+
# config.theme = 'base'
|
|
1717
|
+
# # config.theme_variables = {
|
|
1718
|
+
# # 'primaryColor': '#ffffff',
|
|
1719
|
+
# # 'primaryTextColor': '#000000',
|
|
1720
|
+
# # 'primaryBorderColor': '#000000',
|
|
1721
|
+
# # 'lineColor': '#000000',
|
|
1722
|
+
# # 'fontSize': '12px'
|
|
1723
|
+
# # }
|
|
1724
|
+
# # config.flowchart = {
|
|
1725
|
+
# # 'nodeSpacing': 20,
|
|
1726
|
+
# # 'rankSpacing': 10,
|
|
1727
|
+
# # 'curve': 'linear'
|
|
1728
|
+
# # }
|
|
1729
|
+
# graph.config = config
|
|
1730
|
+
|
|
1731
|
+
# return md.Mermaid(graph)
|
|
1732
|
+
|
|
1733
|
+
|
|
1734
|
+
Split.model_rebuild()
|
|
1735
|
+
Nested.model_rebuild()
|
|
1736
|
+
Mapping.model_rebuild()
|