@sjcrh/proteinpaint-rust 2.188.0 → 2.190.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Cargo.toml +0 -9
- package/package.json +1 -1
- package/src/manhattan_plot.rs +38 -76
- package/src/volcano.rs +89 -49
- package/src/aichatbot.rs +0 -1554
- package/src/query_classification.rs +0 -152
- package/src/summary_agent.rs +0 -201
- package/src/test_ai.rs +0 -193
package/src/aichatbot.rs
DELETED
|
@@ -1,1554 +0,0 @@
|
|
|
1
|
-
// Syntax: cd .. && cargo build --release && time cat ~/sjpp/test.txt | target/release/aichatbot
|
|
2
|
-
#![allow(non_snake_case)]
|
|
3
|
-
//use anyhow::Result;
|
|
4
|
-
//use json::JsonValue;
|
|
5
|
-
use r2d2_sqlite::SqliteConnectionManager;
|
|
6
|
-
use rig::agent::AgentBuilder;
|
|
7
|
-
use rig::completion::Prompt;
|
|
8
|
-
use rig::embeddings::builder::EmbeddingsBuilder;
|
|
9
|
-
use rig::vector_store::in_memory_store::InMemoryVectorStore;
|
|
10
|
-
use schemars::JsonSchema;
|
|
11
|
-
use serde_json::{Map, Value, json};
|
|
12
|
-
use std::collections::HashMap;
|
|
13
|
-
use std::fs;
|
|
14
|
-
|
|
15
|
-
// Struct for intaking data from dataset json
|
|
16
|
-
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
17
|
-
pub struct AiJsonFormat {
|
|
18
|
-
pub hasGeneExpression: bool,
|
|
19
|
-
pub db: String, // Dataset db
|
|
20
|
-
pub genedb: String, // Gene db
|
|
21
|
-
pub charts: Vec<TrainTestData>,
|
|
22
|
-
}
|
|
23
|
-
|
|
24
|
-
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
25
|
-
pub struct TrainTestData {
|
|
26
|
-
pub r#type: String,
|
|
27
|
-
pub SystemPrompt: String,
|
|
28
|
-
pub TrainingData: Vec<QuestionAnswer>,
|
|
29
|
-
pub TestData: Vec<QuestionAnswer>,
|
|
30
|
-
}
|
|
31
|
-
|
|
32
|
-
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
33
|
-
pub struct QuestionAnswer {
|
|
34
|
-
pub question: String,
|
|
35
|
-
pub answer: AnswerFormat,
|
|
36
|
-
}
|
|
37
|
-
|
|
38
|
-
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
39
|
-
pub enum AnswerFormat {
|
|
40
|
-
#[allow(non_camel_case_types)]
|
|
41
|
-
summary_type(SummaryType),
|
|
42
|
-
#[allow(non_camel_case_types)]
|
|
43
|
-
DE_type(DEType),
|
|
44
|
-
}
|
|
45
|
-
|
|
46
|
-
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
47
|
-
pub struct DEType {
|
|
48
|
-
action: String,
|
|
49
|
-
DE_output: DETerms,
|
|
50
|
-
}
|
|
51
|
-
|
|
52
|
-
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
53
|
-
pub struct DETerms {
|
|
54
|
-
group1: GroupType,
|
|
55
|
-
group2: GroupType,
|
|
56
|
-
}
|
|
57
|
-
|
|
58
|
-
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
59
|
-
pub struct GroupType {
|
|
60
|
-
name: String,
|
|
61
|
-
}
|
|
62
|
-
|
|
63
|
-
#[allow(non_camel_case_types)]
|
|
64
|
-
#[derive(Debug, Clone)]
|
|
65
|
-
pub enum llm_backend {
|
|
66
|
-
Ollama(),
|
|
67
|
-
Sj(),
|
|
68
|
-
}
|
|
69
|
-
|
|
70
|
-
#[derive(Debug, JsonSchema)]
|
|
71
|
-
#[allow(dead_code)]
|
|
72
|
-
pub struct OutputJson {
|
|
73
|
-
pub answer: String,
|
|
74
|
-
}
|
|
75
|
-
|
|
76
|
-
//#[tokio::main]
|
|
77
|
-
//async fn main() -> Result<()> {
|
|
78
|
-
// let mut input = String::new();
|
|
79
|
-
// match io::stdin().read_line(&mut input) {
|
|
80
|
-
// // Accepting the piped input from nodejs (or command line from testing)
|
|
81
|
-
// Ok(_n) => {
|
|
82
|
-
// let input_json = json::parse(&input);
|
|
83
|
-
// match input_json {
|
|
84
|
-
// Ok(json_string) => {
|
|
85
|
-
// //println!("json_string:{}", json_string);
|
|
86
|
-
// let user_input_json: &JsonValue = &json_string["user_input"];
|
|
87
|
-
// let user_input: &str;
|
|
88
|
-
// match user_input_json.as_str() {
|
|
89
|
-
// Some(inp) => user_input = inp,
|
|
90
|
-
// None => panic!("user_input field is missing in input json"),
|
|
91
|
-
// }
|
|
92
|
-
// let dataset_db_json: &JsonValue = &json_string["dataset_db"];
|
|
93
|
-
// let dataset_db_str: &str;
|
|
94
|
-
// match dataset_db_json.as_str() {
|
|
95
|
-
// Some(inp) => dataset_db_str = inp,
|
|
96
|
-
// None => panic!("dataset_db field is missing in input json"),
|
|
97
|
-
// }
|
|
98
|
-
// let genedb_json: &JsonValue = &json_string["genedb"];
|
|
99
|
-
// let genedb_str: &str;
|
|
100
|
-
// match genedb_json.as_str() {
|
|
101
|
-
// Some(inp) => genedb_str = inp,
|
|
102
|
-
// None => panic!("genedb field is missing in input json"),
|
|
103
|
-
// }
|
|
104
|
-
// let aiRoute_json: &JsonValue = &json_string["aiRoute"];
|
|
105
|
-
// let aiRoute_str: &str;
|
|
106
|
-
// match aiRoute_json.as_str() {
|
|
107
|
-
// Some(inp) => aiRoute_str = inp,
|
|
108
|
-
// None => panic!("aiRoute field is missing in input json"),
|
|
109
|
-
// }
|
|
110
|
-
// if user_input.len() == 0 {
|
|
111
|
-
// panic!("The user input is empty");
|
|
112
|
-
// }
|
|
113
|
-
// let tpmasterdir_json: &JsonValue = &json_string["tpmasterdir"];
|
|
114
|
-
// let tpmasterdir: &str;
|
|
115
|
-
// match tpmasterdir_json.as_str() {
|
|
116
|
-
// Some(inp) => tpmasterdir = inp,
|
|
117
|
-
// None => panic!("tpmasterdir not found"),
|
|
118
|
-
// }
|
|
119
|
-
// let binpath_json: &JsonValue = &json_string["binpath"];
|
|
120
|
-
// let binpath: &str;
|
|
121
|
-
// match binpath_json.as_str() {
|
|
122
|
-
// Some(inp) => binpath = inp,
|
|
123
|
-
// None => panic!("binpath not found"),
|
|
124
|
-
// }
|
|
125
|
-
// let ai_json_file_json: &JsonValue = &json_string["aifiles"];
|
|
126
|
-
// let ai_json_file: String;
|
|
127
|
-
// match ai_json_file_json.as_str() {
|
|
128
|
-
// Some(inp) => ai_json_file = String::from(binpath) + &"/../../" + &inp,
|
|
129
|
-
// None => {
|
|
130
|
-
// panic!("ai json file not found")
|
|
131
|
-
// }
|
|
132
|
-
// }
|
|
133
|
-
// let ai_json_file = Path::new(&ai_json_file);
|
|
134
|
-
// let ai_json_file_path;
|
|
135
|
-
// let current_dir = std::env::current_dir().unwrap();
|
|
136
|
-
// match ai_json_file.canonicalize() {
|
|
137
|
-
// Ok(p) => ai_json_file_path = p,
|
|
138
|
-
// Err(_) => {
|
|
139
|
-
// panic!(
|
|
140
|
-
// "AI JSON file path not found:{:?}, current directory:{:?}",
|
|
141
|
-
// ai_json_file, current_dir
|
|
142
|
-
// )
|
|
143
|
-
// }
|
|
144
|
-
// }
|
|
145
|
-
// // Read the file
|
|
146
|
-
// let ai_data = fs::read_to_string(ai_json_file_path).unwrap();
|
|
147
|
-
// // Parse the JSON data
|
|
148
|
-
// let ai_json: AiJsonFormat =
|
|
149
|
-
// serde_json::from_str(&ai_data).expect("AI JSON file does not have the correct format");
|
|
150
|
-
// let genedb = String::from(tpmasterdir) + &"/" + &genedb_str;
|
|
151
|
-
// let dataset_db = String::from(tpmasterdir) + &"/" + &dataset_db_str;
|
|
152
|
-
// let airoute = String::from(binpath) + &"/../../" + &aiRoute_str;
|
|
153
|
-
// let apilink_json: &JsonValue = &json_string["apilink"];
|
|
154
|
-
// let apilink: &str;
|
|
155
|
-
// match apilink_json.as_str() {
|
|
156
|
-
// Some(inp) => apilink = inp,
|
|
157
|
-
// None => panic!("apilink field is missing in input json"),
|
|
158
|
-
// }
|
|
159
|
-
// let comp_model_name_json: &JsonValue = &json_string["comp_model_name"];
|
|
160
|
-
// let comp_model_name: &str;
|
|
161
|
-
// match comp_model_name_json.as_str() {
|
|
162
|
-
// Some(inp) => comp_model_name = inp,
|
|
163
|
-
// None => panic!("comp_model_name field is missing in input json"),
|
|
164
|
-
// }
|
|
165
|
-
// let embedding_model_name_json: &JsonValue = &json_string["embedding_model_name"];
|
|
166
|
-
// let embedding_model_name: &str;
|
|
167
|
-
// match embedding_model_name_json.as_str() {
|
|
168
|
-
// Some(inp) => embedding_model_name = inp,
|
|
169
|
-
// None => panic!("embedding_model_name field is missing in input json"),
|
|
170
|
-
// }
|
|
171
|
-
// let llm_backend_name_json: &JsonValue = &json_string["llm_backend_name"];
|
|
172
|
-
// let llm_backend_name: &str;
|
|
173
|
-
// match llm_backend_name_json.as_str() {
|
|
174
|
-
// Some(inp) => llm_backend_name = inp,
|
|
175
|
-
// None => panic!("llm_backend_name field is missing in input json"),
|
|
176
|
-
// }
|
|
177
|
-
// let llm_backend_type: llm_backend;
|
|
178
|
-
// let mut final_output: Option<String> = None;
|
|
179
|
-
// let temperature: f64 = 0.01;
|
|
180
|
-
// let max_new_tokens: usize = 512;
|
|
181
|
-
// let top_p: f32 = 0.95;
|
|
182
|
-
// let testing = false; // This variable is always false in production, this is true in test_ai.rs for testing code
|
|
183
|
-
// if llm_backend_name != "ollama" && llm_backend_name != "SJ" {
|
|
184
|
-
// panic!(
|
|
185
|
-
// "This code currently supports only Ollama and SJ provider. llm_backend_name must be \"ollama\" or \"SJ\""
|
|
186
|
-
// );
|
|
187
|
-
// } else if llm_backend_name == "ollama".to_string() {
|
|
188
|
-
// llm_backend_type = llm_backend::Ollama();
|
|
189
|
-
// // Initialize Ollama client
|
|
190
|
-
// let ollama_client = ollama::Client::builder()
|
|
191
|
-
// .base_url(apilink)
|
|
192
|
-
// .build()
|
|
193
|
-
// .expect("Ollama server not found");
|
|
194
|
-
// let embedding_model = ollama_client.embedding_model(embedding_model_name);
|
|
195
|
-
// let comp_model = ollama_client.completion_model(comp_model_name);
|
|
196
|
-
// final_output = run_pipeline(
|
|
197
|
-
// user_input,
|
|
198
|
-
// comp_model,
|
|
199
|
-
// embedding_model,
|
|
200
|
-
// llm_backend_type,
|
|
201
|
-
// temperature,
|
|
202
|
-
// max_new_tokens,
|
|
203
|
-
// top_p,
|
|
204
|
-
// &dataset_db,
|
|
205
|
-
// &genedb,
|
|
206
|
-
// &ai_json,
|
|
207
|
-
// &airoute,
|
|
208
|
-
// testing,
|
|
209
|
-
// )
|
|
210
|
-
// .await;
|
|
211
|
-
// } else if llm_backend_name == "SJ".to_string() {
|
|
212
|
-
// llm_backend_type = llm_backend::Sj();
|
|
213
|
-
// // Initialize Sj provider client
|
|
214
|
-
// let sj_client = sjprovider::Client::builder()
|
|
215
|
-
// .base_url(apilink)
|
|
216
|
-
// .build()
|
|
217
|
-
// .expect("SJ server not found");
|
|
218
|
-
// let embedding_model = sj_client.embedding_model(embedding_model_name);
|
|
219
|
-
// let comp_model = sj_client.completion_model(comp_model_name);
|
|
220
|
-
// final_output = run_pipeline(
|
|
221
|
-
// user_input,
|
|
222
|
-
// comp_model,
|
|
223
|
-
// embedding_model,
|
|
224
|
-
// llm_backend_type,
|
|
225
|
-
// temperature,
|
|
226
|
-
// max_new_tokens,
|
|
227
|
-
// top_p,
|
|
228
|
-
// &dataset_db,
|
|
229
|
-
// &genedb,
|
|
230
|
-
// &ai_json,
|
|
231
|
-
// &airoute,
|
|
232
|
-
// testing,
|
|
233
|
-
// )
|
|
234
|
-
// .await;
|
|
235
|
-
// }
|
|
236
|
-
// match final_output {
|
|
237
|
-
// Some(fin_out) => {
|
|
238
|
-
// println!("final_output:{:?}", fin_out.replace("\\", ""));
|
|
239
|
-
// }
|
|
240
|
-
// None => {
|
|
241
|
-
// println!("final_output:{{\"{}\":\"{}\"}}", "action", "unknown");
|
|
242
|
-
// }
|
|
243
|
-
// }
|
|
244
|
-
// }
|
|
245
|
-
// Err(error) => println!("Incorrect json:{}", error),
|
|
246
|
-
// }
|
|
247
|
-
// }
|
|
248
|
-
// Err(error) => println!("Piping error: {}", error),
|
|
249
|
-
// }
|
|
250
|
-
// Ok(())
|
|
251
|
-
//}
|
|
252
|
-
|
|
253
|
-
#[allow(dead_code)]
|
|
254
|
-
pub async fn run_pipeline(
|
|
255
|
-
user_input: &str,
|
|
256
|
-
comp_model: impl rig::completion::CompletionModel + 'static,
|
|
257
|
-
embedding_model: impl rig::embeddings::EmbeddingModel + 'static,
|
|
258
|
-
llm_backend_type: llm_backend,
|
|
259
|
-
temperature: f64,
|
|
260
|
-
max_new_tokens: usize,
|
|
261
|
-
top_p: f32,
|
|
262
|
-
dataset_db: &str,
|
|
263
|
-
genedb: &str,
|
|
264
|
-
ai_json: &AiJsonFormat,
|
|
265
|
-
ai_route: &str,
|
|
266
|
-
testing: bool,
|
|
267
|
-
) -> Option<String> {
|
|
268
|
-
let mut classification: String = classify_query_by_dataset_type(
|
|
269
|
-
user_input,
|
|
270
|
-
comp_model.clone(),
|
|
271
|
-
embedding_model.clone(),
|
|
272
|
-
&llm_backend_type,
|
|
273
|
-
temperature,
|
|
274
|
-
max_new_tokens,
|
|
275
|
-
top_p,
|
|
276
|
-
ai_route,
|
|
277
|
-
)
|
|
278
|
-
.await;
|
|
279
|
-
classification = classification.replace("\"", "");
|
|
280
|
-
let final_output;
|
|
281
|
-
if classification == "dge".to_string() {
|
|
282
|
-
let de_result = extract_DE_search_terms_from_query(
|
|
283
|
-
user_input,
|
|
284
|
-
comp_model,
|
|
285
|
-
embedding_model,
|
|
286
|
-
&llm_backend_type,
|
|
287
|
-
temperature,
|
|
288
|
-
max_new_tokens,
|
|
289
|
-
top_p,
|
|
290
|
-
)
|
|
291
|
-
.await;
|
|
292
|
-
if testing == true {
|
|
293
|
-
final_output = format!(
|
|
294
|
-
"{{\"{}\":\"{}\",\"{}\":[{}}}",
|
|
295
|
-
"action",
|
|
296
|
-
"dge",
|
|
297
|
-
"DE_output",
|
|
298
|
-
de_result + &"]"
|
|
299
|
-
);
|
|
300
|
-
} else {
|
|
301
|
-
final_output = format!(
|
|
302
|
-
"{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
|
|
303
|
-
"type", "html", "html", "DE agent not implemented yet"
|
|
304
|
-
);
|
|
305
|
-
}
|
|
306
|
-
} else if classification == "summary".to_string() {
|
|
307
|
-
final_output = extract_summary_information(
|
|
308
|
-
user_input,
|
|
309
|
-
comp_model,
|
|
310
|
-
embedding_model,
|
|
311
|
-
&llm_backend_type,
|
|
312
|
-
temperature,
|
|
313
|
-
max_new_tokens,
|
|
314
|
-
top_p,
|
|
315
|
-
dataset_db,
|
|
316
|
-
genedb,
|
|
317
|
-
ai_json,
|
|
318
|
-
testing,
|
|
319
|
-
)
|
|
320
|
-
.await;
|
|
321
|
-
} else if classification == "hierarchical".to_string() {
|
|
322
|
-
// Not implemented yet
|
|
323
|
-
if testing == true {
|
|
324
|
-
final_output = format!("{{\"{}\":\"{}\"}}", "action", "hierarchical");
|
|
325
|
-
} else {
|
|
326
|
-
final_output = format!(
|
|
327
|
-
"{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
|
|
328
|
-
"type", "html", "html", "hierarchical clustering agent not implemented yet"
|
|
329
|
-
);
|
|
330
|
-
}
|
|
331
|
-
} else if classification == "snv_indel".to_string() {
|
|
332
|
-
// Not implemented yet
|
|
333
|
-
if testing == true {
|
|
334
|
-
final_output = format!("{{\"{}\":\"{}\"}}", "action", "snv_indel");
|
|
335
|
-
} else {
|
|
336
|
-
final_output = format!(
|
|
337
|
-
"{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
|
|
338
|
-
"type", "html", "html", "snv_indel agent not implemented yet"
|
|
339
|
-
);
|
|
340
|
-
}
|
|
341
|
-
} else if classification == "cnv".to_string() {
|
|
342
|
-
// Not implemented yet
|
|
343
|
-
if testing == true {
|
|
344
|
-
final_output = format!("{{\"{}\":\"{}\"}}", "action", "cnv");
|
|
345
|
-
} else {
|
|
346
|
-
final_output = format!(
|
|
347
|
-
"{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
|
|
348
|
-
"type", "html", "html", "cnv agent not implemented yet"
|
|
349
|
-
);
|
|
350
|
-
}
|
|
351
|
-
} else if classification == "variant_calling".to_string() {
|
|
352
|
-
// Not implemented yet and will never be supported. Need a separate messages for this
|
|
353
|
-
if testing == true {
|
|
354
|
-
final_output = format!("{{\"{}\":\"{}\"}}", "action", "variant_calling");
|
|
355
|
-
} else {
|
|
356
|
-
final_output = format!(
|
|
357
|
-
"{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
|
|
358
|
-
"type", "html", "html", "variant_calling agent not implemented yet"
|
|
359
|
-
);
|
|
360
|
-
}
|
|
361
|
-
} else if classification == "survival".to_string() {
|
|
362
|
-
// Not implemented yet
|
|
363
|
-
if testing == true {
|
|
364
|
-
final_output = format!("{{\"{}\":\"{}\"}}", "action", "surivial");
|
|
365
|
-
} else {
|
|
366
|
-
final_output = format!(
|
|
367
|
-
"{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
|
|
368
|
-
"type", "html", "html", "survival agent not implemented yet"
|
|
369
|
-
);
|
|
370
|
-
}
|
|
371
|
-
} else if classification == "none".to_string() {
|
|
372
|
-
if testing == true {
|
|
373
|
-
final_output = format!(
|
|
374
|
-
"{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
|
|
375
|
-
"action", "none", "message", "The input query did not match any known features in Proteinpaint"
|
|
376
|
-
);
|
|
377
|
-
} else {
|
|
378
|
-
final_output = format!(
|
|
379
|
-
"{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
|
|
380
|
-
"type", "html", "html", "The input query did not match any known features in Proteinpaint"
|
|
381
|
-
);
|
|
382
|
-
}
|
|
383
|
-
} else {
|
|
384
|
-
if testing == true {
|
|
385
|
-
final_output = format!("{{\"{}\":\"{}\"}}", "action", "unknown:".to_string() + &classification);
|
|
386
|
-
} else {
|
|
387
|
-
final_output = format!(
|
|
388
|
-
"{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
|
|
389
|
-
"type",
|
|
390
|
-
"html",
|
|
391
|
-
"html",
|
|
392
|
-
"unknown:".to_string() + &classification
|
|
393
|
-
);
|
|
394
|
-
}
|
|
395
|
-
}
|
|
396
|
-
Some(final_output)
|
|
397
|
-
}
|
|
398
|
-
|
|
399
|
-
#[allow(dead_code)]
|
|
400
|
-
pub async fn classify_query_by_dataset_type(
|
|
401
|
-
user_input: &str,
|
|
402
|
-
comp_model: impl rig::completion::CompletionModel + 'static,
|
|
403
|
-
_embedding_model: impl rig::embeddings::EmbeddingModel + 'static,
|
|
404
|
-
llm_backend_type: &llm_backend,
|
|
405
|
-
temperature: f64,
|
|
406
|
-
max_new_tokens: usize,
|
|
407
|
-
top_p: f32,
|
|
408
|
-
ai_route: &str,
|
|
409
|
-
) -> String {
|
|
410
|
-
// Read the file
|
|
411
|
-
let ai_route_data = fs::read_to_string(ai_route).unwrap();
|
|
412
|
-
|
|
413
|
-
// Parse the JSON data
|
|
414
|
-
let ai_json: Value = serde_json::from_str(&ai_route_data).expect("AI JSON file does not have the correct format");
|
|
415
|
-
|
|
416
|
-
// Create a string to hold the file contents
|
|
417
|
-
let mut contents = String::from("");
|
|
418
|
-
|
|
419
|
-
if let Some(object) = ai_json.as_object() {
|
|
420
|
-
contents = object["general"].to_string();
|
|
421
|
-
for (key, value) in object {
|
|
422
|
-
if key != "general" {
|
|
423
|
-
contents += &value.as_str().unwrap();
|
|
424
|
-
contents += "---"; // Adding delimiter
|
|
425
|
-
}
|
|
426
|
-
}
|
|
427
|
-
}
|
|
428
|
-
|
|
429
|
-
// Removing the last "---" characters
|
|
430
|
-
contents.pop();
|
|
431
|
-
contents.pop();
|
|
432
|
-
contents.pop();
|
|
433
|
-
|
|
434
|
-
// Split the contents by the delimiter "---"
|
|
435
|
-
let parts: Vec<&str> = contents.split("---").collect();
|
|
436
|
-
let schema_json: Value = serde_json::to_value(schemars::schema_for!(OutputJson)).unwrap(); // error handling here
|
|
437
|
-
let schema_json_string = serde_json::to_string_pretty(&schema_json).unwrap();
|
|
438
|
-
|
|
439
|
-
let additional;
|
|
440
|
-
match llm_backend_type {
|
|
441
|
-
llm_backend::Ollama() => {
|
|
442
|
-
additional = json!({
|
|
443
|
-
"max_new_tokens": max_new_tokens,
|
|
444
|
-
"top_p": top_p,
|
|
445
|
-
"schema_json": schema_json_string
|
|
446
|
-
});
|
|
447
|
-
}
|
|
448
|
-
llm_backend::Sj() => {
|
|
449
|
-
additional = json!({
|
|
450
|
-
"max_new_tokens": max_new_tokens,
|
|
451
|
-
"top_p": top_p
|
|
452
|
-
});
|
|
453
|
-
}
|
|
454
|
-
}
|
|
455
|
-
|
|
456
|
-
// Print the separated parts
|
|
457
|
-
let mut rag_docs = Vec::<String>::new();
|
|
458
|
-
for (_i, part) in parts.iter().enumerate() {
|
|
459
|
-
//println!("Part {}: {}", i + 1, part.trim());
|
|
460
|
-
rag_docs.push(part.trim().to_string())
|
|
461
|
-
}
|
|
462
|
-
|
|
463
|
-
//let top_k: usize = 3; // Embedding model not used currently
|
|
464
|
-
// Create embeddings and add to vector store
|
|
465
|
-
//let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
|
|
466
|
-
// .documents(rag_docs)
|
|
467
|
-
// .expect("Reason1")
|
|
468
|
-
// .build()
|
|
469
|
-
// .await
|
|
470
|
-
// .unwrap();
|
|
471
|
-
|
|
472
|
-
//// Create vector store
|
|
473
|
-
//let mut vector_store = InMemoryVectorStore::<String>::default();
|
|
474
|
-
//InMemoryVectorStore::add_documents(&mut vector_store, embeddings);
|
|
475
|
-
|
|
476
|
-
// Create RAG agent
|
|
477
|
-
let agent = AgentBuilder::new(comp_model).preamble(&(String::from("Generate classification for the user query into summary, dge, hierarchical, snv_indel, cnv, variant_calling, sv_fusion and none categories. Return output in JSON with ALWAYS a single word answer { \"answer\": \"dge\" }, that is 'summary' for summary plot, 'dge' for differential gene expression, 'hierarchical' for hierarchical clustering, 'snv_indel' for SNV/Indel, 'cnv' for CNV and 'sv_fusion' for SV/fusion, 'variant_calling' for variant calling, 'surivial' for survival data, 'none' for none of the previously described categories. The summary plot list and summarizes the cohort of patients according to the user query. The answer should always be in lower case\n The options are as follows:\n") + &contents + "\nQuestion= {question} \nanswer")).temperature(temperature).additional_params(additional).build();
|
|
478
|
-
//.dynamic_context(top_k, vector_store.index(embedding_model))
|
|
479
|
-
|
|
480
|
-
let response = agent.prompt(user_input).await.expect("Failed to prompt server");
|
|
481
|
-
|
|
482
|
-
//println!("Ollama: {}", response);
|
|
483
|
-
let result = response.replace("json", "").replace("```", "");
|
|
484
|
-
let json_value: Value = serde_json::from_str(&result).expect("REASON");
|
|
485
|
-
match llm_backend_type {
|
|
486
|
-
llm_backend::Ollama() => {
|
|
487
|
-
let json_value2: Value = serde_json::from_str(&json_value["content"].to_string()).expect("REASON2");
|
|
488
|
-
let json_value3: Value = serde_json::from_str(&json_value2.as_str().unwrap()).expect("REASON3");
|
|
489
|
-
json_value3["answer"].to_string()
|
|
490
|
-
}
|
|
491
|
-
llm_backend::Sj() => {
|
|
492
|
-
let json_value2: Value =
|
|
493
|
-
serde_json::from_str(&json_value[0]["generated_text"].to_string()).expect("REASON2");
|
|
494
|
-
//println!("json_value2:{}", json_value2.as_str().unwrap());
|
|
495
|
-
let json_value3: Value = serde_json::from_str(&json_value2.as_str().unwrap()).expect("REASON3");
|
|
496
|
-
//let json_value3: Value = serde_json::from_str(&json_value2["answer"].to_string()).expect("REASON2");
|
|
497
|
-
//println!("Classification result:{}", json_value3["answer"]);
|
|
498
|
-
json_value3["answer"].to_string()
|
|
499
|
-
}
|
|
500
|
-
}
|
|
501
|
-
}
|
|
502
|
-
|
|
503
|
-
// DE JSON output schema
|
|
504
|
-
|
|
505
|
-
#[allow(non_camel_case_types)]
|
|
506
|
-
#[derive(Debug, JsonSchema)]
|
|
507
|
-
#[allow(dead_code)]
|
|
508
|
-
enum cutoff_info {
|
|
509
|
-
lesser(f32),
|
|
510
|
-
greater(f32),
|
|
511
|
-
equalto(f32),
|
|
512
|
-
}
|
|
513
|
-
|
|
514
|
-
#[derive(Debug, JsonSchema)]
|
|
515
|
-
#[allow(dead_code)]
|
|
516
|
-
struct Cutoff {
|
|
517
|
-
cutoff_name: cutoff_info,
|
|
518
|
-
units: Option<String>,
|
|
519
|
-
}
|
|
520
|
-
|
|
521
|
-
#[derive(Debug, JsonSchema)]
|
|
522
|
-
#[allow(dead_code)]
|
|
523
|
-
struct Filter {
|
|
524
|
-
name: String,
|
|
525
|
-
cutoff: Cutoff,
|
|
526
|
-
}
|
|
527
|
-
|
|
528
|
-
#[derive(Debug, JsonSchema)]
|
|
529
|
-
#[allow(dead_code)]
|
|
530
|
-
struct Group {
|
|
531
|
-
name: String,
|
|
532
|
-
filter: Filter,
|
|
533
|
-
}
|
|
534
|
-
|
|
535
|
-
#[derive(Debug, JsonSchema)]
|
|
536
|
-
#[allow(dead_code)]
|
|
537
|
-
struct DEOutput {
|
|
538
|
-
group1: Group,
|
|
539
|
-
group2: Group,
|
|
540
|
-
}
|
|
541
|
-
|
|
542
|
-
#[allow(dead_code)]
|
|
543
|
-
#[allow(non_snake_case)]
|
|
544
|
-
async fn extract_DE_search_terms_from_query(
|
|
545
|
-
user_input: &str,
|
|
546
|
-
comp_model: impl rig::completion::CompletionModel + 'static,
|
|
547
|
-
embedding_model: impl rig::embeddings::EmbeddingModel + 'static,
|
|
548
|
-
llm_backend_type: &llm_backend,
|
|
549
|
-
temperature: f64,
|
|
550
|
-
max_new_tokens: usize,
|
|
551
|
-
top_p: f32,
|
|
552
|
-
) -> String {
|
|
553
|
-
let contents = String::from("Differential Gene Expression (DGE or DE) is a technique where the most upregulated and downregulated genes between two cohorts of samples (or patients) are determined. A volcano plot is shown with fold-change in the x-axis and adjusted p-value on the y-axis. So, the upregulated and downregulared genes are on opposite sides of the graph and the most significant genes (based on adjusted p-value) is on the top of the graph.
|
|
554
|
-
|
|
555
|
-
The user may select a cutoff for a continuous variables such as age. In such cases the group should only include the range specified by the user. Inside the JSON each entry the name of the group must be inside the field \"name\". For the cutoff (if provided) a field called \"cutoff\" must be provided which should contain a subfield \"name\" containing the name of the cutoff, followed by \"greater\"/\"lesser\"/\"equal\" to followed by the numeric value of the cutoff. If the unit of the variable is provided such as cm,m,inches,celsius etc. then add it to a separate field called \"units\".
|
|
556
|
-
|
|
557
|
-
Example input user queries:
|
|
558
|
-
When two groups are found give the following JSON output show {\"group1\": \"groupA\", \"group2\": \"groupB\"}
|
|
559
|
-
User query1: \"Show me the differential gene expression plot for groups groupA and groupB\"
|
|
560
|
-
Output JSON query1: {\"group1\": {\"name\": \"groupA\"}, \"group2\": {\"name\": \"groupB\"}}
|
|
561
|
-
|
|
562
|
-
User query2: \"Show volcano plot for White vs Black\"
|
|
563
|
-
Output JSON query2: {\"group1\": {\"name\": \"White\"}, \"group2\": {\"name\": \"Black\"}}
|
|
564
|
-
|
|
565
|
-
In case no suitable groups are found, show {\"output\":\"No suitable two groups found for differential gene expression\"}
|
|
566
|
-
User query3: \"Who wants to have vodka?\"
|
|
567
|
-
Output JSON query3: {\"output\":\"No suitable two groups found for differential gene expression\"}
|
|
568
|
-
|
|
569
|
-
User query4: \"Show volcano plot for Asians with age less than 20 and African greater than 80\"
|
|
570
|
-
Output JSON query4: {\"group1\": {\"name\": \"Asians\", \"filter\": {\"name\": \"age\", \"cutoff\": {\"lesser\": 20}}}, \"group2\": {\"name\": \"African\", \"filter\": {\"name\": \"age\", \"cutoff\": {\"greater\": 80}}}}
|
|
571
|
-
|
|
572
|
-
User query5: \"Show Differential gene expression plot for males with height greater than 185cm and women with less than 100cm\"
|
|
573
|
-
Output JSON query5: {\"group1\": {\"name\": \"males\", \"filter\": {\"name\": \"height\", \"cutoff\": {\"greater\": 185, \"units\":\"cm\"}}}, \"group2\": {\"name\": \"women\", \"filter\": {\"name\": \"height\", \"cutoff\": {\"lesser\": 100, \"units\": \"cm\"}}}}");
|
|
574
|
-
|
|
575
|
-
// Split the contents by the delimiter "---"
|
|
576
|
-
let parts: Vec<&str> = contents.split("---").collect();
|
|
577
|
-
|
|
578
|
-
let schema_json: Value = serde_json::to_value(schemars::schema_for!(DEOutput)).unwrap(); // error handling here
|
|
579
|
-
let schema_json_string = serde_json::to_string_pretty(&schema_json).unwrap();
|
|
580
|
-
//println!("DE schema:{}", schema_json);
|
|
581
|
-
|
|
582
|
-
let additional;
|
|
583
|
-
match llm_backend_type {
|
|
584
|
-
llm_backend::Ollama() => {
|
|
585
|
-
additional = json!({
|
|
586
|
-
"max_new_tokens": max_new_tokens,
|
|
587
|
-
"top_p": top_p,
|
|
588
|
-
"schema_json": schema_json_string
|
|
589
|
-
});
|
|
590
|
-
}
|
|
591
|
-
llm_backend::Sj() => {
|
|
592
|
-
additional = json!({
|
|
593
|
-
"max_new_tokens": max_new_tokens,
|
|
594
|
-
"top_p": top_p
|
|
595
|
-
});
|
|
596
|
-
}
|
|
597
|
-
}
|
|
598
|
-
|
|
599
|
-
// Print the separated parts
|
|
600
|
-
let mut rag_docs = Vec::<String>::new();
|
|
601
|
-
for (_i, part) in parts.iter().enumerate() {
|
|
602
|
-
//println!("Part {}: {}", i + 1, part.trim());
|
|
603
|
-
rag_docs.push(part.trim().to_string())
|
|
604
|
-
}
|
|
605
|
-
|
|
606
|
-
let rag_docs_length = rag_docs.len();
|
|
607
|
-
// Create embeddings and add to vector store
|
|
608
|
-
let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
|
|
609
|
-
.documents(rag_docs)
|
|
610
|
-
.expect("Reason1")
|
|
611
|
-
.build()
|
|
612
|
-
.await
|
|
613
|
-
.unwrap();
|
|
614
|
-
|
|
615
|
-
// Create vector store
|
|
616
|
-
let mut vector_store = InMemoryVectorStore::<String>::default();
|
|
617
|
-
InMemoryVectorStore::add_documents(&mut vector_store, embeddings);
|
|
618
|
-
|
|
619
|
-
// Create RAG agent
|
|
620
|
-
let router_instructions = String::from(
|
|
621
|
-
"Extract the group variable names for differential gene expression from input query. When two groups are found give the following JSON output with no extra comments. Show {{\"group1\": {\"name\": \"groupA\"}, \"group2\": {\"name\": \"groupB\"}}}. In case no suitable groups are found, show {\"output\":\"No suitable two groups found for differential gene expression\"}. In case of a continuous variable such as age, height added additional field to the group called \"filter\". This should contain a sub-field called \"names\" followed by a subfield called \"cutoff\". This sub-field should contain a key either greater, lesser or equalto. If the continuous variable has units provided by the user then add it in a separate field called \"units\".",
|
|
622
|
-
) + &contents
|
|
623
|
-
+ " The JSON schema is as follows"
|
|
624
|
-
+ &schema_json_string
|
|
625
|
-
+ "\n Examples: User query1: \"Show volcano plot for Asians with age less than 20 and African greater than 80\". Output JSON query1: {\"group1\": {\"name\": \"Asians\", \"filter\": {\"name\": \"age\", \"cutoff\": {\"lesser\": 20}}}, \"group2\": {\"name\": \"African\", \"filter\": {\"name\": \"age\", \"cutoff\": {\"greater\": 80}}}}. User query2: \"Show Differential gene expression plot for males with height greater than 185cm and women with less than 100cm\". Output JSON query2: {\"group1\": {\"name\": \"males\", \"filter\": {\"name\": \"height\", \"cutoff\": {\"greater\": 185, \"units\":\"cm\"}}}, \"group2\": {\"name\": \"women\", \"filter\": {\"name\": \"height\", \"cutoff\": {\"lesser\": 100, \"units\": \"cm\"}}}}. User query3: \"Show DE plot between healthy and diseased groups. Output JSON query3: {\"group1\":{\"name\":\"healthy\"},\"group2\":{\"name\":\"diseased\"}} \nQuestion= {question} \nanswer";
|
|
626
|
-
//println! {"router_instructions:{}",router_instructions};
|
|
627
|
-
let agent = AgentBuilder::new(comp_model)
|
|
628
|
-
.preamble(&router_instructions)
|
|
629
|
-
.dynamic_context(rag_docs_length, vector_store.index(embedding_model))
|
|
630
|
-
.temperature(temperature)
|
|
631
|
-
.additional_params(additional)
|
|
632
|
-
.build();
|
|
633
|
-
|
|
634
|
-
let response = agent.prompt(user_input).await.expect("Failed to prompt server");
|
|
635
|
-
|
|
636
|
-
//println!("Ollama_groups: {}", response);
|
|
637
|
-
let result = response.replace("json", "").replace("```", "");
|
|
638
|
-
//println!("result_groups:{}", result);
|
|
639
|
-
let json_value: Value = serde_json::from_str(&result).expect("REASON");
|
|
640
|
-
//println!("json_value:{}", json_value);
|
|
641
|
-
match llm_backend_type {
|
|
642
|
-
llm_backend::Ollama() => {
|
|
643
|
-
let json_value2: Value = serde_json::from_str(&json_value["content"].to_string()).expect("REASON2");
|
|
644
|
-
//println!("json_value2:{:?}", json_value2);
|
|
645
|
-
let json_value3: Value = serde_json::from_str(&json_value2.as_str().unwrap()).expect("REASON3");
|
|
646
|
-
json_value3.to_string()
|
|
647
|
-
}
|
|
648
|
-
llm_backend::Sj() => {
|
|
649
|
-
let json_value2: Value =
|
|
650
|
-
serde_json::from_str(&json_value[0]["generated_text"].to_string()).expect("REASON2");
|
|
651
|
-
//println!("json_value2:{}", json_value2.as_str().unwrap());
|
|
652
|
-
let json_value3: Value = serde_json::from_str(&json_value2.as_str().unwrap()).expect("REASON2");
|
|
653
|
-
//println!("Classification result:{}", json_value3);
|
|
654
|
-
json_value3.to_string()
|
|
655
|
-
}
|
|
656
|
-
}
|
|
657
|
-
}
|
|
658
|
-
|
|
659
|
-
#[derive(Debug, Clone)]
|
|
660
|
-
struct DbRows {
|
|
661
|
-
name: String,
|
|
662
|
-
description: Option<String>,
|
|
663
|
-
term_type: Option<String>,
|
|
664
|
-
values: Vec<String>,
|
|
665
|
-
}
|
|
666
|
-
|
|
667
|
-
async fn parse_geneset_db(db: &str) -> Vec<String> {
|
|
668
|
-
let manager = SqliteConnectionManager::file(db);
|
|
669
|
-
let pool = r2d2::Pool::new(manager).unwrap();
|
|
670
|
-
let conn = pool.get().unwrap();
|
|
671
|
-
let sql_statement_genedb = "SELECT * from codingGenes";
|
|
672
|
-
let mut genedb = conn.prepare(&sql_statement_genedb).unwrap();
|
|
673
|
-
let mut rows_genedb = genedb.query([]).unwrap();
|
|
674
|
-
let mut gene_list = Vec::<String>::new();
|
|
675
|
-
while let Some(coding_gene) = rows_genedb.next().unwrap() {
|
|
676
|
-
let code_gene: String = coding_gene.get(0).unwrap();
|
|
677
|
-
gene_list.push(code_gene)
|
|
678
|
-
}
|
|
679
|
-
gene_list
|
|
680
|
-
}
|
|
681
|
-
|
|
682
|
-
trait ParseDbRows {
|
|
683
|
-
fn parse_db_rows(&self) -> String;
|
|
684
|
-
}
|
|
685
|
-
|
|
686
|
-
impl ParseDbRows for DbRows {
|
|
687
|
-
fn parse_db_rows(&self) -> String {
|
|
688
|
-
let mut output: String = "Name of field is \"".to_string() + &self.name + &"\". ";
|
|
689
|
-
|
|
690
|
-
match &self.term_type {
|
|
691
|
-
Some(item_ty) => {
|
|
692
|
-
output += "This field is of the type ";
|
|
693
|
-
output += &item_ty;
|
|
694
|
-
output += &". ";
|
|
695
|
-
}
|
|
696
|
-
None => {}
|
|
697
|
-
}
|
|
698
|
-
match &self.description {
|
|
699
|
-
Some(desc) => output += desc,
|
|
700
|
-
None => {}
|
|
701
|
-
}
|
|
702
|
-
if self.values.len() > 0 {
|
|
703
|
-
output += "This contains the following values (separated by comma(,)):";
|
|
704
|
-
output += &(self.values.join(",") + &".");
|
|
705
|
-
}
|
|
706
|
-
output
|
|
707
|
-
}
|
|
708
|
-
}
|
|
709
|
-
|
|
710
|
-
async fn parse_dataset_db(db: &str) -> (Vec<String>, Vec<DbRows>) {
|
|
711
|
-
let manager = SqliteConnectionManager::file(db);
|
|
712
|
-
let pool = r2d2::Pool::new(manager).unwrap();
|
|
713
|
-
let conn = pool.get().unwrap();
|
|
714
|
-
|
|
715
|
-
let sql_statement_termhtmldef = "SELECT * from termhtmldef";
|
|
716
|
-
let mut termhtmldef = conn.prepare(&sql_statement_termhtmldef).unwrap();
|
|
717
|
-
let mut rows_termhtmldef = termhtmldef.query([]).unwrap();
|
|
718
|
-
let mut description_map = HashMap::new();
|
|
719
|
-
while let Some(row) = rows_termhtmldef.next().unwrap() {
|
|
720
|
-
//println!("row:{:?}", row);
|
|
721
|
-
let name: String = row.get(0).unwrap();
|
|
722
|
-
//println!("name:{}", name);
|
|
723
|
-
let json_html_str: String = row.get(1).unwrap();
|
|
724
|
-
let json_html: Value = serde_json::from_str(&json_html_str).expect("Not a JSON");
|
|
725
|
-
let json_html2: &Map<String, Value> = json_html.as_object().unwrap();
|
|
726
|
-
let description: String = String::from(
|
|
727
|
-
json_html2.get("description").unwrap()[0]
|
|
728
|
-
.as_object()
|
|
729
|
-
.unwrap()
|
|
730
|
-
.get("value")
|
|
731
|
-
.unwrap()
|
|
732
|
-
.as_str()
|
|
733
|
-
.unwrap(),
|
|
734
|
-
);
|
|
735
|
-
//println!("description:{}", description);
|
|
736
|
-
description_map.insert(name, description);
|
|
737
|
-
}
|
|
738
|
-
|
|
739
|
-
//// Open the file
|
|
740
|
-
//let mut file = File::open(dataset_file).unwrap();
|
|
741
|
-
|
|
742
|
-
//// Create a string to hold the file contents
|
|
743
|
-
//let mut contents = String::new();
|
|
744
|
-
|
|
745
|
-
//// Read the file contents into the string
|
|
746
|
-
//file.read_to_string(&mut contents).unwrap();
|
|
747
|
-
|
|
748
|
-
//// Split the contents by the delimiter "---"
|
|
749
|
-
//let parts: Vec<&str> = contents.split("\n").collect();
|
|
750
|
-
|
|
751
|
-
//for (_i, part) in parts.iter().enumerate() {
|
|
752
|
-
// let sentence: &str = part.trim();
|
|
753
|
-
// let parts2: Vec<&str> = sentence.split(':').collect();
|
|
754
|
-
// //println!("parts2:{:?}", parts2);
|
|
755
|
-
// if parts2.len() == 2 {
|
|
756
|
-
// description_map.insert(parts2[0], parts2[1]);
|
|
757
|
-
// //println!("Part {}: {:?}", i + 1, parts2);
|
|
758
|
-
// }
|
|
759
|
-
//}
|
|
760
|
-
//println!("description_map:{:?}", description_map);
|
|
761
|
-
|
|
762
|
-
let sql_statement_terms = "SELECT * from terms";
|
|
763
|
-
let mut terms = conn.prepare(&sql_statement_terms).unwrap();
|
|
764
|
-
let mut rows_terms = terms.query([]).unwrap();
|
|
765
|
-
|
|
766
|
-
// Print the separated parts
|
|
767
|
-
let mut rag_docs = Vec::<String>::new();
|
|
768
|
-
let mut names = Vec::<String>::new();
|
|
769
|
-
let mut db_vec = Vec::<DbRows>::new();
|
|
770
|
-
while let Some(row) = rows_terms.next().unwrap() {
|
|
771
|
-
//println!("row:{:?}", row);
|
|
772
|
-
let name: String = row.get(0).unwrap();
|
|
773
|
-
//println!("id:{}", name);
|
|
774
|
-
match description_map.get(&name as &str) {
|
|
775
|
-
Some(desc) => {
|
|
776
|
-
let line: String = row.get(3).unwrap();
|
|
777
|
-
//println!("line:{}", line);
|
|
778
|
-
let json_data: Value = serde_json::from_str(&line).expect("Not a JSON");
|
|
779
|
-
let values_json = json_data["values"].as_object();
|
|
780
|
-
let mut keys = Vec::<String>::new();
|
|
781
|
-
match values_json {
|
|
782
|
-
Some(values) => {
|
|
783
|
-
for (key, _value) in values {
|
|
784
|
-
keys.push(key.to_string())
|
|
785
|
-
}
|
|
786
|
-
}
|
|
787
|
-
None => {}
|
|
788
|
-
}
|
|
789
|
-
|
|
790
|
-
let item_type_json = json_data["type"].as_str();
|
|
791
|
-
let mut item_type: Option<String> = None;
|
|
792
|
-
match item_type_json {
|
|
793
|
-
Some(item_ty) => item_type = Some(String::from(item_ty)),
|
|
794
|
-
None => {}
|
|
795
|
-
}
|
|
796
|
-
|
|
797
|
-
//println!("items:{:?}", keys);
|
|
798
|
-
let item: DbRows = DbRows {
|
|
799
|
-
name: name.clone(),
|
|
800
|
-
description: Some(String::from(desc.clone())),
|
|
801
|
-
term_type: item_type,
|
|
802
|
-
values: keys,
|
|
803
|
-
};
|
|
804
|
-
db_vec.push(item.clone());
|
|
805
|
-
//println!("Field details:{}", item.parse_db_rows());
|
|
806
|
-
rag_docs.push(item.parse_db_rows());
|
|
807
|
-
names.push(name)
|
|
808
|
-
}
|
|
809
|
-
None => {}
|
|
810
|
-
}
|
|
811
|
-
}
|
|
812
|
-
//println!("names:{:?}", names);
|
|
813
|
-
(rag_docs, db_vec)
|
|
814
|
-
}
|
|
815
|
-
|
|
816
|
-
pub async fn extract_summary_information(
|
|
817
|
-
user_input: &str,
|
|
818
|
-
comp_model: impl rig::completion::CompletionModel + 'static,
|
|
819
|
-
_embedding_model: impl rig::embeddings::EmbeddingModel + 'static,
|
|
820
|
-
llm_backend_type: &llm_backend,
|
|
821
|
-
temperature: f64,
|
|
822
|
-
max_new_tokens: usize,
|
|
823
|
-
top_p: f32,
|
|
824
|
-
dataset_db: &str,
|
|
825
|
-
genedb: &str,
|
|
826
|
-
ai_json: &AiJsonFormat,
|
|
827
|
-
testing: bool,
|
|
828
|
-
) -> String {
|
|
829
|
-
let (rag_docs, db_vec) = parse_dataset_db(dataset_db).await;
|
|
830
|
-
let additional;
|
|
831
|
-
let schema_json = schemars::schema_for!(SummaryType); // error handling here
|
|
832
|
-
let schema_json_string = serde_json::to_string_pretty(&schema_json).unwrap();
|
|
833
|
-
//println!("schema_json summary:{}", schema_json_string);
|
|
834
|
-
match llm_backend_type {
|
|
835
|
-
llm_backend::Ollama() => {
|
|
836
|
-
additional = json!({
|
|
837
|
-
"max_new_tokens": max_new_tokens,
|
|
838
|
-
"top_p": top_p,
|
|
839
|
-
"schema_json": schema_json_string
|
|
840
|
-
});
|
|
841
|
-
}
|
|
842
|
-
llm_backend::Sj() => {
|
|
843
|
-
additional = json!({
|
|
844
|
-
"max_new_tokens": max_new_tokens,
|
|
845
|
-
"top_p": top_p
|
|
846
|
-
});
|
|
847
|
-
}
|
|
848
|
-
}
|
|
849
|
-
|
|
850
|
-
// Create embeddings and add to vector store
|
|
851
|
-
//let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
|
|
852
|
-
// .documents(rag_docs)
|
|
853
|
-
// .expect("Reason1")
|
|
854
|
-
// .build()
|
|
855
|
-
// .await
|
|
856
|
-
// .unwrap();
|
|
857
|
-
|
|
858
|
-
//// Create vector store
|
|
859
|
-
//let mut vector_store = InMemoryVectorStore::<String>::default();
|
|
860
|
-
//InMemoryVectorStore::add_documents(&mut vector_store, embeddings);
|
|
861
|
-
|
|
862
|
-
let gene_list: Vec<String> = parse_geneset_db(genedb).await;
|
|
863
|
-
let lowercase_user_input = user_input.to_lowercase();
|
|
864
|
-
let user_words: Vec<&str> = lowercase_user_input.split_whitespace().collect();
|
|
865
|
-
let user_words2: Vec<String> = user_words.into_iter().map(|s| s.to_string()).collect();
|
|
866
|
-
|
|
867
|
-
let common_genes: Vec<String> = gene_list
|
|
868
|
-
.into_iter()
|
|
869
|
-
.filter(|x| user_words2.contains(&x.to_lowercase()))
|
|
870
|
-
.collect();
|
|
871
|
-
|
|
872
|
-
let mut summary_data_check: Option<TrainTestData> = None;
|
|
873
|
-
for chart in ai_json.charts.clone() {
|
|
874
|
-
if chart.r#type == "Summary" {
|
|
875
|
-
summary_data_check = Some(chart);
|
|
876
|
-
break;
|
|
877
|
-
}
|
|
878
|
-
}
|
|
879
|
-
|
|
880
|
-
match summary_data_check {
|
|
881
|
-
Some(summary_data) => {
|
|
882
|
-
let mut training_data: String = String::from("");
|
|
883
|
-
let mut train_iter = 0;
|
|
884
|
-
for ques_ans in summary_data.TrainingData {
|
|
885
|
-
match ques_ans.answer {
|
|
886
|
-
AnswerFormat::summary_type(sum) => {
|
|
887
|
-
let summary_answer: SummaryType = sum;
|
|
888
|
-
train_iter += 1;
|
|
889
|
-
training_data += "Example question";
|
|
890
|
-
training_data += &train_iter.to_string();
|
|
891
|
-
training_data += &":";
|
|
892
|
-
training_data += &ques_ans.question;
|
|
893
|
-
training_data += &" ";
|
|
894
|
-
training_data += "Example answer";
|
|
895
|
-
training_data += &train_iter.to_string();
|
|
896
|
-
training_data += &":";
|
|
897
|
-
training_data += &serde_json::to_string(&summary_answer).unwrap();
|
|
898
|
-
training_data += &"\n";
|
|
899
|
-
}
|
|
900
|
-
AnswerFormat::DE_type(_) => panic!("DE type not valid for summary"),
|
|
901
|
-
}
|
|
902
|
-
}
|
|
903
|
-
|
|
904
|
-
let system_prompt: String = String::from(
|
|
905
|
-
String::from(
|
|
906
|
-
"I am an assistant that extracts the summary terms from user query. The final output must be in the following JSON format with NO extra comments. There are three fields in the JSON to be returned: The \"action\" field will ALWAYS be \"summary\". The \"summaryterms\" field should contain all the variables that the user wants to visualize. The \"clinical\" subfield should ONLY contain names of the fields from the sqlite db. ",
|
|
907
|
-
) + &summary_data.SystemPrompt
|
|
908
|
-
+ &" The \"filter\" field is optional and should contain an array of JSON terms with which the dataset will be filtered. A variable simultaneously CANNOT be part of both \"summaryterms\" and \"filter\". There are two kinds of filter variables: \"Categorical\" and \"Numeric\". \"Categorical\" variables are those variables which can have a fixed set of values e.g. gender, molecular subtypes. They are defined by the \"CategoricalFilterTerm\" which consists of \"term\" (a field from the sqlite3 db) and \"value\" (a value of the field from the sqlite db). \"Numeric\" variables are those which can have any numeric value. They are defined by \"NumericFilterTerm\" and contain the subfields \"term\" (a field from the sqlite3 db), \"greaterThan\" an optional filter which is defined when a lower cutoff is defined in the user input for the numeric variable and \"lessThan\" an optional filter which is defined when a higher cutoff is defined in the user input for the numeric variable. The \"message\" field only contain messages of terms in the user input that were not found in their respective databases. The JSON schema is as follows:"
|
|
909
|
-
+ &schema_json_string
|
|
910
|
-
+ &training_data
|
|
911
|
-
+ "The sqlite db in plain language is as follows:\n"
|
|
912
|
-
+ &rag_docs.join(",")
|
|
913
|
-
+ &"\n Relevant genes are as follows (separated by comma(,)):"
|
|
914
|
-
+ &common_genes.join(",")
|
|
915
|
-
+ &"\nQuestion: {question} \nanswer:",
|
|
916
|
-
);
|
|
917
|
-
|
|
918
|
-
//println!("system_prompt:{}", system_prompt);
|
|
919
|
-
// Create RAG agent
|
|
920
|
-
let agent = AgentBuilder::new(comp_model)
|
|
921
|
-
.preamble(&system_prompt)
|
|
922
|
-
//.dynamic_context(top_k, vector_store.index(embedding_model))
|
|
923
|
-
.temperature(temperature)
|
|
924
|
-
.additional_params(additional)
|
|
925
|
-
.build();
|
|
926
|
-
|
|
927
|
-
let response = agent.prompt(user_input).await.expect("Failed to prompt ollama");
|
|
928
|
-
|
|
929
|
-
//println!("Ollama: {}", response);
|
|
930
|
-
let result = response.replace("json", "").replace("```", "");
|
|
931
|
-
//println!("result:{}", result);
|
|
932
|
-
let json_value: Value = serde_json::from_str(&result).expect("REASON");
|
|
933
|
-
//println!("Classification result:{}", json_value);
|
|
934
|
-
|
|
935
|
-
let final_llm_json;
|
|
936
|
-
match llm_backend_type {
|
|
937
|
-
llm_backend::Ollama() => {
|
|
938
|
-
let json_value2: Value = serde_json::from_str(&json_value["content"].to_string()).expect("REASON2");
|
|
939
|
-
let json_value3: Value = serde_json::from_str(&json_value2.as_str().unwrap()).expect("REASON3");
|
|
940
|
-
final_llm_json = json_value3.to_string()
|
|
941
|
-
}
|
|
942
|
-
llm_backend::Sj() => {
|
|
943
|
-
let json_value2: Value =
|
|
944
|
-
serde_json::from_str(&json_value[0]["generated_text"].to_string()).expect("REASON2");
|
|
945
|
-
//println!("json_value2:{}", json_value2.as_str().unwrap());
|
|
946
|
-
let json_value3: Value = serde_json::from_str(&json_value2.as_str().unwrap()).expect("REASON3");
|
|
947
|
-
//println!("Classification result:{}", json_value3);
|
|
948
|
-
final_llm_json = json_value3.to_string()
|
|
949
|
-
}
|
|
950
|
-
}
|
|
951
|
-
//println!("final_llm_json:{}", final_llm_json);
|
|
952
|
-
let final_validated_json =
|
|
953
|
-
validate_summary_output(final_llm_json.clone(), db_vec, common_genes, ai_json, testing);
|
|
954
|
-
final_validated_json
|
|
955
|
-
}
|
|
956
|
-
None => {
|
|
957
|
-
panic!("summary chart train and test data is not defined in dataset JSON file")
|
|
958
|
-
}
|
|
959
|
-
}
|
|
960
|
-
}
|
|
961
|
-
|
|
962
|
-
fn get_summary_string() -> String {
|
|
963
|
-
"summary".to_string()
|
|
964
|
-
}
|
|
965
|
-
|
|
966
|
-
//const action: &str = &"summary";
|
|
967
|
-
//const geneExpression: &str = &"geneExpression";
|
|
968
|
-
|
|
969
|
-
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
970
|
-
pub struct SummaryType {
|
|
971
|
-
// Serde uses this for deserialization.
|
|
972
|
-
#[serde(default = "get_summary_string")]
|
|
973
|
-
// Schemars uses this for schema generation.
|
|
974
|
-
#[schemars(rename = "action")]
|
|
975
|
-
action: String,
|
|
976
|
-
summaryterms: Vec<SummaryTerms>,
|
|
977
|
-
filter: Option<Vec<FilterTerm>>,
|
|
978
|
-
message: Option<String>,
|
|
979
|
-
}
|
|
980
|
-
|
|
981
|
-
impl SummaryType {
|
|
982
|
-
#[allow(dead_code)]
|
|
983
|
-
pub fn sort_summarytype_struct(mut self) -> SummaryType {
|
|
984
|
-
// This function is necessary for testing (test_ai.rs) to see if two variables of type "SummaryType" are equal or not. Without this a vector of two Summarytype holding the same values but in different order will be classified separately.
|
|
985
|
-
self.summaryterms.sort();
|
|
986
|
-
|
|
987
|
-
match self.filter.clone() {
|
|
988
|
-
Some(ref mut filterterms) => filterterms.sort(),
|
|
989
|
-
None => {}
|
|
990
|
-
}
|
|
991
|
-
self.clone()
|
|
992
|
-
}
|
|
993
|
-
}
|
|
994
|
-
|
|
995
|
-
#[derive(PartialEq, Eq, Ord, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
996
|
-
enum SummaryTerms {
|
|
997
|
-
#[allow(non_camel_case_types)]
|
|
998
|
-
clinical(String),
|
|
999
|
-
#[allow(non_camel_case_types)]
|
|
1000
|
-
geneExpression(String),
|
|
1001
|
-
}
|
|
1002
|
-
|
|
1003
|
-
impl PartialOrd for SummaryTerms {
|
|
1004
|
-
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
|
1005
|
-
match (self, other) {
|
|
1006
|
-
(SummaryTerms::clinical(_), SummaryTerms::clinical(_)) => Some(std::cmp::Ordering::Equal),
|
|
1007
|
-
(SummaryTerms::geneExpression(_), SummaryTerms::geneExpression(_)) => Some(std::cmp::Ordering::Equal),
|
|
1008
|
-
(SummaryTerms::clinical(_), SummaryTerms::geneExpression(_)) => Some(std::cmp::Ordering::Greater),
|
|
1009
|
-
(SummaryTerms::geneExpression(_), SummaryTerms::clinical(_)) => Some(std::cmp::Ordering::Less),
|
|
1010
|
-
}
|
|
1011
|
-
}
|
|
1012
|
-
}
|
|
1013
|
-
|
|
1014
|
-
#[derive(PartialEq, Eq, Ord, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
1015
|
-
enum FilterTerm {
|
|
1016
|
-
Categorical(CategoricalFilterTerm),
|
|
1017
|
-
Numeric(NumericFilterTerm),
|
|
1018
|
-
}
|
|
1019
|
-
|
|
1020
|
-
impl PartialOrd for FilterTerm {
|
|
1021
|
-
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
|
1022
|
-
match (self, other) {
|
|
1023
|
-
(FilterTerm::Categorical(_), FilterTerm::Categorical(_)) => Some(std::cmp::Ordering::Equal),
|
|
1024
|
-
(FilterTerm::Numeric(_), FilterTerm::Numeric(_)) => Some(std::cmp::Ordering::Equal),
|
|
1025
|
-
(FilterTerm::Categorical(_), FilterTerm::Numeric(_)) => Some(std::cmp::Ordering::Greater),
|
|
1026
|
-
(FilterTerm::Numeric(_), FilterTerm::Categorical(_)) => Some(std::cmp::Ordering::Greater),
|
|
1027
|
-
}
|
|
1028
|
-
}
|
|
1029
|
-
}
|
|
1030
|
-
|
|
1031
|
-
#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
1032
|
-
struct CategoricalFilterTerm {
|
|
1033
|
-
term: String,
|
|
1034
|
-
value: String,
|
|
1035
|
-
}
|
|
1036
|
-
|
|
1037
|
-
#[derive(Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
1038
|
-
#[allow(non_snake_case)]
|
|
1039
|
-
struct NumericFilterTerm {
|
|
1040
|
-
term: String,
|
|
1041
|
-
greaterThan: Option<f32>,
|
|
1042
|
-
lessThan: Option<f32>,
|
|
1043
|
-
}
|
|
1044
|
-
|
|
1045
|
-
impl PartialEq for NumericFilterTerm {
|
|
1046
|
-
fn eq(&self, other: &Self) -> bool {
|
|
1047
|
-
let greater_equality: bool;
|
|
1048
|
-
match (self.greaterThan, other.greaterThan) {
|
|
1049
|
-
(Some(a), Some(b)) => greater_equality = (a - b).abs() < 1e-6,
|
|
1050
|
-
(None, None) => greater_equality = true,
|
|
1051
|
-
_ => greater_equality = false,
|
|
1052
|
-
}
|
|
1053
|
-
|
|
1054
|
-
let less_equality: bool;
|
|
1055
|
-
match (self.lessThan, other.lessThan) {
|
|
1056
|
-
(Some(a), Some(b)) => less_equality = (a - b).abs() < 1e-6,
|
|
1057
|
-
(None, None) => less_equality = true,
|
|
1058
|
-
_ => less_equality = false,
|
|
1059
|
-
}
|
|
1060
|
-
|
|
1061
|
-
if greater_equality == true && less_equality == true {
|
|
1062
|
-
true
|
|
1063
|
-
} else {
|
|
1064
|
-
false
|
|
1065
|
-
}
|
|
1066
|
-
}
|
|
1067
|
-
}
|
|
1068
|
-
|
|
1069
|
-
impl Eq for NumericFilterTerm {}
|
|
1070
|
-
|
|
1071
|
-
impl PartialOrd for NumericFilterTerm {
|
|
1072
|
-
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
|
1073
|
-
if self.greaterThan < other.greaterThan {
|
|
1074
|
-
Some(std::cmp::Ordering::Less)
|
|
1075
|
-
} else if self.greaterThan > other.greaterThan {
|
|
1076
|
-
Some(std::cmp::Ordering::Greater)
|
|
1077
|
-
} else if self.lessThan < other.lessThan {
|
|
1078
|
-
Some(std::cmp::Ordering::Less)
|
|
1079
|
-
} else if self.lessThan > other.lessThan {
|
|
1080
|
-
Some(std::cmp::Ordering::Greater)
|
|
1081
|
-
} else {
|
|
1082
|
-
Some(std::cmp::Ordering::Equal)
|
|
1083
|
-
}
|
|
1084
|
-
}
|
|
1085
|
-
}
|
|
1086
|
-
|
|
1087
|
-
impl Ord for NumericFilterTerm {
|
|
1088
|
-
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
|
1089
|
-
self.partial_cmp(other).unwrap()
|
|
1090
|
-
}
|
|
1091
|
-
}
|
|
1092
|
-
|
|
1093
|
-
fn validate_summary_output(
|
|
1094
|
-
raw_llm_json: String,
|
|
1095
|
-
db_vec: Vec<DbRows>,
|
|
1096
|
-
common_genes: Vec<String>,
|
|
1097
|
-
ai_json: &AiJsonFormat,
|
|
1098
|
-
testing: bool,
|
|
1099
|
-
) -> String {
|
|
1100
|
-
let json_value: SummaryType =
|
|
1101
|
-
serde_json::from_str(&raw_llm_json).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
|
|
1102
|
-
let mut message: String = String::from("");
|
|
1103
|
-
match json_value.message {
|
|
1104
|
-
Some(mes) => {
|
|
1105
|
-
message = message + &mes; // Append any message given by the LLM
|
|
1106
|
-
}
|
|
1107
|
-
None => {}
|
|
1108
|
-
}
|
|
1109
|
-
|
|
1110
|
-
let mut new_json: Value; // New JSON value that will contain items of the final validated JSON
|
|
1111
|
-
if json_value.action != String::from("summary") {
|
|
1112
|
-
message = message + &"Did not return a summary action";
|
|
1113
|
-
new_json = serde_json::json!(null);
|
|
1114
|
-
} else {
|
|
1115
|
-
new_json = serde_json::from_str(&"{\"action\":\"summary\"}").expect("Not a valid JSON");
|
|
1116
|
-
}
|
|
1117
|
-
|
|
1118
|
-
let mut validated_summary_terms = Vec::<SummaryTerms>::new();
|
|
1119
|
-
let mut summary_terms_tobe_removed = Vec::<SummaryTerms>::new();
|
|
1120
|
-
for sum_term in &json_value.summaryterms {
|
|
1121
|
-
match sum_term {
|
|
1122
|
-
SummaryTerms::clinical(clin) => {
|
|
1123
|
-
let term_verification = verify_json_field(clin, &db_vec);
|
|
1124
|
-
if Some(term_verification.correct_field.clone()).is_some()
|
|
1125
|
-
&& term_verification.correct_value.clone().is_none()
|
|
1126
|
-
{
|
|
1127
|
-
match term_verification.correct_field {
|
|
1128
|
-
Some(tm) => validated_summary_terms.push(SummaryTerms::clinical(tm)),
|
|
1129
|
-
None => {
|
|
1130
|
-
message = message + &"'" + &clin + &"'" + &" not found in db.";
|
|
1131
|
-
}
|
|
1132
|
-
}
|
|
1133
|
-
} else if Some(term_verification.correct_field.clone()).is_some()
|
|
1134
|
-
&& Some(term_verification.correct_value.clone()).is_some()
|
|
1135
|
-
{
|
|
1136
|
-
message = message
|
|
1137
|
-
+ &term_verification.correct_value.unwrap()
|
|
1138
|
-
+ &"is a value of "
|
|
1139
|
-
+ &term_verification.correct_field.unwrap()
|
|
1140
|
-
+ &".";
|
|
1141
|
-
}
|
|
1142
|
-
}
|
|
1143
|
-
SummaryTerms::geneExpression(gene) => {
|
|
1144
|
-
match ai_json.hasGeneExpression {
|
|
1145
|
-
true => {
|
|
1146
|
-
let mut num_gene_verification = 0;
|
|
1147
|
-
for common_gene in &common_genes {
|
|
1148
|
-
// Comparing predicted gene against the common gene
|
|
1149
|
-
if common_gene == gene {
|
|
1150
|
-
num_gene_verification += 1;
|
|
1151
|
-
validated_summary_terms.push(SummaryTerms::geneExpression(String::from(gene)));
|
|
1152
|
-
}
|
|
1153
|
-
}
|
|
1154
|
-
|
|
1155
|
-
if num_gene_verification == 0 || common_genes.len() == 0 {
|
|
1156
|
-
if message.to_lowercase().contains(&gene.to_lowercase()) { // Check if the LLM has already added the message, if not then add it
|
|
1157
|
-
} else {
|
|
1158
|
-
message = message + &"'" + &gene + &"'" + &" not found in genedb.";
|
|
1159
|
-
}
|
|
1160
|
-
}
|
|
1161
|
-
}
|
|
1162
|
-
false => {
|
|
1163
|
-
let missing_gene_data: &str = "gene expression is not supported for this dataset";
|
|
1164
|
-
if message.to_lowercase().contains(&missing_gene_data.to_lowercase()) { // Check if the LLM has already added the message, if not then add it
|
|
1165
|
-
} else {
|
|
1166
|
-
message = message + &"Gene expression not supported for this dataset";
|
|
1167
|
-
}
|
|
1168
|
-
}
|
|
1169
|
-
}
|
|
1170
|
-
}
|
|
1171
|
-
}
|
|
1172
|
-
}
|
|
1173
|
-
|
|
1174
|
-
let mut pp_plot_json: Value; // The PP compliant plot JSON
|
|
1175
|
-
pp_plot_json = serde_json::from_str(&"{\"chartType\":\"summary\"}").expect("Not a valid JSON");
|
|
1176
|
-
match &json_value.filter {
|
|
1177
|
-
Some(filter_terms_array) => {
|
|
1178
|
-
let mut validated_filter_terms = Vec::<FilterTerm>::new();
|
|
1179
|
-
for parsed_filter_term in filter_terms_array {
|
|
1180
|
-
match parsed_filter_term {
|
|
1181
|
-
FilterTerm::Categorical(categorical) => {
|
|
1182
|
-
let term_verification = verify_json_field(&categorical.term, &db_vec);
|
|
1183
|
-
let mut value_verification: Option<String> = None;
|
|
1184
|
-
for item in &db_vec {
|
|
1185
|
-
if &item.name == &categorical.term {
|
|
1186
|
-
for val in &item.values {
|
|
1187
|
-
if &categorical.value == val {
|
|
1188
|
-
value_verification = Some(val.clone());
|
|
1189
|
-
break;
|
|
1190
|
-
}
|
|
1191
|
-
}
|
|
1192
|
-
}
|
|
1193
|
-
if value_verification != None {
|
|
1194
|
-
break;
|
|
1195
|
-
}
|
|
1196
|
-
}
|
|
1197
|
-
if term_verification.correct_field.is_some() && value_verification.is_some() {
|
|
1198
|
-
let verified_filter = CategoricalFilterTerm {
|
|
1199
|
-
term: term_verification.correct_field.clone().unwrap(),
|
|
1200
|
-
value: value_verification.clone().unwrap(),
|
|
1201
|
-
};
|
|
1202
|
-
let categorical_filter_term: FilterTerm = FilterTerm::Categorical(verified_filter);
|
|
1203
|
-
validated_filter_terms.push(categorical_filter_term);
|
|
1204
|
-
}
|
|
1205
|
-
if term_verification.correct_field.is_none() {
|
|
1206
|
-
message = message + &"'" + &categorical.term + &"' filter term not found in db";
|
|
1207
|
-
}
|
|
1208
|
-
if value_verification.is_none() {
|
|
1209
|
-
message = message
|
|
1210
|
-
+ &"'"
|
|
1211
|
-
+ &categorical.value
|
|
1212
|
-
+ &"' filter value not found for filter field '"
|
|
1213
|
-
+ &categorical.term
|
|
1214
|
-
+ "' in db";
|
|
1215
|
-
}
|
|
1216
|
-
}
|
|
1217
|
-
FilterTerm::Numeric(numeric) => {
|
|
1218
|
-
let term_verification = verify_json_field(&numeric.term, &db_vec);
|
|
1219
|
-
if term_verification.correct_field.is_none() {
|
|
1220
|
-
message = message + &"'" + &numeric.term + &"' filter term not found in db";
|
|
1221
|
-
} else {
|
|
1222
|
-
let numeric_filter_term: FilterTerm = FilterTerm::Numeric(numeric.clone());
|
|
1223
|
-
validated_filter_terms.push(numeric_filter_term);
|
|
1224
|
-
}
|
|
1225
|
-
}
|
|
1226
|
-
}
|
|
1227
|
-
}
|
|
1228
|
-
|
|
1229
|
-
for summary_term in &validated_summary_terms {
|
|
1230
|
-
match summary_term {
|
|
1231
|
-
SummaryTerms::clinical(clinicial_term) => {
|
|
1232
|
-
for filter_term in &validated_filter_terms {
|
|
1233
|
-
match filter_term {
|
|
1234
|
-
FilterTerm::Categorical(categorical) => {
|
|
1235
|
-
if &categorical.term == clinicial_term {
|
|
1236
|
-
summary_terms_tobe_removed.push(summary_term.clone());
|
|
1237
|
-
}
|
|
1238
|
-
}
|
|
1239
|
-
FilterTerm::Numeric(numeric) => {
|
|
1240
|
-
if &numeric.term == clinicial_term {
|
|
1241
|
-
summary_terms_tobe_removed.push(summary_term.clone());
|
|
1242
|
-
}
|
|
1243
|
-
}
|
|
1244
|
-
}
|
|
1245
|
-
}
|
|
1246
|
-
}
|
|
1247
|
-
SummaryTerms::geneExpression(gene) => {
|
|
1248
|
-
for filter_term in &validated_filter_terms {
|
|
1249
|
-
match filter_term {
|
|
1250
|
-
FilterTerm::Categorical(categorical) => {
|
|
1251
|
-
if &categorical.term == gene {
|
|
1252
|
-
summary_terms_tobe_removed.push(summary_term.clone());
|
|
1253
|
-
}
|
|
1254
|
-
}
|
|
1255
|
-
FilterTerm::Numeric(numeric) => {
|
|
1256
|
-
if &numeric.term == gene {
|
|
1257
|
-
summary_terms_tobe_removed.push(summary_term.clone());
|
|
1258
|
-
}
|
|
1259
|
-
}
|
|
1260
|
-
}
|
|
1261
|
-
}
|
|
1262
|
-
}
|
|
1263
|
-
}
|
|
1264
|
-
}
|
|
1265
|
-
|
|
1266
|
-
if validated_filter_terms.len() > 0 {
|
|
1267
|
-
if testing == true {
|
|
1268
|
-
if let Some(obj) = new_json.as_object_mut() {
|
|
1269
|
-
obj.insert(String::from("filter"), serde_json::json!(validated_filter_terms));
|
|
1270
|
-
}
|
|
1271
|
-
} else {
|
|
1272
|
-
let mut validated_filter_terms_PP: String = "[".to_string();
|
|
1273
|
-
let mut filter_hits = 0;
|
|
1274
|
-
for validated_term in validated_filter_terms {
|
|
1275
|
-
match validated_term {
|
|
1276
|
-
FilterTerm::Categorical(categorical_filter) => {
|
|
1277
|
-
let string_json = "{\"term\":\"".to_string()
|
|
1278
|
-
+ &categorical_filter.term
|
|
1279
|
-
+ &"\", \"category\":\""
|
|
1280
|
-
+ &categorical_filter.value
|
|
1281
|
-
+ &"\"},";
|
|
1282
|
-
validated_filter_terms_PP += &string_json;
|
|
1283
|
-
}
|
|
1284
|
-
FilterTerm::Numeric(numeric_filter) => {
|
|
1285
|
-
let string_json;
|
|
1286
|
-
if numeric_filter.greaterThan.is_some() && numeric_filter.lessThan.is_none() {
|
|
1287
|
-
string_json = "{\"term\":\"".to_string()
|
|
1288
|
-
+ &numeric_filter.term
|
|
1289
|
-
+ &"\", \"gt\":\""
|
|
1290
|
-
+ &numeric_filter.greaterThan.unwrap().to_string()
|
|
1291
|
-
+ &"\"},";
|
|
1292
|
-
} else if numeric_filter.greaterThan.is_none() && numeric_filter.lessThan.is_some() {
|
|
1293
|
-
string_json = "{\"term\":\"".to_string()
|
|
1294
|
-
+ &numeric_filter.term
|
|
1295
|
-
+ &"\", \"lt\":\""
|
|
1296
|
-
+ &numeric_filter.lessThan.unwrap().to_string()
|
|
1297
|
-
+ &"\"},";
|
|
1298
|
-
} else if numeric_filter.greaterThan.is_some() && numeric_filter.lessThan.is_some() {
|
|
1299
|
-
string_json = "{\"term\":\"".to_string()
|
|
1300
|
-
+ &numeric_filter.term
|
|
1301
|
-
+ &"\", \"lt\":\""
|
|
1302
|
-
+ &numeric_filter.lessThan.unwrap().to_string()
|
|
1303
|
-
+ &"\", \"gt\":\""
|
|
1304
|
-
+ &numeric_filter.greaterThan.unwrap().to_string()
|
|
1305
|
-
+ &"\"},";
|
|
1306
|
-
} else {
|
|
1307
|
-
// When both greater and less than are none
|
|
1308
|
-
panic!(
|
|
1309
|
-
"Numeric filter term {} is missing both greater than and less than values. One of them must be defined",
|
|
1310
|
-
&numeric_filter.term
|
|
1311
|
-
);
|
|
1312
|
-
}
|
|
1313
|
-
validated_filter_terms_PP += &string_json;
|
|
1314
|
-
}
|
|
1315
|
-
};
|
|
1316
|
-
filter_hits += 1;
|
|
1317
|
-
}
|
|
1318
|
-
println!("validated_filter_terms_PP:{}", validated_filter_terms_PP);
|
|
1319
|
-
if filter_hits > 0 {
|
|
1320
|
-
validated_filter_terms_PP.pop();
|
|
1321
|
-
validated_filter_terms_PP += &"]";
|
|
1322
|
-
if let Some(obj) = pp_plot_json.as_object_mut() {
|
|
1323
|
-
obj.insert(
|
|
1324
|
-
String::from("simpleFilter"),
|
|
1325
|
-
serde_json::from_str(&validated_filter_terms_PP).expect("Not a valid JSON"),
|
|
1326
|
-
);
|
|
1327
|
-
}
|
|
1328
|
-
}
|
|
1329
|
-
}
|
|
1330
|
-
}
|
|
1331
|
-
}
|
|
1332
|
-
None => {}
|
|
1333
|
-
}
|
|
1334
|
-
|
|
1335
|
-
// Removing terms that are found both in filter term as well summary
|
|
1336
|
-
let mut validated_summary_terms_final = Vec::<SummaryTerms>::new();
|
|
1337
|
-
|
|
1338
|
-
let mut sum_iter = 0;
|
|
1339
|
-
let mut pp_json: Value; // New JSON value that will contain items of the final PP compliant JSON
|
|
1340
|
-
pp_json = serde_json::from_str(&"{\"type\":\"plot\"}").expect("Not a valid JSON");
|
|
1341
|
-
|
|
1342
|
-
for summary_term in &validated_summary_terms {
|
|
1343
|
-
let mut hit = 0;
|
|
1344
|
-
match summary_term {
|
|
1345
|
-
SummaryTerms::clinical(clinical_term) => {
|
|
1346
|
-
for summary_term2 in &summary_terms_tobe_removed {
|
|
1347
|
-
match summary_term2 {
|
|
1348
|
-
SummaryTerms::clinical(clinical_term2) => {
|
|
1349
|
-
if clinical_term == clinical_term2 {
|
|
1350
|
-
hit = 1;
|
|
1351
|
-
}
|
|
1352
|
-
}
|
|
1353
|
-
SummaryTerms::geneExpression(gene2) => {
|
|
1354
|
-
if clinical_term == gene2 {
|
|
1355
|
-
hit = 1;
|
|
1356
|
-
}
|
|
1357
|
-
}
|
|
1358
|
-
}
|
|
1359
|
-
}
|
|
1360
|
-
}
|
|
1361
|
-
SummaryTerms::geneExpression(gene) => {
|
|
1362
|
-
for summary_term2 in &summary_terms_tobe_removed {
|
|
1363
|
-
match summary_term2 {
|
|
1364
|
-
SummaryTerms::clinical(clinical_term2) => {
|
|
1365
|
-
if gene == clinical_term2 {
|
|
1366
|
-
hit = 1;
|
|
1367
|
-
}
|
|
1368
|
-
}
|
|
1369
|
-
SummaryTerms::geneExpression(gene2) => {
|
|
1370
|
-
if gene == gene2 {
|
|
1371
|
-
hit = 1;
|
|
1372
|
-
}
|
|
1373
|
-
}
|
|
1374
|
-
}
|
|
1375
|
-
}
|
|
1376
|
-
}
|
|
1377
|
-
}
|
|
1378
|
-
|
|
1379
|
-
if hit == 0 {
|
|
1380
|
-
let mut termidpp: Option<TermIDPP> = None;
|
|
1381
|
-
let mut geneexp: Option<GeneExpressionPP> = None;
|
|
1382
|
-
match summary_term {
|
|
1383
|
-
SummaryTerms::clinical(clinical_term) => {
|
|
1384
|
-
termidpp = Some(TermIDPP {
|
|
1385
|
-
id: clinical_term.to_string(),
|
|
1386
|
-
});
|
|
1387
|
-
}
|
|
1388
|
-
SummaryTerms::geneExpression(gene) => {
|
|
1389
|
-
geneexp = Some(GeneExpressionPP {
|
|
1390
|
-
gene: gene.to_string(),
|
|
1391
|
-
r#type: "geneExpression".to_string(),
|
|
1392
|
-
});
|
|
1393
|
-
}
|
|
1394
|
-
}
|
|
1395
|
-
if sum_iter == 0 {
|
|
1396
|
-
if termidpp.is_some() {
|
|
1397
|
-
if let Some(obj) = pp_plot_json.as_object_mut() {
|
|
1398
|
-
obj.insert(String::from("term"), serde_json::json!(Some(termidpp)));
|
|
1399
|
-
}
|
|
1400
|
-
}
|
|
1401
|
-
|
|
1402
|
-
if geneexp.is_some() {
|
|
1403
|
-
let gene_term = GeneTerm { term: geneexp.unwrap() };
|
|
1404
|
-
if let Some(obj) = pp_plot_json.as_object_mut() {
|
|
1405
|
-
obj.insert(String::from("term"), serde_json::json!(gene_term));
|
|
1406
|
-
}
|
|
1407
|
-
}
|
|
1408
|
-
} else if sum_iter == 1 {
|
|
1409
|
-
if termidpp.is_some() {
|
|
1410
|
-
if let Some(obj) = pp_plot_json.as_object_mut() {
|
|
1411
|
-
obj.insert(String::from("term2"), serde_json::json!(Some(termidpp)));
|
|
1412
|
-
}
|
|
1413
|
-
}
|
|
1414
|
-
|
|
1415
|
-
if geneexp.is_some() {
|
|
1416
|
-
let gene_term = GeneTerm { term: geneexp.unwrap() };
|
|
1417
|
-
if let Some(obj) = pp_plot_json.as_object_mut() {
|
|
1418
|
-
obj.insert(String::from("term2"), serde_json::json!(gene_term));
|
|
1419
|
-
}
|
|
1420
|
-
}
|
|
1421
|
-
}
|
|
1422
|
-
validated_summary_terms_final.push(summary_term.clone())
|
|
1423
|
-
}
|
|
1424
|
-
sum_iter += 1
|
|
1425
|
-
}
|
|
1426
|
-
|
|
1427
|
-
if let Some(obj) = new_json.as_object_mut() {
|
|
1428
|
-
obj.insert(
|
|
1429
|
-
String::from("summaryterms"),
|
|
1430
|
-
serde_json::json!(validated_summary_terms_final),
|
|
1431
|
-
);
|
|
1432
|
-
}
|
|
1433
|
-
|
|
1434
|
-
if let Some(obj) = pp_json.as_object_mut() {
|
|
1435
|
-
// The `if let` ensures we only proceed if the top-level JSON is an object.
|
|
1436
|
-
// Append a new string field.
|
|
1437
|
-
obj.insert(String::from("plot"), serde_json::json!(pp_plot_json));
|
|
1438
|
-
}
|
|
1439
|
-
|
|
1440
|
-
let mut err_json: Value; // Error JSON containing the error message (if present)
|
|
1441
|
-
if message.len() > 0 {
|
|
1442
|
-
if testing == false {
|
|
1443
|
-
err_json = serde_json::from_str(&"{\"type\":\"html\"}").expect("Not a valid JSON");
|
|
1444
|
-
if let Some(obj) = err_json.as_object_mut() {
|
|
1445
|
-
// The `if let` ensures we only proceed if the top-level JSON is an object.
|
|
1446
|
-
// Append a new string field.
|
|
1447
|
-
obj.insert(String::from("html"), serde_json::json!(message));
|
|
1448
|
-
};
|
|
1449
|
-
serde_json::to_string(&err_json).unwrap()
|
|
1450
|
-
} else {
|
|
1451
|
-
if let Some(obj) = new_json.as_object_mut() {
|
|
1452
|
-
// The `if let` ensures we only proceed if the top-level JSON is an object.
|
|
1453
|
-
// Append a new string field.
|
|
1454
|
-
obj.insert(String::from("message"), serde_json::json!(message));
|
|
1455
|
-
};
|
|
1456
|
-
serde_json::to_string(&new_json).unwrap()
|
|
1457
|
-
}
|
|
1458
|
-
} else {
|
|
1459
|
-
if testing == true {
|
|
1460
|
-
// When testing script output native LLM JSON
|
|
1461
|
-
serde_json::to_string(&new_json).unwrap()
|
|
1462
|
-
} else {
|
|
1463
|
-
// When in production output PP compliant JSON
|
|
1464
|
-
serde_json::to_string(&pp_json).unwrap()
|
|
1465
|
-
}
|
|
1466
|
-
}
|
|
1467
|
-
}
|
|
1468
|
-
|
|
1469
|
-
fn getGeneExpression() -> String {
|
|
1470
|
-
"geneExpression".to_string()
|
|
1471
|
-
}
|
|
1472
|
-
|
|
1473
|
-
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
1474
|
-
struct TermIDPP {
|
|
1475
|
-
id: String,
|
|
1476
|
-
}
|
|
1477
|
-
|
|
1478
|
-
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
1479
|
-
struct GeneTerm {
|
|
1480
|
-
term: GeneExpressionPP,
|
|
1481
|
-
}
|
|
1482
|
-
|
|
1483
|
-
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
1484
|
-
struct GeneExpressionPP {
|
|
1485
|
-
gene: String,
|
|
1486
|
-
// Serde uses this for deserialization.
|
|
1487
|
-
#[serde(default = "getGeneExpression")]
|
|
1488
|
-
r#type: String,
|
|
1489
|
-
}
|
|
1490
|
-
|
|
1491
|
-
#[derive(Debug, Clone)]
|
|
1492
|
-
struct VerifiedField {
|
|
1493
|
-
correct_field: Option<String>, // Name of the correct field
|
|
1494
|
-
correct_value: Option<String>, // Name of the correct value if there is a match between incorrect field and one of the values
|
|
1495
|
-
_probable_fields: Option<Vec<String>>, // If multiple fields are matching to the incomplete query
|
|
1496
|
-
}
|
|
1497
|
-
|
|
1498
|
-
fn verify_json_field(llm_field_name: &str, db_vec: &Vec<DbRows>) -> VerifiedField {
|
|
1499
|
-
// Check if llm_field_name exists or not in db name field
|
|
1500
|
-
let verified_result: VerifiedField;
|
|
1501
|
-
if db_vec.iter().any(|item| item.name == llm_field_name) {
|
|
1502
|
-
//println!("Found \"{}\" in db", llm_field_name);
|
|
1503
|
-
verified_result = VerifiedField {
|
|
1504
|
-
correct_field: Some(String::from(llm_field_name)),
|
|
1505
|
-
correct_value: None,
|
|
1506
|
-
_probable_fields: None,
|
|
1507
|
-
};
|
|
1508
|
-
} else {
|
|
1509
|
-
println!("Did not find \"{}\" in db", llm_field_name);
|
|
1510
|
-
// Check to see if llm_field_name exists as values under any of the fields
|
|
1511
|
-
let (search_field, search_val) = verify_json_value(llm_field_name, &db_vec);
|
|
1512
|
-
|
|
1513
|
-
match search_field {
|
|
1514
|
-
Some(x) => {
|
|
1515
|
-
verified_result = VerifiedField {
|
|
1516
|
-
correct_field: Some(String::from(x)),
|
|
1517
|
-
correct_value: search_val,
|
|
1518
|
-
_probable_fields: None,
|
|
1519
|
-
};
|
|
1520
|
-
}
|
|
1521
|
-
None => {
|
|
1522
|
-
// Incorrect field found neither in any of the fields nor any of the values. This will then invoke embedding match across all the fields and their corresponding values
|
|
1523
|
-
|
|
1524
|
-
let mut search_terms = Vec::<String>::new();
|
|
1525
|
-
search_terms.push(String::from(llm_field_name)); // Added the incorrect field item to the search
|
|
1526
|
-
verified_result = VerifiedField {
|
|
1527
|
-
correct_field: None,
|
|
1528
|
-
correct_value: None,
|
|
1529
|
-
_probable_fields: None,
|
|
1530
|
-
};
|
|
1531
|
-
}
|
|
1532
|
-
}
|
|
1533
|
-
}
|
|
1534
|
-
verified_result
|
|
1535
|
-
}
|
|
1536
|
-
|
|
1537
|
-
fn verify_json_value(llm_value_name: &str, db_vec: &Vec<DbRows>) -> (Option<String>, Option<String>) {
|
|
1538
|
-
let mut search_field: Option<String> = None;
|
|
1539
|
-
let mut search_val: Option<String> = None;
|
|
1540
|
-
for item in db_vec {
|
|
1541
|
-
for val in &item.values {
|
|
1542
|
-
if llm_value_name == val {
|
|
1543
|
-
search_field = Some(item.name.clone());
|
|
1544
|
-
search_val = Some(String::from(val));
|
|
1545
|
-
break;
|
|
1546
|
-
}
|
|
1547
|
-
}
|
|
1548
|
-
match search_field {
|
|
1549
|
-
Some(_) => break,
|
|
1550
|
-
None => {}
|
|
1551
|
-
}
|
|
1552
|
-
}
|
|
1553
|
-
(search_field, search_val)
|
|
1554
|
-
}
|