accelforge 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accelforge/__init__.py +21 -0
- accelforge/_accelerated_imports.py +16 -0
- accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
- accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
- accelforge/_deprecate/_simanneal/simanneal.py +666 -0
- accelforge/_deprecate/_simanneal/tracking.py +105 -0
- accelforge/_deprecate/_simanneal/wrappers.py +218 -0
- accelforge/_deprecate/_simanneal2/__init__.py +7 -0
- accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
- accelforge/_deprecate/_simanneal2/tracking.py +116 -0
- accelforge/_deprecate/compatibility_util.py +181 -0
- accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
- accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
- accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
- accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
- accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
- accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
- accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
- accelforge/_deprecate/tags.py +69 -0
- accelforge/_deprecate/viz/__init__.py +0 -0
- accelforge/_deprecate/viz/interactive.py +159 -0
- accelforge/_deprecate/viz/reservationtree.py +307 -0
- accelforge/_deprecate/viz/ski_slope.py +88 -0
- accelforge/_version.py +15 -0
- accelforge/examples.py +39 -0
- accelforge/frontend/__init__.py +10 -0
- accelforge/frontend/_binding.py +129 -0
- accelforge/frontend/_workload_isl/__init__.py +2 -0
- accelforge/frontend/_workload_isl/_isl.py +149 -0
- accelforge/frontend/_workload_isl/_symbolic.py +141 -0
- accelforge/frontend/arch copy.py +1544 -0
- accelforge/frontend/arch.py +1642 -0
- accelforge/frontend/config.py +63 -0
- accelforge/frontend/mapper/__init__.py +5 -0
- accelforge/frontend/mapper/ffm.py +126 -0
- accelforge/frontend/mapper/mapper.py +7 -0
- accelforge/frontend/mapper/metrics.py +30 -0
- accelforge/frontend/mapping/__init__.py +1 -0
- accelforge/frontend/mapping/mapping.py +1736 -0
- accelforge/frontend/model.py +14 -0
- accelforge/frontend/renames.py +150 -0
- accelforge/frontend/spec copy.py +230 -0
- accelforge/frontend/spec.py +301 -0
- accelforge/frontend/variables.py +12 -0
- accelforge/frontend/workload.py +952 -0
- accelforge/mapper/FFM/__init__.py +9 -0
- accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
- accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
- accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
- accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
- accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
- accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
- accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
- accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
- accelforge/mapper/FFM/data.py +61 -0
- accelforge/mapper/FFM/main copy.py +236 -0
- accelforge/mapper/FFM/main.py +208 -0
- accelforge/mapper/FFM/mappings.py +510 -0
- accelforge/mapper/FFM/pmappings.py +310 -0
- accelforge/mapper/__init__.py +4 -0
- accelforge/mapper.py +0 -0
- accelforge/model/__init__.py +1 -0
- accelforge/model/_looptree/__init__.py +0 -0
- accelforge/model/_looptree/accesses.py +335 -0
- accelforge/model/_looptree/capacity/__init__.py +1 -0
- accelforge/model/_looptree/capacity/aggregators.py +36 -0
- accelforge/model/_looptree/capacity/capacity.py +47 -0
- accelforge/model/_looptree/energy.py +150 -0
- accelforge/model/_looptree/equivalent_ranks.py +29 -0
- accelforge/model/_looptree/latency/__init__.py +1 -0
- accelforge/model/_looptree/latency/latency.py +98 -0
- accelforge/model/_looptree/latency/memory.py +120 -0
- accelforge/model/_looptree/latency/processors.py +92 -0
- accelforge/model/_looptree/mapping_utilities.py +71 -0
- accelforge/model/_looptree/reuse/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
- accelforge/model/_looptree/reuse/isl/des.py +59 -0
- accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
- accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
- accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
- accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
- accelforge/model/_looptree/run.py +122 -0
- accelforge/model/_looptree/types.py +26 -0
- accelforge/model/_looptree/visualization/__init__.py +0 -0
- accelforge/model/_looptree/visualization/occupancy.py +11 -0
- accelforge/model/main.py +222 -0
- accelforge/plotting/__init__.py +2 -0
- accelforge/plotting/mappings.py +219 -0
- accelforge/plotting/specs.py +57 -0
- accelforge/util/__init__.py +4 -0
- accelforge/util/_base_analysis_types.py +24 -0
- accelforge/util/_basetypes.py +1089 -0
- accelforge/util/_frozenset.py +36 -0
- accelforge/util/_isl.py +29 -0
- accelforge/util/_itertools.py +14 -0
- accelforge/util/_mathfuncs.py +57 -0
- accelforge/util/_parse_expressions.py +339 -0
- accelforge/util/_picklecache.py +32 -0
- accelforge/util/_setexpressions.py +268 -0
- accelforge/util/_sympy/__init__.py +0 -0
- accelforge/util/_sympy/broadcast_max.py +18 -0
- accelforge/util/_visualization.py +112 -0
- accelforge/util/_yaml.py +579 -0
- accelforge/util/parallel.py +193 -0
- accelforge-0.0.1.dist-info/METADATA +64 -0
- accelforge-0.0.1.dist-info/RECORD +258 -0
- accelforge-0.0.1.dist-info/WHEEL +5 -0
- accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
- accelforge-0.0.1.dist-info/top_level.txt +5 -0
- docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
- docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
- docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
- docs/_build/html/_sources/fastfusion.rst.txt +20 -0
- docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
- docs/_build/html/_sources/index.rst.txt +87 -0
- docs/_build/html/_sources/modules.rst.txt +7 -0
- docs/_build/html/_sources/notes/citation.rst.txt +45 -0
- docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
- docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
- docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
- docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
- docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
- docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
- docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
- docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
- docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
- docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
- docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
- docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
- docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
- docs/_build/html/_sources/notes/spec.rst.txt +36 -0
- docs/source/_ext/include_attrs.py +213 -0
- docs/source/_ext/include_docstring.py +364 -0
- docs/source/_ext/include_functions.py +154 -0
- docs/source/_ext/include_notebook.py +131 -0
- docs/source/_ext/include_yaml.py +119 -0
- docs/source/_ext/inherited_attributes.py +222 -0
- docs/source/_ext/paths.py +4 -0
- docs/source/conf.py +79 -0
- examples/arches/compute_in_memory/_include.yaml +74 -0
- examples/arches/compute_in_memory/_include_functions.py +229 -0
- examples/arches/compute_in_memory/_load_spec.py +57 -0
- examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
- examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
- examples/arches/compute_in_memory/components/misc.py +195 -0
- examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
- examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
- examples/arches/compute_in_memory/isaac.yaml +233 -0
- examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
- examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
- examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
- examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
- examples/arches/eyeriss.yaml +68 -0
- examples/arches/fanout_variations/at_glb.yaml +31 -0
- examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
- examples/arches/fanout_variations/at_mac.yaml +31 -0
- examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
- examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
- examples/arches/nvdla.yaml +47 -0
- examples/arches/simple.yaml +28 -0
- examples/arches/tpu_v4i.yaml +67 -0
- examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
- examples/misc/component_annotated.yaml +33 -0
- examples/workloads/gpt3_6.7B.yaml +124 -0
- examples/workloads/matmuls.yaml +20 -0
- examples/workloads/mobilenet_28.yaml +81 -0
- examples/workloads/mobilenet_various_separate.yaml +106 -0
- examples/workloads/three_matmuls_annotated.yaml +59 -0
- notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
- notebooks/compute_in_memory/_scripts.py +339 -0
- notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
- notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
- notebooks/paths.py +4 -0
- notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
- notebooks/tutorials/FFM.ipynb +3498 -0
- notebooks/tutorials/_include.py +48 -0
- notebooks/tutorials/component_energy_area.ipynb +363 -0
- tests/Q_mapping.yaml +38 -0
- tests/__init__.py +0 -0
- tests/conv.mapping.yaml +27 -0
- tests/conv.workload.yaml +13 -0
- tests/conv_sym.mapping.yaml +43 -0
- tests/copy.mapping.yaml +35 -0
- tests/copy.workload.yaml +15 -0
- tests/distribuffers/__init__.py +0 -0
- tests/distribuffers/multicast/test_cases.yaml +482 -0
- tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
- tests/distribuffers/spec/distributed.yaml +100 -0
- tests/distribuffers/spec/logical_arch.yaml +32 -0
- tests/distribuffers/spec/physical_arch.yaml +69 -0
- tests/distribuffers/test_binding.py +48 -0
- tests/frontend/__init__.py +0 -0
- tests/frontend/test_mapping_viz.py +52 -0
- tests/mapper/__init__.py +0 -0
- tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
- tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
- tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
- tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
- tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
- tests/mapper/test_mapping_to_isl.py +90 -0
- tests/mapper/test_spatial_reuse_analysis.py +67 -0
- tests/mapper/test_temporal_reuse_analysis.py +56 -0
- tests/mapper/util.py +58 -0
- tests/matmul.mapping.yaml +29 -0
- tests/matmul.workload.yaml +12 -0
- tests/matmul_spatial.mapping.yaml +44 -0
- tests/mha.renames.yaml +65 -0
- tests/mha.workload.yaml +67 -0
- tests/mha.yaml +59 -0
- tests/mha_full.workload.yaml +67 -0
- tests/mobilenet.workload.yaml +35 -0
- tests/mobilenet_long.workload.yaml +64 -0
- tests/pmappingcache.py +24 -0
- tests/processing_stage.arch.yaml +40 -0
- tests/snowcat.arch.yaml +36 -0
- tests/test_ffm_join_pmappings.py +106 -0
- tests/test_ffm_make_pmappings.py +82 -0
- tests/test_ffm_make_tile_shapes.py +49 -0
- tests/test_mapper.py +100 -0
- tests/test_model.py +37 -0
- tests/test_plotting.py +72 -0
- tests/test_processing_stage.py +46 -0
- tests/test_symbolic_model.py +248 -0
- tests/test_workload.py +141 -0
|
@@ -0,0 +1,685 @@
|
|
|
1
|
+
"""
|
|
2
|
+
File for all the functions that conduct tiling analysis for the overall mapping
|
|
3
|
+
analysis.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from collections import defaultdict, deque
|
|
7
|
+
from typing import List, Tuple, Optional
|
|
8
|
+
|
|
9
|
+
from pprint import pformat
|
|
10
|
+
|
|
11
|
+
import islpy as isl
|
|
12
|
+
|
|
13
|
+
from accelforge.frontend.mapping import (
|
|
14
|
+
# Types
|
|
15
|
+
MappingNode,
|
|
16
|
+
# Mapping objects
|
|
17
|
+
Mapping,
|
|
18
|
+
MappingNodeWithChildren,
|
|
19
|
+
Nested,
|
|
20
|
+
# Physical object types in Mappings.
|
|
21
|
+
Compute,
|
|
22
|
+
Storage,
|
|
23
|
+
# Logical object types in Mappings.
|
|
24
|
+
Loop,
|
|
25
|
+
Spatial,
|
|
26
|
+
Temporal,
|
|
27
|
+
Split,
|
|
28
|
+
)
|
|
29
|
+
from accelforge.frontend.workload import (
|
|
30
|
+
# Workload class for all of AccelForge.
|
|
31
|
+
Workload,
|
|
32
|
+
)
|
|
33
|
+
from accelforge.frontend._workload_isl._isl import (
|
|
34
|
+
get_einsum_operation_space,
|
|
35
|
+
get_projection_map,
|
|
36
|
+
)
|
|
37
|
+
from accelforge.frontend.mapping import TensorName
|
|
38
|
+
from accelforge.model._looptree.reuse.isl.isl_functions import (
|
|
39
|
+
add_dims_preserve_name_map,
|
|
40
|
+
insert_dims_preserve_name_map,
|
|
41
|
+
map_to_prior_coordinate,
|
|
42
|
+
)
|
|
43
|
+
from accelforge.model._looptree.reuse.isl.mapping_to_isl import DUMP_ISL_IR
|
|
44
|
+
from accelforge.model._looptree.reuse.isl.mapping_to_isl.types import (
|
|
45
|
+
EinsumName,
|
|
46
|
+
Tiling,
|
|
47
|
+
BranchTiling,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_mapping_group_einsums(
|
|
52
|
+
mapping: Mapping,
|
|
53
|
+
) -> defaultdict[MappingNode, set[EinsumName]]:
|
|
54
|
+
"""
|
|
55
|
+
From a mapping, get the group of einsums for a given node.
|
|
56
|
+
|
|
57
|
+
Parameters
|
|
58
|
+
----------
|
|
59
|
+
mapping:
|
|
60
|
+
The mapping we are getting the grouped einsums for.
|
|
61
|
+
|
|
62
|
+
Returns
|
|
63
|
+
-------
|
|
64
|
+
A dictionary relating a MappingNode to a set of einsums.
|
|
65
|
+
"""
|
|
66
|
+
# Each pair is a (current_node, last_non_branch_node)
|
|
67
|
+
dfs_stack: deque[Tuple[MappingNode, MappingNode]] = deque()
|
|
68
|
+
# Each pair is a (last_non_branch_node, set_of_children_nodes)
|
|
69
|
+
child_stack: deque[Tuple[MappingNode, set[MappingNode]]] = deque()
|
|
70
|
+
result: defaultdict[MappingNode, set[EinsumName]] = defaultdict(set)
|
|
71
|
+
|
|
72
|
+
# Start DFS hierarchical search from the root.
|
|
73
|
+
dfs_stack.append((mapping, mapping))
|
|
74
|
+
|
|
75
|
+
# Exhaustive DFS search.
|
|
76
|
+
while dfs_stack:
|
|
77
|
+
# Grabs latest node to search.
|
|
78
|
+
node, last_non_branch = dfs_stack.pop()
|
|
79
|
+
|
|
80
|
+
# Differentiates behavior by number of child nodes.
|
|
81
|
+
match node:
|
|
82
|
+
case MappingNodeWithChildren():
|
|
83
|
+
match len(node.nodes):
|
|
84
|
+
# No children, log as a folded result.
|
|
85
|
+
case 0:
|
|
86
|
+
# Note:: Check necesary in case Distrobuffers elides
|
|
87
|
+
# computes into one large unit.
|
|
88
|
+
if isinstance(node, Compute):
|
|
89
|
+
result[last_non_branch].add(node.einsum)
|
|
90
|
+
else:
|
|
91
|
+
raise TypeError(
|
|
92
|
+
f"The following node should be of class "
|
|
93
|
+
f"Compute as it has no children:\n---\n{node}"
|
|
94
|
+
)
|
|
95
|
+
# Explore the children further.
|
|
96
|
+
case 1:
|
|
97
|
+
dfs_stack.append((node.nodes[0], last_non_branch))
|
|
98
|
+
# Log all branching children and explore all children.
|
|
99
|
+
case _:
|
|
100
|
+
children: set[MappingNode] = set(node.nodes)
|
|
101
|
+
child_stack.append((last_non_branch, children))
|
|
102
|
+
dfs_stack.extend((child, child) for child in children)
|
|
103
|
+
# Assumed no children, log as a folded result.
|
|
104
|
+
case Compute():
|
|
105
|
+
result[last_non_branch].add(node.einsum)
|
|
106
|
+
# These had children in Timeloop we had to add to the DFS, but because
|
|
107
|
+
# of our extension of dfs_stack we can just skip this node.
|
|
108
|
+
case Spatial() | Temporal() | Storage():
|
|
109
|
+
continue
|
|
110
|
+
case _:
|
|
111
|
+
raise AttributeError(
|
|
112
|
+
f"The following node of class {type(node)} has "
|
|
113
|
+
f"indeterminant number of children:\n---\n"
|
|
114
|
+
f"{node}"
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Push up einsums to parents.
|
|
118
|
+
for node, children in reversed(child_stack):
|
|
119
|
+
node_einsum_set: set[EinsumName] = result[node]
|
|
120
|
+
for child in children:
|
|
121
|
+
node_einsum_set.update(result[child])
|
|
122
|
+
|
|
123
|
+
return result
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def get_head_among_einsums(
|
|
127
|
+
einsum_set: set[EinsumName], workload: Workload
|
|
128
|
+
) -> set[EinsumName]:
|
|
129
|
+
"""
|
|
130
|
+
Gets the provider einsums that only consume data (i.e., sink einsums).
|
|
131
|
+
|
|
132
|
+
Parameters
|
|
133
|
+
----------
|
|
134
|
+
einsum_set:
|
|
135
|
+
Set of einsums to consider.
|
|
136
|
+
workload:
|
|
137
|
+
The workload context the einsums exist in.
|
|
138
|
+
|
|
139
|
+
Returns
|
|
140
|
+
-------
|
|
141
|
+
The set of all head einsums.
|
|
142
|
+
"""
|
|
143
|
+
# Returns set of einsums that are not data producers.
|
|
144
|
+
return {
|
|
145
|
+
einsum
|
|
146
|
+
for einsum in einsum_set
|
|
147
|
+
if all(
|
|
148
|
+
not any(
|
|
149
|
+
consumer.name in einsum_set
|
|
150
|
+
for consumer in workload.einsums_with_tensor_as_input(output_tensor)
|
|
151
|
+
)
|
|
152
|
+
for output_tensor in workload.einsums[einsum].output_tensor_names
|
|
153
|
+
)
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def add_new_tile_dim(
|
|
158
|
+
old_tiling: Tiling, dim_idx: int, tile_size: int, rank_var: Optional[str] = None
|
|
159
|
+
) -> Tiling:
|
|
160
|
+
"""
|
|
161
|
+
Given a tiling, add a new dimension to the tiling.
|
|
162
|
+
|
|
163
|
+
Parameters
|
|
164
|
+
----------
|
|
165
|
+
old_tiling:
|
|
166
|
+
The previous tiling the mapper proposed.
|
|
167
|
+
dim_idx:
|
|
168
|
+
The index of the dimension being tiled.
|
|
169
|
+
tile_size:
|
|
170
|
+
The size of the tiling on dim_idx.
|
|
171
|
+
rank_var:
|
|
172
|
+
Rank variable name to assign to the new input dimension, if provided.
|
|
173
|
+
|
|
174
|
+
Returns
|
|
175
|
+
-------
|
|
176
|
+
The new Tiling with tiled dimension at dim_idx.
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
# new_tiling has one extra dimension at the end compared to old_tiling.
|
|
180
|
+
new_tiling = insert_dims_preserve_name_map(
|
|
181
|
+
old_tiling, isl.dim_type.in_, old_tiling.dim(isl.dim_type.in_), 1
|
|
182
|
+
)
|
|
183
|
+
if rank_var:
|
|
184
|
+
new_tiling = new_tiling.set_dim_name(
|
|
185
|
+
isl.dim_type.in_, old_tiling.dim(isl.dim_type.in_), rank_var
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# Min and max of dim_idx. dimension being tiled as function of tiled dimensions.
|
|
189
|
+
dim_min: isl.PwAff = new_tiling.dim_min(dim_idx)
|
|
190
|
+
dim_max: isl.PwAff = new_tiling.dim_max(dim_idx)
|
|
191
|
+
|
|
192
|
+
# Aff from tiled dimensions space to value of newest dim.
|
|
193
|
+
new_dim_id: isl.Aff = isl.Aff.var_on_domain(
|
|
194
|
+
dim_min.get_domain_space().to_local_space(),
|
|
195
|
+
isl.dim_type.set,
|
|
196
|
+
dim_min.dim(isl.dim_type.in_) - 1,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
# Aff from tiled dimensions space to tile tile size constant.
|
|
200
|
+
tile_size_aff: isl.Aff = isl.Aff.val_on_domain_space(
|
|
201
|
+
dim_min.get_domain_space(), isl.Val.int_from_ui(isl.DEFAULT_CONTEXT, tile_size)
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
# PwAff from tiled dimension space to tile_size * newest_dim.
|
|
205
|
+
tile_translate: isl.PwAff = isl.PwAff.from_aff(new_dim_id.mul(tile_size_aff))
|
|
206
|
+
|
|
207
|
+
# What dim_min should be given new tiling.
|
|
208
|
+
new_dim_min: isl.PwAff = dim_min.add(tile_translate)
|
|
209
|
+
|
|
210
|
+
# What dim_max should be given new tiling.
|
|
211
|
+
new_dim_max: isl.PwAff = new_dim_min.add(
|
|
212
|
+
isl.PwAff.from_aff(tile_size_aff.add_constant_val(-1))
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
# TODO: Might be logically equivalent to new_dim_id:
|
|
216
|
+
# https://github.com/NVlabs/timeloop/blob/32370826fdf1aa3c8deb0c93e6b2a2fc7cf053aa/src/loop-analysis/mapping-to-isl/tiling.cpp#L52-L59
|
|
217
|
+
new_iter_id: isl.PwAff = isl.PwAff.from_aff(
|
|
218
|
+
isl.Aff.var_on_domain(
|
|
219
|
+
new_tiling.get_space().domain(),
|
|
220
|
+
isl.dim_type.set,
|
|
221
|
+
old_tiling.dim(isl.dim_type.in_),
|
|
222
|
+
)
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# The set of valid values of the new tiled dimensions.
|
|
226
|
+
iter_set: isl.Set = new_tiling.domain()
|
|
227
|
+
iter_set = iter_set.intersect(new_iter_id.le_set(dim_max.div(tile_size_aff).ceil()))
|
|
228
|
+
iter_set = iter_set.intersect(new_dim_min.ge_set(dim_min))
|
|
229
|
+
|
|
230
|
+
# The value of iter dims cannot exceed what was available before tiling.
|
|
231
|
+
new_tiling = new_tiling.intersect_domain(iter_set)
|
|
232
|
+
|
|
233
|
+
# The set of operations need to to follow the new tile bounds.
|
|
234
|
+
identity: isl.PwAff = isl.PwAff.from_aff(
|
|
235
|
+
isl.Aff.var_on_domain(new_tiling.get_space().range(), isl.dim_type.set, dim_idx)
|
|
236
|
+
)
|
|
237
|
+
new_tiling = new_tiling.intersect(new_dim_min.le_map(identity))
|
|
238
|
+
new_tiling = new_tiling.intersect(new_dim_max.ge_map(identity))
|
|
239
|
+
|
|
240
|
+
return new_tiling
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def shared_input_based_tile_shape_inference(
|
|
244
|
+
workload: Workload,
|
|
245
|
+
tiling_info: defaultdict[EinsumName, Tiling],
|
|
246
|
+
einsums: set[EinsumName],
|
|
247
|
+
shared_input_tensor: TensorName,
|
|
248
|
+
tiled_einsum: EinsumName,
|
|
249
|
+
) -> None:
|
|
250
|
+
"""
|
|
251
|
+
Given a `tiled_einsum` in a `workload`, restrict the other `einsums`' execution
|
|
252
|
+
in this tiling to one in which the data is shared with the `tiled_einsum`. This
|
|
253
|
+
is because, when tiled, data is multicast so the other einsums being tiled together
|
|
254
|
+
must shared data.
|
|
255
|
+
|
|
256
|
+
Parameters
|
|
257
|
+
----------
|
|
258
|
+
workload:
|
|
259
|
+
The workload context the tiling is occurring in.
|
|
260
|
+
tiling_info:
|
|
261
|
+
Relation of `EinsumName` and its viable tiling on hardware.
|
|
262
|
+
einsums:
|
|
263
|
+
The set of all einsums.
|
|
264
|
+
shared_input_tensor:
|
|
265
|
+
The singular tensor `einsums` all read from.
|
|
266
|
+
tiled_einsum:
|
|
267
|
+
The einsum being tiled.
|
|
268
|
+
|
|
269
|
+
Returns
|
|
270
|
+
-------
|
|
271
|
+
None
|
|
272
|
+
|
|
273
|
+
Postconditions
|
|
274
|
+
--------------
|
|
275
|
+
`tiling_info` is updated such that each Tiling contains only compatible tilings
|
|
276
|
+
with `tiled_einsum`.
|
|
277
|
+
"""
|
|
278
|
+
# Gets the data tiled_einsum reads from shared_input_tensor
|
|
279
|
+
tiled_einsum_read_accesses: isl.Map = get_projection_map(
|
|
280
|
+
workload.einsums[tiled_einsum], shared_input_tensor
|
|
281
|
+
)
|
|
282
|
+
read_data: isl.Map = tiling_info[tiled_einsum].apply_range(
|
|
283
|
+
tiled_einsum_read_accesses
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
# Goes through all other einsums and restrict their tilings to only the executable
|
|
287
|
+
# operations after one of the einsums is tiled.
|
|
288
|
+
for einsum in einsums:
|
|
289
|
+
if einsum == tiled_einsum:
|
|
290
|
+
continue
|
|
291
|
+
|
|
292
|
+
read_accesses: isl.Map = get_projection_map(
|
|
293
|
+
workload.einsums[einsum], shared_input_tensor
|
|
294
|
+
)
|
|
295
|
+
executable_operations: isl.Map = read_data.apply_range(read_accesses.reverse())
|
|
296
|
+
executable_operations = executable_operations.intersect_range(
|
|
297
|
+
get_einsum_operation_space(workload, einsum)
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
tiling_info[einsum] = tiling_info[einsum].intersect(executable_operations)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def consumer_based_tile_shape_inference(
|
|
304
|
+
workload: Workload,
|
|
305
|
+
tiling_info: defaultdict[EinsumName, Tiling],
|
|
306
|
+
tensor_to_reuse_level: defaultdict[TensorName, int],
|
|
307
|
+
einsums: set[EinsumName],
|
|
308
|
+
tiled_einsum: EinsumName,
|
|
309
|
+
):
|
|
310
|
+
"""
|
|
311
|
+
Given a `tiled_einsum` in a `workload`, restrict the other `einsums`' execution
|
|
312
|
+
in this tiling to one in which the data is required for the tensors read by
|
|
313
|
+
`tiled_einsum`. This is because, when tiled, data is multicast so the other
|
|
314
|
+
einsums being tiled together must shared data.
|
|
315
|
+
|
|
316
|
+
Parameters
|
|
317
|
+
----------
|
|
318
|
+
workload:
|
|
319
|
+
The workload context the tiling is occurring in.
|
|
320
|
+
tiling_info:
|
|
321
|
+
Relation of `EinsumName` and its viable tiling on hardware.
|
|
322
|
+
tensor_to_reuse_level:
|
|
323
|
+
A relation between a tensor and the amount of reuse occurring.
|
|
324
|
+
einsums:
|
|
325
|
+
The set of all einsums.
|
|
326
|
+
tiled_einsum:
|
|
327
|
+
The einsum being tiled.
|
|
328
|
+
|
|
329
|
+
Returns
|
|
330
|
+
-------
|
|
331
|
+
None
|
|
332
|
+
|
|
333
|
+
Postconditions
|
|
334
|
+
--------------
|
|
335
|
+
`tiling_info` is updated such that each Tiling contains only compatible tilings
|
|
336
|
+
with `tiled_einsum`.
|
|
337
|
+
"""
|
|
338
|
+
# Goes recursively through tensor dependencies (read tensors) and tiles them.
|
|
339
|
+
queue: deque[EinsumName] = deque([tiled_einsum])
|
|
340
|
+
while queue:
|
|
341
|
+
einsum: EinsumName = queue.popleft()
|
|
342
|
+
tiling: Tiling = tiling_info[einsum]
|
|
343
|
+
|
|
344
|
+
# For each tensor read by this einsum, tile that tensor's producers.
|
|
345
|
+
for tensor in workload.einsums[einsum].input_tensor_names:
|
|
346
|
+
producer_einsums: set[EinsumName] = {
|
|
347
|
+
e.name for e in workload.einsums[einsum].output_tensor_names
|
|
348
|
+
}
|
|
349
|
+
if len(producer_einsums) > 1:
|
|
350
|
+
raise NotImplementedError(
|
|
351
|
+
"Tile shape inference cannot handle multiple einsums writing the same tensor."
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
# Not an intermediate tensor.
|
|
355
|
+
if not producer_einsums:
|
|
356
|
+
continue
|
|
357
|
+
|
|
358
|
+
producer_einsums.intersection_update(einsums)
|
|
359
|
+
# No producer einsum in this fusion set.
|
|
360
|
+
if not producer_einsums:
|
|
361
|
+
continue
|
|
362
|
+
|
|
363
|
+
# Collates all the consumer einsum read accesses.
|
|
364
|
+
producer_einsum: EinsumName = next(iter(producer_einsums))
|
|
365
|
+
read_accesses: isl.Map = get_projection_map(
|
|
366
|
+
workload.einsums[einsum], tensor
|
|
367
|
+
)
|
|
368
|
+
# Required data of the tiling as a mapping of read accesses.
|
|
369
|
+
required_data: isl.Map = tiling.apply_range(read_accesses)
|
|
370
|
+
|
|
371
|
+
# Calculates the data computed by the producer einsums.
|
|
372
|
+
computed_data: isl.Map = required_data
|
|
373
|
+
if tensor in tensor_to_reuse_level:
|
|
374
|
+
reuse_level: int = tensor_to_reuse_level[tensor]
|
|
375
|
+
shifter: isl.Map = map_to_prior_coordinate(
|
|
376
|
+
tiling.dim(isl.dim_type.in_),
|
|
377
|
+
reuse_level,
|
|
378
|
+
tiling.get_tuple_name(isl.dim_type.in_),
|
|
379
|
+
)
|
|
380
|
+
buffered_data: isl.Map = shifter.apply_range(required_data)
|
|
381
|
+
computed_data = computed_data.subtract(buffered_data).coalesce()
|
|
382
|
+
|
|
383
|
+
# Grabs the elements this tensor relies on from producer_einsums.
|
|
384
|
+
producer_write_dependency: isl.Map = get_projection_map(
|
|
385
|
+
workload.einsums[producer_einsum], tensor
|
|
386
|
+
)
|
|
387
|
+
# Gets the required operations to produce the current tensor.
|
|
388
|
+
required_operations: isl.Map = computed_data.apply_range(
|
|
389
|
+
producer_write_dependency.reverse()
|
|
390
|
+
)
|
|
391
|
+
required_operations = required_operations.intersect_range(
|
|
392
|
+
get_einsum_operation_space(workload, producer_einsum)
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
# Mutations of the tilings of producer einsums.
|
|
396
|
+
# TODO: Deal with fusing naming better (perhaps mix the names?)
|
|
397
|
+
tiling_info[producer_einsum] = tiling_info[producer_einsum].intersect(
|
|
398
|
+
required_operations.set_tuple_name(
|
|
399
|
+
isl.dim_type.in_,
|
|
400
|
+
tiling_info[producer_einsum].get_tuple_name(isl.dim_type.in_),
|
|
401
|
+
)
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
queue.append(producer_einsum)
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def detect_shared_input_tensor(
|
|
408
|
+
fused_set: set[EinsumName], workload: Workload
|
|
409
|
+
) -> List[TensorName]:
|
|
410
|
+
"""
|
|
411
|
+
Given a set of fused einsums on a workload, detect the input tensor that they
|
|
412
|
+
all are dependent on, if it exists.
|
|
413
|
+
|
|
414
|
+
Parameters
|
|
415
|
+
----------
|
|
416
|
+
fused_set:
|
|
417
|
+
The set of fused einsums being analyzed.
|
|
418
|
+
workload:
|
|
419
|
+
The workload context the einsums exist in.
|
|
420
|
+
|
|
421
|
+
Returns
|
|
422
|
+
-------
|
|
423
|
+
The list of tensors shared by the inputs. Because we default to consumer-based
|
|
424
|
+
analysis if there's more than 1 shared input among the tensors, we only return
|
|
425
|
+
tuple sizes of {0, 1, 2}.
|
|
426
|
+
"""
|
|
427
|
+
n_einsums: int = 0
|
|
428
|
+
tensor_read_counts: defaultdict[TensorName, int] = defaultdict(lambda: 0)
|
|
429
|
+
|
|
430
|
+
# Counts the number of times a tensor is read by an einsum.
|
|
431
|
+
for einsum in fused_set:
|
|
432
|
+
for tensor in workload.einsums[einsum].input_tensor_names:
|
|
433
|
+
tensor_read_counts[tensor] += 1
|
|
434
|
+
n_einsums += 1
|
|
435
|
+
|
|
436
|
+
shared_input_tensors: List[TensorName] = []
|
|
437
|
+
for tensor, count in tensor_read_counts.items():
|
|
438
|
+
# Tensor is shared by all einsums.
|
|
439
|
+
if count == n_einsums:
|
|
440
|
+
shared_input_tensors.append(tensor)
|
|
441
|
+
# Caller should resort to consumer-based fusing methods.
|
|
442
|
+
if len(shared_input_tensors) > 1:
|
|
443
|
+
return shared_input_tensors
|
|
444
|
+
|
|
445
|
+
return shared_input_tensors
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
def tiling_from_mapping(mapping: Mapping, workload: Workload) -> BranchTiling:
|
|
449
|
+
"""
|
|
450
|
+
Given a mapping and a workload generates a tiling.
|
|
451
|
+
|
|
452
|
+
Parameters
|
|
453
|
+
----------
|
|
454
|
+
mapping:
|
|
455
|
+
A mapping of data to hardware.
|
|
456
|
+
workload:
|
|
457
|
+
The problem being solved.
|
|
458
|
+
|
|
459
|
+
Returns
|
|
460
|
+
-------
|
|
461
|
+
BranchTiling associating a node's ID with its tiling.
|
|
462
|
+
"""
|
|
463
|
+
result: BranchTiling = BranchTiling()
|
|
464
|
+
# Grabs the head einsums.
|
|
465
|
+
mapping_groups: defaultdict[MappingNode, set[EinsumName]] = (
|
|
466
|
+
get_mapping_group_einsums(mapping)
|
|
467
|
+
)
|
|
468
|
+
mapping_group_heads: defaultdict[MappingNode, set[EinsumName]] = defaultdict(
|
|
469
|
+
set,
|
|
470
|
+
{
|
|
471
|
+
node: get_head_among_einsums(group, workload)
|
|
472
|
+
for node, group in mapping_groups.items()
|
|
473
|
+
},
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
tensor_to_reuse_level: defaultdict[TensorName, int] = defaultdict()
|
|
477
|
+
dfs_stack: deque[MappingNode] = deque([mapping]) # DFS starts at mapping root.
|
|
478
|
+
|
|
479
|
+
# Maps last non-branch to tiling of each in the group.
|
|
480
|
+
tiling_info: defaultdict[MappingNode, defaultdict[EinsumName, Tiling]] = (
|
|
481
|
+
defaultdict(defaultdict)
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
# Appends info for the root.
|
|
485
|
+
for einsum_name in workload.einsum_names:
|
|
486
|
+
tiling_info[mapping][einsum_name] = isl.Map.from_range(
|
|
487
|
+
get_einsum_operation_space(workload, einsum_name)
|
|
488
|
+
).set_tuple_name(isl.dim_type.in_, f"{einsum_name}_tiled_iteration")
|
|
489
|
+
|
|
490
|
+
# Tracks rank_var specified to partitioned_rank_var index, as traversal
|
|
491
|
+
# in tiling goes down the partition.
|
|
492
|
+
rank_var_partitions: defaultdict[str, int] = defaultdict(lambda: 0)
|
|
493
|
+
|
|
494
|
+
def _get_rank_var_partition(rank_var: str) -> str:
|
|
495
|
+
"""
|
|
496
|
+
Given a rank_var, get the partition at the current point in execution
|
|
497
|
+
and increment for the next retrieval.
|
|
498
|
+
"""
|
|
499
|
+
nonlocal rank_var_partitions
|
|
500
|
+
rank_var_partition: str = f"{rank_var}{rank_var_partitions[rank_var]}"
|
|
501
|
+
rank_var_partitions[rank_var] += 1
|
|
502
|
+
return rank_var_partition
|
|
503
|
+
|
|
504
|
+
def _tile_branch(heads: set[EinsumName], fusing_node: MappingNode):
|
|
505
|
+
"""
|
|
506
|
+
Given a set of `heads` to fuse at `fusing_node`, fuse as much as possible
|
|
507
|
+
in this branch.
|
|
508
|
+
|
|
509
|
+
Parameters
|
|
510
|
+
----------
|
|
511
|
+
heads:
|
|
512
|
+
The heads being fused.
|
|
513
|
+
fusing_node:
|
|
514
|
+
The node node in the mapping at which the fusing is happening.
|
|
515
|
+
|
|
516
|
+
Preconditions
|
|
517
|
+
-------------
|
|
518
|
+
1. `dfs_stack`: initialized with tiles to proceed to explore.
|
|
519
|
+
2. `tiling_info`: prima facie populated.
|
|
520
|
+
3. `tensor_to_reuse_level`: initialized and unmutated from last time this
|
|
521
|
+
function was run.
|
|
522
|
+
|
|
523
|
+
Postconditions
|
|
524
|
+
--------------
|
|
525
|
+
1. `dfs_stack`: progressed to the next node to tile at.
|
|
526
|
+
2. `tiling_info`: updated to include the fusing and tiling.
|
|
527
|
+
3. `tensor_to_reuse_level`: populated if information has changed from tiling.
|
|
528
|
+
"""
|
|
529
|
+
nonlocal dfs_stack
|
|
530
|
+
nonlocal tiling_info
|
|
531
|
+
nonlocal tensor_to_reuse_level
|
|
532
|
+
|
|
533
|
+
current_node: MappingNode = fusing_node
|
|
534
|
+
while True:
|
|
535
|
+
# Fuses current_node to one of the heads.
|
|
536
|
+
match current_node:
|
|
537
|
+
# For or Par-For loop handling.
|
|
538
|
+
case Loop():
|
|
539
|
+
if len(heads) != 1:
|
|
540
|
+
raise ValueError(
|
|
541
|
+
f"Cannot fuse tiled set with {len(heads)} heads.\n"
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
# Tiles `current_node.rank_variable` at `head`
|
|
545
|
+
head = next(iter(heads))
|
|
546
|
+
tiling: Tiling = tiling_info[fusing_node][head]
|
|
547
|
+
# Downstreams of "heads" is also constant as it is a set, not
|
|
548
|
+
# AbstractSet.
|
|
549
|
+
idx: int = tuple(workload.einsums[head].rank_variables).index(
|
|
550
|
+
current_node.rank_variable
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
# Adds a new tile_dim to the old tiling.
|
|
554
|
+
# TODO: Handle stride.
|
|
555
|
+
if (
|
|
556
|
+
isinstance(
|
|
557
|
+
_ := current_node.tile_pattern.initial_tile_shape, int
|
|
558
|
+
)
|
|
559
|
+
and (_ != 0)
|
|
560
|
+
and (_ == current_node.tile_pattern.tile_shape)
|
|
561
|
+
):
|
|
562
|
+
tiling: Tiling = add_new_tile_dim(
|
|
563
|
+
tiling,
|
|
564
|
+
idx,
|
|
565
|
+
current_node.tile_pattern.initial_tile_shape,
|
|
566
|
+
_get_rank_var_partition(current_node.rank_variable),
|
|
567
|
+
)
|
|
568
|
+
else:
|
|
569
|
+
raise NotImplementedError(
|
|
570
|
+
f"Tile size analysis not implemented for type {type(fusing_node)} "
|
|
571
|
+
f"with tile shape {current_node.tile_pattern.initial_tile_shape}"
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
# Saves the fused tiling.
|
|
575
|
+
tiling_info[fusing_node][head] = tiling
|
|
576
|
+
|
|
577
|
+
# Adds the ranks to the tiling isl.Map.
|
|
578
|
+
iteration_set: isl.Set = tiling.domain()
|
|
579
|
+
for einsum in mapping_groups[fusing_node] - {head}:
|
|
580
|
+
tiling = tiling_info[fusing_node][einsum]
|
|
581
|
+
# Index variables for the branch.
|
|
582
|
+
tiling = insert_dims_preserve_name_map(
|
|
583
|
+
tiling, isl.dim_type.in_, tiling.dim(isl.dim_type.in_), 1
|
|
584
|
+
)
|
|
585
|
+
tiling = tiling.set_dim_name(
|
|
586
|
+
isl.dim_type.in_,
|
|
587
|
+
tiling.dim(isl.dim_type.in_) - 1,
|
|
588
|
+
_get_rank_var_partition(current_node.rank_variable),
|
|
589
|
+
)
|
|
590
|
+
# TODO: Figure out if this intersection is correct.
|
|
591
|
+
tiling = tiling.intersect_domain(
|
|
592
|
+
iteration_set.set_tuple_name(
|
|
593
|
+
tiling.get_tuple_name(isl.dim_type.in_)
|
|
594
|
+
)
|
|
595
|
+
)
|
|
596
|
+
tiling_info[fusing_node][einsum] = tiling
|
|
597
|
+
|
|
598
|
+
current_node = dfs_stack.pop()
|
|
599
|
+
# Notes what reuse level the tensor is on.
|
|
600
|
+
case Storage():
|
|
601
|
+
# See current_node is the highest level of Storage to determine reuse level.
|
|
602
|
+
for tensor in current_node.tensors:
|
|
603
|
+
# Check second term
|
|
604
|
+
if tensor not in tensor_to_reuse_level:
|
|
605
|
+
random_einsum: EinsumName = next(
|
|
606
|
+
iter(mapping_groups[fusing_node])
|
|
607
|
+
)
|
|
608
|
+
tiling: Tiling = tiling_info[fusing_node][random_einsum]
|
|
609
|
+
tensor_to_reuse_level[tensor] = tiling.dim(isl.dim_type.in_)
|
|
610
|
+
|
|
611
|
+
current_node = dfs_stack.pop()
|
|
612
|
+
# If we are at the Mapping root, just go to the actual Nodes.
|
|
613
|
+
case Mapping():
|
|
614
|
+
dfs_stack.extend(reversed(current_node.nodes))
|
|
615
|
+
current_node = dfs_stack.pop()
|
|
616
|
+
# If we hit the compute node, we've finished tiling, end!
|
|
617
|
+
case Compute():
|
|
618
|
+
result[current_node] = tiling_info[fusing_node][current_node.einsum]
|
|
619
|
+
return
|
|
620
|
+
case Split():
|
|
621
|
+
fused_set: set[EinsumName] = mapping_groups[fusing_node]
|
|
622
|
+
if len(heads) != 1:
|
|
623
|
+
# There can't be a tiling, so no inference to be done.
|
|
624
|
+
break
|
|
625
|
+
|
|
626
|
+
random_head = next(iter(heads))
|
|
627
|
+
if len(_ := detect_shared_input_tensor(fused_set, workload)) == 1:
|
|
628
|
+
shared_input_based_tile_shape_inference(
|
|
629
|
+
workload,
|
|
630
|
+
tiling_info[fusing_node],
|
|
631
|
+
fused_set,
|
|
632
|
+
_[0],
|
|
633
|
+
random_head,
|
|
634
|
+
)
|
|
635
|
+
else:
|
|
636
|
+
consumer_based_tile_shape_inference(
|
|
637
|
+
workload,
|
|
638
|
+
tiling_info[fusing_node],
|
|
639
|
+
tensor_to_reuse_level,
|
|
640
|
+
fused_set,
|
|
641
|
+
random_head,
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
# Goes through each child node of the current node and propagate
|
|
645
|
+
# the tiling updates.
|
|
646
|
+
for idx, child in enumerate(current_node.nodes):
|
|
647
|
+
# Each child needs tilings for all Einsums in its group.
|
|
648
|
+
group: set[EinsumName] = mapping_groups[child]
|
|
649
|
+
tilings: defaultdict[EinsumName, Tiling] = defaultdict()
|
|
650
|
+
|
|
651
|
+
# For all einsums the child is involved in, update their tilings.
|
|
652
|
+
for einsum in group:
|
|
653
|
+
tiling: Tiling = tiling_info[fusing_node][einsum]
|
|
654
|
+
# Add dimension that iterates over branches.
|
|
655
|
+
new_tiling: Tiling = add_dims_preserve_name_map(
|
|
656
|
+
tiling, isl.dim_type.in_, 1
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
tilings[einsum] = new_tiling.fix_input_si(
|
|
660
|
+
new_tiling.dim(isl.dim_type.in_) - 1, idx
|
|
661
|
+
)
|
|
662
|
+
|
|
663
|
+
# Update the tiling info for the child.
|
|
664
|
+
tiling_info[child] = tilings
|
|
665
|
+
# DFS tile on the child.
|
|
666
|
+
dfs_stack.append(child)
|
|
667
|
+
|
|
668
|
+
return
|
|
669
|
+
case Nested():
|
|
670
|
+
dfs_stack.extend(reversed(current_node.nodes))
|
|
671
|
+
current_node = dfs_stack.pop()
|
|
672
|
+
case _:
|
|
673
|
+
raise NotImplementedError(
|
|
674
|
+
f"Type {type(fusing_node)} not handled.\n"
|
|
675
|
+
f"---\n"
|
|
676
|
+
f"node={pformat(fusing_node)}"
|
|
677
|
+
)
|
|
678
|
+
|
|
679
|
+
while dfs_stack:
|
|
680
|
+
fusing_node = dfs_stack.pop()
|
|
681
|
+
if DUMP_ISL_IR:
|
|
682
|
+
print(f"New Tiling Root: {pformat(fusing_node)}")
|
|
683
|
+
_tile_branch(mapping_group_heads[fusing_node], fusing_node)
|
|
684
|
+
|
|
685
|
+
return result
|