accelforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. accelforge/__init__.py +21 -0
  2. accelforge/_accelerated_imports.py +16 -0
  3. accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
  4. accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
  5. accelforge/_deprecate/_simanneal/simanneal.py +666 -0
  6. accelforge/_deprecate/_simanneal/tracking.py +105 -0
  7. accelforge/_deprecate/_simanneal/wrappers.py +218 -0
  8. accelforge/_deprecate/_simanneal2/__init__.py +7 -0
  9. accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
  10. accelforge/_deprecate/_simanneal2/tracking.py +116 -0
  11. accelforge/_deprecate/compatibility_util.py +181 -0
  12. accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
  13. accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
  14. accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
  15. accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
  16. accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
  17. accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
  18. accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
  19. accelforge/_deprecate/tags.py +69 -0
  20. accelforge/_deprecate/viz/__init__.py +0 -0
  21. accelforge/_deprecate/viz/interactive.py +159 -0
  22. accelforge/_deprecate/viz/reservationtree.py +307 -0
  23. accelforge/_deprecate/viz/ski_slope.py +88 -0
  24. accelforge/_version.py +15 -0
  25. accelforge/examples.py +39 -0
  26. accelforge/frontend/__init__.py +10 -0
  27. accelforge/frontend/_binding.py +129 -0
  28. accelforge/frontend/_workload_isl/__init__.py +2 -0
  29. accelforge/frontend/_workload_isl/_isl.py +149 -0
  30. accelforge/frontend/_workload_isl/_symbolic.py +141 -0
  31. accelforge/frontend/arch copy.py +1544 -0
  32. accelforge/frontend/arch.py +1642 -0
  33. accelforge/frontend/config.py +63 -0
  34. accelforge/frontend/mapper/__init__.py +5 -0
  35. accelforge/frontend/mapper/ffm.py +126 -0
  36. accelforge/frontend/mapper/mapper.py +7 -0
  37. accelforge/frontend/mapper/metrics.py +30 -0
  38. accelforge/frontend/mapping/__init__.py +1 -0
  39. accelforge/frontend/mapping/mapping.py +1736 -0
  40. accelforge/frontend/model.py +14 -0
  41. accelforge/frontend/renames.py +150 -0
  42. accelforge/frontend/spec copy.py +230 -0
  43. accelforge/frontend/spec.py +301 -0
  44. accelforge/frontend/variables.py +12 -0
  45. accelforge/frontend/workload.py +952 -0
  46. accelforge/mapper/FFM/__init__.py +9 -0
  47. accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
  48. accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
  49. accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
  50. accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
  51. accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
  52. accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
  53. accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
  54. accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
  55. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
  56. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
  57. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
  58. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
  59. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
  60. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
  61. accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
  62. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
  63. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
  64. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
  65. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
  66. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
  67. accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
  68. accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
  69. accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
  70. accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
  71. accelforge/mapper/FFM/data.py +61 -0
  72. accelforge/mapper/FFM/main copy.py +236 -0
  73. accelforge/mapper/FFM/main.py +208 -0
  74. accelforge/mapper/FFM/mappings.py +510 -0
  75. accelforge/mapper/FFM/pmappings.py +310 -0
  76. accelforge/mapper/__init__.py +4 -0
  77. accelforge/mapper.py +0 -0
  78. accelforge/model/__init__.py +1 -0
  79. accelforge/model/_looptree/__init__.py +0 -0
  80. accelforge/model/_looptree/accesses.py +335 -0
  81. accelforge/model/_looptree/capacity/__init__.py +1 -0
  82. accelforge/model/_looptree/capacity/aggregators.py +36 -0
  83. accelforge/model/_looptree/capacity/capacity.py +47 -0
  84. accelforge/model/_looptree/energy.py +150 -0
  85. accelforge/model/_looptree/equivalent_ranks.py +29 -0
  86. accelforge/model/_looptree/latency/__init__.py +1 -0
  87. accelforge/model/_looptree/latency/latency.py +98 -0
  88. accelforge/model/_looptree/latency/memory.py +120 -0
  89. accelforge/model/_looptree/latency/processors.py +92 -0
  90. accelforge/model/_looptree/mapping_utilities.py +71 -0
  91. accelforge/model/_looptree/reuse/__init__.py +4 -0
  92. accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
  93. accelforge/model/_looptree/reuse/isl/des.py +59 -0
  94. accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
  95. accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
  96. accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
  97. accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
  98. accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
  99. accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
  100. accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
  101. accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
  102. accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
  103. accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
  104. accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
  105. accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
  106. accelforge/model/_looptree/run.py +122 -0
  107. accelforge/model/_looptree/types.py +26 -0
  108. accelforge/model/_looptree/visualization/__init__.py +0 -0
  109. accelforge/model/_looptree/visualization/occupancy.py +11 -0
  110. accelforge/model/main.py +222 -0
  111. accelforge/plotting/__init__.py +2 -0
  112. accelforge/plotting/mappings.py +219 -0
  113. accelforge/plotting/specs.py +57 -0
  114. accelforge/util/__init__.py +4 -0
  115. accelforge/util/_base_analysis_types.py +24 -0
  116. accelforge/util/_basetypes.py +1089 -0
  117. accelforge/util/_frozenset.py +36 -0
  118. accelforge/util/_isl.py +29 -0
  119. accelforge/util/_itertools.py +14 -0
  120. accelforge/util/_mathfuncs.py +57 -0
  121. accelforge/util/_parse_expressions.py +339 -0
  122. accelforge/util/_picklecache.py +32 -0
  123. accelforge/util/_setexpressions.py +268 -0
  124. accelforge/util/_sympy/__init__.py +0 -0
  125. accelforge/util/_sympy/broadcast_max.py +18 -0
  126. accelforge/util/_visualization.py +112 -0
  127. accelforge/util/_yaml.py +579 -0
  128. accelforge/util/parallel.py +193 -0
  129. accelforge-0.0.1.dist-info/METADATA +64 -0
  130. accelforge-0.0.1.dist-info/RECORD +258 -0
  131. accelforge-0.0.1.dist-info/WHEEL +5 -0
  132. accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
  133. accelforge-0.0.1.dist-info/top_level.txt +5 -0
  134. docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
  135. docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
  136. docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
  137. docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
  138. docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
  139. docs/_build/html/_sources/fastfusion.rst.txt +20 -0
  140. docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
  141. docs/_build/html/_sources/index.rst.txt +87 -0
  142. docs/_build/html/_sources/modules.rst.txt +7 -0
  143. docs/_build/html/_sources/notes/citation.rst.txt +45 -0
  144. docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
  145. docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
  146. docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
  147. docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
  148. docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
  149. docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
  150. docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
  151. docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
  152. docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
  153. docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
  154. docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
  155. docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
  156. docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
  157. docs/_build/html/_sources/notes/spec.rst.txt +36 -0
  158. docs/source/_ext/include_attrs.py +213 -0
  159. docs/source/_ext/include_docstring.py +364 -0
  160. docs/source/_ext/include_functions.py +154 -0
  161. docs/source/_ext/include_notebook.py +131 -0
  162. docs/source/_ext/include_yaml.py +119 -0
  163. docs/source/_ext/inherited_attributes.py +222 -0
  164. docs/source/_ext/paths.py +4 -0
  165. docs/source/conf.py +79 -0
  166. examples/arches/compute_in_memory/_include.yaml +74 -0
  167. examples/arches/compute_in_memory/_include_functions.py +229 -0
  168. examples/arches/compute_in_memory/_load_spec.py +57 -0
  169. examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
  170. examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
  171. examples/arches/compute_in_memory/components/misc.py +195 -0
  172. examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
  173. examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
  174. examples/arches/compute_in_memory/isaac.yaml +233 -0
  175. examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
  176. examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
  177. examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
  178. examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
  179. examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
  180. examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
  181. examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
  182. examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
  183. examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
  184. examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
  185. examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
  186. examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
  187. examples/arches/eyeriss.yaml +68 -0
  188. examples/arches/fanout_variations/at_glb.yaml +31 -0
  189. examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
  190. examples/arches/fanout_variations/at_mac.yaml +31 -0
  191. examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
  192. examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
  193. examples/arches/nvdla.yaml +47 -0
  194. examples/arches/simple.yaml +28 -0
  195. examples/arches/tpu_v4i.yaml +67 -0
  196. examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
  197. examples/misc/component_annotated.yaml +33 -0
  198. examples/workloads/gpt3_6.7B.yaml +124 -0
  199. examples/workloads/matmuls.yaml +20 -0
  200. examples/workloads/mobilenet_28.yaml +81 -0
  201. examples/workloads/mobilenet_various_separate.yaml +106 -0
  202. examples/workloads/three_matmuls_annotated.yaml +59 -0
  203. notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
  204. notebooks/compute_in_memory/_scripts.py +339 -0
  205. notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
  206. notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
  207. notebooks/paths.py +4 -0
  208. notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
  209. notebooks/tutorials/FFM.ipynb +3498 -0
  210. notebooks/tutorials/_include.py +48 -0
  211. notebooks/tutorials/component_energy_area.ipynb +363 -0
  212. tests/Q_mapping.yaml +38 -0
  213. tests/__init__.py +0 -0
  214. tests/conv.mapping.yaml +27 -0
  215. tests/conv.workload.yaml +13 -0
  216. tests/conv_sym.mapping.yaml +43 -0
  217. tests/copy.mapping.yaml +35 -0
  218. tests/copy.workload.yaml +15 -0
  219. tests/distribuffers/__init__.py +0 -0
  220. tests/distribuffers/multicast/test_cases.yaml +482 -0
  221. tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
  222. tests/distribuffers/spec/distributed.yaml +100 -0
  223. tests/distribuffers/spec/logical_arch.yaml +32 -0
  224. tests/distribuffers/spec/physical_arch.yaml +69 -0
  225. tests/distribuffers/test_binding.py +48 -0
  226. tests/frontend/__init__.py +0 -0
  227. tests/frontend/test_mapping_viz.py +52 -0
  228. tests/mapper/__init__.py +0 -0
  229. tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
  230. tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
  231. tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
  232. tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
  233. tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
  234. tests/mapper/test_mapping_to_isl.py +90 -0
  235. tests/mapper/test_spatial_reuse_analysis.py +67 -0
  236. tests/mapper/test_temporal_reuse_analysis.py +56 -0
  237. tests/mapper/util.py +58 -0
  238. tests/matmul.mapping.yaml +29 -0
  239. tests/matmul.workload.yaml +12 -0
  240. tests/matmul_spatial.mapping.yaml +44 -0
  241. tests/mha.renames.yaml +65 -0
  242. tests/mha.workload.yaml +67 -0
  243. tests/mha.yaml +59 -0
  244. tests/mha_full.workload.yaml +67 -0
  245. tests/mobilenet.workload.yaml +35 -0
  246. tests/mobilenet_long.workload.yaml +64 -0
  247. tests/pmappingcache.py +24 -0
  248. tests/processing_stage.arch.yaml +40 -0
  249. tests/snowcat.arch.yaml +36 -0
  250. tests/test_ffm_join_pmappings.py +106 -0
  251. tests/test_ffm_make_pmappings.py +82 -0
  252. tests/test_ffm_make_tile_shapes.py +49 -0
  253. tests/test_mapper.py +100 -0
  254. tests/test_model.py +37 -0
  255. tests/test_plotting.py +72 -0
  256. tests/test_processing_stage.py +46 -0
  257. tests/test_symbolic_model.py +248 -0
  258. tests/test_workload.py +141 -0
@@ -0,0 +1,268 @@
1
+ import functools
2
+ from accelforge.util._parse_expressions import ParseError
3
+ from pydantic import BaseModel, ConfigDict, model_serializer
4
+ from typing import Iterator, Optional, TypeVar, Generic, Any, Union
5
+ from accelforge.util._parse_expressions import MATH_FUNCS
6
+
7
+ T = TypeVar("T")
8
+
9
+
10
+ def _reconstruct_invertible_set(state):
11
+ """Helper function to reconstruct InvertibleSet during unpickling."""
12
+ obj = object.__new__(InvertibleSet)
13
+ obj.__dict__.update(state)
14
+ return obj
15
+
16
+
17
+ class InvertibleSet(BaseModel, Generic[T]):
18
+ instance: frozenset[T]
19
+ full_space: frozenset[T]
20
+ space_type: type[T]
21
+ # child_access_name: Optional[str] = None
22
+ element_to_child_space: Optional[dict[str, Any]] = None
23
+ _bits_per_value: Optional[int] = None
24
+
25
+ def __init__(self, *args, **kwargs):
26
+ super().__init__(*args, **kwargs)
27
+
28
+ @model_serializer
29
+ def _serialize_model(self):
30
+ """Custom serializer for InvertibleSet to avoid Pydantic serialization warnings."""
31
+ return {
32
+ "instance": list(self.instance),
33
+ "full_space": list(self.full_space),
34
+ "space_type": self.space_type.__name__,
35
+ "element_to_child_space": self.element_to_child_space,
36
+ "_bits_per_value": self._bits_per_value,
37
+ }
38
+
39
+ @property
40
+ def bits_per_value(self) -> int:
41
+ if len(self.instance) != 1:
42
+ raise ValueError(
43
+ f"Can not access bits_per_value for a set !=1 elements: "
44
+ f"{self.instance}."
45
+ )
46
+ if self._bits_per_value is None:
47
+ raise ValueError(f"Bits per value is not defined for set {self.instance}.")
48
+ return self._bits_per_value
49
+
50
+ @bits_per_value.setter
51
+ def bits_per_value(self, value: int):
52
+ self._bits_per_value = value
53
+
54
+ def __reduce__(self):
55
+ return (_reconstruct_invertible_set, (self.__dict__,))
56
+
57
+ def __getstate__(self):
58
+ return self.__dict__
59
+
60
+ def __setstate__(self, state):
61
+ self.__dict__.update(state)
62
+
63
+ def __deepcopy__(self, memo):
64
+ """Custom deepcopy implementation to avoid pydantic deepcopy issues."""
65
+ import copy
66
+
67
+ cls = type(self)
68
+ # Create a new instance without calling __init__
69
+ new_obj = cls.__new__(cls)
70
+ # Mark it in the memo to handle circular references
71
+ memo[id(self)] = new_obj
72
+ # Deep copy the __dict__ directly to avoid triggering setattr
73
+ new_obj.__dict__.update(copy.deepcopy(self.__dict__, memo))
74
+ # Initialize pydantic's internal attributes if they don't exist
75
+ if not hasattr(new_obj, "__pydantic_fields_set__"):
76
+ object.__setattr__(new_obj, "__pydantic_fields_set__", set())
77
+ if not hasattr(new_obj, "__pydantic_extra__"):
78
+ object.__setattr__(new_obj, "__pydantic_extra__", {})
79
+ if not hasattr(new_obj, "__pydantic_private__"):
80
+ object.__setattr__(new_obj, "__pydantic_private__", {})
81
+ return new_obj
82
+
83
+ def __repr__(self):
84
+ return f"InvertibleSet({self.instance})"
85
+
86
+ def __str__(self):
87
+ return self.__repr__()
88
+
89
+ def __invert__(self):
90
+ return self.to_my_space(self.full_space - self.instance)
91
+
92
+ def check_match_space_name(self, other):
93
+ if self.space_type != other.space_type:
94
+ raise ValueError(
95
+ f"Can not perform set operations between different spaces "
96
+ f"{self.space_type} and {other.space_type}."
97
+ )
98
+
99
+ def to_my_space(self, other) -> Union[set, "InvertibleSet"]:
100
+ return InvertibleSet(
101
+ instance=other.instance if isinstance(other, InvertibleSet) else other,
102
+ full_space=self.full_space,
103
+ space_type=self.space_type,
104
+ # child_access_name=self.child_access_name,
105
+ element_to_child_space=self.element_to_child_space,
106
+ )
107
+
108
+ @staticmethod
109
+ def _make_set(x) -> set:
110
+ return x.instance if isinstance(x, InvertibleSet) else x
111
+
112
+ def __and__(self, other: "InvertibleSet[T]") -> "InvertibleSet[T]":
113
+ a, b = self._make_set(self), self._make_set(other)
114
+ return self.to_my_space(a & b)
115
+
116
+ def __or__(self, other: "InvertibleSet[T]") -> "InvertibleSet[T]":
117
+ a, b = self._make_set(self), self._make_set(other)
118
+ return self.to_my_space(a | b)
119
+
120
+ def __sub__(self, other: "InvertibleSet[T]") -> "InvertibleSet[T]":
121
+ a, b = self._make_set(self), self._make_set(other)
122
+ return self.to_my_space(a - b)
123
+
124
+ def __xor__(self, other: "InvertibleSet[T]") -> "InvertibleSet[T]":
125
+ a, b = self._make_set(self), self._make_set(other)
126
+ return self.to_my_space(a ^ b)
127
+
128
+ def __call__(self):
129
+ return self
130
+
131
+ def _cast_to_child_space(self, *args, **kwargs):
132
+ if not self.full_space:
133
+ raise ValueError(f"Full space is empty for set {self.space_type}.")
134
+ for item in self:
135
+ if item not in self.element_to_child_space:
136
+ raise ValueError(
137
+ f"Item {item} is not in the element_to_child_space "
138
+ f"for set {self.space_type}."
139
+ )
140
+
141
+ if not self.element_to_child_space:
142
+ raise ValueError(
143
+ f"Element to child space is not set for set {self.space_type}."
144
+ )
145
+
146
+ first_child_space_item: InvertibleSet = next(
147
+ iter(self.element_to_child_space.values())
148
+ )
149
+ return first_child_space_item.to_my_space(
150
+ set.union(*(set(self.element_to_child_space[item]) for item in self), set())
151
+ )
152
+
153
+ def __bool__(self):
154
+ return bool(self.instance)
155
+
156
+ def __len__(self):
157
+ return len(self.instance)
158
+
159
+ def __contains__(self, item):
160
+ return item in self.instance
161
+
162
+ def __iter__(self):
163
+ return iter(self.instance)
164
+
165
+ def __getitem__(self, item):
166
+ return self.instance[item]
167
+
168
+ def iter_one_element_sets(self) -> Iterator["InvertibleSet[T]"]:
169
+ for item in self.instance:
170
+ yield InvertibleSet(
171
+ instance=set((item,)),
172
+ full_space=self.full_space,
173
+ space_type=self.space_type,
174
+ # child_access_name=self.child_access_name,
175
+ element_to_child_space=self.element_to_child_space,
176
+ )
177
+
178
+ @property
179
+ def rank_variables(self) -> set["RankVariable"]:
180
+ from accelforge.frontend.workload import RankVariable
181
+ from accelforge.frontend.renames import TensorName
182
+
183
+ if self.space_type == TensorName:
184
+ return self._cast_to_child_space()
185
+ raise ValueError(
186
+ f"Can not get rank variables for a set with space type "
187
+ f"{self.space_type.__name__}."
188
+ )
189
+
190
+ @property
191
+ def tensors(self) -> set["TensorName"]:
192
+ from accelforge.frontend.renames import TensorName
193
+
194
+ if self.space_type == TensorName:
195
+ return self
196
+ raise ValueError(
197
+ f"Can not get tensors for a set with space type "
198
+ f"{self.space_type.__name__}."
199
+ )
200
+
201
+
202
+ def set_expression_type_check(
203
+ result: InvertibleSet[T],
204
+ expected_space: type[T],
205
+ expected_count: int | None = None,
206
+ location: str | None = None,
207
+ ) -> None:
208
+ if not isinstance(result, InvertibleSet):
209
+ raise TypeError(f"Expected a InvertibleSet, got {type(result)}: {result}")
210
+ if expected_space is not None and result.space_type != expected_space:
211
+ raise ValueError(
212
+ f"Expected a set with space type '{expected_space.__name__}', got {result.space_type.__name__}"
213
+ )
214
+ if expected_count is not None and len(result) != expected_count:
215
+ raise ValueError(
216
+ f"Expected {expected_count=} elements, got {len(result)}: {result.instance}"
217
+ )
218
+
219
+
220
+ def eval_set_expression(
221
+ expression: str | InvertibleSet,
222
+ symbol_table: dict[str, InvertibleSet],
223
+ expected_space: type[T],
224
+ location: str,
225
+ expected_count: int | None = None,
226
+ ) -> InvertibleSet[T]:
227
+ try:
228
+ err = None
229
+ if not isinstance(expression, (InvertibleSet, str)):
230
+ raise TypeError(f"Expected a string, got {type(expression)}: {expression}")
231
+
232
+ prev_result = "NOT_FOUND"
233
+ if isinstance(expression, str):
234
+ result = prev_result
235
+ if expression in symbol_table:
236
+ result = symbol_table[expression]
237
+ elif expression[-2:] == "()" and expression[:-2] in symbol_table:
238
+ try:
239
+ result = symbol_table[expression[:-2]]()
240
+ except:
241
+ pass
242
+ else:
243
+ result = expression
244
+
245
+ if id(result) == id(prev_result):
246
+ result = eval(expression, {"__builtins__": MATH_FUNCS}, symbol_table)
247
+
248
+ if not isinstance(result, InvertibleSet):
249
+ raise TypeError(
250
+ f"Returned a non-InvertibleSet with type {type(result)}: {result}"
251
+ )
252
+ set_expression_type_check(result, expected_space, expected_count, location)
253
+
254
+ except Exception as e:
255
+
256
+ def strformat(v):
257
+ v = str(v)
258
+ return v if len(v) <= 100 else v[:100] + "..."
259
+
260
+ err = ParseError(
261
+ f'{e}. Set expression: "{expression}". Symbol table:\n\t'
262
+ + "\n\t".join(f"{k}: {strformat(v)}" for k, v in symbol_table.items())
263
+ )
264
+ if location is not None:
265
+ err.add_field(location)
266
+ if err:
267
+ raise err
268
+ return result
File without changes
@@ -0,0 +1,18 @@
1
+ import sympy
2
+ from sympy import Max
3
+ from sympy import Min
4
+
5
+ # MAX BUG FIX.
6
+ # def Min(a, *bs):
7
+ # """More post-lambdify broadcast-friendly option than sympy.Min"""
8
+ # result = a
9
+ # for b in bs:
10
+ # result = sympy.Piecewise((result, result < b), (b, True))
11
+ # return result
12
+
13
+ # def Max(a, *bs):
14
+ # """More post-lambdify broadcast-friendly option than sympy.Max"""
15
+ # result = a
16
+ # for b in bs:
17
+ # result = sympy.Piecewise((result, result > b), (b, True))
18
+ # return result
@@ -0,0 +1,112 @@
1
+ import pydot
2
+
3
+
4
+ def _pydot_graph() -> pydot.Dot:
5
+ graph = pydot.Dot(graph_type="graph", rankdir="TD", ranksep=0.2)
6
+ graph.set_node_defaults(shape="box", fontname="Arial", fontsize="12")
7
+ graph.set_edge_defaults(fontname="Arial", fontsize="10")
8
+ return graph
9
+
10
+
11
+ # =============================================================================
12
+ # Color Map for Visualization
13
+ # =============================================================================
14
+
15
+
16
+ class ColorMap:
17
+
18
+ def __init__(self, keys: list[str]):
19
+ self.keys = keys
20
+ self.color_list = self._make_color_map(len(keys))
21
+ self.color_map = {key: self.color_list[i] for i, key in enumerate(keys)}
22
+
23
+ def format_list(self, items: list[str]) -> str:
24
+ result = ['<<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0"><TR>']
25
+ for i, item in enumerate(items):
26
+ start = '<TD ALIGN="CENTER">' # if i < len(items) - 1 else f'</TR><TR><TD ALIGN="CENTER" COLSPAN="100">'
27
+ if item in self.color_map:
28
+ start = f'<TD ALIGN="CENTER" BORDER="5" COLOR="{self.color_map[item]}">'
29
+ end = "</TD>"
30
+ result.append(f"{start}{item}{end}")
31
+ result.append("</TR></TABLE>>")
32
+ return "".join(result)
33
+
34
+ # This makes a colored bar under the text
35
+ # result = ['<<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">']
36
+ # # First row: text
37
+ # result.append('<TR>')
38
+ # for item in items:
39
+ # result.append(f'<TD ALIGN="CENTER" STYLE="margin:0;padding:0;">{item}</TD>')
40
+ # result.append('</TR>')
41
+ # # Second row: color bar (height 20, width 40, minimal spacing)
42
+ # result.append('<TR>')
43
+ # for item in items:
44
+ # if item in self.color_map:
45
+ # result.append(f'<TD BGCOLOR="{self.color_map[item]}" HEIGHT="10" WIDTH="15" FIXEDSIZE="TRUE" STYLE="margin:0;padding:0;"></TD>')
46
+ # else:
47
+ # result.append('<TD HEIGHT="20" WIDTH="40" FIXEDSIZE="TRUE" STYLE="margin:0;padding:0;"></TD>')
48
+ # result.append('</TR>')
49
+ # result.append('</TABLE>>')
50
+ # return ''.join(result)
51
+
52
+ def _make_color_map(self, n_colors: int) -> list[str]:
53
+ if n_colors <= 0:
54
+ return []
55
+
56
+ # High contrast, distinguishable colors for borders
57
+ base_colors = [
58
+ "#FF0000", # Red
59
+ "#00FF00", # Green
60
+ "#0000FF", # Blue
61
+ "#FFFF00", # Yellow
62
+ "#FF00FF", # Magenta
63
+ "#00FFFF", # Cyan
64
+ "#FF8000", # Orange
65
+ "#8000FF", # Purple
66
+ "#008000", # Dark Green
67
+ "#800000", # Dark Red
68
+ "#000080", # Dark Blue
69
+ "#808000", # Olive
70
+ ]
71
+
72
+ if n_colors <= len(base_colors):
73
+ return base_colors[:n_colors]
74
+
75
+ # For more colors, generate additional colors with maximum distinction
76
+ colors = base_colors.copy()
77
+
78
+ # Use evenly spaced hues for maximum distinction
79
+ for i in range(len(base_colors), n_colors):
80
+ # Evenly space hues around the color wheel
81
+ hue = i / n_colors
82
+
83
+ # Use high saturation and value for maximum contrast
84
+ saturation = 1.0 # Full saturation
85
+ value = 1.0 # Full value
86
+
87
+ # Convert HSV to RGB
88
+ h = hue * 6
89
+ c = value * saturation
90
+ x = c * (1 - abs(h % 2 - 1))
91
+ m = value - c
92
+
93
+ if h < 1:
94
+ r, g, b = c, x, 0
95
+ elif h < 2:
96
+ r, g, b = x, c, 0
97
+ elif h < 3:
98
+ r, g, b = 0, c, x
99
+ elif h < 4:
100
+ r, g, b = 0, x, c
101
+ elif h < 5:
102
+ r, g, b = x, 0, c
103
+ else:
104
+ r, g, b = c, 0, x
105
+
106
+ r = int((r + m) * 255)
107
+ g = int((g + m) * 255)
108
+ b = int((b + m) * 255)
109
+
110
+ colors.append(f"#{r:02x}{g:02x}{b:02x}")
111
+
112
+ return colors