@sjcrh/proteinpaint-rust 2.169.0 → 2.170.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,7 +3,7 @@ use plotters::prelude::*;
3
3
  use plotters::style::ShapeStyle;
4
4
  use serde::{Deserialize, Serialize};
5
5
  use serde_json;
6
- use std::collections::HashMap;
6
+ use std::collections::{HashMap, HashSet};
7
7
  use std::convert::TryInto;
8
8
  use std::error::Error;
9
9
  use std::fs::File;
@@ -22,6 +22,10 @@ struct Input {
22
22
  plot_height: u64,
23
23
  device_pixel_ratio: f64,
24
24
  png_dot_radius: u64,
25
+ max_capped_points: u64,
26
+ hard_cap: f64,
27
+ bin_size: f64,
28
+ q_value_threshold: f64,
25
29
  }
26
30
 
27
31
  // chromosome info
@@ -58,6 +62,8 @@ struct InteractiveData {
58
62
  y_min: f64,
59
63
  y_max: f64,
60
64
  device_pixel_ratio: f64,
65
+ default_log_cutoff: f64,
66
+ has_capped_points: bool,
61
67
  }
62
68
 
63
69
  #[derive(Serialize)]
@@ -78,6 +84,134 @@ fn hex_to_rgb(hex: &str) -> Option<(u8, u8, u8)> {
78
84
  Some((r, g, b))
79
85
  }
80
86
 
87
+ // Helper function to calculate default log cutoff value from the data coming from GRIN2 file
88
+ // We just find the mean of the -log10 q-values that are below the hard cap and
89
+ // set it as the default log cutoff. If the mean is less than 40, we set it to 40.
90
+ // If it is too low it can cause an error in the setting up of the histogram bins in the dynamic y-cap calculation.
91
+ // The exclude_indices parameter allows us to skip placeholder values (e.g., 0.0 placeholders for zero q-values)
92
+ // that would otherwise contaminate the mean calculation.
93
+ fn get_log_cutoff(ys: &[f64], hard_cap: f64, exclude_indices: &HashSet<usize>) -> f64 {
94
+ let filtered: Vec<f64> = ys
95
+ .iter()
96
+ .enumerate()
97
+ .filter(|(i, &y)| y < hard_cap && !exclude_indices.contains(i))
98
+ .map(|(_, &y)| y)
99
+ .collect();
100
+ let count = filtered.len();
101
+ let sum: f64 = filtered.iter().sum();
102
+
103
+ // If all values are greater than or equal to hard_cap (or excluded), default to hard_cap
104
+ if filtered.is_empty() {
105
+ return hard_cap;
106
+ }
107
+ let mean = sum / count as f64;
108
+
109
+ mean.max(40.0)
110
+ }
111
+
112
+ /// Calculates a dynamic y-axis cap for Manhattan plots to handle outliers gracefully.
113
+ ///
114
+ /// # Problem
115
+ /// Manhattan plots often have a few extreme outliers (very significant p-values) that
116
+ /// compress the visual range for the majority of points. This function finds an optimal
117
+ /// y-axis cap that:
118
+ /// - Shows most data at true scale
119
+ /// - Caps only a small number of extreme outliers
120
+ /// - Ensures visible outliers (below hard cap) render at their true positions
121
+ ///
122
+ /// # Algorithm
123
+ /// 1. **No outliers**: If `max_y <= default_cap`, return `max_y` (no capping needed)
124
+ /// 2. **Histogram binning**: Partition the range `(default_cap, hard_cap]` into fixed-size bins
125
+ /// 3. **Walk up**: Starting from the lowest bin, find the first cap where at most
126
+ /// `max_capped_points` would be clamped
127
+ /// 4. **Preserve visible outliers**: Ensure the chosen cap is above the highest y-value
128
+ /// that falls below `hard_cap`, so those points render at their true positions
129
+ ///
130
+ /// # Parameters
131
+ /// - `ys`: All y-values (-log10 q-values) in the plot
132
+ /// - `max_capped_points`: Maximum points allowed to be clamped to the cap (e.g., 5)
133
+ /// - `default_cap`: Starting threshold; points below this are never capped (e.g., whatever log_cutoff is calculated to be from get_log_cutoff)
134
+ /// - `hard_cap`: Absolute maximum y-axis value; points above are always clamped (e.g., 200)
135
+ /// - `bin_size`: Histogram bin width on -log10 scale (e.g., 10)
136
+ ///
137
+ /// # Returns
138
+ /// The optimal y-axis cap, guaranteed to be in the range `[max_y.min(default_cap), hard_cap]`
139
+ ///
140
+ /// # Example
141
+ /// With `default_cap=40`, `hard_cap=200`, `bin_size=10`, `max_capped_points=5`:
142
+ /// - If 7 points are above 40, with two at 83 and 183 and five at/above 200:
143
+ /// Returns 200, so the points at 83 and 183 display at their true positions while
144
+ /// the 5 extreme outliers are clamped to 200
145
+ fn calculate_dynamic_y_cap(
146
+ ys: &[f64],
147
+ max_capped_points: usize,
148
+ default_cap: f64,
149
+ hard_cap: f64,
150
+ bin_size: f64,
151
+ ) -> f64 {
152
+ let mut num_bins = ((hard_cap - default_cap) / bin_size) as usize;
153
+ if num_bins == 0 {
154
+ // Have to make sure num_bins is positive to avoid issues with histogram later
155
+ num_bins = 1;
156
+ }
157
+ let mut histogram = vec![0usize; num_bins];
158
+ let mut max_y = f64::NEG_INFINITY;
159
+ let mut max_y_below_hard_cap = f64::NEG_INFINITY; // Track highest value that's not hard-capped
160
+ let mut points_above_default = 0usize;
161
+
162
+ // Single pass: find max and build histogram simultaneously
163
+ for &y in ys {
164
+ if y > max_y {
165
+ max_y = y;
166
+ }
167
+ if y > default_cap {
168
+ points_above_default += 1;
169
+ if y > hard_cap {
170
+ histogram[num_bins - 1] += 1;
171
+ } else {
172
+ // Track the max y that's at or below the hard cap
173
+ if y > max_y_below_hard_cap {
174
+ max_y_below_hard_cap = y;
175
+ }
176
+ let bin_idx = ((y - default_cap) / bin_size) as usize;
177
+ histogram[bin_idx] += 1;
178
+ }
179
+ }
180
+ }
181
+
182
+ // Case 1: No points exceed default cap - use actual max
183
+ if max_y <= default_cap {
184
+ return max_y;
185
+ }
186
+
187
+ // Walk up from default_cap to hard_cap
188
+ let mut points_above = points_above_default;
189
+
190
+ for (i, &count) in histogram.iter().enumerate() {
191
+ if points_above <= max_capped_points {
192
+ // Found acceptable number of capped points
193
+ let bin_upper_bound = default_cap + ((i + 1) as f64) * bin_size;
194
+
195
+ // The cap should be:
196
+ // 1. At least above max_y_below_hard_cap (so those points render at true position)
197
+ // 2. At most hard_cap
198
+ // 3. But if all outliers are at/above hard_cap, use the bin boundary
199
+ let cap = if max_y_below_hard_cap > bin_upper_bound {
200
+ // There's a visible outlier above this bin - extend cap to show it
201
+ (max_y_below_hard_cap + bin_size).min(hard_cap)
202
+ } else {
203
+ bin_upper_bound.min(hard_cap)
204
+ };
205
+
206
+ return cap;
207
+ }
208
+ points_above -= count;
209
+ }
210
+
211
+ // All points are hard-capped outliers
212
+ hard_cap
213
+ }
214
+
81
215
  // Function to Build cumulative chromosome map
82
216
  fn cumulative_chrom(
83
217
  chrom_size: &HashMap<String, u64>,
@@ -120,7 +254,18 @@ fn cumulative_chrom(
120
254
  fn grin2_file_read(
121
255
  grin2_file: &str,
122
256
  chrom_data: &HashMap<String, ChromInfo>,
123
- ) -> Result<(Vec<u64>, Vec<f64>, Vec<String>, Vec<PointDetail>, Vec<usize>), Box<dyn Error>> {
257
+ q_value_threshold: f64,
258
+ ) -> Result<
259
+ (
260
+ Vec<u64>,
261
+ Vec<f64>,
262
+ Vec<String>,
263
+ Vec<PointDetail>,
264
+ Vec<usize>,
265
+ Vec<usize>,
266
+ ),
267
+ Box<dyn Error>,
268
+ > {
124
269
  // Default colours
125
270
  let mut colors: HashMap<String, String> = HashMap::new();
126
271
  colors.insert("gain".into(), "#FF4444".into());
@@ -134,6 +279,7 @@ fn grin2_file_read(
134
279
  let mut colors_vec = Vec::new();
135
280
  let mut point_details = Vec::new();
136
281
  let mut sig_indices: Vec<usize> = Vec::new();
282
+ let mut zero_q_indices: Vec<usize> = Vec::new();
137
283
 
138
284
  let grin2_file = File::open(grin2_file).expect("Failed to open grin2_result_file");
139
285
  let mut reader = BufReader::new(grin2_file);
@@ -217,13 +363,20 @@ fn grin2_file_read(
217
363
  Some(q) => q,
218
364
  None => continue,
219
365
  };
220
-
221
- let q_val: f64 = match q_val_str.parse() {
222
- Ok(v) if v > 0.0 => v,
223
- Ok(v) if v == 0.0 => 1e-300, // Treat exact 0 as ~1e-300 so we can still show q-values that are 0 and not filter them out
366
+ let original_q_val: f64 = match q_val_str.parse() {
367
+ Ok(v) if v >= 0.0 => v,
224
368
  _ => continue,
225
369
  };
226
- let neg_log10_q = -q_val.log10();
370
+
371
+ // Use a placeholder for zero q-values - these will be updated later
372
+ // after we calculate the dynamic y_cap from the full dataset
373
+ let neg_log10_q = if original_q_val == 0.0 {
374
+ zero_q_indices.push(mut_num);
375
+ 0.0 // Placeholder - will be set to y_cap later in plot_grin2_manhattan
376
+ } else {
377
+ -original_q_val.log10()
378
+ };
379
+
227
380
  let n_subj_count: Option<i64> = n_idx_opt
228
381
  .and_then(|i| fields.get(i))
229
382
  .and_then(|s| s.parse::<i64>().ok());
@@ -234,7 +387,8 @@ fn grin2_file_read(
234
387
  colors_vec.push(color.clone());
235
388
 
236
389
  // only add significant points for interactivity
237
- if q_val <= 0.05 {
390
+ // We check against the original q-value here so we send back the correct values instead of the 1e-300 used for log transform
391
+ if original_q_val <= q_value_threshold {
238
392
  point_details.push(PointDetail {
239
393
  x: x_pos,
240
394
  y: neg_log10_q,
@@ -245,7 +399,7 @@ fn grin2_file_read(
245
399
  start: gene_start,
246
400
  end: gene_end,
247
401
  pos: gene_start,
248
- q_value: q_val,
402
+ q_value: original_q_val,
249
403
  nsubj: n_subj_count,
250
404
  pixel_x: 0.0,
251
405
  pixel_y: 0.0,
@@ -256,7 +410,7 @@ fn grin2_file_read(
256
410
  }
257
411
  }
258
412
 
259
- Ok((xs, ys, colors_vec, point_details, sig_indices))
413
+ Ok((xs, ys, colors_vec, point_details, sig_indices, zero_q_indices))
260
414
  }
261
415
 
262
416
  // Function to create the GRIN2 Manhattan plot
@@ -267,6 +421,10 @@ fn plot_grin2_manhattan(
267
421
  plot_height: u64,
268
422
  device_pixel_ratio: f64,
269
423
  png_dot_radius: u64,
424
+ bin_size: f64,
425
+ max_capped_points: u64,
426
+ hard_cap: f64,
427
+ q_value_threshold: f64,
270
428
  ) -> Result<(String, InteractiveData), Box<dyn Error>> {
271
429
  // ------------------------------------------------
272
430
  // 1. Build cumulative chromosome map
@@ -294,39 +452,73 @@ fn plot_grin2_manhattan(
294
452
  let mut colors_vec = Vec::new();
295
453
  let mut point_details = Vec::new();
296
454
  let mut sig_indices = Vec::new();
455
+ let mut zero_q_indices: Vec<usize> = Vec::new();
297
456
 
298
- if let Ok((x, y, c, pd, si)) = grin2_file_read(&grin2_result_file, &chrom_data) {
457
+ if let Ok((x, y, c, pd, si, zq)) = grin2_file_read(&grin2_result_file, &chrom_data, q_value_threshold) {
299
458
  xs = x;
300
459
  ys = y;
301
460
  colors_vec = c;
302
461
  point_details = pd;
303
462
  sig_indices = si;
463
+ zero_q_indices = zq;
304
464
  }
305
465
 
306
466
  // ------------------------------------------------
307
- // 3. Y-axis scaling
467
+ // 3. Calculate log_cutoff from data and update zero q-values
468
+ // ------------------------------------------------
469
+ // Convert zero_q_indices to HashSet for O(1) lookup when excluding placeholders
470
+ let zero_q_set: HashSet<usize> = zero_q_indices.iter().cloned().collect();
471
+ let log_cutoff = get_log_cutoff(&ys, hard_cap, &zero_q_set);
472
+
473
+ // ------------------------------------------------
474
+ // 4. Y-axis capping with dynamic cap
308
475
  // ------------------------------------------------
309
476
  let y_padding = png_dot_radius as f64;
310
477
  let y_min = 0.0 - y_padding;
311
- let y_max = if !ys.is_empty() {
478
+
479
+ // Dynamic y-cap calculation:
480
+ // - log_cutoff: the baseline cap (calculated from data mean)
481
+ // - max_capped_points: maximum number of points allowed above cap before raising it
482
+ // - hard_cap: absolute maximum cap regardless of data distribution
483
+ // - bin_size: size of bins for histogram approach
484
+ let max_capped_points = max_capped_points as usize;
485
+
486
+ let y_cap = calculate_dynamic_y_cap(&ys, max_capped_points, log_cutoff, hard_cap, bin_size);
487
+
488
+ let (y_max, has_capped_points) = if !ys.is_empty() {
312
489
  let max_y = ys.iter().cloned().fold(f64::MIN, f64::max);
313
- if max_y > 40.0 {
314
- let target = 40.0;
315
- let scale_factor_y = target / max_y;
316
490
 
491
+ // has_capped_points is true if any points exceed the default cap (log_cutoff)
492
+ let has_capped = max_y > log_cutoff;
493
+
494
+ // Set q=0 points (currently placeholders at 0.0) to y_cap so they appear at the top
495
+ for &idx in &zero_q_indices {
496
+ ys[idx] = y_cap;
497
+ }
498
+ for p in point_details.iter_mut() {
499
+ if p.q_value == 0.0 {
500
+ p.y = y_cap;
501
+ }
502
+ }
503
+
504
+ if max_y > y_cap {
505
+ // Clamp values above the cap
317
506
  for y in ys.iter_mut() {
318
- *y *= scale_factor_y;
507
+ if *y > y_cap {
508
+ *y = y_cap;
509
+ }
319
510
  }
320
511
  for p in point_details.iter_mut() {
321
- p.y *= scale_factor_y;
512
+ if p.y > y_cap {
513
+ p.y = y_cap;
514
+ }
322
515
  }
323
- let scaled_max = ys.iter().cloned().fold(f64::MIN, f64::max);
324
- scaled_max + 0.35 + y_padding
516
+ (y_cap + 0.35 + y_padding, has_capped)
325
517
  } else {
326
- max_y + 0.35 + y_padding
518
+ (max_y + 0.35 + y_padding, has_capped)
327
519
  }
328
520
  } else {
329
- 1.0 + y_padding
521
+ (1.0 + y_padding, false)
330
522
  };
331
523
 
332
524
  // ------------------------------------------------
@@ -386,7 +578,7 @@ fn plot_grin2_manhattan(
386
578
  }
387
579
 
388
580
  // ------------------------------------------------
389
- // 7. capture high-DPR pixel mapping for the points
581
+ // 7. Capture high-DPR pixel mapping for the points
390
582
  // we do not draw the points with plotters (will use tiny-skia for AA)
391
583
  // but use charts.backend_coord to map data->pixel in the high-DPR backend
392
584
  // ------------------------------------------------
@@ -469,6 +661,8 @@ fn plot_grin2_manhattan(
469
661
  y_min,
470
662
  y_max,
471
663
  device_pixel_ratio: dpr,
664
+ default_log_cutoff: log_cutoff,
665
+ has_capped_points,
472
666
  };
473
667
  Ok((png_data, interactive_data))
474
668
  }
@@ -495,6 +689,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
495
689
  let plot_height = &input_json.plot_height;
496
690
  let device_pixel_ratio = &input_json.device_pixel_ratio;
497
691
  let png_dot_radius = &input_json.png_dot_radius;
692
+ let max_capped_points = &input_json.max_capped_points;
693
+ let hard_cap = &input_json.hard_cap;
694
+ let bin_size = &input_json.bin_size;
695
+ let q_value_threshold = &input_json.q_value_threshold;
498
696
  if let Ok((base64_string, plot_data)) = plot_grin2_manhattan(
499
697
  grin2_file.clone(),
500
698
  chrom_size.clone(),
@@ -502,6 +700,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
502
700
  plot_height.clone(),
503
701
  device_pixel_ratio.clone(),
504
702
  png_dot_radius.clone(),
703
+ bin_size.clone(),
704
+ max_capped_points.clone(),
705
+ hard_cap.clone(),
706
+ q_value_threshold.clone(),
505
707
  ) {
506
708
  let output = Output {
507
709
  png: base64_string,
@@ -0,0 +1,152 @@
1
+ // Syntax: cd .. && cargo build --release && time cat ~/sjpp/test.txt | target/release/aichatbot
2
+ #![allow(non_snake_case)]
3
+ use anyhow::Result;
4
+ use json::JsonValue;
5
+ use schemars::JsonSchema;
6
+ use std::io;
7
+ mod aichatbot; // Importing classification agent from aichatbot.rs
8
+ mod ollama; // Importing custom rig module for invoking ollama server
9
+ mod sjprovider; // Importing custom rig module for invoking SJ GPU server
10
+ mod test_ai; // Test examples for AI chatbot
11
+
12
+ #[tokio::main]
13
+ async fn main() -> Result<()> {
14
+ let mut input = String::new();
15
+ match io::stdin().read_line(&mut input) {
16
+ // Accepting the piped input from nodejs (or command line from testing)
17
+ Ok(_n) => {
18
+ let input_json = json::parse(&input);
19
+ match input_json {
20
+ Ok(json_string) => {
21
+ //println!("json_string:{}", json_string);
22
+ let user_input_json: &JsonValue = &json_string["user_input"];
23
+ let user_input: &str;
24
+ match user_input_json.as_str() {
25
+ Some(inp) => user_input = inp,
26
+ None => panic!("user_input field is missing in input json"),
27
+ }
28
+
29
+ if user_input.len() == 0 {
30
+ panic!("The user input is empty");
31
+ }
32
+
33
+ let binpath_json: &JsonValue = &json_string["binpath"];
34
+ let binpath: &str;
35
+ match binpath_json.as_str() {
36
+ Some(inp) => binpath = inp,
37
+ None => panic!("binpath not found"),
38
+ }
39
+
40
+ let aiRoute_json: &JsonValue = &json_string["aiRoute"];
41
+ let aiRoute_str: &str;
42
+ match aiRoute_json.as_str() {
43
+ Some(inp) => aiRoute_str = inp,
44
+ None => panic!("aiRoute field is missing in input json"),
45
+ }
46
+ let airoute = String::from(binpath) + &"/../../" + &aiRoute_str;
47
+
48
+ let apilink_json: &JsonValue = &json_string["apilink"];
49
+ let apilink: &str;
50
+ match apilink_json.as_str() {
51
+ Some(inp) => apilink = inp,
52
+ None => panic!("apilink field is missing in input json"),
53
+ }
54
+
55
+ let comp_model_name_json: &JsonValue = &json_string["comp_model_name"];
56
+ let comp_model_name: &str;
57
+ match comp_model_name_json.as_str() {
58
+ Some(inp) => comp_model_name = inp,
59
+ None => panic!("comp_model_name field is missing in input json"),
60
+ }
61
+
62
+ let embedding_model_name_json: &JsonValue = &json_string["embedding_model_name"];
63
+ let embedding_model_name: &str;
64
+ match embedding_model_name_json.as_str() {
65
+ Some(inp) => embedding_model_name = inp,
66
+ None => panic!("embedding_model_name field is missing in input json"),
67
+ }
68
+
69
+ let llm_backend_name_json: &JsonValue = &json_string["llm_backend_name"];
70
+ let llm_backend_name: &str;
71
+ match llm_backend_name_json.as_str() {
72
+ Some(inp) => llm_backend_name = inp,
73
+ None => panic!("llm_backend_name field is missing in input json"),
74
+ }
75
+
76
+ let llm_backend_type: aichatbot::llm_backend;
77
+ let mut final_output: Option<String> = None;
78
+ let temperature: f64 = 0.01;
79
+ let max_new_tokens: usize = 512;
80
+ let top_p: f32 = 0.95;
81
+ if llm_backend_name != "ollama" && llm_backend_name != "SJ" {
82
+ panic!(
83
+ "This code currently supports only Ollama and SJ provider. llm_backend_name must be \"ollama\" or \"SJ\""
84
+ );
85
+ } else if llm_backend_name == "ollama".to_string() {
86
+ llm_backend_type = aichatbot::llm_backend::Ollama();
87
+ // Initialize Ollama client
88
+ let ollama_client = ollama::Client::builder()
89
+ .base_url(apilink)
90
+ .build()
91
+ .expect("Ollama server not found");
92
+ let embedding_model = ollama_client.embedding_model(embedding_model_name);
93
+ let comp_model = ollama_client.completion_model(comp_model_name);
94
+ final_output = Some(
95
+ aichatbot::classify_query_by_dataset_type(
96
+ user_input,
97
+ comp_model.clone(),
98
+ embedding_model.clone(),
99
+ &llm_backend_type,
100
+ temperature,
101
+ max_new_tokens,
102
+ top_p,
103
+ &airoute,
104
+ )
105
+ .await,
106
+ );
107
+ } else if llm_backend_name == "SJ".to_string() {
108
+ llm_backend_type = aichatbot::llm_backend::Sj();
109
+ // Initialize Sj provider client
110
+ let sj_client = sjprovider::Client::builder()
111
+ .base_url(apilink)
112
+ .build()
113
+ .expect("SJ server not found");
114
+ let embedding_model = sj_client.embedding_model(embedding_model_name);
115
+ let comp_model = sj_client.completion_model(comp_model_name);
116
+ final_output = Some(
117
+ aichatbot::classify_query_by_dataset_type(
118
+ user_input,
119
+ comp_model.clone(),
120
+ embedding_model.clone(),
121
+ &llm_backend_type,
122
+ temperature,
123
+ max_new_tokens,
124
+ top_p,
125
+ &airoute,
126
+ )
127
+ .await,
128
+ );
129
+ }
130
+
131
+ match final_output {
132
+ Some(fin_out) => {
133
+ println!("{{\"{}\":{}}}", "route", fin_out);
134
+ }
135
+ None => {
136
+ println!("{{\"{}\":\"{}\"}}", "route", "unknown");
137
+ }
138
+ }
139
+ }
140
+ Err(error) => println!("Incorrect json:{}", error),
141
+ }
142
+ }
143
+ Err(error) => println!("Piping error: {}", error),
144
+ }
145
+ Ok(())
146
+ }
147
+
148
+ #[derive(Debug, JsonSchema)]
149
+ #[allow(dead_code)]
150
+ struct OutputJson {
151
+ pub answer: String,
152
+ }