@sjcrh/proteinpaint-rust 2.166.0 → 2.167.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/package.json CHANGED
@@ -1,5 +1,5 @@
1
1
  {
2
- "version": "2.166.0",
2
+ "version": "2.167.0",
3
3
  "name": "@sjcrh/proteinpaint-rust",
4
4
  "type": "module",
5
5
  "description": "Rust-based utilities for proteinpaint",
package/src/aichatbot.rs CHANGED
@@ -160,7 +160,7 @@ async fn main() -> Result<()> {
160
160
  let temperature: f64 = 0.01;
161
161
  let max_new_tokens: usize = 512;
162
162
  let top_p: f32 = 0.95;
163
-
163
+ let testing = false; // This variable is always false in production, this is true in test_ai.rs for testing code
164
164
  if llm_backend_name != "ollama" && llm_backend_name != "SJ" {
165
165
  panic!(
166
166
  "This code currently supports only Ollama and SJ provider. llm_backend_name must be \"ollama\" or \"SJ\""
@@ -185,6 +185,7 @@ async fn main() -> Result<()> {
185
185
  &dataset_db,
186
186
  &genedb,
187
187
  &ai_json,
188
+ testing,
188
189
  )
189
190
  .await;
190
191
  } else if llm_backend_name == "SJ".to_string() {
@@ -207,6 +208,7 @@ async fn main() -> Result<()> {
207
208
  &dataset_db,
208
209
  &genedb,
209
210
  &ai_json,
211
+ testing,
210
212
  )
211
213
  .await;
212
214
  }
@@ -239,6 +241,7 @@ pub async fn run_pipeline(
239
241
  dataset_db: &str,
240
242
  genedb: &str,
241
243
  ai_json: &AiJsonFormat,
244
+ testing: bool,
242
245
  ) -> Option<String> {
243
246
  let mut classification: String = classify_query_by_dataset_type(
244
247
  user_input,
@@ -263,13 +266,20 @@ pub async fn run_pipeline(
263
266
  top_p,
264
267
  )
265
268
  .await;
266
- final_output = format!(
267
- "{{\"{}\":\"{}\",\"{}\":[{}}}",
268
- "action",
269
- "dge",
270
- "DE_output",
271
- de_result + &"]"
272
- );
269
+ if testing == true {
270
+ final_output = format!(
271
+ "{{\"{}\":\"{}\",\"{}\":[{}}}",
272
+ "action",
273
+ "dge",
274
+ "DE_output",
275
+ de_result + &"]"
276
+ );
277
+ } else {
278
+ final_output = format!(
279
+ "{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
280
+ "type", "html", "html", "DE agent not implemented yet"
281
+ );
282
+ }
273
283
  } else if classification == "summary".to_string() {
274
284
  final_output = extract_summary_information(
275
285
  user_input,
@@ -282,30 +292,83 @@ pub async fn run_pipeline(
282
292
  dataset_db,
283
293
  genedb,
284
294
  ai_json,
295
+ testing,
285
296
  )
286
297
  .await;
287
298
  } else if classification == "hierarchical".to_string() {
288
299
  // Not implemented yet
289
- final_output = format!("{{\"{}\":\"{}\"}}", "action", "hierarchical");
300
+ if testing == true {
301
+ final_output = format!("{{\"{}\":\"{}\"}}", "action", "hierarchical");
302
+ } else {
303
+ final_output = format!(
304
+ "{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
305
+ "type", "html", "html", "hierarchical clustering agent not implemented yet"
306
+ );
307
+ }
290
308
  } else if classification == "snv_indel".to_string() {
291
309
  // Not implemented yet
292
- final_output = format!("{{\"{}\":\"{}\"}}", "action", "snv_indel");
310
+ if testing == true {
311
+ final_output = format!("{{\"{}\":\"{}\"}}", "action", "snv_indel");
312
+ } else {
313
+ final_output = format!(
314
+ "{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
315
+ "type", "html", "html", "snv_indel agent not implemented yet"
316
+ );
317
+ }
293
318
  } else if classification == "cnv".to_string() {
294
319
  // Not implemented yet
295
- final_output = format!("{{\"{}\":\"{}\"}}", "action", "cnv");
320
+ if testing == true {
321
+ final_output = format!("{{\"{}\":\"{}\"}}", "action", "cnv");
322
+ } else {
323
+ final_output = format!(
324
+ "{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
325
+ "type", "html", "html", "cnv agent not implemented yet"
326
+ );
327
+ }
296
328
  } else if classification == "variant_calling".to_string() {
297
329
  // Not implemented yet and will never be supported. Need a separate messages for this
298
- final_output = format!("{{\"{}\":\"{}\"}}", "action", "variant_calling");
330
+ if testing == true {
331
+ final_output = format!("{{\"{}\":\"{}\"}}", "action", "variant_calling");
332
+ } else {
333
+ final_output = format!(
334
+ "{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
335
+ "type", "html", "html", "variant_calling agent not implemented yet"
336
+ );
337
+ }
299
338
  } else if classification == "survival".to_string() {
300
339
  // Not implemented yet
301
- final_output = format!("{{\"{}\":\"{}\"}}", "action", "surivial");
340
+ if testing == true {
341
+ final_output = format!("{{\"{}\":\"{}\"}}", "action", "surivial");
342
+ } else {
343
+ final_output = format!(
344
+ "{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
345
+ "type", "html", "html", "survival agent not implemented yet"
346
+ );
347
+ }
302
348
  } else if classification == "none".to_string() {
303
- final_output = format!(
304
- "{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
305
- "action", "none", "message", "The input query did not match any known features in Proteinpaint"
306
- );
349
+ if testing == true {
350
+ final_output = format!(
351
+ "{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
352
+ "action", "none", "message", "The input query did not match any known features in Proteinpaint"
353
+ );
354
+ } else {
355
+ final_output = format!(
356
+ "{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
357
+ "type", "html", "html", "The input query did not match any known features in Proteinpaint"
358
+ );
359
+ }
307
360
  } else {
308
- final_output = format!("{{\"{}\":\"{}\"}}", "action", "unknown:".to_string() + &classification);
361
+ if testing == true {
362
+ final_output = format!("{{\"{}\":\"{}\"}}", "action", "unknown:".to_string() + &classification);
363
+ } else {
364
+ final_output = format!(
365
+ "{{\"{}\":\"{}\",\"{}\":\"{}\"}}",
366
+ "type",
367
+ "html",
368
+ "html",
369
+ "unknown:".to_string() + &classification
370
+ );
371
+ }
309
372
  }
310
373
  Some(final_output)
311
374
  }
@@ -801,6 +864,7 @@ async fn extract_summary_information(
801
864
  dataset_db: &str,
802
865
  genedb: &str,
803
866
  ai_json: &AiJsonFormat,
867
+ testing: bool,
804
868
  ) -> String {
805
869
  let (rag_docs, db_vec) = parse_dataset_db(dataset_db).await;
806
870
  let additional;
@@ -919,7 +983,8 @@ async fn extract_summary_information(
919
983
  }
920
984
  }
921
985
  //println!("final_llm_json:{}", final_llm_json);
922
- let final_validated_json = validate_summary_output(final_llm_json.clone(), db_vec, common_genes, ai_json);
986
+ let final_validated_json =
987
+ validate_summary_output(final_llm_json.clone(), db_vec, common_genes, ai_json, testing);
923
988
  final_validated_json
924
989
  }
925
990
  None => {
@@ -1063,6 +1128,7 @@ fn validate_summary_output(
1063
1128
  db_vec: Vec<DbRows>,
1064
1129
  common_genes: Vec<String>,
1065
1130
  ai_json: &AiJsonFormat,
1131
+ testing: bool,
1066
1132
  ) -> String {
1067
1133
  let json_value: SummaryType =
1068
1134
  serde_json::from_str(&raw_llm_json).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
@@ -1094,7 +1160,7 @@ fn validate_summary_output(
1094
1160
  match term_verification.correct_field {
1095
1161
  Some(tm) => validated_summary_terms.push(SummaryTerms::clinical(tm)),
1096
1162
  None => {
1097
- message = message + &"\"" + &clin + &"\"" + &" not found in db.";
1163
+ message = message + &"'" + &clin + &"'" + &" not found in db.";
1098
1164
  }
1099
1165
  }
1100
1166
  } else if Some(term_verification.correct_field.clone()).is_some()
@@ -1122,7 +1188,7 @@ fn validate_summary_output(
1122
1188
  if num_gene_verification == 0 || common_genes.len() == 0 {
1123
1189
  if message.to_lowercase().contains(&gene.to_lowercase()) { // Check if the LLM has already added the message, if not then add it
1124
1190
  } else {
1125
- message = message + &"\"" + &gene + &"\"" + &" not found in genedb.";
1191
+ message = message + &"'" + &gene + &"'" + &" not found in genedb.";
1126
1192
  }
1127
1193
  }
1128
1194
  }
@@ -1138,6 +1204,8 @@ fn validate_summary_output(
1138
1204
  }
1139
1205
  }
1140
1206
 
1207
+ let mut pp_plot_json: Value; // The PP compliant plot JSON
1208
+ pp_plot_json = serde_json::from_str(&"{\"chartType\":\"summary\"}").expect("Not a valid JSON");
1141
1209
  match &json_value.filter {
1142
1210
  Some(filter_terms_array) => {
1143
1211
  let mut validated_filter_terms = Vec::<FilterTerm>::new();
@@ -1168,21 +1236,21 @@ fn validate_summary_output(
1168
1236
  validated_filter_terms.push(categorical_filter_term);
1169
1237
  }
1170
1238
  if term_verification.correct_field.is_none() {
1171
- message = message + &"\"" + &categorical.term + &"\" filter term not found in db";
1239
+ message = message + &"'" + &categorical.term + &"' filter term not found in db";
1172
1240
  }
1173
1241
  if value_verification.is_none() {
1174
1242
  message = message
1175
- + &"\""
1243
+ + &"'"
1176
1244
  + &categorical.value
1177
- + &"\" filter value not found for filter field \""
1245
+ + &"' filter value not found for filter field '"
1178
1246
  + &categorical.term
1179
- + "\" in db";
1247
+ + "' in db";
1180
1248
  }
1181
1249
  }
1182
1250
  FilterTerm::Numeric(numeric) => {
1183
1251
  let term_verification = verify_json_field(&numeric.term, &db_vec);
1184
1252
  if term_verification.correct_field.is_none() {
1185
- message = message + &"\"" + &numeric.term + &"\" filter term not found in db";
1253
+ message = message + &"'" + &numeric.term + &"' filter term not found in db";
1186
1254
  } else {
1187
1255
  let numeric_filter_term: FilterTerm = FilterTerm::Numeric(numeric.clone());
1188
1256
  validated_filter_terms.push(numeric_filter_term);
@@ -1229,8 +1297,38 @@ fn validate_summary_output(
1229
1297
  }
1230
1298
 
1231
1299
  if validated_filter_terms.len() > 0 {
1232
- if let Some(obj) = new_json.as_object_mut() {
1233
- obj.insert(String::from("filter"), serde_json::json!(validated_filter_terms));
1300
+ if testing == true {
1301
+ if let Some(obj) = new_json.as_object_mut() {
1302
+ obj.insert(String::from("filter"), serde_json::json!(validated_filter_terms));
1303
+ }
1304
+ } else {
1305
+ let mut validated_filter_terms_PP: String = "[".to_string();
1306
+ let mut filter_hits = 0;
1307
+ for validated_term in validated_filter_terms {
1308
+ match validated_term {
1309
+ FilterTerm::Categorical(categorical_filter) => {
1310
+ let string_json = "{\"term\":\"".to_string()
1311
+ + &categorical_filter.term
1312
+ + &"\", \"category\":\""
1313
+ + &categorical_filter.value
1314
+ + &"\"},";
1315
+ validated_filter_terms_PP += &string_json;
1316
+ filter_hits += 1; // Once numeric term is also implemented, this statement will go outside the match block
1317
+ }
1318
+ FilterTerm::Numeric(_numeric_term) => {} // To be implemented later
1319
+ };
1320
+ }
1321
+ println!("validated_filter_terms_PP:{}", validated_filter_terms_PP);
1322
+ if filter_hits > 0 {
1323
+ validated_filter_terms_PP.pop();
1324
+ validated_filter_terms_PP += &"]";
1325
+ if let Some(obj) = pp_plot_json.as_object_mut() {
1326
+ obj.insert(
1327
+ String::from("simpleFilter"),
1328
+ serde_json::from_str(&validated_filter_terms_PP).expect("Not a valid JSON"),
1329
+ );
1330
+ }
1331
+ }
1234
1332
  }
1235
1333
  }
1236
1334
  }
@@ -1240,6 +1338,10 @@ fn validate_summary_output(
1240
1338
  // Removing terms that are found both in filter term as well summary
1241
1339
  let mut validated_summary_terms_final = Vec::<SummaryTerms>::new();
1242
1340
 
1341
+ let mut sum_iter = 0;
1342
+ let mut pp_json: Value; // New JSON value that will contain items of the final PP compliant JSON
1343
+ pp_json = serde_json::from_str(&"{\"type\":\"plot\"}").expect("Not a valid JSON");
1344
+
1243
1345
  for summary_term in &validated_summary_terms {
1244
1346
  let mut hit = 0;
1245
1347
  match summary_term {
@@ -1276,9 +1378,53 @@ fn validate_summary_output(
1276
1378
  }
1277
1379
  }
1278
1380
  }
1381
+
1279
1382
  if hit == 0 {
1383
+ let mut termidpp: Option<TermIDPP> = None;
1384
+ let mut geneexp: Option<GeneExpressionPP> = None;
1385
+ match summary_term {
1386
+ SummaryTerms::clinical(clinical_term) => {
1387
+ termidpp = Some(TermIDPP {
1388
+ id: clinical_term.to_string(),
1389
+ });
1390
+ }
1391
+ SummaryTerms::geneExpression(gene) => {
1392
+ geneexp = Some(GeneExpressionPP {
1393
+ gene: gene.to_string(),
1394
+ r#type: "geneExpression".to_string(),
1395
+ });
1396
+ }
1397
+ }
1398
+ if sum_iter == 0 {
1399
+ if termidpp.is_some() {
1400
+ if let Some(obj) = pp_plot_json.as_object_mut() {
1401
+ obj.insert(String::from("term"), serde_json::json!(Some(termidpp)));
1402
+ }
1403
+ }
1404
+
1405
+ if geneexp.is_some() {
1406
+ let gene_term = GeneTerm { term: geneexp.unwrap() };
1407
+ if let Some(obj) = pp_plot_json.as_object_mut() {
1408
+ obj.insert(String::from("term"), serde_json::json!(gene_term));
1409
+ }
1410
+ }
1411
+ } else if sum_iter == 1 {
1412
+ if termidpp.is_some() {
1413
+ if let Some(obj) = pp_plot_json.as_object_mut() {
1414
+ obj.insert(String::from("term2"), serde_json::json!(Some(termidpp)));
1415
+ }
1416
+ }
1417
+
1418
+ if geneexp.is_some() {
1419
+ let gene_term = GeneTerm { term: geneexp.unwrap() };
1420
+ if let Some(obj) = pp_plot_json.as_object_mut() {
1421
+ obj.insert(String::from("term2"), serde_json::json!(gene_term));
1422
+ }
1423
+ }
1424
+ }
1280
1425
  validated_summary_terms_final.push(summary_term.clone())
1281
1426
  }
1427
+ sum_iter += 1
1282
1428
  }
1283
1429
 
1284
1430
  if let Some(obj) = new_json.as_object_mut() {
@@ -1288,14 +1434,61 @@ fn validate_summary_output(
1288
1434
  );
1289
1435
  }
1290
1436
 
1437
+ if let Some(obj) = pp_json.as_object_mut() {
1438
+ // The `if let` ensures we only proceed if the top-level JSON is an object.
1439
+ // Append a new string field.
1440
+ obj.insert(String::from("plot"), serde_json::json!(pp_plot_json));
1441
+ }
1442
+
1443
+ let mut err_json: Value; // Error JSON containing the error message (if present)
1291
1444
  if message.len() > 0 {
1292
- if let Some(obj) = new_json.as_object_mut() {
1293
- // The `if let` ensures we only proceed if the top-level JSON is an object.
1294
- // Append a new string field.
1295
- obj.insert(String::from("message"), serde_json::json!(message));
1445
+ if testing == false {
1446
+ err_json = serde_json::from_str(&"{\"type\":\"html\"}").expect("Not a valid JSON");
1447
+ if let Some(obj) = err_json.as_object_mut() {
1448
+ // The `if let` ensures we only proceed if the top-level JSON is an object.
1449
+ // Append a new string field.
1450
+ obj.insert(String::from("html"), serde_json::json!(message));
1451
+ };
1452
+ serde_json::to_string(&err_json).unwrap()
1453
+ } else {
1454
+ if let Some(obj) = new_json.as_object_mut() {
1455
+ // The `if let` ensures we only proceed if the top-level JSON is an object.
1456
+ // Append a new string field.
1457
+ obj.insert(String::from("message"), serde_json::json!(message));
1458
+ };
1459
+ serde_json::to_string(&new_json).unwrap()
1460
+ }
1461
+ } else {
1462
+ if testing == true {
1463
+ // When testing script output native LLM JSON
1464
+ serde_json::to_string(&new_json).unwrap()
1465
+ } else {
1466
+ // When in production output PP compliant JSON
1467
+ serde_json::to_string(&pp_json).unwrap()
1296
1468
  }
1297
1469
  }
1298
- serde_json::to_string(&new_json).unwrap()
1470
+ }
1471
+
1472
+ fn getGeneExpression() -> String {
1473
+ "geneExpression".to_string()
1474
+ }
1475
+
1476
+ #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
1477
+ struct TermIDPP {
1478
+ id: String,
1479
+ }
1480
+
1481
+ #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
1482
+ struct GeneTerm {
1483
+ term: GeneExpressionPP,
1484
+ }
1485
+
1486
+ #[derive(PartialEq, Debug, Clone, schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
1487
+ struct GeneExpressionPP {
1488
+ gene: String,
1489
+ // Serde uses this for deserialization.
1490
+ #[serde(default = "getGeneExpression")]
1491
+ r#type: String,
1299
1492
  }
1300
1493
 
1301
1494
  #[derive(Debug, Clone)]
@@ -57,6 +57,7 @@ struct InteractiveData {
57
57
  x_buffer: i64,
58
58
  y_min: f64,
59
59
  y_max: f64,
60
+ device_pixel_ratio: f64,
60
61
  }
61
62
 
62
63
  #[derive(Serialize)]
@@ -335,12 +336,8 @@ fn plot_grin2_manhattan(
335
336
  let png_width = plot_width + 2 * png_dot_radius;
336
337
  let png_height = plot_height + 2 * png_dot_radius;
337
338
 
338
- let w: u32 = (png_width * device_pixel_ratio as u64)
339
- .try_into()
340
- .expect("PNG width too large for u32");
341
- let h: u32 = (png_height * device_pixel_ratio as u64)
342
- .try_into()
343
- .expect("PNG height too large for u32");
339
+ let w: u32 = ((png_width as f64) * dpr) as u32;
340
+ let h: u32 = ((png_height as f64) * dpr) as u32;
344
341
 
345
342
  // Create RGB buffer for Plotters
346
343
  let mut buffer = vec![0u8; w as usize * h as usize * 3];
@@ -402,8 +399,8 @@ fn plot_grin2_manhattan(
402
399
 
403
400
  for (i, p) in point_details.iter_mut().enumerate() {
404
401
  let (px, py) = pixel_positions[*&sig_indices[i]];
405
- p.pixel_x = px;
406
- p.pixel_y = py;
402
+ p.pixel_x = px / dpr;
403
+ p.pixel_y = py / dpr;
407
404
  }
408
405
 
409
406
  // flush root drawing area
@@ -469,6 +466,7 @@ fn plot_grin2_manhattan(
469
466
  x_buffer,
470
467
  y_min,
471
468
  y_max,
469
+ device_pixel_ratio: dpr,
472
470
  };
473
471
  Ok((png_data, interactive_data))
474
472
  }
package/src/test_ai.rs CHANGED
@@ -42,6 +42,7 @@ mod tests {
42
42
  let top_p: f32 = 0.95;
43
43
  let serverconfig_file_path = Path::new("../../serverconfig.json");
44
44
  let absolute_path = serverconfig_file_path.canonicalize().unwrap();
45
+ let testing = true; // This causes the JSON being output from run_pipeline() to be in LLM JSON format
45
46
 
46
47
  // Read the file
47
48
  let data = fs::read_to_string(absolute_path).unwrap();
@@ -83,7 +84,6 @@ mod tests {
83
84
  .expect("Ollama server not found");
84
85
  let embedding_model = ollama_client.embedding_model(ollama_embedding_model_name);
85
86
  let comp_model = ollama_client.completion_model(ollama_comp_model_name);
86
-
87
87
  for chart in ai_json.charts.clone() {
88
88
  match chart {
89
89
  super::super::Charts::Summary(testdata) => {
@@ -100,6 +100,7 @@ mod tests {
100
100
  &dataset_db,
101
101
  &genedb,
102
102
  &ai_json,
103
+ testing,
103
104
  )
104
105
  .await;
105
106
  let mut llm_json_value: super::super::SummaryType = serde_json::from_str(&llm_output.unwrap()).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");
@@ -142,6 +143,7 @@ mod tests {
142
143
  &dataset_db,
143
144
  &genedb,
144
145
  &ai_json,
146
+ testing,
145
147
  )
146
148
  .await;
147
149
  let mut llm_json_value: super::super::SummaryType = serde_json::from_str(&llm_output.unwrap()).expect("Did not get a valid JSON of type {action: summary, summaryterms:[{clinical: term1}, {geneExpression: gene}], filter:[{term: term1, value: value1}]} from the LLM");