pertpy 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (56) hide show
  1. pertpy/__init__.py +2 -1
  2. pertpy/data/__init__.py +61 -0
  3. pertpy/data/_dataloader.py +27 -23
  4. pertpy/data/_datasets.py +58 -0
  5. pertpy/metadata/__init__.py +2 -0
  6. pertpy/metadata/_cell_line.py +39 -70
  7. pertpy/metadata/_compound.py +3 -4
  8. pertpy/metadata/_drug.py +2 -6
  9. pertpy/metadata/_look_up.py +38 -51
  10. pertpy/metadata/_metadata.py +7 -10
  11. pertpy/metadata/_moa.py +2 -6
  12. pertpy/plot/__init__.py +0 -5
  13. pertpy/preprocessing/__init__.py +2 -0
  14. pertpy/preprocessing/_guide_rna.py +2 -3
  15. pertpy/tools/__init__.py +42 -4
  16. pertpy/tools/_augur.py +14 -15
  17. pertpy/tools/_cinemaot.py +2 -2
  18. pertpy/tools/_coda/_base_coda.py +118 -142
  19. pertpy/tools/_coda/_sccoda.py +16 -15
  20. pertpy/tools/_coda/_tasccoda.py +21 -22
  21. pertpy/tools/_dialogue.py +18 -23
  22. pertpy/tools/_differential_gene_expression/__init__.py +20 -0
  23. pertpy/tools/_differential_gene_expression/_base.py +657 -0
  24. pertpy/tools/_differential_gene_expression/_checks.py +41 -0
  25. pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
  26. pertpy/tools/_differential_gene_expression/_edger.py +125 -0
  27. pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
  28. pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
  29. pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
  30. pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
  31. pertpy/tools/_distances/_distance_tests.py +21 -16
  32. pertpy/tools/_distances/_distances.py +406 -70
  33. pertpy/tools/_enrichment.py +10 -15
  34. pertpy/tools/_kernel_pca.py +1 -1
  35. pertpy/tools/_milo.py +76 -53
  36. pertpy/tools/_mixscape.py +15 -11
  37. pertpy/tools/_perturbation_space/_clustering.py +5 -2
  38. pertpy/tools/_perturbation_space/_comparison.py +112 -0
  39. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +20 -22
  40. pertpy/tools/_perturbation_space/_perturbation_space.py +23 -21
  41. pertpy/tools/_perturbation_space/_simple.py +3 -3
  42. pertpy/tools/_scgen/__init__.py +1 -1
  43. pertpy/tools/_scgen/_base_components.py +2 -3
  44. pertpy/tools/_scgen/_scgen.py +33 -28
  45. pertpy/tools/_scgen/_utils.py +2 -2
  46. {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +22 -13
  47. pertpy-0.8.0.dist-info/RECORD +57 -0
  48. {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
  49. pertpy/plot/_augur.py +0 -171
  50. pertpy/plot/_coda.py +0 -601
  51. pertpy/plot/_guide_rna.py +0 -64
  52. pertpy/plot/_milopy.py +0 -209
  53. pertpy/plot/_mixscape.py +0 -355
  54. pertpy/tools/_differential_gene_expression.py +0 -325
  55. pertpy-0.7.0.dist-info/RECORD +0 -53
  56. {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,657 @@
1
+ import os
2
+ from abc import ABC, abstractmethod
3
+ from dataclasses import dataclass
4
+ from itertools import chain
5
+ from types import MappingProxyType
6
+
7
+ import adjustText
8
+ import anndata as ad
9
+ import matplotlib.patheffects as PathEffects
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import pandas as pd
13
+ import seaborn as sns
14
+ from matplotlib.ticker import MaxNLocator
15
+
16
+ from pertpy.tools._differential_gene_expression._checks import check_is_numeric_matrix
17
+ from pertpy.tools._differential_gene_expression._formulaic import (
18
+ AmbiguousAttributeError,
19
+ Factor,
20
+ get_factor_storage_and_materializer,
21
+ resolve_ambiguous,
22
+ )
23
+
24
+
25
+ @dataclass
26
+ class Contrast:
27
+ """Simple contrast for comparison between groups"""
28
+
29
+ column: str
30
+ baseline: str
31
+ group_to_compare: str
32
+
33
+
34
+ ContrastType = Contrast | tuple[str, str, str]
35
+
36
+
37
+ class MethodBase(ABC):
38
+ def __init__(self, adata, *, mask=None, layer=None, **kwargs):
39
+ """
40
+ Initialize the method.
41
+
42
+ Args:
43
+ adata: AnnData object, usually pseudobulked.
44
+ mask: A column in `adata.var` that contains a boolean mask with selected features.
45
+ layer: Layer to use in fit(). If None, use the X array.
46
+ **kwargs: Keyword arguments specific to the method implementation.
47
+ """
48
+ self.adata = adata
49
+ if mask is not None:
50
+ self.adata = self.adata[:, self.adata.var[mask]]
51
+
52
+ self.layer = layer
53
+ check_is_numeric_matrix(self.data)
54
+
55
+ @property
56
+ def data(self):
57
+ """Get the data matrix from anndata this object was initalized with (X or layer)."""
58
+ if self.layer is None:
59
+ return self.adata.X
60
+ else:
61
+ return self.adata.layer[self.layer]
62
+
63
+ @classmethod
64
+ @abstractmethod
65
+ def compare_groups(
66
+ cls,
67
+ adata,
68
+ column,
69
+ baseline,
70
+ groups_to_compare,
71
+ *,
72
+ paired_by=None,
73
+ mask=None,
74
+ layer=None,
75
+ fit_kwargs=MappingProxyType({}),
76
+ test_kwargs=MappingProxyType({}),
77
+ ):
78
+ """
79
+ Compare between groups in a specified column.
80
+
81
+ Args:
82
+ adata: AnnData object.
83
+ column: column in obs that contains the grouping information.
84
+ baseline: baseline value (one category from variable).
85
+ groups_to_compare: One or multiple categories from variable to compare against baseline.
86
+ paired_by: Column from `obs` that contains information about paired sample (e.g. subject_id).
87
+ mask: Subset anndata by a boolean mask stored in this column in `.obs` before making any tests.
88
+ layer: Use this layer instead of `.X`.
89
+ fit_kwargs: Additional fit options.
90
+ test_kwargs: Additional test options.
91
+
92
+ Returns:
93
+ Pandas dataframe with results ordered by significance. If multiple comparisons were performed this is indicated in an additional column.
94
+ """
95
+ ...
96
+
97
+ def plot_volcano(
98
+ self,
99
+ data: pd.DataFrame | ad.AnnData,
100
+ *,
101
+ log2fc_col: str = "log_fc",
102
+ pvalue_col: str = "adj_p_value",
103
+ symbol_col: str = "variable",
104
+ pval_thresh: float = 0.05,
105
+ log2fc_thresh: float = 0.75,
106
+ to_label: int | list[str] = 5,
107
+ s_curve: bool | None = False,
108
+ colors: list[str] = None,
109
+ varm_key: str | None = None,
110
+ color_dict: dict[str, list[str]] | None = None,
111
+ shape_dict: dict[str, list[str]] | None = None,
112
+ size_col: str | None = None,
113
+ fontsize: int = 10,
114
+ top_right_frame: bool = False,
115
+ figsize: tuple[int, int] = (5, 5),
116
+ legend_pos: tuple[float, float] = (1.6, 1),
117
+ point_sizes: tuple[int, int] = (15, 150),
118
+ save: bool | str | None = None,
119
+ shapes: list[str] | None = None,
120
+ shape_order: list[str] | None = None,
121
+ x_label: str | None = None,
122
+ y_label: str | None = None,
123
+ **kwargs: int,
124
+ ) -> None:
125
+ """Creates a volcano plot from a pandas DataFrame or Anndata.
126
+
127
+ Args:
128
+ data: DataFrame or Anndata to plot.
129
+ log2fc_col: Column name of log2 Fold-Change values.
130
+ pvalue_col: Column name of the p values.
131
+ symbol_col: Column name of gene IDs.
132
+ varm_key: Key in Anndata.varm slot to use for plotting if an Anndata object was passed.
133
+ size_col: Column name to size points by.
134
+ point_sizes: Lower and upper bounds of point sizes.
135
+ pval_thresh: Threshold p value for significance.
136
+ log2fc_thresh: Threshold for log2 fold change significance.
137
+ to_label: Number of top genes or list of genes to label.
138
+ s_curve: Whether to use a reciprocal threshold for up and down gene determination.
139
+ color_dict: Dictionary for coloring dots by categories.
140
+ shape_dict: Dictionary for shaping dots by categories.
141
+ fontsize: Size of gene labels.
142
+ colors: Colors for [non-DE, up, down] genes. Defaults to ['gray', '#D62728', '#1F77B4'].
143
+ top_right_frame: Whether to show the top and right frame of the plot.
144
+ figsize: Size of the figure.
145
+ legend_pos: Position of the legend as determined by matplotlib.
146
+ save: Saves the plot if True or to the path provided.
147
+ shapes: List of matplotlib marker ids.
148
+ shape_order: Order of categories for shapes.
149
+ x_label: Label for the x-axis.
150
+ y_label: Label for the y-axis.
151
+ **kwargs: Additional arguments for seaborn.scatterplot.
152
+ """
153
+ if colors is None:
154
+ colors = ["gray", "#D62728", "#1F77B4"]
155
+
156
+ def _pval_reciprocal(lfc: float) -> float:
157
+ """
158
+ Function for relating -log10(pvalue) and logfoldchange in a reciprocal.
159
+
160
+ Used for plotting the S-curve
161
+ """
162
+ return pval_thresh / (lfc - log2fc_thresh)
163
+
164
+ def _map_shape(symbol: str) -> str:
165
+ if shape_dict is not None:
166
+ for k in shape_dict.keys():
167
+ if shape_dict[k] is not None and symbol in shape_dict[k]:
168
+ return k
169
+ return "other"
170
+
171
+ # TODO join the two mapping functions
172
+ def _map_genes_categories(
173
+ row: pd.Series,
174
+ log2fc_col: str,
175
+ nlog10_col: str,
176
+ log2fc_thresh: float,
177
+ pval_thresh: float = None,
178
+ s_curve: bool = False,
179
+ ) -> str:
180
+ """
181
+ Map genes to categorize based on log2fc and pvalue.
182
+
183
+ These categories are used for coloring the dots.
184
+ Used when no color_dict is passed, sets up/down/nonsignificant.
185
+ """
186
+ log2fc = row[log2fc_col]
187
+ nlog10 = row[nlog10_col]
188
+
189
+ if s_curve:
190
+ # S-curve condition for Up or Down categorization
191
+ reciprocal_thresh = _pval_reciprocal(abs(log2fc))
192
+ if log2fc > log2fc_thresh and nlog10 > reciprocal_thresh:
193
+ return "Up"
194
+ elif log2fc < -log2fc_thresh and nlog10 > reciprocal_thresh:
195
+ return "Down"
196
+ else:
197
+ return "not DE"
198
+ else:
199
+ # Standard condition for Up or Down categorization
200
+ if log2fc > log2fc_thresh and nlog10 > pval_thresh:
201
+ return "Up"
202
+ elif log2fc < -log2fc_thresh and nlog10 > pval_thresh:
203
+ return "Down"
204
+ else:
205
+ return "not DE"
206
+
207
+ def _map_genes_categories_highlight(
208
+ row: pd.Series,
209
+ log2fc_col: str,
210
+ nlog10_col: str,
211
+ log2fc_thresh: float,
212
+ pval_thresh: float = None,
213
+ s_curve: bool = False,
214
+ symbol_col: str = None,
215
+ ) -> str:
216
+ """
217
+ Map genes to categorize based on log2fc and pvalue.
218
+
219
+ These categories are used for coloring the dots.
220
+ Used when color_dict is passed, sets DE / not DE for background and user supplied highlight genes.
221
+ """
222
+ log2fc = row[log2fc_col]
223
+ nlog10 = row[nlog10_col]
224
+ symbol = row[symbol_col]
225
+
226
+ if color_dict is not None:
227
+ for k in color_dict.keys():
228
+ if symbol in color_dict[k]:
229
+ return k
230
+
231
+ if s_curve:
232
+ # Use S-curve condition for filtering DE
233
+ if nlog10 > _pval_reciprocal(abs(log2fc)) and abs(log2fc) > log2fc_thresh:
234
+ return "DE"
235
+ return "not DE"
236
+ else:
237
+ # Use standard condition for filtering DE
238
+ if abs(log2fc) < log2fc_thresh or nlog10 < pval_thresh:
239
+ return "not DE"
240
+ return "DE"
241
+
242
+ if isinstance(data, ad.AnnData):
243
+ if varm_key is None:
244
+ raise ValueError("Please pass a .varm key to use for plotting")
245
+
246
+ raise NotImplementedError("Anndata not implemented yet")
247
+ df = data.varm[varm_key].copy()
248
+
249
+ df = data.copy(deep=True)
250
+
251
+ # clean and replace 0s as they would lead to -inf
252
+ if df[[log2fc_col, pvalue_col]].isnull().values.any():
253
+ print("NaNs encountered, dropping rows with NaNs")
254
+ df = df.dropna(subset=[log2fc_col, pvalue_col])
255
+
256
+ if df[pvalue_col].min() == 0:
257
+ print("0s encountered for p value, replacing with 1e-323")
258
+ df.loc[df[pvalue_col] == 0, pvalue_col] = 1e-323
259
+
260
+ # convert p value threshold to nlog10
261
+ pval_thresh = -np.log10(pval_thresh)
262
+ # make nlog10 column
263
+ df["nlog10"] = -np.log10(df[pvalue_col])
264
+ y_max = df["nlog10"].max() + 1
265
+ # make a column to pick top genes
266
+ df["top_genes"] = df["nlog10"] * df[log2fc_col]
267
+
268
+ # Label everything with assigned color / shape
269
+ if shape_dict or color_dict:
270
+ combined_labels = []
271
+ if isinstance(shape_dict, dict):
272
+ combined_labels.extend([item for sublist in shape_dict.values() for item in sublist])
273
+ if isinstance(color_dict, dict):
274
+ combined_labels.extend([item for sublist in color_dict.values() for item in sublist])
275
+ label_df = df[df[symbol_col].isin(combined_labels)]
276
+
277
+ # Label top n_gens
278
+ elif isinstance(to_label, int):
279
+ label_df = pd.concat(
280
+ (
281
+ df.sort_values("top_genes")[-to_label:],
282
+ df.sort_values("top_genes")[0:to_label],
283
+ )
284
+ )
285
+
286
+ # assume that a list of genes was passed to label
287
+ else:
288
+ label_df = df[df[symbol_col].isin(to_label)]
289
+
290
+ # By default mode colors by up/down if no dict is passed
291
+
292
+ if color_dict is None:
293
+ df["color"] = df.apply(
294
+ lambda row: _map_genes_categories(
295
+ row,
296
+ log2fc_col=log2fc_col,
297
+ nlog10_col="nlog10",
298
+ log2fc_thresh=log2fc_thresh,
299
+ pval_thresh=pval_thresh,
300
+ s_curve=s_curve,
301
+ ),
302
+ axis=1,
303
+ )
304
+
305
+ # order of colors
306
+ hues = ["not DE", "Up", "Down"][: len(df.color.unique())]
307
+
308
+ else:
309
+ df["color"] = df.apply(
310
+ lambda row: _map_genes_categories_highlight(
311
+ row,
312
+ log2fc_col=log2fc_col,
313
+ nlog10_col="nlog10",
314
+ log2fc_thresh=log2fc_thresh,
315
+ pval_thresh=pval_thresh,
316
+ symbol_col=symbol_col,
317
+ s_curve=s_curve,
318
+ ),
319
+ axis=1,
320
+ )
321
+
322
+ user_added_cats = [x for x in df.color.unique() if x not in ["DE", "not DE"]]
323
+ hues = ["DE", "not DE"] + user_added_cats
324
+
325
+ # order of colors
326
+ hues = hues[: len(df.color.unique())]
327
+ colors = [
328
+ "dimgrey",
329
+ "lightgrey",
330
+ "tab:blue",
331
+ "tab:orange",
332
+ "tab:green",
333
+ "tab:red",
334
+ "tab:purple",
335
+ "tab:brown",
336
+ "tab:pink",
337
+ "tab:olive",
338
+ "tab:cyan",
339
+ ]
340
+
341
+ # coloring if dictionary passed, subtle background + highlight
342
+ # map shapes if dictionary exists
343
+ if shape_dict is not None:
344
+ df["shape"] = df[symbol_col].map(_map_shape)
345
+ user_added_cats = [x for x in df["shape"].unique() if x != "other"]
346
+ shape_order = ["other"] + user_added_cats
347
+ if shapes is None:
348
+ shapes = ["o", "^", "s", "X", "*", "d"]
349
+ shapes = shapes[: len(df["shape"].unique())]
350
+ shape_col = "shape"
351
+ else:
352
+ shape_col = None
353
+
354
+ # build palette
355
+ colors = colors[: len(df.color.unique())]
356
+
357
+ # We want plot highlighted genes on top + at bigger size, split dataframe
358
+ df_highlight = None
359
+ if shape_dict or color_dict:
360
+ label_genes = label_df[symbol_col].unique()
361
+ df_highlight = df[df[symbol_col].isin(label_genes)]
362
+ df = df[~df[symbol_col].isin(label_genes)]
363
+
364
+ plt.figure(figsize=figsize)
365
+ # Plot non-highlighted genes
366
+ ax = sns.scatterplot(
367
+ data=df,
368
+ x=log2fc_col,
369
+ y="nlog10",
370
+ hue="color",
371
+ hue_order=hues,
372
+ palette=colors,
373
+ size=size_col,
374
+ sizes=point_sizes,
375
+ style=shape_col,
376
+ style_order=shape_order,
377
+ markers=shapes,
378
+ **kwargs,
379
+ )
380
+ # Plot highlighted genes
381
+ if df_highlight is not None:
382
+ ax = sns.scatterplot(
383
+ data=df_highlight,
384
+ x=log2fc_col,
385
+ y="nlog10",
386
+ hue="color",
387
+ hue_order=hues,
388
+ palette=colors,
389
+ size=size_col,
390
+ sizes=point_sizes,
391
+ style=shape_col,
392
+ style_order=shape_order,
393
+ markers=shapes,
394
+ legend=False,
395
+ edgecolor="black",
396
+ linewidth=1,
397
+ **kwargs,
398
+ )
399
+
400
+ # plot vertical and horizontal lines
401
+ if s_curve:
402
+ x = np.arange((log2fc_thresh + 0.000001), y_max, 0.01)
403
+ y = _pval_reciprocal(x)
404
+ ax.plot(x, y, zorder=1, c="k", lw=2, ls="--")
405
+ ax.plot(-x, y, zorder=1, c="k", lw=2, ls="--")
406
+
407
+ else:
408
+ ax.axhline(pval_thresh, zorder=1, c="k", lw=2, ls="--")
409
+ ax.axvline(log2fc_thresh, zorder=1, c="k", lw=2, ls="--")
410
+ ax.axvline(log2fc_thresh * -1, zorder=1, c="k", lw=2, ls="--")
411
+ plt.ylim(0, y_max)
412
+ ax.xaxis.set_major_locator(MaxNLocator(integer=True))
413
+
414
+ # make labels
415
+ texts = []
416
+ for i in range(len(label_df)):
417
+ txt = plt.text(
418
+ x=label_df.iloc[i][log2fc_col],
419
+ y=label_df.iloc[i].nlog10,
420
+ s=label_df.iloc[i][symbol_col],
421
+ fontsize=fontsize,
422
+ )
423
+
424
+ txt.set_path_effects([PathEffects.withStroke(linewidth=3, foreground="w")])
425
+ texts.append(txt)
426
+
427
+ adjustText.adjust_text(texts, arrowprops={"arrowstyle": "-", "color": "k", "zorder": 5})
428
+
429
+ # make things pretty
430
+ for axis in ["bottom", "left", "top", "right"]:
431
+ ax.spines[axis].set_linewidth(2)
432
+
433
+ if not top_right_frame:
434
+ ax.spines["top"].set_visible(False)
435
+ ax.spines["right"].set_visible(False)
436
+
437
+ ax.tick_params(width=2)
438
+ plt.xticks(size=11, fontsize=10)
439
+ plt.yticks(size=11)
440
+
441
+ # Set default axis titles
442
+ if x_label is None:
443
+ x_label = log2fc_col
444
+ if y_label is None:
445
+ y_label = f"-$log_{{10}}$ {pvalue_col}"
446
+
447
+ plt.xlabel(x_label, size=15)
448
+ plt.ylabel(y_label, size=15)
449
+
450
+ plt.legend(loc=1, bbox_to_anchor=legend_pos, frameon=False)
451
+
452
+ # TODO replace with scanpy save style
453
+ if save:
454
+ files = os.listdir()
455
+ for x in range(100):
456
+ file_pref = "volcano_" + "%02d" % (x,)
457
+ if len([x for x in files if x.startswith(file_pref)]) == 0:
458
+ plt.savefig(file_pref + ".png", dpi=300, bbox_inches="tight")
459
+ plt.savefig(file_pref + ".svg", bbox_inches="tight")
460
+ break
461
+ elif isinstance(save, str):
462
+ plt.savefig(save + ".png", dpi=300, bbox_inches="tight")
463
+ plt.savefig(save + ".svg", bbox_inches="tight")
464
+
465
+ plt.show()
466
+
467
+
468
+ class LinearModelBase(MethodBase):
469
+ def __init__(self, adata, design, *, mask=None, layer=None, **kwargs):
470
+ """
471
+ Initialize the method.
472
+
473
+ Args:
474
+ adata: AnnData object, usually pseudobulked.
475
+ design: Model design. Can be either a design matrix, a formulaic formula.Formulaic formula in the format 'x + z' or '~x+z'.
476
+ mask: A column in adata.var that contains a boolean mask with selected features.
477
+ layer: Layer to use in fit(). If None, use the X array.
478
+ **kwargs: Keyword arguments specific to the method implementation.
479
+ """
480
+ super().__init__(adata, mask=mask, layer=layer)
481
+ self._check_counts()
482
+
483
+ self.factor_storage = None
484
+ self.variable_to_factors = None
485
+
486
+ if isinstance(design, str):
487
+ self.factor_storage, self.variable_to_factors, materializer_class = get_factor_storage_and_materializer()
488
+ self.design = materializer_class(adata.obs, record_factor_metadata=True).get_model_matrix(design)
489
+ else:
490
+ self.design = design
491
+
492
+ @classmethod
493
+ def compare_groups(
494
+ cls,
495
+ adata,
496
+ column,
497
+ baseline,
498
+ groups_to_compare,
499
+ *,
500
+ paired_by=None,
501
+ mask=None,
502
+ layer=None,
503
+ fit_kwargs=MappingProxyType({}),
504
+ test_kwargs=MappingProxyType({}),
505
+ ):
506
+ design = f"~{column}"
507
+ if paired_by is not None:
508
+ design += f"+{paired_by}"
509
+ if isinstance(groups_to_compare, str):
510
+ groups_to_compare = [groups_to_compare]
511
+ model = cls(adata, design=design, mask=mask, layer=layer)
512
+
513
+ model.fit(**fit_kwargs)
514
+
515
+ de_res = model.test_contrasts(
516
+ {
517
+ group_to_compare: model.contrast(column=column, baseline=baseline, group_to_compare=group_to_compare)
518
+ for group_to_compare in groups_to_compare
519
+ },
520
+ **test_kwargs,
521
+ )
522
+
523
+ return de_res
524
+
525
+ @property
526
+ def variables(self):
527
+ """Get the names of the variables used in the model definition."""
528
+ try:
529
+ return self.design.model_spec.variables_by_source["data"]
530
+ except AttributeError:
531
+ raise ValueError(
532
+ "Retrieving variables is only possible if the model was initialized using a formula."
533
+ ) from None
534
+
535
+ @abstractmethod
536
+ def _check_counts(self):
537
+ """
538
+ Check that counts are valid for the specific method.
539
+
540
+ Raises:
541
+ ValueError: if the data matrix does not comply with the expectations.
542
+ """
543
+ ...
544
+
545
+ @abstractmethod
546
+ def fit(self, **kwargs):
547
+ """
548
+ Fit the model.
549
+
550
+ Args:
551
+ **kwargs: Additional arguments for fitting the specific method.
552
+ """
553
+ ...
554
+
555
+ @abstractmethod
556
+ def _test_single_contrast(self, contrast, **kwargs): ...
557
+
558
+ def test_contrasts(self, contrasts, **kwargs):
559
+ """
560
+ Perform a comparison as specified in a contrast vector.
561
+
562
+ Args:
563
+ contrasts: Either a numeric contrast vector, or a dictionary of numeric contrast vectors.
564
+ **kwargs: passed to the respective implementation.
565
+
566
+ Returns:
567
+ A dataframe with the results.
568
+ """
569
+ if not isinstance(contrasts, dict):
570
+ contrasts = {None: contrasts}
571
+ results = []
572
+ for name, contrast in contrasts.items():
573
+ results.append(self._test_single_contrast(contrast, **kwargs).assign(contrast=name))
574
+
575
+ results_df = pd.concat(results)
576
+ return results_df
577
+
578
+ def test_reduced(self, modelB):
579
+ """
580
+ Test against a reduced model.
581
+
582
+ Args:
583
+ modelB: the reduced model against which to test.
584
+
585
+ Example:
586
+ modelA = Model().fit()
587
+ modelB = Model().fit()
588
+ modelA.test_reduced(modelB)
589
+ """
590
+ raise NotImplementedError
591
+
592
+ def cond(self, **kwargs):
593
+ """
594
+ Get a contrast vector representing a specific condition.
595
+
596
+ Args:
597
+ **kwargs: column/value pairs.
598
+
599
+ Returns:
600
+ A contrast vector that aligns to the columns of the design matrix.
601
+ """
602
+ if self.factor_storage is None:
603
+ raise RuntimeError(
604
+ "Building contrasts with `cond` only works if you specified the model using a formulaic formula. Please manually provide a contrast vector."
605
+ )
606
+ cond_dict = kwargs
607
+ if not set(cond_dict.keys()).issubset(self.variables):
608
+ raise ValueError(
609
+ "You specified a variable that is not part of the model. Available variables: "
610
+ + ",".join(self.variables)
611
+ )
612
+ for var in self.variables:
613
+ if var in cond_dict:
614
+ self._check_category(var, cond_dict[var])
615
+ else:
616
+ cond_dict[var] = self._get_default_value(var)
617
+ df = pd.DataFrame([kwargs])
618
+ return self.design.model_spec.get_model_matrix(df).iloc[0]
619
+
620
+ def _get_factor_metadata_for_variable(self, var):
621
+ factors = self.variable_to_factors[var]
622
+ return list(chain.from_iterable(self.factor_storage[f] for f in factors))
623
+
624
+ def _get_default_value(self, var):
625
+ factor_metadata = self._get_factor_metadata_for_variable(var)
626
+ if resolve_ambiguous(factor_metadata, "kind") == Factor.Kind.CATEGORICAL:
627
+ try:
628
+ tmp_base = resolve_ambiguous(factor_metadata, "base")
629
+ except AmbiguousAttributeError as e:
630
+ raise ValueError(
631
+ f"Could not automatically resolve base category for variable {var}. Please specify it explicity in `model.cond`."
632
+ ) from e
633
+ return tmp_base if tmp_base is not None else "\0"
634
+ else:
635
+ return 0
636
+
637
+ def _check_category(self, var, value):
638
+ factor_metadata = self._get_factor_metadata_for_variable(var)
639
+ tmp_categories = resolve_ambiguous(factor_metadata, "categories")
640
+ if resolve_ambiguous(factor_metadata, "kind") == Factor.Kind.CATEGORICAL and value not in tmp_categories:
641
+ raise ValueError(
642
+ f"You specified a non-existant category for {var}. Possible categories: {', '.join(tmp_categories)}"
643
+ )
644
+
645
+ def contrast(self, column, baseline, group_to_compare):
646
+ """
647
+ Build a simple contrast for pairwise comparisons.
648
+
649
+ Args:
650
+ column: column in adata.obs to test on.
651
+ baseline: baseline category (denominator).
652
+ group_to_compare: category to compare against baseline (nominator).
653
+
654
+ Returns:
655
+ Numeric contrast vector.
656
+ """
657
+ return self.cond(**{column: group_to_compare}) - self.cond(**{column: baseline})
@@ -0,0 +1,41 @@
1
+ import numpy as np
2
+ from scipy.sparse import issparse, spmatrix
3
+
4
+
5
+ def check_is_numeric_matrix(array: np.ndarray | spmatrix) -> None:
6
+ """Check if a matrix is numeric and only contains finite/non-NA values.
7
+
8
+ Args:
9
+ array: Dense or sparse matrix to check.
10
+
11
+ Raises:
12
+ ValueError: If the matrix is not numeric or contains NaNs or infinite values.
13
+ """
14
+ if not np.issubdtype(array.dtype, np.number):
15
+ raise ValueError("Counts must be numeric.")
16
+ if issparse(array):
17
+ if np.any(~np.isfinite(array.data)):
18
+ raise ValueError("Counts cannot contain negative, NaN or Inf values.")
19
+ else:
20
+ if np.any(~np.isfinite(array)):
21
+ raise ValueError("Counts cannot contain negative, NaN or Inf values.")
22
+
23
+
24
+ def check_is_integer_matrix(array: np.ndarray | spmatrix, tolerance: float = 1e-6) -> None:
25
+ """Check if a matrix container integers, or floats that are close to integers.
26
+
27
+ Args:
28
+ array: Dense or sparse matrix to check.
29
+ tolerance: Values must be this close to integers.
30
+
31
+ Raises:
32
+ ValueError: If the matrix contains values that are not close to integers.
33
+ """
34
+ if issparse(array):
35
+ if not array.data.dtype.kind == "i" and not np.all(np.abs(array.data - np.round(array.data)) < tolerance):
36
+ raise ValueError("Non-zero elements of the matrix must be close to integer values.")
37
+ else:
38
+ if not array.dtype.kind == "i" and not np.all(np.abs(array - np.round(array)) < tolerance):
39
+ raise ValueError("Matrix must be a count matrix.")
40
+ if (array < 0).sum() > 0:
41
+ raise ValueError("Non-zero elements of the matrix must be positive.")