accelforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. accelforge/__init__.py +21 -0
  2. accelforge/_accelerated_imports.py +16 -0
  3. accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
  4. accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
  5. accelforge/_deprecate/_simanneal/simanneal.py +666 -0
  6. accelforge/_deprecate/_simanneal/tracking.py +105 -0
  7. accelforge/_deprecate/_simanneal/wrappers.py +218 -0
  8. accelforge/_deprecate/_simanneal2/__init__.py +7 -0
  9. accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
  10. accelforge/_deprecate/_simanneal2/tracking.py +116 -0
  11. accelforge/_deprecate/compatibility_util.py +181 -0
  12. accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
  13. accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
  14. accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
  15. accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
  16. accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
  17. accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
  18. accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
  19. accelforge/_deprecate/tags.py +69 -0
  20. accelforge/_deprecate/viz/__init__.py +0 -0
  21. accelforge/_deprecate/viz/interactive.py +159 -0
  22. accelforge/_deprecate/viz/reservationtree.py +307 -0
  23. accelforge/_deprecate/viz/ski_slope.py +88 -0
  24. accelforge/_version.py +15 -0
  25. accelforge/examples.py +39 -0
  26. accelforge/frontend/__init__.py +10 -0
  27. accelforge/frontend/_binding.py +129 -0
  28. accelforge/frontend/_workload_isl/__init__.py +2 -0
  29. accelforge/frontend/_workload_isl/_isl.py +149 -0
  30. accelforge/frontend/_workload_isl/_symbolic.py +141 -0
  31. accelforge/frontend/arch copy.py +1544 -0
  32. accelforge/frontend/arch.py +1642 -0
  33. accelforge/frontend/config.py +63 -0
  34. accelforge/frontend/mapper/__init__.py +5 -0
  35. accelforge/frontend/mapper/ffm.py +126 -0
  36. accelforge/frontend/mapper/mapper.py +7 -0
  37. accelforge/frontend/mapper/metrics.py +30 -0
  38. accelforge/frontend/mapping/__init__.py +1 -0
  39. accelforge/frontend/mapping/mapping.py +1736 -0
  40. accelforge/frontend/model.py +14 -0
  41. accelforge/frontend/renames.py +150 -0
  42. accelforge/frontend/spec copy.py +230 -0
  43. accelforge/frontend/spec.py +301 -0
  44. accelforge/frontend/variables.py +12 -0
  45. accelforge/frontend/workload.py +952 -0
  46. accelforge/mapper/FFM/__init__.py +9 -0
  47. accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
  48. accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
  49. accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
  50. accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
  51. accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
  52. accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
  53. accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
  54. accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
  55. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
  56. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
  57. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
  58. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
  59. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
  60. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
  61. accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
  62. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
  63. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
  64. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
  65. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
  66. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
  67. accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
  68. accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
  69. accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
  70. accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
  71. accelforge/mapper/FFM/data.py +61 -0
  72. accelforge/mapper/FFM/main copy.py +236 -0
  73. accelforge/mapper/FFM/main.py +208 -0
  74. accelforge/mapper/FFM/mappings.py +510 -0
  75. accelforge/mapper/FFM/pmappings.py +310 -0
  76. accelforge/mapper/__init__.py +4 -0
  77. accelforge/mapper.py +0 -0
  78. accelforge/model/__init__.py +1 -0
  79. accelforge/model/_looptree/__init__.py +0 -0
  80. accelforge/model/_looptree/accesses.py +335 -0
  81. accelforge/model/_looptree/capacity/__init__.py +1 -0
  82. accelforge/model/_looptree/capacity/aggregators.py +36 -0
  83. accelforge/model/_looptree/capacity/capacity.py +47 -0
  84. accelforge/model/_looptree/energy.py +150 -0
  85. accelforge/model/_looptree/equivalent_ranks.py +29 -0
  86. accelforge/model/_looptree/latency/__init__.py +1 -0
  87. accelforge/model/_looptree/latency/latency.py +98 -0
  88. accelforge/model/_looptree/latency/memory.py +120 -0
  89. accelforge/model/_looptree/latency/processors.py +92 -0
  90. accelforge/model/_looptree/mapping_utilities.py +71 -0
  91. accelforge/model/_looptree/reuse/__init__.py +4 -0
  92. accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
  93. accelforge/model/_looptree/reuse/isl/des.py +59 -0
  94. accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
  95. accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
  96. accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
  97. accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
  98. accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
  99. accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
  100. accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
  101. accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
  102. accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
  103. accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
  104. accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
  105. accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
  106. accelforge/model/_looptree/run.py +122 -0
  107. accelforge/model/_looptree/types.py +26 -0
  108. accelforge/model/_looptree/visualization/__init__.py +0 -0
  109. accelforge/model/_looptree/visualization/occupancy.py +11 -0
  110. accelforge/model/main.py +222 -0
  111. accelforge/plotting/__init__.py +2 -0
  112. accelforge/plotting/mappings.py +219 -0
  113. accelforge/plotting/specs.py +57 -0
  114. accelforge/util/__init__.py +4 -0
  115. accelforge/util/_base_analysis_types.py +24 -0
  116. accelforge/util/_basetypes.py +1089 -0
  117. accelforge/util/_frozenset.py +36 -0
  118. accelforge/util/_isl.py +29 -0
  119. accelforge/util/_itertools.py +14 -0
  120. accelforge/util/_mathfuncs.py +57 -0
  121. accelforge/util/_parse_expressions.py +339 -0
  122. accelforge/util/_picklecache.py +32 -0
  123. accelforge/util/_setexpressions.py +268 -0
  124. accelforge/util/_sympy/__init__.py +0 -0
  125. accelforge/util/_sympy/broadcast_max.py +18 -0
  126. accelforge/util/_visualization.py +112 -0
  127. accelforge/util/_yaml.py +579 -0
  128. accelforge/util/parallel.py +193 -0
  129. accelforge-0.0.1.dist-info/METADATA +64 -0
  130. accelforge-0.0.1.dist-info/RECORD +258 -0
  131. accelforge-0.0.1.dist-info/WHEEL +5 -0
  132. accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
  133. accelforge-0.0.1.dist-info/top_level.txt +5 -0
  134. docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
  135. docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
  136. docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
  137. docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
  138. docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
  139. docs/_build/html/_sources/fastfusion.rst.txt +20 -0
  140. docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
  141. docs/_build/html/_sources/index.rst.txt +87 -0
  142. docs/_build/html/_sources/modules.rst.txt +7 -0
  143. docs/_build/html/_sources/notes/citation.rst.txt +45 -0
  144. docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
  145. docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
  146. docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
  147. docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
  148. docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
  149. docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
  150. docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
  151. docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
  152. docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
  153. docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
  154. docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
  155. docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
  156. docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
  157. docs/_build/html/_sources/notes/spec.rst.txt +36 -0
  158. docs/source/_ext/include_attrs.py +213 -0
  159. docs/source/_ext/include_docstring.py +364 -0
  160. docs/source/_ext/include_functions.py +154 -0
  161. docs/source/_ext/include_notebook.py +131 -0
  162. docs/source/_ext/include_yaml.py +119 -0
  163. docs/source/_ext/inherited_attributes.py +222 -0
  164. docs/source/_ext/paths.py +4 -0
  165. docs/source/conf.py +79 -0
  166. examples/arches/compute_in_memory/_include.yaml +74 -0
  167. examples/arches/compute_in_memory/_include_functions.py +229 -0
  168. examples/arches/compute_in_memory/_load_spec.py +57 -0
  169. examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
  170. examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
  171. examples/arches/compute_in_memory/components/misc.py +195 -0
  172. examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
  173. examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
  174. examples/arches/compute_in_memory/isaac.yaml +233 -0
  175. examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
  176. examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
  177. examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
  178. examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
  179. examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
  180. examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
  181. examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
  182. examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
  183. examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
  184. examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
  185. examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
  186. examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
  187. examples/arches/eyeriss.yaml +68 -0
  188. examples/arches/fanout_variations/at_glb.yaml +31 -0
  189. examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
  190. examples/arches/fanout_variations/at_mac.yaml +31 -0
  191. examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
  192. examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
  193. examples/arches/nvdla.yaml +47 -0
  194. examples/arches/simple.yaml +28 -0
  195. examples/arches/tpu_v4i.yaml +67 -0
  196. examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
  197. examples/misc/component_annotated.yaml +33 -0
  198. examples/workloads/gpt3_6.7B.yaml +124 -0
  199. examples/workloads/matmuls.yaml +20 -0
  200. examples/workloads/mobilenet_28.yaml +81 -0
  201. examples/workloads/mobilenet_various_separate.yaml +106 -0
  202. examples/workloads/three_matmuls_annotated.yaml +59 -0
  203. notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
  204. notebooks/compute_in_memory/_scripts.py +339 -0
  205. notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
  206. notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
  207. notebooks/paths.py +4 -0
  208. notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
  209. notebooks/tutorials/FFM.ipynb +3498 -0
  210. notebooks/tutorials/_include.py +48 -0
  211. notebooks/tutorials/component_energy_area.ipynb +363 -0
  212. tests/Q_mapping.yaml +38 -0
  213. tests/__init__.py +0 -0
  214. tests/conv.mapping.yaml +27 -0
  215. tests/conv.workload.yaml +13 -0
  216. tests/conv_sym.mapping.yaml +43 -0
  217. tests/copy.mapping.yaml +35 -0
  218. tests/copy.workload.yaml +15 -0
  219. tests/distribuffers/__init__.py +0 -0
  220. tests/distribuffers/multicast/test_cases.yaml +482 -0
  221. tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
  222. tests/distribuffers/spec/distributed.yaml +100 -0
  223. tests/distribuffers/spec/logical_arch.yaml +32 -0
  224. tests/distribuffers/spec/physical_arch.yaml +69 -0
  225. tests/distribuffers/test_binding.py +48 -0
  226. tests/frontend/__init__.py +0 -0
  227. tests/frontend/test_mapping_viz.py +52 -0
  228. tests/mapper/__init__.py +0 -0
  229. tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
  230. tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
  231. tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
  232. tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
  233. tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
  234. tests/mapper/test_mapping_to_isl.py +90 -0
  235. tests/mapper/test_spatial_reuse_analysis.py +67 -0
  236. tests/mapper/test_temporal_reuse_analysis.py +56 -0
  237. tests/mapper/util.py +58 -0
  238. tests/matmul.mapping.yaml +29 -0
  239. tests/matmul.workload.yaml +12 -0
  240. tests/matmul_spatial.mapping.yaml +44 -0
  241. tests/mha.renames.yaml +65 -0
  242. tests/mha.workload.yaml +67 -0
  243. tests/mha.yaml +59 -0
  244. tests/mha_full.workload.yaml +67 -0
  245. tests/mobilenet.workload.yaml +35 -0
  246. tests/mobilenet_long.workload.yaml +64 -0
  247. tests/pmappingcache.py +24 -0
  248. tests/processing_stage.arch.yaml +40 -0
  249. tests/snowcat.arch.yaml +36 -0
  250. tests/test_ffm_join_pmappings.py +106 -0
  251. tests/test_ffm_make_pmappings.py +82 -0
  252. tests/test_ffm_make_tile_shapes.py +49 -0
  253. tests/test_mapper.py +100 -0
  254. tests/test_model.py +37 -0
  255. tests/test_plotting.py +72 -0
  256. tests/test_processing_stage.py +46 -0
  257. tests/test_symbolic_model.py +248 -0
  258. tests/test_workload.py +141 -0
@@ -0,0 +1,188 @@
1
+ """
2
+ Relevant name changes:
3
+ - [logical] buffer/lbuf -> buffet
4
+ - [logical] comp/lcomp -> compute_einsum
5
+ -
6
+ """
7
+
8
+ from abc import ABC
9
+
10
+ from collections import defaultdict
11
+ from dataclasses import dataclass
12
+ from typing import List, TypeAlias
13
+
14
+ import islpy as isl
15
+
16
+ from accelforge.frontend.mapping import Compute, MappingNode
17
+ from accelforge.frontend.workload import TensorName
18
+ from accelforge.model._looptree.types import Buffet
19
+
20
+
21
+ # Mapper intermediates.
22
+ ##
23
+ # @brief Iteration -> Operation relation that specifies the tiling.
24
+ #
25
+ # The tiling relation allows us to distribute data and operations using the
26
+ # skew and data distribution relations.
27
+ #
28
+ # The tiling relation may have unspecified bounds which will be inferred by
29
+ # LoopTree. The tiling relation that goes to the nest analysis is guaranteed
30
+ # to be fully specified.
31
+ EinsumName: TypeAlias = str
32
+ "Einsum's identifier."
33
+ Tiling: TypeAlias = isl.Map
34
+ "Tiling of data and operations."
35
+ BranchTiling: TypeAlias = defaultdict[MappingNode, Tiling]
36
+ "Relation between a node and its tiling."
37
+ BuffetTiling: TypeAlias = defaultdict[Buffet, Tiling]
38
+ "Relation between a buffet and its tiling."
39
+
40
+
41
+ @dataclass(frozen=True, slots=True)
42
+ class Tag(ABC): # pylint: disable=too-few-public-methods
43
+ """Associating an element with its type metadata without introspection?"""
44
+
45
+
46
+ class TemporalTag(Tag): # pylint: disable=too-few-public-methods
47
+ """The associated element is temporally spreading?"""
48
+
49
+
50
+ @dataclass(frozen=True, slots=True)
51
+ class SpatialTag(Tag): # pylint: disable=too-few-public-methods
52
+ """The associated element is spatially spreading?"""
53
+
54
+ spatial_dim: int
55
+ "The spatial dim in a given buffer?"
56
+ buffer: MappingNode
57
+ "The buffer the spatial dim is across?"
58
+
59
+
60
+ class PipelineTag(Tag): # pylint: disable=too-few-public-methods
61
+ """The associated element is pipelined?"""
62
+
63
+
64
+ class SequentialTag(Tag): # pylint: disable=too-few-public-methods
65
+ """The associated element is serialized?"""
66
+
67
+
68
+ TEMPORAL_TAGS = (TemporalTag, SequentialTag)
69
+ BRANCH_TAGS = (PipelineTag, SequentialTag)
70
+ LOOP_TAGS = (TemporalTag, SpatialTag)
71
+
72
+
73
+ @dataclass(frozen=True, slots=True)
74
+ class TaggedMap: # pylint: disable=too-few-public-methods
75
+ """A :class:`isl.Map` with its dimensions tagged."""
76
+
77
+ tags: List[Tag]
78
+ map_: isl.Map
79
+
80
+ def __repr__(self):
81
+ return f"{type(self)}({self.tags}, {self.map_})"
82
+
83
+
84
+ class Occupancy(TaggedMap): # pylint: disable=too-few-public-methods
85
+ """Location of data in [logical?] hardware elements."""
86
+
87
+ def __init__(self, tags: list[Tag], map_: isl.Map):
88
+ assert len(tags) == map_.dim(isl.dim_type.in_), (
89
+ "Occupancy labels input dims with tags\n"
90
+ "-------------------------------------\n"
91
+ f"tags: {tags}\n"
92
+ f"map: {map_}\n"
93
+ )
94
+ super().__init__(tags, map_)
95
+
96
+
97
+ class OperationOccupancy(TaggedMap): # pylint: disable=too-few-public-methods
98
+ """Location of operations in [logical?] hardware elements."""
99
+
100
+
101
+ class Fill(TaggedMap):
102
+ """Spacetime -> fill of a logical buffer"""
103
+
104
+ def __init__(self, tags: list[Tag], map_: isl.Map):
105
+ assert len(tags) == map_.dim(isl.dim_type.in_), (
106
+ "Fill labels input dims with tags\n"
107
+ "--------------------------------\n"
108
+ f"tags: {tags}\n"
109
+ f"map: {map_}\n"
110
+ )
111
+ super().__init__(tags, map_)
112
+
113
+
114
+ class Skew(TaggedMap): # pylint: disable=too-few-public-methods
115
+ """TODO: Figure out what this is."""
116
+
117
+
118
+ @dataclass(frozen=True, slots=True)
119
+ class BufferTensorEinsum:
120
+ """
121
+ A buffet relating a [logical?] hardware element storing data, a tensor it
122
+ contains, and the [logical?] hardware element that is requesting the tensor.
123
+
124
+ See Also:
125
+ ---------
126
+ :class:`accelforge.model._looptree.reuse.Buffet`
127
+ """
128
+
129
+ buffer: str
130
+ "The logical name of the buffer supplying the tensor."
131
+ tensor: TensorName
132
+ "The tensor being supplied."
133
+ einsum: Compute
134
+ "The leaf in mapping doing the einsum compute on tensor."
135
+
136
+
137
+ @dataclass(frozen=True, slots=True)
138
+ class ComputeEinsum:
139
+ """A logical computation the workload? needs to carry out."""
140
+
141
+ compute: str
142
+ """TODO: Figure out what this does."""
143
+ branch_leaf_node: Compute
144
+ """TODO: The compute element at the leaf of a :class:`BranchTiling`"""
145
+
146
+
147
+ # Output classes.
148
+ @dataclass(frozen=True, slots=True)
149
+ class SkewsInfo: # pylint: disable=too-few-public-methods
150
+ """TODO: Figure out what this does."""
151
+
152
+ bte_to_skew: defaultdict[BufferTensorEinsum, Skew]
153
+ """Relates a :class:`~.BufferTensorEinsum` to a :class:`~.Skew`"""
154
+ ce_unit_to_skew: defaultdict[ComputeEinsum, Skew]
155
+ """Relates a :class:`~.ComputeEinsum` to a :class:`~.Skew`"""
156
+
157
+
158
+ @dataclass(frozen=True, slots=True)
159
+ class MappingAnalysisResult: # pylint: disable=too-few-public-methods
160
+ """
161
+ Results of mapping analysis that will become input into reuse
162
+ analysis.
163
+ """
164
+
165
+ buffet_direct_above_sequential: defaultdict[Buffet, bool]
166
+ """
167
+ Whether a buffet is right above a sequential node. This is used when calculating
168
+ capacity since some data can be dropped earlier than usual when using sequential
169
+ mapping without tiling.
170
+ """
171
+ buffet_to_occupancy: defaultdict[BufferTensorEinsum, Occupancy]
172
+ """The occupancy of every buffet as defined in the mapping."""
173
+ compute_einsum_to_occupancy: defaultdict[ComputeEinsum, OperationOccupancy]
174
+ """The occupancy of every compute unit."""
175
+ # TODO: Figure out if this is deprecated:
176
+ # https://github.com/NVlabs/timeloop/blob/32370826fdf1aa3c8deb0c93e6b2a2fc7cf053aa/include/loop-analysis/mapping-to-isl/fused-mapping-to-isl.hpp#L31-L35
177
+ # node_to_buffets
178
+ # Buffets found between the current root/branch node and the next one.
179
+ branch_tiling: BranchTiling
180
+ """
181
+ Tiling of each branch. The tiling is a relation between tiling variables and
182
+ operations. An uncompletely tiled branch will have multiple-valued :class:`isl.Map`.
183
+ """
184
+ compute_to_assumed_parallelism: defaultdict[MappingNode, float]
185
+ """
186
+ We can assume an amount of parallelism to quickly calculate approx. compute
187
+ latency by simply dividing number of operations with assumed parallelism.
188
+ """
@@ -0,0 +1,260 @@
1
+ """
2
+ Handles the ISL spatial reuse functions.
3
+ """
4
+
5
+ from abc import ABC, abstractmethod
6
+ from dataclasses import dataclass
7
+ from typing import Optional
8
+
9
+ import islpy as isl
10
+
11
+ from accelforge.frontend.mapping import MappingNode
12
+ from accelforge.model._looptree.reuse.isl.isl_functions import (
13
+ insert_equal_dims_map,
14
+ reorder_projector,
15
+ )
16
+ from accelforge.model._looptree.reuse.isl.mapping_to_isl.types import (
17
+ TEMPORAL_TAGS,
18
+ Fill,
19
+ Occupancy,
20
+ SpatialTag,
21
+ Tag,
22
+ TaggedMap,
23
+ )
24
+
25
+
26
+ class Transfers(TaggedMap):
27
+ """Transfers between regions in spacetime."""
28
+
29
+
30
+ class Reads(TaggedMap):
31
+ """Reads between regions in spacetime."""
32
+
33
+
34
+ @dataclass(frozen=True, slots=True)
35
+ class TransferInfo:
36
+ """Data transfer information about a certain [subset] of the chip."""
37
+
38
+ # Crucial information to transfer info.
39
+ fulfilled_fill: Transfers
40
+ """Fills done by peer-to-peer transfers."""
41
+ unfulfilled_fill: Fill
42
+ """Fills not performed."""
43
+ parent_reads: Reads
44
+ """Fills done by parent-to-child transfers."""
45
+ hops: isl.PwQPolynomial
46
+ """Peer-to-peer transfer cost metric across spacetime."""
47
+
48
+ # Metadata on what is occurring.
49
+ link_transfer: bool
50
+
51
+
52
+ class TransferModel(ABC):
53
+ """
54
+ A peer-to-peer/multicast transfer model for spatial analysis.
55
+ """
56
+
57
+ @abstractmethod
58
+ def apply(self, buff: MappingNode, fills: Fill, occs: Occupancy) -> TransferInfo:
59
+ """
60
+ Given a buffer, its fills across time, and its occupancies across time,
61
+ calculate the spatial transfers.
62
+
63
+ Parameters
64
+ ----------
65
+ buff:
66
+ The buffer whose spatial analysis is being considered.
67
+ fills:
68
+ The fill of `buffer` across time from parents.
69
+ occs:
70
+ The occupancy of `buffer` across time.
71
+
72
+ Returns
73
+ -------
74
+ Fills that were fulfilled, Fills that were unfilled, and parent reads per
75
+ position in spacetime. Then, gets hops per timestep.
76
+ """
77
+ raise NotImplementedError(
78
+ f"{type(self)} has not implemented `apply(self, MappingNode, Fill, Occupancy)`"
79
+ )
80
+
81
+ def __repr__(self):
82
+ """Returns what transfer model it is."""
83
+ return f"{type(self)}"
84
+
85
+
86
+ class SimpleLinkTransferModel(TransferModel):
87
+ """
88
+ Basic link transfer model.
89
+ """
90
+
91
+ def apply(self, buff: MappingNode, fills: Fill, occs: Occupancy) -> TransferInfo:
92
+ # Sanity check the fill is for the same occupancy. Necessary but insufficient proof.
93
+ assert fills.tags == occs.tags, (
94
+ "Fill and Occupancy mismatch"
95
+ "---------------------------"
96
+ f"Fill: {fills}\n"
97
+ f"Occs: {occs}\n"
98
+ )
99
+
100
+ # Gets number of input dimensions, along with spatial and temporal indices.
101
+ n: int = fills.map_.dim(isl.dim_type.in_)
102
+ spatial_dims: list[int] = get_spatial_tags_idxs(fills.tags, buff)
103
+ last_temporal: Optional[int] = get_last_temporal_tag_idx(fills.tags)
104
+
105
+ # No temporal or no spatial dims, you're just not moving data across time
106
+ # so no transfers occurring.
107
+ if last_temporal is None or len(spatial_dims) == 0:
108
+ return TransferInfo(
109
+ fulfilled_fill=Transfers(
110
+ fills.tags, fills.map_.subtract(fills.map_)
111
+ ), # Empty map
112
+ unfulfilled_fill=fills, # No fulfilled_fills, so only unfulfilled_fills
113
+ parent_reads=Reads(
114
+ occs.tags, occs.map_.subtract(occs.map_)
115
+ ), # Empty map
116
+ hops=isl.PwQPolynomial.from_qpolynomial(
117
+ isl.QPolynomial.zero_on_domain(fills.map_.domain().get_space())
118
+ ),
119
+ link_transfer=True,
120
+ )
121
+ # Gets the connectivity between points in space.
122
+ connectivity: isl.Map = make_mesh_connectivity(
123
+ len(spatial_dims), occs.map_.get_tuple_name(isl.dim_type.in_)
124
+ )
125
+ padded_connectivity: isl.Map = insert_equal_dims_map(
126
+ connectivity, 0, 0, n - len(spatial_dims) - 1
127
+ )
128
+ permutation: list[int] = make_connectivity_permutation(spatial_dims, n)
129
+ reorder_map: isl.Map = reorder_projector(
130
+ permutation, occs.map_.get_tuple_name(isl.dim_type.in_)
131
+ )
132
+ complete_connectivity: isl.Map = reorder_map.apply_range(
133
+ padded_connectivity
134
+ ).apply_range(reorder_map.reverse())
135
+
136
+ # Gets data available from neighbors at each point in space per time.
137
+ available_from_neighbors: isl.Map = complete_connectivity.apply_range(occs.map_)
138
+ # Prunes data that does not need to be fetched from a higher in the mem hierarchy.
139
+ neighbor_filled: isl.Map = fills.map_.intersect(available_from_neighbors)
140
+
141
+ return TransferInfo(
142
+ fulfilled_fill=Transfers(fills.tags, neighbor_filled.coalesce()),
143
+ unfulfilled_fill=Fill(
144
+ fills.tags, fills.map_.subtract(neighbor_filled).coalesce()
145
+ ),
146
+ # Empty, since only p2p analyzed.
147
+ parent_reads=Reads(
148
+ fills.tags, neighbor_filled.subtract(neighbor_filled).coalesce()
149
+ ),
150
+ hops=isl.PwQPolynomial.from_qpolynomial(
151
+ isl.QPolynomial.one_on_domain(neighbor_filled.wrap().get_space())
152
+ )
153
+ .intersect_domain(neighbor_filled.wrap())
154
+ .coalesce(),
155
+ link_transfer=True,
156
+ )
157
+
158
+
159
+ def make_mesh_connectivity(n: int, spacetime: str) -> isl.Map:
160
+ """
161
+ Makes a neighbor-to-neighbor mesh connection given a number of spatial dims.
162
+
163
+ Parameters
164
+ ----------
165
+ n:
166
+ The number of spatial dimensions.
167
+ spacetime:
168
+ The name of the spacetime the mesh is operating on.
169
+
170
+ Returns
171
+ -------
172
+ A direct orthogonal adjacency map on the space `spacetime[t, x_1, x_2, ..., x_n]`
173
+ """
174
+ mesh: isl.Map
175
+ match (n):
176
+ case 2:
177
+ mesh = isl.Map.read_from_str(
178
+ isl.DEFAULT_CONTEXT,
179
+ "{ [t, x, y] -> [t-1, x', y'] : "
180
+ " (y'=y and x'=x-1) or (y'=y and x'=x+1) "
181
+ " or (x'=x and y'=y-1) or (x'=x and y'=y+1) }",
182
+ )
183
+ case 1:
184
+ mesh = isl.Map.read_from_str(
185
+ isl.DEFAULT_CONTEXT,
186
+ "{ [t, x] -> [t-1, x'] : (x'=x-1) or (x'=x+1) }",
187
+ )
188
+ case _:
189
+ raise ValueError(f"Cannot make mesh with {n} spatial dims")
190
+
191
+ mesh = mesh.set_tuple_name(isl.dim_type.in_, spacetime).set_tuple_name(
192
+ isl.dim_type.out, spacetime
193
+ )
194
+
195
+ return mesh
196
+
197
+
198
+ def make_connectivity_permutation(spatial_idxs: list[int], dims: int) -> list[int]:
199
+ """TODO: Figure out what this is doing."""
200
+ permutation: list[int] = []
201
+
202
+ cur_spatial_idx: int = 0
203
+ for i in range(dims):
204
+ if cur_spatial_idx < len(spatial_idxs) and i == spatial_idxs[cur_spatial_idx]:
205
+ cur_spatial_idx += 1
206
+ else:
207
+ permutation.append(i)
208
+
209
+ for spatial_idx in spatial_idxs:
210
+ permutation.append(spatial_idx)
211
+
212
+ return permutation
213
+
214
+
215
+ def get_spatial_tags_idxs(tags: list[Tag], buffer: MappingNode) -> list[int]:
216
+ """
217
+ Given a list if tags, identify the spatial dimensions belong to a given `buffer`.
218
+
219
+ Parameters
220
+ ----------
221
+ tags:
222
+ The `Occupancy` or `Fill` domain dimension tags.
223
+ buffer:
224
+ The `MappingNode` which is the logical-memory we're looking for spatial
225
+ dims over.
226
+
227
+ Returns
228
+ -------
229
+ A list of the spatial_dim_idxs in order.
230
+ """
231
+ spatial_dim_idxs: list[int] = [
232
+ i
233
+ for i, tag in enumerate(tags)
234
+ if isinstance(tag, SpatialTag) and tag.buffer == buffer
235
+ ]
236
+
237
+ return spatial_dim_idxs
238
+
239
+
240
+ def get_last_temporal_tag_idx(tags: list[Tag]) -> Optional[int]:
241
+ """
242
+ Returns the idx of the deepest temporal tag in the list.
243
+
244
+ Parameters
245
+ ----------
246
+ tags:
247
+ A list of `Tags`.
248
+
249
+ Returns
250
+ -------
251
+ The index of the last tag that is a `TEMPORAL_TAGS`.
252
+ """
253
+ if len(tags) == 0:
254
+ return None
255
+
256
+ for idx, tag in reversed(list(enumerate(tags))):
257
+ if isinstance(tag, TEMPORAL_TAGS):
258
+ return idx
259
+
260
+ return None
@@ -0,0 +1,182 @@
1
+ """
2
+ Handles the ISL temporal reuse functions.
3
+ """
4
+
5
+ from dataclasses import dataclass
6
+
7
+ import islpy as isl
8
+
9
+ from accelforge.model._looptree.reuse.isl.isl_functions import map_to_shifted
10
+ from accelforge.model._looptree.reuse.isl.mapping_to_isl.types import (
11
+ TEMPORAL_TAGS,
12
+ Fill,
13
+ Occupancy,
14
+ Tag,
15
+ )
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class TemporalReuse:
20
+ """Results for an temporal reuse analysis."""
21
+
22
+ effective_occupancy: Occupancy
23
+ """TODO: Figure this out."""
24
+ fill: Fill
25
+ """Data deliveries to locations in spacetime that need to be made."""
26
+
27
+
28
+ def analyze_temporal_reuse(
29
+ occ: Occupancy, exploit_reuse: bool = True, multi_loop_reuse: bool = True
30
+ ) -> TemporalReuse:
31
+ """
32
+ Computes the required fill to satisfy the buffer occupancy.
33
+ If the buffer can `exploit_reuse`, then the fill will only consist
34
+ of data not currently resident in buffer.
35
+
36
+ Parameters
37
+ ----------
38
+ occ:
39
+ The logical occupancy to be temporally analyzed.
40
+ exploit_reuse:
41
+ Temporally exploits reuse through persisting data currently in buffer
42
+ to the next time step.
43
+ multi_loop_reuse:
44
+ Whether when this loop, or one above it in the memory hierarchy, loops,
45
+ does the buffer flush.
46
+
47
+ Returns
48
+ -------
49
+ A struct containing a `..types.Fill` which is how to load data into the buffer
50
+ across time and a `..types.Occupancy` describing the effective_occupancy across
51
+ time (i.e., what data needs to be persisted in the buffer per time step and what
52
+ can be ignored/purged).
53
+
54
+ TODO: Make sure spaces are named properly
55
+ """
56
+ if exploit_reuse:
57
+ return fill_from_occupancy(occ, multi_loop_reuse)
58
+ return TemporalReuse(occ, Fill(occ.tags, occ.map_))
59
+
60
+
61
+ def fill_from_occupancy(
62
+ occupancy: Occupancy, multiple_loop_reuse: bool
63
+ ) -> TemporalReuse:
64
+ """
65
+ Given an occupancy and if you're allowed to reuse across loops, calculate the
66
+ `fill` and the `effective_occupancy` per time step.
67
+
68
+ Parameters
69
+ ----------
70
+ occupancy:
71
+ The logical occupancy of data in logical buffers.
72
+ multi_loop_reuse:
73
+ If you are allowed to use data between loop iterations.
74
+
75
+ Returns
76
+ -------
77
+ A `TemporalReuse` object that contains the `fill` and `effective_occpancy`
78
+ of the lowest buffer level.
79
+ """
80
+ # Iterates through each dimension in reverse order (i.e., deepest loop first)
81
+ occ = occupancy.map_.copy()
82
+ tags = occupancy.tags.copy()
83
+ for dim_idx, tag in reversed(list(enumerate(occupancy.tags))):
84
+ if not isinstance(tag, TEMPORAL_TAGS):
85
+ continue
86
+ # Check if temporal dimension is "trivial," i.e., equals a singular value
87
+ proj_occ: isl.Map = occ.project_out(
88
+ isl.dim_type.in_, dim_idx, 1
89
+ ).set_tuple_name(
90
+ isl.dim_type.in_, f"{occ.get_tuple_name(isl.dim_type.in_)}_abridged"
91
+ )
92
+ reinserted_occ: isl.Map = (
93
+ proj_occ.insert_dims(isl.dim_type.in_, dim_idx, 1).set_tuple_name(
94
+ isl.dim_type.in_,
95
+ occ.get_tuple_name(isl.dim_type.in_).removesuffix("_abridged"),
96
+ )
97
+ ).intersect_domain(occ.domain())
98
+
99
+ if occ.plain_is_equal(reinserted_occ) or occ.is_equal(reinserted_occ):
100
+ occ = proj_occ
101
+ tags.pop(dim_idx)
102
+ continue
103
+
104
+ # Nontrivial analysis
105
+ time_shift: isl.Map
106
+ if not multiple_loop_reuse:
107
+ # TODO: Verify space names are preserved and/or replace.
108
+ time_shift = map_to_shifted(occ.domain().get_space(), dim_idx, -1)
109
+ # Calculates the time_shift assuming no cache flushing for loops.
110
+ else:
111
+ # TODO: this is a better way of getting time_shift. Use method to
112
+ # replace the other branch (!multi_loop_reuse)
113
+ time_shift = construct_time_shift(occ, tags)
114
+
115
+ # Gets the fill (i.e., feeds data not currently in buffer).
116
+ occ_before: isl.Map = time_shift.apply_range(occ)
117
+ fill: isl.Map = occ.subtract(occ_before)
118
+
119
+ return TemporalReuse(Occupancy(tags, occ), Fill(tags, fill))
120
+
121
+ return TemporalReuse(Occupancy(tags, occ.coalesce()), Fill(tags, occ.coalesce()))
122
+
123
+
124
+ def construct_time_shift(occ: isl.Map, tags: list[Tag]):
125
+ """
126
+ Given an occupancy and its input dimension tags, create the proper spatial
127
+ and temporal separation objects.
128
+
129
+ Parameters
130
+ ----------
131
+ occ:
132
+ The occupancy map we're analyzing the reuse for.
133
+ tags:
134
+ The tags of what an input represents.
135
+
136
+ Returns
137
+ -------
138
+ time_shift:
139
+ Relation of the current time step to the previous one across loops.
140
+ """
141
+ # Creates the spacetime deconstruction to the two separate components.
142
+ spacetime: isl.Set = occ.domain()
143
+ spacetime_to_time: isl.Map = isl.Map.identity(spacetime.get_space().map_from_set())
144
+ spacetime_to_space: isl.Map = isl.Map.identity(spacetime.get_space().map_from_set())
145
+ # Prunes out the output dimensions that do not correspond to the
146
+ # correct mapping into a generic space-to-space relation.
147
+ for idx, t in reversed(list(enumerate(tags))):
148
+ if not isinstance(t, TEMPORAL_TAGS):
149
+ spacetime_to_time = spacetime_to_time.project_out(isl.dim_type.out, idx, 1)
150
+ else:
151
+ spacetime_to_space = spacetime_to_space.project_out(
152
+ isl.dim_type.out, idx, 1
153
+ )
154
+ # Gets the names correct after transformations.
155
+ spacetime_to_time = spacetime_to_time.set_tuple_name(
156
+ isl.dim_type.out, f"{spacetime.get_tuple_name()}_time"
157
+ )
158
+ spacetime_to_space = spacetime_to_space.set_tuple_name(
159
+ isl.dim_type.out, f"{spacetime.get_tuple_name()}_space"
160
+ )
161
+
162
+ # Properly constrains the spacetime_to_time's domain.
163
+ spacetime_to_time = spacetime_to_time.intersect_domain(spacetime)
164
+ time_: isl.Set = spacetime_to_time.range()
165
+ # Creates a map of time_ to previous regions of time_.
166
+ time_to_past: isl.Map = (
167
+ isl.Map.lex_gt(time_.get_space()).intersect_domain(time_).intersect_range(time_)
168
+ )
169
+ # Restricts the relation to only the most recent previous region of time_.
170
+ time_to_most_recent_past = time_to_past.lexmax()
171
+ # Relates the current spacetime to its direct predecessor in time.
172
+ time_shift: isl.Map = spacetime_to_time.apply_range(
173
+ time_to_most_recent_past.apply_range(spacetime_to_time.reverse())
174
+ )
175
+
176
+ # Prunes spatial relations to only ones that are valid.
177
+ spacetime_space_preserver: isl.Map = spacetime_to_space.apply_range(
178
+ spacetime_to_space.reverse()
179
+ )
180
+ # Intersects with time_shift as space information is lost with the compression of
181
+ # spacetime to time_ and then rexpansion to past time_.
182
+ return time_shift.intersect(spacetime_space_preserver)
@@ -0,0 +1 @@
1
+ from .symbolic import *