@sjcrh/proteinpaint-rust 2.166.0 → 2.169.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/package.json CHANGED
@@ -1,5 +1,5 @@
1
1
  {
2
- "version": "2.166.0",
2
+ "version": "2.169.0",
3
3
  "name": "@sjcrh/proteinpaint-rust",
4
4
  "type": "module",
5
5
  "description": "Rust-based utilities for proteinpaint",
package/src/aichatbot.rs CHANGED
@@ -30,21 +30,51 @@ pub struct AiJsonFormat {
30
30
  #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
31
31
  enum Charts {
32
32
  // More chart types will be added here later
33
- Summary(TrainTestData),
34
- DE(TrainTestData),
33
+ Summary(TrainTestDataSummary),
34
+ DE(TrainTestDataDE),
35
35
  }
36
36
 
37
37
  #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
38
- struct TrainTestData {
38
+ struct TrainTestDataSummary {
39
39
  SystemPrompt: String,
40
- TrainingData: Vec<QuestionAnswer>,
41
- TestData: Vec<QuestionAnswer>,
40
+ TrainingData: Vec<QuestionAnswerSummary>,
41
+ TestData: Vec<QuestionAnswerSummary>,
42
42
  }
43
43
 
44
44
  #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
45
- struct QuestionAnswer {
45
+ struct QuestionAnswerSummary {
46
46
  question: String,
47
- answer: String,
47
+ answer: SummaryType,
48
+ }
49
+
50
+ #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
51
+ struct TrainTestDataDE {
52
+ SystemPrompt: String,
53
+ TrainingData: Vec<QuestionAnswerDE>,
54
+ TestData: Vec<QuestionAnswerDE>,
55
+ }
56
+
57
+ #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
58
+ struct QuestionAnswerDE {
59
+ question: String,
60
+ answer: DEType,
61
+ }
62
+
63
+ #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
64
+ struct DEType {
65
+ action: String,
66
+ DE_output: DETerms,
67
+ }
68
+
69
+ #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
70
+ struct DETerms {
71
+ group1: GroupType,
72
+ group2: GroupType,
73
+ }
74
+
75
+ #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
76
+ struct GroupType {
77
+ name: String,
48
78
  }
49
79
 
50
80
  #[allow(non_camel_case_types)]
@@ -77,6 +107,27 @@ async fn main() -> Result<()> {
77
107
  None => panic!("user_input field is missing in input json"),
78
108
  }
79
109
 
110
+ let dataset_db_json: &JsonValue = &json_string["dataset_db"];
111
+ let dataset_db_str: &str;
112
+ match dataset_db_json.as_str() {
113
+ Some(inp) => dataset_db_str = inp,
114
+ None => panic!("dataset_db field is missing in input json"),
115
+ }
116
+
117
+ let genedb_json: &JsonValue = &json_string["genedb"];
118
+ let genedb_str: &str;
119
+ match genedb_json.as_str() {
120
+ Some(inp) => genedb_str = inp,
121
+ None => panic!("genedb field is missing in input json"),
122
+ }
123
+
124
+ let aiRoute_json: &JsonValue = &json_string["aiRoute"];
125
+ let aiRoute_str: &str;
126
+ match aiRoute_json.as_str() {
127
+ Some(inp) => aiRoute_str = inp,
128
+ None => panic!("aiRoute field is missing in input json"),
129
+ }
130
+
80
131
  if user_input.len() == 0 {
81
132
  panic!("The user input is empty");
82
133
  }
@@ -124,8 +175,9 @@ async fn main() -> Result<()> {
124
175
  let ai_json: AiJsonFormat =
125
176
  serde_json::from_str(&ai_data).expect("AI JSON file does not have the correct format");
126
177
 
127
- let genedb = String::from(tpmasterdir) + &"/" + &ai_json.genedb;
128
- let dataset_db = String::from(tpmasterdir) + &"/" + &ai_json.db;
178
+ let genedb = String::from(tpmasterdir) + &"/" + &genedb_str;
179
+ let dataset_db = String::from(tpmasterdir) + &"/" + &dataset_db_str;
180
+ let airoute = String::from(binpath) + &"/../../" + &aiRoute_str;
129
181
 
130
182
  let apilink_json: &JsonValue = &json_string["apilink"];
131
183
  let apilink: &str;
@@ -160,7 +212,7 @@ async fn main() -> Result<()> {
160
212
  let temperature: f64 = 0.01;
161
213
  let max_new_tokens: usize = 512;
162
214
  let top_p: f32 = 0.95;
163
-
215
+ let testing = false; // This variable is always false in production, this is true in test_ai.rs for testing code
164
216
  if llm_backend_name != "ollama" && llm_backend_name != "SJ" {
165
217
  panic!(
166
218
  "This code currently supports only Ollama and SJ provider. llm_backend_name must be \"ollama\" or \"SJ\""
@@ -185,6 +237,8 @@ async fn main() -> Result<()> {
185
237
  &dataset_db,
186
238
  &genedb,
187
239
  &ai_json,
240
+ &airoute,
241
+ testing,
188
242
  )
189
243
  .await;
190
244
  } else if llm_backend_name == "SJ".to_string() {
@@ -207,6 +261,8 @@ async fn main() -> Result<()> {
207
261
  &dataset_db,
208
262
  &genedb,
209
263
  &ai_json,
264
+ &airoute,
265
+ testing,
210
266
  )
211
267
  .await;
212
268
  }
@@ -239,6 +295,8 @@ pub async fn run_pipeline(
239
295
  dataset_db: &str,
240
296
  genedb: &str,
241
297
  ai_json: &AiJsonFormat,
298
+ ai_route: &str,
299
+ testing: bool,
242
300
  ) -> Option<String> {
243
301
  let mut classification: String = classify_query_by_dataset_type(
244
302
  user_input,
@@ -248,6 +306,7 @@ pub async fn run_pipeline(
248
306
  temperature,
249
307
  max_new_tokens,
250
308
  top_p,
309
+ ai_route,
251
310
  )
252
311
  .await;
253
312
  classification = classification.replace("\"", "");
@@ -263,13 +322,20 @@ pub async fn run_pipeline(
263
322
  top_p,
264
323
  )
265
324
  .await;
266
- final_output = format!(
267
- "{{\"{}\":\"{}\",\"{}\":[{}}}",
268
- "action",
269
- "dge",
270
- "DE_output",
271
- de_result + &"]"
272
- );
325
+ if testing == true {
326
+ final_output = format!(
327
+ "{{\"{}\":\"{}\",\"{}\":[{}}}",
328
+ "action",
329
+ "dge",
330
+ "DE_output",
331
+ de_result + &"]"
332
+ );
333
+ } else {
334
+ final_output = format!(
335
+ "{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
336
+ "type", "html", "html", "DE agent not implemented yet"
337
+ );
338
+ }
273
339
  } else if classification == "summary".to_string() {
274
340
  final_output = extract_summary_information(
275
341
  user_input,
@@ -282,30 +348,83 @@ pub async fn run_pipeline(
282
348
  dataset_db,
283
349
  genedb,
284
350
  ai_json,
351
+ testing,
285
352
  )
286
353
  .await;
287
354
  } else if classification == "hierarchical".to_string() {
288
355
  // Not implemented yet
289
- final_output = format!("{{\"{}\":\"{}\"}}", "action", "hierarchical");
356
+ if testing == true {
357
+ final_output = format!("{{\"{}\":\"{}\"}}", "action", "hierarchical");
358
+ } else {
359
+ final_output = format!(
360
+ "{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
361
+ "type", "html", "html", "hierarchical clustering agent not implemented yet"
362
+ );
363
+ }
290
364
  } else if classification == "snv_indel".to_string() {
291
365
  // Not implemented yet
292
- final_output = format!("{{\"{}\":\"{}\"}}", "action", "snv_indel");
366
+ if testing == true {
367
+ final_output = format!("{{\"{}\":\"{}\"}}", "action", "snv_indel");
368
+ } else {
369
+ final_output = format!(
370
+ "{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
371
+ "type", "html", "html", "snv_indel agent not implemented yet"
372
+ );
373
+ }
293
374
  } else if classification == "cnv".to_string() {
294
375
  // Not implemented yet
295
- final_output = format!("{{\"{}\":\"{}\"}}", "action", "cnv");
376
+ if testing == true {
377
+ final_output = format!("{{\"{}\":\"{}\"}}", "action", "cnv");
378
+ } else {
379
+ final_output = format!(
380
+ "{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
381
+ "type", "html", "html", "cnv agent not implemented yet"
382
+ );
383
+ }
296
384
  } else if classification == "variant_calling".to_string() {
297
385
  // Not implemented yet and will never be supported. Need a separate messages for this
298
- final_output = format!("{{\"{}\":\"{}\"}}", "action", "variant_calling");
386
+ if testing == true {
387
+ final_output = format!("{{\"{}\":\"{}\"}}", "action", "variant_calling");
388
+ } else {
389
+ final_output = format!(
390
+ "{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
391
+ "type", "html", "html", "variant_calling agent not implemented yet"
392
+ );
393
+ }
299
394
  } else if classification == "survival".to_string() {
300
395
  // Not implemented yet
301
- final_output = format!("{{\"{}\":\"{}\"}}", "action", "surivial");
396
+ if testing == true {
397
+ final_output = format!("{{\"{}\":\"{}\"}}", "action", "surivial");
398
+ } else {
399
+ final_output = format!(
400
+ "{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
401
+ "type", "html", "html", "survival agent not implemented yet"
402
+ );
403
+ }
302
404
  } else if classification == "none".to_string() {
303
- final_output = format!(
304
- "{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
305
- "action", "none", "message", "The input query did not match any known features in Proteinpaint"
306
- );
405
+ if testing == true {
406
+ final_output = format!(
407
+ "{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
408
+ "action", "none", "message", "The input query did not match any known features in Proteinpaint"
409
+ );
410
+ } else {
411
+ final_output = format!(
412
+ "{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
413
+ "type", "html", "html", "The input query did not match any known features in Proteinpaint"
414
+ );
415
+ }
307
416
  } else {
308
- final_output = format!("{{\"{}\":\"{}\"}}", "action", "unknown:".to_string() + &classification);
417
+ if testing == true {
418
+ final_output = format!("{{\"{}\":\"{}\"}}", "action", "unknown:".to_string() + &classification);
419
+ } else {
420
+ final_output = format!(
421
+ "{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
422
+ "type",
423
+ "html",
424
+ "html",
425
+ "unknown:".to_string() + &classification
426
+ );
427
+ }
309
428
  }
310
429
  Some(final_output)
311
430
  }
@@ -313,101 +432,33 @@ pub async fn run_pipeline(
313
432
  async fn classify_query_by_dataset_type(
314
433
  user_input: &str,
315
434
  comp_model: impl rig::completion::CompletionModel + 'static,
316
- embedding_model: impl rig::embeddings::EmbeddingModel + 'static,
435
+ _embedding_model: impl rig::embeddings::EmbeddingModel + 'static,
317
436
  llm_backend_type: &llm_backend,
318
437
  temperature: f64,
319
438
  max_new_tokens: usize,
320
439
  top_p: f32,
440
+ ai_route: &str,
321
441
  ) -> String {
322
- // Create a string to hold the file contents
323
- let contents = String::from("SNV/SNP or point mutations nucleotide mutations are very common forms of mutations which can often give rise to genetic diseases such as cancer, Alzheimer's disease etc. They can be duw to substitution of nucleotide, or insertion or deletion of a nucleotide. Indels are multi-nucleotide insertion/deletion/substitutions. Complex indels are indels where insertion and deletion have happened in the same genomic locus. Every genomic sample from each patient has its own set of mutations therefore requiring personalized treatment.
324
-
325
- If a ProteinPaint dataset contains SNV/Indel/SV data then return JSON with single key, 'snv_indel'.
326
-
327
- ---
328
-
329
- Copy number variation (CNV) is a phenomenon in which sections of the genome are repeated and the number of repeats in the genome varies between individuals.[1] Copy number variation is a special type of structural variation: specifically, it is a type of duplication or deletion event that affects a considerable number of base pairs.
330
-
331
- If a ProteinPaint dataset contains copy number variation data then return JSON with single key, 'cnv'.
332
-
333
- ---
334
-
335
- Structural variants/fusions (SV) are genomic mutations when eith a DNA region is translocated or copied to an entirely different genomic locus. In case of transcriptomic data, when RNA is fused from two different genes its called a gene fusion.
336
-
337
- If a ProteinPaint dataset contains structural variation or gene fusion data then return JSON with single key, 'sv_fusion'.
338
- ---
339
-
340
- Hierarchical clustering of gene expression is an unsupervised learning technique where several number of relevant genes and the samples are clustered so as to determine (previously unknown) cohorts of samples (or patients) or structure in data. It is very commonly used to determine subtypes of a particular disease based on RNA sequencing data.
341
-
342
- If a ProteinPaint dataset contains hierarchical data then return JSON with single key, 'hierarchical'.
343
-
344
- ---
345
-
346
- Differential Gene Expression (DGE or DE) is a technique where the most upregulated (or highest) and downregulated (or lowest) genes between two cohorts of samples (or patients) are determined from a pool of THOUSANDS of genes. Differential gene expression CANNOT be computed for a SINGLE gene. A volcano plot is shown with fold-change in the x-axis and adjusted p-value on the y-axis. So, the upregulated and downregulared genes are on opposite sides of the graph and the most significant genes (based on adjusted p-value) is on the top of the graph. Following differential gene expression generally GeneSet Enrichment Analysis (GSEA) is carried out where based on the genes and their corresponding fold changes the upregulation/downregulation of genesets (or pathways) is determined.
347
-
348
- Sample Query1: \"Which gene has the highest expression between the two genders\"
349
- Sample Answer1: { \"answer\": \"dge\" }
350
-
351
- Sample Query2: \"Which gene has the lowest expression between the two races\"
352
- Sample Answer2: { \"answer\": \"dge\" }
353
-
354
- Sample Query1: \"Which genes are the most upregulated genes between group A and group B\"
355
- Sample Answer1: { \"answer\": \"dge\" }
356
-
357
- Sample Query3: \"Which gene are overexpressed between male and female\"
358
- Sample Answer3: { \"answer\": \"dge\" }
359
-
360
- Sample Query4: \"Which gene are housekeeping genes between male and female\"
361
- Sample Answer4: { \"answer\": \"dge\" }
362
-
363
-
364
- If a ProteinPaint dataset contains differential gene expression data then return JSON with single key, 'dge'.
365
-
366
- ---
442
+ // Read the file
443
+ let ai_route_data = fs::read_to_string(ai_route).unwrap();
367
444
 
368
- Survival analysis (also called time-to-event analysis or duration analysis) is a branch of statistics aimed at analyzing the duration of time from a well-defined time origin until one or more events happen, called survival times or duration times. In other words, in survival analysis, we are interested in a certain event and want to analyze the time until the event happens. Generally in survival analysis survival rates between two (or more) cohorts of patients is compared.
445
+ // Parse the JSON data
446
+ let ai_json: Value = serde_json::from_str(&ai_route_data).expect("AI JSON file does not have the correct format");
369
447
 
370
- There are two main methods of survival analysis:
371
-
372
- 1) Kaplan-Meier (HM) analysis is a univariate test that only takes into account a single categorical variable.
373
- 2) Cox proportional hazards model (coxph) is a multivariate test that can take into account multiple variables.
374
-
375
- The hazard ratio (HR) is an indicator of the effect of the stimulus (e.g. drug dose, treatment) between two cohorts of patients.
376
- HR = 1: No effect
377
- HR < 1: Reduction in the hazard
378
- HR > 1: Increase in Hazard
379
-
380
- Sample Query1: \"Compare survival rates between group A and B\"
381
- Sample Answer1: { \"answer\": \"survival\" }
382
-
383
-
384
- If a ProteinPaint dataset contains survival data then return JSON with single key, 'survival'.
385
-
386
- ---
387
-
388
- Next generation sequencing reads (NGS) are mapped to a human genome using alignment algorithm such as burrows-wheelers alignment algorithm. Then these reads are called using variant calling algorithms such as GATK (Genome Analysis Toolkit). However this type of analysis is too compute intensive and beyond the scope of visualization software such as ProteinPaint.
389
-
390
- If a user query asks about variant calling or mapping reads then JSON with single key, 'variant_calling'.
391
-
392
- ---
393
-
394
- Summary plot in ProteinPaint shows the various facets of the datasets. Show expression of a SINGLE gene or compare the expression of a SINGLE gene across two different cohorts defined by the user. It may show all the samples according to their respective diagnosis or subtypes of cancer. It is also useful for comparing and correlating different clinical variables. It can show all possible distributions, frequency of a category, overlay, correlate or cross-tabulate with another variable on top of it. If a user query asks about a SINGLE gene expression or correlating clinical variables then return JSON with single key, 'summary'.
395
-
396
- Sample Query1: \"Show all fusions for patients with age less than 30\"
397
- Sample Answer1: { \"answer\": \"summary\" }
398
-
399
- Sample Query2: \"List all molecular subtypes of leukemia\"
400
- Sample Answer2: { \"answer\": \"summary\" }
401
-
402
- Sample Query3: \"is tp53 expression higher in men than women ?\"
403
- Sample Answer3: { \"answer\": \"summary\" }
404
-
405
- Sample Query4: \"Compare ATM expression between races for women greater than 80yrs\"
406
- Sample Answer4: { \"answer\": \"summary\" }
448
+ // Create a string to hold the file contents
449
+ let mut contents = String::from("");
407
450
 
451
+ if let Some(object) = ai_json.as_object() {
452
+ for (_key, value) in object {
453
+ contents += &value.as_str().unwrap();
454
+ contents += "---"; // Adding delimiter
455
+ }
456
+ }
408
457
 
409
- If a query does not match any of the fields described above, then return JSON with single key, 'none'
410
- ");
458
+ // Removing the last "---" characters
459
+ contents.pop();
460
+ contents.pop();
461
+ contents.pop();
411
462
 
412
463
  // Split the contents by the delimiter "---"
413
464
  let parts: Vec<&str> = contents.split("---").collect();
@@ -438,18 +489,18 @@ If a query does not match any of the fields described above, then return JSON wi
438
489
  rag_docs.push(part.trim().to_string())
439
490
  }
440
491
 
441
- //let top_k: usize = 3;
492
+ //let top_k: usize = 3; // Embedding model not used currently
442
493
  // Create embeddings and add to vector store
443
- let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
444
- .documents(rag_docs)
445
- .expect("Reason1")
446
- .build()
447
- .await
448
- .unwrap();
494
+ //let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
495
+ // .documents(rag_docs)
496
+ // .expect("Reason1")
497
+ // .build()
498
+ // .await
499
+ // .unwrap();
449
500
 
450
- // Create vector store
451
- let mut vector_store = InMemoryVectorStore::<String>::default();
452
- InMemoryVectorStore::add_documents(&mut vector_store, embeddings);
501
+ //// Create vector store
502
+ //let mut vector_store = InMemoryVectorStore::<String>::default();
503
+ //InMemoryVectorStore::add_documents(&mut vector_store, embeddings);
453
504
 
454
505
  // Create RAG agent
455
506
  let agent = AgentBuilder::new(comp_model).preamble(&(String::from("Generate classification for the user query into summary, dge, hierarchical, snv_indel, cnv, variant_calling, sv_fusion and none categories. Return output in JSON with ALWAYS a single word answer { \"answer\": \"dge\" }, that is 'summary' for summary plot, 'dge' for differential gene expression, 'hierarchical' for hierarchical clustering, 'snv_indel' for SNV/Indel, 'cnv' for CNV and 'sv_fusion' for SV/fusion, 'variant_calling' for variant calling, 'surivial' for survival data, 'none' for none of the previously described categories. The summary plot list and summarizes the cohort of patients according to the user query. The answer should always be in lower case\n The options are as follows:\n") + &contents + "\nQuestion= {question} \nanswer")).temperature(temperature).additional_params(additional).build();
@@ -801,6 +852,7 @@ async fn extract_summary_information(
801
852
  dataset_db: &str,
802
853
  genedb: &str,
803
854
  ai_json: &AiJsonFormat,
855
+ testing: bool,
804
856
  ) -> String {
805
857
  let (rag_docs, db_vec) = parse_dataset_db(dataset_db).await;
806
858
  let additional;
@@ -845,7 +897,7 @@ async fn extract_summary_information(
845
897
  .filter(|x| user_words2.contains(&x.to_lowercase()))
846
898
  .collect();
847
899
 
848
- let mut summary_data_check: Option<TrainTestData> = None;
900
+ let mut summary_data_check: Option<TrainTestDataSummary> = None;
849
901
  for chart in ai_json.charts.clone() {
850
902
  if let Charts::Summary(traindata) = chart {
851
903
  summary_data_check = Some(traindata);
@@ -858,6 +910,7 @@ async fn extract_summary_information(
858
910
  let mut training_data: String = String::from("");
859
911
  let mut train_iter = 0;
860
912
  for ques_ans in summary_data.TrainingData {
913
+ let summary_answer: SummaryType = ques_ans.answer;
861
914
  train_iter += 1;
862
915
  training_data += "Example question";
863
916
  training_data += &train_iter.to_string();
@@ -867,7 +920,7 @@ async fn extract_summary_information(
867
920
  training_data += "Example answer";
868
921
  training_data += &train_iter.to_string();
869
922
  training_data += &":";
870
- training_data += &ques_ans.answer;
923
+ training_data += &serde_json::to_string(&summary_answer).unwrap();
871
924
  training_data += &"\n";
872
925
  }
873
926
 
@@ -919,7 +972,8 @@ async fn extract_summary_information(
919
972
  }
920
973
  }
921
974
  //println!("final_llm_json:{}", final_llm_json);
922
- let final_validated_json = validate_summary_output(final_llm_json.clone(), db_vec, common_genes, ai_json);
975
+ let final_validated_json =
976
+ validate_summary_output(final_llm_json.clone(), db_vec, common_genes, ai_json, testing);
923
977
  final_validated_json
924
978
  }
925
979
  None => {
@@ -949,7 +1003,7 @@ struct SummaryType {
949
1003
 
950
1004
  impl SummaryType {
951
1005
  #[allow(dead_code)]
952
- pub fn sort_summarytype_struct(&mut self) {
1006
+ pub fn sort_summarytype_struct(mut self) -> SummaryType {
953
1007
  // This function is necessary for testing (test_ai.rs) to see if two variables of type "SummaryType" are equal or not. Without this a vector of two Summarytype holding the same values but in different order will be classified separately.
954
1008
  self.summaryterms.sort();
955
1009
 
@@ -957,6 +1011,7 @@ impl SummaryType {
957
1011
  Some(ref mut filterterms) => filterterms.sort(),
958
1012
  None => {}
959
1013
  }
1014
+ self.clone()
960
1015
  }
961
1016
  }
962
1017
 
@@ -974,7 +1029,7 @@ impl PartialOrd for SummaryTerms {
974
1029
  (SummaryTerms::clinical(_), SummaryTerms::clinical(_)) => Some(std::cmp::Ordering::Equal),
975
1030
  (SummaryTerms::geneExpression(_), SummaryTerms::geneExpression(_)) => Some(std::cmp::Ordering::Equal),
976
1031
  (SummaryTerms::clinical(_), SummaryTerms::geneExpression(_)) => Some(std::cmp::Ordering::Greater),
977
- (SummaryTerms::geneExpression(_), SummaryTerms::clinical(_)) => Some(std::cmp::Ordering::Greater),
1032
+ (SummaryTerms::geneExpression(_), SummaryTerms::clinical(_)) => Some(std::cmp::Ordering::Less),
978
1033
  }
979
1034
  }
980
1035
  }
@@ -1063,6 +1118,7 @@ fn validate_summary_output(
1063
1118
  db_vec: Vec<DbRows>,
1064
1119
  common_genes: Vec<String>,
1065
1120
  ai_json: &AiJsonFormat,
1121
+ testing: bool,
1066
1122
  ) -> String {
1067
1123
  let json_value: SummaryType =
1068
1124
  serde_json::from_str(&raw_llm_json).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
@@ -1094,7 +1150,7 @@ fn validate_summary_output(
1094
1150
  match term_verification.correct_field {
1095
1151
  Some(tm) => validated_summary_terms.push(SummaryTerms::clinical(tm)),
1096
1152
  None => {
1097
- message = message + &"\"" + &clin + &"\"" + &" not found in db.";
1153
+ message = message + &"'" + &clin + &"'" + &" not found in db.";
1098
1154
  }
1099
1155
  }
1100
1156
  } else if Some(term_verification.correct_field.clone()).is_some()
@@ -1122,7 +1178,7 @@ fn validate_summary_output(
1122
1178
  if num_gene_verification == 0 || common_genes.len() == 0 {
1123
1179
  if message.to_lowercase().contains(&gene.to_lowercase()) { // Check if the LLM has already added the message, if not then add it
1124
1180
  } else {
1125
- message = message + &"\"" + &gene + &"\"" + &" not found in genedb.";
1181
+ message = message + &"'" + &gene + &"'" + &" not found in genedb.";
1126
1182
  }
1127
1183
  }
1128
1184
  }
@@ -1138,6 +1194,8 @@ fn validate_summary_output(
1138
1194
  }
1139
1195
  }
1140
1196
 
1197
+ let mut pp_plot_json: Value; // The PP compliant plot JSON
1198
+ pp_plot_json = serde_json::from_str(&"{\"chartType\":\"summary\"}").expect("Not a valid JSON");
1141
1199
  match &json_value.filter {
1142
1200
  Some(filter_terms_array) => {
1143
1201
  let mut validated_filter_terms = Vec::<FilterTerm>::new();
@@ -1168,21 +1226,21 @@ fn validate_summary_output(
1168
1226
  validated_filter_terms.push(categorical_filter_term);
1169
1227
  }
1170
1228
  if term_verification.correct_field.is_none() {
1171
- message = message + &"\"" + &categorical.term + &"\" filter term not found in db";
1229
+ message = message + &"'" + &categorical.term + &"' filter term not found in db";
1172
1230
  }
1173
1231
  if value_verification.is_none() {
1174
1232
  message = message
1175
- + &"\""
1233
+ + &"'"
1176
1234
  + &categorical.value
1177
- + &"\" filter value not found for filter field \""
1235
+ + &"' filter value not found for filter field '"
1178
1236
  + &categorical.term
1179
- + "\" in db";
1237
+ + "' in db";
1180
1238
  }
1181
1239
  }
1182
1240
  FilterTerm::Numeric(numeric) => {
1183
1241
  let term_verification = verify_json_field(&numeric.term, &db_vec);
1184
1242
  if term_verification.correct_field.is_none() {
1185
- message = message + &"\"" + &numeric.term + &"\" filter term not found in db";
1243
+ message = message + &"'" + &numeric.term + &"' filter term not found in db";
1186
1244
  } else {
1187
1245
  let numeric_filter_term: FilterTerm = FilterTerm::Numeric(numeric.clone());
1188
1246
  validated_filter_terms.push(numeric_filter_term);
@@ -1229,8 +1287,68 @@ fn validate_summary_output(
1229
1287
  }
1230
1288
 
1231
1289
  if validated_filter_terms.len() > 0 {
1232
- if let Some(obj) = new_json.as_object_mut() {
1233
- obj.insert(String::from("filter"), serde_json::json!(validated_filter_terms));
1290
+ if testing == true {
1291
+ if let Some(obj) = new_json.as_object_mut() {
1292
+ obj.insert(String::from("filter"), serde_json::json!(validated_filter_terms));
1293
+ }
1294
+ } else {
1295
+ let mut validated_filter_terms_PP: String = "[".to_string();
1296
+ let mut filter_hits = 0;
1297
+ for validated_term in validated_filter_terms {
1298
+ match validated_term {
1299
+ FilterTerm::Categorical(categorical_filter) => {
1300
+ let string_json = "{\"term\":\"".to_string()
1301
+ + &categorical_filter.term
1302
+ + &"\", \"category\":\""
1303
+ + &categorical_filter.value
1304
+ + &"\"},";
1305
+ validated_filter_terms_PP += &string_json;
1306
+ }
1307
+ FilterTerm::Numeric(numeric_filter) => {
1308
+ let string_json;
1309
+ if numeric_filter.greaterThan.is_some() && numeric_filter.lessThan.is_none() {
1310
+ string_json = "{\"term\":\"".to_string()
1311
+ + &numeric_filter.term
1312
+ + &"\", \"gt\":\""
1313
+ + &numeric_filter.greaterThan.unwrap().to_string()
1314
+ + &"\"},";
1315
+ } else if numeric_filter.greaterThan.is_none() && numeric_filter.lessThan.is_some() {
1316
+ string_json = "{\"term\":\"".to_string()
1317
+ + &numeric_filter.term
1318
+ + &"\", \"lt\":\""
1319
+ + &numeric_filter.lessThan.unwrap().to_string()
1320
+ + &"\"},";
1321
+ } else if numeric_filter.greaterThan.is_some() && numeric_filter.lessThan.is_some() {
1322
+ string_json = "{\"term\":\"".to_string()
1323
+ + &numeric_filter.term
1324
+ + &"\", \"lt\":\""
1325
+ + &numeric_filter.lessThan.unwrap().to_string()
1326
+ + &"\", \"gt\":\""
1327
+ + &numeric_filter.greaterThan.unwrap().to_string()
1328
+ + &"\"},";
1329
+ } else {
1330
+ // When both greater and less than are none
1331
+ panic!(
1332
+ "Numeric filter term {} is missing both greater than and less than values. One of them must be defined",
1333
+ &numeric_filter.term
1334
+ );
1335
+ }
1336
+ validated_filter_terms_PP += &string_json;
1337
+ }
1338
+ };
1339
+ filter_hits += 1;
1340
+ }
1341
+ println!("validated_filter_terms_PP:{}", validated_filter_terms_PP);
1342
+ if filter_hits > 0 {
1343
+ validated_filter_terms_PP.pop();
1344
+ validated_filter_terms_PP += &"]";
1345
+ if let Some(obj) = pp_plot_json.as_object_mut() {
1346
+ obj.insert(
1347
+ String::from("simpleFilter"),
1348
+ serde_json::from_str(&validated_filter_terms_PP).expect("Not a valid JSON"),
1349
+ );
1350
+ }
1351
+ }
1234
1352
  }
1235
1353
  }
1236
1354
  }
@@ -1240,6 +1358,10 @@ fn validate_summary_output(
1240
1358
  // Removing terms that are found both in filter term as well summary
1241
1359
  let mut validated_summary_terms_final = Vec::<SummaryTerms>::new();
1242
1360
 
1361
+ let mut sum_iter = 0;
1362
+ let mut pp_json: Value; // New JSON value that will contain items of the final PP compliant JSON
1363
+ pp_json = serde_json::from_str(&"{\"type\":\"plot\"}").expect("Not a valid JSON");
1364
+
1243
1365
  for summary_term in &validated_summary_terms {
1244
1366
  let mut hit = 0;
1245
1367
  match summary_term {
@@ -1276,9 +1398,53 @@ fn validate_summary_output(
1276
1398
  }
1277
1399
  }
1278
1400
  }
1401
+
1279
1402
  if hit == 0 {
1403
+ let mut termidpp: Option<TermIDPP> = None;
1404
+ let mut geneexp: Option<GeneExpressionPP> = None;
1405
+ match summary_term {
1406
+ SummaryTerms::clinical(clinical_term) => {
1407
+ termidpp = Some(TermIDPP {
1408
+ id: clinical_term.to_string(),
1409
+ });
1410
+ }
1411
+ SummaryTerms::geneExpression(gene) => {
1412
+ geneexp = Some(GeneExpressionPP {
1413
+ gene: gene.to_string(),
1414
+ r#type: "geneExpression".to_string(),
1415
+ });
1416
+ }
1417
+ }
1418
+ if sum_iter == 0 {
1419
+ if termidpp.is_some() {
1420
+ if let Some(obj) = pp_plot_json.as_object_mut() {
1421
+ obj.insert(String::from("term"), serde_json::json!(Some(termidpp)));
1422
+ }
1423
+ }
1424
+
1425
+ if geneexp.is_some() {
1426
+ let gene_term = GeneTerm { term: geneexp.unwrap() };
1427
+ if let Some(obj) = pp_plot_json.as_object_mut() {
1428
+ obj.insert(String::from("term"), serde_json::json!(gene_term));
1429
+ }
1430
+ }
1431
+ } else if sum_iter == 1 {
1432
+ if termidpp.is_some() {
1433
+ if let Some(obj) = pp_plot_json.as_object_mut() {
1434
+ obj.insert(String::from("term2"), serde_json::json!(Some(termidpp)));
1435
+ }
1436
+ }
1437
+
1438
+ if geneexp.is_some() {
1439
+ let gene_term = GeneTerm { term: geneexp.unwrap() };
1440
+ if let Some(obj) = pp_plot_json.as_object_mut() {
1441
+ obj.insert(String::from("term2"), serde_json::json!(gene_term));
1442
+ }
1443
+ }
1444
+ }
1280
1445
  validated_summary_terms_final.push(summary_term.clone())
1281
1446
  }
1447
+ sum_iter += 1
1282
1448
  }
1283
1449
 
1284
1450
  if let Some(obj) = new_json.as_object_mut() {
@@ -1288,14 +1454,61 @@ fn validate_summary_output(
1288
1454
  );
1289
1455
  }
1290
1456
 
1457
+ if let Some(obj) = pp_json.as_object_mut() {
1458
+ // The `if let` ensures we only proceed if the top-level JSON is an object.
1459
+ // Append a new string field.
1460
+ obj.insert(String::from("plot"), serde_json::json!(pp_plot_json));
1461
+ }
1462
+
1463
+ let mut err_json: Value; // Error JSON containing the error message (if present)
1291
1464
  if message.len() > 0 {
1292
- if let Some(obj) = new_json.as_object_mut() {
1293
- // The `if let` ensures we only proceed if the top-level JSON is an object.
1294
- // Append a new string field.
1295
- obj.insert(String::from("message"), serde_json::json!(message));
1465
+ if testing == false {
1466
+ err_json = serde_json::from_str(&"{\"type\":\"html\"}").expect("Not a valid JSON");
1467
+ if let Some(obj) = err_json.as_object_mut() {
1468
+ // The `if let` ensures we only proceed if the top-level JSON is an object.
1469
+ // Append a new string field.
1470
+ obj.insert(String::from("html"), serde_json::json!(message));
1471
+ };
1472
+ serde_json::to_string(&err_json).unwrap()
1473
+ } else {
1474
+ if let Some(obj) = new_json.as_object_mut() {
1475
+ // The `if let` ensures we only proceed if the top-level JSON is an object.
1476
+ // Append a new string field.
1477
+ obj.insert(String::from("message"), serde_json::json!(message));
1478
+ };
1479
+ serde_json::to_string(&new_json).unwrap()
1480
+ }
1481
+ } else {
1482
+ if testing == true {
1483
+ // When testing script output native LLM JSON
1484
+ serde_json::to_string(&new_json).unwrap()
1485
+ } else {
1486
+ // When in production output PP compliant JSON
1487
+ serde_json::to_string(&pp_json).unwrap()
1296
1488
  }
1297
1489
  }
1298
- serde_json::to_string(&new_json).unwrap()
1490
+ }
1491
+
1492
+ fn getGeneExpression() -> String {
1493
+ "geneExpression".to_string()
1494
+ }
1495
+
1496
+ #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
1497
+ struct TermIDPP {
1498
+ id: String,
1499
+ }
1500
+
1501
+ #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
1502
+ struct GeneTerm {
1503
+ term: GeneExpressionPP,
1504
+ }
1505
+
1506
+ #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
1507
+ struct GeneExpressionPP {
1508
+ gene: String,
1509
+ // Serde uses this for deserialization.
1510
+ #[serde(default = "getGeneExpression")]
1511
+ r#type: String,
1299
1512
  }
1300
1513
 
1301
1514
  #[derive(Debug, Clone)]
@@ -57,6 +57,7 @@ struct InteractiveData {
57
57
  x_buffer: i64,
58
58
  y_min: f64,
59
59
  y_max: f64,
60
+ device_pixel_ratio: f64,
60
61
  }
61
62
 
62
63
  #[derive(Serialize)]
@@ -216,8 +217,10 @@ fn grin2_file_read(
216
217
  Some(q) => q,
217
218
  None => continue,
218
219
  };
220
+
219
221
  let q_val: f64 = match q_val_str.parse() {
220
222
  Ok(v) if v > 0.0 => v,
223
+ Ok(v) if v == 0.0 => 1e-300, // Treat exact 0 as ~1e-300 so we can still show q-values that are 0 and not filter them out
221
224
  _ => continue,
222
225
  };
223
226
  let neg_log10_q = -q_val.log10();
@@ -335,12 +338,8 @@ fn plot_grin2_manhattan(
335
338
  let png_width = plot_width + 2 * png_dot_radius;
336
339
  let png_height = plot_height + 2 * png_dot_radius;
337
340
 
338
- let w: u32 = (png_width * device_pixel_ratio as u64)
339
- .try_into()
340
- .expect("PNG width too large for u32");
341
- let h: u32 = (png_height * device_pixel_ratio as u64)
342
- .try_into()
343
- .expect("PNG height too large for u32");
341
+ let w: u32 = ((png_width as f64) * dpr) as u32;
342
+ let h: u32 = ((png_height as f64) * dpr) as u32;
344
343
 
345
344
  // Create RGB buffer for Plotters
346
345
  let mut buffer = vec![0u8; w as usize * h as usize * 3];
@@ -402,8 +401,8 @@ fn plot_grin2_manhattan(
402
401
 
403
402
  for (i, p) in point_details.iter_mut().enumerate() {
404
403
  let (px, py) = pixel_positions[*&sig_indices[i]];
405
- p.pixel_x = px;
406
- p.pixel_y = py;
404
+ p.pixel_x = px / dpr;
405
+ p.pixel_y = py / dpr;
407
406
  }
408
407
 
409
408
  // flush root drawing area
@@ -469,6 +468,7 @@ fn plot_grin2_manhattan(
469
468
  x_buffer,
470
469
  y_min,
471
470
  y_max,
471
+ device_pixel_ratio: dpr,
472
472
  };
473
473
  Ok((png_data, interactive_data))
474
474
  }
package/src/test_ai.rs CHANGED
@@ -20,6 +20,7 @@ mod tests {
20
20
  ollama_comp_model_name: String,
21
21
  ollama_embedding_model_name: String,
22
22
  genomes: Vec<Genomes>,
23
+ aiRoute: String,
23
24
  }
24
25
 
25
26
  #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
@@ -42,13 +43,14 @@ mod tests {
42
43
  let top_p: f32 = 0.95;
43
44
  let serverconfig_file_path = Path::new("../../serverconfig.json");
44
45
  let absolute_path = serverconfig_file_path.canonicalize().unwrap();
46
+ let testing = true; // This causes the JSON being output from run_pipeline() to be in LLM JSON format
45
47
 
46
48
  // Read the file
47
49
  let data = fs::read_to_string(absolute_path).unwrap();
48
50
 
49
51
  // Parse the JSON data
50
52
  let serverconfig: ServerConfig = serde_json::from_str(&data).expect("JSON not in serverconfig.json format");
51
-
53
+ let airoute = String::from("../../") + &serverconfig.aiRoute;
52
54
  for genome in &serverconfig.genomes {
53
55
  for dataset in &genome.datasets {
54
56
  match &dataset.aifiles {
@@ -83,7 +85,6 @@ mod tests {
83
85
  .expect("Ollama server not found");
84
86
  let embedding_model = ollama_client.embedding_model(ollama_embedding_model_name);
85
87
  let comp_model = ollama_client.completion_model(ollama_comp_model_name);
86
-
87
88
  for chart in ai_json.charts.clone() {
88
89
  match chart {
89
90
  super::super::Charts::Summary(testdata) => {
@@ -100,13 +101,16 @@ mod tests {
100
101
  &dataset_db,
101
102
  &genedb,
102
103
  &ai_json,
104
+ &airoute,
105
+ testing,
103
106
  )
104
107
  .await;
105
- let mut llm_json_value: super::super::SummaryType = serde_json::from_str(&llm_output.unwrap()).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
106
- let mut expected_json_value: super::super::SummaryType = serde_json::from_str(&ques_ans.answer).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
108
+ let llm_json_value: super::super::SummaryType = serde_json::from_str(&llm_output.unwrap()).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
109
+ let sum: super::super::SummaryType = ques_ans.answer;
110
+ //println!("expected answer:{:?}", &sum);
107
111
  assert_eq!(
108
112
  llm_json_value.sort_summarytype_struct(),
109
- expected_json_value.sort_summarytype_struct()
113
+ sum.sort_summarytype_struct()
110
114
  );
111
115
  }
112
116
  }
@@ -142,13 +146,27 @@ mod tests {
142
146
  &dataset_db,
143
147
  &genedb,
144
148
  &ai_json,
149
+ &airoute,
150
+ testing,
145
151
  )
146
152
  .await;
147
- let mut llm_json_value: super::super::SummaryType = serde_json::from_str(&llm_output.unwrap()).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
148
- let mut expected_json_value: super::super::SummaryType = serde_json::from_str(&ques_ans.answer).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
153
+ //println!("user_input:{}", user_input);
154
+ //println!("llm_answer:{:?}", llm_output);
155
+ //println!("expected answer:{:?}", &ques_ans.answer);
156
+ let llm_json_value: super::super::SummaryType = serde_json::from_str(&llm_output.unwrap()).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
157
+ //println!(
158
+ // "llm_answer:{:?}",
159
+ // llm_json_value.clone().sort_summarytype_struct()
160
+ //);
161
+ //println!(
162
+ // "expected answer:{:?}",
163
+ // &expected_json_value.clone().sort_summarytype_struct()
164
+ //);
165
+ let sum: super::super::SummaryType = ques_ans.answer;
166
+ //println!("expected answer:{:?}", &sum);
149
167
  assert_eq!(
150
168
  llm_json_value.sort_summarytype_struct(),
151
- expected_json_value.sort_summarytype_struct()
169
+ sum.sort_summarytype_struct()
152
170
  );
153
171
  } else {
154
172
  panic!("The user input is empty");