cystainer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cystainer/__init__.py +16 -0
- cystainer/data.py +204 -0
- cystainer/modules.py +277 -0
- cystainer/runner.py +318 -0
- cystainer-0.1.0.dist-info/METADATA +170 -0
- cystainer-0.1.0.dist-info/RECORD +9 -0
- cystainer-0.1.0.dist-info/WHEEL +5 -0
- cystainer-0.1.0.dist-info/licenses/LICENSE +21 -0
- cystainer-0.1.0.dist-info/top_level.txt +1 -0
cystainer/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# cystainer/__init__.py
|
|
2
|
+
|
|
3
|
+
# Package version
|
|
4
|
+
__version__ = "0.1.0"
|
|
5
|
+
|
|
6
|
+
# Imports
|
|
7
|
+
from .runner import CyStainer
|
|
8
|
+
from .data import load_data_from_folder
|
|
9
|
+
from .modules import CyStainerModel
|
|
10
|
+
|
|
11
|
+
# Imports `from cystainer import *`
|
|
12
|
+
__all__ = [
|
|
13
|
+
"CyStainerModel",
|
|
14
|
+
"CyStainer",
|
|
15
|
+
"load_data_from_folder",
|
|
16
|
+
]
|
cystainer/data.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import itertools
|
|
3
|
+
import torch
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import anndata as ad
|
|
7
|
+
import matplotlib.pyplot as plt
|
|
8
|
+
import matplotlib.patches as patches
|
|
9
|
+
from torch.utils.data import DataLoader
|
|
10
|
+
from .modules import get_marker_list, pad_and_get_pad_mask
|
|
11
|
+
|
|
12
|
+
def load_data_from_folder(folder_path=None, adata_list=None, exclude_files=[], max_n_cells=int(1e+6), batch_size=1024, is_train=True, drop_markers=None,
|
|
13
|
+
reference_markers=None, reference_batch_dict=None, normalize=False, batch_info=None, get_panel_vis=False):
|
|
14
|
+
"""Reads .h5ad files from a folder OR accepts a list of AnnData objects directly."""
|
|
15
|
+
|
|
16
|
+
# Establish the adata_list
|
|
17
|
+
if adata_list is None:
|
|
18
|
+
if folder_path is None:
|
|
19
|
+
raise ValueError("You must provide either 'folder_path' or 'adata_list'.")
|
|
20
|
+
|
|
21
|
+
sample_names = [f for f in os.listdir(folder_path) if f.endswith('.h5ad') and (f not in exclude_files)]
|
|
22
|
+
if not sample_names:
|
|
23
|
+
raise ValueError(f"No files found in {folder_path}")
|
|
24
|
+
|
|
25
|
+
# Load adatas from disk
|
|
26
|
+
loaded_adatas = [ad.read_h5ad(os.path.join(folder_path, f)) for f in sample_names]
|
|
27
|
+
else:
|
|
28
|
+
# Use the provided list of AnnData objects
|
|
29
|
+
loaded_adatas = adata_list
|
|
30
|
+
|
|
31
|
+
# Extract DataFrames and Batch Info
|
|
32
|
+
df_list = [adata.to_df().iloc[:max_n_cells,:] for adata in loaded_adatas]
|
|
33
|
+
|
|
34
|
+
if batch_info is not None:
|
|
35
|
+
batch_list = [adata.obs[batch_info].tolist() for adata in loaded_adatas]
|
|
36
|
+
batch_names = list(itertools.chain.from_iterable(batch_list))
|
|
37
|
+
else:
|
|
38
|
+
print("No batch info passed, setting 'single_batch' for the data.")
|
|
39
|
+
batch_names = list(itertools.chain.from_iterable([['single_batch'] * df.shape[0] for df in df_list]))
|
|
40
|
+
|
|
41
|
+
# Visualization, Processing, and Normalization
|
|
42
|
+
if get_panel_vis:
|
|
43
|
+
visualize_panel_overlap(df_list)
|
|
44
|
+
|
|
45
|
+
if drop_markers is not None:
|
|
46
|
+
df_list = [df.loc[:,~df.columns.isin(drop_markers)] for df in df_list]
|
|
47
|
+
|
|
48
|
+
if normalize:
|
|
49
|
+
df_list = [(df - df.mean()) / df.std() for df in df_list]
|
|
50
|
+
|
|
51
|
+
# Handle Batches
|
|
52
|
+
if reference_batch_dict is None:
|
|
53
|
+
# Sort for deterministic behavior
|
|
54
|
+
batch_dict = {batch: i for i, batch in enumerate(sorted(list(set(batch_names))))}
|
|
55
|
+
else:
|
|
56
|
+
batch_dict = reference_batch_dict.copy()
|
|
57
|
+
current_max_idx = max(batch_dict.values()) if batch_dict else -1
|
|
58
|
+
|
|
59
|
+
# Find only the new batches and sort them
|
|
60
|
+
new_batches = sorted(list(set(batch_names) - set(batch_dict.keys())))
|
|
61
|
+
|
|
62
|
+
for batch in new_batches:
|
|
63
|
+
current_max_idx += 1
|
|
64
|
+
batch_dict[batch] = current_max_idx
|
|
65
|
+
|
|
66
|
+
batch_ids = pd.Series(batch_names).apply(lambda x: batch_dict[x]).values
|
|
67
|
+
|
|
68
|
+
# Handle Markers and Padding
|
|
69
|
+
if reference_markers is None:
|
|
70
|
+
marker_list, shared_markers, unique_markers = get_marker_list(df_list)
|
|
71
|
+
else:
|
|
72
|
+
marker_list = reference_markers['marker_list']
|
|
73
|
+
shared_markers = reference_markers['shared_markers']
|
|
74
|
+
unique_markers = reference_markers['unique_markers']
|
|
75
|
+
|
|
76
|
+
df_list, pad_mask = pad_and_get_pad_mask(df_list, marker_list)
|
|
77
|
+
|
|
78
|
+
# PyTorch Dataloader
|
|
79
|
+
data = [{'x': x.astype('float32'), 'batch': b, 'pad': p}
|
|
80
|
+
for x, b, p in zip(pd.concat(df_list).values, batch_ids, np.concatenate(pad_mask, axis=0))]
|
|
81
|
+
|
|
82
|
+
dataloader = DataLoader(data, batch_size=batch_size, shuffle=is_train)
|
|
83
|
+
|
|
84
|
+
markers_info = {
|
|
85
|
+
'marker_list': marker_list,
|
|
86
|
+
'shared_markers': shared_markers,
|
|
87
|
+
'unique_markers': unique_markers
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
return dataloader, markers_info, batch_dict
|
|
91
|
+
|
|
92
|
+
def visualize_panel_overlap(df_list, alignment='left'):
|
|
93
|
+
"""
|
|
94
|
+
Creates a block alignment plot showing shared and unique markers
|
|
95
|
+
across different panels.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
df_list (list): List of pandas DataFrames.
|
|
99
|
+
panel_names (list): List of panel names as strings.
|
|
100
|
+
alignment (str): 'center' for a center-outwards pyramid style,
|
|
101
|
+
'left' for a cascading left-aligned style.
|
|
102
|
+
"""
|
|
103
|
+
# Extract markers and build presence matrix
|
|
104
|
+
panel_markers = set([frozenset(df.columns) for df in df_list])
|
|
105
|
+
all_markers = set().union(*panel_markers)
|
|
106
|
+
num_panels = len(panel_markers)
|
|
107
|
+
panel_names = [f'Panel {i+1}' for i in range(num_panels)]
|
|
108
|
+
|
|
109
|
+
matrix = pd.DataFrame(index=list(all_markers), columns=panel_names)
|
|
110
|
+
for name, markers in zip(panel_names, panel_markers):
|
|
111
|
+
matrix[name] = matrix.index.isin(markers).astype(int)
|
|
112
|
+
|
|
113
|
+
# Sort markers based on requested alignment
|
|
114
|
+
if alignment == 'left':
|
|
115
|
+
# Simple cascade sort
|
|
116
|
+
matrix = matrix.sort_values(by=panel_names, ascending=False)
|
|
117
|
+
sorted_markers = matrix.index.tolist()
|
|
118
|
+
|
|
119
|
+
elif alignment == 'center':
|
|
120
|
+
# Center-outwards sort (Frequency and Center of Mass)
|
|
121
|
+
matrix['freq'] = matrix.sum(axis=1)
|
|
122
|
+
col_indices = np.arange(num_panels)
|
|
123
|
+
|
|
124
|
+
matrix['com'] = (matrix[panel_names] * col_indices).sum(axis=1) / matrix['freq']
|
|
125
|
+
matrix['pattern'] = matrix[panel_names].astype(str).agg(''.join, axis=1)
|
|
126
|
+
|
|
127
|
+
grouped = matrix.groupby('pattern')
|
|
128
|
+
group_stats = grouped.first()[['freq', 'com']].sort_values(by=['freq', 'com'], ascending=[False, True])
|
|
129
|
+
|
|
130
|
+
left_part = []
|
|
131
|
+
right_part = []
|
|
132
|
+
center_part = []
|
|
133
|
+
mid_point = (num_panels - 1) / 2.0
|
|
134
|
+
|
|
135
|
+
for pattern, row in group_stats.iterrows():
|
|
136
|
+
group_markers = matrix[matrix['pattern'] == pattern].index.tolist()
|
|
137
|
+
group_markers.sort() # Alphabetical fallback
|
|
138
|
+
|
|
139
|
+
if row['freq'] == num_panels:
|
|
140
|
+
center_part.extend(group_markers)
|
|
141
|
+
else:
|
|
142
|
+
if row['com'] < mid_point:
|
|
143
|
+
left_part = group_markers + left_part
|
|
144
|
+
elif row['com'] > mid_point:
|
|
145
|
+
right_part = right_part + group_markers
|
|
146
|
+
else:
|
|
147
|
+
if len(left_part) <= len(right_part):
|
|
148
|
+
left_part = group_markers + left_part
|
|
149
|
+
else:
|
|
150
|
+
right_part = right_part + group_markers
|
|
151
|
+
|
|
152
|
+
sorted_markers = left_part + center_part + right_part
|
|
153
|
+
|
|
154
|
+
else:
|
|
155
|
+
raise ValueError("The 'alignment' parameter must be either 'left' or 'center'.")
|
|
156
|
+
|
|
157
|
+
# Setup the plot
|
|
158
|
+
fig, ax = plt.subplots(figsize=(14, len(panel_names) * 1.5))
|
|
159
|
+
rect_height = 0.6
|
|
160
|
+
colors = ['#72A0C1', '#D65A61', '#E8B358', '#7CE0C9', '#8291A8', '#C48CB3']
|
|
161
|
+
|
|
162
|
+
# Draw the blocks
|
|
163
|
+
for i, (name, markers) in enumerate(zip(panel_names, panel_markers)):
|
|
164
|
+
presence = [1 if m in markers else 0 for m in sorted_markers]
|
|
165
|
+
start_idx = None
|
|
166
|
+
for j, val in enumerate(presence):
|
|
167
|
+
if val == 1 and start_idx is None:
|
|
168
|
+
start_idx = j
|
|
169
|
+
elif val == 0 and start_idx is not None:
|
|
170
|
+
rect = patches.Rectangle(
|
|
171
|
+
(start_idx, i - rect_height/2), j - start_idx, rect_height,
|
|
172
|
+
edgecolor='black', facecolor=colors[i % len(colors)], alpha=0.9, linewidth=0.8
|
|
173
|
+
)
|
|
174
|
+
ax.add_patch(rect)
|
|
175
|
+
start_idx = None
|
|
176
|
+
|
|
177
|
+
# Catch end blocks
|
|
178
|
+
if start_idx is not None:
|
|
179
|
+
rect = patches.Rectangle(
|
|
180
|
+
(start_idx, i - rect_height/2), len(presence) - start_idx, rect_height,
|
|
181
|
+
edgecolor='black', facecolor=colors[i % len(colors)], alpha=0.9, linewidth=0.8
|
|
182
|
+
)
|
|
183
|
+
ax.add_patch(rect)
|
|
184
|
+
|
|
185
|
+
ax.text(-0.8, i, name, va='center', ha='right', fontsize=12, fontweight='bold')
|
|
186
|
+
|
|
187
|
+
# Formatting
|
|
188
|
+
ax.set_xlim(0, len(sorted_markers))
|
|
189
|
+
ax.set_ylim(-1, len(panel_names))
|
|
190
|
+
ax.set_yticks([])
|
|
191
|
+
ax.set_xticks(np.arange(len(sorted_markers)) + 0.5)
|
|
192
|
+
ax.set_xticklabels(sorted_markers, rotation=90, ha='center', fontsize=10)
|
|
193
|
+
|
|
194
|
+
# Remove borders
|
|
195
|
+
for spine in ['top', 'right', 'left']:
|
|
196
|
+
ax.spines[spine].set_visible(False)
|
|
197
|
+
|
|
198
|
+
ax.xaxis.grid(True, linestyle='--', alpha=0.4)
|
|
199
|
+
ax.set_axisbelow(True)
|
|
200
|
+
|
|
201
|
+
title_align = "Center" if alignment == 'center' else "Left"
|
|
202
|
+
plt.title(f"Panel Marker Alignment", fontsize=16, pad=25, fontweight='bold')
|
|
203
|
+
plt.tight_layout()
|
|
204
|
+
plt.show()
|
cystainer/modules.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
# modules
|
|
2
|
+
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from torch.distributions import Normal
|
|
8
|
+
from torch.distributions import kl_divergence as kl
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class CyStainerModel(nn.Module):
|
|
12
|
+
def __init__(self, marker_info, batch_dict, hidden_size=[512, 256, 128], feature_hidden_size=[128, 64, 32],
|
|
13
|
+
num_layers=3, num_heads=8, value_emb_dim=8, marker_emb_dim=8,
|
|
14
|
+
batch_emb_dim=8, latent_size=16, dropout=0.1):
|
|
15
|
+
super(CyStainerModel, self).__init__()
|
|
16
|
+
self.latent_size = latent_size
|
|
17
|
+
self.batch_dict = batch_dict
|
|
18
|
+
self.num_batches = len(batch_dict)
|
|
19
|
+
self.marker_info = marker_info
|
|
20
|
+
self.n_marker = len(marker_info['marker_list'])
|
|
21
|
+
self.value_emb_dim = value_emb_dim
|
|
22
|
+
self.marker_emb_dim = marker_emb_dim
|
|
23
|
+
self.batch_emb_dim = batch_emb_dim
|
|
24
|
+
|
|
25
|
+
self.marker_embedding = nn.Embedding(self.n_marker, marker_emb_dim)
|
|
26
|
+
self.batch_embedding = nn.Embedding(self.num_batches, batch_emb_dim)
|
|
27
|
+
self.value_projector = nn.Linear(1, value_emb_dim)
|
|
28
|
+
self.value_layer_norm = nn.LayerNorm(value_emb_dim)
|
|
29
|
+
|
|
30
|
+
transformer_input_dim = value_emb_dim + marker_emb_dim + batch_emb_dim
|
|
31
|
+
|
|
32
|
+
transformer_encoder_layer = nn.TransformerEncoderLayer(
|
|
33
|
+
d_model=transformer_input_dim,
|
|
34
|
+
nhead=num_heads,
|
|
35
|
+
dim_feedforward=hidden_size[0],
|
|
36
|
+
dropout=dropout,
|
|
37
|
+
activation='gelu',
|
|
38
|
+
batch_first=True
|
|
39
|
+
)
|
|
40
|
+
self.transformer_encoder = nn.TransformerEncoder(transformer_encoder_layer,
|
|
41
|
+
num_layers=num_layers,
|
|
42
|
+
enable_nested_tensor=False)
|
|
43
|
+
self.transformer_layer_norm = nn.LayerNorm(transformer_input_dim)
|
|
44
|
+
|
|
45
|
+
self.feature_encoder = nn.Sequential(
|
|
46
|
+
nn.Linear(transformer_input_dim, feature_hidden_size[0]),
|
|
47
|
+
nn.Dropout(dropout),
|
|
48
|
+
nn.SiLU(),
|
|
49
|
+
nn.Linear(feature_hidden_size[0], feature_hidden_size[1]),
|
|
50
|
+
nn.Dropout(dropout),
|
|
51
|
+
nn.SiLU(),
|
|
52
|
+
nn.Linear(feature_hidden_size[1], feature_hidden_size[2]))
|
|
53
|
+
|
|
54
|
+
self.latent_projector = nn.Sequential(
|
|
55
|
+
nn.Linear(feature_hidden_size[2] * self.n_marker, hidden_size[0]),
|
|
56
|
+
nn.Dropout(dropout),
|
|
57
|
+
nn.SiLU(),
|
|
58
|
+
nn.Linear(hidden_size[0], hidden_size[1]),
|
|
59
|
+
nn.Dropout(dropout),
|
|
60
|
+
nn.SiLU())
|
|
61
|
+
|
|
62
|
+
self.mu_encoder = nn.Linear(hidden_size[1], latent_size)
|
|
63
|
+
self.logvar_encoder = nn.Linear(hidden_size[1], latent_size)
|
|
64
|
+
|
|
65
|
+
decoder_input_dim = latent_size + batch_emb_dim
|
|
66
|
+
|
|
67
|
+
self.fc_decoder = nn.Sequential(
|
|
68
|
+
nn.Linear(decoder_input_dim, self.n_marker),
|
|
69
|
+
nn.Dropout(dropout),
|
|
70
|
+
nn.SiLU(),
|
|
71
|
+
nn.Linear(self.n_marker, self.n_marker),
|
|
72
|
+
nn.Dropout(dropout),
|
|
73
|
+
nn.SiLU(),
|
|
74
|
+
nn.Linear(self.n_marker, self.n_marker))
|
|
75
|
+
|
|
76
|
+
def encode(self, x, batch_idx, attention_mask=None):
|
|
77
|
+
batch_size, seq_len = x.shape
|
|
78
|
+
device = x.device
|
|
79
|
+
|
|
80
|
+
marker_indices = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
|
|
81
|
+
marker_emb = self.marker_embedding(marker_indices)
|
|
82
|
+
|
|
83
|
+
batch_emb_flat = self.batch_embedding(batch_idx)
|
|
84
|
+
batch_emb = batch_emb_flat.unsqueeze(1).expand(-1, seq_len, -1)
|
|
85
|
+
|
|
86
|
+
value_emb = self.value_projector(x.unsqueeze(-1))
|
|
87
|
+
value_emb = self.value_layer_norm(value_emb)
|
|
88
|
+
|
|
89
|
+
combined_emb = torch.cat([value_emb, marker_emb, batch_emb], dim=-1)
|
|
90
|
+
|
|
91
|
+
transformer_output = self.transformer_encoder(
|
|
92
|
+
combined_emb, src_key_padding_mask=attention_mask
|
|
93
|
+
)
|
|
94
|
+
normalized_output = self.transformer_layer_norm(transformer_output)
|
|
95
|
+
|
|
96
|
+
encoded_features = self.feature_encoder(normalized_output).flatten(1,-1)
|
|
97
|
+
latent_projection = self.latent_projector(encoded_features)
|
|
98
|
+
mu = self.mu_encoder(latent_projection)
|
|
99
|
+
logvar = self.logvar_encoder(latent_projection)
|
|
100
|
+
var = torch.exp(logvar) + 1e-4
|
|
101
|
+
dist = Normal(mu, var.sqrt())
|
|
102
|
+
latent_sample = dist.rsample()
|
|
103
|
+
return mu, var, latent_sample
|
|
104
|
+
|
|
105
|
+
def decode(self, z, batch_idx):
|
|
106
|
+
batch_emb = self.batch_embedding(batch_idx)
|
|
107
|
+
decode_input = torch.cat([z, batch_emb], dim=-1)
|
|
108
|
+
x = self.fc_decoder(decode_input)
|
|
109
|
+
return x
|
|
110
|
+
|
|
111
|
+
def forward(self, x, batch_idx, attention_mask=None):
|
|
112
|
+
mu, var, latent_sample = self.encode(x, batch_idx, attention_mask)
|
|
113
|
+
z = mu if not self.training else latent_sample
|
|
114
|
+
x = self.decode(z, batch_idx)
|
|
115
|
+
return x, mu, var, latent_sample
|
|
116
|
+
|
|
117
|
+
def add_new_batches(self, n_new_batches):
|
|
118
|
+
"""
|
|
119
|
+
Expands the embedding layer to accommodate new batches.
|
|
120
|
+
Useful for fine-tuning on new datasets without retraining entirely.
|
|
121
|
+
"""
|
|
122
|
+
old_embedding = self.batch_embedding
|
|
123
|
+
old_num_embeddings = old_embedding.num_embeddings
|
|
124
|
+
new_num_embeddings = old_num_embeddings + n_new_batches
|
|
125
|
+
|
|
126
|
+
# Create new embedding layer
|
|
127
|
+
device = old_embedding.weight.device
|
|
128
|
+
new_embedding = nn.Embedding(new_num_embeddings, old_embedding.embedding_dim, device=device)
|
|
129
|
+
|
|
130
|
+
# Initialize weights (copy old ones, randomize new ones)
|
|
131
|
+
with torch.no_grad():
|
|
132
|
+
new_embedding.weight[:old_num_embeddings].copy_(old_embedding.weight)
|
|
133
|
+
|
|
134
|
+
self.batch_embedding = new_embedding
|
|
135
|
+
return new_embedding
|
|
136
|
+
|
|
137
|
+
@torch.no_grad()
|
|
138
|
+
def correct_batch(self, x, source_batch_idx, target_batch_idx, attention_mask=None):
|
|
139
|
+
"""
|
|
140
|
+
Translates cells from their original batch to a reference batch distribution.
|
|
141
|
+
"""
|
|
142
|
+
self.eval()
|
|
143
|
+
mu, _, _ = self.encode(x, source_batch_idx, attention_mask)
|
|
144
|
+
corrected_x = self.decode(mu, target_batch_idx)
|
|
145
|
+
return corrected_x
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def get_marker_list(df_list):
|
|
149
|
+
all_features = []
|
|
150
|
+
for df in df_list:
|
|
151
|
+
all_features.extend(df.columns.tolist())
|
|
152
|
+
marker_list = sorted(list(pd.Series(all_features).unique()))
|
|
153
|
+
is_overlap = [
|
|
154
|
+
all(marker in df.columns for df in df_list) for marker in marker_list
|
|
155
|
+
]
|
|
156
|
+
shared_markers = [marker for marker, is_o in zip(marker_list, is_overlap) if is_o]
|
|
157
|
+
unique_markers = [marker for marker, is_o in zip(marker_list, is_overlap) if not is_o]
|
|
158
|
+
|
|
159
|
+
return marker_list, shared_markers, unique_markers
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def pad_and_get_pad_mask(df_list, marker_list):
|
|
163
|
+
df_list_copy = df_list.copy()
|
|
164
|
+
pad_mask = [
|
|
165
|
+
np.tile(
|
|
166
|
+
np.array([0.0 if marker not in df.columns else 1.0 for marker in marker_list], dtype=np.float32),
|
|
167
|
+
(len(df), 1)
|
|
168
|
+
)
|
|
169
|
+
for df in df_list_copy
|
|
170
|
+
]
|
|
171
|
+
|
|
172
|
+
for i, df in enumerate(df_list_copy):
|
|
173
|
+
df = df.reindex(columns=marker_list, fill_value=0.0)
|
|
174
|
+
df_list_copy[i] = df
|
|
175
|
+
|
|
176
|
+
return df_list_copy, pad_mask
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def mask_batch(batch,
|
|
180
|
+
marker_list=None,
|
|
181
|
+
max_mask_rate_overlap=0.3,
|
|
182
|
+
separated_mask_rate=True,
|
|
183
|
+
max_mask_rate_unique=1.0,
|
|
184
|
+
shared_markers=None,
|
|
185
|
+
unique_markers=None):
|
|
186
|
+
|
|
187
|
+
batch_x_clone = batch['x'].clone()
|
|
188
|
+
pad_mask = batch['pad'].bool()
|
|
189
|
+
|
|
190
|
+
batch_size, seq_len = batch_x_clone.shape
|
|
191
|
+
device = batch_x_clone.device
|
|
192
|
+
|
|
193
|
+
def get_mask_for_top_k(scores, is_valid_token_mask, n_to_mask):
|
|
194
|
+
"""Helper to select top-k random scores for masking."""
|
|
195
|
+
scores[~is_valid_token_mask] = float('inf')
|
|
196
|
+
sorted_indices = torch.argsort(scores, dim=1)
|
|
197
|
+
ranks = torch.empty_like(sorted_indices)
|
|
198
|
+
row_ranks = torch.arange(seq_len, device=device).expand(batch_size, -1)
|
|
199
|
+
ranks.scatter_(1, sorted_indices, row_ranks)
|
|
200
|
+
return ranks < n_to_mask.unsqueeze(1)
|
|
201
|
+
|
|
202
|
+
rand_scores = torch.rand(batch_size, seq_len, device=device)
|
|
203
|
+
final_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=device)
|
|
204
|
+
|
|
205
|
+
if separated_mask_rate:
|
|
206
|
+
if marker_list is None or shared_markers is None or unique_markers is None:
|
|
207
|
+
raise ValueError(
|
|
208
|
+
"marker_list, shared_markers, and unique_markers must be provided "
|
|
209
|
+
"when separated_mask_rate is True."
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
# Convert to sets for faster lookup
|
|
213
|
+
shared_set = set(shared_markers)
|
|
214
|
+
unique_set = set(unique_markers)
|
|
215
|
+
|
|
216
|
+
# Create 1D boolean masks for the columns
|
|
217
|
+
is_shared_col = torch.tensor([m in shared_set for m in marker_list], device=device)
|
|
218
|
+
is_unique_col = torch.tensor([m in unique_set for m in marker_list], device=device)
|
|
219
|
+
|
|
220
|
+
# A marker is valid to mask if it matches the group AND is present in the sample
|
|
221
|
+
is_overlap = is_shared_col.unsqueeze(0) & pad_mask
|
|
222
|
+
is_unique = is_unique_col.unsqueeze(0) & pad_mask
|
|
223
|
+
|
|
224
|
+
# Calculate number of markers to mask for overlap and unique sets
|
|
225
|
+
max_n_overlap = is_overlap.sum(dim=1) * max_mask_rate_overlap
|
|
226
|
+
max_n_unique = is_unique.sum(dim=1) * max_mask_rate_unique
|
|
227
|
+
|
|
228
|
+
n_to_mask_overlap = torch.floor(torch.rand(batch_size, device=device) * (max_n_overlap.float() + 1))
|
|
229
|
+
n_to_mask_unique = torch.floor(torch.rand(batch_size, device=device) * (max_n_unique.float() + 1))
|
|
230
|
+
|
|
231
|
+
# Get masks and combine them
|
|
232
|
+
mask_overlap = get_mask_for_top_k(rand_scores.clone(), is_overlap, n_to_mask_overlap)
|
|
233
|
+
mask_unique = get_mask_for_top_k(rand_scores.clone(), is_unique, n_to_mask_unique)
|
|
234
|
+
final_mask = mask_overlap | mask_unique
|
|
235
|
+
|
|
236
|
+
else: # Standard masking mode
|
|
237
|
+
max_n_to_mask = pad_mask.sum(dim=1) * max_mask_rate_overlap
|
|
238
|
+
n_to_mask = torch.floor(torch.rand(batch_size, device=device) * (max_n_to_mask.float() + 1))
|
|
239
|
+
final_mask = get_mask_for_top_k(rand_scores, pad_mask, n_to_mask)
|
|
240
|
+
|
|
241
|
+
# Apply the final mask: Set masked expression values to 0.0
|
|
242
|
+
batch_x_clone[final_mask] = 0.0
|
|
243
|
+
|
|
244
|
+
return {
|
|
245
|
+
'x': batch_x_clone,
|
|
246
|
+
'batch': batch['batch'].clone(),
|
|
247
|
+
'pad': batch['pad'].clone(),
|
|
248
|
+
'mask': final_mask | ~batch['pad'].bool()
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def loss_function(recon_x, x, pad, mu, var, kl_weight):
|
|
253
|
+
"""
|
|
254
|
+
Calculates the VAE loss, combining reconstruction and KL divergence.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
recon_x (torch.Tensor): The reconstructed data from the model.
|
|
258
|
+
x (torch.Tensor): The original input data.
|
|
259
|
+
pad_mask (torch.Tensor): A boolean mask where True indicates non-padded
|
|
260
|
+
values that should be included in the loss.
|
|
261
|
+
mu (torch.Tensor): The mean of the latent distribution.
|
|
262
|
+
var (torch.Tensor): The variance of the latent distribution.
|
|
263
|
+
kl_weight (float): A weight factor for the KL divergence term, often
|
|
264
|
+
used for annealing.
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
torch.Tensor: The total calculated loss.
|
|
268
|
+
"""
|
|
269
|
+
# Reconstruction loss (Huber loss) on non-padded values
|
|
270
|
+
recon_loss = F.huber_loss(recon_x[pad], x[pad])
|
|
271
|
+
|
|
272
|
+
# KL divergence between the learned distribution and a standard normal
|
|
273
|
+
prior = Normal(torch.zeros_like(mu), torch.ones_like(var))
|
|
274
|
+
kl_div = kl(Normal(mu, var.sqrt()), prior).sum(dim=1).mean()
|
|
275
|
+
kl_loss = kl_div * kl_weight
|
|
276
|
+
|
|
277
|
+
return recon_loss + kl_loss
|
cystainer/runner.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
from torch.utils.data import RandomSampler
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import anndata as ad
|
|
6
|
+
import numpy as np
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
|
|
9
|
+
from .modules import CyStainerModel, mask_batch, loss_function
|
|
10
|
+
from .data import load_data_from_folder
|
|
11
|
+
|
|
12
|
+
def get_default_device():
|
|
13
|
+
"""Helper function to automatically detect CUDA, MPS, or CPU."""
|
|
14
|
+
if torch.cuda.is_available():
|
|
15
|
+
return torch.device('cuda')
|
|
16
|
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
|
17
|
+
return torch.device('mps')
|
|
18
|
+
return torch.device('cpu')
|
|
19
|
+
|
|
20
|
+
class CyStainer:
|
|
21
|
+
def __init__(self, marker_info=None, batch_dict=None, model=None, device=None, **model_kwargs):
|
|
22
|
+
"""
|
|
23
|
+
Initializes the CyStainer wrapper. Can be initialized empty and populated later
|
|
24
|
+
via .load_train_data() and .build_model().
|
|
25
|
+
"""
|
|
26
|
+
self.device = device or get_default_device()
|
|
27
|
+
|
|
28
|
+
self.marker_info = marker_info
|
|
29
|
+
self.batch_dict = batch_dict
|
|
30
|
+
self.train_dataloader = None
|
|
31
|
+
self.finetune_dataloader = None
|
|
32
|
+
self.predict_dataloader = None
|
|
33
|
+
self.model = None
|
|
34
|
+
|
|
35
|
+
if model is not None:
|
|
36
|
+
self.model = model.to(self.device)
|
|
37
|
+
self.marker_info = model.marker_info
|
|
38
|
+
self.batch_dict = model.batch_dict
|
|
39
|
+
elif marker_info is not None and batch_dict is not None:
|
|
40
|
+
self.build_model(**model_kwargs)
|
|
41
|
+
|
|
42
|
+
def __getstate__(self):
|
|
43
|
+
"""
|
|
44
|
+
Dictates what gets saved when torch.save(stainer) is called.
|
|
45
|
+
We create a copy of the object's dictionary and remove the dataloader.
|
|
46
|
+
"""
|
|
47
|
+
state = self.__dict__.copy()
|
|
48
|
+
state['train_dataloader'] = None
|
|
49
|
+
state['finetune_dataloader'] = None
|
|
50
|
+
state['predict_dataloader'] = None
|
|
51
|
+
return state
|
|
52
|
+
|
|
53
|
+
def load_train_data(self, folder_path=None, adata_list=None, **kwargs):
|
|
54
|
+
"""
|
|
55
|
+
Convenience method to load data from a folder or an AnnData list
|
|
56
|
+
and populate the class state.
|
|
57
|
+
"""
|
|
58
|
+
source_msg = folder_path if folder_path else "provided AnnData list"
|
|
59
|
+
print(f"Loading data from {source_msg}...")
|
|
60
|
+
|
|
61
|
+
self.train_dataloader, self.marker_info, self.batch_dict = load_data_from_folder(
|
|
62
|
+
folder_path=folder_path,
|
|
63
|
+
adata_list=adata_list,
|
|
64
|
+
**kwargs
|
|
65
|
+
)
|
|
66
|
+
print(f"Data loaded. Found {len(self.marker_info['marker_list'])} markers and {len(self.batch_dict)} batches.")
|
|
67
|
+
return self
|
|
68
|
+
|
|
69
|
+
def load_finetune_data(self, folder_path=None, adata_list=None, **kwargs):
|
|
70
|
+
"""
|
|
71
|
+
Loads a new dataset for fine-tuning, using the existing model's
|
|
72
|
+
markers and batches as a reference to align the new data.
|
|
73
|
+
"""
|
|
74
|
+
if self.marker_info is None or self.batch_dict is None:
|
|
75
|
+
raise RuntimeError("Base model state not found. Train or load a base model first before loading fine-tuning data.")
|
|
76
|
+
|
|
77
|
+
source_msg = folder_path if folder_path else "provided AnnData list"
|
|
78
|
+
print(f"Loading fine-tuning data from {source_msg}...")
|
|
79
|
+
|
|
80
|
+
self.finetune_dataloader, _, self.ft_batch_dict = load_data_from_folder(
|
|
81
|
+
folder_path=folder_path,
|
|
82
|
+
adata_list=adata_list,
|
|
83
|
+
reference_markers=self.marker_info,
|
|
84
|
+
reference_batch_dict=self.batch_dict,
|
|
85
|
+
**kwargs
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
ft_batches_count = len(self.ft_batch_dict) - len(self.batch_dict)
|
|
89
|
+
print(f"Fine-tuning data loaded. Found {ft_batches_count} new batches to integrate.")
|
|
90
|
+
return self
|
|
91
|
+
|
|
92
|
+
def load_predict_data(self, folder_path=None, adata_list=None, **kwargs):
|
|
93
|
+
"""
|
|
94
|
+
Loads data specifically for inference. Forces is_train=False to
|
|
95
|
+
ensure the predicted cells maintain their original order.
|
|
96
|
+
"""
|
|
97
|
+
if self.marker_info is None or self.batch_dict is None:
|
|
98
|
+
raise RuntimeError("Model state missing. Train or load a model first.")
|
|
99
|
+
|
|
100
|
+
source_msg = folder_path if folder_path else "provided AnnData list"
|
|
101
|
+
print(f"Loading prediction data from {source_msg}...")
|
|
102
|
+
|
|
103
|
+
# Force is_train to False to prevent shuffling!
|
|
104
|
+
kwargs['is_train'] = False
|
|
105
|
+
|
|
106
|
+
self.predict_dataloader, _, _ = load_data_from_folder(
|
|
107
|
+
folder_path=folder_path,
|
|
108
|
+
adata_list=adata_list,
|
|
109
|
+
reference_markers=self.marker_info,
|
|
110
|
+
reference_batch_dict=self.batch_dict,
|
|
111
|
+
**kwargs
|
|
112
|
+
)
|
|
113
|
+
print("Prediction data loaded successfully.")
|
|
114
|
+
return self
|
|
115
|
+
|
|
116
|
+
def build_model(self, **model_kwargs):
|
|
117
|
+
"""
|
|
118
|
+
Initializes the PyTorch model using the state extracted from load_train_data.
|
|
119
|
+
"""
|
|
120
|
+
if self.marker_info is None or self.batch_dict is None:
|
|
121
|
+
raise RuntimeError("Cannot build model: marker_info and batch_dict are missing. Run .load_train_data() first.")
|
|
122
|
+
|
|
123
|
+
self.model = CyStainerModel(
|
|
124
|
+
marker_info=self.marker_info,
|
|
125
|
+
batch_dict=self.batch_dict,
|
|
126
|
+
**model_kwargs
|
|
127
|
+
).to(self.device)
|
|
128
|
+
print("Model initialized successfully.")
|
|
129
|
+
return self
|
|
130
|
+
|
|
131
|
+
def train(self, dataloader=None, max_epochs=300, lr=0.0001, kl_weight=0.0001, tol=1e-4, patience=3,
|
|
132
|
+
model_name='cystainer.pt', loss_name='loss.csv', metrics_path='./metrics_and_loss', separated_mask_rate=True):
|
|
133
|
+
"""Trains the model from scratch with early stopping."""
|
|
134
|
+
if self.model is None:
|
|
135
|
+
raise RuntimeError("Model is not built. Call .build_model() before training.")
|
|
136
|
+
|
|
137
|
+
train_loader = dataloader or self.train_dataloader
|
|
138
|
+
if train_loader is None:
|
|
139
|
+
raise ValueError("No dataloader found. Please provide one or run .load_train_data() first.")
|
|
140
|
+
|
|
141
|
+
optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)
|
|
142
|
+
self._run_loop(train_loader, optimizer, max_epochs, kl_weight, tol, patience, model_name, loss_name, metrics_path, separated_mask_rate)
|
|
143
|
+
|
|
144
|
+
def finetune(self, ft_loader=None, ft_batch_dict=None, max_epochs=50, lr=0.001, kl_weight=0.0001, tol=1e-4, patience=3,
|
|
145
|
+
model_name='cystainer_finetuned.pt', loss_name='loss_finetuned.csv', metrics_path='./metrics_and_loss', separated_mask_rate=True):
|
|
146
|
+
"""Fine-tunes only the batch embeddings for new data with early stopping."""
|
|
147
|
+
if self.model is None:
|
|
148
|
+
raise RuntimeError("Model is not trained or loaded.")
|
|
149
|
+
|
|
150
|
+
# Fallback to internal state if arguments aren't provided
|
|
151
|
+
ft_loader = ft_loader or self.finetune_dataloader
|
|
152
|
+
ft_batch_dict = ft_batch_dict or self.ft_batch_dict
|
|
153
|
+
|
|
154
|
+
if ft_loader is None or ft_batch_dict is None:
|
|
155
|
+
raise ValueError("Missing fine-tuning data. Run .load_finetune_data() first or pass the dataloader and new_batch_dict explicitly.")
|
|
156
|
+
|
|
157
|
+
# Calculate how many new batches were added
|
|
158
|
+
ft_batches_count = len(ft_batch_dict) - len(self.batch_dict)
|
|
159
|
+
if ft_batches_count > 0:
|
|
160
|
+
self.model.add_new_batches(ft_batches_count)
|
|
161
|
+
self.batch_dict = ft_batch_dict
|
|
162
|
+
|
|
163
|
+
# Freeze all parameters except batch_embedding
|
|
164
|
+
for param in self.model.parameters():
|
|
165
|
+
param.requires_grad = False
|
|
166
|
+
self.model.batch_embedding.weight.requires_grad = True
|
|
167
|
+
|
|
168
|
+
# Calculate the index where new batches start
|
|
169
|
+
old_num_embeddings = len(self.batch_dict) - ft_batches_count
|
|
170
|
+
|
|
171
|
+
# Register a backward hook to freeze old batch embeddings
|
|
172
|
+
def freeze_old_embeddings(grad):
|
|
173
|
+
grad_clone = grad.clone()
|
|
174
|
+
grad_clone[:old_num_embeddings] = 0.0
|
|
175
|
+
return grad_clone
|
|
176
|
+
|
|
177
|
+
self.model.batch_embedding.weight.register_hook(freeze_old_embeddings)
|
|
178
|
+
|
|
179
|
+
optimizer = torch.optim.AdamW(self.model.batch_embedding.parameters(), lr=lr)
|
|
180
|
+
self._run_loop(ft_loader, optimizer, max_epochs, kl_weight, tol, patience, model_name, loss_name, metrics_path, separated_mask_rate)
|
|
181
|
+
|
|
182
|
+
def predict(self, dataloader=None, output_path='adata_pred.h5ad', keep_markers=None, correct_batch=False, target_batch_name=None, return_pred=False):
|
|
183
|
+
"""Runs inference and saves the predicted AnnData object."""
|
|
184
|
+
if self.model is None:
|
|
185
|
+
raise RuntimeError("Model is not built or loaded.")
|
|
186
|
+
|
|
187
|
+
# Fallback logic: User provided -> Predict loader -> Finetune loader -> Train loader
|
|
188
|
+
pred_loader = dataloader or self.predict_dataloader or self.finetune_dataloader or self.train_dataloader
|
|
189
|
+
|
|
190
|
+
if pred_loader is None:
|
|
191
|
+
raise ValueError("No dataloader found. Pass one explicitly or run .load_predict_data() first.")
|
|
192
|
+
|
|
193
|
+
# Block predictions on shuffled dataloaders
|
|
194
|
+
if isinstance(pred_loader.sampler, RandomSampler):
|
|
195
|
+
raise RuntimeError(
|
|
196
|
+
"CRITICAL: You are attempting to predict on a shuffled DataLoader! "
|
|
197
|
+
"The output rows will NOT correspond to the original order of your cells in the AnnData object. "
|
|
198
|
+
"You must load the data using .load_predict_data() before running .predict()."
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
target_batch_int = 0
|
|
202
|
+
if correct_batch and target_batch_name is not None:
|
|
203
|
+
if target_batch_name not in self.batch_dict:
|
|
204
|
+
raise ValueError(f"Batch name '{target_batch_name}' not found. Available batches: {list(self.batch_dict.keys())}")
|
|
205
|
+
target_batch_int = self.batch_dict[target_batch_name]
|
|
206
|
+
|
|
207
|
+
self.model.eval()
|
|
208
|
+
result = []
|
|
209
|
+
|
|
210
|
+
if correct_batch:
|
|
211
|
+
with torch.no_grad():
|
|
212
|
+
for batch in tqdm(pred_loader, desc="Predicting (Batch Corrected)"):
|
|
213
|
+
batch = {key: value.to(self.device) for key, value in batch.items()}
|
|
214
|
+
target_batch = torch.full_like(batch['batch'], fill_value=target_batch_int)
|
|
215
|
+
x_pred = self.model.correct_batch(batch['x'], batch['batch'], target_batch, ~batch['pad'].bool())
|
|
216
|
+
result.append(x_pred.cpu().numpy())
|
|
217
|
+
else:
|
|
218
|
+
with torch.no_grad():
|
|
219
|
+
for batch in tqdm(pred_loader, desc="Predicting"):
|
|
220
|
+
batch = {key: value.to(self.device) for key, value in batch.items()}
|
|
221
|
+
x_pred, _, _, _ = self.model.forward(batch['x'], batch['batch'], ~batch['pad'].bool())
|
|
222
|
+
result.append(x_pred.cpu().numpy())
|
|
223
|
+
|
|
224
|
+
df_pred = pd.DataFrame(np.concatenate(result), columns=self.marker_info['marker_list'])
|
|
225
|
+
if keep_markers is not None:
|
|
226
|
+
df_pred = df_pred[keep_markers]
|
|
227
|
+
if return_pred:
|
|
228
|
+
return df_pred
|
|
229
|
+
else:
|
|
230
|
+
ad.AnnData(df_pred).write(output_path, compression='gzip')
|
|
231
|
+
print(f"Predictions saved to {output_path}")
|
|
232
|
+
|
|
233
|
+
def save(self, losses=None, model_name=None, loss_name=None, metrics_path=None):
|
|
234
|
+
"""Saves the training losses and the entire CyStainer object."""
|
|
235
|
+
# Save losses if provided
|
|
236
|
+
if losses is not None:
|
|
237
|
+
if not os.path.exists(metrics_path):
|
|
238
|
+
os.makedirs(metrics_path, exist_ok=True)
|
|
239
|
+
df_loss = pd.DataFrame(losses, columns=['loss'])
|
|
240
|
+
df_loss.to_csv(f'{metrics_path}/{loss_name}', index=False)
|
|
241
|
+
|
|
242
|
+
torch.save(self, model_name)
|
|
243
|
+
|
|
244
|
+
def _run_loop(self, dataloader, optimizer, max_epochs, kl_weight, tol, patience, model_name, loss_name, metrics_path, separated_mask_rate):
|
|
245
|
+
"""Internal training loop with early stopping mechanism."""
|
|
246
|
+
self.model.train()
|
|
247
|
+
pbar = tqdm(range(max_epochs))
|
|
248
|
+
|
|
249
|
+
prev_loss = float('inf')
|
|
250
|
+
patience_counter = 0
|
|
251
|
+
losses = []
|
|
252
|
+
|
|
253
|
+
for epoch in pbar:
|
|
254
|
+
epoch_losses = []
|
|
255
|
+
|
|
256
|
+
for batch in dataloader:
|
|
257
|
+
batch = {key: value.to(self.device) for key, value in batch.items()}
|
|
258
|
+
batch_masked = mask_batch(
|
|
259
|
+
batch,
|
|
260
|
+
self.marker_info['marker_list'],
|
|
261
|
+
separated_mask_rate=separated_mask_rate,
|
|
262
|
+
shared_markers=self.marker_info['shared_markers'],
|
|
263
|
+
unique_markers=self.marker_info['unique_markers']
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
x_pred, mu, var, _ = self.model.forward(batch_masked['x'], batch_masked['batch'], batch_masked['mask'].bool())
|
|
267
|
+
loss = loss_function(x_pred, batch['x'], batch['pad'].bool(), mu, var, kl_weight)
|
|
268
|
+
|
|
269
|
+
optimizer.zero_grad()
|
|
270
|
+
loss.backward()
|
|
271
|
+
optimizer.step()
|
|
272
|
+
|
|
273
|
+
epoch_losses.append(loss.item())
|
|
274
|
+
|
|
275
|
+
# Early Stopping Logic
|
|
276
|
+
avg_loss = np.mean(epoch_losses)
|
|
277
|
+
losses.append(avg_loss)
|
|
278
|
+
pbar.set_description(f'Loss: {avg_loss:.4f}')
|
|
279
|
+
|
|
280
|
+
self.save(losses=losses, model_name=model_name, loss_name=loss_name, metrics_path=metrics_path)
|
|
281
|
+
|
|
282
|
+
loss_diff = abs(prev_loss - avg_loss)
|
|
283
|
+
|
|
284
|
+
if loss_diff < tol:
|
|
285
|
+
patience_counter += 1
|
|
286
|
+
else:
|
|
287
|
+
patience_counter = 0
|
|
288
|
+
|
|
289
|
+
if patience_counter >= patience:
|
|
290
|
+
# We use tqdm.write so it doesn't break the progress bar visual
|
|
291
|
+
tqdm.write(f"\nEarly stopping triggered at epoch {epoch + 1}. Loss change ({loss_diff:.6f}) is below tolerance ({tol}).")
|
|
292
|
+
break
|
|
293
|
+
|
|
294
|
+
prev_loss = avg_loss
|
|
295
|
+
|
|
296
|
+
@classmethod
|
|
297
|
+
def load(cls, filepath, device=None):
|
|
298
|
+
"""
|
|
299
|
+
Safely loads a saved CyStainer object, automatically mapping all tensors
|
|
300
|
+
and internal states to the requested device (CPU, GPU, or MPS).
|
|
301
|
+
"""
|
|
302
|
+
# Determine the target device
|
|
303
|
+
target_device = device or get_default_device()
|
|
304
|
+
|
|
305
|
+
print(f"Loading model from {filepath} onto {target_device}...")
|
|
306
|
+
|
|
307
|
+
# 1. Load the object, intercepting tensors and forcing them to the target device
|
|
308
|
+
stainer = torch.load(filepath, map_location=target_device, weights_only=False)
|
|
309
|
+
|
|
310
|
+
# 2. Update the internal device attribute so future dataloaders route correctly
|
|
311
|
+
stainer.device = target_device
|
|
312
|
+
|
|
313
|
+
# 3. Explicitly move the PyTorch model to the new device just to be perfectly safe
|
|
314
|
+
if stainer.model is not None:
|
|
315
|
+
stainer.model = stainer.model.to(target_device)
|
|
316
|
+
|
|
317
|
+
print("Model loaded successfully.")
|
|
318
|
+
return stainer
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: cystainer
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A PyTorch-based tool for predicting missing proteins in cytometry/single-cell data
|
|
5
|
+
Author: Konstantin Ivanov
|
|
6
|
+
Author-email: kivanov@uef.fi
|
|
7
|
+
Classifier: Programming Language :: Python :: 3
|
|
8
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
9
|
+
Classifier: Operating System :: OS Independent
|
|
10
|
+
Requires-Python: >=3.11.9
|
|
11
|
+
Description-Content-Type: text/markdown
|
|
12
|
+
License-File: LICENSE
|
|
13
|
+
Requires-Dist: torch>=2.7.1
|
|
14
|
+
Requires-Dist: numpy>=1.26.4
|
|
15
|
+
Requires-Dist: pandas>=2.2.3
|
|
16
|
+
Requires-Dist: anndata>=0.12.4
|
|
17
|
+
Requires-Dist: scipy>=1.14.1
|
|
18
|
+
Requires-Dist: scikit-learn>=1.5.1
|
|
19
|
+
Requires-Dist: tqdm>=4.66.5
|
|
20
|
+
Requires-Dist: matplotlib>=3.9.2
|
|
21
|
+
Dynamic: author
|
|
22
|
+
Dynamic: author-email
|
|
23
|
+
Dynamic: classifier
|
|
24
|
+
Dynamic: description
|
|
25
|
+
Dynamic: description-content-type
|
|
26
|
+
Dynamic: license-file
|
|
27
|
+
Dynamic: requires-dist
|
|
28
|
+
Dynamic: requires-python
|
|
29
|
+
Dynamic: summary
|
|
30
|
+
|
|
31
|
+
# CyStainer
|
|
32
|
+
CyStainer package for cytometry marker imputation
|
|
33
|
+
|
|
34
|
+
**CyStainer** is a PyTorch-based deep learning tool for predicting missing proteins and imputing marker expression in cytometry and single-cell data. It utilizes a combination of Variational Autoencoders (VAEs) and Transformer architectures to integrate multiple batches/panels and infer missing markers accurately.
|
|
35
|
+
|
|
36
|
+
## 📦 Installation
|
|
37
|
+
|
|
38
|
+
Since CyStainer is built on PyTorch, ensure you have an environment with Python 3.11+ and the appropriate PyTorch version for your hardware (CUDA recommended).
|
|
39
|
+
|
|
40
|
+
You can install CyStainer directly from the source:
|
|
41
|
+
|
|
42
|
+
```bash
|
|
43
|
+
git clone https://github.com/sysgen-uef/cystainer_package.git
|
|
44
|
+
cd cystainer_package
|
|
45
|
+
pip install .
|
|
46
|
+
```
|
|
47
|
+
## 🧹 Data Preprocessing (From .fcs to .h5ad)
|
|
48
|
+
|
|
49
|
+
**CyStainer** expects your input data to be formatted as `AnnData` objects (either passed as a list in memory or saved locally as `.h5ad` files). Before feeding your cytometry data into the model, it must be properly preprocessed.
|
|
50
|
+
|
|
51
|
+
**Mandatory Pre-processing:**
|
|
52
|
+
* **Cleaning:** Ensure your data is pre-gated to remove doublets, debris, and dead cells.
|
|
53
|
+
* **Compensation:** Your `.fcs` files should be already compensated.
|
|
54
|
+
|
|
55
|
+
**Transformation & Scaling (Recommended):**
|
|
56
|
+
You will typically need to transform your fluorescence intensities (e.g., using an arcsinh transformation) and optionally scale them. Removing saturated values (extreme highs or zeros) can also improve model performance depending on your dataset.
|
|
57
|
+
|
|
58
|
+
Here is a minimal example of how to process a standard `.fcs` file into a ready-to-use `.h5ad` file using `FlowKit` and `AnnData`:
|
|
59
|
+
|
|
60
|
+
```python
|
|
61
|
+
import flowkit as fk
|
|
62
|
+
import pandas as pd
|
|
63
|
+
import numpy as np
|
|
64
|
+
import anndata as ad
|
|
65
|
+
|
|
66
|
+
# Load the compensated and cleaned .fcs file
|
|
67
|
+
sample = fk.Sample('path/to/cleaned_sample.fcs')
|
|
68
|
+
df = sample.as_dataframe('raw')
|
|
69
|
+
df.columns = sample.pnn_labels # Set marker names
|
|
70
|
+
|
|
71
|
+
# Transformation (e.g., arcsinh with a cofactor of 100 or 150)
|
|
72
|
+
cofactor = 100
|
|
73
|
+
non_scatter = ['FSC' not in c and 'SSC' not in c for c in df.columns]
|
|
74
|
+
df.loc[:, non_scatter] = np.arcsinh(df.loc[:, non_scatter] / cofactor)
|
|
75
|
+
|
|
76
|
+
# Optional: Remove saturation / extreme outliers
|
|
77
|
+
# Useful if your instrument records artificial bounds (e.g., exactly 0 or max value)
|
|
78
|
+
df = df[~np.any(((df <= 0) | (df >= df.max().max())), axis=1)]
|
|
79
|
+
|
|
80
|
+
# Optional: Scaling (Min-Max scaling to [0, 1] or Z-score normalization)
|
|
81
|
+
# Min-Max Scaling example:
|
|
82
|
+
df = (df - df.min()) / (df.max() - df.min())
|
|
83
|
+
# Z-score Scaling example (alternatively):
|
|
84
|
+
# df = (df - df.mean()) / df.std()
|
|
85
|
+
|
|
86
|
+
# Convert to AnnData and save for CyStainer
|
|
87
|
+
adata = ad.AnnData(df)
|
|
88
|
+
adata.write('preprocessed_sample.h5ad', compression='gzip')
|
|
89
|
+
```
|
|
90
|
+
|
|
91
|
+
Once your `.fcs` files are converted into `.h5ad` objects, you can load them directly into your workflow.
|
|
92
|
+
|
|
93
|
+
## 🚀 Quick Start Guide
|
|
94
|
+
|
|
95
|
+
The primary way to interact with the package is through the `CyStainer` wrapper class. The standard workflow consists of initializing the model, loading training data, building the network, and running inference.
|
|
96
|
+
|
|
97
|
+
### 1. Training a Base Model
|
|
98
|
+
|
|
99
|
+
You can load your data either from a folder of `.h5ad` files or by passing a list of `AnnData` objects directly in memory.
|
|
100
|
+
|
|
101
|
+
```python
|
|
102
|
+
from cystainer import CyStainer
|
|
103
|
+
|
|
104
|
+
# Initialize the stainer (automatically detects CUDA/CPU)
|
|
105
|
+
stainer = CyStainer()
|
|
106
|
+
|
|
107
|
+
# Load training data
|
|
108
|
+
# Alternatively, use: adata_list=[adata1, adata2]
|
|
109
|
+
# CyStainer includes a built-in utility to visualize how markers overlap across different panels or batches.
|
|
110
|
+
stainer.load_train_data(folder_path='./data_example/train', get_panel_vis=True)
|
|
111
|
+
```
|
|
112
|
+

|
|
113
|
+
|
|
114
|
+
```python
|
|
115
|
+
# Build the model
|
|
116
|
+
# You can pass custom hyperparameters here if needed
|
|
117
|
+
stainer.build_model()
|
|
118
|
+
|
|
119
|
+
# Train the model
|
|
120
|
+
stainer.train()
|
|
121
|
+
|
|
122
|
+
# By default the model is automatically saved in the same directory
|
|
123
|
+
# as cystainer.pt
|
|
124
|
+
```
|
|
125
|
+
|
|
126
|
+
### 2. Predicting Missing Markers
|
|
127
|
+
|
|
128
|
+
Once a model is trained, you can load inference data. The `.load_predict_data()` method ensures your cells are not shuffled so that the output matches your input order.
|
|
129
|
+
|
|
130
|
+
```python
|
|
131
|
+
# Load prediction data using the base model's reference markers
|
|
132
|
+
stainer.load_predict_data(folder_path='./data_example/test')
|
|
133
|
+
|
|
134
|
+
# Run predictions and save directly to disk
|
|
135
|
+
stainer.predict(output_path='imputed_cells.h5ad')
|
|
136
|
+
|
|
137
|
+
# Alternatively, return the predictions as a pandas DataFrame:
|
|
138
|
+
# df_imputed = stainer.predict(return_pred=True)
|
|
139
|
+
```
|
|
140
|
+
|
|
141
|
+
### 3. Fine-Tuning on New Batches
|
|
142
|
+
|
|
143
|
+
If you receive new data from a different batch or panel, you don't need to retrain from scratch. CyStainer can freeze the main network and fine-tune only the batch embeddings to align the new data.
|
|
144
|
+
|
|
145
|
+
```python
|
|
146
|
+
# Load the pre-trained model
|
|
147
|
+
stainer = CyStainer.load('cystainer.pt')
|
|
148
|
+
|
|
149
|
+
# Load the new data for fine-tuning
|
|
150
|
+
# Note, anndata objects must have batch info column
|
|
151
|
+
stainer.load_finetune_data(folder_path='./data_example/test', batch_info='batch')
|
|
152
|
+
|
|
153
|
+
# Fine-tune the batch embeddings
|
|
154
|
+
stainer.finetune()
|
|
155
|
+
|
|
156
|
+
# Predict on the newly fine-tuned data
|
|
157
|
+
stainer.predict(output_path='imputed_new_batch.h5ad')
|
|
158
|
+
```
|
|
159
|
+
|
|
160
|
+
### 4. Batch Correction
|
|
161
|
+
|
|
162
|
+
CyStainer allows you to translate cells from their original batch to a specific target batch distribution.
|
|
163
|
+
|
|
164
|
+
```python
|
|
165
|
+
stainer.predict(
|
|
166
|
+
output_path='batch_corrected_cells.h5ad',
|
|
167
|
+
correct_batch=True,
|
|
168
|
+
target_batch_name='single_batch' # Must match a batch name in stainer.batch_dict
|
|
169
|
+
)
|
|
170
|
+
```
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
cystainer/__init__.py,sha256=Yv_K9NHhvtxenToMvf3FadTmDS1YSmBwkbjBnkwewJM,301
|
|
2
|
+
cystainer/data.py,sha256=xS5-XX1NguRH4Mxa4otMjwzE00HlhJzwooz43aPD1gw,8321
|
|
3
|
+
cystainer/modules.py,sha256=rKZMW_-7yYJXSEJlSjTVtsY4cyFqT8_qwoXVHVfWrco,11195
|
|
4
|
+
cystainer/runner.py,sha256=CaEsn50m-v3YnyTSgBvHacEG7DbUA5ScsKIXrMkJv7Q,14413
|
|
5
|
+
cystainer-0.1.0.dist-info/licenses/LICENSE,sha256=0qOacCooGqkP7CL7qDsSOyMSmGtEb1s_Q1oSoZSrLxs,1067
|
|
6
|
+
cystainer-0.1.0.dist-info/METADATA,sha256=sSH7f70qUwymDJXee7Bq60cZFETGyfSRE2L2rCPmy3Q,6211
|
|
7
|
+
cystainer-0.1.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
|
|
8
|
+
cystainer-0.1.0.dist-info/top_level.txt,sha256=8kyy2H_PFxOOr53It06GAWAyUNDGd3Tw1ZThnhaSPcg,10
|
|
9
|
+
cystainer-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Sysgen lab
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
cystainer
|