gte 0.0.13 → 0.0.14
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +93 -27
- data/VERSION +1 -1
- data/ext/gte/Cargo.toml +26 -4
- data/ext/gte/benches/hot_path.rs +20 -54
- data/ext/gte/build.rs +2 -6
- data/ext/gte/rustfmt.toml +5 -0
- data/ext/gte/src/embedder.rs +71 -43
- data/ext/gte/src/error.rs +4 -4
- data/ext/gte/src/lib.rs +1 -1
- data/ext/gte/src/model_config.rs +4 -0
- data/ext/gte/src/model_profile.rs +26 -87
- data/ext/gte/src/pipeline.rs +11 -30
- data/ext/gte/src/postprocess.rs +8 -14
- data/ext/gte/src/reranker.rs +50 -50
- data/ext/gte/src/ruby_embedder.rs +48 -53
- data/ext/gte/src/session.rs +136 -248
- data/ext/gte/src/tokenizer.rs +51 -125
- data/ext/gte/tests/inference_integration_test.rs +8 -18
- data/ext/gte/tests/padding_regression_test.rs +13 -26
- data/ext/gte/tests/tokenizer_unit_test.rs +10 -24
- data/lib/gte/config.rb +2 -1
- data/lib/gte/embedder.rb +6 -2
- data/lib/gte/reranker.rb +3 -1
- data/lib/gte.rb +6 -0
- metadata +2 -1
data/ext/gte/src/session.rs
CHANGED
|
@@ -4,224 +4,151 @@ use crate::pipeline::{extract_output_tensor, InputTensors};
|
|
|
4
4
|
use crate::postprocess::mean_pool;
|
|
5
5
|
use crate::tokenizer::Tokenized;
|
|
6
6
|
use ndarray::{Array2, ArrayView2, ArrayViewD, Ix2};
|
|
7
|
-
use ort::execution_providers::{
|
|
8
|
-
CoreMLExecutionProvider, ExecutionProviderDispatch, XNNPACKExecutionProvider,
|
|
9
|
-
};
|
|
7
|
+
use ort::execution_providers::{CoreMLExecutionProvider, ExecutionProviderDispatch, XNNPACKExecutionProvider};
|
|
10
8
|
use ort::session::{OutputSelector, RunOptions, Session};
|
|
9
|
+
use std::cell::RefCell;
|
|
10
|
+
use std::collections::hash_map::Entry;
|
|
11
|
+
use std::collections::HashMap;
|
|
11
12
|
use std::path::{Path, PathBuf};
|
|
12
13
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
13
|
-
use std::sync::{Condvar, Mutex};
|
|
14
|
-
|
|
15
|
-
pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
|
|
16
|
-
let opt_level = match config.optimization_level {
|
|
17
|
-
0 => ort::session::builder::GraphOptimizationLevel::Disable,
|
|
18
|
-
1 => ort::session::builder::GraphOptimizationLevel::Level1,
|
|
19
|
-
2 => ort::session::builder::GraphOptimizationLevel::Level2,
|
|
20
|
-
_ => ort::session::builder::GraphOptimizationLevel::Level3,
|
|
21
|
-
};
|
|
22
|
-
|
|
23
|
-
fn ort_err(e: impl std::fmt::Display) -> GteError {
|
|
24
|
-
GteError::Ort(e.to_string())
|
|
25
|
-
}
|
|
26
|
-
|
|
27
|
-
let mut builder = Session::builder()
|
|
28
|
-
.map_err(ort_err)?
|
|
29
|
-
.with_optimization_level(opt_level)
|
|
30
|
-
.map_err(ort_err)?;
|
|
31
|
-
|
|
32
|
-
let providers = preferred_execution_providers(config.execution_providers.as_deref());
|
|
33
|
-
if !providers.is_empty() {
|
|
34
|
-
builder = builder
|
|
35
|
-
.with_execution_providers(providers)
|
|
36
|
-
.map_err(ort_err)?;
|
|
37
|
-
}
|
|
38
|
-
|
|
39
|
-
builder.commit_from_file(model_path).map_err(ort_err)
|
|
40
|
-
}
|
|
41
14
|
|
|
42
15
|
// ---------------------------------------------------------------------------
|
|
43
|
-
//
|
|
16
|
+
// Thread-local session storage — each OS thread lazily creates its own ONNX
|
|
17
|
+
// session the first time it calls into a given pool. No Mutex, no contention.
|
|
44
18
|
// ---------------------------------------------------------------------------
|
|
45
19
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
20
|
+
static NEXT_POOL_ID: AtomicUsize = AtomicUsize::new(1);
|
|
21
|
+
|
|
22
|
+
struct SessionRecipe {
|
|
23
|
+
model_path: PathBuf,
|
|
24
|
+
build_config: ModelConfig,
|
|
51
25
|
}
|
|
52
26
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
let parsed = raw.trim().parse::<usize>().ok()?;
|
|
56
|
-
(parsed > 0).then_some(parsed)
|
|
27
|
+
thread_local! {
|
|
28
|
+
static SESSIONS: RefCell<HashMap<usize, Session>> = RefCell::new(HashMap::new());
|
|
57
29
|
}
|
|
58
30
|
|
|
59
31
|
pub struct SessionPool {
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
created: AtomicUsize,
|
|
63
|
-
capacity: usize,
|
|
64
|
-
model_path: PathBuf,
|
|
65
|
-
build_config: ModelConfig,
|
|
32
|
+
pool_id: usize,
|
|
33
|
+
recipe: SessionRecipe,
|
|
66
34
|
}
|
|
67
35
|
|
|
68
36
|
impl SessionPool {
|
|
69
|
-
pub fn new(initial: Session, model_path:
|
|
70
|
-
let
|
|
71
|
-
Self {
|
|
72
|
-
sessions: Mutex::new(vec![initial]),
|
|
73
|
-
available: Condvar::new(),
|
|
74
|
-
created: AtomicUsize::new(1),
|
|
75
|
-
capacity,
|
|
76
|
-
model_path,
|
|
77
|
-
build_config,
|
|
78
|
-
}
|
|
79
|
-
}
|
|
80
|
-
|
|
81
|
-
pub fn acquire(&self) -> Result<PooledSession<'_>> {
|
|
82
|
-
if let Some(session) = self.take_available_session() {
|
|
83
|
-
return Ok(PooledSession {
|
|
84
|
-
pool: self,
|
|
85
|
-
session: Some(session),
|
|
86
|
-
});
|
|
87
|
-
}
|
|
37
|
+
pub fn new(initial: Session, model_path: &Path, build_config: &ModelConfig) -> Result<Self> {
|
|
38
|
+
let pool_id = NEXT_POOL_ID.fetch_add(1, Ordering::Relaxed);
|
|
88
39
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
session: Some(session),
|
|
93
|
-
});
|
|
94
|
-
}
|
|
40
|
+
SESSIONS.with(|map| {
|
|
41
|
+
_ = map.borrow_mut().insert(pool_id, initial);
|
|
42
|
+
});
|
|
95
43
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
session: Some(session),
|
|
44
|
+
Ok(Self {
|
|
45
|
+
pool_id,
|
|
46
|
+
recipe: SessionRecipe { model_path: model_path.to_path_buf(), build_config: build_config.clone() },
|
|
100
47
|
})
|
|
101
48
|
}
|
|
102
49
|
|
|
103
|
-
fn
|
|
104
|
-
self.
|
|
105
|
-
self.available.notify_one();
|
|
50
|
+
pub fn run(&self, tokenized: &Tokenized, config: &ModelConfig) -> Result<Array2<f32>> {
|
|
51
|
+
self.with_session(|session| run_session(session, tokenized, config))
|
|
106
52
|
}
|
|
107
53
|
|
|
108
|
-
fn
|
|
109
|
-
|
|
54
|
+
pub fn with_session<F, R>(&self, f: F) -> Result<R>
|
|
55
|
+
where
|
|
56
|
+
F: FnOnce(&mut Session) -> Result<R>,
|
|
57
|
+
{
|
|
58
|
+
SESSIONS.with(|map| {
|
|
59
|
+
let mut map = map.borrow_mut();
|
|
60
|
+
let session = match map.entry(self.pool_id) {
|
|
61
|
+
Entry::Occupied(e) => e.into_mut(),
|
|
62
|
+
Entry::Vacant(e) => {
|
|
63
|
+
let session = build_session(&self.recipe.model_path, &self.recipe.build_config)?;
|
|
64
|
+
e.insert(session)
|
|
65
|
+
}
|
|
66
|
+
};
|
|
67
|
+
f(session)
|
|
68
|
+
})
|
|
110
69
|
}
|
|
70
|
+
}
|
|
111
71
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |count| {
|
|
116
|
-
(count < self.capacity).then_some(count + 1)
|
|
117
|
-
});
|
|
118
|
-
if grew.is_err() {
|
|
119
|
-
return Ok(None);
|
|
120
|
-
}
|
|
72
|
+
// ---------------------------------------------------------------------------
|
|
73
|
+
// Session construction
|
|
74
|
+
// ---------------------------------------------------------------------------
|
|
121
75
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
self.created.fetch_sub(1, Ordering::AcqRel);
|
|
126
|
-
Err(error)
|
|
127
|
-
}
|
|
128
|
-
}
|
|
76
|
+
pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
|
|
77
|
+
fn ort_err(e: impl std::fmt::Display) -> GteError {
|
|
78
|
+
GteError::Ort(e.to_string())
|
|
129
79
|
}
|
|
130
80
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
lock = self.available.wait(lock).unwrap();
|
|
138
|
-
}
|
|
139
|
-
}
|
|
140
|
-
}
|
|
81
|
+
let opt_level = match config.optimization_level {
|
|
82
|
+
0 => ort::session::builder::GraphOptimizationLevel::Disable,
|
|
83
|
+
1 => ort::session::builder::GraphOptimizationLevel::Level1,
|
|
84
|
+
2 => ort::session::builder::GraphOptimizationLevel::Level2,
|
|
85
|
+
_ => ort::session::builder::GraphOptimizationLevel::Level3,
|
|
86
|
+
};
|
|
141
87
|
|
|
142
|
-
|
|
143
|
-
pool: &'a SessionPool,
|
|
144
|
-
session: Option<Session>,
|
|
145
|
-
}
|
|
88
|
+
let mut builder = Session::builder().map_err(ort_err)?.with_optimization_level(opt_level).map_err(ort_err)?;
|
|
146
89
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
}
|
|
90
|
+
let intra_threads = std::env::var("GTE_INTRA_OP_NUM_THREADS")
|
|
91
|
+
.ok()
|
|
92
|
+
.and_then(|v| v.trim().parse::<usize>().ok())
|
|
93
|
+
.unwrap_or_else(|| std::thread::available_parallelism().map(|n| n.get().min(4)).unwrap_or(1));
|
|
94
|
+
builder = builder.with_intra_threads(intra_threads).map_err(ort_err)?;
|
|
153
95
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
}
|
|
158
|
-
}
|
|
96
|
+
let inter_threads =
|
|
97
|
+
std::env::var("GTE_INTER_OP_NUM_THREADS").ok().and_then(|v| v.trim().parse::<usize>().ok()).unwrap_or(1);
|
|
98
|
+
builder = builder.with_inter_threads(inter_threads).map_err(ort_err)?;
|
|
159
99
|
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
100
|
+
let providers = match config.execution_providers.as_deref() {
|
|
101
|
+
Some(override_val) => preferred_execution_providers(Some(override_val)),
|
|
102
|
+
None => auto_detect_providers(),
|
|
103
|
+
};
|
|
104
|
+
if !providers.is_empty() {
|
|
105
|
+
builder = builder.with_execution_providers(providers).map_err(ort_err)?;
|
|
165
106
|
}
|
|
166
|
-
}
|
|
167
|
-
|
|
168
|
-
// ---------------------------------------------------------------------------
|
|
169
107
|
|
|
170
|
-
|
|
171
|
-
|
|
108
|
+
builder.commit_from_file(model_path).map_err(ort_err)
|
|
109
|
+
}
|
|
172
110
|
|
|
111
|
+
fn auto_detect_providers() -> Vec<ExecutionProviderDispatch> {
|
|
173
112
|
let mut providers = Vec::new();
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
"xnnpack" => {
|
|
177
|
-
providers.push(XNNPACKExecutionProvider::default().build().fail_silently())
|
|
178
|
-
}
|
|
179
|
-
"coreml" => providers.push(CoreMLExecutionProvider::default().build().fail_silently()),
|
|
180
|
-
_ => {}
|
|
181
|
-
}
|
|
182
|
-
}
|
|
113
|
+
#[cfg(target_arch = "aarch64")]
|
|
114
|
+
providers.push(XNNPACKExecutionProvider::default().build().fail_silently());
|
|
183
115
|
providers
|
|
184
116
|
}
|
|
185
117
|
|
|
186
|
-
fn
|
|
187
|
-
let
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
fn resolve_provider_order_with_env(
|
|
192
|
-
order_override: Option<&str>,
|
|
193
|
-
env_order: Option<&str>,
|
|
194
|
-
) -> String {
|
|
195
|
-
order_override
|
|
196
|
-
.or(env_order)
|
|
197
|
-
.unwrap_or("cpu")
|
|
198
|
-
.to_ascii_lowercase()
|
|
199
|
-
}
|
|
118
|
+
fn preferred_execution_providers(order_override: Option<&str>) -> Vec<ExecutionProviderDispatch> {
|
|
119
|
+
let order = match order_override {
|
|
120
|
+
Some(s) => s.to_ascii_lowercase(),
|
|
121
|
+
None => return auto_detect_providers(),
|
|
122
|
+
};
|
|
200
123
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
for provider in order.split(',').map(str::trim).filter(|p| !p.is_empty()) {
|
|
204
|
-
match provider {
|
|
205
|
-
"xnnpack" | "coreml" => providers.push(provider),
|
|
206
|
-
"none" | "cpu" => {}
|
|
207
|
-
_ => {}
|
|
208
|
-
}
|
|
124
|
+
if order.is_empty() || order == "cpu" || order == "none" {
|
|
125
|
+
return Vec::new();
|
|
209
126
|
}
|
|
127
|
+
|
|
128
|
+
let providers: Vec<_> = order
|
|
129
|
+
.split(',')
|
|
130
|
+
.map(str::trim)
|
|
131
|
+
.filter(|p| !p.is_empty())
|
|
132
|
+
.filter_map(|provider| match provider {
|
|
133
|
+
"xnnpack" => Some(XNNPACKExecutionProvider::default().build().fail_silently()),
|
|
134
|
+
"coreml" => Some(CoreMLExecutionProvider::default().build().fail_silently()),
|
|
135
|
+
_ => None,
|
|
136
|
+
})
|
|
137
|
+
.collect();
|
|
210
138
|
providers
|
|
211
139
|
}
|
|
212
140
|
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
) -> Result<Array2<f32>> {
|
|
141
|
+
// ---------------------------------------------------------------------------
|
|
142
|
+
// Run a single inference
|
|
143
|
+
// ---------------------------------------------------------------------------
|
|
144
|
+
|
|
145
|
+
pub fn run_session(session: &mut Session, tokenized: &Tokenized, config: &ModelConfig) -> Result<Array2<f32>> {
|
|
218
146
|
let input_tensors = InputTensors::from_tokenized(tokenized, config.with_attention_mask)?;
|
|
219
147
|
let run_opts = RunOptions::new()
|
|
220
148
|
.map_err(|e| GteError::Ort(e.to_string()))?
|
|
221
149
|
.with_outputs(OutputSelector::no_default().with(config.output_tensor.as_str()));
|
|
222
|
-
let outputs =
|
|
223
|
-
.run_with_options(input_tensors.inputs, &run_opts)
|
|
224
|
-
.map_err(|e| GteError::Ort(e.to_string()))?;
|
|
150
|
+
let outputs =
|
|
151
|
+
session.run_with_options(input_tensors.inputs, &run_opts).map_err(|e| GteError::Ort(e.to_string()))?;
|
|
225
152
|
let array = extract_output_tensor(&outputs, config.output_tensor.as_str())?;
|
|
226
153
|
|
|
227
154
|
extract_embeddings(array, input_tensors.attention_mask, config)
|
|
@@ -237,26 +164,21 @@ fn extract_embeddings(
|
|
|
237
164
|
let shape = array.shape();
|
|
238
165
|
if shape.len() != 3 || idx >= shape[1] {
|
|
239
166
|
return Err(GteError::Inference(format!(
|
|
240
|
-
"token extraction index {} out of bounds for output shape {:?}"
|
|
241
|
-
idx, shape
|
|
167
|
+
"token extraction index {idx} out of bounds for output shape {shape:?}"
|
|
242
168
|
)));
|
|
243
169
|
}
|
|
244
170
|
Ok(array.slice(ndarray::s![.., idx, ..]).into_owned())
|
|
245
171
|
}
|
|
246
172
|
ExtractorMode::MeanPool => {
|
|
247
173
|
let ndim = array.ndim();
|
|
248
|
-
let hidden_states = array
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
ndim
|
|
252
|
-
))
|
|
253
|
-
})?;
|
|
174
|
+
let hidden_states = array
|
|
175
|
+
.into_dimensionality::<ndarray::Ix3>()
|
|
176
|
+
.map_err(|_| GteError::Inference(format!("mean pooling requires rank-3 output, got rank {ndim}")))?;
|
|
254
177
|
mean_pool(hidden_states, attention_mask)
|
|
255
178
|
}
|
|
256
|
-
ExtractorMode::Raw =>
|
|
257
|
-
.into_dimensionality::<Ix2>()
|
|
258
|
-
|
|
259
|
-
.map_err(|e| GteError::Shape(e.to_string())),
|
|
179
|
+
ExtractorMode::Raw => {
|
|
180
|
+
array.into_dimensionality::<Ix2>().map(|view| view.to_owned()).map_err(|e| GteError::Shape(e.to_string()))
|
|
181
|
+
}
|
|
260
182
|
}
|
|
261
183
|
}
|
|
262
184
|
|
|
@@ -265,10 +187,22 @@ mod tests {
|
|
|
265
187
|
use crate::model_config::{ExtractorMode, ModelConfig, PaddingMode};
|
|
266
188
|
use ndarray::{array, ArrayView2};
|
|
267
189
|
|
|
268
|
-
use super::
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
190
|
+
use super::extract_embeddings;
|
|
191
|
+
|
|
192
|
+
fn resolve_provider_order_with_env(order_override: Option<&str>, env_order: Option<&str>) -> String {
|
|
193
|
+
order_override.or(env_order).unwrap_or("cpu").to_ascii_lowercase()
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
fn parse_provider_registrations(order: &str) -> Vec<&str> {
|
|
197
|
+
let mut providers = Vec::new();
|
|
198
|
+
for provider in order.split(',').map(str::trim).filter(|p| !p.is_empty()) {
|
|
199
|
+
match provider {
|
|
200
|
+
"xnnpack" | "coreml" => providers.push(provider),
|
|
201
|
+
_ => {}
|
|
202
|
+
}
|
|
203
|
+
}
|
|
204
|
+
providers
|
|
205
|
+
}
|
|
272
206
|
|
|
273
207
|
fn test_config(mode: ExtractorMode) -> ModelConfig {
|
|
274
208
|
ModelConfig {
|
|
@@ -280,6 +214,8 @@ mod tests {
|
|
|
280
214
|
with_attention_mask: true,
|
|
281
215
|
optimization_level: 3,
|
|
282
216
|
execution_providers: None,
|
|
217
|
+
lowercase_input: false,
|
|
218
|
+
max_input_chars: None,
|
|
283
219
|
}
|
|
284
220
|
}
|
|
285
221
|
|
|
@@ -309,93 +245,45 @@ mod tests {
|
|
|
309
245
|
|
|
310
246
|
#[test]
|
|
311
247
|
fn resolve_provider_order_prefers_override() {
|
|
312
|
-
assert_eq!(
|
|
313
|
-
resolve_provider_order_with_env(Some("xnnpack"), Some("coreml")),
|
|
314
|
-
"xnnpack"
|
|
315
|
-
);
|
|
248
|
+
assert_eq!(resolve_provider_order_with_env(Some("xnnpack"), Some("coreml")), "xnnpack");
|
|
316
249
|
assert_eq!(resolve_provider_order_with_env(Some("CPU"), None), "cpu");
|
|
317
250
|
}
|
|
318
251
|
|
|
319
252
|
#[test]
|
|
320
253
|
fn resolve_provider_order_falls_back_to_env_then_cpu_default() {
|
|
321
|
-
assert_eq!(
|
|
322
|
-
resolve_provider_order_with_env(None, Some("coreml")),
|
|
323
|
-
"coreml"
|
|
324
|
-
);
|
|
254
|
+
assert_eq!(resolve_provider_order_with_env(None, Some("coreml")), "coreml");
|
|
325
255
|
assert_eq!(resolve_provider_order_with_env(None, None), "cpu");
|
|
326
256
|
}
|
|
327
257
|
|
|
328
|
-
#[test]
|
|
329
|
-
fn parse_pool_capacity_override_uses_positive_integer_only() {
|
|
330
|
-
unsafe {
|
|
331
|
-
std::env::remove_var("GTE_SESSION_POOL_CAP");
|
|
332
|
-
}
|
|
333
|
-
assert_eq!(parse_pool_capacity_override(), None);
|
|
334
|
-
|
|
335
|
-
unsafe {
|
|
336
|
-
std::env::set_var("GTE_SESSION_POOL_CAP", "0");
|
|
337
|
-
}
|
|
338
|
-
assert_eq!(parse_pool_capacity_override(), None);
|
|
339
|
-
|
|
340
|
-
unsafe {
|
|
341
|
-
std::env::set_var("GTE_SESSION_POOL_CAP", "4");
|
|
342
|
-
}
|
|
343
|
-
assert_eq!(parse_pool_capacity_override(), Some(4));
|
|
344
|
-
|
|
345
|
-
unsafe {
|
|
346
|
-
std::env::set_var("GTE_SESSION_POOL_CAP", "abc");
|
|
347
|
-
}
|
|
348
|
-
assert_eq!(parse_pool_capacity_override(), None);
|
|
349
|
-
|
|
350
|
-
unsafe {
|
|
351
|
-
std::env::remove_var("GTE_SESSION_POOL_CAP");
|
|
352
|
-
}
|
|
353
|
-
}
|
|
354
|
-
|
|
355
258
|
#[test]
|
|
356
259
|
fn extract_embeddings_raw_copies_only_final_matrix() {
|
|
357
260
|
let output = array![[1.0f32, 2.0], [3.0, 4.0]];
|
|
358
|
-
let extracted =
|
|
359
|
-
output.view().into_dyn(),
|
|
360
|
-
|
|
361
|
-
&test_config(ExtractorMode::Raw),
|
|
362
|
-
)
|
|
363
|
-
.unwrap();
|
|
261
|
+
let extracted =
|
|
262
|
+
extract_embeddings(output.view().into_dyn(), empty_attention_mask(), &test_config(ExtractorMode::Raw))
|
|
263
|
+
.unwrap();
|
|
364
264
|
|
|
365
265
|
assert_eq!(extracted, output);
|
|
366
266
|
}
|
|
367
267
|
|
|
368
268
|
#[test]
|
|
369
269
|
fn extract_embeddings_token_selects_without_copying_full_sequence() {
|
|
370
|
-
let output = array![
|
|
371
|
-
[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]],
|
|
372
|
-
[[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]
|
|
373
|
-
];
|
|
270
|
+
let output = array![[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]];
|
|
374
271
|
let expected = array![[3.0f32, 4.0], [9.0, 10.0]];
|
|
375
|
-
let extracted =
|
|
376
|
-
output.view().into_dyn(),
|
|
377
|
-
|
|
378
|
-
&test_config(ExtractorMode::Token(1)),
|
|
379
|
-
)
|
|
380
|
-
.unwrap();
|
|
272
|
+
let extracted =
|
|
273
|
+
extract_embeddings(output.view().into_dyn(), empty_attention_mask(), &test_config(ExtractorMode::Token(1)))
|
|
274
|
+
.unwrap();
|
|
381
275
|
|
|
382
276
|
assert_eq!(extracted, expected);
|
|
383
277
|
}
|
|
384
278
|
|
|
385
279
|
#[test]
|
|
386
280
|
fn extract_embeddings_mean_pool_uses_output_view_and_attention_mask() {
|
|
387
|
-
let output = array![
|
|
388
|
-
[[1.0f32, 3.0], [5.0, 7.0], [100.0, 100.0]],
|
|
389
|
-
[[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]]
|
|
390
|
-
];
|
|
281
|
+
let output = array![[[1.0f32, 3.0], [5.0, 7.0], [100.0, 100.0]], [[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]]];
|
|
391
282
|
let attention_mask = array![[1_i64, 1, 0], [0, 1, 1]];
|
|
392
283
|
let expected = array![[3.0f32, 5.0], [8.0, 10.0]];
|
|
393
|
-
let extracted =
|
|
394
|
-
output.view().into_dyn(),
|
|
395
|
-
|
|
396
|
-
&test_config(ExtractorMode::MeanPool),
|
|
397
|
-
)
|
|
398
|
-
.unwrap();
|
|
284
|
+
let extracted =
|
|
285
|
+
extract_embeddings(output.view().into_dyn(), attention_mask.view(), &test_config(ExtractorMode::MeanPool))
|
|
286
|
+
.unwrap();
|
|
399
287
|
|
|
400
288
|
assert_eq!(extracted, expected);
|
|
401
289
|
}
|