pertpy 0.9.3__py3-none-any.whl → 0.9.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +1 -1
- pertpy/_doc.py +20 -0
- pertpy/data/_dataloader.py +4 -4
- pertpy/data/_datasets.py +3 -3
- pertpy/metadata/_cell_line.py +19 -7
- pertpy/metadata/_compound.py +3 -4
- pertpy/metadata/_metadata.py +1 -1
- pertpy/preprocessing/_guide_rna.py +19 -6
- pertpy/tools/__init__.py +12 -15
- pertpy/tools/_augur.py +36 -46
- pertpy/tools/_cinemaot.py +24 -18
- pertpy/tools/_coda/_base_coda.py +87 -106
- pertpy/tools/_dialogue.py +17 -21
- pertpy/tools/_differential_gene_expression/__init__.py +1 -2
- pertpy/tools/_differential_gene_expression/_base.py +495 -113
- pertpy/tools/_differential_gene_expression/_edger.py +30 -21
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +15 -29
- pertpy/tools/_differential_gene_expression/_statsmodels.py +0 -11
- pertpy/tools/_distances/_distances.py +15 -8
- pertpy/tools/_enrichment.py +18 -8
- pertpy/tools/_milo.py +58 -46
- pertpy/tools/_mixscape.py +111 -100
- pertpy/tools/_perturbation_space/_perturbation_space.py +40 -31
- pertpy/tools/_perturbation_space/_simple.py +50 -0
- pertpy/tools/_scgen/_scgen.py +35 -25
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/METADATA +5 -4
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/RECORD +29 -29
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/WHEEL +1 -1
- pertpy/tools/_differential_gene_expression/_formulaic.py +0 -189
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_scgen/_scgen.py
CHANGED
@@ -18,12 +18,16 @@ from scvi.data.fields import CategoricalObsField, LayerField
|
|
18
18
|
from scvi.model.base import BaseModelClass, JaxTrainingMixin
|
19
19
|
from scvi.utils import setup_anndata_dsp
|
20
20
|
|
21
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
22
|
+
|
21
23
|
from ._scgenvae import JaxSCGENVAE
|
22
24
|
from ._utils import balancer, extractor
|
23
25
|
|
24
26
|
if TYPE_CHECKING:
|
25
27
|
from collections.abc import Sequence
|
26
28
|
|
29
|
+
from matplotlib.pyplot import Figure
|
30
|
+
|
27
31
|
font = {"family": "Arial", "size": 14}
|
28
32
|
|
29
33
|
|
@@ -377,9 +381,8 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
377
381
|
condition_key: str,
|
378
382
|
axis_keys: dict[str, str],
|
379
383
|
labels: dict[str, str],
|
380
|
-
|
384
|
+
*,
|
381
385
|
gene_list: list[str] = None,
|
382
|
-
show: bool = False,
|
383
386
|
top_100_genes: list[str] = None,
|
384
387
|
verbose: bool = False,
|
385
388
|
legend: bool = True,
|
@@ -387,6 +390,8 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
387
390
|
x_coeff: float = 0.30,
|
388
391
|
y_coeff: float = 0.8,
|
389
392
|
fontsize: float = 14,
|
393
|
+
show: bool = False,
|
394
|
+
save: str | bool | None = None,
|
390
395
|
**kwargs,
|
391
396
|
) -> tuple[float, float] | float:
|
392
397
|
"""Plots mean matching for a set of specified genes.
|
@@ -397,21 +402,23 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
397
402
|
corresponding to batch and cell type metadata, respectively.
|
398
403
|
condition_key: The key for the condition
|
399
404
|
axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:
|
400
|
-
`
|
401
|
-
labels: Dictionary of axes labels of the form `
|
402
|
-
path_to_save: path to save the plot.
|
403
|
-
save: Specify if the plot should be saved or not.
|
405
|
+
{`x`: `Key for x-axis`, `y`: `Key for y-axis`}.
|
406
|
+
labels: Dictionary of axes labels of the form {`x`: `x-axis-name`, `y`: `y-axis name`}.
|
404
407
|
gene_list: list of gene names to be plotted.
|
405
|
-
show: if `True`: will show to the plot after saving it.
|
406
408
|
top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
|
407
|
-
verbose: Specify if you want information to be printed while creating the plot
|
409
|
+
verbose: Specify if you want information to be printed while creating the plot.
|
408
410
|
legend: Whether to plot a legend.
|
409
411
|
title: Set if you want the plot to display a title.
|
410
412
|
x_coeff: Offset to print the R^2 value in x-direction.
|
411
413
|
y_coeff: Offset to print the R^2 value in y-direction.
|
412
414
|
fontsize: Fontsize used for text in the plot.
|
415
|
+
show: if `True`, will show to the plot after saving it.
|
416
|
+
save: Specify if the plot should be saved or not.
|
413
417
|
**kwargs:
|
414
418
|
|
419
|
+
Returns:
|
420
|
+
Returns R^2 value for all genes and R^2 value for top 100 DEGs if `top_100_genes` is not `None`.
|
421
|
+
|
415
422
|
Examples:
|
416
423
|
>>> import pertpy as pt
|
417
424
|
>>> data = pt.dt.kang_2018()
|
@@ -498,6 +505,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
498
505
|
r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}",
|
499
506
|
fontsize=kwargs.get("textsize", fontsize),
|
500
507
|
)
|
508
|
+
|
501
509
|
if save:
|
502
510
|
plt.savefig(save, bbox_inches="tight")
|
503
511
|
if show:
|
@@ -514,16 +522,17 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
514
522
|
condition_key: str,
|
515
523
|
axis_keys: dict[str, str],
|
516
524
|
labels: dict[str, str],
|
517
|
-
|
525
|
+
*,
|
518
526
|
gene_list: list[str] = None,
|
519
527
|
top_100_genes: list[str] = None,
|
520
|
-
show: bool = False,
|
521
528
|
legend: bool = True,
|
522
529
|
title: str = None,
|
523
530
|
verbose: bool = False,
|
524
531
|
x_coeff: float = 0.3,
|
525
532
|
y_coeff: float = 0.8,
|
526
533
|
fontsize: float = 14,
|
534
|
+
show: bool = True,
|
535
|
+
save: str | bool | None = None,
|
527
536
|
**kwargs,
|
528
537
|
) -> tuple[float, float] | float:
|
529
538
|
"""Plots variance matching for a set of specified genes.
|
@@ -534,19 +543,18 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
534
543
|
corresponding to batch and cell type metadata, respectively.
|
535
544
|
condition_key: Key of the condition.
|
536
545
|
axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:
|
537
|
-
|
538
|
-
labels: Dictionary of axes labels of the form
|
539
|
-
path_to_save: path to save the plot.
|
540
|
-
save: Specify if the plot should be saved or not.
|
546
|
+
{"x": "Key for x-axis", "y": "Key for y-axis"}.
|
547
|
+
labels: Dictionary of axes labels of the form {"x": "x-axis-name", "y": "y-axis name"}.
|
541
548
|
gene_list: list of gene names to be plotted.
|
542
|
-
show: if `True`: will show to the plot after saving it.
|
543
549
|
top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
|
544
|
-
legend: Whether to plot a
|
550
|
+
legend: Whether to plot a legend.
|
545
551
|
title: Set if you want the plot to display a title.
|
546
552
|
verbose: Specify if you want information to be printed while creating the plot.
|
547
553
|
x_coeff: Offset to print the R^2 value in x-direction.
|
548
554
|
y_coeff: Offset to print the R^2 value in y-direction.
|
549
555
|
fontsize: Fontsize used for text in the plot.
|
556
|
+
show: if `True`, will show to the plot after saving it.
|
557
|
+
save: Specify if the plot should be saved or not.
|
550
558
|
"""
|
551
559
|
import seaborn as sns
|
552
560
|
|
@@ -636,6 +644,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
636
644
|
else:
|
637
645
|
return r_value**2
|
638
646
|
|
647
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
639
648
|
def plot_binary_classifier(
|
640
649
|
self,
|
641
650
|
scgen: Scgen,
|
@@ -643,10 +652,11 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
643
652
|
delta: np.ndarray,
|
644
653
|
ctrl_key: str,
|
645
654
|
stim_key: str,
|
646
|
-
|
647
|
-
save: str | bool | None = None,
|
655
|
+
*,
|
648
656
|
fontsize: float = 14,
|
649
|
-
|
657
|
+
show: bool = True,
|
658
|
+
return_fig: bool = False,
|
659
|
+
) -> Figure | None:
|
650
660
|
"""Plots the dot product between delta and latent representation of a linear classifier.
|
651
661
|
|
652
662
|
Builds a linear classifier based on the dot product between
|
@@ -661,9 +671,11 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
661
671
|
delta: Difference between stimulated and control cells in latent space
|
662
672
|
ctrl_key: Key for `control` part of the `data` found in `condition_key`.
|
663
673
|
stim_key: Key for `stimulated` part of the `data` found in `condition_key`.
|
664
|
-
path_to_save: Path to save the plot.
|
665
|
-
save: Specify if the plot should be saved or not.
|
666
674
|
fontsize: Set the font size of the plot.
|
675
|
+
{common_plot_args}
|
676
|
+
|
677
|
+
Returns:
|
678
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
667
679
|
"""
|
668
680
|
plt.close("all")
|
669
681
|
adata = scgen._validate_anndata(adata)
|
@@ -693,12 +705,10 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
693
705
|
ax = plt.gca()
|
694
706
|
ax.grid(False)
|
695
707
|
|
696
|
-
if save:
|
697
|
-
plt.savefig(save, bbox_inches="tight")
|
698
708
|
if show:
|
699
709
|
plt.show()
|
700
|
-
if
|
701
|
-
return
|
710
|
+
if return_fig:
|
711
|
+
return plt.gcf()
|
702
712
|
return None
|
703
713
|
|
704
714
|
|
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: pertpy
|
3
|
-
Version: 0.9.
|
3
|
+
Version: 0.9.5
|
4
4
|
Summary: Perturbation Analysis in the scverse ecosystem.
|
5
5
|
Project-URL: Documentation, https://pertpy.readthedocs.io
|
6
6
|
Project-URL: Source, https://github.com/scverse/pertpy
|
@@ -44,7 +44,7 @@ Classifier: Programming Language :: Python :: 3.11
|
|
44
44
|
Classifier: Programming Language :: Python :: 3.12
|
45
45
|
Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
|
46
46
|
Classifier: Topic :: Scientific/Engineering :: Visualization
|
47
|
-
Requires-Python:
|
47
|
+
Requires-Python: <3.13,>=3.10
|
48
48
|
Requires-Dist: adjusttext
|
49
49
|
Requires-Dist: blitzgsea
|
50
50
|
Requires-Dist: decoupler
|
@@ -68,7 +68,8 @@ Requires-Dist: pyqt5; extra == 'coda'
|
|
68
68
|
Requires-Dist: toytree; extra == 'coda'
|
69
69
|
Provides-Extra: de
|
70
70
|
Requires-Dist: formulaic; extra == 'de'
|
71
|
-
Requires-Dist:
|
71
|
+
Requires-Dist: formulaic-contrasts>=0.2.0; extra == 'de'
|
72
|
+
Requires-Dist: pydeseq2>=v0.5.0pre1; extra == 'de'
|
72
73
|
Provides-Extra: dev
|
73
74
|
Requires-Dist: pre-commit; extra == 'dev'
|
74
75
|
Provides-Extra: doc
|
@@ -1,57 +1,57 @@
|
|
1
|
-
pertpy/__init__.py,sha256=
|
1
|
+
pertpy/__init__.py,sha256=r5QhDw2-Ls4yYLs1kJJVe_r6dstQ7SjoASFutlTU9JA,658
|
2
|
+
pertpy/_doc.py,sha256=pVt5Iegvh4rC1N81fd9e4cwmoGPNSgttZxxPWbLK6Bs,453
|
2
3
|
pertpy/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
4
|
pertpy/data/__init__.py,sha256=ah3yvoxkgbdMUNAWxS3SyqcUuVamBOSeuWkF2QRAEwM,2703
|
4
|
-
pertpy/data/_dataloader.py,sha256=
|
5
|
-
pertpy/data/_datasets.py,sha256=
|
5
|
+
pertpy/data/_dataloader.py,sha256=ENbk1T3w3N6tVI11V4FVUxuWFEwOHP8_kIB-ehiMlVQ,2428
|
6
|
+
pertpy/data/_datasets.py,sha256=c_U2U2NvncPZc6vs6w_s77zmWqr8ywDpzmkx6edCfUE,65616
|
6
7
|
pertpy/metadata/__init__.py,sha256=zoE_VXNyuKa4nlXlUk2nTgsHRW3jSQSpDEulcCnzOT0,222
|
7
|
-
pertpy/metadata/_cell_line.py,sha256
|
8
|
-
pertpy/metadata/_compound.py,sha256=
|
8
|
+
pertpy/metadata/_cell_line.py,sha256=Ell5PDVoMlrhHXPDKGCiPGwNY0DAeghbUUvTYL-SFF0,38919
|
9
|
+
pertpy/metadata/_compound.py,sha256=ywNNqtib0exHv0z8ctmTRf1Hk64tSGWSiUEffycxf6A,4755
|
9
10
|
pertpy/metadata/_drug.py,sha256=8QDSyxiFl25JdS80EQJC_krg6fEe5LIQEE6BsV1r8nY,9006
|
10
11
|
pertpy/metadata/_look_up.py,sha256=DoWp6OxIk_HyyyOhW1p8z5E68IZ31_nZDnqxk1rJqps,28778
|
11
|
-
pertpy/metadata/_metadata.py,sha256=
|
12
|
+
pertpy/metadata/_metadata.py,sha256=hV2LTFrExddLNU_RsDkZju6lQUSRoP4OIn_dumCyQao,3277
|
12
13
|
pertpy/metadata/_moa.py,sha256=u_OcMonjOeeoW5P9xOltquVSoTH3Vs80ztHsXf-X1DY,4701
|
13
14
|
pertpy/plot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
15
|
pertpy/preprocessing/__init__.py,sha256=VAPFaeq2_qCvdFkQTCj_Hm460HC4Tersu8Rig_tnp_Y,71
|
15
|
-
pertpy/preprocessing/_guide_rna.py,sha256=
|
16
|
-
pertpy/tools/__init__.py,sha256=
|
17
|
-
pertpy/tools/_augur.py,sha256=
|
18
|
-
pertpy/tools/_cinemaot.py,sha256=
|
19
|
-
pertpy/tools/_dialogue.py,sha256=
|
20
|
-
pertpy/tools/_enrichment.py,sha256=
|
16
|
+
pertpy/preprocessing/_guide_rna.py,sha256=D9hEh8LOOTs_UfVBsBW3b-o6ipRFq09J471Hg1s0tlM,7963
|
17
|
+
pertpy/tools/__init__.py,sha256=NUTwCGxRdzUzLTgsS3r7MywENwPAdcGZDKrl83sU8mo,2599
|
18
|
+
pertpy/tools/_augur.py,sha256=Vghsx5-fYlaEeu__8-HUg6v5_KoVNhiRDjhgE3pORpY,55339
|
19
|
+
pertpy/tools/_cinemaot.py,sha256=U6vCb_mI4ZPFshYgsx-hOOsDA1IPwI7ZR_-IH4F9s7s,39621
|
20
|
+
pertpy/tools/_dialogue.py,sha256=BShXZ1ehO2eMbP5PV-ONJ-1SsxD6h9nAN7bGQ4_F6Rw,51906
|
21
|
+
pertpy/tools/_enrichment.py,sha256=jxVdOrpS_lAu7GCpemgdB4JJvsGH9SJTQsAKLBKi9Tc,21640
|
21
22
|
pertpy/tools/_kernel_pca.py,sha256=_EJ9WlBLjHOafF34sZGdyBgZL6Fj0WiJ1elVT1XMmo4,1579
|
22
|
-
pertpy/tools/_milo.py,sha256=
|
23
|
-
pertpy/tools/_mixscape.py,sha256=
|
23
|
+
pertpy/tools/_milo.py,sha256=SQqknT2zkzI0pcUmTm0ijWMs7CFMRiyRnXt9rC0jvmg,43811
|
24
|
+
pertpy/tools/_mixscape.py,sha256=T-oUHDnepao5aujAHw9bAbbQHPSK6oD_8Wr_mw4U0nc,52089
|
24
25
|
pertpy/tools/decoupler_LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
25
26
|
pertpy/tools/transferlearning_MMD_LICENSE,sha256=MUvDA-o_j9htRpI8fStVdCRuyLdPkQUuIH0a_EIc57w,1069
|
26
27
|
pertpy/tools/_coda/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
|
-
pertpy/tools/_coda/_base_coda.py,sha256=
|
28
|
+
pertpy/tools/_coda/_base_coda.py,sha256=Uxb_vAORL77LkUri8QqL9SxaULK4hfpgRGsCQdtsAgk,111808
|
28
29
|
pertpy/tools/_coda/_sccoda.py,sha256=gGmyd0MGpchulV9d4PxKSmGORyZ8fCDS9tQVPOuF_Og,22622
|
29
30
|
pertpy/tools/_coda/_tasccoda.py,sha256=vNk43OQHn7pBLsez2rmSj0bMZKOf8jZTI7G8TfBByRg,30665
|
30
|
-
pertpy/tools/_differential_gene_expression/__init__.py,sha256=
|
31
|
-
pertpy/tools/_differential_gene_expression/_base.py,sha256=
|
31
|
+
pertpy/tools/_differential_gene_expression/__init__.py,sha256=SEydWg0iT3Y1pApjnCAOuHxFeI6xVUfgyBHv2s3LADU,487
|
32
|
+
pertpy/tools/_differential_gene_expression/_base.py,sha256=yc9DBj2KgJVk4mkjz7EDFoBj8WBZW92Z4ayD-Xdla1g,38514
|
32
33
|
pertpy/tools/_differential_gene_expression/_checks.py,sha256=SxNHJDsCYZ6rWLTMEymEBpigs_B9cnXyw0kkAe1l6e0,1675
|
33
34
|
pertpy/tools/_differential_gene_expression/_dge_comparison.py,sha256=9HjmWkrqZhj_ZJeR-ymyEDzpRJNx7JiYJoStvCfKuCU,4188
|
34
|
-
pertpy/tools/_differential_gene_expression/_edger.py,sha256=
|
35
|
-
pertpy/tools/_differential_gene_expression/
|
36
|
-
pertpy/tools/_differential_gene_expression/_pydeseq2.py,sha256=JK7H7u4va0q_TLE_sqi4JEzoPBd_xNRycYGu1507HS4,4117
|
35
|
+
pertpy/tools/_differential_gene_expression/_edger.py,sha256=ttgTocAYnr8BTDcixwHGjRZew6zeja-U77TLKkSdd1Y,4857
|
36
|
+
pertpy/tools/_differential_gene_expression/_pydeseq2.py,sha256=aOqsdu8hKp8_h2HhjkxS0B_itxRBnzEU2oSnU2PYiQ4,2942
|
37
37
|
pertpy/tools/_differential_gene_expression/_simple_tests.py,sha256=tTSr0Z2Qbpxdy9bcO8Gi_up6R616IcoK_e4_rlanyx4,6621
|
38
|
-
pertpy/tools/_differential_gene_expression/_statsmodels.py,sha256=
|
38
|
+
pertpy/tools/_differential_gene_expression/_statsmodels.py,sha256=jBCtaCglOvvVjkIBGXuTCTDB6g2AJsZMCf7iOlDyn48,2195
|
39
39
|
pertpy/tools/_distances/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
40
40
|
pertpy/tools/_distances/_distance_tests.py,sha256=mNmNu5cX0Wj5IegR6x5K-CbBSid8EhrH2jZPQxuvK4U,13521
|
41
|
-
pertpy/tools/_distances/_distances.py,sha256=
|
41
|
+
pertpy/tools/_distances/_distances.py,sha256=CmrOKevVCTY9j3PzhpVc3ga6SwZy9wbbJa0_7bwLMWQ,50569
|
42
42
|
pertpy/tools/_perturbation_space/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
43
43
|
pertpy/tools/_perturbation_space/_clustering.py,sha256=m52-J8c8OnIXRCf3NoFabIO2yMHIuy1X0m0amtsK2vE,3556
|
44
44
|
pertpy/tools/_perturbation_space/_comparison.py,sha256=rLO-EGU0I7t5MnLw4k1gYU-ypRu-JsDPLat1t4h2U2M,4329
|
45
45
|
pertpy/tools/_perturbation_space/_discriminator_classifiers.py,sha256=OA2eZeG_4iuW1T5ilsRIkS0rU-azmwEch7IuB546KSY,21617
|
46
46
|
pertpy/tools/_perturbation_space/_metrics.py,sha256=y8-baP8WRdB1iDgvP3uuQxSCDxA2lcxvEHHM2C_vWHY,3248
|
47
|
-
pertpy/tools/_perturbation_space/_perturbation_space.py,sha256=
|
48
|
-
pertpy/tools/_perturbation_space/_simple.py,sha256=
|
47
|
+
pertpy/tools/_perturbation_space/_perturbation_space.py,sha256=F-F-_pMCTWxjkVQSLre6hrE6PeRfCRscpt2ug3NlfuU,19531
|
48
|
+
pertpy/tools/_perturbation_space/_simple.py,sha256=RQv6B0xPq4RJa5zlkLkxMYXQ3LAJLglmQDGaTMseaA8,14238
|
49
49
|
pertpy/tools/_scgen/__init__.py,sha256=uERFlFyF88TH0uLiwmsUGEfHfLVCiZMFuk8gO5f7164,45
|
50
50
|
pertpy/tools/_scgen/_base_components.py,sha256=Qq8myRUm43q9XBrZ9gBggfa2cSV2wbz_KYoLgH7iF1A,3009
|
51
|
-
pertpy/tools/_scgen/_scgen.py,sha256=
|
51
|
+
pertpy/tools/_scgen/_scgen.py,sha256=oVY2JNYhDn1OrPoq22ATIP5-H615BafidBCC0eC5C-4,30756
|
52
52
|
pertpy/tools/_scgen/_scgenvae.py,sha256=v_6tZ4wY-JjdMH1QVd_wG4_N0PoaqB-FM8zC2JsDu1o,3935
|
53
53
|
pertpy/tools/_scgen/_utils.py,sha256=1upgOt1FpadfvNG05YpMjYYG-IAlxrC3l_ZxczmIczo,2841
|
54
|
-
pertpy-0.9.
|
55
|
-
pertpy-0.9.
|
56
|
-
pertpy-0.9.
|
57
|
-
pertpy-0.9.
|
54
|
+
pertpy-0.9.5.dist-info/METADATA,sha256=vuf16H5cVKgNKRg35pFSWN7Oa7tT4IOLqqymVzZfnr4,6927
|
55
|
+
pertpy-0.9.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
56
|
+
pertpy-0.9.5.dist-info/licenses/LICENSE,sha256=OZ-ZkXM5CmExJiEMM90b_7dGNNvRpj7kdE-49AnrLuI,1070
|
57
|
+
pertpy-0.9.5.dist-info/RECORD,,
|
@@ -1,189 +0,0 @@
|
|
1
|
-
"""Helpers to interact with Formulaic Formulas
|
2
|
-
|
3
|
-
Some helpful definitions for working with formulaic formulas (e.g. `~ 0 + C(donor):treatment + np.log1p(continuous)`):
|
4
|
-
* A *term* refers to an expression in the formula, separated by `+`, e.g. `C(donor):treatment`, or `np.log1p(continuous)`.
|
5
|
-
* A *variable* refers to a column of the data frame passed to formulaic, e.g. `donor`.
|
6
|
-
* A *factor* is the specification of how a certain variable is represented in the design matrix, e.g. treatment coding with base level "A" and reduced rank.
|
7
|
-
"""
|
8
|
-
|
9
|
-
from collections import defaultdict
|
10
|
-
from collections.abc import Mapping, Sequence
|
11
|
-
from dataclasses import dataclass
|
12
|
-
from typing import Any
|
13
|
-
|
14
|
-
from formulaic import FactorValues, ModelSpec
|
15
|
-
from formulaic.materializers import PandasMaterializer
|
16
|
-
from formulaic.materializers.types import EvaluatedFactor
|
17
|
-
from formulaic.parser.types import Factor
|
18
|
-
from interface_meta import override
|
19
|
-
|
20
|
-
|
21
|
-
@dataclass
|
22
|
-
class FactorMetadata:
|
23
|
-
"""Store (relevant) metadata for a factor of a formula."""
|
24
|
-
|
25
|
-
name: str
|
26
|
-
"""The unambiguous factor name as specified in the formula. E.g. `donor`, or `C(donor, contr.treatment(base="A"))`"""
|
27
|
-
|
28
|
-
reduced_rank: bool
|
29
|
-
"""Whether a column will be dropped because it is redundant"""
|
30
|
-
|
31
|
-
custom_encoder: bool
|
32
|
-
"""Whether or not a custom encoder (e.g. `C(...)`) was used."""
|
33
|
-
|
34
|
-
categories: Sequence[str]
|
35
|
-
"""The unique categories in this factor (after applying `drop_rows`)"""
|
36
|
-
|
37
|
-
kind: Factor.Kind
|
38
|
-
"""Type of the factor"""
|
39
|
-
|
40
|
-
drop_field: str = None
|
41
|
-
"""The category that is dropped.
|
42
|
-
|
43
|
-
Note that
|
44
|
-
* this may also be populated if `reduced_rank = False`
|
45
|
-
* this is only populated when no encoder was used (e.g. `~ donor` but NOT `~ C(donor)`.
|
46
|
-
"""
|
47
|
-
|
48
|
-
column_names: Sequence[str] = None
|
49
|
-
"""The column names for this factor included in the design matrix.
|
50
|
-
|
51
|
-
This may be the same as `categories` if the default encoder is used, or
|
52
|
-
categories without the base level if a custom encoder (e.g. `C(...)`) is used.
|
53
|
-
"""
|
54
|
-
|
55
|
-
colname_format: str = None
|
56
|
-
"""A formattable string that can be used to generate the column name in the design matrix, e.g. `{name}[T.{field}]`"""
|
57
|
-
|
58
|
-
@property
|
59
|
-
def base(self) -> str | None:
|
60
|
-
"""
|
61
|
-
The base category for this categorical factor.
|
62
|
-
|
63
|
-
This is derived from `drop_field` (for default encoding) or by comparing the column names in
|
64
|
-
the design matrix with all categories (for custom encoding, e.g. `C(...)`).
|
65
|
-
"""
|
66
|
-
if not self.reduced_rank:
|
67
|
-
return None
|
68
|
-
else:
|
69
|
-
if self.custom_encoder:
|
70
|
-
tmp_base = set(self.categories) - set(self.column_names)
|
71
|
-
assert len(tmp_base) == 1
|
72
|
-
return tmp_base.pop()
|
73
|
-
else:
|
74
|
-
assert self.drop_field is not None
|
75
|
-
return self.drop_field
|
76
|
-
|
77
|
-
|
78
|
-
def get_factor_storage_and_materializer() -> tuple[dict[str, list[FactorMetadata]], dict[str, set[str]], type]:
|
79
|
-
"""Keeps track of categorical factors used in a model specification by generating a custom materializer.
|
80
|
-
|
81
|
-
This materializer reports back metadata upon materialization of the model matrix.
|
82
|
-
|
83
|
-
Returns:
|
84
|
-
- A dictionary storing metadata for each factor processed by the custom materializer, named `factor_storage`.
|
85
|
-
- A dictionary mapping variables to factor names, which works similarly to model_spec.variable_terms
|
86
|
-
but maps to factors rather than terms, named `variable_to_factors`.
|
87
|
-
- A materializer class tied to the specific instance of `factor_storage`.
|
88
|
-
"""
|
89
|
-
# There can be multiple FactorMetadata entries per sample, for instance when formulaic generates an interaction
|
90
|
-
# term, it generates the factor with both full rank and reduced rank.
|
91
|
-
factor_storage: dict[str, list[FactorMetadata]] = defaultdict(list)
|
92
|
-
variable_to_factors: dict[str, set[str]] = defaultdict(set)
|
93
|
-
|
94
|
-
class CustomPandasMaterializer(PandasMaterializer):
|
95
|
-
"""An extension of the PandasMaterializer that records all categorical variables and their (base) categories."""
|
96
|
-
|
97
|
-
REGISTER_NAME = "custom_pandas"
|
98
|
-
REGISTER_INPUTS = ("pandas.core.frame.DataFrame",)
|
99
|
-
REGISTER_OUTPUTS = ("pandas", "numpy", "sparse")
|
100
|
-
|
101
|
-
def __init__(
|
102
|
-
self,
|
103
|
-
data: Any,
|
104
|
-
context: Mapping[str, Any] | None = None,
|
105
|
-
record_factor_metadata: bool = False,
|
106
|
-
**params: Any,
|
107
|
-
):
|
108
|
-
"""Initialize the Materializer.
|
109
|
-
|
110
|
-
Args:
|
111
|
-
data: Passed to PandasMaterializer.
|
112
|
-
context: Passed to PandasMaterializer
|
113
|
-
record_factor_metadata: Flag that tells whether this particular instance of the custom materializer class
|
114
|
-
is supposed to record factor metadata. Only the instance that is used for building the design
|
115
|
-
matrix should record the metadata. All other instances (e.g. used to generate contrast vectors)
|
116
|
-
should not record metadata to not overwrite the specifications from the design matrix.
|
117
|
-
**params: Passed to PandasMaterializer
|
118
|
-
"""
|
119
|
-
self.factor_metadata_storage = factor_storage if record_factor_metadata else None
|
120
|
-
self.variable_to_factors = variable_to_factors if record_factor_metadata else None
|
121
|
-
# temporary pointer to metadata of factor that is currently evaluated
|
122
|
-
self._current_factor: FactorMetadata = None
|
123
|
-
super().__init__(data, context, **params)
|
124
|
-
|
125
|
-
@override
|
126
|
-
def _encode_evaled_factor(
|
127
|
-
self, factor: EvaluatedFactor, spec: ModelSpec, drop_rows: Sequence[int], reduced_rank: bool = False
|
128
|
-
) -> dict[str, Any]:
|
129
|
-
"""Function is called just before the factor is evaluated.
|
130
|
-
|
131
|
-
We can record some metadata, before we call the original function.
|
132
|
-
"""
|
133
|
-
assert (
|
134
|
-
self._current_factor is None
|
135
|
-
), "_current_factor should always be None when we start recording metadata"
|
136
|
-
if self.factor_metadata_storage is not None:
|
137
|
-
# Don't store if the factor is cached (then we should already have recorded it)
|
138
|
-
if factor.expr in self.encoded_cache or (factor.expr, reduced_rank) in self.encoded_cache:
|
139
|
-
assert factor.expr in self.factor_metadata_storage, "Factor should be there since it's cached"
|
140
|
-
else:
|
141
|
-
for var in factor.variables:
|
142
|
-
self.variable_to_factors[var].add(factor.expr)
|
143
|
-
self._current_factor = FactorMetadata(
|
144
|
-
name=factor.expr,
|
145
|
-
reduced_rank=reduced_rank,
|
146
|
-
categories=tuple(sorted(factor.values.drop(index=factor.values.index[drop_rows]).unique())),
|
147
|
-
custom_encoder=factor.metadata.encoder is not None,
|
148
|
-
kind=factor.metadata.kind,
|
149
|
-
)
|
150
|
-
return super()._encode_evaled_factor(factor, spec, drop_rows, reduced_rank)
|
151
|
-
|
152
|
-
@override
|
153
|
-
def _flatten_encoded_evaled_factor(self, name: str, values: FactorValues[dict]) -> dict[str, Any]:
|
154
|
-
"""
|
155
|
-
Function is called at the end, before the design matrix gets materialized.
|
156
|
-
|
157
|
-
Here we have access to additional metadata, such as `drop_field`.
|
158
|
-
"""
|
159
|
-
if self._current_factor is not None:
|
160
|
-
assert self._current_factor.name == name
|
161
|
-
self._current_factor.drop_field = values.__formulaic_metadata__.drop_field
|
162
|
-
self._current_factor.column_names = values.__formulaic_metadata__.column_names
|
163
|
-
self._current_factor.colname_format = values.__formulaic_metadata__.format
|
164
|
-
self.factor_metadata_storage[name].append(self._current_factor)
|
165
|
-
self._current_factor = None
|
166
|
-
|
167
|
-
return super()._flatten_encoded_evaled_factor(name, values)
|
168
|
-
|
169
|
-
return factor_storage, variable_to_factors, CustomPandasMaterializer
|
170
|
-
|
171
|
-
|
172
|
-
class AmbiguousAttributeError(ValueError):
|
173
|
-
pass
|
174
|
-
|
175
|
-
|
176
|
-
def resolve_ambiguous(objs: Sequence[Any], attr: str) -> Any:
|
177
|
-
"""Given a list of objects, return an attribute if it is the same between all object. Otherwise, raise an error."""
|
178
|
-
if not objs:
|
179
|
-
raise ValueError("Collection is empty")
|
180
|
-
|
181
|
-
first_obj_attr = getattr(objs[0], attr)
|
182
|
-
|
183
|
-
# Check if the attribute is the same for all objects
|
184
|
-
for obj in objs[1:]:
|
185
|
-
if getattr(obj, attr) != first_obj_attr:
|
186
|
-
raise AmbiguousAttributeError(f"Ambiguous attribute '{attr}': values differ between objects")
|
187
|
-
|
188
|
-
# If attribute is the same for all objects, return it
|
189
|
-
return first_obj_attr
|
File without changes
|