gte 0.0.6 → 0.0.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +16 -8
- data/Rakefile +38 -3
- data/VERSION +1 -1
- data/ext/gte/Cargo.toml +4 -4
- data/ext/gte/src/embedder.rs +42 -33
- data/ext/gte/src/model_config.rs +18 -0
- data/ext/gte/src/model_profile.rs +129 -33
- data/ext/gte/src/pipeline.rs +12 -9
- data/ext/gte/src/reranker.rs +49 -31
- data/ext/gte/src/ruby_embedder.rs +73 -113
- data/ext/gte/src/session.rs +279 -15
- data/ext/gte/src/tokenizer.rs +99 -14
- data/ext/gte/tests/inference_integration_test.rs +5 -4
- data/ext/gte/tests/tokenizer_unit_test.rs +5 -2
- data/lib/gte/config.rb +2 -2
- data/lib/gte/embedder.rb +7 -4
- data/lib/gte/reranker.rb +3 -1
- data/lib/gte.rb +1 -10
- metadata +6 -6
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: 2c754b4675ee105e9a280cd9deafa00a81b9e02ee629131f3e908400006b6ae4
|
|
4
|
+
data.tar.gz: 40a0d3e04c3d2943ae50910164d644ecb763eac99a02044dc962cc141a0e13c5
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: 16614e01e7a33a53339ba9fe7cf32fe7606041518a24177258d7a6e5550516e8cff741d0f0df02b7e5863fc763c02ae81b943dc4b18295701a4cafdec6627cb0
|
|
7
|
+
data.tar.gz: 348e1fd1d9f4c44214b5101ba339109b5ececfbef18b48b7c11324a64481f476d8da831cc5148d17a85c41b525ee753c296d4421a4fb2adda269a3f5fe38cda6
|
data/README.md
CHANGED
|
@@ -33,14 +33,15 @@ raw_model = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
|
|
|
33
33
|
config.with(normalize: false)
|
|
34
34
|
end
|
|
35
35
|
|
|
36
|
-
|
|
37
|
-
config.with(threads:
|
|
36
|
+
single_thread = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
|
|
37
|
+
config.with(threads: 1)
|
|
38
38
|
end
|
|
39
39
|
|
|
40
40
|
custom = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
|
|
41
41
|
config.with(
|
|
42
42
|
output_tensor: "last_hidden_state",
|
|
43
43
|
max_length: 256,
|
|
44
|
+
padding: "batch_longest",
|
|
44
45
|
optimization_level: 3
|
|
45
46
|
)
|
|
46
47
|
end
|
|
@@ -49,12 +50,13 @@ end
|
|
|
49
50
|
Config fields and defaults:
|
|
50
51
|
|
|
51
52
|
- `model_dir`: absolute path to model directory
|
|
52
|
-
- `threads`: `
|
|
53
|
+
- `threads`: `1` (default tuned for p95 latency; use `0` for ONNX Runtime auto-thread mode)
|
|
53
54
|
- `optimization_level`: `3`
|
|
54
55
|
- `model_name`: `nil`
|
|
55
56
|
- `normalize`: `true` (L2 normalization at Ruby-facing API)
|
|
56
57
|
- `output_tensor`: `nil` (auto-select output tensor)
|
|
57
58
|
- `max_length`: `nil` (uses tokenizer/model defaults)
|
|
59
|
+
- `padding`: `nil` (auto; accepts `auto`, `batch_longest`, `fixed`)
|
|
58
60
|
- `execution_providers`: `nil` (falls back to `GTE_EXECUTION_PROVIDERS` / CPU default)
|
|
59
61
|
|
|
60
62
|
Notes:
|
|
@@ -66,7 +68,7 @@ Low-level embedder setup (without model cache):
|
|
|
66
68
|
|
|
67
69
|
```ruby
|
|
68
70
|
embedder = GTE::Embedder.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
|
|
69
|
-
config.with(threads:
|
|
71
|
+
config.with(threads: 1, execution_providers: "cpu")
|
|
70
72
|
end
|
|
71
73
|
```
|
|
72
74
|
|
|
@@ -76,7 +78,7 @@ Use `GTE::Reranker.config(model_dir)` for cross-encoder reranking.
|
|
|
76
78
|
|
|
77
79
|
```ruby
|
|
78
80
|
reranker = GTE::Reranker.config(ENV.fetch("GTE_RERANK_DIR")) do |config|
|
|
79
|
-
config.with(sigmoid: true, threads:
|
|
81
|
+
config.with(sigmoid: true, threads: 1)
|
|
80
82
|
end
|
|
81
83
|
|
|
82
84
|
query = "how to train a neural network?"
|
|
@@ -100,12 +102,13 @@ ranked = reranker.rerank(query: query, candidates: candidates)
|
|
|
100
102
|
Reranker config fields and defaults:
|
|
101
103
|
|
|
102
104
|
- `model_dir`: absolute path to model directory
|
|
103
|
-
- `threads`: `
|
|
105
|
+
- `threads`: `1`
|
|
104
106
|
- `optimization_level`: `3`
|
|
105
107
|
- `model_name`: `nil`
|
|
106
108
|
- `sigmoid`: `false` (set `true` if you want bounded [0,1] style scores)
|
|
107
109
|
- `output_tensor`: `nil`
|
|
108
110
|
- `max_length`: `nil`
|
|
111
|
+
- `padding`: `nil` (auto; accepts `auto`, `batch_longest`, `fixed`)
|
|
109
112
|
- `execution_providers`: `nil`
|
|
110
113
|
|
|
111
114
|
## Runtime + Result Examples
|
|
@@ -171,7 +174,7 @@ make ci
|
|
|
171
174
|
|
|
172
175
|
## Benchmark
|
|
173
176
|
|
|
174
|
-
The repo includes
|
|
177
|
+
The repo includes a shared multi-runtime benchmark harness:
|
|
175
178
|
|
|
176
179
|
```bash
|
|
177
180
|
make bench
|
|
@@ -180,6 +183,11 @@ nix develop -c bundle exec rake bench:matrix_sweep
|
|
|
180
183
|
nix develop -c bundle exec ruby bench/memory_probe.rb --compare-pure
|
|
181
184
|
```
|
|
182
185
|
|
|
186
|
+
- `make bench`: Puma-like single-request comparison at concurrency `16`
|
|
187
|
+
- `rake bench:pure_compare`: batch amortization comparison
|
|
188
|
+
- `rake bench:matrix_sweep`: GTE provider/thread sweep using the shared result schema
|
|
189
|
+
- Optional Python comparisons use `bench/python_onnxruntime.py` and are skipped automatically if local dependencies are unavailable.
|
|
190
|
+
|
|
183
191
|
To run benchmark + append a `RUNS.md` entry + enforce goal checks:
|
|
184
192
|
|
|
185
193
|
```bash
|
|
@@ -188,5 +196,5 @@ make bench-record
|
|
|
188
196
|
|
|
189
197
|
`bench/runs_ledger.rb check` is goal-focused by default:
|
|
190
198
|
|
|
191
|
-
- Enforces goal metric (`response_time_p95`
|
|
199
|
+
- Enforces the goal metric (`response_time_p95`) across every enabled competitor.
|
|
192
200
|
- Does not require current-version coverage in `RUNS.md` unless explicitly enabled.
|
data/Rakefile
CHANGED
|
@@ -10,17 +10,52 @@ rescue LoadError
|
|
|
10
10
|
end
|
|
11
11
|
|
|
12
12
|
spec = Gem::Specification.load('gte.gemspec')
|
|
13
|
+
cross_target = ENV.fetch('RUBY_TARGET', nil)
|
|
13
14
|
|
|
14
|
-
|
|
15
|
+
if cross_target == 'arm64-darwin'
|
|
16
|
+
# rb-sys-dock's darwin image can expose an unusable default LIBRARY_PATH.
|
|
17
|
+
# Force the compiler-rt darwin runtime directory so -lclang_rt.osx resolves.
|
|
18
|
+
ENV['LIBRARY_PATH'] = '/usr/lib/llvm-10/lib/clang/10.0.0/lib/darwin'
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
extension_task = Rake::ExtensionTask.new('gte', spec) do |ext|
|
|
15
22
|
ext.lib_dir = 'lib/gte'
|
|
16
23
|
ext.cross_compile = true
|
|
17
|
-
|
|
24
|
+
# rb-sys-dock invokes `rake native:$RUBY_TARGET gem` without the `cross` task,
|
|
25
|
+
# so scope platforms during dock builds to avoid host-Ruby fallback copy tasks.
|
|
26
|
+
cross_platforms = if cross_target && !cross_target.empty?
|
|
27
|
+
[cross_target]
|
|
28
|
+
else
|
|
29
|
+
%w[x86_64-linux aarch64-linux arm64-darwin]
|
|
30
|
+
end
|
|
31
|
+
ext.cross_platform = cross_platforms
|
|
32
|
+
end
|
|
33
|
+
|
|
34
|
+
if cross_target && !cross_target.empty? && ENV['RUBY_CC_VERSION']
|
|
35
|
+
ruby_version = ENV['RUBY_CC_VERSION'].split(':').first
|
|
36
|
+
lib_binary_path = File.join(extension_task.lib_dir, File.basename(extension_task.binary(cross_target)))
|
|
37
|
+
copy_task = "copy:gte:#{cross_target}:#{ruby_version}"
|
|
38
|
+
|
|
39
|
+
if Rake::Task.task_defined?(lib_binary_path) && Rake::Task.task_defined?(copy_task)
|
|
40
|
+
Rake::Task[lib_binary_path].prerequisites.clear
|
|
41
|
+
Rake::Task[lib_binary_path].enhance([copy_task])
|
|
42
|
+
end
|
|
18
43
|
end
|
|
19
44
|
|
|
20
45
|
task default: %i[compile spec]
|
|
21
46
|
|
|
47
|
+
def bundler_env
|
|
48
|
+
root = File.expand_path(__dir__)
|
|
49
|
+
{
|
|
50
|
+
'BUNDLE_DISABLE_SHARED_GEMS' => '1',
|
|
51
|
+
'GEM_HOME' => File.join(root, '.bundle-gems'),
|
|
52
|
+
'GEM_PATH' => File.join(root, '.bundle-gems'),
|
|
53
|
+
'BUNDLE_PATH' => File.join(root, 'vendor/bundle')
|
|
54
|
+
}
|
|
55
|
+
end
|
|
56
|
+
|
|
22
57
|
def run_in_nix(*command)
|
|
23
|
-
sh('nix', 'develop', '-c', *command)
|
|
58
|
+
sh(bundler_env, 'nix', 'develop', '-c', *command)
|
|
24
59
|
end
|
|
25
60
|
|
|
26
61
|
namespace :bench do
|
data/VERSION
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
0.0.
|
|
1
|
+
0.0.8
|
data/ext/gte/Cargo.toml
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[package]
|
|
2
2
|
name = "gte"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.8"
|
|
4
4
|
edition = "2021"
|
|
5
5
|
authors = ["elcuervo <elcuervo@elcuervo.net>"]
|
|
6
6
|
license = "MIT"
|
|
@@ -21,10 +21,10 @@ ruby-ffi = ["dep:magnus", "dep:rb-sys"]
|
|
|
21
21
|
[dependencies]
|
|
22
22
|
rb-sys = { version = "0.9", features = ["stable-api-compiled-fallback"], optional = true }
|
|
23
23
|
magnus = { version = "0.8", optional = true }
|
|
24
|
-
ort = { version = "=2.0.0-rc.
|
|
25
|
-
ort-sys = "=2.0.0-rc.
|
|
24
|
+
ort = { version = "=2.0.0-rc.12", features = ["ndarray", "xnnpack"] }
|
|
25
|
+
ort-sys = "=2.0.0-rc.12"
|
|
26
26
|
tokenizers = "0.21.0"
|
|
27
|
-
ndarray = "0.
|
|
27
|
+
ndarray = "0.17"
|
|
28
28
|
half = "2"
|
|
29
29
|
serde = { version = "1", features = ["derive"] }
|
|
30
30
|
serde_json = "1"
|
data/ext/gte/src/embedder.rs
CHANGED
|
@@ -1,19 +1,18 @@
|
|
|
1
1
|
use crate::error::{GteError, Result};
|
|
2
|
-
use crate::model_config::{ExtractorMode, ModelConfig};
|
|
2
|
+
use crate::model_config::{ExtractorMode, ModelConfig, ModelLoadOverrides, PaddingMode};
|
|
3
3
|
use crate::model_profile::{
|
|
4
|
-
has_input, infer_extraction_mode,
|
|
5
|
-
resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
|
|
4
|
+
has_input, infer_extraction_mode, read_tokenizer_profile, resolve_default_text_model,
|
|
5
|
+
resolve_named_model, resolve_tokenizer_path, select_output_tensor, validate_supported_text_inputs,
|
|
6
6
|
};
|
|
7
7
|
use crate::postprocess::normalize_l2 as normalize_l2_rows;
|
|
8
|
-
use crate::session::{build_session, run_session};
|
|
9
|
-
use crate::tokenizer::{Tokenized, Tokenizer};
|
|
8
|
+
use crate::session::{build_session, run_session, SessionPool};
|
|
9
|
+
use crate::tokenizer::{parse_padding_mode_override, Tokenized, Tokenizer};
|
|
10
10
|
use ndarray::Array2;
|
|
11
|
-
use
|
|
12
|
-
use std::path::Path;
|
|
11
|
+
use std::path::{Path, PathBuf};
|
|
13
12
|
|
|
14
13
|
pub struct Embedder {
|
|
15
14
|
tokenizer: Tokenizer,
|
|
16
|
-
|
|
15
|
+
pool: SessionPool,
|
|
17
16
|
config: ModelConfig,
|
|
18
17
|
}
|
|
19
18
|
|
|
@@ -23,23 +22,24 @@ impl Embedder {
|
|
|
23
22
|
P1: AsRef<Path>,
|
|
24
23
|
P2: AsRef<Path>,
|
|
25
24
|
{
|
|
26
|
-
let tokenizer = Tokenizer::new(
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
25
|
+
let tokenizer = Tokenizer::new(
|
|
26
|
+
tokenizer_path,
|
|
27
|
+
config.max_length,
|
|
28
|
+
config.with_type_ids,
|
|
29
|
+
config.padding_mode,
|
|
30
|
+
None,
|
|
31
|
+
)?;
|
|
32
|
+
let model_path = model_path.as_ref().to_path_buf();
|
|
33
|
+
let session = build_session(&model_path, &config)?;
|
|
34
|
+
let pool = SessionPool::new(session, model_path, config.clone());
|
|
35
|
+
Ok(Self { tokenizer, pool, config })
|
|
33
36
|
}
|
|
34
37
|
|
|
35
38
|
pub fn from_dir<P: AsRef<Path>>(
|
|
36
39
|
dir: P,
|
|
37
40
|
num_threads: usize,
|
|
38
41
|
optimization_level: u8,
|
|
39
|
-
|
|
40
|
-
output_tensor_override: Option<&str>,
|
|
41
|
-
max_length_override: Option<usize>,
|
|
42
|
-
execution_providers_override: Option<&str>,
|
|
42
|
+
overrides: ModelLoadOverrides<'_>,
|
|
43
43
|
) -> Result<Self> {
|
|
44
44
|
const PREFERRED_EMBEDDING_OUTPUTS: [&str; 4] = [
|
|
45
45
|
"pooler_output",
|
|
@@ -50,31 +50,35 @@ impl Embedder {
|
|
|
50
50
|
|
|
51
51
|
let dir = dir.as_ref();
|
|
52
52
|
let tokenizer_path = resolve_tokenizer_path(dir)?;
|
|
53
|
-
let model_path = match model_name.filter(|s| !s.is_empty()) {
|
|
53
|
+
let model_path: PathBuf = match overrides.model_name.filter(|s| !s.is_empty()) {
|
|
54
54
|
Some(name) => resolve_named_model(dir, name)?,
|
|
55
55
|
None => resolve_default_text_model(dir)?,
|
|
56
56
|
};
|
|
57
57
|
|
|
58
|
-
let
|
|
58
|
+
let tokenizer_profile = read_tokenizer_profile(dir);
|
|
59
|
+
let max_length = if let Some(override_value) = overrides.max_length {
|
|
59
60
|
if override_value == 0 {
|
|
60
61
|
return Err(GteError::Inference(
|
|
61
62
|
"max_length override must be greater than 0".to_string(),
|
|
62
63
|
));
|
|
63
64
|
}
|
|
64
|
-
override_value
|
|
65
|
+
override_value.min(tokenizer_profile.safe_max_length)
|
|
65
66
|
} else {
|
|
66
|
-
|
|
67
|
+
tokenizer_profile.default_max_length
|
|
67
68
|
};
|
|
69
|
+
let padding_mode =
|
|
70
|
+
parse_padding_mode_override(overrides.padding)?.unwrap_or(PaddingMode::Auto);
|
|
68
71
|
|
|
69
72
|
let session_config = ModelConfig {
|
|
70
73
|
max_length,
|
|
74
|
+
padding_mode,
|
|
71
75
|
output_tensor: String::new(),
|
|
72
76
|
mode: ExtractorMode::Raw,
|
|
73
77
|
with_type_ids: false,
|
|
74
78
|
with_attention_mask: true,
|
|
75
79
|
num_threads,
|
|
76
80
|
optimization_level,
|
|
77
|
-
execution_providers:
|
|
81
|
+
execution_providers: overrides.execution_providers.map(str::to_string),
|
|
78
82
|
};
|
|
79
83
|
let session = build_session(&model_path, &session_config)?;
|
|
80
84
|
|
|
@@ -82,7 +86,7 @@ impl Embedder {
|
|
|
82
86
|
let with_type_ids = has_input(&session, "token_type_ids");
|
|
83
87
|
let with_attention_mask = has_input(&session, "attention_mask");
|
|
84
88
|
let output_tensor =
|
|
85
|
-
select_output_tensor(&session,
|
|
89
|
+
select_output_tensor(&session, overrides.output_tensor, &PREFERRED_EMBEDDING_OUTPUTS)?;
|
|
86
90
|
let mode = infer_extraction_mode(&session, output_tensor.as_str())?;
|
|
87
91
|
if matches!(mode, ExtractorMode::MeanPool) && !with_attention_mask {
|
|
88
92
|
return Err(GteError::Inference(
|
|
@@ -92,22 +96,26 @@ impl Embedder {
|
|
|
92
96
|
|
|
93
97
|
let config = ModelConfig {
|
|
94
98
|
max_length,
|
|
99
|
+
padding_mode,
|
|
95
100
|
output_tensor,
|
|
96
101
|
mode,
|
|
97
102
|
with_type_ids,
|
|
98
103
|
with_attention_mask,
|
|
99
104
|
num_threads,
|
|
100
105
|
optimization_level,
|
|
101
|
-
execution_providers:
|
|
106
|
+
execution_providers: overrides.execution_providers.map(str::to_string),
|
|
102
107
|
};
|
|
103
108
|
|
|
104
|
-
let tokenizer = Tokenizer::new(
|
|
109
|
+
let tokenizer = Tokenizer::new(
|
|
110
|
+
&tokenizer_path,
|
|
111
|
+
config.max_length,
|
|
112
|
+
config.with_type_ids,
|
|
113
|
+
config.padding_mode,
|
|
114
|
+
tokenizer_profile.fixed_padding_length,
|
|
115
|
+
)?;
|
|
105
116
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
session,
|
|
109
|
-
config,
|
|
110
|
-
})
|
|
117
|
+
let pool = SessionPool::new(session, model_path, session_config);
|
|
118
|
+
Ok(Self { tokenizer, pool, config })
|
|
111
119
|
}
|
|
112
120
|
|
|
113
121
|
pub fn embed(&self, texts: Vec<String>) -> Result<Array2<f32>> {
|
|
@@ -120,7 +128,8 @@ impl Embedder {
|
|
|
120
128
|
}
|
|
121
129
|
|
|
122
130
|
pub fn run(&self, tokenized: &Tokenized) -> crate::error::Result<Array2<f32>> {
|
|
123
|
-
|
|
131
|
+
let mut session = self.pool.acquire()?;
|
|
132
|
+
run_session(&mut session, tokenized, &self.config)
|
|
124
133
|
}
|
|
125
134
|
}
|
|
126
135
|
|
data/ext/gte/src/model_config.rs
CHANGED
|
@@ -5,9 +5,18 @@ pub enum ExtractorMode {
|
|
|
5
5
|
Raw,
|
|
6
6
|
}
|
|
7
7
|
|
|
8
|
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
|
9
|
+
pub enum PaddingMode {
|
|
10
|
+
#[default]
|
|
11
|
+
Auto,
|
|
12
|
+
BatchLongest,
|
|
13
|
+
Fixed,
|
|
14
|
+
}
|
|
15
|
+
|
|
8
16
|
#[derive(Debug, Clone)]
|
|
9
17
|
pub struct ModelConfig {
|
|
10
18
|
pub max_length: usize,
|
|
19
|
+
pub padding_mode: PaddingMode,
|
|
11
20
|
pub output_tensor: String,
|
|
12
21
|
pub mode: ExtractorMode,
|
|
13
22
|
pub with_type_ids: bool,
|
|
@@ -16,3 +25,12 @@ pub struct ModelConfig {
|
|
|
16
25
|
pub optimization_level: u8,
|
|
17
26
|
pub execution_providers: Option<String>,
|
|
18
27
|
}
|
|
28
|
+
|
|
29
|
+
#[derive(Debug, Clone, Copy, Default)]
|
|
30
|
+
pub struct ModelLoadOverrides<'a> {
|
|
31
|
+
pub model_name: Option<&'a str>,
|
|
32
|
+
pub output_tensor: Option<&'a str>,
|
|
33
|
+
pub max_length: Option<usize>,
|
|
34
|
+
pub padding: Option<&'a str>,
|
|
35
|
+
pub execution_providers: Option<&'a str>,
|
|
36
|
+
}
|
|
@@ -1,9 +1,19 @@
|
|
|
1
1
|
use crate::error::{GteError, Result};
|
|
2
2
|
use crate::model_config::ExtractorMode;
|
|
3
3
|
use ort::session::Session;
|
|
4
|
+
use serde_json::Value;
|
|
4
5
|
use std::path::{Path, PathBuf};
|
|
5
6
|
|
|
6
7
|
const SUPPORTED_INPUTS: [&str; 3] = ["input_ids", "attention_mask", "token_type_ids"];
|
|
8
|
+
const DEFAULT_MAX_LENGTH: usize = 512;
|
|
9
|
+
const MAX_SUPPORTED_LENGTH: usize = 8192;
|
|
10
|
+
|
|
11
|
+
#[derive(Debug, Clone, Copy)]
|
|
12
|
+
pub struct TokenizerProfile {
|
|
13
|
+
pub default_max_length: usize,
|
|
14
|
+
pub safe_max_length: usize,
|
|
15
|
+
pub fixed_padding_length: Option<usize>,
|
|
16
|
+
}
|
|
7
17
|
|
|
8
18
|
pub fn resolve_tokenizer_path(dir: &Path) -> Result<PathBuf> {
|
|
9
19
|
let tokenizer_path = dir.join("tokenizer.json");
|
|
@@ -48,27 +58,84 @@ pub fn resolve_default_text_model(dir: &Path) -> Result<PathBuf> {
|
|
|
48
58
|
)))
|
|
49
59
|
}
|
|
50
60
|
|
|
51
|
-
pub fn
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
Some(
|
|
62
|
-
|
|
63
|
-
|
|
61
|
+
pub fn read_tokenizer_profile(dir: &Path) -> TokenizerProfile {
|
|
62
|
+
let tokenizer_config = read_json(dir.join("tokenizer_config.json"));
|
|
63
|
+
let tokenizer_json = read_json(dir.join("tokenizer.json"));
|
|
64
|
+
|
|
65
|
+
let fixed_padding_length = tokenizer_json
|
|
66
|
+
.as_ref()
|
|
67
|
+
.and_then(parse_fixed_padding_length_from_tokenizer_json);
|
|
68
|
+
|
|
69
|
+
let mut candidates = Vec::new();
|
|
70
|
+
if let Some(config) = tokenizer_config.as_ref() {
|
|
71
|
+
if let Some(v) = config.get("max_length").and_then(parse_positive_usize) {
|
|
72
|
+
candidates.push(v.min(MAX_SUPPORTED_LENGTH));
|
|
73
|
+
}
|
|
74
|
+
if let Some(v) = config.get("model_max_length").and_then(parse_positive_usize) {
|
|
75
|
+
candidates.push(v.min(MAX_SUPPORTED_LENGTH));
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
if let Some(tokenizer) = tokenizer_json.as_ref() {
|
|
80
|
+
if let Some(v) = tokenizer
|
|
81
|
+
.get("truncation")
|
|
82
|
+
.and_then(|truncation| truncation.get("max_length"))
|
|
83
|
+
.and_then(parse_positive_usize)
|
|
84
|
+
{
|
|
85
|
+
candidates.push(v.min(MAX_SUPPORTED_LENGTH));
|
|
86
|
+
}
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
if let Some(v) = fixed_padding_length {
|
|
90
|
+
candidates.push(v.min(MAX_SUPPORTED_LENGTH));
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
let default_max_length = candidates
|
|
94
|
+
.iter()
|
|
95
|
+
.copied()
|
|
96
|
+
.min()
|
|
97
|
+
.unwrap_or(DEFAULT_MAX_LENGTH)
|
|
98
|
+
.max(1);
|
|
99
|
+
let safe_max_length = fixed_padding_length.unwrap_or(default_max_length).max(1);
|
|
100
|
+
|
|
101
|
+
TokenizerProfile {
|
|
102
|
+
default_max_length,
|
|
103
|
+
safe_max_length,
|
|
104
|
+
fixed_padding_length,
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
fn read_json(path: PathBuf) -> Option<Value> {
|
|
109
|
+
let contents = std::fs::read_to_string(path).ok()?;
|
|
110
|
+
serde_json::from_str(&contents).ok()
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
fn parse_positive_usize(value: &Value) -> Option<usize> {
|
|
114
|
+
let raw = value
|
|
115
|
+
.as_u64()
|
|
116
|
+
.or_else(|| {
|
|
117
|
+
value
|
|
118
|
+
.as_f64()
|
|
119
|
+
.filter(|&v| v.is_finite() && v > 0.0)
|
|
120
|
+
.map(|v| v as u64)
|
|
121
|
+
})
|
|
122
|
+
.or_else(|| value.as_str().and_then(|s| s.parse::<u64>().ok()))?;
|
|
123
|
+
let parsed = usize::try_from(raw).ok()?;
|
|
124
|
+
(parsed > 0).then_some(parsed)
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
fn parse_fixed_padding_length_from_tokenizer_json(tokenizer_json: &Value) -> Option<usize> {
|
|
128
|
+
tokenizer_json
|
|
129
|
+
.get("padding")
|
|
130
|
+
.and_then(|padding| padding.get("strategy"))
|
|
131
|
+
.and_then(|strategy| strategy.get("Fixed"))
|
|
132
|
+
.and_then(parse_positive_usize)
|
|
64
133
|
}
|
|
65
134
|
|
|
66
135
|
pub fn validate_supported_text_inputs(session: &Session, api_label: &str) -> Result<()> {
|
|
67
|
-
let unsupported: Vec<String> = session
|
|
68
|
-
.
|
|
69
|
-
.
|
|
70
|
-
.filter(|i| !SUPPORTED_INPUTS.contains(&i.name.as_str()))
|
|
71
|
-
.map(|i| i.name.clone())
|
|
136
|
+
let unsupported: Vec<String> = session.inputs().iter()
|
|
137
|
+
.filter(|i| !SUPPORTED_INPUTS.contains(&i.name()))
|
|
138
|
+
.map(|i| i.name().to_owned())
|
|
72
139
|
.collect();
|
|
73
140
|
|
|
74
141
|
if unsupported.is_empty() {
|
|
@@ -91,7 +158,7 @@ pub fn validate_supported_text_inputs(session: &Session, api_label: &str) -> Res
|
|
|
91
158
|
}
|
|
92
159
|
|
|
93
160
|
pub fn has_input(session: &Session, name: &str) -> bool {
|
|
94
|
-
session.inputs.iter().any(|input| input.name == name)
|
|
161
|
+
session.inputs().iter().any(|input| input.name() == name)
|
|
95
162
|
}
|
|
96
163
|
|
|
97
164
|
fn output_name_matches(name: &str, preferred: &str) -> bool {
|
|
@@ -106,16 +173,16 @@ pub fn select_output_tensor(
|
|
|
106
173
|
) -> Result<String> {
|
|
107
174
|
if let Some(requested_name) = requested.map(str::trim).filter(|name| !name.is_empty()) {
|
|
108
175
|
if let Some(output) = session
|
|
109
|
-
.outputs
|
|
176
|
+
.outputs()
|
|
110
177
|
.iter()
|
|
111
|
-
.find(|o| output_name_matches(o.name
|
|
178
|
+
.find(|o| output_name_matches(o.name(), requested_name))
|
|
112
179
|
{
|
|
113
|
-
return Ok(output.name.
|
|
180
|
+
return Ok(output.name().to_owned());
|
|
114
181
|
}
|
|
115
182
|
let available = session
|
|
116
|
-
.outputs
|
|
183
|
+
.outputs()
|
|
117
184
|
.iter()
|
|
118
|
-
.map(|o| o.name
|
|
185
|
+
.map(|o| o.name())
|
|
119
186
|
.collect::<Vec<_>>()
|
|
120
187
|
.join(", ");
|
|
121
188
|
return Err(GteError::Inference(format!(
|
|
@@ -126,18 +193,18 @@ pub fn select_output_tensor(
|
|
|
126
193
|
|
|
127
194
|
for preferred in preferred_outputs {
|
|
128
195
|
if let Some(output) = session
|
|
129
|
-
.outputs
|
|
196
|
+
.outputs()
|
|
130
197
|
.iter()
|
|
131
|
-
.find(|o| output_name_matches(o.name
|
|
198
|
+
.find(|o| output_name_matches(o.name(), preferred))
|
|
132
199
|
{
|
|
133
|
-
return Ok(output.name.
|
|
200
|
+
return Ok(output.name().to_owned());
|
|
134
201
|
}
|
|
135
202
|
}
|
|
136
203
|
|
|
137
204
|
session
|
|
138
|
-
.outputs
|
|
205
|
+
.outputs()
|
|
139
206
|
.first()
|
|
140
|
-
.map(|o| o.name.
|
|
207
|
+
.map(|o| o.name().to_owned())
|
|
141
208
|
.ok_or_else(|| GteError::Inference("model has no outputs".into()))
|
|
142
209
|
}
|
|
143
210
|
|
|
@@ -147,9 +214,9 @@ fn output_basename(name: &str) -> &str {
|
|
|
147
214
|
|
|
148
215
|
pub fn infer_extraction_mode(session: &Session, output_tensor: &str) -> Result<ExtractorMode> {
|
|
149
216
|
let output = session
|
|
150
|
-
.outputs
|
|
217
|
+
.outputs()
|
|
151
218
|
.iter()
|
|
152
|
-
.find(|o| o.name == output_tensor)
|
|
219
|
+
.find(|o| o.name() == output_tensor)
|
|
153
220
|
.ok_or_else(|| {
|
|
154
221
|
GteError::Inference(format!(
|
|
155
222
|
"output tensor '{}' not found in model outputs",
|
|
@@ -157,8 +224,8 @@ pub fn infer_extraction_mode(session: &Session, output_tensor: &str) -> Result<E
|
|
|
157
224
|
))
|
|
158
225
|
})?;
|
|
159
226
|
|
|
160
|
-
let ndims = match
|
|
161
|
-
ort::value::ValueType::Tensor {
|
|
227
|
+
let ndims = match output.dtype() {
|
|
228
|
+
ort::value::ValueType::Tensor { shape, .. } => shape.len(),
|
|
162
229
|
other => {
|
|
163
230
|
return Err(GteError::Inference(format!(
|
|
164
231
|
"output is not a tensor: {:?}",
|
|
@@ -177,3 +244,32 @@ pub fn infer_extraction_mode(session: &Session, output_tensor: &str) -> Result<E
|
|
|
177
244
|
))),
|
|
178
245
|
}
|
|
179
246
|
}
|
|
247
|
+
|
|
248
|
+
#[cfg(test)]
|
|
249
|
+
mod tests {
|
|
250
|
+
use super::{parse_fixed_padding_length_from_tokenizer_json, parse_positive_usize};
|
|
251
|
+
use serde_json::json;
|
|
252
|
+
|
|
253
|
+
#[test]
|
|
254
|
+
fn parse_positive_usize_handles_integer_float_and_string() {
|
|
255
|
+
assert_eq!(parse_positive_usize(&json!(64)), Some(64));
|
|
256
|
+
assert_eq!(parse_positive_usize(&json!(64.0)), Some(64));
|
|
257
|
+
assert_eq!(parse_positive_usize(&json!("64")), Some(64));
|
|
258
|
+
assert_eq!(parse_positive_usize(&json!(0)), None);
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
#[test]
|
|
262
|
+
fn parse_fixed_padding_length_reads_fixed_padding_strategy() {
|
|
263
|
+
let tokenizer_json = json!({
|
|
264
|
+
"padding": {
|
|
265
|
+
"strategy": {
|
|
266
|
+
"Fixed": 64
|
|
267
|
+
}
|
|
268
|
+
}
|
|
269
|
+
});
|
|
270
|
+
assert_eq!(
|
|
271
|
+
parse_fixed_padding_length_from_tokenizer_json(&tokenizer_json),
|
|
272
|
+
Some(64)
|
|
273
|
+
);
|
|
274
|
+
}
|
|
275
|
+
}
|
data/ext/gte/src/pipeline.rs
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
use crate::error::{GteError, Result};
|
|
2
2
|
use crate::tokenizer::Tokenized;
|
|
3
|
-
use ndarray::ArrayView2;
|
|
3
|
+
use ndarray::{ArrayView2, ArrayViewD};
|
|
4
4
|
use ort::session::SessionInputValue;
|
|
5
|
-
use ort::value::
|
|
5
|
+
use ort::value::TensorRef;
|
|
6
6
|
|
|
7
7
|
pub struct InputTensors<'a> {
|
|
8
8
|
pub inputs: Vec<(&'static str, SessionInputValue<'a>)>,
|
|
@@ -23,13 +23,13 @@ impl<'a> InputTensors<'a> {
|
|
|
23
23
|
let mut inputs = Vec::with_capacity(2 + usize::from(tokenized.type_ids.is_some()));
|
|
24
24
|
inputs.push((
|
|
25
25
|
"input_ids",
|
|
26
|
-
SessionInputValue::from(
|
|
26
|
+
SessionInputValue::from(TensorRef::from_array_view(input_ids_view)?),
|
|
27
27
|
));
|
|
28
28
|
|
|
29
29
|
if with_attention_mask {
|
|
30
30
|
inputs.push((
|
|
31
31
|
"attention_mask",
|
|
32
|
-
SessionInputValue::from(
|
|
32
|
+
SessionInputValue::from(TensorRef::from_array_view(attention_mask)?),
|
|
33
33
|
));
|
|
34
34
|
}
|
|
35
35
|
|
|
@@ -38,7 +38,7 @@ impl<'a> InputTensors<'a> {
|
|
|
38
38
|
ArrayView2::from_shape((tokenized.rows, tokenized.cols), type_ids)?;
|
|
39
39
|
inputs.push((
|
|
40
40
|
"token_type_ids",
|
|
41
|
-
SessionInputValue::from(
|
|
41
|
+
SessionInputValue::from(TensorRef::from_array_view(type_ids_view)?),
|
|
42
42
|
));
|
|
43
43
|
}
|
|
44
44
|
|
|
@@ -50,11 +50,14 @@ impl<'a> InputTensors<'a> {
|
|
|
50
50
|
}
|
|
51
51
|
|
|
52
52
|
pub fn extract_output_tensor<'a>(
|
|
53
|
-
outputs: &'a ort::session::SessionOutputs<'
|
|
53
|
+
outputs: &'a ort::session::SessionOutputs<'_>,
|
|
54
54
|
output_name: &str,
|
|
55
|
-
) -> Result<
|
|
55
|
+
) -> Result<ArrayViewD<'a, f32>> {
|
|
56
56
|
let tensor_value = outputs.get(output_name).ok_or_else(|| {
|
|
57
|
-
GteError::Inference(format!(
|
|
57
|
+
GteError::Inference(format!(
|
|
58
|
+
"output tensor '{}' not found in model outputs",
|
|
59
|
+
output_name
|
|
60
|
+
))
|
|
58
61
|
})?;
|
|
59
|
-
Ok(tensor_value.
|
|
62
|
+
Ok(tensor_value.try_extract_array::<f32>()?)
|
|
60
63
|
}
|