ml4gw 0.7.8__py3-none-any.whl → 0.7.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4gw/dataloading/hdf5_dataset.py +1 -0
- ml4gw/distributions.py +20 -8
- ml4gw/spectral.py +12 -4
- ml4gw/transforms/__init__.py +1 -0
- ml4gw/transforms/decimator.py +77 -29
- ml4gw/transforms/whitening.py +28 -4
- {ml4gw-0.7.8.dist-info → ml4gw-0.7.10.dist-info}/METADATA +3 -2
- {ml4gw-0.7.8.dist-info → ml4gw-0.7.10.dist-info}/RECORD +11 -11
- {ml4gw-0.7.8.dist-info → ml4gw-0.7.10.dist-info}/WHEEL +0 -0
- {ml4gw-0.7.8.dist-info → ml4gw-0.7.10.dist-info}/licenses/LICENSE +0 -0
- {ml4gw-0.7.8.dist-info → ml4gw-0.7.10.dist-info}/top_level.txt +0 -0
|
@@ -151,6 +151,7 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
|
151
151
|
unique_fnames, inv, counts = np.unique(
|
|
152
152
|
fnames, return_inverse=True, return_counts=True
|
|
153
153
|
)
|
|
154
|
+
inv = inv.reshape(-1)
|
|
154
155
|
for i, (fname, count) in enumerate(
|
|
155
156
|
zip(unique_fnames, counts, strict=True)
|
|
156
157
|
):
|
ml4gw/distributions.py
CHANGED
|
@@ -225,8 +225,10 @@ class UniformComovingVolume(dist.Distribution):
|
|
|
225
225
|
f"or 'luminosity_distance'; got {distance_type}"
|
|
226
226
|
)
|
|
227
227
|
|
|
228
|
-
self.minimum = minimum
|
|
229
|
-
self.maximum = maximum
|
|
228
|
+
self.minimum = torch.as_tensor(minimum)
|
|
229
|
+
self.maximum = torch.as_tensor(maximum)
|
|
230
|
+
if self.minimum.device != self.maximum.device:
|
|
231
|
+
raise RuntimeError("Min and max values are not on same device")
|
|
230
232
|
self.distance_type = distance_type
|
|
231
233
|
self.grid_size = grid_size
|
|
232
234
|
self.z_grid_max = z_grid_max
|
|
@@ -265,7 +267,9 @@ class UniformComovingVolume(dist.Distribution):
|
|
|
265
267
|
distances, using the specified distance type.
|
|
266
268
|
"""
|
|
267
269
|
self._generate_distance_grids()
|
|
268
|
-
bounds = torch.tensor(
|
|
270
|
+
bounds = torch.tensor(
|
|
271
|
+
[self.minimum, self.maximum], device=self.minimum.device
|
|
272
|
+
)
|
|
269
273
|
z_min, z_max = self._linear_interp_1d(
|
|
270
274
|
self.distance_grid, self.z_grid, bounds
|
|
271
275
|
)
|
|
@@ -276,7 +280,9 @@ class UniformComovingVolume(dist.Distribution):
|
|
|
276
280
|
"""
|
|
277
281
|
Generate distance grids based on the specified redshift range.
|
|
278
282
|
"""
|
|
279
|
-
self.z_grid = torch.linspace(
|
|
283
|
+
self.z_grid = torch.linspace(
|
|
284
|
+
0, self.z_grid_max, self.grid_size, device=self.minimum.device
|
|
285
|
+
)
|
|
280
286
|
self.dz = self.z_grid[1] - self.z_grid[0]
|
|
281
287
|
# C is specfied in m/s, h0 in km/s/Mpc, so divide by 1000 to convert
|
|
282
288
|
comoving_dist_grid = (
|
|
@@ -285,7 +291,9 @@ class UniformComovingVolume(dist.Distribution):
|
|
|
285
291
|
)
|
|
286
292
|
/ 1000
|
|
287
293
|
)
|
|
288
|
-
zero_prefix = torch.zeros(
|
|
294
|
+
zero_prefix = torch.zeros(
|
|
295
|
+
1, dtype=comoving_dist_grid.dtype, device=self.minimum.device
|
|
296
|
+
)
|
|
289
297
|
self.comoving_dist_grid = torch.cat([zero_prefix, comoving_dist_grid])
|
|
290
298
|
self.luminosity_dist_grid = self.comoving_dist_grid * (1 + self.z_grid)
|
|
291
299
|
|
|
@@ -315,7 +323,9 @@ class UniformComovingVolume(dist.Distribution):
|
|
|
315
323
|
p_of_distance, self.distance_grid
|
|
316
324
|
)
|
|
317
325
|
cdf = torch.cumulative_trapezoid(self.pdf, self.distance_grid)
|
|
318
|
-
zero_prefix = torch.zeros(
|
|
326
|
+
zero_prefix = torch.zeros(
|
|
327
|
+
1, dtype=cdf.dtype, device=self.minimum.device
|
|
328
|
+
)
|
|
319
329
|
self.cdf = torch.cat([zero_prefix, cdf])
|
|
320
330
|
self.log_pdf = torch.log(self.pdf)
|
|
321
331
|
|
|
@@ -333,7 +343,7 @@ class UniformComovingVolume(dist.Distribution):
|
|
|
333
343
|
|
|
334
344
|
def rsample(self, sample_shape: torch.Size = None) -> Tensor:
|
|
335
345
|
sample_shape = sample_shape or torch.Size()
|
|
336
|
-
u = torch.rand(sample_shape)
|
|
346
|
+
u = torch.rand(sample_shape, device=self.minimum.device)
|
|
337
347
|
return self._linear_interp_1d(self.cdf, self.distance_grid, u)
|
|
338
348
|
|
|
339
349
|
def log_prob(self, value: Tensor) -> Tensor:
|
|
@@ -341,7 +351,9 @@ class UniformComovingVolume(dist.Distribution):
|
|
|
341
351
|
self.distance_grid, self.log_pdf, value
|
|
342
352
|
)
|
|
343
353
|
inside_range = (value >= self.minimum) & (value <= self.maximum)
|
|
344
|
-
log_prob[~inside_range] =
|
|
354
|
+
log_prob[~inside_range] = torch.as_tensor(
|
|
355
|
+
float("-inf"), device=self.minimum.device
|
|
356
|
+
)
|
|
345
357
|
return log_prob
|
|
346
358
|
|
|
347
359
|
|
ml4gw/spectral.py
CHANGED
|
@@ -436,6 +436,7 @@ def normalize_by_psd(
|
|
|
436
436
|
psd: PSDTensor,
|
|
437
437
|
sample_rate: float,
|
|
438
438
|
pad: int,
|
|
439
|
+
crop: bool = True,
|
|
439
440
|
):
|
|
440
441
|
# compute the FFT of the section we want to whiten
|
|
441
442
|
# and divide it by the ASD of the background section.
|
|
@@ -452,7 +453,8 @@ def normalize_by_psd(
|
|
|
452
453
|
X = X.float() / sample_rate**0.5
|
|
453
454
|
|
|
454
455
|
# slice off corrupted data at edges of kernel
|
|
455
|
-
|
|
456
|
+
if crop:
|
|
457
|
+
X = X[:, :, pad:-pad]
|
|
456
458
|
return X
|
|
457
459
|
|
|
458
460
|
|
|
@@ -463,6 +465,7 @@ def whiten(
|
|
|
463
465
|
sample_rate: float,
|
|
464
466
|
highpass: float | None = None,
|
|
465
467
|
lowpass: float | None = None,
|
|
468
|
+
crop: bool = True,
|
|
466
469
|
) -> WaveformTensor:
|
|
467
470
|
"""
|
|
468
471
|
Whiten a batch of timeseries using the specified
|
|
@@ -506,9 +509,14 @@ def whiten(
|
|
|
506
509
|
the data, setting the frequency response in the
|
|
507
510
|
whitening filter to 0. If left as ``None``, no
|
|
508
511
|
lowpass filtering will be applied.
|
|
512
|
+
crop:
|
|
513
|
+
If ``True``, crop ``fduration / 2`` seconds of data
|
|
514
|
+
from both sides of the time dimension to remove the
|
|
515
|
+
corruption from the filter. If ``False``, return the
|
|
516
|
+
full timeseries.
|
|
509
517
|
Returns:
|
|
510
|
-
Batch of whitened multichannel timeseries with
|
|
511
|
-
|
|
518
|
+
Batch of whitened multichannel timeseries with ``fduration / 2``
|
|
519
|
+
seconds optionally trimmed from each side.
|
|
512
520
|
"""
|
|
513
521
|
|
|
514
522
|
# figure out how much data we'll need to slice
|
|
@@ -549,4 +557,4 @@ def whiten(
|
|
|
549
557
|
lowpass,
|
|
550
558
|
)
|
|
551
559
|
|
|
552
|
-
return normalize_by_psd(X, psd, sample_rate, pad)
|
|
560
|
+
return normalize_by_psd(X, psd, sample_rate, pad, crop)
|
ml4gw/transforms/__init__.py
CHANGED
ml4gw/transforms/decimator.py
CHANGED
|
@@ -21,17 +21,25 @@ class Decimator(torch.nn.Module):
|
|
|
21
21
|
schedule (torch.Tensor):
|
|
22
22
|
Tensor of shape `(N, 3)` defining start time, end time,
|
|
23
23
|
and target sample rate for each segment.
|
|
24
|
+
split (bool, optional):
|
|
25
|
+
- If True, the module returns a list of decimated segments
|
|
26
|
+
(one per schedule entry). Overlapping schedule segments are
|
|
27
|
+
only allowed when ``split=True``.
|
|
28
|
+
- If False (default), the segments are concatenated into a
|
|
29
|
+
single continuous output tensor.
|
|
24
30
|
|
|
25
31
|
Shape:
|
|
26
32
|
- Input: `(B, C, T)` where
|
|
27
33
|
- B = batch size
|
|
28
34
|
- C = channels
|
|
29
35
|
- T = number of timesteps
|
|
30
|
-
(must equal schedule duration
|
|
36
|
+
(must equal schedule duration x sample_rate)
|
|
31
37
|
- Output:
|
|
32
|
-
- If ``split=False`` → `(B, C, T')` where `T'` is total
|
|
38
|
+
- If ``split=False`` → `(B, C, T')` where `T'` is the total
|
|
33
39
|
number of decimated samples across all segments.
|
|
34
|
-
- If ``split=True`` → list of tensors,
|
|
40
|
+
- If ``split=True`` → list of tensors, each with shape
|
|
41
|
+
:math:`(B, C, T_i)`, corresponding to the decimated samples
|
|
42
|
+
in each schedule segment.
|
|
35
43
|
|
|
36
44
|
Returns:
|
|
37
45
|
torch.Tensor or List[torch.Tensor]:
|
|
@@ -45,16 +53,16 @@ class Decimator(torch.nn.Module):
|
|
|
45
53
|
>>> from ml4gw.transforms.decimator import Decimator
|
|
46
54
|
|
|
47
55
|
>>> sample_rate = 2048
|
|
48
|
-
>>> X_duration = 60
|
|
56
|
+
>>> X_duration = 60 # seconds
|
|
57
|
+
>>> X = torch.randn(1, 1, sample_rate * X_duration)
|
|
49
58
|
|
|
50
59
|
>>> schedule = torch.tensor(
|
|
51
60
|
... [[0, 40, 256], [40, 58, 512], [58, 60, 2048]],
|
|
52
61
|
... dtype=torch.int,
|
|
53
62
|
... )
|
|
54
|
-
|
|
55
63
|
>>> decimator = Decimator(sample_rate=sample_rate,
|
|
56
|
-
...
|
|
57
|
-
|
|
64
|
+
... schedule=schedule)
|
|
65
|
+
|
|
58
66
|
>>> X_dec = decimator(X)
|
|
59
67
|
>>> X_seg = decimator(X, split=True)
|
|
60
68
|
|
|
@@ -67,16 +75,34 @@ class Decimator(torch.nn.Module):
|
|
|
67
75
|
Segment 0 shape: torch.Size([1, 1, 10240])
|
|
68
76
|
Segment 1 shape: torch.Size([1, 1, 9216])
|
|
69
77
|
Segment 2 shape: torch.Size([1, 1, 4096])
|
|
78
|
+
|
|
79
|
+
>>> overlap_schedule = torch.tensor(
|
|
80
|
+
... [[0, 40, 256], [32, 58, 512]], [52, 60, 2048]],
|
|
81
|
+
... dtype=torch.int,
|
|
82
|
+
... )
|
|
83
|
+
>>> decimator_ov = Decimator(
|
|
84
|
+
... sample_rate=sample_rate,
|
|
85
|
+
... schedule=overlap_schedule,
|
|
86
|
+
... split=True,
|
|
87
|
+
... )
|
|
88
|
+
>>> X_overlap = decimator_ov(X)
|
|
89
|
+
>>> for i, seg in enumerate(X_overlap):
|
|
90
|
+
... print(f"Overlapping segment {i} shape:", seg.shape)
|
|
91
|
+
Overlapping segment 0 shape: torch.Size([1, 1, 10240])
|
|
92
|
+
Overlapping segment 1 shape: torch.Size([1, 1, 13312])
|
|
93
|
+
Overlapping segment 2 shape: torch.Size([1, 1, 16384])
|
|
70
94
|
"""
|
|
71
95
|
|
|
72
96
|
def __init__(
|
|
73
97
|
self,
|
|
74
98
|
sample_rate: int = None,
|
|
75
99
|
schedule: torch.Tensor = None,
|
|
100
|
+
split: bool = False,
|
|
76
101
|
) -> None:
|
|
77
102
|
super().__init__()
|
|
78
103
|
self.sample_rate = sample_rate
|
|
79
|
-
self.schedule
|
|
104
|
+
self.register_buffer("schedule", schedule)
|
|
105
|
+
self.split = split
|
|
80
106
|
|
|
81
107
|
self._validate_inputs()
|
|
82
108
|
idx = self.build_variable_indices()
|
|
@@ -89,7 +115,8 @@ class Decimator(torch.nn.Module):
|
|
|
89
115
|
|
|
90
116
|
def _validate_inputs(self) -> None:
|
|
91
117
|
r"""
|
|
92
|
-
Validate the schedule and sample_rate.
|
|
118
|
+
Validate the schedule and sample_rate. This method also checks
|
|
119
|
+
schedule segments do **not overlap** unless ``split=True``.
|
|
93
120
|
"""
|
|
94
121
|
if self.schedule.ndim != 2 or self.schedule.shape[1] != 3:
|
|
95
122
|
raise ValueError(
|
|
@@ -107,6 +134,15 @@ class Decimator(torch.nn.Module):
|
|
|
107
134
|
f"target rates {self.schedule[:, 2].tolist()}"
|
|
108
135
|
)
|
|
109
136
|
|
|
137
|
+
if not self.split:
|
|
138
|
+
starts = self.schedule[:, 0]
|
|
139
|
+
ends = self.schedule[:, 1]
|
|
140
|
+
if torch.any(starts[1:] < ends[:-1]):
|
|
141
|
+
raise ValueError(
|
|
142
|
+
"Schedule segments overlap — overlapping schedules "
|
|
143
|
+
"are only supported when split=True."
|
|
144
|
+
)
|
|
145
|
+
|
|
110
146
|
def build_variable_indices(self) -> torch.Tensor:
|
|
111
147
|
r"""
|
|
112
148
|
Compute the time indices to keep based on the schedule.
|
|
@@ -130,11 +166,15 @@ class Decimator(torch.nn.Module):
|
|
|
130
166
|
|
|
131
167
|
def split_by_schedule(self, X: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
|
132
168
|
r"""
|
|
133
|
-
Split a
|
|
169
|
+
Split and decimate a timeseries into segments according to the
|
|
170
|
+
schedule.
|
|
171
|
+
|
|
172
|
+
This method applies the decimation defined by each schedule row
|
|
173
|
+
and returns a list of the resulting segments.
|
|
134
174
|
|
|
135
175
|
Args:
|
|
136
176
|
X (torch.Tensor):
|
|
137
|
-
|
|
177
|
+
Input timeseries of shape `(B, C, T)` before decimation.
|
|
138
178
|
|
|
139
179
|
Returns:
|
|
140
180
|
tuple of torch.Tensor:
|
|
@@ -142,33 +182,42 @@ class Decimator(torch.nn.Module):
|
|
|
142
182
|
where :math:`T_i` is the length implied by
|
|
143
183
|
the corresponding schedule row.
|
|
144
184
|
"""
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
.
|
|
149
|
-
|
|
185
|
+
segments = []
|
|
186
|
+
|
|
187
|
+
for s in self.schedule:
|
|
188
|
+
start = int(s[0] * self.sample_rate)
|
|
189
|
+
stop = int(s[1] * self.sample_rate)
|
|
190
|
+
step = int(self.sample_rate // s[2])
|
|
191
|
+
idx = torch.arange(
|
|
192
|
+
start,
|
|
193
|
+
stop,
|
|
194
|
+
step,
|
|
195
|
+
dtype=torch.long,
|
|
196
|
+
device=self.schedule.device,
|
|
197
|
+
)
|
|
198
|
+
seg = X.index_select(dim=-1, index=idx)
|
|
199
|
+
segments.append(seg)
|
|
150
200
|
|
|
151
|
-
return
|
|
201
|
+
return segments
|
|
152
202
|
|
|
153
203
|
def forward(
|
|
154
204
|
self,
|
|
155
205
|
X: torch.Tensor,
|
|
156
|
-
split: bool = False,
|
|
157
206
|
) -> torch.Tensor | list[torch.Tensor]:
|
|
158
207
|
r"""
|
|
159
|
-
Apply decimation to the input timeseries.
|
|
208
|
+
Apply decimation to the input timeseries according to the schedule.
|
|
160
209
|
|
|
161
210
|
Args:
|
|
162
211
|
X (torch.Tensor):
|
|
163
212
|
Input tensor of shape `(B, C, T)`, where `T` must equal
|
|
164
|
-
schedule duration
|
|
165
|
-
split (bool, optional):
|
|
166
|
-
If True, return a list of segments instead of a single
|
|
167
|
-
concatenated tensor. Default: False.
|
|
213
|
+
schedule duration x sample_rate.
|
|
168
214
|
|
|
169
215
|
Returns:
|
|
170
|
-
torch.Tensor or
|
|
171
|
-
|
|
216
|
+
torch.Tensor or list[torch.Tensor]:
|
|
217
|
+
- If ``split=False`` (default), returns a single decimated
|
|
218
|
+
tensor of shape `(B, C, T')`.
|
|
219
|
+
- If ``split=True``, returns a list of decimated segments,
|
|
220
|
+
one per schedule entry.
|
|
172
221
|
"""
|
|
173
222
|
if X.shape[-1] != self.expected_len:
|
|
174
223
|
raise ValueError(
|
|
@@ -176,8 +225,7 @@ class Decimator(torch.nn.Module):
|
|
|
176
225
|
f"expected schedule duration {self.expected_len}"
|
|
177
226
|
)
|
|
178
227
|
|
|
179
|
-
|
|
228
|
+
if self.split:
|
|
229
|
+
return self.split_by_schedule(X)
|
|
180
230
|
|
|
181
|
-
|
|
182
|
-
X_dec = self.split_by_schedule(X_dec)
|
|
183
|
-
return X_dec
|
|
231
|
+
return X.index_select(dim=-1, index=self.idx)
|
ml4gw/transforms/whitening.py
CHANGED
|
@@ -69,7 +69,10 @@ class Whiten(torch.nn.Module):
|
|
|
69
69
|
self.register_buffer("window", window)
|
|
70
70
|
|
|
71
71
|
def forward(
|
|
72
|
-
self,
|
|
72
|
+
self,
|
|
73
|
+
X: TimeSeries3d,
|
|
74
|
+
psd: FrequencySeries1to3d,
|
|
75
|
+
crop: bool = True,
|
|
73
76
|
) -> TimeSeries3d:
|
|
74
77
|
"""
|
|
75
78
|
Whiten a batch of multichannel timeseries by a
|
|
@@ -96,6 +99,11 @@ class Whiten(torch.nn.Module):
|
|
|
96
99
|
For more information about what these different
|
|
97
100
|
shapes for ``psd`` represent, consult the documentation
|
|
98
101
|
for :meth:`~ml4gw.spectral.whiten`.
|
|
102
|
+
crop:
|
|
103
|
+
If ``True``, crop ``fduration / 2`` seconds of data
|
|
104
|
+
from both sides of the time dimension to remove the
|
|
105
|
+
corruption from the filter. If ``False``, return the
|
|
106
|
+
full timeseries.
|
|
99
107
|
Returns:
|
|
100
108
|
Whitened timeseries, with ``fduration * sample_rate / 2``
|
|
101
109
|
samples cropped from each edge. Output shape will then
|
|
@@ -109,6 +117,7 @@ class Whiten(torch.nn.Module):
|
|
|
109
117
|
sample_rate=self.sample_rate,
|
|
110
118
|
highpass=self.highpass,
|
|
111
119
|
lowpass=self.lowpass,
|
|
120
|
+
crop=crop,
|
|
112
121
|
)
|
|
113
122
|
|
|
114
123
|
|
|
@@ -127,7 +136,7 @@ class FixedWhiten(FittableSpectralTransform):
|
|
|
127
136
|
frequency bins in the fit PSD.
|
|
128
137
|
sample_rate:
|
|
129
138
|
Rate at which timeseries will be sampled, in Hz
|
|
130
|
-
|
|
139
|
+
crop:
|
|
131
140
|
Datatype with which background PSD will be stored
|
|
132
141
|
"""
|
|
133
142
|
|
|
@@ -243,11 +252,24 @@ class FixedWhiten(FittableSpectralTransform):
|
|
|
243
252
|
fduration = torch.Tensor([fduration])
|
|
244
253
|
self.build(psd=psd, fduration=fduration)
|
|
245
254
|
|
|
246
|
-
def forward(self, X: TimeSeries3d) -> TimeSeries3d:
|
|
255
|
+
def forward(self, X: TimeSeries3d, crop: bool = True) -> TimeSeries3d:
|
|
247
256
|
"""
|
|
248
257
|
Whiten the input timeseries tensor using the
|
|
249
258
|
PSD fit by the ``.fit`` method, which must be
|
|
250
259
|
called **before** the first call to ``.forward``.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
X:
|
|
263
|
+
Batch of multichannel timeseries to whiten.
|
|
264
|
+
Should have the shape ``(B, C, N)``, where
|
|
265
|
+
``B`` is the batch size, ``C`` is the number of
|
|
266
|
+
channels, and ``N`` is the number of seconds
|
|
267
|
+
in the timeseries times ``self.sample_rate``.
|
|
268
|
+
crop:
|
|
269
|
+
If ``True``, crop ``fduration / 2`` seconds of data
|
|
270
|
+
from both sides of the time dimension to remove the
|
|
271
|
+
corruption from the filter. If ``False``, return the
|
|
272
|
+
full timeseries.
|
|
251
273
|
"""
|
|
252
274
|
expected_dim = int(self.kernel_length * self.sample_rate)
|
|
253
275
|
if X.size(-1) != expected_dim:
|
|
@@ -258,4 +280,6 @@ class FixedWhiten(FittableSpectralTransform):
|
|
|
258
280
|
)
|
|
259
281
|
|
|
260
282
|
pad = int(self.fduration.item() * self.sample_rate / 2)
|
|
261
|
-
return spectral.normalize_by_psd(
|
|
283
|
+
return spectral.normalize_by_psd(
|
|
284
|
+
X, self.psd, self.sample_rate, pad, crop
|
|
285
|
+
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ml4gw
|
|
3
|
-
Version: 0.7.
|
|
3
|
+
Version: 0.7.10
|
|
4
4
|
Summary: Tools for training torch models on gravitational wave data
|
|
5
5
|
Author-email: Ethan Marx <emarx@mit.edu>, Will Benoit <benoi090@umn.edu>, Deep Chatterjee <deep1018@mit.edu>, Alec Gunny <alec.gunny@ligo.org>, Ravi Kumar <ravi.kumar@ligo.org>
|
|
6
6
|
Maintainer-email: Ethan Marx <emarx@mit.edu>, Will Benoit <benoi090@umn.edu>, Deep Chatterjee <deep1018@mit.edu>
|
|
@@ -16,7 +16,7 @@ License-File: LICENSE
|
|
|
16
16
|
Requires-Dist: jaxtyping<0.3,>=0.2
|
|
17
17
|
Requires-Dist: torch~=2.0
|
|
18
18
|
Requires-Dist: torchaudio~=2.0
|
|
19
|
-
Requires-Dist: numpy
|
|
19
|
+
Requires-Dist: numpy>=1.0.0
|
|
20
20
|
Requires-Dist: scipy<1.15,>=1.9.0
|
|
21
21
|
Dynamic: license-file
|
|
22
22
|
|
|
@@ -26,6 +26,7 @@ Dynamic: license-file
|
|
|
26
26
|

|
|
27
27
|

|
|
28
28
|

|
|
29
|
+
[](https://doi.org/10.21105/joss.08836)
|
|
29
30
|
|
|
30
31
|
Torch utilities for training neural networks in gravitational wave physics applications.
|
|
31
32
|
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
ml4gw/__init__.py,sha256=81quoggCuIypZjZs3bbf1Ty70KHdva5RGEJxi0oC57E,25
|
|
2
2
|
ml4gw/augmentations.py,sha256=Jck8FVtKM18evUESIRrJC-WFk8amhTdqD4kM74HjlGI,1382
|
|
3
3
|
ml4gw/constants.py,sha256=RQPXwavlw_cWu3ByltvTejPsi6EWXHDJQ1HaV9iE3Lg,850
|
|
4
|
-
ml4gw/distributions.py,sha256=
|
|
4
|
+
ml4gw/distributions.py,sha256=gmTxbqi3zSM42Oizm_svPMrDK3GX8zYLhoNl3B7mUro,13288
|
|
5
5
|
ml4gw/gw.py,sha256=L4WgicskE_Nn6UQMDG1C0TNquXC7D3A84yd_I54e9y0,20181
|
|
6
|
-
ml4gw/spectral.py,sha256=
|
|
6
|
+
ml4gw/spectral.py,sha256=rDfptr2WQ8dE11EUjXVmYb8-T9ROwizMBYtryI_03ag,20213
|
|
7
7
|
ml4gw/types.py,sha256=weerxUI-PEWWvzGtzwY_A9AH_m_I1p-rGWvKbiZle-0,919
|
|
8
8
|
ml4gw/dataloading/__init__.py,sha256=Qrfq3yPGMy8-1gAMyo81MuKGK5gwp7y6Gx8gd4suEjc,240
|
|
9
9
|
ml4gw/dataloading/chunked_dataset.py,sha256=CZ40OTjuRx_v0nzLrApybIj40zCkdhP5xsGBnqGArB8,5255
|
|
10
|
-
ml4gw/dataloading/hdf5_dataset.py,sha256=
|
|
10
|
+
ml4gw/dataloading/hdf5_dataset.py,sha256=Xf_AKBCmwsuPxqRTaJQqJbt4w7BEpTPeWhGKm2YKGbE,7964
|
|
11
11
|
ml4gw/dataloading/in_memory_dataset.py,sha256=wbdWID76IqhD01u-XlD-QPRnWPSVmkSQTyHR2lyy_NA,9559
|
|
12
12
|
ml4gw/nn/__init__.py,sha256=Vn8CqewYAK6GD-gOvsR7TJK6I568soXLpp9Ga9sYqkY,188
|
|
13
13
|
ml4gw/nn/norm.py,sha256=zd8NcjrtqM4yFyHFmDkknuV623NA5Cj0o6jBdPv6xh0,3584
|
|
@@ -22,8 +22,8 @@ ml4gw/nn/resnet/resnet_2d.py,sha256=YiHxP3cNIfjOrEmKSVqZYOUxoVnIkpDdwCW9VieNM7E,
|
|
|
22
22
|
ml4gw/nn/streaming/__init__.py,sha256=zgjGR2L8t0txXLnil9ceZT0tM8Y2FC8yPxqIKYH0o1A,80
|
|
23
23
|
ml4gw/nn/streaming/online_average.py,sha256=22jQ_JbJTpusmqeGdo7Ta7lTsGoTBjYtKZnXzucW3wc,4676
|
|
24
24
|
ml4gw/nn/streaming/snapshotter.py,sha256=kH73np-LUGF0ZP-tkWY19TrCJa3m1RIvvZ-SmLA7YvM,4378
|
|
25
|
-
ml4gw/transforms/__init__.py,sha256=
|
|
26
|
-
ml4gw/transforms/decimator.py,sha256=
|
|
25
|
+
ml4gw/transforms/__init__.py,sha256=Q5o8L7LKLc2AYItI_-fapNQVU6gJE6a5BOaAPwQbBVo,665
|
|
26
|
+
ml4gw/transforms/decimator.py,sha256=_7AtlM9UqwpJbE1G9gyuV61l044gw7lrVfOfeGBLie8,8279
|
|
27
27
|
ml4gw/transforms/iirfilter.py,sha256=T4qgrJeA3vPeVWyZ-bPBOxQkJL0yfaUVoU0MTtpwhvg,3152
|
|
28
28
|
ml4gw/transforms/pearson.py,sha256=buPEfjaPJkMtdnBP5tvFzRzId8DbThAfJcGzaYmANqc,3202
|
|
29
29
|
ml4gw/transforms/qtransform.py,sha256=a7hqZ4pq9J6pq8L3Dm4Dqyxz-QyaznoNZO4mTdb5apY,20616
|
|
@@ -34,7 +34,7 @@ ml4gw/transforms/spectrogram.py,sha256=TGG_fVng-Y569KsIQAaf5_WN-W4-6F89oQSyHFxVo
|
|
|
34
34
|
ml4gw/transforms/spline_interpolation.py,sha256=QHBp5g_1_lOYmCFJqQyZAbPXBZaqzb5_LBWeLY6ppkI,18614
|
|
35
35
|
ml4gw/transforms/transform.py,sha256=_jAxsCnLmIo97g4b2J8WKS0Omy-yyOzNt6lFHEM5ESM,2463
|
|
36
36
|
ml4gw/transforms/waveforms.py,sha256=yFOzGlYjjM488oYxZLpikS9noZCnisVM0_gjgWoF-_E,3018
|
|
37
|
-
ml4gw/transforms/whitening.py,sha256=
|
|
37
|
+
ml4gw/transforms/whitening.py,sha256=s3nE5Z45jYCJIZSaAYCXn_j6Kic5AXF2WwS7USu2bvM,11226
|
|
38
38
|
ml4gw/utils/interferometer.py,sha256=lRGtMRFSco1mI1Y1O2kz4dRp5hmK5cp4nG7sYbAiYG4,2179
|
|
39
39
|
ml4gw/utils/slicing.py,sha256=kaO54GMUV8d7vIt3oodVZJ4jFwJMEK8aGknJ8UOJzRs,13667
|
|
40
40
|
ml4gw/waveforms/__init__.py,sha256=SxTc6rSkQfoOtEgNYvA-8tMJsQQQROTRRKaFDRQOmh4,172
|
|
@@ -50,8 +50,8 @@ ml4gw/waveforms/cbc/phenom_d_data.py,sha256=WA1FBxUp9fo1IQaV_OLJ_5g5gI166mY1FtG9
|
|
|
50
50
|
ml4gw/waveforms/cbc/phenom_p.py,sha256=m81Xt_zIffHiGlWgzf-AmI46mSn6CFZAX-6Fwwr5Tfk,27635
|
|
51
51
|
ml4gw/waveforms/cbc/taylorf2.py,sha256=emWbl3vjsCzBOooHOVO7pPlPcj05r4up6InlMkO5m_E,10422
|
|
52
52
|
ml4gw/waveforms/cbc/utils.py,sha256=LT1ky10_6ZrbwTcxIrWP1O75GUEuU5q2ZE2yYDhadQE,3037
|
|
53
|
-
ml4gw-0.7.
|
|
54
|
-
ml4gw-0.7.
|
|
55
|
-
ml4gw-0.7.
|
|
56
|
-
ml4gw-0.7.
|
|
57
|
-
ml4gw-0.7.
|
|
53
|
+
ml4gw-0.7.10.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
54
|
+
ml4gw-0.7.10.dist-info/METADATA,sha256=wIeraROEpLgAXivws9yOmH2BMOXCGPULCbZzXlCU89k,4392
|
|
55
|
+
ml4gw-0.7.10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
56
|
+
ml4gw-0.7.10.dist-info/top_level.txt,sha256=JnWLyPXJ3_WUcjr6fRV0ZTXj8FR0x4vBzjkg-1bl2tw,6
|
|
57
|
+
ml4gw-0.7.10.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|