gte 0.0.6 → 0.0.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +16 -8
- data/Rakefile +38 -3
- data/VERSION +1 -1
- data/ext/gte/Cargo.toml +4 -4
- data/ext/gte/src/embedder.rs +42 -33
- data/ext/gte/src/model_config.rs +18 -0
- data/ext/gte/src/model_profile.rs +129 -33
- data/ext/gte/src/pipeline.rs +12 -9
- data/ext/gte/src/reranker.rs +49 -31
- data/ext/gte/src/ruby_embedder.rs +73 -113
- data/ext/gte/src/session.rs +279 -15
- data/ext/gte/src/tokenizer.rs +99 -14
- data/ext/gte/tests/inference_integration_test.rs +5 -4
- data/ext/gte/tests/tokenizer_unit_test.rs +5 -2
- data/lib/gte/config.rb +2 -2
- data/lib/gte/embedder.rb +7 -4
- data/lib/gte/reranker.rb +3 -1
- data/lib/gte.rb +1 -10
- metadata +6 -6
data/ext/gte/src/reranker.rs
CHANGED
|
@@ -1,19 +1,19 @@
|
|
|
1
1
|
use crate::error::{GteError, Result};
|
|
2
|
+
use crate::model_config::{ModelLoadOverrides, PaddingMode};
|
|
2
3
|
use crate::model_profile::{
|
|
3
|
-
has_input,
|
|
4
|
-
select_output_tensor, validate_supported_text_inputs,
|
|
4
|
+
has_input, read_tokenizer_profile, resolve_default_text_model, resolve_named_model,
|
|
5
|
+
resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
|
|
5
6
|
};
|
|
6
7
|
use crate::pipeline::{extract_output_tensor, InputTensors};
|
|
7
8
|
use crate::postprocess::sigmoid_scores;
|
|
8
|
-
use crate::session::build_session;
|
|
9
|
-
use crate::tokenizer::Tokenizer;
|
|
10
|
-
use
|
|
11
|
-
use ort::session::Session;
|
|
12
|
-
use std::path::Path;
|
|
9
|
+
use crate::session::{build_session, SessionPool};
|
|
10
|
+
use crate::tokenizer::{parse_padding_mode_override, Tokenizer};
|
|
11
|
+
use std::path::{Path, PathBuf};
|
|
13
12
|
|
|
14
13
|
#[derive(Debug, Clone)]
|
|
15
14
|
struct RerankerConfig {
|
|
16
15
|
max_length: usize,
|
|
16
|
+
padding_mode: PaddingMode,
|
|
17
17
|
output_tensor: String,
|
|
18
18
|
with_type_ids: bool,
|
|
19
19
|
with_attention_mask: bool,
|
|
@@ -21,7 +21,7 @@ struct RerankerConfig {
|
|
|
21
21
|
|
|
22
22
|
pub struct Reranker {
|
|
23
23
|
tokenizer: Tokenizer,
|
|
24
|
-
|
|
24
|
+
pool: SessionPool,
|
|
25
25
|
config: RerankerConfig,
|
|
26
26
|
}
|
|
27
27
|
|
|
@@ -30,70 +30,89 @@ impl Reranker {
|
|
|
30
30
|
dir: P,
|
|
31
31
|
num_threads: usize,
|
|
32
32
|
optimization_level: u8,
|
|
33
|
-
|
|
34
|
-
output_tensor_override: Option<&str>,
|
|
35
|
-
max_length_override: Option<usize>,
|
|
36
|
-
execution_providers_override: Option<&str>,
|
|
33
|
+
overrides: ModelLoadOverrides<'_>,
|
|
37
34
|
) -> Result<Self> {
|
|
38
35
|
let dir = dir.as_ref();
|
|
39
36
|
let tokenizer_path = resolve_tokenizer_path(dir)?;
|
|
40
|
-
let model_path = match model_name.filter(|s| !s.is_empty()) {
|
|
37
|
+
let model_path: PathBuf = match overrides.model_name.filter(|s| !s.is_empty()) {
|
|
41
38
|
Some(name) => resolve_named_model(dir, name)?,
|
|
42
39
|
None => resolve_default_text_model(dir)?,
|
|
43
40
|
};
|
|
44
41
|
|
|
45
|
-
let
|
|
42
|
+
let tokenizer_profile = read_tokenizer_profile(dir);
|
|
43
|
+
let max_length = if let Some(override_value) = overrides.max_length {
|
|
46
44
|
if override_value == 0 {
|
|
47
45
|
return Err(GteError::Inference(
|
|
48
46
|
"max_length override must be greater than 0".to_string(),
|
|
49
47
|
));
|
|
50
48
|
}
|
|
51
|
-
override_value
|
|
49
|
+
override_value.min(tokenizer_profile.safe_max_length)
|
|
52
50
|
} else {
|
|
53
|
-
|
|
51
|
+
tokenizer_profile.default_max_length
|
|
54
52
|
};
|
|
53
|
+
let padding_mode =
|
|
54
|
+
parse_padding_mode_override(overrides.padding)?.unwrap_or(PaddingMode::Auto);
|
|
55
55
|
|
|
56
56
|
let probe_config = crate::model_config::ModelConfig {
|
|
57
57
|
max_length,
|
|
58
|
+
padding_mode,
|
|
58
59
|
output_tensor: String::new(),
|
|
59
60
|
mode: crate::model_config::ExtractorMode::Raw,
|
|
60
61
|
with_type_ids: false,
|
|
61
62
|
with_attention_mask: true,
|
|
62
63
|
num_threads,
|
|
63
64
|
optimization_level,
|
|
64
|
-
execution_providers:
|
|
65
|
+
execution_providers: overrides.execution_providers.map(str::to_string),
|
|
65
66
|
};
|
|
66
67
|
let session = build_session(&model_path, &probe_config)?;
|
|
67
68
|
|
|
68
69
|
validate_supported_text_inputs(&session, "text reranking")?;
|
|
69
70
|
let with_type_ids = has_input(&session, "token_type_ids");
|
|
70
71
|
let with_attention_mask = has_input(&session, "attention_mask");
|
|
71
|
-
let output_tensor = select_output_tensor(&session,
|
|
72
|
+
let output_tensor = select_output_tensor(&session, overrides.output_tensor, &["logits"])?;
|
|
72
73
|
|
|
73
74
|
let config = RerankerConfig {
|
|
74
75
|
max_length,
|
|
76
|
+
padding_mode,
|
|
75
77
|
output_tensor,
|
|
76
78
|
with_type_ids,
|
|
77
79
|
with_attention_mask,
|
|
78
80
|
};
|
|
79
81
|
|
|
80
|
-
let tokenizer = Tokenizer::new(
|
|
82
|
+
let tokenizer = Tokenizer::new(
|
|
83
|
+
&tokenizer_path,
|
|
84
|
+
config.max_length,
|
|
85
|
+
config.with_type_ids,
|
|
86
|
+
config.padding_mode,
|
|
87
|
+
tokenizer_profile.fixed_padding_length,
|
|
88
|
+
)?;
|
|
81
89
|
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
session,
|
|
85
|
-
config,
|
|
86
|
-
})
|
|
90
|
+
let pool = SessionPool::new(session, model_path, probe_config);
|
|
91
|
+
Ok(Self { tokenizer, pool, config })
|
|
87
92
|
}
|
|
88
93
|
|
|
89
|
-
pub fn score_pairs(&self, pairs: &[(String, String)], apply_sigmoid: bool) -> Result<
|
|
94
|
+
pub fn score_pairs(&self, pairs: &[(String, String)], apply_sigmoid: bool) -> Result<Vec<f32>> {
|
|
90
95
|
let tokenized = self.tokenizer.tokenize_pairs(pairs)?;
|
|
91
|
-
|
|
92
|
-
|
|
96
|
+
self.score_tokenized(&tokenized, apply_sigmoid)
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
pub fn score(&self, query: &str, candidates: &[String], apply_sigmoid: bool) -> Result<Vec<f32>> {
|
|
100
|
+
let tokenized = self.tokenizer.tokenize_query_candidates(query, candidates)?;
|
|
101
|
+
self.score_tokenized(&tokenized, apply_sigmoid)
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
fn score_tokenized(
|
|
105
|
+
&self,
|
|
106
|
+
tokenized: &crate::tokenizer::Tokenized,
|
|
107
|
+
apply_sigmoid: bool,
|
|
108
|
+
) -> Result<Vec<f32>> {
|
|
109
|
+
let input_tensors = InputTensors::from_tokenized(tokenized, self.config.with_attention_mask)?;
|
|
110
|
+
let mut session = self.pool.acquire()?;
|
|
111
|
+
let outputs = session.run(input_tensors.inputs).map_err(|e| GteError::Ort(e.to_string()))?;
|
|
93
112
|
let array = extract_output_tensor(&outputs, self.config.output_tensor.as_str())?;
|
|
94
113
|
|
|
95
114
|
let mut scores = match array.ndim() {
|
|
96
|
-
1 => array.into_dimensionality::<ndarray::Ix1>()?.
|
|
115
|
+
1 => array.into_dimensionality::<ndarray::Ix1>()?.to_vec(),
|
|
97
116
|
2 => {
|
|
98
117
|
let shape = array.shape();
|
|
99
118
|
if shape[1] == 0 {
|
|
@@ -102,7 +121,7 @@ impl Reranker {
|
|
|
102
121
|
self.config.output_tensor, shape
|
|
103
122
|
)));
|
|
104
123
|
}
|
|
105
|
-
array.slice(ndarray::s![.., 0]).
|
|
124
|
+
array.slice(ndarray::s![.., 0]).to_vec()
|
|
106
125
|
}
|
|
107
126
|
n => {
|
|
108
127
|
return Err(GteError::Inference(format!(
|
|
@@ -113,10 +132,9 @@ impl Reranker {
|
|
|
113
132
|
};
|
|
114
133
|
|
|
115
134
|
if apply_sigmoid {
|
|
116
|
-
sigmoid_scores(scores.
|
|
135
|
+
sigmoid_scores(ndarray::ArrayViewMut1::from(scores.as_mut_slice()));
|
|
117
136
|
}
|
|
118
137
|
|
|
119
138
|
Ok(scores)
|
|
120
139
|
}
|
|
121
|
-
|
|
122
140
|
}
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
use crate::embedder::{normalize_l2, Embedder};
|
|
4
4
|
use crate::error::GteError;
|
|
5
|
+
use crate::model_config::ModelLoadOverrides;
|
|
5
6
|
use crate::reranker::Reranker;
|
|
6
7
|
use magnus::{function, method, prelude::*, wrap, Error, RArray, Ruby};
|
|
7
8
|
use std::os::raw::c_void;
|
|
@@ -27,11 +28,15 @@ pub struct RbTensor {
|
|
|
27
28
|
data: Vec<f32>,
|
|
28
29
|
}
|
|
29
30
|
|
|
31
|
+
// ---------------------------------------------------------------------------
|
|
32
|
+
// GVL-release helpers
|
|
33
|
+
// ---------------------------------------------------------------------------
|
|
34
|
+
|
|
30
35
|
struct InferArgs {
|
|
31
36
|
embedder: *const Embedder,
|
|
32
37
|
texts: *const Vec<String>,
|
|
33
38
|
normalize: bool,
|
|
34
|
-
result: Option<Result<ndarray::Array2<f32
|
|
39
|
+
result: Option<crate::error::Result<ndarray::Array2<f32>>>,
|
|
35
40
|
}
|
|
36
41
|
|
|
37
42
|
unsafe impl Send for InferArgs {}
|
|
@@ -40,7 +45,7 @@ struct ScoreArgs {
|
|
|
40
45
|
reranker: *const Reranker,
|
|
41
46
|
pairs: *const Vec<(String, String)>,
|
|
42
47
|
apply_sigmoid: bool,
|
|
43
|
-
result: Option<Result<Vec<f32
|
|
48
|
+
result: Option<crate::error::Result<Vec<f32>>>,
|
|
44
49
|
}
|
|
45
50
|
|
|
46
51
|
unsafe impl Send for ScoreArgs {}
|
|
@@ -55,6 +60,38 @@ fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
|
|
|
55
60
|
}
|
|
56
61
|
}
|
|
57
62
|
|
|
63
|
+
unsafe extern "C" fn run_without_gvl(ptr: *mut c_void) -> *mut c_void {
|
|
64
|
+
let args = &mut *(ptr as *mut InferArgs);
|
|
65
|
+
let run_result = catch_unwind(AssertUnwindSafe(|| {
|
|
66
|
+
let tokenized = (*args.embedder).tokenize(&*args.texts)?;
|
|
67
|
+
let embeddings = (*args.embedder).run(&tokenized)?;
|
|
68
|
+
if args.normalize { Ok(normalize_l2(embeddings)) } else { Ok(embeddings) }
|
|
69
|
+
}));
|
|
70
|
+
args.result = Some(match run_result {
|
|
71
|
+
Ok(result) => result,
|
|
72
|
+
Err(payload) => Err(GteError::Inference(format!(
|
|
73
|
+
"panic during inference: {}",
|
|
74
|
+
panic_payload_to_string(payload),
|
|
75
|
+
))),
|
|
76
|
+
});
|
|
77
|
+
std::ptr::null_mut()
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
unsafe extern "C" fn run_score_without_gvl(ptr: *mut c_void) -> *mut c_void {
|
|
81
|
+
let args = &mut *(ptr as *mut ScoreArgs);
|
|
82
|
+
let run_result = catch_unwind(AssertUnwindSafe(|| {
|
|
83
|
+
(*args.reranker).score_pairs(&*args.pairs, args.apply_sigmoid)
|
|
84
|
+
}));
|
|
85
|
+
args.result = Some(match run_result {
|
|
86
|
+
Ok(result) => result,
|
|
87
|
+
Err(payload) => Err(GteError::Inference(format!(
|
|
88
|
+
"panic during reranking: {}",
|
|
89
|
+
panic_payload_to_string(payload),
|
|
90
|
+
))),
|
|
91
|
+
});
|
|
92
|
+
std::ptr::null_mut()
|
|
93
|
+
}
|
|
94
|
+
|
|
58
95
|
fn infer_without_gvl(
|
|
59
96
|
embedder: &Arc<Embedder>,
|
|
60
97
|
normalize: bool,
|
|
@@ -111,42 +148,7 @@ fn score_without_gvl(
|
|
|
111
148
|
Ok(scores)
|
|
112
149
|
}
|
|
113
150
|
|
|
114
|
-
|
|
115
|
-
let args = &mut *(ptr as *mut InferArgs);
|
|
116
|
-
let run_result = catch_unwind(AssertUnwindSafe(|| {
|
|
117
|
-
let tokenized = (*args.embedder).tokenize(&*args.texts)?;
|
|
118
|
-
let embeddings = (*args.embedder).run(&tokenized)?;
|
|
119
|
-
if args.normalize {
|
|
120
|
-
Ok(normalize_l2(embeddings))
|
|
121
|
-
} else {
|
|
122
|
-
Ok(embeddings)
|
|
123
|
-
}
|
|
124
|
-
}));
|
|
125
|
-
args.result = Some(match run_result {
|
|
126
|
-
Ok(result) => result,
|
|
127
|
-
Err(payload) => Err(GteError::Inference(format!(
|
|
128
|
-
"panic during inference: {}",
|
|
129
|
-
panic_payload_to_string(payload),
|
|
130
|
-
))),
|
|
131
|
-
});
|
|
132
|
-
std::ptr::null_mut()
|
|
133
|
-
}
|
|
134
|
-
|
|
135
|
-
unsafe extern "C" fn run_score_without_gvl(ptr: *mut c_void) -> *mut c_void {
|
|
136
|
-
let args = &mut *(ptr as *mut ScoreArgs);
|
|
137
|
-
let run_result = catch_unwind(AssertUnwindSafe(|| {
|
|
138
|
-
let scores = (*args.reranker).score_pairs(&*args.pairs, args.apply_sigmoid)?;
|
|
139
|
-
Ok(scores.to_vec())
|
|
140
|
-
}));
|
|
141
|
-
args.result = Some(match run_result {
|
|
142
|
-
Ok(result) => result,
|
|
143
|
-
Err(payload) => Err(GteError::Inference(format!(
|
|
144
|
-
"panic during reranking: {}",
|
|
145
|
-
panic_payload_to_string(payload),
|
|
146
|
-
))),
|
|
147
|
-
});
|
|
148
|
-
std::ptr::null_mut()
|
|
149
|
-
}
|
|
151
|
+
// ---------------------------------------------------------------------------
|
|
150
152
|
|
|
151
153
|
fn tensor_from_array(embeddings: ndarray::Array2<f32>) -> Result<RbTensor, Error> {
|
|
152
154
|
let rows = embeddings.nrows();
|
|
@@ -171,42 +173,24 @@ impl RbEmbedder {
|
|
|
171
173
|
normalize: bool,
|
|
172
174
|
output_tensor: String,
|
|
173
175
|
max_length: usize,
|
|
176
|
+
padding: String,
|
|
174
177
|
execution_providers: String,
|
|
175
178
|
) -> Result<Self, Error> {
|
|
176
|
-
let name = if model_name.is_empty() {
|
|
177
|
-
|
|
178
|
-
} else {
|
|
179
|
-
|
|
180
|
-
};
|
|
181
|
-
let
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
None
|
|
188
|
-
} else {
|
|
189
|
-
Some(max_length)
|
|
179
|
+
let name = if model_name.is_empty() { None } else { Some(model_name.as_str()) };
|
|
180
|
+
let output_override = if output_tensor.is_empty() { None } else { Some(output_tensor.as_str()) };
|
|
181
|
+
let max_length_override = if max_length == 0 { None } else { Some(max_length) };
|
|
182
|
+
let execution_providers_override = if execution_providers.is_empty() { None } else { Some(execution_providers.as_str()) };
|
|
183
|
+
let padding_override = if padding.is_empty() { None } else { Some(padding.as_str()) };
|
|
184
|
+
let overrides = ModelLoadOverrides {
|
|
185
|
+
model_name: name,
|
|
186
|
+
output_tensor: output_override,
|
|
187
|
+
max_length: max_length_override,
|
|
188
|
+
padding: padding_override,
|
|
189
|
+
execution_providers: execution_providers_override,
|
|
190
190
|
};
|
|
191
|
-
let
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
Some(execution_providers.as_str())
|
|
195
|
-
};
|
|
196
|
-
let embedder = Embedder::from_dir(
|
|
197
|
-
&dir_path,
|
|
198
|
-
num_threads,
|
|
199
|
-
optimization_level,
|
|
200
|
-
name,
|
|
201
|
-
output_override,
|
|
202
|
-
max_length_override,
|
|
203
|
-
execution_providers_override,
|
|
204
|
-
)
|
|
205
|
-
.map_err(magnus::Error::from)?;
|
|
206
|
-
Ok(RbEmbedder {
|
|
207
|
-
inner: Arc::new(embedder),
|
|
208
|
-
normalize,
|
|
209
|
-
})
|
|
191
|
+
let embedder = Embedder::from_dir(&dir_path, num_threads, optimization_level, overrides)
|
|
192
|
+
.map_err(magnus::Error::from)?;
|
|
193
|
+
Ok(RbEmbedder { inner: Arc::new(embedder), normalize })
|
|
210
194
|
}
|
|
211
195
|
|
|
212
196
|
pub fn rb_embed(_ruby: &Ruby, rb_self: &Self, texts: RArray) -> Result<RbTensor, Error> {
|
|
@@ -231,42 +215,24 @@ impl RbReranker {
|
|
|
231
215
|
sigmoid: bool,
|
|
232
216
|
output_tensor: String,
|
|
233
217
|
max_length: usize,
|
|
218
|
+
padding: String,
|
|
234
219
|
execution_providers: String,
|
|
235
220
|
) -> Result<Self, Error> {
|
|
236
|
-
let name = if model_name.is_empty() {
|
|
237
|
-
|
|
238
|
-
} else {
|
|
239
|
-
|
|
240
|
-
};
|
|
241
|
-
let
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
None
|
|
248
|
-
} else {
|
|
249
|
-
Some(max_length)
|
|
221
|
+
let name = if model_name.is_empty() { None } else { Some(model_name.as_str()) };
|
|
222
|
+
let output_override = if output_tensor.is_empty() { None } else { Some(output_tensor.as_str()) };
|
|
223
|
+
let max_length_override = if max_length == 0 { None } else { Some(max_length) };
|
|
224
|
+
let execution_providers_override = if execution_providers.is_empty() { None } else { Some(execution_providers.as_str()) };
|
|
225
|
+
let padding_override = if padding.is_empty() { None } else { Some(padding.as_str()) };
|
|
226
|
+
let overrides = ModelLoadOverrides {
|
|
227
|
+
model_name: name,
|
|
228
|
+
output_tensor: output_override,
|
|
229
|
+
max_length: max_length_override,
|
|
230
|
+
padding: padding_override,
|
|
231
|
+
execution_providers: execution_providers_override,
|
|
250
232
|
};
|
|
251
|
-
let
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
Some(execution_providers.as_str())
|
|
255
|
-
};
|
|
256
|
-
let reranker = Reranker::from_dir(
|
|
257
|
-
&dir_path,
|
|
258
|
-
num_threads,
|
|
259
|
-
optimization_level,
|
|
260
|
-
name,
|
|
261
|
-
output_override,
|
|
262
|
-
max_length_override,
|
|
263
|
-
execution_providers_override,
|
|
264
|
-
)
|
|
265
|
-
.map_err(magnus::Error::from)?;
|
|
266
|
-
Ok(RbReranker {
|
|
267
|
-
inner: Arc::new(reranker),
|
|
268
|
-
sigmoid,
|
|
269
|
-
})
|
|
233
|
+
let reranker = Reranker::from_dir(&dir_path, num_threads, optimization_level, overrides)
|
|
234
|
+
.map_err(magnus::Error::from)?;
|
|
235
|
+
Ok(RbReranker { inner: Arc::new(reranker), sigmoid })
|
|
270
236
|
}
|
|
271
237
|
|
|
272
238
|
pub fn rb_score(
|
|
@@ -276,12 +242,8 @@ impl RbReranker {
|
|
|
276
242
|
candidates: RArray,
|
|
277
243
|
) -> Result<RArray, Error> {
|
|
278
244
|
let candidates: Vec<String> = candidates.to_vec()?;
|
|
279
|
-
let pairs: Vec<(String, String)> = candidates
|
|
280
|
-
.into_iter()
|
|
281
|
-
.map(|candidate| (query.clone(), candidate))
|
|
282
|
-
.collect();
|
|
245
|
+
let pairs: Vec<(String, String)> = candidates.into_iter().map(|c| (query.clone(), c)).collect();
|
|
283
246
|
let scores = score_without_gvl(&rb_self.inner, pairs, rb_self.sigmoid)?;
|
|
284
|
-
|
|
285
247
|
let out = ruby.ary_new_capa(scores.len());
|
|
286
248
|
for score in scores {
|
|
287
249
|
out.push(score)?;
|
|
@@ -317,7 +279,6 @@ impl RbTensor {
|
|
|
317
279
|
index, rb_self.rows
|
|
318
280
|
))));
|
|
319
281
|
}
|
|
320
|
-
|
|
321
282
|
let start = index * rb_self.cols;
|
|
322
283
|
let end = start + rb_self.cols;
|
|
323
284
|
let out = ruby.ary_new_capa(rb_self.cols);
|
|
@@ -342,7 +303,6 @@ impl RbTensor {
|
|
|
342
303
|
index, rb_self.rows
|
|
343
304
|
))));
|
|
344
305
|
}
|
|
345
|
-
|
|
346
306
|
let start = index * rb_self.cols;
|
|
347
307
|
let end = start + rb_self.cols;
|
|
348
308
|
let bytes = unsafe {
|
|
@@ -376,12 +336,12 @@ impl RbTensor {
|
|
|
376
336
|
pub fn register(ruby: &Ruby) -> Result<(), Error> {
|
|
377
337
|
let module = ruby.define_module("GTE")?;
|
|
378
338
|
let embedder_class = module.define_class("Embedder", ruby.class_object())?;
|
|
379
|
-
embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new,
|
|
339
|
+
embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 9))?;
|
|
380
340
|
embedder_class.define_method("embed", method!(RbEmbedder::rb_embed, 1))?;
|
|
381
341
|
embedder_class.define_method("embed_one", method!(RbEmbedder::rb_embed_one, 1))?;
|
|
382
342
|
|
|
383
343
|
let reranker_class = module.define_class("Reranker", ruby.class_object())?;
|
|
384
|
-
reranker_class.define_singleton_method("new", function!(RbReranker::rb_new,
|
|
344
|
+
reranker_class.define_singleton_method("new", function!(RbReranker::rb_new, 9))?;
|
|
385
345
|
reranker_class.define_method("score", method!(RbReranker::rb_score, 2))?;
|
|
386
346
|
|
|
387
347
|
let tensor_class = module.define_class("Tensor", ruby.class_object())?;
|