@sjcrh/proteinpaint-rust 2.148.1 → 2.150.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/sjprovider.rs CHANGED
@@ -73,6 +73,15 @@ impl Client {
73
73
  pub fn builder() -> ClientBuilder<'static> {
74
74
  ClientBuilder::new()
75
75
  }
76
+
77
+ pub fn completion_model(&self, model: &str) -> CompletionModel {
78
+ CompletionModel::new(self.clone(), model)
79
+ }
80
+
81
+ pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
82
+ EmbeddingModel::new(self.clone(), model, 0, self.base_url.to_string())
83
+ }
84
+
76
85
  pub fn new() -> Self {
77
86
  Self::builder().build().expect("Myprovider client should build")
78
87
  }
@@ -129,11 +138,16 @@ impl EmbeddingsClient for Client {
129
138
 
130
139
  impl VerifyClient for Client {
131
140
  async fn verify(&self) -> Result<(), VerifyError> {
132
- let response = self.get("api/tags").expect("Failed to build request").send().await?;
141
+ let response = self
142
+ .get("api/tags")
143
+ .expect("Failed to build request")
144
+ .send()
145
+ .await
146
+ .unwrap();
133
147
  match response.status() {
134
148
  reqwest::StatusCode::OK => Ok(()),
135
149
  _ => {
136
- response.error_for_status()?;
150
+ response.error_for_status().unwrap();
137
151
  Ok(())
138
152
  }
139
153
  }
@@ -262,7 +276,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
262
276
 
263
277
  if response.status().is_success() {
264
278
  //println!("response.json:{:?}", response.text().await?);
265
- let json_data: Value = serde_json::from_str(&response.text().await?)?;
279
+ let json_data: Value = serde_json::from_str(&response.text().await.unwrap())?;
266
280
  let emb = json_data["outputs"].as_array().unwrap();
267
281
  //.unwrap_or(&vec![serde_json::Value::String(
268
282
  // "No embeddings found in json output".to_string(),
@@ -481,12 +495,39 @@ impl CompletionModel {
481
495
  panic!("max_new_tokens and top_p not found!");
482
496
  };
483
497
 
498
+ let mut user_query = "";
499
+ let mut system_prompt = "";
500
+ for message in &full_history {
501
+ match message {
502
+ self::Message::User {
503
+ content: text,
504
+ images: _,
505
+ name: _,
506
+ } => {
507
+ //println!("User:{:?}", text);
508
+ user_query = text;
509
+ }
510
+ self::Message::System {
511
+ content: text,
512
+ images: _,
513
+ name: _,
514
+ } => {
515
+ system_prompt = text;
516
+ //println!("System:{:?}", text);
517
+ }
518
+ self::Message::Assistant { content: _, id: _ } => {}
519
+ self::Message::ToolResult { content: _, name: _ } => {}
520
+ }
521
+ }
522
+ let final_text = system_prompt.replace(&"{question}", &user_query);
523
+
524
+ //println!("final_text:{:?}", final_text);
484
525
  let mut request_payload = json!({
485
526
  "inputs":[
486
527
  {
487
528
  "model_name": self.model,
488
529
  "inputs": {
489
- "text": full_history,
530
+ "text": final_text,
490
531
  "max_new_tokens": max_new_tokens,
491
532
  "temperature": completion_request.temperature,
492
533
  "top_p": top_p
@@ -612,7 +653,7 @@ impl completion::CompletionModel for CompletionModel {
612
653
  let chunk = match chunk_result {
613
654
  Ok(c) => c,
614
655
  Err(e) => {
615
- yield Err(CompletionError::from(e));
656
+ yield Err(CompletionError::RequestError(e.into()));
616
657
  break;
617
658
  }
618
659
  };
@@ -797,7 +838,7 @@ impl ConvertMessage for Message {
797
838
  images.push(data)
798
839
  }
799
840
  rig::message::UserContent::Document(rig::message::Document { data, .. }) => {
800
- texts.push(data)
841
+ texts.push(data.to_string())
801
842
  }
802
843
  _ => {} // Audio not supported by Ollama
803
844
  }
@@ -993,7 +1034,7 @@ mod tests {
993
1034
  #[tokio::test]
994
1035
  #[ignore]
995
1036
 
996
- async fn test_myprovider_implementation() {
1037
+ async fn test_sjprovider_implementation() {
997
1038
  let user_input = "Generate DE plot for men with weight greater than 30lbs vs women less than 20lbs";
998
1039
  let serverconfig_file_path = Path::new("../../serverconfig.json");
999
1040
  let absolute_path = serverconfig_file_path.canonicalize().unwrap();
@@ -1118,17 +1159,17 @@ If a query does not match any of the fields described above, then return JSON wi
1118
1159
  });
1119
1160
 
1120
1161
  // Create RAG agent
1121
- let agent = AgentBuilder::new(comp_model).preamble("Generate classification for the user query into summary, dge, hierarchial, snv_indel, cnv, variant_calling, sv_fusion and none categories. Return output in JSON with ALWAYS a single word answer { \"answer\": \"dge\" }, that is 'summary' for summary plot, 'dge' for differential gene expression, 'hierarchial' for hierarchial clustering, 'snv_indel' for SNV/Indel, 'cnv' for CNV and 'sv_fusion' for SV/fusion, 'variant_calling' for variant calling, 'surivial' for survival data, 'none' for none of the previously described categories. The answer should always be in lower case").dynamic_context(top_k, vector_store.index(embedding_model)).additional_params(additional).temperature(temperature).build();
1162
+ let agent = AgentBuilder::new(comp_model).preamble("Generate classification for the user query into summary, dge, hierarchial, snv_indel, cnv, variant_calling, sv_fusion and none categories. Return output in JSON with ALWAYS a single word answer { \"answer\": \"dge\" }, that is 'summary' for summary plot, 'dge' for differential gene expression, 'hierarchial' for hierarchial clustering, 'snv_indel' for SNV/Indel, 'cnv' for CNV and 'sv_fusion' for SV/fusion, 'variant_calling' for variant calling, 'surivial' for survival data, 'none' for none of the previously described categories. The answer should always be in lower case. \nQuestion= {question} \nanswer").dynamic_context(top_k, vector_store.index(embedding_model)).additional_params(additional).temperature(temperature).build();
1122
1163
 
1123
1164
  let response = agent.prompt(user_input).await.expect("Failed to prompt myprovider");
1124
1165
 
1125
1166
  //println!("Myprovider: {}", response);
1126
1167
  let result = response.replace("json", "").replace("```", "");
1127
1168
  //println!("result:{}", result);
1128
- let json_value: Value = serde_json::from_str(&result).expect("REASON");
1129
- let json_value2: Value = serde_json::from_str(&json_value[0]["generated_text"].to_string()).expect("REASON2");
1169
+ let json_value: Value = serde_json::from_str(&result).expect("REASON2");
1170
+ let json_value2: Value = serde_json::from_str(&json_value[0]["generated_text"].to_string()).expect("REASON3");
1130
1171
  //println!("json_value2:{}", json_value2.as_str().unwrap());
1131
- let json_value3: Value = serde_json::from_str(&json_value2.as_str().unwrap()).expect("REASON2");
1172
+ let json_value3: Value = serde_json::from_str(&json_value2.as_str().unwrap()).expect("REASON4");
1132
1173
  assert_eq!(json_value3["answer"].to_string().replace("\"", ""), "dge");
1133
1174
  }
1134
1175
  }
package/src/test_ai.rs ADDED
@@ -0,0 +1,168 @@
1
+ // For capturing output from a test, run: cd .. && cargo test -- --nocapture
2
+ // Ignored tests: cd .. && export RUST_BACKTRACE=full && time cargo test -- --ignored --nocapture
3
+ #[allow(dead_code)]
4
+ fn main() {}
5
+
6
+ #[cfg(test)]
7
+ mod tests {
8
+ use serde_json;
9
+ use std::fs::{self};
10
+ use std::path::Path;
11
+
12
+ #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
13
+ struct ServerConfig {
14
+ tpmasterdir: String,
15
+ llm_backend: String,
16
+ sj_apilink: String,
17
+ sj_comp_model_name: String,
18
+ sj_embedding_model_name: String,
19
+ ollama_apilink: String,
20
+ ollama_comp_model_name: String,
21
+ ollama_embedding_model_name: String,
22
+ genomes: Vec<Genomes>,
23
+ }
24
+
25
+ #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
26
+ struct Genomes {
27
+ name: String,
28
+ datasets: Vec<Dataset>,
29
+ }
30
+
31
+ #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
32
+ struct Dataset {
33
+ name: String,
34
+ aifiles: Option<String>, // For now aifiles are defined only for certain datasets
35
+ }
36
+
37
+ #[tokio::test]
38
+ #[ignore]
39
+ async fn user_prompts() {
40
+ let temperature: f64 = 0.01;
41
+ let max_new_tokens: usize = 512;
42
+ let top_p: f32 = 0.95;
43
+ let serverconfig_file_path = Path::new("../../serverconfig.json");
44
+ let absolute_path = serverconfig_file_path.canonicalize().unwrap();
45
+
46
+ // Read the file
47
+ let data = fs::read_to_string(absolute_path).unwrap();
48
+
49
+ // Parse the JSON data
50
+ let serverconfig: ServerConfig = serde_json::from_str(&data).expect("JSON not in serverconfig.json format");
51
+
52
+ for genome in &serverconfig.genomes {
53
+ for dataset in &genome.datasets {
54
+ match &dataset.aifiles {
55
+ Some(ai_json_file) => {
56
+ println!("Testing dataset:{}", dataset.name);
57
+ let ai_json_file_path = String::from("../../") + ai_json_file;
58
+ let ai_json_file = Path::new(&ai_json_file_path);
59
+
60
+ // Read the file
61
+ let ai_data = fs::read_to_string(ai_json_file).unwrap();
62
+ // Parse the JSON data
63
+ let ai_json: super::super::AiJsonFormat =
64
+ serde_json::from_str(&ai_data).expect("AI JSON file does not have the correct format");
65
+ //println!("ai_json:{:?}", ai_json);
66
+ let genedb = String::from(&serverconfig.tpmasterdir) + &"/" + &ai_json.genedb;
67
+ let dataset_db = String::from(&serverconfig.tpmasterdir) + &"/" + &ai_json.db;
68
+ let llm_backend_name = &serverconfig.llm_backend;
69
+ let llm_backend_type: super::super::llm_backend;
70
+
71
+ if llm_backend_name != "ollama" && llm_backend_name != "SJ" {
72
+ panic!(
73
+ "This code currently supports only Ollama and SJ provider. llm_backend_name must be \"ollama\" or \"SJ\""
74
+ );
75
+ } else if *llm_backend_name == "ollama".to_string() {
76
+ let ollama_host = &serverconfig.ollama_apilink;
77
+ let ollama_embedding_model_name = &serverconfig.ollama_embedding_model_name;
78
+ let ollama_comp_model_name = &serverconfig.ollama_comp_model_name;
79
+ llm_backend_type = super::super::llm_backend::Ollama();
80
+ let ollama_client = super::super::ollama::Client::builder()
81
+ .base_url(ollama_host)
82
+ .build()
83
+ .expect("Ollama server not found");
84
+ let embedding_model = ollama_client.embedding_model(ollama_embedding_model_name);
85
+ let comp_model = ollama_client.completion_model(ollama_comp_model_name);
86
+
87
+ for chart in ai_json.charts.clone() {
88
+ match chart {
89
+ super::super::Charts::Summary(testdata) => {
90
+ for ques_ans in testdata.TestData {
91
+ let user_input = ques_ans.question;
92
+ let llm_output = super::super::run_pipeline(
93
+ &user_input,
94
+ comp_model.clone(),
95
+ embedding_model.clone(),
96
+ llm_backend_type.clone(),
97
+ temperature,
98
+ max_new_tokens,
99
+ top_p,
100
+ &dataset_db,
101
+ &genedb,
102
+ &ai_json,
103
+ )
104
+ .await;
105
+ let mut llm_json_value: super::super::SummaryType = serde_json::from_str(&llm_output.unwrap()).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
106
+ let mut expected_json_value: super::super::SummaryType = serde_json::from_str(&ques_ans.answer).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
107
+ assert_eq!(
108
+ llm_json_value.sort_summarytype_struct(),
109
+ expected_json_value.sort_summarytype_struct()
110
+ );
111
+ }
112
+ }
113
+ super::super::Charts::DE(_testdata) => {} // To do
114
+ }
115
+ }
116
+ } else if *llm_backend_name == "SJ".to_string() {
117
+ let sjprovider_host = &serverconfig.sj_apilink;
118
+ let sj_embedding_model_name = &serverconfig.sj_embedding_model_name;
119
+ let sj_comp_model_name = &serverconfig.sj_comp_model_name;
120
+ llm_backend_type = super::super::llm_backend::Sj();
121
+ let sj_client = super::super::sjprovider::Client::builder()
122
+ .base_url(sjprovider_host)
123
+ .build()
124
+ .expect("SJ server not found");
125
+ let embedding_model = sj_client.embedding_model(sj_embedding_model_name);
126
+ let comp_model = sj_client.completion_model(sj_comp_model_name);
127
+
128
+ for chart in ai_json.charts.clone() {
129
+ match chart {
130
+ super::super::Charts::Summary(testdata) => {
131
+ for ques_ans in testdata.TestData {
132
+ let user_input = ques_ans.question;
133
+ if user_input.len() > 0 {
134
+ let llm_output = super::super::run_pipeline(
135
+ &user_input,
136
+ comp_model.clone(),
137
+ embedding_model.clone(),
138
+ llm_backend_type.clone(),
139
+ temperature,
140
+ max_new_tokens,
141
+ top_p,
142
+ &dataset_db,
143
+ &genedb,
144
+ &ai_json,
145
+ )
146
+ .await;
147
+ let mut llm_json_value: super::super::SummaryType = serde_json::from_str(&llm_output.unwrap()).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
148
+ let mut expected_json_value: super::super::SummaryType = serde_json::from_str(&ques_ans.answer).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
149
+ assert_eq!(
150
+ llm_json_value.sort_summarytype_struct(),
151
+ expected_json_value.sort_summarytype_struct()
152
+ );
153
+ } else {
154
+ panic!("The user input is empty");
155
+ }
156
+ }
157
+ }
158
+ super::super::Charts::DE(_testdata) => {} // To do
159
+ }
160
+ }
161
+ }
162
+ }
163
+ None => {}
164
+ }
165
+ }
166
+ }
167
+ }
168
+ }