accelforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. accelforge/__init__.py +21 -0
  2. accelforge/_accelerated_imports.py +16 -0
  3. accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
  4. accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
  5. accelforge/_deprecate/_simanneal/simanneal.py +666 -0
  6. accelforge/_deprecate/_simanneal/tracking.py +105 -0
  7. accelforge/_deprecate/_simanneal/wrappers.py +218 -0
  8. accelforge/_deprecate/_simanneal2/__init__.py +7 -0
  9. accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
  10. accelforge/_deprecate/_simanneal2/tracking.py +116 -0
  11. accelforge/_deprecate/compatibility_util.py +181 -0
  12. accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
  13. accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
  14. accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
  15. accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
  16. accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
  17. accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
  18. accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
  19. accelforge/_deprecate/tags.py +69 -0
  20. accelforge/_deprecate/viz/__init__.py +0 -0
  21. accelforge/_deprecate/viz/interactive.py +159 -0
  22. accelforge/_deprecate/viz/reservationtree.py +307 -0
  23. accelforge/_deprecate/viz/ski_slope.py +88 -0
  24. accelforge/_version.py +15 -0
  25. accelforge/examples.py +39 -0
  26. accelforge/frontend/__init__.py +10 -0
  27. accelforge/frontend/_binding.py +129 -0
  28. accelforge/frontend/_workload_isl/__init__.py +2 -0
  29. accelforge/frontend/_workload_isl/_isl.py +149 -0
  30. accelforge/frontend/_workload_isl/_symbolic.py +141 -0
  31. accelforge/frontend/arch copy.py +1544 -0
  32. accelforge/frontend/arch.py +1642 -0
  33. accelforge/frontend/config.py +63 -0
  34. accelforge/frontend/mapper/__init__.py +5 -0
  35. accelforge/frontend/mapper/ffm.py +126 -0
  36. accelforge/frontend/mapper/mapper.py +7 -0
  37. accelforge/frontend/mapper/metrics.py +30 -0
  38. accelforge/frontend/mapping/__init__.py +1 -0
  39. accelforge/frontend/mapping/mapping.py +1736 -0
  40. accelforge/frontend/model.py +14 -0
  41. accelforge/frontend/renames.py +150 -0
  42. accelforge/frontend/spec copy.py +230 -0
  43. accelforge/frontend/spec.py +301 -0
  44. accelforge/frontend/variables.py +12 -0
  45. accelforge/frontend/workload.py +952 -0
  46. accelforge/mapper/FFM/__init__.py +9 -0
  47. accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
  48. accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
  49. accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
  50. accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
  51. accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
  52. accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
  53. accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
  54. accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
  55. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
  56. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
  57. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
  58. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
  59. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
  60. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
  61. accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
  62. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
  63. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
  64. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
  65. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
  66. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
  67. accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
  68. accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
  69. accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
  70. accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
  71. accelforge/mapper/FFM/data.py +61 -0
  72. accelforge/mapper/FFM/main copy.py +236 -0
  73. accelforge/mapper/FFM/main.py +208 -0
  74. accelforge/mapper/FFM/mappings.py +510 -0
  75. accelforge/mapper/FFM/pmappings.py +310 -0
  76. accelforge/mapper/__init__.py +4 -0
  77. accelforge/mapper.py +0 -0
  78. accelforge/model/__init__.py +1 -0
  79. accelforge/model/_looptree/__init__.py +0 -0
  80. accelforge/model/_looptree/accesses.py +335 -0
  81. accelforge/model/_looptree/capacity/__init__.py +1 -0
  82. accelforge/model/_looptree/capacity/aggregators.py +36 -0
  83. accelforge/model/_looptree/capacity/capacity.py +47 -0
  84. accelforge/model/_looptree/energy.py +150 -0
  85. accelforge/model/_looptree/equivalent_ranks.py +29 -0
  86. accelforge/model/_looptree/latency/__init__.py +1 -0
  87. accelforge/model/_looptree/latency/latency.py +98 -0
  88. accelforge/model/_looptree/latency/memory.py +120 -0
  89. accelforge/model/_looptree/latency/processors.py +92 -0
  90. accelforge/model/_looptree/mapping_utilities.py +71 -0
  91. accelforge/model/_looptree/reuse/__init__.py +4 -0
  92. accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
  93. accelforge/model/_looptree/reuse/isl/des.py +59 -0
  94. accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
  95. accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
  96. accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
  97. accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
  98. accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
  99. accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
  100. accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
  101. accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
  102. accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
  103. accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
  104. accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
  105. accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
  106. accelforge/model/_looptree/run.py +122 -0
  107. accelforge/model/_looptree/types.py +26 -0
  108. accelforge/model/_looptree/visualization/__init__.py +0 -0
  109. accelforge/model/_looptree/visualization/occupancy.py +11 -0
  110. accelforge/model/main.py +222 -0
  111. accelforge/plotting/__init__.py +2 -0
  112. accelforge/plotting/mappings.py +219 -0
  113. accelforge/plotting/specs.py +57 -0
  114. accelforge/util/__init__.py +4 -0
  115. accelforge/util/_base_analysis_types.py +24 -0
  116. accelforge/util/_basetypes.py +1089 -0
  117. accelforge/util/_frozenset.py +36 -0
  118. accelforge/util/_isl.py +29 -0
  119. accelforge/util/_itertools.py +14 -0
  120. accelforge/util/_mathfuncs.py +57 -0
  121. accelforge/util/_parse_expressions.py +339 -0
  122. accelforge/util/_picklecache.py +32 -0
  123. accelforge/util/_setexpressions.py +268 -0
  124. accelforge/util/_sympy/__init__.py +0 -0
  125. accelforge/util/_sympy/broadcast_max.py +18 -0
  126. accelforge/util/_visualization.py +112 -0
  127. accelforge/util/_yaml.py +579 -0
  128. accelforge/util/parallel.py +193 -0
  129. accelforge-0.0.1.dist-info/METADATA +64 -0
  130. accelforge-0.0.1.dist-info/RECORD +258 -0
  131. accelforge-0.0.1.dist-info/WHEEL +5 -0
  132. accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
  133. accelforge-0.0.1.dist-info/top_level.txt +5 -0
  134. docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
  135. docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
  136. docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
  137. docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
  138. docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
  139. docs/_build/html/_sources/fastfusion.rst.txt +20 -0
  140. docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
  141. docs/_build/html/_sources/index.rst.txt +87 -0
  142. docs/_build/html/_sources/modules.rst.txt +7 -0
  143. docs/_build/html/_sources/notes/citation.rst.txt +45 -0
  144. docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
  145. docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
  146. docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
  147. docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
  148. docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
  149. docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
  150. docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
  151. docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
  152. docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
  153. docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
  154. docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
  155. docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
  156. docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
  157. docs/_build/html/_sources/notes/spec.rst.txt +36 -0
  158. docs/source/_ext/include_attrs.py +213 -0
  159. docs/source/_ext/include_docstring.py +364 -0
  160. docs/source/_ext/include_functions.py +154 -0
  161. docs/source/_ext/include_notebook.py +131 -0
  162. docs/source/_ext/include_yaml.py +119 -0
  163. docs/source/_ext/inherited_attributes.py +222 -0
  164. docs/source/_ext/paths.py +4 -0
  165. docs/source/conf.py +79 -0
  166. examples/arches/compute_in_memory/_include.yaml +74 -0
  167. examples/arches/compute_in_memory/_include_functions.py +229 -0
  168. examples/arches/compute_in_memory/_load_spec.py +57 -0
  169. examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
  170. examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
  171. examples/arches/compute_in_memory/components/misc.py +195 -0
  172. examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
  173. examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
  174. examples/arches/compute_in_memory/isaac.yaml +233 -0
  175. examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
  176. examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
  177. examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
  178. examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
  179. examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
  180. examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
  181. examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
  182. examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
  183. examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
  184. examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
  185. examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
  186. examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
  187. examples/arches/eyeriss.yaml +68 -0
  188. examples/arches/fanout_variations/at_glb.yaml +31 -0
  189. examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
  190. examples/arches/fanout_variations/at_mac.yaml +31 -0
  191. examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
  192. examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
  193. examples/arches/nvdla.yaml +47 -0
  194. examples/arches/simple.yaml +28 -0
  195. examples/arches/tpu_v4i.yaml +67 -0
  196. examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
  197. examples/misc/component_annotated.yaml +33 -0
  198. examples/workloads/gpt3_6.7B.yaml +124 -0
  199. examples/workloads/matmuls.yaml +20 -0
  200. examples/workloads/mobilenet_28.yaml +81 -0
  201. examples/workloads/mobilenet_various_separate.yaml +106 -0
  202. examples/workloads/three_matmuls_annotated.yaml +59 -0
  203. notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
  204. notebooks/compute_in_memory/_scripts.py +339 -0
  205. notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
  206. notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
  207. notebooks/paths.py +4 -0
  208. notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
  209. notebooks/tutorials/FFM.ipynb +3498 -0
  210. notebooks/tutorials/_include.py +48 -0
  211. notebooks/tutorials/component_energy_area.ipynb +363 -0
  212. tests/Q_mapping.yaml +38 -0
  213. tests/__init__.py +0 -0
  214. tests/conv.mapping.yaml +27 -0
  215. tests/conv.workload.yaml +13 -0
  216. tests/conv_sym.mapping.yaml +43 -0
  217. tests/copy.mapping.yaml +35 -0
  218. tests/copy.workload.yaml +15 -0
  219. tests/distribuffers/__init__.py +0 -0
  220. tests/distribuffers/multicast/test_cases.yaml +482 -0
  221. tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
  222. tests/distribuffers/spec/distributed.yaml +100 -0
  223. tests/distribuffers/spec/logical_arch.yaml +32 -0
  224. tests/distribuffers/spec/physical_arch.yaml +69 -0
  225. tests/distribuffers/test_binding.py +48 -0
  226. tests/frontend/__init__.py +0 -0
  227. tests/frontend/test_mapping_viz.py +52 -0
  228. tests/mapper/__init__.py +0 -0
  229. tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
  230. tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
  231. tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
  232. tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
  233. tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
  234. tests/mapper/test_mapping_to_isl.py +90 -0
  235. tests/mapper/test_spatial_reuse_analysis.py +67 -0
  236. tests/mapper/test_temporal_reuse_analysis.py +56 -0
  237. tests/mapper/util.py +58 -0
  238. tests/matmul.mapping.yaml +29 -0
  239. tests/matmul.workload.yaml +12 -0
  240. tests/matmul_spatial.mapping.yaml +44 -0
  241. tests/mha.renames.yaml +65 -0
  242. tests/mha.workload.yaml +67 -0
  243. tests/mha.yaml +59 -0
  244. tests/mha_full.workload.yaml +67 -0
  245. tests/mobilenet.workload.yaml +35 -0
  246. tests/mobilenet_long.workload.yaml +64 -0
  247. tests/pmappingcache.py +24 -0
  248. tests/processing_stage.arch.yaml +40 -0
  249. tests/snowcat.arch.yaml +36 -0
  250. tests/test_ffm_join_pmappings.py +106 -0
  251. tests/test_ffm_make_pmappings.py +82 -0
  252. tests/test_ffm_make_tile_shapes.py +49 -0
  253. tests/test_mapper.py +100 -0
  254. tests/test_model.py +37 -0
  255. tests/test_plotting.py +72 -0
  256. tests/test_processing_stage.py +46 -0
  257. tests/test_symbolic_model.py +248 -0
  258. tests/test_workload.py +141 -0
@@ -0,0 +1,952 @@
1
+ """
2
+ All the objects used for a Workload description in AccelForge.
3
+ """
4
+
5
+ from itertools import product
6
+ import itertools
7
+ import logging
8
+ import re
9
+ from typing import Annotated, Any, TypeAlias
10
+
11
+ import pydot
12
+
13
+ from accelforge.util.parallel import _SVGJupyterRender
14
+
15
+ from accelforge.util._basetypes import (
16
+ ParsableDict,
17
+ ParsableList,
18
+ ParsableModel,
19
+ ParsesTo,
20
+ )
21
+ from accelforge.util._visualization import _pydot_graph
22
+ from accelforge.frontend.renames import (
23
+ EinsumName,
24
+ RankVariable,
25
+ Rename,
26
+ RenameList,
27
+ Renames,
28
+ TensorName,
29
+ Rank,
30
+ rename_list_factory,
31
+ )
32
+ from accelforge.util._parse_expressions import ParseError, parse_expression
33
+ from accelforge.util._setexpressions import InvertibleSet, eval_set_expression
34
+ from accelforge._version import __version__
35
+
36
+ from accelforge.frontend.renames import (
37
+ EinsumName,
38
+ RankVariable,
39
+ Rename,
40
+ RenameList,
41
+ Renames,
42
+ TensorName,
43
+ Rank,
44
+ rename_list_factory,
45
+ )
46
+
47
+
48
+ CLIST_OPERATORS = [
49
+ "EQ",
50
+ "NE",
51
+ "LT",
52
+ "GT",
53
+ "LE",
54
+ "GE",
55
+ "NG",
56
+ "NL",
57
+ "AND",
58
+ "OR",
59
+ ]
60
+
61
+ _ISL_REGEX = re.compile(
62
+ r"\b(?!(?:" + "|".join(CLIST_OPERATORS) + r")\b)[a-zA-Z#$@][a-zA-Z0-9_]*\b"
63
+ )
64
+ """
65
+ Pattern[AnyStr@compile] _ISL_REGEX: A compiled regex pattern that matches
66
+ words that are not exactly in CLIST_OPERATORS (case-sensitive), start with a
67
+ letter, `#`, `$`, or `@`, and are followed by zero or more letters, digits,
68
+ or underscores.
69
+ """
70
+
71
+
72
+ def isl_expression_has_variable(expression: str, variable: RankVariable) -> bool:
73
+ """
74
+ Returns True if the given ISL expression has the given rank variable.
75
+
76
+ Parameters
77
+ ----------
78
+ expression : str
79
+ The ISL expression to check.
80
+ variable : RankVariable
81
+ The rank variable to check for.
82
+
83
+ Returns
84
+ -------
85
+ bool
86
+ True if the given ISL expression has the given rank variable.
87
+ """
88
+ return variable in re.findall(_ISL_REGEX, expression)
89
+
90
+
91
+ SymbolTable: TypeAlias = dict[str, InvertibleSet]
92
+
93
+
94
+ class TensorAccess(ParsableModel):
95
+ """Information about how an Einsum accesses a tensor."""
96
+
97
+ name: TensorName
98
+ """ The name of the tensor. """
99
+
100
+ projection: dict[str, str] | list[str]
101
+ """
102
+ How the rank variables of the Einsum project into the tensor. If this is a list,
103
+ then it is assumed that each of the elements of the list is a single rank variable
104
+ and they index into the tensor in ranks that equal the uppercase of the rank
105
+ variable. For example:
106
+
107
+ name: X, projection: [a, b, c] means X[A=a, B=b, C=c]
108
+
109
+ If this is a dictionary, it is a mapping from rank names to rank variable
110
+ expressions. This can be used to either project into a non-matching rank name or to
111
+ project into a tensor using an expression. For example:
112
+
113
+ name: X, projection: {A: a, B2: b, C: a+b} means X[A=a, B2=b, C=a+b]
114
+ """
115
+
116
+ output: bool = False
117
+ """ Whether the tensor is an output. False means the tensor is an input. """
118
+
119
+ persistent: bool = False
120
+ """ If True, then a copy of this tensor must remain in backing storage for the full
121
+ duration of the workload's execution. """
122
+
123
+ backing_storage_size_scale: float = 1.0
124
+ """ If != 1, then the backing storage size will be scaled by this factor. """
125
+
126
+ bits_per_value: int | str | None = None
127
+ """ Bits per value for this tensor. """
128
+
129
+ def model_post_init(self, __context__=None) -> None:
130
+ self.projection: ImpliedProjection = _projection_factory(self.projection)
131
+
132
+ def _to_formatted_string(self) -> str:
133
+ """Returns a string representation of the tensor access for Pydot nodes."""
134
+ subscript = ",".join(self.projection.values())
135
+ if isinstance(self.projection, ImpliedProjection):
136
+ return f"{self.name}<sub>{subscript}</sub>"
137
+
138
+ string = [self.name]
139
+ for k, v in self.projection.items():
140
+ if len(string) < len(self.projection):
141
+ string.append(f"<sup>{k},</sup><sub>{v},</sub>")
142
+ else:
143
+ string.append(f"<sup>{k}</sup><sub>{v}</sub>")
144
+ return "".join(string)
145
+
146
+ @property
147
+ def rank2rank_variables(self) -> dict[Rank, set[RankVariable]]:
148
+ """
149
+ Returns a dictionary of rank names to the rank variables that project into that
150
+ rank.
151
+ """
152
+ return {
153
+ Rank(rank): set(
154
+ RankVariable(rank_var)
155
+ for rank_var in re.findall(_ISL_REGEX, projection)
156
+ )
157
+ for rank, projection in self.projection.items()
158
+ }
159
+
160
+ @property
161
+ def rank_variable2ranks(self) -> dict[RankVariable, set[Rank]]:
162
+ """
163
+ Returns a dictionary of rank variables to the ranks into which that rank
164
+ variable projects.
165
+ """
166
+ result = {}
167
+ for rank, projection in self.projection.items():
168
+ for rank_var in re.findall(_ISL_REGEX, projection):
169
+ rank_set: set = result.setdefault(rank_var, set())
170
+ rank_set.add(rank)
171
+ return result
172
+
173
+ @property
174
+ def ranks(self) -> tuple[Rank, ...]:
175
+ """Returns the ranks of this access's tensor."""
176
+ return tuple(Rank(x) for x in self.projection.keys())
177
+
178
+ @property
179
+ def rank_variables(self) -> set[RankVariable]:
180
+ """Returns all rank variables used in this access."""
181
+ # Projection values may be expressions, so we need to grab all identifiers
182
+ return set(
183
+ RankVariable(x)
184
+ for x in re.findall(_ISL_REGEX, " ".join(self.projection.values()))
185
+ )
186
+
187
+ @property
188
+ def directly_indexing_rank_variables(self) -> set[RankVariable]:
189
+ """
190
+ Returns the rank variables that directly index into this tensor without any
191
+ expression (e.g., "M=m", NOT "M=m+n").
192
+ """
193
+ return set(
194
+ RankVariable(x) for x in self.projection.values() if _ISL_REGEX.match(x)
195
+ )
196
+
197
+ @property
198
+ def expression_indexing_rank_variables(self) -> set[RankVariable]:
199
+ """
200
+ Returns the rank variables that indirectly index into this tensor through an
201
+ expression (e.g., "M=m+n") instead of a direct index (e.g., "M=m").
202
+ """
203
+ return self.rank_variables - self.directly_indexing_rank_variables
204
+
205
+
206
+ class ImpliedProjection(dict):
207
+ """
208
+ Holds a projection that has been implied by a list of rank variables. The implied
209
+ rank names are uppercased versions of the rank variables; for example, [a, b, c] ->
210
+ {A: a, B: b, C: c}.
211
+ """
212
+
213
+
214
+ def _projection_factory(projection: dict | list):
215
+ if isinstance(projection, list):
216
+ for i, x in enumerate(projection):
217
+ if not isinstance(x, str):
218
+ raise TypeError(f"Element at index {i} must be a string, got {type(x)}")
219
+ if not _ISL_REGEX.match(x):
220
+ raise ValueError(
221
+ f"Element '{x}' at index {i} is not a valid ISL identifier"
222
+ f"In a projection list, all elements must be valid ISL identifiers."
223
+ f"For expressions, use a dictionary projection."
224
+ )
225
+ projection = ImpliedProjection({x.upper(): x for x in projection})
226
+ elif not isinstance(projection, dict):
227
+ raise TypeError(
228
+ f"Invalid projection: {projection}. Must be a list of rank variables or a "
229
+ f"dictionary of rank variable to projection."
230
+ )
231
+ for key in projection:
232
+ if not isinstance(key, str):
233
+ raise TypeError(f"Invalid projection key: {key}. Must be a string.")
234
+ if not key.isidentifier():
235
+ raise ValueError(
236
+ f"Invalid projection key: {key}. Must be a valid identifier. Check with "
237
+ f"the Python isidentifier() function."
238
+ )
239
+ return projection
240
+
241
+
242
+ class Shape(ParsableList):
243
+ """
244
+ Specifies valid values for the rank variables. This is a list of strings, each one
245
+ an ISL expression. The total space is considered to be the logal AND of all the
246
+ expressions in the list.
247
+ """
248
+
249
+ @property
250
+ def rank_variables(self) -> set[str]:
251
+ """Returns all rank variables used in this shape."""
252
+ if not self:
253
+ return set()
254
+ return set.union(*[set(re.findall(_ISL_REGEX, x)) for x in self])
255
+
256
+
257
+ class Einsum(ParsableModel):
258
+ """
259
+ Represents an Einsum, which is a single computation step in the workload. The Einsum
260
+ includes a set of rank variables, which are used to index into tensors. Rank
261
+ variables iterate through an iteration space.
262
+
263
+ For example, if the Einsum is A[m, n] += B[k, n] * C[k, n] and we define the
264
+ iteration space as "0 <= m < 10, 0 <= n < 10, 0 <= k < 10", then the Einsum will
265
+ iterate through all possible values of (m, n, k) in the iteration space, indexing
266
+ into tensors for each and updating A[m, n] with B[k, n] * C[k, n].
267
+ """
268
+
269
+ name: EinsumName
270
+ """ The name of the Einsum. """
271
+ tensor_accesses: ParsableList[TensorAccess]
272
+ """ The tensors accessed by this Einsum, and how they are accessed. """
273
+ iteration_space_shape: Shape[str] = Shape()
274
+ """
275
+ Bounds of valid rank variable values. This is a list of expressions, each one an ISL
276
+ expression. Additionally, global iteration_space_shape expressions are appended to
277
+ the list if their rank variables are present in the Einsum's rank_variables. For
278
+ example, if the global scope has "m: 0 <= m < 10" and the Einsum has "m" in its
279
+ rank_variables, then "0 <= m < 10" will be appended to the iteration_space_shape.
280
+ """
281
+ rank_sizes: ParsableDict[Rank, int] = ParsableDict()
282
+ """
283
+ Sizes of ranks. This is a dictionary of rank names to sizes. Sizes are integers, and
284
+ the rank's bounds are 0 <= rank < size. Accesses outside of these bounds are
285
+ skipped.
286
+ """
287
+ is_copy_operation: bool = False
288
+ """ Whether the Einsum is a copy operation. Copy operations take the input tensor
289
+ and directly place them at the location of the output tensor(s) without any
290
+ computation. If the destination tensor is at the same location, then this is a
291
+ no-op."""
292
+ renames: RenameList[Rename] = RenameList()
293
+ """ Renames of the Einsum. Renames here can be used to rename rank variables or
294
+ tensors. When this Einsum is executed on an architecture, the architecture can use
295
+ renamed tensors and rank variables to access the tensors and rank variables. """
296
+ n_instances: int = 1
297
+ """
298
+ Number of times to repeat the Einsum. Multiplied by `Workload.n_instances` to get
299
+ the total number of Einsum instances. Energy, latency, and other summable metrics
300
+ are multiplied by this value. Persistent reservations are also multiplied by this
301
+ value, but non-persistent reservations are not, as they are assumed to be freed
302
+ between each instance.
303
+ """
304
+
305
+ def model_post_init(self, __context__=None) -> None:
306
+ if self.name == "Total":
307
+ raise ValueError(
308
+ f'Einsum name "Total" is reserved for totaling across Einsums.'
309
+ f"Use a different name for the Einsum."
310
+ )
311
+
312
+ def __init__(self, *args, **kwargs):
313
+ if "renames" in kwargs:
314
+ kwargs["renames"] = rename_list_factory(kwargs["renames"])
315
+ super().__init__(*args, **kwargs)
316
+
317
+ @property
318
+ def rank_variables(self) -> set[RankVariable]:
319
+ """Returns all rank variables used in this Einsum."""
320
+ if not self.tensor_accesses:
321
+ return set()
322
+ return set.union(*[t.rank_variables for t in self.tensor_accesses])
323
+
324
+ @property
325
+ def ranks(self) -> set[Rank]:
326
+ """Returns all ranks used in this Einsum."""
327
+ if not self.tensor_accesses:
328
+ return set()
329
+ return set.union(*[set(t.ranks) for t in self.tensor_accesses])
330
+
331
+ @property
332
+ def input_tensor_names(self) -> set[TensorName]:
333
+ """Returns the names of the input tensors of this Einsum."""
334
+ return set([TensorName(t.name) for t in self.tensor_accesses if not t.output])
335
+
336
+ @property
337
+ def output_tensor_names(self) -> set[TensorName]:
338
+ """Returns the names of the output tensors of this Einsum."""
339
+ return set([TensorName(t.name) for t in self.tensor_accesses if t.output])
340
+
341
+ @property
342
+ def tensor_names(self) -> set[TensorName]:
343
+ """Returns the names of all tensors of this Einsum."""
344
+ return set([TensorName(t.name) for t in self.tensor_accesses])
345
+
346
+ @property
347
+ def tensor2rank_variables(self) -> dict[TensorName, set[RankVariable]]:
348
+ """Returns a dictionary of tensor names to the rank variables that project into
349
+ that tensor."""
350
+ return {TensorName(t.name): t.rank_variables for t in self.tensor_accesses}
351
+
352
+ @property
353
+ def tensor2directly_indexing_rank_variables(
354
+ self,
355
+ ) -> dict[TensorName, set[RankVariable]]:
356
+ """
357
+ Returns a dictionary of tensor names to the rank variables that directly index
358
+ into that tensor. Direct indexing means that the rank variable is used as a
359
+ direct index into the tensor, without any expression (e.g., "M=m", NOT "M=m+n").
360
+ """
361
+ return {
362
+ TensorName(t.name): t.directly_indexing_rank_variables
363
+ for t in self.tensor_accesses
364
+ }
365
+
366
+ @property
367
+ def tensor2expression_indexing_rank_variables(
368
+ self,
369
+ ) -> dict[TensorName, set[RankVariable]]:
370
+ """
371
+ Returns a dictionary of tensor names to the rank variables that indirectly index
372
+ into that tensor through an expression (e.g., "M=m+n") instead of a direct index
373
+ (e.g., "M=m").
374
+ """
375
+ fully_relevant_rank_vars = self.tensor2directly_indexing_rank_variables
376
+ return {
377
+ TensorName(t.name): t.rank_variables - fully_relevant_rank_vars[t.name]
378
+ for t in self.tensor_accesses
379
+ }
380
+
381
+ @property
382
+ def tensor2irrelevant_rank_variables(
383
+ self,
384
+ ) -> dict[TensorName, set[RankVariable]]:
385
+ """
386
+ Returns a dictionary of tensor names to the rank variables that are irrelevant
387
+ to that tensor. Irrelevant rank variables are rank variables that are not used
388
+ to index into the tensor.
389
+ """
390
+ partially_relevant = self.tensor2expression_indexing_rank_variables
391
+ fully_relevant = self.tensor2directly_indexing_rank_variables
392
+ rank_variables = self.rank_variables
393
+ return {
394
+ TensorName(t.name): rank_variables
395
+ - fully_relevant[t.name]
396
+ - partially_relevant[t.name]
397
+ for t in self.tensor_accesses
398
+ }
399
+
400
+ def _to_formatted_string(self, compress: bool = False) -> str:
401
+ """
402
+ Returns a string representation of this Einsum for use in a Pydot graph.
403
+
404
+ Parameters
405
+ ----------
406
+ compress : bool, optional
407
+ If True, the string will be compressed to a single line.
408
+
409
+ Returns
410
+ -------
411
+ str
412
+ A string representation of this Einsum for use in a Pydot graph.
413
+ """
414
+ lhs_join = ",\n" if compress else " , "
415
+ rhs_join = " \n " if compress else " "
416
+ lhs = lhs_join.join(
417
+ [t._to_formatted_string() for t in self.tensor_accesses if t.output]
418
+ )
419
+ rhs = rhs_join.join(
420
+ [t._to_formatted_string() for t in self.tensor_accesses if not t.output]
421
+ )
422
+ return f"{lhs}=\n{rhs}" if compress else f"{lhs} = {rhs}"
423
+
424
+ def copy_source_tensor(self) -> TensorName | None:
425
+ """
426
+ If this Einsum is a copy operation, returns the name of the tensor that is the
427
+ source of the copy. Otherwise, returns None.
428
+ """
429
+ if not self.is_copy_operation:
430
+ return None
431
+ input_tensors = self.input_tensor_names
432
+ if len(input_tensors) != 1:
433
+ raise ValueError(
434
+ f"Copy Einsum {self.name} has {len(input_tensors)} input tensors, expected 1"
435
+ )
436
+ return input_tensors.pop()
437
+
438
+ @property
439
+ def rank_variable2ranks(self) -> dict[RankVariable, set[Rank]]:
440
+ """
441
+ Returns a dictionary of rank variables to the ranks that are indexed into by
442
+ that rank variable.
443
+ """
444
+ result: dict[RankVariable, set[Rank]] = {}
445
+ for tensor_access in self.tensor_accesses:
446
+ new = tensor_access.rank_variable2ranks
447
+ for rank_var, ranks in new.items():
448
+ result.setdefault(rank_var, set()).update(ranks)
449
+ return result
450
+
451
+ @property
452
+ def indexing_expressions(self) -> set[str]:
453
+ """
454
+ Returns a list of all the expressions that index into the tensors of this
455
+ Einsum.
456
+ """
457
+ result = set()
458
+ for tensor_access in self.tensor_accesses:
459
+ for _, projection in tensor_access.projection.items():
460
+ result.add(projection)
461
+ return result
462
+
463
+ def _parse_expressions(self, symbol_table: dict[str, Any], *args, **kwargs):
464
+ workload: Workload = symbol_table["spec_workload"]
465
+ renames: Renames = symbol_table["spec_renames"]
466
+
467
+ # Put together renames symbol table
468
+ inputs = self.input_tensor_names
469
+ outputs = self.output_tensor_names
470
+ all_ = inputs | outputs
471
+ persistent = {t.name for t in self.tensor_accesses if t.persistent}
472
+ element_to_child_space = {}
473
+ all_rank_variables = self.rank_variables
474
+ for tensor in self.tensor_names:
475
+ element_to_child_space[tensor] = InvertibleSet(
476
+ instance=self.tensor2rank_variables[tensor],
477
+ full_space=all_rank_variables,
478
+ space_type=RankVariable,
479
+ )
480
+
481
+ intermediates = {
482
+ t
483
+ for t in all_
484
+ if workload.einsums_with_tensor_as_input(t)
485
+ and workload.einsums_with_tensor_as_output(t)
486
+ }
487
+ shared = {
488
+ t
489
+ for t in all_
490
+ if len(
491
+ set(e.name for e in workload.einsums_with_tensor_as_input(t))
492
+ | set(e.name for e in workload.einsums_with_tensor_as_output(t))
493
+ )
494
+ > 1
495
+ }
496
+
497
+ kwargs_tensors = dict(
498
+ full_space=all_,
499
+ space_type=TensorName,
500
+ child_access_name="rank_variables",
501
+ element_to_child_space=element_to_child_space,
502
+ )
503
+ kwargs_rank_variables = dict(
504
+ full_space=all_rank_variables,
505
+ space_type=RankVariable,
506
+ )
507
+ rename_symbol_table = {
508
+ "All": InvertibleSet(instance=all_, **kwargs_tensors),
509
+ "Tensors": InvertibleSet(instance=all_, **kwargs_tensors),
510
+ "Nothing": InvertibleSet(instance=(), **kwargs_tensors),
511
+ "Inputs": InvertibleSet(instance=inputs, **kwargs_tensors),
512
+ "Outputs": InvertibleSet(instance=outputs, **kwargs_tensors),
513
+ "Intermediates": InvertibleSet(instance=intermediates, **kwargs_tensors),
514
+ "Shared": InvertibleSet(instance=shared, **kwargs_tensors),
515
+ "Persistent": InvertibleSet(instance=persistent, **kwargs_tensors),
516
+ **{t: InvertibleSet(instance=(t,), **kwargs_tensors) for t in all_},
517
+ **{
518
+ r: InvertibleSet(instance=(r,), **kwargs_rank_variables)
519
+ for r in all_rank_variables
520
+ },
521
+ "Einsum": self.name,
522
+ "Above": InvertibleSet(instance=(), **kwargs_tensors),
523
+ }
524
+
525
+ for t in workload.tensor_names:
526
+ if t not in rename_symbol_table:
527
+ rename_symbol_table[t] = InvertibleSet(instance=(), **kwargs_tensors)
528
+
529
+ for r in workload.rank_variables:
530
+ if r not in rename_symbol_table:
531
+ rename_symbol_table[r] = InvertibleSet(
532
+ instance=(), **kwargs_rank_variables
533
+ )
534
+
535
+ st = {**rename_symbol_table, **symbol_table}
536
+
537
+ self: Einsum = self.model_copy()
538
+ self.renames = RenameList(self.renames)
539
+
540
+ # Grab the default renames and update the renames with more values
541
+ default_renames = renames.get_renames_for_einsum("default")
542
+ for tensor_rename in default_renames.tensor_accesses:
543
+ if tensor_rename.name not in self.renames:
544
+ self.renames.append(tensor_rename)
545
+ for rank_variable_rename in default_renames.rank_variables:
546
+ if rank_variable_rename.name not in self.renames:
547
+ self.renames.append(rank_variable_rename)
548
+
549
+ # Parse me!
550
+ kwargs["must_parse_try_parse_to"] = True
551
+ parsed, _ = super(self.__class__, self)._parse_expressions(st, *args, **kwargs)
552
+
553
+ # Update the renames with the new values
554
+ for k, v in rename_symbol_table.items():
555
+ if k not in parsed.renames:
556
+ parsed.renames.append(Rename(name=k, source=v))
557
+
558
+ # Parse the bits per value
559
+ bits_per_value = dict()
560
+ bpv_to_source = dict()
561
+ for k, v in symbol_table["workload_bits_per_value"].items():
562
+ bpv = eval_set_expression(
563
+ expression=k,
564
+ symbol_table=st,
565
+ expected_space=TensorName,
566
+ location=f"(workload global bits_per_value)[{k}]",
567
+ )
568
+ for t in bpv:
569
+ if t in bits_per_value:
570
+ raise ParseError(
571
+ f"Tensor {t} is specified in multiple entries in the workload "
572
+ f"global bits_per_value dictionary.",
573
+ source_field=f"({k} AND {bpv_to_source[t]})",
574
+ )
575
+ bits_per_value[t] = v
576
+ bpv_to_source[t] = k
577
+
578
+ for t in parsed.tensor_accesses:
579
+ if t.bits_per_value is None and t.name not in bits_per_value:
580
+ raise ParseError(
581
+ f"Tensor {t.name} in Einsum does not have a bits per value "
582
+ f"specified. Ensure that the tensor is either covered by the set "
583
+ f"expressions in the workload.bits_per_value dictionary "
584
+ f"or bits_per_value is specified for the tensor access."
585
+ f"",
586
+ source_field=f"tensor_accesses[{t.name}].bits_per_value",
587
+ )
588
+ if t.bits_per_value is None:
589
+ t.bits_per_value = bits_per_value[t.name]
590
+
591
+ return parsed, symbol_table
592
+
593
+
594
+ class Workload(ParsableModel):
595
+ """
596
+ The workload specification as a cascade of Einsums, with each Einsum being a
597
+ computation step in the workload.
598
+ """
599
+
600
+ # version: Annotated[str, assert_version] = __version__
601
+ # """ The version of the workload specification. """
602
+
603
+ einsums: ParsableList[Einsum] = ParsableList()
604
+ """ The Einsums in the workload. """
605
+
606
+ iteration_space_shape: ParsableDict[RankVariable, str] = ParsableDict()
607
+ """
608
+ Bounds of valid rank variable values. This is a dictionary of rank variable
609
+ names to bounds of valid rank variable values. The bounds are specified as a string
610
+ in the ISL format. For example, "0 <= a < 10" means that the rank variable `a` must
611
+ be between 0 and 10, including 0 but not 10. Bounds are included for all Einsums
612
+ that include that rank variable.
613
+ """
614
+
615
+ rank_sizes: ParsableDict[Rank, ParsesTo[int]] = ParsableDict()
616
+ """
617
+ Rank sizes. This is a dictionary of rank names to sizes. Sizes are integers, and the
618
+ rank's bounds are 0 <= rank < size. Accesses outside of these bounds are skipped.
619
+ """
620
+
621
+ n_instances: int = 1
622
+ """
623
+ Number of times to repeat the workload. Multiplied by `Einsum.n_instances` to get
624
+ the total number of Einsum instances. Energy, latency, and other summable metrics
625
+ are multiplied by this value. Persistent reservations are also multiplied by this
626
+ value, but non-persistent reservations are not, as they are assumed to be freed
627
+ between each instance.
628
+ """
629
+
630
+ bits_per_value: ParsableDict[str, int | str] = ParsableDict()
631
+ """
632
+ Bits per value for each tensor. The workload-level bits_per_value is overridden if
633
+ bits_per_action is specified for any given tensor access. This is a dictionary of
634
+ set expressions to bits per value for the tensors given by those expressions. For
635
+ example, we may write "Inputs: 8" to set the bits per value to 8 for all input
636
+ tensors, unless overridden.
637
+ """
638
+
639
+ def model_post_init(self, __context__=None) -> None:
640
+ self._validate()
641
+
642
+ def _validate(self):
643
+ tensor2ranks = {}
644
+ einsum_names = set()
645
+ for einsum in self.einsums:
646
+ if einsum.name in einsum_names:
647
+ raise ValueError(f"Einsum name {einsum.name} is not unique")
648
+ einsum_names.add(einsum.name)
649
+ for tensor_accesses in einsum.tensor_accesses:
650
+ tensor2ranks.setdefault(tensor_accesses.name, tensor_accesses.ranks)
651
+ if tensor2ranks[tensor_accesses.name] != tensor_accesses.ranks:
652
+ raise ValueError(
653
+ f"TensorName {tensor_accesses.name} has inconsistent ranks. Found "
654
+ f"{tensor2ranks[tensor_accesses.name]} and {tensor_accesses.ranks}. "
655
+ "TensorName is in Einsums "
656
+ f"{', '.join(
657
+ e.name for e in self.einsums_with_tensor(tensor_accesses.name)
658
+ )}"
659
+ )
660
+
661
+ @property
662
+ def einsum_names(self) -> list[EinsumName]:
663
+ """Returns the names of the Einsums in the workload."""
664
+ return [EinsumName(e.name) for e in self.einsums]
665
+
666
+ def einsums_with_tensor(self, tensor: TensorName) -> list["Einsum"]:
667
+ """
668
+ Returns the Einsums in the workload that access the given tensor.
669
+
670
+ Parameters
671
+ ----------
672
+ tensor : TensorName
673
+ The tensor to check.
674
+
675
+ Returns
676
+ -------
677
+ list[Einsum]
678
+ The Einsums in the workload that access the given tensor. Order is the same
679
+ as the order in this workload's Einsums list.
680
+ """
681
+ return [e for e in self.einsums if tensor in e.tensor_names]
682
+
683
+ def einsums_with_tensor_as_input(self, tensor: TensorName) -> list["Einsum"]:
684
+ """
685
+ Returns the Einsums in the workload that use the given tensor as an input.
686
+
687
+ Parameters
688
+ ----------
689
+ tensor : TensorName
690
+ The tensor to check.
691
+
692
+ Returns
693
+ -------
694
+ list[Einsum]
695
+ The Einsums in the workload that use the given tensor as an input. Order is
696
+ the same as the order in this workload's Einsums list.
697
+ """
698
+ return [e for e in self.einsums if tensor in e.input_tensor_names]
699
+
700
+ def einsums_with_tensor_as_output(self, tensor: TensorName) -> list["Einsum"]:
701
+ """
702
+ Returns the Einsums in the workload that have the given tensor as an output.
703
+
704
+ Parameters
705
+ ----------
706
+ tensor : TensorName
707
+ The tensor to check.
708
+
709
+ Returns
710
+ -------
711
+ list[Einsum]
712
+ The Einsums in the workload that have the given tensor as an output. Order
713
+ is the same as the order in this workload's Einsums list.
714
+ """
715
+ return [e for e in self.einsums if tensor in e.output_tensor_names]
716
+
717
+ def accesses_for_tensor(self, tensor: TensorName) -> list[TensorAccess]:
718
+ """
719
+ Returns all TensorAccess objects that access the given tensor across all
720
+ Einsums.
721
+
722
+ Parameters
723
+ ----------
724
+ tensor : TensorName
725
+ The tensor to check.
726
+
727
+ Returns
728
+ -------
729
+ list[TensorAccess]
730
+ The TensorAccess objects that access the given tensor across all Einsums.
731
+ Order is the same as the order in this workload's Einsums list.
732
+ """
733
+ return [t for e in self.einsums for t in e.tensor_accesses if t.name == tensor]
734
+
735
+ def get_iteration_space_shape_isl_string(self, einsum_name: str) -> str:
736
+ """
737
+ Returns the ISL string representing the iteration space of the given Einsum.
738
+
739
+ Parameters
740
+ ----------
741
+ einsum_name : str
742
+ The name of the Einsum for which to get the iteration space shape.
743
+
744
+ Returns
745
+ -------
746
+ str
747
+ The ISL string representing the iteration space shape of the given Einsum.
748
+ """
749
+ einsum = self.einsums[einsum_name]
750
+ einsum_shape = einsum.iteration_space_shape
751
+ my_ispace = self.iteration_space_shape
752
+ global_shape = [my_ispace[r] for r in einsum.rank_variables if r in my_ispace]
753
+ rank_sizes = einsum.rank_sizes
754
+ global_rank_sizes = {
755
+ r: self.rank_sizes[r] for r in einsum.ranks if r in self.rank_sizes
756
+ }
757
+
758
+ exprs = einsum_shape + global_shape
759
+ for tensor in einsum.tensor_accesses:
760
+ for rank, projection in tensor.projection.items():
761
+ if rank in rank_sizes:
762
+ exprs.append(f"0 <= {projection} < {rank_sizes[rank]}")
763
+ elif rank in global_rank_sizes:
764
+ exprs.append(f"0 <= {projection} < {global_rank_sizes[rank]}")
765
+
766
+ return " and ".join(exprs)
767
+
768
+ def _check_consistent_persistent(self):
769
+ for tensor in self.tensor_names:
770
+ persistents = {
771
+ e.tensor_accesses[tensor].persistent
772
+ for e in self.einsums_with_tensor(tensor)
773
+ }
774
+ if len(persistents) > 1:
775
+ raise ValueError(
776
+ f"Tensor {tensor} is used in multiple Einsums with different "
777
+ f"persistent values. Persistent values must be consistent across "
778
+ f"all Einsums that use the tensor."
779
+ )
780
+
781
+ @property
782
+ def tensor_names_used_in_multiple_einsums(self) -> set[TensorName]:
783
+ """Returns the names of the tensors that are used in multiple Einsums."""
784
+ return {t for t in self.tensor_names if len(self.einsums_with_tensor(t)) > 1}
785
+
786
+ @property
787
+ def tensor_names(self) -> set[TensorName]:
788
+ """Returns the names of all tensors in the workload."""
789
+ return {TensorName(t.name) for e in self.einsums for t in e.tensor_accesses}
790
+
791
+ @property
792
+ def rank_variables(self) -> set[RankVariable]:
793
+ """Returns the names of all rank variables in the workload."""
794
+ return {RankVariable(r) for e in self.einsums for r in e.rank_variables}
795
+
796
+ def _repr_svg_(self) -> str:
797
+ return self.render()
798
+
799
+ def render(self) -> str:
800
+ """Renders the workload as a Pydot graph. Returns an SVG string."""
801
+ graph = _pydot_graph()
802
+
803
+ # Add all tensors as nodes (circles)
804
+ tensors = []
805
+ seen_tensor_names = set()
806
+ for einsum in self.einsums:
807
+ node = pydot.Node(
808
+ f"Einsum_{einsum.name}",
809
+ shape="box",
810
+ label=f"<{einsum._to_formatted_string(compress=True)}>",
811
+ )
812
+ graph.add_node(node)
813
+ for tensor_access in einsum.tensor_accesses:
814
+ if tensor_access.name not in seen_tensor_names:
815
+ tensors.append(tensor_access.name)
816
+ seen_tensor_names.add(tensor_access.name)
817
+ node = pydot.Node(
818
+ f"Tensor_{tensor_access.name}",
819
+ shape="oval",
820
+ label=f"<{tensor_access._to_formatted_string()}>",
821
+ )
822
+ graph.add_node(node)
823
+
824
+ # Add all einsums as nodes (rectangles)
825
+ for einsum in self.einsums:
826
+ # Add edges from tensors to einsums
827
+ for tensor_access in einsum.tensor_accesses:
828
+ if tensor_access.output:
829
+ # Output tensor: einsum -> tensor
830
+ edge = pydot.Edge(
831
+ f"Einsum_{einsum.name}", f"Tensor_{tensor_access.name}"
832
+ )
833
+ graph.add_edge(edge)
834
+ else:
835
+ # Input tensor: tensor -> einsum
836
+ edge = pydot.Edge(
837
+ f"Tensor_{tensor_access.name}", f"Einsum_{einsum.name}"
838
+ )
839
+ graph.add_edge(edge)
840
+ return _SVGJupyterRender(graph.create_svg(prog="dot").decode("utf-8"))
841
+
842
+ def _parse_expressions(
843
+ self, symbol_table: dict[str, Any], *args, renames: Renames, **kwargs
844
+ ):
845
+ bpv, _ = self.bits_per_value._parse_expressions(symbol_table, *args, **kwargs)
846
+ new_st = {
847
+ **symbol_table,
848
+ "spec_workload": self,
849
+ "spec_renames": renames,
850
+ "workload_bits_per_value": bpv,
851
+ }
852
+ parsed, new_st = super()._parse_expressions(new_st, *args, **kwargs)
853
+
854
+ # Ensure bits_per_value is consistent across Einsums
855
+ bits_per_value_per_einsum = {}
856
+ bits_per_value = {}
857
+ for einsum in parsed.einsums:
858
+ cur_bpv = {t.name: t.bits_per_value for t in einsum.tensor_accesses}
859
+ # Check for consistency across Einsums
860
+ for prev_einsum, prev_bpv in bits_per_value_per_einsum.items():
861
+ shared_keys = set(cur_bpv.keys()) & set(prev_bpv.keys())
862
+ for t in shared_keys:
863
+ b0 = cur_bpv[t]
864
+ b1 = prev_bpv[t]
865
+ if b0 != b1:
866
+ raise ValueError(
867
+ f"Tensor {t} has bits per value {b0} in Einsum {einsum.name} "
868
+ f"and {b1} in Einsum {prev_einsum}. Bits per value must be "
869
+ "consistent across all Einsums that access a tensor."
870
+ )
871
+ bits_per_value_per_einsum[einsum.name] = cur_bpv
872
+ bits_per_value.update(cur_bpv)
873
+
874
+ for einsum in parsed.einsums:
875
+ for t, bpv in bits_per_value.items():
876
+ einsum.renames[t].source.bits_per_value = bpv
877
+
878
+ for r in einsum.renames:
879
+ src: InvertibleSet = r.source
880
+ if (
881
+ isinstance(src, InvertibleSet)
882
+ and len(src) == 1
883
+ and src.space_type == TensorName
884
+ and next(iter(src)) in bits_per_value
885
+ ):
886
+ src.bits_per_value = bits_per_value[next(iter(src))]
887
+
888
+ parsed._check_consistent_persistent()
889
+
890
+ return parsed, symbol_table
891
+
892
+ def _get_ranks_that_share_indexing_rank_variables(self) -> dict[Rank, set[Rank]]:
893
+ """
894
+ Returns a dictionary of ranks to the ranks with which they share indexing rank
895
+ variables. For example, if one einsum indexes into rank A with rank variable a
896
+ and another einsum indexes into rank B with rank variable a, then A and B share
897
+ the indexing rank variable a. Then we'd have in our return value both A: {A, B}
898
+ and B: {A, B}. This is transitive and reflexive.
899
+
900
+ Returns
901
+ -------
902
+ dict[Rank, set[Rank]]
903
+ A dictionary of ranks to the ranks with which they share indexing rank
904
+ variables. The ranks are the keys, and the values are sets of ranks that
905
+ share indexing rank variables with the key.
906
+ """
907
+ rank2rankvars = {}
908
+ for tensor in self.tensor_names:
909
+ for acc in self.accesses_for_tensor(tensor):
910
+ for rank, rank_vars in acc.rank2rank_variables.items():
911
+ rank2rankvars.setdefault(rank, set()).update(rank_vars)
912
+
913
+ rank_var_to_ranks = {}
914
+ for rank, rank_vars in rank2rankvars.items():
915
+ for rank_var in rank_vars:
916
+ rank_var_to_ranks.setdefault(rank_var, set()).add(rank)
917
+
918
+ rank_to_ranks = {r: set((r,)) for r in rank2rankvars.keys()}
919
+ update_with = list(rank_var_to_ranks.values())
920
+ changed = True
921
+ while changed:
922
+ changed = False
923
+ for ranks in rank_to_ranks.values():
924
+ for u in update_with:
925
+ if u & ranks:
926
+ changed = changed or (u - ranks)
927
+ ranks.update(u)
928
+
929
+ return rank_to_ranks
930
+
931
+ def get_tensor_copies(self) -> dict[TensorName, set[TensorName]]:
932
+ """
933
+ Returns a dictionary specifying which tensors are copies of which other tensors.
934
+ For example, if einsum A copies tensor X into tensors Y and Z, then we'd have in
935
+ the return value X: {Y, Z}, Y: {X, Z}, and Z: {X, Y}. This is transitive.
936
+
937
+ Returns
938
+ -------
939
+ dict[TensorName, set[TensorName]]
940
+ A dictionary specifying which tensors are copies of which other tensors. The
941
+ keys are the tensors that are copies, and the values are sets of tensors
942
+ that are copies of the key.
943
+ """
944
+ tensor_copies = {}
945
+ for einsum in self.einsums:
946
+ if not einsum.is_copy_operation:
947
+ continue
948
+ input_tensor = einsum.copy_source_tensor()
949
+ for output_tensor in einsum.output_tensor_names:
950
+ tensor_copies.setdefault(input_tensor, set()).add(output_tensor)
951
+ tensor_copies.setdefault(output_tensor, set()).add(input_tensor)
952
+ return tensor_copies