eclipse-ms 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
eclipse_ms/modelhub.py ADDED
@@ -0,0 +1,173 @@
1
+ """Model registry, download/cache, and loaders.
2
+
3
+ The trained weights are far too large to ship inside the PyPI wheel, so they
4
+ live in external storage (a GitHub Release asset, a Hugging Face Hub file, or a
5
+ Zenodo record) and are downloaded on first use and cached locally, with a
6
+ SHA-256 integrity check.
7
+
8
+ Resolution order for any model file:
9
+ 1. ``ECLIPSE_MODEL_DIR`` env var, if set and the file exists there;
10
+ 2. the local cache (``platformdirs`` user cache dir);
11
+ 3. download from the registry URL into the cache.
12
+
13
+ You can also bypass the registry entirely and pass explicit local paths to
14
+ :func:`load_encoder` / :func:`load_autoencoder` (e.g. on an HPC node where you
15
+ already have the weights).
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import hashlib
21
+ import json
22
+ import os
23
+ import urllib.request
24
+ from pathlib import Path
25
+ from typing import Optional
26
+
27
+ from platformdirs import user_cache_dir
28
+
29
+ # ---------------------------------------------------------------------------
30
+ # Registry. After uploading your weights, fill in `url` and `sha256` for each
31
+ # entry. `sha256=None` disables the integrity check (not recommended for a
32
+ # release). Compute a hash with: python -c "import hashlib,sys;
33
+ # print(hashlib.sha256(open(sys.argv[1],'rb').read()).hexdigest())" FILE
34
+ # ---------------------------------------------------------------------------
35
+ REGISTRY: dict[str, dict] = {
36
+ # Slim, recommended for embedding/clustering: encoder weights only (~half size).
37
+ "encoder-weights": {
38
+ "filename": "specclust_encoder.weights.h5",
39
+ "url": "https://github.com/VilenneFrederique/ECLIPSE/releases/download/v0.1.0/specclust_encoder.weights.h5",
40
+ "sha256": "3c90bb9bb5c9960251f9b2165dd61be89f5ed78be6b3d21f5d28a0bd49877a6e",
41
+ },
42
+ "encoder-config": {
43
+ "filename": "encoder_config.json",
44
+ "url": "https://github.com/VilenneFrederique/ECLIPSE/releases/download/v0.1.0/encoder_config.json",
45
+ "sha256": "89e53685f735458973c358746fb5444148cc6813725d93d7a45fcfd9974c0a00",
46
+ },
47
+ }
48
+
49
+
50
+ def cache_dir() -> Path:
51
+ """Directory where downloaded weights are cached."""
52
+ d = Path(user_cache_dir("eclipse-ms"))
53
+ d.mkdir(parents=True, exist_ok=True)
54
+ return d
55
+
56
+
57
+ def _sha256(path: Path) -> str:
58
+ h = hashlib.sha256()
59
+ with open(path, "rb") as f:
60
+ for chunk in iter(lambda: f.read(1 << 20), b""):
61
+ h.update(chunk)
62
+ return h.hexdigest()
63
+
64
+
65
+ def _download(url: str, dest: Path) -> None:
66
+ if url in (None, "", "REPLACE_ME"):
67
+ raise RuntimeError(
68
+ f"No download URL configured for {dest.name}. Either set the URL in "
69
+ f"eclipse_ms.modelhub.REGISTRY, set the ECLIPSE_MODEL_DIR environment "
70
+ f"variable to a folder containing the file, or pass an explicit path."
71
+ )
72
+ tmp = dest.with_suffix(dest.suffix + ".part")
73
+ print(f"Downloading {dest.name} from {url} ...")
74
+ with urllib.request.urlopen(url) as resp, open(tmp, "wb") as out: # noqa: S310
75
+ total = int(resp.headers.get("Content-Length", 0))
76
+ read = 0
77
+ while True:
78
+ chunk = resp.read(1 << 20)
79
+ if not chunk:
80
+ break
81
+ out.write(chunk)
82
+ read += len(chunk)
83
+ if total:
84
+ pct = 100 * read / total
85
+ print(f"\r {read / 1e6:,.0f} / {total / 1e6:,.0f} MB ({pct:.0f}%)", end="")
86
+ print()
87
+ tmp.replace(dest)
88
+
89
+
90
+ def get_model_file(key: str) -> Path:
91
+ """Resolve a registry key to a local path, downloading/caching as needed."""
92
+ if key not in REGISTRY:
93
+ raise KeyError(f"Unknown model key '{key}'. Known: {list(REGISTRY)}")
94
+ entry = REGISTRY[key]
95
+ filename = entry["filename"]
96
+
97
+ env_dir = os.environ.get("ECLIPSE_MODEL_DIR")
98
+ if env_dir:
99
+ candidate = Path(env_dir) / filename
100
+ if candidate.exists():
101
+ return candidate
102
+
103
+ cached = cache_dir() / filename
104
+ if cached.exists():
105
+ if entry.get("sha256") and _sha256(cached) != entry["sha256"]:
106
+ print(f"Cached {filename} failed checksum; re-downloading.")
107
+ cached.unlink()
108
+ else:
109
+ return cached
110
+
111
+ _download(entry["url"], cached)
112
+ if entry.get("sha256") and _sha256(cached) != entry["sha256"]:
113
+ cached.unlink(missing_ok=True)
114
+ raise RuntimeError(f"Checksum mismatch for {filename} after download.")
115
+ return cached
116
+
117
+
118
+ # ---------------------------------------------------------------------------
119
+ # Loaders
120
+ # ---------------------------------------------------------------------------
121
+ def _build_and_load_encoder(config: dict, weights_path: str):
122
+ import tensorflow as tf
123
+
124
+ from .config import COND_DIM
125
+ from .models import ConditionalSpectrumEncoder
126
+
127
+ cfg = {k: v for k, v in config.items() if k not in ("conditional", "use_kl", "kl_weight")}
128
+ encoder = ConditionalSpectrumEncoder(**cfg)
129
+
130
+ cond_dim = config.get("cond_dim", COND_DIM)
131
+ n_bins = config.get("n_bins", 3200)
132
+ _ = encoder((tf.zeros((2, n_bins)), tf.zeros((2, cond_dim))), training=False)
133
+ encoder.load_weights(weights_path)
134
+ return encoder
135
+
136
+
137
+ def load_encoder(weights: Optional[str] = None, config: Optional[str] = None):
138
+ """Load the encoder for embedding spectra.
139
+
140
+ With no arguments, downloads/caches the published encoder weights. Pass
141
+ explicit ``weights`` (``.h5``) and ``config`` (``.json``) paths to load a
142
+ local model instead.
143
+ """
144
+ weights_path = weights or str(get_model_file("encoder-weights"))
145
+ config_path = config or str(get_model_file("encoder-config"))
146
+ with open(config_path) as f:
147
+ cfg = json.load(f)
148
+ return _build_and_load_encoder(cfg, weights_path)
149
+
150
+
151
+ def load_autoencoder(weights: Optional[str] = None, config: Optional[str] = None):
152
+ """Load the full autoencoder (encoder + decoder).
153
+
154
+ Needed only for reconstruction / visualisation; embedding and clustering
155
+ use :func:`load_encoder`, which is roughly half the download.
156
+ """
157
+ import tensorflow as tf
158
+
159
+ from .config import COND_DIM
160
+ from .models import ConditionalSpectrumAutoencoder
161
+
162
+ weights_path = weights or str(get_model_file("ae-weights"))
163
+ config_path = config or str(get_model_file("ae-config"))
164
+ with open(config_path) as f:
165
+ cfg = json.load(f)
166
+
167
+ ctor = {k: v for k, v in cfg.items() if k != "conditional"}
168
+ ae = ConditionalSpectrumAutoencoder(**ctor)
169
+ cond_dim = cfg.get("cond_dim", COND_DIM)
170
+ n_bins = cfg.get("n_bins", 3200)
171
+ _ = ae((tf.zeros((2, n_bins)), tf.zeros((2, cond_dim))), training=False)
172
+ ae.load_weights(weights_path)
173
+ return ae
eclipse_ms/models.py ADDED
@@ -0,0 +1,451 @@
1
+ """Conditional spectrum autoencoder models.
2
+
3
+ Ported verbatim from the training code. The encoder is the only part needed for
4
+ embedding/clustering; the decoder and the full autoencoder (with train/test
5
+ steps) are included so the same classes can load full weights and be retrained.
6
+ """
7
+
8
+ import tensorflow as tf
9
+ from tensorflow import keras
10
+ from tensorflow.keras import layers
11
+
12
+ from .layers import PatchEmbedding, TransformerBlock
13
+
14
+
15
+ @keras.utils.register_keras_serializable()
16
+ class ConditionalSpectrumEncoder(keras.Model):
17
+ """Encode a binned spectrum + conditioning vector to a latent vector."""
18
+
19
+ def __init__(
20
+ self,
21
+ n_bins: int = 3200,
22
+ patch_size: int = 16,
23
+ embed_dim: int = 256,
24
+ num_heads: int = 8,
25
+ num_layers: int = 4,
26
+ ff_dim: int = 512,
27
+ latent_dim: int = 256,
28
+ cond_dim: int = 8,
29
+ dropout: float = 0.1,
30
+ **kwargs,
31
+ ):
32
+ super().__init__(**kwargs)
33
+
34
+ self.n_bins = n_bins
35
+ self.patch_size = patch_size
36
+ self.embed_dim = embed_dim
37
+ self.latent_dim = latent_dim
38
+ self.cond_dim = cond_dim
39
+ self.num_patches = n_bins // patch_size
40
+
41
+ self.patch_embed = PatchEmbedding(embed_dim, patch_size)
42
+
43
+ self.cond_proj = keras.Sequential(
44
+ [
45
+ layers.Dense(embed_dim, activation="gelu"),
46
+ layers.LayerNormalization(epsilon=1e-6),
47
+ layers.Dense(embed_dim),
48
+ ],
49
+ name="cond_projection",
50
+ )
51
+
52
+ self.cls_token = self.add_weight(
53
+ name="cls_token",
54
+ shape=(1, 1, embed_dim),
55
+ initializer=keras.initializers.TruncatedNormal(stddev=0.02),
56
+ trainable=True,
57
+ )
58
+
59
+ self.pos_embed = self.add_weight(
60
+ name="pos_embed",
61
+ shape=(1, self.num_patches + 2, embed_dim),
62
+ initializer="glorot_uniform",
63
+ trainable=True,
64
+ )
65
+
66
+ self.transformer_blocks = [
67
+ TransformerBlock(embed_dim, num_heads, ff_dim, dropout)
68
+ for _ in range(num_layers)
69
+ ]
70
+
71
+ self.final_norm = layers.LayerNormalization(epsilon=1e-6)
72
+
73
+ self.to_latent = keras.Sequential(
74
+ [
75
+ layers.Dense(latent_dim, activation="gelu"),
76
+ layers.LayerNormalization(epsilon=1e-6),
77
+ layers.Dense(latent_dim),
78
+ ]
79
+ )
80
+
81
+ def call(self, inputs, training=False):
82
+ x, cond = inputs
83
+ batch_size = tf.shape(x)[0]
84
+
85
+ x = self.patch_embed(x)
86
+
87
+ cond_token = self.cond_proj(cond)
88
+ cond_token = tf.expand_dims(cond_token, 1)
89
+
90
+ cls_tokens = tf.repeat(self.cls_token, batch_size, axis=0)
91
+
92
+ x = tf.concat([cls_tokens, cond_token, x], axis=1)
93
+ x = x + self.pos_embed
94
+
95
+ for block in self.transformer_blocks:
96
+ x = block(x, training=training)
97
+
98
+ x = self.final_norm(x)
99
+ cls_output = x[:, 0, :]
100
+ z = self.to_latent(cls_output)
101
+
102
+ return z
103
+
104
+ def get_config(self):
105
+ return {
106
+ "n_bins": self.n_bins,
107
+ "patch_size": self.patch_size,
108
+ "embed_dim": self.embed_dim,
109
+ "latent_dim": self.latent_dim,
110
+ "cond_dim": self.cond_dim,
111
+ }
112
+
113
+
114
+ @keras.utils.register_keras_serializable()
115
+ class ConditionalSpectrumDecoder(keras.Model):
116
+ """Two-head conditional decoder: latent + conditioning -> spectrum."""
117
+
118
+ def __init__(
119
+ self,
120
+ n_bins: int = 3200,
121
+ patch_size: int = 16,
122
+ embed_dim: int = 256,
123
+ num_heads: int = 8,
124
+ num_layers: int = 4,
125
+ ff_dim: int = 512,
126
+ latent_dim: int = 256,
127
+ cond_dim: int = 8,
128
+ dropout: float = 0.1,
129
+ **kwargs,
130
+ ):
131
+ super().__init__(**kwargs)
132
+
133
+ self.n_bins = n_bins
134
+ self.patch_size = patch_size
135
+ self.embed_dim = embed_dim
136
+ self.latent_dim = latent_dim
137
+ self.cond_dim = cond_dim
138
+ self.num_patches = n_bins // patch_size
139
+
140
+ self.cond_proj = keras.Sequential(
141
+ [
142
+ layers.Dense(embed_dim, activation="gelu"),
143
+ layers.Dense(embed_dim),
144
+ ],
145
+ name="cond_projection",
146
+ )
147
+
148
+ self.from_latent = keras.Sequential(
149
+ [
150
+ layers.Dense(embed_dim * 4, activation="gelu"),
151
+ layers.LayerNormalization(epsilon=1e-6),
152
+ layers.Dense(embed_dim * self.num_patches),
153
+ layers.Reshape((self.num_patches, embed_dim)),
154
+ ]
155
+ )
156
+
157
+ self.pos_embed = self.add_weight(
158
+ name="dec_pos_embed",
159
+ shape=(1, self.num_patches + 1, embed_dim),
160
+ initializer="glorot_uniform",
161
+ trainable=True,
162
+ )
163
+
164
+ self.transformer_blocks = [
165
+ TransformerBlock(embed_dim, num_heads, ff_dim, dropout)
166
+ for _ in range(num_layers)
167
+ ]
168
+
169
+ self.final_norm = layers.LayerNormalization(epsilon=1e-6)
170
+
171
+ self.presence_head = keras.Sequential(
172
+ [
173
+ layers.Dense(ff_dim, activation="gelu"),
174
+ layers.Dense(patch_size, dtype="float32"),
175
+ ],
176
+ name="presence_head",
177
+ )
178
+
179
+ self.intensity_head = keras.Sequential(
180
+ [
181
+ layers.Dense(ff_dim, activation="gelu"),
182
+ layers.Dense(patch_size, dtype="float32"),
183
+ ],
184
+ name="intensity_head",
185
+ )
186
+
187
+ self.presence_threshold = 0.5
188
+ self.presence_temperature = 2.0
189
+
190
+ def call(self, inputs, training=False):
191
+ z, cond = inputs
192
+ batch_size = tf.shape(z)[0]
193
+
194
+ z_cond = tf.concat([z, cond], axis=-1)
195
+ x = self.from_latent(z_cond)
196
+
197
+ cond_token = self.cond_proj(cond)
198
+ cond_token = tf.expand_dims(cond_token, 1)
199
+
200
+ x = tf.concat([cond_token, x], axis=1)
201
+ x = x + self.pos_embed
202
+
203
+ for block in self.transformer_blocks:
204
+ x = block(x, training=training)
205
+
206
+ x = self.final_norm(x)
207
+ x = x[:, 1:, :]
208
+
209
+ presence_logits = self.presence_head(x)
210
+ presence_logits = tf.reshape(presence_logits, [batch_size, self.n_bins])
211
+ presence_prob = tf.nn.sigmoid(presence_logits * self.presence_temperature)
212
+
213
+ intensity_raw = self.intensity_head(x)
214
+ intensity_raw = tf.reshape(intensity_raw, [batch_size, self.n_bins])
215
+ intensity = tf.nn.sigmoid(intensity_raw)
216
+
217
+ self.last_presence_prob = presence_prob
218
+ self.last_presence_logits = presence_logits
219
+ self.last_intensity = intensity
220
+
221
+ if training:
222
+ x_recon = presence_prob * intensity
223
+ else:
224
+ presence_mask = tf.cast(presence_prob > self.presence_threshold, tf.float32)
225
+ x_recon = presence_mask * intensity
226
+
227
+ return x_recon
228
+
229
+ def get_config(self):
230
+ return {
231
+ "n_bins": self.n_bins,
232
+ "patch_size": self.patch_size,
233
+ "embed_dim": self.embed_dim,
234
+ "latent_dim": self.latent_dim,
235
+ "cond_dim": self.cond_dim,
236
+ }
237
+
238
+
239
+ @keras.utils.register_keras_serializable()
240
+ class ConditionalSpectrumAutoencoder(keras.Model):
241
+ """Conditional autoencoder: (spectrum, conditioning) -> latent -> spectrum."""
242
+
243
+ def __init__(
244
+ self,
245
+ n_bins: int = 3200,
246
+ patch_size: int = 16,
247
+ embed_dim: int = 256,
248
+ num_heads: int = 8,
249
+ num_layers: int = 4,
250
+ ff_dim: int = 512,
251
+ latent_dim: int = 256,
252
+ cond_dim: int = 8,
253
+ dropout: float = 0.1,
254
+ use_kl: bool = False,
255
+ kl_weight: float = 1e-4,
256
+ **kwargs,
257
+ ):
258
+ super().__init__(**kwargs)
259
+
260
+ self.latent_dim = latent_dim
261
+ self.cond_dim = cond_dim
262
+ self.use_kl = use_kl
263
+ self.kl_weight = kl_weight
264
+
265
+ self.encoder = ConditionalSpectrumEncoder(
266
+ n_bins=n_bins,
267
+ patch_size=patch_size,
268
+ embed_dim=embed_dim,
269
+ num_heads=num_heads,
270
+ num_layers=num_layers,
271
+ ff_dim=ff_dim,
272
+ latent_dim=latent_dim if not use_kl else latent_dim * 2,
273
+ cond_dim=cond_dim,
274
+ dropout=dropout,
275
+ )
276
+
277
+ self.decoder = ConditionalSpectrumDecoder(
278
+ n_bins=n_bins,
279
+ patch_size=patch_size,
280
+ embed_dim=embed_dim,
281
+ num_heads=num_heads,
282
+ num_layers=num_layers,
283
+ ff_dim=ff_dim,
284
+ latent_dim=latent_dim,
285
+ cond_dim=cond_dim,
286
+ dropout=dropout,
287
+ )
288
+
289
+ self.recon_loss_tracker = keras.metrics.Mean(name="recon_loss")
290
+ self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
291
+ self.total_loss_tracker = keras.metrics.Mean(name="loss")
292
+ self.cosine_sim_tracker = keras.metrics.Mean(name="cosine_sim")
293
+ self.sparsity_tracker = keras.metrics.Mean(name="sparsity")
294
+ self.presence_acc_tracker = keras.metrics.Mean(name="presence_acc")
295
+ self.fp_rate_tracker = keras.metrics.Mean(name="fp_rate")
296
+
297
+ def encode(self, x, cond, training=False):
298
+ z = self.encoder((x, cond), training=training)
299
+
300
+ if self.use_kl:
301
+ mu = z[:, : self.latent_dim]
302
+ logvar = z[:, self.latent_dim :]
303
+
304
+ if training:
305
+ std = tf.exp(0.5 * logvar)
306
+ eps = tf.random.normal(tf.shape(std))
307
+ z = mu + eps * std
308
+ else:
309
+ z = mu
310
+
311
+ return z, mu, logvar
312
+
313
+ return z
314
+
315
+ def decode(self, z, cond, training=False):
316
+ return self.decoder((z, cond), training=training)
317
+
318
+ def call(self, inputs, training=False):
319
+ x, cond = inputs
320
+
321
+ if self.use_kl:
322
+ z, mu, logvar = self.encode(x, cond, training=training)
323
+ else:
324
+ z = self.encode(x, cond, training=training)
325
+
326
+ x_recon = self.decode(z, cond, training=training)
327
+ return x_recon
328
+
329
+ def _compute_losses(self, x, x_recon):
330
+ presence_prob = self.decoder.last_presence_prob
331
+ presence_logits = self.decoder.last_presence_logits
332
+ intensity = self.decoder.last_intensity
333
+
334
+ peak_mask = tf.cast(x > 0.05, tf.float32)
335
+
336
+ presence_bce = tf.nn.sigmoid_cross_entropy_with_logits(
337
+ labels=peak_mask, logits=presence_logits
338
+ )
339
+ presence_loss = tf.reduce_mean(presence_bce)
340
+
341
+ intensity_error = tf.square(x - intensity)
342
+ masked_intensity_error = intensity_error * peak_mask
343
+ num_peaks = tf.reduce_sum(peak_mask, axis=-1, keepdims=True) + 1e-6
344
+ intensity_loss = tf.reduce_mean(
345
+ tf.reduce_sum(masked_intensity_error, axis=-1) / tf.squeeze(num_peaks)
346
+ )
347
+
348
+ x_norm = tf.nn.l2_normalize(x, axis=-1)
349
+ x_recon_norm = tf.nn.l2_normalize(x_recon, axis=-1)
350
+ cos_sim = tf.reduce_sum(x_norm * x_recon_norm, axis=-1)
351
+ spectral_angle_loss = tf.reduce_mean(1 - cos_sim)
352
+
353
+ false_positive_mask = 1 - peak_mask
354
+ false_positive_penalty = tf.reduce_mean(presence_prob * false_positive_mask)
355
+
356
+ recon_loss = (
357
+ 1.0 * presence_loss
358
+ + 1.0 * intensity_loss
359
+ + 0.5 * spectral_angle_loss
360
+ + 0.5 * false_positive_penalty
361
+ )
362
+ return recon_loss, cos_sim, presence_prob, peak_mask
363
+
364
+ def _update_trackers(self, recon_loss, kl_loss, total_loss, cos_sim, presence_prob, peak_mask):
365
+ self.recon_loss_tracker.update_state(recon_loss)
366
+ self.kl_loss_tracker.update_state(kl_loss)
367
+ self.total_loss_tracker.update_state(total_loss)
368
+ self.cosine_sim_tracker.update_state(tf.reduce_mean(cos_sim))
369
+
370
+ sparsity = tf.reduce_mean(tf.cast(presence_prob < 0.1, tf.float32))
371
+ self.sparsity_tracker.update_state(sparsity)
372
+
373
+ presence_pred = tf.cast(presence_prob > 0.5, tf.float32)
374
+ presence_acc = tf.reduce_mean(tf.cast(tf.equal(presence_pred, peak_mask), tf.float32))
375
+ self.presence_acc_tracker.update_state(presence_acc)
376
+
377
+ predicted_peaks = tf.reduce_sum(presence_pred)
378
+ false_positives = tf.reduce_sum(presence_pred * (1 - peak_mask))
379
+ fp_rate = false_positives / (predicted_peaks + 1e-6)
380
+ self.fp_rate_tracker.update_state(fp_rate)
381
+
382
+ def _results(self):
383
+ return {
384
+ "loss": self.total_loss_tracker.result(),
385
+ "recon_loss": self.recon_loss_tracker.result(),
386
+ "kl_loss": self.kl_loss_tracker.result(),
387
+ "cosine_sim": self.cosine_sim_tracker.result(),
388
+ "sparsity": self.sparsity_tracker.result(),
389
+ "presence_acc": self.presence_acc_tracker.result(),
390
+ "fp_rate": self.fp_rate_tracker.result(),
391
+ }
392
+
393
+ def train_step(self, data):
394
+ x, cond = data
395
+ with tf.GradientTape() as tape:
396
+ if self.use_kl:
397
+ z, mu, logvar = self.encode(x, cond, training=True)
398
+ x_recon = self.decode(z, cond, training=True)
399
+ kl_loss = -0.5 * tf.reduce_mean(1 + logvar - tf.square(mu) - tf.exp(logvar))
400
+ else:
401
+ z = self.encode(x, cond, training=True)
402
+ x_recon = self.decode(z, cond, training=True)
403
+ kl_loss = 0.0
404
+
405
+ recon_loss, cos_sim, presence_prob, peak_mask = self._compute_losses(x, x_recon)
406
+ total_loss = recon_loss
407
+ if self.use_kl:
408
+ total_loss += self.kl_weight * kl_loss
409
+
410
+ gradients = tape.gradient(total_loss, self.trainable_variables)
411
+ self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
412
+ self._update_trackers(recon_loss, kl_loss, total_loss, cos_sim, presence_prob, peak_mask)
413
+ return self._results()
414
+
415
+ def test_step(self, data):
416
+ x, cond = data
417
+ if self.use_kl:
418
+ z, mu, logvar = self.encode(x, cond, training=False)
419
+ x_recon = self.decode(z, cond, training=False)
420
+ kl_loss = -0.5 * tf.reduce_mean(1 + logvar - tf.square(mu) - tf.exp(logvar))
421
+ else:
422
+ z = self.encode(x, cond, training=False)
423
+ x_recon = self.decode(z, cond, training=False)
424
+ kl_loss = 0.0
425
+
426
+ recon_loss, cos_sim, presence_prob, peak_mask = self._compute_losses(x, x_recon)
427
+ total_loss = recon_loss
428
+ if self.use_kl:
429
+ total_loss += self.kl_weight * kl_loss
430
+ self._update_trackers(recon_loss, kl_loss, total_loss, cos_sim, presence_prob, peak_mask)
431
+ return self._results()
432
+
433
+ @property
434
+ def metrics(self):
435
+ return [
436
+ self.total_loss_tracker,
437
+ self.recon_loss_tracker,
438
+ self.kl_loss_tracker,
439
+ self.cosine_sim_tracker,
440
+ self.sparsity_tracker,
441
+ self.presence_acc_tracker,
442
+ self.fp_rate_tracker,
443
+ ]
444
+
445
+ def get_config(self):
446
+ return {
447
+ "latent_dim": self.latent_dim,
448
+ "cond_dim": self.cond_dim,
449
+ "use_kl": self.use_kl,
450
+ "kl_weight": self.kl_weight,
451
+ }
@@ -0,0 +1,85 @@
1
+ """Spectrum preprocessing (NumPy).
2
+
3
+ These reproduce the exact binning and conditioning used during training, so
4
+ embeddings computed at inference time match the model's expectations. The
5
+ NumPy implementation keeps this module importable without TensorFlow.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import numpy as np
11
+
12
+ from .config import Config
13
+
14
+
15
+ def bin_spectrum_numpy(mz: np.ndarray, intensity: np.ndarray, config=Config) -> np.ndarray:
16
+ """Bin a single spectrum to the fixed-width vector the encoder expects."""
17
+ mz = np.asarray(mz, dtype=np.float64)
18
+ intensity = np.asarray(intensity, dtype=np.float64)
19
+
20
+ mask = (mz >= config.MZ_MIN) & (mz < config.MZ_MAX) & (intensity > 0)
21
+ mz = mz[mask]
22
+ intensity = intensity[mask]
23
+ if len(intensity) == 0:
24
+ return np.zeros(config.N_BINS, dtype=np.float32)
25
+
26
+ intensity = intensity / intensity.max()
27
+
28
+ mask = intensity >= config.RELATIVE_INTENSITY_THRESHOLD
29
+ mz = mz[mask]
30
+ intensity = intensity[mask]
31
+ if len(intensity) == 0:
32
+ return np.zeros(config.N_BINS, dtype=np.float32)
33
+
34
+ if getattr(config, "TOP_N_PEAKS", None) and len(intensity) > config.TOP_N_PEAKS:
35
+ top_idx = np.argsort(intensity)[-config.TOP_N_PEAKS:]
36
+ mz = mz[top_idx]
37
+ intensity = intensity[top_idx]
38
+
39
+ intensity = np.sqrt(intensity)
40
+
41
+ bin_indices = ((mz - config.MZ_MIN) / config.BIN_SIZE).astype(int)
42
+ bin_indices = np.clip(bin_indices, 0, config.N_BINS - 1)
43
+
44
+ binned = np.zeros(config.N_BINS, dtype=np.float32)
45
+ np.maximum.at(binned, bin_indices, intensity.astype(np.float32))
46
+
47
+ if binned.max() > 0:
48
+ binned = binned / binned.max()
49
+
50
+ return binned
51
+
52
+
53
+ def build_cond_vector(
54
+ precursor_mz: float,
55
+ charge: int,
56
+ ion_mobility: float,
57
+ config=Config,
58
+ ) -> np.ndarray:
59
+ """Build the conditioning vector (one-hot charge + norm. m/z + norm. IM).
60
+
61
+ Matches the training preprocessing; length is ``config.MAX_CHARGE + 2``.
62
+ """
63
+ charge_int = max(1, min(int(charge), config.MAX_CHARGE))
64
+ charge_onehot = np.zeros(config.MAX_CHARGE, dtype=np.float32)
65
+ charge_onehot[charge_int - 1] = 1.0
66
+
67
+ mz_norm = float(precursor_mz) / config.PRECURSOR_MZ_MAX
68
+ im_norm = float(
69
+ np.clip((ion_mobility - config.IM_MIN) / (config.IM_MAX - config.IM_MIN), 0.0, 1.0)
70
+ )
71
+ return np.concatenate([charge_onehot, [mz_norm, im_norm]]).astype(np.float32)
72
+
73
+
74
+ def preprocess(
75
+ mz: np.ndarray,
76
+ intensity: np.ndarray,
77
+ precursor_mz: float,
78
+ charge: int,
79
+ ion_mobility: float,
80
+ config=Config,
81
+ ) -> tuple[np.ndarray, np.ndarray]:
82
+ """Convenience: return ``(binned_spectrum, conditioning_vector)``."""
83
+ binned = bin_spectrum_numpy(mz, intensity, config)
84
+ cond = build_cond_vector(precursor_mz, charge, ion_mobility, config)
85
+ return binned, cond