@sjcrh/proteinpaint-rust 2.188.0 → 2.190.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Cargo.toml +0 -9
- package/package.json +1 -1
- package/src/manhattan_plot.rs +38 -76
- package/src/volcano.rs +89 -49
- package/src/aichatbot.rs +0 -1554
- package/src/query_classification.rs +0 -152
- package/src/summary_agent.rs +0 -201
- package/src/test_ai.rs +0 -193
|
@@ -1,152 +0,0 @@
|
|
|
1
|
-
// Syntax: cd .. && cargo build --release && time cat ~/sjpp/test.txt | target/release/aichatbot
|
|
2
|
-
#![allow(non_snake_case)]
|
|
3
|
-
use anyhow::Result;
|
|
4
|
-
use json::JsonValue;
|
|
5
|
-
use schemars::JsonSchema;
|
|
6
|
-
use std::io;
|
|
7
|
-
mod aichatbot; // Importing classification agent from aichatbot.rs
|
|
8
|
-
mod ollama; // Importing custom rig module for invoking ollama server
|
|
9
|
-
mod sjprovider; // Importing custom rig module for invoking SJ GPU server
|
|
10
|
-
mod test_ai; // Test examples for AI chatbot
|
|
11
|
-
|
|
12
|
-
#[tokio::main]
|
|
13
|
-
async fn main() -> Result<()> {
|
|
14
|
-
let mut input = String::new();
|
|
15
|
-
match io::stdin().read_line(&mut input) {
|
|
16
|
-
// Accepting the piped input from nodejs (or command line from testing)
|
|
17
|
-
Ok(_n) => {
|
|
18
|
-
let input_json = json::parse(&input);
|
|
19
|
-
match input_json {
|
|
20
|
-
Ok(json_string) => {
|
|
21
|
-
//println!("json_string:{}", json_string);
|
|
22
|
-
let user_input_json: &JsonValue = &json_string["user_input"];
|
|
23
|
-
let user_input: &str;
|
|
24
|
-
match user_input_json.as_str() {
|
|
25
|
-
Some(inp) => user_input = inp,
|
|
26
|
-
None => panic!("user_input field is missing in input json"),
|
|
27
|
-
}
|
|
28
|
-
|
|
29
|
-
if user_input.len() == 0 {
|
|
30
|
-
panic!("The user input is empty");
|
|
31
|
-
}
|
|
32
|
-
|
|
33
|
-
let binpath_json: &JsonValue = &json_string["binpath"];
|
|
34
|
-
let binpath: &str;
|
|
35
|
-
match binpath_json.as_str() {
|
|
36
|
-
Some(inp) => binpath = inp,
|
|
37
|
-
None => panic!("binpath not found"),
|
|
38
|
-
}
|
|
39
|
-
|
|
40
|
-
let aiRoute_json: &JsonValue = &json_string["aiRoute"];
|
|
41
|
-
let aiRoute_str: &str;
|
|
42
|
-
match aiRoute_json.as_str() {
|
|
43
|
-
Some(inp) => aiRoute_str = inp,
|
|
44
|
-
None => panic!("aiRoute field is missing in input json"),
|
|
45
|
-
}
|
|
46
|
-
let airoute = String::from(binpath) + &"/../../" + &aiRoute_str;
|
|
47
|
-
|
|
48
|
-
let apilink_json: &JsonValue = &json_string["apilink"];
|
|
49
|
-
let apilink: &str;
|
|
50
|
-
match apilink_json.as_str() {
|
|
51
|
-
Some(inp) => apilink = inp,
|
|
52
|
-
None => panic!("apilink field is missing in input json"),
|
|
53
|
-
}
|
|
54
|
-
|
|
55
|
-
let comp_model_name_json: &JsonValue = &json_string["comp_model_name"];
|
|
56
|
-
let comp_model_name: &str;
|
|
57
|
-
match comp_model_name_json.as_str() {
|
|
58
|
-
Some(inp) => comp_model_name = inp,
|
|
59
|
-
None => panic!("comp_model_name field is missing in input json"),
|
|
60
|
-
}
|
|
61
|
-
|
|
62
|
-
let embedding_model_name_json: &JsonValue = &json_string["embedding_model_name"];
|
|
63
|
-
let embedding_model_name: &str;
|
|
64
|
-
match embedding_model_name_json.as_str() {
|
|
65
|
-
Some(inp) => embedding_model_name = inp,
|
|
66
|
-
None => panic!("embedding_model_name field is missing in input json"),
|
|
67
|
-
}
|
|
68
|
-
|
|
69
|
-
let llm_backend_name_json: &JsonValue = &json_string["llm_backend_name"];
|
|
70
|
-
let llm_backend_name: &str;
|
|
71
|
-
match llm_backend_name_json.as_str() {
|
|
72
|
-
Some(inp) => llm_backend_name = inp,
|
|
73
|
-
None => panic!("llm_backend_name field is missing in input json"),
|
|
74
|
-
}
|
|
75
|
-
|
|
76
|
-
let llm_backend_type: aichatbot::llm_backend;
|
|
77
|
-
let mut final_output: Option<String> = None;
|
|
78
|
-
let temperature: f64 = 0.01;
|
|
79
|
-
let max_new_tokens: usize = 512;
|
|
80
|
-
let top_p: f32 = 0.95;
|
|
81
|
-
if llm_backend_name != "ollama" && llm_backend_name != "SJ" {
|
|
82
|
-
panic!(
|
|
83
|
-
"This code currently supports only Ollama and SJ provider. llm_backend_name must be \"ollama\" or \"SJ\""
|
|
84
|
-
);
|
|
85
|
-
} else if llm_backend_name == "ollama".to_string() {
|
|
86
|
-
llm_backend_type = aichatbot::llm_backend::Ollama();
|
|
87
|
-
// Initialize Ollama client
|
|
88
|
-
let ollama_client = ollama::Client::builder()
|
|
89
|
-
.base_url(apilink)
|
|
90
|
-
.build()
|
|
91
|
-
.expect("Ollama server not found");
|
|
92
|
-
let embedding_model = ollama_client.embedding_model(embedding_model_name);
|
|
93
|
-
let comp_model = ollama_client.completion_model(comp_model_name);
|
|
94
|
-
final_output = Some(
|
|
95
|
-
aichatbot::classify_query_by_dataset_type(
|
|
96
|
-
user_input,
|
|
97
|
-
comp_model.clone(),
|
|
98
|
-
embedding_model.clone(),
|
|
99
|
-
&llm_backend_type,
|
|
100
|
-
temperature,
|
|
101
|
-
max_new_tokens,
|
|
102
|
-
top_p,
|
|
103
|
-
&airoute,
|
|
104
|
-
)
|
|
105
|
-
.await,
|
|
106
|
-
);
|
|
107
|
-
} else if llm_backend_name == "SJ".to_string() {
|
|
108
|
-
llm_backend_type = aichatbot::llm_backend::Sj();
|
|
109
|
-
// Initialize Sj provider client
|
|
110
|
-
let sj_client = sjprovider::Client::builder()
|
|
111
|
-
.base_url(apilink)
|
|
112
|
-
.build()
|
|
113
|
-
.expect("SJ server not found");
|
|
114
|
-
let embedding_model = sj_client.embedding_model(embedding_model_name);
|
|
115
|
-
let comp_model = sj_client.completion_model(comp_model_name);
|
|
116
|
-
final_output = Some(
|
|
117
|
-
aichatbot::classify_query_by_dataset_type(
|
|
118
|
-
user_input,
|
|
119
|
-
comp_model.clone(),
|
|
120
|
-
embedding_model.clone(),
|
|
121
|
-
&llm_backend_type,
|
|
122
|
-
temperature,
|
|
123
|
-
max_new_tokens,
|
|
124
|
-
top_p,
|
|
125
|
-
&airoute,
|
|
126
|
-
)
|
|
127
|
-
.await,
|
|
128
|
-
);
|
|
129
|
-
}
|
|
130
|
-
|
|
131
|
-
match final_output {
|
|
132
|
-
Some(fin_out) => {
|
|
133
|
-
println!("{{\"{}\":{}}}", "route", fin_out);
|
|
134
|
-
}
|
|
135
|
-
None => {
|
|
136
|
-
println!("{{\"{}\":\"{}\"}}", "route", "unknown");
|
|
137
|
-
}
|
|
138
|
-
}
|
|
139
|
-
}
|
|
140
|
-
Err(error) => println!("Incorrect json:{}", error),
|
|
141
|
-
}
|
|
142
|
-
}
|
|
143
|
-
Err(error) => println!("Piping error: {}", error),
|
|
144
|
-
}
|
|
145
|
-
Ok(())
|
|
146
|
-
}
|
|
147
|
-
|
|
148
|
-
#[derive(Debug, JsonSchema)]
|
|
149
|
-
#[allow(dead_code)]
|
|
150
|
-
struct OutputJson {
|
|
151
|
-
pub answer: String,
|
|
152
|
-
}
|
package/src/summary_agent.rs
DELETED
|
@@ -1,201 +0,0 @@
|
|
|
1
|
-
// Syntax: cd .. && cargo build --release && time cat ~/sjpp/test.txt | target/release/aichatbot
|
|
2
|
-
#![allow(non_snake_case)]
|
|
3
|
-
use crate::aichatbot::AiJsonFormat;
|
|
4
|
-
use anyhow::Result;
|
|
5
|
-
use json::JsonValue;
|
|
6
|
-
use std::fs;
|
|
7
|
-
use std::io;
|
|
8
|
-
use std::path::Path;
|
|
9
|
-
mod aichatbot; // Get summary agent
|
|
10
|
-
|
|
11
|
-
mod ollama; // Importing custom rig module for invoking ollama server
|
|
12
|
-
mod sjprovider; // Importing custom rig module for invoking SJ GPU server
|
|
13
|
-
mod test_ai; // Test examples for AI chatbot
|
|
14
|
-
|
|
15
|
-
#[tokio::main]
|
|
16
|
-
async fn main() -> Result<()> {
|
|
17
|
-
let mut input = String::new();
|
|
18
|
-
match io::stdin().read_line(&mut input) {
|
|
19
|
-
// Accepting the piped input from nodejs (or command line from testing)
|
|
20
|
-
Ok(_n) => {
|
|
21
|
-
let input_json = json::parse(&input);
|
|
22
|
-
match input_json {
|
|
23
|
-
Ok(json_string) => {
|
|
24
|
-
//println!("json_string:{}", json_string);
|
|
25
|
-
let user_input_json: &JsonValue = &json_string["user_input"];
|
|
26
|
-
let user_input: &str;
|
|
27
|
-
match user_input_json.as_str() {
|
|
28
|
-
Some(inp) => user_input = inp,
|
|
29
|
-
None => panic!("user_input field is missing in input json"),
|
|
30
|
-
}
|
|
31
|
-
|
|
32
|
-
let dataset_db_json: &JsonValue = &json_string["dataset_db"];
|
|
33
|
-
let dataset_db_str: &str;
|
|
34
|
-
match dataset_db_json.as_str() {
|
|
35
|
-
Some(inp) => dataset_db_str = inp,
|
|
36
|
-
None => panic!("dataset_db field is missing in input json"),
|
|
37
|
-
}
|
|
38
|
-
|
|
39
|
-
let genedb_json: &JsonValue = &json_string["genedb"];
|
|
40
|
-
let genedb_str: &str;
|
|
41
|
-
match genedb_json.as_str() {
|
|
42
|
-
Some(inp) => genedb_str = inp,
|
|
43
|
-
None => panic!("genedb field is missing in input json"),
|
|
44
|
-
}
|
|
45
|
-
|
|
46
|
-
if user_input.len() == 0 {
|
|
47
|
-
panic!("The user input is empty");
|
|
48
|
-
}
|
|
49
|
-
|
|
50
|
-
let tpmasterdir_json: &JsonValue = &json_string["tpmasterdir"];
|
|
51
|
-
let tpmasterdir: &str;
|
|
52
|
-
match tpmasterdir_json.as_str() {
|
|
53
|
-
Some(inp) => tpmasterdir = inp,
|
|
54
|
-
None => panic!("tpmasterdir not found"),
|
|
55
|
-
}
|
|
56
|
-
|
|
57
|
-
let binpath_json: &JsonValue = &json_string["binpath"];
|
|
58
|
-
let binpath: &str;
|
|
59
|
-
match binpath_json.as_str() {
|
|
60
|
-
Some(inp) => binpath = inp,
|
|
61
|
-
None => panic!("binpath not found"),
|
|
62
|
-
}
|
|
63
|
-
|
|
64
|
-
let ai_json_file_json: &JsonValue = &json_string["aifiles"];
|
|
65
|
-
let ai_json_file: String;
|
|
66
|
-
match ai_json_file_json.as_str() {
|
|
67
|
-
Some(inp) => ai_json_file = String::from(binpath) + &"/../../" + &inp,
|
|
68
|
-
None => {
|
|
69
|
-
panic!("ai json file not found")
|
|
70
|
-
}
|
|
71
|
-
}
|
|
72
|
-
|
|
73
|
-
let ai_json_file = Path::new(&ai_json_file);
|
|
74
|
-
let ai_json_file_path;
|
|
75
|
-
let current_dir = std::env::current_dir().unwrap();
|
|
76
|
-
match ai_json_file.canonicalize() {
|
|
77
|
-
Ok(p) => ai_json_file_path = p,
|
|
78
|
-
Err(_) => {
|
|
79
|
-
panic!(
|
|
80
|
-
"AI JSON file path not found:{:?}, current directory:{:?}",
|
|
81
|
-
ai_json_file, current_dir
|
|
82
|
-
)
|
|
83
|
-
}
|
|
84
|
-
}
|
|
85
|
-
|
|
86
|
-
// Read the file
|
|
87
|
-
let ai_data = fs::read_to_string(ai_json_file_path).unwrap();
|
|
88
|
-
|
|
89
|
-
// Parse the JSON data
|
|
90
|
-
let ai_json: AiJsonFormat =
|
|
91
|
-
serde_json::from_str(&ai_data).expect("AI JSON file does not have the correct format");
|
|
92
|
-
|
|
93
|
-
let genedb = String::from(tpmasterdir) + &"/" + &genedb_str;
|
|
94
|
-
let dataset_db = String::from(tpmasterdir) + &"/" + &dataset_db_str;
|
|
95
|
-
|
|
96
|
-
let apilink_json: &JsonValue = &json_string["apilink"];
|
|
97
|
-
let apilink: &str;
|
|
98
|
-
match apilink_json.as_str() {
|
|
99
|
-
Some(inp) => apilink = inp,
|
|
100
|
-
None => panic!("apilink field is missing in input json"),
|
|
101
|
-
}
|
|
102
|
-
|
|
103
|
-
let comp_model_name_json: &JsonValue = &json_string["comp_model_name"];
|
|
104
|
-
let comp_model_name: &str;
|
|
105
|
-
match comp_model_name_json.as_str() {
|
|
106
|
-
Some(inp) => comp_model_name = inp,
|
|
107
|
-
None => panic!("comp_model_name field is missing in input json"),
|
|
108
|
-
}
|
|
109
|
-
|
|
110
|
-
let embedding_model_name_json: &JsonValue = &json_string["embedding_model_name"];
|
|
111
|
-
let embedding_model_name: &str;
|
|
112
|
-
match embedding_model_name_json.as_str() {
|
|
113
|
-
Some(inp) => embedding_model_name = inp,
|
|
114
|
-
None => panic!("embedding_model_name field is missing in input json"),
|
|
115
|
-
}
|
|
116
|
-
|
|
117
|
-
let llm_backend_name_json: &JsonValue = &json_string["llm_backend_name"];
|
|
118
|
-
let llm_backend_name: &str;
|
|
119
|
-
match llm_backend_name_json.as_str() {
|
|
120
|
-
Some(inp) => llm_backend_name = inp,
|
|
121
|
-
None => panic!("llm_backend_name field is missing in input json"),
|
|
122
|
-
}
|
|
123
|
-
|
|
124
|
-
let llm_backend_type: aichatbot::llm_backend;
|
|
125
|
-
let mut final_output: Option<String> = None;
|
|
126
|
-
let temperature: f64 = 0.01;
|
|
127
|
-
let max_new_tokens: usize = 512;
|
|
128
|
-
let top_p: f32 = 0.95;
|
|
129
|
-
let testing = false; // This variable is always false in production, this is true in test_ai.rs for testing code
|
|
130
|
-
if llm_backend_name != "ollama" && llm_backend_name != "SJ" {
|
|
131
|
-
panic!(
|
|
132
|
-
"This code currently supports only Ollama and SJ provider. llm_backend_name must be \"ollama\" or \"SJ\""
|
|
133
|
-
);
|
|
134
|
-
} else if llm_backend_name == "ollama".to_string() {
|
|
135
|
-
llm_backend_type = aichatbot::llm_backend::Ollama();
|
|
136
|
-
// Initialize Ollama client
|
|
137
|
-
let ollama_client = ollama::Client::builder()
|
|
138
|
-
.base_url(apilink)
|
|
139
|
-
.build()
|
|
140
|
-
.expect("Ollama server not found");
|
|
141
|
-
let embedding_model = ollama_client.embedding_model(embedding_model_name);
|
|
142
|
-
let comp_model = ollama_client.completion_model(comp_model_name);
|
|
143
|
-
final_output = Some(
|
|
144
|
-
aichatbot::extract_summary_information(
|
|
145
|
-
user_input,
|
|
146
|
-
comp_model,
|
|
147
|
-
embedding_model,
|
|
148
|
-
&llm_backend_type,
|
|
149
|
-
temperature,
|
|
150
|
-
max_new_tokens,
|
|
151
|
-
top_p,
|
|
152
|
-
&dataset_db,
|
|
153
|
-
&genedb,
|
|
154
|
-
&ai_json,
|
|
155
|
-
testing,
|
|
156
|
-
)
|
|
157
|
-
.await,
|
|
158
|
-
);
|
|
159
|
-
} else if llm_backend_name == "SJ".to_string() {
|
|
160
|
-
llm_backend_type = aichatbot::llm_backend::Sj();
|
|
161
|
-
// Initialize Sj provider client
|
|
162
|
-
let sj_client = sjprovider::Client::builder()
|
|
163
|
-
.base_url(apilink)
|
|
164
|
-
.build()
|
|
165
|
-
.expect("SJ server not found");
|
|
166
|
-
let embedding_model = sj_client.embedding_model(embedding_model_name);
|
|
167
|
-
let comp_model = sj_client.completion_model(comp_model_name);
|
|
168
|
-
final_output = Some(
|
|
169
|
-
aichatbot::extract_summary_information(
|
|
170
|
-
user_input,
|
|
171
|
-
comp_model,
|
|
172
|
-
embedding_model,
|
|
173
|
-
&llm_backend_type,
|
|
174
|
-
temperature,
|
|
175
|
-
max_new_tokens,
|
|
176
|
-
top_p,
|
|
177
|
-
&dataset_db,
|
|
178
|
-
&genedb,
|
|
179
|
-
&ai_json,
|
|
180
|
-
testing,
|
|
181
|
-
)
|
|
182
|
-
.await,
|
|
183
|
-
);
|
|
184
|
-
}
|
|
185
|
-
|
|
186
|
-
match final_output {
|
|
187
|
-
Some(fin_out) => {
|
|
188
|
-
println!("final_output:{:?}", fin_out.replace("\\", ""));
|
|
189
|
-
}
|
|
190
|
-
None => {
|
|
191
|
-
println!("final_output:{{\"{}\":\"{}\"}}", "action", "unknown");
|
|
192
|
-
}
|
|
193
|
-
}
|
|
194
|
-
}
|
|
195
|
-
Err(error) => println!("Incorrect json:{}", error),
|
|
196
|
-
}
|
|
197
|
-
}
|
|
198
|
-
Err(error) => println!("Piping error: {}", error),
|
|
199
|
-
}
|
|
200
|
-
Ok(())
|
|
201
|
-
}
|
package/src/test_ai.rs
DELETED
|
@@ -1,193 +0,0 @@
|
|
|
1
|
-
// For capturing output from a test, run: cd .. && cargo test -- --nocapture
|
|
2
|
-
// Ignored tests: cd .. && export RUST_BACKTRACE=full && time cargo test -- --ignored --nocapture
|
|
3
|
-
#[allow(dead_code)]
|
|
4
|
-
fn main() {}
|
|
5
|
-
|
|
6
|
-
#[cfg(test)]
|
|
7
|
-
mod tests {
|
|
8
|
-
use crate::aichatbot::{AiJsonFormat, AnswerFormat, SummaryType, llm_backend, run_pipeline};
|
|
9
|
-
use serde_json;
|
|
10
|
-
use std::fs::{self};
|
|
11
|
-
use std::path::Path;
|
|
12
|
-
|
|
13
|
-
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
14
|
-
struct ServerConfig {
|
|
15
|
-
tpmasterdir: String,
|
|
16
|
-
llm_backend: String,
|
|
17
|
-
sj_apilink: String,
|
|
18
|
-
sj_comp_model_name: String,
|
|
19
|
-
sj_embedding_model_name: String,
|
|
20
|
-
ollama_apilink: String,
|
|
21
|
-
ollama_comp_model_name: String,
|
|
22
|
-
ollama_embedding_model_name: String,
|
|
23
|
-
genomes: Vec<Genomes>,
|
|
24
|
-
aiRoute: String,
|
|
25
|
-
}
|
|
26
|
-
|
|
27
|
-
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
28
|
-
struct Genomes {
|
|
29
|
-
name: String,
|
|
30
|
-
datasets: Vec<Dataset>,
|
|
31
|
-
}
|
|
32
|
-
|
|
33
|
-
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
34
|
-
struct Dataset {
|
|
35
|
-
name: String,
|
|
36
|
-
aifiles: Option<String>, // For now aifiles are defined only for certain datasets
|
|
37
|
-
}
|
|
38
|
-
|
|
39
|
-
#[tokio::test]
|
|
40
|
-
#[ignore]
|
|
41
|
-
async fn user_prompts() {
|
|
42
|
-
let temperature: f64 = 0.01;
|
|
43
|
-
let max_new_tokens: usize = 512;
|
|
44
|
-
let top_p: f32 = 0.95;
|
|
45
|
-
let serverconfig_file_path = Path::new("../../serverconfig.json");
|
|
46
|
-
let absolute_path = serverconfig_file_path.canonicalize().unwrap();
|
|
47
|
-
let testing = true; // This causes the JSON being output from run_pipeline() to be in LLM JSON format
|
|
48
|
-
|
|
49
|
-
// Read the file
|
|
50
|
-
let data = fs::read_to_string(absolute_path).unwrap();
|
|
51
|
-
|
|
52
|
-
// Parse the JSON data
|
|
53
|
-
let serverconfig: ServerConfig = serde_json::from_str(&data).expect("JSON not in serverconfig.json format");
|
|
54
|
-
let airoute = String::from("../../") + &serverconfig.aiRoute;
|
|
55
|
-
for genome in &serverconfig.genomes {
|
|
56
|
-
for dataset in &genome.datasets {
|
|
57
|
-
match &dataset.aifiles {
|
|
58
|
-
Some(ai_json_file) => {
|
|
59
|
-
println!("Testing dataset:{}", dataset.name);
|
|
60
|
-
let ai_json_file_path = String::from("../../") + ai_json_file;
|
|
61
|
-
let ai_json_file = Path::new(&ai_json_file_path);
|
|
62
|
-
|
|
63
|
-
// Read the file
|
|
64
|
-
let ai_data = fs::read_to_string(ai_json_file).unwrap();
|
|
65
|
-
// Parse the JSON data
|
|
66
|
-
let ai_json: AiJsonFormat =
|
|
67
|
-
serde_json::from_str(&ai_data).expect("AI JSON file does not have the correct format");
|
|
68
|
-
//println!("ai_json:{:?}", ai_json);
|
|
69
|
-
let genedb = String::from(&serverconfig.tpmasterdir) + &"/" + &ai_json.genedb;
|
|
70
|
-
let dataset_db = String::from(&serverconfig.tpmasterdir) + &"/" + &ai_json.db;
|
|
71
|
-
let llm_backend_name = &serverconfig.llm_backend;
|
|
72
|
-
let llm_backend_type: llm_backend;
|
|
73
|
-
|
|
74
|
-
if llm_backend_name != "ollama" && llm_backend_name != "SJ" {
|
|
75
|
-
panic!(
|
|
76
|
-
"This code currently supports only Ollama and SJ provider. llm_backend_name must be \"ollama\" or \"SJ\""
|
|
77
|
-
);
|
|
78
|
-
} else if *llm_backend_name == "ollama".to_string() {
|
|
79
|
-
let ollama_host = &serverconfig.ollama_apilink;
|
|
80
|
-
let ollama_embedding_model_name = &serverconfig.ollama_embedding_model_name;
|
|
81
|
-
let ollama_comp_model_name = &serverconfig.ollama_comp_model_name;
|
|
82
|
-
llm_backend_type = llm_backend::Ollama();
|
|
83
|
-
let ollama_client = super::super::ollama::Client::builder()
|
|
84
|
-
.base_url(ollama_host)
|
|
85
|
-
.build()
|
|
86
|
-
.expect("Ollama server not found");
|
|
87
|
-
let embedding_model = ollama_client.embedding_model(ollama_embedding_model_name);
|
|
88
|
-
let comp_model = ollama_client.completion_model(ollama_comp_model_name);
|
|
89
|
-
for chart in ai_json.charts.clone() {
|
|
90
|
-
if chart.r#type == "Summary" {
|
|
91
|
-
for ques_ans in chart.TestData {
|
|
92
|
-
let user_input = ques_ans.question;
|
|
93
|
-
let llm_output = run_pipeline(
|
|
94
|
-
&user_input,
|
|
95
|
-
comp_model.clone(),
|
|
96
|
-
embedding_model.clone(),
|
|
97
|
-
llm_backend_type.clone(),
|
|
98
|
-
temperature,
|
|
99
|
-
max_new_tokens,
|
|
100
|
-
top_p,
|
|
101
|
-
&dataset_db,
|
|
102
|
-
&genedb,
|
|
103
|
-
&ai_json,
|
|
104
|
-
&airoute,
|
|
105
|
-
testing,
|
|
106
|
-
)
|
|
107
|
-
.await;
|
|
108
|
-
let llm_json_value: SummaryType = serde_json::from_str(&llm_output.unwrap()).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
|
|
109
|
-
match ques_ans.answer {
|
|
110
|
-
AnswerFormat::summary_type(sum) => {
|
|
111
|
-
//println!("expected answer:{:?}", &sum);
|
|
112
|
-
assert_eq!(
|
|
113
|
-
llm_json_value.sort_summarytype_struct(),
|
|
114
|
-
sum.sort_summarytype_struct()
|
|
115
|
-
);
|
|
116
|
-
}
|
|
117
|
-
AnswerFormat::DE_type(_) => {
|
|
118
|
-
panic!("DE type not valid for summary")
|
|
119
|
-
}
|
|
120
|
-
}
|
|
121
|
-
}
|
|
122
|
-
}
|
|
123
|
-
}
|
|
124
|
-
} else if *llm_backend_name == "SJ".to_string() {
|
|
125
|
-
let sjprovider_host = &serverconfig.sj_apilink;
|
|
126
|
-
let sj_embedding_model_name = &serverconfig.sj_embedding_model_name;
|
|
127
|
-
let sj_comp_model_name = &serverconfig.sj_comp_model_name;
|
|
128
|
-
llm_backend_type = llm_backend::Sj();
|
|
129
|
-
let sj_client = super::super::sjprovider::Client::builder()
|
|
130
|
-
.base_url(sjprovider_host)
|
|
131
|
-
.build()
|
|
132
|
-
.expect("SJ server not found");
|
|
133
|
-
let embedding_model = sj_client.embedding_model(sj_embedding_model_name);
|
|
134
|
-
let comp_model = sj_client.completion_model(sj_comp_model_name);
|
|
135
|
-
|
|
136
|
-
for chart in ai_json.charts.clone() {
|
|
137
|
-
if chart.r#type == "Summary" {
|
|
138
|
-
for ques_ans in chart.TestData {
|
|
139
|
-
let user_input = ques_ans.question;
|
|
140
|
-
if user_input.len() > 0 {
|
|
141
|
-
let llm_output = run_pipeline(
|
|
142
|
-
&user_input,
|
|
143
|
-
comp_model.clone(),
|
|
144
|
-
embedding_model.clone(),
|
|
145
|
-
llm_backend_type.clone(),
|
|
146
|
-
temperature,
|
|
147
|
-
max_new_tokens,
|
|
148
|
-
top_p,
|
|
149
|
-
&dataset_db,
|
|
150
|
-
&genedb,
|
|
151
|
-
&ai_json,
|
|
152
|
-
&airoute,
|
|
153
|
-
testing,
|
|
154
|
-
)
|
|
155
|
-
.await;
|
|
156
|
-
//println!("user_input:{}", user_input);
|
|
157
|
-
//println!("llm_answer:{:?}", llm_output);
|
|
158
|
-
//println!("expected answer:{:?}", &ques_ans.answer);
|
|
159
|
-
let llm_json_value: SummaryType = serde_json::from_str(&llm_output.unwrap()).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
|
|
160
|
-
//println!(
|
|
161
|
-
// "llm_answer:{:?}",
|
|
162
|
-
// llm_json_value.clone().sort_summarytype_struct()
|
|
163
|
-
//);
|
|
164
|
-
//println!(
|
|
165
|
-
// "expected answer:{:?}",
|
|
166
|
-
// &expected_json_value.clone().sort_summarytype_struct()
|
|
167
|
-
//);
|
|
168
|
-
match ques_ans.answer {
|
|
169
|
-
AnswerFormat::summary_type(sum) => {
|
|
170
|
-
//println!("expected answer:{:?}", &sum);
|
|
171
|
-
assert_eq!(
|
|
172
|
-
llm_json_value.sort_summarytype_struct(),
|
|
173
|
-
sum.sort_summarytype_struct()
|
|
174
|
-
);
|
|
175
|
-
}
|
|
176
|
-
AnswerFormat::DE_type(_) => {
|
|
177
|
-
panic!("DE type not valid for summary")
|
|
178
|
-
}
|
|
179
|
-
}
|
|
180
|
-
} else {
|
|
181
|
-
panic!("The user input is empty");
|
|
182
|
-
}
|
|
183
|
-
}
|
|
184
|
-
}
|
|
185
|
-
}
|
|
186
|
-
}
|
|
187
|
-
}
|
|
188
|
-
None => {}
|
|
189
|
-
}
|
|
190
|
-
}
|
|
191
|
-
}
|
|
192
|
-
}
|
|
193
|
-
}
|