@sjcrh/proteinpaint-rust 2.169.0 → 2.170.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/test_ai.rs CHANGED
@@ -5,6 +5,7 @@ fn main() {}
5
5
 
6
6
  #[cfg(test)]
7
7
  mod tests {
8
+ use crate::aichatbot::{AiJsonFormat, AnswerFormat, SummaryType, llm_backend, run_pipeline};
8
9
  use serde_json;
9
10
  use std::fs::{self};
10
11
  use std::path::Path;
@@ -62,13 +63,13 @@ mod tests {
62
63
  // Read the file
63
64
  let ai_data = fs::read_to_string(ai_json_file).unwrap();
64
65
  // Parse the JSON data
65
- let ai_json: super::super::AiJsonFormat =
66
+ let ai_json: AiJsonFormat =
66
67
  serde_json::from_str(&ai_data).expect("AI JSON file does not have the correct format");
67
68
  //println!("ai_json:{:?}", ai_json);
68
69
  let genedb = String::from(&serverconfig.tpmasterdir) + &"/" + &ai_json.genedb;
69
70
  let dataset_db = String::from(&serverconfig.tpmasterdir) + &"/" + &ai_json.db;
70
71
  let llm_backend_name = &serverconfig.llm_backend;
71
- let llm_backend_type: super::super::llm_backend;
72
+ let llm_backend_type: llm_backend;
72
73
 
73
74
  if llm_backend_name != "ollama" && llm_backend_name != "SJ" {
74
75
  panic!(
@@ -78,7 +79,7 @@ mod tests {
78
79
  let ollama_host = &serverconfig.ollama_apilink;
79
80
  let ollama_embedding_model_name = &serverconfig.ollama_embedding_model_name;
80
81
  let ollama_comp_model_name = &serverconfig.ollama_comp_model_name;
81
- llm_backend_type = super::super::llm_backend::Ollama();
82
+ llm_backend_type = llm_backend::Ollama();
82
83
  let ollama_client = super::super::ollama::Client::builder()
83
84
  .base_url(ollama_host)
84
85
  .build()
@@ -86,42 +87,45 @@ mod tests {
86
87
  let embedding_model = ollama_client.embedding_model(ollama_embedding_model_name);
87
88
  let comp_model = ollama_client.completion_model(ollama_comp_model_name);
88
89
  for chart in ai_json.charts.clone() {
89
- match chart {
90
- super::super::Charts::Summary(testdata) => {
91
- for ques_ans in testdata.TestData {
92
- let user_input = ques_ans.question;
93
- let llm_output = super::super::run_pipeline(
94
- &user_input,
95
- comp_model.clone(),
96
- embedding_model.clone(),
97
- llm_backend_type.clone(),
98
- temperature,
99
- max_new_tokens,
100
- top_p,
101
- &dataset_db,
102
- &genedb,
103
- &ai_json,
104
- &airoute,
105
- testing,
106
- )
107
- .await;
108
- let llm_json_value: super::super::SummaryType = serde_json::from_str(&llm_output.unwrap()).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
109
- let sum: super::super::SummaryType = ques_ans.answer;
110
- //println!("expected answer:{:?}", &sum);
111
- assert_eq!(
112
- llm_json_value.sort_summarytype_struct(),
113
- sum.sort_summarytype_struct()
114
- );
90
+ if chart.r#type == "Summary" {
91
+ for ques_ans in chart.TestData {
92
+ let user_input = ques_ans.question;
93
+ let llm_output = run_pipeline(
94
+ &user_input,
95
+ comp_model.clone(),
96
+ embedding_model.clone(),
97
+ llm_backend_type.clone(),
98
+ temperature,
99
+ max_new_tokens,
100
+ top_p,
101
+ &dataset_db,
102
+ &genedb,
103
+ &ai_json,
104
+ &airoute,
105
+ testing,
106
+ )
107
+ .await;
108
+ let llm_json_value: SummaryType = serde_json::from_str(&llm_output.unwrap()).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
109
+ match ques_ans.answer {
110
+ AnswerFormat::summary_type(sum) => {
111
+ //println!("expected answer:{:?}", &sum);
112
+ assert_eq!(
113
+ llm_json_value.sort_summarytype_struct(),
114
+ sum.sort_summarytype_struct()
115
+ );
116
+ }
117
+ AnswerFormat::DE_type(_) => {
118
+ panic!("DE type not valid for summary")
119
+ }
115
120
  }
116
121
  }
117
- super::super::Charts::DE(_testdata) => {} // To do
118
122
  }
119
123
  }
120
124
  } else if *llm_backend_name == "SJ".to_string() {
121
125
  let sjprovider_host = &serverconfig.sj_apilink;
122
126
  let sj_embedding_model_name = &serverconfig.sj_embedding_model_name;
123
127
  let sj_comp_model_name = &serverconfig.sj_comp_model_name;
124
- llm_backend_type = super::super::llm_backend::Sj();
128
+ llm_backend_type = llm_backend::Sj();
125
129
  let sj_client = super::super::sjprovider::Client::builder()
126
130
  .base_url(sjprovider_host)
127
131
  .build()
@@ -130,50 +134,53 @@ mod tests {
130
134
  let comp_model = sj_client.completion_model(sj_comp_model_name);
131
135
 
132
136
  for chart in ai_json.charts.clone() {
133
- match chart {
134
- super::super::Charts::Summary(testdata) => {
135
- for ques_ans in testdata.TestData {
136
- let user_input = ques_ans.question;
137
- if user_input.len() > 0 {
138
- let llm_output = super::super::run_pipeline(
139
- &user_input,
140
- comp_model.clone(),
141
- embedding_model.clone(),
142
- llm_backend_type.clone(),
143
- temperature,
144
- max_new_tokens,
145
- top_p,
146
- &dataset_db,
147
- &genedb,
148
- &ai_json,
149
- &airoute,
150
- testing,
151
- )
152
- .await;
153
- //println!("user_input:{}", user_input);
154
- //println!("llm_answer:{:?}", llm_output);
155
- //println!("expected answer:{:?}", &ques_ans.answer);
156
- let llm_json_value: super::super::SummaryType = serde_json::from_str(&llm_output.unwrap()).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
157
- //println!(
158
- // "llm_answer:{:?}",
159
- // llm_json_value.clone().sort_summarytype_struct()
160
- //);
161
- //println!(
162
- // "expected answer:{:?}",
163
- // &expected_json_value.clone().sort_summarytype_struct()
164
- //);
165
- let sum: super::super::SummaryType = ques_ans.answer;
166
- //println!("expected answer:{:?}", &sum);
167
- assert_eq!(
168
- llm_json_value.sort_summarytype_struct(),
169
- sum.sort_summarytype_struct()
170
- );
171
- } else {
172
- panic!("The user input is empty");
137
+ if chart.r#type == "Summary" {
138
+ for ques_ans in chart.TestData {
139
+ let user_input = ques_ans.question;
140
+ if user_input.len() > 0 {
141
+ let llm_output = run_pipeline(
142
+ &user_input,
143
+ comp_model.clone(),
144
+ embedding_model.clone(),
145
+ llm_backend_type.clone(),
146
+ temperature,
147
+ max_new_tokens,
148
+ top_p,
149
+ &dataset_db,
150
+ &genedb,
151
+ &ai_json,
152
+ &airoute,
153
+ testing,
154
+ )
155
+ .await;
156
+ //println!("user_input:{}", user_input);
157
+ //println!("llm_answer:{:?}", llm_output);
158
+ //println!("expected answer:{:?}", &ques_ans.answer);
159
+ let llm_json_value: SummaryType = serde_json::from_str(&llm_output.unwrap()).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
160
+ //println!(
161
+ // "llm_answer:{:?}",
162
+ // llm_json_value.clone().sort_summarytype_struct()
163
+ //);
164
+ //println!(
165
+ // "expected answer:{:?}",
166
+ // &expected_json_value.clone().sort_summarytype_struct()
167
+ //);
168
+ match ques_ans.answer {
169
+ AnswerFormat::summary_type(sum) => {
170
+ //println!("expected answer:{:?}", &sum);
171
+ assert_eq!(
172
+ llm_json_value.sort_summarytype_struct(),
173
+ sum.sort_summarytype_struct()
174
+ );
175
+ }
176
+ AnswerFormat::DE_type(_) => {
177
+ panic!("DE type not valid for summary")
178
+ }
173
179
  }
180
+ } else {
181
+ panic!("The user input is empty");
174
182
  }
175
183
  }
176
- super::super::Charts::DE(_testdata) => {} // To do
177
184
  }
178
185
  }
179
186
  }