py-neuromodulation 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py_neuromodulation/ConnectivityDecoding/_get_grid_hull.m +34 -34
- py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +95 -106
- py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +107 -119
- py_neuromodulation/__init__.py +80 -13
- py_neuromodulation/{nm_RMAP.py → analysis/RMAP.py} +496 -531
- py_neuromodulation/analysis/__init__.py +4 -0
- py_neuromodulation/{nm_decode.py → analysis/decode.py} +918 -992
- py_neuromodulation/{nm_analysis.py → analysis/feature_reader.py} +994 -1074
- py_neuromodulation/{nm_plots.py → analysis/plots.py} +627 -612
- py_neuromodulation/{nm_stats.py → analysis/stats.py} +458 -480
- py_neuromodulation/data/README +6 -6
- py_neuromodulation/data/dataset_description.json +8 -8
- py_neuromodulation/data/participants.json +32 -32
- py_neuromodulation/data/participants.tsv +2 -2
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_coordsystem.json +5 -5
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_electrodes.tsv +11 -11
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_channels.tsv +11 -11
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.json +18 -18
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vhdr +35 -35
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vmrk +13 -13
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/sub-testsub_ses-EphysMedOff_scans.tsv +2 -2
- py_neuromodulation/default_settings.yaml +241 -0
- py_neuromodulation/features/__init__.py +31 -0
- py_neuromodulation/features/bandpower.py +165 -0
- py_neuromodulation/features/bispectra.py +157 -0
- py_neuromodulation/features/bursts.py +297 -0
- py_neuromodulation/features/coherence.py +255 -0
- py_neuromodulation/features/feature_processor.py +121 -0
- py_neuromodulation/features/fooof.py +142 -0
- py_neuromodulation/features/hjorth_raw.py +57 -0
- py_neuromodulation/features/linelength.py +21 -0
- py_neuromodulation/features/mne_connectivity.py +148 -0
- py_neuromodulation/features/nolds.py +94 -0
- py_neuromodulation/features/oscillatory.py +249 -0
- py_neuromodulation/features/sharpwaves.py +432 -0
- py_neuromodulation/filter/__init__.py +3 -0
- py_neuromodulation/filter/kalman_filter.py +67 -0
- py_neuromodulation/filter/kalman_filter_external.py +1890 -0
- py_neuromodulation/filter/mne_filter.py +128 -0
- py_neuromodulation/filter/notch_filter.py +93 -0
- py_neuromodulation/grid_cortex.tsv +40 -40
- py_neuromodulation/liblsl/libpugixml.so.1.12 +0 -0
- py_neuromodulation/liblsl/linux/bionic_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/bookworm_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/focal_amd46/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/jammy_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/jammy_x86/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/noble_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/macos/amd64/liblsl.1.16.2.dylib +0 -0
- py_neuromodulation/liblsl/macos/arm64/liblsl.1.16.0.dylib +0 -0
- py_neuromodulation/liblsl/windows/amd64/liblsl.1.16.2.dll +0 -0
- py_neuromodulation/liblsl/windows/x86/liblsl.1.16.2.dll +0 -0
- py_neuromodulation/processing/__init__.py +10 -0
- py_neuromodulation/{nm_artifacts.py → processing/artifacts.py} +29 -25
- py_neuromodulation/processing/data_preprocessor.py +77 -0
- py_neuromodulation/processing/filter_preprocessing.py +78 -0
- py_neuromodulation/processing/normalization.py +175 -0
- py_neuromodulation/{nm_projection.py → processing/projection.py} +370 -394
- py_neuromodulation/{nm_rereference.py → processing/rereference.py} +97 -95
- py_neuromodulation/{nm_resample.py → processing/resample.py} +56 -50
- py_neuromodulation/stream/__init__.py +3 -0
- py_neuromodulation/stream/data_processor.py +325 -0
- py_neuromodulation/stream/generator.py +53 -0
- py_neuromodulation/stream/mnelsl_player.py +94 -0
- py_neuromodulation/stream/mnelsl_stream.py +120 -0
- py_neuromodulation/stream/settings.py +292 -0
- py_neuromodulation/stream/stream.py +427 -0
- py_neuromodulation/utils/__init__.py +2 -0
- py_neuromodulation/{nm_define_nmchannels.py → utils/channels.py} +305 -302
- py_neuromodulation/utils/database.py +149 -0
- py_neuromodulation/utils/io.py +378 -0
- py_neuromodulation/utils/keyboard.py +52 -0
- py_neuromodulation/utils/logging.py +66 -0
- py_neuromodulation/utils/types.py +251 -0
- {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/METADATA +28 -33
- py_neuromodulation-0.0.6.dist-info/RECORD +89 -0
- {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/WHEEL +1 -1
- {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/licenses/LICENSE +21 -21
- py_neuromodulation/FieldTrip.py +0 -589
- py_neuromodulation/_write_example_dataset_helper.py +0 -65
- py_neuromodulation/nm_EpochStream.py +0 -92
- py_neuromodulation/nm_IO.py +0 -417
- py_neuromodulation/nm_across_patient_decoding.py +0 -927
- py_neuromodulation/nm_bispectra.py +0 -168
- py_neuromodulation/nm_bursts.py +0 -198
- py_neuromodulation/nm_coherence.py +0 -205
- py_neuromodulation/nm_cohortwrapper.py +0 -435
- py_neuromodulation/nm_eval_timing.py +0 -239
- py_neuromodulation/nm_features.py +0 -116
- py_neuromodulation/nm_features_abc.py +0 -39
- py_neuromodulation/nm_filter.py +0 -219
- py_neuromodulation/nm_filter_preprocessing.py +0 -91
- py_neuromodulation/nm_fooof.py +0 -159
- py_neuromodulation/nm_generator.py +0 -37
- py_neuromodulation/nm_hjorth_raw.py +0 -73
- py_neuromodulation/nm_kalmanfilter.py +0 -58
- py_neuromodulation/nm_linelength.py +0 -33
- py_neuromodulation/nm_mne_connectivity.py +0 -112
- py_neuromodulation/nm_nolds.py +0 -93
- py_neuromodulation/nm_normalization.py +0 -214
- py_neuromodulation/nm_oscillatory.py +0 -448
- py_neuromodulation/nm_run_analysis.py +0 -435
- py_neuromodulation/nm_settings.json +0 -338
- py_neuromodulation/nm_settings.py +0 -68
- py_neuromodulation/nm_sharpwaves.py +0 -401
- py_neuromodulation/nm_stream_abc.py +0 -218
- py_neuromodulation/nm_stream_offline.py +0 -359
- py_neuromodulation/utils/_logging.py +0 -24
- py_neuromodulation-0.0.4.dist-info/RECORD +0 -72
|
@@ -1,992 +1,918 @@
|
|
|
1
|
-
from sklearn import
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
from
|
|
13
|
-
|
|
14
|
-
from
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
self
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
self.
|
|
57
|
-
self.
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
self.
|
|
178
|
-
self.
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
]
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
)
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
#
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
cv_res
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
cv_res
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
)
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
if self.STACK_FEATURES_N_SAMPLES
|
|
679
|
-
if
|
|
680
|
-
X_train, y_train
|
|
681
|
-
X_train,
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
)
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
cv_res = self.
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
Parameters
|
|
920
|
-
----------
|
|
921
|
-
X_train: np.ndarray
|
|
922
|
-
y_train: np.ndarray
|
|
923
|
-
X_test: np.ndarray
|
|
924
|
-
y_test: np.ndarray
|
|
925
|
-
rounds : int, optional
|
|
926
|
-
optimizing rounds, by default 10
|
|
927
|
-
base_estimator : str, optional
|
|
928
|
-
surrogate model, used as optimization function instead of cross validation, by default "GP"
|
|
929
|
-
acq_func : str, optional
|
|
930
|
-
function to minimize over the posterior distribution, by default "EI"
|
|
931
|
-
acq_optimizer : str, optional
|
|
932
|
-
method to minimize the acquisition function, by default "sampling"
|
|
933
|
-
initial_point_generator : str, optional
|
|
934
|
-
sets a initial point generator, by default "lhs"
|
|
935
|
-
|
|
936
|
-
Returns
|
|
937
|
-
-------
|
|
938
|
-
skopt result parameters
|
|
939
|
-
"""
|
|
940
|
-
|
|
941
|
-
def get_f_val(model_bo):
|
|
942
|
-
|
|
943
|
-
try:
|
|
944
|
-
model_bo = self.fit_model(model_bo, X_train, y_train)
|
|
945
|
-
except Decoder.ClassMissingException:
|
|
946
|
-
pass
|
|
947
|
-
|
|
948
|
-
return self.eval_method(y_test, model_bo.predict(X_test))
|
|
949
|
-
|
|
950
|
-
opt = Optimizer(
|
|
951
|
-
self.bay_opt_param_space,
|
|
952
|
-
base_estimator=base_estimator,
|
|
953
|
-
acq_func=acq_func,
|
|
954
|
-
acq_optimizer=acq_optimizer,
|
|
955
|
-
initial_point_generator=initial_point_generator,
|
|
956
|
-
)
|
|
957
|
-
|
|
958
|
-
for _ in range(rounds):
|
|
959
|
-
next_x = opt.ask()
|
|
960
|
-
# set model values
|
|
961
|
-
model_bo = clone(self.model)
|
|
962
|
-
for i in range(len(next_x)):
|
|
963
|
-
setattr(model_bo, self.bay_opt_param_space[i].name, next_x[i])
|
|
964
|
-
f_val = get_f_val(model_bo)
|
|
965
|
-
res = opt.tell(next_x, f_val)
|
|
966
|
-
if self.VERBOSE:
|
|
967
|
-
print(f_val)
|
|
968
|
-
|
|
969
|
-
# res is here automatically appended by skopt
|
|
970
|
-
return res.x
|
|
971
|
-
|
|
972
|
-
def save(
|
|
973
|
-
self, feature_path: str, feature_file: str, str_save_add=None
|
|
974
|
-
) -> None:
|
|
975
|
-
"""Save decoder object to pickle"""
|
|
976
|
-
|
|
977
|
-
# why is the decoder not saved to a .json?
|
|
978
|
-
|
|
979
|
-
if str_save_add is None:
|
|
980
|
-
PATH_OUT = os.path.join(
|
|
981
|
-
feature_path, feature_file, feature_file + "_ML_RES.p"
|
|
982
|
-
)
|
|
983
|
-
else:
|
|
984
|
-
PATH_OUT = os.path.join(
|
|
985
|
-
feature_path,
|
|
986
|
-
feature_file,
|
|
987
|
-
feature_file + "_" + str_save_add + "_ML_RES.p",
|
|
988
|
-
)
|
|
989
|
-
|
|
990
|
-
print("model being saved to: " + str(PATH_OUT))
|
|
991
|
-
with open(PATH_OUT, "wb") as output: # Overwrites any existing file.
|
|
992
|
-
cPickle.dump(self, output)
|
|
1
|
+
from sklearn import model_selection
|
|
2
|
+
from sklearn.linear_model import LinearRegression
|
|
3
|
+
from sklearn.base import clone
|
|
4
|
+
from sklearn.metrics import r2_score
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import numpy as np
|
|
8
|
+
from copy import deepcopy
|
|
9
|
+
from pathlib import PurePath
|
|
10
|
+
import pickle
|
|
11
|
+
|
|
12
|
+
from py_neuromodulation import logger
|
|
13
|
+
|
|
14
|
+
from typing import Callable
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CV_res:
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
get_movement_detection_rate: bool = False,
|
|
21
|
+
RUN_BAY_OPT: bool = False,
|
|
22
|
+
mrmr_select: bool = False,
|
|
23
|
+
model_save: bool = False,
|
|
24
|
+
) -> None:
|
|
25
|
+
self.score_train: list = []
|
|
26
|
+
self.score_test: list = []
|
|
27
|
+
self.y_test: list = []
|
|
28
|
+
self.y_train: list = []
|
|
29
|
+
self.y_test_pr: list = []
|
|
30
|
+
self.y_train_pr: list = []
|
|
31
|
+
self.X_test: list = []
|
|
32
|
+
self.X_train: list = []
|
|
33
|
+
self.coef: list = []
|
|
34
|
+
|
|
35
|
+
if get_movement_detection_rate:
|
|
36
|
+
self.mov_detection_rates_test: list = []
|
|
37
|
+
self.tprate_test: list = []
|
|
38
|
+
self.fprate_test: list = []
|
|
39
|
+
self.mov_detection_rates_train: list = []
|
|
40
|
+
self.tprate_train: list = []
|
|
41
|
+
self.fprate_train: list = []
|
|
42
|
+
if RUN_BAY_OPT:
|
|
43
|
+
self.best_bay_opt_params: list = []
|
|
44
|
+
if mrmr_select:
|
|
45
|
+
self.mrmr_select: list = []
|
|
46
|
+
if model_save:
|
|
47
|
+
self.model_save: list = []
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class Decoder:
|
|
51
|
+
class ClassMissingException(Exception):
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
message="Only one class present.",
|
|
55
|
+
) -> None:
|
|
56
|
+
self.message = message
|
|
57
|
+
super().__init__(self.message)
|
|
58
|
+
|
|
59
|
+
def __str__(self):
|
|
60
|
+
return self.message
|
|
61
|
+
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
features: "pd.DataFrame| None " = None,
|
|
65
|
+
label: np.ndarray | None = None,
|
|
66
|
+
label_name: str | None = None,
|
|
67
|
+
used_chs: list[str] = [],
|
|
68
|
+
model=LinearRegression(),
|
|
69
|
+
eval_method: Callable = r2_score,
|
|
70
|
+
cv_method=model_selection.KFold(n_splits=3, shuffle=False),
|
|
71
|
+
use_nested_cv: bool = False,
|
|
72
|
+
threshold_score=True,
|
|
73
|
+
mov_detection_threshold: float = 0.5,
|
|
74
|
+
TRAIN_VAL_SPLIT: bool = False,
|
|
75
|
+
RUN_BAY_OPT: bool = False,
|
|
76
|
+
STACK_FEATURES_N_SAMPLES: bool = False,
|
|
77
|
+
time_stack_n_samples: int = 5,
|
|
78
|
+
save_coef: bool = False,
|
|
79
|
+
get_movement_detection_rate: bool = False,
|
|
80
|
+
min_consequent_count: int = 3,
|
|
81
|
+
bay_opt_param_space: list = [],
|
|
82
|
+
VERBOSE: bool = False,
|
|
83
|
+
sfreq: int | None = None,
|
|
84
|
+
undersampling: bool = False,
|
|
85
|
+
oversampling: bool = False,
|
|
86
|
+
mrmr_select: bool = False,
|
|
87
|
+
pca: bool = False,
|
|
88
|
+
cca: bool = False,
|
|
89
|
+
model_save: bool = False,
|
|
90
|
+
) -> None:
|
|
91
|
+
"""Initialize here a feature file for processing
|
|
92
|
+
Read settings.json channels.csv and features.csv
|
|
93
|
+
Read target label
|
|
94
|
+
|
|
95
|
+
Parameters
|
|
96
|
+
----------
|
|
97
|
+
model : machine learning model
|
|
98
|
+
model that utilizes fit and predict functions
|
|
99
|
+
eval_method : sklearn metrics
|
|
100
|
+
evaluation scoring method, will default to r2_score if not passed
|
|
101
|
+
cv_method : sklearm model_selection method
|
|
102
|
+
threshold_score : boolean
|
|
103
|
+
if True set lower threshold at zero (useful for r2),
|
|
104
|
+
mov_detection_threshold : float
|
|
105
|
+
if get_movement_detection_rate is True, find given minimum 'threshold' respective
|
|
106
|
+
consecutive movement blocks, by default 0.5
|
|
107
|
+
TRAIN_VAL_SPLIT (boolean):
|
|
108
|
+
if true split data into additinal validation, and run class weighted CV
|
|
109
|
+
save_coef (boolean):
|
|
110
|
+
if true, save model._coef trained coefficients
|
|
111
|
+
get_movement_detection_rate (boolean):
|
|
112
|
+
save detection rate and tpr / fpr as well
|
|
113
|
+
min_consequent_count (int):
|
|
114
|
+
if get_movement_detection_rate is True, find given 'min_consequent_count' respective
|
|
115
|
+
consecutive movement blocks with minimum size of 'min_consequent_count'
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
self.model = model
|
|
119
|
+
self.eval_method = eval_method
|
|
120
|
+
self.cv_method = cv_method
|
|
121
|
+
self.use_nested_cv = use_nested_cv
|
|
122
|
+
self.threshold_score = threshold_score
|
|
123
|
+
self.mov_detection_threshold = mov_detection_threshold
|
|
124
|
+
self.TRAIN_VAL_SPLIT = TRAIN_VAL_SPLIT
|
|
125
|
+
self.RUN_BAY_OPT = RUN_BAY_OPT
|
|
126
|
+
self.save_coef = save_coef
|
|
127
|
+
self.sfreq = sfreq
|
|
128
|
+
self.get_movement_detection_rate = get_movement_detection_rate
|
|
129
|
+
self.min_consequent_count = min_consequent_count
|
|
130
|
+
self.STACK_FEATURES_N_SAMPLES = STACK_FEATURES_N_SAMPLES
|
|
131
|
+
self.time_stack_n_samples = time_stack_n_samples
|
|
132
|
+
self.bay_opt_param_space = bay_opt_param_space
|
|
133
|
+
self.VERBOSE = VERBOSE
|
|
134
|
+
self.undersampling = undersampling
|
|
135
|
+
self.oversampling = oversampling
|
|
136
|
+
self.mrmr_select = mrmr_select
|
|
137
|
+
self.used_chs = used_chs
|
|
138
|
+
self.label = label
|
|
139
|
+
self.label_name = label_name
|
|
140
|
+
self.cca = cca
|
|
141
|
+
self.pca = pca
|
|
142
|
+
self.model_save = model_save
|
|
143
|
+
|
|
144
|
+
self.set_data(features)
|
|
145
|
+
|
|
146
|
+
self.ch_ind_data = {}
|
|
147
|
+
self.grid_point_ind_data = {}
|
|
148
|
+
self.active_gridpoints = []
|
|
149
|
+
self.feature_names = []
|
|
150
|
+
self.ch_ind_results = {}
|
|
151
|
+
self.gridpoint_ind_results = {}
|
|
152
|
+
self.all_ch_results = {}
|
|
153
|
+
self.columns_names_single_ch = None
|
|
154
|
+
|
|
155
|
+
if undersampling:
|
|
156
|
+
from imblearn.under_sampling import RandomUnderSampler
|
|
157
|
+
|
|
158
|
+
self.rus = RandomUnderSampler(random_state=0)
|
|
159
|
+
|
|
160
|
+
if oversampling:
|
|
161
|
+
from imblearn.over_sampling import RandomOverSampler
|
|
162
|
+
|
|
163
|
+
self.ros = RandomOverSampler(random_state=0)
|
|
164
|
+
|
|
165
|
+
def set_data(self, features):
|
|
166
|
+
if features is not None:
|
|
167
|
+
self.features = features
|
|
168
|
+
self.feature_names = [
|
|
169
|
+
col
|
|
170
|
+
for col in self.features.columns
|
|
171
|
+
if not (("time" in col) or (self.label_name in col))
|
|
172
|
+
]
|
|
173
|
+
self.data = np.nan_to_num(np.array(self.features[self.feature_names]))
|
|
174
|
+
|
|
175
|
+
def set_data_ind_channels(self):
|
|
176
|
+
"""specified channel individual data"""
|
|
177
|
+
self.ch_ind_data = {}
|
|
178
|
+
for ch in self.used_chs:
|
|
179
|
+
self.ch_ind_data[ch] = np.nan_to_num(
|
|
180
|
+
np.array(
|
|
181
|
+
self.features[
|
|
182
|
+
[col for col in self.features.columns if col.startswith(ch)]
|
|
183
|
+
]
|
|
184
|
+
)
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
def set_CV_results(self, attr_name, contact_point=None):
|
|
188
|
+
"""set CV results in respectie nm_decode attributes
|
|
189
|
+
The reference is first stored in obj_set, and the used lateron
|
|
190
|
+
|
|
191
|
+
Parameters
|
|
192
|
+
----------
|
|
193
|
+
attr_name : string
|
|
194
|
+
is either all_ch_results, ch_ind_results, gridpoint_ind_results
|
|
195
|
+
contact_point : object, optional
|
|
196
|
+
usually an int specifying the grid_point or string, specifying the used channel,
|
|
197
|
+
by default None
|
|
198
|
+
"""
|
|
199
|
+
if contact_point is not None:
|
|
200
|
+
getattr(self, attr_name)[contact_point] = {}
|
|
201
|
+
obj_set = getattr(self, attr_name)[contact_point]
|
|
202
|
+
else:
|
|
203
|
+
obj_set = getattr(self, attr_name)
|
|
204
|
+
|
|
205
|
+
def set_scores(cv_res: CV_res, set_inner_CV_res: bool = False):
|
|
206
|
+
"""
|
|
207
|
+
This function renames the CV_res keys for InnerCV
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
def set_score(key_: str, val):
|
|
211
|
+
if set_inner_CV_res:
|
|
212
|
+
key_ = "InnerCV_" + key_
|
|
213
|
+
obj_set[key_] = val
|
|
214
|
+
|
|
215
|
+
set_score("score_train", cv_res.score_train)
|
|
216
|
+
set_score("score_test", cv_res.score_test)
|
|
217
|
+
set_score("y_test", cv_res.y_test)
|
|
218
|
+
set_score("y_train", cv_res.y_train)
|
|
219
|
+
set_score("y_test_pr", cv_res.y_test_pr)
|
|
220
|
+
set_score("y_train_pr", cv_res.y_train_pr)
|
|
221
|
+
set_score("X_train", cv_res.X_train)
|
|
222
|
+
set_score("X_test", cv_res.X_test)
|
|
223
|
+
|
|
224
|
+
if self.save_coef:
|
|
225
|
+
set_score("coef", cv_res.coef)
|
|
226
|
+
if self.get_movement_detection_rate:
|
|
227
|
+
set_score("mov_detection_rates_test", cv_res.mov_detection_rates_test)
|
|
228
|
+
set_score(
|
|
229
|
+
"mov_detection_rates_train",
|
|
230
|
+
cv_res.mov_detection_rates_train,
|
|
231
|
+
)
|
|
232
|
+
set_score("fprate_test", cv_res.fprate_test)
|
|
233
|
+
set_score("fprate_train", cv_res.fprate_train)
|
|
234
|
+
set_score("tprate_test", cv_res.tprate_test)
|
|
235
|
+
set_score("tprate_train", cv_res.tprate_train)
|
|
236
|
+
|
|
237
|
+
if self.RUN_BAY_OPT:
|
|
238
|
+
set_score("best_bay_opt_params", cv_res.best_bay_opt_params)
|
|
239
|
+
|
|
240
|
+
if self.mrmr_select:
|
|
241
|
+
set_score("mrmr_select", cv_res.mrmr_select)
|
|
242
|
+
if self.model_save:
|
|
243
|
+
set_score("model_save", cv_res.model_save)
|
|
244
|
+
return obj_set
|
|
245
|
+
|
|
246
|
+
obj_set = set_scores(self.cv_res)
|
|
247
|
+
|
|
248
|
+
if self.use_nested_cv:
|
|
249
|
+
obj_set = set_scores(self.cv_res_inner, set_inner_CV_res=True)
|
|
250
|
+
|
|
251
|
+
def run_CV_caller(self, feature_contacts: str = "ind_channels"):
|
|
252
|
+
"""Wrapper that call for all channels / grid points / combined channels the CV function
|
|
253
|
+
|
|
254
|
+
Parameters
|
|
255
|
+
----------
|
|
256
|
+
feature_contacts : str, optional
|
|
257
|
+
"grid_points", "ind_channels" or "all_channels_combined" , by default "ind_channels"
|
|
258
|
+
"""
|
|
259
|
+
valid_feature_contacts = [
|
|
260
|
+
"ind_channels",
|
|
261
|
+
"all_channels_combined",
|
|
262
|
+
"grid_points",
|
|
263
|
+
]
|
|
264
|
+
if feature_contacts not in valid_feature_contacts:
|
|
265
|
+
raise ValueError(f"{feature_contacts} not in {valid_feature_contacts}")
|
|
266
|
+
|
|
267
|
+
if feature_contacts == "grid_points":
|
|
268
|
+
for grid_point in self.active_gridpoints:
|
|
269
|
+
self.run_CV(self.grid_point_ind_data[grid_point], self.label)
|
|
270
|
+
self.set_CV_results("gridpoint_ind_results", contact_point=grid_point)
|
|
271
|
+
return self.gridpoint_ind_results
|
|
272
|
+
|
|
273
|
+
if feature_contacts == "ind_channels":
|
|
274
|
+
for ch in self.used_chs:
|
|
275
|
+
self.ch_name_tested = ch
|
|
276
|
+
self.run_CV(self.ch_ind_data[ch], self.label)
|
|
277
|
+
self.set_CV_results("ch_ind_results", contact_point=ch)
|
|
278
|
+
return self.ch_ind_results
|
|
279
|
+
|
|
280
|
+
if feature_contacts == "all_channels_combined":
|
|
281
|
+
dat_combined = np.array(self.data)
|
|
282
|
+
self.run_CV(dat_combined, self.label)
|
|
283
|
+
self.set_CV_results("all_ch_results", contact_point=None)
|
|
284
|
+
return self.all_ch_results
|
|
285
|
+
|
|
286
|
+
def set_data_grid_points(self, cortex_only=False, subcortex_only=False):
|
|
287
|
+
"""Read the run_analysis
|
|
288
|
+
Projected data has the shape (samples, grid points, features)
|
|
289
|
+
"""
|
|
290
|
+
|
|
291
|
+
# activate_gridpoints stores cortex + subcortex data
|
|
292
|
+
self.active_gridpoints = np.unique(
|
|
293
|
+
[
|
|
294
|
+
i.split("_")[0] + "_" + i.split("_")[1]
|
|
295
|
+
for i in self.features.columns
|
|
296
|
+
if "grid" in i
|
|
297
|
+
]
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
if cortex_only:
|
|
301
|
+
self.active_gridpoints = [
|
|
302
|
+
i for i in self.active_gridpoints if i.startswith("gridcortex")
|
|
303
|
+
]
|
|
304
|
+
|
|
305
|
+
if subcortex_only:
|
|
306
|
+
self.active_gridpoints = [
|
|
307
|
+
i for i in self.active_gridpoints if i.startswith("gridsubcortex")
|
|
308
|
+
]
|
|
309
|
+
|
|
310
|
+
self.feature_names = [
|
|
311
|
+
i[len(self.active_gridpoints[0] + "_") :]
|
|
312
|
+
for i in self.features.columns
|
|
313
|
+
if self.active_gridpoints[0] + "_" in i
|
|
314
|
+
]
|
|
315
|
+
|
|
316
|
+
self.grid_point_ind_data = {}
|
|
317
|
+
|
|
318
|
+
self.grid_point_ind_data = {
|
|
319
|
+
grid_point: np.nan_to_num(
|
|
320
|
+
self.features[
|
|
321
|
+
[i for i in self.features.columns if grid_point + "_" in i]
|
|
322
|
+
]
|
|
323
|
+
)
|
|
324
|
+
for grid_point in self.active_gridpoints
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
def get_movement_grouped_array(
|
|
328
|
+
self, prediction, threshold=0.5, min_consequent_count=5
|
|
329
|
+
):
|
|
330
|
+
"""Return given a 1D numpy array, an array of same size with grouped consective blocks
|
|
331
|
+
|
|
332
|
+
Parameters
|
|
333
|
+
----------
|
|
334
|
+
prediction : np.ndarray
|
|
335
|
+
numpy array of either predictions or labels, that is going to be grouped
|
|
336
|
+
threshold : float, optional
|
|
337
|
+
threshold to be applied to 'prediction', by default 0.5
|
|
338
|
+
min_consequent_count : int, optional
|
|
339
|
+
minimum required consective samples higher than 'threshold', by default 5
|
|
340
|
+
|
|
341
|
+
Returns
|
|
342
|
+
-------
|
|
343
|
+
labeled_array : np.ndarray
|
|
344
|
+
grouped vector with incrementing number for movement blocks
|
|
345
|
+
labels_count : int
|
|
346
|
+
count of individual movement blocks
|
|
347
|
+
"""
|
|
348
|
+
|
|
349
|
+
from scipy.ndimage import label as label_ndimage
|
|
350
|
+
from scipy.ndimage import binary_dilation, binary_erosion
|
|
351
|
+
|
|
352
|
+
mask = prediction > threshold
|
|
353
|
+
structure = [True] * min_consequent_count # used for erosion and dilation
|
|
354
|
+
eroded = binary_erosion(mask, structure)
|
|
355
|
+
dilated = binary_dilation(eroded, structure)
|
|
356
|
+
labeled_array, labels_count = label_ndimage(dilated)
|
|
357
|
+
return labeled_array, labels_count
|
|
358
|
+
|
|
359
|
+
def calc_movement_detection_rate(
|
|
360
|
+
self, y_label, prediction, threshold=0.5, min_consequent_count=3
|
|
361
|
+
):
|
|
362
|
+
"""Given a label and prediction, return the movement detection rate on the basis of
|
|
363
|
+
movements classified in blocks of 'min_consequent_count'.
|
|
364
|
+
|
|
365
|
+
Parameters
|
|
366
|
+
----------
|
|
367
|
+
y_label : [type]
|
|
368
|
+
[description]
|
|
369
|
+
prediction : [type]
|
|
370
|
+
[description]
|
|
371
|
+
threshold : float, optional
|
|
372
|
+
threshold to be applied to 'prediction', by default 0.5
|
|
373
|
+
min_consequent_count : int, optional
|
|
374
|
+
minimum required consective samples higher than 'threshold', by default 3
|
|
375
|
+
|
|
376
|
+
Returns
|
|
377
|
+
-------
|
|
378
|
+
mov_detection_rate : float
|
|
379
|
+
movement detection rate, where at least 'min_consequent_count' samples where high in prediction
|
|
380
|
+
fpr : np.ndarray
|
|
381
|
+
sklearn.metrics false positive rate np.ndarray
|
|
382
|
+
tpr : np.ndarray
|
|
383
|
+
sklearn.metrics true positive rate np.ndarray
|
|
384
|
+
"""
|
|
385
|
+
from sklearn.metrics import confusion_matrix
|
|
386
|
+
|
|
387
|
+
pred_grouped, _ = self.get_movement_grouped_array(
|
|
388
|
+
prediction, threshold, min_consequent_count
|
|
389
|
+
)
|
|
390
|
+
y_grouped, labels_count = self.get_movement_grouped_array(
|
|
391
|
+
y_label, threshold, min_consequent_count
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
hit_rate = np.zeros(labels_count)
|
|
395
|
+
pred_group_bin = np.array(pred_grouped > 0)
|
|
396
|
+
|
|
397
|
+
for label_number in range(1, labels_count + 1): # labeling starts from 1
|
|
398
|
+
hit_rate[label_number - 1] = np.sum(
|
|
399
|
+
pred_group_bin[np.where(y_grouped == label_number)[0]]
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
try:
|
|
403
|
+
mov_detection_rate = np.where(hit_rate > 0)[0].shape[0] / labels_count
|
|
404
|
+
except ZeroDivisionError:
|
|
405
|
+
logger.warning("no movements in label")
|
|
406
|
+
return 0, 0, 0
|
|
407
|
+
|
|
408
|
+
# calculating TPR and FPR: https://stackoverflow.com/a/40324184/5060208
|
|
409
|
+
CM = confusion_matrix(y_label, prediction)
|
|
410
|
+
|
|
411
|
+
TN = CM[0][0]
|
|
412
|
+
FN = CM[1][0]
|
|
413
|
+
TP = CM[1][1]
|
|
414
|
+
FP = CM[0][1]
|
|
415
|
+
fpr = FP / (FP + TN)
|
|
416
|
+
tpr = TP / (TP + FN)
|
|
417
|
+
|
|
418
|
+
return mov_detection_rate, fpr, tpr
|
|
419
|
+
|
|
420
|
+
def init_cv_res(self) -> None:
|
|
421
|
+
return CV_res(
|
|
422
|
+
get_movement_detection_rate=self.get_movement_detection_rate,
|
|
423
|
+
RUN_BAY_OPT=self.RUN_BAY_OPT,
|
|
424
|
+
mrmr_select=self.mrmr_select,
|
|
425
|
+
model_save=self.model_save,
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
# @staticmethod
|
|
429
|
+
# @jit(nopython=True)
|
|
430
|
+
def append_previous_n_samples(X: np.ndarray, y: np.ndarray, n: int = 5):
|
|
431
|
+
"""
|
|
432
|
+
stack feature vector for n samples
|
|
433
|
+
"""
|
|
434
|
+
TIME_DIM = X.shape[0] - n
|
|
435
|
+
FEATURE_DIM = int(n * X.shape[1])
|
|
436
|
+
time_arr = np.empty((TIME_DIM, FEATURE_DIM))
|
|
437
|
+
for time_idx, time_ in enumerate(np.arange(n, X.shape[0])):
|
|
438
|
+
for time_point in range(n):
|
|
439
|
+
time_arr[
|
|
440
|
+
time_idx,
|
|
441
|
+
time_point * X.shape[1] : (time_point + 1) * X.shape[1],
|
|
442
|
+
] = X[time_ - time_point, :]
|
|
443
|
+
return time_arr, y[n:]
|
|
444
|
+
|
|
445
|
+
@staticmethod
|
|
446
|
+
def append_samples_val(X_train, y_train, X_val, y_val, n):
|
|
447
|
+
X_train, y_train = Decoder.append_previous_n_samples(X_train, y_train, n=n)
|
|
448
|
+
X_val, y_val = Decoder.append_previous_n_samples(X_val, y_val, n=n)
|
|
449
|
+
return X_train, y_train, X_val, y_val
|
|
450
|
+
|
|
451
|
+
def fit_model(self, model, X_train, y_train):
|
|
452
|
+
if self.TRAIN_VAL_SPLIT:
|
|
453
|
+
X_train, X_val, y_train, y_val = model_selection.train_test_split(
|
|
454
|
+
X_train, y_train, train_size=0.7, shuffle=False
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
if y_train.sum() == 0 or y_val.sum(0) == 0:
|
|
458
|
+
raise Decoder.ClassMissingException
|
|
459
|
+
|
|
460
|
+
# if type(model) is xgboost.sklearn.XGBClassifier:
|
|
461
|
+
# classes_weights = class_weight.compute_sample_weight(
|
|
462
|
+
# class_weight="balanced", y=y_train
|
|
463
|
+
# )
|
|
464
|
+
# model.set_params(eval_metric="logloss")
|
|
465
|
+
# model.fit(
|
|
466
|
+
# X_train,
|
|
467
|
+
# y_train,
|
|
468
|
+
# eval_set=[(X_val, y_val)],
|
|
469
|
+
# early_stopping_rounds=7,
|
|
470
|
+
# sample_weight=classes_weights,
|
|
471
|
+
# verbose=self.VERBOSE,
|
|
472
|
+
# )
|
|
473
|
+
# elif type(model) is xgboost.sklearn.XGBRegressor:
|
|
474
|
+
# # might be necessary to adapt for other classifiers
|
|
475
|
+
#
|
|
476
|
+
# def evalerror(preds, dtrain):
|
|
477
|
+
# labels = dtrain.get_label()
|
|
478
|
+
# # return a pair metric_name, result. The metric name must not contain a
|
|
479
|
+
# # colon (:) or a space since preds are margin(before logistic
|
|
480
|
+
# # transformation, cutoff at 0)
|
|
481
|
+
#
|
|
482
|
+
# r2 = metrics.r2_score(labels, preds)
|
|
483
|
+
#
|
|
484
|
+
# if r2 < 0:
|
|
485
|
+
# r2 = 0
|
|
486
|
+
#
|
|
487
|
+
# return "r2", -r2
|
|
488
|
+
#
|
|
489
|
+
# model.set_params(eval_metric=evalerror)
|
|
490
|
+
# model.fit(
|
|
491
|
+
# X_train,
|
|
492
|
+
# y_train,
|
|
493
|
+
# eval_set=[(X_val, y_val)],
|
|
494
|
+
# early_stopping_rounds=10,
|
|
495
|
+
# verbose=self.VERBOSE,
|
|
496
|
+
# )
|
|
497
|
+
# else:
|
|
498
|
+
# model.fit(X_train, y_train, eval_set=[(X_val, y_val)])
|
|
499
|
+
else:
|
|
500
|
+
# check for LDA; and apply rebalancing
|
|
501
|
+
if self.oversampling:
|
|
502
|
+
X_train, y_train = self.ros.fit_resample(X_train, y_train)
|
|
503
|
+
if self.undersampling:
|
|
504
|
+
X_train, y_train = self.rus.fit_resample(X_train, y_train)
|
|
505
|
+
|
|
506
|
+
# if type(model) is xgboost.sklearn.XGBClassifier:
|
|
507
|
+
# model.set_params(eval_metric="logloss")
|
|
508
|
+
# model.fit(X_train, y_train)
|
|
509
|
+
# else:
|
|
510
|
+
model.fit(X_train, y_train)
|
|
511
|
+
|
|
512
|
+
return model
|
|
513
|
+
|
|
514
|
+
def eval_model(
|
|
515
|
+
self,
|
|
516
|
+
model_train,
|
|
517
|
+
X_train,
|
|
518
|
+
X_test,
|
|
519
|
+
y_train,
|
|
520
|
+
y_test,
|
|
521
|
+
cv_res: CV_res,
|
|
522
|
+
save_data=True,
|
|
523
|
+
save_probabilities=False,
|
|
524
|
+
) -> CV_res:
|
|
525
|
+
if self.save_coef:
|
|
526
|
+
cv_res.coef.append(model_train.coef_)
|
|
527
|
+
|
|
528
|
+
y_test_pr = model_train.predict(X_test)
|
|
529
|
+
y_train_pr = model_train.predict(X_train)
|
|
530
|
+
|
|
531
|
+
sc_te = self.eval_method(y_test, y_test_pr)
|
|
532
|
+
sc_tr = self.eval_method(y_train, y_train_pr)
|
|
533
|
+
|
|
534
|
+
if self.threshold_score:
|
|
535
|
+
if sc_tr < 0:
|
|
536
|
+
sc_tr = 0
|
|
537
|
+
if sc_te < 0:
|
|
538
|
+
sc_te = 0
|
|
539
|
+
|
|
540
|
+
if self.get_movement_detection_rate:
|
|
541
|
+
self._set_movement_detection_rates(
|
|
542
|
+
y_test, y_test_pr, y_train, y_train_pr, cv_res
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
cv_res.score_train.append(sc_tr)
|
|
546
|
+
cv_res.score_test.append(sc_te)
|
|
547
|
+
if save_data:
|
|
548
|
+
cv_res.X_train.append(X_train)
|
|
549
|
+
cv_res.X_test.append(X_test)
|
|
550
|
+
if self.model_save:
|
|
551
|
+
cv_res.model_save.append(deepcopy(model_train)) # clone won't copy params
|
|
552
|
+
cv_res.y_train.append(y_train)
|
|
553
|
+
cv_res.y_test.append(y_test)
|
|
554
|
+
|
|
555
|
+
if not save_probabilities:
|
|
556
|
+
cv_res.y_train_pr.append(y_train_pr)
|
|
557
|
+
cv_res.y_test_pr.append(y_test_pr)
|
|
558
|
+
else:
|
|
559
|
+
cv_res.y_train_pr.append(model_train.predict_proba(X_train))
|
|
560
|
+
cv_res.y_test_pr.append(model_train.predict_proba(X_test))
|
|
561
|
+
return cv_res
|
|
562
|
+
|
|
563
|
+
def _set_movement_detection_rates(
|
|
564
|
+
self,
|
|
565
|
+
y_test: np.ndarray,
|
|
566
|
+
y_test_pr: np.ndarray,
|
|
567
|
+
y_train: np.ndarray,
|
|
568
|
+
y_train_pr: np.ndarray,
|
|
569
|
+
cv_res: CV_res,
|
|
570
|
+
) -> CV_res:
|
|
571
|
+
mov_detection_rate, fpr, tpr = self.calc_movement_detection_rate(
|
|
572
|
+
y_test,
|
|
573
|
+
y_test_pr,
|
|
574
|
+
self.mov_detection_threshold,
|
|
575
|
+
self.min_consequent_count,
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
cv_res.mov_detection_rates_test.append(mov_detection_rate)
|
|
579
|
+
cv_res.tprate_test.append(tpr)
|
|
580
|
+
cv_res.fprate_test.append(fpr)
|
|
581
|
+
|
|
582
|
+
mov_detection_rate, fpr, tpr = self.calc_movement_detection_rate(
|
|
583
|
+
y_train,
|
|
584
|
+
y_train_pr,
|
|
585
|
+
self.mov_detection_threshold,
|
|
586
|
+
self.min_consequent_count,
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
cv_res.mov_detection_rates_train.append(mov_detection_rate)
|
|
590
|
+
cv_res.tprate_train.append(tpr)
|
|
591
|
+
cv_res.fprate_train.append(fpr)
|
|
592
|
+
|
|
593
|
+
return cv_res
|
|
594
|
+
|
|
595
|
+
def wrapper_model_train(
|
|
596
|
+
self,
|
|
597
|
+
X_train,
|
|
598
|
+
y_train,
|
|
599
|
+
X_test=None,
|
|
600
|
+
y_test=None,
|
|
601
|
+
cv_res: CV_res | None = None,
|
|
602
|
+
return_fitted_model_only: bool = False,
|
|
603
|
+
save_data=True,
|
|
604
|
+
):
|
|
605
|
+
if cv_res is None:
|
|
606
|
+
cv_res = CV_res(
|
|
607
|
+
get_movement_detection_rate=self.get_movement_detection_rate,
|
|
608
|
+
RUN_BAY_OPT=self.RUN_BAY_OPT,
|
|
609
|
+
mrmr_select=self.mrmr_select,
|
|
610
|
+
model_save=self.model_save,
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
model_train = clone(self.model)
|
|
614
|
+
if self.STACK_FEATURES_N_SAMPLES:
|
|
615
|
+
if X_test is not None:
|
|
616
|
+
X_train, y_train, X_test, y_test = Decoder.append_samples_val(
|
|
617
|
+
X_train,
|
|
618
|
+
y_train,
|
|
619
|
+
X_test,
|
|
620
|
+
y_test,
|
|
621
|
+
n=self.time_stack_n_samples,
|
|
622
|
+
)
|
|
623
|
+
else:
|
|
624
|
+
X_train, y_train = Decoder.append_previous_n_samples(
|
|
625
|
+
X_train, y_train, n=self.time_stack_n_samples
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
if y_train.sum() == 0 or (
|
|
629
|
+
y_test is not None and y_test.sum() == 0
|
|
630
|
+
): # only one class present
|
|
631
|
+
raise Decoder.ClassMissingException
|
|
632
|
+
|
|
633
|
+
if self.RUN_BAY_OPT:
|
|
634
|
+
model_train = self.bay_opt_wrapper(model_train, X_train, y_train)
|
|
635
|
+
|
|
636
|
+
if self.mrmr_select:
|
|
637
|
+
from mrmr import mrmr_classif
|
|
638
|
+
|
|
639
|
+
if len(self.feature_names) > X_train.shape[1]:
|
|
640
|
+
# analyze induvidual ch
|
|
641
|
+
columns_names = [
|
|
642
|
+
col
|
|
643
|
+
for col in self.feature_names
|
|
644
|
+
if col.startswith(self.ch_name_tested)
|
|
645
|
+
]
|
|
646
|
+
if self.columns_names_single_ch is None:
|
|
647
|
+
self.columns_names_single_ch = [
|
|
648
|
+
f[len(self.ch_name_tested) + 1 :] for f in columns_names
|
|
649
|
+
]
|
|
650
|
+
else:
|
|
651
|
+
# analyze all_ch_combined
|
|
652
|
+
columns_names = self.feature_names
|
|
653
|
+
X_train = pd.DataFrame(X_train, columns=columns_names)
|
|
654
|
+
X_test = pd.DataFrame(X_test, columns=columns_names)
|
|
655
|
+
|
|
656
|
+
y_train = pd.Series(y_train)
|
|
657
|
+
selected_features = mrmr_classif(X=X_train, y=y_train, K=20, n_jobs=60)
|
|
658
|
+
|
|
659
|
+
X_train = X_train[selected_features]
|
|
660
|
+
X_test = X_test[selected_features]
|
|
661
|
+
|
|
662
|
+
if self.pca:
|
|
663
|
+
from sklearn.decomposition import PCA
|
|
664
|
+
|
|
665
|
+
pca = PCA(n_components=10)
|
|
666
|
+
pca.fit(X_train)
|
|
667
|
+
X_train = pca.transform(X_train)
|
|
668
|
+
X_test = pca.transform(X_test)
|
|
669
|
+
|
|
670
|
+
if self.cca:
|
|
671
|
+
from sklearn.cross_decomposition import CCA
|
|
672
|
+
|
|
673
|
+
cca = CCA(n_components=10)
|
|
674
|
+
cca.fit(X_train, y_train)
|
|
675
|
+
X_train = cca.transform(X_train)
|
|
676
|
+
X_test = cca.transform(X_test)
|
|
677
|
+
|
|
678
|
+
if self.STACK_FEATURES_N_SAMPLES:
|
|
679
|
+
if return_fitted_model_only:
|
|
680
|
+
X_train, y_train = self.append_previous_n_samples(
|
|
681
|
+
X_train, y_train, self.time_stack_n_samples
|
|
682
|
+
)
|
|
683
|
+
else:
|
|
684
|
+
X_train, y_train, X_test, y_test = self.append_samples_val(
|
|
685
|
+
X_train, y_train, X_test, y_test, self.time_stack_n_samples
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
# fit model
|
|
689
|
+
model_train = self.fit_model(model_train, X_train, y_train)
|
|
690
|
+
|
|
691
|
+
if return_fitted_model_only:
|
|
692
|
+
return model_train
|
|
693
|
+
|
|
694
|
+
cv_res = self.eval_model(
|
|
695
|
+
model_train, X_train, X_test, y_train, y_test, cv_res, save_data
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
if self.mrmr_select:
|
|
699
|
+
cv_res.mrmr_select.append(selected_features)
|
|
700
|
+
|
|
701
|
+
return cv_res
|
|
702
|
+
|
|
703
|
+
def run_CV(self, data, label):
|
|
704
|
+
"""Evaluate model performance on the specified cross validation.
|
|
705
|
+
If no data and label is specified, use whole feature class attributes.
|
|
706
|
+
|
|
707
|
+
Parameters
|
|
708
|
+
----------
|
|
709
|
+
data (np.ndarray):
|
|
710
|
+
data to train and test with shape samples, features
|
|
711
|
+
label (np.ndarray):
|
|
712
|
+
label to train and test with shape samples, features
|
|
713
|
+
"""
|
|
714
|
+
|
|
715
|
+
def split_data(data):
|
|
716
|
+
if self.cv_method == "NonShuffledTrainTestSplit":
|
|
717
|
+
# set outer 10s set to train index
|
|
718
|
+
# test index is thus in the middle starting at random number
|
|
719
|
+
N_samples = data.shape[0]
|
|
720
|
+
test_area_points = (N_samples - self.sfreq * 10) - (self.sfreq * 10)
|
|
721
|
+
test_points = int(N_samples * 0.3)
|
|
722
|
+
|
|
723
|
+
if test_area_points > test_points:
|
|
724
|
+
start_index = np.random.randint(
|
|
725
|
+
int(self.sfreq * 10),
|
|
726
|
+
N_samples - self.sfreq * 10 - test_points,
|
|
727
|
+
)
|
|
728
|
+
test_index = np.arange(start_index, start_index + test_points)
|
|
729
|
+
train_index = np.concatenate(
|
|
730
|
+
(
|
|
731
|
+
np.arange(0, start_index),
|
|
732
|
+
np.arange(start_index + test_points, N_samples),
|
|
733
|
+
),
|
|
734
|
+
axis=0,
|
|
735
|
+
).flatten()
|
|
736
|
+
yield train_index, test_index
|
|
737
|
+
else:
|
|
738
|
+
cv_single_tr_te_split = model_selection.check_cv(
|
|
739
|
+
cv=[
|
|
740
|
+
model_selection.train_test_split(
|
|
741
|
+
np.arange(data.shape[0]),
|
|
742
|
+
test_size=0.3,
|
|
743
|
+
shuffle=False,
|
|
744
|
+
)
|
|
745
|
+
]
|
|
746
|
+
)
|
|
747
|
+
for (
|
|
748
|
+
train_index,
|
|
749
|
+
test_index,
|
|
750
|
+
) in cv_single_tr_te_split.split():
|
|
751
|
+
yield train_index, test_index
|
|
752
|
+
else:
|
|
753
|
+
for train_index, test_index in self.cv_method.split(data):
|
|
754
|
+
yield train_index, test_index
|
|
755
|
+
|
|
756
|
+
cv_res = self.init_cv_res()
|
|
757
|
+
|
|
758
|
+
if self.use_nested_cv:
|
|
759
|
+
cv_res_inner = self.init_cv_res()
|
|
760
|
+
|
|
761
|
+
for train_index, test_index in split_data(data):
|
|
762
|
+
X_train, y_train = data[train_index, :], label[train_index]
|
|
763
|
+
X_test, y_test = data[test_index], label[test_index]
|
|
764
|
+
try:
|
|
765
|
+
cv_res = self.wrapper_model_train(
|
|
766
|
+
X_train, y_train, X_test, y_test, cv_res
|
|
767
|
+
)
|
|
768
|
+
except Decoder.ClassMissingException:
|
|
769
|
+
continue
|
|
770
|
+
|
|
771
|
+
if self.use_nested_cv:
|
|
772
|
+
data_inner = data[train_index]
|
|
773
|
+
label_inner = label[train_index]
|
|
774
|
+
for train_index_inner, test_index_inner in split_data(data_inner):
|
|
775
|
+
X_train_inner = data_inner[train_index_inner, :]
|
|
776
|
+
y_train_inner = label_inner[train_index_inner]
|
|
777
|
+
X_test_inner = data_inner[test_index_inner]
|
|
778
|
+
y_test_inner = label_inner[test_index_inner]
|
|
779
|
+
try:
|
|
780
|
+
cv_res_inner = self.wrapper_model_train(
|
|
781
|
+
X_train_inner,
|
|
782
|
+
y_train_inner,
|
|
783
|
+
X_test_inner,
|
|
784
|
+
y_test_inner,
|
|
785
|
+
cv_res_inner,
|
|
786
|
+
)
|
|
787
|
+
except Decoder.ClassMissingException:
|
|
788
|
+
continue
|
|
789
|
+
|
|
790
|
+
self.cv_res = cv_res
|
|
791
|
+
if self.use_nested_cv:
|
|
792
|
+
self.cv_res_inner = cv_res_inner
|
|
793
|
+
|
|
794
|
+
def bay_opt_wrapper(self, model_train, X_train, y_train):
|
|
795
|
+
"""Run bayesian optimization and test best params to model_train
|
|
796
|
+
Save best params into self.best_bay_opt_params
|
|
797
|
+
"""
|
|
798
|
+
|
|
799
|
+
(
|
|
800
|
+
X_train_bo,
|
|
801
|
+
X_test_bo,
|
|
802
|
+
y_train_bo,
|
|
803
|
+
y_test_bo,
|
|
804
|
+
) = model_selection.train_test_split(
|
|
805
|
+
X_train, y_train, train_size=0.7, shuffle=False
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
if y_train_bo.sum() == 0 or y_test_bo.sum() == 0:
|
|
809
|
+
logger.critical("could not start Bay. Opt. with no labels > 0")
|
|
810
|
+
raise Decoder.ClassMissingException
|
|
811
|
+
|
|
812
|
+
params_bo = self.run_Bay_Opt(
|
|
813
|
+
X_train_bo, y_train_bo, X_test_bo, y_test_bo, rounds=10
|
|
814
|
+
)
|
|
815
|
+
|
|
816
|
+
# set bay. opt. obtained best params to model
|
|
817
|
+
params_bo_dict = {}
|
|
818
|
+
for i in range(len(params_bo)):
|
|
819
|
+
setattr(model_train, self.bay_opt_param_space[i].name, params_bo[i])
|
|
820
|
+
params_bo_dict[self.bay_opt_param_space[i].name] = params_bo[i]
|
|
821
|
+
|
|
822
|
+
self.best_bay_opt_params.append(params_bo_dict)
|
|
823
|
+
|
|
824
|
+
return model_train
|
|
825
|
+
|
|
826
|
+
def run_Bay_Opt(
|
|
827
|
+
self,
|
|
828
|
+
X_train,
|
|
829
|
+
y_train,
|
|
830
|
+
X_test,
|
|
831
|
+
y_test,
|
|
832
|
+
rounds=30,
|
|
833
|
+
base_estimator="GP",
|
|
834
|
+
acq_func="EI",
|
|
835
|
+
acq_optimizer="sampling",
|
|
836
|
+
initial_point_generator="lhs",
|
|
837
|
+
):
|
|
838
|
+
"""Run skopt bayesian optimization
|
|
839
|
+
skopt.Optimizer:
|
|
840
|
+
https://scikit-optimize.github.io/stable/modules/generated/skopt.Optimizer.html#skopt.Optimizer
|
|
841
|
+
|
|
842
|
+
example:
|
|
843
|
+
https://scikit-optimize.github.io/stable/auto_examples/ask-and-tell.html#sphx-glr-auto-examples-ask-and-tell-py
|
|
844
|
+
|
|
845
|
+
Special attention needs to be made with the run_CV output,
|
|
846
|
+
some metrics are minimized (MAE), some are maximized (r^2)
|
|
847
|
+
|
|
848
|
+
Parameters
|
|
849
|
+
----------
|
|
850
|
+
X_train: np.ndarray
|
|
851
|
+
y_train: np.ndarray
|
|
852
|
+
X_test: np.ndarray
|
|
853
|
+
y_test: np.ndarray
|
|
854
|
+
rounds : int, optional
|
|
855
|
+
optimizing rounds, by default 10
|
|
856
|
+
base_estimator : str, optional
|
|
857
|
+
surrogate model, used as optimization function instead of cross validation, by default "GP"
|
|
858
|
+
acq_func : str, optional
|
|
859
|
+
function to minimize over the posterior distribution, by default "EI"
|
|
860
|
+
acq_optimizer : str, optional
|
|
861
|
+
method to minimize the acquisition function, by default "sampling"
|
|
862
|
+
initial_point_generator : str, optional
|
|
863
|
+
sets a initial point generator, by default "lhs"
|
|
864
|
+
|
|
865
|
+
Returns
|
|
866
|
+
-------
|
|
867
|
+
skopt result parameters
|
|
868
|
+
"""
|
|
869
|
+
|
|
870
|
+
def get_f_val(model_bo):
|
|
871
|
+
try:
|
|
872
|
+
model_bo = self.fit_model(model_bo, X_train, y_train)
|
|
873
|
+
except Decoder.ClassMissingException:
|
|
874
|
+
pass
|
|
875
|
+
|
|
876
|
+
return self.eval_method(y_test, model_bo.predict(X_test))
|
|
877
|
+
|
|
878
|
+
from skopt import Optimizer
|
|
879
|
+
|
|
880
|
+
opt = Optimizer(
|
|
881
|
+
self.bay_opt_param_space,
|
|
882
|
+
base_estimator=base_estimator,
|
|
883
|
+
acq_func=acq_func,
|
|
884
|
+
acq_optimizer=acq_optimizer,
|
|
885
|
+
initial_point_generator=initial_point_generator,
|
|
886
|
+
)
|
|
887
|
+
|
|
888
|
+
for _ in range(rounds):
|
|
889
|
+
next_x = opt.ask()
|
|
890
|
+
# set model values
|
|
891
|
+
model_bo = clone(self.model)
|
|
892
|
+
for i in range(len(next_x)):
|
|
893
|
+
setattr(model_bo, self.bay_opt_param_space[i].name, next_x[i])
|
|
894
|
+
f_val = get_f_val(model_bo)
|
|
895
|
+
res = opt.tell(next_x, f_val)
|
|
896
|
+
if self.VERBOSE:
|
|
897
|
+
logger.info(f_val)
|
|
898
|
+
|
|
899
|
+
# res is here automatically appended by skopt
|
|
900
|
+
return res.x
|
|
901
|
+
|
|
902
|
+
def save(self, feature_path: str, feature_file: str, str_save_add=None) -> None:
|
|
903
|
+
"""Save decoder object to pickle"""
|
|
904
|
+
|
|
905
|
+
# why is the decoder not saved to a .json?
|
|
906
|
+
|
|
907
|
+
if str_save_add is None:
|
|
908
|
+
PATH_OUT = PurePath(feature_path, feature_file, feature_file + "_ML_RES.p")
|
|
909
|
+
else:
|
|
910
|
+
PATH_OUT = PurePath(
|
|
911
|
+
feature_path,
|
|
912
|
+
feature_file,
|
|
913
|
+
feature_file + "_" + str_save_add + "_ML_RES.p",
|
|
914
|
+
)
|
|
915
|
+
|
|
916
|
+
logger.info(f"model being saved to: {PATH_OUT}")
|
|
917
|
+
with open(PATH_OUT, "wb") as output: # Overwrites any existing file.
|
|
918
|
+
pickle.dump(self, output)
|