eclipse-ms 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eclipse_ms/__init__.py +45 -0
- eclipse_ms/cli.py +107 -0
- eclipse_ms/cluster.py +105 -0
- eclipse_ms/config.py +53 -0
- eclipse_ms/consensus.py +148 -0
- eclipse_ms/embed.py +96 -0
- eclipse_ms/layers.py +84 -0
- eclipse_ms/modelhub.py +173 -0
- eclipse_ms/models.py +451 -0
- eclipse_ms/preprocessing.py +85 -0
- eclipse_ms-0.1.2.dist-info/METADATA +152 -0
- eclipse_ms-0.1.2.dist-info/RECORD +15 -0
- eclipse_ms-0.1.2.dist-info/WHEEL +4 -0
- eclipse_ms-0.1.2.dist-info/entry_points.txt +2 -0
- eclipse_ms-0.1.2.dist-info/licenses/LICENSE +21 -0
eclipse_ms/modelhub.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
"""Model registry, download/cache, and loaders.
|
|
2
|
+
|
|
3
|
+
The trained weights are far too large to ship inside the PyPI wheel, so they
|
|
4
|
+
live in external storage (a GitHub Release asset, a Hugging Face Hub file, or a
|
|
5
|
+
Zenodo record) and are downloaded on first use and cached locally, with a
|
|
6
|
+
SHA-256 integrity check.
|
|
7
|
+
|
|
8
|
+
Resolution order for any model file:
|
|
9
|
+
1. ``ECLIPSE_MODEL_DIR`` env var, if set and the file exists there;
|
|
10
|
+
2. the local cache (``platformdirs`` user cache dir);
|
|
11
|
+
3. download from the registry URL into the cache.
|
|
12
|
+
|
|
13
|
+
You can also bypass the registry entirely and pass explicit local paths to
|
|
14
|
+
:func:`load_encoder` / :func:`load_autoencoder` (e.g. on an HPC node where you
|
|
15
|
+
already have the weights).
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import hashlib
|
|
21
|
+
import json
|
|
22
|
+
import os
|
|
23
|
+
import urllib.request
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
from typing import Optional
|
|
26
|
+
|
|
27
|
+
from platformdirs import user_cache_dir
|
|
28
|
+
|
|
29
|
+
# ---------------------------------------------------------------------------
|
|
30
|
+
# Registry. After uploading your weights, fill in `url` and `sha256` for each
|
|
31
|
+
# entry. `sha256=None` disables the integrity check (not recommended for a
|
|
32
|
+
# release). Compute a hash with: python -c "import hashlib,sys;
|
|
33
|
+
# print(hashlib.sha256(open(sys.argv[1],'rb').read()).hexdigest())" FILE
|
|
34
|
+
# ---------------------------------------------------------------------------
|
|
35
|
+
REGISTRY: dict[str, dict] = {
|
|
36
|
+
# Slim, recommended for embedding/clustering: encoder weights only (~half size).
|
|
37
|
+
"encoder-weights": {
|
|
38
|
+
"filename": "specclust_encoder.weights.h5",
|
|
39
|
+
"url": "https://github.com/VilenneFrederique/ECLIPSE/releases/download/v0.1.0/specclust_encoder.weights.h5",
|
|
40
|
+
"sha256": "3c90bb9bb5c9960251f9b2165dd61be89f5ed78be6b3d21f5d28a0bd49877a6e",
|
|
41
|
+
},
|
|
42
|
+
"encoder-config": {
|
|
43
|
+
"filename": "encoder_config.json",
|
|
44
|
+
"url": "https://github.com/VilenneFrederique/ECLIPSE/releases/download/v0.1.0/encoder_config.json",
|
|
45
|
+
"sha256": "89e53685f735458973c358746fb5444148cc6813725d93d7a45fcfd9974c0a00",
|
|
46
|
+
},
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def cache_dir() -> Path:
|
|
51
|
+
"""Directory where downloaded weights are cached."""
|
|
52
|
+
d = Path(user_cache_dir("eclipse-ms"))
|
|
53
|
+
d.mkdir(parents=True, exist_ok=True)
|
|
54
|
+
return d
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _sha256(path: Path) -> str:
|
|
58
|
+
h = hashlib.sha256()
|
|
59
|
+
with open(path, "rb") as f:
|
|
60
|
+
for chunk in iter(lambda: f.read(1 << 20), b""):
|
|
61
|
+
h.update(chunk)
|
|
62
|
+
return h.hexdigest()
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _download(url: str, dest: Path) -> None:
|
|
66
|
+
if url in (None, "", "REPLACE_ME"):
|
|
67
|
+
raise RuntimeError(
|
|
68
|
+
f"No download URL configured for {dest.name}. Either set the URL in "
|
|
69
|
+
f"eclipse_ms.modelhub.REGISTRY, set the ECLIPSE_MODEL_DIR environment "
|
|
70
|
+
f"variable to a folder containing the file, or pass an explicit path."
|
|
71
|
+
)
|
|
72
|
+
tmp = dest.with_suffix(dest.suffix + ".part")
|
|
73
|
+
print(f"Downloading {dest.name} from {url} ...")
|
|
74
|
+
with urllib.request.urlopen(url) as resp, open(tmp, "wb") as out: # noqa: S310
|
|
75
|
+
total = int(resp.headers.get("Content-Length", 0))
|
|
76
|
+
read = 0
|
|
77
|
+
while True:
|
|
78
|
+
chunk = resp.read(1 << 20)
|
|
79
|
+
if not chunk:
|
|
80
|
+
break
|
|
81
|
+
out.write(chunk)
|
|
82
|
+
read += len(chunk)
|
|
83
|
+
if total:
|
|
84
|
+
pct = 100 * read / total
|
|
85
|
+
print(f"\r {read / 1e6:,.0f} / {total / 1e6:,.0f} MB ({pct:.0f}%)", end="")
|
|
86
|
+
print()
|
|
87
|
+
tmp.replace(dest)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def get_model_file(key: str) -> Path:
|
|
91
|
+
"""Resolve a registry key to a local path, downloading/caching as needed."""
|
|
92
|
+
if key not in REGISTRY:
|
|
93
|
+
raise KeyError(f"Unknown model key '{key}'. Known: {list(REGISTRY)}")
|
|
94
|
+
entry = REGISTRY[key]
|
|
95
|
+
filename = entry["filename"]
|
|
96
|
+
|
|
97
|
+
env_dir = os.environ.get("ECLIPSE_MODEL_DIR")
|
|
98
|
+
if env_dir:
|
|
99
|
+
candidate = Path(env_dir) / filename
|
|
100
|
+
if candidate.exists():
|
|
101
|
+
return candidate
|
|
102
|
+
|
|
103
|
+
cached = cache_dir() / filename
|
|
104
|
+
if cached.exists():
|
|
105
|
+
if entry.get("sha256") and _sha256(cached) != entry["sha256"]:
|
|
106
|
+
print(f"Cached {filename} failed checksum; re-downloading.")
|
|
107
|
+
cached.unlink()
|
|
108
|
+
else:
|
|
109
|
+
return cached
|
|
110
|
+
|
|
111
|
+
_download(entry["url"], cached)
|
|
112
|
+
if entry.get("sha256") and _sha256(cached) != entry["sha256"]:
|
|
113
|
+
cached.unlink(missing_ok=True)
|
|
114
|
+
raise RuntimeError(f"Checksum mismatch for {filename} after download.")
|
|
115
|
+
return cached
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
# ---------------------------------------------------------------------------
|
|
119
|
+
# Loaders
|
|
120
|
+
# ---------------------------------------------------------------------------
|
|
121
|
+
def _build_and_load_encoder(config: dict, weights_path: str):
|
|
122
|
+
import tensorflow as tf
|
|
123
|
+
|
|
124
|
+
from .config import COND_DIM
|
|
125
|
+
from .models import ConditionalSpectrumEncoder
|
|
126
|
+
|
|
127
|
+
cfg = {k: v for k, v in config.items() if k not in ("conditional", "use_kl", "kl_weight")}
|
|
128
|
+
encoder = ConditionalSpectrumEncoder(**cfg)
|
|
129
|
+
|
|
130
|
+
cond_dim = config.get("cond_dim", COND_DIM)
|
|
131
|
+
n_bins = config.get("n_bins", 3200)
|
|
132
|
+
_ = encoder((tf.zeros((2, n_bins)), tf.zeros((2, cond_dim))), training=False)
|
|
133
|
+
encoder.load_weights(weights_path)
|
|
134
|
+
return encoder
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def load_encoder(weights: Optional[str] = None, config: Optional[str] = None):
|
|
138
|
+
"""Load the encoder for embedding spectra.
|
|
139
|
+
|
|
140
|
+
With no arguments, downloads/caches the published encoder weights. Pass
|
|
141
|
+
explicit ``weights`` (``.h5``) and ``config`` (``.json``) paths to load a
|
|
142
|
+
local model instead.
|
|
143
|
+
"""
|
|
144
|
+
weights_path = weights or str(get_model_file("encoder-weights"))
|
|
145
|
+
config_path = config or str(get_model_file("encoder-config"))
|
|
146
|
+
with open(config_path) as f:
|
|
147
|
+
cfg = json.load(f)
|
|
148
|
+
return _build_and_load_encoder(cfg, weights_path)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def load_autoencoder(weights: Optional[str] = None, config: Optional[str] = None):
|
|
152
|
+
"""Load the full autoencoder (encoder + decoder).
|
|
153
|
+
|
|
154
|
+
Needed only for reconstruction / visualisation; embedding and clustering
|
|
155
|
+
use :func:`load_encoder`, which is roughly half the download.
|
|
156
|
+
"""
|
|
157
|
+
import tensorflow as tf
|
|
158
|
+
|
|
159
|
+
from .config import COND_DIM
|
|
160
|
+
from .models import ConditionalSpectrumAutoencoder
|
|
161
|
+
|
|
162
|
+
weights_path = weights or str(get_model_file("ae-weights"))
|
|
163
|
+
config_path = config or str(get_model_file("ae-config"))
|
|
164
|
+
with open(config_path) as f:
|
|
165
|
+
cfg = json.load(f)
|
|
166
|
+
|
|
167
|
+
ctor = {k: v for k, v in cfg.items() if k != "conditional"}
|
|
168
|
+
ae = ConditionalSpectrumAutoencoder(**ctor)
|
|
169
|
+
cond_dim = cfg.get("cond_dim", COND_DIM)
|
|
170
|
+
n_bins = cfg.get("n_bins", 3200)
|
|
171
|
+
_ = ae((tf.zeros((2, n_bins)), tf.zeros((2, cond_dim))), training=False)
|
|
172
|
+
ae.load_weights(weights_path)
|
|
173
|
+
return ae
|
eclipse_ms/models.py
ADDED
|
@@ -0,0 +1,451 @@
|
|
|
1
|
+
"""Conditional spectrum autoencoder models.
|
|
2
|
+
|
|
3
|
+
Ported verbatim from the training code. The encoder is the only part needed for
|
|
4
|
+
embedding/clustering; the decoder and the full autoencoder (with train/test
|
|
5
|
+
steps) are included so the same classes can load full weights and be retrained.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import tensorflow as tf
|
|
9
|
+
from tensorflow import keras
|
|
10
|
+
from tensorflow.keras import layers
|
|
11
|
+
|
|
12
|
+
from .layers import PatchEmbedding, TransformerBlock
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@keras.utils.register_keras_serializable()
|
|
16
|
+
class ConditionalSpectrumEncoder(keras.Model):
|
|
17
|
+
"""Encode a binned spectrum + conditioning vector to a latent vector."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
n_bins: int = 3200,
|
|
22
|
+
patch_size: int = 16,
|
|
23
|
+
embed_dim: int = 256,
|
|
24
|
+
num_heads: int = 8,
|
|
25
|
+
num_layers: int = 4,
|
|
26
|
+
ff_dim: int = 512,
|
|
27
|
+
latent_dim: int = 256,
|
|
28
|
+
cond_dim: int = 8,
|
|
29
|
+
dropout: float = 0.1,
|
|
30
|
+
**kwargs,
|
|
31
|
+
):
|
|
32
|
+
super().__init__(**kwargs)
|
|
33
|
+
|
|
34
|
+
self.n_bins = n_bins
|
|
35
|
+
self.patch_size = patch_size
|
|
36
|
+
self.embed_dim = embed_dim
|
|
37
|
+
self.latent_dim = latent_dim
|
|
38
|
+
self.cond_dim = cond_dim
|
|
39
|
+
self.num_patches = n_bins // patch_size
|
|
40
|
+
|
|
41
|
+
self.patch_embed = PatchEmbedding(embed_dim, patch_size)
|
|
42
|
+
|
|
43
|
+
self.cond_proj = keras.Sequential(
|
|
44
|
+
[
|
|
45
|
+
layers.Dense(embed_dim, activation="gelu"),
|
|
46
|
+
layers.LayerNormalization(epsilon=1e-6),
|
|
47
|
+
layers.Dense(embed_dim),
|
|
48
|
+
],
|
|
49
|
+
name="cond_projection",
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
self.cls_token = self.add_weight(
|
|
53
|
+
name="cls_token",
|
|
54
|
+
shape=(1, 1, embed_dim),
|
|
55
|
+
initializer=keras.initializers.TruncatedNormal(stddev=0.02),
|
|
56
|
+
trainable=True,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
self.pos_embed = self.add_weight(
|
|
60
|
+
name="pos_embed",
|
|
61
|
+
shape=(1, self.num_patches + 2, embed_dim),
|
|
62
|
+
initializer="glorot_uniform",
|
|
63
|
+
trainable=True,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
self.transformer_blocks = [
|
|
67
|
+
TransformerBlock(embed_dim, num_heads, ff_dim, dropout)
|
|
68
|
+
for _ in range(num_layers)
|
|
69
|
+
]
|
|
70
|
+
|
|
71
|
+
self.final_norm = layers.LayerNormalization(epsilon=1e-6)
|
|
72
|
+
|
|
73
|
+
self.to_latent = keras.Sequential(
|
|
74
|
+
[
|
|
75
|
+
layers.Dense(latent_dim, activation="gelu"),
|
|
76
|
+
layers.LayerNormalization(epsilon=1e-6),
|
|
77
|
+
layers.Dense(latent_dim),
|
|
78
|
+
]
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def call(self, inputs, training=False):
|
|
82
|
+
x, cond = inputs
|
|
83
|
+
batch_size = tf.shape(x)[0]
|
|
84
|
+
|
|
85
|
+
x = self.patch_embed(x)
|
|
86
|
+
|
|
87
|
+
cond_token = self.cond_proj(cond)
|
|
88
|
+
cond_token = tf.expand_dims(cond_token, 1)
|
|
89
|
+
|
|
90
|
+
cls_tokens = tf.repeat(self.cls_token, batch_size, axis=0)
|
|
91
|
+
|
|
92
|
+
x = tf.concat([cls_tokens, cond_token, x], axis=1)
|
|
93
|
+
x = x + self.pos_embed
|
|
94
|
+
|
|
95
|
+
for block in self.transformer_blocks:
|
|
96
|
+
x = block(x, training=training)
|
|
97
|
+
|
|
98
|
+
x = self.final_norm(x)
|
|
99
|
+
cls_output = x[:, 0, :]
|
|
100
|
+
z = self.to_latent(cls_output)
|
|
101
|
+
|
|
102
|
+
return z
|
|
103
|
+
|
|
104
|
+
def get_config(self):
|
|
105
|
+
return {
|
|
106
|
+
"n_bins": self.n_bins,
|
|
107
|
+
"patch_size": self.patch_size,
|
|
108
|
+
"embed_dim": self.embed_dim,
|
|
109
|
+
"latent_dim": self.latent_dim,
|
|
110
|
+
"cond_dim": self.cond_dim,
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@keras.utils.register_keras_serializable()
|
|
115
|
+
class ConditionalSpectrumDecoder(keras.Model):
|
|
116
|
+
"""Two-head conditional decoder: latent + conditioning -> spectrum."""
|
|
117
|
+
|
|
118
|
+
def __init__(
|
|
119
|
+
self,
|
|
120
|
+
n_bins: int = 3200,
|
|
121
|
+
patch_size: int = 16,
|
|
122
|
+
embed_dim: int = 256,
|
|
123
|
+
num_heads: int = 8,
|
|
124
|
+
num_layers: int = 4,
|
|
125
|
+
ff_dim: int = 512,
|
|
126
|
+
latent_dim: int = 256,
|
|
127
|
+
cond_dim: int = 8,
|
|
128
|
+
dropout: float = 0.1,
|
|
129
|
+
**kwargs,
|
|
130
|
+
):
|
|
131
|
+
super().__init__(**kwargs)
|
|
132
|
+
|
|
133
|
+
self.n_bins = n_bins
|
|
134
|
+
self.patch_size = patch_size
|
|
135
|
+
self.embed_dim = embed_dim
|
|
136
|
+
self.latent_dim = latent_dim
|
|
137
|
+
self.cond_dim = cond_dim
|
|
138
|
+
self.num_patches = n_bins // patch_size
|
|
139
|
+
|
|
140
|
+
self.cond_proj = keras.Sequential(
|
|
141
|
+
[
|
|
142
|
+
layers.Dense(embed_dim, activation="gelu"),
|
|
143
|
+
layers.Dense(embed_dim),
|
|
144
|
+
],
|
|
145
|
+
name="cond_projection",
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
self.from_latent = keras.Sequential(
|
|
149
|
+
[
|
|
150
|
+
layers.Dense(embed_dim * 4, activation="gelu"),
|
|
151
|
+
layers.LayerNormalization(epsilon=1e-6),
|
|
152
|
+
layers.Dense(embed_dim * self.num_patches),
|
|
153
|
+
layers.Reshape((self.num_patches, embed_dim)),
|
|
154
|
+
]
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
self.pos_embed = self.add_weight(
|
|
158
|
+
name="dec_pos_embed",
|
|
159
|
+
shape=(1, self.num_patches + 1, embed_dim),
|
|
160
|
+
initializer="glorot_uniform",
|
|
161
|
+
trainable=True,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
self.transformer_blocks = [
|
|
165
|
+
TransformerBlock(embed_dim, num_heads, ff_dim, dropout)
|
|
166
|
+
for _ in range(num_layers)
|
|
167
|
+
]
|
|
168
|
+
|
|
169
|
+
self.final_norm = layers.LayerNormalization(epsilon=1e-6)
|
|
170
|
+
|
|
171
|
+
self.presence_head = keras.Sequential(
|
|
172
|
+
[
|
|
173
|
+
layers.Dense(ff_dim, activation="gelu"),
|
|
174
|
+
layers.Dense(patch_size, dtype="float32"),
|
|
175
|
+
],
|
|
176
|
+
name="presence_head",
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
self.intensity_head = keras.Sequential(
|
|
180
|
+
[
|
|
181
|
+
layers.Dense(ff_dim, activation="gelu"),
|
|
182
|
+
layers.Dense(patch_size, dtype="float32"),
|
|
183
|
+
],
|
|
184
|
+
name="intensity_head",
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
self.presence_threshold = 0.5
|
|
188
|
+
self.presence_temperature = 2.0
|
|
189
|
+
|
|
190
|
+
def call(self, inputs, training=False):
|
|
191
|
+
z, cond = inputs
|
|
192
|
+
batch_size = tf.shape(z)[0]
|
|
193
|
+
|
|
194
|
+
z_cond = tf.concat([z, cond], axis=-1)
|
|
195
|
+
x = self.from_latent(z_cond)
|
|
196
|
+
|
|
197
|
+
cond_token = self.cond_proj(cond)
|
|
198
|
+
cond_token = tf.expand_dims(cond_token, 1)
|
|
199
|
+
|
|
200
|
+
x = tf.concat([cond_token, x], axis=1)
|
|
201
|
+
x = x + self.pos_embed
|
|
202
|
+
|
|
203
|
+
for block in self.transformer_blocks:
|
|
204
|
+
x = block(x, training=training)
|
|
205
|
+
|
|
206
|
+
x = self.final_norm(x)
|
|
207
|
+
x = x[:, 1:, :]
|
|
208
|
+
|
|
209
|
+
presence_logits = self.presence_head(x)
|
|
210
|
+
presence_logits = tf.reshape(presence_logits, [batch_size, self.n_bins])
|
|
211
|
+
presence_prob = tf.nn.sigmoid(presence_logits * self.presence_temperature)
|
|
212
|
+
|
|
213
|
+
intensity_raw = self.intensity_head(x)
|
|
214
|
+
intensity_raw = tf.reshape(intensity_raw, [batch_size, self.n_bins])
|
|
215
|
+
intensity = tf.nn.sigmoid(intensity_raw)
|
|
216
|
+
|
|
217
|
+
self.last_presence_prob = presence_prob
|
|
218
|
+
self.last_presence_logits = presence_logits
|
|
219
|
+
self.last_intensity = intensity
|
|
220
|
+
|
|
221
|
+
if training:
|
|
222
|
+
x_recon = presence_prob * intensity
|
|
223
|
+
else:
|
|
224
|
+
presence_mask = tf.cast(presence_prob > self.presence_threshold, tf.float32)
|
|
225
|
+
x_recon = presence_mask * intensity
|
|
226
|
+
|
|
227
|
+
return x_recon
|
|
228
|
+
|
|
229
|
+
def get_config(self):
|
|
230
|
+
return {
|
|
231
|
+
"n_bins": self.n_bins,
|
|
232
|
+
"patch_size": self.patch_size,
|
|
233
|
+
"embed_dim": self.embed_dim,
|
|
234
|
+
"latent_dim": self.latent_dim,
|
|
235
|
+
"cond_dim": self.cond_dim,
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
@keras.utils.register_keras_serializable()
|
|
240
|
+
class ConditionalSpectrumAutoencoder(keras.Model):
|
|
241
|
+
"""Conditional autoencoder: (spectrum, conditioning) -> latent -> spectrum."""
|
|
242
|
+
|
|
243
|
+
def __init__(
|
|
244
|
+
self,
|
|
245
|
+
n_bins: int = 3200,
|
|
246
|
+
patch_size: int = 16,
|
|
247
|
+
embed_dim: int = 256,
|
|
248
|
+
num_heads: int = 8,
|
|
249
|
+
num_layers: int = 4,
|
|
250
|
+
ff_dim: int = 512,
|
|
251
|
+
latent_dim: int = 256,
|
|
252
|
+
cond_dim: int = 8,
|
|
253
|
+
dropout: float = 0.1,
|
|
254
|
+
use_kl: bool = False,
|
|
255
|
+
kl_weight: float = 1e-4,
|
|
256
|
+
**kwargs,
|
|
257
|
+
):
|
|
258
|
+
super().__init__(**kwargs)
|
|
259
|
+
|
|
260
|
+
self.latent_dim = latent_dim
|
|
261
|
+
self.cond_dim = cond_dim
|
|
262
|
+
self.use_kl = use_kl
|
|
263
|
+
self.kl_weight = kl_weight
|
|
264
|
+
|
|
265
|
+
self.encoder = ConditionalSpectrumEncoder(
|
|
266
|
+
n_bins=n_bins,
|
|
267
|
+
patch_size=patch_size,
|
|
268
|
+
embed_dim=embed_dim,
|
|
269
|
+
num_heads=num_heads,
|
|
270
|
+
num_layers=num_layers,
|
|
271
|
+
ff_dim=ff_dim,
|
|
272
|
+
latent_dim=latent_dim if not use_kl else latent_dim * 2,
|
|
273
|
+
cond_dim=cond_dim,
|
|
274
|
+
dropout=dropout,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
self.decoder = ConditionalSpectrumDecoder(
|
|
278
|
+
n_bins=n_bins,
|
|
279
|
+
patch_size=patch_size,
|
|
280
|
+
embed_dim=embed_dim,
|
|
281
|
+
num_heads=num_heads,
|
|
282
|
+
num_layers=num_layers,
|
|
283
|
+
ff_dim=ff_dim,
|
|
284
|
+
latent_dim=latent_dim,
|
|
285
|
+
cond_dim=cond_dim,
|
|
286
|
+
dropout=dropout,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
self.recon_loss_tracker = keras.metrics.Mean(name="recon_loss")
|
|
290
|
+
self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
|
|
291
|
+
self.total_loss_tracker = keras.metrics.Mean(name="loss")
|
|
292
|
+
self.cosine_sim_tracker = keras.metrics.Mean(name="cosine_sim")
|
|
293
|
+
self.sparsity_tracker = keras.metrics.Mean(name="sparsity")
|
|
294
|
+
self.presence_acc_tracker = keras.metrics.Mean(name="presence_acc")
|
|
295
|
+
self.fp_rate_tracker = keras.metrics.Mean(name="fp_rate")
|
|
296
|
+
|
|
297
|
+
def encode(self, x, cond, training=False):
|
|
298
|
+
z = self.encoder((x, cond), training=training)
|
|
299
|
+
|
|
300
|
+
if self.use_kl:
|
|
301
|
+
mu = z[:, : self.latent_dim]
|
|
302
|
+
logvar = z[:, self.latent_dim :]
|
|
303
|
+
|
|
304
|
+
if training:
|
|
305
|
+
std = tf.exp(0.5 * logvar)
|
|
306
|
+
eps = tf.random.normal(tf.shape(std))
|
|
307
|
+
z = mu + eps * std
|
|
308
|
+
else:
|
|
309
|
+
z = mu
|
|
310
|
+
|
|
311
|
+
return z, mu, logvar
|
|
312
|
+
|
|
313
|
+
return z
|
|
314
|
+
|
|
315
|
+
def decode(self, z, cond, training=False):
|
|
316
|
+
return self.decoder((z, cond), training=training)
|
|
317
|
+
|
|
318
|
+
def call(self, inputs, training=False):
|
|
319
|
+
x, cond = inputs
|
|
320
|
+
|
|
321
|
+
if self.use_kl:
|
|
322
|
+
z, mu, logvar = self.encode(x, cond, training=training)
|
|
323
|
+
else:
|
|
324
|
+
z = self.encode(x, cond, training=training)
|
|
325
|
+
|
|
326
|
+
x_recon = self.decode(z, cond, training=training)
|
|
327
|
+
return x_recon
|
|
328
|
+
|
|
329
|
+
def _compute_losses(self, x, x_recon):
|
|
330
|
+
presence_prob = self.decoder.last_presence_prob
|
|
331
|
+
presence_logits = self.decoder.last_presence_logits
|
|
332
|
+
intensity = self.decoder.last_intensity
|
|
333
|
+
|
|
334
|
+
peak_mask = tf.cast(x > 0.05, tf.float32)
|
|
335
|
+
|
|
336
|
+
presence_bce = tf.nn.sigmoid_cross_entropy_with_logits(
|
|
337
|
+
labels=peak_mask, logits=presence_logits
|
|
338
|
+
)
|
|
339
|
+
presence_loss = tf.reduce_mean(presence_bce)
|
|
340
|
+
|
|
341
|
+
intensity_error = tf.square(x - intensity)
|
|
342
|
+
masked_intensity_error = intensity_error * peak_mask
|
|
343
|
+
num_peaks = tf.reduce_sum(peak_mask, axis=-1, keepdims=True) + 1e-6
|
|
344
|
+
intensity_loss = tf.reduce_mean(
|
|
345
|
+
tf.reduce_sum(masked_intensity_error, axis=-1) / tf.squeeze(num_peaks)
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
x_norm = tf.nn.l2_normalize(x, axis=-1)
|
|
349
|
+
x_recon_norm = tf.nn.l2_normalize(x_recon, axis=-1)
|
|
350
|
+
cos_sim = tf.reduce_sum(x_norm * x_recon_norm, axis=-1)
|
|
351
|
+
spectral_angle_loss = tf.reduce_mean(1 - cos_sim)
|
|
352
|
+
|
|
353
|
+
false_positive_mask = 1 - peak_mask
|
|
354
|
+
false_positive_penalty = tf.reduce_mean(presence_prob * false_positive_mask)
|
|
355
|
+
|
|
356
|
+
recon_loss = (
|
|
357
|
+
1.0 * presence_loss
|
|
358
|
+
+ 1.0 * intensity_loss
|
|
359
|
+
+ 0.5 * spectral_angle_loss
|
|
360
|
+
+ 0.5 * false_positive_penalty
|
|
361
|
+
)
|
|
362
|
+
return recon_loss, cos_sim, presence_prob, peak_mask
|
|
363
|
+
|
|
364
|
+
def _update_trackers(self, recon_loss, kl_loss, total_loss, cos_sim, presence_prob, peak_mask):
|
|
365
|
+
self.recon_loss_tracker.update_state(recon_loss)
|
|
366
|
+
self.kl_loss_tracker.update_state(kl_loss)
|
|
367
|
+
self.total_loss_tracker.update_state(total_loss)
|
|
368
|
+
self.cosine_sim_tracker.update_state(tf.reduce_mean(cos_sim))
|
|
369
|
+
|
|
370
|
+
sparsity = tf.reduce_mean(tf.cast(presence_prob < 0.1, tf.float32))
|
|
371
|
+
self.sparsity_tracker.update_state(sparsity)
|
|
372
|
+
|
|
373
|
+
presence_pred = tf.cast(presence_prob > 0.5, tf.float32)
|
|
374
|
+
presence_acc = tf.reduce_mean(tf.cast(tf.equal(presence_pred, peak_mask), tf.float32))
|
|
375
|
+
self.presence_acc_tracker.update_state(presence_acc)
|
|
376
|
+
|
|
377
|
+
predicted_peaks = tf.reduce_sum(presence_pred)
|
|
378
|
+
false_positives = tf.reduce_sum(presence_pred * (1 - peak_mask))
|
|
379
|
+
fp_rate = false_positives / (predicted_peaks + 1e-6)
|
|
380
|
+
self.fp_rate_tracker.update_state(fp_rate)
|
|
381
|
+
|
|
382
|
+
def _results(self):
|
|
383
|
+
return {
|
|
384
|
+
"loss": self.total_loss_tracker.result(),
|
|
385
|
+
"recon_loss": self.recon_loss_tracker.result(),
|
|
386
|
+
"kl_loss": self.kl_loss_tracker.result(),
|
|
387
|
+
"cosine_sim": self.cosine_sim_tracker.result(),
|
|
388
|
+
"sparsity": self.sparsity_tracker.result(),
|
|
389
|
+
"presence_acc": self.presence_acc_tracker.result(),
|
|
390
|
+
"fp_rate": self.fp_rate_tracker.result(),
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
def train_step(self, data):
|
|
394
|
+
x, cond = data
|
|
395
|
+
with tf.GradientTape() as tape:
|
|
396
|
+
if self.use_kl:
|
|
397
|
+
z, mu, logvar = self.encode(x, cond, training=True)
|
|
398
|
+
x_recon = self.decode(z, cond, training=True)
|
|
399
|
+
kl_loss = -0.5 * tf.reduce_mean(1 + logvar - tf.square(mu) - tf.exp(logvar))
|
|
400
|
+
else:
|
|
401
|
+
z = self.encode(x, cond, training=True)
|
|
402
|
+
x_recon = self.decode(z, cond, training=True)
|
|
403
|
+
kl_loss = 0.0
|
|
404
|
+
|
|
405
|
+
recon_loss, cos_sim, presence_prob, peak_mask = self._compute_losses(x, x_recon)
|
|
406
|
+
total_loss = recon_loss
|
|
407
|
+
if self.use_kl:
|
|
408
|
+
total_loss += self.kl_weight * kl_loss
|
|
409
|
+
|
|
410
|
+
gradients = tape.gradient(total_loss, self.trainable_variables)
|
|
411
|
+
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
|
|
412
|
+
self._update_trackers(recon_loss, kl_loss, total_loss, cos_sim, presence_prob, peak_mask)
|
|
413
|
+
return self._results()
|
|
414
|
+
|
|
415
|
+
def test_step(self, data):
|
|
416
|
+
x, cond = data
|
|
417
|
+
if self.use_kl:
|
|
418
|
+
z, mu, logvar = self.encode(x, cond, training=False)
|
|
419
|
+
x_recon = self.decode(z, cond, training=False)
|
|
420
|
+
kl_loss = -0.5 * tf.reduce_mean(1 + logvar - tf.square(mu) - tf.exp(logvar))
|
|
421
|
+
else:
|
|
422
|
+
z = self.encode(x, cond, training=False)
|
|
423
|
+
x_recon = self.decode(z, cond, training=False)
|
|
424
|
+
kl_loss = 0.0
|
|
425
|
+
|
|
426
|
+
recon_loss, cos_sim, presence_prob, peak_mask = self._compute_losses(x, x_recon)
|
|
427
|
+
total_loss = recon_loss
|
|
428
|
+
if self.use_kl:
|
|
429
|
+
total_loss += self.kl_weight * kl_loss
|
|
430
|
+
self._update_trackers(recon_loss, kl_loss, total_loss, cos_sim, presence_prob, peak_mask)
|
|
431
|
+
return self._results()
|
|
432
|
+
|
|
433
|
+
@property
|
|
434
|
+
def metrics(self):
|
|
435
|
+
return [
|
|
436
|
+
self.total_loss_tracker,
|
|
437
|
+
self.recon_loss_tracker,
|
|
438
|
+
self.kl_loss_tracker,
|
|
439
|
+
self.cosine_sim_tracker,
|
|
440
|
+
self.sparsity_tracker,
|
|
441
|
+
self.presence_acc_tracker,
|
|
442
|
+
self.fp_rate_tracker,
|
|
443
|
+
]
|
|
444
|
+
|
|
445
|
+
def get_config(self):
|
|
446
|
+
return {
|
|
447
|
+
"latent_dim": self.latent_dim,
|
|
448
|
+
"cond_dim": self.cond_dim,
|
|
449
|
+
"use_kl": self.use_kl,
|
|
450
|
+
"kl_weight": self.kl_weight,
|
|
451
|
+
}
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""Spectrum preprocessing (NumPy).
|
|
2
|
+
|
|
3
|
+
These reproduce the exact binning and conditioning used during training, so
|
|
4
|
+
embeddings computed at inference time match the model's expectations. The
|
|
5
|
+
NumPy implementation keeps this module importable without TensorFlow.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
from .config import Config
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def bin_spectrum_numpy(mz: np.ndarray, intensity: np.ndarray, config=Config) -> np.ndarray:
|
|
16
|
+
"""Bin a single spectrum to the fixed-width vector the encoder expects."""
|
|
17
|
+
mz = np.asarray(mz, dtype=np.float64)
|
|
18
|
+
intensity = np.asarray(intensity, dtype=np.float64)
|
|
19
|
+
|
|
20
|
+
mask = (mz >= config.MZ_MIN) & (mz < config.MZ_MAX) & (intensity > 0)
|
|
21
|
+
mz = mz[mask]
|
|
22
|
+
intensity = intensity[mask]
|
|
23
|
+
if len(intensity) == 0:
|
|
24
|
+
return np.zeros(config.N_BINS, dtype=np.float32)
|
|
25
|
+
|
|
26
|
+
intensity = intensity / intensity.max()
|
|
27
|
+
|
|
28
|
+
mask = intensity >= config.RELATIVE_INTENSITY_THRESHOLD
|
|
29
|
+
mz = mz[mask]
|
|
30
|
+
intensity = intensity[mask]
|
|
31
|
+
if len(intensity) == 0:
|
|
32
|
+
return np.zeros(config.N_BINS, dtype=np.float32)
|
|
33
|
+
|
|
34
|
+
if getattr(config, "TOP_N_PEAKS", None) and len(intensity) > config.TOP_N_PEAKS:
|
|
35
|
+
top_idx = np.argsort(intensity)[-config.TOP_N_PEAKS:]
|
|
36
|
+
mz = mz[top_idx]
|
|
37
|
+
intensity = intensity[top_idx]
|
|
38
|
+
|
|
39
|
+
intensity = np.sqrt(intensity)
|
|
40
|
+
|
|
41
|
+
bin_indices = ((mz - config.MZ_MIN) / config.BIN_SIZE).astype(int)
|
|
42
|
+
bin_indices = np.clip(bin_indices, 0, config.N_BINS - 1)
|
|
43
|
+
|
|
44
|
+
binned = np.zeros(config.N_BINS, dtype=np.float32)
|
|
45
|
+
np.maximum.at(binned, bin_indices, intensity.astype(np.float32))
|
|
46
|
+
|
|
47
|
+
if binned.max() > 0:
|
|
48
|
+
binned = binned / binned.max()
|
|
49
|
+
|
|
50
|
+
return binned
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def build_cond_vector(
|
|
54
|
+
precursor_mz: float,
|
|
55
|
+
charge: int,
|
|
56
|
+
ion_mobility: float,
|
|
57
|
+
config=Config,
|
|
58
|
+
) -> np.ndarray:
|
|
59
|
+
"""Build the conditioning vector (one-hot charge + norm. m/z + norm. IM).
|
|
60
|
+
|
|
61
|
+
Matches the training preprocessing; length is ``config.MAX_CHARGE + 2``.
|
|
62
|
+
"""
|
|
63
|
+
charge_int = max(1, min(int(charge), config.MAX_CHARGE))
|
|
64
|
+
charge_onehot = np.zeros(config.MAX_CHARGE, dtype=np.float32)
|
|
65
|
+
charge_onehot[charge_int - 1] = 1.0
|
|
66
|
+
|
|
67
|
+
mz_norm = float(precursor_mz) / config.PRECURSOR_MZ_MAX
|
|
68
|
+
im_norm = float(
|
|
69
|
+
np.clip((ion_mobility - config.IM_MIN) / (config.IM_MAX - config.IM_MIN), 0.0, 1.0)
|
|
70
|
+
)
|
|
71
|
+
return np.concatenate([charge_onehot, [mz_norm, im_norm]]).astype(np.float32)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def preprocess(
|
|
75
|
+
mz: np.ndarray,
|
|
76
|
+
intensity: np.ndarray,
|
|
77
|
+
precursor_mz: float,
|
|
78
|
+
charge: int,
|
|
79
|
+
ion_mobility: float,
|
|
80
|
+
config=Config,
|
|
81
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
82
|
+
"""Convenience: return ``(binned_spectrum, conditioning_vector)``."""
|
|
83
|
+
binned = bin_spectrum_numpy(mz, intensity, config)
|
|
84
|
+
cond = build_cond_vector(precursor_mz, charge, ion_mobility, config)
|
|
85
|
+
return binned, cond
|