spacr 0.0.35__py3-none-any.whl → 0.0.61__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/sequencing.py ADDED
@@ -0,0 +1,1130 @@
1
+ import os, re, time, math, subprocess
2
+ import numpy as np
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ from Bio import pairwise2
7
+ import statsmodels.api as sm
8
+ import statsmodels.formula.api as smf
9
+ from scipy.stats import gmean
10
+ from difflib import SequenceMatcher
11
+
12
+ def reverse_complement(dna_sequence):
13
+ complement_dict = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C', 'N':'N'}
14
+ reverse_seq = dna_sequence[::-1]
15
+ reverse_complement_seq = ''.join([complement_dict[base] for base in reverse_seq])
16
+ return reverse_complement_seq
17
+
18
+ def complement(dna_sequence):
19
+ complement_dict = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C', 'N':'N'}
20
+ complement_seq = ''.join([complement_dict[base] for base in dna_sequence])
21
+ return complement_seq
22
+
23
+ def file_len(fname):
24
+ p = subprocess.Popen(['wc', '-l', fname], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
25
+ result, err = p.communicate()
26
+ if p.returncode != 0:
27
+ raise IOError(err)
28
+ return int(result.strip().split()[0])
29
+
30
+ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max):
31
+ if grouping == 'mean':
32
+ temp = df.groupby(['plate','row','col']).mean()[variable]
33
+ if grouping == 'sum':
34
+ temp = df.groupby(['plate','row','col']).sum()[variable]
35
+ if grouping == 'count':
36
+ temp = df.groupby(['plate','row','col']).count()[variable]
37
+ if grouping in ['mean', 'count', 'sum']:
38
+ temp = pd.DataFrame(temp)
39
+ if min_max == 'all':
40
+ min_max=[np.min(temp[variable]),np.max(temp[variable])]
41
+ if min_max == 'allq':
42
+ min_max = np.quantile(temp[variable], [0.2, 0.98])
43
+ plate = df[df['plate'] == plate_number]
44
+ plate = pd.DataFrame(plate)
45
+ if grouping == 'mean':
46
+ plate = plate.groupby(['plate','row','col']).mean()[variable]
47
+ if grouping == 'sum':
48
+ plate = plate.groupby(['plate','row','col']).sum()[variable]
49
+ if grouping == 'count':
50
+ plate = plate.groupby(['plate','row','col']).count()[variable]
51
+ if grouping not in ['mean', 'count', 'sum']:
52
+ plate = plate.groupby(['plate','row','col']).mean()[variable]
53
+ if min_max == 'plate':
54
+ min_max=[np.min(plate[variable]),np.max(plate[variable])]
55
+ plate = pd.DataFrame(plate)
56
+ plate = plate.reset_index()
57
+ if 'plate' in plate.columns:
58
+ plate = plate.drop(['plate'], axis=1)
59
+ pcol = [*range(1,28,1)]
60
+ prow = [*range(1,17,1)]
61
+ new_col = []
62
+ for v in pcol:
63
+ col = 'c'+str(v)
64
+ new_col.append(col)
65
+ new_col.remove('c15')
66
+ new_row = []
67
+ for v in prow:
68
+ ro = 'r'+str(v)
69
+ new_row.append(ro)
70
+ plate_map = pd.DataFrame(columns=new_col, index = new_row)
71
+ for index, row in plate.iterrows():
72
+ r = row['row']
73
+ c = row['col']
74
+ v = row[variable]
75
+ plate_map.loc[r,c]=v
76
+ plate_map = plate_map.fillna(0)
77
+ return pd.DataFrame(plate_map), min_max
78
+
79
+ def plot_plates(df, variable, grouping, min_max, cmap):
80
+ try:
81
+ plates = np.unique(df['plate'], return_counts=False)
82
+ except:
83
+ try:
84
+ df[['plate', 'row', 'col']] = df['prc'].str.split('_', expand=True)
85
+ df = pd.DataFrame(df)
86
+ plates = np.unique(df['plate'], return_counts=False)
87
+ except:
88
+ next
89
+ #plates = np.unique(df['plate'], return_counts=False)
90
+ nr_of_plates = len(plates)
91
+ print('nr_of_plates:',nr_of_plates)
92
+ # Calculate the number of rows and columns for the subplot grid
93
+ if nr_of_plates in [1, 2, 3, 4]:
94
+ n_rows, n_cols = 1, 4
95
+ elif nr_of_plates in [5, 6, 7, 8]:
96
+ n_rows, n_cols = 2, 4
97
+ elif nr_of_plates in [9, 10, 11, 12]:
98
+ n_rows, n_cols = 3, 4
99
+ elif nr_of_plates in [13, 14, 15, 16]:
100
+ n_rows, n_cols = 4, 4
101
+
102
+ # Create the subplot grid with the specified number of rows and columns
103
+ fig, ax = plt.subplots(n_rows, n_cols, figsize=(40, 5 * n_rows))
104
+
105
+ # Flatten the axes array to a one-dimensional array
106
+ ax = ax.flatten()
107
+
108
+ # Loop over each plate and plot the heatmap
109
+ for index, plate in enumerate(plates):
110
+ plate_number = plate
111
+ plate_map, min_max = generate_plate_heatmap(df=df, plate_number=plate_number, variable=variable, grouping=grouping, min_max=min_max)
112
+ if index == 0:
113
+ print('plate_number:',plate_number,'minimum:',min_max[0], 'maximum:',min_max[1])
114
+ # Plot the heatmap on the appropriate subplot
115
+ sns.heatmap(plate_map, cmap=cmap, vmin=min_max[0], vmax=min_max[1], ax=ax[index])
116
+ ax[index].set_title(plate_number)
117
+
118
+ # Remove any empty subplots
119
+ for i in range(nr_of_plates, n_rows * n_cols):
120
+ fig.delaxes(ax[i])
121
+
122
+ # Adjust the spacing between the subplots
123
+ plt.subplots_adjust(wspace=0.1, hspace=0.4)
124
+
125
+ # Show the plot
126
+ plt.show()
127
+ print()
128
+ return
129
+
130
+ def count_mismatches(seq1, seq2, align_length=10):
131
+ alignments = pairwise2.align.globalxx(seq1, seq2)
132
+ # choose the first alignment (there might be several with the same score)
133
+ alignment = alignments[0]
134
+ # alignment is a tuple (seq1_aligned, seq2_aligned, score, begin, end)
135
+ seq1_aligned, seq2_aligned, score, begin, end = alignment
136
+ # Determine the start of alignment (first position where at least align_length bases are the same)
137
+ start_of_alignment = next(i for i in range(len(seq1_aligned) - align_length + 1)
138
+ if seq1_aligned[i:i+align_length] == seq2_aligned[i:i+align_length])
139
+ # Trim the sequences to the same length from the start of the alignment
140
+ seq1_aligned = seq1_aligned[start_of_alignment:]
141
+ seq2_aligned = seq2_aligned[start_of_alignment:]
142
+ # Trim the sequences to be of the same length (from the end)
143
+ min_length = min(len(seq1_aligned), len(seq2_aligned))
144
+ seq1_aligned = seq1_aligned[:min_length]
145
+ seq2_aligned = seq2_aligned[:min_length]
146
+ mismatches = sum(c1 != c2 for c1, c2 in zip(seq1_aligned, seq2_aligned))
147
+ return mismatches
148
+
149
+
150
+ def get_sequence_data(r1,r2):
151
+ forward_regex = re.compile(r'^(...GGTGCCACTT)TTTCAAGTTG.*?TTCTAGCTCT(AAAAC[A-Z]{18,22}AACTT)GACATCCCCA.*?AAGGCAAACA(CCCCCTTCGG....).*')
152
+ r1fd = forward_regex.search(r1)
153
+ reverce_regex = re.compile(r'^(...CCGAAGGGGG)TGTTTGCCTT.*?TGGGGATGTC(AAGTT[A-Z]{18,22}GTTTT)AGAGCTAGAA.*?CAACTTGAAA(AAGTGGCACC...).*')
154
+ r2fd = reverce_regex.search(r2)
155
+ rc_r1 = reverse_complement(r1)
156
+ rc_r2 = reverse_complement(r2)
157
+ if all(var is not None for var in [r1fd, r2fd]):
158
+ try:
159
+ r1_mis_matches, _ = count_mismatches(seq1=r1, seq2=rc_r2, align_length=5)
160
+ r2_mis_matches, _ = count_mismatches(seq1=r2, seq2=rc_r1, align_length=5)
161
+ except:
162
+ r1_mis_matches = None
163
+ r2_mis_matches = None
164
+ column_r1 = reverse_complement(r1fd[1])
165
+ sgrna_r1 = r1fd[2]
166
+ platerow_r1 = r1fd[3]
167
+ column_r2 = r2fd[3]
168
+ sgrna_r2 = reverse_complement(r2fd[2])
169
+ platerow_r2 = reverse_complement(r2fd[1])+'N'
170
+
171
+ data_dict = {'r1_plate_row':platerow_r1,
172
+ 'r1_col':column_r1,
173
+ 'r1_gRNA':sgrna_r1,
174
+ 'r1_read':r1,
175
+ 'r2_plate_row':platerow_r2,
176
+ 'r2_col':column_r2,
177
+ 'r2_gRNA':sgrna_r2,
178
+ 'r2_read':r2,
179
+ 'r1_r2_rc_mismatch':r1_mis_matches,
180
+ 'r2_r1_rc_mismatch':r2_mis_matches,
181
+ 'r1_len':len(r1),
182
+ 'r2_len':len(r2)}
183
+ else:
184
+ try:
185
+ r1_mis_matches, _ = count_mismatches(r1, rc_r2, align_length=5)
186
+ r2_mis_matches, _ = count_mismatches(r2, rc_r1, align_length=5)
187
+ except:
188
+ r1_mis_matches = None
189
+ r2_mis_matches = None
190
+ data_dict = {'r1_plate_row':None,
191
+ 'r1_col':None,
192
+ 'r1_gRNA':None,
193
+ 'r1_read':r1,
194
+ 'r2_plate_row':None,
195
+ 'r2_col':None,
196
+ 'r2_gRNA':None,
197
+ 'r2_read':r2,
198
+ 'r1_r2_rc_mismatch':r1_mis_matches,
199
+ 'r2_r1_rc_mismatch':r2_mis_matches,
200
+ 'r1_len':len(r1),
201
+ 'r2_len':len(r2)}
202
+
203
+ return data_dict
204
+
205
+ def get_read_data(identifier, prefix):
206
+ if identifier.startswith("@"):
207
+ parts = identifier.split(" ")
208
+ # The first part contains the instrument, run number, flowcell ID, lane, tile, and coordinates
209
+ instrument, run_number, flowcell_id, lane, tile, x_pos, y_pos = parts[0][1:].split(":")
210
+ # The second part contains the read number, filter status, control number, and sample number
211
+ read, is_filtered, control_number, sample_number = parts[1].split(":")
212
+ rund_data_dict = {'instrument':instrument,
213
+ 'run_number':run_number,
214
+ 'flowcell_id':flowcell_id,
215
+ 'lane':lane,
216
+ 'tile':tile,
217
+ 'x_pos':x_pos,
218
+ 'y_pos':y_pos,
219
+ 'read':read,
220
+ 'is_filtered':is_filtered,
221
+ 'control_number':control_number,
222
+ 'sample_number':sample_number}
223
+ modified_dict = {prefix + key: value for key, value in rund_data_dict.items()}
224
+ return modified_dict
225
+
226
+ def pos_dict(string):
227
+ pos_dict = {}
228
+ for i, char in enumerate(string):
229
+ if char not in pos_dict:
230
+ pos_dict[char] = [i]
231
+ else:
232
+ pos_dict[char].append(i)
233
+ return pos_dict
234
+
235
+ def truncate_read(seq,qual,target):
236
+ index = seq.find(target)
237
+ end = len(seq)-(3+len(target))
238
+ if index != -1: # If the sequence is found
239
+ if index-3 >= 0:
240
+ seq = seq[index-3:]
241
+ qual = qual[index-3:]
242
+
243
+ return seq, qual
244
+
245
+ def equalize_lengths(seq1, seq2, pad_char='N'):
246
+ len_diff = len(seq1) - len(seq2)
247
+
248
+ if len_diff > 0: # seq1 is longer
249
+ seq2 += pad_char * len_diff # pad seq2 with 'N's
250
+ elif len_diff < 0: # seq2 is longer
251
+ seq1 += pad_char * (-len_diff) # pad seq1 with 'N's
252
+
253
+ return seq1, seq2
254
+
255
+ def get_read_data(identifier, prefix):
256
+ if identifier.startswith("@"):
257
+ parts = identifier.split(" ")
258
+ # The first part contains the instrument, run number, flowcell ID, lane, tile, and coordinates
259
+ instrument, run_number, flowcell_id, lane, tile, x_pos, y_pos = parts[0][1:].split(":")
260
+ # The second part contains the read number, filter status, control number, and sample number
261
+ read, is_filtered, control_number, sample_number = parts[1].split(":")
262
+ rund_data_dict = {'instrument':instrument,
263
+ 'x_pos':x_pos,
264
+ 'y_pos':y_pos}
265
+ modified_dict = {prefix + key: value for key, value in rund_data_dict.items()}
266
+ return modified_dict
267
+
268
+ def extract_barecodes(r1_fastq, r2_fastq, csv_loc, chunk_size=100000):
269
+ data_chunk = []
270
+ # Open both FASTQ files.
271
+ with open(r1_fastq) as r1_file, open(r2_fastq) as r2_file:
272
+ index = 0
273
+ save_index = 0
274
+ while True:
275
+ index += 1
276
+ start = time.time()
277
+ # Read 4 lines at a time
278
+ r1_identifier = r1_file.readline().strip()
279
+ r1_sequence = r1_file.readline().strip()
280
+ r1_plus = r1_file.readline().strip()
281
+ r1_quality = r1_file.readline().strip()
282
+ r2_identifier = r2_file.readline().strip()
283
+ r2_sequence = r2_file.readline().strip()
284
+ r2_sequence = reverse_complement(r2_sequence)
285
+ r2_sequence = r2_sequence
286
+ r2_plus = r2_file.readline().strip()
287
+ r2_quality = r2_file.readline().strip()
288
+ r2_quality = r2_quality
289
+ if not r1_identifier or not r2_identifier:
290
+ break
291
+ #if index > 100:
292
+ # break
293
+ target = 'GGTGCCACTT'
294
+ r1_sequence, r1_quality = truncate_read(r1_sequence, r1_quality, target)
295
+ r2_sequence, r2_quality = truncate_read(r2_sequence, r2_quality, target)
296
+ r1_sequence, r2_sequence = equalize_lengths(r1_sequence, r2_sequence, pad_char='N')
297
+ r1_quality, r2_quality = equalize_lengths(r1_quality, r2_quality, pad_char='-')
298
+ alignments = pairwise2.align.globalxx(r1_sequence, r2_sequence)
299
+ alignment = alignments[0]
300
+ score = alignment[2]
301
+ column = None
302
+ platerow = None
303
+ grna = None
304
+ if score >= 125:
305
+ aligned_r1 = alignment[0]
306
+ aligned_r2 = alignment[1]
307
+ position_dict = {i+1: (base1, base2) for i, (base1, base2) in enumerate(zip(aligned_r1, aligned_r2))}
308
+ phred_quality1 = [ord(char) - 33 for char in r1_quality]
309
+ phred_quality2 = [ord(char) - 33 for char in r2_quality]
310
+ r1_q_dict = {i+1: quality for i, quality in enumerate(phred_quality1)}
311
+ r2_q_dict = {i+1: quality for i, quality in enumerate(phred_quality2)}
312
+ read = ''
313
+ for key in sorted(position_dict.keys()):
314
+ if position_dict[key][0] != '-' and (position_dict[key][1] == '-' or r1_q_dict.get(key, 0) >= r2_q_dict.get(key, 0)):
315
+ read = read + position_dict[key][0]
316
+ elif position_dict[key][1] != '-' and (position_dict[key][0] == '-' or r2_q_dict.get(key, 0) > r1_q_dict.get(key, 0)):
317
+ read = read + position_dict[key][1]
318
+ pattern = re.compile(r'^(...GGTGC)CACTT.*GCTCT(TAAAC[A-Z]{18,22}AACTT)GACAT.*CCCCC(TTCGG....).*')
319
+ regex_patterns = pattern.search(read)
320
+ if all(var is not None for var in [regex_patterns]):
321
+ column = regex_patterns[1]
322
+ grna = reverse_complement(regex_patterns[2])
323
+ platerow = reverse_complement(regex_patterns[3])
324
+ elif score < 125:
325
+ read = r1_sequence
326
+ pattern = re.compile(r'^(...GGTGC)CACTT.*GCTCT(TAAAC[A-Z]{18,22}AACTT)GACAT.*CCCCC(TTCGG....).*')
327
+ regex_patterns = pattern.search(read)
328
+ if all(var is not None for var in [regex_patterns]):
329
+ column = regex_patterns[1]
330
+ grna = reverse_complement(regex_patterns[2])
331
+ platerow = reverse_complement(regex_patterns[3])
332
+ #print('2', platerow)
333
+ data_dict = {'read':read,'column':column,'platerow':platerow,'grna':grna, 'score':score}
334
+ end = time.time()
335
+ if data_dict.get('grna') is not None:
336
+ save_index += 1
337
+ r1_rund_data_dict = get_read_data(r1_identifier, prefix='r1_')
338
+ r2_rund_data_dict = get_read_data(r2_identifier, prefix='r2_')
339
+ r1_rund_data_dict.update(r2_rund_data_dict)
340
+ r1_rund_data_dict.update(data_dict)
341
+ r1_rund_data_dict['r1_quality'] = r1_quality
342
+ r1_rund_data_dict['r2_quality'] = r2_quality
343
+ data_chunk.append(r1_rund_data_dict)
344
+ print(f'Processed reads: {index} Found barecodes in {save_index} Time/read: {end - start}', end='\r', flush=True)
345
+ if save_index % chunk_size == 0: # Every `chunk_size` reads, write to the CSV
346
+ if not os.path.isfile(csv_loc):
347
+ df = pd.DataFrame(data_chunk)
348
+ df.to_csv(csv_loc, index=False)
349
+ else:
350
+ df = pd.DataFrame(data_chunk)
351
+ df.to_csv(csv_loc, mode='a', header=False, index=False)
352
+ data_chunk = [] # Clear the chunk
353
+
354
+ def split_fastq(input_fastq, output_base, num_files):
355
+ # Create file objects for each output file
356
+ outputs = [open(f"{output_base}_{i}.fastq", "w") for i in range(num_files)]
357
+ with open(input_fastq, "r") as f:
358
+ # Initialize a counter for the lines
359
+ line_counter = 0
360
+ for line in f:
361
+ # Determine the output file
362
+ output_file = outputs[line_counter // 4 % num_files]
363
+ # Write the line to the appropriate output file
364
+ output_file.write(line)
365
+ # Increment the line counter
366
+ line_counter += 1
367
+ # Close output files
368
+ for output in outputs:
369
+ output.close()
370
+
371
+ def process_barecodes(df):
372
+ print('==== Preprocessing barecodes ====')
373
+ plate_ls = []
374
+ row_ls = []
375
+ column_ls = []
376
+ grna_ls = []
377
+ read_ls = []
378
+ score_ls = []
379
+ match_score_ls = []
380
+ index_ls = []
381
+ index = 0
382
+ print_every = 100
383
+ for i,row in df.iterrows():
384
+ index += 1
385
+ r1_instrument=row['r1_instrument']
386
+ r1_x_pos=row['r1_x_pos']
387
+ r1_y_pos=row['r1_y_pos']
388
+ r2_instrument=row['r2_instrument']
389
+ r2_x_pos=row['r2_x_pos']
390
+ r2_y_pos=row['r2_y_pos']
391
+ read=row['read']
392
+ column=row['column']
393
+ platerow=row['platerow']
394
+ grna=row['grna']
395
+ score=row['score']
396
+ r1_quality=row['r1_quality']
397
+ r2_quality=row['r2_quality']
398
+ if r1_x_pos == r2_x_pos:
399
+ if r1_y_pos == r2_y_pos:
400
+ match_score = 0
401
+
402
+ if grna.startswith('AAGTT'):
403
+ match_score += 0.5
404
+ if column.endswith('GGTGC'):
405
+ match_score += 0.5
406
+ if platerow.endswith('CCGAA'):
407
+ match_score += 0.5
408
+ index_ls.append(index)
409
+ match_score_ls.append(match_score)
410
+ score_ls.append(score)
411
+ read_ls.append(read)
412
+ plate_ls.append(platerow[:2])
413
+ row_ls.append(platerow[2:4])
414
+ column_ls.append(column[:3])
415
+ grna_ls.append(grna)
416
+ if index % print_every == 0:
417
+ print(f'Processed reads: {index}', end='\r', flush=True)
418
+ df = pd.DataFrame()
419
+ df['index'] = index_ls
420
+ df['score'] = score_ls
421
+ df['match_score'] = match_score_ls
422
+ df['plate'] = plate_ls
423
+ df['row'] = row_ls
424
+ df['col'] = column_ls
425
+ df['seq'] = grna_ls
426
+ df_high_score = df[df['score']>=125]
427
+ df_low_score = df[df['score']<125]
428
+ print(f'', flush=True)
429
+ print(f'Found {len(df_high_score)} high score reads;Found {len(df_low_score)} low score reads')
430
+ return df, df_high_score, df_low_score
431
+
432
+ def find_grna(df, grna_df):
433
+ print('==== Finding gRNAs ====')
434
+ seqs = list(set(df.seq.tolist()))
435
+ seq_ls = []
436
+ grna_ls = []
437
+ index = 0
438
+ print_every = 1000
439
+ for grna in grna_df.Seq.tolist():
440
+ reverse_regex = re.compile(r'.*({}).*'.format(grna))
441
+ for seq in seqs:
442
+ index += 1
443
+ if index % print_every == 0:
444
+ print(f'Processed reads: {index}', end='\r', flush=True)
445
+ found_grna = reverse_regex.search(seq)
446
+ if found_grna is None:
447
+ seq_ls.append('error')
448
+ grna_ls.append('error')
449
+ else:
450
+ seq_ls.append(found_grna[0])
451
+ grna_ls.append(found_grna[1])
452
+ grna_dict = dict(zip(seq_ls, grna_ls))
453
+ df = df.assign(grna_seq=df['seq'].map(grna_dict).fillna('error'))
454
+ print(f'', flush=True)
455
+ return df
456
+
457
+ def map_unmapped_grnas(df):
458
+ print('==== Mapping lost gRNA barecodes ====')
459
+ def similar(a, b):
460
+ return SequenceMatcher(None, a, b).ratio()
461
+ index = 0
462
+ print_every = 100
463
+ sequence_list = df[df['grna_seq'] != 'error']['seq'].unique().tolist()
464
+ grna_error = df[df['grna_seq']=='error']
465
+ df = grna_error.copy()
466
+ similarity_dict = {}
467
+ #change this so that it itterates throug each well
468
+ for idx, row in df.iterrows():
469
+ matches = 0
470
+ match_string = None
471
+ for string in sequence_list:
472
+ index += 1
473
+ if index % print_every == 0:
474
+ print(f'Processed reads: {index}', end='\r', flush=True)
475
+ ratio = similar(row['seq'], string)
476
+ # check if only one character is different
477
+ if ratio > ((len(row['seq']) - 1) / len(row['seq'])):
478
+ matches += 1
479
+ if matches > 1: # if we find more than one match, we break and don't add anything to the dictionary
480
+ break
481
+ match_string = string
482
+ if matches == 1: # only add to the dictionary if there was exactly one match
483
+ similarity_dict[row['seq']] = match_string
484
+ return similarity_dict
485
+
486
+ def translate_barecodes(df, grna_df, map_unmapped=False):
487
+ print('==== Translating barecodes ====')
488
+ if map_unmapped:
489
+ similarity_dict = map_unmapped_grnas(df)
490
+ df = df.assign(seq=df['seq'].map(similarity_dict).fillna('error'))
491
+ df = df.groupby(['plate','row', 'col'])['grna_seq'].value_counts().reset_index(name='count')
492
+ grna_dict = grna_df.set_index('Seq')['gene'].to_dict()
493
+
494
+ plate_barcodes = {'AA':'p1','TT':'p2','CC':'p3','GG':'p4','AT':'p5','TA':'p6','CG':'p7','GC':'p8'}
495
+
496
+ row_barcodes = {'AA':'r1','AT':'r2','AC':'r3','AG':'r4','TT':'r5','TA':'r6','TC':'r7','TG':'r8',
497
+ 'CC':'r9','CA':'r10','CT':'r11','CG':'r12','GG':'r13','GA':'r14','GT':'r15','GC':'r16'}
498
+
499
+ col_barcodes = {'AAA':'c1','TTT':'c2','CCC':'c3','GGG':'c4','AAT':'c5','AAC':'c6','AAG':'c7',
500
+ 'TTA':'c8','TTC':'c9','TTG':'c10','CCA':'c11','CCT':'c12','CCG':'c13','GGA':'c14',
501
+ 'CCT':'c15','GGC':'c16','ATT':'c17','ACC':'c18','AGG':'c19','TAA':'c20','TCC':'c21',
502
+ 'TGG':'c22','CAA':'c23','CGG':'c24'}
503
+
504
+
505
+ df['plate'] = df['plate'].map(plate_barcodes)
506
+ df['row'] = df['row'].map(row_barcodes)
507
+ df['col'] = df['col'].map(col_barcodes)
508
+ df['grna'] = df['grna_seq'].map(grna_dict)
509
+ df['gene'] = df['grna'].str.split('_').str[1]
510
+ df = df.fillna('error')
511
+ df['prc'] = df['plate']+'_'+df['row']+'_'+df['col']
512
+ df = df[df['count']>=2]
513
+ error_count = df[df.apply(lambda row: row.astype(str).str.contains('error').any(), axis=1)].shape[0]
514
+ plate_error = df['plate'].str.contains('error').sum()/len(df)
515
+ row_error = df['row'].str.contains('error').sum()/len(df)
516
+ col_error = df['col'].str.contains('error').sum()/len(df)
517
+ grna_error = df['grna'].str.contains('error').sum()/len(df)
518
+ print(f'Matched: {len(df)} rows; Errors: plate:{plate_error*100:.3f}% row:{row_error*100:.3f}% column:{col_error*100:.3f}% gRNA:{grna_error*100:.3f}%')
519
+ return df
520
+
521
+ def vert_horiz(v, h, n_col):
522
+ h = h+1
523
+ if h not in [*range(0,n_col)]:
524
+ v = v+1
525
+ h = 0
526
+ return v,h
527
+
528
+ def plot_data(df, v, h, color, n_col, ax, x_axis, y_axis, fontsize=12, lw=2, ls='-', log_x=False, log_y=False, title=None):
529
+ ax[v, h].plot(df[x_axis], df[y_axis], ls=ls, lw=lw, color=color, label=y_axis)
530
+ ax[v, h].set_title(None)
531
+ ax[v, h].set_xlabel(None)
532
+ ax[v, h].set_ylabel(None)
533
+ ax[v, h].legend(fontsize=fontsize)
534
+
535
+ if log_x:
536
+ ax[v, h].set_xscale('log')
537
+ if log_y:
538
+ ax[v, h].set_yscale('log')
539
+ v,h =vert_horiz(v, h, n_col)
540
+ return v, h
541
+
542
+ def test_error(df, min_=25,max_=3025, metric='count',log_x=False, log_y=False):
543
+ max_ = max_+min_
544
+ step = math.sqrt(min_)
545
+ plate_error_ls = []
546
+ col_error_ls = []
547
+ row_error_ls = []
548
+ grna_error_ls = []
549
+ prc_error_ls = []
550
+ total_error_ls = []
551
+ temp_len_ls = []
552
+ val_ls = []
553
+ df['sum_count'] = df.groupby('prc')['count'].transform('sum')
554
+ df['fraction'] = df['count'] / df['sum_count']
555
+ if metric=='fraction':
556
+ range_ = np.arange(min_, max_, step).tolist()
557
+ if metric=='count':
558
+ range_ = [*range(int(min_),int(max_),int(step))]
559
+ for val in range_:
560
+ temp = pd.DataFrame(df[df[metric]>val])
561
+ temp_len = len(temp)
562
+ if temp_len == 0:
563
+ break
564
+ temp_len_ls.append(temp_len)
565
+ error_count = temp[temp.apply(lambda row: row.astype(str).str.contains('error').any(), axis=1)].shape[0]/len(temp)
566
+ plate_error = temp['plate'].str.contains('error').sum()/temp_len
567
+ row_error = temp['row'].str.contains('error').sum()/temp_len
568
+ col_error = temp['col'].str.contains('error').sum()/temp_len
569
+ prc_error = temp['prc'].str.contains('error').sum()/temp_len
570
+ grna_error = temp['gene'].str.contains('error').sum()/temp_len
571
+ #print(error_count, plate_error, row_error, col_error, prc_error, grna_error)
572
+ val_ls.append(val)
573
+ total_error_ls.append(error_count)
574
+ plate_error_ls.append(plate_error)
575
+ row_error_ls.append(row_error)
576
+ col_error_ls.append(col_error)
577
+ prc_error_ls.append(prc_error)
578
+ grna_error_ls.append(grna_error)
579
+ df2 = pd.DataFrame()
580
+ df2['val'] = val_ls
581
+ df2['plate'] = plate_error_ls
582
+ df2['row'] = row_error_ls
583
+ df2['col'] = col_error_ls
584
+ df2['gRNA'] = grna_error_ls
585
+ df2['prc'] = prc_error_ls
586
+ df2['total'] = total_error_ls
587
+ df2['len'] = temp_len_ls
588
+
589
+ n_row, n_col = 2, 7
590
+ v, h, lw, ls, color = 0, 0, 1, '-', 'teal'
591
+ fig, ax = plt.subplots(n_row, n_col, figsize=(n_col*5, n_row*5))
592
+
593
+ v, h = plot_data(df=df2, v=v, h=h, color=color, n_col=n_col, ax=ax, x_axis='val', y_axis='total',log_x=log_x, log_y=log_y)
594
+ v, h = plot_data(df=df2, v=v, h=h, color=color, n_col=n_col, ax=ax, x_axis='val', y_axis='prc',log_x=log_x, log_y=log_y)
595
+ v, h = plot_data(df=df2, v=v, h=h, color=color, n_col=n_col, ax=ax, x_axis='val', y_axis='plate',log_x=log_x, log_y=log_y)
596
+ v, h = plot_data(df=df2, v=v, h=h, color=color, n_col=n_col, ax=ax, x_axis='val', y_axis='row',log_x=log_x, log_y=log_y)
597
+ v, h = plot_data(df=df2, v=v, h=h, color=color, n_col=n_col, ax=ax, x_axis='val', y_axis='col',log_x=log_x, log_y=log_y)
598
+ v, h = plot_data(df=df2, v=v, h=h, color=color, n_col=n_col, ax=ax, x_axis='val', y_axis='gRNA',log_x=log_x, log_y=log_y)
599
+ v, h = plot_data(df=df2, v=v, h=h, color=color, n_col=n_col, ax=ax, x_axis='val', y_axis='len',log_x=log_x, log_y=log_y)
600
+
601
+ def generate_fraction_map(df, gene_column, min_=10, plates=['p1','p2','p3','p4'], metric = 'count', plot=False):
602
+ df['prcs'] = df['prc']+''+df['grna_seq']
603
+ df['gene'] = df['grna'].str.split('_').str[1]
604
+ if metric == 'count':
605
+ df = pd.DataFrame(df[df['count']>min_])
606
+ df = df[~(df == 'error').any(axis=1)]
607
+ df = df[df['plate'].isin(plates)]
608
+ gRNA_well_count = df.groupby('prc')['prcs'].transform('nunique')
609
+ df['gRNA_well_count'] = gRNA_well_count
610
+ df = df[df['gRNA_well_count']>=2]
611
+ df = df[df['gRNA_well_count']<=100]
612
+ well_sum = df.groupby('prc')['count'].transform('sum')
613
+ df['well_sum'] = well_sum
614
+ df['gRNA_fraction'] = df['count']/df['well_sum']
615
+ if metric == 'fraction':
616
+ df = pd.DataFrame(df[df['gRNA_fraction']>=min_])
617
+ df = df[df['plate'].isin(plates)]
618
+ gRNA_well_count = df.groupby('prc')['prcs'].transform('nunique')
619
+ df['gRNA_well_count'] = gRNA_well_count
620
+ well_sum = df.groupby('prc')['count'].transform('sum')
621
+ df['well_sum'] = well_sum
622
+ df['gRNA_fraction'] = df['count']/df['well_sum']
623
+ if plot:
624
+ print('gRNAs/well')
625
+ plot_plates(df=df, variable='gRNA_well_count', grouping='mean', min_max='allq', cmap='viridis')
626
+ print('well read sum')
627
+ plot_plates(df=df, variable='well_sum', grouping='mean', min_max='allq', cmap='viridis')
628
+ genes = df[gene_column].unique().tolist()
629
+ wells = df['prc'].unique().tolist()
630
+ print('numer of genes:',len(genes),'numer of wells:', len(wells))
631
+ independent_variables = pd.DataFrame(columns=genes, index = wells)
632
+ for index, row in df.iterrows():
633
+ prc = row['prc']
634
+ gene = row[gene_column]
635
+ fraction = row['gRNA_fraction']
636
+ independent_variables.loc[prc,gene]=fraction
637
+ independent_variables = independent_variables.fillna(0.0)
638
+ independent_variables['sum'] = independent_variables.sum(axis=1)
639
+ independent_variables = independent_variables[independent_variables['sum']==1.0]
640
+ independent_variables = independent_variables.drop('sum', axis=1)
641
+ independent_variables.index.name = 'prc'
642
+ independent_variables = independent_variables.loc[:, (independent_variables.sum() != 0)]
643
+ return independent_variables
644
+
645
+ # Check if filename or path
646
+ def split_filenames(df, filename_column):
647
+ plate_ls = []
648
+ well_ls = []
649
+ col_ls = []
650
+ row_ls = []
651
+ field_ls = []
652
+ obj_ls = []
653
+ ls = df[filename_column].tolist()
654
+ if '/' in ls[0]:
655
+ file_list = [os.path.basename(path) for path in ls]
656
+ else:
657
+ file_list = ls
658
+ print('first file',file_list[0])
659
+ for filename in file_list:
660
+ plate = filename.split('_')[0]
661
+ plate = plate.split('plate')[1]
662
+ well = filename.split('_')[1]
663
+ field = filename.split('_')[2]
664
+ object_nr = filename.split('_')[3]
665
+ object_nr = object_nr.split('.')[0]
666
+ object_nr = 'o'+str(object_nr)
667
+ if re.match('A..', well):
668
+ row = 'r1'
669
+ if re.match('B..', well):
670
+ row = 'r2'
671
+ if re.match('C..', well):
672
+ row = 'r3'
673
+ if re.match('D..', well):
674
+ row = 'r4'
675
+ if re.match('E..', well):
676
+ row = 'r5'
677
+ if re.match('F..', well):
678
+ row = 'r6'
679
+ if re.match('G..', well):
680
+ row = 'r7'
681
+ if re.match('H..', well):
682
+ row = 'r8'
683
+ if re.match('I..', well):
684
+ row = 'r9'
685
+ if re.match('J..', well):
686
+ row = 'r10'
687
+ if re.match('K..', well):
688
+ row = 'r11'
689
+ if re.match('L..', well):
690
+ row = 'r12'
691
+ if re.match('M..', well):
692
+ row = 'r13'
693
+ if re.match('N..', well):
694
+ row = 'r14'
695
+ if re.match('O..', well):
696
+ row = 'r15'
697
+ if re.match('P..', well):
698
+ row = 'r16'
699
+ if re.match('.01', well):
700
+ col = 'c1'
701
+ if re.match('.02', well):
702
+ col = 'c2'
703
+ if re.match('.03', well):
704
+ col = 'c3'
705
+ if re.match('.04', well):
706
+ col = 'c4'
707
+ if re.match('.05', well):
708
+ col = 'c5'
709
+ if re.match('.06', well):
710
+ col = 'c6'
711
+ if re.match('.07', well):
712
+ col = 'c7'
713
+ if re.match('.08', well):
714
+ col = 'c8'
715
+ if re.match('.09', well):
716
+ col = 'c9'
717
+ if re.match('.10', well):
718
+ col = 'c10'
719
+ if re.match('.11', well):
720
+ col = 'c11'
721
+ if re.match('.12', well):
722
+ col = 'c12'
723
+ if re.match('.13', well):
724
+ col = 'c13'
725
+ if re.match('.14', well):
726
+ col = 'c14'
727
+ if re.match('.15', well):
728
+ col = 'c15'
729
+ if re.match('.16', well):
730
+ col = 'c16'
731
+ if re.match('.17', well):
732
+ col = 'c17'
733
+ if re.match('.18', well):
734
+ col = 'c18'
735
+ if re.match('.19', well):
736
+ col = 'c19'
737
+ if re.match('.20', well):
738
+ col = 'c20'
739
+ if re.match('.21', well):
740
+ col = 'c21'
741
+ if re.match('.22', well):
742
+ col = 'c22'
743
+ if re.match('.23', well):
744
+ col = 'c23'
745
+ if re.match('.24', well):
746
+ col = 'c24'
747
+ plate_ls.append(plate)
748
+ well_ls.append(well)
749
+ field_ls.append(field)
750
+ obj_ls.append(object_nr)
751
+ row_ls.append(row)
752
+ col_ls.append(col)
753
+ df['file'] = ls
754
+ df['plate'] = plate_ls
755
+ df['well'] = well_ls
756
+ df['row'] = row_ls
757
+ df['col'] = col_ls
758
+ df['field'] = field_ls
759
+ df['obj'] = obj_ls
760
+ df['plate_well'] = df['plate']+'_'+df['well']
761
+ df = df.set_index(filename_column)
762
+ return df
763
+
764
+ def rename_plate_metadata(df):
765
+ try:
766
+ df = df.drop(['plateID'], axis=1)
767
+ df = df.drop(['rowID'], axis=1)
768
+ df = df.drop(['columnID'], axis=1)
769
+ df = df.drop(['plate_row_col'], axis=1)
770
+ df = df.drop(['Unnamed: 0'], axis=1)
771
+ df = df.drop(['Unnamed: 0.1'], axis=1)
772
+ except:
773
+ next
774
+
775
+ df['plate'] = df['plate'].astype('string')
776
+ df.plate.replace('1', 'A', inplace=True)
777
+ df.plate.replace('2', 'B', inplace=True)
778
+ df.plate.replace('3', 'C', inplace=True)
779
+ df.plate.replace('4', 'D', inplace=True)
780
+ df.plate.replace('5', 'E', inplace=True)
781
+ df.plate.replace('6', 'F', inplace=True)
782
+ df.plate.replace('7', 'G', inplace=True)
783
+ df.plate.replace('8', 'H', inplace=True)
784
+ df.plate.replace('9', 'I', inplace=True)
785
+ df.plate.replace('10', 'J', inplace=True)
786
+
787
+ df.plate.replace('A', 'p1', inplace=True)# 1 - 1
788
+ df.plate.replace('B', 'p2', inplace=True)# 2 - 2
789
+ df.plate.replace('C', 'p3', inplace=True)# 3 - 3
790
+ df.plate.replace('E', 'p4', inplace=True)# 5 - 4
791
+
792
+ df.plate.replace('F', 'p5', inplace=True)# 6 - 5
793
+ df.plate.replace('G', 'p6', inplace=True)# 7 - 6
794
+ df.plate.replace('H', 'p7', inplace=True)# 8 - 7
795
+ df.plate.replace('I', 'p8', inplace=True)# 9 - 8
796
+
797
+ df['plateID'] = df['plate']
798
+
799
+ df.loc[(df['plateID'].isin(['D'])) & (df['col'].isin(['c1', 'c2', 'c3'])), 'plate'] = 'p1'
800
+ df.loc[(df['plateID'].isin(['D'])) & (df['col'].isin(['c4', 'c5', 'c6'])), 'plate'] = 'p2'
801
+ df.loc[(df['plateID'].isin(['D'])) & (df['col'].isin(['c7', 'c8', 'c9'])), 'plate'] = 'p3'
802
+ df.loc[(df['plateID'].isin(['D'])) & (df['col'].isin(['c10', 'c11', 'c12'])), 'plate'] = 'p4'
803
+
804
+ df.loc[(df['plateID'].isin(['J'])) & (df['col'].isin(['c1', 'c2', 'c3'])), 'plate'] = 'p5'
805
+ df.loc[(df['plateID'].isin(['J'])) & (df['col'].isin(['c4', 'c5', 'c6'])), 'plate'] = 'p6'
806
+ df.loc[(df['plateID'].isin(['J'])) & (df['col'].isin(['c7', 'c8', 'c9'])), 'plate'] = 'p7'
807
+ df.loc[(df['plateID'].isin(['J'])) & (df['col'].isin(['c10', 'c11', 'c12'])), 'plate'] = 'p8'
808
+
809
+ df.loc[(df['plateID'].isin(['D', 'J'])) & (df['col'].isin(['c1', 'c4', 'c7', 'c10'])), 'col'] = 'c1'
810
+ df.loc[(df['plateID'].isin(['D', 'J'])) & (df['col'].isin(['c2', 'c5', 'c8', 'c11'])), 'col'] = 'c2'
811
+ df.loc[(df['plateID'].isin(['D', 'J'])) & (df['col'].isin(['c3', 'c6', 'c9', 'c12'])), 'col'] = 'c3'
812
+
813
+ df.loc[(~df['plateID'].isin(['D', 'J'])) & (df['col'].isin(['c1'])), 'col'] = 'c25'
814
+ df.loc[(~df['plateID'].isin(['D', 'J'])) & (df['col'].isin(['c2'])), 'col'] = 'c26'
815
+ df.loc[(~df['plateID'].isin(['D', 'J'])) & (df['col'].isin(['c3'])), 'col'] = 'c27'
816
+
817
+ df.loc[(~df['plateID'].isin(['D', 'J'])) & (df['col'].isin(['c1'])), 'col'] = 'c25'
818
+
819
+ df = df.drop(['plateID'], axis=1)
820
+
821
+ df = df.loc[~df['plate'].isin(['D', 'J'])]
822
+
823
+ screen_cols = ['c1','c2','c3','c4','c5','c6','c7','c8','c9','c10','c11','c12','c13','c14','c15','c16','c17','c18','c19','c20','c21','c22','c23','c24']
824
+ screen_plates = ['p1','p2','p3','p4']
825
+ positive_control_plates = ['p5','p6','p7','p8']
826
+ positive_control_cols = screen_cols
827
+ negative_control_cols = ['c25','c26','c27']
828
+ #extra_plates = ['p9','p10']
829
+ cond_ls = []
830
+
831
+ cols = df.col.tolist()
832
+ for index, plate in enumerate(df.plate.tolist()):
833
+ co = cols[index]
834
+ if plate in screen_plates:
835
+ if co in screen_cols:
836
+ cond = 'SCREEN'
837
+ if co in negative_control_cols:
838
+ cond = 'NC'
839
+ if plate in positive_control_plates:
840
+ if co in positive_control_cols:
841
+ cond = 'PC'
842
+ if co in negative_control_cols:
843
+ cond = 'NC'
844
+ cond_ls.append(cond)
845
+
846
+ df['cond'] = cond_ls
847
+ df['plate'] = df['plate'].astype('string')
848
+ df['row'] = df['row'].astype('string')
849
+ df['col'] = df['col'].astype('string')
850
+ df['obj'] = df['obj'].astype('string')
851
+ df['prco'] = df['plate']+'_'+df['row']+'_'+df['col']+'_'+df['field']+'_'+df['obj']
852
+ df['prc'] = df['plate']+'_'+df['row']+'_'+df['col']
853
+ df = df.set_index(['prco'], drop=True)
854
+ df = df.sort_values(by = ['plate'], ascending = [True], na_position = 'first')
855
+ values, counts = np.unique(df['plate'], return_counts=True)
856
+ print('plates:', values)
857
+ print('well count:', counts)
858
+ return df
859
+
860
+ def plot_reg_res(df, coef_col='coef', col_p='P>|t|'):
861
+ df['gene'] = df.index
862
+ df[coef_col] = pd.to_numeric(df[coef_col], errors='coerce')
863
+ df[col_p] = pd.to_numeric(df[col_p], errors='coerce')
864
+ df = df.sort_values(by = [coef_col], ascending = [False], na_position = 'first')
865
+ df['color'] = 'None'
866
+ df.loc[df['gene'].str.contains('239740'), 'color'] = '239740'
867
+ df.loc[df['gene'].str.contains('205250'), 'color'] = '205250'
868
+
869
+ df.loc[df['gene'].str.contains('000000'), 'color'] = '000000'
870
+ df.loc[df['gene'].str.contains('000001'), 'color'] = '000000'
871
+ df.loc[df['gene'].str.contains('000002'), 'color'] = '000000'
872
+ df.loc[df['gene'].str.contains('000003'), 'color'] = '000000'
873
+ df.loc[df['gene'].str.contains('000004'), 'color'] = '000000'
874
+ df.loc[df['gene'].str.contains('000005'), 'color'] = '000000'
875
+ df.loc[df['gene'].str.contains('000006'), 'color'] = '000000'
876
+ df.loc[df['gene'].str.contains('000007'), 'color'] = '000000'
877
+ df.loc[df['gene'].str.contains('000008'), 'color'] = '000000'
878
+ df.loc[df['gene'].str.contains('000009'), 'color'] = '000000'
879
+ df.loc[df['gene'].str.contains('000010'), 'color'] = '000000'
880
+ fig, ax = plt.subplots(figsize=(10,10))
881
+ df.loc[df[col_p] == 0.000, col_p] = 0.001
882
+ df['logp'] = -np.log10(df[col_p])
883
+ sns.scatterplot(data = df, x = coef_col, y = 'logp', legend = False, ax = ax,
884
+ hue= 'color', hue_order = ['239740','205250','None', '000000'],
885
+ palette = ['purple', 'teal', 'lightgrey', 'black'],
886
+ size = 'color', sizes = (100, 10))
887
+ g14 = df[df['gene'].str.contains('239740')]
888
+ r18 = df[df['gene'].str.contains('205250')]
889
+ res = pd.concat([g14, r18], axis=0)
890
+ res = res[[coef_col, col_p]]
891
+ print(res)
892
+ return df, res
893
+
894
+ def reg_model(iv_loc,dv_loc):
895
+ independent_variables = pd.read_csv(iv_loc)
896
+ dependent_variable = pd.read_csv(dv_loc)
897
+ independent_variables = independent_variables.set_index('prc')
898
+ columns = independent_variables.columns
899
+ new_columns = [col.replace('TGGT1_', '') for col in columns]
900
+ independent_variables.columns = new_columns
901
+
902
+ dependent_variable = dependent_variable.set_index('prc')
903
+
904
+ reg_input = pd.DataFrame(pd.merge(independent_variables, dependent_variable, left_index=True, right_index=True))
905
+ reg_input = reg_input.dropna(axis=0, how='any')
906
+ reg_input = reg_input.dropna(axis=1, how='any')
907
+ print('Number of wells',len(reg_input))
908
+ x = reg_input.drop(['score'], axis=1)
909
+ x = sm.add_constant(x)
910
+ y = np.log10(reg_input['score']+1)
911
+ model = sm.OLS(y, x).fit()
912
+ predictions = model.predict(x)
913
+ results_summary = model.summary()
914
+ print(results_summary)
915
+ results_as_html = results_summary.tables[1].as_html()
916
+ results_df = pd.read_html(results_as_html, header=0, index_col=0)[0]
917
+ df, res = plot_reg_res(df=results_df)
918
+ return df, res
919
+
920
+ def mixed_model(iv_loc,dv_loc):
921
+ independent_variables = pd.read_csv(iv_loc)
922
+ dependent_variable = pd.read_csv(dv_loc)
923
+ independent_variables = independent_variables.set_index('prc')
924
+ columns = independent_variables.columns
925
+ new_columns = [col.replace('TGGT1_', '') for col in columns]
926
+ independent_variables.columns = new_columns
927
+ dependent_variable = dependent_variable.set_index('prc')
928
+ reg_input = pd.DataFrame(pd.merge(independent_variables, dependent_variable, left_index=True, right_index=True))
929
+ reg_input = reg_input.dropna(axis=0, how='any')
930
+
931
+ y = np.log10(reg_input['score']+1)
932
+ X = reg_input.drop('score', axis=1)
933
+ X.columns = pd.MultiIndex.from_tuples([tuple(col.split('_')) for col in X.columns],
934
+ names=['main_variable', 'sub_variable'])
935
+ # Melt the DataFrame
936
+ X_long = X.melt(ignore_index=False, var_name=['main_variable', 'sub_variable'], value_name='value')
937
+ X_long = X_long[X_long['value']>0]
938
+
939
+ # Create a new column to represent the nested structure of gRNA within gene
940
+ X_long['gene_gRNA'] = X_long['main_variable'].astype(str) + "_" + X_long['sub_variable'].astype(str)
941
+
942
+ # Add 'score' to the DataFrame
943
+ X_long['score'] = y
944
+
945
+ # Create and convert the plate, row, and column variables to type category
946
+ X_long.reset_index(inplace=True)
947
+ split_values = X_long['prc'].str.split('_', expand=True)
948
+ X_long[['plate', 'row', 'col']] = split_values
949
+ X_long['plate'] = X_long['plate'].str[1:]
950
+ X_long['plate'] = X_long['plate'].astype(int)
951
+ X_long['row'] = X_long['row'].str[1:]
952
+ X_long['row'] = X_long['row'].astype(int)
953
+ X_long['col'] = X_long['col'].str[1:]
954
+ X_long['col'] = X_long['col'].astype(int)
955
+ X_long = X_long.set_index('prc')
956
+ # Create a new column to represent the nested structure of plate, row, and column
957
+ X_long['plate_row_col'] = X_long['plate'].astype(str) + "_" + X_long['row'].astype(str) + "_" + X_long['col'].astype(str)
958
+ n_group = pd.DataFrame(X_long.groupby(['gene_gRNA']).count()['main_variable'])
959
+ n_group = n_group.rename({'main_variable': 'n_group'}, axis=1)
960
+ n_group = n_group.reset_index(drop=False)
961
+ X_long = pd.merge(X_long, n_group, on='gene_gRNA')
962
+ X_long = X_long[X_long['n_group']>1]
963
+ #print(X_long.isna().sum())
964
+
965
+ X_long['main_variable'] = X_long['main_variable'].astype('category')
966
+ X_long['sub_variable'] = X_long['sub_variable'].astype('category')
967
+ X_long['plate'] = X_long['plate'].astype('category')
968
+ X_long['row'] = X_long['row'].astype('category')
969
+ X_long['col'] = X_long['col'].astype('category')
970
+ X_long = pd.DataFrame(X_long)
971
+ print(X_long)
972
+
973
+ md = smf.mixedlm("score ~ C(main_variable)", X_long,
974
+ groups=X_long["sub_variable"])
975
+
976
+ # Define your nonlinear function here
977
+ def nonlinear_function(x, *params):
978
+ pass # Implement non linear function here
979
+
980
+ mdf = md.fit(method='bfgs', maxiter=1000)
981
+ print(mdf.summary())
982
+ summary = mdf.summary()
983
+ df = pd.DataFrame(summary.tables[1])
984
+ df, res = plot_reg_res(df, coef_col='Coef.', col_p='P>|z|')
985
+ return df, res
986
+
987
+ def calculate_accuracy(df):
988
+ df.loc[df['pc_score'] <= 0.5, 'pred'] = 0
989
+ df.loc[df['pc_score'] >= 0.5, 'pred'] = 1
990
+ df.loc[df['cond'] == 'NC', 'lab'] = 0
991
+ df.loc[df['cond'] == 'PC', 'lab'] = 1
992
+ df = df[df['cond'] != 'SCREEN']
993
+ df_nc = df[df['cond'] != 'NC']
994
+ df_pc = df[df['cond'] != 'PC']
995
+ correct = []
996
+ all_ls = []
997
+ pred_list = df['pred'].tolist()
998
+ lab_list = df['lab'].tolist()
999
+ for i,v in enumerate(pred_list):
1000
+ all_ls.append(1)
1001
+ if v == lab_list[i]:
1002
+ correct.append(1)
1003
+ print('total accuracy',len(correct)/len(all_ls))
1004
+ correct = []
1005
+ all_ls = []
1006
+ pred_list = df_pc['pred'].tolist()
1007
+ lab_list = df_pc['lab'].tolist()
1008
+ for i,v in enumerate(pred_list):
1009
+ all_ls.append(1)
1010
+ if v == lab_list[i]:
1011
+ correct.append(1)
1012
+ print('positives accuracy', len(correct)/len(all_ls))
1013
+ correct = []
1014
+ all_ls = []
1015
+ pred_list = df_nc['pred'].tolist()
1016
+ lab_list = df_nc['lab'].tolist()
1017
+ for i,v in enumerate(pred_list):
1018
+ all_ls.append(1)
1019
+ if v == lab_list[i]:
1020
+ correct.append(1)
1021
+ print('negatives accuracy',len(correct)/len(all_ls))
1022
+
1023
+ def preprocess_image_data(df, resnet_loc, min_count=25, metric='mean', plot=True, score='pc_score'):
1024
+ print('number of cells', len(df))
1025
+ resnet_preds = pd.read_csv(resnet_loc)
1026
+ res_df = split_filenames(df=resnet_preds, filename_column='path')
1027
+ pred_df = rename_plate_metadata(df=res_df)
1028
+ pred_df['prcfo'] = pred_df['plate']+'_'+pred_df['row']+'_'+pred_df['col']+'_'+pred_df['field']+'_'+pred_df['obj']
1029
+ print('number of resnet scores', len(df))
1030
+ merged_df = pd.merge(df, pred_df, on='prcfo', how='inner', suffixes=('', '_right'))
1031
+ merged_df = merged_df.rename(columns={'pred': 'pc_score'})
1032
+
1033
+ merged_df = merged_df[(merged_df['pc_score'] <= 0.25) | (merged_df['pc_score'] >= 0.75)]
1034
+
1035
+ merged_df['recruitment'] = merged_df['Toxo_channel_1_quartiles75']/merged_df['Cytosol_channel_1_quartiles75']
1036
+ merged_df = pd.DataFrame(merged_df[merged_df['duplicates'] == 1.0])
1037
+ columns_to_drop = [col for col in merged_df.columns if col.endswith('_right')]
1038
+ merged_df = merged_df.drop(columns_to_drop, axis=1)
1039
+ well_group = pd.DataFrame(merged_df.groupby(['prc']).count()['cond'])
1040
+ well_group = well_group.rename({'cond': 'cell_count'}, axis=1)
1041
+ merged_df = pd.merge(merged_df, well_group, on='prc', how='inner', suffixes=('', '_right'))
1042
+ columns_to_drop = [col for col in merged_df.columns if col.endswith('_right')]
1043
+ merged_df = merged_df.drop(columns_to_drop, axis=1)
1044
+ #merged_df = merged_df.drop(['duplicates', 'outlier', 'prcfo.1'], axis=1)
1045
+ merged_df = merged_df.drop(['duplicates', 'prcfo.1'], axis=1)
1046
+ merged_df = pd.DataFrame(merged_df[merged_df['cell_count'] >= min_count])
1047
+
1048
+ if metric == 'mean':
1049
+ well_scores_score = pd.DataFrame(merged_df.groupby(['prc']).mean()['pc_score'])
1050
+ well_scores_score = well_scores_score.rename({'pc_score': 'mean_pc_score'}, axis=1)
1051
+ well_scores_rec = pd.DataFrame(merged_df.groupby(['prc']).mean()['recruitment'])
1052
+ well_scores_rec = well_scores_rec.rename({'recruitment': 'mean_recruitment'}, axis=1)
1053
+ if metric == 'geomean':
1054
+ well_scores_score = pd.DataFrame(merged_df.groupby(['prc'])['pc_score'].apply(gmean))
1055
+ well_scores_score = well_scores_score.rename({'pc_score': 'mean_pc_score'}, axis=1)
1056
+ well_scores_rec = pd.DataFrame(merged_df.groupby(['prc'])['recruitment'].apply(gmean))
1057
+ well_scores_rec = well_scores_rec.rename({'recruitment': 'mean_recruitment'}, axis=1)
1058
+ if metric == 'median':
1059
+ well_scores_score = pd.DataFrame(merged_df.groupby(['prc']).median()['pc_score'])
1060
+ well_scores_score = well_scores_score.rename({'pc_score': 'mean_pc_score'}, axis=1)
1061
+ well_scores_rec = pd.DataFrame(merged_df.groupby(['prc']).median()['recruitment'])
1062
+ well_scores_rec = well_scores_rec.rename({'recruitment': 'mean_recruitment'}, axis=1)
1063
+ if metric == 'quntile':
1064
+ well_scores_score = pd.DataFrame(merged_df.groupby(['prc']).quantile(0.75)['pc_score'])
1065
+ well_scores_score = well_scores_score.rename({'pc_score': 'mean_pc_score'}, axis=1)
1066
+ well_scores_rec = pd.DataFrame(merged_df.groupby(['prc']).quantile(0.75)['recruitment'])
1067
+ well_scores_rec = well_scores_rec.rename({'recruitment': 'mean_recruitment'}, axis=1)
1068
+ well = pd.DataFrame(pd.DataFrame(merged_df.select_dtypes(include=['object'])).groupby(['prc']).first())
1069
+ well['mean_pc_score'] = well_scores_score['mean_pc_score']
1070
+ well['mean_recruitment'] = well_scores_rec['mean_recruitment']
1071
+ nc = well[well['cond'] == 'NC']
1072
+ max_nc = nc['mean_recruitment'].max()
1073
+ pc = well[well['cond'] == 'PC']
1074
+ screen = well[well['cond'] == 'SCREEN']
1075
+ screen = screen[screen['mean_recruitment'] <= max_nc]
1076
+ if plot:
1077
+ x_axis = 'mean_pc_score'
1078
+ fig, ax = plt.subplots(1,3,figsize=(30,10))
1079
+ sns.histplot(data=nc, x=x_axis, kde=False, stat='density', element="step", ax=ax[0], color='lightgray', log_scale=False)
1080
+ sns.histplot(data=pc, x=x_axis, kde=False, stat='density', element="step", ax=ax[0], color='teal', log_scale=False)
1081
+ sns.histplot(data=screen, x=x_axis, kde=False, stat='density', element="step", ax=ax[1], color='purple', log_scale=False)
1082
+ sns.histplot(data=nc, x=x_axis, kde=False, stat='density', element="step", ax=ax[2], color='lightgray', log_scale=False)
1083
+ sns.histplot(data=pc, x=x_axis, kde=False, stat='density', element="step", ax=ax[2], color='teal', log_scale=False)
1084
+ sns.histplot(data=screen, x=x_axis, kde=False, stat='density', element="step", ax=ax[2], color='purple', log_scale=False)
1085
+ ax[0].set_title('NC vs PC wells')
1086
+ ax[1].set_title('Screen wells')
1087
+ ax[2].set_title('NC vs PC vs Screen wells')
1088
+ ax[0].spines['top'].set_visible(False)
1089
+ ax[0].spines['right'].set_visible(False)
1090
+ ax[1].spines['top'].set_visible(False)
1091
+ ax[1].spines['right'].set_visible(False)
1092
+ ax[2].spines['top'].set_visible(False)
1093
+ ax[2].spines['right'].set_visible(False)
1094
+ ax[0].set_xlim([0, 1])
1095
+ ax[1].set_xlim([0, 1])
1096
+ ax[2].set_xlim([0, 1])
1097
+ loc = '/media/olafsson/umich/matt_graphs/resnet_score_well_av.pdf'
1098
+ fig.savefig(loc, dpi = 600, format='pdf', bbox_inches='tight')
1099
+ x_axis = 'mean_recruitment'
1100
+ fig, ax = plt.subplots(1,3,figsize=(30,10))
1101
+ sns.histplot(data=nc, x=x_axis, kde=False, stat='density', element="step", ax=ax[0], color='lightgray', log_scale=False)
1102
+ sns.histplot(data=pc, x=x_axis, kde=False, stat='density', element="step", ax=ax[0], color='teal', log_scale=False)
1103
+ sns.histplot(data=screen, x=x_axis, kde=False, stat='density', element="step", ax=ax[1], color='purple', log_scale=False)
1104
+ sns.histplot(data=nc, x=x_axis, kde=False, stat='density', element="step", ax=ax[2], color='lightgray', log_scale=False)
1105
+ sns.histplot(data=pc, x=x_axis, kde=False, stat='density', element="step", ax=ax[2], color='teal', log_scale=False)
1106
+ sns.histplot(data=screen, x=x_axis, kde=False, stat='density', element="step", ax=ax[2], color='purple', log_scale=False)
1107
+ ax[0].set_title('NC vs PC wells')
1108
+ ax[1].set_title('Screen wells')
1109
+ ax[2].set_title('NC vs PC vs Screen wells')
1110
+ ax[0].spines['top'].set_visible(False)
1111
+ ax[0].spines['right'].set_visible(False)
1112
+ ax[1].spines['top'].set_visible(False)
1113
+ ax[1].spines['right'].set_visible(False)
1114
+ ax[2].spines['top'].set_visible(False)
1115
+ ax[2].spines['right'].set_visible(False)
1116
+ loc = '/media/olafsson/umich/matt_graphs/mean_recruitment_well_av.pdf'
1117
+ fig.savefig(loc, dpi = 600, format='pdf', bbox_inches='tight')
1118
+ plates = ['p1','p2','p3','p4']
1119
+ screen = screen[screen['plate'].isin(plates)]
1120
+ if score == 'pc_score':
1121
+ dv = pd.DataFrame(screen['mean_pc_score'])
1122
+ dv = dv.rename({'mean_pc_score': 'score'}, axis=1)
1123
+ if score == 'recruitment':
1124
+ dv = pd.DataFrame(screen['mean_recruitment'])
1125
+ dv = dv.rename({'mean_recruitment': 'score'}, axis=1)
1126
+ print('dependant variable well count:', len(well))
1127
+ dv_loc = '/media/olafsson/Data2/methods_paper/data/dv.csv'
1128
+ dv.to_csv(dv_loc)
1129
+ calculate_accuracy(df=merged_df)
1130
+ return merged_df, well