cofi 0.1.3.dev2__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {cofi-0.1.3.dev2/src/cofi.egg-info → cofi-0.2.0}/PKG-INFO +5 -4
  2. {cofi-0.1.3.dev2 → cofi-0.2.0}/README.md +4 -3
  3. {cofi-0.1.3.dev2 → cofi-0.2.0}/src/cofi/_base_problem.py +25 -114
  4. cofi-0.2.0/src/cofi/_version.py +1 -0
  5. {cofi-0.1.3.dev2 → cofi-0.2.0}/src/cofi/tools/_base_inference_tool.py +4 -2
  6. {cofi-0.1.3.dev2 → cofi-0.2.0}/src/cofi/tools/_scipy_lstsq.py +1 -2
  7. {cofi-0.1.3.dev2 → cofi-0.2.0}/src/cofi/tools/_scipy_opt_min.py +5 -5
  8. cofi-0.2.0/src/cofi/utils/__init__.py +26 -0
  9. cofi-0.2.0/src/cofi/utils/_reg_base.py +214 -0
  10. cofi-0.2.0/src/cofi/utils/_reg_lp_norm.py +371 -0
  11. cofi-0.2.0/src/cofi/utils/_reg_model_cov.py +161 -0
  12. cofi-0.2.0/src/cofi/version.py +1 -0
  13. {cofi-0.1.3.dev2 → cofi-0.2.0/src/cofi.egg-info}/PKG-INFO +5 -4
  14. {cofi-0.1.3.dev2 → cofi-0.2.0}/src/cofi.egg-info/SOURCES.txt +7 -2
  15. cofi-0.2.0/tests/test_base_problem_basics.py +92 -0
  16. cofi-0.2.0/tests/test_base_problem_for_optimization.py +221 -0
  17. cofi-0.2.0/tests/test_base_problem_for_sampling.py +39 -0
  18. cofi-0.2.0/tests/test_base_problem_misc.py +173 -0
  19. cofi-0.1.3.dev2/src/cofi/_version.py +0 -1
  20. cofi-0.1.3.dev2/src/cofi/utils/__init__.py +0 -13
  21. cofi-0.1.3.dev2/src/cofi/utils/_regularization.py +0 -463
  22. cofi-0.1.3.dev2/src/cofi/version.py +0 -1
  23. cofi-0.1.3.dev2/tests/test_base_problem.py +0 -604
  24. {cofi-0.1.3.dev2 → cofi-0.2.0}/LICENCE +0 -0
  25. {cofi-0.1.3.dev2 → cofi-0.2.0}/pyproject.toml +0 -0
  26. {cofi-0.1.3.dev2 → cofi-0.2.0}/setup.cfg +0 -0
  27. {cofi-0.1.3.dev2 → cofi-0.2.0}/src/cofi/__init__.py +0 -0
  28. {cofi-0.1.3.dev2 → cofi-0.2.0}/src/cofi/_exceptions.py +0 -0
  29. {cofi-0.1.3.dev2 → cofi-0.2.0}/src/cofi/_inversion.py +0 -0
  30. {cofi-0.1.3.dev2 → cofi-0.2.0}/src/cofi/_inversion_options.py +0 -0
  31. {cofi-0.1.3.dev2 → cofi-0.2.0}/src/cofi/tools/__init__.py +0 -0
  32. {cofi-0.1.3.dev2 → cofi-0.2.0}/src/cofi/tools/_cofi_simple_newton.py +0 -0
  33. {cofi-0.1.3.dev2 → cofi-0.2.0}/src/cofi/tools/_emcee.py +0 -0
  34. {cofi-0.1.3.dev2 → cofi-0.2.0}/src/cofi/tools/_pytorch_optim.py +0 -0
  35. {cofi-0.1.3.dev2 → cofi-0.2.0}/src/cofi/tools/_scipy_opt_lstsq.py +0 -0
  36. {cofi-0.1.3.dev2 → cofi-0.2.0}/src/cofi.egg-info/dependency_links.txt +0 -0
  37. {cofi-0.1.3.dev2 → cofi-0.2.0}/src/cofi.egg-info/requires.txt +0 -0
  38. {cofi-0.1.3.dev2 → cofi-0.2.0}/src/cofi.egg-info/top_level.txt +0 -0
  39. {cofi-0.1.3.dev2 → cofi-0.2.0}/tests/test_deprecation.py +0 -0
  40. {cofi-0.1.3.dev2 → cofi-0.2.0}/tests/test_inversion.py +0 -0
  41. {cofi-0.1.3.dev2 → cofi-0.2.0}/tests/test_inversion_options.py +0 -0
  42. {cofi-0.1.3.dev2 → cofi-0.2.0}/tests/test_inversion_result.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: cofi
3
- Version: 0.1.3.dev2
3
+ Version: 0.2.0
4
4
  Summary: Common Framework for Inference
5
5
  Author: InLab, CoFI development team
6
6
  Keywords: inversion,inference,python package,geoscience,geophysics
@@ -24,7 +24,6 @@ License-File: LICENCE
24
24
 
25
25
  # <img src="https://raw.githubusercontent.com/inlab-geo/cofi/main/docs/source/_static/latte_art_cropped.png" width="5%" style="vertical-align:bottom"/> CoFI (Common Framework for Inference)
26
26
 
27
-
28
27
  [![PyPI version](https://img.shields.io/pypi/v/cofi?logo=pypi&style=flat-square&color=cae9ff&labelColor=f8f9fa)](https://pypi.org/project/cofi/)
29
28
  [![Conda Version](https://img.shields.io/conda/vn/conda-forge/cofi.svg?logo=condaforge&style=flat-square&color=cce3de&labelColor=f8f9fa&logoColor=344e41)](https://anaconda.org/conda-forge/cofi)
30
29
  [![Documentation Status](https://img.shields.io/readthedocs/cofi?logo=readthedocs&style=flat-square&color=fed9b7&labelColor=f8f9fa&logoColor=eaac8b)](https://cofi.readthedocs.io/en/latest/?badge=latest)
@@ -32,6 +31,9 @@ License-File: LICENCE
32
31
  [![Slack](https://img.shields.io/badge/Slack-InLab_community-4A154B?logo=slack&style=flat-square&color=cdb4db&labelColor=f8f9fa&logoColor=9c89b8)](https://join.slack.com/t/inlab-community/shared_invite/zt-1ejny069z-v5ZyvP2tDjBR42OAu~TkHg)
33
32
  <!-- [![Wheels](https://img.shields.io/pypi/wheel/cofi)](https://pypi.org/project/cofi/) -->
34
33
 
34
+ > Related repositories by [InLab](https://inlab.edu.au/community/):
35
+ > - [CoFI Examples](https://github.com/inlab-geo/cofi-examples)
36
+ > - [Espresso](https://github.com/inlab-geo/espresso)
35
37
 
36
38
  ## Introduction
37
39
 
@@ -39,8 +41,7 @@ CoFI (Common Framework for Inference) is an open source initiative for interfaci
39
41
 
40
42
  With a mission to bridge the gap between the domain expertise and the inference expertise, CoFI provides an interface across a wide range of inference algorithms from different sources, underpinned by a rich set of domain relevant [examples](https://cofi.readthedocs.io/en/latest/examples/generated/index.html).
41
43
 
42
- > This project and [documentation](https://cofi.readthedocs.io/en/latest/) are under initial development stage. Please feel free to contact us for feedback or issues!
43
-
44
+ Read [the documentation](https://cofi.readthedocs.io/en/latest/), and let us know your feedback or any issues!
44
45
 
45
46
  ## Installation
46
47
 
@@ -2,7 +2,6 @@
2
2
 
3
3
  # <img src="https://raw.githubusercontent.com/inlab-geo/cofi/main/docs/source/_static/latte_art_cropped.png" width="5%" style="vertical-align:bottom"/> CoFI (Common Framework for Inference)
4
4
 
5
-
6
5
  [![PyPI version](https://img.shields.io/pypi/v/cofi?logo=pypi&style=flat-square&color=cae9ff&labelColor=f8f9fa)](https://pypi.org/project/cofi/)
7
6
  [![Conda Version](https://img.shields.io/conda/vn/conda-forge/cofi.svg?logo=condaforge&style=flat-square&color=cce3de&labelColor=f8f9fa&logoColor=344e41)](https://anaconda.org/conda-forge/cofi)
8
7
  [![Documentation Status](https://img.shields.io/readthedocs/cofi?logo=readthedocs&style=flat-square&color=fed9b7&labelColor=f8f9fa&logoColor=eaac8b)](https://cofi.readthedocs.io/en/latest/?badge=latest)
@@ -10,6 +9,9 @@
10
9
  [![Slack](https://img.shields.io/badge/Slack-InLab_community-4A154B?logo=slack&style=flat-square&color=cdb4db&labelColor=f8f9fa&logoColor=9c89b8)](https://join.slack.com/t/inlab-community/shared_invite/zt-1ejny069z-v5ZyvP2tDjBR42OAu~TkHg)
11
10
  <!-- [![Wheels](https://img.shields.io/pypi/wheel/cofi)](https://pypi.org/project/cofi/) -->
12
11
 
12
+ > Related repositories by [InLab](https://inlab.edu.au/community/):
13
+ > - [CoFI Examples](https://github.com/inlab-geo/cofi-examples)
14
+ > - [Espresso](https://github.com/inlab-geo/espresso)
13
15
 
14
16
  ## Introduction
15
17
 
@@ -17,8 +19,7 @@ CoFI (Common Framework for Inference) is an open source initiative for interfaci
17
19
 
18
20
  With a mission to bridge the gap between the domain expertise and the inference expertise, CoFI provides an interface across a wide range of inference algorithms from different sources, underpinned by a rich set of domain relevant [examples](https://cofi.readthedocs.io/en/latest/examples/generated/index.html).
19
21
 
20
- > This project and [documentation](https://cofi.readthedocs.io/en/latest/) are under initial development stage. Please feel free to contact us for feedback or issues!
21
-
22
+ Read [the documentation](https://cofi.readthedocs.io/en/latest/), and let us know your feedback or any issues!
22
23
 
23
24
  ## Installation
24
25
 
@@ -5,6 +5,7 @@ import json
5
5
 
6
6
  import numpy as np
7
7
 
8
+ from .utils import BaseRegularization
8
9
  from ._exceptions import (
9
10
  DimensionMismatchError,
10
11
  InvalidOptionError,
@@ -197,7 +198,6 @@ class BaseProblem:
197
198
  BaseProblem.data_misfit
198
199
  BaseProblem.regularization
199
200
  BaseProblem.regularization_matrix
200
- BaseProblem.regularization_factor
201
201
  BaseProblem.forward
202
202
  BaseProblem.name
203
203
  BaseProblem.data
@@ -230,7 +230,6 @@ class BaseProblem:
230
230
  "data_misfit",
231
231
  "regularization",
232
232
  "regularization_matrix",
233
- "regularization_factor",
234
233
  "forward",
235
234
  "data",
236
235
  "data_covariance",
@@ -968,8 +967,7 @@ class BaseProblem:
968
967
 
969
968
  def set_regularization(
970
969
  self,
971
- regularization: Union[str, Callable[[np.ndarray], Number]],
972
- regularization_factor: Number = 1,
970
+ regularization: Union[Callable[[np.ndarray], Number], BaseRegularization],
973
971
  regularization_matrix: Union[
974
972
  np.ndarray, Callable[[np.ndarray], np.ndarray]
975
973
  ] = None,
@@ -989,11 +987,6 @@ class BaseProblem:
989
987
  regularization : str or (function - np.ndarray -> Number)
990
988
  either a string from pre-built functions above, or a regularization function that
991
989
  matches :meth:`regularization` in signature.
992
- regularization_factor : Number, optional
993
- the regularization factor (lamda) that adjusts the ratio of the regularization
994
- term over the data misfit, by default 1. If ``regularization`` and ``data_misfit``
995
- are set but ``objective`` isn't, then we will generate ``objective`` function as
996
- following: :math:`\text{objective}(model)=\text{data_misfit}(model)+\text{factor}\times\text{regularization}(model)`
997
990
  regularization_matrix : np.ndarray or (function - np.ndarray -> np.ndarray)
998
991
  a matrix of shape ``(model_size, model_size)``, or a function that takes in
999
992
  a model and calculates the (weighting) matrix.
@@ -1010,11 +1003,6 @@ class BaseProblem:
1010
1003
  kwargs : dict, optional
1011
1004
  extra dict of keyword arguments for regularization function
1012
1005
 
1013
- Raises
1014
- ------
1015
- InvalidOptionError
1016
- when you've passed in a string not in our supported regularization list
1017
-
1018
1006
  Examples
1019
1007
  --------
1020
1008
 
@@ -1023,87 +1011,41 @@ class BaseProblem:
1023
1011
  >>> from cofi import BaseProblem
1024
1012
  >>> inv_problem = BaseProblem()
1025
1013
 
1026
- 1. Example with an L1 norm
1027
-
1028
- >>> inv_problem.set_regularization(1)
1029
- >>> inv_problem.regularization([1,1])
1030
- 2
1031
-
1032
- 2. Example with an inf norm
1033
-
1034
- >>> inv_problem.set_regularization("inf")
1035
- >>> inv_problem.regularization([1,1])
1036
- 1
1037
-
1038
- 3. Example with a custom regularization function
1014
+ 1. Example with a custom regularization function
1039
1015
 
1040
1016
  >>> inv_problem.set_regularization(lambda x: sum(x))
1041
1017
  >>> inv_problem.regularization([1,1])
1042
1018
  2
1043
1019
 
1044
- 4. Example with an L2 norm and regularization factor of 0.5 (by default 1)
1020
+ 2. Example with a custom regularization + a regularization matrix
1045
1021
 
1046
- >>> inv_problem.set_regularization(2, 0.5)
1022
+ >>> inv_problem.set_regularization(lambda x: np.sum(x**2), np.eye(3))
1047
1023
  >>> inv_problem.regularization([1,1])
1048
- 0.7071067811865476
1049
-
1050
- 5. Example with a regularization matrix
1051
-
1052
- >>> inv_problem.set_regularization(2, 0.5, np.array([[2,0], [0,1]]))
1053
- >>> inv_problem.regularization([1,1])
1054
- 1.118033988749895
1024
+ 2
1055
1025
  """
1056
- # preprocess regularization_matrix
1057
- if np.ndim(regularization_matrix) != 0:
1026
+ # preprocess regularization_matrix if there is one
1027
+ _reg_matrix = None
1028
+ if regularization_matrix is not None:
1029
+ _reg_matrix = regularization_matrix
1030
+ elif isinstance(regularization, BaseRegularization) and hasattr(
1031
+ regularization, "matrix"
1032
+ ):
1033
+ _reg_matrix = regularization.matrix
1034
+ # wrap regularization_matrix as a function
1035
+ if _reg_matrix is not None and np.ndim(_reg_matrix) != 0:
1058
1036
  self.regularization_matrix = _FunctionWrapper(
1059
- "regularization_matrix", _matrix_to_func, args=[regularization_matrix]
1037
+ "regularization_matrix", _matrix_to_func, args=[_reg_matrix]
1060
1038
  )
1061
- elif callable(regularization_matrix):
1039
+ elif _reg_matrix is not None and callable(_reg_matrix):
1062
1040
  self.regularization_matrix = _FunctionWrapper(
1063
- "regularization_matrix", regularization_matrix
1041
+ "regularization_matrix", _reg_matrix
1064
1042
  )
1065
1043
  else:
1066
1044
  self.regularization_matrix = None
1067
- # preprocess regularization function without lambda
1068
- if isinstance(regularization, (Number, str)) or not regularization:
1069
- order = regularization
1070
- if (
1071
- isinstance(order, str)
1072
- and order not in ["fro", "nuc", "inf", "-inf"]
1073
- or isinstance(order, Number)
1074
- and order < 0
1075
- ):
1076
- raise InvalidOptionError(
1077
- name="regularization order",
1078
- invalid_option=order,
1079
- valid_options=(
1080
- "[None, 'fro', 'nuc', numpy.inf, -numpy.inf] or any positive"
1081
- " number"
1082
- ),
1083
- )
1084
- elif isinstance(order, str) and order in ["inf", "-inf"]:
1085
- order = float(order)
1086
- _reg = _FunctionWrapper(
1087
- "regularization_none_lamda", np.linalg.norm, args=[order]
1088
- )
1089
- else:
1090
- _reg = _FunctionWrapper(
1091
- "regularization_none_lamda", regularization, args, kwargs
1092
- )
1093
- # wrapper function that calculates: lambda * raw regularization value
1094
- self._regularization_factor = regularization_factor
1095
- if self.regularization_matrix is None:
1096
- self.regularization = _FunctionWrapper(
1097
- "regularization",
1098
- _regularization_with_lamda,
1099
- args=[_reg, regularization_factor],
1100
- )
1101
- else:
1102
- self.regularization = _FunctionWrapper(
1103
- "regularization",
1104
- _regularization_with_lamda_n_matrix,
1105
- args=[_reg, regularization_factor, self.regularization_matrix],
1106
- )
1045
+ # process regularization function
1046
+ self.regularization = _FunctionWrapper(
1047
+ "regularization", regularization, args, kwargs
1048
+ )
1107
1049
  # update some autogenerated functions (as usual)
1108
1050
  self._update_autogen("regularization")
1109
1051
 
@@ -1444,24 +1386,6 @@ class BaseProblem:
1444
1386
  return self._blobs_dtype
1445
1387
  raise NotDefinedError(needs="blobs name and type")
1446
1388
 
1447
- @property
1448
- def regularization_factor(self) -> Number:
1449
- r"""regularization factor (lambda) that adjusts weights of the regularization
1450
- term
1451
-
1452
- Raises
1453
- ------
1454
- NotDefinedError
1455
- when this property has not been defined (by
1456
- :meth:`set_regularization`
1457
- """
1458
- if (
1459
- hasattr(self, "_regularization_factor")
1460
- and self._regularization_factor is not None
1461
- ):
1462
- return self._regularization_factor
1463
- raise NotDefinedError(needs="regularization_factor (lamda)")
1464
-
1465
1389
  @property
1466
1390
  def bounds(self):
1467
1391
  r"""TODO: document me!
@@ -1597,11 +1521,6 @@ class BaseProblem:
1597
1521
  r"""indicates whether :meth:`blobs_dtype` has been defined"""
1598
1522
  return self._check_property_defined("blobs_dtype")
1599
1523
 
1600
- @property
1601
- def regularization_factor_defined(self) -> bool:
1602
- r"""indicates whether :meth:`regularization_factor` has been defined"""
1603
- return self._check_property_defined("regularization_factor")
1604
-
1605
1524
  @property
1606
1525
  def bounds_defined(self) -> bool:
1607
1526
  r"""indicates whether :meth:`bounds` has been defined"""
@@ -1718,12 +1637,12 @@ class BaseProblem:
1718
1637
  if self.data_covariance_inv_defined:
1719
1638
  if _is_diag(self.data_covariance_inv):
1720
1639
  weighted_res = np.diag(self.data_covariance_inv) * res
1721
- return np.sum(np.square(weighted_res))
1640
+ return res @ weighted_res
1722
1641
  else:
1723
1642
  return res.T @ self.data_covariance_inv @ res
1724
1643
  elif self.data_covariance_defined and _is_diag(self.data_covariance):
1725
1644
  weighted_res = res / np.diag(self.data_covariance)
1726
- return np.sum(np.square(weighted_res))
1645
+ return res @ weighted_res
1727
1646
  else:
1728
1647
  return np.sum(np.square(res))
1729
1648
  except Exception as exception:
@@ -1895,14 +1814,6 @@ def _jacobian_times_vector_from_jcb(model, vector, jacobian):
1895
1814
  ) from exception
1896
1815
 
1897
1816
 
1898
- def _regularization_with_lamda(model, reg_func, lamda):
1899
- return lamda * reg_func(model)
1900
-
1901
-
1902
- def _regularization_with_lamda_n_matrix(model, reg_func, lamda, reg_matrix_func):
1903
- return lamda * reg_func(reg_matrix_func(model) @ model)
1904
-
1905
-
1906
1817
  def _matrix_to_func(_, matrix):
1907
1818
  return matrix
1908
1819
 
@@ -0,0 +1 @@
1
+ __version__ = "0.2.0"
@@ -310,8 +310,10 @@ def error_handler(when, context):
310
310
  return func(*args, **kwargs)
311
311
  except Exception as e:
312
312
  raise CofiError(
313
- f"error ocurred {when} ({context}). Check exception details "
314
- "from message above.",
313
+ (
314
+ f"error ocurred {when} ({context}). Check exception details "
315
+ "from message above."
316
+ ),
315
317
  ) from e
316
318
 
317
319
  return wrapped_func
@@ -119,7 +119,6 @@ class ScipyLstSq(BaseInferenceTool):
119
119
  )
120
120
  # get lamda and L matrix if needed
121
121
  if self._params["with_tikhonov"]:
122
- self._lamda = inv_problem.regularization_factor
123
122
  if inv_problem.regularization_matrix_defined:
124
123
  try:
125
124
  _L = inv_problem.regularization_matrix(dummy_model)
@@ -133,7 +132,7 @@ class ScipyLstSq(BaseInferenceTool):
133
132
  self._components_used.append("regularization_matrix")
134
133
  else:
135
134
  self._LtL = np.identity(self._G.shape[1])
136
- self._a += self._lamda * self._LtL
135
+ self._a += self._LtL
137
136
 
138
137
  def __call__(self) -> dict:
139
138
  res_p, residual, rank, singular_vals = self._call_lstsq()
@@ -6,20 +6,20 @@ from . import BaseInferenceTool, error_handler
6
6
  # Official documentation for scipy.optimize.minimize
7
7
  # https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html
8
8
 
9
- # 'jac' is only for:
9
+ # 'jac' will be used when choosing the following methods:
10
10
  # CG, BFGS, Newton-CG, L-BFGS-B, TNC, SLSQP, dogleg, trust-ncg, trust-krylov, trust-exact
11
11
  # and trust-constr
12
12
 
13
- # 'hess' is only for:
13
+ # 'hess' will be used when choosing the following methods:
14
14
  # Newton-CG, dogleg, trust-ncg, trust-krylov, trust-exact and trust-constr
15
15
 
16
- # 'hessp' is only for:
16
+ # 'hessp' will be used when choosing the following methods:
17
17
  # Newton-CG, trust-ncg, trust-krylov, trust-constr
18
18
 
19
- # 'bounds' is only for:
19
+ # 'bounds' will be used when choosing the following methods:
20
20
  # Nelder-Mead, L-BFGS-B, TNC, SLSQP, Powell, and trust-constr
21
21
 
22
- # 'constraints' is only for:
22
+ # 'constraints' will be used when choosing the following methods:
23
23
  # COBYLA, SLSQP and trust-constr
24
24
 
25
25
  # other arguments include: tol, options, callback
@@ -0,0 +1,26 @@
1
+ r"""Utility classes and functions (e.g. to generate regularization terms and more)
2
+
3
+ The class inheritance of regularization classes:
4
+
5
+ .. mermaid::
6
+
7
+ graph TD;
8
+ BaseRegularization --> LpNormRegularization;
9
+ LpNormRegularization --> QuadraticReg;
10
+ BaseRegularization --> ModelCovariance;
11
+ ModelCovariance --> GaussianPrior;
12
+
13
+ """
14
+
15
+ from ._reg_base import BaseRegularization
16
+ from ._reg_lp_norm import LpNormRegularization, QuadraticReg
17
+ from ._reg_model_cov import ModelCovariance, GaussianPrior
18
+
19
+
20
+ __all__ = [
21
+ "BaseRegularization",
22
+ "LpNormRegularization",
23
+ "QuadraticReg",
24
+ "ModelCovariance",
25
+ "GaussianPrior",
26
+ ]
@@ -0,0 +1,214 @@
1
+ from abc import abstractmethod, ABCMeta
2
+ from numbers import Number
3
+ from functools import reduce
4
+ import numpy as np
5
+
6
+ from .._exceptions import DimensionMismatchError
7
+
8
+
9
+ class BaseRegularization(metaclass=ABCMeta):
10
+ r"""Base class for a regularization term
11
+
12
+ Check :class:`QuadraticReg` for a concrete example.
13
+
14
+ .. rubric:: Basic interface
15
+
16
+ The basic properties / methods for a regularization term in ``cofi.utils``
17
+ include the following:
18
+
19
+ .. autosummary::
20
+ BaseRegularization.model_size
21
+ BaseRegularization.reg
22
+ BaseRegularization.gradient
23
+ BaseRegularization.hessian
24
+ BaseRegularization.__call__
25
+
26
+ .. rubric:: Adding two terms
27
+
28
+ Two instances of ``BaseRegularization`` can also be added together using the ``+``
29
+ operator:
30
+
31
+ .. autosummary::
32
+ BaseRegularization.__add__
33
+
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ ):
39
+ pass
40
+
41
+ @property
42
+ @abstractmethod
43
+ def model_shape(self) -> tuple:
44
+ """the shape of models that current regularization function accepts"""
45
+ raise NotImplementedError
46
+
47
+ @property
48
+ def model_size(self) -> Number:
49
+ """the number of unknowns that current regularization function accepts"""
50
+ return reduce(lambda a, b: a * b, np.array(self.model_shape), 1)
51
+
52
+ def __call__(self, model: np.ndarray) -> Number:
53
+ r"""a class instance itself can also be called as a function
54
+
55
+ It works exactly the same as :meth:`reg`.
56
+
57
+ In other words, the following two usages are exactly the same::
58
+
59
+ >>> my_reg = QuadraticReg(factor=1, model_size=3)
60
+ >>> my_reg_value = my_reg(np.array([1,2,3])) # usage 1
61
+ >>> my_reg_value = my_reg.reg(np.array([1,2,3])) # usage 2
62
+ """
63
+ return self.reg(model)
64
+
65
+ @abstractmethod
66
+ def reg(self, model: np.ndarray) -> Number:
67
+ """the regularization function value given a model to evaluate"""
68
+ raise NotImplementedError
69
+
70
+ @abstractmethod
71
+ def gradient(self, model: np.ndarray) -> np.ndarray:
72
+ """the gradient of regularization function with respect to model given a model
73
+
74
+ The usual size for the gradient is :math:`(M,)` where :math:`M` is the number
75
+ of model parameters
76
+ """
77
+ raise NotImplementedError
78
+
79
+ @abstractmethod
80
+ def hessian(self, model: np.ndarray) -> np.ndarray:
81
+ """the hessian of regularization function with respect to model given a model
82
+
83
+ The usual size for the Hessian is :math:`(M,M)` where :math:`M` is the number
84
+ of model parameters
85
+ """
86
+ raise NotImplementedError
87
+
88
+ def __add__(self, other_reg):
89
+ r"""Adds two regularization terms
90
+
91
+ Parameters
92
+ ----------
93
+ other_reg : BaseRegularization
94
+ the second argument of "+" operator; must also be a
95
+ :class:`BaseRegularization` instance
96
+
97
+ Returns
98
+ -------
99
+ BaseRegularization
100
+ a regularization term ``resRegularization`` such that:
101
+
102
+ - :math:`\text{resRegularization.reg}(m)=\text{self.reg}(m)+\text{other_reg.reg}(m)`
103
+ - :math:`\text{resRegularization.gradient}(m)=\text{self.gradient}(m)+\text{other_reg.gradient}(m)`
104
+ - :math:`\text{resRegularization.hessian}(m)=\text{self.hessian}(m)+\text{other_reg.hessian}(m)`
105
+
106
+ Raises
107
+ ------
108
+ TypeError
109
+ when the ``other_reg`` is not a regularization term generated by CoFI Utils
110
+ DimensionMismatchError
111
+ when the ``other_reg`` doesn't accept model_size that matches the one of
112
+ ``self``
113
+
114
+ Examples
115
+ --------
116
+
117
+ >>> from cofi import BaseProblem
118
+ >>> from cofi.utils import QuadraticReg
119
+ >>> reg1 = QuadraticReg(model_shape=(3,), weighting_matrix="damping")
120
+ >>> reg2 = QuadraticReg(model_shape=(3,), weighting_matrix="smoothing")
121
+ >>> my_problem = BaseProblem()
122
+ >>> my_problem.set_regularization(reg1 + reg2)
123
+
124
+ """
125
+ if not isinstance(other_reg, BaseRegularization):
126
+ raise TypeError(
127
+ f"unsupported operand type(s) for +: '{self.__class__.__name__}' "
128
+ f"and '{other_reg.__class__.__name__}"
129
+ )
130
+ if self.model_size != other_reg.model_size:
131
+ raise DimensionMismatchError(
132
+ entered_name="the second regularization term",
133
+ entered_dimension=other_reg.model_size,
134
+ expected_source="the first regularization term",
135
+ expected_dimension=self.model_size,
136
+ )
137
+ tmp_model_shape = self.model_shape
138
+ tmp_reg = self.reg
139
+ tmp_grad = self.gradient
140
+ tmp_hess = self.hessian
141
+
142
+ class CompositeRegularization(BaseRegularization):
143
+ @property
144
+ def model_shape(self):
145
+ return tmp_model_shape
146
+
147
+ def reg(self, model):
148
+ return tmp_reg(model) + other_reg(model)
149
+
150
+ def gradient(self, model):
151
+ return tmp_grad(model) + other_reg.gradient(model)
152
+
153
+ def hessian(self, model):
154
+ return tmp_hess(model) + other_reg.hessian(model)
155
+
156
+ return CompositeRegularization()
157
+
158
+ def __rmul__(self, coefficient):
159
+ r"""Multiply a regularization term with a constant number
160
+
161
+ Parameters
162
+ ----------
163
+ coefficient : Number
164
+ the first argument of "*" operator; must be a Number
165
+
166
+ Returns
167
+ -------
168
+ BaseRegularization
169
+ a regularization term ``resRegularization`` such that:
170
+
171
+ - :math:`\text{resRegularization.reg}(m)=\text{coefficient}\times\text{self.reg}(m)`
172
+ - :math:`\text{resRegularization.gradient}(m)=\text{coefficient}\times\text{self.gradient}(m)`
173
+ - :math:`\text{resRegularization.hessian}(m)=\text{coefficient}\times\text{self.hessian}(m)`
174
+
175
+ Raises
176
+ ------
177
+ TypeError
178
+ when the ``coefficient`` is not of a python Number type
179
+
180
+ Examples
181
+ --------
182
+
183
+ >>> from cofi import BaseProblem
184
+ >>> from cofi.utils import QuadraticReg
185
+ >>> reg = QuadraticReg(model_shape=(3,), weighting_matrix="damping")
186
+ >>> my_problem = BaseProblem()
187
+ >>> my_problem.set_regularization(10 * reg)
188
+
189
+ """
190
+ if not isinstance(coefficient, Number):
191
+ raise TypeError(
192
+ f"unsupported operand type(s) for *: '{coefficient.__class__.__name__}'"
193
+ f" and '{self.__class__.__name__}"
194
+ )
195
+ tmp_model_shape = self.model_shape
196
+ tmp_reg = self.reg
197
+ tmp_grad = self.gradient
198
+ tmp_hess = self.hessian
199
+
200
+ class CompositeRegularization(BaseRegularization):
201
+ @property
202
+ def model_shape(self):
203
+ return tmp_model_shape
204
+
205
+ def reg(self, model):
206
+ return coefficient * tmp_reg(model)
207
+
208
+ def gradient(self, model):
209
+ return coefficient * tmp_grad(model)
210
+
211
+ def hessian(self, model):
212
+ return coefficient * tmp_hess(model)
213
+
214
+ return CompositeRegularization()