lrdbench 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- analysis/high_performance/jax/__init__.py +36 -0
- analysis/high_performance/jax/cwt_jax.py +242 -0
- analysis/high_performance/jax/dfa_jax.py +437 -0
- analysis/high_performance/jax/dma_jax.py +397 -0
- analysis/high_performance/jax/gph_jax.py +289 -0
- analysis/high_performance/jax/higuchi_jax.py +403 -0
- analysis/high_performance/jax/mfdfa_jax.py +309 -0
- analysis/high_performance/jax/multifractal_wavelet_leaders_jax.py +320 -0
- analysis/high_performance/jax/periodogram_jax.py +422 -0
- analysis/high_performance/jax/rs_jax.py +391 -0
- analysis/high_performance/jax/wavelet_log_variance_jax.py +243 -0
- analysis/high_performance/jax/wavelet_variance_jax.py +242 -0
- analysis/high_performance/jax/wavelet_whittle_jax.py +233 -0
- analysis/high_performance/jax/whittle_jax.py +539 -0
- analysis/high_performance/numba/__init__.py +36 -0
- analysis/high_performance/numba/cwt_numba.py +264 -0
- analysis/high_performance/numba/dfa_numba.py +427 -0
- analysis/high_performance/numba/dma_numba.py +388 -0
- analysis/high_performance/numba/gph_numba.py +370 -0
- analysis/high_performance/numba/higuchi_numba.py +375 -0
- analysis/high_performance/numba/mfdfa_numba.py +370 -0
- analysis/high_performance/numba/multifractal_wavelet_leaders_numba.py +351 -0
- analysis/high_performance/numba/periodogram_numba.py +475 -0
- analysis/high_performance/numba/rs_numba.py +384 -0
- analysis/high_performance/numba/wavelet_log_variance_numba.py +259 -0
- analysis/high_performance/numba/wavelet_variance_numba.py +258 -0
- analysis/high_performance/numba/wavelet_whittle_numba.py +231 -0
- analysis/high_performance/numba/whittle_numba.py +608 -0
- analysis/machine_learning/__init__.py +37 -0
- analysis/machine_learning/base_ml_estimator.py +587 -0
- analysis/machine_learning/cnn_estimator.py +349 -0
- analysis/machine_learning/gradient_boosting_estimator.py +112 -0
- analysis/machine_learning/gru_estimator.py +235 -0
- analysis/machine_learning/lstm_estimator.py +266 -0
- analysis/machine_learning/neural_network_estimator.py +110 -0
- analysis/machine_learning/random_forest_estimator.py +113 -0
- analysis/machine_learning/svr_estimator.py +97 -0
- analysis/machine_learning/transformer_estimator.py +432 -0
- analysis/multifractal/mfdfa/mfdfa_estimator.py +405 -0
- analysis/multifractal/wavelet_leaders/multifractal_wavelet_leaders_estimator.py +446 -0
- analysis/spectral/__init__.py +23 -0
- analysis/spectral/gph/__init__.py +22 -0
- analysis/spectral/gph/gph_estimator.py +207 -0
- analysis/spectral/periodogram/__init__.py +22 -0
- analysis/spectral/periodogram/periodogram_estimator.py +224 -0
- analysis/spectral/whittle/__init__.py +22 -0
- analysis/spectral/whittle/whittle_estimator.py +279 -0
- analysis/temporal/dfa/__init__.py +11 -0
- analysis/temporal/dfa/dfa_estimator.py +303 -0
- analysis/temporal/dma/__init__.py +10 -0
- analysis/temporal/dma/dma_estimator.py +314 -0
- analysis/temporal/higuchi/__init__.py +10 -0
- analysis/temporal/higuchi/higuchi_estimator.py +327 -0
- analysis/temporal/rs/__init__.py +10 -0
- analysis/temporal/rs/rs_estimator.py +439 -0
- analysis/wavelet/cwt/__init__.py +10 -0
- analysis/wavelet/cwt/cwt_estimator.py +256 -0
- analysis/wavelet/log_variance/__init__.py +10 -0
- analysis/wavelet/log_variance/wavelet_log_variance_estimator.py +243 -0
- analysis/wavelet/variance/__init__.py +17 -0
- analysis/wavelet/variance/wavelet_variance_estimator.py +236 -0
- analysis/wavelet/whittle/__init__.py +10 -0
- analysis/wavelet/whittle/wavelet_whittle_estimator.py +300 -0
- lrdbench-1.0.0.dist-info/METADATA +369 -0
- lrdbench-1.0.0.dist-info/RECORD +92 -0
- lrdbench-1.0.0.dist-info/WHEEL +5 -0
- lrdbench-1.0.0.dist-info/entry_points.txt +4 -0
- lrdbench-1.0.0.dist-info/licenses/LICENSE +21 -0
- lrdbench-1.0.0.dist-info/top_level.txt +2 -0
- models/__init__.py +12 -0
- models/contamination/complex_time_series_library.py +686 -0
- models/contamination/contamination_models.py +575 -0
- models/data_models/__init__.py +10 -0
- models/data_models/arfima/__init__.py +9 -0
- models/data_models/arfima/arfima_model.py +354 -0
- models/data_models/base_model.py +105 -0
- models/data_models/fbm/__init__.py +10 -0
- models/data_models/fbm/fbm_model.py +242 -0
- models/data_models/fgn/__init__.py +13 -0
- models/data_models/fgn/fgn_model.py +103 -0
- models/data_models/mrw/__init__.py +10 -0
- models/data_models/mrw/mrw_model.py +263 -0
- models/data_models/neural_fsde/__init__.py +183 -0
- models/data_models/neural_fsde/base_neural_fsde.py +261 -0
- models/data_models/neural_fsde/fractional_brownian_motion.py +386 -0
- models/data_models/neural_fsde/hybrid_factory.py +439 -0
- models/data_models/neural_fsde/jax_fsde_net.py +764 -0
- models/data_models/neural_fsde/numerical_solvers.py +702 -0
- models/data_models/neural_fsde/test_hybrid_system.py +279 -0
- models/data_models/neural_fsde/torch_fsde_net.py +785 -0
- models/estimators/__init__.py +11 -0
- models/estimators/base_estimator.py +136 -0
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""
|
|
2
|
+
JAX-optimized high-performance estimators.
|
|
3
|
+
|
|
4
|
+
This package contains JAX-optimized versions of all estimators for GPU acceleration
|
|
5
|
+
and improved performance on large datasets.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from .dfa_jax import DFAEstimatorJAX
|
|
9
|
+
from .rs_jax import RSEstimatorJAX
|
|
10
|
+
from .higuchi_jax import HiguchiEstimatorJAX
|
|
11
|
+
from .dma_jax import DMAEstimatorJAX
|
|
12
|
+
from .periodogram_jax import PeriodogramEstimatorJAX
|
|
13
|
+
from .whittle_jax import WhittleEstimatorJAX
|
|
14
|
+
from .gph_jax import GPHEstimatorJAX
|
|
15
|
+
from .wavelet_log_variance_jax import WaveletLogVarianceEstimatorJAX
|
|
16
|
+
from .wavelet_variance_jax import WaveletVarianceEstimatorJAX
|
|
17
|
+
from .wavelet_whittle_jax import WaveletWhittleEstimatorJAX
|
|
18
|
+
from .cwt_jax import CWTEstimatorJAX
|
|
19
|
+
from .mfdfa_jax import MFDFAEstimatorJAX
|
|
20
|
+
from .multifractal_wavelet_leaders_jax import MultifractalWaveletLeadersEstimatorJAX
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
'DFAEstimatorJAX',
|
|
24
|
+
'RSEstimatorJAX',
|
|
25
|
+
'HiguchiEstimatorJAX',
|
|
26
|
+
'DMAEstimatorJAX',
|
|
27
|
+
'PeriodogramEstimatorJAX',
|
|
28
|
+
'WhittleEstimatorJAX',
|
|
29
|
+
'GPHEstimatorJAX',
|
|
30
|
+
'WaveletLogVarianceEstimatorJAX',
|
|
31
|
+
'WaveletVarianceEstimatorJAX',
|
|
32
|
+
'WaveletWhittleEstimatorJAX',
|
|
33
|
+
'CWTEstimatorJAX',
|
|
34
|
+
'MFDFAEstimatorJAX',
|
|
35
|
+
'MultifractalWaveletLeadersEstimatorJAX'
|
|
36
|
+
]
|
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
"""
|
|
2
|
+
JAX-optimized Continuous Wavelet Transform (CWT) Analysis estimator.
|
|
3
|
+
|
|
4
|
+
This module provides JAX-optimized Continuous Wavelet Transform analysis for estimating
|
|
5
|
+
the Hurst parameter from time series data using continuous wavelet decomposition.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import jax
|
|
10
|
+
import jax.numpy as jnp
|
|
11
|
+
from jax import jit, vmap
|
|
12
|
+
from typing import Optional, Tuple, List, Dict, Any
|
|
13
|
+
from models.estimators.base_estimator import BaseEstimator
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class CWTEstimatorJAX(BaseEstimator):
|
|
17
|
+
"""
|
|
18
|
+
JAX-optimized Continuous Wavelet Transform (CWT) Analysis estimator.
|
|
19
|
+
|
|
20
|
+
This estimator uses continuous wavelet transforms to analyze the scaling behavior
|
|
21
|
+
of time series data and estimate the Hurst parameter for fractional processes.
|
|
22
|
+
|
|
23
|
+
Attributes:
|
|
24
|
+
wavelet (str): Wavelet type to use for continuous transform
|
|
25
|
+
scales (np.ndarray): Array of scales for wavelet analysis
|
|
26
|
+
confidence (float): Confidence level for confidence intervals
|
|
27
|
+
use_gpu (bool): Whether to use GPU acceleration
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, wavelet: str = 'cmor1.5-1.0', scales: Optional[np.ndarray] = None,
|
|
31
|
+
confidence: float = 0.95, use_gpu: bool = False):
|
|
32
|
+
"""
|
|
33
|
+
Initialize the JAX-optimized CWT estimator.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
wavelet (str): Wavelet type for continuous transform (default: 'cmor1.5-1.0')
|
|
37
|
+
scales (np.ndarray, optional): Array of scales for analysis.
|
|
38
|
+
If None, uses automatic scale selection
|
|
39
|
+
confidence (float): Confidence level for intervals (default: 0.95)
|
|
40
|
+
use_gpu (bool): Whether to use GPU acceleration (default: False)
|
|
41
|
+
"""
|
|
42
|
+
super().__init__()
|
|
43
|
+
self.wavelet = wavelet
|
|
44
|
+
self.confidence = confidence
|
|
45
|
+
self.use_gpu = use_gpu
|
|
46
|
+
|
|
47
|
+
# Set default scales if not provided
|
|
48
|
+
if scales is None:
|
|
49
|
+
self.scales = np.logspace(1, 4, 20) # Logarithmically spaced scales
|
|
50
|
+
else:
|
|
51
|
+
self.scales = scales
|
|
52
|
+
|
|
53
|
+
# Results storage
|
|
54
|
+
self.results = {}
|
|
55
|
+
self._validate_parameters()
|
|
56
|
+
self._jit_functions()
|
|
57
|
+
|
|
58
|
+
# GPU setup
|
|
59
|
+
if self.use_gpu:
|
|
60
|
+
try:
|
|
61
|
+
jax.devices('gpu')
|
|
62
|
+
print("JAX CWT: Using GPU acceleration")
|
|
63
|
+
except:
|
|
64
|
+
print("JAX CWT: GPU not available, using CPU")
|
|
65
|
+
self.use_gpu = False
|
|
66
|
+
|
|
67
|
+
def _validate_parameters(self) -> None:
|
|
68
|
+
"""Validate the estimator parameters."""
|
|
69
|
+
if not isinstance(self.wavelet, str):
|
|
70
|
+
raise ValueError("wavelet must be a string")
|
|
71
|
+
if not isinstance(self.scales, np.ndarray) or len(self.scales) == 0:
|
|
72
|
+
raise ValueError("scales must be a non-empty numpy array")
|
|
73
|
+
if not (0 < self.confidence < 1):
|
|
74
|
+
raise ValueError("confidence must be between 0 and 1")
|
|
75
|
+
|
|
76
|
+
def _jit_functions(self):
|
|
77
|
+
"""JIT compile the core computation functions."""
|
|
78
|
+
# Note: Functions have dynamic parameters, so we don't JIT them to avoid tracing issues
|
|
79
|
+
pass
|
|
80
|
+
|
|
81
|
+
def _compute_cwt_jax(self, data: jnp.ndarray, scale: float) -> jnp.ndarray:
|
|
82
|
+
"""
|
|
83
|
+
Compute CWT coefficients for a given scale using JAX.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
data: Input time series data
|
|
87
|
+
scale: Wavelet scale
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
CWT coefficients at the given scale
|
|
91
|
+
"""
|
|
92
|
+
# For JAX compatibility, we'll use a simplified approach
|
|
93
|
+
# In practice, you might want to use a JAX-compatible wavelet library
|
|
94
|
+
# For now, we'll compute a simple approximation using convolution
|
|
95
|
+
|
|
96
|
+
# Create a simple wavelet kernel (Gaussian-like)
|
|
97
|
+
kernel_size = int(scale * 10) # Kernel size proportional to scale
|
|
98
|
+
if kernel_size < 3:
|
|
99
|
+
kernel_size = 3
|
|
100
|
+
|
|
101
|
+
# Create Gaussian-like kernel
|
|
102
|
+
x = jnp.linspace(-3, 3, kernel_size)
|
|
103
|
+
kernel = jnp.exp(-x**2 / (2 * scale**2))
|
|
104
|
+
kernel = kernel / jnp.sum(kernel) # Normalize
|
|
105
|
+
|
|
106
|
+
# Convolve data with kernel
|
|
107
|
+
# For simplicity, we'll use a simple moving average approximation
|
|
108
|
+
if len(data) < kernel_size:
|
|
109
|
+
return jnp.array([])
|
|
110
|
+
|
|
111
|
+
# Simple convolution approximation
|
|
112
|
+
result = jnp.convolve(data, kernel, mode='valid')
|
|
113
|
+
|
|
114
|
+
return result
|
|
115
|
+
|
|
116
|
+
def _linear_regression_jax(self, x: jnp.ndarray, y: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
|
117
|
+
"""
|
|
118
|
+
Perform linear regression using JAX.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
x: Independent variable
|
|
122
|
+
y: Dependent variable
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
Tuple of (slope, intercept, r_squared)
|
|
126
|
+
"""
|
|
127
|
+
# Center the data
|
|
128
|
+
x_mean = jnp.mean(x)
|
|
129
|
+
y_mean = jnp.mean(y)
|
|
130
|
+
|
|
131
|
+
x_centered = x - x_mean
|
|
132
|
+
y_centered = y - y_mean
|
|
133
|
+
|
|
134
|
+
# Compute slope
|
|
135
|
+
numerator = jnp.sum(x_centered * y_centered)
|
|
136
|
+
denominator = jnp.sum(x_centered ** 2)
|
|
137
|
+
|
|
138
|
+
if denominator == 0:
|
|
139
|
+
slope = jnp.array(0.0)
|
|
140
|
+
else:
|
|
141
|
+
slope = numerator / denominator
|
|
142
|
+
|
|
143
|
+
# Compute intercept
|
|
144
|
+
intercept = y_mean - slope * x_mean
|
|
145
|
+
|
|
146
|
+
# Compute R-squared
|
|
147
|
+
y_pred = slope * x + intercept
|
|
148
|
+
ss_res = jnp.sum((y - y_pred) ** 2)
|
|
149
|
+
ss_tot = jnp.sum((y - y_mean) ** 2)
|
|
150
|
+
|
|
151
|
+
if ss_tot == 0:
|
|
152
|
+
r_squared = jnp.array(0.0)
|
|
153
|
+
else:
|
|
154
|
+
r_squared = 1 - (ss_res / ss_tot)
|
|
155
|
+
|
|
156
|
+
return slope, intercept, r_squared
|
|
157
|
+
|
|
158
|
+
def estimate(self, data: np.ndarray) -> Dict[str, Any]:
|
|
159
|
+
"""
|
|
160
|
+
Estimate the Hurst parameter using JAX-optimized CWT analysis.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
data: Input time series data
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
Dictionary containing estimation results
|
|
167
|
+
"""
|
|
168
|
+
data = jnp.asarray(data)
|
|
169
|
+
|
|
170
|
+
if len(data) < 100:
|
|
171
|
+
raise ValueError("Data length must be at least 100 for CWT analysis")
|
|
172
|
+
|
|
173
|
+
# Calculate CWT coefficients for each scale
|
|
174
|
+
scale_logs = []
|
|
175
|
+
power_logs = []
|
|
176
|
+
scale_powers = {}
|
|
177
|
+
|
|
178
|
+
for scale in self.scales:
|
|
179
|
+
# Compute CWT coefficients using JAX
|
|
180
|
+
coeffs = self._compute_cwt_jax(data, scale)
|
|
181
|
+
|
|
182
|
+
if len(coeffs) > 0:
|
|
183
|
+
# Calculate power at this scale
|
|
184
|
+
power = jnp.mean(jnp.abs(coeffs) ** 2)
|
|
185
|
+
scale_powers[scale] = float(power)
|
|
186
|
+
|
|
187
|
+
# Compute log values
|
|
188
|
+
if power > 0:
|
|
189
|
+
scale_log = jnp.log2(scale)
|
|
190
|
+
power_log = jnp.log2(power)
|
|
191
|
+
|
|
192
|
+
scale_logs.append(scale_log)
|
|
193
|
+
power_logs.append(power_log)
|
|
194
|
+
|
|
195
|
+
if len(scale_logs) < 2:
|
|
196
|
+
# Return default values if insufficient data
|
|
197
|
+
self.results = {
|
|
198
|
+
"hurst_parameter": 0.5,
|
|
199
|
+
"r_squared": 0.0,
|
|
200
|
+
"std_error": 0.0,
|
|
201
|
+
"confidence_interval": (0.5, 0.5),
|
|
202
|
+
"scale_powers": scale_powers
|
|
203
|
+
}
|
|
204
|
+
return self.results
|
|
205
|
+
|
|
206
|
+
# Convert to JAX arrays for regression
|
|
207
|
+
x = jnp.array(scale_logs)
|
|
208
|
+
y = jnp.array(power_logs)
|
|
209
|
+
|
|
210
|
+
# Perform linear regression using JAX
|
|
211
|
+
slope, intercept, r_squared = self._linear_regression_jax(x, y)
|
|
212
|
+
|
|
213
|
+
# Hurst parameter is related to the slope
|
|
214
|
+
# For CWT: H = (slope + 1) / 2
|
|
215
|
+
hurst_parameter = (float(slope) + 1) / 2
|
|
216
|
+
|
|
217
|
+
# Ensure Hurst parameter is in valid range
|
|
218
|
+
hurst_parameter = jnp.clip(hurst_parameter, 0.01, 0.99)
|
|
219
|
+
|
|
220
|
+
# Calculate confidence interval (simplified)
|
|
221
|
+
n = len(scale_logs)
|
|
222
|
+
if n > 2:
|
|
223
|
+
# Simple confidence interval based on R-squared
|
|
224
|
+
margin = 0.1 * (1 - float(r_squared))
|
|
225
|
+
confidence_interval = (float(hurst_parameter) - margin, float(hurst_parameter) + margin)
|
|
226
|
+
else:
|
|
227
|
+
confidence_interval = (float(hurst_parameter), float(hurst_parameter))
|
|
228
|
+
|
|
229
|
+
# Store results
|
|
230
|
+
self.results = {
|
|
231
|
+
"hurst_parameter": float(hurst_parameter),
|
|
232
|
+
"r_squared": float(r_squared),
|
|
233
|
+
"std_error": 0.0, # Simplified for JAX version
|
|
234
|
+
"confidence_interval": confidence_interval,
|
|
235
|
+
"scale_powers": scale_powers,
|
|
236
|
+
"scale_logs": [float(x) for x in scale_logs],
|
|
237
|
+
"power_logs": [float(y) for y in power_logs],
|
|
238
|
+
"slope": float(slope),
|
|
239
|
+
"intercept": float(intercept)
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
return self.results
|