accelforge 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accelforge/__init__.py +21 -0
- accelforge/_accelerated_imports.py +16 -0
- accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
- accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
- accelforge/_deprecate/_simanneal/simanneal.py +666 -0
- accelforge/_deprecate/_simanneal/tracking.py +105 -0
- accelforge/_deprecate/_simanneal/wrappers.py +218 -0
- accelforge/_deprecate/_simanneal2/__init__.py +7 -0
- accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
- accelforge/_deprecate/_simanneal2/tracking.py +116 -0
- accelforge/_deprecate/compatibility_util.py +181 -0
- accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
- accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
- accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
- accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
- accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
- accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
- accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
- accelforge/_deprecate/tags.py +69 -0
- accelforge/_deprecate/viz/__init__.py +0 -0
- accelforge/_deprecate/viz/interactive.py +159 -0
- accelforge/_deprecate/viz/reservationtree.py +307 -0
- accelforge/_deprecate/viz/ski_slope.py +88 -0
- accelforge/_version.py +15 -0
- accelforge/examples.py +39 -0
- accelforge/frontend/__init__.py +10 -0
- accelforge/frontend/_binding.py +129 -0
- accelforge/frontend/_workload_isl/__init__.py +2 -0
- accelforge/frontend/_workload_isl/_isl.py +149 -0
- accelforge/frontend/_workload_isl/_symbolic.py +141 -0
- accelforge/frontend/arch copy.py +1544 -0
- accelforge/frontend/arch.py +1642 -0
- accelforge/frontend/config.py +63 -0
- accelforge/frontend/mapper/__init__.py +5 -0
- accelforge/frontend/mapper/ffm.py +126 -0
- accelforge/frontend/mapper/mapper.py +7 -0
- accelforge/frontend/mapper/metrics.py +30 -0
- accelforge/frontend/mapping/__init__.py +1 -0
- accelforge/frontend/mapping/mapping.py +1736 -0
- accelforge/frontend/model.py +14 -0
- accelforge/frontend/renames.py +150 -0
- accelforge/frontend/spec copy.py +230 -0
- accelforge/frontend/spec.py +301 -0
- accelforge/frontend/variables.py +12 -0
- accelforge/frontend/workload.py +952 -0
- accelforge/mapper/FFM/__init__.py +9 -0
- accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
- accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
- accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
- accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
- accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
- accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
- accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
- accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
- accelforge/mapper/FFM/data.py +61 -0
- accelforge/mapper/FFM/main copy.py +236 -0
- accelforge/mapper/FFM/main.py +208 -0
- accelforge/mapper/FFM/mappings.py +510 -0
- accelforge/mapper/FFM/pmappings.py +310 -0
- accelforge/mapper/__init__.py +4 -0
- accelforge/mapper.py +0 -0
- accelforge/model/__init__.py +1 -0
- accelforge/model/_looptree/__init__.py +0 -0
- accelforge/model/_looptree/accesses.py +335 -0
- accelforge/model/_looptree/capacity/__init__.py +1 -0
- accelforge/model/_looptree/capacity/aggregators.py +36 -0
- accelforge/model/_looptree/capacity/capacity.py +47 -0
- accelforge/model/_looptree/energy.py +150 -0
- accelforge/model/_looptree/equivalent_ranks.py +29 -0
- accelforge/model/_looptree/latency/__init__.py +1 -0
- accelforge/model/_looptree/latency/latency.py +98 -0
- accelforge/model/_looptree/latency/memory.py +120 -0
- accelforge/model/_looptree/latency/processors.py +92 -0
- accelforge/model/_looptree/mapping_utilities.py +71 -0
- accelforge/model/_looptree/reuse/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
- accelforge/model/_looptree/reuse/isl/des.py +59 -0
- accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
- accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
- accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
- accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
- accelforge/model/_looptree/run.py +122 -0
- accelforge/model/_looptree/types.py +26 -0
- accelforge/model/_looptree/visualization/__init__.py +0 -0
- accelforge/model/_looptree/visualization/occupancy.py +11 -0
- accelforge/model/main.py +222 -0
- accelforge/plotting/__init__.py +2 -0
- accelforge/plotting/mappings.py +219 -0
- accelforge/plotting/specs.py +57 -0
- accelforge/util/__init__.py +4 -0
- accelforge/util/_base_analysis_types.py +24 -0
- accelforge/util/_basetypes.py +1089 -0
- accelforge/util/_frozenset.py +36 -0
- accelforge/util/_isl.py +29 -0
- accelforge/util/_itertools.py +14 -0
- accelforge/util/_mathfuncs.py +57 -0
- accelforge/util/_parse_expressions.py +339 -0
- accelforge/util/_picklecache.py +32 -0
- accelforge/util/_setexpressions.py +268 -0
- accelforge/util/_sympy/__init__.py +0 -0
- accelforge/util/_sympy/broadcast_max.py +18 -0
- accelforge/util/_visualization.py +112 -0
- accelforge/util/_yaml.py +579 -0
- accelforge/util/parallel.py +193 -0
- accelforge-0.0.1.dist-info/METADATA +64 -0
- accelforge-0.0.1.dist-info/RECORD +258 -0
- accelforge-0.0.1.dist-info/WHEEL +5 -0
- accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
- accelforge-0.0.1.dist-info/top_level.txt +5 -0
- docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
- docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
- docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
- docs/_build/html/_sources/fastfusion.rst.txt +20 -0
- docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
- docs/_build/html/_sources/index.rst.txt +87 -0
- docs/_build/html/_sources/modules.rst.txt +7 -0
- docs/_build/html/_sources/notes/citation.rst.txt +45 -0
- docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
- docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
- docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
- docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
- docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
- docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
- docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
- docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
- docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
- docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
- docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
- docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
- docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
- docs/_build/html/_sources/notes/spec.rst.txt +36 -0
- docs/source/_ext/include_attrs.py +213 -0
- docs/source/_ext/include_docstring.py +364 -0
- docs/source/_ext/include_functions.py +154 -0
- docs/source/_ext/include_notebook.py +131 -0
- docs/source/_ext/include_yaml.py +119 -0
- docs/source/_ext/inherited_attributes.py +222 -0
- docs/source/_ext/paths.py +4 -0
- docs/source/conf.py +79 -0
- examples/arches/compute_in_memory/_include.yaml +74 -0
- examples/arches/compute_in_memory/_include_functions.py +229 -0
- examples/arches/compute_in_memory/_load_spec.py +57 -0
- examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
- examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
- examples/arches/compute_in_memory/components/misc.py +195 -0
- examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
- examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
- examples/arches/compute_in_memory/isaac.yaml +233 -0
- examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
- examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
- examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
- examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
- examples/arches/eyeriss.yaml +68 -0
- examples/arches/fanout_variations/at_glb.yaml +31 -0
- examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
- examples/arches/fanout_variations/at_mac.yaml +31 -0
- examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
- examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
- examples/arches/nvdla.yaml +47 -0
- examples/arches/simple.yaml +28 -0
- examples/arches/tpu_v4i.yaml +67 -0
- examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
- examples/misc/component_annotated.yaml +33 -0
- examples/workloads/gpt3_6.7B.yaml +124 -0
- examples/workloads/matmuls.yaml +20 -0
- examples/workloads/mobilenet_28.yaml +81 -0
- examples/workloads/mobilenet_various_separate.yaml +106 -0
- examples/workloads/three_matmuls_annotated.yaml +59 -0
- notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
- notebooks/compute_in_memory/_scripts.py +339 -0
- notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
- notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
- notebooks/paths.py +4 -0
- notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
- notebooks/tutorials/FFM.ipynb +3498 -0
- notebooks/tutorials/_include.py +48 -0
- notebooks/tutorials/component_energy_area.ipynb +363 -0
- tests/Q_mapping.yaml +38 -0
- tests/__init__.py +0 -0
- tests/conv.mapping.yaml +27 -0
- tests/conv.workload.yaml +13 -0
- tests/conv_sym.mapping.yaml +43 -0
- tests/copy.mapping.yaml +35 -0
- tests/copy.workload.yaml +15 -0
- tests/distribuffers/__init__.py +0 -0
- tests/distribuffers/multicast/test_cases.yaml +482 -0
- tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
- tests/distribuffers/spec/distributed.yaml +100 -0
- tests/distribuffers/spec/logical_arch.yaml +32 -0
- tests/distribuffers/spec/physical_arch.yaml +69 -0
- tests/distribuffers/test_binding.py +48 -0
- tests/frontend/__init__.py +0 -0
- tests/frontend/test_mapping_viz.py +52 -0
- tests/mapper/__init__.py +0 -0
- tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
- tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
- tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
- tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
- tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
- tests/mapper/test_mapping_to_isl.py +90 -0
- tests/mapper/test_spatial_reuse_analysis.py +67 -0
- tests/mapper/test_temporal_reuse_analysis.py +56 -0
- tests/mapper/util.py +58 -0
- tests/matmul.mapping.yaml +29 -0
- tests/matmul.workload.yaml +12 -0
- tests/matmul_spatial.mapping.yaml +44 -0
- tests/mha.renames.yaml +65 -0
- tests/mha.workload.yaml +67 -0
- tests/mha.yaml +59 -0
- tests/mha_full.workload.yaml +67 -0
- tests/mobilenet.workload.yaml +35 -0
- tests/mobilenet_long.workload.yaml +64 -0
- tests/pmappingcache.py +24 -0
- tests/processing_stage.arch.yaml +40 -0
- tests/snowcat.arch.yaml +36 -0
- tests/test_ffm_join_pmappings.py +106 -0
- tests/test_ffm_make_pmappings.py +82 -0
- tests/test_ffm_make_tile_shapes.py +49 -0
- tests/test_mapper.py +100 -0
- tests/test_model.py +37 -0
- tests/test_plotting.py +72 -0
- tests/test_processing_stage.py +46 -0
- tests/test_symbolic_model.py +248 -0
- tests/test_workload.py +141 -0
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Relevant name changes:
|
|
3
|
+
- [logical] buffer/lbuf -> buffet
|
|
4
|
+
- [logical] comp/lcomp -> compute_einsum
|
|
5
|
+
-
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from abc import ABC
|
|
9
|
+
|
|
10
|
+
from collections import defaultdict
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import List, TypeAlias
|
|
13
|
+
|
|
14
|
+
import islpy as isl
|
|
15
|
+
|
|
16
|
+
from accelforge.frontend.mapping import Compute, MappingNode
|
|
17
|
+
from accelforge.frontend.workload import TensorName
|
|
18
|
+
from accelforge.model._looptree.types import Buffet
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# Mapper intermediates.
|
|
22
|
+
##
|
|
23
|
+
# @brief Iteration -> Operation relation that specifies the tiling.
|
|
24
|
+
#
|
|
25
|
+
# The tiling relation allows us to distribute data and operations using the
|
|
26
|
+
# skew and data distribution relations.
|
|
27
|
+
#
|
|
28
|
+
# The tiling relation may have unspecified bounds which will be inferred by
|
|
29
|
+
# LoopTree. The tiling relation that goes to the nest analysis is guaranteed
|
|
30
|
+
# to be fully specified.
|
|
31
|
+
EinsumName: TypeAlias = str
|
|
32
|
+
"Einsum's identifier."
|
|
33
|
+
Tiling: TypeAlias = isl.Map
|
|
34
|
+
"Tiling of data and operations."
|
|
35
|
+
BranchTiling: TypeAlias = defaultdict[MappingNode, Tiling]
|
|
36
|
+
"Relation between a node and its tiling."
|
|
37
|
+
BuffetTiling: TypeAlias = defaultdict[Buffet, Tiling]
|
|
38
|
+
"Relation between a buffet and its tiling."
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass(frozen=True, slots=True)
|
|
42
|
+
class Tag(ABC): # pylint: disable=too-few-public-methods
|
|
43
|
+
"""Associating an element with its type metadata without introspection?"""
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class TemporalTag(Tag): # pylint: disable=too-few-public-methods
|
|
47
|
+
"""The associated element is temporally spreading?"""
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass(frozen=True, slots=True)
|
|
51
|
+
class SpatialTag(Tag): # pylint: disable=too-few-public-methods
|
|
52
|
+
"""The associated element is spatially spreading?"""
|
|
53
|
+
|
|
54
|
+
spatial_dim: int
|
|
55
|
+
"The spatial dim in a given buffer?"
|
|
56
|
+
buffer: MappingNode
|
|
57
|
+
"The buffer the spatial dim is across?"
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class PipelineTag(Tag): # pylint: disable=too-few-public-methods
|
|
61
|
+
"""The associated element is pipelined?"""
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class SequentialTag(Tag): # pylint: disable=too-few-public-methods
|
|
65
|
+
"""The associated element is serialized?"""
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
TEMPORAL_TAGS = (TemporalTag, SequentialTag)
|
|
69
|
+
BRANCH_TAGS = (PipelineTag, SequentialTag)
|
|
70
|
+
LOOP_TAGS = (TemporalTag, SpatialTag)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@dataclass(frozen=True, slots=True)
|
|
74
|
+
class TaggedMap: # pylint: disable=too-few-public-methods
|
|
75
|
+
"""A :class:`isl.Map` with its dimensions tagged."""
|
|
76
|
+
|
|
77
|
+
tags: List[Tag]
|
|
78
|
+
map_: isl.Map
|
|
79
|
+
|
|
80
|
+
def __repr__(self):
|
|
81
|
+
return f"{type(self)}({self.tags}, {self.map_})"
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class Occupancy(TaggedMap): # pylint: disable=too-few-public-methods
|
|
85
|
+
"""Location of data in [logical?] hardware elements."""
|
|
86
|
+
|
|
87
|
+
def __init__(self, tags: list[Tag], map_: isl.Map):
|
|
88
|
+
assert len(tags) == map_.dim(isl.dim_type.in_), (
|
|
89
|
+
"Occupancy labels input dims with tags\n"
|
|
90
|
+
"-------------------------------------\n"
|
|
91
|
+
f"tags: {tags}\n"
|
|
92
|
+
f"map: {map_}\n"
|
|
93
|
+
)
|
|
94
|
+
super().__init__(tags, map_)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class OperationOccupancy(TaggedMap): # pylint: disable=too-few-public-methods
|
|
98
|
+
"""Location of operations in [logical?] hardware elements."""
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class Fill(TaggedMap):
|
|
102
|
+
"""Spacetime -> fill of a logical buffer"""
|
|
103
|
+
|
|
104
|
+
def __init__(self, tags: list[Tag], map_: isl.Map):
|
|
105
|
+
assert len(tags) == map_.dim(isl.dim_type.in_), (
|
|
106
|
+
"Fill labels input dims with tags\n"
|
|
107
|
+
"--------------------------------\n"
|
|
108
|
+
f"tags: {tags}\n"
|
|
109
|
+
f"map: {map_}\n"
|
|
110
|
+
)
|
|
111
|
+
super().__init__(tags, map_)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class Skew(TaggedMap): # pylint: disable=too-few-public-methods
|
|
115
|
+
"""TODO: Figure out what this is."""
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@dataclass(frozen=True, slots=True)
|
|
119
|
+
class BufferTensorEinsum:
|
|
120
|
+
"""
|
|
121
|
+
A buffet relating a [logical?] hardware element storing data, a tensor it
|
|
122
|
+
contains, and the [logical?] hardware element that is requesting the tensor.
|
|
123
|
+
|
|
124
|
+
See Also:
|
|
125
|
+
---------
|
|
126
|
+
:class:`accelforge.model._looptree.reuse.Buffet`
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
buffer: str
|
|
130
|
+
"The logical name of the buffer supplying the tensor."
|
|
131
|
+
tensor: TensorName
|
|
132
|
+
"The tensor being supplied."
|
|
133
|
+
einsum: Compute
|
|
134
|
+
"The leaf in mapping doing the einsum compute on tensor."
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@dataclass(frozen=True, slots=True)
|
|
138
|
+
class ComputeEinsum:
|
|
139
|
+
"""A logical computation the workload? needs to carry out."""
|
|
140
|
+
|
|
141
|
+
compute: str
|
|
142
|
+
"""TODO: Figure out what this does."""
|
|
143
|
+
branch_leaf_node: Compute
|
|
144
|
+
"""TODO: The compute element at the leaf of a :class:`BranchTiling`"""
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
# Output classes.
|
|
148
|
+
@dataclass(frozen=True, slots=True)
|
|
149
|
+
class SkewsInfo: # pylint: disable=too-few-public-methods
|
|
150
|
+
"""TODO: Figure out what this does."""
|
|
151
|
+
|
|
152
|
+
bte_to_skew: defaultdict[BufferTensorEinsum, Skew]
|
|
153
|
+
"""Relates a :class:`~.BufferTensorEinsum` to a :class:`~.Skew`"""
|
|
154
|
+
ce_unit_to_skew: defaultdict[ComputeEinsum, Skew]
|
|
155
|
+
"""Relates a :class:`~.ComputeEinsum` to a :class:`~.Skew`"""
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
@dataclass(frozen=True, slots=True)
|
|
159
|
+
class MappingAnalysisResult: # pylint: disable=too-few-public-methods
|
|
160
|
+
"""
|
|
161
|
+
Results of mapping analysis that will become input into reuse
|
|
162
|
+
analysis.
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
buffet_direct_above_sequential: defaultdict[Buffet, bool]
|
|
166
|
+
"""
|
|
167
|
+
Whether a buffet is right above a sequential node. This is used when calculating
|
|
168
|
+
capacity since some data can be dropped earlier than usual when using sequential
|
|
169
|
+
mapping without tiling.
|
|
170
|
+
"""
|
|
171
|
+
buffet_to_occupancy: defaultdict[BufferTensorEinsum, Occupancy]
|
|
172
|
+
"""The occupancy of every buffet as defined in the mapping."""
|
|
173
|
+
compute_einsum_to_occupancy: defaultdict[ComputeEinsum, OperationOccupancy]
|
|
174
|
+
"""The occupancy of every compute unit."""
|
|
175
|
+
# TODO: Figure out if this is deprecated:
|
|
176
|
+
# https://github.com/NVlabs/timeloop/blob/32370826fdf1aa3c8deb0c93e6b2a2fc7cf053aa/include/loop-analysis/mapping-to-isl/fused-mapping-to-isl.hpp#L31-L35
|
|
177
|
+
# node_to_buffets
|
|
178
|
+
# Buffets found between the current root/branch node and the next one.
|
|
179
|
+
branch_tiling: BranchTiling
|
|
180
|
+
"""
|
|
181
|
+
Tiling of each branch. The tiling is a relation between tiling variables and
|
|
182
|
+
operations. An uncompletely tiled branch will have multiple-valued :class:`isl.Map`.
|
|
183
|
+
"""
|
|
184
|
+
compute_to_assumed_parallelism: defaultdict[MappingNode, float]
|
|
185
|
+
"""
|
|
186
|
+
We can assume an amount of parallelism to quickly calculate approx. compute
|
|
187
|
+
latency by simply dividing number of operations with assumed parallelism.
|
|
188
|
+
"""
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Handles the ISL spatial reuse functions.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
import islpy as isl
|
|
10
|
+
|
|
11
|
+
from accelforge.frontend.mapping import MappingNode
|
|
12
|
+
from accelforge.model._looptree.reuse.isl.isl_functions import (
|
|
13
|
+
insert_equal_dims_map,
|
|
14
|
+
reorder_projector,
|
|
15
|
+
)
|
|
16
|
+
from accelforge.model._looptree.reuse.isl.mapping_to_isl.types import (
|
|
17
|
+
TEMPORAL_TAGS,
|
|
18
|
+
Fill,
|
|
19
|
+
Occupancy,
|
|
20
|
+
SpatialTag,
|
|
21
|
+
Tag,
|
|
22
|
+
TaggedMap,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Transfers(TaggedMap):
|
|
27
|
+
"""Transfers between regions in spacetime."""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Reads(TaggedMap):
|
|
31
|
+
"""Reads between regions in spacetime."""
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass(frozen=True, slots=True)
|
|
35
|
+
class TransferInfo:
|
|
36
|
+
"""Data transfer information about a certain [subset] of the chip."""
|
|
37
|
+
|
|
38
|
+
# Crucial information to transfer info.
|
|
39
|
+
fulfilled_fill: Transfers
|
|
40
|
+
"""Fills done by peer-to-peer transfers."""
|
|
41
|
+
unfulfilled_fill: Fill
|
|
42
|
+
"""Fills not performed."""
|
|
43
|
+
parent_reads: Reads
|
|
44
|
+
"""Fills done by parent-to-child transfers."""
|
|
45
|
+
hops: isl.PwQPolynomial
|
|
46
|
+
"""Peer-to-peer transfer cost metric across spacetime."""
|
|
47
|
+
|
|
48
|
+
# Metadata on what is occurring.
|
|
49
|
+
link_transfer: bool
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class TransferModel(ABC):
|
|
53
|
+
"""
|
|
54
|
+
A peer-to-peer/multicast transfer model for spatial analysis.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
@abstractmethod
|
|
58
|
+
def apply(self, buff: MappingNode, fills: Fill, occs: Occupancy) -> TransferInfo:
|
|
59
|
+
"""
|
|
60
|
+
Given a buffer, its fills across time, and its occupancies across time,
|
|
61
|
+
calculate the spatial transfers.
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
buff:
|
|
66
|
+
The buffer whose spatial analysis is being considered.
|
|
67
|
+
fills:
|
|
68
|
+
The fill of `buffer` across time from parents.
|
|
69
|
+
occs:
|
|
70
|
+
The occupancy of `buffer` across time.
|
|
71
|
+
|
|
72
|
+
Returns
|
|
73
|
+
-------
|
|
74
|
+
Fills that were fulfilled, Fills that were unfilled, and parent reads per
|
|
75
|
+
position in spacetime. Then, gets hops per timestep.
|
|
76
|
+
"""
|
|
77
|
+
raise NotImplementedError(
|
|
78
|
+
f"{type(self)} has not implemented `apply(self, MappingNode, Fill, Occupancy)`"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def __repr__(self):
|
|
82
|
+
"""Returns what transfer model it is."""
|
|
83
|
+
return f"{type(self)}"
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class SimpleLinkTransferModel(TransferModel):
|
|
87
|
+
"""
|
|
88
|
+
Basic link transfer model.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def apply(self, buff: MappingNode, fills: Fill, occs: Occupancy) -> TransferInfo:
|
|
92
|
+
# Sanity check the fill is for the same occupancy. Necessary but insufficient proof.
|
|
93
|
+
assert fills.tags == occs.tags, (
|
|
94
|
+
"Fill and Occupancy mismatch"
|
|
95
|
+
"---------------------------"
|
|
96
|
+
f"Fill: {fills}\n"
|
|
97
|
+
f"Occs: {occs}\n"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Gets number of input dimensions, along with spatial and temporal indices.
|
|
101
|
+
n: int = fills.map_.dim(isl.dim_type.in_)
|
|
102
|
+
spatial_dims: list[int] = get_spatial_tags_idxs(fills.tags, buff)
|
|
103
|
+
last_temporal: Optional[int] = get_last_temporal_tag_idx(fills.tags)
|
|
104
|
+
|
|
105
|
+
# No temporal or no spatial dims, you're just not moving data across time
|
|
106
|
+
# so no transfers occurring.
|
|
107
|
+
if last_temporal is None or len(spatial_dims) == 0:
|
|
108
|
+
return TransferInfo(
|
|
109
|
+
fulfilled_fill=Transfers(
|
|
110
|
+
fills.tags, fills.map_.subtract(fills.map_)
|
|
111
|
+
), # Empty map
|
|
112
|
+
unfulfilled_fill=fills, # No fulfilled_fills, so only unfulfilled_fills
|
|
113
|
+
parent_reads=Reads(
|
|
114
|
+
occs.tags, occs.map_.subtract(occs.map_)
|
|
115
|
+
), # Empty map
|
|
116
|
+
hops=isl.PwQPolynomial.from_qpolynomial(
|
|
117
|
+
isl.QPolynomial.zero_on_domain(fills.map_.domain().get_space())
|
|
118
|
+
),
|
|
119
|
+
link_transfer=True,
|
|
120
|
+
)
|
|
121
|
+
# Gets the connectivity between points in space.
|
|
122
|
+
connectivity: isl.Map = make_mesh_connectivity(
|
|
123
|
+
len(spatial_dims), occs.map_.get_tuple_name(isl.dim_type.in_)
|
|
124
|
+
)
|
|
125
|
+
padded_connectivity: isl.Map = insert_equal_dims_map(
|
|
126
|
+
connectivity, 0, 0, n - len(spatial_dims) - 1
|
|
127
|
+
)
|
|
128
|
+
permutation: list[int] = make_connectivity_permutation(spatial_dims, n)
|
|
129
|
+
reorder_map: isl.Map = reorder_projector(
|
|
130
|
+
permutation, occs.map_.get_tuple_name(isl.dim_type.in_)
|
|
131
|
+
)
|
|
132
|
+
complete_connectivity: isl.Map = reorder_map.apply_range(
|
|
133
|
+
padded_connectivity
|
|
134
|
+
).apply_range(reorder_map.reverse())
|
|
135
|
+
|
|
136
|
+
# Gets data available from neighbors at each point in space per time.
|
|
137
|
+
available_from_neighbors: isl.Map = complete_connectivity.apply_range(occs.map_)
|
|
138
|
+
# Prunes data that does not need to be fetched from a higher in the mem hierarchy.
|
|
139
|
+
neighbor_filled: isl.Map = fills.map_.intersect(available_from_neighbors)
|
|
140
|
+
|
|
141
|
+
return TransferInfo(
|
|
142
|
+
fulfilled_fill=Transfers(fills.tags, neighbor_filled.coalesce()),
|
|
143
|
+
unfulfilled_fill=Fill(
|
|
144
|
+
fills.tags, fills.map_.subtract(neighbor_filled).coalesce()
|
|
145
|
+
),
|
|
146
|
+
# Empty, since only p2p analyzed.
|
|
147
|
+
parent_reads=Reads(
|
|
148
|
+
fills.tags, neighbor_filled.subtract(neighbor_filled).coalesce()
|
|
149
|
+
),
|
|
150
|
+
hops=isl.PwQPolynomial.from_qpolynomial(
|
|
151
|
+
isl.QPolynomial.one_on_domain(neighbor_filled.wrap().get_space())
|
|
152
|
+
)
|
|
153
|
+
.intersect_domain(neighbor_filled.wrap())
|
|
154
|
+
.coalesce(),
|
|
155
|
+
link_transfer=True,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def make_mesh_connectivity(n: int, spacetime: str) -> isl.Map:
|
|
160
|
+
"""
|
|
161
|
+
Makes a neighbor-to-neighbor mesh connection given a number of spatial dims.
|
|
162
|
+
|
|
163
|
+
Parameters
|
|
164
|
+
----------
|
|
165
|
+
n:
|
|
166
|
+
The number of spatial dimensions.
|
|
167
|
+
spacetime:
|
|
168
|
+
The name of the spacetime the mesh is operating on.
|
|
169
|
+
|
|
170
|
+
Returns
|
|
171
|
+
-------
|
|
172
|
+
A direct orthogonal adjacency map on the space `spacetime[t, x_1, x_2, ..., x_n]`
|
|
173
|
+
"""
|
|
174
|
+
mesh: isl.Map
|
|
175
|
+
match (n):
|
|
176
|
+
case 2:
|
|
177
|
+
mesh = isl.Map.read_from_str(
|
|
178
|
+
isl.DEFAULT_CONTEXT,
|
|
179
|
+
"{ [t, x, y] -> [t-1, x', y'] : "
|
|
180
|
+
" (y'=y and x'=x-1) or (y'=y and x'=x+1) "
|
|
181
|
+
" or (x'=x and y'=y-1) or (x'=x and y'=y+1) }",
|
|
182
|
+
)
|
|
183
|
+
case 1:
|
|
184
|
+
mesh = isl.Map.read_from_str(
|
|
185
|
+
isl.DEFAULT_CONTEXT,
|
|
186
|
+
"{ [t, x] -> [t-1, x'] : (x'=x-1) or (x'=x+1) }",
|
|
187
|
+
)
|
|
188
|
+
case _:
|
|
189
|
+
raise ValueError(f"Cannot make mesh with {n} spatial dims")
|
|
190
|
+
|
|
191
|
+
mesh = mesh.set_tuple_name(isl.dim_type.in_, spacetime).set_tuple_name(
|
|
192
|
+
isl.dim_type.out, spacetime
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
return mesh
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def make_connectivity_permutation(spatial_idxs: list[int], dims: int) -> list[int]:
|
|
199
|
+
"""TODO: Figure out what this is doing."""
|
|
200
|
+
permutation: list[int] = []
|
|
201
|
+
|
|
202
|
+
cur_spatial_idx: int = 0
|
|
203
|
+
for i in range(dims):
|
|
204
|
+
if cur_spatial_idx < len(spatial_idxs) and i == spatial_idxs[cur_spatial_idx]:
|
|
205
|
+
cur_spatial_idx += 1
|
|
206
|
+
else:
|
|
207
|
+
permutation.append(i)
|
|
208
|
+
|
|
209
|
+
for spatial_idx in spatial_idxs:
|
|
210
|
+
permutation.append(spatial_idx)
|
|
211
|
+
|
|
212
|
+
return permutation
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def get_spatial_tags_idxs(tags: list[Tag], buffer: MappingNode) -> list[int]:
|
|
216
|
+
"""
|
|
217
|
+
Given a list if tags, identify the spatial dimensions belong to a given `buffer`.
|
|
218
|
+
|
|
219
|
+
Parameters
|
|
220
|
+
----------
|
|
221
|
+
tags:
|
|
222
|
+
The `Occupancy` or `Fill` domain dimension tags.
|
|
223
|
+
buffer:
|
|
224
|
+
The `MappingNode` which is the logical-memory we're looking for spatial
|
|
225
|
+
dims over.
|
|
226
|
+
|
|
227
|
+
Returns
|
|
228
|
+
-------
|
|
229
|
+
A list of the spatial_dim_idxs in order.
|
|
230
|
+
"""
|
|
231
|
+
spatial_dim_idxs: list[int] = [
|
|
232
|
+
i
|
|
233
|
+
for i, tag in enumerate(tags)
|
|
234
|
+
if isinstance(tag, SpatialTag) and tag.buffer == buffer
|
|
235
|
+
]
|
|
236
|
+
|
|
237
|
+
return spatial_dim_idxs
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def get_last_temporal_tag_idx(tags: list[Tag]) -> Optional[int]:
|
|
241
|
+
"""
|
|
242
|
+
Returns the idx of the deepest temporal tag in the list.
|
|
243
|
+
|
|
244
|
+
Parameters
|
|
245
|
+
----------
|
|
246
|
+
tags:
|
|
247
|
+
A list of `Tags`.
|
|
248
|
+
|
|
249
|
+
Returns
|
|
250
|
+
-------
|
|
251
|
+
The index of the last tag that is a `TEMPORAL_TAGS`.
|
|
252
|
+
"""
|
|
253
|
+
if len(tags) == 0:
|
|
254
|
+
return None
|
|
255
|
+
|
|
256
|
+
for idx, tag in reversed(list(enumerate(tags))):
|
|
257
|
+
if isinstance(tag, TEMPORAL_TAGS):
|
|
258
|
+
return idx
|
|
259
|
+
|
|
260
|
+
return None
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Handles the ISL temporal reuse functions.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
import islpy as isl
|
|
8
|
+
|
|
9
|
+
from accelforge.model._looptree.reuse.isl.isl_functions import map_to_shifted
|
|
10
|
+
from accelforge.model._looptree.reuse.isl.mapping_to_isl.types import (
|
|
11
|
+
TEMPORAL_TAGS,
|
|
12
|
+
Fill,
|
|
13
|
+
Occupancy,
|
|
14
|
+
Tag,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass(frozen=True)
|
|
19
|
+
class TemporalReuse:
|
|
20
|
+
"""Results for an temporal reuse analysis."""
|
|
21
|
+
|
|
22
|
+
effective_occupancy: Occupancy
|
|
23
|
+
"""TODO: Figure this out."""
|
|
24
|
+
fill: Fill
|
|
25
|
+
"""Data deliveries to locations in spacetime that need to be made."""
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def analyze_temporal_reuse(
|
|
29
|
+
occ: Occupancy, exploit_reuse: bool = True, multi_loop_reuse: bool = True
|
|
30
|
+
) -> TemporalReuse:
|
|
31
|
+
"""
|
|
32
|
+
Computes the required fill to satisfy the buffer occupancy.
|
|
33
|
+
If the buffer can `exploit_reuse`, then the fill will only consist
|
|
34
|
+
of data not currently resident in buffer.
|
|
35
|
+
|
|
36
|
+
Parameters
|
|
37
|
+
----------
|
|
38
|
+
occ:
|
|
39
|
+
The logical occupancy to be temporally analyzed.
|
|
40
|
+
exploit_reuse:
|
|
41
|
+
Temporally exploits reuse through persisting data currently in buffer
|
|
42
|
+
to the next time step.
|
|
43
|
+
multi_loop_reuse:
|
|
44
|
+
Whether when this loop, or one above it in the memory hierarchy, loops,
|
|
45
|
+
does the buffer flush.
|
|
46
|
+
|
|
47
|
+
Returns
|
|
48
|
+
-------
|
|
49
|
+
A struct containing a `..types.Fill` which is how to load data into the buffer
|
|
50
|
+
across time and a `..types.Occupancy` describing the effective_occupancy across
|
|
51
|
+
time (i.e., what data needs to be persisted in the buffer per time step and what
|
|
52
|
+
can be ignored/purged).
|
|
53
|
+
|
|
54
|
+
TODO: Make sure spaces are named properly
|
|
55
|
+
"""
|
|
56
|
+
if exploit_reuse:
|
|
57
|
+
return fill_from_occupancy(occ, multi_loop_reuse)
|
|
58
|
+
return TemporalReuse(occ, Fill(occ.tags, occ.map_))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def fill_from_occupancy(
|
|
62
|
+
occupancy: Occupancy, multiple_loop_reuse: bool
|
|
63
|
+
) -> TemporalReuse:
|
|
64
|
+
"""
|
|
65
|
+
Given an occupancy and if you're allowed to reuse across loops, calculate the
|
|
66
|
+
`fill` and the `effective_occupancy` per time step.
|
|
67
|
+
|
|
68
|
+
Parameters
|
|
69
|
+
----------
|
|
70
|
+
occupancy:
|
|
71
|
+
The logical occupancy of data in logical buffers.
|
|
72
|
+
multi_loop_reuse:
|
|
73
|
+
If you are allowed to use data between loop iterations.
|
|
74
|
+
|
|
75
|
+
Returns
|
|
76
|
+
-------
|
|
77
|
+
A `TemporalReuse` object that contains the `fill` and `effective_occpancy`
|
|
78
|
+
of the lowest buffer level.
|
|
79
|
+
"""
|
|
80
|
+
# Iterates through each dimension in reverse order (i.e., deepest loop first)
|
|
81
|
+
occ = occupancy.map_.copy()
|
|
82
|
+
tags = occupancy.tags.copy()
|
|
83
|
+
for dim_idx, tag in reversed(list(enumerate(occupancy.tags))):
|
|
84
|
+
if not isinstance(tag, TEMPORAL_TAGS):
|
|
85
|
+
continue
|
|
86
|
+
# Check if temporal dimension is "trivial," i.e., equals a singular value
|
|
87
|
+
proj_occ: isl.Map = occ.project_out(
|
|
88
|
+
isl.dim_type.in_, dim_idx, 1
|
|
89
|
+
).set_tuple_name(
|
|
90
|
+
isl.dim_type.in_, f"{occ.get_tuple_name(isl.dim_type.in_)}_abridged"
|
|
91
|
+
)
|
|
92
|
+
reinserted_occ: isl.Map = (
|
|
93
|
+
proj_occ.insert_dims(isl.dim_type.in_, dim_idx, 1).set_tuple_name(
|
|
94
|
+
isl.dim_type.in_,
|
|
95
|
+
occ.get_tuple_name(isl.dim_type.in_).removesuffix("_abridged"),
|
|
96
|
+
)
|
|
97
|
+
).intersect_domain(occ.domain())
|
|
98
|
+
|
|
99
|
+
if occ.plain_is_equal(reinserted_occ) or occ.is_equal(reinserted_occ):
|
|
100
|
+
occ = proj_occ
|
|
101
|
+
tags.pop(dim_idx)
|
|
102
|
+
continue
|
|
103
|
+
|
|
104
|
+
# Nontrivial analysis
|
|
105
|
+
time_shift: isl.Map
|
|
106
|
+
if not multiple_loop_reuse:
|
|
107
|
+
# TODO: Verify space names are preserved and/or replace.
|
|
108
|
+
time_shift = map_to_shifted(occ.domain().get_space(), dim_idx, -1)
|
|
109
|
+
# Calculates the time_shift assuming no cache flushing for loops.
|
|
110
|
+
else:
|
|
111
|
+
# TODO: this is a better way of getting time_shift. Use method to
|
|
112
|
+
# replace the other branch (!multi_loop_reuse)
|
|
113
|
+
time_shift = construct_time_shift(occ, tags)
|
|
114
|
+
|
|
115
|
+
# Gets the fill (i.e., feeds data not currently in buffer).
|
|
116
|
+
occ_before: isl.Map = time_shift.apply_range(occ)
|
|
117
|
+
fill: isl.Map = occ.subtract(occ_before)
|
|
118
|
+
|
|
119
|
+
return TemporalReuse(Occupancy(tags, occ), Fill(tags, fill))
|
|
120
|
+
|
|
121
|
+
return TemporalReuse(Occupancy(tags, occ.coalesce()), Fill(tags, occ.coalesce()))
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def construct_time_shift(occ: isl.Map, tags: list[Tag]):
|
|
125
|
+
"""
|
|
126
|
+
Given an occupancy and its input dimension tags, create the proper spatial
|
|
127
|
+
and temporal separation objects.
|
|
128
|
+
|
|
129
|
+
Parameters
|
|
130
|
+
----------
|
|
131
|
+
occ:
|
|
132
|
+
The occupancy map we're analyzing the reuse for.
|
|
133
|
+
tags:
|
|
134
|
+
The tags of what an input represents.
|
|
135
|
+
|
|
136
|
+
Returns
|
|
137
|
+
-------
|
|
138
|
+
time_shift:
|
|
139
|
+
Relation of the current time step to the previous one across loops.
|
|
140
|
+
"""
|
|
141
|
+
# Creates the spacetime deconstruction to the two separate components.
|
|
142
|
+
spacetime: isl.Set = occ.domain()
|
|
143
|
+
spacetime_to_time: isl.Map = isl.Map.identity(spacetime.get_space().map_from_set())
|
|
144
|
+
spacetime_to_space: isl.Map = isl.Map.identity(spacetime.get_space().map_from_set())
|
|
145
|
+
# Prunes out the output dimensions that do not correspond to the
|
|
146
|
+
# correct mapping into a generic space-to-space relation.
|
|
147
|
+
for idx, t in reversed(list(enumerate(tags))):
|
|
148
|
+
if not isinstance(t, TEMPORAL_TAGS):
|
|
149
|
+
spacetime_to_time = spacetime_to_time.project_out(isl.dim_type.out, idx, 1)
|
|
150
|
+
else:
|
|
151
|
+
spacetime_to_space = spacetime_to_space.project_out(
|
|
152
|
+
isl.dim_type.out, idx, 1
|
|
153
|
+
)
|
|
154
|
+
# Gets the names correct after transformations.
|
|
155
|
+
spacetime_to_time = spacetime_to_time.set_tuple_name(
|
|
156
|
+
isl.dim_type.out, f"{spacetime.get_tuple_name()}_time"
|
|
157
|
+
)
|
|
158
|
+
spacetime_to_space = spacetime_to_space.set_tuple_name(
|
|
159
|
+
isl.dim_type.out, f"{spacetime.get_tuple_name()}_space"
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# Properly constrains the spacetime_to_time's domain.
|
|
163
|
+
spacetime_to_time = spacetime_to_time.intersect_domain(spacetime)
|
|
164
|
+
time_: isl.Set = spacetime_to_time.range()
|
|
165
|
+
# Creates a map of time_ to previous regions of time_.
|
|
166
|
+
time_to_past: isl.Map = (
|
|
167
|
+
isl.Map.lex_gt(time_.get_space()).intersect_domain(time_).intersect_range(time_)
|
|
168
|
+
)
|
|
169
|
+
# Restricts the relation to only the most recent previous region of time_.
|
|
170
|
+
time_to_most_recent_past = time_to_past.lexmax()
|
|
171
|
+
# Relates the current spacetime to its direct predecessor in time.
|
|
172
|
+
time_shift: isl.Map = spacetime_to_time.apply_range(
|
|
173
|
+
time_to_most_recent_past.apply_range(spacetime_to_time.reverse())
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Prunes spatial relations to only ones that are valid.
|
|
177
|
+
spacetime_space_preserver: isl.Map = spacetime_to_space.apply_range(
|
|
178
|
+
spacetime_to_space.reverse()
|
|
179
|
+
)
|
|
180
|
+
# Intersects with time_shift as space information is lost with the compression of
|
|
181
|
+
# spacetime to time_ and then rexpansion to past time_.
|
|
182
|
+
return time_shift.intersect(spacetime_space_preserver)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .symbolic import *
|