pymc-extras 0.2.6__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,583 +0,0 @@
1
- # Copyright 2024 The PyMC Developers
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- import logging
17
-
18
- from functools import reduce
19
- from importlib.util import find_spec
20
- from itertools import product
21
- from typing import Literal
22
-
23
- import arviz as az
24
- import numpy as np
25
- import pymc as pm
26
- import pytensor
27
- import pytensor.tensor as pt
28
- import xarray as xr
29
-
30
- from arviz import dict_to_dataset
31
- from better_optimize.constants import minimize_method
32
- from pymc import DictToArrayBijection
33
- from pymc.backends.arviz import (
34
- coords_and_dims_for_inferencedata,
35
- find_constants,
36
- find_observations,
37
- )
38
- from pymc.blocking import RaveledVars
39
- from pymc.model.transform.conditioning import remove_value_transforms
40
- from pymc.model.transform.optimization import freeze_dims_and_data
41
- from pymc.util import get_default_varnames
42
- from scipy import stats
43
-
44
- from pymc_extras.inference.find_map import (
45
- GradientBackend,
46
- _unconstrained_vector_to_constrained_rvs,
47
- find_MAP,
48
- get_nearest_psd,
49
- scipy_optimize_funcs_from_loss,
50
- )
51
-
52
- _log = logging.getLogger(__name__)
53
-
54
-
55
- def laplace_draws_to_inferencedata(
56
- posterior_draws: list[np.ndarray[float | int]], model: pm.Model | None = None
57
- ) -> az.InferenceData:
58
- """
59
- Convert draws from a posterior estimated with the Laplace approximation to an InferenceData object.
60
-
61
-
62
- Parameters
63
- ----------
64
- posterior_draws: list of np.ndarray
65
- A list of arrays containing the posterior draws. Each array should have shape (chains, draws, *shape), where
66
- shape is the shape of the variable in the posterior.
67
- model: Model, optional
68
- A PyMC model. If None, the model is taken from the current model context.
69
-
70
- Returns
71
- -------
72
- idata: az.InferenceData
73
- An InferenceData object containing the approximated posterior samples
74
- """
75
- model = pm.modelcontext(model)
76
- chains, draws, *_ = posterior_draws[0].shape
77
-
78
- def make_rv_coords(name):
79
- coords = {"chain": range(chains), "draw": range(draws)}
80
- extra_dims = model.named_vars_to_dims.get(name)
81
- if extra_dims is None:
82
- return coords
83
- return coords | {dim: list(model.coords[dim]) for dim in extra_dims}
84
-
85
- def make_rv_dims(name):
86
- dims = ["chain", "draw"]
87
- extra_dims = model.named_vars_to_dims.get(name)
88
- if extra_dims is None:
89
- return dims
90
- return dims + list(extra_dims)
91
-
92
- names = [
93
- x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=False)
94
- ]
95
- idata = {
96
- name: xr.DataArray(
97
- data=draws,
98
- coords=make_rv_coords(name),
99
- dims=make_rv_dims(name),
100
- name=name,
101
- )
102
- for name, draws in zip(names, posterior_draws)
103
- }
104
-
105
- coords, dims = coords_and_dims_for_inferencedata(model)
106
- idata = az.convert_to_inference_data(idata, coords=coords, dims=dims)
107
-
108
- return idata
109
-
110
-
111
- def add_fit_to_inferencedata(
112
- idata: az.InferenceData, mu: RaveledVars, H_inv: np.ndarray, model: pm.Model | None = None
113
- ) -> az.InferenceData:
114
- """
115
- Add the mean vector and covariance matrix of the Laplace approximation to an InferenceData object.
116
-
117
-
118
- Parameters
119
- ----------
120
- idata: az.InfereceData
121
- An InferenceData object containing the approximated posterior samples.
122
- mu: RaveledVars
123
- The MAP estimate of the model parameters.
124
- H_inv: np.ndarray
125
- The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
126
- model: Model, optional
127
- A PyMC model. If None, the model is taken from the current model context.
128
-
129
- Returns
130
- -------
131
- idata: az.InferenceData
132
- The provided InferenceData, with the mean vector and covariance matrix added to the "fit" group.
133
- """
134
- model = pm.modelcontext(model)
135
- coords = model.coords
136
-
137
- variable_names, *_ = zip(*mu.point_map_info)
138
-
139
- def make_unpacked_variable_names(name):
140
- value_to_dim = {
141
- x.name: model.named_vars_to_dims.get(model.values_to_rvs[x].name, None)
142
- for x in model.value_vars
143
- }
144
- value_to_dim = {k: v for k, v in value_to_dim.items() if v is not None}
145
-
146
- rv_to_dim = model.named_vars_to_dims
147
- dims_dict = rv_to_dim | value_to_dim
148
-
149
- dims = dims_dict.get(name)
150
- if dims is None:
151
- return [name]
152
- labels = product(*(coords[dim] for dim in dims))
153
- return [f"{name}[{','.join(map(str, label))}]" for label in labels]
154
-
155
- unpacked_variable_names = reduce(
156
- lambda lst, name: lst + make_unpacked_variable_names(name), variable_names, []
157
- )
158
-
159
- mean_dataarray = xr.DataArray(mu.data, dims=["rows"], coords={"rows": unpacked_variable_names})
160
- cov_dataarray = xr.DataArray(
161
- H_inv,
162
- dims=["rows", "columns"],
163
- coords={"rows": unpacked_variable_names, "columns": unpacked_variable_names},
164
- )
165
-
166
- dataset = xr.Dataset({"mean_vector": mean_dataarray, "covariance_matrix": cov_dataarray})
167
- idata.add_groups(fit=dataset)
168
-
169
- return idata
170
-
171
-
172
- def add_data_to_inferencedata(
173
- idata: az.InferenceData,
174
- progressbar: bool = True,
175
- model: pm.Model | None = None,
176
- compile_kwargs: dict | None = None,
177
- ) -> az.InferenceData:
178
- """
179
- Add observed and constant data to an InferenceData object.
180
-
181
- Parameters
182
- ----------
183
- idata: az.InferenceData
184
- An InferenceData object containing the approximated posterior samples.
185
- progressbar: bool
186
- Whether to display a progress bar during computations. Default is True.
187
- model: Model, optional
188
- A PyMC model. If None, the model is taken from the current model context.
189
- compile_kwargs: dict, optional
190
- Additional keyword arguments to pass to pytensor.function.
191
-
192
- Returns
193
- -------
194
- idata: az.InferenceData
195
- The provided InferenceData, with observed and constant data added.
196
- """
197
- model = pm.modelcontext(model)
198
-
199
- if model.deterministics:
200
- idata.posterior = pm.compute_deterministics(
201
- idata.posterior,
202
- model=model,
203
- merge_dataset=True,
204
- progressbar=progressbar,
205
- compile_kwargs=compile_kwargs,
206
- )
207
-
208
- coords, dims = coords_and_dims_for_inferencedata(model)
209
-
210
- observed_data = dict_to_dataset(
211
- find_observations(model),
212
- library=pm,
213
- coords=coords,
214
- dims=dims,
215
- default_dims=[],
216
- )
217
-
218
- constant_data = dict_to_dataset(
219
- find_constants(model),
220
- library=pm,
221
- coords=coords,
222
- dims=dims,
223
- default_dims=[],
224
- )
225
-
226
- idata.add_groups(
227
- {"observed_data": observed_data, "constant_data": constant_data},
228
- coords=coords,
229
- dims=dims,
230
- )
231
-
232
- return idata
233
-
234
-
235
- def fit_mvn_at_MAP(
236
- optimized_point: dict[str, np.ndarray],
237
- model: pm.Model | None = None,
238
- on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
239
- transform_samples: bool = False,
240
- gradient_backend: GradientBackend = "pytensor",
241
- zero_tol: float = 1e-8,
242
- diag_jitter: float | None = 1e-8,
243
- compile_kwargs: dict | None = None,
244
- ) -> tuple[RaveledVars, np.ndarray]:
245
- """
246
- Create a multivariate normal distribution using the inverse of the negative Hessian matrix of the log-posterior
247
- evaluated at the MAP estimate. This is the basis of the Laplace approximation.
248
-
249
- Parameters
250
- ----------
251
- optimized_point : dict[str, np.ndarray]
252
- Local maximum a posteriori (MAP) point returned from pymc.find_MAP or jax_tools.fit_map
253
- model : Model, optional
254
- A PyMC model. If None, the model is taken from the current model context.
255
- on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore'
256
- What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite.
257
- If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned.
258
- If 'error', an error will be raised.
259
- transform_samples : bool
260
- Whether to transform the samples back to the original parameter space. Default is True.
261
- gradient_backend: str, default "pytensor"
262
- The backend to use for gradient computations. Must be one of "pytensor" or "jax".
263
- zero_tol: float
264
- Value below which an element of the Hessian matrix is counted as 0.
265
- This is used to stabilize the computation of the inverse Hessian matrix. Default is 1e-8.
266
- diag_jitter: float | None
267
- A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite.
268
- If None, no jitter is added. Default is 1e-8.
269
- compile_kwargs: dict, optional
270
- Additional keyword arguments to pass to pytensor.function when compiling loss functions
271
-
272
- Returns
273
- -------
274
- map_estimate: RaveledVars
275
- The MAP estimate of the model parameters, raveled into a 1D array.
276
-
277
- inverse_hessian: np.ndarray
278
- The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
279
- """
280
- if gradient_backend == "jax" and not find_spec("jax"):
281
- raise ImportError("JAX must be installed to use JAX gradients")
282
-
283
- model = pm.modelcontext(model)
284
- compile_kwargs = {} if compile_kwargs is None else compile_kwargs
285
- frozen_model = freeze_dims_and_data(model)
286
-
287
- if not transform_samples:
288
- untransformed_model = remove_value_transforms(frozen_model)
289
- logp = untransformed_model.logp(jacobian=False)
290
- variables = untransformed_model.continuous_value_vars
291
- else:
292
- logp = frozen_model.logp(jacobian=True)
293
- variables = frozen_model.continuous_value_vars
294
-
295
- variable_names = {var.name for var in variables}
296
- optimized_free_params = {k: v for k, v in optimized_point.items() if k in variable_names}
297
- mu = DictToArrayBijection.map(optimized_free_params)
298
-
299
- _, f_hess, _ = scipy_optimize_funcs_from_loss(
300
- loss=-logp,
301
- inputs=variables,
302
- initial_point_dict=optimized_free_params,
303
- use_grad=True,
304
- use_hess=True,
305
- use_hessp=False,
306
- gradient_backend=gradient_backend,
307
- compile_kwargs=compile_kwargs,
308
- )
309
-
310
- H = -f_hess(mu.data)
311
- H_inv = np.linalg.pinv(np.where(np.abs(H) < zero_tol, 0, -H))
312
-
313
- def stabilize(x, jitter):
314
- return x + np.eye(x.shape[0]) * jitter
315
-
316
- H_inv = H_inv if diag_jitter is None else stabilize(H_inv, diag_jitter)
317
-
318
- try:
319
- np.linalg.cholesky(H_inv)
320
- except np.linalg.LinAlgError:
321
- if on_bad_cov == "error":
322
- raise np.linalg.LinAlgError(
323
- "Inverse Hessian not positive-semi definite at the provided point"
324
- )
325
- H_inv = get_nearest_psd(H_inv)
326
- if on_bad_cov == "warn":
327
- _log.warning(
328
- "Inverse Hessian is not positive semi-definite at the provided point, using the closest PSD "
329
- "matrix in L1-norm instead"
330
- )
331
-
332
- return mu, H_inv
333
-
334
-
335
- def sample_laplace_posterior(
336
- mu: RaveledVars,
337
- H_inv: np.ndarray,
338
- model: pm.Model | None = None,
339
- chains: int = 2,
340
- draws: int = 500,
341
- transform_samples: bool = False,
342
- progressbar: bool = True,
343
- random_seed: int | np.random.Generator | None = None,
344
- compile_kwargs: dict | None = None,
345
- ) -> az.InferenceData:
346
- """
347
- Generate samples from a multivariate normal distribution with mean `mu` and inverse covariance matrix `H_inv`.
348
-
349
- Parameters
350
- ----------
351
- mu: RaveledVars
352
- The MAP estimate of the model parameters.
353
- H_inv: np.ndarray
354
- The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
355
- model : Model
356
- A PyMC model
357
- chains : int
358
- The number of sampling chains running in parallel. Default is 2.
359
- draws : int
360
- The number of samples to draw from the approximated posterior. Default is 500.
361
- transform_samples : bool
362
- Whether to transform the samples back to the original parameter space. Default is True.
363
- progressbar : bool
364
- Whether to display a progress bar during computations. Default is True.
365
- random_seed: int | np.random.Generator | None
366
- Seed for the random number generator or a numpy Generator for reproducibility
367
-
368
- Returns
369
- -------
370
- idata: az.InferenceData
371
- An InferenceData object containing the approximated posterior samples.
372
- """
373
- model = pm.modelcontext(model)
374
- compile_kwargs = {} if compile_kwargs is None else compile_kwargs
375
- rng = np.random.default_rng(random_seed)
376
-
377
- posterior_dist = stats.multivariate_normal(
378
- mean=mu.data, cov=H_inv, allow_singular=True, seed=rng
379
- )
380
-
381
- posterior_draws = posterior_dist.rvs(size=(chains, draws))
382
- if mu.data.shape == (1,):
383
- posterior_draws = np.expand_dims(posterior_draws, -1)
384
-
385
- if transform_samples:
386
- constrained_rvs, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(model)
387
- batched_values = pt.tensor(
388
- "batched_values",
389
- shape=(chains, draws, *unconstrained_vector.type.shape),
390
- dtype=unconstrained_vector.type.dtype,
391
- )
392
- batched_rvs = pytensor.graph.vectorize_graph(
393
- constrained_rvs, replace={unconstrained_vector: batched_values}
394
- )
395
-
396
- f_constrain = pm.compile(inputs=[batched_values], outputs=batched_rvs, **compile_kwargs)
397
- posterior_draws = f_constrain(posterior_draws)
398
-
399
- else:
400
- info = mu.point_map_info
401
- flat_shapes = [size for _, _, size, _ in info]
402
- slices = [
403
- slice(sum(flat_shapes[:i]), sum(flat_shapes[: i + 1])) for i in range(len(flat_shapes))
404
- ]
405
-
406
- posterior_draws = [
407
- posterior_draws[..., idx].reshape((chains, draws, *shape)).astype(dtype)
408
- for idx, (name, shape, _, dtype) in zip(slices, info)
409
- ]
410
-
411
- idata = laplace_draws_to_inferencedata(posterior_draws, model)
412
- idata = add_fit_to_inferencedata(idata, mu, H_inv)
413
- idata = add_data_to_inferencedata(idata, progressbar, model, compile_kwargs)
414
-
415
- return idata
416
-
417
-
418
- def fit_laplace(
419
- optimize_method: minimize_method | Literal["basinhopping"] = "BFGS",
420
- *,
421
- model: pm.Model | None = None,
422
- use_grad: bool | None = None,
423
- use_hessp: bool | None = None,
424
- use_hess: bool | None = None,
425
- initvals: dict | None = None,
426
- random_seed: int | np.random.Generator | None = None,
427
- return_raw: bool = False,
428
- jitter_rvs: list[pt.TensorVariable] | None = None,
429
- progressbar: bool = True,
430
- include_transformed: bool = True,
431
- gradient_backend: GradientBackend = "pytensor",
432
- chains: int = 2,
433
- draws: int = 500,
434
- on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
435
- fit_in_unconstrained_space: bool = False,
436
- zero_tol: float = 1e-8,
437
- diag_jitter: float | None = 1e-8,
438
- optimizer_kwargs: dict | None = None,
439
- compile_kwargs: dict | None = None,
440
- ) -> az.InferenceData:
441
- """
442
- Create a Laplace (quadratic) approximation for a posterior distribution.
443
-
444
- This function generates a Laplace approximation for a given posterior distribution using a specified
445
- number of draws. This is useful for obtaining a parametric approximation to the posterior distribution
446
- that can be used for further analysis.
447
-
448
- Parameters
449
- ----------
450
- model : pm.Model
451
- The PyMC model to be fit. If None, the current model context is used.
452
- method : str
453
- The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP,
454
- trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping.
455
-
456
- See scipy.optimize.minimize documentation for details.
457
- use_grad : bool | None, optional
458
- Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on
459
- the ``method``.
460
- use_hessp : bool | None, optional
461
- Whether to use Hessian-vector products in the optimization. Defaults to None, which determines this automatically based on
462
- the ``method``.
463
- use_hess : bool | None, optional
464
- Whether to use the Hessian matrix in the optimization. Defaults to None, which determines this automatically based on
465
- the ``method``.
466
- initvals : None | dict, optional
467
- Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted.
468
- If None, the model's default initial values are used.
469
- random_seed : None | int | np.random.Generator, optional
470
- Seed for the random number generator or a numpy Generator for reproducibility
471
- return_raw: bool | False, optinal
472
- Whether to also return the full output of `scipy.optimize.minimize`
473
- jitter_rvs : list of TensorVariables, optional
474
- Variables whose initial values should be jittered. If None, all variables are jittered.
475
- progressbar : bool, optional
476
- Whether to display a progress bar during optimization. Defaults to True.
477
- fit_in_unconstrained_space: bool, default False
478
- Whether to fit the Laplace approximation in the unconstrained parameter space. If True, samples will be drawn
479
- from a mean and covariance matrix computed at a point in the **unconstrained** parameter space. Samples will
480
- then be transformed back to the original parameter space. This will guarantee that the samples will respect
481
- the domain of prior distributions (for exmaple, samples from a Beta distribution will be strictly between 0
482
- and 1).
483
-
484
- .. warning::
485
- This argument should be considered highly experimental. It has not been verified if this method produces
486
- valid draws from the posterior. **Use at your own risk**.
487
-
488
- gradient_backend: str, default "pytensor"
489
- The backend to use for gradient computations. Must be one of "pytensor" or "jax".
490
- chains: int, default: 2
491
- The number of chain dimensions to sample. Note that this is *not* the number of chains to run in parallel,
492
- because the Laplace approximation is not an MCMC method. This argument exists to ensure that outputs are
493
- compatible with the ArviZ library.
494
- draws: int, default: 500
495
- The number of samples to draw from the approximated posterior. Totals samples will be chains * draws.
496
- on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore'
497
- What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite.
498
- If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned.
499
- If 'error', an error will be raised.
500
- zero_tol: float
501
- Value below which an element of the Hessian matrix is counted as 0.
502
- This is used to stabilize the computation of the inverse Hessian matrix. Default is 1e-8.
503
- diag_jitter: float | None
504
- A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite.
505
- If None, no jitter is added. Default is 1e-8.
506
- optimizer_kwargs
507
- Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
508
- ``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
509
- ``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.
510
- compile_kwargs: dict, optional
511
- Additional keyword arguments to pass to pytensor.function.
512
-
513
- Returns
514
- -------
515
- :class:`~arviz.InferenceData`
516
- An InferenceData object containing the approximated posterior samples.
517
-
518
- Examples
519
- --------
520
- >>> from pymc_extras.inference.laplace import fit_laplace
521
- >>> import numpy as np
522
- >>> import pymc as pm
523
- >>> import arviz as az
524
- >>> y = np.array([2642, 3503, 4358]*10)
525
- >>> with pm.Model() as m:
526
- >>> logsigma = pm.Uniform("logsigma", 1, 100)
527
- >>> mu = pm.Uniform("mu", -10000, 10000)
528
- >>> yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y)
529
- >>> idata = fit_laplace()
530
-
531
- Notes
532
- -----
533
- This method of approximation may not be suitable for all types of posterior distributions,
534
- especially those with significant skewness or multimodality.
535
-
536
- See Also
537
- --------
538
- fit : Calling the inference function 'fit' like pmx.fit(method="laplace", model=m)
539
- will forward the call to 'fit_laplace'.
540
-
541
- """
542
- compile_kwargs = {} if compile_kwargs is None else compile_kwargs
543
- optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs
544
-
545
- optimized_point = find_MAP(
546
- method=optimize_method,
547
- model=model,
548
- use_grad=use_grad,
549
- use_hessp=use_hessp,
550
- use_hess=use_hess,
551
- initvals=initvals,
552
- random_seed=random_seed,
553
- return_raw=return_raw,
554
- jitter_rvs=jitter_rvs,
555
- progressbar=progressbar,
556
- include_transformed=include_transformed,
557
- gradient_backend=gradient_backend,
558
- compile_kwargs=compile_kwargs,
559
- **optimizer_kwargs,
560
- )
561
-
562
- mu, H_inv = fit_mvn_at_MAP(
563
- optimized_point=optimized_point,
564
- model=model,
565
- on_bad_cov=on_bad_cov,
566
- transform_samples=fit_in_unconstrained_space,
567
- gradient_backend=gradient_backend,
568
- zero_tol=zero_tol,
569
- diag_jitter=diag_jitter,
570
- compile_kwargs=compile_kwargs,
571
- )
572
-
573
- return sample_laplace_posterior(
574
- mu=mu,
575
- H_inv=H_inv,
576
- model=model,
577
- chains=chains,
578
- draws=draws,
579
- transform_samples=fit_in_unconstrained_space,
580
- progressbar=progressbar,
581
- random_seed=random_seed,
582
- compile_kwargs=compile_kwargs,
583
- )
@@ -1,69 +0,0 @@
1
- try:
2
- import torch
3
-
4
- from gpytorch.utils.permutation import apply_permutation
5
- except ImportError as e:
6
- raise ImportError("PyTorch and GPyTorch not found.") from e
7
-
8
- import numpy as np
9
-
10
-
11
- def pp(x):
12
- return np.array2string(x, precision=4, floatmode="fixed")
13
-
14
-
15
- def pivoted_cholesky(mat: np.matrix, error_tol=1e-6, max_iter=np.inf):
16
- """
17
- mat: numpy matrix of N x N
18
-
19
- This is to replicate what is done in GPyTorch verbatim.
20
- """
21
- n = mat.shape[-1]
22
- max_iter = min(int(max_iter), n)
23
-
24
- d = np.array(np.diag(mat))
25
- orig_error = np.max(d)
26
- error = np.linalg.norm(d, 1) / orig_error
27
- pi = np.arange(n)
28
-
29
- L = np.zeros((max_iter, n))
30
-
31
- m = 0
32
- while m < max_iter and error > error_tol:
33
- permuted_d = d[pi]
34
- max_diag_idx = np.argmax(permuted_d[m:])
35
- max_diag_idx = max_diag_idx + m
36
- max_diag_val = permuted_d[max_diag_idx]
37
- i = max_diag_idx
38
-
39
- # swap pi_m and pi_i
40
- pi[m], pi[i] = pi[i], pi[m]
41
- pim = pi[m]
42
-
43
- L[m, pim] = np.sqrt(max_diag_val)
44
-
45
- if m + 1 < n:
46
- row = apply_permutation(
47
- torch.from_numpy(mat), torch.tensor(pim), right_permutation=None
48
- ) # left permutation just swaps row
49
- row = row.numpy().flatten()
50
- pi_i = pi[m + 1 :]
51
- L_m_new = row[pi_i] # length = 9
52
-
53
- if m > 0:
54
- L_prev = L[:m, pi_i]
55
- update = L[:m, pim]
56
- prod = update @ L_prev
57
- L_m_new = L_m_new - prod # np.sum(prod, axis=-1)
58
-
59
- L_m = L[m, :]
60
- L_m_new = L_m_new / L_m[pim]
61
- L_m[pi_i] = L_m_new
62
-
63
- matrix_diag_current = d[pi_i]
64
- d[pi_i] = matrix_diag_current - L_m_new**2
65
-
66
- L[m, :] = L_m
67
- error = np.linalg.norm(d[pi_i], 1) / orig_error
68
- m = m + 1
69
- return L, pi