gte 0.0.12 → 0.0.13
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +37 -0
- data/VERSION +1 -1
- data/ext/gte/Cargo.toml +1 -1
- data/ext/gte/src/session.rs +5 -2
- metadata +1 -1
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: 278028df09fbcdd14fd583f0af5e1a8c9553adb28fe7aa0bc67b67666dbbdccd
|
|
4
|
+
data.tar.gz: ce994e3f505200ed4654ca8f87f585ff88919201fe82dd79007622f07a3d1ea0
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: 742f1830ff2b83f89726be527c4323a81649b04f341b7adc0544a9000373f6a097c0b4b4ba211ead5912ba45d876565fbaab6d723ef8f06c488ab7827323f827
|
|
7
|
+
data.tar.gz: 75e91b3d4c3980b166268c6468b96bebe4b74db999e0cee433a295e57d89bec95c7614b004c61e8b3ed88cff30f02f3b6aff74de710d3dd3bb34552f36fb3422
|
data/README.md
CHANGED
|
@@ -59,6 +59,43 @@ Notes:
|
|
|
59
59
|
- Return a `Config::Text` from the block (for example, `config.with(...)`).
|
|
60
60
|
- Model instances are cached by full config key; different config values create different cached instances.
|
|
61
61
|
|
|
62
|
+
Common model presets:
|
|
63
|
+
|
|
64
|
+
```ruby
|
|
65
|
+
e5 = GTE.config(ENV.fetch("GTE_MODEL_DIR")) do |config|
|
|
66
|
+
config.with(
|
|
67
|
+
model_name: "model.onnx",
|
|
68
|
+
output_tensor: "last_hidden_state",
|
|
69
|
+
max_length: 512,
|
|
70
|
+
execution_providers: "cpu"
|
|
71
|
+
)
|
|
72
|
+
end
|
|
73
|
+
|
|
74
|
+
siglip2 = GTE.config(ENV.fetch("GTE_SIGLIP2_DIR")) do |config|
|
|
75
|
+
config.with(
|
|
76
|
+
model_name: "text_model_int8.onnx",
|
|
77
|
+
output_tensor: "pooler_output",
|
|
78
|
+
max_length: 64,
|
|
79
|
+
execution_providers: "cpu"
|
|
80
|
+
)
|
|
81
|
+
end
|
|
82
|
+
|
|
83
|
+
clip = GTE.config(ENV.fetch("GTE_CLIP_DIR")) do |config|
|
|
84
|
+
config.with(
|
|
85
|
+
output_tensor: "sentence_embedding",
|
|
86
|
+
max_length: 512,
|
|
87
|
+
execution_providers: "cpu"
|
|
88
|
+
)
|
|
89
|
+
end
|
|
90
|
+
```
|
|
91
|
+
|
|
92
|
+
Picking a specific layer:
|
|
93
|
+
|
|
94
|
+
- Use `output_tensor:` to request a named model output.
|
|
95
|
+
- `last_hidden_state` gives token-level hidden states and is mean-pooled by `gte` when the tensor is rank 3.
|
|
96
|
+
- `pooler_output`, `sentence_embedding`, and similar 2D tensors are returned directly and then L2-normalized by default.
|
|
97
|
+
- If the requested tensor is not present in the model, `gte` raises an error instead of silently falling back.
|
|
98
|
+
|
|
62
99
|
Low-level embedder setup (without model cache):
|
|
63
100
|
|
|
64
101
|
```ruby
|
data/VERSION
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
0.0.
|
|
1
|
+
0.0.13
|
data/ext/gte/Cargo.toml
CHANGED
data/ext/gte/src/session.rs
CHANGED
|
@@ -7,7 +7,7 @@ use ndarray::{Array2, ArrayView2, ArrayViewD, Ix2};
|
|
|
7
7
|
use ort::execution_providers::{
|
|
8
8
|
CoreMLExecutionProvider, ExecutionProviderDispatch, XNNPACKExecutionProvider,
|
|
9
9
|
};
|
|
10
|
-
use ort::session::Session;
|
|
10
|
+
use ort::session::{OutputSelector, RunOptions, Session};
|
|
11
11
|
use std::path::{Path, PathBuf};
|
|
12
12
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
13
13
|
use std::sync::{Condvar, Mutex};
|
|
@@ -216,8 +216,11 @@ pub fn run_session(
|
|
|
216
216
|
config: &ModelConfig,
|
|
217
217
|
) -> Result<Array2<f32>> {
|
|
218
218
|
let input_tensors = InputTensors::from_tokenized(tokenized, config.with_attention_mask)?;
|
|
219
|
+
let run_opts = RunOptions::new()
|
|
220
|
+
.map_err(|e| GteError::Ort(e.to_string()))?
|
|
221
|
+
.with_outputs(OutputSelector::no_default().with(config.output_tensor.as_str()));
|
|
219
222
|
let outputs = session
|
|
220
|
-
.
|
|
223
|
+
.run_with_options(input_tensors.inputs, &run_opts)
|
|
221
224
|
.map_err(|e| GteError::Ort(e.to_string()))?;
|
|
222
225
|
let array = extract_output_tensor(&outputs, config.output_tensor.as_str())?;
|
|
223
226
|
|