ml4gw 0.7.4__py3-none-any.whl → 0.7.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

ml4gw/gw.py CHANGED
@@ -2,13 +2,11 @@
2
2
  Tools for manipulating raw gravitational waveforms
3
3
  and projecting them onto interferometer responses.
4
4
  Much of the projection code is an extension of the
5
- implementation made available in bilby:
6
-
7
- https://arxiv.org/abs/1811.02042
8
-
9
- Specifically the code here:
10
- https://github.com/lscsoft/bilby/blob/master/bilby/gw/detector/interferometer.py
11
- """
5
+ implementation made available in
6
+ `bilby <https://arxiv.org/abs/1811.02042>`_.
7
+ Specifically code from
8
+ `this module <https://github.com/lscsoft/bilby/blob/master/bilby/gw/detector/interferometer.py>`_.
9
+ """ # noqa E501
12
10
 
13
11
  from typing import List, Tuple, Union
14
12
 
@@ -134,6 +132,9 @@ def compute_antenna_responses(
134
132
  # shape: batch x num_polarizations x 3 x 3
135
133
  polarization = torch.stack(polarizations, axis=1)
136
134
 
135
+ # Ensure dtype consistency before einsum
136
+ detector_tensors = detector_tensors.to(polarization.dtype)
137
+
137
138
  # compute the weight of each interferometer's response
138
139
  # to each polarization: batch x polarizations x ifos
139
140
  return torch.einsum("...jk,ijk->...i", polarization, detector_tensors)
@@ -194,7 +195,7 @@ def compute_observed_strain(
194
195
  **polarizations: Float[Tensor, "batch time"],
195
196
  ) -> WaveformTensor:
196
197
  """
197
- Compute the strain timeseries $h(t)$ observed by a network
198
+ Compute the strain timeseries :math:`h(t)` observed by a network
198
199
  of interferometers from the given polarization timeseries
199
200
  corresponding to gravitational waveforms from sources with
200
201
  the indicated sky parameters.
@@ -222,13 +223,13 @@ def compute_observed_strain(
222
223
  between the waveform observed at the geocenter and
223
224
  the one observed at the detector site. To avoid
224
225
  adding any delay between the two, reset your coordinates
225
- such that the desired interferometer is at `(0., 0., 0.)`.
226
+ such that the desired interferometer is at ``(0., 0., 0.)``.
226
227
  sample_rate:
227
228
  Rate at which the polarization timeseries have been sampled
228
229
  polarziations:
229
230
  Timeseries for each waveform polarization which
230
231
  contributes to the interferometer response. Allowed
231
- polarizations are `cross`, `plus`, and `breathing`.
232
+ polarizations are ``cross``, ``plus``, and ``breathing``.
232
233
  Returns:
233
234
  Tensor representing the observed strain at each
234
235
  interferometer for each waveform.
@@ -236,13 +237,15 @@ def compute_observed_strain(
236
237
 
237
238
  # TODO: just use theta as the input parameter?
238
239
  # note that ** syntax is ordered, so we're safe
239
- # to be lazy and use `list` for the keys and values
240
+ # to be lazy and use ``list`` for the keys and values
240
241
  theta = torch.pi / 2 - dec
241
242
  antenna_responses = compute_antenna_responses(
242
243
  theta, psi, phi, detector_tensors, list(polarizations)
243
244
  )
244
245
 
245
246
  polarizations = torch.stack(list(polarizations.values()), axis=1)
247
+ # Ensure dtype consistency before einsum
248
+ antenna_responses = antenna_responses.to(polarizations.dtype)
246
249
  waveforms = torch.einsum(
247
250
  "...pi,...pt->...it", antenna_responses, polarizations
248
251
  )
@@ -286,26 +289,28 @@ def compute_ifo_snr(
286
289
  highpass: Union[float, Float[Tensor, " frequency"], None] = None,
287
290
  lowpass: Union[float, Float[Tensor, " frequency"], None] = None,
288
291
  ) -> Float[Tensor, "batch num_ifos"]:
289
- r"""Compute the SNRs of a batch of interferometer responses
292
+ """Compute the SNRs of a batch of interferometer responses
290
293
 
291
294
  Compute the signal to noise ratio (SNR) of individual
292
295
  interferometer responses to gravitational waveforms with
293
296
  respect to a background PSD for each interferometer. The
294
- SNR of the $i$th waveform at the $j$th interferometer
297
+ SNR of the :math:`i` th waveform at the :math:`j` th interferometer
295
298
  is computed as:
296
299
 
297
- $$\rho_{ij} =
298
- 4 \int_{f_{\text{min}}}^{f_{\text{max}}}
299
- \frac{\tilde{h_{ij}}(f)\tilde{h_{ij}}^*(f)}
300
- {S_n^{(j)}(f)}df$$
300
+ .. math::
301
301
 
302
- Where $f_{\text{min}}$ is a minimum frequency denoted
303
- by `highpass`, `f_{\text{max}}` is the maximum frequency
304
- denoted by `lowpass`, which defaults to the Nyquist frequency
305
- dictated by `sample_rate`; `\tilde{h_{ij}}` and `\tilde{h_{ij}}*`
306
- indicate the fourier transform of the $i$th waveform at
307
- the $j$th inteferometer and its complex conjugate, respectively;
308
- and $S_n^{(j)}$ is the backround PSD at the $j$th interferometer.
302
+ \\rho_{ij} =
303
+ 4 \\int_{f_{\\text{min}}}^{f_{\\text{max}}}
304
+ \\frac{\\tilde{h_{ij}}(f)\\tilde{h_{ij}}^*(f)}
305
+ {S_n^{(j)}(f)}df
306
+
307
+ Where :math:`f_{\\text{min}}` is a minimum frequency denoted
308
+ by ``highpass``, :math:`f_{\\text{max}}` is the maximum frequency
309
+ denoted by ``lowpass``, which defaults to the Nyquist frequency
310
+ dictated by ``sample_rate``; :math:`\\tilde{h}_{ij}` and :math:`\\tilde{h}_{ij}^*`
311
+ indicate the fourier transform of the :math:`i` th waveform at
312
+ the :math:`j` th inteferometer and its complex conjugate, respectively;
313
+ and :math:`S_n^{(j)}` is the backround PSD at the :math:`j` th interferometer.
309
314
 
310
315
  Args:
311
316
  responses:
@@ -314,12 +319,12 @@ def compute_ifo_snr(
314
319
  psd:
315
320
  The one-sided power spectral density of the background
316
321
  noise at each interferometer to which a response
317
- in `responses` has been calculated. If 2D, each row of
318
- `psd` will be assumed to be the background PSD for each
319
- channel of _every_ batch element in `responses`. If 3D,
322
+ in ``responses`` has been calculated. If 2D, each row of
323
+ ``psd`` will be assumed to be the background PSD for each
324
+ channel of _every_ batch element in ``responses``. If 3D,
320
325
  this should contain a background PSD for each channel
321
- of each element in `responses`, and therefore the first
322
- two dimensions of `psd` and `responses` should match.
326
+ of each element in ``responses``, and therefore the first
327
+ two dimensions of ``psd`` and ``responses`` should match.
323
328
  sample_rate:
324
329
  The frequency at which the waveform responses timeseries
325
330
  have been sampled. Upon fourier transforming, should
@@ -329,18 +334,18 @@ def compute_ifo_snr(
329
334
  If a tensor is provided, it will be assumed to be a
330
335
  pre-computed mask used to 0-out low frequency components.
331
336
  If a float, it will be used to compute such a mask. If
332
- left as `None`, all frequencies up to `lowpass`
337
+ left as ``None``, all frequencies up to ``lowpass``
333
338
  will contribute to the SNR calculation.
334
339
  lowpass:
335
340
  The maximum frequency below which to compute the SNR.
336
341
  If a tensor is provided, it will be assumed to be a
337
342
  pre-computed mask used to 0-out high frequency components.
338
343
  If a float, it will be used to compute such a mask. If
339
- left as `None`, all frequencies from `highpass` up to
344
+ left as ``None``, all frequencies from ``highpass`` up to
340
345
  the Nyquist freqyency will contribute to the SNR calculation.
341
346
  Returns:
342
347
  Batch of SNRs computed for each interferometer
343
- """
348
+ """ # noqa E501
344
349
 
345
350
  # TODO: should we do windowing here?
346
351
  # compute frequency power, upsampling precision so that
@@ -388,10 +393,10 @@ def compute_ifo_snr(
388
393
  # that the user specify the sample rate by taking the
389
394
  # fft as-is (without dividing by sample rate) and then
390
395
  # taking the mean here (or taking the sum and dividing
391
- # by the sum of `highpass` if it's a mask). If we want
396
+ # by the sum of ``highpass`` if it's a mask). If we want
392
397
  # to allow the user to pass a float for highpass, we'll
393
398
  # need the sample rate to compute the mask, but if we
394
- # replace this with a `mask` argument instead we're in
399
+ # replace this with a ``mask`` argument instead we're in
395
400
  # the clear
396
401
  df = sample_rate / responses.shape[-1]
397
402
  integrated = integrand.sum(axis=-1) * df
@@ -408,15 +413,17 @@ def compute_network_snr(
408
413
  highpass: Union[float, Float[Tensor, " frequency"], None] = None,
409
414
  lowpass: Union[float, Float[Tensor, " frequency"], None] = None,
410
415
  ) -> BatchTensor:
411
- r"""
416
+ """
412
417
  Compute the total SNR from a gravitational waveform
413
418
  from a network of interferometers. The total SNR for
414
- the $i$th waveform is computed as
419
+ the :math:`i` th waveform is computed as
420
+
421
+ .. math::
415
422
 
416
- $$\rho_i = \sqrt{\sum_{j}^{N}\rho_{ij}^2}$$
423
+ \\rho_i = \\sqrt{\\sum_{j}^{N}\\rho_{ij}^2}
417
424
 
418
- where \rho_{ij} is the SNR for the $i$th waveform at
419
- the $j$th interferometer in the network and $N$ is
425
+ where :math:`\\rho_{ij}` is the SNR for the :math:`i` th waveform at
426
+ the :math:`j` th interferometer in the network and :math:`N` is
420
427
  the total number of interferometers.
421
428
 
422
429
  Args:
@@ -426,12 +433,12 @@ def compute_network_snr(
426
433
  backgrounds:
427
434
  The one-sided power spectral density of the background
428
435
  noise at each interferometer to which a response
429
- in `responses` has been calculated. If 2D, each row of
430
- `psd` will be assumed to be the background PSD for each
431
- channel of _every_ batch element in `responses`. If 3D,
436
+ in ``responses`` has been calculated. If 2D, each row of
437
+ ``psd`` will be assumed to be the background PSD for each
438
+ channel of **every** batch element in ``responses``. If 3D,
432
439
  this should contain a background PSD for each channel
433
- of each element in `responses`, and therefore the first
434
- two dimensions of `psd` and `responses` should match.
440
+ of each element in ``responses``, and therefore the first
441
+ two dimensions of ``psd`` and ``responses`` should match.
435
442
  sample_rate:
436
443
  The frequency at which the waveform responses timeseries
437
444
  have been sampled. Upon fourier transforming, should
@@ -441,14 +448,14 @@ def compute_network_snr(
441
448
  If a tensor is provided, it will be assumed to be a
442
449
  pre-computed mask used to 0-out low frequency components.
443
450
  If a float, it will be used to compute such a mask. If
444
- left as `None`, all frequencies up to `sample_rate / 2`
451
+ left as ``None``, all frequencies up to ``sample_rate / 2``
445
452
  will contribute to the SNR calculation.
446
453
  lowpass:
447
454
  The maximum frequency below which to compute the SNR.
448
455
  If a tensor is provided, it will be assumed to be a
449
456
  pre-computed mask used to 0-out high frequency components.
450
457
  If a float, it will be used to compute such a mask. If
451
- left as `None`, all frequencies from `highpass` up to
458
+ left as ``None``, all frequencies from ``highpass`` up to
452
459
  the Nyquist freqyency will contribute to the SNR calculation.
453
460
  Returns:
454
461
  Batch of SNRs for each waveform across the interferometer network
@@ -478,12 +485,12 @@ def reweight_snrs(
478
485
  psd:
479
486
  The one-sided power spectral density of the background
480
487
  noise at each interferometer to which a response
481
- in `responses` has been calculated. If 2D, each row of
482
- `psd` will be assumed to be the background PSD for each
483
- channel of _every_ batch element in `responses`. If 3D,
488
+ in ``responses`` has been calculated. If 2D, each row of
489
+ ``psd`` will be assumed to be the background PSD for each
490
+ channel of **every** batch element in ``responses``. If 3D,
484
491
  this should contain a background PSD for each channel
485
- of each element in `responses`, and therefore the first
486
- two dimensions of `psd` and `responses` should match.
492
+ of each element in ``responses``, and therefore the first
493
+ two dimensions of ``psd`` and ``responses`` should match.
487
494
  sample_rate:
488
495
  The frequency at which the waveform responses timeseries
489
496
  have been sampled. Upon fourier transforming, should
@@ -493,14 +500,14 @@ def reweight_snrs(
493
500
  If a tensor is provided, it will be assumed to be a
494
501
  pre-computed mask used to 0-out low frequency components.
495
502
  If a float, it will be used to compute such a mask. If
496
- left as `None`, all frequencies up to `sample_rate / 2`
503
+ left as ``None``, all frequencies up to ``sample_rate / 2``
497
504
  will contribute to the SNR calculation.
498
505
  lowpass:
499
506
  The maximum frequency below which to compute the SNR.
500
507
  If a tensor is provided, it will be assumed to be a
501
508
  pre-computed mask used to 0-out high frequency components.
502
509
  If a float, it will be used to compute such a mask. If
503
- left as `None`, all frequencies from `highpass` up to
510
+ left as ``None``, all frequencies from ``highpass`` up to
504
511
  the Nyquist freqyency will contribute to the SNR calculation.
505
512
  Returns:
506
513
  Rescaled interferometer responses
@@ -12,18 +12,18 @@ class Autoencoder(torch.nn.Module):
12
12
  Base autoencoder class that defines some of the
13
13
  basic methods and functionality. Autoencoders are
14
14
  defined here as a set of sequential blocks that
15
- have an `encode` method, which acts on the input
16
- data to the autoencoder, and a `decode` method, which
17
- acts on the encoded vector generated by the `encode`
18
- method. `forward` just runs these steps one after the
15
+ have an ``encode`` method, which acts on the input
16
+ data to the autoencoder, and a ``decode`` method, which
17
+ acts on the encoded vector generated by the ``encode``
18
+ method. ``forward`` just runs these steps one after the
19
19
  other. Although it isn't explicitly enforced, a good
20
- rule of thumb is that the ouput of a block's `decode`
20
+ rule of thumb is that the ouput of a block's ``decode``
21
21
  method should have the same shape as the _input_ of its
22
- `encode` method.
22
+ ``encode`` method.
23
23
 
24
- Accepts a `skip_connection` argument that defines how to
25
- combine information from the input of one block's `encode`
26
- layer with the output to its `decode`layer. See `skip_connections.py`
24
+ Accepts a ``skip_connection`` argument that defines how to
25
+ combine information from the input of one block's ``encode``
26
+ layer with the output to its ``decode`` layer. See ``skip_connections.py``
27
27
  for more info about what these classes are expected to contain
28
28
  and how they operate.
29
29
  """
@@ -83,11 +83,11 @@ class ConvolutionalAutoencoder(Autoencoder):
83
83
  match the shape of the input to its corresponding
84
84
  encoder layer, except for the last decoder which
85
85
  can have an arbitrary number of channels specified
86
- by `decode_channels`.
86
+ by ``decode_channels``.
87
87
 
88
- All layers also share the same `activation` except
88
+ All layers also share the same ``activation`` except
89
89
  for the last decoder layer, which can have an
90
- arbitrary `output_activation`.
90
+ arbitrary ``output_activation``.
91
91
  """
92
92
 
93
93
  def __init__(
@@ -115,7 +115,7 @@ class ConvolutionalAutoencoder(Autoencoder):
115
115
  # All intermediate layers should decode to
116
116
  # the same number of channels. The last decoder
117
117
  # should decode to whatever number of channels
118
- # was specified, even if it's `None` (in which
118
+ # was specified, even if it's ``None`` (in which
119
119
  # case it will just be in_channels anyway)
120
120
  decode = in_channels if i else decode_channels
121
121
 
@@ -108,10 +108,10 @@ class BasicBlock(nn.Module):
108
108
  class Bottleneck(nn.Module):
109
109
  """
110
110
  Bottleneck blocks implement one extra convolution
111
- compared to basic blocks. In this layers, the `planes`
112
- parameter is generally meant to _downsize_ the number
111
+ compared to basic blocks. In this layers, the ``planes``
112
+ parameter is generally meant to **downsize** the number
113
113
  of feature maps first, which then get expanded out to
114
- `planes * Bottleneck.expansion` feature maps at the
114
+ ``planes * Bottleneck.expansion`` feature maps at the
115
115
  output of the layer.
116
116
  """
117
117
 
@@ -192,9 +192,9 @@ class ResNet1D(nn.Module):
192
192
  A list representing the number of residual
193
193
  blocks to include in each "layer" of the
194
194
  network. Total layers (e.g. 50 in ResNet50)
195
- is `2 + sum(layers) * factor`, where factor
196
- is `2` for vanilla `ResNet` and `3` for
197
- `BottleneckResNet`.
195
+ is ``2 + sum(layers) * factor``, where factor
196
+ is ``2`` for vanilla ``ResNet`` and ``3`` for
197
+ ``BottleneckResNet``.
198
198
  kernel_size:
199
199
  The size of the convolutional kernel to
200
200
  use in all residual layers. _NOT_ the size
@@ -211,19 +211,19 @@ class ResNet1D(nn.Module):
211
211
  connections between feature maps at subsequent
212
212
  layers rather than global. Generally won't
213
213
  need this to be >1, and wil raise an error if
214
- >1 when using vanilla `ResNet`.
214
+ >1 when using vanilla ``ResNet``.
215
215
  width_per_group:
216
216
  Base width of each of the feature map groups,
217
217
  which is scaled up by the typical expansion
218
218
  factor at each layer of the network. Meaningless
219
- for vanilla `ResNet`.
219
+ for vanilla ``ResNet``.
220
220
  stride_type:
221
221
  Whether to achieve downsampling on the time axis
222
222
  by strided or dilated convolutions for each layer.
223
- If left as `None`, strided convolutions will be
224
- used at each layer. Otherwise, `stride_type` should
225
- be one element shorter than `layers` and indicate either
226
- `stride` or `dilation` for each layer after the first.
223
+ If left as ``None``, strided convolutions will be
224
+ used at each layer. Otherwise, ``stride_type`` should
225
+ be one element shorter than ``layers`` and indicate either
226
+ ``stride`` or ``dilation`` for each layer after the first.
227
227
  """
228
228
 
229
229
  block = BasicBlock
@@ -316,7 +316,7 @@ class ResNet1D(nn.Module):
316
316
  nn.init.kaiming_normal_(
317
317
  m.weight, mode="fan_out", nonlinearity="relu"
318
318
  )
319
- elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)):
319
+ elif isinstance(m, nn.BatchNorm1d):
320
320
  nn.init.constant_(m.weight, 1)
321
321
  nn.init.constant_(m.bias, 0)
322
322
 
@@ -105,10 +105,10 @@ class BasicBlock(nn.Module):
105
105
  class Bottleneck(nn.Module):
106
106
  """
107
107
  Bottleneck blocks implement one extra convolution
108
- compared to basic blocks. In this layers, the `planes`
108
+ compared to basic blocks. In this layers, the ``planes``
109
109
  parameter is generally meant to _downsize_ the number
110
110
  of feature maps first, which then get expanded out to
111
- `planes * Bottleneck.expansion` feature maps at the
111
+ ``planes * Bottleneck.expansion`` feature maps at the
112
112
  output of the layer.
113
113
  """
114
114
 
@@ -188,9 +188,9 @@ class ResNet2D(nn.Module):
188
188
  A list representing the number of residual
189
189
  blocks to include in each "layer" of the
190
190
  network. Total layers (e.g. 50 in ResNet50)
191
- is `2 + sum(layers) * factor`, where factor
192
- is `2` for vanilla `ResNet` and `3` for
193
- `BottleneckResNet`.
191
+ is ``2 + sum(layers) * factor``, where factor
192
+ is ``2`` for vanilla ``ResNet`` and ``3`` for
193
+ ``BottleneckResNet``.
194
194
  kernel_size:
195
195
  The size of the convolutional kernel to
196
196
  use in all residual layers. _NOT_ the size
@@ -207,22 +207,22 @@ class ResNet2D(nn.Module):
207
207
  connections between feature maps at subsequent
208
208
  layers rather than global. Generally won't
209
209
  need this to be >1, and wil raise an error if
210
- >1 when using vanilla `ResNet`.
210
+ >1 when using vanilla ``ResNet``.
211
211
  width_per_group:
212
212
  Base width of each of the feature map groups,
213
213
  which is scaled up by the typical expansion
214
214
  factor at each layer of the network. Meaningless
215
- for vanilla `ResNet`.
215
+ for vanilla ``ResNet``.
216
216
  stride_type:
217
217
  Whether to achieve downsampling on the time axis
218
218
  by strided or dilated convolutions for each layer.
219
- If left as `None`, strided convolutions will be
220
- used at each layer. Otherwise, `stride_type` should
221
- be one element shorter than `layers` and indicate either
222
- `stride` or `dilation` for each layer after the first.
219
+ If left as ``None``, strided convolutions will be
220
+ used at each layer. Otherwise, ``stride_type`` should
221
+ be one element shorter than ``layers`` and indicate either
222
+ ``stride`` or ``dilation`` for each layer after the first.
223
223
  norm_groups:
224
224
  The number of groups to use in GroupNorm layers
225
- throughout the model. If left as `-1`, the number
225
+ throughout the model. If left as ``-1``, the number
226
226
  of groups will be equal to the number of channels,
227
227
  making this equilavent to LayerNorm
228
228
  """
@@ -11,7 +11,7 @@ class OnlineAverager(torch.nn.Module):
11
11
  """
12
12
  Module for performing stateful online averaging of
13
13
  batches of overlapping timeseries. At present, the
14
- first `num_updates` predictions produced by this
14
+ first ``num_updates`` predictions produced by this
15
15
  model will underestimate the true average.
16
16
 
17
17
  Args:
@@ -12,24 +12,24 @@ class Snapshotter(torch.nn.Module):
12
12
  Model for converting streaming state updates into
13
13
  a batch of overlapping snaphots of a multichannel
14
14
  timeseries. Can support multiple timeseries in a
15
- single state update via the `channels_per_snapshot`
15
+ single state update via the ``channels_per_snapshot``
16
16
  kwarg.
17
17
 
18
18
  Specifically, maps tensors of shape
19
- `(num_channels, batch_size * stride_size)` to a tensor
20
- of shape `(batch_size, num_channels, snapshot_size)`.
21
- If `channels_per_snapshot` is specified, it will return
22
- `len(channels_per_snapshot)` tensors of this shape,
19
+ ``(num_channels, batch_size * stride_size)`` to a tensor
20
+ of shape ``(batch_size, num_channels, snapshot_size)``.
21
+ If ``channels_per_snapshot`` is specified, it will return
22
+ ``len(channels_per_snapshot)`` tensors of this shape,
23
23
  with the channel dimension replaced by the corresponding
24
- value of `channels_per_snapshot`. The last tensor returned
24
+ value of ``channels_per_snapshot``. The last tensor returned
25
25
  at call time will be the current state that can be passed
26
- to the next `forward` call.
26
+ to the next ``forward`` call.
27
27
 
28
28
  Args:
29
29
  num_channels:
30
30
  Number of channels in the timeseries. If
31
- `channels_per_snapshot` is not `None`,
32
- this should be equal to `sum(channels_per_snapshot)`.
31
+ ``channels_per_snapshot`` is not ``None``,
32
+ this should be equal to ``sum(channels_per_snapshot)``.
33
33
  snapshot_size:
34
34
  The size of the output snapshot windows in
35
35
  number of samples
@@ -39,17 +39,17 @@ class Snapshotter(torch.nn.Module):
39
39
  batch_size:
40
40
  The number of snapshots to produce at each
41
41
  update. The last dimension of the input
42
- tensor should have size `batch_size * stride_size`.
42
+ tensor should have size ``batch_size * stride_size``.
43
43
  channels_per_snapshot:
44
44
  How to split up the channels in the timeseries
45
- for different tensors. If left as `None`, all
45
+ for different tensors. If left as ``None``, all
46
46
  the channels will be returned in a single tensor.
47
47
  Otherwise, the channels will be split up into
48
- `len(channels_per_snapshot)` tensors, with each
48
+ ``len(channels_per_snapshot)`` tensors, with each
49
49
  tensor's channel dimension being equal to the
50
- corresponding value in `channels_per_snapshot`.
50
+ corresponding value in ``channels_per_snapshot``.
51
51
  Therefore, if specified, these values should
52
- add up to `num_channels`.
52
+ add up to ``num_channels``.
53
53
  """
54
54
 
55
55
  def __init__(