pertpy 0.9.3__py3-none-any.whl → 0.9.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -18,12 +18,16 @@ from scvi.data.fields import CategoricalObsField, LayerField
18
18
  from scvi.model.base import BaseModelClass, JaxTrainingMixin
19
19
  from scvi.utils import setup_anndata_dsp
20
20
 
21
+ from pertpy._doc import _doc_params, doc_common_plot_args
22
+
21
23
  from ._scgenvae import JaxSCGENVAE
22
24
  from ._utils import balancer, extractor
23
25
 
24
26
  if TYPE_CHECKING:
25
27
  from collections.abc import Sequence
26
28
 
29
+ from matplotlib.pyplot import Figure
30
+
27
31
  font = {"family": "Arial", "size": 14}
28
32
 
29
33
 
@@ -377,9 +381,8 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
377
381
  condition_key: str,
378
382
  axis_keys: dict[str, str],
379
383
  labels: dict[str, str],
380
- save: str | bool | None = None,
384
+ *,
381
385
  gene_list: list[str] = None,
382
- show: bool = False,
383
386
  top_100_genes: list[str] = None,
384
387
  verbose: bool = False,
385
388
  legend: bool = True,
@@ -387,6 +390,8 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
387
390
  x_coeff: float = 0.30,
388
391
  y_coeff: float = 0.8,
389
392
  fontsize: float = 14,
393
+ show: bool = False,
394
+ save: str | bool | None = None,
390
395
  **kwargs,
391
396
  ) -> tuple[float, float] | float:
392
397
  """Plots mean matching for a set of specified genes.
@@ -397,21 +402,23 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
397
402
  corresponding to batch and cell type metadata, respectively.
398
403
  condition_key: The key for the condition
399
404
  axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:
400
- `{"x": "Key for x-axis", "y": "Key for y-axis"}`.
401
- labels: Dictionary of axes labels of the form `{"x": "x-axis-name", "y": "y-axis name"}`.
402
- path_to_save: path to save the plot.
403
- save: Specify if the plot should be saved or not.
405
+ {`x`: `Key for x-axis`, `y`: `Key for y-axis`}.
406
+ labels: Dictionary of axes labels of the form {`x`: `x-axis-name`, `y`: `y-axis name`}.
404
407
  gene_list: list of gene names to be plotted.
405
- show: if `True`: will show to the plot after saving it.
406
408
  top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
407
- verbose: Specify if you want information to be printed while creating the plot.,
409
+ verbose: Specify if you want information to be printed while creating the plot.
408
410
  legend: Whether to plot a legend.
409
411
  title: Set if you want the plot to display a title.
410
412
  x_coeff: Offset to print the R^2 value in x-direction.
411
413
  y_coeff: Offset to print the R^2 value in y-direction.
412
414
  fontsize: Fontsize used for text in the plot.
415
+ show: if `True`, will show to the plot after saving it.
416
+ save: Specify if the plot should be saved or not.
413
417
  **kwargs:
414
418
 
419
+ Returns:
420
+ Returns R^2 value for all genes and R^2 value for top 100 DEGs if `top_100_genes` is not `None`.
421
+
415
422
  Examples:
416
423
  >>> import pertpy as pt
417
424
  >>> data = pt.dt.kang_2018()
@@ -498,6 +505,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
498
505
  r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}",
499
506
  fontsize=kwargs.get("textsize", fontsize),
500
507
  )
508
+
501
509
  if save:
502
510
  plt.savefig(save, bbox_inches="tight")
503
511
  if show:
@@ -514,16 +522,17 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
514
522
  condition_key: str,
515
523
  axis_keys: dict[str, str],
516
524
  labels: dict[str, str],
517
- save: str | bool | None = None,
525
+ *,
518
526
  gene_list: list[str] = None,
519
527
  top_100_genes: list[str] = None,
520
- show: bool = False,
521
528
  legend: bool = True,
522
529
  title: str = None,
523
530
  verbose: bool = False,
524
531
  x_coeff: float = 0.3,
525
532
  y_coeff: float = 0.8,
526
533
  fontsize: float = 14,
534
+ show: bool = True,
535
+ save: str | bool | None = None,
527
536
  **kwargs,
528
537
  ) -> tuple[float, float] | float:
529
538
  """Plots variance matching for a set of specified genes.
@@ -534,19 +543,18 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
534
543
  corresponding to batch and cell type metadata, respectively.
535
544
  condition_key: Key of the condition.
536
545
  axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:
537
- `{"x": "Key for x-axis", "y": "Key for y-axis"}`.
538
- labels: Dictionary of axes labels of the form `{"x": "x-axis-name", "y": "y-axis name"}`.
539
- path_to_save: path to save the plot.
540
- save: Specify if the plot should be saved or not.
546
+ {"x": "Key for x-axis", "y": "Key for y-axis"}.
547
+ labels: Dictionary of axes labels of the form {"x": "x-axis-name", "y": "y-axis name"}.
541
548
  gene_list: list of gene names to be plotted.
542
- show: if `True`: will show to the plot after saving it.
543
549
  top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
544
- legend: Whether to plot a elgend
550
+ legend: Whether to plot a legend.
545
551
  title: Set if you want the plot to display a title.
546
552
  verbose: Specify if you want information to be printed while creating the plot.
547
553
  x_coeff: Offset to print the R^2 value in x-direction.
548
554
  y_coeff: Offset to print the R^2 value in y-direction.
549
555
  fontsize: Fontsize used for text in the plot.
556
+ show: if `True`, will show to the plot after saving it.
557
+ save: Specify if the plot should be saved or not.
550
558
  """
551
559
  import seaborn as sns
552
560
 
@@ -636,6 +644,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
636
644
  else:
637
645
  return r_value**2
638
646
 
647
+ @_doc_params(common_plot_args=doc_common_plot_args)
639
648
  def plot_binary_classifier(
640
649
  self,
641
650
  scgen: Scgen,
@@ -643,10 +652,11 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
643
652
  delta: np.ndarray,
644
653
  ctrl_key: str,
645
654
  stim_key: str,
646
- show: bool = False,
647
- save: str | bool | None = None,
655
+ *,
648
656
  fontsize: float = 14,
649
- ) -> plt.Axes | None:
657
+ show: bool = True,
658
+ return_fig: bool = False,
659
+ ) -> Figure | None:
650
660
  """Plots the dot product between delta and latent representation of a linear classifier.
651
661
 
652
662
  Builds a linear classifier based on the dot product between
@@ -661,9 +671,11 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
661
671
  delta: Difference between stimulated and control cells in latent space
662
672
  ctrl_key: Key for `control` part of the `data` found in `condition_key`.
663
673
  stim_key: Key for `stimulated` part of the `data` found in `condition_key`.
664
- path_to_save: Path to save the plot.
665
- save: Specify if the plot should be saved or not.
666
674
  fontsize: Set the font size of the plot.
675
+ {common_plot_args}
676
+
677
+ Returns:
678
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
667
679
  """
668
680
  plt.close("all")
669
681
  adata = scgen._validate_anndata(adata)
@@ -693,12 +705,10 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
693
705
  ax = plt.gca()
694
706
  ax.grid(False)
695
707
 
696
- if save:
697
- plt.savefig(save, bbox_inches="tight")
698
708
  if show:
699
709
  plt.show()
700
- if not (show or save):
701
- return ax
710
+ if return_fig:
711
+ return plt.gcf()
702
712
  return None
703
713
 
704
714
 
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: pertpy
3
- Version: 0.9.3
3
+ Version: 0.9.5
4
4
  Summary: Perturbation Analysis in the scverse ecosystem.
5
5
  Project-URL: Documentation, https://pertpy.readthedocs.io
6
6
  Project-URL: Source, https://github.com/scverse/pertpy
@@ -44,7 +44,7 @@ Classifier: Programming Language :: Python :: 3.11
44
44
  Classifier: Programming Language :: Python :: 3.12
45
45
  Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
46
46
  Classifier: Topic :: Scientific/Engineering :: Visualization
47
- Requires-Python: >=3.10
47
+ Requires-Python: <3.13,>=3.10
48
48
  Requires-Dist: adjusttext
49
49
  Requires-Dist: blitzgsea
50
50
  Requires-Dist: decoupler
@@ -68,7 +68,8 @@ Requires-Dist: pyqt5; extra == 'coda'
68
68
  Requires-Dist: toytree; extra == 'coda'
69
69
  Provides-Extra: de
70
70
  Requires-Dist: formulaic; extra == 'de'
71
- Requires-Dist: pydeseq2; extra == 'de'
71
+ Requires-Dist: formulaic-contrasts>=0.2.0; extra == 'de'
72
+ Requires-Dist: pydeseq2>=v0.5.0pre1; extra == 'de'
72
73
  Provides-Extra: dev
73
74
  Requires-Dist: pre-commit; extra == 'dev'
74
75
  Provides-Extra: doc
@@ -1,57 +1,57 @@
1
- pertpy/__init__.py,sha256=BDOzyW_PnNzv7Nfa8Skj90mC9T1ILiYtxI_bPXwhc1E,658
1
+ pertpy/__init__.py,sha256=r5QhDw2-Ls4yYLs1kJJVe_r6dstQ7SjoASFutlTU9JA,658
2
+ pertpy/_doc.py,sha256=pVt5Iegvh4rC1N81fd9e4cwmoGPNSgttZxxPWbLK6Bs,453
2
3
  pertpy/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
4
  pertpy/data/__init__.py,sha256=ah3yvoxkgbdMUNAWxS3SyqcUuVamBOSeuWkF2QRAEwM,2703
4
- pertpy/data/_dataloader.py,sha256=fl16n82nun01gGiP7qhr5sShfcDchp0szzZp7aXkfBI,2495
5
- pertpy/data/_datasets.py,sha256=I-keaJSTsRBySCPjiVonKmC9rRIM0AEgo0_0UlEX804,65616
5
+ pertpy/data/_dataloader.py,sha256=ENbk1T3w3N6tVI11V4FVUxuWFEwOHP8_kIB-ehiMlVQ,2428
6
+ pertpy/data/_datasets.py,sha256=c_U2U2NvncPZc6vs6w_s77zmWqr8ywDpzmkx6edCfUE,65616
6
7
  pertpy/metadata/__init__.py,sha256=zoE_VXNyuKa4nlXlUk2nTgsHRW3jSQSpDEulcCnzOT0,222
7
- pertpy/metadata/_cell_line.py,sha256=-8KSqmP5XjmLEmNX3TavxSM_MtIHwLWS_x3MVkk6JEw,38595
8
- pertpy/metadata/_compound.py,sha256=JEFwP_TOTyMzfd2qFMb2VkJJvPhCVIvu6gs9Bq_stgs,4756
8
+ pertpy/metadata/_cell_line.py,sha256=Ell5PDVoMlrhHXPDKGCiPGwNY0DAeghbUUvTYL-SFF0,38919
9
+ pertpy/metadata/_compound.py,sha256=ywNNqtib0exHv0z8ctmTRf1Hk64tSGWSiUEffycxf6A,4755
9
10
  pertpy/metadata/_drug.py,sha256=8QDSyxiFl25JdS80EQJC_krg6fEe5LIQEE6BsV1r8nY,9006
10
11
  pertpy/metadata/_look_up.py,sha256=DoWp6OxIk_HyyyOhW1p8z5E68IZ31_nZDnqxk1rJqps,28778
11
- pertpy/metadata/_metadata.py,sha256=pvarnv3X5pblnvG8AQ8Omu5jQcC5ORzCxRk3FRhOLgs,3276
12
+ pertpy/metadata/_metadata.py,sha256=hV2LTFrExddLNU_RsDkZju6lQUSRoP4OIn_dumCyQao,3277
12
13
  pertpy/metadata/_moa.py,sha256=u_OcMonjOeeoW5P9xOltquVSoTH3Vs80ztHsXf-X1DY,4701
13
14
  pertpy/plot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
15
  pertpy/preprocessing/__init__.py,sha256=VAPFaeq2_qCvdFkQTCj_Hm460HC4Tersu8Rig_tnp_Y,71
15
- pertpy/preprocessing/_guide_rna.py,sha256=IgKhXEyfRwEA7ccKJNLA_aIxKHm09QJINM09KaIwn68,7644
16
- pertpy/tools/__init__.py,sha256=Zl4W4bWDIIUVRUFDh1qjT14Rg2hjJ6gvRHcupV_sywk,2647
17
- pertpy/tools/_augur.py,sha256=UWro1nIEZe_rWtjlQCBv4ucqeh3Vt1m8IRzKlux72Z8,55683
18
- pertpy/tools/_cinemaot.py,sha256=BD_oYC1TktbFMX7fpp0A57QAF6frLEgNQ_2wFUpxjyo,39509
19
- pertpy/tools/_dialogue.py,sha256=f2fbhKWdm4Co79ZzVgtVq9xYwjYWFLdGNDeGFOO_pfM,51990
20
- pertpy/tools/_enrichment.py,sha256=rjPHK9YBCJZfpa5Rvfxo3Ii7W5Mvm5dOdolAD7QazVg,21440
16
+ pertpy/preprocessing/_guide_rna.py,sha256=D9hEh8LOOTs_UfVBsBW3b-o6ipRFq09J471Hg1s0tlM,7963
17
+ pertpy/tools/__init__.py,sha256=NUTwCGxRdzUzLTgsS3r7MywENwPAdcGZDKrl83sU8mo,2599
18
+ pertpy/tools/_augur.py,sha256=Vghsx5-fYlaEeu__8-HUg6v5_KoVNhiRDjhgE3pORpY,55339
19
+ pertpy/tools/_cinemaot.py,sha256=U6vCb_mI4ZPFshYgsx-hOOsDA1IPwI7ZR_-IH4F9s7s,39621
20
+ pertpy/tools/_dialogue.py,sha256=BShXZ1ehO2eMbP5PV-ONJ-1SsxD6h9nAN7bGQ4_F6Rw,51906
21
+ pertpy/tools/_enrichment.py,sha256=jxVdOrpS_lAu7GCpemgdB4JJvsGH9SJTQsAKLBKi9Tc,21640
21
22
  pertpy/tools/_kernel_pca.py,sha256=_EJ9WlBLjHOafF34sZGdyBgZL6Fj0WiJ1elVT1XMmo4,1579
22
- pertpy/tools/_milo.py,sha256=FDFGmGMkJiVrvATEnOeAMCe-Q2w7F0nbBMuACVbyIQI,43699
23
- pertpy/tools/_mixscape.py,sha256=FtH3PKvbLTe03LPgN4O9sS70oj_6AHz4Mz5otzEwRl8,52406
23
+ pertpy/tools/_milo.py,sha256=SQqknT2zkzI0pcUmTm0ijWMs7CFMRiyRnXt9rC0jvmg,43811
24
+ pertpy/tools/_mixscape.py,sha256=T-oUHDnepao5aujAHw9bAbbQHPSK6oD_8Wr_mw4U0nc,52089
24
25
  pertpy/tools/decoupler_LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
25
26
  pertpy/tools/transferlearning_MMD_LICENSE,sha256=MUvDA-o_j9htRpI8fStVdCRuyLdPkQUuIH0a_EIc57w,1069
26
27
  pertpy/tools/_coda/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
- pertpy/tools/_coda/_base_coda.py,sha256=jnoLPFfluxB0_CK8-T-qolPa7xPIEb6NpeEpGwHPiNg,113058
28
+ pertpy/tools/_coda/_base_coda.py,sha256=Uxb_vAORL77LkUri8QqL9SxaULK4hfpgRGsCQdtsAgk,111808
28
29
  pertpy/tools/_coda/_sccoda.py,sha256=gGmyd0MGpchulV9d4PxKSmGORyZ8fCDS9tQVPOuF_Og,22622
29
30
  pertpy/tools/_coda/_tasccoda.py,sha256=vNk43OQHn7pBLsez2rmSj0bMZKOf8jZTI7G8TfBByRg,30665
30
- pertpy/tools/_differential_gene_expression/__init__.py,sha256=sabAXym8mMLwp19ZjyBN7wp-oJh32iVj9plvJ-AbXlE,521
31
- pertpy/tools/_differential_gene_expression/_base.py,sha256=qnQkK_hyIcViHBSkgJcAazC26JQ72bEyafKiytZikCY,23624
31
+ pertpy/tools/_differential_gene_expression/__init__.py,sha256=SEydWg0iT3Y1pApjnCAOuHxFeI6xVUfgyBHv2s3LADU,487
32
+ pertpy/tools/_differential_gene_expression/_base.py,sha256=yc9DBj2KgJVk4mkjz7EDFoBj8WBZW92Z4ayD-Xdla1g,38514
32
33
  pertpy/tools/_differential_gene_expression/_checks.py,sha256=SxNHJDsCYZ6rWLTMEymEBpigs_B9cnXyw0kkAe1l6e0,1675
33
34
  pertpy/tools/_differential_gene_expression/_dge_comparison.py,sha256=9HjmWkrqZhj_ZJeR-ymyEDzpRJNx7JiYJoStvCfKuCU,4188
34
- pertpy/tools/_differential_gene_expression/_edger.py,sha256=JziiW5rkXuQBJISAD_LvB2HOZUgJ1_qoqiR5Q4hEoP0,4321
35
- pertpy/tools/_differential_gene_expression/_formulaic.py,sha256=X4rPv4j8SDu5VJnf6_AIYJCCquUQka7G2LGtDLa8FhE,8715
36
- pertpy/tools/_differential_gene_expression/_pydeseq2.py,sha256=JK7H7u4va0q_TLE_sqi4JEzoPBd_xNRycYGu1507HS4,4117
35
+ pertpy/tools/_differential_gene_expression/_edger.py,sha256=ttgTocAYnr8BTDcixwHGjRZew6zeja-U77TLKkSdd1Y,4857
36
+ pertpy/tools/_differential_gene_expression/_pydeseq2.py,sha256=aOqsdu8hKp8_h2HhjkxS0B_itxRBnzEU2oSnU2PYiQ4,2942
37
37
  pertpy/tools/_differential_gene_expression/_simple_tests.py,sha256=tTSr0Z2Qbpxdy9bcO8Gi_up6R616IcoK_e4_rlanyx4,6621
38
- pertpy/tools/_differential_gene_expression/_statsmodels.py,sha256=zSOwJYDJyrl3hsEhMI5Q9Pyw2XLuEuj7T0zSAVcP6tQ,2585
38
+ pertpy/tools/_differential_gene_expression/_statsmodels.py,sha256=jBCtaCglOvvVjkIBGXuTCTDB6g2AJsZMCf7iOlDyn48,2195
39
39
  pertpy/tools/_distances/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
40
40
  pertpy/tools/_distances/_distance_tests.py,sha256=mNmNu5cX0Wj5IegR6x5K-CbBSid8EhrH2jZPQxuvK4U,13521
41
- pertpy/tools/_distances/_distances.py,sha256=iuHpBtWZbJhMZNSEjQkZUu6KPJXCjs_fX6YjopIWvwY,50343
41
+ pertpy/tools/_distances/_distances.py,sha256=CmrOKevVCTY9j3PzhpVc3ga6SwZy9wbbJa0_7bwLMWQ,50569
42
42
  pertpy/tools/_perturbation_space/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
43
43
  pertpy/tools/_perturbation_space/_clustering.py,sha256=m52-J8c8OnIXRCf3NoFabIO2yMHIuy1X0m0amtsK2vE,3556
44
44
  pertpy/tools/_perturbation_space/_comparison.py,sha256=rLO-EGU0I7t5MnLw4k1gYU-ypRu-JsDPLat1t4h2U2M,4329
45
45
  pertpy/tools/_perturbation_space/_discriminator_classifiers.py,sha256=OA2eZeG_4iuW1T5ilsRIkS0rU-azmwEch7IuB546KSY,21617
46
46
  pertpy/tools/_perturbation_space/_metrics.py,sha256=y8-baP8WRdB1iDgvP3uuQxSCDxA2lcxvEHHM2C_vWHY,3248
47
- pertpy/tools/_perturbation_space/_perturbation_space.py,sha256=cZPPzzK4_UZV7ktcD5BQVXEy6ITHrfkg1CLFov3TzsY,18497
48
- pertpy/tools/_perturbation_space/_simple.py,sha256=LH5EYvcAbzFMvgd9bH7AUPKFmdioPiy2xG8xGaXzmq0,12624
47
+ pertpy/tools/_perturbation_space/_perturbation_space.py,sha256=F-F-_pMCTWxjkVQSLre6hrE6PeRfCRscpt2ug3NlfuU,19531
48
+ pertpy/tools/_perturbation_space/_simple.py,sha256=RQv6B0xPq4RJa5zlkLkxMYXQ3LAJLglmQDGaTMseaA8,14238
49
49
  pertpy/tools/_scgen/__init__.py,sha256=uERFlFyF88TH0uLiwmsUGEfHfLVCiZMFuk8gO5f7164,45
50
50
  pertpy/tools/_scgen/_base_components.py,sha256=Qq8myRUm43q9XBrZ9gBggfa2cSV2wbz_KYoLgH7iF1A,3009
51
- pertpy/tools/_scgen/_scgen.py,sha256=HPvFVjY9SS9bGqgTkCDuPYjmA4QHW7rKgHnI2yuI_Q4,30608
51
+ pertpy/tools/_scgen/_scgen.py,sha256=oVY2JNYhDn1OrPoq22ATIP5-H615BafidBCC0eC5C-4,30756
52
52
  pertpy/tools/_scgen/_scgenvae.py,sha256=v_6tZ4wY-JjdMH1QVd_wG4_N0PoaqB-FM8zC2JsDu1o,3935
53
53
  pertpy/tools/_scgen/_utils.py,sha256=1upgOt1FpadfvNG05YpMjYYG-IAlxrC3l_ZxczmIczo,2841
54
- pertpy-0.9.3.dist-info/METADATA,sha256=S6HYSnvP3MYzaICvPCVeFkkOd6HSQU14kMzRTv2RUkI,6852
55
- pertpy-0.9.3.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
56
- pertpy-0.9.3.dist-info/licenses/LICENSE,sha256=OZ-ZkXM5CmExJiEMM90b_7dGNNvRpj7kdE-49AnrLuI,1070
57
- pertpy-0.9.3.dist-info/RECORD,,
54
+ pertpy-0.9.5.dist-info/METADATA,sha256=vuf16H5cVKgNKRg35pFSWN7Oa7tT4IOLqqymVzZfnr4,6927
55
+ pertpy-0.9.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
56
+ pertpy-0.9.5.dist-info/licenses/LICENSE,sha256=OZ-ZkXM5CmExJiEMM90b_7dGNNvRpj7kdE-49AnrLuI,1070
57
+ pertpy-0.9.5.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.25.0
2
+ Generator: hatchling 1.27.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,189 +0,0 @@
1
- """Helpers to interact with Formulaic Formulas
2
-
3
- Some helpful definitions for working with formulaic formulas (e.g. `~ 0 + C(donor):treatment + np.log1p(continuous)`):
4
- * A *term* refers to an expression in the formula, separated by `+`, e.g. `C(donor):treatment`, or `np.log1p(continuous)`.
5
- * A *variable* refers to a column of the data frame passed to formulaic, e.g. `donor`.
6
- * A *factor* is the specification of how a certain variable is represented in the design matrix, e.g. treatment coding with base level "A" and reduced rank.
7
- """
8
-
9
- from collections import defaultdict
10
- from collections.abc import Mapping, Sequence
11
- from dataclasses import dataclass
12
- from typing import Any
13
-
14
- from formulaic import FactorValues, ModelSpec
15
- from formulaic.materializers import PandasMaterializer
16
- from formulaic.materializers.types import EvaluatedFactor
17
- from formulaic.parser.types import Factor
18
- from interface_meta import override
19
-
20
-
21
- @dataclass
22
- class FactorMetadata:
23
- """Store (relevant) metadata for a factor of a formula."""
24
-
25
- name: str
26
- """The unambiguous factor name as specified in the formula. E.g. `donor`, or `C(donor, contr.treatment(base="A"))`"""
27
-
28
- reduced_rank: bool
29
- """Whether a column will be dropped because it is redundant"""
30
-
31
- custom_encoder: bool
32
- """Whether or not a custom encoder (e.g. `C(...)`) was used."""
33
-
34
- categories: Sequence[str]
35
- """The unique categories in this factor (after applying `drop_rows`)"""
36
-
37
- kind: Factor.Kind
38
- """Type of the factor"""
39
-
40
- drop_field: str = None
41
- """The category that is dropped.
42
-
43
- Note that
44
- * this may also be populated if `reduced_rank = False`
45
- * this is only populated when no encoder was used (e.g. `~ donor` but NOT `~ C(donor)`.
46
- """
47
-
48
- column_names: Sequence[str] = None
49
- """The column names for this factor included in the design matrix.
50
-
51
- This may be the same as `categories` if the default encoder is used, or
52
- categories without the base level if a custom encoder (e.g. `C(...)`) is used.
53
- """
54
-
55
- colname_format: str = None
56
- """A formattable string that can be used to generate the column name in the design matrix, e.g. `{name}[T.{field}]`"""
57
-
58
- @property
59
- def base(self) -> str | None:
60
- """
61
- The base category for this categorical factor.
62
-
63
- This is derived from `drop_field` (for default encoding) or by comparing the column names in
64
- the design matrix with all categories (for custom encoding, e.g. `C(...)`).
65
- """
66
- if not self.reduced_rank:
67
- return None
68
- else:
69
- if self.custom_encoder:
70
- tmp_base = set(self.categories) - set(self.column_names)
71
- assert len(tmp_base) == 1
72
- return tmp_base.pop()
73
- else:
74
- assert self.drop_field is not None
75
- return self.drop_field
76
-
77
-
78
- def get_factor_storage_and_materializer() -> tuple[dict[str, list[FactorMetadata]], dict[str, set[str]], type]:
79
- """Keeps track of categorical factors used in a model specification by generating a custom materializer.
80
-
81
- This materializer reports back metadata upon materialization of the model matrix.
82
-
83
- Returns:
84
- - A dictionary storing metadata for each factor processed by the custom materializer, named `factor_storage`.
85
- - A dictionary mapping variables to factor names, which works similarly to model_spec.variable_terms
86
- but maps to factors rather than terms, named `variable_to_factors`.
87
- - A materializer class tied to the specific instance of `factor_storage`.
88
- """
89
- # There can be multiple FactorMetadata entries per sample, for instance when formulaic generates an interaction
90
- # term, it generates the factor with both full rank and reduced rank.
91
- factor_storage: dict[str, list[FactorMetadata]] = defaultdict(list)
92
- variable_to_factors: dict[str, set[str]] = defaultdict(set)
93
-
94
- class CustomPandasMaterializer(PandasMaterializer):
95
- """An extension of the PandasMaterializer that records all categorical variables and their (base) categories."""
96
-
97
- REGISTER_NAME = "custom_pandas"
98
- REGISTER_INPUTS = ("pandas.core.frame.DataFrame",)
99
- REGISTER_OUTPUTS = ("pandas", "numpy", "sparse")
100
-
101
- def __init__(
102
- self,
103
- data: Any,
104
- context: Mapping[str, Any] | None = None,
105
- record_factor_metadata: bool = False,
106
- **params: Any,
107
- ):
108
- """Initialize the Materializer.
109
-
110
- Args:
111
- data: Passed to PandasMaterializer.
112
- context: Passed to PandasMaterializer
113
- record_factor_metadata: Flag that tells whether this particular instance of the custom materializer class
114
- is supposed to record factor metadata. Only the instance that is used for building the design
115
- matrix should record the metadata. All other instances (e.g. used to generate contrast vectors)
116
- should not record metadata to not overwrite the specifications from the design matrix.
117
- **params: Passed to PandasMaterializer
118
- """
119
- self.factor_metadata_storage = factor_storage if record_factor_metadata else None
120
- self.variable_to_factors = variable_to_factors if record_factor_metadata else None
121
- # temporary pointer to metadata of factor that is currently evaluated
122
- self._current_factor: FactorMetadata = None
123
- super().__init__(data, context, **params)
124
-
125
- @override
126
- def _encode_evaled_factor(
127
- self, factor: EvaluatedFactor, spec: ModelSpec, drop_rows: Sequence[int], reduced_rank: bool = False
128
- ) -> dict[str, Any]:
129
- """Function is called just before the factor is evaluated.
130
-
131
- We can record some metadata, before we call the original function.
132
- """
133
- assert (
134
- self._current_factor is None
135
- ), "_current_factor should always be None when we start recording metadata"
136
- if self.factor_metadata_storage is not None:
137
- # Don't store if the factor is cached (then we should already have recorded it)
138
- if factor.expr in self.encoded_cache or (factor.expr, reduced_rank) in self.encoded_cache:
139
- assert factor.expr in self.factor_metadata_storage, "Factor should be there since it's cached"
140
- else:
141
- for var in factor.variables:
142
- self.variable_to_factors[var].add(factor.expr)
143
- self._current_factor = FactorMetadata(
144
- name=factor.expr,
145
- reduced_rank=reduced_rank,
146
- categories=tuple(sorted(factor.values.drop(index=factor.values.index[drop_rows]).unique())),
147
- custom_encoder=factor.metadata.encoder is not None,
148
- kind=factor.metadata.kind,
149
- )
150
- return super()._encode_evaled_factor(factor, spec, drop_rows, reduced_rank)
151
-
152
- @override
153
- def _flatten_encoded_evaled_factor(self, name: str, values: FactorValues[dict]) -> dict[str, Any]:
154
- """
155
- Function is called at the end, before the design matrix gets materialized.
156
-
157
- Here we have access to additional metadata, such as `drop_field`.
158
- """
159
- if self._current_factor is not None:
160
- assert self._current_factor.name == name
161
- self._current_factor.drop_field = values.__formulaic_metadata__.drop_field
162
- self._current_factor.column_names = values.__formulaic_metadata__.column_names
163
- self._current_factor.colname_format = values.__formulaic_metadata__.format
164
- self.factor_metadata_storage[name].append(self._current_factor)
165
- self._current_factor = None
166
-
167
- return super()._flatten_encoded_evaled_factor(name, values)
168
-
169
- return factor_storage, variable_to_factors, CustomPandasMaterializer
170
-
171
-
172
- class AmbiguousAttributeError(ValueError):
173
- pass
174
-
175
-
176
- def resolve_ambiguous(objs: Sequence[Any], attr: str) -> Any:
177
- """Given a list of objects, return an attribute if it is the same between all object. Otherwise, raise an error."""
178
- if not objs:
179
- raise ValueError("Collection is empty")
180
-
181
- first_obj_attr = getattr(objs[0], attr)
182
-
183
- # Check if the attribute is the same for all objects
184
- for obj in objs[1:]:
185
- if getattr(obj, attr) != first_obj_attr:
186
- raise AmbiguousAttributeError(f"Ambiguous attribute '{attr}': values differ between objects")
187
-
188
- # If attribute is the same for all objects, return it
189
- return first_obj_attr