accelforge 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accelforge/__init__.py +21 -0
- accelforge/_accelerated_imports.py +16 -0
- accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
- accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
- accelforge/_deprecate/_simanneal/simanneal.py +666 -0
- accelforge/_deprecate/_simanneal/tracking.py +105 -0
- accelforge/_deprecate/_simanneal/wrappers.py +218 -0
- accelforge/_deprecate/_simanneal2/__init__.py +7 -0
- accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
- accelforge/_deprecate/_simanneal2/tracking.py +116 -0
- accelforge/_deprecate/compatibility_util.py +181 -0
- accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
- accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
- accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
- accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
- accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
- accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
- accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
- accelforge/_deprecate/tags.py +69 -0
- accelforge/_deprecate/viz/__init__.py +0 -0
- accelforge/_deprecate/viz/interactive.py +159 -0
- accelforge/_deprecate/viz/reservationtree.py +307 -0
- accelforge/_deprecate/viz/ski_slope.py +88 -0
- accelforge/_version.py +15 -0
- accelforge/examples.py +39 -0
- accelforge/frontend/__init__.py +10 -0
- accelforge/frontend/_binding.py +129 -0
- accelforge/frontend/_workload_isl/__init__.py +2 -0
- accelforge/frontend/_workload_isl/_isl.py +149 -0
- accelforge/frontend/_workload_isl/_symbolic.py +141 -0
- accelforge/frontend/arch copy.py +1544 -0
- accelforge/frontend/arch.py +1642 -0
- accelforge/frontend/config.py +63 -0
- accelforge/frontend/mapper/__init__.py +5 -0
- accelforge/frontend/mapper/ffm.py +126 -0
- accelforge/frontend/mapper/mapper.py +7 -0
- accelforge/frontend/mapper/metrics.py +30 -0
- accelforge/frontend/mapping/__init__.py +1 -0
- accelforge/frontend/mapping/mapping.py +1736 -0
- accelforge/frontend/model.py +14 -0
- accelforge/frontend/renames.py +150 -0
- accelforge/frontend/spec copy.py +230 -0
- accelforge/frontend/spec.py +301 -0
- accelforge/frontend/variables.py +12 -0
- accelforge/frontend/workload.py +952 -0
- accelforge/mapper/FFM/__init__.py +9 -0
- accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
- accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
- accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
- accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
- accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
- accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
- accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
- accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
- accelforge/mapper/FFM/data.py +61 -0
- accelforge/mapper/FFM/main copy.py +236 -0
- accelforge/mapper/FFM/main.py +208 -0
- accelforge/mapper/FFM/mappings.py +510 -0
- accelforge/mapper/FFM/pmappings.py +310 -0
- accelforge/mapper/__init__.py +4 -0
- accelforge/mapper.py +0 -0
- accelforge/model/__init__.py +1 -0
- accelforge/model/_looptree/__init__.py +0 -0
- accelforge/model/_looptree/accesses.py +335 -0
- accelforge/model/_looptree/capacity/__init__.py +1 -0
- accelforge/model/_looptree/capacity/aggregators.py +36 -0
- accelforge/model/_looptree/capacity/capacity.py +47 -0
- accelforge/model/_looptree/energy.py +150 -0
- accelforge/model/_looptree/equivalent_ranks.py +29 -0
- accelforge/model/_looptree/latency/__init__.py +1 -0
- accelforge/model/_looptree/latency/latency.py +98 -0
- accelforge/model/_looptree/latency/memory.py +120 -0
- accelforge/model/_looptree/latency/processors.py +92 -0
- accelforge/model/_looptree/mapping_utilities.py +71 -0
- accelforge/model/_looptree/reuse/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
- accelforge/model/_looptree/reuse/isl/des.py +59 -0
- accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
- accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
- accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
- accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
- accelforge/model/_looptree/run.py +122 -0
- accelforge/model/_looptree/types.py +26 -0
- accelforge/model/_looptree/visualization/__init__.py +0 -0
- accelforge/model/_looptree/visualization/occupancy.py +11 -0
- accelforge/model/main.py +222 -0
- accelforge/plotting/__init__.py +2 -0
- accelforge/plotting/mappings.py +219 -0
- accelforge/plotting/specs.py +57 -0
- accelforge/util/__init__.py +4 -0
- accelforge/util/_base_analysis_types.py +24 -0
- accelforge/util/_basetypes.py +1089 -0
- accelforge/util/_frozenset.py +36 -0
- accelforge/util/_isl.py +29 -0
- accelforge/util/_itertools.py +14 -0
- accelforge/util/_mathfuncs.py +57 -0
- accelforge/util/_parse_expressions.py +339 -0
- accelforge/util/_picklecache.py +32 -0
- accelforge/util/_setexpressions.py +268 -0
- accelforge/util/_sympy/__init__.py +0 -0
- accelforge/util/_sympy/broadcast_max.py +18 -0
- accelforge/util/_visualization.py +112 -0
- accelforge/util/_yaml.py +579 -0
- accelforge/util/parallel.py +193 -0
- accelforge-0.0.1.dist-info/METADATA +64 -0
- accelforge-0.0.1.dist-info/RECORD +258 -0
- accelforge-0.0.1.dist-info/WHEEL +5 -0
- accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
- accelforge-0.0.1.dist-info/top_level.txt +5 -0
- docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
- docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
- docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
- docs/_build/html/_sources/fastfusion.rst.txt +20 -0
- docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
- docs/_build/html/_sources/index.rst.txt +87 -0
- docs/_build/html/_sources/modules.rst.txt +7 -0
- docs/_build/html/_sources/notes/citation.rst.txt +45 -0
- docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
- docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
- docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
- docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
- docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
- docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
- docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
- docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
- docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
- docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
- docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
- docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
- docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
- docs/_build/html/_sources/notes/spec.rst.txt +36 -0
- docs/source/_ext/include_attrs.py +213 -0
- docs/source/_ext/include_docstring.py +364 -0
- docs/source/_ext/include_functions.py +154 -0
- docs/source/_ext/include_notebook.py +131 -0
- docs/source/_ext/include_yaml.py +119 -0
- docs/source/_ext/inherited_attributes.py +222 -0
- docs/source/_ext/paths.py +4 -0
- docs/source/conf.py +79 -0
- examples/arches/compute_in_memory/_include.yaml +74 -0
- examples/arches/compute_in_memory/_include_functions.py +229 -0
- examples/arches/compute_in_memory/_load_spec.py +57 -0
- examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
- examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
- examples/arches/compute_in_memory/components/misc.py +195 -0
- examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
- examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
- examples/arches/compute_in_memory/isaac.yaml +233 -0
- examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
- examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
- examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
- examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
- examples/arches/eyeriss.yaml +68 -0
- examples/arches/fanout_variations/at_glb.yaml +31 -0
- examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
- examples/arches/fanout_variations/at_mac.yaml +31 -0
- examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
- examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
- examples/arches/nvdla.yaml +47 -0
- examples/arches/simple.yaml +28 -0
- examples/arches/tpu_v4i.yaml +67 -0
- examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
- examples/misc/component_annotated.yaml +33 -0
- examples/workloads/gpt3_6.7B.yaml +124 -0
- examples/workloads/matmuls.yaml +20 -0
- examples/workloads/mobilenet_28.yaml +81 -0
- examples/workloads/mobilenet_various_separate.yaml +106 -0
- examples/workloads/three_matmuls_annotated.yaml +59 -0
- notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
- notebooks/compute_in_memory/_scripts.py +339 -0
- notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
- notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
- notebooks/paths.py +4 -0
- notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
- notebooks/tutorials/FFM.ipynb +3498 -0
- notebooks/tutorials/_include.py +48 -0
- notebooks/tutorials/component_energy_area.ipynb +363 -0
- tests/Q_mapping.yaml +38 -0
- tests/__init__.py +0 -0
- tests/conv.mapping.yaml +27 -0
- tests/conv.workload.yaml +13 -0
- tests/conv_sym.mapping.yaml +43 -0
- tests/copy.mapping.yaml +35 -0
- tests/copy.workload.yaml +15 -0
- tests/distribuffers/__init__.py +0 -0
- tests/distribuffers/multicast/test_cases.yaml +482 -0
- tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
- tests/distribuffers/spec/distributed.yaml +100 -0
- tests/distribuffers/spec/logical_arch.yaml +32 -0
- tests/distribuffers/spec/physical_arch.yaml +69 -0
- tests/distribuffers/test_binding.py +48 -0
- tests/frontend/__init__.py +0 -0
- tests/frontend/test_mapping_viz.py +52 -0
- tests/mapper/__init__.py +0 -0
- tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
- tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
- tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
- tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
- tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
- tests/mapper/test_mapping_to_isl.py +90 -0
- tests/mapper/test_spatial_reuse_analysis.py +67 -0
- tests/mapper/test_temporal_reuse_analysis.py +56 -0
- tests/mapper/util.py +58 -0
- tests/matmul.mapping.yaml +29 -0
- tests/matmul.workload.yaml +12 -0
- tests/matmul_spatial.mapping.yaml +44 -0
- tests/mha.renames.yaml +65 -0
- tests/mha.workload.yaml +67 -0
- tests/mha.yaml +59 -0
- tests/mha_full.workload.yaml +67 -0
- tests/mobilenet.workload.yaml +35 -0
- tests/mobilenet_long.workload.yaml +64 -0
- tests/pmappingcache.py +24 -0
- tests/processing_stage.arch.yaml +40 -0
- tests/snowcat.arch.yaml +36 -0
- tests/test_ffm_join_pmappings.py +106 -0
- tests/test_ffm_make_pmappings.py +82 -0
- tests/test_ffm_make_tile_shapes.py +49 -0
- tests/test_mapper.py +100 -0
- tests/test_model.py +37 -0
- tests/test_plotting.py +72 -0
- tests/test_processing_stage.py +46 -0
- tests/test_symbolic_model.py +248 -0
- tests/test_workload.py +141 -0
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from accelforge.mapper.FFM.main import (
|
|
2
|
+
map_workload_to_arch,
|
|
3
|
+
make_pmappings,
|
|
4
|
+
join_pmappings,
|
|
5
|
+
MultiEinsumPmappings,
|
|
6
|
+
Mappings,
|
|
7
|
+
)
|
|
8
|
+
from accelforge.frontend.mapper.metrics import Metrics
|
|
9
|
+
from accelforge.mapper.FFM._join_pmappings.pmapping_group import PmappingGroup
|
|
File without changes
|
|
@@ -0,0 +1,653 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from dataclasses import dataclass, replace
|
|
4
|
+
import itertools
|
|
5
|
+
from numbers import Number
|
|
6
|
+
from typing import Literal, TypeVar
|
|
7
|
+
|
|
8
|
+
from accelforge.frontend.workload import TensorAccess, Workload
|
|
9
|
+
from accelforge.frontend.mapping import (
|
|
10
|
+
Compute,
|
|
11
|
+
Loop,
|
|
12
|
+
Mapping,
|
|
13
|
+
Spatial,
|
|
14
|
+
TensorHolder,
|
|
15
|
+
Reservation as MappingReservation,
|
|
16
|
+
Split as MappingSplit,
|
|
17
|
+
TilePattern,
|
|
18
|
+
Loop as MappingLoop,
|
|
19
|
+
)
|
|
20
|
+
from accelforge.frontend.renames import Rank, RankVariable, TensorName
|
|
21
|
+
from accelforge.mapper.FFM._pareto_df.df_convention import (
|
|
22
|
+
make_fused_loop_col,
|
|
23
|
+
stride2col,
|
|
24
|
+
initial2col,
|
|
25
|
+
iterations2col,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
from accelforge.util import _expfmt, fzs
|
|
29
|
+
|
|
30
|
+
# Abstractions:
|
|
31
|
+
# 1. Each tensor is stored above some loop index. 0 is the outermost loop, 1 the
|
|
32
|
+
# next-innermost...
|
|
33
|
+
# 2. All loops above any shared tensor are co-tiled and must match between PmappingGroups.
|
|
34
|
+
|
|
35
|
+
T = TypeVar("T", bound="Updatable")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class Updatable:
|
|
39
|
+
def update(self: T, **kwargs) -> T:
|
|
40
|
+
return replace(self, **kwargs)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _update_rename_dict(
|
|
44
|
+
renames: dict[str, str],
|
|
45
|
+
new_renames: dict[str, str],
|
|
46
|
+
):
|
|
47
|
+
for mine, other in new_renames.items():
|
|
48
|
+
if mine not in renames:
|
|
49
|
+
renames[mine] = other
|
|
50
|
+
elif renames[mine] != other:
|
|
51
|
+
raise ValueError(
|
|
52
|
+
f"Renaming {mine} to {other} conflicts with {renames[mine]}"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass(frozen=True, order=True, eq=True)
|
|
57
|
+
class Loop(Updatable):
|
|
58
|
+
rank_name: Rank
|
|
59
|
+
tile_pattern: TilePattern | None
|
|
60
|
+
is_spatial: bool
|
|
61
|
+
|
|
62
|
+
def __post_init__(self):
|
|
63
|
+
assert isinstance(self.rank_name, Rank)
|
|
64
|
+
assert isinstance(self.tile_pattern, Number | TilePattern | str | None)
|
|
65
|
+
assert isinstance(self.is_spatial, bool)
|
|
66
|
+
assert isinstance(
|
|
67
|
+
self.tile_pattern.initial_tile_shape,
|
|
68
|
+
Number | str | None,
|
|
69
|
+
)
|
|
70
|
+
assert isinstance(
|
|
71
|
+
self.tile_pattern.tile_shape,
|
|
72
|
+
Number | str | None,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
def __repr__(self):
|
|
76
|
+
return (
|
|
77
|
+
f"Loop({self.rank_name.__repr__()}, {self.tile_pattern}, {self.is_spatial})"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
def __str__(self):
|
|
81
|
+
return (
|
|
82
|
+
"S-" if self.is_spatial else ""
|
|
83
|
+
) + f"{self.rank_name}-{self.tile_pattern}"
|
|
84
|
+
|
|
85
|
+
def pydot_str(self):
|
|
86
|
+
if self.is_spatial:
|
|
87
|
+
return f"S-for R{self.rank_name} size {_expfmt(self.tile_pattern)}"
|
|
88
|
+
return f"for {self.rank_name} size {_expfmt(self.tile_pattern)}"
|
|
89
|
+
|
|
90
|
+
def to_yaml(self):
|
|
91
|
+
return {"type": "loop", **self.__dict__}
|
|
92
|
+
|
|
93
|
+
def merge_next(self, right: "Loop") -> "Loop":
|
|
94
|
+
assert self.tile_pattern == right.tile_pattern
|
|
95
|
+
return Loop(
|
|
96
|
+
self.rank_name | right.rank_name,
|
|
97
|
+
right.tile_pattern,
|
|
98
|
+
self.is_spatial,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
def clear_loop_bound(self, value=0):
|
|
102
|
+
return self.update(tile_pattern=value)
|
|
103
|
+
|
|
104
|
+
def populate(self, nloop: int) -> "Loop":
|
|
105
|
+
tile_pattern = TilePattern(
|
|
106
|
+
tile_shape=stride2col(self.rank_name, nloop),
|
|
107
|
+
initial_tile_shape=initial2col(self.rank_name, nloop),
|
|
108
|
+
calculated_n_iterations=iterations2col(nloop),
|
|
109
|
+
)
|
|
110
|
+
return self.update(tile_pattern=tile_pattern)
|
|
111
|
+
|
|
112
|
+
def _prepend_symbols(self, prepend: str) -> "Loop":
|
|
113
|
+
return self.update(tile_pattern=self.tile_pattern._prepend_symbols(prepend))
|
|
114
|
+
|
|
115
|
+
def clear_symbolic_tile_patterns(self) -> "Loop":
|
|
116
|
+
return self.update(tile_pattern=self.tile_pattern._clear_symbols())
|
|
117
|
+
|
|
118
|
+
def make_fused_loop_symbols(self, prefix: str) -> tuple[dict[str, str], "Loop"]:
|
|
119
|
+
r = {}
|
|
120
|
+
new = self
|
|
121
|
+
|
|
122
|
+
def replace(attr, new):
|
|
123
|
+
g = getattr(self.tile_pattern, attr)
|
|
124
|
+
if not isinstance(g, str):
|
|
125
|
+
return new
|
|
126
|
+
g2 = make_fused_loop_col(f"{prefix}<SEP>{g}")
|
|
127
|
+
r[g] = g2
|
|
128
|
+
return new.update(tile_pattern=new.tile_pattern.update(**{attr: g2}))
|
|
129
|
+
|
|
130
|
+
for s in new.tile_pattern._symbol_attrs():
|
|
131
|
+
new = replace(s, new)
|
|
132
|
+
|
|
133
|
+
return r, new
|
|
134
|
+
|
|
135
|
+
def add_n_iteration_symbols(self) -> "Loop":
|
|
136
|
+
return self.update(tile_pattern=self.tile_pattern.add_n_iteration_symbols())
|
|
137
|
+
|
|
138
|
+
def _rename_to_match(self, other: "Loop") -> tuple["Loop", dict[str, str]]:
|
|
139
|
+
new_tp, renames = self.tile_pattern._rename_to_match(other.tile_pattern)
|
|
140
|
+
return self.update(rank_name=other.rank_name, tile_pattern=new_tp), renames
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@dataclass(frozen=True, eq=True, order=True)
|
|
144
|
+
class TensorReservation(Updatable):
|
|
145
|
+
# This order is important. Above loop index should be before resource name
|
|
146
|
+
# so when we sort reservations for tensors then the backing tensor holder comes
|
|
147
|
+
# first.
|
|
148
|
+
# Size is not included in hash or equality functions. This is because there
|
|
149
|
+
# may be floating point rounding errors in reservation sizes. The other
|
|
150
|
+
# attributes are sufficient to determine equality.
|
|
151
|
+
loops: tuple[Loop]
|
|
152
|
+
name: TensorName
|
|
153
|
+
resource_name: str
|
|
154
|
+
persistent: bool = False
|
|
155
|
+
|
|
156
|
+
def __post_init__(self):
|
|
157
|
+
if self.persistent:
|
|
158
|
+
assert len(self.loops) == 0, "Persistent tensors be above all loops"
|
|
159
|
+
|
|
160
|
+
@property
|
|
161
|
+
def above_loop_index(self) -> int:
|
|
162
|
+
return -1 if self.persistent else len(self.loops)
|
|
163
|
+
|
|
164
|
+
def __str__(self):
|
|
165
|
+
return f"[{self.resource_name}] {self.name} below {self.loops}"
|
|
166
|
+
|
|
167
|
+
def __repr__(self):
|
|
168
|
+
return f"Reservation({repr(self.name)}, {repr(self.loops)}, {repr(self.resource_name)})"
|
|
169
|
+
|
|
170
|
+
def pydot_str(self):
|
|
171
|
+
return f"[{self.resource_name}] {self.name}"
|
|
172
|
+
|
|
173
|
+
def permute(self, permutation) -> "Reservation":
|
|
174
|
+
new_loops = [self.loops[permutation[i]] for i in range(len(self.loops))]
|
|
175
|
+
return self.update(loops=tuple(new_loops))
|
|
176
|
+
|
|
177
|
+
def clear_loop_bounds(self) -> "Reservation":
|
|
178
|
+
return self.update(loops=tuple(loop.clear_loop_bound() for loop in self.loops))
|
|
179
|
+
|
|
180
|
+
def populate_loops(self) -> "TensorReservation":
|
|
181
|
+
return self.update(
|
|
182
|
+
loops=tuple(loop.populate(nloop) for nloop, loop in enumerate(self.loops))
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
@staticmethod
|
|
186
|
+
def get_backing_tensors(
|
|
187
|
+
all_tensors: set["TensorReservation"],
|
|
188
|
+
) -> list["TensorReservation"]:
|
|
189
|
+
id2tensor = defaultdict(lambda: [])
|
|
190
|
+
for t in all_tensors:
|
|
191
|
+
id2tensor[t.name].append(t)
|
|
192
|
+
return sorted(sorted(v)[0] for v in id2tensor.values())
|
|
193
|
+
|
|
194
|
+
def drop_loop_indices(self, loop_indices: set[int]) -> "TensorReservation":
|
|
195
|
+
loops = tuple(l for i, l in enumerate(self.loops) if i not in loop_indices)
|
|
196
|
+
return self.update(loops=loops)
|
|
197
|
+
|
|
198
|
+
def _prepend_symbols(self, prepend: str) -> "TensorReservation":
|
|
199
|
+
return self.update(
|
|
200
|
+
loops=tuple(l._prepend_symbols(prepend) for l in self.loops),
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
def clear_symbolic_tile_patterns(self) -> "TensorReservation":
|
|
204
|
+
return self.update(
|
|
205
|
+
loops=tuple(l.clear_symbolic_tile_patterns() for l in self.loops),
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
def make_fused_loop_symbols(
|
|
209
|
+
self, prefix: str
|
|
210
|
+
) -> tuple[dict[str, str], "TensorReservation"]:
|
|
211
|
+
result = {}
|
|
212
|
+
loops = []
|
|
213
|
+
for l in self.loops:
|
|
214
|
+
r, l = l.make_fused_loop_symbols(prefix)
|
|
215
|
+
result.update(r)
|
|
216
|
+
loops.append(l)
|
|
217
|
+
return result, self.update(loops=tuple(loops))
|
|
218
|
+
|
|
219
|
+
def add_n_iteration_symbols(self) -> "TensorReservation":
|
|
220
|
+
return self.update(
|
|
221
|
+
loops=tuple(l.add_n_iteration_symbols() for l in self.loops),
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
def _rename_to_match(
|
|
225
|
+
self, other: "TensorReservation"
|
|
226
|
+
) -> tuple["TensorReservation", dict[str, str]]:
|
|
227
|
+
renames = {}
|
|
228
|
+
new_loops = []
|
|
229
|
+
for l_mine, l_other in zip(self.loops, other.loops):
|
|
230
|
+
l_mine, new_renames = l_mine._rename_to_match(l_other)
|
|
231
|
+
_update_rename_dict(renames, new_renames)
|
|
232
|
+
new_loops.append(l_mine)
|
|
233
|
+
return self.update(loops=tuple(new_loops)), renames
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class SplitKind(Enum):
|
|
237
|
+
SEQUENTIAL = 0
|
|
238
|
+
PIPELINE = 1
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
@dataclass(frozen=True, order=True, eq=True)
|
|
242
|
+
class Split:
|
|
243
|
+
split: MappingSplit
|
|
244
|
+
above_loop_index: int
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
@dataclass(frozen=True)
|
|
248
|
+
class Compatibility(Updatable):
|
|
249
|
+
tensors: fzs[TensorReservation]
|
|
250
|
+
splits: fzs[Split] = fzs()
|
|
251
|
+
reservation_indices: fzs[int] = fzs()
|
|
252
|
+
check_reservation_indices: bool = True
|
|
253
|
+
|
|
254
|
+
@property
|
|
255
|
+
def n_loops(self) -> int:
|
|
256
|
+
return max([len(s.loops) for s in self.tensors], default=0)
|
|
257
|
+
|
|
258
|
+
@property
|
|
259
|
+
def loops(self) -> tuple[Loop, ...]:
|
|
260
|
+
return max([t.loops for t in self.tensors], key=len) if self.tensors else ()
|
|
261
|
+
|
|
262
|
+
def _get_hash_tuple(self):
|
|
263
|
+
return self.n_loops, self.tensors, self.reservation_indices
|
|
264
|
+
|
|
265
|
+
def __hash__(self):
|
|
266
|
+
return hash(self._get_hash_tuple())
|
|
267
|
+
|
|
268
|
+
def __eq__(self, other):
|
|
269
|
+
return self._get_hash_tuple() == other._get_hash_tuple()
|
|
270
|
+
|
|
271
|
+
def __post_init__(self):
|
|
272
|
+
assert isinstance(self.n_loops, int)
|
|
273
|
+
assert isinstance(self.tensors, fzs)
|
|
274
|
+
assert isinstance(self.splits, fzs)
|
|
275
|
+
assert isinstance(self.reservation_indices, fzs)
|
|
276
|
+
assert (
|
|
277
|
+
max(self.reservation_indices, default=-1) <= self.n_loops
|
|
278
|
+
), f"Extra reservation indices {self.reservation_indices} are greater than n_loops {self.n_loops}"
|
|
279
|
+
if self.check_reservation_indices:
|
|
280
|
+
p = f"are not in reservation indices {self.reservation_indices}"
|
|
281
|
+
assert all(
|
|
282
|
+
i >= 0 for i in self.reservation_indices
|
|
283
|
+
), f"Reservation indices {self.reservation_indices} are not all >= 0"
|
|
284
|
+
assert all(
|
|
285
|
+
s.above_loop_index in self.reservation_indices for s in self.splits
|
|
286
|
+
), f"Split above loop indices {self.splits} {p}"
|
|
287
|
+
assert all(
|
|
288
|
+
len(s.loops) in self.reservation_indices for s in self.tensors
|
|
289
|
+
), f"Tensor loops {self.tensors} {p}"
|
|
290
|
+
|
|
291
|
+
def get_backing_levels(self) -> dict[str, int]:
|
|
292
|
+
backings = {}
|
|
293
|
+
for t in self.tensors:
|
|
294
|
+
prev = backings.get(t.name, t.above_loop_index)
|
|
295
|
+
backings[t.name] = min(prev, t.above_loop_index)
|
|
296
|
+
return backings
|
|
297
|
+
|
|
298
|
+
@property
|
|
299
|
+
def tensor_names(self) -> set[str]:
|
|
300
|
+
return {t.name for t in self.tensors}
|
|
301
|
+
|
|
302
|
+
@property
|
|
303
|
+
def max_above_loop_index(self) -> int:
|
|
304
|
+
if len(self.tensors) == 0:
|
|
305
|
+
return 0
|
|
306
|
+
return max(s.above_loop_index for s in self.tensors)
|
|
307
|
+
|
|
308
|
+
def shared_loop_index(self, live_tensors: set[str]) -> int:
|
|
309
|
+
n = [l for t, l in self.get_backing_levels().items() if t in live_tensors]
|
|
310
|
+
return max(n) - 1 if n else -1
|
|
311
|
+
|
|
312
|
+
def __len__(self) -> int:
|
|
313
|
+
return self.max_above_loop_index
|
|
314
|
+
|
|
315
|
+
def _rename_to_match(
|
|
316
|
+
self, other: "Compatibility"
|
|
317
|
+
) -> tuple["Compatibility", dict[str, str]]:
|
|
318
|
+
renames = {}
|
|
319
|
+
assert (
|
|
320
|
+
self.clear_symbolic_tile_patterns() == other.clear_symbolic_tile_patterns()
|
|
321
|
+
)
|
|
322
|
+
tensors = []
|
|
323
|
+
for t in self.tensors:
|
|
324
|
+
other_t = other.get_tensor_by_name(t.name)
|
|
325
|
+
t, new_renames = t._rename_to_match(other_t)
|
|
326
|
+
tensors.append(t)
|
|
327
|
+
_update_rename_dict(renames, new_renames)
|
|
328
|
+
|
|
329
|
+
return (
|
|
330
|
+
Compatibility(
|
|
331
|
+
tensors=fzs(tensors),
|
|
332
|
+
splits=self.splits,
|
|
333
|
+
reservation_indices=self.reservation_indices,
|
|
334
|
+
),
|
|
335
|
+
renames,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
def clear_dead_tensors(
|
|
339
|
+
self,
|
|
340
|
+
live_tensors: set[str] | Literal["All"],
|
|
341
|
+
) -> "Compatibility":
|
|
342
|
+
"""
|
|
343
|
+
Return a new compatibility with "dead" tensors removed by:
|
|
344
|
+
1. keeping only loops relevant to `live_tensors` and
|
|
345
|
+
2. keeping only `live_tensors`.
|
|
346
|
+
|
|
347
|
+
If `keep_loops` is `True`, then all loops are kept.
|
|
348
|
+
If `keep_tensors` is a set, tensors in the set are kept.
|
|
349
|
+
"""
|
|
350
|
+
if live_tensors == "All":
|
|
351
|
+
live_tensors = self.tensor_names
|
|
352
|
+
|
|
353
|
+
remaining_tensors = fzs(s for s in self.tensors if s.name in live_tensors)
|
|
354
|
+
new_n_loops = max((len(s.loops) for s in remaining_tensors), default=0)
|
|
355
|
+
new_splits = fzs(
|
|
356
|
+
split for split in self.splits if split.above_loop_index < new_n_loops
|
|
357
|
+
)
|
|
358
|
+
reservation_indices = fzs(
|
|
359
|
+
{min(i, new_n_loops) for i in self.reservation_indices}
|
|
360
|
+
)
|
|
361
|
+
reservation_indices = fzs(x for x in reservation_indices if x >= 0)
|
|
362
|
+
|
|
363
|
+
return self.update(
|
|
364
|
+
tensors=remaining_tensors,
|
|
365
|
+
splits=new_splits,
|
|
366
|
+
reservation_indices=reservation_indices,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
def __lt__(self, other):
|
|
370
|
+
return self._get_hash_tuple() < other._get_hash_tuple()
|
|
371
|
+
|
|
372
|
+
def __str__(self):
|
|
373
|
+
return self.__repr__()
|
|
374
|
+
|
|
375
|
+
def __repr__(self):
|
|
376
|
+
return f"Compatibility(n_loops={self.n_loops}, tensors={repr(self.tensors)}), splits={repr(self.splits)}"
|
|
377
|
+
|
|
378
|
+
def _and_tensors_with_names(self, names: set[str]) -> "Compatibility":
|
|
379
|
+
return fzs(s for s in self.tensors if s.name in names)
|
|
380
|
+
|
|
381
|
+
def merge_next(
|
|
382
|
+
self,
|
|
383
|
+
right: "Compatibility",
|
|
384
|
+
live_tensors: set[str],
|
|
385
|
+
mixable_ranks: dict[Rank, set[Rank]],
|
|
386
|
+
) -> "Compatibility":
|
|
387
|
+
self_freed = self.clear_dead_tensors(live_tensors)
|
|
388
|
+
right_freed = right.clear_dead_tensors(live_tensors)
|
|
389
|
+
if self_freed.n_loops > right_freed.n_loops:
|
|
390
|
+
# This can be relaxed if we have a way to do order-independent joining
|
|
391
|
+
# and/or non-looptree mappings.
|
|
392
|
+
raise ValueError(
|
|
393
|
+
f"Can't merge. I have more loops than the next, so my dataflow can't "
|
|
394
|
+
f"be carried through a LoopTree to where it's needed."
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
live_minus_mine = live_tensors - {s.name for s in self.tensors}
|
|
398
|
+
tensors_a = self._and_tensors_with_names(live_tensors)
|
|
399
|
+
tensors_b = right._and_tensors_with_names(live_minus_mine)
|
|
400
|
+
|
|
401
|
+
# TODO: split handling?
|
|
402
|
+
joined = Compatibility(
|
|
403
|
+
tensors=tensors_a | tensors_b,
|
|
404
|
+
reservation_indices=self_freed.reservation_indices
|
|
405
|
+
| right_freed.reservation_indices,
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
if mixable_ranks is not None and not joined._is_valid(mixable_ranks):
|
|
409
|
+
raise ValueError(f"Invalid rank mixing.")
|
|
410
|
+
|
|
411
|
+
return joined
|
|
412
|
+
|
|
413
|
+
def has_tensor(self, *tensors: TensorReservation) -> bool:
|
|
414
|
+
return all(any(s == t for s in self.tensors) for t in tensors)
|
|
415
|
+
|
|
416
|
+
def _permute_stops(self):
|
|
417
|
+
stops = set(len(s.loops) for s in self.tensors)
|
|
418
|
+
stops |= self.reservation_indices
|
|
419
|
+
stops |= set(s.above_loop_index for s in self.splits)
|
|
420
|
+
return stops
|
|
421
|
+
|
|
422
|
+
def permute(
|
|
423
|
+
self,
|
|
424
|
+
loop_changes: list[int],
|
|
425
|
+
) -> "Compatibility":
|
|
426
|
+
assert len(loop_changes) <= self.n_loops
|
|
427
|
+
assert set(loop_changes) == set(
|
|
428
|
+
range(len(loop_changes))
|
|
429
|
+
), f"Loop changes {loop_changes} are not a permutation of {range(len(loop_changes))}"
|
|
430
|
+
if len(loop_changes) < len(self.loops):
|
|
431
|
+
loop_changes = loop_changes + list(
|
|
432
|
+
range(len(loop_changes), len(self.loops))
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
permute_stops = self._permute_stops()
|
|
436
|
+
for i, c in enumerate(loop_changes):
|
|
437
|
+
for r in permute_stops:
|
|
438
|
+
assert (i < r) == (
|
|
439
|
+
c < r
|
|
440
|
+
), f"Loop changes {loop_changes} cross reservation {r}"
|
|
441
|
+
new_tensors = fzs(s.permute(loop_changes) for s in self.tensors)
|
|
442
|
+
return self.update(tensors=new_tensors)
|
|
443
|
+
|
|
444
|
+
def make_equivalent_permutations(self) -> list[tuple["Compatibility", list[int]]]:
|
|
445
|
+
# Get contiguous blocks of loops with no tensor reservation between them
|
|
446
|
+
blocks = []
|
|
447
|
+
current_block = []
|
|
448
|
+
permute_stops = self._permute_stops()
|
|
449
|
+
for i in range(self.n_loops):
|
|
450
|
+
# Can't permute loops if there's a reservation between them
|
|
451
|
+
if i in permute_stops:
|
|
452
|
+
blocks.append(current_block)
|
|
453
|
+
current_block = []
|
|
454
|
+
current_block.append(i)
|
|
455
|
+
if current_block:
|
|
456
|
+
blocks.append(current_block)
|
|
457
|
+
|
|
458
|
+
per_block_permutations = [
|
|
459
|
+
list(itertools.permutations(block)) for block in blocks
|
|
460
|
+
]
|
|
461
|
+
all_permutations = list(itertools.product(*per_block_permutations))
|
|
462
|
+
all_permutations = [
|
|
463
|
+
list(itertools.chain(*loop_changes)) for loop_changes in all_permutations
|
|
464
|
+
]
|
|
465
|
+
return [(self.permute(p), p) for p in all_permutations]
|
|
466
|
+
|
|
467
|
+
def get_tensor_by_name(self, tensor: str) -> TensorReservation:
|
|
468
|
+
for s in self.tensors:
|
|
469
|
+
if s.name == tensor:
|
|
470
|
+
return s
|
|
471
|
+
raise ValueError(f"No reservation found for {tensor}")
|
|
472
|
+
|
|
473
|
+
def per_tensor_compatibility(self) -> dict[str, "Compatibility"]:
|
|
474
|
+
result = {}
|
|
475
|
+
for s in self.tensors:
|
|
476
|
+
result[s.name] = self.clear_dead_tensors(set([s.name]))
|
|
477
|
+
return result
|
|
478
|
+
|
|
479
|
+
def clear_loop_bounds(self) -> "Compatibility":
|
|
480
|
+
return self.update(tensors=fzs(t.clear_loop_bounds() for t in self.tensors))
|
|
481
|
+
|
|
482
|
+
def compatible_with(self, other: "Compatibility") -> bool:
|
|
483
|
+
return True
|
|
484
|
+
# for a in self.tensors:
|
|
485
|
+
# a = a.loops
|
|
486
|
+
# for b in other.tensors:
|
|
487
|
+
# b = b.loops
|
|
488
|
+
# if a[:len(b)] != b[:len(a)]:
|
|
489
|
+
# return False
|
|
490
|
+
# return True
|
|
491
|
+
|
|
492
|
+
def populate_loops(self):
|
|
493
|
+
return self.update(
|
|
494
|
+
tensors=fzs(t.populate_loops() for t in self.tensors),
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
@classmethod
|
|
498
|
+
def from_mapping(
|
|
499
|
+
cls,
|
|
500
|
+
mapping: Mapping,
|
|
501
|
+
tensors: set[TensorName],
|
|
502
|
+
rank_variable_to_ranks: dict[TensorName, dict[RankVariable, Rank]],
|
|
503
|
+
) -> "Compatibility":
|
|
504
|
+
# TODO: update compatibility to handle spatial-for loop per-tensor update
|
|
505
|
+
tensor_indices = []
|
|
506
|
+
split_above_loop_indices = []
|
|
507
|
+
reservation_indices = []
|
|
508
|
+
backing_remaining = set(tensors)
|
|
509
|
+
n_seen_loops = 0
|
|
510
|
+
n_fused_loops = 0
|
|
511
|
+
for i, n in enumerate(mapping.nodes):
|
|
512
|
+
if isinstance(n, MappingReservation):
|
|
513
|
+
reservation_indices.append(n_seen_loops)
|
|
514
|
+
if not (backing := set(n.purposes) & backing_remaining):
|
|
515
|
+
continue
|
|
516
|
+
backing_remaining -= backing
|
|
517
|
+
assert (
|
|
518
|
+
len(n.purposes) == 1
|
|
519
|
+
), "Backing reservations should have only one purpose"
|
|
520
|
+
tensor_indices.append(i)
|
|
521
|
+
elif isinstance(n, MappingSplit):
|
|
522
|
+
split_above_loop_indices.append(n_seen_loops)
|
|
523
|
+
elif isinstance(n, MappingLoop):
|
|
524
|
+
n_seen_loops += 1
|
|
525
|
+
n_fused_loops += bool(backing_remaining)
|
|
526
|
+
elif isinstance(n, TensorHolder):
|
|
527
|
+
reservation_indices.append(n_seen_loops)
|
|
528
|
+
|
|
529
|
+
reservation_indices = fzs(min(n, n_fused_loops) for n in reservation_indices)
|
|
530
|
+
reservation_indices = fzs(x for x in reservation_indices if x >= 0)
|
|
531
|
+
|
|
532
|
+
assert (
|
|
533
|
+
not backing_remaining
|
|
534
|
+
), f"Tensors {backing_remaining} not found in mapping"
|
|
535
|
+
|
|
536
|
+
def get_rank(rank_variable, tensor):
|
|
537
|
+
rv = rank_variable_to_ranks[tensor].get(rank_variable, set())
|
|
538
|
+
assert (
|
|
539
|
+
len(rv) <= 1
|
|
540
|
+
), f"Rank variable {rank_variable} indexes into multiple ranks {rv} for tensor {tensor} "
|
|
541
|
+
return next(iter(rv), Rank("NO RANK. RECOMPUTED."))
|
|
542
|
+
|
|
543
|
+
def make_loops(above_index: int, tensor_name: TensorName) -> list[MappingLoop]:
|
|
544
|
+
loops = [
|
|
545
|
+
n for n in mapping.nodes[:above_index] if isinstance(n, MappingLoop)
|
|
546
|
+
]
|
|
547
|
+
loops = [
|
|
548
|
+
Loop(
|
|
549
|
+
rank_name=get_rank(n.rank_variable, tensor_name),
|
|
550
|
+
tile_pattern=n.tile_pattern._symbol2str(),
|
|
551
|
+
is_spatial=isinstance(n, Spatial),
|
|
552
|
+
)
|
|
553
|
+
for n in loops
|
|
554
|
+
]
|
|
555
|
+
return tuple(loops)
|
|
556
|
+
|
|
557
|
+
return cls(
|
|
558
|
+
tensors=fzs(
|
|
559
|
+
TensorReservation(
|
|
560
|
+
name=mapping.nodes[i].purpose,
|
|
561
|
+
loops=make_loops(i, mapping.nodes[i].purpose),
|
|
562
|
+
resource_name=mapping.nodes[i].resource,
|
|
563
|
+
persistent=mapping.nodes[i].persistent,
|
|
564
|
+
)
|
|
565
|
+
for i in tensor_indices
|
|
566
|
+
),
|
|
567
|
+
splits=fzs(
|
|
568
|
+
Split(split=n, above_loop_index=i) for i in split_above_loop_indices
|
|
569
|
+
),
|
|
570
|
+
reservation_indices=fzs(reservation_indices),
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
def symbols(self) -> list[str]:
|
|
574
|
+
symbols = []
|
|
575
|
+
|
|
576
|
+
def add(x: str):
|
|
577
|
+
if isinstance(x, str) and x not in symbols:
|
|
578
|
+
symbols.append(x)
|
|
579
|
+
|
|
580
|
+
for t in self.tensors:
|
|
581
|
+
for l in t.loops:
|
|
582
|
+
add(l.tile_pattern.initial_tile_shape)
|
|
583
|
+
add(l.tile_pattern.tile_shape)
|
|
584
|
+
add(l.tile_pattern.calculated_n_iterations)
|
|
585
|
+
return symbols
|
|
586
|
+
|
|
587
|
+
def drop_loop_indices(self, loop_indices: set[int]) -> "Compatibility":
|
|
588
|
+
loop_indices = set(loop_indices)
|
|
589
|
+
tensors = fzs(t.drop_loop_indices(loop_indices) for t in self.tensors)
|
|
590
|
+
splits = fzs(s for s in self.splits if s.above_loop_index not in loop_indices)
|
|
591
|
+
|
|
592
|
+
def adjust(i: int) -> int:
|
|
593
|
+
return i - sum(x < i for x in loop_indices)
|
|
594
|
+
|
|
595
|
+
reservation_indices = fzs(adjust(i) for i in self.reservation_indices)
|
|
596
|
+
reservation_indices = fzs(x for x in reservation_indices if x >= 0)
|
|
597
|
+
|
|
598
|
+
splits = fzs(
|
|
599
|
+
s.update(above_loop_index=adjust(s.above_loop_index)) for s in self.splits
|
|
600
|
+
)
|
|
601
|
+
return Compatibility(
|
|
602
|
+
tensors=tensors,
|
|
603
|
+
splits=splits,
|
|
604
|
+
reservation_indices=reservation_indices,
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
def _prepend_symbols(self, prepend: str) -> "Compatibility":
|
|
608
|
+
return self.update(
|
|
609
|
+
tensors=fzs(t._prepend_symbols(prepend) for t in self.tensors)
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
def clear_tile_patterns_and_reservation_indices(self) -> "Compatibility":
|
|
613
|
+
return self.update(
|
|
614
|
+
tensors=fzs(t.clear_symbolic_tile_patterns() for t in self.tensors),
|
|
615
|
+
reservation_indices=fzs(),
|
|
616
|
+
check_reservation_indices=False,
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
def clear_symbolic_tile_patterns(self) -> "Compatibility":
|
|
620
|
+
return self.update(
|
|
621
|
+
tensors=fzs(t.clear_symbolic_tile_patterns() for t in self.tensors)
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
def make_fused_loop_symbols(
|
|
625
|
+
self, prefix: str
|
|
626
|
+
) -> tuple[dict[str, str], "Compatibility"]:
|
|
627
|
+
result = {}
|
|
628
|
+
tensors = []
|
|
629
|
+
for t in self.tensors:
|
|
630
|
+
r, t = t.make_fused_loop_symbols(prefix)
|
|
631
|
+
tensors.append(t)
|
|
632
|
+
result.update(r)
|
|
633
|
+
|
|
634
|
+
return result, self.update(tensors=fzs(tensors))
|
|
635
|
+
|
|
636
|
+
def add_n_iteration_symbols(self) -> "Compatibility":
|
|
637
|
+
return self.update(
|
|
638
|
+
tensors=fzs(t.add_n_iteration_symbols() for t in self.tensors)
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
def _is_valid(self, mixable_ranks: dict[Rank, set[Rank]]) -> bool:
|
|
642
|
+
# Mixable ranks: Ranks that may be co-iterated by a single loop.
|
|
643
|
+
ranks_at_each_loop_index = []
|
|
644
|
+
for i in range(self.n_loops):
|
|
645
|
+
ranks_at_each_loop_index.append(
|
|
646
|
+
set(t.loops[i].rank_name for t in self.tensors if i < len(t.loops))
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
for ranks in ranks_at_each_loop_index:
|
|
650
|
+
for r1, r2 in itertools.combinations(ranks, 2):
|
|
651
|
+
if r1 not in mixable_ranks[r2]:
|
|
652
|
+
return False
|
|
653
|
+
return True
|