SURE-tools 2.4.2__py3-none-any.whl → 2.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

@@ -86,14 +86,16 @@ class TranscriptomeDecoder:
86
86
  output = torch.matmul(latent_proj, gene_embeds) # [batch, gene_dim]
87
87
 
88
88
  return self.activation(output)
89
-
90
- class ChunkedTransformer(nn.Module):
91
- """Process genes in chunks to reduce memory usage"""
92
- def __init__(self, gene_dim, hidden_dim, chunk_size=1000, num_layers=4):
89
+
90
+ class ChunkedTransformer(nn.Module):
91
+ def __init__(self, gene_dim, hidden_dim=512, chunk_size=2000, num_layers=3):
93
92
  super().__init__()
94
93
  self.chunk_size = chunk_size
94
+ self.hidden_dim = hidden_dim
95
95
  self.num_chunks = (gene_dim + chunk_size - 1) // chunk_size
96
- self.layers = nn.ModuleList([
96
+
97
+ # 共享的Transformer层
98
+ self.transformer_layers = nn.ModuleList([
97
99
  nn.Sequential(
98
100
  nn.Linear(hidden_dim, hidden_dim),
99
101
  nn.GELU(),
@@ -101,22 +103,40 @@ class TranscriptomeDecoder:
101
103
  nn.Linear(hidden_dim, hidden_dim),
102
104
  ) for _ in range(num_layers)
103
105
  ])
104
-
106
+
107
+ # 每个chunk独立的投影层
108
+ self.input_projections = nn.ModuleList([
109
+ nn.Linear(min(chunk_size, gene_dim - i * chunk_size), hidden_dim)
110
+ for i in range(self.num_chunks)
111
+ ])
112
+ self.output_projections = nn.ModuleList([
113
+ nn.Linear(hidden_dim, min(chunk_size, gene_dim - i * chunk_size))
114
+ for i in range(self.num_chunks)
115
+ ])
116
+
105
117
  def forward(self, x):
106
- # Process in chunks to save memory
107
- batch_size = x.shape[0]
118
+ batch_size, gene_dim = x.shape
108
119
  output = torch.zeros_like(x)
109
-
120
+
110
121
  for i in range(self.num_chunks):
111
122
  start_idx = i * self.chunk_size
112
- end_idx = min((i + 1) * self.chunk_size, x.shape[1])
113
-
114
- chunk = x[:, start_idx:end_idx]
115
- for layer in self.layers:
116
- chunk = layer(chunk) + chunk # Residual connection
117
-
118
- output[:, start_idx:end_idx] = chunk
123
+ end_idx = min((i + 1) * self.chunk_size, gene_dim)
124
+ current_chunk_size = end_idx - start_idx
125
+
126
+ chunk = x[:, start_idx:end_idx] # [batch_size, current_chunk_size]
127
+
128
+ # 投影到hidden_dim
129
+ chunk_proj = self.input_projections[i](chunk) # [batch_size, hidden_dim]
119
130
 
131
+ # Transformer处理
132
+ for layer in self.transformer_layers:
133
+ chunk_proj = layer(chunk_proj) + chunk_proj
134
+
135
+ # 投影回原始维度
136
+ chunk_out = self.output_projections[i](chunk_proj) # [batch_size, current_chunk_size]
137
+
138
+ output[:, start_idx:end_idx] = chunk_out
139
+
120
140
  return output
121
141
 
122
142
  class Decoder(nn.Module):
@@ -166,6 +186,8 @@ class TranscriptomeDecoder:
166
186
  self.output_scale = nn.Parameter(torch.ones(1))
167
187
  self.output_bias = nn.Parameter(torch.zeros(1))
168
188
 
189
+ self.latent_to_gene = nn.Linear(hidden_dim, gene_dim)
190
+
169
191
  self._init_weights()
170
192
 
171
193
  def _init_weights(self):
@@ -185,8 +207,8 @@ class TranscriptomeDecoder:
185
207
  gene_features = self.gene_projection(latent)
186
208
 
187
209
  # 3. Add latent information
188
- print(f'{gene_features.shape}; {latent_expanded.shape}')
189
- gene_features = gene_features + latent_expanded.unsqueeze(1)
210
+ latent_gene_injection = self.latent_to_gene(latent_expanded)
211
+ gene_features = gene_features + latent_gene_injection
190
212
 
191
213
  # 4. Chunked processing (memory efficient)
192
214
  gene_features = self.chunked_processor(gene_features)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.4.2
3
+ Version: 2.4.5
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,7 +1,7 @@
1
1
  SURE/DensityFlow.py,sha256=YvaE9aPbAC2U7WhTye5i2AMtcw0BI_qS3gv9SP4aE0k,56676
2
2
  SURE/PerturbE.py,sha256=DxEp-qef--x8-GMZdPfBf8ts8UDDc34h2P5AnpqZ-YM,52265
3
3
  SURE/SURE.py,sha256=MXs7iuvcj-lU4dJ_MwKegpL2Rqk2HB4eFfAgHRA3RtA,47744
4
- SURE/TranscriptomeDecoder.py,sha256=z0NtjvpT1nN4H9WC97VJPWuTvqaZYn36MbpaIWjXgGU,20324
4
+ SURE/TranscriptomeDecoder.py,sha256=fjTl2wC-nGTdbQGgFDbTmWYI8RoEg6J4cHPmoUoJJfI,21286
5
5
  SURE/__init__.py,sha256=pNSGQ4BMqMXBAPHpFOYNB8_0vFW-RqPy3rr5fvdEEyU,473
6
6
  SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
7
7
  SURE/assembly/assembly.py,sha256=6IMdelPOiRO4mUb4dC7gVCoF1Uvfw86-Map8P_jnUag,21477
@@ -19,9 +19,9 @@ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
19
19
  SURE/utils/custom_mlp.py,sha256=Rn_PQouxPMSda-KKBYrwVVv3GFFuUmCLxp8cV5LszZo,10580
20
20
  SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
21
21
  SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
22
- sure_tools-2.4.2.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
23
- sure_tools-2.4.2.dist-info/METADATA,sha256=gorPWZ40-GBJ3Rz6CH-eWZFz0ZKW8LacnDEJbg1cik4,2677
24
- sure_tools-2.4.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
25
- sure_tools-2.4.2.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
26
- sure_tools-2.4.2.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
27
- sure_tools-2.4.2.dist-info/RECORD,,
22
+ sure_tools-2.4.5.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
23
+ sure_tools-2.4.5.dist-info/METADATA,sha256=2GjCK_HUQ_Vs6b8AT2PIelOadhiVeOakI8B_OqbRyi0,2677
24
+ sure_tools-2.4.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
25
+ sure_tools-2.4.5.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
26
+ sure_tools-2.4.5.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
27
+ sure_tools-2.4.5.dist-info/RECORD,,