cystainer 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cystainer-0.1.0/LICENSE +21 -0
- cystainer-0.1.0/PKG-INFO +170 -0
- cystainer-0.1.0/README.md +140 -0
- cystainer-0.1.0/cystainer/__init__.py +16 -0
- cystainer-0.1.0/cystainer/data.py +204 -0
- cystainer-0.1.0/cystainer/modules.py +277 -0
- cystainer-0.1.0/cystainer/runner.py +318 -0
- cystainer-0.1.0/cystainer.egg-info/PKG-INFO +170 -0
- cystainer-0.1.0/cystainer.egg-info/SOURCES.txt +12 -0
- cystainer-0.1.0/cystainer.egg-info/dependency_links.txt +1 -0
- cystainer-0.1.0/cystainer.egg-info/requires.txt +8 -0
- cystainer-0.1.0/cystainer.egg-info/top_level.txt +1 -0
- cystainer-0.1.0/setup.cfg +4 -0
- cystainer-0.1.0/setup.py +32 -0
cystainer-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Sysgen lab
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
cystainer-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: cystainer
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A PyTorch-based tool for predicting missing proteins in cytometry/single-cell data
|
|
5
|
+
Author: Konstantin Ivanov
|
|
6
|
+
Author-email: kivanov@uef.fi
|
|
7
|
+
Classifier: Programming Language :: Python :: 3
|
|
8
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
9
|
+
Classifier: Operating System :: OS Independent
|
|
10
|
+
Requires-Python: >=3.11.9
|
|
11
|
+
Description-Content-Type: text/markdown
|
|
12
|
+
License-File: LICENSE
|
|
13
|
+
Requires-Dist: torch>=2.7.1
|
|
14
|
+
Requires-Dist: numpy>=1.26.4
|
|
15
|
+
Requires-Dist: pandas>=2.2.3
|
|
16
|
+
Requires-Dist: anndata>=0.12.4
|
|
17
|
+
Requires-Dist: scipy>=1.14.1
|
|
18
|
+
Requires-Dist: scikit-learn>=1.5.1
|
|
19
|
+
Requires-Dist: tqdm>=4.66.5
|
|
20
|
+
Requires-Dist: matplotlib>=3.9.2
|
|
21
|
+
Dynamic: author
|
|
22
|
+
Dynamic: author-email
|
|
23
|
+
Dynamic: classifier
|
|
24
|
+
Dynamic: description
|
|
25
|
+
Dynamic: description-content-type
|
|
26
|
+
Dynamic: license-file
|
|
27
|
+
Dynamic: requires-dist
|
|
28
|
+
Dynamic: requires-python
|
|
29
|
+
Dynamic: summary
|
|
30
|
+
|
|
31
|
+
# CyStainer
|
|
32
|
+
CyStainer package for cytometry marker imputation
|
|
33
|
+
|
|
34
|
+
**CyStainer** is a PyTorch-based deep learning tool for predicting missing proteins and imputing marker expression in cytometry and single-cell data. It utilizes a combination of Variational Autoencoders (VAEs) and Transformer architectures to integrate multiple batches/panels and infer missing markers accurately.
|
|
35
|
+
|
|
36
|
+
## ๐ฆ Installation
|
|
37
|
+
|
|
38
|
+
Since CyStainer is built on PyTorch, ensure you have an environment with Python 3.11+ and the appropriate PyTorch version for your hardware (CUDA recommended).
|
|
39
|
+
|
|
40
|
+
You can install CyStainer directly from the source:
|
|
41
|
+
|
|
42
|
+
```bash
|
|
43
|
+
git clone https://github.com/sysgen-uef/cystainer_package.git
|
|
44
|
+
cd cystainer_package
|
|
45
|
+
pip install .
|
|
46
|
+
```
|
|
47
|
+
## ๐งน Data Preprocessing (From .fcs to .h5ad)
|
|
48
|
+
|
|
49
|
+
**CyStainer** expects your input data to be formatted as `AnnData` objects (either passed as a list in memory or saved locally as `.h5ad` files). Before feeding your cytometry data into the model, it must be properly preprocessed.
|
|
50
|
+
|
|
51
|
+
**Mandatory Pre-processing:**
|
|
52
|
+
* **Cleaning:** Ensure your data is pre-gated to remove doublets, debris, and dead cells.
|
|
53
|
+
* **Compensation:** Your `.fcs` files should be already compensated.
|
|
54
|
+
|
|
55
|
+
**Transformation & Scaling (Recommended):**
|
|
56
|
+
You will typically need to transform your fluorescence intensities (e.g., using an arcsinh transformation) and optionally scale them. Removing saturated values (extreme highs or zeros) can also improve model performance depending on your dataset.
|
|
57
|
+
|
|
58
|
+
Here is a minimal example of how to process a standard `.fcs` file into a ready-to-use `.h5ad` file using `FlowKit` and `AnnData`:
|
|
59
|
+
|
|
60
|
+
```python
|
|
61
|
+
import flowkit as fk
|
|
62
|
+
import pandas as pd
|
|
63
|
+
import numpy as np
|
|
64
|
+
import anndata as ad
|
|
65
|
+
|
|
66
|
+
# Load the compensated and cleaned .fcs file
|
|
67
|
+
sample = fk.Sample('path/to/cleaned_sample.fcs')
|
|
68
|
+
df = sample.as_dataframe('raw')
|
|
69
|
+
df.columns = sample.pnn_labels # Set marker names
|
|
70
|
+
|
|
71
|
+
# Transformation (e.g., arcsinh with a cofactor of 100 or 150)
|
|
72
|
+
cofactor = 100
|
|
73
|
+
non_scatter = ['FSC' not in c and 'SSC' not in c for c in df.columns]
|
|
74
|
+
df.loc[:, non_scatter] = np.arcsinh(df.loc[:, non_scatter] / cofactor)
|
|
75
|
+
|
|
76
|
+
# Optional: Remove saturation / extreme outliers
|
|
77
|
+
# Useful if your instrument records artificial bounds (e.g., exactly 0 or max value)
|
|
78
|
+
df = df[~np.any(((df <= 0) | (df >= df.max().max())), axis=1)]
|
|
79
|
+
|
|
80
|
+
# Optional: Scaling (Min-Max scaling to [0, 1] or Z-score normalization)
|
|
81
|
+
# Min-Max Scaling example:
|
|
82
|
+
df = (df - df.min()) / (df.max() - df.min())
|
|
83
|
+
# Z-score Scaling example (alternatively):
|
|
84
|
+
# df = (df - df.mean()) / df.std()
|
|
85
|
+
|
|
86
|
+
# Convert to AnnData and save for CyStainer
|
|
87
|
+
adata = ad.AnnData(df)
|
|
88
|
+
adata.write('preprocessed_sample.h5ad', compression='gzip')
|
|
89
|
+
```
|
|
90
|
+
|
|
91
|
+
Once your `.fcs` files are converted into `.h5ad` objects, you can load them directly into your workflow.
|
|
92
|
+
|
|
93
|
+
## ๐ Quick Start Guide
|
|
94
|
+
|
|
95
|
+
The primary way to interact with the package is through the `CyStainer` wrapper class. The standard workflow consists of initializing the model, loading training data, building the network, and running inference.
|
|
96
|
+
|
|
97
|
+
### 1. Training a Base Model
|
|
98
|
+
|
|
99
|
+
You can load your data either from a folder of `.h5ad` files or by passing a list of `AnnData` objects directly in memory.
|
|
100
|
+
|
|
101
|
+
```python
|
|
102
|
+
from cystainer import CyStainer
|
|
103
|
+
|
|
104
|
+
# Initialize the stainer (automatically detects CUDA/CPU)
|
|
105
|
+
stainer = CyStainer()
|
|
106
|
+
|
|
107
|
+
# Load training data
|
|
108
|
+
# Alternatively, use: adata_list=[adata1, adata2]
|
|
109
|
+
# CyStainer includes a built-in utility to visualize how markers overlap across different panels or batches.
|
|
110
|
+
stainer.load_train_data(folder_path='./data_example/train', get_panel_vis=True)
|
|
111
|
+
```
|
|
112
|
+

|
|
113
|
+
|
|
114
|
+
```python
|
|
115
|
+
# Build the model
|
|
116
|
+
# You can pass custom hyperparameters here if needed
|
|
117
|
+
stainer.build_model()
|
|
118
|
+
|
|
119
|
+
# Train the model
|
|
120
|
+
stainer.train()
|
|
121
|
+
|
|
122
|
+
# By default the model is automatically saved in the same directory
|
|
123
|
+
# as cystainer.pt
|
|
124
|
+
```
|
|
125
|
+
|
|
126
|
+
### 2. Predicting Missing Markers
|
|
127
|
+
|
|
128
|
+
Once a model is trained, you can load inference data. The `.load_predict_data()` method ensures your cells are not shuffled so that the output matches your input order.
|
|
129
|
+
|
|
130
|
+
```python
|
|
131
|
+
# Load prediction data using the base model's reference markers
|
|
132
|
+
stainer.load_predict_data(folder_path='./data_example/test')
|
|
133
|
+
|
|
134
|
+
# Run predictions and save directly to disk
|
|
135
|
+
stainer.predict(output_path='imputed_cells.h5ad')
|
|
136
|
+
|
|
137
|
+
# Alternatively, return the predictions as a pandas DataFrame:
|
|
138
|
+
# df_imputed = stainer.predict(return_pred=True)
|
|
139
|
+
```
|
|
140
|
+
|
|
141
|
+
### 3. Fine-Tuning on New Batches
|
|
142
|
+
|
|
143
|
+
If you receive new data from a different batch or panel, you don't need to retrain from scratch. CyStainer can freeze the main network and fine-tune only the batch embeddings to align the new data.
|
|
144
|
+
|
|
145
|
+
```python
|
|
146
|
+
# Load the pre-trained model
|
|
147
|
+
stainer = CyStainer.load('cystainer.pt')
|
|
148
|
+
|
|
149
|
+
# Load the new data for fine-tuning
|
|
150
|
+
# Note, anndata objects must have batch info column
|
|
151
|
+
stainer.load_finetune_data(folder_path='./data_example/test', batch_info='batch')
|
|
152
|
+
|
|
153
|
+
# Fine-tune the batch embeddings
|
|
154
|
+
stainer.finetune()
|
|
155
|
+
|
|
156
|
+
# Predict on the newly fine-tuned data
|
|
157
|
+
stainer.predict(output_path='imputed_new_batch.h5ad')
|
|
158
|
+
```
|
|
159
|
+
|
|
160
|
+
### 4. Batch Correction
|
|
161
|
+
|
|
162
|
+
CyStainer allows you to translate cells from their original batch to a specific target batch distribution.
|
|
163
|
+
|
|
164
|
+
```python
|
|
165
|
+
stainer.predict(
|
|
166
|
+
output_path='batch_corrected_cells.h5ad',
|
|
167
|
+
correct_batch=True,
|
|
168
|
+
target_batch_name='single_batch' # Must match a batch name in stainer.batch_dict
|
|
169
|
+
)
|
|
170
|
+
```
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
# CyStainer
|
|
2
|
+
CyStainer package for cytometry marker imputation
|
|
3
|
+
|
|
4
|
+
**CyStainer** is a PyTorch-based deep learning tool for predicting missing proteins and imputing marker expression in cytometry and single-cell data. It utilizes a combination of Variational Autoencoders (VAEs) and Transformer architectures to integrate multiple batches/panels and infer missing markers accurately.
|
|
5
|
+
|
|
6
|
+
## ๐ฆ Installation
|
|
7
|
+
|
|
8
|
+
Since CyStainer is built on PyTorch, ensure you have an environment with Python 3.11+ and the appropriate PyTorch version for your hardware (CUDA recommended).
|
|
9
|
+
|
|
10
|
+
You can install CyStainer directly from the source:
|
|
11
|
+
|
|
12
|
+
```bash
|
|
13
|
+
git clone https://github.com/sysgen-uef/cystainer_package.git
|
|
14
|
+
cd cystainer_package
|
|
15
|
+
pip install .
|
|
16
|
+
```
|
|
17
|
+
## ๐งน Data Preprocessing (From .fcs to .h5ad)
|
|
18
|
+
|
|
19
|
+
**CyStainer** expects your input data to be formatted as `AnnData` objects (either passed as a list in memory or saved locally as `.h5ad` files). Before feeding your cytometry data into the model, it must be properly preprocessed.
|
|
20
|
+
|
|
21
|
+
**Mandatory Pre-processing:**
|
|
22
|
+
* **Cleaning:** Ensure your data is pre-gated to remove doublets, debris, and dead cells.
|
|
23
|
+
* **Compensation:** Your `.fcs` files should be already compensated.
|
|
24
|
+
|
|
25
|
+
**Transformation & Scaling (Recommended):**
|
|
26
|
+
You will typically need to transform your fluorescence intensities (e.g., using an arcsinh transformation) and optionally scale them. Removing saturated values (extreme highs or zeros) can also improve model performance depending on your dataset.
|
|
27
|
+
|
|
28
|
+
Here is a minimal example of how to process a standard `.fcs` file into a ready-to-use `.h5ad` file using `FlowKit` and `AnnData`:
|
|
29
|
+
|
|
30
|
+
```python
|
|
31
|
+
import flowkit as fk
|
|
32
|
+
import pandas as pd
|
|
33
|
+
import numpy as np
|
|
34
|
+
import anndata as ad
|
|
35
|
+
|
|
36
|
+
# Load the compensated and cleaned .fcs file
|
|
37
|
+
sample = fk.Sample('path/to/cleaned_sample.fcs')
|
|
38
|
+
df = sample.as_dataframe('raw')
|
|
39
|
+
df.columns = sample.pnn_labels # Set marker names
|
|
40
|
+
|
|
41
|
+
# Transformation (e.g., arcsinh with a cofactor of 100 or 150)
|
|
42
|
+
cofactor = 100
|
|
43
|
+
non_scatter = ['FSC' not in c and 'SSC' not in c for c in df.columns]
|
|
44
|
+
df.loc[:, non_scatter] = np.arcsinh(df.loc[:, non_scatter] / cofactor)
|
|
45
|
+
|
|
46
|
+
# Optional: Remove saturation / extreme outliers
|
|
47
|
+
# Useful if your instrument records artificial bounds (e.g., exactly 0 or max value)
|
|
48
|
+
df = df[~np.any(((df <= 0) | (df >= df.max().max())), axis=1)]
|
|
49
|
+
|
|
50
|
+
# Optional: Scaling (Min-Max scaling to [0, 1] or Z-score normalization)
|
|
51
|
+
# Min-Max Scaling example:
|
|
52
|
+
df = (df - df.min()) / (df.max() - df.min())
|
|
53
|
+
# Z-score Scaling example (alternatively):
|
|
54
|
+
# df = (df - df.mean()) / df.std()
|
|
55
|
+
|
|
56
|
+
# Convert to AnnData and save for CyStainer
|
|
57
|
+
adata = ad.AnnData(df)
|
|
58
|
+
adata.write('preprocessed_sample.h5ad', compression='gzip')
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
Once your `.fcs` files are converted into `.h5ad` objects, you can load them directly into your workflow.
|
|
62
|
+
|
|
63
|
+
## ๐ Quick Start Guide
|
|
64
|
+
|
|
65
|
+
The primary way to interact with the package is through the `CyStainer` wrapper class. The standard workflow consists of initializing the model, loading training data, building the network, and running inference.
|
|
66
|
+
|
|
67
|
+
### 1. Training a Base Model
|
|
68
|
+
|
|
69
|
+
You can load your data either from a folder of `.h5ad` files or by passing a list of `AnnData` objects directly in memory.
|
|
70
|
+
|
|
71
|
+
```python
|
|
72
|
+
from cystainer import CyStainer
|
|
73
|
+
|
|
74
|
+
# Initialize the stainer (automatically detects CUDA/CPU)
|
|
75
|
+
stainer = CyStainer()
|
|
76
|
+
|
|
77
|
+
# Load training data
|
|
78
|
+
# Alternatively, use: adata_list=[adata1, adata2]
|
|
79
|
+
# CyStainer includes a built-in utility to visualize how markers overlap across different panels or batches.
|
|
80
|
+
stainer.load_train_data(folder_path='./data_example/train', get_panel_vis=True)
|
|
81
|
+
```
|
|
82
|
+

|
|
83
|
+
|
|
84
|
+
```python
|
|
85
|
+
# Build the model
|
|
86
|
+
# You can pass custom hyperparameters here if needed
|
|
87
|
+
stainer.build_model()
|
|
88
|
+
|
|
89
|
+
# Train the model
|
|
90
|
+
stainer.train()
|
|
91
|
+
|
|
92
|
+
# By default the model is automatically saved in the same directory
|
|
93
|
+
# as cystainer.pt
|
|
94
|
+
```
|
|
95
|
+
|
|
96
|
+
### 2. Predicting Missing Markers
|
|
97
|
+
|
|
98
|
+
Once a model is trained, you can load inference data. The `.load_predict_data()` method ensures your cells are not shuffled so that the output matches your input order.
|
|
99
|
+
|
|
100
|
+
```python
|
|
101
|
+
# Load prediction data using the base model's reference markers
|
|
102
|
+
stainer.load_predict_data(folder_path='./data_example/test')
|
|
103
|
+
|
|
104
|
+
# Run predictions and save directly to disk
|
|
105
|
+
stainer.predict(output_path='imputed_cells.h5ad')
|
|
106
|
+
|
|
107
|
+
# Alternatively, return the predictions as a pandas DataFrame:
|
|
108
|
+
# df_imputed = stainer.predict(return_pred=True)
|
|
109
|
+
```
|
|
110
|
+
|
|
111
|
+
### 3. Fine-Tuning on New Batches
|
|
112
|
+
|
|
113
|
+
If you receive new data from a different batch or panel, you don't need to retrain from scratch. CyStainer can freeze the main network and fine-tune only the batch embeddings to align the new data.
|
|
114
|
+
|
|
115
|
+
```python
|
|
116
|
+
# Load the pre-trained model
|
|
117
|
+
stainer = CyStainer.load('cystainer.pt')
|
|
118
|
+
|
|
119
|
+
# Load the new data for fine-tuning
|
|
120
|
+
# Note, anndata objects must have batch info column
|
|
121
|
+
stainer.load_finetune_data(folder_path='./data_example/test', batch_info='batch')
|
|
122
|
+
|
|
123
|
+
# Fine-tune the batch embeddings
|
|
124
|
+
stainer.finetune()
|
|
125
|
+
|
|
126
|
+
# Predict on the newly fine-tuned data
|
|
127
|
+
stainer.predict(output_path='imputed_new_batch.h5ad')
|
|
128
|
+
```
|
|
129
|
+
|
|
130
|
+
### 4. Batch Correction
|
|
131
|
+
|
|
132
|
+
CyStainer allows you to translate cells from their original batch to a specific target batch distribution.
|
|
133
|
+
|
|
134
|
+
```python
|
|
135
|
+
stainer.predict(
|
|
136
|
+
output_path='batch_corrected_cells.h5ad',
|
|
137
|
+
correct_batch=True,
|
|
138
|
+
target_batch_name='single_batch' # Must match a batch name in stainer.batch_dict
|
|
139
|
+
)
|
|
140
|
+
```
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# cystainer/__init__.py
|
|
2
|
+
|
|
3
|
+
# Package version
|
|
4
|
+
__version__ = "0.1.0"
|
|
5
|
+
|
|
6
|
+
# Imports
|
|
7
|
+
from .runner import CyStainer
|
|
8
|
+
from .data import load_data_from_folder
|
|
9
|
+
from .modules import CyStainerModel
|
|
10
|
+
|
|
11
|
+
# Imports `from cystainer import *`
|
|
12
|
+
__all__ = [
|
|
13
|
+
"CyStainerModel",
|
|
14
|
+
"CyStainer",
|
|
15
|
+
"load_data_from_folder",
|
|
16
|
+
]
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import itertools
|
|
3
|
+
import torch
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import anndata as ad
|
|
7
|
+
import matplotlib.pyplot as plt
|
|
8
|
+
import matplotlib.patches as patches
|
|
9
|
+
from torch.utils.data import DataLoader
|
|
10
|
+
from .modules import get_marker_list, pad_and_get_pad_mask
|
|
11
|
+
|
|
12
|
+
def load_data_from_folder(folder_path=None, adata_list=None, exclude_files=[], max_n_cells=int(1e+6), batch_size=1024, is_train=True, drop_markers=None,
|
|
13
|
+
reference_markers=None, reference_batch_dict=None, normalize=False, batch_info=None, get_panel_vis=False):
|
|
14
|
+
"""Reads .h5ad files from a folder OR accepts a list of AnnData objects directly."""
|
|
15
|
+
|
|
16
|
+
# Establish the adata_list
|
|
17
|
+
if adata_list is None:
|
|
18
|
+
if folder_path is None:
|
|
19
|
+
raise ValueError("You must provide either 'folder_path' or 'adata_list'.")
|
|
20
|
+
|
|
21
|
+
sample_names = [f for f in os.listdir(folder_path) if f.endswith('.h5ad') and (f not in exclude_files)]
|
|
22
|
+
if not sample_names:
|
|
23
|
+
raise ValueError(f"No files found in {folder_path}")
|
|
24
|
+
|
|
25
|
+
# Load adatas from disk
|
|
26
|
+
loaded_adatas = [ad.read_h5ad(os.path.join(folder_path, f)) for f in sample_names]
|
|
27
|
+
else:
|
|
28
|
+
# Use the provided list of AnnData objects
|
|
29
|
+
loaded_adatas = adata_list
|
|
30
|
+
|
|
31
|
+
# Extract DataFrames and Batch Info
|
|
32
|
+
df_list = [adata.to_df().iloc[:max_n_cells,:] for adata in loaded_adatas]
|
|
33
|
+
|
|
34
|
+
if batch_info is not None:
|
|
35
|
+
batch_list = [adata.obs[batch_info].tolist() for adata in loaded_adatas]
|
|
36
|
+
batch_names = list(itertools.chain.from_iterable(batch_list))
|
|
37
|
+
else:
|
|
38
|
+
print("No batch info passed, setting 'single_batch' for the data.")
|
|
39
|
+
batch_names = list(itertools.chain.from_iterable([['single_batch'] * df.shape[0] for df in df_list]))
|
|
40
|
+
|
|
41
|
+
# Visualization, Processing, and Normalization
|
|
42
|
+
if get_panel_vis:
|
|
43
|
+
visualize_panel_overlap(df_list)
|
|
44
|
+
|
|
45
|
+
if drop_markers is not None:
|
|
46
|
+
df_list = [df.loc[:,~df.columns.isin(drop_markers)] for df in df_list]
|
|
47
|
+
|
|
48
|
+
if normalize:
|
|
49
|
+
df_list = [(df - df.mean()) / df.std() for df in df_list]
|
|
50
|
+
|
|
51
|
+
# Handle Batches
|
|
52
|
+
if reference_batch_dict is None:
|
|
53
|
+
# Sort for deterministic behavior
|
|
54
|
+
batch_dict = {batch: i for i, batch in enumerate(sorted(list(set(batch_names))))}
|
|
55
|
+
else:
|
|
56
|
+
batch_dict = reference_batch_dict.copy()
|
|
57
|
+
current_max_idx = max(batch_dict.values()) if batch_dict else -1
|
|
58
|
+
|
|
59
|
+
# Find only the new batches and sort them
|
|
60
|
+
new_batches = sorted(list(set(batch_names) - set(batch_dict.keys())))
|
|
61
|
+
|
|
62
|
+
for batch in new_batches:
|
|
63
|
+
current_max_idx += 1
|
|
64
|
+
batch_dict[batch] = current_max_idx
|
|
65
|
+
|
|
66
|
+
batch_ids = pd.Series(batch_names).apply(lambda x: batch_dict[x]).values
|
|
67
|
+
|
|
68
|
+
# Handle Markers and Padding
|
|
69
|
+
if reference_markers is None:
|
|
70
|
+
marker_list, shared_markers, unique_markers = get_marker_list(df_list)
|
|
71
|
+
else:
|
|
72
|
+
marker_list = reference_markers['marker_list']
|
|
73
|
+
shared_markers = reference_markers['shared_markers']
|
|
74
|
+
unique_markers = reference_markers['unique_markers']
|
|
75
|
+
|
|
76
|
+
df_list, pad_mask = pad_and_get_pad_mask(df_list, marker_list)
|
|
77
|
+
|
|
78
|
+
# PyTorch Dataloader
|
|
79
|
+
data = [{'x': x.astype('float32'), 'batch': b, 'pad': p}
|
|
80
|
+
for x, b, p in zip(pd.concat(df_list).values, batch_ids, np.concatenate(pad_mask, axis=0))]
|
|
81
|
+
|
|
82
|
+
dataloader = DataLoader(data, batch_size=batch_size, shuffle=is_train)
|
|
83
|
+
|
|
84
|
+
markers_info = {
|
|
85
|
+
'marker_list': marker_list,
|
|
86
|
+
'shared_markers': shared_markers,
|
|
87
|
+
'unique_markers': unique_markers
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
return dataloader, markers_info, batch_dict
|
|
91
|
+
|
|
92
|
+
def visualize_panel_overlap(df_list, alignment='left'):
|
|
93
|
+
"""
|
|
94
|
+
Creates a block alignment plot showing shared and unique markers
|
|
95
|
+
across different panels.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
df_list (list): List of pandas DataFrames.
|
|
99
|
+
panel_names (list): List of panel names as strings.
|
|
100
|
+
alignment (str): 'center' for a center-outwards pyramid style,
|
|
101
|
+
'left' for a cascading left-aligned style.
|
|
102
|
+
"""
|
|
103
|
+
# Extract markers and build presence matrix
|
|
104
|
+
panel_markers = set([frozenset(df.columns) for df in df_list])
|
|
105
|
+
all_markers = set().union(*panel_markers)
|
|
106
|
+
num_panels = len(panel_markers)
|
|
107
|
+
panel_names = [f'Panel {i+1}' for i in range(num_panels)]
|
|
108
|
+
|
|
109
|
+
matrix = pd.DataFrame(index=list(all_markers), columns=panel_names)
|
|
110
|
+
for name, markers in zip(panel_names, panel_markers):
|
|
111
|
+
matrix[name] = matrix.index.isin(markers).astype(int)
|
|
112
|
+
|
|
113
|
+
# Sort markers based on requested alignment
|
|
114
|
+
if alignment == 'left':
|
|
115
|
+
# Simple cascade sort
|
|
116
|
+
matrix = matrix.sort_values(by=panel_names, ascending=False)
|
|
117
|
+
sorted_markers = matrix.index.tolist()
|
|
118
|
+
|
|
119
|
+
elif alignment == 'center':
|
|
120
|
+
# Center-outwards sort (Frequency and Center of Mass)
|
|
121
|
+
matrix['freq'] = matrix.sum(axis=1)
|
|
122
|
+
col_indices = np.arange(num_panels)
|
|
123
|
+
|
|
124
|
+
matrix['com'] = (matrix[panel_names] * col_indices).sum(axis=1) / matrix['freq']
|
|
125
|
+
matrix['pattern'] = matrix[panel_names].astype(str).agg(''.join, axis=1)
|
|
126
|
+
|
|
127
|
+
grouped = matrix.groupby('pattern')
|
|
128
|
+
group_stats = grouped.first()[['freq', 'com']].sort_values(by=['freq', 'com'], ascending=[False, True])
|
|
129
|
+
|
|
130
|
+
left_part = []
|
|
131
|
+
right_part = []
|
|
132
|
+
center_part = []
|
|
133
|
+
mid_point = (num_panels - 1) / 2.0
|
|
134
|
+
|
|
135
|
+
for pattern, row in group_stats.iterrows():
|
|
136
|
+
group_markers = matrix[matrix['pattern'] == pattern].index.tolist()
|
|
137
|
+
group_markers.sort() # Alphabetical fallback
|
|
138
|
+
|
|
139
|
+
if row['freq'] == num_panels:
|
|
140
|
+
center_part.extend(group_markers)
|
|
141
|
+
else:
|
|
142
|
+
if row['com'] < mid_point:
|
|
143
|
+
left_part = group_markers + left_part
|
|
144
|
+
elif row['com'] > mid_point:
|
|
145
|
+
right_part = right_part + group_markers
|
|
146
|
+
else:
|
|
147
|
+
if len(left_part) <= len(right_part):
|
|
148
|
+
left_part = group_markers + left_part
|
|
149
|
+
else:
|
|
150
|
+
right_part = right_part + group_markers
|
|
151
|
+
|
|
152
|
+
sorted_markers = left_part + center_part + right_part
|
|
153
|
+
|
|
154
|
+
else:
|
|
155
|
+
raise ValueError("The 'alignment' parameter must be either 'left' or 'center'.")
|
|
156
|
+
|
|
157
|
+
# Setup the plot
|
|
158
|
+
fig, ax = plt.subplots(figsize=(14, len(panel_names) * 1.5))
|
|
159
|
+
rect_height = 0.6
|
|
160
|
+
colors = ['#72A0C1', '#D65A61', '#E8B358', '#7CE0C9', '#8291A8', '#C48CB3']
|
|
161
|
+
|
|
162
|
+
# Draw the blocks
|
|
163
|
+
for i, (name, markers) in enumerate(zip(panel_names, panel_markers)):
|
|
164
|
+
presence = [1 if m in markers else 0 for m in sorted_markers]
|
|
165
|
+
start_idx = None
|
|
166
|
+
for j, val in enumerate(presence):
|
|
167
|
+
if val == 1 and start_idx is None:
|
|
168
|
+
start_idx = j
|
|
169
|
+
elif val == 0 and start_idx is not None:
|
|
170
|
+
rect = patches.Rectangle(
|
|
171
|
+
(start_idx, i - rect_height/2), j - start_idx, rect_height,
|
|
172
|
+
edgecolor='black', facecolor=colors[i % len(colors)], alpha=0.9, linewidth=0.8
|
|
173
|
+
)
|
|
174
|
+
ax.add_patch(rect)
|
|
175
|
+
start_idx = None
|
|
176
|
+
|
|
177
|
+
# Catch end blocks
|
|
178
|
+
if start_idx is not None:
|
|
179
|
+
rect = patches.Rectangle(
|
|
180
|
+
(start_idx, i - rect_height/2), len(presence) - start_idx, rect_height,
|
|
181
|
+
edgecolor='black', facecolor=colors[i % len(colors)], alpha=0.9, linewidth=0.8
|
|
182
|
+
)
|
|
183
|
+
ax.add_patch(rect)
|
|
184
|
+
|
|
185
|
+
ax.text(-0.8, i, name, va='center', ha='right', fontsize=12, fontweight='bold')
|
|
186
|
+
|
|
187
|
+
# Formatting
|
|
188
|
+
ax.set_xlim(0, len(sorted_markers))
|
|
189
|
+
ax.set_ylim(-1, len(panel_names))
|
|
190
|
+
ax.set_yticks([])
|
|
191
|
+
ax.set_xticks(np.arange(len(sorted_markers)) + 0.5)
|
|
192
|
+
ax.set_xticklabels(sorted_markers, rotation=90, ha='center', fontsize=10)
|
|
193
|
+
|
|
194
|
+
# Remove borders
|
|
195
|
+
for spine in ['top', 'right', 'left']:
|
|
196
|
+
ax.spines[spine].set_visible(False)
|
|
197
|
+
|
|
198
|
+
ax.xaxis.grid(True, linestyle='--', alpha=0.4)
|
|
199
|
+
ax.set_axisbelow(True)
|
|
200
|
+
|
|
201
|
+
title_align = "Center" if alignment == 'center' else "Left"
|
|
202
|
+
plt.title(f"Panel Marker Alignment", fontsize=16, pad=25, fontweight='bold')
|
|
203
|
+
plt.tight_layout()
|
|
204
|
+
plt.show()
|