SURE-tools 2.1.4__py3-none-any.whl → 2.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

SURE/PerturbFlow.py CHANGED
@@ -55,44 +55,6 @@ def set_random_seed(seed):
55
55
  pyro.set_rng_seed(seed)
56
56
 
57
57
  class PerturbFlow(nn.Module):
58
- """SUccinct REpresentation of single-omics cells
59
-
60
- Parameters
61
- ----------
62
- inpute_size
63
- Number of features (e.g., genes, peaks, proteins, etc.) per cell.
64
- codebook_size
65
- Number of metacells.
66
- cell_factor_size
67
- Number of cell-level factors.
68
- z_dim
69
- Dimensionality of latent states and metacells.
70
- hidden_layers
71
- A list give the numbers of neurons for each hidden layer.
72
- loss_func
73
- The likelihood model for single-cell data generation.
74
-
75
- One of the following:
76
- * ``'negbinomial'`` - negative binomial distribution (default)
77
- * ``'poisson'`` - poisson distribution
78
- * ``'multinomial'`` - multinomial distribution
79
- z_dist
80
- The distribution model for latent states.
81
-
82
- One of the following:
83
- * ``'normal'`` - normal distribution
84
- * ``'laplacian'`` - Laplacian distribution
85
- * ``'studentt'`` - Student-t distribution.
86
- use_cuda
87
- A boolean option for switching on cuda device.
88
-
89
- Examples
90
- --------
91
- >>>
92
- >>>
93
- >>>
94
-
95
- """
96
58
  def __init__(self,
97
59
  input_size: int,
98
60
  codebook_size: int = 200,
@@ -447,7 +409,6 @@ class PerturbFlow(nn.Module):
447
409
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
448
410
 
449
411
  if self.cell_factor_size>0:
450
- #zus = self.decoder_undesired([zns,us])
451
412
  zus = None
452
413
  for i in np.arange(self.cell_factor_size):
453
414
  if i==0:
@@ -643,7 +604,6 @@ class PerturbFlow(nn.Module):
643
604
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
644
605
 
645
606
  if self.cell_factor_size>0:
646
- #zus = self.decoder_undesired([zns,us])
647
607
  zus = None
648
608
  for i in np.arange(self.cell_factor_size):
649
609
  if i==0:
@@ -710,7 +670,7 @@ class PerturbFlow(nn.Module):
710
670
  xs,
711
671
  batch_size: int = 1024):
712
672
  """
713
- Return cells' latent representations
673
+ Return cells' basis latent representations
714
674
 
715
675
  Parameters
716
676
  ----------
@@ -731,7 +691,7 @@ class PerturbFlow(nn.Module):
731
691
  Z = []
732
692
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
733
693
  for X_batch, _ in dataloader:
734
- zns = self._get_cell_embedding(X_batch)
694
+ zns = self._get_basis_embedding(X_batch)
735
695
  Z.append(tensor_to_numpy(zns))
736
696
  pbar.update(1)
737
697
 
@@ -858,7 +818,7 @@ class PerturbFlow(nn.Module):
858
818
  delta_zs,
859
819
  batch_size: int = 1024):
860
820
  """
861
- Return cells' changes in the latent space induced by specific perturbation of a factor
821
+ Return cells' changes in the feature space induced by specific perturbation of a factor
862
822
 
863
823
  """
864
824
  delta_zs = convert_to_tensor(delta_zs, device=self.get_device())
@@ -1083,24 +1043,18 @@ def parse_args():
1083
1043
  help="the file for the record of cell-level factors",
1084
1044
  )
1085
1045
  parser.add_argument(
1086
- "-delta",
1087
- "--delta",
1088
- default=0.0,
1089
- type=float,
1090
- help="penalty weight for zero inflation loss",
1091
- )
1092
- parser.add_argument(
1093
- "-64",
1094
- "--float64",
1095
- action="store_true",
1096
- help="use double float precision",
1046
+ "-bs",
1047
+ "--batch-size",
1048
+ default=1000,
1049
+ type=int,
1050
+ help="number of cells to be considered in a batch",
1097
1051
  )
1098
1052
  parser.add_argument(
1099
- "--z-dist",
1100
- default='normal',
1101
- type=str,
1102
- choices=['normal','laplacian','studentt','cauchy'],
1103
- help="distribution model for latent representation",
1053
+ "-lr",
1054
+ "--learning-rate",
1055
+ default=0.0001,
1056
+ type=float,
1057
+ help="learning rate for Adam optimizer",
1104
1058
  )
1105
1059
  parser.add_argument(
1106
1060
  "-cs",
@@ -1109,6 +1063,13 @@ def parse_args():
1109
1063
  type=int,
1110
1064
  help="size of vector quantization codebook",
1111
1065
  )
1066
+ parser.add_argument(
1067
+ "--z-dist",
1068
+ default='gumbel',
1069
+ type=str,
1070
+ choices=['normal','laplacian','studentt','gumbel','cauchy'],
1071
+ help="distribution model for latent representation",
1072
+ )
1112
1073
  parser.add_argument(
1113
1074
  "-zd",
1114
1075
  "--z-dim",
@@ -1116,6 +1077,27 @@ def parse_args():
1116
1077
  type=int,
1117
1078
  help="size of the tensor representing the latent variable z variable",
1118
1079
  )
1080
+ parser.add_argument(
1081
+ "-likeli",
1082
+ "--likelihood",
1083
+ default='negbinomial',
1084
+ type=str,
1085
+ choices=['negbinomial', 'multinomial', 'poisson', 'bernoulli'],
1086
+ help="specify the distribution likelihood function",
1087
+ )
1088
+ parser.add_argument(
1089
+ "-zi",
1090
+ "--zeroinflate",
1091
+ action="store_true",
1092
+ help="use zero-inflated estimation",
1093
+ )
1094
+ parser.add_argument(
1095
+ "-id",
1096
+ "--inverse-dispersion",
1097
+ default=10.0,
1098
+ type=float,
1099
+ help="inverse dispersion prior for negative binomial",
1100
+ )
1119
1101
  parser.add_argument(
1120
1102
  "-hl",
1121
1103
  "--hidden-layers",
@@ -1150,18 +1132,10 @@ def parse_args():
1150
1132
  help="post functions for activation layers, could be none or dropout, default is 'none'",
1151
1133
  )
1152
1134
  parser.add_argument(
1153
- "-id",
1154
- "--inverse-dispersion",
1155
- default=10.0,
1156
- type=float,
1157
- help="inverse dispersion prior for negative binomial",
1158
- )
1159
- parser.add_argument(
1160
- "-lr",
1161
- "--learning-rate",
1162
- default=0.0001,
1163
- type=float,
1164
- help="learning rate for Adam optimizer",
1135
+ "-64",
1136
+ "--float64",
1137
+ action="store_true",
1138
+ help="use double float precision",
1165
1139
  )
1166
1140
  parser.add_argument(
1167
1141
  "-dr",
@@ -1182,50 +1156,7 @@ def parse_args():
1182
1156
  default=0.95,
1183
1157
  type=float,
1184
1158
  help="beta-1 parameter for Adam optimizer",
1185
- )
1186
- parser.add_argument(
1187
- "-bs",
1188
- "--batch-size",
1189
- default=1000,
1190
- type=int,
1191
- help="number of cells to be considered in a batch",
1192
- )
1193
- parser.add_argument(
1194
- "-gp",
1195
- "--gate-prior",
1196
- default=0.6,
1197
- type=float,
1198
- help="gate prior for zero-inflated model",
1199
- )
1200
- parser.add_argument(
1201
- "-likeli",
1202
- "--likelihood",
1203
- default='negbinomial',
1204
- type=str,
1205
- choices=['negbinomial', 'multinomial', 'poisson', 'gaussian','lognormal'],
1206
- help="specify the distribution likelihood function",
1207
- )
1208
- parser.add_argument(
1209
- "-dirichlet",
1210
- "--use-dirichlet",
1211
- action="store_true",
1212
- help="use Dirichlet distribution over gene frequency",
1213
- )
1214
- parser.add_argument(
1215
- "-mass",
1216
- "--dirichlet-mass",
1217
- default=1,
1218
- type=float,
1219
- help="mass param for dirichlet model",
1220
- )
1221
- parser.add_argument(
1222
- "-zi",
1223
- "--zero-inflation",
1224
- default='exact',
1225
- type=str,
1226
- choices=['none','exact','inexact'],
1227
- help="use zero-inflated estimation",
1228
- )
1159
+ )
1229
1160
  parser.add_argument(
1230
1161
  "--seed",
1231
1162
  default=None,
@@ -1266,29 +1197,23 @@ def main():
1266
1197
  input_size = xs.shape[1]
1267
1198
  cell_factor_size = 0 if us is None else us.shape[1]
1268
1199
 
1269
- latent_dist = args.z_dist
1270
-
1271
1200
  ###########################################
1272
1201
  perturbflow = PerturbFlow(
1273
1202
  input_size=input_size,
1274
1203
  cell_factor_size=cell_factor_size,
1275
1204
  inverse_dispersion=args.inverse_dispersion,
1276
- latent_dim=args.latent_dim,
1205
+ z_dim=args.z_dim,
1277
1206
  hidden_layers=args.hidden_layers,
1278
1207
  hidden_layer_activation=args.hidden_layer_activation,
1279
1208
  use_cuda=args.cuda,
1280
1209
  config_enum=args.enum_discrete,
1281
- use_dirichlet=args.use_dirichlet,
1282
- zero_inflation=args.zero_inflation,
1283
- gate_prior=args.gate_prior,
1284
- delta=args.delta,
1210
+ use_zeroinflate=args.zeroinflate,
1285
1211
  loss_func=args.likelihood,
1286
- dirichlet_mass=args.dirichlet_mass,
1287
1212
  nn_dropout=args.layer_dropout_rate,
1288
1213
  post_layer_fct=args.post_layer_function,
1289
1214
  post_act_fct=args.post_activation_function,
1290
1215
  codebook_size=args.codebook_size,
1291
- latent_dist = latent_dist,
1216
+ z_dist = args.z_dist,
1292
1217
  dtype=dtype,
1293
1218
  )
1294
1219
 
SURE/SURE.py CHANGED
@@ -1004,24 +1004,18 @@ def parse_args():
1004
1004
  help="the file for the record of cell-level factors",
1005
1005
  )
1006
1006
  parser.add_argument(
1007
- "-delta",
1008
- "--delta",
1009
- default=0.0,
1010
- type=float,
1011
- help="penalty weight for zero inflation loss",
1012
- )
1013
- parser.add_argument(
1014
- "-64",
1015
- "--float64",
1016
- action="store_true",
1017
- help="use double float precision",
1007
+ "-bs",
1008
+ "--batch-size",
1009
+ default=1000,
1010
+ type=int,
1011
+ help="number of cells to be considered in a batch",
1018
1012
  )
1019
1013
  parser.add_argument(
1020
- "--z-dist",
1021
- default='normal',
1022
- type=str,
1023
- choices=['normal','laplacian','studentt','cauchy'],
1024
- help="distribution model for latent representation",
1014
+ "-lr",
1015
+ "--learning-rate",
1016
+ default=0.0001,
1017
+ type=float,
1018
+ help="learning rate for Adam optimizer",
1025
1019
  )
1026
1020
  parser.add_argument(
1027
1021
  "-cs",
@@ -1030,6 +1024,13 @@ def parse_args():
1030
1024
  type=int,
1031
1025
  help="size of vector quantization codebook",
1032
1026
  )
1027
+ parser.add_argument(
1028
+ "--z-dist",
1029
+ default='gumbel',
1030
+ type=str,
1031
+ choices=['normal','laplacian','studentt','cauchy','gumbel'],
1032
+ help="distribution model for latent representation",
1033
+ )
1033
1034
  parser.add_argument(
1034
1035
  "-zd",
1035
1036
  "--z-dim",
@@ -1037,6 +1038,27 @@ def parse_args():
1037
1038
  type=int,
1038
1039
  help="size of the tensor representing the latent variable z variable",
1039
1040
  )
1041
+ parser.add_argument(
1042
+ "-likeli",
1043
+ "--likelihood",
1044
+ default='negbinomial',
1045
+ type=str,
1046
+ choices=['negbinomial', 'multinomial', 'poisson', 'bernoulli'],
1047
+ help="specify the distribution likelihood function",
1048
+ )
1049
+ parser.add_argument(
1050
+ "-zi",
1051
+ "--zeroinflate",
1052
+ action="store_true",
1053
+ help="use zeroinflation",
1054
+ )
1055
+ parser.add_argument(
1056
+ "-id",
1057
+ "--inverse-dispersion",
1058
+ default=10.0,
1059
+ type=float,
1060
+ help="inverse dispersion prior for negative binomial",
1061
+ )
1040
1062
  parser.add_argument(
1041
1063
  "-hl",
1042
1064
  "--hidden-layers",
@@ -1070,20 +1092,6 @@ def parse_args():
1070
1092
  type=str,
1071
1093
  help="post functions for activation layers, could be none or dropout, default is 'none'",
1072
1094
  )
1073
- parser.add_argument(
1074
- "-id",
1075
- "--inverse-dispersion",
1076
- default=10.0,
1077
- type=float,
1078
- help="inverse dispersion prior for negative binomial",
1079
- )
1080
- parser.add_argument(
1081
- "-lr",
1082
- "--learning-rate",
1083
- default=0.0001,
1084
- type=float,
1085
- help="learning rate for Adam optimizer",
1086
- )
1087
1095
  parser.add_argument(
1088
1096
  "-dr",
1089
1097
  "--decay-rate",
@@ -1105,47 +1113,10 @@ def parse_args():
1105
1113
  help="beta-1 parameter for Adam optimizer",
1106
1114
  )
1107
1115
  parser.add_argument(
1108
- "-bs",
1109
- "--batch-size",
1110
- default=1000,
1111
- type=int,
1112
- help="number of cells to be considered in a batch",
1113
- )
1114
- parser.add_argument(
1115
- "-gp",
1116
- "--gate-prior",
1117
- default=0.6,
1118
- type=float,
1119
- help="gate prior for zero-inflated model",
1120
- )
1121
- parser.add_argument(
1122
- "-likeli",
1123
- "--likelihood",
1124
- default='negbinomial',
1125
- type=str,
1126
- choices=['negbinomial', 'multinomial', 'poisson', 'gaussian','lognormal'],
1127
- help="specify the distribution likelihood function",
1128
- )
1129
- parser.add_argument(
1130
- "-dirichlet",
1131
- "--use-dirichlet",
1116
+ "-64",
1117
+ "--float64",
1132
1118
  action="store_true",
1133
- help="use Dirichlet distribution over gene frequency",
1134
- )
1135
- parser.add_argument(
1136
- "-mass",
1137
- "--dirichlet-mass",
1138
- default=1,
1139
- type=float,
1140
- help="mass param for dirichlet model",
1141
- )
1142
- parser.add_argument(
1143
- "-zi",
1144
- "--zero-inflation",
1145
- default='exact',
1146
- type=str,
1147
- choices=['none','exact','inexact'],
1148
- help="use zero-inflated estimation",
1119
+ help="use double float precision",
1149
1120
  )
1150
1121
  parser.add_argument(
1151
1122
  "--seed",
@@ -1187,29 +1158,23 @@ def main():
1187
1158
  input_size = xs.shape[1]
1188
1159
  cell_factor_size = 0 if us is None else us.shape[1]
1189
1160
 
1190
- latent_dist = args.z_dist
1191
-
1192
1161
  ###########################################
1193
1162
  sure = SURE(
1194
1163
  input_size=input_size,
1195
1164
  cell_factor_size=cell_factor_size,
1196
1165
  inverse_dispersion=args.inverse_dispersion,
1197
- latent_dim=args.latent_dim,
1166
+ z_dim=args.z_dim,
1198
1167
  hidden_layers=args.hidden_layers,
1199
1168
  hidden_layer_activation=args.hidden_layer_activation,
1200
1169
  use_cuda=args.cuda,
1201
1170
  config_enum=args.enum_discrete,
1202
- use_dirichlet=args.use_dirichlet,
1203
- zero_inflation=args.zero_inflation,
1204
- gate_prior=args.gate_prior,
1205
- delta=args.delta,
1171
+ use_zeroinflate=args.zeroinflate,
1206
1172
  loss_func=args.likelihood,
1207
- dirichlet_mass=args.dirichlet_mass,
1208
1173
  nn_dropout=args.layer_dropout_rate,
1209
1174
  post_layer_fct=args.post_layer_function,
1210
1175
  post_act_fct=args.post_activation_function,
1211
1176
  codebook_size=args.codebook_size,
1212
- latent_dist = latent_dist,
1177
+ z_dist = args.z_dist,
1213
1178
  dtype=dtype,
1214
1179
  )
1215
1180
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.4
3
+ Version: 2.1.6
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,5 +1,5 @@
1
- SURE/PerturbFlow.py,sha256=M_h66aHWTl5Vf0TC6e-UJrC4T3FQGH0wFLZ95jgdt7U,51674
2
- SURE/SURE.py,sha256=_ZOymj24DLQju0Lb90lKspHPmqIUDDzjIEr9t4qgqCI,48364
1
+ SURE/PerturbFlow.py,sha256=eW_RUuNd-D4UUY-YhCHklWm0TMdOdfyuzfvwRWwfwAc,49553
2
+ SURE/SURE.py,sha256=u60yf0eVa8MglVXZiIyGiwJGTBFqn3-92CC6VlvHp6w,47434
3
3
  SURE/SURE2.py,sha256=8wlnMwb1xuf9QUksNkWdWx5ZWq-xIy9NLx8RdUnE82o,48501
4
4
  SURE/__init__.py,sha256=xV10iBbh69g4mjBMb1cQxjuHe8e3Aq7pDzkZmx5G754,260
5
5
  SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
@@ -17,9 +17,9 @@ SURE/utils/__init__.py,sha256=Htqv4KqVKcRiaaTBsR-6yZ4LSlbhbzutjNKXGD9-uds,660
17
17
  SURE/utils/custom_mlp.py,sha256=07TYX1HgxfEjb_3i5MpiZfNhOhx3dKntuwGkrpteWiM,7036
18
18
  SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
19
19
  SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
20
- sure_tools-2.1.4.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
- sure_tools-2.1.4.dist-info/METADATA,sha256=oJNTuJ2eturhVARGLQkmZ3-wGhGmWTTMAOXHABKUhJg,2650
22
- sure_tools-2.1.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- sure_tools-2.1.4.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
- sure_tools-2.1.4.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
- sure_tools-2.1.4.dist-info/RECORD,,
20
+ sure_tools-2.1.6.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
+ sure_tools-2.1.6.dist-info/METADATA,sha256=BP7LH0ln9g3DZeiuQlh8Mjxz8hHunCF3n9bwelpQ1c4,2650
22
+ sure_tools-2.1.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ sure_tools-2.1.6.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
+ sure_tools-2.1.6.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
+ sure_tools-2.1.6.dist-info/RECORD,,