@sjcrh/proteinpaint-rust 2.169.0 → 2.170.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Cargo.toml +8 -4
- package/package.json +1 -1
- package/src/aichatbot.rs +231 -254
- package/src/manhattan_plot.rs +31 -18
- package/src/query_classification.rs +152 -0
- package/src/summary_agent.rs +201 -0
- package/src/test_ai.rs +79 -72
package/src/manhattan_plot.rs
CHANGED
|
@@ -22,6 +22,7 @@ struct Input {
|
|
|
22
22
|
plot_height: u64,
|
|
23
23
|
device_pixel_ratio: f64,
|
|
24
24
|
png_dot_radius: u64,
|
|
25
|
+
log_cutoff: f64,
|
|
25
26
|
}
|
|
26
27
|
|
|
27
28
|
// chromosome info
|
|
@@ -120,6 +121,7 @@ fn cumulative_chrom(
|
|
|
120
121
|
fn grin2_file_read(
|
|
121
122
|
grin2_file: &str,
|
|
122
123
|
chrom_data: &HashMap<String, ChromInfo>,
|
|
124
|
+
log_cutoff: f64,
|
|
123
125
|
) -> Result<(Vec<u64>, Vec<f64>, Vec<String>, Vec<PointDetail>, Vec<usize>), Box<dyn Error>> {
|
|
124
126
|
// Default colours
|
|
125
127
|
let mut colors: HashMap<String, String> = HashMap::new();
|
|
@@ -217,13 +219,18 @@ fn grin2_file_read(
|
|
|
217
219
|
Some(q) => q,
|
|
218
220
|
None => continue,
|
|
219
221
|
};
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
Ok(v) if v > 0.0 => v,
|
|
223
|
-
Ok(v) if v == 0.0 => 1e-300, // Treat exact 0 as ~1e-300 so we can still show q-values that are 0 and not filter them out
|
|
222
|
+
let original_q_val: f64 = match q_val_str.parse() {
|
|
223
|
+
Ok(v) if v >= 0.0 => v,
|
|
224
224
|
_ => continue,
|
|
225
225
|
};
|
|
226
|
-
|
|
226
|
+
|
|
227
|
+
// Use log_cutoff for zero q-values to avoid -inf. These will be capped later in plotting at log_cutoff
|
|
228
|
+
let neg_log10_q = if original_q_val == 0.0 {
|
|
229
|
+
log_cutoff
|
|
230
|
+
} else {
|
|
231
|
+
-original_q_val.log10()
|
|
232
|
+
};
|
|
233
|
+
|
|
227
234
|
let n_subj_count: Option<i64> = n_idx_opt
|
|
228
235
|
.and_then(|i| fields.get(i))
|
|
229
236
|
.and_then(|s| s.parse::<i64>().ok());
|
|
@@ -234,7 +241,8 @@ fn grin2_file_read(
|
|
|
234
241
|
colors_vec.push(color.clone());
|
|
235
242
|
|
|
236
243
|
// only add significant points for interactivity
|
|
237
|
-
|
|
244
|
+
// We check against the original q-value here so we send back the correct values instead of the 1e-300 used for log transform
|
|
245
|
+
if original_q_val <= 0.05 {
|
|
238
246
|
point_details.push(PointDetail {
|
|
239
247
|
x: x_pos,
|
|
240
248
|
y: neg_log10_q,
|
|
@@ -245,7 +253,7 @@ fn grin2_file_read(
|
|
|
245
253
|
start: gene_start,
|
|
246
254
|
end: gene_end,
|
|
247
255
|
pos: gene_start,
|
|
248
|
-
q_value:
|
|
256
|
+
q_value: original_q_val,
|
|
249
257
|
nsubj: n_subj_count,
|
|
250
258
|
pixel_x: 0.0,
|
|
251
259
|
pixel_y: 0.0,
|
|
@@ -267,6 +275,7 @@ fn plot_grin2_manhattan(
|
|
|
267
275
|
plot_height: u64,
|
|
268
276
|
device_pixel_ratio: f64,
|
|
269
277
|
png_dot_radius: u64,
|
|
278
|
+
log_cutoff: f64,
|
|
270
279
|
) -> Result<(String, InteractiveData), Box<dyn Error>> {
|
|
271
280
|
// ------------------------------------------------
|
|
272
281
|
// 1. Build cumulative chromosome map
|
|
@@ -295,7 +304,7 @@ fn plot_grin2_manhattan(
|
|
|
295
304
|
let mut point_details = Vec::new();
|
|
296
305
|
let mut sig_indices = Vec::new();
|
|
297
306
|
|
|
298
|
-
if let Ok((x, y, c, pd, si)) = grin2_file_read(&grin2_result_file, &chrom_data) {
|
|
307
|
+
if let Ok((x, y, c, pd, si)) = grin2_file_read(&grin2_result_file, &chrom_data, log_cutoff) {
|
|
299
308
|
xs = x;
|
|
300
309
|
ys = y;
|
|
301
310
|
colors_vec = c;
|
|
@@ -304,24 +313,26 @@ fn plot_grin2_manhattan(
|
|
|
304
313
|
}
|
|
305
314
|
|
|
306
315
|
// ------------------------------------------------
|
|
307
|
-
// 3. Y-axis scaling
|
|
316
|
+
// 3. Y-axis scaling (cap at 40)
|
|
308
317
|
// ------------------------------------------------
|
|
309
318
|
let y_padding = png_dot_radius as f64;
|
|
310
319
|
let y_min = 0.0 - y_padding;
|
|
320
|
+
let y_cap = log_cutoff; // typically 40.0. Use the passed log_cutoff value that user will be able to modify in the future
|
|
311
321
|
let y_max = if !ys.is_empty() {
|
|
312
322
|
let max_y = ys.iter().cloned().fold(f64::MIN, f64::max);
|
|
313
|
-
if max_y >
|
|
314
|
-
|
|
315
|
-
let scale_factor_y = target / max_y;
|
|
316
|
-
|
|
323
|
+
if max_y > y_cap {
|
|
324
|
+
// Clamp values above the cap
|
|
317
325
|
for y in ys.iter_mut() {
|
|
318
|
-
*y
|
|
326
|
+
if *y > y_cap {
|
|
327
|
+
*y = y_cap;
|
|
328
|
+
}
|
|
319
329
|
}
|
|
320
330
|
for p in point_details.iter_mut() {
|
|
321
|
-
p.y
|
|
331
|
+
if p.y > y_cap {
|
|
332
|
+
p.y = y_cap;
|
|
333
|
+
}
|
|
322
334
|
}
|
|
323
|
-
|
|
324
|
-
scaled_max + 0.35 + y_padding
|
|
335
|
+
y_cap + 0.35 + y_padding
|
|
325
336
|
} else {
|
|
326
337
|
max_y + 0.35 + y_padding
|
|
327
338
|
}
|
|
@@ -386,7 +397,7 @@ fn plot_grin2_manhattan(
|
|
|
386
397
|
}
|
|
387
398
|
|
|
388
399
|
// ------------------------------------------------
|
|
389
|
-
// 7.
|
|
400
|
+
// 7. Capture high-DPR pixel mapping for the points
|
|
390
401
|
// we do not draw the points with plotters (will use tiny-skia for AA)
|
|
391
402
|
// but use charts.backend_coord to map data->pixel in the high-DPR backend
|
|
392
403
|
// ------------------------------------------------
|
|
@@ -495,6 +506,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
|
495
506
|
let plot_height = &input_json.plot_height;
|
|
496
507
|
let device_pixel_ratio = &input_json.device_pixel_ratio;
|
|
497
508
|
let png_dot_radius = &input_json.png_dot_radius;
|
|
509
|
+
let log_cutoff = &input_json.log_cutoff;
|
|
498
510
|
if let Ok((base64_string, plot_data)) = plot_grin2_manhattan(
|
|
499
511
|
grin2_file.clone(),
|
|
500
512
|
chrom_size.clone(),
|
|
@@ -502,6 +514,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
|
502
514
|
plot_height.clone(),
|
|
503
515
|
device_pixel_ratio.clone(),
|
|
504
516
|
png_dot_radius.clone(),
|
|
517
|
+
log_cutoff.clone(),
|
|
505
518
|
) {
|
|
506
519
|
let output = Output {
|
|
507
520
|
png: base64_string,
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
// Syntax: cd .. && cargo build --release && time cat ~/sjpp/test.txt | target/release/aichatbot
|
|
2
|
+
#![allow(non_snake_case)]
|
|
3
|
+
use anyhow::Result;
|
|
4
|
+
use json::JsonValue;
|
|
5
|
+
use schemars::JsonSchema;
|
|
6
|
+
use std::io;
|
|
7
|
+
mod aichatbot; // Importing classification agent from aichatbot.rs
|
|
8
|
+
mod ollama; // Importing custom rig module for invoking ollama server
|
|
9
|
+
mod sjprovider; // Importing custom rig module for invoking SJ GPU server
|
|
10
|
+
mod test_ai; // Test examples for AI chatbot
|
|
11
|
+
|
|
12
|
+
#[tokio::main]
|
|
13
|
+
async fn main() -> Result<()> {
|
|
14
|
+
let mut input = String::new();
|
|
15
|
+
match io::stdin().read_line(&mut input) {
|
|
16
|
+
// Accepting the piped input from nodejs (or command line from testing)
|
|
17
|
+
Ok(_n) => {
|
|
18
|
+
let input_json = json::parse(&input);
|
|
19
|
+
match input_json {
|
|
20
|
+
Ok(json_string) => {
|
|
21
|
+
//println!("json_string:{}", json_string);
|
|
22
|
+
let user_input_json: &JsonValue = &json_string["user_input"];
|
|
23
|
+
let user_input: &str;
|
|
24
|
+
match user_input_json.as_str() {
|
|
25
|
+
Some(inp) => user_input = inp,
|
|
26
|
+
None => panic!("user_input field is missing in input json"),
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
if user_input.len() == 0 {
|
|
30
|
+
panic!("The user input is empty");
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
let binpath_json: &JsonValue = &json_string["binpath"];
|
|
34
|
+
let binpath: &str;
|
|
35
|
+
match binpath_json.as_str() {
|
|
36
|
+
Some(inp) => binpath = inp,
|
|
37
|
+
None => panic!("binpath not found"),
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
let aiRoute_json: &JsonValue = &json_string["aiRoute"];
|
|
41
|
+
let aiRoute_str: &str;
|
|
42
|
+
match aiRoute_json.as_str() {
|
|
43
|
+
Some(inp) => aiRoute_str = inp,
|
|
44
|
+
None => panic!("aiRoute field is missing in input json"),
|
|
45
|
+
}
|
|
46
|
+
let airoute = String::from(binpath) + &"/../../" + &aiRoute_str;
|
|
47
|
+
|
|
48
|
+
let apilink_json: &JsonValue = &json_string["apilink"];
|
|
49
|
+
let apilink: &str;
|
|
50
|
+
match apilink_json.as_str() {
|
|
51
|
+
Some(inp) => apilink = inp,
|
|
52
|
+
None => panic!("apilink field is missing in input json"),
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
let comp_model_name_json: &JsonValue = &json_string["comp_model_name"];
|
|
56
|
+
let comp_model_name: &str;
|
|
57
|
+
match comp_model_name_json.as_str() {
|
|
58
|
+
Some(inp) => comp_model_name = inp,
|
|
59
|
+
None => panic!("comp_model_name field is missing in input json"),
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
let embedding_model_name_json: &JsonValue = &json_string["embedding_model_name"];
|
|
63
|
+
let embedding_model_name: &str;
|
|
64
|
+
match embedding_model_name_json.as_str() {
|
|
65
|
+
Some(inp) => embedding_model_name = inp,
|
|
66
|
+
None => panic!("embedding_model_name field is missing in input json"),
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
let llm_backend_name_json: &JsonValue = &json_string["llm_backend_name"];
|
|
70
|
+
let llm_backend_name: &str;
|
|
71
|
+
match llm_backend_name_json.as_str() {
|
|
72
|
+
Some(inp) => llm_backend_name = inp,
|
|
73
|
+
None => panic!("llm_backend_name field is missing in input json"),
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
let llm_backend_type: aichatbot::llm_backend;
|
|
77
|
+
let mut final_output: Option<String> = None;
|
|
78
|
+
let temperature: f64 = 0.01;
|
|
79
|
+
let max_new_tokens: usize = 512;
|
|
80
|
+
let top_p: f32 = 0.95;
|
|
81
|
+
if llm_backend_name != "ollama" && llm_backend_name != "SJ" {
|
|
82
|
+
panic!(
|
|
83
|
+
"This code currently supports only Ollama and SJ provider. llm_backend_name must be \"ollama\" or \"SJ\""
|
|
84
|
+
);
|
|
85
|
+
} else if llm_backend_name == "ollama".to_string() {
|
|
86
|
+
llm_backend_type = aichatbot::llm_backend::Ollama();
|
|
87
|
+
// Initialize Ollama client
|
|
88
|
+
let ollama_client = ollama::Client::builder()
|
|
89
|
+
.base_url(apilink)
|
|
90
|
+
.build()
|
|
91
|
+
.expect("Ollama server not found");
|
|
92
|
+
let embedding_model = ollama_client.embedding_model(embedding_model_name);
|
|
93
|
+
let comp_model = ollama_client.completion_model(comp_model_name);
|
|
94
|
+
final_output = Some(
|
|
95
|
+
aichatbot::classify_query_by_dataset_type(
|
|
96
|
+
user_input,
|
|
97
|
+
comp_model.clone(),
|
|
98
|
+
embedding_model.clone(),
|
|
99
|
+
&llm_backend_type,
|
|
100
|
+
temperature,
|
|
101
|
+
max_new_tokens,
|
|
102
|
+
top_p,
|
|
103
|
+
&airoute,
|
|
104
|
+
)
|
|
105
|
+
.await,
|
|
106
|
+
);
|
|
107
|
+
} else if llm_backend_name == "SJ".to_string() {
|
|
108
|
+
llm_backend_type = aichatbot::llm_backend::Sj();
|
|
109
|
+
// Initialize Sj provider client
|
|
110
|
+
let sj_client = sjprovider::Client::builder()
|
|
111
|
+
.base_url(apilink)
|
|
112
|
+
.build()
|
|
113
|
+
.expect("SJ server not found");
|
|
114
|
+
let embedding_model = sj_client.embedding_model(embedding_model_name);
|
|
115
|
+
let comp_model = sj_client.completion_model(comp_model_name);
|
|
116
|
+
final_output = Some(
|
|
117
|
+
aichatbot::classify_query_by_dataset_type(
|
|
118
|
+
user_input,
|
|
119
|
+
comp_model.clone(),
|
|
120
|
+
embedding_model.clone(),
|
|
121
|
+
&llm_backend_type,
|
|
122
|
+
temperature,
|
|
123
|
+
max_new_tokens,
|
|
124
|
+
top_p,
|
|
125
|
+
&airoute,
|
|
126
|
+
)
|
|
127
|
+
.await,
|
|
128
|
+
);
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
match final_output {
|
|
132
|
+
Some(fin_out) => {
|
|
133
|
+
println!("{{\"{}\":{}}}", "route", fin_out);
|
|
134
|
+
}
|
|
135
|
+
None => {
|
|
136
|
+
println!("{{\"{}\":\"{}\"}}", "route", "unknown");
|
|
137
|
+
}
|
|
138
|
+
}
|
|
139
|
+
}
|
|
140
|
+
Err(error) => println!("Incorrect json:{}", error),
|
|
141
|
+
}
|
|
142
|
+
}
|
|
143
|
+
Err(error) => println!("Piping error: {}", error),
|
|
144
|
+
}
|
|
145
|
+
Ok(())
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
#[derive(Debug, JsonSchema)]
|
|
149
|
+
#[allow(dead_code)]
|
|
150
|
+
struct OutputJson {
|
|
151
|
+
pub answer: String,
|
|
152
|
+
}
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
// Syntax: cd .. && cargo build --release && time cat ~/sjpp/test.txt | target/release/aichatbot
|
|
2
|
+
#![allow(non_snake_case)]
|
|
3
|
+
use crate::aichatbot::AiJsonFormat;
|
|
4
|
+
use anyhow::Result;
|
|
5
|
+
use json::JsonValue;
|
|
6
|
+
use std::fs;
|
|
7
|
+
use std::io;
|
|
8
|
+
use std::path::Path;
|
|
9
|
+
mod aichatbot; // Get summary agent
|
|
10
|
+
|
|
11
|
+
mod ollama; // Importing custom rig module for invoking ollama server
|
|
12
|
+
mod sjprovider; // Importing custom rig module for invoking SJ GPU server
|
|
13
|
+
mod test_ai; // Test examples for AI chatbot
|
|
14
|
+
|
|
15
|
+
#[tokio::main]
|
|
16
|
+
async fn main() -> Result<()> {
|
|
17
|
+
let mut input = String::new();
|
|
18
|
+
match io::stdin().read_line(&mut input) {
|
|
19
|
+
// Accepting the piped input from nodejs (or command line from testing)
|
|
20
|
+
Ok(_n) => {
|
|
21
|
+
let input_json = json::parse(&input);
|
|
22
|
+
match input_json {
|
|
23
|
+
Ok(json_string) => {
|
|
24
|
+
//println!("json_string:{}", json_string);
|
|
25
|
+
let user_input_json: &JsonValue = &json_string["user_input"];
|
|
26
|
+
let user_input: &str;
|
|
27
|
+
match user_input_json.as_str() {
|
|
28
|
+
Some(inp) => user_input = inp,
|
|
29
|
+
None => panic!("user_input field is missing in input json"),
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
let dataset_db_json: &JsonValue = &json_string["dataset_db"];
|
|
33
|
+
let dataset_db_str: &str;
|
|
34
|
+
match dataset_db_json.as_str() {
|
|
35
|
+
Some(inp) => dataset_db_str = inp,
|
|
36
|
+
None => panic!("dataset_db field is missing in input json"),
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
let genedb_json: &JsonValue = &json_string["genedb"];
|
|
40
|
+
let genedb_str: &str;
|
|
41
|
+
match genedb_json.as_str() {
|
|
42
|
+
Some(inp) => genedb_str = inp,
|
|
43
|
+
None => panic!("genedb field is missing in input json"),
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
if user_input.len() == 0 {
|
|
47
|
+
panic!("The user input is empty");
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
let tpmasterdir_json: &JsonValue = &json_string["tpmasterdir"];
|
|
51
|
+
let tpmasterdir: &str;
|
|
52
|
+
match tpmasterdir_json.as_str() {
|
|
53
|
+
Some(inp) => tpmasterdir = inp,
|
|
54
|
+
None => panic!("tpmasterdir not found"),
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
let binpath_json: &JsonValue = &json_string["binpath"];
|
|
58
|
+
let binpath: &str;
|
|
59
|
+
match binpath_json.as_str() {
|
|
60
|
+
Some(inp) => binpath = inp,
|
|
61
|
+
None => panic!("binpath not found"),
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
let ai_json_file_json: &JsonValue = &json_string["aifiles"];
|
|
65
|
+
let ai_json_file: String;
|
|
66
|
+
match ai_json_file_json.as_str() {
|
|
67
|
+
Some(inp) => ai_json_file = String::from(binpath) + &"/../../" + &inp,
|
|
68
|
+
None => {
|
|
69
|
+
panic!("ai json file not found")
|
|
70
|
+
}
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
let ai_json_file = Path::new(&ai_json_file);
|
|
74
|
+
let ai_json_file_path;
|
|
75
|
+
let current_dir = std::env::current_dir().unwrap();
|
|
76
|
+
match ai_json_file.canonicalize() {
|
|
77
|
+
Ok(p) => ai_json_file_path = p,
|
|
78
|
+
Err(_) => {
|
|
79
|
+
panic!(
|
|
80
|
+
"AI JSON file path not found:{:?}, current directory:{:?}",
|
|
81
|
+
ai_json_file, current_dir
|
|
82
|
+
)
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
// Read the file
|
|
87
|
+
let ai_data = fs::read_to_string(ai_json_file_path).unwrap();
|
|
88
|
+
|
|
89
|
+
// Parse the JSON data
|
|
90
|
+
let ai_json: AiJsonFormat =
|
|
91
|
+
serde_json::from_str(&ai_data).expect("AI JSON file does not have the correct format");
|
|
92
|
+
|
|
93
|
+
let genedb = String::from(tpmasterdir) + &"/" + &genedb_str;
|
|
94
|
+
let dataset_db = String::from(tpmasterdir) + &"/" + &dataset_db_str;
|
|
95
|
+
|
|
96
|
+
let apilink_json: &JsonValue = &json_string["apilink"];
|
|
97
|
+
let apilink: &str;
|
|
98
|
+
match apilink_json.as_str() {
|
|
99
|
+
Some(inp) => apilink = inp,
|
|
100
|
+
None => panic!("apilink field is missing in input json"),
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
let comp_model_name_json: &JsonValue = &json_string["comp_model_name"];
|
|
104
|
+
let comp_model_name: &str;
|
|
105
|
+
match comp_model_name_json.as_str() {
|
|
106
|
+
Some(inp) => comp_model_name = inp,
|
|
107
|
+
None => panic!("comp_model_name field is missing in input json"),
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
let embedding_model_name_json: &JsonValue = &json_string["embedding_model_name"];
|
|
111
|
+
let embedding_model_name: &str;
|
|
112
|
+
match embedding_model_name_json.as_str() {
|
|
113
|
+
Some(inp) => embedding_model_name = inp,
|
|
114
|
+
None => panic!("embedding_model_name field is missing in input json"),
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
let llm_backend_name_json: &JsonValue = &json_string["llm_backend_name"];
|
|
118
|
+
let llm_backend_name: &str;
|
|
119
|
+
match llm_backend_name_json.as_str() {
|
|
120
|
+
Some(inp) => llm_backend_name = inp,
|
|
121
|
+
None => panic!("llm_backend_name field is missing in input json"),
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
let llm_backend_type: aichatbot::llm_backend;
|
|
125
|
+
let mut final_output: Option<String> = None;
|
|
126
|
+
let temperature: f64 = 0.01;
|
|
127
|
+
let max_new_tokens: usize = 512;
|
|
128
|
+
let top_p: f32 = 0.95;
|
|
129
|
+
let testing = false; // This variable is always false in production, this is true in test_ai.rs for testing code
|
|
130
|
+
if llm_backend_name != "ollama" && llm_backend_name != "SJ" {
|
|
131
|
+
panic!(
|
|
132
|
+
"This code currently supports only Ollama and SJ provider. llm_backend_name must be \"ollama\" or \"SJ\""
|
|
133
|
+
);
|
|
134
|
+
} else if llm_backend_name == "ollama".to_string() {
|
|
135
|
+
llm_backend_type = aichatbot::llm_backend::Ollama();
|
|
136
|
+
// Initialize Ollama client
|
|
137
|
+
let ollama_client = ollama::Client::builder()
|
|
138
|
+
.base_url(apilink)
|
|
139
|
+
.build()
|
|
140
|
+
.expect("Ollama server not found");
|
|
141
|
+
let embedding_model = ollama_client.embedding_model(embedding_model_name);
|
|
142
|
+
let comp_model = ollama_client.completion_model(comp_model_name);
|
|
143
|
+
final_output = Some(
|
|
144
|
+
aichatbot::extract_summary_information(
|
|
145
|
+
user_input,
|
|
146
|
+
comp_model,
|
|
147
|
+
embedding_model,
|
|
148
|
+
&llm_backend_type,
|
|
149
|
+
temperature,
|
|
150
|
+
max_new_tokens,
|
|
151
|
+
top_p,
|
|
152
|
+
&dataset_db,
|
|
153
|
+
&genedb,
|
|
154
|
+
&ai_json,
|
|
155
|
+
testing,
|
|
156
|
+
)
|
|
157
|
+
.await,
|
|
158
|
+
);
|
|
159
|
+
} else if llm_backend_name == "SJ".to_string() {
|
|
160
|
+
llm_backend_type = aichatbot::llm_backend::Sj();
|
|
161
|
+
// Initialize Sj provider client
|
|
162
|
+
let sj_client = sjprovider::Client::builder()
|
|
163
|
+
.base_url(apilink)
|
|
164
|
+
.build()
|
|
165
|
+
.expect("SJ server not found");
|
|
166
|
+
let embedding_model = sj_client.embedding_model(embedding_model_name);
|
|
167
|
+
let comp_model = sj_client.completion_model(comp_model_name);
|
|
168
|
+
final_output = Some(
|
|
169
|
+
aichatbot::extract_summary_information(
|
|
170
|
+
user_input,
|
|
171
|
+
comp_model,
|
|
172
|
+
embedding_model,
|
|
173
|
+
&llm_backend_type,
|
|
174
|
+
temperature,
|
|
175
|
+
max_new_tokens,
|
|
176
|
+
top_p,
|
|
177
|
+
&dataset_db,
|
|
178
|
+
&genedb,
|
|
179
|
+
&ai_json,
|
|
180
|
+
testing,
|
|
181
|
+
)
|
|
182
|
+
.await,
|
|
183
|
+
);
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
match final_output {
|
|
187
|
+
Some(fin_out) => {
|
|
188
|
+
println!("final_output:{:?}", fin_out.replace("\\", ""));
|
|
189
|
+
}
|
|
190
|
+
None => {
|
|
191
|
+
println!("final_output:{{\"{}\":\"{}\"}}", "action", "unknown");
|
|
192
|
+
}
|
|
193
|
+
}
|
|
194
|
+
}
|
|
195
|
+
Err(error) => println!("Incorrect json:{}", error),
|
|
196
|
+
}
|
|
197
|
+
}
|
|
198
|
+
Err(error) => println!("Piping error: {}", error),
|
|
199
|
+
}
|
|
200
|
+
Ok(())
|
|
201
|
+
}
|