gte 0.0.7 → 0.0.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +13 -8
- data/Rakefile +38 -3
- data/VERSION +1 -1
- data/ext/gte/Cargo.toml +4 -4
- data/ext/gte/src/embedder.rs +12 -17
- data/ext/gte/src/model_profile.rs +18 -20
- data/ext/gte/src/pipeline.rs +12 -9
- data/ext/gte/src/reranker.rs +8 -11
- data/ext/gte/src/ruby_embedder.rs +60 -119
- data/ext/gte/src/session.rs +279 -15
- data/lib/gte/embedder.rb +5 -4
- data/lib/gte/reranker.rb +1 -1
- data/lib/gte.rb +1 -11
- metadata +6 -6
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: 2c754b4675ee105e9a280cd9deafa00a81b9e02ee629131f3e908400006b6ae4
|
|
4
|
+
data.tar.gz: 40a0d3e04c3d2943ae50910164d644ecb763eac99a02044dc962cc141a0e13c5
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: 16614e01e7a33a53339ba9fe7cf32fe7606041518a24177258d7a6e5550516e8cff741d0f0df02b7e5863fc763c02ae81b943dc4b18295701a4cafdec6627cb0
|
|
7
|
+
data.tar.gz: 348e1fd1d9f4c44214b5101ba339109b5ececfbef18b48b7c11324a64481f476d8da831cc5148d17a85c41b525ee753c296d4421a4fb2adda269a3f5fe38cda6
|
data/README.md
CHANGED
|
@@ -33,8 +33,8 @@ raw_model = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
|
|
|
33
33
|
config.with(normalize: false)
|
|
34
34
|
end
|
|
35
35
|
|
|
36
|
-
|
|
37
|
-
config.with(threads:
|
|
36
|
+
single_thread = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
|
|
37
|
+
config.with(threads: 1)
|
|
38
38
|
end
|
|
39
39
|
|
|
40
40
|
custom = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
|
|
@@ -50,7 +50,7 @@ end
|
|
|
50
50
|
Config fields and defaults:
|
|
51
51
|
|
|
52
52
|
- `model_dir`: absolute path to model directory
|
|
53
|
-
- `threads`: `
|
|
53
|
+
- `threads`: `1` (default tuned for p95 latency; use `0` for ONNX Runtime auto-thread mode)
|
|
54
54
|
- `optimization_level`: `3`
|
|
55
55
|
- `model_name`: `nil`
|
|
56
56
|
- `normalize`: `true` (L2 normalization at Ruby-facing API)
|
|
@@ -68,7 +68,7 @@ Low-level embedder setup (without model cache):
|
|
|
68
68
|
|
|
69
69
|
```ruby
|
|
70
70
|
embedder = GTE::Embedder.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
|
|
71
|
-
config.with(threads:
|
|
71
|
+
config.with(threads: 1, execution_providers: "cpu")
|
|
72
72
|
end
|
|
73
73
|
```
|
|
74
74
|
|
|
@@ -78,7 +78,7 @@ Use `GTE::Reranker.config(model_dir)` for cross-encoder reranking.
|
|
|
78
78
|
|
|
79
79
|
```ruby
|
|
80
80
|
reranker = GTE::Reranker.config(ENV.fetch("GTE_RERANK_DIR")) do |config|
|
|
81
|
-
config.with(sigmoid: true, threads:
|
|
81
|
+
config.with(sigmoid: true, threads: 1)
|
|
82
82
|
end
|
|
83
83
|
|
|
84
84
|
query = "how to train a neural network?"
|
|
@@ -102,7 +102,7 @@ ranked = reranker.rerank(query: query, candidates: candidates)
|
|
|
102
102
|
Reranker config fields and defaults:
|
|
103
103
|
|
|
104
104
|
- `model_dir`: absolute path to model directory
|
|
105
|
-
- `threads`: `
|
|
105
|
+
- `threads`: `1`
|
|
106
106
|
- `optimization_level`: `3`
|
|
107
107
|
- `model_name`: `nil`
|
|
108
108
|
- `sigmoid`: `false` (set `true` if you want bounded [0,1] style scores)
|
|
@@ -174,7 +174,7 @@ make ci
|
|
|
174
174
|
|
|
175
175
|
## Benchmark
|
|
176
176
|
|
|
177
|
-
The repo includes
|
|
177
|
+
The repo includes a shared multi-runtime benchmark harness:
|
|
178
178
|
|
|
179
179
|
```bash
|
|
180
180
|
make bench
|
|
@@ -183,6 +183,11 @@ nix develop -c bundle exec rake bench:matrix_sweep
|
|
|
183
183
|
nix develop -c bundle exec ruby bench/memory_probe.rb --compare-pure
|
|
184
184
|
```
|
|
185
185
|
|
|
186
|
+
- `make bench`: Puma-like single-request comparison at concurrency `16`
|
|
187
|
+
- `rake bench:pure_compare`: batch amortization comparison
|
|
188
|
+
- `rake bench:matrix_sweep`: GTE provider/thread sweep using the shared result schema
|
|
189
|
+
- Optional Python comparisons use `bench/python_onnxruntime.py` and are skipped automatically if local dependencies are unavailable.
|
|
190
|
+
|
|
186
191
|
To run benchmark + append a `RUNS.md` entry + enforce goal checks:
|
|
187
192
|
|
|
188
193
|
```bash
|
|
@@ -191,5 +196,5 @@ make bench-record
|
|
|
191
196
|
|
|
192
197
|
`bench/runs_ledger.rb check` is goal-focused by default:
|
|
193
198
|
|
|
194
|
-
- Enforces goal metric (`response_time_p95`
|
|
199
|
+
- Enforces the goal metric (`response_time_p95`) across every enabled competitor.
|
|
195
200
|
- Does not require current-version coverage in `RUNS.md` unless explicitly enabled.
|
data/Rakefile
CHANGED
|
@@ -10,17 +10,52 @@ rescue LoadError
|
|
|
10
10
|
end
|
|
11
11
|
|
|
12
12
|
spec = Gem::Specification.load('gte.gemspec')
|
|
13
|
+
cross_target = ENV.fetch('RUBY_TARGET', nil)
|
|
13
14
|
|
|
14
|
-
|
|
15
|
+
if cross_target == 'arm64-darwin'
|
|
16
|
+
# rb-sys-dock's darwin image can expose an unusable default LIBRARY_PATH.
|
|
17
|
+
# Force the compiler-rt darwin runtime directory so -lclang_rt.osx resolves.
|
|
18
|
+
ENV['LIBRARY_PATH'] = '/usr/lib/llvm-10/lib/clang/10.0.0/lib/darwin'
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
extension_task = Rake::ExtensionTask.new('gte', spec) do |ext|
|
|
15
22
|
ext.lib_dir = 'lib/gte'
|
|
16
23
|
ext.cross_compile = true
|
|
17
|
-
|
|
24
|
+
# rb-sys-dock invokes `rake native:$RUBY_TARGET gem` without the `cross` task,
|
|
25
|
+
# so scope platforms during dock builds to avoid host-Ruby fallback copy tasks.
|
|
26
|
+
cross_platforms = if cross_target && !cross_target.empty?
|
|
27
|
+
[cross_target]
|
|
28
|
+
else
|
|
29
|
+
%w[x86_64-linux aarch64-linux arm64-darwin]
|
|
30
|
+
end
|
|
31
|
+
ext.cross_platform = cross_platforms
|
|
32
|
+
end
|
|
33
|
+
|
|
34
|
+
if cross_target && !cross_target.empty? && ENV['RUBY_CC_VERSION']
|
|
35
|
+
ruby_version = ENV['RUBY_CC_VERSION'].split(':').first
|
|
36
|
+
lib_binary_path = File.join(extension_task.lib_dir, File.basename(extension_task.binary(cross_target)))
|
|
37
|
+
copy_task = "copy:gte:#{cross_target}:#{ruby_version}"
|
|
38
|
+
|
|
39
|
+
if Rake::Task.task_defined?(lib_binary_path) && Rake::Task.task_defined?(copy_task)
|
|
40
|
+
Rake::Task[lib_binary_path].prerequisites.clear
|
|
41
|
+
Rake::Task[lib_binary_path].enhance([copy_task])
|
|
42
|
+
end
|
|
18
43
|
end
|
|
19
44
|
|
|
20
45
|
task default: %i[compile spec]
|
|
21
46
|
|
|
47
|
+
def bundler_env
|
|
48
|
+
root = File.expand_path(__dir__)
|
|
49
|
+
{
|
|
50
|
+
'BUNDLE_DISABLE_SHARED_GEMS' => '1',
|
|
51
|
+
'GEM_HOME' => File.join(root, '.bundle-gems'),
|
|
52
|
+
'GEM_PATH' => File.join(root, '.bundle-gems'),
|
|
53
|
+
'BUNDLE_PATH' => File.join(root, 'vendor/bundle')
|
|
54
|
+
}
|
|
55
|
+
end
|
|
56
|
+
|
|
22
57
|
def run_in_nix(*command)
|
|
23
|
-
sh('nix', 'develop', '-c', *command)
|
|
58
|
+
sh(bundler_env, 'nix', 'develop', '-c', *command)
|
|
24
59
|
end
|
|
25
60
|
|
|
26
61
|
namespace :bench do
|
data/VERSION
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
0.0.
|
|
1
|
+
0.0.8
|
data/ext/gte/Cargo.toml
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[package]
|
|
2
2
|
name = "gte"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.8"
|
|
4
4
|
edition = "2021"
|
|
5
5
|
authors = ["elcuervo <elcuervo@elcuervo.net>"]
|
|
6
6
|
license = "MIT"
|
|
@@ -21,10 +21,10 @@ ruby-ffi = ["dep:magnus", "dep:rb-sys"]
|
|
|
21
21
|
[dependencies]
|
|
22
22
|
rb-sys = { version = "0.9", features = ["stable-api-compiled-fallback"], optional = true }
|
|
23
23
|
magnus = { version = "0.8", optional = true }
|
|
24
|
-
ort = { version = "=2.0.0-rc.
|
|
25
|
-
ort-sys = "=2.0.0-rc.
|
|
24
|
+
ort = { version = "=2.0.0-rc.12", features = ["ndarray", "xnnpack"] }
|
|
25
|
+
ort-sys = "=2.0.0-rc.12"
|
|
26
26
|
tokenizers = "0.21.0"
|
|
27
|
-
ndarray = "0.
|
|
27
|
+
ndarray = "0.17"
|
|
28
28
|
half = "2"
|
|
29
29
|
serde = { version = "1", features = ["derive"] }
|
|
30
30
|
serde_json = "1"
|
data/ext/gte/src/embedder.rs
CHANGED
|
@@ -5,15 +5,14 @@ use crate::model_profile::{
|
|
|
5
5
|
resolve_named_model, resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
|
|
6
6
|
};
|
|
7
7
|
use crate::postprocess::normalize_l2 as normalize_l2_rows;
|
|
8
|
-
use crate::session::{build_session, run_session};
|
|
8
|
+
use crate::session::{build_session, run_session, SessionPool};
|
|
9
9
|
use crate::tokenizer::{parse_padding_mode_override, Tokenized, Tokenizer};
|
|
10
10
|
use ndarray::Array2;
|
|
11
|
-
use
|
|
12
|
-
use std::path::Path;
|
|
11
|
+
use std::path::{Path, PathBuf};
|
|
13
12
|
|
|
14
13
|
pub struct Embedder {
|
|
15
14
|
tokenizer: Tokenizer,
|
|
16
|
-
|
|
15
|
+
pool: SessionPool,
|
|
17
16
|
config: ModelConfig,
|
|
18
17
|
}
|
|
19
18
|
|
|
@@ -30,12 +29,10 @@ impl Embedder {
|
|
|
30
29
|
config.padding_mode,
|
|
31
30
|
None,
|
|
32
31
|
)?;
|
|
33
|
-
let
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
config,
|
|
38
|
-
})
|
|
32
|
+
let model_path = model_path.as_ref().to_path_buf();
|
|
33
|
+
let session = build_session(&model_path, &config)?;
|
|
34
|
+
let pool = SessionPool::new(session, model_path, config.clone());
|
|
35
|
+
Ok(Self { tokenizer, pool, config })
|
|
39
36
|
}
|
|
40
37
|
|
|
41
38
|
pub fn from_dir<P: AsRef<Path>>(
|
|
@@ -53,7 +50,7 @@ impl Embedder {
|
|
|
53
50
|
|
|
54
51
|
let dir = dir.as_ref();
|
|
55
52
|
let tokenizer_path = resolve_tokenizer_path(dir)?;
|
|
56
|
-
let model_path = match overrides.model_name.filter(|s| !s.is_empty()) {
|
|
53
|
+
let model_path: PathBuf = match overrides.model_name.filter(|s| !s.is_empty()) {
|
|
57
54
|
Some(name) => resolve_named_model(dir, name)?,
|
|
58
55
|
None => resolve_default_text_model(dir)?,
|
|
59
56
|
};
|
|
@@ -117,11 +114,8 @@ impl Embedder {
|
|
|
117
114
|
tokenizer_profile.fixed_padding_length,
|
|
118
115
|
)?;
|
|
119
116
|
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
session,
|
|
123
|
-
config,
|
|
124
|
-
})
|
|
117
|
+
let pool = SessionPool::new(session, model_path, session_config);
|
|
118
|
+
Ok(Self { tokenizer, pool, config })
|
|
125
119
|
}
|
|
126
120
|
|
|
127
121
|
pub fn embed(&self, texts: Vec<String>) -> Result<Array2<f32>> {
|
|
@@ -134,7 +128,8 @@ impl Embedder {
|
|
|
134
128
|
}
|
|
135
129
|
|
|
136
130
|
pub fn run(&self, tokenized: &Tokenized) -> crate::error::Result<Array2<f32>> {
|
|
137
|
-
|
|
131
|
+
let mut session = self.pool.acquire()?;
|
|
132
|
+
run_session(&mut session, tokenized, &self.config)
|
|
138
133
|
}
|
|
139
134
|
}
|
|
140
135
|
|
|
@@ -133,11 +133,9 @@ fn parse_fixed_padding_length_from_tokenizer_json(tokenizer_json: &Value) -> Opt
|
|
|
133
133
|
}
|
|
134
134
|
|
|
135
135
|
pub fn validate_supported_text_inputs(session: &Session, api_label: &str) -> Result<()> {
|
|
136
|
-
let unsupported: Vec<String> = session
|
|
137
|
-
.
|
|
138
|
-
.
|
|
139
|
-
.filter(|i| !SUPPORTED_INPUTS.contains(&i.name.as_str()))
|
|
140
|
-
.map(|i| i.name.clone())
|
|
136
|
+
let unsupported: Vec<String> = session.inputs().iter()
|
|
137
|
+
.filter(|i| !SUPPORTED_INPUTS.contains(&i.name()))
|
|
138
|
+
.map(|i| i.name().to_owned())
|
|
141
139
|
.collect();
|
|
142
140
|
|
|
143
141
|
if unsupported.is_empty() {
|
|
@@ -160,7 +158,7 @@ pub fn validate_supported_text_inputs(session: &Session, api_label: &str) -> Res
|
|
|
160
158
|
}
|
|
161
159
|
|
|
162
160
|
pub fn has_input(session: &Session, name: &str) -> bool {
|
|
163
|
-
session.inputs.iter().any(|input| input.name == name)
|
|
161
|
+
session.inputs().iter().any(|input| input.name() == name)
|
|
164
162
|
}
|
|
165
163
|
|
|
166
164
|
fn output_name_matches(name: &str, preferred: &str) -> bool {
|
|
@@ -175,16 +173,16 @@ pub fn select_output_tensor(
|
|
|
175
173
|
) -> Result<String> {
|
|
176
174
|
if let Some(requested_name) = requested.map(str::trim).filter(|name| !name.is_empty()) {
|
|
177
175
|
if let Some(output) = session
|
|
178
|
-
.outputs
|
|
176
|
+
.outputs()
|
|
179
177
|
.iter()
|
|
180
|
-
.find(|o| output_name_matches(o.name
|
|
178
|
+
.find(|o| output_name_matches(o.name(), requested_name))
|
|
181
179
|
{
|
|
182
|
-
return Ok(output.name.
|
|
180
|
+
return Ok(output.name().to_owned());
|
|
183
181
|
}
|
|
184
182
|
let available = session
|
|
185
|
-
.outputs
|
|
183
|
+
.outputs()
|
|
186
184
|
.iter()
|
|
187
|
-
.map(|o| o.name
|
|
185
|
+
.map(|o| o.name())
|
|
188
186
|
.collect::<Vec<_>>()
|
|
189
187
|
.join(", ");
|
|
190
188
|
return Err(GteError::Inference(format!(
|
|
@@ -195,18 +193,18 @@ pub fn select_output_tensor(
|
|
|
195
193
|
|
|
196
194
|
for preferred in preferred_outputs {
|
|
197
195
|
if let Some(output) = session
|
|
198
|
-
.outputs
|
|
196
|
+
.outputs()
|
|
199
197
|
.iter()
|
|
200
|
-
.find(|o| output_name_matches(o.name
|
|
198
|
+
.find(|o| output_name_matches(o.name(), preferred))
|
|
201
199
|
{
|
|
202
|
-
return Ok(output.name.
|
|
200
|
+
return Ok(output.name().to_owned());
|
|
203
201
|
}
|
|
204
202
|
}
|
|
205
203
|
|
|
206
204
|
session
|
|
207
|
-
.outputs
|
|
205
|
+
.outputs()
|
|
208
206
|
.first()
|
|
209
|
-
.map(|o| o.name.
|
|
207
|
+
.map(|o| o.name().to_owned())
|
|
210
208
|
.ok_or_else(|| GteError::Inference("model has no outputs".into()))
|
|
211
209
|
}
|
|
212
210
|
|
|
@@ -216,9 +214,9 @@ fn output_basename(name: &str) -> &str {
|
|
|
216
214
|
|
|
217
215
|
pub fn infer_extraction_mode(session: &Session, output_tensor: &str) -> Result<ExtractorMode> {
|
|
218
216
|
let output = session
|
|
219
|
-
.outputs
|
|
217
|
+
.outputs()
|
|
220
218
|
.iter()
|
|
221
|
-
.find(|o| o.name == output_tensor)
|
|
219
|
+
.find(|o| o.name() == output_tensor)
|
|
222
220
|
.ok_or_else(|| {
|
|
223
221
|
GteError::Inference(format!(
|
|
224
222
|
"output tensor '{}' not found in model outputs",
|
|
@@ -226,8 +224,8 @@ pub fn infer_extraction_mode(session: &Session, output_tensor: &str) -> Result<E
|
|
|
226
224
|
))
|
|
227
225
|
})?;
|
|
228
226
|
|
|
229
|
-
let ndims = match
|
|
230
|
-
ort::value::ValueType::Tensor {
|
|
227
|
+
let ndims = match output.dtype() {
|
|
228
|
+
ort::value::ValueType::Tensor { shape, .. } => shape.len(),
|
|
231
229
|
other => {
|
|
232
230
|
return Err(GteError::Inference(format!(
|
|
233
231
|
"output is not a tensor: {:?}",
|
data/ext/gte/src/pipeline.rs
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
use crate::error::{GteError, Result};
|
|
2
2
|
use crate::tokenizer::Tokenized;
|
|
3
|
-
use ndarray::ArrayView2;
|
|
3
|
+
use ndarray::{ArrayView2, ArrayViewD};
|
|
4
4
|
use ort::session::SessionInputValue;
|
|
5
|
-
use ort::value::
|
|
5
|
+
use ort::value::TensorRef;
|
|
6
6
|
|
|
7
7
|
pub struct InputTensors<'a> {
|
|
8
8
|
pub inputs: Vec<(&'static str, SessionInputValue<'a>)>,
|
|
@@ -23,13 +23,13 @@ impl<'a> InputTensors<'a> {
|
|
|
23
23
|
let mut inputs = Vec::with_capacity(2 + usize::from(tokenized.type_ids.is_some()));
|
|
24
24
|
inputs.push((
|
|
25
25
|
"input_ids",
|
|
26
|
-
SessionInputValue::from(
|
|
26
|
+
SessionInputValue::from(TensorRef::from_array_view(input_ids_view)?),
|
|
27
27
|
));
|
|
28
28
|
|
|
29
29
|
if with_attention_mask {
|
|
30
30
|
inputs.push((
|
|
31
31
|
"attention_mask",
|
|
32
|
-
SessionInputValue::from(
|
|
32
|
+
SessionInputValue::from(TensorRef::from_array_view(attention_mask)?),
|
|
33
33
|
));
|
|
34
34
|
}
|
|
35
35
|
|
|
@@ -38,7 +38,7 @@ impl<'a> InputTensors<'a> {
|
|
|
38
38
|
ArrayView2::from_shape((tokenized.rows, tokenized.cols), type_ids)?;
|
|
39
39
|
inputs.push((
|
|
40
40
|
"token_type_ids",
|
|
41
|
-
SessionInputValue::from(
|
|
41
|
+
SessionInputValue::from(TensorRef::from_array_view(type_ids_view)?),
|
|
42
42
|
));
|
|
43
43
|
}
|
|
44
44
|
|
|
@@ -50,11 +50,14 @@ impl<'a> InputTensors<'a> {
|
|
|
50
50
|
}
|
|
51
51
|
|
|
52
52
|
pub fn extract_output_tensor<'a>(
|
|
53
|
-
outputs: &'a ort::session::SessionOutputs<'
|
|
53
|
+
outputs: &'a ort::session::SessionOutputs<'_>,
|
|
54
54
|
output_name: &str,
|
|
55
|
-
) -> Result<
|
|
55
|
+
) -> Result<ArrayViewD<'a, f32>> {
|
|
56
56
|
let tensor_value = outputs.get(output_name).ok_or_else(|| {
|
|
57
|
-
GteError::Inference(format!(
|
|
57
|
+
GteError::Inference(format!(
|
|
58
|
+
"output tensor '{}' not found in model outputs",
|
|
59
|
+
output_name
|
|
60
|
+
))
|
|
58
61
|
})?;
|
|
59
|
-
Ok(tensor_value.
|
|
62
|
+
Ok(tensor_value.try_extract_array::<f32>()?)
|
|
60
63
|
}
|
data/ext/gte/src/reranker.rs
CHANGED
|
@@ -6,10 +6,9 @@ use crate::model_profile::{
|
|
|
6
6
|
};
|
|
7
7
|
use crate::pipeline::{extract_output_tensor, InputTensors};
|
|
8
8
|
use crate::postprocess::sigmoid_scores;
|
|
9
|
-
use crate::session::build_session;
|
|
9
|
+
use crate::session::{build_session, SessionPool};
|
|
10
10
|
use crate::tokenizer::{parse_padding_mode_override, Tokenizer};
|
|
11
|
-
use
|
|
12
|
-
use std::path::Path;
|
|
11
|
+
use std::path::{Path, PathBuf};
|
|
13
12
|
|
|
14
13
|
#[derive(Debug, Clone)]
|
|
15
14
|
struct RerankerConfig {
|
|
@@ -22,7 +21,7 @@ struct RerankerConfig {
|
|
|
22
21
|
|
|
23
22
|
pub struct Reranker {
|
|
24
23
|
tokenizer: Tokenizer,
|
|
25
|
-
|
|
24
|
+
pool: SessionPool,
|
|
26
25
|
config: RerankerConfig,
|
|
27
26
|
}
|
|
28
27
|
|
|
@@ -35,7 +34,7 @@ impl Reranker {
|
|
|
35
34
|
) -> Result<Self> {
|
|
36
35
|
let dir = dir.as_ref();
|
|
37
36
|
let tokenizer_path = resolve_tokenizer_path(dir)?;
|
|
38
|
-
let model_path = match overrides.model_name.filter(|s| !s.is_empty()) {
|
|
37
|
+
let model_path: PathBuf = match overrides.model_name.filter(|s| !s.is_empty()) {
|
|
39
38
|
Some(name) => resolve_named_model(dir, name)?,
|
|
40
39
|
None => resolve_default_text_model(dir)?,
|
|
41
40
|
};
|
|
@@ -88,11 +87,8 @@ impl Reranker {
|
|
|
88
87
|
tokenizer_profile.fixed_padding_length,
|
|
89
88
|
)?;
|
|
90
89
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
session,
|
|
94
|
-
config,
|
|
95
|
-
})
|
|
90
|
+
let pool = SessionPool::new(session, model_path, probe_config);
|
|
91
|
+
Ok(Self { tokenizer, pool, config })
|
|
96
92
|
}
|
|
97
93
|
|
|
98
94
|
pub fn score_pairs(&self, pairs: &[(String, String)], apply_sigmoid: bool) -> Result<Vec<f32>> {
|
|
@@ -111,7 +107,8 @@ impl Reranker {
|
|
|
111
107
|
apply_sigmoid: bool,
|
|
112
108
|
) -> Result<Vec<f32>> {
|
|
113
109
|
let input_tensors = InputTensors::from_tokenized(tokenized, self.config.with_attention_mask)?;
|
|
114
|
-
let
|
|
110
|
+
let mut session = self.pool.acquire()?;
|
|
111
|
+
let outputs = session.run(input_tensors.inputs).map_err(|e| GteError::Ort(e.to_string()))?;
|
|
115
112
|
let array = extract_output_tensor(&outputs, self.config.output_tensor.as_str())?;
|
|
116
113
|
|
|
117
114
|
let mut scores = match array.ndim() {
|
|
@@ -28,21 +28,24 @@ pub struct RbTensor {
|
|
|
28
28
|
data: Vec<f32>,
|
|
29
29
|
}
|
|
30
30
|
|
|
31
|
+
// ---------------------------------------------------------------------------
|
|
32
|
+
// GVL-release helpers
|
|
33
|
+
// ---------------------------------------------------------------------------
|
|
34
|
+
|
|
31
35
|
struct InferArgs {
|
|
32
36
|
embedder: *const Embedder,
|
|
33
37
|
texts: *const Vec<String>,
|
|
34
38
|
normalize: bool,
|
|
35
|
-
result: Option<Result<ndarray::Array2<f32
|
|
39
|
+
result: Option<crate::error::Result<ndarray::Array2<f32>>>,
|
|
36
40
|
}
|
|
37
41
|
|
|
38
42
|
unsafe impl Send for InferArgs {}
|
|
39
43
|
|
|
40
44
|
struct ScoreArgs {
|
|
41
45
|
reranker: *const Reranker,
|
|
42
|
-
|
|
43
|
-
candidates: *const Vec<String>,
|
|
46
|
+
pairs: *const Vec<(String, String)>,
|
|
44
47
|
apply_sigmoid: bool,
|
|
45
|
-
result: Option<Result<Vec<f32
|
|
48
|
+
result: Option<crate::error::Result<Vec<f32>>>,
|
|
46
49
|
}
|
|
47
50
|
|
|
48
51
|
unsafe impl Send for ScoreArgs {}
|
|
@@ -57,6 +60,38 @@ fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
|
|
|
57
60
|
}
|
|
58
61
|
}
|
|
59
62
|
|
|
63
|
+
unsafe extern "C" fn run_without_gvl(ptr: *mut c_void) -> *mut c_void {
|
|
64
|
+
let args = &mut *(ptr as *mut InferArgs);
|
|
65
|
+
let run_result = catch_unwind(AssertUnwindSafe(|| {
|
|
66
|
+
let tokenized = (*args.embedder).tokenize(&*args.texts)?;
|
|
67
|
+
let embeddings = (*args.embedder).run(&tokenized)?;
|
|
68
|
+
if args.normalize { Ok(normalize_l2(embeddings)) } else { Ok(embeddings) }
|
|
69
|
+
}));
|
|
70
|
+
args.result = Some(match run_result {
|
|
71
|
+
Ok(result) => result,
|
|
72
|
+
Err(payload) => Err(GteError::Inference(format!(
|
|
73
|
+
"panic during inference: {}",
|
|
74
|
+
panic_payload_to_string(payload),
|
|
75
|
+
))),
|
|
76
|
+
});
|
|
77
|
+
std::ptr::null_mut()
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
unsafe extern "C" fn run_score_without_gvl(ptr: *mut c_void) -> *mut c_void {
|
|
81
|
+
let args = &mut *(ptr as *mut ScoreArgs);
|
|
82
|
+
let run_result = catch_unwind(AssertUnwindSafe(|| {
|
|
83
|
+
(*args.reranker).score_pairs(&*args.pairs, args.apply_sigmoid)
|
|
84
|
+
}));
|
|
85
|
+
args.result = Some(match run_result {
|
|
86
|
+
Ok(result) => result,
|
|
87
|
+
Err(payload) => Err(GteError::Inference(format!(
|
|
88
|
+
"panic during reranking: {}",
|
|
89
|
+
panic_payload_to_string(payload),
|
|
90
|
+
))),
|
|
91
|
+
});
|
|
92
|
+
std::ptr::null_mut()
|
|
93
|
+
}
|
|
94
|
+
|
|
60
95
|
fn infer_without_gvl(
|
|
61
96
|
embedder: &Arc<Embedder>,
|
|
62
97
|
normalize: bool,
|
|
@@ -87,15 +122,13 @@ fn infer_without_gvl(
|
|
|
87
122
|
|
|
88
123
|
fn score_without_gvl(
|
|
89
124
|
reranker: &Arc<Reranker>,
|
|
90
|
-
|
|
91
|
-
candidates: Vec<String>,
|
|
125
|
+
pairs: Vec<(String, String)>,
|
|
92
126
|
apply_sigmoid: bool,
|
|
93
127
|
) -> Result<Vec<f32>, Error> {
|
|
94
128
|
let scores = unsafe {
|
|
95
129
|
let mut args = ScoreArgs {
|
|
96
130
|
reranker: Arc::as_ptr(reranker),
|
|
97
|
-
|
|
98
|
-
candidates: &candidates as *const Vec<String>,
|
|
131
|
+
pairs: &pairs as *const Vec<(String, String)>,
|
|
99
132
|
apply_sigmoid,
|
|
100
133
|
result: None,
|
|
101
134
|
};
|
|
@@ -115,41 +148,7 @@ fn score_without_gvl(
|
|
|
115
148
|
Ok(scores)
|
|
116
149
|
}
|
|
117
150
|
|
|
118
|
-
|
|
119
|
-
let args = &mut *(ptr as *mut InferArgs);
|
|
120
|
-
let run_result = catch_unwind(AssertUnwindSafe(|| {
|
|
121
|
-
let tokenized = (*args.embedder).tokenize(&*args.texts)?;
|
|
122
|
-
let embeddings = (*args.embedder).run(&tokenized)?;
|
|
123
|
-
if args.normalize {
|
|
124
|
-
Ok(normalize_l2(embeddings))
|
|
125
|
-
} else {
|
|
126
|
-
Ok(embeddings)
|
|
127
|
-
}
|
|
128
|
-
}));
|
|
129
|
-
args.result = Some(match run_result {
|
|
130
|
-
Ok(result) => result,
|
|
131
|
-
Err(payload) => Err(GteError::Inference(format!(
|
|
132
|
-
"panic during inference: {}",
|
|
133
|
-
panic_payload_to_string(payload),
|
|
134
|
-
))),
|
|
135
|
-
});
|
|
136
|
-
std::ptr::null_mut()
|
|
137
|
-
}
|
|
138
|
-
|
|
139
|
-
unsafe extern "C" fn run_score_without_gvl(ptr: *mut c_void) -> *mut c_void {
|
|
140
|
-
let args = &mut *(ptr as *mut ScoreArgs);
|
|
141
|
-
let run_result = catch_unwind(AssertUnwindSafe(|| {
|
|
142
|
-
(*args.reranker).score(&*args.query, &*args.candidates, args.apply_sigmoid)
|
|
143
|
-
}));
|
|
144
|
-
args.result = Some(match run_result {
|
|
145
|
-
Ok(result) => result,
|
|
146
|
-
Err(payload) => Err(GteError::Inference(format!(
|
|
147
|
-
"panic during reranking: {}",
|
|
148
|
-
panic_payload_to_string(payload),
|
|
149
|
-
))),
|
|
150
|
-
});
|
|
151
|
-
std::ptr::null_mut()
|
|
152
|
-
}
|
|
151
|
+
// ---------------------------------------------------------------------------
|
|
153
152
|
|
|
154
153
|
fn tensor_from_array(embeddings: ndarray::Array2<f32>) -> Result<RbTensor, Error> {
|
|
155
154
|
let rows = embeddings.nrows();
|
|
@@ -177,31 +176,11 @@ impl RbEmbedder {
|
|
|
177
176
|
padding: String,
|
|
178
177
|
execution_providers: String,
|
|
179
178
|
) -> Result<Self, Error> {
|
|
180
|
-
let name = if model_name.is_empty() {
|
|
181
|
-
|
|
182
|
-
} else {
|
|
183
|
-
|
|
184
|
-
};
|
|
185
|
-
let output_override = if output_tensor.is_empty() {
|
|
186
|
-
None
|
|
187
|
-
} else {
|
|
188
|
-
Some(output_tensor.as_str())
|
|
189
|
-
};
|
|
190
|
-
let max_length_override = if max_length == 0 {
|
|
191
|
-
None
|
|
192
|
-
} else {
|
|
193
|
-
Some(max_length)
|
|
194
|
-
};
|
|
195
|
-
let execution_providers_override = if execution_providers.is_empty() {
|
|
196
|
-
None
|
|
197
|
-
} else {
|
|
198
|
-
Some(execution_providers.as_str())
|
|
199
|
-
};
|
|
200
|
-
let padding_override = if padding.is_empty() {
|
|
201
|
-
None
|
|
202
|
-
} else {
|
|
203
|
-
Some(padding.as_str())
|
|
204
|
-
};
|
|
179
|
+
let name = if model_name.is_empty() { None } else { Some(model_name.as_str()) };
|
|
180
|
+
let output_override = if output_tensor.is_empty() { None } else { Some(output_tensor.as_str()) };
|
|
181
|
+
let max_length_override = if max_length == 0 { None } else { Some(max_length) };
|
|
182
|
+
let execution_providers_override = if execution_providers.is_empty() { None } else { Some(execution_providers.as_str()) };
|
|
183
|
+
let padding_override = if padding.is_empty() { None } else { Some(padding.as_str()) };
|
|
205
184
|
let overrides = ModelLoadOverrides {
|
|
206
185
|
model_name: name,
|
|
207
186
|
output_tensor: output_override,
|
|
@@ -209,17 +188,9 @@ impl RbEmbedder {
|
|
|
209
188
|
padding: padding_override,
|
|
210
189
|
execution_providers: execution_providers_override,
|
|
211
190
|
};
|
|
212
|
-
let embedder = Embedder::from_dir(
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
optimization_level,
|
|
216
|
-
overrides,
|
|
217
|
-
)
|
|
218
|
-
.map_err(magnus::Error::from)?;
|
|
219
|
-
Ok(RbEmbedder {
|
|
220
|
-
inner: Arc::new(embedder),
|
|
221
|
-
normalize,
|
|
222
|
-
})
|
|
191
|
+
let embedder = Embedder::from_dir(&dir_path, num_threads, optimization_level, overrides)
|
|
192
|
+
.map_err(magnus::Error::from)?;
|
|
193
|
+
Ok(RbEmbedder { inner: Arc::new(embedder), normalize })
|
|
223
194
|
}
|
|
224
195
|
|
|
225
196
|
pub fn rb_embed(_ruby: &Ruby, rb_self: &Self, texts: RArray) -> Result<RbTensor, Error> {
|
|
@@ -247,31 +218,11 @@ impl RbReranker {
|
|
|
247
218
|
padding: String,
|
|
248
219
|
execution_providers: String,
|
|
249
220
|
) -> Result<Self, Error> {
|
|
250
|
-
let name = if model_name.is_empty() {
|
|
251
|
-
|
|
252
|
-
} else {
|
|
253
|
-
|
|
254
|
-
};
|
|
255
|
-
let output_override = if output_tensor.is_empty() {
|
|
256
|
-
None
|
|
257
|
-
} else {
|
|
258
|
-
Some(output_tensor.as_str())
|
|
259
|
-
};
|
|
260
|
-
let max_length_override = if max_length == 0 {
|
|
261
|
-
None
|
|
262
|
-
} else {
|
|
263
|
-
Some(max_length)
|
|
264
|
-
};
|
|
265
|
-
let execution_providers_override = if execution_providers.is_empty() {
|
|
266
|
-
None
|
|
267
|
-
} else {
|
|
268
|
-
Some(execution_providers.as_str())
|
|
269
|
-
};
|
|
270
|
-
let padding_override = if padding.is_empty() {
|
|
271
|
-
None
|
|
272
|
-
} else {
|
|
273
|
-
Some(padding.as_str())
|
|
274
|
-
};
|
|
221
|
+
let name = if model_name.is_empty() { None } else { Some(model_name.as_str()) };
|
|
222
|
+
let output_override = if output_tensor.is_empty() { None } else { Some(output_tensor.as_str()) };
|
|
223
|
+
let max_length_override = if max_length == 0 { None } else { Some(max_length) };
|
|
224
|
+
let execution_providers_override = if execution_providers.is_empty() { None } else { Some(execution_providers.as_str()) };
|
|
225
|
+
let padding_override = if padding.is_empty() { None } else { Some(padding.as_str()) };
|
|
275
226
|
let overrides = ModelLoadOverrides {
|
|
276
227
|
model_name: name,
|
|
277
228
|
output_tensor: output_override,
|
|
@@ -279,17 +230,9 @@ impl RbReranker {
|
|
|
279
230
|
padding: padding_override,
|
|
280
231
|
execution_providers: execution_providers_override,
|
|
281
232
|
};
|
|
282
|
-
let reranker = Reranker::from_dir(
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
optimization_level,
|
|
286
|
-
overrides,
|
|
287
|
-
)
|
|
288
|
-
.map_err(magnus::Error::from)?;
|
|
289
|
-
Ok(RbReranker {
|
|
290
|
-
inner: Arc::new(reranker),
|
|
291
|
-
sigmoid,
|
|
292
|
-
})
|
|
233
|
+
let reranker = Reranker::from_dir(&dir_path, num_threads, optimization_level, overrides)
|
|
234
|
+
.map_err(magnus::Error::from)?;
|
|
235
|
+
Ok(RbReranker { inner: Arc::new(reranker), sigmoid })
|
|
293
236
|
}
|
|
294
237
|
|
|
295
238
|
pub fn rb_score(
|
|
@@ -299,8 +242,8 @@ impl RbReranker {
|
|
|
299
242
|
candidates: RArray,
|
|
300
243
|
) -> Result<RArray, Error> {
|
|
301
244
|
let candidates: Vec<String> = candidates.to_vec()?;
|
|
302
|
-
let
|
|
303
|
-
|
|
245
|
+
let pairs: Vec<(String, String)> = candidates.into_iter().map(|c| (query.clone(), c)).collect();
|
|
246
|
+
let scores = score_without_gvl(&rb_self.inner, pairs, rb_self.sigmoid)?;
|
|
304
247
|
let out = ruby.ary_new_capa(scores.len());
|
|
305
248
|
for score in scores {
|
|
306
249
|
out.push(score)?;
|
|
@@ -336,7 +279,6 @@ impl RbTensor {
|
|
|
336
279
|
index, rb_self.rows
|
|
337
280
|
))));
|
|
338
281
|
}
|
|
339
|
-
|
|
340
282
|
let start = index * rb_self.cols;
|
|
341
283
|
let end = start + rb_self.cols;
|
|
342
284
|
let out = ruby.ary_new_capa(rb_self.cols);
|
|
@@ -361,7 +303,6 @@ impl RbTensor {
|
|
|
361
303
|
index, rb_self.rows
|
|
362
304
|
))));
|
|
363
305
|
}
|
|
364
|
-
|
|
365
306
|
let start = index * rb_self.cols;
|
|
366
307
|
let end = start + rb_self.cols;
|
|
367
308
|
let bytes = unsafe {
|
data/ext/gte/src/session.rs
CHANGED
|
@@ -3,12 +3,14 @@ use crate::model_config::{ExtractorMode, ModelConfig};
|
|
|
3
3
|
use crate::pipeline::{extract_output_tensor, InputTensors};
|
|
4
4
|
use crate::postprocess::mean_pool;
|
|
5
5
|
use crate::tokenizer::Tokenized;
|
|
6
|
-
use ndarray::{Array2, Ix2};
|
|
6
|
+
use ndarray::{Array2, ArrayView2, ArrayViewD, Ix2};
|
|
7
7
|
use ort::execution_providers::{
|
|
8
8
|
CoreMLExecutionProvider, ExecutionProviderDispatch, XNNPACKExecutionProvider,
|
|
9
9
|
};
|
|
10
10
|
use ort::session::Session;
|
|
11
|
-
use std::path::Path;
|
|
11
|
+
use std::path::{Path, PathBuf};
|
|
12
|
+
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
13
|
+
use std::sync::{Condvar, Mutex};
|
|
12
14
|
|
|
13
15
|
pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Result<Session> {
|
|
14
16
|
let opt_level = match config.optimization_level {
|
|
@@ -18,22 +20,176 @@ pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Res
|
|
|
18
20
|
_ => ort::session::builder::GraphOptimizationLevel::Level3,
|
|
19
21
|
};
|
|
20
22
|
|
|
21
|
-
|
|
22
|
-
.
|
|
23
|
-
|
|
23
|
+
fn ort_err(e: impl std::fmt::Display) -> GteError {
|
|
24
|
+
GteError::Ort(e.to_string())
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
let mut builder = Session::builder()
|
|
28
|
+
.map_err(ort_err)?
|
|
29
|
+
.with_optimization_level(opt_level)
|
|
30
|
+
.map_err(ort_err)?
|
|
31
|
+
.with_memory_pattern(true)
|
|
32
|
+
.map_err(ort_err)?;
|
|
24
33
|
|
|
25
34
|
let providers = preferred_execution_providers(config.execution_providers.as_deref());
|
|
26
35
|
if !providers.is_empty() {
|
|
27
|
-
builder = builder
|
|
36
|
+
builder = builder
|
|
37
|
+
.with_execution_providers(providers)
|
|
38
|
+
.map_err(ort_err)?;
|
|
28
39
|
}
|
|
29
40
|
|
|
30
41
|
if config.num_threads > 0 {
|
|
31
|
-
builder = builder
|
|
42
|
+
builder = builder
|
|
43
|
+
.with_intra_threads(config.num_threads)
|
|
44
|
+
.map_err(ort_err)?;
|
|
45
|
+
builder = builder
|
|
46
|
+
.with_inter_threads(config.num_threads)
|
|
47
|
+
.map_err(ort_err)?;
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
builder.commit_from_file(model_path).map_err(ort_err)
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
// ---------------------------------------------------------------------------
|
|
54
|
+
// Session pool
|
|
55
|
+
// ---------------------------------------------------------------------------
|
|
56
|
+
|
|
57
|
+
const AUTO_THREAD_POOL_CAP: usize = 6;
|
|
58
|
+
|
|
59
|
+
/// Keep enough sessions to cover the configured thread budget without
|
|
60
|
+
/// oversubscribing CPU parallelism. In ORT auto-thread mode (`num_threads == 0`)
|
|
61
|
+
/// we still keep a modest pool because request-level concurrency benefits from
|
|
62
|
+
/// more than one session even when ORT manages thread counts internally.
|
|
63
|
+
fn pool_capacity(num_threads: usize) -> usize {
|
|
64
|
+
let available_parallelism = std::thread::available_parallelism()
|
|
65
|
+
.map(|n| n.get())
|
|
66
|
+
.unwrap_or(1);
|
|
67
|
+
pool_capacity_with_parallelism(num_threads, available_parallelism)
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
fn pool_capacity_with_parallelism(num_threads: usize, available_parallelism: usize) -> usize {
|
|
71
|
+
if available_parallelism == 0 {
|
|
72
|
+
return 1;
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
if num_threads == 0 {
|
|
76
|
+
return available_parallelism.clamp(1, AUTO_THREAD_POOL_CAP);
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
available_parallelism.div_ceil(num_threads).max(1)
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
pub struct SessionPool {
|
|
83
|
+
sessions: Mutex<Vec<Session>>,
|
|
84
|
+
available: Condvar,
|
|
85
|
+
created: AtomicUsize,
|
|
86
|
+
capacity: usize,
|
|
87
|
+
model_path: PathBuf,
|
|
88
|
+
build_config: ModelConfig,
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
impl SessionPool {
|
|
92
|
+
pub fn new(initial: Session, model_path: PathBuf, build_config: ModelConfig) -> Self {
|
|
93
|
+
let capacity = pool_capacity(build_config.num_threads);
|
|
94
|
+
Self {
|
|
95
|
+
sessions: Mutex::new(vec![initial]),
|
|
96
|
+
available: Condvar::new(),
|
|
97
|
+
created: AtomicUsize::new(1),
|
|
98
|
+
capacity,
|
|
99
|
+
model_path,
|
|
100
|
+
build_config,
|
|
101
|
+
}
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
pub fn acquire(&self) -> Result<PooledSession<'_>> {
|
|
105
|
+
if let Some(session) = self.take_available_session() {
|
|
106
|
+
return Ok(PooledSession {
|
|
107
|
+
pool: self,
|
|
108
|
+
session: Some(session),
|
|
109
|
+
});
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
if let Some(session) = self.try_grow()? {
|
|
113
|
+
return Ok(PooledSession {
|
|
114
|
+
pool: self,
|
|
115
|
+
session: Some(session),
|
|
116
|
+
});
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
let session = self.wait_for_session();
|
|
120
|
+
Ok(PooledSession {
|
|
121
|
+
pool: self,
|
|
122
|
+
session: Some(session),
|
|
123
|
+
})
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
fn release(&self, session: Session) {
|
|
127
|
+
self.sessions.lock().unwrap().push(session);
|
|
128
|
+
self.available.notify_one();
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
fn take_available_session(&self) -> Option<Session> {
|
|
132
|
+
self.sessions.lock().unwrap().pop()
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
fn try_grow(&self) -> Result<Option<Session>> {
|
|
136
|
+
let grew = self
|
|
137
|
+
.created
|
|
138
|
+
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |count| {
|
|
139
|
+
(count < self.capacity).then_some(count + 1)
|
|
140
|
+
});
|
|
141
|
+
if grew.is_err() {
|
|
142
|
+
return Ok(None);
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
match build_session(&self.model_path, &self.build_config) {
|
|
146
|
+
Ok(session) => Ok(Some(session)),
|
|
147
|
+
Err(error) => {
|
|
148
|
+
self.created.fetch_sub(1, Ordering::AcqRel);
|
|
149
|
+
Err(error)
|
|
150
|
+
}
|
|
151
|
+
}
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
fn wait_for_session(&self) -> Session {
|
|
155
|
+
let mut lock = self.sessions.lock().unwrap();
|
|
156
|
+
loop {
|
|
157
|
+
if let Some(session) = lock.pop() {
|
|
158
|
+
return session;
|
|
159
|
+
}
|
|
160
|
+
lock = self.available.wait(lock).unwrap();
|
|
161
|
+
}
|
|
162
|
+
}
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
pub struct PooledSession<'a> {
|
|
166
|
+
pool: &'a SessionPool,
|
|
167
|
+
session: Option<Session>,
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
impl std::ops::Deref for PooledSession<'_> {
|
|
171
|
+
type Target = Session;
|
|
172
|
+
fn deref(&self) -> &Session {
|
|
173
|
+
self.session.as_ref().unwrap()
|
|
32
174
|
}
|
|
175
|
+
}
|
|
33
176
|
|
|
34
|
-
|
|
177
|
+
impl std::ops::DerefMut for PooledSession<'_> {
|
|
178
|
+
fn deref_mut(&mut self) -> &mut Session {
|
|
179
|
+
self.session.as_mut().unwrap()
|
|
180
|
+
}
|
|
35
181
|
}
|
|
36
182
|
|
|
183
|
+
impl Drop for PooledSession<'_> {
|
|
184
|
+
fn drop(&mut self) {
|
|
185
|
+
if let Some(s) = self.session.take() {
|
|
186
|
+
self.pool.release(s);
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
// ---------------------------------------------------------------------------
|
|
192
|
+
|
|
37
193
|
fn preferred_execution_providers(order_override: Option<&str>) -> Vec<ExecutionProviderDispatch> {
|
|
38
194
|
let order = resolve_provider_order(order_override);
|
|
39
195
|
|
|
@@ -55,7 +211,10 @@ fn resolve_provider_order(order_override: Option<&str>) -> String {
|
|
|
55
211
|
resolve_provider_order_with_env(order_override, env_order.as_deref())
|
|
56
212
|
}
|
|
57
213
|
|
|
58
|
-
fn resolve_provider_order_with_env(
|
|
214
|
+
fn resolve_provider_order_with_env(
|
|
215
|
+
order_override: Option<&str>,
|
|
216
|
+
env_order: Option<&str>,
|
|
217
|
+
) -> String {
|
|
59
218
|
order_override
|
|
60
219
|
.or(env_order)
|
|
61
220
|
.unwrap_or("cpu")
|
|
@@ -75,14 +234,24 @@ fn parse_provider_registrations(order: &str) -> Vec<&str> {
|
|
|
75
234
|
}
|
|
76
235
|
|
|
77
236
|
pub fn run_session(
|
|
78
|
-
session: &Session,
|
|
237
|
+
session: &mut Session,
|
|
79
238
|
tokenized: &Tokenized,
|
|
80
239
|
config: &ModelConfig,
|
|
81
240
|
) -> Result<Array2<f32>> {
|
|
82
241
|
let input_tensors = InputTensors::from_tokenized(tokenized, config.with_attention_mask)?;
|
|
83
|
-
let outputs = session
|
|
242
|
+
let outputs = session
|
|
243
|
+
.run(input_tensors.inputs)
|
|
244
|
+
.map_err(|e| GteError::Ort(e.to_string()))?;
|
|
84
245
|
let array = extract_output_tensor(&outputs, config.output_tensor.as_str())?;
|
|
85
246
|
|
|
247
|
+
extract_embeddings(array, input_tensors.attention_mask, config)
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
fn extract_embeddings(
|
|
251
|
+
array: ArrayViewD<'_, f32>,
|
|
252
|
+
attention_mask: ArrayView2<'_, i64>,
|
|
253
|
+
config: &ModelConfig,
|
|
254
|
+
) -> Result<Array2<f32>> {
|
|
86
255
|
match config.mode {
|
|
87
256
|
ExtractorMode::Token(idx) => {
|
|
88
257
|
let shape = array.shape();
|
|
@@ -102,15 +271,43 @@ pub fn run_session(
|
|
|
102
271
|
ndim
|
|
103
272
|
))
|
|
104
273
|
})?;
|
|
105
|
-
mean_pool(hidden_states
|
|
274
|
+
mean_pool(hidden_states, attention_mask)
|
|
106
275
|
}
|
|
107
|
-
ExtractorMode::Raw =>
|
|
276
|
+
ExtractorMode::Raw => array
|
|
277
|
+
.into_dimensionality::<Ix2>()
|
|
278
|
+
.map(|view| view.to_owned())
|
|
279
|
+
.map_err(|e| GteError::Shape(e.to_string())),
|
|
108
280
|
}
|
|
109
281
|
}
|
|
110
282
|
|
|
111
283
|
#[cfg(test)]
|
|
112
284
|
mod tests {
|
|
113
|
-
use
|
|
285
|
+
use crate::model_config::{ExtractorMode, ModelConfig, PaddingMode};
|
|
286
|
+
use ndarray::{array, ArrayView2};
|
|
287
|
+
|
|
288
|
+
use super::{
|
|
289
|
+
extract_embeddings, parse_provider_registrations, pool_capacity_with_parallelism,
|
|
290
|
+
resolve_provider_order_with_env,
|
|
291
|
+
};
|
|
292
|
+
|
|
293
|
+
fn test_config(mode: ExtractorMode) -> ModelConfig {
|
|
294
|
+
ModelConfig {
|
|
295
|
+
max_length: 8,
|
|
296
|
+
padding_mode: PaddingMode::BatchLongest,
|
|
297
|
+
output_tensor: "output".to_string(),
|
|
298
|
+
mode,
|
|
299
|
+
with_type_ids: false,
|
|
300
|
+
with_attention_mask: true,
|
|
301
|
+
num_threads: 1,
|
|
302
|
+
optimization_level: 3,
|
|
303
|
+
execution_providers: None,
|
|
304
|
+
}
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
fn empty_attention_mask() -> ArrayView2<'static, i64> {
|
|
308
|
+
static EMPTY: [i64; 0] = [];
|
|
309
|
+
ArrayView2::from_shape((0, 0), &EMPTY).unwrap()
|
|
310
|
+
}
|
|
114
311
|
|
|
115
312
|
#[test]
|
|
116
313
|
fn parse_provider_registrations_keeps_supported_order() {
|
|
@@ -142,7 +339,74 @@ mod tests {
|
|
|
142
339
|
|
|
143
340
|
#[test]
|
|
144
341
|
fn resolve_provider_order_falls_back_to_env_then_cpu_default() {
|
|
145
|
-
assert_eq!(
|
|
342
|
+
assert_eq!(
|
|
343
|
+
resolve_provider_order_with_env(None, Some("coreml")),
|
|
344
|
+
"coreml"
|
|
345
|
+
);
|
|
146
346
|
assert_eq!(resolve_provider_order_with_env(None, None), "cpu");
|
|
147
347
|
}
|
|
348
|
+
|
|
349
|
+
#[test]
|
|
350
|
+
fn pool_capacity_uses_bounded_parallel_pool_for_auto_thread_mode() {
|
|
351
|
+
assert_eq!(pool_capacity_with_parallelism(0, 1), 1);
|
|
352
|
+
assert_eq!(pool_capacity_with_parallelism(0, 4), 4);
|
|
353
|
+
assert_eq!(pool_capacity_with_parallelism(0, 8), 6);
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
#[test]
|
|
357
|
+
fn pool_capacity_scales_with_available_parallelism() {
|
|
358
|
+
assert_eq!(pool_capacity_with_parallelism(1, 1), 1);
|
|
359
|
+
assert_eq!(pool_capacity_with_parallelism(1, 8), 8);
|
|
360
|
+
assert_eq!(pool_capacity_with_parallelism(2, 8), 4);
|
|
361
|
+
assert_eq!(pool_capacity_with_parallelism(3, 8), 3);
|
|
362
|
+
assert_eq!(pool_capacity_with_parallelism(8, 4), 1);
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
#[test]
|
|
366
|
+
fn extract_embeddings_raw_copies_only_final_matrix() {
|
|
367
|
+
let output = array![[1.0f32, 2.0], [3.0, 4.0]];
|
|
368
|
+
let extracted = extract_embeddings(
|
|
369
|
+
output.view().into_dyn(),
|
|
370
|
+
empty_attention_mask(),
|
|
371
|
+
&test_config(ExtractorMode::Raw),
|
|
372
|
+
)
|
|
373
|
+
.unwrap();
|
|
374
|
+
|
|
375
|
+
assert_eq!(extracted, output);
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
#[test]
|
|
379
|
+
fn extract_embeddings_token_selects_without_copying_full_sequence() {
|
|
380
|
+
let output = array![
|
|
381
|
+
[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]],
|
|
382
|
+
[[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]
|
|
383
|
+
];
|
|
384
|
+
let expected = array![[3.0f32, 4.0], [9.0, 10.0]];
|
|
385
|
+
let extracted = extract_embeddings(
|
|
386
|
+
output.view().into_dyn(),
|
|
387
|
+
empty_attention_mask(),
|
|
388
|
+
&test_config(ExtractorMode::Token(1)),
|
|
389
|
+
)
|
|
390
|
+
.unwrap();
|
|
391
|
+
|
|
392
|
+
assert_eq!(extracted, expected);
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
#[test]
|
|
396
|
+
fn extract_embeddings_mean_pool_uses_output_view_and_attention_mask() {
|
|
397
|
+
let output = array![
|
|
398
|
+
[[1.0f32, 3.0], [5.0, 7.0], [100.0, 100.0]],
|
|
399
|
+
[[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]]
|
|
400
|
+
];
|
|
401
|
+
let attention_mask = array![[1_i64, 1, 0], [0, 1, 1]];
|
|
402
|
+
let expected = array![[3.0f32, 5.0], [8.0, 10.0]];
|
|
403
|
+
let extracted = extract_embeddings(
|
|
404
|
+
output.view().into_dyn(),
|
|
405
|
+
attention_mask.view(),
|
|
406
|
+
&test_config(ExtractorMode::MeanPool),
|
|
407
|
+
)
|
|
408
|
+
.unwrap();
|
|
409
|
+
|
|
410
|
+
assert_eq!(extracted, expected);
|
|
411
|
+
}
|
|
148
412
|
}
|
data/lib/gte/embedder.rb
CHANGED
|
@@ -2,6 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
module GTE
|
|
4
4
|
class Embedder
|
|
5
|
+
DEFAULT_THREADS = 1
|
|
6
|
+
DEFAULT_OPTIMIZATION_LEVEL = 3
|
|
7
|
+
|
|
5
8
|
class << self
|
|
6
9
|
def config(model_dir)
|
|
7
10
|
cfg = default_config(model_dir)
|
|
@@ -23,13 +26,11 @@ module GTE
|
|
|
23
26
|
)
|
|
24
27
|
end
|
|
25
28
|
|
|
26
|
-
private
|
|
27
|
-
|
|
28
29
|
def default_config(model_dir)
|
|
29
30
|
Config::Text.new(
|
|
30
31
|
model_dir: File.expand_path(model_dir),
|
|
31
|
-
threads:
|
|
32
|
-
optimization_level:
|
|
32
|
+
threads: DEFAULT_THREADS,
|
|
33
|
+
optimization_level: DEFAULT_OPTIMIZATION_LEVEL,
|
|
33
34
|
model_name: nil,
|
|
34
35
|
normalize: true,
|
|
35
36
|
output_tensor: nil,
|
data/lib/gte/reranker.rb
CHANGED
data/lib/gte.rb
CHANGED
|
@@ -19,17 +19,7 @@ module GTE
|
|
|
19
19
|
|
|
20
20
|
class << self
|
|
21
21
|
def config(model_dir)
|
|
22
|
-
cfg =
|
|
23
|
-
model_dir: File.expand_path(model_dir),
|
|
24
|
-
threads: 3,
|
|
25
|
-
optimization_level: 3,
|
|
26
|
-
model_name: nil,
|
|
27
|
-
normalize: true,
|
|
28
|
-
output_tensor: nil,
|
|
29
|
-
max_length: nil,
|
|
30
|
-
padding: nil,
|
|
31
|
-
execution_providers: nil
|
|
32
|
-
)
|
|
22
|
+
cfg = Embedder.default_config(model_dir)
|
|
33
23
|
|
|
34
24
|
cfg = yield(cfg) if block_given?
|
|
35
25
|
|
metadata
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: gte
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 0.0.
|
|
4
|
+
version: 0.0.8
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- elcuervo
|
|
8
8
|
autorequire:
|
|
9
9
|
bindir: bin
|
|
10
10
|
cert_chain: []
|
|
11
|
-
date: 2026-04-
|
|
11
|
+
date: 2026-04-28 00:00:00.000000000 Z
|
|
12
12
|
dependencies:
|
|
13
13
|
- !ruby/object:Gem::Dependency
|
|
14
14
|
name: rake
|
|
@@ -42,16 +42,16 @@ dependencies:
|
|
|
42
42
|
name: rb_sys
|
|
43
43
|
requirement: !ruby/object:Gem::Requirement
|
|
44
44
|
requirements:
|
|
45
|
-
- -
|
|
45
|
+
- - '='
|
|
46
46
|
- !ruby/object:Gem::Version
|
|
47
|
-
version:
|
|
47
|
+
version: 0.9.126
|
|
48
48
|
type: :runtime
|
|
49
49
|
prerelease: false
|
|
50
50
|
version_requirements: !ruby/object:Gem::Requirement
|
|
51
51
|
requirements:
|
|
52
|
-
- -
|
|
52
|
+
- - '='
|
|
53
53
|
- !ruby/object:Gem::Version
|
|
54
|
-
version:
|
|
54
|
+
version: 0.9.126
|
|
55
55
|
- !ruby/object:Gem::Dependency
|
|
56
56
|
name: rspec
|
|
57
57
|
requirement: !ruby/object:Gem::Requirement
|