gte 0.0.12 → 0.0.14
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +129 -26
- data/VERSION +1 -1
- data/ext/gte/Cargo.toml +26 -4
- data/ext/gte/benches/hot_path.rs +20 -54
- data/ext/gte/build.rs +2 -6
- data/ext/gte/rustfmt.toml +5 -0
- data/ext/gte/src/embedder.rs +71 -43
- data/ext/gte/src/error.rs +4 -4
- data/ext/gte/src/lib.rs +1 -1
- data/ext/gte/src/model_config.rs +4 -0
- data/ext/gte/src/model_profile.rs +26 -87
- data/ext/gte/src/pipeline.rs +11 -30
- data/ext/gte/src/postprocess.rs +8 -14
- data/ext/gte/src/reranker.rs +50 -50
- data/ext/gte/src/ruby_embedder.rs +48 -53
- data/ext/gte/src/session.rs +140 -249
- data/ext/gte/src/tokenizer.rs +51 -125
- data/ext/gte/tests/inference_integration_test.rs +8 -18
- data/ext/gte/tests/padding_regression_test.rs +13 -26
- data/ext/gte/tests/tokenizer_unit_test.rs +10 -24
- data/lib/gte/config.rb +2 -1
- data/lib/gte/embedder.rb +6 -2
- data/lib/gte/reranker.rb +3 -1
- data/lib/gte.rb +6 -0
- metadata +2 -1
data/ext/gte/src/session.rs
CHANGED
|
@@ -4,221 +4,151 @@ use crate::pipeline::{extract_output_tensor, InputTensors};
|
|
|
4
4
|
use crate::postprocess::mean_pool;
|
|
5
5
|
use crate::tokenizer::Tokenized;
|
|
6
6
|
use ndarray::{Array2, ArrayView2, ArrayViewD, Ix2};
|
|
7
|
-
use ort::execution_providers::{
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
use
|
|
7
|
+
use ort::execution_providers::{CoreMLExecutionProvider, ExecutionProviderDispatch, XNNPACKExecutionProvider};
|
|
8
|
+
use ort::session::{OutputSelector, RunOptions, Session};
|
|
9
|
+
use std::cell::RefCell;
|
|
10
|
+
use std::collections::hash_map::Entry;
|
|
11
|
+
use std::collections::HashMap;
|
|
11
12
|
use std::path::{Path, PathBuf};
|
|
12
13
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
13
|
-
use std::sync::{Condvar, Mutex};
|
|
14
|
-
|
|
15
|
-
pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
|
|
16
|
-
let opt_level = match config.optimization_level {
|
|
17
|
-
0 => ort::session::builder::GraphOptimizationLevel::Disable,
|
|
18
|
-
1 => ort::session::builder::GraphOptimizationLevel::Level1,
|
|
19
|
-
2 => ort::session::builder::GraphOptimizationLevel::Level2,
|
|
20
|
-
_ => ort::session::builder::GraphOptimizationLevel::Level3,
|
|
21
|
-
};
|
|
22
|
-
|
|
23
|
-
fn ort_err(e: impl std::fmt::Display) -> GteError {
|
|
24
|
-
GteError::Ort(e.to_string())
|
|
25
|
-
}
|
|
26
|
-
|
|
27
|
-
let mut builder = Session::builder()
|
|
28
|
-
.map_err(ort_err)?
|
|
29
|
-
.with_optimization_level(opt_level)
|
|
30
|
-
.map_err(ort_err)?;
|
|
31
|
-
|
|
32
|
-
let providers = preferred_execution_providers(config.execution_providers.as_deref());
|
|
33
|
-
if !providers.is_empty() {
|
|
34
|
-
builder = builder
|
|
35
|
-
.with_execution_providers(providers)
|
|
36
|
-
.map_err(ort_err)?;
|
|
37
|
-
}
|
|
38
|
-
|
|
39
|
-
builder.commit_from_file(model_path).map_err(ort_err)
|
|
40
|
-
}
|
|
41
14
|
|
|
42
15
|
// ---------------------------------------------------------------------------
|
|
43
|
-
//
|
|
16
|
+
// Thread-local session storage — each OS thread lazily creates its own ONNX
|
|
17
|
+
// session the first time it calls into a given pool. No Mutex, no contention.
|
|
44
18
|
// ---------------------------------------------------------------------------
|
|
45
19
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
20
|
+
static NEXT_POOL_ID: AtomicUsize = AtomicUsize::new(1);
|
|
21
|
+
|
|
22
|
+
struct SessionRecipe {
|
|
23
|
+
model_path: PathBuf,
|
|
24
|
+
build_config: ModelConfig,
|
|
51
25
|
}
|
|
52
26
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
let parsed = raw.trim().parse::<usize>().ok()?;
|
|
56
|
-
(parsed > 0).then_some(parsed)
|
|
27
|
+
thread_local! {
|
|
28
|
+
static SESSIONS: RefCell<HashMap<usize, Session>> = RefCell::new(HashMap::new());
|
|
57
29
|
}
|
|
58
30
|
|
|
59
31
|
pub struct SessionPool {
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
created: AtomicUsize,
|
|
63
|
-
capacity: usize,
|
|
64
|
-
model_path: PathBuf,
|
|
65
|
-
build_config: ModelConfig,
|
|
32
|
+
pool_id: usize,
|
|
33
|
+
recipe: SessionRecipe,
|
|
66
34
|
}
|
|
67
35
|
|
|
68
36
|
impl SessionPool {
|
|
69
|
-
pub fn new(initial: Session, model_path:
|
|
70
|
-
let
|
|
71
|
-
Self {
|
|
72
|
-
sessions: Mutex::new(vec![initial]),
|
|
73
|
-
available: Condvar::new(),
|
|
74
|
-
created: AtomicUsize::new(1),
|
|
75
|
-
capacity,
|
|
76
|
-
model_path,
|
|
77
|
-
build_config,
|
|
78
|
-
}
|
|
79
|
-
}
|
|
80
|
-
|
|
81
|
-
pub fn acquire(&self) -> Result<PooledSession<'_>> {
|
|
82
|
-
if let Some(session) = self.take_available_session() {
|
|
83
|
-
return Ok(PooledSession {
|
|
84
|
-
pool: self,
|
|
85
|
-
session: Some(session),
|
|
86
|
-
});
|
|
87
|
-
}
|
|
37
|
+
pub fn new(initial: Session, model_path: &Path, build_config: &ModelConfig) -> Result<Self> {
|
|
38
|
+
let pool_id = NEXT_POOL_ID.fetch_add(1, Ordering::Relaxed);
|
|
88
39
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
session: Some(session),
|
|
93
|
-
});
|
|
94
|
-
}
|
|
40
|
+
SESSIONS.with(|map| {
|
|
41
|
+
_ = map.borrow_mut().insert(pool_id, initial);
|
|
42
|
+
});
|
|
95
43
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
session: Some(session),
|
|
44
|
+
Ok(Self {
|
|
45
|
+
pool_id,
|
|
46
|
+
recipe: SessionRecipe { model_path: model_path.to_path_buf(), build_config: build_config.clone() },
|
|
100
47
|
})
|
|
101
48
|
}
|
|
102
49
|
|
|
103
|
-
fn
|
|
104
|
-
self.
|
|
105
|
-
self.available.notify_one();
|
|
50
|
+
pub fn run(&self, tokenized: &Tokenized, config: &ModelConfig) -> Result<Array2<f32>> {
|
|
51
|
+
self.with_session(|session| run_session(session, tokenized, config))
|
|
106
52
|
}
|
|
107
53
|
|
|
108
|
-
fn
|
|
109
|
-
|
|
54
|
+
pub fn with_session<F, R>(&self, f: F) -> Result<R>
|
|
55
|
+
where
|
|
56
|
+
F: FnOnce(&mut Session) -> Result<R>,
|
|
57
|
+
{
|
|
58
|
+
SESSIONS.with(|map| {
|
|
59
|
+
let mut map = map.borrow_mut();
|
|
60
|
+
let session = match map.entry(self.pool_id) {
|
|
61
|
+
Entry::Occupied(e) => e.into_mut(),
|
|
62
|
+
Entry::Vacant(e) => {
|
|
63
|
+
let session = build_session(&self.recipe.model_path, &self.recipe.build_config)?;
|
|
64
|
+
e.insert(session)
|
|
65
|
+
}
|
|
66
|
+
};
|
|
67
|
+
f(session)
|
|
68
|
+
})
|
|
110
69
|
}
|
|
70
|
+
}
|
|
111
71
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |count| {
|
|
116
|
-
(count < self.capacity).then_some(count + 1)
|
|
117
|
-
});
|
|
118
|
-
if grew.is_err() {
|
|
119
|
-
return Ok(None);
|
|
120
|
-
}
|
|
72
|
+
// ---------------------------------------------------------------------------
|
|
73
|
+
// Session construction
|
|
74
|
+
// ---------------------------------------------------------------------------
|
|
121
75
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
self.created.fetch_sub(1, Ordering::AcqRel);
|
|
126
|
-
Err(error)
|
|
127
|
-
}
|
|
128
|
-
}
|
|
76
|
+
pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
|
|
77
|
+
fn ort_err(e: impl std::fmt::Display) -> GteError {
|
|
78
|
+
GteError::Ort(e.to_string())
|
|
129
79
|
}
|
|
130
80
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
lock = self.available.wait(lock).unwrap();
|
|
138
|
-
}
|
|
139
|
-
}
|
|
140
|
-
}
|
|
81
|
+
let opt_level = match config.optimization_level {
|
|
82
|
+
0 => ort::session::builder::GraphOptimizationLevel::Disable,
|
|
83
|
+
1 => ort::session::builder::GraphOptimizationLevel::Level1,
|
|
84
|
+
2 => ort::session::builder::GraphOptimizationLevel::Level2,
|
|
85
|
+
_ => ort::session::builder::GraphOptimizationLevel::Level3,
|
|
86
|
+
};
|
|
141
87
|
|
|
142
|
-
|
|
143
|
-
pool: &'a SessionPool,
|
|
144
|
-
session: Option<Session>,
|
|
145
|
-
}
|
|
88
|
+
let mut builder = Session::builder().map_err(ort_err)?.with_optimization_level(opt_level).map_err(ort_err)?;
|
|
146
89
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
}
|
|
90
|
+
let intra_threads = std::env::var("GTE_INTRA_OP_NUM_THREADS")
|
|
91
|
+
.ok()
|
|
92
|
+
.and_then(|v| v.trim().parse::<usize>().ok())
|
|
93
|
+
.unwrap_or_else(|| std::thread::available_parallelism().map(|n| n.get().min(4)).unwrap_or(1));
|
|
94
|
+
builder = builder.with_intra_threads(intra_threads).map_err(ort_err)?;
|
|
153
95
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
}
|
|
158
|
-
}
|
|
96
|
+
let inter_threads =
|
|
97
|
+
std::env::var("GTE_INTER_OP_NUM_THREADS").ok().and_then(|v| v.trim().parse::<usize>().ok()).unwrap_or(1);
|
|
98
|
+
builder = builder.with_inter_threads(inter_threads).map_err(ort_err)?;
|
|
159
99
|
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
100
|
+
let providers = match config.execution_providers.as_deref() {
|
|
101
|
+
Some(override_val) => preferred_execution_providers(Some(override_val)),
|
|
102
|
+
None => auto_detect_providers(),
|
|
103
|
+
};
|
|
104
|
+
if !providers.is_empty() {
|
|
105
|
+
builder = builder.with_execution_providers(providers).map_err(ort_err)?;
|
|
165
106
|
}
|
|
166
|
-
}
|
|
167
|
-
|
|
168
|
-
// ---------------------------------------------------------------------------
|
|
169
107
|
|
|
170
|
-
|
|
171
|
-
|
|
108
|
+
builder.commit_from_file(model_path).map_err(ort_err)
|
|
109
|
+
}
|
|
172
110
|
|
|
111
|
+
fn auto_detect_providers() -> Vec<ExecutionProviderDispatch> {
|
|
173
112
|
let mut providers = Vec::new();
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
"xnnpack" => {
|
|
177
|
-
providers.push(XNNPACKExecutionProvider::default().build().fail_silently())
|
|
178
|
-
}
|
|
179
|
-
"coreml" => providers.push(CoreMLExecutionProvider::default().build().fail_silently()),
|
|
180
|
-
_ => {}
|
|
181
|
-
}
|
|
182
|
-
}
|
|
113
|
+
#[cfg(target_arch = "aarch64")]
|
|
114
|
+
providers.push(XNNPACKExecutionProvider::default().build().fail_silently());
|
|
183
115
|
providers
|
|
184
116
|
}
|
|
185
117
|
|
|
186
|
-
fn
|
|
187
|
-
let
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
fn resolve_provider_order_with_env(
|
|
192
|
-
order_override: Option<&str>,
|
|
193
|
-
env_order: Option<&str>,
|
|
194
|
-
) -> String {
|
|
195
|
-
order_override
|
|
196
|
-
.or(env_order)
|
|
197
|
-
.unwrap_or("cpu")
|
|
198
|
-
.to_ascii_lowercase()
|
|
199
|
-
}
|
|
118
|
+
fn preferred_execution_providers(order_override: Option<&str>) -> Vec<ExecutionProviderDispatch> {
|
|
119
|
+
let order = match order_override {
|
|
120
|
+
Some(s) => s.to_ascii_lowercase(),
|
|
121
|
+
None => return auto_detect_providers(),
|
|
122
|
+
};
|
|
200
123
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
for provider in order.split(',').map(str::trim).filter(|p| !p.is_empty()) {
|
|
204
|
-
match provider {
|
|
205
|
-
"xnnpack" | "coreml" => providers.push(provider),
|
|
206
|
-
"none" | "cpu" => {}
|
|
207
|
-
_ => {}
|
|
208
|
-
}
|
|
124
|
+
if order.is_empty() || order == "cpu" || order == "none" {
|
|
125
|
+
return Vec::new();
|
|
209
126
|
}
|
|
127
|
+
|
|
128
|
+
let providers: Vec<_> = order
|
|
129
|
+
.split(',')
|
|
130
|
+
.map(str::trim)
|
|
131
|
+
.filter(|p| !p.is_empty())
|
|
132
|
+
.filter_map(|provider| match provider {
|
|
133
|
+
"xnnpack" => Some(XNNPACKExecutionProvider::default().build().fail_silently()),
|
|
134
|
+
"coreml" => Some(CoreMLExecutionProvider::default().build().fail_silently()),
|
|
135
|
+
_ => None,
|
|
136
|
+
})
|
|
137
|
+
.collect();
|
|
210
138
|
providers
|
|
211
139
|
}
|
|
212
140
|
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
) -> Result<Array2<f32>> {
|
|
141
|
+
// ---------------------------------------------------------------------------
|
|
142
|
+
// Run a single inference
|
|
143
|
+
// ---------------------------------------------------------------------------
|
|
144
|
+
|
|
145
|
+
pub fn run_session(session: &mut Session, tokenized: &Tokenized, config: &ModelConfig) -> Result<Array2<f32>> {
|
|
218
146
|
let input_tensors = InputTensors::from_tokenized(tokenized, config.with_attention_mask)?;
|
|
219
|
-
let
|
|
220
|
-
.
|
|
221
|
-
.
|
|
147
|
+
let run_opts = RunOptions::new()
|
|
148
|
+
.map_err(|e| GteError::Ort(e.to_string()))?
|
|
149
|
+
.with_outputs(OutputSelector::no_default().with(config.output_tensor.as_str()));
|
|
150
|
+
let outputs =
|
|
151
|
+
session.run_with_options(input_tensors.inputs, &run_opts).map_err(|e| GteError::Ort(e.to_string()))?;
|
|
222
152
|
let array = extract_output_tensor(&outputs, config.output_tensor.as_str())?;
|
|
223
153
|
|
|
224
154
|
extract_embeddings(array, input_tensors.attention_mask, config)
|
|
@@ -234,26 +164,21 @@ fn extract_embeddings(
|
|
|
234
164
|
let shape = array.shape();
|
|
235
165
|
if shape.len() != 3 || idx >= shape[1] {
|
|
236
166
|
return Err(GteError::Inference(format!(
|
|
237
|
-
"token extraction index {} out of bounds for output shape {:?}"
|
|
238
|
-
idx, shape
|
|
167
|
+
"token extraction index {idx} out of bounds for output shape {shape:?}"
|
|
239
168
|
)));
|
|
240
169
|
}
|
|
241
170
|
Ok(array.slice(ndarray::s![.., idx, ..]).into_owned())
|
|
242
171
|
}
|
|
243
172
|
ExtractorMode::MeanPool => {
|
|
244
173
|
let ndim = array.ndim();
|
|
245
|
-
let hidden_states = array
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
ndim
|
|
249
|
-
))
|
|
250
|
-
})?;
|
|
174
|
+
let hidden_states = array
|
|
175
|
+
.into_dimensionality::<ndarray::Ix3>()
|
|
176
|
+
.map_err(|_| GteError::Inference(format!("mean pooling requires rank-3 output, got rank {ndim}")))?;
|
|
251
177
|
mean_pool(hidden_states, attention_mask)
|
|
252
178
|
}
|
|
253
|
-
ExtractorMode::Raw =>
|
|
254
|
-
.into_dimensionality::<Ix2>()
|
|
255
|
-
|
|
256
|
-
.map_err(|e| GteError::Shape(e.to_string())),
|
|
179
|
+
ExtractorMode::Raw => {
|
|
180
|
+
array.into_dimensionality::<Ix2>().map(|view| view.to_owned()).map_err(|e| GteError::Shape(e.to_string()))
|
|
181
|
+
}
|
|
257
182
|
}
|
|
258
183
|
}
|
|
259
184
|
|
|
@@ -262,10 +187,22 @@ mod tests {
|
|
|
262
187
|
use crate::model_config::{ExtractorMode, ModelConfig, PaddingMode};
|
|
263
188
|
use ndarray::{array, ArrayView2};
|
|
264
189
|
|
|
265
|
-
use super::
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
190
|
+
use super::extract_embeddings;
|
|
191
|
+
|
|
192
|
+
fn resolve_provider_order_with_env(order_override: Option<&str>, env_order: Option<&str>) -> String {
|
|
193
|
+
order_override.or(env_order).unwrap_or("cpu").to_ascii_lowercase()
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
fn parse_provider_registrations(order: &str) -> Vec<&str> {
|
|
197
|
+
let mut providers = Vec::new();
|
|
198
|
+
for provider in order.split(',').map(str::trim).filter(|p| !p.is_empty()) {
|
|
199
|
+
match provider {
|
|
200
|
+
"xnnpack" | "coreml" => providers.push(provider),
|
|
201
|
+
_ => {}
|
|
202
|
+
}
|
|
203
|
+
}
|
|
204
|
+
providers
|
|
205
|
+
}
|
|
269
206
|
|
|
270
207
|
fn test_config(mode: ExtractorMode) -> ModelConfig {
|
|
271
208
|
ModelConfig {
|
|
@@ -277,6 +214,8 @@ mod tests {
|
|
|
277
214
|
with_attention_mask: true,
|
|
278
215
|
optimization_level: 3,
|
|
279
216
|
execution_providers: None,
|
|
217
|
+
lowercase_input: false,
|
|
218
|
+
max_input_chars: None,
|
|
280
219
|
}
|
|
281
220
|
}
|
|
282
221
|
|
|
@@ -306,93 +245,45 @@ mod tests {
|
|
|
306
245
|
|
|
307
246
|
#[test]
|
|
308
247
|
fn resolve_provider_order_prefers_override() {
|
|
309
|
-
assert_eq!(
|
|
310
|
-
resolve_provider_order_with_env(Some("xnnpack"), Some("coreml")),
|
|
311
|
-
"xnnpack"
|
|
312
|
-
);
|
|
248
|
+
assert_eq!(resolve_provider_order_with_env(Some("xnnpack"), Some("coreml")), "xnnpack");
|
|
313
249
|
assert_eq!(resolve_provider_order_with_env(Some("CPU"), None), "cpu");
|
|
314
250
|
}
|
|
315
251
|
|
|
316
252
|
#[test]
|
|
317
253
|
fn resolve_provider_order_falls_back_to_env_then_cpu_default() {
|
|
318
|
-
assert_eq!(
|
|
319
|
-
resolve_provider_order_with_env(None, Some("coreml")),
|
|
320
|
-
"coreml"
|
|
321
|
-
);
|
|
254
|
+
assert_eq!(resolve_provider_order_with_env(None, Some("coreml")), "coreml");
|
|
322
255
|
assert_eq!(resolve_provider_order_with_env(None, None), "cpu");
|
|
323
256
|
}
|
|
324
257
|
|
|
325
|
-
#[test]
|
|
326
|
-
fn parse_pool_capacity_override_uses_positive_integer_only() {
|
|
327
|
-
unsafe {
|
|
328
|
-
std::env::remove_var("GTE_SESSION_POOL_CAP");
|
|
329
|
-
}
|
|
330
|
-
assert_eq!(parse_pool_capacity_override(), None);
|
|
331
|
-
|
|
332
|
-
unsafe {
|
|
333
|
-
std::env::set_var("GTE_SESSION_POOL_CAP", "0");
|
|
334
|
-
}
|
|
335
|
-
assert_eq!(parse_pool_capacity_override(), None);
|
|
336
|
-
|
|
337
|
-
unsafe {
|
|
338
|
-
std::env::set_var("GTE_SESSION_POOL_CAP", "4");
|
|
339
|
-
}
|
|
340
|
-
assert_eq!(parse_pool_capacity_override(), Some(4));
|
|
341
|
-
|
|
342
|
-
unsafe {
|
|
343
|
-
std::env::set_var("GTE_SESSION_POOL_CAP", "abc");
|
|
344
|
-
}
|
|
345
|
-
assert_eq!(parse_pool_capacity_override(), None);
|
|
346
|
-
|
|
347
|
-
unsafe {
|
|
348
|
-
std::env::remove_var("GTE_SESSION_POOL_CAP");
|
|
349
|
-
}
|
|
350
|
-
}
|
|
351
|
-
|
|
352
258
|
#[test]
|
|
353
259
|
fn extract_embeddings_raw_copies_only_final_matrix() {
|
|
354
260
|
let output = array![[1.0f32, 2.0], [3.0, 4.0]];
|
|
355
|
-
let extracted =
|
|
356
|
-
output.view().into_dyn(),
|
|
357
|
-
|
|
358
|
-
&test_config(ExtractorMode::Raw),
|
|
359
|
-
)
|
|
360
|
-
.unwrap();
|
|
261
|
+
let extracted =
|
|
262
|
+
extract_embeddings(output.view().into_dyn(), empty_attention_mask(), &test_config(ExtractorMode::Raw))
|
|
263
|
+
.unwrap();
|
|
361
264
|
|
|
362
265
|
assert_eq!(extracted, output);
|
|
363
266
|
}
|
|
364
267
|
|
|
365
268
|
#[test]
|
|
366
269
|
fn extract_embeddings_token_selects_without_copying_full_sequence() {
|
|
367
|
-
let output = array![
|
|
368
|
-
[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]],
|
|
369
|
-
[[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]
|
|
370
|
-
];
|
|
270
|
+
let output = array![[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]];
|
|
371
271
|
let expected = array![[3.0f32, 4.0], [9.0, 10.0]];
|
|
372
|
-
let extracted =
|
|
373
|
-
output.view().into_dyn(),
|
|
374
|
-
|
|
375
|
-
&test_config(ExtractorMode::Token(1)),
|
|
376
|
-
)
|
|
377
|
-
.unwrap();
|
|
272
|
+
let extracted =
|
|
273
|
+
extract_embeddings(output.view().into_dyn(), empty_attention_mask(), &test_config(ExtractorMode::Token(1)))
|
|
274
|
+
.unwrap();
|
|
378
275
|
|
|
379
276
|
assert_eq!(extracted, expected);
|
|
380
277
|
}
|
|
381
278
|
|
|
382
279
|
#[test]
|
|
383
280
|
fn extract_embeddings_mean_pool_uses_output_view_and_attention_mask() {
|
|
384
|
-
let output = array![
|
|
385
|
-
[[1.0f32, 3.0], [5.0, 7.0], [100.0, 100.0]],
|
|
386
|
-
[[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]]
|
|
387
|
-
];
|
|
281
|
+
let output = array![[[1.0f32, 3.0], [5.0, 7.0], [100.0, 100.0]], [[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]]];
|
|
388
282
|
let attention_mask = array![[1_i64, 1, 0], [0, 1, 1]];
|
|
389
283
|
let expected = array![[3.0f32, 5.0], [8.0, 10.0]];
|
|
390
|
-
let extracted =
|
|
391
|
-
output.view().into_dyn(),
|
|
392
|
-
|
|
393
|
-
&test_config(ExtractorMode::MeanPool),
|
|
394
|
-
)
|
|
395
|
-
.unwrap();
|
|
284
|
+
let extracted =
|
|
285
|
+
extract_embeddings(output.view().into_dyn(), attention_mask.view(), &test_config(ExtractorMode::MeanPool))
|
|
286
|
+
.unwrap();
|
|
396
287
|
|
|
397
288
|
assert_eq!(extracted, expected);
|
|
398
289
|
}
|