gte 0.0.3 → 0.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +122 -10
- data/Rakefile +8 -0
- data/VERSION +1 -1
- data/ext/gte/Cargo.toml +1 -1
- data/ext/gte/src/embedder.rs +34 -268
- data/ext/gte/src/lib.rs +3 -0
- data/ext/gte/src/model_profile.rs +179 -0
- data/ext/gte/src/pipeline.rs +60 -0
- data/ext/gte/src/postprocess.rs +25 -2
- data/ext/gte/src/reranker.rs +120 -0
- data/ext/gte/src/ruby_embedder.rs +165 -7
- data/ext/gte/src/session.rs +9 -39
- data/ext/gte/src/tokenizer.rs +21 -2
- data/ext/gte/tests/inference_integration_test.rs +8 -4
- data/ext/gte/tests/postprocess_unit_test.rs +17 -0
- data/ext/gte/tests/tokenizer_unit_test.rs +4 -1
- data/lib/gte/config.rb +15 -0
- data/lib/gte/model.rb +35 -0
- data/lib/gte/reranker.rb +54 -0
- data/lib/gte/version.rb +5 -0
- data/lib/gte.rb +27 -19
- metadata +10 -2
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
use crate::error::{GteError, Result};
|
|
2
|
+
use crate::model_config::ExtractorMode;
|
|
3
|
+
use ort::session::Session;
|
|
4
|
+
use std::path::{Path, PathBuf};
|
|
5
|
+
|
|
6
|
+
const SUPPORTED_INPUTS: [&str; 3] = ["input_ids", "attention_mask", "token_type_ids"];
|
|
7
|
+
|
|
8
|
+
pub fn resolve_tokenizer_path(dir: &Path) -> Result<PathBuf> {
|
|
9
|
+
let tokenizer_path = dir.join("tokenizer.json");
|
|
10
|
+
if !tokenizer_path.exists() {
|
|
11
|
+
return Err(GteError::Tokenizer(format!(
|
|
12
|
+
"tokenizer.json not found in {}",
|
|
13
|
+
dir.display()
|
|
14
|
+
)));
|
|
15
|
+
}
|
|
16
|
+
Ok(tokenizer_path)
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
pub fn resolve_named_model(dir: &Path, name: &str) -> Result<PathBuf> {
|
|
20
|
+
let candidates = [dir.join("onnx").join(name), dir.join(name)];
|
|
21
|
+
for path in &candidates {
|
|
22
|
+
if path.exists() {
|
|
23
|
+
return Ok(path.clone());
|
|
24
|
+
}
|
|
25
|
+
}
|
|
26
|
+
Err(GteError::Inference(format!(
|
|
27
|
+
"model '{}' not found in {} (checked onnx/{0} and {0})",
|
|
28
|
+
name,
|
|
29
|
+
dir.display()
|
|
30
|
+
)))
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
pub fn resolve_default_text_model(dir: &Path) -> Result<PathBuf> {
|
|
34
|
+
let candidates = [
|
|
35
|
+
dir.join("onnx").join("text_model.onnx"),
|
|
36
|
+
dir.join("text_model.onnx"),
|
|
37
|
+
dir.join("onnx").join("model.onnx"),
|
|
38
|
+
dir.join("model.onnx"),
|
|
39
|
+
];
|
|
40
|
+
for path in &candidates {
|
|
41
|
+
if path.exists() {
|
|
42
|
+
return Ok(path.clone());
|
|
43
|
+
}
|
|
44
|
+
}
|
|
45
|
+
Err(GteError::Inference(format!(
|
|
46
|
+
"no ONNX model found in {} (checked text_model.onnx and model.onnx)",
|
|
47
|
+
dir.display()
|
|
48
|
+
)))
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
pub fn read_max_length(dir: &Path) -> usize {
|
|
52
|
+
(|| -> Option<usize> {
|
|
53
|
+
let contents = std::fs::read_to_string(dir.join("tokenizer_config.json")).ok()?;
|
|
54
|
+
let json: serde_json::Value = serde_json::from_str(&contents).ok()?;
|
|
55
|
+
let v = json.get("model_max_length")?;
|
|
56
|
+
let n = v.as_u64().or_else(|| {
|
|
57
|
+
v.as_f64()
|
|
58
|
+
.filter(|&f| f > 0.0 && f < 1e15)
|
|
59
|
+
.map(|f| f as u64)
|
|
60
|
+
})?;
|
|
61
|
+
Some((n as usize).min(8192))
|
|
62
|
+
})()
|
|
63
|
+
.unwrap_or(512)
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
pub fn validate_supported_text_inputs(session: &Session, api_label: &str) -> Result<()> {
|
|
67
|
+
let unsupported: Vec<String> = session
|
|
68
|
+
.inputs
|
|
69
|
+
.iter()
|
|
70
|
+
.filter(|i| !SUPPORTED_INPUTS.contains(&i.name.as_str()))
|
|
71
|
+
.map(|i| i.name.clone())
|
|
72
|
+
.collect();
|
|
73
|
+
|
|
74
|
+
if unsupported.is_empty() {
|
|
75
|
+
return Ok(());
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
let mut message = format!(
|
|
79
|
+
"unsupported model inputs for {} API: {}",
|
|
80
|
+
api_label,
|
|
81
|
+
unsupported.join(", ")
|
|
82
|
+
);
|
|
83
|
+
if unsupported.iter().any(|n| n == "pixel_values") {
|
|
84
|
+
message.push_str(
|
|
85
|
+
". This looks like a multimodal graph. Provide a text-only export (for example onnx/text_model.onnx).",
|
|
86
|
+
);
|
|
87
|
+
} else {
|
|
88
|
+
message.push_str(". Supported inputs are: input_ids, attention_mask, token_type_ids.");
|
|
89
|
+
}
|
|
90
|
+
Err(GteError::Inference(message))
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
pub fn has_input(session: &Session, name: &str) -> bool {
|
|
94
|
+
session.inputs.iter().any(|input| input.name == name)
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
fn output_name_matches(name: &str, preferred: &str) -> bool {
|
|
98
|
+
let lower = name.to_ascii_lowercase();
|
|
99
|
+
lower == preferred || lower.ends_with(&format!("/{}", preferred))
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
pub fn select_output_tensor(
|
|
103
|
+
session: &Session,
|
|
104
|
+
requested: Option<&str>,
|
|
105
|
+
preferred_outputs: &[&str],
|
|
106
|
+
) -> Result<String> {
|
|
107
|
+
if let Some(requested_name) = requested.map(str::trim).filter(|name| !name.is_empty()) {
|
|
108
|
+
if let Some(output) = session
|
|
109
|
+
.outputs
|
|
110
|
+
.iter()
|
|
111
|
+
.find(|o| output_name_matches(o.name.as_str(), requested_name))
|
|
112
|
+
{
|
|
113
|
+
return Ok(output.name.clone());
|
|
114
|
+
}
|
|
115
|
+
let available = session
|
|
116
|
+
.outputs
|
|
117
|
+
.iter()
|
|
118
|
+
.map(|o| o.name.as_str())
|
|
119
|
+
.collect::<Vec<_>>()
|
|
120
|
+
.join(", ");
|
|
121
|
+
return Err(GteError::Inference(format!(
|
|
122
|
+
"requested output tensor '{}' not found in model outputs: {}",
|
|
123
|
+
requested_name, available
|
|
124
|
+
)));
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
for preferred in preferred_outputs {
|
|
128
|
+
if let Some(output) = session
|
|
129
|
+
.outputs
|
|
130
|
+
.iter()
|
|
131
|
+
.find(|o| output_name_matches(o.name.as_str(), preferred))
|
|
132
|
+
{
|
|
133
|
+
return Ok(output.name.clone());
|
|
134
|
+
}
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
session
|
|
138
|
+
.outputs
|
|
139
|
+
.first()
|
|
140
|
+
.map(|o| o.name.clone())
|
|
141
|
+
.ok_or_else(|| GteError::Inference("model has no outputs".into()))
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
fn output_basename(name: &str) -> &str {
|
|
145
|
+
name.rsplit('/').next().unwrap_or(name)
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
pub fn infer_extraction_mode(session: &Session, output_tensor: &str) -> Result<ExtractorMode> {
|
|
149
|
+
let output = session
|
|
150
|
+
.outputs
|
|
151
|
+
.iter()
|
|
152
|
+
.find(|o| o.name == output_tensor)
|
|
153
|
+
.ok_or_else(|| {
|
|
154
|
+
GteError::Inference(format!(
|
|
155
|
+
"output tensor '{}' not found in model outputs",
|
|
156
|
+
output_tensor
|
|
157
|
+
))
|
|
158
|
+
})?;
|
|
159
|
+
|
|
160
|
+
let ndims = match &output.output_type {
|
|
161
|
+
ort::value::ValueType::Tensor { dimensions, .. } => dimensions.len(),
|
|
162
|
+
other => {
|
|
163
|
+
return Err(GteError::Inference(format!(
|
|
164
|
+
"output is not a tensor: {:?}",
|
|
165
|
+
other
|
|
166
|
+
)))
|
|
167
|
+
}
|
|
168
|
+
};
|
|
169
|
+
|
|
170
|
+
match (output_basename(output_tensor), ndims) {
|
|
171
|
+
("last_hidden_state", 3) => Ok(ExtractorMode::MeanPool),
|
|
172
|
+
(_, 2) => Ok(ExtractorMode::Raw),
|
|
173
|
+
(_, 3) => Ok(ExtractorMode::MeanPool),
|
|
174
|
+
(_, n) => Err(GteError::Inference(format!(
|
|
175
|
+
"unexpected output tensor rank {} for '{}': expected 2 (Raw) or 3 (MeanPool)",
|
|
176
|
+
n, output_tensor
|
|
177
|
+
))),
|
|
178
|
+
}
|
|
179
|
+
}
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
use crate::error::{GteError, Result};
|
|
2
|
+
use crate::tokenizer::Tokenized;
|
|
3
|
+
use ndarray::ArrayView2;
|
|
4
|
+
use ort::session::SessionInputValue;
|
|
5
|
+
use ort::value::Value;
|
|
6
|
+
|
|
7
|
+
pub struct InputTensors<'a> {
|
|
8
|
+
pub inputs: Vec<(&'static str, SessionInputValue<'a>)>,
|
|
9
|
+
pub attention_mask: ArrayView2<'a, i64>,
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
impl<'a> InputTensors<'a> {
|
|
13
|
+
pub fn from_tokenized(tokenized: &'a Tokenized, with_attention_mask: bool) -> Result<Self> {
|
|
14
|
+
let input_ids_view: ArrayView2<'_, i64> = ArrayView2::from_shape(
|
|
15
|
+
(tokenized.rows, tokenized.cols),
|
|
16
|
+
tokenized.input_ids.as_slice(),
|
|
17
|
+
)?;
|
|
18
|
+
let attention_mask: ArrayView2<'_, i64> = ArrayView2::from_shape(
|
|
19
|
+
(tokenized.rows, tokenized.cols),
|
|
20
|
+
tokenized.attn_masks.as_slice(),
|
|
21
|
+
)?;
|
|
22
|
+
|
|
23
|
+
let mut inputs = Vec::with_capacity(2 + usize::from(tokenized.type_ids.is_some()));
|
|
24
|
+
inputs.push((
|
|
25
|
+
"input_ids",
|
|
26
|
+
SessionInputValue::from(Value::from_array(input_ids_view)?),
|
|
27
|
+
));
|
|
28
|
+
|
|
29
|
+
if with_attention_mask {
|
|
30
|
+
inputs.push((
|
|
31
|
+
"attention_mask",
|
|
32
|
+
SessionInputValue::from(Value::from_array(attention_mask)?),
|
|
33
|
+
));
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
if let Some(type_ids) = tokenized.type_ids.as_deref() {
|
|
37
|
+
let type_ids_view: ArrayView2<'_, i64> =
|
|
38
|
+
ArrayView2::from_shape((tokenized.rows, tokenized.cols), type_ids)?;
|
|
39
|
+
inputs.push((
|
|
40
|
+
"token_type_ids",
|
|
41
|
+
SessionInputValue::from(Value::from_array(type_ids_view)?),
|
|
42
|
+
));
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
Ok(Self {
|
|
46
|
+
inputs,
|
|
47
|
+
attention_mask,
|
|
48
|
+
})
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
pub fn extract_output_tensor<'a>(
|
|
53
|
+
outputs: &'a ort::session::SessionOutputs<'a, 'a>,
|
|
54
|
+
output_name: &str,
|
|
55
|
+
) -> Result<ndarray::CowArray<'a, f32, ndarray::IxDyn>> {
|
|
56
|
+
let tensor_value = outputs.get(output_name).ok_or_else(|| {
|
|
57
|
+
GteError::Inference(format!("output tensor '{}' not found in model outputs", output_name))
|
|
58
|
+
})?;
|
|
59
|
+
Ok(tensor_value.try_extract_tensor::<f32>()?.into())
|
|
60
|
+
}
|
data/ext/gte/src/postprocess.rs
CHANGED
|
@@ -75,6 +75,12 @@ pub fn normalize_l2(mut embeddings: Array2<f32>) -> Array2<f32> {
|
|
|
75
75
|
embeddings
|
|
76
76
|
}
|
|
77
77
|
|
|
78
|
+
pub fn sigmoid_scores(mut scores: ndarray::ArrayViewMut1<'_, f32>) {
|
|
79
|
+
scores.map_inplace(|value| {
|
|
80
|
+
*value = 1.0 / (1.0 + (-*value).exp());
|
|
81
|
+
});
|
|
82
|
+
}
|
|
83
|
+
|
|
78
84
|
fn mean_pool_contiguous(
|
|
79
85
|
hidden: &[f32],
|
|
80
86
|
attention_mask: &[i64],
|
|
@@ -87,10 +93,27 @@ fn mean_pool_contiguous(
|
|
|
87
93
|
let mask_base = batch_index * seq;
|
|
88
94
|
let hidden_base = batch_index * seq * dim;
|
|
89
95
|
let output_row = &mut output[batch_index * dim..(batch_index + 1) * dim];
|
|
96
|
+
let mask_row = &attention_mask[mask_base..mask_base + seq];
|
|
97
|
+
|
|
98
|
+
if mask_row.iter().all(|&weight| weight == 1) {
|
|
99
|
+
for token_index in 0..seq {
|
|
100
|
+
let token_base = hidden_base + token_index * dim;
|
|
101
|
+
for dim_index in 0..dim {
|
|
102
|
+
output_row[dim_index] += hidden[token_base + dim_index];
|
|
103
|
+
}
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
let inverse = (seq as f32).recip();
|
|
107
|
+
for value in output_row {
|
|
108
|
+
*value *= inverse;
|
|
109
|
+
}
|
|
110
|
+
continue;
|
|
111
|
+
}
|
|
112
|
+
|
|
90
113
|
let mut weight_sum = 0.0f32;
|
|
91
114
|
|
|
92
|
-
for token_index in
|
|
93
|
-
let weight =
|
|
115
|
+
for (token_index, &weight_raw) in mask_row.iter().enumerate() {
|
|
116
|
+
let weight = weight_raw;
|
|
94
117
|
if weight <= 0 {
|
|
95
118
|
continue;
|
|
96
119
|
}
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
use crate::error::{GteError, Result};
|
|
2
|
+
use crate::model_profile::{
|
|
3
|
+
has_input, read_max_length, resolve_default_text_model, resolve_named_model, resolve_tokenizer_path,
|
|
4
|
+
select_output_tensor, validate_supported_text_inputs,
|
|
5
|
+
};
|
|
6
|
+
use crate::pipeline::{extract_output_tensor, InputTensors};
|
|
7
|
+
use crate::postprocess::sigmoid_scores;
|
|
8
|
+
use crate::session::build_session;
|
|
9
|
+
use crate::tokenizer::Tokenizer;
|
|
10
|
+
use ndarray::Array1;
|
|
11
|
+
use ort::session::Session;
|
|
12
|
+
use std::path::Path;
|
|
13
|
+
|
|
14
|
+
#[derive(Debug, Clone)]
|
|
15
|
+
struct RerankerConfig {
|
|
16
|
+
max_length: usize,
|
|
17
|
+
output_tensor: String,
|
|
18
|
+
with_type_ids: bool,
|
|
19
|
+
with_attention_mask: bool,
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
pub struct Reranker {
|
|
23
|
+
tokenizer: Tokenizer,
|
|
24
|
+
session: Session,
|
|
25
|
+
config: RerankerConfig,
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
impl Reranker {
|
|
29
|
+
pub fn from_dir<P: AsRef<Path>>(
|
|
30
|
+
dir: P,
|
|
31
|
+
num_threads: usize,
|
|
32
|
+
optimization_level: u8,
|
|
33
|
+
model_name: Option<&str>,
|
|
34
|
+
output_tensor_override: Option<&str>,
|
|
35
|
+
max_length_override: Option<usize>,
|
|
36
|
+
) -> Result<Self> {
|
|
37
|
+
let dir = dir.as_ref();
|
|
38
|
+
let tokenizer_path = resolve_tokenizer_path(dir)?;
|
|
39
|
+
let model_path = match model_name.filter(|s| !s.is_empty()) {
|
|
40
|
+
Some(name) => resolve_named_model(dir, name)?,
|
|
41
|
+
None => resolve_default_text_model(dir)?,
|
|
42
|
+
};
|
|
43
|
+
|
|
44
|
+
let max_length = if let Some(override_value) = max_length_override {
|
|
45
|
+
if override_value == 0 {
|
|
46
|
+
return Err(GteError::Inference(
|
|
47
|
+
"max_length override must be greater than 0".to_string(),
|
|
48
|
+
));
|
|
49
|
+
}
|
|
50
|
+
override_value
|
|
51
|
+
} else {
|
|
52
|
+
read_max_length(dir)
|
|
53
|
+
};
|
|
54
|
+
|
|
55
|
+
let probe_config = crate::model_config::ModelConfig {
|
|
56
|
+
max_length,
|
|
57
|
+
output_tensor: String::new(),
|
|
58
|
+
mode: crate::model_config::ExtractorMode::Raw,
|
|
59
|
+
with_type_ids: false,
|
|
60
|
+
with_attention_mask: true,
|
|
61
|
+
num_threads,
|
|
62
|
+
optimization_level,
|
|
63
|
+
};
|
|
64
|
+
let session = build_session(&model_path, &probe_config)?;
|
|
65
|
+
|
|
66
|
+
validate_supported_text_inputs(&session, "text reranking")?;
|
|
67
|
+
let with_type_ids = has_input(&session, "token_type_ids");
|
|
68
|
+
let with_attention_mask = has_input(&session, "attention_mask");
|
|
69
|
+
let output_tensor = select_output_tensor(&session, output_tensor_override, &["logits"])?;
|
|
70
|
+
|
|
71
|
+
let config = RerankerConfig {
|
|
72
|
+
max_length,
|
|
73
|
+
output_tensor,
|
|
74
|
+
with_type_ids,
|
|
75
|
+
with_attention_mask,
|
|
76
|
+
};
|
|
77
|
+
|
|
78
|
+
let tokenizer = Tokenizer::new(&tokenizer_path, config.max_length, config.with_type_ids)?;
|
|
79
|
+
|
|
80
|
+
Ok(Self {
|
|
81
|
+
tokenizer,
|
|
82
|
+
session,
|
|
83
|
+
config,
|
|
84
|
+
})
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
pub fn score_pairs(&self, pairs: &[(String, String)], apply_sigmoid: bool) -> Result<Array1<f32>> {
|
|
88
|
+
let tokenized = self.tokenizer.tokenize_pairs(pairs)?;
|
|
89
|
+
let input_tensors = InputTensors::from_tokenized(&tokenized, self.config.with_attention_mask)?;
|
|
90
|
+
let outputs = self.session.run(input_tensors.inputs)?;
|
|
91
|
+
let array = extract_output_tensor(&outputs, self.config.output_tensor.as_str())?;
|
|
92
|
+
|
|
93
|
+
let mut scores = match array.ndim() {
|
|
94
|
+
1 => array.into_dimensionality::<ndarray::Ix1>()?.into_owned(),
|
|
95
|
+
2 => {
|
|
96
|
+
let shape = array.shape();
|
|
97
|
+
if shape[1] == 0 {
|
|
98
|
+
return Err(GteError::Inference(format!(
|
|
99
|
+
"reranker output '{}' has invalid shape {:?}",
|
|
100
|
+
self.config.output_tensor, shape
|
|
101
|
+
)));
|
|
102
|
+
}
|
|
103
|
+
array.slice(ndarray::s![.., 0]).into_owned()
|
|
104
|
+
}
|
|
105
|
+
n => {
|
|
106
|
+
return Err(GteError::Inference(format!(
|
|
107
|
+
"reranker output '{}' rank {} is unsupported; expected rank 1 or 2",
|
|
108
|
+
self.config.output_tensor, n
|
|
109
|
+
)))
|
|
110
|
+
}
|
|
111
|
+
};
|
|
112
|
+
|
|
113
|
+
if apply_sigmoid {
|
|
114
|
+
sigmoid_scores(scores.view_mut());
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
Ok(scores)
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
}
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
use crate::embedder::{normalize_l2, Embedder};
|
|
4
4
|
use crate::error::GteError;
|
|
5
|
+
use crate::reranker::Reranker;
|
|
5
6
|
use magnus::{function, method, prelude::*, wrap, Error, RArray, Ruby};
|
|
6
7
|
use std::os::raw::c_void;
|
|
7
8
|
use std::panic::{catch_unwind, AssertUnwindSafe};
|
|
@@ -10,6 +11,13 @@ use std::sync::Arc;
|
|
|
10
11
|
#[wrap(class = "GTE::Embedder", free_immediately, size)]
|
|
11
12
|
pub struct RbEmbedder {
|
|
12
13
|
inner: Arc<Embedder>,
|
|
14
|
+
normalize: bool,
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
#[wrap(class = "GTE::Reranker", free_immediately, size)]
|
|
18
|
+
pub struct RbReranker {
|
|
19
|
+
inner: Arc<Reranker>,
|
|
20
|
+
sigmoid: bool,
|
|
13
21
|
}
|
|
14
22
|
|
|
15
23
|
#[wrap(class = "GTE::Tensor", free_immediately, size)]
|
|
@@ -22,11 +30,21 @@ pub struct RbTensor {
|
|
|
22
30
|
struct InferArgs {
|
|
23
31
|
embedder: *const Embedder,
|
|
24
32
|
texts: *const Vec<String>,
|
|
33
|
+
normalize: bool,
|
|
25
34
|
result: Option<Result<ndarray::Array2<f32>, GteError>>,
|
|
26
35
|
}
|
|
27
36
|
|
|
28
37
|
unsafe impl Send for InferArgs {}
|
|
29
38
|
|
|
39
|
+
struct ScoreArgs {
|
|
40
|
+
reranker: *const Reranker,
|
|
41
|
+
pairs: *const Vec<(String, String)>,
|
|
42
|
+
apply_sigmoid: bool,
|
|
43
|
+
result: Option<Result<Vec<f32>, GteError>>,
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
unsafe impl Send for ScoreArgs {}
|
|
47
|
+
|
|
30
48
|
fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
|
|
31
49
|
if let Some(msg) = payload.downcast_ref::<&str>() {
|
|
32
50
|
(*msg).to_string()
|
|
@@ -37,11 +55,16 @@ fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
|
|
|
37
55
|
}
|
|
38
56
|
}
|
|
39
57
|
|
|
40
|
-
fn infer_without_gvl(
|
|
58
|
+
fn infer_without_gvl(
|
|
59
|
+
embedder: &Arc<Embedder>,
|
|
60
|
+
normalize: bool,
|
|
61
|
+
texts: Vec<String>,
|
|
62
|
+
) -> Result<ndarray::Array2<f32>, Error> {
|
|
41
63
|
let embeddings = unsafe {
|
|
42
64
|
let mut args = InferArgs {
|
|
43
65
|
embedder: Arc::as_ptr(embedder),
|
|
44
66
|
texts: &texts as *const Vec<String>,
|
|
67
|
+
normalize,
|
|
45
68
|
result: None,
|
|
46
69
|
};
|
|
47
70
|
rb_sys::rb_thread_call_without_gvl(
|
|
@@ -60,12 +83,44 @@ fn infer_without_gvl(embedder: &Arc<Embedder>, texts: Vec<String>) -> Result<nda
|
|
|
60
83
|
Ok(embeddings)
|
|
61
84
|
}
|
|
62
85
|
|
|
86
|
+
fn score_without_gvl(
|
|
87
|
+
reranker: &Arc<Reranker>,
|
|
88
|
+
pairs: Vec<(String, String)>,
|
|
89
|
+
apply_sigmoid: bool,
|
|
90
|
+
) -> Result<Vec<f32>, Error> {
|
|
91
|
+
let scores = unsafe {
|
|
92
|
+
let mut args = ScoreArgs {
|
|
93
|
+
reranker: Arc::as_ptr(reranker),
|
|
94
|
+
pairs: &pairs as *const Vec<(String, String)>,
|
|
95
|
+
apply_sigmoid,
|
|
96
|
+
result: None,
|
|
97
|
+
};
|
|
98
|
+
rb_sys::rb_thread_call_without_gvl(
|
|
99
|
+
Some(run_score_without_gvl),
|
|
100
|
+
&mut args as *mut ScoreArgs as *mut c_void,
|
|
101
|
+
None,
|
|
102
|
+
std::ptr::null_mut(),
|
|
103
|
+
);
|
|
104
|
+
let result = args.result.take().ok_or_else(|| {
|
|
105
|
+
magnus::Error::from(GteError::Inference(
|
|
106
|
+
"reranking did not return a result".to_string(),
|
|
107
|
+
))
|
|
108
|
+
})?;
|
|
109
|
+
result.map_err(magnus::Error::from)?
|
|
110
|
+
};
|
|
111
|
+
Ok(scores)
|
|
112
|
+
}
|
|
113
|
+
|
|
63
114
|
unsafe extern "C" fn run_without_gvl(ptr: *mut c_void) -> *mut c_void {
|
|
64
115
|
let args = &mut *(ptr as *mut InferArgs);
|
|
65
116
|
let run_result = catch_unwind(AssertUnwindSafe(|| {
|
|
66
117
|
let tokenized = (*args.embedder).tokenize(&*args.texts)?;
|
|
67
118
|
let embeddings = (*args.embedder).run(&tokenized)?;
|
|
68
|
-
|
|
119
|
+
if args.normalize {
|
|
120
|
+
Ok(normalize_l2(embeddings))
|
|
121
|
+
} else {
|
|
122
|
+
Ok(embeddings)
|
|
123
|
+
}
|
|
69
124
|
}));
|
|
70
125
|
args.result = Some(match run_result {
|
|
71
126
|
Ok(result) => result,
|
|
@@ -77,6 +132,22 @@ unsafe extern "C" fn run_without_gvl(ptr: *mut c_void) -> *mut c_void {
|
|
|
77
132
|
std::ptr::null_mut()
|
|
78
133
|
}
|
|
79
134
|
|
|
135
|
+
unsafe extern "C" fn run_score_without_gvl(ptr: *mut c_void) -> *mut c_void {
|
|
136
|
+
let args = &mut *(ptr as *mut ScoreArgs);
|
|
137
|
+
let run_result = catch_unwind(AssertUnwindSafe(|| {
|
|
138
|
+
let scores = (*args.reranker).score_pairs(&*args.pairs, args.apply_sigmoid)?;
|
|
139
|
+
Ok(scores.to_vec())
|
|
140
|
+
}));
|
|
141
|
+
args.result = Some(match run_result {
|
|
142
|
+
Ok(result) => result,
|
|
143
|
+
Err(payload) => Err(GteError::Inference(format!(
|
|
144
|
+
"panic during reranking: {}",
|
|
145
|
+
panic_payload_to_string(payload),
|
|
146
|
+
))),
|
|
147
|
+
});
|
|
148
|
+
std::ptr::null_mut()
|
|
149
|
+
}
|
|
150
|
+
|
|
80
151
|
fn tensor_from_array(embeddings: ndarray::Array2<f32>) -> Result<RbTensor, Error> {
|
|
81
152
|
let rows = embeddings.nrows();
|
|
82
153
|
let cols = embeddings.ncols();
|
|
@@ -97,31 +168,114 @@ impl RbEmbedder {
|
|
|
97
168
|
num_threads: usize,
|
|
98
169
|
optimization_level: u8,
|
|
99
170
|
model_name: String,
|
|
171
|
+
normalize: bool,
|
|
172
|
+
output_tensor: String,
|
|
173
|
+
max_length: usize,
|
|
100
174
|
) -> Result<Self, Error> {
|
|
101
175
|
let name = if model_name.is_empty() {
|
|
102
176
|
None
|
|
103
177
|
} else {
|
|
104
178
|
Some(model_name.as_str())
|
|
105
179
|
};
|
|
106
|
-
let
|
|
107
|
-
|
|
180
|
+
let output_override = if output_tensor.is_empty() {
|
|
181
|
+
None
|
|
182
|
+
} else {
|
|
183
|
+
Some(output_tensor.as_str())
|
|
184
|
+
};
|
|
185
|
+
let max_length_override = if max_length == 0 {
|
|
186
|
+
None
|
|
187
|
+
} else {
|
|
188
|
+
Some(max_length)
|
|
189
|
+
};
|
|
190
|
+
let embedder = Embedder::from_dir(
|
|
191
|
+
&dir_path,
|
|
192
|
+
num_threads,
|
|
193
|
+
optimization_level,
|
|
194
|
+
name,
|
|
195
|
+
output_override,
|
|
196
|
+
max_length_override,
|
|
197
|
+
)
|
|
198
|
+
.map_err(magnus::Error::from)?;
|
|
108
199
|
Ok(RbEmbedder {
|
|
109
200
|
inner: Arc::new(embedder),
|
|
201
|
+
normalize,
|
|
110
202
|
})
|
|
111
203
|
}
|
|
112
204
|
|
|
113
205
|
pub fn rb_embed(_ruby: &Ruby, rb_self: &Self, texts: RArray) -> Result<RbTensor, Error> {
|
|
114
206
|
let texts: Vec<String> = texts.to_vec()?;
|
|
115
|
-
let embeddings = infer_without_gvl(&rb_self.inner, texts)?;
|
|
207
|
+
let embeddings = infer_without_gvl(&rb_self.inner, rb_self.normalize, texts)?;
|
|
116
208
|
tensor_from_array(embeddings)
|
|
117
209
|
}
|
|
118
210
|
|
|
119
211
|
pub fn rb_embed_one(_ruby: &Ruby, rb_self: &Self, text: String) -> Result<RbTensor, Error> {
|
|
120
|
-
let embeddings = infer_without_gvl(&rb_self.inner, vec![text])?;
|
|
212
|
+
let embeddings = infer_without_gvl(&rb_self.inner, rb_self.normalize, vec![text])?;
|
|
121
213
|
tensor_from_array(embeddings)
|
|
122
214
|
}
|
|
123
215
|
}
|
|
124
216
|
|
|
217
|
+
impl RbReranker {
|
|
218
|
+
pub fn rb_new(
|
|
219
|
+
_ruby: &Ruby,
|
|
220
|
+
dir_path: String,
|
|
221
|
+
num_threads: usize,
|
|
222
|
+
optimization_level: u8,
|
|
223
|
+
model_name: String,
|
|
224
|
+
sigmoid: bool,
|
|
225
|
+
output_tensor: String,
|
|
226
|
+
max_length: usize,
|
|
227
|
+
) -> Result<Self, Error> {
|
|
228
|
+
let name = if model_name.is_empty() {
|
|
229
|
+
None
|
|
230
|
+
} else {
|
|
231
|
+
Some(model_name.as_str())
|
|
232
|
+
};
|
|
233
|
+
let output_override = if output_tensor.is_empty() {
|
|
234
|
+
None
|
|
235
|
+
} else {
|
|
236
|
+
Some(output_tensor.as_str())
|
|
237
|
+
};
|
|
238
|
+
let max_length_override = if max_length == 0 {
|
|
239
|
+
None
|
|
240
|
+
} else {
|
|
241
|
+
Some(max_length)
|
|
242
|
+
};
|
|
243
|
+
let reranker = Reranker::from_dir(
|
|
244
|
+
&dir_path,
|
|
245
|
+
num_threads,
|
|
246
|
+
optimization_level,
|
|
247
|
+
name,
|
|
248
|
+
output_override,
|
|
249
|
+
max_length_override,
|
|
250
|
+
)
|
|
251
|
+
.map_err(magnus::Error::from)?;
|
|
252
|
+
Ok(RbReranker {
|
|
253
|
+
inner: Arc::new(reranker),
|
|
254
|
+
sigmoid,
|
|
255
|
+
})
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
pub fn rb_score(
|
|
259
|
+
ruby: &Ruby,
|
|
260
|
+
rb_self: &Self,
|
|
261
|
+
query: String,
|
|
262
|
+
candidates: RArray,
|
|
263
|
+
) -> Result<RArray, Error> {
|
|
264
|
+
let candidates: Vec<String> = candidates.to_vec()?;
|
|
265
|
+
let pairs: Vec<(String, String)> = candidates
|
|
266
|
+
.into_iter()
|
|
267
|
+
.map(|candidate| (query.clone(), candidate))
|
|
268
|
+
.collect();
|
|
269
|
+
let scores = score_without_gvl(&rb_self.inner, pairs, rb_self.sigmoid)?;
|
|
270
|
+
|
|
271
|
+
let out = ruby.ary_new_capa(scores.len());
|
|
272
|
+
for score in scores {
|
|
273
|
+
out.push(score)?;
|
|
274
|
+
}
|
|
275
|
+
Ok(out)
|
|
276
|
+
}
|
|
277
|
+
}
|
|
278
|
+
|
|
125
279
|
impl RbTensor {
|
|
126
280
|
pub fn len(&self) -> usize {
|
|
127
281
|
self.rows
|
|
@@ -208,10 +362,14 @@ impl RbTensor {
|
|
|
208
362
|
pub fn register(ruby: &Ruby) -> Result<(), Error> {
|
|
209
363
|
let module = ruby.define_module("GTE")?;
|
|
210
364
|
let embedder_class = module.define_class("Embedder", ruby.class_object())?;
|
|
211
|
-
embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new,
|
|
365
|
+
embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 7))?;
|
|
212
366
|
embedder_class.define_method("embed", method!(RbEmbedder::rb_embed, 1))?;
|
|
213
367
|
embedder_class.define_method("embed_one", method!(RbEmbedder::rb_embed_one, 1))?;
|
|
214
368
|
|
|
369
|
+
let reranker_class = module.define_class("Reranker", ruby.class_object())?;
|
|
370
|
+
reranker_class.define_singleton_method("new", function!(RbReranker::rb_new, 7))?;
|
|
371
|
+
reranker_class.define_method("score", method!(RbReranker::rb_score, 2))?;
|
|
372
|
+
|
|
215
373
|
let tensor_class = module.define_class("Tensor", ruby.class_object())?;
|
|
216
374
|
tensor_class.define_method("rows", method!(RbTensor::rows, 0))?;
|
|
217
375
|
tensor_class.define_method("size", method!(RbTensor::len, 0))?;
|