crisp-ase 1.0.0.post0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crisp-ase might be problematic. Click here for more details.
- CRISP/__init__.py +19 -0
- CRISP/_version.py +1 -0
- CRISP/cli.py +31 -0
- CRISP/data_analysis/__init__.py +9 -0
- CRISP/data_analysis/clustering.py +828 -0
- CRISP/data_analysis/contact_coordination.py +915 -0
- CRISP/data_analysis/h_bond.py +716 -0
- CRISP/data_analysis/msd.py +1179 -0
- CRISP/data_analysis/prdf.py +403 -0
- CRISP/data_analysis/volumetric_atomic_density.py +527 -0
- CRISP/py.typed +1 -0
- CRISP/simulation_utility/__init__.py +9 -0
- CRISP/simulation_utility/atomic_indices.py +144 -0
- CRISP/simulation_utility/atomic_traj_linemap.py +278 -0
- CRISP/simulation_utility/error_analysis.py +253 -0
- CRISP/simulation_utility/interatomic_distances.py +198 -0
- CRISP/simulation_utility/subsampling.py +221 -0
- CRISP/tests/__init__.py +3 -0
- CRISP/tests/test_CRISP.py +28 -0
- CRISP/tests/test_crisp_comprehensive.py +677 -0
- crisp_ase-1.0.0.post0.dev0.dist-info/METADATA +116 -0
- crisp_ase-1.0.0.post0.dev0.dist-info/RECORD +25 -0
- crisp_ase-1.0.0.post0.dev0.dist-info/WHEEL +5 -0
- crisp_ase-1.0.0.post0.dev0.dist-info/entry_points.txt +2 -0
- crisp_ase-1.0.0.post0.dev0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CRISP/simulation_utility/atomic_traj_linemap.py
|
|
3
|
+
|
|
4
|
+
This module provides functionality for visualizing atomic trajectories from molecular dynamics
|
|
5
|
+
simulations using interactive 3D plots.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import numpy as np
|
|
10
|
+
from typing import List, Optional, Dict, Union
|
|
11
|
+
from ase.io import read
|
|
12
|
+
import plotly.graph_objects as go
|
|
13
|
+
import plotly.io as pio
|
|
14
|
+
pio.renderers.default = "notebook"
|
|
15
|
+
|
|
16
|
+
# Dictionary of van der Waals radii for all elements in Ångström
|
|
17
|
+
VDW_RADII = {
|
|
18
|
+
# Period 1
|
|
19
|
+
'H': 1.20, 'He': 1.40,
|
|
20
|
+
# Period 2
|
|
21
|
+
'Li': 1.82, 'Be': 1.53, 'B': 1.92, 'C': 1.70, 'N': 1.55, 'O': 1.52, 'F': 1.47, 'Ne': 1.54,
|
|
22
|
+
# Period 3
|
|
23
|
+
'Na': 2.27, 'Mg': 1.73, 'Al': 1.84, 'Si': 2.10, 'P': 1.80, 'S': 1.80, 'Cl': 1.75, 'Ar': 1.88,
|
|
24
|
+
# Period 4
|
|
25
|
+
'K': 2.75, 'Ca': 2.31, 'Sc': 2.30, 'Ti': 2.15, 'V': 2.05, 'Cr': 2.05, 'Mn': 2.05,
|
|
26
|
+
'Fe': 2.05, 'Co': 2.00, 'Ni': 2.00, 'Cu': 2.00, 'Zn': 2.10, 'Ga': 1.87, 'Ge': 2.11,
|
|
27
|
+
'As': 1.85, 'Se': 1.90, 'Br': 1.85, 'Kr': 2.02,
|
|
28
|
+
# Period 5
|
|
29
|
+
'Rb': 3.03, 'Sr': 2.49, 'Y': 2.40, 'Zr': 2.30, 'Nb': 2.15, 'Mo': 2.10, 'Tc': 2.05,
|
|
30
|
+
'Ru': 2.05, 'Rh': 2.00, 'Pd': 2.05, 'Ag': 2.10, 'Cd': 2.20, 'In': 2.20, 'Sn': 2.17,
|
|
31
|
+
'Sb': 2.06, 'Te': 2.06, 'I': 1.98, 'Xe': 2.16,
|
|
32
|
+
# Period 6
|
|
33
|
+
'Cs': 3.43, 'Ba': 2.68, 'La': 2.50, 'Ce': 2.48, 'Pr': 2.47, 'Nd': 2.45, 'Pm': 2.43,
|
|
34
|
+
'Sm': 2.42, 'Eu': 2.40, 'Gd': 2.38, 'Tb': 2.37, 'Dy': 2.35, 'Ho': 2.33, 'Er': 2.32,
|
|
35
|
+
'Tm': 2.30, 'Yb': 2.28, 'Lu': 2.27, 'Hf': 2.25, 'Ta': 2.20, 'W': 2.10, 'Re': 2.05,
|
|
36
|
+
'Os': 2.00, 'Ir': 2.00, 'Pt': 2.05, 'Au': 2.10, 'Hg': 2.05, 'Tl': 2.20, 'Pb': 2.30,
|
|
37
|
+
'Bi': 2.30, 'Po': 2.00, 'At': 2.00, 'Rn': 2.00,
|
|
38
|
+
# Period 7
|
|
39
|
+
'Fr': 3.50, 'Ra': 2.80, 'Ac': 2.60, 'Th': 2.40, 'Pa': 2.30, 'U': 2.30, 'Np': 2.30,
|
|
40
|
+
'Pu': 2.30, 'Am': 2.30, 'Cm': 2.30, 'Bk': 2.30, 'Cf': 2.30, 'Es': 2.30, 'Fm': 2.30,
|
|
41
|
+
'Md': 2.30, 'No': 2.30, 'Lr': 2.30, 'Rf': 2.30, 'Db': 2.30, 'Sg': 2.30, 'Bh': 2.30,
|
|
42
|
+
'Hs': 2.30, 'Mt': 2.30, 'Ds': 2.30, 'Rg': 2.30, 'Cn': 2.30, 'Nh': 2.30, 'Fl': 2.30,
|
|
43
|
+
'Mc': 2.30, 'Lv': 2.30, 'Ts': 2.30, 'Og': 2.30
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
ELEMENT_COLORS = {
|
|
48
|
+
# Common elements
|
|
49
|
+
'H': 'white', 'C': 'black', 'N': 'blue', 'O': 'red', 'F': 'green',
|
|
50
|
+
'Na': 'purple', 'Mg': 'pink', 'Al': 'gray', 'Si': 'yellow', 'P': 'orange',
|
|
51
|
+
'S': 'yellow', 'Cl': 'green', 'K': 'purple', 'Ca': 'gray', 'Fe': 'orange',
|
|
52
|
+
'Cu': 'orange', 'Zn': 'gray',
|
|
53
|
+
# Additional common elements with colors
|
|
54
|
+
'Br': 'brown', 'I': 'purple', 'Li': 'purple', 'B': 'olive',
|
|
55
|
+
'He': 'cyan', 'Ne': 'cyan', 'Ar': 'cyan', 'Kr': 'cyan', 'Xe': 'cyan',
|
|
56
|
+
'Mn': 'gray', 'Co': 'blue', 'Ni': 'green', 'Pd': 'gray', 'Pt': 'gray',
|
|
57
|
+
'Au': 'gold', 'Hg': 'silver', 'Pb': 'darkgray', 'Ag': 'silver',
|
|
58
|
+
'Ti': 'gray', 'V': 'gray', 'Cr': 'gray', 'Zr': 'gray', 'Mo': 'gray',
|
|
59
|
+
'W': 'gray', 'U': 'green'
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
def plot_atomic_trajectory(
|
|
63
|
+
traj_path: str,
|
|
64
|
+
indices_path: Union[str, List[int]],
|
|
65
|
+
output_dir: str,
|
|
66
|
+
output_filename: str = "trajectory_plot.html",
|
|
67
|
+
frame_skip: int = 100,
|
|
68
|
+
plot_title: str = None,
|
|
69
|
+
show_plot: bool = False,
|
|
70
|
+
atom_size_scale: float = 1.0,
|
|
71
|
+
plot_lines: bool = False
|
|
72
|
+
):
|
|
73
|
+
"""
|
|
74
|
+
Create a 3D visualization of atom trajectories with all atom types displayed.
|
|
75
|
+
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
traj_path : str
|
|
79
|
+
Path to the ASE trajectory file (supports any ASE-readable format like XYZ)
|
|
80
|
+
indices_path : str or List[int]
|
|
81
|
+
Either a path to numpy file containing atom indices to plot trajectories for,
|
|
82
|
+
or a direct list of atom indices
|
|
83
|
+
output_dir : str
|
|
84
|
+
Directory where the output visualization will be saved
|
|
85
|
+
output_filename : str, optional
|
|
86
|
+
Filename for the output visualization (default: "trajectory_plot.html")
|
|
87
|
+
frame_skip : int, optional
|
|
88
|
+
Use every nth frame from the trajectory (default: 100)
|
|
89
|
+
plot_title : str, optional
|
|
90
|
+
Custom title for the plot (default: auto-generated based on atom types)
|
|
91
|
+
show_plot : bool, optional
|
|
92
|
+
Whether to display the plot interactively (default: False)
|
|
93
|
+
atom_size_scale : float, optional
|
|
94
|
+
Scale factor for atom sizes in the visualization (default: 1.0)
|
|
95
|
+
plot_lines : bool, optional
|
|
96
|
+
Whether to connect trajectory points with lines (default: False)
|
|
97
|
+
|
|
98
|
+
Returns
|
|
99
|
+
-------
|
|
100
|
+
plotly.graph_objects.Figure
|
|
101
|
+
The generated plotly figure object that can be further customized
|
|
102
|
+
|
|
103
|
+
Notes
|
|
104
|
+
-----
|
|
105
|
+
This function creates an interactive 3D visualization showing:
|
|
106
|
+
1. All atoms from the first frame, colored by element
|
|
107
|
+
2. Trajectory paths for selected atoms throughout all frames
|
|
108
|
+
3. Annotations for the start and end positions of traced atoms
|
|
109
|
+
|
|
110
|
+
The output is saved as an HTML file which can be opened in any web browser.
|
|
111
|
+
"""
|
|
112
|
+
print(f"Loading trajectory from {traj_path} (using every {frame_skip}th frame)...")
|
|
113
|
+
traj = read(traj_path, index=f'::{frame_skip}')
|
|
114
|
+
|
|
115
|
+
# Convert to list if not already (happens with single frame)
|
|
116
|
+
if not isinstance(traj, list):
|
|
117
|
+
traj = [traj]
|
|
118
|
+
|
|
119
|
+
print(f"Loaded {len(traj)} frames from trajectory")
|
|
120
|
+
|
|
121
|
+
if isinstance(indices_path, str):
|
|
122
|
+
selected_indices = np.load(indices_path)
|
|
123
|
+
print(f"Loaded {len(selected_indices)} atoms for trajectory plotting from {indices_path}")
|
|
124
|
+
else:
|
|
125
|
+
selected_indices = np.array(indices_path)
|
|
126
|
+
print(f"Using {len(selected_indices)} directly provided atom indices for trajectory plotting")
|
|
127
|
+
|
|
128
|
+
box = traj[0].cell.lengths()
|
|
129
|
+
print(f"Simulation box dimensions: {box} Å")
|
|
130
|
+
|
|
131
|
+
atom_types = {}
|
|
132
|
+
max_index = max([atom.index for atom in traj[0]])
|
|
133
|
+
print(f"Analyzing atom types in first frame (total atoms: {len(traj[0])}, max index: {max_index})...")
|
|
134
|
+
|
|
135
|
+
for atom in traj[0]:
|
|
136
|
+
symbol = atom.symbol
|
|
137
|
+
if symbol not in atom_types:
|
|
138
|
+
atom_types[symbol] = []
|
|
139
|
+
atom_types[symbol].append(atom.index)
|
|
140
|
+
|
|
141
|
+
print(f"Found {len(atom_types)} atom types: {', '.join(atom_types.keys())}")
|
|
142
|
+
|
|
143
|
+
fig = go.Figure()
|
|
144
|
+
|
|
145
|
+
use_same_color = len(selected_indices) > 5
|
|
146
|
+
colors = ['blue'] * len(selected_indices) if use_same_color else [
|
|
147
|
+
'blue', 'green', 'red', 'orange', 'purple'
|
|
148
|
+
][:len(selected_indices)]
|
|
149
|
+
|
|
150
|
+
for symbol, indices in atom_types.items():
|
|
151
|
+
positions = np.array([traj[0].positions[i] for i in indices])
|
|
152
|
+
|
|
153
|
+
# Skip if no atoms of this type
|
|
154
|
+
if len(positions) == 0:
|
|
155
|
+
continue
|
|
156
|
+
|
|
157
|
+
size = VDW_RADII.get(symbol, 1.0) * 3.0 * atom_size_scale
|
|
158
|
+
color = ELEMENT_COLORS.get(symbol, 'gray')
|
|
159
|
+
|
|
160
|
+
fig.add_trace(go.Scatter3d(
|
|
161
|
+
x=positions[:, 0],
|
|
162
|
+
y=positions[:, 1],
|
|
163
|
+
z=positions[:, 2],
|
|
164
|
+
mode='markers',
|
|
165
|
+
name=f'{symbol} Atoms',
|
|
166
|
+
marker=dict(
|
|
167
|
+
size=size,
|
|
168
|
+
color=color,
|
|
169
|
+
symbol='circle',
|
|
170
|
+
opacity=0.7,
|
|
171
|
+
line=dict(color='black', width=0.5)
|
|
172
|
+
)
|
|
173
|
+
))
|
|
174
|
+
|
|
175
|
+
selected_positions = {idx: [] for idx in selected_indices}
|
|
176
|
+
for atoms in traj:
|
|
177
|
+
for idx in selected_indices:
|
|
178
|
+
if idx < len(atoms):
|
|
179
|
+
selected_positions[idx].append(atoms.positions[idx])
|
|
180
|
+
else:
|
|
181
|
+
print(f"Warning: Index {idx} is out of range")
|
|
182
|
+
|
|
183
|
+
annotations = []
|
|
184
|
+
for i, idx in enumerate(selected_indices):
|
|
185
|
+
if not selected_positions[idx]:
|
|
186
|
+
continue
|
|
187
|
+
|
|
188
|
+
pos = np.array(selected_positions[idx])
|
|
189
|
+
color = colors[i % len(colors)]
|
|
190
|
+
|
|
191
|
+
# Add annotations for first and last frames
|
|
192
|
+
first_frame_pos = pos[0]
|
|
193
|
+
last_frame_pos = pos[-1]
|
|
194
|
+
|
|
195
|
+
annotations.extend([
|
|
196
|
+
dict(
|
|
197
|
+
x=first_frame_pos[0],
|
|
198
|
+
y=first_frame_pos[1],
|
|
199
|
+
z=first_frame_pos[2],
|
|
200
|
+
text=f'Start {idx}',
|
|
201
|
+
showarrow=True,
|
|
202
|
+
arrowhead=2,
|
|
203
|
+
ax=20,
|
|
204
|
+
ay=-20,
|
|
205
|
+
arrowcolor=color,
|
|
206
|
+
font=dict(color=color, size=10)
|
|
207
|
+
),
|
|
208
|
+
dict(
|
|
209
|
+
x=last_frame_pos[0],
|
|
210
|
+
y=last_frame_pos[1],
|
|
211
|
+
z=last_frame_pos[2],
|
|
212
|
+
text=f'End {idx}',
|
|
213
|
+
showarrow=True,
|
|
214
|
+
arrowhead=2,
|
|
215
|
+
ax=-20,
|
|
216
|
+
ay=-20,
|
|
217
|
+
arrowcolor=color,
|
|
218
|
+
font=dict(color=color, size=10)
|
|
219
|
+
)
|
|
220
|
+
])
|
|
221
|
+
|
|
222
|
+
if plot_lines:
|
|
223
|
+
# Add trajectory line and markers (original behavior)
|
|
224
|
+
fig.add_trace(go.Scatter3d(
|
|
225
|
+
x=pos[:, 0],
|
|
226
|
+
y=pos[:, 1],
|
|
227
|
+
z=pos[:, 2],
|
|
228
|
+
mode='lines+markers',
|
|
229
|
+
name=f'Atom {idx}',
|
|
230
|
+
line=dict(width=3, color=color),
|
|
231
|
+
marker=dict(size=4, color=color),
|
|
232
|
+
))
|
|
233
|
+
else:
|
|
234
|
+
# Scatter-only mode: just showing points for each frame
|
|
235
|
+
fig.add_trace(go.Scatter3d(
|
|
236
|
+
x=pos[:, 0],
|
|
237
|
+
y=pos[:, 1],
|
|
238
|
+
z=pos[:, 2],
|
|
239
|
+
mode='markers',
|
|
240
|
+
name=f'Atom {idx}',
|
|
241
|
+
marker=dict(
|
|
242
|
+
size=5,
|
|
243
|
+
color=color,
|
|
244
|
+
symbol='circle',
|
|
245
|
+
opacity=0.8,
|
|
246
|
+
line=dict(color='black', width=0.5)
|
|
247
|
+
)
|
|
248
|
+
))
|
|
249
|
+
|
|
250
|
+
if not plot_title:
|
|
251
|
+
atom_types_str = ', '.join(atom_types.keys())
|
|
252
|
+
plot_title = f'Atomic Trajectories in {atom_types_str} System'
|
|
253
|
+
|
|
254
|
+
fig.update_layout(
|
|
255
|
+
title=plot_title,
|
|
256
|
+
scene=dict(
|
|
257
|
+
xaxis_title='X (Å)',
|
|
258
|
+
yaxis_title='Y (Å)',
|
|
259
|
+
zaxis_title='Z (Å)',
|
|
260
|
+
xaxis=dict(range=[0, box[0]]),
|
|
261
|
+
yaxis=dict(range=[0, box[1]]),
|
|
262
|
+
zaxis=dict(range=[0, box[2]]),
|
|
263
|
+
aspectmode='cube'
|
|
264
|
+
),
|
|
265
|
+
margin=dict(l=0, r=0, b=0, t=40),
|
|
266
|
+
scene_annotations=annotations
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
270
|
+
output_path = os.path.join(output_dir, output_filename)
|
|
271
|
+
|
|
272
|
+
fig.write_html(output_path)
|
|
273
|
+
print(f"Plot has been saved to {output_path}")
|
|
274
|
+
|
|
275
|
+
if show_plot:
|
|
276
|
+
fig.show()
|
|
277
|
+
|
|
278
|
+
return fig
|
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CRISP/simulation_utility/error_analysis.py
|
|
3
|
+
|
|
4
|
+
This module provides statistical error analysis tools for molecular dynamics simulations,
|
|
5
|
+
including autocorrelation and block averaging methods to estimate statistical errors.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import matplotlib.pyplot as plt
|
|
10
|
+
from statsmodels.tsa.stattools import acf
|
|
11
|
+
import warnings
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def optimal_lag(acf_values, threshold=0.05):
|
|
15
|
+
"""
|
|
16
|
+
Find the optimal lag time at which autocorrelation drops below threshold.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
acf_values : numpy.ndarray
|
|
21
|
+
Array of autocorrelation function values
|
|
22
|
+
threshold : float, optional
|
|
23
|
+
Correlation threshold below which data is considered uncorrelated (default: 0.05)
|
|
24
|
+
|
|
25
|
+
Returns
|
|
26
|
+
-------
|
|
27
|
+
int
|
|
28
|
+
Optimal lag time where autocorrelation drops below threshold
|
|
29
|
+
|
|
30
|
+
Warns
|
|
31
|
+
-----
|
|
32
|
+
UserWarning
|
|
33
|
+
If autocorrelation function doesn't converge within available data
|
|
34
|
+
"""
|
|
35
|
+
for lag, value in enumerate(acf_values):
|
|
36
|
+
if abs(value) < threshold:
|
|
37
|
+
return lag
|
|
38
|
+
|
|
39
|
+
acf_not_converged = (
|
|
40
|
+
"Autocorrelation function is not converged. "
|
|
41
|
+
f"Consider increasing the 'max_lag' parameter (current: {len(acf_values) - 1}) "
|
|
42
|
+
"or extending the simulation length."
|
|
43
|
+
)
|
|
44
|
+
warnings.warn(acf_not_converged)
|
|
45
|
+
|
|
46
|
+
return len(acf_values) - 1
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def vector_acf(data, max_lag):
|
|
50
|
+
"""
|
|
51
|
+
Calculate autocorrelation function for vector data.
|
|
52
|
+
|
|
53
|
+
Parameters
|
|
54
|
+
----------
|
|
55
|
+
data : numpy.ndarray
|
|
56
|
+
Input vector data with shape (n_frames, n_dimensions)
|
|
57
|
+
max_lag : int
|
|
58
|
+
Maximum lag time to calculate autocorrelation for
|
|
59
|
+
|
|
60
|
+
Returns
|
|
61
|
+
-------
|
|
62
|
+
numpy.ndarray
|
|
63
|
+
Array of autocorrelation values from lag 0 to max_lag
|
|
64
|
+
"""
|
|
65
|
+
n_frames = data.shape[0]
|
|
66
|
+
m = np.mean(data, axis=0)
|
|
67
|
+
data_centered = data - m
|
|
68
|
+
norm0 = np.mean(np.sum(data_centered**2, axis=1))
|
|
69
|
+
acf_vals = np.zeros(max_lag + 1)
|
|
70
|
+
for tau in range(max_lag + 1):
|
|
71
|
+
dots = np.sum(data_centered[:n_frames - tau] * data_centered[tau:], axis=1)
|
|
72
|
+
acf_vals[tau] = np.mean(dots) / norm0
|
|
73
|
+
|
|
74
|
+
return acf_vals
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def autocorrelation_analysis(data, max_lag=None, threshold=0.05, plot_acf=False):
|
|
78
|
+
"""
|
|
79
|
+
Perform autocorrelation analysis to estimate statistical errors.
|
|
80
|
+
|
|
81
|
+
Parameters
|
|
82
|
+
----------
|
|
83
|
+
data : numpy.ndarray
|
|
84
|
+
Input data (1D array or multi-dimensional array)
|
|
85
|
+
max_lag : int, optional
|
|
86
|
+
Maximum lag time to calculate autocorrelation for (default: min(1000, N/10))
|
|
87
|
+
threshold : float, optional
|
|
88
|
+
Correlation threshold below which data is considered uncorrelated (default: 0.05)
|
|
89
|
+
plot_acf : bool, optional
|
|
90
|
+
Whether to generate an autocorrelation plot (default: False)
|
|
91
|
+
|
|
92
|
+
Returns
|
|
93
|
+
-------
|
|
94
|
+
dict
|
|
95
|
+
Dictionary containing:
|
|
96
|
+
- mean: Mean value of data
|
|
97
|
+
- acf_err: Error estimate from autocorrelation analysis
|
|
98
|
+
- std: Standard deviation of data
|
|
99
|
+
- tau_int: Integrated autocorrelation time
|
|
100
|
+
- optimal_lag: Optimal lag time where autocorrelation drops below threshold
|
|
101
|
+
"""
|
|
102
|
+
if data.ndim == 1:
|
|
103
|
+
N = len(data)
|
|
104
|
+
mean_value = np.mean(data)
|
|
105
|
+
std_value = np.std(data, ddof=1)
|
|
106
|
+
if max_lag is None:
|
|
107
|
+
max_lag = min(1000, N // 10)
|
|
108
|
+
acf_values = acf(data - mean_value, nlags=max_lag, fft=True)
|
|
109
|
+
opt_lag = optimal_lag(acf_values, threshold)
|
|
110
|
+
else:
|
|
111
|
+
N = data.shape[0]
|
|
112
|
+
mean_value = np.mean(data, axis=0)
|
|
113
|
+
std_value = np.std(data, axis=0, ddof=1)
|
|
114
|
+
if max_lag is None:
|
|
115
|
+
max_lag = min(N // 20)
|
|
116
|
+
acf_values = vector_acf(data, max_lag)
|
|
117
|
+
opt_lag = optimal_lag(acf_values, threshold)
|
|
118
|
+
|
|
119
|
+
tau_int = 0.5 + np.sum(acf_values[1:opt_lag + 1])
|
|
120
|
+
autocorr_error = std_value * np.sqrt(2 * tau_int / N)
|
|
121
|
+
|
|
122
|
+
if plot_acf:
|
|
123
|
+
plt.figure(figsize=(8, 5))
|
|
124
|
+
plt.plot(np.arange(len(acf_values)), acf_values, linestyle='-', linewidth=2, color="blue", label='ACF')
|
|
125
|
+
plt.axhline(y=threshold, color='red', linestyle='--', linewidth=1.5, label='Threshold')
|
|
126
|
+
plt.xlabel('Lag')
|
|
127
|
+
plt.ylabel('Autocorrelation')
|
|
128
|
+
plt.title('Autocorrelation Function (ACF)')
|
|
129
|
+
plt.legend()
|
|
130
|
+
plt.savefig("ACF_lag_analysis.png", dpi=300, bbox_inches='tight')
|
|
131
|
+
|
|
132
|
+
return {"mean": mean_value, "acf_err": autocorr_error, "std": std_value, "tau_int": tau_int, "optimal_lag": opt_lag}
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def block_analysis(data, convergence_tol=0.001, plot_blocks=False):
|
|
136
|
+
"""
|
|
137
|
+
Perform block averaging analysis to estimate statistical errors.
|
|
138
|
+
|
|
139
|
+
Parameters
|
|
140
|
+
----------
|
|
141
|
+
data : numpy.ndarray
|
|
142
|
+
Input data array
|
|
143
|
+
convergence_tol : float, optional
|
|
144
|
+
Tolerance for determining convergence of standard error (default: 0.001)
|
|
145
|
+
plot_blocks : bool, optional
|
|
146
|
+
Whether to generate a block averaging plot (default: False)
|
|
147
|
+
|
|
148
|
+
Returns
|
|
149
|
+
-------
|
|
150
|
+
dict
|
|
151
|
+
Dictionary containing:
|
|
152
|
+
- mean: Mean value of data
|
|
153
|
+
- block_err: Error estimate from block averaging
|
|
154
|
+
- std: Standard deviation of data
|
|
155
|
+
- converged_blocks: Number of blocks at convergence
|
|
156
|
+
|
|
157
|
+
Warns
|
|
158
|
+
-----
|
|
159
|
+
UserWarning
|
|
160
|
+
If block averaging does not converge with the given tolerance
|
|
161
|
+
"""
|
|
162
|
+
N = len(data)
|
|
163
|
+
mean_value = np.mean(data)
|
|
164
|
+
std_value = np.std(data, ddof=1)
|
|
165
|
+
|
|
166
|
+
block_sizes = np.arange(1, N // 2)
|
|
167
|
+
standard_errors = []
|
|
168
|
+
|
|
169
|
+
for M in block_sizes:
|
|
170
|
+
block_length = N // M
|
|
171
|
+
|
|
172
|
+
truncated_data = data[:block_length * M]
|
|
173
|
+
blocks = truncated_data.reshape(M, block_length)
|
|
174
|
+
block_means = np.mean(blocks, axis=1)
|
|
175
|
+
|
|
176
|
+
if len(block_means) > 1:
|
|
177
|
+
std_error = np.std(block_means, ddof=1) / np.sqrt(M)
|
|
178
|
+
else:
|
|
179
|
+
continue
|
|
180
|
+
|
|
181
|
+
standard_errors.append(std_error)
|
|
182
|
+
|
|
183
|
+
if len(standard_errors) > 5:
|
|
184
|
+
recent_errors = standard_errors[-5:]
|
|
185
|
+
if np.max(recent_errors) - np.min(recent_errors) < convergence_tol:
|
|
186
|
+
converged_blocks = M
|
|
187
|
+
final_error = std_error
|
|
188
|
+
break
|
|
189
|
+
else:
|
|
190
|
+
converged_blocks = block_sizes[-1]
|
|
191
|
+
final_error = standard_errors[-1]
|
|
192
|
+
warnings.warn("Block averaging did not fully converge. Consider increasing data length or lowering tolerance.")
|
|
193
|
+
|
|
194
|
+
if plot_blocks:
|
|
195
|
+
plt.figure(figsize=(8, 5))
|
|
196
|
+
plt.plot(block_sizes[:len(standard_errors)], standard_errors, color="blue", label='Standard Error')
|
|
197
|
+
plt.xlabel('Number of Blocks')
|
|
198
|
+
plt.ylabel('Standard Error')
|
|
199
|
+
plt.title('Block Averaging Convergence')
|
|
200
|
+
plt.savefig("block_averaging_convergence.png", dpi=300, bbox_inches='tight')
|
|
201
|
+
|
|
202
|
+
return {
|
|
203
|
+
"mean": mean_value,
|
|
204
|
+
"block_err": final_error,
|
|
205
|
+
"std": std_value,
|
|
206
|
+
"converged_blocks": converged_blocks
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def error_analysis(data, max_lag=None, threshold=0.05, convergence_tol=0.001, plot=False):
|
|
211
|
+
"""
|
|
212
|
+
Perform comprehensive error analysis using both autocorrelation and block averaging.
|
|
213
|
+
|
|
214
|
+
Parameters
|
|
215
|
+
----------
|
|
216
|
+
data : numpy.ndarray
|
|
217
|
+
Input data array
|
|
218
|
+
max_lag : int, optional
|
|
219
|
+
Maximum lag time for autocorrelation (default: min(1000, N/10))
|
|
220
|
+
threshold : float, optional
|
|
221
|
+
Correlation threshold for autocorrelation analysis (default: 0.05)
|
|
222
|
+
convergence_tol : float, optional
|
|
223
|
+
Convergence tolerance for block averaging (default: 0.001)
|
|
224
|
+
plot : bool, optional
|
|
225
|
+
Whether to generate diagnostic plots (default: False)
|
|
226
|
+
|
|
227
|
+
Returns
|
|
228
|
+
-------
|
|
229
|
+
dict
|
|
230
|
+
Dictionary containing results from both methods:
|
|
231
|
+
- mean: Mean value of data
|
|
232
|
+
- std: Standard deviation of data
|
|
233
|
+
- acf_results: Full results from autocorrelation analysis
|
|
234
|
+
- block_results: Full results from block averaging analysis
|
|
235
|
+
"""
|
|
236
|
+
# Ensure data is a numpy array
|
|
237
|
+
data = np.asarray(data)
|
|
238
|
+
|
|
239
|
+
# Perform both types of analysis
|
|
240
|
+
acf_results = autocorrelation_analysis(data, max_lag, threshold, plot_acf=plot)
|
|
241
|
+
block_results = block_analysis(data, convergence_tol, plot_blocks=plot)
|
|
242
|
+
|
|
243
|
+
# Combine results
|
|
244
|
+
results = {
|
|
245
|
+
"mean": acf_results["mean"],
|
|
246
|
+
"std": acf_results["std"],
|
|
247
|
+
"acf_results": acf_results,
|
|
248
|
+
"block_results": block_results
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
return results
|
|
252
|
+
|
|
253
|
+
|