mlda 2024.11.22__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlda-2024.11.22/LICENSE +28 -0
- mlda-2024.11.22/PKG-INFO +25 -0
- mlda-2024.11.22/README.md +4 -0
- mlda-2024.11.22/mlda/__init__.py +11 -0
- mlda-2024.11.22/mlda/da.py +333 -0
- mlda-2024.11.22/mlda/obs.py +80 -0
- mlda-2024.11.22/mlda/prior.py +329 -0
- mlda-2024.11.22/mlda/psm.py +147 -0
- mlda-2024.11.22/mlda/utils.py +103 -0
- mlda-2024.11.22/mlda.egg-info/PKG-INFO +25 -0
- mlda-2024.11.22/mlda.egg-info/SOURCES.txt +15 -0
- mlda-2024.11.22/mlda.egg-info/dependency_links.txt +1 -0
- mlda-2024.11.22/mlda.egg-info/not-zip-safe +1 -0
- mlda-2024.11.22/mlda.egg-info/requires.txt +7 -0
- mlda-2024.11.22/mlda.egg-info/top_level.txt +1 -0
- mlda-2024.11.22/setup.cfg +4 -0
- mlda-2024.11.22/setup.py +33 -0
mlda-2024.11.22/LICENSE
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
BSD 3-Clause License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024, Feng Zhu
|
|
4
|
+
|
|
5
|
+
Redistribution and use in source and binary forms, with or without
|
|
6
|
+
modification, are permitted provided that the following conditions are met:
|
|
7
|
+
|
|
8
|
+
1. Redistributions of source code must retain the above copyright notice, this
|
|
9
|
+
list of conditions and the following disclaimer.
|
|
10
|
+
|
|
11
|
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
12
|
+
this list of conditions and the following disclaimer in the documentation
|
|
13
|
+
and/or other materials provided with the distribution.
|
|
14
|
+
|
|
15
|
+
3. Neither the name of the copyright holder nor the names of its
|
|
16
|
+
contributors may be used to endorse or promote products derived from
|
|
17
|
+
this software without specific prior written permission.
|
|
18
|
+
|
|
19
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
20
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
21
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
22
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
23
|
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
24
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
25
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
26
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
27
|
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
28
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
mlda-2024.11.22/PKG-INFO
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: mlda
|
|
3
|
+
Version: 2024.11.22
|
|
4
|
+
Summary: mlda: A Python package for Machine Learning-base Data Assimilation
|
|
5
|
+
Home-page: https://github.com/fzhu2e/mlda
|
|
6
|
+
Author: Feng Zhu, Weimin Si
|
|
7
|
+
Author-email: fengzhu@ucar.edu, weimin_si@brown.edu
|
|
8
|
+
License: BSD-3
|
|
9
|
+
Keywords: Machine Learning,Data Assimilation
|
|
10
|
+
Classifier: Natural Language :: English
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
12
|
+
Description-Content-Type: text/markdown
|
|
13
|
+
License-File: LICENSE
|
|
14
|
+
Requires-Dist: netCDF4
|
|
15
|
+
Requires-Dist: xarray
|
|
16
|
+
Requires-Dist: dask
|
|
17
|
+
Requires-Dist: nc-time-axis
|
|
18
|
+
Requires-Dist: colorama
|
|
19
|
+
Requires-Dist: tqdm
|
|
20
|
+
Requires-Dist: x4c-exp
|
|
21
|
+
|
|
22
|
+
# mlda: A Python package for Machine Learning-based Data Assimilation
|
|
23
|
+
|
|
24
|
+
`mlda` is a Python package for Machine Learning-base Data Assimilation (DA).
|
|
25
|
+
It aims to provide a universal framework and the corresponding utilities for conducting reproducible data assimilation experiments using novel machine learning-based DA methods.
|
|
@@ -0,0 +1,4 @@
|
|
|
1
|
+
# mlda: A Python package for Machine Learning-based Data Assimilation
|
|
2
|
+
|
|
3
|
+
`mlda` is a Python package for Machine Learning-base Data Assimilation (DA).
|
|
4
|
+
It aims to provide a universal framework and the corresponding utilities for conducting reproducible data assimilation experiments using novel machine learning-based DA methods.
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
# get the version
|
|
2
|
+
from importlib.metadata import version
|
|
3
|
+
__version__ = version('cpda')
|
|
4
|
+
import warnings
|
|
5
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
6
|
+
|
|
7
|
+
from . import utils, cesm
|
|
8
|
+
from .prior import Prior, PriorMember
|
|
9
|
+
from .obs import Obs
|
|
10
|
+
from .da import Solver
|
|
11
|
+
|
|
@@ -0,0 +1,333 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import xarray as xr
|
|
3
|
+
from scipy.linalg import cholesky, sqrtm
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
from . import utils
|
|
7
|
+
|
|
8
|
+
def gaspari_cohn(dist, loc_radius):
|
|
9
|
+
'''
|
|
10
|
+
Vectorized Gaspari-Cohn localization function.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
dist (ndarray): Distance(s) between model state and observation.
|
|
14
|
+
loc_radius (float): Localization radius (distance beyond which covariance is set to zero).
|
|
15
|
+
|
|
16
|
+
Reference:
|
|
17
|
+
Gaspari, G., Cohn, S.E., 1999. Construction of correlation functions in two and three dimensions.
|
|
18
|
+
Quarterly Journal of the Royal Meteorological Society 125, 723-757. https://doi.org/10.1002/qj.49712555417
|
|
19
|
+
'''
|
|
20
|
+
# Normalize the distances
|
|
21
|
+
r = np.abs(dist) / loc_radius
|
|
22
|
+
|
|
23
|
+
# Initialize the result array with zeros
|
|
24
|
+
f = np.zeros_like(r)
|
|
25
|
+
|
|
26
|
+
# Eq. (4.10) in Gaaspari & Coh (1999)
|
|
27
|
+
mask1 = r <= 1
|
|
28
|
+
f[mask1] = -r[mask1]**5 / 4 + r[mask1]**4 / 2 + 5/8 * r[mask1]**3 - 5/3 * r[mask1]**2 + 1
|
|
29
|
+
|
|
30
|
+
mask2 = (r > 1) & (r <= 2)
|
|
31
|
+
f[mask2] = r[mask2]**5 / 12 - r[mask2]**4 / 2 + 5/8 * r[mask2]**3 + 5/3 * r[mask2]**2 - 5 * r[mask2] + 4 - 2/3 / r[mask2]
|
|
32
|
+
|
|
33
|
+
f[f<0] = 0 # force f >= 0
|
|
34
|
+
return f
|
|
35
|
+
|
|
36
|
+
def gaspari_cohn_dash(dist, loc_radius, scale=0.5):
|
|
37
|
+
"""
|
|
38
|
+
Implements a Gaspari-Cohn 5th order polynomial localization function.
|
|
39
|
+
|
|
40
|
+
Parameters:
|
|
41
|
+
dist (ndarray): An array of distances.
|
|
42
|
+
loc_radius (float): The cutoff radius, beyond which weights are zero.
|
|
43
|
+
scale (float or str, optional): The length scale for the polynomial.
|
|
44
|
+
Must be on the interval 0 < scale <= 0.5, or 'optimal' to use the optimal
|
|
45
|
+
length scale as described by Lorenc (2003). Default is 0.5.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
weights (ndarray): Covariance localization weights with the same shape as distances.
|
|
49
|
+
"""
|
|
50
|
+
# Set the scale if 'optimal' is specified
|
|
51
|
+
if isinstance(scale, str) and scale == 'optimal':
|
|
52
|
+
scale = np.sqrt(10 / 3)
|
|
53
|
+
|
|
54
|
+
# Define length scale and localization radius
|
|
55
|
+
c = scale * loc_radius
|
|
56
|
+
|
|
57
|
+
# Preallocate weights array with ones
|
|
58
|
+
weights = np.ones_like(dist)
|
|
59
|
+
|
|
60
|
+
# Calculate mask arrays for the different distance ranges
|
|
61
|
+
outside_radius = dist > loc_radius
|
|
62
|
+
inside_scale = dist <= c
|
|
63
|
+
in_between = ~inside_scale & ~outside_radius
|
|
64
|
+
|
|
65
|
+
# Apply Gaspari-Cohn polynomial
|
|
66
|
+
X = dist / c
|
|
67
|
+
weights[outside_radius] = 0
|
|
68
|
+
weights[in_between] = X[in_between]**5 / 12 - 0.5 * X[in_between]**4 + 0.625 * X[in_between]**3 + (5 / 3) * X[in_between]**2 - 5 * X[in_between] + 4 - 2 / (3 * X[in_between])
|
|
69
|
+
weights[inside_scale] = -0.25 * X[inside_scale]**5 + 0.5 * X[inside_scale]**4 + 0.625 * X[inside_scale]**3 - (5 / 3) * X[inside_scale]**2 + 1
|
|
70
|
+
|
|
71
|
+
# Ensure weights are non-negative due to rounding errors
|
|
72
|
+
weights[weights < 0] = 0
|
|
73
|
+
|
|
74
|
+
return weights
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class EnSRF:
|
|
79
|
+
def __init__(self, X=None, Y=None, y=None, R=None, L=None, Lobs=None):
|
|
80
|
+
self.X = X # ensemble of the prior state vectors (n x N)
|
|
81
|
+
self.Y = Y # ensemble of the forward estimates (m x N); Y=H(X)
|
|
82
|
+
self.y = y # observations (m x 1)
|
|
83
|
+
self.R = R # obs err matrix (m x m)
|
|
84
|
+
self.L = L # localization matrix (n x m)
|
|
85
|
+
self.Lobs = Lobs # localization matrix (m x m)
|
|
86
|
+
|
|
87
|
+
def update(self, debug=False):
|
|
88
|
+
''' Perform an EnSRF update with localization. '''
|
|
89
|
+
N = self.X.shape[1] # Ensemble size
|
|
90
|
+
|
|
91
|
+
# Compute the ensemble mean
|
|
92
|
+
Xm = np.mean(self.X, axis=1, keepdims=True)
|
|
93
|
+
Xp = self.X - Xm
|
|
94
|
+
|
|
95
|
+
Ym = np.mean(self.Y, axis=1, keepdims=True)
|
|
96
|
+
Yp = self.Y - Ym
|
|
97
|
+
|
|
98
|
+
# Observation error covariance matrix
|
|
99
|
+
Ycov = (Yp @ Yp.T) / (N - 1)
|
|
100
|
+
|
|
101
|
+
# Localize the obs err covariance matrix
|
|
102
|
+
if self.Lobs is not None:
|
|
103
|
+
Ycov_loc = Ycov * self.Lobs
|
|
104
|
+
else:
|
|
105
|
+
Ycov_loc = Ycov
|
|
106
|
+
|
|
107
|
+
C = Ycov_loc + self.R
|
|
108
|
+
|
|
109
|
+
# Kalman gain matrix
|
|
110
|
+
XYcov = (Xp @ Yp.T) / (N - 1)
|
|
111
|
+
|
|
112
|
+
# Localize the Kalman gain
|
|
113
|
+
if self.L is not None:
|
|
114
|
+
XYcov_loc = XYcov * self.L
|
|
115
|
+
else:
|
|
116
|
+
XYcov_loc = XYcov
|
|
117
|
+
|
|
118
|
+
K = XYcov_loc @ np.linalg.inv(C)
|
|
119
|
+
|
|
120
|
+
# Observation innovation
|
|
121
|
+
d = self.y - Ym
|
|
122
|
+
|
|
123
|
+
# Update the ensemble mean
|
|
124
|
+
Xm_updated = Xm + K @ d
|
|
125
|
+
|
|
126
|
+
# Update the ensemble perturbations
|
|
127
|
+
T = np.eye(N) - (Yp.T @ np.linalg.inv(C)) @ Yp / (N - 1)
|
|
128
|
+
Xp_updated = Xp @ T
|
|
129
|
+
|
|
130
|
+
# Combine updated mean and perturbations
|
|
131
|
+
self.X_updated = Xm_updated + Xp_updated
|
|
132
|
+
|
|
133
|
+
if debug:
|
|
134
|
+
self.Xm = Xm
|
|
135
|
+
self.Xp = Xp
|
|
136
|
+
self.Ym = Ym
|
|
137
|
+
self.Yp = Yp
|
|
138
|
+
self.C = C
|
|
139
|
+
self.K = K
|
|
140
|
+
self.d = d
|
|
141
|
+
self.T = T
|
|
142
|
+
|
|
143
|
+
class EnSRF_DASH:
|
|
144
|
+
def __init__(self, X=None, Y=None, y=None, R=None, L=None, Lobs=None):
|
|
145
|
+
self.X = X # ensemble of the prior state vectors (n x N)
|
|
146
|
+
self.Y = Y # ensemble of the forward estimates (m x N); Y=H(X)
|
|
147
|
+
self.y = y # observations (m x 1)
|
|
148
|
+
self.R = R # obs err matrix (m x m)
|
|
149
|
+
self.L = L # localization matrix (n x m)
|
|
150
|
+
self.Lobs = Lobs # localization matrix (m x m)
|
|
151
|
+
|
|
152
|
+
def update(self, debug=False):
|
|
153
|
+
''' Perform an EnSRF update with localization. '''
|
|
154
|
+
N = self.X.shape[1] # Ensemble size
|
|
155
|
+
|
|
156
|
+
# Compute the ensemble mean
|
|
157
|
+
Xm = np.mean(self.X, axis=1, keepdims=True)
|
|
158
|
+
Xp = self.X - Xm
|
|
159
|
+
|
|
160
|
+
Ym = np.mean(self.Y, axis=1, keepdims=True)
|
|
161
|
+
Yp = self.Y - Ym
|
|
162
|
+
|
|
163
|
+
# Observation error covariance matrix
|
|
164
|
+
Ycov = (Yp @ Yp.T) / (N - 1)
|
|
165
|
+
|
|
166
|
+
# Localize the obs err covariance matrix
|
|
167
|
+
if self.Lobs is not None:
|
|
168
|
+
Ycov_loc = self.Lobs * Ycov
|
|
169
|
+
else:
|
|
170
|
+
Ycov_loc = Ycov
|
|
171
|
+
|
|
172
|
+
C = Ycov_loc + self.R
|
|
173
|
+
|
|
174
|
+
# Kalman gain matrix
|
|
175
|
+
XYcov = (Xp @ Yp.T) / (N - 1)
|
|
176
|
+
|
|
177
|
+
# Localize the Kalman gain
|
|
178
|
+
if self.L is not None:
|
|
179
|
+
XYcov_loc = self.L * XYcov
|
|
180
|
+
else:
|
|
181
|
+
XYcov_loc = XYcov
|
|
182
|
+
|
|
183
|
+
K = XYcov_loc @ np.linalg.inv(C)
|
|
184
|
+
|
|
185
|
+
# Observation innovation
|
|
186
|
+
d = self.y - Ym
|
|
187
|
+
|
|
188
|
+
# Update the ensemble mean
|
|
189
|
+
Xm_updated = Xm + K @ d
|
|
190
|
+
|
|
191
|
+
# Update the ensemble perturbations
|
|
192
|
+
Ksqrt = sqrtm(C)
|
|
193
|
+
Ksqrt_inv_transpose = np.linalg.inv(Ksqrt).T
|
|
194
|
+
Rcov_sqrt = sqrtm(self.R)
|
|
195
|
+
Ka = K @ Ksqrt_inv_transpose @ np.linalg.inv(Ksqrt + Rcov_sqrt)
|
|
196
|
+
Xp_updated = Xp - Ka @ Yp
|
|
197
|
+
|
|
198
|
+
# Combine updated mean and perturbations
|
|
199
|
+
self.X_updated = Xm_updated + Xp_updated
|
|
200
|
+
|
|
201
|
+
if debug:
|
|
202
|
+
self.Xm = Xm
|
|
203
|
+
self.Xp = Xp
|
|
204
|
+
self.Ym = Ym
|
|
205
|
+
self.Yp = Yp
|
|
206
|
+
self.C = C
|
|
207
|
+
self.K = K
|
|
208
|
+
self.d = d
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
class EnOI:
|
|
212
|
+
def __init__(self, X_target=None, X=None, Y=None, y=None, R=None, L=None):
|
|
213
|
+
self.X_target = X_target # the **monthly** prior state vectors (n x 1)
|
|
214
|
+
self.X = X # ensemble of the prior state vectors (n x N)
|
|
215
|
+
self.Y = Y # ensemble of the forward estimates (m x N); Y=H(X)
|
|
216
|
+
self.y = y # observations (m x 1)
|
|
217
|
+
self.R = R # obs err matrix (m x m)
|
|
218
|
+
self.L = L # localization matrix (n x m)
|
|
219
|
+
|
|
220
|
+
def update(self, debug=False):
|
|
221
|
+
''' Perform an EnOI update with localization. '''
|
|
222
|
+
N = self.X.shape[1] # Ensemble size
|
|
223
|
+
|
|
224
|
+
# Compute the ensemble mean
|
|
225
|
+
Xm = np.mean(self.X, axis=1, keepdims=True)
|
|
226
|
+
Xp = self.X - Xm
|
|
227
|
+
|
|
228
|
+
Ym = np.mean(self.Y, axis=1, keepdims=True)
|
|
229
|
+
Yp = self.Y - Ym
|
|
230
|
+
|
|
231
|
+
# Observation error covariance matrix
|
|
232
|
+
C = (Yp @ Yp.T) / (N - 1) + self.R
|
|
233
|
+
|
|
234
|
+
# Kalman gain matrix
|
|
235
|
+
K = (Xp @ Yp.T) / (N - 1) @ np.linalg.inv(C)
|
|
236
|
+
|
|
237
|
+
# Localize the Kalman gain
|
|
238
|
+
if self.L is not None:
|
|
239
|
+
K_loc = K * self.L
|
|
240
|
+
else:
|
|
241
|
+
K_loc = K
|
|
242
|
+
|
|
243
|
+
# Observation innovation
|
|
244
|
+
d = self.y - Ym
|
|
245
|
+
|
|
246
|
+
# the increment
|
|
247
|
+
inc = K_loc @ d
|
|
248
|
+
|
|
249
|
+
# update
|
|
250
|
+
self.X_target_updated = self.X_target + inc
|
|
251
|
+
|
|
252
|
+
if debug:
|
|
253
|
+
self.Xm = Xm
|
|
254
|
+
self.Xp = Xp
|
|
255
|
+
self.Ym = Ym
|
|
256
|
+
self.Yp = Yp
|
|
257
|
+
self.C = C
|
|
258
|
+
self.K = K
|
|
259
|
+
self.K_loc = K_loc
|
|
260
|
+
self.d = d
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
class Solver:
|
|
264
|
+
def __init__(self, prior=None, obs=None, prior_target=None):
|
|
265
|
+
self.prior = prior.copy() if prior is not None else None
|
|
266
|
+
self.obs = obs.copy() if obs is not None else None
|
|
267
|
+
self.prior_target = prior_target.copy() if prior_target is not None else None
|
|
268
|
+
|
|
269
|
+
def prep(self, localize=True, loc_radius=2500, dist_vsf=1, dlat=1, dlon=1, loc_method='dash',
|
|
270
|
+
recon_season=list(range(1, 13)), startover=False, nearest_valid_radius=5, **fwd_kws):
|
|
271
|
+
''' Prepare Y=H(X) and the localization matrix for DA
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
dist_vsf (float, list of float): the vertical scaling factor of the distance
|
|
275
|
+
|
|
276
|
+
'''
|
|
277
|
+
if not hasattr(self.prior, 'ds_rgd'):
|
|
278
|
+
utils.p_header(f'>>> Regridding the prior (dlat={dlat}, dlon={dlon})')
|
|
279
|
+
self.prior.regrid(dlat=dlat, dlon=dlon)
|
|
280
|
+
|
|
281
|
+
if startover or not hasattr(self.prior, 'Y'):
|
|
282
|
+
utils.p_header('>>> Proxy System Modeling: Y = H(X)')
|
|
283
|
+
self.prior.get_Y(self.obs, nearest_valid_radius=nearest_valid_radius, **fwd_kws)
|
|
284
|
+
|
|
285
|
+
if not hasattr(self.prior, 'ds_ann'):
|
|
286
|
+
utils.p_header(f'>>> Annualizing prior w/ season: {recon_season}')
|
|
287
|
+
self.prior.annualize(months=recon_season)
|
|
288
|
+
|
|
289
|
+
if localize and not hasattr(self.prior, 'dist'):
|
|
290
|
+
loc_func = {
|
|
291
|
+
'cpda': gaspari_cohn,
|
|
292
|
+
'dash': gaspari_cohn_dash,
|
|
293
|
+
}
|
|
294
|
+
utils.p_header('>>> Computing the localization matrix')
|
|
295
|
+
self.prior.get_dist(self.prior.obs_assim, dist_vsf)
|
|
296
|
+
self.L = loc_func[loc_method](self.prior.dist, loc_radius)
|
|
297
|
+
self.obs.get_dist()
|
|
298
|
+
self.Lobs = loc_func[loc_method](self.obs.dist, loc_radius)
|
|
299
|
+
else:
|
|
300
|
+
self.L = None
|
|
301
|
+
self.Lobs = None
|
|
302
|
+
|
|
303
|
+
def run(self, method='EnSRF', debug=False):
|
|
304
|
+
algo = {
|
|
305
|
+
'EnSRF': EnSRF,
|
|
306
|
+
'EnSRF_DASH': EnSRF_DASH,
|
|
307
|
+
'EnOI': EnOI,
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
kws = {}
|
|
311
|
+
for m in algo.keys():
|
|
312
|
+
kws[m] = {
|
|
313
|
+
'X': self.prior.X,
|
|
314
|
+
'Y': self.prior.Y,
|
|
315
|
+
'y': self.obs.y,
|
|
316
|
+
'R': self.obs.R,
|
|
317
|
+
'L': self.L,
|
|
318
|
+
'Lobs': self.Lobs,
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
if self.prior_target is not None:
|
|
322
|
+
kws['EnOI']['X_target'] = self.prior_target.X
|
|
323
|
+
|
|
324
|
+
self.S = algo[method](**kws[method])
|
|
325
|
+
|
|
326
|
+
utils.p_header('>>> DA update')
|
|
327
|
+
self.S.update(debug=debug)
|
|
328
|
+
|
|
329
|
+
utils.p_header('>>> Formatting the posterior')
|
|
330
|
+
if method in ['EnSRF', 'EnSRF_DASH']:
|
|
331
|
+
self.post = utils.states2ds(self.S.X_updated, self.prior.ds_ann)
|
|
332
|
+
elif method == 'EnOI':
|
|
333
|
+
self.post = utils.states2ds(self.S.X_target_updated, self.prior_target.ds_ann)
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from copy import deepcopy
|
|
2
|
+
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import xarray as xr
|
|
5
|
+
|
|
6
|
+
from . import utils
|
|
7
|
+
class Obs:
|
|
8
|
+
def __init__(self, df:pd.DataFrame):
|
|
9
|
+
self.df = df
|
|
10
|
+
self.df['lon'] = (df['lon'] + 360) % 360
|
|
11
|
+
self.nobs = len(df)
|
|
12
|
+
self.pids = df['pid'].values
|
|
13
|
+
self.records = {}
|
|
14
|
+
for pid in self.pids:
|
|
15
|
+
self.records[pid] = self[pid]
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def y(self):
|
|
19
|
+
return self.df['value'].values[..., np.newaxis]
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def y_locs(self):
|
|
23
|
+
return self.df[['lat', 'lon']].values
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def R(self):
|
|
27
|
+
return np.diag(self.df['R'].values)
|
|
28
|
+
|
|
29
|
+
def copy(self):
|
|
30
|
+
return deepcopy(self)
|
|
31
|
+
|
|
32
|
+
def __getitem__(self, pid:str):
|
|
33
|
+
mask = self.df['pid'] == pid
|
|
34
|
+
row = self.df[mask].iloc[0]
|
|
35
|
+
rec = ProxyRecord(row)
|
|
36
|
+
return rec
|
|
37
|
+
|
|
38
|
+
def get_dist(self):
|
|
39
|
+
lats = self.df['lat'].values
|
|
40
|
+
lons = self.df['lon'].values
|
|
41
|
+
lat1, lat2 = np.meshgrid(lats, lats)
|
|
42
|
+
lon1, lon2 = np.meshgrid(lons, lons)
|
|
43
|
+
self.dist = utils.gcd(lat1, lon1, lat2, lon2)
|
|
44
|
+
return self.dist
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class ProxyRecord:
|
|
48
|
+
def __init__(self, data:pd.Series):
|
|
49
|
+
self.data = data.copy()
|
|
50
|
+
if 'time' in data: self.data['time'] = np.array(data['time'])
|
|
51
|
+
if 'value' in data: self.data['value'] = np.array(data['value'])
|
|
52
|
+
|
|
53
|
+
if 'seasonality' in data:
|
|
54
|
+
if isinstance(data['seasonality'], str):
|
|
55
|
+
self.data['seasonality'] = utils.str2list(data['seasonality'])
|
|
56
|
+
elif isinstance(data['seasonality'], list):
|
|
57
|
+
self.data['seasonality'] = data['seasonality']
|
|
58
|
+
else:
|
|
59
|
+
raise ValueError('Wrong seasonality type; should be a string or a list.')
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def get_clim(self, clim_ds, vns:list=None, verbose=False):
|
|
63
|
+
if vns is None:
|
|
64
|
+
vns = clim_ds.data_vars
|
|
65
|
+
else:
|
|
66
|
+
vns = [vn for vn in vns if vn in clim_ds.data_vars]
|
|
67
|
+
|
|
68
|
+
self.clim = xr.Dataset()
|
|
69
|
+
for vn in vns:
|
|
70
|
+
self.clim[vn] = clim_ds[vn].x.nearest2d(
|
|
71
|
+
# filled_da = clim_ds[vn].ffill(dim='lon').bfill(dim='lon').ffill(dim='lat').bfill(dim='lat')
|
|
72
|
+
# self.clim[vn] = filled_da.sel(
|
|
73
|
+
lat=self.data.lat,
|
|
74
|
+
lon=self.data.lon,
|
|
75
|
+
method='nearest',
|
|
76
|
+
).sel(month=self.data.seasonality).mean(dim='month')
|
|
77
|
+
if verbose: utils.p_success(f'>>> ProxyRecord.clim["{vn}"] created')
|
|
78
|
+
|
|
79
|
+
self.clim.attrs['seasonality'] = self.data.seasonality
|
|
80
|
+
|
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import xarray as xr
|
|
3
|
+
from tqdm import tqdm
|
|
4
|
+
from scipy.stats import norm
|
|
5
|
+
from tqdm import tqdm
|
|
6
|
+
from copy import deepcopy
|
|
7
|
+
|
|
8
|
+
from . import psm
|
|
9
|
+
from . import utils
|
|
10
|
+
|
|
11
|
+
class PriorMember:
|
|
12
|
+
def __init__(self, ds):
|
|
13
|
+
if isinstance(ds, xr.DataArray): ds = ds.to_dataset()
|
|
14
|
+
self.ds = ds
|
|
15
|
+
self.vns = list(ds.data_vars)
|
|
16
|
+
|
|
17
|
+
def gen_samples_Gaussian(self, local_sigma:dict, global_sigma:dict, nens:int=100, seed:int=2333):
|
|
18
|
+
''' Generate samples following Gaussian
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
sigma (dict): Dictionary with standard deviation (sigma) for each variable.
|
|
22
|
+
nens (int): Number of ensemble members to generate.
|
|
23
|
+
seed (int): Seed for reproducibility.
|
|
24
|
+
'''
|
|
25
|
+
rng = np.random.default_rng(seed)
|
|
26
|
+
self.samples = xr.Dataset()
|
|
27
|
+
for vn in self.vns:
|
|
28
|
+
mean = self.ds[vn].values
|
|
29
|
+
samples_shape = (*mean.shape, nens)
|
|
30
|
+
global_perturbation = norm.rvs(loc=0, scale=global_sigma[vn], size=nens, random_state=rng)
|
|
31
|
+
local_perturbation = norm.rvs(loc=0, scale=local_sigma[vn], size=samples_shape, random_state=rng)
|
|
32
|
+
samples = mean[..., np.newaxis] + global_perturbation + local_perturbation
|
|
33
|
+
samples_da = xr.DataArray(samples, dims=(*self.ds[vn].dims, 'ens'), coords=self.ds[vn].coords)
|
|
34
|
+
samples_da.attrs = self.ds[vn].attrs
|
|
35
|
+
self.samples[vn] = samples_da
|
|
36
|
+
|
|
37
|
+
def gen_samples_bootstrap(self, nens:int=30, clim_yrs:int=50, seed:int=0, dim='time'):
|
|
38
|
+
''' Generate samples from the prior pool
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
'''
|
|
42
|
+
nt = len(self.ds[dim])
|
|
43
|
+
pool_idx = list(range(nt))
|
|
44
|
+
sample_list = []
|
|
45
|
+
if seed is None: seed = 0
|
|
46
|
+
for i in range(nens):
|
|
47
|
+
seed += 1
|
|
48
|
+
rng = np.random.default_rng(seed)
|
|
49
|
+
sample_idx = rng.choice(pool_idx, size=clim_yrs, replace=False)
|
|
50
|
+
sample = self.ds.isel({dim: sample_idx}).mean(dim)
|
|
51
|
+
sample_list.append(sample)
|
|
52
|
+
|
|
53
|
+
samples = xr.concat(sample_list, dim=dim)
|
|
54
|
+
self.samples = xr.Dataset(samples).rename({dim: 'ens'})
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class Prior:
|
|
58
|
+
def __init__(self, members, lat_name='TLAT', lon_name='TLONG', depth_name='z_t'):
|
|
59
|
+
if not isinstance(members, list): members = [members]
|
|
60
|
+
ds_list = []
|
|
61
|
+
for m in members:
|
|
62
|
+
if hasattr(m, 'samples'):
|
|
63
|
+
ds_list.append(m.samples)
|
|
64
|
+
else:
|
|
65
|
+
if 'ens' not in m.ds.dims:
|
|
66
|
+
ds_list.append(m.ds.expand_dims({'ens': 1}))
|
|
67
|
+
else:
|
|
68
|
+
ds_list.append(m.ds)
|
|
69
|
+
|
|
70
|
+
self.ds = xr.concat(ds_list, dim='ens').transpose(..., 'ens')
|
|
71
|
+
self.lat_name = lat_name
|
|
72
|
+
self.lon_name = lon_name
|
|
73
|
+
self.depth_name = depth_name
|
|
74
|
+
if depth_name is not None:
|
|
75
|
+
self.nz = len(self.ds[depth_name])
|
|
76
|
+
|
|
77
|
+
self.nlat, self.nlon = self.ds[lat_name].shape, self.ds[lon_name].shape
|
|
78
|
+
self.nens = len(self.ds.ens)
|
|
79
|
+
self.nvar = len(self.ds.data_vars)
|
|
80
|
+
|
|
81
|
+
def regrid(self, dlat=1, dlon=1, verbose=False):
|
|
82
|
+
self.ds_rgd = xr.Dataset()
|
|
83
|
+
for vn in tqdm(self.ds.data_vars, desc=f'Regridding variables to {dlat}x{dlon}'):
|
|
84
|
+
self.ds_rgd[vn] = self.ds.x[vn].x.regrid(dlat=dlat, dlon=dlon)
|
|
85
|
+
if verbose: utils.p_success(f'>>> Prior.ds_rgd["{vn}"] created')
|
|
86
|
+
|
|
87
|
+
def annualize(self, months=list(range(1, 13))):
|
|
88
|
+
self.ds_ann = self.ds.sel(month=months).mean('month')
|
|
89
|
+
|
|
90
|
+
def inflate(self, factor=2):
|
|
91
|
+
self.ds_raw = self.ds.copy()
|
|
92
|
+
ens_mean = self.ds.mean('ens')
|
|
93
|
+
ens_pert = self.ds - self.ds.mean('ens')
|
|
94
|
+
inflated_pert = ens_pert * factor
|
|
95
|
+
self.ds = ens_mean + inflated_pert
|
|
96
|
+
|
|
97
|
+
def copy(self):
|
|
98
|
+
return deepcopy(self)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def X(self):
|
|
103
|
+
res = []
|
|
104
|
+
for vn in self.ds_ann.data_vars:
|
|
105
|
+
res.append(self.ds_ann[vn].values.reshape(-1, self.nens))
|
|
106
|
+
|
|
107
|
+
res = np.array(res).reshape(-1, self.nens)
|
|
108
|
+
return res
|
|
109
|
+
|
|
110
|
+
# def get_Y(self, obs, **fwd_kws):
|
|
111
|
+
# self.obs_assim = obs.copy()
|
|
112
|
+
# lats = obs.df['lat'].values
|
|
113
|
+
# lons = obs.df['lon'].values
|
|
114
|
+
# pids = obs.df['pid'].values
|
|
115
|
+
# depths = obs.df['depth'].values
|
|
116
|
+
# if 'clean' in obs.df.columns: cleans = obs.df['clean'].values
|
|
117
|
+
# if 'species' in obs.df.columns: specs = obs.df['species'].values
|
|
118
|
+
|
|
119
|
+
# psms = obs.df['psm'].values
|
|
120
|
+
# psm_names = list(set(psms))
|
|
121
|
+
|
|
122
|
+
# pseudo_obs = np.empty((len(obs.df), self.nens))
|
|
123
|
+
|
|
124
|
+
# # Loop over PSM types (psm_names)
|
|
125
|
+
# for psm_name in psm_names:
|
|
126
|
+
# mask = psms == psm_name
|
|
127
|
+
# idx = np.where(mask)[0]
|
|
128
|
+
# if np.any(mask):
|
|
129
|
+
# lat_lon_pairs = xr.Dataset({
|
|
130
|
+
# 'lat': (('obs',), lats[mask]),
|
|
131
|
+
# 'lon': (('obs',), lons[mask]),
|
|
132
|
+
# })
|
|
133
|
+
# self.clim_proxy_locs = xr.Dataset()
|
|
134
|
+
# for vn in self.ds_rgd.data_vars:
|
|
135
|
+
# filled_da = self.ds_rgd[vn].ffill(dim='lat').bfill(dim='lat').ffill(dim='lon').bfill(dim='lon')
|
|
136
|
+
# self.clim_proxy_locs[vn] = filled_da.sel(
|
|
137
|
+
# lat=lat_lon_pairs['lat'],
|
|
138
|
+
# lon=lat_lon_pairs['lon'],
|
|
139
|
+
# method='nearest',
|
|
140
|
+
# ).transpose(..., 'ens')
|
|
141
|
+
|
|
142
|
+
# for i in tqdm(range(len(idx)), desc=f'>>> Looping over sites w/ PSM - {psm_name}'):
|
|
143
|
+
# pid = pids[idx[i]]
|
|
144
|
+
# lat = lats[idx[i]]
|
|
145
|
+
# lon = lons[idx[i]]
|
|
146
|
+
# depth = depths[idx[i]]
|
|
147
|
+
# if np.isnan(depth): depth = 0
|
|
148
|
+
|
|
149
|
+
# obs_meta = {
|
|
150
|
+
# 'pid': pid,
|
|
151
|
+
# 'lat': lat,
|
|
152
|
+
# 'lon': lon,
|
|
153
|
+
# 'depth': depth,
|
|
154
|
+
# }
|
|
155
|
+
# if 'clean' in obs.df.columns:
|
|
156
|
+
# clean = cleans[idx[i]]
|
|
157
|
+
# if np.isnan(clean): clean = 0
|
|
158
|
+
# obs_meta['clean'] = clean
|
|
159
|
+
|
|
160
|
+
# if 'species' in obs.df.columns:
|
|
161
|
+
# species = specs[idx[i]]
|
|
162
|
+
# if not isinstance(species, str): species = 'all'
|
|
163
|
+
# obs_meta['species'] = species
|
|
164
|
+
|
|
165
|
+
# mdl = psm.__dict__[psm_name](obs_meta, self.clim_proxy_locs.isel({'obs': i}))
|
|
166
|
+
# _fwd_kws = {}
|
|
167
|
+
# _fwd_kws[psm_name] = {}
|
|
168
|
+
# if psm_name in fwd_kws:
|
|
169
|
+
# _fwd_kws[psm_name].update(fwd_kws[psm_name])
|
|
170
|
+
|
|
171
|
+
# res = mdl.forward(**_fwd_kws[psm_name])
|
|
172
|
+
# if res is None:
|
|
173
|
+
# utils.p_warning(f'>>> Dropping proxy: {pid}')
|
|
174
|
+
# self.obs_assim.df = obs.df.drop(obs.df[obs.df['pid'] == pid].index)
|
|
175
|
+
# pseudo_obs[idx[i]] = np.nan
|
|
176
|
+
# else:
|
|
177
|
+
# pseudo_obs[idx[i]] = res
|
|
178
|
+
|
|
179
|
+
# self.obs_assim.nobs = len(self.obs_assim.df)
|
|
180
|
+
# pseudo_obs = pseudo_obs[~np.isnan(pseudo_obs).any(axis=1)]
|
|
181
|
+
# self.Y = pseudo_obs
|
|
182
|
+
# self.obs_assim.df['Ym'] = self.Y.mean(axis=1)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
# def get_Y(self, obs, **fwd_kws):
|
|
186
|
+
# self.obs_assim = obs.copy()
|
|
187
|
+
# pseudo_obs = np.empty((len(obs.df), self.nens))
|
|
188
|
+
|
|
189
|
+
# for i, (pid, rec) in tqdm(enumerate(obs.records.items()), total=obs.nobs, desc='Looping over records'):
|
|
190
|
+
# mdl = psm.__dict__[rec.data.psm](rec)
|
|
191
|
+
# mdl.record.get_clim(self.ds_rgd, vns=mdl.clim_vns)
|
|
192
|
+
|
|
193
|
+
# _fwd_kws = {}
|
|
194
|
+
# _fwd_kws[rec.data.psm] = {}
|
|
195
|
+
# if rec.data.psm in fwd_kws:
|
|
196
|
+
# _fwd_kws[rec.data.psm].update(fwd_kws[rec.data.psm])
|
|
197
|
+
|
|
198
|
+
# res = mdl.forward(**_fwd_kws[rec.data.psm])
|
|
199
|
+
# if res is None:
|
|
200
|
+
# utils.p_warning(f'>>> Dropping proxy: {pid}')
|
|
201
|
+
# self.obs_assim.df = obs.df.drop(obs.df[obs.df['pid'] == pid].index)
|
|
202
|
+
# pseudo_obs[i] = np.nan
|
|
203
|
+
# else:
|
|
204
|
+
# pseudo_obs[i] = res
|
|
205
|
+
|
|
206
|
+
# self.obs_assim.nobs = len(self.obs_assim.df)
|
|
207
|
+
# pseudo_obs = pseudo_obs[~np.isnan(pseudo_obs).any(axis=1)]
|
|
208
|
+
# self.Y = pseudo_obs
|
|
209
|
+
# self.obs_assim.df['Ym'] = self.Y.mean(axis=1)
|
|
210
|
+
|
|
211
|
+
def get_Y(self, obs, nearest_valid_radius=5, **fwd_kws):
|
|
212
|
+
self.obs_assim = obs.copy()
|
|
213
|
+
pseudo_obs = np.empty((len(obs.df), self.nens))
|
|
214
|
+
|
|
215
|
+
psm_names = set(obs.df['psm_name'])
|
|
216
|
+
clim_vns = list({
|
|
217
|
+
vn for psm_name in psm_names
|
|
218
|
+
for vn in psm.__dict__[psm_name]().clim_vns
|
|
219
|
+
if vn in self.ds_rgd.data_vars
|
|
220
|
+
})
|
|
221
|
+
|
|
222
|
+
lat_lon_pairs = xr.Dataset({
|
|
223
|
+
'lat': (('sites',), obs.df['lat'].values),
|
|
224
|
+
'lon': (('sites',), obs.df['lon'].values),
|
|
225
|
+
})
|
|
226
|
+
self.ds_proxy_locs = xr.Dataset()
|
|
227
|
+
for vn in clim_vns:
|
|
228
|
+
# filled_da = self.ds_rgd[vn].ffill(dim='lon').bfill(dim='lon').ffill(dim='lat').bfill(dim='lat')
|
|
229
|
+
# ds_proxy_locs[vn] = filled_da.sel(
|
|
230
|
+
# lat=lat_lon_pairs['lat'],
|
|
231
|
+
# lon=lat_lon_pairs['lon'],
|
|
232
|
+
# method='nearest',
|
|
233
|
+
# ).transpose(..., 'ens')
|
|
234
|
+
|
|
235
|
+
self.ds_proxy_locs[vn] = self.ds_rgd[vn].x.nearest2d(
|
|
236
|
+
lat=lat_lon_pairs['lat'],
|
|
237
|
+
lon=lat_lon_pairs['lon'],
|
|
238
|
+
r=nearest_valid_radius,
|
|
239
|
+
extra_dim='ens',
|
|
240
|
+
).transpose(..., 'ens')
|
|
241
|
+
|
|
242
|
+
if 'sites' not in self.ds_proxy_locs.dims:
|
|
243
|
+
self.ds_proxy_locs = self.ds_proxy_locs.expand_dims({'sites': [0]})
|
|
244
|
+
|
|
245
|
+
# if ds_proxy_locs[vn].isnull().any():
|
|
246
|
+
# for idx in obs.df.index:
|
|
247
|
+
# if ds_proxy_locs[vn].sel(sites=idx).isnull().any():
|
|
248
|
+
# utils.p_warning(f"NaN detected for {vn}: {obs.df.iloc[idx][['pid', 'lat', 'lon']].values}")
|
|
249
|
+
# print(ds_proxy_locs[vn].sel(sites=idx).dims)
|
|
250
|
+
# print(ds_proxy_locs[vn].sel(sites=idx).values)
|
|
251
|
+
# utils.p_warning('------------------------------------')
|
|
252
|
+
# raise ValueError('Some of the nearest gridcell values are NaN.')
|
|
253
|
+
|
|
254
|
+
nearest_lats, nearest_lons = [], []
|
|
255
|
+
for i, (pid, rec) in tqdm(enumerate(obs.records.items()), total=obs.nobs, desc='Looping over records'):
|
|
256
|
+
# nearest_clim = self.ds_proxy_locs.isel({'sites': i}).sel(month=rec.data.seasonality).mean(dim='month')
|
|
257
|
+
# nearest_lat = nearest_clim.lat.values.mean()
|
|
258
|
+
# nearest_lon = nearest_clim.lon.values.mean()
|
|
259
|
+
# nearest_lats.append(nearest_lat)
|
|
260
|
+
# nearest_lons.append(nearest_lon)
|
|
261
|
+
# rec.data.lat = nearest_lat
|
|
262
|
+
# rec.data.lon = nearest_lon
|
|
263
|
+
|
|
264
|
+
mdl = psm.__dict__[rec.data.psm_name](rec)
|
|
265
|
+
mdl.record.clim = self.ds_proxy_locs.isel({'sites': i}).sel(month=rec.data.seasonality).mean(dim='month')
|
|
266
|
+
for vn in clim_vns:
|
|
267
|
+
if mdl.record.clim[vn].isnull().any():
|
|
268
|
+
# print(i, ds_proxy_locs[vn].isel({'sites': i}))
|
|
269
|
+
# print(ds_proxy_locs[vn].isel({'sites': i, 'ens': 6}))
|
|
270
|
+
# print(ds_proxy_locs.sel(month=rec.data.seasonality).mean(dim='month')[vn].values[i])
|
|
271
|
+
# print(vn, rec.data.pid, rec.data.lat, rec.data.lon, rec.data.seasonality)
|
|
272
|
+
# print(mdl.record.clim[vn].values)
|
|
273
|
+
raise ValueError(f'NaN values detected in input climate for forward modeling of: {pid}')
|
|
274
|
+
|
|
275
|
+
obs.records[pid].psm = mdl # for debugging purposes
|
|
276
|
+
|
|
277
|
+
_fwd_kws = {}
|
|
278
|
+
_fwd_kws[rec.data.psm_name] = {}
|
|
279
|
+
if rec.data.psm_name in fwd_kws:
|
|
280
|
+
_fwd_kws[rec.data.psm_name].update(fwd_kws[rec.data.psm_name])
|
|
281
|
+
mdl.forward(**_fwd_kws[rec.data.psm_name])
|
|
282
|
+
if mdl.output is None:
|
|
283
|
+
utils.p_warning(f'>>> Dropping proxy: {pid}')
|
|
284
|
+
self.obs_assim.df = obs.df.drop(obs.df[obs.df['pid'] == pid].index)
|
|
285
|
+
pseudo_obs[i] = np.nan
|
|
286
|
+
else:
|
|
287
|
+
pseudo_obs[i] = mdl.output
|
|
288
|
+
|
|
289
|
+
self.obs_assim.nobs = len(self.obs_assim.df)
|
|
290
|
+
# self.obs_assim.df['lat'] = nearest_lats
|
|
291
|
+
# self.obs_assim.df['lon'] = nearest_lons
|
|
292
|
+
pseudo_obs = pseudo_obs[~np.isnan(pseudo_obs).any(axis=1)]
|
|
293
|
+
self.Y = pseudo_obs
|
|
294
|
+
self.obs_assim.df['Ym'] = self.Y.mean(axis=1)
|
|
295
|
+
|
|
296
|
+
def get_dist(self, obs, s=1):
|
|
297
|
+
# Extract grid latitudes and longitudes as 2D arrays
|
|
298
|
+
lat_grid = self.ds[self.lat_name].values # shape: (nlat, nlon)
|
|
299
|
+
lon_grid = self.ds[self.lon_name].values # shape: (nlat, nlon)
|
|
300
|
+
|
|
301
|
+
if lat_grid.ndim == 1 and lon_grid.ndim == 1:
|
|
302
|
+
# If lat and lon are 1D, create a meshgrid
|
|
303
|
+
lon_grid, lat_grid = np.meshgrid(lon_grid, lat_grid)
|
|
304
|
+
|
|
305
|
+
# Flatten the grid arrays to 1D
|
|
306
|
+
lat_grid_flat = lat_grid.ravel() # shape: (nlat * nlon,)
|
|
307
|
+
lon_grid_flat = lon_grid.ravel() # shape: (nlat * nlon,)
|
|
308
|
+
|
|
309
|
+
# Get the observation lat/lon as a 2D array
|
|
310
|
+
lats2 = obs.df['lat'].values # shape: (nobs,)
|
|
311
|
+
lons2 = obs.df['lon'].values # shape: (nobs,)
|
|
312
|
+
|
|
313
|
+
# Broadcast the grid cells to all observation points
|
|
314
|
+
lats1 = np.repeat(lat_grid_flat, obs.nobs) # shape: (nlat * nlon * nobs,)
|
|
315
|
+
lons1 = np.repeat(lon_grid_flat, obs.nobs) # shape: (nlat * nlon * nobs,)
|
|
316
|
+
|
|
317
|
+
# Repeat observation points for every grid point
|
|
318
|
+
lats2 = np.tile(lats2, len(lat_grid_flat)) # shape: (nlat * nlon * nobs,)
|
|
319
|
+
lons2 = np.tile(lons2, len(lon_grid_flat)) # shape: (nlat * nlon * nobs,)
|
|
320
|
+
|
|
321
|
+
dist0 = utils.gcd(lats1, lons1, lats2, lons2).reshape((-1, obs.nobs))
|
|
322
|
+
|
|
323
|
+
if hasattr(self, 'nz'):
|
|
324
|
+
# 3D localization
|
|
325
|
+
s = (np.ones(self.nz)*s).reshape(-1, 1, 1)
|
|
326
|
+
dist1 = (dist0[None, :, :] * s).reshape((-1, obs.nobs))
|
|
327
|
+
self.dist = dist1[np.newaxis, :].repeat(self.nvar, axis=0).reshape(-1, obs.nobs)
|
|
328
|
+
else:
|
|
329
|
+
self.dist = dist0[None, :, :].repeat(self.nvar, axis=0).reshape(-1, obs.nobs)
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
import xarray as xr
|
|
2
|
+
import pybaywatch as pb
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from . import utils
|
|
6
|
+
from . import obs
|
|
7
|
+
|
|
8
|
+
class IdenticalSST:
|
|
9
|
+
def __init__(self, record:obs.ProxyRecord=None):
|
|
10
|
+
self.record = record
|
|
11
|
+
|
|
12
|
+
@property
|
|
13
|
+
def clim_vns(self):
|
|
14
|
+
return ['TEMP']
|
|
15
|
+
|
|
16
|
+
def forward(self):
|
|
17
|
+
self.output = self.record.clim['TEMP'].isel(z_t=0).values
|
|
18
|
+
|
|
19
|
+
class IdenticalSSS:
|
|
20
|
+
def __init__(self, record:obs.ProxyRecord=None):
|
|
21
|
+
self.record = record
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def clim_vns(self):
|
|
25
|
+
return ['SALT']
|
|
26
|
+
|
|
27
|
+
def forward(self):
|
|
28
|
+
self.output = self.record.clim['SALT'].isel(z_t=0).values
|
|
29
|
+
|
|
30
|
+
class IdenticalSSTSSS:
|
|
31
|
+
def __init__(self, record:obs.ProxyRecord=None):
|
|
32
|
+
self.record = record
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def clim_vns(self):
|
|
36
|
+
return ['TEMP', 'SALT']
|
|
37
|
+
|
|
38
|
+
def forward(self):
|
|
39
|
+
self.output = self.record.clim['TEMP'].isel(z_t=0).values+self.record.clim['SALT'].isel(z_t=0).values
|
|
40
|
+
|
|
41
|
+
class TEX86:
|
|
42
|
+
def __init__(self, record:obs.ProxyRecord=None):
|
|
43
|
+
self.record = record
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def clim_vns(self):
|
|
47
|
+
return ['TEMP', 'tos', 'sst']
|
|
48
|
+
|
|
49
|
+
def forward(self, seed=2333, mode='analog', type='SST', tolerance=1):
|
|
50
|
+
if 'TEMP' in self.record.clim:
|
|
51
|
+
sst = self.record.clim['TEMP'].isel(z_t=0).values
|
|
52
|
+
elif 'tos' in self.record.clim:
|
|
53
|
+
sst = self.record.clim['tos'].values
|
|
54
|
+
elif 'sst' in self.record.clim:
|
|
55
|
+
sst = self.record.clim['sst'].values
|
|
56
|
+
|
|
57
|
+
lat = self.record.data.lat
|
|
58
|
+
lon = self.record.data.lon
|
|
59
|
+
lon180 = utils.lon180(lon)
|
|
60
|
+
|
|
61
|
+
# run
|
|
62
|
+
self.params = {
|
|
63
|
+
'lat': lat,
|
|
64
|
+
'lon': lon180,
|
|
65
|
+
'temp': sst,
|
|
66
|
+
'seed': seed,
|
|
67
|
+
'type': type,
|
|
68
|
+
'mode': mode,
|
|
69
|
+
'tolerance': tolerance,
|
|
70
|
+
}
|
|
71
|
+
res = pb.TEX_forward(**self.params)
|
|
72
|
+
if res['status'] == 'FAIL':
|
|
73
|
+
utils.p_warning(f'>>> Forward modeling failed for proxy: {self.meta["pid"]}')
|
|
74
|
+
self.output = None
|
|
75
|
+
else:
|
|
76
|
+
self.output = np.median(res['values'], axis=1)
|
|
77
|
+
|
|
78
|
+
class UK37:
|
|
79
|
+
def __init__(self, record:obs.ProxyRecord=None):
|
|
80
|
+
self.record = record
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def clim_vns(self):
|
|
84
|
+
return ['TEMP', 'tos', 'sst']
|
|
85
|
+
|
|
86
|
+
def forward(self, order=3, seed=2333):
|
|
87
|
+
if 'TEMP' in self.record.clim:
|
|
88
|
+
sst = self.clim['TEMP'].isel(z_t=0).values
|
|
89
|
+
elif 'tos' in self.record.clim:
|
|
90
|
+
sst = self.record.clim['tos'].values
|
|
91
|
+
elif 'sst' in self.record.clim:
|
|
92
|
+
sst = self.record.clim['sst'].values
|
|
93
|
+
|
|
94
|
+
# run
|
|
95
|
+
self.params = {
|
|
96
|
+
'sst': sst,
|
|
97
|
+
'order': order,
|
|
98
|
+
'seed': seed,
|
|
99
|
+
}
|
|
100
|
+
res = pb.UK_forward(**self.params)
|
|
101
|
+
self.output = np.median(res['values'], axis=1)
|
|
102
|
+
|
|
103
|
+
class MgCa:
|
|
104
|
+
def __init__(self, record:obs.ProxyRecord=None):
|
|
105
|
+
self.record = record
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def clim_vns(self):
|
|
109
|
+
return ['TEMP', 'tos', 'sst', 'SALT', 'sos', 'sss']
|
|
110
|
+
|
|
111
|
+
def forward(self, age, omega=None, pH=None, clean=None, species=None, sw=2, H=1, seed=2333):
|
|
112
|
+
if 'TEMP' in self.record.clim and 'SALT' in self.record.clim:
|
|
113
|
+
sst = self.record.clim['TEMP'].isel(z_t=0).values
|
|
114
|
+
sss = self.record.clim['SALT'].isel(z_t=0).values
|
|
115
|
+
elif 'tos' in self.record.clim and 'sos' in self.record.clim:
|
|
116
|
+
sst = self.record.clim['tos'].values
|
|
117
|
+
sss = self.record.clim['sos'].values
|
|
118
|
+
elif 'sst' in self.record.clim and 'sss' in self.record.clim:
|
|
119
|
+
sst = self.record.clim['sst'].values
|
|
120
|
+
sss = self.record.clim['sss'].values
|
|
121
|
+
|
|
122
|
+
# get omega and pH
|
|
123
|
+
lat = self.record.data.lat
|
|
124
|
+
lon = self.record.data.lon
|
|
125
|
+
depth = self.record.data.depth
|
|
126
|
+
if omega is None and pH is None:
|
|
127
|
+
lon180 = np.mod(lon + 180, 360) - 180
|
|
128
|
+
omega, pH = pb.core.omgph(lat, lon180, depth)
|
|
129
|
+
|
|
130
|
+
if clean is None: clean = self.record.data.clean
|
|
131
|
+
if species is None: species = self.record.data.species
|
|
132
|
+
|
|
133
|
+
# run
|
|
134
|
+
self.params = {
|
|
135
|
+
'age': age,
|
|
136
|
+
'sst': sst,
|
|
137
|
+
'salinity': sss,
|
|
138
|
+
'pH': pH,
|
|
139
|
+
'omega': omega,
|
|
140
|
+
'species': species,
|
|
141
|
+
'clean': clean,
|
|
142
|
+
'sw': sw,
|
|
143
|
+
'H': H,
|
|
144
|
+
'seed': seed,
|
|
145
|
+
}
|
|
146
|
+
res = pb.MgCa_forward(**self.params)
|
|
147
|
+
self.output = np.median(res['values'], axis=1)
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import xarray as xr
|
|
3
|
+
import colorama as ca
|
|
4
|
+
|
|
5
|
+
def p_header(text):
|
|
6
|
+
print(ca.Fore.CYAN + ca.Style.BRIGHT + text + ca.Style.RESET_ALL)
|
|
7
|
+
|
|
8
|
+
def p_hint(text):
|
|
9
|
+
print(ca.Fore.LIGHTBLACK_EX + ca.Style.BRIGHT + text + ca.Style.RESET_ALL)
|
|
10
|
+
|
|
11
|
+
def p_success(text):
|
|
12
|
+
print(ca.Fore.GREEN + ca.Style.BRIGHT + text + ca.Style.RESET_ALL)
|
|
13
|
+
|
|
14
|
+
def p_fail(text):
|
|
15
|
+
print(ca.Fore.RED + ca.Style.BRIGHT + text + ca.Style.RESET_ALL)
|
|
16
|
+
|
|
17
|
+
def p_warning(text):
|
|
18
|
+
print(ca.Fore.YELLOW + ca.Style.BRIGHT + text + ca.Style.RESET_ALL)
|
|
19
|
+
|
|
20
|
+
def gcd(lat1, lon1, lat2, lon2, radius=6378.137):
|
|
21
|
+
''' 2D Great Circle Distance [km]
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
radius (float): Earth radius
|
|
25
|
+
'''
|
|
26
|
+
# Convert degrees to radians
|
|
27
|
+
lon1, lat1, lon2, lat2 = map(np.radians, [lon1, lat1, lon2, lat2])
|
|
28
|
+
dlat, dlon = lat2 - lat1, lon2 - lon1
|
|
29
|
+
a = np.sin(dlat / 2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon / 2)**2
|
|
30
|
+
c = 2 * np.arcsin(np.sqrt(a))
|
|
31
|
+
dist = radius * c
|
|
32
|
+
return dist
|
|
33
|
+
|
|
34
|
+
def states2ds(states, ds):
|
|
35
|
+
original_shapes = {vn: ds[vn].shape for vn in ds.data_vars}
|
|
36
|
+
original_dims = {vn: ds[vn].dims for vn in ds.data_vars}
|
|
37
|
+
original_coords = {vn: ds[vn].coords for vn in ds.data_vars}
|
|
38
|
+
ds_out = xr.Dataset()
|
|
39
|
+
start_loc = 0
|
|
40
|
+
for vn in ds.data_vars:
|
|
41
|
+
if 'ens' in original_dims[vn]:
|
|
42
|
+
end_loc = start_loc + np.prod(original_shapes[vn][:-1])
|
|
43
|
+
# p_hint(f'{np.prod(original_shapes[vn][:-1]) = }')
|
|
44
|
+
else:
|
|
45
|
+
end_loc = start_loc + np.prod(original_shapes[vn])
|
|
46
|
+
# p_hint(f'{np.prod(original_shapes[vn]) = }')
|
|
47
|
+
|
|
48
|
+
# p_hint(f'{vn = }')
|
|
49
|
+
# p_hint(f'{start_loc = }')
|
|
50
|
+
# p_hint(f'{end_loc = }')
|
|
51
|
+
# p_hint(f'{np.shape(states) = }')
|
|
52
|
+
# p_hint(f'{np.shape(states[start_loc:end_loc]) = }')
|
|
53
|
+
data = states[start_loc:end_loc].reshape(original_shapes[vn])
|
|
54
|
+
nan_mask = np.isnan(ds[vn].values)
|
|
55
|
+
data[nan_mask] = np.nan
|
|
56
|
+
|
|
57
|
+
ds_out[vn] = xr.DataArray(
|
|
58
|
+
data,
|
|
59
|
+
dims=original_dims[vn],
|
|
60
|
+
coords=original_coords[vn],
|
|
61
|
+
)
|
|
62
|
+
start_loc = end_loc
|
|
63
|
+
ds_out[vn].attrs = ds[vn].attrs
|
|
64
|
+
|
|
65
|
+
return ds_out
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# def gcd_3d(loc1, loc2, radius=6371.0):
|
|
69
|
+
# ''' 3D Great Circle Distance [km]
|
|
70
|
+
|
|
71
|
+
# Args:
|
|
72
|
+
# loc1 (tuple): lat1 [degree], lon1 [degree], depth1 [km]
|
|
73
|
+
# loc2 (tuple): lat2 [degree], lon2 [degree], depth2 [km]
|
|
74
|
+
# radius (float): Earth radius
|
|
75
|
+
# '''
|
|
76
|
+
# lat1, lon1, depth1 = loc1
|
|
77
|
+
# lat2, lon2, depth2 = loc2
|
|
78
|
+
|
|
79
|
+
# # Convert degrees to radians
|
|
80
|
+
# lat1, lon1 = np.radians(lat1), np.radians(lon1)
|
|
81
|
+
# lat2, lon2 = np.radians(lat2), np.radians(lon2)
|
|
82
|
+
|
|
83
|
+
# # Calculate radial distances (Earth's radius minus depth)
|
|
84
|
+
# r1 = radius - depth1
|
|
85
|
+
# r2 = radius - depth2
|
|
86
|
+
|
|
87
|
+
# # Compute central angle component
|
|
88
|
+
# central_angle = np.sin(lat1) * np.sin(lat2) + np.cos(lat1) * np.cos(lat2) * np.cos(lon2 - lon1)
|
|
89
|
+
|
|
90
|
+
# # Compute the 3D distance
|
|
91
|
+
# distance_3d = np.sqrt(r1**2 + r2**2 - 2 * r1 * r2 * central_angle)
|
|
92
|
+
|
|
93
|
+
# return distance_3d
|
|
94
|
+
|
|
95
|
+
def str2list(s, sep=','):
|
|
96
|
+
l = [int(ss.strip()) for ss in s.split(sep)]
|
|
97
|
+
return l
|
|
98
|
+
|
|
99
|
+
def lon360(lon180):
|
|
100
|
+
return np.mod(lon180, 360)
|
|
101
|
+
|
|
102
|
+
def lon180(lon360):
|
|
103
|
+
return np.mod(lon360 + 180, 360) - 180
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: mlda
|
|
3
|
+
Version: 2024.11.22
|
|
4
|
+
Summary: mlda: A Python package for Machine Learning-base Data Assimilation
|
|
5
|
+
Home-page: https://github.com/fzhu2e/mlda
|
|
6
|
+
Author: Feng Zhu, Weimin Si
|
|
7
|
+
Author-email: fengzhu@ucar.edu, weimin_si@brown.edu
|
|
8
|
+
License: BSD-3
|
|
9
|
+
Keywords: Machine Learning,Data Assimilation
|
|
10
|
+
Classifier: Natural Language :: English
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
12
|
+
Description-Content-Type: text/markdown
|
|
13
|
+
License-File: LICENSE
|
|
14
|
+
Requires-Dist: netCDF4
|
|
15
|
+
Requires-Dist: xarray
|
|
16
|
+
Requires-Dist: dask
|
|
17
|
+
Requires-Dist: nc-time-axis
|
|
18
|
+
Requires-Dist: colorama
|
|
19
|
+
Requires-Dist: tqdm
|
|
20
|
+
Requires-Dist: x4c-exp
|
|
21
|
+
|
|
22
|
+
# mlda: A Python package for Machine Learning-based Data Assimilation
|
|
23
|
+
|
|
24
|
+
`mlda` is a Python package for Machine Learning-base Data Assimilation (DA).
|
|
25
|
+
It aims to provide a universal framework and the corresponding utilities for conducting reproducible data assimilation experiments using novel machine learning-based DA methods.
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
setup.py
|
|
4
|
+
mlda/__init__.py
|
|
5
|
+
mlda/da.py
|
|
6
|
+
mlda/obs.py
|
|
7
|
+
mlda/prior.py
|
|
8
|
+
mlda/psm.py
|
|
9
|
+
mlda/utils.py
|
|
10
|
+
mlda.egg-info/PKG-INFO
|
|
11
|
+
mlda.egg-info/SOURCES.txt
|
|
12
|
+
mlda.egg-info/dependency_links.txt
|
|
13
|
+
mlda.egg-info/not-zip-safe
|
|
14
|
+
mlda.egg-info/requires.txt
|
|
15
|
+
mlda.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
mlda
|
mlda-2024.11.22/setup.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from setuptools import setup, find_packages
|
|
2
|
+
|
|
3
|
+
with open('README.md', 'r') as fh:
|
|
4
|
+
long_description = fh.read()
|
|
5
|
+
|
|
6
|
+
setup(
|
|
7
|
+
name='mlda', # required
|
|
8
|
+
version='2024.11.22',
|
|
9
|
+
description='mlda: A Python package for Machine Learning-base Data Assimilation',
|
|
10
|
+
long_description=long_description,
|
|
11
|
+
long_description_content_type='text/markdown',
|
|
12
|
+
author='Feng Zhu, Weimin Si',
|
|
13
|
+
author_email='fengzhu@ucar.edu, weimin_si@brown.edu',
|
|
14
|
+
url='https://github.com/fzhu2e/mlda',
|
|
15
|
+
packages=find_packages(),
|
|
16
|
+
include_package_data=True,
|
|
17
|
+
license='BSD-3',
|
|
18
|
+
zip_safe=False,
|
|
19
|
+
keywords=['Machine Learning', 'Data Assimilation'],
|
|
20
|
+
classifiers=[
|
|
21
|
+
'Natural Language :: English',
|
|
22
|
+
'Programming Language :: Python :: 3.12',
|
|
23
|
+
],
|
|
24
|
+
install_requires=[
|
|
25
|
+
'netCDF4',
|
|
26
|
+
'xarray',
|
|
27
|
+
'dask',
|
|
28
|
+
'nc-time-axis',
|
|
29
|
+
'colorama',
|
|
30
|
+
'tqdm',
|
|
31
|
+
'x4c-exp',
|
|
32
|
+
],
|
|
33
|
+
)
|