gpu-glm 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpu_glm-0.1.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Greg McMahan
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
gpu_glm-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,72 @@
1
+ Metadata-Version: 2.4
2
+ Name: gpu-glm
3
+ Version: 0.1.0
4
+ Summary: Regularized GLM models running on a GPU.
5
+ Author-email: Greg McMahan <gmcmacran@gmail.com>
6
+ License-Expression: MIT
7
+ Project-URL: Homepage, https://github.com/gmcmacran/gpu_glm
8
+ Requires-Python: >=3.12
9
+ Description-Content-Type: text/markdown
10
+ License-File: LICENSE
11
+ Requires-Dist: numpy>=2.4.3
12
+ Requires-Dist: scipy>=1.17.1
13
+ Provides-Extra: gpu
14
+ Requires-Dist: cupy>=14.0.1; extra == "gpu"
15
+ Provides-Extra: dev
16
+ Requires-Dist: numpy>=2.4.3; extra == "dev"
17
+ Requires-Dist: scipy>=1.17.1; extra == "dev"
18
+ Requires-Dist: scikit-learn>=1.8.0; extra == "dev"
19
+ Requires-Dist: pytest>=9.0.2; extra == "dev"
20
+ Requires-Dist: ruff>=0.15.1; extra == "dev"
21
+ Requires-Dist: pre-commit>=4.5.1; extra == "dev"
22
+ Requires-Dist: mkdocs>=1.6.1; extra == "dev"
23
+ Requires-Dist: mkdocs-material>=9.7.1; extra == "dev"
24
+ Requires-Dist: mkdocstrings[python]>=1.0.3; extra == "dev"
25
+ Dynamic: license-file
26
+
27
+ # gpu_glm
28
+
29
+ A lightweight Python implementation of Generalized Linear Models (GLMs) that runs on a GPU.
30
+
31
+ This package provides:
32
+
33
+ - Gaussian, Bernoulli, Poisson, Gamma, and Inverse Gaussian models.
34
+ - Multiple link functions (identity, log, inverse, logit, probit, etc.)
35
+ - A Cupy-based implementation that falls back to Numpy.
36
+ - A sci-kit learn interface.
37
+ - L2 regularization.
38
+
39
+ ---
40
+
41
+ ## Installation
42
+ To use the GPU, cupy must be installed with a GPU dependancies already working. If cupy is unavailable, numpy is used.
43
+
44
+ ```bash
45
+ pip install gpu-glm
46
+ ```
47
+
48
+ A conda package to handle GPU dependancies is under development.
49
+
50
+
51
+ ## Quick Example
52
+
53
+ Below fits a linear regression model.
54
+
55
+ ```python
56
+ import numpy as np
57
+ from gpu_glm import gaussian_glm
58
+ from sklearn.metrics import root_mean_squared_error
59
+
60
+ # Simulated data
61
+ X = np.column_stack([np.random.randn(100), np.ones(100)])
62
+ y = 2 * X[:, 0] + 3 + np.random.randn(100)
63
+
64
+ # Fit model
65
+ model = gaussian_glm()
66
+ model.fit(X, y)
67
+ print(f"coefficients: {model.coef()}")
68
+
69
+ y_hat = model.predict(X)
70
+ rmse = root_mean_squared_error(y, y_hat)
71
+ print(f"RMSE: {np.round(rmse, 3)}")
72
+ ```
@@ -0,0 +1,46 @@
1
+ # gpu_glm
2
+
3
+ A lightweight Python implementation of Generalized Linear Models (GLMs) that runs on a GPU.
4
+
5
+ This package provides:
6
+
7
+ - Gaussian, Bernoulli, Poisson, Gamma, and Inverse Gaussian models.
8
+ - Multiple link functions (identity, log, inverse, logit, probit, etc.)
9
+ - A Cupy-based implementation that falls back to Numpy.
10
+ - A sci-kit learn interface.
11
+ - L2 regularization.
12
+
13
+ ---
14
+
15
+ ## Installation
16
+ To use the GPU, cupy must be installed with a GPU dependancies already working. If cupy is unavailable, numpy is used.
17
+
18
+ ```bash
19
+ pip install gpu-glm
20
+ ```
21
+
22
+ A conda package to handle GPU dependancies is under development.
23
+
24
+
25
+ ## Quick Example
26
+
27
+ Below fits a linear regression model.
28
+
29
+ ```python
30
+ import numpy as np
31
+ from gpu_glm import gaussian_glm
32
+ from sklearn.metrics import root_mean_squared_error
33
+
34
+ # Simulated data
35
+ X = np.column_stack([np.random.randn(100), np.ones(100)])
36
+ y = 2 * X[:, 0] + 3 + np.random.randn(100)
37
+
38
+ # Fit model
39
+ model = gaussian_glm()
40
+ model.fit(X, y)
41
+ print(f"coefficients: {model.coef()}")
42
+
43
+ y_hat = model.predict(X)
44
+ rmse = root_mean_squared_error(y, y_hat)
45
+ print(f"RMSE: {np.round(rmse, 3)}")
46
+ ```
@@ -0,0 +1,54 @@
1
+ [build-system]
2
+ requires = ["setuptools>=69.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "gpu-glm"
7
+ version = "0.1.0"
8
+ description = "Regularized GLM models running on a GPU."
9
+ readme = "README.md"
10
+ requires-python = ">=3.12"
11
+ license = "MIT"
12
+
13
+ authors = [
14
+ { name = "Greg McMahan", email = "gmcmacran@gmail.com" }
15
+ ]
16
+
17
+ # IMPORTANT:
18
+ # Do NOT list CuPy here if you want Conda to install it.
19
+ # These dependencies must be CPU‑safe and pip‑installable everywhere.
20
+ dependencies = [
21
+ "numpy>=2.4.3",
22
+ "scipy>=1.17.1"
23
+ ]
24
+
25
+ [project.optional-dependencies]
26
+ # Users who want GPU support can install this extra via pip,
27
+ # but Conda users will ignore it and install CuPy themselves.
28
+ gpu = [
29
+ "cupy>=14.0.1"
30
+ ]
31
+
32
+ dev = [
33
+ "numpy>=2.4.3",
34
+ "scipy>=1.17.1",
35
+ "scikit-learn>=1.8.0",
36
+ "pytest>=9.0.2",
37
+ "ruff>=0.15.1",
38
+ "pre-commit>=4.5.1",
39
+ "mkdocs>=1.6.1",
40
+ "mkdocs-material>=9.7.1",
41
+ "mkdocstrings[python]>=1.0.3"
42
+ ]
43
+
44
+ [project.urls]
45
+ Homepage = "https://github.com/gmcmacran/gpu_glm"
46
+
47
+ [tool.ruff]
48
+ line-length = 88
49
+ target-version = "py313"
50
+
51
+ [tool.ruff.lint]
52
+ select = ["E", "F", "I", "B", "UP", "SIM", "C4"]
53
+
54
+ [tool.ruff.format]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,15 @@
1
+ from .models import (
2
+ bernoulli_glm,
3
+ gamma_glm,
4
+ gaussian_glm,
5
+ inverse_gaussian_glm,
6
+ poisson_glm,
7
+ )
8
+
9
+ __all__ = [
10
+ "gaussian_glm",
11
+ "bernoulli_glm",
12
+ "poisson_glm",
13
+ "gamma_glm",
14
+ "inverse_gaussian_glm",
15
+ ]
@@ -0,0 +1,480 @@
1
+ from abc import ABC, ABCMeta, abstractmethod
2
+
3
+ import numpy as _np
4
+
5
+ try:
6
+ import cupy as _cp
7
+
8
+ CUPY_AVAILABLE = True
9
+ except ImportError:
10
+ _cp = None
11
+ CUPY_AVAILABLE = False
12
+
13
+ import scipy.stats as stats
14
+
15
+
16
+ def xp():
17
+ """Return the active array module (CuPy if available, else NumPy)."""
18
+ return _cp if CUPY_AVAILABLE else _np
19
+
20
+
21
+ def backend_info():
22
+ """
23
+ Return a human‑readable description of the active compute backend.
24
+
25
+ Returns
26
+ -------
27
+ str
28
+ A multi‑line string describing whether the package is using
29
+ NumPy (CPU) or CuPy (GPU), and GPU details if available.
30
+ """
31
+ if not CUPY_AVAILABLE:
32
+ return "Backend: NumPy (CPU)\nCuPy not installed. All computations run on CPU."
33
+
34
+ # CuPy is available → gather GPU info
35
+ try:
36
+ device_id = _cp.cuda.runtime.getDevice()
37
+ props = _cp.cuda.runtime.getDeviceProperties(device_id)
38
+
39
+ name = props["name"].decode("utf-8")
40
+ total_mem = props["totalGlobalMem"] / (1024**3)
41
+ mp_count = props["multiProcessorCount"]
42
+
43
+ cuda_rt = _cp.cuda.runtime.runtimeGetVersion()
44
+ cuda_drv = _cp.cuda.runtime.driverGetVersion()
45
+
46
+ return (
47
+ "Backend: CuPy (GPU)\n"
48
+ f"Device: {name}\n"
49
+ f"Total Memory: {total_mem:.2f} GB\n"
50
+ f"Multiprocessors: {mp_count}\n"
51
+ f"CUDA Runtime Version: {cuda_rt}\n"
52
+ f"CUDA Driver Version: {cuda_drv}"
53
+ )
54
+
55
+ except Exception as e:
56
+ # Fallback if GPU query fails
57
+ return (
58
+ "Backend: CuPy (GPU)\n"
59
+ "CuPy is installed, but GPU properties could not be retrieved.\n"
60
+ f"Error: {e}"
61
+ )
62
+
63
+
64
+ class IRLS(ABC):
65
+ """
66
+ Base class implementing the Iteratively Reweighted Least Squares (IRLS)
67
+ algorithm for fitting Generalized Linear Models (GLMs).
68
+
69
+ Subclasses must implement:
70
+
71
+ - ``_var_mu(mu)``: variance function of the mean
72
+ - ``_a_of_phi(Y, mu, B)``: dispersion-related function
73
+
74
+ Parameters
75
+ ----------
76
+ link : str
77
+ Name of the link function. Supported values depend on the subclass.
78
+ alpha : float, default=0.0
79
+ L2 regularization strength. Matches scikit-learn's ``alpha``
80
+ parameter in ``PoissonRegressor``, ``GammaRegressor``, etc.
81
+ The penalty term is ``(alpha / 2) * ||coef||^2``. The intercept
82
+ is not regularized.
83
+ """
84
+
85
+ __metaclass__ = ABCMeta
86
+
87
+ def __init__(self, link, alpha=0.0):
88
+ """
89
+ Initialize the IRLS model.
90
+
91
+ Parameters
92
+ ----------
93
+ link : str
94
+ The link function to use (e.g., ``"identity"``, ``"log"``,
95
+ ``"logit"``, ``"inverse"``, ``"probit"``, ``"sqrt"``,
96
+ ``"1/mu^2"``).
97
+ alpha : float, default=0.0
98
+ L2 regularization strength. The penalty is
99
+ ``(alpha / 2) * ||coef||^2``. The intercept is excluded
100
+ from regularization.
101
+ """
102
+ if alpha >= 0.0:
103
+ self._alpha = alpha
104
+ else:
105
+ raise ValueError(f"Invalid alpha: {alpha}. Alpha must be non-negative.")
106
+ self._B = None
107
+ self._link = link
108
+ self._alpha = alpha
109
+ super().__init__()
110
+
111
+ # -----------------------------
112
+ # Backend helpers
113
+ # -----------------------------
114
+ def _to_backend(self, arr):
115
+ """Convert input to backend array (CuPy if available, else NumPy)."""
116
+ if CUPY_AVAILABLE:
117
+ return _cp.asarray(arr)
118
+ return _np.asarray(arr)
119
+
120
+ def _to_numpy(self, arr):
121
+ """Convert backend array to NumPy (for SciPy/stats)."""
122
+ if CUPY_AVAILABLE:
123
+ return _cp.asnumpy(arr)
124
+ return arr
125
+
126
+ # -----------------------------
127
+ # Public API
128
+ # -----------------------------
129
+ def coef(self):
130
+ """
131
+ Return the fitted coefficient vector.
132
+
133
+ Returns
134
+ -------
135
+ np.ndarray
136
+ The coefficient vector ``B`` of shape ``(n_features,)``.
137
+ """
138
+ return self._to_numpy(self._B)
139
+
140
+ def fit(self, X, Y):
141
+ """
142
+ Fit the GLM using the IRLS algorithm.
143
+
144
+ Parameters
145
+ ----------
146
+ X : np.ndarray, shape (n_samples, n_features)
147
+ Design matrix.
148
+ Y : np.ndarray, shape (n_samples,)
149
+ Response vector.
150
+ """
151
+ xp_backend = xp()
152
+
153
+ # Convert to backend (CuPy or NumPy)
154
+ X = self._to_backend(X)
155
+ Y = self._to_backend(Y)
156
+
157
+ # Add intercept column if missing
158
+ if not xp_backend.allclose(X[:, -1], 1):
159
+ ones = xp_backend.ones((X.shape[0], 1))
160
+ X = xp_backend.concatenate([X, ones], axis=1)
161
+
162
+ n_features = X.shape[1]
163
+
164
+ # Initialize coefficients
165
+ self._B = xp_backend.zeros(n_features)
166
+ self._B[-1] = Y.mean()
167
+
168
+ tol = 1e6
169
+ while tol > 1e-5:
170
+ eta = X.dot(self._B)
171
+ mu = self._inv_link(eta)
172
+
173
+ # Vectorized weights
174
+ w = (
175
+ 1 / (self._var_mu(mu) * self._a_of_phi(Y, mu, self._B))
176
+ ) * xp_backend.power(self._del_eta_del_mu(mu), 2)
177
+
178
+ # Vectorized z
179
+ z = (Y - mu) * self._del_eta_del_mu(mu) + eta
180
+
181
+ # Weighted least squares without forming diag(W)
182
+ # X^T W X == (X * w[:, None]).T @ X
183
+ Xw = X * w[:, None]
184
+ XtWX = Xw.T.dot(X)
185
+ XtWz = Xw.T.dot(z)
186
+
187
+ # L2 regularization: add alpha to diagonal (exclude intercept)
188
+ if self._alpha > 0:
189
+ n_coef = n_features - 1
190
+ XtWX[:n_coef, :n_coef] += self._alpha * xp_backend.eye(n_coef)
191
+
192
+ # Solve for update
193
+ B_new = xp_backend.linalg.solve(XtWX, XtWz)
194
+
195
+ tol = xp_backend.sum(xp_backend.abs(B_new - self._B))
196
+ self._B = B_new
197
+
198
+ return self
199
+
200
+ def predict(self, X):
201
+ """
202
+ Predict the mean response for new data.
203
+
204
+ Parameters
205
+ ----------
206
+ X : np.ndarray, shape (n_samples, n_features)
207
+ Design matrix.
208
+
209
+ Returns
210
+ -------
211
+ np.ndarray
212
+ Predicted mean response ``mu``.
213
+ """
214
+ xp_backend = xp()
215
+ Xb = self._to_backend(X)
216
+ # Add intercept column if missing
217
+ if not xp_backend.allclose(Xb[:, -1], 1):
218
+ ones = xp_backend.ones((Xb.shape[0], 1))
219
+ Xb = xp_backend.concatenate([Xb, ones], axis=1)
220
+ eta = Xb.dot(self._B)
221
+ mu = self._inv_link(eta)
222
+ return self._to_numpy(mu)
223
+
224
+ # -----------------------------
225
+ # Link functions
226
+ # -----------------------------
227
+ def _inv_link(self, eta):
228
+ """
229
+ Apply the inverse link function.
230
+
231
+ Parameters
232
+ ----------
233
+ eta : np.ndarray
234
+ Linear predictor.
235
+
236
+ Returns
237
+ -------
238
+ np.ndarray
239
+ Mean response ``mu``.
240
+ """
241
+ xp_backend = xp()
242
+
243
+ if self._link == "identity":
244
+ return eta
245
+ elif self._link == "log":
246
+ return xp_backend.exp(eta)
247
+ elif self._link == "inverse":
248
+ return 1 / eta
249
+ elif self._link == "logit":
250
+ e = xp_backend.exp(eta)
251
+ return e / (1 + e)
252
+ elif self._link == "probit":
253
+ eta_np = self._to_numpy(eta)
254
+ return self._to_backend(stats.norm.cdf(eta_np))
255
+ elif self._link == "sqrt":
256
+ return xp_backend.power(eta, 2)
257
+ elif self._link == "1/mu^2":
258
+ return 1 / xp_backend.power(eta, 0.5)
259
+
260
+ def _del_eta_del_mu(self, mu):
261
+ """
262
+ Compute derivative :math:`d\\eta/d\\mu` for the link function.
263
+
264
+ Parameters
265
+ ----------
266
+ mu : np.ndarray
267
+ Mean response.
268
+
269
+ Returns
270
+ -------
271
+ np.ndarray
272
+ Derivative ``dη/dμ`` evaluated at ``mu``.
273
+ """
274
+ xp_backend = xp()
275
+
276
+ if self._link == "identity":
277
+ return xp_backend.ones(mu.shape)
278
+ elif self._link == "log":
279
+ return 1 / mu
280
+ elif self._link == "inverse":
281
+ return -1 / xp_backend.power(mu, 2)
282
+ elif self._link == "logit":
283
+ return 1 / (mu * (1 - mu))
284
+ elif self._link == "probit":
285
+ mu_np = self._to_numpy(mu)
286
+ return self._to_backend(stats.norm.pdf(stats.norm.ppf(mu_np)))
287
+ elif self._link == "sqrt":
288
+ return 0.5 * xp_backend.power(mu, -0.5)
289
+ elif self._link == "1/mu^2":
290
+ return -2 / xp_backend.power(mu, 3)
291
+
292
+ # -----------------------------
293
+ # Abstract variance + dispersion
294
+ # -----------------------------
295
+ @abstractmethod
296
+ def _var_mu(self, mu):
297
+ """
298
+ Variance function :math:`\\mathrm{Var}(Y \\mid \\mu)`.
299
+
300
+ Parameters
301
+ ----------
302
+ mu : np.ndarray
303
+ Mean response.
304
+
305
+ Returns
306
+ -------
307
+ np.ndarray
308
+ Variance evaluated at ``mu``.
309
+ """
310
+ pass
311
+
312
+ @abstractmethod
313
+ def _a_of_phi(self, Y, mu, B):
314
+ """
315
+ Dispersion-related function :math:`a(\\phi)`.
316
+
317
+ Parameters
318
+ ----------
319
+ Y : np.ndarray
320
+ Observed response.
321
+ mu : np.ndarray
322
+ Mean response.
323
+ B : np.ndarray
324
+ Coefficient vector.
325
+
326
+ Returns
327
+ -------
328
+ np.ndarray or float
329
+ Dispersion-related quantity.
330
+ """
331
+ pass
332
+
333
+
334
+ # -----------------------------
335
+ # GLM Subclasses
336
+ # -----------------------------
337
+ class gaussian_glm(IRLS):
338
+ """
339
+ Gaussian GLM with identity, log, or inverse link.
340
+
341
+ Parameters
342
+ ----------
343
+ link : str, default="identity"
344
+ Link function. One of ``"identity"``, ``"log"``, ``"inverse"``.
345
+ alpha : float, default=0.0
346
+ L2 regularization strength.
347
+ """
348
+
349
+ def __init__(self, link="identity", alpha=0.0):
350
+ if link in ("identity", "log", "inverse"):
351
+ super().__init__(link, alpha=alpha)
352
+ else:
353
+ raise ValueError(f"Invalid link: {link}")
354
+
355
+ def _var_mu(self, mu):
356
+ return xp().ones(mu.shape)
357
+
358
+ def _a_of_phi(self, Y, mu, B):
359
+ xp_backend = xp()
360
+ return xp_backend.sum((Y - mu) ** 2) / (Y.shape[0] - B.shape[0])
361
+
362
+
363
+ class bernoulli_glm(IRLS):
364
+ """
365
+ Bernoulli GLM with logit or probit link.
366
+
367
+ Parameters
368
+ ----------
369
+ link : str, default="logit"
370
+ Link function. One of ``"logit"``, ``"probit"``.
371
+ alpha : float, default=0.0
372
+ L2 regularization strength.
373
+ """
374
+
375
+ def __init__(self, link="logit", alpha=0.0):
376
+ if link in ("logit", "probit"):
377
+ super().__init__(link, alpha=alpha)
378
+ else:
379
+ raise ValueError(f"Invalid link: {link}")
380
+
381
+ def _var_mu(self, mu):
382
+ return mu * (1 - mu)
383
+
384
+ def _a_of_phi(self, Y, mu, B):
385
+ return xp().ones(Y.shape[0])
386
+
387
+ def predict_proba(self, X):
388
+ xp_backend = xp()
389
+ Xb = self._to_backend(X)
390
+ # Add intercept column if missing
391
+ if not xp_backend.allclose(Xb[:, -1], 1):
392
+ ones = xp_backend.ones((Xb.shape[0], 1))
393
+ Xb = xp_backend.concatenate([Xb, ones], axis=1)
394
+ props = self._inv_link(Xb.dot(self._B))
395
+ props = self._to_numpy(props)
396
+ return _np.column_stack([1 - props, props])
397
+
398
+ def predict(self, X):
399
+ probs = self.predict_proba(X)
400
+ return (probs[:, 1] > 0.5).astype(int)
401
+
402
+
403
+ class poisson_glm(IRLS):
404
+ """
405
+ Poisson GLM with log, identity, or sqrt link.
406
+
407
+ Parameters
408
+ ----------
409
+ link : str, default="log"
410
+ Link function. One of ``"log"``, ``"identity"``, ``"sqrt"``.
411
+ alpha : float, default=0.0
412
+ L2 regularization strength.
413
+ """
414
+
415
+ def __init__(self, link="log", alpha=0.0):
416
+ if link in ("log", "identity", "sqrt"):
417
+ super().__init__(link, alpha=alpha)
418
+ else:
419
+ raise ValueError(f"Invalid link: {link}")
420
+
421
+ def _var_mu(self, mu):
422
+ return mu
423
+
424
+ def _a_of_phi(self, Y, mu, B):
425
+ return xp().ones(Y.shape[0])
426
+
427
+
428
+ class gamma_glm(IRLS):
429
+ """
430
+ Gamma GLM with inverse, identity, or log link.
431
+
432
+ Parameters
433
+ ----------
434
+ link : str, default="inverse"
435
+ Link function. One of ``"inverse"``, ``"identity"``, ``"log"``.
436
+ alpha : float, default=0.0
437
+ L2 regularization strength.
438
+ """
439
+
440
+ def __init__(self, link="inverse", alpha=0.0):
441
+ if link in ("inverse", "identity", "log"):
442
+ super().__init__(link, alpha=alpha)
443
+ else:
444
+ raise ValueError(f"Invalid link: {link}")
445
+
446
+ def _var_mu(self, mu):
447
+ return mu**2
448
+
449
+ def _a_of_phi(self, Y, mu, B):
450
+ xp_backend = xp()
451
+ numerator = (Y - mu) ** 2
452
+ denominator = mu**2 * (Y.shape[0] - B.shape[0])
453
+ phi = xp_backend.sum(numerator / denominator)
454
+ return xp_backend.ones(Y.shape[0]) * phi
455
+
456
+
457
+ class inverse_gaussian_glm(IRLS):
458
+ """
459
+ Inverse Gaussian GLM with 1/μ², inverse, identity, or log link.
460
+
461
+ Parameters
462
+ ----------
463
+ link : str, default="1/mu^2"
464
+ Link function. One of ``"1/mu^2"``, ``"inverse"``, ``"identity"``, ``"log"``.
465
+ alpha : float, default=0.0
466
+ L2 regularization strength.
467
+ """
468
+
469
+ def __init__(self, link="1/mu^2", alpha=0.0):
470
+ if link in ("1/mu^2", "inverse", "identity", "log"):
471
+ super().__init__(link, alpha=alpha)
472
+ else:
473
+ raise ValueError(f"Invalid link: {link}")
474
+
475
+ def _var_mu(self, mu):
476
+ return mu**3
477
+
478
+ def _a_of_phi(self, Y, mu, B):
479
+ xp_backend = xp()
480
+ return -xp_backend.sum((Y - mu) ** 2) / (Y.shape[0] - B.shape[0])