gte 0.0.13 → 0.0.14
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +93 -27
- data/VERSION +1 -1
- data/ext/gte/Cargo.toml +26 -4
- data/ext/gte/benches/hot_path.rs +20 -54
- data/ext/gte/build.rs +2 -6
- data/ext/gte/rustfmt.toml +5 -0
- data/ext/gte/src/embedder.rs +71 -43
- data/ext/gte/src/error.rs +4 -4
- data/ext/gte/src/lib.rs +1 -1
- data/ext/gte/src/model_config.rs +4 -0
- data/ext/gte/src/model_profile.rs +26 -87
- data/ext/gte/src/pipeline.rs +11 -30
- data/ext/gte/src/postprocess.rs +8 -14
- data/ext/gte/src/reranker.rs +50 -50
- data/ext/gte/src/ruby_embedder.rs +48 -53
- data/ext/gte/src/session.rs +136 -248
- data/ext/gte/src/tokenizer.rs +51 -125
- data/ext/gte/tests/inference_integration_test.rs +8 -18
- data/ext/gte/tests/padding_regression_test.rs +13 -26
- data/ext/gte/tests/tokenizer_unit_test.rs +10 -24
- data/lib/gte/config.rb +2 -1
- data/lib/gte/embedder.rb +6 -2
- data/lib/gte/reranker.rb +3 -1
- data/lib/gte.rb +6 -0
- metadata +2 -1
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
#![cfg(feature = "ruby-ffi")]
|
|
2
|
+
#![allow(unsafe_code)]
|
|
3
|
+
#![allow(unused_results)]
|
|
4
|
+
#![allow(unused_qualifications)]
|
|
2
5
|
|
|
3
|
-
use crate::embedder::{normalize_l2, Embedder};
|
|
6
|
+
use crate::embedder::{normalize_l2, output_name_suggests_normalized, Embedder};
|
|
4
7
|
use crate::error::GteError;
|
|
5
8
|
use crate::model_config::ModelLoadOverrides;
|
|
6
9
|
use crate::reranker::Reranker;
|
|
@@ -66,29 +69,30 @@ unsafe extern "C" fn run_embed_without_gvl(ptr: *mut c_void) -> *mut c_void {
|
|
|
66
69
|
let run_result = catch_unwind(AssertUnwindSafe(|| {
|
|
67
70
|
// Full embedding path (tokenization + inference) runs without the GVL.
|
|
68
71
|
let embeddings = (*args.embedder).embed_ref(&*args.texts)?;
|
|
69
|
-
if args.normalize {
|
|
72
|
+
if args.normalize {
|
|
73
|
+
Ok(normalize_l2(embeddings))
|
|
74
|
+
} else {
|
|
75
|
+
Ok(embeddings)
|
|
76
|
+
}
|
|
70
77
|
}));
|
|
71
78
|
args.result = Some(match run_result {
|
|
72
79
|
Ok(result) => result,
|
|
73
|
-
Err(payload) =>
|
|
74
|
-
"panic during inference: {}",
|
|
75
|
-
|
|
76
|
-
))),
|
|
80
|
+
Err(payload) => {
|
|
81
|
+
Err(GteError::Inference(format!("panic during inference: {}", panic_payload_to_string(payload),)))
|
|
82
|
+
}
|
|
77
83
|
});
|
|
78
84
|
std::ptr::null_mut()
|
|
79
85
|
}
|
|
80
86
|
|
|
81
87
|
unsafe extern "C" fn run_score_without_gvl(ptr: *mut c_void) -> *mut c_void {
|
|
82
88
|
let args = &mut *(ptr as *mut ScoreArgs);
|
|
83
|
-
let run_result =
|
|
84
|
-
(*args.reranker).score(&*args.query, &*args.candidates, args.apply_sigmoid)
|
|
85
|
-
}));
|
|
89
|
+
let run_result =
|
|
90
|
+
catch_unwind(AssertUnwindSafe(|| (*args.reranker).score(&*args.query, &*args.candidates, args.apply_sigmoid)));
|
|
86
91
|
args.result = Some(match run_result {
|
|
87
92
|
Ok(result) => result,
|
|
88
|
-
Err(payload) =>
|
|
89
|
-
"panic during reranking: {}",
|
|
90
|
-
|
|
91
|
-
))),
|
|
93
|
+
Err(payload) => {
|
|
94
|
+
Err(GteError::Inference(format!("panic during reranking: {}", panic_payload_to_string(payload),)))
|
|
95
|
+
}
|
|
92
96
|
});
|
|
93
97
|
std::ptr::null_mut()
|
|
94
98
|
}
|
|
@@ -99,23 +103,18 @@ fn infer_without_gvl(
|
|
|
99
103
|
texts: Vec<String>,
|
|
100
104
|
) -> Result<ndarray::Array2<f32>, Error> {
|
|
101
105
|
let embeddings = unsafe {
|
|
102
|
-
let mut args =
|
|
103
|
-
embedder: Arc::as_ptr(embedder),
|
|
104
|
-
texts: &texts as *const Vec<String>,
|
|
105
|
-
normalize,
|
|
106
|
-
result: None,
|
|
107
|
-
};
|
|
106
|
+
let mut args =
|
|
107
|
+
InferArgs { embedder: Arc::as_ptr(embedder), texts: &texts as *const Vec<String>, normalize, result: None };
|
|
108
108
|
rb_sys::rb_thread_call_without_gvl(
|
|
109
109
|
Some(run_embed_without_gvl),
|
|
110
110
|
&mut args as *mut InferArgs as *mut c_void,
|
|
111
111
|
None,
|
|
112
112
|
std::ptr::null_mut(),
|
|
113
113
|
);
|
|
114
|
-
let result = args
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
))
|
|
118
|
-
})?;
|
|
114
|
+
let result = args
|
|
115
|
+
.result
|
|
116
|
+
.take()
|
|
117
|
+
.ok_or_else(|| magnus::Error::from(GteError::Inference("inference did not return a result".to_string())))?;
|
|
119
118
|
result.map_err(magnus::Error::from)?
|
|
120
119
|
};
|
|
121
120
|
Ok(embeddings)
|
|
@@ -141,11 +140,10 @@ fn score_without_gvl(
|
|
|
141
140
|
None,
|
|
142
141
|
std::ptr::null_mut(),
|
|
143
142
|
);
|
|
144
|
-
let result = args
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
))
|
|
148
|
-
})?;
|
|
143
|
+
let result = args
|
|
144
|
+
.result
|
|
145
|
+
.take()
|
|
146
|
+
.ok_or_else(|| magnus::Error::from(GteError::Inference("reranking did not return a result".to_string())))?;
|
|
149
147
|
result.map_err(magnus::Error::from)?
|
|
150
148
|
};
|
|
151
149
|
Ok(scores)
|
|
@@ -158,10 +156,7 @@ fn tensor_from_array(embeddings: ndarray::Array2<f32>) -> Result<RbTensor, Error
|
|
|
158
156
|
let cols = embeddings.ncols();
|
|
159
157
|
let (data, offset) = embeddings.into_raw_vec_and_offset();
|
|
160
158
|
if let Some(off) = offset.filter(|&o| o != 0) {
|
|
161
|
-
return Err(magnus::Error::from(GteError::Inference(format!(
|
|
162
|
-
"unexpected non-zero tensor offset: {}",
|
|
163
|
-
off
|
|
164
|
-
))));
|
|
159
|
+
return Err(magnus::Error::from(GteError::Inference(format!("unexpected non-zero tensor offset: {}", off))));
|
|
165
160
|
}
|
|
166
161
|
Ok(RbTensor { rows, cols, data })
|
|
167
162
|
}
|
|
@@ -177,22 +172,28 @@ impl RbEmbedder {
|
|
|
177
172
|
max_length: usize,
|
|
178
173
|
padding: String,
|
|
179
174
|
execution_providers: String,
|
|
175
|
+
lowercase_input: bool,
|
|
176
|
+
max_input_chars: usize,
|
|
180
177
|
) -> Result<Self, Error> {
|
|
181
178
|
let name = if model_name.is_empty() { None } else { Some(model_name.as_str()) };
|
|
182
179
|
let output_override = if output_tensor.is_empty() { None } else { Some(output_tensor.as_str()) };
|
|
183
180
|
let max_length_override = if max_length == 0 { None } else { Some(max_length) };
|
|
184
|
-
let execution_providers_override =
|
|
181
|
+
let execution_providers_override =
|
|
182
|
+
if execution_providers.is_empty() { None } else { Some(execution_providers.as_str()) };
|
|
185
183
|
let padding_override = if padding.is_empty() { None } else { Some(padding.as_str()) };
|
|
184
|
+
let max_input_chars_override = if max_input_chars == 0 { None } else { Some(max_input_chars) };
|
|
186
185
|
let overrides = ModelLoadOverrides {
|
|
187
186
|
model_name: name,
|
|
188
187
|
output_tensor: output_override,
|
|
189
188
|
max_length: max_length_override,
|
|
190
189
|
padding: padding_override,
|
|
191
190
|
execution_providers: execution_providers_override,
|
|
191
|
+
lowercase_input: Some(lowercase_input),
|
|
192
|
+
max_input_chars: max_input_chars_override,
|
|
192
193
|
};
|
|
193
|
-
let embedder = Embedder::from_dir(&dir_path, optimization_level, overrides)
|
|
194
|
-
|
|
195
|
-
Ok(RbEmbedder { inner: Arc::new(embedder), normalize })
|
|
194
|
+
let embedder = Embedder::from_dir(&dir_path, optimization_level, overrides).map_err(magnus::Error::from)?;
|
|
195
|
+
let skip_normalize = normalize && output_name_suggests_normalized(&embedder.config.output_tensor);
|
|
196
|
+
Ok(RbEmbedder { inner: Arc::new(embedder), normalize: normalize && !skip_normalize })
|
|
196
197
|
}
|
|
197
198
|
|
|
198
199
|
pub fn rb_embed(_ruby: &Ruby, rb_self: &Self, texts: RArray) -> Result<RbTensor, Error> {
|
|
@@ -218,11 +219,14 @@ impl RbReranker {
|
|
|
218
219
|
max_length: usize,
|
|
219
220
|
padding: String,
|
|
220
221
|
execution_providers: String,
|
|
222
|
+
_lowercase_input: bool,
|
|
223
|
+
_max_input_chars: usize,
|
|
221
224
|
) -> Result<Self, Error> {
|
|
222
225
|
let name = if model_name.is_empty() { None } else { Some(model_name.as_str()) };
|
|
223
226
|
let output_override = if output_tensor.is_empty() { None } else { Some(output_tensor.as_str()) };
|
|
224
227
|
let max_length_override = if max_length == 0 { None } else { Some(max_length) };
|
|
225
|
-
let execution_providers_override =
|
|
228
|
+
let execution_providers_override =
|
|
229
|
+
if execution_providers.is_empty() { None } else { Some(execution_providers.as_str()) };
|
|
226
230
|
let padding_override = if padding.is_empty() { None } else { Some(padding.as_str()) };
|
|
227
231
|
let overrides = ModelLoadOverrides {
|
|
228
232
|
model_name: name,
|
|
@@ -230,18 +234,13 @@ impl RbReranker {
|
|
|
230
234
|
max_length: max_length_override,
|
|
231
235
|
padding: padding_override,
|
|
232
236
|
execution_providers: execution_providers_override,
|
|
237
|
+
..ModelLoadOverrides::default()
|
|
233
238
|
};
|
|
234
|
-
let reranker = Reranker::from_dir(&dir_path, optimization_level, overrides)
|
|
235
|
-
.map_err(magnus::Error::from)?;
|
|
239
|
+
let reranker = Reranker::from_dir(&dir_path, optimization_level, overrides).map_err(magnus::Error::from)?;
|
|
236
240
|
Ok(RbReranker { inner: Arc::new(reranker), sigmoid })
|
|
237
241
|
}
|
|
238
242
|
|
|
239
|
-
pub fn rb_score(
|
|
240
|
-
ruby: &Ruby,
|
|
241
|
-
rb_self: &Self,
|
|
242
|
-
query: String,
|
|
243
|
-
candidates: RArray,
|
|
244
|
-
) -> Result<RArray, Error> {
|
|
243
|
+
pub fn rb_score(ruby: &Ruby, rb_self: &Self, query: String, candidates: RArray) -> Result<RArray, Error> {
|
|
245
244
|
let candidates: Vec<String> = candidates.to_vec()?;
|
|
246
245
|
let scores = score_without_gvl(&rb_self.inner, query, candidates, rb_self.sigmoid)?;
|
|
247
246
|
let out = ruby.ary_new_capa(scores.len());
|
|
@@ -296,11 +295,7 @@ impl RbTensor {
|
|
|
296
295
|
Self::row_binary_f32(ruby, rb_self, 0)
|
|
297
296
|
}
|
|
298
297
|
|
|
299
|
-
pub fn row_binary_f32(
|
|
300
|
-
ruby: &Ruby,
|
|
301
|
-
rb_self: &Self,
|
|
302
|
-
index: usize,
|
|
303
|
-
) -> Result<magnus::RString, Error> {
|
|
298
|
+
pub fn row_binary_f32(ruby: &Ruby, rb_self: &Self, index: usize) -> Result<magnus::RString, Error> {
|
|
304
299
|
if index >= rb_self.rows {
|
|
305
300
|
return Err(magnus::Error::from(GteError::Inference(format!(
|
|
306
301
|
"row index {} out of bounds for {} rows",
|
|
@@ -340,12 +335,12 @@ impl RbTensor {
|
|
|
340
335
|
pub fn register(ruby: &Ruby) -> Result<(), Error> {
|
|
341
336
|
let module = ruby.define_module("GTE")?;
|
|
342
337
|
let embedder_class = module.define_class("Embedder", ruby.class_object())?;
|
|
343
|
-
embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new,
|
|
338
|
+
embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 10))?;
|
|
344
339
|
embedder_class.define_method("embed", method!(RbEmbedder::rb_embed, 1))?;
|
|
345
340
|
embedder_class.define_method("embed_one", method!(RbEmbedder::rb_embed_one, 1))?;
|
|
346
341
|
|
|
347
342
|
let reranker_class = module.define_class("Reranker", ruby.class_object())?;
|
|
348
|
-
reranker_class.define_singleton_method("new", function!(RbReranker::rb_new,
|
|
343
|
+
reranker_class.define_singleton_method("new", function!(RbReranker::rb_new, 10))?;
|
|
349
344
|
reranker_class.define_method("score", method!(RbReranker::rb_score, 2))?;
|
|
350
345
|
|
|
351
346
|
let tensor_class = module.define_class("Tensor", ruby.class_object())?;
|