ecgen 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ecgen-0.4.0/LICENSE +21 -0
- ecgen-0.4.0/PKG-INFO +37 -0
- ecgen-0.4.0/pyproject.toml +28 -0
- ecgen-0.4.0/setup.cfg +4 -0
- ecgen-0.4.0/src/ecgen/__init__.py +7 -0
- ecgen-0.4.0/src/ecgen/data/datamodule.py +32 -0
- ecgen-0.4.0/src/ecgen/data/mimic_dataset.py +214 -0
- ecgen-0.4.0/src/ecgen/data/pulse2pulse_mimic.py +106 -0
- ecgen-0.4.0/src/ecgen/data/transforms.py +0 -0
- ecgen-0.4.0/src/ecgen/models/__init__.py +33 -0
- ecgen-0.4.0/src/ecgen/models/pulse2pulse.py +411 -0
- ecgen-0.4.0/src/ecgen/models/vae.py +343 -0
- ecgen-0.4.0/src/ecgen/training/callbacks.py +167 -0
- ecgen-0.4.0/src/ecgen/training/losses.py +79 -0
- ecgen-0.4.0/src/ecgen/training/metrics.py +64 -0
- ecgen-0.4.0/src/ecgen/training/test.py +27 -0
- ecgen-0.4.0/src/ecgen/training/train.py +183 -0
- ecgen-0.4.0/src/ecgen/training/validate.py +27 -0
- ecgen-0.4.0/src/ecgen/utils/io.py +36 -0
- ecgen-0.4.0/src/ecgen/utils/logging.py +16 -0
- ecgen-0.4.0/src/ecgen/utils/metadata.py +104 -0
- ecgen-0.4.0/src/ecgen/utils/seed.py +27 -0
- ecgen-0.4.0/src/ecgen.egg-info/PKG-INFO +37 -0
- ecgen-0.4.0/src/ecgen.egg-info/SOURCES.txt +28 -0
- ecgen-0.4.0/src/ecgen.egg-info/dependency_links.txt +1 -0
- ecgen-0.4.0/src/ecgen.egg-info/requires.txt +6 -0
- ecgen-0.4.0/src/ecgen.egg-info/top_level.txt +1 -0
- ecgen-0.4.0/tests/test_data.py +0 -0
- ecgen-0.4.0/tests/test_models.py +0 -0
- ecgen-0.4.0/tests/test_vae.py +76 -0
ecgen-0.4.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Vajira Thambawita
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
ecgen-0.4.0/PKG-INFO
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: ecgen
|
|
3
|
+
Version: 0.4.0
|
|
4
|
+
Summary: ECG generation and modeling experiments
|
|
5
|
+
License: MIT License
|
|
6
|
+
|
|
7
|
+
Copyright (c) 2026 Vajira Thambawita
|
|
8
|
+
|
|
9
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
10
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
11
|
+
in the Software without restriction, including without limitation the rights
|
|
12
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
13
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
14
|
+
furnished to do so, subject to the following conditions:
|
|
15
|
+
|
|
16
|
+
The above copyright notice and this permission notice shall be included in all
|
|
17
|
+
copies or substantial portions of the Software.
|
|
18
|
+
|
|
19
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
20
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
21
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
22
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
23
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
24
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
25
|
+
SOFTWARE.
|
|
26
|
+
|
|
27
|
+
Project-URL: Homepage, https://github.com/vlbthambawita/ECGEN
|
|
28
|
+
Requires-Python: >=3.8
|
|
29
|
+
Description-Content-Type: text/markdown
|
|
30
|
+
License-File: LICENSE
|
|
31
|
+
Requires-Dist: torch
|
|
32
|
+
Requires-Dist: pyyaml
|
|
33
|
+
Requires-Dist: pytorch-lightning
|
|
34
|
+
Requires-Dist: pandas
|
|
35
|
+
Requires-Dist: scikit-learn
|
|
36
|
+
Requires-Dist: wfdb
|
|
37
|
+
Dynamic: license-file
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "ecgen"
|
|
7
|
+
dynamic = ["version"]
|
|
8
|
+
description = "ECG generation and modeling experiments"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
license = { file = "LICENSE" }
|
|
11
|
+
requires-python = ">=3.8"
|
|
12
|
+
dependencies = [
|
|
13
|
+
"torch",
|
|
14
|
+
"pyyaml",
|
|
15
|
+
"pytorch-lightning",
|
|
16
|
+
"pandas",
|
|
17
|
+
"scikit-learn",
|
|
18
|
+
"wfdb",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
[project.urls]
|
|
22
|
+
Homepage = "https://github.com/vlbthambawita/ECGEN"
|
|
23
|
+
|
|
24
|
+
[tool.setuptools.packages.find]
|
|
25
|
+
where = ["src"]
|
|
26
|
+
|
|
27
|
+
[tool.setuptools.dynamic]
|
|
28
|
+
version = { attr = "ecgen.__version__" }
|
ecgen-0.4.0/setup.cfg
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from torch.utils.data import DataLoader, Dataset
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class ECGDataModule:
|
|
9
|
+
"""
|
|
10
|
+
Minimal placeholder datamodule describing train/val/test datasets.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
train_dataset: Optional[Dataset] = None
|
|
14
|
+
val_dataset: Optional[Dataset] = None
|
|
15
|
+
test_dataset: Optional[Dataset] = None
|
|
16
|
+
batch_size: int = 32
|
|
17
|
+
num_workers: int = 4
|
|
18
|
+
|
|
19
|
+
def train_dataloader(self) -> DataLoader:
|
|
20
|
+
if self.train_dataset is None:
|
|
21
|
+
raise RuntimeError("train_dataset is not set.")
|
|
22
|
+
return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
|
|
23
|
+
|
|
24
|
+
def val_dataloader(self) -> DataLoader:
|
|
25
|
+
if self.val_dataset is None:
|
|
26
|
+
raise RuntimeError("val_dataset is not set.")
|
|
27
|
+
return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
|
|
28
|
+
|
|
29
|
+
def test_dataloader(self) -> DataLoader:
|
|
30
|
+
if self.test_dataset is None:
|
|
31
|
+
raise RuntimeError("test_dataset is not set.")
|
|
32
|
+
return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Native MIMIC-IV-ECG dataset implementation.
|
|
3
|
+
|
|
4
|
+
Dataset structure:
|
|
5
|
+
- ECG waveforms: files/p{XXXX}/p{subject_id}/s{study_id}/{study_id}.hea/.dat (WFDB format)
|
|
6
|
+
- Machine measurements: machine_measurements.csv
|
|
7
|
+
|
|
8
|
+
Expected columns in machine_measurements.csv:
|
|
9
|
+
- subject_id, study_id
|
|
10
|
+
- rr_interval, p_onset, p_end, qrs_onset, qrs_end, t_end
|
|
11
|
+
- p_axis, qrs_axis, t_axis
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
from typing import Optional, Tuple
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
import pandas as pd
|
|
21
|
+
import torch
|
|
22
|
+
import wfdb
|
|
23
|
+
from sklearn.model_selection import train_test_split
|
|
24
|
+
from torch.utils.data import Dataset
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class MIMICIVECGDataset(Dataset):
|
|
28
|
+
"""
|
|
29
|
+
Dataset for MIMIC-IV-ECG signals with machine measurements conditioning.
|
|
30
|
+
|
|
31
|
+
Returns (ecg, features) where:
|
|
32
|
+
- ecg: (num_leads, seq_length) normalized ECG signal, float32
|
|
33
|
+
- features: (9,) normalized machine measurements, float32
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
FEATURE_NAMES = [
|
|
37
|
+
"rr_interval",
|
|
38
|
+
"p_onset",
|
|
39
|
+
"p_end",
|
|
40
|
+
"qrs_onset",
|
|
41
|
+
"qrs_end",
|
|
42
|
+
"t_end",
|
|
43
|
+
"p_axis",
|
|
44
|
+
"qrs_axis",
|
|
45
|
+
"t_axis",
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
mimic_path: str,
|
|
51
|
+
split: str = "train",
|
|
52
|
+
val_split: float = 0.1,
|
|
53
|
+
test_split: float = 0.1,
|
|
54
|
+
max_samples: Optional[int] = None,
|
|
55
|
+
seed: int = 42,
|
|
56
|
+
skip_missing_check: bool = False,
|
|
57
|
+
ecg_norm_eps: float = 1e-6,
|
|
58
|
+
ecg_norm_factor: Optional[float] = None,
|
|
59
|
+
num_leads: int = 12,
|
|
60
|
+
seq_length: int = 5000,
|
|
61
|
+
) -> None:
|
|
62
|
+
self.mimic_path = mimic_path
|
|
63
|
+
self.ecg_norm_eps = ecg_norm_eps
|
|
64
|
+
self.ecg_norm_factor = ecg_norm_factor
|
|
65
|
+
self.split = split
|
|
66
|
+
self.seed = seed
|
|
67
|
+
self.skip_missing_check = skip_missing_check
|
|
68
|
+
self.num_leads = num_leads
|
|
69
|
+
self.seq_length = seq_length
|
|
70
|
+
|
|
71
|
+
self.load_measurements()
|
|
72
|
+
self.create_splits(val_split, test_split)
|
|
73
|
+
self.filter_by_split()
|
|
74
|
+
|
|
75
|
+
if not skip_missing_check:
|
|
76
|
+
self.filter_missing_files()
|
|
77
|
+
else:
|
|
78
|
+
import warnings
|
|
79
|
+
warnings.warn(
|
|
80
|
+
"Skipping missing file check. Some samples may fail during loading.",
|
|
81
|
+
UserWarning,
|
|
82
|
+
stacklevel=2,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
if max_samples is not None:
|
|
86
|
+
self.measurements = self.measurements.head(max_samples).reset_index(drop=True)
|
|
87
|
+
|
|
88
|
+
self.compute_feature_stats()
|
|
89
|
+
|
|
90
|
+
def load_measurements(self) -> None:
|
|
91
|
+
path = os.path.join(self.mimic_path, "machine_measurements.csv")
|
|
92
|
+
if not os.path.isfile(path):
|
|
93
|
+
raise FileNotFoundError(
|
|
94
|
+
f"machine_measurements.csv not found at {path}. "
|
|
95
|
+
"Download from https://physionet.org/content/mimic-iv-ecg/1.0/"
|
|
96
|
+
)
|
|
97
|
+
self.measurements = pd.read_csv(path)
|
|
98
|
+
|
|
99
|
+
required = ["subject_id", "study_id"] + self.FEATURE_NAMES
|
|
100
|
+
missing = [c for c in required if c not in self.measurements.columns]
|
|
101
|
+
if missing:
|
|
102
|
+
raise ValueError(f"machine_measurements.csv missing columns: {missing}")
|
|
103
|
+
|
|
104
|
+
self.measurements = self.measurements.dropna(subset=self.FEATURE_NAMES).reset_index(drop=True)
|
|
105
|
+
|
|
106
|
+
def create_splits(self, val_split: float, test_split: float) -> None:
|
|
107
|
+
subjects = self.measurements["subject_id"].unique()
|
|
108
|
+
|
|
109
|
+
train_subjects, test_subjects = train_test_split(
|
|
110
|
+
subjects, test_size=test_split, random_state=self.seed
|
|
111
|
+
)
|
|
112
|
+
train_subjects, val_subjects = train_test_split(
|
|
113
|
+
train_subjects,
|
|
114
|
+
test_size=val_split / (1 - test_split),
|
|
115
|
+
random_state=self.seed,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
def assign_split(row: pd.Series) -> str:
|
|
119
|
+
sid = row["subject_id"]
|
|
120
|
+
if sid in val_subjects:
|
|
121
|
+
return "val"
|
|
122
|
+
if sid in test_subjects:
|
|
123
|
+
return "test"
|
|
124
|
+
return "train"
|
|
125
|
+
|
|
126
|
+
self.measurements["split"] = self.measurements.apply(assign_split, axis=1)
|
|
127
|
+
|
|
128
|
+
def filter_by_split(self) -> None:
|
|
129
|
+
self.measurements = self.measurements[
|
|
130
|
+
self.measurements["split"] == self.split
|
|
131
|
+
].reset_index(drop=True)
|
|
132
|
+
|
|
133
|
+
def filter_missing_files(self) -> None:
|
|
134
|
+
files_dir = os.path.join(self.mimic_path, "files")
|
|
135
|
+
if not os.path.isdir(files_dir):
|
|
136
|
+
return
|
|
137
|
+
|
|
138
|
+
valid = []
|
|
139
|
+
for idx in range(len(self.measurements)):
|
|
140
|
+
row = self.measurements.iloc[idx]
|
|
141
|
+
rec_path = self._ecg_record_path(row["subject_id"], row["study_id"])
|
|
142
|
+
if os.path.isfile(rec_path + ".hea"):
|
|
143
|
+
valid.append(idx)
|
|
144
|
+
|
|
145
|
+
self.measurements = self.measurements.iloc[valid].reset_index(drop=True)
|
|
146
|
+
|
|
147
|
+
def _ecg_record_path(self, subject_id: int, study_id: int) -> str:
|
|
148
|
+
sub_str = str(subject_id)
|
|
149
|
+
prefix = sub_str[:4]
|
|
150
|
+
return os.path.join(
|
|
151
|
+
self.mimic_path,
|
|
152
|
+
"files",
|
|
153
|
+
f"p{prefix}",
|
|
154
|
+
f"p{subject_id}",
|
|
155
|
+
f"s{study_id}",
|
|
156
|
+
str(study_id),
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
def compute_feature_stats(self) -> None:
|
|
160
|
+
self.feature_stats = {}
|
|
161
|
+
for name in self.FEATURE_NAMES:
|
|
162
|
+
vals = self.measurements[name].values
|
|
163
|
+
self.feature_stats[name] = {
|
|
164
|
+
"mean": float(np.mean(vals)),
|
|
165
|
+
"std": float(np.std(vals)) + 1e-6,
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
def load_ecg(self, idx: int) -> np.ndarray:
|
|
169
|
+
row = self.measurements.iloc[idx]
|
|
170
|
+
rec_path = self._ecg_record_path(row["subject_id"], row["study_id"])
|
|
171
|
+
|
|
172
|
+
record = wfdb.rdrecord(rec_path)
|
|
173
|
+
signal = record.p_signal # (time, leads)
|
|
174
|
+
signal = signal.T.astype(np.float32) # (leads, time)
|
|
175
|
+
|
|
176
|
+
if signal.shape[0] < self.num_leads:
|
|
177
|
+
pad = np.zeros((self.num_leads - signal.shape[0], signal.shape[1]), dtype=np.float32)
|
|
178
|
+
signal = np.vstack([signal, pad])
|
|
179
|
+
elif signal.shape[0] > self.num_leads:
|
|
180
|
+
signal = signal[: self.num_leads]
|
|
181
|
+
|
|
182
|
+
if signal.shape[1] < self.seq_length:
|
|
183
|
+
pad = np.zeros((signal.shape[0], self.seq_length - signal.shape[1]), dtype=np.float32)
|
|
184
|
+
signal = np.hstack([signal, pad])
|
|
185
|
+
elif signal.shape[1] > self.seq_length:
|
|
186
|
+
signal = signal[:, : self.seq_length]
|
|
187
|
+
|
|
188
|
+
return signal
|
|
189
|
+
|
|
190
|
+
def _get_features(self, idx: int) -> np.ndarray:
|
|
191
|
+
row = self.measurements.iloc[idx]
|
|
192
|
+
out = []
|
|
193
|
+
for name in self.FEATURE_NAMES:
|
|
194
|
+
val = row[name]
|
|
195
|
+
s = self.feature_stats[name]
|
|
196
|
+
out.append((float(val) - s["mean"]) / s["std"])
|
|
197
|
+
return np.array(out, dtype=np.float32)
|
|
198
|
+
|
|
199
|
+
def __len__(self) -> int:
|
|
200
|
+
return len(self.measurements)
|
|
201
|
+
|
|
202
|
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
203
|
+
ecg = self.load_ecg(idx)
|
|
204
|
+
|
|
205
|
+
ecg_mean = ecg.mean()
|
|
206
|
+
if self.ecg_norm_factor is not None:
|
|
207
|
+
scale = self.ecg_norm_factor
|
|
208
|
+
else:
|
|
209
|
+
scale = max(float(np.std(ecg)), self.ecg_norm_eps)
|
|
210
|
+
ecg = (ecg.astype(np.float32) - ecg_mean) / scale
|
|
211
|
+
|
|
212
|
+
features = self._get_features(idx)
|
|
213
|
+
|
|
214
|
+
return torch.from_numpy(ecg), torch.from_numpy(features)
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch.utils.data import DataLoader, Dataset
|
|
8
|
+
|
|
9
|
+
import pytorch_lightning as pl
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ECGDatasetAdapter(Dataset):
|
|
13
|
+
"""
|
|
14
|
+
Wrap (ecg, cond) style dataset and return only ECG signals for GAN training.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, base_dataset: Dataset, num_leads: int = 8) -> None:
|
|
18
|
+
self.base = base_dataset
|
|
19
|
+
self.num_leads = num_leads
|
|
20
|
+
|
|
21
|
+
def __len__(self) -> int:
|
|
22
|
+
return len(self.base) # type: ignore[arg-type]
|
|
23
|
+
|
|
24
|
+
def __getitem__(self, idx: int):
|
|
25
|
+
ecg, _ = self.base[idx]
|
|
26
|
+
if ecg.shape[0] > self.num_leads:
|
|
27
|
+
ecg = ecg[: self.num_leads]
|
|
28
|
+
ecg = ecg.float()
|
|
29
|
+
return {"ecg_signals": ecg}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class Pulse2PulseMIMICConfig:
|
|
34
|
+
data_dir: str
|
|
35
|
+
batch_size: int = 128
|
|
36
|
+
num_workers: int = 4
|
|
37
|
+
max_samples: Optional[int] = None
|
|
38
|
+
skip_missing_check: bool = True
|
|
39
|
+
num_channels: int = 8
|
|
40
|
+
seq_length: int = 5000
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Pulse2PulseMIMICDataModule(pl.LightningDataModule):
|
|
44
|
+
"""
|
|
45
|
+
LightningDataModule for MIMIC‑IV‑ECG ECG generation with Pulse2Pulse.
|
|
46
|
+
|
|
47
|
+
Uses native ecgen.data.mimic_dataset.MIMICIVECGDataset.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(self, config: Pulse2PulseMIMICConfig | dict) -> None:
|
|
51
|
+
super().__init__()
|
|
52
|
+
if isinstance(config, dict):
|
|
53
|
+
config = Pulse2PulseMIMICConfig(**config)
|
|
54
|
+
self.config = config
|
|
55
|
+
|
|
56
|
+
self.train_dataset: Optional[Dataset] = None
|
|
57
|
+
self.val_dataset: Optional[Dataset] = None
|
|
58
|
+
|
|
59
|
+
def setup(self, stage: Optional[str] = None) -> None:
|
|
60
|
+
if self.train_dataset is not None and self.val_dataset is not None:
|
|
61
|
+
return
|
|
62
|
+
|
|
63
|
+
from ecgen.data.mimic_dataset import MIMICIVECGDataset
|
|
64
|
+
|
|
65
|
+
train_base = MIMICIVECGDataset(
|
|
66
|
+
mimic_path=self.config.data_dir,
|
|
67
|
+
split="train",
|
|
68
|
+
max_samples=self.config.max_samples,
|
|
69
|
+
skip_missing_check=self.config.skip_missing_check,
|
|
70
|
+
num_leads=12,
|
|
71
|
+
seq_length=self.config.seq_length,
|
|
72
|
+
)
|
|
73
|
+
val_base = MIMICIVECGDataset(
|
|
74
|
+
mimic_path=self.config.data_dir,
|
|
75
|
+
split="val",
|
|
76
|
+
max_samples=min(self.config.max_samples or 1000, 1000),
|
|
77
|
+
skip_missing_check=self.config.skip_missing_check,
|
|
78
|
+
num_leads=12,
|
|
79
|
+
seq_length=self.config.seq_length,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
self.train_dataset = ECGDatasetAdapter(train_base, num_leads=self.config.num_channels)
|
|
83
|
+
self.val_dataset = ECGDatasetAdapter(val_base, num_leads=self.config.num_channels)
|
|
84
|
+
|
|
85
|
+
def train_dataloader(self) -> DataLoader:
|
|
86
|
+
if self.train_dataset is None:
|
|
87
|
+
raise RuntimeError("train_dataset is not set. Did you forget to call setup()? ")
|
|
88
|
+
return DataLoader(
|
|
89
|
+
self.train_dataset,
|
|
90
|
+
batch_size=self.config.batch_size,
|
|
91
|
+
shuffle=True,
|
|
92
|
+
num_workers=self.config.num_workers,
|
|
93
|
+
pin_memory=torch.cuda.is_available(),
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
def val_dataloader(self) -> DataLoader:
|
|
97
|
+
if self.val_dataset is None:
|
|
98
|
+
raise RuntimeError("val_dataset is not set. Did you forget to call setup()? ")
|
|
99
|
+
return DataLoader(
|
|
100
|
+
self.val_dataset,
|
|
101
|
+
batch_size=self.config.batch_size,
|
|
102
|
+
shuffle=False,
|
|
103
|
+
num_workers=self.config.num_workers,
|
|
104
|
+
pin_memory=torch.cuda.is_available(),
|
|
105
|
+
)
|
|
106
|
+
|
|
File without changes
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ECG generation models.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from ecgen.models.pulse2pulse import (
|
|
6
|
+
Pulse2PulseConfig,
|
|
7
|
+
Pulse2PulseGAN,
|
|
8
|
+
WaveGANDiscriminator,
|
|
9
|
+
WaveGANGenerator,
|
|
10
|
+
)
|
|
11
|
+
from ecgen.models.vae import (
|
|
12
|
+
Decoder1D,
|
|
13
|
+
Encoder1D,
|
|
14
|
+
ResidualBlock1D,
|
|
15
|
+
VAE1D,
|
|
16
|
+
VAEConfig,
|
|
17
|
+
VAELightning,
|
|
18
|
+
vae_loss,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
"Pulse2PulseConfig",
|
|
23
|
+
"Pulse2PulseGAN",
|
|
24
|
+
"WaveGANDiscriminator",
|
|
25
|
+
"WaveGANGenerator",
|
|
26
|
+
"ResidualBlock1D",
|
|
27
|
+
"Encoder1D",
|
|
28
|
+
"Decoder1D",
|
|
29
|
+
"VAE1D",
|
|
30
|
+
"VAEConfig",
|
|
31
|
+
"VAELightning",
|
|
32
|
+
"vae_loss",
|
|
33
|
+
]
|