@sjcrh/proteinpaint-rust 2.169.0 → 2.170.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Cargo.toml +8 -4
- package/package.json +1 -1
- package/src/aichatbot.rs +231 -254
- package/src/manhattan_plot.rs +225 -23
- package/src/query_classification.rs +152 -0
- package/src/summary_agent.rs +201 -0
- package/src/test_ai.rs +79 -72
package/src/manhattan_plot.rs
CHANGED
|
@@ -3,7 +3,7 @@ use plotters::prelude::*;
|
|
|
3
3
|
use plotters::style::ShapeStyle;
|
|
4
4
|
use serde::{Deserialize, Serialize};
|
|
5
5
|
use serde_json;
|
|
6
|
-
use std::collections::HashMap;
|
|
6
|
+
use std::collections::{HashMap, HashSet};
|
|
7
7
|
use std::convert::TryInto;
|
|
8
8
|
use std::error::Error;
|
|
9
9
|
use std::fs::File;
|
|
@@ -22,6 +22,10 @@ struct Input {
|
|
|
22
22
|
plot_height: u64,
|
|
23
23
|
device_pixel_ratio: f64,
|
|
24
24
|
png_dot_radius: u64,
|
|
25
|
+
max_capped_points: u64,
|
|
26
|
+
hard_cap: f64,
|
|
27
|
+
bin_size: f64,
|
|
28
|
+
q_value_threshold: f64,
|
|
25
29
|
}
|
|
26
30
|
|
|
27
31
|
// chromosome info
|
|
@@ -58,6 +62,8 @@ struct InteractiveData {
|
|
|
58
62
|
y_min: f64,
|
|
59
63
|
y_max: f64,
|
|
60
64
|
device_pixel_ratio: f64,
|
|
65
|
+
default_log_cutoff: f64,
|
|
66
|
+
has_capped_points: bool,
|
|
61
67
|
}
|
|
62
68
|
|
|
63
69
|
#[derive(Serialize)]
|
|
@@ -78,6 +84,134 @@ fn hex_to_rgb(hex: &str) -> Option<(u8, u8, u8)> {
|
|
|
78
84
|
Some((r, g, b))
|
|
79
85
|
}
|
|
80
86
|
|
|
87
|
+
// Helper function to calculate default log cutoff value from the data coming from GRIN2 file
|
|
88
|
+
// We just find the mean of the -log10 q-values that are below the hard cap and
|
|
89
|
+
// set it as the default log cutoff. If the mean is less than 40, we set it to 40.
|
|
90
|
+
// If it is too low it can cause an error in the setting up of the histogram bins in the dynamic y-cap calculation.
|
|
91
|
+
// The exclude_indices parameter allows us to skip placeholder values (e.g., 0.0 placeholders for zero q-values)
|
|
92
|
+
// that would otherwise contaminate the mean calculation.
|
|
93
|
+
fn get_log_cutoff(ys: &[f64], hard_cap: f64, exclude_indices: &HashSet<usize>) -> f64 {
|
|
94
|
+
let filtered: Vec<f64> = ys
|
|
95
|
+
.iter()
|
|
96
|
+
.enumerate()
|
|
97
|
+
.filter(|(i, &y)| y < hard_cap && !exclude_indices.contains(i))
|
|
98
|
+
.map(|(_, &y)| y)
|
|
99
|
+
.collect();
|
|
100
|
+
let count = filtered.len();
|
|
101
|
+
let sum: f64 = filtered.iter().sum();
|
|
102
|
+
|
|
103
|
+
// If all values are greater than or equal to hard_cap (or excluded), default to hard_cap
|
|
104
|
+
if filtered.is_empty() {
|
|
105
|
+
return hard_cap;
|
|
106
|
+
}
|
|
107
|
+
let mean = sum / count as f64;
|
|
108
|
+
|
|
109
|
+
mean.max(40.0)
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
/// Calculates a dynamic y-axis cap for Manhattan plots to handle outliers gracefully.
|
|
113
|
+
///
|
|
114
|
+
/// # Problem
|
|
115
|
+
/// Manhattan plots often have a few extreme outliers (very significant p-values) that
|
|
116
|
+
/// compress the visual range for the majority of points. This function finds an optimal
|
|
117
|
+
/// y-axis cap that:
|
|
118
|
+
/// - Shows most data at true scale
|
|
119
|
+
/// - Caps only a small number of extreme outliers
|
|
120
|
+
/// - Ensures visible outliers (below hard cap) render at their true positions
|
|
121
|
+
///
|
|
122
|
+
/// # Algorithm
|
|
123
|
+
/// 1. **No outliers**: If `max_y <= default_cap`, return `max_y` (no capping needed)
|
|
124
|
+
/// 2. **Histogram binning**: Partition the range `(default_cap, hard_cap]` into fixed-size bins
|
|
125
|
+
/// 3. **Walk up**: Starting from the lowest bin, find the first cap where at most
|
|
126
|
+
/// `max_capped_points` would be clamped
|
|
127
|
+
/// 4. **Preserve visible outliers**: Ensure the chosen cap is above the highest y-value
|
|
128
|
+
/// that falls below `hard_cap`, so those points render at their true positions
|
|
129
|
+
///
|
|
130
|
+
/// # Parameters
|
|
131
|
+
/// - `ys`: All y-values (-log10 q-values) in the plot
|
|
132
|
+
/// - `max_capped_points`: Maximum points allowed to be clamped to the cap (e.g., 5)
|
|
133
|
+
/// - `default_cap`: Starting threshold; points below this are never capped (e.g., whatever log_cutoff is calculated to be from get_log_cutoff)
|
|
134
|
+
/// - `hard_cap`: Absolute maximum y-axis value; points above are always clamped (e.g., 200)
|
|
135
|
+
/// - `bin_size`: Histogram bin width on -log10 scale (e.g., 10)
|
|
136
|
+
///
|
|
137
|
+
/// # Returns
|
|
138
|
+
/// The optimal y-axis cap, guaranteed to be in the range `[max_y.min(default_cap), hard_cap]`
|
|
139
|
+
///
|
|
140
|
+
/// # Example
|
|
141
|
+
/// With `default_cap=40`, `hard_cap=200`, `bin_size=10`, `max_capped_points=5`:
|
|
142
|
+
/// - If 7 points are above 40, with two at 83 and 183 and five at/above 200:
|
|
143
|
+
/// Returns 200, so the points at 83 and 183 display at their true positions while
|
|
144
|
+
/// the 5 extreme outliers are clamped to 200
|
|
145
|
+
fn calculate_dynamic_y_cap(
|
|
146
|
+
ys: &[f64],
|
|
147
|
+
max_capped_points: usize,
|
|
148
|
+
default_cap: f64,
|
|
149
|
+
hard_cap: f64,
|
|
150
|
+
bin_size: f64,
|
|
151
|
+
) -> f64 {
|
|
152
|
+
let mut num_bins = ((hard_cap - default_cap) / bin_size) as usize;
|
|
153
|
+
if num_bins == 0 {
|
|
154
|
+
// Have to make sure num_bins is positive to avoid issues with histogram later
|
|
155
|
+
num_bins = 1;
|
|
156
|
+
}
|
|
157
|
+
let mut histogram = vec![0usize; num_bins];
|
|
158
|
+
let mut max_y = f64::NEG_INFINITY;
|
|
159
|
+
let mut max_y_below_hard_cap = f64::NEG_INFINITY; // Track highest value that's not hard-capped
|
|
160
|
+
let mut points_above_default = 0usize;
|
|
161
|
+
|
|
162
|
+
// Single pass: find max and build histogram simultaneously
|
|
163
|
+
for &y in ys {
|
|
164
|
+
if y > max_y {
|
|
165
|
+
max_y = y;
|
|
166
|
+
}
|
|
167
|
+
if y > default_cap {
|
|
168
|
+
points_above_default += 1;
|
|
169
|
+
if y > hard_cap {
|
|
170
|
+
histogram[num_bins - 1] += 1;
|
|
171
|
+
} else {
|
|
172
|
+
// Track the max y that's at or below the hard cap
|
|
173
|
+
if y > max_y_below_hard_cap {
|
|
174
|
+
max_y_below_hard_cap = y;
|
|
175
|
+
}
|
|
176
|
+
let bin_idx = ((y - default_cap) / bin_size) as usize;
|
|
177
|
+
histogram[bin_idx] += 1;
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
// Case 1: No points exceed default cap - use actual max
|
|
183
|
+
if max_y <= default_cap {
|
|
184
|
+
return max_y;
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
// Walk up from default_cap to hard_cap
|
|
188
|
+
let mut points_above = points_above_default;
|
|
189
|
+
|
|
190
|
+
for (i, &count) in histogram.iter().enumerate() {
|
|
191
|
+
if points_above <= max_capped_points {
|
|
192
|
+
// Found acceptable number of capped points
|
|
193
|
+
let bin_upper_bound = default_cap + ((i + 1) as f64) * bin_size;
|
|
194
|
+
|
|
195
|
+
// The cap should be:
|
|
196
|
+
// 1. At least above max_y_below_hard_cap (so those points render at true position)
|
|
197
|
+
// 2. At most hard_cap
|
|
198
|
+
// 3. But if all outliers are at/above hard_cap, use the bin boundary
|
|
199
|
+
let cap = if max_y_below_hard_cap > bin_upper_bound {
|
|
200
|
+
// There's a visible outlier above this bin - extend cap to show it
|
|
201
|
+
(max_y_below_hard_cap + bin_size).min(hard_cap)
|
|
202
|
+
} else {
|
|
203
|
+
bin_upper_bound.min(hard_cap)
|
|
204
|
+
};
|
|
205
|
+
|
|
206
|
+
return cap;
|
|
207
|
+
}
|
|
208
|
+
points_above -= count;
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
// All points are hard-capped outliers
|
|
212
|
+
hard_cap
|
|
213
|
+
}
|
|
214
|
+
|
|
81
215
|
// Function to Build cumulative chromosome map
|
|
82
216
|
fn cumulative_chrom(
|
|
83
217
|
chrom_size: &HashMap<String, u64>,
|
|
@@ -120,7 +254,18 @@ fn cumulative_chrom(
|
|
|
120
254
|
fn grin2_file_read(
|
|
121
255
|
grin2_file: &str,
|
|
122
256
|
chrom_data: &HashMap<String, ChromInfo>,
|
|
123
|
-
|
|
257
|
+
q_value_threshold: f64,
|
|
258
|
+
) -> Result<
|
|
259
|
+
(
|
|
260
|
+
Vec<u64>,
|
|
261
|
+
Vec<f64>,
|
|
262
|
+
Vec<String>,
|
|
263
|
+
Vec<PointDetail>,
|
|
264
|
+
Vec<usize>,
|
|
265
|
+
Vec<usize>,
|
|
266
|
+
),
|
|
267
|
+
Box<dyn Error>,
|
|
268
|
+
> {
|
|
124
269
|
// Default colours
|
|
125
270
|
let mut colors: HashMap<String, String> = HashMap::new();
|
|
126
271
|
colors.insert("gain".into(), "#FF4444".into());
|
|
@@ -134,6 +279,7 @@ fn grin2_file_read(
|
|
|
134
279
|
let mut colors_vec = Vec::new();
|
|
135
280
|
let mut point_details = Vec::new();
|
|
136
281
|
let mut sig_indices: Vec<usize> = Vec::new();
|
|
282
|
+
let mut zero_q_indices: Vec<usize> = Vec::new();
|
|
137
283
|
|
|
138
284
|
let grin2_file = File::open(grin2_file).expect("Failed to open grin2_result_file");
|
|
139
285
|
let mut reader = BufReader::new(grin2_file);
|
|
@@ -217,13 +363,20 @@ fn grin2_file_read(
|
|
|
217
363
|
Some(q) => q,
|
|
218
364
|
None => continue,
|
|
219
365
|
};
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
Ok(v) if v > 0.0 => v,
|
|
223
|
-
Ok(v) if v == 0.0 => 1e-300, // Treat exact 0 as ~1e-300 so we can still show q-values that are 0 and not filter them out
|
|
366
|
+
let original_q_val: f64 = match q_val_str.parse() {
|
|
367
|
+
Ok(v) if v >= 0.0 => v,
|
|
224
368
|
_ => continue,
|
|
225
369
|
};
|
|
226
|
-
|
|
370
|
+
|
|
371
|
+
// Use a placeholder for zero q-values - these will be updated later
|
|
372
|
+
// after we calculate the dynamic y_cap from the full dataset
|
|
373
|
+
let neg_log10_q = if original_q_val == 0.0 {
|
|
374
|
+
zero_q_indices.push(mut_num);
|
|
375
|
+
0.0 // Placeholder - will be set to y_cap later in plot_grin2_manhattan
|
|
376
|
+
} else {
|
|
377
|
+
-original_q_val.log10()
|
|
378
|
+
};
|
|
379
|
+
|
|
227
380
|
let n_subj_count: Option<i64> = n_idx_opt
|
|
228
381
|
.and_then(|i| fields.get(i))
|
|
229
382
|
.and_then(|s| s.parse::<i64>().ok());
|
|
@@ -234,7 +387,8 @@ fn grin2_file_read(
|
|
|
234
387
|
colors_vec.push(color.clone());
|
|
235
388
|
|
|
236
389
|
// only add significant points for interactivity
|
|
237
|
-
|
|
390
|
+
// We check against the original q-value here so we send back the correct values instead of the 1e-300 used for log transform
|
|
391
|
+
if original_q_val <= q_value_threshold {
|
|
238
392
|
point_details.push(PointDetail {
|
|
239
393
|
x: x_pos,
|
|
240
394
|
y: neg_log10_q,
|
|
@@ -245,7 +399,7 @@ fn grin2_file_read(
|
|
|
245
399
|
start: gene_start,
|
|
246
400
|
end: gene_end,
|
|
247
401
|
pos: gene_start,
|
|
248
|
-
q_value:
|
|
402
|
+
q_value: original_q_val,
|
|
249
403
|
nsubj: n_subj_count,
|
|
250
404
|
pixel_x: 0.0,
|
|
251
405
|
pixel_y: 0.0,
|
|
@@ -256,7 +410,7 @@ fn grin2_file_read(
|
|
|
256
410
|
}
|
|
257
411
|
}
|
|
258
412
|
|
|
259
|
-
Ok((xs, ys, colors_vec, point_details, sig_indices))
|
|
413
|
+
Ok((xs, ys, colors_vec, point_details, sig_indices, zero_q_indices))
|
|
260
414
|
}
|
|
261
415
|
|
|
262
416
|
// Function to create the GRIN2 Manhattan plot
|
|
@@ -267,6 +421,10 @@ fn plot_grin2_manhattan(
|
|
|
267
421
|
plot_height: u64,
|
|
268
422
|
device_pixel_ratio: f64,
|
|
269
423
|
png_dot_radius: u64,
|
|
424
|
+
bin_size: f64,
|
|
425
|
+
max_capped_points: u64,
|
|
426
|
+
hard_cap: f64,
|
|
427
|
+
q_value_threshold: f64,
|
|
270
428
|
) -> Result<(String, InteractiveData), Box<dyn Error>> {
|
|
271
429
|
// ------------------------------------------------
|
|
272
430
|
// 1. Build cumulative chromosome map
|
|
@@ -294,39 +452,73 @@ fn plot_grin2_manhattan(
|
|
|
294
452
|
let mut colors_vec = Vec::new();
|
|
295
453
|
let mut point_details = Vec::new();
|
|
296
454
|
let mut sig_indices = Vec::new();
|
|
455
|
+
let mut zero_q_indices: Vec<usize> = Vec::new();
|
|
297
456
|
|
|
298
|
-
if let Ok((x, y, c, pd, si)) = grin2_file_read(&grin2_result_file, &chrom_data) {
|
|
457
|
+
if let Ok((x, y, c, pd, si, zq)) = grin2_file_read(&grin2_result_file, &chrom_data, q_value_threshold) {
|
|
299
458
|
xs = x;
|
|
300
459
|
ys = y;
|
|
301
460
|
colors_vec = c;
|
|
302
461
|
point_details = pd;
|
|
303
462
|
sig_indices = si;
|
|
463
|
+
zero_q_indices = zq;
|
|
304
464
|
}
|
|
305
465
|
|
|
306
466
|
// ------------------------------------------------
|
|
307
|
-
// 3.
|
|
467
|
+
// 3. Calculate log_cutoff from data and update zero q-values
|
|
468
|
+
// ------------------------------------------------
|
|
469
|
+
// Convert zero_q_indices to HashSet for O(1) lookup when excluding placeholders
|
|
470
|
+
let zero_q_set: HashSet<usize> = zero_q_indices.iter().cloned().collect();
|
|
471
|
+
let log_cutoff = get_log_cutoff(&ys, hard_cap, &zero_q_set);
|
|
472
|
+
|
|
473
|
+
// ------------------------------------------------
|
|
474
|
+
// 4. Y-axis capping with dynamic cap
|
|
308
475
|
// ------------------------------------------------
|
|
309
476
|
let y_padding = png_dot_radius as f64;
|
|
310
477
|
let y_min = 0.0 - y_padding;
|
|
311
|
-
|
|
478
|
+
|
|
479
|
+
// Dynamic y-cap calculation:
|
|
480
|
+
// - log_cutoff: the baseline cap (calculated from data mean)
|
|
481
|
+
// - max_capped_points: maximum number of points allowed above cap before raising it
|
|
482
|
+
// - hard_cap: absolute maximum cap regardless of data distribution
|
|
483
|
+
// - bin_size: size of bins for histogram approach
|
|
484
|
+
let max_capped_points = max_capped_points as usize;
|
|
485
|
+
|
|
486
|
+
let y_cap = calculate_dynamic_y_cap(&ys, max_capped_points, log_cutoff, hard_cap, bin_size);
|
|
487
|
+
|
|
488
|
+
let (y_max, has_capped_points) = if !ys.is_empty() {
|
|
312
489
|
let max_y = ys.iter().cloned().fold(f64::MIN, f64::max);
|
|
313
|
-
if max_y > 40.0 {
|
|
314
|
-
let target = 40.0;
|
|
315
|
-
let scale_factor_y = target / max_y;
|
|
316
490
|
|
|
491
|
+
// has_capped_points is true if any points exceed the default cap (log_cutoff)
|
|
492
|
+
let has_capped = max_y > log_cutoff;
|
|
493
|
+
|
|
494
|
+
// Set q=0 points (currently placeholders at 0.0) to y_cap so they appear at the top
|
|
495
|
+
for &idx in &zero_q_indices {
|
|
496
|
+
ys[idx] = y_cap;
|
|
497
|
+
}
|
|
498
|
+
for p in point_details.iter_mut() {
|
|
499
|
+
if p.q_value == 0.0 {
|
|
500
|
+
p.y = y_cap;
|
|
501
|
+
}
|
|
502
|
+
}
|
|
503
|
+
|
|
504
|
+
if max_y > y_cap {
|
|
505
|
+
// Clamp values above the cap
|
|
317
506
|
for y in ys.iter_mut() {
|
|
318
|
-
*y
|
|
507
|
+
if *y > y_cap {
|
|
508
|
+
*y = y_cap;
|
|
509
|
+
}
|
|
319
510
|
}
|
|
320
511
|
for p in point_details.iter_mut() {
|
|
321
|
-
p.y
|
|
512
|
+
if p.y > y_cap {
|
|
513
|
+
p.y = y_cap;
|
|
514
|
+
}
|
|
322
515
|
}
|
|
323
|
-
|
|
324
|
-
scaled_max + 0.35 + y_padding
|
|
516
|
+
(y_cap + 0.35 + y_padding, has_capped)
|
|
325
517
|
} else {
|
|
326
|
-
max_y + 0.35 + y_padding
|
|
518
|
+
(max_y + 0.35 + y_padding, has_capped)
|
|
327
519
|
}
|
|
328
520
|
} else {
|
|
329
|
-
1.0 + y_padding
|
|
521
|
+
(1.0 + y_padding, false)
|
|
330
522
|
};
|
|
331
523
|
|
|
332
524
|
// ------------------------------------------------
|
|
@@ -386,7 +578,7 @@ fn plot_grin2_manhattan(
|
|
|
386
578
|
}
|
|
387
579
|
|
|
388
580
|
// ------------------------------------------------
|
|
389
|
-
// 7.
|
|
581
|
+
// 7. Capture high-DPR pixel mapping for the points
|
|
390
582
|
// we do not draw the points with plotters (will use tiny-skia for AA)
|
|
391
583
|
// but use charts.backend_coord to map data->pixel in the high-DPR backend
|
|
392
584
|
// ------------------------------------------------
|
|
@@ -469,6 +661,8 @@ fn plot_grin2_manhattan(
|
|
|
469
661
|
y_min,
|
|
470
662
|
y_max,
|
|
471
663
|
device_pixel_ratio: dpr,
|
|
664
|
+
default_log_cutoff: log_cutoff,
|
|
665
|
+
has_capped_points,
|
|
472
666
|
};
|
|
473
667
|
Ok((png_data, interactive_data))
|
|
474
668
|
}
|
|
@@ -495,6 +689,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
|
495
689
|
let plot_height = &input_json.plot_height;
|
|
496
690
|
let device_pixel_ratio = &input_json.device_pixel_ratio;
|
|
497
691
|
let png_dot_radius = &input_json.png_dot_radius;
|
|
692
|
+
let max_capped_points = &input_json.max_capped_points;
|
|
693
|
+
let hard_cap = &input_json.hard_cap;
|
|
694
|
+
let bin_size = &input_json.bin_size;
|
|
695
|
+
let q_value_threshold = &input_json.q_value_threshold;
|
|
498
696
|
if let Ok((base64_string, plot_data)) = plot_grin2_manhattan(
|
|
499
697
|
grin2_file.clone(),
|
|
500
698
|
chrom_size.clone(),
|
|
@@ -502,6 +700,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
|
502
700
|
plot_height.clone(),
|
|
503
701
|
device_pixel_ratio.clone(),
|
|
504
702
|
png_dot_radius.clone(),
|
|
703
|
+
bin_size.clone(),
|
|
704
|
+
max_capped_points.clone(),
|
|
705
|
+
hard_cap.clone(),
|
|
706
|
+
q_value_threshold.clone(),
|
|
505
707
|
) {
|
|
506
708
|
let output = Output {
|
|
507
709
|
png: base64_string,
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
// Syntax: cd .. && cargo build --release && time cat ~/sjpp/test.txt | target/release/aichatbot
|
|
2
|
+
#![allow(non_snake_case)]
|
|
3
|
+
use anyhow::Result;
|
|
4
|
+
use json::JsonValue;
|
|
5
|
+
use schemars::JsonSchema;
|
|
6
|
+
use std::io;
|
|
7
|
+
mod aichatbot; // Importing classification agent from aichatbot.rs
|
|
8
|
+
mod ollama; // Importing custom rig module for invoking ollama server
|
|
9
|
+
mod sjprovider; // Importing custom rig module for invoking SJ GPU server
|
|
10
|
+
mod test_ai; // Test examples for AI chatbot
|
|
11
|
+
|
|
12
|
+
#[tokio::main]
|
|
13
|
+
async fn main() -> Result<()> {
|
|
14
|
+
let mut input = String::new();
|
|
15
|
+
match io::stdin().read_line(&mut input) {
|
|
16
|
+
// Accepting the piped input from nodejs (or command line from testing)
|
|
17
|
+
Ok(_n) => {
|
|
18
|
+
let input_json = json::parse(&input);
|
|
19
|
+
match input_json {
|
|
20
|
+
Ok(json_string) => {
|
|
21
|
+
//println!("json_string:{}", json_string);
|
|
22
|
+
let user_input_json: &JsonValue = &json_string["user_input"];
|
|
23
|
+
let user_input: &str;
|
|
24
|
+
match user_input_json.as_str() {
|
|
25
|
+
Some(inp) => user_input = inp,
|
|
26
|
+
None => panic!("user_input field is missing in input json"),
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
if user_input.len() == 0 {
|
|
30
|
+
panic!("The user input is empty");
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
let binpath_json: &JsonValue = &json_string["binpath"];
|
|
34
|
+
let binpath: &str;
|
|
35
|
+
match binpath_json.as_str() {
|
|
36
|
+
Some(inp) => binpath = inp,
|
|
37
|
+
None => panic!("binpath not found"),
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
let aiRoute_json: &JsonValue = &json_string["aiRoute"];
|
|
41
|
+
let aiRoute_str: &str;
|
|
42
|
+
match aiRoute_json.as_str() {
|
|
43
|
+
Some(inp) => aiRoute_str = inp,
|
|
44
|
+
None => panic!("aiRoute field is missing in input json"),
|
|
45
|
+
}
|
|
46
|
+
let airoute = String::from(binpath) + &"/../../" + &aiRoute_str;
|
|
47
|
+
|
|
48
|
+
let apilink_json: &JsonValue = &json_string["apilink"];
|
|
49
|
+
let apilink: &str;
|
|
50
|
+
match apilink_json.as_str() {
|
|
51
|
+
Some(inp) => apilink = inp,
|
|
52
|
+
None => panic!("apilink field is missing in input json"),
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
let comp_model_name_json: &JsonValue = &json_string["comp_model_name"];
|
|
56
|
+
let comp_model_name: &str;
|
|
57
|
+
match comp_model_name_json.as_str() {
|
|
58
|
+
Some(inp) => comp_model_name = inp,
|
|
59
|
+
None => panic!("comp_model_name field is missing in input json"),
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
let embedding_model_name_json: &JsonValue = &json_string["embedding_model_name"];
|
|
63
|
+
let embedding_model_name: &str;
|
|
64
|
+
match embedding_model_name_json.as_str() {
|
|
65
|
+
Some(inp) => embedding_model_name = inp,
|
|
66
|
+
None => panic!("embedding_model_name field is missing in input json"),
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
let llm_backend_name_json: &JsonValue = &json_string["llm_backend_name"];
|
|
70
|
+
let llm_backend_name: &str;
|
|
71
|
+
match llm_backend_name_json.as_str() {
|
|
72
|
+
Some(inp) => llm_backend_name = inp,
|
|
73
|
+
None => panic!("llm_backend_name field is missing in input json"),
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
let llm_backend_type: aichatbot::llm_backend;
|
|
77
|
+
let mut final_output: Option<String> = None;
|
|
78
|
+
let temperature: f64 = 0.01;
|
|
79
|
+
let max_new_tokens: usize = 512;
|
|
80
|
+
let top_p: f32 = 0.95;
|
|
81
|
+
if llm_backend_name != "ollama" && llm_backend_name != "SJ" {
|
|
82
|
+
panic!(
|
|
83
|
+
"This code currently supports only Ollama and SJ provider. llm_backend_name must be \"ollama\" or \"SJ\""
|
|
84
|
+
);
|
|
85
|
+
} else if llm_backend_name == "ollama".to_string() {
|
|
86
|
+
llm_backend_type = aichatbot::llm_backend::Ollama();
|
|
87
|
+
// Initialize Ollama client
|
|
88
|
+
let ollama_client = ollama::Client::builder()
|
|
89
|
+
.base_url(apilink)
|
|
90
|
+
.build()
|
|
91
|
+
.expect("Ollama server not found");
|
|
92
|
+
let embedding_model = ollama_client.embedding_model(embedding_model_name);
|
|
93
|
+
let comp_model = ollama_client.completion_model(comp_model_name);
|
|
94
|
+
final_output = Some(
|
|
95
|
+
aichatbot::classify_query_by_dataset_type(
|
|
96
|
+
user_input,
|
|
97
|
+
comp_model.clone(),
|
|
98
|
+
embedding_model.clone(),
|
|
99
|
+
&llm_backend_type,
|
|
100
|
+
temperature,
|
|
101
|
+
max_new_tokens,
|
|
102
|
+
top_p,
|
|
103
|
+
&airoute,
|
|
104
|
+
)
|
|
105
|
+
.await,
|
|
106
|
+
);
|
|
107
|
+
} else if llm_backend_name == "SJ".to_string() {
|
|
108
|
+
llm_backend_type = aichatbot::llm_backend::Sj();
|
|
109
|
+
// Initialize Sj provider client
|
|
110
|
+
let sj_client = sjprovider::Client::builder()
|
|
111
|
+
.base_url(apilink)
|
|
112
|
+
.build()
|
|
113
|
+
.expect("SJ server not found");
|
|
114
|
+
let embedding_model = sj_client.embedding_model(embedding_model_name);
|
|
115
|
+
let comp_model = sj_client.completion_model(comp_model_name);
|
|
116
|
+
final_output = Some(
|
|
117
|
+
aichatbot::classify_query_by_dataset_type(
|
|
118
|
+
user_input,
|
|
119
|
+
comp_model.clone(),
|
|
120
|
+
embedding_model.clone(),
|
|
121
|
+
&llm_backend_type,
|
|
122
|
+
temperature,
|
|
123
|
+
max_new_tokens,
|
|
124
|
+
top_p,
|
|
125
|
+
&airoute,
|
|
126
|
+
)
|
|
127
|
+
.await,
|
|
128
|
+
);
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
match final_output {
|
|
132
|
+
Some(fin_out) => {
|
|
133
|
+
println!("{{\"{}\":{}}}", "route", fin_out);
|
|
134
|
+
}
|
|
135
|
+
None => {
|
|
136
|
+
println!("{{\"{}\":\"{}\"}}", "route", "unknown");
|
|
137
|
+
}
|
|
138
|
+
}
|
|
139
|
+
}
|
|
140
|
+
Err(error) => println!("Incorrect json:{}", error),
|
|
141
|
+
}
|
|
142
|
+
}
|
|
143
|
+
Err(error) => println!("Piping error: {}", error),
|
|
144
|
+
}
|
|
145
|
+
Ok(())
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
#[derive(Debug, JsonSchema)]
|
|
149
|
+
#[allow(dead_code)]
|
|
150
|
+
struct OutputJson {
|
|
151
|
+
pub answer: String,
|
|
152
|
+
}
|