@yibeichan/claude-skills 1.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +98 -0
  3. package/cli.js +272 -0
  4. package/install.py +240 -0
  5. package/package.json +44 -0
  6. package/skills/bidsapp-nidm-standards/SKILL.md +202 -0
  7. package/skills/bidsapp-nidm-standards/references/babs_config.md +20 -0
  8. package/skills/bidsapp-nidm-standards/references/cli_arguments.md +76 -0
  9. package/skills/bidsapp-nidm-standards/references/container_patterns.md +53 -0
  10. package/skills/bidsapp-nidm-standards/references/nidm_integration.md +403 -0
  11. package/skills/bidsapp-nidm-standards/references/repo_structure.md +121 -0
  12. package/skills/bidsapp-nidm-standards/references/testing_patterns.md +82 -0
  13. package/skills/dicom2fmriprep/SKILL.md +377 -0
  14. package/skills/dicom2fmriprep/evals/evals.json +26 -0
  15. package/skills/dicom2fmriprep/references/babs-details.md +407 -0
  16. package/skills/dicom2fmriprep/references/fmriprep-details.md +250 -0
  17. package/skills/dicom2fmriprep/references/heudiconv-details.md +243 -0
  18. package/skills/fmri-ssm/SKILL.md +317 -0
  19. package/skills/fmri-ssm/references/code_templates.md +1570 -0
  20. package/skills/fmri-ssm/references/downstream_analysis.md +680 -0
  21. package/skills/fmri-ssm/references/group_inference.md +608 -0
  22. package/skills/fmri-ssm/references/hrf_modeling.md +447 -0
  23. package/skills/fmri-ssm/references/model_catalog.md +436 -0
  24. package/skills/fmri-ssm/references/paradigm_guide.md +406 -0
  25. package/skills/fmri-ssm/references/preprocessing.md +614 -0
  26. package/skills/fmri-ssm.zip +0 -0
  27. package/skills/neuroimaging-qc/SKILL.md +203 -0
  28. package/skills/neuroimaging-qc/references/eeg_qc.md +400 -0
  29. package/skills/neuroimaging-qc/references/fmri_qc.md +343 -0
  30. package/skills/neuroimaging-qc/references/fnirs_qc.md +430 -0
  31. package/skills/neuroimaging-qc/references/structural_qc.md +454 -0
  32. package/skills/neuroimaging-qc/scripts/parse_fmriprep_confounds.py +153 -0
  33. package/skills/neuroimaging-qc/scripts/parse_mriqc.py +114 -0
  34. package/skills/neuroimaging-qc/scripts/qc_report.py +295 -0
  35. package/skills/scientific-writer/SKILL.md +202 -0
  36. package/skills/scientific-writer/references/citation_styles.md +163 -0
  37. package/skills/scientific-writer/references/field_conventions.md +245 -0
  38. package/skills/scientific-writer/references/figures_tables.md +225 -0
  39. package/skills/scientific-writer/references/reporting_guidelines.md +225 -0
  40. package/skills.json +54 -0
@@ -0,0 +1,1570 @@
1
+ # Python Code Templates for SSMs on fMRI Data
2
+
3
+ ## Table of Contents
4
+ 1. [Gaussian HMM with hmmlearn](#gaussian-hmm)
5
+ 2. [Gaussian HMM with ssm library](#gaussian-hmm-ssm)
6
+ 3. [Sticky HMM](#sticky-hmm)
7
+ 4. [HMM-MAR with osl-dynamics](#hmm-mar)
8
+ 5. [Input-Output HMM](#io-hmm)
9
+ 6. [SLDS with ssm library](#slds)
10
+ 7. [rSLDS with ssm library](#rslds)
11
+ 8. [Model Selection (choosing K)](#model-selection)
12
+ 9. [State Visualization](#visualization)
13
+ 10. [Reproducibility and Initialization](#reproducibility)
14
+ 11. [Model Diagnostics: Detecting Pathological Fits](#diagnostics)
15
+ 12. [JAX-Based HMM for GPU Acceleration](#jax-gpu)
16
+ 13. [dynamax: Lego-Style Custom SSMs](#dynamax)
17
+
18
+ All code assumes preprocessed, parcellated timeseries. See `preprocessing.md` for
19
+ how to get from raw fMRIPrep outputs to the data matrices used here.
20
+
21
+ ---
22
+
23
+ ## 1. Gaussian HMM with hmmlearn {#gaussian-hmm}
24
+
25
+ The simplest and most widely used SSM for fMRI. Start here unless you have
26
+ specific reasons to use a more complex model.
27
+
28
+ ```python
29
+ """Gaussian HMM for fMRI using hmmlearn.
30
+
31
+ Fits a K-state HMM with Gaussian emissions on parcellated BOLD timeseries.
32
+ Includes multiple random restarts, K-means initialization, and run-boundary handling.
33
+ """
34
+ import numpy as np
35
+ from hmmlearn import hmm
36
+ from sklearn.cluster import KMeans
37
+
38
+ def fit_gaussian_hmm(data, lengths, K, covariance_type='full',
39
+ n_restarts=50, n_iter=200, random_state=42):
40
+ """Fit Gaussian HMM with multiple random restarts.
41
+
42
+ Parameters
43
+ ----------
44
+ data : array, shape (total_T, n_features)
45
+ Concatenated BOLD timeseries across runs
46
+ lengths : list of int
47
+ Number of TRs per run (for run boundary handling)
48
+ K : int
49
+ Number of hidden states
50
+ covariance_type : str
51
+ 'full', 'diag', 'tied', or 'spherical'
52
+ n_restarts : int
53
+ Number of random restarts (take best log-likelihood)
54
+ n_iter : int
55
+ Max EM iterations per restart
56
+ random_state : int
57
+ Base random seed
58
+
59
+ Returns
60
+ -------
61
+ best_model : GaussianHMM
62
+ Fitted model with highest log-likelihood
63
+ best_score : float
64
+ Log-likelihood of best model
65
+ all_scores : list
66
+ Log-likelihoods from all restarts
67
+ """
68
+ best_model = None
69
+ best_score = -np.inf
70
+ all_scores = []
71
+
72
+ for restart in range(n_restarts):
73
+ model = hmm.GaussianHMM(
74
+ n_components=K,
75
+ covariance_type=covariance_type,
76
+ n_iter=n_iter,
77
+ tol=1e-4,
78
+ random_state=random_state + restart,
79
+ init_params='stmc', # initialize all parameters
80
+ verbose=False,
81
+ )
82
+
83
+ # K-means initialization for means (much better than random)
84
+ if restart == 0:
85
+ kmeans = KMeans(n_clusters=K, random_state=random_state, n_init=10)
86
+ kmeans.fit(data)
87
+ model.means_init = kmeans.cluster_centers_
88
+
89
+ try:
90
+ model.fit(data, lengths=lengths)
91
+ score = model.score(data, lengths=lengths)
92
+ all_scores.append(score)
93
+
94
+ if score > best_score:
95
+ best_score = score
96
+ best_model = model
97
+ except Exception as e:
98
+ all_scores.append(np.nan)
99
+ continue
100
+
101
+ n_converged = sum(1 for s in all_scores if not np.isnan(s))
102
+ print(f"Converged: {n_converged}/{n_restarts}")
103
+ print(f"Best log-likelihood: {best_score:.2f}")
104
+ print(f"Score range: {np.nanmin(all_scores):.2f} to {np.nanmax(all_scores):.2f}")
105
+
106
+ return best_model, best_score, all_scores
107
+
108
+
109
+ def extract_state_info(model, data, lengths):
110
+ """Extract state sequence, dwell times, and transition matrix.
111
+
112
+ Returns
113
+ -------
114
+ results : dict with keys:
115
+ 'states': state sequence
116
+ 'state_probs': posterior state probabilities
117
+ 'dwell_times': dict mapping state -> list of dwell durations (in TRs)
118
+ 'fractional_occupancy': proportion of time in each state
119
+ 'transition_matrix': estimated transition matrix
120
+ 'means': state means, shape (K, n_features)
121
+ 'covariances': state covariances
122
+ """
123
+ states = model.predict(data, lengths=lengths)
124
+ state_probs = model.predict_proba(data, lengths=lengths)
125
+
126
+ # Compute dwell times (respecting run boundaries)
127
+ dwell_times = {k: [] for k in range(model.n_components)}
128
+ offset = 0
129
+ for length in lengths:
130
+ run_states = states[offset:offset + length]
131
+ current_state = run_states[0]
132
+ current_dwell = 1
133
+ for t in range(1, length):
134
+ if run_states[t] == current_state:
135
+ current_dwell += 1
136
+ else:
137
+ dwell_times[current_state].append(current_dwell)
138
+ current_state = run_states[t]
139
+ current_dwell = 1
140
+ dwell_times[current_state].append(current_dwell) # last dwell
141
+ offset += length
142
+
143
+ # Fractional occupancy
144
+ frac_occ = np.array([(states == k).sum() / len(states)
145
+ for k in range(model.n_components)])
146
+
147
+ return {
148
+ 'states': states,
149
+ 'state_probs': state_probs,
150
+ 'dwell_times': dwell_times,
151
+ 'fractional_occupancy': frac_occ,
152
+ 'transition_matrix': model.transmat_,
153
+ 'means': model.means_,
154
+ 'covariances': model.covars_,
155
+ }
156
+ ```
157
+
158
+ ---
159
+
160
+ ## 2. Gaussian HMM with ssm library {#gaussian-hmm-ssm}
161
+
162
+ The `ssm` library (Linderman lab) provides a unified API for HMMs, SLDS, and rSLDS.
163
+ Useful when you want to compare model families within the same framework.
164
+
165
+ ```python
166
+ """Gaussian HMM using the ssm library (Linderman et al.)
167
+
168
+ Install: pip install ssm (or from GitHub: pip install git+https://github.com/lindermanlab/ssm)
169
+ """
170
+ import ssm
171
+ import numpy as np
172
+
173
+ def fit_hmm_ssm(data_list, K, D=None, n_restarts=20, n_iters=200):
174
+ """Fit Gaussian HMM using ssm library.
175
+
176
+ Parameters
177
+ ----------
178
+ data_list : list of arrays
179
+ Each array is (T_run, D) for one run. Do NOT concatenate — ssm
180
+ handles multiple sequences natively.
181
+ K : int
182
+ Number of states
183
+ D : int or None
184
+ Observation dimension (inferred from data if None)
185
+
186
+ Returns
187
+ -------
188
+ best_model : ssm.HMM
189
+ best_lls : list of float
190
+ Log-likelihood per EM iteration for best run
191
+ """
192
+ if D is None:
193
+ D = data_list[0].shape[1]
194
+
195
+ best_model = None
196
+ best_ll = -np.inf
197
+
198
+ for restart in range(n_restarts):
199
+ model = ssm.HMM(
200
+ K=K,
201
+ D=D,
202
+ observations='gaussian', # or 'diagonal_gaussian', 'studentst'
203
+ transitions='standard', # or 'sticky'
204
+ )
205
+
206
+ lls = model.fit(
207
+ data_list,
208
+ method='em',
209
+ num_iters=n_iters,
210
+ tolerance=1e-4,
211
+ )
212
+
213
+ final_ll = lls[-1]
214
+ if final_ll > best_ll:
215
+ best_ll = final_ll
216
+ best_model = model
217
+ best_lls = lls
218
+
219
+ print(f"Best log-likelihood: {best_ll:.2f}")
220
+ return best_model, best_lls
221
+
222
+
223
+ def decode_states_ssm(model, data_list):
224
+ """Get most likely state sequences and posterior probabilities."""
225
+ all_states = []
226
+ all_posteriors = []
227
+
228
+ for data in data_list:
229
+ # Viterbi (most likely sequence)
230
+ states = model.most_likely_states(data)
231
+ all_states.append(states)
232
+
233
+ # Posterior probabilities
234
+ posteriors = model.expected_states(data)[0] # (T, K)
235
+ all_posteriors.append(posteriors)
236
+
237
+ return all_states, all_posteriors
238
+ ```
239
+
240
+ ---
241
+
242
+ ## 3. Sticky HMM {#sticky-hmm}
243
+
244
+ Adds self-transition bias to prevent unrealistically rapid state switching.
245
+
246
+ ```python
247
+ """Sticky HMM using ssm library."""
248
+ import ssm
249
+
250
+ def fit_sticky_hmm(data_list, K, D, kappa=100, n_restarts=20, n_iters=200):
251
+ """Fit sticky HMM.
252
+
253
+ Parameters
254
+ ----------
255
+ kappa : float
256
+ Stickiness parameter. Higher = longer dwell times.
257
+ kappa=0 is standard HMM. kappa=100-1000 is typical for fMRI.
258
+ Rule of thumb: set kappa so expected dwell time is ~5-10 TRs.
259
+ """
260
+ best_model = None
261
+ best_ll = -np.inf
262
+
263
+ for restart in range(n_restarts):
264
+ model = ssm.HMM(
265
+ K=K,
266
+ D=D,
267
+ observations='gaussian',
268
+ transitions='sticky',
269
+ transition_kwargs={'kappa': kappa},
270
+ )
271
+
272
+ lls = model.fit(data_list, method='em', num_iters=n_iters)
273
+
274
+ if lls[-1] > best_ll:
275
+ best_ll = lls[-1]
276
+ best_model = model
277
+
278
+ # Check effective dwell times
279
+ trans_mat = np.exp(best_model.transitions.log_Ps)
280
+ expected_dwell = 1.0 / (1.0 - np.diag(trans_mat))
281
+ print(f"Expected dwell times (TRs): {expected_dwell}")
282
+
283
+ return best_model
284
+ ```
285
+
286
+ ---
287
+
288
+ ## 4. HMM-MAR with osl-dynamics {#hmm-mar}
289
+
290
+ ```python
291
+ """HMM-MAR using osl-dynamics (Oxford's toolbox for dynamic brain analysis).
292
+
293
+ Install: pip install osl-dynamics
294
+ This is the Python successor to the MATLAB HMM-MAR toolbox (Vidaurre et al.).
295
+ """
296
+ from osl_dynamics.models.hmm import Config, Model
297
+ from osl_dynamics.data import Data
298
+ import numpy as np
299
+
300
+ def fit_hmm_mar(data_files, K, n_channels, sequence_length=200,
301
+ n_ar_lags=3, learn_means=True, learn_covariances=True,
302
+ n_epochs=40, batch_size=64):
303
+ """Fit HMM-MAR using osl-dynamics.
304
+
305
+ Parameters
306
+ ----------
307
+ data_files : list of str or list of arrays
308
+ Paths to numpy files or arrays, each (T, n_channels)
309
+ K : int
310
+ Number of states
311
+ n_channels : int
312
+ Number of brain regions / ICA components
313
+ sequence_length : int
314
+ Length of sequences for training (segments of the timeseries)
315
+ n_ar_lags : int
316
+ Number of autoregressive lags (typically 1-5 for fMRI)
317
+ learn_means : bool
318
+ Whether states have different means
319
+ learn_covariances : bool
320
+ Whether states have different covariances
321
+ """
322
+ # Prepare data
323
+ data = Data(data_files, store_dir='/tmp/osl_dynamics_data')
324
+ data.prepare({
325
+ 'tde_pca': {'n_embeddings': n_ar_lags * 2 + 1, 'n_pca_components': n_channels},
326
+ 'standardize': {},
327
+ })
328
+
329
+ # Configure model
330
+ config = Config(
331
+ n_states=K,
332
+ n_channels=n_channels,
333
+ sequence_length=sequence_length,
334
+ learn_means=learn_means,
335
+ learn_covariances=learn_covariances,
336
+ batch_size=batch_size,
337
+ learning_rate=0.01,
338
+ n_epochs=n_epochs,
339
+ )
340
+
341
+ model = Model(config)
342
+ model.random_state_time_course_initialization(data, n_init=5, n_epochs=2)
343
+ history = model.fit(data)
344
+
345
+ # Get state time courses
346
+ alpha = model.get_alpha(data) # list of (T, K) arrays — state probabilities
347
+
348
+ return model, alpha, history
349
+
350
+
351
+ # Alternative: using glhmm (Vidaurre's newer library)
352
+ def fit_glhmm(data_list, K, ar_order=3):
353
+ """Fit Gaussian-Linear HMM using glhmm.
354
+
355
+ Install: pip install glhmm
356
+ Note: import the class directly — 'from glhmm import glhmm as gl' followed by
357
+ 'gl.glhmm(...)' is a double-level call and will raise AttributeError.
358
+ """
359
+ from glhmm import glhmm # imports the glhmm *class* from the glhmm *package*
360
+ from glhmm import preproc
361
+
362
+ # Stack data
363
+ data_concat = np.vstack(data_list)
364
+ T_list = [d.shape[0] for d in data_list]
365
+ indices = preproc.build_indices(T_list)
366
+
367
+ # Instantiate directly — glhmm is the class, not a submodule
368
+ model = glhmm(
369
+ K=K,
370
+ covtype='full',
371
+ model_mean='state',
372
+ model_beta='state',
373
+ ar_order=ar_order,
374
+ )
375
+ model.train(data_concat, indices=indices, maxiter=200)
376
+
377
+ # Decode
378
+ vpath = model.decode(data_concat, indices=indices)
379
+
380
+ return model, vpath
381
+ ```
382
+
383
+ ---
384
+
385
+ ## 5. Input-Output HMM {#io-hmm}
386
+
387
+ For task-based fMRI where external events drive state transitions.
388
+
389
+ ```python
390
+ """Input-Output HMM using ssm library.
391
+
392
+ Task events enter as inputs that modulate either transitions or emissions.
393
+ """
394
+ import ssm
395
+ import numpy as np
396
+ from nilearn.glm.first_level import spm_hrf
397
+
398
+ def prepare_task_inputs(events_df, n_trs, tr, hrf_convolve=True):
399
+ """Convert task events to input matrix for IO-HMM.
400
+
401
+ Parameters
402
+ ----------
403
+ events_df : DataFrame
404
+ Columns: onset, duration, trial_type
405
+ n_trs : int
406
+ Total number of TRs
407
+ tr : float
408
+ Repetition time
409
+ hrf_convolve : bool
410
+ Whether to convolve inputs with HRF. Set True when fitting on BOLD
411
+ (so inputs align with BOLD timing). Set False if fitting on deconvolved data.
412
+
413
+ Returns
414
+ -------
415
+ inputs : array, shape (n_trs, n_conditions)
416
+ One column per trial_type
417
+ condition_names : list of str
418
+ """
419
+ trial_types = sorted(events_df['trial_type'].unique())
420
+ n_conditions = len(trial_types)
421
+
422
+ # Build stimulus timecourse at TR resolution
423
+ inputs = np.zeros((n_trs, n_conditions))
424
+
425
+ for i, tt in enumerate(trial_types):
426
+ events = events_df[events_df['trial_type'] == tt]
427
+ for _, event in events.iterrows():
428
+ onset_tr = int(np.round(event['onset'] / tr))
429
+ dur_trs = max(1, int(np.round(event['duration'] / tr)))
430
+ end_tr = min(n_trs, onset_tr + dur_trs)
431
+ inputs[onset_tr:end_tr, i] = 1.0
432
+
433
+ if hrf_convolve:
434
+ hrf = spm_hrf(tr, oversampling=1)
435
+ for i in range(n_conditions):
436
+ # mode='same' preserves array length T (mode='full' would return T+len(hrf)-1)
437
+ convolved = np.convolve(inputs[:, i], hrf, mode='same')
438
+ inputs[:, i] = convolved
439
+
440
+ return inputs, trial_types
441
+
442
+
443
+ def fit_io_hmm(data_list, inputs_list, K, n_features, n_inputs,
444
+ input_driven='transitions', n_restarts=20):
445
+ """Fit input-output HMM.
446
+
447
+ Parameters
448
+ ----------
449
+ data_list : list of arrays
450
+ Each (T, n_features)
451
+ inputs_list : list of arrays
452
+ Each (T, n_inputs) — task inputs for each run
453
+ K : int
454
+ Number of states
455
+ input_driven : str
456
+ 'transitions': inputs affect state switching probabilities
457
+ 'observations': inputs affect emission means
458
+ 'both': inputs affect both
459
+ """
460
+ if input_driven == 'transitions':
461
+ transitions = 'inputdriven'
462
+ observations = 'gaussian'
463
+ elif input_driven == 'observations':
464
+ transitions = 'standard'
465
+ # ssm does not have a built-in 'input_driven_obs' observation class as of v0.0.1.
466
+ # For input-modulated emissions, subclass ssm.observations.GaussianObservations
467
+ # and override the log_likelihoods() method to add B_k @ u_t to the mean.
468
+ # See: https://github.com/lindermanlab/ssm/blob/master/ssm/observations.py
469
+ observations = 'gaussian' # placeholder — customise as needed
470
+ elif input_driven == 'both':
471
+ transitions = 'inputdriven'
472
+ observations = 'gaussian' # replace with custom class for input-driven emissions
473
+
474
+ best_model = None
475
+ best_ll = -np.inf
476
+
477
+ for restart in range(n_restarts):
478
+ model = ssm.HMM(
479
+ K=K,
480
+ D=n_features,
481
+ M=n_inputs,
482
+ observations='gaussian',
483
+ transitions=transitions,
484
+ )
485
+
486
+ lls = model.fit(
487
+ data_list,
488
+ inputs=inputs_list,
489
+ method='em',
490
+ num_iters=200,
491
+ )
492
+
493
+ if lls[-1] > best_ll:
494
+ best_ll = lls[-1]
495
+ best_model = model
496
+
497
+ return best_model
498
+ ```
499
+
500
+ ---
501
+
502
+ ## 6. SLDS with ssm library {#slds}
503
+
504
+ ```python
505
+ """Switching Linear Dynamical System using ssm library."""
506
+ import ssm
507
+ import numpy as np
508
+
509
+ def fit_slds(data_list, K, D, latent_dim,
510
+ n_restarts=10, n_iters=100):
511
+ """Fit SLDS.
512
+
513
+ Parameters
514
+ ----------
515
+ data_list : list of arrays
516
+ Each (T, D) observation timeseries
517
+ K : int
518
+ Number of discrete switching states
519
+ D : int
520
+ Observation dimension
521
+ latent_dim : int
522
+ Continuous latent state dimension (typically 5-15 for fMRI)
523
+
524
+ Returns
525
+ -------
526
+ best_model : ssm.SLDS
527
+ q : variational posterior
528
+ """
529
+ best_model = None
530
+ best_elbo = -np.inf
531
+
532
+ for restart in range(n_restarts):
533
+ model = ssm.SLDS(
534
+ N=D, # observation dimension
535
+ K=K, # number of discrete states
536
+ D=latent_dim, # latent dimension
537
+ emissions='gaussian_orthog', # orthogonal emission matrix
538
+ dynamics='gaussian',
539
+ transitions='standard', # or 'recurrent' for rSLDS
540
+ )
541
+
542
+ # Fit with Laplace-EM
543
+ elbos = model.fit(
544
+ data_list,
545
+ method='laplace_em',
546
+ variational_posterior='structured_meanfield',
547
+ num_iters=n_iters,
548
+ initialize=True,
549
+ )
550
+
551
+ if elbos[-1] > best_elbo:
552
+ best_elbo = elbos[-1]
553
+ best_model = model
554
+ best_elbos = elbos
555
+
556
+ print(f"Best ELBO: {best_elbo:.2f}")
557
+ return best_model, best_elbos
558
+
559
+
560
+ def decode_slds(model, data_list):
561
+ """Decode latent states from fitted SLDS."""
562
+ all_discrete_states = []
563
+ all_continuous_states = []
564
+
565
+ for data in data_list:
566
+ # Most likely discrete states
567
+ z = model.most_likely_states(data)
568
+ all_discrete_states.append(z)
569
+
570
+ # Posterior mean of continuous states
571
+ q = model.approximate_posterior(
572
+ data,
573
+ method='laplace_em',
574
+ variational_posterior='structured_meanfield',
575
+ num_iters=50,
576
+ )
577
+ x = q.mean_continuous_states[0] # (T, latent_dim)
578
+ all_continuous_states.append(x)
579
+
580
+ return all_discrete_states, all_continuous_states
581
+ ```
582
+
583
+ ---
584
+
585
+ ## 7. rSLDS with ssm library {#rslds}
586
+
587
+ ```python
588
+ """Recurrent SLDS — discrete states depend on continuous latent state."""
589
+ import ssm
590
+
591
+ def fit_rslds(data_list, K, D, latent_dim, n_restarts=10, n_iters=100):
592
+ """Fit rSLDS.
593
+
594
+ The key difference from SLDS: transitions='recurrent', meaning
595
+ P(z_t | z_{t-1}, x_{t-1}) depends on the continuous state x.
596
+ """
597
+ best_model = None
598
+ best_elbo = -np.inf
599
+
600
+ for restart in range(n_restarts):
601
+ model = ssm.SLDS(
602
+ N=D,
603
+ K=K,
604
+ D=latent_dim,
605
+ emissions='gaussian_orthog',
606
+ dynamics='gaussian',
607
+ transitions='recurrent', # This makes it recurrent
608
+ )
609
+
610
+ elbos = model.fit(
611
+ data_list,
612
+ method='laplace_em',
613
+ variational_posterior='structured_meanfield',
614
+ num_iters=n_iters,
615
+ initialize=True,
616
+ )
617
+
618
+ if elbos[-1] > best_elbo:
619
+ best_elbo = elbos[-1]
620
+ best_model = model
621
+
622
+ return best_model
623
+ ```
624
+
625
+ ---
626
+
627
+ ## 8. Model Selection — Choosing K {#model-selection}
628
+
629
+ ```python
630
+ """Model selection utilities for SSMs on fMRI data."""
631
+ import numpy as np
632
+ from sklearn.model_selection import KFold
633
+
634
+ def select_K_bic(data, lengths, K_range=range(2, 16), covariance_type='full',
635
+ n_restarts=30):
636
+ """Select K using BIC (Bayesian Information Criterion).
637
+
638
+ BIC = -2 * log_likelihood + n_params * log(n_samples)
639
+ Lower BIC is better.
640
+ """
641
+ from hmmlearn import hmm
642
+
643
+ results = {}
644
+ T, p = data.shape
645
+
646
+ for K in K_range:
647
+ model, score, _ = fit_gaussian_hmm(
648
+ data, lengths, K, covariance_type=covariance_type,
649
+ n_restarts=n_restarts
650
+ )
651
+
652
+ # Count parameters
653
+ n_params = (K - 1) # initial state
654
+ n_params += K * (K - 1) # transition matrix
655
+ n_params += K * p # means
656
+ if covariance_type == 'full':
657
+ n_params += K * p * (p + 1) // 2 # covariances
658
+ elif covariance_type == 'diag':
659
+ n_params += K * p
660
+
661
+ bic = -2 * score + n_params * np.log(T)
662
+ aic = -2 * score + 2 * n_params
663
+
664
+ results[K] = {
665
+ 'log_likelihood': score,
666
+ 'bic': bic,
667
+ 'aic': aic,
668
+ 'n_params': n_params,
669
+ 'model': model,
670
+ }
671
+ print(f"K={K}: LL={score:.1f}, BIC={bic:.1f}, AIC={aic:.1f}, params={n_params}")
672
+
673
+ best_K_bic = min(results, key=lambda k: results[k]['bic'])
674
+ best_K_aic = min(results, key=lambda k: results[k]['aic'])
675
+ print(f"\nBest K by BIC: {best_K_bic}")
676
+ print(f"Best K by AIC: {best_K_aic}")
677
+
678
+ return results, best_K_bic
679
+
680
+
681
+ def select_K_crossval(data_runs, K_range=range(2, 16), n_folds=None,
682
+ covariance_type='full', n_restarts=20):
683
+ """Select K using cross-validated log-likelihood on held-out runs.
684
+
685
+ Parameters
686
+ ----------
687
+ data_runs : list of arrays
688
+ One array per fMRI run, shape (T_run, n_features)
689
+ K_range : range
690
+ K values to test
691
+ n_folds : int or None
692
+ If None, use leave-one-run-out
693
+ """
694
+ from hmmlearn import hmm
695
+
696
+ n_runs = len(data_runs)
697
+ if n_folds is None:
698
+ n_folds = n_runs # leave-one-run-out
699
+
700
+ results = {}
701
+
702
+ for K in K_range:
703
+ fold_scores = []
704
+
705
+ kf = KFold(n_splits=min(n_folds, n_runs), shuffle=False)
706
+ run_indices = np.arange(n_runs)
707
+
708
+ for train_idx, test_idx in kf.split(run_indices):
709
+ # Concatenate training runs
710
+ train_data = np.vstack([data_runs[i] for i in train_idx])
711
+ train_lengths = [data_runs[i].shape[0] for i in train_idx]
712
+
713
+ # Concatenate test runs
714
+ test_data = np.vstack([data_runs[i] for i in test_idx])
715
+ test_lengths = [data_runs[i].shape[0] for i in test_idx]
716
+
717
+ # Fit on train
718
+ model, _, _ = fit_gaussian_hmm(
719
+ train_data, train_lengths, K,
720
+ covariance_type=covariance_type,
721
+ n_restarts=n_restarts
722
+ )
723
+
724
+ # Score on test
725
+ test_score = model.score(test_data, test_lengths)
726
+ # Normalize by number of test time points
727
+ fold_scores.append(test_score / test_data.shape[0])
728
+
729
+ mean_score = np.mean(fold_scores)
730
+ std_score = np.std(fold_scores)
731
+ results[K] = {'mean_cv_ll': mean_score, 'std_cv_ll': std_score}
732
+ print(f"K={K}: CV log-lik = {mean_score:.4f} ± {std_score:.4f}")
733
+
734
+ best_K = max(results, key=lambda k: results[k]['mean_cv_ll'])
735
+ print(f"\nBest K by CV: {best_K}")
736
+
737
+ return results, best_K
738
+
739
+
740
+ def state_stability_analysis(data, lengths, K, n_splits=10, n_restarts=30):
741
+ """Check if states are reproducible across random splits and initializations.
742
+
743
+ Fits the model on random 50/50 splits of the data and measures how
744
+ similar the inferred states are. Stable states should be recoverable
745
+ across splits.
746
+ """
747
+ from scipy.optimize import linear_sum_assignment
748
+ from scipy.spatial.distance import cdist
749
+
750
+ all_means = []
751
+
752
+ for split in range(n_splits):
753
+ rng = np.random.RandomState(split)
754
+ # Random split of time points (within runs)
755
+ split_data_list = []
756
+ split_lengths = []
757
+ offset = 0
758
+ for length in lengths:
759
+ run_data = data[offset:offset + length]
760
+ mid = length // 2
761
+ if rng.random() > 0.5:
762
+ split_data_list.append(run_data[:mid])
763
+ else:
764
+ split_data_list.append(run_data[mid:])
765
+ split_lengths.append(split_data_list[-1].shape[0])
766
+ offset += length
767
+
768
+ split_data = np.vstack(split_data_list)
769
+ model, _, _ = fit_gaussian_hmm(
770
+ split_data, split_lengths, K,
771
+ n_restarts=n_restarts
772
+ )
773
+ all_means.append(model.means_)
774
+
775
+ # Compare all pairs of splits using Hungarian algorithm
776
+ similarities = []
777
+ for i in range(n_splits):
778
+ for j in range(i + 1, n_splits):
779
+ cost = cdist(all_means[i], all_means[j], metric='correlation')
780
+ row_ind, col_ind = linear_sum_assignment(cost)
781
+ matched_corr = 1 - cost[row_ind, col_ind].mean()
782
+ similarities.append(matched_corr)
783
+
784
+ print(f"State stability (mean matched correlation): {np.mean(similarities):.3f} ± {np.std(similarities):.3f}")
785
+ print(f"Values > 0.8 suggest stable states; < 0.5 suggests instability")
786
+
787
+ return similarities
788
+ ```
789
+
790
+ ---
791
+
792
+ ## 9. State Visualization {#visualization}
793
+
794
+ ```python
795
+ """Visualization utilities for SSM results on fMRI data."""
796
+ import numpy as np
797
+ import matplotlib.pyplot as plt
798
+
799
+ def plot_state_timecourse(states, tr, run_boundaries=None, ax=None,
800
+ state_colors=None, title='State timecourse'):
801
+ """Plot the inferred state sequence over time."""
802
+ if ax is None:
803
+ fig, ax = plt.subplots(figsize=(14, 2))
804
+
805
+ T = len(states)
806
+ K = len(np.unique(states))
807
+ times = np.arange(T) * tr
808
+
809
+ if state_colors is None:
810
+ cmap = plt.cm.Set2
811
+ state_colors = [cmap(i / K) for i in range(K)]
812
+
813
+ for t in range(T - 1):
814
+ ax.axvspan(times[t], times[t + 1], color=state_colors[states[t]], alpha=0.8)
815
+
816
+ if run_boundaries is not None:
817
+ for b in run_boundaries:
818
+ ax.axvline(b * tr, color='black', linewidth=2, linestyle='--', alpha=0.5)
819
+
820
+ ax.set_xlim(0, times[-1])
821
+ ax.set_xlabel('Time (s)')
822
+ ax.set_title(title)
823
+ ax.set_yticks([])
824
+
825
+ return ax
826
+
827
+
828
+ def plot_state_spatial_maps(means, roi_labels=None, n_top_regions=10):
829
+ """Plot the top activated regions for each state."""
830
+ K, p = means.shape
831
+
832
+ fig, axes = plt.subplots(1, K, figsize=(4 * K, 6))
833
+ if K == 1:
834
+ axes = [axes]
835
+
836
+ for k in range(K):
837
+ ax = axes[k]
838
+ state_mean = means[k]
839
+
840
+ # Top positive and negative regions
841
+ top_pos = np.argsort(state_mean)[-n_top_regions:][::-1]
842
+ top_neg = np.argsort(state_mean)[:n_top_regions]
843
+ top_idx = np.concatenate([top_pos, top_neg])
844
+
845
+ values = state_mean[top_idx]
846
+ if roi_labels is not None:
847
+ labels = [roi_labels[i] for i in top_idx]
848
+ else:
849
+ labels = [f'ROI {i}' for i in top_idx]
850
+
851
+ colors = ['#e74c3c' if v > 0 else '#3498db' for v in values]
852
+ ax.barh(range(len(values)), values, color=colors)
853
+ ax.set_yticks(range(len(values)))
854
+ ax.set_yticklabels(labels, fontsize=8)
855
+ ax.set_title(f'State {k + 1}')
856
+ ax.axvline(0, color='black', linewidth=0.5)
857
+
858
+ plt.tight_layout()
859
+ return fig
860
+
861
+
862
+ def plot_transition_matrix(transmat, ax=None):
863
+ """Plot transition probability matrix as heatmap."""
864
+ if ax is None:
865
+ fig, ax = plt.subplots(figsize=(6, 5))
866
+
867
+ K = transmat.shape[0]
868
+ im = ax.imshow(transmat, cmap='Blues', vmin=0, vmax=1)
869
+
870
+ for i in range(K):
871
+ for j in range(K):
872
+ ax.text(j, i, f'{transmat[i, j]:.2f}', ha='center', va='center',
873
+ color='white' if transmat[i, j] > 0.5 else 'black', fontsize=10)
874
+
875
+ ax.set_xticks(range(K))
876
+ ax.set_yticks(range(K))
877
+ ax.set_xticklabels([f'State {k+1}' for k in range(K)])
878
+ ax.set_yticklabels([f'State {k+1}' for k in range(K)])
879
+ ax.set_xlabel('To state')
880
+ ax.set_ylabel('From state')
881
+ ax.set_title('Transition matrix')
882
+ plt.colorbar(im, ax=ax)
883
+
884
+ return ax
885
+
886
+
887
+ def plot_dwell_time_distributions(dwell_times, tr, ax=None):
888
+ """Plot dwell time distributions for each state."""
889
+ K = len(dwell_times)
890
+ if ax is None:
891
+ fig, ax = plt.subplots(figsize=(8, 4))
892
+
893
+ for k in range(K):
894
+ dwells_sec = np.array(dwell_times[k]) * tr
895
+ if len(dwells_sec) > 0:
896
+ ax.hist(dwells_sec, bins=20, alpha=0.5, label=f'State {k+1} '
897
+ f'(mean={dwells_sec.mean():.1f}s)', density=True)
898
+
899
+ ax.set_xlabel('Dwell time (seconds)')
900
+ ax.set_ylabel('Density')
901
+ ax.set_title('Dwell time distributions')
902
+ ax.legend()
903
+
904
+ return ax
905
+ ```
906
+
907
+ ---
908
+
909
+ ## 10. Reproducibility and Initialization {#reproducibility}
910
+
911
+ ```python
912
+ """Best practices for reproducible SSM fitting on fMRI data."""
913
+ import numpy as np
914
+ import json
915
+
916
+ def save_ssm_config(filepath, **kwargs):
917
+ """Save SSM configuration for reproducibility."""
918
+ config = {
919
+ 'model_type': kwargs.get('model_type', 'gaussian_hmm'),
920
+ 'K': kwargs.get('K'),
921
+ 'covariance_type': kwargs.get('covariance_type', 'full'),
922
+ 'n_restarts': kwargs.get('n_restarts', 50),
923
+ 'n_iter': kwargs.get('n_iter', 200),
924
+ 'random_seed': kwargs.get('random_seed', 42),
925
+ 'parcellation': kwargs.get('parcellation', 'schaefer200'),
926
+ 'confound_strategy': kwargs.get('confound_strategy', 'moderate'),
927
+ 'hrf_strategy': kwargs.get('hrf_strategy', 'bold_direct'),
928
+ 'tr': kwargs.get('tr'),
929
+ 'preprocessing': kwargs.get('preprocessing', 'fmriprep+xcpd'),
930
+ 'notes': kwargs.get('notes', ''),
931
+ }
932
+ with open(filepath, 'w') as f:
933
+ json.dump(config, f, indent=2)
934
+ print(f"Config saved to {filepath}")
935
+
936
+
937
+ def align_state_labels(reference_means, target_means):
938
+ """Align state labels between two models using Hungarian algorithm.
939
+
940
+ Use when comparing states across subjects or between runs.
941
+ """
942
+ from scipy.optimize import linear_sum_assignment
943
+ from scipy.spatial.distance import cdist
944
+
945
+ cost = cdist(reference_means, target_means, metric='correlation')
946
+ row_ind, col_ind = linear_sum_assignment(cost)
947
+
948
+ label_mapping = {col: row for row, col in zip(row_ind, col_ind)}
949
+ return label_mapping
950
+ ```
951
+
952
+ ---
953
+
954
+ ## 11. Model Diagnostics: Detecting Pathological Fits {#diagnostics}
955
+
956
+ Before reporting SSM results, always run these checks. Pathological fits produce
957
+ scientifically meaningless states that can appear statistically significant.
958
+
959
+ ```python
960
+ """Diagnostics for detecting common HMM failure modes on fMRI data."""
961
+ import numpy as np
962
+
963
+
964
+ def diagnose_hmm_fit(model, state_seq, confounds, fd, tr,
965
+ dominant_state_threshold=0.70,
966
+ min_dwell_trs=2,
967
+ motion_corr_threshold=0.30):
968
+ """Run a battery of diagnostics on a fitted HMM.
969
+
970
+ Parameters
971
+ ----------
972
+ model : fitted HMM (hmmlearn GaussianHMM or similar)
973
+ state_seq : array, shape (T,)
974
+ Viterbi state sequence
975
+ confounds : array, shape (T, n_confounds)
976
+ Confound matrix (motion params, WM/CSF, etc.)
977
+ fd : array, shape (T,)
978
+ Framewise displacement per TR
979
+ tr : float
980
+ Repetition time in seconds
981
+ dominant_state_threshold : float
982
+ Warn if any state occupies more than this fraction of TRs
983
+ min_dwell_trs : int
984
+ Warn if mean dwell time is below this many TRs
985
+ motion_corr_threshold : float
986
+ Warn if any state's occurrence correlates with FD above this value
987
+
988
+ Returns
989
+ -------
990
+ report : dict
991
+ Diagnostic results with 'warnings' list
992
+ """
993
+ K = model.n_components
994
+ T = len(state_seq)
995
+ warnings = []
996
+
997
+ # --- 11a. Motion-driven state detection ---
998
+ fd_clean = np.nan_to_num(fd, nan=0.0)
999
+ motion_corrs = {}
1000
+ for k in range(K):
1001
+ state_indicator = (state_seq == k).astype(float)
1002
+ r = np.corrcoef(state_indicator, fd_clean)[0, 1]
1003
+ motion_corrs[k] = r
1004
+ if abs(r) > motion_corr_threshold:
1005
+ warnings.append(
1006
+ f"State {k}: |r|={abs(r):.2f} with framewise displacement "
1007
+ f"(threshold {motion_corr_threshold}). May be motion-driven."
1008
+ )
1009
+
1010
+ # --- 11b. Dominant state pathology ---
1011
+ frac_occ = np.array([(state_seq == k).sum() / T for k in range(K)])
1012
+ dominant_states = np.where(frac_occ > dominant_state_threshold)[0]
1013
+ for k in dominant_states:
1014
+ warnings.append(
1015
+ f"State {k} dominates: {frac_occ[k]:.1%} of TRs. "
1016
+ f"Model may have collapsed — check BIC at lower K."
1017
+ )
1018
+
1019
+ # --- 11c. Per-TR switching (too-fast switching) ---
1020
+ dwell_times = {k: [] for k in range(K)}
1021
+ current_state = state_seq[0]
1022
+ current_dwell = 1
1023
+ for t in range(1, T):
1024
+ if state_seq[t] == current_state:
1025
+ current_dwell += 1
1026
+ else:
1027
+ dwell_times[current_state].append(current_dwell)
1028
+ current_state = state_seq[t]
1029
+ current_dwell = 1
1030
+ dwell_times[current_state].append(current_dwell)
1031
+
1032
+ mean_dwell = {k: np.mean(dwell_times[k]) if dwell_times[k] else 0 for k in range(K)}
1033
+ fast_states = [k for k, d in mean_dwell.items() if d < min_dwell_trs]
1034
+ for k in fast_states:
1035
+ warnings.append(
1036
+ f"State {k}: mean dwell = {mean_dwell[k]:.1f} TRs "
1037
+ f"(< {min_dwell_trs} TR minimum). "
1038
+ f"Solutions: add sticky prior, reduce K, check preprocessing."
1039
+ )
1040
+
1041
+ # --- 11d. State-confound correlation check ---
1042
+ confound_corrs = {}
1043
+ for k in range(K):
1044
+ state_indicator = (state_seq == k).astype(float)
1045
+ corrs = [np.corrcoef(state_indicator, confounds[:, j])[0, 1]
1046
+ for j in range(confounds.shape[1])]
1047
+ max_corr = np.max(np.abs(corrs))
1048
+ confound_corrs[k] = max_corr
1049
+ if max_corr > motion_corr_threshold:
1050
+ warnings.append(
1051
+ f"State {k}: max |r|={max_corr:.2f} with confound regressors. "
1052
+ f"Consider re-running with that confound removed from the data."
1053
+ )
1054
+
1055
+ report = {
1056
+ 'fractional_occupancy': frac_occ,
1057
+ 'mean_dwell_times_trs': mean_dwell,
1058
+ 'mean_dwell_times_sec': {k: v * tr for k, v in mean_dwell.items()},
1059
+ 'motion_correlations': motion_corrs,
1060
+ 'confound_correlations': confound_corrs,
1061
+ 'warnings': warnings,
1062
+ }
1063
+
1064
+ if warnings:
1065
+ print(f"=== {len(warnings)} diagnostic warning(s) ===")
1066
+ for w in warnings:
1067
+ print(f" WARNING: {w}")
1068
+ else:
1069
+ print("All diagnostics passed.")
1070
+
1071
+ return report
1072
+ ```
1073
+
1074
+ **Quick interpretation guide:**
1075
+
1076
+ | Warning | Likely cause | Fix |
1077
+ |---------|-------------|-----|
1078
+ | State correlates with FD | Motion artifact state | Tighter scrubbing; check if state disappears after removing high-motion subjects |
1079
+ | Dominant state (>70%) | Model collapsed to trivial solution | Lower K; check for degenerate covariance; more restarts |
1080
+ | Mean dwell < 2 TRs | Noise-driven rapid switching | Add sticky prior (`kappa`); or post-hoc apply minimum dwell-time filter |
1081
+ | High confound correlation | Confound leakage | Revisit confound strategy; ensure confound regression happened before SSM fitting |
1082
+
1083
+ ---
1084
+
1085
+ ## 12. JAX-Based HMM for GPU Acceleration {#jax-gpu}
1086
+
1087
+ Use `ssm` with JAX for GPU-accelerated inference. The API is identical to the CPU version —
1088
+ JAX handles device dispatch automatically.
1089
+
1090
+ ```python
1091
+ """GPU-accelerated HMM using the ssm library with JAX backend.
1092
+
1093
+ Install:
1094
+ pip install jax jaxlib # CPU JAX
1095
+ # For GPU (CUDA 12):
1096
+ pip install "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
1097
+ pip install ssm # Linderman lab ssm (uses JAX internally for some operations)
1098
+
1099
+ When GPU acceleration matters:
1100
+ - >50 subjects or >1000 TRs/subject
1101
+ - rSLDS / SNLDS (Laplace-EM is computationally heavy)
1102
+ - Model selection sweeps over K (embarrassingly parallel)
1103
+ All code in this file runs on CPU — GPU is a drop-in speedup.
1104
+ """
1105
+ import jax
1106
+ import jax.numpy as jnp
1107
+ import numpy as np
1108
+ import ssm
1109
+
1110
+
1111
+ def check_jax_device():
1112
+ """Report which device JAX is using."""
1113
+ backend = jax.default_backend()
1114
+ devices = jax.devices()
1115
+ print(f"JAX backend: {backend}")
1116
+ print(f"Available devices: {devices}")
1117
+ if backend == 'cpu':
1118
+ print("NOTE: Running on CPU. For GPU, install jax[cuda12] and ensure CUDA is available.")
1119
+ return backend
1120
+
1121
+
1122
+ def fit_hmm_jax(data_list, K, D=None, n_restarts=20, n_iters=200):
1123
+ """Fit Gaussian HMM using ssm with JAX acceleration.
1124
+
1125
+ Parameters
1126
+ ----------
1127
+ data_list : list of np.ndarray, each (T, D)
1128
+ Multiple runs / subjects. ssm handles variable-length sequences natively.
1129
+ K : int
1130
+ Number of states
1131
+ D : int or None
1132
+ Observation dimension (inferred from data if None)
1133
+
1134
+ Notes
1135
+ -----
1136
+ ssm uses JAX for its core computations (forward-backward, EM updates).
1137
+ If JAX is configured to use a GPU, inference runs on GPU automatically.
1138
+ Convert data to float32 (vs float64) for faster GPU computation.
1139
+ """
1140
+ check_jax_device()
1141
+
1142
+ if D is None:
1143
+ D = data_list[0].shape[1]
1144
+
1145
+ # Convert to float32 for GPU efficiency (float64 is slower on most GPUs)
1146
+ data_list_f32 = [d.astype(np.float32) for d in data_list]
1147
+
1148
+ best_model = None
1149
+ best_ll = -np.inf
1150
+
1151
+ for restart in range(n_restarts):
1152
+ model = ssm.HMM(
1153
+ K=K,
1154
+ D=D,
1155
+ observations='gaussian',
1156
+ transitions='standard',
1157
+ )
1158
+ lls = model.fit(
1159
+ data_list_f32,
1160
+ method='em',
1161
+ num_iters=n_iters,
1162
+ tolerance=1e-4,
1163
+ )
1164
+ if lls[-1] > best_ll:
1165
+ best_ll = lls[-1]
1166
+ best_model = model
1167
+ best_lls = lls
1168
+
1169
+ print(f"Best log-likelihood: {best_ll:.2f}")
1170
+ return best_model, best_lls
1171
+
1172
+
1173
+ def fit_rslds_gpu(data_list, K, D, latent_dim, n_restarts=5, n_iters=100):
1174
+ """Fit rSLDS on GPU — most useful model to GPU-accelerate (Laplace-EM is heavy).
1175
+
1176
+ rSLDS runtime scales as O(T × D² × latent_dim) per iteration.
1177
+ For large datasets (T > 10000, D > 50), GPU gives 5-10× speedup.
1178
+ """
1179
+ check_jax_device()
1180
+ data_f32 = [d.astype(np.float32) for d in data_list]
1181
+
1182
+ best_model = None
1183
+ best_elbo = -np.inf
1184
+
1185
+ for restart in range(n_restarts):
1186
+ model = ssm.SLDS(
1187
+ N=D,
1188
+ K=K,
1189
+ D=latent_dim,
1190
+ emissions='gaussian_orthog',
1191
+ dynamics='gaussian',
1192
+ transitions='recurrent',
1193
+ )
1194
+ elbos = model.fit(
1195
+ data_f32,
1196
+ method='laplace_em',
1197
+ variational_posterior='structured_meanfield',
1198
+ num_iters=n_iters,
1199
+ initialize=True,
1200
+ )
1201
+ if elbos[-1] > best_elbo:
1202
+ best_elbo = elbos[-1]
1203
+ best_model = model
1204
+
1205
+ return best_model
1206
+
1207
+
1208
+ # --- DyNeMo (osl-dynamics): GPU-native deep generative model ---
1209
+ #
1210
+ # DyNeMo is a variational recurrent neural network that learns dynamic
1211
+ # network modes from BOLD data. It requires a GPU for practical runtimes.
1212
+ #
1213
+ # Install: pip install osl-dynamics tensorflow (or jax-based version)
1214
+ #
1215
+ # Key use case: when HMM-MAR underfits and you need a more expressive model
1216
+ # for naturalistic / resting-state data with complex temporal dynamics.
1217
+ #
1218
+ # from osl_dynamics.models.dynemo import Config, Model
1219
+ # config = Config(
1220
+ # n_modes=K,
1221
+ # n_channels=n_rois,
1222
+ # sequence_length=200,
1223
+ # inference_n_units=64,
1224
+ # inference_normalization='layer',
1225
+ # model_n_units=64,
1226
+ # model_normalization='layer',
1227
+ # learn_means=False,
1228
+ # learn_covariances=True,
1229
+ # batch_size=16,
1230
+ # learning_rate=0.01,
1231
+ # n_epochs=50,
1232
+ # )
1233
+ # model = Model(config)
1234
+ # model.compile()
1235
+ # history = model.fit(data)
1236
+ # alpha = model.get_alpha(data) # (T, K) mixing coefficients
1237
+ ```
1238
+
1239
+ **GPU recommendation summary:**
1240
+
1241
+ | Scenario | Recommendation |
1242
+ |----------|---------------|
1243
+ | hmmlearn, standard Gaussian HMM | CPU is fine; hmmlearn has no JAX/GPU path |
1244
+ | ssm Gaussian HMM, <50 subjects | CPU is fine |
1245
+ | ssm / dynamax HMM, >50 subjects or >1000 TRs | GPU recommended |
1246
+ | ssm rSLDS / SNLDS | GPU strongly recommended; 5-10× speedup |
1247
+ | dynamax custom SSM with jit | GPU recommended; jit alone gives large speedup |
1248
+ | osl-dynamics DyNeMo | GPU required for practical training |
1249
+ | Model selection sweeps (many K values) | Parallelize across GPUs or use batch jobs |
1250
+
1251
+ ---
1252
+
1253
+ ## 13. dynamax: Lego-Style Custom SSMs {#dynamax}
1254
+
1255
+ `dynamax` (probml / Murphy lab) is a **JAX-native** SSM library built around a modular,
1256
+ composable design. Instead of picking from a fixed menu of model types, you assemble your
1257
+ model from independent, swappable pieces — like Lego bricks:
1258
+
1259
+ ```
1260
+ Your model = InitialComponent + TransitionComponent + EmissionComponent
1261
+ ```
1262
+
1263
+ Each component can be swapped independently without touching the others. This makes it
1264
+ trivial to, e.g., test Gaussian vs. diagonal-Gaussian emissions with the same sticky
1265
+ transition model, or compare standard vs. input-driven transitions with identical emissions.
1266
+
1267
+ All inference (EM, Viterbi, forward-backward, Kalman filter/smoother) runs under JAX JIT,
1268
+ giving GPU acceleration and fast CPU execution automatically.
1269
+
1270
+ ```
1271
+ Install: pip install dynamax (requires jax jaxlib)
1272
+ Docs: https://probml.github.io/dynamax/
1273
+ ```
1274
+
1275
+ ### 13a. The Lego pieces available in dynamax
1276
+
1277
+ ```
1278
+ dynamax.hidden_markov_model
1279
+ ├── Transitions
1280
+ │ ├── StandardTransitions # unconstrained K×K matrix
1281
+ │ ├── StickyTransitions # adds κ self-transition bias (= sticky HMM)
1282
+ │ └── (subclass AbstractTransitions for custom)
1283
+
1284
+ ├── Emissions
1285
+ │ ├── GaussianEmissions # full-covariance Gaussian — standard choice for fMRI
1286
+ │ ├── DiagonalGaussianEmissions # diagonal covariance — for high-dimensional data
1287
+ │ ├── LowRankGaussianEmissions # low-rank + diagonal — balances FC and parameters
1288
+ │ ├── SphericalGaussianEmissions # isotropic — minimal parameters
1289
+ │ ├── GaussianMixtureEmissions # mixture-of-Gaussians per state (multi-modal)
1290
+ │ └── (subclass AbstractEmissions for custom HRF-aware emissions, AR emissions, etc.)
1291
+
1292
+ └── Initial
1293
+ └── StandardInitialDistribution # learnable π vector
1294
+
1295
+ dynamax.linear_gaussian_ssm
1296
+ ├── LinearGaussianSSM # standard LGSSM with Kalman filter/smoother
1297
+ └── LinearGaussianConjugateSSM # conjugate priors — EM with closed-form M-step
1298
+ ```
1299
+
1300
+ ### 13b. Drop-in Gaussian HMM (baseline)
1301
+
1302
+ ```python
1303
+ """Standard Gaussian HMM with dynamax — equivalent to hmmlearn but JAX-native."""
1304
+ import jax.numpy as jnp
1305
+ import jax.random as jr
1306
+ from dynamax.hidden_markov_model import GaussianHMM
1307
+
1308
+
1309
+ def fit_gaussian_hmm_dynamax(data_list, K, n_restarts=20, n_iters=100, seed=0):
1310
+ """Fit Gaussian HMM using dynamax EM.
1311
+
1312
+ Parameters
1313
+ ----------
1314
+ data_list : list of np.ndarray, each (T, D)
1315
+ Multiple runs — dynamax handles variable lengths via a list.
1316
+ K : int
1317
+ Number of states
1318
+ n_restarts : int
1319
+ Number of random restarts (take best final log-likelihood)
1320
+ n_iters : int
1321
+ Max EM iterations per restart
1322
+
1323
+ Returns
1324
+ -------
1325
+ best_params : dynamax parameter pytree
1326
+ best_lls : array of log-likelihoods per EM iteration
1327
+ model : GaussianHMM instance (for inference calls)
1328
+ """
1329
+ import numpy as np
1330
+ D = data_list[0].shape[1]
1331
+
1332
+ # dynamax expects a single 2D array for single-sequence fitting,
1333
+ # or use jax.vmap / a loop for multiple sequences.
1334
+ # For multi-run fMRI, fit on concatenated data (pass lengths separately for scoring).
1335
+ emissions = jnp.array(np.vstack(data_list))
1336
+
1337
+ model = GaussianHMM(num_states=K, emission_dim=D)
1338
+
1339
+ best_params = None
1340
+ best_ll = -jnp.inf
1341
+
1342
+ for restart in range(n_restarts):
1343
+ key = jr.PRNGKey(seed + restart)
1344
+
1345
+ # K-means initialization on first restart, random otherwise
1346
+ init_method = "kmeans" if restart == 0 else "prior"
1347
+ params, props = model.initialize(key, method=init_method, emissions=emissions)
1348
+
1349
+ params, lls = model.fit_em(params, props, emissions, num_iters=n_iters)
1350
+
1351
+ if lls[-1] > best_ll:
1352
+ best_ll = lls[-1]
1353
+ best_params = params
1354
+ best_lls = lls
1355
+
1356
+ print(f"Best final log-likelihood: {best_ll:.2f}")
1357
+ return best_params, best_lls, model
1358
+
1359
+
1360
+ def decode_dynamax(model, params, emissions_jnp):
1361
+ """Viterbi decoding and posterior smoothing with dynamax."""
1362
+ # Most likely state sequence (Viterbi)
1363
+ most_likely_states = model.posterior_mode(params, emissions_jnp)
1364
+
1365
+ # Posterior state probabilities (forward-backward smoother)
1366
+ posterior = model.smoother(params, emissions_jnp)
1367
+ smoothed_probs = posterior.smoothed_probs # (T, K)
1368
+
1369
+ return most_likely_states, smoothed_probs
1370
+ ```
1371
+
1372
+ ### 13c. Swapping the emission component (the Lego idea)
1373
+
1374
+ ```python
1375
+ """Replace Gaussian with DiagonalGaussian or LowRankGaussian — same training code."""
1376
+ from dynamax.hidden_markov_model import (
1377
+ GaussianHMM,
1378
+ DiagonalGaussianHMM,
1379
+ LowRankGaussianHMM,
1380
+ )
1381
+ import jax.random as jr
1382
+ import jax.numpy as jnp
1383
+
1384
+
1385
+ def compare_emission_types(emissions_jnp, K, rank=5, seed=0):
1386
+ """Fit three HMMs that differ only in emission covariance structure.
1387
+
1388
+ This is the Lego principle: swap one brick (emission), keep everything else.
1389
+ Use BIC to decide which covariance structure the data supports.
1390
+
1391
+ Parameters
1392
+ ----------
1393
+ emissions_jnp : jax array, shape (T, D)
1394
+ K : int
1395
+ Number of states
1396
+ rank : int
1397
+ Rank for LowRankGaussianHMM (ignored for other types)
1398
+ """
1399
+ T, D = emissions_jnp.shape
1400
+ key = jr.PRNGKey(seed)
1401
+
1402
+ results = {}
1403
+
1404
+ emission_models = {
1405
+ # Full covariance: captures all pairwise FC per state
1406
+ # Parameters per state: D*(D+1)/2 — expensive for large D
1407
+ 'full': GaussianHMM(num_states=K, emission_dim=D),
1408
+
1409
+ # Diagonal covariance: ignores FC, only models per-region variance
1410
+ # Parameters per state: D — use when D is large or data is limited
1411
+ 'diagonal': DiagonalGaussianHMM(num_states=K, emission_dim=D),
1412
+
1413
+ # Low-rank + diagonal: rank-r approximation to FC + independent noise
1414
+ # Parameters per state: D*rank + D — good middle ground for parcellated data
1415
+ 'low_rank': LowRankGaussianHMM(num_states=K, emission_dim=D, emission_rank=rank),
1416
+ }
1417
+
1418
+ for name, model in emission_models.items():
1419
+ k1, k2 = jr.split(jr.fold_in(key, hash(name) % 2**31))
1420
+ params, props = model.initialize(k1, method="kmeans", emissions=emissions_jnp)
1421
+ params, lls = model.fit_em(params, props, emissions_jnp, num_iters=100)
1422
+
1423
+ final_ll = float(lls[-1])
1424
+ # Count parameters for BIC
1425
+ n_transition = K * (K - 1)
1426
+ n_emission = {
1427
+ 'full': K * D * (D + 1) // 2,
1428
+ 'diagonal': K * D,
1429
+ 'low_rank': K * (D * rank + D),
1430
+ }[name]
1431
+ n_params = n_transition + n_emission + (K - 1) # +initial
1432
+ bic = -2 * final_ll + n_params * jnp.log(T)
1433
+
1434
+ results[name] = {'ll': final_ll, 'bic': float(bic), 'params': params, 'model': model}
1435
+ print(f"{name:10s}: LL={final_ll:.1f} BIC={float(bic):.1f} n_params={n_params}")
1436
+
1437
+ best = min(results, key=lambda x: results[x]['bic'])
1438
+ print(f"\nBest emission type by BIC: {best}")
1439
+ return results
1440
+
1441
+
1442
+ # Swap the TRANSITION component instead:
1443
+ # Standard → Sticky, keeping the same GaussianHMM emissions
1444
+ from dynamax.hidden_markov_model import GaussianHMM
1445
+ # The sticky prior is set via transition_matrix_stickiness at initialization:
1446
+ #
1447
+ # params, props = model.initialize(
1448
+ # key,
1449
+ # transition_matrix=jnp.full((K, K), 1.0/K), # starting point
1450
+ # )
1451
+ # Then in props, set the stickiness concentration:
1452
+ # props.transitions.transition_matrix_concentration = ...
1453
+ # props.transitions.transition_matrix_stickiness = 50.0 # κ (higher = stickier)
1454
+ #
1455
+ # See: GaussianHMM(..., transition_matrix_stickiness=50.0) constructor kwarg.
1456
+ ```
1457
+
1458
+ ### 13d. Custom emission model — HRF-aware Gaussian emissions
1459
+
1460
+ ```python
1461
+ """Build a custom emission class that absorbs HRF smoothing in its mean structure.
1462
+
1463
+ The Lego design lets you subclass AbstractEmissions and plug it into any dynamax HMM
1464
+ without touching the transition or initial components.
1465
+
1466
+ This is a sketch — fill in the HRF convolution matrix H and adapt for your data.
1467
+ """
1468
+ import jax.numpy as jnp
1469
+ import jax
1470
+ from dynamax.hidden_markov_model.models import HMM
1471
+ from dynamax.hidden_markov_model.emissions import GaussianEmissions
1472
+
1473
+
1474
+ class HRFGaussianEmissions(GaussianEmissions):
1475
+ """Gaussian emissions whose mean is HRF-convolved neural activity.
1476
+
1477
+ Each state has a neural mean μ_k. The observed emission mean is H @ μ_k,
1478
+ where H is the T×T lower-triangular Toeplitz HRF convolution matrix.
1479
+ This bakes the HRF into the emission model rather than preprocessing.
1480
+
1481
+ Parameters
1482
+ ----------
1483
+ hrf_matrix : jax array, shape (T, T)
1484
+ Pre-computed HRF convolution matrix (lower-triangular Toeplitz).
1485
+ Build with: scipy.linalg.toeplitz(hrf, np.zeros(T)) clipped to (T, T).
1486
+ """
1487
+
1488
+ def __init__(self, num_states, emission_dim, hrf_matrix, **kwargs):
1489
+ super().__init__(num_states, emission_dim, **kwargs)
1490
+ self.hrf_matrix = hrf_matrix # (T, T) fixed — not learned
1491
+
1492
+ def log_likelihoods(self, emissions, inputs, params, **kwargs):
1493
+ # Convolve each state mean with HRF before computing Gaussian likelihood
1494
+ # params.means: (K, D)
1495
+ # hrf_matrix @ params.means.T would give (T, K, D) — reshape as needed
1496
+ convolved_means = jnp.einsum('td,kd->tk', self.hrf_matrix, params.means)
1497
+ # convolved_means: (T, K) for D=1, or extend to (T, K, D) for multivariate
1498
+ # ... then evaluate N(emissions[t] | convolved_means[t,k], Sigma_k)
1499
+ # Full implementation depends on your HRF matrix structure.
1500
+ raise NotImplementedError("Fill in the multivariate Gaussian log-likelihood here.")
1501
+
1502
+
1503
+ # Usage pattern:
1504
+ # hrf_matrix = build_hrf_toeplitz(hrf_kernel, T=n_trs)
1505
+ # emission_component = HRFGaussianEmissions(K, D, hrf_matrix)
1506
+ # model = HMM(num_states=K, initial_component=...,
1507
+ # transition_component=..., emission_component=emission_component)
1508
+ ```
1509
+
1510
+ ### 13e. Linear Gaussian SSM (Kalman filter) for smooth latent dynamics
1511
+
1512
+ ```python
1513
+ """LGSSM with dynamax — useful when you want continuous latent dynamics without discrete switching.
1514
+
1515
+ Think of it as SLDS with K=1 (single regime). Use it as a baseline before fitting SLDS/rSLDS.
1516
+ """
1517
+ import jax.numpy as jnp
1518
+ import jax.random as jr
1519
+ from dynamax.linear_gaussian_ssm import LinearGaussianSSM
1520
+
1521
+
1522
+ def fit_lgssm_fmri(bold_data, state_dim=10, n_iters=100, seed=0):
1523
+ """Fit a Linear Gaussian SSM (Kalman filter model) to BOLD data.
1524
+
1525
+ Learns a low-dimensional latent trajectory that best explains the BOLD signal.
1526
+ No discrete switching — use this as a baseline or for smooth dynamics analyses.
1527
+
1528
+ Parameters
1529
+ ----------
1530
+ bold_data : array, shape (T, D)
1531
+ Preprocessed, parcellated BOLD timeseries
1532
+ state_dim : int
1533
+ Latent state dimension (typically 5–20 for parcellated fMRI)
1534
+
1535
+ Returns
1536
+ -------
1537
+ params : fitted LGSSM parameters
1538
+ posterior : smoothed posterior (means, covariances, marginal_loglik)
1539
+ """
1540
+ T, D = bold_data.shape
1541
+ emissions = jnp.array(bold_data)
1542
+
1543
+ model = LinearGaussianSSM(state_dim=state_dim, emission_dim=D)
1544
+ params, props = model.initialize(jr.PRNGKey(seed))
1545
+ params, lls = model.fit_em(params, props, emissions, num_iters=n_iters)
1546
+
1547
+ # Smooth the latent trajectory
1548
+ posterior = model.smoother(params, emissions)
1549
+ latent_means = posterior.smoothed_means # (T, state_dim)
1550
+ latent_covs = posterior.smoothed_covariances # (T, state_dim, state_dim)
1551
+ marginal_ll = posterior.marginal_loglik
1552
+
1553
+ print(f"Final marginal log-likelihood: {float(marginal_ll):.2f}")
1554
+ print(f"Emission matrix shape: {params.emissions.weights.shape}") # (D, state_dim)
1555
+
1556
+ return params, posterior
1557
+ ```
1558
+
1559
+ ### 13f. When to use dynamax vs. ssm vs. hmmlearn
1560
+
1561
+ | Situation | Best choice |
1562
+ |-----------|------------|
1563
+ | Quickest path to a working Gaussian HMM | `hmmlearn` |
1564
+ | Need rSLDS, SNLDS, or input-driven SLDS | `ssm` |
1565
+ | Want to experiment with different emission/transition types rapidly | **`dynamax`** |
1566
+ | Want to build a custom emission (e.g., HRF-aware, AR, mixture) | **`dynamax`** (subclass) |
1567
+ | Want Kalman filter / LGSSM as a baseline | **`dynamax`** |
1568
+ | Need JAX JIT + GPU for large-scale inference | **`dynamax`** or `ssm` |
1569
+ | Group-level HMM with neuroimaging-specific features | `glhmm` |
1570
+ | Deep generative model (DyNeMo) | `osl-dynamics` |