@sjcrh/proteinpaint-rust 2.149.0 → 2.152.1-0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Cargo.toml +1 -1
- package/README.md +5 -0
- package/package.json +1 -1
- package/src/DEanalysis.rs +110 -311
- package/src/aichatbot.rs +770 -136
- package/src/ollama.rs +1108 -0
- package/src/sjprovider.rs +52 -11
- package/src/test_ai.rs +168 -0
package/src/sjprovider.rs
CHANGED
|
@@ -73,6 +73,15 @@ impl Client {
|
|
|
73
73
|
pub fn builder() -> ClientBuilder<'static> {
|
|
74
74
|
ClientBuilder::new()
|
|
75
75
|
}
|
|
76
|
+
|
|
77
|
+
pub fn completion_model(&self, model: &str) -> CompletionModel {
|
|
78
|
+
CompletionModel::new(self.clone(), model)
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
|
|
82
|
+
EmbeddingModel::new(self.clone(), model, 0, self.base_url.to_string())
|
|
83
|
+
}
|
|
84
|
+
|
|
76
85
|
pub fn new() -> Self {
|
|
77
86
|
Self::builder().build().expect("Myprovider client should build")
|
|
78
87
|
}
|
|
@@ -129,11 +138,16 @@ impl EmbeddingsClient for Client {
|
|
|
129
138
|
|
|
130
139
|
impl VerifyClient for Client {
|
|
131
140
|
async fn verify(&self) -> Result<(), VerifyError> {
|
|
132
|
-
let response = self
|
|
141
|
+
let response = self
|
|
142
|
+
.get("api/tags")
|
|
143
|
+
.expect("Failed to build request")
|
|
144
|
+
.send()
|
|
145
|
+
.await
|
|
146
|
+
.unwrap();
|
|
133
147
|
match response.status() {
|
|
134
148
|
reqwest::StatusCode::OK => Ok(()),
|
|
135
149
|
_ => {
|
|
136
|
-
response.error_for_status()
|
|
150
|
+
response.error_for_status().unwrap();
|
|
137
151
|
Ok(())
|
|
138
152
|
}
|
|
139
153
|
}
|
|
@@ -262,7 +276,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
|
|
|
262
276
|
|
|
263
277
|
if response.status().is_success() {
|
|
264
278
|
//println!("response.json:{:?}", response.text().await?);
|
|
265
|
-
let json_data: Value = serde_json::from_str(&response.text().await
|
|
279
|
+
let json_data: Value = serde_json::from_str(&response.text().await.unwrap())?;
|
|
266
280
|
let emb = json_data["outputs"].as_array().unwrap();
|
|
267
281
|
//.unwrap_or(&vec![serde_json::Value::String(
|
|
268
282
|
// "No embeddings found in json output".to_string(),
|
|
@@ -481,12 +495,39 @@ impl CompletionModel {
|
|
|
481
495
|
panic!("max_new_tokens and top_p not found!");
|
|
482
496
|
};
|
|
483
497
|
|
|
498
|
+
let mut user_query = "";
|
|
499
|
+
let mut system_prompt = "";
|
|
500
|
+
for message in &full_history {
|
|
501
|
+
match message {
|
|
502
|
+
self::Message::User {
|
|
503
|
+
content: text,
|
|
504
|
+
images: _,
|
|
505
|
+
name: _,
|
|
506
|
+
} => {
|
|
507
|
+
//println!("User:{:?}", text);
|
|
508
|
+
user_query = text;
|
|
509
|
+
}
|
|
510
|
+
self::Message::System {
|
|
511
|
+
content: text,
|
|
512
|
+
images: _,
|
|
513
|
+
name: _,
|
|
514
|
+
} => {
|
|
515
|
+
system_prompt = text;
|
|
516
|
+
//println!("System:{:?}", text);
|
|
517
|
+
}
|
|
518
|
+
self::Message::Assistant { content: _, id: _ } => {}
|
|
519
|
+
self::Message::ToolResult { content: _, name: _ } => {}
|
|
520
|
+
}
|
|
521
|
+
}
|
|
522
|
+
let final_text = system_prompt.replace(&"{question}", &user_query);
|
|
523
|
+
|
|
524
|
+
//println!("final_text:{:?}", final_text);
|
|
484
525
|
let mut request_payload = json!({
|
|
485
526
|
"inputs":[
|
|
486
527
|
{
|
|
487
528
|
"model_name": self.model,
|
|
488
529
|
"inputs": {
|
|
489
|
-
"text":
|
|
530
|
+
"text": final_text,
|
|
490
531
|
"max_new_tokens": max_new_tokens,
|
|
491
532
|
"temperature": completion_request.temperature,
|
|
492
533
|
"top_p": top_p
|
|
@@ -612,7 +653,7 @@ impl completion::CompletionModel for CompletionModel {
|
|
|
612
653
|
let chunk = match chunk_result {
|
|
613
654
|
Ok(c) => c,
|
|
614
655
|
Err(e) => {
|
|
615
|
-
yield Err(CompletionError::
|
|
656
|
+
yield Err(CompletionError::RequestError(e.into()));
|
|
616
657
|
break;
|
|
617
658
|
}
|
|
618
659
|
};
|
|
@@ -797,7 +838,7 @@ impl ConvertMessage for Message {
|
|
|
797
838
|
images.push(data)
|
|
798
839
|
}
|
|
799
840
|
rig::message::UserContent::Document(rig::message::Document { data, .. }) => {
|
|
800
|
-
texts.push(data)
|
|
841
|
+
texts.push(data.to_string())
|
|
801
842
|
}
|
|
802
843
|
_ => {} // Audio not supported by Ollama
|
|
803
844
|
}
|
|
@@ -993,7 +1034,7 @@ mod tests {
|
|
|
993
1034
|
#[tokio::test]
|
|
994
1035
|
#[ignore]
|
|
995
1036
|
|
|
996
|
-
async fn
|
|
1037
|
+
async fn test_sjprovider_implementation() {
|
|
997
1038
|
let user_input = "Generate DE plot for men with weight greater than 30lbs vs women less than 20lbs";
|
|
998
1039
|
let serverconfig_file_path = Path::new("../../serverconfig.json");
|
|
999
1040
|
let absolute_path = serverconfig_file_path.canonicalize().unwrap();
|
|
@@ -1118,17 +1159,17 @@ If a query does not match any of the fields described above, then return JSON wi
|
|
|
1118
1159
|
});
|
|
1119
1160
|
|
|
1120
1161
|
// Create RAG agent
|
|
1121
|
-
let agent = AgentBuilder::new(comp_model).preamble("Generate classification for the user query into summary, dge, hierarchial, snv_indel, cnv, variant_calling, sv_fusion and none categories. Return output in JSON with ALWAYS a single word answer { \"answer\": \"dge\" }, that is 'summary' for summary plot, 'dge' for differential gene expression, 'hierarchial' for hierarchial clustering, 'snv_indel' for SNV/Indel, 'cnv' for CNV and 'sv_fusion' for SV/fusion, 'variant_calling' for variant calling, 'surivial' for survival data, 'none' for none of the previously described categories. The answer should always be in lower case").dynamic_context(top_k, vector_store.index(embedding_model)).additional_params(additional).temperature(temperature).build();
|
|
1162
|
+
let agent = AgentBuilder::new(comp_model).preamble("Generate classification for the user query into summary, dge, hierarchial, snv_indel, cnv, variant_calling, sv_fusion and none categories. Return output in JSON with ALWAYS a single word answer { \"answer\": \"dge\" }, that is 'summary' for summary plot, 'dge' for differential gene expression, 'hierarchial' for hierarchial clustering, 'snv_indel' for SNV/Indel, 'cnv' for CNV and 'sv_fusion' for SV/fusion, 'variant_calling' for variant calling, 'surivial' for survival data, 'none' for none of the previously described categories. The answer should always be in lower case. \nQuestion= {question} \nanswer").dynamic_context(top_k, vector_store.index(embedding_model)).additional_params(additional).temperature(temperature).build();
|
|
1122
1163
|
|
|
1123
1164
|
let response = agent.prompt(user_input).await.expect("Failed to prompt myprovider");
|
|
1124
1165
|
|
|
1125
1166
|
//println!("Myprovider: {}", response);
|
|
1126
1167
|
let result = response.replace("json", "").replace("```", "");
|
|
1127
1168
|
//println!("result:{}", result);
|
|
1128
|
-
let json_value: Value = serde_json::from_str(&result).expect("
|
|
1129
|
-
let json_value2: Value = serde_json::from_str(&json_value[0]["generated_text"].to_string()).expect("
|
|
1169
|
+
let json_value: Value = serde_json::from_str(&result).expect("REASON2");
|
|
1170
|
+
let json_value2: Value = serde_json::from_str(&json_value[0]["generated_text"].to_string()).expect("REASON3");
|
|
1130
1171
|
//println!("json_value2:{}", json_value2.as_str().unwrap());
|
|
1131
|
-
let json_value3: Value = serde_json::from_str(&json_value2.as_str().unwrap()).expect("
|
|
1172
|
+
let json_value3: Value = serde_json::from_str(&json_value2.as_str().unwrap()).expect("REASON4");
|
|
1132
1173
|
assert_eq!(json_value3["answer"].to_string().replace("\"", ""), "dge");
|
|
1133
1174
|
}
|
|
1134
1175
|
}
|
package/src/test_ai.rs
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
// For capturing output from a test, run: cd .. && cargo test -- --nocapture
|
|
2
|
+
// Ignored tests: cd .. && export RUST_BACKTRACE=full && time cargo test -- --ignored --nocapture
|
|
3
|
+
#[allow(dead_code)]
|
|
4
|
+
fn main() {}
|
|
5
|
+
|
|
6
|
+
#[cfg(test)]
|
|
7
|
+
mod tests {
|
|
8
|
+
use serde_json;
|
|
9
|
+
use std::fs::{self};
|
|
10
|
+
use std::path::Path;
|
|
11
|
+
|
|
12
|
+
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
13
|
+
struct ServerConfig {
|
|
14
|
+
tpmasterdir: String,
|
|
15
|
+
llm_backend: String,
|
|
16
|
+
sj_apilink: String,
|
|
17
|
+
sj_comp_model_name: String,
|
|
18
|
+
sj_embedding_model_name: String,
|
|
19
|
+
ollama_apilink: String,
|
|
20
|
+
ollama_comp_model_name: String,
|
|
21
|
+
ollama_embedding_model_name: String,
|
|
22
|
+
genomes: Vec<Genomes>,
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
26
|
+
struct Genomes {
|
|
27
|
+
name: String,
|
|
28
|
+
datasets: Vec<Dataset>,
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
32
|
+
struct Dataset {
|
|
33
|
+
name: String,
|
|
34
|
+
aifiles: Option<String>, // For now aifiles are defined only for certain datasets
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
#[tokio::test]
|
|
38
|
+
#[ignore]
|
|
39
|
+
async fn user_prompts() {
|
|
40
|
+
let temperature: f64 = 0.01;
|
|
41
|
+
let max_new_tokens: usize = 512;
|
|
42
|
+
let top_p: f32 = 0.95;
|
|
43
|
+
let serverconfig_file_path = Path::new("../../serverconfig.json");
|
|
44
|
+
let absolute_path = serverconfig_file_path.canonicalize().unwrap();
|
|
45
|
+
|
|
46
|
+
// Read the file
|
|
47
|
+
let data = fs::read_to_string(absolute_path).unwrap();
|
|
48
|
+
|
|
49
|
+
// Parse the JSON data
|
|
50
|
+
let serverconfig: ServerConfig = serde_json::from_str(&data).expect("JSON not in serverconfig.json format");
|
|
51
|
+
|
|
52
|
+
for genome in &serverconfig.genomes {
|
|
53
|
+
for dataset in &genome.datasets {
|
|
54
|
+
match &dataset.aifiles {
|
|
55
|
+
Some(ai_json_file) => {
|
|
56
|
+
println!("Testing dataset:{}", dataset.name);
|
|
57
|
+
let ai_json_file_path = String::from("../../") + ai_json_file;
|
|
58
|
+
let ai_json_file = Path::new(&ai_json_file_path);
|
|
59
|
+
|
|
60
|
+
// Read the file
|
|
61
|
+
let ai_data = fs::read_to_string(ai_json_file).unwrap();
|
|
62
|
+
// Parse the JSON data
|
|
63
|
+
let ai_json: super::super::AiJsonFormat =
|
|
64
|
+
serde_json::from_str(&ai_data).expect("AI JSON file does not have the correct format");
|
|
65
|
+
//println!("ai_json:{:?}", ai_json);
|
|
66
|
+
let genedb = String::from(&serverconfig.tpmasterdir) + &"/" + &ai_json.genedb;
|
|
67
|
+
let dataset_db = String::from(&serverconfig.tpmasterdir) + &"/" + &ai_json.db;
|
|
68
|
+
let llm_backend_name = &serverconfig.llm_backend;
|
|
69
|
+
let llm_backend_type: super::super::llm_backend;
|
|
70
|
+
|
|
71
|
+
if llm_backend_name != "ollama" && llm_backend_name != "SJ" {
|
|
72
|
+
panic!(
|
|
73
|
+
"This code currently supports only Ollama and SJ provider. llm_backend_name must be \"ollama\" or \"SJ\""
|
|
74
|
+
);
|
|
75
|
+
} else if *llm_backend_name == "ollama".to_string() {
|
|
76
|
+
let ollama_host = &serverconfig.ollama_apilink;
|
|
77
|
+
let ollama_embedding_model_name = &serverconfig.ollama_embedding_model_name;
|
|
78
|
+
let ollama_comp_model_name = &serverconfig.ollama_comp_model_name;
|
|
79
|
+
llm_backend_type = super::super::llm_backend::Ollama();
|
|
80
|
+
let ollama_client = super::super::ollama::Client::builder()
|
|
81
|
+
.base_url(ollama_host)
|
|
82
|
+
.build()
|
|
83
|
+
.expect("Ollama server not found");
|
|
84
|
+
let embedding_model = ollama_client.embedding_model(ollama_embedding_model_name);
|
|
85
|
+
let comp_model = ollama_client.completion_model(ollama_comp_model_name);
|
|
86
|
+
|
|
87
|
+
for chart in ai_json.charts.clone() {
|
|
88
|
+
match chart {
|
|
89
|
+
super::super::Charts::Summary(testdata) => {
|
|
90
|
+
for ques_ans in testdata.TestData {
|
|
91
|
+
let user_input = ques_ans.question;
|
|
92
|
+
let llm_output = super::super::run_pipeline(
|
|
93
|
+
&user_input,
|
|
94
|
+
comp_model.clone(),
|
|
95
|
+
embedding_model.clone(),
|
|
96
|
+
llm_backend_type.clone(),
|
|
97
|
+
temperature,
|
|
98
|
+
max_new_tokens,
|
|
99
|
+
top_p,
|
|
100
|
+
&dataset_db,
|
|
101
|
+
&genedb,
|
|
102
|
+
&ai_json,
|
|
103
|
+
)
|
|
104
|
+
.await;
|
|
105
|
+
let mut llm_json_value: super::super::SummaryType = serde_json::from_str(&llm_output.unwrap()).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
|
|
106
|
+
let mut expected_json_value: super::super::SummaryType = serde_json::from_str(&ques_ans.answer).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
|
|
107
|
+
assert_eq!(
|
|
108
|
+
llm_json_value.sort_summarytype_struct(),
|
|
109
|
+
expected_json_value.sort_summarytype_struct()
|
|
110
|
+
);
|
|
111
|
+
}
|
|
112
|
+
}
|
|
113
|
+
super::super::Charts::DE(_testdata) => {} // To do
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
} else if *llm_backend_name == "SJ".to_string() {
|
|
117
|
+
let sjprovider_host = &serverconfig.sj_apilink;
|
|
118
|
+
let sj_embedding_model_name = &serverconfig.sj_embedding_model_name;
|
|
119
|
+
let sj_comp_model_name = &serverconfig.sj_comp_model_name;
|
|
120
|
+
llm_backend_type = super::super::llm_backend::Sj();
|
|
121
|
+
let sj_client = super::super::sjprovider::Client::builder()
|
|
122
|
+
.base_url(sjprovider_host)
|
|
123
|
+
.build()
|
|
124
|
+
.expect("SJ server not found");
|
|
125
|
+
let embedding_model = sj_client.embedding_model(sj_embedding_model_name);
|
|
126
|
+
let comp_model = sj_client.completion_model(sj_comp_model_name);
|
|
127
|
+
|
|
128
|
+
for chart in ai_json.charts.clone() {
|
|
129
|
+
match chart {
|
|
130
|
+
super::super::Charts::Summary(testdata) => {
|
|
131
|
+
for ques_ans in testdata.TestData {
|
|
132
|
+
let user_input = ques_ans.question;
|
|
133
|
+
if user_input.len() > 0 {
|
|
134
|
+
let llm_output = super::super::run_pipeline(
|
|
135
|
+
&user_input,
|
|
136
|
+
comp_model.clone(),
|
|
137
|
+
embedding_model.clone(),
|
|
138
|
+
llm_backend_type.clone(),
|
|
139
|
+
temperature,
|
|
140
|
+
max_new_tokens,
|
|
141
|
+
top_p,
|
|
142
|
+
&dataset_db,
|
|
143
|
+
&genedb,
|
|
144
|
+
&ai_json,
|
|
145
|
+
)
|
|
146
|
+
.await;
|
|
147
|
+
let mut llm_json_value: super::super::SummaryType = serde_json::from_str(&llm_output.unwrap()).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
|
|
148
|
+
let mut expected_json_value: super::super::SummaryType = serde_json::from_str(&ques_ans.answer).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
|
|
149
|
+
assert_eq!(
|
|
150
|
+
llm_json_value.sort_summarytype_struct(),
|
|
151
|
+
expected_json_value.sort_summarytype_struct()
|
|
152
|
+
);
|
|
153
|
+
} else {
|
|
154
|
+
panic!("The user input is empty");
|
|
155
|
+
}
|
|
156
|
+
}
|
|
157
|
+
}
|
|
158
|
+
super::super::Charts::DE(_testdata) => {} // To do
|
|
159
|
+
}
|
|
160
|
+
}
|
|
161
|
+
}
|
|
162
|
+
}
|
|
163
|
+
None => {}
|
|
164
|
+
}
|
|
165
|
+
}
|
|
166
|
+
}
|
|
167
|
+
}
|
|
168
|
+
}
|