@sjcrh/proteinpaint-rust 2.166.0 → 2.169.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +1 -1
- package/src/aichatbot.rs +355 -142
- package/src/manhattan_plot.rs +8 -8
- package/src/test_ai.rs +26 -8
package/package.json
CHANGED
package/src/aichatbot.rs
CHANGED
|
@@ -30,21 +30,51 @@ pub struct AiJsonFormat {
|
|
|
30
30
|
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
31
31
|
enum Charts {
|
|
32
32
|
// More chart types will be added here later
|
|
33
|
-
Summary(
|
|
34
|
-
DE(
|
|
33
|
+
Summary(TrainTestDataSummary),
|
|
34
|
+
DE(TrainTestDataDE),
|
|
35
35
|
}
|
|
36
36
|
|
|
37
37
|
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
38
|
-
struct
|
|
38
|
+
struct TrainTestDataSummary {
|
|
39
39
|
SystemPrompt: String,
|
|
40
|
-
TrainingData: Vec<
|
|
41
|
-
TestData: Vec<
|
|
40
|
+
TrainingData: Vec<QuestionAnswerSummary>,
|
|
41
|
+
TestData: Vec<QuestionAnswerSummary>,
|
|
42
42
|
}
|
|
43
43
|
|
|
44
44
|
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
45
|
-
struct
|
|
45
|
+
struct QuestionAnswerSummary {
|
|
46
46
|
question: String,
|
|
47
|
-
answer:
|
|
47
|
+
answer: SummaryType,
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
51
|
+
struct TrainTestDataDE {
|
|
52
|
+
SystemPrompt: String,
|
|
53
|
+
TrainingData: Vec<QuestionAnswerDE>,
|
|
54
|
+
TestData: Vec<QuestionAnswerDE>,
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
58
|
+
struct QuestionAnswerDE {
|
|
59
|
+
question: String,
|
|
60
|
+
answer: DEType,
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
64
|
+
struct DEType {
|
|
65
|
+
action: String,
|
|
66
|
+
DE_output: DETerms,
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
70
|
+
struct DETerms {
|
|
71
|
+
group1: GroupType,
|
|
72
|
+
group2: GroupType,
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
76
|
+
struct GroupType {
|
|
77
|
+
name: String,
|
|
48
78
|
}
|
|
49
79
|
|
|
50
80
|
#[allow(non_camel_case_types)]
|
|
@@ -77,6 +107,27 @@ async fn main() -> Result<()> {
|
|
|
77
107
|
None => panic!("user_input field is missing in input json"),
|
|
78
108
|
}
|
|
79
109
|
|
|
110
|
+
let dataset_db_json: &JsonValue = &json_string["dataset_db"];
|
|
111
|
+
let dataset_db_str: &str;
|
|
112
|
+
match dataset_db_json.as_str() {
|
|
113
|
+
Some(inp) => dataset_db_str = inp,
|
|
114
|
+
None => panic!("dataset_db field is missing in input json"),
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
let genedb_json: &JsonValue = &json_string["genedb"];
|
|
118
|
+
let genedb_str: &str;
|
|
119
|
+
match genedb_json.as_str() {
|
|
120
|
+
Some(inp) => genedb_str = inp,
|
|
121
|
+
None => panic!("genedb field is missing in input json"),
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
let aiRoute_json: &JsonValue = &json_string["aiRoute"];
|
|
125
|
+
let aiRoute_str: &str;
|
|
126
|
+
match aiRoute_json.as_str() {
|
|
127
|
+
Some(inp) => aiRoute_str = inp,
|
|
128
|
+
None => panic!("aiRoute field is missing in input json"),
|
|
129
|
+
}
|
|
130
|
+
|
|
80
131
|
if user_input.len() == 0 {
|
|
81
132
|
panic!("The user input is empty");
|
|
82
133
|
}
|
|
@@ -124,8 +175,9 @@ async fn main() -> Result<()> {
|
|
|
124
175
|
let ai_json: AiJsonFormat =
|
|
125
176
|
serde_json::from_str(&ai_data).expect("AI JSON file does not have the correct format");
|
|
126
177
|
|
|
127
|
-
let genedb = String::from(tpmasterdir) + &"/" + &
|
|
128
|
-
let dataset_db = String::from(tpmasterdir) + &"/" + &
|
|
178
|
+
let genedb = String::from(tpmasterdir) + &"/" + &genedb_str;
|
|
179
|
+
let dataset_db = String::from(tpmasterdir) + &"/" + &dataset_db_str;
|
|
180
|
+
let airoute = String::from(binpath) + &"/../../" + &aiRoute_str;
|
|
129
181
|
|
|
130
182
|
let apilink_json: &JsonValue = &json_string["apilink"];
|
|
131
183
|
let apilink: &str;
|
|
@@ -160,7 +212,7 @@ async fn main() -> Result<()> {
|
|
|
160
212
|
let temperature: f64 = 0.01;
|
|
161
213
|
let max_new_tokens: usize = 512;
|
|
162
214
|
let top_p: f32 = 0.95;
|
|
163
|
-
|
|
215
|
+
let testing = false; // This variable is always false in production, this is true in test_ai.rs for testing code
|
|
164
216
|
if llm_backend_name != "ollama" && llm_backend_name != "SJ" {
|
|
165
217
|
panic!(
|
|
166
218
|
"This code currently supports only Ollama and SJ provider. llm_backend_name must be \"ollama\" or \"SJ\""
|
|
@@ -185,6 +237,8 @@ async fn main() -> Result<()> {
|
|
|
185
237
|
&dataset_db,
|
|
186
238
|
&genedb,
|
|
187
239
|
&ai_json,
|
|
240
|
+
&airoute,
|
|
241
|
+
testing,
|
|
188
242
|
)
|
|
189
243
|
.await;
|
|
190
244
|
} else if llm_backend_name == "SJ".to_string() {
|
|
@@ -207,6 +261,8 @@ async fn main() -> Result<()> {
|
|
|
207
261
|
&dataset_db,
|
|
208
262
|
&genedb,
|
|
209
263
|
&ai_json,
|
|
264
|
+
&airoute,
|
|
265
|
+
testing,
|
|
210
266
|
)
|
|
211
267
|
.await;
|
|
212
268
|
}
|
|
@@ -239,6 +295,8 @@ pub async fn run_pipeline(
|
|
|
239
295
|
dataset_db: &str,
|
|
240
296
|
genedb: &str,
|
|
241
297
|
ai_json: &AiJsonFormat,
|
|
298
|
+
ai_route: &str,
|
|
299
|
+
testing: bool,
|
|
242
300
|
) -> Option<String> {
|
|
243
301
|
let mut classification: String = classify_query_by_dataset_type(
|
|
244
302
|
user_input,
|
|
@@ -248,6 +306,7 @@ pub async fn run_pipeline(
|
|
|
248
306
|
temperature,
|
|
249
307
|
max_new_tokens,
|
|
250
308
|
top_p,
|
|
309
|
+
ai_route,
|
|
251
310
|
)
|
|
252
311
|
.await;
|
|
253
312
|
classification = classification.replace("\"", "");
|
|
@@ -263,13 +322,20 @@ pub async fn run_pipeline(
|
|
|
263
322
|
top_p,
|
|
264
323
|
)
|
|
265
324
|
.await;
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
325
|
+
if testing == true {
|
|
326
|
+
final_output = format!(
|
|
327
|
+
"{{\"{}\":\"{}\",\"{}\":[{}}}",
|
|
328
|
+
"action",
|
|
329
|
+
"dge",
|
|
330
|
+
"DE_output",
|
|
331
|
+
de_result + &"]"
|
|
332
|
+
);
|
|
333
|
+
} else {
|
|
334
|
+
final_output = format!(
|
|
335
|
+
"{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
|
|
336
|
+
"type", "html", "html", "DE agent not implemented yet"
|
|
337
|
+
);
|
|
338
|
+
}
|
|
273
339
|
} else if classification == "summary".to_string() {
|
|
274
340
|
final_output = extract_summary_information(
|
|
275
341
|
user_input,
|
|
@@ -282,30 +348,83 @@ pub async fn run_pipeline(
|
|
|
282
348
|
dataset_db,
|
|
283
349
|
genedb,
|
|
284
350
|
ai_json,
|
|
351
|
+
testing,
|
|
285
352
|
)
|
|
286
353
|
.await;
|
|
287
354
|
} else if classification == "hierarchical".to_string() {
|
|
288
355
|
// Not implemented yet
|
|
289
|
-
|
|
356
|
+
if testing == true {
|
|
357
|
+
final_output = format!("{{\"{}\":\"{}\"}}", "action", "hierarchical");
|
|
358
|
+
} else {
|
|
359
|
+
final_output = format!(
|
|
360
|
+
"{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
|
|
361
|
+
"type", "html", "html", "hierarchical clustering agent not implemented yet"
|
|
362
|
+
);
|
|
363
|
+
}
|
|
290
364
|
} else if classification == "snv_indel".to_string() {
|
|
291
365
|
// Not implemented yet
|
|
292
|
-
|
|
366
|
+
if testing == true {
|
|
367
|
+
final_output = format!("{{\"{}\":\"{}\"}}", "action", "snv_indel");
|
|
368
|
+
} else {
|
|
369
|
+
final_output = format!(
|
|
370
|
+
"{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
|
|
371
|
+
"type", "html", "html", "snv_indel agent not implemented yet"
|
|
372
|
+
);
|
|
373
|
+
}
|
|
293
374
|
} else if classification == "cnv".to_string() {
|
|
294
375
|
// Not implemented yet
|
|
295
|
-
|
|
376
|
+
if testing == true {
|
|
377
|
+
final_output = format!("{{\"{}\":\"{}\"}}", "action", "cnv");
|
|
378
|
+
} else {
|
|
379
|
+
final_output = format!(
|
|
380
|
+
"{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
|
|
381
|
+
"type", "html", "html", "cnv agent not implemented yet"
|
|
382
|
+
);
|
|
383
|
+
}
|
|
296
384
|
} else if classification == "variant_calling".to_string() {
|
|
297
385
|
// Not implemented yet and will never be supported. Need a separate messages for this
|
|
298
|
-
|
|
386
|
+
if testing == true {
|
|
387
|
+
final_output = format!("{{\"{}\":\"{}\"}}", "action", "variant_calling");
|
|
388
|
+
} else {
|
|
389
|
+
final_output = format!(
|
|
390
|
+
"{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
|
|
391
|
+
"type", "html", "html", "variant_calling agent not implemented yet"
|
|
392
|
+
);
|
|
393
|
+
}
|
|
299
394
|
} else if classification == "survival".to_string() {
|
|
300
395
|
// Not implemented yet
|
|
301
|
-
|
|
396
|
+
if testing == true {
|
|
397
|
+
final_output = format!("{{\"{}\":\"{}\"}}", "action", "surivial");
|
|
398
|
+
} else {
|
|
399
|
+
final_output = format!(
|
|
400
|
+
"{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
|
|
401
|
+
"type", "html", "html", "survival agent not implemented yet"
|
|
402
|
+
);
|
|
403
|
+
}
|
|
302
404
|
} else if classification == "none".to_string() {
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
405
|
+
if testing == true {
|
|
406
|
+
final_output = format!(
|
|
407
|
+
"{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
|
|
408
|
+
"action", "none", "message", "The input query did not match any known features in Proteinpaint"
|
|
409
|
+
);
|
|
410
|
+
} else {
|
|
411
|
+
final_output = format!(
|
|
412
|
+
"{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
|
|
413
|
+
"type", "html", "html", "The input query did not match any known features in Proteinpaint"
|
|
414
|
+
);
|
|
415
|
+
}
|
|
307
416
|
} else {
|
|
308
|
-
|
|
417
|
+
if testing == true {
|
|
418
|
+
final_output = format!("{{\"{}\":\"{}\"}}", "action", "unknown:".to_string() + &classification);
|
|
419
|
+
} else {
|
|
420
|
+
final_output = format!(
|
|
421
|
+
"{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
|
|
422
|
+
"type",
|
|
423
|
+
"html",
|
|
424
|
+
"html",
|
|
425
|
+
"unknown:".to_string() + &classification
|
|
426
|
+
);
|
|
427
|
+
}
|
|
309
428
|
}
|
|
310
429
|
Some(final_output)
|
|
311
430
|
}
|
|
@@ -313,101 +432,33 @@ pub async fn run_pipeline(
|
|
|
313
432
|
async fn classify_query_by_dataset_type(
|
|
314
433
|
user_input: &str,
|
|
315
434
|
comp_model: impl rig::completion::CompletionModel + 'static,
|
|
316
|
-
|
|
435
|
+
_embedding_model: impl rig::embeddings::EmbeddingModel + 'static,
|
|
317
436
|
llm_backend_type: &llm_backend,
|
|
318
437
|
temperature: f64,
|
|
319
438
|
max_new_tokens: usize,
|
|
320
439
|
top_p: f32,
|
|
440
|
+
ai_route: &str,
|
|
321
441
|
) -> String {
|
|
322
|
-
//
|
|
323
|
-
let
|
|
324
|
-
|
|
325
|
-
If a ProteinPaint dataset contains SNV/Indel/SV data then return JSON with single key, 'snv_indel'.
|
|
326
|
-
|
|
327
|
-
---
|
|
328
|
-
|
|
329
|
-
Copy number variation (CNV) is a phenomenon in which sections of the genome are repeated and the number of repeats in the genome varies between individuals.[1] Copy number variation is a special type of structural variation: specifically, it is a type of duplication or deletion event that affects a considerable number of base pairs.
|
|
330
|
-
|
|
331
|
-
If a ProteinPaint dataset contains copy number variation data then return JSON with single key, 'cnv'.
|
|
332
|
-
|
|
333
|
-
---
|
|
334
|
-
|
|
335
|
-
Structural variants/fusions (SV) are genomic mutations when eith a DNA region is translocated or copied to an entirely different genomic locus. In case of transcriptomic data, when RNA is fused from two different genes its called a gene fusion.
|
|
336
|
-
|
|
337
|
-
If a ProteinPaint dataset contains structural variation or gene fusion data then return JSON with single key, 'sv_fusion'.
|
|
338
|
-
---
|
|
339
|
-
|
|
340
|
-
Hierarchical clustering of gene expression is an unsupervised learning technique where several number of relevant genes and the samples are clustered so as to determine (previously unknown) cohorts of samples (or patients) or structure in data. It is very commonly used to determine subtypes of a particular disease based on RNA sequencing data.
|
|
341
|
-
|
|
342
|
-
If a ProteinPaint dataset contains hierarchical data then return JSON with single key, 'hierarchical'.
|
|
343
|
-
|
|
344
|
-
---
|
|
345
|
-
|
|
346
|
-
Differential Gene Expression (DGE or DE) is a technique where the most upregulated (or highest) and downregulated (or lowest) genes between two cohorts of samples (or patients) are determined from a pool of THOUSANDS of genes. Differential gene expression CANNOT be computed for a SINGLE gene. A volcano plot is shown with fold-change in the x-axis and adjusted p-value on the y-axis. So, the upregulated and downregulared genes are on opposite sides of the graph and the most significant genes (based on adjusted p-value) is on the top of the graph. Following differential gene expression generally GeneSet Enrichment Analysis (GSEA) is carried out where based on the genes and their corresponding fold changes the upregulation/downregulation of genesets (or pathways) is determined.
|
|
347
|
-
|
|
348
|
-
Sample Query1: \"Which gene has the highest expression between the two genders\"
|
|
349
|
-
Sample Answer1: { \"answer\": \"dge\" }
|
|
350
|
-
|
|
351
|
-
Sample Query2: \"Which gene has the lowest expression between the two races\"
|
|
352
|
-
Sample Answer2: { \"answer\": \"dge\" }
|
|
353
|
-
|
|
354
|
-
Sample Query1: \"Which genes are the most upregulated genes between group A and group B\"
|
|
355
|
-
Sample Answer1: { \"answer\": \"dge\" }
|
|
356
|
-
|
|
357
|
-
Sample Query3: \"Which gene are overexpressed between male and female\"
|
|
358
|
-
Sample Answer3: { \"answer\": \"dge\" }
|
|
359
|
-
|
|
360
|
-
Sample Query4: \"Which gene are housekeeping genes between male and female\"
|
|
361
|
-
Sample Answer4: { \"answer\": \"dge\" }
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
If a ProteinPaint dataset contains differential gene expression data then return JSON with single key, 'dge'.
|
|
365
|
-
|
|
366
|
-
---
|
|
442
|
+
// Read the file
|
|
443
|
+
let ai_route_data = fs::read_to_string(ai_route).unwrap();
|
|
367
444
|
|
|
368
|
-
|
|
445
|
+
// Parse the JSON data
|
|
446
|
+
let ai_json: Value = serde_json::from_str(&ai_route_data).expect("AI JSON file does not have the correct format");
|
|
369
447
|
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
1) Kaplan-Meier (HM) analysis is a univariate test that only takes into account a single categorical variable.
|
|
373
|
-
2) Cox proportional hazards model (coxph) is a multivariate test that can take into account multiple variables.
|
|
374
|
-
|
|
375
|
-
The hazard ratio (HR) is an indicator of the effect of the stimulus (e.g. drug dose, treatment) between two cohorts of patients.
|
|
376
|
-
HR = 1: No effect
|
|
377
|
-
HR < 1: Reduction in the hazard
|
|
378
|
-
HR > 1: Increase in Hazard
|
|
379
|
-
|
|
380
|
-
Sample Query1: \"Compare survival rates between group A and B\"
|
|
381
|
-
Sample Answer1: { \"answer\": \"survival\" }
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
If a ProteinPaint dataset contains survival data then return JSON with single key, 'survival'.
|
|
385
|
-
|
|
386
|
-
---
|
|
387
|
-
|
|
388
|
-
Next generation sequencing reads (NGS) are mapped to a human genome using alignment algorithm such as burrows-wheelers alignment algorithm. Then these reads are called using variant calling algorithms such as GATK (Genome Analysis Toolkit). However this type of analysis is too compute intensive and beyond the scope of visualization software such as ProteinPaint.
|
|
389
|
-
|
|
390
|
-
If a user query asks about variant calling or mapping reads then JSON with single key, 'variant_calling'.
|
|
391
|
-
|
|
392
|
-
---
|
|
393
|
-
|
|
394
|
-
Summary plot in ProteinPaint shows the various facets of the datasets. Show expression of a SINGLE gene or compare the expression of a SINGLE gene across two different cohorts defined by the user. It may show all the samples according to their respective diagnosis or subtypes of cancer. It is also useful for comparing and correlating different clinical variables. It can show all possible distributions, frequency of a category, overlay, correlate or cross-tabulate with another variable on top of it. If a user query asks about a SINGLE gene expression or correlating clinical variables then return JSON with single key, 'summary'.
|
|
395
|
-
|
|
396
|
-
Sample Query1: \"Show all fusions for patients with age less than 30\"
|
|
397
|
-
Sample Answer1: { \"answer\": \"summary\" }
|
|
398
|
-
|
|
399
|
-
Sample Query2: \"List all molecular subtypes of leukemia\"
|
|
400
|
-
Sample Answer2: { \"answer\": \"summary\" }
|
|
401
|
-
|
|
402
|
-
Sample Query3: \"is tp53 expression higher in men than women ?\"
|
|
403
|
-
Sample Answer3: { \"answer\": \"summary\" }
|
|
404
|
-
|
|
405
|
-
Sample Query4: \"Compare ATM expression between races for women greater than 80yrs\"
|
|
406
|
-
Sample Answer4: { \"answer\": \"summary\" }
|
|
448
|
+
// Create a string to hold the file contents
|
|
449
|
+
let mut contents = String::from("");
|
|
407
450
|
|
|
451
|
+
if let Some(object) = ai_json.as_object() {
|
|
452
|
+
for (_key, value) in object {
|
|
453
|
+
contents += &value.as_str().unwrap();
|
|
454
|
+
contents += "---"; // Adding delimiter
|
|
455
|
+
}
|
|
456
|
+
}
|
|
408
457
|
|
|
409
|
-
|
|
410
|
-
|
|
458
|
+
// Removing the last "---" characters
|
|
459
|
+
contents.pop();
|
|
460
|
+
contents.pop();
|
|
461
|
+
contents.pop();
|
|
411
462
|
|
|
412
463
|
// Split the contents by the delimiter "---"
|
|
413
464
|
let parts: Vec<&str> = contents.split("---").collect();
|
|
@@ -438,18 +489,18 @@ If a query does not match any of the fields described above, then return JSON wi
|
|
|
438
489
|
rag_docs.push(part.trim().to_string())
|
|
439
490
|
}
|
|
440
491
|
|
|
441
|
-
//let top_k: usize = 3;
|
|
492
|
+
//let top_k: usize = 3; // Embedding model not used currently
|
|
442
493
|
// Create embeddings and add to vector store
|
|
443
|
-
let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
494
|
+
//let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
|
|
495
|
+
// .documents(rag_docs)
|
|
496
|
+
// .expect("Reason1")
|
|
497
|
+
// .build()
|
|
498
|
+
// .await
|
|
499
|
+
// .unwrap();
|
|
449
500
|
|
|
450
|
-
|
|
451
|
-
let mut vector_store = InMemoryVectorStore::<String>::default();
|
|
452
|
-
InMemoryVectorStore::add_documents(&mut vector_store, embeddings);
|
|
501
|
+
//// Create vector store
|
|
502
|
+
//let mut vector_store = InMemoryVectorStore::<String>::default();
|
|
503
|
+
//InMemoryVectorStore::add_documents(&mut vector_store, embeddings);
|
|
453
504
|
|
|
454
505
|
// Create RAG agent
|
|
455
506
|
let agent = AgentBuilder::new(comp_model).preamble(&(String::from("Generate classification for the user query into summary, dge, hierarchical, snv_indel, cnv, variant_calling, sv_fusion and none categories. Return output in JSON with ALWAYS a single word answer { \"answer\": \"dge\" }, that is 'summary' for summary plot, 'dge' for differential gene expression, 'hierarchical' for hierarchical clustering, 'snv_indel' for SNV/Indel, 'cnv' for CNV and 'sv_fusion' for SV/fusion, 'variant_calling' for variant calling, 'surivial' for survival data, 'none' for none of the previously described categories. The summary plot list and summarizes the cohort of patients according to the user query. The answer should always be in lower case\n The options are as follows:\n") + &contents + "\nQuestion= {question} \nanswer")).temperature(temperature).additional_params(additional).build();
|
|
@@ -801,6 +852,7 @@ async fn extract_summary_information(
|
|
|
801
852
|
dataset_db: &str,
|
|
802
853
|
genedb: &str,
|
|
803
854
|
ai_json: &AiJsonFormat,
|
|
855
|
+
testing: bool,
|
|
804
856
|
) -> String {
|
|
805
857
|
let (rag_docs, db_vec) = parse_dataset_db(dataset_db).await;
|
|
806
858
|
let additional;
|
|
@@ -845,7 +897,7 @@ async fn extract_summary_information(
|
|
|
845
897
|
.filter(|x| user_words2.contains(&x.to_lowercase()))
|
|
846
898
|
.collect();
|
|
847
899
|
|
|
848
|
-
let mut summary_data_check: Option<
|
|
900
|
+
let mut summary_data_check: Option<TrainTestDataSummary> = None;
|
|
849
901
|
for chart in ai_json.charts.clone() {
|
|
850
902
|
if let Charts::Summary(traindata) = chart {
|
|
851
903
|
summary_data_check = Some(traindata);
|
|
@@ -858,6 +910,7 @@ async fn extract_summary_information(
|
|
|
858
910
|
let mut training_data: String = String::from("");
|
|
859
911
|
let mut train_iter = 0;
|
|
860
912
|
for ques_ans in summary_data.TrainingData {
|
|
913
|
+
let summary_answer: SummaryType = ques_ans.answer;
|
|
861
914
|
train_iter += 1;
|
|
862
915
|
training_data += "Example question";
|
|
863
916
|
training_data += &train_iter.to_string();
|
|
@@ -867,7 +920,7 @@ async fn extract_summary_information(
|
|
|
867
920
|
training_data += "Example answer";
|
|
868
921
|
training_data += &train_iter.to_string();
|
|
869
922
|
training_data += &":";
|
|
870
|
-
training_data += &
|
|
923
|
+
training_data += &serde_json::to_string(&summary_answer).unwrap();
|
|
871
924
|
training_data += &"\n";
|
|
872
925
|
}
|
|
873
926
|
|
|
@@ -919,7 +972,8 @@ async fn extract_summary_information(
|
|
|
919
972
|
}
|
|
920
973
|
}
|
|
921
974
|
//println!("final_llm_json:{}", final_llm_json);
|
|
922
|
-
let final_validated_json =
|
|
975
|
+
let final_validated_json =
|
|
976
|
+
validate_summary_output(final_llm_json.clone(), db_vec, common_genes, ai_json, testing);
|
|
923
977
|
final_validated_json
|
|
924
978
|
}
|
|
925
979
|
None => {
|
|
@@ -949,7 +1003,7 @@ struct SummaryType {
|
|
|
949
1003
|
|
|
950
1004
|
impl SummaryType {
|
|
951
1005
|
#[allow(dead_code)]
|
|
952
|
-
pub fn sort_summarytype_struct(
|
|
1006
|
+
pub fn sort_summarytype_struct(mut self) -> SummaryType {
|
|
953
1007
|
// This function is necessary for testing (test_ai.rs) to see if two variables of type "SummaryType" are equal or not. Without this a vector of two Summarytype holding the same values but in different order will be classified separately.
|
|
954
1008
|
self.summaryterms.sort();
|
|
955
1009
|
|
|
@@ -957,6 +1011,7 @@ impl SummaryType {
|
|
|
957
1011
|
Some(ref mut filterterms) => filterterms.sort(),
|
|
958
1012
|
None => {}
|
|
959
1013
|
}
|
|
1014
|
+
self.clone()
|
|
960
1015
|
}
|
|
961
1016
|
}
|
|
962
1017
|
|
|
@@ -974,7 +1029,7 @@ impl PartialOrd for SummaryTerms {
|
|
|
974
1029
|
(SummaryTerms::clinical(_), SummaryTerms::clinical(_)) => Some(std::cmp::Ordering::Equal),
|
|
975
1030
|
(SummaryTerms::geneExpression(_), SummaryTerms::geneExpression(_)) => Some(std::cmp::Ordering::Equal),
|
|
976
1031
|
(SummaryTerms::clinical(_), SummaryTerms::geneExpression(_)) => Some(std::cmp::Ordering::Greater),
|
|
977
|
-
(SummaryTerms::geneExpression(_), SummaryTerms::clinical(_)) => Some(std::cmp::Ordering::
|
|
1032
|
+
(SummaryTerms::geneExpression(_), SummaryTerms::clinical(_)) => Some(std::cmp::Ordering::Less),
|
|
978
1033
|
}
|
|
979
1034
|
}
|
|
980
1035
|
}
|
|
@@ -1063,6 +1118,7 @@ fn validate_summary_output(
|
|
|
1063
1118
|
db_vec: Vec<DbRows>,
|
|
1064
1119
|
common_genes: Vec<String>,
|
|
1065
1120
|
ai_json: &AiJsonFormat,
|
|
1121
|
+
testing: bool,
|
|
1066
1122
|
) -> String {
|
|
1067
1123
|
let json_value: SummaryType =
|
|
1068
1124
|
serde_json::from_str(&raw_llm_json).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
|
|
@@ -1094,7 +1150,7 @@ fn validate_summary_output(
|
|
|
1094
1150
|
match term_verification.correct_field {
|
|
1095
1151
|
Some(tm) => validated_summary_terms.push(SummaryTerms::clinical(tm)),
|
|
1096
1152
|
None => {
|
|
1097
|
-
message = message + &"
|
|
1153
|
+
message = message + &"'" + &clin + &"'" + &" not found in db.";
|
|
1098
1154
|
}
|
|
1099
1155
|
}
|
|
1100
1156
|
} else if Some(term_verification.correct_field.clone()).is_some()
|
|
@@ -1122,7 +1178,7 @@ fn validate_summary_output(
|
|
|
1122
1178
|
if num_gene_verification == 0 || common_genes.len() == 0 {
|
|
1123
1179
|
if message.to_lowercase().contains(&gene.to_lowercase()) { // Check if the LLM has already added the message, if not then add it
|
|
1124
1180
|
} else {
|
|
1125
|
-
message = message + &"
|
|
1181
|
+
message = message + &"'" + &gene + &"'" + &" not found in genedb.";
|
|
1126
1182
|
}
|
|
1127
1183
|
}
|
|
1128
1184
|
}
|
|
@@ -1138,6 +1194,8 @@ fn validate_summary_output(
|
|
|
1138
1194
|
}
|
|
1139
1195
|
}
|
|
1140
1196
|
|
|
1197
|
+
let mut pp_plot_json: Value; // The PP compliant plot JSON
|
|
1198
|
+
pp_plot_json = serde_json::from_str(&"{\"chartType\":\"summary\"}").expect("Not a valid JSON");
|
|
1141
1199
|
match &json_value.filter {
|
|
1142
1200
|
Some(filter_terms_array) => {
|
|
1143
1201
|
let mut validated_filter_terms = Vec::<FilterTerm>::new();
|
|
@@ -1168,21 +1226,21 @@ fn validate_summary_output(
|
|
|
1168
1226
|
validated_filter_terms.push(categorical_filter_term);
|
|
1169
1227
|
}
|
|
1170
1228
|
if term_verification.correct_field.is_none() {
|
|
1171
|
-
message = message + &"
|
|
1229
|
+
message = message + &"'" + &categorical.term + &"' filter term not found in db";
|
|
1172
1230
|
}
|
|
1173
1231
|
if value_verification.is_none() {
|
|
1174
1232
|
message = message
|
|
1175
|
-
+ &"
|
|
1233
|
+
+ &"'"
|
|
1176
1234
|
+ &categorical.value
|
|
1177
|
-
+ &"
|
|
1235
|
+
+ &"' filter value not found for filter field '"
|
|
1178
1236
|
+ &categorical.term
|
|
1179
|
-
+ "
|
|
1237
|
+
+ "' in db";
|
|
1180
1238
|
}
|
|
1181
1239
|
}
|
|
1182
1240
|
FilterTerm::Numeric(numeric) => {
|
|
1183
1241
|
let term_verification = verify_json_field(&numeric.term, &db_vec);
|
|
1184
1242
|
if term_verification.correct_field.is_none() {
|
|
1185
|
-
message = message + &"
|
|
1243
|
+
message = message + &"'" + &numeric.term + &"' filter term not found in db";
|
|
1186
1244
|
} else {
|
|
1187
1245
|
let numeric_filter_term: FilterTerm = FilterTerm::Numeric(numeric.clone());
|
|
1188
1246
|
validated_filter_terms.push(numeric_filter_term);
|
|
@@ -1229,8 +1287,68 @@ fn validate_summary_output(
|
|
|
1229
1287
|
}
|
|
1230
1288
|
|
|
1231
1289
|
if validated_filter_terms.len() > 0 {
|
|
1232
|
-
if
|
|
1233
|
-
obj.
|
|
1290
|
+
if testing == true {
|
|
1291
|
+
if let Some(obj) = new_json.as_object_mut() {
|
|
1292
|
+
obj.insert(String::from("filter"), serde_json::json!(validated_filter_terms));
|
|
1293
|
+
}
|
|
1294
|
+
} else {
|
|
1295
|
+
let mut validated_filter_terms_PP: String = "[".to_string();
|
|
1296
|
+
let mut filter_hits = 0;
|
|
1297
|
+
for validated_term in validated_filter_terms {
|
|
1298
|
+
match validated_term {
|
|
1299
|
+
FilterTerm::Categorical(categorical_filter) => {
|
|
1300
|
+
let string_json = "{\"term\":\"".to_string()
|
|
1301
|
+
+ &categorical_filter.term
|
|
1302
|
+
+ &"\", \"category\":\""
|
|
1303
|
+
+ &categorical_filter.value
|
|
1304
|
+
+ &"\"},";
|
|
1305
|
+
validated_filter_terms_PP += &string_json;
|
|
1306
|
+
}
|
|
1307
|
+
FilterTerm::Numeric(numeric_filter) => {
|
|
1308
|
+
let string_json;
|
|
1309
|
+
if numeric_filter.greaterThan.is_some() && numeric_filter.lessThan.is_none() {
|
|
1310
|
+
string_json = "{\"term\":\"".to_string()
|
|
1311
|
+
+ &numeric_filter.term
|
|
1312
|
+
+ &"\", \"gt\":\""
|
|
1313
|
+
+ &numeric_filter.greaterThan.unwrap().to_string()
|
|
1314
|
+
+ &"\"},";
|
|
1315
|
+
} else if numeric_filter.greaterThan.is_none() && numeric_filter.lessThan.is_some() {
|
|
1316
|
+
string_json = "{\"term\":\"".to_string()
|
|
1317
|
+
+ &numeric_filter.term
|
|
1318
|
+
+ &"\", \"lt\":\""
|
|
1319
|
+
+ &numeric_filter.lessThan.unwrap().to_string()
|
|
1320
|
+
+ &"\"},";
|
|
1321
|
+
} else if numeric_filter.greaterThan.is_some() && numeric_filter.lessThan.is_some() {
|
|
1322
|
+
string_json = "{\"term\":\"".to_string()
|
|
1323
|
+
+ &numeric_filter.term
|
|
1324
|
+
+ &"\", \"lt\":\""
|
|
1325
|
+
+ &numeric_filter.lessThan.unwrap().to_string()
|
|
1326
|
+
+ &"\", \"gt\":\""
|
|
1327
|
+
+ &numeric_filter.greaterThan.unwrap().to_string()
|
|
1328
|
+
+ &"\"},";
|
|
1329
|
+
} else {
|
|
1330
|
+
// When both greater and less than are none
|
|
1331
|
+
panic!(
|
|
1332
|
+
"Numeric filter term {} is missing both greater than and less than values. One of them must be defined",
|
|
1333
|
+
&numeric_filter.term
|
|
1334
|
+
);
|
|
1335
|
+
}
|
|
1336
|
+
validated_filter_terms_PP += &string_json;
|
|
1337
|
+
}
|
|
1338
|
+
};
|
|
1339
|
+
filter_hits += 1;
|
|
1340
|
+
}
|
|
1341
|
+
println!("validated_filter_terms_PP:{}", validated_filter_terms_PP);
|
|
1342
|
+
if filter_hits > 0 {
|
|
1343
|
+
validated_filter_terms_PP.pop();
|
|
1344
|
+
validated_filter_terms_PP += &"]";
|
|
1345
|
+
if let Some(obj) = pp_plot_json.as_object_mut() {
|
|
1346
|
+
obj.insert(
|
|
1347
|
+
String::from("simpleFilter"),
|
|
1348
|
+
serde_json::from_str(&validated_filter_terms_PP).expect("Not a valid JSON"),
|
|
1349
|
+
);
|
|
1350
|
+
}
|
|
1351
|
+
}
|
|
1234
1352
|
}
|
|
1235
1353
|
}
|
|
1236
1354
|
}
|
|
@@ -1240,6 +1358,10 @@ fn validate_summary_output(
|
|
|
1240
1358
|
// Removing terms that are found both in filter term as well summary
|
|
1241
1359
|
let mut validated_summary_terms_final = Vec::<SummaryTerms>::new();
|
|
1242
1360
|
|
|
1361
|
+
let mut sum_iter = 0;
|
|
1362
|
+
let mut pp_json: Value; // New JSON value that will contain items of the final PP compliant JSON
|
|
1363
|
+
pp_json = serde_json::from_str(&"{\"type\":\"plot\"}").expect("Not a valid JSON");
|
|
1364
|
+
|
|
1243
1365
|
for summary_term in &validated_summary_terms {
|
|
1244
1366
|
let mut hit = 0;
|
|
1245
1367
|
match summary_term {
|
|
@@ -1276,9 +1398,53 @@ fn validate_summary_output(
|
|
|
1276
1398
|
}
|
|
1277
1399
|
}
|
|
1278
1400
|
}
|
|
1401
|
+
|
|
1279
1402
|
if hit == 0 {
|
|
1403
|
+
let mut termidpp: Option<TermIDPP> = None;
|
|
1404
|
+
let mut geneexp: Option<GeneExpressionPP> = None;
|
|
1405
|
+
match summary_term {
|
|
1406
|
+
SummaryTerms::clinical(clinical_term) => {
|
|
1407
|
+
termidpp = Some(TermIDPP {
|
|
1408
|
+
id: clinical_term.to_string(),
|
|
1409
|
+
});
|
|
1410
|
+
}
|
|
1411
|
+
SummaryTerms::geneExpression(gene) => {
|
|
1412
|
+
geneexp = Some(GeneExpressionPP {
|
|
1413
|
+
gene: gene.to_string(),
|
|
1414
|
+
r#type: "geneExpression".to_string(),
|
|
1415
|
+
});
|
|
1416
|
+
}
|
|
1417
|
+
}
|
|
1418
|
+
if sum_iter == 0 {
|
|
1419
|
+
if termidpp.is_some() {
|
|
1420
|
+
if let Some(obj) = pp_plot_json.as_object_mut() {
|
|
1421
|
+
obj.insert(String::from("term"), serde_json::json!(Some(termidpp)));
|
|
1422
|
+
}
|
|
1423
|
+
}
|
|
1424
|
+
|
|
1425
|
+
if geneexp.is_some() {
|
|
1426
|
+
let gene_term = GeneTerm { term: geneexp.unwrap() };
|
|
1427
|
+
if let Some(obj) = pp_plot_json.as_object_mut() {
|
|
1428
|
+
obj.insert(String::from("term"), serde_json::json!(gene_term));
|
|
1429
|
+
}
|
|
1430
|
+
}
|
|
1431
|
+
} else if sum_iter == 1 {
|
|
1432
|
+
if termidpp.is_some() {
|
|
1433
|
+
if let Some(obj) = pp_plot_json.as_object_mut() {
|
|
1434
|
+
obj.insert(String::from("term2"), serde_json::json!(Some(termidpp)));
|
|
1435
|
+
}
|
|
1436
|
+
}
|
|
1437
|
+
|
|
1438
|
+
if geneexp.is_some() {
|
|
1439
|
+
let gene_term = GeneTerm { term: geneexp.unwrap() };
|
|
1440
|
+
if let Some(obj) = pp_plot_json.as_object_mut() {
|
|
1441
|
+
obj.insert(String::from("term2"), serde_json::json!(gene_term));
|
|
1442
|
+
}
|
|
1443
|
+
}
|
|
1444
|
+
}
|
|
1280
1445
|
validated_summary_terms_final.push(summary_term.clone())
|
|
1281
1446
|
}
|
|
1447
|
+
sum_iter += 1
|
|
1282
1448
|
}
|
|
1283
1449
|
|
|
1284
1450
|
if let Some(obj) = new_json.as_object_mut() {
|
|
@@ -1288,14 +1454,61 @@ fn validate_summary_output(
|
|
|
1288
1454
|
);
|
|
1289
1455
|
}
|
|
1290
1456
|
|
|
1457
|
+
if let Some(obj) = pp_json.as_object_mut() {
|
|
1458
|
+
// The `if let` ensures we only proceed if the top-level JSON is an object.
|
|
1459
|
+
// Append a new string field.
|
|
1460
|
+
obj.insert(String::from("plot"), serde_json::json!(pp_plot_json));
|
|
1461
|
+
}
|
|
1462
|
+
|
|
1463
|
+
let mut err_json: Value; // Error JSON containing the error message (if present)
|
|
1291
1464
|
if message.len() > 0 {
|
|
1292
|
-
if
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
1465
|
+
if testing == false {
|
|
1466
|
+
err_json = serde_json::from_str(&"{\"type\":\"html\"}").expect("Not a valid JSON");
|
|
1467
|
+
if let Some(obj) = err_json.as_object_mut() {
|
|
1468
|
+
// The `if let` ensures we only proceed if the top-level JSON is an object.
|
|
1469
|
+
// Append a new string field.
|
|
1470
|
+
obj.insert(String::from("html"), serde_json::json!(message));
|
|
1471
|
+
};
|
|
1472
|
+
serde_json::to_string(&err_json).unwrap()
|
|
1473
|
+
} else {
|
|
1474
|
+
if let Some(obj) = new_json.as_object_mut() {
|
|
1475
|
+
// The `if let` ensures we only proceed if the top-level JSON is an object.
|
|
1476
|
+
// Append a new string field.
|
|
1477
|
+
obj.insert(String::from("message"), serde_json::json!(message));
|
|
1478
|
+
};
|
|
1479
|
+
serde_json::to_string(&new_json).unwrap()
|
|
1480
|
+
}
|
|
1481
|
+
} else {
|
|
1482
|
+
if testing == true {
|
|
1483
|
+
// When testing script output native LLM JSON
|
|
1484
|
+
serde_json::to_string(&new_json).unwrap()
|
|
1485
|
+
} else {
|
|
1486
|
+
// When in production output PP compliant JSON
|
|
1487
|
+
serde_json::to_string(&pp_json).unwrap()
|
|
1296
1488
|
}
|
|
1297
1489
|
}
|
|
1298
|
-
|
|
1490
|
+
}
|
|
1491
|
+
|
|
1492
|
+
fn getGeneExpression() -> String {
|
|
1493
|
+
"geneExpression".to_string()
|
|
1494
|
+
}
|
|
1495
|
+
|
|
1496
|
+
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
1497
|
+
struct TermIDPP {
|
|
1498
|
+
id: String,
|
|
1499
|
+
}
|
|
1500
|
+
|
|
1501
|
+
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
1502
|
+
struct GeneTerm {
|
|
1503
|
+
term: GeneExpressionPP,
|
|
1504
|
+
}
|
|
1505
|
+
|
|
1506
|
+
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
1507
|
+
struct GeneExpressionPP {
|
|
1508
|
+
gene: String,
|
|
1509
|
+
// Serde uses this for deserialization.
|
|
1510
|
+
#[serde(default = "getGeneExpression")]
|
|
1511
|
+
r#type: String,
|
|
1299
1512
|
}
|
|
1300
1513
|
|
|
1301
1514
|
#[derive(Debug, Clone)]
|
package/src/manhattan_plot.rs
CHANGED
|
@@ -57,6 +57,7 @@ struct InteractiveData {
|
|
|
57
57
|
x_buffer: i64,
|
|
58
58
|
y_min: f64,
|
|
59
59
|
y_max: f64,
|
|
60
|
+
device_pixel_ratio: f64,
|
|
60
61
|
}
|
|
61
62
|
|
|
62
63
|
#[derive(Serialize)]
|
|
@@ -216,8 +217,10 @@ fn grin2_file_read(
|
|
|
216
217
|
Some(q) => q,
|
|
217
218
|
None => continue,
|
|
218
219
|
};
|
|
220
|
+
|
|
219
221
|
let q_val: f64 = match q_val_str.parse() {
|
|
220
222
|
Ok(v) if v > 0.0 => v,
|
|
223
|
+
Ok(v) if v == 0.0 => 1e-300, // Treat exact 0 as ~1e-300 so we can still show q-values that are 0 and not filter them out
|
|
221
224
|
_ => continue,
|
|
222
225
|
};
|
|
223
226
|
let neg_log10_q = -q_val.log10();
|
|
@@ -335,12 +338,8 @@ fn plot_grin2_manhattan(
|
|
|
335
338
|
let png_width = plot_width + 2 * png_dot_radius;
|
|
336
339
|
let png_height = plot_height + 2 * png_dot_radius;
|
|
337
340
|
|
|
338
|
-
let w: u32 = (png_width *
|
|
339
|
-
|
|
340
|
-
.expect("PNG width too large for u32");
|
|
341
|
-
let h: u32 = (png_height * device_pixel_ratio as u64)
|
|
342
|
-
.try_into()
|
|
343
|
-
.expect("PNG height too large for u32");
|
|
341
|
+
let w: u32 = ((png_width as f64) * dpr) as u32;
|
|
342
|
+
let h: u32 = ((png_height as f64) * dpr) as u32;
|
|
344
343
|
|
|
345
344
|
// Create RGB buffer for Plotters
|
|
346
345
|
let mut buffer = vec![0u8; w as usize * h as usize * 3];
|
|
@@ -402,8 +401,8 @@ fn plot_grin2_manhattan(
|
|
|
402
401
|
|
|
403
402
|
for (i, p) in point_details.iter_mut().enumerate() {
|
|
404
403
|
let (px, py) = pixel_positions[*&sig_indices[i]];
|
|
405
|
-
p.pixel_x = px;
|
|
406
|
-
p.pixel_y = py;
|
|
404
|
+
p.pixel_x = px / dpr;
|
|
405
|
+
p.pixel_y = py / dpr;
|
|
407
406
|
}
|
|
408
407
|
|
|
409
408
|
// flush root drawing area
|
|
@@ -469,6 +468,7 @@ fn plot_grin2_manhattan(
|
|
|
469
468
|
x_buffer,
|
|
470
469
|
y_min,
|
|
471
470
|
y_max,
|
|
471
|
+
device_pixel_ratio: dpr,
|
|
472
472
|
};
|
|
473
473
|
Ok((png_data, interactive_data))
|
|
474
474
|
}
|
package/src/test_ai.rs
CHANGED
|
@@ -20,6 +20,7 @@ mod tests {
|
|
|
20
20
|
ollama_comp_model_name: String,
|
|
21
21
|
ollama_embedding_model_name: String,
|
|
22
22
|
genomes: Vec<Genomes>,
|
|
23
|
+
aiRoute: String,
|
|
23
24
|
}
|
|
24
25
|
|
|
25
26
|
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
@@ -42,13 +43,14 @@ mod tests {
|
|
|
42
43
|
let top_p: f32 = 0.95;
|
|
43
44
|
let serverconfig_file_path = Path::new("../../serverconfig.json");
|
|
44
45
|
let absolute_path = serverconfig_file_path.canonicalize().unwrap();
|
|
46
|
+
let testing = true; // This causes the JSON being output from run_pipeline() to be in LLM JSON format
|
|
45
47
|
|
|
46
48
|
// Read the file
|
|
47
49
|
let data = fs::read_to_string(absolute_path).unwrap();
|
|
48
50
|
|
|
49
51
|
// Parse the JSON data
|
|
50
52
|
let serverconfig: ServerConfig = serde_json::from_str(&data).expect("JSON not in serverconfig.json format");
|
|
51
|
-
|
|
53
|
+
let airoute = String::from("../../") + &serverconfig.aiRoute;
|
|
52
54
|
for genome in &serverconfig.genomes {
|
|
53
55
|
for dataset in &genome.datasets {
|
|
54
56
|
match &dataset.aifiles {
|
|
@@ -83,7 +85,6 @@ mod tests {
|
|
|
83
85
|
.expect("Ollama server not found");
|
|
84
86
|
let embedding_model = ollama_client.embedding_model(ollama_embedding_model_name);
|
|
85
87
|
let comp_model = ollama_client.completion_model(ollama_comp_model_name);
|
|
86
|
-
|
|
87
88
|
for chart in ai_json.charts.clone() {
|
|
88
89
|
match chart {
|
|
89
90
|
super::super::Charts::Summary(testdata) => {
|
|
@@ -100,13 +101,16 @@ mod tests {
|
|
|
100
101
|
&dataset_db,
|
|
101
102
|
&genedb,
|
|
102
103
|
&ai_json,
|
|
104
|
+
&airoute,
|
|
105
|
+
testing,
|
|
103
106
|
)
|
|
104
107
|
.await;
|
|
105
|
-
let
|
|
106
|
-
let
|
|
108
|
+
let llm_json_value: super::super::SummaryType = serde_json::from_str(&llm_output.unwrap()).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
|
|
109
|
+
let sum: super::super::SummaryType = ques_ans.answer;
|
|
110
|
+
//println!("expected answer:{:?}", &sum);
|
|
107
111
|
assert_eq!(
|
|
108
112
|
llm_json_value.sort_summarytype_struct(),
|
|
109
|
-
|
|
113
|
+
sum.sort_summarytype_struct()
|
|
110
114
|
);
|
|
111
115
|
}
|
|
112
116
|
}
|
|
@@ -142,13 +146,27 @@ mod tests {
|
|
|
142
146
|
&dataset_db,
|
|
143
147
|
&genedb,
|
|
144
148
|
&ai_json,
|
|
149
|
+
&airoute,
|
|
150
|
+
testing,
|
|
145
151
|
)
|
|
146
152
|
.await;
|
|
147
|
-
|
|
148
|
-
|
|
153
|
+
//println!("user_input:{}", user_input);
|
|
154
|
+
//println!("llm_answer:{:?}", llm_output);
|
|
155
|
+
//println!("expected answer:{:?}", &ques_ans.answer);
|
|
156
|
+
let llm_json_value: super::super::SummaryType = serde_json::from_str(&llm_output.unwrap()).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
|
|
157
|
+
//println!(
|
|
158
|
+
// "llm_answer:{:?}",
|
|
159
|
+
// llm_json_value.clone().sort_summarytype_struct()
|
|
160
|
+
//);
|
|
161
|
+
//println!(
|
|
162
|
+
// "expected answer:{:?}",
|
|
163
|
+
// &expected_json_value.clone().sort_summarytype_struct()
|
|
164
|
+
//);
|
|
165
|
+
let sum: super::super::SummaryType = ques_ans.answer;
|
|
166
|
+
//println!("expected answer:{:?}", &sum);
|
|
149
167
|
assert_eq!(
|
|
150
168
|
llm_json_value.sort_summarytype_struct(),
|
|
151
|
-
|
|
169
|
+
sum.sort_summarytype_struct()
|
|
152
170
|
);
|
|
153
171
|
} else {
|
|
154
172
|
panic!("The user input is empty");
|