cystainer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cystainer/__init__.py ADDED
@@ -0,0 +1,16 @@
1
+ # cystainer/__init__.py
2
+
3
+ # Package version
4
+ __version__ = "0.1.0"
5
+
6
+ # Imports
7
+ from .runner import CyStainer
8
+ from .data import load_data_from_folder
9
+ from .modules import CyStainerModel
10
+
11
+ # Imports `from cystainer import *`
12
+ __all__ = [
13
+ "CyStainerModel",
14
+ "CyStainer",
15
+ "load_data_from_folder",
16
+ ]
cystainer/data.py ADDED
@@ -0,0 +1,204 @@
1
+ import os
2
+ import itertools
3
+ import torch
4
+ import numpy as np
5
+ import pandas as pd
6
+ import anndata as ad
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib.patches as patches
9
+ from torch.utils.data import DataLoader
10
+ from .modules import get_marker_list, pad_and_get_pad_mask
11
+
12
+ def load_data_from_folder(folder_path=None, adata_list=None, exclude_files=[], max_n_cells=int(1e+6), batch_size=1024, is_train=True, drop_markers=None,
13
+ reference_markers=None, reference_batch_dict=None, normalize=False, batch_info=None, get_panel_vis=False):
14
+ """Reads .h5ad files from a folder OR accepts a list of AnnData objects directly."""
15
+
16
+ # Establish the adata_list
17
+ if adata_list is None:
18
+ if folder_path is None:
19
+ raise ValueError("You must provide either 'folder_path' or 'adata_list'.")
20
+
21
+ sample_names = [f for f in os.listdir(folder_path) if f.endswith('.h5ad') and (f not in exclude_files)]
22
+ if not sample_names:
23
+ raise ValueError(f"No files found in {folder_path}")
24
+
25
+ # Load adatas from disk
26
+ loaded_adatas = [ad.read_h5ad(os.path.join(folder_path, f)) for f in sample_names]
27
+ else:
28
+ # Use the provided list of AnnData objects
29
+ loaded_adatas = adata_list
30
+
31
+ # Extract DataFrames and Batch Info
32
+ df_list = [adata.to_df().iloc[:max_n_cells,:] for adata in loaded_adatas]
33
+
34
+ if batch_info is not None:
35
+ batch_list = [adata.obs[batch_info].tolist() for adata in loaded_adatas]
36
+ batch_names = list(itertools.chain.from_iterable(batch_list))
37
+ else:
38
+ print("No batch info passed, setting 'single_batch' for the data.")
39
+ batch_names = list(itertools.chain.from_iterable([['single_batch'] * df.shape[0] for df in df_list]))
40
+
41
+ # Visualization, Processing, and Normalization
42
+ if get_panel_vis:
43
+ visualize_panel_overlap(df_list)
44
+
45
+ if drop_markers is not None:
46
+ df_list = [df.loc[:,~df.columns.isin(drop_markers)] for df in df_list]
47
+
48
+ if normalize:
49
+ df_list = [(df - df.mean()) / df.std() for df in df_list]
50
+
51
+ # Handle Batches
52
+ if reference_batch_dict is None:
53
+ # Sort for deterministic behavior
54
+ batch_dict = {batch: i for i, batch in enumerate(sorted(list(set(batch_names))))}
55
+ else:
56
+ batch_dict = reference_batch_dict.copy()
57
+ current_max_idx = max(batch_dict.values()) if batch_dict else -1
58
+
59
+ # Find only the new batches and sort them
60
+ new_batches = sorted(list(set(batch_names) - set(batch_dict.keys())))
61
+
62
+ for batch in new_batches:
63
+ current_max_idx += 1
64
+ batch_dict[batch] = current_max_idx
65
+
66
+ batch_ids = pd.Series(batch_names).apply(lambda x: batch_dict[x]).values
67
+
68
+ # Handle Markers and Padding
69
+ if reference_markers is None:
70
+ marker_list, shared_markers, unique_markers = get_marker_list(df_list)
71
+ else:
72
+ marker_list = reference_markers['marker_list']
73
+ shared_markers = reference_markers['shared_markers']
74
+ unique_markers = reference_markers['unique_markers']
75
+
76
+ df_list, pad_mask = pad_and_get_pad_mask(df_list, marker_list)
77
+
78
+ # PyTorch Dataloader
79
+ data = [{'x': x.astype('float32'), 'batch': b, 'pad': p}
80
+ for x, b, p in zip(pd.concat(df_list).values, batch_ids, np.concatenate(pad_mask, axis=0))]
81
+
82
+ dataloader = DataLoader(data, batch_size=batch_size, shuffle=is_train)
83
+
84
+ markers_info = {
85
+ 'marker_list': marker_list,
86
+ 'shared_markers': shared_markers,
87
+ 'unique_markers': unique_markers
88
+ }
89
+
90
+ return dataloader, markers_info, batch_dict
91
+
92
+ def visualize_panel_overlap(df_list, alignment='left'):
93
+ """
94
+ Creates a block alignment plot showing shared and unique markers
95
+ across different panels.
96
+
97
+ Args:
98
+ df_list (list): List of pandas DataFrames.
99
+ panel_names (list): List of panel names as strings.
100
+ alignment (str): 'center' for a center-outwards pyramid style,
101
+ 'left' for a cascading left-aligned style.
102
+ """
103
+ # Extract markers and build presence matrix
104
+ panel_markers = set([frozenset(df.columns) for df in df_list])
105
+ all_markers = set().union(*panel_markers)
106
+ num_panels = len(panel_markers)
107
+ panel_names = [f'Panel {i+1}' for i in range(num_panels)]
108
+
109
+ matrix = pd.DataFrame(index=list(all_markers), columns=panel_names)
110
+ for name, markers in zip(panel_names, panel_markers):
111
+ matrix[name] = matrix.index.isin(markers).astype(int)
112
+
113
+ # Sort markers based on requested alignment
114
+ if alignment == 'left':
115
+ # Simple cascade sort
116
+ matrix = matrix.sort_values(by=panel_names, ascending=False)
117
+ sorted_markers = matrix.index.tolist()
118
+
119
+ elif alignment == 'center':
120
+ # Center-outwards sort (Frequency and Center of Mass)
121
+ matrix['freq'] = matrix.sum(axis=1)
122
+ col_indices = np.arange(num_panels)
123
+
124
+ matrix['com'] = (matrix[panel_names] * col_indices).sum(axis=1) / matrix['freq']
125
+ matrix['pattern'] = matrix[panel_names].astype(str).agg(''.join, axis=1)
126
+
127
+ grouped = matrix.groupby('pattern')
128
+ group_stats = grouped.first()[['freq', 'com']].sort_values(by=['freq', 'com'], ascending=[False, True])
129
+
130
+ left_part = []
131
+ right_part = []
132
+ center_part = []
133
+ mid_point = (num_panels - 1) / 2.0
134
+
135
+ for pattern, row in group_stats.iterrows():
136
+ group_markers = matrix[matrix['pattern'] == pattern].index.tolist()
137
+ group_markers.sort() # Alphabetical fallback
138
+
139
+ if row['freq'] == num_panels:
140
+ center_part.extend(group_markers)
141
+ else:
142
+ if row['com'] < mid_point:
143
+ left_part = group_markers + left_part
144
+ elif row['com'] > mid_point:
145
+ right_part = right_part + group_markers
146
+ else:
147
+ if len(left_part) <= len(right_part):
148
+ left_part = group_markers + left_part
149
+ else:
150
+ right_part = right_part + group_markers
151
+
152
+ sorted_markers = left_part + center_part + right_part
153
+
154
+ else:
155
+ raise ValueError("The 'alignment' parameter must be either 'left' or 'center'.")
156
+
157
+ # Setup the plot
158
+ fig, ax = plt.subplots(figsize=(14, len(panel_names) * 1.5))
159
+ rect_height = 0.6
160
+ colors = ['#72A0C1', '#D65A61', '#E8B358', '#7CE0C9', '#8291A8', '#C48CB3']
161
+
162
+ # Draw the blocks
163
+ for i, (name, markers) in enumerate(zip(panel_names, panel_markers)):
164
+ presence = [1 if m in markers else 0 for m in sorted_markers]
165
+ start_idx = None
166
+ for j, val in enumerate(presence):
167
+ if val == 1 and start_idx is None:
168
+ start_idx = j
169
+ elif val == 0 and start_idx is not None:
170
+ rect = patches.Rectangle(
171
+ (start_idx, i - rect_height/2), j - start_idx, rect_height,
172
+ edgecolor='black', facecolor=colors[i % len(colors)], alpha=0.9, linewidth=0.8
173
+ )
174
+ ax.add_patch(rect)
175
+ start_idx = None
176
+
177
+ # Catch end blocks
178
+ if start_idx is not None:
179
+ rect = patches.Rectangle(
180
+ (start_idx, i - rect_height/2), len(presence) - start_idx, rect_height,
181
+ edgecolor='black', facecolor=colors[i % len(colors)], alpha=0.9, linewidth=0.8
182
+ )
183
+ ax.add_patch(rect)
184
+
185
+ ax.text(-0.8, i, name, va='center', ha='right', fontsize=12, fontweight='bold')
186
+
187
+ # Formatting
188
+ ax.set_xlim(0, len(sorted_markers))
189
+ ax.set_ylim(-1, len(panel_names))
190
+ ax.set_yticks([])
191
+ ax.set_xticks(np.arange(len(sorted_markers)) + 0.5)
192
+ ax.set_xticklabels(sorted_markers, rotation=90, ha='center', fontsize=10)
193
+
194
+ # Remove borders
195
+ for spine in ['top', 'right', 'left']:
196
+ ax.spines[spine].set_visible(False)
197
+
198
+ ax.xaxis.grid(True, linestyle='--', alpha=0.4)
199
+ ax.set_axisbelow(True)
200
+
201
+ title_align = "Center" if alignment == 'center' else "Left"
202
+ plt.title(f"Panel Marker Alignment", fontsize=16, pad=25, fontweight='bold')
203
+ plt.tight_layout()
204
+ plt.show()
cystainer/modules.py ADDED
@@ -0,0 +1,277 @@
1
+ # modules
2
+ import numpy as np
3
+ import pandas as pd
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch.distributions import Normal
8
+ from torch.distributions import kl_divergence as kl
9
+
10
+
11
+ class CyStainerModel(nn.Module):
12
+ def __init__(self, marker_info, batch_dict, hidden_size=[512, 256, 128], feature_hidden_size=[128, 64, 32],
13
+ num_layers=3, num_heads=8, value_emb_dim=8, marker_emb_dim=8,
14
+ batch_emb_dim=8, latent_size=16, dropout=0.1):
15
+ super(CyStainerModel, self).__init__()
16
+ self.latent_size = latent_size
17
+ self.batch_dict = batch_dict
18
+ self.num_batches = len(batch_dict)
19
+ self.marker_info = marker_info
20
+ self.n_marker = len(marker_info['marker_list'])
21
+ self.value_emb_dim = value_emb_dim
22
+ self.marker_emb_dim = marker_emb_dim
23
+ self.batch_emb_dim = batch_emb_dim
24
+
25
+ self.marker_embedding = nn.Embedding(self.n_marker, marker_emb_dim)
26
+ self.batch_embedding = nn.Embedding(self.num_batches, batch_emb_dim)
27
+ self.value_projector = nn.Linear(1, value_emb_dim)
28
+ self.value_layer_norm = nn.LayerNorm(value_emb_dim)
29
+
30
+ transformer_input_dim = value_emb_dim + marker_emb_dim + batch_emb_dim
31
+
32
+ transformer_encoder_layer = nn.TransformerEncoderLayer(
33
+ d_model=transformer_input_dim,
34
+ nhead=num_heads,
35
+ dim_feedforward=hidden_size[0],
36
+ dropout=dropout,
37
+ activation='gelu',
38
+ batch_first=True
39
+ )
40
+ self.transformer_encoder = nn.TransformerEncoder(transformer_encoder_layer,
41
+ num_layers=num_layers,
42
+ enable_nested_tensor=False)
43
+ self.transformer_layer_norm = nn.LayerNorm(transformer_input_dim)
44
+
45
+ self.feature_encoder = nn.Sequential(
46
+ nn.Linear(transformer_input_dim, feature_hidden_size[0]),
47
+ nn.Dropout(dropout),
48
+ nn.SiLU(),
49
+ nn.Linear(feature_hidden_size[0], feature_hidden_size[1]),
50
+ nn.Dropout(dropout),
51
+ nn.SiLU(),
52
+ nn.Linear(feature_hidden_size[1], feature_hidden_size[2]))
53
+
54
+ self.latent_projector = nn.Sequential(
55
+ nn.Linear(feature_hidden_size[2] * self.n_marker, hidden_size[0]),
56
+ nn.Dropout(dropout),
57
+ nn.SiLU(),
58
+ nn.Linear(hidden_size[0], hidden_size[1]),
59
+ nn.Dropout(dropout),
60
+ nn.SiLU())
61
+
62
+ self.mu_encoder = nn.Linear(hidden_size[1], latent_size)
63
+ self.logvar_encoder = nn.Linear(hidden_size[1], latent_size)
64
+
65
+ decoder_input_dim = latent_size + batch_emb_dim
66
+
67
+ self.fc_decoder = nn.Sequential(
68
+ nn.Linear(decoder_input_dim, self.n_marker),
69
+ nn.Dropout(dropout),
70
+ nn.SiLU(),
71
+ nn.Linear(self.n_marker, self.n_marker),
72
+ nn.Dropout(dropout),
73
+ nn.SiLU(),
74
+ nn.Linear(self.n_marker, self.n_marker))
75
+
76
+ def encode(self, x, batch_idx, attention_mask=None):
77
+ batch_size, seq_len = x.shape
78
+ device = x.device
79
+
80
+ marker_indices = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
81
+ marker_emb = self.marker_embedding(marker_indices)
82
+
83
+ batch_emb_flat = self.batch_embedding(batch_idx)
84
+ batch_emb = batch_emb_flat.unsqueeze(1).expand(-1, seq_len, -1)
85
+
86
+ value_emb = self.value_projector(x.unsqueeze(-1))
87
+ value_emb = self.value_layer_norm(value_emb)
88
+
89
+ combined_emb = torch.cat([value_emb, marker_emb, batch_emb], dim=-1)
90
+
91
+ transformer_output = self.transformer_encoder(
92
+ combined_emb, src_key_padding_mask=attention_mask
93
+ )
94
+ normalized_output = self.transformer_layer_norm(transformer_output)
95
+
96
+ encoded_features = self.feature_encoder(normalized_output).flatten(1,-1)
97
+ latent_projection = self.latent_projector(encoded_features)
98
+ mu = self.mu_encoder(latent_projection)
99
+ logvar = self.logvar_encoder(latent_projection)
100
+ var = torch.exp(logvar) + 1e-4
101
+ dist = Normal(mu, var.sqrt())
102
+ latent_sample = dist.rsample()
103
+ return mu, var, latent_sample
104
+
105
+ def decode(self, z, batch_idx):
106
+ batch_emb = self.batch_embedding(batch_idx)
107
+ decode_input = torch.cat([z, batch_emb], dim=-1)
108
+ x = self.fc_decoder(decode_input)
109
+ return x
110
+
111
+ def forward(self, x, batch_idx, attention_mask=None):
112
+ mu, var, latent_sample = self.encode(x, batch_idx, attention_mask)
113
+ z = mu if not self.training else latent_sample
114
+ x = self.decode(z, batch_idx)
115
+ return x, mu, var, latent_sample
116
+
117
+ def add_new_batches(self, n_new_batches):
118
+ """
119
+ Expands the embedding layer to accommodate new batches.
120
+ Useful for fine-tuning on new datasets without retraining entirely.
121
+ """
122
+ old_embedding = self.batch_embedding
123
+ old_num_embeddings = old_embedding.num_embeddings
124
+ new_num_embeddings = old_num_embeddings + n_new_batches
125
+
126
+ # Create new embedding layer
127
+ device = old_embedding.weight.device
128
+ new_embedding = nn.Embedding(new_num_embeddings, old_embedding.embedding_dim, device=device)
129
+
130
+ # Initialize weights (copy old ones, randomize new ones)
131
+ with torch.no_grad():
132
+ new_embedding.weight[:old_num_embeddings].copy_(old_embedding.weight)
133
+
134
+ self.batch_embedding = new_embedding
135
+ return new_embedding
136
+
137
+ @torch.no_grad()
138
+ def correct_batch(self, x, source_batch_idx, target_batch_idx, attention_mask=None):
139
+ """
140
+ Translates cells from their original batch to a reference batch distribution.
141
+ """
142
+ self.eval()
143
+ mu, _, _ = self.encode(x, source_batch_idx, attention_mask)
144
+ corrected_x = self.decode(mu, target_batch_idx)
145
+ return corrected_x
146
+
147
+
148
+ def get_marker_list(df_list):
149
+ all_features = []
150
+ for df in df_list:
151
+ all_features.extend(df.columns.tolist())
152
+ marker_list = sorted(list(pd.Series(all_features).unique()))
153
+ is_overlap = [
154
+ all(marker in df.columns for df in df_list) for marker in marker_list
155
+ ]
156
+ shared_markers = [marker for marker, is_o in zip(marker_list, is_overlap) if is_o]
157
+ unique_markers = [marker for marker, is_o in zip(marker_list, is_overlap) if not is_o]
158
+
159
+ return marker_list, shared_markers, unique_markers
160
+
161
+
162
+ def pad_and_get_pad_mask(df_list, marker_list):
163
+ df_list_copy = df_list.copy()
164
+ pad_mask = [
165
+ np.tile(
166
+ np.array([0.0 if marker not in df.columns else 1.0 for marker in marker_list], dtype=np.float32),
167
+ (len(df), 1)
168
+ )
169
+ for df in df_list_copy
170
+ ]
171
+
172
+ for i, df in enumerate(df_list_copy):
173
+ df = df.reindex(columns=marker_list, fill_value=0.0)
174
+ df_list_copy[i] = df
175
+
176
+ return df_list_copy, pad_mask
177
+
178
+
179
+ def mask_batch(batch,
180
+ marker_list=None,
181
+ max_mask_rate_overlap=0.3,
182
+ separated_mask_rate=True,
183
+ max_mask_rate_unique=1.0,
184
+ shared_markers=None,
185
+ unique_markers=None):
186
+
187
+ batch_x_clone = batch['x'].clone()
188
+ pad_mask = batch['pad'].bool()
189
+
190
+ batch_size, seq_len = batch_x_clone.shape
191
+ device = batch_x_clone.device
192
+
193
+ def get_mask_for_top_k(scores, is_valid_token_mask, n_to_mask):
194
+ """Helper to select top-k random scores for masking."""
195
+ scores[~is_valid_token_mask] = float('inf')
196
+ sorted_indices = torch.argsort(scores, dim=1)
197
+ ranks = torch.empty_like(sorted_indices)
198
+ row_ranks = torch.arange(seq_len, device=device).expand(batch_size, -1)
199
+ ranks.scatter_(1, sorted_indices, row_ranks)
200
+ return ranks < n_to_mask.unsqueeze(1)
201
+
202
+ rand_scores = torch.rand(batch_size, seq_len, device=device)
203
+ final_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=device)
204
+
205
+ if separated_mask_rate:
206
+ if marker_list is None or shared_markers is None or unique_markers is None:
207
+ raise ValueError(
208
+ "marker_list, shared_markers, and unique_markers must be provided "
209
+ "when separated_mask_rate is True."
210
+ )
211
+
212
+ # Convert to sets for faster lookup
213
+ shared_set = set(shared_markers)
214
+ unique_set = set(unique_markers)
215
+
216
+ # Create 1D boolean masks for the columns
217
+ is_shared_col = torch.tensor([m in shared_set for m in marker_list], device=device)
218
+ is_unique_col = torch.tensor([m in unique_set for m in marker_list], device=device)
219
+
220
+ # A marker is valid to mask if it matches the group AND is present in the sample
221
+ is_overlap = is_shared_col.unsqueeze(0) & pad_mask
222
+ is_unique = is_unique_col.unsqueeze(0) & pad_mask
223
+
224
+ # Calculate number of markers to mask for overlap and unique sets
225
+ max_n_overlap = is_overlap.sum(dim=1) * max_mask_rate_overlap
226
+ max_n_unique = is_unique.sum(dim=1) * max_mask_rate_unique
227
+
228
+ n_to_mask_overlap = torch.floor(torch.rand(batch_size, device=device) * (max_n_overlap.float() + 1))
229
+ n_to_mask_unique = torch.floor(torch.rand(batch_size, device=device) * (max_n_unique.float() + 1))
230
+
231
+ # Get masks and combine them
232
+ mask_overlap = get_mask_for_top_k(rand_scores.clone(), is_overlap, n_to_mask_overlap)
233
+ mask_unique = get_mask_for_top_k(rand_scores.clone(), is_unique, n_to_mask_unique)
234
+ final_mask = mask_overlap | mask_unique
235
+
236
+ else: # Standard masking mode
237
+ max_n_to_mask = pad_mask.sum(dim=1) * max_mask_rate_overlap
238
+ n_to_mask = torch.floor(torch.rand(batch_size, device=device) * (max_n_to_mask.float() + 1))
239
+ final_mask = get_mask_for_top_k(rand_scores, pad_mask, n_to_mask)
240
+
241
+ # Apply the final mask: Set masked expression values to 0.0
242
+ batch_x_clone[final_mask] = 0.0
243
+
244
+ return {
245
+ 'x': batch_x_clone,
246
+ 'batch': batch['batch'].clone(),
247
+ 'pad': batch['pad'].clone(),
248
+ 'mask': final_mask | ~batch['pad'].bool()
249
+ }
250
+
251
+
252
+ def loss_function(recon_x, x, pad, mu, var, kl_weight):
253
+ """
254
+ Calculates the VAE loss, combining reconstruction and KL divergence.
255
+
256
+ Args:
257
+ recon_x (torch.Tensor): The reconstructed data from the model.
258
+ x (torch.Tensor): The original input data.
259
+ pad_mask (torch.Tensor): A boolean mask where True indicates non-padded
260
+ values that should be included in the loss.
261
+ mu (torch.Tensor): The mean of the latent distribution.
262
+ var (torch.Tensor): The variance of the latent distribution.
263
+ kl_weight (float): A weight factor for the KL divergence term, often
264
+ used for annealing.
265
+
266
+ Returns:
267
+ torch.Tensor: The total calculated loss.
268
+ """
269
+ # Reconstruction loss (Huber loss) on non-padded values
270
+ recon_loss = F.huber_loss(recon_x[pad], x[pad])
271
+
272
+ # KL divergence between the learned distribution and a standard normal
273
+ prior = Normal(torch.zeros_like(mu), torch.ones_like(var))
274
+ kl_div = kl(Normal(mu, var.sqrt()), prior).sum(dim=1).mean()
275
+ kl_loss = kl_div * kl_weight
276
+
277
+ return recon_loss + kl_loss
cystainer/runner.py ADDED
@@ -0,0 +1,318 @@
1
+ import os
2
+ import torch
3
+ from torch.utils.data import RandomSampler
4
+ import pandas as pd
5
+ import anndata as ad
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+
9
+ from .modules import CyStainerModel, mask_batch, loss_function
10
+ from .data import load_data_from_folder
11
+
12
+ def get_default_device():
13
+ """Helper function to automatically detect CUDA, MPS, or CPU."""
14
+ if torch.cuda.is_available():
15
+ return torch.device('cuda')
16
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
17
+ return torch.device('mps')
18
+ return torch.device('cpu')
19
+
20
+ class CyStainer:
21
+ def __init__(self, marker_info=None, batch_dict=None, model=None, device=None, **model_kwargs):
22
+ """
23
+ Initializes the CyStainer wrapper. Can be initialized empty and populated later
24
+ via .load_train_data() and .build_model().
25
+ """
26
+ self.device = device or get_default_device()
27
+
28
+ self.marker_info = marker_info
29
+ self.batch_dict = batch_dict
30
+ self.train_dataloader = None
31
+ self.finetune_dataloader = None
32
+ self.predict_dataloader = None
33
+ self.model = None
34
+
35
+ if model is not None:
36
+ self.model = model.to(self.device)
37
+ self.marker_info = model.marker_info
38
+ self.batch_dict = model.batch_dict
39
+ elif marker_info is not None and batch_dict is not None:
40
+ self.build_model(**model_kwargs)
41
+
42
+ def __getstate__(self):
43
+ """
44
+ Dictates what gets saved when torch.save(stainer) is called.
45
+ We create a copy of the object's dictionary and remove the dataloader.
46
+ """
47
+ state = self.__dict__.copy()
48
+ state['train_dataloader'] = None
49
+ state['finetune_dataloader'] = None
50
+ state['predict_dataloader'] = None
51
+ return state
52
+
53
+ def load_train_data(self, folder_path=None, adata_list=None, **kwargs):
54
+ """
55
+ Convenience method to load data from a folder or an AnnData list
56
+ and populate the class state.
57
+ """
58
+ source_msg = folder_path if folder_path else "provided AnnData list"
59
+ print(f"Loading data from {source_msg}...")
60
+
61
+ self.train_dataloader, self.marker_info, self.batch_dict = load_data_from_folder(
62
+ folder_path=folder_path,
63
+ adata_list=adata_list,
64
+ **kwargs
65
+ )
66
+ print(f"Data loaded. Found {len(self.marker_info['marker_list'])} markers and {len(self.batch_dict)} batches.")
67
+ return self
68
+
69
+ def load_finetune_data(self, folder_path=None, adata_list=None, **kwargs):
70
+ """
71
+ Loads a new dataset for fine-tuning, using the existing model's
72
+ markers and batches as a reference to align the new data.
73
+ """
74
+ if self.marker_info is None or self.batch_dict is None:
75
+ raise RuntimeError("Base model state not found. Train or load a base model first before loading fine-tuning data.")
76
+
77
+ source_msg = folder_path if folder_path else "provided AnnData list"
78
+ print(f"Loading fine-tuning data from {source_msg}...")
79
+
80
+ self.finetune_dataloader, _, self.ft_batch_dict = load_data_from_folder(
81
+ folder_path=folder_path,
82
+ adata_list=adata_list,
83
+ reference_markers=self.marker_info,
84
+ reference_batch_dict=self.batch_dict,
85
+ **kwargs
86
+ )
87
+
88
+ ft_batches_count = len(self.ft_batch_dict) - len(self.batch_dict)
89
+ print(f"Fine-tuning data loaded. Found {ft_batches_count} new batches to integrate.")
90
+ return self
91
+
92
+ def load_predict_data(self, folder_path=None, adata_list=None, **kwargs):
93
+ """
94
+ Loads data specifically for inference. Forces is_train=False to
95
+ ensure the predicted cells maintain their original order.
96
+ """
97
+ if self.marker_info is None or self.batch_dict is None:
98
+ raise RuntimeError("Model state missing. Train or load a model first.")
99
+
100
+ source_msg = folder_path if folder_path else "provided AnnData list"
101
+ print(f"Loading prediction data from {source_msg}...")
102
+
103
+ # Force is_train to False to prevent shuffling!
104
+ kwargs['is_train'] = False
105
+
106
+ self.predict_dataloader, _, _ = load_data_from_folder(
107
+ folder_path=folder_path,
108
+ adata_list=adata_list,
109
+ reference_markers=self.marker_info,
110
+ reference_batch_dict=self.batch_dict,
111
+ **kwargs
112
+ )
113
+ print("Prediction data loaded successfully.")
114
+ return self
115
+
116
+ def build_model(self, **model_kwargs):
117
+ """
118
+ Initializes the PyTorch model using the state extracted from load_train_data.
119
+ """
120
+ if self.marker_info is None or self.batch_dict is None:
121
+ raise RuntimeError("Cannot build model: marker_info and batch_dict are missing. Run .load_train_data() first.")
122
+
123
+ self.model = CyStainerModel(
124
+ marker_info=self.marker_info,
125
+ batch_dict=self.batch_dict,
126
+ **model_kwargs
127
+ ).to(self.device)
128
+ print("Model initialized successfully.")
129
+ return self
130
+
131
+ def train(self, dataloader=None, max_epochs=300, lr=0.0001, kl_weight=0.0001, tol=1e-4, patience=3,
132
+ model_name='cystainer.pt', loss_name='loss.csv', metrics_path='./metrics_and_loss', separated_mask_rate=True):
133
+ """Trains the model from scratch with early stopping."""
134
+ if self.model is None:
135
+ raise RuntimeError("Model is not built. Call .build_model() before training.")
136
+
137
+ train_loader = dataloader or self.train_dataloader
138
+ if train_loader is None:
139
+ raise ValueError("No dataloader found. Please provide one or run .load_train_data() first.")
140
+
141
+ optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)
142
+ self._run_loop(train_loader, optimizer, max_epochs, kl_weight, tol, patience, model_name, loss_name, metrics_path, separated_mask_rate)
143
+
144
+ def finetune(self, ft_loader=None, ft_batch_dict=None, max_epochs=50, lr=0.001, kl_weight=0.0001, tol=1e-4, patience=3,
145
+ model_name='cystainer_finetuned.pt', loss_name='loss_finetuned.csv', metrics_path='./metrics_and_loss', separated_mask_rate=True):
146
+ """Fine-tunes only the batch embeddings for new data with early stopping."""
147
+ if self.model is None:
148
+ raise RuntimeError("Model is not trained or loaded.")
149
+
150
+ # Fallback to internal state if arguments aren't provided
151
+ ft_loader = ft_loader or self.finetune_dataloader
152
+ ft_batch_dict = ft_batch_dict or self.ft_batch_dict
153
+
154
+ if ft_loader is None or ft_batch_dict is None:
155
+ raise ValueError("Missing fine-tuning data. Run .load_finetune_data() first or pass the dataloader and new_batch_dict explicitly.")
156
+
157
+ # Calculate how many new batches were added
158
+ ft_batches_count = len(ft_batch_dict) - len(self.batch_dict)
159
+ if ft_batches_count > 0:
160
+ self.model.add_new_batches(ft_batches_count)
161
+ self.batch_dict = ft_batch_dict
162
+
163
+ # Freeze all parameters except batch_embedding
164
+ for param in self.model.parameters():
165
+ param.requires_grad = False
166
+ self.model.batch_embedding.weight.requires_grad = True
167
+
168
+ # Calculate the index where new batches start
169
+ old_num_embeddings = len(self.batch_dict) - ft_batches_count
170
+
171
+ # Register a backward hook to freeze old batch embeddings
172
+ def freeze_old_embeddings(grad):
173
+ grad_clone = grad.clone()
174
+ grad_clone[:old_num_embeddings] = 0.0
175
+ return grad_clone
176
+
177
+ self.model.batch_embedding.weight.register_hook(freeze_old_embeddings)
178
+
179
+ optimizer = torch.optim.AdamW(self.model.batch_embedding.parameters(), lr=lr)
180
+ self._run_loop(ft_loader, optimizer, max_epochs, kl_weight, tol, patience, model_name, loss_name, metrics_path, separated_mask_rate)
181
+
182
+ def predict(self, dataloader=None, output_path='adata_pred.h5ad', keep_markers=None, correct_batch=False, target_batch_name=None, return_pred=False):
183
+ """Runs inference and saves the predicted AnnData object."""
184
+ if self.model is None:
185
+ raise RuntimeError("Model is not built or loaded.")
186
+
187
+ # Fallback logic: User provided -> Predict loader -> Finetune loader -> Train loader
188
+ pred_loader = dataloader or self.predict_dataloader or self.finetune_dataloader or self.train_dataloader
189
+
190
+ if pred_loader is None:
191
+ raise ValueError("No dataloader found. Pass one explicitly or run .load_predict_data() first.")
192
+
193
+ # Block predictions on shuffled dataloaders
194
+ if isinstance(pred_loader.sampler, RandomSampler):
195
+ raise RuntimeError(
196
+ "CRITICAL: You are attempting to predict on a shuffled DataLoader! "
197
+ "The output rows will NOT correspond to the original order of your cells in the AnnData object. "
198
+ "You must load the data using .load_predict_data() before running .predict()."
199
+ )
200
+
201
+ target_batch_int = 0
202
+ if correct_batch and target_batch_name is not None:
203
+ if target_batch_name not in self.batch_dict:
204
+ raise ValueError(f"Batch name '{target_batch_name}' not found. Available batches: {list(self.batch_dict.keys())}")
205
+ target_batch_int = self.batch_dict[target_batch_name]
206
+
207
+ self.model.eval()
208
+ result = []
209
+
210
+ if correct_batch:
211
+ with torch.no_grad():
212
+ for batch in tqdm(pred_loader, desc="Predicting (Batch Corrected)"):
213
+ batch = {key: value.to(self.device) for key, value in batch.items()}
214
+ target_batch = torch.full_like(batch['batch'], fill_value=target_batch_int)
215
+ x_pred = self.model.correct_batch(batch['x'], batch['batch'], target_batch, ~batch['pad'].bool())
216
+ result.append(x_pred.cpu().numpy())
217
+ else:
218
+ with torch.no_grad():
219
+ for batch in tqdm(pred_loader, desc="Predicting"):
220
+ batch = {key: value.to(self.device) for key, value in batch.items()}
221
+ x_pred, _, _, _ = self.model.forward(batch['x'], batch['batch'], ~batch['pad'].bool())
222
+ result.append(x_pred.cpu().numpy())
223
+
224
+ df_pred = pd.DataFrame(np.concatenate(result), columns=self.marker_info['marker_list'])
225
+ if keep_markers is not None:
226
+ df_pred = df_pred[keep_markers]
227
+ if return_pred:
228
+ return df_pred
229
+ else:
230
+ ad.AnnData(df_pred).write(output_path, compression='gzip')
231
+ print(f"Predictions saved to {output_path}")
232
+
233
+ def save(self, losses=None, model_name=None, loss_name=None, metrics_path=None):
234
+ """Saves the training losses and the entire CyStainer object."""
235
+ # Save losses if provided
236
+ if losses is not None:
237
+ if not os.path.exists(metrics_path):
238
+ os.makedirs(metrics_path, exist_ok=True)
239
+ df_loss = pd.DataFrame(losses, columns=['loss'])
240
+ df_loss.to_csv(f'{metrics_path}/{loss_name}', index=False)
241
+
242
+ torch.save(self, model_name)
243
+
244
+ def _run_loop(self, dataloader, optimizer, max_epochs, kl_weight, tol, patience, model_name, loss_name, metrics_path, separated_mask_rate):
245
+ """Internal training loop with early stopping mechanism."""
246
+ self.model.train()
247
+ pbar = tqdm(range(max_epochs))
248
+
249
+ prev_loss = float('inf')
250
+ patience_counter = 0
251
+ losses = []
252
+
253
+ for epoch in pbar:
254
+ epoch_losses = []
255
+
256
+ for batch in dataloader:
257
+ batch = {key: value.to(self.device) for key, value in batch.items()}
258
+ batch_masked = mask_batch(
259
+ batch,
260
+ self.marker_info['marker_list'],
261
+ separated_mask_rate=separated_mask_rate,
262
+ shared_markers=self.marker_info['shared_markers'],
263
+ unique_markers=self.marker_info['unique_markers']
264
+ )
265
+
266
+ x_pred, mu, var, _ = self.model.forward(batch_masked['x'], batch_masked['batch'], batch_masked['mask'].bool())
267
+ loss = loss_function(x_pred, batch['x'], batch['pad'].bool(), mu, var, kl_weight)
268
+
269
+ optimizer.zero_grad()
270
+ loss.backward()
271
+ optimizer.step()
272
+
273
+ epoch_losses.append(loss.item())
274
+
275
+ # Early Stopping Logic
276
+ avg_loss = np.mean(epoch_losses)
277
+ losses.append(avg_loss)
278
+ pbar.set_description(f'Loss: {avg_loss:.4f}')
279
+
280
+ self.save(losses=losses, model_name=model_name, loss_name=loss_name, metrics_path=metrics_path)
281
+
282
+ loss_diff = abs(prev_loss - avg_loss)
283
+
284
+ if loss_diff < tol:
285
+ patience_counter += 1
286
+ else:
287
+ patience_counter = 0
288
+
289
+ if patience_counter >= patience:
290
+ # We use tqdm.write so it doesn't break the progress bar visual
291
+ tqdm.write(f"\nEarly stopping triggered at epoch {epoch + 1}. Loss change ({loss_diff:.6f}) is below tolerance ({tol}).")
292
+ break
293
+
294
+ prev_loss = avg_loss
295
+
296
+ @classmethod
297
+ def load(cls, filepath, device=None):
298
+ """
299
+ Safely loads a saved CyStainer object, automatically mapping all tensors
300
+ and internal states to the requested device (CPU, GPU, or MPS).
301
+ """
302
+ # Determine the target device
303
+ target_device = device or get_default_device()
304
+
305
+ print(f"Loading model from {filepath} onto {target_device}...")
306
+
307
+ # 1. Load the object, intercepting tensors and forcing them to the target device
308
+ stainer = torch.load(filepath, map_location=target_device, weights_only=False)
309
+
310
+ # 2. Update the internal device attribute so future dataloaders route correctly
311
+ stainer.device = target_device
312
+
313
+ # 3. Explicitly move the PyTorch model to the new device just to be perfectly safe
314
+ if stainer.model is not None:
315
+ stainer.model = stainer.model.to(target_device)
316
+
317
+ print("Model loaded successfully.")
318
+ return stainer
@@ -0,0 +1,170 @@
1
+ Metadata-Version: 2.4
2
+ Name: cystainer
3
+ Version: 0.1.0
4
+ Summary: A PyTorch-based tool for predicting missing proteins in cytometry/single-cell data
5
+ Author: Konstantin Ivanov
6
+ Author-email: kivanov@uef.fi
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Operating System :: OS Independent
10
+ Requires-Python: >=3.11.9
11
+ Description-Content-Type: text/markdown
12
+ License-File: LICENSE
13
+ Requires-Dist: torch>=2.7.1
14
+ Requires-Dist: numpy>=1.26.4
15
+ Requires-Dist: pandas>=2.2.3
16
+ Requires-Dist: anndata>=0.12.4
17
+ Requires-Dist: scipy>=1.14.1
18
+ Requires-Dist: scikit-learn>=1.5.1
19
+ Requires-Dist: tqdm>=4.66.5
20
+ Requires-Dist: matplotlib>=3.9.2
21
+ Dynamic: author
22
+ Dynamic: author-email
23
+ Dynamic: classifier
24
+ Dynamic: description
25
+ Dynamic: description-content-type
26
+ Dynamic: license-file
27
+ Dynamic: requires-dist
28
+ Dynamic: requires-python
29
+ Dynamic: summary
30
+
31
+ # CyStainer
32
+ CyStainer package for cytometry marker imputation
33
+
34
+ **CyStainer** is a PyTorch-based deep learning tool for predicting missing proteins and imputing marker expression in cytometry and single-cell data. It utilizes a combination of Variational Autoencoders (VAEs) and Transformer architectures to integrate multiple batches/panels and infer missing markers accurately.
35
+
36
+ ## 📦 Installation
37
+
38
+ Since CyStainer is built on PyTorch, ensure you have an environment with Python 3.11+ and the appropriate PyTorch version for your hardware (CUDA recommended).
39
+
40
+ You can install CyStainer directly from the source:
41
+
42
+ ```bash
43
+ git clone https://github.com/sysgen-uef/cystainer_package.git
44
+ cd cystainer_package
45
+ pip install .
46
+ ```
47
+ ## 🧹 Data Preprocessing (From .fcs to .h5ad)
48
+
49
+ **CyStainer** expects your input data to be formatted as `AnnData` objects (either passed as a list in memory or saved locally as `.h5ad` files). Before feeding your cytometry data into the model, it must be properly preprocessed.
50
+
51
+ **Mandatory Pre-processing:**
52
+ * **Cleaning:** Ensure your data is pre-gated to remove doublets, debris, and dead cells.
53
+ * **Compensation:** Your `.fcs` files should be already compensated.
54
+
55
+ **Transformation & Scaling (Recommended):**
56
+ You will typically need to transform your fluorescence intensities (e.g., using an arcsinh transformation) and optionally scale them. Removing saturated values (extreme highs or zeros) can also improve model performance depending on your dataset.
57
+
58
+ Here is a minimal example of how to process a standard `.fcs` file into a ready-to-use `.h5ad` file using `FlowKit` and `AnnData`:
59
+
60
+ ```python
61
+ import flowkit as fk
62
+ import pandas as pd
63
+ import numpy as np
64
+ import anndata as ad
65
+
66
+ # Load the compensated and cleaned .fcs file
67
+ sample = fk.Sample('path/to/cleaned_sample.fcs')
68
+ df = sample.as_dataframe('raw')
69
+ df.columns = sample.pnn_labels # Set marker names
70
+
71
+ # Transformation (e.g., arcsinh with a cofactor of 100 or 150)
72
+ cofactor = 100
73
+ non_scatter = ['FSC' not in c and 'SSC' not in c for c in df.columns]
74
+ df.loc[:, non_scatter] = np.arcsinh(df.loc[:, non_scatter] / cofactor)
75
+
76
+ # Optional: Remove saturation / extreme outliers
77
+ # Useful if your instrument records artificial bounds (e.g., exactly 0 or max value)
78
+ df = df[~np.any(((df <= 0) | (df >= df.max().max())), axis=1)]
79
+
80
+ # Optional: Scaling (Min-Max scaling to [0, 1] or Z-score normalization)
81
+ # Min-Max Scaling example:
82
+ df = (df - df.min()) / (df.max() - df.min())
83
+ # Z-score Scaling example (alternatively):
84
+ # df = (df - df.mean()) / df.std()
85
+
86
+ # Convert to AnnData and save for CyStainer
87
+ adata = ad.AnnData(df)
88
+ adata.write('preprocessed_sample.h5ad', compression='gzip')
89
+ ```
90
+
91
+ Once your `.fcs` files are converted into `.h5ad` objects, you can load them directly into your workflow.
92
+
93
+ ## 🚀 Quick Start Guide
94
+
95
+ The primary way to interact with the package is through the `CyStainer` wrapper class. The standard workflow consists of initializing the model, loading training data, building the network, and running inference.
96
+
97
+ ### 1. Training a Base Model
98
+
99
+ You can load your data either from a folder of `.h5ad` files or by passing a list of `AnnData` objects directly in memory.
100
+
101
+ ```python
102
+ from cystainer import CyStainer
103
+
104
+ # Initialize the stainer (automatically detects CUDA/CPU)
105
+ stainer = CyStainer()
106
+
107
+ # Load training data
108
+ # Alternatively, use: adata_list=[adata1, adata2]
109
+ # CyStainer includes a built-in utility to visualize how markers overlap across different panels or batches.
110
+ stainer.load_train_data(folder_path='./data_example/train', get_panel_vis=True)
111
+ ```
112
+ ![image](image/panel_alignment.png)
113
+
114
+ ```python
115
+ # Build the model
116
+ # You can pass custom hyperparameters here if needed
117
+ stainer.build_model()
118
+
119
+ # Train the model
120
+ stainer.train()
121
+
122
+ # By default the model is automatically saved in the same directory
123
+ # as cystainer.pt
124
+ ```
125
+
126
+ ### 2. Predicting Missing Markers
127
+
128
+ Once a model is trained, you can load inference data. The `.load_predict_data()` method ensures your cells are not shuffled so that the output matches your input order.
129
+
130
+ ```python
131
+ # Load prediction data using the base model's reference markers
132
+ stainer.load_predict_data(folder_path='./data_example/test')
133
+
134
+ # Run predictions and save directly to disk
135
+ stainer.predict(output_path='imputed_cells.h5ad')
136
+
137
+ # Alternatively, return the predictions as a pandas DataFrame:
138
+ # df_imputed = stainer.predict(return_pred=True)
139
+ ```
140
+
141
+ ### 3. Fine-Tuning on New Batches
142
+
143
+ If you receive new data from a different batch or panel, you don't need to retrain from scratch. CyStainer can freeze the main network and fine-tune only the batch embeddings to align the new data.
144
+
145
+ ```python
146
+ # Load the pre-trained model
147
+ stainer = CyStainer.load('cystainer.pt')
148
+
149
+ # Load the new data for fine-tuning
150
+ # Note, anndata objects must have batch info column
151
+ stainer.load_finetune_data(folder_path='./data_example/test', batch_info='batch')
152
+
153
+ # Fine-tune the batch embeddings
154
+ stainer.finetune()
155
+
156
+ # Predict on the newly fine-tuned data
157
+ stainer.predict(output_path='imputed_new_batch.h5ad')
158
+ ```
159
+
160
+ ### 4. Batch Correction
161
+
162
+ CyStainer allows you to translate cells from their original batch to a specific target batch distribution.
163
+
164
+ ```python
165
+ stainer.predict(
166
+ output_path='batch_corrected_cells.h5ad',
167
+ correct_batch=True,
168
+ target_batch_name='single_batch' # Must match a batch name in stainer.batch_dict
169
+ )
170
+ ```
@@ -0,0 +1,9 @@
1
+ cystainer/__init__.py,sha256=Yv_K9NHhvtxenToMvf3FadTmDS1YSmBwkbjBnkwewJM,301
2
+ cystainer/data.py,sha256=xS5-XX1NguRH4Mxa4otMjwzE00HlhJzwooz43aPD1gw,8321
3
+ cystainer/modules.py,sha256=rKZMW_-7yYJXSEJlSjTVtsY4cyFqT8_qwoXVHVfWrco,11195
4
+ cystainer/runner.py,sha256=CaEsn50m-v3YnyTSgBvHacEG7DbUA5ScsKIXrMkJv7Q,14413
5
+ cystainer-0.1.0.dist-info/licenses/LICENSE,sha256=0qOacCooGqkP7CL7qDsSOyMSmGtEb1s_Q1oSoZSrLxs,1067
6
+ cystainer-0.1.0.dist-info/METADATA,sha256=sSH7f70qUwymDJXee7Bq60cZFETGyfSRE2L2rCPmy3Q,6211
7
+ cystainer-0.1.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
8
+ cystainer-0.1.0.dist-info/top_level.txt,sha256=8kyy2H_PFxOOr53It06GAWAyUNDGd3Tw1ZThnhaSPcg,10
9
+ cystainer-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Sysgen lab
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1 @@
1
+ cystainer