pythonCRO 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyCRO/fitting.py ADDED
@@ -0,0 +1,259 @@
1
+ import os
2
+ import sys
3
+ import numpy as np
4
+ import pandas as pd
5
+
6
+ from .fit_LR import fit_LR
7
+ from .fit_MLE import fit_MLE
8
+
9
+ def func_default_fitting_method(par_option_T, par_option_h, par_option_noise, table_path='table_default_fitting_method.txt', verbose=True):
10
+ """
11
+ Determine the default fitting method for the CRO (Community Recharge Oscillator)
12
+ model based on prescribed parameter and noise options.
13
+
14
+ The function reads a lookup table (`table_default_fitting_method.txt`) that maps
15
+ combinations of seasonality, linearity, noise color, and noise amplitude type
16
+ to a recommended fitting method.
17
+
18
+ Parameters
19
+ ----------
20
+ par_option_T : list of int
21
+ Options for prescribed terms in the SST equation.
22
+ par_option_h : list of int
23
+ Options for prescribed terms in the thermocline equation.
24
+ par_option_noise : list of int
25
+ Noise configuration array: [noise_color_T, noise_color_h, noise_amp_type].
26
+ table_path : str, optional
27
+ Path to the tab-delimited lookup table text file (default: 'table_default_fitting_method.txt').
28
+
29
+ Returns
30
+ -------
31
+ str
32
+ Recommended fitting method, e.g., 'LR-F', 'LR-C', 'LR-F-MAC', or 'MLE'.
33
+ Defaults to 'LR-F' if no matching entry is found.
34
+
35
+ Notes
36
+ -----
37
+ - Determines whether the model is seasonal or constant.
38
+ - Determines whether the dynamics are linear or nonlinear.
39
+ - Determines the noise type (red/white) and amplitude (multiplicative/additive).
40
+ - The function prints the selected fitting method.
41
+ """
42
+
43
+ # Get path relative to the script file
44
+ script_dir = os.path.dirname(os.path.abspath(__file__))
45
+ table_path = os.path.join(script_dir, table_path)
46
+
47
+ # Determine seasonality
48
+ seasonal_type = "seasonal" if any(p > 1 for p in par_option_T + par_option_h) else "constant"
49
+
50
+ # Determine linearity
51
+ if par_option_T[2] == 0 and par_option_T[3] == 0 and par_option_T[4] == 0 and par_option_h[2] == 0:
52
+ det_type = "linear"
53
+ else:
54
+ det_type = "nonlinear"
55
+
56
+ # Noise settings
57
+ if par_option_noise[0] == 0:
58
+ noise_color_type = "red"
59
+ elif par_option_noise[0] == 1:
60
+ noise_color_type = "white"
61
+
62
+ if par_option_noise[2] == 0:
63
+ noise_amp_type = "multiplicative"
64
+ elif par_option_noise[2] == 1:
65
+ noise_amp_type = "multiplicative-H"
66
+ elif par_option_noise[2] == 2:
67
+ noise_amp_type = "additive"
68
+
69
+ # Load the method lookup table from a tab-delimited text file
70
+ df = pd.read_csv(table_path, sep='\\s+')
71
+
72
+ # Filter for matching row
73
+ matched = df[
74
+ (df["seasonal_type"] == seasonal_type) &
75
+ (df["det_type"] == det_type) &
76
+ (df["noise_color_type"].astype(str) == noise_color_type) &
77
+ (df["noise_amp_type"].astype(str) == noise_amp_type)
78
+ ]
79
+
80
+ if not matched.empty:
81
+ if verbose:
82
+ print("Referring to table_default_fitting_method.txt and using "+matched["fitting_method"].iloc[0])
83
+ return matched["fitting_method"].iloc[0]
84
+ else:
85
+ print("Warning: No matching fitting method found. Defaulting to 'LR-F'.")
86
+ return "LR-F"
87
+
88
+ ########################################################################################################
89
+
90
+ def RO_fitting(T, h, par_option_T, par_option_h, par_option_noise, method_fitting=None, dt=None, verbose=True):
91
+ """
92
+ Fit the Recharge Oscillator (RO) model parameters to T and h time series.
93
+
94
+ This function performs parameter fitting for the RO system, using the specified
95
+ prescribed parameter options and noise characteristics. Fitting can be done
96
+ via linear regression (LR) or maximum likelihood estimation (MLE), with optional
97
+ handling for multiplicative or additive noise. It can automatically select a
98
+ default fitting method based on parameter options.
99
+
100
+ Parameters
101
+ ----------
102
+ T : ndarray
103
+ Time series of SST anomalies (1D array).
104
+ h : ndarray
105
+ Time series of thermocline anomalies (1D array).
106
+ par_option_T : dict
107
+ Prescribed options for SST equation parameters.
108
+ par_option_h : dict
109
+ Prescribed options for thermocline equation parameters.
110
+ par_option_noise : dict
111
+ Noise settings, e.g., {'T': 'red', 'h': 'red', 'T_type': 'multiplicative'}.
112
+ method_fitting : str, optional
113
+ Fitting method to use: 'LR-F', 'LR-C', 'LR-F-MAC', or 'MLE'.
114
+ If None, the method is determined automatically.
115
+ dt : float, optional
116
+ Time step of the input time series. Default is 1.0 (months).
117
+ verbose : bool, optional
118
+ If True, prints fitting information and progress.
119
+
120
+ Returns
121
+ -------
122
+ dict
123
+ Dictionary of fitted CRO parameters:
124
+ Keys include:
125
+ 'R', 'F1', 'F2', 'epsilon', 'b_T', 'c_T', 'd_T', 'b_h',
126
+ 'sigma_T', 'sigma_h', 'B', 'm_T', 'm_h', 'n_T', 'n_h', 'n_g'.
127
+
128
+ Notes
129
+ -----
130
+ - Automatically removes mean from input time series.
131
+ - Applies Ito-to-Stratonovich correction for multiplicative noise.
132
+ - Negligible seasonal amplitudes are rounded to zero.
133
+ - Ensures consistent noise settings for T and h equations.
134
+ - Prints a summary of the fitting process if `verbose` is True.
135
+
136
+ Examples
137
+ --------
138
+ >>> par_option_T = {'mean': 1, 'seasonal': 3, 'semi_annual': 0, ...}
139
+ >>> par_option_h = {'mean': 1, 'seasonal': 3, 'semi_annual': 0, ...}
140
+ >>> par_option_noise = {'T': 'red', 'h': 'red', 'T_type': 'multiplicative'}
141
+ >>> fitted_params = RO_fitting(T, h, par_option_T, par_option_h, par_option_noise)
142
+ >>> print(fitted_params['sigma_T'])
143
+ """
144
+
145
+ fitting_option_red = "ARn" # "LR" or "AR1" or "ARn"
146
+
147
+ ### Checking fitting setups ###
148
+ if verbose:
149
+ print("---------------------------------------------------------------------------------")
150
+ print("Welcome to CRO Fitting! Your fitting setups:")
151
+ print("---------------------------------------------------------------------------------")
152
+
153
+ if dt is None:
154
+ dt = 1.0
155
+ if verbose:
156
+ print(f" - Data time step is not given, defaulting to: dt = 1.0 months.")
157
+ if verbose:
158
+ print(f" - Time series length: N = len(T)*dt = {len(T)*dt} months.")
159
+
160
+ print(f" - Prescribed terms: {par_option_T}. \n"
161
+ f" {par_option_h}. \n"
162
+ " 0 - Do not prescribe. \n"
163
+ " 1 - Prescribe only the annual mean. \n"
164
+ " 3 - Prescribe the annual mean and annual seasonality. \n"
165
+ " 5 - Prescribe the annual mean, annual seasonality, and semi-annual seasonality.")
166
+
167
+ if par_option_noise['T'] != par_option_noise['h']:
168
+ raise ValueError(f"par_option_noise['T'] = {par_option_noise['T']}, "
169
+ f"par_option_noise['h'] = {par_option_noise['h']}\n"
170
+ "Fitting methods for T and h equations should be the same.")
171
+
172
+ if verbose:
173
+ print(f" - Noise options: {par_option_noise}.")
174
+ if (par_option_noise['T'] == 'red') and (par_option_noise['h'] == 'red'):
175
+ print(f" - Fitting method for T and h red noises: {fitting_option_red}.\n"
176
+ f" This option is defined internally within fit.py.\n"
177
+ f" Options available are: LR or AR1 or ARn.")
178
+
179
+ # Convert parameter options
180
+ par_option_T_vals = list(par_option_T.values())
181
+ par_option_h_vals = list(par_option_h.values())
182
+ noise_keys = ['T', 'h', 'T_type']
183
+
184
+ noise_map = {
185
+ 'white': 1,
186
+ 'red': 0,
187
+ 'additive': 2,
188
+ 'multiplicative': 0,
189
+ 'multiplicative-H': 1
190
+ }
191
+ par_option_noise_array = [noise_map[str(par_option_noise[k])] for k in noise_keys]
192
+
193
+ if verbose:
194
+ print(f" - Fitting method for T and h main equations: {method_fitting}.")
195
+
196
+ if method_fitting is None:
197
+ method_fitting = func_default_fitting_method(par_option_T_vals, par_option_h_vals, par_option_noise_array, verbose=verbose)
198
+
199
+ # Perform fitting
200
+ if method_fitting == "LR-F":
201
+ par = fit_LR(T, h, par_option_T_vals, par_option_h_vals,
202
+ par_option_noise_array, dt, "F", "LR", fitting_option_red)
203
+ elif method_fitting == "LR-C":
204
+ par = fit_LR(T, h, par_option_T_vals, par_option_h_vals,
205
+ par_option_noise_array, dt, "C", "LR", fitting_option_red)
206
+ elif method_fitting == "LR-F-MAC":
207
+ par = fit_LR(T, h, par_option_T_vals, par_option_h_vals,
208
+ par_option_noise_array, dt, "F", "MAC", fitting_option_red)
209
+ elif method_fitting == "MLE":
210
+ par = fit_MLE(T, h, par_option_T_vals, par_option_h_vals,
211
+ par_option_noise_array, dt)
212
+ else:
213
+ raise ValueError(f"Unknown fitting method: {method_fitting}")
214
+
215
+ if method_fitting in ("LR-F", "LR-F-MAC", "MLE"):
216
+ if par[15][0] == 0: # Ito to Stratonovich Conversion for multiplicative noise
217
+ par[0][0] = par[0][0] - 0.5 * (par[8][0] * par[10][0])**2
218
+ elif par[15][0] == 1: # Ito to Stratonovich Conversion for Heaviside multiplicative noise
219
+ par[0][0] = par[0][0] - 0.25 * (par[8][0] * par[10][0])**2
220
+ elif par[15][0] == 2: # No correction needed for additive noise
221
+ par[0][0] = par[0][0]
222
+
223
+ ### Round negligible seasonal amplitudes of linear/nonlinear parameters to zero ###
224
+ for i in range(0, 8):
225
+ if par[i][1] < 1e-4:
226
+ par[i][1] = 0.0
227
+ par[i][2] = 0.0
228
+ if par[i][3] < 1e-4:
229
+ par[i][3] = 0.0
230
+ par[i][4] = 0.0
231
+
232
+ # Filter out zero values
233
+ Arr = []
234
+ for i in range(0, 13):
235
+ arr = par[i]
236
+ arr = arr[arr != 0]
237
+ Arr.append(arr)
238
+
239
+ # Organize into dictionary
240
+ param_keys = [
241
+ "R", "F1", "F2", "epsilon", "b_T", "c_T", "d_T", "b_h",
242
+ "sigma_T", "sigma_h", "B", "m_T", "m_h", "n_T", "n_h", "n_g"
243
+ ]
244
+ param_values = [
245
+ Arr[0], Arr[1], Arr[5], Arr[6], Arr[2], Arr[3], Arr[4], Arr[7],
246
+ Arr[8], Arr[9], Arr[10], Arr[11], Arr[12],
247
+ [int(par[13][0])], [int(par[14][0])], [int(par[15][0])]
248
+ ]
249
+
250
+ par = dict(zip(param_keys, param_values))
251
+ par = {k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in par.items()}
252
+
253
+ ### Final print ###
254
+ if verbose:
255
+ print("---------------------------------------------------------------------------------")
256
+ print("All steps are successfully completed!")
257
+ print("---------------------------------------------------------------------------------")
258
+
259
+ return par
pyCRO/par_load.py ADDED
@@ -0,0 +1,185 @@
1
+ import re
2
+ import os
3
+ import sys
4
+
5
+ import numpy as np
6
+
7
+ script_dir = os.path.dirname(os.path.abspath(__file__))
8
+ _MAT_FILENAME = os.path.join(script_dir, "../data/CRO_parlib_v0.0.mat")
9
+
10
+ def _try_load_mat(fname):
11
+ """Load .mat (v7 via scipy; fallback to v7.3 via mat73)."""
12
+ try:
13
+ from scipy.io import loadmat
14
+ mat = loadmat(fname, squeeze_me=True, struct_as_record=False)
15
+ # strip MATLAB metadata keys
16
+ return {k: v for k, v in mat.items() if not k.startswith("__")}
17
+ except Exception:
18
+ # v7.3 (HDF5) fallback if mat73 is available
19
+ try:
20
+ import mat73
21
+ return mat73.loadmat(fname)
22
+ except Exception as e:
23
+ raise RuntimeError(f"Failed to load {fname} with scipy and mat73") from e
24
+
25
+ def _to_str_array(x):
26
+ """Normalize MATLAB string/cellstr/char arrays -> numpy array of Python str (same shape)."""
27
+ x = np.asarray(x, dtype=object)
28
+ out = np.empty(x.shape, dtype=object)
29
+ it = np.nditer(out, flags=['multi_index', 'refs_ok'], op_flags=['writeonly'])
30
+ while not it.finished:
31
+ v = x[it.multi_index]
32
+ if isinstance(v, str):
33
+ s = v
34
+ elif isinstance(v, bytes):
35
+ s = v.decode("utf-8", errors="ignore")
36
+ elif isinstance(v, np.ndarray) and v.dtype.kind in ("U", "S"):
37
+ # MATLAB char array -> join characters
38
+ s = "".join(map(str, v.tolist()))
39
+ else:
40
+ s = str(v)
41
+ it[0] = s
42
+ it.iternext()
43
+ return out.astype(str)
44
+
45
+ def _as_col_cell(obj):
46
+ """
47
+ Ensure a (16,1) numpy object array from a MATLAB 16x1 cell stored inside S['par'] element.
48
+ The element could already be object array (16,), (16,1), list of 16, etc.
49
+ """
50
+ arr = obj
51
+ # Convert lists/tuples to np.object array
52
+ if isinstance(arr, (list, tuple)):
53
+ arr = np.array(arr, dtype=object)
54
+ if isinstance(arr, np.ndarray):
55
+ # flatten then reshape to (16,1)
56
+ arr = arr.astype(object)
57
+ arr = arr.reshape(-1, order="F") # MATLAB-friendly flatten
58
+ if arr.size != 16:
59
+ raise ValueError(f"Expected 16 elements, got {arr.size}")
60
+ return arr.reshape(16, 1, order="F")
61
+ # Anything else: treat as scalar and fail
62
+ raise TypeError("Unexpected parameter cell content type")
63
+
64
+ def par_load(data_name: str, ro_name: str):
65
+ """
66
+ Load CRO parameters by (data_name, ro_name) from CRO_parlib_v0.0.mat.
67
+ - Exact match -> returns (16,1) object array
68
+ - If data_name endswith '-all' -> loads <base>-<number> (case-insensitive on base),
69
+ exact match on ro_name, sorts by number, returns (16,N) object array.
70
+ """
71
+ S = _try_load_mat(_MAT_FILENAME)
72
+ par = S['par']
73
+ # print(par.shape) # (8, 49)
74
+
75
+ data_name = str(data_name)
76
+ ro_name = str(ro_name)
77
+
78
+
79
+ # ---- exact single match ----
80
+ if ro_name == "Linear-White-Additive":
81
+ ro_name_index = 0
82
+ elif ro_name == "Seasonal-Linear-White-Additive":
83
+ ro_name_index = 1
84
+ elif ro_name == "Nonlinear-White-Additive":
85
+ ro_name_index = 2
86
+ elif ro_name == "Seasonal-Nonlinear-White-Additive":
87
+ ro_name_index = 3
88
+ elif ro_name == "Linear-White-Multiplicative":
89
+ ro_name_index = 4
90
+ elif ro_name == "Seasonal-Linear-White-Multiplicative":
91
+ ro_name_index = 5
92
+ elif ro_name == "Nonlinear-White-Multiplicative":
93
+ ro_name_index = 6
94
+ elif ro_name == "Seasonal-Nonlinear-White-Multiplicative":
95
+ ro_name_index = 7
96
+ else:
97
+ raise ValueError(f"Wrong input for RO_type")
98
+
99
+
100
+ if data_name == "CMIP6-historical-all":
101
+ my_parr = []
102
+ for data_name_index in range(1,49):
103
+ row = par[ro_name_index, data_name_index]
104
+
105
+ R, F1, F2, epsilon, b_T, c_T, d_T, b_h, sigma_T, sigma_h, B, m_T, m_h, n_T, n_h, n_g = row
106
+
107
+ def to_list(x):
108
+ if isinstance(x, np.ndarray):
109
+ if x.size == 0: # empty placeholder
110
+ return []
111
+ if x.ndim == 0: # scalar-like
112
+ return [x.item()]
113
+ return x.ravel().tolist()
114
+ # numpy or python scalar
115
+ try:
116
+ return [x.item()]
117
+ except AttributeError:
118
+ return [x] if np.isscalar(x) else [x]
119
+
120
+ my_par = {
121
+ 'R': to_list(R),
122
+ 'F1': to_list(F1),
123
+ 'F2': to_list(F2),
124
+ 'epsilon': to_list(epsilon),
125
+ 'b_T': to_list(b_T),
126
+ 'c_T': to_list(c_T),
127
+ 'd_T': to_list(d_T),
128
+ 'b_h': to_list(b_h),
129
+ 'sigma_T': to_list(sigma_T),
130
+ 'sigma_h': to_list(sigma_h),
131
+ 'B': to_list(B),
132
+ 'm_T': to_list(m_T),
133
+ 'm_h': to_list(m_h),
134
+ 'n_T': to_list(n_T),
135
+ 'n_h': to_list(n_h),
136
+ 'n_g': to_list(n_g),
137
+ }
138
+ my_parr.append(my_par)
139
+ return my_parr
140
+ else:
141
+ data_name_mapping = {"ORAS5": 0}
142
+ data_name_mapping.update({f"CMIP6-historical-{i}": i for i in range(1, 49)})
143
+
144
+ try:
145
+ data_name_index = data_name_mapping[data_name]
146
+ except KeyError:
147
+ raise ValueError("Invalid input for `data_name` or `ro_name`")
148
+
149
+ row = par[ro_name_index, data_name_index]
150
+
151
+ R, F1, F2, epsilon, b_T, c_T, d_T, b_h, sigma_T, sigma_h, B, m_T, m_h, n_T, n_h, n_g = row
152
+
153
+ def to_list(x):
154
+ if isinstance(x, np.ndarray):
155
+ if x.size == 0: # empty placeholder
156
+ return []
157
+ if x.ndim == 0: # scalar-like
158
+ return [x.item()]
159
+ return x.ravel().tolist()
160
+ # numpy or python scalar
161
+ try:
162
+ return [x.item()]
163
+ except AttributeError:
164
+ return [x] if np.isscalar(x) else [x]
165
+
166
+ my_par = {
167
+ 'R': to_list(R),
168
+ 'F1': to_list(F1),
169
+ 'F2': to_list(F2),
170
+ 'epsilon': to_list(epsilon),
171
+ 'b_T': to_list(b_T),
172
+ 'c_T': to_list(c_T),
173
+ 'd_T': to_list(d_T),
174
+ 'b_h': to_list(b_h),
175
+ 'sigma_T': to_list(sigma_T),
176
+ 'sigma_h': to_list(sigma_h),
177
+ 'B': to_list(B),
178
+ 'm_T': to_list(m_T),
179
+ 'm_h': to_list(m_h),
180
+ 'n_T': to_list(n_T),
181
+ 'n_h': to_list(n_h),
182
+ 'n_g': to_list(n_g),
183
+ }
184
+
185
+ return my_par