pysolverkit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pysolverkit/__init__.py +70 -0
- pysolverkit/enums.py +91 -0
- pysolverkit/functions/__init__.py +17 -0
- pysolverkit/functions/base.py +445 -0
- pysolverkit/functions/elementary.py +192 -0
- pysolverkit/functions/fem2d.py +128 -0
- pysolverkit/functions/multivariate.py +57 -0
- pysolverkit/linalg/__init__.py +5 -0
- pysolverkit/linalg/linear_system.py +134 -0
- pysolverkit/linalg/matrix.py +42 -0
- pysolverkit/linalg/vector.py +46 -0
- pysolverkit/ode/__init__.py +12 -0
- pysolverkit/ode/base.py +6 -0
- pysolverkit/ode/first_order.py +297 -0
- pysolverkit/ode/second_order.py +323 -0
- pysolverkit/util.py +62 -0
- pysolverkit-0.1.0.dist-info/METADATA +58 -0
- pysolverkit-0.1.0.dist-info/RECORD +20 -0
- pysolverkit-0.1.0.dist-info/WHEEL +5 -0
- pysolverkit-0.1.0.dist-info/top_level.txt +1 -0
pysolverkit/__init__.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from .functions import (
|
|
2
|
+
Function,
|
|
3
|
+
Polynomial,
|
|
4
|
+
Exponent,
|
|
5
|
+
Sin,
|
|
6
|
+
Cos,
|
|
7
|
+
Tan,
|
|
8
|
+
Log,
|
|
9
|
+
MultiVariableFunction,
|
|
10
|
+
BivariateFunction,
|
|
11
|
+
FEM2D,
|
|
12
|
+
)
|
|
13
|
+
from .linalg import Vector, Matrix, LinearSystem
|
|
14
|
+
from .ode import (
|
|
15
|
+
OrdinaryDifferentialEquation,
|
|
16
|
+
LinearODE,
|
|
17
|
+
FirstOrderLinearODE,
|
|
18
|
+
SecondOrderLinearODE_BVP,
|
|
19
|
+
SecondOrderODE_IVP,
|
|
20
|
+
SecondOrderODE_BVP,
|
|
21
|
+
)
|
|
22
|
+
from .util import Util
|
|
23
|
+
from .enums import (
|
|
24
|
+
DifferentiationMethod,
|
|
25
|
+
IntegrationMethod,
|
|
26
|
+
RootFindingMethod,
|
|
27
|
+
InterpolationMethod,
|
|
28
|
+
InterpolationForm,
|
|
29
|
+
ODEMethod,
|
|
30
|
+
BVPMethod,
|
|
31
|
+
NonlinearBVPMethod,
|
|
32
|
+
LinearSolverMethod,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
__all__ = [
|
|
36
|
+
# Functions
|
|
37
|
+
"Function",
|
|
38
|
+
"Polynomial",
|
|
39
|
+
"Exponent",
|
|
40
|
+
"Sin",
|
|
41
|
+
"Cos",
|
|
42
|
+
"Tan",
|
|
43
|
+
"Log",
|
|
44
|
+
"MultiVariableFunction",
|
|
45
|
+
"BivariateFunction",
|
|
46
|
+
"FEM2D",
|
|
47
|
+
# Linear algebra
|
|
48
|
+
"Vector",
|
|
49
|
+
"Matrix",
|
|
50
|
+
"LinearSystem",
|
|
51
|
+
# ODEs
|
|
52
|
+
"OrdinaryDifferentialEquation",
|
|
53
|
+
"LinearODE",
|
|
54
|
+
"FirstOrderLinearODE",
|
|
55
|
+
"SecondOrderLinearODE_BVP",
|
|
56
|
+
"SecondOrderODE_IVP",
|
|
57
|
+
"SecondOrderODE_BVP",
|
|
58
|
+
# Utilities
|
|
59
|
+
"Util",
|
|
60
|
+
# Enums
|
|
61
|
+
"DifferentiationMethod",
|
|
62
|
+
"IntegrationMethod",
|
|
63
|
+
"RootFindingMethod",
|
|
64
|
+
"InterpolationMethod",
|
|
65
|
+
"InterpolationForm",
|
|
66
|
+
"ODEMethod",
|
|
67
|
+
"BVPMethod",
|
|
68
|
+
"NonlinearBVPMethod",
|
|
69
|
+
"LinearSolverMethod",
|
|
70
|
+
]
|
pysolverkit/enums.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class DifferentiationMethod(Enum):
|
|
5
|
+
"""Numerical differentiation method."""
|
|
6
|
+
|
|
7
|
+
FORWARD = "forward"
|
|
8
|
+
BACKWARD = "backward"
|
|
9
|
+
CENTRAL = "central"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class IntegrationMethod(Enum):
|
|
13
|
+
"""Numerical integration (quadrature) method."""
|
|
14
|
+
|
|
15
|
+
RECTANGULAR = "rectangular"
|
|
16
|
+
MIDPOINT = "midpoint"
|
|
17
|
+
TRAPEZOIDAL = "trapezoidal"
|
|
18
|
+
SIMPSON = "simpson"
|
|
19
|
+
GAUSS = "gauss"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class RootFindingMethod(Enum):
|
|
23
|
+
"""Root-finding method for scalar equations."""
|
|
24
|
+
|
|
25
|
+
BISECTION = "bisection"
|
|
26
|
+
NEWTON = "newton"
|
|
27
|
+
MODIFIED_NEWTON = "modified_newton"
|
|
28
|
+
SECANT = "secant"
|
|
29
|
+
REGULA_FALSI = "regula_falsi"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class InterpolationMethod(Enum):
|
|
33
|
+
"""Polynomial interpolation method."""
|
|
34
|
+
|
|
35
|
+
LAGRANGE = "lagrange"
|
|
36
|
+
NEWTON = "newton"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class InterpolationForm(Enum):
|
|
40
|
+
"""Form of Newton interpolating polynomial."""
|
|
41
|
+
|
|
42
|
+
STANDARD = "standard"
|
|
43
|
+
FORWARD_DIFF = "forward_diff"
|
|
44
|
+
BACKWARD_DIFF = "backward_diff"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class ODEMethod(Enum):
|
|
48
|
+
"""Numerical solver for first-order ODE initial value problems."""
|
|
49
|
+
|
|
50
|
+
EULER = "euler"
|
|
51
|
+
RUNGE_KUTTA = "runge-kutta"
|
|
52
|
+
TAYLOR = "taylor"
|
|
53
|
+
TRAPEZOIDAL = "trapezoidal"
|
|
54
|
+
ADAMS_BASHFORTH = "adam-bashforth"
|
|
55
|
+
ADAMS_MOULTON = "adam-moulton"
|
|
56
|
+
PREDICTOR_CORRECTOR = "predictor-corrector"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class BVPMethod(Enum):
|
|
60
|
+
"""Solver for second-order linear boundary value problems."""
|
|
61
|
+
|
|
62
|
+
SHOOTING = "shooting"
|
|
63
|
+
FINITE_DIFFERENCE = "finite_difference"
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class NonlinearBVPMethod(Enum):
|
|
67
|
+
"""Solver for second-order nonlinear boundary value problems."""
|
|
68
|
+
|
|
69
|
+
SHOOTING_NEWTON = "shooting_newton"
|
|
70
|
+
FINITE_DIFFERENCE = "finite_difference"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class LinearSolverMethod(Enum):
|
|
74
|
+
"""Solver for systems of linear equations."""
|
|
75
|
+
|
|
76
|
+
GAUSS_ELIMINATION = "gauss_elimination"
|
|
77
|
+
GAUSS_JACOBI = "gauss_jacobi"
|
|
78
|
+
GAUSS_SEIDEL = "gauss_seidel"
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
__all__ = [
|
|
82
|
+
"DifferentiationMethod",
|
|
83
|
+
"IntegrationMethod",
|
|
84
|
+
"RootFindingMethod",
|
|
85
|
+
"InterpolationMethod",
|
|
86
|
+
"InterpolationForm",
|
|
87
|
+
"ODEMethod",
|
|
88
|
+
"BVPMethod",
|
|
89
|
+
"NonlinearBVPMethod",
|
|
90
|
+
"LinearSolverMethod",
|
|
91
|
+
]
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from .base import Function
|
|
2
|
+
from .elementary import Polynomial, Exponent, Sin, Cos, Tan, Log
|
|
3
|
+
from .multivariate import MultiVariableFunction, BivariateFunction
|
|
4
|
+
from .fem2d import FEM2D
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"Function",
|
|
8
|
+
"Polynomial",
|
|
9
|
+
"Exponent",
|
|
10
|
+
"Sin",
|
|
11
|
+
"Cos",
|
|
12
|
+
"Tan",
|
|
13
|
+
"Log",
|
|
14
|
+
"MultiVariableFunction",
|
|
15
|
+
"BivariateFunction",
|
|
16
|
+
"FEM2D",
|
|
17
|
+
]
|
|
@@ -0,0 +1,445 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from fractions import Fraction
|
|
5
|
+
from typing import TYPE_CHECKING, Callable
|
|
6
|
+
|
|
7
|
+
from ..enums import DifferentiationMethod, IntegrationMethod, RootFindingMethod
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Function:
|
|
14
|
+
"""A callable mathematical function with composable arithmetic and numerical methods.
|
|
15
|
+
|
|
16
|
+
A :class:`Function` wraps any callable ``f: float -> float`` and provides:
|
|
17
|
+
|
|
18
|
+
* **Arithmetic operators** — ``+``, ``-``, ``*``, ``/``, ``**`` (and their
|
|
19
|
+
reflected variants) that return new :class:`Function` objects.
|
|
20
|
+
* **Differentiation** — :meth:`differentiate` and :meth:`multi_differentiate`
|
|
21
|
+
via finite differences, or by supplying an exact derivative.
|
|
22
|
+
* **Integration** — :meth:`integrate` with several quadrature rules.
|
|
23
|
+
* **Root finding** — :meth:`root` with bisection, Newton-Raphson, secant,
|
|
24
|
+
Regula Falsi, and modified Newton methods.
|
|
25
|
+
* **Fixed-point iteration** — :meth:`fixed_point`.
|
|
26
|
+
* **Plotting** — :meth:`plot` (requires ``matplotlib``).
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
function:
|
|
31
|
+
Any callable ``f(x) -> float``.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, function: Callable) -> None:
|
|
35
|
+
self.function = function
|
|
36
|
+
|
|
37
|
+
# ------------------------------------------------------------------
|
|
38
|
+
# Calling
|
|
39
|
+
# ------------------------------------------------------------------
|
|
40
|
+
|
|
41
|
+
def __call__(self, x):
|
|
42
|
+
if callable(x):
|
|
43
|
+
return Function(lambda p: self(x(p)))
|
|
44
|
+
return self.function(x)
|
|
45
|
+
|
|
46
|
+
# ------------------------------------------------------------------
|
|
47
|
+
# Arithmetic operators
|
|
48
|
+
# ------------------------------------------------------------------
|
|
49
|
+
|
|
50
|
+
def __add__(self, other: Function) -> Function:
|
|
51
|
+
return Function(lambda x: self(x) + other(x))
|
|
52
|
+
|
|
53
|
+
def __sub__(self, other: Function) -> Function:
|
|
54
|
+
return Function(lambda x: self(x) - other(x))
|
|
55
|
+
|
|
56
|
+
def __mul__(self, other: Function) -> Function:
|
|
57
|
+
return Function(lambda x: self(x) * other(x))
|
|
58
|
+
|
|
59
|
+
def __truediv__(self, other: Function) -> Function:
|
|
60
|
+
return Function(lambda x: self(x) / other(x))
|
|
61
|
+
|
|
62
|
+
def __pow__(self, other: Function) -> Function:
|
|
63
|
+
return Function(lambda x: self(x) ** other(x))
|
|
64
|
+
|
|
65
|
+
def __radd__(self, other: float) -> Function:
|
|
66
|
+
return Function(lambda x: other + self(x))
|
|
67
|
+
|
|
68
|
+
def __rsub__(self, other: float) -> Function:
|
|
69
|
+
return Function(lambda x: other - self(x))
|
|
70
|
+
|
|
71
|
+
def __rmul__(self, other: float) -> Function:
|
|
72
|
+
return Function(lambda x: other * self(x))
|
|
73
|
+
|
|
74
|
+
def __rtruediv__(self, other: float) -> Function:
|
|
75
|
+
return Function(lambda x: other / self(x))
|
|
76
|
+
|
|
77
|
+
def __neg__(self) -> Function:
|
|
78
|
+
return Function(lambda x: -self(x))
|
|
79
|
+
|
|
80
|
+
# ------------------------------------------------------------------
|
|
81
|
+
# Differentiation
|
|
82
|
+
# ------------------------------------------------------------------
|
|
83
|
+
|
|
84
|
+
def differentiate(
|
|
85
|
+
self,
|
|
86
|
+
func: Function | Callable | None = None,
|
|
87
|
+
h: float = 1e-5,
|
|
88
|
+
method: DifferentiationMethod = DifferentiationMethod.FORWARD,
|
|
89
|
+
) -> Function | None:
|
|
90
|
+
"""Return a numerical derivative, or set an exact derivative.
|
|
91
|
+
|
|
92
|
+
* If *func* is ``None`` (default), returns a :class:`Function`
|
|
93
|
+
approximating the derivative via finite differences.
|
|
94
|
+
* If *func* is a :class:`Function` or callable, stores it as the
|
|
95
|
+
exact derivative (accessible in Newton-type root finders).
|
|
96
|
+
|
|
97
|
+
Parameters
|
|
98
|
+
----------
|
|
99
|
+
func:
|
|
100
|
+
Exact derivative function, or ``None`` to use finite differences.
|
|
101
|
+
h:
|
|
102
|
+
Step size for finite-difference approximation.
|
|
103
|
+
method:
|
|
104
|
+
Finite-difference scheme: ``DifferentiationMethod.FORWARD``,
|
|
105
|
+
``BACKWARD``, or ``CENTRAL``.
|
|
106
|
+
"""
|
|
107
|
+
if isinstance(func, Function):
|
|
108
|
+
self.derivative = func
|
|
109
|
+
return None
|
|
110
|
+
if callable(func):
|
|
111
|
+
self.derivative = Function(func)
|
|
112
|
+
return None
|
|
113
|
+
|
|
114
|
+
if method is DifferentiationMethod.FORWARD:
|
|
115
|
+
return self._differentiate_forward(h)
|
|
116
|
+
if method is DifferentiationMethod.BACKWARD:
|
|
117
|
+
return self._differentiate_forward(-h)
|
|
118
|
+
if method is DifferentiationMethod.CENTRAL:
|
|
119
|
+
return self._differentiate_central(h)
|
|
120
|
+
|
|
121
|
+
raise ValueError(f"Unknown differentiation method: {method!r}")
|
|
122
|
+
|
|
123
|
+
def _differentiate_forward(self, h: float) -> Function:
|
|
124
|
+
return Function(lambda x: (self(x + h) - self(x)) / h)
|
|
125
|
+
|
|
126
|
+
def _differentiate_central(self, h: float) -> Function:
|
|
127
|
+
return Function(lambda x: (self(x + h) - self(x - h)) / (2 * h))
|
|
128
|
+
|
|
129
|
+
def multi_differentiate(
|
|
130
|
+
self,
|
|
131
|
+
n: int,
|
|
132
|
+
h: float = 1e-5,
|
|
133
|
+
method: DifferentiationMethod = DifferentiationMethod.FORWARD,
|
|
134
|
+
) -> Function:
|
|
135
|
+
"""Return the *n*-th order derivative via repeated finite differences.
|
|
136
|
+
|
|
137
|
+
Parameters
|
|
138
|
+
----------
|
|
139
|
+
n:
|
|
140
|
+
Order of differentiation.
|
|
141
|
+
h:
|
|
142
|
+
Step size.
|
|
143
|
+
method:
|
|
144
|
+
Finite-difference scheme.
|
|
145
|
+
"""
|
|
146
|
+
if n == 0:
|
|
147
|
+
return self
|
|
148
|
+
return self.differentiate(h=h, method=method).multi_differentiate(n - 1, h, method)
|
|
149
|
+
|
|
150
|
+
# ------------------------------------------------------------------
|
|
151
|
+
# Integration
|
|
152
|
+
# ------------------------------------------------------------------
|
|
153
|
+
|
|
154
|
+
def integral(self, func: Function | Callable | None = None) -> None:
|
|
155
|
+
"""Set an exact antiderivative.
|
|
156
|
+
|
|
157
|
+
Parameters
|
|
158
|
+
----------
|
|
159
|
+
func:
|
|
160
|
+
A :class:`Function` (or callable) representing the antiderivative.
|
|
161
|
+
Once set, :meth:`integrate` with no *method* will use it.
|
|
162
|
+
"""
|
|
163
|
+
if isinstance(func, Function):
|
|
164
|
+
self._integral = func
|
|
165
|
+
elif callable(func):
|
|
166
|
+
self._integral = Function(func)
|
|
167
|
+
else:
|
|
168
|
+
raise NotImplementedError("Automatic antiderivative computation is not implemented.")
|
|
169
|
+
|
|
170
|
+
def integrate(
|
|
171
|
+
self,
|
|
172
|
+
a: float,
|
|
173
|
+
b: float,
|
|
174
|
+
method: IntegrationMethod | None = None,
|
|
175
|
+
n: int | None = None,
|
|
176
|
+
) -> float:
|
|
177
|
+
"""Compute the definite integral :math:`\\int_a^b f(x)\\,dx`.
|
|
178
|
+
|
|
179
|
+
Parameters
|
|
180
|
+
----------
|
|
181
|
+
a, b:
|
|
182
|
+
Integration bounds.
|
|
183
|
+
method:
|
|
184
|
+
Quadrature rule (:class:`IntegrationMethod`). If ``None`` and an
|
|
185
|
+
exact antiderivative has been set via :meth:`integral`, it is used
|
|
186
|
+
instead.
|
|
187
|
+
n:
|
|
188
|
+
Number of subintervals (or quadrature points for Gauss).
|
|
189
|
+
"""
|
|
190
|
+
if method is IntegrationMethod.RECTANGULAR:
|
|
191
|
+
return self._integrate_rectangular(a, b, n)
|
|
192
|
+
if method is IntegrationMethod.MIDPOINT:
|
|
193
|
+
return self._integrate_midpoint(a, b, n)
|
|
194
|
+
if method is IntegrationMethod.TRAPEZOIDAL:
|
|
195
|
+
return self._integrate_trapezoidal(a, b, n)
|
|
196
|
+
if method is IntegrationMethod.SIMPSON:
|
|
197
|
+
return self._integrate_simpson(a, b, n)
|
|
198
|
+
if method is IntegrationMethod.GAUSS:
|
|
199
|
+
return self._integrate_gauss(a, b, n)
|
|
200
|
+
|
|
201
|
+
if hasattr(self, "_integral"):
|
|
202
|
+
return self._integral(b) - self._integral(a)
|
|
203
|
+
|
|
204
|
+
raise ValueError("Specify a method or set an exact antiderivative via .integral().")
|
|
205
|
+
|
|
206
|
+
def _integrate_rectangular(self, a: float, b: float, n: int | None) -> float:
|
|
207
|
+
if not n:
|
|
208
|
+
return (b - a) * self(a)
|
|
209
|
+
h = (b - a) / n
|
|
210
|
+
return h * sum(self(a + i * h) for i in range(n))
|
|
211
|
+
|
|
212
|
+
def _integrate_midpoint(self, a: float, b: float, n: int | None) -> float:
|
|
213
|
+
if not n:
|
|
214
|
+
return (b - a) * self((a + b) / 2)
|
|
215
|
+
h = (b - a) / n
|
|
216
|
+
return h * sum(self(a + i * h + h / 2) for i in range(n))
|
|
217
|
+
|
|
218
|
+
def _integrate_trapezoidal(self, a: float, b: float, n: int | None) -> float:
|
|
219
|
+
if not n:
|
|
220
|
+
return (b - a) * (self(a) + self(b)) / 2
|
|
221
|
+
h = (b - a) / n
|
|
222
|
+
return h * (self(a) + 2 * sum(self(a + i * h) for i in range(1, n)) + self(b)) / 2
|
|
223
|
+
|
|
224
|
+
def _integrate_simpson(self, a: float, b: float, n: int | None) -> float:
|
|
225
|
+
if not n:
|
|
226
|
+
return (b - a) * (self(a) + 4 * self((a + b) / 2) + self(b)) / 6
|
|
227
|
+
h = (b - a) / n
|
|
228
|
+
return (
|
|
229
|
+
h
|
|
230
|
+
* (
|
|
231
|
+
self(a)
|
|
232
|
+
+ 4 * sum(self(a + i * h + h / 2) for i in range(n))
|
|
233
|
+
+ 2 * sum(self(a + i * h) for i in range(1, n))
|
|
234
|
+
+ self(b)
|
|
235
|
+
)
|
|
236
|
+
/ 6
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
def _integrate_gauss(self, a: float, b: float, n: int | None) -> float:
|
|
240
|
+
# Import here to avoid circular imports at module level
|
|
241
|
+
from .elementary import Polynomial
|
|
242
|
+
|
|
243
|
+
t = Polynomial((a + b) / 2, (b - a) / 2)
|
|
244
|
+
g = ((b - a) / 2) * self(t)
|
|
245
|
+
if n == 1:
|
|
246
|
+
return 2 * g(0)
|
|
247
|
+
if n == 2:
|
|
248
|
+
return g(-1 / math.sqrt(3)) + g(1 / math.sqrt(3))
|
|
249
|
+
|
|
250
|
+
raise NotImplementedError("Gaussian quadrature is only implemented for n=1 and n=2.")
|
|
251
|
+
|
|
252
|
+
# ------------------------------------------------------------------
|
|
253
|
+
# Root finding
|
|
254
|
+
# ------------------------------------------------------------------
|
|
255
|
+
|
|
256
|
+
def root(
|
|
257
|
+
self,
|
|
258
|
+
method: RootFindingMethod,
|
|
259
|
+
a: float | None = None,
|
|
260
|
+
b: float | None = None,
|
|
261
|
+
p0: float | None = None,
|
|
262
|
+
p1: float | None = None,
|
|
263
|
+
TOLERANCE: float = 1e-10,
|
|
264
|
+
N: int = 100,
|
|
265
|
+
return_iterations: bool = False,
|
|
266
|
+
early_stop: int | None = None,
|
|
267
|
+
) -> float | tuple[float, int] | None:
|
|
268
|
+
"""Find a root of the function.
|
|
269
|
+
|
|
270
|
+
Parameters
|
|
271
|
+
----------
|
|
272
|
+
method:
|
|
273
|
+
Algorithm to use (:class:`RootFindingMethod`).
|
|
274
|
+
a, b:
|
|
275
|
+
Bracket for bisection (both required).
|
|
276
|
+
p0, p1:
|
|
277
|
+
Initial guesses (requirements vary by method).
|
|
278
|
+
TOLERANCE:
|
|
279
|
+
Convergence tolerance.
|
|
280
|
+
N:
|
|
281
|
+
Maximum number of iterations.
|
|
282
|
+
return_iterations:
|
|
283
|
+
If ``True``, returns ``(root, n_iterations)`` instead of just the root.
|
|
284
|
+
early_stop:
|
|
285
|
+
Stop after this many iterations regardless of convergence (useful for
|
|
286
|
+
computing the *k*-th iterate explicitly).
|
|
287
|
+
"""
|
|
288
|
+
if method is RootFindingMethod.BISECTION:
|
|
289
|
+
if a is None or b is None:
|
|
290
|
+
raise ValueError("Bisection requires both a and b.")
|
|
291
|
+
if a >= b:
|
|
292
|
+
raise ValueError("a must be less than b.")
|
|
293
|
+
if self(a) * self(b) >= 0:
|
|
294
|
+
raise ValueError("f(a) and f(b) must have opposite signs.")
|
|
295
|
+
sol, n = self._bisection(a, b, TOLERANCE, N, early_stop)
|
|
296
|
+
elif method is RootFindingMethod.NEWTON:
|
|
297
|
+
if p0 is None:
|
|
298
|
+
raise ValueError("Newton's method requires p0.")
|
|
299
|
+
sol, n = self._newton(p0, TOLERANCE, N, early_stop)
|
|
300
|
+
elif method is RootFindingMethod.SECANT:
|
|
301
|
+
if p0 is None or p1 is None:
|
|
302
|
+
raise ValueError("Secant method requires both p0 and p1.")
|
|
303
|
+
sol, n = self._secant(p0, p1, TOLERANCE, N, early_stop)
|
|
304
|
+
elif method is RootFindingMethod.REGULA_FALSI:
|
|
305
|
+
if p0 is None or p1 is None:
|
|
306
|
+
raise ValueError("Regula falsi requires both p0 and p1.")
|
|
307
|
+
if self(p0) * self(p1) >= 0:
|
|
308
|
+
raise ValueError("f(p0) and f(p1) must have opposite signs.")
|
|
309
|
+
sol, n = self._regula_falsi(p0, p1, TOLERANCE, N, early_stop)
|
|
310
|
+
elif method is RootFindingMethod.MODIFIED_NEWTON:
|
|
311
|
+
if p0 is None:
|
|
312
|
+
raise ValueError("Modified Newton's method requires p0.")
|
|
313
|
+
sol, n = self._modified_newton(p0, TOLERANCE, N, early_stop)
|
|
314
|
+
else:
|
|
315
|
+
raise ValueError(f"Unknown root-finding method: {method!r}")
|
|
316
|
+
|
|
317
|
+
if return_iterations:
|
|
318
|
+
return sol, n
|
|
319
|
+
return sol
|
|
320
|
+
|
|
321
|
+
def _bisection(
|
|
322
|
+
self, a: float, b: float, tol: float, N: int, early_stop: int | None
|
|
323
|
+
) -> tuple[float | None, int]:
|
|
324
|
+
for i in range(N):
|
|
325
|
+
p = (a + b) / 2
|
|
326
|
+
if self(p) == 0 or abs(a - b) < tol or (early_stop is not None and i >= early_stop):
|
|
327
|
+
return p, i + 1
|
|
328
|
+
if self(a) * self(p) > 0:
|
|
329
|
+
a = p
|
|
330
|
+
else:
|
|
331
|
+
b = p
|
|
332
|
+
return None, N
|
|
333
|
+
|
|
334
|
+
def _newton(
|
|
335
|
+
self, p0: float, tol: float, N: int, early_stop: int | None
|
|
336
|
+
) -> tuple[float | None, int]:
|
|
337
|
+
deriv = self.differentiate()
|
|
338
|
+
try:
|
|
339
|
+
for i in range(N):
|
|
340
|
+
p = p0 - self(p0) / deriv(p0)
|
|
341
|
+
if abs(p - p0) < tol or (early_stop is not None and i >= early_stop):
|
|
342
|
+
return p, i + 1
|
|
343
|
+
p0 = p
|
|
344
|
+
except (ZeroDivisionError, OverflowError):
|
|
345
|
+
return None, i # noqa: F821 (i is defined when exception fires)
|
|
346
|
+
return None, N
|
|
347
|
+
|
|
348
|
+
def _modified_newton(
|
|
349
|
+
self, p0: float, tol: float, N: int, early_stop: int | None
|
|
350
|
+
) -> tuple[float | None, int]:
|
|
351
|
+
deriv = self.differentiate()
|
|
352
|
+
double_deriv = deriv.differentiate()
|
|
353
|
+
try:
|
|
354
|
+
for i in range(N):
|
|
355
|
+
p = p0 - self(p0) * deriv(p0) / (deriv(p0) ** 2 - self(p0) * double_deriv(p0))
|
|
356
|
+
if abs(p - p0) < tol or (early_stop is not None and i >= early_stop):
|
|
357
|
+
return p, i + 1
|
|
358
|
+
p0 = p
|
|
359
|
+
except (ZeroDivisionError, OverflowError):
|
|
360
|
+
return None, i # noqa: F821
|
|
361
|
+
return None, N
|
|
362
|
+
|
|
363
|
+
def _secant(
|
|
364
|
+
self, p0: float, p1: float, tol: float, N: int, early_stop: int | None
|
|
365
|
+
) -> tuple[float | None, int]:
|
|
366
|
+
for i in range(N):
|
|
367
|
+
p = p1 - self(p1) * (p1 - p0) / (self(p1) - self(p0))
|
|
368
|
+
if abs(p - p1) < tol or (early_stop is not None and i >= early_stop):
|
|
369
|
+
return p, i + 1
|
|
370
|
+
p0 = p1
|
|
371
|
+
p1 = p
|
|
372
|
+
return None, N
|
|
373
|
+
|
|
374
|
+
def _regula_falsi(
|
|
375
|
+
self, p0: float, p1: float, tol: float, N: int, early_stop: int | None
|
|
376
|
+
) -> tuple[float | None, int]:
|
|
377
|
+
for i in range(N):
|
|
378
|
+
p = p1 - self(p1) * (p1 - p0) / (self(p1) - self(p0))
|
|
379
|
+
if abs(p - p1) < tol or (early_stop is not None and i >= early_stop):
|
|
380
|
+
return p, i + 1
|
|
381
|
+
if self(p0) * self(p) > 0:
|
|
382
|
+
p0 = p1
|
|
383
|
+
p1 = p
|
|
384
|
+
return None, N
|
|
385
|
+
|
|
386
|
+
def fixed_point(
|
|
387
|
+
self, p0: float, TOLERANCE: float = 1e-10, N: int = 100
|
|
388
|
+
) -> float | None:
|
|
389
|
+
"""Apply fixed-point iteration :math:`p_{n+1} = f(p_n)`.
|
|
390
|
+
|
|
391
|
+
Parameters
|
|
392
|
+
----------
|
|
393
|
+
p0:
|
|
394
|
+
Initial approximation.
|
|
395
|
+
TOLERANCE:
|
|
396
|
+
Convergence tolerance.
|
|
397
|
+
N:
|
|
398
|
+
Maximum number of iterations.
|
|
399
|
+
"""
|
|
400
|
+
try:
|
|
401
|
+
for _ in range(N):
|
|
402
|
+
p = self(p0)
|
|
403
|
+
if abs(p - p0) < TOLERANCE:
|
|
404
|
+
return p
|
|
405
|
+
p0 = p
|
|
406
|
+
except OverflowError:
|
|
407
|
+
return None
|
|
408
|
+
return None
|
|
409
|
+
|
|
410
|
+
# ------------------------------------------------------------------
|
|
411
|
+
# Plotting
|
|
412
|
+
# ------------------------------------------------------------------
|
|
413
|
+
|
|
414
|
+
def plot(
|
|
415
|
+
self, min: float, max: float, N: int = 1000, file: str = "", clear: bool = False
|
|
416
|
+
) -> None:
|
|
417
|
+
"""Plot the function on ``[min, max]``.
|
|
418
|
+
|
|
419
|
+
Requires ``matplotlib``.
|
|
420
|
+
|
|
421
|
+
Parameters
|
|
422
|
+
----------
|
|
423
|
+
min, max:
|
|
424
|
+
Domain bounds.
|
|
425
|
+
N:
|
|
426
|
+
Number of sample points.
|
|
427
|
+
file:
|
|
428
|
+
If non-empty, save to this path instead of displaying.
|
|
429
|
+
clear:
|
|
430
|
+
If ``True``, clear the current figure before plotting.
|
|
431
|
+
"""
|
|
432
|
+
import matplotlib.pyplot as plt
|
|
433
|
+
|
|
434
|
+
x = [min + (i / N) * (max - min) for i in range(N)]
|
|
435
|
+
y = [self(t) for t in x]
|
|
436
|
+
|
|
437
|
+
if clear:
|
|
438
|
+
plt.clf()
|
|
439
|
+
plt.plot(x, y)
|
|
440
|
+
plt.xlabel("x")
|
|
441
|
+
plt.ylabel("y")
|
|
442
|
+
if file:
|
|
443
|
+
plt.savefig(file)
|
|
444
|
+
else:
|
|
445
|
+
plt.show()
|