@sjcrh/proteinpaint-server 2.105.0 → 2.106.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/utils/gsea.py CHANGED
@@ -1,129 +1,140 @@
1
- # cat ~/sjpp/test.txt | python gsea.py
1
+ # Test syntax: cat ~/sjpp/test.txt | time python gsea.py
2
+ # test.txt contains the json string autogenerated by the commented out nodejs code.
3
+ import blitzgsea as blitz
4
+ import json
5
+ import time
6
+ import sys
7
+ import sqlite3
8
+ import os
9
+ import numpy as np
10
+ import pandas as pd
2
11
 
3
- import blitzgsea as blitz
4
- import json
5
- import time
6
- import sys
7
- import sqlite3
8
- import os
9
- import numpy as np
10
- import pandas as pd
11
-
12
+ # Helper function to extract gene symbols from a dictionary
12
13
  def extract_symbols(x):
13
- return x['symbol']
14
-
15
- def extract_plot_data(signature, geneset, library, result, center=True):
16
- signature = signature.copy()
17
- signature.columns = ["i","v"]
18
- signature = signature.sort_values("v", ascending=False).set_index("i")
19
- signature = signature[~signature.index.duplicated(keep='first')]
20
- if center:
21
- signature.loc[:,"v"] -= np.mean(signature.loc[:,"v"])
22
- signature_map = {}
23
- for i,h in enumerate(signature.index):
24
- signature_map[h] = i
25
-
26
- gs = set(library[geneset])
27
- hits = [i for i,x in enumerate(signature.index) if x in gs]
28
-
29
- running_sum, es = blitz.enrichment_score(np.array(np.abs(signature.iloc[:,0])), signature_map, gs)
30
- running_sum = list(running_sum)
31
- nn = np.where(np.abs(running_sum)==np.max(np.abs(running_sum)))[0][0]
32
- #print ("nn:",nn)
33
- #print ("running_sum:",running_sum)
34
- #print ("es:",es)
35
- running_sum_str=[str(elem) for elem in running_sum]
36
- print ('result: {"nn":'+str(nn)+',"running_sum":"'+",".join(running_sum_str)+'","es":'+str(es)+'}')
14
+ return x['symbol'] # Return the 'symbol' field from the dictionary
37
15
 
38
-
39
- # Main function
16
+ # Main function
40
17
  try:
41
- # Try to read a single character from stdin without blocking
18
+ # Check if there is input from stdin
42
19
  if sys.stdin.read(1):
43
- # Read from stdin
20
+ # Read each line from stdin
44
21
  for line in sys.stdin:
45
- # Process each line
22
+ # Parse the JSON input
46
23
  json_object = json.loads(line)
47
- cachedir=json_object['cachedir']
48
- genes=json_object['genes']
49
- fold_change=json_object['fold_change']
50
- table_name=json_object['geneset_group']
51
- filter_non_coding_genes=json_object['filter_non_coding_genes']
52
- df = {'Genes': genes, 'fold_change': fold_change}
53
- signature=pd.DataFrame(df)
54
- db=json_object['db']
55
- # Connect to the SQLite database
56
- conn = sqlite3.connect(db)
57
-
58
- # Create a cursor object using the cursor() method
59
- cursor = conn.cursor()
60
-
61
- # SQL query to select all data from the table
62
- query = f"select id from terms where parent_id='" + table_name + "'"
63
- # Execute the SQL query
64
- cursor.execute(query)
65
- if filter_non_coding_genes == True:
66
- # SQL query to code all the protein coding genes
67
- coding_genes_query = f"select * from codingGenes"
68
- genedb = json_object['genedb']
69
- gene_conn = sqlite3.connect(genedb)
70
- gene_cursor = gene_conn.cursor()
71
- gene_cursor.execute(coding_genes_query)
72
- coding_genes_list=gene_cursor.fetchall()
73
- coding_genes_list=list(map(lambda x: x[0],coding_genes_list))
74
- signature=signature[signature['Genes'].isin(coding_genes_list)]
75
-
76
- # Fetch all rows from the executed SQL query
77
- rows = cursor.fetchall()
78
-
79
- start_loop_time = time.time()
80
- msigdb_library={}
81
- # Iterate over the rows and print them
82
- for row in rows:
83
- #print(row[0])
84
- query2=f"select genes from term2genes where id='" + row[0] + "'"
85
- cursor.execute(query2)
86
- rows2 = cursor.fetchall()
87
- row3=json.loads(rows2[0][0])
88
- msigdb_library[row[0]] = list(map(extract_symbols,row3))
24
+ cachedir = json_object['cachedir'] # Get the cache directory from the JSON object
25
+ genes = json_object['genes'] # Get the genes from the JSON object
26
+ fold_change = json_object['fold_change'] # Get the fold change values from the JSON object
27
+ table_name = json_object['geneset_group'] # Get the gene set group from the JSON object
28
+ filter_non_coding_genes = json_object['filter_non_coding_genes'] # Get the filter_non_coding_genes flag from the JSON object
29
+ db = json_object['db'] # Get the database path from the JSON object
30
+ # Create a DataFrame for the signature
31
+ df = {'Genes': genes, 'fold_change': fold_change} # Create a dictionary with genes and fold change
32
+ signature = pd.DataFrame(df) # Convert the dictionary to a DataFrame
89
33
 
90
- #print ("msigdb_library:",msigdb_library)
91
- # Close the cursor and connection to the database
92
- cursor.close()
93
- conn.close()
94
- stop_loop_time = time.time()
95
- execution_time = stop_loop_time - start_loop_time
96
- print(f"Execution time: {execution_time} seconds")
97
- try: # Extract ES data to be plotted on client side
98
- geneset_name=json_object['geneset_name'] # Checks if geneset_name is present, if yes it indicates the server request is for generating the image. It retrieves the result.pkl file and generates the image without having to recompute gsea again.
99
- pickle_file=json_object['pickle_file']
100
- result = pd.read_pickle(os.path.join(cachedir,pickle_file))
101
- fig = blitz.plot.running_sum(signature, geneset_name, msigdb_library, result=result.T, compact=True)
102
- random_num = np.random.rand()
103
- png_filename = "gsea_plot_" + str(random_num) + ".png"
104
- fig.savefig(os.path.join(cachedir,png_filename), bbox_inches='tight')
105
- #extract_plot_data(signature, geneset_name, msigdb_library, result) # This returns raw data to client side, not currently used
106
- print ('image: {"image_file":"' + png_filename + '"}')
107
- except KeyError: #Initial GSEA calculation, result saved to a result.pkl pickle file
108
- # run enrichment analysis
109
- start_gsea_time = time.time()
110
- if __name__ == "__main__":
111
- result = blitz.gsea(signature, msigdb_library).T
112
- random_num = np.random.rand()
113
- pickle_filename="gsea_result_"+ str(random_num) +".pkl"
114
- result.to_pickle(os.path.join(cachedir,pickle_filename))
115
- gsea_str='{"data":' + result.to_json() + '}'
116
- pickle_str='{"pickle_file":"' + pickle_filename + '"}'
117
- #print ("pickle_file:",pickle_str)
118
- gsea_dict = json.loads(gsea_str)
119
- pickle_dict = json.loads(pickle_str)
120
- result_dict = {**gsea_dict, **pickle_dict}
121
- print ("result:",json.dumps(result_dict))
122
- stop_gsea_time = time.time()
123
- gsea_time = stop_gsea_time - start_gsea_time
124
- print (f"GSEA time: {gsea_time} seconds")
34
+ # Connect to the SQLite database
35
+ conn = sqlite3.connect(db) # Connect to the SQLite database
36
+ cursor = conn.cursor() # Create a cursor object
37
+
38
+ msigdb_library = {} # Initialize an empty dictionary for the gene set library
39
+ if table_name == "REACTOME--blitzgsea": # Parse from blitzgsea reactome library
40
+ msigdb_library = blitz.enrichr.get_library("Reactome_2022")
41
+ elif table_name == "KEGG--blitzgsea": # Parse from blitzgsea KEGG library
42
+ msigdb_library = blitz.enrichr.get_library("KEGG_2021_Human")
43
+ elif table_name == "WikiPathways--blitzgsea": # Parse from blitzgsea WikiPathways library
44
+ msigdb_library = blitz.enrichr.get_library("WikiPathways_2019_Human")
45
+ else: # Use geneset groups from msigdb
46
+ # Query to get gene set IDs
47
+ query = f"SELECT id FROM terms WHERE parent_id='{table_name}'" # SQL query to get gene set IDs
48
+ cursor.execute(query) # Execute the query
49
+
50
+ # Fetch all gene set IDs
51
+ rows = cursor.fetchall() # Fetch all rows from the executed query
52
+
53
+ start_loop_time = time.time() # Record the start time of the loop
54
+
55
+ # Iterate over gene set IDs and fetch corresponding genes
56
+ for row in rows:
57
+ query2 = f"SELECT genes FROM term2genes WHERE id='{row[0]}'" # SQL query to get genes for a gene set ID
58
+ cursor.execute(query2) # Execute the query
59
+ rows2 = cursor.fetchall() # Fetch all rows from the executed query
60
+ row3 = json.loads(rows2[0][0]) # Parse the JSON data
61
+ msigdb_library[row[0]] = list(set(map(extract_symbols, row3))) # Extract only unique gene symbols and add them to the library. "set" command selects only unique genes
62
+ #print ("msigdb_library:",msigdb_library)
63
+
64
+ # Close the cursor and connection to the database
65
+ cursor.close() # Close the cursor
66
+ conn.close() # Close the connection
67
+
68
+ stop_loop_time = time.time() # Record the stop time of the loop
69
+ execution_time = stop_loop_time - start_loop_time # Calculate the execution time
70
+ print(f"Execution time: {execution_time} seconds") # Print the execution time
71
+
72
+ # Filter out non-coding genes if specified
73
+ if filter_non_coding_genes:
74
+ coding_genes_query = "SELECT * FROM codingGenes" # SQL query to get coding genes
75
+ genedb = json_object['genedb'] # Get the gene database path from the JSON object
76
+ gene_conn = sqlite3.connect(genedb) # Connect to the gene database
77
+ gene_cursor = gene_conn.cursor() # Create a cursor object for the gene database
78
+ gene_cursor.execute(coding_genes_query) # Execute the query to get coding genes
79
+ coding_genes_list = gene_cursor.fetchall() # Fetch all coding genes
80
+ coding_genes_list = list(map(lambda x: x[0], coding_genes_list)) # Extract the gene symbols
81
+ signature = signature[signature['Genes'].isin(coding_genes_list)] # Filter the signature to include only coding genes
125
82
 
83
+ try:
84
+ # Check if geneset_name and pickle_file are present for generating the plot
85
+ geneset_name = json_object['geneset_name'] # Get the gene set name from the JSON object
86
+ pickle_file = json_object['pickle_file'] # Get the pickle file name from the JSON object
87
+ result = pd.read_pickle(os.path.join(cachedir, pickle_file)) # Load the result from the pickle file
88
+ fig = blitz.plot.running_sum(signature, geneset_name, msigdb_library, result=result.T, compact=True) # Generate the running sum plot
89
+ random_num = np.random.rand() # Generate a random number for unique png filename
90
+ png_filename = f"gsea_plot_{random_num}.png" # Create a filename for the plot
91
+ fig.savefig(os.path.join(cachedir, png_filename), bbox_inches='tight') # Save the plot as a PNG file
92
+ print(f'image: {{"image_file": "{png_filename}"}}') # Print the image file path in JSON format
93
+ except KeyError:
94
+ # Initial GSEA calculation and save the result to a pickle file
95
+ start_gsea_time = time.time() # Record the start time of GSEA
96
+ if __name__ == "__main__":
97
+ num_permutations = json_object['num_permutations'] # Number of permutations for GSEA analysis
98
+ result = blitz.gsea(signature, msigdb_library, permutations=num_permutations).T # Perform GSEA and transpose the result
99
+ random_num = np.random.rand() # Generate a random number for unique pickle filename
100
+ pickle_filename = f"gsea_result_{random_num}.pkl" # Create a filename for the pickle file
101
+ result.to_pickle(os.path.join(cachedir, pickle_filename)) # Save the result to the pickle file
102
+ gsea_str = f'{{"data": {result.to_json()}}}' # Convert the result to JSON format
103
+ pickle_str = f'{{"pickle_file": "{pickle_filename}"}}' # Create a JSON string for the pickle file
104
+ gsea_dict = json.loads(gsea_str) # Parse the JSON string
105
+ pickle_dict = json.loads(pickle_str) # Parse the JSON string
106
+ result_dict = {**gsea_dict, **pickle_dict} # Merge the dictionaries
107
+ print(f"result: {json.dumps(result_dict)}") # Print the result in JSON format
108
+ stop_gsea_time = time.time() # Record the stop time of GSEA
109
+ gsea_time = stop_gsea_time - start_gsea_time # Calculate the GSEA execution time
110
+ print(f"GSEA time: {gsea_time} seconds") # Print the GSEA execution time
126
111
  else:
127
- pass
112
+ pass # Do nothing if there is no input from stdin
128
113
  except (EOFError, IOError):
129
- pass
114
+ pass # Handle EOFError and IOError exceptions gracefully
115
+
116
+ # Function to extract plot data for GSEA visualization (NOT currently being used, but will be used for generating client side gsea plots)
117
+ def extract_plot_data(signature, geneset, library, result, center=True):
118
+ print("signature", signature)
119
+ print("result", result)
120
+ print("geneset", geneset)
121
+ print("library", library)
122
+ signature = signature.copy() # Create a copy of the signature DataFrame
123
+ signature.columns = ["i", "v"] # Rename columns to 'i' and 'v'
124
+ signature = signature.sort_values("v", ascending=False).set_index("i") # Sort by 'v' in descending order and set 'i' as index
125
+ signature = signature[~signature.index.duplicated(keep='first')] # Remove duplicate indices, keeping the first occurrence
126
+
127
+ if center:
128
+ signature.loc[:, "v"] -= np.mean(signature.loc[:, "v"]) # Center the signature values by subtracting the mean
129
+
130
+ signature_map = {h: i for i, h in enumerate(signature.index)} # Create a mapping of signature indices
131
+
132
+ gs = set(library[geneset]) # Get the gene set from the library
133
+ hits = [i for i, x in enumerate(signature.index) if x in gs] # Find the indices of hits in the signature
134
+
135
+ running_sum, es = blitz.enrichment_score(np.array(np.abs(signature.iloc[:, 0])), signature_map, gs) # Compute running sum and enrichment score
136
+ running_sum = list(running_sum) # Convert running sum to a list
137
+ nn = np.where(np.abs(running_sum) == np.max(np.abs(running_sum)))[0][0] # Find the index of the maximum absolute running sum
138
+
139
+ running_sum_str = [str(elem) for elem in running_sum] # Convert running sum elements to strings
140
+ print(f'result: {{"nn": {nn}, "running_sum": "{",".join(running_sum_str)}", "es": {es}}}') # Print the result in JSON format