SURE-tools 2.1.34__py3-none-any.whl → 2.2.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -54,26 +54,27 @@ def set_random_seed(seed):
54
54
  # Set seed for Pyro
55
55
  pyro.set_rng_seed(seed)
56
56
 
57
- class PerturbFlow(nn.Module):
57
+ class DensityFlow(nn.Module):
58
58
  def __init__(self,
59
59
  input_size: int,
60
60
  codebook_size: int = 200,
61
61
  cell_factor_size: int = 0,
62
+ turn_off_cell_specific: bool = False,
62
63
  supervised_mode: bool = False,
63
64
  z_dim: int = 10,
64
- z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
65
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
65
+ z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
66
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
66
67
  inverse_dispersion: float = 10.0,
67
- use_zeroinflate: bool = True,
68
- hidden_layers: list = [300],
68
+ use_zeroinflate: bool = False,
69
+ hidden_layers: list = [500],
69
70
  hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
70
71
  nn_dropout: float = 0.1,
71
72
  post_layer_fct: list = ['layernorm'],
72
73
  post_act_fct: list = None,
73
74
  config_enum: str = 'parallel',
74
- use_cuda: bool = False,
75
+ use_cuda: bool = True,
75
76
  seed: int = 42,
76
- zero_bias: bool = True,
77
+ zero_bias: bool|list = True,
77
78
  dtype = torch.float32, # type: ignore
78
79
  ):
79
80
  super().__init__()
@@ -97,7 +98,12 @@ class PerturbFlow(nn.Module):
97
98
  self.post_layer_fct = post_layer_fct
98
99
  self.post_act_fct = post_act_fct
99
100
  self.hidden_layer_activation = hidden_layer_activation
100
- self.use_bias = not zero_bias
101
+ if type(zero_bias) == list:
102
+ self.use_bias = [not x for x in zero_bias]
103
+ else:
104
+ self.use_bias = [not zero_bias] * self.cell_factor_size
105
+ #self.use_bias = not zero_bias
106
+ self.turn_off_cell_specific = turn_off_cell_specific
101
107
 
102
108
  self.codebook_weights = None
103
109
 
@@ -198,38 +204,62 @@ class PerturbFlow(nn.Module):
198
204
  if self.cell_factor_size>0:
199
205
  self.cell_factor_effect = nn.ModuleList()
200
206
  for i in np.arange(self.cell_factor_size):
201
- if self.use_bias:
202
- self.cell_factor_effect.append(MLP(
203
- [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
204
- activation=activate_fct,
205
- output_activation=None,
206
- post_layer_fct=post_layer_fct,
207
- post_act_fct=post_act_fct,
208
- allow_broadcast=self.allow_broadcast,
209
- use_cuda=self.use_cuda,
207
+ if self.use_bias[i]:
208
+ if self.turn_off_cell_specific:
209
+ self.cell_factor_effect.append(MLP(
210
+ [1] + self.decoder_hidden_layers + [self.latent_dim],
211
+ activation=activate_fct,
212
+ output_activation=None,
213
+ post_layer_fct=post_layer_fct,
214
+ post_act_fct=post_act_fct,
215
+ allow_broadcast=self.allow_broadcast,
216
+ use_cuda=self.use_cuda,
217
+ )
218
+ )
219
+ else:
220
+ self.cell_factor_effect.append(MLP(
221
+ [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
222
+ activation=activate_fct,
223
+ output_activation=None,
224
+ post_layer_fct=post_layer_fct,
225
+ post_act_fct=post_act_fct,
226
+ allow_broadcast=self.allow_broadcast,
227
+ use_cuda=self.use_cuda,
228
+ )
210
229
  )
211
- )
212
230
  else:
213
- self.cell_factor_effect.append(ZeroBiasMLP(
214
- [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
215
- activation=activate_fct,
216
- output_activation=None,
217
- post_layer_fct=post_layer_fct,
218
- post_act_fct=post_act_fct,
219
- allow_broadcast=self.allow_broadcast,
220
- use_cuda=self.use_cuda,
231
+ if self.turn_off_cell_specific:
232
+ self.cell_factor_effect.append(ZeroBiasMLP(
233
+ [1] + self.decoder_hidden_layers + [self.latent_dim],
234
+ activation=activate_fct,
235
+ output_activation=None,
236
+ post_layer_fct=post_layer_fct,
237
+ post_act_fct=post_act_fct,
238
+ allow_broadcast=self.allow_broadcast,
239
+ use_cuda=self.use_cuda,
240
+ )
241
+ )
242
+ else:
243
+ self.cell_factor_effect.append(ZeroBiasMLP(
244
+ [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
245
+ activation=activate_fct,
246
+ output_activation=None,
247
+ post_layer_fct=post_layer_fct,
248
+ post_act_fct=post_act_fct,
249
+ allow_broadcast=self.allow_broadcast,
250
+ use_cuda=self.use_cuda,
251
+ )
221
252
  )
222
- )
223
253
 
224
254
  self.decoder_concentrate = MLP(
225
- [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
226
- activation=activate_fct,
227
- output_activation=None,
228
- post_layer_fct=post_layer_fct,
229
- post_act_fct=post_act_fct,
230
- allow_broadcast=self.allow_broadcast,
231
- use_cuda=self.use_cuda,
232
- )
255
+ [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
256
+ activation=activate_fct,
257
+ output_activation=None,
258
+ post_layer_fct=post_layer_fct,
259
+ post_act_fct=post_act_fct,
260
+ allow_broadcast=self.allow_broadcast,
261
+ use_cuda=self.use_cuda,
262
+ )
233
263
 
234
264
  if self.latent_dist == 'studentt':
235
265
  self.codebook = MLP(
@@ -304,7 +334,7 @@ class PerturbFlow(nn.Module):
304
334
  return xs
305
335
 
306
336
  def model1(self, xs):
307
- pyro.module('PerturbFlow', self)
337
+ pyro.module('DensityFlow', self)
308
338
 
309
339
  eps = torch.finfo(xs.dtype).eps
310
340
  batch_size = xs.size(0)
@@ -318,7 +348,7 @@ class PerturbFlow(nn.Module):
318
348
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
319
349
 
320
350
  acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
321
-
351
+
322
352
  I = torch.eye(self.code_size)
323
353
  if self.latent_dist=='studentt':
324
354
  acs_dof,acs_loc = self.codebook(I)
@@ -347,12 +377,13 @@ class PerturbFlow(nn.Module):
347
377
 
348
378
  zs = zns
349
379
  concentrate = self.decoder_concentrate(zs)
350
- if self.loss_func == 'bernoulli':
380
+ if self.loss_func in ['bernoulli']:
351
381
  log_theta = concentrate
352
382
  else:
353
383
  rate = concentrate.exp()
354
- if self.loss_func != 'poisson':
355
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
384
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
385
+ if self.loss_func == 'poisson':
386
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
356
387
 
357
388
  if self.loss_func == 'negbinomial':
358
389
  if self.use_zeroinflate:
@@ -374,14 +405,15 @@ class PerturbFlow(nn.Module):
374
405
 
375
406
  def guide1(self, xs):
376
407
  with pyro.plate('data'):
377
- zn_loc, zn_scale = self.encoder_zn(xs)
408
+ #zn_loc, zn_scale = self.encoder_zn(xs)
409
+ zn_loc, zn_scale = self._get_basal_embedding(xs)
378
410
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
379
411
 
380
412
  alpha = self.encoder_n(zns)
381
413
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
382
414
 
383
415
  def model2(self, xs, us=None):
384
- pyro.module('PerturbFlow', self)
416
+ pyro.module('DensityFlow', self)
385
417
 
386
418
  eps = torch.finfo(xs.dtype).eps
387
419
  batch_size = xs.size(0)
@@ -423,23 +455,19 @@ class PerturbFlow(nn.Module):
423
455
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
424
456
 
425
457
  if self.cell_factor_size>0:
426
- zus = None
427
- for i in np.arange(self.cell_factor_size):
428
- if i==0:
429
- zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
430
- else:
431
- zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
458
+ zus = self._total_effects(zns, us)
432
459
  zs = zns+zus
433
460
  else:
434
461
  zs = zns
435
462
 
436
463
  concentrate = self.decoder_concentrate(zs)
437
- if self.loss_func == 'bernoulli':
464
+ if self.loss_func in ['bernoulli']:
438
465
  log_theta = concentrate
439
466
  else:
440
467
  rate = concentrate.exp()
441
- if self.loss_func != 'poisson':
442
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
468
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
469
+ if self.loss_func == 'poisson':
470
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
443
471
 
444
472
  if self.loss_func == 'negbinomial':
445
473
  if self.use_zeroinflate:
@@ -461,14 +489,15 @@ class PerturbFlow(nn.Module):
461
489
 
462
490
  def guide2(self, xs, us=None):
463
491
  with pyro.plate('data'):
464
- zn_loc, zn_scale = self.encoder_zn(xs)
492
+ #zn_loc, zn_scale = self.encoder_zn(xs)
493
+ zn_loc, zn_scale = self._get_basal_embedding(xs)
465
494
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
466
495
 
467
496
  alpha = self.encoder_n(zns)
468
497
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
469
498
 
470
499
  def model3(self, xs, ys, embeds=None):
471
- pyro.module('PerturbFlow', self)
500
+ pyro.module('DensityFlow', self)
472
501
 
473
502
  eps = torch.finfo(xs.dtype).eps
474
503
  batch_size = xs.size(0)
@@ -528,12 +557,13 @@ class PerturbFlow(nn.Module):
528
557
  zs = zns
529
558
 
530
559
  concentrate = self.decoder_concentrate(zs)
531
- if self.loss_func == 'bernoulli':
560
+ if self.loss_func in ['bernoulli']:
532
561
  log_theta = concentrate
533
562
  else:
534
563
  rate = concentrate.exp()
535
- if self.loss_func != 'poisson':
536
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
564
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
565
+ if self.loss_func == 'poisson':
566
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
537
567
 
538
568
  if self.loss_func == 'negbinomial':
539
569
  if self.use_zeroinflate:
@@ -556,11 +586,14 @@ class PerturbFlow(nn.Module):
556
586
  def guide3(self, xs, ys, embeds=None):
557
587
  with pyro.plate('data'):
558
588
  if embeds is None:
559
- zn_loc, zn_scale = self.encoder_zn(xs)
589
+ #zn_loc, zn_scale = self.encoder_zn(xs)
590
+ zn_loc, zn_scale = self._get_basal_embedding(xs)
560
591
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
592
+ else:
593
+ zns = embeds
561
594
 
562
595
  def model4(self, xs, us, ys, embeds=None):
563
- pyro.module('PerturbFlow', self)
596
+ pyro.module('DensityFlow', self)
564
597
 
565
598
  eps = torch.finfo(xs.dtype).eps
566
599
  batch_size = xs.size(0)
@@ -618,23 +651,25 @@ class PerturbFlow(nn.Module):
618
651
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
619
652
 
620
653
  if self.cell_factor_size>0:
621
- zus = None
622
- for i in np.arange(self.cell_factor_size):
623
- if i==0:
624
- zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
625
- else:
626
- zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
654
+ #zus = None
655
+ #for i in np.arange(self.cell_factor_size):
656
+ # if i==0:
657
+ # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
658
+ # else:
659
+ # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
660
+ zus = self._total_effects(zns, us)
627
661
  zs = zns+zus
628
662
  else:
629
663
  zs = zns
630
664
 
631
665
  concentrate = self.decoder_concentrate(zs)
632
- if self.loss_func == 'bernoulli':
666
+ if self.loss_func in ['bernoulli']:
633
667
  log_theta = concentrate
634
668
  else:
635
669
  rate = concentrate.exp()
636
- if self.loss_func != 'poisson':
637
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
670
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
671
+ if self.loss_func == 'poisson':
672
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
638
673
 
639
674
  if self.loss_func == 'negbinomial':
640
675
  if self.use_zeroinflate:
@@ -657,9 +692,32 @@ class PerturbFlow(nn.Module):
657
692
  def guide4(self, xs, us, ys, embeds=None):
658
693
  with pyro.plate('data'):
659
694
  if embeds is None:
660
- zn_loc, zn_scale = self.encoder_zn(xs)
695
+ #zn_loc, zn_scale = self.encoder_zn(xs)
696
+ zn_loc, zn_scale = self._get_basal_embedding(xs)
661
697
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
662
-
698
+ else:
699
+ zns = embeds
700
+
701
+ def _total_effects(self, zns, us):
702
+ zus = None
703
+ for i in np.arange(self.cell_factor_size):
704
+ if i==0:
705
+ #if self.turn_off_cell_specific:
706
+ # zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
707
+ #else:
708
+ # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
709
+ zus = self._cell_response(zns, i, us[:,i].reshape(-1,1))
710
+ else:
711
+ #if self.turn_off_cell_specific:
712
+ # zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
713
+ #else:
714
+ # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
715
+ zus = zus + self._cell_response(zns, i, us[:,i].reshape(-1,1))
716
+ return zus
717
+
718
+ def _get_codebook_identity(self):
719
+ return torch.eye(self.code_size, **self.options)
720
+
663
721
  def _get_codebook(self):
664
722
  I = torch.eye(self.code_size, **self.options)
665
723
  if self.latent_dist=='studentt':
@@ -672,13 +730,13 @@ class PerturbFlow(nn.Module):
672
730
  """
673
731
  Return the mean part of metacell codebook
674
732
  """
675
- cb = self._get_metacell_coordinates()
733
+ cb = self._get_codebook()
676
734
  cb = tensor_to_numpy(cb)
677
735
  return cb
678
736
 
679
737
  def _get_basal_embedding(self, xs):
680
- zns, _ = self.encoder_zn(xs)
681
- return zns
738
+ loc, scale = self.encoder_zn(xs)
739
+ return loc, scale
682
740
 
683
741
  def get_basal_embedding(self,
684
742
  xs,
@@ -705,7 +763,7 @@ class PerturbFlow(nn.Module):
705
763
  Z = []
706
764
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
707
765
  for X_batch, _ in dataloader:
708
- zns = self._get_basal_embedding(X_batch)
766
+ zns,_ = self._get_basal_embedding(X_batch)
709
767
  Z.append(tensor_to_numpy(zns))
710
768
  pbar.update(1)
711
769
 
@@ -716,7 +774,8 @@ class PerturbFlow(nn.Module):
716
774
  if self.supervised_mode:
717
775
  alpha = self.encoder_n(xs)
718
776
  else:
719
- zns,_ = self.encoder_zn(xs)
777
+ #zns,_ = self.encoder_zn(xs)
778
+ zns,_ = self._get_basal_embedding(xs)
720
779
  alpha = self.encoder_n(zns)
721
780
  return alpha
722
781
 
@@ -785,46 +844,80 @@ class PerturbFlow(nn.Module):
785
844
  A = np.concatenate(A)
786
845
  return A
787
846
 
788
- def _cell_response(self, xs, factor_idx, perturb):
789
- zns,_ = self.encoder_zn(xs)
847
+ def predict(self, xs, us, perturbs_predict:list, perturbs_reference:list, library_sizes=None):
848
+ perturbs_reference = np.array(perturbs_reference)
849
+
850
+ # basal embedding
851
+ zs = self.get_basal_embedding(xs)
852
+ for pert in perturbs_predict:
853
+ pert_idx = int(np.where(perturbs_reference==pert)[0])
854
+ us_i = us[:,pert_idx].reshape(-1,1)
855
+
856
+ # factor effect of xs
857
+ dzs0 = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=us_i)
858
+
859
+ # perturbation effect
860
+ ps = np.ones_like(us_i)
861
+ if np.sum(np.abs(ps-us_i))>=1:
862
+ dzs = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=ps)
863
+ zs = zs + dzs0 + dzs
864
+ else:
865
+ zs = zs + dzs0
866
+
867
+ if library_sizes is None:
868
+ library_sizes = np.sum(xs, axis=1, keepdims=True)
869
+ elif type(library_sizes) == list:
870
+ library_sizes = np.array(library_sizes)
871
+ library_sizes = library_sizes.reshape(-1,1)
872
+ elif len(library_sizes.shape)==1:
873
+ library_sizes = library_sizes.reshape(-1,1)
874
+
875
+ counts = self.get_counts(zs, library_sizes=library_sizes)
876
+
877
+ return counts, zs
878
+
879
+ def _cell_response(self, zs, perturb_idx, perturb):
880
+ #zns,_ = self.encoder_zn(xs)
881
+ #zns,_ = self._get_basal_embedding(xs)
882
+ zns = zs
790
883
  if perturb.ndim==2:
791
- ms = self.cell_factor_effect[factor_idx]([zns, perturb])
884
+ if self.turn_off_cell_specific:
885
+ ms = self.cell_factor_effect[perturb_idx](perturb)
886
+ else:
887
+ ms = self.cell_factor_effect[perturb_idx]([zns, perturb])
792
888
  else:
793
- ms = self.cell_factor_effect[factor_idx]([zns, perturb.reshape(-1,1)])
889
+ if self.turn_off_cell_specific:
890
+ ms = self.cell_factor_effect[perturb_idx](perturb.reshape(-1,1))
891
+ else:
892
+ ms = self.cell_factor_effect[perturb_idx]([zns, perturb.reshape(-1,1)])
794
893
 
795
894
  return ms
796
895
 
797
896
  def get_cell_response(self,
798
- xs,
799
- factor_idx,
800
- perturb,
897
+ zs,
898
+ perturb_idx,
899
+ perturb_us,
801
900
  batch_size: int = 1024):
802
901
  """
803
902
  Return cells' changes in the latent space induced by specific perturbation of a factor
804
903
 
805
904
  """
806
- xs = self.preprocess(xs)
807
- xs = convert_to_tensor(xs, device=self.get_device())
808
- ps = convert_to_tensor(perturb, device=self.get_device())
809
- dataset = CustomDataset2(xs,ps)
905
+ #xs = self.preprocess(xs)
906
+ zs = convert_to_tensor(zs, device=self.get_device())
907
+ ps = convert_to_tensor(perturb_us, device=self.get_device())
908
+ dataset = CustomDataset2(zs,ps)
810
909
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
811
910
 
812
911
  Z = []
813
912
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
814
- for X_batch, P_batch, _ in dataloader:
815
- zns = self._cell_response(X_batch, factor_idx, P_batch)
913
+ for Z_batch, P_batch, _ in dataloader:
914
+ zns = self._cell_response(Z_batch, perturb_idx, P_batch)
816
915
  Z.append(tensor_to_numpy(zns))
817
916
  pbar.update(1)
818
917
 
819
918
  Z = np.concatenate(Z)
820
919
  return Z
821
920
 
822
- def get_metacell_response(self, factor_idx, perturb):
823
- zs = self._get_codebook()
824
- ps = convert_to_tensor(perturb, device=self.get_device())
825
- ms = self.cell_factor_effect[factor_idx]([zs,ps])
826
- return tensor_to_numpy(ms)
827
-
828
921
  def _get_expression_response(self, delta_zs):
829
922
  return self.decoder_concentrate(delta_zs)
830
923
 
@@ -849,38 +942,35 @@ class PerturbFlow(nn.Module):
849
942
  R = np.concatenate(R)
850
943
  return R
851
944
 
852
- def _count(self,concentrate):
853
- if self.loss_func == 'bernoulli':
854
- counts = self.sigmoid(concentrate)
855
- else:
856
- counts = concentrate.exp()
857
- return counts
858
-
859
- def _count_sample(self,concentrate):
945
+ def _count(self, concentrate, library_size=None):
860
946
  if self.loss_func == 'bernoulli':
861
- logits = concentrate
862
- counts = dist.Bernoulli(logits=logits).to_event(1).sample()
947
+ #counts = self.sigmoid(concentrate)
948
+ counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
863
949
  else:
864
- counts = self._count(concentrate=concentrate)
865
- counts = dist.Poisson(rate=counts).to_event(1).sample()
950
+ rate = concentrate.exp()
951
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
952
+ counts = theta * library_size
866
953
  return counts
867
954
 
868
- def get_counts(self, zs,
869
- batch_size: int = 1024,
870
- use_sampler: bool = False):
955
+ def get_counts(self, zs, library_sizes,
956
+ batch_size: int = 1024):
871
957
 
872
958
  zs = convert_to_tensor(zs, device=self.get_device())
873
- dataset = CustomDataset(zs)
959
+
960
+ if type(library_sizes) == list:
961
+ library_sizes = np.array(library_sizes).reshape(-1,1)
962
+ elif len(library_sizes.shape)==1:
963
+ library_sizes = library_sizes.reshape(-1,1)
964
+ ls = convert_to_tensor(library_sizes, device=self.get_device())
965
+
966
+ dataset = CustomDataset2(zs,ls)
874
967
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
875
968
 
876
969
  E = []
877
970
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
878
- for Z_batch, _ in dataloader:
879
- concentrate = self._expression(Z_batch)
880
- if use_sampler:
881
- counts = self._count_sample(concentrate)
882
- else:
883
- counts = self._count(concentrate)
971
+ for Z_batch, L_batch, _ in dataloader:
972
+ concentrate = self._get_expression_response(Z_batch)
973
+ counts = self._count(concentrate, L_batch)
884
974
  E.append(tensor_to_numpy(counts))
885
975
  pbar.update(1)
886
976
 
@@ -903,7 +993,7 @@ class PerturbFlow(nn.Module):
903
993
  us = None,
904
994
  ys = None,
905
995
  zs = None,
906
- num_epochs: int = 200,
996
+ num_epochs: int = 500,
907
997
  learning_rate: float = 0.0001,
908
998
  batch_size: int = 256,
909
999
  algo: Literal['adam','rmsprop','adamw'] = 'adam',
@@ -912,9 +1002,9 @@ class PerturbFlow(nn.Module):
912
1002
  decay_rate: float = 0.9,
913
1003
  config_enum: str = 'parallel',
914
1004
  threshold: int = 0,
915
- use_jax: bool = False):
1005
+ use_jax: bool = True):
916
1006
  """
917
- Train the PerturbFlow model.
1007
+ Train the DensityFlow model.
918
1008
 
919
1009
  Parameters
920
1010
  ----------
@@ -940,7 +1030,7 @@ class PerturbFlow(nn.Module):
940
1030
  Parameter for optimization.
941
1031
  use_jax
942
1032
  If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
943
- the Python script or Jupyter notebook. It is OK if it is used when runing PerturbFlow in the shell command.
1033
+ the Python script or Jupyter notebook. It is OK if it is used when runing DensityFlow in the shell command.
944
1034
  """
945
1035
  xs = self.preprocess(xs, threshold=threshold)
946
1036
  xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
@@ -1025,7 +1115,7 @@ class PerturbFlow(nn.Module):
1025
1115
  # Update progress bar
1026
1116
  pbar.set_postfix({'loss': str_loss})
1027
1117
  pbar.update(1)
1028
-
1118
+
1029
1119
  @classmethod
1030
1120
  def save_model(cls, model, file_path, compression=False):
1031
1121
  """Save the model to the specified file path."""
@@ -1058,12 +1148,12 @@ class PerturbFlow(nn.Module):
1058
1148
 
1059
1149
 
1060
1150
  EXAMPLE_RUN = (
1061
- "example run: PerturbFlow --help"
1151
+ "example run: DensityFlow --help"
1062
1152
  )
1063
1153
 
1064
1154
  def parse_args():
1065
1155
  parser = argparse.ArgumentParser(
1066
- description="PerturbFlow\n{}".format(EXAMPLE_RUN))
1156
+ description="DensityFlow\n{}".format(EXAMPLE_RUN))
1067
1157
 
1068
1158
  parser.add_argument(
1069
1159
  "--cuda", action="store_true", help="use GPU(s) to speed up training"
@@ -1250,7 +1340,7 @@ def main():
1250
1340
  cell_factor_size = 0 if us is None else us.shape[1]
1251
1341
 
1252
1342
  ###########################################
1253
- perturbflow = PerturbFlow(
1343
+ DensityFlow = DensityFlow(
1254
1344
  input_size=input_size,
1255
1345
  cell_factor_size=cell_factor_size,
1256
1346
  inverse_dispersion=args.inverse_dispersion,
@@ -1269,7 +1359,7 @@ def main():
1269
1359
  dtype=dtype,
1270
1360
  )
1271
1361
 
1272
- perturbflow.fit(xs, us=us,
1362
+ DensityFlow.fit(xs, us=us,
1273
1363
  num_epochs=args.num_epochs,
1274
1364
  learning_rate=args.learning_rate,
1275
1365
  batch_size=args.batch_size,
@@ -1281,12 +1371,11 @@ def main():
1281
1371
 
1282
1372
  if args.save_model is not None:
1283
1373
  if args.save_model.endswith('gz'):
1284
- PerturbFlow.save_model(perturbflow, args.save_model, compression=True)
1374
+ DensityFlow.save_model(DensityFlow, args.save_model, compression=True)
1285
1375
  else:
1286
- PerturbFlow.save_model(perturbflow, args.save_model)
1376
+ DensityFlow.save_model(DensityFlow, args.save_model)
1287
1377
 
1288
1378
 
1289
1379
 
1290
1380
  if __name__ == "__main__":
1291
-
1292
1381
  main()
SURE/SURE.py CHANGED
@@ -99,19 +99,18 @@ class SURE(nn.Module):
99
99
  cell_factor_size: int = 0,
100
100
  supervised_mode: bool = False,
101
101
  z_dim: int = 10,
102
- z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
103
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
102
+ z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
103
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'poisson',
104
104
  inverse_dispersion: float = 10.0,
105
105
  use_zeroinflate: bool = True,
106
- hidden_layers: list = [300],
106
+ hidden_layers: list = [500],
107
107
  hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
108
108
  nn_dropout: float = 0.1,
109
109
  post_layer_fct: list = ['layernorm'],
110
110
  post_act_fct: list = None,
111
111
  config_enum: str = 'parallel',
112
- use_cuda: bool = False,
112
+ use_cuda: bool = True,
113
113
  seed: int = 42,
114
- zero_bias: bool = True,
115
114
  dtype = torch.float32, # type: ignore
116
115
  ):
117
116
  super().__init__()
@@ -135,7 +134,6 @@ class SURE(nn.Module):
135
134
  self.post_layer_fct = post_layer_fct
136
135
  self.post_act_fct = post_act_fct
137
136
  self.hidden_layer_activation = hidden_layer_activation
138
- self.use_bias = not zero_bias
139
137
 
140
138
  self.codebook_weights = None
141
139
 
@@ -234,26 +232,16 @@ class SURE(nn.Module):
234
232
  )
235
233
 
236
234
  if self.cell_factor_size>0:
237
- if self.use_bias:
238
- self.cell_factor_effect = MLP(
239
- [self.latent_dim + self.cell_factor_size] + self.decoder_hidden_layers + [self.latent_dim],
240
- activation=activate_fct,
241
- output_activation=None,
242
- post_layer_fct=post_layer_fct,
243
- post_act_fct=post_act_fct,
244
- allow_broadcast=self.allow_broadcast,
245
- use_cuda=self.use_cuda,
246
- )
247
- else:
248
- self.cell_factor_effect = ZeroBiasMLP(
249
- [self.latent_dim + self.cell_factor_size] + self.decoder_hidden_layers + [self.latent_dim],
250
- activation=activate_fct,
251
- output_activation=None,
252
- post_layer_fct=post_layer_fct,
253
- post_act_fct=post_act_fct,
254
- allow_broadcast=self.allow_broadcast,
255
- use_cuda=self.use_cuda,
256
- )
235
+ self.cell_factor_effect = MLP(
236
+ [self.latent_dim + self.cell_factor_size] + self.decoder_hidden_layers + [self.latent_dim],
237
+ activation=activate_fct,
238
+ output_activation=None,
239
+ post_layer_fct=post_layer_fct,
240
+ post_act_fct=post_act_fct,
241
+ allow_broadcast=self.allow_broadcast,
242
+ use_cuda=self.use_cuda,
243
+ )
244
+
257
245
 
258
246
  self.decoder_concentrate = MLP(
259
247
  [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
@@ -381,12 +369,13 @@ class SURE(nn.Module):
381
369
 
382
370
  zs = zns
383
371
  concentrate = self.decoder_concentrate(zs)
384
- if self.loss_func == 'bernoulli':
372
+ if self.loss_func in ['bernoulli']:
385
373
  log_theta = concentrate
386
374
  else:
387
375
  rate = concentrate.exp()
388
- if self.loss_func != 'poisson':
389
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
376
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
377
+ if self.loss_func == 'poisson':
378
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
390
379
 
391
380
  if self.loss_func == 'negbinomial':
392
381
  if self.use_zeroinflate:
@@ -463,12 +452,13 @@ class SURE(nn.Module):
463
452
  zs = zns
464
453
 
465
454
  concentrate = self.decoder_concentrate(zs)
466
- if self.loss_func == 'bernoulli':
455
+ if self.loss_func in ['bernoulli']:
467
456
  log_theta = concentrate
468
457
  else:
469
458
  rate = concentrate.exp()
470
- if self.loss_func != 'poisson':
471
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
459
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
460
+ if self.loss_func == 'poisson':
461
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
472
462
 
473
463
  if self.loss_func == 'negbinomial':
474
464
  if self.use_zeroinflate:
@@ -557,12 +547,13 @@ class SURE(nn.Module):
557
547
  zs = zns
558
548
 
559
549
  concentrate = self.decoder_concentrate(zs)
560
- if self.loss_func == 'bernoulli':
550
+ if self.loss_func in ['bernoulli']:
561
551
  log_theta = concentrate
562
552
  else:
563
553
  rate = concentrate.exp()
564
- if self.loss_func != 'poisson':
565
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
554
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
555
+ if self.loss_func == 'poisson':
556
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
566
557
 
567
558
  if self.loss_func == 'negbinomial':
568
559
  if self.use_zeroinflate:
@@ -653,13 +644,14 @@ class SURE(nn.Module):
653
644
  zs = zns
654
645
 
655
646
  concentrate = self.decoder_concentrate(zs)
656
- if self.loss_func == 'bernoulli':
647
+ if self.loss_func in ['bernoulli']:
657
648
  log_theta = concentrate
658
649
  else:
659
650
  rate = concentrate.exp()
660
- if self.loss_func != 'poisson':
661
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
662
-
651
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
652
+ if self.loss_func == 'poisson':
653
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
654
+
663
655
  if self.loss_func == 'negbinomial':
664
656
  if self.use_zeroinflate:
665
657
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -825,7 +817,7 @@ class SURE(nn.Module):
825
817
  us = None,
826
818
  ys = None,
827
819
  zs = None,
828
- num_epochs: int = 200,
820
+ num_epochs: int = 500,
829
821
  learning_rate: float = 0.0001,
830
822
  batch_size: int = 256,
831
823
  algo: Literal['adam','rmsprop','adamw'] = 'adam',
@@ -834,7 +826,7 @@ class SURE(nn.Module):
834
826
  decay_rate: float = 0.9,
835
827
  config_enum: str = 'parallel',
836
828
  threshold: int = 0,
837
- use_jax: bool = False):
829
+ use_jax: bool = True):
838
830
  """
839
831
  Train the SURE model.
840
832
 
SURE/__init__.py CHANGED
@@ -1,12 +1,12 @@
1
1
  from .SURE import SURE
2
- from .PerturbFlow import PerturbFlow
2
+ from .DensityFlow import DensityFlow
3
3
 
4
4
  from . import utils
5
5
  from . import codebook
6
6
  from . import SURE
7
- from . import PerturbFlow
7
+ from . import DensityFlow
8
8
  from . import atac
9
9
  from . import flow
10
10
  from . import perturb
11
11
 
12
- __all__ = ['SURE', 'PerturbFlow', 'flow', 'perturb', 'atac', 'utils', 'codebook']
12
+ __all__ = ['SURE', 'DensityFlow', 'flow', 'perturb', 'atac', 'utils', 'codebook']
SURE/flow/flow_stats.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import numpy as np
2
+ from scipy.interpolate import griddata
2
3
  from scipy.spatial.distance import pdist, squareform
3
4
  from sklearn.decomposition import PCA
4
5
  from scipy.stats import pearsonr
@@ -16,6 +17,42 @@ class VectorFieldEval:
16
17
 
17
18
  def momentum_flow_metric(self, vectors, masses=None):
18
19
  return momentum_flow_metric(vectors=vectors, masses=masses)
20
+
21
+ def divergence(self, points, vectors, grid_resolution=30):
22
+ # 提取坐标和向量分量
23
+ x_coords = points[:, 0]
24
+ y_coords = points[:, 1]
25
+ u_components = vectors[:, 0] # x方向分量
26
+ v_components = vectors[:, 1] # y方向分量
27
+
28
+ # 创建规则网格
29
+ x_grid = np.linspace(x_coords.min(), x_coords.max(), grid_resolution)
30
+ y_grid = np.linspace(y_coords.min(), y_coords.max(), grid_resolution)
31
+ X, Y = np.meshgrid(x_grid, y_grid)
32
+
33
+ # 插值到网格
34
+ U_grid = griddata((x_coords, y_coords), u_components, (X, Y), method='linear')
35
+ V_grid = griddata((x_coords, y_coords), v_components, (X, Y), method='linear')
36
+
37
+ # 计算散度
38
+ dU_dx = np.gradient(U_grid, x_grid, axis=1)
39
+ dV_dy = np.gradient(V_grid, y_grid, axis=0)
40
+ divergence = dU_dx + dV_dy
41
+ divergence[np.isnan(divergence)] = 0
42
+
43
+ return divergence
44
+
45
+ def movement_stats(self,vectors):
46
+ return calculate_movement_stats(vectors)
47
+
48
+ def direction_stats(self, vectors):
49
+ return calculate_direction_stats(vectors)
50
+
51
+ def movement_energy(self, vectors, masses=None):
52
+ return calculate_movement_energy(vectors, masses)
53
+
54
+ def movement_divergence(self, positions, vectors):
55
+ return calculate_movement_divergence(positions, vectors)
19
56
 
20
57
 
21
58
  def calculate_movement_stats(vectors):
SURE/perturb/__init__.py CHANGED
@@ -1 +1 @@
1
- from .perturb import LabelMatrix
1
+ from .perturb import LabelMatrix,DoseMatrix
SURE/perturb/perturb.py CHANGED
@@ -1,5 +1,7 @@
1
1
  import re
2
2
  import numpy as np
3
+ import pandas as pd
4
+ from numba import njit
3
5
  from itertools import chain
4
6
  from joblib import Parallel, delayed
5
7
  from typing import Literal
@@ -7,8 +9,10 @@ from typing import Literal
7
9
  class LabelMatrix:
8
10
  def __init__(self):
9
11
  self.labels_ = None
12
+ self.control_label = None
13
+ self.sep_pattern = None
10
14
 
11
- def fit_transform(self, labels, sep_pattern=r'[;_\-\s]', speedup: Literal['none','vectorize','parallel']='none'):
15
+ def fit_transform(self, labels, control_label=None, sep_pattern=r'[,;_\s]', speedup: Literal['none','vectorize','parallel']='none'):
12
16
  if speedup=='none':
13
17
  mat, self.labels_ = label_to_matrix(labels=labels, sep_pattern=sep_pattern)
14
18
  elif speedup=='vectorize':
@@ -17,10 +21,53 @@ class LabelMatrix:
17
21
  mat, self.labels_ = parallel_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
18
22
 
19
23
  self.labels_ = np.array(self.labels_)
20
- return mat
24
+
25
+ if control_label is not None:
26
+ idx = np.where(self.labels_==control_label)[0]
27
+ mat = np.delete(mat, idx, axis=1)
28
+ self.labels_ = np.delete(self.labels_, idx)
21
29
 
30
+ self.control_label = control_label
31
+ self.sep_pattern=sep_pattern
32
+
33
+ return mat
34
+
35
+ def transform(self, labels, speedup: Literal['none','vectorize','parallel']='none'):
36
+ sep_pattern = self.sep_pattern
37
+ if speedup=='none':
38
+ mat, labels_ = label_to_matrix(labels=labels, sep_pattern=sep_pattern)
39
+ elif speedup=='vectorize':
40
+ mat, labels_ = vectorized_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
41
+ elif speedup=='parallel':
42
+ mat, labels_ = parallel_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
43
+
44
+ mat_df = pd.DataFrame(mat, columns=labels_)
45
+
46
+ labels_valid = [x for x in labels_ if x in self.labels_]
47
+ mat_df = mat_df[labels_valid]
48
+
49
+ mat_valid = np.zeros([mat.shape[0], len(self.labels_)])
50
+ mat_valid_df = pd.DataFrame(mat_valid, columns=self.labels_)
51
+ mat_valid_df[labels_valid] = mat_df
52
+
53
+ return mat_valid_df.values
54
+
22
55
  def inverse_transform(self, matrix):
23
56
  return matrix_to_labels(matrix=matrix, unique_labels=self.labels_)
57
+
58
+ class DoseMatrix:
59
+ def __init__(self):
60
+ self.labels_ = None
61
+
62
+ def fit_transform(self, labels, label_dose, control_label=None):
63
+ mat, self.labels_ = dose_to_matrix(labels, label_dose)
64
+
65
+ if control_label is not None:
66
+ idx = np.where(self.labels_==control_label)[0]
67
+ mat = np.delete(mat, idx, axis=1)
68
+ self.labels_ = np.delete(self.labels_, idx)
69
+
70
+ return mat
24
71
 
25
72
  def label_to_matrix(labels, sep_pattern=r'[;_\-\s]'):
26
73
  """
@@ -85,3 +132,38 @@ def parallel_label_to_matrix(labels, sep_pattern=r'[;_\-\s]', n_jobs=4):
85
132
  def matrix_to_labels(matrix, unique_labels):
86
133
  return [';'.join([unique_labels[i] for i in np.where(row)[0]])
87
134
  for row in matrix]
135
+
136
+
137
+
138
+
139
+
140
+
141
+
142
+ @njit(parallel=True)
143
+ def _numba_fill_matrix(dose_matrix, label_indices, label_doses):
144
+ """Numba 加速的矩阵填充函数"""
145
+ for i in range(len(label_indices)):
146
+ dose_matrix[i, label_indices[i]] = label_doses[i]
147
+
148
+ def dose_to_matrix(labels, label_dose, all_labels=None):
149
+ """
150
+ 使用 Numba 的终极加速版本(需预先安装 numba)
151
+ """
152
+ if all_labels is None:
153
+ all_labels = sorted(set().union(labels))
154
+
155
+ label_to_idx = {label: idx for idx, label in enumerate(all_labels)}
156
+ n_samples = len(labels)
157
+ n_labels = len(all_labels)
158
+ dose_matrix = np.zeros((n_samples, n_labels), dtype=np.float64)
159
+
160
+ # 预处理为 Numba 兼容格式
161
+ label_indices = []
162
+ label_doses = []
163
+ for i, label in enumerate(labels):
164
+ label_indices.append(label_to_idx[label])
165
+ label_doses.append(label_dose[i])
166
+
167
+ # 调用 Numba 加速函数
168
+ _numba_fill_matrix(dose_matrix, label_indices, label_doses)
169
+ return dose_matrix,np.array(all_labels)
SURE/utils/__init__.py CHANGED
@@ -7,7 +7,7 @@ from .utils import find_partitions_greedy
7
7
 
8
8
  from .queue import PriorityQueue
9
9
 
10
- from .custom_mlp import MLP, Exp, ZeroBiasMLP
10
+ from .custom_mlp import MLP, Exp, ZeroBiasMLP, HDMLP
11
11
 
12
12
  # Importing modules
13
13
  #from . import utils
SURE/utils/custom_mlp.py CHANGED
@@ -239,6 +239,43 @@ class ZeroBiasMLP(nn.Module):
239
239
  def forward(self, x):
240
240
  y = self.mlp(x)
241
241
  mask = torch.zeros_like(y)
242
- mask[x[1][:,0]>0,:] = 1
242
+ if len(y.shape)==2:
243
+ mask[x[1][:,0]>0,:] = 1
244
+ elif len(y.shape)==3:
245
+ mask[:,x[1][:,0]>0,:] = 1
243
246
  return y*mask
244
-
247
+
248
+
249
+ class HDMLP(nn.Module):
250
+ def __init__(
251
+ self,
252
+ input_size,
253
+ hidden_sizes,
254
+ output_depth,
255
+ activation=nn.ReLU,
256
+ output_activation=None,
257
+ post_layer_fct=lambda layer_ix, total_layers, layer: None,
258
+ post_act_fct=lambda layer_ix, total_layers, layer: None,
259
+ allow_broadcast=False,
260
+ use_cuda=False,
261
+ ):
262
+ # init the module object
263
+ super().__init__()
264
+ self.mlp = MLP(mlp_sizes=[1] + hidden_sizes + [output_depth],
265
+ activation=activation,
266
+ output_activation=output_activation,
267
+ post_layer_fct=post_layer_fct,
268
+ post_act_fct=post_act_fct,
269
+ allow_broadcast=allow_broadcast,
270
+ use_cuda=use_cuda,
271
+ bias=True)
272
+ self.input_size=input_size
273
+ self.output_depth=output_depth
274
+
275
+ # pass through our sequential for the output!
276
+ def forward(self, x):
277
+ batch_size, n = x.shape
278
+ x = x.view(batch_size * n, 1)
279
+ out = self.mlp(x)
280
+ out = out.view(batch_size, n, self.output_depth)
281
+ return out
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.34
3
+ Version: 2.2.24
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -20,6 +20,7 @@ Requires-Dist: numpy
20
20
  Requires-Dist: scikit-learn
21
21
  Requires-Dist: pandas
22
22
  Requires-Dist: pyro-ppl
23
+ Requires-Dist: jax[cuda12]
23
24
  Requires-Dist: leidenalg
24
25
  Requires-Dist: python-igraph
25
26
  Requires-Dist: networkx
@@ -0,0 +1,25 @@
1
+ SURE/DensityFlow.py,sha256=IpObVzq3pb2GAYt8f0rCkR8d9YdsRg1RwPruvNKHoHM,56132
2
+ SURE/SURE.py,sha256=MXs7iuvcj-lU4dJ_MwKegpL2Rqk2HB4eFfAgHRA3RtA,47744
3
+ SURE/__init__.py,sha256=NVp22RCHrhSwHNMomABC-eftoCYvt7vV1XOzim-UZHE,293
4
+ SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
5
+ SURE/assembly/assembly.py,sha256=6IMdelPOiRO4mUb4dC7gVCoF1Uvfw86-Map8P_jnUag,21477
6
+ SURE/assembly/atlas.py,sha256=ALjmVWutm_tOHTcT1aqOxmuCEQw-XzrtDoMCV_8oXLk,21794
7
+ SURE/atac/__init__.py,sha256=3smP8IKHfwNCd1G_sZH3pKHXuLkLpFuLtjUTUSy7_As,34
8
+ SURE/atac/utils.py,sha256=m4NYwpy9O5T1pXTzgCOCcmlwrC6GTi-cQ5sm2wZu2O8,4354
9
+ SURE/codebook/__init__.py,sha256=2T5gjp8JIaBayrXAnOJYSebQHsWprOs87difpR1OPNw,243
10
+ SURE/codebook/codebook.py,sha256=ZlN6gRX9Gj2D2u3P5KeOsbZri0MoMAiJo9lNeL-MK-I,17117
11
+ SURE/flow/__init__.py,sha256=rsAjYsh1xVIrxBCuwOE0Q_6N5th1wBgjJceV0ABPG3c,183
12
+ SURE/flow/flow_stats.py,sha256=6SzNMT59WRFRP1nC6bvpBPF7BugWnkIS_DSlr4S-Ez0,11338
13
+ SURE/flow/plot_quiver.py,sha256=UbmuScUcgbQHeMmjKmgqxjrIjHhiHx0VWct16UMMwuE,8110
14
+ SURE/perturb/__init__.py,sha256=8TP1dSUhXiZzKpFebHZmm8XMMGbUz_OfQ10xu-6uPPY,43
15
+ SURE/perturb/perturb.py,sha256=ey7cxsM1tO1MW4UaE_MLpLHK87CjvXzn2CBPtvv1VZ0,6116
16
+ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
17
+ SURE/utils/custom_mlp.py,sha256=HuNb7f8-6RFjsvfEu1XOuNpLrHZkGYHgf8TpJfPSNO0,9382
18
+ SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
19
+ SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
20
+ sure_tools-2.2.24.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
+ sure_tools-2.2.24.dist-info/METADATA,sha256=oQslRmRo5_NDhapldLCnsck6dXrGEEHj-VAEl4XzWNU,2678
22
+ sure_tools-2.2.24.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ sure_tools-2.2.24.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
+ sure_tools-2.2.24.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
+ sure_tools-2.2.24.dist-info/RECORD,,
@@ -1,25 +0,0 @@
1
- SURE/PerturbFlow.py,sha256=BoaNDubCKpsYJcwipZxrSCpol4nVvCttP28MizHffzY,51650
2
- SURE/SURE.py,sha256=ghagk4vO3xrAXwdyYTIv7y0X2KXr1R2baXH8lqvUl7k,48094
3
- SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
4
- SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
5
- SURE/assembly/assembly.py,sha256=6IMdelPOiRO4mUb4dC7gVCoF1Uvfw86-Map8P_jnUag,21477
6
- SURE/assembly/atlas.py,sha256=ALjmVWutm_tOHTcT1aqOxmuCEQw-XzrtDoMCV_8oXLk,21794
7
- SURE/atac/__init__.py,sha256=3smP8IKHfwNCd1G_sZH3pKHXuLkLpFuLtjUTUSy7_As,34
8
- SURE/atac/utils.py,sha256=m4NYwpy9O5T1pXTzgCOCcmlwrC6GTi-cQ5sm2wZu2O8,4354
9
- SURE/codebook/__init__.py,sha256=2T5gjp8JIaBayrXAnOJYSebQHsWprOs87difpR1OPNw,243
10
- SURE/codebook/codebook.py,sha256=ZlN6gRX9Gj2D2u3P5KeOsbZri0MoMAiJo9lNeL-MK-I,17117
11
- SURE/flow/__init__.py,sha256=rsAjYsh1xVIrxBCuwOE0Q_6N5th1wBgjJceV0ABPG3c,183
12
- SURE/flow/flow_stats.py,sha256=cBBsPEDpWNMpbzlyQ3f0385RSrX6_5RCH2caOyi4ihM,9908
13
- SURE/flow/plot_quiver.py,sha256=UbmuScUcgbQHeMmjKmgqxjrIjHhiHx0VWct16UMMwuE,8110
14
- SURE/perturb/__init__.py,sha256=ouxShhbxZM4r5Gf7GmKiutrsmtyq7QL8rHjhgF0BU08,32
15
- SURE/perturb/perturb.py,sha256=CqO3xPfNA3cG175tadDidKvGsTu_yKfJRRLn_93awKM,3303
16
- SURE/utils/__init__.py,sha256=QJUOfrXzdWSmoM0P3LH8oKEHttzCWqpDy2UF0F0dtN4,673
17
- SURE/utils/custom_mlp.py,sha256=rHnx9jEef02zfCUdbYVCmbuHcDdIBmRgt__wpdpZvYg,8104
18
- SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
19
- SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
20
- sure_tools-2.1.34.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
- sure_tools-2.1.34.dist-info/METADATA,sha256=EV5AA3dO2YrSVqVC0wytPt0RkCIb7nuws0FSUhNcXuE,2651
22
- sure_tools-2.1.34.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- sure_tools-2.1.34.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
- sure_tools-2.1.34.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
- sure_tools-2.1.34.dist-info/RECORD,,