eclipse-ms 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
eclipse_ms/__init__.py ADDED
@@ -0,0 +1,45 @@
1
+ """ECLIPSE: conditional spectrum autoencoder + clustering for MS/MS.
2
+
3
+ Importing this package does NOT import TensorFlow; TF is loaded lazily the
4
+ first time you build or load a model (e.g. ``load_encoder``) or call
5
+ ``embed_spectra``. This keeps imports fast and lets the clustering / consensus
6
+ utilities be used in TF-free environments.
7
+
8
+ The Keras model classes live in ``eclipse_ms.models`` (importing that submodule
9
+ does import TensorFlow).
10
+ """
11
+
12
+ from .config import COND_DIM, Config
13
+ from .preprocessing import bin_spectrum_numpy, build_cond_vector, preprocess
14
+ from .cluster import cluster_latents, score_clusters
15
+ from .consensus import generate_consensus_spectrum, write_mzml
16
+ from .modelhub import (
17
+ REGISTRY,
18
+ cache_dir,
19
+ get_model_file,
20
+ load_autoencoder,
21
+ load_encoder,
22
+ )
23
+ from .embed import embed_raw_spectra, embed_spectra
24
+
25
+ __version__ = "0.1.2"
26
+
27
+ __all__ = [
28
+ "__version__",
29
+ "Config",
30
+ "COND_DIM",
31
+ "bin_spectrum_numpy",
32
+ "build_cond_vector",
33
+ "preprocess",
34
+ "load_encoder",
35
+ "load_autoencoder",
36
+ "get_model_file",
37
+ "cache_dir",
38
+ "REGISTRY",
39
+ "embed_spectra",
40
+ "embed_raw_spectra",
41
+ "cluster_latents",
42
+ "score_clusters",
43
+ "generate_consensus_spectrum",
44
+ "write_mzml",
45
+ ]
eclipse_ms/cli.py ADDED
@@ -0,0 +1,107 @@
1
+ """ECLIPSE command-line interface.
2
+
3
+ Subcommands:
4
+ embed Bin + encode spectra from parquet to a latents .npy
5
+ cluster Cluster a latents .npy into cluster labels
6
+ consensus Build consensus spectra (mzML) from clusters
7
+
8
+ Training and HPC data-prep live in the repo's ``training/`` scripts, not in the
9
+ installed package.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import glob
16
+ import json
17
+ import os
18
+
19
+ import numpy as np
20
+
21
+
22
+ def _cmd_embed(args):
23
+ import pandas as pd
24
+
25
+ from .config import Config
26
+ from .embed import embed_raw_spectra
27
+ from .modelhub import load_encoder
28
+
29
+ encoder = load_encoder(weights=args.weights, config=args.config)
30
+
31
+ files = sorted(glob.glob(os.path.join(args.input, "*.parquet")))
32
+ print(f"Found {len(files)} parquet files")
33
+
34
+ mz, inten, pmz, charge, im = [], [], [], [], []
35
+ for fp in files:
36
+ df = pd.read_parquet(fp)
37
+ for _, row in df.iterrows():
38
+ mz.append(row["mz_array"])
39
+ inten.append(row["intensity_array"])
40
+ pmz.append(float(row.get("precursor_mz", 0.0)))
41
+ charge.append(int(row.get("precursor_charge", 2)))
42
+ im.append(float(row.get("ion_mobility", 0.0)))
43
+ if args.max_spectra and len(mz) >= args.max_spectra:
44
+ break
45
+ if args.max_spectra and len(mz) >= args.max_spectra:
46
+ break
47
+
48
+ latents = embed_raw_spectra(encoder, mz, inten, pmz, charge, im, Config, args.batch_size)
49
+ os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
50
+ np.save(args.output, latents)
51
+ print(f"Saved {latents.shape} latents to {args.output}")
52
+
53
+
54
+ def _cmd_cluster(args):
55
+ from .cluster import cluster_latents, score_clusters
56
+
57
+ latents = np.load(args.input)
58
+ labels, info = cluster_latents(
59
+ latents, method=args.method, min_cluster_size=args.min_cluster_size
60
+ )
61
+ os.makedirs(args.output, exist_ok=True)
62
+ np.save(os.path.join(args.output, "cluster_labels.npy"), labels)
63
+ with open(os.path.join(args.output, "cluster_info.json"), "w") as f:
64
+ json.dump(info, f, indent=2)
65
+ scores = score_clusters(latents, labels)
66
+ scores.to_csv(os.path.join(args.output, "cluster_scores.csv"), index=False)
67
+ print(f"Clusters: {info.get('n_clusters')}, noise: {info.get('n_noise', 0)}")
68
+ print(f"Wrote labels, info, and scores to {args.output}")
69
+
70
+
71
+ def _cmd_consensus(args):
72
+ print(
73
+ "Consensus generation needs spectra grouped by cluster. See "
74
+ "eclipse_ms.consensus.generate_consensus_spectrum and the example in "
75
+ "training/consensus_reference.py for the full pipeline."
76
+ )
77
+
78
+
79
+ def main(argv=None):
80
+ p = argparse.ArgumentParser(prog="eclipse", description="ECLIPSE")
81
+ sub = p.add_subparsers(dest="command", required=True)
82
+
83
+ pe = sub.add_parser("embed", help="Encode spectra to latents")
84
+ pe.add_argument("-i", "--input", required=True, help="Parquet directory")
85
+ pe.add_argument("-o", "--output", required=True, help="Output latents .npy")
86
+ pe.add_argument("--weights", default=None, help="Local encoder weights (.h5)")
87
+ pe.add_argument("--config", default=None, help="Local encoder config (.json)")
88
+ pe.add_argument("--batch-size", type=int, default=256)
89
+ pe.add_argument("--max-spectra", type=int, default=None)
90
+ pe.set_defaults(func=_cmd_embed)
91
+
92
+ pc = sub.add_parser("cluster", help="Cluster latents")
93
+ pc.add_argument("-i", "--input", required=True, help="latents .npy")
94
+ pc.add_argument("-o", "--output", required=True, help="Output directory")
95
+ pc.add_argument("--method", choices=["hdbscan", "kmeans"], default="hdbscan")
96
+ pc.add_argument("--min-cluster-size", type=int, default=5)
97
+ pc.set_defaults(func=_cmd_cluster)
98
+
99
+ pk = sub.add_parser("consensus", help="Consensus spectra from clusters")
100
+ pk.set_defaults(func=_cmd_consensus)
101
+
102
+ args = p.parse_args(argv)
103
+ args.func(args)
104
+
105
+
106
+ if __name__ == "__main__":
107
+ main()
eclipse_ms/cluster.py ADDED
@@ -0,0 +1,105 @@
1
+ """Cluster latent vectors and score the resulting clusters."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import time
6
+ from typing import Tuple
7
+
8
+ import numpy as np
9
+
10
+
11
+ def cluster_latents(
12
+ latents: np.ndarray,
13
+ method: str = "hdbscan",
14
+ min_cluster_size: int = 5,
15
+ pca_dims: int = 100,
16
+ random_state: int = 42,
17
+ ) -> Tuple[np.ndarray, dict]:
18
+ """Cluster latent vectors with HDBSCAN (preferred) or MiniBatchKMeans.
19
+
20
+ Reduces dimensionality with PCA first when ``n_dims > pca_dims``. Falls back
21
+ to KMeans if HDBSCAN is not installed.
22
+
23
+ Returns ``(labels, info)`` where ``labels`` is ``-1`` for HDBSCAN noise.
24
+ """
25
+ from sklearn.decomposition import PCA
26
+
27
+ n_samples, n_dims = latents.shape
28
+ info = {"method": method, "n_samples": int(n_samples)}
29
+
30
+ if n_dims > pca_dims:
31
+ pca = PCA(n_components=pca_dims, random_state=random_state)
32
+ reduced = pca.fit_transform(latents)
33
+ info["pca_variance_explained"] = float(pca.explained_variance_ratio_.sum())
34
+ else:
35
+ reduced = latents
36
+
37
+ if method == "hdbscan":
38
+ try:
39
+ import hdbscan
40
+ except ImportError:
41
+ print("hdbscan not installed (`pip install eclipse-ms[hdbscan]`); using KMeans.")
42
+ method = "kmeans"
43
+
44
+ if method == "hdbscan":
45
+ import hdbscan
46
+
47
+ start = time.time()
48
+ clusterer = hdbscan.HDBSCAN(
49
+ min_cluster_size=min_cluster_size,
50
+ min_samples=3,
51
+ metric="euclidean",
52
+ cluster_selection_method="eom",
53
+ core_dist_n_jobs=-1,
54
+ )
55
+ labels = clusterer.fit_predict(reduced)
56
+ info["time"] = time.time() - start
57
+ info["n_clusters"] = len(set(labels)) - (1 if -1 in labels else 0)
58
+ info["n_noise"] = int((labels == -1).sum())
59
+
60
+ elif method == "kmeans":
61
+ from sklearn.cluster import MiniBatchKMeans
62
+
63
+ n_clusters = max(2, min(10000, n_samples // 10))
64
+ start = time.time()
65
+ kmeans = MiniBatchKMeans(
66
+ n_clusters=n_clusters, batch_size=1024, random_state=random_state, n_init=3
67
+ )
68
+ labels = kmeans.fit_predict(reduced)
69
+ info["time"] = time.time() - start
70
+ info["n_clusters"] = len(set(labels))
71
+ info["n_noise"] = 0
72
+ else:
73
+ raise ValueError(f"Unknown method: {method}")
74
+
75
+ return labels, info
76
+
77
+
78
+ def score_clusters(latents: np.ndarray, labels: np.ndarray) -> "pd.DataFrame": # noqa: F821
79
+ """Lightweight per-cluster quality scores from latent geometry.
80
+
81
+ For each non-noise cluster, reports size and intra-cluster cohesion
82
+ (mean distance to centroid; smaller = tighter).
83
+ """
84
+ import pandas as pd
85
+
86
+ rows = []
87
+ for c in sorted(set(labels)):
88
+ if c == -1:
89
+ continue
90
+ idx = np.where(labels == c)[0]
91
+ pts = latents[idx]
92
+ centroid = pts.mean(axis=0)
93
+ dists = np.linalg.norm(pts - centroid, axis=1)
94
+ rows.append(
95
+ {
96
+ "cluster": int(c),
97
+ "size": int(len(idx)),
98
+ "cohesion_mean_dist": float(dists.mean()),
99
+ "cohesion_std_dist": float(dists.std()),
100
+ }
101
+ )
102
+ df = pd.DataFrame(rows)
103
+ if not df.empty:
104
+ df = df.sort_values("size", ascending=False).reset_index(drop=True)
105
+ return df
eclipse_ms/config.py ADDED
@@ -0,0 +1,53 @@
1
+ """Centralised configuration for ECLIPSE.
2
+
3
+ These values define the spectrum binning and the autoencoder architecture. They
4
+ MUST match the configuration used to train the published weights; the binning
5
+ parameters in particular are baked into the model input.
6
+ """
7
+
8
+
9
+ class Config:
10
+ """Centralized configuration."""
11
+
12
+ SEED = 42
13
+
14
+ # Binning parameters (autoencoder input)
15
+ MZ_MIN = 100.0
16
+ MZ_MAX = 1700.0
17
+ BIN_SIZE = 0.5
18
+ N_BINS = int((MZ_MAX - MZ_MIN) / BIN_SIZE) # 3200
19
+
20
+ # Preprocessing
21
+ RELATIVE_INTENSITY_THRESHOLD = 0.01
22
+ TOP_N_PEAKS = 100 # keep only the top N most intense peaks (None to disable)
23
+
24
+ # Ion mobility normalisation (1/K0 range for peptides)
25
+ IM_MIN = 0.6
26
+ IM_MAX = 1.6
27
+
28
+ # Precursor features
29
+ MAX_CHARGE = 6
30
+ PRECURSOR_MZ_MAX = 1700.0
31
+
32
+ # Autoencoder architecture
33
+ LATENT_DIM = 256
34
+ AE_EMBED_DIM = 256
35
+ AE_NUM_HEADS = 8
36
+ AE_NUM_LAYERS = 4
37
+ AE_FF_DIM = 512
38
+ AE_PATCH_SIZE = 16
39
+ AE_DROPOUT = 0.1
40
+
41
+ # Conditioning embedding
42
+ COND_EMBED_DIM = 256
43
+
44
+ # Training - Autoencoder
45
+ AE_BATCH_SIZE = 256
46
+ AE_INITIAL_LR = 1e-4
47
+ AE_EPOCHS = 50
48
+ AE_WARMUP_EPOCHS = 5
49
+ WEIGHT_DECAY = 1e-5
50
+
51
+
52
+ # Conditioning dimension: one-hot charge (MAX_CHARGE) + precursor_mz (1) + IM (1).
53
+ COND_DIM = Config.MAX_CHARGE + 2
@@ -0,0 +1,148 @@
1
+ """Consensus (averaged) spectrum generation for clusters.
2
+
3
+ The consensus algorithm is ported from the project's consensus script; the full
4
+ original (with the richer mzML writer and FragPipe-specific helpers) is kept in
5
+ ``training/consensus_reference.py`` for reference.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Dict, List, Optional
11
+
12
+ import numpy as np
13
+
14
+
15
+ def generate_consensus_spectrum(
16
+ spectra: List[Dict],
17
+ mz_tolerance: float = 0.01,
18
+ min_peak_occurrence: float = 0.3,
19
+ ) -> Optional[Dict]:
20
+ """Average multiple spectra into a single consensus spectrum.
21
+
22
+ Peaks within ``mz_tolerance`` (Da) are binned; a bin is kept if it occurs in
23
+ at least ``min_peak_occurrence`` of the spectra. Each input dict must have
24
+ ``mz_array``, ``intensity_array``, ``precursor_mz``, ``charge``,
25
+ ``ion_mobility``.
26
+
27
+ Returns a consensus dict, or ``None`` if nothing passes the filters.
28
+ """
29
+ if not spectra:
30
+ return None
31
+ n_spectra = len(spectra)
32
+
33
+ all_peaks = []
34
+ for spec in spectra:
35
+ mz = np.asarray(spec["mz_array"], dtype=float)
36
+ intensity = np.asarray(spec["intensity_array"], dtype=float)
37
+ if len(intensity) > 0 and intensity.max() > 0:
38
+ intensity = intensity / intensity.max()
39
+ all_peaks.extend(zip(mz, intensity))
40
+
41
+ if not all_peaks:
42
+ return None
43
+ all_peaks.sort(key=lambda x: x[0])
44
+
45
+ bins: List[List] = []
46
+ current = [all_peaks[0]]
47
+ for peak in all_peaks[1:]:
48
+ if peak[0] - current[0][0] <= mz_tolerance:
49
+ current.append(peak)
50
+ else:
51
+ bins.append(current)
52
+ current = [peak]
53
+ bins.append(current)
54
+
55
+ consensus_mz, consensus_intensity = [], []
56
+ for bin_peaks in bins:
57
+ occurrence = min(len(bin_peaks), n_spectra) / n_spectra
58
+ if occurrence >= min_peak_occurrence:
59
+ consensus_mz.append(np.mean([p[0] for p in bin_peaks]))
60
+ consensus_intensity.append(np.mean([p[1] for p in bin_peaks]))
61
+
62
+ if not consensus_mz:
63
+ return None
64
+
65
+ consensus_mz = np.array(consensus_mz)
66
+ consensus_intensity = np.array(consensus_intensity)
67
+ if consensus_intensity.max() > 0:
68
+ consensus_intensity = consensus_intensity / consensus_intensity.max() * 10000
69
+
70
+ order = np.argsort(consensus_mz)
71
+ consensus_mz = consensus_mz[order]
72
+ consensus_intensity = consensus_intensity[order]
73
+
74
+ precursor_mz = float(np.mean([s["precursor_mz"] for s in spectra]))
75
+ ims = [s["ion_mobility"] for s in spectra if s.get("ion_mobility", 0) > 0]
76
+ ion_mobility = float(np.mean(ims)) if ims else 0.0
77
+ charges = [
78
+ s["charge"] for s in spectra if s.get("charge") and not np.isnan(s["charge"])
79
+ ]
80
+ charge = int(np.median(charges)) if charges else 2
81
+
82
+ return {
83
+ "mz_array": consensus_mz,
84
+ "intensity_array": consensus_intensity,
85
+ "precursor_mz": precursor_mz,
86
+ "charge": charge,
87
+ "ion_mobility": ion_mobility,
88
+ "n_spectra": n_spectra,
89
+ "n_peaks": len(consensus_mz),
90
+ }
91
+
92
+
93
+ def write_mzml(consensus_spectra: List[Dict], output_path: str) -> None:
94
+ """Write consensus spectra to a minimal, valid mzML file.
95
+
96
+ A compact writer covering MS2 centroid spectra with a single precursor.
97
+ For the full-featured writer (instrument metadata, etc.), see
98
+ ``training/consensus_reference.py``.
99
+ """
100
+ import base64
101
+ import struct
102
+ from datetime import datetime, timezone
103
+
104
+ def _b64(arr, dtype):
105
+ packed = struct.pack(f"<{len(arr)}{'d' if dtype == 64 else 'f'}", *arr)
106
+ return base64.b64encode(packed).decode("ascii")
107
+
108
+ lines = [
109
+ '<?xml version="1.0" encoding="utf-8"?>',
110
+ '<mzML xmlns="http://psi.hupo.org/ms/mzml" version="1.1.0">',
111
+ " <run id=\"eclipse_consensus\" "
112
+ f'startTimeStamp="{datetime.now(timezone.utc).isoformat()}">',
113
+ f' <spectrumList count="{len(consensus_spectra)}">',
114
+ ]
115
+ for idx, s in enumerate(consensus_spectra):
116
+ mz = np.asarray(s["mz_array"], dtype=np.float64)
117
+ inten = np.asarray(s["intensity_array"], dtype=np.float32)
118
+ mz_b64 = _b64(mz, 64)
119
+ in_b64 = _b64(inten, 32)
120
+ lines += [
121
+ f' <spectrum index="{idx}" id="scan={idx + 1}" '
122
+ f'defaultArrayLength="{len(mz)}">',
123
+ ' <cvParam cvRef="MS" accession="MS:1000580" name="MSn spectrum"/>',
124
+ ' <cvParam cvRef="MS" accession="MS:1000511" name="ms level" '
125
+ 'value="2"/>',
126
+ " <precursorList count=\"1\"><precursor><selectedIonList count=\"1\">"
127
+ "<selectedIon>",
128
+ ' <cvParam cvRef="MS" accession="MS:1000744" '
129
+ f'name="selected ion m/z" value="{s["precursor_mz"]}"/>',
130
+ ' <cvParam cvRef="MS" accession="MS:1000041" '
131
+ f'name="charge state" value="{s["charge"]}"/>',
132
+ " </selectedIon></selectedIonList></precursor></precursorList>",
133
+ ' <binaryDataArrayList count="2">',
134
+ f' <binaryDataArray encodedLength="{len(mz_b64)}">'
135
+ '<cvParam cvRef="MS" accession="MS:1000523" name="64-bit float"/>'
136
+ '<cvParam cvRef="MS" accession="MS:1000514" name="m/z array"/>'
137
+ f"<binary>{mz_b64}</binary></binaryDataArray>",
138
+ f' <binaryDataArray encodedLength="{len(in_b64)}">'
139
+ '<cvParam cvRef="MS" accession="MS:1000521" name="32-bit float"/>'
140
+ '<cvParam cvRef="MS" accession="MS:1000515" name="intensity array"/>'
141
+ f"<binary>{in_b64}</binary></binaryDataArray>",
142
+ " </binaryDataArrayList>",
143
+ " </spectrum>",
144
+ ]
145
+ lines += [" </spectrumList>", " </run>", "</mzML>"]
146
+ with open(output_path, "w") as f:
147
+ f.write("\n".join(lines))
148
+ print(f"Wrote {len(consensus_spectra)} consensus spectra to {output_path}")
eclipse_ms/embed.py ADDED
@@ -0,0 +1,96 @@
1
+ """Embed spectra into the autoencoder latent space."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import time
6
+ from typing import Optional
7
+
8
+ import numpy as np
9
+
10
+ from .config import Config
11
+ from .preprocessing import bin_spectrum_numpy, build_cond_vector
12
+
13
+
14
+ def embed_spectra(
15
+ encoder,
16
+ spectra: np.ndarray,
17
+ conditioning: np.ndarray,
18
+ batch_size: int = 256,
19
+ latent_dim: Optional[int] = None,
20
+ verbose: bool = True,
21
+ ) -> np.ndarray:
22
+ """Encode pre-binned spectra to latent vectors in batches.
23
+
24
+ Args:
25
+ encoder: a loaded encoder (see :func:`eclipse_ms.modelhub.load_encoder`)
26
+ or a full autoencoder exposing ``.encode``.
27
+ spectra: array ``[n, N_BINS]`` of binned spectra (float32).
28
+ conditioning: array ``[n, cond_dim]`` of conditioning vectors.
29
+ batch_size: encode this many spectra at a time.
30
+ latent_dim: output dimension; inferred from the first batch if None.
31
+
32
+ Returns:
33
+ Array ``[n, latent_dim]`` of latent vectors.
34
+ """
35
+ import tensorflow as tf
36
+
37
+ n = len(spectra)
38
+ if n != len(conditioning):
39
+ raise ValueError(f"spectra and conditioning differ in length: {n} vs {len(conditioning)}")
40
+
41
+ spectra = np.asarray(spectra, dtype=np.float32)
42
+ conditioning = np.asarray(conditioning, dtype=np.float32)
43
+
44
+ def _encode(x, cond):
45
+ # Support both an Encoder (callable) and a full AE (has .encode).
46
+ if hasattr(encoder, "encode"):
47
+ z = encoder.encode(x, cond, training=False)
48
+ if isinstance(z, (tuple, list)):
49
+ z = z[0]
50
+ return z
51
+ return encoder((x, cond), training=False)
52
+
53
+ if latent_dim is None:
54
+ probe = _encode(tf.convert_to_tensor(spectra[:1]), tf.convert_to_tensor(conditioning[:1]))
55
+ latent_dim = int(probe.shape[-1])
56
+
57
+ latents = np.zeros((n, latent_dim), dtype=np.float32)
58
+ start = time.time()
59
+ for i in range(0, n, batch_size):
60
+ end = min(i + batch_size, n)
61
+ z = _encode(
62
+ tf.convert_to_tensor(spectra[i:end]),
63
+ tf.convert_to_tensor(conditioning[i:end]),
64
+ )
65
+ latents[i:end] = np.asarray(z)
66
+ if verbose and (i // batch_size + 1) % 100 == 0:
67
+ rate = end / (time.time() - start)
68
+ print(f" encoded {end:,}/{n:,} ({rate:.0f}/s)")
69
+ if verbose:
70
+ print(f" done in {(time.time() - start) / 60:.1f} min")
71
+ return latents
72
+
73
+
74
+ def embed_raw_spectra(
75
+ encoder,
76
+ mz_list,
77
+ intensity_list,
78
+ precursor_mz,
79
+ charge,
80
+ ion_mobility,
81
+ config=Config,
82
+ batch_size: int = 256,
83
+ verbose: bool = True,
84
+ ) -> np.ndarray:
85
+ """Bin + condition raw peak lists, then embed.
86
+
87
+ Each of ``precursor_mz``, ``charge``, ``ion_mobility`` is a sequence aligned
88
+ with ``mz_list`` / ``intensity_list``.
89
+ """
90
+ n = len(mz_list)
91
+ spectra = np.zeros((n, config.N_BINS), dtype=np.float32)
92
+ cond = np.zeros((n, config.MAX_CHARGE + 2), dtype=np.float32)
93
+ for i in range(n):
94
+ spectra[i] = bin_spectrum_numpy(mz_list[i], intensity_list[i], config)
95
+ cond[i] = build_cond_vector(precursor_mz[i], charge[i], ion_mobility[i], config)
96
+ return embed_spectra(encoder, spectra, cond, batch_size=batch_size, verbose=verbose)
eclipse_ms/layers.py ADDED
@@ -0,0 +1,84 @@
1
+ """Keras layer building blocks for the ECLIPSE autoencoder.
2
+
3
+ Ported verbatim from the training code so the registered serialisable layers
4
+ reconstruct with exactly the same weight structure when loading published
5
+ weights.
6
+ """
7
+
8
+ import tensorflow as tf
9
+ from tensorflow import keras
10
+ from tensorflow.keras import layers
11
+
12
+
13
+ @keras.utils.register_keras_serializable()
14
+ class PatchEmbedding(layers.Layer):
15
+ """Convert a 1D spectrum into patch embeddings."""
16
+
17
+ def __init__(self, embed_dim: int = 256, patch_size: int = 16, **kwargs):
18
+ super().__init__(**kwargs)
19
+ self.embed_dim = embed_dim
20
+ self.patch_size = patch_size
21
+ self.projection = layers.Dense(embed_dim)
22
+
23
+ def call(self, x):
24
+ batch_size = tf.shape(x)[0]
25
+ x = tf.reshape(x, [batch_size, -1, self.patch_size])
26
+ return self.projection(x)
27
+
28
+ def get_config(self):
29
+ config = super().get_config()
30
+ config.update({"embed_dim": self.embed_dim, "patch_size": self.patch_size})
31
+ return config
32
+
33
+
34
+ @keras.utils.register_keras_serializable()
35
+ class TransformerBlock(layers.Layer):
36
+ """Pre-norm transformer block."""
37
+
38
+ def __init__(self, embed_dim=256, num_heads=8, ff_dim=512, dropout=0.1, **kwargs):
39
+ super().__init__(**kwargs)
40
+ self.embed_dim = embed_dim
41
+ self.num_heads = num_heads
42
+ self.ff_dim = ff_dim
43
+ self.dropout_rate = dropout
44
+
45
+ self.att = layers.MultiHeadAttention(
46
+ num_heads=num_heads,
47
+ key_dim=embed_dim // num_heads,
48
+ dropout=dropout,
49
+ )
50
+
51
+ self.ffn = keras.Sequential(
52
+ [
53
+ layers.Dense(ff_dim, activation="gelu"),
54
+ layers.Dropout(dropout),
55
+ layers.Dense(embed_dim),
56
+ layers.Dropout(dropout),
57
+ ]
58
+ )
59
+
60
+ self.norm1 = layers.LayerNormalization(epsilon=1e-6)
61
+ self.norm2 = layers.LayerNormalization(epsilon=1e-6)
62
+
63
+ def call(self, x, training=False):
64
+ x_norm = self.norm1(x)
65
+ attn_out = self.att(x_norm, x_norm, training=training)
66
+ x = x + attn_out
67
+
68
+ x_norm = self.norm2(x)
69
+ ffn_out = self.ffn(x_norm, training=training)
70
+ x = x + ffn_out
71
+
72
+ return x
73
+
74
+ def get_config(self):
75
+ config = super().get_config()
76
+ config.update(
77
+ {
78
+ "embed_dim": self.embed_dim,
79
+ "num_heads": self.num_heads,
80
+ "ff_dim": self.ff_dim,
81
+ "dropout": self.dropout_rate,
82
+ }
83
+ )
84
+ return config