@sjcrh/proteinpaint-rust 2.167.0 → 2.170.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Cargo.toml +8 -4
- package/package.json +1 -1
- package/src/aichatbot.rs +307 -310
- package/src/manhattan_plot.rs +31 -16
- package/src/query_classification.rs +152 -0
- package/src/summary_agent.rs +201 -0
- package/src/test_ai.rs +81 -58
package/src/aichatbot.rs
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
// Syntax: cd .. && cargo build --release && time cat ~/sjpp/test.txt | target/release/aichatbot
|
|
2
2
|
#![allow(non_snake_case)]
|
|
3
|
-
use anyhow::Result;
|
|
4
|
-
use json::JsonValue;
|
|
3
|
+
//use anyhow::Result;
|
|
4
|
+
//use json::JsonValue;
|
|
5
5
|
use r2d2_sqlite::SqliteConnectionManager;
|
|
6
6
|
use rig::agent::AgentBuilder;
|
|
7
7
|
use rig::completion::Prompt;
|
|
@@ -11,40 +11,53 @@ use schemars::JsonSchema;
|
|
|
11
11
|
use serde_json::{Map, Value, json};
|
|
12
12
|
use std::collections::HashMap;
|
|
13
13
|
use std::fs;
|
|
14
|
-
use std::io;
|
|
15
|
-
use std::path::Path;
|
|
16
|
-
mod ollama; // Importing custom rig module for invoking ollama server
|
|
17
|
-
mod sjprovider; // Importing custom rig module for invoking SJ GPU server
|
|
18
|
-
|
|
19
|
-
mod test_ai; // Test examples for AI chatbot
|
|
20
14
|
|
|
21
15
|
// Struct for intaking data from dataset json
|
|
22
16
|
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
23
17
|
pub struct AiJsonFormat {
|
|
24
|
-
hasGeneExpression: bool,
|
|
25
|
-
db: String, // Dataset db
|
|
26
|
-
genedb: String, // Gene db
|
|
27
|
-
charts: Vec<
|
|
18
|
+
pub hasGeneExpression: bool,
|
|
19
|
+
pub db: String, // Dataset db
|
|
20
|
+
pub genedb: String, // Gene db
|
|
21
|
+
pub charts: Vec<TrainTestData>,
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
25
|
+
pub struct TrainTestData {
|
|
26
|
+
pub r#type: String,
|
|
27
|
+
pub SystemPrompt: String,
|
|
28
|
+
pub TrainingData: Vec<QuestionAnswer>,
|
|
29
|
+
pub TestData: Vec<QuestionAnswer>,
|
|
28
30
|
}
|
|
29
31
|
|
|
30
32
|
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
DE(TrainTestData),
|
|
33
|
+
pub struct QuestionAnswer {
|
|
34
|
+
pub question: String,
|
|
35
|
+
pub answer: AnswerFormat,
|
|
35
36
|
}
|
|
36
37
|
|
|
37
38
|
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
39
|
+
pub enum AnswerFormat {
|
|
40
|
+
#[allow(non_camel_case_types)]
|
|
41
|
+
summary_type(SummaryType),
|
|
42
|
+
#[allow(non_camel_case_types)]
|
|
43
|
+
DE_type(DEType),
|
|
42
44
|
}
|
|
43
45
|
|
|
44
46
|
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
45
|
-
struct
|
|
46
|
-
|
|
47
|
-
|
|
47
|
+
pub struct DEType {
|
|
48
|
+
action: String,
|
|
49
|
+
DE_output: DETerms,
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
53
|
+
pub struct DETerms {
|
|
54
|
+
group1: GroupType,
|
|
55
|
+
group2: GroupType,
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
59
|
+
pub struct GroupType {
|
|
60
|
+
name: String,
|
|
48
61
|
}
|
|
49
62
|
|
|
50
63
|
#[allow(non_camel_case_types)]
|
|
@@ -56,180 +69,188 @@ pub enum llm_backend {
|
|
|
56
69
|
|
|
57
70
|
#[derive(Debug, JsonSchema)]
|
|
58
71
|
#[allow(dead_code)]
|
|
59
|
-
struct OutputJson {
|
|
72
|
+
pub struct OutputJson {
|
|
60
73
|
pub answer: String,
|
|
61
74
|
}
|
|
62
75
|
|
|
63
|
-
|
|
64
|
-
async fn main() -> Result<()> {
|
|
65
|
-
let mut input = String::new();
|
|
66
|
-
match io::stdin().read_line(&mut input) {
|
|
67
|
-
// Accepting the piped input from nodejs (or command line from testing)
|
|
68
|
-
Ok(_n) => {
|
|
69
|
-
let input_json = json::parse(&input);
|
|
70
|
-
match input_json {
|
|
71
|
-
Ok(json_string) => {
|
|
72
|
-
//println!("json_string:{}", json_string);
|
|
73
|
-
let user_input_json: &JsonValue = &json_string["user_input"];
|
|
74
|
-
let user_input: &str;
|
|
75
|
-
match user_input_json.as_str() {
|
|
76
|
-
Some(inp) => user_input = inp,
|
|
77
|
-
None => panic!("user_input field is missing in input json"),
|
|
78
|
-
}
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
let
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
let
|
|
92
|
-
let
|
|
93
|
-
match
|
|
94
|
-
Some(inp) =>
|
|
95
|
-
None => panic!("
|
|
96
|
-
}
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
}
|
|
106
|
-
|
|
107
|
-
let
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
let
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
let
|
|
138
|
-
let
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
let
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
let
|
|
159
|
-
let
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
}
|
|
76
|
+
//#[tokio::main]
|
|
77
|
+
//async fn main() -> Result<()> {
|
|
78
|
+
// let mut input = String::new();
|
|
79
|
+
// match io::stdin().read_line(&mut input) {
|
|
80
|
+
// // Accepting the piped input from nodejs (or command line from testing)
|
|
81
|
+
// Ok(_n) => {
|
|
82
|
+
// let input_json = json::parse(&input);
|
|
83
|
+
// match input_json {
|
|
84
|
+
// Ok(json_string) => {
|
|
85
|
+
// //println!("json_string:{}", json_string);
|
|
86
|
+
// let user_input_json: &JsonValue = &json_string["user_input"];
|
|
87
|
+
// let user_input: &str;
|
|
88
|
+
// match user_input_json.as_str() {
|
|
89
|
+
// Some(inp) => user_input = inp,
|
|
90
|
+
// None => panic!("user_input field is missing in input json"),
|
|
91
|
+
// }
|
|
92
|
+
// let dataset_db_json: &JsonValue = &json_string["dataset_db"];
|
|
93
|
+
// let dataset_db_str: &str;
|
|
94
|
+
// match dataset_db_json.as_str() {
|
|
95
|
+
// Some(inp) => dataset_db_str = inp,
|
|
96
|
+
// None => panic!("dataset_db field is missing in input json"),
|
|
97
|
+
// }
|
|
98
|
+
// let genedb_json: &JsonValue = &json_string["genedb"];
|
|
99
|
+
// let genedb_str: &str;
|
|
100
|
+
// match genedb_json.as_str() {
|
|
101
|
+
// Some(inp) => genedb_str = inp,
|
|
102
|
+
// None => panic!("genedb field is missing in input json"),
|
|
103
|
+
// }
|
|
104
|
+
// let aiRoute_json: &JsonValue = &json_string["aiRoute"];
|
|
105
|
+
// let aiRoute_str: &str;
|
|
106
|
+
// match aiRoute_json.as_str() {
|
|
107
|
+
// Some(inp) => aiRoute_str = inp,
|
|
108
|
+
// None => panic!("aiRoute field is missing in input json"),
|
|
109
|
+
// }
|
|
110
|
+
// if user_input.len() == 0 {
|
|
111
|
+
// panic!("The user input is empty");
|
|
112
|
+
// }
|
|
113
|
+
// let tpmasterdir_json: &JsonValue = &json_string["tpmasterdir"];
|
|
114
|
+
// let tpmasterdir: &str;
|
|
115
|
+
// match tpmasterdir_json.as_str() {
|
|
116
|
+
// Some(inp) => tpmasterdir = inp,
|
|
117
|
+
// None => panic!("tpmasterdir not found"),
|
|
118
|
+
// }
|
|
119
|
+
// let binpath_json: &JsonValue = &json_string["binpath"];
|
|
120
|
+
// let binpath: &str;
|
|
121
|
+
// match binpath_json.as_str() {
|
|
122
|
+
// Some(inp) => binpath = inp,
|
|
123
|
+
// None => panic!("binpath not found"),
|
|
124
|
+
// }
|
|
125
|
+
// let ai_json_file_json: &JsonValue = &json_string["aifiles"];
|
|
126
|
+
// let ai_json_file: String;
|
|
127
|
+
// match ai_json_file_json.as_str() {
|
|
128
|
+
// Some(inp) => ai_json_file = String::from(binpath) + &"/../../" + &inp,
|
|
129
|
+
// None => {
|
|
130
|
+
// panic!("ai json file not found")
|
|
131
|
+
// }
|
|
132
|
+
// }
|
|
133
|
+
// let ai_json_file = Path::new(&ai_json_file);
|
|
134
|
+
// let ai_json_file_path;
|
|
135
|
+
// let current_dir = std::env::current_dir().unwrap();
|
|
136
|
+
// match ai_json_file.canonicalize() {
|
|
137
|
+
// Ok(p) => ai_json_file_path = p,
|
|
138
|
+
// Err(_) => {
|
|
139
|
+
// panic!(
|
|
140
|
+
// "AI JSON file path not found:{:?}, current directory:{:?}",
|
|
141
|
+
// ai_json_file, current_dir
|
|
142
|
+
// )
|
|
143
|
+
// }
|
|
144
|
+
// }
|
|
145
|
+
// // Read the file
|
|
146
|
+
// let ai_data = fs::read_to_string(ai_json_file_path).unwrap();
|
|
147
|
+
// // Parse the JSON data
|
|
148
|
+
// let ai_json: AiJsonFormat =
|
|
149
|
+
// serde_json::from_str(&ai_data).expect("AI JSON file does not have the correct format");
|
|
150
|
+
// let genedb = String::from(tpmasterdir) + &"/" + &genedb_str;
|
|
151
|
+
// let dataset_db = String::from(tpmasterdir) + &"/" + &dataset_db_str;
|
|
152
|
+
// let airoute = String::from(binpath) + &"/../../" + &aiRoute_str;
|
|
153
|
+
// let apilink_json: &JsonValue = &json_string["apilink"];
|
|
154
|
+
// let apilink: &str;
|
|
155
|
+
// match apilink_json.as_str() {
|
|
156
|
+
// Some(inp) => apilink = inp,
|
|
157
|
+
// None => panic!("apilink field is missing in input json"),
|
|
158
|
+
// }
|
|
159
|
+
// let comp_model_name_json: &JsonValue = &json_string["comp_model_name"];
|
|
160
|
+
// let comp_model_name: &str;
|
|
161
|
+
// match comp_model_name_json.as_str() {
|
|
162
|
+
// Some(inp) => comp_model_name = inp,
|
|
163
|
+
// None => panic!("comp_model_name field is missing in input json"),
|
|
164
|
+
// }
|
|
165
|
+
// let embedding_model_name_json: &JsonValue = &json_string["embedding_model_name"];
|
|
166
|
+
// let embedding_model_name: &str;
|
|
167
|
+
// match embedding_model_name_json.as_str() {
|
|
168
|
+
// Some(inp) => embedding_model_name = inp,
|
|
169
|
+
// None => panic!("embedding_model_name field is missing in input json"),
|
|
170
|
+
// }
|
|
171
|
+
// let llm_backend_name_json: &JsonValue = &json_string["llm_backend_name"];
|
|
172
|
+
// let llm_backend_name: &str;
|
|
173
|
+
// match llm_backend_name_json.as_str() {
|
|
174
|
+
// Some(inp) => llm_backend_name = inp,
|
|
175
|
+
// None => panic!("llm_backend_name field is missing in input json"),
|
|
176
|
+
// }
|
|
177
|
+
// let llm_backend_type: llm_backend;
|
|
178
|
+
// let mut final_output: Option<String> = None;
|
|
179
|
+
// let temperature: f64 = 0.01;
|
|
180
|
+
// let max_new_tokens: usize = 512;
|
|
181
|
+
// let top_p: f32 = 0.95;
|
|
182
|
+
// let testing = false; // This variable is always false in production, this is true in test_ai.rs for testing code
|
|
183
|
+
// if llm_backend_name != "ollama" && llm_backend_name != "SJ" {
|
|
184
|
+
// panic!(
|
|
185
|
+
// "This code currently supports only Ollama and SJ provider. llm_backend_name must be \"ollama\" or \"SJ\""
|
|
186
|
+
// );
|
|
187
|
+
// } else if llm_backend_name == "ollama".to_string() {
|
|
188
|
+
// llm_backend_type = llm_backend::Ollama();
|
|
189
|
+
// // Initialize Ollama client
|
|
190
|
+
// let ollama_client = ollama::Client::builder()
|
|
191
|
+
// .base_url(apilink)
|
|
192
|
+
// .build()
|
|
193
|
+
// .expect("Ollama server not found");
|
|
194
|
+
// let embedding_model = ollama_client.embedding_model(embedding_model_name);
|
|
195
|
+
// let comp_model = ollama_client.completion_model(comp_model_name);
|
|
196
|
+
// final_output = run_pipeline(
|
|
197
|
+
// user_input,
|
|
198
|
+
// comp_model,
|
|
199
|
+
// embedding_model,
|
|
200
|
+
// llm_backend_type,
|
|
201
|
+
// temperature,
|
|
202
|
+
// max_new_tokens,
|
|
203
|
+
// top_p,
|
|
204
|
+
// &dataset_db,
|
|
205
|
+
// &genedb,
|
|
206
|
+
// &ai_json,
|
|
207
|
+
// &airoute,
|
|
208
|
+
// testing,
|
|
209
|
+
// )
|
|
210
|
+
// .await;
|
|
211
|
+
// } else if llm_backend_name == "SJ".to_string() {
|
|
212
|
+
// llm_backend_type = llm_backend::Sj();
|
|
213
|
+
// // Initialize Sj provider client
|
|
214
|
+
// let sj_client = sjprovider::Client::builder()
|
|
215
|
+
// .base_url(apilink)
|
|
216
|
+
// .build()
|
|
217
|
+
// .expect("SJ server not found");
|
|
218
|
+
// let embedding_model = sj_client.embedding_model(embedding_model_name);
|
|
219
|
+
// let comp_model = sj_client.completion_model(comp_model_name);
|
|
220
|
+
// final_output = run_pipeline(
|
|
221
|
+
// user_input,
|
|
222
|
+
// comp_model,
|
|
223
|
+
// embedding_model,
|
|
224
|
+
// llm_backend_type,
|
|
225
|
+
// temperature,
|
|
226
|
+
// max_new_tokens,
|
|
227
|
+
// top_p,
|
|
228
|
+
// &dataset_db,
|
|
229
|
+
// &genedb,
|
|
230
|
+
// &ai_json,
|
|
231
|
+
// &airoute,
|
|
232
|
+
// testing,
|
|
233
|
+
// )
|
|
234
|
+
// .await;
|
|
235
|
+
// }
|
|
236
|
+
// match final_output {
|
|
237
|
+
// Some(fin_out) => {
|
|
238
|
+
// println!("final_output:{:?}", fin_out.replace("\\", ""));
|
|
239
|
+
// }
|
|
240
|
+
// None => {
|
|
241
|
+
// println!("final_output:{{\"{}\":\"{}\"}}", "action", "unknown");
|
|
242
|
+
// }
|
|
243
|
+
// }
|
|
244
|
+
// }
|
|
245
|
+
// Err(error) => println!("Incorrect json:{}", error),
|
|
246
|
+
// }
|
|
247
|
+
// }
|
|
248
|
+
// Err(error) => println!("Piping error: {}", error),
|
|
249
|
+
// }
|
|
250
|
+
// Ok(())
|
|
251
|
+
//}
|
|
232
252
|
|
|
253
|
+
#[allow(dead_code)]
|
|
233
254
|
pub async fn run_pipeline(
|
|
234
255
|
user_input: &str,
|
|
235
256
|
comp_model: impl rig::completion::CompletionModel + 'static,
|
|
@@ -241,6 +262,7 @@ pub async fn run_pipeline(
|
|
|
241
262
|
dataset_db: &str,
|
|
242
263
|
genedb: &str,
|
|
243
264
|
ai_json: &AiJsonFormat,
|
|
265
|
+
ai_route: &str,
|
|
244
266
|
testing: bool,
|
|
245
267
|
) -> Option<String> {
|
|
246
268
|
let mut classification: String = classify_query_by_dataset_type(
|
|
@@ -251,6 +273,7 @@ pub async fn run_pipeline(
|
|
|
251
273
|
temperature,
|
|
252
274
|
max_new_tokens,
|
|
253
275
|
top_p,
|
|
276
|
+
ai_route,
|
|
254
277
|
)
|
|
255
278
|
.await;
|
|
256
279
|
classification = classification.replace("\"", "");
|
|
@@ -373,104 +396,40 @@ pub async fn run_pipeline(
|
|
|
373
396
|
Some(final_output)
|
|
374
397
|
}
|
|
375
398
|
|
|
376
|
-
|
|
399
|
+
#[allow(dead_code)]
|
|
400
|
+
pub async fn classify_query_by_dataset_type(
|
|
377
401
|
user_input: &str,
|
|
378
402
|
comp_model: impl rig::completion::CompletionModel + 'static,
|
|
379
|
-
|
|
403
|
+
_embedding_model: impl rig::embeddings::EmbeddingModel + 'static,
|
|
380
404
|
llm_backend_type: &llm_backend,
|
|
381
405
|
temperature: f64,
|
|
382
406
|
max_new_tokens: usize,
|
|
383
407
|
top_p: f32,
|
|
408
|
+
ai_route: &str,
|
|
384
409
|
) -> String {
|
|
385
|
-
//
|
|
386
|
-
let
|
|
387
|
-
|
|
388
|
-
If a ProteinPaint dataset contains SNV/Indel/SV data then return JSON with single key, 'snv_indel'.
|
|
389
|
-
|
|
390
|
-
---
|
|
391
|
-
|
|
392
|
-
Copy number variation (CNV) is a phenomenon in which sections of the genome are repeated and the number of repeats in the genome varies between individuals.[1] Copy number variation is a special type of structural variation: specifically, it is a type of duplication or deletion event that affects a considerable number of base pairs.
|
|
410
|
+
// Read the file
|
|
411
|
+
let ai_route_data = fs::read_to_string(ai_route).unwrap();
|
|
393
412
|
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
---
|
|
397
|
-
|
|
398
|
-
Structural variants/fusions (SV) are genomic mutations when eith a DNA region is translocated or copied to an entirely different genomic locus. In case of transcriptomic data, when RNA is fused from two different genes its called a gene fusion.
|
|
399
|
-
|
|
400
|
-
If a ProteinPaint dataset contains structural variation or gene fusion data then return JSON with single key, 'sv_fusion'.
|
|
401
|
-
---
|
|
402
|
-
|
|
403
|
-
Hierarchical clustering of gene expression is an unsupervised learning technique where several number of relevant genes and the samples are clustered so as to determine (previously unknown) cohorts of samples (or patients) or structure in data. It is very commonly used to determine subtypes of a particular disease based on RNA sequencing data.
|
|
404
|
-
|
|
405
|
-
If a ProteinPaint dataset contains hierarchical data then return JSON with single key, 'hierarchical'.
|
|
406
|
-
|
|
407
|
-
---
|
|
408
|
-
|
|
409
|
-
Differential Gene Expression (DGE or DE) is a technique where the most upregulated (or highest) and downregulated (or lowest) genes between two cohorts of samples (or patients) are determined from a pool of THOUSANDS of genes. Differential gene expression CANNOT be computed for a SINGLE gene. A volcano plot is shown with fold-change in the x-axis and adjusted p-value on the y-axis. So, the upregulated and downregulared genes are on opposite sides of the graph and the most significant genes (based on adjusted p-value) is on the top of the graph. Following differential gene expression generally GeneSet Enrichment Analysis (GSEA) is carried out where based on the genes and their corresponding fold changes the upregulation/downregulation of genesets (or pathways) is determined.
|
|
410
|
-
|
|
411
|
-
Sample Query1: \"Which gene has the highest expression between the two genders\"
|
|
412
|
-
Sample Answer1: { \"answer\": \"dge\" }
|
|
413
|
-
|
|
414
|
-
Sample Query2: \"Which gene has the lowest expression between the two races\"
|
|
415
|
-
Sample Answer2: { \"answer\": \"dge\" }
|
|
416
|
-
|
|
417
|
-
Sample Query1: \"Which genes are the most upregulated genes between group A and group B\"
|
|
418
|
-
Sample Answer1: { \"answer\": \"dge\" }
|
|
419
|
-
|
|
420
|
-
Sample Query3: \"Which gene are overexpressed between male and female\"
|
|
421
|
-
Sample Answer3: { \"answer\": \"dge\" }
|
|
422
|
-
|
|
423
|
-
Sample Query4: \"Which gene are housekeeping genes between male and female\"
|
|
424
|
-
Sample Answer4: { \"answer\": \"dge\" }
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
If a ProteinPaint dataset contains differential gene expression data then return JSON with single key, 'dge'.
|
|
428
|
-
|
|
429
|
-
---
|
|
430
|
-
|
|
431
|
-
Survival analysis (also called time-to-event analysis or duration analysis) is a branch of statistics aimed at analyzing the duration of time from a well-defined time origin until one or more events happen, called survival times or duration times. In other words, in survival analysis, we are interested in a certain event and want to analyze the time until the event happens. Generally in survival analysis survival rates between two (or more) cohorts of patients is compared.
|
|
432
|
-
|
|
433
|
-
There are two main methods of survival analysis:
|
|
434
|
-
|
|
435
|
-
1) Kaplan-Meier (HM) analysis is a univariate test that only takes into account a single categorical variable.
|
|
436
|
-
2) Cox proportional hazards model (coxph) is a multivariate test that can take into account multiple variables.
|
|
437
|
-
|
|
438
|
-
The hazard ratio (HR) is an indicator of the effect of the stimulus (e.g. drug dose, treatment) between two cohorts of patients.
|
|
439
|
-
HR = 1: No effect
|
|
440
|
-
HR < 1: Reduction in the hazard
|
|
441
|
-
HR > 1: Increase in Hazard
|
|
442
|
-
|
|
443
|
-
Sample Query1: \"Compare survival rates between group A and B\"
|
|
444
|
-
Sample Answer1: { \"answer\": \"survival\" }
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
If a ProteinPaint dataset contains survival data then return JSON with single key, 'survival'.
|
|
448
|
-
|
|
449
|
-
---
|
|
450
|
-
|
|
451
|
-
Next generation sequencing reads (NGS) are mapped to a human genome using alignment algorithm such as burrows-wheelers alignment algorithm. Then these reads are called using variant calling algorithms such as GATK (Genome Analysis Toolkit). However this type of analysis is too compute intensive and beyond the scope of visualization software such as ProteinPaint.
|
|
452
|
-
|
|
453
|
-
If a user query asks about variant calling or mapping reads then JSON with single key, 'variant_calling'.
|
|
454
|
-
|
|
455
|
-
---
|
|
456
|
-
|
|
457
|
-
Summary plot in ProteinPaint shows the various facets of the datasets. Show expression of a SINGLE gene or compare the expression of a SINGLE gene across two different cohorts defined by the user. It may show all the samples according to their respective diagnosis or subtypes of cancer. It is also useful for comparing and correlating different clinical variables. It can show all possible distributions, frequency of a category, overlay, correlate or cross-tabulate with another variable on top of it. If a user query asks about a SINGLE gene expression or correlating clinical variables then return JSON with single key, 'summary'.
|
|
458
|
-
|
|
459
|
-
Sample Query1: \"Show all fusions for patients with age less than 30\"
|
|
460
|
-
Sample Answer1: { \"answer\": \"summary\" }
|
|
461
|
-
|
|
462
|
-
Sample Query2: \"List all molecular subtypes of leukemia\"
|
|
463
|
-
Sample Answer2: { \"answer\": \"summary\" }
|
|
464
|
-
|
|
465
|
-
Sample Query3: \"is tp53 expression higher in men than women ?\"
|
|
466
|
-
Sample Answer3: { \"answer\": \"summary\" }
|
|
467
|
-
|
|
468
|
-
Sample Query4: \"Compare ATM expression between races for women greater than 80yrs\"
|
|
469
|
-
Sample Answer4: { \"answer\": \"summary\" }
|
|
413
|
+
// Parse the JSON data
|
|
414
|
+
let ai_json: Value = serde_json::from_str(&ai_route_data).expect("AI JSON file does not have the correct format");
|
|
470
415
|
|
|
416
|
+
// Create a string to hold the file contents
|
|
417
|
+
let mut contents = String::from("");
|
|
418
|
+
|
|
419
|
+
if let Some(object) = ai_json.as_object() {
|
|
420
|
+
contents = object["general"].to_string();
|
|
421
|
+
for (key, value) in object {
|
|
422
|
+
if key != "general" {
|
|
423
|
+
contents += &value.as_str().unwrap();
|
|
424
|
+
contents += "---"; // Adding delimiter
|
|
425
|
+
}
|
|
426
|
+
}
|
|
427
|
+
}
|
|
471
428
|
|
|
472
|
-
|
|
473
|
-
|
|
429
|
+
// Removing the last "---" characters
|
|
430
|
+
contents.pop();
|
|
431
|
+
contents.pop();
|
|
432
|
+
contents.pop();
|
|
474
433
|
|
|
475
434
|
// Split the contents by the delimiter "---"
|
|
476
435
|
let parts: Vec<&str> = contents.split("---").collect();
|
|
@@ -501,18 +460,18 @@ If a query does not match any of the fields described above, then return JSON wi
|
|
|
501
460
|
rag_docs.push(part.trim().to_string())
|
|
502
461
|
}
|
|
503
462
|
|
|
504
|
-
//let top_k: usize = 3;
|
|
463
|
+
//let top_k: usize = 3; // Embedding model not used currently
|
|
505
464
|
// Create embeddings and add to vector store
|
|
506
|
-
let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
465
|
+
//let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
|
|
466
|
+
// .documents(rag_docs)
|
|
467
|
+
// .expect("Reason1")
|
|
468
|
+
// .build()
|
|
469
|
+
// .await
|
|
470
|
+
// .unwrap();
|
|
512
471
|
|
|
513
|
-
|
|
514
|
-
let mut vector_store = InMemoryVectorStore::<String>::default();
|
|
515
|
-
InMemoryVectorStore::add_documents(&mut vector_store, embeddings);
|
|
472
|
+
//// Create vector store
|
|
473
|
+
//let mut vector_store = InMemoryVectorStore::<String>::default();
|
|
474
|
+
//InMemoryVectorStore::add_documents(&mut vector_store, embeddings);
|
|
516
475
|
|
|
517
476
|
// Create RAG agent
|
|
518
477
|
let agent = AgentBuilder::new(comp_model).preamble(&(String::from("Generate classification for the user query into summary, dge, hierarchical, snv_indel, cnv, variant_calling, sv_fusion and none categories. Return output in JSON with ALWAYS a single word answer { \"answer\": \"dge\" }, that is 'summary' for summary plot, 'dge' for differential gene expression, 'hierarchical' for hierarchical clustering, 'snv_indel' for SNV/Indel, 'cnv' for CNV and 'sv_fusion' for SV/fusion, 'variant_calling' for variant calling, 'surivial' for survival data, 'none' for none of the previously described categories. The summary plot list and summarizes the cohort of patients according to the user query. The answer should always be in lower case\n The options are as follows:\n") + &contents + "\nQuestion= {question} \nanswer")).temperature(temperature).additional_params(additional).build();
|
|
@@ -580,6 +539,7 @@ struct DEOutput {
|
|
|
580
539
|
group2: Group,
|
|
581
540
|
}
|
|
582
541
|
|
|
542
|
+
#[allow(dead_code)]
|
|
583
543
|
#[allow(non_snake_case)]
|
|
584
544
|
async fn extract_DE_search_terms_from_query(
|
|
585
545
|
user_input: &str,
|
|
@@ -853,7 +813,7 @@ async fn parse_dataset_db(db: &str) -> (Vec<String>, Vec<DbRows>) {
|
|
|
853
813
|
(rag_docs, db_vec)
|
|
854
814
|
}
|
|
855
815
|
|
|
856
|
-
async fn extract_summary_information(
|
|
816
|
+
pub async fn extract_summary_information(
|
|
857
817
|
user_input: &str,
|
|
858
818
|
comp_model: impl rig::completion::CompletionModel + 'static,
|
|
859
819
|
_embedding_model: impl rig::embeddings::EmbeddingModel + 'static,
|
|
@@ -911,8 +871,8 @@ async fn extract_summary_information(
|
|
|
911
871
|
|
|
912
872
|
let mut summary_data_check: Option<TrainTestData> = None;
|
|
913
873
|
for chart in ai_json.charts.clone() {
|
|
914
|
-
if
|
|
915
|
-
summary_data_check = Some(
|
|
874
|
+
if chart.r#type == "Summary" {
|
|
875
|
+
summary_data_check = Some(chart);
|
|
916
876
|
break;
|
|
917
877
|
}
|
|
918
878
|
}
|
|
@@ -922,17 +882,23 @@ async fn extract_summary_information(
|
|
|
922
882
|
let mut training_data: String = String::from("");
|
|
923
883
|
let mut train_iter = 0;
|
|
924
884
|
for ques_ans in summary_data.TrainingData {
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
885
|
+
match ques_ans.answer {
|
|
886
|
+
AnswerFormat::summary_type(sum) => {
|
|
887
|
+
let summary_answer: SummaryType = sum;
|
|
888
|
+
train_iter += 1;
|
|
889
|
+
training_data += "Example question";
|
|
890
|
+
training_data += &train_iter.to_string();
|
|
891
|
+
training_data += &":";
|
|
892
|
+
training_data += &ques_ans.question;
|
|
893
|
+
training_data += &" ";
|
|
894
|
+
training_data += "Example answer";
|
|
895
|
+
training_data += &train_iter.to_string();
|
|
896
|
+
training_data += &":";
|
|
897
|
+
training_data += &serde_json::to_string(&summary_answer).unwrap();
|
|
898
|
+
training_data += &"\n";
|
|
899
|
+
}
|
|
900
|
+
AnswerFormat::DE_type(_) => panic!("DE type not valid for summary"),
|
|
901
|
+
}
|
|
936
902
|
}
|
|
937
903
|
|
|
938
904
|
let system_prompt: String = String::from(
|
|
@@ -1001,7 +967,7 @@ fn get_summary_string() -> String {
|
|
|
1001
967
|
//const geneExpression: &str = &"geneExpression";
|
|
1002
968
|
|
|
1003
969
|
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
1004
|
-
struct SummaryType {
|
|
970
|
+
pub struct SummaryType {
|
|
1005
971
|
// Serde uses this for deserialization.
|
|
1006
972
|
#[serde(default = "get_summary_string")]
|
|
1007
973
|
// Schemars uses this for schema generation.
|
|
@@ -1014,7 +980,7 @@ struct SummaryType {
|
|
|
1014
980
|
|
|
1015
981
|
impl SummaryType {
|
|
1016
982
|
#[allow(dead_code)]
|
|
1017
|
-
pub fn sort_summarytype_struct(
|
|
983
|
+
pub fn sort_summarytype_struct(mut self) -> SummaryType {
|
|
1018
984
|
// This function is necessary for testing (test_ai.rs) to see if two variables of type "SummaryType" are equal or not. Without this a vector of two Summarytype holding the same values but in different order will be classified separately.
|
|
1019
985
|
self.summaryterms.sort();
|
|
1020
986
|
|
|
@@ -1022,6 +988,7 @@ impl SummaryType {
|
|
|
1022
988
|
Some(ref mut filterterms) => filterterms.sort(),
|
|
1023
989
|
None => {}
|
|
1024
990
|
}
|
|
991
|
+
self.clone()
|
|
1025
992
|
}
|
|
1026
993
|
}
|
|
1027
994
|
|
|
@@ -1039,7 +1006,7 @@ impl PartialOrd for SummaryTerms {
|
|
|
1039
1006
|
(SummaryTerms::clinical(_), SummaryTerms::clinical(_)) => Some(std::cmp::Ordering::Equal),
|
|
1040
1007
|
(SummaryTerms::geneExpression(_), SummaryTerms::geneExpression(_)) => Some(std::cmp::Ordering::Equal),
|
|
1041
1008
|
(SummaryTerms::clinical(_), SummaryTerms::geneExpression(_)) => Some(std::cmp::Ordering::Greater),
|
|
1042
|
-
(SummaryTerms::geneExpression(_), SummaryTerms::clinical(_)) => Some(std::cmp::Ordering::
|
|
1009
|
+
(SummaryTerms::geneExpression(_), SummaryTerms::clinical(_)) => Some(std::cmp::Ordering::Less),
|
|
1043
1010
|
}
|
|
1044
1011
|
}
|
|
1045
1012
|
}
|
|
@@ -1313,10 +1280,40 @@ fn validate_summary_output(
|
|
|
1313
1280
|
+ &categorical_filter.value
|
|
1314
1281
|
+ &"\"},";
|
|
1315
1282
|
validated_filter_terms_PP += &string_json;
|
|
1316
|
-
filter_hits += 1; // Once numeric term is also implemented, this statement will go outside the match block
|
|
1317
1283
|
}
|
|
1318
|
-
FilterTerm::Numeric(
|
|
1284
|
+
FilterTerm::Numeric(numeric_filter) => {
|
|
1285
|
+
let string_json;
|
|
1286
|
+
if numeric_filter.greaterThan.is_some() && numeric_filter.lessThan.is_none() {
|
|
1287
|
+
string_json = "{\"term\":\"".to_string()
|
|
1288
|
+
+ &numeric_filter.term
|
|
1289
|
+
+ &"\", \"gt\":\""
|
|
1290
|
+
+ &numeric_filter.greaterThan.unwrap().to_string()
|
|
1291
|
+
+ &"\"},";
|
|
1292
|
+
} else if numeric_filter.greaterThan.is_none() && numeric_filter.lessThan.is_some() {
|
|
1293
|
+
string_json = "{\"term\":\"".to_string()
|
|
1294
|
+
+ &numeric_filter.term
|
|
1295
|
+
+ &"\", \"lt\":\""
|
|
1296
|
+
+ &numeric_filter.lessThan.unwrap().to_string()
|
|
1297
|
+
+ &"\"},";
|
|
1298
|
+
} else if numeric_filter.greaterThan.is_some() && numeric_filter.lessThan.is_some() {
|
|
1299
|
+
string_json = "{\"term\":\"".to_string()
|
|
1300
|
+
+ &numeric_filter.term
|
|
1301
|
+
+ &"\", \"lt\":\""
|
|
1302
|
+
+ &numeric_filter.lessThan.unwrap().to_string()
|
|
1303
|
+
+ &"\", \"gt\":\""
|
|
1304
|
+
+ &numeric_filter.greaterThan.unwrap().to_string()
|
|
1305
|
+
+ &"\"},";
|
|
1306
|
+
} else {
|
|
1307
|
+
// When both greater and less than are none
|
|
1308
|
+
panic!(
|
|
1309
|
+
"Numeric filter term {} is missing both greater than and less than values. One of them must be defined",
|
|
1310
|
+
&numeric_filter.term
|
|
1311
|
+
);
|
|
1312
|
+
}
|
|
1313
|
+
validated_filter_terms_PP += &string_json;
|
|
1314
|
+
}
|
|
1319
1315
|
};
|
|
1316
|
+
filter_hits += 1;
|
|
1320
1317
|
}
|
|
1321
1318
|
println!("validated_filter_terms_PP:{}", validated_filter_terms_PP);
|
|
1322
1319
|
if filter_hits > 0 {
|