accelforge 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accelforge/__init__.py +21 -0
- accelforge/_accelerated_imports.py +16 -0
- accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
- accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
- accelforge/_deprecate/_simanneal/simanneal.py +666 -0
- accelforge/_deprecate/_simanneal/tracking.py +105 -0
- accelforge/_deprecate/_simanneal/wrappers.py +218 -0
- accelforge/_deprecate/_simanneal2/__init__.py +7 -0
- accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
- accelforge/_deprecate/_simanneal2/tracking.py +116 -0
- accelforge/_deprecate/compatibility_util.py +181 -0
- accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
- accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
- accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
- accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
- accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
- accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
- accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
- accelforge/_deprecate/tags.py +69 -0
- accelforge/_deprecate/viz/__init__.py +0 -0
- accelforge/_deprecate/viz/interactive.py +159 -0
- accelforge/_deprecate/viz/reservationtree.py +307 -0
- accelforge/_deprecate/viz/ski_slope.py +88 -0
- accelforge/_version.py +15 -0
- accelforge/examples.py +39 -0
- accelforge/frontend/__init__.py +10 -0
- accelforge/frontend/_binding.py +129 -0
- accelforge/frontend/_workload_isl/__init__.py +2 -0
- accelforge/frontend/_workload_isl/_isl.py +149 -0
- accelforge/frontend/_workload_isl/_symbolic.py +141 -0
- accelforge/frontend/arch copy.py +1544 -0
- accelforge/frontend/arch.py +1642 -0
- accelforge/frontend/config.py +63 -0
- accelforge/frontend/mapper/__init__.py +5 -0
- accelforge/frontend/mapper/ffm.py +126 -0
- accelforge/frontend/mapper/mapper.py +7 -0
- accelforge/frontend/mapper/metrics.py +30 -0
- accelforge/frontend/mapping/__init__.py +1 -0
- accelforge/frontend/mapping/mapping.py +1736 -0
- accelforge/frontend/model.py +14 -0
- accelforge/frontend/renames.py +150 -0
- accelforge/frontend/spec copy.py +230 -0
- accelforge/frontend/spec.py +301 -0
- accelforge/frontend/variables.py +12 -0
- accelforge/frontend/workload.py +952 -0
- accelforge/mapper/FFM/__init__.py +9 -0
- accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
- accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
- accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
- accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
- accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
- accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
- accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
- accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
- accelforge/mapper/FFM/data.py +61 -0
- accelforge/mapper/FFM/main copy.py +236 -0
- accelforge/mapper/FFM/main.py +208 -0
- accelforge/mapper/FFM/mappings.py +510 -0
- accelforge/mapper/FFM/pmappings.py +310 -0
- accelforge/mapper/__init__.py +4 -0
- accelforge/mapper.py +0 -0
- accelforge/model/__init__.py +1 -0
- accelforge/model/_looptree/__init__.py +0 -0
- accelforge/model/_looptree/accesses.py +335 -0
- accelforge/model/_looptree/capacity/__init__.py +1 -0
- accelforge/model/_looptree/capacity/aggregators.py +36 -0
- accelforge/model/_looptree/capacity/capacity.py +47 -0
- accelforge/model/_looptree/energy.py +150 -0
- accelforge/model/_looptree/equivalent_ranks.py +29 -0
- accelforge/model/_looptree/latency/__init__.py +1 -0
- accelforge/model/_looptree/latency/latency.py +98 -0
- accelforge/model/_looptree/latency/memory.py +120 -0
- accelforge/model/_looptree/latency/processors.py +92 -0
- accelforge/model/_looptree/mapping_utilities.py +71 -0
- accelforge/model/_looptree/reuse/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
- accelforge/model/_looptree/reuse/isl/des.py +59 -0
- accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
- accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
- accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
- accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
- accelforge/model/_looptree/run.py +122 -0
- accelforge/model/_looptree/types.py +26 -0
- accelforge/model/_looptree/visualization/__init__.py +0 -0
- accelforge/model/_looptree/visualization/occupancy.py +11 -0
- accelforge/model/main.py +222 -0
- accelforge/plotting/__init__.py +2 -0
- accelforge/plotting/mappings.py +219 -0
- accelforge/plotting/specs.py +57 -0
- accelforge/util/__init__.py +4 -0
- accelforge/util/_base_analysis_types.py +24 -0
- accelforge/util/_basetypes.py +1089 -0
- accelforge/util/_frozenset.py +36 -0
- accelforge/util/_isl.py +29 -0
- accelforge/util/_itertools.py +14 -0
- accelforge/util/_mathfuncs.py +57 -0
- accelforge/util/_parse_expressions.py +339 -0
- accelforge/util/_picklecache.py +32 -0
- accelforge/util/_setexpressions.py +268 -0
- accelforge/util/_sympy/__init__.py +0 -0
- accelforge/util/_sympy/broadcast_max.py +18 -0
- accelforge/util/_visualization.py +112 -0
- accelforge/util/_yaml.py +579 -0
- accelforge/util/parallel.py +193 -0
- accelforge-0.0.1.dist-info/METADATA +64 -0
- accelforge-0.0.1.dist-info/RECORD +258 -0
- accelforge-0.0.1.dist-info/WHEEL +5 -0
- accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
- accelforge-0.0.1.dist-info/top_level.txt +5 -0
- docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
- docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
- docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
- docs/_build/html/_sources/fastfusion.rst.txt +20 -0
- docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
- docs/_build/html/_sources/index.rst.txt +87 -0
- docs/_build/html/_sources/modules.rst.txt +7 -0
- docs/_build/html/_sources/notes/citation.rst.txt +45 -0
- docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
- docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
- docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
- docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
- docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
- docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
- docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
- docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
- docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
- docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
- docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
- docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
- docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
- docs/_build/html/_sources/notes/spec.rst.txt +36 -0
- docs/source/_ext/include_attrs.py +213 -0
- docs/source/_ext/include_docstring.py +364 -0
- docs/source/_ext/include_functions.py +154 -0
- docs/source/_ext/include_notebook.py +131 -0
- docs/source/_ext/include_yaml.py +119 -0
- docs/source/_ext/inherited_attributes.py +222 -0
- docs/source/_ext/paths.py +4 -0
- docs/source/conf.py +79 -0
- examples/arches/compute_in_memory/_include.yaml +74 -0
- examples/arches/compute_in_memory/_include_functions.py +229 -0
- examples/arches/compute_in_memory/_load_spec.py +57 -0
- examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
- examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
- examples/arches/compute_in_memory/components/misc.py +195 -0
- examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
- examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
- examples/arches/compute_in_memory/isaac.yaml +233 -0
- examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
- examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
- examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
- examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
- examples/arches/eyeriss.yaml +68 -0
- examples/arches/fanout_variations/at_glb.yaml +31 -0
- examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
- examples/arches/fanout_variations/at_mac.yaml +31 -0
- examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
- examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
- examples/arches/nvdla.yaml +47 -0
- examples/arches/simple.yaml +28 -0
- examples/arches/tpu_v4i.yaml +67 -0
- examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
- examples/misc/component_annotated.yaml +33 -0
- examples/workloads/gpt3_6.7B.yaml +124 -0
- examples/workloads/matmuls.yaml +20 -0
- examples/workloads/mobilenet_28.yaml +81 -0
- examples/workloads/mobilenet_various_separate.yaml +106 -0
- examples/workloads/three_matmuls_annotated.yaml +59 -0
- notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
- notebooks/compute_in_memory/_scripts.py +339 -0
- notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
- notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
- notebooks/paths.py +4 -0
- notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
- notebooks/tutorials/FFM.ipynb +3498 -0
- notebooks/tutorials/_include.py +48 -0
- notebooks/tutorials/component_energy_area.ipynb +363 -0
- tests/Q_mapping.yaml +38 -0
- tests/__init__.py +0 -0
- tests/conv.mapping.yaml +27 -0
- tests/conv.workload.yaml +13 -0
- tests/conv_sym.mapping.yaml +43 -0
- tests/copy.mapping.yaml +35 -0
- tests/copy.workload.yaml +15 -0
- tests/distribuffers/__init__.py +0 -0
- tests/distribuffers/multicast/test_cases.yaml +482 -0
- tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
- tests/distribuffers/spec/distributed.yaml +100 -0
- tests/distribuffers/spec/logical_arch.yaml +32 -0
- tests/distribuffers/spec/physical_arch.yaml +69 -0
- tests/distribuffers/test_binding.py +48 -0
- tests/frontend/__init__.py +0 -0
- tests/frontend/test_mapping_viz.py +52 -0
- tests/mapper/__init__.py +0 -0
- tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
- tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
- tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
- tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
- tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
- tests/mapper/test_mapping_to_isl.py +90 -0
- tests/mapper/test_spatial_reuse_analysis.py +67 -0
- tests/mapper/test_temporal_reuse_analysis.py +56 -0
- tests/mapper/util.py +58 -0
- tests/matmul.mapping.yaml +29 -0
- tests/matmul.workload.yaml +12 -0
- tests/matmul_spatial.mapping.yaml +44 -0
- tests/mha.renames.yaml +65 -0
- tests/mha.workload.yaml +67 -0
- tests/mha.yaml +59 -0
- tests/mha_full.workload.yaml +67 -0
- tests/mobilenet.workload.yaml +35 -0
- tests/mobilenet_long.workload.yaml +64 -0
- tests/pmappingcache.py +24 -0
- tests/processing_stage.arch.yaml +40 -0
- tests/snowcat.arch.yaml +36 -0
- tests/test_ffm_join_pmappings.py +106 -0
- tests/test_ffm_make_pmappings.py +82 -0
- tests/test_ffm_make_tile_shapes.py +49 -0
- tests/test_mapper.py +100 -0
- tests/test_model.py +37 -0
- tests/test_plotting.py +72 -0
- tests/test_processing_stage.py +46 -0
- tests/test_symbolic_model.py +248 -0
- tests/test_workload.py +141 -0
|
@@ -0,0 +1,1089 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import glob
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
import re
|
|
7
|
+
from pydantic import BaseModel, ConfigDict, Tag, ValidationError
|
|
8
|
+
from pydantic.main import IncEx
|
|
9
|
+
from pydantic_core.core_schema import (
|
|
10
|
+
CoreSchema,
|
|
11
|
+
chain_schema,
|
|
12
|
+
list_schema,
|
|
13
|
+
union_schema,
|
|
14
|
+
no_info_plain_validator_function,
|
|
15
|
+
str_schema,
|
|
16
|
+
dict_schema,
|
|
17
|
+
tagged_union_schema,
|
|
18
|
+
)
|
|
19
|
+
from typing import (
|
|
20
|
+
Iterator,
|
|
21
|
+
List,
|
|
22
|
+
Mapping,
|
|
23
|
+
TypeVar,
|
|
24
|
+
Generic,
|
|
25
|
+
Any,
|
|
26
|
+
Callable,
|
|
27
|
+
TypeVarTuple,
|
|
28
|
+
Dict,
|
|
29
|
+
Optional,
|
|
30
|
+
Type,
|
|
31
|
+
TypeAlias,
|
|
32
|
+
Union,
|
|
33
|
+
get_args,
|
|
34
|
+
get_origin,
|
|
35
|
+
TYPE_CHECKING,
|
|
36
|
+
Self,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
from accelforge.util import _yaml
|
|
40
|
+
from accelforge.util._parse_expressions import (
|
|
41
|
+
parse_expression,
|
|
42
|
+
ParseError,
|
|
43
|
+
LiteralString,
|
|
44
|
+
is_literal_string,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Import will be resolved at runtime to avoid circular dependency
|
|
48
|
+
TYPE_CHECKING_RUNTIME = False
|
|
49
|
+
if TYPE_CHECKING or TYPE_CHECKING_RUNTIME:
|
|
50
|
+
from accelforge.util._setexpressions import InvertibleSet, eval_set_expression
|
|
51
|
+
|
|
52
|
+
T = TypeVar("T")
|
|
53
|
+
M = TypeVar("M", bound=BaseModel)
|
|
54
|
+
K = TypeVar("K")
|
|
55
|
+
V = TypeVar("V")
|
|
56
|
+
PM = TypeVar("PM", bound="ParsableModel")
|
|
57
|
+
PL = TypeVar("PL", bound="ParsableList[Any]")
|
|
58
|
+
|
|
59
|
+
Ts = TypeVarTuple("Ts")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _get_tag(value: Any) -> str:
|
|
63
|
+
if not isinstance(value, dict):
|
|
64
|
+
return value.__class__.__name__
|
|
65
|
+
tag = None
|
|
66
|
+
|
|
67
|
+
def try_get_tag(attr: str) -> str:
|
|
68
|
+
if hasattr(value, attr) and getattr(value, attr) is not None:
|
|
69
|
+
return getattr(value, attr)
|
|
70
|
+
return None
|
|
71
|
+
|
|
72
|
+
def try_index(attr: str) -> str:
|
|
73
|
+
try:
|
|
74
|
+
return value[attr]
|
|
75
|
+
except:
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
tag = None
|
|
79
|
+
for attr in ("type", "_type", "_yaml_tag"):
|
|
80
|
+
if tag := try_get_tag(attr):
|
|
81
|
+
break
|
|
82
|
+
if tag := try_index(attr):
|
|
83
|
+
break
|
|
84
|
+
if tag is None:
|
|
85
|
+
raise ValueError(
|
|
86
|
+
f"No tag found for {value}. Either set the type field " "or use a YAML tag."
|
|
87
|
+
)
|
|
88
|
+
tag = str(tag)
|
|
89
|
+
if tag.startswith("!"):
|
|
90
|
+
tag = tag[1:]
|
|
91
|
+
return tag
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _uninstantiable(cls):
|
|
95
|
+
prev_init = cls.__init__
|
|
96
|
+
|
|
97
|
+
def _get_all_subclasses(cls):
|
|
98
|
+
subclasses = set()
|
|
99
|
+
for subclass in cls.__subclasses__():
|
|
100
|
+
subclasses.add(subclass.__name__)
|
|
101
|
+
subclasses.update(_get_all_subclasses(subclass))
|
|
102
|
+
return subclasses
|
|
103
|
+
|
|
104
|
+
def __init__(self, *args, **kwargs):
|
|
105
|
+
if self.__class__ is cls:
|
|
106
|
+
subclasses = _get_all_subclasses(cls)
|
|
107
|
+
raise ValueError(
|
|
108
|
+
f"{cls} can not be instantiated directly. Use a subclass. "
|
|
109
|
+
f"Supported subclasses are:\n\t" + "\n\t".join(sorted(subclasses))
|
|
110
|
+
)
|
|
111
|
+
return prev_init(self, *args, **kwargs)
|
|
112
|
+
|
|
113
|
+
cls.__init__ = __init__
|
|
114
|
+
return cls
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class _InferFromTag(Generic[*Ts]):
|
|
118
|
+
@classmethod
|
|
119
|
+
def __get_pydantic_core_schema__(
|
|
120
|
+
cls, source_type: Any, handler: Callable
|
|
121
|
+
) -> CoreSchema:
|
|
122
|
+
type_args = get_args(source_type)
|
|
123
|
+
if not type_args:
|
|
124
|
+
raise TypeError(
|
|
125
|
+
f"_InferFromTag must be used with a type parameter, e.g. _InferFromTag[int]"
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# type_args contains all the possible types: (Compute, Memory, "Hierarchical")
|
|
129
|
+
target_types = []
|
|
130
|
+
for arg in type_args:
|
|
131
|
+
if isinstance(arg, str):
|
|
132
|
+
# Handle string type names - we'll need to resolve them later
|
|
133
|
+
target_types.append(arg)
|
|
134
|
+
elif isinstance(arg, type):
|
|
135
|
+
target_types.append(arg)
|
|
136
|
+
else:
|
|
137
|
+
target_types.append(arg)
|
|
138
|
+
|
|
139
|
+
# Create tag to class mapping
|
|
140
|
+
tag2class = {}
|
|
141
|
+
for target_type in target_types:
|
|
142
|
+
if isinstance(target_type, str):
|
|
143
|
+
# For string types, use the string as both key and placeholder
|
|
144
|
+
tag2class[target_type] = target_type
|
|
145
|
+
elif hasattr(target_type, "__name__"):
|
|
146
|
+
tag2class[target_type.__name__] = target_type
|
|
147
|
+
else:
|
|
148
|
+
# Fallback for other types
|
|
149
|
+
tag2class[str(target_type)] = target_type
|
|
150
|
+
|
|
151
|
+
def validate(value: Any) -> T:
|
|
152
|
+
if hasattr(value, "_yaml_tag"):
|
|
153
|
+
tag = value._yaml_tag
|
|
154
|
+
elif hasattr(value, "_type"):
|
|
155
|
+
tag = value._type
|
|
156
|
+
else:
|
|
157
|
+
for to_try in ("_yaml_tag", "_type", "type"):
|
|
158
|
+
try:
|
|
159
|
+
tag = value[to_try]
|
|
160
|
+
break
|
|
161
|
+
except:
|
|
162
|
+
pass
|
|
163
|
+
else:
|
|
164
|
+
raise ValueError(
|
|
165
|
+
f"No tag found for {value}. Either set the type field "
|
|
166
|
+
"or use a YAML tag."
|
|
167
|
+
)
|
|
168
|
+
tag = str(tag)
|
|
169
|
+
if tag.startswith("!"):
|
|
170
|
+
tag = tag[1:]
|
|
171
|
+
value._type = tag
|
|
172
|
+
|
|
173
|
+
print(f"Tag found! {tag}")
|
|
174
|
+
if tag in tag2class:
|
|
175
|
+
return tag2class[tag](**value)
|
|
176
|
+
else:
|
|
177
|
+
raise ValueError(
|
|
178
|
+
f"Unknown tag: {tag}. Supported tags are: {sorted(tag2class.keys())}"
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# target_schema = handler.generate_schema(target_types)
|
|
182
|
+
schemas = []
|
|
183
|
+
for t in target_types:
|
|
184
|
+
schemas.append(handler.generate_schema(t))
|
|
185
|
+
target_schema = union_schema(schemas)
|
|
186
|
+
# return chain_schema([
|
|
187
|
+
# no_info_plain_validator_function(validate),
|
|
188
|
+
# target_schema
|
|
189
|
+
# ])
|
|
190
|
+
return chain_schema(
|
|
191
|
+
[
|
|
192
|
+
no_info_plain_validator_function(validate),
|
|
193
|
+
tagged_union_schema(tag2class, discriminator="_type"),
|
|
194
|
+
]
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class NoParse(Generic[T]):
|
|
199
|
+
"""A type skips parsing of the specified object."""
|
|
200
|
+
|
|
201
|
+
_class_name: str = "NoParse"
|
|
202
|
+
|
|
203
|
+
def __init__(self, value: T):
|
|
204
|
+
self._value = value
|
|
205
|
+
self._type = T
|
|
206
|
+
|
|
207
|
+
@classmethod
|
|
208
|
+
def __get_pydantic_core_schema__(
|
|
209
|
+
cls, source_type: Any, handler: Callable
|
|
210
|
+
) -> CoreSchema:
|
|
211
|
+
# Get the type parameter T from ParsesTo[T]
|
|
212
|
+
type_args = get_args(source_type)
|
|
213
|
+
if not type_args:
|
|
214
|
+
raise TypeError(
|
|
215
|
+
f"{cls._class_name} must be used with a type parameter, "
|
|
216
|
+
f"e.g. {cls._class_name}[int]"
|
|
217
|
+
)
|
|
218
|
+
target_type = type_args[0]
|
|
219
|
+
|
|
220
|
+
# Get the schema for the target type
|
|
221
|
+
target_schema = handler(target_type)
|
|
222
|
+
|
|
223
|
+
def validate_raw_string(value):
|
|
224
|
+
if isinstance(value, str) and is_literal_string(value):
|
|
225
|
+
return LiteralString(value)
|
|
226
|
+
# raise ValueError("Not a raw string")
|
|
227
|
+
|
|
228
|
+
# Create a union schema that either validates as raw string or normal validation
|
|
229
|
+
return target_schema
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class ParsesTo(Generic[T]):
|
|
233
|
+
"""A type that parses to the specified type T.
|
|
234
|
+
|
|
235
|
+
Example:
|
|
236
|
+
class Example(ParsableModel):
|
|
237
|
+
a: ParsesTo[int] # Will parse string expressions to integers
|
|
238
|
+
b: ParsesTo[str] # Will parse string expressions to strings
|
|
239
|
+
c: str # Regular string, no parsing
|
|
240
|
+
"""
|
|
241
|
+
|
|
242
|
+
_class_name: str = "ParsesTo"
|
|
243
|
+
|
|
244
|
+
def __init__(self, value: str):
|
|
245
|
+
self._value = value
|
|
246
|
+
self._is_literal_string = is_literal_string(value)
|
|
247
|
+
self._type = T
|
|
248
|
+
|
|
249
|
+
assert self._type != str, (
|
|
250
|
+
f"{self._class_name}[str] is not allowed. Use str directly instead."
|
|
251
|
+
f"If something should just be a string, no expressions are allowed. "
|
|
252
|
+
f"This is so the users don't have to quote-wrap all strings."
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
def __str__(self) -> str:
|
|
256
|
+
return str(self._value)
|
|
257
|
+
|
|
258
|
+
def __repr__(self) -> str:
|
|
259
|
+
return f"{self._class_name}({repr(self._value)})"
|
|
260
|
+
|
|
261
|
+
@classmethod
|
|
262
|
+
def __get_pydantic_core_schema__(
|
|
263
|
+
cls, source_type: Any, handler: Callable
|
|
264
|
+
) -> CoreSchema:
|
|
265
|
+
# Get the type parameter T from ParsesTo[T]
|
|
266
|
+
type_args = get_args(source_type)
|
|
267
|
+
if not type_args:
|
|
268
|
+
raise TypeError(
|
|
269
|
+
f"{cls._class_name} must be used with a type parameter, "
|
|
270
|
+
f"e.g. {cls._class_name}[int]"
|
|
271
|
+
)
|
|
272
|
+
target_type = type_args[0]
|
|
273
|
+
|
|
274
|
+
# Get the schema for the target type
|
|
275
|
+
target_schema = handler(target_type)
|
|
276
|
+
|
|
277
|
+
def validate_raw_string(value):
|
|
278
|
+
if isinstance(value, str) and is_literal_string(value):
|
|
279
|
+
return LiteralString(value)
|
|
280
|
+
# raise ValueError("Not a raw string")
|
|
281
|
+
|
|
282
|
+
# Create a union schema that either validates as raw string or normal validation
|
|
283
|
+
return union_schema(
|
|
284
|
+
[
|
|
285
|
+
# First option: validate as raw string
|
|
286
|
+
chain_schema(
|
|
287
|
+
[
|
|
288
|
+
no_info_plain_validator_function(validate_raw_string),
|
|
289
|
+
str_schema(),
|
|
290
|
+
# target_schema
|
|
291
|
+
]
|
|
292
|
+
),
|
|
293
|
+
# Second option: normal validation (string then target type)
|
|
294
|
+
chain_schema(
|
|
295
|
+
[
|
|
296
|
+
str_schema(),
|
|
297
|
+
# target_schema
|
|
298
|
+
]
|
|
299
|
+
),
|
|
300
|
+
# Third option: direct target type validation
|
|
301
|
+
target_schema,
|
|
302
|
+
]
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
class TryParseTo(ParsesTo, Generic[T]):
|
|
307
|
+
"""
|
|
308
|
+
A type that tries to parse to the specified type T. If the parsing fails, the value
|
|
309
|
+
is returned as a string.
|
|
310
|
+
"""
|
|
311
|
+
|
|
312
|
+
_class_name: str = "TryParseTo"
|
|
313
|
+
|
|
314
|
+
def __init__(self, value: str):
|
|
315
|
+
super().__init__(value)
|
|
316
|
+
|
|
317
|
+
@classmethod
|
|
318
|
+
def __get_pydantic_core_schema__(
|
|
319
|
+
cls, source_type: Any, handler: Callable
|
|
320
|
+
) -> CoreSchema:
|
|
321
|
+
# Get the type parameter T from ParsesTo[T]
|
|
322
|
+
type_args = get_args(source_type)
|
|
323
|
+
if not type_args:
|
|
324
|
+
raise TypeError(
|
|
325
|
+
f"{cls._class_name} must be used with a type parameter, "
|
|
326
|
+
f"e.g. {cls._class_name}[int]"
|
|
327
|
+
)
|
|
328
|
+
target_type = type_args[0]
|
|
329
|
+
|
|
330
|
+
# Get the schema for the target type
|
|
331
|
+
target_schema = handler(target_type)
|
|
332
|
+
|
|
333
|
+
def validate_raw_string(value):
|
|
334
|
+
if isinstance(value, str) and is_literal_string(value):
|
|
335
|
+
return LiteralString(value)
|
|
336
|
+
# raise ValueError("Not a raw string")
|
|
337
|
+
|
|
338
|
+
# Create a union schema that either validates as raw string or normal validation
|
|
339
|
+
return union_schema(
|
|
340
|
+
[
|
|
341
|
+
# First option: validate as raw string
|
|
342
|
+
chain_schema(
|
|
343
|
+
[
|
|
344
|
+
no_info_plain_validator_function(validate_raw_string),
|
|
345
|
+
str_schema(),
|
|
346
|
+
# target_schema
|
|
347
|
+
]
|
|
348
|
+
),
|
|
349
|
+
# Second option: normal validation (string then target type)
|
|
350
|
+
chain_schema(
|
|
351
|
+
[
|
|
352
|
+
str_schema(),
|
|
353
|
+
# target_schema
|
|
354
|
+
]
|
|
355
|
+
),
|
|
356
|
+
# Third option: direct target type validation
|
|
357
|
+
target_schema,
|
|
358
|
+
# Fourth option: return the value as a string
|
|
359
|
+
str_schema(),
|
|
360
|
+
]
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
if TYPE_CHECKING:
|
|
365
|
+
try:
|
|
366
|
+
from typing_extensions import TypeAliasType
|
|
367
|
+
|
|
368
|
+
_T_alias = TypeVar("_T_alias")
|
|
369
|
+
ParsesTo = TypeAliasType("ParsesTo", _T_alias, type_params=(_T_alias,))
|
|
370
|
+
TryParseTo = TypeAliasType("TryParseTo", _T_alias, type_params=(_T_alias,))
|
|
371
|
+
except Exception:
|
|
372
|
+
# Best-effort fallback for type checkers that don't support TypeAliasType
|
|
373
|
+
pass
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
class _PostCall(Generic[T]):
|
|
377
|
+
def __call__(self, field: str, value: T, symbol_table: dict[str, Any]) -> T:
|
|
378
|
+
return value
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
@_uninstantiable
|
|
382
|
+
class Parsable(Generic[M]):
|
|
383
|
+
"""An abstract base class for parsing. Parsables support the `_parse_expressions`
|
|
384
|
+
method, which is used to parse the object from a string.
|
|
385
|
+
"""
|
|
386
|
+
|
|
387
|
+
def _parse_expressions(
|
|
388
|
+
self, symbol_table: dict[str, Any] = None, **kwargs
|
|
389
|
+
) -> tuple[M, dict[str, Any]]:
|
|
390
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
391
|
+
|
|
392
|
+
def get_fields(self) -> list[str]:
|
|
393
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
394
|
+
|
|
395
|
+
def get_validator(self, field: str) -> type:
|
|
396
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
397
|
+
|
|
398
|
+
def _parse_expressions_final(
|
|
399
|
+
self,
|
|
400
|
+
symbol_table: dict[str, Any],
|
|
401
|
+
order: tuple[str, ...],
|
|
402
|
+
post_calls: tuple[_PostCall[T], ...],
|
|
403
|
+
use_setattr: bool = True,
|
|
404
|
+
already_parsed: dict[str, Any] | None = None,
|
|
405
|
+
**kwargs,
|
|
406
|
+
) -> tuple["Parsable", dict[str, Any]]:
|
|
407
|
+
self._parsed = True
|
|
408
|
+
|
|
409
|
+
if already_parsed is None:
|
|
410
|
+
already_parsed = {}
|
|
411
|
+
|
|
412
|
+
fields = [f for f in self.get_fields() if f not in already_parsed]
|
|
413
|
+
|
|
414
|
+
field_order = _get_parsable_field_order(
|
|
415
|
+
order,
|
|
416
|
+
[
|
|
417
|
+
(
|
|
418
|
+
f,
|
|
419
|
+
getattr(self, f) if use_setattr else self[f],
|
|
420
|
+
self.get_validator(f),
|
|
421
|
+
)
|
|
422
|
+
for f in fields
|
|
423
|
+
],
|
|
424
|
+
)
|
|
425
|
+
prev_symbol_table = symbol_table.copy()
|
|
426
|
+
# for k, v in symbol_table.items():
|
|
427
|
+
# if isinstance(k, str) and k.startswith("global_") and v is None:
|
|
428
|
+
# raise ParseError(
|
|
429
|
+
# f"Global variable {k} is required. Please set it in "
|
|
430
|
+
# f"either the attributes or an outer scope. Try setting it with "
|
|
431
|
+
# f"Spec.variables.{k} = [value]."
|
|
432
|
+
# )
|
|
433
|
+
|
|
434
|
+
for field, value in already_parsed.items():
|
|
435
|
+
symbol_table[field] = value
|
|
436
|
+
if use_setattr:
|
|
437
|
+
setattr(self, field, value)
|
|
438
|
+
else:
|
|
439
|
+
self[field] = value
|
|
440
|
+
symbol_table[field] = value
|
|
441
|
+
|
|
442
|
+
for field in field_order:
|
|
443
|
+
value = getattr(self, field) if use_setattr else self[field]
|
|
444
|
+
validator = self.get_validator(field)
|
|
445
|
+
parsed = _parse_field(field, value, validator, symbol_table, self, **kwargs)
|
|
446
|
+
|
|
447
|
+
for post_call in post_calls:
|
|
448
|
+
parsed = post_call(field, value, parsed, symbol_table)
|
|
449
|
+
if use_setattr:
|
|
450
|
+
setattr(self, field, parsed)
|
|
451
|
+
else:
|
|
452
|
+
self[field] = parsed
|
|
453
|
+
symbol_table[field] = parsed
|
|
454
|
+
|
|
455
|
+
for k, v in prev_symbol_table.items():
|
|
456
|
+
if (
|
|
457
|
+
isinstance(k, str)
|
|
458
|
+
and k.startswith("global_")
|
|
459
|
+
and symbol_table.get(k, None) != v
|
|
460
|
+
):
|
|
461
|
+
raise ParseError(
|
|
462
|
+
f"Global variable {k} is already set to {v} in the outer scope. "
|
|
463
|
+
f"It cannot be changed to {symbol_table[k]}."
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
return self, symbol_table
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
class _FromYAMLAble:
|
|
470
|
+
@classmethod
|
|
471
|
+
def from_yaml(
|
|
472
|
+
cls: type[T],
|
|
473
|
+
*files: str | list[str] | Path | list[Path],
|
|
474
|
+
jinja_parse_data: dict[str, Any] | None = None,
|
|
475
|
+
top_key: str | None = None,
|
|
476
|
+
**kwargs,
|
|
477
|
+
) -> T:
|
|
478
|
+
"""
|
|
479
|
+
Loads a dictionary from one more more yaml files.
|
|
480
|
+
|
|
481
|
+
Each yaml file should contain a dictionary. Dictionaries are combined in the
|
|
482
|
+
order they are given.
|
|
483
|
+
|
|
484
|
+
Keyword arguments are also added to the dictionary.
|
|
485
|
+
|
|
486
|
+
Args:
|
|
487
|
+
files:
|
|
488
|
+
A list of yaml files to load.
|
|
489
|
+
jinja_parse_data: Optional[Dict[str, Any]]
|
|
490
|
+
A dictionary of Jinja2 data to use when parsing the yaml files.
|
|
491
|
+
top_key: Optional[str]
|
|
492
|
+
The top key to use when parsing the yaml files.
|
|
493
|
+
kwargs: Extra keyword arguments to be passed to the constructor.
|
|
494
|
+
|
|
495
|
+
Returns:
|
|
496
|
+
A dict containing the combined dictionaries.
|
|
497
|
+
"""
|
|
498
|
+
|
|
499
|
+
allfiles = []
|
|
500
|
+
jinja_parse_data = jinja_parse_data or {}
|
|
501
|
+
for f in files:
|
|
502
|
+
if isinstance(f, (list, tuple)):
|
|
503
|
+
if isinstance(f[0], Path):
|
|
504
|
+
f = list(map(str, f))
|
|
505
|
+
allfiles.extend(f)
|
|
506
|
+
else:
|
|
507
|
+
if isinstance(f, Path):
|
|
508
|
+
f = str(f)
|
|
509
|
+
allfiles.append(f)
|
|
510
|
+
files = allfiles
|
|
511
|
+
rval = {}
|
|
512
|
+
key2file = {}
|
|
513
|
+
extra_elems = []
|
|
514
|
+
to_parse = []
|
|
515
|
+
for f in files:
|
|
516
|
+
globbed = [x for x in glob.glob(f) if os.path.isfile(x)]
|
|
517
|
+
if not globbed:
|
|
518
|
+
raise FileNotFoundError(f"Could not find file {f}")
|
|
519
|
+
for g in globbed:
|
|
520
|
+
if any(os.path.samefile(g, x) for x in to_parse):
|
|
521
|
+
logging.info('Ignoring duplicate file "%s" in yaml load', g)
|
|
522
|
+
else:
|
|
523
|
+
to_parse.append(g)
|
|
524
|
+
|
|
525
|
+
for f in to_parse:
|
|
526
|
+
if not (
|
|
527
|
+
f.endswith(".yaml") or f.endswith(".jinja") or f.endswith(".jinja2")
|
|
528
|
+
):
|
|
529
|
+
logging.warning(
|
|
530
|
+
f"File {f} does not end with .yaml, .jinja, or .jinja2. Skipping."
|
|
531
|
+
)
|
|
532
|
+
logging.info("Loading yaml file %s", f)
|
|
533
|
+
loaded = _yaml.load_yaml(f, data=jinja_parse_data)
|
|
534
|
+
if not isinstance(loaded, dict):
|
|
535
|
+
raise TypeError(
|
|
536
|
+
f"Expected a dictionary from file {f}, got {type(loaded)}"
|
|
537
|
+
)
|
|
538
|
+
for k, v in loaded.items():
|
|
539
|
+
if k in rval:
|
|
540
|
+
logging.info("Found extra top-key %s in %s", k, f)
|
|
541
|
+
extra_elems.append((k, v))
|
|
542
|
+
else:
|
|
543
|
+
logging.info("Found top key %s in %s", k, f)
|
|
544
|
+
key2file[k] = f
|
|
545
|
+
rval[k] = v
|
|
546
|
+
|
|
547
|
+
if top_key is not None:
|
|
548
|
+
if top_key not in rval:
|
|
549
|
+
raise KeyError(f"Top key {top_key} not found in {files}")
|
|
550
|
+
rval = rval[top_key]
|
|
551
|
+
|
|
552
|
+
c = None
|
|
553
|
+
try:
|
|
554
|
+
c = cls(**rval, **kwargs)
|
|
555
|
+
except Exception as e:
|
|
556
|
+
pass
|
|
557
|
+
|
|
558
|
+
if c is None and rval is None:
|
|
559
|
+
if top_key is not None:
|
|
560
|
+
raise ValueError(
|
|
561
|
+
f"No data to parse from {files} with top key {top_key}. Is there "
|
|
562
|
+
f"content under the top key {top_key}?"
|
|
563
|
+
)
|
|
564
|
+
raise ValueError(
|
|
565
|
+
f"No data to parse from {files}. Is there content in the file(s)?"
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
if c is None and len(rval) == 1:
|
|
569
|
+
logging.warning(
|
|
570
|
+
f"Trying to parse a single element dictionary as a {cls.__name__}. "
|
|
571
|
+
)
|
|
572
|
+
try:
|
|
573
|
+
rval_first = list(rval.values())[0]
|
|
574
|
+
if not isinstance(rval_first, dict):
|
|
575
|
+
raise TypeError(
|
|
576
|
+
f"Expected a dictionary as the top-level element in {files}, "
|
|
577
|
+
f"got {type(rval_first)}."
|
|
578
|
+
)
|
|
579
|
+
c = cls(**rval_first, **kwargs)
|
|
580
|
+
except Exception as e:
|
|
581
|
+
logging.warning(
|
|
582
|
+
f"Error parsing {files} with top key {top_key}. " f"Error: {e}"
|
|
583
|
+
)
|
|
584
|
+
if c is None:
|
|
585
|
+
c = cls(**rval, **kwargs)
|
|
586
|
+
|
|
587
|
+
if extra_elems:
|
|
588
|
+
logging.info(
|
|
589
|
+
"Parsing extra attributes %s", ", ".join([x[0] for x in extra_elems])
|
|
590
|
+
)
|
|
591
|
+
c._yaml_source = ",".join(files)
|
|
592
|
+
return c
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
def _parse_field(
|
|
596
|
+
field,
|
|
597
|
+
value,
|
|
598
|
+
validator,
|
|
599
|
+
symbol_table,
|
|
600
|
+
parent,
|
|
601
|
+
must_parse_try_parse_to: bool = False,
|
|
602
|
+
must_copy: bool = True,
|
|
603
|
+
**kwargs,
|
|
604
|
+
):
|
|
605
|
+
from accelforge.util._setexpressions import InvertibleSet, eval_set_expression
|
|
606
|
+
|
|
607
|
+
def check_subclass(x, cls):
|
|
608
|
+
return isinstance(x, type) and issubclass(x, cls)
|
|
609
|
+
|
|
610
|
+
try:
|
|
611
|
+
# Get the origin type (ParsesTo or TryParseTo) and its arguments
|
|
612
|
+
origin = get_origin(validator)
|
|
613
|
+
if origin is ParsesTo or origin is TryParseTo:
|
|
614
|
+
try:
|
|
615
|
+
target_type = get_args(validator)[0]
|
|
616
|
+
parsed = value
|
|
617
|
+
if isinstance(target_type, tuple) and any(
|
|
618
|
+
check_subclass(t, InvertibleSet) for t in target_type
|
|
619
|
+
):
|
|
620
|
+
raise NotImplementedError(
|
|
621
|
+
f"InvertibleSet must be used directly, not as a part of a "
|
|
622
|
+
f"union, else this function must be updated."
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
# Check if validator is for InvertibleSet
|
|
626
|
+
if check_subclass(target_type, InvertibleSet):
|
|
627
|
+
# Get the target type from the validator
|
|
628
|
+
|
|
629
|
+
# If the given type is a set, replace it with a string that'll parse
|
|
630
|
+
if isinstance(value, set):
|
|
631
|
+
value = " | ".join(str(v) for v in value)
|
|
632
|
+
|
|
633
|
+
type_args = target_type.__pydantic_generic_metadata__["args"]
|
|
634
|
+
assert len(type_args) == 1, "Expected exactly one type argument"
|
|
635
|
+
expected_element_type = type_args[0]
|
|
636
|
+
|
|
637
|
+
try:
|
|
638
|
+
# eval_set_expression does the type checking for us
|
|
639
|
+
return eval_set_expression(
|
|
640
|
+
value,
|
|
641
|
+
symbol_table,
|
|
642
|
+
expected_space=expected_element_type,
|
|
643
|
+
location=field,
|
|
644
|
+
)
|
|
645
|
+
except ParseError as e:
|
|
646
|
+
if origin is TryParseTo and not must_parse_try_parse_to:
|
|
647
|
+
return LiteralString(value)
|
|
648
|
+
raise
|
|
649
|
+
elif is_literal_string(value):
|
|
650
|
+
parsed = LiteralString(value)
|
|
651
|
+
else:
|
|
652
|
+
parsed = parse_expression(value, symbol_table)
|
|
653
|
+
|
|
654
|
+
if must_copy and id(parsed) == id(value):
|
|
655
|
+
parsed = copy.deepcopy(parsed)
|
|
656
|
+
|
|
657
|
+
# Get the target type from the validator
|
|
658
|
+
target_any = (
|
|
659
|
+
target_type is Any
|
|
660
|
+
or isinstance(target_type, tuple)
|
|
661
|
+
and Any in target_type
|
|
662
|
+
)
|
|
663
|
+
if not target_any and not isinstance(parsed, target_type):
|
|
664
|
+
raise ParseError(
|
|
665
|
+
f'{value} parsed to "{parsed}" with type {type(parsed).__name__}.'
|
|
666
|
+
f" Expected {target_type}.",
|
|
667
|
+
)
|
|
668
|
+
except ParseError as e:
|
|
669
|
+
if origin is TryParseTo and not must_parse_try_parse_to:
|
|
670
|
+
return LiteralString(value)
|
|
671
|
+
raise
|
|
672
|
+
else:
|
|
673
|
+
parsed = value
|
|
674
|
+
|
|
675
|
+
if isinstance(parsed, Parsable) and origin is not NoParse:
|
|
676
|
+
parsed, _ = parsed._parse_expressions(
|
|
677
|
+
symbol_table=symbol_table,
|
|
678
|
+
must_copy=must_copy,
|
|
679
|
+
must_parse_try_parse_to=must_parse_try_parse_to,
|
|
680
|
+
**kwargs,
|
|
681
|
+
)
|
|
682
|
+
return parsed
|
|
683
|
+
elif isinstance(parsed, str):
|
|
684
|
+
return LiteralString(parsed)
|
|
685
|
+
else:
|
|
686
|
+
return parsed
|
|
687
|
+
except ParseError as e:
|
|
688
|
+
try:
|
|
689
|
+
e.add_field(parent[field].name)
|
|
690
|
+
except:
|
|
691
|
+
e.add_field(field)
|
|
692
|
+
raise e
|
|
693
|
+
|
|
694
|
+
|
|
695
|
+
# python_name_regex = re.compile(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b')
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
def _get_parsable_field_order(
|
|
699
|
+
order: tuple[str, ...], field_value_validator_triples: list[tuple[str, Any, type]]
|
|
700
|
+
) -> list[str]:
|
|
701
|
+
|
|
702
|
+
def is_parsable(value, validator):
|
|
703
|
+
if isinstance(value, Parsable):
|
|
704
|
+
return True
|
|
705
|
+
return False
|
|
706
|
+
|
|
707
|
+
order = list(order)
|
|
708
|
+
to_sort = []
|
|
709
|
+
|
|
710
|
+
for field, value, validator in field_value_validator_triples:
|
|
711
|
+
if field in order:
|
|
712
|
+
continue
|
|
713
|
+
if get_origin(validator) is not ParsesTo and not is_parsable(value, validator):
|
|
714
|
+
order.append(field)
|
|
715
|
+
continue
|
|
716
|
+
to_sort.append((field, value))
|
|
717
|
+
|
|
718
|
+
field2validator = {f: v for f, v, _ in field_value_validator_triples}
|
|
719
|
+
|
|
720
|
+
dependencies = {field: set() for field, _ in to_sort}
|
|
721
|
+
for other_field, other_value in to_sort:
|
|
722
|
+
# Can't have any dependencies if you're not going to be parsed
|
|
723
|
+
if not isinstance(other_value, str) or is_literal_string(other_value):
|
|
724
|
+
continue
|
|
725
|
+
for field, value in to_sort:
|
|
726
|
+
if field != other_field:
|
|
727
|
+
if re.findall(r"\b" + re.escape(field) + r"\b", other_value):
|
|
728
|
+
dependencies[other_field].add(field)
|
|
729
|
+
|
|
730
|
+
while to_sort:
|
|
731
|
+
can_add = [
|
|
732
|
+
(f, v) for f, v in to_sort if all(dep in order for dep in dependencies[f])
|
|
733
|
+
]
|
|
734
|
+
if not can_add:
|
|
735
|
+
raise ParseError(
|
|
736
|
+
f"Circular dependency detected in expressions. "
|
|
737
|
+
f"Fields: {', '.join(t[0] for t in to_sort)}"
|
|
738
|
+
)
|
|
739
|
+
# Parsables last
|
|
740
|
+
for f, v in can_add:
|
|
741
|
+
if not is_parsable(v, field2validator[f]):
|
|
742
|
+
order.append(f)
|
|
743
|
+
to_sort.remove((f, v))
|
|
744
|
+
break
|
|
745
|
+
else:
|
|
746
|
+
order.append(can_add[0][0])
|
|
747
|
+
to_sort.remove(can_add[0])
|
|
748
|
+
return order
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
class _OurBaseModel(BaseModel, _FromYAMLAble, Mapping):
|
|
752
|
+
# Exclude is supported OK, but makes the docs a lot longer because it's in so many
|
|
753
|
+
# objects and has a very long type.
|
|
754
|
+
def to_yaml(
|
|
755
|
+
self, f: str | None = None
|
|
756
|
+
) -> str: # , exclude: IncEx | None = None) -> str:
|
|
757
|
+
"""
|
|
758
|
+
Dump the model to a YAML string.
|
|
759
|
+
|
|
760
|
+
Parameters
|
|
761
|
+
----------
|
|
762
|
+
f: str | None
|
|
763
|
+
The file to write the YAML to. If not given, then returns as a string.
|
|
764
|
+
exclude: IncEx | None
|
|
765
|
+
The fields to exclude from the YAML.
|
|
766
|
+
|
|
767
|
+
Returns
|
|
768
|
+
-------
|
|
769
|
+
str
|
|
770
|
+
The YAML string.
|
|
771
|
+
"""
|
|
772
|
+
dump = self.model_dump() # exclude=exclude)
|
|
773
|
+
|
|
774
|
+
def _to_str(value: Any):
|
|
775
|
+
if isinstance(value, list):
|
|
776
|
+
return [_to_str(x) for x in value]
|
|
777
|
+
elif isinstance(value, dict):
|
|
778
|
+
return {_to_str(k): _to_str(v) for k, v in value.items()}
|
|
779
|
+
elif isinstance(value, str):
|
|
780
|
+
return str(value)
|
|
781
|
+
return value
|
|
782
|
+
|
|
783
|
+
if f is not None:
|
|
784
|
+
_yaml.write_yaml_file(f, _to_str(dump))
|
|
785
|
+
return _yaml.to_yaml_string(_to_str(dump))
|
|
786
|
+
|
|
787
|
+
def all_fields_default(self):
|
|
788
|
+
for field in self.__class__.model_fields:
|
|
789
|
+
default = self.__class__.model_fields[field].default
|
|
790
|
+
if getattr(self, field) != default:
|
|
791
|
+
return False
|
|
792
|
+
return True
|
|
793
|
+
|
|
794
|
+
def model_dump_non_none(self, **kwargs):
|
|
795
|
+
return {k: v for k, v in self.model_dump(**kwargs).items() if v is not None}
|
|
796
|
+
|
|
797
|
+
def shallow_model_dump(self, include_None: bool = False, **kwargs):
|
|
798
|
+
keys = self.get_fields()
|
|
799
|
+
if getattr(self, "__pydantic_extra__", None) is not None:
|
|
800
|
+
keys.extend([k for k in self.__pydantic_extra__.keys() if k not in keys])
|
|
801
|
+
|
|
802
|
+
if not include_None:
|
|
803
|
+
keys = [k for k in keys if getattr(self, k) is not None]
|
|
804
|
+
|
|
805
|
+
return {k: getattr(self, k) for k in keys}
|
|
806
|
+
|
|
807
|
+
def __contains__(self, key: str) -> bool:
|
|
808
|
+
try:
|
|
809
|
+
self[key]
|
|
810
|
+
return True
|
|
811
|
+
except KeyError:
|
|
812
|
+
return False
|
|
813
|
+
|
|
814
|
+
def __getitem__(self, key: str) -> Any:
|
|
815
|
+
try:
|
|
816
|
+
return getattr(self, key)
|
|
817
|
+
except AttributeError:
|
|
818
|
+
pass
|
|
819
|
+
raise KeyError(f"Key {key} not found in {self.__class__.__name__}")
|
|
820
|
+
|
|
821
|
+
def __setitem__(self, key: str, value: Any):
|
|
822
|
+
setattr(self, key, value)
|
|
823
|
+
|
|
824
|
+
def __delitem__(self, key: str):
|
|
825
|
+
delattr(self, key)
|
|
826
|
+
|
|
827
|
+
def __iter__(self) -> Iterator[str]:
|
|
828
|
+
return iter(self.get_fields())
|
|
829
|
+
|
|
830
|
+
def __len__(self) -> int:
|
|
831
|
+
return len(self.get_fields())
|
|
832
|
+
|
|
833
|
+
|
|
834
|
+
@_uninstantiable
|
|
835
|
+
class ParsableModel(_OurBaseModel, Parsable["ParsableModel"]):
|
|
836
|
+
"""A model that will parse any fields that are given to it. When parsing, submodels
|
|
837
|
+
will also be parsed if they support it. Parsing will parse any fields that are given
|
|
838
|
+
as strings and do not match the expected type.
|
|
839
|
+
"""
|
|
840
|
+
|
|
841
|
+
model_config = ConfigDict(extra="forbid")
|
|
842
|
+
# type: Optional[str] = None
|
|
843
|
+
|
|
844
|
+
def __init__(self, **kwargs):
|
|
845
|
+
required_type = kwargs.pop("type", None)
|
|
846
|
+
|
|
847
|
+
if self.model_config["extra"] == "forbid":
|
|
848
|
+
supported_fields = set(self.__class__.model_fields.keys())
|
|
849
|
+
for k in kwargs.keys():
|
|
850
|
+
if k not in supported_fields:
|
|
851
|
+
raise ValueError(
|
|
852
|
+
f"Field {k} is not supported for {self.__class__.__name__}. "
|
|
853
|
+
f"Supported fields are:\n\t"
|
|
854
|
+
+ "\n\t".join(sorted(supported_fields))
|
|
855
|
+
+ "\n",
|
|
856
|
+
)
|
|
857
|
+
|
|
858
|
+
super().__init__(**kwargs)
|
|
859
|
+
if required_type is not None:
|
|
860
|
+
try:
|
|
861
|
+
passed_check = isinstance(self, required_type)
|
|
862
|
+
except TypeError:
|
|
863
|
+
raise TypeError(
|
|
864
|
+
f"Error checking required type. Was given type argument "
|
|
865
|
+
f"{required_type} a valid type?"
|
|
866
|
+
) from None
|
|
867
|
+
|
|
868
|
+
if not passed_check:
|
|
869
|
+
raise TypeError(
|
|
870
|
+
f"type field {required_type} does not match"
|
|
871
|
+
f"{self.__class__.__name__}"
|
|
872
|
+
)
|
|
873
|
+
|
|
874
|
+
def get_validator(self, field: str) -> Type:
|
|
875
|
+
if field in self.__class__.model_fields:
|
|
876
|
+
return self.__class__.model_fields[field].annotation
|
|
877
|
+
return ParsesTo[Any]
|
|
878
|
+
|
|
879
|
+
def get_fields(self) -> list[str]:
|
|
880
|
+
fields = set(self.__class__.model_fields.keys())
|
|
881
|
+
if getattr(self, "__pydantic_extra__", None) is not None:
|
|
882
|
+
fields.update(self.__pydantic_extra__.keys())
|
|
883
|
+
return sorted(fields)
|
|
884
|
+
|
|
885
|
+
def _parse_expressions(
|
|
886
|
+
self,
|
|
887
|
+
symbol_table: dict[str, Any] = None,
|
|
888
|
+
order: tuple[str, ...] = (),
|
|
889
|
+
post_calls: tuple[_PostCall[T], ...] = (),
|
|
890
|
+
already_parsed: dict[str, Any] | None = None,
|
|
891
|
+
**kwargs,
|
|
892
|
+
) -> tuple[Self, dict[str, Any]]:
|
|
893
|
+
new = self.model_copy()
|
|
894
|
+
symbol_table = symbol_table.copy() if symbol_table is not None else {}
|
|
895
|
+
kwargs = dict(kwargs)
|
|
896
|
+
return new._parse_expressions_final(
|
|
897
|
+
symbol_table,
|
|
898
|
+
order,
|
|
899
|
+
post_calls,
|
|
900
|
+
use_setattr=True,
|
|
901
|
+
already_parsed=already_parsed,
|
|
902
|
+
**kwargs,
|
|
903
|
+
)
|
|
904
|
+
|
|
905
|
+
|
|
906
|
+
class NonParsableModel(_OurBaseModel):
|
|
907
|
+
"""A model that will not parse any fields."""
|
|
908
|
+
|
|
909
|
+
model_config = ConfigDict(extra="forbid")
|
|
910
|
+
type: Optional[str] = None
|
|
911
|
+
|
|
912
|
+
def get_validator(self, field: str) -> Type:
|
|
913
|
+
return Any
|
|
914
|
+
|
|
915
|
+
|
|
916
|
+
class ParsableList(list[T], Parsable["ParsableList[T]"], Generic[T]):
|
|
917
|
+
"""
|
|
918
|
+
A list that can be parsed from a string. ParsableList[T] means that a given string
|
|
919
|
+
can be parsed, yielding a list of objects of type T.
|
|
920
|
+
"""
|
|
921
|
+
|
|
922
|
+
def get_validator(self, field: str) -> Type:
|
|
923
|
+
return T
|
|
924
|
+
|
|
925
|
+
def _parse_expressions(
|
|
926
|
+
self,
|
|
927
|
+
symbol_table: dict[str, Any] = None,
|
|
928
|
+
order: tuple[str, ...] = (),
|
|
929
|
+
post_calls: tuple[_PostCall[T], ...] = (),
|
|
930
|
+
already_parsed: dict[str, Any] | None = None,
|
|
931
|
+
**kwargs,
|
|
932
|
+
) -> tuple["ParsableList[T]", dict[str, Any]]:
|
|
933
|
+
new = ParsableList[T](x for x in self)
|
|
934
|
+
symbol_table = symbol_table.copy() if symbol_table is not None else {}
|
|
935
|
+
order = order + tuple(x for x in range(len(new)) if x not in order)
|
|
936
|
+
return new._parse_expressions_final(
|
|
937
|
+
symbol_table,
|
|
938
|
+
order,
|
|
939
|
+
post_calls,
|
|
940
|
+
use_setattr=False,
|
|
941
|
+
already_parsed=already_parsed,
|
|
942
|
+
**kwargs,
|
|
943
|
+
)
|
|
944
|
+
|
|
945
|
+
def get_fields(self) -> list[str]:
|
|
946
|
+
return sorted(range(len(self)))
|
|
947
|
+
|
|
948
|
+
@classmethod
|
|
949
|
+
def __get_pydantic_core_schema__(
|
|
950
|
+
cls, source_type: Any, handler: Callable
|
|
951
|
+
) -> CoreSchema:
|
|
952
|
+
# Get the type parameter T from ParsableList[T]
|
|
953
|
+
type_args = get_args(source_type)
|
|
954
|
+
if not type_args:
|
|
955
|
+
raise TypeError(
|
|
956
|
+
f"ParsableList must be used with a type parameter, e.g. ParsableList[int]"
|
|
957
|
+
)
|
|
958
|
+
item_type = type_args[0]
|
|
959
|
+
|
|
960
|
+
# Get the schema for the item type
|
|
961
|
+
item_schema = handler(item_type)
|
|
962
|
+
|
|
963
|
+
# Create a schema that validates lists of the item type
|
|
964
|
+
return chain_schema(
|
|
965
|
+
[
|
|
966
|
+
list_schema(item_schema),
|
|
967
|
+
no_info_plain_validator_function(lambda x: cls(x)),
|
|
968
|
+
]
|
|
969
|
+
)
|
|
970
|
+
|
|
971
|
+
def __getitem__(self, key: str | int | slice) -> T:
|
|
972
|
+
if isinstance(key, int):
|
|
973
|
+
return super().__getitem__(key) # type: ignore
|
|
974
|
+
|
|
975
|
+
elif isinstance(key, slice):
|
|
976
|
+
return ParsableList[T](super().__getitem__(key))
|
|
977
|
+
|
|
978
|
+
elif isinstance(key, str):
|
|
979
|
+
found = None
|
|
980
|
+
for elem in self:
|
|
981
|
+
name = None
|
|
982
|
+
if isinstance(elem, dict):
|
|
983
|
+
name = elem.get("name", None)
|
|
984
|
+
elif hasattr(elem, "name"):
|
|
985
|
+
name = elem.name
|
|
986
|
+
if name is not None and name == key:
|
|
987
|
+
if found is not None:
|
|
988
|
+
raise ValueError(f'Multiple elements with name "{key}" found.')
|
|
989
|
+
found = elem
|
|
990
|
+
if found is not None:
|
|
991
|
+
return found
|
|
992
|
+
|
|
993
|
+
fields = self.get_fields()
|
|
994
|
+
fields += [
|
|
995
|
+
(
|
|
996
|
+
x.name
|
|
997
|
+
if hasattr(x, "name")
|
|
998
|
+
else x.get("name", None) if isinstance(x, dict) else None
|
|
999
|
+
)
|
|
1000
|
+
for x in self
|
|
1001
|
+
]
|
|
1002
|
+
fields = sorted(str(x) for x in fields if x is not None)
|
|
1003
|
+
raise KeyError(
|
|
1004
|
+
f'No element with name "{key}" found. Available names: {', '.join(fields)}'
|
|
1005
|
+
)
|
|
1006
|
+
|
|
1007
|
+
def __contains__(self, item: Any) -> bool:
|
|
1008
|
+
try:
|
|
1009
|
+
self[item]
|
|
1010
|
+
return True
|
|
1011
|
+
except KeyError:
|
|
1012
|
+
return super().__contains__(item)
|
|
1013
|
+
|
|
1014
|
+
def __copy__(self) -> Self:
|
|
1015
|
+
return type(self)(x for x in self)
|
|
1016
|
+
|
|
1017
|
+
|
|
1018
|
+
class ParsableDict(
|
|
1019
|
+
dict[K, V], Parsable["ParsableDict[K, V]"], Generic[K, V], _FromYAMLAble
|
|
1020
|
+
):
|
|
1021
|
+
"""A dictionary that can be parsed from a string. ParsableDict[K, V] means that a
|
|
1022
|
+
given string can be parsed, yielding a dictionary with keys of type K and values of
|
|
1023
|
+
type V.
|
|
1024
|
+
"""
|
|
1025
|
+
|
|
1026
|
+
def get_validator(self, field: str) -> type:
|
|
1027
|
+
return V
|
|
1028
|
+
|
|
1029
|
+
def get_fields(self) -> list[str]:
|
|
1030
|
+
return sorted(self.keys())
|
|
1031
|
+
|
|
1032
|
+
def _parse_expressions(
|
|
1033
|
+
self,
|
|
1034
|
+
symbol_table: dict[str, Any] = None,
|
|
1035
|
+
order: tuple[str, ...] = (),
|
|
1036
|
+
post_calls: tuple[_PostCall[V], ...] = (),
|
|
1037
|
+
already_parsed: dict[str, Any] | None = None,
|
|
1038
|
+
**kwargs,
|
|
1039
|
+
) -> tuple["ParsableDict[K, V]", dict[str, Any]]:
|
|
1040
|
+
new = ParsableDict[K, V](self)
|
|
1041
|
+
symbol_table = symbol_table.copy() if symbol_table is not None else {}
|
|
1042
|
+
return new._parse_expressions_final(
|
|
1043
|
+
symbol_table,
|
|
1044
|
+
order,
|
|
1045
|
+
post_calls,
|
|
1046
|
+
use_setattr=False,
|
|
1047
|
+
already_parsed=already_parsed,
|
|
1048
|
+
**kwargs,
|
|
1049
|
+
)
|
|
1050
|
+
|
|
1051
|
+
@classmethod
|
|
1052
|
+
def __get_pydantic_core_schema__(
|
|
1053
|
+
cls, source_type: Any, handler: Callable
|
|
1054
|
+
) -> CoreSchema:
|
|
1055
|
+
# Get the type parameters K and V from ParsableDict[K, V]
|
|
1056
|
+
type_args = get_args(source_type)
|
|
1057
|
+
if len(type_args) != 2:
|
|
1058
|
+
raise TypeError(
|
|
1059
|
+
f"ParsableDict must be used with two type parameters, e.g. ParsableDict[str, int]"
|
|
1060
|
+
)
|
|
1061
|
+
key_type, value_type = type_args
|
|
1062
|
+
|
|
1063
|
+
# Get the schemas for the key and value types
|
|
1064
|
+
key_schema = handler.generate_schema(key_type)
|
|
1065
|
+
value_schema = handler.generate_schema(value_type)
|
|
1066
|
+
|
|
1067
|
+
# Create a schema that validates dictionaries with the specified key and value types
|
|
1068
|
+
return chain_schema(
|
|
1069
|
+
[
|
|
1070
|
+
dict_schema(key_schema, value_schema),
|
|
1071
|
+
no_info_plain_validator_function(lambda x: cls(x)),
|
|
1072
|
+
]
|
|
1073
|
+
)
|
|
1074
|
+
|
|
1075
|
+
def __copy__(self) -> Self:
|
|
1076
|
+
return type(self)({k: v for k, v in self.items()})
|
|
1077
|
+
|
|
1078
|
+
|
|
1079
|
+
class ParseExtras(ParsableModel):
|
|
1080
|
+
"""
|
|
1081
|
+
A model that will parse any extra fields that are given to it.
|
|
1082
|
+
"""
|
|
1083
|
+
|
|
1084
|
+
model_config = ConfigDict(extra="allow")
|
|
1085
|
+
|
|
1086
|
+
def get_validator(self, field: str) -> type:
|
|
1087
|
+
if field not in self.__class__.model_fields:
|
|
1088
|
+
return ParsesTo[Any]
|
|
1089
|
+
return self.__class__.model_fields[field].annotation
|