SURE-tools 3.6.11__tar.gz → 3.6.13__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sure_tools-3.6.11 → sure_tools-3.6.13}/PKG-INFO +1 -1
- {sure_tools-3.6.11 → sure_tools-3.6.13}/SURE/SURE_nsf.py +5 -6
- {sure_tools-3.6.11 → sure_tools-3.6.13}/SURE/SURE_vae.py +8 -10
- {sure_tools-3.6.11 → sure_tools-3.6.13}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-3.6.11 → sure_tools-3.6.13}/setup.py +1 -1
- {sure_tools-3.6.11 → sure_tools-3.6.13}/LICENSE +0 -0
- {sure_tools-3.6.11 → sure_tools-3.6.13}/README.md +0 -0
- {sure_tools-3.6.11 → sure_tools-3.6.13}/SURE/SURE.py +0 -0
- {sure_tools-3.6.11 → sure_tools-3.6.13}/SURE/SUREMO.py +0 -0
- {sure_tools-3.6.11 → sure_tools-3.6.13}/SURE/SURE_vanilla.py +0 -0
- {sure_tools-3.6.11 → sure_tools-3.6.13}/SURE/__init__.py +0 -0
- {sure_tools-3.6.11 → sure_tools-3.6.13}/SURE/atac/__init__.py +0 -0
- {sure_tools-3.6.11 → sure_tools-3.6.13}/SURE/atac/utils.py +0 -0
- {sure_tools-3.6.11 → sure_tools-3.6.13}/SURE/dist/__init__.py +0 -0
- {sure_tools-3.6.11 → sure_tools-3.6.13}/SURE/dist/negbinomial.py +0 -0
- {sure_tools-3.6.11 → sure_tools-3.6.13}/SURE/graph/__init__.py +0 -0
- {sure_tools-3.6.11 → sure_tools-3.6.13}/SURE/graph/graph_utils.py +0 -0
- {sure_tools-3.6.11 → sure_tools-3.6.13}/SURE/utils/__init__.py +0 -0
- {sure_tools-3.6.11 → sure_tools-3.6.13}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-3.6.11 → sure_tools-3.6.13}/SURE/utils/label.py +0 -0
- {sure_tools-3.6.11 → sure_tools-3.6.13}/SURE/utils/queue.py +0 -0
- {sure_tools-3.6.11 → sure_tools-3.6.13}/SURE/utils/utils.py +0 -0
- {sure_tools-3.6.11 → sure_tools-3.6.13}/SURE_tools.egg-info/SOURCES.txt +0 -0
- {sure_tools-3.6.11 → sure_tools-3.6.13}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-3.6.11 → sure_tools-3.6.13}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-3.6.11 → sure_tools-3.6.13}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-3.6.11 → sure_tools-3.6.13}/setup.cfg +0 -0
|
@@ -168,7 +168,7 @@ class SURENF(nn.Module):
|
|
|
168
168
|
self.use_mask = False
|
|
169
169
|
self.mask_ratio = 0
|
|
170
170
|
|
|
171
|
-
self.kmeans =
|
|
171
|
+
self.kmeans = KMeans(n_clusters=codebook_size)
|
|
172
172
|
self.seed = seed
|
|
173
173
|
|
|
174
174
|
set_random_seed(seed)
|
|
@@ -593,8 +593,8 @@ class SURENF(nn.Module):
|
|
|
593
593
|
return A
|
|
594
594
|
|
|
595
595
|
def predict_cluster(self, xs, batch_size=1024, show_progress=True):
|
|
596
|
-
|
|
597
|
-
return
|
|
596
|
+
zs = self.get_cell_embedding(xs, batch_size=batch_size, show_progress=show_progress)
|
|
597
|
+
return self.kmeans.predict(zs)
|
|
598
598
|
|
|
599
599
|
def predict(self, xs, cs, batch_size=1024, show_progress=True):
|
|
600
600
|
"""
|
|
@@ -1013,9 +1013,8 @@ class SURENF(nn.Module):
|
|
|
1013
1013
|
print("Restored model and optimizer states from best epoch.")
|
|
1014
1014
|
|
|
1015
1015
|
zs = self.get_cell_embedding(xs, show_progress=False)
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
self.codebook_loc = ns @ zs
|
|
1016
|
+
self.kmeans.fit(zs)
|
|
1017
|
+
self.codebook_loc = self.kmeans.cluster_centers_
|
|
1019
1018
|
|
|
1020
1019
|
def get_model_and_optimizer_state(self, scheduler=None):
|
|
1021
1020
|
"""获取模型和优化器的完整状态"""
|
|
@@ -169,7 +169,7 @@ class SUREVAE(nn.Module):
|
|
|
169
169
|
self.transforms = transforms
|
|
170
170
|
self.flow_hidden_layers = flow_hidden_layers
|
|
171
171
|
|
|
172
|
-
self.kmeans =
|
|
172
|
+
self.kmeans = KMeans(n_clusters=codebook_size)
|
|
173
173
|
self.seed = seed
|
|
174
174
|
|
|
175
175
|
set_random_seed(seed)
|
|
@@ -593,10 +593,9 @@ class SUREVAE(nn.Module):
|
|
|
593
593
|
X_batch = X_batch.to(self.get_device())
|
|
594
594
|
C_batch = cs[idx].to(self.get_device())
|
|
595
595
|
|
|
596
|
-
#
|
|
597
|
-
#z_basal =
|
|
598
|
-
|
|
599
|
-
z_basal = torch.matmul(ns, cb_loc)
|
|
596
|
+
#ns = self._soft_assignments(X_batch)
|
|
597
|
+
#z_basal = torch.matmul(ns, cb_loc)
|
|
598
|
+
z_basal = self._get_cell_embedding(X_batch)
|
|
600
599
|
dzs = self.condition_effect([C_batch,z_basal])
|
|
601
600
|
|
|
602
601
|
A.append(tensor_to_numpy(dzs))
|
|
@@ -606,8 +605,8 @@ class SUREVAE(nn.Module):
|
|
|
606
605
|
return A
|
|
607
606
|
|
|
608
607
|
def predict_cluster(self, xs, batch_size=1024, show_progress=True):
|
|
609
|
-
|
|
610
|
-
return
|
|
608
|
+
zs = self.get_cell_embedding(xs, batch_size=batch_size, show_progress=show_progress)
|
|
609
|
+
return self.kmeans.predict(zs)
|
|
611
610
|
|
|
612
611
|
def predict(self, xs, cs, batch_size=1024, show_progress=True):
|
|
613
612
|
"""
|
|
@@ -945,9 +944,8 @@ class SUREVAE(nn.Module):
|
|
|
945
944
|
print("Restored model and optimizer states from best epoch.")
|
|
946
945
|
|
|
947
946
|
zs = self.get_cell_embedding(xs, show_progress=False)
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
self.codebook_loc = ns @ zs
|
|
947
|
+
self.kmeans.fit(zs)
|
|
948
|
+
self.codebook_loc = self.kmeans.cluster_centers_
|
|
951
949
|
|
|
952
950
|
def get_model_and_optimizer_state(self, scheduler=None):
|
|
953
951
|
"""获取模型和优化器的完整状态"""
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|