mt-metadata 0.3.9__py2.py3-none-any.whl → 0.4.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mt-metadata might be problematic. Click here for more details.
- mt_metadata/__init__.py +1 -1
- mt_metadata/base/helpers.py +84 -9
- mt_metadata/base/metadata.py +137 -65
- mt_metadata/features/__init__.py +14 -0
- mt_metadata/features/coherence.py +303 -0
- mt_metadata/features/cross_powers.py +29 -0
- mt_metadata/features/fc_coherence.py +81 -0
- mt_metadata/features/feature.py +72 -0
- mt_metadata/features/feature_decimation_channel.py +26 -0
- mt_metadata/features/feature_fc.py +24 -0
- mt_metadata/{transfer_functions/processing/aurora/decimation.py → features/feature_fc_run.py} +9 -4
- mt_metadata/features/feature_ts.py +24 -0
- mt_metadata/{transfer_functions/processing/aurora/window.py → features/feature_ts_run.py} +11 -18
- mt_metadata/features/standards/__init__.py +6 -0
- mt_metadata/features/standards/base_feature.json +46 -0
- mt_metadata/features/standards/coherence.json +57 -0
- mt_metadata/features/standards/fc_coherence.json +57 -0
- mt_metadata/features/standards/feature_decimation_channel.json +68 -0
- mt_metadata/features/standards/feature_fc_run.json +35 -0
- mt_metadata/features/standards/feature_ts_run.json +35 -0
- mt_metadata/features/standards/feature_weighting_window.json +46 -0
- mt_metadata/features/standards/weight_kernel.json +46 -0
- mt_metadata/features/standards/weights.json +101 -0
- mt_metadata/features/test_helpers/channel_weight_specs_example.json +156 -0
- mt_metadata/features/weights/__init__.py +0 -0
- mt_metadata/features/weights/base.py +44 -0
- mt_metadata/features/weights/channel_weight_spec.py +209 -0
- mt_metadata/features/weights/feature_weight_spec.py +194 -0
- mt_metadata/features/weights/monotonic_weight_kernel.py +275 -0
- mt_metadata/features/weights/standards/__init__.py +6 -0
- mt_metadata/features/weights/standards/activation_monotonic_weight_kernel.json +38 -0
- mt_metadata/features/weights/standards/base.json +36 -0
- mt_metadata/features/weights/standards/channel_weight_spec.json +35 -0
- mt_metadata/features/weights/standards/composite.json +36 -0
- mt_metadata/features/weights/standards/feature_weight_spec.json +13 -0
- mt_metadata/features/weights/standards/monotonic_weight_kernel.json +49 -0
- mt_metadata/features/weights/standards/taper_monotonic_weight_kernel.json +16 -0
- mt_metadata/features/weights/taper_weight_kernel.py +60 -0
- mt_metadata/helper_functions.py +69 -0
- mt_metadata/timeseries/filters/channel_response.py +77 -37
- mt_metadata/timeseries/filters/coefficient_filter.py +6 -5
- mt_metadata/timeseries/filters/filter_base.py +11 -15
- mt_metadata/timeseries/filters/fir_filter.py +8 -1
- mt_metadata/timeseries/filters/frequency_response_table_filter.py +26 -11
- mt_metadata/timeseries/filters/helper_functions.py +0 -2
- mt_metadata/timeseries/filters/obspy_stages.py +4 -1
- mt_metadata/timeseries/filters/pole_zero_filter.py +9 -5
- mt_metadata/timeseries/filters/time_delay_filter.py +8 -1
- mt_metadata/timeseries/location.py +20 -5
- mt_metadata/timeseries/person.py +14 -7
- mt_metadata/timeseries/standards/person.json +1 -1
- mt_metadata/timeseries/standards/run.json +2 -2
- mt_metadata/timeseries/station.py +4 -2
- mt_metadata/timeseries/stationxml/__init__.py +5 -0
- mt_metadata/timeseries/stationxml/xml_channel_mt_channel.py +25 -27
- mt_metadata/timeseries/stationxml/xml_inventory_mt_experiment.py +16 -47
- mt_metadata/timeseries/stationxml/xml_station_mt_station.py +25 -24
- mt_metadata/transfer_functions/__init__.py +3 -0
- mt_metadata/transfer_functions/core.py +8 -11
- mt_metadata/transfer_functions/io/emtfxml/metadata/location.py +5 -0
- mt_metadata/transfer_functions/io/emtfxml/metadata/provenance.py +14 -3
- mt_metadata/transfer_functions/io/tools.py +2 -0
- mt_metadata/transfer_functions/io/zonge/metadata/header.py +1 -1
- mt_metadata/transfer_functions/io/zonge/metadata/standards/header.json +1 -1
- mt_metadata/transfer_functions/io/zonge/metadata/standards/job.json +2 -2
- mt_metadata/transfer_functions/io/zonge/zonge.py +19 -23
- mt_metadata/transfer_functions/processing/__init__.py +2 -1
- mt_metadata/transfer_functions/processing/aurora/__init__.py +2 -4
- mt_metadata/transfer_functions/processing/aurora/band.py +46 -125
- mt_metadata/transfer_functions/processing/aurora/channel_nomenclature.py +27 -20
- mt_metadata/transfer_functions/processing/aurora/decimation_level.py +324 -152
- mt_metadata/transfer_functions/processing/aurora/frequency_bands.py +230 -0
- mt_metadata/transfer_functions/processing/aurora/processing.py +3 -3
- mt_metadata/transfer_functions/processing/aurora/run.py +32 -7
- mt_metadata/transfer_functions/processing/aurora/standards/decimation_level.json +7 -73
- mt_metadata/transfer_functions/processing/aurora/stations.py +33 -4
- mt_metadata/transfer_functions/processing/fourier_coefficients/decimation.py +176 -178
- mt_metadata/transfer_functions/processing/fourier_coefficients/fc.py +11 -9
- mt_metadata/transfer_functions/processing/fourier_coefficients/standards/decimation.json +1 -111
- mt_metadata/transfer_functions/processing/short_time_fourier_transform.py +64 -0
- mt_metadata/transfer_functions/processing/standards/__init__.py +6 -0
- mt_metadata/transfer_functions/processing/standards/short_time_fourier_transform.json +94 -0
- mt_metadata/transfer_functions/processing/{aurora/standards/decimation.json → standards/time_series_decimation.json} +17 -6
- mt_metadata/transfer_functions/processing/{aurora/standards → standards}/window.json +13 -2
- mt_metadata/transfer_functions/processing/time_series_decimation.py +50 -0
- mt_metadata/transfer_functions/processing/window.py +118 -0
- mt_metadata/transfer_functions/tf/station.py +17 -1
- mt_metadata/utils/mttime.py +22 -3
- mt_metadata/utils/validators.py +4 -2
- {mt_metadata-0.3.9.dist-info → mt_metadata-0.4.0.dist-info}/METADATA +39 -15
- {mt_metadata-0.3.9.dist-info → mt_metadata-0.4.0.dist-info}/RECORD +95 -55
- {mt_metadata-0.3.9.dist-info → mt_metadata-0.4.0.dist-info}/WHEEL +1 -1
- {mt_metadata-0.3.9.dist-info → mt_metadata-0.4.0.dist-info}/AUTHORS.rst +0 -0
- {mt_metadata-0.3.9.dist-info → mt_metadata-0.4.0.dist-info}/LICENSE +0 -0
- {mt_metadata-0.3.9.dist-info → mt_metadata-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""
|
|
2
|
+
The base class for a weighting kernel.
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
from mt_metadata.base.helpers import write_lines
|
|
6
|
+
from mt_metadata.base import get_schema, Base
|
|
7
|
+
from .standards import SCHEMA_FN_PATHS
|
|
8
|
+
|
|
9
|
+
# attr_dict = get_schema("base", SCHEMA_FN_PATHS)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BaseWeightKernel(Base):
|
|
13
|
+
"""
|
|
14
|
+
BaseWeightKernel
|
|
15
|
+
|
|
16
|
+
A base class for defining a weighting kernel that can be applied to a feature
|
|
17
|
+
to determine its contribution to a final weight value.
|
|
18
|
+
|
|
19
|
+
This class is not intended to be used directly but to be subclassed by
|
|
20
|
+
specific kernel types (e.g., MonotonicWeightKernel, CompositeWeightKernel).
|
|
21
|
+
"""
|
|
22
|
+
# __doc__ = write_lines(attr_dict)
|
|
23
|
+
|
|
24
|
+
def __init__(self, **kwargs):
|
|
25
|
+
super().__init__(**kwargs)
|
|
26
|
+
#self.from_dict(kwargs)
|
|
27
|
+
|
|
28
|
+
def evaluate(self, values):
|
|
29
|
+
"""
|
|
30
|
+
Evaluate the kernel on the input feature values.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
values : np.ndarray or float
|
|
35
|
+
The feature values to apply the weight kernel to.
|
|
36
|
+
|
|
37
|
+
Returns
|
|
38
|
+
-------
|
|
39
|
+
weights : np.ndarray or float
|
|
40
|
+
The resulting weight(s).
|
|
41
|
+
"""
|
|
42
|
+
raise NotImplementedError("BaseWeightKernel cannot be evaluated directly.")
|
|
43
|
+
|
|
44
|
+
|
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Container for weighting strategy to apply to a single tf estimation
|
|
3
|
+
having a single output channel (usually one of "ex", "ey", "hz").
|
|
4
|
+
|
|
5
|
+
candidate data structure is stored in test_helpers/channel_weight_specs_example.json
|
|
6
|
+
|
|
7
|
+
Candidate names: processing_weights, feature_weights, channel_weights_spec, channel_weighting
|
|
8
|
+
|
|
9
|
+
Notes, and doc for weights PR.
|
|
10
|
+
|
|
11
|
+
channel_weight_specs is a candidate name for the json block like the following:
|
|
12
|
+
>>> diff processing_configuration_template.json test_processing_config_with_weights_block.json
|
|
13
|
+
(Another candidate name could be `processing_weights`, or `weights`, but the final nomenclature
|
|
14
|
+
can be sorted out after there is a functional prototype with the appropriate structure.)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
This block is basically a dict that maps an output channel name to a ChannelWeightSpec (CWS) object.
|
|
18
|
+
|
|
19
|
+
There are at least three places we would like to be able to plug in such a dict to the processing flow.
|
|
20
|
+
1. At the frequency_band level, so that each band can be associated with a specialty CWS
|
|
21
|
+
2. At the decimation_level level, so that all bands in a GIB have a common, default.
|
|
22
|
+
3. At a high level, so that all processing uses them.
|
|
23
|
+
TAI: In future, hopefully we could insert a custom CWS for a specific band, but leave
|
|
24
|
+
all other bands to use the DecimationLevel default CWS, for example. i.e. the CWS can
|
|
25
|
+
be defined for different scopes.
|
|
26
|
+
|
|
27
|
+
TODO FIXME: IN mt_metadata/transfer_functions/processing/auaora/processing.py
|
|
28
|
+
when you output a json, it looks like the `decimations` level should be named:
|
|
29
|
+
`decimation_levels` instead.
|
|
30
|
+
|
|
31
|
+
The general model I'll try to follow will be to open an itearable of objects
|
|
32
|
+
with a plural of the object name. For example, the processing block called "bands"
|
|
33
|
+
follows with an itearble of:
|
|
34
|
+
{
|
|
35
|
+
"band": {
|
|
36
|
+
"center_averaging_type": "geometric",
|
|
37
|
+
...
|
|
38
|
+
"index_min": 25
|
|
39
|
+
}
|
|
40
|
+
}
|
|
41
|
+
...
|
|
42
|
+
{
|
|
43
|
+
"band": {
|
|
44
|
+
"center_averaging_type": "geometric",
|
|
45
|
+
...
|
|
46
|
+
"index_min": 25
|
|
47
|
+
}
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
Will start by plugging this into the DecimationLevel.
|
|
51
|
+
|
|
52
|
+
TODO: Determine if this class, which represents a single element of a list
|
|
53
|
+
of channel weight specs, which will be in the json, should have a wrapper or not.
|
|
54
|
+
|
|
55
|
+
In the same way that a DecimationLevel has Bands,
|
|
56
|
+
it will also have ChannelWeightSpecs.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
from mt_metadata.base.helpers import write_lines
|
|
61
|
+
from mt_metadata.base import get_schema, Base
|
|
62
|
+
from mt_metadata.features.weights.feature_weight_spec import FeatureWeightSpec
|
|
63
|
+
from mt_metadata.features.weights.standards import SCHEMA_FN_PATHS
|
|
64
|
+
from mt_metadata.helper_functions import cast_to_class_if_dict
|
|
65
|
+
from mt_metadata.helper_functions import validate_setter_input
|
|
66
|
+
from typing import List, Union
|
|
67
|
+
|
|
68
|
+
import numpy as np
|
|
69
|
+
import xarray as xr
|
|
70
|
+
|
|
71
|
+
attr_dict = get_schema("channel_weight_spec", SCHEMA_FN_PATHS)
|
|
72
|
+
|
|
73
|
+
class ChannelWeightSpec(Base):
|
|
74
|
+
"""
|
|
75
|
+
ChannelWeightSpec
|
|
76
|
+
|
|
77
|
+
Defines a weighting model for one output channel (e.g., ex, ey, hz).
|
|
78
|
+
Combines multiple feature-based weighting specifications into a
|
|
79
|
+
single weight using the specified combination strategy.
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
"""
|
|
83
|
+
__doc__ = write_lines(attr_dict)
|
|
84
|
+
|
|
85
|
+
def __init__(self, **kwargs):
|
|
86
|
+
super().__init__(attr_dict=attr_dict, **kwargs)
|
|
87
|
+
self._weights = None
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def feature_weight_specs(self) -> List[FeatureWeightSpec]:
|
|
91
|
+
"""
|
|
92
|
+
Return feature_weight_specs.
|
|
93
|
+
|
|
94
|
+
"""
|
|
95
|
+
return self._feature_weight_specs
|
|
96
|
+
|
|
97
|
+
@feature_weight_specs.setter
|
|
98
|
+
def feature_weight_specs(self, value: Union[List[Union[FeatureWeightSpec, dict]], FeatureWeightSpec]) -> None:
|
|
99
|
+
"""
|
|
100
|
+
Set features. If any are in dict form, cast them to FeatureWeightSpec objects before setting.
|
|
101
|
+
|
|
102
|
+
:param value: FeatureWeightSpecs or equivalent dicts
|
|
103
|
+
:type value: Union[List[Union[FeatureWeightSpec, dict]]
|
|
104
|
+
|
|
105
|
+
"""
|
|
106
|
+
values = validate_setter_input(value, FeatureWeightSpec)
|
|
107
|
+
fws_list = [cast_to_class_if_dict(obj, FeatureWeightSpec) for obj in values]
|
|
108
|
+
self._feature_weight_specs = fws_list
|
|
109
|
+
|
|
110
|
+
def evaluate(self, feature_values_dict):
|
|
111
|
+
"""
|
|
112
|
+
Evaluate the channel weight by combining weights from all features.
|
|
113
|
+
|
|
114
|
+
Parameters
|
|
115
|
+
----------
|
|
116
|
+
feature_values_dict : dict
|
|
117
|
+
Dictionary mapping feature names to their computed values.
|
|
118
|
+
e.g., {"coherence": ndarray, "multiple_coherence": ndarray}
|
|
119
|
+
|
|
120
|
+
Returns
|
|
121
|
+
-------
|
|
122
|
+
channel_weight : float or np.ndarray
|
|
123
|
+
"""
|
|
124
|
+
import numpy as np
|
|
125
|
+
|
|
126
|
+
weights = []
|
|
127
|
+
for feature_weight_spec in self.feature_weight_specs:
|
|
128
|
+
fname = feature_weight_spec.feature.name
|
|
129
|
+
if fname not in feature_values_dict:
|
|
130
|
+
raise KeyError(f"Feature values missing for '{fname}'")
|
|
131
|
+
|
|
132
|
+
w = feature_weight_spec.evaluate(feature_values_dict[fname])
|
|
133
|
+
weights.append(w)
|
|
134
|
+
|
|
135
|
+
if not weights:
|
|
136
|
+
return 1.0
|
|
137
|
+
|
|
138
|
+
combo = self.combination_style
|
|
139
|
+
if combo == "multiplication":
|
|
140
|
+
return np.prod(weights, axis=0)
|
|
141
|
+
elif combo == "mean":
|
|
142
|
+
return np.mean(weights, axis=0)
|
|
143
|
+
elif combo == "minimum":
|
|
144
|
+
return np.min(weights, axis=0)
|
|
145
|
+
elif combo == "maximum":
|
|
146
|
+
return np.max(weights, axis=0)
|
|
147
|
+
else:
|
|
148
|
+
raise ValueError(f"Unknown combination style: {combo}")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
def weights(self):
|
|
153
|
+
return self._weights
|
|
154
|
+
|
|
155
|
+
@weights.setter
|
|
156
|
+
def weights(self, value):
|
|
157
|
+
if not isinstance(value, (xr.DataArray, xr.Dataset, np.ndarray, None.__class__)):
|
|
158
|
+
raise TypeError("Data must be a numpy array or xarray.")
|
|
159
|
+
self._weights = value
|
|
160
|
+
|
|
161
|
+
def get_weights_for_band(self, band):
|
|
162
|
+
"""
|
|
163
|
+
Extract weights for the frequency bin closest to the band's center frequency.
|
|
164
|
+
|
|
165
|
+
TODO: Add tests.
|
|
166
|
+
Parameters
|
|
167
|
+
----------
|
|
168
|
+
band : object
|
|
169
|
+
Should have a .center_frequency attribute (float, Hz).
|
|
170
|
+
|
|
171
|
+
Returns
|
|
172
|
+
-------
|
|
173
|
+
weights : np.ndarray or xarray.DataArray
|
|
174
|
+
Weights for the closest frequency bin.
|
|
175
|
+
"""
|
|
176
|
+
if self.weights is None:
|
|
177
|
+
raise ValueError("No weights have been set.")
|
|
178
|
+
|
|
179
|
+
# Assume weights is an xarray.DataArray or Dataset with a 'frequency' dimension
|
|
180
|
+
freq_axis = None
|
|
181
|
+
if hasattr(self.weights, "dims"):
|
|
182
|
+
# Try to find the frequency dimension
|
|
183
|
+
for dim in self.weights.dims:
|
|
184
|
+
if "freq" in dim:
|
|
185
|
+
freq_axis = dim
|
|
186
|
+
break
|
|
187
|
+
if freq_axis is None:
|
|
188
|
+
raise ValueError("Could not find frequency dimension in weights.")
|
|
189
|
+
|
|
190
|
+
freqs = self.weights[freq_axis].values
|
|
191
|
+
elif isinstance(self.weights, np.ndarray):
|
|
192
|
+
# If it's a plain ndarray, assume first axis is frequency
|
|
193
|
+
freqs = np.arange(self.weights.shape[0])
|
|
194
|
+
freq_axis = 0
|
|
195
|
+
else:
|
|
196
|
+
raise TypeError("Weights must be an xarray.DataArray, Dataset, or numpy array.")
|
|
197
|
+
|
|
198
|
+
# Find index of closest frequency
|
|
199
|
+
idx = np.argmin(np.abs(freqs - band.center_frequency))
|
|
200
|
+
|
|
201
|
+
# Extract weights for that frequency
|
|
202
|
+
if hasattr(self.weights, "isel"):
|
|
203
|
+
# xarray: use isel
|
|
204
|
+
weights_for_band = self.weights.isel({freq_axis: idx})
|
|
205
|
+
else:
|
|
206
|
+
# numpy: index along first axis
|
|
207
|
+
weights_for_band = self.weights[idx]
|
|
208
|
+
|
|
209
|
+
return weights_for_band
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
"""
|
|
2
|
+
FeatureWeightSpec is the next key layer of abstraction after WeightKernels.
|
|
3
|
+
|
|
4
|
+
It ties together a feature (including its parameterization),
|
|
5
|
+
and one or more weighting kernels (like MonotonicWeightKernel).
|
|
6
|
+
|
|
7
|
+
This will let you do things like:
|
|
8
|
+
- Evaluate "coherence" between ex and hy with a taper kernel
|
|
9
|
+
- Apply multiple kernels to the same feature (e.g., low cut and high cut)
|
|
10
|
+
- Plug this into a higher-level channel weighting model
|
|
11
|
+
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from mt_metadata.base.helpers import write_lines
|
|
15
|
+
from mt_metadata.base import get_schema, Base
|
|
16
|
+
from mt_metadata.features.feature import Feature
|
|
17
|
+
from mt_metadata.features.weights.monotonic_weight_kernel import MonotonicWeightKernel
|
|
18
|
+
from mt_metadata.features.weights.monotonic_weight_kernel import ActivationMonotonicWeightKernel
|
|
19
|
+
from mt_metadata.features.weights.monotonic_weight_kernel import TaperMonotonicWeightKernel
|
|
20
|
+
from mt_metadata.features.weights.standards import SCHEMA_FN_PATHS
|
|
21
|
+
import numpy as np
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
attr_dict = get_schema("base", SCHEMA_FN_PATHS)
|
|
25
|
+
# no need to add to attr dict if we have lists of mtmetadata objs.
|
|
26
|
+
# attr_dict.add_dict(Feature()._attr_dict, "feature")
|
|
27
|
+
|
|
28
|
+
class FeatureWeightSpec(Base):
|
|
29
|
+
"""
|
|
30
|
+
FeatureWeightSpec
|
|
31
|
+
|
|
32
|
+
Defines how a particular feature is used to weight an output channel.
|
|
33
|
+
Includes parameters needed to compute the feature and one or more
|
|
34
|
+
weight kernels to evaluate its influence.
|
|
35
|
+
"""
|
|
36
|
+
__doc__ = write_lines(attr_dict)
|
|
37
|
+
|
|
38
|
+
def __init__(self, **kwargs):
|
|
39
|
+
"""
|
|
40
|
+
Consstuctor.
|
|
41
|
+
"""
|
|
42
|
+
self._feature = None # <-- initialize the backing variable directly
|
|
43
|
+
super().__init__(attr_dict=attr_dict, **kwargs)
|
|
44
|
+
weight_kernels = kwargs.get("weight_kernels", [])
|
|
45
|
+
self.weight_kernels = weight_kernels
|
|
46
|
+
|
|
47
|
+
# TODO: Remove this method after mt_metadata pydantic upgrade
|
|
48
|
+
# This is a workaround to ensure the setter logic runs when feature is a dict
|
|
49
|
+
# This is needed because the setter logic is not automatically triggered
|
|
50
|
+
# when the object is created from a dict.
|
|
51
|
+
def post_from_dict(self):
|
|
52
|
+
"""
|
|
53
|
+
If feature is a dict, force the setter logic to run
|
|
54
|
+
"""
|
|
55
|
+
if isinstance(self.feature, dict):
|
|
56
|
+
self.feature = self.feature
|
|
57
|
+
# Optionally, do the same for weight_kernels if needed
|
|
58
|
+
|
|
59
|
+
def from_dict(self, d):
|
|
60
|
+
# If 'feature' is a dict, convert it to the correct object before base from_dict
|
|
61
|
+
if "feature" in d and isinstance(d["feature"], dict):
|
|
62
|
+
self.feature = d["feature"] # This will use your property setter
|
|
63
|
+
d["feature"] = self.feature # Now it's the correct object
|
|
64
|
+
super().from_dict(d)
|
|
65
|
+
self.post_from_dict()
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def feature(self):
|
|
70
|
+
return self._feature
|
|
71
|
+
|
|
72
|
+
@feature.setter
|
|
73
|
+
def feature(self, value):
|
|
74
|
+
"""
|
|
75
|
+
Set the feature for this weight spec.
|
|
76
|
+
If a dict is provided, it will be used to initialize the feature object.
|
|
77
|
+
If an object is provided, it will be used directly.
|
|
78
|
+
Unwraps nested 'feature' keys if present.
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
TODO: FIXME (circular import)
|
|
82
|
+
Should be able to use a model like:
|
|
83
|
+
SUPPORTED_FEATURE_CLASS_MAP = {
|
|
84
|
+
"coherence": Coherence,
|
|
85
|
+
# "multiple_coherence": MultipleCoherence,
|
|
86
|
+
# Add more as needed
|
|
87
|
+
}
|
|
88
|
+
but that will result in a circular import if Coherence import at the top of module.
|
|
89
|
+
|
|
90
|
+
"""
|
|
91
|
+
# Unwrap if wrapped in 'feature' repeatedly
|
|
92
|
+
while isinstance(value, dict) and "feature" in value and isinstance(value["feature"], dict):
|
|
93
|
+
value = value["feature"]
|
|
94
|
+
if isinstance(value, dict):
|
|
95
|
+
feature_name = value.get("name")
|
|
96
|
+
# Import here to avoid circular import at module level
|
|
97
|
+
print(f"Feature setter: feature_name={feature_name}, value={value}") # DEBUG
|
|
98
|
+
if feature_name == "coherence":
|
|
99
|
+
from mt_metadata.features.coherence import Coherence
|
|
100
|
+
feature_cls = Coherence
|
|
101
|
+
elif feature_name == "striding_window_coherence":
|
|
102
|
+
from mt_metadata.features.coherence import StridingWindowCoherence
|
|
103
|
+
feature_cls = StridingWindowCoherence
|
|
104
|
+
else:
|
|
105
|
+
msg = f"feature_name {feature_name} not recognized -- resorting to base class"
|
|
106
|
+
self.logger.warning(msg)
|
|
107
|
+
from mt_metadata.features.feature import Feature
|
|
108
|
+
feature_cls = Feature
|
|
109
|
+
self._feature = feature_cls(**value)
|
|
110
|
+
print(f"Feature setter: instantiated {self._feature.__class__}") # DEBUG
|
|
111
|
+
else:
|
|
112
|
+
self._feature = value
|
|
113
|
+
print(f"Feature setter: set directly to {type(value)}") # DEBUG
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def weight_kernels(self):
|
|
118
|
+
return self._weight_kernels
|
|
119
|
+
|
|
120
|
+
@weight_kernels.setter
|
|
121
|
+
def weight_kernels(self, value):
|
|
122
|
+
"""
|
|
123
|
+
Ensure weight_kernels are properly initialized.
|
|
124
|
+
"""
|
|
125
|
+
self._weight_kernels = _unpack_weight_kernels(weight_kernels=value)
|
|
126
|
+
|
|
127
|
+
def evaluate(self, feature_values):
|
|
128
|
+
"""
|
|
129
|
+
Evaluate this feature's weighting based on the list of kernels.
|
|
130
|
+
|
|
131
|
+
Parameters
|
|
132
|
+
----------
|
|
133
|
+
feature_values : np.ndarray or float
|
|
134
|
+
The computed values for this feature.
|
|
135
|
+
|
|
136
|
+
Returns
|
|
137
|
+
-------
|
|
138
|
+
combined_weight : np.ndarray or float
|
|
139
|
+
The combined weight from all kernels (e.g., multiplied together).
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
weights = [kernel.evaluate(feature_values) for kernel in self.weight_kernels]
|
|
143
|
+
return np.prod(weights, axis=0) if weights else 1.0
|
|
144
|
+
|
|
145
|
+
def _unpack_weight_kernels(weight_kernels):
|
|
146
|
+
"""
|
|
147
|
+
Unpack weight kernels from a list of dictionaries or objects.
|
|
148
|
+
Determines the correct kernel class (Activation or Taper) based on keys.
|
|
149
|
+
"""
|
|
150
|
+
result = []
|
|
151
|
+
for wk in weight_kernels:
|
|
152
|
+
# Unwrap if wrapped in "weight_kernel" (TODO: Delete, or revert after mt_metadata pydantic upgrade.)
|
|
153
|
+
if isinstance(wk, dict) and "weight_kernel" in wk:
|
|
154
|
+
wk = wk["weight_kernel"]
|
|
155
|
+
if isinstance(wk, dict):
|
|
156
|
+
if "activation_style" in wk or wk.get("style") == "activation":
|
|
157
|
+
result.append(ActivationMonotonicWeightKernel(**wk))
|
|
158
|
+
elif "half_window_style" in wk or wk.get("style") == "taper":
|
|
159
|
+
result.append(TaperMonotonicWeightKernel(**wk))
|
|
160
|
+
else:
|
|
161
|
+
result.append(MonotonicWeightKernel(**wk))
|
|
162
|
+
else:
|
|
163
|
+
result.append(wk)
|
|
164
|
+
return result
|
|
165
|
+
|
|
166
|
+
def unwrap_known_wrappers(obj, known_keys=None):
|
|
167
|
+
"""
|
|
168
|
+
Recursively unwraps dicts/lists for known single-key wrappers.
|
|
169
|
+
"""
|
|
170
|
+
if known_keys is None:
|
|
171
|
+
known_keys = {"feature_weight_spec", "channel_weight_spec", "weight_kernel", "feature"}
|
|
172
|
+
if isinstance(obj, dict):
|
|
173
|
+
# If it's a single-key dict and the key is known, unwrap
|
|
174
|
+
while (
|
|
175
|
+
len(obj) == 1
|
|
176
|
+
and next(iter(obj)) in known_keys
|
|
177
|
+
and isinstance(obj[next(iter(obj))], (dict, list))
|
|
178
|
+
):
|
|
179
|
+
obj = obj[next(iter(obj))]
|
|
180
|
+
# Recurse into dict values
|
|
181
|
+
return {k: unwrap_known_wrappers(v, known_keys) for k, v in obj.items()}
|
|
182
|
+
elif isinstance(obj, list):
|
|
183
|
+
return [unwrap_known_wrappers(item, known_keys) for item in obj]
|
|
184
|
+
else:
|
|
185
|
+
return obj
|
|
186
|
+
|
|
187
|
+
# Patch FeatureWeightSpec.from_dict to unwrap wrappers
|
|
188
|
+
orig_from_dict = FeatureWeightSpec.from_dict
|
|
189
|
+
|
|
190
|
+
def from_dict_unwrap(self, d):
|
|
191
|
+
d = unwrap_known_wrappers(d)
|
|
192
|
+
return orig_from_dict(self, d)
|
|
193
|
+
|
|
194
|
+
FeatureWeightSpec.from_dict = from_dict_unwrap
|