gte 0.0.5 → 0.0.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +34 -3
- data/Rakefile +2 -2
- data/VERSION +1 -1
- data/ext/gte/Cargo.toml +1 -1
- data/ext/gte/src/embedder.rs +3 -0
- data/ext/gte/src/model_config.rs +1 -0
- data/ext/gte/src/reranker.rs +2 -0
- data/ext/gte/src/ruby_embedder.rs +16 -2
- data/ext/gte/src/session.rs +67 -7
- data/ext/gte/tests/inference_integration_test.rs +8 -8
- data/lib/gte/config.rb +2 -2
- data/lib/gte/embedder.rb +41 -0
- data/lib/gte/model.rb +1 -9
- data/lib/gte/reranker.rb +4 -2
- data/lib/gte.rb +3 -1
- metadata +3 -2
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: fc149108c647dc5b14154bfbdc4975b53670b9ed3cf7d80760cc2b415c935a48
|
|
4
|
+
data.tar.gz: 32a682a95d56c8fab8d0d64a7ada0c0347ae796b6aefe6191f9aca8fc96426c2
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: f5c69d954f51a51521b143b576942a9c0505ad60574c1727f963dd79e0b6c22cacc4e6d9af75394ae06f451521dbc788af51f1e79397a5cc66a41b4ce1b31933
|
|
7
|
+
data.tar.gz: 9e75fdbc9b5c8cfdd9d0e377a7e4a944057ec604e38ab23d960c4ed75ec6a72ce1dd27c2dd1bb2802721387babdabe0996e0c42be34d17d98253e0582b375de1
|
data/README.md
CHANGED
|
@@ -55,12 +55,21 @@ Config fields and defaults:
|
|
|
55
55
|
- `normalize`: `true` (L2 normalization at Ruby-facing API)
|
|
56
56
|
- `output_tensor`: `nil` (auto-select output tensor)
|
|
57
57
|
- `max_length`: `nil` (uses tokenizer/model defaults)
|
|
58
|
+
- `execution_providers`: `nil` (falls back to `GTE_EXECUTION_PROVIDERS` / CPU default)
|
|
58
59
|
|
|
59
60
|
Notes:
|
|
60
61
|
|
|
61
62
|
- Return a `Config::Text` from the block (for example, `config.with(...)`).
|
|
62
63
|
- Model instances are cached by full config key; different config values create different cached instances.
|
|
63
64
|
|
|
65
|
+
Low-level embedder setup (without model cache):
|
|
66
|
+
|
|
67
|
+
```ruby
|
|
68
|
+
embedder = GTE::Embedder.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
|
|
69
|
+
config.with(threads: 0, execution_providers: "cpu")
|
|
70
|
+
end
|
|
71
|
+
```
|
|
72
|
+
|
|
64
73
|
## Reranker
|
|
65
74
|
|
|
66
75
|
Use `GTE::Reranker.config(model_dir)` for cross-encoder reranking.
|
|
@@ -97,6 +106,7 @@ Reranker config fields and defaults:
|
|
|
97
106
|
- `sigmoid`: `false` (set `true` if you want bounded [0,1] style scores)
|
|
98
107
|
- `output_tensor`: `nil`
|
|
99
108
|
- `max_length`: `nil`
|
|
109
|
+
- `execution_providers`: `nil`
|
|
100
110
|
|
|
101
111
|
## Runtime + Result Examples
|
|
102
112
|
|
|
@@ -123,14 +133,30 @@ Input policy is text-only. Graphs requiring unsupported multimodal inputs (such
|
|
|
123
133
|
|
|
124
134
|
## Execution Providers
|
|
125
135
|
|
|
126
|
-
Default
|
|
136
|
+
Default behavior is CPU fallback via ONNX Runtime's default provider (no explicit provider registration).
|
|
137
|
+
|
|
138
|
+
Configure providers with `GTE_EXECUTION_PROVIDERS` (comma-separated, case-insensitive).
|
|
139
|
+
Supported values:
|
|
140
|
+
|
|
141
|
+
- `cpu` or `none`: CPU fallback (skip explicit provider registration)
|
|
142
|
+
- `xnnpack`
|
|
143
|
+
- `coreml`
|
|
127
144
|
|
|
128
|
-
|
|
145
|
+
Examples:
|
|
129
146
|
|
|
130
147
|
```bash
|
|
148
|
+
export GTE_EXECUTION_PROVIDERS=cpu
|
|
131
149
|
export GTE_EXECUTION_PROVIDERS=xnnpack,coreml
|
|
132
150
|
```
|
|
133
151
|
|
|
152
|
+
Ruby per-instance override (takes precedence over `GTE_EXECUTION_PROVIDERS`):
|
|
153
|
+
|
|
154
|
+
```ruby
|
|
155
|
+
model = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
|
|
156
|
+
config.with(execution_providers: "cpu")
|
|
157
|
+
end
|
|
158
|
+
```
|
|
159
|
+
|
|
134
160
|
## Development
|
|
135
161
|
|
|
136
162
|
Run commands inside `nix develop` via Make targets:
|
|
@@ -154,8 +180,13 @@ nix develop -c bundle exec rake bench:matrix_sweep
|
|
|
154
180
|
nix develop -c bundle exec ruby bench/memory_probe.rb --compare-pure
|
|
155
181
|
```
|
|
156
182
|
|
|
157
|
-
|
|
183
|
+
To run benchmark + append a `RUNS.md` entry + enforce goal checks:
|
|
158
184
|
|
|
159
185
|
```bash
|
|
160
186
|
make bench-record
|
|
161
187
|
```
|
|
188
|
+
|
|
189
|
+
`bench/runs_ledger.rb check` is goal-focused by default:
|
|
190
|
+
|
|
191
|
+
- Enforces goal metric (`response_time_p95` ratio threshold).
|
|
192
|
+
- Does not require current-version coverage in `RUNS.md` unless explicitly enabled.
|
data/Rakefile
CHANGED
|
@@ -56,7 +56,7 @@ namespace :bench do
|
|
|
56
56
|
)
|
|
57
57
|
end
|
|
58
58
|
|
|
59
|
-
desc 'Run Puma benchmark, append RUNS.md entry, and enforce goal
|
|
59
|
+
desc 'Run Puma benchmark, append RUNS.md entry, and enforce goal checks'
|
|
60
60
|
task :record_run do
|
|
61
61
|
run_in_nix(
|
|
62
62
|
'bundle', 'exec', 'ruby', 'bench/puma_compare.rb',
|
|
@@ -74,7 +74,7 @@ namespace :bench do
|
|
|
74
74
|
)
|
|
75
75
|
end
|
|
76
76
|
|
|
77
|
-
desc 'Validate current Puma benchmark output against 2x goal
|
|
77
|
+
desc 'Validate current Puma benchmark output against 2x goal only'
|
|
78
78
|
task :check_goal do
|
|
79
79
|
run_in_nix(
|
|
80
80
|
'bundle', 'exec', 'ruby', 'bench/runs_ledger.rb', 'check',
|
data/VERSION
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
0.0.
|
|
1
|
+
0.0.6
|
data/ext/gte/Cargo.toml
CHANGED
data/ext/gte/src/embedder.rs
CHANGED
|
@@ -39,6 +39,7 @@ impl Embedder {
|
|
|
39
39
|
model_name: Option<&str>,
|
|
40
40
|
output_tensor_override: Option<&str>,
|
|
41
41
|
max_length_override: Option<usize>,
|
|
42
|
+
execution_providers_override: Option<&str>,
|
|
42
43
|
) -> Result<Self> {
|
|
43
44
|
const PREFERRED_EMBEDDING_OUTPUTS: [&str; 4] = [
|
|
44
45
|
"pooler_output",
|
|
@@ -73,6 +74,7 @@ impl Embedder {
|
|
|
73
74
|
with_attention_mask: true,
|
|
74
75
|
num_threads,
|
|
75
76
|
optimization_level,
|
|
77
|
+
execution_providers: execution_providers_override.map(str::to_string),
|
|
76
78
|
};
|
|
77
79
|
let session = build_session(&model_path, &session_config)?;
|
|
78
80
|
|
|
@@ -96,6 +98,7 @@ impl Embedder {
|
|
|
96
98
|
with_attention_mask,
|
|
97
99
|
num_threads,
|
|
98
100
|
optimization_level,
|
|
101
|
+
execution_providers: execution_providers_override.map(str::to_string),
|
|
99
102
|
};
|
|
100
103
|
|
|
101
104
|
let tokenizer = Tokenizer::new(&tokenizer_path, config.max_length, config.with_type_ids)?;
|
data/ext/gte/src/model_config.rs
CHANGED
data/ext/gte/src/reranker.rs
CHANGED
|
@@ -33,6 +33,7 @@ impl Reranker {
|
|
|
33
33
|
model_name: Option<&str>,
|
|
34
34
|
output_tensor_override: Option<&str>,
|
|
35
35
|
max_length_override: Option<usize>,
|
|
36
|
+
execution_providers_override: Option<&str>,
|
|
36
37
|
) -> Result<Self> {
|
|
37
38
|
let dir = dir.as_ref();
|
|
38
39
|
let tokenizer_path = resolve_tokenizer_path(dir)?;
|
|
@@ -60,6 +61,7 @@ impl Reranker {
|
|
|
60
61
|
with_attention_mask: true,
|
|
61
62
|
num_threads,
|
|
62
63
|
optimization_level,
|
|
64
|
+
execution_providers: execution_providers_override.map(str::to_string),
|
|
63
65
|
};
|
|
64
66
|
let session = build_session(&model_path, &probe_config)?;
|
|
65
67
|
|
|
@@ -171,6 +171,7 @@ impl RbEmbedder {
|
|
|
171
171
|
normalize: bool,
|
|
172
172
|
output_tensor: String,
|
|
173
173
|
max_length: usize,
|
|
174
|
+
execution_providers: String,
|
|
174
175
|
) -> Result<Self, Error> {
|
|
175
176
|
let name = if model_name.is_empty() {
|
|
176
177
|
None
|
|
@@ -187,6 +188,11 @@ impl RbEmbedder {
|
|
|
187
188
|
} else {
|
|
188
189
|
Some(max_length)
|
|
189
190
|
};
|
|
191
|
+
let execution_providers_override = if execution_providers.is_empty() {
|
|
192
|
+
None
|
|
193
|
+
} else {
|
|
194
|
+
Some(execution_providers.as_str())
|
|
195
|
+
};
|
|
190
196
|
let embedder = Embedder::from_dir(
|
|
191
197
|
&dir_path,
|
|
192
198
|
num_threads,
|
|
@@ -194,6 +200,7 @@ impl RbEmbedder {
|
|
|
194
200
|
name,
|
|
195
201
|
output_override,
|
|
196
202
|
max_length_override,
|
|
203
|
+
execution_providers_override,
|
|
197
204
|
)
|
|
198
205
|
.map_err(magnus::Error::from)?;
|
|
199
206
|
Ok(RbEmbedder {
|
|
@@ -224,6 +231,7 @@ impl RbReranker {
|
|
|
224
231
|
sigmoid: bool,
|
|
225
232
|
output_tensor: String,
|
|
226
233
|
max_length: usize,
|
|
234
|
+
execution_providers: String,
|
|
227
235
|
) -> Result<Self, Error> {
|
|
228
236
|
let name = if model_name.is_empty() {
|
|
229
237
|
None
|
|
@@ -240,6 +248,11 @@ impl RbReranker {
|
|
|
240
248
|
} else {
|
|
241
249
|
Some(max_length)
|
|
242
250
|
};
|
|
251
|
+
let execution_providers_override = if execution_providers.is_empty() {
|
|
252
|
+
None
|
|
253
|
+
} else {
|
|
254
|
+
Some(execution_providers.as_str())
|
|
255
|
+
};
|
|
243
256
|
let reranker = Reranker::from_dir(
|
|
244
257
|
&dir_path,
|
|
245
258
|
num_threads,
|
|
@@ -247,6 +260,7 @@ impl RbReranker {
|
|
|
247
260
|
name,
|
|
248
261
|
output_override,
|
|
249
262
|
max_length_override,
|
|
263
|
+
execution_providers_override,
|
|
250
264
|
)
|
|
251
265
|
.map_err(magnus::Error::from)?;
|
|
252
266
|
Ok(RbReranker {
|
|
@@ -362,12 +376,12 @@ impl RbTensor {
|
|
|
362
376
|
pub fn register(ruby: &Ruby) -> Result<(), Error> {
|
|
363
377
|
let module = ruby.define_module("GTE")?;
|
|
364
378
|
let embedder_class = module.define_class("Embedder", ruby.class_object())?;
|
|
365
|
-
embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new,
|
|
379
|
+
embedder_class.define_singleton_method("new", function!(RbEmbedder::rb_new, 8))?;
|
|
366
380
|
embedder_class.define_method("embed", method!(RbEmbedder::rb_embed, 1))?;
|
|
367
381
|
embedder_class.define_method("embed_one", method!(RbEmbedder::rb_embed_one, 1))?;
|
|
368
382
|
|
|
369
383
|
let reranker_class = module.define_class("Reranker", ruby.class_object())?;
|
|
370
|
-
reranker_class.define_singleton_method("new", function!(RbReranker::rb_new,
|
|
384
|
+
reranker_class.define_singleton_method("new", function!(RbReranker::rb_new, 8))?;
|
|
371
385
|
reranker_class.define_method("score", method!(RbReranker::rb_score, 2))?;
|
|
372
386
|
|
|
373
387
|
let tensor_class = module.define_class("Tensor", ruby.class_object())?;
|
data/ext/gte/src/session.rs
CHANGED
|
@@ -22,7 +22,7 @@ pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Res
|
|
|
22
22
|
.with_optimization_level(opt_level)?
|
|
23
23
|
.with_memory_pattern(true)?;
|
|
24
24
|
|
|
25
|
-
let providers = preferred_execution_providers();
|
|
25
|
+
let providers = preferred_execution_providers(config.execution_providers.as_deref());
|
|
26
26
|
if !providers.is_empty() {
|
|
27
27
|
builder = builder.with_execution_providers(providers)?;
|
|
28
28
|
}
|
|
@@ -34,19 +34,40 @@ pub fn build_session<P: AsRef<Path>>(model_path: P, config: &ModelConfig) -> Res
|
|
|
34
34
|
Ok(builder.commit_from_file(model_path)?)
|
|
35
35
|
}
|
|
36
36
|
|
|
37
|
-
fn preferred_execution_providers() -> Vec<ExecutionProviderDispatch> {
|
|
38
|
-
let order =
|
|
39
|
-
.unwrap_or_else(|_| "xnnpack".to_string())
|
|
40
|
-
.to_ascii_lowercase();
|
|
37
|
+
fn preferred_execution_providers(order_override: Option<&str>) -> Vec<ExecutionProviderDispatch> {
|
|
38
|
+
let order = resolve_provider_order(order_override);
|
|
41
39
|
|
|
42
40
|
let mut providers = Vec::new();
|
|
43
|
-
for provider in order.
|
|
41
|
+
for provider in parse_provider_registrations(order.as_str()) {
|
|
44
42
|
match provider {
|
|
45
43
|
"xnnpack" => {
|
|
46
44
|
providers.push(XNNPACKExecutionProvider::default().build().fail_silently())
|
|
47
45
|
}
|
|
48
46
|
"coreml" => providers.push(CoreMLExecutionProvider::default().build().fail_silently()),
|
|
49
|
-
|
|
47
|
+
_ => {}
|
|
48
|
+
}
|
|
49
|
+
}
|
|
50
|
+
providers
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
fn resolve_provider_order(order_override: Option<&str>) -> String {
|
|
54
|
+
let env_order = std::env::var("GTE_EXECUTION_PROVIDERS").ok();
|
|
55
|
+
resolve_provider_order_with_env(order_override, env_order.as_deref())
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
fn resolve_provider_order_with_env(order_override: Option<&str>, env_order: Option<&str>) -> String {
|
|
59
|
+
order_override
|
|
60
|
+
.or(env_order)
|
|
61
|
+
.unwrap_or("cpu")
|
|
62
|
+
.to_ascii_lowercase()
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
fn parse_provider_registrations(order: &str) -> Vec<&str> {
|
|
66
|
+
let mut providers = Vec::new();
|
|
67
|
+
for provider in order.split(',').map(str::trim).filter(|p| !p.is_empty()) {
|
|
68
|
+
match provider {
|
|
69
|
+
"xnnpack" | "coreml" => providers.push(provider),
|
|
70
|
+
"none" | "cpu" => {}
|
|
50
71
|
_ => {}
|
|
51
72
|
}
|
|
52
73
|
}
|
|
@@ -86,3 +107,42 @@ pub fn run_session(
|
|
|
86
107
|
ExtractorMode::Raw => Ok(array.into_dimensionality::<Ix2>()?.into_owned()),
|
|
87
108
|
}
|
|
88
109
|
}
|
|
110
|
+
|
|
111
|
+
#[cfg(test)]
|
|
112
|
+
mod tests {
|
|
113
|
+
use super::{parse_provider_registrations, resolve_provider_order_with_env};
|
|
114
|
+
|
|
115
|
+
#[test]
|
|
116
|
+
fn parse_provider_registrations_keeps_supported_order() {
|
|
117
|
+
let parsed = parse_provider_registrations("xnnpack,coreml");
|
|
118
|
+
assert_eq!(parsed, vec!["xnnpack", "coreml"]);
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
#[test]
|
|
122
|
+
fn parse_provider_registrations_treats_cpu_and_none_as_fallback() {
|
|
123
|
+
assert!(parse_provider_registrations("cpu").is_empty());
|
|
124
|
+
assert!(parse_provider_registrations("none").is_empty());
|
|
125
|
+
assert!(parse_provider_registrations("none,cpu").is_empty());
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
#[test]
|
|
129
|
+
fn parse_provider_registrations_ignores_unknowns_and_empties() {
|
|
130
|
+
let parsed = parse_provider_registrations(" ,xnnpak,,xnnpack,unknown,coreml,");
|
|
131
|
+
assert_eq!(parsed, vec!["xnnpack", "coreml"]);
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
#[test]
|
|
135
|
+
fn resolve_provider_order_prefers_override() {
|
|
136
|
+
assert_eq!(
|
|
137
|
+
resolve_provider_order_with_env(Some("xnnpack"), Some("coreml")),
|
|
138
|
+
"xnnpack"
|
|
139
|
+
);
|
|
140
|
+
assert_eq!(resolve_provider_order_with_env(Some("CPU"), None), "cpu");
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
#[test]
|
|
144
|
+
fn resolve_provider_order_falls_back_to_env_then_cpu_default() {
|
|
145
|
+
assert_eq!(resolve_provider_order_with_env(None, Some("coreml")), "coreml");
|
|
146
|
+
assert_eq!(resolve_provider_order_with_env(None, None), "cpu");
|
|
147
|
+
}
|
|
148
|
+
}
|
|
@@ -5,8 +5,8 @@ use gte::embedder::Embedder;
|
|
|
5
5
|
fn test_e5_single_embedding_shape() {
|
|
6
6
|
const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
|
|
7
7
|
|
|
8
|
-
let embedder =
|
|
9
|
-
|
|
8
|
+
let embedder = Embedder::from_dir(DIR, 0, 3, None, None, None, None)
|
|
9
|
+
.expect("embedder should initialize");
|
|
10
10
|
let result = embedder
|
|
11
11
|
.embed(vec!["query: Hello world".to_string()])
|
|
12
12
|
.expect("embed should succeed");
|
|
@@ -20,8 +20,8 @@ fn test_e5_single_embedding_shape() {
|
|
|
20
20
|
fn test_clip_single_embedding_shape() {
|
|
21
21
|
const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/clip");
|
|
22
22
|
|
|
23
|
-
let embedder =
|
|
24
|
-
|
|
23
|
+
let embedder = Embedder::from_dir(DIR, 0, 3, None, None, None, None)
|
|
24
|
+
.expect("embedder should initialize");
|
|
25
25
|
let result = embedder
|
|
26
26
|
.embed(vec!["a photo of a cat".to_string()])
|
|
27
27
|
.expect("embed should succeed");
|
|
@@ -35,8 +35,8 @@ fn test_clip_single_embedding_shape() {
|
|
|
35
35
|
fn test_e5_batch_embedding_shape() {
|
|
36
36
|
const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
|
|
37
37
|
|
|
38
|
-
let embedder =
|
|
39
|
-
|
|
38
|
+
let embedder = Embedder::from_dir(DIR, 0, 3, None, None, None, None)
|
|
39
|
+
.expect("embedder should initialize");
|
|
40
40
|
let texts = vec![
|
|
41
41
|
"query: first sentence".to_string(),
|
|
42
42
|
"query: second sentence".to_string(),
|
|
@@ -54,8 +54,8 @@ fn test_e5_batch_embedding_shape() {
|
|
|
54
54
|
fn test_e5_long_input_truncation_no_error() {
|
|
55
55
|
const DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/e5");
|
|
56
56
|
|
|
57
|
-
let embedder =
|
|
58
|
-
|
|
57
|
+
let embedder = Embedder::from_dir(DIR, 0, 3, None, None, None, None)
|
|
58
|
+
.expect("embedder should initialize");
|
|
59
59
|
let very_long_text = "word ".repeat(1000);
|
|
60
60
|
let result = embedder
|
|
61
61
|
.embed(vec![very_long_text])
|
data/lib/gte/config.rb
CHANGED
|
@@ -4,12 +4,12 @@ module GTE
|
|
|
4
4
|
module Config
|
|
5
5
|
Text = Data.define(
|
|
6
6
|
:model_dir, :threads, :optimization_level,
|
|
7
|
-
:model_name, :normalize, :output_tensor, :max_length
|
|
7
|
+
:model_name, :normalize, :output_tensor, :max_length, :execution_providers
|
|
8
8
|
)
|
|
9
9
|
|
|
10
10
|
Reranker = Data.define(
|
|
11
11
|
:model_dir, :threads, :optimization_level,
|
|
12
|
-
:model_name, :sigmoid, :output_tensor, :max_length
|
|
12
|
+
:model_name, :sigmoid, :output_tensor, :max_length, :execution_providers
|
|
13
13
|
)
|
|
14
14
|
end
|
|
15
15
|
end
|
data/lib/gte/embedder.rb
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module GTE
|
|
4
|
+
class Embedder
|
|
5
|
+
class << self
|
|
6
|
+
def config(model_dir)
|
|
7
|
+
cfg = default_config(model_dir)
|
|
8
|
+
cfg = yield(cfg) if block_given?
|
|
9
|
+
from_config(cfg)
|
|
10
|
+
end
|
|
11
|
+
|
|
12
|
+
def from_config(config)
|
|
13
|
+
new(
|
|
14
|
+
config.model_dir,
|
|
15
|
+
config.threads,
|
|
16
|
+
config.optimization_level,
|
|
17
|
+
config.model_name.to_s,
|
|
18
|
+
config.normalize,
|
|
19
|
+
config.output_tensor.to_s,
|
|
20
|
+
config.max_length || 0,
|
|
21
|
+
config.execution_providers.to_s
|
|
22
|
+
)
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
private
|
|
26
|
+
|
|
27
|
+
def default_config(model_dir)
|
|
28
|
+
Config::Text.new(
|
|
29
|
+
model_dir: File.expand_path(model_dir),
|
|
30
|
+
threads: 3,
|
|
31
|
+
optimization_level: 3,
|
|
32
|
+
model_name: nil,
|
|
33
|
+
normalize: true,
|
|
34
|
+
output_tensor: nil,
|
|
35
|
+
max_length: nil,
|
|
36
|
+
execution_providers: nil
|
|
37
|
+
)
|
|
38
|
+
end
|
|
39
|
+
end
|
|
40
|
+
end
|
|
41
|
+
end
|
data/lib/gte/model.rb
CHANGED
|
@@ -8,15 +8,7 @@ module GTE
|
|
|
8
8
|
raise ArgumentError, 'config must be a GTE::Config::Text' unless config.is_a?(Config::Text)
|
|
9
9
|
|
|
10
10
|
@config = config
|
|
11
|
-
@embedder = GTE::Embedder.
|
|
12
|
-
config.model_dir,
|
|
13
|
-
config.threads,
|
|
14
|
-
config.optimization_level,
|
|
15
|
-
config.model_name.to_s,
|
|
16
|
-
config.normalize,
|
|
17
|
-
config.output_tensor.to_s,
|
|
18
|
-
config.max_length || 0
|
|
19
|
-
)
|
|
11
|
+
@embedder = GTE::Embedder.from_config(config)
|
|
20
12
|
end
|
|
21
13
|
|
|
22
14
|
def embed(texts)
|
data/lib/gte/reranker.rb
CHANGED
|
@@ -24,7 +24,8 @@ module GTE
|
|
|
24
24
|
model_name: nil,
|
|
25
25
|
sigmoid: false,
|
|
26
26
|
output_tensor: nil,
|
|
27
|
-
max_length: nil
|
|
27
|
+
max_length: nil,
|
|
28
|
+
execution_providers: nil
|
|
28
29
|
)
|
|
29
30
|
end
|
|
30
31
|
|
|
@@ -36,7 +37,8 @@ module GTE
|
|
|
36
37
|
cfg.model_name.to_s,
|
|
37
38
|
cfg.sigmoid,
|
|
38
39
|
cfg.output_tensor.to_s,
|
|
39
|
-
cfg.max_length || 0
|
|
40
|
+
cfg.max_length || 0,
|
|
41
|
+
cfg.execution_providers.to_s
|
|
40
42
|
)
|
|
41
43
|
end
|
|
42
44
|
end
|
data/lib/gte.rb
CHANGED
|
@@ -9,6 +9,7 @@ rescue LoadError
|
|
|
9
9
|
end
|
|
10
10
|
|
|
11
11
|
require 'gte/config'
|
|
12
|
+
require 'gte/embedder'
|
|
12
13
|
require 'gte/model'
|
|
13
14
|
require 'gte/reranker'
|
|
14
15
|
|
|
@@ -25,7 +26,8 @@ module GTE
|
|
|
25
26
|
model_name: nil,
|
|
26
27
|
normalize: true,
|
|
27
28
|
output_tensor: nil,
|
|
28
|
-
max_length: nil
|
|
29
|
+
max_length: nil,
|
|
30
|
+
execution_providers: nil
|
|
29
31
|
)
|
|
30
32
|
|
|
31
33
|
cfg = yield(cfg) if block_given?
|
metadata
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: gte
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 0.0.
|
|
4
|
+
version: 0.0.6
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- elcuervo
|
|
8
8
|
autorequire:
|
|
9
9
|
bindir: bin
|
|
10
10
|
cert_chain: []
|
|
11
|
-
date: 2026-04-
|
|
11
|
+
date: 2026-04-16 00:00:00.000000000 Z
|
|
12
12
|
dependencies:
|
|
13
13
|
- !ruby/object:Gem::Dependency
|
|
14
14
|
name: rake
|
|
@@ -114,6 +114,7 @@ files:
|
|
|
114
114
|
- ext/gte/tests/tokenizer_unit_test.rs
|
|
115
115
|
- lib/gte.rb
|
|
116
116
|
- lib/gte/config.rb
|
|
117
|
+
- lib/gte/embedder.rb
|
|
117
118
|
- lib/gte/model.rb
|
|
118
119
|
- lib/gte/reranker.rb
|
|
119
120
|
- lib/gte/version.rb
|