accelforge 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accelforge/__init__.py +21 -0
- accelforge/_accelerated_imports.py +16 -0
- accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
- accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
- accelforge/_deprecate/_simanneal/simanneal.py +666 -0
- accelforge/_deprecate/_simanneal/tracking.py +105 -0
- accelforge/_deprecate/_simanneal/wrappers.py +218 -0
- accelforge/_deprecate/_simanneal2/__init__.py +7 -0
- accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
- accelforge/_deprecate/_simanneal2/tracking.py +116 -0
- accelforge/_deprecate/compatibility_util.py +181 -0
- accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
- accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
- accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
- accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
- accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
- accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
- accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
- accelforge/_deprecate/tags.py +69 -0
- accelforge/_deprecate/viz/__init__.py +0 -0
- accelforge/_deprecate/viz/interactive.py +159 -0
- accelforge/_deprecate/viz/reservationtree.py +307 -0
- accelforge/_deprecate/viz/ski_slope.py +88 -0
- accelforge/_version.py +15 -0
- accelforge/examples.py +39 -0
- accelforge/frontend/__init__.py +10 -0
- accelforge/frontend/_binding.py +129 -0
- accelforge/frontend/_workload_isl/__init__.py +2 -0
- accelforge/frontend/_workload_isl/_isl.py +149 -0
- accelforge/frontend/_workload_isl/_symbolic.py +141 -0
- accelforge/frontend/arch copy.py +1544 -0
- accelforge/frontend/arch.py +1642 -0
- accelforge/frontend/config.py +63 -0
- accelforge/frontend/mapper/__init__.py +5 -0
- accelforge/frontend/mapper/ffm.py +126 -0
- accelforge/frontend/mapper/mapper.py +7 -0
- accelforge/frontend/mapper/metrics.py +30 -0
- accelforge/frontend/mapping/__init__.py +1 -0
- accelforge/frontend/mapping/mapping.py +1736 -0
- accelforge/frontend/model.py +14 -0
- accelforge/frontend/renames.py +150 -0
- accelforge/frontend/spec copy.py +230 -0
- accelforge/frontend/spec.py +301 -0
- accelforge/frontend/variables.py +12 -0
- accelforge/frontend/workload.py +952 -0
- accelforge/mapper/FFM/__init__.py +9 -0
- accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
- accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
- accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
- accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
- accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
- accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
- accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
- accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
- accelforge/mapper/FFM/data.py +61 -0
- accelforge/mapper/FFM/main copy.py +236 -0
- accelforge/mapper/FFM/main.py +208 -0
- accelforge/mapper/FFM/mappings.py +510 -0
- accelforge/mapper/FFM/pmappings.py +310 -0
- accelforge/mapper/__init__.py +4 -0
- accelforge/mapper.py +0 -0
- accelforge/model/__init__.py +1 -0
- accelforge/model/_looptree/__init__.py +0 -0
- accelforge/model/_looptree/accesses.py +335 -0
- accelforge/model/_looptree/capacity/__init__.py +1 -0
- accelforge/model/_looptree/capacity/aggregators.py +36 -0
- accelforge/model/_looptree/capacity/capacity.py +47 -0
- accelforge/model/_looptree/energy.py +150 -0
- accelforge/model/_looptree/equivalent_ranks.py +29 -0
- accelforge/model/_looptree/latency/__init__.py +1 -0
- accelforge/model/_looptree/latency/latency.py +98 -0
- accelforge/model/_looptree/latency/memory.py +120 -0
- accelforge/model/_looptree/latency/processors.py +92 -0
- accelforge/model/_looptree/mapping_utilities.py +71 -0
- accelforge/model/_looptree/reuse/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
- accelforge/model/_looptree/reuse/isl/des.py +59 -0
- accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
- accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
- accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
- accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
- accelforge/model/_looptree/run.py +122 -0
- accelforge/model/_looptree/types.py +26 -0
- accelforge/model/_looptree/visualization/__init__.py +0 -0
- accelforge/model/_looptree/visualization/occupancy.py +11 -0
- accelforge/model/main.py +222 -0
- accelforge/plotting/__init__.py +2 -0
- accelforge/plotting/mappings.py +219 -0
- accelforge/plotting/specs.py +57 -0
- accelforge/util/__init__.py +4 -0
- accelforge/util/_base_analysis_types.py +24 -0
- accelforge/util/_basetypes.py +1089 -0
- accelforge/util/_frozenset.py +36 -0
- accelforge/util/_isl.py +29 -0
- accelforge/util/_itertools.py +14 -0
- accelforge/util/_mathfuncs.py +57 -0
- accelforge/util/_parse_expressions.py +339 -0
- accelforge/util/_picklecache.py +32 -0
- accelforge/util/_setexpressions.py +268 -0
- accelforge/util/_sympy/__init__.py +0 -0
- accelforge/util/_sympy/broadcast_max.py +18 -0
- accelforge/util/_visualization.py +112 -0
- accelforge/util/_yaml.py +579 -0
- accelforge/util/parallel.py +193 -0
- accelforge-0.0.1.dist-info/METADATA +64 -0
- accelforge-0.0.1.dist-info/RECORD +258 -0
- accelforge-0.0.1.dist-info/WHEEL +5 -0
- accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
- accelforge-0.0.1.dist-info/top_level.txt +5 -0
- docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
- docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
- docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
- docs/_build/html/_sources/fastfusion.rst.txt +20 -0
- docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
- docs/_build/html/_sources/index.rst.txt +87 -0
- docs/_build/html/_sources/modules.rst.txt +7 -0
- docs/_build/html/_sources/notes/citation.rst.txt +45 -0
- docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
- docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
- docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
- docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
- docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
- docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
- docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
- docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
- docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
- docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
- docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
- docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
- docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
- docs/_build/html/_sources/notes/spec.rst.txt +36 -0
- docs/source/_ext/include_attrs.py +213 -0
- docs/source/_ext/include_docstring.py +364 -0
- docs/source/_ext/include_functions.py +154 -0
- docs/source/_ext/include_notebook.py +131 -0
- docs/source/_ext/include_yaml.py +119 -0
- docs/source/_ext/inherited_attributes.py +222 -0
- docs/source/_ext/paths.py +4 -0
- docs/source/conf.py +79 -0
- examples/arches/compute_in_memory/_include.yaml +74 -0
- examples/arches/compute_in_memory/_include_functions.py +229 -0
- examples/arches/compute_in_memory/_load_spec.py +57 -0
- examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
- examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
- examples/arches/compute_in_memory/components/misc.py +195 -0
- examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
- examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
- examples/arches/compute_in_memory/isaac.yaml +233 -0
- examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
- examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
- examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
- examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
- examples/arches/eyeriss.yaml +68 -0
- examples/arches/fanout_variations/at_glb.yaml +31 -0
- examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
- examples/arches/fanout_variations/at_mac.yaml +31 -0
- examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
- examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
- examples/arches/nvdla.yaml +47 -0
- examples/arches/simple.yaml +28 -0
- examples/arches/tpu_v4i.yaml +67 -0
- examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
- examples/misc/component_annotated.yaml +33 -0
- examples/workloads/gpt3_6.7B.yaml +124 -0
- examples/workloads/matmuls.yaml +20 -0
- examples/workloads/mobilenet_28.yaml +81 -0
- examples/workloads/mobilenet_various_separate.yaml +106 -0
- examples/workloads/three_matmuls_annotated.yaml +59 -0
- notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
- notebooks/compute_in_memory/_scripts.py +339 -0
- notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
- notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
- notebooks/paths.py +4 -0
- notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
- notebooks/tutorials/FFM.ipynb +3498 -0
- notebooks/tutorials/_include.py +48 -0
- notebooks/tutorials/component_energy_area.ipynb +363 -0
- tests/Q_mapping.yaml +38 -0
- tests/__init__.py +0 -0
- tests/conv.mapping.yaml +27 -0
- tests/conv.workload.yaml +13 -0
- tests/conv_sym.mapping.yaml +43 -0
- tests/copy.mapping.yaml +35 -0
- tests/copy.workload.yaml +15 -0
- tests/distribuffers/__init__.py +0 -0
- tests/distribuffers/multicast/test_cases.yaml +482 -0
- tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
- tests/distribuffers/spec/distributed.yaml +100 -0
- tests/distribuffers/spec/logical_arch.yaml +32 -0
- tests/distribuffers/spec/physical_arch.yaml +69 -0
- tests/distribuffers/test_binding.py +48 -0
- tests/frontend/__init__.py +0 -0
- tests/frontend/test_mapping_viz.py +52 -0
- tests/mapper/__init__.py +0 -0
- tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
- tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
- tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
- tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
- tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
- tests/mapper/test_mapping_to_isl.py +90 -0
- tests/mapper/test_spatial_reuse_analysis.py +67 -0
- tests/mapper/test_temporal_reuse_analysis.py +56 -0
- tests/mapper/util.py +58 -0
- tests/matmul.mapping.yaml +29 -0
- tests/matmul.workload.yaml +12 -0
- tests/matmul_spatial.mapping.yaml +44 -0
- tests/mha.renames.yaml +65 -0
- tests/mha.workload.yaml +67 -0
- tests/mha.yaml +59 -0
- tests/mha_full.workload.yaml +67 -0
- tests/mobilenet.workload.yaml +35 -0
- tests/mobilenet_long.workload.yaml +64 -0
- tests/pmappingcache.py +24 -0
- tests/processing_stage.arch.yaml +40 -0
- tests/snowcat.arch.yaml +36 -0
- tests/test_ffm_join_pmappings.py +106 -0
- tests/test_ffm_make_pmappings.py +82 -0
- tests/test_ffm_make_tile_shapes.py +49 -0
- tests/test_mapper.py +100 -0
- tests/test_model.py +37 -0
- tests/test_plotting.py +72 -0
- tests/test_processing_stage.py +46 -0
- tests/test_symbolic_model.py +248 -0
- tests/test_workload.py +141 -0
|
@@ -0,0 +1,952 @@
|
|
|
1
|
+
"""
|
|
2
|
+
All the objects used for a Workload description in AccelForge.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from itertools import product
|
|
6
|
+
import itertools
|
|
7
|
+
import logging
|
|
8
|
+
import re
|
|
9
|
+
from typing import Annotated, Any, TypeAlias
|
|
10
|
+
|
|
11
|
+
import pydot
|
|
12
|
+
|
|
13
|
+
from accelforge.util.parallel import _SVGJupyterRender
|
|
14
|
+
|
|
15
|
+
from accelforge.util._basetypes import (
|
|
16
|
+
ParsableDict,
|
|
17
|
+
ParsableList,
|
|
18
|
+
ParsableModel,
|
|
19
|
+
ParsesTo,
|
|
20
|
+
)
|
|
21
|
+
from accelforge.util._visualization import _pydot_graph
|
|
22
|
+
from accelforge.frontend.renames import (
|
|
23
|
+
EinsumName,
|
|
24
|
+
RankVariable,
|
|
25
|
+
Rename,
|
|
26
|
+
RenameList,
|
|
27
|
+
Renames,
|
|
28
|
+
TensorName,
|
|
29
|
+
Rank,
|
|
30
|
+
rename_list_factory,
|
|
31
|
+
)
|
|
32
|
+
from accelforge.util._parse_expressions import ParseError, parse_expression
|
|
33
|
+
from accelforge.util._setexpressions import InvertibleSet, eval_set_expression
|
|
34
|
+
from accelforge._version import __version__
|
|
35
|
+
|
|
36
|
+
from accelforge.frontend.renames import (
|
|
37
|
+
EinsumName,
|
|
38
|
+
RankVariable,
|
|
39
|
+
Rename,
|
|
40
|
+
RenameList,
|
|
41
|
+
Renames,
|
|
42
|
+
TensorName,
|
|
43
|
+
Rank,
|
|
44
|
+
rename_list_factory,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
CLIST_OPERATORS = [
|
|
49
|
+
"EQ",
|
|
50
|
+
"NE",
|
|
51
|
+
"LT",
|
|
52
|
+
"GT",
|
|
53
|
+
"LE",
|
|
54
|
+
"GE",
|
|
55
|
+
"NG",
|
|
56
|
+
"NL",
|
|
57
|
+
"AND",
|
|
58
|
+
"OR",
|
|
59
|
+
]
|
|
60
|
+
|
|
61
|
+
_ISL_REGEX = re.compile(
|
|
62
|
+
r"\b(?!(?:" + "|".join(CLIST_OPERATORS) + r")\b)[a-zA-Z#$@][a-zA-Z0-9_]*\b"
|
|
63
|
+
)
|
|
64
|
+
"""
|
|
65
|
+
Pattern[AnyStr@compile] _ISL_REGEX: A compiled regex pattern that matches
|
|
66
|
+
words that are not exactly in CLIST_OPERATORS (case-sensitive), start with a
|
|
67
|
+
letter, `#`, `$`, or `@`, and are followed by zero or more letters, digits,
|
|
68
|
+
or underscores.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def isl_expression_has_variable(expression: str, variable: RankVariable) -> bool:
|
|
73
|
+
"""
|
|
74
|
+
Returns True if the given ISL expression has the given rank variable.
|
|
75
|
+
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
expression : str
|
|
79
|
+
The ISL expression to check.
|
|
80
|
+
variable : RankVariable
|
|
81
|
+
The rank variable to check for.
|
|
82
|
+
|
|
83
|
+
Returns
|
|
84
|
+
-------
|
|
85
|
+
bool
|
|
86
|
+
True if the given ISL expression has the given rank variable.
|
|
87
|
+
"""
|
|
88
|
+
return variable in re.findall(_ISL_REGEX, expression)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
SymbolTable: TypeAlias = dict[str, InvertibleSet]
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class TensorAccess(ParsableModel):
|
|
95
|
+
"""Information about how an Einsum accesses a tensor."""
|
|
96
|
+
|
|
97
|
+
name: TensorName
|
|
98
|
+
""" The name of the tensor. """
|
|
99
|
+
|
|
100
|
+
projection: dict[str, str] | list[str]
|
|
101
|
+
"""
|
|
102
|
+
How the rank variables of the Einsum project into the tensor. If this is a list,
|
|
103
|
+
then it is assumed that each of the elements of the list is a single rank variable
|
|
104
|
+
and they index into the tensor in ranks that equal the uppercase of the rank
|
|
105
|
+
variable. For example:
|
|
106
|
+
|
|
107
|
+
name: X, projection: [a, b, c] means X[A=a, B=b, C=c]
|
|
108
|
+
|
|
109
|
+
If this is a dictionary, it is a mapping from rank names to rank variable
|
|
110
|
+
expressions. This can be used to either project into a non-matching rank name or to
|
|
111
|
+
project into a tensor using an expression. For example:
|
|
112
|
+
|
|
113
|
+
name: X, projection: {A: a, B2: b, C: a+b} means X[A=a, B2=b, C=a+b]
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
output: bool = False
|
|
117
|
+
""" Whether the tensor is an output. False means the tensor is an input. """
|
|
118
|
+
|
|
119
|
+
persistent: bool = False
|
|
120
|
+
""" If True, then a copy of this tensor must remain in backing storage for the full
|
|
121
|
+
duration of the workload's execution. """
|
|
122
|
+
|
|
123
|
+
backing_storage_size_scale: float = 1.0
|
|
124
|
+
""" If != 1, then the backing storage size will be scaled by this factor. """
|
|
125
|
+
|
|
126
|
+
bits_per_value: int | str | None = None
|
|
127
|
+
""" Bits per value for this tensor. """
|
|
128
|
+
|
|
129
|
+
def model_post_init(self, __context__=None) -> None:
|
|
130
|
+
self.projection: ImpliedProjection = _projection_factory(self.projection)
|
|
131
|
+
|
|
132
|
+
def _to_formatted_string(self) -> str:
|
|
133
|
+
"""Returns a string representation of the tensor access for Pydot nodes."""
|
|
134
|
+
subscript = ",".join(self.projection.values())
|
|
135
|
+
if isinstance(self.projection, ImpliedProjection):
|
|
136
|
+
return f"{self.name}<sub>{subscript}</sub>"
|
|
137
|
+
|
|
138
|
+
string = [self.name]
|
|
139
|
+
for k, v in self.projection.items():
|
|
140
|
+
if len(string) < len(self.projection):
|
|
141
|
+
string.append(f"<sup>{k},</sup><sub>{v},</sub>")
|
|
142
|
+
else:
|
|
143
|
+
string.append(f"<sup>{k}</sup><sub>{v}</sub>")
|
|
144
|
+
return "".join(string)
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def rank2rank_variables(self) -> dict[Rank, set[RankVariable]]:
|
|
148
|
+
"""
|
|
149
|
+
Returns a dictionary of rank names to the rank variables that project into that
|
|
150
|
+
rank.
|
|
151
|
+
"""
|
|
152
|
+
return {
|
|
153
|
+
Rank(rank): set(
|
|
154
|
+
RankVariable(rank_var)
|
|
155
|
+
for rank_var in re.findall(_ISL_REGEX, projection)
|
|
156
|
+
)
|
|
157
|
+
for rank, projection in self.projection.items()
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
@property
|
|
161
|
+
def rank_variable2ranks(self) -> dict[RankVariable, set[Rank]]:
|
|
162
|
+
"""
|
|
163
|
+
Returns a dictionary of rank variables to the ranks into which that rank
|
|
164
|
+
variable projects.
|
|
165
|
+
"""
|
|
166
|
+
result = {}
|
|
167
|
+
for rank, projection in self.projection.items():
|
|
168
|
+
for rank_var in re.findall(_ISL_REGEX, projection):
|
|
169
|
+
rank_set: set = result.setdefault(rank_var, set())
|
|
170
|
+
rank_set.add(rank)
|
|
171
|
+
return result
|
|
172
|
+
|
|
173
|
+
@property
|
|
174
|
+
def ranks(self) -> tuple[Rank, ...]:
|
|
175
|
+
"""Returns the ranks of this access's tensor."""
|
|
176
|
+
return tuple(Rank(x) for x in self.projection.keys())
|
|
177
|
+
|
|
178
|
+
@property
|
|
179
|
+
def rank_variables(self) -> set[RankVariable]:
|
|
180
|
+
"""Returns all rank variables used in this access."""
|
|
181
|
+
# Projection values may be expressions, so we need to grab all identifiers
|
|
182
|
+
return set(
|
|
183
|
+
RankVariable(x)
|
|
184
|
+
for x in re.findall(_ISL_REGEX, " ".join(self.projection.values()))
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
@property
|
|
188
|
+
def directly_indexing_rank_variables(self) -> set[RankVariable]:
|
|
189
|
+
"""
|
|
190
|
+
Returns the rank variables that directly index into this tensor without any
|
|
191
|
+
expression (e.g., "M=m", NOT "M=m+n").
|
|
192
|
+
"""
|
|
193
|
+
return set(
|
|
194
|
+
RankVariable(x) for x in self.projection.values() if _ISL_REGEX.match(x)
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
@property
|
|
198
|
+
def expression_indexing_rank_variables(self) -> set[RankVariable]:
|
|
199
|
+
"""
|
|
200
|
+
Returns the rank variables that indirectly index into this tensor through an
|
|
201
|
+
expression (e.g., "M=m+n") instead of a direct index (e.g., "M=m").
|
|
202
|
+
"""
|
|
203
|
+
return self.rank_variables - self.directly_indexing_rank_variables
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class ImpliedProjection(dict):
|
|
207
|
+
"""
|
|
208
|
+
Holds a projection that has been implied by a list of rank variables. The implied
|
|
209
|
+
rank names are uppercased versions of the rank variables; for example, [a, b, c] ->
|
|
210
|
+
{A: a, B: b, C: c}.
|
|
211
|
+
"""
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def _projection_factory(projection: dict | list):
|
|
215
|
+
if isinstance(projection, list):
|
|
216
|
+
for i, x in enumerate(projection):
|
|
217
|
+
if not isinstance(x, str):
|
|
218
|
+
raise TypeError(f"Element at index {i} must be a string, got {type(x)}")
|
|
219
|
+
if not _ISL_REGEX.match(x):
|
|
220
|
+
raise ValueError(
|
|
221
|
+
f"Element '{x}' at index {i} is not a valid ISL identifier"
|
|
222
|
+
f"In a projection list, all elements must be valid ISL identifiers."
|
|
223
|
+
f"For expressions, use a dictionary projection."
|
|
224
|
+
)
|
|
225
|
+
projection = ImpliedProjection({x.upper(): x for x in projection})
|
|
226
|
+
elif not isinstance(projection, dict):
|
|
227
|
+
raise TypeError(
|
|
228
|
+
f"Invalid projection: {projection}. Must be a list of rank variables or a "
|
|
229
|
+
f"dictionary of rank variable to projection."
|
|
230
|
+
)
|
|
231
|
+
for key in projection:
|
|
232
|
+
if not isinstance(key, str):
|
|
233
|
+
raise TypeError(f"Invalid projection key: {key}. Must be a string.")
|
|
234
|
+
if not key.isidentifier():
|
|
235
|
+
raise ValueError(
|
|
236
|
+
f"Invalid projection key: {key}. Must be a valid identifier. Check with "
|
|
237
|
+
f"the Python isidentifier() function."
|
|
238
|
+
)
|
|
239
|
+
return projection
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
class Shape(ParsableList):
|
|
243
|
+
"""
|
|
244
|
+
Specifies valid values for the rank variables. This is a list of strings, each one
|
|
245
|
+
an ISL expression. The total space is considered to be the logal AND of all the
|
|
246
|
+
expressions in the list.
|
|
247
|
+
"""
|
|
248
|
+
|
|
249
|
+
@property
|
|
250
|
+
def rank_variables(self) -> set[str]:
|
|
251
|
+
"""Returns all rank variables used in this shape."""
|
|
252
|
+
if not self:
|
|
253
|
+
return set()
|
|
254
|
+
return set.union(*[set(re.findall(_ISL_REGEX, x)) for x in self])
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
class Einsum(ParsableModel):
|
|
258
|
+
"""
|
|
259
|
+
Represents an Einsum, which is a single computation step in the workload. The Einsum
|
|
260
|
+
includes a set of rank variables, which are used to index into tensors. Rank
|
|
261
|
+
variables iterate through an iteration space.
|
|
262
|
+
|
|
263
|
+
For example, if the Einsum is A[m, n] += B[k, n] * C[k, n] and we define the
|
|
264
|
+
iteration space as "0 <= m < 10, 0 <= n < 10, 0 <= k < 10", then the Einsum will
|
|
265
|
+
iterate through all possible values of (m, n, k) in the iteration space, indexing
|
|
266
|
+
into tensors for each and updating A[m, n] with B[k, n] * C[k, n].
|
|
267
|
+
"""
|
|
268
|
+
|
|
269
|
+
name: EinsumName
|
|
270
|
+
""" The name of the Einsum. """
|
|
271
|
+
tensor_accesses: ParsableList[TensorAccess]
|
|
272
|
+
""" The tensors accessed by this Einsum, and how they are accessed. """
|
|
273
|
+
iteration_space_shape: Shape[str] = Shape()
|
|
274
|
+
"""
|
|
275
|
+
Bounds of valid rank variable values. This is a list of expressions, each one an ISL
|
|
276
|
+
expression. Additionally, global iteration_space_shape expressions are appended to
|
|
277
|
+
the list if their rank variables are present in the Einsum's rank_variables. For
|
|
278
|
+
example, if the global scope has "m: 0 <= m < 10" and the Einsum has "m" in its
|
|
279
|
+
rank_variables, then "0 <= m < 10" will be appended to the iteration_space_shape.
|
|
280
|
+
"""
|
|
281
|
+
rank_sizes: ParsableDict[Rank, int] = ParsableDict()
|
|
282
|
+
"""
|
|
283
|
+
Sizes of ranks. This is a dictionary of rank names to sizes. Sizes are integers, and
|
|
284
|
+
the rank's bounds are 0 <= rank < size. Accesses outside of these bounds are
|
|
285
|
+
skipped.
|
|
286
|
+
"""
|
|
287
|
+
is_copy_operation: bool = False
|
|
288
|
+
""" Whether the Einsum is a copy operation. Copy operations take the input tensor
|
|
289
|
+
and directly place them at the location of the output tensor(s) without any
|
|
290
|
+
computation. If the destination tensor is at the same location, then this is a
|
|
291
|
+
no-op."""
|
|
292
|
+
renames: RenameList[Rename] = RenameList()
|
|
293
|
+
""" Renames of the Einsum. Renames here can be used to rename rank variables or
|
|
294
|
+
tensors. When this Einsum is executed on an architecture, the architecture can use
|
|
295
|
+
renamed tensors and rank variables to access the tensors and rank variables. """
|
|
296
|
+
n_instances: int = 1
|
|
297
|
+
"""
|
|
298
|
+
Number of times to repeat the Einsum. Multiplied by `Workload.n_instances` to get
|
|
299
|
+
the total number of Einsum instances. Energy, latency, and other summable metrics
|
|
300
|
+
are multiplied by this value. Persistent reservations are also multiplied by this
|
|
301
|
+
value, but non-persistent reservations are not, as they are assumed to be freed
|
|
302
|
+
between each instance.
|
|
303
|
+
"""
|
|
304
|
+
|
|
305
|
+
def model_post_init(self, __context__=None) -> None:
|
|
306
|
+
if self.name == "Total":
|
|
307
|
+
raise ValueError(
|
|
308
|
+
f'Einsum name "Total" is reserved for totaling across Einsums.'
|
|
309
|
+
f"Use a different name for the Einsum."
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
def __init__(self, *args, **kwargs):
|
|
313
|
+
if "renames" in kwargs:
|
|
314
|
+
kwargs["renames"] = rename_list_factory(kwargs["renames"])
|
|
315
|
+
super().__init__(*args, **kwargs)
|
|
316
|
+
|
|
317
|
+
@property
|
|
318
|
+
def rank_variables(self) -> set[RankVariable]:
|
|
319
|
+
"""Returns all rank variables used in this Einsum."""
|
|
320
|
+
if not self.tensor_accesses:
|
|
321
|
+
return set()
|
|
322
|
+
return set.union(*[t.rank_variables for t in self.tensor_accesses])
|
|
323
|
+
|
|
324
|
+
@property
|
|
325
|
+
def ranks(self) -> set[Rank]:
|
|
326
|
+
"""Returns all ranks used in this Einsum."""
|
|
327
|
+
if not self.tensor_accesses:
|
|
328
|
+
return set()
|
|
329
|
+
return set.union(*[set(t.ranks) for t in self.tensor_accesses])
|
|
330
|
+
|
|
331
|
+
@property
|
|
332
|
+
def input_tensor_names(self) -> set[TensorName]:
|
|
333
|
+
"""Returns the names of the input tensors of this Einsum."""
|
|
334
|
+
return set([TensorName(t.name) for t in self.tensor_accesses if not t.output])
|
|
335
|
+
|
|
336
|
+
@property
|
|
337
|
+
def output_tensor_names(self) -> set[TensorName]:
|
|
338
|
+
"""Returns the names of the output tensors of this Einsum."""
|
|
339
|
+
return set([TensorName(t.name) for t in self.tensor_accesses if t.output])
|
|
340
|
+
|
|
341
|
+
@property
|
|
342
|
+
def tensor_names(self) -> set[TensorName]:
|
|
343
|
+
"""Returns the names of all tensors of this Einsum."""
|
|
344
|
+
return set([TensorName(t.name) for t in self.tensor_accesses])
|
|
345
|
+
|
|
346
|
+
@property
|
|
347
|
+
def tensor2rank_variables(self) -> dict[TensorName, set[RankVariable]]:
|
|
348
|
+
"""Returns a dictionary of tensor names to the rank variables that project into
|
|
349
|
+
that tensor."""
|
|
350
|
+
return {TensorName(t.name): t.rank_variables for t in self.tensor_accesses}
|
|
351
|
+
|
|
352
|
+
@property
|
|
353
|
+
def tensor2directly_indexing_rank_variables(
|
|
354
|
+
self,
|
|
355
|
+
) -> dict[TensorName, set[RankVariable]]:
|
|
356
|
+
"""
|
|
357
|
+
Returns a dictionary of tensor names to the rank variables that directly index
|
|
358
|
+
into that tensor. Direct indexing means that the rank variable is used as a
|
|
359
|
+
direct index into the tensor, without any expression (e.g., "M=m", NOT "M=m+n").
|
|
360
|
+
"""
|
|
361
|
+
return {
|
|
362
|
+
TensorName(t.name): t.directly_indexing_rank_variables
|
|
363
|
+
for t in self.tensor_accesses
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
@property
|
|
367
|
+
def tensor2expression_indexing_rank_variables(
|
|
368
|
+
self,
|
|
369
|
+
) -> dict[TensorName, set[RankVariable]]:
|
|
370
|
+
"""
|
|
371
|
+
Returns a dictionary of tensor names to the rank variables that indirectly index
|
|
372
|
+
into that tensor through an expression (e.g., "M=m+n") instead of a direct index
|
|
373
|
+
(e.g., "M=m").
|
|
374
|
+
"""
|
|
375
|
+
fully_relevant_rank_vars = self.tensor2directly_indexing_rank_variables
|
|
376
|
+
return {
|
|
377
|
+
TensorName(t.name): t.rank_variables - fully_relevant_rank_vars[t.name]
|
|
378
|
+
for t in self.tensor_accesses
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
@property
|
|
382
|
+
def tensor2irrelevant_rank_variables(
|
|
383
|
+
self,
|
|
384
|
+
) -> dict[TensorName, set[RankVariable]]:
|
|
385
|
+
"""
|
|
386
|
+
Returns a dictionary of tensor names to the rank variables that are irrelevant
|
|
387
|
+
to that tensor. Irrelevant rank variables are rank variables that are not used
|
|
388
|
+
to index into the tensor.
|
|
389
|
+
"""
|
|
390
|
+
partially_relevant = self.tensor2expression_indexing_rank_variables
|
|
391
|
+
fully_relevant = self.tensor2directly_indexing_rank_variables
|
|
392
|
+
rank_variables = self.rank_variables
|
|
393
|
+
return {
|
|
394
|
+
TensorName(t.name): rank_variables
|
|
395
|
+
- fully_relevant[t.name]
|
|
396
|
+
- partially_relevant[t.name]
|
|
397
|
+
for t in self.tensor_accesses
|
|
398
|
+
}
|
|
399
|
+
|
|
400
|
+
def _to_formatted_string(self, compress: bool = False) -> str:
|
|
401
|
+
"""
|
|
402
|
+
Returns a string representation of this Einsum for use in a Pydot graph.
|
|
403
|
+
|
|
404
|
+
Parameters
|
|
405
|
+
----------
|
|
406
|
+
compress : bool, optional
|
|
407
|
+
If True, the string will be compressed to a single line.
|
|
408
|
+
|
|
409
|
+
Returns
|
|
410
|
+
-------
|
|
411
|
+
str
|
|
412
|
+
A string representation of this Einsum for use in a Pydot graph.
|
|
413
|
+
"""
|
|
414
|
+
lhs_join = ",\n" if compress else " , "
|
|
415
|
+
rhs_join = " \n " if compress else " "
|
|
416
|
+
lhs = lhs_join.join(
|
|
417
|
+
[t._to_formatted_string() for t in self.tensor_accesses if t.output]
|
|
418
|
+
)
|
|
419
|
+
rhs = rhs_join.join(
|
|
420
|
+
[t._to_formatted_string() for t in self.tensor_accesses if not t.output]
|
|
421
|
+
)
|
|
422
|
+
return f"{lhs}=\n{rhs}" if compress else f"{lhs} = {rhs}"
|
|
423
|
+
|
|
424
|
+
def copy_source_tensor(self) -> TensorName | None:
|
|
425
|
+
"""
|
|
426
|
+
If this Einsum is a copy operation, returns the name of the tensor that is the
|
|
427
|
+
source of the copy. Otherwise, returns None.
|
|
428
|
+
"""
|
|
429
|
+
if not self.is_copy_operation:
|
|
430
|
+
return None
|
|
431
|
+
input_tensors = self.input_tensor_names
|
|
432
|
+
if len(input_tensors) != 1:
|
|
433
|
+
raise ValueError(
|
|
434
|
+
f"Copy Einsum {self.name} has {len(input_tensors)} input tensors, expected 1"
|
|
435
|
+
)
|
|
436
|
+
return input_tensors.pop()
|
|
437
|
+
|
|
438
|
+
@property
|
|
439
|
+
def rank_variable2ranks(self) -> dict[RankVariable, set[Rank]]:
|
|
440
|
+
"""
|
|
441
|
+
Returns a dictionary of rank variables to the ranks that are indexed into by
|
|
442
|
+
that rank variable.
|
|
443
|
+
"""
|
|
444
|
+
result: dict[RankVariable, set[Rank]] = {}
|
|
445
|
+
for tensor_access in self.tensor_accesses:
|
|
446
|
+
new = tensor_access.rank_variable2ranks
|
|
447
|
+
for rank_var, ranks in new.items():
|
|
448
|
+
result.setdefault(rank_var, set()).update(ranks)
|
|
449
|
+
return result
|
|
450
|
+
|
|
451
|
+
@property
|
|
452
|
+
def indexing_expressions(self) -> set[str]:
|
|
453
|
+
"""
|
|
454
|
+
Returns a list of all the expressions that index into the tensors of this
|
|
455
|
+
Einsum.
|
|
456
|
+
"""
|
|
457
|
+
result = set()
|
|
458
|
+
for tensor_access in self.tensor_accesses:
|
|
459
|
+
for _, projection in tensor_access.projection.items():
|
|
460
|
+
result.add(projection)
|
|
461
|
+
return result
|
|
462
|
+
|
|
463
|
+
def _parse_expressions(self, symbol_table: dict[str, Any], *args, **kwargs):
|
|
464
|
+
workload: Workload = symbol_table["spec_workload"]
|
|
465
|
+
renames: Renames = symbol_table["spec_renames"]
|
|
466
|
+
|
|
467
|
+
# Put together renames symbol table
|
|
468
|
+
inputs = self.input_tensor_names
|
|
469
|
+
outputs = self.output_tensor_names
|
|
470
|
+
all_ = inputs | outputs
|
|
471
|
+
persistent = {t.name for t in self.tensor_accesses if t.persistent}
|
|
472
|
+
element_to_child_space = {}
|
|
473
|
+
all_rank_variables = self.rank_variables
|
|
474
|
+
for tensor in self.tensor_names:
|
|
475
|
+
element_to_child_space[tensor] = InvertibleSet(
|
|
476
|
+
instance=self.tensor2rank_variables[tensor],
|
|
477
|
+
full_space=all_rank_variables,
|
|
478
|
+
space_type=RankVariable,
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
intermediates = {
|
|
482
|
+
t
|
|
483
|
+
for t in all_
|
|
484
|
+
if workload.einsums_with_tensor_as_input(t)
|
|
485
|
+
and workload.einsums_with_tensor_as_output(t)
|
|
486
|
+
}
|
|
487
|
+
shared = {
|
|
488
|
+
t
|
|
489
|
+
for t in all_
|
|
490
|
+
if len(
|
|
491
|
+
set(e.name for e in workload.einsums_with_tensor_as_input(t))
|
|
492
|
+
| set(e.name for e in workload.einsums_with_tensor_as_output(t))
|
|
493
|
+
)
|
|
494
|
+
> 1
|
|
495
|
+
}
|
|
496
|
+
|
|
497
|
+
kwargs_tensors = dict(
|
|
498
|
+
full_space=all_,
|
|
499
|
+
space_type=TensorName,
|
|
500
|
+
child_access_name="rank_variables",
|
|
501
|
+
element_to_child_space=element_to_child_space,
|
|
502
|
+
)
|
|
503
|
+
kwargs_rank_variables = dict(
|
|
504
|
+
full_space=all_rank_variables,
|
|
505
|
+
space_type=RankVariable,
|
|
506
|
+
)
|
|
507
|
+
rename_symbol_table = {
|
|
508
|
+
"All": InvertibleSet(instance=all_, **kwargs_tensors),
|
|
509
|
+
"Tensors": InvertibleSet(instance=all_, **kwargs_tensors),
|
|
510
|
+
"Nothing": InvertibleSet(instance=(), **kwargs_tensors),
|
|
511
|
+
"Inputs": InvertibleSet(instance=inputs, **kwargs_tensors),
|
|
512
|
+
"Outputs": InvertibleSet(instance=outputs, **kwargs_tensors),
|
|
513
|
+
"Intermediates": InvertibleSet(instance=intermediates, **kwargs_tensors),
|
|
514
|
+
"Shared": InvertibleSet(instance=shared, **kwargs_tensors),
|
|
515
|
+
"Persistent": InvertibleSet(instance=persistent, **kwargs_tensors),
|
|
516
|
+
**{t: InvertibleSet(instance=(t,), **kwargs_tensors) for t in all_},
|
|
517
|
+
**{
|
|
518
|
+
r: InvertibleSet(instance=(r,), **kwargs_rank_variables)
|
|
519
|
+
for r in all_rank_variables
|
|
520
|
+
},
|
|
521
|
+
"Einsum": self.name,
|
|
522
|
+
"Above": InvertibleSet(instance=(), **kwargs_tensors),
|
|
523
|
+
}
|
|
524
|
+
|
|
525
|
+
for t in workload.tensor_names:
|
|
526
|
+
if t not in rename_symbol_table:
|
|
527
|
+
rename_symbol_table[t] = InvertibleSet(instance=(), **kwargs_tensors)
|
|
528
|
+
|
|
529
|
+
for r in workload.rank_variables:
|
|
530
|
+
if r not in rename_symbol_table:
|
|
531
|
+
rename_symbol_table[r] = InvertibleSet(
|
|
532
|
+
instance=(), **kwargs_rank_variables
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
st = {**rename_symbol_table, **symbol_table}
|
|
536
|
+
|
|
537
|
+
self: Einsum = self.model_copy()
|
|
538
|
+
self.renames = RenameList(self.renames)
|
|
539
|
+
|
|
540
|
+
# Grab the default renames and update the renames with more values
|
|
541
|
+
default_renames = renames.get_renames_for_einsum("default")
|
|
542
|
+
for tensor_rename in default_renames.tensor_accesses:
|
|
543
|
+
if tensor_rename.name not in self.renames:
|
|
544
|
+
self.renames.append(tensor_rename)
|
|
545
|
+
for rank_variable_rename in default_renames.rank_variables:
|
|
546
|
+
if rank_variable_rename.name not in self.renames:
|
|
547
|
+
self.renames.append(rank_variable_rename)
|
|
548
|
+
|
|
549
|
+
# Parse me!
|
|
550
|
+
kwargs["must_parse_try_parse_to"] = True
|
|
551
|
+
parsed, _ = super(self.__class__, self)._parse_expressions(st, *args, **kwargs)
|
|
552
|
+
|
|
553
|
+
# Update the renames with the new values
|
|
554
|
+
for k, v in rename_symbol_table.items():
|
|
555
|
+
if k not in parsed.renames:
|
|
556
|
+
parsed.renames.append(Rename(name=k, source=v))
|
|
557
|
+
|
|
558
|
+
# Parse the bits per value
|
|
559
|
+
bits_per_value = dict()
|
|
560
|
+
bpv_to_source = dict()
|
|
561
|
+
for k, v in symbol_table["workload_bits_per_value"].items():
|
|
562
|
+
bpv = eval_set_expression(
|
|
563
|
+
expression=k,
|
|
564
|
+
symbol_table=st,
|
|
565
|
+
expected_space=TensorName,
|
|
566
|
+
location=f"(workload global bits_per_value)[{k}]",
|
|
567
|
+
)
|
|
568
|
+
for t in bpv:
|
|
569
|
+
if t in bits_per_value:
|
|
570
|
+
raise ParseError(
|
|
571
|
+
f"Tensor {t} is specified in multiple entries in the workload "
|
|
572
|
+
f"global bits_per_value dictionary.",
|
|
573
|
+
source_field=f"({k} AND {bpv_to_source[t]})",
|
|
574
|
+
)
|
|
575
|
+
bits_per_value[t] = v
|
|
576
|
+
bpv_to_source[t] = k
|
|
577
|
+
|
|
578
|
+
for t in parsed.tensor_accesses:
|
|
579
|
+
if t.bits_per_value is None and t.name not in bits_per_value:
|
|
580
|
+
raise ParseError(
|
|
581
|
+
f"Tensor {t.name} in Einsum does not have a bits per value "
|
|
582
|
+
f"specified. Ensure that the tensor is either covered by the set "
|
|
583
|
+
f"expressions in the workload.bits_per_value dictionary "
|
|
584
|
+
f"or bits_per_value is specified for the tensor access."
|
|
585
|
+
f"",
|
|
586
|
+
source_field=f"tensor_accesses[{t.name}].bits_per_value",
|
|
587
|
+
)
|
|
588
|
+
if t.bits_per_value is None:
|
|
589
|
+
t.bits_per_value = bits_per_value[t.name]
|
|
590
|
+
|
|
591
|
+
return parsed, symbol_table
|
|
592
|
+
|
|
593
|
+
|
|
594
|
+
class Workload(ParsableModel):
|
|
595
|
+
"""
|
|
596
|
+
The workload specification as a cascade of Einsums, with each Einsum being a
|
|
597
|
+
computation step in the workload.
|
|
598
|
+
"""
|
|
599
|
+
|
|
600
|
+
# version: Annotated[str, assert_version] = __version__
|
|
601
|
+
# """ The version of the workload specification. """
|
|
602
|
+
|
|
603
|
+
einsums: ParsableList[Einsum] = ParsableList()
|
|
604
|
+
""" The Einsums in the workload. """
|
|
605
|
+
|
|
606
|
+
iteration_space_shape: ParsableDict[RankVariable, str] = ParsableDict()
|
|
607
|
+
"""
|
|
608
|
+
Bounds of valid rank variable values. This is a dictionary of rank variable
|
|
609
|
+
names to bounds of valid rank variable values. The bounds are specified as a string
|
|
610
|
+
in the ISL format. For example, "0 <= a < 10" means that the rank variable `a` must
|
|
611
|
+
be between 0 and 10, including 0 but not 10. Bounds are included for all Einsums
|
|
612
|
+
that include that rank variable.
|
|
613
|
+
"""
|
|
614
|
+
|
|
615
|
+
rank_sizes: ParsableDict[Rank, ParsesTo[int]] = ParsableDict()
|
|
616
|
+
"""
|
|
617
|
+
Rank sizes. This is a dictionary of rank names to sizes. Sizes are integers, and the
|
|
618
|
+
rank's bounds are 0 <= rank < size. Accesses outside of these bounds are skipped.
|
|
619
|
+
"""
|
|
620
|
+
|
|
621
|
+
n_instances: int = 1
|
|
622
|
+
"""
|
|
623
|
+
Number of times to repeat the workload. Multiplied by `Einsum.n_instances` to get
|
|
624
|
+
the total number of Einsum instances. Energy, latency, and other summable metrics
|
|
625
|
+
are multiplied by this value. Persistent reservations are also multiplied by this
|
|
626
|
+
value, but non-persistent reservations are not, as they are assumed to be freed
|
|
627
|
+
between each instance.
|
|
628
|
+
"""
|
|
629
|
+
|
|
630
|
+
bits_per_value: ParsableDict[str, int | str] = ParsableDict()
|
|
631
|
+
"""
|
|
632
|
+
Bits per value for each tensor. The workload-level bits_per_value is overridden if
|
|
633
|
+
bits_per_action is specified for any given tensor access. This is a dictionary of
|
|
634
|
+
set expressions to bits per value for the tensors given by those expressions. For
|
|
635
|
+
example, we may write "Inputs: 8" to set the bits per value to 8 for all input
|
|
636
|
+
tensors, unless overridden.
|
|
637
|
+
"""
|
|
638
|
+
|
|
639
|
+
def model_post_init(self, __context__=None) -> None:
|
|
640
|
+
self._validate()
|
|
641
|
+
|
|
642
|
+
def _validate(self):
|
|
643
|
+
tensor2ranks = {}
|
|
644
|
+
einsum_names = set()
|
|
645
|
+
for einsum in self.einsums:
|
|
646
|
+
if einsum.name in einsum_names:
|
|
647
|
+
raise ValueError(f"Einsum name {einsum.name} is not unique")
|
|
648
|
+
einsum_names.add(einsum.name)
|
|
649
|
+
for tensor_accesses in einsum.tensor_accesses:
|
|
650
|
+
tensor2ranks.setdefault(tensor_accesses.name, tensor_accesses.ranks)
|
|
651
|
+
if tensor2ranks[tensor_accesses.name] != tensor_accesses.ranks:
|
|
652
|
+
raise ValueError(
|
|
653
|
+
f"TensorName {tensor_accesses.name} has inconsistent ranks. Found "
|
|
654
|
+
f"{tensor2ranks[tensor_accesses.name]} and {tensor_accesses.ranks}. "
|
|
655
|
+
"TensorName is in Einsums "
|
|
656
|
+
f"{', '.join(
|
|
657
|
+
e.name for e in self.einsums_with_tensor(tensor_accesses.name)
|
|
658
|
+
)}"
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
@property
|
|
662
|
+
def einsum_names(self) -> list[EinsumName]:
|
|
663
|
+
"""Returns the names of the Einsums in the workload."""
|
|
664
|
+
return [EinsumName(e.name) for e in self.einsums]
|
|
665
|
+
|
|
666
|
+
def einsums_with_tensor(self, tensor: TensorName) -> list["Einsum"]:
|
|
667
|
+
"""
|
|
668
|
+
Returns the Einsums in the workload that access the given tensor.
|
|
669
|
+
|
|
670
|
+
Parameters
|
|
671
|
+
----------
|
|
672
|
+
tensor : TensorName
|
|
673
|
+
The tensor to check.
|
|
674
|
+
|
|
675
|
+
Returns
|
|
676
|
+
-------
|
|
677
|
+
list[Einsum]
|
|
678
|
+
The Einsums in the workload that access the given tensor. Order is the same
|
|
679
|
+
as the order in this workload's Einsums list.
|
|
680
|
+
"""
|
|
681
|
+
return [e for e in self.einsums if tensor in e.tensor_names]
|
|
682
|
+
|
|
683
|
+
def einsums_with_tensor_as_input(self, tensor: TensorName) -> list["Einsum"]:
|
|
684
|
+
"""
|
|
685
|
+
Returns the Einsums in the workload that use the given tensor as an input.
|
|
686
|
+
|
|
687
|
+
Parameters
|
|
688
|
+
----------
|
|
689
|
+
tensor : TensorName
|
|
690
|
+
The tensor to check.
|
|
691
|
+
|
|
692
|
+
Returns
|
|
693
|
+
-------
|
|
694
|
+
list[Einsum]
|
|
695
|
+
The Einsums in the workload that use the given tensor as an input. Order is
|
|
696
|
+
the same as the order in this workload's Einsums list.
|
|
697
|
+
"""
|
|
698
|
+
return [e for e in self.einsums if tensor in e.input_tensor_names]
|
|
699
|
+
|
|
700
|
+
def einsums_with_tensor_as_output(self, tensor: TensorName) -> list["Einsum"]:
|
|
701
|
+
"""
|
|
702
|
+
Returns the Einsums in the workload that have the given tensor as an output.
|
|
703
|
+
|
|
704
|
+
Parameters
|
|
705
|
+
----------
|
|
706
|
+
tensor : TensorName
|
|
707
|
+
The tensor to check.
|
|
708
|
+
|
|
709
|
+
Returns
|
|
710
|
+
-------
|
|
711
|
+
list[Einsum]
|
|
712
|
+
The Einsums in the workload that have the given tensor as an output. Order
|
|
713
|
+
is the same as the order in this workload's Einsums list.
|
|
714
|
+
"""
|
|
715
|
+
return [e for e in self.einsums if tensor in e.output_tensor_names]
|
|
716
|
+
|
|
717
|
+
def accesses_for_tensor(self, tensor: TensorName) -> list[TensorAccess]:
|
|
718
|
+
"""
|
|
719
|
+
Returns all TensorAccess objects that access the given tensor across all
|
|
720
|
+
Einsums.
|
|
721
|
+
|
|
722
|
+
Parameters
|
|
723
|
+
----------
|
|
724
|
+
tensor : TensorName
|
|
725
|
+
The tensor to check.
|
|
726
|
+
|
|
727
|
+
Returns
|
|
728
|
+
-------
|
|
729
|
+
list[TensorAccess]
|
|
730
|
+
The TensorAccess objects that access the given tensor across all Einsums.
|
|
731
|
+
Order is the same as the order in this workload's Einsums list.
|
|
732
|
+
"""
|
|
733
|
+
return [t for e in self.einsums for t in e.tensor_accesses if t.name == tensor]
|
|
734
|
+
|
|
735
|
+
def get_iteration_space_shape_isl_string(self, einsum_name: str) -> str:
|
|
736
|
+
"""
|
|
737
|
+
Returns the ISL string representing the iteration space of the given Einsum.
|
|
738
|
+
|
|
739
|
+
Parameters
|
|
740
|
+
----------
|
|
741
|
+
einsum_name : str
|
|
742
|
+
The name of the Einsum for which to get the iteration space shape.
|
|
743
|
+
|
|
744
|
+
Returns
|
|
745
|
+
-------
|
|
746
|
+
str
|
|
747
|
+
The ISL string representing the iteration space shape of the given Einsum.
|
|
748
|
+
"""
|
|
749
|
+
einsum = self.einsums[einsum_name]
|
|
750
|
+
einsum_shape = einsum.iteration_space_shape
|
|
751
|
+
my_ispace = self.iteration_space_shape
|
|
752
|
+
global_shape = [my_ispace[r] for r in einsum.rank_variables if r in my_ispace]
|
|
753
|
+
rank_sizes = einsum.rank_sizes
|
|
754
|
+
global_rank_sizes = {
|
|
755
|
+
r: self.rank_sizes[r] for r in einsum.ranks if r in self.rank_sizes
|
|
756
|
+
}
|
|
757
|
+
|
|
758
|
+
exprs = einsum_shape + global_shape
|
|
759
|
+
for tensor in einsum.tensor_accesses:
|
|
760
|
+
for rank, projection in tensor.projection.items():
|
|
761
|
+
if rank in rank_sizes:
|
|
762
|
+
exprs.append(f"0 <= {projection} < {rank_sizes[rank]}")
|
|
763
|
+
elif rank in global_rank_sizes:
|
|
764
|
+
exprs.append(f"0 <= {projection} < {global_rank_sizes[rank]}")
|
|
765
|
+
|
|
766
|
+
return " and ".join(exprs)
|
|
767
|
+
|
|
768
|
+
def _check_consistent_persistent(self):
|
|
769
|
+
for tensor in self.tensor_names:
|
|
770
|
+
persistents = {
|
|
771
|
+
e.tensor_accesses[tensor].persistent
|
|
772
|
+
for e in self.einsums_with_tensor(tensor)
|
|
773
|
+
}
|
|
774
|
+
if len(persistents) > 1:
|
|
775
|
+
raise ValueError(
|
|
776
|
+
f"Tensor {tensor} is used in multiple Einsums with different "
|
|
777
|
+
f"persistent values. Persistent values must be consistent across "
|
|
778
|
+
f"all Einsums that use the tensor."
|
|
779
|
+
)
|
|
780
|
+
|
|
781
|
+
@property
|
|
782
|
+
def tensor_names_used_in_multiple_einsums(self) -> set[TensorName]:
|
|
783
|
+
"""Returns the names of the tensors that are used in multiple Einsums."""
|
|
784
|
+
return {t for t in self.tensor_names if len(self.einsums_with_tensor(t)) > 1}
|
|
785
|
+
|
|
786
|
+
@property
|
|
787
|
+
def tensor_names(self) -> set[TensorName]:
|
|
788
|
+
"""Returns the names of all tensors in the workload."""
|
|
789
|
+
return {TensorName(t.name) for e in self.einsums for t in e.tensor_accesses}
|
|
790
|
+
|
|
791
|
+
@property
|
|
792
|
+
def rank_variables(self) -> set[RankVariable]:
|
|
793
|
+
"""Returns the names of all rank variables in the workload."""
|
|
794
|
+
return {RankVariable(r) for e in self.einsums for r in e.rank_variables}
|
|
795
|
+
|
|
796
|
+
def _repr_svg_(self) -> str:
|
|
797
|
+
return self.render()
|
|
798
|
+
|
|
799
|
+
def render(self) -> str:
|
|
800
|
+
"""Renders the workload as a Pydot graph. Returns an SVG string."""
|
|
801
|
+
graph = _pydot_graph()
|
|
802
|
+
|
|
803
|
+
# Add all tensors as nodes (circles)
|
|
804
|
+
tensors = []
|
|
805
|
+
seen_tensor_names = set()
|
|
806
|
+
for einsum in self.einsums:
|
|
807
|
+
node = pydot.Node(
|
|
808
|
+
f"Einsum_{einsum.name}",
|
|
809
|
+
shape="box",
|
|
810
|
+
label=f"<{einsum._to_formatted_string(compress=True)}>",
|
|
811
|
+
)
|
|
812
|
+
graph.add_node(node)
|
|
813
|
+
for tensor_access in einsum.tensor_accesses:
|
|
814
|
+
if tensor_access.name not in seen_tensor_names:
|
|
815
|
+
tensors.append(tensor_access.name)
|
|
816
|
+
seen_tensor_names.add(tensor_access.name)
|
|
817
|
+
node = pydot.Node(
|
|
818
|
+
f"Tensor_{tensor_access.name}",
|
|
819
|
+
shape="oval",
|
|
820
|
+
label=f"<{tensor_access._to_formatted_string()}>",
|
|
821
|
+
)
|
|
822
|
+
graph.add_node(node)
|
|
823
|
+
|
|
824
|
+
# Add all einsums as nodes (rectangles)
|
|
825
|
+
for einsum in self.einsums:
|
|
826
|
+
# Add edges from tensors to einsums
|
|
827
|
+
for tensor_access in einsum.tensor_accesses:
|
|
828
|
+
if tensor_access.output:
|
|
829
|
+
# Output tensor: einsum -> tensor
|
|
830
|
+
edge = pydot.Edge(
|
|
831
|
+
f"Einsum_{einsum.name}", f"Tensor_{tensor_access.name}"
|
|
832
|
+
)
|
|
833
|
+
graph.add_edge(edge)
|
|
834
|
+
else:
|
|
835
|
+
# Input tensor: tensor -> einsum
|
|
836
|
+
edge = pydot.Edge(
|
|
837
|
+
f"Tensor_{tensor_access.name}", f"Einsum_{einsum.name}"
|
|
838
|
+
)
|
|
839
|
+
graph.add_edge(edge)
|
|
840
|
+
return _SVGJupyterRender(graph.create_svg(prog="dot").decode("utf-8"))
|
|
841
|
+
|
|
842
|
+
def _parse_expressions(
|
|
843
|
+
self, symbol_table: dict[str, Any], *args, renames: Renames, **kwargs
|
|
844
|
+
):
|
|
845
|
+
bpv, _ = self.bits_per_value._parse_expressions(symbol_table, *args, **kwargs)
|
|
846
|
+
new_st = {
|
|
847
|
+
**symbol_table,
|
|
848
|
+
"spec_workload": self,
|
|
849
|
+
"spec_renames": renames,
|
|
850
|
+
"workload_bits_per_value": bpv,
|
|
851
|
+
}
|
|
852
|
+
parsed, new_st = super()._parse_expressions(new_st, *args, **kwargs)
|
|
853
|
+
|
|
854
|
+
# Ensure bits_per_value is consistent across Einsums
|
|
855
|
+
bits_per_value_per_einsum = {}
|
|
856
|
+
bits_per_value = {}
|
|
857
|
+
for einsum in parsed.einsums:
|
|
858
|
+
cur_bpv = {t.name: t.bits_per_value for t in einsum.tensor_accesses}
|
|
859
|
+
# Check for consistency across Einsums
|
|
860
|
+
for prev_einsum, prev_bpv in bits_per_value_per_einsum.items():
|
|
861
|
+
shared_keys = set(cur_bpv.keys()) & set(prev_bpv.keys())
|
|
862
|
+
for t in shared_keys:
|
|
863
|
+
b0 = cur_bpv[t]
|
|
864
|
+
b1 = prev_bpv[t]
|
|
865
|
+
if b0 != b1:
|
|
866
|
+
raise ValueError(
|
|
867
|
+
f"Tensor {t} has bits per value {b0} in Einsum {einsum.name} "
|
|
868
|
+
f"and {b1} in Einsum {prev_einsum}. Bits per value must be "
|
|
869
|
+
"consistent across all Einsums that access a tensor."
|
|
870
|
+
)
|
|
871
|
+
bits_per_value_per_einsum[einsum.name] = cur_bpv
|
|
872
|
+
bits_per_value.update(cur_bpv)
|
|
873
|
+
|
|
874
|
+
for einsum in parsed.einsums:
|
|
875
|
+
for t, bpv in bits_per_value.items():
|
|
876
|
+
einsum.renames[t].source.bits_per_value = bpv
|
|
877
|
+
|
|
878
|
+
for r in einsum.renames:
|
|
879
|
+
src: InvertibleSet = r.source
|
|
880
|
+
if (
|
|
881
|
+
isinstance(src, InvertibleSet)
|
|
882
|
+
and len(src) == 1
|
|
883
|
+
and src.space_type == TensorName
|
|
884
|
+
and next(iter(src)) in bits_per_value
|
|
885
|
+
):
|
|
886
|
+
src.bits_per_value = bits_per_value[next(iter(src))]
|
|
887
|
+
|
|
888
|
+
parsed._check_consistent_persistent()
|
|
889
|
+
|
|
890
|
+
return parsed, symbol_table
|
|
891
|
+
|
|
892
|
+
def _get_ranks_that_share_indexing_rank_variables(self) -> dict[Rank, set[Rank]]:
|
|
893
|
+
"""
|
|
894
|
+
Returns a dictionary of ranks to the ranks with which they share indexing rank
|
|
895
|
+
variables. For example, if one einsum indexes into rank A with rank variable a
|
|
896
|
+
and another einsum indexes into rank B with rank variable a, then A and B share
|
|
897
|
+
the indexing rank variable a. Then we'd have in our return value both A: {A, B}
|
|
898
|
+
and B: {A, B}. This is transitive and reflexive.
|
|
899
|
+
|
|
900
|
+
Returns
|
|
901
|
+
-------
|
|
902
|
+
dict[Rank, set[Rank]]
|
|
903
|
+
A dictionary of ranks to the ranks with which they share indexing rank
|
|
904
|
+
variables. The ranks are the keys, and the values are sets of ranks that
|
|
905
|
+
share indexing rank variables with the key.
|
|
906
|
+
"""
|
|
907
|
+
rank2rankvars = {}
|
|
908
|
+
for tensor in self.tensor_names:
|
|
909
|
+
for acc in self.accesses_for_tensor(tensor):
|
|
910
|
+
for rank, rank_vars in acc.rank2rank_variables.items():
|
|
911
|
+
rank2rankvars.setdefault(rank, set()).update(rank_vars)
|
|
912
|
+
|
|
913
|
+
rank_var_to_ranks = {}
|
|
914
|
+
for rank, rank_vars in rank2rankvars.items():
|
|
915
|
+
for rank_var in rank_vars:
|
|
916
|
+
rank_var_to_ranks.setdefault(rank_var, set()).add(rank)
|
|
917
|
+
|
|
918
|
+
rank_to_ranks = {r: set((r,)) for r in rank2rankvars.keys()}
|
|
919
|
+
update_with = list(rank_var_to_ranks.values())
|
|
920
|
+
changed = True
|
|
921
|
+
while changed:
|
|
922
|
+
changed = False
|
|
923
|
+
for ranks in rank_to_ranks.values():
|
|
924
|
+
for u in update_with:
|
|
925
|
+
if u & ranks:
|
|
926
|
+
changed = changed or (u - ranks)
|
|
927
|
+
ranks.update(u)
|
|
928
|
+
|
|
929
|
+
return rank_to_ranks
|
|
930
|
+
|
|
931
|
+
def get_tensor_copies(self) -> dict[TensorName, set[TensorName]]:
|
|
932
|
+
"""
|
|
933
|
+
Returns a dictionary specifying which tensors are copies of which other tensors.
|
|
934
|
+
For example, if einsum A copies tensor X into tensors Y and Z, then we'd have in
|
|
935
|
+
the return value X: {Y, Z}, Y: {X, Z}, and Z: {X, Y}. This is transitive.
|
|
936
|
+
|
|
937
|
+
Returns
|
|
938
|
+
-------
|
|
939
|
+
dict[TensorName, set[TensorName]]
|
|
940
|
+
A dictionary specifying which tensors are copies of which other tensors. The
|
|
941
|
+
keys are the tensors that are copies, and the values are sets of tensors
|
|
942
|
+
that are copies of the key.
|
|
943
|
+
"""
|
|
944
|
+
tensor_copies = {}
|
|
945
|
+
for einsum in self.einsums:
|
|
946
|
+
if not einsum.is_copy_operation:
|
|
947
|
+
continue
|
|
948
|
+
input_tensor = einsum.copy_source_tensor()
|
|
949
|
+
for output_tensor in einsum.output_tensor_names:
|
|
950
|
+
tensor_copies.setdefault(input_tensor, set()).add(output_tensor)
|
|
951
|
+
tensor_copies.setdefault(output_tensor, set()).add(input_tensor)
|
|
952
|
+
return tensor_copies
|