accelforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. accelforge/__init__.py +21 -0
  2. accelforge/_accelerated_imports.py +16 -0
  3. accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
  4. accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
  5. accelforge/_deprecate/_simanneal/simanneal.py +666 -0
  6. accelforge/_deprecate/_simanneal/tracking.py +105 -0
  7. accelforge/_deprecate/_simanneal/wrappers.py +218 -0
  8. accelforge/_deprecate/_simanneal2/__init__.py +7 -0
  9. accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
  10. accelforge/_deprecate/_simanneal2/tracking.py +116 -0
  11. accelforge/_deprecate/compatibility_util.py +181 -0
  12. accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
  13. accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
  14. accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
  15. accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
  16. accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
  17. accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
  18. accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
  19. accelforge/_deprecate/tags.py +69 -0
  20. accelforge/_deprecate/viz/__init__.py +0 -0
  21. accelforge/_deprecate/viz/interactive.py +159 -0
  22. accelforge/_deprecate/viz/reservationtree.py +307 -0
  23. accelforge/_deprecate/viz/ski_slope.py +88 -0
  24. accelforge/_version.py +15 -0
  25. accelforge/examples.py +39 -0
  26. accelforge/frontend/__init__.py +10 -0
  27. accelforge/frontend/_binding.py +129 -0
  28. accelforge/frontend/_workload_isl/__init__.py +2 -0
  29. accelforge/frontend/_workload_isl/_isl.py +149 -0
  30. accelforge/frontend/_workload_isl/_symbolic.py +141 -0
  31. accelforge/frontend/arch copy.py +1544 -0
  32. accelforge/frontend/arch.py +1642 -0
  33. accelforge/frontend/config.py +63 -0
  34. accelforge/frontend/mapper/__init__.py +5 -0
  35. accelforge/frontend/mapper/ffm.py +126 -0
  36. accelforge/frontend/mapper/mapper.py +7 -0
  37. accelforge/frontend/mapper/metrics.py +30 -0
  38. accelforge/frontend/mapping/__init__.py +1 -0
  39. accelforge/frontend/mapping/mapping.py +1736 -0
  40. accelforge/frontend/model.py +14 -0
  41. accelforge/frontend/renames.py +150 -0
  42. accelforge/frontend/spec copy.py +230 -0
  43. accelforge/frontend/spec.py +301 -0
  44. accelforge/frontend/variables.py +12 -0
  45. accelforge/frontend/workload.py +952 -0
  46. accelforge/mapper/FFM/__init__.py +9 -0
  47. accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
  48. accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
  49. accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
  50. accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
  51. accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
  52. accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
  53. accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
  54. accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
  55. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
  56. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
  57. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
  58. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
  59. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
  60. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
  61. accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
  62. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
  63. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
  64. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
  65. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
  66. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
  67. accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
  68. accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
  69. accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
  70. accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
  71. accelforge/mapper/FFM/data.py +61 -0
  72. accelforge/mapper/FFM/main copy.py +236 -0
  73. accelforge/mapper/FFM/main.py +208 -0
  74. accelforge/mapper/FFM/mappings.py +510 -0
  75. accelforge/mapper/FFM/pmappings.py +310 -0
  76. accelforge/mapper/__init__.py +4 -0
  77. accelforge/mapper.py +0 -0
  78. accelforge/model/__init__.py +1 -0
  79. accelforge/model/_looptree/__init__.py +0 -0
  80. accelforge/model/_looptree/accesses.py +335 -0
  81. accelforge/model/_looptree/capacity/__init__.py +1 -0
  82. accelforge/model/_looptree/capacity/aggregators.py +36 -0
  83. accelforge/model/_looptree/capacity/capacity.py +47 -0
  84. accelforge/model/_looptree/energy.py +150 -0
  85. accelforge/model/_looptree/equivalent_ranks.py +29 -0
  86. accelforge/model/_looptree/latency/__init__.py +1 -0
  87. accelforge/model/_looptree/latency/latency.py +98 -0
  88. accelforge/model/_looptree/latency/memory.py +120 -0
  89. accelforge/model/_looptree/latency/processors.py +92 -0
  90. accelforge/model/_looptree/mapping_utilities.py +71 -0
  91. accelforge/model/_looptree/reuse/__init__.py +4 -0
  92. accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
  93. accelforge/model/_looptree/reuse/isl/des.py +59 -0
  94. accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
  95. accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
  96. accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
  97. accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
  98. accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
  99. accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
  100. accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
  101. accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
  102. accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
  103. accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
  104. accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
  105. accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
  106. accelforge/model/_looptree/run.py +122 -0
  107. accelforge/model/_looptree/types.py +26 -0
  108. accelforge/model/_looptree/visualization/__init__.py +0 -0
  109. accelforge/model/_looptree/visualization/occupancy.py +11 -0
  110. accelforge/model/main.py +222 -0
  111. accelforge/plotting/__init__.py +2 -0
  112. accelforge/plotting/mappings.py +219 -0
  113. accelforge/plotting/specs.py +57 -0
  114. accelforge/util/__init__.py +4 -0
  115. accelforge/util/_base_analysis_types.py +24 -0
  116. accelforge/util/_basetypes.py +1089 -0
  117. accelforge/util/_frozenset.py +36 -0
  118. accelforge/util/_isl.py +29 -0
  119. accelforge/util/_itertools.py +14 -0
  120. accelforge/util/_mathfuncs.py +57 -0
  121. accelforge/util/_parse_expressions.py +339 -0
  122. accelforge/util/_picklecache.py +32 -0
  123. accelforge/util/_setexpressions.py +268 -0
  124. accelforge/util/_sympy/__init__.py +0 -0
  125. accelforge/util/_sympy/broadcast_max.py +18 -0
  126. accelforge/util/_visualization.py +112 -0
  127. accelforge/util/_yaml.py +579 -0
  128. accelforge/util/parallel.py +193 -0
  129. accelforge-0.0.1.dist-info/METADATA +64 -0
  130. accelforge-0.0.1.dist-info/RECORD +258 -0
  131. accelforge-0.0.1.dist-info/WHEEL +5 -0
  132. accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
  133. accelforge-0.0.1.dist-info/top_level.txt +5 -0
  134. docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
  135. docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
  136. docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
  137. docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
  138. docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
  139. docs/_build/html/_sources/fastfusion.rst.txt +20 -0
  140. docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
  141. docs/_build/html/_sources/index.rst.txt +87 -0
  142. docs/_build/html/_sources/modules.rst.txt +7 -0
  143. docs/_build/html/_sources/notes/citation.rst.txt +45 -0
  144. docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
  145. docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
  146. docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
  147. docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
  148. docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
  149. docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
  150. docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
  151. docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
  152. docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
  153. docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
  154. docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
  155. docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
  156. docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
  157. docs/_build/html/_sources/notes/spec.rst.txt +36 -0
  158. docs/source/_ext/include_attrs.py +213 -0
  159. docs/source/_ext/include_docstring.py +364 -0
  160. docs/source/_ext/include_functions.py +154 -0
  161. docs/source/_ext/include_notebook.py +131 -0
  162. docs/source/_ext/include_yaml.py +119 -0
  163. docs/source/_ext/inherited_attributes.py +222 -0
  164. docs/source/_ext/paths.py +4 -0
  165. docs/source/conf.py +79 -0
  166. examples/arches/compute_in_memory/_include.yaml +74 -0
  167. examples/arches/compute_in_memory/_include_functions.py +229 -0
  168. examples/arches/compute_in_memory/_load_spec.py +57 -0
  169. examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
  170. examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
  171. examples/arches/compute_in_memory/components/misc.py +195 -0
  172. examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
  173. examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
  174. examples/arches/compute_in_memory/isaac.yaml +233 -0
  175. examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
  176. examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
  177. examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
  178. examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
  179. examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
  180. examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
  181. examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
  182. examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
  183. examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
  184. examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
  185. examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
  186. examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
  187. examples/arches/eyeriss.yaml +68 -0
  188. examples/arches/fanout_variations/at_glb.yaml +31 -0
  189. examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
  190. examples/arches/fanout_variations/at_mac.yaml +31 -0
  191. examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
  192. examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
  193. examples/arches/nvdla.yaml +47 -0
  194. examples/arches/simple.yaml +28 -0
  195. examples/arches/tpu_v4i.yaml +67 -0
  196. examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
  197. examples/misc/component_annotated.yaml +33 -0
  198. examples/workloads/gpt3_6.7B.yaml +124 -0
  199. examples/workloads/matmuls.yaml +20 -0
  200. examples/workloads/mobilenet_28.yaml +81 -0
  201. examples/workloads/mobilenet_various_separate.yaml +106 -0
  202. examples/workloads/three_matmuls_annotated.yaml +59 -0
  203. notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
  204. notebooks/compute_in_memory/_scripts.py +339 -0
  205. notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
  206. notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
  207. notebooks/paths.py +4 -0
  208. notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
  209. notebooks/tutorials/FFM.ipynb +3498 -0
  210. notebooks/tutorials/_include.py +48 -0
  211. notebooks/tutorials/component_energy_area.ipynb +363 -0
  212. tests/Q_mapping.yaml +38 -0
  213. tests/__init__.py +0 -0
  214. tests/conv.mapping.yaml +27 -0
  215. tests/conv.workload.yaml +13 -0
  216. tests/conv_sym.mapping.yaml +43 -0
  217. tests/copy.mapping.yaml +35 -0
  218. tests/copy.workload.yaml +15 -0
  219. tests/distribuffers/__init__.py +0 -0
  220. tests/distribuffers/multicast/test_cases.yaml +482 -0
  221. tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
  222. tests/distribuffers/spec/distributed.yaml +100 -0
  223. tests/distribuffers/spec/logical_arch.yaml +32 -0
  224. tests/distribuffers/spec/physical_arch.yaml +69 -0
  225. tests/distribuffers/test_binding.py +48 -0
  226. tests/frontend/__init__.py +0 -0
  227. tests/frontend/test_mapping_viz.py +52 -0
  228. tests/mapper/__init__.py +0 -0
  229. tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
  230. tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
  231. tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
  232. tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
  233. tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
  234. tests/mapper/test_mapping_to_isl.py +90 -0
  235. tests/mapper/test_spatial_reuse_analysis.py +67 -0
  236. tests/mapper/test_temporal_reuse_analysis.py +56 -0
  237. tests/mapper/util.py +58 -0
  238. tests/matmul.mapping.yaml +29 -0
  239. tests/matmul.workload.yaml +12 -0
  240. tests/matmul_spatial.mapping.yaml +44 -0
  241. tests/mha.renames.yaml +65 -0
  242. tests/mha.workload.yaml +67 -0
  243. tests/mha.yaml +59 -0
  244. tests/mha_full.workload.yaml +67 -0
  245. tests/mobilenet.workload.yaml +35 -0
  246. tests/mobilenet_long.workload.yaml +64 -0
  247. tests/pmappingcache.py +24 -0
  248. tests/processing_stage.arch.yaml +40 -0
  249. tests/snowcat.arch.yaml +36 -0
  250. tests/test_ffm_join_pmappings.py +106 -0
  251. tests/test_ffm_make_pmappings.py +82 -0
  252. tests/test_ffm_make_tile_shapes.py +49 -0
  253. tests/test_mapper.py +100 -0
  254. tests/test_model.py +37 -0
  255. tests/test_plotting.py +72 -0
  256. tests/test_processing_stage.py +46 -0
  257. tests/test_symbolic_model.py +248 -0
  258. tests/test_workload.py +141 -0
@@ -0,0 +1,129 @@
1
+ """
2
+ Defines Pydantic models to handle Binding Specifications that relate logical to
3
+ physical architectures.
4
+ """
5
+
6
+ from abc import abstractmethod
7
+ from typing import Dict, Tuple
8
+
9
+ from pydantic import StrictFloat
10
+ import islpy as isl
11
+
12
+ from accelforge.util._basetypes import ParsableDict, ParsableList, ParsableModel
13
+
14
+
15
+ class Domain(ParsableModel):
16
+ """
17
+ Represents an architecture dangling reference of the binding.
18
+ """
19
+
20
+ name: str
21
+
22
+ @property
23
+ @abstractmethod
24
+ def isl_space(self) -> isl.Space:
25
+ """Gets the domain as an isl.Space"""
26
+ raise NotImplementedError(f"{type(self)} has not implemented isl_space")
27
+
28
+ @property
29
+ @abstractmethod
30
+ def isl_universe(self) -> isl.Set:
31
+ """Gets the domain as an isl.Set"""
32
+ raise NotImplementedError(f"{type(self)} has not implemented isl_universe")
33
+
34
+
35
+ class LogicalDomain(Domain):
36
+ """
37
+ Represents the logical architecture domain space of logical dims × tensor ranks.
38
+ """
39
+
40
+ ranks: Tuple[str] = ("c", "h", "w", "p", "q", "r", "s")
41
+ l_dims: ParsableList[str]
42
+
43
+ @property
44
+ def isl_space(self) -> isl.Space:
45
+ return isl.Space.create_from_names(
46
+ isl.DEFAULT_CONTEXT, in_=self.ranks, out=self.l_dims
47
+ ).set_tuple_name(isl.dim_type.out, f"l_{self.name}_dims")
48
+
49
+ @property
50
+ def isl_universe(self) -> isl.Map:
51
+ return isl.Map.universe(self.isl_space)
52
+
53
+
54
+ class PhysicalDomain(Domain):
55
+ """
56
+ Represents the logical architecture domain space of physical dims.
57
+ The physical space is defined as the physical architecture dims.
58
+ """
59
+
60
+ p_dims: ParsableList[str]
61
+
62
+ @property
63
+ def isl_space(self) -> isl.Space:
64
+ return isl.Space.create_from_names(
65
+ isl.DEFAULT_CONTEXT, set=self.p_dims
66
+ ).set_tuple_name(isl.dim_type.set, f"p_{self.name}_dims")
67
+
68
+ @property
69
+ def isl_universe(self) -> isl.Set:
70
+ return isl.Set.universe(self.isl_space)
71
+
72
+
73
+ class BindingNode(ParsableModel):
74
+ """
75
+ How a logical architecture is implemented on a particular physical architecture
76
+ for a particular hardware level. Represents a injection relation between points
77
+ in logical to physical space.
78
+
79
+ The logical space is defined as logical architecture dims × tensor dims.
80
+ The physical space is defined as the physical architecture dims.
81
+ """
82
+
83
+ logical: LogicalDomain
84
+ physical: PhysicalDomain
85
+ relations: ParsableDict[str, str]
86
+
87
+ @property
88
+ def isl_relations(self) -> Dict[str, isl.Map]:
89
+ """
90
+ Converts the logical, physical, and binding relation strings into an
91
+ isl.Map representing the bindings at this binding node.
92
+ """
93
+
94
+ def islify_relation(key: str) -> isl.Map:
95
+ """Converts a relation at a given key into isl"""
96
+ relation: str = self.relations[key]
97
+ logical_space: isl.Space = self.logical.isl_space.set_tuple_name(
98
+ isl.dim_type.in_, f"{key}_ranks"
99
+ )
100
+
101
+ binding_space: isl.Space = logical_space.wrap().map_from_domain_and_range(
102
+ range=self.physical.isl_space,
103
+ )
104
+
105
+ # Simple bodge to get the binding space into a real space
106
+ binding_str: str = binding_space.to_str()
107
+ binding_str: str = f"{binding_str[:-1]}: {relation} {binding_str[-1]}"
108
+
109
+ binding: isl.Map = isl.Map.read_from_str(
110
+ ctx=isl.DEFAULT_CONTEXT, str=binding_str
111
+ )
112
+
113
+ return binding
114
+
115
+ isl_relations: Dict[str, isl.Map] = {
116
+ key: islify_relation(key) for key in self.relations
117
+ }
118
+
119
+ return isl_relations
120
+
121
+
122
+ class Binding(ParsableModel):
123
+ """
124
+ A collection of binding nodes that fully specifies a relation between the
125
+ logical and physical space.
126
+ """
127
+
128
+ # version: StrictFloat
129
+ nodes: ParsableList[BindingNode]
@@ -0,0 +1,2 @@
1
+ from accelforge.frontend.workload import *
2
+ from ._isl import get_rank_variable_bounds
@@ -0,0 +1,149 @@
1
+ import math
2
+ import islpy as isl
3
+
4
+ from accelforge.frontend.renames import RankVariable
5
+ from accelforge.frontend.workload import Workload, TensorName, Einsum, EinsumName
6
+
7
+
8
+ def get_einsum_operation_space(workload: Workload, einsum_name: str) -> isl.Set:
9
+ """Return isl.Set of all operations in an einsum."""
10
+ einsum_shape = workload.get_iteration_space_shape_isl_string(einsum_name)
11
+ rank_variable_names = ",".join(
12
+ map(str, workload.einsums[einsum_name].rank_variables)
13
+ )
14
+ try:
15
+ return isl.Set(
16
+ f"{{ {einsum_name}_operation[{rank_variable_names}] : {einsum_shape} }}"
17
+ )
18
+ except:
19
+ raise Exception(f"Error creating isl.Set for {einsum_name}: {einsum_shape}")
20
+
21
+
22
+ def get_dim_bounds(isl_set: isl.Set) -> list[int]:
23
+ bounds = []
24
+ for i in range(isl_set.dim(isl.dim_type.set)):
25
+ max_val = isl_set.dim_max_val(i)
26
+ min_val = isl_set.dim_min_val(i)
27
+ shape = max_val - min_val + 1 # max is inclusive
28
+ try:
29
+ bounds.append(shape.to_python())
30
+ except:
31
+ raise Exception(
32
+ f"Shape is not an integer. Are all rank variables bounded? "
33
+ f"Shape {shape} for rank variable {i} in {isl_set}"
34
+ )
35
+ return bounds
36
+
37
+
38
+ def get_rank_variable_bounds(
39
+ workload: Workload, einsum_name: EinsumName
40
+ ) -> dict[RankVariable, int]:
41
+ """Return dictionary mapping rank variable name to bound."""
42
+ operation_space = get_einsum_operation_space(workload, einsum_name)
43
+ dim_shapes = get_dim_bounds(operation_space)
44
+ return {
45
+ rank_var: shape
46
+ for rank_var, shape in zip(
47
+ workload.einsums[einsum_name].rank_variables, dim_shapes
48
+ )
49
+ }
50
+
51
+
52
+ def get_projection_multi_aff(einsum: Einsum, tensor: TensorName) -> isl.MultiAff:
53
+ """Return isl.MultiAff of projection from einsum to tensor."""
54
+ rank_variables = einsum.rank_variables
55
+ projection = einsum.tensor_accesses[tensor].projection
56
+
57
+ rank_variables_str = ",".join(map(str, rank_variables))
58
+
59
+ projection_str = ", ".join(
60
+ f"{rank_name}={rank_projection}"
61
+ for rank_name, rank_projection in projection.items()
62
+ )
63
+
64
+ return isl.MultiAff(
65
+ f"{{ {einsum.name}_operation[{rank_variables_str}] -> "
66
+ f"{tensor}[{projection_str}] }}"
67
+ )
68
+
69
+
70
+ def get_projection_map(einsum: Einsum, tensor: TensorName) -> isl.Map:
71
+ """Return isl.Map of projection from einsum to tensor."""
72
+ return get_projection_multi_aff(einsum, tensor).as_map()
73
+
74
+
75
+ def get_tensor_data_space(workload: Workload, tensor: TensorName) -> isl.Set:
76
+ """
77
+ Get tensor data space based on the operation spaces of (for lack of
78
+ a better term)'canonical' Einsums.
79
+
80
+ Canonical Einsums (for this purpose) are all reader Einsums if the
81
+ tensor is only ever read or all writer EInsums if the tensor is ever
82
+ an output tensor.
83
+ """
84
+ writer_einsums = workload.einsums_with_tensor_as_output(tensor)
85
+ if len(writer_einsums) == 0:
86
+ reader_einsums = workload.einsums_with_tensor_as_input(tensor)
87
+ canonical_einsums = reader_einsums
88
+ else:
89
+ canonical_einsums = writer_einsums
90
+
91
+ tensor_data_space = None
92
+ for einsum in canonical_einsums:
93
+ operation_space = get_einsum_operation_space(workload, einsum.name)
94
+ projection_map = get_projection_map(einsum, tensor)
95
+ if tensor_data_space is None:
96
+ tensor_data_space = operation_space.apply(projection_map)
97
+ else:
98
+ tensor_data_space = tensor_data_space.intersect(
99
+ operation_space.apply(projection_map)
100
+ )
101
+
102
+ return tensor_data_space
103
+
104
+
105
+ def _card_box(data_space: isl.Set) -> int:
106
+ dims = []
107
+ for i in range(data_space.dim(isl.dim_type.set)):
108
+ dim_min = data_space.dim_min(i)
109
+ dim_max = data_space.dim_max(i)
110
+
111
+ if dim_min.is_cst() and dim_max.is_cst():
112
+ min_val = dim_min.as_aff().get_constant_val().to_python()
113
+ max_val = dim_max.as_aff().get_constant_val().to_python()
114
+ else:
115
+ raise ValueError(f"Data space is not rectangular: {data_space}")
116
+
117
+ dims.append(max_val - min_val + 1)
118
+
119
+ return math.prod(dims)
120
+
121
+
122
+ ERRMSG = """ Non-box-shaped sets are not supported. This happens if ISL is installed without
123
+ Barvinok support. Please install ISL with Barvinok support, or use workloads with
124
+ rectangular data spaces and operation spaces. Non-rectangular spaces occur when the
125
+ workload contains complex expressions involving the rank variables, such as
126
+ multi-rank-variable inequalities.
127
+ Offending space: {space}.
128
+ """
129
+
130
+
131
+ def get_tensor_size(workload: Workload, tensor: TensorName):
132
+ """Get the size (num. of elements) of a tensor."""
133
+ data_space = get_tensor_data_space(workload, tensor)
134
+ if data_space.is_box():
135
+ return _card_box(data_space)
136
+ if not hasattr(data_space, "card"):
137
+ raise RuntimeError(ERRMSG.format(space=str(data_space)))
138
+ card_pwqp = isl.PwQPolynomial.card(data_space)
139
+ return card_pwqp.eval(card_pwqp.domain().sample_point()).to_python()
140
+
141
+
142
+ def get_operation_space_size(workload: Workload, einsum_name: str):
143
+ operation_space = get_einsum_operation_space(workload, einsum_name)
144
+ if operation_space.is_box():
145
+ return _card_box(operation_space)
146
+ if not hasattr(operation_space, "card"):
147
+ raise RuntimeError(ERRMSG.format(space=str(operation_space)))
148
+ card_pwqp = isl.PwQPolynomial.card(operation_space)
149
+ return card_pwqp.eval(card_pwqp.domain().sample_point()).to_python()
@@ -0,0 +1,141 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any
3
+ from functools import reduce
4
+ from operator import mul
5
+
6
+ import sympy
7
+
8
+ from accelforge.frontend.workload import (
9
+ TensorName,
10
+ Einsum,
11
+ EinsumName,
12
+ Workload,
13
+ Rank,
14
+ RankVariable,
15
+ )
16
+ from ._isl import get_rank_variable_bounds
17
+
18
+
19
+ def get_projection_expr(einsum: Einsum, tensor: TensorName) -> dict[Rank, sympy.Expr]:
20
+ projection = einsum.tensor_accesses[tensor].projection
21
+ return {
22
+ rank_name: sympy.parsing.sympy_parser.parse_expr(proj_str)
23
+ for rank_name, proj_str in projection.items()
24
+ }
25
+
26
+
27
+ class Irrelevant:
28
+ pass
29
+
30
+
31
+ @dataclass
32
+ class Relevant:
33
+ rank: Any
34
+
35
+
36
+ @dataclass
37
+ class PartiallyRelevant:
38
+ rank: Any
39
+
40
+
41
+ def get_rank_variable_relevancy(einsum: Einsum, tensor: TensorName):
42
+ relevancy = {}
43
+ projection = einsum.tensor_accesses[tensor].projection
44
+ for rank_variable in einsum.rank_variables:
45
+ relevancy[rank_variable] = Irrelevant()
46
+ for rank_name, projection_str in projection.items():
47
+ projection_expr = sympy.parsing.sympy_parser.parse_expr(projection_str)
48
+ is_simple = len(sympy.Add.make_args(projection_expr)) == 1
49
+ is_relevant = (
50
+ sympy.symbols(f"{rank_variable}") in projection_expr.free_symbols
51
+ )
52
+
53
+ if not is_relevant:
54
+ continue
55
+
56
+ if is_simple:
57
+ relevancy[rank_variable] = Relevant(rank=rank_name)
58
+ else:
59
+ relevancy[rank_variable] = PartiallyRelevant(rank=rank_name)
60
+
61
+ break
62
+ return relevancy
63
+
64
+
65
+ def compute_dense_tile_occupancy(
66
+ projection_expr: dict[str, sympy.Expr], rank_variable_shapes: dict
67
+ ):
68
+ substitutions = [
69
+ (rank_variable, rank_variable_shape - 1)
70
+ for rank_variable, rank_variable_shape in rank_variable_shapes.items()
71
+ ]
72
+ return reduce(
73
+ mul,
74
+ [index_expr.subs(substitutions) + 1 for index_expr in projection_expr.values()],
75
+ )
76
+
77
+
78
+ def compute_rank_occupancy(projection_expr: sympy.Expr, rank_variable_shapes: dict):
79
+ substitutions = [
80
+ (rank_variable, rank_variable_shape - 1)
81
+ for rank_variable, rank_variable_shape in rank_variable_shapes.items()
82
+ ]
83
+ return projection_expr.subs(substitutions) + 1
84
+
85
+
86
+ def get_stride_and_halo_of_einsum(
87
+ einsum_name: str,
88
+ workload: Workload,
89
+ rank_variable_bounds: dict[RankVariable, int] | None = None,
90
+ ) -> dict[TensorName, dict[tuple[Rank, RankVariable]], tuple[int, int]]:
91
+ """
92
+ Get stride and halo (initial delta) for an Einsum in workload.
93
+
94
+ Returns dictionary mapping tensor to another dictionary mapping
95
+ (rank, rank_var) to the stride and halo.
96
+ """
97
+ stride_and_halo = {}
98
+ einsum = workload.einsums[einsum_name]
99
+ if rank_variable_bounds is None:
100
+ shape = get_rank_variable_bounds(workload, einsum_name)
101
+ else:
102
+ shape = rank_variable_bounds
103
+ for tensor in einsum.tensor_names:
104
+ stride_and_halo[tensor] = {}
105
+ tensor_stride_and_halo = stride_and_halo[tensor]
106
+
107
+ projection = get_projection_expr(einsum, tensor)
108
+ tensor_accesses = einsum.tensor_accesses[tensor]
109
+ for rank, rank_vars in tensor_accesses.rank2rank_variables.items():
110
+ rank_projection = projection[rank]
111
+ for rank_var in rank_vars:
112
+ stride = rank_projection.coeff(rank_var)
113
+
114
+ # Careful: in-place mutation of cons_shape
115
+ original_shape = shape[rank_var]
116
+ shape[rank_var] = 1
117
+ halo = compute_rank_occupancy(rank_projection, shape) - 1
118
+ shape[rank_var] = original_shape
119
+
120
+ tensor_stride_and_halo[(rank, rank_var)] = (stride, halo)
121
+ return stride_and_halo
122
+
123
+
124
+ def get_stride_and_halo(
125
+ workload: Workload,
126
+ ) -> dict[
127
+ tuple[EinsumName, TensorName],
128
+ dict[tuple[Rank, RankVariable], tuple[int, int]],
129
+ ]:
130
+ """
131
+ Get stride and halo (initial delta) for Einsums in workload.
132
+
133
+ Returns dictionary mapping (Einsum, tensor) to another dictionary mapping
134
+ (rank, rank_var) to the stride and halo.
135
+ """
136
+ stride_and_halo = {}
137
+ for einsum in workload.einsums:
138
+ stride_and_halo_of_einsum = get_stride_and_halo_of_einsum(einsum.name, workload)
139
+ for tensor, ranks2stride_and_halo in stride_and_halo_of_einsum.items():
140
+ stride_and_halo[(einsum.name, tensor)] = ranks2stride_and_halo
141
+ return stride_and_halo