@sjcrh/proteinpaint-rust 2.169.0 → 2.170.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/Cargo.toml CHANGED
@@ -123,10 +123,14 @@ path="src/cerno.rs"
123
123
  name="readH5"
124
124
  path="src/readH5.rs"
125
125
 
126
- [[bin]]
127
- name="aichatbot"
128
- path="src/aichatbot.rs"
129
-
130
126
  [[bin]]
131
127
  name="manhattan_plot"
132
128
  path="src/manhattan_plot.rs"
129
+
130
+ [[bin]]
131
+ name="query_classification"
132
+ path="src/query_classification.rs"
133
+
134
+ [[bin]]
135
+ name="summary_agent"
136
+ path="src/summary_agent.rs"
package/package.json CHANGED
@@ -1,5 +1,5 @@
1
1
  {
2
- "version": "2.169.0",
2
+ "version": "2.170.0",
3
3
  "name": "@sjcrh/proteinpaint-rust",
4
4
  "type": "module",
5
5
  "description": "Rust-based utilities for proteinpaint",
package/src/aichatbot.rs CHANGED
@@ -1,7 +1,7 @@
1
1
  // Syntax: cd .. && cargo build --release && time cat ~/sjpp/test.txt | target/release/aichatbot
2
2
  #![allow(non_snake_case)]
3
- use anyhow::Result;
4
- use json::JsonValue;
3
+ //use anyhow::Result;
4
+ //use json::JsonValue;
5
5
  use r2d2_sqlite::SqliteConnectionManager;
6
6
  use rig::agent::AgentBuilder;
7
7
  use rig::completion::Prompt;
@@ -11,69 +11,52 @@ use schemars::JsonSchema;
11
11
  use serde_json::{Map, Value, json};
12
12
  use std::collections::HashMap;
13
13
  use std::fs;
14
- use std::io;
15
- use std::path::Path;
16
- mod ollama; // Importing custom rig module for invoking ollama server
17
- mod sjprovider; // Importing custom rig module for invoking SJ GPU server
18
-
19
- mod test_ai; // Test examples for AI chatbot
20
14
 
21
15
  // Struct for intaking data from dataset json
22
16
  #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
23
17
  pub struct AiJsonFormat {
24
- hasGeneExpression: bool,
25
- db: String, // Dataset db
26
- genedb: String, // Gene db
27
- charts: Vec<Charts>,
28
- }
29
-
30
- #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
31
- enum Charts {
32
- // More chart types will be added here later
33
- Summary(TrainTestDataSummary),
34
- DE(TrainTestDataDE),
18
+ pub hasGeneExpression: bool,
19
+ pub db: String, // Dataset db
20
+ pub genedb: String, // Gene db
21
+ pub charts: Vec<TrainTestData>,
35
22
  }
36
23
 
37
24
  #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
38
- struct TrainTestDataSummary {
39
- SystemPrompt: String,
40
- TrainingData: Vec<QuestionAnswerSummary>,
41
- TestData: Vec<QuestionAnswerSummary>,
25
+ pub struct TrainTestData {
26
+ pub r#type: String,
27
+ pub SystemPrompt: String,
28
+ pub TrainingData: Vec<QuestionAnswer>,
29
+ pub TestData: Vec<QuestionAnswer>,
42
30
  }
43
31
 
44
32
  #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
45
- struct QuestionAnswerSummary {
46
- question: String,
47
- answer: SummaryType,
33
+ pub struct QuestionAnswer {
34
+ pub question: String,
35
+ pub answer: AnswerFormat,
48
36
  }
49
37
 
50
38
  #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
51
- struct TrainTestDataDE {
52
- SystemPrompt: String,
53
- TrainingData: Vec<QuestionAnswerDE>,
54
- TestData: Vec<QuestionAnswerDE>,
55
- }
56
-
57
- #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
58
- struct QuestionAnswerDE {
59
- question: String,
60
- answer: DEType,
39
+ pub enum AnswerFormat {
40
+ #[allow(non_camel_case_types)]
41
+ summary_type(SummaryType),
42
+ #[allow(non_camel_case_types)]
43
+ DE_type(DEType),
61
44
  }
62
45
 
63
46
  #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
64
- struct DEType {
47
+ pub struct DEType {
65
48
  action: String,
66
49
  DE_output: DETerms,
67
50
  }
68
51
 
69
52
  #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
70
- struct DETerms {
53
+ pub struct DETerms {
71
54
  group1: GroupType,
72
55
  group2: GroupType,
73
56
  }
74
57
 
75
58
  #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
76
- struct GroupType {
59
+ pub struct GroupType {
77
60
  name: String,
78
61
  }
79
62
 
@@ -86,204 +69,188 @@ pub enum llm_backend {
86
69
 
87
70
  #[derive(Debug, JsonSchema)]
88
71
  #[allow(dead_code)]
89
- struct OutputJson {
72
+ pub struct OutputJson {
90
73
  pub answer: String,
91
74
  }
92
75
 
93
- #[tokio::main]
94
- async fn main() -> Result<()> {
95
- let mut input = String::new();
96
- match io::stdin().read_line(&mut input) {
97
- // Accepting the piped input from nodejs (or command line from testing)
98
- Ok(_n) => {
99
- let input_json = json::parse(&input);
100
- match input_json {
101
- Ok(json_string) => {
102
- //println!("json_string:{}", json_string);
103
- let user_input_json: &JsonValue = &json_string["user_input"];
104
- let user_input: &str;
105
- match user_input_json.as_str() {
106
- Some(inp) => user_input = inp,
107
- None => panic!("user_input field is missing in input json"),
108
- }
109
-
110
- let dataset_db_json: &JsonValue = &json_string["dataset_db"];
111
- let dataset_db_str: &str;
112
- match dataset_db_json.as_str() {
113
- Some(inp) => dataset_db_str = inp,
114
- None => panic!("dataset_db field is missing in input json"),
115
- }
116
-
117
- let genedb_json: &JsonValue = &json_string["genedb"];
118
- let genedb_str: &str;
119
- match genedb_json.as_str() {
120
- Some(inp) => genedb_str = inp,
121
- None => panic!("genedb field is missing in input json"),
122
- }
123
-
124
- let aiRoute_json: &JsonValue = &json_string["aiRoute"];
125
- let aiRoute_str: &str;
126
- match aiRoute_json.as_str() {
127
- Some(inp) => aiRoute_str = inp,
128
- None => panic!("aiRoute field is missing in input json"),
129
- }
130
-
131
- if user_input.len() == 0 {
132
- panic!("The user input is empty");
133
- }
134
-
135
- let tpmasterdir_json: &JsonValue = &json_string["tpmasterdir"];
136
- let tpmasterdir: &str;
137
- match tpmasterdir_json.as_str() {
138
- Some(inp) => tpmasterdir = inp,
139
- None => panic!("tpmasterdir not found"),
140
- }
141
-
142
- let binpath_json: &JsonValue = &json_string["binpath"];
143
- let binpath: &str;
144
- match binpath_json.as_str() {
145
- Some(inp) => binpath = inp,
146
- None => panic!("binpath not found"),
147
- }
148
-
149
- let ai_json_file_json: &JsonValue = &json_string["aifiles"];
150
- let ai_json_file: String;
151
- match ai_json_file_json.as_str() {
152
- Some(inp) => ai_json_file = String::from(binpath) + &"/../../" + &inp,
153
- None => {
154
- panic!("ai json file not found")
155
- }
156
- }
157
-
158
- let ai_json_file = Path::new(&ai_json_file);
159
- let ai_json_file_path;
160
- let current_dir = std::env::current_dir().unwrap();
161
- match ai_json_file.canonicalize() {
162
- Ok(p) => ai_json_file_path = p,
163
- Err(_) => {
164
- panic!(
165
- "AI JSON file path not found:{:?}, current directory:{:?}",
166
- ai_json_file, current_dir
167
- )
168
- }
169
- }
170
-
171
- // Read the file
172
- let ai_data = fs::read_to_string(ai_json_file_path).unwrap();
173
-
174
- // Parse the JSON data
175
- let ai_json: AiJsonFormat =
176
- serde_json::from_str(&ai_data).expect("AI JSON file does not have the correct format");
177
-
178
- let genedb = String::from(tpmasterdir) + &"/" + &genedb_str;
179
- let dataset_db = String::from(tpmasterdir) + &"/" + &dataset_db_str;
180
- let airoute = String::from(binpath) + &"/../../" + &aiRoute_str;
181
-
182
- let apilink_json: &JsonValue = &json_string["apilink"];
183
- let apilink: &str;
184
- match apilink_json.as_str() {
185
- Some(inp) => apilink = inp,
186
- None => panic!("apilink field is missing in input json"),
187
- }
188
-
189
- let comp_model_name_json: &JsonValue = &json_string["comp_model_name"];
190
- let comp_model_name: &str;
191
- match comp_model_name_json.as_str() {
192
- Some(inp) => comp_model_name = inp,
193
- None => panic!("comp_model_name field is missing in input json"),
194
- }
195
-
196
- let embedding_model_name_json: &JsonValue = &json_string["embedding_model_name"];
197
- let embedding_model_name: &str;
198
- match embedding_model_name_json.as_str() {
199
- Some(inp) => embedding_model_name = inp,
200
- None => panic!("embedding_model_name field is missing in input json"),
201
- }
202
-
203
- let llm_backend_name_json: &JsonValue = &json_string["llm_backend_name"];
204
- let llm_backend_name: &str;
205
- match llm_backend_name_json.as_str() {
206
- Some(inp) => llm_backend_name = inp,
207
- None => panic!("llm_backend_name field is missing in input json"),
208
- }
209
-
210
- let llm_backend_type: llm_backend;
211
- let mut final_output: Option<String> = None;
212
- let temperature: f64 = 0.01;
213
- let max_new_tokens: usize = 512;
214
- let top_p: f32 = 0.95;
215
- let testing = false; // This variable is always false in production, this is true in test_ai.rs for testing code
216
- if llm_backend_name != "ollama" && llm_backend_name != "SJ" {
217
- panic!(
218
- "This code currently supports only Ollama and SJ provider. llm_backend_name must be \"ollama\" or \"SJ\""
219
- );
220
- } else if llm_backend_name == "ollama".to_string() {
221
- llm_backend_type = llm_backend::Ollama();
222
- // Initialize Ollama client
223
- let ollama_client = ollama::Client::builder()
224
- .base_url(apilink)
225
- .build()
226
- .expect("Ollama server not found");
227
- let embedding_model = ollama_client.embedding_model(embedding_model_name);
228
- let comp_model = ollama_client.completion_model(comp_model_name);
229
- final_output = run_pipeline(
230
- user_input,
231
- comp_model,
232
- embedding_model,
233
- llm_backend_type,
234
- temperature,
235
- max_new_tokens,
236
- top_p,
237
- &dataset_db,
238
- &genedb,
239
- &ai_json,
240
- &airoute,
241
- testing,
242
- )
243
- .await;
244
- } else if llm_backend_name == "SJ".to_string() {
245
- llm_backend_type = llm_backend::Sj();
246
- // Initialize Sj provider client
247
- let sj_client = sjprovider::Client::builder()
248
- .base_url(apilink)
249
- .build()
250
- .expect("SJ server not found");
251
- let embedding_model = sj_client.embedding_model(embedding_model_name);
252
- let comp_model = sj_client.completion_model(comp_model_name);
253
- final_output = run_pipeline(
254
- user_input,
255
- comp_model,
256
- embedding_model,
257
- llm_backend_type,
258
- temperature,
259
- max_new_tokens,
260
- top_p,
261
- &dataset_db,
262
- &genedb,
263
- &ai_json,
264
- &airoute,
265
- testing,
266
- )
267
- .await;
268
- }
269
-
270
- match final_output {
271
- Some(fin_out) => {
272
- println!("final_output:{:?}", fin_out.replace("\\", ""));
273
- }
274
- None => {
275
- println!("final_output:{{\"{}\":\"{}\"}}", "action", "unknown");
276
- }
277
- }
278
- }
279
- Err(error) => println!("Incorrect json:{}", error),
280
- }
281
- }
282
- Err(error) => println!("Piping error: {}", error),
283
- }
284
- Ok(())
285
- }
76
+ //#[tokio::main]
77
+ //async fn main() -> Result<()> {
78
+ // let mut input = String::new();
79
+ // match io::stdin().read_line(&mut input) {
80
+ // // Accepting the piped input from nodejs (or command line from testing)
81
+ // Ok(_n) => {
82
+ // let input_json = json::parse(&input);
83
+ // match input_json {
84
+ // Ok(json_string) => {
85
+ // //println!("json_string:{}", json_string);
86
+ // let user_input_json: &JsonValue = &json_string["user_input"];
87
+ // let user_input: &str;
88
+ // match user_input_json.as_str() {
89
+ // Some(inp) => user_input = inp,
90
+ // None => panic!("user_input field is missing in input json"),
91
+ // }
92
+ // let dataset_db_json: &JsonValue = &json_string["dataset_db"];
93
+ // let dataset_db_str: &str;
94
+ // match dataset_db_json.as_str() {
95
+ // Some(inp) => dataset_db_str = inp,
96
+ // None => panic!("dataset_db field is missing in input json"),
97
+ // }
98
+ // let genedb_json: &JsonValue = &json_string["genedb"];
99
+ // let genedb_str: &str;
100
+ // match genedb_json.as_str() {
101
+ // Some(inp) => genedb_str = inp,
102
+ // None => panic!("genedb field is missing in input json"),
103
+ // }
104
+ // let aiRoute_json: &JsonValue = &json_string["aiRoute"];
105
+ // let aiRoute_str: &str;
106
+ // match aiRoute_json.as_str() {
107
+ // Some(inp) => aiRoute_str = inp,
108
+ // None => panic!("aiRoute field is missing in input json"),
109
+ // }
110
+ // if user_input.len() == 0 {
111
+ // panic!("The user input is empty");
112
+ // }
113
+ // let tpmasterdir_json: &JsonValue = &json_string["tpmasterdir"];
114
+ // let tpmasterdir: &str;
115
+ // match tpmasterdir_json.as_str() {
116
+ // Some(inp) => tpmasterdir = inp,
117
+ // None => panic!("tpmasterdir not found"),
118
+ // }
119
+ // let binpath_json: &JsonValue = &json_string["binpath"];
120
+ // let binpath: &str;
121
+ // match binpath_json.as_str() {
122
+ // Some(inp) => binpath = inp,
123
+ // None => panic!("binpath not found"),
124
+ // }
125
+ // let ai_json_file_json: &JsonValue = &json_string["aifiles"];
126
+ // let ai_json_file: String;
127
+ // match ai_json_file_json.as_str() {
128
+ // Some(inp) => ai_json_file = String::from(binpath) + &"/../../" + &inp,
129
+ // None => {
130
+ // panic!("ai json file not found")
131
+ // }
132
+ // }
133
+ // let ai_json_file = Path::new(&ai_json_file);
134
+ // let ai_json_file_path;
135
+ // let current_dir = std::env::current_dir().unwrap();
136
+ // match ai_json_file.canonicalize() {
137
+ // Ok(p) => ai_json_file_path = p,
138
+ // Err(_) => {
139
+ // panic!(
140
+ // "AI JSON file path not found:{:?}, current directory:{:?}",
141
+ // ai_json_file, current_dir
142
+ // )
143
+ // }
144
+ // }
145
+ // // Read the file
146
+ // let ai_data = fs::read_to_string(ai_json_file_path).unwrap();
147
+ // // Parse the JSON data
148
+ // let ai_json: AiJsonFormat =
149
+ // serde_json::from_str(&ai_data).expect("AI JSON file does not have the correct format");
150
+ // let genedb = String::from(tpmasterdir) + &"/" + &genedb_str;
151
+ // let dataset_db = String::from(tpmasterdir) + &"/" + &dataset_db_str;
152
+ // let airoute = String::from(binpath) + &"/../../" + &aiRoute_str;
153
+ // let apilink_json: &JsonValue = &json_string["apilink"];
154
+ // let apilink: &str;
155
+ // match apilink_json.as_str() {
156
+ // Some(inp) => apilink = inp,
157
+ // None => panic!("apilink field is missing in input json"),
158
+ // }
159
+ // let comp_model_name_json: &JsonValue = &json_string["comp_model_name"];
160
+ // let comp_model_name: &str;
161
+ // match comp_model_name_json.as_str() {
162
+ // Some(inp) => comp_model_name = inp,
163
+ // None => panic!("comp_model_name field is missing in input json"),
164
+ // }
165
+ // let embedding_model_name_json: &JsonValue = &json_string["embedding_model_name"];
166
+ // let embedding_model_name: &str;
167
+ // match embedding_model_name_json.as_str() {
168
+ // Some(inp) => embedding_model_name = inp,
169
+ // None => panic!("embedding_model_name field is missing in input json"),
170
+ // }
171
+ // let llm_backend_name_json: &JsonValue = &json_string["llm_backend_name"];
172
+ // let llm_backend_name: &str;
173
+ // match llm_backend_name_json.as_str() {
174
+ // Some(inp) => llm_backend_name = inp,
175
+ // None => panic!("llm_backend_name field is missing in input json"),
176
+ // }
177
+ // let llm_backend_type: llm_backend;
178
+ // let mut final_output: Option<String> = None;
179
+ // let temperature: f64 = 0.01;
180
+ // let max_new_tokens: usize = 512;
181
+ // let top_p: f32 = 0.95;
182
+ // let testing = false; // This variable is always false in production, this is true in test_ai.rs for testing code
183
+ // if llm_backend_name != "ollama" && llm_backend_name != "SJ" {
184
+ // panic!(
185
+ // "This code currently supports only Ollama and SJ provider. llm_backend_name must be \"ollama\" or \"SJ\""
186
+ // );
187
+ // } else if llm_backend_name == "ollama".to_string() {
188
+ // llm_backend_type = llm_backend::Ollama();
189
+ // // Initialize Ollama client
190
+ // let ollama_client = ollama::Client::builder()
191
+ // .base_url(apilink)
192
+ // .build()
193
+ // .expect("Ollama server not found");
194
+ // let embedding_model = ollama_client.embedding_model(embedding_model_name);
195
+ // let comp_model = ollama_client.completion_model(comp_model_name);
196
+ // final_output = run_pipeline(
197
+ // user_input,
198
+ // comp_model,
199
+ // embedding_model,
200
+ // llm_backend_type,
201
+ // temperature,
202
+ // max_new_tokens,
203
+ // top_p,
204
+ // &dataset_db,
205
+ // &genedb,
206
+ // &ai_json,
207
+ // &airoute,
208
+ // testing,
209
+ // )
210
+ // .await;
211
+ // } else if llm_backend_name == "SJ".to_string() {
212
+ // llm_backend_type = llm_backend::Sj();
213
+ // // Initialize Sj provider client
214
+ // let sj_client = sjprovider::Client::builder()
215
+ // .base_url(apilink)
216
+ // .build()
217
+ // .expect("SJ server not found");
218
+ // let embedding_model = sj_client.embedding_model(embedding_model_name);
219
+ // let comp_model = sj_client.completion_model(comp_model_name);
220
+ // final_output = run_pipeline(
221
+ // user_input,
222
+ // comp_model,
223
+ // embedding_model,
224
+ // llm_backend_type,
225
+ // temperature,
226
+ // max_new_tokens,
227
+ // top_p,
228
+ // &dataset_db,
229
+ // &genedb,
230
+ // &ai_json,
231
+ // &airoute,
232
+ // testing,
233
+ // )
234
+ // .await;
235
+ // }
236
+ // match final_output {
237
+ // Some(fin_out) => {
238
+ // println!("final_output:{:?}", fin_out.replace("\\", ""));
239
+ // }
240
+ // None => {
241
+ // println!("final_output:{{\"{}\":\"{}\"}}", "action", "unknown");
242
+ // }
243
+ // }
244
+ // }
245
+ // Err(error) => println!("Incorrect json:{}", error),
246
+ // }
247
+ // }
248
+ // Err(error) => println!("Piping error: {}", error),
249
+ // }
250
+ // Ok(())
251
+ //}
286
252
 
253
+ #[allow(dead_code)]
287
254
  pub async fn run_pipeline(
288
255
  user_input: &str,
289
256
  comp_model: impl rig::completion::CompletionModel + 'static,
@@ -429,7 +396,8 @@ pub async fn run_pipeline(
429
396
  Some(final_output)
430
397
  }
431
398
 
432
- async fn classify_query_by_dataset_type(
399
+ #[allow(dead_code)]
400
+ pub async fn classify_query_by_dataset_type(
433
401
  user_input: &str,
434
402
  comp_model: impl rig::completion::CompletionModel + 'static,
435
403
  _embedding_model: impl rig::embeddings::EmbeddingModel + 'static,
@@ -449,9 +417,12 @@ async fn classify_query_by_dataset_type(
449
417
  let mut contents = String::from("");
450
418
 
451
419
  if let Some(object) = ai_json.as_object() {
452
- for (_key, value) in object {
453
- contents += &value.as_str().unwrap();
454
- contents += "---"; // Adding delimiter
420
+ contents = object["general"].to_string();
421
+ for (key, value) in object {
422
+ if key != "general" {
423
+ contents += &value.as_str().unwrap();
424
+ contents += "---"; // Adding delimiter
425
+ }
455
426
  }
456
427
  }
457
428
 
@@ -568,6 +539,7 @@ struct DEOutput {
568
539
  group2: Group,
569
540
  }
570
541
 
542
+ #[allow(dead_code)]
571
543
  #[allow(non_snake_case)]
572
544
  async fn extract_DE_search_terms_from_query(
573
545
  user_input: &str,
@@ -841,7 +813,7 @@ async fn parse_dataset_db(db: &str) -> (Vec<String>, Vec<DbRows>) {
841
813
  (rag_docs, db_vec)
842
814
  }
843
815
 
844
- async fn extract_summary_information(
816
+ pub async fn extract_summary_information(
845
817
  user_input: &str,
846
818
  comp_model: impl rig::completion::CompletionModel + 'static,
847
819
  _embedding_model: impl rig::embeddings::EmbeddingModel + 'static,
@@ -897,10 +869,10 @@ async fn extract_summary_information(
897
869
  .filter(|x| user_words2.contains(&x.to_lowercase()))
898
870
  .collect();
899
871
 
900
- let mut summary_data_check: Option<TrainTestDataSummary> = None;
872
+ let mut summary_data_check: Option<TrainTestData> = None;
901
873
  for chart in ai_json.charts.clone() {
902
- if let Charts::Summary(traindata) = chart {
903
- summary_data_check = Some(traindata);
874
+ if chart.r#type == "Summary" {
875
+ summary_data_check = Some(chart);
904
876
  break;
905
877
  }
906
878
  }
@@ -910,18 +882,23 @@ async fn extract_summary_information(
910
882
  let mut training_data: String = String::from("");
911
883
  let mut train_iter = 0;
912
884
  for ques_ans in summary_data.TrainingData {
913
- let summary_answer: SummaryType = ques_ans.answer;
914
- train_iter += 1;
915
- training_data += "Example question";
916
- training_data += &train_iter.to_string();
917
- training_data += &":";
918
- training_data += &ques_ans.question;
919
- training_data += &" ";
920
- training_data += "Example answer";
921
- training_data += &train_iter.to_string();
922
- training_data += &":";
923
- training_data += &serde_json::to_string(&summary_answer).unwrap();
924
- training_data += &"\n";
885
+ match ques_ans.answer {
886
+ AnswerFormat::summary_type(sum) => {
887
+ let summary_answer: SummaryType = sum;
888
+ train_iter += 1;
889
+ training_data += "Example question";
890
+ training_data += &train_iter.to_string();
891
+ training_data += &":";
892
+ training_data += &ques_ans.question;
893
+ training_data += &" ";
894
+ training_data += "Example answer";
895
+ training_data += &train_iter.to_string();
896
+ training_data += &":";
897
+ training_data += &serde_json::to_string(&summary_answer).unwrap();
898
+ training_data += &"\n";
899
+ }
900
+ AnswerFormat::DE_type(_) => panic!("DE type not valid for summary"),
901
+ }
925
902
  }
926
903
 
927
904
  let system_prompt: String = String::from(
@@ -990,7 +967,7 @@ fn get_summary_string() -> String {
990
967
  //const geneExpression: &str = &"geneExpression";
991
968
 
992
969
  #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
993
- struct SummaryType {
970
+ pub struct SummaryType {
994
971
  // Serde uses this for deserialization.
995
972
  #[serde(default = "get_summary_string")]
996
973
  // Schemars uses this for schema generation.