python-gls 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,190 @@
1
+ """Spatial correlation structures.
2
+
3
+ Implements isotropic spatial correlation functions parameterized by
4
+ range and optional nugget effect.
5
+ """
6
+
7
+ import numpy as np
8
+ from numpy.typing import NDArray
9
+
10
+ from python_gls.correlation.base import CorStruct
11
+
12
+
13
+ class _SpatialCorStruct(CorStruct):
14
+ """Base class for spatial correlation structures.
15
+
16
+ All spatial structures share:
17
+ - A range parameter controlling decay
18
+ - An optional nugget parameter (discontinuity at distance 0)
19
+ - Distance matrices stored per group
20
+
21
+ Parameters
22
+ ----------
23
+ range_param : float, optional
24
+ Initial range parameter (> 0).
25
+ nugget : bool
26
+ Whether to include a nugget effect.
27
+ """
28
+
29
+ def __init__(self, range_param: float | None = None, nugget: bool = False):
30
+ super().__init__()
31
+ if range_param is not None:
32
+ if not isinstance(range_param, (int, float)):
33
+ raise TypeError(
34
+ f"range_param must be a number, got {type(range_param).__name__}"
35
+ )
36
+ if range_param <= 0:
37
+ raise ValueError(
38
+ f"range_param must be positive, got {range_param}"
39
+ )
40
+ if not isinstance(nugget, bool):
41
+ raise TypeError(f"nugget must be a boolean, got {type(nugget).__name__}")
42
+ self._nugget = nugget
43
+ self._distances: dict[int, NDArray] = {}
44
+ if range_param is not None:
45
+ if nugget:
46
+ self._params = np.array([float(range_param), 0.0])
47
+ else:
48
+ self._params = np.array([float(range_param)])
49
+
50
+ @property
51
+ def n_params(self) -> int:
52
+ return 2 if self._nugget else 1
53
+
54
+ def set_distances(self, group_id: int, dist_matrix: NDArray) -> None:
55
+ """Set the distance matrix for a group."""
56
+ self._distances[group_id] = np.asarray(dist_matrix, dtype=float)
57
+
58
+ def set_coordinates(self, group_id: int, coords: NDArray) -> None:
59
+ """Set coordinates for a group; distances computed automatically."""
60
+ coords = np.asarray(coords, dtype=float)
61
+ if coords.ndim == 1:
62
+ coords = coords[:, None]
63
+ from scipy.spatial.distance import cdist
64
+ self._distances[group_id] = cdist(coords, coords)
65
+
66
+ def _correlation_function(self, d: NDArray, range_param: float) -> NDArray:
67
+ """Compute correlation from distances. Override in subclasses."""
68
+ raise NotImplementedError
69
+
70
+ def get_correlation_matrix(self, group_size: int, **kwargs) -> NDArray:
71
+ if self._params is None:
72
+ return np.eye(group_size)
73
+
74
+ range_param = self._params[0]
75
+ group_id = kwargs.get("group_id", None)
76
+
77
+ if group_id is not None and group_id in self._distances:
78
+ dist = self._distances[group_id]
79
+ else:
80
+ # Default: unit-spaced
81
+ idx = np.arange(group_size, dtype=float)
82
+ dist = np.abs(idx[:, None] - idx[None, :])
83
+
84
+ R = self._correlation_function(dist, range_param)
85
+
86
+ if self._nugget and len(self._params) > 1:
87
+ nug = 1 / (1 + np.exp(-self._params[1])) # sigmoid to (0, 1)
88
+ R = (1 - nug) * R
89
+ np.fill_diagonal(R, 1.0)
90
+
91
+ return R
92
+
93
+ def _get_init_params(self, residuals_by_group: list[NDArray]) -> NDArray:
94
+ # Heuristic: set range to median distance
95
+ if self._distances:
96
+ all_dists = []
97
+ for d in self._distances.values():
98
+ mask = np.triu(np.ones(d.shape, dtype=bool), k=1)
99
+ all_dists.extend(d[mask].tolist())
100
+ if all_dists:
101
+ range_init = np.median(all_dists)
102
+ else:
103
+ range_init = 1.0
104
+ else:
105
+ range_init = 1.0
106
+ if self._nugget:
107
+ return np.array([range_init, 0.0])
108
+ return np.array([range_init])
109
+
110
+ def _params_to_unconstrained(self, params: NDArray) -> NDArray:
111
+ # range > 0: use log transform
112
+ u = np.zeros_like(params)
113
+ u[0] = np.log(max(params[0], 1e-10))
114
+ if self._nugget and len(params) > 1:
115
+ u[1] = params[1] # already unconstrained (sigmoid applied in get_corr)
116
+ return u
117
+
118
+ def _unconstrained_to_params(self, uparams: NDArray) -> NDArray:
119
+ p = np.zeros_like(uparams)
120
+ p[0] = np.exp(uparams[0])
121
+ if self._nugget and len(uparams) > 1:
122
+ p[1] = uparams[1]
123
+ return p
124
+
125
+
126
+ class CorExp(_SpatialCorStruct):
127
+ """Exponential spatial correlation.
128
+
129
+ R(d) = exp(-d / range).
130
+
131
+ Equivalent to R's `corExp()`.
132
+ """
133
+
134
+ def _correlation_function(self, d: NDArray, range_param: float) -> NDArray:
135
+ return np.exp(-d / range_param)
136
+
137
+
138
+ class CorGaus(_SpatialCorStruct):
139
+ """Gaussian spatial correlation.
140
+
141
+ R(d) = exp(-(d/range)^2).
142
+
143
+ Equivalent to R's `corGaus()`.
144
+ """
145
+
146
+ def _correlation_function(self, d: NDArray, range_param: float) -> NDArray:
147
+ return np.exp(-(d / range_param) ** 2)
148
+
149
+
150
+ class CorLin(_SpatialCorStruct):
151
+ """Linear spatial correlation.
152
+
153
+ R(d) = max(1 - d/range, 0).
154
+
155
+ Equivalent to R's `corLin()`.
156
+ """
157
+
158
+ def _correlation_function(self, d: NDArray, range_param: float) -> NDArray:
159
+ return np.maximum(1 - d / range_param, 0)
160
+
161
+
162
+ class CorRatio(_SpatialCorStruct):
163
+ """Rational quadratic spatial correlation.
164
+
165
+ R(d) = 1 / (1 + (d/range)^2).
166
+
167
+ Equivalent to R's `corRatio()`.
168
+ """
169
+
170
+ def _correlation_function(self, d: NDArray, range_param: float) -> NDArray:
171
+ return 1.0 / (1.0 + (d / range_param) ** 2)
172
+
173
+
174
+ class CorSpher(_SpatialCorStruct):
175
+ """Spherical spatial correlation.
176
+
177
+ R(d) = 1 - 1.5*(d/range) + 0.5*(d/range)^3 for d < range
178
+ R(d) = 0 for d >= range
179
+
180
+ Equivalent to R's `corSpher()`.
181
+ """
182
+
183
+ def _correlation_function(self, d: NDArray, range_param: float) -> NDArray:
184
+ ratio = d / range_param
185
+ R = np.where(
186
+ ratio < 1,
187
+ 1 - 1.5 * ratio + 0.5 * ratio ** 3,
188
+ 0.0,
189
+ )
190
+ return R
@@ -0,0 +1,85 @@
1
+ """Unstructured (general symmetric) correlation structure."""
2
+
3
+ import numpy as np
4
+ from numpy.typing import NDArray
5
+
6
+ from python_gls.correlation.base import CorStruct
7
+ from python_gls._parametrization import (
8
+ unconstrained_to_corr,
9
+ corr_to_unconstrained,
10
+ )
11
+
12
+
13
+ class CorSymm(CorStruct):
14
+ """Unstructured (general symmetric) correlation.
15
+
16
+ Estimates all d(d-1)/2 unique off-diagonal correlations freely.
17
+ Uses spherical parametrization to ensure positive-definiteness.
18
+
19
+ Equivalent to R's `corSymm()`.
20
+
21
+ Parameters
22
+ ----------
23
+ dim : int or None
24
+ Dimension of the correlation matrix. If None, inferred from data.
25
+ """
26
+
27
+ def __init__(self, dim: int | None = None):
28
+ super().__init__()
29
+ if dim is not None:
30
+ if not isinstance(dim, int):
31
+ raise TypeError(f"dim must be an integer, got {type(dim).__name__}")
32
+ if dim < 2:
33
+ raise ValueError(f"dim must be >= 2 for a correlation matrix, got {dim}")
34
+ self._dim = dim
35
+
36
+ @property
37
+ def n_params(self) -> int:
38
+ if self._dim is None:
39
+ raise ValueError("Dimension not set. Call initialize() first or pass dim=.")
40
+ return self._dim * (self._dim - 1) // 2
41
+
42
+ def get_correlation_matrix(self, group_size: int, **kwargs) -> NDArray:
43
+ """Build unstructured correlation matrix from current parameters."""
44
+ if self._params is None:
45
+ return np.eye(group_size)
46
+ return unconstrained_to_corr(self._params, group_size)
47
+
48
+ def _get_init_params(self, residuals_by_group: list[NDArray]) -> NDArray:
49
+ """Initialize from sample correlation of residuals."""
50
+ # Determine dimension from data
51
+ sizes = [len(r) for r in residuals_by_group]
52
+ if len(set(sizes)) != 1:
53
+ d = max(sizes)
54
+ else:
55
+ d = sizes[0]
56
+ self._dim = d
57
+
58
+ # Compute sample correlation from residuals
59
+ # Stack residuals into matrix (n_groups x d)
60
+ equal_groups = [r for r in residuals_by_group if len(r) == d]
61
+ if len(equal_groups) >= 2:
62
+ resid_mat = np.vstack(equal_groups)
63
+ if resid_mat.shape[0] > d:
64
+ R_sample = np.corrcoef(resid_mat.T)
65
+ else:
66
+ R_sample = np.eye(d)
67
+ else:
68
+ R_sample = np.eye(d)
69
+
70
+ # Ensure positive-definiteness
71
+ eigvals = np.linalg.eigvalsh(R_sample)
72
+ if np.min(eigvals) < 1e-6:
73
+ R_sample = R_sample + (1e-6 - np.min(eigvals) + 1e-6) * np.eye(d)
74
+ # Renormalize to correlation
75
+ d_inv = np.diag(1.0 / np.sqrt(np.diag(R_sample)))
76
+ R_sample = d_inv @ R_sample @ d_inv
77
+
78
+ return corr_to_unconstrained(R_sample)
79
+
80
+ def _params_to_unconstrained(self, params: NDArray) -> NDArray:
81
+ # Params are already unconstrained (spherical parametrization)
82
+ return params.copy()
83
+
84
+ def _unconstrained_to_params(self, uparams: NDArray) -> NDArray:
85
+ return uparams.copy()
@@ -0,0 +1,302 @@
1
+ """ML and REML log-likelihood functions for GLS estimation.
2
+
3
+ Implements profile log-likelihood where fixed effects (beta) are profiled out,
4
+ leaving only correlation and variance parameters to be optimized.
5
+ """
6
+
7
+ import warnings
8
+
9
+ import numpy as np
10
+ from numpy.typing import NDArray
11
+
12
+
13
+ def _build_omega_block(
14
+ corr_matrix: NDArray,
15
+ var_weights: NDArray,
16
+ sigma2: float,
17
+ ) -> NDArray:
18
+ """Build covariance block: sigma^2 * A^{1/2} R A^{1/2}.
19
+
20
+ Parameters
21
+ ----------
22
+ corr_matrix : (m, m) correlation matrix for this group.
23
+ var_weights : (m,) variance weights for this group (standard deviations).
24
+ sigma2 : scalar residual variance.
25
+
26
+ Returns
27
+ -------
28
+ Omega block of shape (m, m).
29
+ """
30
+ A_half = np.diag(var_weights)
31
+ return sigma2 * A_half @ corr_matrix @ A_half
32
+
33
+
34
+ def _build_omega_inv_block(
35
+ corr_matrix: NDArray,
36
+ var_weights: NDArray,
37
+ ) -> NDArray:
38
+ """Build Omega^{-1} block (up to 1/sigma^2 scaling).
39
+
40
+ Returns (1/sigma^2) * A^{-1/2} R^{-1} A^{-1/2}.
41
+ We drop the sigma^2 factor since it cancels in the profile likelihood.
42
+ """
43
+ # Guard against zero/near-zero weights
44
+ safe_weights = np.where(np.abs(var_weights) < 1e-15, 1e-15, var_weights)
45
+ A_inv_half = np.diag(1.0 / safe_weights)
46
+ R_inv = np.linalg.solve(corr_matrix, np.eye(corr_matrix.shape[0]))
47
+ return A_inv_half @ R_inv @ A_inv_half
48
+
49
+
50
+ def _safe_log_weights(var_weights: NDArray) -> float:
51
+ """Compute sum of log(weights) with protection against zero/negative."""
52
+ safe_weights = np.maximum(np.abs(var_weights), 1e-300)
53
+ return float(np.sum(np.log(safe_weights)))
54
+
55
+
56
+ def profile_loglik_ml(
57
+ X_groups: list[NDArray],
58
+ y_groups: list[NDArray],
59
+ corr_matrices: list[NDArray],
60
+ var_weights_groups: list[NDArray],
61
+ nobs: int,
62
+ ) -> float:
63
+ """Profile log-likelihood under ML estimation.
64
+
65
+ Beta and sigma^2 are profiled out. The returned value is the
66
+ concentrated log-likelihood as a function of correlation and variance
67
+ parameters only.
68
+
69
+ Parameters
70
+ ----------
71
+ X_groups : list of (m_g, k) design matrices per group.
72
+ y_groups : list of (m_g,) response vectors per group.
73
+ corr_matrices : list of (m_g, m_g) correlation matrices per group.
74
+ var_weights_groups : list of (m_g,) variance weight vectors per group.
75
+ nobs : int, total number of observations.
76
+
77
+ Returns
78
+ -------
79
+ float : profile log-likelihood value.
80
+ """
81
+ k = X_groups[0].shape[1]
82
+ N = nobs
83
+
84
+ # Accumulate X'Omega^{-1}X and X'Omega^{-1}y and log|Omega| across groups
85
+ XtOiX = np.zeros((k, k))
86
+ XtOiy = np.zeros(k)
87
+ log_det_omega = 0.0
88
+
89
+ for Xg, yg, Rg, wg in zip(X_groups, y_groups, corr_matrices, var_weights_groups):
90
+ Omega_inv = _build_omega_inv_block(Rg, wg)
91
+ XtOiX += Xg.T @ Omega_inv @ Xg
92
+ XtOiy += Xg.T @ Omega_inv @ yg
93
+
94
+ # log|Omega_g| = log|R_g| + 2*sum(log(w_g)) (sigma^2 factor added later)
95
+ sign, logdet_R = np.linalg.slogdet(Rg)
96
+ if sign <= 0:
97
+ return -np.inf
98
+ log_det_omega += logdet_R + 2 * _safe_log_weights(wg)
99
+
100
+ # Profile beta: beta_hat = (X'Omega^{-1}X)^{-1} X'Omega^{-1}y
101
+ try:
102
+ beta_hat = np.linalg.solve(XtOiX, XtOiy)
103
+ except np.linalg.LinAlgError:
104
+ return -np.inf
105
+
106
+ # Profile sigma^2: sigma^2_hat = (1/N) * sum_g (y_g - X_g beta)' Omega_inv_g (y_g - X_g beta)
107
+ rss_weighted = 0.0
108
+ for Xg, yg, Rg, wg in zip(X_groups, y_groups, corr_matrices, var_weights_groups):
109
+ Omega_inv = _build_omega_inv_block(Rg, wg)
110
+ resid = yg - Xg @ beta_hat
111
+ rss_weighted += resid @ Omega_inv @ resid
112
+
113
+ sigma2_hat = rss_weighted / N
114
+
115
+ if sigma2_hat <= 0:
116
+ return -np.inf
117
+
118
+ # Log-likelihood: -N/2 * log(2*pi) - N/2 * log(sigma^2) - 1/2 * log|Omega/sigma^2| - N/2
119
+ # where log|Omega| = N*log(sigma^2) + log_det_omega (without sigma^2)
120
+ loglik = (
121
+ -0.5 * N * np.log(2 * np.pi)
122
+ - 0.5 * N * np.log(sigma2_hat)
123
+ - 0.5 * log_det_omega
124
+ - 0.5 * N
125
+ )
126
+
127
+ if np.isnan(loglik) or np.isinf(loglik):
128
+ return -np.inf
129
+
130
+ return loglik
131
+
132
+
133
+ def profile_loglik_reml(
134
+ X_groups: list[NDArray],
135
+ y_groups: list[NDArray],
136
+ corr_matrices: list[NDArray],
137
+ var_weights_groups: list[NDArray],
138
+ nobs: int,
139
+ ) -> float:
140
+ """Profile log-likelihood under REML estimation.
141
+
142
+ Like ML, but integrates out the fixed effects for unbiased variance
143
+ estimation. The REML adjustment adds -0.5 * log|X'Omega^{-1}X|.
144
+
145
+ Parameters
146
+ ----------
147
+ X_groups : list of (m_g, k) design matrices per group.
148
+ y_groups : list of (m_g,) response vectors per group.
149
+ corr_matrices : list of (m_g, m_g) correlation matrices per group.
150
+ var_weights_groups : list of (m_g,) variance weight vectors per group.
151
+ nobs : int, total number of observations.
152
+
153
+ Returns
154
+ -------
155
+ float : profile REML log-likelihood value.
156
+ """
157
+ k = X_groups[0].shape[1]
158
+ N = nobs
159
+
160
+ XtOiX = np.zeros((k, k))
161
+ XtOiy = np.zeros(k)
162
+ log_det_omega = 0.0
163
+
164
+ for Xg, yg, Rg, wg in zip(X_groups, y_groups, corr_matrices, var_weights_groups):
165
+ Omega_inv = _build_omega_inv_block(Rg, wg)
166
+ XtOiX += Xg.T @ Omega_inv @ Xg
167
+ XtOiy += Xg.T @ Omega_inv @ yg
168
+
169
+ sign, logdet_R = np.linalg.slogdet(Rg)
170
+ if sign <= 0:
171
+ return -np.inf
172
+ log_det_omega += logdet_R + 2 * _safe_log_weights(wg)
173
+
174
+ try:
175
+ beta_hat = np.linalg.solve(XtOiX, XtOiy)
176
+ except np.linalg.LinAlgError:
177
+ return -np.inf
178
+
179
+ # REML uses N-k for sigma^2
180
+ rss_weighted = 0.0
181
+ for Xg, yg, Rg, wg in zip(X_groups, y_groups, corr_matrices, var_weights_groups):
182
+ Omega_inv = _build_omega_inv_block(Rg, wg)
183
+ resid = yg - Xg @ beta_hat
184
+ rss_weighted += resid @ Omega_inv @ resid
185
+
186
+ N_reml = N - k
187
+ if N_reml <= 0:
188
+ return -np.inf
189
+ sigma2_hat = rss_weighted / N_reml
190
+
191
+ if sigma2_hat <= 0:
192
+ return -np.inf
193
+
194
+ # REML log-likelihood
195
+ sign_xtox, logdet_xtox = np.linalg.slogdet(XtOiX)
196
+ if sign_xtox <= 0:
197
+ return -np.inf
198
+
199
+ loglik = (
200
+ -0.5 * N_reml * np.log(2 * np.pi)
201
+ - 0.5 * N_reml * np.log(sigma2_hat)
202
+ - 0.5 * log_det_omega
203
+ - 0.5 * logdet_xtox
204
+ - 0.5 * N_reml
205
+ )
206
+
207
+ if np.isnan(loglik) or np.isinf(loglik):
208
+ return -np.inf
209
+
210
+ return loglik
211
+
212
+
213
+ def compute_gls_estimates(
214
+ X_groups: list[NDArray],
215
+ y_groups: list[NDArray],
216
+ corr_matrices: list[NDArray],
217
+ var_weights_groups: list[NDArray],
218
+ nobs: int,
219
+ method: str = "REML",
220
+ ) -> tuple[NDArray, NDArray, float, float]:
221
+ """Compute GLS beta, covariance, sigma^2, and log-likelihood.
222
+
223
+ Given the estimated correlation and variance parameters, compute the
224
+ final GLS estimates.
225
+
226
+ Returns
227
+ -------
228
+ beta_hat : (k,) estimated coefficients
229
+ cov_beta : (k, k) covariance of beta estimates
230
+ sigma2_hat : estimated residual variance
231
+ loglik : log-likelihood at these estimates
232
+ """
233
+ k = X_groups[0].shape[1]
234
+ N = nobs
235
+
236
+ XtOiX = np.zeros((k, k))
237
+ XtOiy = np.zeros(k)
238
+ log_det_omega = 0.0
239
+
240
+ for Xg, yg, Rg, wg in zip(X_groups, y_groups, corr_matrices, var_weights_groups):
241
+ Omega_inv = _build_omega_inv_block(Rg, wg)
242
+ XtOiX += Xg.T @ Omega_inv @ Xg
243
+ XtOiy += Xg.T @ Omega_inv @ yg
244
+
245
+ sign, logdet_R = np.linalg.slogdet(Rg)
246
+ log_det_omega += logdet_R + 2 * _safe_log_weights(wg)
247
+
248
+ try:
249
+ beta_hat = np.linalg.solve(XtOiX, XtOiy)
250
+ except np.linalg.LinAlgError as e:
251
+ raise np.linalg.LinAlgError(
252
+ "X'Omega^{-1}X is singular. This may indicate perfect collinearity "
253
+ "in the design matrix or a degenerate correlation structure."
254
+ ) from e
255
+
256
+ rss_weighted = 0.0
257
+ for Xg, yg, Rg, wg in zip(X_groups, y_groups, corr_matrices, var_weights_groups):
258
+ Omega_inv = _build_omega_inv_block(Rg, wg)
259
+ resid = yg - Xg @ beta_hat
260
+ rss_weighted += resid @ Omega_inv @ resid
261
+
262
+ if method == "REML":
263
+ denom = N - k
264
+ if denom <= 0:
265
+ warnings.warn(
266
+ f"N - k = {denom} <= 0 for REML. Using ML denominator (N={N}) instead.",
267
+ stacklevel=2,
268
+ )
269
+ denom = N
270
+ sigma2_hat = rss_weighted / denom
271
+ else:
272
+ sigma2_hat = rss_weighted / N
273
+
274
+ # Covariance of beta: sigma^2 * (X'Omega^{-1}X)^{-1}
275
+ try:
276
+ cov_beta = sigma2_hat * np.linalg.inv(XtOiX)
277
+ except np.linalg.LinAlgError as e:
278
+ raise np.linalg.LinAlgError(
279
+ "Failed to invert X'Omega^{-1}X for covariance estimation. "
280
+ "The design matrix may be singular or near-singular."
281
+ ) from e
282
+
283
+ # Warn if covariance has issues
284
+ cov_diag = np.diag(cov_beta)
285
+ if np.any(cov_diag < 0):
286
+ warnings.warn(
287
+ "Some variance estimates are negative, indicating numerical instability. "
288
+ "Standard errors may be unreliable.",
289
+ stacklevel=2,
290
+ )
291
+
292
+ # Compute log-likelihood at these estimates
293
+ if method == "REML":
294
+ loglik = profile_loglik_reml(
295
+ X_groups, y_groups, corr_matrices, var_weights_groups, nobs
296
+ )
297
+ else:
298
+ loglik = profile_loglik_ml(
299
+ X_groups, y_groups, corr_matrices, var_weights_groups, nobs
300
+ )
301
+
302
+ return beta_hat, cov_beta, sigma2_hat, loglik