pertpy 0.9.3__py3-none-any.whl → 0.9.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +1 -1
- pertpy/_doc.py +20 -0
- pertpy/data/_dataloader.py +4 -4
- pertpy/data/_datasets.py +3 -3
- pertpy/metadata/_cell_line.py +19 -7
- pertpy/metadata/_compound.py +3 -4
- pertpy/metadata/_metadata.py +1 -1
- pertpy/preprocessing/_guide_rna.py +19 -6
- pertpy/tools/__init__.py +12 -15
- pertpy/tools/_augur.py +36 -46
- pertpy/tools/_cinemaot.py +24 -18
- pertpy/tools/_coda/_base_coda.py +87 -106
- pertpy/tools/_dialogue.py +17 -21
- pertpy/tools/_differential_gene_expression/__init__.py +1 -2
- pertpy/tools/_differential_gene_expression/_base.py +495 -113
- pertpy/tools/_differential_gene_expression/_edger.py +30 -21
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +15 -29
- pertpy/tools/_differential_gene_expression/_statsmodels.py +0 -11
- pertpy/tools/_distances/_distances.py +15 -8
- pertpy/tools/_enrichment.py +18 -8
- pertpy/tools/_milo.py +58 -46
- pertpy/tools/_mixscape.py +111 -100
- pertpy/tools/_perturbation_space/_perturbation_space.py +40 -31
- pertpy/tools/_perturbation_space/_simple.py +50 -0
- pertpy/tools/_scgen/_scgen.py +35 -25
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/METADATA +5 -4
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/RECORD +29 -29
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/WHEEL +1 -1
- pertpy/tools/_differential_gene_expression/_formulaic.py +0 -189
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_scgen/_scgen.py
CHANGED
@@ -18,12 +18,16 @@ from scvi.data.fields import CategoricalObsField, LayerField
|
|
18
18
|
from scvi.model.base import BaseModelClass, JaxTrainingMixin
|
19
19
|
from scvi.utils import setup_anndata_dsp
|
20
20
|
|
21
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
22
|
+
|
21
23
|
from ._scgenvae import JaxSCGENVAE
|
22
24
|
from ._utils import balancer, extractor
|
23
25
|
|
24
26
|
if TYPE_CHECKING:
|
25
27
|
from collections.abc import Sequence
|
26
28
|
|
29
|
+
from matplotlib.pyplot import Figure
|
30
|
+
|
27
31
|
font = {"family": "Arial", "size": 14}
|
28
32
|
|
29
33
|
|
@@ -377,9 +381,8 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
377
381
|
condition_key: str,
|
378
382
|
axis_keys: dict[str, str],
|
379
383
|
labels: dict[str, str],
|
380
|
-
|
384
|
+
*,
|
381
385
|
gene_list: list[str] = None,
|
382
|
-
show: bool = False,
|
383
386
|
top_100_genes: list[str] = None,
|
384
387
|
verbose: bool = False,
|
385
388
|
legend: bool = True,
|
@@ -387,6 +390,8 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
387
390
|
x_coeff: float = 0.30,
|
388
391
|
y_coeff: float = 0.8,
|
389
392
|
fontsize: float = 14,
|
393
|
+
show: bool = False,
|
394
|
+
save: str | bool | None = None,
|
390
395
|
**kwargs,
|
391
396
|
) -> tuple[float, float] | float:
|
392
397
|
"""Plots mean matching for a set of specified genes.
|
@@ -397,21 +402,23 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
397
402
|
corresponding to batch and cell type metadata, respectively.
|
398
403
|
condition_key: The key for the condition
|
399
404
|
axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:
|
400
|
-
`
|
401
|
-
labels: Dictionary of axes labels of the form `
|
402
|
-
path_to_save: path to save the plot.
|
403
|
-
save: Specify if the plot should be saved or not.
|
405
|
+
{`x`: `Key for x-axis`, `y`: `Key for y-axis`}.
|
406
|
+
labels: Dictionary of axes labels of the form {`x`: `x-axis-name`, `y`: `y-axis name`}.
|
404
407
|
gene_list: list of gene names to be plotted.
|
405
|
-
show: if `True`: will show to the plot after saving it.
|
406
408
|
top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
|
407
|
-
verbose: Specify if you want information to be printed while creating the plot
|
409
|
+
verbose: Specify if you want information to be printed while creating the plot.
|
408
410
|
legend: Whether to plot a legend.
|
409
411
|
title: Set if you want the plot to display a title.
|
410
412
|
x_coeff: Offset to print the R^2 value in x-direction.
|
411
413
|
y_coeff: Offset to print the R^2 value in y-direction.
|
412
414
|
fontsize: Fontsize used for text in the plot.
|
415
|
+
show: if `True`, will show to the plot after saving it.
|
416
|
+
save: Specify if the plot should be saved or not.
|
413
417
|
**kwargs:
|
414
418
|
|
419
|
+
Returns:
|
420
|
+
Returns R^2 value for all genes and R^2 value for top 100 DEGs if `top_100_genes` is not `None`.
|
421
|
+
|
415
422
|
Examples:
|
416
423
|
>>> import pertpy as pt
|
417
424
|
>>> data = pt.dt.kang_2018()
|
@@ -498,6 +505,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
498
505
|
r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}",
|
499
506
|
fontsize=kwargs.get("textsize", fontsize),
|
500
507
|
)
|
508
|
+
|
501
509
|
if save:
|
502
510
|
plt.savefig(save, bbox_inches="tight")
|
503
511
|
if show:
|
@@ -514,16 +522,17 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
514
522
|
condition_key: str,
|
515
523
|
axis_keys: dict[str, str],
|
516
524
|
labels: dict[str, str],
|
517
|
-
|
525
|
+
*,
|
518
526
|
gene_list: list[str] = None,
|
519
527
|
top_100_genes: list[str] = None,
|
520
|
-
show: bool = False,
|
521
528
|
legend: bool = True,
|
522
529
|
title: str = None,
|
523
530
|
verbose: bool = False,
|
524
531
|
x_coeff: float = 0.3,
|
525
532
|
y_coeff: float = 0.8,
|
526
533
|
fontsize: float = 14,
|
534
|
+
show: bool = True,
|
535
|
+
save: str | bool | None = None,
|
527
536
|
**kwargs,
|
528
537
|
) -> tuple[float, float] | float:
|
529
538
|
"""Plots variance matching for a set of specified genes.
|
@@ -534,19 +543,18 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
534
543
|
corresponding to batch and cell type metadata, respectively.
|
535
544
|
condition_key: Key of the condition.
|
536
545
|
axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:
|
537
|
-
|
538
|
-
labels: Dictionary of axes labels of the form
|
539
|
-
path_to_save: path to save the plot.
|
540
|
-
save: Specify if the plot should be saved or not.
|
546
|
+
{"x": "Key for x-axis", "y": "Key for y-axis"}.
|
547
|
+
labels: Dictionary of axes labels of the form {"x": "x-axis-name", "y": "y-axis name"}.
|
541
548
|
gene_list: list of gene names to be plotted.
|
542
|
-
show: if `True`: will show to the plot after saving it.
|
543
549
|
top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
|
544
|
-
legend: Whether to plot a
|
550
|
+
legend: Whether to plot a legend.
|
545
551
|
title: Set if you want the plot to display a title.
|
546
552
|
verbose: Specify if you want information to be printed while creating the plot.
|
547
553
|
x_coeff: Offset to print the R^2 value in x-direction.
|
548
554
|
y_coeff: Offset to print the R^2 value in y-direction.
|
549
555
|
fontsize: Fontsize used for text in the plot.
|
556
|
+
show: if `True`, will show to the plot after saving it.
|
557
|
+
save: Specify if the plot should be saved or not.
|
550
558
|
"""
|
551
559
|
import seaborn as sns
|
552
560
|
|
@@ -636,6 +644,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
636
644
|
else:
|
637
645
|
return r_value**2
|
638
646
|
|
647
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
639
648
|
def plot_binary_classifier(
|
640
649
|
self,
|
641
650
|
scgen: Scgen,
|
@@ -643,10 +652,11 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
643
652
|
delta: np.ndarray,
|
644
653
|
ctrl_key: str,
|
645
654
|
stim_key: str,
|
646
|
-
|
647
|
-
save: str | bool | None = None,
|
655
|
+
*,
|
648
656
|
fontsize: float = 14,
|
649
|
-
|
657
|
+
show: bool = True,
|
658
|
+
return_fig: bool = False,
|
659
|
+
) -> Figure | None:
|
650
660
|
"""Plots the dot product between delta and latent representation of a linear classifier.
|
651
661
|
|
652
662
|
Builds a linear classifier based on the dot product between
|
@@ -661,9 +671,11 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
661
671
|
delta: Difference between stimulated and control cells in latent space
|
662
672
|
ctrl_key: Key for `control` part of the `data` found in `condition_key`.
|
663
673
|
stim_key: Key for `stimulated` part of the `data` found in `condition_key`.
|
664
|
-
path_to_save: Path to save the plot.
|
665
|
-
save: Specify if the plot should be saved or not.
|
666
674
|
fontsize: Set the font size of the plot.
|
675
|
+
{common_plot_args}
|
676
|
+
|
677
|
+
Returns:
|
678
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
667
679
|
"""
|
668
680
|
plt.close("all")
|
669
681
|
adata = scgen._validate_anndata(adata)
|
@@ -693,12 +705,10 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
693
705
|
ax = plt.gca()
|
694
706
|
ax.grid(False)
|
695
707
|
|
696
|
-
if save:
|
697
|
-
plt.savefig(save, bbox_inches="tight")
|
698
708
|
if show:
|
699
709
|
plt.show()
|
700
|
-
if
|
701
|
-
return
|
710
|
+
if return_fig:
|
711
|
+
return plt.gcf()
|
702
712
|
return None
|
703
713
|
|
704
714
|
|
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: pertpy
|
3
|
-
Version: 0.9.
|
3
|
+
Version: 0.9.5
|
4
4
|
Summary: Perturbation Analysis in the scverse ecosystem.
|
5
5
|
Project-URL: Documentation, https://pertpy.readthedocs.io
|
6
6
|
Project-URL: Source, https://github.com/scverse/pertpy
|
@@ -44,7 +44,7 @@ Classifier: Programming Language :: Python :: 3.11
|
|
44
44
|
Classifier: Programming Language :: Python :: 3.12
|
45
45
|
Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
|
46
46
|
Classifier: Topic :: Scientific/Engineering :: Visualization
|
47
|
-
Requires-Python:
|
47
|
+
Requires-Python: <3.13,>=3.10
|
48
48
|
Requires-Dist: adjusttext
|
49
49
|
Requires-Dist: blitzgsea
|
50
50
|
Requires-Dist: decoupler
|
@@ -68,7 +68,8 @@ Requires-Dist: pyqt5; extra == 'coda'
|
|
68
68
|
Requires-Dist: toytree; extra == 'coda'
|
69
69
|
Provides-Extra: de
|
70
70
|
Requires-Dist: formulaic; extra == 'de'
|
71
|
-
Requires-Dist:
|
71
|
+
Requires-Dist: formulaic-contrasts>=0.2.0; extra == 'de'
|
72
|
+
Requires-Dist: pydeseq2>=v0.5.0pre1; extra == 'de'
|
72
73
|
Provides-Extra: dev
|
73
74
|
Requires-Dist: pre-commit; extra == 'dev'
|
74
75
|
Provides-Extra: doc
|
@@ -1,57 +1,57 @@
|
|
1
|
-
pertpy/__init__.py,sha256=
|
1
|
+
pertpy/__init__.py,sha256=r5QhDw2-Ls4yYLs1kJJVe_r6dstQ7SjoASFutlTU9JA,658
|
2
|
+
pertpy/_doc.py,sha256=pVt5Iegvh4rC1N81fd9e4cwmoGPNSgttZxxPWbLK6Bs,453
|
2
3
|
pertpy/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
4
|
pertpy/data/__init__.py,sha256=ah3yvoxkgbdMUNAWxS3SyqcUuVamBOSeuWkF2QRAEwM,2703
|
4
|
-
pertpy/data/_dataloader.py,sha256=
|
5
|
-
pertpy/data/_datasets.py,sha256=
|
5
|
+
pertpy/data/_dataloader.py,sha256=ENbk1T3w3N6tVI11V4FVUxuWFEwOHP8_kIB-ehiMlVQ,2428
|
6
|
+
pertpy/data/_datasets.py,sha256=c_U2U2NvncPZc6vs6w_s77zmWqr8ywDpzmkx6edCfUE,65616
|
6
7
|
pertpy/metadata/__init__.py,sha256=zoE_VXNyuKa4nlXlUk2nTgsHRW3jSQSpDEulcCnzOT0,222
|
7
|
-
pertpy/metadata/_cell_line.py,sha256
|
8
|
-
pertpy/metadata/_compound.py,sha256=
|
8
|
+
pertpy/metadata/_cell_line.py,sha256=Ell5PDVoMlrhHXPDKGCiPGwNY0DAeghbUUvTYL-SFF0,38919
|
9
|
+
pertpy/metadata/_compound.py,sha256=ywNNqtib0exHv0z8ctmTRf1Hk64tSGWSiUEffycxf6A,4755
|
9
10
|
pertpy/metadata/_drug.py,sha256=8QDSyxiFl25JdS80EQJC_krg6fEe5LIQEE6BsV1r8nY,9006
|
10
11
|
pertpy/metadata/_look_up.py,sha256=DoWp6OxIk_HyyyOhW1p8z5E68IZ31_nZDnqxk1rJqps,28778
|
11
|
-
pertpy/metadata/_metadata.py,sha256=
|
12
|
+
pertpy/metadata/_metadata.py,sha256=hV2LTFrExddLNU_RsDkZju6lQUSRoP4OIn_dumCyQao,3277
|
12
13
|
pertpy/metadata/_moa.py,sha256=u_OcMonjOeeoW5P9xOltquVSoTH3Vs80ztHsXf-X1DY,4701
|
13
14
|
pertpy/plot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
15
|
pertpy/preprocessing/__init__.py,sha256=VAPFaeq2_qCvdFkQTCj_Hm460HC4Tersu8Rig_tnp_Y,71
|
15
|
-
pertpy/preprocessing/_guide_rna.py,sha256=
|
16
|
-
pertpy/tools/__init__.py,sha256=
|
17
|
-
pertpy/tools/_augur.py,sha256=
|
18
|
-
pertpy/tools/_cinemaot.py,sha256=
|
19
|
-
pertpy/tools/_dialogue.py,sha256=
|
20
|
-
pertpy/tools/_enrichment.py,sha256=
|
16
|
+
pertpy/preprocessing/_guide_rna.py,sha256=D9hEh8LOOTs_UfVBsBW3b-o6ipRFq09J471Hg1s0tlM,7963
|
17
|
+
pertpy/tools/__init__.py,sha256=NUTwCGxRdzUzLTgsS3r7MywENwPAdcGZDKrl83sU8mo,2599
|
18
|
+
pertpy/tools/_augur.py,sha256=Vghsx5-fYlaEeu__8-HUg6v5_KoVNhiRDjhgE3pORpY,55339
|
19
|
+
pertpy/tools/_cinemaot.py,sha256=U6vCb_mI4ZPFshYgsx-hOOsDA1IPwI7ZR_-IH4F9s7s,39621
|
20
|
+
pertpy/tools/_dialogue.py,sha256=BShXZ1ehO2eMbP5PV-ONJ-1SsxD6h9nAN7bGQ4_F6Rw,51906
|
21
|
+
pertpy/tools/_enrichment.py,sha256=jxVdOrpS_lAu7GCpemgdB4JJvsGH9SJTQsAKLBKi9Tc,21640
|
21
22
|
pertpy/tools/_kernel_pca.py,sha256=_EJ9WlBLjHOafF34sZGdyBgZL6Fj0WiJ1elVT1XMmo4,1579
|
22
|
-
pertpy/tools/_milo.py,sha256=
|
23
|
-
pertpy/tools/_mixscape.py,sha256=
|
23
|
+
pertpy/tools/_milo.py,sha256=SQqknT2zkzI0pcUmTm0ijWMs7CFMRiyRnXt9rC0jvmg,43811
|
24
|
+
pertpy/tools/_mixscape.py,sha256=T-oUHDnepao5aujAHw9bAbbQHPSK6oD_8Wr_mw4U0nc,52089
|
24
25
|
pertpy/tools/decoupler_LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
25
26
|
pertpy/tools/transferlearning_MMD_LICENSE,sha256=MUvDA-o_j9htRpI8fStVdCRuyLdPkQUuIH0a_EIc57w,1069
|
26
27
|
pertpy/tools/_coda/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
|
-
pertpy/tools/_coda/_base_coda.py,sha256=
|
28
|
+
pertpy/tools/_coda/_base_coda.py,sha256=Uxb_vAORL77LkUri8QqL9SxaULK4hfpgRGsCQdtsAgk,111808
|
28
29
|
pertpy/tools/_coda/_sccoda.py,sha256=gGmyd0MGpchulV9d4PxKSmGORyZ8fCDS9tQVPOuF_Og,22622
|
29
30
|
pertpy/tools/_coda/_tasccoda.py,sha256=vNk43OQHn7pBLsez2rmSj0bMZKOf8jZTI7G8TfBByRg,30665
|
30
|
-
pertpy/tools/_differential_gene_expression/__init__.py,sha256=
|
31
|
-
pertpy/tools/_differential_gene_expression/_base.py,sha256=
|
31
|
+
pertpy/tools/_differential_gene_expression/__init__.py,sha256=SEydWg0iT3Y1pApjnCAOuHxFeI6xVUfgyBHv2s3LADU,487
|
32
|
+
pertpy/tools/_differential_gene_expression/_base.py,sha256=yc9DBj2KgJVk4mkjz7EDFoBj8WBZW92Z4ayD-Xdla1g,38514
|
32
33
|
pertpy/tools/_differential_gene_expression/_checks.py,sha256=SxNHJDsCYZ6rWLTMEymEBpigs_B9cnXyw0kkAe1l6e0,1675
|
33
34
|
pertpy/tools/_differential_gene_expression/_dge_comparison.py,sha256=9HjmWkrqZhj_ZJeR-ymyEDzpRJNx7JiYJoStvCfKuCU,4188
|
34
|
-
pertpy/tools/_differential_gene_expression/_edger.py,sha256=
|
35
|
-
pertpy/tools/_differential_gene_expression/
|
36
|
-
pertpy/tools/_differential_gene_expression/_pydeseq2.py,sha256=JK7H7u4va0q_TLE_sqi4JEzoPBd_xNRycYGu1507HS4,4117
|
35
|
+
pertpy/tools/_differential_gene_expression/_edger.py,sha256=ttgTocAYnr8BTDcixwHGjRZew6zeja-U77TLKkSdd1Y,4857
|
36
|
+
pertpy/tools/_differential_gene_expression/_pydeseq2.py,sha256=aOqsdu8hKp8_h2HhjkxS0B_itxRBnzEU2oSnU2PYiQ4,2942
|
37
37
|
pertpy/tools/_differential_gene_expression/_simple_tests.py,sha256=tTSr0Z2Qbpxdy9bcO8Gi_up6R616IcoK_e4_rlanyx4,6621
|
38
|
-
pertpy/tools/_differential_gene_expression/_statsmodels.py,sha256=
|
38
|
+
pertpy/tools/_differential_gene_expression/_statsmodels.py,sha256=jBCtaCglOvvVjkIBGXuTCTDB6g2AJsZMCf7iOlDyn48,2195
|
39
39
|
pertpy/tools/_distances/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
40
40
|
pertpy/tools/_distances/_distance_tests.py,sha256=mNmNu5cX0Wj5IegR6x5K-CbBSid8EhrH2jZPQxuvK4U,13521
|
41
|
-
pertpy/tools/_distances/_distances.py,sha256=
|
41
|
+
pertpy/tools/_distances/_distances.py,sha256=CmrOKevVCTY9j3PzhpVc3ga6SwZy9wbbJa0_7bwLMWQ,50569
|
42
42
|
pertpy/tools/_perturbation_space/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
43
43
|
pertpy/tools/_perturbation_space/_clustering.py,sha256=m52-J8c8OnIXRCf3NoFabIO2yMHIuy1X0m0amtsK2vE,3556
|
44
44
|
pertpy/tools/_perturbation_space/_comparison.py,sha256=rLO-EGU0I7t5MnLw4k1gYU-ypRu-JsDPLat1t4h2U2M,4329
|
45
45
|
pertpy/tools/_perturbation_space/_discriminator_classifiers.py,sha256=OA2eZeG_4iuW1T5ilsRIkS0rU-azmwEch7IuB546KSY,21617
|
46
46
|
pertpy/tools/_perturbation_space/_metrics.py,sha256=y8-baP8WRdB1iDgvP3uuQxSCDxA2lcxvEHHM2C_vWHY,3248
|
47
|
-
pertpy/tools/_perturbation_space/_perturbation_space.py,sha256=
|
48
|
-
pertpy/tools/_perturbation_space/_simple.py,sha256=
|
47
|
+
pertpy/tools/_perturbation_space/_perturbation_space.py,sha256=F-F-_pMCTWxjkVQSLre6hrE6PeRfCRscpt2ug3NlfuU,19531
|
48
|
+
pertpy/tools/_perturbation_space/_simple.py,sha256=RQv6B0xPq4RJa5zlkLkxMYXQ3LAJLglmQDGaTMseaA8,14238
|
49
49
|
pertpy/tools/_scgen/__init__.py,sha256=uERFlFyF88TH0uLiwmsUGEfHfLVCiZMFuk8gO5f7164,45
|
50
50
|
pertpy/tools/_scgen/_base_components.py,sha256=Qq8myRUm43q9XBrZ9gBggfa2cSV2wbz_KYoLgH7iF1A,3009
|
51
|
-
pertpy/tools/_scgen/_scgen.py,sha256=
|
51
|
+
pertpy/tools/_scgen/_scgen.py,sha256=oVY2JNYhDn1OrPoq22ATIP5-H615BafidBCC0eC5C-4,30756
|
52
52
|
pertpy/tools/_scgen/_scgenvae.py,sha256=v_6tZ4wY-JjdMH1QVd_wG4_N0PoaqB-FM8zC2JsDu1o,3935
|
53
53
|
pertpy/tools/_scgen/_utils.py,sha256=1upgOt1FpadfvNG05YpMjYYG-IAlxrC3l_ZxczmIczo,2841
|
54
|
-
pertpy-0.9.
|
55
|
-
pertpy-0.9.
|
56
|
-
pertpy-0.9.
|
57
|
-
pertpy-0.9.
|
54
|
+
pertpy-0.9.5.dist-info/METADATA,sha256=vuf16H5cVKgNKRg35pFSWN7Oa7tT4IOLqqymVzZfnr4,6927
|
55
|
+
pertpy-0.9.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
56
|
+
pertpy-0.9.5.dist-info/licenses/LICENSE,sha256=OZ-ZkXM5CmExJiEMM90b_7dGNNvRpj7kdE-49AnrLuI,1070
|
57
|
+
pertpy-0.9.5.dist-info/RECORD,,
|
@@ -1,189 +0,0 @@
|
|
1
|
-
"""Helpers to interact with Formulaic Formulas
|
2
|
-
|
3
|
-
Some helpful definitions for working with formulaic formulas (e.g. `~ 0 + C(donor):treatment + np.log1p(continuous)`):
|
4
|
-
* A *term* refers to an expression in the formula, separated by `+`, e.g. `C(donor):treatment`, or `np.log1p(continuous)`.
|
5
|
-
* A *variable* refers to a column of the data frame passed to formulaic, e.g. `donor`.
|
6
|
-
* A *factor* is the specification of how a certain variable is represented in the design matrix, e.g. treatment coding with base level "A" and reduced rank.
|
7
|
-
"""
|
8
|
-
|
9
|
-
from collections import defaultdict
|
10
|
-
from collections.abc import Mapping, Sequence
|
11
|
-
from dataclasses import dataclass
|
12
|
-
from typing import Any
|
13
|
-
|
14
|
-
from formulaic import FactorValues, ModelSpec
|
15
|
-
from formulaic.materializers import PandasMaterializer
|
16
|
-
from formulaic.materializers.types import EvaluatedFactor
|
17
|
-
from formulaic.parser.types import Factor
|
18
|
-
from interface_meta import override
|
19
|
-
|
20
|
-
|
21
|
-
@dataclass
|
22
|
-
class FactorMetadata:
|
23
|
-
"""Store (relevant) metadata for a factor of a formula."""
|
24
|
-
|
25
|
-
name: str
|
26
|
-
"""The unambiguous factor name as specified in the formula. E.g. `donor`, or `C(donor, contr.treatment(base="A"))`"""
|
27
|
-
|
28
|
-
reduced_rank: bool
|
29
|
-
"""Whether a column will be dropped because it is redundant"""
|
30
|
-
|
31
|
-
custom_encoder: bool
|
32
|
-
"""Whether or not a custom encoder (e.g. `C(...)`) was used."""
|
33
|
-
|
34
|
-
categories: Sequence[str]
|
35
|
-
"""The unique categories in this factor (after applying `drop_rows`)"""
|
36
|
-
|
37
|
-
kind: Factor.Kind
|
38
|
-
"""Type of the factor"""
|
39
|
-
|
40
|
-
drop_field: str = None
|
41
|
-
"""The category that is dropped.
|
42
|
-
|
43
|
-
Note that
|
44
|
-
* this may also be populated if `reduced_rank = False`
|
45
|
-
* this is only populated when no encoder was used (e.g. `~ donor` but NOT `~ C(donor)`.
|
46
|
-
"""
|
47
|
-
|
48
|
-
column_names: Sequence[str] = None
|
49
|
-
"""The column names for this factor included in the design matrix.
|
50
|
-
|
51
|
-
This may be the same as `categories` if the default encoder is used, or
|
52
|
-
categories without the base level if a custom encoder (e.g. `C(...)`) is used.
|
53
|
-
"""
|
54
|
-
|
55
|
-
colname_format: str = None
|
56
|
-
"""A formattable string that can be used to generate the column name in the design matrix, e.g. `{name}[T.{field}]`"""
|
57
|
-
|
58
|
-
@property
|
59
|
-
def base(self) -> str | None:
|
60
|
-
"""
|
61
|
-
The base category for this categorical factor.
|
62
|
-
|
63
|
-
This is derived from `drop_field` (for default encoding) or by comparing the column names in
|
64
|
-
the design matrix with all categories (for custom encoding, e.g. `C(...)`).
|
65
|
-
"""
|
66
|
-
if not self.reduced_rank:
|
67
|
-
return None
|
68
|
-
else:
|
69
|
-
if self.custom_encoder:
|
70
|
-
tmp_base = set(self.categories) - set(self.column_names)
|
71
|
-
assert len(tmp_base) == 1
|
72
|
-
return tmp_base.pop()
|
73
|
-
else:
|
74
|
-
assert self.drop_field is not None
|
75
|
-
return self.drop_field
|
76
|
-
|
77
|
-
|
78
|
-
def get_factor_storage_and_materializer() -> tuple[dict[str, list[FactorMetadata]], dict[str, set[str]], type]:
|
79
|
-
"""Keeps track of categorical factors used in a model specification by generating a custom materializer.
|
80
|
-
|
81
|
-
This materializer reports back metadata upon materialization of the model matrix.
|
82
|
-
|
83
|
-
Returns:
|
84
|
-
- A dictionary storing metadata for each factor processed by the custom materializer, named `factor_storage`.
|
85
|
-
- A dictionary mapping variables to factor names, which works similarly to model_spec.variable_terms
|
86
|
-
but maps to factors rather than terms, named `variable_to_factors`.
|
87
|
-
- A materializer class tied to the specific instance of `factor_storage`.
|
88
|
-
"""
|
89
|
-
# There can be multiple FactorMetadata entries per sample, for instance when formulaic generates an interaction
|
90
|
-
# term, it generates the factor with both full rank and reduced rank.
|
91
|
-
factor_storage: dict[str, list[FactorMetadata]] = defaultdict(list)
|
92
|
-
variable_to_factors: dict[str, set[str]] = defaultdict(set)
|
93
|
-
|
94
|
-
class CustomPandasMaterializer(PandasMaterializer):
|
95
|
-
"""An extension of the PandasMaterializer that records all categorical variables and their (base) categories."""
|
96
|
-
|
97
|
-
REGISTER_NAME = "custom_pandas"
|
98
|
-
REGISTER_INPUTS = ("pandas.core.frame.DataFrame",)
|
99
|
-
REGISTER_OUTPUTS = ("pandas", "numpy", "sparse")
|
100
|
-
|
101
|
-
def __init__(
|
102
|
-
self,
|
103
|
-
data: Any,
|
104
|
-
context: Mapping[str, Any] | None = None,
|
105
|
-
record_factor_metadata: bool = False,
|
106
|
-
**params: Any,
|
107
|
-
):
|
108
|
-
"""Initialize the Materializer.
|
109
|
-
|
110
|
-
Args:
|
111
|
-
data: Passed to PandasMaterializer.
|
112
|
-
context: Passed to PandasMaterializer
|
113
|
-
record_factor_metadata: Flag that tells whether this particular instance of the custom materializer class
|
114
|
-
is supposed to record factor metadata. Only the instance that is used for building the design
|
115
|
-
matrix should record the metadata. All other instances (e.g. used to generate contrast vectors)
|
116
|
-
should not record metadata to not overwrite the specifications from the design matrix.
|
117
|
-
**params: Passed to PandasMaterializer
|
118
|
-
"""
|
119
|
-
self.factor_metadata_storage = factor_storage if record_factor_metadata else None
|
120
|
-
self.variable_to_factors = variable_to_factors if record_factor_metadata else None
|
121
|
-
# temporary pointer to metadata of factor that is currently evaluated
|
122
|
-
self._current_factor: FactorMetadata = None
|
123
|
-
super().__init__(data, context, **params)
|
124
|
-
|
125
|
-
@override
|
126
|
-
def _encode_evaled_factor(
|
127
|
-
self, factor: EvaluatedFactor, spec: ModelSpec, drop_rows: Sequence[int], reduced_rank: bool = False
|
128
|
-
) -> dict[str, Any]:
|
129
|
-
"""Function is called just before the factor is evaluated.
|
130
|
-
|
131
|
-
We can record some metadata, before we call the original function.
|
132
|
-
"""
|
133
|
-
assert (
|
134
|
-
self._current_factor is None
|
135
|
-
), "_current_factor should always be None when we start recording metadata"
|
136
|
-
if self.factor_metadata_storage is not None:
|
137
|
-
# Don't store if the factor is cached (then we should already have recorded it)
|
138
|
-
if factor.expr in self.encoded_cache or (factor.expr, reduced_rank) in self.encoded_cache:
|
139
|
-
assert factor.expr in self.factor_metadata_storage, "Factor should be there since it's cached"
|
140
|
-
else:
|
141
|
-
for var in factor.variables:
|
142
|
-
self.variable_to_factors[var].add(factor.expr)
|
143
|
-
self._current_factor = FactorMetadata(
|
144
|
-
name=factor.expr,
|
145
|
-
reduced_rank=reduced_rank,
|
146
|
-
categories=tuple(sorted(factor.values.drop(index=factor.values.index[drop_rows]).unique())),
|
147
|
-
custom_encoder=factor.metadata.encoder is not None,
|
148
|
-
kind=factor.metadata.kind,
|
149
|
-
)
|
150
|
-
return super()._encode_evaled_factor(factor, spec, drop_rows, reduced_rank)
|
151
|
-
|
152
|
-
@override
|
153
|
-
def _flatten_encoded_evaled_factor(self, name: str, values: FactorValues[dict]) -> dict[str, Any]:
|
154
|
-
"""
|
155
|
-
Function is called at the end, before the design matrix gets materialized.
|
156
|
-
|
157
|
-
Here we have access to additional metadata, such as `drop_field`.
|
158
|
-
"""
|
159
|
-
if self._current_factor is not None:
|
160
|
-
assert self._current_factor.name == name
|
161
|
-
self._current_factor.drop_field = values.__formulaic_metadata__.drop_field
|
162
|
-
self._current_factor.column_names = values.__formulaic_metadata__.column_names
|
163
|
-
self._current_factor.colname_format = values.__formulaic_metadata__.format
|
164
|
-
self.factor_metadata_storage[name].append(self._current_factor)
|
165
|
-
self._current_factor = None
|
166
|
-
|
167
|
-
return super()._flatten_encoded_evaled_factor(name, values)
|
168
|
-
|
169
|
-
return factor_storage, variable_to_factors, CustomPandasMaterializer
|
170
|
-
|
171
|
-
|
172
|
-
class AmbiguousAttributeError(ValueError):
|
173
|
-
pass
|
174
|
-
|
175
|
-
|
176
|
-
def resolve_ambiguous(objs: Sequence[Any], attr: str) -> Any:
|
177
|
-
"""Given a list of objects, return an attribute if it is the same between all object. Otherwise, raise an error."""
|
178
|
-
if not objs:
|
179
|
-
raise ValueError("Collection is empty")
|
180
|
-
|
181
|
-
first_obj_attr = getattr(objs[0], attr)
|
182
|
-
|
183
|
-
# Check if the attribute is the same for all objects
|
184
|
-
for obj in objs[1:]:
|
185
|
-
if getattr(obj, attr) != first_obj_attr:
|
186
|
-
raise AmbiguousAttributeError(f"Ambiguous attribute '{attr}': values differ between objects")
|
187
|
-
|
188
|
-
# If attribute is the same for all objects, return it
|
189
|
-
return first_obj_attr
|
File without changes
|