@sjcrh/proteinpaint-rust 2.167.0 → 2.170.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Cargo.toml +8 -4
- package/package.json +1 -1
- package/src/aichatbot.rs +307 -310
- package/src/manhattan_plot.rs +31 -16
- package/src/query_classification.rs +152 -0
- package/src/summary_agent.rs +201 -0
- package/src/test_ai.rs +81 -58
package/src/test_ai.rs
CHANGED
|
@@ -5,6 +5,7 @@ fn main() {}
|
|
|
5
5
|
|
|
6
6
|
#[cfg(test)]
|
|
7
7
|
mod tests {
|
|
8
|
+
use crate::aichatbot::{AiJsonFormat, AnswerFormat, SummaryType, llm_backend, run_pipeline};
|
|
8
9
|
use serde_json;
|
|
9
10
|
use std::fs::{self};
|
|
10
11
|
use std::path::Path;
|
|
@@ -20,6 +21,7 @@ mod tests {
|
|
|
20
21
|
ollama_comp_model_name: String,
|
|
21
22
|
ollama_embedding_model_name: String,
|
|
22
23
|
genomes: Vec<Genomes>,
|
|
24
|
+
aiRoute: String,
|
|
23
25
|
}
|
|
24
26
|
|
|
25
27
|
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
@@ -49,7 +51,7 @@ mod tests {
|
|
|
49
51
|
|
|
50
52
|
// Parse the JSON data
|
|
51
53
|
let serverconfig: ServerConfig = serde_json::from_str(&data).expect("JSON not in serverconfig.json format");
|
|
52
|
-
|
|
54
|
+
let airoute = String::from("../../") + &serverconfig.aiRoute;
|
|
53
55
|
for genome in &serverconfig.genomes {
|
|
54
56
|
for dataset in &genome.datasets {
|
|
55
57
|
match &dataset.aifiles {
|
|
@@ -61,13 +63,13 @@ mod tests {
|
|
|
61
63
|
// Read the file
|
|
62
64
|
let ai_data = fs::read_to_string(ai_json_file).unwrap();
|
|
63
65
|
// Parse the JSON data
|
|
64
|
-
let ai_json:
|
|
66
|
+
let ai_json: AiJsonFormat =
|
|
65
67
|
serde_json::from_str(&ai_data).expect("AI JSON file does not have the correct format");
|
|
66
68
|
//println!("ai_json:{:?}", ai_json);
|
|
67
69
|
let genedb = String::from(&serverconfig.tpmasterdir) + &"/" + &ai_json.genedb;
|
|
68
70
|
let dataset_db = String::from(&serverconfig.tpmasterdir) + &"/" + &ai_json.db;
|
|
69
71
|
let llm_backend_name = &serverconfig.llm_backend;
|
|
70
|
-
let llm_backend_type:
|
|
72
|
+
let llm_backend_type: llm_backend;
|
|
71
73
|
|
|
72
74
|
if llm_backend_name != "ollama" && llm_backend_name != "SJ" {
|
|
73
75
|
panic!(
|
|
@@ -77,7 +79,7 @@ mod tests {
|
|
|
77
79
|
let ollama_host = &serverconfig.ollama_apilink;
|
|
78
80
|
let ollama_embedding_model_name = &serverconfig.ollama_embedding_model_name;
|
|
79
81
|
let ollama_comp_model_name = &serverconfig.ollama_comp_model_name;
|
|
80
|
-
llm_backend_type =
|
|
82
|
+
llm_backend_type = llm_backend::Ollama();
|
|
81
83
|
let ollama_client = super::super::ollama::Client::builder()
|
|
82
84
|
.base_url(ollama_host)
|
|
83
85
|
.build()
|
|
@@ -85,40 +87,45 @@ mod tests {
|
|
|
85
87
|
let embedding_model = ollama_client.embedding_model(ollama_embedding_model_name);
|
|
86
88
|
let comp_model = ollama_client.completion_model(ollama_comp_model_name);
|
|
87
89
|
for chart in ai_json.charts.clone() {
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
90
|
+
if chart.r#type == "Summary" {
|
|
91
|
+
for ques_ans in chart.TestData {
|
|
92
|
+
let user_input = ques_ans.question;
|
|
93
|
+
let llm_output = run_pipeline(
|
|
94
|
+
&user_input,
|
|
95
|
+
comp_model.clone(),
|
|
96
|
+
embedding_model.clone(),
|
|
97
|
+
llm_backend_type.clone(),
|
|
98
|
+
temperature,
|
|
99
|
+
max_new_tokens,
|
|
100
|
+
top_p,
|
|
101
|
+
&dataset_db,
|
|
102
|
+
&genedb,
|
|
103
|
+
&ai_json,
|
|
104
|
+
&airoute,
|
|
105
|
+
testing,
|
|
106
|
+
)
|
|
107
|
+
.await;
|
|
108
|
+
let llm_json_value: SummaryType = serde_json::from_str(&llm_output.unwrap()).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
|
|
109
|
+
match ques_ans.answer {
|
|
110
|
+
AnswerFormat::summary_type(sum) => {
|
|
111
|
+
//println!("expected answer:{:?}", &sum);
|
|
112
|
+
assert_eq!(
|
|
113
|
+
llm_json_value.sort_summarytype_struct(),
|
|
114
|
+
sum.sort_summarytype_struct()
|
|
115
|
+
);
|
|
116
|
+
}
|
|
117
|
+
AnswerFormat::DE_type(_) => {
|
|
118
|
+
panic!("DE type not valid for summary")
|
|
119
|
+
}
|
|
112
120
|
}
|
|
113
121
|
}
|
|
114
|
-
super::super::Charts::DE(_testdata) => {} // To do
|
|
115
122
|
}
|
|
116
123
|
}
|
|
117
124
|
} else if *llm_backend_name == "SJ".to_string() {
|
|
118
125
|
let sjprovider_host = &serverconfig.sj_apilink;
|
|
119
126
|
let sj_embedding_model_name = &serverconfig.sj_embedding_model_name;
|
|
120
127
|
let sj_comp_model_name = &serverconfig.sj_comp_model_name;
|
|
121
|
-
llm_backend_type =
|
|
128
|
+
llm_backend_type = llm_backend::Sj();
|
|
122
129
|
let sj_client = super::super::sjprovider::Client::builder()
|
|
123
130
|
.base_url(sjprovider_host)
|
|
124
131
|
.build()
|
|
@@ -127,37 +134,53 @@ mod tests {
|
|
|
127
134
|
let comp_model = sj_client.completion_model(sj_comp_model_name);
|
|
128
135
|
|
|
129
136
|
for chart in ai_json.charts.clone() {
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
137
|
+
if chart.r#type == "Summary" {
|
|
138
|
+
for ques_ans in chart.TestData {
|
|
139
|
+
let user_input = ques_ans.question;
|
|
140
|
+
if user_input.len() > 0 {
|
|
141
|
+
let llm_output = run_pipeline(
|
|
142
|
+
&user_input,
|
|
143
|
+
comp_model.clone(),
|
|
144
|
+
embedding_model.clone(),
|
|
145
|
+
llm_backend_type.clone(),
|
|
146
|
+
temperature,
|
|
147
|
+
max_new_tokens,
|
|
148
|
+
top_p,
|
|
149
|
+
&dataset_db,
|
|
150
|
+
&genedb,
|
|
151
|
+
&ai_json,
|
|
152
|
+
&airoute,
|
|
153
|
+
testing,
|
|
154
|
+
)
|
|
155
|
+
.await;
|
|
156
|
+
//println!("user_input:{}", user_input);
|
|
157
|
+
//println!("llm_answer:{:?}", llm_output);
|
|
158
|
+
//println!("expected answer:{:?}", &ques_ans.answer);
|
|
159
|
+
let llm_json_value: SummaryType = serde_json::from_str(&llm_output.unwrap()).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
|
|
160
|
+
//println!(
|
|
161
|
+
// "llm_answer:{:?}",
|
|
162
|
+
// llm_json_value.clone().sort_summarytype_struct()
|
|
163
|
+
//);
|
|
164
|
+
//println!(
|
|
165
|
+
// "expected answer:{:?}",
|
|
166
|
+
// &expected_json_value.clone().sort_summarytype_struct()
|
|
167
|
+
//);
|
|
168
|
+
match ques_ans.answer {
|
|
169
|
+
AnswerFormat::summary_type(sum) => {
|
|
170
|
+
//println!("expected answer:{:?}", &sum);
|
|
171
|
+
assert_eq!(
|
|
172
|
+
llm_json_value.sort_summarytype_struct(),
|
|
173
|
+
sum.sort_summarytype_struct()
|
|
174
|
+
);
|
|
175
|
+
}
|
|
176
|
+
AnswerFormat::DE_type(_) => {
|
|
177
|
+
panic!("DE type not valid for summary")
|
|
178
|
+
}
|
|
157
179
|
}
|
|
180
|
+
} else {
|
|
181
|
+
panic!("The user input is empty");
|
|
158
182
|
}
|
|
159
183
|
}
|
|
160
|
-
super::super::Charts::DE(_testdata) => {} // To do
|
|
161
184
|
}
|
|
162
185
|
}
|
|
163
186
|
}
|