spacr 0.3.1__py3-none-any.whl → 0.3.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. spacr/__init__.py +19 -3
  2. spacr/cellpose.py +311 -0
  3. spacr/core.py +245 -2494
  4. spacr/deep_spacr.py +316 -48
  5. spacr/gui.py +1 -0
  6. spacr/gui_core.py +74 -63
  7. spacr/gui_elements.py +110 -5
  8. spacr/gui_utils.py +346 -6
  9. spacr/io.py +680 -141
  10. spacr/logger.py +28 -9
  11. spacr/measure.py +107 -95
  12. spacr/mediar.py +0 -3
  13. spacr/ml.py +1051 -0
  14. spacr/openai.py +37 -0
  15. spacr/plot.py +707 -20
  16. spacr/resources/data/lopit.csv +3833 -0
  17. spacr/resources/data/toxoplasma_metadata.csv +8843 -0
  18. spacr/resources/icons/convert.png +0 -0
  19. spacr/resources/{models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model → icons/dna_matrix.mp4} +0 -0
  20. spacr/sequencing.py +241 -1311
  21. spacr/settings.py +134 -47
  22. spacr/sim.py +0 -2
  23. spacr/submodules.py +349 -0
  24. spacr/timelapse.py +0 -2
  25. spacr/toxo.py +238 -0
  26. spacr/utils.py +419 -180
  27. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/METADATA +31 -22
  28. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/RECORD +32 -33
  29. spacr/chris.py +0 -50
  30. spacr/graph_learning.py +0 -340
  31. spacr/resources/MEDIAR/.git +0 -1
  32. spacr/resources/MEDIAR_weights/.DS_Store +0 -0
  33. spacr/resources/icons/.DS_Store +0 -0
  34. spacr/resources/icons/spacr_logo_rotation.gif +0 -0
  35. spacr/resources/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +0 -23
  36. spacr/resources/models/cp/toxo_pv_lumen.CP_model +0 -0
  37. spacr/sim_app.py +0 -0
  38. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/LICENSE +0 -0
  39. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/WHEEL +0 -0
  40. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/entry_points.txt +0 -0
  41. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/top_level.txt +0 -0
spacr/sequencing.py CHANGED
@@ -1,71 +1,7 @@
1
- import os, gzip, re, time, math, subprocess, gzip
1
+ import os, gzip, re, time, gzip
2
2
  import pandas as pd
3
- import numpy as np
4
- import matplotlib.pyplot as plt
5
- import seaborn as sns
6
- from Bio import pairwise2
7
- import statsmodels.api as sm
8
- from statsmodels.regression.mixed_linear_model import MixedLM
9
- from statsmodels.stats.outliers_influence import variance_inflation_factor
10
- from scipy.stats import gmean
11
- from scipy import stats
12
- from difflib import SequenceMatcher
13
- from collections import Counter
14
- from IPython.display import display
15
3
  from multiprocessing import Pool, cpu_count, Queue, Process
16
- from rapidfuzz import process, fuzz
17
-
18
- from sklearn.linear_model import LinearRegression, Lasso, Ridge
19
- from sklearn.preprocessing import FunctionTransformer, MinMaxScaler
20
-
21
- from scipy.stats import shapiro
22
- from patsy import dmatrices
23
-
24
- from Bio import SeqIO
25
4
  from Bio.Seq import Seq
26
- from Bio.SeqRecord import SeqRecord
27
-
28
- from collections import defaultdict
29
-
30
- import gzip, re
31
- from Bio.Seq import Seq
32
- import pandas as pd
33
- import numpy as np
34
- import gzip, re
35
- from Bio.Seq import Seq
36
- import pandas as pd
37
- import numpy as np
38
- from multiprocessing import Pool, cpu_count
39
-
40
- def parse_gz_files(folder_path):
41
- """
42
- Parses the .fastq.gz files in the specified folder path and returns a dictionary
43
- containing the sample names and their corresponding file paths.
44
-
45
- Args:
46
- folder_path (str): The path to the folder containing the .fastq.gz files.
47
-
48
- Returns:
49
- dict: A dictionary where the keys are the sample names and the values are
50
- dictionaries containing the file paths for the 'R1' and 'R2' read directions.
51
- """
52
- files = os.listdir(folder_path)
53
- gz_files = [f for f in files if f.endswith('.fastq.gz')]
54
-
55
- samples_dict = {}
56
- for gz_file in gz_files:
57
- parts = gz_file.split('_')
58
- sample_name = parts[0]
59
- read_direction = parts[1]
60
-
61
- if sample_name not in samples_dict:
62
- samples_dict[sample_name] = {}
63
-
64
- if read_direction == "R1":
65
- samples_dict[sample_name]['R1'] = os.path.join(folder_path, gz_file)
66
- elif read_direction == "R2":
67
- samples_dict[sample_name]['R2'] = os.path.join(folder_path, gz_file)
68
- return samples_dict
69
5
 
70
6
  # Function to map sequences to names (same as your original)
71
7
  def map_sequences_to_names(csv_file, sequences, rc):
@@ -148,16 +84,13 @@ def reverse_complement(seq):
148
84
  # Core logic for processing a chunk (same as your original)
149
85
  def process_chunk(chunk_data):
150
86
 
151
- def find_sequence_in_chunk_reads(r1_chunk, r2_chunk, target_sequence, offset_start, expected_end):
152
- i = 0
153
- fail_count = 0
154
- failed_cases = []
155
- regex = r"^(?P<column>.{8})TGCTG.*TAAAC(?P<grna>.{20,21})AACTT.*AGAAG(?P<row>.{8}).*"
87
+ def paired_find_sequence_in_chunk_reads(r1_chunk, r2_chunk, target_sequence, offset_start, expected_end, regex):
88
+
156
89
  consensus_sequences, columns, grnas, rows = [], [], [], []
157
-
90
+
158
91
  for r1_lines, r2_lines in zip(r1_chunk, r2_chunk):
159
- r1_header, r1_sequence, r1_plus, r1_quality = r1_lines.split('\n')
160
- r2_header, r2_sequence, r2_plus, r2_quality = r2_lines.split('\n')
92
+ _, r1_sequence, _, r1_quality = r1_lines.split('\n')
93
+ _, r2_sequence, _, r2_quality = r2_lines.split('\n')
161
94
  r2_sequence = reverse_complement(r2_sequence)
162
95
 
163
96
  r1_pos = r1_sequence.find(target_sequence)
@@ -192,14 +125,84 @@ def process_chunk(chunk_data):
192
125
  grnas.append(grna_sequence)
193
126
  rows.append(row_sequence)
194
127
 
195
- return consensus_sequences, columns, grnas, rows, fail_count
128
+ if len(consensus_sequences) == 0:
129
+ print(f"WARNING: No sequences matched {regex} in chunk")
130
+ print(f"Are bacode sequences in the correct orientation?")
131
+ print(f"Is {consensus_seq} compatible with {regex} ?")
196
132
 
197
- r1_chunk, r2_chunk, target_sequence, offset_start, expected_end, column_csv, grna_csv, row_csv = chunk_data
198
- consensus_sequences, columns, grnas, rows, _ = find_sequence_in_chunk_reads(r1_chunk, r2_chunk, target_sequence, offset_start, expected_end)
133
+ if len(consensus_seq) >= expected_end:
134
+ consensus_seq_rc = reverse_complement(consensus_seq)
135
+ match = re.match(regex, consensus_seq_rc)
136
+ if match:
137
+ print(f"Reverse complement of last sequence in chunk matched {regex}")
138
+
139
+ return consensus_sequences, columns, grnas, rows
140
+
141
+ def single_find_sequence_in_chunk_reads(r1_chunk, target_sequence, offset_start, expected_end, regex):
142
+
143
+ consensus_sequences, columns, grnas, rows = [], [], [], []
144
+
145
+ for r1_lines in r1_chunk:
146
+ _, r1_sequence, _, r1_quality = r1_lines.split('\n')
147
+
148
+ # Find the target sequence in R1
149
+ r1_pos = r1_sequence.find(target_sequence)
150
+
151
+ if r1_pos != -1:
152
+ # Adjust start and end positions based on the offset and expected length
153
+ r1_start = max(r1_pos + offset_start, 0)
154
+ r1_end = min(r1_start + expected_end, len(r1_sequence))
155
+
156
+ # Extract the sequence and quality within the defined region
157
+ r1_seq, r1_qual = extract_sequence_and_quality(r1_sequence, r1_quality, r1_start, r1_end)
158
+
159
+ # If the sequence is shorter than expected, pad with 'N's and '!' for quality
160
+ if len(r1_seq) < expected_end:
161
+ r1_seq += 'N' * (expected_end - len(r1_seq))
162
+ r1_qual += '!' * (expected_end - len(r1_qual))
163
+
164
+ # Use the R1 sequence as the "consensus"
165
+ consensus_seq = r1_seq
166
+
167
+ # Check if the consensus sequence matches the regex
168
+ if len(consensus_seq) >= expected_end:
169
+ match = re.match(regex, consensus_seq)
170
+ if match:
171
+ consensus_sequences.append(consensus_seq)
172
+ column_sequence = match.group('column')
173
+ grna_sequence = match.group('grna')
174
+ row_sequence = match.group('row')
175
+ columns.append(column_sequence)
176
+ grnas.append(grna_sequence)
177
+ rows.append(row_sequence)
178
+
179
+ if len(consensus_sequences) == 0:
180
+ print(f"WARNING: No sequences matched {regex} in chunk")
181
+ print(f"Are bacode sequences in the correct orientation?")
182
+ print(f"Is {consensus_seq} compatible with {regex} ?")
183
+
184
+ if len(consensus_seq) >= expected_end:
185
+ consensus_seq_rc = reverse_complement(consensus_seq)
186
+ match = re.match(regex, consensus_seq_rc)
187
+ if match:
188
+ print(f"Reverse complement of last sequence in chunk matched {regex}")
189
+
190
+ return consensus_sequences, columns, grnas, rows
191
+
192
+ if len(chunk_data) == 9:
193
+ r1_chunk, r2_chunk, regex, target_sequence, offset_start, expected_end, column_csv, grna_csv, row_csv = chunk_data
194
+ if len(chunk_data) == 8:
195
+ r1_chunk, regex, target_sequence, offset_start, expected_end, column_csv, grna_csv, row_csv = chunk_data
196
+ r2_chunk = None
197
+
198
+ if r2_chunk is None:
199
+ consensus_sequences, columns, grnas, rows = single_find_sequence_in_chunk_reads(r1_chunk, target_sequence, offset_start, expected_end, regex)
200
+ else:
201
+ consensus_sequences, columns, grnas, rows = paired_find_sequence_in_chunk_reads(r1_chunk, r2_chunk, target_sequence, offset_start, expected_end, regex)
199
202
 
200
203
  column_names = map_sequences_to_names(column_csv, columns, rc=False)
201
- grna_names = map_sequences_to_names(grna_csv, grnas, rc=True)
202
- row_names = map_sequences_to_names(row_csv, rows, rc=True)
204
+ grna_names = map_sequences_to_names(grna_csv, grnas, rc=False)
205
+ row_names = map_sequences_to_names(row_csv, rows, rc=False)
203
206
 
204
207
  df = pd.DataFrame({
205
208
  'read': consensus_sequences,
@@ -220,18 +223,18 @@ def process_chunk(chunk_data):
220
223
  return df, unique_combinations, qc_df
221
224
 
222
225
  # Function to save data from the queue
223
- def saver_process(save_queue, hdf5_file, unique_combinations_csv, qc_csv_file, comp_type, comp_level):
226
+ def saver_process(save_queue, hdf5_file, save_h5, unique_combinations_csv, qc_csv_file, comp_type, comp_level):
224
227
  while True:
225
228
  item = save_queue.get()
226
229
  if item == "STOP":
227
230
  break
228
231
  df, unique_combinations, qc_df = item
229
- save_df_to_hdf5(df, hdf5_file, key='df', comp_type=comp_type, comp_level=comp_level)
232
+ if save_h5:
233
+ save_df_to_hdf5(df, hdf5_file, key='df', comp_type=comp_type, comp_level=comp_level)
230
234
  save_unique_combinations_to_csv(unique_combinations, unique_combinations_csv)
231
235
  save_qc_df_to_csv(qc_df, qc_csv_file)
232
236
 
233
- # Updated chunked_processing with improved multiprocessing logic
234
- def chunked_processing(r1_file, r2_file, target_sequence, offset_start, expected_end, column_csv, grna_csv, row_csv, save_h5, comp_type, comp_level, hdf5_file, unique_combinations_csv, qc_csv_file, chunk_size=10000, n_jobs=None):
237
+ def paired_read_chunked_processing(r1_file, r2_file, regex, target_sequence, offset_start, expected_end, column_csv, grna_csv, row_csv, save_h5, comp_type, comp_level, hdf5_file, unique_combinations_csv, qc_csv_file, chunk_size=10000, n_jobs=None, test=False):
235
238
 
236
239
  from .utils import count_reads_in_fastq, print_progress
237
240
 
@@ -239,24 +242,30 @@ def chunked_processing(r1_file, r2_file, target_sequence, offset_start, expected
239
242
  if n_jobs is None:
240
243
  n_jobs = cpu_count() - 3
241
244
 
242
- analyzed_chunks = 0
243
245
  chunk_count = 0
244
246
  time_ls = []
245
247
 
246
- print(f'Calculating read count for {r1_file}...')
247
- total_reads = count_reads_in_fastq(r1_file)
248
- chunks_nr = int(total_reads / chunk_size)
248
+ if not test:
249
+ print(f'Calculating read count for {r1_file}...')
250
+ total_reads = count_reads_in_fastq(r1_file)
251
+ chunks_nr = int(total_reads / chunk_size)+1
252
+ else:
253
+ total_reads = chunk_size
254
+ chunks_nr = 1
255
+
249
256
  print(f'Mapping barcodes for {total_reads} reads in {chunks_nr} batches for {r1_file}...')
250
257
 
251
258
  # Queue for saving
252
259
  save_queue = Queue()
253
260
 
254
261
  # Start the saving process
255
- save_process = Process(target=saver_process, args=(save_queue, hdf5_file, unique_combinations_csv, qc_csv_file, comp_type, comp_level))
262
+ save_process = Process(target=saver_process, args=(save_queue, hdf5_file, save_h5, unique_combinations_csv, qc_csv_file, comp_type, comp_level))
256
263
  save_process.start()
257
264
 
258
265
  pool = Pool(n_jobs)
259
266
 
267
+ print(f'Chunk size: {chunk_size}')
268
+
260
269
  with gzip.open(r1_file, 'rt') as r1, gzip.open(r2_file, 'rt') as r2:
261
270
  fastq_iter = zip(r1, r2)
262
271
  while True:
@@ -265,22 +274,27 @@ def chunked_processing(r1_file, r2_file, target_sequence, offset_start, expected
265
274
  r2_chunk = []
266
275
 
267
276
  for _ in range(chunk_size):
268
- try:
269
- r1_lines = [r1.readline().strip() for _ in range(4)]
270
- r2_lines = [r2.readline().strip() for _ in range(4)]
271
- r1_chunk.append('\n'.join(r1_lines))
272
- r2_chunk.append('\n'.join(r2_lines))
273
- except StopIteration:
277
+ # Read the next 4 lines for both R1 and R2 files
278
+ r1_lines = [r1.readline().strip() for _ in range(4)]
279
+ r2_lines = [r2.readline().strip() for _ in range(4)]
280
+
281
+ # Break if we've reached the end of either file
282
+ if not r1_lines[0] or not r2_lines[0]:
274
283
  break
275
284
 
285
+ r1_chunk.append('\n'.join(r1_lines))
286
+ r2_chunk.append('\n'.join(r2_lines))
287
+
288
+ # If the chunks are empty, break the outer while loop
276
289
  if not r1_chunk:
277
290
  break
278
291
 
279
292
  chunk_count += 1
280
- chunk_data = (r1_chunk, r2_chunk, target_sequence, offset_start, expected_end, column_csv, grna_csv, row_csv)
293
+ chunk_data = (r1_chunk, r2_chunk, regex, target_sequence, offset_start, expected_end, column_csv, grna_csv, row_csv)
281
294
 
282
295
  # Process chunks in parallel
283
296
  result = pool.apply_async(process_chunk, (chunk_data,))
297
+
284
298
  df, unique_combinations, qc_df = result.get()
285
299
 
286
300
  # Queue the results for saving
@@ -291,6 +305,11 @@ def chunked_processing(r1_file, r2_file, target_sequence, offset_start, expected
291
305
  time_ls.append(chunk_time)
292
306
  print_progress(files_processed=chunk_count, files_to_process=chunks_nr, n_jobs=n_jobs, time_ls=time_ls, batch_size=chunk_size, operation_type="Mapping Barcodes")
293
307
 
308
+ if test:
309
+ print(f'First 1000 lines in chunk 1')
310
+ print(df[:100])
311
+ break
312
+
294
313
  # Cleanup the pool
295
314
  pool.close()
296
315
  pool.join()
@@ -299,1255 +318,166 @@ def chunked_processing(r1_file, r2_file, target_sequence, offset_start, expected
299
318
  save_queue.put("STOP")
300
319
  save_process.join()
301
320
 
302
- def generate_barecode_mapping(settings={}):
303
-
304
- from .settings import set_default_generate_barecode_mapping
305
-
306
- settings = set_default_generate_barecode_mapping(settings)
307
-
308
- samples_dict = parse_gz_files(settings['src'])
309
-
310
- for key in samples_dict:
311
-
312
- if samples_dict[key]['R1'] and samples_dict[key]['R2']:
313
-
314
- dst = os.path.join(settings['src'], key)
315
- hdf5_file = os.path.join(dst, 'annotated_reads.h5')
316
- unique_combinations_csv = os.path.join(dst, 'unique_combinations.csv')
317
- qc_csv_file = os.path.join(dst, 'qc.csv')
318
- os.makedirs(dst, exist_ok=True)
319
-
320
- print(f'Analyzing reads from sample {key}')
321
-
322
- chunked_processing(r1_file=samples_dict[key]['R1'],
323
- r2_file=samples_dict[key]['R2'],
324
- target_sequence=settings['target_sequence'],
325
- offset_start=settings['offset_start'],
326
- expected_end=settings['expected_end'],
327
- column_csv=settings['column_csv'],
328
- grna_csv=settings['grna_csv'],
329
- row_csv=settings['row_csv'],
330
- save_h5 = settings['save_h5'],
331
- comp_type = settings['comp_type'],
332
- comp_level=settings['comp_level'],
333
- hdf5_file=hdf5_file,
334
- unique_combinations_csv=unique_combinations_csv,
335
- qc_csv_file=qc_csv_file,
336
- chunk_size=settings['chunk_size'],
337
- n_jobs=settings['n_jobs'])
338
-
339
-
340
-
341
-
342
-
343
-
344
-
345
-
346
-
347
-
348
-
349
-
350
-
351
-
352
-
353
-
354
-
355
-
356
-
357
- def grna_plate_heatmap(path, specific_grna=None, min_max='all', cmap='viridis', min_count=0, save=True):
358
- """
359
- Generate a heatmap of gRNA plate data.
360
-
361
- Args:
362
- path (str): The path to the CSV file containing the gRNA plate data.
363
- specific_grna (str, optional): The specific gRNA to filter the data for. Defaults to None.
364
- min_max (str or list or tuple, optional): The range of values to use for the color scale.
365
- If 'all', the range will be determined by the minimum and maximum values in the data.
366
- If 'allq', the range will be determined by the 2nd and 98th percentiles of the data.
367
- If a list or tuple of two values, the range will be determined by those values.
368
- Defaults to 'all'.
369
- cmap (str, optional): The colormap to use for the heatmap. Defaults to 'viridis'.
370
- min_count (int, optional): The minimum count threshold for including a gRNA in the heatmap.
371
- Defaults to 0.
372
- save (bool, optional): Whether to save the heatmap as a PDF file. Defaults to True.
373
-
374
- Returns:
375
- matplotlib.figure.Figure: The generated heatmap figure.
376
- """
377
- def generate_grna_plate_heatmap(df, plate_number, min_max, min_count, specific_grna=None):
378
- df = df.copy() # Work on a copy to avoid SettingWithCopyWarning
379
-
380
- # Filtering the dataframe based on the plate_number and specific gRNA if provided
381
- df = df[df['plate_row'].str.startswith(plate_number)]
382
- if specific_grna:
383
- df = df[df['grna'] == specific_grna]
384
-
385
- # Split plate_row into plate and row
386
- df[['plate', 'row']] = df['plate_row'].str.split('_', expand=True)
387
-
388
- # Ensure proper ordering
389
- row_order = [f'r{i}' for i in range(1, 17)]
390
- col_order = [f'c{i}' for i in range(1, 28)]
391
-
392
- df['row'] = pd.Categorical(df['row'], categories=row_order, ordered=True)
393
- df['column'] = pd.Categorical(df['column'], categories=col_order, ordered=True)
394
-
395
- # Group by row and column, summing counts
396
- grouped = df.groupby(['row', 'column'], observed=True)['count'].sum().reset_index()
397
-
398
- plate_map = pd.pivot_table(grouped, values='count', index='row', columns='column').fillna(0)
399
-
400
- if min_max == 'all':
401
- min_max = [plate_map.min().min(), plate_map.max().max()]
402
- elif min_max == 'allq':
403
- min_max = np.quantile(plate_map.values, [0.02, 0.98])
404
- elif isinstance(min_max, (list, tuple)) and len(min_max) == 2:
405
- if isinstance(min_max[0], (float)) and isinstance(min_max[1], (float)):
406
- min_max = np.quantile(plate_map.values, [min_max[0], min_max[1]])
407
- if isinstance(min_max[0], (int)) and isinstance(min_max[1], (int)):
408
- min_max = [min_max[0], min_max[1]]
409
-
410
- return plate_map, min_max
411
-
412
- if isinstance(path, pd.DataFrame):
413
- df = path
414
- else:
415
- df = pd.read_csv(path)
416
-
417
- plates = df['plate_row'].str.split('_', expand=True)[0].unique()
418
- n_rows, n_cols = (len(plates) + 3) // 4, 4
419
- fig, ax = plt.subplots(n_rows, n_cols, figsize=(40, 5 * n_rows))
420
- ax = ax.flatten()
421
-
422
- for index, plate in enumerate(plates):
423
- plate_map, min_max_values = generate_grna_plate_heatmap(df, plate, min_max, min_count, specific_grna)
424
- sns.heatmap(plate_map, cmap=cmap, vmin=min_max_values[0], vmax=min_max_values[1], ax=ax[index])
425
- ax[index].set_title(plate)
426
-
427
- for i in range(len(plates), n_rows * n_cols):
428
- fig.delaxes(ax[i])
429
-
430
- plt.subplots_adjust(wspace=0.1, hspace=0.4)
431
-
432
- # Save the figure
433
- if save:
434
- filename = path.replace('.csv', '')
435
- if specific_grna:
436
- filename += f'_{specific_grna}'
437
- filename += '.pdf'
438
- plt.savefig(filename)
439
- print(f'saved {filename}')
440
- plt.show()
441
-
442
- return fig
443
-
444
- def reverse_complement(dna_sequence):
445
- complement_dict = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C', 'N':'N'}
446
- reverse_seq = dna_sequence[::-1]
447
- reverse_complement_seq = ''.join([complement_dict[base] for base in reverse_seq])
448
- return reverse_complement_seq
449
-
450
- def complement(dna_sequence):
451
- complement_dict = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C', 'N':'N'}
452
- complement_seq = ''.join([complement_dict[base] for base in dna_sequence])
453
- return complement_seq
454
-
455
- def file_len(fname):
456
- p = subprocess.Popen(['wc', '-l', fname], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
457
- result, err = p.communicate()
458
- if p.returncode != 0:
459
- raise IOError(err)
460
- return int(result.strip().split()[0])
461
-
462
- def generate_plate_heatmap(df, plate_number, variable, grouping, min_max):
463
- if grouping == 'mean':
464
- temp = df.groupby(['plate','row','col']).mean()[variable]
465
- if grouping == 'sum':
466
- temp = df.groupby(['plate','row','col']).sum()[variable]
467
- if grouping == 'count':
468
- temp = df.groupby(['plate','row','col']).count()[variable]
469
- if grouping in ['mean', 'count', 'sum']:
470
- temp = pd.DataFrame(temp)
471
- if min_max == 'all':
472
- min_max=[np.min(temp[variable]),np.max(temp[variable])]
473
- if min_max == 'allq':
474
- min_max = np.quantile(temp[variable], [0.2, 0.98])
475
- plate = df[df['plate'] == plate_number]
476
- plate = pd.DataFrame(plate)
477
- if grouping == 'mean':
478
- plate = plate.groupby(['plate','row','col']).mean()[variable]
479
- if grouping == 'sum':
480
- plate = plate.groupby(['plate','row','col']).sum()[variable]
481
- if grouping == 'count':
482
- plate = plate.groupby(['plate','row','col']).count()[variable]
483
- if grouping not in ['mean', 'count', 'sum']:
484
- plate = plate.groupby(['plate','row','col']).mean()[variable]
485
- if min_max == 'plate':
486
- min_max=[np.min(plate[variable]),np.max(plate[variable])]
487
- plate = pd.DataFrame(plate)
488
- plate = plate.reset_index()
489
- if 'plate' in plate.columns:
490
- plate = plate.drop(['plate'], axis=1)
491
- pcol = [*range(1,28,1)]
492
- prow = [*range(1,17,1)]
493
- new_col = []
494
- for v in pcol:
495
- col = 'c'+str(v)
496
- new_col.append(col)
497
- new_col.remove('c15')
498
- new_row = []
499
- for v in prow:
500
- ro = 'r'+str(v)
501
- new_row.append(ro)
502
- plate_map = pd.DataFrame(columns=new_col, index = new_row)
503
- for index, row in plate.iterrows():
504
- r = row['row']
505
- c = row['col']
506
- v = row[variable]
507
- plate_map.loc[r,c]=v
508
- plate_map = plate_map.fillna(0)
509
- return pd.DataFrame(plate_map), min_max
510
-
511
- def plot_plates(df, variable, grouping, min_max, cmap):
512
- try:
513
- plates = np.unique(df['plate'], return_counts=False)
514
- except:
515
- try:
516
- df[['plate', 'row', 'col']] = df['prc'].str.split('_', expand=True)
517
- df = pd.DataFrame(df)
518
- plates = np.unique(df['plate'], return_counts=False)
519
- except:
520
- next
521
- #plates = np.unique(df['plate'], return_counts=False)
522
- nr_of_plates = len(plates)
523
- print('nr_of_plates:',nr_of_plates)
524
- # Calculate the number of rows and columns for the subplot grid
525
- if nr_of_plates in [1, 2, 3, 4]:
526
- n_rows, n_cols = 1, 4
527
- elif nr_of_plates in [5, 6, 7, 8]:
528
- n_rows, n_cols = 2, 4
529
- elif nr_of_plates in [9, 10, 11, 12]:
530
- n_rows, n_cols = 3, 4
531
- elif nr_of_plates in [13, 14, 15, 16]:
532
- n_rows, n_cols = 4, 4
533
-
534
- # Create the subplot grid with the specified number of rows and columns
535
- fig, ax = plt.subplots(n_rows, n_cols, figsize=(40, 5 * n_rows))
536
-
537
- # Flatten the axes array to a one-dimensional array
538
- ax = ax.flatten()
539
-
540
- # Loop over each plate and plot the heatmap
541
- for index, plate in enumerate(plates):
542
- plate_number = plate
543
- plate_map, min_max = generate_plate_heatmap(df=df, plate_number=plate_number, variable=variable, grouping=grouping, min_max=min_max)
544
- if index == 0:
545
- print('plate_number:',plate_number,'minimum:',min_max[0], 'maximum:',min_max[1])
546
- # Plot the heatmap on the appropriate subplot
547
- sns.heatmap(plate_map, cmap=cmap, vmin=min_max[0], vmax=min_max[1], ax=ax[index])
548
- ax[index].set_title(plate_number)
549
-
550
- # Remove any empty subplots
551
- for i in range(nr_of_plates, n_rows * n_cols):
552
- fig.delaxes(ax[i])
553
-
554
- # Adjust the spacing between the subplots
555
- plt.subplots_adjust(wspace=0.1, hspace=0.4)
556
-
557
- # Show the plot
558
- plt.show()
559
- print()
560
- return
561
-
562
- def count_mismatches(seq1, seq2, align_length=10):
563
- alignments = pairwise2.align.globalxx(seq1, seq2)
564
- # choose the first alignment (there might be several with the same score)
565
- alignment = alignments[0]
566
- # alignment is a tuple (seq1_aligned, seq2_aligned, score, begin, end)
567
- seq1_aligned, seq2_aligned, score, begin, end = alignment
568
- # Determine the start of alignment (first position where at least align_length bases are the same)
569
- start_of_alignment = next(i for i in range(len(seq1_aligned) - align_length + 1)
570
- if seq1_aligned[i:i+align_length] == seq2_aligned[i:i+align_length])
571
- # Trim the sequences to the same length from the start of the alignment
572
- seq1_aligned = seq1_aligned[start_of_alignment:]
573
- seq2_aligned = seq2_aligned[start_of_alignment:]
574
- # Trim the sequences to be of the same length (from the end)
575
- min_length = min(len(seq1_aligned), len(seq2_aligned))
576
- seq1_aligned = seq1_aligned[:min_length]
577
- seq2_aligned = seq2_aligned[:min_length]
578
- mismatches = sum(c1 != c2 for c1, c2 in zip(seq1_aligned, seq2_aligned))
579
- return mismatches
580
-
581
- def get_sequence_data(r1,r2):
582
- forward_regex = re.compile(r'^(...GGTGCCACTT)TTTCAAGTTG.*?TTCTAGCTCT(AAAAC[A-Z]{18,22}AACTT)GACATCCCCA.*?AAGGCAAACA(CCCCCTTCGG....).*')
583
- r1fd = forward_regex.search(r1)
584
- reverce_regex = re.compile(r'^(...CCGAAGGGGG)TGTTTGCCTT.*?TGGGGATGTC(AAGTT[A-Z]{18,22}GTTTT)AGAGCTAGAA.*?CAACTTGAAA(AAGTGGCACC...).*')
585
- r2fd = reverce_regex.search(r2)
586
- rc_r1 = reverse_complement(r1)
587
- rc_r2 = reverse_complement(r2)
588
- if all(var is not None for var in [r1fd, r2fd]):
589
- try:
590
- r1_mis_matches, _ = count_mismatches(seq1=r1, seq2=rc_r2, align_length=5)
591
- r2_mis_matches, _ = count_mismatches(seq1=r2, seq2=rc_r1, align_length=5)
592
- except:
593
- r1_mis_matches = None
594
- r2_mis_matches = None
595
- column_r1 = reverse_complement(r1fd[1])
596
- sgrna_r1 = r1fd[2]
597
- platerow_r1 = r1fd[3]
598
- column_r2 = r2fd[3]
599
- sgrna_r2 = reverse_complement(r2fd[2])
600
- platerow_r2 = reverse_complement(r2fd[1])+'N'
601
-
602
- data_dict = {'r1_plate_row':platerow_r1,
603
- 'r1_col':column_r1,
604
- 'r1_gRNA':sgrna_r1,
605
- 'r1_read':r1,
606
- 'r2_plate_row':platerow_r2,
607
- 'r2_col':column_r2,
608
- 'r2_gRNA':sgrna_r2,
609
- 'r2_read':r2,
610
- 'r1_r2_rc_mismatch':r1_mis_matches,
611
- 'r2_r1_rc_mismatch':r2_mis_matches,
612
- 'r1_len':len(r1),
613
- 'r2_len':len(r2)}
614
- else:
615
- try:
616
- r1_mis_matches, _ = count_mismatches(r1, rc_r2, align_length=5)
617
- r2_mis_matches, _ = count_mismatches(r2, rc_r1, align_length=5)
618
- except:
619
- r1_mis_matches = None
620
- r2_mis_matches = None
621
- data_dict = {'r1_plate_row':None,
622
- 'r1_col':None,
623
- 'r1_gRNA':None,
624
- 'r1_read':r1,
625
- 'r2_plate_row':None,
626
- 'r2_col':None,
627
- 'r2_gRNA':None,
628
- 'r2_read':r2,
629
- 'r1_r2_rc_mismatch':r1_mis_matches,
630
- 'r2_r1_rc_mismatch':r2_mis_matches,
631
- 'r1_len':len(r1),
632
- 'r2_len':len(r2)}
633
-
634
- return data_dict
635
-
636
- def get_read_data(identifier, prefix):
637
- if identifier.startswith("@"):
638
- parts = identifier.split(" ")
639
- # The first part contains the instrument, run number, flowcell ID, lane, tile, and coordinates
640
- instrument, run_number, flowcell_id, lane, tile, x_pos, y_pos = parts[0][1:].split(":")
641
- # The second part contains the read number, filter status, control number, and sample number
642
- read, is_filtered, control_number, sample_number = parts[1].split(":")
643
- rund_data_dict = {'instrument':instrument,
644
- 'run_number':run_number,
645
- 'flowcell_id':flowcell_id,
646
- 'lane':lane,
647
- 'tile':tile,
648
- 'x_pos':x_pos,
649
- 'y_pos':y_pos,
650
- 'read':read,
651
- 'is_filtered':is_filtered,
652
- 'control_number':control_number,
653
- 'sample_number':sample_number}
654
- modified_dict = {prefix + key: value for key, value in rund_data_dict.items()}
655
- return modified_dict
656
-
657
- def pos_dict(string):
658
- pos_dict = {}
659
- for i, char in enumerate(string):
660
- if char not in pos_dict:
661
- pos_dict[char] = [i]
662
- else:
663
- pos_dict[char].append(i)
664
- return pos_dict
665
-
666
- def truncate_read(seq,qual,target):
667
- index = seq.find(target)
668
- end = len(seq)-(3+len(target))
669
- if index != -1: # If the sequence is found
670
- if index-3 >= 0:
671
- seq = seq[index-3:]
672
- qual = qual[index-3:]
673
-
674
- return seq, qual
675
-
676
- def equalize_lengths(seq1, seq2, pad_char='N'):
677
- len_diff = len(seq1) - len(seq2)
678
-
679
- if len_diff > 0: # seq1 is longer
680
- seq2 += pad_char * len_diff # pad seq2 with 'N's
681
- elif len_diff < 0: # seq2 is longer
682
- seq1 += pad_char * (-len_diff) # pad seq1 with 'N's
683
-
684
- return seq1, seq2
685
-
686
- def get_read_data(identifier, prefix):
687
- if identifier.startswith("@"):
688
- parts = identifier.split(" ")
689
- # The first part contains the instrument, run number, flowcell ID, lane, tile, and coordinates
690
- instrument, run_number, flowcell_id, lane, tile, x_pos, y_pos = parts[0][1:].split(":")
691
- # The second part contains the read number, filter status, control number, and sample number
692
- read, is_filtered, control_number, sample_number = parts[1].split(":")
693
- rund_data_dict = {'instrument':instrument,
694
- 'x_pos':x_pos,
695
- 'y_pos':y_pos}
696
- modified_dict = {prefix + key: value for key, value in rund_data_dict.items()}
697
- return modified_dict
698
-
699
- def extract_barecodes(r1_fastq, r2_fastq, csv_loc, chunk_size=100000):
700
- data_chunk = []
701
- # Open both FASTQ files.
702
- with open(r1_fastq) as r1_file, open(r2_fastq) as r2_file:
703
- index = 0
704
- save_index = 0
705
- while True:
706
- index += 1
707
- start = time.time()
708
- # Read 4 lines at a time
709
- r1_identifier = r1_file.readline().strip()
710
- r1_sequence = r1_file.readline().strip()
711
- r1_plus = r1_file.readline().strip()
712
- r1_quality = r1_file.readline().strip()
713
- r2_identifier = r2_file.readline().strip()
714
- r2_sequence = r2_file.readline().strip()
715
- r2_sequence = reverse_complement(r2_sequence)
716
- r2_sequence = r2_sequence
717
- r2_plus = r2_file.readline().strip()
718
- r2_quality = r2_file.readline().strip()
719
- r2_quality = r2_quality
720
- if not r1_identifier or not r2_identifier:
721
- break
722
- #if index > 100:
723
- # break
724
- target = 'GGTGCCACTT'
725
- r1_sequence, r1_quality = truncate_read(r1_sequence, r1_quality, target)
726
- r2_sequence, r2_quality = truncate_read(r2_sequence, r2_quality, target)
727
- r1_sequence, r2_sequence = equalize_lengths(r1_sequence, r2_sequence, pad_char='N')
728
- r1_quality, r2_quality = equalize_lengths(r1_quality, r2_quality, pad_char='-')
729
- alignments = pairwise2.align.globalxx(r1_sequence, r2_sequence)
730
- alignment = alignments[0]
731
- score = alignment[2]
732
- column = None
733
- platerow = None
734
- grna = None
735
- if score >= 125:
736
- aligned_r1 = alignment[0]
737
- aligned_r2 = alignment[1]
738
- position_dict = {i+1: (base1, base2) for i, (base1, base2) in enumerate(zip(aligned_r1, aligned_r2))}
739
- phred_quality1 = [ord(char) - 33 for char in r1_quality]
740
- phred_quality2 = [ord(char) - 33 for char in r2_quality]
741
- r1_q_dict = {i+1: quality for i, quality in enumerate(phred_quality1)}
742
- r2_q_dict = {i+1: quality for i, quality in enumerate(phred_quality2)}
743
- read = ''
744
- for key in sorted(position_dict.keys()):
745
- if position_dict[key][0] != '-' and (position_dict[key][1] == '-' or r1_q_dict.get(key, 0) >= r2_q_dict.get(key, 0)):
746
- read = read + position_dict[key][0]
747
- elif position_dict[key][1] != '-' and (position_dict[key][0] == '-' or r2_q_dict.get(key, 0) > r1_q_dict.get(key, 0)):
748
- read = read + position_dict[key][1]
749
- pattern = re.compile(r'^(...GGTGC)CACTT.*GCTCT(TAAAC[A-Z]{18,22}AACTT)GACAT.*CCCCC(TTCGG....).*')
750
- regex_patterns = pattern.search(read)
751
- if all(var is not None for var in [regex_patterns]):
752
- column = regex_patterns[1]
753
- grna = reverse_complement(regex_patterns[2])
754
- platerow = reverse_complement(regex_patterns[3])
755
- elif score < 125:
756
- read = r1_sequence
757
- pattern = re.compile(r'^(...GGTGC)CACTT.*GCTCT(TAAAC[A-Z]{18,22}AACTT)GACAT.*CCCCC(TTCGG....).*')
758
- regex_patterns = pattern.search(read)
759
- if all(var is not None for var in [regex_patterns]):
760
- column = regex_patterns[1]
761
- grna = reverse_complement(regex_patterns[2])
762
- platerow = reverse_complement(regex_patterns[3])
763
- #print('2', platerow)
764
- data_dict = {'read':read,'column':column,'platerow':platerow,'grna':grna, 'score':score}
765
- end = time.time()
766
- if data_dict.get('grna') is not None:
767
- save_index += 1
768
- r1_rund_data_dict = get_read_data(r1_identifier, prefix='r1_')
769
- r2_rund_data_dict = get_read_data(r2_identifier, prefix='r2_')
770
- r1_rund_data_dict.update(r2_rund_data_dict)
771
- r1_rund_data_dict.update(data_dict)
772
- r1_rund_data_dict['r1_quality'] = r1_quality
773
- r1_rund_data_dict['r2_quality'] = r2_quality
774
- data_chunk.append(r1_rund_data_dict)
775
- print(f'Processed reads: {index} Found barecodes in {save_index} Time/read: {end - start}', end='\r', flush=True)
776
- if save_index % chunk_size == 0: # Every `chunk_size` reads, write to the CSV
777
- if not os.path.isfile(csv_loc):
778
- df = pd.DataFrame(data_chunk)
779
- df.to_csv(csv_loc, index=False)
780
- else:
781
- df = pd.DataFrame(data_chunk)
782
- df.to_csv(csv_loc, mode='a', header=False, index=False)
783
- data_chunk = [] # Clear the chunk
784
-
785
- def split_fastq(input_fastq, output_base, num_files):
786
- # Create file objects for each output file
787
- outputs = [open(f"{output_base}_{i}.fastq", "w") for i in range(num_files)]
788
- with open(input_fastq, "r") as f:
789
- # Initialize a counter for the lines
790
- line_counter = 0
791
- for line in f:
792
- # Determine the output file
793
- output_file = outputs[line_counter // 4 % num_files]
794
- # Write the line to the appropriate output file
795
- output_file.write(line)
796
- # Increment the line counter
797
- line_counter += 1
798
- # Close output files
799
- for output in outputs:
800
- output.close()
801
-
802
- def process_barecodes(df):
803
- print('==== Preprocessing barecodes ====')
804
- plate_ls = []
805
- row_ls = []
806
- column_ls = []
807
- grna_ls = []
808
- read_ls = []
809
- score_ls = []
810
- match_score_ls = []
811
- index_ls = []
812
- index = 0
813
- print_every = 100
814
- for i,row in df.iterrows():
815
- index += 1
816
- r1_instrument=row['r1_instrument']
817
- r1_x_pos=row['r1_x_pos']
818
- r1_y_pos=row['r1_y_pos']
819
- r2_instrument=row['r2_instrument']
820
- r2_x_pos=row['r2_x_pos']
821
- r2_y_pos=row['r2_y_pos']
822
- read=row['read']
823
- column=row['column']
824
- platerow=row['platerow']
825
- grna=row['grna']
826
- score=row['score']
827
- r1_quality=row['r1_quality']
828
- r2_quality=row['r2_quality']
829
- if r1_x_pos == r2_x_pos:
830
- if r1_y_pos == r2_y_pos:
831
- match_score = 0
832
-
833
- if grna.startswith('AAGTT'):
834
- match_score += 0.5
835
- if column.endswith('GGTGC'):
836
- match_score += 0.5
837
- if platerow.endswith('CCGAA'):
838
- match_score += 0.5
839
- index_ls.append(index)
840
- match_score_ls.append(match_score)
841
- score_ls.append(score)
842
- read_ls.append(read)
843
- plate_ls.append(platerow[:2])
844
- row_ls.append(platerow[2:4])
845
- column_ls.append(column[:3])
846
- grna_ls.append(grna)
847
- if index % print_every == 0:
848
- print(f'Processed reads: {index}', end='\r', flush=True)
849
- df = pd.DataFrame()
850
- df['index'] = index_ls
851
- df['score'] = score_ls
852
- df['match_score'] = match_score_ls
853
- df['plate'] = plate_ls
854
- df['row'] = row_ls
855
- df['col'] = column_ls
856
- df['seq'] = grna_ls
857
- df_high_score = df[df['score']>=125]
858
- df_low_score = df[df['score']<125]
859
- print(f'', flush=True)
860
- print(f'Found {len(df_high_score)} high score reads;Found {len(df_low_score)} low score reads')
861
- return df, df_high_score, df_low_score
862
-
863
- def find_grna(df, grna_df):
864
- print('==== Finding gRNAs ====')
865
- seqs = list(set(df.seq.tolist()))
866
- seq_ls = []
867
- grna_ls = []
868
- index = 0
869
- print_every = 1000
870
- for grna in grna_df.Seq.tolist():
871
- reverse_regex = re.compile(r'.*({}).*'.format(grna))
872
- for seq in seqs:
873
- index += 1
874
- if index % print_every == 0:
875
- print(f'Processed reads: {index}', end='\r', flush=True)
876
- found_grna = reverse_regex.search(seq)
877
- if found_grna is None:
878
- seq_ls.append('error')
879
- grna_ls.append('error')
880
- else:
881
- seq_ls.append(found_grna[0])
882
- grna_ls.append(found_grna[1])
883
- grna_dict = dict(zip(seq_ls, grna_ls))
884
- df = df.assign(grna_seq=df['seq'].map(grna_dict).fillna('error'))
885
- print(f'', flush=True)
886
- return df
887
-
888
- def map_unmapped_grnas(df):
889
- print('==== Mapping lost gRNA barecodes ====')
890
- def similar(a, b):
891
- return SequenceMatcher(None, a, b).ratio()
892
- index = 0
893
- print_every = 100
894
- sequence_list = df[df['grna_seq'] != 'error']['seq'].unique().tolist()
895
- grna_error = df[df['grna_seq']=='error']
896
- df = grna_error.copy()
897
- similarity_dict = {}
898
- #change this so that it itterates throug each well
899
- for idx, row in df.iterrows():
900
- matches = 0
901
- match_string = None
902
- for string in sequence_list:
903
- index += 1
904
- if index % print_every == 0:
905
- print(f'Processed reads: {index}', end='\r', flush=True)
906
- ratio = similar(row['seq'], string)
907
- # check if only one character is different
908
- if ratio > ((len(row['seq']) - 1) / len(row['seq'])):
909
- matches += 1
910
- if matches > 1: # if we find more than one match, we break and don't add anything to the dictionary
911
- break
912
- match_string = string
913
- if matches == 1: # only add to the dictionary if there was exactly one match
914
- similarity_dict[row['seq']] = match_string
915
- return similarity_dict
916
-
917
- def translate_barecodes(df, grna_df, map_unmapped=False):
918
- print('==== Translating barecodes ====')
919
- if map_unmapped:
920
- similarity_dict = map_unmapped_grnas(df)
921
- df = df.assign(seq=df['seq'].map(similarity_dict).fillna('error'))
922
- df = df.groupby(['plate','row', 'col'])['grna_seq'].value_counts().reset_index(name='count')
923
- grna_dict = grna_df.set_index('Seq')['gene'].to_dict()
924
-
925
- plate_barcodes = {'AA':'p1','TT':'p2','CC':'p3','GG':'p4','AT':'p5','TA':'p6','CG':'p7','GC':'p8'}
926
-
927
- row_barcodes = {'AA':'r1','AT':'r2','AC':'r3','AG':'r4','TT':'r5','TA':'r6','TC':'r7','TG':'r8',
928
- 'CC':'r9','CA':'r10','CT':'r11','CG':'r12','GG':'r13','GA':'r14','GT':'r15','GC':'r16'}
929
-
930
- col_barcodes = {'AAA':'c1','TTT':'c2','CCC':'c3','GGG':'c4','AAT':'c5','AAC':'c6','AAG':'c7',
931
- 'TTA':'c8','TTC':'c9','TTG':'c10','CCA':'c11','CCT':'c12','CCG':'c13','GGA':'c14',
932
- 'CCT':'c15','GGC':'c16','ATT':'c17','ACC':'c18','AGG':'c19','TAA':'c20','TCC':'c21',
933
- 'TGG':'c22','CAA':'c23','CGG':'c24'}
934
-
935
-
936
- df['plate'] = df['plate'].map(plate_barcodes)
937
- df['row'] = df['row'].map(row_barcodes)
938
- df['col'] = df['col'].map(col_barcodes)
939
- df['grna'] = df['grna_seq'].map(grna_dict)
940
- df['gene'] = df['grna'].str.split('_').str[1]
941
- df = df.fillna('error')
942
- df['prc'] = df['plate']+'_'+df['row']+'_'+df['col']
943
- df = df[df['count']>=2]
944
- error_count = df[df.apply(lambda row: row.astype(str).str.contains('error').any(), axis=1)].shape[0]
945
- plate_error = df['plate'].str.contains('error').sum()/len(df)
946
- row_error = df['row'].str.contains('error').sum()/len(df)
947
- col_error = df['col'].str.contains('error').sum()/len(df)
948
- grna_error = df['grna'].str.contains('error').sum()/len(df)
949
- print(f'Matched: {len(df)} rows; Errors: plate:{plate_error*100:.3f}% row:{row_error*100:.3f}% column:{col_error*100:.3f}% gRNA:{grna_error*100:.3f}%')
950
- return df
951
-
952
- def vert_horiz(v, h, n_col):
953
- h = h+1
954
- if h not in [*range(0,n_col)]:
955
- v = v+1
956
- h = 0
957
- return v,h
958
-
959
- def plot_data(df, v, h, color, n_col, ax, x_axis, y_axis, fontsize=12, lw=2, ls='-', log_x=False, log_y=False, title=None):
960
- ax[v, h].plot(df[x_axis], df[y_axis], ls=ls, lw=lw, color=color, label=y_axis)
961
- ax[v, h].set_title(None)
962
- ax[v, h].set_xlabel(None)
963
- ax[v, h].set_ylabel(None)
964
- ax[v, h].legend(fontsize=fontsize)
965
-
966
- if log_x:
967
- ax[v, h].set_xscale('log')
968
- if log_y:
969
- ax[v, h].set_yscale('log')
970
- v,h =vert_horiz(v, h, n_col)
971
- return v, h
972
-
973
- def test_error(df, min_=25,max_=3025, metric='count',log_x=False, log_y=False):
974
- max_ = max_+min_
975
- step = math.sqrt(min_)
976
- plate_error_ls = []
977
- col_error_ls = []
978
- row_error_ls = []
979
- grna_error_ls = []
980
- prc_error_ls = []
981
- total_error_ls = []
982
- temp_len_ls = []
983
- val_ls = []
984
- df['sum_count'] = df.groupby('prc')['count'].transform('sum')
985
- df['fraction'] = df['count'] / df['sum_count']
986
- if metric=='fraction':
987
- range_ = np.arange(min_, max_, step).tolist()
988
- if metric=='count':
989
- range_ = [*range(int(min_),int(max_),int(step))]
990
- for val in range_:
991
- temp = pd.DataFrame(df[df[metric]>val])
992
- temp_len = len(temp)
993
- if temp_len == 0:
994
- break
995
- temp_len_ls.append(temp_len)
996
- error_count = temp[temp.apply(lambda row: row.astype(str).str.contains('error').any(), axis=1)].shape[0]/len(temp)
997
- plate_error = temp['plate'].str.contains('error').sum()/temp_len
998
- row_error = temp['row'].str.contains('error').sum()/temp_len
999
- col_error = temp['col'].str.contains('error').sum()/temp_len
1000
- prc_error = temp['prc'].str.contains('error').sum()/temp_len
1001
- grna_error = temp['gene'].str.contains('error').sum()/temp_len
1002
- #print(error_count, plate_error, row_error, col_error, prc_error, grna_error)
1003
- val_ls.append(val)
1004
- total_error_ls.append(error_count)
1005
- plate_error_ls.append(plate_error)
1006
- row_error_ls.append(row_error)
1007
- col_error_ls.append(col_error)
1008
- prc_error_ls.append(prc_error)
1009
- grna_error_ls.append(grna_error)
1010
- df2 = pd.DataFrame()
1011
- df2['val'] = val_ls
1012
- df2['plate'] = plate_error_ls
1013
- df2['row'] = row_error_ls
1014
- df2['col'] = col_error_ls
1015
- df2['gRNA'] = grna_error_ls
1016
- df2['prc'] = prc_error_ls
1017
- df2['total'] = total_error_ls
1018
- df2['len'] = temp_len_ls
1019
-
1020
- n_row, n_col = 2, 7
1021
- v, h, lw, ls, color = 0, 0, 1, '-', 'teal'
1022
- fig, ax = plt.subplots(n_row, n_col, figsize=(n_col*5, n_row*5))
1023
-
1024
- v, h = plot_data(df=df2, v=v, h=h, color=color, n_col=n_col, ax=ax, x_axis='val', y_axis='total',log_x=log_x, log_y=log_y)
1025
- v, h = plot_data(df=df2, v=v, h=h, color=color, n_col=n_col, ax=ax, x_axis='val', y_axis='prc',log_x=log_x, log_y=log_y)
1026
- v, h = plot_data(df=df2, v=v, h=h, color=color, n_col=n_col, ax=ax, x_axis='val', y_axis='plate',log_x=log_x, log_y=log_y)
1027
- v, h = plot_data(df=df2, v=v, h=h, color=color, n_col=n_col, ax=ax, x_axis='val', y_axis='row',log_x=log_x, log_y=log_y)
1028
- v, h = plot_data(df=df2, v=v, h=h, color=color, n_col=n_col, ax=ax, x_axis='val', y_axis='col',log_x=log_x, log_y=log_y)
1029
- v, h = plot_data(df=df2, v=v, h=h, color=color, n_col=n_col, ax=ax, x_axis='val', y_axis='gRNA',log_x=log_x, log_y=log_y)
1030
- v, h = plot_data(df=df2, v=v, h=h, color=color, n_col=n_col, ax=ax, x_axis='val', y_axis='len',log_x=log_x, log_y=log_y)
1031
-
1032
- def generate_fraction_map(df, gene_column, min_=10, plates=['p1','p2','p3','p4'], metric = 'count', plot=False):
1033
- df['prcs'] = df['prc']+''+df['grna_seq']
1034
- df['gene'] = df['grna'].str.split('_').str[1]
1035
- if metric == 'count':
1036
- df = pd.DataFrame(df[df['count']>min_])
1037
- df = df[~(df == 'error').any(axis=1)]
1038
- df = df[df['plate'].isin(plates)]
1039
- gRNA_well_count = df.groupby('prc')['prcs'].transform('nunique')
1040
- df['gRNA_well_count'] = gRNA_well_count
1041
- df = df[df['gRNA_well_count']>=2]
1042
- df = df[df['gRNA_well_count']<=100]
1043
- well_sum = df.groupby('prc')['count'].transform('sum')
1044
- df['well_sum'] = well_sum
1045
- df['gRNA_fraction'] = df['count']/df['well_sum']
1046
- if metric == 'fraction':
1047
- df = pd.DataFrame(df[df['gRNA_fraction']>=min_])
1048
- df = df[df['plate'].isin(plates)]
1049
- gRNA_well_count = df.groupby('prc')['prcs'].transform('nunique')
1050
- df['gRNA_well_count'] = gRNA_well_count
1051
- well_sum = df.groupby('prc')['count'].transform('sum')
1052
- df['well_sum'] = well_sum
1053
- df['gRNA_fraction'] = df['count']/df['well_sum']
1054
- if plot:
1055
- print('gRNAs/well')
1056
- plot_plates(df=df, variable='gRNA_well_count', grouping='mean', min_max='allq', cmap='viridis')
1057
- print('well read sum')
1058
- plot_plates(df=df, variable='well_sum', grouping='mean', min_max='allq', cmap='viridis')
1059
- genes = df[gene_column].unique().tolist()
1060
- wells = df['prc'].unique().tolist()
1061
- print('numer of genes:',len(genes),'numer of wells:', len(wells))
1062
- independent_variables = pd.DataFrame(columns=genes, index = wells)
1063
- for index, row in df.iterrows():
1064
- prc = row['prc']
1065
- gene = row[gene_column]
1066
- fraction = row['gRNA_fraction']
1067
- independent_variables.loc[prc,gene]=fraction
1068
- independent_variables = independent_variables.fillna(0.0)
1069
- independent_variables['sum'] = independent_variables.sum(axis=1)
1070
- independent_variables = independent_variables[independent_variables['sum']==1.0]
1071
- independent_variables = independent_variables.drop('sum', axis=1)
1072
- independent_variables.index.name = 'prc'
1073
- independent_variables = independent_variables.loc[:, (independent_variables.sum() != 0)]
1074
- return independent_variables
1075
-
1076
- def precess_reads(csv_path, fraction_threshold, plate):
1077
- # Read the CSV file into a DataFrame
1078
- csv_df = pd.read_csv(csv_path)
1079
-
1080
- # Ensure the necessary columns are present
1081
- if not all(col in csv_df.columns for col in ['grna', 'count', 'column']):
1082
- raise ValueError("The CSV file must contain 'grna', 'count', 'plate_row', and 'column' columns.")
1083
-
1084
- if 'plate_row' in csv_df.columns:
1085
- csv_df[['plate', 'row']] = csv_df['plate_row'].str.split('_', expand=True)
1086
- if plate is not None:
1087
- csv_df = csv_df.drop(columns=['plate'])
1088
- csv_df['plate'] = plate
1089
-
1090
- if plate is not None:
1091
- csv_df['plate'] = plate
1092
-
1093
- # Create the prc column
1094
- csv_df['prc'] = csv_df['plate'] + '_' + csv_df['row'] + '_' + csv_df['column']
1095
-
1096
- # Group by prc and calculate the sum of counts
1097
- grouped_df = csv_df.groupby('prc')['count'].sum().reset_index()
1098
- grouped_df = grouped_df.rename(columns={'count': 'total_counts'})
1099
- merged_df = pd.merge(csv_df, grouped_df, on='prc')
1100
- merged_df['fraction'] = merged_df['count'] / merged_df['total_counts']
1101
-
1102
- # Filter rows with fraction under the threshold
1103
- if fraction_threshold is not None:
1104
- observations_before = len(merged_df)
1105
- merged_df = merged_df[merged_df['fraction'] >= fraction_threshold]
1106
- observations_after = len(merged_df)
1107
- removed = observations_before - observations_after
1108
- print(f'Removed {removed} observation below fraction threshold: {fraction_threshold}')
1109
-
1110
- merged_df = merged_df[['prc', 'grna', 'fraction']]
1111
-
1112
- if not all(col in merged_df.columns for col in ['grna', 'gene']):
1113
- try:
1114
- merged_df[['org', 'gene', 'grna']] = merged_df['grna'].str.split('_', expand=True)
1115
- merged_df = merged_df.drop(columns=['org'])
1116
- merged_df['grna'] = merged_df['gene'] + '_' + merged_df['grna']
1117
- except:
1118
- print('Error splitting grna into org, gene, grna.')
1119
-
1120
- return merged_df
1121
-
1122
- def apply_transformation(X, transform):
1123
- if transform == 'log':
1124
- transformer = FunctionTransformer(np.log1p, validate=True)
1125
- elif transform == 'sqrt':
1126
- transformer = FunctionTransformer(np.sqrt, validate=True)
1127
- elif transform == 'square':
1128
- transformer = FunctionTransformer(np.square, validate=True)
1129
- else:
1130
- transformer = None
1131
- return transformer
1132
-
1133
- def check_normality(data, variable_name, verbose=False):
1134
- """Check if the data is normally distributed using the Shapiro-Wilk test."""
1135
- stat, p_value = shapiro(data)
1136
- if verbose:
1137
- print(f"Shapiro-Wilk Test for {variable_name}:\nStatistic: {stat}, P-value: {p_value}")
1138
- if p_value > 0.05:
1139
- if verbose:
1140
- print(f"The data for {variable_name} is normally distributed.")
1141
- return True
1142
- else:
1143
- if verbose:
1144
- print(f"The data for {variable_name} is not normally distributed.")
1145
- return False
1146
-
1147
- def process_scores(df, dependent_variable, plate, min_cell_count=25, agg_type='mean', transform=None, regression_type='ols'):
1148
-
1149
- if plate is not None:
1150
- df['plate'] = plate
1151
-
1152
- if 'col' not in df.columns:
1153
- df['col'] = df['column']
1154
-
1155
- df['prc'] = df['plate'] + '_' + df['row'] + '_' + df['col']
1156
- df = df[['prc', dependent_variable]]
321
+ def single_read_chunked_processing(r1_file, r2_file, regex, target_sequence, offset_start, expected_end, column_csv, grna_csv, row_csv, save_h5, comp_type, comp_level, hdf5_file, unique_combinations_csv, qc_csv_file, chunk_size=10000, n_jobs=None, test=False):
1157
322
 
1158
- # Group by prc and calculate the mean and count of the dependent_variable
1159
- grouped = df.groupby('prc')[dependent_variable]
1160
-
1161
- if regression_type != 'poisson':
1162
-
1163
- print(f'Using agg_type: {agg_type}')
1164
-
1165
- if agg_type == 'median':
1166
- dependent_df = grouped.median().reset_index()
1167
- elif agg_type == 'mean':
1168
- dependent_df = grouped.mean().reset_index()
1169
- elif agg_type == 'quantile':
1170
- dependent_df = grouped.quantile(0.75).reset_index()
1171
- elif agg_type == None:
1172
- dependent_df = df.reset_index()
1173
- if 'prcfo' in dependent_df.columns:
1174
- dependent_df = dependent_df.drop(columns=['prcfo'])
1175
- else:
1176
- raise ValueError(f"Unsupported aggregation type {agg_type}")
1177
-
1178
- if regression_type == 'poisson':
1179
- agg_type = 'count'
1180
- print(f'Using agg_type: {agg_type} for poisson regression')
1181
- dependent_df = grouped.sum().reset_index()
1182
-
1183
- # Calculate cell_count for all cases
1184
- cell_count = grouped.size().reset_index(name='cell_count')
1185
-
1186
- if agg_type is None:
1187
- dependent_df = pd.merge(dependent_df, cell_count, on='prc')
1188
- else:
1189
- dependent_df['cell_count'] = cell_count['cell_count']
1190
-
1191
- dependent_df = dependent_df[dependent_df['cell_count'] >= min_cell_count]
1192
-
1193
- is_normal = check_normality(dependent_df[dependent_variable], dependent_variable)
323
+ from .utils import count_reads_in_fastq, print_progress
1194
324
 
1195
- if not transform is None:
1196
- transformer = apply_transformation(dependent_df[dependent_variable], transform=transform)
1197
- transformed_var = f'{transform}_{dependent_variable}'
1198
- dependent_df[transformed_var] = transformer.fit_transform(dependent_df[[dependent_variable]])
1199
- dependent_variable = transformed_var
1200
- is_normal = check_normality(dependent_df[transformed_var], transformed_var)
325
+ # Use cpu_count minus 3 cores if n_jobs isn't specified
326
+ if n_jobs is None:
327
+ n_jobs = cpu_count() - 3
1201
328
 
1202
- if not is_normal:
1203
- print(f'{dependent_variable} is not normally distributed')
1204
- else:
1205
- print(f'{dependent_variable} is normally distributed')
329
+ chunk_count = 0
330
+ time_ls = []
1206
331
 
1207
- return dependent_df, dependent_variable
1208
-
1209
- def perform_mixed_model(y, X, groups, alpha=1.0):
1210
- # Ensure groups are defined correctly and check for multicollinearity
1211
- if groups is None:
1212
- raise ValueError("Groups must be defined for mixed model regression")
1213
-
1214
- # Check for multicollinearity by calculating the VIF for each feature
1215
- X_np = X.values
1216
- vif = [variance_inflation_factor(X_np, i) for i in range(X_np.shape[1])]
1217
- print(f"VIF: {vif}")
1218
- if any(v > 10 for v in vif):
1219
- print(f"Multicollinearity detected with VIF: {vif}. Applying Ridge regression to the fixed effects.")
1220
- ridge = Ridge(alpha=alpha)
1221
- ridge.fit(X, y)
1222
- X_ridge = ridge.coef_ * X # Adjust X with Ridge coefficients
1223
- model = MixedLM(y, X_ridge, groups=groups)
332
+ if not test:
333
+ print(f'Calculating read count for {r1_file}...')
334
+ total_reads = count_reads_in_fastq(r1_file)
335
+ chunks_nr = int(total_reads / chunk_size) + 1
1224
336
  else:
1225
- model = MixedLM(y, X, groups=groups)
1226
-
1227
- result = model.fit()
1228
- return result
1229
-
1230
- def regression_model(X, y, regression_type='ols', groups=None, alpha=1.0, remove_row_column_effect=True):
1231
-
1232
- if regression_type == 'ols':
1233
- model = sm.OLS(y, X).fit()
1234
-
1235
- elif regression_type == 'gls':
1236
- model = sm.GLS(y, X).fit()
1237
-
1238
- elif regression_type == 'wls':
1239
- weights = 1 / np.sqrt(X.iloc[:, 1])
1240
- model = sm.WLS(y, X, weights=weights).fit()
1241
-
1242
- elif regression_type == 'rlm':
1243
- model = sm.RLM(y, X, M=sm.robust.norms.HuberT()).fit()
1244
- #model = sm.RLM(y, X, M=sm.robust.norms.TukeyBiweight()).fit()
1245
- #model = sm.RLM(y, X, M=sm.robust.norms.Hampel()).fit()
1246
- #model = sm.RLM(y, X, M=sm.robust.norms.LeastSquares()).fit()
1247
- #model = sm.RLM(y, X, M=sm.robust.norms.RamsayE()).fit()
1248
- #model = sm.RLM(y, X, M=sm.robust.norms.TrimmedMean()).fit()
1249
-
1250
- elif regression_type == 'glm':
1251
- model = sm.GLM(y, X, family=sm.families.Gaussian()).fit() # Gaussian: Used for continuous data, similar to OLS regression.
1252
- #model = sm.GLM(y, X, family=sm.families.Binomial()).fit() # Binomial: Used for binary data, modeling the probability of success.
1253
- #model = sm.GLM(y, X, family=sm.families.Poisson()).fit() # Poisson: Used for count data.
1254
- #model = sm.GLM(y, X, family=sm.families.Gamma()).fit() # Gamma: Used for continuous, positive data, often for modeling waiting times or life data.
1255
- #model = sm.GLM(y, X, family=sm.families.InverseGaussian()).fit() # Inverse Gaussian: Used for positive continuous data with a variance that increases with the
1256
- #model = sm.GLM(y, X, family=sm.families.NegativeBinomial()).fit() # Negative Binomial: Used for count data with overdispersion (variance greater than the mean).
1257
- #model = sm.GLM(y, X, family=sm.families.Tweedie()).fit() # Tweedie: Used for data that can take both positive continuous and count values, allowing for a mixture of distributions.
1258
-
1259
- elif regression_type == 'mixed':
1260
- model = perform_mixed_model(y, X, groups, alpha=alpha)
1261
-
1262
- elif regression_type == 'quantile':
1263
- model = sm.QuantReg(y, X).fit(q=alpha)
1264
-
1265
- elif regression_type == 'logit':
1266
- model = sm.Logit(y, X).fit()
337
+ total_reads = chunk_size
338
+ chunks_nr = 1
1267
339
 
1268
- elif regression_type == 'probit':
1269
- model = sm.Probit(y, X).fit()
1270
-
1271
- elif regression_type == 'poisson':
1272
- model = sm.Poisson(y, X).fit()
340
+ print(f'Mapping barcodes for {total_reads} reads in {chunks_nr} batches for {r1_file}...')
1273
341
 
1274
- elif regression_type == 'lasso':
1275
- model = Lasso(alpha=alpha).fit(X, y)
342
+ # Queue for saving
343
+ save_queue = Queue()
1276
344
 
1277
- elif regression_type == 'ridge':
1278
- model = Ridge(alpha=alpha).fit(X, y)
345
+ # Start the saving process
346
+ save_process = Process(target=saver_process, args=(save_queue, hdf5_file, save_h5, unique_combinations_csv, qc_csv_file, comp_type, comp_level))
347
+ save_process.start()
1279
348
 
1280
- else:
1281
- raise ValueError(f"Unsupported regression type {regression_type}")
1282
-
1283
- if regression_type in ['lasso', 'ridge']:
1284
- y_pred = model.predict(X)
1285
- plt.scatter(X.iloc[:, 1], y, color='blue', label='Data')
1286
- plt.plot(X.iloc[:, 1], y_pred, color='red', label='Regression line')
1287
- plt.xlabel('Features')
1288
- plt.ylabel('Dependent Variable')
1289
- plt.legend()
1290
- plt.show()
1291
-
1292
- return model
1293
-
1294
- def clean_controls(df,pc,nc,other):
1295
- if 'col' in df.columns:
1296
- df['column'] = df['col']
1297
- if nc != None:
1298
- df = df[~df['column'].isin([nc])]
1299
- if pc != None:
1300
- df = df[~df['column'].isin([pc])]
1301
- if other != None:
1302
- df = df[~df['column'].isin([other])]
1303
- print(f'Removed data from {nc, pc, other}')
1304
- return df
1305
-
1306
- # Remove outliers by capping values at 1st and 99th percentiles for numerical columns only
1307
- def remove_outliers(df, low=0.01, high=0.99):
1308
- numerical_cols = df.select_dtypes(include=[np.number]).columns
1309
- quantiles = df[numerical_cols].quantile([low, high])
1310
- for col in numerical_cols:
1311
- df[col] = np.clip(df[col], quantiles.loc[low, col], quantiles.loc[high, col])
1312
- return df
1313
-
1314
- def calculate_p_values(X, y, model):
1315
- # Predict y values
1316
- y_pred = model.predict(X)
1317
-
1318
- # Calculate residuals
1319
- residuals = y - y_pred
1320
-
1321
- # Calculate the standard error of the residuals
1322
- dof = X.shape[0] - X.shape[1] - 1
1323
- residual_std_error = np.sqrt(np.sum(residuals ** 2) / dof)
1324
-
1325
- # Calculate the standard error of the coefficients
1326
- X_design = np.hstack((np.ones((X.shape[0], 1)), X)) # Add intercept
1327
-
1328
- # Use pseudoinverse instead of inverse to handle singular matrices
1329
- coef_var_covar = residual_std_error ** 2 * np.linalg.pinv(X_design.T @ X_design)
1330
- coef_standard_errors = np.sqrt(np.diag(coef_var_covar))
1331
-
1332
- # Calculate t-statistics
1333
- t_stats = model.coef_ / coef_standard_errors[1:] # Skip intercept error
1334
-
1335
- # Calculate p-values
1336
- p_values = [2 * (1 - stats.t.cdf(np.abs(t), dof)) for t in t_stats]
1337
-
1338
- return np.array(p_values) # Ensure p_values is a 1-dimensional array
349
+ pool = Pool(n_jobs)
1339
350
 
1340
- def regression(df, csv_path, dependent_variable='predictions', regression_type=None, alpha=1.0, remove_row_column_effect=False):
351
+ with gzip.open(r1_file, 'rt') as r1:
352
+ while True:
353
+ start_time = time.time()
354
+ r1_chunk = []
1341
355
 
1342
- from .plot import volcano_plot, plot_histogram
356
+ for _ in range(chunk_size):
357
+ # Read the next 4 lines for both R1 and R2 files
358
+ r1_lines = [r1.readline().strip() for _ in range(4)]
1343
359
 
1344
- volcano_filename = os.path.splitext(os.path.basename(csv_path))[0] + '_volcano_plot.pdf'
1345
- volcano_filename = regression_type+'_'+volcano_filename
1346
- if regression_type == 'quantile':
1347
- volcano_filename = str(alpha)+'_'+volcano_filename
1348
- volcano_path=os.path.join(os.path.dirname(csv_path), volcano_filename)
360
+ # Break if we've reached the end of either file
361
+ if not r1_lines[0]:
362
+ break
1349
363
 
1350
- is_normal = check_normality(df[dependent_variable], dependent_variable)
364
+ r1_chunk.append('\n'.join(r1_lines))
1351
365
 
1352
- if regression_type is None:
1353
- if is_normal:
1354
- regression_type = 'ols'
1355
- else:
1356
- regression_type = 'glm'
366
+ # If the chunks are empty, break the outer while loop
367
+ if not r1_chunk:
368
+ break
1357
369
 
1358
- #df = remove_outliers(df)
370
+ chunk_count += 1
371
+ chunk_data = (r1_chunk, regex, target_sequence, offset_start, expected_end, column_csv, grna_csv, row_csv)
1359
372
 
1360
- if remove_row_column_effect:
373
+ # Process chunks in parallel
374
+ result = pool.apply_async(process_chunk, (chunk_data,))
375
+ df, unique_combinations, qc_df = result.get()
1361
376
 
1362
- ## 1. Fit the initial model with row and column to estimate their effects
1363
- ## 2. Fit the initial model using the specified regression type
1364
- ## 3. Calculate the residuals
1365
- ### Residual calculation: Residuals are the differences between the observed and predicted values. This step checks if the initial_model has an attribute resid (residuals). If it does, it directly uses them. Otherwise, it calculates residuals manually by subtracting the predicted values from the observed values (y_with_row_col).
1366
- ## 4. Use the residuals as the new dependent variable in the final regression model without row and column
1367
- ### Formula creation: A new regression formula is created, excluding row and column effects, with residuals as the new dependent variable.
1368
- ### Matrix creation: dmatrices is used again to create new design matrices (X for independent variables and y for the new dependent variable, residuals) based on the new formula and the dataframe df.
1369
- #### Remove Confounding Effects:Variables like row and column can introduce systematic biases or confounding effects that might obscure the relationships between the dependent variable and the variables of interest (fraction:gene and fraction:grna).
1370
- #### By first estimating the effects of row and column and then using the residuals (the part of the dependent variable that is not explained by row and column), we can focus the final regression model on the relationships of interest without the interference from row and column.
377
+ # Queue the results for saving
378
+ save_queue.put((df, unique_combinations, qc_df))
1371
379
 
1372
- #### Reduce Multicollinearity: Including variables like row and column along with other predictors can sometimes lead to multicollinearity, where predictors are highly correlated with each other. This can make it difficult to determine the individual effect of each predictor.
1373
- #### By regressing out the effects of row and column first, we reduce potential multicollinearity issues in the final model.
1374
-
1375
- # Fit the initial model with row and column to estimate their effects
1376
- formula_with_row_col = f'{dependent_variable} ~ row + column'
1377
- y_with_row_col, X_with_row_col = dmatrices(formula_with_row_col, data=df, return_type='dataframe')
380
+ end_time = time.time()
381
+ chunk_time = end_time - start_time
382
+ time_ls.append(chunk_time)
383
+ print_progress(files_processed=chunk_count, files_to_process=chunks_nr, n_jobs=n_jobs, time_ls=time_ls, batch_size=chunk_size, operation_type="Mapping Barcodes")
1378
384
 
1379
- # Fit the initial model using the specified regression type
1380
- initial_model = regression_model(X_with_row_col, y_with_row_col, regression_type=regression_type, alpha=alpha)
385
+ if test:
386
+ print(f'First 1000 lines in chunk 1')
387
+ print(df[:100])
388
+ break
1381
389
 
1382
- # Calculate the residuals manually
1383
- if hasattr(initial_model, 'resid'):
1384
- df['residuals'] = initial_model.resid
1385
- else:
1386
- df['residuals'] = y_with_row_col.values.ravel() - initial_model.predict(X_with_row_col)
390
+ # Cleanup the pool
391
+ pool.close()
392
+ pool.join()
1387
393
 
1388
- # Use the residuals as the new dependent variable in the final regression model without row and column
1389
- formula_without_row_col = 'residuals ~ fraction:gene + fraction:grna'
1390
- y, X = dmatrices(formula_without_row_col, data=df, return_type='dataframe')
394
+ # Send stop signal to saver process
395
+ save_queue.put("STOP")
396
+ save_process.join()
1391
397
 
1392
- # Plot histogram of the residuals
1393
- plot_histogram(df, 'residuals')
398
+ def generate_barecode_mapping(settings={}):
1394
399
 
1395
- # Scale the independent variables and residuals
1396
- scaler_X = MinMaxScaler()
1397
- scaler_y = MinMaxScaler()
1398
- X = pd.DataFrame(scaler_X.fit_transform(X), columns=X.columns)
1399
- y = scaler_y.fit_transform(y)
400
+ from .settings import set_default_generate_barecode_mapping
401
+ from .utils import save_settings
402
+ from .io import parse_gz_files
1400
403
 
1401
- else:
1402
- formula = f'{dependent_variable} ~ fraction:gene + fraction:grna + row + column'
1403
- y, X = dmatrices(formula, data=df, return_type='dataframe')
404
+ settings = set_default_generate_barecode_mapping(settings)
405
+ save_settings(settings, name=f"sequencing_{settings['mode']}_{settings['single_direction']}", show=True)
1404
406
 
1405
- plot_histogram(y, dependent_variable)
407
+ regex = settings['regex']
1406
408
 
1407
- # Scale the independent variables and dependent variable
1408
- scaler_X = MinMaxScaler()
1409
- scaler_y = MinMaxScaler()
1410
- X = pd.DataFrame(scaler_X.fit_transform(X), columns=X.columns)
1411
- y = scaler_y.fit_transform(y)
409
+ print(f'Using regex: {regex} to extract barcode information')
1412
410
 
1413
- groups = df['prc'] if regression_type == 'mixed' else None
1414
- print(f'performing {regression_type} regression')
1415
- model = regression_model(X, y, regression_type=regression_type, groups=groups, alpha=alpha, remove_row_column_effect=remove_row_column_effect)
1416
-
1417
- # Get the model coefficients and p-values
1418
- if regression_type in ['ols','gls','wls','rlm','glm','mixed','quantile','logit','probit','poisson']:
1419
- coefs = model.params
1420
- p_values = model.pvalues
1421
-
1422
- coef_df = pd.DataFrame({
1423
- 'feature': coefs.index,
1424
- 'coefficient': coefs.values,
1425
- 'p_value': p_values.values
1426
- })
1427
- elif regression_type in ['ridge', 'lasso']:
1428
- coefs = model.coef_
1429
- coefs = np.array(coefs).flatten()
1430
- # Calculate p-values
1431
- p_values = calculate_p_values(X, y, model)
1432
- p_values = np.array(p_values).flatten()
1433
-
1434
- # Create a DataFrame for the coefficients and p-values
1435
- coef_df = pd.DataFrame({
1436
- 'feature': X.columns,
1437
- 'coefficient': coefs,
1438
- 'p_value': p_values})
1439
- else:
1440
- coefs = model.coef_
1441
- intercept = model.intercept_
1442
- feature_names = X.design_info.column_names
1443
-
1444
- coef_df = pd.DataFrame({
1445
- 'feature': feature_names,
1446
- 'coefficient': coefs
1447
- })
1448
- coef_df.loc[0, 'coefficient'] += intercept
1449
- coef_df['p_value'] = np.nan # Placeholder since sklearn doesn't provide p-values
1450
-
1451
- coef_df['-log10(p_value)'] = -np.log10(coef_df['p_value'])
1452
- coef_df_v = coef_df[coef_df['feature'] != 'Intercept']
1453
-
1454
- # Create the highlight column
1455
- coef_df['highlight'] = coef_df['feature'].apply(lambda x: '220950' in x)
1456
- coef_df = coef_df[~coef_df['feature'].str.contains('row|column')]
1457
- volcano_plot(coef_df, volcano_path)
1458
-
1459
- return model, coef_df
1460
-
1461
- def perform_regression(df, settings):
1462
-
1463
- from spacr.plot import plot_plates
1464
- from .utils import merge_regression_res_with_metadata
1465
- from .settings import get_perform_regression_default_settings
1466
-
1467
- reg_types = ['ols','gls','wls','rlm','glm','mixed','quantile','logit','probit','poisson','lasso','ridge']
1468
- if settings['regression_type'] not in reg_types:
1469
- print(f'Possible regression types: {reg_types}')
1470
- raise ValueError(f"Unsupported regression type {settings['regression_type']}")
1471
-
1472
- if isinstance(df, str):
1473
- df = pd.read_csv(df)
1474
- elif isinstance(df, pd.DataFrame):
1475
- pass
1476
- else:
1477
- raise ValueError("Data must be a DataFrame or a path to a CSV file")
1478
-
1479
-
1480
- if settings['dependent_variable'] not in df.columns:
1481
- print(f'Columns in DataFrame:')
1482
- for col in df.columns:
1483
- print(col)
1484
- raise ValueError(f"Dependent variable {settings['dependent_variable']} not found in the DataFrame")
1485
-
1486
- results_filename = os.path.splitext(os.path.basename(settings['gene_weights_csv']))[0] + '_results.csv'
1487
- hits_filename = os.path.splitext(os.path.basename(settings['gene_weights_csv']))[0] + '_results_significant.csv'
1488
-
1489
- results_filename = settings['regression_type']+'_'+results_filename
1490
- hits_filename = settings['regression_type']+'_'+hits_filename
1491
- if settings['regression_type'] == 'quantile':
1492
- results_filename = str(settings['alpha'])+'_'+results_filename
1493
- hits_filename = str(settings['alpha'])+'_'+hits_filename
1494
- results_path=os.path.join(os.path.dirname(settings['gene_weights_csv']), results_filename)
1495
- hits_path=os.path.join(os.path.dirname(settings['gene_weights_csv']), hits_filename)
1496
-
1497
- settings = get_perform_regression_default_settings(settings)
411
+ samples_dict = parse_gz_files(settings['src'])
1498
412
 
1499
- settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
1500
- settings_dir = os.path.dirname(settings['gene_weights_csv'])
1501
- settings_csv = os.path.join(settings_dir,f"{settings['regression_type']}_regression_settings.csv")
1502
- settings_df.to_csv(settings_csv, index=False)
1503
- display(settings_df)
413
+ print(f'If compression is low and save_h5 is True, saving might take longer than processing.')
1504
414
 
1505
- df = clean_controls(df,settings['pc'],settings['nc'],settings['other'])
415
+ for key in samples_dict:
416
+ if settings['mode'] == 'paired' and samples_dict[key]['R1'] and samples_dict[key]['R2'] or settings['mode'] == 'single' and samples_dict[key]['R1'] or settings['mode'] == 'single' and samples_dict[key]['R2']:
417
+ key_mode = f"{key}_{settings['mode']}"
418
+ if settings['mode'] == 'single':
419
+ key_mode = f"{key_mode}_{settings['single_direction']}"
420
+ dst = os.path.join(settings['src'], key_mode)
421
+ hdf5_file = os.path.join(dst, 'annotated_reads.h5')
422
+ unique_combinations_csv = os.path.join(dst, 'unique_combinations.csv')
423
+ qc_csv_file = os.path.join(dst, 'qc.csv')
424
+ os.makedirs(dst, exist_ok=True)
1506
425
 
1507
- if 'prediction_probability_class_1' in df.columns:
1508
- if not settings['class_1_threshold'] is None:
1509
- df['predictions'] = (df['prediction_probability_class_1'] >= settings['class_1_threshold']).astype(int)
426
+ print(f'Analyzing reads from sample {key}')
1510
427
 
1511
- dependent_df, dependent_variable = process_scores(df, settings['dependent_variable'], settings['plate'], settings['min_cell_count'], settings['agg_type'], settings['transform'])
1512
-
1513
- display(dependent_df)
1514
-
1515
- independent_df = precess_reads(settings['gene_weights_csv'], settings['fraction_threshold'], settings['plate'])
1516
- display(independent_df)
1517
-
1518
- merged_df = pd.merge(independent_df, dependent_df, on='prc')
1519
-
1520
- merged_df[['plate', 'row', 'column']] = merged_df['prc'].str.split('_', expand=True)
1521
-
1522
- if settings['transform'] is None:
1523
- _ = plot_plates(df, variable=dependent_variable, grouping='mean', min_max='allq', cmap='viridis', min_count=settings['min_cell_count'])
428
+ if settings['mode'] == 'paired':
429
+ function = paired_read_chunked_processing
430
+ R1=samples_dict[key]['R1']
431
+ R2=samples_dict[key]['R2']
432
+
433
+ elif settings['mode'] == 'single':
434
+ function = single_read_chunked_processing
435
+
436
+ if settings['single_direction'] == 'R1':
437
+ R1=samples_dict[key]['R1']
438
+ R2=None
439
+ elif settings['single_direction'] == 'R2':
440
+ R1=samples_dict[key]['R2']
441
+ R2=None
442
+
443
+ function(r1_file=R1,
444
+ r2_file=R2,
445
+ regex=regex,
446
+ target_sequence=settings['target_sequence'],
447
+ offset_start=settings['offset_start'],
448
+ expected_end=settings['expected_end'],
449
+ column_csv=settings['column_csv'],
450
+ grna_csv=settings['grna_csv'],
451
+ row_csv=settings['row_csv'],
452
+ save_h5 = settings['save_h5'],
453
+ comp_type = settings['comp_type'],
454
+ comp_level=settings['comp_level'],
455
+ hdf5_file=hdf5_file,
456
+ unique_combinations_csv=unique_combinations_csv,
457
+ qc_csv_file=qc_csv_file,
458
+ chunk_size=settings['chunk_size'],
459
+ n_jobs=settings['n_jobs'],
460
+ test=settings['test'])
461
+
462
+ # Function to read the CSV, compute reverse complement, and save it
463
+ def barecodes_reverse_complement(csv_file):
464
+
465
+ def reverse_complement(sequence):
466
+ complement = {'A': 'T', 'T': 'A', 'G': 'C', 'C': 'G', 'N': 'N'}
467
+ return ''.join(complement[base] for base in reversed(sequence))
468
+
469
+ # Read the CSV file
470
+ df = pd.read_csv(csv_file)
1524
471
 
1525
- model, coef_df = regression(merged_df, settings['gene_weights_csv'], dependent_variable, settings['regression_type'], settings['alpha'], settings['remove_row_column_effect'])
1526
-
1527
- coef_df.to_csv(results_path, index=False)
1528
-
1529
- if settings['regression_type'] == 'lasso':
1530
- significant = coef_df[coef_df['coefficient'] > 0]
1531
-
1532
- else:
1533
- significant = coef_df[coef_df['p_value']<= 0.05]
1534
- #significant = significant[significant['coefficient'] > 0.1]
1535
- significant.sort_values(by='coefficient', ascending=False, inplace=True)
1536
- significant = significant[~significant['feature'].str.contains('row|column')]
1537
-
1538
- if settings['regression_type'] == 'ols':
1539
- print(model.summary())
1540
-
1541
- significant.to_csv(hits_path, index=False)
472
+ # Compute reverse complement for each sequence
473
+ df['sequence'] = df['sequence'].apply(reverse_complement)
1542
474
 
1543
- me49 = '/home/carruthers/Documents/TGME49_Summary.csv'
1544
- gt1 = '/home/carruthers/Documents/TGGT1_Summary.csv'
475
+ # Create the new filename
476
+ file_dir, file_name = os.path.split(csv_file)
477
+ file_name_no_ext = os.path.splitext(file_name)[0]
478
+ new_filename = os.path.join(file_dir, f"{file_name_no_ext}_RC.csv")
1545
479
 
1546
- _ = merge_regression_res_with_metadata(hits_path, me49, name='_me49_metadata')
1547
- _ = merge_regression_res_with_metadata(hits_path, gt1, name='_gt1_metadata')
1548
- _ = merge_regression_res_with_metadata(results_path, me49, name='_me49_metadata')
1549
- _ = merge_regression_res_with_metadata(results_path, gt1, name='_gt1_metadata')
480
+ # Save the DataFrame with the reverse complement sequences
481
+ df.to_csv(new_filename, index=False)
1550
482
 
1551
- print('Significant Genes')
1552
- display(significant)
1553
- return coef_df
483
+ print(f"Reverse complement file saved as {new_filename}")