rapidtide 3.0.11__py3-none-any.whl → 3.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. rapidtide/Colortables.py +492 -27
  2. rapidtide/OrthoImageItem.py +1049 -46
  3. rapidtide/RapidtideDataset.py +1533 -86
  4. rapidtide/_version.py +3 -3
  5. rapidtide/calccoherence.py +196 -29
  6. rapidtide/calcnullsimfunc.py +188 -40
  7. rapidtide/calcsimfunc.py +242 -42
  8. rapidtide/correlate.py +1203 -383
  9. rapidtide/data/examples/src/testLD +56 -0
  10. rapidtide/data/examples/src/testalign +1 -1
  11. rapidtide/data/examples/src/testdelayvar +0 -1
  12. rapidtide/data/examples/src/testfmri +53 -3
  13. rapidtide/data/examples/src/testglmfilt +5 -5
  14. rapidtide/data/examples/src/testhappy +29 -7
  15. rapidtide/data/examples/src/testppgproc +17 -0
  16. rapidtide/data/examples/src/testrolloff +11 -0
  17. rapidtide/data/models/model_cnn_pytorch/best_model.pth +0 -0
  18. rapidtide/data/models/model_cnn_pytorch/loss.png +0 -0
  19. rapidtide/data/models/model_cnn_pytorch/loss.txt +1 -0
  20. rapidtide/data/models/model_cnn_pytorch/model.pth +0 -0
  21. rapidtide/data/models/model_cnn_pytorch/model_meta.json +68 -0
  22. rapidtide/decorators.py +91 -0
  23. rapidtide/dlfilter.py +2226 -110
  24. rapidtide/dlfiltertorch.py +4842 -0
  25. rapidtide/externaltools.py +327 -12
  26. rapidtide/fMRIData_class.py +79 -40
  27. rapidtide/filter.py +1899 -810
  28. rapidtide/fit.py +2011 -581
  29. rapidtide/genericmultiproc.py +93 -18
  30. rapidtide/happy_supportfuncs.py +2047 -172
  31. rapidtide/helper_classes.py +584 -43
  32. rapidtide/io.py +2370 -372
  33. rapidtide/linfitfiltpass.py +346 -99
  34. rapidtide/makelaggedtcs.py +210 -24
  35. rapidtide/maskutil.py +448 -62
  36. rapidtide/miscmath.py +827 -121
  37. rapidtide/multiproc.py +210 -22
  38. rapidtide/patchmatch.py +242 -42
  39. rapidtide/peakeval.py +31 -31
  40. rapidtide/ppgproc.py +2203 -0
  41. rapidtide/qualitycheck.py +352 -39
  42. rapidtide/refinedelay.py +431 -57
  43. rapidtide/refineregressor.py +494 -189
  44. rapidtide/resample.py +671 -185
  45. rapidtide/scripts/applyppgproc.py +28 -0
  46. rapidtide/scripts/showxcorr_legacy.py +7 -7
  47. rapidtide/scripts/stupidramtricks.py +15 -17
  48. rapidtide/simFuncClasses.py +1052 -77
  49. rapidtide/simfuncfit.py +269 -69
  50. rapidtide/stats.py +540 -238
  51. rapidtide/tests/happycomp +9 -0
  52. rapidtide/tests/test_cleanregressor.py +1 -2
  53. rapidtide/tests/test_dlfiltertorch.py +627 -0
  54. rapidtide/tests/test_findmaxlag.py +24 -8
  55. rapidtide/tests/test_fullrunhappy_v1.py +0 -2
  56. rapidtide/tests/test_fullrunhappy_v2.py +0 -2
  57. rapidtide/tests/test_fullrunhappy_v3.py +11 -4
  58. rapidtide/tests/test_fullrunhappy_v4.py +10 -2
  59. rapidtide/tests/test_fullrunrapidtide_v7.py +1 -1
  60. rapidtide/tests/test_getparsers.py +11 -3
  61. rapidtide/tests/test_refinedelay.py +0 -1
  62. rapidtide/tests/test_simroundtrip.py +16 -8
  63. rapidtide/tests/test_stcorrelate.py +3 -1
  64. rapidtide/tests/utils.py +9 -8
  65. rapidtide/tidepoolTemplate.py +142 -38
  66. rapidtide/tidepoolTemplate_alt.py +165 -44
  67. rapidtide/tidepoolTemplate_big.py +189 -52
  68. rapidtide/util.py +1217 -118
  69. rapidtide/voxelData.py +684 -37
  70. rapidtide/wiener.py +136 -23
  71. rapidtide/wiener2.py +113 -7
  72. rapidtide/workflows/adjustoffset.py +105 -3
  73. rapidtide/workflows/aligntcs.py +85 -2
  74. rapidtide/workflows/applydlfilter.py +87 -10
  75. rapidtide/workflows/applyppgproc.py +540 -0
  76. rapidtide/workflows/atlasaverage.py +210 -47
  77. rapidtide/workflows/atlastool.py +100 -3
  78. rapidtide/workflows/calcSimFuncMap.py +288 -69
  79. rapidtide/workflows/calctexticc.py +201 -9
  80. rapidtide/workflows/ccorrica.py +101 -6
  81. rapidtide/workflows/cleanregressor.py +165 -31
  82. rapidtide/workflows/delayvar.py +171 -23
  83. rapidtide/workflows/diffrois.py +81 -3
  84. rapidtide/workflows/endtidalproc.py +144 -4
  85. rapidtide/workflows/fdica.py +195 -15
  86. rapidtide/workflows/filtnifti.py +70 -3
  87. rapidtide/workflows/filttc.py +74 -3
  88. rapidtide/workflows/fitSimFuncMap.py +202 -51
  89. rapidtide/workflows/fixtr.py +73 -3
  90. rapidtide/workflows/gmscalc.py +113 -3
  91. rapidtide/workflows/happy.py +801 -199
  92. rapidtide/workflows/happy2std.py +144 -12
  93. rapidtide/workflows/happy_parser.py +163 -23
  94. rapidtide/workflows/histnifti.py +118 -2
  95. rapidtide/workflows/histtc.py +84 -3
  96. rapidtide/workflows/linfitfilt.py +117 -4
  97. rapidtide/workflows/localflow.py +328 -28
  98. rapidtide/workflows/mergequality.py +79 -3
  99. rapidtide/workflows/niftidecomp.py +322 -18
  100. rapidtide/workflows/niftistats.py +174 -4
  101. rapidtide/workflows/pairproc.py +98 -4
  102. rapidtide/workflows/pairwisemergenifti.py +85 -2
  103. rapidtide/workflows/parser_funcs.py +1421 -40
  104. rapidtide/workflows/physiofreq.py +137 -11
  105. rapidtide/workflows/pixelcomp.py +207 -5
  106. rapidtide/workflows/plethquality.py +103 -21
  107. rapidtide/workflows/polyfitim.py +151 -11
  108. rapidtide/workflows/proj2flow.py +75 -2
  109. rapidtide/workflows/rankimage.py +111 -4
  110. rapidtide/workflows/rapidtide.py +368 -76
  111. rapidtide/workflows/rapidtide2std.py +98 -2
  112. rapidtide/workflows/rapidtide_parser.py +109 -9
  113. rapidtide/workflows/refineDelayMap.py +144 -33
  114. rapidtide/workflows/refineRegressor.py +675 -96
  115. rapidtide/workflows/regressfrommaps.py +161 -37
  116. rapidtide/workflows/resamplenifti.py +85 -3
  117. rapidtide/workflows/resampletc.py +91 -3
  118. rapidtide/workflows/retrolagtcs.py +99 -9
  119. rapidtide/workflows/retroregress.py +176 -26
  120. rapidtide/workflows/roisummarize.py +174 -5
  121. rapidtide/workflows/runqualitycheck.py +71 -3
  122. rapidtide/workflows/showarbcorr.py +149 -6
  123. rapidtide/workflows/showhist.py +86 -2
  124. rapidtide/workflows/showstxcorr.py +160 -3
  125. rapidtide/workflows/showtc.py +159 -3
  126. rapidtide/workflows/showxcorrx.py +190 -10
  127. rapidtide/workflows/showxy.py +185 -15
  128. rapidtide/workflows/simdata.py +264 -38
  129. rapidtide/workflows/spatialfit.py +77 -2
  130. rapidtide/workflows/spatialmi.py +250 -27
  131. rapidtide/workflows/spectrogram.py +305 -32
  132. rapidtide/workflows/synthASL.py +154 -3
  133. rapidtide/workflows/tcfrom2col.py +76 -2
  134. rapidtide/workflows/tcfrom3col.py +74 -2
  135. rapidtide/workflows/tidepool.py +2971 -130
  136. rapidtide/workflows/utils.py +19 -14
  137. rapidtide/workflows/utils_doc.py +293 -0
  138. rapidtide/workflows/variabilityizer.py +116 -3
  139. {rapidtide-3.0.11.dist-info → rapidtide-3.1.1.dist-info}/METADATA +10 -8
  140. {rapidtide-3.0.11.dist-info → rapidtide-3.1.1.dist-info}/RECORD +144 -128
  141. {rapidtide-3.0.11.dist-info → rapidtide-3.1.1.dist-info}/entry_points.txt +1 -0
  142. {rapidtide-3.0.11.dist-info → rapidtide-3.1.1.dist-info}/WHEEL +0 -0
  143. {rapidtide-3.0.11.dist-info → rapidtide-3.1.1.dist-info}/licenses/LICENSE +0 -0
  144. {rapidtide-3.0.11.dist-info → rapidtide-3.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,4842 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ #
4
+ # Copyright 2016-2025 Blaise Frederick
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ #
18
+ #
19
+ import glob
20
+ import logging
21
+ import os
22
+ import sys
23
+ import warnings
24
+
25
+ import matplotlib as mpl
26
+ import matplotlib.pyplot as plt
27
+ import numpy as np
28
+ import tqdm
29
+ from numpy.typing import NDArray
30
+
31
+ with warnings.catch_warnings():
32
+ warnings.simplefilter("ignore")
33
+ try:
34
+ import pyfftw
35
+ except ImportError:
36
+ pyfftwpresent = False
37
+ else:
38
+ pyfftwpresent = True
39
+
40
+ from scipy import fftpack
41
+ from statsmodels.robust.scale import mad
42
+
43
+ if pyfftwpresent:
44
+ fftpack = pyfftw.interfaces.scipy_fftpack
45
+ pyfftw.interfaces.cache.enable()
46
+
47
+ import torch
48
+ import torch.nn as nn
49
+ import torch.optim as optim
50
+ from torch.utils.data import DataLoader, TensorDataset
51
+
52
+ import rapidtide.io as tide_io
53
+
54
+ LGR = logging.getLogger("GENERAL")
55
+ LGR.debug("setting backend to Agg")
56
+ mpl.use("Agg")
57
+
58
+ # Disable GPU if desired
59
+ if torch.cuda.is_available():
60
+ device = torch.device("cuda")
61
+ LGR.info(f"Using CUDA device: {torch.cuda.get_device_name(0)}")
62
+ elif torch.backends.mps.is_available():
63
+ device = torch.device("mps")
64
+ LGR.info(f"Using MPS device")
65
+ else:
66
+ device = torch.device("cpu")
67
+ LGR.info("Using CPU")
68
+
69
+ LGR.debug(f"pytorch version: >>>{torch.__version__}<<<")
70
+
71
+
72
+ class DeepLearningFilter:
73
+ """Base class for deep learning filter"""
74
+
75
+ thesuffix = "sliceres"
76
+ thedatadir = "/Users/frederic/Documents/MR_data/physioconn/timecourses"
77
+ inputfrag = "abc"
78
+ targetfrag = "xyz"
79
+ namesuffix = None
80
+ modelroot = "."
81
+ excludethresh = 4.0
82
+ modelname = None
83
+ intermediatemodelpath = None
84
+ usebadpts = False
85
+ activation = "tanh"
86
+ dofft = False
87
+ readlim = None
88
+ countlim = None
89
+ lossfilename = None
90
+ train_x = None
91
+ train_y = None
92
+ val_x = None
93
+ val_y = None
94
+ model = None
95
+ modelpath = None
96
+ inputsize = None
97
+ infodict = {}
98
+
99
+ def __init__(
100
+ self,
101
+ window_size: int = 128,
102
+ num_layers: int = 5,
103
+ dropout_rate: float = 0.3,
104
+ num_pretrain_epochs: int = 0,
105
+ num_epochs: int = 1,
106
+ activation: str = "relu",
107
+ modelroot: str = ".",
108
+ dofft: bool = False,
109
+ excludethresh: float = 4.0,
110
+ usebadpts: bool = False,
111
+ thesuffix: str = "25.0Hz",
112
+ modelpath: str = ".",
113
+ thedatadir: str = "/Users/frederic/Documents/MR_data/physioconn/timecourses",
114
+ inputfrag: str = "abc",
115
+ targetfrag: str = "xyz",
116
+ corrthresh: float = 0.5,
117
+ excludebysubject: bool = True,
118
+ startskip: int = 200,
119
+ endskip: int = 200,
120
+ step: int = 1,
121
+ namesuffix: str | None = None,
122
+ readlim: int | None = None,
123
+ readskip: int | None = None,
124
+ countlim: int | None = None,
125
+ **kwargs,
126
+ ) -> None:
127
+ """
128
+ Initialize the DeepLearningFilter with specified parameters.
129
+
130
+ This constructor sets up the configuration for a deep learning model used
131
+ for filtering physiological timecourses. It initializes various hyperparameters,
132
+ paths, and flags that control the behavior of the model and data processing.
133
+
134
+ Parameters
135
+ ----------
136
+ window_size : int, optional
137
+ Size of the sliding window used for processing time series data. Default is 128.
138
+ num_layers : int, optional
139
+ Number of layers in the neural network model. Default is 5.
140
+ dropout_rate : float, optional
141
+ Dropout rate for regularization during training. Default is 0.3.
142
+ num_pretrain_epochs : int, optional
143
+ Number of pre-training epochs. Default is 0.
144
+ num_epochs : int, optional
145
+ Number of training epochs. Default is 1.
146
+ activation : str, optional
147
+ Activation function to use in the model. Default is "relu".
148
+ modelroot : str, optional
149
+ Root directory for model storage. Default is ".".
150
+ dofft : bool, optional
151
+ Whether to apply FFT transformation to input data. Default is False.
152
+ excludethresh : float, optional
153
+ Threshold for excluding data points based on correlation. Default is 4.0.
154
+ usebadpts : bool, optional
155
+ Whether to include bad points in the input. Default is False.
156
+ thesuffix : str, optional
157
+ Suffix to append to filenames. Default is "25.0Hz".
158
+ modelpath : str, optional
159
+ Path to save or load the model. Default is ".".
160
+ thedatadir : str, optional
161
+ Directory containing the physiological data files. Default is
162
+ "/Users/frederic/Documents/MR_data/physioconn/timecourses".
163
+ inputfrag : str, optional
164
+ Fragment identifier for input data. Default is "abc".
165
+ targetfrag : str, optional
166
+ Fragment identifier for target data. Default is "xyz".
167
+ corrthresh : float, optional
168
+ Correlation threshold for filtering. Default is 0.5.
169
+ excludebysubject : bool, optional
170
+ Whether to exclude data by subject. Default is True.
171
+ startskip : int, optional
172
+ Number of samples to skip at the beginning of each timecourse. Default is 200.
173
+ endskip : int, optional
174
+ Number of samples to skip at the end of each timecourse. Default is 200.
175
+ step : int, optional
176
+ Step size for sliding window. Default is 1.
177
+ namesuffix : str, optional
178
+ Suffix to append to model name. Default is None.
179
+ readlim : int, optional
180
+ Limit on number of samples to read. Default is None.
181
+ readskip : int, optional
182
+ Number of samples to skip when reading data. Default is None.
183
+ countlim : int, optional
184
+ Limit on number of timecourses to process. Default is None.
185
+ **kwargs
186
+ Additional keyword arguments passed to the parent class.
187
+
188
+ Notes
189
+ -----
190
+ The `inputsize` is dynamically set based on the `usebadpts` flag:
191
+ - If `usebadpts` is True, input size is 2.
192
+ - Otherwise, input size is 1.
193
+
194
+ Examples
195
+ --------
196
+ >>> filter = DeepLearningFilter(
197
+ ... window_size=256,
198
+ ... num_layers=6,
199
+ ... dropout_rate=0.2,
200
+ ... modelroot="/models",
201
+ ... dofft=True
202
+ ... )
203
+ """
204
+ self.window_size = window_size
205
+ self.dropout_rate = dropout_rate
206
+ self.num_pretrain_epochs = num_pretrain_epochs
207
+ self.num_epochs = num_epochs
208
+ self.usebadpts = usebadpts
209
+ self.num_layers = num_layers
210
+ if self.usebadpts:
211
+ self.inputsize = 2
212
+ else:
213
+ self.inputsize = 1
214
+ self.activation = activation
215
+ self.modelroot = modelroot
216
+ self.dofft = dofft
217
+ self.thesuffix = thesuffix
218
+ self.thedatadir = thedatadir
219
+ self.modelpath = modelpath
220
+ LGR.info(f"modeldir from DeepLearningFilter: {self.modelpath}")
221
+ self.corrthresh = corrthresh
222
+ self.excludethresh = excludethresh
223
+ self.readlim = readlim
224
+ self.readskip = readskip
225
+ self.countlim = countlim
226
+ self.model = None
227
+ self.initialized = False
228
+ self.trained = False
229
+ self.usetensorboard = False
230
+ self.inputfrag = inputfrag
231
+ self.targetfrag = targetfrag
232
+ self.namesuffix = namesuffix
233
+ self.startskip = startskip
234
+ self.endskip = endskip
235
+ self.step = step
236
+ self.excludebysubject = excludebysubject
237
+ self.device = device
238
+
239
+ # populate infodict
240
+ self.infodict["window_size"] = self.window_size
241
+ self.infodict["usebadpts"] = self.usebadpts
242
+ self.infodict["dofft"] = self.dofft
243
+ self.infodict["corrthresh"] = self.corrthresh
244
+ self.infodict["excludethresh"] = self.excludethresh
245
+ self.infodict["num_pretrain_epochs"] = self.num_pretrain_epochs
246
+ self.infodict["num_epochs"] = self.num_epochs
247
+ self.infodict["modelname"] = self.modelname
248
+ self.infodict["dropout_rate"] = self.dropout_rate
249
+ self.infodict["startskip"] = self.startskip
250
+ self.infodict["endskip"] = self.endskip
251
+ self.infodict["step"] = self.step
252
+ self.infodict["train_arch"] = sys.platform
253
+
254
+ def loaddata(self) -> None:
255
+ """
256
+ Load and preprocess data for training and validation.
257
+
258
+ This method initializes the data loading process by calling the `prep` function
259
+ with a set of parameters derived from the instance attributes. It handles both
260
+ FFT and non-FFT modes of data preprocessing. The loaded data is stored in
261
+ instance variables for use in subsequent training steps.
262
+
263
+ Parameters
264
+ ----------
265
+ self : object
266
+ The instance of the class containing the following attributes:
267
+ - initialized : bool
268
+ Indicates whether the model has been initialized.
269
+ - dofft : bool
270
+ Whether to apply FFT transformation to the data.
271
+ - window_size : int
272
+ Size of the sliding window used for data segmentation.
273
+ - thesuffix : str
274
+ Suffix to append to filenames when reading data.
275
+ - thedatadir : str
276
+ Directory path where the data files are located.
277
+ - inputfrag : str
278
+ Fragment identifier for input data.
279
+ - targetfrag : str
280
+ Fragment identifier for target data.
281
+ - startskip : int
282
+ Number of samples to skip at the beginning of each file.
283
+ - endskip : int
284
+ Number of samples to skip at the end of each file.
285
+ - corrthresh : float
286
+ Correlation threshold for filtering data.
287
+ - step : int
288
+ Step size for sliding window.
289
+ - usebadpts : bool
290
+ Whether to include bad points in the data.
291
+ - excludethresh : float
292
+ Threshold for excluding data points.
293
+ - excludebysubject : bool
294
+ Whether to exclude data by subject.
295
+ - readlim : int
296
+ Limit on the number of samples to read.
297
+ - readskip : int
298
+ Number of samples to skip while reading.
299
+ - countlim : int
300
+ Limit on the number of data points to process.
301
+
302
+ Returns
303
+ -------
304
+ None
305
+ This method does not return any value. It modifies the instance attributes
306
+ in place.
307
+
308
+ Raises
309
+ ------
310
+ Exception
311
+ If the model is not initialized prior to calling this method.
312
+
313
+ Notes
314
+ -----
315
+ The method assigns the following attributes to the instance after loading:
316
+ - train_x : array-like
317
+ Training input data.
318
+ - train_y : array-like
319
+ Training target data.
320
+ - val_x : array-like
321
+ Validation input data.
322
+ - val_y : array-like
323
+ Validation target data.
324
+ - Ns : int
325
+ Number of samples.
326
+ - tclen : int
327
+ Length of time series.
328
+ - thebatchsize : int
329
+ Batch size for training.
330
+
331
+ Examples
332
+ --------
333
+ >>> model = MyModel()
334
+ >>> model.initialized = True
335
+ >>> model.loaddata()
336
+ >>> print(model.train_x.shape)
337
+ (1000, 10)
338
+ """
339
+ if not self.initialized:
340
+ raise Exception("model must be initialized prior to loading data")
341
+
342
+ if self.dofft:
343
+ (
344
+ self.train_x,
345
+ self.train_y,
346
+ self.val_x,
347
+ self.val_y,
348
+ self.Ns,
349
+ self.tclen,
350
+ self.thebatchsize,
351
+ dummy,
352
+ dummy,
353
+ ) = prep(
354
+ self.window_size,
355
+ thesuffix=self.thesuffix,
356
+ thedatadir=self.thedatadir,
357
+ inputfrag=self.inputfrag,
358
+ targetfrag=self.targetfrag,
359
+ startskip=self.startskip,
360
+ endskip=self.endskip,
361
+ corrthresh=self.corrthresh,
362
+ step=self.step,
363
+ dofft=self.dofft,
364
+ usebadpts=self.usebadpts,
365
+ excludethresh=self.excludethresh,
366
+ excludebysubject=self.excludebysubject,
367
+ readlim=self.readlim,
368
+ readskip=self.readskip,
369
+ countlim=self.countlim,
370
+ )
371
+ else:
372
+ (
373
+ self.train_x,
374
+ self.train_y,
375
+ self.val_x,
376
+ self.val_y,
377
+ self.Ns,
378
+ self.tclen,
379
+ self.thebatchsize,
380
+ ) = prep(
381
+ self.window_size,
382
+ thesuffix=self.thesuffix,
383
+ thedatadir=self.thedatadir,
384
+ inputfrag=self.inputfrag,
385
+ targetfrag=self.targetfrag,
386
+ startskip=self.startskip,
387
+ endskip=self.endskip,
388
+ corrthresh=self.corrthresh,
389
+ step=self.step,
390
+ dofft=self.dofft,
391
+ usebadpts=self.usebadpts,
392
+ excludethresh=self.excludethresh,
393
+ excludebysubject=self.excludebysubject,
394
+ readlim=self.readlim,
395
+ readskip=self.readskip,
396
+ countlim=self.countlim,
397
+ )
398
+
399
+ def predict_model(self, X: NDArray) -> NDArray:
400
+ """
401
+ Make predictions using the trained model.
402
+
403
+ Parameters
404
+ ----------
405
+ X : NDArray
406
+ Input features for prediction. Shape should be (n_samples, n_features)
407
+ where n_samples is the number of samples and n_features is the number
408
+ of features expected by the model.
409
+
410
+ Returns
411
+ -------
412
+ NDArray
413
+ Model predictions. Shape will depend on the specific model type but
414
+ typically follows (n_samples,) for regression or (n_samples, n_classes)
415
+ for classification.
416
+
417
+ Notes
418
+ -----
419
+ This method sets the model to inference mode by calling with training=False.
420
+ The predictions are made without computing gradients, making it efficient
421
+ for inference tasks. Input data is automatically converted to PyTorch tensors
422
+ and moved to the appropriate device. Special handling is included for
423
+ tensor dimension permutation to match model expectations.
424
+
425
+ Examples
426
+ --------
427
+ >>> # Assuming model is already trained
428
+ >>> X_test = np.array([[1.0, 2.0], [3.0, 4.0]])
429
+ >>> predictions = model.predict_model(X_test)
430
+ >>> print(predictions)
431
+ """
432
+ self.model.eval()
433
+ with torch.no_grad():
434
+ if isinstance(X, np.ndarray):
435
+ X = torch.from_numpy(X).float().to(self.device)
436
+ # PyTorch expects (batch, channels, length) but we have (batch, length, channels)
437
+ X = X.permute(0, 2, 1)
438
+ output = self.model(X)
439
+ # Convert back to (batch, length, channels)
440
+ output = output.permute(0, 2, 1)
441
+ return output.cpu().numpy()
442
+
443
+ def evaluate(self) -> tuple[list, list, float, float]:
444
+ """
445
+ Evaluate the model performance on validation data and compute loss metrics.
446
+
447
+ This method performs model evaluation by computing prediction errors and
448
+ saving training/validation loss curves. It calculates both prediction error
449
+ (difference between predicted and actual values) and raw error (difference
450
+ between input and actual values). The method also generates and saves a
451
+ plot of the training and validation loss over epochs.
452
+
453
+ Parameters
454
+ ----------
455
+ self : object
456
+ The instance of the class containing the model and data attributes.
457
+
458
+ Returns
459
+ -------
460
+ tuple[list, list, float, float]
461
+ A tuple containing:
462
+ - training_loss : list
463
+ List of training loss values per epoch
464
+ - validation_loss : list
465
+ List of validation loss values per epoch
466
+ - prediction_error : float
467
+ Mean squared error between predicted and actual values
468
+ - raw_error : float
469
+ Mean squared error between input features and actual values
470
+
471
+ Notes
472
+ -----
473
+ This method modifies the instance attributes:
474
+ - self.lossfilename: Path to the saved loss plot
475
+ - self.pred_error: Computed prediction error
476
+ - self.raw_error: Computed raw error
477
+ - self.loss: Training loss history
478
+ - self.val_loss: Validation loss history
479
+
480
+ The method saves:
481
+ - Loss plot as PNG file
482
+ - Loss metrics as text file
483
+
484
+ Examples
485
+ --------
486
+ >>> model = MyModel()
487
+ >>> train_loss, val_loss, pred_error, raw_error = model.evaluate()
488
+ >>> print(f"Prediction Error: {pred_error}")
489
+ Prediction Error: 0.1234
490
+ """
491
+ self.lossfilename = os.path.join(self.modelname, "loss.png")
492
+ LGR.info(f"lossfilename: {self.lossfilename}")
493
+
494
+ YPred = self.predict_model(self.val_x)
495
+
496
+ error = self.val_y - YPred
497
+ self.pred_error = np.mean(np.square(error))
498
+
499
+ error2 = self.val_x - self.val_y
500
+ self.raw_error = np.mean(np.square(error2))
501
+ LGR.info(f"Prediction Error: {self.pred_error}\tRaw Error: {self.raw_error}")
502
+
503
+ f = open(os.path.join(self.modelname, "loss.txt"), "w")
504
+ f.write(
505
+ self.modelname
506
+ + ": Prediction Error: "
507
+ + str(self.pred_error)
508
+ + " Raw Error: "
509
+ + str(self.raw_error)
510
+ + "\n"
511
+ )
512
+ f.close()
513
+
514
+ epochs = range(len(self.loss))
515
+
516
+ self.updatemetadata()
517
+
518
+ plt.figure()
519
+ plt.plot(epochs, self.loss, "bo", label="Training loss")
520
+ plt.plot(epochs, self.val_loss, "b", label="Validation loss")
521
+ plt.title("Training and validation loss")
522
+ plt.legend()
523
+ plt.savefig(self.lossfilename)
524
+ plt.close()
525
+
526
+ return self.loss, self.val_loss, self.pred_error, self.raw_error
527
+
528
+ def initmetadata(self) -> None:
529
+ """
530
+ Initialize and store metadata information for the model.
531
+
532
+ This function creates a dictionary containing various model configuration parameters
533
+ and writes them to a JSON file for future reference and reproducibility.
534
+
535
+ Parameters
536
+ ----------
537
+ self : object
538
+ The instance of the class containing the metadata attributes. Expected to have
539
+ the following attributes:
540
+ - `nettype`: Type of neural network
541
+ - `window_size`: Size of the window used for processing
542
+ - `usebadpts`: Flag indicating whether bad points are handled
543
+ - `dofft`: Flag indicating whether FFT is used
544
+ - `excludethresh`: Threshold for exclusion
545
+ - `num_epochs`: Number of training epochs
546
+ - `num_layers`: Number of layers in the model
547
+ - `dropout_rate`: Dropout rate for regularization
548
+ - `modelname`: Name of the model
549
+
550
+ Returns
551
+ -------
552
+ None
553
+ This function does not return any value but writes metadata to a JSON file.
554
+
555
+ Notes
556
+ -----
557
+ The metadata includes:
558
+ - Window size for processing
559
+ - Bad point handling flag
560
+ - FFT usage flag
561
+ - Exclusion threshold
562
+ - Number of epochs and layers
563
+ - Dropout rate
564
+ - Operating system platform
565
+ - Model name
566
+
567
+ The metadata is saved to ``{modelname}/model_meta.json`` where ``modelname``
568
+ is the model's name attribute.
569
+
570
+ Examples
571
+ --------
572
+ >>> model = MyModel()
573
+ >>> model.initmetadata()
574
+ >>> # Metadata stored in modelname/model_meta.json
575
+ """
576
+
577
+ self.infodict = {}
578
+ self.infodict["nettype"] = self.nettype
579
+ self.infodict["window_size"] = self.window_size
580
+ self.infodict["usebadpts"] = self.usebadpts
581
+ self.infodict["dofft"] = self.dofft
582
+ self.infodict["excludethresh"] = self.excludethresh
583
+ self.infodict["num_epochs"] = self.num_epochs
584
+ self.infodict["num_layers"] = self.num_layers
585
+ self.infodict["dropout_rate"] = self.dropout_rate
586
+ self.infodict["train_arch"] = sys.platform
587
+ self.infodict["modelname"] = self.modelname
588
+ tide_io.writedicttojson(self.infodict, os.path.join(self.modelname, "model_meta.json"))
589
+
590
+ def updatemetadata(self) -> None:
591
+ """
592
+ Update metadata dictionary with model metrics and save to JSON file.
593
+
594
+ This method updates the internal information dictionary with various model
595
+ performance metrics and writes the complete metadata to a JSON file for
596
+ model persistence and tracking.
597
+
598
+ Parameters
599
+ ----------
600
+ self : object
601
+ The instance of the class containing the metadata and model information.
602
+ Expected to have the following attributes:
603
+ - infodict : dict
604
+ Dictionary containing model metadata.
605
+ - loss : float
606
+ Training loss value.
607
+ - val_loss : float
608
+ Validation loss value.
609
+ - raw_error : float
610
+ Raw error metric.
611
+ - pred_error : float
612
+ Prediction error metric.
613
+ - modelname : str
614
+ Name/path of the model for file output.
615
+
616
+ Returns
617
+ -------
618
+ None
619
+ This method does not return any value but modifies the `infodict` in-place
620
+ and writes to a JSON file.
621
+
622
+ Notes
623
+ -----
624
+ The method writes metadata to ``{modelname}/model_meta.json`` where
625
+ ``modelname`` is the model name attribute of the instance.
626
+
627
+ Examples
628
+ --------
629
+ >>> model = MyModel()
630
+ >>> model.updatemetadata()
631
+ >>> # Creates model_meta.json with loss, val_loss, raw_error, and pred_error
632
+ """
633
+ self.infodict["loss"] = self.loss
634
+ self.infodict["val_loss"] = self.val_loss
635
+ self.infodict["raw_error"] = self.raw_error
636
+ self.infodict["prediction_error"] = self.pred_error
637
+ tide_io.writedicttojson(self.infodict, os.path.join(self.modelname, "model_meta.json"))
638
+
639
+ def savemodel(self, altname: str | None = None) -> None:
640
+ """
641
+ Save the model to disk with the specified name.
642
+
643
+ This method saves the current model to a Keras file format (.keras) in a
644
+ directory named according to the model name or an alternative name provided.
645
+
646
+ Parameters
647
+ ----------
648
+ altname : str, optional
649
+ Alternative name to use for saving the model. If None, uses the
650
+ model's default name stored in `self.modelname`. Default is None.
651
+
652
+ Returns
653
+ -------
654
+ None
655
+ This method does not return any value.
656
+
657
+ Notes
658
+ -----
659
+ The model is saved in the Keras format (.keras) and stored in a directory
660
+ with the same name as the model. The method logs the saving operation
661
+ using the logger instance `LGR`.
662
+
663
+ Examples
664
+ --------
665
+ >>> # Save model with default name
666
+ >>> savemodel()
667
+ >>>
668
+ >>> # Save model with alternative name
669
+ >>> savemodel(altname="my_custom_model")
670
+ """
671
+ if altname is None:
672
+ modelsavename = self.modelname
673
+ else:
674
+ modelsavename = altname
675
+ LGR.info(f"saving {modelsavename}")
676
+ torch.save(
677
+ {
678
+ "model_state_dict": self.model.state_dict(),
679
+ "model_config": (
680
+ self.model.get_config() if hasattr(self.model, "get_config") else None
681
+ ),
682
+ },
683
+ os.path.join(modelsavename, "model.pth"),
684
+ )
685
+
686
+ def loadmodel(self, modelname: str, verbose: bool = False) -> None:
687
+ """
688
+ Load a trained model from disk and initialize model parameters.
689
+
690
+ Load a Keras model from the specified model directory, along with associated
691
+ metadata and configuration information. The function attempts to load the model
692
+ in Keras format first, falling back to HDF5 format if the Keras format is not found.
693
+
694
+ Parameters
695
+ ----------
696
+ modelname : str
697
+ Name of the model to load, corresponding to a subdirectory in ``self.modelpath``.
698
+ verbose : bool, optional
699
+ If True, print model summary and metadata information. Default is False.
700
+
701
+ Returns
702
+ -------
703
+ None
704
+ This method modifies the instance attributes in-place and does not return anything.
705
+
706
+ Notes
707
+ -----
708
+ The function attempts to load the model in the following order:
709
+ 1. Keras format (model.keras)
710
+ 2. HDF5 format (model.h5)
711
+
712
+ If neither format is found, the function exits with an error message.
713
+
714
+ The loaded model metadata is stored in ``self.infodict``, and model configuration
715
+ is stored in ``self.config``. Additional attributes like ``window_size`` and
716
+ ``usebadpts`` are extracted from the metadata and stored as instance attributes.
717
+
718
+ Examples
719
+ --------
720
+ >>> loader = ModelLoader()
721
+ >>> loader.loadmodel("my_model", verbose=True)
722
+ loading my_model
723
+ Model: "sequential"
724
+ _________________________________________________________________
725
+ Layer (type) Output Shape Param #
726
+ =================================================================
727
+ ...
728
+ >>> print(loader.window_size)
729
+ 100
730
+ """
731
+ # read in the data
732
+ LGR.info(f"loading {modelname}")
733
+
734
+ # load additional information first to reconstruct model
735
+ self.infodict = tide_io.readdictfromjson(
736
+ os.path.join(self.modelpath, modelname, "model_meta.json")
737
+ )
738
+ if verbose:
739
+ print(self.infodict)
740
+ self.window_size = self.infodict["window_size"]
741
+ self.usebadpts = self.infodict["usebadpts"]
742
+
743
+ # Load the model as a dict
744
+ checkpoint = torch.load(
745
+ os.path.join(self.modelpath, modelname, "model.pth"), map_location=self.device
746
+ )
747
+
748
+ # Reconstruct the model architecture (must be done by subclass)
749
+ if self.infodict["nettype"] == "cnn":
750
+ self.num_filters = checkpoint["model_config"]["num_filters"]
751
+ self.kernel_size = checkpoint["model_config"]["kernel_size"]
752
+ self.num_layers = checkpoint["model_config"]["num_layers"]
753
+ self.dropout_rate = checkpoint["model_config"]["dropout_rate"]
754
+ self.dilation_rate = checkpoint["model_config"]["dilation_rate"]
755
+ self.activation = checkpoint["model_config"]["activation"]
756
+ self.inputsize = checkpoint["model_config"]["inputsize"]
757
+
758
+ self.model = CNNModel(
759
+ self.num_filters,
760
+ self.kernel_size,
761
+ self.num_layers,
762
+ self.dropout_rate,
763
+ self.dilation_rate,
764
+ self.activation,
765
+ self.inputsize,
766
+ )
767
+ elif self.infodict["nettype"] == "autoencoder":
768
+ self.encoding_dim = checkpoint["model_config"]["encoding_dim"]
769
+ self.num_layers = checkpoint["model_config"]["num_layers"]
770
+ self.dropout_rate = checkpoint["model_config"]["dropout_rate"]
771
+ self.activation = checkpoint["model_config"]["activation"]
772
+ self.inputsize = checkpoint["model_config"]["inputsize"]
773
+
774
+ self.model = DenseAutoencoderModel(
775
+ self.window_size,
776
+ self.encoding_dim,
777
+ self.num_layers,
778
+ self.dropout_rate,
779
+ self.activation,
780
+ self.inputsize,
781
+ )
782
+ elif self.infodict["nettype"] == "multiscalecnn":
783
+ self.num_filters = checkpoint["model_config"]["num_filters"]
784
+ self.kernel_sizes = checkpoint["model_config"]["kernel_sizes"]
785
+ self.input_lens = checkpoint["model_config"]["input_lens"]
786
+ self.input_width = checkpoint["model_config"]["input_width"]
787
+ self.dilation_rate = checkpoint["model_config"]["dilation_rate"]
788
+
789
+ self.model = MultiscaleCNNModel(
790
+ self.num_filters,
791
+ self.kernel_sizes,
792
+ self.input_lens,
793
+ self.input_width,
794
+ self.dilation_rate,
795
+ )
796
+ elif self.infodict["nettype"] == "convautoencoder":
797
+ self.encoding_dim = checkpoint["model_config"]["encoding_dim"]
798
+ self.num_filters = checkpoint["model_config"]["num_filters"]
799
+ self.kernel_size = checkpoint["model_config"]["kernel_size"]
800
+ self.dropout_rate = checkpoint["model_config"]["dropout_rate"]
801
+ self.activation = checkpoint["model_config"]["activation"]
802
+ self.inputsize = checkpoint["model_config"]["inputsize"]
803
+
804
+ self.model = ConvAutoencoderModel(
805
+ self.window_size,
806
+ self.encoding_dim,
807
+ self.num_filters,
808
+ self.kernel_size,
809
+ self.dropout_rate,
810
+ self.activation,
811
+ self.inputsize,
812
+ )
813
+ elif self.infodict["nettype"] == "crnn":
814
+ self.num_filters = checkpoint["model_config"]["num_filters"]
815
+ self.kernel_size = checkpoint["model_config"]["kernel_size"]
816
+ self.encoding_dim = checkpoint["model_config"]["encoding_dim"]
817
+ self.dropout_rate = checkpoint["model_config"]["dropout_rate"]
818
+ self.activation = checkpoint["model_config"]["activation"]
819
+ self.inputsize = checkpoint["model_config"]["inputsize"]
820
+
821
+ self.model = CRNNModel(
822
+ self.num_filters,
823
+ self.kernel_size,
824
+ self.encoding_dim,
825
+ self.dropout_rate,
826
+ self.activation,
827
+ self.inputsize,
828
+ )
829
+ elif self.infodict["nettype"] == "lstm":
830
+ self.num_units = checkpoint["model_config"]["num_units"]
831
+ self.num_layers = checkpoint["model_config"]["num_layers"]
832
+ self.dropout_rate = checkpoint["model_config"]["dropout_rate"]
833
+ self.inputsize = checkpoint["model_config"]["inputsize"]
834
+
835
+ self.model = LSTMModel(
836
+ self.num_units,
837
+ self.num_layers,
838
+ self.dropout_rate,
839
+ self.window_size,
840
+ self.inputsize,
841
+ )
842
+ elif self.infodict["nettype"] == "hybrid":
843
+ self.num_filters = checkpoint["model_config"]["num_filters"]
844
+ self.kernel_size = checkpoint["model_config"]["kernel_size"]
845
+ self.num_units = checkpoint["model_config"]["num_units"]
846
+ self.num_layers = checkpoint["model_config"]["num_layers"]
847
+ self.dropout_rate = checkpoint["model_config"]["dropout_rate"]
848
+ self.activation = checkpoint["model_config"]["activation"]
849
+ self.inputsize = checkpoint["model_config"]["inputsize"]
850
+ self.invert = checkpoint["model_config"]["invert"]
851
+
852
+ self.model = HybridModel(
853
+ self.num_filters,
854
+ self.kernel_size,
855
+ self.num_units,
856
+ self.num_layers,
857
+ self.dropout_rate,
858
+ self.activation,
859
+ self.inputsize,
860
+ self.window_size,
861
+ self.invert,
862
+ )
863
+ else:
864
+ print(f"nettype {self.infodict['nettype']} is not supported!")
865
+ sys.exit()
866
+
867
+ self.model.load_state_dict(checkpoint["model_state_dict"])
868
+ self.model.to(self.device)
869
+
870
+ def initialize(self) -> None:
871
+ """
872
+ Initialize the model by setting up network architecture and metadata.
873
+
874
+ This method performs a series of initialization steps including retrieving
875
+ the model name, creating the network architecture, displaying model summary,
876
+ saving the model configuration, initializing metadata, and setting appropriate
877
+ flags to indicate initialization status.
878
+
879
+ Parameters
880
+ ----------
881
+ self : object
882
+ The instance of the model class being initialized.
883
+
884
+ Returns
885
+ -------
886
+ None
887
+ This method does not return any value.
888
+
889
+ Notes
890
+ -----
891
+ This method should be called before any training or prediction operations.
892
+ The initialization process sets `self.initialized` to True and `self.trained`
893
+ to False, indicating that the model is ready for training but has not been
894
+ trained yet.
895
+
896
+ Examples
897
+ --------
898
+ >>> model = MyModel()
899
+ >>> model.initialize()
900
+ >>> print(model.initialized)
901
+ True
902
+ >>> print(model.trained)
903
+ False
904
+ """
905
+ self.getname()
906
+ self.makenet()
907
+ print(self.model)
908
+ self.savemodel()
909
+ self.initmetadata()
910
+ self.initialized = True
911
+ self.trained = False
912
+
913
+ def train(self) -> None:
914
+ """
915
+ Train the model using the provided training and validation datasets.
916
+
917
+ This method performs model training with optional pretraining and logging. It supports
918
+ TensorBoard logging, model checkpointing, early stopping, and NaN termination. The trained
919
+ model is saved at the end of training.
920
+
921
+ Parameters
922
+ ----------
923
+ self : object
924
+ The instance of the class containing the model and training configuration.
925
+ Expected attributes include:
926
+ - `model`: PyTorch model to be trained.
927
+ - `train_x`, `train_y`, `val_x`, `val_y`: Training and validation data as numpy arrays.
928
+ - `device`: Device to run the training on (e.g., 'cpu' or 'cuda').
929
+ - `num_pretrain_epochs`: Number of pretraining epochs (default: 0).
930
+ - `num_epochs`: Number of main training epochs.
931
+ - `modelname`: Directory name to save model checkpoints and logs.
932
+ - `usetensorboard`: Boolean flag to enable TensorBoard logging.
933
+ - `savemodel()`: Method to save the final trained model.
934
+
935
+ Returns
936
+ -------
937
+ None
938
+ This function does not return any value.
939
+
940
+ Notes
941
+ -----
942
+ - If `self.usetensorboard` is True, TensorBoard logging is enabled.
943
+ - If `self.num_pretrain_epochs` is greater than 0, a pretraining phase is performed
944
+ before the main training loop.
945
+ - The model is saved after training using the `savemodel()` method.
946
+ - Training uses `ModelCheckpoint`, `EarlyStopping`, and `TerminateOnNaN` callbacks
947
+ to manage training process and prevent overfitting or NaN issues.
948
+ - Intermediate model checkpoints are saved during training.
949
+ - The best model (based on validation loss) is retained and restored upon early stopping.
950
+
951
+ Examples
952
+ --------
953
+ >>> trainer = ModelTrainer(model, train_x, train_y, val_x, val_y)
954
+ >>> trainer.train()
955
+ """
956
+ self.model.train()
957
+ self.model.to(self.device)
958
+
959
+ # Convert numpy arrays to PyTorch tensors and transpose for Conv1d
960
+ print("converting tensors")
961
+ train_x_tensor = torch.from_numpy(self.train_x).float().permute(0, 2, 1)
962
+ train_y_tensor = torch.from_numpy(self.train_y).float().permute(0, 2, 1)
963
+ val_x_tensor = torch.from_numpy(self.val_x).float().permute(0, 2, 1)
964
+ val_y_tensor = torch.from_numpy(self.val_y).float().permute(0, 2, 1)
965
+
966
+ print("setting data")
967
+ train_dataset = TensorDataset(train_x_tensor, train_y_tensor)
968
+ val_dataset = TensorDataset(val_x_tensor, val_y_tensor)
969
+
970
+ train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
971
+ val_loader = DataLoader(val_dataset, batch_size=1024, shuffle=False)
972
+
973
+ print("setting criterion")
974
+ criterion = nn.MSELoss()
975
+
976
+ print("setting optimizer")
977
+ optimizer = optim.RMSprop(self.model.parameters())
978
+
979
+ self.loss = []
980
+ self.val_loss = []
981
+
982
+ best_val_loss = float("inf")
983
+ patience = 10
984
+ patience_counter = 0
985
+
986
+ total_epochs = self.num_pretrain_epochs + self.num_epochs
987
+
988
+ for epoch in range(total_epochs):
989
+ print(f"Epoch {epoch+1}/{total_epochs}")
990
+ # Training phase
991
+ self.model.train()
992
+ train_loss_epoch = 0.0
993
+ # for batch_x, batch_y in train_loader:
994
+ for batch_x, batch_y in tqdm.tqdm(
995
+ train_loader,
996
+ desc="Batch",
997
+ unit="batches",
998
+ disable=False,
999
+ ):
1000
+ batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
1001
+
1002
+ optimizer.zero_grad()
1003
+ outputs = self.model(batch_x)
1004
+ loss = criterion(outputs, batch_y)
1005
+
1006
+ if torch.isnan(loss):
1007
+ LGR.error("NaN loss detected, terminating training")
1008
+ break
1009
+
1010
+ loss.backward()
1011
+ optimizer.step()
1012
+ train_loss_epoch += loss.item()
1013
+
1014
+ train_loss_epoch /= len(train_loader)
1015
+ self.loss.append(train_loss_epoch)
1016
+
1017
+ # Validation phase
1018
+ self.model.eval()
1019
+ val_loss_epoch = 0.0
1020
+ with torch.no_grad():
1021
+ for batch_x, batch_y in val_loader:
1022
+ batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
1023
+ outputs = self.model(batch_x)
1024
+ loss = criterion(outputs, batch_y)
1025
+ val_loss_epoch += loss.item()
1026
+
1027
+ val_loss_epoch /= len(val_loader)
1028
+ self.val_loss.append(val_loss_epoch)
1029
+
1030
+ LGR.info(
1031
+ f"Epoch {epoch+1}/{total_epochs} - Loss: {train_loss_epoch:.4f} - Val Loss: {val_loss_epoch:.4f}"
1032
+ )
1033
+
1034
+ # Save checkpoint
1035
+ self.intermediatemodelpath = os.path.join(
1036
+ self.modelname, f"model_e{epoch+1:02d}_v{val_loss_epoch:.4f}.pth"
1037
+ )
1038
+ torch.save(
1039
+ {
1040
+ "epoch": epoch,
1041
+ "model_state_dict": self.model.state_dict(),
1042
+ "optimizer_state_dict": optimizer.state_dict(),
1043
+ "loss": train_loss_epoch,
1044
+ "val_loss": val_loss_epoch,
1045
+ },
1046
+ self.intermediatemodelpath,
1047
+ )
1048
+
1049
+ # Early stopping
1050
+ if val_loss_epoch < best_val_loss:
1051
+ best_val_loss = val_loss_epoch
1052
+ patience_counter = 0
1053
+ # Save best model
1054
+ torch.save(self.model.state_dict(), os.path.join(self.modelname, "best_model.pth"))
1055
+ else:
1056
+ patience_counter += 1
1057
+ if patience_counter >= patience:
1058
+ LGR.info(f"Early stopping triggered after {epoch+1} epochs")
1059
+ # Restore best weights
1060
+ self.model.load_state_dict(
1061
+ torch.load(
1062
+ os.path.join(self.modelname, "best_model.pth"), weights_only=True
1063
+ )
1064
+ )
1065
+ break
1066
+ self.evaluate()
1067
+
1068
+ self.savemodel()
1069
+ self.trained = True
1070
+
1071
+ def apply(self, inputdata: NDArray, badpts: NDArray | None = None) -> NDArray:
1072
+ """
1073
+ Apply a sliding-window prediction model to the input data, optionally incorporating bad points.
1074
+
1075
+ This function performs a sliding-window prediction using a pre-trained model. It scales the input
1076
+ data using the median absolute deviation (MAD), applies the model to overlapping windows of data,
1077
+ and aggregates predictions with a weighted scheme. Optionally, bad points can be included in
1078
+ the prediction process to influence the model's behavior.
1079
+
1080
+ Parameters
1081
+ ----------
1082
+ inputdata : NDArray
1083
+ Input data array of shape (N,) to be processed.
1084
+ badpts : NDArray | None, optional
1085
+ Array of same shape as `inputdata` indicating bad or invalid points. If None, no bad points
1086
+ are considered. Default is None.
1087
+
1088
+ Returns
1089
+ -------
1090
+ NDArray
1091
+ Predicted data array of the same shape as `inputdata`, with predictions aggregated and
1092
+ weighted across overlapping windows.
1093
+
1094
+ Notes
1095
+ -----
1096
+ - The function uses a sliding window of size `self.window_size` to process input data.
1097
+ - Predictions are aggregated by summing over overlapping windows.
1098
+ - A triangular weight scheme is applied to the aggregated predictions to reduce edge effects.
1099
+ - If `self.usebadpts` is True, `badpts` are included as an additional feature in the model input.
1100
+
1101
+ Examples
1102
+ --------
1103
+ >>> model = MyModel(window_size=10, usebadpts=True)
1104
+ >>> input_data = np.random.randn(100)
1105
+ >>> bad_points = np.zeros_like(input_data)
1106
+ >>> result = model.apply(input_data, bad_points)
1107
+ """
1108
+ initscale = mad(inputdata)
1109
+ scaleddata = inputdata / initscale
1110
+ predicteddata = scaleddata * 0.0
1111
+ weightarray = scaleddata * 0.0
1112
+ N_pts = len(scaleddata)
1113
+ if self.usebadpts:
1114
+ if badpts is None:
1115
+ badpts = scaleddata * 0.0
1116
+ X = np.zeros(((N_pts - self.window_size - 1), self.window_size, 2))
1117
+ for i in range(X.shape[0]):
1118
+ X[i, :, 0] = scaleddata[i : i + self.window_size]
1119
+ X[i, :, 1] = badpts[i : i + self.window_size]
1120
+ else:
1121
+ X = np.zeros(((N_pts - self.window_size - 1), self.window_size, 1))
1122
+ for i in range(X.shape[0]):
1123
+ X[i, :, 0] = scaleddata[i : i + self.window_size]
1124
+
1125
+ Y = self.predict_model(X)
1126
+ for i in range(X.shape[0]):
1127
+ predicteddata[i : i + self.window_size] += Y[i, :, 0]
1128
+
1129
+ weightarray[:] = self.window_size
1130
+ weightarray[0 : self.window_size] = np.linspace(
1131
+ 1.0, self.window_size, self.window_size, endpoint=False
1132
+ )
1133
+ weightarray[-(self.window_size + 1) : -1] = np.linspace(
1134
+ self.window_size, 1.0, self.window_size, endpoint=False
1135
+ )
1136
+ return initscale * predicteddata / weightarray
1137
+
1138
+
1139
+ class CNNModel(nn.Module):
1140
+ def __init__(
1141
+ self,
1142
+ num_filters: int,
1143
+ kernel_size: int,
1144
+ num_layers: int,
1145
+ dropout_rate: float,
1146
+ dilation_rate: int,
1147
+ activation: str,
1148
+ inputsize: int,
1149
+ ) -> None:
1150
+ """
1151
+ Initialize the CNNModel with specified architecture parameters.
1152
+
1153
+ Parameters
1154
+ ----------
1155
+ num_filters : int
1156
+ Number of convolutional filters in each layer.
1157
+ kernel_size : int
1158
+ Size of the convolutional kernel.
1159
+ num_layers : int
1160
+ Total number of layers in the network.
1161
+ dropout_rate : float
1162
+ Dropout rate for regularization.
1163
+ dilation_rate : int
1164
+ Dilation rate for dilated convolutions in intermediate layers.
1165
+ activation : str
1166
+ Activation function to use; options are 'relu' or 'tanh'.
1167
+ inputsize : int
1168
+ Size of the input features.
1169
+
1170
+ Returns
1171
+ -------
1172
+ None
1173
+ This method initializes the model in-place and does not return any value.
1174
+
1175
+ Notes
1176
+ -----
1177
+ The model consists of an input layer, intermediate layers with dilated convolutions,
1178
+ and an output layer. Batch normalization and dropout are applied after each convolutional
1179
+ layer except the output layer. The activation function is applied after each convolutional
1180
+ layer based on the `activation` parameter.
1181
+
1182
+ Examples
1183
+ --------
1184
+ >>> model = CNNModel(
1185
+ ... num_filters=64,
1186
+ ... kernel_size=3,
1187
+ ... num_layers=4,
1188
+ ... dropout_rate=0.2,
1189
+ ... dilation_rate=2,
1190
+ ... activation="relu",
1191
+ ... inputsize=10
1192
+ ... )
1193
+ """
1194
+ super(CNNModel, self).__init__()
1195
+
1196
+ self.num_filters = num_filters
1197
+ self.kernel_size = kernel_size
1198
+ self.num_layers = num_layers
1199
+ self.dropout_rate = dropout_rate
1200
+ self.dilation_rate = dilation_rate
1201
+ self.activation = activation
1202
+ self.inputsize = inputsize
1203
+
1204
+ self.layers = nn.ModuleList()
1205
+
1206
+ # Input layer
1207
+ self.layers.append(nn.Conv1d(inputsize, num_filters, kernel_size, padding="same"))
1208
+ self.layers.append(nn.BatchNorm1d(num_filters))
1209
+ self.layers.append(nn.Dropout(dropout_rate))
1210
+ if activation == "relu":
1211
+ self.layers.append(nn.ReLU())
1212
+ elif activation == "tanh":
1213
+ self.layers.append(nn.Tanh())
1214
+ else:
1215
+ self.layers.append(nn.ReLU())
1216
+
1217
+ # Intermediate layers
1218
+ for _ in range(num_layers - 2):
1219
+ self.layers.append(
1220
+ nn.Conv1d(
1221
+ num_filters,
1222
+ num_filters,
1223
+ kernel_size,
1224
+ dilation=dilation_rate,
1225
+ padding="same",
1226
+ )
1227
+ )
1228
+ self.layers.append(nn.BatchNorm1d(num_filters))
1229
+ self.layers.append(nn.Dropout(dropout_rate))
1230
+ if activation == "relu":
1231
+ self.layers.append(nn.ReLU())
1232
+ elif activation == "tanh":
1233
+ self.layers.append(nn.Tanh())
1234
+ else:
1235
+ self.layers.append(nn.ReLU())
1236
+
1237
+ # Output layer
1238
+ self.layers.append(nn.Conv1d(num_filters, inputsize, kernel_size, padding="same"))
1239
+
1240
+ def forward(self, x):
1241
+ """
1242
+ Forward pass through all layers.
1243
+
1244
+ Applies each layer in the network sequentially to the input tensor.
1245
+
1246
+ Parameters
1247
+ ----------
1248
+ x : torch.Tensor
1249
+ Input tensor to the forward pass. Shape should be compatible with the
1250
+ first layer's expected input dimensions.
1251
+
1252
+ Returns
1253
+ -------
1254
+ torch.Tensor
1255
+ Output tensor after passing through all layers. Shape will depend on
1256
+ the output dimensions of the last layer in the network.
1257
+
1258
+ Notes
1259
+ -----
1260
+ This method applies layers in the order they were added to the network.
1261
+ Each layer's forward method is called sequentially, with the output of
1262
+ one layer becoming the input to the next.
1263
+
1264
+ Examples
1265
+ --------
1266
+ >>> import torch
1267
+ >>> model = MyNetwork()
1268
+ >>> input_tensor = torch.randn(32, 10)
1269
+ >>> output = model.forward(input_tensor)
1270
+ >>> print(output.shape)
1271
+ torch.Size([32, 5])
1272
+ """
1273
+ for layer in self.layers:
1274
+ x = layer(x)
1275
+ return x
1276
+
1277
+ def get_config(self):
1278
+ """
1279
+ Get the configuration parameters of the model.
1280
+
1281
+ Returns
1282
+ -------
1283
+ dict
1284
+ A dictionary containing all configuration parameters with their current values:
1285
+ - num_filters: int, number of filters in the convolutional layers
1286
+ - kernel_size: int, size of the convolutional kernel
1287
+ - num_layers: int, number of layers in the network
1288
+ - dropout_rate: float, dropout rate for regularization
1289
+ - dilation_rate: int, dilation rate for dilated convolution
1290
+ - activation: str, activation function used in layers
1291
+ - inputsize: tuple, input dimensions of the model
1292
+
1293
+ Notes
1294
+ -----
1295
+ This method returns a copy of the current configuration. Modifications to the
1296
+ returned dictionary will not affect the original model configuration.
1297
+
1298
+ Examples
1299
+ --------
1300
+ >>> config = model.get_config()
1301
+ >>> print(config['num_filters'])
1302
+ 32
1303
+ """
1304
+ return {
1305
+ "num_filters": self.num_filters,
1306
+ "kernel_size": self.kernel_size,
1307
+ "num_layers": self.num_layers,
1308
+ "dropout_rate": self.dropout_rate,
1309
+ "dilation_rate": self.dilation_rate,
1310
+ "activation": self.activation,
1311
+ "inputsize": self.inputsize,
1312
+ }
1313
+
1314
+
1315
+ class CNNDLFilter(DeepLearningFilter):
1316
+ def __init__(
1317
+ self,
1318
+ num_filters: int = 10,
1319
+ kernel_size: int = 5,
1320
+ dilation_rate: int = 1,
1321
+ *args,
1322
+ **kwargs,
1323
+ ) -> None:
1324
+ """
1325
+ Initialize CNN deep learning filter.
1326
+
1327
+ Parameters
1328
+ ----------
1329
+ num_filters : int, optional
1330
+ Number of convolutional filters to use, by default 10
1331
+ kernel_size : int, optional
1332
+ Size of the convolutional kernel, by default 5
1333
+ dilation_rate : int, optional
1334
+ Dilation rate for the convolutional layers, by default 1
1335
+ *args
1336
+ Variable length argument list passed to parent class
1337
+ **kwargs
1338
+ Arbitrary keyword arguments passed to parent class
1339
+
1340
+ Returns
1341
+ -------
1342
+ None
1343
+ This method initializes the instance and does not return any value
1344
+
1345
+ Notes
1346
+ -----
1347
+ This constructor sets up the basic configuration for a CNN filter with
1348
+ specified number of filters, kernel size, and dilation rate. The network
1349
+ type is automatically set to "cnn" and information is stored in infodict
1350
+ for later reference.
1351
+
1352
+ Examples
1353
+ --------
1354
+ >>> filter = CNNDLFilter(num_filters=32, kernel_size=3, dilation_rate=2)
1355
+ >>> print(filter.num_filters)
1356
+ 32
1357
+ """
1358
+ self.num_filters = num_filters
1359
+ self.kernel_size = kernel_size
1360
+ self.dilation_rate = dilation_rate
1361
+ self.nettype = "cnn"
1362
+ self.infodict["nettype"] = self.nettype
1363
+ self.infodict["num_filters"] = self.num_filters
1364
+ self.infodict["kernel_size"] = self.kernel_size
1365
+ super(CNNDLFilter, self).__init__(*args, **kwargs)
1366
+
1367
+ def getname(self):
1368
+ """
1369
+ Generate and configure the model name and path based on current parameters.
1370
+
1371
+ This method constructs a descriptive model name string using various instance
1372
+ attributes and creates the corresponding directory path. The generated name
1373
+ includes information about model architecture, hyperparameters, and configuration
1374
+ options. The method also ensures the model directory exists by creating it if
1375
+ necessary.
1376
+
1377
+ Parameters
1378
+ ----------
1379
+ self : object
1380
+ The instance containing model configuration parameters.
1381
+
1382
+ Returns
1383
+ -------
1384
+ None
1385
+ This method does not return a value but modifies instance attributes:
1386
+ - self.modelname: Generated model name string
1387
+ - self.modelpath: Full path to the model directory
1388
+
1389
+ Notes
1390
+ -----
1391
+ The generated model name follows a specific format:
1392
+ "model_cnn_pytorch_wXXX_lYY_fnZZ_flZZ_eXXX_tY_ctZ_sZ_dZ_activation[options]"
1393
+
1394
+ Where:
1395
+ - XXX: window_size (3 digits zero-padded)
1396
+ - YY: num_layers (2 digits zero-padded)
1397
+ - ZZ: num_filters (2 digits zero-padded)
1398
+ - ZZ: kernel_size (2 digits zero-padded)
1399
+ - XXX: num_epochs (3 digits zero-padded)
1400
+ - Y: excludethresh (single digit)
1401
+ - Z: corrthresh (single digit)
1402
+ - Z: step (single digit)
1403
+ - Z: dilation_rate (single digit)
1404
+
1405
+ Options are appended if corresponding boolean flags are True:
1406
+ - _usebadpts: when usebadpts is True
1407
+ - _excludebysubject: when excludebysubject is True
1408
+
1409
+ Examples
1410
+ --------
1411
+ >>> model = MyModel()
1412
+ >>> model.window_size = 128
1413
+ >>> model.num_layers = 3
1414
+ >>> model.num_filters = 16
1415
+ >>> model.kernel_size = 3
1416
+ >>> model.num_epochs = 100
1417
+ >>> model.excludethresh = 0.5
1418
+ >>> model.corrthresh = 0.8
1419
+ >>> model.step = 1
1420
+ >>> model.dilation_rate = 2
1421
+ >>> model.activation = "relu"
1422
+ >>> model.usebadpts = True
1423
+ >>> model.excludebysubject = False
1424
+ >>> model.namesuffix = "test"
1425
+ >>> model.getname()
1426
+ >>> print(model.modelname)
1427
+ 'model_cnn_pytorch_w128_l03_fn16_fl03_e100_t0_ct0_s1_d2_relu_usebadpts_test'
1428
+ """
1429
+ self.modelname = "_".join(
1430
+ [
1431
+ "model",
1432
+ "cnn",
1433
+ "pytorch",
1434
+ "w" + str(self.window_size).zfill(3),
1435
+ "l" + str(self.num_layers).zfill(2),
1436
+ "fn" + str(self.num_filters).zfill(2),
1437
+ "fl" + str(self.kernel_size).zfill(2),
1438
+ "e" + str(self.num_epochs).zfill(3),
1439
+ "t" + str(self.excludethresh),
1440
+ "ct" + str(self.corrthresh),
1441
+ "s" + str(self.step),
1442
+ "d" + str(self.dilation_rate),
1443
+ self.activation,
1444
+ ]
1445
+ )
1446
+ if self.usebadpts:
1447
+ self.modelname += "_usebadpts"
1448
+ if self.excludebysubject:
1449
+ self.modelname += "_excludebysubject"
1450
+ if self.namesuffix is not None:
1451
+ self.modelname += "_" + self.namesuffix
1452
+ self.modelpath = os.path.join(self.modelroot, self.modelname)
1453
+
1454
+ try:
1455
+ os.makedirs(self.modelpath)
1456
+ except OSError:
1457
+ pass
1458
+
1459
+ def makenet(self):
1460
+ """
1461
+ Create and configure a CNN model for neural network training.
1462
+
1463
+ This method initializes a CNNModel with the specified parameters and moves
1464
+ it to the designated device (CPU or GPU). The model configuration is
1465
+ determined by the instance attributes set prior to calling this method.
1466
+
1467
+ Parameters
1468
+ ----------
1469
+ self : object
1470
+ The instance containing the following attributes:
1471
+ - num_filters : int
1472
+ Number of filters in each convolutional layer
1473
+ - kernel_size : int or tuple
1474
+ Size of the convolutional kernel
1475
+ - num_layers : int
1476
+ Number of convolutional layers in the network
1477
+ - dropout_rate : float
1478
+ Dropout rate for regularization
1479
+ - dilation_rate : int or tuple
1480
+ Dilation rate for dilated convolutions
1481
+ - activation : str or callable
1482
+ Activation function to use
1483
+ - inputsize : tuple
1484
+ Input dimensions for the model
1485
+ - device : torch.device
1486
+ Device to move the model to (CPU or GPU)
1487
+
1488
+ Returns
1489
+ -------
1490
+ None
1491
+ This method does not return any value. It modifies the instance
1492
+ by setting the `model` attribute to the created CNNModel.
1493
+
1494
+ Notes
1495
+ -----
1496
+ The method assumes that all required attributes are properly initialized
1497
+ before calling. The model is automatically moved to the specified device
1498
+ using the `.to()` method.
1499
+
1500
+ Examples
1501
+ --------
1502
+ >>> # Assuming all required attributes are set
1503
+ >>> makenet()
1504
+ >>> # Model is now available as self.model
1505
+ >>> print(self.model)
1506
+ CNNModel(...)
1507
+ """
1508
+ self.model = CNNModel(
1509
+ self.num_filters,
1510
+ self.kernel_size,
1511
+ self.num_layers,
1512
+ self.dropout_rate,
1513
+ self.dilation_rate,
1514
+ self.activation,
1515
+ self.inputsize,
1516
+ )
1517
+ self.model.to(self.device)
1518
+
1519
+
1520
+ class DenseAutoencoderModel(nn.Module):
1521
+ def __init__(
1522
+ self,
1523
+ window_size: int,
1524
+ encoding_dim: int,
1525
+ num_layers: int,
1526
+ dropout_rate: float,
1527
+ activation: str,
1528
+ inputsize: int,
1529
+ ) -> None:
1530
+ """
1531
+ Initialize a dense autoencoder model with configurable architecture.
1532
+
1533
+ This constructor builds a symmetric dense autoencoder with customizable number of layers,
1534
+ encoding dimension, dropout rate, and activation function. The model architecture follows
1535
+ a symmetric encoder-decoder structure, where the bottleneck layer reduces the feature space
1536
+ to `encoding_dim`, and then expands back to the original input size.
1537
+
1538
+ Parameters
1539
+ ----------
1540
+ window_size : int
1541
+ The size of the input window (number of time steps or features per sample).
1542
+ encoding_dim : int
1543
+ The dimension of the latent space (bottleneck layer size).
1544
+ num_layers : int
1545
+ Total number of layers in the autoencoder (including input and output).
1546
+ dropout_rate : float
1547
+ Dropout rate applied to all hidden layers to prevent overfitting.
1548
+ activation : str
1549
+ Activation function to use in hidden layers. Supported values: 'relu', 'tanh'.
1550
+ Defaults to 'relu' if an unsupported value is provided.
1551
+ inputsize : int
1552
+ The number of features per time step in the input data.
1553
+
1554
+ Returns
1555
+ -------
1556
+ None
1557
+ This method initializes the model in-place and does not return any value.
1558
+
1559
+ Notes
1560
+ -----
1561
+ - The model uses batch normalization after each linear layer except the output layer.
1562
+ - The architecture is symmetric: the number of neurons decreases from input to bottleneck
1563
+ and then increases back to the output size.
1564
+ - Logging is performed at various stages to track layer sizes during construction.
1565
+
1566
+ Examples
1567
+ --------
1568
+ >>> model = DenseAutoencoderModel(
1569
+ ... window_size=10,
1570
+ ... encoding_dim=5,
1571
+ ... num_layers=4,
1572
+ ... dropout_rate=0.2,
1573
+ ... activation="relu",
1574
+ ... inputsize=3
1575
+ ... )
1576
+ """
1577
+ super(DenseAutoencoderModel, self).__init__()
1578
+
1579
+ self.window_size = window_size
1580
+ self.encoding_dim = encoding_dim
1581
+ self.num_layers = num_layers
1582
+ self.dropout_rate = dropout_rate
1583
+ self.activation = activation
1584
+ self.inputsize = inputsize
1585
+
1586
+ self.layers = nn.ModuleList()
1587
+
1588
+ # Calculate initial size factor
1589
+ sizefac = 2
1590
+ for i in range(1, num_layers - 1):
1591
+ sizefac = int(sizefac * 2)
1592
+ LGR.info(f"input layer - sizefac: {sizefac}")
1593
+
1594
+ # Input layer
1595
+ self.layers.append(nn.Linear(window_size * inputsize, sizefac * encoding_dim))
1596
+ self.layers.append(nn.BatchNorm1d(sizefac * encoding_dim))
1597
+ self.layers.append(nn.Dropout(dropout_rate))
1598
+ if activation == "relu":
1599
+ self.layers.append(nn.ReLU())
1600
+ elif activation == "tanh":
1601
+ self.layers.append(nn.Tanh())
1602
+ else:
1603
+ self.layers.append(nn.ReLU())
1604
+
1605
+ # Encoding layers
1606
+ for i in range(1, num_layers - 1):
1607
+ sizefac = int(sizefac // 2)
1608
+ LGR.info(f"encoder layer {i + 1}, sizefac: {sizefac}")
1609
+ self.layers.append(nn.Linear(sizefac * 2 * encoding_dim, sizefac * encoding_dim))
1610
+ self.layers.append(nn.BatchNorm1d(sizefac * encoding_dim))
1611
+ self.layers.append(nn.Dropout(dropout_rate))
1612
+ if activation == "relu":
1613
+ self.layers.append(nn.ReLU())
1614
+ elif activation == "tanh":
1615
+ self.layers.append(nn.Tanh())
1616
+ else:
1617
+ self.layers.append(nn.ReLU())
1618
+
1619
+ # Encoding layer (bottleneck)
1620
+ sizefac = int(sizefac // 2)
1621
+ LGR.info(f"encoding layer - sizefac: {sizefac}")
1622
+ self.layers.append(nn.Linear(sizefac * 2 * encoding_dim, encoding_dim))
1623
+ self.layers.append(nn.BatchNorm1d(encoding_dim))
1624
+ self.layers.append(nn.Dropout(dropout_rate))
1625
+ if activation == "relu":
1626
+ self.layers.append(nn.ReLU())
1627
+ elif activation == "tanh":
1628
+ self.layers.append(nn.Tanh())
1629
+ else:
1630
+ self.layers.append(nn.ReLU())
1631
+
1632
+ # Decoding layers
1633
+ for i in range(1, num_layers):
1634
+ sizefac = int(sizefac * 2)
1635
+ LGR.info(f"decoding layer {i}, sizefac: {sizefac}")
1636
+ if i == 1:
1637
+ self.layers.append(nn.Linear(encoding_dim, sizefac * encoding_dim))
1638
+ else:
1639
+ self.layers.append(nn.Linear(sizefac // 2 * encoding_dim, sizefac * encoding_dim))
1640
+ self.layers.append(nn.BatchNorm1d(sizefac * encoding_dim))
1641
+ self.layers.append(nn.Dropout(dropout_rate))
1642
+ if activation == "relu":
1643
+ self.layers.append(nn.ReLU())
1644
+ elif activation == "tanh":
1645
+ self.layers.append(nn.Tanh())
1646
+ else:
1647
+ self.layers.append(nn.ReLU())
1648
+
1649
+ # Output layer
1650
+ self.layers.append(nn.Linear(sizefac * encoding_dim, window_size * inputsize))
1651
+
1652
+ def forward(self, x):
1653
+ """
1654
+ Forward pass through the network layers.
1655
+
1656
+ Applies a series of layers to the input tensor, flattening and reshaping
1657
+ as needed for processing.
1658
+
1659
+ Parameters
1660
+ ----------
1661
+ x : torch.Tensor
1662
+ Input tensor with shape (batch, channels, length) where batch is the
1663
+ number of samples, channels is the number of channels, and length is
1664
+ the sequence length.
1665
+
1666
+ Returns
1667
+ -------
1668
+ torch.Tensor
1669
+ Output tensor with shape (batch, channels, length) where the
1670
+ dimensions match the input shape after processing through the layers.
1671
+
1672
+ Notes
1673
+ -----
1674
+ The function first flattens the input from (batch, channels, length) to
1675
+ (batch, channels*length) to enable processing through linear layers,
1676
+ then reshapes the output back to the original format.
1677
+
1678
+ Examples
1679
+ --------
1680
+ >>> import torch
1681
+ >>> # Assuming self.layers contains appropriate layer definitions
1682
+ >>> x = torch.randn(32, 3, 100) # batch_size=32, channels=3, length=100
1683
+ >>> output = model.forward(x)
1684
+ >>> output.shape
1685
+ torch.Size([32, 3, 100])
1686
+ """
1687
+ # Flatten input from (batch, channels, length) to (batch, channels*length)
1688
+ batch_size = x.shape[0]
1689
+ x = x.reshape(batch_size, -1)
1690
+
1691
+ for layer in self.layers:
1692
+ x = layer(x)
1693
+
1694
+ # Reshape back to (batch, channels, length)
1695
+ x = x.reshape(batch_size, self.inputsize, self.window_size)
1696
+ return x
1697
+
1698
+ def get_config(self):
1699
+ """
1700
+ Get the configuration parameters of the model.
1701
+
1702
+ Returns
1703
+ -------
1704
+ dict
1705
+ A dictionary containing the model configuration parameters with the following keys:
1706
+ - "window_size" (int): The size of the sliding window used for input sequences
1707
+ - "encoding_dim" (int): The dimensionality of the encoding layer
1708
+ - "num_layers" (int): The number of layers in the model
1709
+ - "dropout_rate" (float): The dropout rate for regularization
1710
+ - "activation" (str): The activation function used in the model
1711
+ - "inputsize" (int): The size of the input features
1712
+
1713
+ Notes
1714
+ -----
1715
+ This method returns a copy of the internal configuration parameters. Modifications
1716
+ to the returned dictionary will not affect the original model configuration.
1717
+
1718
+ Examples
1719
+ --------
1720
+ >>> model = MyModel()
1721
+ >>> config = model.get_config()
1722
+ >>> print(config['window_size'])
1723
+ 10
1724
+ """
1725
+ return {
1726
+ "window_size": self.window_size,
1727
+ "encoding_dim": self.encoding_dim,
1728
+ "num_layers": self.num_layers,
1729
+ "dropout_rate": self.dropout_rate,
1730
+ "activation": self.activation,
1731
+ "inputsize": self.inputsize,
1732
+ }
1733
+
1734
+
1735
+ class DenseAutoencoderDLFilter(DeepLearningFilter):
1736
+ def __init__(self, encoding_dim: int = 10, *args, **kwargs) -> None:
1737
+ """
1738
+ Initialize DenseAutoencoderDLFilter instance.
1739
+
1740
+ Parameters
1741
+ ----------
1742
+ encoding_dim : int, default=10
1743
+ Dimension of the encoding layer in the autoencoder. This determines
1744
+ the size of the latent representation learned by the model.
1745
+ *args
1746
+ Variable length argument list passed to the parent class constructor.
1747
+ **kwargs
1748
+ Arbitrary keyword arguments passed to the parent class constructor.
1749
+
1750
+ Returns
1751
+ -------
1752
+ None
1753
+ This method initializes the instance and does not return any value.
1754
+
1755
+ Notes
1756
+ -----
1757
+ This constructor sets up the autoencoder architecture by:
1758
+ 1. Storing the encoding dimension as an instance attribute
1759
+ 2. Setting the network type to "autoencoder"
1760
+ 3. Updating the info dictionary with network type and encoding dimension
1761
+ 4. Calling the parent class constructor with any additional arguments
1762
+
1763
+ Examples
1764
+ --------
1765
+ >>> filter = DenseAutoencoderDLFilter(encoding_dim=15)
1766
+ >>> print(filter.encoding_dim)
1767
+ 15
1768
+ >>> print(filter.nettype)
1769
+ 'autoencoder'
1770
+ """
1771
+ self.encoding_dim = encoding_dim
1772
+ self.nettype = "autoencoder"
1773
+ self.infodict["nettype"] = self.nettype
1774
+ self.infodict["encoding_dim"] = self.encoding_dim
1775
+ super(DenseAutoencoderDLFilter, self).__init__(*args, **kwargs)
1776
+
1777
+ def getname(self):
1778
+ """
1779
+ Generate and return the model name and path based on current configuration parameters.
1780
+
1781
+ This method constructs a descriptive model name string using various configuration
1782
+ parameters and creates the corresponding directory path for model storage. The
1783
+ generated name includes information about window size, encoding dimensions, epochs,
1784
+ thresholds, step size, activation function, and additional flags.
1785
+
1786
+ Parameters
1787
+ ----------
1788
+ self : object
1789
+ The instance containing the following attributes:
1790
+ - window_size : int
1791
+ Size of the sliding window
1792
+ - encoding_dim : int
1793
+ Dimension of the encoding layer
1794
+ - num_epochs : int
1795
+ Number of training epochs
1796
+ - excludethresh : float
1797
+ Threshold for excluding data points
1798
+ - corrthresh : float
1799
+ Correlation threshold for filtering
1800
+ - step : int
1801
+ Step size for sliding window
1802
+ - activation : str
1803
+ Activation function name
1804
+ - usebadpts : bool
1805
+ Flag indicating whether to use bad points
1806
+ - excludebysubject : bool
1807
+ Flag indicating whether to exclude by subject
1808
+ - namesuffix : str, optional
1809
+ Additional suffix to append to model name
1810
+ - modelroot : str
1811
+ Root directory for model storage
1812
+
1813
+ Returns
1814
+ -------
1815
+ None
1816
+ This method modifies the instance attributes in-place:
1817
+ - self.modelname : str
1818
+ Generated model name string
1819
+ - self.modelpath : str
1820
+ Full path to the model directory
1821
+
1822
+ Notes
1823
+ -----
1824
+ The generated model name follows a specific format:
1825
+ "model_denseautoencoder_pytorch_wXXX_enXXX_eXXX_tX.XX_ctX.XX_sX_activation[flags]"
1826
+
1827
+ Examples
1828
+ --------
1829
+ >>> model = MyModel()
1830
+ >>> model.window_size = 100
1831
+ >>> model.encoding_dim = 50
1832
+ >>> model.num_epochs = 1000
1833
+ >>> model.excludethresh = 0.5
1834
+ >>> model.corrthresh = 0.8
1835
+ >>> model.step = 10
1836
+ >>> model.activation = 'relu'
1837
+ >>> model.usebadpts = True
1838
+ >>> model.excludebysubject = False
1839
+ >>> model.namesuffix = 'test'
1840
+ >>> model.modelroot = '/path/to/models'
1841
+ >>> model.getname()
1842
+ >>> print(model.modelname)
1843
+ 'model_denseautoencoder_pytorch_w100_en050_e1000_t0.5_ct0.8_s10_relu_usebadpts_test'
1844
+ """
1845
+ self.modelname = "_".join(
1846
+ [
1847
+ "model",
1848
+ "denseautoencoder",
1849
+ "pytorch",
1850
+ "w" + str(self.window_size).zfill(3),
1851
+ "en" + str(self.encoding_dim).zfill(3),
1852
+ "e" + str(self.num_epochs).zfill(3),
1853
+ "t" + str(self.excludethresh),
1854
+ "ct" + str(self.corrthresh),
1855
+ "s" + str(self.step),
1856
+ self.activation,
1857
+ ]
1858
+ )
1859
+ if self.usebadpts:
1860
+ self.modelname += "_usebadpts"
1861
+ if self.excludebysubject:
1862
+ self.modelname += "_excludebysubject"
1863
+ if self.namesuffix is not None:
1864
+ self.modelname += "_" + self.namesuffix
1865
+ self.modelpath = os.path.join(self.modelroot, self.modelname)
1866
+
1867
+ try:
1868
+ os.makedirs(self.modelpath)
1869
+ except OSError:
1870
+ pass
1871
+
1872
+ def makenet(self):
1873
+ """
1874
+ Create and configure a dense autoencoder model.
1875
+
1876
+ This method initializes a DenseAutoencoderModel with the specified parameters
1877
+ and moves it to the designated device (CPU or GPU).
1878
+
1879
+ Parameters
1880
+ ----------
1881
+ self : object
1882
+ The instance of the class containing this method. Expected to have the
1883
+ following attributes:
1884
+ - window_size : int
1885
+ - encoding_dim : int
1886
+ - num_layers : int
1887
+ - dropout_rate : float
1888
+ - activation : str or callable
1889
+ - inputsize : int
1890
+ - device : torch.device
1891
+
1892
+ Returns
1893
+ -------
1894
+ None
1895
+ This method does not return any value. It sets the model attribute
1896
+ of the instance to the created DenseAutoencoderModel.
1897
+
1898
+ Notes
1899
+ -----
1900
+ The model is automatically moved to the device specified by self.device.
1901
+ This method should be called after all required parameters have been set
1902
+ on the instance.
1903
+
1904
+ Examples
1905
+ --------
1906
+ >>> # Assuming a class with the makenet method
1907
+ >>> instance = MyClass()
1908
+ >>> instance.window_size = 100
1909
+ >>> instance.encoding_dim = 32
1910
+ >>> instance.num_layers = 3
1911
+ >>> instance.dropout_rate = 0.2
1912
+ >>> instance.activation = 'relu'
1913
+ >>> instance.inputsize = 100
1914
+ >>> instance.device = torch.device('cuda')
1915
+ >>> instance.makenet()
1916
+ >>> # Model is now available as instance.model
1917
+ """
1918
+ self.model = DenseAutoencoderModel(
1919
+ self.window_size,
1920
+ self.encoding_dim,
1921
+ self.num_layers,
1922
+ self.dropout_rate,
1923
+ self.activation,
1924
+ self.inputsize,
1925
+ )
1926
+ self.model.to(self.device)
1927
+
1928
+
1929
+ class MultiscaleCNNModel(nn.Module):
1930
+ def __init__(
1931
+ self,
1932
+ num_filters,
1933
+ kernel_sizes,
1934
+ input_lens,
1935
+ input_width,
1936
+ dilation_rate,
1937
+ ):
1938
+ """
1939
+ Initialize the MultiscaleCNNModel.
1940
+
1941
+ This constructor initializes a multiscale CNN model with three parallel branches
1942
+ processing input at different scales. Each branch uses dilated convolutions to
1943
+ capture features at different receptive fields.
1944
+
1945
+ Parameters
1946
+ ----------
1947
+ num_filters : int
1948
+ Number of filters (channels) to use in the convolutional layers.
1949
+ kernel_sizes : list of int
1950
+ List of three kernel sizes for the three branches (small, medium, large scales).
1951
+ input_lens : list of int
1952
+ List of input lengths for each branch, corresponding to the input sequence lengths.
1953
+ input_width : int
1954
+ Width of the input features (number of input channels).
1955
+ dilation_rate : int
1956
+ Dilation rate to use in the dilated convolutional layers.
1957
+
1958
+ Returns
1959
+ -------
1960
+ None
1961
+ This method initializes the model instance and does not return any value.
1962
+
1963
+ Notes
1964
+ -----
1965
+ The model creates three parallel branches with different kernel sizes to capture
1966
+ multi-scale temporal features. Each branch uses dilated convolutions to increase
1967
+ the receptive field without increasing the number of parameters significantly.
1968
+
1969
+ The final dense layer reduces the combined features to a single output value,
1970
+ followed by a sigmoid activation for binary classification.
1971
+
1972
+ Examples
1973
+ --------
1974
+ >>> model = MultiscaleCNNModel(
1975
+ ... num_filters=64,
1976
+ ... kernel_sizes=[3, 5, 7],
1977
+ ... input_lens=[100, 100, 100],
1978
+ ... input_width=10,
1979
+ ... dilation_rate=2
1980
+ ... )
1981
+ """
1982
+ super(MultiscaleCNNModel, self).__init__()
1983
+
1984
+ self.num_filters = num_filters
1985
+ self.kernel_sizes = kernel_sizes
1986
+ self.input_lens = input_lens
1987
+ self.input_width = input_width
1988
+ self.dilation_rate = dilation_rate
1989
+
1990
+ # Create three separate branches for different scales
1991
+ self.branch_small = self._make_branch(kernel_sizes[0])
1992
+ self.branch_med = self._make_branch(kernel_sizes[1])
1993
+ self.branch_large = self._make_branch(kernel_sizes[2])
1994
+
1995
+ # Final dense layer
1996
+ self.fc = nn.Linear(150, 1)
1997
+ self.sigmoid = nn.Sigmoid()
1998
+
1999
+ def _make_branch(self, kernel_size):
2000
+ """
2001
+ Create a convolutional branch for the neural network architecture.
2002
+
2003
+ This method constructs a sequential neural network branch consisting of
2004
+ convolutional, pooling, flattening, linear, activation, and dropout layers.
2005
+
2006
+ Parameters
2007
+ ----------
2008
+ kernel_size : int
2009
+ The size of the convolutional kernel to be used in the Conv1d layer.
2010
+
2011
+ Returns
2012
+ -------
2013
+ torch.nn.Sequential
2014
+ A sequential container containing the following layers:
2015
+ - Conv1d: 1D convolutional layer with input_width as input channels
2016
+ - AdaptiveMaxPool1d: 1D adaptive max pooling with output size 1
2017
+ - Flatten: Flattens the tensor for linear layer input
2018
+ - Linear: Linear layer with num_filters input features and 50 output features
2019
+ - Tanh: Hyperbolic tangent activation function
2020
+ - Dropout: 30% dropout regularization
2021
+
2022
+ Notes
2023
+ -----
2024
+ The branch is designed to process 1D input data through a convolutional
2025
+ feature extraction pathway followed by a fully connected classifier head.
2026
+ The padding="same" parameter ensures the output size matches the input size
2027
+ for the convolutional layer.
2028
+
2029
+ Examples
2030
+ --------
2031
+ >>> branch = self._make_branch(kernel_size=3)
2032
+ >>> print(type(branch))
2033
+ <class 'torch.nn.modules.container.Sequential'>
2034
+ """
2035
+ return nn.Sequential(
2036
+ nn.Conv1d(self.input_width, self.num_filters, kernel_size, padding="same"),
2037
+ nn.AdaptiveMaxPool1d(1),
2038
+ nn.Flatten(),
2039
+ nn.Linear(self.num_filters, 50),
2040
+ nn.Tanh(),
2041
+ nn.Dropout(0.3),
2042
+ )
2043
+
2044
+ def forward(self, x_small, x_med, x_large):
2045
+ """
2046
+ Forward pass of the multi-scale feature extraction network.
2047
+
2048
+ This function processes input tensors through three parallel branches with different
2049
+ receptive fields and concatenates the outputs before applying a final fully connected
2050
+ layer with sigmoid activation.
2051
+
2052
+ Parameters
2053
+ ----------
2054
+ x_small : torch.Tensor
2055
+ Input tensor for the small-scale branch with shape (batch_size, channels, height, width)
2056
+ x_med : torch.Tensor
2057
+ Input tensor for the medium-scale branch with shape (batch_size, channels, height, width)
2058
+ x_large : torch.Tensor
2059
+ Input tensor for the large-scale branch with shape (batch_size, channels, height, width)
2060
+
2061
+ Returns
2062
+ -------
2063
+ torch.Tensor
2064
+ Output tensor with shape (batch_size, num_classes) containing sigmoid-activated
2065
+ predictions for each class
2066
+
2067
+ Notes
2068
+ -----
2069
+ The function assumes that `self.branch_small`, `self.branch_med`, `self.branch_large`,
2070
+ `self.fc`, and `self.sigmoid` are properly initialized components of the class.
2071
+
2072
+ Examples
2073
+ --------
2074
+ >>> import torch
2075
+ >>> # Assuming model is initialized
2076
+ >>> x_small = torch.randn(1, 3, 32, 32)
2077
+ >>> x_med = torch.randn(1, 3, 64, 64)
2078
+ >>> x_large = torch.randn(1, 3, 128, 128)
2079
+ >>> output = model.forward(x_small, x_med, x_large)
2080
+ >>> print(output.shape)
2081
+ torch.Size([1, num_classes])
2082
+ """
2083
+ # Process each branch
2084
+ out_small = self.branch_small(x_small)
2085
+ out_med = self.branch_med(x_med)
2086
+ out_large = self.branch_large(x_large)
2087
+
2088
+ # Concatenate outputs
2089
+ merged = torch.cat([out_small, out_med, out_large], dim=1)
2090
+
2091
+ # Final output
2092
+ out = self.fc(merged)
2093
+ out = self.sigmoid(out)
2094
+
2095
+ return out
2096
+
2097
+ def get_config(self):
2098
+ """
2099
+ Get the configuration parameters of the layer.
2100
+
2101
+ Returns
2102
+ -------
2103
+ dict
2104
+ A dictionary containing the layer configuration parameters with the following keys:
2105
+ - "num_filters" (int): Number of filters in the layer
2106
+ - "kernel_sizes" (list of int): Size of the convolutional kernels
2107
+ - "input_lens" (list of int): Lengths of input sequences
2108
+ - "input_width" (int): Width of the input data
2109
+ - "dilation_rate" (int): Dilation rate for dilated convolution
2110
+
2111
+ Notes
2112
+ -----
2113
+ This method returns a copy of the internal configuration parameters
2114
+ that can be used to reconstruct the layer with the same settings.
2115
+
2116
+ Examples
2117
+ --------
2118
+ >>> config = layer.get_config()
2119
+ >>> print(config['num_filters'])
2120
+ 32
2121
+ """
2122
+ return {
2123
+ "num_filters": self.num_filters,
2124
+ "kernel_sizes": self.kernel_sizes,
2125
+ "input_lens": self.input_lens,
2126
+ "input_width": self.input_width,
2127
+ "dilation_rate": self.dilation_rate,
2128
+ }
2129
+
2130
+
2131
+ class MultiscaleCNNDLFilter(DeepLearningFilter):
2132
+ def __init__(
2133
+ self,
2134
+ num_filters=10,
2135
+ kernel_sizes=[4, 8, 12],
2136
+ input_lens=[64, 128, 192],
2137
+ input_width=1,
2138
+ dilation_rate=1,
2139
+ *args,
2140
+ **kwargs,
2141
+ ):
2142
+ """
2143
+ Initialize the MultiscaleCNNDLFilter layer.
2144
+
2145
+ This constructor initializes a multiscale CNN filter with configurable
2146
+ kernel sizes, input lengths, and dilation rates for multi-scale feature extraction.
2147
+
2148
+ Parameters
2149
+ ----------
2150
+ num_filters : int, optional
2151
+ Number of filters to use in each convolutional layer, default is 10
2152
+ kernel_sizes : list of int, optional
2153
+ List of kernel sizes for different convolutional layers, default is [4, 8, 12]
2154
+ input_lens : list of int, optional
2155
+ List of input sequence lengths for different scales, default is [64, 128, 192]
2156
+ input_width : int, optional
2157
+ Width of the input data, default is 1
2158
+ dilation_rate : int, optional
2159
+ Dilation rate for the convolutional layers, default is 1
2160
+ *args : tuple
2161
+ Variable length argument list passed to parent class
2162
+ **kwargs : dict
2163
+ Arbitrary keyword arguments passed to parent class
2164
+
2165
+ Returns
2166
+ -------
2167
+ None
2168
+ This method initializes the object and does not return any value
2169
+
2170
+ Notes
2171
+ -----
2172
+ The multiscale CNN filter uses multiple convolutional layers with different
2173
+ kernel sizes and dilation rates to capture features at multiple scales.
2174
+ The input data is processed through parallel convolutional branches,
2175
+ each with different kernel sizes and dilation rates.
2176
+
2177
+ Examples
2178
+ --------
2179
+ >>> filter_layer = MultiscaleCNNDLFilter(
2180
+ ... num_filters=20,
2181
+ ... kernel_sizes=[3, 6, 9],
2182
+ ... input_lens=[32, 64, 128],
2183
+ ... input_width=2,
2184
+ ... dilation_rate=2
2185
+ ... )
2186
+ """
2187
+ self.num_filters = num_filters
2188
+ self.kernel_sizes = kernel_sizes
2189
+ self.input_lens = input_lens
2190
+ self.input_width = input_width
2191
+ self.dilation_rate = dilation_rate
2192
+ self.nettype = "multiscalecnn"
2193
+ self.infodict["nettype"] = self.nettype
2194
+ self.infodict["num_filters"] = self.num_filters
2195
+ self.infodict["kernel_sizes"] = self.kernel_sizes
2196
+ self.infodict["input_lens"] = self.input_lens
2197
+ self.infodict["input_width"] = self.input_width
2198
+ super(MultiscaleCNNDLFilter, self).__init__(*args, **kwargs)
2199
+
2200
+ def getname(self):
2201
+ """
2202
+ Generate and return the model name and path based on current configuration parameters.
2203
+
2204
+ This method constructs a descriptive model name string by joining various configuration
2205
+ parameters with specific prefixes and zero-padded numeric values. The resulting name
2206
+ is used to create a unique directory path for model storage.
2207
+
2208
+ Parameters
2209
+ ----------
2210
+ self : object
2211
+ The instance containing model configuration parameters.
2212
+
2213
+ Returns
2214
+ -------
2215
+ None
2216
+ This method does not return a value but sets the following attributes:
2217
+ - self.modelname: str, the generated model name
2218
+ - self.modelpath: str, the full path to the model directory
2219
+
2220
+ Notes
2221
+ -----
2222
+ The generated model name includes the following components:
2223
+ - Model type: "model_multiscalecnn_pytorch"
2224
+ - Window size: "w" + zero-padded window size
2225
+ - Number of layers: "l" + zero-padded layer count
2226
+ - Number of filters: "fn" + zero-padded filter count
2227
+ - First kernel size: "fl" + zero-padded kernel size
2228
+ - Number of epochs: "e" + zero-padded epoch count
2229
+ - Exclusion threshold: "t" + threshold value
2230
+ - Correlation threshold: "ct" + threshold value
2231
+ - Step size: "s" + zero-padded step value
2232
+ - Dilation rate: "d" + dilation rate value
2233
+ - Activation function name
2234
+
2235
+ Additional suffixes are appended if:
2236
+ - usebadpts is True: "_usebadpts"
2237
+ - excludebysubject is True: "_excludebysubject"
2238
+ - namesuffix is not None: "_{namesuffix}"
2239
+
2240
+ Examples
2241
+ --------
2242
+ >>> model = MyModel()
2243
+ >>> model.window_size = 128
2244
+ >>> model.num_layers = 5
2245
+ >>> model.num_filters = 32
2246
+ >>> model.kernel_sizes = [3, 5, 7]
2247
+ >>> model.num_epochs = 100
2248
+ >>> model.excludethresh = 0.5
2249
+ >>> model.corrthresh = 0.8
2250
+ >>> model.step = 16
2251
+ >>> model.dilation_rate = 2
2252
+ >>> model.activation = "relu"
2253
+ >>> model.usebadpts = True
2254
+ >>> model.excludebysubject = False
2255
+ >>> model.namesuffix = "exp1"
2256
+ >>> model.getname()
2257
+ >>> print(model.modelname)
2258
+ 'model_multiscalecnn_pytorch_w128_l05_fn32_fl03_e100_t0.5_ct0.8_s16_d2_relu_usebadpts_exp1'
2259
+ """
2260
+ self.modelname = "_".join(
2261
+ [
2262
+ "model",
2263
+ "multiscalecnn",
2264
+ "pytorch",
2265
+ "w" + str(self.window_size).zfill(3),
2266
+ "l" + str(self.num_layers).zfill(2),
2267
+ "fn" + str(self.num_filters).zfill(2),
2268
+ "fl" + str(self.kernel_sizes[0]).zfill(2),
2269
+ "e" + str(self.num_epochs).zfill(3),
2270
+ "t" + str(self.excludethresh),
2271
+ "ct" + str(self.corrthresh),
2272
+ "s" + str(self.step),
2273
+ "d" + str(self.dilation_rate),
2274
+ self.activation,
2275
+ ]
2276
+ )
2277
+ if self.usebadpts:
2278
+ self.modelname += "_usebadpts"
2279
+ if self.excludebysubject:
2280
+ self.modelname += "_excludebysubject"
2281
+ if self.namesuffix is not None:
2282
+ self.modelname += "_" + self.namesuffix
2283
+ self.modelpath = os.path.join(self.modelroot, self.modelname)
2284
+
2285
+ try:
2286
+ os.makedirs(self.modelpath)
2287
+ except OSError:
2288
+ pass
2289
+
2290
+ def makenet(self):
2291
+ """
2292
+ Create and initialize a multiscale CNN model for network construction.
2293
+
2294
+ This method initializes a MultiscaleCNNModel with the specified parameters
2295
+ and moves the model to the designated device (CPU or GPU).
2296
+
2297
+ Parameters
2298
+ ----------
2299
+ self : object
2300
+ The instance containing the following attributes:
2301
+ - num_filters : int
2302
+ Number of filters for the CNN layers
2303
+ - kernel_sizes : list of int
2304
+ List of kernel sizes for different scales
2305
+ - input_lens : list of int
2306
+ List of input lengths for different scales
2307
+ - input_width : int
2308
+ Width of the input data
2309
+ - dilation_rate : int
2310
+ Dilation rate for the convolutional layers
2311
+ - device : torch.device
2312
+ Device to move the model to (e.g., 'cuda' or 'cpu')
2313
+
2314
+ Returns
2315
+ -------
2316
+ None
2317
+ This method does not return any value but modifies the instance
2318
+ by setting the `model` attribute to the created MultiscaleCNNModel.
2319
+
2320
+ Notes
2321
+ -----
2322
+ The method assumes that all required attributes are properly initialized
2323
+ in the instance before calling this method. The model is automatically
2324
+ moved to the specified device using the `.to()` method.
2325
+
2326
+ Examples
2327
+ --------
2328
+ >>> # Assuming instance with required attributes is created
2329
+ >>> instance.makenet()
2330
+ >>> print(instance.model)
2331
+ MultiscaleCNNModel(...)
2332
+ """
2333
+ self.model = MultiscaleCNNModel(
2334
+ self.num_filters,
2335
+ self.kernel_sizes,
2336
+ self.input_lens,
2337
+ self.input_width,
2338
+ self.dilation_rate,
2339
+ )
2340
+ self.model.to(self.device)
2341
+
2342
+
2343
+ class ConvAutoencoderModel(nn.Module):
2344
+ def __init__(
2345
+ self,
2346
+ window_size,
2347
+ encoding_dim,
2348
+ num_filters,
2349
+ kernel_size,
2350
+ dropout_rate,
2351
+ activation,
2352
+ inputsize,
2353
+ ):
2354
+ """
2355
+ Initialize the ConvAutoencoderModel.
2356
+
2357
+ This class implements a convolutional autoencoder for time series data. The model
2358
+ consists of an encoder and a decoder, with symmetric architecture. The encoder
2359
+ reduces the input dimensionality through convolutional and pooling layers, while
2360
+ the decoder reconstructs the input from the encoded representation.
2361
+
2362
+ Parameters
2363
+ ----------
2364
+ window_size : int
2365
+ The length of the input time series window.
2366
+ encoding_dim : int
2367
+ The dimensionality of the latent space representation.
2368
+ num_filters : int
2369
+ The number of filters in the first convolutional layer.
2370
+ kernel_size : int
2371
+ The size of the convolutional kernels.
2372
+ dropout_rate : float
2373
+ The dropout rate applied after each convolutional layer.
2374
+ activation : str
2375
+ The activation function to use. Supported values are "relu" and "tanh".
2376
+ inputsize : int
2377
+ The number of input channels (e.g., number of features in the time series).
2378
+
2379
+ Returns
2380
+ -------
2381
+ None
2382
+ This method initializes the model in-place and does not return any value.
2383
+
2384
+ Notes
2385
+ -----
2386
+ The model uses a symmetric encoder-decoder architecture. The encoder reduces
2387
+ the input size through 4 max-pooling layers, and the decoder reconstructs
2388
+ the input using upsample and convolutional layers. The final layer uses
2389
+ a convolution with padding to match the input size.
2390
+
2391
+ Examples
2392
+ --------
2393
+ >>> model = ConvAutoencoderModel(
2394
+ ... window_size=100,
2395
+ ... encoding_dim=32,
2396
+ ... num_filters=32,
2397
+ ... kernel_size=3,
2398
+ ... dropout_rate=0.2,
2399
+ ... activation="relu",
2400
+ ... inputsize=1
2401
+ ... )
2402
+ """
2403
+ super(ConvAutoencoderModel, self).__init__()
2404
+
2405
+ self.window_size = window_size
2406
+ self.encoding_dim = encoding_dim
2407
+ self.num_filters = num_filters
2408
+ self.kernel_size = kernel_size
2409
+ self.dropout_rate = dropout_rate
2410
+ self.activation = activation
2411
+ self.inputsize = inputsize
2412
+
2413
+ # Get activation function
2414
+ if activation == "relu":
2415
+ act_fn = nn.ReLU
2416
+ elif activation == "tanh":
2417
+ act_fn = nn.Tanh
2418
+ else:
2419
+ act_fn = nn.ReLU
2420
+
2421
+ # Initial conv block
2422
+ self.encoder_layers = nn.ModuleList()
2423
+ self.encoder_layers.append(nn.Conv1d(inputsize, num_filters, kernel_size, padding="same"))
2424
+ self.encoder_layers.append(nn.BatchNorm1d(num_filters))
2425
+ self.encoder_layers.append(nn.Dropout(dropout_rate))
2426
+ self.encoder_layers.append(act_fn())
2427
+ self.encoder_layers.append(nn.MaxPool1d(2, padding=1))
2428
+
2429
+ # Encoding path (3 layers)
2430
+ nfilters = num_filters
2431
+ self.filter_list = []
2432
+ for _ in range(3):
2433
+ nfilters *= 2
2434
+ self.filter_list.append(nfilters)
2435
+ self.encoder_layers.append(
2436
+ nn.Conv1d(nfilters // 2, nfilters, kernel_size, padding="same")
2437
+ )
2438
+ self.encoder_layers.append(nn.BatchNorm1d(nfilters))
2439
+ self.encoder_layers.append(nn.Dropout(dropout_rate))
2440
+ self.encoder_layers.append(act_fn())
2441
+ self.encoder_layers.append(nn.MaxPool1d(2, padding=1))
2442
+
2443
+ # Calculate size after pooling
2444
+ self.encoded_size = window_size
2445
+ for _ in range(4): # 4 pooling layers
2446
+ self.encoded_size = (self.encoded_size + 1) // 2
2447
+
2448
+ # Bottleneck
2449
+ self.flatten = nn.Flatten()
2450
+ self.encode_fc = nn.Linear(nfilters * self.encoded_size, encoding_dim)
2451
+ self.encode_act = act_fn()
2452
+ self.decode_fc = nn.Linear(encoding_dim, nfilters * self.encoded_size)
2453
+ self.decode_act = act_fn()
2454
+ self.unflatten_size = (nfilters, self.encoded_size)
2455
+
2456
+ # Decoding path (mirror)
2457
+ self.decoder_layers = nn.ModuleList()
2458
+ for i, filters in enumerate(reversed(self.filter_list)):
2459
+ self.decoder_layers.append(nn.Upsample(scale_factor=2, mode="nearest"))
2460
+ if i == 0:
2461
+ self.decoder_layers.append(
2462
+ nn.Conv1d(nfilters, filters, kernel_size, padding="same")
2463
+ )
2464
+ else:
2465
+ self.decoder_layers.append(
2466
+ nn.Conv1d(self.filter_list[-i], filters, kernel_size, padding="same")
2467
+ )
2468
+ self.decoder_layers.append(nn.BatchNorm1d(filters))
2469
+ self.decoder_layers.append(nn.Dropout(dropout_rate))
2470
+ self.decoder_layers.append(act_fn())
2471
+
2472
+ # Final upsampling
2473
+ self.decoder_layers.append(nn.Upsample(scale_factor=2, mode="nearest"))
2474
+ self.decoder_layers.append(nn.Conv1d(num_filters, inputsize, kernel_size, padding="same"))
2475
+
2476
+ def forward(self, x):
2477
+ """
2478
+ Forward pass of the autoencoder.
2479
+
2480
+ Applies encoding, bottleneck processing, and decoding to the input tensor.
2481
+
2482
+ Parameters
2483
+ ----------
2484
+ x : torch.Tensor
2485
+ Input tensor of shape (batch_size, channels, height, width) where
2486
+ height and width should match the expected window size.
2487
+
2488
+ Returns
2489
+ -------
2490
+ torch.Tensor
2491
+ Reconstructed tensor of shape (batch_size, channels, window_size, width)
2492
+ with the same spatial dimensions as the input.
2493
+
2494
+ Notes
2495
+ -----
2496
+ The forward pass consists of three main stages:
2497
+ 1. Encoding: Input is passed through encoder layers
2498
+ 2. Bottleneck: Flattening, encoding, and decoding with activation functions
2499
+ 3. Decoding: Reconstructed features are passed through decoder layers
2500
+
2501
+ The output is cropped or padded to match the original window size.
2502
+
2503
+ Examples
2504
+ --------
2505
+ >>> import torch
2506
+ >>> model = AutoEncoder()
2507
+ >>> x = torch.randn(1, 3, 64, 64)
2508
+ >>> output = model.forward(x)
2509
+ >>> print(output.shape)
2510
+ torch.Size([1, 3, 64, 64])
2511
+ """
2512
+ # Encoding
2513
+ for layer in self.encoder_layers:
2514
+ x = layer(x)
2515
+
2516
+ # Bottleneck
2517
+ x = self.flatten(x)
2518
+ x = self.encode_fc(x)
2519
+ x = self.encode_act(x)
2520
+ x = self.decode_fc(x)
2521
+ x = self.decode_act(x)
2522
+ x = x.view(x.size(0), *self.unflatten_size)
2523
+
2524
+ # Decoding
2525
+ for layer in self.decoder_layers:
2526
+ x = layer(x)
2527
+
2528
+ # Crop/pad to original window size
2529
+ if x.size(2) > self.window_size:
2530
+ x = x[:, :, : self.window_size]
2531
+ elif x.size(2) < self.window_size:
2532
+ pad_size = self.window_size - x.size(2)
2533
+ x = nn.functional.pad(x, (0, pad_size))
2534
+
2535
+ return x
2536
+
2537
+ def get_config(self):
2538
+ """
2539
+ Get the configuration parameters of the model.
2540
+
2541
+ Returns
2542
+ -------
2543
+ dict
2544
+ A dictionary containing all configuration parameters with their current values:
2545
+ - "window_size" (int): Size of the sliding window
2546
+ - "encoding_dim" (int): Dimension of the encoding layer
2547
+ - "num_filters" (int): Number of filters in the convolutional layers
2548
+ - "kernel_size" (int): Size of the convolutional kernel
2549
+ - "dropout_rate" (float): Dropout rate for regularization
2550
+ - "activation" (str): Activation function to use
2551
+ - "inputsize" (int): Size of the input data
2552
+
2553
+ Notes
2554
+ -----
2555
+ This method returns a copy of the current configuration. Modifications to the
2556
+ returned dictionary will not affect the original model configuration.
2557
+
2558
+ Examples
2559
+ --------
2560
+ >>> config = model.get_config()
2561
+ >>> print(config['window_size'])
2562
+ 100
2563
+ """
2564
+ return {
2565
+ "window_size": self.window_size,
2566
+ "encoding_dim": self.encoding_dim,
2567
+ "num_filters": self.num_filters,
2568
+ "kernel_size": self.kernel_size,
2569
+ "dropout_rate": self.dropout_rate,
2570
+ "activation": self.activation,
2571
+ "inputsize": self.inputsize,
2572
+ }
2573
+
2574
+
2575
+ class ConvAutoencoderDLFilter(DeepLearningFilter):
2576
+ def __init__(
2577
+ self,
2578
+ encoding_dim: int = 10,
2579
+ num_filters: int = 5,
2580
+ kernel_size: int = 5,
2581
+ dilation_rate: int = 1,
2582
+ *args,
2583
+ **kwargs,
2584
+ ) -> None:
2585
+ """
2586
+ Initialize ConvAutoencoderDLFilter instance.
2587
+
2588
+ Parameters
2589
+ ----------
2590
+ encoding_dim : int, optional
2591
+ Dimension of the encoded representation, by default 10
2592
+ num_filters : int, optional
2593
+ Number of filters in the convolutional layers, by default 5
2594
+ kernel_size : int, optional
2595
+ Size of the convolutional kernel, by default 5
2596
+ dilation_rate : int, optional
2597
+ Dilation rate for the convolutional layers, by default 1
2598
+ *args
2599
+ Variable length argument list
2600
+ **kwargs
2601
+ Arbitrary keyword arguments
2602
+
2603
+ Returns
2604
+ -------
2605
+ None
2606
+ This method does not return any value
2607
+
2608
+ Notes
2609
+ -----
2610
+ This constructor initializes a convolutional autoencoder with dilated filters.
2611
+ The network type is set to "convautoencoder" and various configuration parameters
2612
+ are stored in the infodict for later reference.
2613
+
2614
+ Examples
2615
+ --------
2616
+ >>> autoencoder = ConvAutoencoderDLFilter(
2617
+ ... encoding_dim=15,
2618
+ ... num_filters=8,
2619
+ ... kernel_size=3,
2620
+ ... dilation_rate=2
2621
+ ... )
2622
+ """
2623
+ self.encoding_dim = encoding_dim
2624
+ self.num_filters = num_filters
2625
+ self.kernel_size = kernel_size
2626
+ self.dilation_rate = dilation_rate
2627
+ self.nettype = "convautoencoder"
2628
+ self.infodict["num_filters"] = self.num_filters
2629
+ self.infodict["kernel_size"] = self.kernel_size
2630
+ self.infodict["nettype"] = self.nettype
2631
+ self.infodict["encoding_dim"] = self.encoding_dim
2632
+ super(ConvAutoencoderDLFilter, self).__init__(*args, **kwargs)
2633
+
2634
+ def getname(self):
2635
+ """
2636
+ Generate and configure the model name and path based on current parameters.
2637
+
2638
+ This method constructs a descriptive model name string using various
2639
+ configuration parameters and creates the corresponding directory path.
2640
+ The generated name includes information about window size, encoding dimensions,
2641
+ filters, kernel size, epochs, thresholds, and other model configuration options.
2642
+
2643
+ Parameters
2644
+ ----------
2645
+ self : object
2646
+ The instance of the class containing model configuration parameters.
2647
+ Expected attributes include:
2648
+ - window_size : int
2649
+ - encoding_dim : int
2650
+ - num_filters : int
2651
+ - kernel_size : int
2652
+ - num_epochs : int
2653
+ - excludethresh : float
2654
+ - corrthresh : float
2655
+ - step : int
2656
+ - activation : str
2657
+ - usebadpts : bool
2658
+ - excludebysubject : bool
2659
+ - namesuffix : str, optional
2660
+ - modelroot : str
2661
+
2662
+ Returns
2663
+ -------
2664
+ None
2665
+ This method does not return a value but modifies the instance attributes:
2666
+ - self.modelname : str
2667
+ - self.modelpath : str
2668
+
2669
+ Notes
2670
+ -----
2671
+ The model name is constructed with the following components:
2672
+ - "model_convautoencoder_pytorch" as base identifier
2673
+ - Window size with 3-digit zero-padded formatting (wXXX)
2674
+ - Encoding dimension with 3-digit zero-padded formatting (enXXX)
2675
+ - Number of filters with 2-digit zero-padded formatting (fnXX)
2676
+ - Kernel size with 2-digit zero-padded formatting (flXX)
2677
+ - Number of epochs with 3-digit zero-padded formatting (eXXX)
2678
+ - Exclusion threshold (tX.XX)
2679
+ - Correlation threshold (ctX.XX)
2680
+ - Step size (sX)
2681
+ - Activation function name
2682
+
2683
+ Additional suffixes are appended based on:
2684
+ - usebadpts flag
2685
+ - excludebysubject flag
2686
+ - namesuffix parameter
2687
+
2688
+ Examples
2689
+ --------
2690
+ >>> model = MyModel()
2691
+ >>> model.window_size = 100
2692
+ >>> model.encoding_dim = 50
2693
+ >>> model.getname()
2694
+ >>> print(model.modelname)
2695
+ 'model_convautoencoder_pytorch_w100_en050_fn10_fl05_e001_t0.5_ct0.8_s1_relu'
2696
+ """
2697
+ self.modelname = "_".join(
2698
+ [
2699
+ "model",
2700
+ "convautoencoder",
2701
+ "pytorch",
2702
+ "w" + str(self.window_size).zfill(3),
2703
+ "en" + str(self.encoding_dim).zfill(3),
2704
+ "fn" + str(self.num_filters).zfill(2),
2705
+ "fl" + str(self.kernel_size).zfill(2),
2706
+ "e" + str(self.num_epochs).zfill(3),
2707
+ "t" + str(self.excludethresh),
2708
+ "ct" + str(self.corrthresh),
2709
+ "s" + str(self.step),
2710
+ self.activation,
2711
+ ]
2712
+ )
2713
+ if self.usebadpts:
2714
+ self.modelname += "_usebadpts"
2715
+ if self.excludebysubject:
2716
+ self.modelname += "_excludebysubject"
2717
+ if self.namesuffix is not None:
2718
+ self.modelname += "_" + self.namesuffix
2719
+ self.modelpath = os.path.join(self.modelroot, self.modelname)
2720
+
2721
+ try:
2722
+ os.makedirs(self.modelpath)
2723
+ except OSError:
2724
+ pass
2725
+
2726
+ def makenet(self):
2727
+ """
2728
+ Create and initialize a convolutional autoencoder model.
2729
+
2730
+ This method constructs a ConvAutoencoderModel with the specified parameters
2731
+ and moves it to the designated device (CPU or GPU).
2732
+
2733
+ Parameters
2734
+ ----------
2735
+ self : object
2736
+ The instance of the class containing this method. Expected to have the
2737
+ following attributes:
2738
+ - window_size : int
2739
+ Size of the input window
2740
+ - encoding_dim : int
2741
+ Dimension of the encoded representation
2742
+ - num_filters : int
2743
+ Number of filters in the convolutional layers
2744
+ - kernel_size : int
2745
+ Size of the convolutional kernel
2746
+ - dropout_rate : float
2747
+ Dropout rate for regularization
2748
+ - activation : str or callable
2749
+ Activation function to use
2750
+ - inputsize : tuple
2751
+ Input size dimensions
2752
+ - device : torch.device
2753
+ Device to move the model to (CPU or GPU)
2754
+
2755
+ Returns
2756
+ -------
2757
+ None
2758
+ This method does not return any value. It initializes the model
2759
+ attribute of the class instance.
2760
+
2761
+ Notes
2762
+ -----
2763
+ The method assumes that `ConvAutoencoderModel` is a valid class that accepts
2764
+ the specified parameters. The model is automatically moved to the device
2765
+ specified by `self.device`.
2766
+
2767
+ Examples
2768
+ --------
2769
+ >>> class MyModel:
2770
+ ... def __init__(self):
2771
+ ... self.window_size = 100
2772
+ ... self.encoding_dim = 32
2773
+ ... self.num_filters = 64
2774
+ ... self.kernel_size = 3
2775
+ ... self.dropout_rate = 0.2
2776
+ ... self.activation = 'relu'
2777
+ ... self.inputsize = (1, 100)
2778
+ ... self.device = torch.device('cpu')
2779
+ ... self.model = None
2780
+ ...
2781
+ ... def makenet(self):
2782
+ ... self.model = ConvAutoencoderModel(
2783
+ ... self.window_size,
2784
+ ... self.encoding_dim,
2785
+ ... self.num_filters,
2786
+ ... self.kernel_size,
2787
+ ... self.dropout_rate,
2788
+ ... self.activation,
2789
+ ... self.inputsize,
2790
+ ... )
2791
+ ... self.model.to(self.device)
2792
+ ...
2793
+ >>> model = MyModel()
2794
+ >>> model.makenet()
2795
+ >>> print(model.model)
2796
+ """
2797
+ self.model = ConvAutoencoderModel(
2798
+ self.window_size,
2799
+ self.encoding_dim,
2800
+ self.num_filters,
2801
+ self.kernel_size,
2802
+ self.dropout_rate,
2803
+ self.activation,
2804
+ self.inputsize,
2805
+ )
2806
+ self.model.to(self.device)
2807
+
2808
+
2809
+ class CRNNModel(nn.Module):
2810
+ def __init__(
2811
+ self, num_filters, kernel_size, encoding_dim, dropout_rate, activation, inputsize
2812
+ ):
2813
+ """
2814
+ Initialize the CRNNModel.
2815
+
2816
+ This function initializes a Convolutional Recurrent Neural Network (CRNN) model
2817
+ with convolutional front-end, bidirectional LSTM layers, and output mapping.
2818
+ The model processes sequential data through convolutional layers, applies
2819
+ bidirectional LSTM encoding, and maps the output back to the original input size.
2820
+
2821
+ Parameters
2822
+ ----------
2823
+ num_filters : int
2824
+ Number of filters in the convolutional layers
2825
+ kernel_size : int
2826
+ Size of the convolutional kernel
2827
+ encoding_dim : int
2828
+ Dimension of the LSTM encoding (hidden state size)
2829
+ dropout_rate : float
2830
+ Dropout rate for regularization
2831
+ activation : str
2832
+ Activation function to use ('relu' or 'tanh')
2833
+ inputsize : int
2834
+ Size of the input features
2835
+
2836
+ Returns
2837
+ -------
2838
+ None
2839
+ Initializes the CRNNModel instance
2840
+
2841
+ Notes
2842
+ -----
2843
+ The model uses a bidirectional LSTM with batch_first=True.
2844
+ The convolutional layers use 'same' padding to maintain sequence length.
2845
+ Default activation function is ReLU if an invalid activation is provided.
2846
+
2847
+ Examples
2848
+ --------
2849
+ >>> model = CRNNModel(
2850
+ ... num_filters=32,
2851
+ ... kernel_size=3,
2852
+ ... encoding_dim=64,
2853
+ ... dropout_rate=0.2,
2854
+ ... activation='relu',
2855
+ ... inputsize=128
2856
+ ... )
2857
+ """
2858
+ super(CRNNModel, self).__init__()
2859
+
2860
+ self.num_filters = num_filters
2861
+ self.kernel_size = kernel_size
2862
+ self.encoding_dim = encoding_dim
2863
+ self.dropout_rate = dropout_rate
2864
+ self.activation = activation
2865
+ self.inputsize = inputsize
2866
+
2867
+ # Get activation function
2868
+ if activation == "relu":
2869
+ act_fn = nn.ReLU
2870
+ elif activation == "tanh":
2871
+ act_fn = nn.Tanh
2872
+ else:
2873
+ act_fn = nn.ReLU
2874
+
2875
+ # Convolutional front-end
2876
+ self.conv1 = nn.Conv1d(inputsize, num_filters, kernel_size, padding="same")
2877
+ self.bn1 = nn.BatchNorm1d(num_filters)
2878
+ self.dropout1 = nn.Dropout(dropout_rate)
2879
+ self.act1 = act_fn()
2880
+
2881
+ self.conv2 = nn.Conv1d(num_filters, num_filters * 2, kernel_size, padding="same")
2882
+ self.bn2 = nn.BatchNorm1d(num_filters * 2)
2883
+ self.dropout2 = nn.Dropout(dropout_rate)
2884
+ self.act2 = act_fn()
2885
+
2886
+ # Bidirectional LSTM
2887
+ self.lstm = nn.LSTM(num_filters * 2, encoding_dim, batch_first=True, bidirectional=True)
2888
+
2889
+ # Output mapping
2890
+ self.fc_out = nn.Linear(encoding_dim * 2, inputsize)
2891
+
2892
+ def forward(self, x):
2893
+ """
2894
+ Forward pass through the neural network architecture.
2895
+
2896
+ This function processes input data through a convolutional neural network
2897
+ followed by an LSTM layer and a fully connected output layer. The input
2898
+ is first processed through two convolutional blocks, then reshaped for
2899
+ LSTM processing, and finally converted back to the original output format.
2900
+
2901
+ Parameters
2902
+ ----------
2903
+ x : torch.Tensor
2904
+ Input tensor of shape (batch_size, channels, length) containing
2905
+ the input sequence data to be processed.
2906
+
2907
+ Returns
2908
+ -------
2909
+ torch.Tensor
2910
+ Output tensor of shape (batch_size, channels, length) containing
2911
+ the processed sequence data after passing through all layers.
2912
+
2913
+ Notes
2914
+ -----
2915
+ The function performs the following operations in sequence:
2916
+ 1. Two convolutional blocks with batch normalization, dropout, and activation
2917
+ 2. Permute operation to reshape data for LSTM processing (batch, seq_len, features)
2918
+ 3. LSTM layer processing
2919
+ 4. Fully connected output layer
2920
+ 5. Final permutation to restore original shape (batch, channels, length)
2921
+
2922
+ Examples
2923
+ --------
2924
+ >>> import torch
2925
+ >>> model = YourModelClass()
2926
+ >>> x = torch.randn(32, 1, 100) # batch_size=32, channels=1, length=100
2927
+ >>> output = model.forward(x)
2928
+ >>> print(output.shape) # torch.Size([32, 1, 100])
2929
+ """
2930
+ # Conv layers expect (batch, channels, length)
2931
+ x = self.conv1(x)
2932
+ x = self.bn1(x)
2933
+ x = self.dropout1(x)
2934
+ x = self.act1(x)
2935
+
2936
+ x = self.conv2(x)
2937
+ x = self.bn2(x)
2938
+ x = self.dropout2(x)
2939
+ x = self.act2(x)
2940
+
2941
+ # LSTM expects (batch, seq_len, features)
2942
+ x = x.permute(0, 2, 1)
2943
+ x, _ = self.lstm(x)
2944
+
2945
+ # Output layer
2946
+ x = self.fc_out(x)
2947
+
2948
+ # Convert back to (batch, channels, length)
2949
+ x = x.permute(0, 2, 1)
2950
+
2951
+ return x
2952
+
2953
+ def get_config(self):
2954
+ """
2955
+ Get the configuration parameters of the model.
2956
+
2957
+ Returns
2958
+ -------
2959
+ dict
2960
+ A dictionary containing the model configuration parameters with the following keys:
2961
+ - "num_filters" (int): Number of filters in the convolutional layers
2962
+ - "kernel_size" (int): Size of the convolutional kernel
2963
+ - "encoding_dim" (int): Dimension of the encoding layer
2964
+ - "dropout_rate" (float): Dropout rate for regularization
2965
+ - "activation" (str): Activation function used in the layers
2966
+ - "inputsize" (int): Size of the input data
2967
+
2968
+ Notes
2969
+ -----
2970
+ This method returns a copy of the current configuration parameters. Modifications
2971
+ to the returned dictionary will not affect the original model configuration.
2972
+
2973
+ Examples
2974
+ --------
2975
+ >>> model = MyModel()
2976
+ >>> config = model.get_config()
2977
+ >>> print(config['num_filters'])
2978
+ 32
2979
+ """
2980
+ return {
2981
+ "num_filters": self.num_filters,
2982
+ "kernel_size": self.kernel_size,
2983
+ "encoding_dim": self.encoding_dim,
2984
+ "dropout_rate": self.dropout_rate,
2985
+ "activation": self.activation,
2986
+ "inputsize": self.inputsize,
2987
+ }
2988
+
2989
+
2990
+ class CRNNDLFilter(DeepLearningFilter):
2991
+ def __init__(
2992
+ self,
2993
+ encoding_dim: int = 10,
2994
+ num_filters: int = 10,
2995
+ kernel_size: int = 5,
2996
+ dilation_rate: int = 1,
2997
+ *args,
2998
+ **kwargs,
2999
+ ) -> None:
3000
+ """
3001
+ Initialize CRNNDLFilter layer.
3002
+
3003
+ Parameters
3004
+ ----------
3005
+ encoding_dim : int, optional
3006
+ Dimension of the encoding layer, by default 10
3007
+ num_filters : int, optional
3008
+ Number of filters in the convolutional layer, by default 10
3009
+ kernel_size : int, optional
3010
+ Size of the convolutional kernel, by default 5
3011
+ dilation_rate : int, optional
3012
+ Dilation rate for the convolutional layer, by default 1
3013
+ *args : tuple
3014
+ Variable length argument list
3015
+ **kwargs : dict
3016
+ Arbitrary keyword arguments
3017
+
3018
+ Returns
3019
+ -------
3020
+ None
3021
+ This method does not return any value
3022
+
3023
+ Notes
3024
+ -----
3025
+ This constructor initializes a CRNN (Convolutional Recurrent Neural Network)
3026
+ with dilated filters. The layer type is set to "crnn" and configuration
3027
+ parameters are stored in infodict for later reference.
3028
+
3029
+ Examples
3030
+ --------
3031
+ >>> layer = CRNNDLFilter(encoding_dim=20, num_filters=15, kernel_size=3)
3032
+ >>> print(layer.nettype)
3033
+ 'crnn'
3034
+ """
3035
+ self.num_filters = num_filters
3036
+ self.kernel_size = kernel_size
3037
+ self.dilation_rate = dilation_rate
3038
+ self.encoding_dim = encoding_dim
3039
+ self.nettype = "crnn"
3040
+ self.infodict["nettype"] = self.nettype
3041
+ self.infodict["num_filters"] = self.num_filters
3042
+ self.infodict["kernel_size"] = self.kernel_size
3043
+ self.infodict["encoding_dim"] = self.encoding_dim
3044
+ super(CRNNDLFilter, self).__init__(*args, **kwargs)
3045
+
3046
+ def getname(self):
3047
+ """
3048
+ Generate and configure model name and path based on configuration parameters.
3049
+
3050
+ This method constructs a descriptive model name string based on various configuration
3051
+ parameters and creates the corresponding model directory path. The generated name
3052
+ includes information about window size, encoding dimensions, filters, kernel size,
3053
+ epochs, thresholds, step size, and activation function.
3054
+
3055
+ Parameters
3056
+ ----------
3057
+ self : object
3058
+ The instance containing configuration parameters for model naming.
3059
+
3060
+ Returns
3061
+ -------
3062
+ None
3063
+ This method modifies instance attributes in-place and does not return a value.
3064
+
3065
+ Notes
3066
+ -----
3067
+ The generated model name follows a consistent naming convention:
3068
+ 'model_crnn_pytorch_wXXX_enXXX_fnXX_flXX_eXXX_tX_ctX_sX_activation'
3069
+ where XXX represents zero-padded numeric values and X represents single digits.
3070
+
3071
+ Additional suffixes are appended based on:
3072
+ - usebadpts: '_usebadpts' if True
3073
+ - excludebysubject: '_excludebysubject' if True
3074
+ - namesuffix: '_{suffix}' if not None
3075
+
3076
+ Examples
3077
+ --------
3078
+ >>> model = ModelClass()
3079
+ >>> model.window_size = 100
3080
+ >>> model.encoding_dim = 128
3081
+ >>> model.num_filters = 32
3082
+ >>> model.kernel_size = 5
3083
+ >>> model.num_epochs = 100
3084
+ >>> model.excludethresh = 0.5
3085
+ >>> model.corrthresh = 0.8
3086
+ >>> model.step = 10
3087
+ >>> model.activation = 'relu'
3088
+ >>> model.modelroot = '/path/to/models'
3089
+ >>> model.getname()
3090
+ >>> print(model.modelname)
3091
+ 'model_crnn_pytorch_w100_en128_fn32_fl05_e100_t0.5_ct0.8_s10_relu'
3092
+ """
3093
+ self.modelname = "_".join(
3094
+ [
3095
+ "model",
3096
+ "crnn",
3097
+ "pytorch",
3098
+ "w" + str(self.window_size).zfill(3),
3099
+ "en" + str(self.encoding_dim).zfill(3),
3100
+ "fn" + str(self.num_filters).zfill(2),
3101
+ "fl" + str(self.kernel_size).zfill(2),
3102
+ "e" + str(self.num_epochs).zfill(3),
3103
+ "t" + str(self.excludethresh),
3104
+ "ct" + str(self.corrthresh),
3105
+ "s" + str(self.step),
3106
+ self.activation,
3107
+ ]
3108
+ )
3109
+ if self.usebadpts:
3110
+ self.modelname += "_usebadpts"
3111
+ if self.excludebysubject:
3112
+ self.modelname += "_excludebysubject"
3113
+ if self.namesuffix is not None:
3114
+ self.modelname += "_" + self.namesuffix
3115
+ self.modelpath = os.path.join(self.modelroot, self.modelname)
3116
+
3117
+ try:
3118
+ os.makedirs(self.modelpath)
3119
+ except OSError:
3120
+ pass
3121
+
3122
+ def makenet(self):
3123
+ """
3124
+ Create and initialize a CRNN model for neural network training.
3125
+
3126
+ This method initializes a CRNN (Convolutional Recurrent Neural Network) model
3127
+ using the specified configuration parameters and moves it to the designated
3128
+ device (CPU or GPU).
3129
+
3130
+ Parameters
3131
+ ----------
3132
+ self : object
3133
+ The instance of the class containing this method. Expected to have the
3134
+ following attributes:
3135
+ - num_filters : int
3136
+ Number of filters in the convolutional layers
3137
+ - kernel_size : int or tuple
3138
+ Size of the convolutional kernel
3139
+ - encoding_dim : int
3140
+ Dimension of the encoding layer
3141
+ - dropout_rate : float
3142
+ Dropout rate for regularization
3143
+ - activation : str or callable
3144
+ Activation function to use
3145
+ - inputsize : tuple
3146
+ Input dimensions for the model
3147
+ - device : torch.device
3148
+ Device to move the model to (CPU or GPU)
3149
+
3150
+ Returns
3151
+ -------
3152
+ None
3153
+ This method does not return any value. It initializes the model attribute
3154
+ of the class instance.
3155
+
3156
+ Notes
3157
+ -----
3158
+ The method assumes that the CRNNModel class is properly imported and available
3159
+ in the namespace. The model is automatically moved to the device specified
3160
+ in self.device.
3161
+
3162
+ Examples
3163
+ --------
3164
+ >>> class MyModel:
3165
+ ... def __init__(self):
3166
+ ... self.num_filters = 32
3167
+ ... self.kernel_size = 3
3168
+ ... self.encoding_dim = 128
3169
+ ... self.dropout_rate = 0.2
3170
+ ... self.activation = 'relu'
3171
+ ... self.inputsize = (1, 28, 28)
3172
+ ... self.device = torch.device('cpu')
3173
+ ...
3174
+ ... def makenet(self):
3175
+ ... self.model = CRNNModel(
3176
+ ... self.num_filters,
3177
+ ... self.kernel_size,
3178
+ ... self.encoding_dim,
3179
+ ... self.dropout_rate,
3180
+ ... self.activation,
3181
+ ... self.inputsize,
3182
+ ... )
3183
+ ... self.model.to(self.device)
3184
+ ...
3185
+ >>> model = MyModel()
3186
+ >>> model.makenet()
3187
+ >>> print(model.model)
3188
+ """
3189
+ self.model = CRNNModel(
3190
+ self.num_filters,
3191
+ self.kernel_size,
3192
+ self.encoding_dim,
3193
+ self.dropout_rate,
3194
+ self.activation,
3195
+ self.inputsize,
3196
+ )
3197
+ self.model.to(self.device)
3198
+
3199
+
3200
+ class LSTMModel(nn.Module):
3201
+ def __init__(self, num_units, num_layers, dropout_rate, window_size, inputsize):
3202
+ """
3203
+ Initialize the LSTMModel with specified architecture parameters.
3204
+
3205
+ Parameters
3206
+ ----------
3207
+ num_units : int
3208
+ Number of units in each LSTM layer
3209
+ num_layers : int
3210
+ Number of LSTM layers in the model
3211
+ dropout_rate : float
3212
+ Dropout rate for LSTM layers (applied only if num_layers > 1)
3213
+ window_size : int
3214
+ Size of the sliding window used for sequence processing
3215
+ inputsize : int
3216
+ Dimensionality of input features
3217
+
3218
+ Returns
3219
+ -------
3220
+ None
3221
+ Initializes the LSTMModel instance with the specified architecture
3222
+
3223
+ Notes
3224
+ -----
3225
+ This constructor creates a bidirectional LSTM model with residual connections.
3226
+ The model uses LSTM layers with bidirectional processing and time-distributed
3227
+ dense layers for output transformation. Dropout is applied between layers
3228
+ when multiple layers are present.
3229
+
3230
+ Examples
3231
+ --------
3232
+ >>> model = LSTMModel(num_units=128, num_layers=2, dropout_rate=0.2,
3233
+ ... window_size=10, inputsize=20)
3234
+ >>> print(model)
3235
+ """
3236
+ super(LSTMModel, self).__init__()
3237
+
3238
+ self.num_units = num_units
3239
+ self.num_layers = num_layers
3240
+ self.dropout_rate = dropout_rate
3241
+ self.window_size = window_size
3242
+ self.inputsize = inputsize
3243
+
3244
+ self.lstm_layers = nn.ModuleList()
3245
+ self.dense_layers = nn.ModuleList()
3246
+
3247
+ for _ in range(num_layers):
3248
+ # Bidirectional LSTM
3249
+ self.lstm_layers.append(
3250
+ nn.LSTM(
3251
+ inputsize if len(self.lstm_layers) == 0 else inputsize,
3252
+ num_units,
3253
+ batch_first=True,
3254
+ bidirectional=True,
3255
+ dropout=dropout_rate if num_layers > 1 else 0,
3256
+ )
3257
+ )
3258
+ # Time-distributed dense layer
3259
+ self.dense_layers.append(nn.Linear(num_units * 2, inputsize))
3260
+
3261
+ def forward(self, x):
3262
+ """
3263
+ Forward pass through LSTM and dense layers.
3264
+
3265
+ Apply a sequence of LSTM layers followed by dense layers to the input tensor,
3266
+ with appropriate dimension permutations to maintain correct data flow.
3267
+
3268
+ Parameters
3269
+ ----------
3270
+ x : torch.Tensor
3271
+ Input tensor with shape (batch, channels, length) containing the sequential data.
3272
+
3273
+ Returns
3274
+ -------
3275
+ torch.Tensor
3276
+ Output tensor with shape (batch, channels, length) after processing through
3277
+ LSTM and dense layers.
3278
+
3279
+ Notes
3280
+ -----
3281
+ The function performs the following operations:
3282
+ 1. Permutes input from (batch, channels, length) to (batch, length, channels)
3283
+ 2. Processes through LSTM layers sequentially
3284
+ 3. Applies dense layers to each time step
3285
+ 4. Permutes output back to (batch, channels, length)
3286
+
3287
+ Examples
3288
+ --------
3289
+ >>> import torch
3290
+ >>> # Assuming self.lstm_layers and self.dense_layers are initialized
3291
+ >>> x = torch.randn(32, 128, 100) # batch=32, channels=128, length=100
3292
+ >>> output = model.forward(x)
3293
+ >>> output.shape
3294
+ torch.Size([32, 128, 100])
3295
+ """
3296
+ # x is (batch, channels, length), convert to (batch, length, channels)
3297
+ x = x.permute(0, 2, 1)
3298
+
3299
+ for lstm, dense in zip(self.lstm_layers, self.dense_layers):
3300
+ x, _ = lstm(x)
3301
+ # Apply dense layer across time steps
3302
+ x = dense(x)
3303
+
3304
+ # Convert back to (batch, channels, length)
3305
+ x = x.permute(0, 2, 1)
3306
+
3307
+ return x
3308
+
3309
+ def get_config(self):
3310
+ """
3311
+ Get the configuration parameters of the model.
3312
+
3313
+ Returns
3314
+ -------
3315
+ dict
3316
+ A dictionary containing the model configuration parameters with the following keys:
3317
+ - "num_units" (int): Number of units in each layer
3318
+ - "num_layers" (int): Number of layers in the model
3319
+ - "dropout_rate" (float): Dropout rate for regularization
3320
+ - "window_size" (int): Size of the sliding window for sequence processing
3321
+ - "inputsize" (int): Size of the input features
3322
+
3323
+ Notes
3324
+ -----
3325
+ This method returns a copy of the internal configuration parameters.
3326
+ The returned dictionary can be used to recreate the model with the same configuration.
3327
+
3328
+ Examples
3329
+ --------
3330
+ >>> config = model.get_config()
3331
+ >>> print(config['num_units'])
3332
+ 128
3333
+ >>> new_model = ModelClass(**config)
3334
+ """
3335
+ return {
3336
+ "num_units": self.num_units,
3337
+ "num_layers": self.num_layers,
3338
+ "dropout_rate": self.dropout_rate,
3339
+ "window_size": self.window_size,
3340
+ "inputsize": self.inputsize,
3341
+ }
3342
+
3343
+
3344
+ class LSTMDLFilter(DeepLearningFilter):
3345
+ def __init__(self, num_units: int = 16, *args, **kwargs) -> None:
3346
+ """
3347
+ Initialize the LSTMDLFilter layer.
3348
+
3349
+ Parameters
3350
+ ----------
3351
+ num_units : int, optional
3352
+ Number of units in the LSTM layer, by default 16
3353
+ *args
3354
+ Variable length argument list passed to parent class
3355
+ **kwargs
3356
+ Arbitrary keyword arguments passed to parent class
3357
+
3358
+ Returns
3359
+ -------
3360
+ None
3361
+ This method initializes the instance and does not return any value
3362
+
3363
+ Notes
3364
+ -----
3365
+ This constructor sets up the LSTM layer with specified number of units and
3366
+ initializes the network type identifier. The infodict is updated with both
3367
+ the network type and number of units for tracking purposes.
3368
+
3369
+ Examples
3370
+ --------
3371
+ >>> layer = LSTMDLFilter(num_units=32)
3372
+ >>> print(layer.num_units)
3373
+ 32
3374
+ >>> print(layer.nettype)
3375
+ 'lstm'
3376
+ """
3377
+ self.num_units = num_units
3378
+ self.nettype = "lstm"
3379
+ self.infodict["nettype"] = self.nettype
3380
+ self.infodict["num_units"] = self.num_units
3381
+ super(LSTMDLFilter, self).__init__(*args, **kwargs)
3382
+
3383
+ def getname(self):
3384
+ """
3385
+ Generate and configure model name and path based on current parameters.
3386
+
3387
+ This method constructs a descriptive model name string using various
3388
+ hyperparameters and configuration settings. It then creates the
3389
+ corresponding directory path and ensures it exists.
3390
+
3391
+ Parameters
3392
+ ----------
3393
+ self : object
3394
+ The instance containing model configuration attributes.
3395
+
3396
+ Returns
3397
+ -------
3398
+ None
3399
+ This method modifies instance attributes in-place and does not return a value.
3400
+
3401
+ Notes
3402
+ -----
3403
+ The generated model name follows a specific format:
3404
+ "model_lstm_pytorch_wXXX_lYY_nuZZZ_dDD_rdDD_eFFF_tT_ctTT_sS"
3405
+ where XXX, YY, ZZZ, DD, FF, T, TT, S represent formatted parameter values.
3406
+
3407
+ Examples
3408
+ --------
3409
+ >>> model = MyModel()
3410
+ >>> model.window_size = 100
3411
+ >>> model.num_layers = 2
3412
+ >>> model.num_units = 128
3413
+ >>> model.dropout_rate = 0.2
3414
+ >>> model.num_epochs = 100
3415
+ >>> model.excludethresh = 0.5
3416
+ >>> model.corrthresh = 0.8
3417
+ >>> model.step = 1
3418
+ >>> model.excludebysubject = True
3419
+ >>> model.getname()
3420
+ >>> print(model.modelname)
3421
+ 'model_lstm_pytorch_w100_l02_nu128_d02_rd02_e100_t05_ct08_s1_excludebysubject'
3422
+ """
3423
+ self.modelname = "_".join(
3424
+ [
3425
+ "model",
3426
+ "lstm",
3427
+ "pytorch",
3428
+ "w" + str(self.window_size).zfill(3),
3429
+ "l" + str(self.num_layers).zfill(2),
3430
+ "nu" + str(self.num_units),
3431
+ "d" + str(self.dropout_rate),
3432
+ "rd" + str(self.dropout_rate),
3433
+ "e" + str(self.num_epochs).zfill(3),
3434
+ "t" + str(self.excludethresh),
3435
+ "ct" + str(self.corrthresh),
3436
+ "s" + str(self.step),
3437
+ ]
3438
+ )
3439
+ if self.excludebysubject:
3440
+ self.modelname += "_excludebysubject"
3441
+ self.modelpath = os.path.join(self.modelroot, self.modelname)
3442
+
3443
+ try:
3444
+ os.makedirs(self.modelpath)
3445
+ except OSError:
3446
+ pass
3447
+
3448
+ def makenet(self):
3449
+ """
3450
+ Create and initialize an LSTM model for neural network training.
3451
+
3452
+ This method initializes an LSTMModel with the specified architecture parameters
3453
+ and moves the model to the designated device (CPU or GPU).
3454
+
3455
+ Parameters
3456
+ ----------
3457
+ self : object
3458
+ The instance containing the following attributes:
3459
+ - num_units : int
3460
+ Number of units in each LSTM layer
3461
+ - num_layers : int
3462
+ Number of LSTM layers in the model
3463
+ - dropout_rate : float
3464
+ Dropout rate for regularization
3465
+ - window_size : int
3466
+ Size of the input window for time series data
3467
+ - inputsize : int
3468
+ Size of the input features
3469
+ - device : torch.device
3470
+ Device to move the model to (e.g., 'cpu' or 'cuda')
3471
+
3472
+ Returns
3473
+ -------
3474
+ None
3475
+ This method does not return any value. It initializes the model attribute
3476
+ and moves it to the specified device.
3477
+
3478
+ Notes
3479
+ -----
3480
+ The method creates an LSTMModel instance with the following parameters:
3481
+ - num_units: Number of hidden units in LSTM layers
3482
+ - num_layers: Number of stacked LSTM layers
3483
+ - dropout_rate: Dropout probability for regularization
3484
+ - window_size: Input sequence length
3485
+ - inputsize: Feature dimension of input data
3486
+
3487
+ Examples
3488
+ --------
3489
+ >>> # Assuming self is an instance with required attributes
3490
+ >>> self.makenet()
3491
+ >>> # Model is now initialized and moved to specified device
3492
+ """
3493
+ self.model = LSTMModel(
3494
+ self.num_units,
3495
+ self.num_layers,
3496
+ self.dropout_rate,
3497
+ self.window_size,
3498
+ self.inputsize,
3499
+ )
3500
+ self.model.to(self.device)
3501
+
3502
+
3503
+ class HybridModel(nn.Module):
3504
+ def __init__(
3505
+ self,
3506
+ num_filters,
3507
+ kernel_size,
3508
+ num_units,
3509
+ num_layers,
3510
+ dropout_rate,
3511
+ activation,
3512
+ inputsize,
3513
+ window_size,
3514
+ invert,
3515
+ ):
3516
+ """
3517
+ Initialize the HybridModel with configurable CNN-LSTM architecture.
3518
+
3519
+ Parameters
3520
+ ----------
3521
+ num_filters : int
3522
+ Number of filters in the convolutional layers.
3523
+ kernel_size : int
3524
+ Size of the convolutional kernel.
3525
+ num_units : int
3526
+ Number of units in the LSTM layers.
3527
+ num_layers : int
3528
+ Total number of layers in the model.
3529
+ dropout_rate : float
3530
+ Dropout rate for regularization.
3531
+ activation : str
3532
+ Activation function to use; options are 'relu' or 'tanh'.
3533
+ inputsize : int
3534
+ Size of the input features.
3535
+ window_size : int
3536
+ Size of the sliding window for input data.
3537
+ invert : bool
3538
+ If True, applies CNN first followed by LSTM. Otherwise, applies LSTM first followed by CNN.
3539
+
3540
+ Returns
3541
+ -------
3542
+ None
3543
+ This method initializes the model's layers and components but does not return any value.
3544
+
3545
+ Notes
3546
+ -----
3547
+ The model supports two architectures:
3548
+ - If `invert=False`: LSTM → CNN
3549
+ - If `invert=True`: CNN → LSTM
3550
+
3551
+ Examples
3552
+ --------
3553
+ >>> model = HybridModel(
3554
+ ... num_filters=64,
3555
+ ... kernel_size=3,
3556
+ ... num_units=128,
3557
+ ... num_layers=3,
3558
+ ... dropout_rate=0.2,
3559
+ ... activation="relu",
3560
+ ... inputsize=10,
3561
+ ... window_size=100,
3562
+ ... invert=True
3563
+ ... )
3564
+ """
3565
+ super(HybridModel, self).__init__()
3566
+
3567
+ self.num_filters = num_filters
3568
+ self.kernel_size = kernel_size
3569
+ self.num_units = num_units
3570
+ self.num_layers = num_layers
3571
+ self.dropout_rate = dropout_rate
3572
+ self.activation = activation
3573
+ self.inputsize = inputsize
3574
+ self.window_size = window_size
3575
+ self.invert = invert
3576
+
3577
+ # Get activation function
3578
+ if activation == "relu":
3579
+ act_fn = nn.ReLU
3580
+ elif activation == "tanh":
3581
+ act_fn = nn.Tanh
3582
+ else:
3583
+ act_fn = nn.ReLU
3584
+
3585
+ self.layers = nn.ModuleList()
3586
+
3587
+ if invert:
3588
+ # CNN first, then LSTM
3589
+ # Input layer
3590
+ self.layers.append(nn.Conv1d(inputsize, num_filters, kernel_size, padding="same"))
3591
+ self.layers.append(nn.BatchNorm1d(num_filters))
3592
+ self.layers.append(nn.Dropout(dropout_rate))
3593
+ self.layers.append(act_fn())
3594
+
3595
+ # Intermediate CNN layers
3596
+ for _ in range(num_layers - 2):
3597
+ self.layers.append(
3598
+ nn.Conv1d(num_filters, num_filters, kernel_size, padding="same")
3599
+ )
3600
+ self.layers.append(nn.BatchNorm1d(num_filters))
3601
+ self.layers.append(nn.Dropout(dropout_rate))
3602
+ self.layers.append(act_fn())
3603
+
3604
+ # LSTM layer
3605
+ self.lstm = nn.LSTM(
3606
+ num_filters, num_units, batch_first=True, bidirectional=True, dropout=dropout_rate
3607
+ )
3608
+ self.lstm_dense = nn.Linear(num_units * 2, inputsize)
3609
+
3610
+ else:
3611
+ # LSTM first, then CNN
3612
+ self.lstm = nn.LSTM(
3613
+ inputsize, num_units, batch_first=True, bidirectional=True, dropout=dropout_rate
3614
+ )
3615
+ self.lstm_dense = nn.Linear(num_units * 2, inputsize)
3616
+ self.lstm_dropout = nn.Dropout(dropout_rate)
3617
+
3618
+ # Intermediate CNN layers
3619
+ for _ in range(num_layers - 2):
3620
+ self.layers.append(nn.Conv1d(inputsize, num_filters, kernel_size, padding="same"))
3621
+ self.layers.append(nn.BatchNorm1d(num_filters))
3622
+ self.layers.append(nn.Dropout(dropout_rate))
3623
+ self.layers.append(act_fn())
3624
+
3625
+ # Output layer
3626
+ self.output_conv = nn.Conv1d(
3627
+ num_filters if num_layers > 2 else inputsize,
3628
+ inputsize,
3629
+ kernel_size,
3630
+ padding="same",
3631
+ )
3632
+
3633
+ def forward(self, x):
3634
+ """
3635
+ Forward pass of the model with optional CNN-LSTM hybrid architecture.
3636
+
3637
+ This method implements a flexible forward pass that can operate in two modes
3638
+ depending on the `invert` flag. When `invert` is True, the sequence processing
3639
+ follows CNN → LSTM → CNN structure. When `invert` is False, the sequence
3640
+ processing follows LSTM → CNN structure.
3641
+
3642
+ Parameters
3643
+ ----------
3644
+ x : torch.Tensor
3645
+ Input tensor of shape (batch_size, channels, sequence_length) or
3646
+ (batch_size, sequence_length, channels) depending on the processing path.
3647
+
3648
+ Returns
3649
+ -------
3650
+ torch.Tensor
3651
+ Output tensor with the same batch dimension as input, with shape
3652
+ dependent on the specific layers and processing path used.
3653
+
3654
+ Notes
3655
+ -----
3656
+ The function handles different tensor permutations based on the processing
3657
+ order:
3658
+ - CNN → LSTM path: permutes from (B, C, L) to (B, L, C) for LSTM, then back
3659
+ - LSTM → CNN path: permutes from (B, C, L) to (B, L, C) for LSTM, then back
3660
+ The `invert` flag determines which processing order is used.
3661
+
3662
+ Examples
3663
+ --------
3664
+ >>> model = MyModel()
3665
+ >>> x = torch.randn(32, 10, 128) # batch_size=32, seq_len=10, features=128
3666
+ >>> output = model.forward(x)
3667
+ >>> print(output.shape)
3668
+ torch.Size([32, 10, 256])
3669
+ """
3670
+ if self.invert:
3671
+ # Apply CNN layers
3672
+ for layer in self.layers:
3673
+ x = layer(x)
3674
+
3675
+ # LSTM expects (batch, seq_len, features)
3676
+ x = x.permute(0, 2, 1)
3677
+ x, _ = self.lstm(x)
3678
+ x = self.lstm_dense(x)
3679
+
3680
+ # Convert back to (batch, channels, length)
3681
+ x = x.permute(0, 2, 1)
3682
+
3683
+ else:
3684
+ # LSTM first
3685
+ x = x.permute(0, 2, 1)
3686
+ x, _ = self.lstm(x)
3687
+ x = self.lstm_dense(x)
3688
+ x = self.lstm_dropout(x)
3689
+ x = x.permute(0, 2, 1)
3690
+
3691
+ # CNN layers
3692
+ for layer in self.layers:
3693
+ x = layer(x)
3694
+
3695
+ # Output layer
3696
+ if hasattr(self, "output_conv"):
3697
+ x = self.output_conv(x)
3698
+
3699
+ return x
3700
+
3701
+ def get_config(self):
3702
+ """
3703
+ Get the configuration parameters of the model.
3704
+
3705
+ Returns
3706
+ -------
3707
+ dict
3708
+ A dictionary containing all configuration parameters with their current values:
3709
+ - num_filters: int, number of filters in the convolutional layers
3710
+ - kernel_size: int, size of the convolutional kernel
3711
+ - num_units: int, number of units in the dense layers
3712
+ - num_layers: int, number of layers in the model
3713
+ - dropout_rate: float, dropout rate for regularization
3714
+ - activation: str or callable, activation function to use
3715
+ - inputsize: int, size of the input features
3716
+ - window_size: int, size of the sliding window
3717
+ - invert: bool, whether to invert the input data
3718
+
3719
+ Notes
3720
+ -----
3721
+ This method returns a copy of the internal configuration dictionary.
3722
+ Modifications to the returned dictionary will not affect the original model configuration.
3723
+
3724
+ Examples
3725
+ --------
3726
+ >>> config = model.get_config()
3727
+ >>> print(config['num_filters'])
3728
+ 32
3729
+ """
3730
+ return {
3731
+ "num_filters": self.num_filters,
3732
+ "kernel_size": self.kernel_size,
3733
+ "num_units": self.num_units,
3734
+ "num_layers": self.num_layers,
3735
+ "dropout_rate": self.dropout_rate,
3736
+ "activation": self.activation,
3737
+ "inputsize": self.inputsize,
3738
+ "window_size": self.window_size,
3739
+ "invert": self.invert,
3740
+ }
3741
+
3742
+
3743
+ class HybridDLFilter(DeepLearningFilter):
3744
+ def __init__(
3745
+ self,
3746
+ invert: bool = False,
3747
+ num_filters: int = 10,
3748
+ kernel_size: int = 5,
3749
+ num_units: int = 16,
3750
+ *args,
3751
+ **kwargs,
3752
+ ) -> None:
3753
+ """
3754
+ Initialize HybridDLFilter layer.
3755
+
3756
+ Parameters
3757
+ ----------
3758
+ invert : bool, default=False
3759
+ If True, inverts the filter response. If False, applies normal filtering.
3760
+ num_filters : int, default=10
3761
+ Number of filters to apply in the convolutional layer.
3762
+ kernel_size : int, default=5
3763
+ Size of the convolutional kernel.
3764
+ num_units : int, default=16
3765
+ Number of units in the dense layer.
3766
+ *args
3767
+ Variable length argument list.
3768
+ **kwargs
3769
+ Arbitrary keyword arguments.
3770
+
3771
+ Returns
3772
+ -------
3773
+ None
3774
+ This method initializes the HybridDLFilter instance and does not return any value.
3775
+
3776
+ Notes
3777
+ -----
3778
+ This constructor sets up a hybrid deep learning filter that combines convolutional
3779
+ and dense layers. The infodict dictionary is populated with configuration parameters
3780
+ for tracking and logging purposes.
3781
+
3782
+ Examples
3783
+ --------
3784
+ >>> filter_layer = HybridDLFilter(
3785
+ ... invert=True,
3786
+ ... num_filters=20,
3787
+ ... kernel_size=3,
3788
+ ... num_units=32
3789
+ ... )
3790
+ """
3791
+ self.invert = invert
3792
+ self.num_filters = num_filters
3793
+ self.kernel_size = kernel_size
3794
+ self.num_units = num_units
3795
+ self.nettype = "hybrid"
3796
+ self.infodict["nettype"] = self.nettype
3797
+ self.infodict["num_filters"] = self.num_filters
3798
+ self.infodict["kernel_size"] = self.kernel_size
3799
+ self.infodict["invert"] = self.invert
3800
+ self.infodict["num_units"] = self.num_units
3801
+ super(HybridDLFilter, self).__init__(*args, **kwargs)
3802
+
3803
+ def getname(self):
3804
+ """
3805
+ Generate and configure the model name and path based on current parameters.
3806
+
3807
+ This method constructs a descriptive model name string using various
3808
+ hyperparameters and configuration settings. The generated name follows
3809
+ a standardized format that includes window size, layer count, filter count,
3810
+ kernel size, number of units, dropout rates, number of epochs, threshold
3811
+ values, step size, and activation function. The method also creates the
3812
+ corresponding model directory path and ensures it exists.
3813
+
3814
+ Parameters
3815
+ ----------
3816
+ self : object
3817
+ The instance of the class containing the model configuration attributes.
3818
+ Required attributes include:
3819
+ - window_size : int
3820
+ - num_layers : int
3821
+ - num_filters : int
3822
+ - kernel_size : int
3823
+ - num_units : int
3824
+ - dropout_rate : float
3825
+ - num_epochs : int
3826
+ - excludethresh : float
3827
+ - corrthresh : float
3828
+ - step : int
3829
+ - activation : str
3830
+ - invert : bool
3831
+ - excludebysubject : bool
3832
+ - modelroot : str
3833
+
3834
+ Returns
3835
+ -------
3836
+ None
3837
+ This method modifies the instance attributes in-place:
3838
+ - self.modelname : str
3839
+ - self.modelpath : str
3840
+
3841
+ Notes
3842
+ -----
3843
+ The model name format follows this pattern:
3844
+ "model_hybrid_pytorch_wXXX_lYY_fnZZ_flZZ_nuZZ_dZZ_rdZZ_eXXX_tX_ctX_sX_activation"
3845
+ where XXX, YY, ZZ, etc. represent zero-padded numerical values.
3846
+
3847
+ Additional suffixes are appended based on:
3848
+ - "_invert" if self.invert is True
3849
+ - "_excludebysubject" if self.excludebysubject is True
3850
+
3851
+ Examples
3852
+ --------
3853
+ >>> model = MyModel()
3854
+ >>> model.window_size = 100
3855
+ >>> model.num_layers = 2
3856
+ >>> model.getname()
3857
+ >>> print(model.modelname)
3858
+ 'model_hybrid_pytorch_w100_l02_fn08_fl08_nu128_d05_rd05_e100_t05_ct08_s1_relu'
3859
+ """
3860
+ self.modelname = "_".join(
3861
+ [
3862
+ "model",
3863
+ "hybrid",
3864
+ "pytorch",
3865
+ "w" + str(self.window_size).zfill(3),
3866
+ "l" + str(self.num_layers).zfill(2),
3867
+ "fn" + str(self.num_filters).zfill(2),
3868
+ "fl" + str(self.kernel_size).zfill(2),
3869
+ "nu" + str(self.num_units),
3870
+ "d" + str(self.dropout_rate),
3871
+ "rd" + str(self.dropout_rate),
3872
+ "e" + str(self.num_epochs).zfill(3),
3873
+ "t" + str(self.excludethresh),
3874
+ "ct" + str(self.corrthresh),
3875
+ "s" + str(self.step),
3876
+ self.activation,
3877
+ ]
3878
+ )
3879
+ if self.invert:
3880
+ self.modelname += "_invert"
3881
+ if self.excludebysubject:
3882
+ self.modelname += "_excludebysubject"
3883
+ self.modelpath = os.path.join(self.modelroot, self.modelname)
3884
+
3885
+ try:
3886
+ os.makedirs(self.modelpath)
3887
+ except OSError:
3888
+ pass
3889
+
3890
+ def makenet(self):
3891
+ """
3892
+ Create and initialize a hybrid neural network model.
3893
+
3894
+ This method constructs a HybridModel with the specified architecture parameters
3895
+ and moves it to the designated device (CPU or GPU).
3896
+
3897
+ Parameters
3898
+ ----------
3899
+ self : object
3900
+ The instance containing the following attributes:
3901
+ - num_filters : int
3902
+ Number of filters in the convolutional layers
3903
+ - kernel_size : int
3904
+ Size of the convolutional kernels
3905
+ - num_units : int
3906
+ Number of units in the dense layers
3907
+ - num_layers : int
3908
+ Number of layers in the model
3909
+ - dropout_rate : float
3910
+ Dropout rate for regularization
3911
+ - activation : str or callable
3912
+ Activation function to use
3913
+ - inputsize : int
3914
+ Size of the input features
3915
+ - window_size : int
3916
+ Size of the sliding window
3917
+ - invert : bool
3918
+ Whether to invert the model architecture
3919
+
3920
+ Returns
3921
+ -------
3922
+ None
3923
+ This method does not return any value. It initializes the model attribute
3924
+ and moves it to the specified device.
3925
+
3926
+ Notes
3927
+ -----
3928
+ The method assumes that the instance has all required attributes set before
3929
+ calling. The model is moved to the device specified by `self.device`.
3930
+
3931
+ Examples
3932
+ --------
3933
+ >>> model = MyModel()
3934
+ >>> model.num_filters = 32
3935
+ >>> model.kernel_size = 3
3936
+ >>> model.num_units = 64
3937
+ >>> model.num_layers = 2
3938
+ >>> model.dropout_rate = 0.2
3939
+ >>> model.activation = 'relu'
3940
+ >>> model.inputsize = 10
3941
+ >>> model.window_size = 5
3942
+ >>> model.invert = False
3943
+ >>> model.device = 'cuda'
3944
+ >>> model.makenet()
3945
+ >>> print(model.model)
3946
+ """
3947
+ self.model = HybridModel(
3948
+ self.num_filters,
3949
+ self.kernel_size,
3950
+ self.num_units,
3951
+ self.num_layers,
3952
+ self.dropout_rate,
3953
+ self.activation,
3954
+ self.inputsize,
3955
+ self.window_size,
3956
+ self.invert,
3957
+ )
3958
+ self.model.to(self.device)
3959
+
3960
+
3961
+ def filtscale(
3962
+ data: NDArray,
3963
+ scalefac: float = 1.0,
3964
+ reverse: bool = False,
3965
+ hybrid: bool = False,
3966
+ lognormalize: bool = True,
3967
+ epsilon: float = 1e-10,
3968
+ numorders: int = 6,
3969
+ ) -> tuple[NDArray, float] | NDArray:
3970
+ """
3971
+ Apply or reverse a scaling transformation to spectral data.
3972
+
3973
+ This function performs either forward or inverse scaling of input data,
3974
+ typically used in signal processing or spectral analysis. In forward mode,
3975
+ it computes the FFT of the input data and applies normalization and scaling
3976
+ to the magnitude and phase components. In reverse mode, it reconstructs
3977
+ the original time-domain signal from scaled magnitude and phase components.
3978
+
3979
+ Parameters
3980
+ ----------
3981
+ data : NDArray
3982
+ Input time-domain signal or scaled spectral data depending on `reverse` flag.
3983
+ scalefac : float, optional
3984
+ Scaling factor used in normalization. Default is 1.0.
3985
+ reverse : bool, optional
3986
+ If True, performs inverse transformation to reconstruct the original signal.
3987
+ If False, performs forward transformation. Default is False.
3988
+ hybrid : bool, optional
3989
+ If True, returns a hybrid output combining original signal and magnitude.
3990
+ Only applicable in forward mode. Default is False.
3991
+ lognormalize : bool, optional
3992
+ If True, applies logarithmic normalization to the magnitude. Default is True.
3993
+ epsilon : float, optional
3994
+ Small constant added to magnitude before log to avoid log(0). Default is 1e-10.
3995
+ numorders : int, optional
3996
+ Number of orders used in normalization scaling. Default is 6.
3997
+
3998
+ Returns
3999
+ -------
4000
+ tuple[NDArray, float] or NDArray
4001
+ - If `reverse` is False: Returns a tuple of (scaled_data, scalefac).
4002
+ `scaled_data` is a stacked array of magnitude and phase (or original signal
4003
+ and magnitude in hybrid mode).
4004
+ - If `reverse` is True: Returns the reconstructed time-domain signal as
4005
+ a numpy array.
4006
+
4007
+ Notes
4008
+ -----
4009
+ - In forward mode, the function computes the FFT of `data`, normalizes the
4010
+ magnitude, and scales it to a range suitable for further processing.
4011
+ - In reverse mode, the function reconstructs the time-domain signal using
4012
+ inverse FFT from the provided scaled magnitude and phase components.
4013
+ - The `hybrid` mode is useful for certain types of signal visualization or
4014
+ feature extraction where both time-domain and frequency-domain information
4015
+ are needed.
4016
+
4017
+ Examples
4018
+ --------
4019
+ >>> import numpy as np
4020
+ >>> from scipy import fftpack
4021
+ >>> x = np.random.randn(1024)
4022
+ >>> scaled_data, scalefac = filtscale(x)
4023
+ >>> reconstructed = filtscale(scaled_data, scalefac=scalefac, reverse=True)
4024
+ """
4025
+ if not reverse:
4026
+ specvals = fftpack.fft(data)
4027
+ if lognormalize:
4028
+ themag = np.log(np.absolute(specvals) + epsilon)
4029
+ scalefac = np.max(themag)
4030
+ themag = (themag - scalefac + numorders) / numorders
4031
+ themag[np.where(themag < 0.0)] = 0.0
4032
+ else:
4033
+ scalefac = np.std(data)
4034
+ themag = np.absolute(specvals) / scalefac
4035
+ thephase = np.angle(specvals)
4036
+ thephase = thephase / (2.0 * np.pi) - 0.5
4037
+ if hybrid:
4038
+ return np.stack((data, themag), axis=1), scalefac
4039
+ else:
4040
+ return np.stack((themag, thephase), axis=1), scalefac
4041
+ else:
4042
+ if hybrid:
4043
+ return data[:, 0]
4044
+ else:
4045
+ thephase = (data[:, 1] + 0.5) * 2.0 * np.pi
4046
+ if lognormalize:
4047
+ themag = np.exp(data[:, 0] * numorders - numorders + scalefac)
4048
+ else:
4049
+ themag = data[:, 0] * scalefac
4050
+ specvals = themag * np.exp(1.0j * thephase)
4051
+ return fftpack.ifft(specvals).real
4052
+
4053
+
4054
+ def tobadpts(name: str) -> str:
4055
+ """
4056
+ Convert a filename to its corresponding bad points filename.
4057
+
4058
+ This function takes a filename string and replaces the '.txt' extension
4059
+ with '_badpts.txt' to create a new filename for bad points data.
4060
+
4061
+ Parameters
4062
+ ----------
4063
+ name : str
4064
+ The input filename string, typically ending with '.txt'.
4065
+
4066
+ Returns
4067
+ -------
4068
+ str
4069
+ The converted filename with '_badpts.txt' extension instead of '.txt'.
4070
+
4071
+ Notes
4072
+ -----
4073
+ This function is useful for creating consistent naming conventions for
4074
+ bad points data files that correspond to original data files.
4075
+
4076
+ Examples
4077
+ --------
4078
+ >>> tobadpts("data.txt")
4079
+ 'data_badpts.txt'
4080
+
4081
+ >>> tobadpts("results.txt")
4082
+ 'results_badpts.txt'
4083
+
4084
+ >>> tobadpts("output.txt")
4085
+ 'output_badpts.txt'
4086
+ """
4087
+ return name.replace(".txt", "_badpts.txt")
4088
+
4089
+
4090
+ def targettoinput(name: str, targetfrag: str = "xyz", inputfrag: str = "abc") -> str:
4091
+ """
4092
+ Replace target fragment with input fragment in a string.
4093
+
4094
+ Parameters
4095
+ ----------
4096
+ name : str
4097
+ The input string to perform replacement on.
4098
+ targetfrag : str, default='xyz'
4099
+ The fragment to be replaced in the input string.
4100
+ inputfrag : str, default='abc'
4101
+ The fragment to replace the target fragment with.
4102
+
4103
+ Returns
4104
+ -------
4105
+ str
4106
+ The modified string with targetfrag replaced by inputfrag.
4107
+
4108
+ Notes
4109
+ -----
4110
+ This function uses Python's built-in string replace method, which replaces
4111
+ all occurrences of the target fragment with the input fragment.
4112
+
4113
+ Examples
4114
+ --------
4115
+ >>> targettoinput("hello xyz world")
4116
+ 'hello abc world'
4117
+
4118
+ >>> targettoinput("test xyz xyz test", "xyz", "123")
4119
+ 'test 123 123 test'
4120
+
4121
+ >>> targettoinput("abcdef", "cde", "XXX")
4122
+ 'abXXXf'
4123
+ """
4124
+ LGR.debug(f"replacing {targetfrag} with {inputfrag}")
4125
+ return name.replace(targetfrag, inputfrag)
4126
+
4127
+
4128
+ def getmatchedtcs(
4129
+ searchstring: str,
4130
+ usebadpts: bool = False,
4131
+ targetfrag: str = "xyz",
4132
+ inputfrag: str = "abc",
4133
+ debug: bool = False,
4134
+ ) -> tuple[list[str], int]:
4135
+ """
4136
+ Find and validate matched timecourse files based on a search pattern.
4137
+
4138
+ This function searches for timecourse files matching the given search string,
4139
+ verifies their completeness by checking for associated info files, and
4140
+ determines the length of the timecourses from the first valid file.
4141
+
4142
+ Parameters
4143
+ ----------
4144
+ searchstring : str
4145
+ A glob pattern to match target timecourse files.
4146
+ usebadpts : bool, optional
4147
+ Flag indicating whether bad points should be used (default is False).
4148
+ targetfrag : str, optional
4149
+ Target fragment identifier (default is "xyz").
4150
+ inputfrag : str, optional
4151
+ Input fragment identifier (default is "abc").
4152
+ debug : bool, optional
4153
+ If True, prints debug information including matched files (default is False).
4154
+
4155
+ Returns
4156
+ -------
4157
+ tuple[list[str], int]
4158
+ A tuple containing:
4159
+ - List of matched and validated file paths.
4160
+ - Length of the timecourses (number of timepoints).
4161
+
4162
+ Notes
4163
+ -----
4164
+ The function expects timecourse files to have a corresponding info file
4165
+ with the same base name but with "_info" appended. Only files with complete
4166
+ info files are considered valid.
4167
+
4168
+ Examples
4169
+ --------
4170
+ >>> matched_files, tc_length = getmatchedtcs("data/*cardiac*.tsv")
4171
+ >>> print(f"Found {len(matched_files)} files with {tc_length} timepoints")
4172
+ """
4173
+ # list all of the target files
4174
+ fromfile = sorted(glob.glob(searchstring))
4175
+ if debug:
4176
+ print(f"searchstring: {searchstring} -> {fromfile}")
4177
+
4178
+ # make sure all timecourses exist
4179
+ # we need cardiacfromfmri_25.0Hz as x, normpleth as y, and perhaps badpts
4180
+ matchedfilelist = []
4181
+ for targetname in fromfile:
4182
+ infofile = targetname.replace("_desc-stdrescardfromfmri_timeseries", "_info")
4183
+ if os.path.isfile(infofile):
4184
+ matchedfilelist.append(targetname)
4185
+ print(f"{targetname} is complete")
4186
+ LGR.debug(matchedfilelist[-1])
4187
+ else:
4188
+ print(f"{targetname} is incomplete")
4189
+ print(f"found {len(matchedfilelist)} matched files")
4190
+
4191
+ # find out how long the files are
4192
+ (
4193
+ samplerate,
4194
+ starttime,
4195
+ columns,
4196
+ inputarray,
4197
+ compression,
4198
+ columnsource,
4199
+ ) = tide_io.readbidstsv(
4200
+ matchedfilelist[0],
4201
+ colspec="cardiacfromfmri_25.0Hz,normpleth",
4202
+ )
4203
+ print(f"{inputarray.shape=}")
4204
+ tclen = inputarray.shape[1]
4205
+ LGR.info(f"tclen set to {tclen}")
4206
+ return matchedfilelist, tclen
4207
+
4208
+
4209
+ def readindata(
4210
+ matchedfilelist: list[str],
4211
+ tclen: int,
4212
+ targetfrag: str = "xyz",
4213
+ inputfrag: str = "abc",
4214
+ usebadpts: bool = False,
4215
+ startskip: int = 0,
4216
+ endskip: int = 0,
4217
+ corrthresh: float = 0.5,
4218
+ readlim: int | None = None,
4219
+ readskip: int | None = None,
4220
+ debug: bool = False,
4221
+ ) -> tuple[NDArray, NDArray, list[str]] | tuple[NDArray, NDArray, list[str], NDArray]:
4222
+ """
4223
+ Read and process time-series data from a list of matched files.
4224
+
4225
+ This function reads cardiac and plethysmographic time-series data from a list of
4226
+ files, performs quality checks, and returns the data in arrays suitable for
4227
+ training or analysis. It supports filtering based on correlation thresholds,
4228
+ NaN values, and signal standard deviations, and allows for optional skipping
4229
+ of data at the start and end of each time series.
4230
+
4231
+ Parameters
4232
+ ----------
4233
+ matchedfilelist : list of str
4234
+ List of file paths to be processed. Each file should contain time-series data
4235
+ in a format compatible with `tide_io.readbidstsv`.
4236
+ tclen : int
4237
+ Length of the time series to be read from each file.
4238
+ targetfrag : str, optional
4239
+ Fragment identifier for target files, used in naming conversions. Default is "xyz".
4240
+ inputfrag : str, optional
4241
+ Fragment identifier for input files, used in naming conversions. Default is "abc".
4242
+ usebadpts : bool, optional
4243
+ If True, include a third array with bad point indicators. Default is False.
4244
+ startskip : int, optional
4245
+ Number of samples to skip at the beginning of each time series. Default is 0.
4246
+ endskip : int, optional
4247
+ Number of samples to skip at the end of each time series. Default is 0.
4248
+ corrthresh : float, optional
4249
+ Minimum correlation threshold between raw and plethysmographic signals.
4250
+ Files with lower correlation are excluded. Default is 0.5.
4251
+ readlim : int, optional
4252
+ Maximum number of files to read. If None, all files are read. Default is None.
4253
+ readskip : int, optional
4254
+ Number of files to skip at the beginning of the file list. If None, no files are skipped. Default is None.
4255
+ debug : bool, optional
4256
+ If True, print debug information for each file. Default is False.
4257
+
4258
+ Returns
4259
+ -------
4260
+ tuple of (NDArray, NDArray, list[str]) or (NDArray, NDArray, list[str], NDArray)
4261
+ - `x1`: Array of shape `(tclen, count)` containing x-time series data.
4262
+ - `y1`: Array of shape `(tclen, count)` containing y-time series data.
4263
+ - `names`: List of file names that passed quality checks.
4264
+ - `bad1`: Optional array of shape `(tclen, count)` with bad point indicators if `usebadpts=True`.
4265
+
4266
+ Notes
4267
+ -----
4268
+ - Files with NaNs, short data, extreme standard deviations, or low correlation are excluded.
4269
+ - The function logs information about excluded files for debugging and quality control.
4270
+ - The `startskip` and `endskip` parameters are applied after filtering and before returning the data.
4271
+
4272
+ Examples
4273
+ --------
4274
+ >>> x, y, names = readindata(filelist, tclen=1000)
4275
+ >>> x, y, names, bad = readindata(filelist, tclen=1000, usebadpts=True)
4276
+ """
4277
+ LGR.info(
4278
+ "readindata called with usebadpts, startskip, endskip, readlim, readskip, targetfrag, inputfrag = "
4279
+ f"{usebadpts} {startskip} {endskip} {readlim} {readskip} {targetfrag} {inputfrag}"
4280
+ )
4281
+ # allocate target arrays
4282
+ LGR.info("allocating arrays")
4283
+ s = len(matchedfilelist[readskip:])
4284
+ if readlim is not None:
4285
+ if s > readlim:
4286
+ LGR.info(f"trimming read list to {readlim} from {s}")
4287
+ s = readlim
4288
+ x1 = np.zeros((tclen, s))
4289
+ y1 = np.zeros((tclen, s))
4290
+ names = []
4291
+ if usebadpts:
4292
+ bad1 = np.zeros((tclen, s))
4293
+
4294
+ # now read the data in
4295
+ count = 0
4296
+ LGR.info("checking data")
4297
+ lowcorrfiles = []
4298
+ nanfiles = []
4299
+ shortfiles = []
4300
+ strangemagfiles = []
4301
+ for i in range(readskip, readskip + s):
4302
+ lowcorrfound = False
4303
+ nanfound = False
4304
+ LGR.info(f"processing {matchedfilelist[i]}")
4305
+
4306
+ # read the info dict first
4307
+ infodict = tide_io.readdictfromjson(
4308
+ matchedfilelist[i].replace("_desc-stdrescardfromfmri_timeseries", "_info")
4309
+ )
4310
+ if infodict["corrcoeff_raw2pleth"] < corrthresh:
4311
+ lowcorrfound = True
4312
+ lowcorrfiles.append(matchedfilelist[i])
4313
+ (
4314
+ samplerate,
4315
+ starttime,
4316
+ columns,
4317
+ inputarray,
4318
+ compression,
4319
+ columnsource,
4320
+ ) = tide_io.readbidstsv(
4321
+ matchedfilelist[i],
4322
+ colspec="cardiacfromfmri_25.0Hz,normpleth",
4323
+ )
4324
+ tempy = inputarray[1, :]
4325
+ tempx = inputarray[0, :]
4326
+
4327
+ if np.any(np.isnan(tempy)):
4328
+ LGR.info(f"NaN found in file {matchedfilelist[i]} - discarding")
4329
+ nanfound = True
4330
+ nanfiles.append(matchedfilelist[i])
4331
+ if np.any(np.isnan(tempx)):
4332
+ nan_fname = targettoinput(
4333
+ matchedfilelist[i], targetfrag=targetfrag, inputfrag=inputfrag
4334
+ )
4335
+ LGR.info(f"NaN found in file {nan_fname} - discarding")
4336
+ nanfound = True
4337
+ nanfiles.append(nan_fname)
4338
+ strangefound = False
4339
+ if not (0.5 < np.std(tempx) < 20.0):
4340
+ strange_fname = matchedfilelist[i]
4341
+ LGR.info(
4342
+ f"file {strange_fname} has an extreme cardiacfromfmri standard deviation - discarding"
4343
+ )
4344
+ strangefound = True
4345
+ strangemagfiles.append(strange_fname)
4346
+ if not (0.5 < np.std(tempy) < 20.0):
4347
+ LGR.info(
4348
+ f"file {matchedfilelist[i]} has an extreme normpleth standard deviation - discarding"
4349
+ )
4350
+ strangefound = True
4351
+ strangemagfiles.append(matchedfilelist[i])
4352
+ shortfound = False
4353
+ ntempx = tempx.shape[0]
4354
+ ntempy = tempy.shape[0]
4355
+ if ntempx < tclen:
4356
+ short_fname = matchedfilelist[i]
4357
+ LGR.info(f"file {short_fname} is short - discarding")
4358
+ shortfound = True
4359
+ shortfiles.append(short_fname)
4360
+ if ntempy < tclen:
4361
+ LGR.info(f"file {matchedfilelist[i]} is short - discarding")
4362
+ shortfound = True
4363
+ shortfiles.append(matchedfilelist[i])
4364
+ if (
4365
+ (ntempx >= tclen)
4366
+ and (ntempy >= tclen)
4367
+ and (not nanfound)
4368
+ and (not shortfound)
4369
+ and (not strangefound)
4370
+ and (not lowcorrfound)
4371
+ ):
4372
+ x1[:tclen, count] = tempx[:tclen]
4373
+ y1[:tclen, count] = tempy[:tclen]
4374
+ names.append(matchedfilelist[i])
4375
+ if debug:
4376
+ print(f"{matchedfilelist[i]} included:")
4377
+ if usebadpts:
4378
+ bad1[:tclen, count] = inputarray[2, :]
4379
+ count += 1
4380
+ else:
4381
+ print(f"{matchedfilelist[i]} excluded:")
4382
+ if ntempx < tclen:
4383
+ print("\tx data too short")
4384
+ if ntempy < tclen:
4385
+ print("\ty data too short")
4386
+ print(f"\t{nanfound=}")
4387
+ print(f"\t{shortfound=}")
4388
+ print(f"\t{strangefound=}")
4389
+ print(f"\t{lowcorrfound=}")
4390
+ LGR.info(f"{count} runs pass file length check")
4391
+ if len(lowcorrfiles) > 0:
4392
+ LGR.info("files with low raw/pleth correlations:")
4393
+ for thefile in lowcorrfiles:
4394
+ LGR.info(f"\t{thefile}")
4395
+ if len(nanfiles) > 0:
4396
+ LGR.info("files with NaNs:")
4397
+ for thefile in nanfiles:
4398
+ LGR.info(f"\t{thefile}")
4399
+ if len(shortfiles) > 0:
4400
+ LGR.info("short files:")
4401
+ for thefile in shortfiles:
4402
+ LGR.info(f"\t{thefile}")
4403
+ if len(strangemagfiles) > 0:
4404
+ LGR.info("files with extreme standard deviations:")
4405
+ for thefile in strangemagfiles:
4406
+ LGR.info(f"\t{thefile}")
4407
+
4408
+ print(f"training set contains {count} runs of length {tclen}")
4409
+ if usebadpts:
4410
+ return (
4411
+ x1[startskip:-endskip, :count],
4412
+ y1[startskip:-endskip, :count],
4413
+ names[:count],
4414
+ bad1[startskip:-endskip, :count],
4415
+ )
4416
+ else:
4417
+ return (
4418
+ x1[startskip:-endskip, :count],
4419
+ y1[startskip:-endskip, :count],
4420
+ names[:count],
4421
+ )
4422
+
4423
+
4424
+ def prep(
4425
+ window_size: int,
4426
+ step: int = 1,
4427
+ excludethresh: float = 4.0,
4428
+ usebadpts: bool = False,
4429
+ startskip: int = 200,
4430
+ endskip: int = 200,
4431
+ excludebysubject: bool = True,
4432
+ thesuffix: str = "sliceres",
4433
+ thedatadir: str = "/data/frederic/physioconn/output_2025",
4434
+ inputfrag: str = "abc",
4435
+ targetfrag: str = "xyz",
4436
+ corrthresh: float = 0.5,
4437
+ dofft: bool = False,
4438
+ readlim: int | None = None,
4439
+ readskip: int | None = None,
4440
+ countlim: int | None = None,
4441
+ debug: bool = False,
4442
+ ) -> (
4443
+ tuple[NDArray, NDArray, NDArray, NDArray, int, int, int]
4444
+ | tuple[NDArray, NDArray, NDArray, NDArray, int, int, int, NDArray, NDArray]
4445
+ ):
4446
+ """
4447
+ Prepare time-series data for training and validation by reading, normalizing,
4448
+ windowing, and splitting into batches.
4449
+
4450
+ This function reads physiological time-series data from JSON files, normalizes
4451
+ the data, and organizes it into overlapping windows for model training and
4452
+ validation. It supports filtering by subject or by window, and can optionally
4453
+ apply FFT transformations to the data.
4454
+
4455
+ Parameters
4456
+ ----------
4457
+ window_size : int
4458
+ Size of the sliding window used to segment time series data.
4459
+ step : int, optional
4460
+ Step size for sliding window (default is 1).
4461
+ excludethresh : float, optional
4462
+ Threshold for excluding data points based on maximum absolute value
4463
+ (default is 4.0).
4464
+ usebadpts : bool, optional
4465
+ If True, includes bad points in the data processing (default is False).
4466
+ startskip : int, optional
4467
+ Number of time points to skip at the beginning of each time series
4468
+ (default is 200).
4469
+ endskip : int, optional
4470
+ Number of time points to skip at the end of each time series
4471
+ (default is 200).
4472
+ excludebysubject : bool, optional
4473
+ If True, exclude subjects with any region exceeding `excludethresh`;
4474
+ otherwise, exclude windows (default is True).
4475
+ thesuffix : str, optional
4476
+ Suffix used in file search pattern (default is "sliceres").
4477
+ thedatadir : str, optional
4478
+ Directory path where the data files are stored (default is
4479
+ "/data/frederic/physioconn/output_2025").
4480
+ inputfrag : str, optional
4481
+ Fragment identifier for input data (default is "abc").
4482
+ targetfrag : str, optional
4483
+ Fragment identifier for target data (default is "xyz").
4484
+ corrthresh : float, optional
4485
+ Correlation threshold for data filtering (default is 0.5).
4486
+ dofft : bool, optional
4487
+ If True, apply FFT transformation to the data (default is False).
4488
+ readlim : int, optional
4489
+ Limit on number of time points to read (default is None).
4490
+ readskip : int, optional
4491
+ Number of time points to skip when reading data (default is None).
4492
+ countlim : int, optional
4493
+ Maximum number of subjects to include (default is None).
4494
+ debug : bool, optional
4495
+ If True, enable debug logging (default is False).
4496
+
4497
+ Returns
4498
+ -------
4499
+ tuple of (NDArray, NDArray, NDArray, NDArray, int, int, int)
4500
+ If `dofft` is False:
4501
+ - train_x : Training input data (shape: [n_windows, window_size, 1])
4502
+ - train_y : Training target data (shape: [n_windows, window_size, 1])
4503
+ - val_x : Validation input data (shape: [n_windows, window_size, 1])
4504
+ - val_y : Validation target data (shape: [n_windows, window_size, 1])
4505
+ - N_subjs : Number of subjects
4506
+ - tclen : Total time points after skipping
4507
+ - batchsize : Number of windows per subject
4508
+
4509
+ tuple of (NDArray, NDArray, NDArray, NDArray, int, int, int, NDArray, NDArray)
4510
+ If `dofft` is True:
4511
+ - train_x : Training input data (shape: [n_windows, window_size, 2])
4512
+ - train_y : Training target data (shape: [n_windows, window_size, 2])
4513
+ - val_x : Validation input data (shape: [n_windows, window_size, 2])
4514
+ - val_y : Validation target data (shape: [n_windows, window_size, 2])
4515
+ - N_subjs : Number of subjects
4516
+ - tclen : Total time points after skipping
4517
+ - batchsize : Number of windows per subject
4518
+ - Xscale_fourier : Fourier scaling for input data
4519
+ - Yscale_fourier : Fourier scaling for target data
4520
+
4521
+ Notes
4522
+ -----
4523
+ - Data normalization is performed using median absolute deviation (MAD).
4524
+ - The function supports both window-based and subject-based exclusion strategies.
4525
+ - If `usebadpts` is True, bad points are included in the output arrays.
4526
+ - FFT transformations are applied using a helper function `filtscale`.
4527
+
4528
+ Examples
4529
+ --------
4530
+ >>> train_x, train_y, val_x, val_y, N_subjs, tclen, batchsize = prep(
4531
+ ... window_size=100,
4532
+ ... step=10,
4533
+ ... excludethresh=3.0,
4534
+ ... excludebysubject=True,
4535
+ ... dofft=False
4536
+ ... )
4537
+ """
4538
+ searchstring = os.path.join(thedatadir, "*", "*_desc-stdrescardfromfmri_timeseries.json")
4539
+
4540
+ # find matched files
4541
+ matchedfilelist, tclen = getmatchedtcs(
4542
+ searchstring,
4543
+ usebadpts=usebadpts,
4544
+ targetfrag=targetfrag,
4545
+ inputfrag=inputfrag,
4546
+ debug=debug,
4547
+ )
4548
+ # print("matchedfilelist", matchedfilelist)
4549
+ print("tclen", tclen)
4550
+
4551
+ # read in the data from the matched files
4552
+ print("about to read in data")
4553
+ if usebadpts:
4554
+ x, y, names, bad = readindata(
4555
+ matchedfilelist,
4556
+ tclen,
4557
+ corrthresh=corrthresh,
4558
+ targetfrag=targetfrag,
4559
+ inputfrag=inputfrag,
4560
+ usebadpts=True,
4561
+ startskip=startskip,
4562
+ endskip=endskip,
4563
+ readlim=readlim,
4564
+ readskip=readskip,
4565
+ )
4566
+ else:
4567
+ x, y, names = readindata(
4568
+ matchedfilelist,
4569
+ tclen,
4570
+ corrthresh=corrthresh,
4571
+ targetfrag=targetfrag,
4572
+ inputfrag=inputfrag,
4573
+ startskip=startskip,
4574
+ endskip=endskip,
4575
+ readlim=readlim,
4576
+ readskip=readskip,
4577
+ )
4578
+ print("finished reading in data")
4579
+ LGR.info(f"xshape, yshape: {x.shape} {y.shape}")
4580
+
4581
+ # normalize input and output data
4582
+ LGR.info("normalizing data")
4583
+ LGR.info(f"count: {x.shape[1]}")
4584
+ if LGR.getEffectiveLevel() <= logging.DEBUG:
4585
+ # Only take these steps if the logger is set to DEBUG.
4586
+ for thesubj in range(x.shape[1]):
4587
+ LGR.debug(
4588
+ f"prenorm sub {thesubj} min, max, mean, std, MAD x, y: "
4589
+ f"{thesubj} "
4590
+ f"{np.min(x[:, thesubj])} {np.max(x[:, thesubj])} {np.mean(x[:, thesubj])} "
4591
+ f"{np.std(x[:, thesubj])} {mad(x[:, thesubj])} {np.min(y[:, thesubj])} "
4592
+ f"{np.max(y[:, thesubj])} {np.mean(y[:, thesubj])} {np.std(x[:, thesubj])} "
4593
+ f"{mad(y[:, thesubj])}"
4594
+ )
4595
+
4596
+ y -= np.mean(y, axis=0)
4597
+ themad = mad(y, axis=0)
4598
+ for thesubj in range(themad.shape[0]):
4599
+ if themad[thesubj] > 0.0:
4600
+ y[:, thesubj] /= themad[thesubj]
4601
+
4602
+ x -= np.mean(x, axis=0)
4603
+ themad = mad(x, axis=0)
4604
+ for thesubj in range(themad.shape[0]):
4605
+ if themad[thesubj] > 0.0:
4606
+ x[:, thesubj] /= themad[thesubj]
4607
+
4608
+ if LGR.getEffectiveLevel() <= logging.DEBUG:
4609
+ # Only take these steps if the logger is set to DEBUG.
4610
+ for thesubj in range(x.shape[1]):
4611
+ LGR.debug(
4612
+ f"postnorm sub {thesubj} min, max, mean, std, MAD x, y: "
4613
+ f"{thesubj} "
4614
+ f"{np.min(x[:, thesubj])} {np.max(x[:, thesubj])} {np.mean(x[:, thesubj])} "
4615
+ f"{np.std(x[:, thesubj])} {mad(x[:, thesubj])} {np.min(y[:, thesubj])} "
4616
+ f"{np.max(y[:, thesubj])} {np.mean(y[:, thesubj])} {np.std(x[:, thesubj])} "
4617
+ f"{mad(y[:, thesubj])}"
4618
+ )
4619
+
4620
+ # now decide what to keep and what to exclude
4621
+ thefabs = np.fabs(x)
4622
+ if not excludebysubject:
4623
+ N_pts = x.shape[0]
4624
+ N_subjs = x.shape[1]
4625
+ windowspersubject = np.int64((N_pts - window_size - 1) // step)
4626
+ LGR.info(
4627
+ f"{N_subjs} subjects with {N_pts} points will be evaluated with "
4628
+ f"{windowspersubject} windows per subject with step {step}"
4629
+ )
4630
+ usewindow = np.zeros(N_subjs * windowspersubject, dtype=np.int64)
4631
+ subjectstarts = np.zeros(N_subjs, dtype=np.int64)
4632
+ # check each window
4633
+ numgoodwindows = 0
4634
+ LGR.info("checking windows")
4635
+ subjectnames = []
4636
+ for subj in range(N_subjs):
4637
+ subjectstarts[subj] = numgoodwindows
4638
+ subjectnames.append(names[subj])
4639
+ LGR.info(f"{names[subj]} starts at {numgoodwindows}")
4640
+ for windownumber in range(windowspersubject):
4641
+ if (
4642
+ np.max(
4643
+ thefabs[
4644
+ step * windownumber : (step * windownumber + window_size),
4645
+ subj,
4646
+ ]
4647
+ )
4648
+ <= excludethresh
4649
+ ):
4650
+ usewindow[subj * windowspersubject + windownumber] = 1
4651
+ numgoodwindows += 1
4652
+ LGR.info(
4653
+ f"found {numgoodwindows} out of a potential {N_subjs * windowspersubject} "
4654
+ f"({100.0 * numgoodwindows / (N_subjs * windowspersubject)}%)"
4655
+ )
4656
+
4657
+ for subj in range(N_subjs):
4658
+ LGR.info(f"{names[subj]} starts at {subjectstarts[subj]}")
4659
+
4660
+ LGR.info("copying data into windows")
4661
+ Xb = np.zeros((numgoodwindows, window_size, 1))
4662
+ Yb = np.zeros((numgoodwindows, window_size, 1))
4663
+ if usebadpts:
4664
+ Xb_withbad = np.zeros((numgoodwindows, window_size, 1))
4665
+ LGR.info(f"dimensions of Xb: {Xb.shape}")
4666
+ thiswindow = 0
4667
+ for subj in range(N_subjs):
4668
+ for windownumber in range(windowspersubject):
4669
+ if usewindow[subj * windowspersubject + windownumber] == 1:
4670
+ Xb[thiswindow, :, 0] = x[
4671
+ step * windownumber : (step * windownumber + window_size), subj
4672
+ ]
4673
+ Yb[thiswindow, :, 0] = y[
4674
+ step * windownumber : (step * windownumber + window_size), subj
4675
+ ]
4676
+ if usebadpts:
4677
+ Xb_withbad[thiswindow, :, 0] = bad[
4678
+ step * windownumber : (step * windownumber + window_size),
4679
+ subj,
4680
+ ]
4681
+ thiswindow += 1
4682
+
4683
+ else:
4684
+ # now check for subjects that have regions that exceed the target
4685
+ themax = np.max(thefabs, axis=0)
4686
+
4687
+ cleansubjs = np.where(themax < excludethresh)[0]
4688
+
4689
+ totalcount = x.shape[1] + 0
4690
+ cleancount = len(cleansubjs)
4691
+ if countlim is not None:
4692
+ if cleancount > countlim:
4693
+ LGR.info(f"reducing count to {countlim} from {cleancount}")
4694
+ cleansubjs = cleansubjs[:countlim]
4695
+
4696
+ x = x[:, cleansubjs]
4697
+ y = y[:, cleansubjs]
4698
+ cleannames = []
4699
+ for theindex in cleansubjs:
4700
+ cleannames.append(names[theindex])
4701
+ if usebadpts:
4702
+ bad = bad[:, cleansubjs]
4703
+ subjectnames = cleannames
4704
+
4705
+ LGR.info(f"after filtering, shape of x is {x.shape}")
4706
+
4707
+ N_pts = y.shape[0]
4708
+ N_subjs = y.shape[1]
4709
+
4710
+ X = np.zeros((1, N_pts, N_subjs))
4711
+ Y = np.zeros((1, N_pts, N_subjs))
4712
+ if usebadpts:
4713
+ BAD = np.zeros((1, N_pts, N_subjs))
4714
+
4715
+ X[0, :, :] = x
4716
+ Y[0, :, :] = y
4717
+ if usebadpts:
4718
+ BAD[0, :, :] = bad
4719
+
4720
+ windowspersubject = int((N_pts - window_size - 1) // step)
4721
+ LGR.info(
4722
+ f"found {windowspersubject * cleancount} out of a potential "
4723
+ f"{windowspersubject * totalcount} "
4724
+ f"({100.0 * cleancount / totalcount}%)"
4725
+ )
4726
+ LGR.info(f"{windowspersubject} {cleancount} {totalcount}")
4727
+
4728
+ Xb = np.zeros((N_subjs * windowspersubject, window_size, 1))
4729
+ LGR.info(f"dimensions of Xb: {Xb.shape}")
4730
+ for j in range(N_subjs):
4731
+ LGR.info(
4732
+ f"sub {j} ({cleannames[j]}) min, max X, Y: "
4733
+ f"{j} {np.min(X[0, :, j])} {np.max(X[0, :, j])} {np.min(Y[0, :, j])} "
4734
+ f"{np.max(Y[0, :, j])}"
4735
+ )
4736
+ for i in range(windowspersubject):
4737
+ Xb[j * windowspersubject + i, :, 0] = X[0, step * i : (step * i + window_size), j]
4738
+
4739
+ Yb = np.zeros((N_subjs * windowspersubject, window_size, 1))
4740
+ LGR.info(f"dimensions of Yb: {Yb.shape}")
4741
+ for j in range(N_subjs):
4742
+ for i in range(windowspersubject):
4743
+ Yb[j * windowspersubject + i, :, 0] = Y[0, step * i : (step * i + window_size), j]
4744
+
4745
+ if usebadpts:
4746
+ Xb_withbad = np.zeros((N_subjs * windowspersubject, window_size, 2))
4747
+ LGR.info(f"dimensions of Xb_withbad: {Xb_withbad.shape}")
4748
+ for j in range(N_subjs):
4749
+ LGR.info(f"packing data for subject {j}")
4750
+ for i in range(windowspersubject):
4751
+ Xb_withbad[j * windowspersubject + i, :, 0] = X[
4752
+ 0, step * i : (step * i + window_size), j
4753
+ ]
4754
+ Xb_withbad[j * windowspersubject + i, :, 1] = BAD[
4755
+ 0, step * i : (step * i + window_size), j
4756
+ ]
4757
+ Xb = Xb_withbad
4758
+
4759
+ subjectstarts = [i * windowspersubject for i in range(N_subjs)]
4760
+ for subj in range(N_subjs):
4761
+ LGR.info(f"{names[subj]} starts at {subjectstarts[subj]}")
4762
+
4763
+ LGR.info(f"Xb.shape: {Xb.shape}")
4764
+ LGR.info(f"Yb.shape: {Yb.shape}")
4765
+
4766
+ if dofft:
4767
+ Xb_fourier = np.zeros((N_subjs * windowspersubject, window_size, 2))
4768
+ LGR.info(f"dimensions of Xb_fourier: {Xb_fourier.shape}")
4769
+ Xscale_fourier = np.zeros((N_subjs, windowspersubject))
4770
+ LGR.info(f"dimensions of Xscale_fourier: {Xscale_fourier.shape}")
4771
+ Yb_fourier = np.zeros((N_subjs * windowspersubject, window_size, 2))
4772
+ LGR.info(f"dimensions of Yb_fourier: {Yb_fourier.shape}")
4773
+ Yscale_fourier = np.zeros((N_subjs, windowspersubject))
4774
+ LGR.info(f"dimensions of Yscale_fourier: {Yscale_fourier.shape}")
4775
+ for j in range(N_subjs):
4776
+ LGR.info(f"transforming subject {j}")
4777
+ for i in range((N_pts - window_size - 1)):
4778
+ (
4779
+ Xb_fourier[j * windowspersubject + i, :, :],
4780
+ Xscale_fourier[j, i],
4781
+ ) = filtscale(X[0, step * i : (step * i + window_size), j])
4782
+ (
4783
+ Yb_fourier[j * windowspersubject + i, :, :],
4784
+ Yscale_fourier[j, i],
4785
+ ) = filtscale(Y[0, step * i : (step * i + window_size), j])
4786
+
4787
+ limit = np.int64(0.8 * Xb.shape[0])
4788
+ LGR.info(f"limit: {limit} out of {len(subjectstarts)}")
4789
+ # find nearest subject start
4790
+ firstvalsubject = np.abs(subjectstarts - limit).argmin()
4791
+ LGR.info(f"firstvalsubject: {firstvalsubject}")
4792
+ perm_train = np.random.permutation(np.int64(np.arange(subjectstarts[firstvalsubject])))
4793
+ perm_val = np.random.permutation(
4794
+ np.int64(np.arange(subjectstarts[firstvalsubject], Xb.shape[0]))
4795
+ )
4796
+
4797
+ LGR.info("training subjects:")
4798
+ for i in range(0, firstvalsubject):
4799
+ LGR.info(f"\t{i} {subjectnames[i]}")
4800
+ LGR.info("validation subjects:")
4801
+ for i in range(firstvalsubject, len(subjectstarts)):
4802
+ LGR.info(f"\t{i} {subjectnames[i]}")
4803
+
4804
+ perm = range(Xb.shape[0])
4805
+
4806
+ batchsize = windowspersubject
4807
+
4808
+ if dofft:
4809
+ train_x = Xb_fourier[perm[:limit], :, :]
4810
+ train_y = Yb_fourier[perm[:limit], :, :]
4811
+
4812
+ val_x = Xb_fourier[perm[limit:], :, :]
4813
+ val_y = Yb_fourier[perm[limit:], :, :]
4814
+ LGR.info(f"train, val dims: {train_x.shape} {train_y.shape} {val_x.shape} {val_y.shape}")
4815
+ return (
4816
+ train_x,
4817
+ train_y,
4818
+ val_x,
4819
+ val_y,
4820
+ N_subjs,
4821
+ tclen - startskip - endskip,
4822
+ batchsize,
4823
+ Xscale_fourier,
4824
+ Yscale_fourier,
4825
+ )
4826
+ else:
4827
+ train_x = Xb[perm_train, :, :]
4828
+ train_y = Yb[perm_train, :, :]
4829
+
4830
+ val_x = Xb[perm_val, :, :]
4831
+ val_y = Yb[perm_val, :, :]
4832
+
4833
+ LGR.info(f"train, val dims: {train_x.shape} {train_y.shape} {val_x.shape} {val_y.shape}")
4834
+ return (
4835
+ train_x,
4836
+ train_y,
4837
+ val_x,
4838
+ val_y,
4839
+ N_subjs,
4840
+ tclen - startskip - endskip,
4841
+ batchsize,
4842
+ )