SURE-tools 2.4.2__py3-none-any.whl → 2.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/TranscriptomeDecoder.py +40 -18
- {sure_tools-2.4.2.dist-info → sure_tools-2.4.5.dist-info}/METADATA +1 -1
- {sure_tools-2.4.2.dist-info → sure_tools-2.4.5.dist-info}/RECORD +7 -7
- {sure_tools-2.4.2.dist-info → sure_tools-2.4.5.dist-info}/WHEEL +0 -0
- {sure_tools-2.4.2.dist-info → sure_tools-2.4.5.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.4.2.dist-info → sure_tools-2.4.5.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.4.2.dist-info → sure_tools-2.4.5.dist-info}/top_level.txt +0 -0
SURE/TranscriptomeDecoder.py
CHANGED
|
@@ -86,14 +86,16 @@ class TranscriptomeDecoder:
|
|
|
86
86
|
output = torch.matmul(latent_proj, gene_embeds) # [batch, gene_dim]
|
|
87
87
|
|
|
88
88
|
return self.activation(output)
|
|
89
|
-
|
|
90
|
-
class ChunkedTransformer(nn.Module):
|
|
91
|
-
|
|
92
|
-
def __init__(self, gene_dim, hidden_dim, chunk_size=1000, num_layers=4):
|
|
89
|
+
|
|
90
|
+
class ChunkedTransformer(nn.Module):
|
|
91
|
+
def __init__(self, gene_dim, hidden_dim=512, chunk_size=2000, num_layers=3):
|
|
93
92
|
super().__init__()
|
|
94
93
|
self.chunk_size = chunk_size
|
|
94
|
+
self.hidden_dim = hidden_dim
|
|
95
95
|
self.num_chunks = (gene_dim + chunk_size - 1) // chunk_size
|
|
96
|
-
|
|
96
|
+
|
|
97
|
+
# 共享的Transformer层
|
|
98
|
+
self.transformer_layers = nn.ModuleList([
|
|
97
99
|
nn.Sequential(
|
|
98
100
|
nn.Linear(hidden_dim, hidden_dim),
|
|
99
101
|
nn.GELU(),
|
|
@@ -101,22 +103,40 @@ class TranscriptomeDecoder:
|
|
|
101
103
|
nn.Linear(hidden_dim, hidden_dim),
|
|
102
104
|
) for _ in range(num_layers)
|
|
103
105
|
])
|
|
104
|
-
|
|
106
|
+
|
|
107
|
+
# 每个chunk独立的投影层
|
|
108
|
+
self.input_projections = nn.ModuleList([
|
|
109
|
+
nn.Linear(min(chunk_size, gene_dim - i * chunk_size), hidden_dim)
|
|
110
|
+
for i in range(self.num_chunks)
|
|
111
|
+
])
|
|
112
|
+
self.output_projections = nn.ModuleList([
|
|
113
|
+
nn.Linear(hidden_dim, min(chunk_size, gene_dim - i * chunk_size))
|
|
114
|
+
for i in range(self.num_chunks)
|
|
115
|
+
])
|
|
116
|
+
|
|
105
117
|
def forward(self, x):
|
|
106
|
-
|
|
107
|
-
batch_size = x.shape[0]
|
|
118
|
+
batch_size, gene_dim = x.shape
|
|
108
119
|
output = torch.zeros_like(x)
|
|
109
|
-
|
|
120
|
+
|
|
110
121
|
for i in range(self.num_chunks):
|
|
111
122
|
start_idx = i * self.chunk_size
|
|
112
|
-
end_idx = min((i + 1) * self.chunk_size,
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
123
|
+
end_idx = min((i + 1) * self.chunk_size, gene_dim)
|
|
124
|
+
current_chunk_size = end_idx - start_idx
|
|
125
|
+
|
|
126
|
+
chunk = x[:, start_idx:end_idx] # [batch_size, current_chunk_size]
|
|
127
|
+
|
|
128
|
+
# 投影到hidden_dim
|
|
129
|
+
chunk_proj = self.input_projections[i](chunk) # [batch_size, hidden_dim]
|
|
119
130
|
|
|
131
|
+
# Transformer处理
|
|
132
|
+
for layer in self.transformer_layers:
|
|
133
|
+
chunk_proj = layer(chunk_proj) + chunk_proj
|
|
134
|
+
|
|
135
|
+
# 投影回原始维度
|
|
136
|
+
chunk_out = self.output_projections[i](chunk_proj) # [batch_size, current_chunk_size]
|
|
137
|
+
|
|
138
|
+
output[:, start_idx:end_idx] = chunk_out
|
|
139
|
+
|
|
120
140
|
return output
|
|
121
141
|
|
|
122
142
|
class Decoder(nn.Module):
|
|
@@ -166,6 +186,8 @@ class TranscriptomeDecoder:
|
|
|
166
186
|
self.output_scale = nn.Parameter(torch.ones(1))
|
|
167
187
|
self.output_bias = nn.Parameter(torch.zeros(1))
|
|
168
188
|
|
|
189
|
+
self.latent_to_gene = nn.Linear(hidden_dim, gene_dim)
|
|
190
|
+
|
|
169
191
|
self._init_weights()
|
|
170
192
|
|
|
171
193
|
def _init_weights(self):
|
|
@@ -185,8 +207,8 @@ class TranscriptomeDecoder:
|
|
|
185
207
|
gene_features = self.gene_projection(latent)
|
|
186
208
|
|
|
187
209
|
# 3. Add latent information
|
|
188
|
-
|
|
189
|
-
gene_features = gene_features +
|
|
210
|
+
latent_gene_injection = self.latent_to_gene(latent_expanded)
|
|
211
|
+
gene_features = gene_features + latent_gene_injection
|
|
190
212
|
|
|
191
213
|
# 4. Chunked processing (memory efficient)
|
|
192
214
|
gene_features = self.chunked_processor(gene_features)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
SURE/DensityFlow.py,sha256=YvaE9aPbAC2U7WhTye5i2AMtcw0BI_qS3gv9SP4aE0k,56676
|
|
2
2
|
SURE/PerturbE.py,sha256=DxEp-qef--x8-GMZdPfBf8ts8UDDc34h2P5AnpqZ-YM,52265
|
|
3
3
|
SURE/SURE.py,sha256=MXs7iuvcj-lU4dJ_MwKegpL2Rqk2HB4eFfAgHRA3RtA,47744
|
|
4
|
-
SURE/TranscriptomeDecoder.py,sha256=
|
|
4
|
+
SURE/TranscriptomeDecoder.py,sha256=fjTl2wC-nGTdbQGgFDbTmWYI8RoEg6J4cHPmoUoJJfI,21286
|
|
5
5
|
SURE/__init__.py,sha256=pNSGQ4BMqMXBAPHpFOYNB8_0vFW-RqPy3rr5fvdEEyU,473
|
|
6
6
|
SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
|
|
7
7
|
SURE/assembly/assembly.py,sha256=6IMdelPOiRO4mUb4dC7gVCoF1Uvfw86-Map8P_jnUag,21477
|
|
@@ -19,9 +19,9 @@ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
|
|
|
19
19
|
SURE/utils/custom_mlp.py,sha256=Rn_PQouxPMSda-KKBYrwVVv3GFFuUmCLxp8cV5LszZo,10580
|
|
20
20
|
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
21
21
|
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
22
|
-
sure_tools-2.4.
|
|
23
|
-
sure_tools-2.4.
|
|
24
|
-
sure_tools-2.4.
|
|
25
|
-
sure_tools-2.4.
|
|
26
|
-
sure_tools-2.4.
|
|
27
|
-
sure_tools-2.4.
|
|
22
|
+
sure_tools-2.4.5.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
23
|
+
sure_tools-2.4.5.dist-info/METADATA,sha256=2GjCK_HUQ_Vs6b8AT2PIelOadhiVeOakI8B_OqbRyi0,2677
|
|
24
|
+
sure_tools-2.4.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
25
|
+
sure_tools-2.4.5.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
|
|
26
|
+
sure_tools-2.4.5.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
27
|
+
sure_tools-2.4.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|