accelforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. accelforge/__init__.py +21 -0
  2. accelforge/_accelerated_imports.py +16 -0
  3. accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
  4. accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
  5. accelforge/_deprecate/_simanneal/simanneal.py +666 -0
  6. accelforge/_deprecate/_simanneal/tracking.py +105 -0
  7. accelforge/_deprecate/_simanneal/wrappers.py +218 -0
  8. accelforge/_deprecate/_simanneal2/__init__.py +7 -0
  9. accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
  10. accelforge/_deprecate/_simanneal2/tracking.py +116 -0
  11. accelforge/_deprecate/compatibility_util.py +181 -0
  12. accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
  13. accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
  14. accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
  15. accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
  16. accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
  17. accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
  18. accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
  19. accelforge/_deprecate/tags.py +69 -0
  20. accelforge/_deprecate/viz/__init__.py +0 -0
  21. accelforge/_deprecate/viz/interactive.py +159 -0
  22. accelforge/_deprecate/viz/reservationtree.py +307 -0
  23. accelforge/_deprecate/viz/ski_slope.py +88 -0
  24. accelforge/_version.py +15 -0
  25. accelforge/examples.py +39 -0
  26. accelforge/frontend/__init__.py +10 -0
  27. accelforge/frontend/_binding.py +129 -0
  28. accelforge/frontend/_workload_isl/__init__.py +2 -0
  29. accelforge/frontend/_workload_isl/_isl.py +149 -0
  30. accelforge/frontend/_workload_isl/_symbolic.py +141 -0
  31. accelforge/frontend/arch copy.py +1544 -0
  32. accelforge/frontend/arch.py +1642 -0
  33. accelforge/frontend/config.py +63 -0
  34. accelforge/frontend/mapper/__init__.py +5 -0
  35. accelforge/frontend/mapper/ffm.py +126 -0
  36. accelforge/frontend/mapper/mapper.py +7 -0
  37. accelforge/frontend/mapper/metrics.py +30 -0
  38. accelforge/frontend/mapping/__init__.py +1 -0
  39. accelforge/frontend/mapping/mapping.py +1736 -0
  40. accelforge/frontend/model.py +14 -0
  41. accelforge/frontend/renames.py +150 -0
  42. accelforge/frontend/spec copy.py +230 -0
  43. accelforge/frontend/spec.py +301 -0
  44. accelforge/frontend/variables.py +12 -0
  45. accelforge/frontend/workload.py +952 -0
  46. accelforge/mapper/FFM/__init__.py +9 -0
  47. accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
  48. accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
  49. accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
  50. accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
  51. accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
  52. accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
  53. accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
  54. accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
  55. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
  56. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
  57. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
  58. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
  59. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
  60. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
  61. accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
  62. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
  63. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
  64. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
  65. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
  66. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
  67. accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
  68. accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
  69. accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
  70. accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
  71. accelforge/mapper/FFM/data.py +61 -0
  72. accelforge/mapper/FFM/main copy.py +236 -0
  73. accelforge/mapper/FFM/main.py +208 -0
  74. accelforge/mapper/FFM/mappings.py +510 -0
  75. accelforge/mapper/FFM/pmappings.py +310 -0
  76. accelforge/mapper/__init__.py +4 -0
  77. accelforge/mapper.py +0 -0
  78. accelforge/model/__init__.py +1 -0
  79. accelforge/model/_looptree/__init__.py +0 -0
  80. accelforge/model/_looptree/accesses.py +335 -0
  81. accelforge/model/_looptree/capacity/__init__.py +1 -0
  82. accelforge/model/_looptree/capacity/aggregators.py +36 -0
  83. accelforge/model/_looptree/capacity/capacity.py +47 -0
  84. accelforge/model/_looptree/energy.py +150 -0
  85. accelforge/model/_looptree/equivalent_ranks.py +29 -0
  86. accelforge/model/_looptree/latency/__init__.py +1 -0
  87. accelforge/model/_looptree/latency/latency.py +98 -0
  88. accelforge/model/_looptree/latency/memory.py +120 -0
  89. accelforge/model/_looptree/latency/processors.py +92 -0
  90. accelforge/model/_looptree/mapping_utilities.py +71 -0
  91. accelforge/model/_looptree/reuse/__init__.py +4 -0
  92. accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
  93. accelforge/model/_looptree/reuse/isl/des.py +59 -0
  94. accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
  95. accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
  96. accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
  97. accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
  98. accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
  99. accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
  100. accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
  101. accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
  102. accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
  103. accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
  104. accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
  105. accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
  106. accelforge/model/_looptree/run.py +122 -0
  107. accelforge/model/_looptree/types.py +26 -0
  108. accelforge/model/_looptree/visualization/__init__.py +0 -0
  109. accelforge/model/_looptree/visualization/occupancy.py +11 -0
  110. accelforge/model/main.py +222 -0
  111. accelforge/plotting/__init__.py +2 -0
  112. accelforge/plotting/mappings.py +219 -0
  113. accelforge/plotting/specs.py +57 -0
  114. accelforge/util/__init__.py +4 -0
  115. accelforge/util/_base_analysis_types.py +24 -0
  116. accelforge/util/_basetypes.py +1089 -0
  117. accelforge/util/_frozenset.py +36 -0
  118. accelforge/util/_isl.py +29 -0
  119. accelforge/util/_itertools.py +14 -0
  120. accelforge/util/_mathfuncs.py +57 -0
  121. accelforge/util/_parse_expressions.py +339 -0
  122. accelforge/util/_picklecache.py +32 -0
  123. accelforge/util/_setexpressions.py +268 -0
  124. accelforge/util/_sympy/__init__.py +0 -0
  125. accelforge/util/_sympy/broadcast_max.py +18 -0
  126. accelforge/util/_visualization.py +112 -0
  127. accelforge/util/_yaml.py +579 -0
  128. accelforge/util/parallel.py +193 -0
  129. accelforge-0.0.1.dist-info/METADATA +64 -0
  130. accelforge-0.0.1.dist-info/RECORD +258 -0
  131. accelforge-0.0.1.dist-info/WHEEL +5 -0
  132. accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
  133. accelforge-0.0.1.dist-info/top_level.txt +5 -0
  134. docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
  135. docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
  136. docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
  137. docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
  138. docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
  139. docs/_build/html/_sources/fastfusion.rst.txt +20 -0
  140. docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
  141. docs/_build/html/_sources/index.rst.txt +87 -0
  142. docs/_build/html/_sources/modules.rst.txt +7 -0
  143. docs/_build/html/_sources/notes/citation.rst.txt +45 -0
  144. docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
  145. docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
  146. docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
  147. docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
  148. docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
  149. docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
  150. docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
  151. docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
  152. docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
  153. docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
  154. docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
  155. docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
  156. docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
  157. docs/_build/html/_sources/notes/spec.rst.txt +36 -0
  158. docs/source/_ext/include_attrs.py +213 -0
  159. docs/source/_ext/include_docstring.py +364 -0
  160. docs/source/_ext/include_functions.py +154 -0
  161. docs/source/_ext/include_notebook.py +131 -0
  162. docs/source/_ext/include_yaml.py +119 -0
  163. docs/source/_ext/inherited_attributes.py +222 -0
  164. docs/source/_ext/paths.py +4 -0
  165. docs/source/conf.py +79 -0
  166. examples/arches/compute_in_memory/_include.yaml +74 -0
  167. examples/arches/compute_in_memory/_include_functions.py +229 -0
  168. examples/arches/compute_in_memory/_load_spec.py +57 -0
  169. examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
  170. examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
  171. examples/arches/compute_in_memory/components/misc.py +195 -0
  172. examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
  173. examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
  174. examples/arches/compute_in_memory/isaac.yaml +233 -0
  175. examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
  176. examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
  177. examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
  178. examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
  179. examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
  180. examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
  181. examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
  182. examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
  183. examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
  184. examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
  185. examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
  186. examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
  187. examples/arches/eyeriss.yaml +68 -0
  188. examples/arches/fanout_variations/at_glb.yaml +31 -0
  189. examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
  190. examples/arches/fanout_variations/at_mac.yaml +31 -0
  191. examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
  192. examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
  193. examples/arches/nvdla.yaml +47 -0
  194. examples/arches/simple.yaml +28 -0
  195. examples/arches/tpu_v4i.yaml +67 -0
  196. examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
  197. examples/misc/component_annotated.yaml +33 -0
  198. examples/workloads/gpt3_6.7B.yaml +124 -0
  199. examples/workloads/matmuls.yaml +20 -0
  200. examples/workloads/mobilenet_28.yaml +81 -0
  201. examples/workloads/mobilenet_various_separate.yaml +106 -0
  202. examples/workloads/three_matmuls_annotated.yaml +59 -0
  203. notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
  204. notebooks/compute_in_memory/_scripts.py +339 -0
  205. notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
  206. notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
  207. notebooks/paths.py +4 -0
  208. notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
  209. notebooks/tutorials/FFM.ipynb +3498 -0
  210. notebooks/tutorials/_include.py +48 -0
  211. notebooks/tutorials/component_energy_area.ipynb +363 -0
  212. tests/Q_mapping.yaml +38 -0
  213. tests/__init__.py +0 -0
  214. tests/conv.mapping.yaml +27 -0
  215. tests/conv.workload.yaml +13 -0
  216. tests/conv_sym.mapping.yaml +43 -0
  217. tests/copy.mapping.yaml +35 -0
  218. tests/copy.workload.yaml +15 -0
  219. tests/distribuffers/__init__.py +0 -0
  220. tests/distribuffers/multicast/test_cases.yaml +482 -0
  221. tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
  222. tests/distribuffers/spec/distributed.yaml +100 -0
  223. tests/distribuffers/spec/logical_arch.yaml +32 -0
  224. tests/distribuffers/spec/physical_arch.yaml +69 -0
  225. tests/distribuffers/test_binding.py +48 -0
  226. tests/frontend/__init__.py +0 -0
  227. tests/frontend/test_mapping_viz.py +52 -0
  228. tests/mapper/__init__.py +0 -0
  229. tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
  230. tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
  231. tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
  232. tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
  233. tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
  234. tests/mapper/test_mapping_to_isl.py +90 -0
  235. tests/mapper/test_spatial_reuse_analysis.py +67 -0
  236. tests/mapper/test_temporal_reuse_analysis.py +56 -0
  237. tests/mapper/util.py +58 -0
  238. tests/matmul.mapping.yaml +29 -0
  239. tests/matmul.workload.yaml +12 -0
  240. tests/matmul_spatial.mapping.yaml +44 -0
  241. tests/mha.renames.yaml +65 -0
  242. tests/mha.workload.yaml +67 -0
  243. tests/mha.yaml +59 -0
  244. tests/mha_full.workload.yaml +67 -0
  245. tests/mobilenet.workload.yaml +35 -0
  246. tests/mobilenet_long.workload.yaml +64 -0
  247. tests/pmappingcache.py +24 -0
  248. tests/processing_stage.arch.yaml +40 -0
  249. tests/snowcat.arch.yaml +36 -0
  250. tests/test_ffm_join_pmappings.py +106 -0
  251. tests/test_ffm_make_pmappings.py +82 -0
  252. tests/test_ffm_make_tile_shapes.py +49 -0
  253. tests/test_mapper.py +100 -0
  254. tests/test_model.py +37 -0
  255. tests/test_plotting.py +72 -0
  256. tests/test_processing_stage.py +46 -0
  257. tests/test_symbolic_model.py +248 -0
  258. tests/test_workload.py +141 -0
@@ -0,0 +1,579 @@
1
+ import copy
2
+ import functools
3
+ import logging
4
+ import os
5
+ import glob
6
+ import re
7
+ import io
8
+ from typing import Callable, List, Dict, Any, OrderedDict, Tuple
9
+ import ruamel.yaml
10
+ import warnings
11
+ from ruamel.yaml.error import ReusedAnchorWarning
12
+ from jinja2 import StrictUndefined, Environment, FileSystemLoader
13
+ import threading
14
+ import time
15
+
16
+
17
+ PARSING_LOCK = threading.Lock()
18
+ THREAD_ID = 0
19
+
20
+ SCRIPTS_FROM = []
21
+ EXTRA_PLUG_IN_PATHS = []
22
+
23
+
24
+ class LockAcquirer:
25
+ def __init__(self):
26
+ self.has_lock = False
27
+
28
+ def __enter__(self):
29
+ while True:
30
+ if PARSING_LOCK.acquire(blocking=False):
31
+ global THREAD_ID
32
+ THREAD_ID = threading.get_ident()
33
+ self.has_lock = True
34
+ break
35
+ if THREAD_ID == threading.get_ident():
36
+ break
37
+ time.sleep(0.01)
38
+
39
+ def __exit__(self, exc_type, exc_value, traceback):
40
+ if self.has_lock:
41
+ PARSING_LOCK.release()
42
+
43
+
44
+ def recursive_mutator_stop(func):
45
+ return func
46
+ cache = set()
47
+
48
+ @functools.wraps(func)
49
+ def wrapper(*args, **kwargs):
50
+ assert not kwargs, (
51
+ f"Recursive mutator stop only works with non-keyword arguments. "
52
+ f"Args were {args} and kwargs were {kwargs}."
53
+ )
54
+ k = id(args[0])
55
+ if k in cache:
56
+ return args[0]
57
+ cache.add(k)
58
+ try:
59
+ result = func(*args, **kwargs)
60
+ finally:
61
+ cache.remove(k)
62
+ return result
63
+
64
+ return wrapper
65
+
66
+
67
+ def recursive_mutator_eq_stop(func):
68
+ return func
69
+ cache = set()
70
+
71
+ @functools.wraps(func)
72
+ def wrapper(*args, **kwargs):
73
+ if args[0] in cache:
74
+ failstr = (
75
+ f"Infinite recursion detected: {func.__name__}("
76
+ f"{args[0]}) called from within itself. Does a YAML file "
77
+ f"include itself?"
78
+ )
79
+ raise RuntimeError(failstr)
80
+ cache.add(args[0])
81
+ try:
82
+ result = func(*args, **kwargs)
83
+ finally:
84
+ cache.remove(args[0])
85
+ return result
86
+
87
+ return wrapper
88
+
89
+
90
+ class MultiIncludeWrapper:
91
+ def __init__(self, contents: List):
92
+ self.contents = contents
93
+
94
+
95
+ def append_path(p: str, cur_path: str, include_dirs: List[str]):
96
+ new_paths = find_paths(p, cur_path, include_dirs)
97
+ include_dirs += new_paths
98
+ logging.info(f"YAML Adding {new_paths} to include paths")
99
+ return ""
100
+
101
+
102
+ def find_paths(p: str, cur_path: str, include_dirs: List[str]):
103
+ if isinstance(p, list):
104
+ paths = [find_paths(x, cur_path, include_dirs) for x in p]
105
+ else:
106
+ searched = []
107
+ paths = []
108
+ prepend = (
109
+ [""] if os.path.isabs(p) else [os.path.dirname(cur_path)] + include_dirs
110
+ )
111
+
112
+ for d in prepend:
113
+ s = os.path.abspath(os.path.realpath(os.path.join(d, p)))
114
+ globbed_paths = glob.glob(s)
115
+ if globbed_paths:
116
+ paths.extend(globbed_paths)
117
+ searched.append(s)
118
+ if not paths:
119
+ raise FileNotFoundError(
120
+ f"Could not find file {p} in any of the following paths:"
121
+ + "\n "
122
+ + "\n ".join(searched)
123
+ )
124
+
125
+ unique_paths = []
126
+ uniques = set()
127
+ while paths:
128
+ p = os.path.realpath(os.path.abspath(paths.pop(0)))
129
+ if p not in uniques:
130
+ unique_paths.append(p)
131
+ uniques.add(p)
132
+ return unique_paths
133
+
134
+
135
+ def find_path(p: str, cur_path: str, include_dirs: List[str]):
136
+ prepend = [""] if os.path.isabs(p) else [os.path.dirname(cur_path)] + include_dirs
137
+ searched = []
138
+ for d in prepend:
139
+ s = os.path.abspath(os.path.realpath(os.path.join(d, p)))
140
+ if os.path.exists(s):
141
+ return s
142
+ searched.append(s)
143
+ raise FileNotFoundError(
144
+ f"Could not find file {p} in any of the following paths:"
145
+ + "\n "
146
+ + "\n ".join(searched)
147
+ )
148
+
149
+
150
+ @recursive_mutator_eq_stop
151
+ def load_file_and_includes(
152
+ path: str,
153
+ data: Dict[str, Any] = None,
154
+ include_dirs: List[str] = None,
155
+ ) -> Tuple[str, Dict[str, Any]]:
156
+ """
157
+ Load a YAML file and recursively load any included YAML files
158
+ :param path: string that specifies the path of the YAML file to be loaded
159
+ :param data: dictionary that contains the data to be rendered
160
+ :param include_dirs: list of directories to search for included files
161
+ :return: string that contains the loaded YAML content
162
+ """
163
+ path = os.path.abspath(os.path.realpath(path))
164
+ if not os.path.exists(path):
165
+ raise FileNotFoundError(f"Could not find file {path}")
166
+
167
+ data = data or {}
168
+ data = {k: v for k, v in data.items()}
169
+ include_dirs = include_dirs or []
170
+ include_dirs = [d for d in include_dirs]
171
+
172
+ include_counter = 0
173
+
174
+ def include(p, single, indices: str = ""):
175
+ # If the path is a relative path, make it relative to the current file
176
+ to_include = []
177
+ indices = indices.lstrip(".")
178
+ nonlocal include_counter
179
+ include_name = (
180
+ os.path.basename(path).rsplit(".", 1)[0] + "_" + str(include_counter)
181
+ )
182
+ include_name = re.sub(r"\W+", "", include_name)
183
+ for np in find_paths(p, path, include_dirs):
184
+ logging.info(
185
+ f"YAML Adding {np} to document with !include{'_all' if not single else ''}"
186
+ )
187
+ to_include.append(load_yaml(np, data, include_dirs))
188
+
189
+ if single:
190
+ if len(to_include) > 1:
191
+ raise RuntimeError(
192
+ f"More than one file found for {path}: {find_paths(path)}."
193
+ f"To include multiple files, use include_all()."
194
+ )
195
+
196
+ data[include_name] = (
197
+ to_include[0] if len(to_include) == 1 else MultiIncludeWrapper(to_include)
198
+ )
199
+ v = f"!include_loaded {include_name}"
200
+ if indices:
201
+ v += "." + indices
202
+ include_counter += 1
203
+ return v
204
+
205
+ def include_single(p, indices: str = ""):
206
+ return include(p, True, indices)
207
+
208
+ def include_all(p, indices: str = ""):
209
+ return include(p, False, indices)
210
+
211
+ def include_text(p):
212
+ found = []
213
+ for np in find_paths(p, path, include_dirs):
214
+ found.append(load_file_and_includes(np, data, include_dirs)[0])
215
+ logging.info(f"YAML Adding {np} to document with !include_text")
216
+ return "\n".join(found)
217
+
218
+ # Add include_as to the template environment
219
+ env = Environment(
220
+ loader=FileSystemLoader(os.path.dirname(path)), undefined=StrictUndefined
221
+ )
222
+
223
+ def setenv(key, value):
224
+ key, value = str(key), str(value)
225
+ os.environ[key] = value
226
+ return "{{ setenv('" + key + "', '" + value + "') }}}}"
227
+
228
+ def path_exists(p):
229
+ try:
230
+ find_path(p, path, include_dirs)
231
+ return True
232
+ except FileNotFoundError:
233
+ return False
234
+
235
+ env.globals["cwd"] = lambda: os.path.dirname(path)
236
+ env.globals["include"] = include_single
237
+ env.globals["include_all"] = include_all
238
+ env.globals["include_text"] = include_text
239
+ env.globals["find_path"] = lambda x: find_path(x, path, include_dirs)
240
+ env.globals["find_paths"] = lambda x: find_paths(x, path, include_dirs)
241
+ env.globals["path_exists"] = path_exists
242
+ env.globals["setenv"] = setenv
243
+
244
+ env.globals["add_to_path"] = lambda p: append_path(p, path, include_dirs)
245
+
246
+ path_file = path[len(os.path.dirname(path)) + 1 :]
247
+ template = env.get_template(path_file)
248
+ string = template.render(data, undefined=StrictUndefined)
249
+ return string, data, include_dirs
250
+
251
+
252
+ @recursive_mutator_stop
253
+ def merge_check(x: dict[str, Any] | list[Any] | Any) -> None:
254
+ if isinstance(x, list):
255
+ for i, v in enumerate(x):
256
+ x[i] = merge_check(v)
257
+ elif isinstance(x, dict):
258
+ found_merge = False
259
+ for k, v in list(x.items()):
260
+ x[k] = merge_check(v)
261
+ if str(k) == "<<<" or str(k) == "<<":
262
+ assert not found_merge, (
263
+ f'Cannot have multiple "<<<" or "<<" keys in a dict. '
264
+ f"Keys were {list(x.keys())}"
265
+ )
266
+ found_merge = True
267
+ x = merge(x, copy.deepcopy(x.pop(k)), str(k) == "<<<")
268
+ return x
269
+
270
+
271
+ ERRCOUNT = 0
272
+
273
+
274
+ def load_yaml(
275
+ path: str,
276
+ data: Dict[str, Any] = None,
277
+ include_dirs: List[str] = None,
278
+ ) -> Dict[str, Any]:
279
+ """
280
+ Load YAML content from a file or string
281
+ :param path: string that specifies the path of the YAML file to be loaded
282
+ :param data: dictionary that contains the data to be rendered
283
+ :param include_dirs: list of directories to search for included files
284
+ :return: parsed YAML content or YAML object
285
+ """
286
+ with LockAcquirer():
287
+ data = {k: v for k, v in data.items()} if data is not None else {}
288
+ path = os.path.abspath(os.path.realpath(path))
289
+ try:
290
+ parsed, data, include_dirs = load_file_and_includes(
291
+ path, data, include_dirs
292
+ )
293
+ except Exception as e:
294
+ e.add_note(f"Error loading YAML file {path}")
295
+ raise
296
+ try:
297
+ result = merge_check(get_yaml(path, data).load(parsed))
298
+ return result
299
+ except Exception as e:
300
+ global ERRCOUNT
301
+ failpath = f"/tmp/yaml_parse_error{ERRCOUNT}.yaml"
302
+ ERRCOUNT += 1
303
+ with open(failpath, "w") as f:
304
+ f.write(parsed)
305
+ e.add_note(
306
+ f"Error parsing YAML file {path}. Offending file written to "
307
+ f"{failpath}"
308
+ )
309
+ raise
310
+
311
+
312
+ @recursive_mutator_stop
313
+ def merge(
314
+ merge_into: dict, tomerge: dict | list | tuple, recursive: bool = True
315
+ ) -> dict:
316
+ if isinstance(tomerge, (list, tuple)):
317
+ combined = dict()
318
+ for m in tomerge:
319
+ combined = merge(combined, m, recursive)
320
+ tomerge = combined
321
+ if not isinstance(tomerge, dict):
322
+ raise ValueError(
323
+ f'Expected a dict under the "<<<" or "<<" keys, but ' f"got {tomerge}"
324
+ )
325
+ if not isinstance(merge_into, dict):
326
+ raise ValueError(
327
+ f'Expected to merge into a dict with the "<<<" key, '
328
+ f"but got {merge_into}"
329
+ )
330
+
331
+ for k, v in tomerge.items():
332
+ if k not in merge_into:
333
+ merge_into[k] = v
334
+ elif not recursive:
335
+ continue
336
+ elif isinstance(merge_into[k], (NoMergeListWrapper, NoMergeDictWrapper)):
337
+ continue
338
+ elif isinstance(merge_into[k], dict) and isinstance(v, dict):
339
+ merge_into[k] = merge(merge_into[k], v, recursive)
340
+ elif isinstance(merge_into[k], list) and isinstance(v, list):
341
+ merge_into[k] = merge_into[k] + v
342
+ return merge_into
343
+
344
+
345
+ def represent_none(self, data: None) -> str:
346
+ """
347
+ Represent None as 'null' in YAML
348
+ :param self: YAML representer object
349
+ :param data: None object to be represented
350
+ :return: 'null' string
351
+ """
352
+ return self.represent_scalar("tag:yaml.org,2002:null", "null")
353
+
354
+
355
+ def ordereddict_to_dict(self, dictionary: OrderedDict) -> Dict[str, Any]:
356
+ """
357
+ Change an OrderedDict to a dictionary in YAML
358
+ :param self: YAML representer object
359
+ :param dictionary: OrderedDict object to be represented
360
+ :return: dictionary object
361
+ """
362
+ d = {}
363
+ for key in dictionary.keys():
364
+ d[key] = dictionary[key]
365
+ return self.represent_dict(dictionary)
366
+
367
+
368
+ @recursive_mutator_stop
369
+ def recursive_unorder_dict(to_unorder: Dict[str, Any]) -> Dict[str, Any]:
370
+ if isinstance(to_unorder, dict):
371
+ return {k: recursive_unorder_dict(v) for k, v in to_unorder.items()}
372
+ elif isinstance(to_unorder, list):
373
+ return [recursive_unorder_dict(v) for v in to_unorder]
374
+ return to_unorder
375
+
376
+
377
+ @recursive_mutator_stop
378
+ def callables2strings(to_convert: Dict[str, Any]) -> Dict[str, Any]:
379
+ if isinstance(to_convert, dict):
380
+ to_convert = {k: callables2strings(v) for k, v in to_convert.items()}
381
+ elif isinstance(to_convert, list):
382
+ to_convert = [callables2strings(v) for v in to_convert]
383
+ elif isinstance(to_convert, Callable):
384
+ to_convert = str(getattr(to_convert, "_original_expression", to_convert))
385
+ return to_convert
386
+
387
+
388
+ def write_yaml_file(filepath: str, content: Dict[str, Any]) -> None:
389
+ """
390
+ Write YAML content to a file
391
+ :param filepath: string that specifies the destination file path
392
+ :param content: YAML string that needs to be written to the destination file
393
+ :return: None
394
+ """
395
+ if os.path.exists(filepath):
396
+ os.remove(filepath)
397
+ if os.path.dirname(filepath):
398
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
399
+ out_file = open(filepath, "a")
400
+ out_file.write(to_yaml_string(content))
401
+
402
+
403
+ def to_yaml_string(content: Dict[str, Any]) -> str:
404
+ """
405
+ Convert YAML content to a string
406
+ :param content: YAML content to be converted to a string
407
+ :return: string representation of the YAML content
408
+ """
409
+ with LockAcquirer():
410
+ dumpstream = io.StringIO()
411
+ get_base_yaml().dump(
412
+ callables2strings(recursive_unorder_dict(content)), stream=dumpstream
413
+ )
414
+ return dumpstream.getvalue()
415
+
416
+
417
+ def get_base_yaml() -> ruamel.yaml.YAML:
418
+ yaml = ruamel.yaml.YAML(typ="rt")
419
+ # yaml.default_flow_style = None
420
+ yaml.indent(mapping=4, sequence=4, offset=2)
421
+ yaml.preserve_quotes = True
422
+
423
+ def recursive_mutator_stop(func):
424
+ cache = set()
425
+
426
+ @functools.wraps(func)
427
+ def wrapper(*args, **kwargs):
428
+ assert not kwargs, (
429
+ f"Recursive mutator stop only works with non-keyword "
430
+ f"arguments. Args were {args} and kwargs were {kwargs}."
431
+ )
432
+ k = id(args[0])
433
+ if k in cache:
434
+ return args[0]
435
+ cache.add(k)
436
+ try:
437
+ result = func(*args, **kwargs)
438
+ finally:
439
+ cache.remove(k)
440
+ return result
441
+
442
+ return wrapper
443
+
444
+ yaml.representer.add_representer(type(None), recursive_mutator_stop(represent_none))
445
+ yaml.representer.add_representer(
446
+ OrderedDict, recursive_mutator_stop(ordereddict_to_dict)
447
+ )
448
+
449
+ return yaml
450
+
451
+
452
+ def get_yaml(path: str, data: Dict[str, Any] = None) -> ruamel.yaml.YAML:
453
+ """Get a YAML object with the right settings"""
454
+ yaml = get_base_yaml()
455
+ ymf = YAMLFileLoader(path, data)
456
+ # yaml.default_flow_style = None
457
+
458
+ warnings.simplefilter("ignore", ReusedAnchorWarning)
459
+ yaml.constructor.add_constructor("!include_loaded", ymf.include_loaded)
460
+ yaml.constructor.add_constructor("!include", ymf.include)
461
+ yaml.constructor.add_constructor("!includedir", ymf.includedir)
462
+ yaml.constructor.add_constructor("!nomerge", ymf.nomerge)
463
+
464
+ return yaml
465
+
466
+
467
+ class NoMergeListWrapper(list):
468
+ def __init__(self, *args, **kwargs):
469
+ super().__init__(*args, **kwargs)
470
+
471
+
472
+ class NoMergeDictWrapper(dict):
473
+ def __init__(self, *args, **kwargs):
474
+ super().__init__(*args, **kwargs)
475
+
476
+
477
+ class YAMLFileLoader:
478
+ def __init__(self, path: str, data: Dict[str, Any] = None) -> None:
479
+ self.path = path
480
+ self.data = data or {}
481
+ self.include_counter = 0
482
+ self.loading_from_dir = os.path.abspath(os.path.dirname(path))
483
+ self.include_data = data or {}
484
+ self.env = Environment(
485
+ loader=FileSystemLoader(os.path.dirname(path)), undefined=StrictUndefined
486
+ )
487
+
488
+ def nomerge(self, constructor, node):
489
+ # print(f"Got node {node}")
490
+ # Pop the tag
491
+ # node.tag = None
492
+ # print(f"Got node {node}")
493
+ # # value = constructor.construct_object(node, deep=True)
494
+ # print(f"COnstructed object {value}")
495
+ if isinstance(node, ruamel.yaml.nodes.SequenceNode):
496
+ return NoMergeListWrapper(constructor.construct_sequence(node, deep=True))
497
+ if isinstance(node, ruamel.yaml.nodes.MappingNode):
498
+ return NoMergeDictWrapper(
499
+ ruamel.yaml.SafeConstructor.construct_mapping(
500
+ constructor, node, deep=True
501
+ )
502
+ )
503
+ raise ValueError(f"!nomerge tag must be applied to a list or dict, not {node}")
504
+
505
+ def include_loaded(
506
+ self,
507
+ constructor: ruamel.yaml.constructor.Constructor,
508
+ node: ruamel.yaml.nodes.ScalarNode,
509
+ ) -> dict[str, Any] | None:
510
+ """
511
+ Constructor that parses the !include_loaded relative_file_path and loads the file
512
+ from relative_file_path
513
+ :param self: YAML constructor object
514
+ :param node: YAML node object
515
+ :return: parsed YAML content
516
+ """
517
+ x = constructor.construct_scalar(node)
518
+ found = self.include_data
519
+ for k in str(x).split("."):
520
+ try:
521
+ if isinstance(found, MultiIncludeWrapper):
522
+ for i, f in enumerate(found.contents):
523
+ found.contents[i] = f[k]
524
+ else:
525
+ found = found[k]
526
+ except (KeyError, TypeError) as e:
527
+ if isinstance(found, MultiIncludeWrapper):
528
+ e.add_note(
529
+ f"Could not parse !include_loaded {x}: {k} not found " f"in {f}"
530
+ )
531
+ raise
532
+ e.add_note(
533
+ f"Could not parse !include_loaded {x}: {k} not found "
534
+ f"in {list(found.keys())}"
535
+ )
536
+ raise
537
+ return found.contents if isinstance(found, MultiIncludeWrapper) else found
538
+
539
+ def include(
540
+ self,
541
+ constructor: ruamel.yaml.constructor.Constructor,
542
+ node: ruamel.yaml.nodes.ScalarNode,
543
+ ) -> dict[str, Any] | None:
544
+ """
545
+ Constructor that parses !include relative_file_path and loads the file
546
+ from relative_file_path
547
+ :param self: YAML constructor object
548
+ :param node: YAML node object
549
+ :return: parsed YAML content
550
+ """
551
+ filepath = constructor.construct_scalar(node)
552
+ if filepath[-1] == ",":
553
+ filepath = filepath[:-1]
554
+ load_from = self.loading_from_dir
555
+ return load_yaml(os.path.join(load_from, filepath), self.include_data)
556
+
557
+ def includedir(
558
+ self,
559
+ constructor: ruamel.yaml.constructor.Constructor,
560
+ node: ruamel.yaml.nodes.ScalarNode,
561
+ ) -> list[dict[str, Any]]:
562
+ """
563
+ Constructor that parses the !includedir relative_file_path and loads the
564
+ file from relative_file_path
565
+ :param self: YAML constructor object
566
+ :param node: YAML node object
567
+ :return: list of parsed YAML contents
568
+ """
569
+ filepath = constructor.construct_scalar(node)
570
+ if filepath[-1] == ",":
571
+ filepath = filepath[:-1]
572
+ dirname = os.path.join(self.loading_from_dir, filepath)
573
+ yamllist = []
574
+ for filename in glob.glob(dirname + "/*.yaml"):
575
+ yamllist.append(load_yaml(filename, self.include_data))
576
+ return yamllist
577
+
578
+
579
+ yaml = get_base_yaml()