@sjcrh/proteinpaint-rust 2.167.0 → 2.170.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/aichatbot.rs CHANGED
@@ -1,7 +1,7 @@
1
1
  // Syntax: cd .. && cargo build --release && time cat ~/sjpp/test.txt | target/release/aichatbot
2
2
  #![allow(non_snake_case)]
3
- use anyhow::Result;
4
- use json::JsonValue;
3
+ //use anyhow::Result;
4
+ //use json::JsonValue;
5
5
  use r2d2_sqlite::SqliteConnectionManager;
6
6
  use rig::agent::AgentBuilder;
7
7
  use rig::completion::Prompt;
@@ -11,40 +11,53 @@ use schemars::JsonSchema;
11
11
  use serde_json::{Map, Value, json};
12
12
  use std::collections::HashMap;
13
13
  use std::fs;
14
- use std::io;
15
- use std::path::Path;
16
- mod ollama; // Importing custom rig module for invoking ollama server
17
- mod sjprovider; // Importing custom rig module for invoking SJ GPU server
18
-
19
- mod test_ai; // Test examples for AI chatbot
20
14
 
21
15
  // Struct for intaking data from dataset json
22
16
  #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
23
17
  pub struct AiJsonFormat {
24
- hasGeneExpression: bool,
25
- db: String, // Dataset db
26
- genedb: String, // Gene db
27
- charts: Vec<Charts>,
18
+ pub hasGeneExpression: bool,
19
+ pub db: String, // Dataset db
20
+ pub genedb: String, // Gene db
21
+ pub charts: Vec<TrainTestData>,
22
+ }
23
+
24
+ #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
25
+ pub struct TrainTestData {
26
+ pub r#type: String,
27
+ pub SystemPrompt: String,
28
+ pub TrainingData: Vec<QuestionAnswer>,
29
+ pub TestData: Vec<QuestionAnswer>,
28
30
  }
29
31
 
30
32
  #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
31
- enum Charts {
32
- // More chart types will be added here later
33
- Summary(TrainTestData),
34
- DE(TrainTestData),
33
+ pub struct QuestionAnswer {
34
+ pub question: String,
35
+ pub answer: AnswerFormat,
35
36
  }
36
37
 
37
38
  #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
38
- struct TrainTestData {
39
- SystemPrompt: String,
40
- TrainingData: Vec<QuestionAnswer>,
41
- TestData: Vec<QuestionAnswer>,
39
+ pub enum AnswerFormat {
40
+ #[allow(non_camel_case_types)]
41
+ summary_type(SummaryType),
42
+ #[allow(non_camel_case_types)]
43
+ DE_type(DEType),
42
44
  }
43
45
 
44
46
  #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
45
- struct QuestionAnswer {
46
- question: String,
47
- answer: String,
47
+ pub struct DEType {
48
+ action: String,
49
+ DE_output: DETerms,
50
+ }
51
+
52
+ #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
53
+ pub struct DETerms {
54
+ group1: GroupType,
55
+ group2: GroupType,
56
+ }
57
+
58
+ #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
59
+ pub struct GroupType {
60
+ name: String,
48
61
  }
49
62
 
50
63
  #[allow(non_camel_case_types)]
@@ -56,180 +69,188 @@ pub enum llm_backend {
56
69
 
57
70
  #[derive(Debug, JsonSchema)]
58
71
  #[allow(dead_code)]
59
- struct OutputJson {
72
+ pub struct OutputJson {
60
73
  pub answer: String,
61
74
  }
62
75
 
63
- #[tokio::main]
64
- async fn main() -> Result<()> {
65
- let mut input = String::new();
66
- match io::stdin().read_line(&mut input) {
67
- // Accepting the piped input from nodejs (or command line from testing)
68
- Ok(_n) => {
69
- let input_json = json::parse(&input);
70
- match input_json {
71
- Ok(json_string) => {
72
- //println!("json_string:{}", json_string);
73
- let user_input_json: &JsonValue = &json_string["user_input"];
74
- let user_input: &str;
75
- match user_input_json.as_str() {
76
- Some(inp) => user_input = inp,
77
- None => panic!("user_input field is missing in input json"),
78
- }
79
-
80
- if user_input.len() == 0 {
81
- panic!("The user input is empty");
82
- }
83
-
84
- let tpmasterdir_json: &JsonValue = &json_string["tpmasterdir"];
85
- let tpmasterdir: &str;
86
- match tpmasterdir_json.as_str() {
87
- Some(inp) => tpmasterdir = inp,
88
- None => panic!("tpmasterdir not found"),
89
- }
90
-
91
- let binpath_json: &JsonValue = &json_string["binpath"];
92
- let binpath: &str;
93
- match binpath_json.as_str() {
94
- Some(inp) => binpath = inp,
95
- None => panic!("binpath not found"),
96
- }
97
-
98
- let ai_json_file_json: &JsonValue = &json_string["aifiles"];
99
- let ai_json_file: String;
100
- match ai_json_file_json.as_str() {
101
- Some(inp) => ai_json_file = String::from(binpath) + &"/../../" + &inp,
102
- None => {
103
- panic!("ai json file not found")
104
- }
105
- }
106
-
107
- let ai_json_file = Path::new(&ai_json_file);
108
- let ai_json_file_path;
109
- let current_dir = std::env::current_dir().unwrap();
110
- match ai_json_file.canonicalize() {
111
- Ok(p) => ai_json_file_path = p,
112
- Err(_) => {
113
- panic!(
114
- "AI JSON file path not found:{:?}, current directory:{:?}",
115
- ai_json_file, current_dir
116
- )
117
- }
118
- }
119
-
120
- // Read the file
121
- let ai_data = fs::read_to_string(ai_json_file_path).unwrap();
122
-
123
- // Parse the JSON data
124
- let ai_json: AiJsonFormat =
125
- serde_json::from_str(&ai_data).expect("AI JSON file does not have the correct format");
126
-
127
- let genedb = String::from(tpmasterdir) + &"/" + &ai_json.genedb;
128
- let dataset_db = String::from(tpmasterdir) + &"/" + &ai_json.db;
129
-
130
- let apilink_json: &JsonValue = &json_string["apilink"];
131
- let apilink: &str;
132
- match apilink_json.as_str() {
133
- Some(inp) => apilink = inp,
134
- None => panic!("apilink field is missing in input json"),
135
- }
136
-
137
- let comp_model_name_json: &JsonValue = &json_string["comp_model_name"];
138
- let comp_model_name: &str;
139
- match comp_model_name_json.as_str() {
140
- Some(inp) => comp_model_name = inp,
141
- None => panic!("comp_model_name field is missing in input json"),
142
- }
143
-
144
- let embedding_model_name_json: &JsonValue = &json_string["embedding_model_name"];
145
- let embedding_model_name: &str;
146
- match embedding_model_name_json.as_str() {
147
- Some(inp) => embedding_model_name = inp,
148
- None => panic!("embedding_model_name field is missing in input json"),
149
- }
150
-
151
- let llm_backend_name_json: &JsonValue = &json_string["llm_backend_name"];
152
- let llm_backend_name: &str;
153
- match llm_backend_name_json.as_str() {
154
- Some(inp) => llm_backend_name = inp,
155
- None => panic!("llm_backend_name field is missing in input json"),
156
- }
157
-
158
- let llm_backend_type: llm_backend;
159
- let mut final_output: Option<String> = None;
160
- let temperature: f64 = 0.01;
161
- let max_new_tokens: usize = 512;
162
- let top_p: f32 = 0.95;
163
- let testing = false; // This variable is always false in production, this is true in test_ai.rs for testing code
164
- if llm_backend_name != "ollama" && llm_backend_name != "SJ" {
165
- panic!(
166
- "This code currently supports only Ollama and SJ provider. llm_backend_name must be \"ollama\" or \"SJ\""
167
- );
168
- } else if llm_backend_name == "ollama".to_string() {
169
- llm_backend_type = llm_backend::Ollama();
170
- // Initialize Ollama client
171
- let ollama_client = ollama::Client::builder()
172
- .base_url(apilink)
173
- .build()
174
- .expect("Ollama server not found");
175
- let embedding_model = ollama_client.embedding_model(embedding_model_name);
176
- let comp_model = ollama_client.completion_model(comp_model_name);
177
- final_output = run_pipeline(
178
- user_input,
179
- comp_model,
180
- embedding_model,
181
- llm_backend_type,
182
- temperature,
183
- max_new_tokens,
184
- top_p,
185
- &dataset_db,
186
- &genedb,
187
- &ai_json,
188
- testing,
189
- )
190
- .await;
191
- } else if llm_backend_name == "SJ".to_string() {
192
- llm_backend_type = llm_backend::Sj();
193
- // Initialize Sj provider client
194
- let sj_client = sjprovider::Client::builder()
195
- .base_url(apilink)
196
- .build()
197
- .expect("SJ server not found");
198
- let embedding_model = sj_client.embedding_model(embedding_model_name);
199
- let comp_model = sj_client.completion_model(comp_model_name);
200
- final_output = run_pipeline(
201
- user_input,
202
- comp_model,
203
- embedding_model,
204
- llm_backend_type,
205
- temperature,
206
- max_new_tokens,
207
- top_p,
208
- &dataset_db,
209
- &genedb,
210
- &ai_json,
211
- testing,
212
- )
213
- .await;
214
- }
215
-
216
- match final_output {
217
- Some(fin_out) => {
218
- println!("final_output:{:?}", fin_out.replace("\\", ""));
219
- }
220
- None => {
221
- println!("final_output:{{\"{}\":\"{}\"}}", "action", "unknown");
222
- }
223
- }
224
- }
225
- Err(error) => println!("Incorrect json:{}", error),
226
- }
227
- }
228
- Err(error) => println!("Piping error: {}", error),
229
- }
230
- Ok(())
231
- }
76
+ //#[tokio::main]
77
+ //async fn main() -> Result<()> {
78
+ // let mut input = String::new();
79
+ // match io::stdin().read_line(&mut input) {
80
+ // // Accepting the piped input from nodejs (or command line from testing)
81
+ // Ok(_n) => {
82
+ // let input_json = json::parse(&input);
83
+ // match input_json {
84
+ // Ok(json_string) => {
85
+ // //println!("json_string:{}", json_string);
86
+ // let user_input_json: &JsonValue = &json_string["user_input"];
87
+ // let user_input: &str;
88
+ // match user_input_json.as_str() {
89
+ // Some(inp) => user_input = inp,
90
+ // None => panic!("user_input field is missing in input json"),
91
+ // }
92
+ // let dataset_db_json: &JsonValue = &json_string["dataset_db"];
93
+ // let dataset_db_str: &str;
94
+ // match dataset_db_json.as_str() {
95
+ // Some(inp) => dataset_db_str = inp,
96
+ // None => panic!("dataset_db field is missing in input json"),
97
+ // }
98
+ // let genedb_json: &JsonValue = &json_string["genedb"];
99
+ // let genedb_str: &str;
100
+ // match genedb_json.as_str() {
101
+ // Some(inp) => genedb_str = inp,
102
+ // None => panic!("genedb field is missing in input json"),
103
+ // }
104
+ // let aiRoute_json: &JsonValue = &json_string["aiRoute"];
105
+ // let aiRoute_str: &str;
106
+ // match aiRoute_json.as_str() {
107
+ // Some(inp) => aiRoute_str = inp,
108
+ // None => panic!("aiRoute field is missing in input json"),
109
+ // }
110
+ // if user_input.len() == 0 {
111
+ // panic!("The user input is empty");
112
+ // }
113
+ // let tpmasterdir_json: &JsonValue = &json_string["tpmasterdir"];
114
+ // let tpmasterdir: &str;
115
+ // match tpmasterdir_json.as_str() {
116
+ // Some(inp) => tpmasterdir = inp,
117
+ // None => panic!("tpmasterdir not found"),
118
+ // }
119
+ // let binpath_json: &JsonValue = &json_string["binpath"];
120
+ // let binpath: &str;
121
+ // match binpath_json.as_str() {
122
+ // Some(inp) => binpath = inp,
123
+ // None => panic!("binpath not found"),
124
+ // }
125
+ // let ai_json_file_json: &JsonValue = &json_string["aifiles"];
126
+ // let ai_json_file: String;
127
+ // match ai_json_file_json.as_str() {
128
+ // Some(inp) => ai_json_file = String::from(binpath) + &"/../../" + &inp,
129
+ // None => {
130
+ // panic!("ai json file not found")
131
+ // }
132
+ // }
133
+ // let ai_json_file = Path::new(&ai_json_file);
134
+ // let ai_json_file_path;
135
+ // let current_dir = std::env::current_dir().unwrap();
136
+ // match ai_json_file.canonicalize() {
137
+ // Ok(p) => ai_json_file_path = p,
138
+ // Err(_) => {
139
+ // panic!(
140
+ // "AI JSON file path not found:{:?}, current directory:{:?}",
141
+ // ai_json_file, current_dir
142
+ // )
143
+ // }
144
+ // }
145
+ // // Read the file
146
+ // let ai_data = fs::read_to_string(ai_json_file_path).unwrap();
147
+ // // Parse the JSON data
148
+ // let ai_json: AiJsonFormat =
149
+ // serde_json::from_str(&ai_data).expect("AI JSON file does not have the correct format");
150
+ // let genedb = String::from(tpmasterdir) + &"/" + &genedb_str;
151
+ // let dataset_db = String::from(tpmasterdir) + &"/" + &dataset_db_str;
152
+ // let airoute = String::from(binpath) + &"/../../" + &aiRoute_str;
153
+ // let apilink_json: &JsonValue = &json_string["apilink"];
154
+ // let apilink: &str;
155
+ // match apilink_json.as_str() {
156
+ // Some(inp) => apilink = inp,
157
+ // None => panic!("apilink field is missing in input json"),
158
+ // }
159
+ // let comp_model_name_json: &JsonValue = &json_string["comp_model_name"];
160
+ // let comp_model_name: &str;
161
+ // match comp_model_name_json.as_str() {
162
+ // Some(inp) => comp_model_name = inp,
163
+ // None => panic!("comp_model_name field is missing in input json"),
164
+ // }
165
+ // let embedding_model_name_json: &JsonValue = &json_string["embedding_model_name"];
166
+ // let embedding_model_name: &str;
167
+ // match embedding_model_name_json.as_str() {
168
+ // Some(inp) => embedding_model_name = inp,
169
+ // None => panic!("embedding_model_name field is missing in input json"),
170
+ // }
171
+ // let llm_backend_name_json: &JsonValue = &json_string["llm_backend_name"];
172
+ // let llm_backend_name: &str;
173
+ // match llm_backend_name_json.as_str() {
174
+ // Some(inp) => llm_backend_name = inp,
175
+ // None => panic!("llm_backend_name field is missing in input json"),
176
+ // }
177
+ // let llm_backend_type: llm_backend;
178
+ // let mut final_output: Option<String> = None;
179
+ // let temperature: f64 = 0.01;
180
+ // let max_new_tokens: usize = 512;
181
+ // let top_p: f32 = 0.95;
182
+ // let testing = false; // This variable is always false in production, this is true in test_ai.rs for testing code
183
+ // if llm_backend_name != "ollama" && llm_backend_name != "SJ" {
184
+ // panic!(
185
+ // "This code currently supports only Ollama and SJ provider. llm_backend_name must be \"ollama\" or \"SJ\""
186
+ // );
187
+ // } else if llm_backend_name == "ollama".to_string() {
188
+ // llm_backend_type = llm_backend::Ollama();
189
+ // // Initialize Ollama client
190
+ // let ollama_client = ollama::Client::builder()
191
+ // .base_url(apilink)
192
+ // .build()
193
+ // .expect("Ollama server not found");
194
+ // let embedding_model = ollama_client.embedding_model(embedding_model_name);
195
+ // let comp_model = ollama_client.completion_model(comp_model_name);
196
+ // final_output = run_pipeline(
197
+ // user_input,
198
+ // comp_model,
199
+ // embedding_model,
200
+ // llm_backend_type,
201
+ // temperature,
202
+ // max_new_tokens,
203
+ // top_p,
204
+ // &dataset_db,
205
+ // &genedb,
206
+ // &ai_json,
207
+ // &airoute,
208
+ // testing,
209
+ // )
210
+ // .await;
211
+ // } else if llm_backend_name == "SJ".to_string() {
212
+ // llm_backend_type = llm_backend::Sj();
213
+ // // Initialize Sj provider client
214
+ // let sj_client = sjprovider::Client::builder()
215
+ // .base_url(apilink)
216
+ // .build()
217
+ // .expect("SJ server not found");
218
+ // let embedding_model = sj_client.embedding_model(embedding_model_name);
219
+ // let comp_model = sj_client.completion_model(comp_model_name);
220
+ // final_output = run_pipeline(
221
+ // user_input,
222
+ // comp_model,
223
+ // embedding_model,
224
+ // llm_backend_type,
225
+ // temperature,
226
+ // max_new_tokens,
227
+ // top_p,
228
+ // &dataset_db,
229
+ // &genedb,
230
+ // &ai_json,
231
+ // &airoute,
232
+ // testing,
233
+ // )
234
+ // .await;
235
+ // }
236
+ // match final_output {
237
+ // Some(fin_out) => {
238
+ // println!("final_output:{:?}", fin_out.replace("\\", ""));
239
+ // }
240
+ // None => {
241
+ // println!("final_output:{{\"{}\":\"{}\"}}", "action", "unknown");
242
+ // }
243
+ // }
244
+ // }
245
+ // Err(error) => println!("Incorrect json:{}", error),
246
+ // }
247
+ // }
248
+ // Err(error) => println!("Piping error: {}", error),
249
+ // }
250
+ // Ok(())
251
+ //}
232
252
 
253
+ #[allow(dead_code)]
233
254
  pub async fn run_pipeline(
234
255
  user_input: &str,
235
256
  comp_model: impl rig::completion::CompletionModel + 'static,
@@ -241,6 +262,7 @@ pub async fn run_pipeline(
241
262
  dataset_db: &str,
242
263
  genedb: &str,
243
264
  ai_json: &AiJsonFormat,
265
+ ai_route: &str,
244
266
  testing: bool,
245
267
  ) -> Option<String> {
246
268
  let mut classification: String = classify_query_by_dataset_type(
@@ -251,6 +273,7 @@ pub async fn run_pipeline(
251
273
  temperature,
252
274
  max_new_tokens,
253
275
  top_p,
276
+ ai_route,
254
277
  )
255
278
  .await;
256
279
  classification = classification.replace("\"", "");
@@ -373,104 +396,40 @@ pub async fn run_pipeline(
373
396
  Some(final_output)
374
397
  }
375
398
 
376
- async fn classify_query_by_dataset_type(
399
+ #[allow(dead_code)]
400
+ pub async fn classify_query_by_dataset_type(
377
401
  user_input: &str,
378
402
  comp_model: impl rig::completion::CompletionModel + 'static,
379
- embedding_model: impl rig::embeddings::EmbeddingModel + 'static,
403
+ _embedding_model: impl rig::embeddings::EmbeddingModel + 'static,
380
404
  llm_backend_type: &llm_backend,
381
405
  temperature: f64,
382
406
  max_new_tokens: usize,
383
407
  top_p: f32,
408
+ ai_route: &str,
384
409
  ) -> String {
385
- // Create a string to hold the file contents
386
- let contents = String::from("SNV/SNP or point mutations nucleotide mutations are very common forms of mutations which can often give rise to genetic diseases such as cancer, Alzheimer's disease etc. They can be duw to substitution of nucleotide, or insertion or deletion of a nucleotide. Indels are multi-nucleotide insertion/deletion/substitutions. Complex indels are indels where insertion and deletion have happened in the same genomic locus. Every genomic sample from each patient has its own set of mutations therefore requiring personalized treatment.
387
-
388
- If a ProteinPaint dataset contains SNV/Indel/SV data then return JSON with single key, 'snv_indel'.
389
-
390
- ---
391
-
392
- Copy number variation (CNV) is a phenomenon in which sections of the genome are repeated and the number of repeats in the genome varies between individuals.[1] Copy number variation is a special type of structural variation: specifically, it is a type of duplication or deletion event that affects a considerable number of base pairs.
410
+ // Read the file
411
+ let ai_route_data = fs::read_to_string(ai_route).unwrap();
393
412
 
394
- If a ProteinPaint dataset contains copy number variation data then return JSON with single key, 'cnv'.
395
-
396
- ---
397
-
398
- Structural variants/fusions (SV) are genomic mutations when eith a DNA region is translocated or copied to an entirely different genomic locus. In case of transcriptomic data, when RNA is fused from two different genes its called a gene fusion.
399
-
400
- If a ProteinPaint dataset contains structural variation or gene fusion data then return JSON with single key, 'sv_fusion'.
401
- ---
402
-
403
- Hierarchical clustering of gene expression is an unsupervised learning technique where several number of relevant genes and the samples are clustered so as to determine (previously unknown) cohorts of samples (or patients) or structure in data. It is very commonly used to determine subtypes of a particular disease based on RNA sequencing data.
404
-
405
- If a ProteinPaint dataset contains hierarchical data then return JSON with single key, 'hierarchical'.
406
-
407
- ---
408
-
409
- Differential Gene Expression (DGE or DE) is a technique where the most upregulated (or highest) and downregulated (or lowest) genes between two cohorts of samples (or patients) are determined from a pool of THOUSANDS of genes. Differential gene expression CANNOT be computed for a SINGLE gene. A volcano plot is shown with fold-change in the x-axis and adjusted p-value on the y-axis. So, the upregulated and downregulared genes are on opposite sides of the graph and the most significant genes (based on adjusted p-value) is on the top of the graph. Following differential gene expression generally GeneSet Enrichment Analysis (GSEA) is carried out where based on the genes and their corresponding fold changes the upregulation/downregulation of genesets (or pathways) is determined.
410
-
411
- Sample Query1: \"Which gene has the highest expression between the two genders\"
412
- Sample Answer1: { \"answer\": \"dge\" }
413
-
414
- Sample Query2: \"Which gene has the lowest expression between the two races\"
415
- Sample Answer2: { \"answer\": \"dge\" }
416
-
417
- Sample Query1: \"Which genes are the most upregulated genes between group A and group B\"
418
- Sample Answer1: { \"answer\": \"dge\" }
419
-
420
- Sample Query3: \"Which gene are overexpressed between male and female\"
421
- Sample Answer3: { \"answer\": \"dge\" }
422
-
423
- Sample Query4: \"Which gene are housekeeping genes between male and female\"
424
- Sample Answer4: { \"answer\": \"dge\" }
425
-
426
-
427
- If a ProteinPaint dataset contains differential gene expression data then return JSON with single key, 'dge'.
428
-
429
- ---
430
-
431
- Survival analysis (also called time-to-event analysis or duration analysis) is a branch of statistics aimed at analyzing the duration of time from a well-defined time origin until one or more events happen, called survival times or duration times. In other words, in survival analysis, we are interested in a certain event and want to analyze the time until the event happens. Generally in survival analysis survival rates between two (or more) cohorts of patients is compared.
432
-
433
- There are two main methods of survival analysis:
434
-
435
- 1) Kaplan-Meier (HM) analysis is a univariate test that only takes into account a single categorical variable.
436
- 2) Cox proportional hazards model (coxph) is a multivariate test that can take into account multiple variables.
437
-
438
- The hazard ratio (HR) is an indicator of the effect of the stimulus (e.g. drug dose, treatment) between two cohorts of patients.
439
- HR = 1: No effect
440
- HR < 1: Reduction in the hazard
441
- HR > 1: Increase in Hazard
442
-
443
- Sample Query1: \"Compare survival rates between group A and B\"
444
- Sample Answer1: { \"answer\": \"survival\" }
445
-
446
-
447
- If a ProteinPaint dataset contains survival data then return JSON with single key, 'survival'.
448
-
449
- ---
450
-
451
- Next generation sequencing reads (NGS) are mapped to a human genome using alignment algorithm such as burrows-wheelers alignment algorithm. Then these reads are called using variant calling algorithms such as GATK (Genome Analysis Toolkit). However this type of analysis is too compute intensive and beyond the scope of visualization software such as ProteinPaint.
452
-
453
- If a user query asks about variant calling or mapping reads then JSON with single key, 'variant_calling'.
454
-
455
- ---
456
-
457
- Summary plot in ProteinPaint shows the various facets of the datasets. Show expression of a SINGLE gene or compare the expression of a SINGLE gene across two different cohorts defined by the user. It may show all the samples according to their respective diagnosis or subtypes of cancer. It is also useful for comparing and correlating different clinical variables. It can show all possible distributions, frequency of a category, overlay, correlate or cross-tabulate with another variable on top of it. If a user query asks about a SINGLE gene expression or correlating clinical variables then return JSON with single key, 'summary'.
458
-
459
- Sample Query1: \"Show all fusions for patients with age less than 30\"
460
- Sample Answer1: { \"answer\": \"summary\" }
461
-
462
- Sample Query2: \"List all molecular subtypes of leukemia\"
463
- Sample Answer2: { \"answer\": \"summary\" }
464
-
465
- Sample Query3: \"is tp53 expression higher in men than women ?\"
466
- Sample Answer3: { \"answer\": \"summary\" }
467
-
468
- Sample Query4: \"Compare ATM expression between races for women greater than 80yrs\"
469
- Sample Answer4: { \"answer\": \"summary\" }
413
+ // Parse the JSON data
414
+ let ai_json: Value = serde_json::from_str(&ai_route_data).expect("AI JSON file does not have the correct format");
470
415
 
416
+ // Create a string to hold the file contents
417
+ let mut contents = String::from("");
418
+
419
+ if let Some(object) = ai_json.as_object() {
420
+ contents = object["general"].to_string();
421
+ for (key, value) in object {
422
+ if key != "general" {
423
+ contents += &value.as_str().unwrap();
424
+ contents += "---"; // Adding delimiter
425
+ }
426
+ }
427
+ }
471
428
 
472
- If a query does not match any of the fields described above, then return JSON with single key, 'none'
473
- ");
429
+ // Removing the last "---" characters
430
+ contents.pop();
431
+ contents.pop();
432
+ contents.pop();
474
433
 
475
434
  // Split the contents by the delimiter "---"
476
435
  let parts: Vec<&str> = contents.split("---").collect();
@@ -501,18 +460,18 @@ If a query does not match any of the fields described above, then return JSON wi
501
460
  rag_docs.push(part.trim().to_string())
502
461
  }
503
462
 
504
- //let top_k: usize = 3;
463
+ //let top_k: usize = 3; // Embedding model not used currently
505
464
  // Create embeddings and add to vector store
506
- let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
507
- .documents(rag_docs)
508
- .expect("Reason1")
509
- .build()
510
- .await
511
- .unwrap();
465
+ //let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
466
+ // .documents(rag_docs)
467
+ // .expect("Reason1")
468
+ // .build()
469
+ // .await
470
+ // .unwrap();
512
471
 
513
- // Create vector store
514
- let mut vector_store = InMemoryVectorStore::<String>::default();
515
- InMemoryVectorStore::add_documents(&mut vector_store, embeddings);
472
+ //// Create vector store
473
+ //let mut vector_store = InMemoryVectorStore::<String>::default();
474
+ //InMemoryVectorStore::add_documents(&mut vector_store, embeddings);
516
475
 
517
476
  // Create RAG agent
518
477
  let agent = AgentBuilder::new(comp_model).preamble(&(String::from("Generate classification for the user query into summary, dge, hierarchical, snv_indel, cnv, variant_calling, sv_fusion and none categories. Return output in JSON with ALWAYS a single word answer { \"answer\": \"dge\" }, that is 'summary' for summary plot, 'dge' for differential gene expression, 'hierarchical' for hierarchical clustering, 'snv_indel' for SNV/Indel, 'cnv' for CNV and 'sv_fusion' for SV/fusion, 'variant_calling' for variant calling, 'surivial' for survival data, 'none' for none of the previously described categories. The summary plot list and summarizes the cohort of patients according to the user query. The answer should always be in lower case\n The options are as follows:\n") + &contents + "\nQuestion= {question} \nanswer")).temperature(temperature).additional_params(additional).build();
@@ -580,6 +539,7 @@ struct DEOutput {
580
539
  group2: Group,
581
540
  }
582
541
 
542
+ #[allow(dead_code)]
583
543
  #[allow(non_snake_case)]
584
544
  async fn extract_DE_search_terms_from_query(
585
545
  user_input: &str,
@@ -853,7 +813,7 @@ async fn parse_dataset_db(db: &str) -> (Vec<String>, Vec<DbRows>) {
853
813
  (rag_docs, db_vec)
854
814
  }
855
815
 
856
- async fn extract_summary_information(
816
+ pub async fn extract_summary_information(
857
817
  user_input: &str,
858
818
  comp_model: impl rig::completion::CompletionModel + 'static,
859
819
  _embedding_model: impl rig::embeddings::EmbeddingModel + 'static,
@@ -911,8 +871,8 @@ async fn extract_summary_information(
911
871
 
912
872
  let mut summary_data_check: Option<TrainTestData> = None;
913
873
  for chart in ai_json.charts.clone() {
914
- if let Charts::Summary(traindata) = chart {
915
- summary_data_check = Some(traindata);
874
+ if chart.r#type == "Summary" {
875
+ summary_data_check = Some(chart);
916
876
  break;
917
877
  }
918
878
  }
@@ -922,17 +882,23 @@ async fn extract_summary_information(
922
882
  let mut training_data: String = String::from("");
923
883
  let mut train_iter = 0;
924
884
  for ques_ans in summary_data.TrainingData {
925
- train_iter += 1;
926
- training_data += "Example question";
927
- training_data += &train_iter.to_string();
928
- training_data += &":";
929
- training_data += &ques_ans.question;
930
- training_data += &" ";
931
- training_data += "Example answer";
932
- training_data += &train_iter.to_string();
933
- training_data += &":";
934
- training_data += &ques_ans.answer;
935
- training_data += &"\n";
885
+ match ques_ans.answer {
886
+ AnswerFormat::summary_type(sum) => {
887
+ let summary_answer: SummaryType = sum;
888
+ train_iter += 1;
889
+ training_data += "Example question";
890
+ training_data += &train_iter.to_string();
891
+ training_data += &":";
892
+ training_data += &ques_ans.question;
893
+ training_data += &" ";
894
+ training_data += "Example answer";
895
+ training_data += &train_iter.to_string();
896
+ training_data += &":";
897
+ training_data += &serde_json::to_string(&summary_answer).unwrap();
898
+ training_data += &"\n";
899
+ }
900
+ AnswerFormat::DE_type(_) => panic!("DE type not valid for summary"),
901
+ }
936
902
  }
937
903
 
938
904
  let system_prompt: String = String::from(
@@ -1001,7 +967,7 @@ fn get_summary_string() -> String {
1001
967
  //const geneExpression: &str = &"geneExpression";
1002
968
 
1003
969
  #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
1004
- struct SummaryType {
970
+ pub struct SummaryType {
1005
971
  // Serde uses this for deserialization.
1006
972
  #[serde(default = "get_summary_string")]
1007
973
  // Schemars uses this for schema generation.
@@ -1014,7 +980,7 @@ struct SummaryType {
1014
980
 
1015
981
  impl SummaryType {
1016
982
  #[allow(dead_code)]
1017
- pub fn sort_summarytype_struct(&mut self) {
983
+ pub fn sort_summarytype_struct(mut self) -> SummaryType {
1018
984
  // This function is necessary for testing (test_ai.rs) to see if two variables of type "SummaryType" are equal or not. Without this a vector of two Summarytype holding the same values but in different order will be classified separately.
1019
985
  self.summaryterms.sort();
1020
986
 
@@ -1022,6 +988,7 @@ impl SummaryType {
1022
988
  Some(ref mut filterterms) => filterterms.sort(),
1023
989
  None => {}
1024
990
  }
991
+ self.clone()
1025
992
  }
1026
993
  }
1027
994
 
@@ -1039,7 +1006,7 @@ impl PartialOrd for SummaryTerms {
1039
1006
  (SummaryTerms::clinical(_), SummaryTerms::clinical(_)) => Some(std::cmp::Ordering::Equal),
1040
1007
  (SummaryTerms::geneExpression(_), SummaryTerms::geneExpression(_)) => Some(std::cmp::Ordering::Equal),
1041
1008
  (SummaryTerms::clinical(_), SummaryTerms::geneExpression(_)) => Some(std::cmp::Ordering::Greater),
1042
- (SummaryTerms::geneExpression(_), SummaryTerms::clinical(_)) => Some(std::cmp::Ordering::Greater),
1009
+ (SummaryTerms::geneExpression(_), SummaryTerms::clinical(_)) => Some(std::cmp::Ordering::Less),
1043
1010
  }
1044
1011
  }
1045
1012
  }
@@ -1313,10 +1280,40 @@ fn validate_summary_output(
1313
1280
  + &categorical_filter.value
1314
1281
  + &"\"},";
1315
1282
  validated_filter_terms_PP += &string_json;
1316
- filter_hits += 1; // Once numeric term is also implemented, this statement will go outside the match block
1317
1283
  }
1318
- FilterTerm::Numeric(_numeric_term) => {} // To be implemented later
1284
+ FilterTerm::Numeric(numeric_filter) => {
1285
+ let string_json;
1286
+ if numeric_filter.greaterThan.is_some() && numeric_filter.lessThan.is_none() {
1287
+ string_json = "{\"term\":\"".to_string()
1288
+ + &numeric_filter.term
1289
+ + &"\", \"gt\":\""
1290
+ + &numeric_filter.greaterThan.unwrap().to_string()
1291
+ + &"\"},";
1292
+ } else if numeric_filter.greaterThan.is_none() && numeric_filter.lessThan.is_some() {
1293
+ string_json = "{\"term\":\"".to_string()
1294
+ + &numeric_filter.term
1295
+ + &"\", \"lt\":\""
1296
+ + &numeric_filter.lessThan.unwrap().to_string()
1297
+ + &"\"},";
1298
+ } else if numeric_filter.greaterThan.is_some() && numeric_filter.lessThan.is_some() {
1299
+ string_json = "{\"term\":\"".to_string()
1300
+ + &numeric_filter.term
1301
+ + &"\", \"lt\":\""
1302
+ + &numeric_filter.lessThan.unwrap().to_string()
1303
+ + &"\", \"gt\":\""
1304
+ + &numeric_filter.greaterThan.unwrap().to_string()
1305
+ + &"\"},";
1306
+ } else {
1307
+ // When both greater and less than are none
1308
+ panic!(
1309
+ "Numeric filter term {} is missing both greater than and less than values. One of them must be defined",
1310
+ &numeric_filter.term
1311
+ );
1312
+ }
1313
+ validated_filter_terms_PP += &string_json;
1314
+ }
1319
1315
  };
1316
+ filter_hits += 1;
1320
1317
  }
1321
1318
  println!("validated_filter_terms_PP:{}", validated_filter_terms_PP);
1322
1319
  if filter_hits > 0 {