PyEvoMotion 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,520 @@
1
+ import pandas as pd
2
+ from matplotlib import pyplot as plt
3
+ from matplotlib.backends.backend_pdf import PdfPages
4
+
5
+ from .base import PyEvoMotionBase
6
+ from .parser import PyEvoMotionParser
7
+
8
+
9
+ class PyEvoMotion(PyEvoMotionParser, PyEvoMotionBase):
10
+ """
11
+ Main class to analyze the data as intended by ``PyEvoMotion``. This class inherits from :class:`PyEvoMotionParser` and :class:`PyEvoMotionBase`. On construction, it calls :meth:`count_mutation_types`.
12
+
13
+ :param input_fasta: The path to the input ``.fasta`` file.
14
+ :type input_fasta: str
15
+ :param input_meta: The path to the input metadata file. It has to have a column named ``date``. Accepts ``.csv`` and ``.tsv`` files. Default is ``None``.
16
+ :type input_meta: str
17
+ :param dt: The string datetime interval that will govern the grouping for the statistics. Default is 7 days (``7D``).
18
+ :type dt: str
19
+ :param filters: The filters to apply to the data. Default is ``None``.
20
+ :type filters: dict[str, list[str] | str] | None
21
+ :param positions: The positions to filter by. Default is ``None``.
22
+ :type positions: tuple[int] | None
23
+ :param date_range: The date range to filter by. Default is ``None``.
24
+ :type date_range: tuple[str] | None
25
+
26
+ Attributes:
27
+ -----------
28
+ data: ``pd.DataFrame``
29
+ The parsed data from the input files.
30
+ reference: ``str``
31
+ The reference sequence.
32
+ _MUTATION_TYPES: ``list[str]``
33
+ The types of mutations that can be found in the data. Namely ``substitutions`` and ``indels``.
34
+ """
35
+
36
+ _MUTATION_TYPES = ["substitutions", "indels"]
37
+
38
+ def __init__(self,
39
+ input_fasta: str,
40
+ input_meta: str,
41
+ dt: str = "7D",
42
+ filters: dict[str, list[str] | str] | None = None,
43
+ positions: tuple[int] | None = None,
44
+ date_range: tuple[str] | None = None
45
+ ) -> None:
46
+ """
47
+ Initialize the class.
48
+
49
+ It invokes the ``__init__()`` method of ``PyEvoMotionParser`` and ``count_mutation_types``.
50
+
51
+ :param input_fasta: The path to the input fasta file.
52
+ :type input_fasta: str
53
+ :param input_meta: The path to the input metadata file.
54
+ :type input_meta: str
55
+ :param dt: The string datetime interval that will govern the grouping for the statistics. Default is 7 days.
56
+ :type dt: str
57
+ :param filters: The filters to apply to the data. Default is None.
58
+ :type filters: dict[str, list[str] | str] | None
59
+ :param positions: The positions to filter by. Default is None.
60
+ :type positions: tuple[int] | None
61
+ :param date_range: The date range to filter by. Default is None.
62
+ :type date_range: tuple[str] | None
63
+ """
64
+
65
+ self.dt = dt
66
+
67
+ # Parse the input fasta and metadata files
68
+ super().__init__(
69
+ input_fasta,
70
+ input_meta,
71
+ filters if filters else {},
72
+ positions if positions else (0, 0),
73
+ date_range
74
+ )
75
+
76
+ self._check_dataset_is_not_empty(
77
+ self.data,
78
+ "Perhaps there were no entries or the filters provided (if any) are too restrictive."
79
+ )
80
+
81
+ # Set the origin of the data
82
+ self.origin = self.data["date"].min()
83
+ if date_range:
84
+ self.origin = min(self.origin, date_range[0]) if date_range[0] else self.origin
85
+
86
+ self.count_mutation_types()
87
+
88
+ @classmethod
89
+ def plot_results(cls,
90
+ stats: pd.DataFrame,
91
+ regs: dict[str, dict[str, any]],
92
+ data_xlabel_units: str
93
+ ) -> None:
94
+ """
95
+ Plot the results of the analysis.
96
+
97
+ :param stats: The statistics of the data. The first column has to be the date, the second column has to be the mean and the third column has to be the variance.
98
+ :type stats: pd.DataFrame
99
+ :param regs: The regression models.
100
+ :type regs: dict[str, dict[str, any]]
101
+ :param data_xlabel: The data ``xlabel`` units.
102
+ :type data_xlabel: str
103
+ """
104
+
105
+ _, ax = plt.subplots(3, 1, figsize=(10, 10))
106
+
107
+ # Mean
108
+ _model = next(
109
+ v
110
+ for k,v in regs.items()
111
+ if k.startswith("mean")
112
+ )
113
+ _mean_data = stats[stats.columns[1]]
114
+ cls.plot_single_data_and_model(
115
+ stats.index,
116
+ _mean_data,
117
+ _mean_data.name,
118
+ _model["model"],
119
+ r"$r^2$: " + f"{_model['r2']:.2f}",
120
+ data_xlabel_units,
121
+ ax[0]
122
+ )
123
+
124
+ # Variance
125
+ _model = next(
126
+ v
127
+ for k,v in regs.items()
128
+ if k.startswith("scaled var")
129
+ )
130
+ _variance_data = stats[stats.columns[2]]
131
+ cls.plot_single_data_and_model(
132
+ stats.index,
133
+ _variance_data,
134
+ _variance_data.name,
135
+ _model["model"],
136
+ r"$r^2$: " + f"{_model['r2']:.2f}",
137
+ data_xlabel_units,
138
+ ax[1]
139
+ )
140
+
141
+ # Dispersion index
142
+ cls.plot_single_data_and_model(
143
+ stats.index,
144
+ _mean_data/_variance_data,
145
+ f"dispersion index of {' '.join(_mean_data.name.split()[1:])}",
146
+ lambda x: [1]*len(x),
147
+ "Poissonian regime",
148
+ data_xlabel_units,
149
+ ax[2],
150
+ line_linestyle="--",
151
+ line_color="black"
152
+ )
153
+
154
+ plt.tight_layout()
155
+ plt.show()
156
+
157
+ @classmethod
158
+ def export_plot_results(cls,
159
+ stats: pd.DataFrame,
160
+ regs: dict[str, dict[str, any]],
161
+ data_xlabel_units: str,
162
+ output_ptr: str | None = None
163
+ ) -> None:
164
+ """
165
+ Export the results of the analysis to a ``.pdf`` file.
166
+
167
+ :param stats: The statistics of the data.
168
+ :type stats: pd.DataFrame
169
+ :param regs: The regression models.
170
+ :type regs: dict[str, dict[str, any]]
171
+ :param data_xlabel_units: The data ``xlabel`` units for the plot.
172
+ :type data_xlabel_units: str
173
+ :param output_ptr: The output ``.pdf`` file. If ``None``, it will create a new ``.pdf`` file.
174
+ :type output: str
175
+ """
176
+
177
+ pdf = output_ptr if output_ptr else PdfPages("output_plots.pdf")
178
+
179
+ plt.figure()
180
+ # Mean
181
+ _model = next(
182
+ v
183
+ for k,v in regs.items()
184
+ if k.startswith("mean")
185
+ )
186
+ _mean_data = stats[stats.columns[1]]
187
+ cls.plot_single_data_and_model(
188
+ stats.index,
189
+ _mean_data,
190
+ _mean_data.name,
191
+ _model["model"],
192
+ r"$r^2$: " + f"{_model['r2']:.2f}",
193
+ data_xlabel_units,
194
+ plt.gca()
195
+ )
196
+
197
+ plt.title(_mean_data.name)
198
+ pdf.savefig()
199
+ plt.close()
200
+
201
+ plt.figure()
202
+ # Variance
203
+ _model = next(
204
+ v
205
+ for k,v in regs.items()
206
+ if k.startswith("scaled var")
207
+ )
208
+ _variance_data = stats[stats.columns[2]]
209
+ cls.plot_single_data_and_model(
210
+ stats.index,
211
+ _variance_data,
212
+ _variance_data.name,
213
+ lambda x: _model["model"](x) + _variance_data.min(), # Adjust the model to the original variance
214
+ r"$r^2$: " + f"{_model['r2']:.2f}",
215
+ data_xlabel_units,
216
+ plt.gca()
217
+ )
218
+
219
+ plt.title(_variance_data.name)
220
+ plt.tight_layout()
221
+ pdf.savefig()
222
+ plt.close()
223
+
224
+ plt.figure()
225
+ # Dispersion index
226
+ _name = " ".join(_mean_data.name.split()[1:])
227
+ cls.plot_single_data_and_model(
228
+ stats.index,
229
+ _mean_data/_variance_data,
230
+ f"dispersion index of {_name}",
231
+ lambda x: [1]*len(x),
232
+ "Poissonian regime",
233
+ data_xlabel_units,
234
+ plt.gca(),
235
+ line_linestyle="--",
236
+ line_color="black"
237
+ )
238
+
239
+ plt.title(f"Dispersion index of {_name}")
240
+ plt.tight_layout()
241
+ pdf.savefig()
242
+ plt.close()
243
+
244
+ def count_mutation_types(self) -> None:
245
+ """
246
+ Count the number of substitutions, insertions and deletions in the data.
247
+
248
+ It updates the ``data`` attribute by adding the columns ``number of substitutions``, ``number of indels`` and ``number of mutations``.
249
+ """
250
+
251
+ for _type in self._MUTATION_TYPES + ["insertions", "deletions"]:
252
+ self.data[f"number of {_type}"] = self.data["mutation instructions"].apply(
253
+ lambda x: self.count_prefixes(_type[0], x)
254
+ )
255
+
256
+ # Set indels together just in case
257
+ self.data["number of indels"] = (
258
+ self.data["number of insertions"]
259
+ + self.data["number of deletions"]
260
+ )
261
+
262
+ self.data["number of mutations"] = self.data["mutation instructions"].apply(len)
263
+
264
+ def get_lengths(self) -> pd.Series:
265
+ """
266
+ Get the lengths of the sequences in the dataset.
267
+
268
+ :return: The lengths of the sequences.
269
+ :rtype: ``pd.Series``
270
+ """
271
+
272
+ return (
273
+ self.data["mutation instructions"].apply(
274
+ lambda x: sum(map(
275
+ lambda y: self.mutation_length_modification(y),
276
+ x
277
+ ))
278
+ )
279
+ + len(self.reference)
280
+ )
281
+
282
+ def length_filter(self, length: int, how: str="gt") -> None:
283
+ """
284
+ Filter the data by sequence length.
285
+
286
+ It updates the ``data`` attribute by filtering the data by the sequence length.
287
+
288
+ :param length: The length to filter by.
289
+ :type length: int
290
+ :param how: The filter condition. It can be ``gt`` (greater than), ``lt`` (less than) or ``eq`` (equal to).
291
+ :type how: str
292
+ """
293
+
294
+ if how == "gt":
295
+ self.data[self.get_lengths() > length]
296
+ elif how == "lt":
297
+ self.data[self.get_lengths() < length]
298
+ elif how == "eq":
299
+ self.data[self.get_lengths() == length]
300
+ else:
301
+ raise ValueError(f"Filter \"{how}\" not recognized")
302
+
303
+ self.data.reset_index(drop=True, inplace=True)
304
+
305
+ def n_filter(self, threshold: float | int = 0.01, how: str = "lt") -> None:
306
+ """
307
+ Filter the data by the number of ``N`` in the sequence.
308
+
309
+ It updates the ``data`` attribute by filtering the data by the number of ``N`` in the sequence.
310
+
311
+ :param threshold: The threshold to filter by. Must be between 0 and 1. Default is 0.01.
312
+ :type threshold: float | int
313
+ :param how: The filter condition. It can be ``gt`` (greater than), ``lt`` (less than) or ``eq`` (equal to).
314
+ :type how: str
315
+ """
316
+
317
+ if not (0 <= threshold <= 1):
318
+ raise ValueError("Threshold must be between 0 and 1")
319
+
320
+ N_freq = self.data["N count"]/self.get_lengths()
321
+
322
+ if how == "gt":
323
+ self.data[N_freq > threshold]
324
+ elif how == "lt":
325
+ self.data[N_freq < threshold]
326
+ elif how == "eq":
327
+ self.data[N_freq == threshold]
328
+ else:
329
+ raise ValueError(f"Filter \"{how}\" not recognized")
330
+
331
+ self.data.reset_index(drop=True, inplace=True)
332
+
333
+ @classmethod
334
+ def _mutation_type_switch(cls, mutation_kind: str) -> list[str]:
335
+ """
336
+ Switch the mutation kind to the corresponding list of mutation types.
337
+
338
+ This is used to subset the analysis to the desired mutation kind.
339
+
340
+ :param mutation_kind: the kind of mutation to compute the statistics for. Has to be one of "all", "total", "substitutions" or "indels".
341
+ :type mutation_kind: str
342
+ :return: the list of mutation types.
343
+ :rtype: list[str]
344
+ """
345
+
346
+ cases = {
347
+ "all": cls._MUTATION_TYPES + ["mutations"],
348
+ "total": ["mutations"],
349
+ "substitutions": [cls._MUTATION_TYPES[0]],
350
+ "indels": [cls._MUTATION_TYPES[1]]
351
+ }
352
+
353
+ choice = cases.get(mutation_kind, None)
354
+
355
+ if choice is None:
356
+ raise ValueError(f'Mutation kind \"{mutation_kind}\" not recognized. It has to be one of {", ".join(cases.keys())}')
357
+
358
+ return choice
359
+
360
+ def compute_stats(self,
361
+ DT: str,
362
+ origin: str,
363
+ n_threshold: int | None = None,
364
+ mutation_kind: str = "all"
365
+ ) -> pd.DataFrame:
366
+ """
367
+ Compute the length, mean and variance of the data.
368
+
369
+ It computes the mean and variance of the data for the specified mutation kind (or kinds) in the specified datetime interval and origin.
370
+
371
+ :param DT: The string datetime interval that will govern the grouping.
372
+ :type DT: str
373
+ :param origin: The string datetime that will be the origin of the grouping.
374
+ :type origin: str
375
+ :param n_threshold: Minimum number of sequences required in a time interval to compute statistics.
376
+ :type n_threshold: int | None
377
+ :param mutation_kind: The kind of mutation to compute the statistics for. Has to be one of ``all``, ``total``, ``substitutions``, ``insertions``, ``deletions`` or ``indels``. Default is ``all``.
378
+ :return: The statistics of the data.
379
+ :rtype: ``pd.DataFrame``
380
+ """
381
+
382
+ grouped = self.date_grouper(self.data, DT, origin)
383
+
384
+ # Only keep weeks where the number of observations is greater than the threshold
385
+ if n_threshold:
386
+
387
+ _filtered = grouped.filter(lambda x: len(x) >= n_threshold)
388
+
389
+ if len(_filtered) == 0:
390
+ raise ValueError(
391
+ f"No groups with at least {n_threshold} observations. Consider lowering the threshold."
392
+ )
393
+
394
+ grouped = self.date_grouper(
395
+ _filtered,
396
+ DT,
397
+ origin
398
+ )
399
+
400
+ levels = [
401
+ f"number of {x}"
402
+ for x in self._mutation_type_switch(mutation_kind)
403
+ ]
404
+
405
+ return pd.concat(
406
+ (
407
+ pd.DataFrame(self._invoke_method(grouped[levels], method))
408
+ .rename(
409
+ columns=lambda col: f"{method} {col}"
410
+ if method != "size" else "size"
411
+ )
412
+ for method in ("mean", "var", "size")
413
+ ),
414
+ axis=1
415
+ ).reset_index(level=['date'])
416
+
417
+ def analysis(self,
418
+ length: int,
419
+ n_threshold: int | None = None,
420
+ show: bool = False,
421
+ mutation_kind: str = "all",
422
+ export_plots_filename: str | None = None
423
+ ) -> tuple[pd.DataFrame, dict[str,dict[str,any]]]:
424
+ """
425
+ Perform the global analysis of the data.
426
+
427
+ It computes the statistics and the regression models for the mean and variance of the data.
428
+
429
+ :param length: The length to filter by.
430
+ :type length: int
431
+ :param n_threshold: Minimum number of sequences required in a time interval to compute statistics.
432
+ :param show: Whether to show the plots or not. Default is False.
433
+ :type show: bool
434
+ :param mutation_kind: The kind of mutation to compute the statistics for. Has to be one of ``all``, ``total``, ``substitutions`` or ``indels``. Default is ``all``.
435
+ :type mutation_kind: str
436
+ :param export_plots: Filename to export the plots. Default is None and does not export the plots.
437
+ :type export_plots: str | None
438
+ :return: The statistics and the regression models.
439
+ :rtype: ``tuple[pd.DataFrame, dict[str, dict[str, any]]]``
440
+ """
441
+
442
+ # Apply filters
443
+ self.n_filter()
444
+ self.length_filter(length=length)
445
+
446
+ # Compute the statistics for the specified mutation kinds
447
+ stats = self.compute_stats(
448
+ self.dt,
449
+ self.origin,
450
+ n_threshold,
451
+ mutation_kind
452
+ )
453
+
454
+
455
+ regs = {}
456
+ # For each column in the statistics (except the date and the size), compute the corresponding regression model
457
+ for col in stats.columns[1:-1]:
458
+ if col.startswith("mean"):
459
+ _single_regression = {
460
+ f"{col} per {self.dt} model": self.linear_regression(
461
+ *self._remove_nan(
462
+ stats.index, # Regression is given by the index, so in time, it is the same as multiplying by dt days
463
+ stats[col]
464
+ )
465
+ )
466
+ }
467
+ elif col.startswith("var"):
468
+ _single_regression = self.adjust_model(
469
+ stats.index,
470
+ stats[col] - stats[col].min(),
471
+ name=f"scaled {col} per {self.dt} model"
472
+ )
473
+ # Save the regression model
474
+ regs.update(_single_regression)
475
+
476
+ # Sets of mutation types used in the analysis
477
+ _sets = sorted({
478
+ " ".join(x.split()[1:])
479
+ for x in stats.columns[1:-1]
480
+ })
481
+
482
+ # Plot the results
483
+ if show:
484
+ # For each set of mutation types
485
+ for _type in _sets:
486
+ self.plot_results(
487
+ stats[["date", f"mean {_type}", f"var {_type}"]],
488
+ {
489
+ k: v
490
+ for k, v in regs.items()
491
+ if k in (
492
+ f"mean {_type} per {self.dt} model",
493
+ f"scaled var {_type} per {self.dt} model"
494
+ )
495
+ },
496
+ f"in steps of {self.dt} since {self.origin}"
497
+ )
498
+ # Export the plots
499
+ if export_plots_filename:
500
+ # Open pdf file pointer
501
+ pdf = PdfPages(f"{export_plots_filename}.pdf")
502
+ # For each set of mutation types save the plots
503
+ for _type in _sets:
504
+ self.export_plot_results(
505
+ stats[["date", f"mean {_type}", f"var {_type}"]],
506
+ {
507
+ k: v
508
+ for k, v in regs.items()
509
+ if k in (
510
+ f"mean {_type} per {self.dt} model",
511
+ f"scaled var {_type} per {self.dt} model"
512
+ )
513
+ },
514
+ f"in steps of {self.dt} since {self.origin}",
515
+ pdf
516
+ )
517
+ # Close pdf file pointer
518
+ pdf.close()
519
+
520
+ return stats, regs