gte 0.0.4 → 0.0.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +150 -14
- data/Rakefile +2 -2
- data/VERSION +1 -1
- data/ext/gte/Cargo.toml +1 -1
- data/ext/gte/src/embedder.rs +38 -253
- data/ext/gte/src/lib.rs +3 -0
- data/ext/gte/src/model_config.rs +1 -0
- data/ext/gte/src/model_profile.rs +179 -0
- data/ext/gte/src/pipeline.rs +60 -0
- data/ext/gte/src/postprocess.rs +6 -0
- data/ext/gte/src/reranker.rs +122 -0
- data/ext/gte/src/ruby_embedder.rs +179 -7
- data/ext/gte/src/session.rs +76 -46
- data/ext/gte/src/tokenizer.rs +21 -2
- data/ext/gte/tests/inference_integration_test.rs +8 -4
- data/ext/gte/tests/postprocess_unit_test.rs +17 -0
- data/ext/gte/tests/tokenizer_unit_test.rs +4 -1
- data/lib/gte/config.rb +15 -0
- data/lib/gte/embedder.rb +41 -0
- data/lib/gte/model.rb +27 -0
- data/lib/gte/reranker.rb +56 -0
- data/lib/gte/version.rb +5 -0
- data/lib/gte.rb +26 -35
- metadata +11 -2
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
use crate::error::{GteError, Result};
|
|
2
|
+
use crate::model_config::ExtractorMode;
|
|
3
|
+
use ort::session::Session;
|
|
4
|
+
use std::path::{Path, PathBuf};
|
|
5
|
+
|
|
6
|
+
const SUPPORTED_INPUTS: [&str; 3] = ["input_ids", "attention_mask", "token_type_ids"];
|
|
7
|
+
|
|
8
|
+
pub fn resolve_tokenizer_path(dir: &Path) -> Result<PathBuf> {
|
|
9
|
+
let tokenizer_path = dir.join("tokenizer.json");
|
|
10
|
+
if !tokenizer_path.exists() {
|
|
11
|
+
return Err(GteError::Tokenizer(format!(
|
|
12
|
+
"tokenizer.json not found in {}",
|
|
13
|
+
dir.display()
|
|
14
|
+
)));
|
|
15
|
+
}
|
|
16
|
+
Ok(tokenizer_path)
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
pub fn resolve_named_model(dir: &Path, name: &str) -> Result<PathBuf> {
|
|
20
|
+
let candidates = [dir.join("onnx").join(name), dir.join(name)];
|
|
21
|
+
for path in &candidates {
|
|
22
|
+
if path.exists() {
|
|
23
|
+
return Ok(path.clone());
|
|
24
|
+
}
|
|
25
|
+
}
|
|
26
|
+
Err(GteError::Inference(format!(
|
|
27
|
+
"model '{}' not found in {} (checked onnx/{0} and {0})",
|
|
28
|
+
name,
|
|
29
|
+
dir.display()
|
|
30
|
+
)))
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
pub fn resolve_default_text_model(dir: &Path) -> Result<PathBuf> {
|
|
34
|
+
let candidates = [
|
|
35
|
+
dir.join("onnx").join("text_model.onnx"),
|
|
36
|
+
dir.join("text_model.onnx"),
|
|
37
|
+
dir.join("onnx").join("model.onnx"),
|
|
38
|
+
dir.join("model.onnx"),
|
|
39
|
+
];
|
|
40
|
+
for path in &candidates {
|
|
41
|
+
if path.exists() {
|
|
42
|
+
return Ok(path.clone());
|
|
43
|
+
}
|
|
44
|
+
}
|
|
45
|
+
Err(GteError::Inference(format!(
|
|
46
|
+
"no ONNX model found in {} (checked text_model.onnx and model.onnx)",
|
|
47
|
+
dir.display()
|
|
48
|
+
)))
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
pub fn read_max_length(dir: &Path) -> usize {
|
|
52
|
+
(|| -> Option<usize> {
|
|
53
|
+
let contents = std::fs::read_to_string(dir.join("tokenizer_config.json")).ok()?;
|
|
54
|
+
let json: serde_json::Value = serde_json::from_str(&contents).ok()?;
|
|
55
|
+
let v = json.get("model_max_length")?;
|
|
56
|
+
let n = v.as_u64().or_else(|| {
|
|
57
|
+
v.as_f64()
|
|
58
|
+
.filter(|&f| f > 0.0 && f < 1e15)
|
|
59
|
+
.map(|f| f as u64)
|
|
60
|
+
})?;
|
|
61
|
+
Some((n as usize).min(8192))
|
|
62
|
+
})()
|
|
63
|
+
.unwrap_or(512)
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
pub fn validate_supported_text_inputs(session: &Session, api_label: &str) -> Result<()> {
|
|
67
|
+
let unsupported: Vec<String> = session
|
|
68
|
+
.inputs
|
|
69
|
+
.iter()
|
|
70
|
+
.filter(|i| !SUPPORTED_INPUTS.contains(&i.name.as_str()))
|
|
71
|
+
.map(|i| i.name.clone())
|
|
72
|
+
.collect();
|
|
73
|
+
|
|
74
|
+
if unsupported.is_empty() {
|
|
75
|
+
return Ok(());
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
let mut message = format!(
|
|
79
|
+
"unsupported model inputs for {} API: {}",
|
|
80
|
+
api_label,
|
|
81
|
+
unsupported.join(", ")
|
|
82
|
+
);
|
|
83
|
+
if unsupported.iter().any(|n| n == "pixel_values") {
|
|
84
|
+
message.push_str(
|
|
85
|
+
". This looks like a multimodal graph. Provide a text-only export (for example onnx/text_model.onnx).",
|
|
86
|
+
);
|
|
87
|
+
} else {
|
|
88
|
+
message.push_str(". Supported inputs are: input_ids, attention_mask, token_type_ids.");
|
|
89
|
+
}
|
|
90
|
+
Err(GteError::Inference(message))
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
pub fn has_input(session: &Session, name: &str) -> bool {
|
|
94
|
+
session.inputs.iter().any(|input| input.name == name)
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
fn output_name_matches(name: &str, preferred: &str) -> bool {
|
|
98
|
+
let lower = name.to_ascii_lowercase();
|
|
99
|
+
lower == preferred || lower.ends_with(&format!("/{}", preferred))
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
pub fn select_output_tensor(
|
|
103
|
+
session: &Session,
|
|
104
|
+
requested: Option<&str>,
|
|
105
|
+
preferred_outputs: &[&str],
|
|
106
|
+
) -> Result<String> {
|
|
107
|
+
if let Some(requested_name) = requested.map(str::trim).filter(|name| !name.is_empty()) {
|
|
108
|
+
if let Some(output) = session
|
|
109
|
+
.outputs
|
|
110
|
+
.iter()
|
|
111
|
+
.find(|o| output_name_matches(o.name.as_str(), requested_name))
|
|
112
|
+
{
|
|
113
|
+
return Ok(output.name.clone());
|
|
114
|
+
}
|
|
115
|
+
let available = session
|
|
116
|
+
.outputs
|
|
117
|
+
.iter()
|
|
118
|
+
.map(|o| o.name.as_str())
|
|
119
|
+
.collect::<Vec<_>>()
|
|
120
|
+
.join(", ");
|
|
121
|
+
return Err(GteError::Inference(format!(
|
|
122
|
+
"requested output tensor '{}' not found in model outputs: {}",
|
|
123
|
+
requested_name, available
|
|
124
|
+
)));
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
for preferred in preferred_outputs {
|
|
128
|
+
if let Some(output) = session
|
|
129
|
+
.outputs
|
|
130
|
+
.iter()
|
|
131
|
+
.find(|o| output_name_matches(o.name.as_str(), preferred))
|
|
132
|
+
{
|
|
133
|
+
return Ok(output.name.clone());
|
|
134
|
+
}
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
session
|
|
138
|
+
.outputs
|
|
139
|
+
.first()
|
|
140
|
+
.map(|o| o.name.clone())
|
|
141
|
+
.ok_or_else(|| GteError::Inference("model has no outputs".into()))
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
fn output_basename(name: &str) -> &str {
|
|
145
|
+
name.rsplit('/').next().unwrap_or(name)
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
pub fn infer_extraction_mode(session: &Session, output_tensor: &str) -> Result<ExtractorMode> {
|
|
149
|
+
let output = session
|
|
150
|
+
.outputs
|
|
151
|
+
.iter()
|
|
152
|
+
.find(|o| o.name == output_tensor)
|
|
153
|
+
.ok_or_else(|| {
|
|
154
|
+
GteError::Inference(format!(
|
|
155
|
+
"output tensor '{}' not found in model outputs",
|
|
156
|
+
output_tensor
|
|
157
|
+
))
|
|
158
|
+
})?;
|
|
159
|
+
|
|
160
|
+
let ndims = match &output.output_type {
|
|
161
|
+
ort::value::ValueType::Tensor { dimensions, .. } => dimensions.len(),
|
|
162
|
+
other => {
|
|
163
|
+
return Err(GteError::Inference(format!(
|
|
164
|
+
"output is not a tensor: {:?}",
|
|
165
|
+
other
|
|
166
|
+
)))
|
|
167
|
+
}
|
|
168
|
+
};
|
|
169
|
+
|
|
170
|
+
match (output_basename(output_tensor), ndims) {
|
|
171
|
+
("last_hidden_state", 3) => Ok(ExtractorMode::MeanPool),
|
|
172
|
+
(_, 2) => Ok(ExtractorMode::Raw),
|
|
173
|
+
(_, 3) => Ok(ExtractorMode::MeanPool),
|
|
174
|
+
(_, n) => Err(GteError::Inference(format!(
|
|
175
|
+
"unexpected output tensor rank {} for '{}': expected 2 (Raw) or 3 (MeanPool)",
|
|
176
|
+
n, output_tensor
|
|
177
|
+
))),
|
|
178
|
+
}
|
|
179
|
+
}
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
use crate::error::{GteError, Result};
|
|
2
|
+
use crate::tokenizer::Tokenized;
|
|
3
|
+
use ndarray::ArrayView2;
|
|
4
|
+
use ort::session::SessionInputValue;
|
|
5
|
+
use ort::value::Value;
|
|
6
|
+
|
|
7
|
+
pub struct InputTensors<'a> {
|
|
8
|
+
pub inputs: Vec<(&'static str, SessionInputValue<'a>)>,
|
|
9
|
+
pub attention_mask: ArrayView2<'a, i64>,
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
impl<'a> InputTensors<'a> {
|
|
13
|
+
pub fn from_tokenized(tokenized: &'a Tokenized, with_attention_mask: bool) -> Result<Self> {
|
|
14
|
+
let input_ids_view: ArrayView2<'_, i64> = ArrayView2::from_shape(
|
|
15
|
+
(tokenized.rows, tokenized.cols),
|
|
16
|
+
tokenized.input_ids.as_slice(),
|
|
17
|
+
)?;
|
|
18
|
+
let attention_mask: ArrayView2<'_, i64> = ArrayView2::from_shape(
|
|
19
|
+
(tokenized.rows, tokenized.cols),
|
|
20
|
+
tokenized.attn_masks.as_slice(),
|
|
21
|
+
)?;
|
|
22
|
+
|
|
23
|
+
let mut inputs = Vec::with_capacity(2 + usize::from(tokenized.type_ids.is_some()));
|
|
24
|
+
inputs.push((
|
|
25
|
+
"input_ids",
|
|
26
|
+
SessionInputValue::from(Value::from_array(input_ids_view)?),
|
|
27
|
+
));
|
|
28
|
+
|
|
29
|
+
if with_attention_mask {
|
|
30
|
+
inputs.push((
|
|
31
|
+
"attention_mask",
|
|
32
|
+
SessionInputValue::from(Value::from_array(attention_mask)?),
|
|
33
|
+
));
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
if let Some(type_ids) = tokenized.type_ids.as_deref() {
|
|
37
|
+
let type_ids_view: ArrayView2<'_, i64> =
|
|
38
|
+
ArrayView2::from_shape((tokenized.rows, tokenized.cols), type_ids)?;
|
|
39
|
+
inputs.push((
|
|
40
|
+
"token_type_ids",
|
|
41
|
+
SessionInputValue::from(Value::from_array(type_ids_view)?),
|
|
42
|
+
));
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
Ok(Self {
|
|
46
|
+
inputs,
|
|
47
|
+
attention_mask,
|
|
48
|
+
})
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
pub fn extract_output_tensor<'a>(
|
|
53
|
+
outputs: &'a ort::session::SessionOutputs<'a, 'a>,
|
|
54
|
+
output_name: &str,
|
|
55
|
+
) -> Result<ndarray::CowArray<'a, f32, ndarray::IxDyn>> {
|
|
56
|
+
let tensor_value = outputs.get(output_name).ok_or_else(|| {
|
|
57
|
+
GteError::Inference(format!("output tensor '{}' not found in model outputs", output_name))
|
|
58
|
+
})?;
|
|
59
|
+
Ok(tensor_value.try_extract_tensor::<f32>()?.into())
|
|
60
|
+
}
|
data/ext/gte/src/postprocess.rs
CHANGED
|
@@ -75,6 +75,12 @@ pub fn normalize_l2(mut embeddings: Array2<f32>) -> Array2<f32> {
|
|
|
75
75
|
embeddings
|
|
76
76
|
}
|
|
77
77
|
|
|
78
|
+
pub fn sigmoid_scores(mut scores: ndarray::ArrayViewMut1<'_, f32>) {
|
|
79
|
+
scores.map_inplace(|value| {
|
|
80
|
+
*value = 1.0 / (1.0 + (-*value).exp());
|
|
81
|
+
});
|
|
82
|
+
}
|
|
83
|
+
|
|
78
84
|
fn mean_pool_contiguous(
|
|
79
85
|
hidden: &[f32],
|
|
80
86
|
attention_mask: &[i64],
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
use crate::error::{GteError, Result};
|
|
2
|
+
use crate::model_profile::{
|
|
3
|
+
has_input, read_max_length, resolve_default_text_model, resolve_named_model, resolve_tokenizer_path,
|
|
4
|
+
select_output_tensor, validate_supported_text_inputs,
|
|
5
|
+
};
|
|
6
|
+
use crate::pipeline::{extract_output_tensor, InputTensors};
|
|
7
|
+
use crate::postprocess::sigmoid_scores;
|
|
8
|
+
use crate::session::build_session;
|
|
9
|
+
use crate::tokenizer::Tokenizer;
|
|
10
|
+
use ndarray::Array1;
|
|
11
|
+
use ort::session::Session;
|
|
12
|
+
use std::path::Path;
|
|
13
|
+
|
|
14
|
+
#[derive(Debug, Clone)]
|
|
15
|
+
struct RerankerConfig {
|
|
16
|
+
max_length: usize,
|
|
17
|
+
output_tensor: String,
|
|
18
|
+
with_type_ids: bool,
|
|
19
|
+
with_attention_mask: bool,
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
pub struct Reranker {
|
|
23
|
+
tokenizer: Tokenizer,
|
|
24
|
+
session: Session,
|
|
25
|
+
config: RerankerConfig,
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
impl Reranker {
|
|
29
|
+
pub fn from_dir<P: AsRef<Path>>(
|
|
30
|
+
dir: P,
|
|
31
|
+
num_threads: usize,
|
|
32
|
+
optimization_level: u8,
|
|
33
|
+
model_name: Option<&str>,
|
|
34
|
+
output_tensor_override: Option<&str>,
|
|
35
|
+
max_length_override: Option<usize>,
|
|
36
|
+
execution_providers_override: Option<&str>,
|
|
37
|
+
) -> Result<Self> {
|
|
38
|
+
let dir = dir.as_ref();
|
|
39
|
+
let tokenizer_path = resolve_tokenizer_path(dir)?;
|
|
40
|
+
let model_path = match model_name.filter(|s| !s.is_empty()) {
|
|
41
|
+
Some(name) => resolve_named_model(dir, name)?,
|
|
42
|
+
None => resolve_default_text_model(dir)?,
|
|
43
|
+
};
|
|
44
|
+
|
|
45
|
+
let max_length = if let Some(override_value) = max_length_override {
|
|
46
|
+
if override_value == 0 {
|
|
47
|
+
return Err(GteError::Inference(
|
|
48
|
+
"max_length override must be greater than 0".to_string(),
|
|
49
|
+
));
|
|
50
|
+
}
|
|
51
|
+
override_value
|
|
52
|
+
} else {
|
|
53
|
+
read_max_length(dir)
|
|
54
|
+
};
|
|
55
|
+
|
|
56
|
+
let probe_config = crate::model_config::ModelConfig {
|
|
57
|
+
max_length,
|
|
58
|
+
output_tensor: String::new(),
|
|
59
|
+
mode: crate::model_config::ExtractorMode::Raw,
|
|
60
|
+
with_type_ids: false,
|
|
61
|
+
with_attention_mask: true,
|
|
62
|
+
num_threads,
|
|
63
|
+
optimization_level,
|
|
64
|
+
execution_providers: execution_providers_override.map(str::to_string),
|
|
65
|
+
};
|
|
66
|
+
let session = build_session(&model_path, &probe_config)?;
|
|
67
|
+
|
|
68
|
+
validate_supported_text_inputs(&session, "text reranking")?;
|
|
69
|
+
let with_type_ids = has_input(&session, "token_type_ids");
|
|
70
|
+
let with_attention_mask = has_input(&session, "attention_mask");
|
|
71
|
+
let output_tensor = select_output_tensor(&session, output_tensor_override, &["logits"])?;
|
|
72
|
+
|
|
73
|
+
let config = RerankerConfig {
|
|
74
|
+
max_length,
|
|
75
|
+
output_tensor,
|
|
76
|
+
with_type_ids,
|
|
77
|
+
with_attention_mask,
|
|
78
|
+
};
|
|
79
|
+
|
|
80
|
+
let tokenizer = Tokenizer::new(&tokenizer_path, config.max_length, config.with_type_ids)?;
|
|
81
|
+
|
|
82
|
+
Ok(Self {
|
|
83
|
+
tokenizer,
|
|
84
|
+
session,
|
|
85
|
+
config,
|
|
86
|
+
})
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
pub fn score_pairs(&self, pairs: &[(String, String)], apply_sigmoid: bool) -> Result<Array1<f32>> {
|
|
90
|
+
let tokenized = self.tokenizer.tokenize_pairs(pairs)?;
|
|
91
|
+
let input_tensors = InputTensors::from_tokenized(&tokenized, self.config.with_attention_mask)?;
|
|
92
|
+
let outputs = self.session.run(input_tensors.inputs)?;
|
|
93
|
+
let array = extract_output_tensor(&outputs, self.config.output_tensor.as_str())?;
|
|
94
|
+
|
|
95
|
+
let mut scores = match array.ndim() {
|
|
96
|
+
1 => array.into_dimensionality::<ndarray::Ix1>()?.into_owned(),
|
|
97
|
+
2 => {
|
|
98
|
+
let shape = array.shape();
|
|
99
|
+
if shape[1] == 0 {
|
|
100
|
+
return Err(GteError::Inference(format!(
|
|
101
|
+
"reranker output '{}' has invalid shape {:?}",
|
|
102
|
+
self.config.output_tensor, shape
|
|
103
|
+
)));
|
|
104
|
+
}
|
|
105
|
+
array.slice(ndarray::s![.., 0]).into_owned()
|
|
106
|
+
}
|
|
107
|
+
n => {
|
|
108
|
+
return Err(GteError::Inference(format!(
|
|
109
|
+
"reranker output '{}' rank {} is unsupported; expected rank 1 or 2",
|
|
110
|
+
self.config.output_tensor, n
|
|
111
|
+
)))
|
|
112
|
+
}
|
|
113
|
+
};
|
|
114
|
+
|
|
115
|
+
if apply_sigmoid {
|
|
116
|
+
sigmoid_scores(scores.view_mut());
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
Ok(scores)
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
}
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
use crate::embedder::{normalize_l2, Embedder};
|
|
4
4
|
use crate::error::GteError;
|
|
5
|
+
use crate::reranker::Reranker;
|
|
5
6
|
use magnus::{function, method, prelude::*, wrap, Error, RArray, Ruby};
|
|
6
7
|
use std::os::raw::c_void;
|
|
7
8
|
use std::panic::{catch_unwind, AssertUnwindSafe};
|
|
@@ -10,6 +11,13 @@ use std::sync::Arc;
|
|
|
10
11
|
#[wrap(class = "GTE::Embedder", free_immediately, size)]
|
|
11
12
|
pub struct RbEmbedder {
|
|
12
13
|
inner: Arc<Embedder>,
|
|
14
|
+
normalize: bool,
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
#[wrap(class = "GTE::Reranker", free_immediately, size)]
|
|
18
|
+
pub struct RbReranker {
|
|
19
|
+
inner: Arc<Reranker>,
|
|
20
|
+
sigmoid: bool,
|
|
13
21
|
}
|
|
14
22
|
|
|
15
23
|
#[wrap(class = "GTE::Tensor", free_immediately, size)]
|
|
@@ -22,11 +30,21 @@ pub struct RbTensor {
|
|
|
22
30
|
struct InferArgs {
|
|
23
31
|
embedder: *const Embedder,
|
|
24
32
|
texts: *const Vec<String>,
|
|
33
|
+
normalize: bool,
|
|
25
34
|
result: Option<Result<ndarray::Array2<f32>, GteError>>,
|
|
26
35
|
}
|
|
27
36
|
|
|
28
37
|
unsafe impl Send for InferArgs {}
|
|
29
38
|
|
|
39
|
+
struct ScoreArgs {
|
|
40
|
+
reranker: *const Reranker,
|
|
41
|
+
pairs: *const Vec<(String, String)>,
|
|
42
|
+
apply_sigmoid: bool,
|
|
43
|
+
result: Option<Result<Vec<f32>, GteError>>,
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
unsafe impl Send for ScoreArgs {}
|
|
47
|
+
|
|
30
48
|
fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
|
|
31
49
|
if let Some(msg) = payload.downcast_ref::<&str>() {
|
|
32
50
|
(*msg).to_string()
|
|
@@ -37,11 +55,16 @@ fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
|
|
|
37
55
|
}
|
|
38
56
|
}
|
|
39
57
|
|
|
40
|
-
fn infer_without_gvl(
|
|
58
|
+
fn infer_without_gvl(
|
|
59
|
+
embedder: &Arc<Embedder>,
|
|
60
|
+
normalize: bool,
|
|
61
|
+
texts: Vec<String>,
|
|
62
|
+
) -> Result<ndarray::Array2<f32>, Error> {
|
|
41
63
|
let embeddings = unsafe {
|
|
42
64
|
let mut args = InferArgs {
|
|
43
65
|
embedder: Arc::as_ptr(embedder),
|
|
44
66
|
texts: &texts as *const Vec<String>,
|
|
67
|
+
normalize,
|
|
45
68
|
result: None,
|
|
46
69
|
};
|
|
47
70
|
rb_sys::rb_thread_call_without_gvl(
|
|
@@ -60,12 +83,44 @@ fn infer_without_gvl(embedder: &Arc<Embedder>, texts: Vec<String>) -> Result<nda
|
|
|
60
83
|
Ok(embeddings)
|
|
61
84
|
}
|
|
62
85
|
|
|
86
|
+
fn score_without_gvl(
|
|
87
|
+
reranker: &Arc<Reranker>,
|
|
88
|
+
pairs: Vec<(String, String)>,
|
|
89
|
+
apply_sigmoid: bool,
|
|
90
|
+
) -> Result<Vec<f32>, Error> {
|
|
91
|
+
let scores = unsafe {
|
|
92
|
+
let mut args = ScoreArgs {
|
|
93
|
+
reranker: Arc::as_ptr(reranker),
|
|
94
|
+
pairs: &pairs as *const Vec<(String, String)>,
|
|
95
|
+
apply_sigmoid,
|
|
96
|
+
result: None,
|
|
97
|
+
};
|
|
98
|
+
rb_sys::rb_thread_call_without_gvl(
|
|
99
|
+
Some(run_score_without_gvl),
|
|
100
|
+
&mut args as *mut ScoreArgs as *mut c_void,
|
|
101
|
+
None,
|
|
102
|
+
std::ptr::null_mut(),
|
|
103
|
+
);
|
|
104
|
+
let result = args.result.take().ok_or_else(|| {
|
|
105
|
+
magnus::Error::from(GteError::Inference(
|
|
106
|
+
"reranking did not return a result".to_string(),
|
|
107
|
+
))
|
|
108
|
+
})?;
|
|
109
|
+
result.map_err(magnus::Error::from)?
|
|
110
|
+
};
|
|
111
|
+
Ok(scores)
|
|
112
|
+
}
|
|
113
|
+
|
|
63
114
|
unsafe extern "C" fn run_without_gvl(ptr: *mut c_void) -> *mut c_void {
|
|
64
115
|
let args = &mut *(ptr as *mut InferArgs);
|
|
65
116
|
let run_result = catch_unwind(AssertUnwindSafe(|| {
|
|
66
117
|
let tokenized = (*args.embedder).tokenize(&*args.texts)?;
|
|
67
118
|
let embeddings = (*args.embedder).run(&tokenized)?;
|
|
68
|
-
|
|
119
|
+
if args.normalize {
|
|
120
|
+
Ok(normalize_l2(embeddings))
|
|
121
|
+
} else {
|
|
122
|
+
Ok(embeddings)
|
|
123
|
+
}
|
|
69
124
|
}));
|
|
70
125
|
args.result = Some(match run_result {
|
|
71
126
|
Ok(result) => result,
|
|
@@ -77,6 +132,22 @@ unsafe extern "C" fn run_without_gvl(ptr: *mut c_void) -> *mut c_void {
|
|
|
77
132
|
std::ptr::null_mut()
|
|
78
133
|
}
|
|
79
134
|
|
|
135
|
+
unsafe extern "C" fn run_score_without_gvl(ptr: *mut c_void) -> *mut c_void {
|
|
136
|
+
let args = &mut *(ptr as *mut ScoreArgs);
|
|
137
|
+
let run_result = catch_unwind(AssertUnwindSafe(|| {
|
|
138
|
+
let scores = (*args.reranker).score_pairs(&*args.pairs, args.apply_sigmoid)?;
|
|
139
|
+
Ok(scores.to_vec())
|
|
140
|
+
}));
|
|
141
|
+
args.result = Some(match run_result {
|
|
142
|
+
Ok(result) => result,
|
|
143
|
+
Err(payload) => Err(GteError::Inference(format!(
|
|
144
|
+
"panic during reranking: {}",
|
|
145
|
+
panic_payload_to_string(payload),
|
|
146
|
+
))),
|
|
147
|
+
});
|
|
148
|
+
std::ptr::null_mut()
|
|
149
|
+
}
|
|
150
|
+
|
|
80
151
|
fn tensor_from_array(embeddings: ndarray::Array2<f32>) -> Result<RbTensor, Error> {
|
|
81
152
|
let rows = embeddings.nrows();
|
|
82
153
|
let cols = embeddings.ncols();
|
|
@@ -97,31 +168,128 @@ impl RbEmbedder {
|
|
|
97
168
|
num_threads: usize,
|
|
98
169
|
optimization_level: u8,
|
|
99
170
|
model_name: String,
|
|
171
|
+
normalize: bool,
|
|
172
|
+
output_tensor: String,
|
|
173
|
+
max_length: usize,
|
|
174
|
+
execution_providers: String,
|
|
100
175
|
) -> Result<Self, Error> {
|
|
101
176
|
let name = if model_name.is_empty() {
|
|
102
177
|
None
|
|
103
178
|
} else {
|
|
104
179
|
Some(model_name.as_str())
|
|
105
180
|
};
|
|
106
|
-
let
|
|
107
|
-
|
|
181
|
+
let output_override = if output_tensor.is_empty() {
|
|
182
|
+
None
|
|
183
|
+
} else {
|
|
184
|
+
Some(output_tensor.as_str())
|
|
185
|
+
};
|
|
186
|
+
let max_length_override = if max_length == 0 {
|
|
187
|
+
None
|
|
188
|
+
} else {
|
|
189
|
+
Some(max_length)
|
|
190
|
+
};
|
|
191
|
+
let execution_providers_override = if execution_providers.is_empty() {
|
|
192
|
+
None
|
|
193
|
+
} else {
|
|
194
|
+
Some(execution_providers.as_str())
|
|
195
|
+
};
|
|
196
|
+
let embedder = Embedder::from_dir(
|
|
197
|
+
&dir_path,
|
|
198
|
+
num_threads,
|
|
199
|
+
optimization_level,
|
|
200
|
+
name,
|
|
201
|
+
output_override,
|
|
202
|
+
max_length_override,
|
|
203
|
+
execution_providers_override,
|
|
204
|
+
)
|
|
205
|
+
.map_err(magnus::Error::from)?;
|
|
108
206
|
Ok(RbEmbedder {
|
|
109
207
|
inner: Arc::new(embedder),
|
|
208
|
+
normalize,
|
|
110
209
|
})
|
|
111
210
|
}
|
|
112
211
|
|
|
113
212
|
pub fn rb_embed(_ruby: &Ruby, rb_self: &Self, texts: RArray) -> Result<RbTensor, Error> {
|
|
114
213
|
let texts: Vec<String> = texts.to_vec()?;
|
|
115
|
-
let embeddings = infer_without_gvl(&rb_self.inner, texts)?;
|
|
214
|
+
let embeddings = infer_without_gvl(&rb_self.inner, rb_self.normalize, texts)?;
|
|
116
215
|
tensor_from_array(embeddings)
|
|
117
216
|
}
|
|
118
217
|
|
|
119
218
|
pub fn rb_embed_one(_ruby: &Ruby, rb_self: &Self, text: String) -> Result<RbTensor, Error> {
|
|
120
|
-
let embeddings = infer_without_gvl(&rb_self.inner, vec![text])?;
|
|
219
|
+
let embeddings = infer_without_gvl(&rb_self.inner, rb_self.normalize, vec![text])?;
|
|
121
220
|
tensor_from_array(embeddings)
|
|
122
221
|
}
|
|
123
222
|
}
|
|
124
223
|
|
|
224
|
+
impl RbReranker {
|
|
225
|
+
pub fn rb_new(
|
|
226
|
+
_ruby: &Ruby,
|
|
227
|
+
dir_path: String,
|
|
228
|
+
num_threads: usize,
|
|
229
|
+
optimization_level: u8,
|
|
230
|
+
model_name: String,
|
|
231
|
+
sigmoid: bool,
|
|
232
|
+
output_tensor: String,
|
|
233
|
+
max_length: usize,
|
|
234
|
+
execution_providers: String,
|
|
235
|
+
) -> Result<Self, Error> {
|
|
236
|
+
let name = if model_name.is_empty() {
|
|
237
|
+
None
|
|
238
|
+
} else {
|
|
239
|
+
Some(model_name.as_str())
|
|
240
|
+
};
|
|
241
|
+
let output_override = if output_tensor.is_empty() {
|
|
242
|
+
None
|
|
243
|
+
} else {
|
|
244
|
+
Some(output_tensor.as_str())
|
|
245
|
+
};
|
|
246
|
+
let max_length_override = if max_length == 0 {
|
|
247
|
+
None
|
|
248
|
+
} else {
|
|
249
|
+
Some(max_length)
|
|
250
|
+
};
|
|
251
|
+
let execution_providers_override = if execution_providers.is_empty() {
|
|
252
|
+
None
|
|
253
|
+
} else {
|
|
254
|
+
Some(execution_providers.as_str())
|
|
255
|
+
};
|
|
256
|
+
let reranker = Reranker::from_dir(
|
|
257
|
+
&dir_path,
|
|
258
|
+
num_threads,
|
|
259
|
+
optimization_level,
|
|
260
|
+
name,
|
|
261
|
+
output_override,
|
|
262
|
+
max_length_override,
|
|
263
|
+
execution_providers_override,
|
|
264
|
+
)
|
|
265
|
+
.map_err(magnus::Error::from)?;
|
|
266
|
+
Ok(RbReranker {
|
|
267
|
+
inner: Arc::new(reranker),
|
|
268
|
+
sigmoid,
|
|
269
|
+
})
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
pub fn rb_score(
|
|
273
|
+
ruby: &Ruby,
|
|
274
|
+
rb_self: &Self,
|
|
275
|
+
query: String,
|
|
276
|
+
candidates: RArray,
|
|
277
|
+
) -> Result<RArray, Error> {
|
|
278
|
+
let candidates: Vec<String> = candidates.to_vec()?;
|
|
279
|
+
let pairs: Vec<(String, String)> = candidates
|
|
280
|
+
.into_iter()
|
|
281
|
+
.map(|candidate| (query.clone(), candidate))
|
|
282
|
+
.collect();
|
|
283
|
+
let scores = score_without_gvl(&rb_self.inner, pairs, rb_self.sigmoid)?;
|
|
284
|
+
|
|
285
|
+
let out = ruby.ary_new_capa(scores.len());
|
|
286
|
+
for score in scores {
|
|
287
|
+
out.push(score)?;
|
|
288
|
+
}
|
|
289
|
+
Ok(out)
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
|
|
125
293
|
impl RbTensor {
|
|
126
294
|
pub fn len(&self) -> usize {
|
|
127
295
|
self.rows
|
|
@@ -208,10 +376,14 @@ impl RbTensor {
|
|
|
208
376
|
pub fn register(ruby: &Ruby) -> Result<(), Error> {
|
|
209
377
|
let module = ruby.define_module("GTE")?;
|
|
210
378
|
let embedder_class = module.define_class("Embedder", ruby.class_object())?;
|
|
211
|
-
embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new,
|
|
379
|
+
embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 8))?;
|
|
212
380
|
embedder_class.define_method("embed", method!(RbEmbedder::rb_embed, 1))?;
|
|
213
381
|
embedder_class.define_method("embed_one", method!(RbEmbedder::rb_embed_one, 1))?;
|
|
214
382
|
|
|
383
|
+
let reranker_class = module.define_class("Reranker", ruby.class_object())?;
|
|
384
|
+
reranker_class.define_singleton_method("new", function!(RbReranker::rb_new, 8))?;
|
|
385
|
+
reranker_class.define_method("score", method!(RbReranker::rb_score, 2))?;
|
|
386
|
+
|
|
215
387
|
let tensor_class = module.define_class("Tensor", ruby.class_object())?;
|
|
216
388
|
tensor_class.define_method("rows", method!(RbTensor::rows, 0))?;
|
|
217
389
|
tensor_class.define_method("size", method!(RbTensor::len, 0))?;
|