braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,483 @@
1
+ import math
2
+ from warnings import warn
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from linear_attention_transformer import LinearAttentionTransformer
7
+
8
+ from braindecode.models.base import EEGModuleMixin
9
+
10
+
11
+ class BIOT(EEGModuleMixin, nn.Module):
12
+ """BIOT from Yang et al. (2023) [Yang2023]_
13
+
14
+ .. figure:: https://braindecode.org/dev/_static/model/biot.jpg
15
+ :align: center
16
+ :alt: BioT
17
+
18
+ BIOT: Cross-data Biosignal Learning in the Wild.
19
+
20
+ BIOT is a large language model for biosignal classification. It is
21
+ a wrapper around the `BIOTEncoder` and `ClassificationHead` modules.
22
+
23
+ It is designed for N-dimensional biosignal data such as EEG, ECG, etc.
24
+ The method was proposed by Yang et al. [Yang2023]_ and the code is
25
+ available at [Code2023]_
26
+
27
+ The model is trained with a contrastive loss on large EEG datasets
28
+ TUH Abnormal EEG Corpus with 400K samples and Sleep Heart Health Study
29
+ 5M. Here, we only provide the model architecture, not the pre-trained
30
+ weights or contrastive loss training.
31
+
32
+ The architecture is based on the `LinearAttentionTransformer` and
33
+ `PatchFrequencyEmbedding` modules.
34
+ The `BIOTEncoder` is a transformer that takes the input data and outputs
35
+ a fixed-size representation of the input data. More details are
36
+ present in the `BIOTEncoder` class.
37
+
38
+ The `ClassificationHead` is an ELU activation layer, followed by a simple
39
+ linear layer that takes the output of the `BIOTEncoder` and outputs
40
+ the classification probabilities.
41
+
42
+ .. versionadded:: 0.9
43
+
44
+ Parameters
45
+ ----------
46
+ emb_size : int, optional
47
+ The size of the embedding layer, by default 256
48
+ att_num_heads : int, optional
49
+ The number of attention heads, by default 8
50
+ n_layers : int, optional
51
+ The number of transformer layers, by default 4
52
+ activation: nn.Module, default=nn.ELU
53
+ Activation function class to apply. Should be a PyTorch activation
54
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
55
+ return_feature: bool, optional
56
+ Changing the output for the neural network. Default is single tensor
57
+ when return_feature is True, return embedding space too.
58
+ Default is False.
59
+ hop_length: int, optional
60
+ The hop length for the torch.stft transformation in the
61
+ encoder. The default is 100.
62
+ sfreq: int, optional
63
+ The sfreq parameter for the encoder. The default is 200
64
+
65
+ References
66
+ ----------
67
+ .. [Yang2023] Yang, C., Westover, M.B. and Sun, J., 2023, November. BIOT:
68
+ Biosignal Transformer for Cross-data Learning in the Wild. In Thirty-seventh
69
+ Conference on Neural Information Processing Systems, NeurIPS.
70
+ .. [Code2023] Yang, C., Westover, M.B. and Sun, J., 2023. BIOT
71
+ Biosignal Transformer for Cross-data Learning in the Wild.
72
+ GitHub https://github.com/ycq091044/BIOT (accessed 2024-02-13)
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ emb_size=256,
78
+ att_num_heads=8,
79
+ n_layers=4,
80
+ sfreq=200,
81
+ hop_length=100,
82
+ return_feature=False,
83
+ n_outputs=None,
84
+ n_chans=None,
85
+ chs_info=None,
86
+ n_times=None,
87
+ input_window_seconds=None,
88
+ activation: nn.Module = nn.ELU,
89
+ drop_prob: float = 0.5,
90
+ ):
91
+ super().__init__(
92
+ n_outputs=n_outputs,
93
+ n_chans=n_chans,
94
+ chs_info=chs_info,
95
+ n_times=n_times,
96
+ input_window_seconds=input_window_seconds,
97
+ sfreq=sfreq,
98
+ )
99
+ del n_outputs, n_chans, chs_info, n_times, sfreq
100
+ self.emb_size = emb_size
101
+ self.hop_length = hop_length
102
+ self.att_num_heads = att_num_heads
103
+ self.n_layers = n_layers
104
+ self.return_feature = return_feature
105
+ if (self.sfreq != 200) & (self.sfreq is not None):
106
+ warn(
107
+ "This model has only been trained on a dataset with 200 Hz. "
108
+ + "no guarantee to generalize well with the default parameters",
109
+ UserWarning,
110
+ )
111
+ if self.n_chans > emb_size:
112
+ warn(
113
+ "The number of channels is larger than the embedding size. "
114
+ + "This may cause overfitting. Consider using a larger "
115
+ + "embedding size or a smaller number of channels.",
116
+ UserWarning,
117
+ )
118
+ if self.hop_length > self.sfreq:
119
+ warn(
120
+ "The hop length is larger than the sampling frequency. "
121
+ + "This may cause aliasing. Consider using a smaller "
122
+ "hop length.",
123
+ UserWarning,
124
+ )
125
+ hop_length = self.sfreq // 2
126
+ self.encoder = _BIOTEncoder(
127
+ emb_size=emb_size,
128
+ att_num_heads=att_num_heads,
129
+ n_layers=n_layers,
130
+ n_chans=self.n_chans,
131
+ n_fft=self.sfreq,
132
+ hop_length=hop_length,
133
+ drop_prob=drop_prob,
134
+ )
135
+
136
+ self.final_layer = _ClassificationHead(
137
+ emb_size=emb_size,
138
+ n_outputs=self.n_outputs,
139
+ activation=activation,
140
+ )
141
+
142
+ def forward(self, x):
143
+ """
144
+ Pass the input through the BIOT encoder, and then through the
145
+ classification head.
146
+
147
+ Parameters
148
+ ----------
149
+ x: Tensor
150
+ (batch_size, n_channels, n_times)
151
+
152
+ Returns
153
+ -------
154
+ out: Tensor
155
+ (batch_size, n_outputs)
156
+ (out, emb): tuple Tensor
157
+ (batch_size, n_outputs), (batch_size, emb_size)
158
+ """
159
+ emb = self.encoder(x)
160
+ x = self.final_layer(emb)
161
+
162
+ if self.return_feature:
163
+ return x, emb
164
+ else:
165
+ return x
166
+
167
+
168
+ class _PatchFrequencyEmbedding(nn.Module):
169
+ """
170
+ Patch Frequency Embedding.
171
+
172
+ A simple linear layer is used to learn some representation over the
173
+ frequency domain with permutation in the frequency axis.
174
+
175
+ Parameters
176
+ ----------
177
+ emb_size: int
178
+ The size of the embedding layer
179
+ n_freq: int
180
+ The number of frequency points after the Fourier transform
181
+
182
+ Returns
183
+ -------
184
+ out: Tensor
185
+ (batch, time, emb_size)
186
+ """
187
+
188
+ def __init__(self, emb_size: int = 256, n_freq: int = 101):
189
+ super().__init__()
190
+ self.projection = nn.Linear(n_freq, emb_size)
191
+
192
+ def forward(self, x):
193
+ """
194
+ Parameters
195
+ ----------
196
+ x: Tensor
197
+ (batch, time, n_freq)
198
+
199
+ Returns
200
+ -------
201
+ out: Tensor
202
+ (batch, time, emb_size)
203
+
204
+ """
205
+ x = x.permute(0, 2, 1)
206
+ x = self.projection(x)
207
+ return x
208
+
209
+
210
+ class _ClassificationHead(nn.Sequential):
211
+ """
212
+ Classification head for the BIOT model.
213
+
214
+ Simple linear layer with ELU activation function.
215
+
216
+ Parameters
217
+ ----------
218
+ emb_size: int
219
+ The size of the embedding layer
220
+ n_outputs: int
221
+ The number of classes
222
+ activation: nn.Module, default=nn.ELU
223
+ Activation function class to apply. Should be a PyTorch activation
224
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
225
+
226
+ Returns
227
+ -------
228
+ out: Tensor
229
+ (batch, n_outputs)
230
+ """
231
+
232
+ def __init__(self, emb_size: int, n_outputs: int, activation: nn.Module = nn.ELU):
233
+ super().__init__()
234
+ self.classification_head = nn.Sequential(
235
+ activation(),
236
+ nn.Linear(emb_size, n_outputs),
237
+ )
238
+
239
+ def forward(self, x):
240
+ out = self.classification_head(x)
241
+ return out
242
+
243
+
244
+ class _PositionalEncoding(nn.Module):
245
+ """
246
+ Positional Encoding.
247
+
248
+ We first create a `pe` zero matrix of shape (max_len, d_model) where max_len is the
249
+ the maximum length of the sequence, and emb_size is the size of the embedding.
250
+
251
+ Then we create a `position` tensor of shape (max_len, 1) with indices from 0 to max_len
252
+ and a `div_term` tensor of shape (d_model // 2) with the exponential of the
253
+ multiplication of the indices from 0 to d_model by -log(10000.0) / d_model.
254
+
255
+ For more details about the positional encoding, see the `Attention is All You Need` paper.
256
+
257
+ Parameters
258
+ ----------
259
+ emb_size: int
260
+ The size of the embedding layer
261
+ dropout: float
262
+ The dropout rate
263
+ max_len: int
264
+ The maximum length of the sequence
265
+ drop_prob : float, default=0.5
266
+ The dropout rate for regularization. Values should be between 0 and 1.
267
+
268
+ Returns
269
+ -------
270
+ out: Tensor
271
+ (batch, max_len, d_model)
272
+ """
273
+
274
+ def __init__(self, emb_size: int, drop_prob: float = 0.1, max_len: int = 1000):
275
+ super(_PositionalEncoding, self).__init__()
276
+
277
+ # Compute the positional encodings once in log space.
278
+ pe = torch.zeros(max_len, emb_size)
279
+ position = torch.arange(0, max_len).unsqueeze(1).float()
280
+ div_term = torch.exp(
281
+ torch.arange(0, emb_size, 2).float() * -(math.log(10000.0) / emb_size)
282
+ )
283
+ pe[:, 0::2] = torch.sin(position * div_term)
284
+ pe[:, 1::2] = torch.cos(position * div_term)
285
+ pe = pe.unsqueeze(0)
286
+ self.register_buffer("pe", pe)
287
+ self.dropout = nn.Dropout(p=drop_prob)
288
+
289
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
290
+ """
291
+ Parameters
292
+ ----------
293
+ x: FloatTensor
294
+ `embeddings,` shape (batch, max_len, d_model)
295
+
296
+ Returns
297
+ -------
298
+ Tensor
299
+ `encoder input`, shape (batch, max_len, d_model)
300
+ """
301
+ x = x + self.pe[:, : x.size(1)]
302
+ return self.dropout(x)
303
+
304
+
305
+ class _BIOTEncoder(nn.Module):
306
+ """
307
+ BIOT Encoder.
308
+
309
+ The BIOT encoder is a transformer that takes the time series input data and
310
+ return a fixed-size embedding representation of the input data.
311
+ The architecture is based on the `LinearAttentionTransformer` and
312
+ `PatchFrequencyEmbedding` modules.
313
+
314
+ The input data is transformed into a spectrogram and then embedded using a
315
+ "patch" embedding.
316
+ The channel token is added to the patch embedding, and then
317
+ positional encoding is applied (simple index positional).
318
+ The resulting embeddings are concatenated
319
+ and passed through a transformer layer. The mean across different channels
320
+ embeddings is returned.
321
+
322
+ Parameters
323
+ ----------
324
+ n_chans: int
325
+ The number of channels
326
+ emb_size: int
327
+ The size of the embedding layer
328
+ att_num_heads: int
329
+ The number of attention heads
330
+ n_layers: int
331
+ The number of transformer layers
332
+ n_fft: int
333
+ The number of Fourier transform points
334
+ hop_length: int (default 100)
335
+ The distance between neighboring sliding window frames
336
+ """
337
+
338
+ def __init__(
339
+ self,
340
+ emb_size=256, # The size of the embedding layer
341
+ att_num_heads=8, # The number of attention heads
342
+ n_chans=16, # The number of channels
343
+ n_layers=4, # The number of transformer layers
344
+ n_fft=200, # Related with the frequency resolution
345
+ hop_length=100,
346
+ drop_prob: float = 0.1,
347
+ ):
348
+ super().__init__()
349
+
350
+ self.n_fft = int(n_fft)
351
+ self.hop_length = hop_length
352
+
353
+ self.patch_embedding = _PatchFrequencyEmbedding(
354
+ emb_size=emb_size, n_freq=int(self.n_fft // 2 + 1)
355
+ )
356
+ self.transformer = LinearAttentionTransformer(
357
+ dim=emb_size,
358
+ heads=att_num_heads,
359
+ depth=n_layers,
360
+ max_seq_len=1024,
361
+ attn_layer_dropout=0.2, # dropout right after self-attention layer
362
+ attn_dropout=0.2, # dropout post-attention
363
+ )
364
+ self.positional_encoding = _PositionalEncoding(emb_size, drop_prob=drop_prob)
365
+
366
+ # channel token, N_channels >= your actual channels
367
+ self.channel_tokens = nn.Embedding(
368
+ num_embeddings=n_chans, embedding_dim=emb_size
369
+ )
370
+ self.index = nn.Parameter(torch.LongTensor(range(n_chans)), requires_grad=False)
371
+
372
+ def stft(self, sample):
373
+ """
374
+ Short-time Fourier transform.
375
+ For more details, see `torch.stft`.
376
+
377
+ The size of the Fourier transform is obtained by `n_fft`, and the distance
378
+ between neighboring sliding window frames `hop_length` defined in
379
+ the __init__ functions.
380
+
381
+ Parameters
382
+ ----------
383
+ sample: Tensor
384
+ channel representation with size (batch_size, n_times)
385
+ Returns
386
+ -------
387
+ spectral: Tensor
388
+ Absolute value of the Fourier transform with size
389
+ (batch_size, n_fft // 2 + 1, n_times // hop_length + 1)
390
+ """
391
+ spectral = torch.stft(
392
+ input=sample.squeeze(1),
393
+ n_fft=int(self.n_fft),
394
+ hop_length=self.hop_length,
395
+ center=False,
396
+ onesided=True,
397
+ return_complex=True,
398
+ )
399
+ return torch.abs(spectral)
400
+
401
+ def forward(self, x, n_channel_offset=0, perturb=False):
402
+ """
403
+ Forward pass of the BIOT encoder.
404
+
405
+ For each channel, the input is transformed into a spectrogram
406
+ and then embedded using a patch embedding. The channel token
407
+ is added to the patch embedding, and then positional encoding
408
+ is applied. The resulting embeddings are concatenated and
409
+ passed through a transformer layer. The mean of the resulting
410
+ embeddings is returned.
411
+
412
+ For each channel in channels, the channels are transformed into a
413
+ spectrogram with STFT; The spectrogram representation is permuted
414
+ and passed through a linear layer to learn some representation over
415
+ the frequency domain.
416
+
417
+ For each embedding in the sequence, the channel token is added to
418
+ the patch embedding, and then positional encoding is applied.
419
+
420
+ The resulting embeddings are concatenated and passed through a
421
+ transformer layer. The mean of the resulting embeddings is returned.
422
+
423
+ Parameters
424
+ ----------
425
+ x: Tensor
426
+ (batch_size, n_channels, n_times)
427
+ n_channel_offset: int (default 0)
428
+ The offset term to be added to the channel tokens
429
+ perturb: bool (default False)
430
+ Randomly select a number of time steps and reduce the
431
+ channel embedding to those time steps.
432
+
433
+ Returns
434
+ -------
435
+ emb: Tensor
436
+ (batch_size, emb_size)
437
+ """
438
+ emb_seq = []
439
+ for i in range(x.shape[1]):
440
+ # Getting the spectrogram
441
+ channel_spec_emb = self.stft(x[:, i : i + 1, :])
442
+ # Linear layer to learn some representation over the frequency domain
443
+ # with permutation
444
+ channel_spec_emb = self.patch_embedding(channel_spec_emb)
445
+ batch_size, ts, _ = channel_spec_emb.shape
446
+ # (batch_size, ts, emb)
447
+ # Step by step the following lines do the following operations:
448
+ # - self.channel_tokens(self.index[i + n_channel_offset]):
449
+ # Fetches the embedding for a channel specified by i + n_channel_offset,
450
+ # where i is the current index and n_channel_offset adjusts
451
+ # which channel's embedding is retrieved.
452
+ # - .unsqueeze(0).unsqueeze(0):
453
+ # Adds two singleton dimensions to the embedding tensor.
454
+ # [emb_size] to [1, 1, emb_size] .
455
+ # - Repeat(batch_size, ts, 1):
456
+ # Replicates the embedding tensor to match the batch size and
457
+ # time steps (ts),
458
+ channel_token_emb = (
459
+ self.channel_tokens(self.index[i + n_channel_offset])
460
+ .unsqueeze(0)
461
+ .unsqueeze(0)
462
+ .repeat(batch_size, ts, 1)
463
+ )
464
+ # (batch_size, ts, emb)
465
+ # The positional embedding is explaining with more
466
+ # detail in the _PositionalEncoding class.
467
+ channel_emb = self.positional_encoding(channel_spec_emb + channel_token_emb)
468
+ # In case of perturb, the time steps are randomly selected
469
+ # and the channel embedding is reduced to a random number
470
+ # of time steps.
471
+ if perturb:
472
+ ts = channel_emb.shape[1]
473
+ ts_new = torch.randint(low=ts // 2, high=ts, size=(1,)).item()
474
+ selected_ts = torch.randperm(ts)[:ts_new]
475
+ channel_emb = channel_emb[:, selected_ts]
476
+ emb_seq.append(channel_emb)
477
+
478
+ # Concat and transformer
479
+ # (batch_size, 16 * ts, emb)
480
+ emb = torch.cat(emb_seq, dim=1)
481
+ # (batch_size, emb)
482
+ emb = self.transformer(emb).mean(dim=1)
483
+ return emb