QuantileFlow 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- QuantileFlow/__init__.py +31 -0
- QuantileFlow/momentsketch/__init__.py +22 -0
- QuantileFlow/momentsketch/core.py +200 -0
- QuantileFlow/momentsketch/example.py +160 -0
- QuantileFlow/momentsketch/optimizer.py +205 -0
- QuantileFlow/momentsketch/simple_moment_sketch.py +495 -0
- QuantileFlow/momentsketch/utils.py +328 -0
- docs/__init__.py +0 -0
- docs/conf.py +90 -0
- quantileflow-0.0.1.dist-info/METADATA +45 -0
- quantileflow-0.0.1.dist-info/RECORD +20 -0
- quantileflow-0.0.1.dist-info/WHEEL +5 -0
- quantileflow-0.0.1.dist-info/licenses/LICENSE +201 -0
- quantileflow-0.0.1.dist-info/licenses/LICENSE.txt +19 -0
- quantileflow-0.0.1.dist-info/top_level.txt +3 -0
- tests/__init__.py +0 -0
- tests/hdrhistogram/__init__.py +3 -0
- tests/hdrhistogram/test_hdr.py +265 -0
- tests/momentsketch/__init__.py +1 -0
- tests/momentsketch/integration_test.py +315 -0
QuantileFlow/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""
|
|
2
|
+
QuantileFlow: Efficient Quantile Computation for Anomaly Detection
|
|
3
|
+
|
|
4
|
+
This package provides APIs and algorithms to efficiently compute quantiles for anomaly detection
|
|
5
|
+
in service and system logs. It implements multiple sketching algorithms optimized for:
|
|
6
|
+
|
|
7
|
+
- Low memory footprint
|
|
8
|
+
- Fast updates and queries
|
|
9
|
+
- Distributed computation support through mergeable sketches
|
|
10
|
+
- Accuracy guarantees for quantile approximation
|
|
11
|
+
|
|
12
|
+
The package includes three main implementations:
|
|
13
|
+
|
|
14
|
+
1. DDSketch: A deterministic algorithm with relative error guarantees
|
|
15
|
+
2. MomentSketch: A moment-based algorithm using maximum entropy optimization
|
|
16
|
+
3. HDRHistogram: A high dynamic range histogram for tracking values across wide ranges
|
|
17
|
+
|
|
18
|
+
All implementations are designed to handle high-throughput data streams and provide
|
|
19
|
+
accurate quantile estimates with minimal memory overhead.
|
|
20
|
+
"""
|
|
21
|
+
from QuantileFlow.momentsketch.core import MomentSketch
|
|
22
|
+
from QuantileFlow.hdrhistogram.core import HDRHistogram
|
|
23
|
+
|
|
24
|
+
__version__ = "0.0.1"
|
|
25
|
+
__all__ = [
|
|
26
|
+
"MomentSketch",
|
|
27
|
+
"HDRHistogram",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
if __name__ == "__main__":
|
|
31
|
+
print("This is root of QuantileFlow module. API not to be exposed as a script!")
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MomentSketch: Quantile Estimation Using Moment-Based Sketching
|
|
3
|
+
|
|
4
|
+
This module provides an efficient implementation of the moment-based sketching algorithm
|
|
5
|
+
for quantile estimation and summary statistics. Key features include:
|
|
6
|
+
|
|
7
|
+
- Memory efficiency: Uses a fixed number of moments regardless of data size
|
|
8
|
+
- Mergeable: Supports distributed computation through sketch merging
|
|
9
|
+
- Accurate: Employs maximum entropy optimization for accurate distribution estimation
|
|
10
|
+
- Flexible: Supports optional data compression for handling widely distributed values
|
|
11
|
+
- Comprehensive: Provides various summary statistics beyond just quantiles
|
|
12
|
+
|
|
13
|
+
The implementation is based on power sums and maximum entropy optimization,
|
|
14
|
+
making it suitable for streaming data applications where memory efficiency
|
|
15
|
+
and accuracy are important.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from .core import MomentSketch
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"MomentSketch"
|
|
22
|
+
]
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Core MomentSketch implementation.
|
|
3
|
+
|
|
4
|
+
This module provides the main MomentSketch class which serves as the public API
|
|
5
|
+
for the moment-based quantile estimation algorithm. The MomentSketch implementation:
|
|
6
|
+
|
|
7
|
+
- Maintains a compact representation of a dataset using power sums
|
|
8
|
+
- Provides accurate quantile estimates through maximum entropy optimization
|
|
9
|
+
- Supports merging sketches for distributed computation
|
|
10
|
+
- Offers comprehensive summary statistics and visualization capabilities
|
|
11
|
+
- Handles data compression for widely distributed values
|
|
12
|
+
|
|
13
|
+
The implementation is designed to be memory-efficient and accurate, making it suitable
|
|
14
|
+
for streaming data applications and monitoring systems where traditional approaches
|
|
15
|
+
would require excessive memory.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from typing import List, Union, Dict
|
|
19
|
+
import numpy as np
|
|
20
|
+
|
|
21
|
+
from .simple_moment_sketch import SimpleMS
|
|
22
|
+
|
|
23
|
+
class MomentSketch:
|
|
24
|
+
"""
|
|
25
|
+
MomentSketch implementation for quantile approximation using the moment-based approach.
|
|
26
|
+
|
|
27
|
+
This implementation uses power sums, Chebyshev moment conversion, and maximum entropy
|
|
28
|
+
optimization to estimate the probability distribution of data and compute quantiles.
|
|
29
|
+
It supports merging sketches from distributed sources and provides accurate quantile
|
|
30
|
+
estimates with a compact representation.
|
|
31
|
+
|
|
32
|
+
Reference:
|
|
33
|
+
"Space- and Computationally-Efficient Set Similarity via Locality Sensitive Sketching"
|
|
34
|
+
by Anshumali Shrivastava
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
num_moments: int = 20,
|
|
40
|
+
compress_values: bool = False
|
|
41
|
+
):
|
|
42
|
+
"""
|
|
43
|
+
Initialize MomentSketch.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
num_moments: Number of moments to track (default 20).
|
|
47
|
+
Higher values increase accuracy at the cost of computation.
|
|
48
|
+
compress_values: Whether to compress values using arcsinh transformation (default False).
|
|
49
|
+
Useful for handling widely distributed data with extreme values.
|
|
50
|
+
"""
|
|
51
|
+
self.sketch = SimpleMS(num_moments)
|
|
52
|
+
self.sketch.set_compressed(compress_values)
|
|
53
|
+
|
|
54
|
+
def insert(self, value: Union[int, float]) -> None:
|
|
55
|
+
"""
|
|
56
|
+
Insert a single value into the sketch.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
value: The value to insert.
|
|
60
|
+
"""
|
|
61
|
+
self.sketch.add(value)
|
|
62
|
+
|
|
63
|
+
def insert_batch(self, values: Union[List[float], np.ndarray]) -> None:
|
|
64
|
+
"""
|
|
65
|
+
Insert multiple values into the sketch.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
values: Array or list of values to insert.
|
|
69
|
+
"""
|
|
70
|
+
self.sketch.add_many(values)
|
|
71
|
+
|
|
72
|
+
def merge(self, other: 'MomentSketch') -> None:
|
|
73
|
+
"""
|
|
74
|
+
Merge another MomentSketch into this one.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
other: Another MomentSketch instance to merge.
|
|
78
|
+
|
|
79
|
+
Raises:
|
|
80
|
+
ValueError: If the sketches are incompatible (different compression settings).
|
|
81
|
+
"""
|
|
82
|
+
self.sketch.merge(other.sketch)
|
|
83
|
+
|
|
84
|
+
def quantile(self, fraction: float) -> float:
|
|
85
|
+
"""
|
|
86
|
+
Get the value at a given quantile.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
fraction: Quantile fraction between 0 and 1 (e.g., 0.5 for median).
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Estimated value at the requested quantile.
|
|
93
|
+
|
|
94
|
+
Raises:
|
|
95
|
+
ValueError: If fraction is not between 0 and 1.
|
|
96
|
+
"""
|
|
97
|
+
if not 0 <= fraction <= 1:
|
|
98
|
+
raise ValueError("Quantile must be between 0 and 1")
|
|
99
|
+
|
|
100
|
+
return self.sketch.get_quantile(fraction)
|
|
101
|
+
|
|
102
|
+
def quantiles(self, fractions: List[float]) -> List[float]:
|
|
103
|
+
"""
|
|
104
|
+
Get values at multiple quantiles.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
fractions: List of quantile fractions between 0 and 1.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
List of estimated values at the requested quantiles.
|
|
111
|
+
|
|
112
|
+
Raises:
|
|
113
|
+
ValueError: If any fraction is not between 0 and 1.
|
|
114
|
+
"""
|
|
115
|
+
for q in fractions:
|
|
116
|
+
if not 0 <= q <= 1:
|
|
117
|
+
raise ValueError("All quantiles must be between 0 and 1")
|
|
118
|
+
|
|
119
|
+
return self.sketch.get_quantiles(fractions)
|
|
120
|
+
|
|
121
|
+
def median(self) -> float:
|
|
122
|
+
"""
|
|
123
|
+
Get the median value (50th percentile).
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
Estimated median value.
|
|
127
|
+
"""
|
|
128
|
+
return self.sketch.get_median()
|
|
129
|
+
|
|
130
|
+
def percentile(self, p: float) -> float:
|
|
131
|
+
"""
|
|
132
|
+
Get the p-th percentile value.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
p: Percentile between 0 and 100 (e.g., 75 for 75th percentile).
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
Estimated value at the requested percentile.
|
|
139
|
+
|
|
140
|
+
Raises:
|
|
141
|
+
ValueError: If p is not between 0 and 100.
|
|
142
|
+
"""
|
|
143
|
+
if not 0 <= p <= 100:
|
|
144
|
+
raise ValueError("Percentile must be between 0 and 100")
|
|
145
|
+
|
|
146
|
+
return self.sketch.get_percentile(p)
|
|
147
|
+
|
|
148
|
+
def interquartile_range(self) -> float:
|
|
149
|
+
"""
|
|
150
|
+
Get the interquartile range (IQR).
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Estimated IQR (difference between 75th and 25th percentiles).
|
|
154
|
+
"""
|
|
155
|
+
return self.sketch.get_iqr()
|
|
156
|
+
|
|
157
|
+
def summary_statistics(self) -> Dict[str, float]:
|
|
158
|
+
"""
|
|
159
|
+
Get summary statistics.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
Dictionary containing min, q1, median, q3, max, count, and mean.
|
|
163
|
+
"""
|
|
164
|
+
return self.sketch.get_stats()
|
|
165
|
+
|
|
166
|
+
def plot_distribution(self, figsize=(10, 6)):
|
|
167
|
+
"""
|
|
168
|
+
Plot the estimated probability distribution.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
figsize: Figure size (width, height) in inches.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
Matplotlib figure object.
|
|
175
|
+
"""
|
|
176
|
+
return self.sketch.plot_dist(figsize=figsize)
|
|
177
|
+
|
|
178
|
+
def to_dict(self) -> Dict:
|
|
179
|
+
"""
|
|
180
|
+
Convert sketch to a dictionary for serialization.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
Dictionary representation of the sketch.
|
|
184
|
+
"""
|
|
185
|
+
return self.sketch.to_dict()
|
|
186
|
+
|
|
187
|
+
@classmethod
|
|
188
|
+
def from_dict(cls, data: Dict) -> 'MomentSketch':
|
|
189
|
+
"""
|
|
190
|
+
Create a sketch from a dictionary.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
data: Dictionary representation of a sketch.
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
New MomentSketch instance.
|
|
197
|
+
"""
|
|
198
|
+
sketch = cls()
|
|
199
|
+
sketch.sketch = SimpleMS.from_dict(data)
|
|
200
|
+
return sketch
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Example script demonstrating the usage of MomentSketch for quantile estimation.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import matplotlib.pyplot as plt
|
|
7
|
+
import time
|
|
8
|
+
from QuantileFlow import MomentSketch
|
|
9
|
+
|
|
10
|
+
def basic_usage_demo():
|
|
11
|
+
"""Demonstrate basic usage of the MomentSketch"""
|
|
12
|
+
print("\nBasic MomentSketch Usage Example")
|
|
13
|
+
print("-" * 50)
|
|
14
|
+
|
|
15
|
+
# Generate sample data - skewed distribution
|
|
16
|
+
np.random.seed(8990)
|
|
17
|
+
data = np.concatenate([
|
|
18
|
+
np.random.normal(10, 2, 10000),
|
|
19
|
+
np.random.exponential(5, 5000)
|
|
20
|
+
])
|
|
21
|
+
|
|
22
|
+
# Create a MomentSketch with 20 moments
|
|
23
|
+
sketch = MomentSketch(num_moments=20)
|
|
24
|
+
|
|
25
|
+
# Add data to the sketch
|
|
26
|
+
start_time = time.time()
|
|
27
|
+
sketch.insert_batch(data)
|
|
28
|
+
sketch_time = time.time() - start_time
|
|
29
|
+
|
|
30
|
+
# Get quantiles
|
|
31
|
+
start_time = time.time()
|
|
32
|
+
percentiles = [0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]
|
|
33
|
+
quantiles = sketch.quantiles(percentiles)
|
|
34
|
+
query_time = time.time() - start_time
|
|
35
|
+
|
|
36
|
+
# Calculate true quantiles for comparison
|
|
37
|
+
true_quantiles = np.quantile(data, percentiles)
|
|
38
|
+
|
|
39
|
+
# Print results
|
|
40
|
+
print(f"Data size: {len(data):,d} points")
|
|
41
|
+
print(f"Time to build sketch: {sketch_time * 1000:.2f} ms")
|
|
42
|
+
print(f"Time to query quantiles: {query_time * 1000:.2f} ms")
|
|
43
|
+
print("\nQuantile Comparison:")
|
|
44
|
+
print(f"{'Percentile':>10} | {'True':>10} | {'Estimated':>10} | {'Error %':>10}")
|
|
45
|
+
print("-" * 50)
|
|
46
|
+
|
|
47
|
+
for p, tq, eq in zip(percentiles, true_quantiles, quantiles):
|
|
48
|
+
error_pct = 100 * abs(tq - eq) / abs(tq) if tq != 0 else 0
|
|
49
|
+
print(f"{p * 100:10.1f}% | {tq:10.4f} | {eq:10.4f} | {error_pct:10.2f}%")
|
|
50
|
+
|
|
51
|
+
# Get summary statistics
|
|
52
|
+
stats = sketch.summary_statistics()
|
|
53
|
+
print("\nSummary Statistics:")
|
|
54
|
+
for key, value in stats.items():
|
|
55
|
+
print(f"{key}: {value}")
|
|
56
|
+
|
|
57
|
+
return sketch, data
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def distribution_demo(sketch, data):
|
|
61
|
+
"""Demonstrate plotting the estimated distribution"""
|
|
62
|
+
print("\nDistribution Visualization")
|
|
63
|
+
print("-" * 50)
|
|
64
|
+
|
|
65
|
+
# Plot the distribution
|
|
66
|
+
fig = sketch.plot_distribution(figsize=(10, 6))
|
|
67
|
+
|
|
68
|
+
# Add a histogram of the original data for comparison
|
|
69
|
+
ax = fig.axes[0]
|
|
70
|
+
ax.hist(data, bins=50, density=True, alpha=0.5, color='blue', label='Original Data')
|
|
71
|
+
ax.legend()
|
|
72
|
+
|
|
73
|
+
return fig
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def merge_demo():
|
|
77
|
+
"""Demonstrate merging multiple sketches"""
|
|
78
|
+
print("\nMerging Sketches Example")
|
|
79
|
+
print("-" * 50)
|
|
80
|
+
|
|
81
|
+
# Create two different data streams
|
|
82
|
+
np.random.seed(8990)
|
|
83
|
+
data1 = np.random.normal(1, 2, 10000)
|
|
84
|
+
data2 = np.random.normal(2, 3, 15000)
|
|
85
|
+
|
|
86
|
+
# Create sketches for each stream
|
|
87
|
+
sketch1 = MomentSketch(num_moments=20)
|
|
88
|
+
sketch2 = MomentSketch(num_moments=20)
|
|
89
|
+
|
|
90
|
+
# Add data to sketches
|
|
91
|
+
sketch1.insert_batch(data1)
|
|
92
|
+
sketch2.insert_batch(data2)
|
|
93
|
+
|
|
94
|
+
# Create a combined sketch by merging
|
|
95
|
+
combined_sketch = MomentSketch(num_moments=20)
|
|
96
|
+
combined_sketch.merge(sketch1)
|
|
97
|
+
combined_sketch.merge(sketch2)
|
|
98
|
+
|
|
99
|
+
# Create a reference by combining the data
|
|
100
|
+
combined_data = np.concatenate([data1, data2])
|
|
101
|
+
|
|
102
|
+
# Compare median estimates
|
|
103
|
+
median1 = sketch1.median()
|
|
104
|
+
median2 = sketch2.median()
|
|
105
|
+
combined_median = combined_sketch.median()
|
|
106
|
+
true_combined_median = np.median(combined_data)
|
|
107
|
+
|
|
108
|
+
print(f"Data 1 size: {len(data1):,d}, Median: {median1:.4f}")
|
|
109
|
+
print(f"Data 2 size: {len(data2):,d}, Median: {median2:.4f}")
|
|
110
|
+
print(f"Combined data size: {len(combined_data):,d}")
|
|
111
|
+
print(f"True combined median: {true_combined_median:.4f}")
|
|
112
|
+
print(f"Estimated combined median: {combined_median:.4f}")
|
|
113
|
+
print(f"Error: {100 * abs(combined_median - true_combined_median) / true_combined_median:.2f}%")
|
|
114
|
+
|
|
115
|
+
return combined_sketch, combined_data
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def serialization_demo():
|
|
119
|
+
"""Demonstrate sketch serialization"""
|
|
120
|
+
print("\nSerialization Example")
|
|
121
|
+
print("-" * 50)
|
|
122
|
+
|
|
123
|
+
# Create and populate a sketch
|
|
124
|
+
np.random.seed(8990)
|
|
125
|
+
data = np.random.lognormal(0, 1, 20000)
|
|
126
|
+
|
|
127
|
+
sketch = MomentSketch(num_moments=20)
|
|
128
|
+
sketch.insert_batch(data)
|
|
129
|
+
|
|
130
|
+
# Serialize to dictionary
|
|
131
|
+
sketch_dict = sketch.to_dict()
|
|
132
|
+
print(f"Serialized sketch keys: {list(sketch_dict.keys())}")
|
|
133
|
+
|
|
134
|
+
# Deserialize
|
|
135
|
+
restored_sketch = MomentSketch.from_dict(sketch_dict)
|
|
136
|
+
|
|
137
|
+
# Compare results
|
|
138
|
+
original_quantiles = sketch.quantiles([0.25, 0.5, 0.75])
|
|
139
|
+
restored_quantiles = restored_sketch.quantiles([0.25, 0.5, 0.75])
|
|
140
|
+
|
|
141
|
+
print("\nQuantile comparison:")
|
|
142
|
+
print(f"{'Percentile':>10} | {'Original':>10} | {'Restored':>10}")
|
|
143
|
+
print("-" * 50)
|
|
144
|
+
|
|
145
|
+
for p, oq, rq in zip([0.25, 0.5, 0.75], original_quantiles, restored_quantiles):
|
|
146
|
+
print(f"{p * 100:10.1f}% | {oq:10.4f} | {rq:10.4f}")
|
|
147
|
+
|
|
148
|
+
return sketch_dict
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
if __name__ == "__main__":
|
|
152
|
+
# Run all demonstrations
|
|
153
|
+
sketch, data = basic_usage_demo()
|
|
154
|
+
fig = distribution_demo(sketch, data)
|
|
155
|
+
|
|
156
|
+
combined_sketch, combined_data = merge_demo()
|
|
157
|
+
serialized_sketch = serialization_demo()
|
|
158
|
+
|
|
159
|
+
plt.tight_layout()
|
|
160
|
+
plt.show()
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Optimization algorithms for the MomentSketch implementation.
|
|
3
|
+
|
|
4
|
+
This module provides optimization algorithms used to solve the maximum entropy problem
|
|
5
|
+
in the MomentSketch implementation. It includes:
|
|
6
|
+
|
|
7
|
+
- BaseOptimizer: An abstract base class defining the common interface for optimizers
|
|
8
|
+
- NewtonOptimizer: An implementation of damped Newton's method for convex optimization
|
|
9
|
+
|
|
10
|
+
The optimizers handle numerical stability issues that can arise in the maximum entropy
|
|
11
|
+
optimization problem, including matrix conditioning problems and numerical precision issues.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
from .utils import Util
|
|
16
|
+
from scipy.linalg import svd, solve
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BaseOptimizer:
|
|
20
|
+
"""
|
|
21
|
+
Base class for optimization algorithms
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def set_verbose(self, flag):
|
|
25
|
+
"""Set verbose output flag"""
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
def set_max_iterations(self, max_iterations):
|
|
29
|
+
"""Set maximum iterations"""
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
def is_converged(self):
|
|
33
|
+
"""Check if optimization has converged"""
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
def get_iteration_count(self):
|
|
37
|
+
"""Get the number of steps taken"""
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
def get_function(self):
|
|
41
|
+
"""Get the function being optimized"""
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
def solve(self, initial_point, gradient_tolerance):
|
|
45
|
+
"""Solve the optimization problem"""
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class NewtonOptimizer(BaseOptimizer):
|
|
50
|
+
"""
|
|
51
|
+
Minimizes a convex function using damped Newton's
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(self, objective_function):
|
|
55
|
+
"""
|
|
56
|
+
Initialize the optimizer
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
objective_function: FunctionWithHessian to optimize
|
|
60
|
+
"""
|
|
61
|
+
self.objective_function = objective_function
|
|
62
|
+
self.max_iterations = 200
|
|
63
|
+
self.iteration_count = 0
|
|
64
|
+
self.steps = 0
|
|
65
|
+
self.converged = False
|
|
66
|
+
|
|
67
|
+
self.alpha = 0.3
|
|
68
|
+
self.backtracking_rate = 0.25
|
|
69
|
+
self.verbose = False
|
|
70
|
+
|
|
71
|
+
def set_verbose(self, flag):
|
|
72
|
+
"""Set verbose output flag"""
|
|
73
|
+
self.verbose = flag
|
|
74
|
+
|
|
75
|
+
def set_max_iterations(self, max_iterations):
|
|
76
|
+
"""Set maximum iterations"""
|
|
77
|
+
self.max_iterations = max_iterations
|
|
78
|
+
|
|
79
|
+
def get_iteration_count(self):
|
|
80
|
+
"""Get the number of steps taken"""
|
|
81
|
+
return self.iteration_count
|
|
82
|
+
|
|
83
|
+
def is_converged(self):
|
|
84
|
+
"""Check if optimization has converged"""
|
|
85
|
+
return self.converged
|
|
86
|
+
|
|
87
|
+
def get_backtracking_count(self):
|
|
88
|
+
"""Get the number of damped steps"""
|
|
89
|
+
return self.steps
|
|
90
|
+
|
|
91
|
+
def get_function(self):
|
|
92
|
+
"""Get the function being optimized"""
|
|
93
|
+
return self.objective_function
|
|
94
|
+
|
|
95
|
+
def solve(self, initial_point, grad_tolerance):
|
|
96
|
+
"""
|
|
97
|
+
Solve the optimization problem
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
initial_point: Initial point
|
|
101
|
+
grad_tolerance: Grad tolerance for convergence
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
Optimal point
|
|
105
|
+
"""
|
|
106
|
+
dimension = self.objective_function.dim
|
|
107
|
+
|
|
108
|
+
current_point = np.array(initial_point).copy()
|
|
109
|
+
|
|
110
|
+
required_precision = grad_tolerance / 10
|
|
111
|
+
self.objective_function.compute_all(current_point, required_precision)
|
|
112
|
+
|
|
113
|
+
squared_tolerance = grad_tolerance * grad_tolerance
|
|
114
|
+
self.converged = False
|
|
115
|
+
|
|
116
|
+
for iteration in range(self.max_iterations):
|
|
117
|
+
function_value = self.objective_function.get_value()
|
|
118
|
+
gradient = self.objective_function.get_gradient()
|
|
119
|
+
hessian = self.objective_function.get_hessian()
|
|
120
|
+
|
|
121
|
+
# Check for NaN or Inf in gradient or hessian
|
|
122
|
+
if not np.all(np.isfinite(gradient)) or not np.all(np.isfinite(hessian)):
|
|
123
|
+
if self.verbose:
|
|
124
|
+
print("Warning: NaN or Inf detected in gradient or Hessian. Using fallback approach.")
|
|
125
|
+
# Fallback to diagonal Hessian
|
|
126
|
+
hessian = np.eye(dimension)
|
|
127
|
+
# Clean gradient if needed
|
|
128
|
+
gradient = np.nan_to_num(gradient, nan=0.0, posinf=1.0, neginf=-1.0)
|
|
129
|
+
|
|
130
|
+
mean_squared_error = Util.get_mse(gradient)
|
|
131
|
+
|
|
132
|
+
if self.verbose:
|
|
133
|
+
print(f"Iteration: {iteration:3d} GradRMSE: {np.sqrt(mean_squared_error):10.5g} Value: {function_value:10.5g}")
|
|
134
|
+
|
|
135
|
+
if mean_squared_error < squared_tolerance:
|
|
136
|
+
self.converged = True
|
|
137
|
+
break
|
|
138
|
+
|
|
139
|
+
# Try to solve using Cholesky decomposition
|
|
140
|
+
try:
|
|
141
|
+
# Add small regularization to ensure positive definiteness if needed
|
|
142
|
+
if np.any(np.diag(hessian) <= 0):
|
|
143
|
+
if self.verbose:
|
|
144
|
+
print("Adding regularization to Hessian")
|
|
145
|
+
hessian = hessian + 1e-8 * np.eye(dimension)
|
|
146
|
+
|
|
147
|
+
newton_direction = solve(hessian, gradient, assume_a='pos')
|
|
148
|
+
except np.linalg.LinAlgError:
|
|
149
|
+
# Fall back to SVD if Cholesky fails
|
|
150
|
+
if self.verbose:
|
|
151
|
+
print("Cholesky decomposition failed, falling back to SVD")
|
|
152
|
+
u, singular_values, vh = svd(hessian)
|
|
153
|
+
pseudoinverse_values = np.array([1 / x if x > 1e-10 else 0 for x in singular_values])
|
|
154
|
+
newton_direction = (vh.T * pseudoinverse_values) @ (u.T @ gradient)
|
|
155
|
+
|
|
156
|
+
newton_direction = -newton_direction
|
|
157
|
+
|
|
158
|
+
# Directional derivative
|
|
159
|
+
directional_derivative = np.sum(newton_direction * gradient)
|
|
160
|
+
|
|
161
|
+
step_size = 1.0
|
|
162
|
+
candidate_point = current_point + step_size * newton_direction
|
|
163
|
+
|
|
164
|
+
# Warning: this overwrites gradient and hessian
|
|
165
|
+
try:
|
|
166
|
+
self.objective_function.compute_all(candidate_point, required_precision)
|
|
167
|
+
except Exception as e:
|
|
168
|
+
if self.verbose:
|
|
169
|
+
print(f"Error computing objective at candidate point: {e}")
|
|
170
|
+
step_size = 0.1
|
|
171
|
+
candidate_point = current_point + step_size * newton_direction
|
|
172
|
+
self.objective_function.compute_all(candidate_point, required_precision)
|
|
173
|
+
|
|
174
|
+
# Do not look for damped steps if we are near stationary point
|
|
175
|
+
if directional_derivative * directional_derivative > squared_tolerance:
|
|
176
|
+
max_backtracking_steps = 10
|
|
177
|
+
backtracking_steps = 0
|
|
178
|
+
|
|
179
|
+
while True:
|
|
180
|
+
new_function_value = self.objective_function.get_value()
|
|
181
|
+
improvement = function_value + self.alpha * step_size * directional_derivative - new_function_value
|
|
182
|
+
|
|
183
|
+
if improvement >= -grad_tolerance or step_size < 1e-3 or backtracking_steps >= max_backtracking_steps:
|
|
184
|
+
break
|
|
185
|
+
else:
|
|
186
|
+
step_size *= self.backtracking_rate
|
|
187
|
+
backtracking_steps += 1
|
|
188
|
+
|
|
189
|
+
candidate_point = current_point + step_size * newton_direction
|
|
190
|
+
try:
|
|
191
|
+
self.objective_function.compute_all(candidate_point, required_precision)
|
|
192
|
+
except Exception:
|
|
193
|
+
# If computation fails, just use the current step size and break
|
|
194
|
+
break
|
|
195
|
+
|
|
196
|
+
if step_size < 1.0:
|
|
197
|
+
self.steps += 1
|
|
198
|
+
|
|
199
|
+
if self.verbose and step_size < 1.0:
|
|
200
|
+
print(f"Step Size: {step_size}")
|
|
201
|
+
|
|
202
|
+
current_point = candidate_point
|
|
203
|
+
|
|
204
|
+
self.iteration_count = iteration + 1
|
|
205
|
+
return current_point
|