@sjcrh/proteinpaint-rust 2.149.0 → 2.152.1-0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/aichatbot.rs CHANGED
@@ -1,67 +1,63 @@
1
+ // Syntax: cd .. && cargo build --release && time cat ~/sjpp/test.txt | target/release/aichatbot
2
+ #![allow(non_snake_case)]
1
3
  use anyhow::Result;
2
4
  use json::JsonValue;
3
5
  use r2d2_sqlite::SqliteConnectionManager;
4
6
  use rig::agent::AgentBuilder;
5
- use rig::client::CompletionClient;
6
- use rig::client::EmbeddingsClient;
7
7
  use rig::completion::Prompt;
8
8
  use rig::embeddings::builder::EmbeddingsBuilder;
9
- use std::collections::HashMap;
10
- //use rig::providers::ollama;
11
9
  use rig::vector_store::in_memory_store::InMemoryVectorStore;
12
10
  use schemars::JsonSchema;
13
11
  use serde_json::{Map, Value, json};
14
- use std::io::{self};
12
+ use std::collections::HashMap;
13
+ use std::fs;
14
+ use std::io;
15
+ use std::path::Path;
16
+ mod ollama; // Importing custom rig module for invoking ollama server
15
17
  mod sjprovider; // Importing custom rig module for invoking SJ GPU server
16
18
 
17
- #[allow(non_camel_case_types)]
18
- #[derive(Debug, Clone)]
19
- enum llm_backend {
20
- Ollama(),
21
- Sj(),
22
- }
19
+ mod test_ai; // Test examples for AI chatbot
23
20
 
24
- #[derive(Debug, JsonSchema)]
25
- #[allow(dead_code)]
26
- struct OutputJson {
27
- pub answer: String,
21
+ // Struct for intaking data from dataset json
22
+ #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
23
+ pub struct AiJsonFormat {
24
+ hasGeneExpression: bool,
25
+ db: String, // Dataset db
26
+ genedb: String, // Gene db
27
+ charts: Vec<Charts>,
28
28
  }
29
29
 
30
- #[allow(non_camel_case_types)]
31
- #[derive(Debug, JsonSchema)]
32
- #[allow(dead_code)]
33
- enum cutoff_info {
34
- lesser(f32),
35
- greater(f32),
36
- equalto(f32),
30
+ #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
31
+ enum Charts {
32
+ // More chart types will be added here later
33
+ Summary(TrainTestData),
34
+ DE(TrainTestData),
37
35
  }
38
36
 
39
- #[derive(Debug, JsonSchema)]
40
- #[allow(dead_code)]
41
- struct Cutoff {
42
- cutoff_name: cutoff_info,
43
- units: Option<String>,
37
+ #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
38
+ struct TrainTestData {
39
+ SystemPrompt: String,
40
+ TrainingData: Vec<QuestionAnswer>,
41
+ TestData: Vec<QuestionAnswer>,
44
42
  }
45
43
 
46
- #[derive(Debug, JsonSchema)]
47
- #[allow(dead_code)]
48
- struct Filter {
49
- name: String,
50
- cutoff: Cutoff,
44
+ #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
45
+ struct QuestionAnswer {
46
+ question: String,
47
+ answer: String,
51
48
  }
52
49
 
53
- #[derive(Debug, JsonSchema)]
54
- #[allow(dead_code)]
55
- struct Group {
56
- name: String,
57
- filter: Filter,
50
+ #[allow(non_camel_case_types)]
51
+ #[derive(Debug, Clone)]
52
+ pub enum llm_backend {
53
+ Ollama(),
54
+ Sj(),
58
55
  }
59
56
 
60
57
  #[derive(Debug, JsonSchema)]
61
58
  #[allow(dead_code)]
62
- struct DEOutput {
63
- group1: Group,
64
- group2: Group,
59
+ struct OutputJson {
60
+ pub answer: String,
65
61
  }
66
62
 
67
63
  #[tokio::main]
@@ -73,23 +69,64 @@ async fn main() -> Result<()> {
73
69
  let input_json = json::parse(&input);
74
70
  match input_json {
75
71
  Ok(json_string) => {
72
+ //println!("json_string:{}", json_string);
76
73
  let user_input_json: &JsonValue = &json_string["user_input"];
77
- //let user_input = "Does aspirin leads to decrease in death rates among Africans?";
78
- //let user_input = "Show the point deletion in TP53 gene.";
79
- //let user_input = "Generate DE plot for men with weight greater than 30lbs vs women less than 20lbs";
80
74
  let user_input: &str;
81
75
  match user_input_json.as_str() {
82
76
  Some(inp) => user_input = inp,
83
77
  None => panic!("user_input field is missing in input json"),
84
78
  }
85
79
 
86
- let dataset_db_json: &JsonValue = &json_string["dataset_db"];
87
- let mut dataset_db: Option<&str> = None;
88
- match dataset_db_json.as_str() {
89
- Some(inp) => dataset_db = Some(inp),
90
- None => {}
80
+ if user_input.len() == 0 {
81
+ panic!("The user input is empty");
82
+ }
83
+
84
+ let tpmasterdir_json: &JsonValue = &json_string["tpmasterdir"];
85
+ let tpmasterdir: &str;
86
+ match tpmasterdir_json.as_str() {
87
+ Some(inp) => tpmasterdir = inp,
88
+ None => panic!("tpmasterdir not found"),
89
+ }
90
+
91
+ let binpath_json: &JsonValue = &json_string["binpath"];
92
+ let binpath: &str;
93
+ match binpath_json.as_str() {
94
+ Some(inp) => binpath = inp,
95
+ None => panic!("binpath not found"),
96
+ }
97
+
98
+ let ai_json_file_json: &JsonValue = &json_string["aifiles"];
99
+ let ai_json_file: String;
100
+ match ai_json_file_json.as_str() {
101
+ Some(inp) => ai_json_file = String::from(binpath) + &"/../../" + &inp,
102
+ None => {
103
+ panic!("ai json file not found")
104
+ }
91
105
  }
92
106
 
107
+ let ai_json_file = Path::new(&ai_json_file);
108
+ let ai_json_file_path;
109
+ let current_dir = std::env::current_dir().unwrap();
110
+ match ai_json_file.canonicalize() {
111
+ Ok(p) => ai_json_file_path = p,
112
+ Err(_) => {
113
+ panic!(
114
+ "AI JSON file path not found:{:?}, current directory:{:?}",
115
+ ai_json_file, current_dir
116
+ )
117
+ }
118
+ }
119
+
120
+ // Read the file
121
+ let ai_data = fs::read_to_string(ai_json_file_path).unwrap();
122
+
123
+ // Parse the JSON data
124
+ let ai_json: AiJsonFormat =
125
+ serde_json::from_str(&ai_data).expect("AI JSON file does not have the correct format");
126
+
127
+ let genedb = String::from(tpmasterdir) + &"/" + &ai_json.genedb;
128
+ let dataset_db = String::from(tpmasterdir) + &"/" + &ai_json.db;
129
+
93
130
  let apilink_json: &JsonValue = &json_string["apilink"];
94
131
  let apilink: &str;
95
132
  match apilink_json.as_str() {
@@ -131,7 +168,7 @@ async fn main() -> Result<()> {
131
168
  } else if llm_backend_name == "ollama".to_string() {
132
169
  llm_backend_type = llm_backend::Ollama();
133
170
  // Initialize Ollama client
134
- let ollama_client = rig::providers::ollama::Client::builder()
171
+ let ollama_client = ollama::Client::builder()
135
172
  .base_url(apilink)
136
173
  .build()
137
174
  .expect("Ollama server not found");
@@ -145,10 +182,11 @@ async fn main() -> Result<()> {
145
182
  temperature,
146
183
  max_new_tokens,
147
184
  top_p,
148
- dataset_db,
185
+ &dataset_db,
186
+ &genedb,
187
+ &ai_json,
149
188
  )
150
189
  .await;
151
- // "gpt-oss:20b" "granite3-dense:latest" "PetrosStav/gemma3-tools:12b" "llama3-groq-tool-use:latest" "PetrosStav/gemma3-tools:12b"
152
190
  } else if llm_backend_name == "SJ".to_string() {
153
191
  llm_backend_type = llm_backend::Sj();
154
192
  // Initialize Sj provider client
@@ -166,17 +204,19 @@ async fn main() -> Result<()> {
166
204
  temperature,
167
205
  max_new_tokens,
168
206
  top_p,
169
- dataset_db,
207
+ &dataset_db,
208
+ &genedb,
209
+ &ai_json,
170
210
  )
171
211
  .await;
172
212
  }
173
213
 
174
214
  match final_output {
175
215
  Some(fin_out) => {
176
- println!("final_output:{:?}", fin_out);
216
+ println!("final_output:{:?}", fin_out.replace("\\", ""));
177
217
  }
178
218
  None => {
179
- println!("final_output:{{\"{}\":\"{}\"}}", "chartType", "unknown");
219
+ println!("final_output:{{\"{}\":\"{}\"}}", "action", "unknown");
180
220
  }
181
221
  }
182
222
  }
@@ -188,7 +228,7 @@ async fn main() -> Result<()> {
188
228
  Ok(())
189
229
  }
190
230
 
191
- async fn run_pipeline(
231
+ pub async fn run_pipeline(
192
232
  user_input: &str,
193
233
  comp_model: impl rig::completion::CompletionModel + 'static,
194
234
  embedding_model: impl rig::embeddings::EmbeddingModel + 'static,
@@ -196,7 +236,9 @@ async fn run_pipeline(
196
236
  temperature: f64,
197
237
  max_new_tokens: usize,
198
238
  top_p: f32,
199
- dataset_db: Option<&str>,
239
+ dataset_db: &str,
240
+ genedb: &str,
241
+ ai_json: &AiJsonFormat,
200
242
  ) -> Option<String> {
201
243
  let mut classification: String = classify_query_by_dataset_type(
202
244
  user_input,
@@ -223,7 +265,7 @@ async fn run_pipeline(
223
265
  .await;
224
266
  final_output = format!(
225
267
  "{{\"{}\":\"{}\",\"{}\":[{}}}",
226
- "chartType",
268
+ "action",
227
269
  "dge",
228
270
  "DE_output",
229
271
  de_result + &"]"
@@ -238,32 +280,32 @@ async fn run_pipeline(
238
280
  max_new_tokens,
239
281
  top_p,
240
282
  dataset_db,
283
+ genedb,
284
+ ai_json,
241
285
  )
242
286
  .await;
243
- } else if classification == "hierarchial".to_string() {
287
+ } else if classification == "hierarchical".to_string() {
244
288
  // Not implemented yet
245
- final_output = format!("{{\"{}\":\"{}\"}}", "chartType", "hierarchial");
289
+ final_output = format!("{{\"{}\":\"{}\"}}", "action", "hierarchical");
246
290
  } else if classification == "snv_indel".to_string() {
247
291
  // Not implemented yet
248
- final_output = format!("{{\"{}\":\"{}\"}}", "chartType", "snv_indel");
292
+ final_output = format!("{{\"{}\":\"{}\"}}", "action", "snv_indel");
249
293
  } else if classification == "cnv".to_string() {
250
294
  // Not implemented yet
251
- final_output = format!("{{\"{}\":\"{}\"}}", "chartType", "cnv");
295
+ final_output = format!("{{\"{}\":\"{}\"}}", "action", "cnv");
252
296
  } else if classification == "variant_calling".to_string() {
253
297
  // Not implemented yet and will never be supported. Need a separate messages for this
254
- final_output = format!("{{\"{}\":\"{}\"}}", "chartType", "variant_calling");
255
- } else if classification == "surivial".to_string() {
298
+ final_output = format!("{{\"{}\":\"{}\"}}", "action", "variant_calling");
299
+ } else if classification == "survival".to_string() {
256
300
  // Not implemented yet
257
- final_output = format!("{{\"{}\":\"{}\"}}", "chartType", "surivial");
301
+ final_output = format!("{{\"{}\":\"{}\"}}", "action", "surivial");
258
302
  } else if classification == "none".to_string() {
259
- final_output = format!("{{\"{}\":\"{}\"}}", "chartType", "none");
260
- println!("The input query did not match any known features in Proteinpaint");
261
- } else {
262
303
  final_output = format!(
263
- "{{\"{}\":\"{}\"}}",
264
- "chartType",
265
- "unknown:".to_string() + &classification
304
+ "{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
305
+ "action", "none", "message", "The input query did not match any known features in Proteinpaint"
266
306
  );
307
+ } else {
308
+ final_output = format!("{{\"{}\":\"{}\"}}", "action", "unknown:".to_string() + &classification);
267
309
  }
268
310
  Some(final_output)
269
311
  }
@@ -295,19 +337,35 @@ Structural variants/fusions (SV) are genomic mutations when eith a DNA region is
295
337
  If a ProteinPaint dataset contains structural variation or gene fusion data then return JSON with single key, 'sv_fusion'.
296
338
  ---
297
339
 
298
- Hierarchial clustering of gene expression is an unsupervised learning technique where several number of relevant genes and the samples are clustered so as to determine (previously unknown) cohorts of samples (or patients) or structure in data. It is very commonly used to determine subtypes of a particular disease based on RNA sequencing data.
340
+ Hierarchical clustering of gene expression is an unsupervised learning technique where several number of relevant genes and the samples are clustered so as to determine (previously unknown) cohorts of samples (or patients) or structure in data. It is very commonly used to determine subtypes of a particular disease based on RNA sequencing data.
299
341
 
300
- If a ProteinPaint dataset contains hierarchial data then return JSON with single key, 'hierarchial'.
342
+ If a ProteinPaint dataset contains hierarchical data then return JSON with single key, 'hierarchical'.
301
343
 
302
344
  ---
303
345
 
304
- Differential Gene Expression (DGE or DE) is a technique where the most upregulated and downregulated genes between two cohorts of samples (or patients) are determined. A volcano plot is shown with fold-change in the x-axis and adjusted p-value on the y-axis. So, the upregulated and downregulared genes are on opposite sides of the graph and the most significant genes (based on adjusted p-value) is on the top of the graph. Following differential gene expression generally GeneSet Enrichment Analysis (GSEA) is carried out where based on the genes and their corresponding fold changes the upregulation/downregulation of genesets (or pathways) is determined.
346
+ Differential Gene Expression (DGE or DE) is a technique where the most upregulated (or highest) and downregulated (or lowest) genes between two cohorts of samples (or patients) are determined from a pool of THOUSANDS of genes. Differential gene expression CANNOT be computed for a SINGLE gene. A volcano plot is shown with fold-change in the x-axis and adjusted p-value on the y-axis. So, the upregulated and downregulared genes are on opposite sides of the graph and the most significant genes (based on adjusted p-value) is on the top of the graph. Following differential gene expression generally GeneSet Enrichment Analysis (GSEA) is carried out where based on the genes and their corresponding fold changes the upregulation/downregulation of genesets (or pathways) is determined.
347
+
348
+ Sample Query1: \"Which gene has the highest expression between the two genders\"
349
+ Sample Answer1: { \"answer\": \"dge\" }
350
+
351
+ Sample Query2: \"Which gene has the lowest expression between the two races\"
352
+ Sample Answer2: { \"answer\": \"dge\" }
353
+
354
+ Sample Query1: \"Which genes are the most upregulated genes between group A and group B\"
355
+ Sample Answer1: { \"answer\": \"dge\" }
356
+
357
+ Sample Query3: \"Which gene are overexpressed between male and female\"
358
+ Sample Answer3: { \"answer\": \"dge\" }
359
+
360
+ Sample Query4: \"Which gene are housekeeping genes between male and female\"
361
+ Sample Answer4: { \"answer\": \"dge\" }
362
+
305
363
 
306
364
  If a ProteinPaint dataset contains differential gene expression data then return JSON with single key, 'dge'.
307
365
 
308
366
  ---
309
367
 
310
- Survival analysis (also called time-to-event analysis or duration analysis) is a branch of statistics aimed at analyzing the duration of time from a well-defined time origin until one or more events happen, called survival times or duration times. In other words, in survival analysis, we are interested in a certain event and want to analyze the time until the event happens.
368
+ Survival analysis (also called time-to-event analysis or duration analysis) is a branch of statistics aimed at analyzing the duration of time from a well-defined time origin until one or more events happen, called survival times or duration times. In other words, in survival analysis, we are interested in a certain event and want to analyze the time until the event happens. Generally in survival analysis survival rates between two (or more) cohorts of patients is compared.
311
369
 
312
370
  There are two main methods of survival analysis:
313
371
 
@@ -319,6 +377,10 @@ There are two main methods of survival analysis:
319
377
  HR < 1: Reduction in the hazard
320
378
  HR > 1: Increase in Hazard
321
379
 
380
+ Sample Query1: \"Compare survival rates between group A and B\"
381
+ Sample Answer1: { \"answer\": \"survival\" }
382
+
383
+
322
384
  If a ProteinPaint dataset contains survival data then return JSON with single key, 'survival'.
323
385
 
324
386
  ---
@@ -329,15 +391,20 @@ If a user query asks about variant calling or mapping reads then JSON with singl
329
391
 
330
392
  ---
331
393
 
332
- Summary plot in ProteinPaint shows the various facets of the datasets. It may show all the samples according to their respective diagnosis or subtypes of cancer. It is also useful for visualizing all the different facets of the dataset. You can display a categorical variable and overlay another variable on top it and stratify (or divide) using a third variable simultaneously. You can also custom filters to the dataset so that you can only study part of the dataset. If a user query asks about variant calling or mapping reads then JSON with single key, 'summary'.
394
+ Summary plot in ProteinPaint shows the various facets of the datasets. Show expression of a SINGLE gene or compare the expression of a SINGLE gene across two different cohorts defined by the user. It may show all the samples according to their respective diagnosis or subtypes of cancer. It is also useful for comparing and correlating different clinical variables. It can show all possible distributions, frequency of a category, overlay, correlate or cross-tabulate with another variable on top of it. If a user query asks about a SINGLE gene expression or correlating clinical variables then return JSON with single key, 'summary'.
333
395
 
334
396
  Sample Query1: \"Show all fusions for patients with age less than 30\"
335
397
  Sample Answer1: { \"answer\": \"summary\" }
336
398
 
337
- Sample Query1: \"List all molecular subtypes of leukemia\"
338
- Sample Answer1: { \"answer\": \"summary\" }
399
+ Sample Query2: \"List all molecular subtypes of leukemia\"
400
+ Sample Answer2: { \"answer\": \"summary\" }
401
+
402
+ Sample Query3: \"is tp53 expression higher in men than women ?\"
403
+ Sample Answer3: { \"answer\": \"summary\" }
404
+
405
+ Sample Query4: \"Compare ATM expression between races for women greater than 80yrs\"
406
+ Sample Answer4: { \"answer\": \"summary\" }
339
407
 
340
- ---
341
408
 
342
409
  If a query does not match any of the fields described above, then return JSON with single key, 'none'
343
410
  ");
@@ -345,14 +412,16 @@ If a query does not match any of the fields described above, then return JSON wi
345
412
  // Split the contents by the delimiter "---"
346
413
  let parts: Vec<&str> = contents.split("---").collect();
347
414
  let schema_json: Value = serde_json::to_value(schemars::schema_for!(OutputJson)).unwrap(); // error handling here
415
+ let schema_json_string = serde_json::to_string_pretty(&schema_json).unwrap();
348
416
 
349
417
  let additional;
350
418
  match llm_backend_type {
351
419
  llm_backend::Ollama() => {
352
420
  additional = json!({
353
- "format": schema_json
354
- }
355
- );
421
+ "max_new_tokens": max_new_tokens,
422
+ "top_p": top_p,
423
+ "schema_json": schema_json_string
424
+ });
356
425
  }
357
426
  llm_backend::Sj() => {
358
427
  additional = json!({
@@ -369,7 +438,7 @@ If a query does not match any of the fields described above, then return JSON wi
369
438
  rag_docs.push(part.trim().to_string())
370
439
  }
371
440
 
372
- let top_k: usize = 3;
441
+ //let top_k: usize = 3;
373
442
  // Create embeddings and add to vector store
374
443
  let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
375
444
  .documents(rag_docs)
@@ -383,20 +452,25 @@ If a query does not match any of the fields described above, then return JSON wi
383
452
  InMemoryVectorStore::add_documents(&mut vector_store, embeddings);
384
453
 
385
454
  // Create RAG agent
386
- let agent = AgentBuilder::new(comp_model).preamble("Generate classification for the user query into summary, dge, hierarchial, snv_indel, cnv, variant_calling, sv_fusion and none categories. Return output in JSON with ALWAYS a single word answer { \"answer\": \"dge\" }, that is 'summary' for summary plot, 'dge' for differential gene expression, 'hierarchial' for hierarchial clustering, 'snv_indel' for SNV/Indel, 'cnv' for CNV and 'sv_fusion' for SV/fusion, 'variant_calling' for variant calling, 'surivial' for survival data, 'none' for none of the previously described categories. The summary plot list and summarizes the cohort of patients according to the user query. The answer should always be in lower case").dynamic_context(top_k, vector_store.index(embedding_model)).temperature(temperature).additional_params(additional).build();
455
+ let agent = AgentBuilder::new(comp_model).preamble(&(String::from("Generate classification for the user query into summary, dge, hierarchical, snv_indel, cnv, variant_calling, sv_fusion and none categories. Return output in JSON with ALWAYS a single word answer { \"answer\": \"dge\" }, that is 'summary' for summary plot, 'dge' for differential gene expression, 'hierarchical' for hierarchical clustering, 'snv_indel' for SNV/Indel, 'cnv' for CNV and 'sv_fusion' for SV/fusion, 'variant_calling' for variant calling, 'surivial' for survival data, 'none' for none of the previously described categories. The summary plot list and summarizes the cohort of patients according to the user query. The answer should always be in lower case\n The options are as follows:\n") + &contents + "\nQuestion= {question} \nanswer")).temperature(temperature).additional_params(additional).build();
456
+ //.dynamic_context(top_k, vector_store.index(embedding_model))
387
457
 
388
- let response = agent.prompt(user_input).await.expect("Failed to prompt ollama");
458
+ let response = agent.prompt(user_input).await.expect("Failed to prompt server");
389
459
 
390
460
  //println!("Ollama: {}", response);
391
461
  let result = response.replace("json", "").replace("```", "");
392
462
  let json_value: Value = serde_json::from_str(&result).expect("REASON");
393
463
  match llm_backend_type {
394
- llm_backend::Ollama() => json_value.as_object().unwrap()["answer"].to_string().replace("\"", ""),
464
+ llm_backend::Ollama() => {
465
+ let json_value2: Value = serde_json::from_str(&json_value["content"].to_string()).expect("REASON2");
466
+ let json_value3: Value = serde_json::from_str(&json_value2.as_str().unwrap()).expect("REASON3");
467
+ json_value3["answer"].to_string()
468
+ }
395
469
  llm_backend::Sj() => {
396
470
  let json_value2: Value =
397
471
  serde_json::from_str(&json_value[0]["generated_text"].to_string()).expect("REASON2");
398
472
  //println!("json_value2:{}", json_value2.as_str().unwrap());
399
- let json_value3: Value = serde_json::from_str(&json_value2.as_str().unwrap()).expect("REASON2");
473
+ let json_value3: Value = serde_json::from_str(&json_value2.as_str().unwrap()).expect("REASON3");
400
474
  //let json_value3: Value = serde_json::from_str(&json_value2["answer"].to_string()).expect("REASON2");
401
475
  //println!("Classification result:{}", json_value3["answer"]);
402
476
  json_value3["answer"].to_string()
@@ -404,6 +478,45 @@ If a query does not match any of the fields described above, then return JSON wi
404
478
  }
405
479
  }
406
480
 
481
+ // DE JSON output schema
482
+
483
+ #[allow(non_camel_case_types)]
484
+ #[derive(Debug, JsonSchema)]
485
+ #[allow(dead_code)]
486
+ enum cutoff_info {
487
+ lesser(f32),
488
+ greater(f32),
489
+ equalto(f32),
490
+ }
491
+
492
+ #[derive(Debug, JsonSchema)]
493
+ #[allow(dead_code)]
494
+ struct Cutoff {
495
+ cutoff_name: cutoff_info,
496
+ units: Option<String>,
497
+ }
498
+
499
+ #[derive(Debug, JsonSchema)]
500
+ #[allow(dead_code)]
501
+ struct Filter {
502
+ name: String,
503
+ cutoff: Cutoff,
504
+ }
505
+
506
+ #[derive(Debug, JsonSchema)]
507
+ #[allow(dead_code)]
508
+ struct Group {
509
+ name: String,
510
+ filter: Filter,
511
+ }
512
+
513
+ #[derive(Debug, JsonSchema)]
514
+ #[allow(dead_code)]
515
+ struct DEOutput {
516
+ group1: Group,
517
+ group2: Group,
518
+ }
519
+
407
520
  #[allow(non_snake_case)]
408
521
  async fn extract_DE_search_terms_from_query(
409
522
  user_input: &str,
@@ -440,16 +553,17 @@ Output JSON query5: {\"group1\": {\"name\": \"males\", \"filter\": {\"name\": \"
440
553
  let parts: Vec<&str> = contents.split("---").collect();
441
554
 
442
555
  let schema_json: Value = serde_json::to_value(schemars::schema_for!(DEOutput)).unwrap(); // error handling here
443
-
556
+ let schema_json_string = serde_json::to_string_pretty(&schema_json).unwrap();
444
557
  //println!("DE schema:{}", schema_json);
445
558
 
446
559
  let additional;
447
560
  match llm_backend_type {
448
561
  llm_backend::Ollama() => {
449
562
  additional = json!({
450
- "format": schema_json
451
- }
452
- );
563
+ "max_new_tokens": max_new_tokens,
564
+ "top_p": top_p,
565
+ "schema_json": schema_json_string
566
+ });
453
567
  }
454
568
  llm_backend::Sj() => {
455
569
  additional = json!({
@@ -480,16 +594,21 @@ Output JSON query5: {\"group1\": {\"name\": \"males\", \"filter\": {\"name\": \"
480
594
  InMemoryVectorStore::add_documents(&mut vector_store, embeddings);
481
595
 
482
596
  // Create RAG agent
483
- let router_instructions = "Extract the group variable names for differential gene expression from input query. When two groups are found give the following JSON output with no extra comments. Show {{\"group1\": {\"name\": \"groupA\"}, \"group2\": {\"name\": \"groupB\"}}}. In case no suitable groups are found, show {\"output\":\"No suitable two groups found for differential gene expression\"}. In case of a continuous variable such as age, height added additional field to the group called \"filter\". This should contain a sub-field called \"names\" followed by a subfield called \"cutoff\". This sub-field should contain a key either greater, lesser or equalto. If the continuous variable has units provided by the user then add it in a separate field called \"units\". User query1: \"Show volcano plot for Asians with age less than 20 and African greater than 80\". Output JSON query1: {\"group1\": {\"name\": \"Asians\", \"filter\": {\"name\": \"age\", \"cutoff\": {\"lesser\": 20}}}, \"group2\": {\"name\": \"African\", \"filter\": {\"name\": \"age\", \"cutoff\": {\"greater\": 80}}}}. User query2: \"Show Differential gene expression plot for males with height greater than 185cm and women with less than 100cm\". Output JSON query2: {\"group1\": {\"name\": \"males\", \"filter\": {\"name\": \"height\", \"cutoff\": {\"greater\": 185, \"units\":\"cm\"}}}, \"group2\": {\"name\": \"women\", \"filter\": {\"name\": \"height\", \"cutoff\": {\"lesser\": 100, \"units\": \"cm\"}}}}. User query3: \"Show DE plot between healthy and diseased groups. Output JSON query3: {\"group1\":{\"name\":\"healthy\"},\"group2\":{\"name\":\"diseased\"}}";
597
+ let router_instructions = String::from(
598
+ "Extract the group variable names for differential gene expression from input query. When two groups are found give the following JSON output with no extra comments. Show {{\"group1\": {\"name\": \"groupA\"}, \"group2\": {\"name\": \"groupB\"}}}. In case no suitable groups are found, show {\"output\":\"No suitable two groups found for differential gene expression\"}. In case of a continuous variable such as age, height added additional field to the group called \"filter\". This should contain a sub-field called \"names\" followed by a subfield called \"cutoff\". This sub-field should contain a key either greater, lesser or equalto. If the continuous variable has units provided by the user then add it in a separate field called \"units\".",
599
+ ) + &contents
600
+ + " The JSON schema is as follows"
601
+ + &schema_json_string
602
+ + "\n Examples: User query1: \"Show volcano plot for Asians with age less than 20 and African greater than 80\". Output JSON query1: {\"group1\": {\"name\": \"Asians\", \"filter\": {\"name\": \"age\", \"cutoff\": {\"lesser\": 20}}}, \"group2\": {\"name\": \"African\", \"filter\": {\"name\": \"age\", \"cutoff\": {\"greater\": 80}}}}. User query2: \"Show Differential gene expression plot for males with height greater than 185cm and women with less than 100cm\". Output JSON query2: {\"group1\": {\"name\": \"males\", \"filter\": {\"name\": \"height\", \"cutoff\": {\"greater\": 185, \"units\":\"cm\"}}}, \"group2\": {\"name\": \"women\", \"filter\": {\"name\": \"height\", \"cutoff\": {\"lesser\": 100, \"units\": \"cm\"}}}}. User query3: \"Show DE plot between healthy and diseased groups. Output JSON query3: {\"group1\":{\"name\":\"healthy\"},\"group2\":{\"name\":\"diseased\"}} \nQuestion= {question} \nanswer";
484
603
  //println! {"router_instructions:{}",router_instructions};
485
604
  let agent = AgentBuilder::new(comp_model)
486
- .preamble(router_instructions)
605
+ .preamble(&router_instructions)
487
606
  .dynamic_context(rag_docs_length, vector_store.index(embedding_model))
488
607
  .temperature(temperature)
489
608
  .additional_params(additional)
490
609
  .build();
491
610
 
492
- let response = agent.prompt(user_input).await.expect("Failed to prompt ollama");
611
+ let response = agent.prompt(user_input).await.expect("Failed to prompt server");
493
612
 
494
613
  //println!("Ollama_groups: {}", response);
495
614
  let result = response.replace("json", "").replace("```", "");
@@ -497,7 +616,12 @@ Output JSON query5: {\"group1\": {\"name\": \"males\", \"filter\": {\"name\": \"
497
616
  let json_value: Value = serde_json::from_str(&result).expect("REASON");
498
617
  //println!("json_value:{}", json_value);
499
618
  match llm_backend_type {
500
- llm_backend::Ollama() => json_value.to_string(),
619
+ llm_backend::Ollama() => {
620
+ let json_value2: Value = serde_json::from_str(&json_value["content"].to_string()).expect("REASON2");
621
+ //println!("json_value2:{:?}", json_value2);
622
+ let json_value3: Value = serde_json::from_str(&json_value2.as_str().unwrap()).expect("REASON3");
623
+ json_value3.to_string()
624
+ }
501
625
  llm_backend::Sj() => {
502
626
  let json_value2: Value =
503
627
  serde_json::from_str(&json_value[0]["generated_text"].to_string()).expect("REASON2");
@@ -509,6 +633,7 @@ Output JSON query5: {\"group1\": {\"name\": \"males\", \"filter\": {\"name\": \"
509
633
  }
510
634
  }
511
635
 
636
+ #[derive(Debug, Clone)]
512
637
  struct DbRows {
513
638
  name: String,
514
639
  description: Option<String>,
@@ -516,6 +641,21 @@ struct DbRows {
516
641
  values: Vec<String>,
517
642
  }
518
643
 
644
+ async fn parse_geneset_db(db: &str) -> Vec<String> {
645
+ let manager = SqliteConnectionManager::file(db);
646
+ let pool = r2d2::Pool::new(manager).unwrap();
647
+ let conn = pool.get().unwrap();
648
+ let sql_statement_genedb = "SELECT * from codingGenes";
649
+ let mut genedb = conn.prepare(&sql_statement_genedb).unwrap();
650
+ let mut rows_genedb = genedb.query([]).unwrap();
651
+ let mut gene_list = Vec::<String>::new();
652
+ while let Some(coding_gene) = rows_genedb.next().unwrap() {
653
+ let code_gene: String = coding_gene.get(0).unwrap();
654
+ gene_list.push(code_gene)
655
+ }
656
+ gene_list
657
+ }
658
+
519
659
  trait ParseDbRows {
520
660
  fn parse_db_rows(&self) -> String;
521
661
  }
@@ -544,7 +684,7 @@ impl ParseDbRows for DbRows {
544
684
  }
545
685
  }
546
686
 
547
- async fn parse_dataset_db(db: &str) -> Vec<String> {
687
+ async fn parse_dataset_db(db: &str) -> (Vec<String>, Vec<DbRows>) {
548
688
  let manager = SqliteConnectionManager::file(db);
549
689
  let pool = r2d2::Pool::new(manager).unwrap();
550
690
  let conn = pool.get().unwrap();
@@ -574,7 +714,7 @@ async fn parse_dataset_db(db: &str) -> Vec<String> {
574
714
  }
575
715
 
576
716
  //// Open the file
577
- //let mut file = File::open(dataset_agnostic_file).unwrap();
717
+ //let mut file = File::open(dataset_file).unwrap();
578
718
 
579
719
  //// Create a string to hold the file contents
580
720
  //let mut contents = String::new();
@@ -603,6 +743,7 @@ async fn parse_dataset_db(db: &str) -> Vec<String> {
603
743
  // Print the separated parts
604
744
  let mut rag_docs = Vec::<String>::new();
605
745
  let mut names = Vec::<String>::new();
746
+ let mut db_vec = Vec::<DbRows>::new();
606
747
  while let Some(row) = rows_terms.next().unwrap() {
607
748
  //println!("row:{:?}", row);
608
749
  let name: String = row.get(0).unwrap();
@@ -637,6 +778,7 @@ async fn parse_dataset_db(db: &str) -> Vec<String> {
637
778
  term_type: item_type,
638
779
  values: keys,
639
780
  };
781
+ db_vec.push(item.clone());
640
782
  //println!("Field details:{}", item.parse_db_rows());
641
783
  rag_docs.push(item.parse_db_rows());
642
784
  names.push(name)
@@ -645,60 +787,109 @@ async fn parse_dataset_db(db: &str) -> Vec<String> {
645
787
  }
646
788
  }
647
789
  //println!("names:{:?}", names);
648
- rag_docs
790
+ (rag_docs, db_vec)
649
791
  }
650
792
 
651
793
  async fn extract_summary_information(
652
794
  user_input: &str,
653
795
  comp_model: impl rig::completion::CompletionModel + 'static,
654
- embedding_model: impl rig::embeddings::EmbeddingModel + 'static,
796
+ _embedding_model: impl rig::embeddings::EmbeddingModel + 'static,
655
797
  llm_backend_type: &llm_backend,
656
798
  temperature: f64,
657
799
  max_new_tokens: usize,
658
800
  top_p: f32,
659
- dataset_db: Option<&str>,
801
+ dataset_db: &str,
802
+ genedb: &str,
803
+ ai_json: &AiJsonFormat,
660
804
  ) -> String {
661
- match dataset_db {
662
- Some(db) => {
663
- let rag_docs = parse_dataset_db(db).await;
664
- //println!("rag_docs:{:?}", rag_docs);
665
- let additional;
666
- match llm_backend_type {
667
- llm_backend::Ollama() => {
668
- additional = json!({});
669
- }
670
- llm_backend::Sj() => {
671
- additional = json!({
672
- "max_new_tokens": max_new_tokens,
673
- "top_p": top_p
674
- });
675
- }
676
- }
677
-
678
- let rag_docs_length = rag_docs.len();
679
- // Create embeddings and add to vector store
680
- let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
681
- .documents(rag_docs)
682
- .expect("Reason1")
683
- .build()
684
- .await
685
- .unwrap();
805
+ let (rag_docs, db_vec) = parse_dataset_db(dataset_db).await;
806
+ let additional;
807
+ let schema_json = schemars::schema_for!(SummaryType); // error handling here
808
+ let schema_json_string = serde_json::to_string_pretty(&schema_json).unwrap();
809
+ //println!("schema_json summary:{}", schema_json_string);
810
+ match llm_backend_type {
811
+ llm_backend::Ollama() => {
812
+ additional = json!({
813
+ "max_new_tokens": max_new_tokens,
814
+ "top_p": top_p,
815
+ "schema_json": schema_json_string
816
+ });
817
+ }
818
+ llm_backend::Sj() => {
819
+ additional = json!({
820
+ "max_new_tokens": max_new_tokens,
821
+ "top_p": top_p
822
+ });
823
+ }
824
+ }
686
825
 
687
- // Create vector store
688
- let mut vector_store = InMemoryVectorStore::<String>::default();
689
- InMemoryVectorStore::add_documents(&mut vector_store, embeddings);
826
+ // Create embeddings and add to vector store
827
+ //let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
828
+ // .documents(rag_docs)
829
+ // .expect("Reason1")
830
+ // .build()
831
+ // .await
832
+ // .unwrap();
833
+
834
+ //// Create vector store
835
+ //let mut vector_store = InMemoryVectorStore::<String>::default();
836
+ //InMemoryVectorStore::add_documents(&mut vector_store, embeddings);
837
+
838
+ let gene_list: Vec<String> = parse_geneset_db(genedb).await;
839
+ let lowercase_user_input = user_input.to_lowercase();
840
+ let user_words: Vec<&str> = lowercase_user_input.split_whitespace().collect();
841
+ let user_words2: Vec<String> = user_words.into_iter().map(|s| s.to_string()).collect();
842
+
843
+ let common_genes: Vec<String> = gene_list
844
+ .into_iter()
845
+ .filter(|x| user_words2.contains(&x.to_lowercase()))
846
+ .collect();
847
+
848
+ let mut summary_data_check: Option<TrainTestData> = None;
849
+ for chart in ai_json.charts.clone() {
850
+ if let Charts::Summary(traindata) = chart {
851
+ summary_data_check = Some(traindata);
852
+ break;
853
+ }
854
+ }
690
855
 
691
- //let system_prompt = "I am an assistant that figures out the summary term from its respective dataset file. Extract the summary term {summary_term} from user query. The final output must be in the following JSON format {{\"chartType\":\"summary\",\"term\":{{\"id\":\"{{summary_term}}\"}}}}";
856
+ match summary_data_check {
857
+ Some(summary_data) => {
858
+ let mut training_data: String = String::from("");
859
+ let mut train_iter = 0;
860
+ for ques_ans in summary_data.TrainingData {
861
+ train_iter += 1;
862
+ training_data += "Example question";
863
+ training_data += &train_iter.to_string();
864
+ training_data += &":";
865
+ training_data += &ques_ans.question;
866
+ training_data += &" ";
867
+ training_data += "Example answer";
868
+ training_data += &train_iter.to_string();
869
+ training_data += &":";
870
+ training_data += &ques_ans.answer;
871
+ training_data += &"\n";
872
+ }
692
873
 
693
- let top_k = rag_docs_length;
694
- let system_prompt = String::from(
695
- "I am an assistant that extracts the summary term from user query. It has four fields: group_categories (required), overlay (optional), filter (optional) and divide_by (optional). group_categories (required) is the primary variable being displayed. Overlay consists of the variable that must be overlayed on top of group_categories. divide_by is the variable used to stratify group_categories into two or more categories. The final output must be in the following JSON format with no extra comments: {\"chartType\":\"summary\",\"term\":{\"group_categories\":\"{group_category_answer}\",\"overlay\":\"{overlay_answer}\",\"divide_by\":\"{divide_by_answer}\",\"filter\":\"{filter_answer}\"}}. The values being added to the JSON parameters must be previously defined as field in the database. If the filter variable is a \"value\" of a \"field\" in the database, use the field name and add the value as a \"filter cutoff\" . If the \"filter\" field is defined in the user query, it should contain an array with each item containing a subfield called \"name\" with the name of the filter variable. If the type of variable is \"categories\", add another field as \"variable_type\" = \"categories\". In case the type of the variable is \"categories\", show the sub-category as a separate sub-field \"cutoff\" with a sub nested JSON with \"name\" as the field containing the subcategory name. In case the type of the variable is \"float\" it should contain a subfield called \"name\" followed by subfield \"variable_type\" = \"float\". In the \"cutoff\" subfield, the nested JSON should contain the field \"lower\" containing the lower numeric limit and the \"upper\" field containing the upper numeric limit. If the upper and lower cutoffs are not defined in the user query, use a default value of 0 and 100 respectively. Sample query1: \"Show ETR1 subtype\" Answer query1: \"{\"chartType\":\"summary\",\"term\":{\"group_categories\":\"ETR1\"}}. Sample query2: \"Show hyperdiploid subtype with age overlayed on top of it\" Answer query2: \"{\"chartType\":\"summary\",\"term\":{\"group_categories\":\"hyperdiploid\", \"overlay\":\"age\"}}. Sample query3: \"Show BAR1 subtype with age overlayed on top of it and stratify it on the basis of gender\" Answer query4: \"{\"chartType\":\"summary\",\"term\":{\"group_categories\":\"BAR1\", \"overlay\":\"age\", \"divide_by\":\"sex\"}}. Sample query5: \"Show summary for cancer-diagnosis only for men\". Since gender is a categorical variable and the user wants to select for men, the answer query for sample query5 is as follows: \"{\"chartType\":\"summary\",\"term\":{\"group_categories\":\"cancer-diagnosis\", \"filter\": {\"name\": \"gender\", \"variable_type\": \"categories\", \"cutoff\": {\"name\": \"male\"}}}}. Sample query6: \"Show molecular subtype summary for patients with age less than 30\". Age is a float variable so we need to provide the lower and higher cutoffs. So the answer to sample query6 is as follows: \"{\"chartType\":\"summary\",\"term\":{\"group_categories\":\"Molecular subtype\", \"filter\": {\"name\": \"age\", \"variable_type\": \"float\", \"cutoff\": {\"lower\": 0, \"higher\": 30}}}} ",
874
+ let system_prompt: String = String::from(
875
+ String::from(
876
+ "I am an assistant that extracts the summary terms from user query. The final output must be in the following JSON format with NO extra comments. There are three fields in the JSON to be returned: The \"action\" field will ALWAYS be \"summary\". The \"summaryterms\" field should contain all the variables that the user wants to visualize. The \"clinical\" subfield should ONLY contain names of the fields from the sqlite db. ",
877
+ ) + &summary_data.SystemPrompt
878
+ + &" The \"filter\" field is optional and should contain an array of JSON terms with which the dataset will be filtered. A variable simultaneously CANNOT be part of both \"summaryterms\" and \"filter\". There are two kinds of filter variables: \"Categorical\" and \"Numeric\". \"Categorical\" variables are those variables which can have a fixed set of values e.g. gender, molecular subtypes. They are defined by the \"CategoricalFilterTerm\" which consists of \"term\" (a field from the sqlite3 db) and \"value\" (a value of the field from the sqlite db). \"Numeric\" variables are those which can have any numeric value. They are defined by \"NumericFilterTerm\" and contain the subfields \"term\" (a field from the sqlite3 db), \"greaterThan\" an optional filter which is defined when a lower cutoff is defined in the user input for the numeric variable and \"lessThan\" an optional filter which is defined when a higher cutoff is defined in the user input for the numeric variable. The \"message\" field only contain messages of terms in the user input that were not found in their respective databases. The JSON schema is as follows:"
879
+ + &schema_json_string
880
+ + &training_data
881
+ + "The sqlite db in plain language is as follows:\n"
882
+ + &rag_docs.join(",")
883
+ + &"\n Relevant genes are as follows (separated by comma(,)):"
884
+ + &common_genes.join(",")
885
+ + &"\nQuestion: {question} \nanswer:",
696
886
  );
887
+
697
888
  //println!("system_prompt:{}", system_prompt);
698
889
  // Create RAG agent
699
890
  let agent = AgentBuilder::new(comp_model)
700
891
  .preamble(&system_prompt)
701
- .dynamic_context(top_k, vector_store.index(embedding_model))
892
+ //.dynamic_context(top_k, vector_store.index(embedding_model))
702
893
  .temperature(temperature)
703
894
  .additional_params(additional)
704
895
  .build();
@@ -711,20 +902,463 @@ async fn extract_summary_information(
711
902
  let json_value: Value = serde_json::from_str(&result).expect("REASON");
712
903
  //println!("Classification result:{}", json_value);
713
904
 
905
+ let final_llm_json;
714
906
  match llm_backend_type {
715
- llm_backend::Ollama() => json_value.to_string(),
907
+ llm_backend::Ollama() => {
908
+ let json_value2: Value = serde_json::from_str(&json_value["content"].to_string()).expect("REASON2");
909
+ let json_value3: Value = serde_json::from_str(&json_value2.as_str().unwrap()).expect("REASON3");
910
+ final_llm_json = json_value3.to_string()
911
+ }
716
912
  llm_backend::Sj() => {
717
913
  let json_value2: Value =
718
914
  serde_json::from_str(&json_value[0]["generated_text"].to_string()).expect("REASON2");
719
915
  //println!("json_value2:{}", json_value2.as_str().unwrap());
720
- let json_value3: Value = serde_json::from_str(&json_value2.as_str().unwrap()).expect("REASON2");
916
+ let json_value3: Value = serde_json::from_str(&json_value2.as_str().unwrap()).expect("REASON3");
721
917
  //println!("Classification result:{}", json_value3);
722
- json_value3.to_string()
918
+ final_llm_json = json_value3.to_string()
723
919
  }
724
920
  }
921
+ //println!("final_llm_json:{}", final_llm_json);
922
+ let final_validated_json = validate_summary_output(final_llm_json.clone(), db_vec, common_genes, ai_json);
923
+ final_validated_json
725
924
  }
726
925
  None => {
727
- panic!("Dataset db file needed for summary term extraction from user input")
926
+ panic!("summary chart train and test data is not defined in dataset JSON file")
927
+ }
928
+ }
929
+ }
930
+
931
+ fn get_summary_string() -> String {
932
+ "summary".to_string()
933
+ }
934
+
935
+ //const action: &str = &"summary";
936
+ //const geneExpression: &str = &"geneExpression";
937
+
938
+ #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
939
+ struct SummaryType {
940
+ // Serde uses this for deserialization.
941
+ #[serde(default = "get_summary_string")]
942
+ // Schemars uses this for schema generation.
943
+ #[schemars(rename = "action")]
944
+ action: String,
945
+ summaryterms: Vec<SummaryTerms>,
946
+ filter: Option<Vec<FilterTerm>>,
947
+ message: Option<String>,
948
+ }
949
+
950
+ impl SummaryType {
951
+ #[allow(dead_code)]
952
+ pub fn sort_summarytype_struct(&mut self) {
953
+ // This function is necessary for testing (test_ai.rs) to see if two variables of type "SummaryType" are equal or not. Without this a vector of two Summarytype holding the same values but in different order will be classified separately.
954
+ self.summaryterms.sort();
955
+
956
+ match self.filter.clone() {
957
+ Some(ref mut filterterms) => filterterms.sort(),
958
+ None => {}
959
+ }
960
+ }
961
+ }
962
+
963
+ #[derive(PartialEq, Eq, Ord, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
964
+ enum SummaryTerms {
965
+ #[allow(non_camel_case_types)]
966
+ clinical(String),
967
+ #[allow(non_camel_case_types)]
968
+ geneExpression(String),
969
+ }
970
+
971
+ impl PartialOrd for SummaryTerms {
972
+ fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
973
+ match (self, other) {
974
+ (SummaryTerms::clinical(_), SummaryTerms::clinical(_)) => Some(std::cmp::Ordering::Equal),
975
+ (SummaryTerms::geneExpression(_), SummaryTerms::geneExpression(_)) => Some(std::cmp::Ordering::Equal),
976
+ (SummaryTerms::clinical(_), SummaryTerms::geneExpression(_)) => Some(std::cmp::Ordering::Greater),
977
+ (SummaryTerms::geneExpression(_), SummaryTerms::clinical(_)) => Some(std::cmp::Ordering::Greater),
978
+ }
979
+ }
980
+ }
981
+
982
+ #[derive(PartialEq, Eq, Ord, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
983
+ enum FilterTerm {
984
+ Categorical(CategoricalFilterTerm),
985
+ Numeric(NumericFilterTerm),
986
+ }
987
+
988
+ impl PartialOrd for FilterTerm {
989
+ fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
990
+ match (self, other) {
991
+ (FilterTerm::Categorical(_), FilterTerm::Categorical(_)) => Some(std::cmp::Ordering::Equal),
992
+ (FilterTerm::Numeric(_), FilterTerm::Numeric(_)) => Some(std::cmp::Ordering::Equal),
993
+ (FilterTerm::Categorical(_), FilterTerm::Numeric(_)) => Some(std::cmp::Ordering::Greater),
994
+ (FilterTerm::Numeric(_), FilterTerm::Categorical(_)) => Some(std::cmp::Ordering::Greater),
995
+ }
996
+ }
997
+ }
998
+
999
+ #[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
1000
+ struct CategoricalFilterTerm {
1001
+ term: String,
1002
+ value: String,
1003
+ }
1004
+
1005
+ #[derive(Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
1006
+ #[allow(non_snake_case)]
1007
+ struct NumericFilterTerm {
1008
+ term: String,
1009
+ greaterThan: Option<f32>,
1010
+ lessThan: Option<f32>,
1011
+ }
1012
+
1013
+ impl PartialEq for NumericFilterTerm {
1014
+ fn eq(&self, other: &Self) -> bool {
1015
+ let greater_equality: bool;
1016
+ match (self.greaterThan, other.greaterThan) {
1017
+ (Some(a), Some(b)) => greater_equality = (a - b).abs() < 1e-6,
1018
+ (None, None) => greater_equality = true,
1019
+ _ => greater_equality = false,
1020
+ }
1021
+
1022
+ let less_equality: bool;
1023
+ match (self.lessThan, other.lessThan) {
1024
+ (Some(a), Some(b)) => less_equality = (a - b).abs() < 1e-6,
1025
+ (None, None) => less_equality = true,
1026
+ _ => less_equality = false,
1027
+ }
1028
+
1029
+ if greater_equality == true && less_equality == true {
1030
+ true
1031
+ } else {
1032
+ false
1033
+ }
1034
+ }
1035
+ }
1036
+
1037
+ impl Eq for NumericFilterTerm {}
1038
+
1039
+ impl PartialOrd for NumericFilterTerm {
1040
+ fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
1041
+ if self.greaterThan < other.greaterThan {
1042
+ Some(std::cmp::Ordering::Less)
1043
+ } else if self.greaterThan > other.greaterThan {
1044
+ Some(std::cmp::Ordering::Greater)
1045
+ } else if self.lessThan < other.lessThan {
1046
+ Some(std::cmp::Ordering::Less)
1047
+ } else if self.lessThan > other.lessThan {
1048
+ Some(std::cmp::Ordering::Greater)
1049
+ } else {
1050
+ Some(std::cmp::Ordering::Equal)
1051
+ }
1052
+ }
1053
+ }
1054
+
1055
+ impl Ord for NumericFilterTerm {
1056
+ fn cmp(&self, other: &Self) -> std::cmp::Ordering {
1057
+ self.partial_cmp(other).unwrap()
1058
+ }
1059
+ }
1060
+
1061
+ fn validate_summary_output(
1062
+ raw_llm_json: String,
1063
+ db_vec: Vec<DbRows>,
1064
+ common_genes: Vec<String>,
1065
+ ai_json: &AiJsonFormat,
1066
+ ) -> String {
1067
+ let json_value: SummaryType =
1068
+ serde_json::from_str(&raw_llm_json).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
1069
+ let mut message: String = String::from("");
1070
+ match json_value.message {
1071
+ Some(mes) => {
1072
+ message = message + &mes; // Append any message given by the LLM
1073
+ }
1074
+ None => {}
1075
+ }
1076
+
1077
+ let mut new_json: Value; // New JSON value that will contain items of the final validated JSON
1078
+ if json_value.action != String::from("summary") {
1079
+ message = message + &"Did not return a summary action";
1080
+ new_json = serde_json::json!(null);
1081
+ } else {
1082
+ new_json = serde_json::from_str(&"{\"action\":\"summary\"}").expect("Not a valid JSON");
1083
+ }
1084
+
1085
+ let mut validated_summary_terms = Vec::<SummaryTerms>::new();
1086
+ let mut summary_terms_tobe_removed = Vec::<SummaryTerms>::new();
1087
+ for sum_term in &json_value.summaryterms {
1088
+ match sum_term {
1089
+ SummaryTerms::clinical(clin) => {
1090
+ let term_verification = verify_json_field(clin, &db_vec);
1091
+ if Some(term_verification.correct_field.clone()).is_some()
1092
+ && term_verification.correct_value.clone().is_none()
1093
+ {
1094
+ match term_verification.correct_field {
1095
+ Some(tm) => validated_summary_terms.push(SummaryTerms::clinical(tm)),
1096
+ None => {
1097
+ message = message + &"\"" + &clin + &"\"" + &" not found in db.";
1098
+ }
1099
+ }
1100
+ } else if Some(term_verification.correct_field.clone()).is_some()
1101
+ && Some(term_verification.correct_value.clone()).is_some()
1102
+ {
1103
+ message = message
1104
+ + &term_verification.correct_value.unwrap()
1105
+ + &"is a value of "
1106
+ + &term_verification.correct_field.unwrap()
1107
+ + &".";
1108
+ }
1109
+ }
1110
+ SummaryTerms::geneExpression(gene) => {
1111
+ match ai_json.hasGeneExpression {
1112
+ true => {
1113
+ let mut num_gene_verification = 0;
1114
+ for common_gene in &common_genes {
1115
+ // Comparing predicted gene against the common gene
1116
+ if common_gene == gene {
1117
+ num_gene_verification += 1;
1118
+ validated_summary_terms.push(SummaryTerms::geneExpression(String::from(gene)));
1119
+ }
1120
+ }
1121
+
1122
+ if num_gene_verification == 0 || common_genes.len() == 0 {
1123
+ if message.to_lowercase().contains(&gene.to_lowercase()) { // Check if the LLM has already added the message, if not then add it
1124
+ } else {
1125
+ message = message + &"\"" + &gene + &"\"" + &" not found in genedb.";
1126
+ }
1127
+ }
1128
+ }
1129
+ false => {
1130
+ let missing_gene_data: &str = "gene expression is not supported for this dataset";
1131
+ if message.to_lowercase().contains(&missing_gene_data.to_lowercase()) { // Check if the LLM has already added the message, if not then add it
1132
+ } else {
1133
+ message = message + &"Gene expression not supported for this dataset";
1134
+ }
1135
+ }
1136
+ }
1137
+ }
1138
+ }
1139
+ }
1140
+
1141
+ match &json_value.filter {
1142
+ Some(filter_terms_array) => {
1143
+ let mut validated_filter_terms = Vec::<FilterTerm>::new();
1144
+ for parsed_filter_term in filter_terms_array {
1145
+ match parsed_filter_term {
1146
+ FilterTerm::Categorical(categorical) => {
1147
+ let term_verification = verify_json_field(&categorical.term, &db_vec);
1148
+ let mut value_verification: Option<String> = None;
1149
+ for item in &db_vec {
1150
+ if &item.name == &categorical.term {
1151
+ for val in &item.values {
1152
+ if &categorical.value == val {
1153
+ value_verification = Some(val.clone());
1154
+ break;
1155
+ }
1156
+ }
1157
+ }
1158
+ if value_verification != None {
1159
+ break;
1160
+ }
1161
+ }
1162
+ if term_verification.correct_field.is_some() && value_verification.is_some() {
1163
+ let verified_filter = CategoricalFilterTerm {
1164
+ term: term_verification.correct_field.clone().unwrap(),
1165
+ value: value_verification.clone().unwrap(),
1166
+ };
1167
+ let categorical_filter_term: FilterTerm = FilterTerm::Categorical(verified_filter);
1168
+ validated_filter_terms.push(categorical_filter_term);
1169
+ }
1170
+ if term_verification.correct_field.is_none() {
1171
+ message = message + &"\"" + &categorical.term + &"\" filter term not found in db";
1172
+ }
1173
+ if value_verification.is_none() {
1174
+ message = message
1175
+ + &"\""
1176
+ + &categorical.value
1177
+ + &"\" filter value not found for filter field \""
1178
+ + &categorical.term
1179
+ + "\" in db";
1180
+ }
1181
+ }
1182
+ FilterTerm::Numeric(numeric) => {
1183
+ let term_verification = verify_json_field(&numeric.term, &db_vec);
1184
+ if term_verification.correct_field.is_none() {
1185
+ message = message + &"\"" + &numeric.term + &"\" filter term not found in db";
1186
+ } else {
1187
+ let numeric_filter_term: FilterTerm = FilterTerm::Numeric(numeric.clone());
1188
+ validated_filter_terms.push(numeric_filter_term);
1189
+ }
1190
+ }
1191
+ }
1192
+ }
1193
+
1194
+ for summary_term in &validated_summary_terms {
1195
+ match summary_term {
1196
+ SummaryTerms::clinical(clinicial_term) => {
1197
+ for filter_term in &validated_filter_terms {
1198
+ match filter_term {
1199
+ FilterTerm::Categorical(categorical) => {
1200
+ if &categorical.term == clinicial_term {
1201
+ summary_terms_tobe_removed.push(summary_term.clone());
1202
+ }
1203
+ }
1204
+ FilterTerm::Numeric(numeric) => {
1205
+ if &numeric.term == clinicial_term {
1206
+ summary_terms_tobe_removed.push(summary_term.clone());
1207
+ }
1208
+ }
1209
+ }
1210
+ }
1211
+ }
1212
+ SummaryTerms::geneExpression(gene) => {
1213
+ for filter_term in &validated_filter_terms {
1214
+ match filter_term {
1215
+ FilterTerm::Categorical(categorical) => {
1216
+ if &categorical.term == gene {
1217
+ summary_terms_tobe_removed.push(summary_term.clone());
1218
+ }
1219
+ }
1220
+ FilterTerm::Numeric(numeric) => {
1221
+ if &numeric.term == gene {
1222
+ summary_terms_tobe_removed.push(summary_term.clone());
1223
+ }
1224
+ }
1225
+ }
1226
+ }
1227
+ }
1228
+ }
1229
+ }
1230
+
1231
+ if validated_filter_terms.len() > 0 {
1232
+ if let Some(obj) = new_json.as_object_mut() {
1233
+ obj.insert(String::from("filter"), serde_json::json!(validated_filter_terms));
1234
+ }
1235
+ }
1236
+ }
1237
+ None => {}
1238
+ }
1239
+
1240
+ // Removing terms that are found both in filter term as well summary
1241
+ let mut validated_summary_terms_final = Vec::<SummaryTerms>::new();
1242
+
1243
+ for summary_term in &validated_summary_terms {
1244
+ let mut hit = 0;
1245
+ match summary_term {
1246
+ SummaryTerms::clinical(clinical_term) => {
1247
+ for summary_term2 in &summary_terms_tobe_removed {
1248
+ match summary_term2 {
1249
+ SummaryTerms::clinical(clinical_term2) => {
1250
+ if clinical_term == clinical_term2 {
1251
+ hit = 1;
1252
+ }
1253
+ }
1254
+ SummaryTerms::geneExpression(gene2) => {
1255
+ if clinical_term == gene2 {
1256
+ hit = 1;
1257
+ }
1258
+ }
1259
+ }
1260
+ }
1261
+ }
1262
+ SummaryTerms::geneExpression(gene) => {
1263
+ for summary_term2 in &summary_terms_tobe_removed {
1264
+ match summary_term2 {
1265
+ SummaryTerms::clinical(clinical_term2) => {
1266
+ if gene == clinical_term2 {
1267
+ hit = 1;
1268
+ }
1269
+ }
1270
+ SummaryTerms::geneExpression(gene2) => {
1271
+ if gene == gene2 {
1272
+ hit = 1;
1273
+ }
1274
+ }
1275
+ }
1276
+ }
1277
+ }
1278
+ }
1279
+ if hit == 0 {
1280
+ validated_summary_terms_final.push(summary_term.clone())
1281
+ }
1282
+ }
1283
+
1284
+ if let Some(obj) = new_json.as_object_mut() {
1285
+ obj.insert(
1286
+ String::from("summaryterms"),
1287
+ serde_json::json!(validated_summary_terms_final),
1288
+ );
1289
+ }
1290
+
1291
+ if message.len() > 0 {
1292
+ if let Some(obj) = new_json.as_object_mut() {
1293
+ // The `if let` ensures we only proceed if the top-level JSON is an object.
1294
+ // Append a new string field.
1295
+ obj.insert(String::from("message"), serde_json::json!(message));
1296
+ }
1297
+ }
1298
+ serde_json::to_string(&new_json).unwrap()
1299
+ }
1300
+
1301
+ #[derive(Debug, Clone)]
1302
+ struct VerifiedField {
1303
+ correct_field: Option<String>, // Name of the correct field
1304
+ correct_value: Option<String>, // Name of the correct value if there is a match between incorrect field and one of the values
1305
+ _probable_fields: Option<Vec<String>>, // If multiple fields are matching to the incomplete query
1306
+ }
1307
+
1308
+ fn verify_json_field(llm_field_name: &str, db_vec: &Vec<DbRows>) -> VerifiedField {
1309
+ // Check if llm_field_name exists or not in db name field
1310
+ let verified_result: VerifiedField;
1311
+ if db_vec.iter().any(|item| item.name == llm_field_name) {
1312
+ //println!("Found \"{}\" in db", llm_field_name);
1313
+ verified_result = VerifiedField {
1314
+ correct_field: Some(String::from(llm_field_name)),
1315
+ correct_value: None,
1316
+ _probable_fields: None,
1317
+ };
1318
+ } else {
1319
+ println!("Did not find \"{}\" in db", llm_field_name);
1320
+ // Check to see if llm_field_name exists as values under any of the fields
1321
+ let (search_field, search_val) = verify_json_value(llm_field_name, &db_vec);
1322
+
1323
+ match search_field {
1324
+ Some(x) => {
1325
+ verified_result = VerifiedField {
1326
+ correct_field: Some(String::from(x)),
1327
+ correct_value: search_val,
1328
+ _probable_fields: None,
1329
+ };
1330
+ }
1331
+ None => {
1332
+ // Incorrect field found neither in any of the fields nor any of the values. This will then invoke embedding match across all the fields and their corresponding values
1333
+
1334
+ let mut search_terms = Vec::<String>::new();
1335
+ search_terms.push(String::from(llm_field_name)); // Added the incorrect field item to the search
1336
+ verified_result = VerifiedField {
1337
+ correct_field: None,
1338
+ correct_value: None,
1339
+ _probable_fields: None,
1340
+ };
1341
+ }
1342
+ }
1343
+ }
1344
+ verified_result
1345
+ }
1346
+
1347
+ fn verify_json_value(llm_value_name: &str, db_vec: &Vec<DbRows>) -> (Option<String>, Option<String>) {
1348
+ let mut search_field: Option<String> = None;
1349
+ let mut search_val: Option<String> = None;
1350
+ for item in db_vec {
1351
+ for val in &item.values {
1352
+ if llm_value_name == val {
1353
+ search_field = Some(item.name.clone());
1354
+ search_val = Some(String::from(val));
1355
+ break;
1356
+ }
1357
+ }
1358
+ match search_field {
1359
+ Some(_) => break,
1360
+ None => {}
728
1361
  }
729
1362
  }
1363
+ (search_field, search_val)
730
1364
  }