ppdmod 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ppdmod/parameter.py ADDED
@@ -0,0 +1,152 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, List
3
+
4
+ import astropy.units as u
5
+ import numpy as np
6
+ from numpy.typing import ArrayLike, NDArray
7
+
8
+ from .options import STANDARD_PARAMS
9
+ from .utils import smooth_interpolation
10
+
11
+
12
+ @dataclass()
13
+ class Parameter:
14
+ """Defines a parameter."""
15
+
16
+ name: str | None = None
17
+ description: str | None = None
18
+ value: Any | None = None
19
+ grid: np.ndarray | None = None
20
+ unit: u.Quantity | None = None
21
+ min: float | None = None
22
+ max: float | None = None
23
+ dtype: type | None = None
24
+ smooth: bool | None = None
25
+ reflective: bool | None = None
26
+ periodic: bool | None = None
27
+ free: bool | None = None
28
+ shared: bool | None = None
29
+ base: str | None = None
30
+ uniform: float | None = None
31
+
32
+ def _process_base(self, base: str | None) -> None:
33
+ """Process the template attribute."""
34
+ if base is None:
35
+ return
36
+
37
+ base_param = getattr(STANDARD_PARAMS, base)
38
+ for key, value in base_param.items():
39
+ if getattr(self, key) is None:
40
+ setattr(self, key, value)
41
+
42
+ for key in ["free", "shared", "smooth", "reflective", "periodic"]:
43
+ if key not in base_param:
44
+ if getattr(self, key) is not None:
45
+ continue
46
+
47
+ setattr(self, key, False)
48
+
49
+ def _set_to_numpy_array(self, array: ArrayLike | None = None) -> Any | np.ndarray:
50
+ """Converts a value to a numpy array."""
51
+ if array is None:
52
+ return
53
+
54
+ if isinstance(array, (tuple, list)):
55
+ return np.array(array)
56
+
57
+ return array
58
+
59
+ def __setattr__(self, key: str, value: Any):
60
+ """Sets an attribute."""
61
+ if key != "unit":
62
+ if isinstance(value, u.Quantity):
63
+ value = value.value
64
+ super().__setattr__(key, value)
65
+
66
+ def __str__(self):
67
+ message = (
68
+ f"Parameter: {self.name} has the value "
69
+ f"{np.round(self.value, 2)} and "
70
+ f"is {'free' if self.free else 'fixed'}"
71
+ f"is {'shared' if self.shared else 'non-shared'}"
72
+ )
73
+ if self.max is not None:
74
+ message += f" with its limits being {self.min:.1f}-{self.max:.1f}"
75
+
76
+ return message
77
+
78
+ def __post_init__(self):
79
+ """Post initialisation actions."""
80
+ self.value = self._set_to_numpy_array(self.value)
81
+ self.grid = self._set_to_numpy_array(self.grid)
82
+ self._process_base(self.base)
83
+
84
+ def __call__(
85
+ self,
86
+ t: NDArray[Any] | None = None,
87
+ wl: NDArray[Any] | None = None,
88
+ ) -> np.ndarray:
89
+ """Gets the value for the parameter or the corresponding
90
+ values for some points."""
91
+ if self.value is None:
92
+ return None
93
+
94
+ if wl is None or self.grid is None:
95
+ value = self.value
96
+ else:
97
+ if self.smooth:
98
+ value = smooth_interpolation(wl.value, self.grid, self.value)
99
+ else:
100
+ value = np.interp(wl.value, self.grid, self.value)
101
+
102
+ return u.Quantity(value, unit=self.unit, dtype=self.dtype)
103
+
104
+ def copy(self) -> "Parameter":
105
+ """Copies the parameter."""
106
+ return Parameter(
107
+ name=self.name,
108
+ description=self.description,
109
+ value=self.value,
110
+ grid=self.grid,
111
+ unit=self.unit,
112
+ min=self.min,
113
+ max=self.max,
114
+ dtype=self.dtype,
115
+ smooth=self.smooth,
116
+ periodic=self.periodic,
117
+ free=self.free,
118
+ shared=self.shared,
119
+ base=self.base,
120
+ uniform=self.uniform,
121
+ )
122
+
123
+ def get_limits(self) -> List[float | None]:
124
+ return self.min, self.max
125
+
126
+
127
+ @dataclass()
128
+ class MultiParam:
129
+ params: List[Parameter] | None = None
130
+
131
+ def __post_init__(self):
132
+ """Post initialisation actions."""
133
+ self.indices = np.array([i for i in range(len(self.params))])
134
+ for index, param in zip(self.indices, self.params):
135
+ if ".t" not in param.name:
136
+ param.name += f".t{index}"
137
+
138
+ def __getitem__(self, t):
139
+ return self.params[t]
140
+
141
+ def __call__(
142
+ self,
143
+ t: NDArray[Any] | None = None,
144
+ wl: NDArray[Any] | None = None,
145
+ ) -> np.ndarray:
146
+ """Gets the value for the parameter or the corresponding
147
+ values for some points."""
148
+ return self.params[t](wl)
149
+
150
+ def copy(self) -> "MultiParam":
151
+ """Copies the parameter."""
152
+ return MultiParam([param.copy() for param in self.params])