geney 1.2.20__py2.py3-none-any.whl → 1.2.22__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of geney might be problematic. Click here for more details.

Files changed (39) hide show
  1. geney/oncosplice.py +1 -1
  2. {geney-1.2.20.dist-info → geney-1.2.22.dist-info}/METADATA +1 -1
  3. geney-1.2.22.dist-info/RECORD +19 -0
  4. geney/Gene.py +0 -258
  5. geney/analyzers/__init__.py +0 -0
  6. geney/analyzers/benchmark_clinvar.py +0 -158
  7. geney/analyzers/characterize_epistasis.py +0 -15
  8. geney/analyzers/compare_sets.py +0 -91
  9. geney/analyzers/group_comparison.py +0 -81
  10. geney/analyzers/survival.py +0 -144
  11. geney/analyzers/tcga_annotations.py +0 -194
  12. geney/analyzers/visualize_protein_conservation.py +0 -398
  13. geney/benchmark_clinvar.py +0 -158
  14. geney/compare_sets.py +0 -91
  15. geney/data_parsers/__init__.py +0 -0
  16. geney/data_parsers/gtex.py +0 -68
  17. geney/gtex.py +0 -68
  18. geney/immunotherapy/__init__.py +0 -0
  19. geney/immunotherapy/netchop.py +0 -78
  20. geney/mutations/__init__.py +0 -0
  21. geney/mutations/variant_utils.py +0 -125
  22. geney/netchop.py +0 -79
  23. geney/oncosplice/__init__.py +0 -0
  24. geney/oncosplice_mouse.py +0 -277
  25. geney/oncosplice_pipeline.py +0 -1588
  26. geney/performance_utils.py +0 -138
  27. geney/pipelines/__init__.py +0 -0
  28. geney/pipelines/dask_utils.py +0 -153
  29. geney/splicing/__init__.py +0 -2
  30. geney/splicing/spliceai_utils.py +0 -253
  31. geney/splicing/splicing_isoform_utils.py +0 -0
  32. geney/splicing/splicing_utils.py +0 -366
  33. geney/survival.py +0 -124
  34. geney/tcga_annotations.py +0 -352
  35. geney/translation_termination/__init__.py +0 -0
  36. geney/translation_termination/tts_utils.py +0 -0
  37. geney-1.2.20.dist-info/RECORD +0 -52
  38. {geney-1.2.20.dist-info → geney-1.2.22.dist-info}/WHEEL +0 -0
  39. {geney-1.2.20.dist-info → geney-1.2.22.dist-info}/top_level.txt +0 -0
@@ -1,366 +0,0 @@
1
- from geney.utils import reverse_complement, find_files_by_gene_name, unload_json, dump_json, unload_pickle
2
- from geney.Fasta_segment import Fasta_segment
3
- from geney.mutations.variant_utils import generate_mut_variant
4
- from geney import config_setup
5
- import networkx as nx
6
- import random
7
- from dataclasses import dataclass
8
-
9
- '''
10
- SpliceAI util functions.
11
- '''
12
- import numpy as np
13
- import tensorflow as tf
14
- from keras.models import load_model
15
- from pkg_resources import resource_filename
16
- from spliceai.utils import one_hot_encode
17
-
18
- tf.config.threading.set_intra_op_parallelism_threads(1)
19
- tf.config.threading.set_inter_op_parallelism_threads(1)
20
-
21
- sai_paths = ('models/spliceai{}.h5'.format(x) for x in range(1, 6))
22
- sai_models = [load_model(resource_filename('spliceai', x)) for x in sai_paths]
23
-
24
- def sai_predict_probs(seq: str, models: list) -> list:
25
- '''
26
- Predicts the donor and acceptor junction probability of each
27
- NT in seq using SpliceAI.
28
-
29
- Let m:=2*sai_mrg_context + L be the input seq length. It is assumed
30
- that the input seq has the following structure:
31
-
32
- seq = |<sai_mrg_context NTs><L NTs><sai_mrg_context NTs>|
33
-
34
- The returned probability matrix is of size 2XL, where
35
- the first row is the acceptor probability and the second row
36
- is the donor probability. These probabilities corresponds to the
37
- middel <L NTs> NTs of the input seq.
38
- '''
39
- x = one_hot_encode(seq)[None, :]
40
- y = np.mean([models[m].predict(x, verbose=0) for m in range(5)], axis=0)
41
- return y[0,:,1:].T
42
-
43
-
44
- def get_actual_sai_seq(seq: str, sai_mrg_context: int=5000) -> str:
45
- '''
46
- This dfunction assumes that the input seq has the following structure:
47
-
48
- seq = |<sai_mrg_context NTs><L NTs><sai_mrg_context NTs>|.
49
-
50
- Then, the function returns the sequence: |<L NTs>|
51
- '''
52
- return seq[sai_mrg_context:-sai_mrg_context]
53
-
54
-
55
- ############################################################################################
56
- ############################################################################################
57
- ############# BEGIN CUSTOM SAI USE CASES ###################################################
58
- ############################################################################################
59
- ############################################################################################
60
-
61
-
62
- def find_ss_changes(ref_dct, mut_dct, known_splice_sites, threshold=0.5):
63
- '''
64
- :param ref_dct: the spliceai probabilities for each nucleotide (by genomic position) as a dictionary for the reference sequence
65
- :param mut_dct: the spliceai probabilities for each nucleotide (by genomic position) as a dictionary for the mutated sequence
66
- :param known_splice_sites: the indices (by genomic position) that serve as known splice sites
67
- :param threshold: the threshold for detection (difference between reference and mutated probabilities)
68
- :return: two dictionaries; discovered_pos is a dictionary containing all the positions that meat the threshold for discovery
69
- and deleted_pos containing all the positions that meet the threshold for missing and the condition for missing
70
- '''
71
-
72
- new_dict = {v: mut_dct.get(v, 0) - ref_dct.get(v, 0) for v in
73
- list(set(list(ref_dct.keys()) + list(mut_dct.keys())))}
74
-
75
- discovered_pos = {k: {'delta': round(float(v), 3), 'absolute': round(float(mut_dct[k]), 3)} for k, v in
76
- new_dict.items() if (k not in known_splice_sites and v >= threshold) or (v > 0.45)}
77
-
78
- deleted_pos = {k: {'delta': round(float(v), 3), 'absolute': round(float(mut_dct.get(k, 0)), 3)} for k, v in
79
- new_dict.items() if k in known_splice_sites and v <= -threshold}
80
-
81
-
82
- return discovered_pos, deleted_pos
83
-
84
-
85
- def run_spliceai(mutations, gene_data, sai_mrg_context=5000, min_coverage=2500, sai_threshold=0.5):
86
- positions = mutations.positions #[m.start for m in mutations]
87
- seq_start_pos = min(positions) - sai_mrg_context - min_coverage
88
- seq_end_pos = max(positions) + sai_mrg_context + min_coverage # + 1
89
-
90
- # ref_seq, ref_indices = pull_fasta_seq_endpoints(mutations.chrom, seq_start_pos, seq_end_pos)
91
- fasta_obj = Fasta_segment()
92
- ref_seq, ref_indices = fasta_obj.read_segment_endpoints(config_setup['CHROM_SOURCE'] / f'chr{mutations.chrom}.fasta',
93
- seq_start_pos,
94
- seq_end_pos)
95
-
96
-
97
- # gene_data = unload_pickle(
98
- # find_files_by_gene_name(gene_name=mutations.gene))
99
- gene_start, gene_end, rev = gene_data.gene_start, gene_data.gene_end, gene_data.rev
100
-
101
- mrna_acceptors = sorted(list(set([lst for lsts in
102
- [mrna.get('acceptors', []) for mrna in gene_data.transcripts.values() if
103
- mrna['transcript_biotype'] == 'protein_coding'] for lst in lsts])))
104
- mrna_donors = sorted(list(set([lst for lsts in
105
- [mrna.get('donors', []) for mrna in gene_data.transcripts.values() if
106
- mrna['transcript_biotype'] == 'protein_coding'] for lst in lsts])))
107
-
108
- visible_donors = np.intersect1d(mrna_donors, ref_indices)
109
- visible_acceptors = np.intersect1d(mrna_acceptors, ref_indices)
110
-
111
- start_pad = ref_indices.index(gene_start) if gene_start in ref_indices else 0
112
- end_cutoff = ref_indices.index(gene_end) if gene_end in ref_indices else len(ref_indices) # - 1
113
- end_pad = len(ref_indices) - end_cutoff
114
- ref_seq = 'N' * start_pad + ref_seq[start_pad:end_cutoff] + 'N' * end_pad
115
- ref_indices = [-1] * start_pad + ref_indices[start_pad:end_cutoff] + [-1] * end_pad
116
- mut_seq, mut_indices = ref_seq, ref_indices
117
-
118
- for mut in mutations:
119
- mut_seq, mut_indices, _, _ = generate_mut_variant(seq=mut_seq, indices=mut_indices, mut=mut)
120
-
121
- ref_indices = ref_indices[sai_mrg_context:-sai_mrg_context]
122
- mut_indices = mut_indices[sai_mrg_context:-sai_mrg_context]
123
-
124
- if rev:
125
- ref_seq = reverse_complement(ref_seq)
126
- mut_seq = reverse_complement(mut_seq)
127
- ref_indices = ref_indices[::-1]
128
- mut_indices = mut_indices[::-1]
129
-
130
- ref_seq_probs_temp = sai_predict_probs(ref_seq, sai_models)
131
- mut_seq_probs_temp = sai_predict_probs(mut_seq, sai_models)
132
-
133
- ref_seq_acceptor_probs, ref_seq_donor_probs = ref_seq_probs_temp[0, :], ref_seq_probs_temp[1, :]
134
- mut_seq_acceptor_probs, mut_seq_donor_probs = mut_seq_probs_temp[0, :], mut_seq_probs_temp[1, :]
135
-
136
- assert len(ref_indices) == len(ref_seq_acceptor_probs), 'Reference pos not the same'
137
- assert len(mut_indices) == len(mut_seq_acceptor_probs), 'Mut pos not the same'
138
-
139
- iap, dap = find_ss_changes({p: v for p, v in list(zip(ref_indices, ref_seq_acceptor_probs))},
140
- {p: v for p, v in list(zip(mut_indices, mut_seq_acceptor_probs))},
141
- visible_acceptors,
142
- threshold=sai_threshold)
143
-
144
- assert len(ref_indices) == len(ref_seq_donor_probs), 'Reference pos not the same'
145
- assert len(mut_indices) == len(mut_seq_donor_probs), 'Mut pos not the same'
146
-
147
- idp, ddp = find_ss_changes({p: v for p, v in list(zip(ref_indices, ref_seq_donor_probs))},
148
- {p: v for p, v in list(zip(mut_indices, mut_seq_donor_probs))},
149
- visible_donors,
150
- threshold=sai_threshold)
151
-
152
- missplicing = {'missed_acceptors': dap, 'missed_donors': ddp, 'discovered_acceptors': iap, 'discovered_donors': idp}
153
- missplicing = {outk: {float(k): v for k, v in outv.items()} for outk, outv in missplicing.items()}
154
- return {outk: {int(k) if k.is_integer() else k: v for k, v in outv.items()} for outk, outv in missplicing.items()}
155
-
156
-
157
-
158
- class PredictSpliceAI:
159
- def __init__(self, mutation, gene_data, threshold=0.5, force=False, sai_mrg_context=5000, min_coverage=2500):
160
- self.modification = mutation
161
- self.threshold = threshold
162
-
163
- # if '|' in mutation.mut_id:
164
- self.spliceai_db = config_setup['MISSPLICING_PATH'] / f'spliceai_epistatic'
165
- # else:
166
- # self.spliceai_db = config_setup['MISSPLICING_PATH'] / f'spliceai_individual'
167
-
168
- self.missplicing = {}
169
-
170
- if self.prediction_file_exists() and not force:
171
- self.missplicing = self.load_sai_predictions()
172
-
173
- else:
174
- self.missplicing = run_spliceai(self.modification, gene_data=gene_data, sai_mrg_context=sai_mrg_context, min_coverage=min_coverage, sai_threshold=0.1)
175
- self.save_sai_predictions()
176
-
177
- def __repr__(self):
178
- return f'Missplicing({self.modification.mut_id}) --> {self.missplicing}'
179
-
180
- def __str__(self):
181
- return self.aberrant_splicing
182
- def __bool__(self):
183
- for event, details in self.aberrant_splicing.items():
184
- if details:
185
- return True
186
- return False
187
-
188
- def __eq__(self, alt_splicing):
189
- flag, _ = check_splicing_difference(self.missplicing, alt_splicing, self.threshold)
190
- return not flag
191
-
192
- @property
193
- def aberrant_splicing(self):
194
- return self.apply_sai_threshold(self.missplicing, self.threshold)
195
-
196
- @property
197
- def prediction_file(self):
198
- return self.spliceai_db / self.modification.gene / self.modification.file_identifier_json
199
-
200
- def prediction_file_exists(self):
201
- return self.prediction_file.exists()
202
-
203
- def load_sai_predictions(self):
204
- missplicing = unload_json(self.prediction_file)
205
- missplicing = {outk: {float(k): v for k, v in outv.items()} for outk, outv in missplicing.items()}
206
- missplicing = {outk: {int(k) if k.is_integer() or 'missed' in outk else k: v for k, v in outv.items()} for
207
- outk, outv in
208
- missplicing.items()}
209
- return missplicing
210
-
211
- def save_sai_predictions(self):
212
- self.prediction_file.parent.mkdir(parents=True, exist_ok=True)
213
- dump_json(self.prediction_file, self.missplicing)
214
- def apply_sai_threshold(self, splicing_dict=None, threshold=None):
215
- splicing_dict = self.missplicing if not splicing_dict else splicing_dict
216
- threshold = self.threshold if not threshold else threshold
217
- new_dict = {}
218
- for event, details in splicing_dict.items():
219
- for e, d in details.items():
220
- if abs(d['delta']) >= threshold:
221
- return splicing_dict
222
- new_dict[event] = {} #{k: v for k, v in details.items() if abs(v['delta']) >= threshold}
223
- return new_dict
224
-
225
- def get_max_missplicing_delta(self):
226
- max_delta = 0
227
- for event, details in self.missplicing.items():
228
- for e, d in details.items():
229
- if abs(d['delta']) > max_delta:
230
- max_delta = abs(d['delta'])
231
- return max_delta
232
-
233
- def check_splicing_difference(missplicing1, missplicing2, threshold=None):
234
- flag = False
235
- true_differences = {}
236
- for event in ['missed_acceptors', 'missed_donors']:
237
- td = {}
238
- dct1 = missplicing1[event]
239
- dct2 = missplicing2[event]
240
- for k in list(set(list(dct1.keys()) + list(dct2.keys()))):
241
- diff = abs(dct1.get(k, {'delta': 0})['delta']) - abs(dct2.get(k, {'delta': 0})['delta'])
242
- if abs(diff) >= threshold:
243
- flag = True
244
- td[k] = diff
245
- true_differences[event] = td
246
- for event in ['discovered_acceptors', 'discovered_donors']:
247
- td = {}
248
- dct1 = missplicing1[event]
249
- dct2 = missplicing2[event]
250
- for k in list(set(list(dct1.keys()) + list(dct2.keys()))):
251
- diff = abs(dct1.get(k, {'delta': 0})['delta']) - abs(dct2.get(k, {'delta': 0})['delta'])
252
- if abs(diff) >= threshold:
253
- flag = True
254
- td[k] = diff
255
- true_differences[event] = td
256
- return flag, true_differences
257
-
258
-
259
-
260
-
261
- def develop_aberrant_splicing(transcript, aberrant_splicing):
262
- boundaries = [lst for lsts in [[a, b] for a, b in transcript.exons] for lst in lsts]
263
- exon_starts, exon_ends = list(zip(*transcript.exons))
264
- transcript_start, transcript_end = exon_starts[0], exon_ends[-1]
265
- next_exon_end = exon_ends[-2]
266
- rev = transcript.rev
267
- upper_range, lower_range = max(boundaries), min(boundaries)
268
- exon_starts = {v: 1 for v in exon_starts}
269
- exon_ends = {v: 1 for v in exon_ends}
270
- for k, v in aberrant_splicing.get('missed_donors', {}).items():
271
- if k in exon_ends.keys():
272
- exon_ends[k] = max(v['absolute'], 0.001)
273
- exon_ends.update(
274
- {k: v['absolute'] for k, v in aberrant_splicing.get('discovered_donors', {}).items() if lower_range <= k <= upper_range})
275
- for k, v in aberrant_splicing.get('missed_acceptors', {}).items():
276
- if k in exon_starts.keys():
277
- exon_starts[k] = max(v['absolute'], 0.001)
278
- exon_starts.update(
279
- {k: v['absolute'] for k, v in aberrant_splicing.get('discovered_acceptors', {}).items() if lower_range <= k <= upper_range})
280
- nodes = [SpliceSite(pos=pos, ss_type=0, prob=prob) for pos, prob in exon_ends.items() if
281
- lower_range <= pos <= upper_range] + \
282
- [SpliceSite(pos=pos, ss_type=1, prob=prob) for pos, prob in exon_starts.items() if
283
- lower_range <= pos <= upper_range]
284
- nodes = [s for s in nodes if s.prob > 0]
285
- nodes.sort(key=lambda x: x.pos, reverse=rev)
286
- G = nx.DiGraph()
287
- G.add_nodes_from([n.pos for n in nodes])
288
- for i in range(len(nodes)):
289
- trailing_prob, in_between = 0, []
290
- for j in range(i + 1, len(nodes)):
291
- curr_node, next_node = nodes[i], nodes[j]
292
- spread = curr_node.ss_type in in_between
293
- in_between.append(next_node.ss_type)
294
- if curr_node.ss_type != next_node.ss_type:
295
- if spread:
296
- new_prob = next_node.prob - trailing_prob
297
- if new_prob <= 0:
298
- break
299
- G.add_edge(curr_node.pos, next_node.pos)
300
- G.edges[curr_node.pos, next_node.pos]['weight'] = new_prob
301
- trailing_prob += next_node.prob
302
- else:
303
- G.add_edge(curr_node.pos, next_node.pos)
304
- G.edges[curr_node.pos, next_node.pos]['weight'] = next_node.prob
305
- trailing_prob += next_node.prob
306
-
307
- new_paths, prob_sum = {}, 0
308
- for i, path in enumerate(nx.all_simple_paths(G, transcript_start, transcript_end)):
309
- curr_prob = path_weight_mult(G, path, 'weight')
310
- prob_sum += curr_prob
311
- new_paths[i] = {'acceptors': sorted([p for p in path if p in exon_starts.keys() and p != transcript_start], reverse=rev),
312
- 'donors': sorted([p for p in path if p in exon_ends.keys() and p != transcript_end], reverse=rev),
313
- 'path_weight': curr_prob}
314
- continuance = i + 1
315
-
316
- if prob_sum < 0.1:
317
- for j, path in enumerate(nx.all_simple_paths(G, transcript_start, next_exon_end)):
318
- curr_prob = path_weight_mult(G, path, 'weight')
319
- if curr_prob < 0.1:
320
- continue
321
- prob_sum += curr_prob
322
- new_paths[continuance+j] = {'acceptors': sorted([p for p in path if p in exon_starts.keys() and p != transcript_start],
323
- reverse=rev),
324
- 'donors': sorted([p for p in path if p in exon_ends.keys() and p != transcript_end],
325
- reverse=rev),
326
- 'path_weight': curr_prob}
327
- for i, d in new_paths.items():
328
- d['path_weight'] = round(d['path_weight'] / prob_sum, 3)
329
- new_paths = {k: v for k, v in new_paths.items() if v['path_weight'] > 0.01}
330
- return list(new_paths.values())
331
-
332
-
333
- def path_weight_mult(G, path, weight):
334
- multigraph = G.is_multigraph()
335
- cost = 1
336
- if not nx.is_path(G, path):
337
- raise nx.NetworkXNoPath("path does not exist")
338
- for node, nbr in nx.utils.pairwise(path):
339
- if multigraph:
340
- cost *= min(v[weight] for v in G[node][nbr].values())
341
- else:
342
- cost *= G[node][nbr][weight]
343
- return cost
344
-
345
-
346
- @dataclass
347
- class SpliceSite(object):
348
- pos: int
349
- ss_type: int
350
- prob: float
351
- def __post_init__(self):
352
- pass
353
- def __lt__(self, other):
354
- return self.pos < other.pos
355
-
356
- def generate_random_as(transcript):
357
- ma = random.sample(transcript.acceptors, 1)[0]
358
- md = random.sample(transcript.donors, 1)[0]
359
- da = random.sample(list(range(min(transcript.acceptors), max(transcript.acceptors))), 1)[0]
360
- dd = random.sample(list(range(min(transcript.donors), max(transcript.donors))), 1)[0]
361
- return {
362
- 'discovered_acceptors': {da: {'absolute': 0.9}},
363
- 'discovered_donors': {dd: {'absolute': 0.6}},
364
- 'missed_donors': {ma: {'absolute': 0.2}},
365
- 'missed_acceptors': {md: {'absolute': 0.1}},
366
- }
geney/survival.py DELETED
@@ -1,124 +0,0 @@
1
- import pandas as pd
2
- import numpy as np
3
- import matplotlib.pyplot as plt
4
- from pathlib import Path
5
- from scipy.integrate import trapz
6
- from geney.utils import unload_pickle, unload_json, contains
7
- from lifelines.exceptions import ConvergenceError
8
- from lifelines import KaplanMeierFitter
9
- from lifelines.statistics import logrank_test
10
- from lifelines import CoxPHFitter
11
-
12
- pd.set_option('display.max_columns', None)
13
- pd.options.mode.chained_assignment = None
14
-
15
-
16
- def prepare_clinical_data(df=None):
17
- if df is None:
18
- CLINICAL_DATA_FILE = Path('/tamir2/yoramzar/Projects/Cancer_mut/Explore_data/reports/df_p_all.pkl')
19
- df = unload_pickle(CLINICAL_DATA_FILE)
20
-
21
- df.rename(columns={'patient_uuid': 'case_id'}, inplace=True)
22
- cols = list(df.columns)
23
- cols_days_to_followup = [col for col in cols if 'days_to_followup' in col] + [col for col in cols if 'days_to_last_followup' in col]
24
- cols_days_to_know_alive = [col for col in cols if 'days_to_know_alive' in col] + [col for col in cols if 'days_to_last_known_alive' in col]
25
- cols_days_to_death = [col for col in cols if 'days_to_death' in col]
26
- cols_duration = cols_days_to_followup + cols_days_to_know_alive + cols_days_to_death
27
- col_vital_status = 'days_to_death'
28
- event_col_label = 'event'
29
- duration_col_label = 'duration'
30
- df.insert(1, event_col_label, df.apply(lambda x: int(not np.isnan(x[col_vital_status])), axis=1))
31
- df.insert(1, duration_col_label, df.apply(lambda x: max([x[col] for col in cols_duration if not np.isnan(x[col])], default=-1), axis=1))
32
- df[duration_col_label] /= 365
33
- df = df.query(f"{duration_col_label}>=0.0")[['duration', 'event', 'case_id', 'chemotherapy', 'hormone_therapy', 'immunotherapy', 'targeted_molecular_therapy', 'Proj_name']]
34
- # df.to_csv('/tamir2/nicolaslynn/data/tcga_metadata/tcga_clinical_data.csv')
35
- return df
36
-
37
-
38
- class SurvivalAnalysis:
39
- def __init__(self, clindf=None):
40
- self.clindf = prepare_clinical_data(clindf)
41
- self.treatment_features = ['chemotherapy', 'hormone_therapy', 'immunotherapy', 'targeted_molecular_therapy']
42
- self.df = self.clindf.copy()
43
- self.df['group'] = 0
44
- self.df.fillna(0, inplace=True)
45
- self.treatment_features = ['chemotherapy', 'hormone_therapy', 'immunotherapy', 'targeted_molecular_therapy']
46
-
47
- def generate_clinical_dataframe(self, target_cases, control_cases=None, inplace=False, features_of_interest=[]):
48
- df = self.df.copy()
49
- df.loc[df[df.case_id.isin(target_cases)].index, 'group'] = 2
50
- if control_cases is not None:
51
- df.loc[df[df.case_id.isin(control_cases)].index, 'group'] = 1
52
-
53
- df = df[df.group > 0]
54
- df.group -= 1
55
- core_features = ['duration', 'event']
56
- df = df[core_features + features_of_interest]
57
-
58
- for col in self.treatment_features:
59
- if col not in df:
60
- continue
61
- df.loc[df[col] > 0, col] = 1
62
-
63
- df = df[core_features + [col for col in features_of_interest if
64
- df[col].nunique() > 1]] # and df[col].value_counts(normalize=True).min() >= 0.01]]
65
- return df
66
-
67
- def kaplan_meier_analysis(self, df, control_label='CV', target_label='Epistasis', feature='group', plot=False, time_cap=False):
68
- # Can only be performed on features with two unique values
69
- cap_time = df.groupby(feature).duration.max().min()
70
- # df['duration'] = df['duration'].clip(upper=cap_time)
71
- auc_vals = []
72
- results = pd.Series()
73
- count = 0
74
- for val in [0, 1]:
75
- g = df[df[feature] == val]
76
- kmf = KaplanMeierFitter()
77
- label = f"{control_label} ({len(g)} cases)" if val == 0 else f"{target_label} ({len(g)} cases)"
78
- if val == 0:
79
- results[control_label] = len(g)
80
- else:
81
- results[target_label] = len(g)
82
-
83
- kmf.fit(g['duration'], g['event'], label=label)
84
- surv_func = kmf.survival_function_
85
- auc = trapz(surv_func[label], surv_func.index)
86
- auc_vals.append(auc)
87
- if plot:
88
- if count == 0:
89
- ax = kmf.plot()
90
- else:
91
- kmf.plot(ax=ax)
92
- count += 1
93
- p_value = self.log_rank(df[df[feature] == 1], df[df[feature] == 0])
94
-
95
- if plot:
96
- ax.text(0.5, 0.85, f'p-value: {p_value:.4f}', transform=ax.transAxes, fontsize=12,
97
- horizontalalignment='center')
98
- plt.title('Kaplan-Meier Survival Curves')
99
- plt.xlabel('Time')
100
- plt.ylabel('Survival Probability')
101
- if time_cap:
102
- plt.xlim([0, cap_time])
103
- plt.show()
104
-
105
- results['p_value'] = p_value
106
- results['auc_target'] = auc_vals[-1]
107
- if len(auc_vals) > 1:
108
- results['auc_delta'] = auc_vals[-1] - auc_vals[0]
109
- results['auc_control'] = auc_vals[0]
110
-
111
- return results
112
-
113
- def log_rank(self, group1, group2):
114
- return logrank_test(group1['duration'], group2['duration'],
115
- event_observed_A=group1['event'],
116
- event_observed_B=group2['event']).p_value
117
-
118
- def perform_cox_analysis(self, df, features_of_interest):
119
- # Very simple... will return a series with p values for each feature
120
- try:
121
- return CoxPHFitter().fit(df[features_of_interest + ['duration', 'event']], 'duration', 'event').summary.p
122
- except ConvergenceError:
123
- print("Convergence Error")
124
- return pd.Series()