mlda 2024.11.22__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,28 @@
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2024, Feng Zhu
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its
16
+ contributors may be used to endorse or promote products derived from
17
+ this software without specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@@ -0,0 +1,25 @@
1
+ Metadata-Version: 2.1
2
+ Name: mlda
3
+ Version: 2024.11.22
4
+ Summary: mlda: A Python package for Machine Learning-base Data Assimilation
5
+ Home-page: https://github.com/fzhu2e/mlda
6
+ Author: Feng Zhu, Weimin Si
7
+ Author-email: fengzhu@ucar.edu, weimin_si@brown.edu
8
+ License: BSD-3
9
+ Keywords: Machine Learning,Data Assimilation
10
+ Classifier: Natural Language :: English
11
+ Classifier: Programming Language :: Python :: 3.12
12
+ Description-Content-Type: text/markdown
13
+ License-File: LICENSE
14
+ Requires-Dist: netCDF4
15
+ Requires-Dist: xarray
16
+ Requires-Dist: dask
17
+ Requires-Dist: nc-time-axis
18
+ Requires-Dist: colorama
19
+ Requires-Dist: tqdm
20
+ Requires-Dist: x4c-exp
21
+
22
+ # mlda: A Python package for Machine Learning-based Data Assimilation
23
+
24
+ `mlda` is a Python package for Machine Learning-base Data Assimilation (DA).
25
+ It aims to provide a universal framework and the corresponding utilities for conducting reproducible data assimilation experiments using novel machine learning-based DA methods.
@@ -0,0 +1,4 @@
1
+ # mlda: A Python package for Machine Learning-based Data Assimilation
2
+
3
+ `mlda` is a Python package for Machine Learning-base Data Assimilation (DA).
4
+ It aims to provide a universal framework and the corresponding utilities for conducting reproducible data assimilation experiments using novel machine learning-based DA methods.
@@ -0,0 +1,11 @@
1
+ # get the version
2
+ from importlib.metadata import version
3
+ __version__ = version('cpda')
4
+ import warnings
5
+ warnings.filterwarnings("ignore", category=UserWarning)
6
+
7
+ from . import utils, cesm
8
+ from .prior import Prior, PriorMember
9
+ from .obs import Obs
10
+ from .da import Solver
11
+
@@ -0,0 +1,333 @@
1
+ import numpy as np
2
+ import xarray as xr
3
+ from scipy.linalg import cholesky, sqrtm
4
+
5
+
6
+ from . import utils
7
+
8
+ def gaspari_cohn(dist, loc_radius):
9
+ '''
10
+ Vectorized Gaspari-Cohn localization function.
11
+
12
+ Args:
13
+ dist (ndarray): Distance(s) between model state and observation.
14
+ loc_radius (float): Localization radius (distance beyond which covariance is set to zero).
15
+
16
+ Reference:
17
+ Gaspari, G., Cohn, S.E., 1999. Construction of correlation functions in two and three dimensions.
18
+ Quarterly Journal of the Royal Meteorological Society 125, 723-757. https://doi.org/10.1002/qj.49712555417
19
+ '''
20
+ # Normalize the distances
21
+ r = np.abs(dist) / loc_radius
22
+
23
+ # Initialize the result array with zeros
24
+ f = np.zeros_like(r)
25
+
26
+ # Eq. (4.10) in Gaaspari & Coh (1999)
27
+ mask1 = r <= 1
28
+ f[mask1] = -r[mask1]**5 / 4 + r[mask1]**4 / 2 + 5/8 * r[mask1]**3 - 5/3 * r[mask1]**2 + 1
29
+
30
+ mask2 = (r > 1) & (r <= 2)
31
+ f[mask2] = r[mask2]**5 / 12 - r[mask2]**4 / 2 + 5/8 * r[mask2]**3 + 5/3 * r[mask2]**2 - 5 * r[mask2] + 4 - 2/3 / r[mask2]
32
+
33
+ f[f<0] = 0 # force f >= 0
34
+ return f
35
+
36
+ def gaspari_cohn_dash(dist, loc_radius, scale=0.5):
37
+ """
38
+ Implements a Gaspari-Cohn 5th order polynomial localization function.
39
+
40
+ Parameters:
41
+ dist (ndarray): An array of distances.
42
+ loc_radius (float): The cutoff radius, beyond which weights are zero.
43
+ scale (float or str, optional): The length scale for the polynomial.
44
+ Must be on the interval 0 < scale <= 0.5, or 'optimal' to use the optimal
45
+ length scale as described by Lorenc (2003). Default is 0.5.
46
+
47
+ Returns:
48
+ weights (ndarray): Covariance localization weights with the same shape as distances.
49
+ """
50
+ # Set the scale if 'optimal' is specified
51
+ if isinstance(scale, str) and scale == 'optimal':
52
+ scale = np.sqrt(10 / 3)
53
+
54
+ # Define length scale and localization radius
55
+ c = scale * loc_radius
56
+
57
+ # Preallocate weights array with ones
58
+ weights = np.ones_like(dist)
59
+
60
+ # Calculate mask arrays for the different distance ranges
61
+ outside_radius = dist > loc_radius
62
+ inside_scale = dist <= c
63
+ in_between = ~inside_scale & ~outside_radius
64
+
65
+ # Apply Gaspari-Cohn polynomial
66
+ X = dist / c
67
+ weights[outside_radius] = 0
68
+ weights[in_between] = X[in_between]**5 / 12 - 0.5 * X[in_between]**4 + 0.625 * X[in_between]**3 + (5 / 3) * X[in_between]**2 - 5 * X[in_between] + 4 - 2 / (3 * X[in_between])
69
+ weights[inside_scale] = -0.25 * X[inside_scale]**5 + 0.5 * X[inside_scale]**4 + 0.625 * X[inside_scale]**3 - (5 / 3) * X[inside_scale]**2 + 1
70
+
71
+ # Ensure weights are non-negative due to rounding errors
72
+ weights[weights < 0] = 0
73
+
74
+ return weights
75
+
76
+
77
+
78
+ class EnSRF:
79
+ def __init__(self, X=None, Y=None, y=None, R=None, L=None, Lobs=None):
80
+ self.X = X # ensemble of the prior state vectors (n x N)
81
+ self.Y = Y # ensemble of the forward estimates (m x N); Y=H(X)
82
+ self.y = y # observations (m x 1)
83
+ self.R = R # obs err matrix (m x m)
84
+ self.L = L # localization matrix (n x m)
85
+ self.Lobs = Lobs # localization matrix (m x m)
86
+
87
+ def update(self, debug=False):
88
+ ''' Perform an EnSRF update with localization. '''
89
+ N = self.X.shape[1] # Ensemble size
90
+
91
+ # Compute the ensemble mean
92
+ Xm = np.mean(self.X, axis=1, keepdims=True)
93
+ Xp = self.X - Xm
94
+
95
+ Ym = np.mean(self.Y, axis=1, keepdims=True)
96
+ Yp = self.Y - Ym
97
+
98
+ # Observation error covariance matrix
99
+ Ycov = (Yp @ Yp.T) / (N - 1)
100
+
101
+ # Localize the obs err covariance matrix
102
+ if self.Lobs is not None:
103
+ Ycov_loc = Ycov * self.Lobs
104
+ else:
105
+ Ycov_loc = Ycov
106
+
107
+ C = Ycov_loc + self.R
108
+
109
+ # Kalman gain matrix
110
+ XYcov = (Xp @ Yp.T) / (N - 1)
111
+
112
+ # Localize the Kalman gain
113
+ if self.L is not None:
114
+ XYcov_loc = XYcov * self.L
115
+ else:
116
+ XYcov_loc = XYcov
117
+
118
+ K = XYcov_loc @ np.linalg.inv(C)
119
+
120
+ # Observation innovation
121
+ d = self.y - Ym
122
+
123
+ # Update the ensemble mean
124
+ Xm_updated = Xm + K @ d
125
+
126
+ # Update the ensemble perturbations
127
+ T = np.eye(N) - (Yp.T @ np.linalg.inv(C)) @ Yp / (N - 1)
128
+ Xp_updated = Xp @ T
129
+
130
+ # Combine updated mean and perturbations
131
+ self.X_updated = Xm_updated + Xp_updated
132
+
133
+ if debug:
134
+ self.Xm = Xm
135
+ self.Xp = Xp
136
+ self.Ym = Ym
137
+ self.Yp = Yp
138
+ self.C = C
139
+ self.K = K
140
+ self.d = d
141
+ self.T = T
142
+
143
+ class EnSRF_DASH:
144
+ def __init__(self, X=None, Y=None, y=None, R=None, L=None, Lobs=None):
145
+ self.X = X # ensemble of the prior state vectors (n x N)
146
+ self.Y = Y # ensemble of the forward estimates (m x N); Y=H(X)
147
+ self.y = y # observations (m x 1)
148
+ self.R = R # obs err matrix (m x m)
149
+ self.L = L # localization matrix (n x m)
150
+ self.Lobs = Lobs # localization matrix (m x m)
151
+
152
+ def update(self, debug=False):
153
+ ''' Perform an EnSRF update with localization. '''
154
+ N = self.X.shape[1] # Ensemble size
155
+
156
+ # Compute the ensemble mean
157
+ Xm = np.mean(self.X, axis=1, keepdims=True)
158
+ Xp = self.X - Xm
159
+
160
+ Ym = np.mean(self.Y, axis=1, keepdims=True)
161
+ Yp = self.Y - Ym
162
+
163
+ # Observation error covariance matrix
164
+ Ycov = (Yp @ Yp.T) / (N - 1)
165
+
166
+ # Localize the obs err covariance matrix
167
+ if self.Lobs is not None:
168
+ Ycov_loc = self.Lobs * Ycov
169
+ else:
170
+ Ycov_loc = Ycov
171
+
172
+ C = Ycov_loc + self.R
173
+
174
+ # Kalman gain matrix
175
+ XYcov = (Xp @ Yp.T) / (N - 1)
176
+
177
+ # Localize the Kalman gain
178
+ if self.L is not None:
179
+ XYcov_loc = self.L * XYcov
180
+ else:
181
+ XYcov_loc = XYcov
182
+
183
+ K = XYcov_loc @ np.linalg.inv(C)
184
+
185
+ # Observation innovation
186
+ d = self.y - Ym
187
+
188
+ # Update the ensemble mean
189
+ Xm_updated = Xm + K @ d
190
+
191
+ # Update the ensemble perturbations
192
+ Ksqrt = sqrtm(C)
193
+ Ksqrt_inv_transpose = np.linalg.inv(Ksqrt).T
194
+ Rcov_sqrt = sqrtm(self.R)
195
+ Ka = K @ Ksqrt_inv_transpose @ np.linalg.inv(Ksqrt + Rcov_sqrt)
196
+ Xp_updated = Xp - Ka @ Yp
197
+
198
+ # Combine updated mean and perturbations
199
+ self.X_updated = Xm_updated + Xp_updated
200
+
201
+ if debug:
202
+ self.Xm = Xm
203
+ self.Xp = Xp
204
+ self.Ym = Ym
205
+ self.Yp = Yp
206
+ self.C = C
207
+ self.K = K
208
+ self.d = d
209
+
210
+
211
+ class EnOI:
212
+ def __init__(self, X_target=None, X=None, Y=None, y=None, R=None, L=None):
213
+ self.X_target = X_target # the **monthly** prior state vectors (n x 1)
214
+ self.X = X # ensemble of the prior state vectors (n x N)
215
+ self.Y = Y # ensemble of the forward estimates (m x N); Y=H(X)
216
+ self.y = y # observations (m x 1)
217
+ self.R = R # obs err matrix (m x m)
218
+ self.L = L # localization matrix (n x m)
219
+
220
+ def update(self, debug=False):
221
+ ''' Perform an EnOI update with localization. '''
222
+ N = self.X.shape[1] # Ensemble size
223
+
224
+ # Compute the ensemble mean
225
+ Xm = np.mean(self.X, axis=1, keepdims=True)
226
+ Xp = self.X - Xm
227
+
228
+ Ym = np.mean(self.Y, axis=1, keepdims=True)
229
+ Yp = self.Y - Ym
230
+
231
+ # Observation error covariance matrix
232
+ C = (Yp @ Yp.T) / (N - 1) + self.R
233
+
234
+ # Kalman gain matrix
235
+ K = (Xp @ Yp.T) / (N - 1) @ np.linalg.inv(C)
236
+
237
+ # Localize the Kalman gain
238
+ if self.L is not None:
239
+ K_loc = K * self.L
240
+ else:
241
+ K_loc = K
242
+
243
+ # Observation innovation
244
+ d = self.y - Ym
245
+
246
+ # the increment
247
+ inc = K_loc @ d
248
+
249
+ # update
250
+ self.X_target_updated = self.X_target + inc
251
+
252
+ if debug:
253
+ self.Xm = Xm
254
+ self.Xp = Xp
255
+ self.Ym = Ym
256
+ self.Yp = Yp
257
+ self.C = C
258
+ self.K = K
259
+ self.K_loc = K_loc
260
+ self.d = d
261
+
262
+
263
+ class Solver:
264
+ def __init__(self, prior=None, obs=None, prior_target=None):
265
+ self.prior = prior.copy() if prior is not None else None
266
+ self.obs = obs.copy() if obs is not None else None
267
+ self.prior_target = prior_target.copy() if prior_target is not None else None
268
+
269
+ def prep(self, localize=True, loc_radius=2500, dist_vsf=1, dlat=1, dlon=1, loc_method='dash',
270
+ recon_season=list(range(1, 13)), startover=False, nearest_valid_radius=5, **fwd_kws):
271
+ ''' Prepare Y=H(X) and the localization matrix for DA
272
+
273
+ Args:
274
+ dist_vsf (float, list of float): the vertical scaling factor of the distance
275
+
276
+ '''
277
+ if not hasattr(self.prior, 'ds_rgd'):
278
+ utils.p_header(f'>>> Regridding the prior (dlat={dlat}, dlon={dlon})')
279
+ self.prior.regrid(dlat=dlat, dlon=dlon)
280
+
281
+ if startover or not hasattr(self.prior, 'Y'):
282
+ utils.p_header('>>> Proxy System Modeling: Y = H(X)')
283
+ self.prior.get_Y(self.obs, nearest_valid_radius=nearest_valid_radius, **fwd_kws)
284
+
285
+ if not hasattr(self.prior, 'ds_ann'):
286
+ utils.p_header(f'>>> Annualizing prior w/ season: {recon_season}')
287
+ self.prior.annualize(months=recon_season)
288
+
289
+ if localize and not hasattr(self.prior, 'dist'):
290
+ loc_func = {
291
+ 'cpda': gaspari_cohn,
292
+ 'dash': gaspari_cohn_dash,
293
+ }
294
+ utils.p_header('>>> Computing the localization matrix')
295
+ self.prior.get_dist(self.prior.obs_assim, dist_vsf)
296
+ self.L = loc_func[loc_method](self.prior.dist, loc_radius)
297
+ self.obs.get_dist()
298
+ self.Lobs = loc_func[loc_method](self.obs.dist, loc_radius)
299
+ else:
300
+ self.L = None
301
+ self.Lobs = None
302
+
303
+ def run(self, method='EnSRF', debug=False):
304
+ algo = {
305
+ 'EnSRF': EnSRF,
306
+ 'EnSRF_DASH': EnSRF_DASH,
307
+ 'EnOI': EnOI,
308
+ }
309
+
310
+ kws = {}
311
+ for m in algo.keys():
312
+ kws[m] = {
313
+ 'X': self.prior.X,
314
+ 'Y': self.prior.Y,
315
+ 'y': self.obs.y,
316
+ 'R': self.obs.R,
317
+ 'L': self.L,
318
+ 'Lobs': self.Lobs,
319
+ }
320
+
321
+ if self.prior_target is not None:
322
+ kws['EnOI']['X_target'] = self.prior_target.X
323
+
324
+ self.S = algo[method](**kws[method])
325
+
326
+ utils.p_header('>>> DA update')
327
+ self.S.update(debug=debug)
328
+
329
+ utils.p_header('>>> Formatting the posterior')
330
+ if method in ['EnSRF', 'EnSRF_DASH']:
331
+ self.post = utils.states2ds(self.S.X_updated, self.prior.ds_ann)
332
+ elif method == 'EnOI':
333
+ self.post = utils.states2ds(self.S.X_target_updated, self.prior_target.ds_ann)
@@ -0,0 +1,80 @@
1
+ from copy import deepcopy
2
+ import numpy as np
3
+ import pandas as pd
4
+ import xarray as xr
5
+
6
+ from . import utils
7
+ class Obs:
8
+ def __init__(self, df:pd.DataFrame):
9
+ self.df = df
10
+ self.df['lon'] = (df['lon'] + 360) % 360
11
+ self.nobs = len(df)
12
+ self.pids = df['pid'].values
13
+ self.records = {}
14
+ for pid in self.pids:
15
+ self.records[pid] = self[pid]
16
+
17
+ @property
18
+ def y(self):
19
+ return self.df['value'].values[..., np.newaxis]
20
+
21
+ @property
22
+ def y_locs(self):
23
+ return self.df[['lat', 'lon']].values
24
+
25
+ @property
26
+ def R(self):
27
+ return np.diag(self.df['R'].values)
28
+
29
+ def copy(self):
30
+ return deepcopy(self)
31
+
32
+ def __getitem__(self, pid:str):
33
+ mask = self.df['pid'] == pid
34
+ row = self.df[mask].iloc[0]
35
+ rec = ProxyRecord(row)
36
+ return rec
37
+
38
+ def get_dist(self):
39
+ lats = self.df['lat'].values
40
+ lons = self.df['lon'].values
41
+ lat1, lat2 = np.meshgrid(lats, lats)
42
+ lon1, lon2 = np.meshgrid(lons, lons)
43
+ self.dist = utils.gcd(lat1, lon1, lat2, lon2)
44
+ return self.dist
45
+
46
+
47
+ class ProxyRecord:
48
+ def __init__(self, data:pd.Series):
49
+ self.data = data.copy()
50
+ if 'time' in data: self.data['time'] = np.array(data['time'])
51
+ if 'value' in data: self.data['value'] = np.array(data['value'])
52
+
53
+ if 'seasonality' in data:
54
+ if isinstance(data['seasonality'], str):
55
+ self.data['seasonality'] = utils.str2list(data['seasonality'])
56
+ elif isinstance(data['seasonality'], list):
57
+ self.data['seasonality'] = data['seasonality']
58
+ else:
59
+ raise ValueError('Wrong seasonality type; should be a string or a list.')
60
+
61
+
62
+ def get_clim(self, clim_ds, vns:list=None, verbose=False):
63
+ if vns is None:
64
+ vns = clim_ds.data_vars
65
+ else:
66
+ vns = [vn for vn in vns if vn in clim_ds.data_vars]
67
+
68
+ self.clim = xr.Dataset()
69
+ for vn in vns:
70
+ self.clim[vn] = clim_ds[vn].x.nearest2d(
71
+ # filled_da = clim_ds[vn].ffill(dim='lon').bfill(dim='lon').ffill(dim='lat').bfill(dim='lat')
72
+ # self.clim[vn] = filled_da.sel(
73
+ lat=self.data.lat,
74
+ lon=self.data.lon,
75
+ method='nearest',
76
+ ).sel(month=self.data.seasonality).mean(dim='month')
77
+ if verbose: utils.p_success(f'>>> ProxyRecord.clim["{vn}"] created')
78
+
79
+ self.clim.attrs['seasonality'] = self.data.seasonality
80
+
@@ -0,0 +1,329 @@
1
+ import numpy as np
2
+ import xarray as xr
3
+ from tqdm import tqdm
4
+ from scipy.stats import norm
5
+ from tqdm import tqdm
6
+ from copy import deepcopy
7
+
8
+ from . import psm
9
+ from . import utils
10
+
11
+ class PriorMember:
12
+ def __init__(self, ds):
13
+ if isinstance(ds, xr.DataArray): ds = ds.to_dataset()
14
+ self.ds = ds
15
+ self.vns = list(ds.data_vars)
16
+
17
+ def gen_samples_Gaussian(self, local_sigma:dict, global_sigma:dict, nens:int=100, seed:int=2333):
18
+ ''' Generate samples following Gaussian
19
+
20
+ Args:
21
+ sigma (dict): Dictionary with standard deviation (sigma) for each variable.
22
+ nens (int): Number of ensemble members to generate.
23
+ seed (int): Seed for reproducibility.
24
+ '''
25
+ rng = np.random.default_rng(seed)
26
+ self.samples = xr.Dataset()
27
+ for vn in self.vns:
28
+ mean = self.ds[vn].values
29
+ samples_shape = (*mean.shape, nens)
30
+ global_perturbation = norm.rvs(loc=0, scale=global_sigma[vn], size=nens, random_state=rng)
31
+ local_perturbation = norm.rvs(loc=0, scale=local_sigma[vn], size=samples_shape, random_state=rng)
32
+ samples = mean[..., np.newaxis] + global_perturbation + local_perturbation
33
+ samples_da = xr.DataArray(samples, dims=(*self.ds[vn].dims, 'ens'), coords=self.ds[vn].coords)
34
+ samples_da.attrs = self.ds[vn].attrs
35
+ self.samples[vn] = samples_da
36
+
37
+ def gen_samples_bootstrap(self, nens:int=30, clim_yrs:int=50, seed:int=0, dim='time'):
38
+ ''' Generate samples from the prior pool
39
+
40
+ Args:
41
+ '''
42
+ nt = len(self.ds[dim])
43
+ pool_idx = list(range(nt))
44
+ sample_list = []
45
+ if seed is None: seed = 0
46
+ for i in range(nens):
47
+ seed += 1
48
+ rng = np.random.default_rng(seed)
49
+ sample_idx = rng.choice(pool_idx, size=clim_yrs, replace=False)
50
+ sample = self.ds.isel({dim: sample_idx}).mean(dim)
51
+ sample_list.append(sample)
52
+
53
+ samples = xr.concat(sample_list, dim=dim)
54
+ self.samples = xr.Dataset(samples).rename({dim: 'ens'})
55
+
56
+
57
+ class Prior:
58
+ def __init__(self, members, lat_name='TLAT', lon_name='TLONG', depth_name='z_t'):
59
+ if not isinstance(members, list): members = [members]
60
+ ds_list = []
61
+ for m in members:
62
+ if hasattr(m, 'samples'):
63
+ ds_list.append(m.samples)
64
+ else:
65
+ if 'ens' not in m.ds.dims:
66
+ ds_list.append(m.ds.expand_dims({'ens': 1}))
67
+ else:
68
+ ds_list.append(m.ds)
69
+
70
+ self.ds = xr.concat(ds_list, dim='ens').transpose(..., 'ens')
71
+ self.lat_name = lat_name
72
+ self.lon_name = lon_name
73
+ self.depth_name = depth_name
74
+ if depth_name is not None:
75
+ self.nz = len(self.ds[depth_name])
76
+
77
+ self.nlat, self.nlon = self.ds[lat_name].shape, self.ds[lon_name].shape
78
+ self.nens = len(self.ds.ens)
79
+ self.nvar = len(self.ds.data_vars)
80
+
81
+ def regrid(self, dlat=1, dlon=1, verbose=False):
82
+ self.ds_rgd = xr.Dataset()
83
+ for vn in tqdm(self.ds.data_vars, desc=f'Regridding variables to {dlat}x{dlon}'):
84
+ self.ds_rgd[vn] = self.ds.x[vn].x.regrid(dlat=dlat, dlon=dlon)
85
+ if verbose: utils.p_success(f'>>> Prior.ds_rgd["{vn}"] created')
86
+
87
+ def annualize(self, months=list(range(1, 13))):
88
+ self.ds_ann = self.ds.sel(month=months).mean('month')
89
+
90
+ def inflate(self, factor=2):
91
+ self.ds_raw = self.ds.copy()
92
+ ens_mean = self.ds.mean('ens')
93
+ ens_pert = self.ds - self.ds.mean('ens')
94
+ inflated_pert = ens_pert * factor
95
+ self.ds = ens_mean + inflated_pert
96
+
97
+ def copy(self):
98
+ return deepcopy(self)
99
+
100
+
101
+ @property
102
+ def X(self):
103
+ res = []
104
+ for vn in self.ds_ann.data_vars:
105
+ res.append(self.ds_ann[vn].values.reshape(-1, self.nens))
106
+
107
+ res = np.array(res).reshape(-1, self.nens)
108
+ return res
109
+
110
+ # def get_Y(self, obs, **fwd_kws):
111
+ # self.obs_assim = obs.copy()
112
+ # lats = obs.df['lat'].values
113
+ # lons = obs.df['lon'].values
114
+ # pids = obs.df['pid'].values
115
+ # depths = obs.df['depth'].values
116
+ # if 'clean' in obs.df.columns: cleans = obs.df['clean'].values
117
+ # if 'species' in obs.df.columns: specs = obs.df['species'].values
118
+
119
+ # psms = obs.df['psm'].values
120
+ # psm_names = list(set(psms))
121
+
122
+ # pseudo_obs = np.empty((len(obs.df), self.nens))
123
+
124
+ # # Loop over PSM types (psm_names)
125
+ # for psm_name in psm_names:
126
+ # mask = psms == psm_name
127
+ # idx = np.where(mask)[0]
128
+ # if np.any(mask):
129
+ # lat_lon_pairs = xr.Dataset({
130
+ # 'lat': (('obs',), lats[mask]),
131
+ # 'lon': (('obs',), lons[mask]),
132
+ # })
133
+ # self.clim_proxy_locs = xr.Dataset()
134
+ # for vn in self.ds_rgd.data_vars:
135
+ # filled_da = self.ds_rgd[vn].ffill(dim='lat').bfill(dim='lat').ffill(dim='lon').bfill(dim='lon')
136
+ # self.clim_proxy_locs[vn] = filled_da.sel(
137
+ # lat=lat_lon_pairs['lat'],
138
+ # lon=lat_lon_pairs['lon'],
139
+ # method='nearest',
140
+ # ).transpose(..., 'ens')
141
+
142
+ # for i in tqdm(range(len(idx)), desc=f'>>> Looping over sites w/ PSM - {psm_name}'):
143
+ # pid = pids[idx[i]]
144
+ # lat = lats[idx[i]]
145
+ # lon = lons[idx[i]]
146
+ # depth = depths[idx[i]]
147
+ # if np.isnan(depth): depth = 0
148
+
149
+ # obs_meta = {
150
+ # 'pid': pid,
151
+ # 'lat': lat,
152
+ # 'lon': lon,
153
+ # 'depth': depth,
154
+ # }
155
+ # if 'clean' in obs.df.columns:
156
+ # clean = cleans[idx[i]]
157
+ # if np.isnan(clean): clean = 0
158
+ # obs_meta['clean'] = clean
159
+
160
+ # if 'species' in obs.df.columns:
161
+ # species = specs[idx[i]]
162
+ # if not isinstance(species, str): species = 'all'
163
+ # obs_meta['species'] = species
164
+
165
+ # mdl = psm.__dict__[psm_name](obs_meta, self.clim_proxy_locs.isel({'obs': i}))
166
+ # _fwd_kws = {}
167
+ # _fwd_kws[psm_name] = {}
168
+ # if psm_name in fwd_kws:
169
+ # _fwd_kws[psm_name].update(fwd_kws[psm_name])
170
+
171
+ # res = mdl.forward(**_fwd_kws[psm_name])
172
+ # if res is None:
173
+ # utils.p_warning(f'>>> Dropping proxy: {pid}')
174
+ # self.obs_assim.df = obs.df.drop(obs.df[obs.df['pid'] == pid].index)
175
+ # pseudo_obs[idx[i]] = np.nan
176
+ # else:
177
+ # pseudo_obs[idx[i]] = res
178
+
179
+ # self.obs_assim.nobs = len(self.obs_assim.df)
180
+ # pseudo_obs = pseudo_obs[~np.isnan(pseudo_obs).any(axis=1)]
181
+ # self.Y = pseudo_obs
182
+ # self.obs_assim.df['Ym'] = self.Y.mean(axis=1)
183
+
184
+
185
+ # def get_Y(self, obs, **fwd_kws):
186
+ # self.obs_assim = obs.copy()
187
+ # pseudo_obs = np.empty((len(obs.df), self.nens))
188
+
189
+ # for i, (pid, rec) in tqdm(enumerate(obs.records.items()), total=obs.nobs, desc='Looping over records'):
190
+ # mdl = psm.__dict__[rec.data.psm](rec)
191
+ # mdl.record.get_clim(self.ds_rgd, vns=mdl.clim_vns)
192
+
193
+ # _fwd_kws = {}
194
+ # _fwd_kws[rec.data.psm] = {}
195
+ # if rec.data.psm in fwd_kws:
196
+ # _fwd_kws[rec.data.psm].update(fwd_kws[rec.data.psm])
197
+
198
+ # res = mdl.forward(**_fwd_kws[rec.data.psm])
199
+ # if res is None:
200
+ # utils.p_warning(f'>>> Dropping proxy: {pid}')
201
+ # self.obs_assim.df = obs.df.drop(obs.df[obs.df['pid'] == pid].index)
202
+ # pseudo_obs[i] = np.nan
203
+ # else:
204
+ # pseudo_obs[i] = res
205
+
206
+ # self.obs_assim.nobs = len(self.obs_assim.df)
207
+ # pseudo_obs = pseudo_obs[~np.isnan(pseudo_obs).any(axis=1)]
208
+ # self.Y = pseudo_obs
209
+ # self.obs_assim.df['Ym'] = self.Y.mean(axis=1)
210
+
211
+ def get_Y(self, obs, nearest_valid_radius=5, **fwd_kws):
212
+ self.obs_assim = obs.copy()
213
+ pseudo_obs = np.empty((len(obs.df), self.nens))
214
+
215
+ psm_names = set(obs.df['psm_name'])
216
+ clim_vns = list({
217
+ vn for psm_name in psm_names
218
+ for vn in psm.__dict__[psm_name]().clim_vns
219
+ if vn in self.ds_rgd.data_vars
220
+ })
221
+
222
+ lat_lon_pairs = xr.Dataset({
223
+ 'lat': (('sites',), obs.df['lat'].values),
224
+ 'lon': (('sites',), obs.df['lon'].values),
225
+ })
226
+ self.ds_proxy_locs = xr.Dataset()
227
+ for vn in clim_vns:
228
+ # filled_da = self.ds_rgd[vn].ffill(dim='lon').bfill(dim='lon').ffill(dim='lat').bfill(dim='lat')
229
+ # ds_proxy_locs[vn] = filled_da.sel(
230
+ # lat=lat_lon_pairs['lat'],
231
+ # lon=lat_lon_pairs['lon'],
232
+ # method='nearest',
233
+ # ).transpose(..., 'ens')
234
+
235
+ self.ds_proxy_locs[vn] = self.ds_rgd[vn].x.nearest2d(
236
+ lat=lat_lon_pairs['lat'],
237
+ lon=lat_lon_pairs['lon'],
238
+ r=nearest_valid_radius,
239
+ extra_dim='ens',
240
+ ).transpose(..., 'ens')
241
+
242
+ if 'sites' not in self.ds_proxy_locs.dims:
243
+ self.ds_proxy_locs = self.ds_proxy_locs.expand_dims({'sites': [0]})
244
+
245
+ # if ds_proxy_locs[vn].isnull().any():
246
+ # for idx in obs.df.index:
247
+ # if ds_proxy_locs[vn].sel(sites=idx).isnull().any():
248
+ # utils.p_warning(f"NaN detected for {vn}: {obs.df.iloc[idx][['pid', 'lat', 'lon']].values}")
249
+ # print(ds_proxy_locs[vn].sel(sites=idx).dims)
250
+ # print(ds_proxy_locs[vn].sel(sites=idx).values)
251
+ # utils.p_warning('------------------------------------')
252
+ # raise ValueError('Some of the nearest gridcell values are NaN.')
253
+
254
+ nearest_lats, nearest_lons = [], []
255
+ for i, (pid, rec) in tqdm(enumerate(obs.records.items()), total=obs.nobs, desc='Looping over records'):
256
+ # nearest_clim = self.ds_proxy_locs.isel({'sites': i}).sel(month=rec.data.seasonality).mean(dim='month')
257
+ # nearest_lat = nearest_clim.lat.values.mean()
258
+ # nearest_lon = nearest_clim.lon.values.mean()
259
+ # nearest_lats.append(nearest_lat)
260
+ # nearest_lons.append(nearest_lon)
261
+ # rec.data.lat = nearest_lat
262
+ # rec.data.lon = nearest_lon
263
+
264
+ mdl = psm.__dict__[rec.data.psm_name](rec)
265
+ mdl.record.clim = self.ds_proxy_locs.isel({'sites': i}).sel(month=rec.data.seasonality).mean(dim='month')
266
+ for vn in clim_vns:
267
+ if mdl.record.clim[vn].isnull().any():
268
+ # print(i, ds_proxy_locs[vn].isel({'sites': i}))
269
+ # print(ds_proxy_locs[vn].isel({'sites': i, 'ens': 6}))
270
+ # print(ds_proxy_locs.sel(month=rec.data.seasonality).mean(dim='month')[vn].values[i])
271
+ # print(vn, rec.data.pid, rec.data.lat, rec.data.lon, rec.data.seasonality)
272
+ # print(mdl.record.clim[vn].values)
273
+ raise ValueError(f'NaN values detected in input climate for forward modeling of: {pid}')
274
+
275
+ obs.records[pid].psm = mdl # for debugging purposes
276
+
277
+ _fwd_kws = {}
278
+ _fwd_kws[rec.data.psm_name] = {}
279
+ if rec.data.psm_name in fwd_kws:
280
+ _fwd_kws[rec.data.psm_name].update(fwd_kws[rec.data.psm_name])
281
+ mdl.forward(**_fwd_kws[rec.data.psm_name])
282
+ if mdl.output is None:
283
+ utils.p_warning(f'>>> Dropping proxy: {pid}')
284
+ self.obs_assim.df = obs.df.drop(obs.df[obs.df['pid'] == pid].index)
285
+ pseudo_obs[i] = np.nan
286
+ else:
287
+ pseudo_obs[i] = mdl.output
288
+
289
+ self.obs_assim.nobs = len(self.obs_assim.df)
290
+ # self.obs_assim.df['lat'] = nearest_lats
291
+ # self.obs_assim.df['lon'] = nearest_lons
292
+ pseudo_obs = pseudo_obs[~np.isnan(pseudo_obs).any(axis=1)]
293
+ self.Y = pseudo_obs
294
+ self.obs_assim.df['Ym'] = self.Y.mean(axis=1)
295
+
296
+ def get_dist(self, obs, s=1):
297
+ # Extract grid latitudes and longitudes as 2D arrays
298
+ lat_grid = self.ds[self.lat_name].values # shape: (nlat, nlon)
299
+ lon_grid = self.ds[self.lon_name].values # shape: (nlat, nlon)
300
+
301
+ if lat_grid.ndim == 1 and lon_grid.ndim == 1:
302
+ # If lat and lon are 1D, create a meshgrid
303
+ lon_grid, lat_grid = np.meshgrid(lon_grid, lat_grid)
304
+
305
+ # Flatten the grid arrays to 1D
306
+ lat_grid_flat = lat_grid.ravel() # shape: (nlat * nlon,)
307
+ lon_grid_flat = lon_grid.ravel() # shape: (nlat * nlon,)
308
+
309
+ # Get the observation lat/lon as a 2D array
310
+ lats2 = obs.df['lat'].values # shape: (nobs,)
311
+ lons2 = obs.df['lon'].values # shape: (nobs,)
312
+
313
+ # Broadcast the grid cells to all observation points
314
+ lats1 = np.repeat(lat_grid_flat, obs.nobs) # shape: (nlat * nlon * nobs,)
315
+ lons1 = np.repeat(lon_grid_flat, obs.nobs) # shape: (nlat * nlon * nobs,)
316
+
317
+ # Repeat observation points for every grid point
318
+ lats2 = np.tile(lats2, len(lat_grid_flat)) # shape: (nlat * nlon * nobs,)
319
+ lons2 = np.tile(lons2, len(lon_grid_flat)) # shape: (nlat * nlon * nobs,)
320
+
321
+ dist0 = utils.gcd(lats1, lons1, lats2, lons2).reshape((-1, obs.nobs))
322
+
323
+ if hasattr(self, 'nz'):
324
+ # 3D localization
325
+ s = (np.ones(self.nz)*s).reshape(-1, 1, 1)
326
+ dist1 = (dist0[None, :, :] * s).reshape((-1, obs.nobs))
327
+ self.dist = dist1[np.newaxis, :].repeat(self.nvar, axis=0).reshape(-1, obs.nobs)
328
+ else:
329
+ self.dist = dist0[None, :, :].repeat(self.nvar, axis=0).reshape(-1, obs.nobs)
@@ -0,0 +1,147 @@
1
+ import xarray as xr
2
+ import pybaywatch as pb
3
+ import numpy as np
4
+
5
+ from . import utils
6
+ from . import obs
7
+
8
+ class IdenticalSST:
9
+ def __init__(self, record:obs.ProxyRecord=None):
10
+ self.record = record
11
+
12
+ @property
13
+ def clim_vns(self):
14
+ return ['TEMP']
15
+
16
+ def forward(self):
17
+ self.output = self.record.clim['TEMP'].isel(z_t=0).values
18
+
19
+ class IdenticalSSS:
20
+ def __init__(self, record:obs.ProxyRecord=None):
21
+ self.record = record
22
+
23
+ @property
24
+ def clim_vns(self):
25
+ return ['SALT']
26
+
27
+ def forward(self):
28
+ self.output = self.record.clim['SALT'].isel(z_t=0).values
29
+
30
+ class IdenticalSSTSSS:
31
+ def __init__(self, record:obs.ProxyRecord=None):
32
+ self.record = record
33
+
34
+ @property
35
+ def clim_vns(self):
36
+ return ['TEMP', 'SALT']
37
+
38
+ def forward(self):
39
+ self.output = self.record.clim['TEMP'].isel(z_t=0).values+self.record.clim['SALT'].isel(z_t=0).values
40
+
41
+ class TEX86:
42
+ def __init__(self, record:obs.ProxyRecord=None):
43
+ self.record = record
44
+
45
+ @property
46
+ def clim_vns(self):
47
+ return ['TEMP', 'tos', 'sst']
48
+
49
+ def forward(self, seed=2333, mode='analog', type='SST', tolerance=1):
50
+ if 'TEMP' in self.record.clim:
51
+ sst = self.record.clim['TEMP'].isel(z_t=0).values
52
+ elif 'tos' in self.record.clim:
53
+ sst = self.record.clim['tos'].values
54
+ elif 'sst' in self.record.clim:
55
+ sst = self.record.clim['sst'].values
56
+
57
+ lat = self.record.data.lat
58
+ lon = self.record.data.lon
59
+ lon180 = utils.lon180(lon)
60
+
61
+ # run
62
+ self.params = {
63
+ 'lat': lat,
64
+ 'lon': lon180,
65
+ 'temp': sst,
66
+ 'seed': seed,
67
+ 'type': type,
68
+ 'mode': mode,
69
+ 'tolerance': tolerance,
70
+ }
71
+ res = pb.TEX_forward(**self.params)
72
+ if res['status'] == 'FAIL':
73
+ utils.p_warning(f'>>> Forward modeling failed for proxy: {self.meta["pid"]}')
74
+ self.output = None
75
+ else:
76
+ self.output = np.median(res['values'], axis=1)
77
+
78
+ class UK37:
79
+ def __init__(self, record:obs.ProxyRecord=None):
80
+ self.record = record
81
+
82
+ @property
83
+ def clim_vns(self):
84
+ return ['TEMP', 'tos', 'sst']
85
+
86
+ def forward(self, order=3, seed=2333):
87
+ if 'TEMP' in self.record.clim:
88
+ sst = self.clim['TEMP'].isel(z_t=0).values
89
+ elif 'tos' in self.record.clim:
90
+ sst = self.record.clim['tos'].values
91
+ elif 'sst' in self.record.clim:
92
+ sst = self.record.clim['sst'].values
93
+
94
+ # run
95
+ self.params = {
96
+ 'sst': sst,
97
+ 'order': order,
98
+ 'seed': seed,
99
+ }
100
+ res = pb.UK_forward(**self.params)
101
+ self.output = np.median(res['values'], axis=1)
102
+
103
+ class MgCa:
104
+ def __init__(self, record:obs.ProxyRecord=None):
105
+ self.record = record
106
+
107
+ @property
108
+ def clim_vns(self):
109
+ return ['TEMP', 'tos', 'sst', 'SALT', 'sos', 'sss']
110
+
111
+ def forward(self, age, omega=None, pH=None, clean=None, species=None, sw=2, H=1, seed=2333):
112
+ if 'TEMP' in self.record.clim and 'SALT' in self.record.clim:
113
+ sst = self.record.clim['TEMP'].isel(z_t=0).values
114
+ sss = self.record.clim['SALT'].isel(z_t=0).values
115
+ elif 'tos' in self.record.clim and 'sos' in self.record.clim:
116
+ sst = self.record.clim['tos'].values
117
+ sss = self.record.clim['sos'].values
118
+ elif 'sst' in self.record.clim and 'sss' in self.record.clim:
119
+ sst = self.record.clim['sst'].values
120
+ sss = self.record.clim['sss'].values
121
+
122
+ # get omega and pH
123
+ lat = self.record.data.lat
124
+ lon = self.record.data.lon
125
+ depth = self.record.data.depth
126
+ if omega is None and pH is None:
127
+ lon180 = np.mod(lon + 180, 360) - 180
128
+ omega, pH = pb.core.omgph(lat, lon180, depth)
129
+
130
+ if clean is None: clean = self.record.data.clean
131
+ if species is None: species = self.record.data.species
132
+
133
+ # run
134
+ self.params = {
135
+ 'age': age,
136
+ 'sst': sst,
137
+ 'salinity': sss,
138
+ 'pH': pH,
139
+ 'omega': omega,
140
+ 'species': species,
141
+ 'clean': clean,
142
+ 'sw': sw,
143
+ 'H': H,
144
+ 'seed': seed,
145
+ }
146
+ res = pb.MgCa_forward(**self.params)
147
+ self.output = np.median(res['values'], axis=1)
@@ -0,0 +1,103 @@
1
+ import numpy as np
2
+ import xarray as xr
3
+ import colorama as ca
4
+
5
+ def p_header(text):
6
+ print(ca.Fore.CYAN + ca.Style.BRIGHT + text + ca.Style.RESET_ALL)
7
+
8
+ def p_hint(text):
9
+ print(ca.Fore.LIGHTBLACK_EX + ca.Style.BRIGHT + text + ca.Style.RESET_ALL)
10
+
11
+ def p_success(text):
12
+ print(ca.Fore.GREEN + ca.Style.BRIGHT + text + ca.Style.RESET_ALL)
13
+
14
+ def p_fail(text):
15
+ print(ca.Fore.RED + ca.Style.BRIGHT + text + ca.Style.RESET_ALL)
16
+
17
+ def p_warning(text):
18
+ print(ca.Fore.YELLOW + ca.Style.BRIGHT + text + ca.Style.RESET_ALL)
19
+
20
+ def gcd(lat1, lon1, lat2, lon2, radius=6378.137):
21
+ ''' 2D Great Circle Distance [km]
22
+
23
+ Args:
24
+ radius (float): Earth radius
25
+ '''
26
+ # Convert degrees to radians
27
+ lon1, lat1, lon2, lat2 = map(np.radians, [lon1, lat1, lon2, lat2])
28
+ dlat, dlon = lat2 - lat1, lon2 - lon1
29
+ a = np.sin(dlat / 2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon / 2)**2
30
+ c = 2 * np.arcsin(np.sqrt(a))
31
+ dist = radius * c
32
+ return dist
33
+
34
+ def states2ds(states, ds):
35
+ original_shapes = {vn: ds[vn].shape for vn in ds.data_vars}
36
+ original_dims = {vn: ds[vn].dims for vn in ds.data_vars}
37
+ original_coords = {vn: ds[vn].coords for vn in ds.data_vars}
38
+ ds_out = xr.Dataset()
39
+ start_loc = 0
40
+ for vn in ds.data_vars:
41
+ if 'ens' in original_dims[vn]:
42
+ end_loc = start_loc + np.prod(original_shapes[vn][:-1])
43
+ # p_hint(f'{np.prod(original_shapes[vn][:-1]) = }')
44
+ else:
45
+ end_loc = start_loc + np.prod(original_shapes[vn])
46
+ # p_hint(f'{np.prod(original_shapes[vn]) = }')
47
+
48
+ # p_hint(f'{vn = }')
49
+ # p_hint(f'{start_loc = }')
50
+ # p_hint(f'{end_loc = }')
51
+ # p_hint(f'{np.shape(states) = }')
52
+ # p_hint(f'{np.shape(states[start_loc:end_loc]) = }')
53
+ data = states[start_loc:end_loc].reshape(original_shapes[vn])
54
+ nan_mask = np.isnan(ds[vn].values)
55
+ data[nan_mask] = np.nan
56
+
57
+ ds_out[vn] = xr.DataArray(
58
+ data,
59
+ dims=original_dims[vn],
60
+ coords=original_coords[vn],
61
+ )
62
+ start_loc = end_loc
63
+ ds_out[vn].attrs = ds[vn].attrs
64
+
65
+ return ds_out
66
+
67
+
68
+ # def gcd_3d(loc1, loc2, radius=6371.0):
69
+ # ''' 3D Great Circle Distance [km]
70
+
71
+ # Args:
72
+ # loc1 (tuple): lat1 [degree], lon1 [degree], depth1 [km]
73
+ # loc2 (tuple): lat2 [degree], lon2 [degree], depth2 [km]
74
+ # radius (float): Earth radius
75
+ # '''
76
+ # lat1, lon1, depth1 = loc1
77
+ # lat2, lon2, depth2 = loc2
78
+
79
+ # # Convert degrees to radians
80
+ # lat1, lon1 = np.radians(lat1), np.radians(lon1)
81
+ # lat2, lon2 = np.radians(lat2), np.radians(lon2)
82
+
83
+ # # Calculate radial distances (Earth's radius minus depth)
84
+ # r1 = radius - depth1
85
+ # r2 = radius - depth2
86
+
87
+ # # Compute central angle component
88
+ # central_angle = np.sin(lat1) * np.sin(lat2) + np.cos(lat1) * np.cos(lat2) * np.cos(lon2 - lon1)
89
+
90
+ # # Compute the 3D distance
91
+ # distance_3d = np.sqrt(r1**2 + r2**2 - 2 * r1 * r2 * central_angle)
92
+
93
+ # return distance_3d
94
+
95
+ def str2list(s, sep=','):
96
+ l = [int(ss.strip()) for ss in s.split(sep)]
97
+ return l
98
+
99
+ def lon360(lon180):
100
+ return np.mod(lon180, 360)
101
+
102
+ def lon180(lon360):
103
+ return np.mod(lon360 + 180, 360) - 180
@@ -0,0 +1,25 @@
1
+ Metadata-Version: 2.1
2
+ Name: mlda
3
+ Version: 2024.11.22
4
+ Summary: mlda: A Python package for Machine Learning-base Data Assimilation
5
+ Home-page: https://github.com/fzhu2e/mlda
6
+ Author: Feng Zhu, Weimin Si
7
+ Author-email: fengzhu@ucar.edu, weimin_si@brown.edu
8
+ License: BSD-3
9
+ Keywords: Machine Learning,Data Assimilation
10
+ Classifier: Natural Language :: English
11
+ Classifier: Programming Language :: Python :: 3.12
12
+ Description-Content-Type: text/markdown
13
+ License-File: LICENSE
14
+ Requires-Dist: netCDF4
15
+ Requires-Dist: xarray
16
+ Requires-Dist: dask
17
+ Requires-Dist: nc-time-axis
18
+ Requires-Dist: colorama
19
+ Requires-Dist: tqdm
20
+ Requires-Dist: x4c-exp
21
+
22
+ # mlda: A Python package for Machine Learning-based Data Assimilation
23
+
24
+ `mlda` is a Python package for Machine Learning-base Data Assimilation (DA).
25
+ It aims to provide a universal framework and the corresponding utilities for conducting reproducible data assimilation experiments using novel machine learning-based DA methods.
@@ -0,0 +1,15 @@
1
+ LICENSE
2
+ README.md
3
+ setup.py
4
+ mlda/__init__.py
5
+ mlda/da.py
6
+ mlda/obs.py
7
+ mlda/prior.py
8
+ mlda/psm.py
9
+ mlda/utils.py
10
+ mlda.egg-info/PKG-INFO
11
+ mlda.egg-info/SOURCES.txt
12
+ mlda.egg-info/dependency_links.txt
13
+ mlda.egg-info/not-zip-safe
14
+ mlda.egg-info/requires.txt
15
+ mlda.egg-info/top_level.txt
@@ -0,0 +1 @@
1
+
@@ -0,0 +1,7 @@
1
+ netCDF4
2
+ xarray
3
+ dask
4
+ nc-time-axis
5
+ colorama
6
+ tqdm
7
+ x4c-exp
@@ -0,0 +1 @@
1
+ mlda
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,33 @@
1
+ from setuptools import setup, find_packages
2
+
3
+ with open('README.md', 'r') as fh:
4
+ long_description = fh.read()
5
+
6
+ setup(
7
+ name='mlda', # required
8
+ version='2024.11.22',
9
+ description='mlda: A Python package for Machine Learning-base Data Assimilation',
10
+ long_description=long_description,
11
+ long_description_content_type='text/markdown',
12
+ author='Feng Zhu, Weimin Si',
13
+ author_email='fengzhu@ucar.edu, weimin_si@brown.edu',
14
+ url='https://github.com/fzhu2e/mlda',
15
+ packages=find_packages(),
16
+ include_package_data=True,
17
+ license='BSD-3',
18
+ zip_safe=False,
19
+ keywords=['Machine Learning', 'Data Assimilation'],
20
+ classifiers=[
21
+ 'Natural Language :: English',
22
+ 'Programming Language :: Python :: 3.12',
23
+ ],
24
+ install_requires=[
25
+ 'netCDF4',
26
+ 'xarray',
27
+ 'dask',
28
+ 'nc-time-axis',
29
+ 'colorama',
30
+ 'tqdm',
31
+ 'x4c-exp',
32
+ ],
33
+ )