@yibeichan/claude-skills 1.0.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +98 -0
- package/cli.js +272 -0
- package/install.py +240 -0
- package/package.json +44 -0
- package/skills/bidsapp-nidm-standards/SKILL.md +202 -0
- package/skills/bidsapp-nidm-standards/references/babs_config.md +20 -0
- package/skills/bidsapp-nidm-standards/references/cli_arguments.md +76 -0
- package/skills/bidsapp-nidm-standards/references/container_patterns.md +53 -0
- package/skills/bidsapp-nidm-standards/references/nidm_integration.md +403 -0
- package/skills/bidsapp-nidm-standards/references/repo_structure.md +121 -0
- package/skills/bidsapp-nidm-standards/references/testing_patterns.md +82 -0
- package/skills/dicom2fmriprep/SKILL.md +377 -0
- package/skills/dicom2fmriprep/evals/evals.json +26 -0
- package/skills/dicom2fmriprep/references/babs-details.md +407 -0
- package/skills/dicom2fmriprep/references/fmriprep-details.md +250 -0
- package/skills/dicom2fmriprep/references/heudiconv-details.md +243 -0
- package/skills/fmri-ssm/SKILL.md +317 -0
- package/skills/fmri-ssm/references/code_templates.md +1570 -0
- package/skills/fmri-ssm/references/downstream_analysis.md +680 -0
- package/skills/fmri-ssm/references/group_inference.md +608 -0
- package/skills/fmri-ssm/references/hrf_modeling.md +447 -0
- package/skills/fmri-ssm/references/model_catalog.md +436 -0
- package/skills/fmri-ssm/references/paradigm_guide.md +406 -0
- package/skills/fmri-ssm/references/preprocessing.md +614 -0
- package/skills/fmri-ssm.zip +0 -0
- package/skills/neuroimaging-qc/SKILL.md +203 -0
- package/skills/neuroimaging-qc/references/eeg_qc.md +400 -0
- package/skills/neuroimaging-qc/references/fmri_qc.md +343 -0
- package/skills/neuroimaging-qc/references/fnirs_qc.md +430 -0
- package/skills/neuroimaging-qc/references/structural_qc.md +454 -0
- package/skills/neuroimaging-qc/scripts/parse_fmriprep_confounds.py +153 -0
- package/skills/neuroimaging-qc/scripts/parse_mriqc.py +114 -0
- package/skills/neuroimaging-qc/scripts/qc_report.py +295 -0
- package/skills/scientific-writer/SKILL.md +202 -0
- package/skills/scientific-writer/references/citation_styles.md +163 -0
- package/skills/scientific-writer/references/field_conventions.md +245 -0
- package/skills/scientific-writer/references/figures_tables.md +225 -0
- package/skills/scientific-writer/references/reporting_guidelines.md +225 -0
- package/skills.json +54 -0
|
@@ -0,0 +1,1570 @@
|
|
|
1
|
+
# Python Code Templates for SSMs on fMRI Data
|
|
2
|
+
|
|
3
|
+
## Table of Contents
|
|
4
|
+
1. [Gaussian HMM with hmmlearn](#gaussian-hmm)
|
|
5
|
+
2. [Gaussian HMM with ssm library](#gaussian-hmm-ssm)
|
|
6
|
+
3. [Sticky HMM](#sticky-hmm)
|
|
7
|
+
4. [HMM-MAR with osl-dynamics](#hmm-mar)
|
|
8
|
+
5. [Input-Output HMM](#io-hmm)
|
|
9
|
+
6. [SLDS with ssm library](#slds)
|
|
10
|
+
7. [rSLDS with ssm library](#rslds)
|
|
11
|
+
8. [Model Selection (choosing K)](#model-selection)
|
|
12
|
+
9. [State Visualization](#visualization)
|
|
13
|
+
10. [Reproducibility and Initialization](#reproducibility)
|
|
14
|
+
11. [Model Diagnostics: Detecting Pathological Fits](#diagnostics)
|
|
15
|
+
12. [JAX-Based HMM for GPU Acceleration](#jax-gpu)
|
|
16
|
+
13. [dynamax: Lego-Style Custom SSMs](#dynamax)
|
|
17
|
+
|
|
18
|
+
All code assumes preprocessed, parcellated timeseries. See `preprocessing.md` for
|
|
19
|
+
how to get from raw fMRIPrep outputs to the data matrices used here.
|
|
20
|
+
|
|
21
|
+
---
|
|
22
|
+
|
|
23
|
+
## 1. Gaussian HMM with hmmlearn {#gaussian-hmm}
|
|
24
|
+
|
|
25
|
+
The simplest and most widely used SSM for fMRI. Start here unless you have
|
|
26
|
+
specific reasons to use a more complex model.
|
|
27
|
+
|
|
28
|
+
```python
|
|
29
|
+
"""Gaussian HMM for fMRI using hmmlearn.
|
|
30
|
+
|
|
31
|
+
Fits a K-state HMM with Gaussian emissions on parcellated BOLD timeseries.
|
|
32
|
+
Includes multiple random restarts, K-means initialization, and run-boundary handling.
|
|
33
|
+
"""
|
|
34
|
+
import numpy as np
|
|
35
|
+
from hmmlearn import hmm
|
|
36
|
+
from sklearn.cluster import KMeans
|
|
37
|
+
|
|
38
|
+
def fit_gaussian_hmm(data, lengths, K, covariance_type='full',
|
|
39
|
+
n_restarts=50, n_iter=200, random_state=42):
|
|
40
|
+
"""Fit Gaussian HMM with multiple random restarts.
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
data : array, shape (total_T, n_features)
|
|
45
|
+
Concatenated BOLD timeseries across runs
|
|
46
|
+
lengths : list of int
|
|
47
|
+
Number of TRs per run (for run boundary handling)
|
|
48
|
+
K : int
|
|
49
|
+
Number of hidden states
|
|
50
|
+
covariance_type : str
|
|
51
|
+
'full', 'diag', 'tied', or 'spherical'
|
|
52
|
+
n_restarts : int
|
|
53
|
+
Number of random restarts (take best log-likelihood)
|
|
54
|
+
n_iter : int
|
|
55
|
+
Max EM iterations per restart
|
|
56
|
+
random_state : int
|
|
57
|
+
Base random seed
|
|
58
|
+
|
|
59
|
+
Returns
|
|
60
|
+
-------
|
|
61
|
+
best_model : GaussianHMM
|
|
62
|
+
Fitted model with highest log-likelihood
|
|
63
|
+
best_score : float
|
|
64
|
+
Log-likelihood of best model
|
|
65
|
+
all_scores : list
|
|
66
|
+
Log-likelihoods from all restarts
|
|
67
|
+
"""
|
|
68
|
+
best_model = None
|
|
69
|
+
best_score = -np.inf
|
|
70
|
+
all_scores = []
|
|
71
|
+
|
|
72
|
+
for restart in range(n_restarts):
|
|
73
|
+
model = hmm.GaussianHMM(
|
|
74
|
+
n_components=K,
|
|
75
|
+
covariance_type=covariance_type,
|
|
76
|
+
n_iter=n_iter,
|
|
77
|
+
tol=1e-4,
|
|
78
|
+
random_state=random_state + restart,
|
|
79
|
+
init_params='stmc', # initialize all parameters
|
|
80
|
+
verbose=False,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# K-means initialization for means (much better than random)
|
|
84
|
+
if restart == 0:
|
|
85
|
+
kmeans = KMeans(n_clusters=K, random_state=random_state, n_init=10)
|
|
86
|
+
kmeans.fit(data)
|
|
87
|
+
model.means_init = kmeans.cluster_centers_
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
model.fit(data, lengths=lengths)
|
|
91
|
+
score = model.score(data, lengths=lengths)
|
|
92
|
+
all_scores.append(score)
|
|
93
|
+
|
|
94
|
+
if score > best_score:
|
|
95
|
+
best_score = score
|
|
96
|
+
best_model = model
|
|
97
|
+
except Exception as e:
|
|
98
|
+
all_scores.append(np.nan)
|
|
99
|
+
continue
|
|
100
|
+
|
|
101
|
+
n_converged = sum(1 for s in all_scores if not np.isnan(s))
|
|
102
|
+
print(f"Converged: {n_converged}/{n_restarts}")
|
|
103
|
+
print(f"Best log-likelihood: {best_score:.2f}")
|
|
104
|
+
print(f"Score range: {np.nanmin(all_scores):.2f} to {np.nanmax(all_scores):.2f}")
|
|
105
|
+
|
|
106
|
+
return best_model, best_score, all_scores
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def extract_state_info(model, data, lengths):
|
|
110
|
+
"""Extract state sequence, dwell times, and transition matrix.
|
|
111
|
+
|
|
112
|
+
Returns
|
|
113
|
+
-------
|
|
114
|
+
results : dict with keys:
|
|
115
|
+
'states': state sequence
|
|
116
|
+
'state_probs': posterior state probabilities
|
|
117
|
+
'dwell_times': dict mapping state -> list of dwell durations (in TRs)
|
|
118
|
+
'fractional_occupancy': proportion of time in each state
|
|
119
|
+
'transition_matrix': estimated transition matrix
|
|
120
|
+
'means': state means, shape (K, n_features)
|
|
121
|
+
'covariances': state covariances
|
|
122
|
+
"""
|
|
123
|
+
states = model.predict(data, lengths=lengths)
|
|
124
|
+
state_probs = model.predict_proba(data, lengths=lengths)
|
|
125
|
+
|
|
126
|
+
# Compute dwell times (respecting run boundaries)
|
|
127
|
+
dwell_times = {k: [] for k in range(model.n_components)}
|
|
128
|
+
offset = 0
|
|
129
|
+
for length in lengths:
|
|
130
|
+
run_states = states[offset:offset + length]
|
|
131
|
+
current_state = run_states[0]
|
|
132
|
+
current_dwell = 1
|
|
133
|
+
for t in range(1, length):
|
|
134
|
+
if run_states[t] == current_state:
|
|
135
|
+
current_dwell += 1
|
|
136
|
+
else:
|
|
137
|
+
dwell_times[current_state].append(current_dwell)
|
|
138
|
+
current_state = run_states[t]
|
|
139
|
+
current_dwell = 1
|
|
140
|
+
dwell_times[current_state].append(current_dwell) # last dwell
|
|
141
|
+
offset += length
|
|
142
|
+
|
|
143
|
+
# Fractional occupancy
|
|
144
|
+
frac_occ = np.array([(states == k).sum() / len(states)
|
|
145
|
+
for k in range(model.n_components)])
|
|
146
|
+
|
|
147
|
+
return {
|
|
148
|
+
'states': states,
|
|
149
|
+
'state_probs': state_probs,
|
|
150
|
+
'dwell_times': dwell_times,
|
|
151
|
+
'fractional_occupancy': frac_occ,
|
|
152
|
+
'transition_matrix': model.transmat_,
|
|
153
|
+
'means': model.means_,
|
|
154
|
+
'covariances': model.covars_,
|
|
155
|
+
}
|
|
156
|
+
```
|
|
157
|
+
|
|
158
|
+
---
|
|
159
|
+
|
|
160
|
+
## 2. Gaussian HMM with ssm library {#gaussian-hmm-ssm}
|
|
161
|
+
|
|
162
|
+
The `ssm` library (Linderman lab) provides a unified API for HMMs, SLDS, and rSLDS.
|
|
163
|
+
Useful when you want to compare model families within the same framework.
|
|
164
|
+
|
|
165
|
+
```python
|
|
166
|
+
"""Gaussian HMM using the ssm library (Linderman et al.)
|
|
167
|
+
|
|
168
|
+
Install: pip install ssm (or from GitHub: pip install git+https://github.com/lindermanlab/ssm)
|
|
169
|
+
"""
|
|
170
|
+
import ssm
|
|
171
|
+
import numpy as np
|
|
172
|
+
|
|
173
|
+
def fit_hmm_ssm(data_list, K, D=None, n_restarts=20, n_iters=200):
|
|
174
|
+
"""Fit Gaussian HMM using ssm library.
|
|
175
|
+
|
|
176
|
+
Parameters
|
|
177
|
+
----------
|
|
178
|
+
data_list : list of arrays
|
|
179
|
+
Each array is (T_run, D) for one run. Do NOT concatenate — ssm
|
|
180
|
+
handles multiple sequences natively.
|
|
181
|
+
K : int
|
|
182
|
+
Number of states
|
|
183
|
+
D : int or None
|
|
184
|
+
Observation dimension (inferred from data if None)
|
|
185
|
+
|
|
186
|
+
Returns
|
|
187
|
+
-------
|
|
188
|
+
best_model : ssm.HMM
|
|
189
|
+
best_lls : list of float
|
|
190
|
+
Log-likelihood per EM iteration for best run
|
|
191
|
+
"""
|
|
192
|
+
if D is None:
|
|
193
|
+
D = data_list[0].shape[1]
|
|
194
|
+
|
|
195
|
+
best_model = None
|
|
196
|
+
best_ll = -np.inf
|
|
197
|
+
|
|
198
|
+
for restart in range(n_restarts):
|
|
199
|
+
model = ssm.HMM(
|
|
200
|
+
K=K,
|
|
201
|
+
D=D,
|
|
202
|
+
observations='gaussian', # or 'diagonal_gaussian', 'studentst'
|
|
203
|
+
transitions='standard', # or 'sticky'
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
lls = model.fit(
|
|
207
|
+
data_list,
|
|
208
|
+
method='em',
|
|
209
|
+
num_iters=n_iters,
|
|
210
|
+
tolerance=1e-4,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
final_ll = lls[-1]
|
|
214
|
+
if final_ll > best_ll:
|
|
215
|
+
best_ll = final_ll
|
|
216
|
+
best_model = model
|
|
217
|
+
best_lls = lls
|
|
218
|
+
|
|
219
|
+
print(f"Best log-likelihood: {best_ll:.2f}")
|
|
220
|
+
return best_model, best_lls
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def decode_states_ssm(model, data_list):
|
|
224
|
+
"""Get most likely state sequences and posterior probabilities."""
|
|
225
|
+
all_states = []
|
|
226
|
+
all_posteriors = []
|
|
227
|
+
|
|
228
|
+
for data in data_list:
|
|
229
|
+
# Viterbi (most likely sequence)
|
|
230
|
+
states = model.most_likely_states(data)
|
|
231
|
+
all_states.append(states)
|
|
232
|
+
|
|
233
|
+
# Posterior probabilities
|
|
234
|
+
posteriors = model.expected_states(data)[0] # (T, K)
|
|
235
|
+
all_posteriors.append(posteriors)
|
|
236
|
+
|
|
237
|
+
return all_states, all_posteriors
|
|
238
|
+
```
|
|
239
|
+
|
|
240
|
+
---
|
|
241
|
+
|
|
242
|
+
## 3. Sticky HMM {#sticky-hmm}
|
|
243
|
+
|
|
244
|
+
Adds self-transition bias to prevent unrealistically rapid state switching.
|
|
245
|
+
|
|
246
|
+
```python
|
|
247
|
+
"""Sticky HMM using ssm library."""
|
|
248
|
+
import ssm
|
|
249
|
+
|
|
250
|
+
def fit_sticky_hmm(data_list, K, D, kappa=100, n_restarts=20, n_iters=200):
|
|
251
|
+
"""Fit sticky HMM.
|
|
252
|
+
|
|
253
|
+
Parameters
|
|
254
|
+
----------
|
|
255
|
+
kappa : float
|
|
256
|
+
Stickiness parameter. Higher = longer dwell times.
|
|
257
|
+
kappa=0 is standard HMM. kappa=100-1000 is typical for fMRI.
|
|
258
|
+
Rule of thumb: set kappa so expected dwell time is ~5-10 TRs.
|
|
259
|
+
"""
|
|
260
|
+
best_model = None
|
|
261
|
+
best_ll = -np.inf
|
|
262
|
+
|
|
263
|
+
for restart in range(n_restarts):
|
|
264
|
+
model = ssm.HMM(
|
|
265
|
+
K=K,
|
|
266
|
+
D=D,
|
|
267
|
+
observations='gaussian',
|
|
268
|
+
transitions='sticky',
|
|
269
|
+
transition_kwargs={'kappa': kappa},
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
lls = model.fit(data_list, method='em', num_iters=n_iters)
|
|
273
|
+
|
|
274
|
+
if lls[-1] > best_ll:
|
|
275
|
+
best_ll = lls[-1]
|
|
276
|
+
best_model = model
|
|
277
|
+
|
|
278
|
+
# Check effective dwell times
|
|
279
|
+
trans_mat = np.exp(best_model.transitions.log_Ps)
|
|
280
|
+
expected_dwell = 1.0 / (1.0 - np.diag(trans_mat))
|
|
281
|
+
print(f"Expected dwell times (TRs): {expected_dwell}")
|
|
282
|
+
|
|
283
|
+
return best_model
|
|
284
|
+
```
|
|
285
|
+
|
|
286
|
+
---
|
|
287
|
+
|
|
288
|
+
## 4. HMM-MAR with osl-dynamics {#hmm-mar}
|
|
289
|
+
|
|
290
|
+
```python
|
|
291
|
+
"""HMM-MAR using osl-dynamics (Oxford's toolbox for dynamic brain analysis).
|
|
292
|
+
|
|
293
|
+
Install: pip install osl-dynamics
|
|
294
|
+
This is the Python successor to the MATLAB HMM-MAR toolbox (Vidaurre et al.).
|
|
295
|
+
"""
|
|
296
|
+
from osl_dynamics.models.hmm import Config, Model
|
|
297
|
+
from osl_dynamics.data import Data
|
|
298
|
+
import numpy as np
|
|
299
|
+
|
|
300
|
+
def fit_hmm_mar(data_files, K, n_channels, sequence_length=200,
|
|
301
|
+
n_ar_lags=3, learn_means=True, learn_covariances=True,
|
|
302
|
+
n_epochs=40, batch_size=64):
|
|
303
|
+
"""Fit HMM-MAR using osl-dynamics.
|
|
304
|
+
|
|
305
|
+
Parameters
|
|
306
|
+
----------
|
|
307
|
+
data_files : list of str or list of arrays
|
|
308
|
+
Paths to numpy files or arrays, each (T, n_channels)
|
|
309
|
+
K : int
|
|
310
|
+
Number of states
|
|
311
|
+
n_channels : int
|
|
312
|
+
Number of brain regions / ICA components
|
|
313
|
+
sequence_length : int
|
|
314
|
+
Length of sequences for training (segments of the timeseries)
|
|
315
|
+
n_ar_lags : int
|
|
316
|
+
Number of autoregressive lags (typically 1-5 for fMRI)
|
|
317
|
+
learn_means : bool
|
|
318
|
+
Whether states have different means
|
|
319
|
+
learn_covariances : bool
|
|
320
|
+
Whether states have different covariances
|
|
321
|
+
"""
|
|
322
|
+
# Prepare data
|
|
323
|
+
data = Data(data_files, store_dir='/tmp/osl_dynamics_data')
|
|
324
|
+
data.prepare({
|
|
325
|
+
'tde_pca': {'n_embeddings': n_ar_lags * 2 + 1, 'n_pca_components': n_channels},
|
|
326
|
+
'standardize': {},
|
|
327
|
+
})
|
|
328
|
+
|
|
329
|
+
# Configure model
|
|
330
|
+
config = Config(
|
|
331
|
+
n_states=K,
|
|
332
|
+
n_channels=n_channels,
|
|
333
|
+
sequence_length=sequence_length,
|
|
334
|
+
learn_means=learn_means,
|
|
335
|
+
learn_covariances=learn_covariances,
|
|
336
|
+
batch_size=batch_size,
|
|
337
|
+
learning_rate=0.01,
|
|
338
|
+
n_epochs=n_epochs,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
model = Model(config)
|
|
342
|
+
model.random_state_time_course_initialization(data, n_init=5, n_epochs=2)
|
|
343
|
+
history = model.fit(data)
|
|
344
|
+
|
|
345
|
+
# Get state time courses
|
|
346
|
+
alpha = model.get_alpha(data) # list of (T, K) arrays — state probabilities
|
|
347
|
+
|
|
348
|
+
return model, alpha, history
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
# Alternative: using glhmm (Vidaurre's newer library)
|
|
352
|
+
def fit_glhmm(data_list, K, ar_order=3):
|
|
353
|
+
"""Fit Gaussian-Linear HMM using glhmm.
|
|
354
|
+
|
|
355
|
+
Install: pip install glhmm
|
|
356
|
+
Note: import the class directly — 'from glhmm import glhmm as gl' followed by
|
|
357
|
+
'gl.glhmm(...)' is a double-level call and will raise AttributeError.
|
|
358
|
+
"""
|
|
359
|
+
from glhmm import glhmm # imports the glhmm *class* from the glhmm *package*
|
|
360
|
+
from glhmm import preproc
|
|
361
|
+
|
|
362
|
+
# Stack data
|
|
363
|
+
data_concat = np.vstack(data_list)
|
|
364
|
+
T_list = [d.shape[0] for d in data_list]
|
|
365
|
+
indices = preproc.build_indices(T_list)
|
|
366
|
+
|
|
367
|
+
# Instantiate directly — glhmm is the class, not a submodule
|
|
368
|
+
model = glhmm(
|
|
369
|
+
K=K,
|
|
370
|
+
covtype='full',
|
|
371
|
+
model_mean='state',
|
|
372
|
+
model_beta='state',
|
|
373
|
+
ar_order=ar_order,
|
|
374
|
+
)
|
|
375
|
+
model.train(data_concat, indices=indices, maxiter=200)
|
|
376
|
+
|
|
377
|
+
# Decode
|
|
378
|
+
vpath = model.decode(data_concat, indices=indices)
|
|
379
|
+
|
|
380
|
+
return model, vpath
|
|
381
|
+
```
|
|
382
|
+
|
|
383
|
+
---
|
|
384
|
+
|
|
385
|
+
## 5. Input-Output HMM {#io-hmm}
|
|
386
|
+
|
|
387
|
+
For task-based fMRI where external events drive state transitions.
|
|
388
|
+
|
|
389
|
+
```python
|
|
390
|
+
"""Input-Output HMM using ssm library.
|
|
391
|
+
|
|
392
|
+
Task events enter as inputs that modulate either transitions or emissions.
|
|
393
|
+
"""
|
|
394
|
+
import ssm
|
|
395
|
+
import numpy as np
|
|
396
|
+
from nilearn.glm.first_level import spm_hrf
|
|
397
|
+
|
|
398
|
+
def prepare_task_inputs(events_df, n_trs, tr, hrf_convolve=True):
|
|
399
|
+
"""Convert task events to input matrix for IO-HMM.
|
|
400
|
+
|
|
401
|
+
Parameters
|
|
402
|
+
----------
|
|
403
|
+
events_df : DataFrame
|
|
404
|
+
Columns: onset, duration, trial_type
|
|
405
|
+
n_trs : int
|
|
406
|
+
Total number of TRs
|
|
407
|
+
tr : float
|
|
408
|
+
Repetition time
|
|
409
|
+
hrf_convolve : bool
|
|
410
|
+
Whether to convolve inputs with HRF. Set True when fitting on BOLD
|
|
411
|
+
(so inputs align with BOLD timing). Set False if fitting on deconvolved data.
|
|
412
|
+
|
|
413
|
+
Returns
|
|
414
|
+
-------
|
|
415
|
+
inputs : array, shape (n_trs, n_conditions)
|
|
416
|
+
One column per trial_type
|
|
417
|
+
condition_names : list of str
|
|
418
|
+
"""
|
|
419
|
+
trial_types = sorted(events_df['trial_type'].unique())
|
|
420
|
+
n_conditions = len(trial_types)
|
|
421
|
+
|
|
422
|
+
# Build stimulus timecourse at TR resolution
|
|
423
|
+
inputs = np.zeros((n_trs, n_conditions))
|
|
424
|
+
|
|
425
|
+
for i, tt in enumerate(trial_types):
|
|
426
|
+
events = events_df[events_df['trial_type'] == tt]
|
|
427
|
+
for _, event in events.iterrows():
|
|
428
|
+
onset_tr = int(np.round(event['onset'] / tr))
|
|
429
|
+
dur_trs = max(1, int(np.round(event['duration'] / tr)))
|
|
430
|
+
end_tr = min(n_trs, onset_tr + dur_trs)
|
|
431
|
+
inputs[onset_tr:end_tr, i] = 1.0
|
|
432
|
+
|
|
433
|
+
if hrf_convolve:
|
|
434
|
+
hrf = spm_hrf(tr, oversampling=1)
|
|
435
|
+
for i in range(n_conditions):
|
|
436
|
+
# mode='same' preserves array length T (mode='full' would return T+len(hrf)-1)
|
|
437
|
+
convolved = np.convolve(inputs[:, i], hrf, mode='same')
|
|
438
|
+
inputs[:, i] = convolved
|
|
439
|
+
|
|
440
|
+
return inputs, trial_types
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def fit_io_hmm(data_list, inputs_list, K, n_features, n_inputs,
|
|
444
|
+
input_driven='transitions', n_restarts=20):
|
|
445
|
+
"""Fit input-output HMM.
|
|
446
|
+
|
|
447
|
+
Parameters
|
|
448
|
+
----------
|
|
449
|
+
data_list : list of arrays
|
|
450
|
+
Each (T, n_features)
|
|
451
|
+
inputs_list : list of arrays
|
|
452
|
+
Each (T, n_inputs) — task inputs for each run
|
|
453
|
+
K : int
|
|
454
|
+
Number of states
|
|
455
|
+
input_driven : str
|
|
456
|
+
'transitions': inputs affect state switching probabilities
|
|
457
|
+
'observations': inputs affect emission means
|
|
458
|
+
'both': inputs affect both
|
|
459
|
+
"""
|
|
460
|
+
if input_driven == 'transitions':
|
|
461
|
+
transitions = 'inputdriven'
|
|
462
|
+
observations = 'gaussian'
|
|
463
|
+
elif input_driven == 'observations':
|
|
464
|
+
transitions = 'standard'
|
|
465
|
+
# ssm does not have a built-in 'input_driven_obs' observation class as of v0.0.1.
|
|
466
|
+
# For input-modulated emissions, subclass ssm.observations.GaussianObservations
|
|
467
|
+
# and override the log_likelihoods() method to add B_k @ u_t to the mean.
|
|
468
|
+
# See: https://github.com/lindermanlab/ssm/blob/master/ssm/observations.py
|
|
469
|
+
observations = 'gaussian' # placeholder — customise as needed
|
|
470
|
+
elif input_driven == 'both':
|
|
471
|
+
transitions = 'inputdriven'
|
|
472
|
+
observations = 'gaussian' # replace with custom class for input-driven emissions
|
|
473
|
+
|
|
474
|
+
best_model = None
|
|
475
|
+
best_ll = -np.inf
|
|
476
|
+
|
|
477
|
+
for restart in range(n_restarts):
|
|
478
|
+
model = ssm.HMM(
|
|
479
|
+
K=K,
|
|
480
|
+
D=n_features,
|
|
481
|
+
M=n_inputs,
|
|
482
|
+
observations='gaussian',
|
|
483
|
+
transitions=transitions,
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
lls = model.fit(
|
|
487
|
+
data_list,
|
|
488
|
+
inputs=inputs_list,
|
|
489
|
+
method='em',
|
|
490
|
+
num_iters=200,
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
if lls[-1] > best_ll:
|
|
494
|
+
best_ll = lls[-1]
|
|
495
|
+
best_model = model
|
|
496
|
+
|
|
497
|
+
return best_model
|
|
498
|
+
```
|
|
499
|
+
|
|
500
|
+
---
|
|
501
|
+
|
|
502
|
+
## 6. SLDS with ssm library {#slds}
|
|
503
|
+
|
|
504
|
+
```python
|
|
505
|
+
"""Switching Linear Dynamical System using ssm library."""
|
|
506
|
+
import ssm
|
|
507
|
+
import numpy as np
|
|
508
|
+
|
|
509
|
+
def fit_slds(data_list, K, D, latent_dim,
|
|
510
|
+
n_restarts=10, n_iters=100):
|
|
511
|
+
"""Fit SLDS.
|
|
512
|
+
|
|
513
|
+
Parameters
|
|
514
|
+
----------
|
|
515
|
+
data_list : list of arrays
|
|
516
|
+
Each (T, D) observation timeseries
|
|
517
|
+
K : int
|
|
518
|
+
Number of discrete switching states
|
|
519
|
+
D : int
|
|
520
|
+
Observation dimension
|
|
521
|
+
latent_dim : int
|
|
522
|
+
Continuous latent state dimension (typically 5-15 for fMRI)
|
|
523
|
+
|
|
524
|
+
Returns
|
|
525
|
+
-------
|
|
526
|
+
best_model : ssm.SLDS
|
|
527
|
+
q : variational posterior
|
|
528
|
+
"""
|
|
529
|
+
best_model = None
|
|
530
|
+
best_elbo = -np.inf
|
|
531
|
+
|
|
532
|
+
for restart in range(n_restarts):
|
|
533
|
+
model = ssm.SLDS(
|
|
534
|
+
N=D, # observation dimension
|
|
535
|
+
K=K, # number of discrete states
|
|
536
|
+
D=latent_dim, # latent dimension
|
|
537
|
+
emissions='gaussian_orthog', # orthogonal emission matrix
|
|
538
|
+
dynamics='gaussian',
|
|
539
|
+
transitions='standard', # or 'recurrent' for rSLDS
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
# Fit with Laplace-EM
|
|
543
|
+
elbos = model.fit(
|
|
544
|
+
data_list,
|
|
545
|
+
method='laplace_em',
|
|
546
|
+
variational_posterior='structured_meanfield',
|
|
547
|
+
num_iters=n_iters,
|
|
548
|
+
initialize=True,
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
if elbos[-1] > best_elbo:
|
|
552
|
+
best_elbo = elbos[-1]
|
|
553
|
+
best_model = model
|
|
554
|
+
best_elbos = elbos
|
|
555
|
+
|
|
556
|
+
print(f"Best ELBO: {best_elbo:.2f}")
|
|
557
|
+
return best_model, best_elbos
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
def decode_slds(model, data_list):
|
|
561
|
+
"""Decode latent states from fitted SLDS."""
|
|
562
|
+
all_discrete_states = []
|
|
563
|
+
all_continuous_states = []
|
|
564
|
+
|
|
565
|
+
for data in data_list:
|
|
566
|
+
# Most likely discrete states
|
|
567
|
+
z = model.most_likely_states(data)
|
|
568
|
+
all_discrete_states.append(z)
|
|
569
|
+
|
|
570
|
+
# Posterior mean of continuous states
|
|
571
|
+
q = model.approximate_posterior(
|
|
572
|
+
data,
|
|
573
|
+
method='laplace_em',
|
|
574
|
+
variational_posterior='structured_meanfield',
|
|
575
|
+
num_iters=50,
|
|
576
|
+
)
|
|
577
|
+
x = q.mean_continuous_states[0] # (T, latent_dim)
|
|
578
|
+
all_continuous_states.append(x)
|
|
579
|
+
|
|
580
|
+
return all_discrete_states, all_continuous_states
|
|
581
|
+
```
|
|
582
|
+
|
|
583
|
+
---
|
|
584
|
+
|
|
585
|
+
## 7. rSLDS with ssm library {#rslds}
|
|
586
|
+
|
|
587
|
+
```python
|
|
588
|
+
"""Recurrent SLDS — discrete states depend on continuous latent state."""
|
|
589
|
+
import ssm
|
|
590
|
+
|
|
591
|
+
def fit_rslds(data_list, K, D, latent_dim, n_restarts=10, n_iters=100):
|
|
592
|
+
"""Fit rSLDS.
|
|
593
|
+
|
|
594
|
+
The key difference from SLDS: transitions='recurrent', meaning
|
|
595
|
+
P(z_t | z_{t-1}, x_{t-1}) depends on the continuous state x.
|
|
596
|
+
"""
|
|
597
|
+
best_model = None
|
|
598
|
+
best_elbo = -np.inf
|
|
599
|
+
|
|
600
|
+
for restart in range(n_restarts):
|
|
601
|
+
model = ssm.SLDS(
|
|
602
|
+
N=D,
|
|
603
|
+
K=K,
|
|
604
|
+
D=latent_dim,
|
|
605
|
+
emissions='gaussian_orthog',
|
|
606
|
+
dynamics='gaussian',
|
|
607
|
+
transitions='recurrent', # This makes it recurrent
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
elbos = model.fit(
|
|
611
|
+
data_list,
|
|
612
|
+
method='laplace_em',
|
|
613
|
+
variational_posterior='structured_meanfield',
|
|
614
|
+
num_iters=n_iters,
|
|
615
|
+
initialize=True,
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
if elbos[-1] > best_elbo:
|
|
619
|
+
best_elbo = elbos[-1]
|
|
620
|
+
best_model = model
|
|
621
|
+
|
|
622
|
+
return best_model
|
|
623
|
+
```
|
|
624
|
+
|
|
625
|
+
---
|
|
626
|
+
|
|
627
|
+
## 8. Model Selection — Choosing K {#model-selection}
|
|
628
|
+
|
|
629
|
+
```python
|
|
630
|
+
"""Model selection utilities for SSMs on fMRI data."""
|
|
631
|
+
import numpy as np
|
|
632
|
+
from sklearn.model_selection import KFold
|
|
633
|
+
|
|
634
|
+
def select_K_bic(data, lengths, K_range=range(2, 16), covariance_type='full',
|
|
635
|
+
n_restarts=30):
|
|
636
|
+
"""Select K using BIC (Bayesian Information Criterion).
|
|
637
|
+
|
|
638
|
+
BIC = -2 * log_likelihood + n_params * log(n_samples)
|
|
639
|
+
Lower BIC is better.
|
|
640
|
+
"""
|
|
641
|
+
from hmmlearn import hmm
|
|
642
|
+
|
|
643
|
+
results = {}
|
|
644
|
+
T, p = data.shape
|
|
645
|
+
|
|
646
|
+
for K in K_range:
|
|
647
|
+
model, score, _ = fit_gaussian_hmm(
|
|
648
|
+
data, lengths, K, covariance_type=covariance_type,
|
|
649
|
+
n_restarts=n_restarts
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
# Count parameters
|
|
653
|
+
n_params = (K - 1) # initial state
|
|
654
|
+
n_params += K * (K - 1) # transition matrix
|
|
655
|
+
n_params += K * p # means
|
|
656
|
+
if covariance_type == 'full':
|
|
657
|
+
n_params += K * p * (p + 1) // 2 # covariances
|
|
658
|
+
elif covariance_type == 'diag':
|
|
659
|
+
n_params += K * p
|
|
660
|
+
|
|
661
|
+
bic = -2 * score + n_params * np.log(T)
|
|
662
|
+
aic = -2 * score + 2 * n_params
|
|
663
|
+
|
|
664
|
+
results[K] = {
|
|
665
|
+
'log_likelihood': score,
|
|
666
|
+
'bic': bic,
|
|
667
|
+
'aic': aic,
|
|
668
|
+
'n_params': n_params,
|
|
669
|
+
'model': model,
|
|
670
|
+
}
|
|
671
|
+
print(f"K={K}: LL={score:.1f}, BIC={bic:.1f}, AIC={aic:.1f}, params={n_params}")
|
|
672
|
+
|
|
673
|
+
best_K_bic = min(results, key=lambda k: results[k]['bic'])
|
|
674
|
+
best_K_aic = min(results, key=lambda k: results[k]['aic'])
|
|
675
|
+
print(f"\nBest K by BIC: {best_K_bic}")
|
|
676
|
+
print(f"Best K by AIC: {best_K_aic}")
|
|
677
|
+
|
|
678
|
+
return results, best_K_bic
|
|
679
|
+
|
|
680
|
+
|
|
681
|
+
def select_K_crossval(data_runs, K_range=range(2, 16), n_folds=None,
|
|
682
|
+
covariance_type='full', n_restarts=20):
|
|
683
|
+
"""Select K using cross-validated log-likelihood on held-out runs.
|
|
684
|
+
|
|
685
|
+
Parameters
|
|
686
|
+
----------
|
|
687
|
+
data_runs : list of arrays
|
|
688
|
+
One array per fMRI run, shape (T_run, n_features)
|
|
689
|
+
K_range : range
|
|
690
|
+
K values to test
|
|
691
|
+
n_folds : int or None
|
|
692
|
+
If None, use leave-one-run-out
|
|
693
|
+
"""
|
|
694
|
+
from hmmlearn import hmm
|
|
695
|
+
|
|
696
|
+
n_runs = len(data_runs)
|
|
697
|
+
if n_folds is None:
|
|
698
|
+
n_folds = n_runs # leave-one-run-out
|
|
699
|
+
|
|
700
|
+
results = {}
|
|
701
|
+
|
|
702
|
+
for K in K_range:
|
|
703
|
+
fold_scores = []
|
|
704
|
+
|
|
705
|
+
kf = KFold(n_splits=min(n_folds, n_runs), shuffle=False)
|
|
706
|
+
run_indices = np.arange(n_runs)
|
|
707
|
+
|
|
708
|
+
for train_idx, test_idx in kf.split(run_indices):
|
|
709
|
+
# Concatenate training runs
|
|
710
|
+
train_data = np.vstack([data_runs[i] for i in train_idx])
|
|
711
|
+
train_lengths = [data_runs[i].shape[0] for i in train_idx]
|
|
712
|
+
|
|
713
|
+
# Concatenate test runs
|
|
714
|
+
test_data = np.vstack([data_runs[i] for i in test_idx])
|
|
715
|
+
test_lengths = [data_runs[i].shape[0] for i in test_idx]
|
|
716
|
+
|
|
717
|
+
# Fit on train
|
|
718
|
+
model, _, _ = fit_gaussian_hmm(
|
|
719
|
+
train_data, train_lengths, K,
|
|
720
|
+
covariance_type=covariance_type,
|
|
721
|
+
n_restarts=n_restarts
|
|
722
|
+
)
|
|
723
|
+
|
|
724
|
+
# Score on test
|
|
725
|
+
test_score = model.score(test_data, test_lengths)
|
|
726
|
+
# Normalize by number of test time points
|
|
727
|
+
fold_scores.append(test_score / test_data.shape[0])
|
|
728
|
+
|
|
729
|
+
mean_score = np.mean(fold_scores)
|
|
730
|
+
std_score = np.std(fold_scores)
|
|
731
|
+
results[K] = {'mean_cv_ll': mean_score, 'std_cv_ll': std_score}
|
|
732
|
+
print(f"K={K}: CV log-lik = {mean_score:.4f} ± {std_score:.4f}")
|
|
733
|
+
|
|
734
|
+
best_K = max(results, key=lambda k: results[k]['mean_cv_ll'])
|
|
735
|
+
print(f"\nBest K by CV: {best_K}")
|
|
736
|
+
|
|
737
|
+
return results, best_K
|
|
738
|
+
|
|
739
|
+
|
|
740
|
+
def state_stability_analysis(data, lengths, K, n_splits=10, n_restarts=30):
|
|
741
|
+
"""Check if states are reproducible across random splits and initializations.
|
|
742
|
+
|
|
743
|
+
Fits the model on random 50/50 splits of the data and measures how
|
|
744
|
+
similar the inferred states are. Stable states should be recoverable
|
|
745
|
+
across splits.
|
|
746
|
+
"""
|
|
747
|
+
from scipy.optimize import linear_sum_assignment
|
|
748
|
+
from scipy.spatial.distance import cdist
|
|
749
|
+
|
|
750
|
+
all_means = []
|
|
751
|
+
|
|
752
|
+
for split in range(n_splits):
|
|
753
|
+
rng = np.random.RandomState(split)
|
|
754
|
+
# Random split of time points (within runs)
|
|
755
|
+
split_data_list = []
|
|
756
|
+
split_lengths = []
|
|
757
|
+
offset = 0
|
|
758
|
+
for length in lengths:
|
|
759
|
+
run_data = data[offset:offset + length]
|
|
760
|
+
mid = length // 2
|
|
761
|
+
if rng.random() > 0.5:
|
|
762
|
+
split_data_list.append(run_data[:mid])
|
|
763
|
+
else:
|
|
764
|
+
split_data_list.append(run_data[mid:])
|
|
765
|
+
split_lengths.append(split_data_list[-1].shape[0])
|
|
766
|
+
offset += length
|
|
767
|
+
|
|
768
|
+
split_data = np.vstack(split_data_list)
|
|
769
|
+
model, _, _ = fit_gaussian_hmm(
|
|
770
|
+
split_data, split_lengths, K,
|
|
771
|
+
n_restarts=n_restarts
|
|
772
|
+
)
|
|
773
|
+
all_means.append(model.means_)
|
|
774
|
+
|
|
775
|
+
# Compare all pairs of splits using Hungarian algorithm
|
|
776
|
+
similarities = []
|
|
777
|
+
for i in range(n_splits):
|
|
778
|
+
for j in range(i + 1, n_splits):
|
|
779
|
+
cost = cdist(all_means[i], all_means[j], metric='correlation')
|
|
780
|
+
row_ind, col_ind = linear_sum_assignment(cost)
|
|
781
|
+
matched_corr = 1 - cost[row_ind, col_ind].mean()
|
|
782
|
+
similarities.append(matched_corr)
|
|
783
|
+
|
|
784
|
+
print(f"State stability (mean matched correlation): {np.mean(similarities):.3f} ± {np.std(similarities):.3f}")
|
|
785
|
+
print(f"Values > 0.8 suggest stable states; < 0.5 suggests instability")
|
|
786
|
+
|
|
787
|
+
return similarities
|
|
788
|
+
```
|
|
789
|
+
|
|
790
|
+
---
|
|
791
|
+
|
|
792
|
+
## 9. State Visualization {#visualization}
|
|
793
|
+
|
|
794
|
+
```python
|
|
795
|
+
"""Visualization utilities for SSM results on fMRI data."""
|
|
796
|
+
import numpy as np
|
|
797
|
+
import matplotlib.pyplot as plt
|
|
798
|
+
|
|
799
|
+
def plot_state_timecourse(states, tr, run_boundaries=None, ax=None,
|
|
800
|
+
state_colors=None, title='State timecourse'):
|
|
801
|
+
"""Plot the inferred state sequence over time."""
|
|
802
|
+
if ax is None:
|
|
803
|
+
fig, ax = plt.subplots(figsize=(14, 2))
|
|
804
|
+
|
|
805
|
+
T = len(states)
|
|
806
|
+
K = len(np.unique(states))
|
|
807
|
+
times = np.arange(T) * tr
|
|
808
|
+
|
|
809
|
+
if state_colors is None:
|
|
810
|
+
cmap = plt.cm.Set2
|
|
811
|
+
state_colors = [cmap(i / K) for i in range(K)]
|
|
812
|
+
|
|
813
|
+
for t in range(T - 1):
|
|
814
|
+
ax.axvspan(times[t], times[t + 1], color=state_colors[states[t]], alpha=0.8)
|
|
815
|
+
|
|
816
|
+
if run_boundaries is not None:
|
|
817
|
+
for b in run_boundaries:
|
|
818
|
+
ax.axvline(b * tr, color='black', linewidth=2, linestyle='--', alpha=0.5)
|
|
819
|
+
|
|
820
|
+
ax.set_xlim(0, times[-1])
|
|
821
|
+
ax.set_xlabel('Time (s)')
|
|
822
|
+
ax.set_title(title)
|
|
823
|
+
ax.set_yticks([])
|
|
824
|
+
|
|
825
|
+
return ax
|
|
826
|
+
|
|
827
|
+
|
|
828
|
+
def plot_state_spatial_maps(means, roi_labels=None, n_top_regions=10):
|
|
829
|
+
"""Plot the top activated regions for each state."""
|
|
830
|
+
K, p = means.shape
|
|
831
|
+
|
|
832
|
+
fig, axes = plt.subplots(1, K, figsize=(4 * K, 6))
|
|
833
|
+
if K == 1:
|
|
834
|
+
axes = [axes]
|
|
835
|
+
|
|
836
|
+
for k in range(K):
|
|
837
|
+
ax = axes[k]
|
|
838
|
+
state_mean = means[k]
|
|
839
|
+
|
|
840
|
+
# Top positive and negative regions
|
|
841
|
+
top_pos = np.argsort(state_mean)[-n_top_regions:][::-1]
|
|
842
|
+
top_neg = np.argsort(state_mean)[:n_top_regions]
|
|
843
|
+
top_idx = np.concatenate([top_pos, top_neg])
|
|
844
|
+
|
|
845
|
+
values = state_mean[top_idx]
|
|
846
|
+
if roi_labels is not None:
|
|
847
|
+
labels = [roi_labels[i] for i in top_idx]
|
|
848
|
+
else:
|
|
849
|
+
labels = [f'ROI {i}' for i in top_idx]
|
|
850
|
+
|
|
851
|
+
colors = ['#e74c3c' if v > 0 else '#3498db' for v in values]
|
|
852
|
+
ax.barh(range(len(values)), values, color=colors)
|
|
853
|
+
ax.set_yticks(range(len(values)))
|
|
854
|
+
ax.set_yticklabels(labels, fontsize=8)
|
|
855
|
+
ax.set_title(f'State {k + 1}')
|
|
856
|
+
ax.axvline(0, color='black', linewidth=0.5)
|
|
857
|
+
|
|
858
|
+
plt.tight_layout()
|
|
859
|
+
return fig
|
|
860
|
+
|
|
861
|
+
|
|
862
|
+
def plot_transition_matrix(transmat, ax=None):
|
|
863
|
+
"""Plot transition probability matrix as heatmap."""
|
|
864
|
+
if ax is None:
|
|
865
|
+
fig, ax = plt.subplots(figsize=(6, 5))
|
|
866
|
+
|
|
867
|
+
K = transmat.shape[0]
|
|
868
|
+
im = ax.imshow(transmat, cmap='Blues', vmin=0, vmax=1)
|
|
869
|
+
|
|
870
|
+
for i in range(K):
|
|
871
|
+
for j in range(K):
|
|
872
|
+
ax.text(j, i, f'{transmat[i, j]:.2f}', ha='center', va='center',
|
|
873
|
+
color='white' if transmat[i, j] > 0.5 else 'black', fontsize=10)
|
|
874
|
+
|
|
875
|
+
ax.set_xticks(range(K))
|
|
876
|
+
ax.set_yticks(range(K))
|
|
877
|
+
ax.set_xticklabels([f'State {k+1}' for k in range(K)])
|
|
878
|
+
ax.set_yticklabels([f'State {k+1}' for k in range(K)])
|
|
879
|
+
ax.set_xlabel('To state')
|
|
880
|
+
ax.set_ylabel('From state')
|
|
881
|
+
ax.set_title('Transition matrix')
|
|
882
|
+
plt.colorbar(im, ax=ax)
|
|
883
|
+
|
|
884
|
+
return ax
|
|
885
|
+
|
|
886
|
+
|
|
887
|
+
def plot_dwell_time_distributions(dwell_times, tr, ax=None):
|
|
888
|
+
"""Plot dwell time distributions for each state."""
|
|
889
|
+
K = len(dwell_times)
|
|
890
|
+
if ax is None:
|
|
891
|
+
fig, ax = plt.subplots(figsize=(8, 4))
|
|
892
|
+
|
|
893
|
+
for k in range(K):
|
|
894
|
+
dwells_sec = np.array(dwell_times[k]) * tr
|
|
895
|
+
if len(dwells_sec) > 0:
|
|
896
|
+
ax.hist(dwells_sec, bins=20, alpha=0.5, label=f'State {k+1} '
|
|
897
|
+
f'(mean={dwells_sec.mean():.1f}s)', density=True)
|
|
898
|
+
|
|
899
|
+
ax.set_xlabel('Dwell time (seconds)')
|
|
900
|
+
ax.set_ylabel('Density')
|
|
901
|
+
ax.set_title('Dwell time distributions')
|
|
902
|
+
ax.legend()
|
|
903
|
+
|
|
904
|
+
return ax
|
|
905
|
+
```
|
|
906
|
+
|
|
907
|
+
---
|
|
908
|
+
|
|
909
|
+
## 10. Reproducibility and Initialization {#reproducibility}
|
|
910
|
+
|
|
911
|
+
```python
|
|
912
|
+
"""Best practices for reproducible SSM fitting on fMRI data."""
|
|
913
|
+
import numpy as np
|
|
914
|
+
import json
|
|
915
|
+
|
|
916
|
+
def save_ssm_config(filepath, **kwargs):
|
|
917
|
+
"""Save SSM configuration for reproducibility."""
|
|
918
|
+
config = {
|
|
919
|
+
'model_type': kwargs.get('model_type', 'gaussian_hmm'),
|
|
920
|
+
'K': kwargs.get('K'),
|
|
921
|
+
'covariance_type': kwargs.get('covariance_type', 'full'),
|
|
922
|
+
'n_restarts': kwargs.get('n_restarts', 50),
|
|
923
|
+
'n_iter': kwargs.get('n_iter', 200),
|
|
924
|
+
'random_seed': kwargs.get('random_seed', 42),
|
|
925
|
+
'parcellation': kwargs.get('parcellation', 'schaefer200'),
|
|
926
|
+
'confound_strategy': kwargs.get('confound_strategy', 'moderate'),
|
|
927
|
+
'hrf_strategy': kwargs.get('hrf_strategy', 'bold_direct'),
|
|
928
|
+
'tr': kwargs.get('tr'),
|
|
929
|
+
'preprocessing': kwargs.get('preprocessing', 'fmriprep+xcpd'),
|
|
930
|
+
'notes': kwargs.get('notes', ''),
|
|
931
|
+
}
|
|
932
|
+
with open(filepath, 'w') as f:
|
|
933
|
+
json.dump(config, f, indent=2)
|
|
934
|
+
print(f"Config saved to {filepath}")
|
|
935
|
+
|
|
936
|
+
|
|
937
|
+
def align_state_labels(reference_means, target_means):
|
|
938
|
+
"""Align state labels between two models using Hungarian algorithm.
|
|
939
|
+
|
|
940
|
+
Use when comparing states across subjects or between runs.
|
|
941
|
+
"""
|
|
942
|
+
from scipy.optimize import linear_sum_assignment
|
|
943
|
+
from scipy.spatial.distance import cdist
|
|
944
|
+
|
|
945
|
+
cost = cdist(reference_means, target_means, metric='correlation')
|
|
946
|
+
row_ind, col_ind = linear_sum_assignment(cost)
|
|
947
|
+
|
|
948
|
+
label_mapping = {col: row for row, col in zip(row_ind, col_ind)}
|
|
949
|
+
return label_mapping
|
|
950
|
+
```
|
|
951
|
+
|
|
952
|
+
---
|
|
953
|
+
|
|
954
|
+
## 11. Model Diagnostics: Detecting Pathological Fits {#diagnostics}
|
|
955
|
+
|
|
956
|
+
Before reporting SSM results, always run these checks. Pathological fits produce
|
|
957
|
+
scientifically meaningless states that can appear statistically significant.
|
|
958
|
+
|
|
959
|
+
```python
|
|
960
|
+
"""Diagnostics for detecting common HMM failure modes on fMRI data."""
|
|
961
|
+
import numpy as np
|
|
962
|
+
|
|
963
|
+
|
|
964
|
+
def diagnose_hmm_fit(model, state_seq, confounds, fd, tr,
|
|
965
|
+
dominant_state_threshold=0.70,
|
|
966
|
+
min_dwell_trs=2,
|
|
967
|
+
motion_corr_threshold=0.30):
|
|
968
|
+
"""Run a battery of diagnostics on a fitted HMM.
|
|
969
|
+
|
|
970
|
+
Parameters
|
|
971
|
+
----------
|
|
972
|
+
model : fitted HMM (hmmlearn GaussianHMM or similar)
|
|
973
|
+
state_seq : array, shape (T,)
|
|
974
|
+
Viterbi state sequence
|
|
975
|
+
confounds : array, shape (T, n_confounds)
|
|
976
|
+
Confound matrix (motion params, WM/CSF, etc.)
|
|
977
|
+
fd : array, shape (T,)
|
|
978
|
+
Framewise displacement per TR
|
|
979
|
+
tr : float
|
|
980
|
+
Repetition time in seconds
|
|
981
|
+
dominant_state_threshold : float
|
|
982
|
+
Warn if any state occupies more than this fraction of TRs
|
|
983
|
+
min_dwell_trs : int
|
|
984
|
+
Warn if mean dwell time is below this many TRs
|
|
985
|
+
motion_corr_threshold : float
|
|
986
|
+
Warn if any state's occurrence correlates with FD above this value
|
|
987
|
+
|
|
988
|
+
Returns
|
|
989
|
+
-------
|
|
990
|
+
report : dict
|
|
991
|
+
Diagnostic results with 'warnings' list
|
|
992
|
+
"""
|
|
993
|
+
K = model.n_components
|
|
994
|
+
T = len(state_seq)
|
|
995
|
+
warnings = []
|
|
996
|
+
|
|
997
|
+
# --- 11a. Motion-driven state detection ---
|
|
998
|
+
fd_clean = np.nan_to_num(fd, nan=0.0)
|
|
999
|
+
motion_corrs = {}
|
|
1000
|
+
for k in range(K):
|
|
1001
|
+
state_indicator = (state_seq == k).astype(float)
|
|
1002
|
+
r = np.corrcoef(state_indicator, fd_clean)[0, 1]
|
|
1003
|
+
motion_corrs[k] = r
|
|
1004
|
+
if abs(r) > motion_corr_threshold:
|
|
1005
|
+
warnings.append(
|
|
1006
|
+
f"State {k}: |r|={abs(r):.2f} with framewise displacement "
|
|
1007
|
+
f"(threshold {motion_corr_threshold}). May be motion-driven."
|
|
1008
|
+
)
|
|
1009
|
+
|
|
1010
|
+
# --- 11b. Dominant state pathology ---
|
|
1011
|
+
frac_occ = np.array([(state_seq == k).sum() / T for k in range(K)])
|
|
1012
|
+
dominant_states = np.where(frac_occ > dominant_state_threshold)[0]
|
|
1013
|
+
for k in dominant_states:
|
|
1014
|
+
warnings.append(
|
|
1015
|
+
f"State {k} dominates: {frac_occ[k]:.1%} of TRs. "
|
|
1016
|
+
f"Model may have collapsed — check BIC at lower K."
|
|
1017
|
+
)
|
|
1018
|
+
|
|
1019
|
+
# --- 11c. Per-TR switching (too-fast switching) ---
|
|
1020
|
+
dwell_times = {k: [] for k in range(K)}
|
|
1021
|
+
current_state = state_seq[0]
|
|
1022
|
+
current_dwell = 1
|
|
1023
|
+
for t in range(1, T):
|
|
1024
|
+
if state_seq[t] == current_state:
|
|
1025
|
+
current_dwell += 1
|
|
1026
|
+
else:
|
|
1027
|
+
dwell_times[current_state].append(current_dwell)
|
|
1028
|
+
current_state = state_seq[t]
|
|
1029
|
+
current_dwell = 1
|
|
1030
|
+
dwell_times[current_state].append(current_dwell)
|
|
1031
|
+
|
|
1032
|
+
mean_dwell = {k: np.mean(dwell_times[k]) if dwell_times[k] else 0 for k in range(K)}
|
|
1033
|
+
fast_states = [k for k, d in mean_dwell.items() if d < min_dwell_trs]
|
|
1034
|
+
for k in fast_states:
|
|
1035
|
+
warnings.append(
|
|
1036
|
+
f"State {k}: mean dwell = {mean_dwell[k]:.1f} TRs "
|
|
1037
|
+
f"(< {min_dwell_trs} TR minimum). "
|
|
1038
|
+
f"Solutions: add sticky prior, reduce K, check preprocessing."
|
|
1039
|
+
)
|
|
1040
|
+
|
|
1041
|
+
# --- 11d. State-confound correlation check ---
|
|
1042
|
+
confound_corrs = {}
|
|
1043
|
+
for k in range(K):
|
|
1044
|
+
state_indicator = (state_seq == k).astype(float)
|
|
1045
|
+
corrs = [np.corrcoef(state_indicator, confounds[:, j])[0, 1]
|
|
1046
|
+
for j in range(confounds.shape[1])]
|
|
1047
|
+
max_corr = np.max(np.abs(corrs))
|
|
1048
|
+
confound_corrs[k] = max_corr
|
|
1049
|
+
if max_corr > motion_corr_threshold:
|
|
1050
|
+
warnings.append(
|
|
1051
|
+
f"State {k}: max |r|={max_corr:.2f} with confound regressors. "
|
|
1052
|
+
f"Consider re-running with that confound removed from the data."
|
|
1053
|
+
)
|
|
1054
|
+
|
|
1055
|
+
report = {
|
|
1056
|
+
'fractional_occupancy': frac_occ,
|
|
1057
|
+
'mean_dwell_times_trs': mean_dwell,
|
|
1058
|
+
'mean_dwell_times_sec': {k: v * tr for k, v in mean_dwell.items()},
|
|
1059
|
+
'motion_correlations': motion_corrs,
|
|
1060
|
+
'confound_correlations': confound_corrs,
|
|
1061
|
+
'warnings': warnings,
|
|
1062
|
+
}
|
|
1063
|
+
|
|
1064
|
+
if warnings:
|
|
1065
|
+
print(f"=== {len(warnings)} diagnostic warning(s) ===")
|
|
1066
|
+
for w in warnings:
|
|
1067
|
+
print(f" WARNING: {w}")
|
|
1068
|
+
else:
|
|
1069
|
+
print("All diagnostics passed.")
|
|
1070
|
+
|
|
1071
|
+
return report
|
|
1072
|
+
```
|
|
1073
|
+
|
|
1074
|
+
**Quick interpretation guide:**
|
|
1075
|
+
|
|
1076
|
+
| Warning | Likely cause | Fix |
|
|
1077
|
+
|---------|-------------|-----|
|
|
1078
|
+
| State correlates with FD | Motion artifact state | Tighter scrubbing; check if state disappears after removing high-motion subjects |
|
|
1079
|
+
| Dominant state (>70%) | Model collapsed to trivial solution | Lower K; check for degenerate covariance; more restarts |
|
|
1080
|
+
| Mean dwell < 2 TRs | Noise-driven rapid switching | Add sticky prior (`kappa`); or post-hoc apply minimum dwell-time filter |
|
|
1081
|
+
| High confound correlation | Confound leakage | Revisit confound strategy; ensure confound regression happened before SSM fitting |
|
|
1082
|
+
|
|
1083
|
+
---
|
|
1084
|
+
|
|
1085
|
+
## 12. JAX-Based HMM for GPU Acceleration {#jax-gpu}
|
|
1086
|
+
|
|
1087
|
+
Use `ssm` with JAX for GPU-accelerated inference. The API is identical to the CPU version —
|
|
1088
|
+
JAX handles device dispatch automatically.
|
|
1089
|
+
|
|
1090
|
+
```python
|
|
1091
|
+
"""GPU-accelerated HMM using the ssm library with JAX backend.
|
|
1092
|
+
|
|
1093
|
+
Install:
|
|
1094
|
+
pip install jax jaxlib # CPU JAX
|
|
1095
|
+
# For GPU (CUDA 12):
|
|
1096
|
+
pip install "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
|
1097
|
+
pip install ssm # Linderman lab ssm (uses JAX internally for some operations)
|
|
1098
|
+
|
|
1099
|
+
When GPU acceleration matters:
|
|
1100
|
+
- >50 subjects or >1000 TRs/subject
|
|
1101
|
+
- rSLDS / SNLDS (Laplace-EM is computationally heavy)
|
|
1102
|
+
- Model selection sweeps over K (embarrassingly parallel)
|
|
1103
|
+
All code in this file runs on CPU — GPU is a drop-in speedup.
|
|
1104
|
+
"""
|
|
1105
|
+
import jax
|
|
1106
|
+
import jax.numpy as jnp
|
|
1107
|
+
import numpy as np
|
|
1108
|
+
import ssm
|
|
1109
|
+
|
|
1110
|
+
|
|
1111
|
+
def check_jax_device():
|
|
1112
|
+
"""Report which device JAX is using."""
|
|
1113
|
+
backend = jax.default_backend()
|
|
1114
|
+
devices = jax.devices()
|
|
1115
|
+
print(f"JAX backend: {backend}")
|
|
1116
|
+
print(f"Available devices: {devices}")
|
|
1117
|
+
if backend == 'cpu':
|
|
1118
|
+
print("NOTE: Running on CPU. For GPU, install jax[cuda12] and ensure CUDA is available.")
|
|
1119
|
+
return backend
|
|
1120
|
+
|
|
1121
|
+
|
|
1122
|
+
def fit_hmm_jax(data_list, K, D=None, n_restarts=20, n_iters=200):
|
|
1123
|
+
"""Fit Gaussian HMM using ssm with JAX acceleration.
|
|
1124
|
+
|
|
1125
|
+
Parameters
|
|
1126
|
+
----------
|
|
1127
|
+
data_list : list of np.ndarray, each (T, D)
|
|
1128
|
+
Multiple runs / subjects. ssm handles variable-length sequences natively.
|
|
1129
|
+
K : int
|
|
1130
|
+
Number of states
|
|
1131
|
+
D : int or None
|
|
1132
|
+
Observation dimension (inferred from data if None)
|
|
1133
|
+
|
|
1134
|
+
Notes
|
|
1135
|
+
-----
|
|
1136
|
+
ssm uses JAX for its core computations (forward-backward, EM updates).
|
|
1137
|
+
If JAX is configured to use a GPU, inference runs on GPU automatically.
|
|
1138
|
+
Convert data to float32 (vs float64) for faster GPU computation.
|
|
1139
|
+
"""
|
|
1140
|
+
check_jax_device()
|
|
1141
|
+
|
|
1142
|
+
if D is None:
|
|
1143
|
+
D = data_list[0].shape[1]
|
|
1144
|
+
|
|
1145
|
+
# Convert to float32 for GPU efficiency (float64 is slower on most GPUs)
|
|
1146
|
+
data_list_f32 = [d.astype(np.float32) for d in data_list]
|
|
1147
|
+
|
|
1148
|
+
best_model = None
|
|
1149
|
+
best_ll = -np.inf
|
|
1150
|
+
|
|
1151
|
+
for restart in range(n_restarts):
|
|
1152
|
+
model = ssm.HMM(
|
|
1153
|
+
K=K,
|
|
1154
|
+
D=D,
|
|
1155
|
+
observations='gaussian',
|
|
1156
|
+
transitions='standard',
|
|
1157
|
+
)
|
|
1158
|
+
lls = model.fit(
|
|
1159
|
+
data_list_f32,
|
|
1160
|
+
method='em',
|
|
1161
|
+
num_iters=n_iters,
|
|
1162
|
+
tolerance=1e-4,
|
|
1163
|
+
)
|
|
1164
|
+
if lls[-1] > best_ll:
|
|
1165
|
+
best_ll = lls[-1]
|
|
1166
|
+
best_model = model
|
|
1167
|
+
best_lls = lls
|
|
1168
|
+
|
|
1169
|
+
print(f"Best log-likelihood: {best_ll:.2f}")
|
|
1170
|
+
return best_model, best_lls
|
|
1171
|
+
|
|
1172
|
+
|
|
1173
|
+
def fit_rslds_gpu(data_list, K, D, latent_dim, n_restarts=5, n_iters=100):
|
|
1174
|
+
"""Fit rSLDS on GPU — most useful model to GPU-accelerate (Laplace-EM is heavy).
|
|
1175
|
+
|
|
1176
|
+
rSLDS runtime scales as O(T × D² × latent_dim) per iteration.
|
|
1177
|
+
For large datasets (T > 10000, D > 50), GPU gives 5-10× speedup.
|
|
1178
|
+
"""
|
|
1179
|
+
check_jax_device()
|
|
1180
|
+
data_f32 = [d.astype(np.float32) for d in data_list]
|
|
1181
|
+
|
|
1182
|
+
best_model = None
|
|
1183
|
+
best_elbo = -np.inf
|
|
1184
|
+
|
|
1185
|
+
for restart in range(n_restarts):
|
|
1186
|
+
model = ssm.SLDS(
|
|
1187
|
+
N=D,
|
|
1188
|
+
K=K,
|
|
1189
|
+
D=latent_dim,
|
|
1190
|
+
emissions='gaussian_orthog',
|
|
1191
|
+
dynamics='gaussian',
|
|
1192
|
+
transitions='recurrent',
|
|
1193
|
+
)
|
|
1194
|
+
elbos = model.fit(
|
|
1195
|
+
data_f32,
|
|
1196
|
+
method='laplace_em',
|
|
1197
|
+
variational_posterior='structured_meanfield',
|
|
1198
|
+
num_iters=n_iters,
|
|
1199
|
+
initialize=True,
|
|
1200
|
+
)
|
|
1201
|
+
if elbos[-1] > best_elbo:
|
|
1202
|
+
best_elbo = elbos[-1]
|
|
1203
|
+
best_model = model
|
|
1204
|
+
|
|
1205
|
+
return best_model
|
|
1206
|
+
|
|
1207
|
+
|
|
1208
|
+
# --- DyNeMo (osl-dynamics): GPU-native deep generative model ---
|
|
1209
|
+
#
|
|
1210
|
+
# DyNeMo is a variational recurrent neural network that learns dynamic
|
|
1211
|
+
# network modes from BOLD data. It requires a GPU for practical runtimes.
|
|
1212
|
+
#
|
|
1213
|
+
# Install: pip install osl-dynamics tensorflow (or jax-based version)
|
|
1214
|
+
#
|
|
1215
|
+
# Key use case: when HMM-MAR underfits and you need a more expressive model
|
|
1216
|
+
# for naturalistic / resting-state data with complex temporal dynamics.
|
|
1217
|
+
#
|
|
1218
|
+
# from osl_dynamics.models.dynemo import Config, Model
|
|
1219
|
+
# config = Config(
|
|
1220
|
+
# n_modes=K,
|
|
1221
|
+
# n_channels=n_rois,
|
|
1222
|
+
# sequence_length=200,
|
|
1223
|
+
# inference_n_units=64,
|
|
1224
|
+
# inference_normalization='layer',
|
|
1225
|
+
# model_n_units=64,
|
|
1226
|
+
# model_normalization='layer',
|
|
1227
|
+
# learn_means=False,
|
|
1228
|
+
# learn_covariances=True,
|
|
1229
|
+
# batch_size=16,
|
|
1230
|
+
# learning_rate=0.01,
|
|
1231
|
+
# n_epochs=50,
|
|
1232
|
+
# )
|
|
1233
|
+
# model = Model(config)
|
|
1234
|
+
# model.compile()
|
|
1235
|
+
# history = model.fit(data)
|
|
1236
|
+
# alpha = model.get_alpha(data) # (T, K) mixing coefficients
|
|
1237
|
+
```
|
|
1238
|
+
|
|
1239
|
+
**GPU recommendation summary:**
|
|
1240
|
+
|
|
1241
|
+
| Scenario | Recommendation |
|
|
1242
|
+
|----------|---------------|
|
|
1243
|
+
| hmmlearn, standard Gaussian HMM | CPU is fine; hmmlearn has no JAX/GPU path |
|
|
1244
|
+
| ssm Gaussian HMM, <50 subjects | CPU is fine |
|
|
1245
|
+
| ssm / dynamax HMM, >50 subjects or >1000 TRs | GPU recommended |
|
|
1246
|
+
| ssm rSLDS / SNLDS | GPU strongly recommended; 5-10× speedup |
|
|
1247
|
+
| dynamax custom SSM with jit | GPU recommended; jit alone gives large speedup |
|
|
1248
|
+
| osl-dynamics DyNeMo | GPU required for practical training |
|
|
1249
|
+
| Model selection sweeps (many K values) | Parallelize across GPUs or use batch jobs |
|
|
1250
|
+
|
|
1251
|
+
---
|
|
1252
|
+
|
|
1253
|
+
## 13. dynamax: Lego-Style Custom SSMs {#dynamax}
|
|
1254
|
+
|
|
1255
|
+
`dynamax` (probml / Murphy lab) is a **JAX-native** SSM library built around a modular,
|
|
1256
|
+
composable design. Instead of picking from a fixed menu of model types, you assemble your
|
|
1257
|
+
model from independent, swappable pieces — like Lego bricks:
|
|
1258
|
+
|
|
1259
|
+
```
|
|
1260
|
+
Your model = InitialComponent + TransitionComponent + EmissionComponent
|
|
1261
|
+
```
|
|
1262
|
+
|
|
1263
|
+
Each component can be swapped independently without touching the others. This makes it
|
|
1264
|
+
trivial to, e.g., test Gaussian vs. diagonal-Gaussian emissions with the same sticky
|
|
1265
|
+
transition model, or compare standard vs. input-driven transitions with identical emissions.
|
|
1266
|
+
|
|
1267
|
+
All inference (EM, Viterbi, forward-backward, Kalman filter/smoother) runs under JAX JIT,
|
|
1268
|
+
giving GPU acceleration and fast CPU execution automatically.
|
|
1269
|
+
|
|
1270
|
+
```
|
|
1271
|
+
Install: pip install dynamax (requires jax jaxlib)
|
|
1272
|
+
Docs: https://probml.github.io/dynamax/
|
|
1273
|
+
```
|
|
1274
|
+
|
|
1275
|
+
### 13a. The Lego pieces available in dynamax
|
|
1276
|
+
|
|
1277
|
+
```
|
|
1278
|
+
dynamax.hidden_markov_model
|
|
1279
|
+
├── Transitions
|
|
1280
|
+
│ ├── StandardTransitions # unconstrained K×K matrix
|
|
1281
|
+
│ ├── StickyTransitions # adds κ self-transition bias (= sticky HMM)
|
|
1282
|
+
│ └── (subclass AbstractTransitions for custom)
|
|
1283
|
+
│
|
|
1284
|
+
├── Emissions
|
|
1285
|
+
│ ├── GaussianEmissions # full-covariance Gaussian — standard choice for fMRI
|
|
1286
|
+
│ ├── DiagonalGaussianEmissions # diagonal covariance — for high-dimensional data
|
|
1287
|
+
│ ├── LowRankGaussianEmissions # low-rank + diagonal — balances FC and parameters
|
|
1288
|
+
│ ├── SphericalGaussianEmissions # isotropic — minimal parameters
|
|
1289
|
+
│ ├── GaussianMixtureEmissions # mixture-of-Gaussians per state (multi-modal)
|
|
1290
|
+
│ └── (subclass AbstractEmissions for custom HRF-aware emissions, AR emissions, etc.)
|
|
1291
|
+
│
|
|
1292
|
+
└── Initial
|
|
1293
|
+
└── StandardInitialDistribution # learnable π vector
|
|
1294
|
+
|
|
1295
|
+
dynamax.linear_gaussian_ssm
|
|
1296
|
+
├── LinearGaussianSSM # standard LGSSM with Kalman filter/smoother
|
|
1297
|
+
└── LinearGaussianConjugateSSM # conjugate priors — EM with closed-form M-step
|
|
1298
|
+
```
|
|
1299
|
+
|
|
1300
|
+
### 13b. Drop-in Gaussian HMM (baseline)
|
|
1301
|
+
|
|
1302
|
+
```python
|
|
1303
|
+
"""Standard Gaussian HMM with dynamax — equivalent to hmmlearn but JAX-native."""
|
|
1304
|
+
import jax.numpy as jnp
|
|
1305
|
+
import jax.random as jr
|
|
1306
|
+
from dynamax.hidden_markov_model import GaussianHMM
|
|
1307
|
+
|
|
1308
|
+
|
|
1309
|
+
def fit_gaussian_hmm_dynamax(data_list, K, n_restarts=20, n_iters=100, seed=0):
|
|
1310
|
+
"""Fit Gaussian HMM using dynamax EM.
|
|
1311
|
+
|
|
1312
|
+
Parameters
|
|
1313
|
+
----------
|
|
1314
|
+
data_list : list of np.ndarray, each (T, D)
|
|
1315
|
+
Multiple runs — dynamax handles variable lengths via a list.
|
|
1316
|
+
K : int
|
|
1317
|
+
Number of states
|
|
1318
|
+
n_restarts : int
|
|
1319
|
+
Number of random restarts (take best final log-likelihood)
|
|
1320
|
+
n_iters : int
|
|
1321
|
+
Max EM iterations per restart
|
|
1322
|
+
|
|
1323
|
+
Returns
|
|
1324
|
+
-------
|
|
1325
|
+
best_params : dynamax parameter pytree
|
|
1326
|
+
best_lls : array of log-likelihoods per EM iteration
|
|
1327
|
+
model : GaussianHMM instance (for inference calls)
|
|
1328
|
+
"""
|
|
1329
|
+
import numpy as np
|
|
1330
|
+
D = data_list[0].shape[1]
|
|
1331
|
+
|
|
1332
|
+
# dynamax expects a single 2D array for single-sequence fitting,
|
|
1333
|
+
# or use jax.vmap / a loop for multiple sequences.
|
|
1334
|
+
# For multi-run fMRI, fit on concatenated data (pass lengths separately for scoring).
|
|
1335
|
+
emissions = jnp.array(np.vstack(data_list))
|
|
1336
|
+
|
|
1337
|
+
model = GaussianHMM(num_states=K, emission_dim=D)
|
|
1338
|
+
|
|
1339
|
+
best_params = None
|
|
1340
|
+
best_ll = -jnp.inf
|
|
1341
|
+
|
|
1342
|
+
for restart in range(n_restarts):
|
|
1343
|
+
key = jr.PRNGKey(seed + restart)
|
|
1344
|
+
|
|
1345
|
+
# K-means initialization on first restart, random otherwise
|
|
1346
|
+
init_method = "kmeans" if restart == 0 else "prior"
|
|
1347
|
+
params, props = model.initialize(key, method=init_method, emissions=emissions)
|
|
1348
|
+
|
|
1349
|
+
params, lls = model.fit_em(params, props, emissions, num_iters=n_iters)
|
|
1350
|
+
|
|
1351
|
+
if lls[-1] > best_ll:
|
|
1352
|
+
best_ll = lls[-1]
|
|
1353
|
+
best_params = params
|
|
1354
|
+
best_lls = lls
|
|
1355
|
+
|
|
1356
|
+
print(f"Best final log-likelihood: {best_ll:.2f}")
|
|
1357
|
+
return best_params, best_lls, model
|
|
1358
|
+
|
|
1359
|
+
|
|
1360
|
+
def decode_dynamax(model, params, emissions_jnp):
|
|
1361
|
+
"""Viterbi decoding and posterior smoothing with dynamax."""
|
|
1362
|
+
# Most likely state sequence (Viterbi)
|
|
1363
|
+
most_likely_states = model.posterior_mode(params, emissions_jnp)
|
|
1364
|
+
|
|
1365
|
+
# Posterior state probabilities (forward-backward smoother)
|
|
1366
|
+
posterior = model.smoother(params, emissions_jnp)
|
|
1367
|
+
smoothed_probs = posterior.smoothed_probs # (T, K)
|
|
1368
|
+
|
|
1369
|
+
return most_likely_states, smoothed_probs
|
|
1370
|
+
```
|
|
1371
|
+
|
|
1372
|
+
### 13c. Swapping the emission component (the Lego idea)
|
|
1373
|
+
|
|
1374
|
+
```python
|
|
1375
|
+
"""Replace Gaussian with DiagonalGaussian or LowRankGaussian — same training code."""
|
|
1376
|
+
from dynamax.hidden_markov_model import (
|
|
1377
|
+
GaussianHMM,
|
|
1378
|
+
DiagonalGaussianHMM,
|
|
1379
|
+
LowRankGaussianHMM,
|
|
1380
|
+
)
|
|
1381
|
+
import jax.random as jr
|
|
1382
|
+
import jax.numpy as jnp
|
|
1383
|
+
|
|
1384
|
+
|
|
1385
|
+
def compare_emission_types(emissions_jnp, K, rank=5, seed=0):
|
|
1386
|
+
"""Fit three HMMs that differ only in emission covariance structure.
|
|
1387
|
+
|
|
1388
|
+
This is the Lego principle: swap one brick (emission), keep everything else.
|
|
1389
|
+
Use BIC to decide which covariance structure the data supports.
|
|
1390
|
+
|
|
1391
|
+
Parameters
|
|
1392
|
+
----------
|
|
1393
|
+
emissions_jnp : jax array, shape (T, D)
|
|
1394
|
+
K : int
|
|
1395
|
+
Number of states
|
|
1396
|
+
rank : int
|
|
1397
|
+
Rank for LowRankGaussianHMM (ignored for other types)
|
|
1398
|
+
"""
|
|
1399
|
+
T, D = emissions_jnp.shape
|
|
1400
|
+
key = jr.PRNGKey(seed)
|
|
1401
|
+
|
|
1402
|
+
results = {}
|
|
1403
|
+
|
|
1404
|
+
emission_models = {
|
|
1405
|
+
# Full covariance: captures all pairwise FC per state
|
|
1406
|
+
# Parameters per state: D*(D+1)/2 — expensive for large D
|
|
1407
|
+
'full': GaussianHMM(num_states=K, emission_dim=D),
|
|
1408
|
+
|
|
1409
|
+
# Diagonal covariance: ignores FC, only models per-region variance
|
|
1410
|
+
# Parameters per state: D — use when D is large or data is limited
|
|
1411
|
+
'diagonal': DiagonalGaussianHMM(num_states=K, emission_dim=D),
|
|
1412
|
+
|
|
1413
|
+
# Low-rank + diagonal: rank-r approximation to FC + independent noise
|
|
1414
|
+
# Parameters per state: D*rank + D — good middle ground for parcellated data
|
|
1415
|
+
'low_rank': LowRankGaussianHMM(num_states=K, emission_dim=D, emission_rank=rank),
|
|
1416
|
+
}
|
|
1417
|
+
|
|
1418
|
+
for name, model in emission_models.items():
|
|
1419
|
+
k1, k2 = jr.split(jr.fold_in(key, hash(name) % 2**31))
|
|
1420
|
+
params, props = model.initialize(k1, method="kmeans", emissions=emissions_jnp)
|
|
1421
|
+
params, lls = model.fit_em(params, props, emissions_jnp, num_iters=100)
|
|
1422
|
+
|
|
1423
|
+
final_ll = float(lls[-1])
|
|
1424
|
+
# Count parameters for BIC
|
|
1425
|
+
n_transition = K * (K - 1)
|
|
1426
|
+
n_emission = {
|
|
1427
|
+
'full': K * D * (D + 1) // 2,
|
|
1428
|
+
'diagonal': K * D,
|
|
1429
|
+
'low_rank': K * (D * rank + D),
|
|
1430
|
+
}[name]
|
|
1431
|
+
n_params = n_transition + n_emission + (K - 1) # +initial
|
|
1432
|
+
bic = -2 * final_ll + n_params * jnp.log(T)
|
|
1433
|
+
|
|
1434
|
+
results[name] = {'ll': final_ll, 'bic': float(bic), 'params': params, 'model': model}
|
|
1435
|
+
print(f"{name:10s}: LL={final_ll:.1f} BIC={float(bic):.1f} n_params={n_params}")
|
|
1436
|
+
|
|
1437
|
+
best = min(results, key=lambda x: results[x]['bic'])
|
|
1438
|
+
print(f"\nBest emission type by BIC: {best}")
|
|
1439
|
+
return results
|
|
1440
|
+
|
|
1441
|
+
|
|
1442
|
+
# Swap the TRANSITION component instead:
|
|
1443
|
+
# Standard → Sticky, keeping the same GaussianHMM emissions
|
|
1444
|
+
from dynamax.hidden_markov_model import GaussianHMM
|
|
1445
|
+
# The sticky prior is set via transition_matrix_stickiness at initialization:
|
|
1446
|
+
#
|
|
1447
|
+
# params, props = model.initialize(
|
|
1448
|
+
# key,
|
|
1449
|
+
# transition_matrix=jnp.full((K, K), 1.0/K), # starting point
|
|
1450
|
+
# )
|
|
1451
|
+
# Then in props, set the stickiness concentration:
|
|
1452
|
+
# props.transitions.transition_matrix_concentration = ...
|
|
1453
|
+
# props.transitions.transition_matrix_stickiness = 50.0 # κ (higher = stickier)
|
|
1454
|
+
#
|
|
1455
|
+
# See: GaussianHMM(..., transition_matrix_stickiness=50.0) constructor kwarg.
|
|
1456
|
+
```
|
|
1457
|
+
|
|
1458
|
+
### 13d. Custom emission model — HRF-aware Gaussian emissions
|
|
1459
|
+
|
|
1460
|
+
```python
|
|
1461
|
+
"""Build a custom emission class that absorbs HRF smoothing in its mean structure.
|
|
1462
|
+
|
|
1463
|
+
The Lego design lets you subclass AbstractEmissions and plug it into any dynamax HMM
|
|
1464
|
+
without touching the transition or initial components.
|
|
1465
|
+
|
|
1466
|
+
This is a sketch — fill in the HRF convolution matrix H and adapt for your data.
|
|
1467
|
+
"""
|
|
1468
|
+
import jax.numpy as jnp
|
|
1469
|
+
import jax
|
|
1470
|
+
from dynamax.hidden_markov_model.models import HMM
|
|
1471
|
+
from dynamax.hidden_markov_model.emissions import GaussianEmissions
|
|
1472
|
+
|
|
1473
|
+
|
|
1474
|
+
class HRFGaussianEmissions(GaussianEmissions):
|
|
1475
|
+
"""Gaussian emissions whose mean is HRF-convolved neural activity.
|
|
1476
|
+
|
|
1477
|
+
Each state has a neural mean μ_k. The observed emission mean is H @ μ_k,
|
|
1478
|
+
where H is the T×T lower-triangular Toeplitz HRF convolution matrix.
|
|
1479
|
+
This bakes the HRF into the emission model rather than preprocessing.
|
|
1480
|
+
|
|
1481
|
+
Parameters
|
|
1482
|
+
----------
|
|
1483
|
+
hrf_matrix : jax array, shape (T, T)
|
|
1484
|
+
Pre-computed HRF convolution matrix (lower-triangular Toeplitz).
|
|
1485
|
+
Build with: scipy.linalg.toeplitz(hrf, np.zeros(T)) clipped to (T, T).
|
|
1486
|
+
"""
|
|
1487
|
+
|
|
1488
|
+
def __init__(self, num_states, emission_dim, hrf_matrix, **kwargs):
|
|
1489
|
+
super().__init__(num_states, emission_dim, **kwargs)
|
|
1490
|
+
self.hrf_matrix = hrf_matrix # (T, T) fixed — not learned
|
|
1491
|
+
|
|
1492
|
+
def log_likelihoods(self, emissions, inputs, params, **kwargs):
|
|
1493
|
+
# Convolve each state mean with HRF before computing Gaussian likelihood
|
|
1494
|
+
# params.means: (K, D)
|
|
1495
|
+
# hrf_matrix @ params.means.T would give (T, K, D) — reshape as needed
|
|
1496
|
+
convolved_means = jnp.einsum('td,kd->tk', self.hrf_matrix, params.means)
|
|
1497
|
+
# convolved_means: (T, K) for D=1, or extend to (T, K, D) for multivariate
|
|
1498
|
+
# ... then evaluate N(emissions[t] | convolved_means[t,k], Sigma_k)
|
|
1499
|
+
# Full implementation depends on your HRF matrix structure.
|
|
1500
|
+
raise NotImplementedError("Fill in the multivariate Gaussian log-likelihood here.")
|
|
1501
|
+
|
|
1502
|
+
|
|
1503
|
+
# Usage pattern:
|
|
1504
|
+
# hrf_matrix = build_hrf_toeplitz(hrf_kernel, T=n_trs)
|
|
1505
|
+
# emission_component = HRFGaussianEmissions(K, D, hrf_matrix)
|
|
1506
|
+
# model = HMM(num_states=K, initial_component=...,
|
|
1507
|
+
# transition_component=..., emission_component=emission_component)
|
|
1508
|
+
```
|
|
1509
|
+
|
|
1510
|
+
### 13e. Linear Gaussian SSM (Kalman filter) for smooth latent dynamics
|
|
1511
|
+
|
|
1512
|
+
```python
|
|
1513
|
+
"""LGSSM with dynamax — useful when you want continuous latent dynamics without discrete switching.
|
|
1514
|
+
|
|
1515
|
+
Think of it as SLDS with K=1 (single regime). Use it as a baseline before fitting SLDS/rSLDS.
|
|
1516
|
+
"""
|
|
1517
|
+
import jax.numpy as jnp
|
|
1518
|
+
import jax.random as jr
|
|
1519
|
+
from dynamax.linear_gaussian_ssm import LinearGaussianSSM
|
|
1520
|
+
|
|
1521
|
+
|
|
1522
|
+
def fit_lgssm_fmri(bold_data, state_dim=10, n_iters=100, seed=0):
|
|
1523
|
+
"""Fit a Linear Gaussian SSM (Kalman filter model) to BOLD data.
|
|
1524
|
+
|
|
1525
|
+
Learns a low-dimensional latent trajectory that best explains the BOLD signal.
|
|
1526
|
+
No discrete switching — use this as a baseline or for smooth dynamics analyses.
|
|
1527
|
+
|
|
1528
|
+
Parameters
|
|
1529
|
+
----------
|
|
1530
|
+
bold_data : array, shape (T, D)
|
|
1531
|
+
Preprocessed, parcellated BOLD timeseries
|
|
1532
|
+
state_dim : int
|
|
1533
|
+
Latent state dimension (typically 5–20 for parcellated fMRI)
|
|
1534
|
+
|
|
1535
|
+
Returns
|
|
1536
|
+
-------
|
|
1537
|
+
params : fitted LGSSM parameters
|
|
1538
|
+
posterior : smoothed posterior (means, covariances, marginal_loglik)
|
|
1539
|
+
"""
|
|
1540
|
+
T, D = bold_data.shape
|
|
1541
|
+
emissions = jnp.array(bold_data)
|
|
1542
|
+
|
|
1543
|
+
model = LinearGaussianSSM(state_dim=state_dim, emission_dim=D)
|
|
1544
|
+
params, props = model.initialize(jr.PRNGKey(seed))
|
|
1545
|
+
params, lls = model.fit_em(params, props, emissions, num_iters=n_iters)
|
|
1546
|
+
|
|
1547
|
+
# Smooth the latent trajectory
|
|
1548
|
+
posterior = model.smoother(params, emissions)
|
|
1549
|
+
latent_means = posterior.smoothed_means # (T, state_dim)
|
|
1550
|
+
latent_covs = posterior.smoothed_covariances # (T, state_dim, state_dim)
|
|
1551
|
+
marginal_ll = posterior.marginal_loglik
|
|
1552
|
+
|
|
1553
|
+
print(f"Final marginal log-likelihood: {float(marginal_ll):.2f}")
|
|
1554
|
+
print(f"Emission matrix shape: {params.emissions.weights.shape}") # (D, state_dim)
|
|
1555
|
+
|
|
1556
|
+
return params, posterior
|
|
1557
|
+
```
|
|
1558
|
+
|
|
1559
|
+
### 13f. When to use dynamax vs. ssm vs. hmmlearn
|
|
1560
|
+
|
|
1561
|
+
| Situation | Best choice |
|
|
1562
|
+
|-----------|------------|
|
|
1563
|
+
| Quickest path to a working Gaussian HMM | `hmmlearn` |
|
|
1564
|
+
| Need rSLDS, SNLDS, or input-driven SLDS | `ssm` |
|
|
1565
|
+
| Want to experiment with different emission/transition types rapidly | **`dynamax`** |
|
|
1566
|
+
| Want to build a custom emission (e.g., HRF-aware, AR, mixture) | **`dynamax`** (subclass) |
|
|
1567
|
+
| Want Kalman filter / LGSSM as a baseline | **`dynamax`** |
|
|
1568
|
+
| Need JAX JIT + GPU for large-scale inference | **`dynamax`** or `ssm` |
|
|
1569
|
+
| Group-level HMM with neuroimaging-specific features | `glhmm` |
|
|
1570
|
+
| Deep generative model (DyNeMo) | `osl-dynamics` |
|