@sjcrh/proteinpaint-rust 2.169.0 → 2.170.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Cargo.toml +8 -4
- package/package.json +1 -1
- package/src/aichatbot.rs +231 -254
- package/src/manhattan_plot.rs +31 -18
- package/src/query_classification.rs +152 -0
- package/src/summary_agent.rs +201 -0
- package/src/test_ai.rs +79 -72
package/Cargo.toml
CHANGED
|
@@ -123,10 +123,14 @@ path="src/cerno.rs"
|
|
|
123
123
|
name="readH5"
|
|
124
124
|
path="src/readH5.rs"
|
|
125
125
|
|
|
126
|
-
[[bin]]
|
|
127
|
-
name="aichatbot"
|
|
128
|
-
path="src/aichatbot.rs"
|
|
129
|
-
|
|
130
126
|
[[bin]]
|
|
131
127
|
name="manhattan_plot"
|
|
132
128
|
path="src/manhattan_plot.rs"
|
|
129
|
+
|
|
130
|
+
[[bin]]
|
|
131
|
+
name="query_classification"
|
|
132
|
+
path="src/query_classification.rs"
|
|
133
|
+
|
|
134
|
+
[[bin]]
|
|
135
|
+
name="summary_agent"
|
|
136
|
+
path="src/summary_agent.rs"
|
package/package.json
CHANGED
package/src/aichatbot.rs
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
// Syntax: cd .. && cargo build --release && time cat ~/sjpp/test.txt | target/release/aichatbot
|
|
2
2
|
#![allow(non_snake_case)]
|
|
3
|
-
use anyhow::Result;
|
|
4
|
-
use json::JsonValue;
|
|
3
|
+
//use anyhow::Result;
|
|
4
|
+
//use json::JsonValue;
|
|
5
5
|
use r2d2_sqlite::SqliteConnectionManager;
|
|
6
6
|
use rig::agent::AgentBuilder;
|
|
7
7
|
use rig::completion::Prompt;
|
|
@@ -11,69 +11,52 @@ use schemars::JsonSchema;
|
|
|
11
11
|
use serde_json::{Map, Value, json};
|
|
12
12
|
use std::collections::HashMap;
|
|
13
13
|
use std::fs;
|
|
14
|
-
use std::io;
|
|
15
|
-
use std::path::Path;
|
|
16
|
-
mod ollama; // Importing custom rig module for invoking ollama server
|
|
17
|
-
mod sjprovider; // Importing custom rig module for invoking SJ GPU server
|
|
18
|
-
|
|
19
|
-
mod test_ai; // Test examples for AI chatbot
|
|
20
14
|
|
|
21
15
|
// Struct for intaking data from dataset json
|
|
22
16
|
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
23
17
|
pub struct AiJsonFormat {
|
|
24
|
-
hasGeneExpression: bool,
|
|
25
|
-
db: String, // Dataset db
|
|
26
|
-
genedb: String, // Gene db
|
|
27
|
-
charts: Vec<
|
|
28
|
-
}
|
|
29
|
-
|
|
30
|
-
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
31
|
-
enum Charts {
|
|
32
|
-
// More chart types will be added here later
|
|
33
|
-
Summary(TrainTestDataSummary),
|
|
34
|
-
DE(TrainTestDataDE),
|
|
18
|
+
pub hasGeneExpression: bool,
|
|
19
|
+
pub db: String, // Dataset db
|
|
20
|
+
pub genedb: String, // Gene db
|
|
21
|
+
pub charts: Vec<TrainTestData>,
|
|
35
22
|
}
|
|
36
23
|
|
|
37
24
|
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
38
|
-
struct
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
25
|
+
pub struct TrainTestData {
|
|
26
|
+
pub r#type: String,
|
|
27
|
+
pub SystemPrompt: String,
|
|
28
|
+
pub TrainingData: Vec<QuestionAnswer>,
|
|
29
|
+
pub TestData: Vec<QuestionAnswer>,
|
|
42
30
|
}
|
|
43
31
|
|
|
44
32
|
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
45
|
-
struct
|
|
46
|
-
question: String,
|
|
47
|
-
answer:
|
|
33
|
+
pub struct QuestionAnswer {
|
|
34
|
+
pub question: String,
|
|
35
|
+
pub answer: AnswerFormat,
|
|
48
36
|
}
|
|
49
37
|
|
|
50
38
|
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
58
|
-
struct QuestionAnswerDE {
|
|
59
|
-
question: String,
|
|
60
|
-
answer: DEType,
|
|
39
|
+
pub enum AnswerFormat {
|
|
40
|
+
#[allow(non_camel_case_types)]
|
|
41
|
+
summary_type(SummaryType),
|
|
42
|
+
#[allow(non_camel_case_types)]
|
|
43
|
+
DE_type(DEType),
|
|
61
44
|
}
|
|
62
45
|
|
|
63
46
|
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
64
|
-
struct DEType {
|
|
47
|
+
pub struct DEType {
|
|
65
48
|
action: String,
|
|
66
49
|
DE_output: DETerms,
|
|
67
50
|
}
|
|
68
51
|
|
|
69
52
|
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
70
|
-
struct DETerms {
|
|
53
|
+
pub struct DETerms {
|
|
71
54
|
group1: GroupType,
|
|
72
55
|
group2: GroupType,
|
|
73
56
|
}
|
|
74
57
|
|
|
75
58
|
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
76
|
-
struct GroupType {
|
|
59
|
+
pub struct GroupType {
|
|
77
60
|
name: String,
|
|
78
61
|
}
|
|
79
62
|
|
|
@@ -86,204 +69,188 @@ pub enum llm_backend {
|
|
|
86
69
|
|
|
87
70
|
#[derive(Debug, JsonSchema)]
|
|
88
71
|
#[allow(dead_code)]
|
|
89
|
-
struct OutputJson {
|
|
72
|
+
pub struct OutputJson {
|
|
90
73
|
pub answer: String,
|
|
91
74
|
}
|
|
92
75
|
|
|
93
|
-
|
|
94
|
-
async fn main() -> Result<()> {
|
|
95
|
-
let mut input = String::new();
|
|
96
|
-
match io::stdin().read_line(&mut input) {
|
|
97
|
-
// Accepting the piped input from nodejs (or command line from testing)
|
|
98
|
-
Ok(_n) => {
|
|
99
|
-
let input_json = json::parse(&input);
|
|
100
|
-
match input_json {
|
|
101
|
-
Ok(json_string) => {
|
|
102
|
-
//println!("json_string:{}", json_string);
|
|
103
|
-
let user_input_json: &JsonValue = &json_string["user_input"];
|
|
104
|
-
let user_input: &str;
|
|
105
|
-
match user_input_json.as_str() {
|
|
106
|
-
Some(inp) => user_input = inp,
|
|
107
|
-
None => panic!("user_input field is missing in input json"),
|
|
108
|
-
}
|
|
109
|
-
|
|
110
|
-
let
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
}
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
let
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
let
|
|
143
|
-
let
|
|
144
|
-
match
|
|
145
|
-
Some(inp) =>
|
|
146
|
-
None =>
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
let ai_json_file
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
let
|
|
183
|
-
let
|
|
184
|
-
match
|
|
185
|
-
Some(inp) =>
|
|
186
|
-
None => panic!("
|
|
187
|
-
}
|
|
188
|
-
|
|
189
|
-
let
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
let
|
|
197
|
-
let
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
match final_output {
|
|
271
|
-
Some(fin_out) => {
|
|
272
|
-
println!("final_output:{:?}", fin_out.replace("\\", ""));
|
|
273
|
-
}
|
|
274
|
-
None => {
|
|
275
|
-
println!("final_output:{{\"{}\":\"{}\"}}", "action", "unknown");
|
|
276
|
-
}
|
|
277
|
-
}
|
|
278
|
-
}
|
|
279
|
-
Err(error) => println!("Incorrect json:{}", error),
|
|
280
|
-
}
|
|
281
|
-
}
|
|
282
|
-
Err(error) => println!("Piping error: {}", error),
|
|
283
|
-
}
|
|
284
|
-
Ok(())
|
|
285
|
-
}
|
|
76
|
+
//#[tokio::main]
|
|
77
|
+
//async fn main() -> Result<()> {
|
|
78
|
+
// let mut input = String::new();
|
|
79
|
+
// match io::stdin().read_line(&mut input) {
|
|
80
|
+
// // Accepting the piped input from nodejs (or command line from testing)
|
|
81
|
+
// Ok(_n) => {
|
|
82
|
+
// let input_json = json::parse(&input);
|
|
83
|
+
// match input_json {
|
|
84
|
+
// Ok(json_string) => {
|
|
85
|
+
// //println!("json_string:{}", json_string);
|
|
86
|
+
// let user_input_json: &JsonValue = &json_string["user_input"];
|
|
87
|
+
// let user_input: &str;
|
|
88
|
+
// match user_input_json.as_str() {
|
|
89
|
+
// Some(inp) => user_input = inp,
|
|
90
|
+
// None => panic!("user_input field is missing in input json"),
|
|
91
|
+
// }
|
|
92
|
+
// let dataset_db_json: &JsonValue = &json_string["dataset_db"];
|
|
93
|
+
// let dataset_db_str: &str;
|
|
94
|
+
// match dataset_db_json.as_str() {
|
|
95
|
+
// Some(inp) => dataset_db_str = inp,
|
|
96
|
+
// None => panic!("dataset_db field is missing in input json"),
|
|
97
|
+
// }
|
|
98
|
+
// let genedb_json: &JsonValue = &json_string["genedb"];
|
|
99
|
+
// let genedb_str: &str;
|
|
100
|
+
// match genedb_json.as_str() {
|
|
101
|
+
// Some(inp) => genedb_str = inp,
|
|
102
|
+
// None => panic!("genedb field is missing in input json"),
|
|
103
|
+
// }
|
|
104
|
+
// let aiRoute_json: &JsonValue = &json_string["aiRoute"];
|
|
105
|
+
// let aiRoute_str: &str;
|
|
106
|
+
// match aiRoute_json.as_str() {
|
|
107
|
+
// Some(inp) => aiRoute_str = inp,
|
|
108
|
+
// None => panic!("aiRoute field is missing in input json"),
|
|
109
|
+
// }
|
|
110
|
+
// if user_input.len() == 0 {
|
|
111
|
+
// panic!("The user input is empty");
|
|
112
|
+
// }
|
|
113
|
+
// let tpmasterdir_json: &JsonValue = &json_string["tpmasterdir"];
|
|
114
|
+
// let tpmasterdir: &str;
|
|
115
|
+
// match tpmasterdir_json.as_str() {
|
|
116
|
+
// Some(inp) => tpmasterdir = inp,
|
|
117
|
+
// None => panic!("tpmasterdir not found"),
|
|
118
|
+
// }
|
|
119
|
+
// let binpath_json: &JsonValue = &json_string["binpath"];
|
|
120
|
+
// let binpath: &str;
|
|
121
|
+
// match binpath_json.as_str() {
|
|
122
|
+
// Some(inp) => binpath = inp,
|
|
123
|
+
// None => panic!("binpath not found"),
|
|
124
|
+
// }
|
|
125
|
+
// let ai_json_file_json: &JsonValue = &json_string["aifiles"];
|
|
126
|
+
// let ai_json_file: String;
|
|
127
|
+
// match ai_json_file_json.as_str() {
|
|
128
|
+
// Some(inp) => ai_json_file = String::from(binpath) + &"/../../" + &inp,
|
|
129
|
+
// None => {
|
|
130
|
+
// panic!("ai json file not found")
|
|
131
|
+
// }
|
|
132
|
+
// }
|
|
133
|
+
// let ai_json_file = Path::new(&ai_json_file);
|
|
134
|
+
// let ai_json_file_path;
|
|
135
|
+
// let current_dir = std::env::current_dir().unwrap();
|
|
136
|
+
// match ai_json_file.canonicalize() {
|
|
137
|
+
// Ok(p) => ai_json_file_path = p,
|
|
138
|
+
// Err(_) => {
|
|
139
|
+
// panic!(
|
|
140
|
+
// "AI JSON file path not found:{:?}, current directory:{:?}",
|
|
141
|
+
// ai_json_file, current_dir
|
|
142
|
+
// )
|
|
143
|
+
// }
|
|
144
|
+
// }
|
|
145
|
+
// // Read the file
|
|
146
|
+
// let ai_data = fs::read_to_string(ai_json_file_path).unwrap();
|
|
147
|
+
// // Parse the JSON data
|
|
148
|
+
// let ai_json: AiJsonFormat =
|
|
149
|
+
// serde_json::from_str(&ai_data).expect("AI JSON file does not have the correct format");
|
|
150
|
+
// let genedb = String::from(tpmasterdir) + &"/" + &genedb_str;
|
|
151
|
+
// let dataset_db = String::from(tpmasterdir) + &"/" + &dataset_db_str;
|
|
152
|
+
// let airoute = String::from(binpath) + &"/../../" + &aiRoute_str;
|
|
153
|
+
// let apilink_json: &JsonValue = &json_string["apilink"];
|
|
154
|
+
// let apilink: &str;
|
|
155
|
+
// match apilink_json.as_str() {
|
|
156
|
+
// Some(inp) => apilink = inp,
|
|
157
|
+
// None => panic!("apilink field is missing in input json"),
|
|
158
|
+
// }
|
|
159
|
+
// let comp_model_name_json: &JsonValue = &json_string["comp_model_name"];
|
|
160
|
+
// let comp_model_name: &str;
|
|
161
|
+
// match comp_model_name_json.as_str() {
|
|
162
|
+
// Some(inp) => comp_model_name = inp,
|
|
163
|
+
// None => panic!("comp_model_name field is missing in input json"),
|
|
164
|
+
// }
|
|
165
|
+
// let embedding_model_name_json: &JsonValue = &json_string["embedding_model_name"];
|
|
166
|
+
// let embedding_model_name: &str;
|
|
167
|
+
// match embedding_model_name_json.as_str() {
|
|
168
|
+
// Some(inp) => embedding_model_name = inp,
|
|
169
|
+
// None => panic!("embedding_model_name field is missing in input json"),
|
|
170
|
+
// }
|
|
171
|
+
// let llm_backend_name_json: &JsonValue = &json_string["llm_backend_name"];
|
|
172
|
+
// let llm_backend_name: &str;
|
|
173
|
+
// match llm_backend_name_json.as_str() {
|
|
174
|
+
// Some(inp) => llm_backend_name = inp,
|
|
175
|
+
// None => panic!("llm_backend_name field is missing in input json"),
|
|
176
|
+
// }
|
|
177
|
+
// let llm_backend_type: llm_backend;
|
|
178
|
+
// let mut final_output: Option<String> = None;
|
|
179
|
+
// let temperature: f64 = 0.01;
|
|
180
|
+
// let max_new_tokens: usize = 512;
|
|
181
|
+
// let top_p: f32 = 0.95;
|
|
182
|
+
// let testing = false; // This variable is always false in production, this is true in test_ai.rs for testing code
|
|
183
|
+
// if llm_backend_name != "ollama" && llm_backend_name != "SJ" {
|
|
184
|
+
// panic!(
|
|
185
|
+
// "This code currently supports only Ollama and SJ provider. llm_backend_name must be \"ollama\" or \"SJ\""
|
|
186
|
+
// );
|
|
187
|
+
// } else if llm_backend_name == "ollama".to_string() {
|
|
188
|
+
// llm_backend_type = llm_backend::Ollama();
|
|
189
|
+
// // Initialize Ollama client
|
|
190
|
+
// let ollama_client = ollama::Client::builder()
|
|
191
|
+
// .base_url(apilink)
|
|
192
|
+
// .build()
|
|
193
|
+
// .expect("Ollama server not found");
|
|
194
|
+
// let embedding_model = ollama_client.embedding_model(embedding_model_name);
|
|
195
|
+
// let comp_model = ollama_client.completion_model(comp_model_name);
|
|
196
|
+
// final_output = run_pipeline(
|
|
197
|
+
// user_input,
|
|
198
|
+
// comp_model,
|
|
199
|
+
// embedding_model,
|
|
200
|
+
// llm_backend_type,
|
|
201
|
+
// temperature,
|
|
202
|
+
// max_new_tokens,
|
|
203
|
+
// top_p,
|
|
204
|
+
// &dataset_db,
|
|
205
|
+
// &genedb,
|
|
206
|
+
// &ai_json,
|
|
207
|
+
// &airoute,
|
|
208
|
+
// testing,
|
|
209
|
+
// )
|
|
210
|
+
// .await;
|
|
211
|
+
// } else if llm_backend_name == "SJ".to_string() {
|
|
212
|
+
// llm_backend_type = llm_backend::Sj();
|
|
213
|
+
// // Initialize Sj provider client
|
|
214
|
+
// let sj_client = sjprovider::Client::builder()
|
|
215
|
+
// .base_url(apilink)
|
|
216
|
+
// .build()
|
|
217
|
+
// .expect("SJ server not found");
|
|
218
|
+
// let embedding_model = sj_client.embedding_model(embedding_model_name);
|
|
219
|
+
// let comp_model = sj_client.completion_model(comp_model_name);
|
|
220
|
+
// final_output = run_pipeline(
|
|
221
|
+
// user_input,
|
|
222
|
+
// comp_model,
|
|
223
|
+
// embedding_model,
|
|
224
|
+
// llm_backend_type,
|
|
225
|
+
// temperature,
|
|
226
|
+
// max_new_tokens,
|
|
227
|
+
// top_p,
|
|
228
|
+
// &dataset_db,
|
|
229
|
+
// &genedb,
|
|
230
|
+
// &ai_json,
|
|
231
|
+
// &airoute,
|
|
232
|
+
// testing,
|
|
233
|
+
// )
|
|
234
|
+
// .await;
|
|
235
|
+
// }
|
|
236
|
+
// match final_output {
|
|
237
|
+
// Some(fin_out) => {
|
|
238
|
+
// println!("final_output:{:?}", fin_out.replace("\\", ""));
|
|
239
|
+
// }
|
|
240
|
+
// None => {
|
|
241
|
+
// println!("final_output:{{\"{}\":\"{}\"}}", "action", "unknown");
|
|
242
|
+
// }
|
|
243
|
+
// }
|
|
244
|
+
// }
|
|
245
|
+
// Err(error) => println!("Incorrect json:{}", error),
|
|
246
|
+
// }
|
|
247
|
+
// }
|
|
248
|
+
// Err(error) => println!("Piping error: {}", error),
|
|
249
|
+
// }
|
|
250
|
+
// Ok(())
|
|
251
|
+
//}
|
|
286
252
|
|
|
253
|
+
#[allow(dead_code)]
|
|
287
254
|
pub async fn run_pipeline(
|
|
288
255
|
user_input: &str,
|
|
289
256
|
comp_model: impl rig::completion::CompletionModel + 'static,
|
|
@@ -429,7 +396,8 @@ pub async fn run_pipeline(
|
|
|
429
396
|
Some(final_output)
|
|
430
397
|
}
|
|
431
398
|
|
|
432
|
-
|
|
399
|
+
#[allow(dead_code)]
|
|
400
|
+
pub async fn classify_query_by_dataset_type(
|
|
433
401
|
user_input: &str,
|
|
434
402
|
comp_model: impl rig::completion::CompletionModel + 'static,
|
|
435
403
|
_embedding_model: impl rig::embeddings::EmbeddingModel + 'static,
|
|
@@ -449,9 +417,12 @@ async fn classify_query_by_dataset_type(
|
|
|
449
417
|
let mut contents = String::from("");
|
|
450
418
|
|
|
451
419
|
if let Some(object) = ai_json.as_object() {
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
420
|
+
contents = object["general"].to_string();
|
|
421
|
+
for (key, value) in object {
|
|
422
|
+
if key != "general" {
|
|
423
|
+
contents += &value.as_str().unwrap();
|
|
424
|
+
contents += "---"; // Adding delimiter
|
|
425
|
+
}
|
|
455
426
|
}
|
|
456
427
|
}
|
|
457
428
|
|
|
@@ -568,6 +539,7 @@ struct DEOutput {
|
|
|
568
539
|
group2: Group,
|
|
569
540
|
}
|
|
570
541
|
|
|
542
|
+
#[allow(dead_code)]
|
|
571
543
|
#[allow(non_snake_case)]
|
|
572
544
|
async fn extract_DE_search_terms_from_query(
|
|
573
545
|
user_input: &str,
|
|
@@ -841,7 +813,7 @@ async fn parse_dataset_db(db: &str) -> (Vec<String>, Vec<DbRows>) {
|
|
|
841
813
|
(rag_docs, db_vec)
|
|
842
814
|
}
|
|
843
815
|
|
|
844
|
-
async fn extract_summary_information(
|
|
816
|
+
pub async fn extract_summary_information(
|
|
845
817
|
user_input: &str,
|
|
846
818
|
comp_model: impl rig::completion::CompletionModel + 'static,
|
|
847
819
|
_embedding_model: impl rig::embeddings::EmbeddingModel + 'static,
|
|
@@ -897,10 +869,10 @@ async fn extract_summary_information(
|
|
|
897
869
|
.filter(|x| user_words2.contains(&x.to_lowercase()))
|
|
898
870
|
.collect();
|
|
899
871
|
|
|
900
|
-
let mut summary_data_check: Option<
|
|
872
|
+
let mut summary_data_check: Option<TrainTestData> = None;
|
|
901
873
|
for chart in ai_json.charts.clone() {
|
|
902
|
-
if
|
|
903
|
-
summary_data_check = Some(
|
|
874
|
+
if chart.r#type == "Summary" {
|
|
875
|
+
summary_data_check = Some(chart);
|
|
904
876
|
break;
|
|
905
877
|
}
|
|
906
878
|
}
|
|
@@ -910,18 +882,23 @@ async fn extract_summary_information(
|
|
|
910
882
|
let mut training_data: String = String::from("");
|
|
911
883
|
let mut train_iter = 0;
|
|
912
884
|
for ques_ans in summary_data.TrainingData {
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
885
|
+
match ques_ans.answer {
|
|
886
|
+
AnswerFormat::summary_type(sum) => {
|
|
887
|
+
let summary_answer: SummaryType = sum;
|
|
888
|
+
train_iter += 1;
|
|
889
|
+
training_data += "Example question";
|
|
890
|
+
training_data += &train_iter.to_string();
|
|
891
|
+
training_data += &":";
|
|
892
|
+
training_data += &ques_ans.question;
|
|
893
|
+
training_data += &" ";
|
|
894
|
+
training_data += "Example answer";
|
|
895
|
+
training_data += &train_iter.to_string();
|
|
896
|
+
training_data += &":";
|
|
897
|
+
training_data += &serde_json::to_string(&summary_answer).unwrap();
|
|
898
|
+
training_data += &"\n";
|
|
899
|
+
}
|
|
900
|
+
AnswerFormat::DE_type(_) => panic!("DE type not valid for summary"),
|
|
901
|
+
}
|
|
925
902
|
}
|
|
926
903
|
|
|
927
904
|
let system_prompt: String = String::from(
|
|
@@ -990,7 +967,7 @@ fn get_summary_string() -> String {
|
|
|
990
967
|
//const geneExpression: &str = &"geneExpression";
|
|
991
968
|
|
|
992
969
|
#[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
|
|
993
|
-
struct SummaryType {
|
|
970
|
+
pub struct SummaryType {
|
|
994
971
|
// Serde uses this for deserialization.
|
|
995
972
|
#[serde(default = "get_summary_string")]
|
|
996
973
|
// Schemars uses this for schema generation.
|