accelforge 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accelforge/__init__.py +21 -0
- accelforge/_accelerated_imports.py +16 -0
- accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
- accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
- accelforge/_deprecate/_simanneal/simanneal.py +666 -0
- accelforge/_deprecate/_simanneal/tracking.py +105 -0
- accelforge/_deprecate/_simanneal/wrappers.py +218 -0
- accelforge/_deprecate/_simanneal2/__init__.py +7 -0
- accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
- accelforge/_deprecate/_simanneal2/tracking.py +116 -0
- accelforge/_deprecate/compatibility_util.py +181 -0
- accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
- accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
- accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
- accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
- accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
- accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
- accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
- accelforge/_deprecate/tags.py +69 -0
- accelforge/_deprecate/viz/__init__.py +0 -0
- accelforge/_deprecate/viz/interactive.py +159 -0
- accelforge/_deprecate/viz/reservationtree.py +307 -0
- accelforge/_deprecate/viz/ski_slope.py +88 -0
- accelforge/_version.py +15 -0
- accelforge/examples.py +39 -0
- accelforge/frontend/__init__.py +10 -0
- accelforge/frontend/_binding.py +129 -0
- accelforge/frontend/_workload_isl/__init__.py +2 -0
- accelforge/frontend/_workload_isl/_isl.py +149 -0
- accelforge/frontend/_workload_isl/_symbolic.py +141 -0
- accelforge/frontend/arch copy.py +1544 -0
- accelforge/frontend/arch.py +1642 -0
- accelforge/frontend/config.py +63 -0
- accelforge/frontend/mapper/__init__.py +5 -0
- accelforge/frontend/mapper/ffm.py +126 -0
- accelforge/frontend/mapper/mapper.py +7 -0
- accelforge/frontend/mapper/metrics.py +30 -0
- accelforge/frontend/mapping/__init__.py +1 -0
- accelforge/frontend/mapping/mapping.py +1736 -0
- accelforge/frontend/model.py +14 -0
- accelforge/frontend/renames.py +150 -0
- accelforge/frontend/spec copy.py +230 -0
- accelforge/frontend/spec.py +301 -0
- accelforge/frontend/variables.py +12 -0
- accelforge/frontend/workload.py +952 -0
- accelforge/mapper/FFM/__init__.py +9 -0
- accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
- accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
- accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
- accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
- accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
- accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
- accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
- accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
- accelforge/mapper/FFM/data.py +61 -0
- accelforge/mapper/FFM/main copy.py +236 -0
- accelforge/mapper/FFM/main.py +208 -0
- accelforge/mapper/FFM/mappings.py +510 -0
- accelforge/mapper/FFM/pmappings.py +310 -0
- accelforge/mapper/__init__.py +4 -0
- accelforge/mapper.py +0 -0
- accelforge/model/__init__.py +1 -0
- accelforge/model/_looptree/__init__.py +0 -0
- accelforge/model/_looptree/accesses.py +335 -0
- accelforge/model/_looptree/capacity/__init__.py +1 -0
- accelforge/model/_looptree/capacity/aggregators.py +36 -0
- accelforge/model/_looptree/capacity/capacity.py +47 -0
- accelforge/model/_looptree/energy.py +150 -0
- accelforge/model/_looptree/equivalent_ranks.py +29 -0
- accelforge/model/_looptree/latency/__init__.py +1 -0
- accelforge/model/_looptree/latency/latency.py +98 -0
- accelforge/model/_looptree/latency/memory.py +120 -0
- accelforge/model/_looptree/latency/processors.py +92 -0
- accelforge/model/_looptree/mapping_utilities.py +71 -0
- accelforge/model/_looptree/reuse/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
- accelforge/model/_looptree/reuse/isl/des.py +59 -0
- accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
- accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
- accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
- accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
- accelforge/model/_looptree/run.py +122 -0
- accelforge/model/_looptree/types.py +26 -0
- accelforge/model/_looptree/visualization/__init__.py +0 -0
- accelforge/model/_looptree/visualization/occupancy.py +11 -0
- accelforge/model/main.py +222 -0
- accelforge/plotting/__init__.py +2 -0
- accelforge/plotting/mappings.py +219 -0
- accelforge/plotting/specs.py +57 -0
- accelforge/util/__init__.py +4 -0
- accelforge/util/_base_analysis_types.py +24 -0
- accelforge/util/_basetypes.py +1089 -0
- accelforge/util/_frozenset.py +36 -0
- accelforge/util/_isl.py +29 -0
- accelforge/util/_itertools.py +14 -0
- accelforge/util/_mathfuncs.py +57 -0
- accelforge/util/_parse_expressions.py +339 -0
- accelforge/util/_picklecache.py +32 -0
- accelforge/util/_setexpressions.py +268 -0
- accelforge/util/_sympy/__init__.py +0 -0
- accelforge/util/_sympy/broadcast_max.py +18 -0
- accelforge/util/_visualization.py +112 -0
- accelforge/util/_yaml.py +579 -0
- accelforge/util/parallel.py +193 -0
- accelforge-0.0.1.dist-info/METADATA +64 -0
- accelforge-0.0.1.dist-info/RECORD +258 -0
- accelforge-0.0.1.dist-info/WHEEL +5 -0
- accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
- accelforge-0.0.1.dist-info/top_level.txt +5 -0
- docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
- docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
- docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
- docs/_build/html/_sources/fastfusion.rst.txt +20 -0
- docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
- docs/_build/html/_sources/index.rst.txt +87 -0
- docs/_build/html/_sources/modules.rst.txt +7 -0
- docs/_build/html/_sources/notes/citation.rst.txt +45 -0
- docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
- docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
- docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
- docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
- docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
- docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
- docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
- docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
- docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
- docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
- docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
- docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
- docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
- docs/_build/html/_sources/notes/spec.rst.txt +36 -0
- docs/source/_ext/include_attrs.py +213 -0
- docs/source/_ext/include_docstring.py +364 -0
- docs/source/_ext/include_functions.py +154 -0
- docs/source/_ext/include_notebook.py +131 -0
- docs/source/_ext/include_yaml.py +119 -0
- docs/source/_ext/inherited_attributes.py +222 -0
- docs/source/_ext/paths.py +4 -0
- docs/source/conf.py +79 -0
- examples/arches/compute_in_memory/_include.yaml +74 -0
- examples/arches/compute_in_memory/_include_functions.py +229 -0
- examples/arches/compute_in_memory/_load_spec.py +57 -0
- examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
- examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
- examples/arches/compute_in_memory/components/misc.py +195 -0
- examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
- examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
- examples/arches/compute_in_memory/isaac.yaml +233 -0
- examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
- examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
- examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
- examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
- examples/arches/eyeriss.yaml +68 -0
- examples/arches/fanout_variations/at_glb.yaml +31 -0
- examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
- examples/arches/fanout_variations/at_mac.yaml +31 -0
- examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
- examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
- examples/arches/nvdla.yaml +47 -0
- examples/arches/simple.yaml +28 -0
- examples/arches/tpu_v4i.yaml +67 -0
- examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
- examples/misc/component_annotated.yaml +33 -0
- examples/workloads/gpt3_6.7B.yaml +124 -0
- examples/workloads/matmuls.yaml +20 -0
- examples/workloads/mobilenet_28.yaml +81 -0
- examples/workloads/mobilenet_various_separate.yaml +106 -0
- examples/workloads/three_matmuls_annotated.yaml +59 -0
- notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
- notebooks/compute_in_memory/_scripts.py +339 -0
- notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
- notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
- notebooks/paths.py +4 -0
- notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
- notebooks/tutorials/FFM.ipynb +3498 -0
- notebooks/tutorials/_include.py +48 -0
- notebooks/tutorials/component_energy_area.ipynb +363 -0
- tests/Q_mapping.yaml +38 -0
- tests/__init__.py +0 -0
- tests/conv.mapping.yaml +27 -0
- tests/conv.workload.yaml +13 -0
- tests/conv_sym.mapping.yaml +43 -0
- tests/copy.mapping.yaml +35 -0
- tests/copy.workload.yaml +15 -0
- tests/distribuffers/__init__.py +0 -0
- tests/distribuffers/multicast/test_cases.yaml +482 -0
- tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
- tests/distribuffers/spec/distributed.yaml +100 -0
- tests/distribuffers/spec/logical_arch.yaml +32 -0
- tests/distribuffers/spec/physical_arch.yaml +69 -0
- tests/distribuffers/test_binding.py +48 -0
- tests/frontend/__init__.py +0 -0
- tests/frontend/test_mapping_viz.py +52 -0
- tests/mapper/__init__.py +0 -0
- tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
- tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
- tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
- tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
- tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
- tests/mapper/test_mapping_to_isl.py +90 -0
- tests/mapper/test_spatial_reuse_analysis.py +67 -0
- tests/mapper/test_temporal_reuse_analysis.py +56 -0
- tests/mapper/util.py +58 -0
- tests/matmul.mapping.yaml +29 -0
- tests/matmul.workload.yaml +12 -0
- tests/matmul_spatial.mapping.yaml +44 -0
- tests/mha.renames.yaml +65 -0
- tests/mha.workload.yaml +67 -0
- tests/mha.yaml +59 -0
- tests/mha_full.workload.yaml +67 -0
- tests/mobilenet.workload.yaml +35 -0
- tests/mobilenet_long.workload.yaml +64 -0
- tests/pmappingcache.py +24 -0
- tests/processing_stage.arch.yaml +40 -0
- tests/snowcat.arch.yaml +36 -0
- tests/test_ffm_join_pmappings.py +106 -0
- tests/test_ffm_make_pmappings.py +82 -0
- tests/test_ffm_make_tile_shapes.py +49 -0
- tests/test_mapper.py +100 -0
- tests/test_model.py +37 -0
- tests/test_plotting.py +72 -0
- tests/test_processing_stage.py +46 -0
- tests/test_symbolic_model.py +248 -0
- tests/test_workload.py +141 -0
|
@@ -0,0 +1,374 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ISL functions that encapsulate more commonly used workflows in looptree for the
|
|
3
|
+
sake of code concision.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
from typing import List
|
|
8
|
+
|
|
9
|
+
import islpy as isl
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def project_dim_in_after(map_: isl.Map, start: int) -> isl.Map:
|
|
13
|
+
"""
|
|
14
|
+
Projects out the input dims of idx [start, end] in map_.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
map_:
|
|
19
|
+
The map to project out dims from.
|
|
20
|
+
start:
|
|
21
|
+
The dim idx to start projecting dims out from.
|
|
22
|
+
|
|
23
|
+
Returns
|
|
24
|
+
-------
|
|
25
|
+
map_ without the input dims [start:].
|
|
26
|
+
"""
|
|
27
|
+
n_dim_in: int = map_.dim(isl.dim_type.in_)
|
|
28
|
+
return (
|
|
29
|
+
map_.project_out(isl.dim_type.in_, start, n_dim_in - start)
|
|
30
|
+
if start <= n_dim_in
|
|
31
|
+
else map_
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def dim_projector_range(space: isl.Space, start: int, n: int) -> isl.Map:
|
|
36
|
+
"""
|
|
37
|
+
Given a space, create a map that projects out the dims [start: start+n).
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
space:
|
|
42
|
+
The space to create the dim projector in.
|
|
43
|
+
start:
|
|
44
|
+
The index to start the projection.
|
|
45
|
+
n:
|
|
46
|
+
The number of dims from `start` to project out.
|
|
47
|
+
|
|
48
|
+
Returns
|
|
49
|
+
-------
|
|
50
|
+
A `isl.Map` in `space` that projects out dims [start:start+n].
|
|
51
|
+
"""
|
|
52
|
+
base_map: isl.Map = isl.Map.identity(isl.Space.map_from_set(space))
|
|
53
|
+
# TODO: propagate tuple names from `space` onto `base_map` (e.g.,
|
|
54
|
+
# `space.get_tuple_name(isl.dim_type.set)`) so the projector keeps the
|
|
55
|
+
# original set label on both domain and range.
|
|
56
|
+
return isl.Map.project_out(base_map, isl.dim_type.in_, start, n)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def dim_projector_mask(space: isl.Space, mask: List[bool]) -> isl.Map:
|
|
60
|
+
"""
|
|
61
|
+
Given a space, create a map that projects out the dims marked `True` in the mask.
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
space:
|
|
66
|
+
The space the projector is created on.
|
|
67
|
+
mask:
|
|
68
|
+
The mask of the list of dims to be projected out.
|
|
69
|
+
|
|
70
|
+
Returns
|
|
71
|
+
-------
|
|
72
|
+
A projection from space in[x_0, ..., x_n] -> out[x_1, ..., x_i, ... x_{n-1}]
|
|
73
|
+
where `x_i ∉ out => mask[i]`.
|
|
74
|
+
"""
|
|
75
|
+
projector: isl.Map = isl.Map.identity(isl.Space.map_from_set(space))
|
|
76
|
+
# TODO: set tuple names on `projector` using the set name (e.g.,
|
|
77
|
+
# `space.get_tuple_name(isl.dim_type.set)`) so domain/range remain
|
|
78
|
+
# attributable after masking.
|
|
79
|
+
|
|
80
|
+
for i in range(len(mask) - 1, -1, -1):
|
|
81
|
+
if mask[i]:
|
|
82
|
+
projector = projector.project_out(isl.dim_type.in_, i, 1)
|
|
83
|
+
|
|
84
|
+
return projector
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def add_dims_preserve_name_map(
|
|
88
|
+
map_: isl.Map, dim_type: isl.dim_type, n: int
|
|
89
|
+
) -> isl.Map:
|
|
90
|
+
"""
|
|
91
|
+
Wrapper of `isl.Map.add_dims` that preserves the space name post
|
|
92
|
+
addition.
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
map_:
|
|
97
|
+
The map we're adding into.
|
|
98
|
+
dim_type:
|
|
99
|
+
The dimension tuple we're inserting into.
|
|
100
|
+
n:
|
|
101
|
+
The number of dimensions to insert.
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
Dimension-inserted maps with preservation.
|
|
106
|
+
|
|
107
|
+
Postcondition
|
|
108
|
+
-------------
|
|
109
|
+
The entirety of dependencies with a given space name in the given context
|
|
110
|
+
we're operating under.
|
|
111
|
+
"""
|
|
112
|
+
name: str = map_.get_tuple_name(dim_type)
|
|
113
|
+
map_ = map_.add_dims(dim_type, n)
|
|
114
|
+
if name is None:
|
|
115
|
+
logging.warning(f"unnamed space for {map_}", stack_info=True)
|
|
116
|
+
return map_
|
|
117
|
+
return map_.set_tuple_name(dim_type, name)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def insert_dims_preserve_name_map(
|
|
121
|
+
map_: isl.Map, dim_type: isl.dim_type, pos: int, n: int
|
|
122
|
+
) -> isl.Map:
|
|
123
|
+
"""
|
|
124
|
+
Wrapper of `isl.Map.insert_dims` that preserves the space name post
|
|
125
|
+
insertion.
|
|
126
|
+
|
|
127
|
+
Parameters
|
|
128
|
+
----------
|
|
129
|
+
map_:
|
|
130
|
+
The inserting map.
|
|
131
|
+
dim_type:
|
|
132
|
+
The dimension tuple we're inserting into.
|
|
133
|
+
pos:
|
|
134
|
+
The position to start inserting into, inclusive.
|
|
135
|
+
n:
|
|
136
|
+
The number of dimensions to insert.
|
|
137
|
+
|
|
138
|
+
Returns
|
|
139
|
+
-------
|
|
140
|
+
Dimension-inserted maps with preservation.
|
|
141
|
+
|
|
142
|
+
Postcondition
|
|
143
|
+
-------------
|
|
144
|
+
The entirety of dependencies with a given space name in the given context
|
|
145
|
+
we're operating under.
|
|
146
|
+
"""
|
|
147
|
+
name: str = map_.get_tuple_name(dim_type)
|
|
148
|
+
map_ = map_.insert_dims(dim_type, pos, n)
|
|
149
|
+
if name is None:
|
|
150
|
+
logging.warning(f"unnamed space for {map_}", stack_info=True)
|
|
151
|
+
return map_
|
|
152
|
+
return map_.set_tuple_name(dim_type, name)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def insert_equal_dims_maff(
|
|
156
|
+
maff: isl.MultiAff, in_pos: int, out_pos: int, n: int
|
|
157
|
+
) -> isl.MultiAff:
|
|
158
|
+
"""
|
|
159
|
+
Given a multi affine, insert equal numbers of input and output dimensions and
|
|
160
|
+
enforce equality between the values of the two dims.
|
|
161
|
+
|
|
162
|
+
Parameters
|
|
163
|
+
----------
|
|
164
|
+
maff:
|
|
165
|
+
The multi affine base to insert dims into.
|
|
166
|
+
in_pos:
|
|
167
|
+
The index to start inserting input dimensions at in `maff`.
|
|
168
|
+
out_pos:
|
|
169
|
+
The index to start inserting output dimensions at in `maff`.
|
|
170
|
+
n:
|
|
171
|
+
The number of dimensions to insert.
|
|
172
|
+
|
|
173
|
+
Returns
|
|
174
|
+
-------
|
|
175
|
+
A new maff which is equivalent to `maff` except it has `n` new input and
|
|
176
|
+
output dimensions starting at `in_pos` and `out_pos` respectively.
|
|
177
|
+
"""
|
|
178
|
+
# Inserts the `n` dimensions into a new maff base.
|
|
179
|
+
maff = maff.insert_dims(isl.dim_type.in_, in_pos, n)
|
|
180
|
+
maff = maff.insert_dims(isl.dim_type.out, out_pos, n)
|
|
181
|
+
|
|
182
|
+
# Modifies each affine to create an equality relation between the input and output.
|
|
183
|
+
for i in range(n):
|
|
184
|
+
aff: isl.Aff = maff.get_at(out_pos + i)
|
|
185
|
+
aff = aff.set_coefficient_val(isl.dim_type.in_, in_pos + i, 1)
|
|
186
|
+
maff = maff.set_aff(out_pos + i, aff)
|
|
187
|
+
|
|
188
|
+
return maff
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def insert_equal_dims_map(map_: isl.Map, in_pos: int, out_pos: int, n: int) -> isl.Map:
|
|
192
|
+
"""
|
|
193
|
+
Given a map, insert equal numbers of input and output dimensions and enforce
|
|
194
|
+
equality between the values of the two dims.
|
|
195
|
+
|
|
196
|
+
Parameters
|
|
197
|
+
----------
|
|
198
|
+
map_:
|
|
199
|
+
The map base to insert dims into.
|
|
200
|
+
in_pos:
|
|
201
|
+
The index to start inserting input dimensions at in `map_`.
|
|
202
|
+
out_pos:
|
|
203
|
+
The index to start inserting output dimensions at in `map_`.
|
|
204
|
+
n:
|
|
205
|
+
The number of dimensions to insert.
|
|
206
|
+
|
|
207
|
+
Returns
|
|
208
|
+
-------
|
|
209
|
+
A new maff which is equivalent to `map_` except it has `n` new input and
|
|
210
|
+
output dimensions starting at `in_pos` and `out_pos` respectively.
|
|
211
|
+
"""
|
|
212
|
+
# Inserts the new input and output dimensions.
|
|
213
|
+
map_ = insert_dims_preserve_name_map(map_, isl.dim_type.in_, in_pos, n)
|
|
214
|
+
map_ = insert_dims_preserve_name_map(map_, isl.dim_type.out, out_pos, n)
|
|
215
|
+
|
|
216
|
+
# Adds constraints for conservation between the new input and output dimensions
|
|
217
|
+
# in the map.
|
|
218
|
+
local_space: isl.LocalSpace = map_.get_space().to_local_space()
|
|
219
|
+
for i in range(n):
|
|
220
|
+
# out - in == 0 => out == in
|
|
221
|
+
constraint: isl.Constraint = isl.Constraint.alloc_equality(local_space)
|
|
222
|
+
constraint = constraint.set_coefficient_val(isl.dim_type.in_, in_pos + i, 1)
|
|
223
|
+
constraint = constraint.set_coefficient_val(isl.dim_type.out, out_pos + i, -1)
|
|
224
|
+
map_ = map_.add_constraint(constraint)
|
|
225
|
+
|
|
226
|
+
return map_
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def map_to_prior_coordinate(n_in_dims: int, shifted_idx: int, name: str) -> isl.Map:
|
|
230
|
+
"""
|
|
231
|
+
Create a map that relates current time index vector to a previous index vector.
|
|
232
|
+
It shifts the coordinate at shifted_idx back by 1.
|
|
233
|
+
|
|
234
|
+
Goal: { [i0,...,i{n_in_dims-1}] ->
|
|
235
|
+
[i0, ..., i{shifted_idx}-1, i{shifted_idx+1}, ..., i{n_in_dims-1}] }
|
|
236
|
+
|
|
237
|
+
Parameters
|
|
238
|
+
----------
|
|
239
|
+
n_in_dims:
|
|
240
|
+
The number of input/output dims of the dataspace.
|
|
241
|
+
shifted_idx:
|
|
242
|
+
The coordinate being shifted.
|
|
243
|
+
name:
|
|
244
|
+
The name for the domain and range of the shifter.
|
|
245
|
+
|
|
246
|
+
Returns
|
|
247
|
+
-------
|
|
248
|
+
A map relating a current index vector to a previous index to a previous
|
|
249
|
+
index vector by shifting the coordinate at `shifted_idx` back by 1.
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
Preconditions
|
|
253
|
+
-------------
|
|
254
|
+
- 0 <= shifted_idx <=n_in_dims-1
|
|
255
|
+
"""
|
|
256
|
+
|
|
257
|
+
# Creates the space, map, and local_space the temporal reuse data map will exist
|
|
258
|
+
# in.
|
|
259
|
+
space: isl.Space = isl.Space.alloc(isl.DEFAULT_CONTEXT, 0, n_in_dims, n_in_dims)
|
|
260
|
+
map_: isl.Map = isl.Map.empty(space)
|
|
261
|
+
local_space: isl.LocalSpace = isl.LocalSpace.from_space(space)
|
|
262
|
+
|
|
263
|
+
constraint: isl.Constraint
|
|
264
|
+
# If there is any data replacement
|
|
265
|
+
if shifted_idx > 0:
|
|
266
|
+
# Create a temporary map.
|
|
267
|
+
tmp_map: isl.Map = isl.Map.universe(space)
|
|
268
|
+
# Model the conservation of data along each data dimension in that map.
|
|
269
|
+
# out - in == 0 => out == in
|
|
270
|
+
for i in range(shifted_idx - 1):
|
|
271
|
+
constraint = isl.Constraint.alloc_equality(local_space)
|
|
272
|
+
constraint = constraint.set_coefficient_val(isl.dim_type.out, i, 1)
|
|
273
|
+
constraint = constraint.set_coefficient_val(isl.dim_type.in_, i, -1)
|
|
274
|
+
tmp_map = tmp_map.add_constraint(constraint)
|
|
275
|
+
|
|
276
|
+
# Sets constraints such that the pivot value is decremented.
|
|
277
|
+
# out - in + 1 == 0 => out == in - 1
|
|
278
|
+
constraint = isl.Constraint.alloc_equality(local_space)
|
|
279
|
+
constraint = constraint.set_coefficient_val(
|
|
280
|
+
isl.dim_type.out, shifted_idx - 1, 1
|
|
281
|
+
)
|
|
282
|
+
constraint = constraint.set_coefficient_val(
|
|
283
|
+
isl.dim_type.in_, shifted_idx - 1, -1
|
|
284
|
+
)
|
|
285
|
+
constraint = constraint.set_constant_val(1)
|
|
286
|
+
tmp_map = tmp_map.add_constraint(constraint)
|
|
287
|
+
|
|
288
|
+
map_ = map_.union(tmp_map)
|
|
289
|
+
|
|
290
|
+
# If we're pivoting any of the data, preserve the `shifted_idx` datapoints.
|
|
291
|
+
if shifted_idx < n_in_dims:
|
|
292
|
+
tmp_map: isl.Map = isl.Map.lex_gt(
|
|
293
|
+
isl.Space.set_alloc(
|
|
294
|
+
isl.DEFAULT_CONTEXT,
|
|
295
|
+
map_.dim(isl.dim_type.param),
|
|
296
|
+
n_in_dims - shifted_idx,
|
|
297
|
+
)
|
|
298
|
+
)
|
|
299
|
+
tmp_map = insert_equal_dims_map(tmp_map, 0, 0, shifted_idx)
|
|
300
|
+
map_ = map_.union(tmp_map)
|
|
301
|
+
|
|
302
|
+
map_ = map_.set_tuple_name(isl.dim_type.in_, name).set_tuple_name(
|
|
303
|
+
isl.dim_type.out, name
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
return map_
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def map_to_shifted(domain_space: isl.Space, pos: int, shift: int) -> isl.Map:
|
|
310
|
+
"""
|
|
311
|
+
Given a `domain_space`, return a map from a point in the `domain_space` to
|
|
312
|
+
a point in that dimension `shift` ahead in the dimension at `pos`.
|
|
313
|
+
|
|
314
|
+
Parameters
|
|
315
|
+
----------
|
|
316
|
+
domain_space:
|
|
317
|
+
The space on which to construct the shift on.
|
|
318
|
+
pos:
|
|
319
|
+
The dimension to construct the shift on.
|
|
320
|
+
shift:
|
|
321
|
+
The amount of shift
|
|
322
|
+
|
|
323
|
+
Returns
|
|
324
|
+
-------
|
|
325
|
+
A mapping from `[x_0, x_1, ..., x_n] -> [x_0, ..., x_{pos} + shift, ..., x_n]`
|
|
326
|
+
on `domain_space`.
|
|
327
|
+
"""
|
|
328
|
+
maff: isl.MultiAff = domain_space.identity_multi_aff_on_domain()
|
|
329
|
+
maff = maff.set_at(pos, maff.get_at(pos).set_constant_val(shift))
|
|
330
|
+
return maff.as_map()
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def reorder_projector(
|
|
334
|
+
permutation: list[int], space: str, ctx: isl.Context = isl.DEFAULT_CONTEXT
|
|
335
|
+
) -> isl.Map:
|
|
336
|
+
"""
|
|
337
|
+
A projection to reorder a space from [i_2, i_0, ..., i_n, ...] -> [i_0, i_1, ..., i_n]
|
|
338
|
+
|
|
339
|
+
Parameters
|
|
340
|
+
----------
|
|
341
|
+
permutation:
|
|
342
|
+
A list where the elements correspond to indices of the space's dimensions
|
|
343
|
+
and the ordering of the indices the reconstruction.
|
|
344
|
+
space:
|
|
345
|
+
The name of the space being permuted.
|
|
346
|
+
|
|
347
|
+
Returns
|
|
348
|
+
-------
|
|
349
|
+
A map that reorders an arbitrary list of input dimensions into an enforced ordering.
|
|
350
|
+
"""
|
|
351
|
+
# Constructs the permutation in the form of a string, because it's easier
|
|
352
|
+
# and less operations than explicit construction via objects.
|
|
353
|
+
pattern: str
|
|
354
|
+
if len(permutation) == 0:
|
|
355
|
+
pattern = "{ [] -> [] }"
|
|
356
|
+
else:
|
|
357
|
+
pattern = "{ [ "
|
|
358
|
+
for i in range(len(permutation) - 1):
|
|
359
|
+
dim_idx = permutation[i]
|
|
360
|
+
pattern += f"i{dim_idx}, "
|
|
361
|
+
pattern += f"i{permutation[-1]}] -> [ "
|
|
362
|
+
|
|
363
|
+
for i in range(len(permutation) - 1):
|
|
364
|
+
pattern += f"i{i}, "
|
|
365
|
+
|
|
366
|
+
pattern += f"i{permutation[-1]} ] }}"
|
|
367
|
+
|
|
368
|
+
# Creates the actual map.
|
|
369
|
+
reorder_proj: isl.Map = isl.Map.read_from_str(ctx, pattern)
|
|
370
|
+
reorder_proj = reorder_proj.set_tuple_name(isl.dim_type.in_, space).set_tuple_name(
|
|
371
|
+
isl.dim_type.out, space
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
return reorder_proj
|
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Flow of analysis:
|
|
3
|
+
- From mapping, create the iteration space. The iteration space is the
|
|
4
|
+
space of iterators in the mapping.
|
|
5
|
+
- Create the relation from iteration space to operation space.
|
|
6
|
+
- Create the relation from the iteration space to tensor space for each
|
|
7
|
+
(buffer, tensor, einsum) tuple.
|
|
8
|
+
- Run tile shape inference.
|
|
9
|
+
|
|
10
|
+
Adapted from:
|
|
11
|
+
https://github.com/NVlabs/timeloop/blob/4cf6d4cd043bc2a5d2eb02afa9063d7117a4dc11/ \
|
|
12
|
+
src/loop-analysis/mapping-to-isl/fused-mapping-to-isl.cpp
|
|
13
|
+
Relevant Name Changes:
|
|
14
|
+
- DataspaceId -> TensorName
|
|
15
|
+
- LogicalBuffer -> Buffet
|
|
16
|
+
- LogicalComputeUnit -> ComputeEinsum
|
|
17
|
+
- Loop.op_dim -> Loop.rank_variable
|
|
18
|
+
- *MappingNode.child -> MappingNode.flatten()[0]
|
|
19
|
+
- Root -> Mapping
|
|
20
|
+
- Compute.kernel -> Compute.einsum
|
|
21
|
+
- Branch -> Split
|
|
22
|
+
- FusedMapping -> Mapping
|
|
23
|
+
- [node]_id -> node (conditional on MappingNode having a valid hashing functor)
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from collections import defaultdict, deque
|
|
27
|
+
from pprint import pformat
|
|
28
|
+
from typing import List, Optional
|
|
29
|
+
|
|
30
|
+
import islpy as isl
|
|
31
|
+
|
|
32
|
+
from accelforge.frontend.mapping import (
|
|
33
|
+
Compute,
|
|
34
|
+
Mapping,
|
|
35
|
+
MappingNode,
|
|
36
|
+
MappingNodeWithChildren,
|
|
37
|
+
Sequential,
|
|
38
|
+
Storage,
|
|
39
|
+
)
|
|
40
|
+
from accelforge.frontend.workload import Workload
|
|
41
|
+
from accelforge.frontend._workload_isl._isl import get_projection_map
|
|
42
|
+
from accelforge.frontend.workload import TensorName
|
|
43
|
+
|
|
44
|
+
from accelforge.model._looptree.mapping_utilities import get_paths
|
|
45
|
+
from accelforge.model._looptree.types import Buffet
|
|
46
|
+
from accelforge.model._looptree.reuse.isl.isl_functions import project_dim_in_after
|
|
47
|
+
from accelforge.model._looptree.reuse.isl.mapping_to_isl.skews_from_mapping import (
|
|
48
|
+
skews_from_mapping,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
from . import DUMP_ISL_IR
|
|
52
|
+
from .tiling import tiling_from_mapping
|
|
53
|
+
from .types import (
|
|
54
|
+
BranchTiling,
|
|
55
|
+
BufferTensorEinsum,
|
|
56
|
+
ComputeEinsum,
|
|
57
|
+
MappingAnalysisResult,
|
|
58
|
+
Occupancy,
|
|
59
|
+
OperationOccupancy,
|
|
60
|
+
SkewsInfo,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def buffet_direct_above_sequential(mapping: Mapping) -> defaultdict[Buffet, bool]:
|
|
65
|
+
"""
|
|
66
|
+
TODO: Verify this docstring
|
|
67
|
+
For all Buffets (logical objects containing a tensor, its operating einsum,
|
|
68
|
+
and a abstract hardware level), denote whether the buffet is directly above
|
|
69
|
+
a :class:`~.Sequential`, or has an uninterrupted path of other buffets to a
|
|
70
|
+
`~.Sequential`.
|
|
71
|
+
|
|
72
|
+
Parameters
|
|
73
|
+
----------
|
|
74
|
+
mapping:
|
|
75
|
+
The mapping context of the buffets to sequential elements.
|
|
76
|
+
|
|
77
|
+
Returns
|
|
78
|
+
-------
|
|
79
|
+
A dictionary of buffets and whether they're directly above a Sequential.
|
|
80
|
+
"""
|
|
81
|
+
result: defaultdict[Buffet, bool] = defaultdict(lambda: False)
|
|
82
|
+
# TODO: Figure out if get_paths is just for certain MappingNodesWithChildren
|
|
83
|
+
# or not.
|
|
84
|
+
for path in get_paths(mapping):
|
|
85
|
+
leaf: Compute = path[-1]
|
|
86
|
+
last_bufs: List[Buffet] = []
|
|
87
|
+
|
|
88
|
+
node: MappingNode
|
|
89
|
+
for node in path:
|
|
90
|
+
match node:
|
|
91
|
+
# If we have a storage, create a buffet for the current leaf.
|
|
92
|
+
case Storage():
|
|
93
|
+
# TODO: Verify this port:
|
|
94
|
+
# https://github.com/NVlabs/timeloop/blob/32370826fdf1aa3c8deb0c93e6b2a2fc7cf053aa/src/loop-analysis/mapping-to-isl/fused-mapping-to-isl.cpp#L518-L520
|
|
95
|
+
# Note: Buffet seems to have changed a lot?
|
|
96
|
+
# https://github.com/NVlabs/timeloop/blob/master/include/loop-analysis/isl-ir.hpp#L96
|
|
97
|
+
last_bufs.extend(
|
|
98
|
+
Buffet(tensor=tensor, einsum=leaf.einsum, level=node.component)
|
|
99
|
+
for tensor in node.tensors
|
|
100
|
+
)
|
|
101
|
+
# TODO: Check that all buffets are unique, because right now
|
|
102
|
+
# it seems it's dependent on the last leaf in traversal?
|
|
103
|
+
|
|
104
|
+
# If we encounter a sequential, we know all the last buffet and its
|
|
105
|
+
# parents that are buffets are directly above sequential.
|
|
106
|
+
case Sequential():
|
|
107
|
+
for buf in last_bufs:
|
|
108
|
+
result[buf] = result[buf] or True
|
|
109
|
+
last_bufs.clear()
|
|
110
|
+
# If we encounter no storages or a sequential, we must not be
|
|
111
|
+
# directly above a sequential element, and thus can purge the path.
|
|
112
|
+
case _:
|
|
113
|
+
for buf in last_bufs:
|
|
114
|
+
result[buf] = result[buf] or False
|
|
115
|
+
last_bufs.clear()
|
|
116
|
+
|
|
117
|
+
return result
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def get_parallelism(mapping: Mapping) -> defaultdict[MappingNode, float]:
|
|
121
|
+
"""
|
|
122
|
+
Given a `accelforge.frontend.mapping.Mapping`, get the parallelism values for
|
|
123
|
+
the Compute leafs.
|
|
124
|
+
|
|
125
|
+
Parameters
|
|
126
|
+
----------
|
|
127
|
+
mapping:
|
|
128
|
+
The mapping to get parallelism for.
|
|
129
|
+
|
|
130
|
+
Returns
|
|
131
|
+
-------
|
|
132
|
+
A map relating Compute nodes with their parallelism.
|
|
133
|
+
"""
|
|
134
|
+
result: defaultdict[MappingNode, float] = defaultdict()
|
|
135
|
+
|
|
136
|
+
# Initiates DFS at the root of the mapping.
|
|
137
|
+
dfs_stack: deque[MappingNode] = deque([mapping])
|
|
138
|
+
|
|
139
|
+
while dfs_stack:
|
|
140
|
+
node: MappingNode = dfs_stack.pop()
|
|
141
|
+
|
|
142
|
+
match node:
|
|
143
|
+
# Recursively traverse children to find computes for parallelism.
|
|
144
|
+
case MappingNodeWithChildren():
|
|
145
|
+
dfs_stack.extend(node.nodes)
|
|
146
|
+
# If Compute has pre-specified parallelism from internal models, trust
|
|
147
|
+
# that it is right. Otherwise, assume none.
|
|
148
|
+
case Compute():
|
|
149
|
+
if hasattr(node, "parallelism"):
|
|
150
|
+
result[node] = node.parallelism
|
|
151
|
+
else:
|
|
152
|
+
result[node] = 1
|
|
153
|
+
case _:
|
|
154
|
+
continue
|
|
155
|
+
|
|
156
|
+
return result
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def align_dim_names(
|
|
160
|
+
map_: isl.Map,
|
|
161
|
+
reference: isl.Map,
|
|
162
|
+
map_align_dim_type: isl.dim_type = isl.dim_type.in_,
|
|
163
|
+
reference_dim_type: Optional[isl.dim_type] = None,
|
|
164
|
+
) -> isl.Map:
|
|
165
|
+
"""
|
|
166
|
+
Given an `isl.Map` and a reference `isl.Map`, align as many of the names as
|
|
167
|
+
possible in the first map with the reference map.
|
|
168
|
+
|
|
169
|
+
e.g. `map_ = [i] -> [o]` with `reference = [x] -> [y]` becomes `[x] -> [o]`
|
|
170
|
+
with map_
|
|
171
|
+
|
|
172
|
+
Parameters
|
|
173
|
+
----------
|
|
174
|
+
map_:
|
|
175
|
+
The map whose input is being aligned.
|
|
176
|
+
reference:
|
|
177
|
+
The map whose input names are used as reference for aligning `map`.
|
|
178
|
+
map_align_dim_type:
|
|
179
|
+
Dimension tuple in `map_` to align. Defaults to `isl.dim_type.in_`.
|
|
180
|
+
reference_dim_type:
|
|
181
|
+
Dimension tuple in `reference` whose names should be copied. Defaults to
|
|
182
|
+
`map_align_dim_type`.
|
|
183
|
+
|
|
184
|
+
Returns
|
|
185
|
+
-------
|
|
186
|
+
A version of `map_` with aligned input names.
|
|
187
|
+
"""
|
|
188
|
+
if reference_dim_type is None:
|
|
189
|
+
reference_dim_type = map_align_dim_type
|
|
190
|
+
|
|
191
|
+
for dim_idx in range(
|
|
192
|
+
min(map_.dim(map_align_dim_type), reference.dim(reference_dim_type))
|
|
193
|
+
):
|
|
194
|
+
dim_name: Optional[str] = reference.get_dim_name(reference_dim_type, dim_idx)
|
|
195
|
+
if dim_name is not None:
|
|
196
|
+
map_ = map_.set_dim_name(map_align_dim_type, dim_idx, dim_name)
|
|
197
|
+
|
|
198
|
+
return map_
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def occupancies_from_mapping(
|
|
202
|
+
mapping: Mapping, workload: Workload
|
|
203
|
+
) -> MappingAnalysisResult:
|
|
204
|
+
"""
|
|
205
|
+
Given a Mapping and a Workload, extract the data occupancies in memory.
|
|
206
|
+
|
|
207
|
+
Parameters
|
|
208
|
+
----------
|
|
209
|
+
mapping:
|
|
210
|
+
The Mapping of data to hardware.
|
|
211
|
+
workload:
|
|
212
|
+
The Workload occurring on chip.
|
|
213
|
+
|
|
214
|
+
Returns
|
|
215
|
+
-------
|
|
216
|
+
The occupancies as an analysis of the Workload on Mapping.
|
|
217
|
+
"""
|
|
218
|
+
branch_tiling: BranchTiling = tiling_from_mapping(mapping, workload)
|
|
219
|
+
# tiling: [tile_iteration_space] -> [iteration_space]
|
|
220
|
+
if DUMP_ISL_IR:
|
|
221
|
+
for node, tiling in branch_tiling.items():
|
|
222
|
+
print(f"[Tiling]Node({node}): {tiling}")
|
|
223
|
+
# TODO: Port this line
|
|
224
|
+
# https://github.com/NVlabs/timeloop/blob/32370826fdf1aa3c8deb0c93e6b2a2fc7cf053aa/src/loop-analysis/mapping-to-isl/fused-mapping-to-isl.cpp#L55-L64
|
|
225
|
+
print(f"[Ops]Node({node}): ")
|
|
226
|
+
|
|
227
|
+
occupancies: defaultdict[BufferTensorEinsum, Occupancy] = defaultdict()
|
|
228
|
+
skews: SkewsInfo = skews_from_mapping(mapping, workload)
|
|
229
|
+
# skew [Spacetime] -> [tile_iteration_space]
|
|
230
|
+
if DUMP_ISL_IR:
|
|
231
|
+
print(f"skews: {pformat(skews)}")
|
|
232
|
+
|
|
233
|
+
### Somewhere, call the domain space of the returned tilings {einsum}_iterations
|
|
234
|
+
for bte, skew in skews.bte_to_skew.items():
|
|
235
|
+
if DUMP_ISL_IR:
|
|
236
|
+
print(f"{bte} has skew: {skew}")
|
|
237
|
+
tiling = branch_tiling[bte.einsum]
|
|
238
|
+
|
|
239
|
+
accesses: Optional[isl.Map] = None
|
|
240
|
+
read_tensors: set[TensorName] = workload.einsums[
|
|
241
|
+
bte.einsum.einsum
|
|
242
|
+
].input_tensor_names
|
|
243
|
+
write_tensors: set[TensorName] = workload.einsums[
|
|
244
|
+
bte.einsum.einsum
|
|
245
|
+
].output_tensor_names
|
|
246
|
+
|
|
247
|
+
if bte.tensor in read_tensors or bte.tensor in write_tensors:
|
|
248
|
+
accesses = get_projection_map(
|
|
249
|
+
workload.einsums[bte.einsum.einsum], bte.tensor
|
|
250
|
+
)
|
|
251
|
+
else:
|
|
252
|
+
continue
|
|
253
|
+
|
|
254
|
+
aligned_skew: isl.Map = align_dim_names(skew.map_, tiling)
|
|
255
|
+
if DUMP_ISL_IR:
|
|
256
|
+
print(f"Skew: {skew.map_}")
|
|
257
|
+
print(f"Aligned Skew: {aligned_skew}")
|
|
258
|
+
print(f"Tiling: {tiling}")
|
|
259
|
+
print(f"{tiling.apply_range(accesses)}")
|
|
260
|
+
print(f"{skew.map_.dim(isl.dim_type.out)}")
|
|
261
|
+
occupancy: isl.Map = aligned_skew.apply_range(
|
|
262
|
+
project_dim_in_after(
|
|
263
|
+
tiling.apply_range(accesses),
|
|
264
|
+
skew.map_.dim(isl.dim_type.out),
|
|
265
|
+
# TODO: fix this unsafe mixing.
|
|
266
|
+
).set_tuple_name(
|
|
267
|
+
isl.dim_type.in_, aligned_skew.get_tuple_name(isl.dim_type.out)
|
|
268
|
+
)
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
occupancies[bte] = Occupancy(skew.tags, occupancy)
|
|
272
|
+
|
|
273
|
+
operations_occupancies: defaultdict[ComputeEinsum, OperationOccupancy] = (
|
|
274
|
+
defaultdict()
|
|
275
|
+
)
|
|
276
|
+
for ce, skew in skews.ce_unit_to_skew.items():
|
|
277
|
+
tiling: isl.Map = branch_tiling[ce.branch_leaf_node]
|
|
278
|
+
if DUMP_ISL_IR:
|
|
279
|
+
print(f"skew.map_ {skew.map_}")
|
|
280
|
+
operation_occupancy: isl.Map = skew.map_.apply_range(
|
|
281
|
+
project_dim_in_after(
|
|
282
|
+
tiling, skew.map_.dim(isl.dim_type.out)
|
|
283
|
+
).set_tuple_name(
|
|
284
|
+
# TODO: Unify the names at some point...
|
|
285
|
+
isl.dim_type.in_,
|
|
286
|
+
skew.map_.get_tuple_name(isl.dim_type.out),
|
|
287
|
+
)
|
|
288
|
+
)
|
|
289
|
+
operations_occupancies[ce] = OperationOccupancy(skew.tags, operation_occupancy)
|
|
290
|
+
|
|
291
|
+
return MappingAnalysisResult(
|
|
292
|
+
buffet_to_occupancy=occupancies,
|
|
293
|
+
compute_einsum_to_occupancy=operations_occupancies,
|
|
294
|
+
buffet_direct_above_sequential=buffet_direct_above_sequential(mapping),
|
|
295
|
+
compute_to_assumed_parallelism=get_parallelism(mapping),
|
|
296
|
+
branch_tiling=branch_tiling,
|
|
297
|
+
)
|