accelforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. accelforge/__init__.py +21 -0
  2. accelforge/_accelerated_imports.py +16 -0
  3. accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
  4. accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
  5. accelforge/_deprecate/_simanneal/simanneal.py +666 -0
  6. accelforge/_deprecate/_simanneal/tracking.py +105 -0
  7. accelforge/_deprecate/_simanneal/wrappers.py +218 -0
  8. accelforge/_deprecate/_simanneal2/__init__.py +7 -0
  9. accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
  10. accelforge/_deprecate/_simanneal2/tracking.py +116 -0
  11. accelforge/_deprecate/compatibility_util.py +181 -0
  12. accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
  13. accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
  14. accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
  15. accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
  16. accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
  17. accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
  18. accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
  19. accelforge/_deprecate/tags.py +69 -0
  20. accelforge/_deprecate/viz/__init__.py +0 -0
  21. accelforge/_deprecate/viz/interactive.py +159 -0
  22. accelforge/_deprecate/viz/reservationtree.py +307 -0
  23. accelforge/_deprecate/viz/ski_slope.py +88 -0
  24. accelforge/_version.py +15 -0
  25. accelforge/examples.py +39 -0
  26. accelforge/frontend/__init__.py +10 -0
  27. accelforge/frontend/_binding.py +129 -0
  28. accelforge/frontend/_workload_isl/__init__.py +2 -0
  29. accelforge/frontend/_workload_isl/_isl.py +149 -0
  30. accelforge/frontend/_workload_isl/_symbolic.py +141 -0
  31. accelforge/frontend/arch copy.py +1544 -0
  32. accelforge/frontend/arch.py +1642 -0
  33. accelforge/frontend/config.py +63 -0
  34. accelforge/frontend/mapper/__init__.py +5 -0
  35. accelforge/frontend/mapper/ffm.py +126 -0
  36. accelforge/frontend/mapper/mapper.py +7 -0
  37. accelforge/frontend/mapper/metrics.py +30 -0
  38. accelforge/frontend/mapping/__init__.py +1 -0
  39. accelforge/frontend/mapping/mapping.py +1736 -0
  40. accelforge/frontend/model.py +14 -0
  41. accelforge/frontend/renames.py +150 -0
  42. accelforge/frontend/spec copy.py +230 -0
  43. accelforge/frontend/spec.py +301 -0
  44. accelforge/frontend/variables.py +12 -0
  45. accelforge/frontend/workload.py +952 -0
  46. accelforge/mapper/FFM/__init__.py +9 -0
  47. accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
  48. accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
  49. accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
  50. accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
  51. accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
  52. accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
  53. accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
  54. accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
  55. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
  56. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
  57. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
  58. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
  59. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
  60. accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
  61. accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
  62. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
  63. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
  64. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
  65. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
  66. accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
  67. accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
  68. accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
  69. accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
  70. accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
  71. accelforge/mapper/FFM/data.py +61 -0
  72. accelforge/mapper/FFM/main copy.py +236 -0
  73. accelforge/mapper/FFM/main.py +208 -0
  74. accelforge/mapper/FFM/mappings.py +510 -0
  75. accelforge/mapper/FFM/pmappings.py +310 -0
  76. accelforge/mapper/__init__.py +4 -0
  77. accelforge/mapper.py +0 -0
  78. accelforge/model/__init__.py +1 -0
  79. accelforge/model/_looptree/__init__.py +0 -0
  80. accelforge/model/_looptree/accesses.py +335 -0
  81. accelforge/model/_looptree/capacity/__init__.py +1 -0
  82. accelforge/model/_looptree/capacity/aggregators.py +36 -0
  83. accelforge/model/_looptree/capacity/capacity.py +47 -0
  84. accelforge/model/_looptree/energy.py +150 -0
  85. accelforge/model/_looptree/equivalent_ranks.py +29 -0
  86. accelforge/model/_looptree/latency/__init__.py +1 -0
  87. accelforge/model/_looptree/latency/latency.py +98 -0
  88. accelforge/model/_looptree/latency/memory.py +120 -0
  89. accelforge/model/_looptree/latency/processors.py +92 -0
  90. accelforge/model/_looptree/mapping_utilities.py +71 -0
  91. accelforge/model/_looptree/reuse/__init__.py +4 -0
  92. accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
  93. accelforge/model/_looptree/reuse/isl/des.py +59 -0
  94. accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
  95. accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
  96. accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
  97. accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
  98. accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
  99. accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
  100. accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
  101. accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
  102. accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
  103. accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
  104. accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
  105. accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
  106. accelforge/model/_looptree/run.py +122 -0
  107. accelforge/model/_looptree/types.py +26 -0
  108. accelforge/model/_looptree/visualization/__init__.py +0 -0
  109. accelforge/model/_looptree/visualization/occupancy.py +11 -0
  110. accelforge/model/main.py +222 -0
  111. accelforge/plotting/__init__.py +2 -0
  112. accelforge/plotting/mappings.py +219 -0
  113. accelforge/plotting/specs.py +57 -0
  114. accelforge/util/__init__.py +4 -0
  115. accelforge/util/_base_analysis_types.py +24 -0
  116. accelforge/util/_basetypes.py +1089 -0
  117. accelforge/util/_frozenset.py +36 -0
  118. accelforge/util/_isl.py +29 -0
  119. accelforge/util/_itertools.py +14 -0
  120. accelforge/util/_mathfuncs.py +57 -0
  121. accelforge/util/_parse_expressions.py +339 -0
  122. accelforge/util/_picklecache.py +32 -0
  123. accelforge/util/_setexpressions.py +268 -0
  124. accelforge/util/_sympy/__init__.py +0 -0
  125. accelforge/util/_sympy/broadcast_max.py +18 -0
  126. accelforge/util/_visualization.py +112 -0
  127. accelforge/util/_yaml.py +579 -0
  128. accelforge/util/parallel.py +193 -0
  129. accelforge-0.0.1.dist-info/METADATA +64 -0
  130. accelforge-0.0.1.dist-info/RECORD +258 -0
  131. accelforge-0.0.1.dist-info/WHEEL +5 -0
  132. accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
  133. accelforge-0.0.1.dist-info/top_level.txt +5 -0
  134. docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
  135. docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
  136. docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
  137. docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
  138. docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
  139. docs/_build/html/_sources/fastfusion.rst.txt +20 -0
  140. docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
  141. docs/_build/html/_sources/index.rst.txt +87 -0
  142. docs/_build/html/_sources/modules.rst.txt +7 -0
  143. docs/_build/html/_sources/notes/citation.rst.txt +45 -0
  144. docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
  145. docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
  146. docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
  147. docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
  148. docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
  149. docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
  150. docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
  151. docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
  152. docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
  153. docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
  154. docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
  155. docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
  156. docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
  157. docs/_build/html/_sources/notes/spec.rst.txt +36 -0
  158. docs/source/_ext/include_attrs.py +213 -0
  159. docs/source/_ext/include_docstring.py +364 -0
  160. docs/source/_ext/include_functions.py +154 -0
  161. docs/source/_ext/include_notebook.py +131 -0
  162. docs/source/_ext/include_yaml.py +119 -0
  163. docs/source/_ext/inherited_attributes.py +222 -0
  164. docs/source/_ext/paths.py +4 -0
  165. docs/source/conf.py +79 -0
  166. examples/arches/compute_in_memory/_include.yaml +74 -0
  167. examples/arches/compute_in_memory/_include_functions.py +229 -0
  168. examples/arches/compute_in_memory/_load_spec.py +57 -0
  169. examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
  170. examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
  171. examples/arches/compute_in_memory/components/misc.py +195 -0
  172. examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
  173. examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
  174. examples/arches/compute_in_memory/isaac.yaml +233 -0
  175. examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
  176. examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
  177. examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
  178. examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
  179. examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
  180. examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
  181. examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
  182. examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
  183. examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
  184. examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
  185. examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
  186. examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
  187. examples/arches/eyeriss.yaml +68 -0
  188. examples/arches/fanout_variations/at_glb.yaml +31 -0
  189. examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
  190. examples/arches/fanout_variations/at_mac.yaml +31 -0
  191. examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
  192. examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
  193. examples/arches/nvdla.yaml +47 -0
  194. examples/arches/simple.yaml +28 -0
  195. examples/arches/tpu_v4i.yaml +67 -0
  196. examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
  197. examples/misc/component_annotated.yaml +33 -0
  198. examples/workloads/gpt3_6.7B.yaml +124 -0
  199. examples/workloads/matmuls.yaml +20 -0
  200. examples/workloads/mobilenet_28.yaml +81 -0
  201. examples/workloads/mobilenet_various_separate.yaml +106 -0
  202. examples/workloads/three_matmuls_annotated.yaml +59 -0
  203. notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
  204. notebooks/compute_in_memory/_scripts.py +339 -0
  205. notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
  206. notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
  207. notebooks/paths.py +4 -0
  208. notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
  209. notebooks/tutorials/FFM.ipynb +3498 -0
  210. notebooks/tutorials/_include.py +48 -0
  211. notebooks/tutorials/component_energy_area.ipynb +363 -0
  212. tests/Q_mapping.yaml +38 -0
  213. tests/__init__.py +0 -0
  214. tests/conv.mapping.yaml +27 -0
  215. tests/conv.workload.yaml +13 -0
  216. tests/conv_sym.mapping.yaml +43 -0
  217. tests/copy.mapping.yaml +35 -0
  218. tests/copy.workload.yaml +15 -0
  219. tests/distribuffers/__init__.py +0 -0
  220. tests/distribuffers/multicast/test_cases.yaml +482 -0
  221. tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
  222. tests/distribuffers/spec/distributed.yaml +100 -0
  223. tests/distribuffers/spec/logical_arch.yaml +32 -0
  224. tests/distribuffers/spec/physical_arch.yaml +69 -0
  225. tests/distribuffers/test_binding.py +48 -0
  226. tests/frontend/__init__.py +0 -0
  227. tests/frontend/test_mapping_viz.py +52 -0
  228. tests/mapper/__init__.py +0 -0
  229. tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
  230. tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
  231. tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
  232. tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
  233. tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
  234. tests/mapper/test_mapping_to_isl.py +90 -0
  235. tests/mapper/test_spatial_reuse_analysis.py +67 -0
  236. tests/mapper/test_temporal_reuse_analysis.py +56 -0
  237. tests/mapper/util.py +58 -0
  238. tests/matmul.mapping.yaml +29 -0
  239. tests/matmul.workload.yaml +12 -0
  240. tests/matmul_spatial.mapping.yaml +44 -0
  241. tests/mha.renames.yaml +65 -0
  242. tests/mha.workload.yaml +67 -0
  243. tests/mha.yaml +59 -0
  244. tests/mha_full.workload.yaml +67 -0
  245. tests/mobilenet.workload.yaml +35 -0
  246. tests/mobilenet_long.workload.yaml +64 -0
  247. tests/pmappingcache.py +24 -0
  248. tests/processing_stage.arch.yaml +40 -0
  249. tests/snowcat.arch.yaml +36 -0
  250. tests/test_ffm_join_pmappings.py +106 -0
  251. tests/test_ffm_make_pmappings.py +82 -0
  252. tests/test_ffm_make_tile_shapes.py +49 -0
  253. tests/test_mapper.py +100 -0
  254. tests/test_model.py +37 -0
  255. tests/test_plotting.py +72 -0
  256. tests/test_processing_stage.py +46 -0
  257. tests/test_symbolic_model.py +248 -0
  258. tests/test_workload.py +141 -0
@@ -0,0 +1,1089 @@
1
+ import copy
2
+ import glob
3
+ import logging
4
+ import os
5
+ from pathlib import Path
6
+ import re
7
+ from pydantic import BaseModel, ConfigDict, Tag, ValidationError
8
+ from pydantic.main import IncEx
9
+ from pydantic_core.core_schema import (
10
+ CoreSchema,
11
+ chain_schema,
12
+ list_schema,
13
+ union_schema,
14
+ no_info_plain_validator_function,
15
+ str_schema,
16
+ dict_schema,
17
+ tagged_union_schema,
18
+ )
19
+ from typing import (
20
+ Iterator,
21
+ List,
22
+ Mapping,
23
+ TypeVar,
24
+ Generic,
25
+ Any,
26
+ Callable,
27
+ TypeVarTuple,
28
+ Dict,
29
+ Optional,
30
+ Type,
31
+ TypeAlias,
32
+ Union,
33
+ get_args,
34
+ get_origin,
35
+ TYPE_CHECKING,
36
+ Self,
37
+ )
38
+
39
+ from accelforge.util import _yaml
40
+ from accelforge.util._parse_expressions import (
41
+ parse_expression,
42
+ ParseError,
43
+ LiteralString,
44
+ is_literal_string,
45
+ )
46
+
47
+ # Import will be resolved at runtime to avoid circular dependency
48
+ TYPE_CHECKING_RUNTIME = False
49
+ if TYPE_CHECKING or TYPE_CHECKING_RUNTIME:
50
+ from accelforge.util._setexpressions import InvertibleSet, eval_set_expression
51
+
52
+ T = TypeVar("T")
53
+ M = TypeVar("M", bound=BaseModel)
54
+ K = TypeVar("K")
55
+ V = TypeVar("V")
56
+ PM = TypeVar("PM", bound="ParsableModel")
57
+ PL = TypeVar("PL", bound="ParsableList[Any]")
58
+
59
+ Ts = TypeVarTuple("Ts")
60
+
61
+
62
+ def _get_tag(value: Any) -> str:
63
+ if not isinstance(value, dict):
64
+ return value.__class__.__name__
65
+ tag = None
66
+
67
+ def try_get_tag(attr: str) -> str:
68
+ if hasattr(value, attr) and getattr(value, attr) is not None:
69
+ return getattr(value, attr)
70
+ return None
71
+
72
+ def try_index(attr: str) -> str:
73
+ try:
74
+ return value[attr]
75
+ except:
76
+ return None
77
+
78
+ tag = None
79
+ for attr in ("type", "_type", "_yaml_tag"):
80
+ if tag := try_get_tag(attr):
81
+ break
82
+ if tag := try_index(attr):
83
+ break
84
+ if tag is None:
85
+ raise ValueError(
86
+ f"No tag found for {value}. Either set the type field " "or use a YAML tag."
87
+ )
88
+ tag = str(tag)
89
+ if tag.startswith("!"):
90
+ tag = tag[1:]
91
+ return tag
92
+
93
+
94
+ def _uninstantiable(cls):
95
+ prev_init = cls.__init__
96
+
97
+ def _get_all_subclasses(cls):
98
+ subclasses = set()
99
+ for subclass in cls.__subclasses__():
100
+ subclasses.add(subclass.__name__)
101
+ subclasses.update(_get_all_subclasses(subclass))
102
+ return subclasses
103
+
104
+ def __init__(self, *args, **kwargs):
105
+ if self.__class__ is cls:
106
+ subclasses = _get_all_subclasses(cls)
107
+ raise ValueError(
108
+ f"{cls} can not be instantiated directly. Use a subclass. "
109
+ f"Supported subclasses are:\n\t" + "\n\t".join(sorted(subclasses))
110
+ )
111
+ return prev_init(self, *args, **kwargs)
112
+
113
+ cls.__init__ = __init__
114
+ return cls
115
+
116
+
117
+ class _InferFromTag(Generic[*Ts]):
118
+ @classmethod
119
+ def __get_pydantic_core_schema__(
120
+ cls, source_type: Any, handler: Callable
121
+ ) -> CoreSchema:
122
+ type_args = get_args(source_type)
123
+ if not type_args:
124
+ raise TypeError(
125
+ f"_InferFromTag must be used with a type parameter, e.g. _InferFromTag[int]"
126
+ )
127
+
128
+ # type_args contains all the possible types: (Compute, Memory, "Hierarchical")
129
+ target_types = []
130
+ for arg in type_args:
131
+ if isinstance(arg, str):
132
+ # Handle string type names - we'll need to resolve them later
133
+ target_types.append(arg)
134
+ elif isinstance(arg, type):
135
+ target_types.append(arg)
136
+ else:
137
+ target_types.append(arg)
138
+
139
+ # Create tag to class mapping
140
+ tag2class = {}
141
+ for target_type in target_types:
142
+ if isinstance(target_type, str):
143
+ # For string types, use the string as both key and placeholder
144
+ tag2class[target_type] = target_type
145
+ elif hasattr(target_type, "__name__"):
146
+ tag2class[target_type.__name__] = target_type
147
+ else:
148
+ # Fallback for other types
149
+ tag2class[str(target_type)] = target_type
150
+
151
+ def validate(value: Any) -> T:
152
+ if hasattr(value, "_yaml_tag"):
153
+ tag = value._yaml_tag
154
+ elif hasattr(value, "_type"):
155
+ tag = value._type
156
+ else:
157
+ for to_try in ("_yaml_tag", "_type", "type"):
158
+ try:
159
+ tag = value[to_try]
160
+ break
161
+ except:
162
+ pass
163
+ else:
164
+ raise ValueError(
165
+ f"No tag found for {value}. Either set the type field "
166
+ "or use a YAML tag."
167
+ )
168
+ tag = str(tag)
169
+ if tag.startswith("!"):
170
+ tag = tag[1:]
171
+ value._type = tag
172
+
173
+ print(f"Tag found! {tag}")
174
+ if tag in tag2class:
175
+ return tag2class[tag](**value)
176
+ else:
177
+ raise ValueError(
178
+ f"Unknown tag: {tag}. Supported tags are: {sorted(tag2class.keys())}"
179
+ )
180
+
181
+ # target_schema = handler.generate_schema(target_types)
182
+ schemas = []
183
+ for t in target_types:
184
+ schemas.append(handler.generate_schema(t))
185
+ target_schema = union_schema(schemas)
186
+ # return chain_schema([
187
+ # no_info_plain_validator_function(validate),
188
+ # target_schema
189
+ # ])
190
+ return chain_schema(
191
+ [
192
+ no_info_plain_validator_function(validate),
193
+ tagged_union_schema(tag2class, discriminator="_type"),
194
+ ]
195
+ )
196
+
197
+
198
+ class NoParse(Generic[T]):
199
+ """A type skips parsing of the specified object."""
200
+
201
+ _class_name: str = "NoParse"
202
+
203
+ def __init__(self, value: T):
204
+ self._value = value
205
+ self._type = T
206
+
207
+ @classmethod
208
+ def __get_pydantic_core_schema__(
209
+ cls, source_type: Any, handler: Callable
210
+ ) -> CoreSchema:
211
+ # Get the type parameter T from ParsesTo[T]
212
+ type_args = get_args(source_type)
213
+ if not type_args:
214
+ raise TypeError(
215
+ f"{cls._class_name} must be used with a type parameter, "
216
+ f"e.g. {cls._class_name}[int]"
217
+ )
218
+ target_type = type_args[0]
219
+
220
+ # Get the schema for the target type
221
+ target_schema = handler(target_type)
222
+
223
+ def validate_raw_string(value):
224
+ if isinstance(value, str) and is_literal_string(value):
225
+ return LiteralString(value)
226
+ # raise ValueError("Not a raw string")
227
+
228
+ # Create a union schema that either validates as raw string or normal validation
229
+ return target_schema
230
+
231
+
232
+ class ParsesTo(Generic[T]):
233
+ """A type that parses to the specified type T.
234
+
235
+ Example:
236
+ class Example(ParsableModel):
237
+ a: ParsesTo[int] # Will parse string expressions to integers
238
+ b: ParsesTo[str] # Will parse string expressions to strings
239
+ c: str # Regular string, no parsing
240
+ """
241
+
242
+ _class_name: str = "ParsesTo"
243
+
244
+ def __init__(self, value: str):
245
+ self._value = value
246
+ self._is_literal_string = is_literal_string(value)
247
+ self._type = T
248
+
249
+ assert self._type != str, (
250
+ f"{self._class_name}[str] is not allowed. Use str directly instead."
251
+ f"If something should just be a string, no expressions are allowed. "
252
+ f"This is so the users don't have to quote-wrap all strings."
253
+ )
254
+
255
+ def __str__(self) -> str:
256
+ return str(self._value)
257
+
258
+ def __repr__(self) -> str:
259
+ return f"{self._class_name}({repr(self._value)})"
260
+
261
+ @classmethod
262
+ def __get_pydantic_core_schema__(
263
+ cls, source_type: Any, handler: Callable
264
+ ) -> CoreSchema:
265
+ # Get the type parameter T from ParsesTo[T]
266
+ type_args = get_args(source_type)
267
+ if not type_args:
268
+ raise TypeError(
269
+ f"{cls._class_name} must be used with a type parameter, "
270
+ f"e.g. {cls._class_name}[int]"
271
+ )
272
+ target_type = type_args[0]
273
+
274
+ # Get the schema for the target type
275
+ target_schema = handler(target_type)
276
+
277
+ def validate_raw_string(value):
278
+ if isinstance(value, str) and is_literal_string(value):
279
+ return LiteralString(value)
280
+ # raise ValueError("Not a raw string")
281
+
282
+ # Create a union schema that either validates as raw string or normal validation
283
+ return union_schema(
284
+ [
285
+ # First option: validate as raw string
286
+ chain_schema(
287
+ [
288
+ no_info_plain_validator_function(validate_raw_string),
289
+ str_schema(),
290
+ # target_schema
291
+ ]
292
+ ),
293
+ # Second option: normal validation (string then target type)
294
+ chain_schema(
295
+ [
296
+ str_schema(),
297
+ # target_schema
298
+ ]
299
+ ),
300
+ # Third option: direct target type validation
301
+ target_schema,
302
+ ]
303
+ )
304
+
305
+
306
+ class TryParseTo(ParsesTo, Generic[T]):
307
+ """
308
+ A type that tries to parse to the specified type T. If the parsing fails, the value
309
+ is returned as a string.
310
+ """
311
+
312
+ _class_name: str = "TryParseTo"
313
+
314
+ def __init__(self, value: str):
315
+ super().__init__(value)
316
+
317
+ @classmethod
318
+ def __get_pydantic_core_schema__(
319
+ cls, source_type: Any, handler: Callable
320
+ ) -> CoreSchema:
321
+ # Get the type parameter T from ParsesTo[T]
322
+ type_args = get_args(source_type)
323
+ if not type_args:
324
+ raise TypeError(
325
+ f"{cls._class_name} must be used with a type parameter, "
326
+ f"e.g. {cls._class_name}[int]"
327
+ )
328
+ target_type = type_args[0]
329
+
330
+ # Get the schema for the target type
331
+ target_schema = handler(target_type)
332
+
333
+ def validate_raw_string(value):
334
+ if isinstance(value, str) and is_literal_string(value):
335
+ return LiteralString(value)
336
+ # raise ValueError("Not a raw string")
337
+
338
+ # Create a union schema that either validates as raw string or normal validation
339
+ return union_schema(
340
+ [
341
+ # First option: validate as raw string
342
+ chain_schema(
343
+ [
344
+ no_info_plain_validator_function(validate_raw_string),
345
+ str_schema(),
346
+ # target_schema
347
+ ]
348
+ ),
349
+ # Second option: normal validation (string then target type)
350
+ chain_schema(
351
+ [
352
+ str_schema(),
353
+ # target_schema
354
+ ]
355
+ ),
356
+ # Third option: direct target type validation
357
+ target_schema,
358
+ # Fourth option: return the value as a string
359
+ str_schema(),
360
+ ]
361
+ )
362
+
363
+
364
+ if TYPE_CHECKING:
365
+ try:
366
+ from typing_extensions import TypeAliasType
367
+
368
+ _T_alias = TypeVar("_T_alias")
369
+ ParsesTo = TypeAliasType("ParsesTo", _T_alias, type_params=(_T_alias,))
370
+ TryParseTo = TypeAliasType("TryParseTo", _T_alias, type_params=(_T_alias,))
371
+ except Exception:
372
+ # Best-effort fallback for type checkers that don't support TypeAliasType
373
+ pass
374
+
375
+
376
+ class _PostCall(Generic[T]):
377
+ def __call__(self, field: str, value: T, symbol_table: dict[str, Any]) -> T:
378
+ return value
379
+
380
+
381
+ @_uninstantiable
382
+ class Parsable(Generic[M]):
383
+ """An abstract base class for parsing. Parsables support the `_parse_expressions`
384
+ method, which is used to parse the object from a string.
385
+ """
386
+
387
+ def _parse_expressions(
388
+ self, symbol_table: dict[str, Any] = None, **kwargs
389
+ ) -> tuple[M, dict[str, Any]]:
390
+ raise NotImplementedError("Subclasses must implement this method")
391
+
392
+ def get_fields(self) -> list[str]:
393
+ raise NotImplementedError("Subclasses must implement this method")
394
+
395
+ def get_validator(self, field: str) -> type:
396
+ raise NotImplementedError("Subclasses must implement this method")
397
+
398
+ def _parse_expressions_final(
399
+ self,
400
+ symbol_table: dict[str, Any],
401
+ order: tuple[str, ...],
402
+ post_calls: tuple[_PostCall[T], ...],
403
+ use_setattr: bool = True,
404
+ already_parsed: dict[str, Any] | None = None,
405
+ **kwargs,
406
+ ) -> tuple["Parsable", dict[str, Any]]:
407
+ self._parsed = True
408
+
409
+ if already_parsed is None:
410
+ already_parsed = {}
411
+
412
+ fields = [f for f in self.get_fields() if f not in already_parsed]
413
+
414
+ field_order = _get_parsable_field_order(
415
+ order,
416
+ [
417
+ (
418
+ f,
419
+ getattr(self, f) if use_setattr else self[f],
420
+ self.get_validator(f),
421
+ )
422
+ for f in fields
423
+ ],
424
+ )
425
+ prev_symbol_table = symbol_table.copy()
426
+ # for k, v in symbol_table.items():
427
+ # if isinstance(k, str) and k.startswith("global_") and v is None:
428
+ # raise ParseError(
429
+ # f"Global variable {k} is required. Please set it in "
430
+ # f"either the attributes or an outer scope. Try setting it with "
431
+ # f"Spec.variables.{k} = [value]."
432
+ # )
433
+
434
+ for field, value in already_parsed.items():
435
+ symbol_table[field] = value
436
+ if use_setattr:
437
+ setattr(self, field, value)
438
+ else:
439
+ self[field] = value
440
+ symbol_table[field] = value
441
+
442
+ for field in field_order:
443
+ value = getattr(self, field) if use_setattr else self[field]
444
+ validator = self.get_validator(field)
445
+ parsed = _parse_field(field, value, validator, symbol_table, self, **kwargs)
446
+
447
+ for post_call in post_calls:
448
+ parsed = post_call(field, value, parsed, symbol_table)
449
+ if use_setattr:
450
+ setattr(self, field, parsed)
451
+ else:
452
+ self[field] = parsed
453
+ symbol_table[field] = parsed
454
+
455
+ for k, v in prev_symbol_table.items():
456
+ if (
457
+ isinstance(k, str)
458
+ and k.startswith("global_")
459
+ and symbol_table.get(k, None) != v
460
+ ):
461
+ raise ParseError(
462
+ f"Global variable {k} is already set to {v} in the outer scope. "
463
+ f"It cannot be changed to {symbol_table[k]}."
464
+ )
465
+
466
+ return self, symbol_table
467
+
468
+
469
+ class _FromYAMLAble:
470
+ @classmethod
471
+ def from_yaml(
472
+ cls: type[T],
473
+ *files: str | list[str] | Path | list[Path],
474
+ jinja_parse_data: dict[str, Any] | None = None,
475
+ top_key: str | None = None,
476
+ **kwargs,
477
+ ) -> T:
478
+ """
479
+ Loads a dictionary from one more more yaml files.
480
+
481
+ Each yaml file should contain a dictionary. Dictionaries are combined in the
482
+ order they are given.
483
+
484
+ Keyword arguments are also added to the dictionary.
485
+
486
+ Args:
487
+ files:
488
+ A list of yaml files to load.
489
+ jinja_parse_data: Optional[Dict[str, Any]]
490
+ A dictionary of Jinja2 data to use when parsing the yaml files.
491
+ top_key: Optional[str]
492
+ The top key to use when parsing the yaml files.
493
+ kwargs: Extra keyword arguments to be passed to the constructor.
494
+
495
+ Returns:
496
+ A dict containing the combined dictionaries.
497
+ """
498
+
499
+ allfiles = []
500
+ jinja_parse_data = jinja_parse_data or {}
501
+ for f in files:
502
+ if isinstance(f, (list, tuple)):
503
+ if isinstance(f[0], Path):
504
+ f = list(map(str, f))
505
+ allfiles.extend(f)
506
+ else:
507
+ if isinstance(f, Path):
508
+ f = str(f)
509
+ allfiles.append(f)
510
+ files = allfiles
511
+ rval = {}
512
+ key2file = {}
513
+ extra_elems = []
514
+ to_parse = []
515
+ for f in files:
516
+ globbed = [x for x in glob.glob(f) if os.path.isfile(x)]
517
+ if not globbed:
518
+ raise FileNotFoundError(f"Could not find file {f}")
519
+ for g in globbed:
520
+ if any(os.path.samefile(g, x) for x in to_parse):
521
+ logging.info('Ignoring duplicate file "%s" in yaml load', g)
522
+ else:
523
+ to_parse.append(g)
524
+
525
+ for f in to_parse:
526
+ if not (
527
+ f.endswith(".yaml") or f.endswith(".jinja") or f.endswith(".jinja2")
528
+ ):
529
+ logging.warning(
530
+ f"File {f} does not end with .yaml, .jinja, or .jinja2. Skipping."
531
+ )
532
+ logging.info("Loading yaml file %s", f)
533
+ loaded = _yaml.load_yaml(f, data=jinja_parse_data)
534
+ if not isinstance(loaded, dict):
535
+ raise TypeError(
536
+ f"Expected a dictionary from file {f}, got {type(loaded)}"
537
+ )
538
+ for k, v in loaded.items():
539
+ if k in rval:
540
+ logging.info("Found extra top-key %s in %s", k, f)
541
+ extra_elems.append((k, v))
542
+ else:
543
+ logging.info("Found top key %s in %s", k, f)
544
+ key2file[k] = f
545
+ rval[k] = v
546
+
547
+ if top_key is not None:
548
+ if top_key not in rval:
549
+ raise KeyError(f"Top key {top_key} not found in {files}")
550
+ rval = rval[top_key]
551
+
552
+ c = None
553
+ try:
554
+ c = cls(**rval, **kwargs)
555
+ except Exception as e:
556
+ pass
557
+
558
+ if c is None and rval is None:
559
+ if top_key is not None:
560
+ raise ValueError(
561
+ f"No data to parse from {files} with top key {top_key}. Is there "
562
+ f"content under the top key {top_key}?"
563
+ )
564
+ raise ValueError(
565
+ f"No data to parse from {files}. Is there content in the file(s)?"
566
+ )
567
+
568
+ if c is None and len(rval) == 1:
569
+ logging.warning(
570
+ f"Trying to parse a single element dictionary as a {cls.__name__}. "
571
+ )
572
+ try:
573
+ rval_first = list(rval.values())[0]
574
+ if not isinstance(rval_first, dict):
575
+ raise TypeError(
576
+ f"Expected a dictionary as the top-level element in {files}, "
577
+ f"got {type(rval_first)}."
578
+ )
579
+ c = cls(**rval_first, **kwargs)
580
+ except Exception as e:
581
+ logging.warning(
582
+ f"Error parsing {files} with top key {top_key}. " f"Error: {e}"
583
+ )
584
+ if c is None:
585
+ c = cls(**rval, **kwargs)
586
+
587
+ if extra_elems:
588
+ logging.info(
589
+ "Parsing extra attributes %s", ", ".join([x[0] for x in extra_elems])
590
+ )
591
+ c._yaml_source = ",".join(files)
592
+ return c
593
+
594
+
595
+ def _parse_field(
596
+ field,
597
+ value,
598
+ validator,
599
+ symbol_table,
600
+ parent,
601
+ must_parse_try_parse_to: bool = False,
602
+ must_copy: bool = True,
603
+ **kwargs,
604
+ ):
605
+ from accelforge.util._setexpressions import InvertibleSet, eval_set_expression
606
+
607
+ def check_subclass(x, cls):
608
+ return isinstance(x, type) and issubclass(x, cls)
609
+
610
+ try:
611
+ # Get the origin type (ParsesTo or TryParseTo) and its arguments
612
+ origin = get_origin(validator)
613
+ if origin is ParsesTo or origin is TryParseTo:
614
+ try:
615
+ target_type = get_args(validator)[0]
616
+ parsed = value
617
+ if isinstance(target_type, tuple) and any(
618
+ check_subclass(t, InvertibleSet) for t in target_type
619
+ ):
620
+ raise NotImplementedError(
621
+ f"InvertibleSet must be used directly, not as a part of a "
622
+ f"union, else this function must be updated."
623
+ )
624
+
625
+ # Check if validator is for InvertibleSet
626
+ if check_subclass(target_type, InvertibleSet):
627
+ # Get the target type from the validator
628
+
629
+ # If the given type is a set, replace it with a string that'll parse
630
+ if isinstance(value, set):
631
+ value = " | ".join(str(v) for v in value)
632
+
633
+ type_args = target_type.__pydantic_generic_metadata__["args"]
634
+ assert len(type_args) == 1, "Expected exactly one type argument"
635
+ expected_element_type = type_args[0]
636
+
637
+ try:
638
+ # eval_set_expression does the type checking for us
639
+ return eval_set_expression(
640
+ value,
641
+ symbol_table,
642
+ expected_space=expected_element_type,
643
+ location=field,
644
+ )
645
+ except ParseError as e:
646
+ if origin is TryParseTo and not must_parse_try_parse_to:
647
+ return LiteralString(value)
648
+ raise
649
+ elif is_literal_string(value):
650
+ parsed = LiteralString(value)
651
+ else:
652
+ parsed = parse_expression(value, symbol_table)
653
+
654
+ if must_copy and id(parsed) == id(value):
655
+ parsed = copy.deepcopy(parsed)
656
+
657
+ # Get the target type from the validator
658
+ target_any = (
659
+ target_type is Any
660
+ or isinstance(target_type, tuple)
661
+ and Any in target_type
662
+ )
663
+ if not target_any and not isinstance(parsed, target_type):
664
+ raise ParseError(
665
+ f'{value} parsed to "{parsed}" with type {type(parsed).__name__}.'
666
+ f" Expected {target_type}.",
667
+ )
668
+ except ParseError as e:
669
+ if origin is TryParseTo and not must_parse_try_parse_to:
670
+ return LiteralString(value)
671
+ raise
672
+ else:
673
+ parsed = value
674
+
675
+ if isinstance(parsed, Parsable) and origin is not NoParse:
676
+ parsed, _ = parsed._parse_expressions(
677
+ symbol_table=symbol_table,
678
+ must_copy=must_copy,
679
+ must_parse_try_parse_to=must_parse_try_parse_to,
680
+ **kwargs,
681
+ )
682
+ return parsed
683
+ elif isinstance(parsed, str):
684
+ return LiteralString(parsed)
685
+ else:
686
+ return parsed
687
+ except ParseError as e:
688
+ try:
689
+ e.add_field(parent[field].name)
690
+ except:
691
+ e.add_field(field)
692
+ raise e
693
+
694
+
695
+ # python_name_regex = re.compile(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b')
696
+
697
+
698
+ def _get_parsable_field_order(
699
+ order: tuple[str, ...], field_value_validator_triples: list[tuple[str, Any, type]]
700
+ ) -> list[str]:
701
+
702
+ def is_parsable(value, validator):
703
+ if isinstance(value, Parsable):
704
+ return True
705
+ return False
706
+
707
+ order = list(order)
708
+ to_sort = []
709
+
710
+ for field, value, validator in field_value_validator_triples:
711
+ if field in order:
712
+ continue
713
+ if get_origin(validator) is not ParsesTo and not is_parsable(value, validator):
714
+ order.append(field)
715
+ continue
716
+ to_sort.append((field, value))
717
+
718
+ field2validator = {f: v for f, v, _ in field_value_validator_triples}
719
+
720
+ dependencies = {field: set() for field, _ in to_sort}
721
+ for other_field, other_value in to_sort:
722
+ # Can't have any dependencies if you're not going to be parsed
723
+ if not isinstance(other_value, str) or is_literal_string(other_value):
724
+ continue
725
+ for field, value in to_sort:
726
+ if field != other_field:
727
+ if re.findall(r"\b" + re.escape(field) + r"\b", other_value):
728
+ dependencies[other_field].add(field)
729
+
730
+ while to_sort:
731
+ can_add = [
732
+ (f, v) for f, v in to_sort if all(dep in order for dep in dependencies[f])
733
+ ]
734
+ if not can_add:
735
+ raise ParseError(
736
+ f"Circular dependency detected in expressions. "
737
+ f"Fields: {', '.join(t[0] for t in to_sort)}"
738
+ )
739
+ # Parsables last
740
+ for f, v in can_add:
741
+ if not is_parsable(v, field2validator[f]):
742
+ order.append(f)
743
+ to_sort.remove((f, v))
744
+ break
745
+ else:
746
+ order.append(can_add[0][0])
747
+ to_sort.remove(can_add[0])
748
+ return order
749
+
750
+
751
+ class _OurBaseModel(BaseModel, _FromYAMLAble, Mapping):
752
+ # Exclude is supported OK, but makes the docs a lot longer because it's in so many
753
+ # objects and has a very long type.
754
+ def to_yaml(
755
+ self, f: str | None = None
756
+ ) -> str: # , exclude: IncEx | None = None) -> str:
757
+ """
758
+ Dump the model to a YAML string.
759
+
760
+ Parameters
761
+ ----------
762
+ f: str | None
763
+ The file to write the YAML to. If not given, then returns as a string.
764
+ exclude: IncEx | None
765
+ The fields to exclude from the YAML.
766
+
767
+ Returns
768
+ -------
769
+ str
770
+ The YAML string.
771
+ """
772
+ dump = self.model_dump() # exclude=exclude)
773
+
774
+ def _to_str(value: Any):
775
+ if isinstance(value, list):
776
+ return [_to_str(x) for x in value]
777
+ elif isinstance(value, dict):
778
+ return {_to_str(k): _to_str(v) for k, v in value.items()}
779
+ elif isinstance(value, str):
780
+ return str(value)
781
+ return value
782
+
783
+ if f is not None:
784
+ _yaml.write_yaml_file(f, _to_str(dump))
785
+ return _yaml.to_yaml_string(_to_str(dump))
786
+
787
+ def all_fields_default(self):
788
+ for field in self.__class__.model_fields:
789
+ default = self.__class__.model_fields[field].default
790
+ if getattr(self, field) != default:
791
+ return False
792
+ return True
793
+
794
+ def model_dump_non_none(self, **kwargs):
795
+ return {k: v for k, v in self.model_dump(**kwargs).items() if v is not None}
796
+
797
+ def shallow_model_dump(self, include_None: bool = False, **kwargs):
798
+ keys = self.get_fields()
799
+ if getattr(self, "__pydantic_extra__", None) is not None:
800
+ keys.extend([k for k in self.__pydantic_extra__.keys() if k not in keys])
801
+
802
+ if not include_None:
803
+ keys = [k for k in keys if getattr(self, k) is not None]
804
+
805
+ return {k: getattr(self, k) for k in keys}
806
+
807
+ def __contains__(self, key: str) -> bool:
808
+ try:
809
+ self[key]
810
+ return True
811
+ except KeyError:
812
+ return False
813
+
814
+ def __getitem__(self, key: str) -> Any:
815
+ try:
816
+ return getattr(self, key)
817
+ except AttributeError:
818
+ pass
819
+ raise KeyError(f"Key {key} not found in {self.__class__.__name__}")
820
+
821
+ def __setitem__(self, key: str, value: Any):
822
+ setattr(self, key, value)
823
+
824
+ def __delitem__(self, key: str):
825
+ delattr(self, key)
826
+
827
+ def __iter__(self) -> Iterator[str]:
828
+ return iter(self.get_fields())
829
+
830
+ def __len__(self) -> int:
831
+ return len(self.get_fields())
832
+
833
+
834
+ @_uninstantiable
835
+ class ParsableModel(_OurBaseModel, Parsable["ParsableModel"]):
836
+ """A model that will parse any fields that are given to it. When parsing, submodels
837
+ will also be parsed if they support it. Parsing will parse any fields that are given
838
+ as strings and do not match the expected type.
839
+ """
840
+
841
+ model_config = ConfigDict(extra="forbid")
842
+ # type: Optional[str] = None
843
+
844
+ def __init__(self, **kwargs):
845
+ required_type = kwargs.pop("type", None)
846
+
847
+ if self.model_config["extra"] == "forbid":
848
+ supported_fields = set(self.__class__.model_fields.keys())
849
+ for k in kwargs.keys():
850
+ if k not in supported_fields:
851
+ raise ValueError(
852
+ f"Field {k} is not supported for {self.__class__.__name__}. "
853
+ f"Supported fields are:\n\t"
854
+ + "\n\t".join(sorted(supported_fields))
855
+ + "\n",
856
+ )
857
+
858
+ super().__init__(**kwargs)
859
+ if required_type is not None:
860
+ try:
861
+ passed_check = isinstance(self, required_type)
862
+ except TypeError:
863
+ raise TypeError(
864
+ f"Error checking required type. Was given type argument "
865
+ f"{required_type} a valid type?"
866
+ ) from None
867
+
868
+ if not passed_check:
869
+ raise TypeError(
870
+ f"type field {required_type} does not match"
871
+ f"{self.__class__.__name__}"
872
+ )
873
+
874
+ def get_validator(self, field: str) -> Type:
875
+ if field in self.__class__.model_fields:
876
+ return self.__class__.model_fields[field].annotation
877
+ return ParsesTo[Any]
878
+
879
+ def get_fields(self) -> list[str]:
880
+ fields = set(self.__class__.model_fields.keys())
881
+ if getattr(self, "__pydantic_extra__", None) is not None:
882
+ fields.update(self.__pydantic_extra__.keys())
883
+ return sorted(fields)
884
+
885
+ def _parse_expressions(
886
+ self,
887
+ symbol_table: dict[str, Any] = None,
888
+ order: tuple[str, ...] = (),
889
+ post_calls: tuple[_PostCall[T], ...] = (),
890
+ already_parsed: dict[str, Any] | None = None,
891
+ **kwargs,
892
+ ) -> tuple[Self, dict[str, Any]]:
893
+ new = self.model_copy()
894
+ symbol_table = symbol_table.copy() if symbol_table is not None else {}
895
+ kwargs = dict(kwargs)
896
+ return new._parse_expressions_final(
897
+ symbol_table,
898
+ order,
899
+ post_calls,
900
+ use_setattr=True,
901
+ already_parsed=already_parsed,
902
+ **kwargs,
903
+ )
904
+
905
+
906
+ class NonParsableModel(_OurBaseModel):
907
+ """A model that will not parse any fields."""
908
+
909
+ model_config = ConfigDict(extra="forbid")
910
+ type: Optional[str] = None
911
+
912
+ def get_validator(self, field: str) -> Type:
913
+ return Any
914
+
915
+
916
+ class ParsableList(list[T], Parsable["ParsableList[T]"], Generic[T]):
917
+ """
918
+ A list that can be parsed from a string. ParsableList[T] means that a given string
919
+ can be parsed, yielding a list of objects of type T.
920
+ """
921
+
922
+ def get_validator(self, field: str) -> Type:
923
+ return T
924
+
925
+ def _parse_expressions(
926
+ self,
927
+ symbol_table: dict[str, Any] = None,
928
+ order: tuple[str, ...] = (),
929
+ post_calls: tuple[_PostCall[T], ...] = (),
930
+ already_parsed: dict[str, Any] | None = None,
931
+ **kwargs,
932
+ ) -> tuple["ParsableList[T]", dict[str, Any]]:
933
+ new = ParsableList[T](x for x in self)
934
+ symbol_table = symbol_table.copy() if symbol_table is not None else {}
935
+ order = order + tuple(x for x in range(len(new)) if x not in order)
936
+ return new._parse_expressions_final(
937
+ symbol_table,
938
+ order,
939
+ post_calls,
940
+ use_setattr=False,
941
+ already_parsed=already_parsed,
942
+ **kwargs,
943
+ )
944
+
945
+ def get_fields(self) -> list[str]:
946
+ return sorted(range(len(self)))
947
+
948
+ @classmethod
949
+ def __get_pydantic_core_schema__(
950
+ cls, source_type: Any, handler: Callable
951
+ ) -> CoreSchema:
952
+ # Get the type parameter T from ParsableList[T]
953
+ type_args = get_args(source_type)
954
+ if not type_args:
955
+ raise TypeError(
956
+ f"ParsableList must be used with a type parameter, e.g. ParsableList[int]"
957
+ )
958
+ item_type = type_args[0]
959
+
960
+ # Get the schema for the item type
961
+ item_schema = handler(item_type)
962
+
963
+ # Create a schema that validates lists of the item type
964
+ return chain_schema(
965
+ [
966
+ list_schema(item_schema),
967
+ no_info_plain_validator_function(lambda x: cls(x)),
968
+ ]
969
+ )
970
+
971
+ def __getitem__(self, key: str | int | slice) -> T:
972
+ if isinstance(key, int):
973
+ return super().__getitem__(key) # type: ignore
974
+
975
+ elif isinstance(key, slice):
976
+ return ParsableList[T](super().__getitem__(key))
977
+
978
+ elif isinstance(key, str):
979
+ found = None
980
+ for elem in self:
981
+ name = None
982
+ if isinstance(elem, dict):
983
+ name = elem.get("name", None)
984
+ elif hasattr(elem, "name"):
985
+ name = elem.name
986
+ if name is not None and name == key:
987
+ if found is not None:
988
+ raise ValueError(f'Multiple elements with name "{key}" found.')
989
+ found = elem
990
+ if found is not None:
991
+ return found
992
+
993
+ fields = self.get_fields()
994
+ fields += [
995
+ (
996
+ x.name
997
+ if hasattr(x, "name")
998
+ else x.get("name", None) if isinstance(x, dict) else None
999
+ )
1000
+ for x in self
1001
+ ]
1002
+ fields = sorted(str(x) for x in fields if x is not None)
1003
+ raise KeyError(
1004
+ f'No element with name "{key}" found. Available names: {', '.join(fields)}'
1005
+ )
1006
+
1007
+ def __contains__(self, item: Any) -> bool:
1008
+ try:
1009
+ self[item]
1010
+ return True
1011
+ except KeyError:
1012
+ return super().__contains__(item)
1013
+
1014
+ def __copy__(self) -> Self:
1015
+ return type(self)(x for x in self)
1016
+
1017
+
1018
+ class ParsableDict(
1019
+ dict[K, V], Parsable["ParsableDict[K, V]"], Generic[K, V], _FromYAMLAble
1020
+ ):
1021
+ """A dictionary that can be parsed from a string. ParsableDict[K, V] means that a
1022
+ given string can be parsed, yielding a dictionary with keys of type K and values of
1023
+ type V.
1024
+ """
1025
+
1026
+ def get_validator(self, field: str) -> type:
1027
+ return V
1028
+
1029
+ def get_fields(self) -> list[str]:
1030
+ return sorted(self.keys())
1031
+
1032
+ def _parse_expressions(
1033
+ self,
1034
+ symbol_table: dict[str, Any] = None,
1035
+ order: tuple[str, ...] = (),
1036
+ post_calls: tuple[_PostCall[V], ...] = (),
1037
+ already_parsed: dict[str, Any] | None = None,
1038
+ **kwargs,
1039
+ ) -> tuple["ParsableDict[K, V]", dict[str, Any]]:
1040
+ new = ParsableDict[K, V](self)
1041
+ symbol_table = symbol_table.copy() if symbol_table is not None else {}
1042
+ return new._parse_expressions_final(
1043
+ symbol_table,
1044
+ order,
1045
+ post_calls,
1046
+ use_setattr=False,
1047
+ already_parsed=already_parsed,
1048
+ **kwargs,
1049
+ )
1050
+
1051
+ @classmethod
1052
+ def __get_pydantic_core_schema__(
1053
+ cls, source_type: Any, handler: Callable
1054
+ ) -> CoreSchema:
1055
+ # Get the type parameters K and V from ParsableDict[K, V]
1056
+ type_args = get_args(source_type)
1057
+ if len(type_args) != 2:
1058
+ raise TypeError(
1059
+ f"ParsableDict must be used with two type parameters, e.g. ParsableDict[str, int]"
1060
+ )
1061
+ key_type, value_type = type_args
1062
+
1063
+ # Get the schemas for the key and value types
1064
+ key_schema = handler.generate_schema(key_type)
1065
+ value_schema = handler.generate_schema(value_type)
1066
+
1067
+ # Create a schema that validates dictionaries with the specified key and value types
1068
+ return chain_schema(
1069
+ [
1070
+ dict_schema(key_schema, value_schema),
1071
+ no_info_plain_validator_function(lambda x: cls(x)),
1072
+ ]
1073
+ )
1074
+
1075
+ def __copy__(self) -> Self:
1076
+ return type(self)({k: v for k, v in self.items()})
1077
+
1078
+
1079
+ class ParseExtras(ParsableModel):
1080
+ """
1081
+ A model that will parse any extra fields that are given to it.
1082
+ """
1083
+
1084
+ model_config = ConfigDict(extra="allow")
1085
+
1086
+ def get_validator(self, field: str) -> type:
1087
+ if field not in self.__class__.model_fields:
1088
+ return ParsesTo[Any]
1089
+ return self.__class__.model_fields[field].annotation