@sjcrh/proteinpaint-rust 2.145.2 → 2.146.4-0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,730 @@
1
+ use anyhow::Result;
2
+ use json::JsonValue;
3
+ use r2d2_sqlite::SqliteConnectionManager;
4
+ use rig::agent::AgentBuilder;
5
+ use rig::client::CompletionClient;
6
+ use rig::client::EmbeddingsClient;
7
+ use rig::completion::Prompt;
8
+ use rig::embeddings::builder::EmbeddingsBuilder;
9
+ use std::collections::HashMap;
10
+ //use rig::providers::ollama;
11
+ use rig::vector_store::in_memory_store::InMemoryVectorStore;
12
+ use schemars::JsonSchema;
13
+ use serde_json::{Map, Value, json};
14
+ use std::io::{self};
15
+ mod sjprovider; // Importing custom rig module for invoking SJ GPU server
16
+
17
+ #[allow(non_camel_case_types)]
18
+ #[derive(Debug, Clone)]
19
+ enum llm_backend {
20
+ Ollama(),
21
+ Sj(),
22
+ }
23
+
24
+ #[derive(Debug, JsonSchema)]
25
+ #[allow(dead_code)]
26
+ struct OutputJson {
27
+ pub answer: String,
28
+ }
29
+
30
+ #[allow(non_camel_case_types)]
31
+ #[derive(Debug, JsonSchema)]
32
+ #[allow(dead_code)]
33
+ enum cutoff_info {
34
+ lesser(f32),
35
+ greater(f32),
36
+ equalto(f32),
37
+ }
38
+
39
+ #[derive(Debug, JsonSchema)]
40
+ #[allow(dead_code)]
41
+ struct Cutoff {
42
+ cutoff_name: cutoff_info,
43
+ units: Option<String>,
44
+ }
45
+
46
+ #[derive(Debug, JsonSchema)]
47
+ #[allow(dead_code)]
48
+ struct Filter {
49
+ name: String,
50
+ cutoff: Cutoff,
51
+ }
52
+
53
+ #[derive(Debug, JsonSchema)]
54
+ #[allow(dead_code)]
55
+ struct Group {
56
+ name: String,
57
+ filter: Filter,
58
+ }
59
+
60
+ #[derive(Debug, JsonSchema)]
61
+ #[allow(dead_code)]
62
+ struct DEOutput {
63
+ group1: Group,
64
+ group2: Group,
65
+ }
66
+
67
+ #[tokio::main]
68
+ async fn main() -> Result<()> {
69
+ let mut input = String::new();
70
+ match io::stdin().read_line(&mut input) {
71
+ // Accepting the piped input from nodejs (or command line from testing)
72
+ Ok(_n) => {
73
+ let input_json = json::parse(&input);
74
+ match input_json {
75
+ Ok(json_string) => {
76
+ let user_input_json: &JsonValue = &json_string["user_input"];
77
+ //let user_input = "Does aspirin leads to decrease in death rates among Africans?";
78
+ //let user_input = "Show the point deletion in TP53 gene.";
79
+ //let user_input = "Generate DE plot for men with weight greater than 30lbs vs women less than 20lbs";
80
+ let user_input: &str;
81
+ match user_input_json.as_str() {
82
+ Some(inp) => user_input = inp,
83
+ None => panic!("user_input field is missing in input json"),
84
+ }
85
+
86
+ let dataset_db_json: &JsonValue = &json_string["dataset_db"];
87
+ let mut dataset_db: Option<&str> = None;
88
+ match dataset_db_json.as_str() {
89
+ Some(inp) => dataset_db = Some(inp),
90
+ None => {}
91
+ }
92
+
93
+ let apilink_json: &JsonValue = &json_string["apilink"];
94
+ let apilink: &str;
95
+ match apilink_json.as_str() {
96
+ Some(inp) => apilink = inp,
97
+ None => panic!("apilink field is missing in input json"),
98
+ }
99
+
100
+ let comp_model_name_json: &JsonValue = &json_string["comp_model_name"];
101
+ let comp_model_name: &str;
102
+ match comp_model_name_json.as_str() {
103
+ Some(inp) => comp_model_name = inp,
104
+ None => panic!("comp_model_name field is missing in input json"),
105
+ }
106
+
107
+ let embedding_model_name_json: &JsonValue = &json_string["embedding_model_name"];
108
+ let embedding_model_name: &str;
109
+ match embedding_model_name_json.as_str() {
110
+ Some(inp) => embedding_model_name = inp,
111
+ None => panic!("embedding_model_name field is missing in input json"),
112
+ }
113
+
114
+ let llm_backend_name_json: &JsonValue = &json_string["llm_backend_name"];
115
+ let llm_backend_name: &str;
116
+ match llm_backend_name_json.as_str() {
117
+ Some(inp) => llm_backend_name = inp,
118
+ None => panic!("llm_backend_name field is missing in input json"),
119
+ }
120
+
121
+ let llm_backend_type: llm_backend;
122
+ let mut final_output: Option<String> = None;
123
+ let temperature: f64 = 0.01;
124
+ let max_new_tokens: usize = 512;
125
+ let top_p: f32 = 0.95;
126
+
127
+ if llm_backend_name != "ollama" && llm_backend_name != "SJ" {
128
+ panic!(
129
+ "This code currently supports only Ollama and SJ provider. llm_backend_name must be \"ollama\" or \"SJ\""
130
+ );
131
+ } else if llm_backend_name == "ollama".to_string() {
132
+ llm_backend_type = llm_backend::Ollama();
133
+ // Initialize Ollama client
134
+ let ollama_client = rig::providers::ollama::Client::builder()
135
+ .base_url(apilink)
136
+ .build()
137
+ .expect("Ollama server not found");
138
+ let embedding_model = ollama_client.embedding_model(embedding_model_name);
139
+ let comp_model = ollama_client.completion_model(comp_model_name);
140
+ final_output = run_pipeline(
141
+ user_input,
142
+ comp_model,
143
+ embedding_model,
144
+ llm_backend_type,
145
+ temperature,
146
+ max_new_tokens,
147
+ top_p,
148
+ dataset_db,
149
+ )
150
+ .await;
151
+ // "gpt-oss:20b" "granite3-dense:latest" "PetrosStav/gemma3-tools:12b" "llama3-groq-tool-use:latest" "PetrosStav/gemma3-tools:12b"
152
+ } else if llm_backend_name == "SJ".to_string() {
153
+ llm_backend_type = llm_backend::Sj();
154
+ // Initialize Sj provider client
155
+ let sj_client = sjprovider::Client::builder()
156
+ .base_url(apilink)
157
+ .build()
158
+ .expect("SJ server not found");
159
+ let embedding_model = sj_client.embedding_model(embedding_model_name);
160
+ let comp_model = sj_client.completion_model(comp_model_name);
161
+ final_output = run_pipeline(
162
+ user_input,
163
+ comp_model,
164
+ embedding_model,
165
+ llm_backend_type,
166
+ temperature,
167
+ max_new_tokens,
168
+ top_p,
169
+ dataset_db,
170
+ )
171
+ .await;
172
+ }
173
+
174
+ match final_output {
175
+ Some(fin_out) => {
176
+ println!("final_output:{:?}", fin_out);
177
+ }
178
+ None => {
179
+ println!("final_output:{{\"{}\":\"{}\"}}", "chartType", "unknown");
180
+ }
181
+ }
182
+ }
183
+ Err(error) => println!("Incorrect json:{}", error),
184
+ }
185
+ }
186
+ Err(error) => println!("Piping error: {}", error),
187
+ }
188
+ Ok(())
189
+ }
190
+
191
+ async fn run_pipeline(
192
+ user_input: &str,
193
+ comp_model: impl rig::completion::CompletionModel + 'static,
194
+ embedding_model: impl rig::embeddings::EmbeddingModel + 'static,
195
+ llm_backend_type: llm_backend,
196
+ temperature: f64,
197
+ max_new_tokens: usize,
198
+ top_p: f32,
199
+ dataset_db: Option<&str>,
200
+ ) -> Option<String> {
201
+ let mut classification: String = classify_query_by_dataset_type(
202
+ user_input,
203
+ comp_model.clone(),
204
+ embedding_model.clone(),
205
+ &llm_backend_type,
206
+ temperature,
207
+ max_new_tokens,
208
+ top_p,
209
+ )
210
+ .await;
211
+ classification = classification.replace("\"", "");
212
+ let final_output;
213
+ if classification == "dge".to_string() {
214
+ let de_result = extract_DE_search_terms_from_query(
215
+ user_input,
216
+ comp_model,
217
+ embedding_model,
218
+ &llm_backend_type,
219
+ temperature,
220
+ max_new_tokens,
221
+ top_p,
222
+ )
223
+ .await;
224
+ final_output = format!(
225
+ "{{\"{}\":\"{}\",\"{}\":[{}}}",
226
+ "chartType",
227
+ "dge",
228
+ "DE_output",
229
+ de_result + &"]"
230
+ );
231
+ } else if classification == "summary".to_string() {
232
+ final_output = extract_summary_information(
233
+ user_input,
234
+ comp_model,
235
+ embedding_model,
236
+ &llm_backend_type,
237
+ temperature,
238
+ max_new_tokens,
239
+ top_p,
240
+ dataset_db,
241
+ )
242
+ .await;
243
+ } else if classification == "hierarchial".to_string() {
244
+ // Not implemented yet
245
+ final_output = format!("{{\"{}\":\"{}\"}}", "chartType", "hierarchial");
246
+ } else if classification == "snv_indel".to_string() {
247
+ // Not implemented yet
248
+ final_output = format!("{{\"{}\":\"{}\"}}", "chartType", "snv_indel");
249
+ } else if classification == "cnv".to_string() {
250
+ // Not implemented yet
251
+ final_output = format!("{{\"{}\":\"{}\"}}", "chartType", "cnv");
252
+ } else if classification == "variant_calling".to_string() {
253
+ // Not implemented yet and will never be supported. Need a separate messages for this
254
+ final_output = format!("{{\"{}\":\"{}\"}}", "chartType", "variant_calling");
255
+ } else if classification == "surivial".to_string() {
256
+ // Not implemented yet
257
+ final_output = format!("{{\"{}\":\"{}\"}}", "chartType", "surivial");
258
+ } else if classification == "none".to_string() {
259
+ final_output = format!("{{\"{}\":\"{}\"}}", "chartType", "none");
260
+ println!("The input query did not match any known features in Proteinpaint");
261
+ } else {
262
+ final_output = format!(
263
+ "{{\"{}\":\"{}\"}}",
264
+ "chartType",
265
+ "unknown:".to_string() + &classification
266
+ );
267
+ }
268
+ Some(final_output)
269
+ }
270
+
271
+ async fn classify_query_by_dataset_type(
272
+ user_input: &str,
273
+ comp_model: impl rig::completion::CompletionModel + 'static,
274
+ embedding_model: impl rig::embeddings::EmbeddingModel + 'static,
275
+ llm_backend_type: &llm_backend,
276
+ temperature: f64,
277
+ max_new_tokens: usize,
278
+ top_p: f32,
279
+ ) -> String {
280
+ // Create a string to hold the file contents
281
+ let contents = String::from("SNV/SNP or point mutations nucleotide mutations are very common forms of mutations which can often give rise to genetic diseases such as cancer, Alzheimer's disease etc. They can be duw to substitution of nucleotide, or insertion or deletion of a nucleotide. Indels are multi-nucleotide insertion/deletion/substitutions. Complex indels are indels where insertion and deletion have happened in the same genomic locus. Every genomic sample from each patient has its own set of mutations therefore requiring personalized treatment.
282
+
283
+ If a ProteinPaint dataset contains SNV/Indel/SV data then return JSON with single key, 'snv_indel'.
284
+
285
+ ---
286
+
287
+ Copy number variation (CNV) is a phenomenon in which sections of the genome are repeated and the number of repeats in the genome varies between individuals.[1] Copy number variation is a special type of structural variation: specifically, it is a type of duplication or deletion event that affects a considerable number of base pairs.
288
+
289
+ If a ProteinPaint dataset contains copy number variation data then return JSON with single key, 'cnv'.
290
+
291
+ ---
292
+
293
+ Structural variants/fusions (SV) are genomic mutations when eith a DNA region is translocated or copied to an entirely different genomic locus. In case of transcriptomic data, when RNA is fused from two different genes its called a gene fusion.
294
+
295
+ If a ProteinPaint dataset contains structural variation or gene fusion data then return JSON with single key, 'sv_fusion'.
296
+ ---
297
+
298
+ Hierarchial clustering of gene expression is an unsupervised learning technique where several number of relevant genes and the samples are clustered so as to determine (previously unknown) cohorts of samples (or patients) or structure in data. It is very commonly used to determine subtypes of a particular disease based on RNA sequencing data.
299
+
300
+ If a ProteinPaint dataset contains hierarchial data then return JSON with single key, 'hierarchial'.
301
+
302
+ ---
303
+
304
+ Differential Gene Expression (DGE or DE) is a technique where the most upregulated and downregulated genes between two cohorts of samples (or patients) are determined. A volcano plot is shown with fold-change in the x-axis and adjusted p-value on the y-axis. So, the upregulated and downregulared genes are on opposite sides of the graph and the most significant genes (based on adjusted p-value) is on the top of the graph. Following differential gene expression generally GeneSet Enrichment Analysis (GSEA) is carried out where based on the genes and their corresponding fold changes the upregulation/downregulation of genesets (or pathways) is determined.
305
+
306
+ If a ProteinPaint dataset contains differential gene expression data then return JSON with single key, 'dge'.
307
+
308
+ ---
309
+
310
+ Survival analysis (also called time-to-event analysis or duration analysis) is a branch of statistics aimed at analyzing the duration of time from a well-defined time origin until one or more events happen, called survival times or duration times. In other words, in survival analysis, we are interested in a certain event and want to analyze the time until the event happens.
311
+
312
+ There are two main methods of survival analysis:
313
+
314
+ 1) Kaplan-Meier (HM) analysis is a univariate test that only takes into account a single categorical variable.
315
+ 2) Cox proportional hazards model (coxph) is a multivariate test that can take into account multiple variables.
316
+
317
+ The hazard ratio (HR) is an indicator of the effect of the stimulus (e.g. drug dose, treatment) between two cohorts of patients.
318
+ HR = 1: No effect
319
+ HR < 1: Reduction in the hazard
320
+ HR > 1: Increase in Hazard
321
+
322
+ If a ProteinPaint dataset contains survival data then return JSON with single key, 'survival'.
323
+
324
+ ---
325
+
326
+ Next generation sequencing reads (NGS) are mapped to a human genome using alignment algorithm such as burrows-wheelers alignment algorithm. Then these reads are called using variant calling algorithms such as GATK (Genome Analysis Toolkit). However this type of analysis is too compute intensive and beyond the scope of visualization software such as ProteinPaint.
327
+
328
+ If a user query asks about variant calling or mapping reads then JSON with single key, 'variant_calling'.
329
+
330
+ ---
331
+
332
+ Summary plot in ProteinPaint shows the various facets of the datasets. It may show all the samples according to their respective diagnosis or subtypes of cancer. It is also useful for visualizing all the different facets of the dataset. You can display a categorical variable and overlay another variable on top it and stratify (or divide) using a third variable simultaneously. You can also custom filters to the dataset so that you can only study part of the dataset. If a user query asks about variant calling or mapping reads then JSON with single key, 'summary'.
333
+
334
+ Sample Query1: \"Show all fusions for patients with age less than 30\"
335
+ Sample Answer1: { \"answer\": \"summary\" }
336
+
337
+ Sample Query1: \"List all molecular subtypes of leukemia\"
338
+ Sample Answer1: { \"answer\": \"summary\" }
339
+
340
+ ---
341
+
342
+ If a query does not match any of the fields described above, then return JSON with single key, 'none'
343
+ ");
344
+
345
+ // Split the contents by the delimiter "---"
346
+ let parts: Vec<&str> = contents.split("---").collect();
347
+ let schema_json: Value = serde_json::to_value(schemars::schema_for!(OutputJson)).unwrap(); // error handling here
348
+
349
+ let additional;
350
+ match llm_backend_type {
351
+ llm_backend::Ollama() => {
352
+ additional = json!({
353
+ "format": schema_json
354
+ }
355
+ );
356
+ }
357
+ llm_backend::Sj() => {
358
+ additional = json!({
359
+ "max_new_tokens": max_new_tokens,
360
+ "top_p": top_p
361
+ });
362
+ }
363
+ }
364
+
365
+ // Print the separated parts
366
+ let mut rag_docs = Vec::<String>::new();
367
+ for (_i, part) in parts.iter().enumerate() {
368
+ //println!("Part {}: {}", i + 1, part.trim());
369
+ rag_docs.push(part.trim().to_string())
370
+ }
371
+
372
+ let top_k: usize = 3;
373
+ // Create embeddings and add to vector store
374
+ let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
375
+ .documents(rag_docs)
376
+ .expect("Reason1")
377
+ .build()
378
+ .await
379
+ .unwrap();
380
+
381
+ // Create vector store
382
+ let mut vector_store = InMemoryVectorStore::<String>::default();
383
+ InMemoryVectorStore::add_documents(&mut vector_store, embeddings);
384
+
385
+ // Create RAG agent
386
+ let agent = AgentBuilder::new(comp_model).preamble("Generate classification for the user query into summary, dge, hierarchial, snv_indel, cnv, variant_calling, sv_fusion and none categories. Return output in JSON with ALWAYS a single word answer { \"answer\": \"dge\" }, that is 'summary' for summary plot, 'dge' for differential gene expression, 'hierarchial' for hierarchial clustering, 'snv_indel' for SNV/Indel, 'cnv' for CNV and 'sv_fusion' for SV/fusion, 'variant_calling' for variant calling, 'surivial' for survival data, 'none' for none of the previously described categories. The summary plot list and summarizes the cohort of patients according to the user query. The answer should always be in lower case").dynamic_context(top_k, vector_store.index(embedding_model)).temperature(temperature).additional_params(additional).build();
387
+
388
+ let response = agent.prompt(user_input).await.expect("Failed to prompt ollama");
389
+
390
+ //println!("Ollama: {}", response);
391
+ let result = response.replace("json", "").replace("```", "");
392
+ let json_value: Value = serde_json::from_str(&result).expect("REASON");
393
+ match llm_backend_type {
394
+ llm_backend::Ollama() => json_value.as_object().unwrap()["answer"].to_string().replace("\"", ""),
395
+ llm_backend::Sj() => {
396
+ let json_value2: Value =
397
+ serde_json::from_str(&json_value[0]["generated_text"].to_string()).expect("REASON2");
398
+ //println!("json_value2:{}", json_value2.as_str().unwrap());
399
+ let json_value3: Value = serde_json::from_str(&json_value2.as_str().unwrap()).expect("REASON2");
400
+ //let json_value3: Value = serde_json::from_str(&json_value2["answer"].to_string()).expect("REASON2");
401
+ //println!("Classification result:{}", json_value3["answer"]);
402
+ json_value3["answer"].to_string()
403
+ }
404
+ }
405
+ }
406
+
407
+ #[allow(non_snake_case)]
408
+ async fn extract_DE_search_terms_from_query(
409
+ user_input: &str,
410
+ comp_model: impl rig::completion::CompletionModel + 'static,
411
+ embedding_model: impl rig::embeddings::EmbeddingModel + 'static,
412
+ llm_backend_type: &llm_backend,
413
+ temperature: f64,
414
+ max_new_tokens: usize,
415
+ top_p: f32,
416
+ ) -> String {
417
+ let contents = String::from("Differential Gene Expression (DGE or DE) is a technique where the most upregulated and downregulated genes between two cohorts of samples (or patients) are determined. A volcano plot is shown with fold-change in the x-axis and adjusted p-value on the y-axis. So, the upregulated and downregulared genes are on opposite sides of the graph and the most significant genes (based on adjusted p-value) is on the top of the graph.
418
+
419
+ The user may select a cutoff for a continuous variables such as age. In such cases the group should only include the range specified by the user. Inside the JSON each entry the name of the group must be inside the field \"name\". For the cutoff (if provided) a field called \"cutoff\" must be provided which should contain a subfield \"name\" containing the name of the cutoff, followed by \"greater\"/\"lesser\"/\"equal\" to followed by the numeric value of the cutoff. If the unit of the variable is provided such as cm,m,inches,celsius etc. then add it to a separate field called \"units\".
420
+
421
+ Example input user queries:
422
+ When two groups are found give the following JSON output show {\"group1\": \"groupA\", \"group2\": \"groupB\"}
423
+ User query1: \"Show me the differential gene expression plot for groups groupA and groupB\"
424
+ Output JSON query1: {\"group1\": {\"name\": \"groupA\"}, \"group2\": {\"name\": \"groupB\"}}
425
+
426
+ User query2: \"Show volcano plot for White vs Black\"
427
+ Output JSON query2: {\"group1\": {\"name\": \"White\"}, \"group2\": {\"name\": \"Black\"}}
428
+
429
+ In case no suitable groups are found, show {\"output\":\"No suitable two groups found for differential gene expression\"}
430
+ User query3: \"Who wants to have vodka?\"
431
+ Output JSON query3: {\"output\":\"No suitable two groups found for differential gene expression\"}
432
+
433
+ User query4: \"Show volcano plot for Asians with age less than 20 and African greater than 80\"
434
+ Output JSON query4: {\"group1\": {\"name\": \"Asians\", \"filter\": {\"name\": \"age\", \"cutoff\": {\"lesser\": 20}}}, \"group2\": {\"name\": \"African\", \"filter\": {\"name\": \"age\", \"cutoff\": {\"greater\": 80}}}}
435
+
436
+ User query5: \"Show Differential gene expression plot for males with height greater than 185cm and women with less than 100cm\"
437
+ Output JSON query5: {\"group1\": {\"name\": \"males\", \"filter\": {\"name\": \"height\", \"cutoff\": {\"greater\": 185, \"units\":\"cm\"}}}, \"group2\": {\"name\": \"women\", \"filter\": {\"name\": \"height\", \"cutoff\": {\"lesser\": 100, \"units\": \"cm\"}}}}");
438
+
439
+ // Split the contents by the delimiter "---"
440
+ let parts: Vec<&str> = contents.split("---").collect();
441
+
442
+ let schema_json: Value = serde_json::to_value(schemars::schema_for!(DEOutput)).unwrap(); // error handling here
443
+
444
+ //println!("DE schema:{}", schema_json);
445
+
446
+ let additional;
447
+ match llm_backend_type {
448
+ llm_backend::Ollama() => {
449
+ additional = json!({
450
+ "format": schema_json
451
+ }
452
+ );
453
+ }
454
+ llm_backend::Sj() => {
455
+ additional = json!({
456
+ "max_new_tokens": max_new_tokens,
457
+ "top_p": top_p
458
+ });
459
+ }
460
+ }
461
+
462
+ // Print the separated parts
463
+ let mut rag_docs = Vec::<String>::new();
464
+ for (_i, part) in parts.iter().enumerate() {
465
+ //println!("Part {}: {}", i + 1, part.trim());
466
+ rag_docs.push(part.trim().to_string())
467
+ }
468
+
469
+ let rag_docs_length = rag_docs.len();
470
+ // Create embeddings and add to vector store
471
+ let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
472
+ .documents(rag_docs)
473
+ .expect("Reason1")
474
+ .build()
475
+ .await
476
+ .unwrap();
477
+
478
+ // Create vector store
479
+ let mut vector_store = InMemoryVectorStore::<String>::default();
480
+ InMemoryVectorStore::add_documents(&mut vector_store, embeddings);
481
+
482
+ // Create RAG agent
483
+ let router_instructions = "Extract the group variable names for differential gene expression from input query. When two groups are found give the following JSON output with no extra comments. Show {{\"group1\": {\"name\": \"groupA\"}, \"group2\": {\"name\": \"groupB\"}}}. In case no suitable groups are found, show {\"output\":\"No suitable two groups found for differential gene expression\"}. In case of a continuous variable such as age, height added additional field to the group called \"filter\". This should contain a sub-field called \"names\" followed by a subfield called \"cutoff\". This sub-field should contain a key either greater, lesser or equalto. If the continuous variable has units provided by the user then add it in a separate field called \"units\". User query1: \"Show volcano plot for Asians with age less than 20 and African greater than 80\". Output JSON query1: {\"group1\": {\"name\": \"Asians\", \"filter\": {\"name\": \"age\", \"cutoff\": {\"lesser\": 20}}}, \"group2\": {\"name\": \"African\", \"filter\": {\"name\": \"age\", \"cutoff\": {\"greater\": 80}}}}. User query2: \"Show Differential gene expression plot for males with height greater than 185cm and women with less than 100cm\". Output JSON query2: {\"group1\": {\"name\": \"males\", \"filter\": {\"name\": \"height\", \"cutoff\": {\"greater\": 185, \"units\":\"cm\"}}}, \"group2\": {\"name\": \"women\", \"filter\": {\"name\": \"height\", \"cutoff\": {\"lesser\": 100, \"units\": \"cm\"}}}}. User query3: \"Show DE plot between healthy and diseased groups. Output JSON query3: {\"group1\":{\"name\":\"healthy\"},\"group2\":{\"name\":\"diseased\"}}";
484
+ //println! {"router_instructions:{}",router_instructions};
485
+ let agent = AgentBuilder::new(comp_model)
486
+ .preamble(router_instructions)
487
+ .dynamic_context(rag_docs_length, vector_store.index(embedding_model))
488
+ .temperature(temperature)
489
+ .additional_params(additional)
490
+ .build();
491
+
492
+ let response = agent.prompt(user_input).await.expect("Failed to prompt ollama");
493
+
494
+ //println!("Ollama_groups: {}", response);
495
+ let result = response.replace("json", "").replace("```", "");
496
+ //println!("result_groups:{}", result);
497
+ let json_value: Value = serde_json::from_str(&result).expect("REASON");
498
+ //println!("json_value:{}", json_value);
499
+ match llm_backend_type {
500
+ llm_backend::Ollama() => json_value.to_string(),
501
+ llm_backend::Sj() => {
502
+ let json_value2: Value =
503
+ serde_json::from_str(&json_value[0]["generated_text"].to_string()).expect("REASON2");
504
+ //println!("json_value2:{}", json_value2.as_str().unwrap());
505
+ let json_value3: Value = serde_json::from_str(&json_value2.as_str().unwrap()).expect("REASON2");
506
+ //println!("Classification result:{}", json_value3);
507
+ json_value3.to_string()
508
+ }
509
+ }
510
+ }
511
+
512
+ struct DbRows {
513
+ name: String,
514
+ description: Option<String>,
515
+ term_type: Option<String>,
516
+ values: Vec<String>,
517
+ }
518
+
519
+ trait ParseDbRows {
520
+ fn parse_db_rows(&self) -> String;
521
+ }
522
+
523
+ impl ParseDbRows for DbRows {
524
+ fn parse_db_rows(&self) -> String {
525
+ let mut output: String = "Name of field is \"".to_string() + &self.name + &"\". ";
526
+
527
+ match &self.term_type {
528
+ Some(item_ty) => {
529
+ output += "This field is of the type ";
530
+ output += &item_ty;
531
+ output += &". ";
532
+ }
533
+ None => {}
534
+ }
535
+ match &self.description {
536
+ Some(desc) => output += desc,
537
+ None => {}
538
+ }
539
+ if self.values.len() > 0 {
540
+ output += "This contains the following values (separated by comma(,)):";
541
+ output += &(self.values.join(",") + &".");
542
+ }
543
+ output
544
+ }
545
+ }
546
+
547
+ async fn parse_dataset_db(db: &str) -> Vec<String> {
548
+ let manager = SqliteConnectionManager::file(db);
549
+ let pool = r2d2::Pool::new(manager).unwrap();
550
+ let conn = pool.get().unwrap();
551
+
552
+ let sql_statement_termhtmldef = "SELECT * from termhtmldef";
553
+ let mut termhtmldef = conn.prepare(&sql_statement_termhtmldef).unwrap();
554
+ let mut rows_termhtmldef = termhtmldef.query([]).unwrap();
555
+ let mut description_map = HashMap::new();
556
+ while let Some(row) = rows_termhtmldef.next().unwrap() {
557
+ //println!("row:{:?}", row);
558
+ let name: String = row.get(0).unwrap();
559
+ //println!("name:{}", name);
560
+ let json_html_str: String = row.get(1).unwrap();
561
+ let json_html: Value = serde_json::from_str(&json_html_str).expect("Not a JSON");
562
+ let json_html2: &Map<String, Value> = json_html.as_object().unwrap();
563
+ let description: String = String::from(
564
+ json_html2.get("description").unwrap()[0]
565
+ .as_object()
566
+ .unwrap()
567
+ .get("value")
568
+ .unwrap()
569
+ .as_str()
570
+ .unwrap(),
571
+ );
572
+ //println!("description:{}", description);
573
+ description_map.insert(name, description);
574
+ }
575
+
576
+ //// Open the file
577
+ //let mut file = File::open(dataset_agnostic_file).unwrap();
578
+
579
+ //// Create a string to hold the file contents
580
+ //let mut contents = String::new();
581
+
582
+ //// Read the file contents into the string
583
+ //file.read_to_string(&mut contents).unwrap();
584
+
585
+ //// Split the contents by the delimiter "---"
586
+ //let parts: Vec<&str> = contents.split("\n").collect();
587
+
588
+ //for (_i, part) in parts.iter().enumerate() {
589
+ // let sentence: &str = part.trim();
590
+ // let parts2: Vec<&str> = sentence.split(':').collect();
591
+ // //println!("parts2:{:?}", parts2);
592
+ // if parts2.len() == 2 {
593
+ // description_map.insert(parts2[0], parts2[1]);
594
+ // //println!("Part {}: {:?}", i + 1, parts2);
595
+ // }
596
+ //}
597
+ //println!("description_map:{:?}", description_map);
598
+
599
+ let sql_statement_terms = "SELECT * from terms";
600
+ let mut terms = conn.prepare(&sql_statement_terms).unwrap();
601
+ let mut rows_terms = terms.query([]).unwrap();
602
+
603
+ // Print the separated parts
604
+ let mut rag_docs = Vec::<String>::new();
605
+ let mut names = Vec::<String>::new();
606
+ while let Some(row) = rows_terms.next().unwrap() {
607
+ //println!("row:{:?}", row);
608
+ let name: String = row.get(0).unwrap();
609
+ //println!("id:{}", name);
610
+ match description_map.get(&name as &str) {
611
+ Some(desc) => {
612
+ let line: String = row.get(3).unwrap();
613
+ //println!("line:{}", line);
614
+ let json_data: Value = serde_json::from_str(&line).expect("Not a JSON");
615
+ let values_json = json_data["values"].as_object();
616
+ let mut keys = Vec::<String>::new();
617
+ match values_json {
618
+ Some(values) => {
619
+ for (key, _value) in values {
620
+ keys.push(key.to_string())
621
+ }
622
+ }
623
+ None => {}
624
+ }
625
+
626
+ let item_type_json = json_data["type"].as_str();
627
+ let mut item_type: Option<String> = None;
628
+ match item_type_json {
629
+ Some(item_ty) => item_type = Some(String::from(item_ty)),
630
+ None => {}
631
+ }
632
+
633
+ //println!("items:{:?}", keys);
634
+ let item: DbRows = DbRows {
635
+ name: name.clone(),
636
+ description: Some(String::from(desc.clone())),
637
+ term_type: item_type,
638
+ values: keys,
639
+ };
640
+ //println!("Field details:{}", item.parse_db_rows());
641
+ rag_docs.push(item.parse_db_rows());
642
+ names.push(name)
643
+ }
644
+ None => {}
645
+ }
646
+ }
647
+ //println!("names:{:?}", names);
648
+ rag_docs
649
+ }
650
+
651
+ async fn extract_summary_information(
652
+ user_input: &str,
653
+ comp_model: impl rig::completion::CompletionModel + 'static,
654
+ embedding_model: impl rig::embeddings::EmbeddingModel + 'static,
655
+ llm_backend_type: &llm_backend,
656
+ temperature: f64,
657
+ max_new_tokens: usize,
658
+ top_p: f32,
659
+ dataset_db: Option<&str>,
660
+ ) -> String {
661
+ match dataset_db {
662
+ Some(db) => {
663
+ let rag_docs = parse_dataset_db(db).await;
664
+ //println!("rag_docs:{:?}", rag_docs);
665
+ let additional;
666
+ match llm_backend_type {
667
+ llm_backend::Ollama() => {
668
+ additional = json!({});
669
+ }
670
+ llm_backend::Sj() => {
671
+ additional = json!({
672
+ "max_new_tokens": max_new_tokens,
673
+ "top_p": top_p
674
+ });
675
+ }
676
+ }
677
+
678
+ let rag_docs_length = rag_docs.len();
679
+ // Create embeddings and add to vector store
680
+ let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
681
+ .documents(rag_docs)
682
+ .expect("Reason1")
683
+ .build()
684
+ .await
685
+ .unwrap();
686
+
687
+ // Create vector store
688
+ let mut vector_store = InMemoryVectorStore::<String>::default();
689
+ InMemoryVectorStore::add_documents(&mut vector_store, embeddings);
690
+
691
+ //let system_prompt = "I am an assistant that figures out the summary term from its respective dataset file. Extract the summary term {summary_term} from user query. The final output must be in the following JSON format {{\"chartType\":\"summary\",\"term\":{{\"id\":\"{{summary_term}}\"}}}}";
692
+
693
+ let top_k = rag_docs_length;
694
+ let system_prompt = String::from(
695
+ "I am an assistant that extracts the summary term from user query. It has four fields: group_categories (required), overlay (optional), filter (optional) and divide_by (optional). group_categories (required) is the primary variable being displayed. Overlay consists of the variable that must be overlayed on top of group_categories. divide_by is the variable used to stratify group_categories into two or more categories. The final output must be in the following JSON format with no extra comments: {\"chartType\":\"summary\",\"term\":{\"group_categories\":\"{group_category_answer}\",\"overlay\":\"{overlay_answer}\",\"divide_by\":\"{divide_by_answer}\",\"filter\":\"{filter_answer}\"}}. The values being added to the JSON parameters must be previously defined as field in the database. If the filter variable is a \"value\" of a \"field\" in the database, use the field name and add the value as a \"filter cutoff\" . If the \"filter\" field is defined in the user query, it should contain an array with each item containing a subfield called \"name\" with the name of the filter variable. If the type of variable is \"categories\", add another field as \"variable_type\" = \"categories\". In case the type of the variable is \"categories\", show the sub-category as a separate sub-field \"cutoff\" with a sub nested JSON with \"name\" as the field containing the subcategory name. In case the type of the variable is \"float\" it should contain a subfield called \"name\" followed by subfield \"variable_type\" = \"float\". In the \"cutoff\" subfield, the nested JSON should contain the field \"lower\" containing the lower numeric limit and the \"upper\" field containing the upper numeric limit. If the upper and lower cutoffs are not defined in the user query, use a default value of 0 and 100 respectively. Sample query1: \"Show ETR1 subtype\" Answer query1: \"{\"chartType\":\"summary\",\"term\":{\"group_categories\":\"ETR1\"}}. Sample query2: \"Show hyperdiploid subtype with age overlayed on top of it\" Answer query2: \"{\"chartType\":\"summary\",\"term\":{\"group_categories\":\"hyperdiploid\", \"overlay\":\"age\"}}. Sample query3: \"Show BAR1 subtype with age overlayed on top of it and stratify it on the basis of gender\" Answer query4: \"{\"chartType\":\"summary\",\"term\":{\"group_categories\":\"BAR1\", \"overlay\":\"age\", \"divide_by\":\"sex\"}}. Sample query5: \"Show summary for cancer-diagnosis only for men\". Since gender is a categorical variable and the user wants to select for men, the answer query for sample query5 is as follows: \"{\"chartType\":\"summary\",\"term\":{\"group_categories\":\"cancer-diagnosis\", \"filter\": {\"name\": \"gender\", \"variable_type\": \"categories\", \"cutoff\": {\"name\": \"male\"}}}}. Sample query6: \"Show molecular subtype summary for patients with age less than 30\". Age is a float variable so we need to provide the lower and higher cutoffs. So the answer to sample query6 is as follows: \"{\"chartType\":\"summary\",\"term\":{\"group_categories\":\"Molecular subtype\", \"filter\": {\"name\": \"age\", \"variable_type\": \"float\", \"cutoff\": {\"lower\": 0, \"higher\": 30}}}} ",
696
+ );
697
+ //println!("system_prompt:{}", system_prompt);
698
+ // Create RAG agent
699
+ let agent = AgentBuilder::new(comp_model)
700
+ .preamble(&system_prompt)
701
+ .dynamic_context(top_k, vector_store.index(embedding_model))
702
+ .temperature(temperature)
703
+ .additional_params(additional)
704
+ .build();
705
+
706
+ let response = agent.prompt(user_input).await.expect("Failed to prompt ollama");
707
+
708
+ //println!("Ollama: {}", response);
709
+ let result = response.replace("json", "").replace("```", "");
710
+ //println!("result:{}", result);
711
+ let json_value: Value = serde_json::from_str(&result).expect("REASON");
712
+ //println!("Classification result:{}", json_value);
713
+
714
+ match llm_backend_type {
715
+ llm_backend::Ollama() => json_value.to_string(),
716
+ llm_backend::Sj() => {
717
+ let json_value2: Value =
718
+ serde_json::from_str(&json_value[0]["generated_text"].to_string()).expect("REASON2");
719
+ //println!("json_value2:{}", json_value2.as_str().unwrap());
720
+ let json_value3: Value = serde_json::from_str(&json_value2.as_str().unwrap()).expect("REASON2");
721
+ //println!("Classification result:{}", json_value3);
722
+ json_value3.to_string()
723
+ }
724
+ }
725
+ }
726
+ None => {
727
+ panic!("Dataset db file needed for summary term extraction from user input")
728
+ }
729
+ }
730
+ }