spacr 0.0.20__py3-none-any.whl → 0.0.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/foldseek.py ADDED
@@ -0,0 +1,793 @@
1
+ import os, shutil, subprocess, tarfile, glob, requests, time, random
2
+ import pandas as pd
3
+ from scipy.stats import fisher_exact
4
+ from statsmodels.stats.multitest import multipletests
5
+ from concurrent.futures import ProcessPoolExecutor, as_completed
6
+ import seaborn as sns
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+
10
+ import requests, time, random
11
+ from concurrent.futures import ProcessPoolExecutor, as_completed
12
+
13
+ import pandas as pd
14
+ from scipy.stats import fisher_exact
15
+ from statsmodels.stats.multitest import multipletests
16
+ from concurrent.futures import ProcessPoolExecutor, as_completed
17
+ import pandas as pd
18
+ from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
19
+
20
+ import seaborn as sns
21
+ import matplotlib.pyplot as plt
22
+ import numpy as np
23
+ from matplotlib.ticker import FixedLocator
24
+
25
+ def run_command(command):
26
+ print(f"Executing: {command}")
27
+ result = subprocess.run(command, shell=True, capture_output=True, text=True)
28
+ if result.returncode != 0:
29
+ print(f"Error running command: {command}")
30
+ print(result.stdout)
31
+ print(result.stderr)
32
+ return False
33
+ return True
34
+
35
+ def add_headers_and_save_csv(input_tsv_path, output_csv_path, results_dir):
36
+
37
+ headers = [
38
+ 'query', 'target', 'fident', 'alnlen', 'mismatch', 'gapopen',
39
+ 'qstart', 'qend', 'tstart', 'tend', 'evalue', 'bits'
40
+ ]
41
+
42
+ # Rename the aln_tmscore file to have a .tsv extension if it doesn't already
43
+ input_tsv_path = f"{results_dir}/aln_tmscore"
44
+ if not input_tsv_path.endswith('.tsv'):
45
+ os.rename(input_tsv_path, input_tsv_path + '.tsv')
46
+ input_tsv_path += '.tsv'
47
+
48
+ # Read the TSV file into a DataFrame
49
+ df = pd.read_csv(input_tsv_path, sep='\t', header=None)
50
+
51
+ # Assign headers to the DataFrame
52
+ df.columns = headers
53
+
54
+ # Save the DataFrame as a CSV file
55
+ df.to_csv(output_csv_path, index=False)
56
+ print(f"File saved as {output_csv_path}")
57
+
58
+ def generate_database(path, base_dir, mode='file'):
59
+ structures_dir = f'{base_dir}/structures'
60
+ os.makedirs(structures_dir, exist_ok=True)
61
+
62
+ if mode == 'tar':
63
+ if os.path.exists(structures_dir) and not os.listdir(structures_dir):
64
+ if not os.path.exists(path):
65
+ print(f"Structure tar file {path} not found.")
66
+ else:
67
+ tar = tarfile.open(path)
68
+ tar.extractall(path=structures_dir)
69
+ tar.close()
70
+ if not run_command(f"foldseek createdb {structures_dir} {structures_dir}/structures_db"):
71
+ raise Exception("Failed to create structures database.")
72
+
73
+ if mode == 'file':
74
+ if os.path.exists(structures_dir) and not os.listdir(structures_dir):
75
+ if not os.path.exists(path):
76
+ print(f"Structure folder {path} not found.")
77
+ else:
78
+ for file in os.listdir(path):
79
+ file_path = os.path.join(path, file)
80
+ new_path = os.path.join(structures_dir, file)
81
+ #print(path)
82
+ #print(structures_dir)
83
+ shutil.copy(file_path, new_path)
84
+
85
+ if not run_command(f"foldseek createdb {structures_dir} {structures_dir}/structures_db"):
86
+ raise Exception("Failed to create structures database.")
87
+ return structures_dir
88
+
89
+ def align_to_database(structure_fldr_path, base_dir='/home/carruthers/foldseek', cores=25):
90
+
91
+ databases_dir = f'{base_dir}/foldseek_databases'
92
+ results_dir = f'{base_dir}/results'
93
+ tmp_dir = f'{base_dir}/tmp'
94
+
95
+ os.makedirs(databases_dir, exist_ok=True)
96
+ os.makedirs(results_dir, exist_ok=True)
97
+ os.makedirs(tmp_dir, exist_ok=True)
98
+
99
+ # Check and download PDB database if not exists
100
+ pdb_db_path = os.path.join(databases_dir, "pdb")
101
+ if not os.path.exists(pdb_db_path):
102
+ print("Downloading PDB database...")
103
+ if not run_command(f"foldseek databases PDB {pdb_db_path} {tmp_dir}"):
104
+ raise Exception("Failed to download PDB database.")
105
+
106
+ # Check and download AlphaFold database if not exists
107
+ afdb_db_path = os.path.join(databases_dir, "afdb")
108
+ if not os.path.exists(afdb_db_path):
109
+ print("Downloading AlphaFold database...")
110
+ if not run_command(f"foldseek databases Alphafold/Proteome {afdb_db_path} {tmp_dir}"):
111
+ raise Exception("Failed to download AlphaFold database.")
112
+
113
+ structures_dir = generate_database(structure_fldr_path, base_dir, mode='file')
114
+
115
+ for i, targetDB in enumerate([pdb_db_path, afdb_db_path]):
116
+
117
+ if i == 0:
118
+ results_dir = os.path.join(base_dir, 'results', "pdb")
119
+ os.makedirs(results_dir, exist_ok=True)
120
+ print("Running Foldseek on PDB...")
121
+ if i == 1:
122
+ results_dir = os.path.join(base_dir, 'results', "afdb")
123
+ os.makedirs(results_dir, exist_ok=True)
124
+ print("Running Foldseek on AFdb...")
125
+
126
+ aln_tmscore = f"{results_dir}/aln_tmscore"
127
+ aln_tmscore_tsv = f"{results_dir}/aln_tmscore.tsv"
128
+
129
+ queryDB = f"{structures_dir}/structures_db"
130
+ targetDB = pdb_db_path
131
+ aln = f"{results_dir}/results"
132
+
133
+ if not run_command(f"foldseek search {queryDB} {targetDB} {aln} {tmp_dir} -a --threads {cores}"):
134
+ raise Exception("Foldseek search against PDB failed.")
135
+
136
+ if not run_command(f"foldseek aln2tmscore {queryDB} {targetDB} {aln} {aln_tmscore} --threads {cores}"):
137
+ raise Exception("Foldseek aln2tmscore against PDB failed.")
138
+
139
+
140
+ output_format = "query,target,fident,alnlen,mismatch,gapopen,qstart,qend,tstart,tend,evalue,bits"
141
+
142
+ if not run_command(f"foldseek createtsv {queryDB} {targetDB} {aln} {aln_tmscore} {aln_tmscore_tsv} --format-output {output_format}"):
143
+ raise Exception("Foldseek createtsv against PDB failed.")
144
+
145
+ input_tsv_path = f"{results_dir}/aln_tmscore"
146
+ output_csv_path = f"{results_dir}/aln_tmscore.csv"
147
+
148
+ # Call the function with the path to your TSV file and the output CSV file path
149
+ add_headers_and_save_csv(input_tsv_path, output_csv_path, results_dir)
150
+
151
+ def check_uniprot_structure(uniprot_id):
152
+ import requests
153
+ base_url = "https://www.ebi.ac.uk/proteins/api/proteins"
154
+ headers = {"Accept": "application/json"}
155
+ response = requests.get(f"{base_url}/{uniprot_id}", headers=headers)
156
+ if response.status_code == 200:
157
+ data = response.json()
158
+ print(data) # Print the whole JSON to examine its structure
159
+ else:
160
+ print(f"Failed to retrieve data for {uniprot_id}: {response.status_code}")
161
+
162
+ def get_ec_numbers(data):
163
+ try:
164
+ # Navigate through the nested structure with checks at each step
165
+ protein_info = data.get('protein', {})
166
+ recommended_name = protein_info.get('recommendedName', {})
167
+ ec_numbers = recommended_name.get('ecNumber', [])
168
+
169
+ # Extract the 'value' field from each EC number entry
170
+ return ", ".join(ec['value'] for ec in ec_numbers if 'value' in ec)
171
+ except Exception as e:
172
+ print(f"Failed to extract EC numbers: {str(e)}")
173
+ return ""
174
+
175
+ def process_protein_data(data, verbose=False):
176
+ if data is None:
177
+ return None
178
+
179
+ uniprot_id = data.get('accession')
180
+ protein_data = {}
181
+ protein_data[uniprot_id] = {
182
+ 'UniProt ID': uniprot_id,
183
+ 'Entry Name': data.get('id'),
184
+ 'Organism': next((name['value'] for name in data.get('organism', {}).get('names', []) if name['type'] == 'scientific'), None),
185
+ #'Taxonomic Lineage': ", ".join(data.get('organism', {}).get('lineage', [])),
186
+ 'Taxonomy ID': data.get('organism', {}).get('taxonomy'),
187
+ 'Sequence Length': data.get('sequence', {}).get('length'),
188
+ #'EC Number': ", ".join([ec['value'] for ec in data.get('protein', {}).get('recommendedName', {}).get('ecNumber', [])]),
189
+ 'EC Number': get_ec_numbers(data),
190
+ 'Function': "; ".join([func['text'][0]['value'] for func in data.get('comments', []) if func['type'] == 'FUNCTION']),
191
+ 'Recommended Name': data.get('protein', {}).get('recommendedName', {}).get('fullName', {}).get('value', ''),
192
+ 'Alternative Names': "; ".join([alt['fullName']['value'] for alt in data.get('protein', {}).get('alternativeName', [])]),
193
+ 'GO Biological Process': [],
194
+ 'GO Cellular Component': [],
195
+ 'GO Molecular Function': [],
196
+ 'GO IDs': [],
197
+ 'KEGG': [],
198
+ 'OrthoDB': [],
199
+ 'Sequence': data.get('sequence', {}).get('sequence', ''),
200
+ 'Family and Domains': {},
201
+ 'Catalytic Activity': "; ".join([cat['reaction']['name'] for cat in data.get('comments', []) if cat['type'] == 'CATALYTIC_ACTIVITY']),
202
+ 'Cofactor': "; ".join([cof['cofactors'][0]['name'] for cof in data.get('comments', []) if cof['type'] == 'COFACTOR']),
203
+ 'Enzyme Regulation': "; ".join([reg['text'][0]['value'] for reg in data.get('comments', []) if reg['type'] == 'ENZYME_REGULATION']),
204
+ 'Disease Association': "; ".join([dis['text'][0]['value'] for dis in data.get('comments', []) if dis['type'] == 'DISEASE']),
205
+ 'Interaction Partners': "; ".join([inter['id'] for inter in data.get('dbReferences', []) if inter['type'] == 'InterPro'])
206
+ }
207
+
208
+ # Subcellular Location processing
209
+ protein_data[uniprot_id].update({
210
+ 'sub_loc_Intermembrane': "",
211
+ 'sub_loc_Topological_Domain': "",
212
+ 'sub_loc_Subcellular_Location': "",
213
+ 'sub_loc_Transmembrane': ""
214
+ })
215
+
216
+ for loc in data.get('comments', []):
217
+ if loc['type'] == 'SUBCELLULAR_LOCATION':
218
+ for component in loc.get('locations', []):
219
+ if 'topology' in component:
220
+ protein_data[uniprot_id]['sub_loc_Topological_Domain'] += component['topology']['value'] + "; "
221
+ if 'orientation' in component:
222
+ protein_data[uniprot_id]['sub_loc_Intermembrane'] += component['orientation']['value'] + "; "
223
+ if 'location' in component:
224
+ protein_data[uniprot_id]['sub_loc_Subcellular_Location'] += component['location']['value'] + "; "
225
+ if 'subcellularLocation' in component:
226
+ protein_data[uniprot_id]['sub_loc_Transmembrane'] += component['subcellularLocation']['value'] + "; "
227
+
228
+ # Initialize PTM/Processing details
229
+ ptms = set(ptm['type'] for ptm in data.get('features', []) if ptm['category'] == 'PTM')
230
+ for ptm in ptms:
231
+ protein_data[uniprot_id][ptm] = []
232
+
233
+ # Process each PTM type
234
+ for ptm in data.get('features', []):
235
+ if ptm['category'] == 'PTM' and ptm['type'] in protein_data[uniprot_id]:
236
+ ptm_description = ptm.get('description', '')
237
+ ptm_details = f"{ptm_description} (positions {ptm.get('begin')} to {ptm.get('end')})"
238
+ protein_data[uniprot_id][ptm['type']].append(ptm_details)
239
+
240
+ # Gene Ontology Annotations
241
+ for go in data.get('dbReferences', []):
242
+ if go['type'] == 'GO' and 'properties' in go:
243
+ go_term = go['properties']['term']
244
+ if go_term.startswith('P:'):
245
+ protein_data[uniprot_id]['GO Biological Process'].append(go_term[2:])
246
+ elif go_term.startswith('C:'):
247
+ protein_data[uniprot_id]['GO Cellular Component'].append(go_term[2:])
248
+ elif go_term.startswith('F:'):
249
+ protein_data[uniprot_id]['GO Molecular Function'].append(go_term[2:])
250
+ protein_data[uniprot_id]['GO IDs'].append(go['id'])
251
+
252
+ # External sources
253
+ for xref in data.get('dbReferences', []):
254
+ if xref['type'] == 'KEGG':
255
+ protein_data[uniprot_id]['KEGG'].append(xref['id'])
256
+ elif xref['type'] == 'OrthoDB':
257
+ protein_data[uniprot_id]['OrthoDB'].append(xref['id'])
258
+
259
+ # Initialize Family and Domains from 'features'
260
+ for feature in data.get('features', []):
261
+ if feature['type'] in ['DOMAIN', 'MOTIF', 'REGION']:
262
+ domain_key = f"{feature['type']} {feature.get('description', 'N/A')}"
263
+ if domain_key not in protein_data[uniprot_id]:
264
+ protein_data[uniprot_id][domain_key] = f"Positions {feature.get('begin')} to {feature.get('end')}"
265
+ if verbose:
266
+ print(protein_data)
267
+ return protein_data
268
+
269
+ def fetch_data_for_uniprot_id(uniprot_id):
270
+ """ Fetch data for a single UniProt ID from the UniProt API. """
271
+ base_url = "https://www.ebi.ac.uk/proteins/api/proteins"
272
+ headers = {"Accept": "application/json"}
273
+ request_url = f"{base_url}/{uniprot_id}"
274
+ response = requests.get(request_url, headers=headers)
275
+ if response.status_code == 200:
276
+ return response.json()
277
+ else:
278
+ print(f"Failed to retrieve data for {uniprot_id}: {response.status_code}")
279
+ return None
280
+
281
+ def fetch_and_aggregate_functional_data(uniprot_ids, num_workers=4):
282
+ """
283
+ Fetch and process functional data for a list of UniProt IDs using multiple processes.
284
+ """
285
+ # Create a process pool to fetch data asynchronously
286
+ with ProcessPoolExecutor(max_workers=num_workers) as executor:
287
+ # Map each UniProt ID to a future object responsible for fetching and processing its data
288
+ future_to_uniprot = {executor.submit(fetch_data_for_uniprot_id, uid): uid for uid in uniprot_ids}
289
+
290
+ # Dictionary to hold processed protein data
291
+ protein_data = {}
292
+
293
+ # Collect results as they are completed
294
+ for future in as_completed(future_to_uniprot):
295
+ data = future.result()
296
+ if data:
297
+ processed_data = process_protein_data(data)
298
+ if processed_data:
299
+ # Each key in processed_data should be a UniProt ID and the value a dictionary of attributes
300
+ protein_data.update(processed_data) # Merge the processed data into the main dictionary
301
+
302
+ # Convert the accumulated dictionary into a pandas DataFrame
303
+ df = pd.DataFrame.from_dict(protein_data, orient='index')
304
+
305
+ return df
306
+
307
+ def get_unique_uniprot_ids(mapping):
308
+ # Extract all UniProt IDs from the mapping
309
+ all_uniprot_ids = set(mapping.values()) # This gets all the unique values (UniProt IDs)
310
+ return list(all_uniprot_ids)
311
+
312
+ def pdb_to_uniprot(pdb_chain_map = {}):
313
+
314
+ import re, time, json, zlib, requests
315
+ from xml.etree import ElementTree
316
+ from urllib.parse import urlparse, parse_qs, urlencode
317
+ from requests.adapters import HTTPAdapter, Retry
318
+
319
+ POLLING_INTERVAL = 3
320
+ API_URL = "https://rest.uniprot.org"
321
+ retries = Retry(total=5, backoff_factor=0.25, status_forcelist=[500, 502, 503, 504])
322
+ session = requests.Session()
323
+ session.mount("https://", HTTPAdapter(max_retries=retries))
324
+
325
+ # The maximum number of IDs we can submit in one request
326
+ MAX_IDS_PER_REQUEST = 90000
327
+
328
+ def check_response(response):
329
+ try:
330
+ response.raise_for_status()
331
+ except requests.HTTPError:
332
+ print(response.json())
333
+ raise
334
+
335
+ def submit_id_mapping(from_db, to_db, ids):
336
+ request = requests.post(
337
+ f"{API_URL}/idmapping/run",
338
+ data={"from": from_db, "to": to_db, "ids": ",".join(ids)},
339
+ )
340
+ check_response(request)
341
+ return request.json()["jobId"]
342
+
343
+ def get_next_link(headers):
344
+ re_next_link = re.compile(r'<(.+)>; rel="next"')
345
+ if "Link" in headers:
346
+ match = re_next_link.match(headers["Link"])
347
+ if match:
348
+ return match.group(1)
349
+
350
+ def check_id_mapping_results_ready(job_id):
351
+ while True:
352
+ request = session.get(f"{API_URL}/idmapping/status/{job_id}")
353
+ check_response(request)
354
+ j = request.json()
355
+ if "jobStatus" in j:
356
+ if j["jobStatus"] == "RUNNING":
357
+ print(f"Retrying in {POLLING_INTERVAL}s")
358
+ time.sleep(POLLING_INTERVAL)
359
+ else:
360
+ raise Exception(j["jobStatus"])
361
+ else:
362
+ return bool(j["results"] or j["failedIds"])
363
+
364
+ def get_batch(batch_response, file_format, compressed):
365
+ batch_url = get_next_link(batch_response.headers)
366
+ while batch_url:
367
+ batch_response = session.get(batch_url)
368
+ batch_response.raise_for_status()
369
+ yield decode_results(batch_response, file_format, compressed)
370
+ batch_url = get_next_link(batch_response.headers)
371
+
372
+ def combine_batches(all_results, batch_results, file_format):
373
+ if file_format == "json":
374
+ for key in ("results", "failedIds"):
375
+ if key in batch_results and batch_results[key]:
376
+ all_results[key] += batch_results[key]
377
+ elif file_format == "tsv":
378
+ return all_results + batch_results[1:]
379
+ else:
380
+ return all_results + batch_results
381
+ return all_results
382
+
383
+ def get_id_mapping_results_link(job_id):
384
+ url = f"{API_URL}/idmapping/details/{job_id}"
385
+ request = session.get(url)
386
+ check_response(request)
387
+ return request.json()["redirectURL"]
388
+
389
+ def decode_results(response, file_format, compressed):
390
+ if compressed:
391
+ decompressed = zlib.decompress(response.content, 16 + zlib.MAX_WBITS)
392
+ if file_format == "json":
393
+ j = json.loads(decompressed.decode("utf-8"))
394
+ return j
395
+ elif file_format == "tsv":
396
+ return [line for line in decompressed.decode("utf-8").split("\n") if line]
397
+ elif file_format == "xlsx":
398
+ return [decompressed]
399
+ elif file_format == "xml":
400
+ return [decompressed.decode("utf-8")]
401
+ else:
402
+ return decompressed.decode("utf-8")
403
+ elif file_format == "json":
404
+ return response.json()
405
+ elif file_format == "tsv":
406
+ return [line for line in response.text.split("\n") if line]
407
+ elif file_format == "xlsx":
408
+ return [response.content]
409
+ elif file_format == "xml":
410
+ return [response.text]
411
+ return response.text
412
+
413
+ def get_xml_namespace(element):
414
+ m = re.match(r"\{(.*)\}", element.tag)
415
+ return m.groups()[0] if m else ""
416
+
417
+ def merge_xml_results(xml_results):
418
+ merged_root = ElementTree.fromstring(xml_results[0])
419
+ for result in xml_results[1:]:
420
+ root = ElementTree.fromstring(result)
421
+ for child in root.findall("{http://uniprot.org/uniprot}entry"):
422
+ merged_root.insert(-1, child)
423
+ ElementTree.register_namespace("", get_xml_namespace(merged_root[0]))
424
+ return ElementTree.tostring(merged_root, encoding="utf-8", xml_declaration=True)
425
+
426
+ def print_progress_batches(batch_index, size, total):
427
+ n_fetched = min((batch_index + 1) * size, total)
428
+ print(f"Fetched: {n_fetched} / {total}")
429
+
430
+ def get_id_mapping_results_search(url):
431
+ parsed = urlparse(url)
432
+ query = parse_qs(parsed.query)
433
+ file_format = query["format"][0] if "format" in query else "json"
434
+ if "size" in query:
435
+ size = int(query["size"][0])
436
+ else:
437
+ size = 500
438
+ query["size"] = size
439
+ compressed = (
440
+ query["compressed"][0].lower() == "true" if "compressed" in query else False
441
+ )
442
+ parsed = parsed._replace(query=urlencode(query, doseq=True))
443
+ url = parsed.geturl()
444
+ request = session.get(url)
445
+ check_response(request)
446
+ results = decode_results(request, file_format, compressed)
447
+ total = int(request.headers["x-total-results"])
448
+ print_progress_batches(0, size, total)
449
+ for i, batch in enumerate(get_batch(request, file_format, compressed), 1):
450
+ results = combine_batches(results, batch, file_format)
451
+ print_progress_batches(i, size, total)
452
+ if file_format == "xml":
453
+ return merge_xml_results(results)
454
+ return results
455
+
456
+ def get_id_mapping_results_stream(url):
457
+ if "/stream/" not in url:
458
+ url = url.replace("/results/", "/results/stream/")
459
+ request = session.get(url)
460
+ check_response(request)
461
+ parsed = urlparse(url)
462
+ query = parse_qs(parsed.query)
463
+ file_format = query["format"][0] if "format" in query else "json"
464
+ compressed = (
465
+ query["compressed"][0].lower() == "true" if "compressed" in query else False
466
+ )
467
+ return decode_results(request, file_format, compressed)
468
+
469
+ def extract_uniprot_names(results):
470
+ uniprot_mapping = {}
471
+ for result in results.get('results', []):
472
+ pdb_name = result['from']
473
+ #print(result['to'])
474
+ #time.sleep(1)
475
+ uniprot_name = result['to'].get('primaryAccession', '') #uniProtkbId
476
+ if uniprot_name:
477
+ uniprot_mapping[pdb_name] = uniprot_name
478
+ return uniprot_mapping
479
+
480
+ def chunks(lst, n):
481
+ """Yield successive n-sized chunks from lst."""
482
+ for i in range(0, len(lst), n):
483
+ yield lst[i:i + n]
484
+
485
+ uniprot_names = {}
486
+ formatted_ids = [f"{pdb_id}:{chain}" for pdb_id, chain in pdb_chain_map.items()]
487
+
488
+ # Iterate over each chunk of formatted IDs and submit separate jobs
489
+ for formatted_ids_chunk in chunks(formatted_ids, MAX_IDS_PER_REQUEST):
490
+ #print('chunk',formatted_ids_chunk)
491
+ job_id = submit_id_mapping("PDB", "UniProtKB", formatted_ids_chunk)
492
+ #accession, UniProtKB
493
+ if check_id_mapping_results_ready(job_id):
494
+ link = get_id_mapping_results_link(job_id)
495
+ results = get_id_mapping_results_search(link)
496
+ uniprot_names.update(extract_uniprot_names(results))
497
+ return uniprot_names
498
+
499
+ def functionally_annotate_foldseek_hits(csv_file_path, num_workers=25, limit=None, threshold=None):
500
+
501
+ foldseek_df = pd.read_csv(csv_file_path)
502
+
503
+ if not threshold is None:
504
+ foldseek_df = foldseek_df[foldseek_df['evalue'] < threshold]
505
+
506
+ if not limit is None:
507
+ foldseek_df = foldseek_df.sample(n=limit)
508
+
509
+ # Extract PDB IDs and chain and convert them to uppercase
510
+ foldseek_df['target_pdbID'] = foldseek_df['target'].str.split('-').str[0].str.upper()
511
+ foldseek_df['chain'] = foldseek_df['target'].str.split('_').str[-1]
512
+ unique_pdb_ids = dict(zip(foldseek_df['target_pdbID'], foldseek_df['chain']))
513
+
514
+ print(f'Found {len(unique_pdb_ids)} unique target proteins')
515
+
516
+ # Fetch UniProt mapping for the unique PDB IDs
517
+ unique_pdb_mapping = pdb_to_uniprot(unique_pdb_ids)
518
+ #print(unique_pdb_mapping)
519
+
520
+ # Map the target PDB IDs and chains to UniProt IDs using the unique_pdb_mapping
521
+ foldseek_df['target_uniprotID'] = foldseek_df.apply(
522
+ lambda row: unique_pdb_mapping.get(f"{row['target_pdbID']}:{row['chain']}", pd.NA),
523
+ axis=1
524
+ )
525
+
526
+ #display(foldseek_df)
527
+ #display(unique_pdb_mapping)
528
+ unique_pdb_ids = get_unique_uniprot_ids(unique_pdb_mapping)
529
+ #print(unique_pdb_ids)
530
+ target_metadata_df = fetch_and_aggregate_functional_data(unique_pdb_ids, num_workers=20)
531
+ #display(target_metadata_df)
532
+ merged_df = pd.merge(foldseek_df, target_metadata_df, left_on='target_uniprotID', right_on='UniProt ID')
533
+ return merged_df
534
+
535
+ def _analyze_group(args):
536
+ group, total, feature_columns, query = args
537
+ results = []
538
+ group_total = group.shape[0]
539
+ for feature in feature_columns:
540
+ try:
541
+ all_features = set(group[feature].explode().dropna().unique())
542
+ except TypeError:
543
+ all_features = set(group[feature].dropna().apply(lambda x: x if isinstance(x, list) else [x]).explode().unique())
544
+
545
+ for specific_feature in all_features:
546
+ observed_present = group[feature].apply(lambda x: specific_feature in x if isinstance(x, list) else specific_feature == x).sum()
547
+ observed_absent = group_total - observed_present
548
+ expected_present = group[feature].apply(lambda x: specific_feature in x if isinstance(x, list) else specific_feature == x).sum()
549
+ expected_absent = total - expected_present
550
+
551
+ contingency_table = [[observed_present, observed_absent], [expected_present, expected_absent]]
552
+ odds_ratio, p_value = fisher_exact(contingency_table, 'greater')
553
+
554
+ results.append({
555
+ 'query': query,
556
+ 'feature': specific_feature,
557
+ 'p_value': p_value,
558
+ 'category': feature
559
+ })
560
+ return results
561
+
562
+ def perform_enrichment_analysis(df, num_workers=4):
563
+
564
+ exclude_columns = [
565
+ 'query', 'target', 'fident', 'alnlen', 'mismatch', 'gapopen', 'qstart', 'qend',
566
+ 'tstart', 'tend', 'evalue', 'bits', 'target_pdbID', 'target_uniprotID', 'UniProt ID',
567
+ 'Entry Name', 'Organism', 'Taxonomy ID', 'Sequence Length', 'Sequence', 'EC Number', 'Function',
568
+ 'Recommended Name', 'Alternative Names'
569
+ ]
570
+ feature_columns = df.columns.difference(exclude_columns)
571
+ total = df.shape[0]
572
+
573
+ with ProcessPoolExecutor(max_workers=num_workers) as executor:
574
+ future_to_group = {executor.submit(_analyze_group, (group, total, feature_columns, query)): query for query, group in df.groupby('query')}
575
+ results = []
576
+ for future in as_completed(future_to_group):
577
+ results.extend(future.result())
578
+
579
+ results_df = pd.DataFrame(results)
580
+ correction_method = 'fdr_bh'
581
+ p_adjust = multipletests(results_df['p_value'], method=correction_method)
582
+ results_df['adjusted_p_value'] = p_adjust[1]
583
+
584
+ return results_df
585
+
586
+ def compare_features(enrichment_results, verbose=False):
587
+
588
+ # Check feature matches
589
+ def check_feature_match(row):
590
+ category = row['category']
591
+ feature = row['feature']
592
+ # Check if the category column exists in the DataFrame
593
+ if category in protein_data_df.columns:
594
+ # Flatten the list if it's not already a scalar
595
+ values = row[category]
596
+ if verbose:
597
+ print(f'category:{category}, feature:{feature}, values:{values}')
598
+ if isinstance(values, list) or isinstance(values, np.ndarray):
599
+ if any(pd.isna(values)):
600
+ return np.nan
601
+ else:
602
+ # Check if the feature is within the list of values
603
+ return 1 if feature in values else 0
604
+ else:
605
+ # Direct comparison if it's scalar
606
+ if pd.isna(values):
607
+ return np.nan
608
+ return 1 if feature == values else 0
609
+ else:
610
+ print(f'Could not find {category} in columns')
611
+ return np.nan
612
+
613
+ # Assuming the format 'something-UniProtID' in the 'query' column
614
+ enrichment_results['UniProt ID'] = enrichment_results['query'].str.split('-').str[1]
615
+
616
+ # Get unique UniProt IDs
617
+ uniprot_ids = enrichment_results['UniProt ID'].unique().tolist()
618
+
619
+ # Fetch data for these UniProt IDs
620
+ protein_data_df = fetch_and_aggregate_functional_data(uniprot_ids)
621
+
622
+ # Assuming the fetched protein_data_df is indexed by 'UniProt ID', merge it
623
+ comparison_df = pd.merge(enrichment_results, protein_data_df, on='UniProt ID', how='left')
624
+
625
+ # Filter significant features
626
+ significant_features = comparison_df[comparison_df['adjusted_p_value'] < 0.05]
627
+
628
+ # Apply the checking function to each row
629
+ significant_features['comparison'] = significant_features.apply(check_feature_match, axis=1)
630
+
631
+ return significant_features
632
+
633
+ def calculate_feature_metrics(comparison_df):
634
+ # Drop rows where comparison is NaN
635
+ filtered_df = comparison_df.dropna(subset=['comparison'])
636
+
637
+ # Convert 'comparison' to integer for metrics calculation
638
+ filtered_df['comparison'] = filtered_df['comparison'].astype(int)
639
+
640
+ # Initialize dictionary to store metrics by category and feature
641
+ metrics = []
642
+
643
+ # Group by category and feature for detailed metrics
644
+ grouped = filtered_df.groupby(['category', 'feature'])
645
+ for (category, feature), group in grouped:
646
+ # True labels are 'comparison', predictions assume 1 if 'comparison' > 0 (already true for 1 and 0)
647
+ true_labels = group['comparison']
648
+ pred_labels = (group['comparison'] > 0).astype(int) # Prediction: 1 if comparison > 0, else 0
649
+
650
+ # Calculating precision, recall, F1-score, and accuracy
651
+ precision = precision_score(true_labels, pred_labels, zero_division=0)
652
+ recall = recall_score(true_labels, pred_labels, zero_division=0)
653
+ f1 = f1_score(true_labels, pred_labels, zero_division=0)
654
+ accuracy = accuracy_score(true_labels, pred_labels)
655
+
656
+ # Append results to metrics list
657
+ metrics.append({
658
+ 'category': category,
659
+ 'feature': feature,
660
+ 'precision': precision,
661
+ 'recall': recall,
662
+ 'f1_score': f1,
663
+ 'accuracy': accuracy
664
+ })
665
+
666
+ # Convert list of metrics to DataFrame
667
+ metrics_df = pd.DataFrame(metrics)
668
+
669
+ return metrics_df
670
+
671
+ def visualize_heatmap(data, pivot_index, pivot_columns, values):
672
+ # Pivoting the data for heatmap
673
+ heatmap_data = data.pivot_table(index=pivot_index, columns=pivot_columns, values=values, aggfunc='first')
674
+
675
+ # Create a figure and axes object
676
+ fig, ax = plt.subplots(figsize=(10, 8))
677
+
678
+ # Create the heatmap on the specified axes
679
+ sns.heatmap(heatmap_data, annot=True, cmap='viridis', fmt=".2g", linewidths=.5, ax=ax)
680
+
681
+ ax.set_title('Heatmap of Enriched Features Across Queries')
682
+ ax.set_ylabel('Query')
683
+ ax.set_xlabel('Feature')
684
+
685
+ # Return the figure object
686
+ return fig
687
+
688
+ def visualize_bar_chart(data):
689
+ # Counting occurrences of significant features
690
+ feature_counts = data['feature'].value_counts().reset_index()
691
+ feature_counts.columns = ['feature', 'counts']
692
+
693
+ # Create a figure and axes object
694
+ fig, ax = plt.subplots(figsize=(12, 8))
695
+
696
+ # Create the bar plot on the specified axes
697
+ bar_plot = sns.barplot(x='counts', y='feature', data=feature_counts.head(20), ax=ax)
698
+
699
+ # Optional: set color palette manually if needed
700
+ #bar_plot.set_palette(sns.color_palette("viridis", n_colors=20))
701
+
702
+ ax.set_title('Top Enriched Features Across All Queries')
703
+ ax.set_xlabel('Counts of Significant Enrichment')
704
+ ax.set_ylabel('Features')
705
+
706
+ # Properly setting the x-ticks and rotating them
707
+ ax.set_xticks(ax.get_xticks()) # This ensures the ticks are explicitly set
708
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
709
+
710
+ # Return the figure object
711
+ return fig
712
+
713
+ def visualize_dot_plot(data):
714
+ # Adjusting data for visualization
715
+ data['-log10(p_value)'] = -np.log10(data['adjusted_p_value'])
716
+
717
+ # Create a figure object
718
+ fig, ax = plt.subplots(figsize=(10, 8))
719
+
720
+ # Create the plot on the specified axes
721
+ sns.scatterplot(data=data, x='feature', y='query', size='-log10(p_value)',
722
+ legend=None, sizes=(20, 200), hue='category', ax=ax)
723
+
724
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
725
+ ax.set_title('Dot Plot of Feature Enrichment Across Queries')
726
+ ax.set_xlabel('Feature')
727
+ ax.set_ylabel('Query')
728
+ ax.grid(True)
729
+
730
+ # Return the figure object
731
+ return fig
732
+
733
+ def analyze_results(foldseek_csv_path, base_dir):
734
+
735
+ results = functionally_annotate_foldseek_hits(foldseek_csv_path, limit=None, threshold=None)
736
+ #display(results)
737
+
738
+ enrichment_results = perform_enrichment_analysis(results, num_workers=25)
739
+ filtered_results = enrichment_results[enrichment_results['adjusted_p_value'] < 0.05]
740
+ filtered_results = filtered_results[filtered_results['feature'].str.strip().astype(bool)]
741
+ #display(filtered_results)
742
+
743
+ fldr = os.path.dirname(foldseek_csv_path)
744
+
745
+ heatmap_path = os.path.join(fldr, 'heatmap.pdf')
746
+ bar_path = os.path.join(fldr, 'bar.pdf')
747
+ dot_path = os.path.join(fldr, 'dot.pdf')
748
+
749
+ heatmap_fig = visualize_heatmap(filtered_results, 'query', 'feature', 'adjusted_p_value')
750
+ bar_fig = visualize_bar_chart(filtered_results)
751
+ dot_fig = visualize_dot_plot(filtered_results)
752
+
753
+ heatmap_fig.savefig(heatmap_path, bbox_inches='tight')
754
+ bar_fig.savefig(bar_path, bbox_inches='tight')
755
+ dot_fig.savefig(dot_path, bbox_inches='tight')
756
+
757
+ comparison_results = compare_features(filtered_results)
758
+ #display(comparison_results)
759
+ feature_metrics_results = calculate_feature_metrics(comparison_results)
760
+ #display(feature_metrics_results)
761
+
762
+ fldr = os.path.dirname(foldseek_csv_path)
763
+
764
+ merged_path = os.path.join(fldr, 'merged.csv')
765
+ enrichment_path = os.path.join(fldr, 'enrichment.csv')
766
+ comparison_path = os.path.join(fldr, 'comparison.csv')
767
+
768
+ results.to_csv(merged_path, index=False)
769
+ filtered_results.to_csv(enrichment_path, index=False)
770
+ comparison_results.to_csv(comparison_path, index=False)
771
+
772
+ print(f'saved to results to {merged_path}')
773
+ print(f'saved to enrichment results to {enrichment_path}')
774
+ print(f'saved to comparison results to {comparison_path}')
775
+
776
+ #display(functional_data_df)
777
+
778
+ # Set up directories
779
+ structure_fldr_path = "/home/carruthers/Downloads/ME49_proteome/cif"
780
+ base_dir='/home/carruthers/foldseek/me49'
781
+
782
+ align_to_database(structure_fldr_path, base_dir, cores=25)
783
+ #foldseek_csv_path = f'{base_dir}/results/pdb/aln_tmscore.csv'
784
+ #analyze_results(foldseek_csv_path, base_dir)
785
+
786
+ # Set up directories
787
+ structure_fldr_path = "/home/carruthers/Downloads/GT1_proteome/cif"
788
+ base_dir='/home/carruthers/foldseek/gt1'
789
+
790
+ align_to_database(structure_fldr_path, base_dir, cores=25)
791
+ #foldseek_csv_path = f'{base_dir}/results/pdb/aln_tmscore.csv'
792
+ #analyze_results(foldseek_csv_path, base_dir)
793
+