pertpy 0.9.4__py3-none-any.whl → 0.9.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,12 +18,16 @@ from scvi.data.fields import CategoricalObsField, LayerField
18
18
  from scvi.model.base import BaseModelClass, JaxTrainingMixin
19
19
  from scvi.utils import setup_anndata_dsp
20
20
 
21
+ from pertpy._doc import _doc_params, doc_common_plot_args
22
+
21
23
  from ._scgenvae import JaxSCGENVAE
22
24
  from ._utils import balancer, extractor
23
25
 
24
26
  if TYPE_CHECKING:
25
27
  from collections.abc import Sequence
26
28
 
29
+ from matplotlib.pyplot import Figure
30
+
27
31
  font = {"family": "Arial", "size": 14}
28
32
 
29
33
 
@@ -377,9 +381,8 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
377
381
  condition_key: str,
378
382
  axis_keys: dict[str, str],
379
383
  labels: dict[str, str],
380
- save: str | bool | None = None,
384
+ *,
381
385
  gene_list: list[str] = None,
382
- show: bool = False,
383
386
  top_100_genes: list[str] = None,
384
387
  verbose: bool = False,
385
388
  legend: bool = True,
@@ -387,6 +390,8 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
387
390
  x_coeff: float = 0.30,
388
391
  y_coeff: float = 0.8,
389
392
  fontsize: float = 14,
393
+ show: bool = False,
394
+ save: str | bool | None = None,
390
395
  **kwargs,
391
396
  ) -> tuple[float, float] | float:
392
397
  """Plots mean matching for a set of specified genes.
@@ -397,21 +402,23 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
397
402
  corresponding to batch and cell type metadata, respectively.
398
403
  condition_key: The key for the condition
399
404
  axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:
400
- `{"x": "Key for x-axis", "y": "Key for y-axis"}`.
401
- labels: Dictionary of axes labels of the form `{"x": "x-axis-name", "y": "y-axis name"}`.
402
- path_to_save: path to save the plot.
403
- save: Specify if the plot should be saved or not.
405
+ {`x`: `Key for x-axis`, `y`: `Key for y-axis`}.
406
+ labels: Dictionary of axes labels of the form {`x`: `x-axis-name`, `y`: `y-axis name`}.
404
407
  gene_list: list of gene names to be plotted.
405
- show: if `True`: will show to the plot after saving it.
406
408
  top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
407
- verbose: Specify if you want information to be printed while creating the plot.,
409
+ verbose: Specify if you want information to be printed while creating the plot.
408
410
  legend: Whether to plot a legend.
409
411
  title: Set if you want the plot to display a title.
410
412
  x_coeff: Offset to print the R^2 value in x-direction.
411
413
  y_coeff: Offset to print the R^2 value in y-direction.
412
414
  fontsize: Fontsize used for text in the plot.
415
+ show: if `True`, will show to the plot after saving it.
416
+ save: Specify if the plot should be saved or not.
413
417
  **kwargs:
414
418
 
419
+ Returns:
420
+ Returns R^2 value for all genes and R^2 value for top 100 DEGs if `top_100_genes` is not `None`.
421
+
415
422
  Examples:
416
423
  >>> import pertpy as pt
417
424
  >>> data = pt.dt.kang_2018()
@@ -498,6 +505,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
498
505
  r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}",
499
506
  fontsize=kwargs.get("textsize", fontsize),
500
507
  )
508
+
501
509
  if save:
502
510
  plt.savefig(save, bbox_inches="tight")
503
511
  if show:
@@ -514,16 +522,17 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
514
522
  condition_key: str,
515
523
  axis_keys: dict[str, str],
516
524
  labels: dict[str, str],
517
- save: str | bool | None = None,
525
+ *,
518
526
  gene_list: list[str] = None,
519
527
  top_100_genes: list[str] = None,
520
- show: bool = False,
521
528
  legend: bool = True,
522
529
  title: str = None,
523
530
  verbose: bool = False,
524
531
  x_coeff: float = 0.3,
525
532
  y_coeff: float = 0.8,
526
533
  fontsize: float = 14,
534
+ show: bool = True,
535
+ save: str | bool | None = None,
527
536
  **kwargs,
528
537
  ) -> tuple[float, float] | float:
529
538
  """Plots variance matching for a set of specified genes.
@@ -534,19 +543,18 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
534
543
  corresponding to batch and cell type metadata, respectively.
535
544
  condition_key: Key of the condition.
536
545
  axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:
537
- `{"x": "Key for x-axis", "y": "Key for y-axis"}`.
538
- labels: Dictionary of axes labels of the form `{"x": "x-axis-name", "y": "y-axis name"}`.
539
- path_to_save: path to save the plot.
540
- save: Specify if the plot should be saved or not.
546
+ {"x": "Key for x-axis", "y": "Key for y-axis"}.
547
+ labels: Dictionary of axes labels of the form {"x": "x-axis-name", "y": "y-axis name"}.
541
548
  gene_list: list of gene names to be plotted.
542
- show: if `True`: will show to the plot after saving it.
543
549
  top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
544
- legend: Whether to plot a elgend
550
+ legend: Whether to plot a legend.
545
551
  title: Set if you want the plot to display a title.
546
552
  verbose: Specify if you want information to be printed while creating the plot.
547
553
  x_coeff: Offset to print the R^2 value in x-direction.
548
554
  y_coeff: Offset to print the R^2 value in y-direction.
549
555
  fontsize: Fontsize used for text in the plot.
556
+ show: if `True`, will show to the plot after saving it.
557
+ save: Specify if the plot should be saved or not.
550
558
  """
551
559
  import seaborn as sns
552
560
 
@@ -636,6 +644,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
636
644
  else:
637
645
  return r_value**2
638
646
 
647
+ @_doc_params(common_plot_args=doc_common_plot_args)
639
648
  def plot_binary_classifier(
640
649
  self,
641
650
  scgen: Scgen,
@@ -643,10 +652,11 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
643
652
  delta: np.ndarray,
644
653
  ctrl_key: str,
645
654
  stim_key: str,
646
- show: bool = False,
647
- save: str | bool | None = None,
655
+ *,
648
656
  fontsize: float = 14,
649
- ) -> plt.Axes | None:
657
+ show: bool = True,
658
+ return_fig: bool = False,
659
+ ) -> Figure | None:
650
660
  """Plots the dot product between delta and latent representation of a linear classifier.
651
661
 
652
662
  Builds a linear classifier based on the dot product between
@@ -661,9 +671,11 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
661
671
  delta: Difference between stimulated and control cells in latent space
662
672
  ctrl_key: Key for `control` part of the `data` found in `condition_key`.
663
673
  stim_key: Key for `stimulated` part of the `data` found in `condition_key`.
664
- path_to_save: Path to save the plot.
665
- save: Specify if the plot should be saved or not.
666
674
  fontsize: Set the font size of the plot.
675
+ {common_plot_args}
676
+
677
+ Returns:
678
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
667
679
  """
668
680
  plt.close("all")
669
681
  adata = scgen._validate_anndata(adata)
@@ -693,12 +705,10 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
693
705
  ax = plt.gca()
694
706
  ax.grid(False)
695
707
 
696
- if save:
697
- plt.savefig(save, bbox_inches="tight")
698
708
  if show:
699
709
  plt.show()
700
- if not (show or save):
701
- return ax
710
+ if return_fig:
711
+ return plt.gcf()
702
712
  return None
703
713
 
704
714
 
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: pertpy
3
- Version: 0.9.4
3
+ Version: 0.9.5
4
4
  Summary: Perturbation Analysis in the scverse ecosystem.
5
5
  Project-URL: Documentation, https://pertpy.readthedocs.io
6
6
  Project-URL: Source, https://github.com/scverse/pertpy
@@ -44,9 +44,8 @@ Classifier: Programming Language :: Python :: 3.11
44
44
  Classifier: Programming Language :: Python :: 3.12
45
45
  Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
46
46
  Classifier: Topic :: Scientific/Engineering :: Visualization
47
- Requires-Python: >=3.10
47
+ Requires-Python: <3.13,>=3.10
48
48
  Requires-Dist: adjusttext
49
- Requires-Dist: anndata<0.10.9
50
49
  Requires-Dist: blitzgsea
51
50
  Requires-Dist: decoupler
52
51
  Requires-Dist: lamin-utils
@@ -69,7 +68,8 @@ Requires-Dist: pyqt5; extra == 'coda'
69
68
  Requires-Dist: toytree; extra == 'coda'
70
69
  Provides-Extra: de
71
70
  Requires-Dist: formulaic; extra == 'de'
72
- Requires-Dist: pydeseq2; extra == 'de'
71
+ Requires-Dist: formulaic-contrasts>=0.2.0; extra == 'de'
72
+ Requires-Dist: pydeseq2>=v0.5.0pre1; extra == 'de'
73
73
  Provides-Extra: dev
74
74
  Requires-Dist: pre-commit; extra == 'dev'
75
75
  Provides-Extra: doc
@@ -1,57 +1,57 @@
1
- pertpy/__init__.py,sha256=k0tPuH0DdvQraT7I-zYrI1TwJHK3GnBx-Nvi-cMobvM,658
1
+ pertpy/__init__.py,sha256=r5QhDw2-Ls4yYLs1kJJVe_r6dstQ7SjoASFutlTU9JA,658
2
+ pertpy/_doc.py,sha256=pVt5Iegvh4rC1N81fd9e4cwmoGPNSgttZxxPWbLK6Bs,453
2
3
  pertpy/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
4
  pertpy/data/__init__.py,sha256=ah3yvoxkgbdMUNAWxS3SyqcUuVamBOSeuWkF2QRAEwM,2703
4
5
  pertpy/data/_dataloader.py,sha256=ENbk1T3w3N6tVI11V4FVUxuWFEwOHP8_kIB-ehiMlVQ,2428
5
- pertpy/data/_datasets.py,sha256=OwI0HSSXnUPnUw_lAG9w5jNMILjLnPZS2Wj_LfrXSoI,65616
6
+ pertpy/data/_datasets.py,sha256=c_U2U2NvncPZc6vs6w_s77zmWqr8ywDpzmkx6edCfUE,65616
6
7
  pertpy/metadata/__init__.py,sha256=zoE_VXNyuKa4nlXlUk2nTgsHRW3jSQSpDEulcCnzOT0,222
7
- pertpy/metadata/_cell_line.py,sha256=-8KSqmP5XjmLEmNX3TavxSM_MtIHwLWS_x3MVkk6JEw,38595
8
- pertpy/metadata/_compound.py,sha256=JEFwP_TOTyMzfd2qFMb2VkJJvPhCVIvu6gs9Bq_stgs,4756
8
+ pertpy/metadata/_cell_line.py,sha256=Ell5PDVoMlrhHXPDKGCiPGwNY0DAeghbUUvTYL-SFF0,38919
9
+ pertpy/metadata/_compound.py,sha256=ywNNqtib0exHv0z8ctmTRf1Hk64tSGWSiUEffycxf6A,4755
9
10
  pertpy/metadata/_drug.py,sha256=8QDSyxiFl25JdS80EQJC_krg6fEe5LIQEE6BsV1r8nY,9006
10
11
  pertpy/metadata/_look_up.py,sha256=DoWp6OxIk_HyyyOhW1p8z5E68IZ31_nZDnqxk1rJqps,28778
11
- pertpy/metadata/_metadata.py,sha256=pvarnv3X5pblnvG8AQ8Omu5jQcC5ORzCxRk3FRhOLgs,3276
12
+ pertpy/metadata/_metadata.py,sha256=hV2LTFrExddLNU_RsDkZju6lQUSRoP4OIn_dumCyQao,3277
12
13
  pertpy/metadata/_moa.py,sha256=u_OcMonjOeeoW5P9xOltquVSoTH3Vs80ztHsXf-X1DY,4701
13
14
  pertpy/plot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
15
  pertpy/preprocessing/__init__.py,sha256=VAPFaeq2_qCvdFkQTCj_Hm460HC4Tersu8Rig_tnp_Y,71
15
- pertpy/preprocessing/_guide_rna.py,sha256=IgKhXEyfRwEA7ccKJNLA_aIxKHm09QJINM09KaIwn68,7644
16
- pertpy/tools/__init__.py,sha256=guh4QL8ac_CKP8S2nFHmxWJ6epCFB_Jfga4eK9HYGnE,2598
17
- pertpy/tools/_augur.py,sha256=UWro1nIEZe_rWtjlQCBv4ucqeh3Vt1m8IRzKlux72Z8,55683
18
- pertpy/tools/_cinemaot.py,sha256=vMm9oTNW6pb8HBe993-BvkVKjSHbfbqlZY1SSCvj12Y,39521
19
- pertpy/tools/_dialogue.py,sha256=f2fbhKWdm4Co79ZzVgtVq9xYwjYWFLdGNDeGFOO_pfM,51990
20
- pertpy/tools/_enrichment.py,sha256=rjPHK9YBCJZfpa5Rvfxo3Ii7W5Mvm5dOdolAD7QazVg,21440
16
+ pertpy/preprocessing/_guide_rna.py,sha256=D9hEh8LOOTs_UfVBsBW3b-o6ipRFq09J471Hg1s0tlM,7963
17
+ pertpy/tools/__init__.py,sha256=NUTwCGxRdzUzLTgsS3r7MywENwPAdcGZDKrl83sU8mo,2599
18
+ pertpy/tools/_augur.py,sha256=Vghsx5-fYlaEeu__8-HUg6v5_KoVNhiRDjhgE3pORpY,55339
19
+ pertpy/tools/_cinemaot.py,sha256=U6vCb_mI4ZPFshYgsx-hOOsDA1IPwI7ZR_-IH4F9s7s,39621
20
+ pertpy/tools/_dialogue.py,sha256=BShXZ1ehO2eMbP5PV-ONJ-1SsxD6h9nAN7bGQ4_F6Rw,51906
21
+ pertpy/tools/_enrichment.py,sha256=jxVdOrpS_lAu7GCpemgdB4JJvsGH9SJTQsAKLBKi9Tc,21640
21
22
  pertpy/tools/_kernel_pca.py,sha256=_EJ9WlBLjHOafF34sZGdyBgZL6Fj0WiJ1elVT1XMmo4,1579
22
- pertpy/tools/_milo.py,sha256=FDFGmGMkJiVrvATEnOeAMCe-Q2w7F0nbBMuACVbyIQI,43699
23
- pertpy/tools/_mixscape.py,sha256=FtH3PKvbLTe03LPgN4O9sS70oj_6AHz4Mz5otzEwRl8,52406
23
+ pertpy/tools/_milo.py,sha256=SQqknT2zkzI0pcUmTm0ijWMs7CFMRiyRnXt9rC0jvmg,43811
24
+ pertpy/tools/_mixscape.py,sha256=T-oUHDnepao5aujAHw9bAbbQHPSK6oD_8Wr_mw4U0nc,52089
24
25
  pertpy/tools/decoupler_LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
25
26
  pertpy/tools/transferlearning_MMD_LICENSE,sha256=MUvDA-o_j9htRpI8fStVdCRuyLdPkQUuIH0a_EIc57w,1069
26
27
  pertpy/tools/_coda/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
- pertpy/tools/_coda/_base_coda.py,sha256=jnoLPFfluxB0_CK8-T-qolPa7xPIEb6NpeEpGwHPiNg,113058
28
+ pertpy/tools/_coda/_base_coda.py,sha256=Uxb_vAORL77LkUri8QqL9SxaULK4hfpgRGsCQdtsAgk,111808
28
29
  pertpy/tools/_coda/_sccoda.py,sha256=gGmyd0MGpchulV9d4PxKSmGORyZ8fCDS9tQVPOuF_Og,22622
29
30
  pertpy/tools/_coda/_tasccoda.py,sha256=vNk43OQHn7pBLsez2rmSj0bMZKOf8jZTI7G8TfBByRg,30665
30
- pertpy/tools/_differential_gene_expression/__init__.py,sha256=sabAXym8mMLwp19ZjyBN7wp-oJh32iVj9plvJ-AbXlE,521
31
- pertpy/tools/_differential_gene_expression/_base.py,sha256=qnQkK_hyIcViHBSkgJcAazC26JQ72bEyafKiytZikCY,23624
31
+ pertpy/tools/_differential_gene_expression/__init__.py,sha256=SEydWg0iT3Y1pApjnCAOuHxFeI6xVUfgyBHv2s3LADU,487
32
+ pertpy/tools/_differential_gene_expression/_base.py,sha256=yc9DBj2KgJVk4mkjz7EDFoBj8WBZW92Z4ayD-Xdla1g,38514
32
33
  pertpy/tools/_differential_gene_expression/_checks.py,sha256=SxNHJDsCYZ6rWLTMEymEBpigs_B9cnXyw0kkAe1l6e0,1675
33
34
  pertpy/tools/_differential_gene_expression/_dge_comparison.py,sha256=9HjmWkrqZhj_ZJeR-ymyEDzpRJNx7JiYJoStvCfKuCU,4188
34
- pertpy/tools/_differential_gene_expression/_edger.py,sha256=JziiW5rkXuQBJISAD_LvB2HOZUgJ1_qoqiR5Q4hEoP0,4321
35
- pertpy/tools/_differential_gene_expression/_formulaic.py,sha256=X4rPv4j8SDu5VJnf6_AIYJCCquUQka7G2LGtDLa8FhE,8715
36
- pertpy/tools/_differential_gene_expression/_pydeseq2.py,sha256=JK7H7u4va0q_TLE_sqi4JEzoPBd_xNRycYGu1507HS4,4117
35
+ pertpy/tools/_differential_gene_expression/_edger.py,sha256=ttgTocAYnr8BTDcixwHGjRZew6zeja-U77TLKkSdd1Y,4857
36
+ pertpy/tools/_differential_gene_expression/_pydeseq2.py,sha256=aOqsdu8hKp8_h2HhjkxS0B_itxRBnzEU2oSnU2PYiQ4,2942
37
37
  pertpy/tools/_differential_gene_expression/_simple_tests.py,sha256=tTSr0Z2Qbpxdy9bcO8Gi_up6R616IcoK_e4_rlanyx4,6621
38
- pertpy/tools/_differential_gene_expression/_statsmodels.py,sha256=zSOwJYDJyrl3hsEhMI5Q9Pyw2XLuEuj7T0zSAVcP6tQ,2585
38
+ pertpy/tools/_differential_gene_expression/_statsmodels.py,sha256=jBCtaCglOvvVjkIBGXuTCTDB6g2AJsZMCf7iOlDyn48,2195
39
39
  pertpy/tools/_distances/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
40
40
  pertpy/tools/_distances/_distance_tests.py,sha256=mNmNu5cX0Wj5IegR6x5K-CbBSid8EhrH2jZPQxuvK4U,13521
41
- pertpy/tools/_distances/_distances.py,sha256=iuHpBtWZbJhMZNSEjQkZUu6KPJXCjs_fX6YjopIWvwY,50343
41
+ pertpy/tools/_distances/_distances.py,sha256=CmrOKevVCTY9j3PzhpVc3ga6SwZy9wbbJa0_7bwLMWQ,50569
42
42
  pertpy/tools/_perturbation_space/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
43
43
  pertpy/tools/_perturbation_space/_clustering.py,sha256=m52-J8c8OnIXRCf3NoFabIO2yMHIuy1X0m0amtsK2vE,3556
44
44
  pertpy/tools/_perturbation_space/_comparison.py,sha256=rLO-EGU0I7t5MnLw4k1gYU-ypRu-JsDPLat1t4h2U2M,4329
45
45
  pertpy/tools/_perturbation_space/_discriminator_classifiers.py,sha256=OA2eZeG_4iuW1T5ilsRIkS0rU-azmwEch7IuB546KSY,21617
46
46
  pertpy/tools/_perturbation_space/_metrics.py,sha256=y8-baP8WRdB1iDgvP3uuQxSCDxA2lcxvEHHM2C_vWHY,3248
47
- pertpy/tools/_perturbation_space/_perturbation_space.py,sha256=cZPPzzK4_UZV7ktcD5BQVXEy6ITHrfkg1CLFov3TzsY,18497
48
- pertpy/tools/_perturbation_space/_simple.py,sha256=LH5EYvcAbzFMvgd9bH7AUPKFmdioPiy2xG8xGaXzmq0,12624
47
+ pertpy/tools/_perturbation_space/_perturbation_space.py,sha256=F-F-_pMCTWxjkVQSLre6hrE6PeRfCRscpt2ug3NlfuU,19531
48
+ pertpy/tools/_perturbation_space/_simple.py,sha256=RQv6B0xPq4RJa5zlkLkxMYXQ3LAJLglmQDGaTMseaA8,14238
49
49
  pertpy/tools/_scgen/__init__.py,sha256=uERFlFyF88TH0uLiwmsUGEfHfLVCiZMFuk8gO5f7164,45
50
50
  pertpy/tools/_scgen/_base_components.py,sha256=Qq8myRUm43q9XBrZ9gBggfa2cSV2wbz_KYoLgH7iF1A,3009
51
- pertpy/tools/_scgen/_scgen.py,sha256=HPvFVjY9SS9bGqgTkCDuPYjmA4QHW7rKgHnI2yuI_Q4,30608
51
+ pertpy/tools/_scgen/_scgen.py,sha256=oVY2JNYhDn1OrPoq22ATIP5-H615BafidBCC0eC5C-4,30756
52
52
  pertpy/tools/_scgen/_scgenvae.py,sha256=v_6tZ4wY-JjdMH1QVd_wG4_N0PoaqB-FM8zC2JsDu1o,3935
53
53
  pertpy/tools/_scgen/_utils.py,sha256=1upgOt1FpadfvNG05YpMjYYG-IAlxrC3l_ZxczmIczo,2841
54
- pertpy-0.9.4.dist-info/METADATA,sha256=0gKL9NKX-_hyYAGZvXqTNZySfUSG-VuJdOL_zNCBDrs,6882
55
- pertpy-0.9.4.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
56
- pertpy-0.9.4.dist-info/licenses/LICENSE,sha256=OZ-ZkXM5CmExJiEMM90b_7dGNNvRpj7kdE-49AnrLuI,1070
57
- pertpy-0.9.4.dist-info/RECORD,,
54
+ pertpy-0.9.5.dist-info/METADATA,sha256=vuf16H5cVKgNKRg35pFSWN7Oa7tT4IOLqqymVzZfnr4,6927
55
+ pertpy-0.9.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
56
+ pertpy-0.9.5.dist-info/licenses/LICENSE,sha256=OZ-ZkXM5CmExJiEMM90b_7dGNNvRpj7kdE-49AnrLuI,1070
57
+ pertpy-0.9.5.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.25.0
2
+ Generator: hatchling 1.27.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,189 +0,0 @@
1
- """Helpers to interact with Formulaic Formulas
2
-
3
- Some helpful definitions for working with formulaic formulas (e.g. `~ 0 + C(donor):treatment + np.log1p(continuous)`):
4
- * A *term* refers to an expression in the formula, separated by `+`, e.g. `C(donor):treatment`, or `np.log1p(continuous)`.
5
- * A *variable* refers to a column of the data frame passed to formulaic, e.g. `donor`.
6
- * A *factor* is the specification of how a certain variable is represented in the design matrix, e.g. treatment coding with base level "A" and reduced rank.
7
- """
8
-
9
- from collections import defaultdict
10
- from collections.abc import Mapping, Sequence
11
- from dataclasses import dataclass
12
- from typing import Any
13
-
14
- from formulaic import FactorValues, ModelSpec
15
- from formulaic.materializers import PandasMaterializer
16
- from formulaic.materializers.types import EvaluatedFactor
17
- from formulaic.parser.types import Factor
18
- from interface_meta import override
19
-
20
-
21
- @dataclass
22
- class FactorMetadata:
23
- """Store (relevant) metadata for a factor of a formula."""
24
-
25
- name: str
26
- """The unambiguous factor name as specified in the formula. E.g. `donor`, or `C(donor, contr.treatment(base="A"))`"""
27
-
28
- reduced_rank: bool
29
- """Whether a column will be dropped because it is redundant"""
30
-
31
- custom_encoder: bool
32
- """Whether or not a custom encoder (e.g. `C(...)`) was used."""
33
-
34
- categories: Sequence[str]
35
- """The unique categories in this factor (after applying `drop_rows`)"""
36
-
37
- kind: Factor.Kind
38
- """Type of the factor"""
39
-
40
- drop_field: str = None
41
- """The category that is dropped.
42
-
43
- Note that
44
- * this may also be populated if `reduced_rank = False`
45
- * this is only populated when no encoder was used (e.g. `~ donor` but NOT `~ C(donor)`.
46
- """
47
-
48
- column_names: Sequence[str] = None
49
- """The column names for this factor included in the design matrix.
50
-
51
- This may be the same as `categories` if the default encoder is used, or
52
- categories without the base level if a custom encoder (e.g. `C(...)`) is used.
53
- """
54
-
55
- colname_format: str = None
56
- """A formattable string that can be used to generate the column name in the design matrix, e.g. `{name}[T.{field}]`"""
57
-
58
- @property
59
- def base(self) -> str | None:
60
- """
61
- The base category for this categorical factor.
62
-
63
- This is derived from `drop_field` (for default encoding) or by comparing the column names in
64
- the design matrix with all categories (for custom encoding, e.g. `C(...)`).
65
- """
66
- if not self.reduced_rank:
67
- return None
68
- else:
69
- if self.custom_encoder:
70
- tmp_base = set(self.categories) - set(self.column_names)
71
- assert len(tmp_base) == 1
72
- return tmp_base.pop()
73
- else:
74
- assert self.drop_field is not None
75
- return self.drop_field
76
-
77
-
78
- def get_factor_storage_and_materializer() -> tuple[dict[str, list[FactorMetadata]], dict[str, set[str]], type]:
79
- """Keeps track of categorical factors used in a model specification by generating a custom materializer.
80
-
81
- This materializer reports back metadata upon materialization of the model matrix.
82
-
83
- Returns:
84
- - A dictionary storing metadata for each factor processed by the custom materializer, named `factor_storage`.
85
- - A dictionary mapping variables to factor names, which works similarly to model_spec.variable_terms
86
- but maps to factors rather than terms, named `variable_to_factors`.
87
- - A materializer class tied to the specific instance of `factor_storage`.
88
- """
89
- # There can be multiple FactorMetadata entries per sample, for instance when formulaic generates an interaction
90
- # term, it generates the factor with both full rank and reduced rank.
91
- factor_storage: dict[str, list[FactorMetadata]] = defaultdict(list)
92
- variable_to_factors: dict[str, set[str]] = defaultdict(set)
93
-
94
- class CustomPandasMaterializer(PandasMaterializer):
95
- """An extension of the PandasMaterializer that records all categorical variables and their (base) categories."""
96
-
97
- REGISTER_NAME = "custom_pandas"
98
- REGISTER_INPUTS = ("pandas.core.frame.DataFrame",)
99
- REGISTER_OUTPUTS = ("pandas", "numpy", "sparse")
100
-
101
- def __init__(
102
- self,
103
- data: Any,
104
- context: Mapping[str, Any] | None = None,
105
- record_factor_metadata: bool = False,
106
- **params: Any,
107
- ):
108
- """Initialize the Materializer.
109
-
110
- Args:
111
- data: Passed to PandasMaterializer.
112
- context: Passed to PandasMaterializer
113
- record_factor_metadata: Flag that tells whether this particular instance of the custom materializer class
114
- is supposed to record factor metadata. Only the instance that is used for building the design
115
- matrix should record the metadata. All other instances (e.g. used to generate contrast vectors)
116
- should not record metadata to not overwrite the specifications from the design matrix.
117
- **params: Passed to PandasMaterializer
118
- """
119
- self.factor_metadata_storage = factor_storage if record_factor_metadata else None
120
- self.variable_to_factors = variable_to_factors if record_factor_metadata else None
121
- # temporary pointer to metadata of factor that is currently evaluated
122
- self._current_factor: FactorMetadata = None
123
- super().__init__(data, context, **params)
124
-
125
- @override
126
- def _encode_evaled_factor(
127
- self, factor: EvaluatedFactor, spec: ModelSpec, drop_rows: Sequence[int], reduced_rank: bool = False
128
- ) -> dict[str, Any]:
129
- """Function is called just before the factor is evaluated.
130
-
131
- We can record some metadata, before we call the original function.
132
- """
133
- assert (
134
- self._current_factor is None
135
- ), "_current_factor should always be None when we start recording metadata"
136
- if self.factor_metadata_storage is not None:
137
- # Don't store if the factor is cached (then we should already have recorded it)
138
- if factor.expr in self.encoded_cache or (factor.expr, reduced_rank) in self.encoded_cache:
139
- assert factor.expr in self.factor_metadata_storage, "Factor should be there since it's cached"
140
- else:
141
- for var in factor.variables:
142
- self.variable_to_factors[var].add(factor.expr)
143
- self._current_factor = FactorMetadata(
144
- name=factor.expr,
145
- reduced_rank=reduced_rank,
146
- categories=tuple(sorted(factor.values.drop(index=factor.values.index[drop_rows]).unique())),
147
- custom_encoder=factor.metadata.encoder is not None,
148
- kind=factor.metadata.kind,
149
- )
150
- return super()._encode_evaled_factor(factor, spec, drop_rows, reduced_rank)
151
-
152
- @override
153
- def _flatten_encoded_evaled_factor(self, name: str, values: FactorValues[dict]) -> dict[str, Any]:
154
- """
155
- Function is called at the end, before the design matrix gets materialized.
156
-
157
- Here we have access to additional metadata, such as `drop_field`.
158
- """
159
- if self._current_factor is not None:
160
- assert self._current_factor.name == name
161
- self._current_factor.drop_field = values.__formulaic_metadata__.drop_field
162
- self._current_factor.column_names = values.__formulaic_metadata__.column_names
163
- self._current_factor.colname_format = values.__formulaic_metadata__.format
164
- self.factor_metadata_storage[name].append(self._current_factor)
165
- self._current_factor = None
166
-
167
- return super()._flatten_encoded_evaled_factor(name, values)
168
-
169
- return factor_storage, variable_to_factors, CustomPandasMaterializer
170
-
171
-
172
- class AmbiguousAttributeError(ValueError):
173
- pass
174
-
175
-
176
- def resolve_ambiguous(objs: Sequence[Any], attr: str) -> Any:
177
- """Given a list of objects, return an attribute if it is the same between all object. Otherwise, raise an error."""
178
- if not objs:
179
- raise ValueError("Collection is empty")
180
-
181
- first_obj_attr = getattr(objs[0], attr)
182
-
183
- # Check if the attribute is the same for all objects
184
- for obj in objs[1:]:
185
- if getattr(obj, attr) != first_obj_attr:
186
- raise AmbiguousAttributeError(f"Ambiguous attribute '{attr}': values differ between objects")
187
-
188
- # If attribute is the same for all objects, return it
189
- return first_obj_attr