red-candle 1.8.0.pre3-aarch64-linux
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/Cargo.lock +5021 -0
- data/Cargo.toml +6 -0
- data/Gemfile +3 -0
- data/LICENSE +22 -0
- data/README.md +1171 -0
- data/Rakefile +167 -0
- data/bin/console +11 -0
- data/bin/setup +17 -0
- data/ext/candle/Cargo.toml +38 -0
- data/ext/candle/build.rs +117 -0
- data/ext/candle/extconf.rb +79 -0
- data/ext/candle/rustfmt.toml +63 -0
- data/ext/candle/src/gvl.rs +58 -0
- data/ext/candle/src/lib.rs +59 -0
- data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
- data/ext/candle/src/llm/gemma.rs +313 -0
- data/ext/candle/src/llm/generation_config.rs +63 -0
- data/ext/candle/src/llm/glm4.rs +236 -0
- data/ext/candle/src/llm/granite.rs +308 -0
- data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
- data/ext/candle/src/llm/llama.rs +396 -0
- data/ext/candle/src/llm/mistral.rs +309 -0
- data/ext/candle/src/llm/mod.rs +49 -0
- data/ext/candle/src/llm/phi.rs +369 -0
- data/ext/candle/src/llm/quantized_gguf.rs +734 -0
- data/ext/candle/src/llm/qwen.rs +261 -0
- data/ext/candle/src/llm/qwen3.rs +257 -0
- data/ext/candle/src/llm/text_generation.rs +284 -0
- data/ext/candle/src/ruby/device.rs +234 -0
- data/ext/candle/src/ruby/dtype.rs +39 -0
- data/ext/candle/src/ruby/embedding_model.rs +477 -0
- data/ext/candle/src/ruby/errors.rs +16 -0
- data/ext/candle/src/ruby/llm.rs +730 -0
- data/ext/candle/src/ruby/mod.rs +24 -0
- data/ext/candle/src/ruby/ner.rs +444 -0
- data/ext/candle/src/ruby/reranker.rs +488 -0
- data/ext/candle/src/ruby/result.rs +3 -0
- data/ext/candle/src/ruby/structured.rs +92 -0
- data/ext/candle/src/ruby/tensor.rs +731 -0
- data/ext/candle/src/ruby/tokenizer.rs +343 -0
- data/ext/candle/src/ruby/utils.rs +96 -0
- data/ext/candle/src/ruby/vlm.rs +330 -0
- data/ext/candle/src/structured/integration_test.rs +130 -0
- data/ext/candle/src/structured/mod.rs +31 -0
- data/ext/candle/src/structured/schema_processor.rs +215 -0
- data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
- data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
- data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
- data/ext/candle/src/tokenizer/loader.rs +108 -0
- data/ext/candle/src/tokenizer/mod.rs +104 -0
- data/ext/candle/tests/device_tests.rs +43 -0
- data/ext/candle/tests/tensor_tests.rs +162 -0
- data/lib/candle/3.1/candle.so +0 -0
- data/lib/candle/3.2/candle.so +0 -0
- data/lib/candle/3.3/candle.so +0 -0
- data/lib/candle/3.4/candle.so +0 -0
- data/lib/candle/4.0/candle.so +0 -0
- data/lib/candle/agent.rb +68 -0
- data/lib/candle/build_info.rb +67 -0
- data/lib/candle/device_utils.rb +10 -0
- data/lib/candle/embedding_model.rb +75 -0
- data/lib/candle/embedding_model_type.rb +31 -0
- data/lib/candle/llm.rb +595 -0
- data/lib/candle/logger.rb +149 -0
- data/lib/candle/ner.rb +368 -0
- data/lib/candle/reranker.rb +45 -0
- data/lib/candle/tensor.rb +99 -0
- data/lib/candle/tokenizer.rb +139 -0
- data/lib/candle/tool.rb +47 -0
- data/lib/candle/tool_call_parser.rb +57 -0
- data/lib/candle/version.rb +5 -0
- data/lib/candle/vlm.rb +31 -0
- data/lib/candle.rb +29 -0
- data/lib/red-candle.rb +1 -0
- metadata +309 -0
|
@@ -0,0 +1,477 @@
|
|
|
1
|
+
// MKL and Accelerate are handled by candle-core when their features are enabled
|
|
2
|
+
|
|
3
|
+
use crate::ruby::{
|
|
4
|
+
errors::{wrap_candle_err, wrap_hf_err, wrap_std_err},
|
|
5
|
+
};
|
|
6
|
+
use crate::ruby::{Tensor, Device, Result};
|
|
7
|
+
use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
|
|
8
|
+
use candle_core::{DType as CoreDType, Device as CoreDevice, Module, Tensor as CoreTensor};
|
|
9
|
+
use safetensors::tensor::SafeTensors;
|
|
10
|
+
use candle_nn::VarBuilder;
|
|
11
|
+
use candle_transformers::models::{
|
|
12
|
+
bert::{BertModel as StdBertModel, Config as BertConfig},
|
|
13
|
+
jina_bert::{BertModel as JinaBertModel, Config as JinaConfig},
|
|
14
|
+
distilbert::{DistilBertModel, Config as DistilBertConfig}
|
|
15
|
+
};
|
|
16
|
+
use magnus::{function, method, prelude::*, Error, RModule, RHash, Ruby};
|
|
17
|
+
use std::path::Path;
|
|
18
|
+
use serde_json;
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
#[magnus::wrap(class = "Candle::EmbeddingModel", free_immediately, size)]
|
|
22
|
+
pub struct EmbeddingModel(pub EmbeddingModelInner);
|
|
23
|
+
|
|
24
|
+
/// Supported model types for embedding generation
|
|
25
|
+
#[derive(Debug, Clone, Copy, PartialEq)]
|
|
26
|
+
pub enum EmbeddingModelType {
|
|
27
|
+
JinaBert,
|
|
28
|
+
StandardBert,
|
|
29
|
+
DistilBert,
|
|
30
|
+
MiniLM,
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
impl EmbeddingModelType {
|
|
34
|
+
pub fn from_string(model_type: &str) -> Option<Self> {
|
|
35
|
+
match model_type.to_lowercase().as_str() {
|
|
36
|
+
"jina_bert" | "jinabert" | "jina" => Some(EmbeddingModelType::JinaBert),
|
|
37
|
+
"bert" | "standard_bert" | "standardbert" => Some(EmbeddingModelType::StandardBert),
|
|
38
|
+
"minilm" => Some(EmbeddingModelType::MiniLM),
|
|
39
|
+
|
|
40
|
+
"distilbert" => Some(EmbeddingModelType::DistilBert),
|
|
41
|
+
_ => None
|
|
42
|
+
}
|
|
43
|
+
}
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
/// Model variants that can produce embeddings
|
|
47
|
+
pub enum EmbeddingModelVariant {
|
|
48
|
+
JinaBert(JinaBertModel),
|
|
49
|
+
StandardBert(StdBertModel),
|
|
50
|
+
DistilBert(DistilBertModel),
|
|
51
|
+
MiniLM(StdBertModel),
|
|
52
|
+
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
impl EmbeddingModelVariant {
|
|
56
|
+
pub fn model_type(&self) -> EmbeddingModelType {
|
|
57
|
+
match self {
|
|
58
|
+
EmbeddingModelVariant::JinaBert(_) => EmbeddingModelType::JinaBert,
|
|
59
|
+
EmbeddingModelVariant::StandardBert(_) => EmbeddingModelType::StandardBert,
|
|
60
|
+
EmbeddingModelVariant::DistilBert(_) => EmbeddingModelType::DistilBert,
|
|
61
|
+
EmbeddingModelVariant::MiniLM(_) => EmbeddingModelType::MiniLM,
|
|
62
|
+
|
|
63
|
+
}
|
|
64
|
+
}
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
pub struct EmbeddingModelInner {
|
|
68
|
+
device: CoreDevice,
|
|
69
|
+
tokenizer_id: Option<String>,
|
|
70
|
+
model_id: Option<String>,
|
|
71
|
+
model_type: Option<EmbeddingModelType>,
|
|
72
|
+
model: Option<EmbeddingModelVariant>,
|
|
73
|
+
tokenizer: Option<TokenizerWrapper>,
|
|
74
|
+
embedding_size: Option<usize>,
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
impl EmbeddingModel {
|
|
78
|
+
pub fn new(model_id: Option<String>, tokenizer: Option<String>, device: Option<Device>, model_type: Option<String>, embedding_size: Option<usize>) -> Result<Self> {
|
|
79
|
+
let device = device.unwrap_or(Device::best()).as_device()?;
|
|
80
|
+
let model_type = model_type
|
|
81
|
+
.and_then(|mt| EmbeddingModelType::from_string(&mt))
|
|
82
|
+
.unwrap_or(EmbeddingModelType::JinaBert);
|
|
83
|
+
Ok(EmbeddingModel(EmbeddingModelInner {
|
|
84
|
+
device: device.clone(),
|
|
85
|
+
model_id: model_id.clone(),
|
|
86
|
+
tokenizer_id: tokenizer.clone(),
|
|
87
|
+
model_type: Some(model_type),
|
|
88
|
+
model: match model_id.as_ref() {
|
|
89
|
+
Some(id) => Some(Self::build_embedding_model(id, device, model_type, embedding_size)?),
|
|
90
|
+
None => None
|
|
91
|
+
},
|
|
92
|
+
tokenizer: match tokenizer {
|
|
93
|
+
Some(tid) => Some(Self::build_tokenizer(tid)?),
|
|
94
|
+
None => None
|
|
95
|
+
},
|
|
96
|
+
embedding_size,
|
|
97
|
+
}))
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
/// Generates an embedding vector for the input text
|
|
101
|
+
/// &RETURNS&: Tensor
|
|
102
|
+
/// Generates an embedding vector for the input text using the specified pooling method.
|
|
103
|
+
/// &RETURNS&: Tensor
|
|
104
|
+
/// pooling_method: "pooled", "pooled_normalized", or "cls" (default: "pooled")
|
|
105
|
+
pub fn embedding(&self, input: String, pooling_method: String) -> Result<Tensor> {
|
|
106
|
+
let ruby = Ruby::get().unwrap();
|
|
107
|
+
match (&self.0.model, &self.0.tokenizer) {
|
|
108
|
+
(Some(model), Some(tokenizer)) => {
|
|
109
|
+
let result = crate::gvl::without_gvl(|| {
|
|
110
|
+
self.compute_embedding(input, model, tokenizer, &pooling_method)
|
|
111
|
+
});
|
|
112
|
+
Ok(Tensor(result?))
|
|
113
|
+
}
|
|
114
|
+
(None, _) => Err(magnus::Error::new(ruby.exception_runtime_error(), "Model not found")),
|
|
115
|
+
(_, None) => Err(magnus::Error::new(ruby.exception_runtime_error(), "Tokenizer not found")),
|
|
116
|
+
}
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
/// Returns the unpooled embedding tensor ([1, SEQLENGTH, DIM]) for the input text
|
|
120
|
+
/// &RETURNS&: Tensor
|
|
121
|
+
pub fn embeddings(&self, input: String) -> Result<Tensor> {
|
|
122
|
+
let ruby = Ruby::get().unwrap();
|
|
123
|
+
match (&self.0.model, &self.0.tokenizer) {
|
|
124
|
+
(Some(model), Some(tokenizer)) => {
|
|
125
|
+
let result = crate::gvl::without_gvl(|| {
|
|
126
|
+
self.compute_embeddings(input, model, tokenizer)
|
|
127
|
+
});
|
|
128
|
+
Ok(Tensor(result?))
|
|
129
|
+
}
|
|
130
|
+
(None, _) => Err(magnus::Error::new(ruby.exception_runtime_error(), "Model not found")),
|
|
131
|
+
(_, None) => Err(magnus::Error::new(ruby.exception_runtime_error(), "Tokenizer not found")),
|
|
132
|
+
}
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
/// Pools and normalizes a sequence embedding tensor ([1, SEQLENGTH, DIM]) to [1, DIM]
|
|
136
|
+
/// &RETURNS&: Tensor
|
|
137
|
+
pub fn pool_embedding(&self, tensor: &Tensor) -> Result<Tensor> {
|
|
138
|
+
let pooled = Self::pooled_embedding(&tensor.0)?;
|
|
139
|
+
Ok(Tensor(pooled))
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
/// Pools and normalizes a sequence embedding tensor ([1, SEQLENGTH, DIM]) to [1, DIM]
|
|
143
|
+
/// &RETURNS&: Tensor
|
|
144
|
+
pub fn pool_and_normalize_embedding(&self, tensor: &Tensor) -> Result<Tensor> {
|
|
145
|
+
let pooled = Self::pooled_normalized_embedding(&tensor.0)?;
|
|
146
|
+
Ok(Tensor(pooled))
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
/// Pools the embedding tensor by extracting the CLS token ([1, SEQLENGTH, DIM] -> [1, DIM])
|
|
150
|
+
/// &RETURNS&: Tensor
|
|
151
|
+
pub fn pool_cls_embedding(&self, tensor: &Tensor) -> Result<Tensor> {
|
|
152
|
+
let pooled = Self::pooled_cls_embedding(&tensor.0)?;
|
|
153
|
+
Ok(Tensor(pooled))
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
/// Infers and validates the embedding size from a safetensors file
|
|
157
|
+
fn resolve_embedding_size(model_path: &Path, embedding_size: Option<usize>) -> std::result::Result<usize, magnus::Error> {
|
|
158
|
+
match embedding_size {
|
|
159
|
+
Some(user_dim) => {
|
|
160
|
+
Ok(user_dim)
|
|
161
|
+
},
|
|
162
|
+
None => {
|
|
163
|
+
let inferred_emb_dim = match SafeTensors::deserialize(&std::fs::read(model_path).map_err(|e| wrap_std_err(Box::new(e)))?) {
|
|
164
|
+
Ok(st) => {
|
|
165
|
+
if let Some(tensor) = st.tensor("embeddings.word_embeddings.weight").ok() {
|
|
166
|
+
let shape = tensor.shape();
|
|
167
|
+
if shape.len() == 2 { Some(shape[1] as usize) } else { None }
|
|
168
|
+
} else { None }
|
|
169
|
+
},
|
|
170
|
+
Err(_) => None
|
|
171
|
+
};
|
|
172
|
+
inferred_emb_dim.ok_or_else(|| {
|
|
173
|
+
let ruby = Ruby::get().unwrap();
|
|
174
|
+
magnus::Error::new(ruby.exception_runtime_error(), "Could not infer embedding size from model file. Please specify embedding_size explicitly.")
|
|
175
|
+
})
|
|
176
|
+
}
|
|
177
|
+
}
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
fn build_embedding_model(model_id: &str, device: CoreDevice, model_type: EmbeddingModelType, embedding_size: Option<usize>) -> Result<EmbeddingModelVariant> {
|
|
181
|
+
use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
182
|
+
let api = Api::new().map_err(wrap_hf_err)?;
|
|
183
|
+
let repo = Repo::new(model_id.to_string(), RepoType::Model);
|
|
184
|
+
match model_type {
|
|
185
|
+
EmbeddingModelType::JinaBert => {
|
|
186
|
+
let model_path = api.repo(repo).get("model.safetensors").map_err(wrap_hf_err)?;
|
|
187
|
+
if !std::path::Path::new(&model_path).exists() {
|
|
188
|
+
let ruby = Ruby::get().unwrap();
|
|
189
|
+
return Err(magnus::Error::new(
|
|
190
|
+
ruby.exception_runtime_error(),
|
|
191
|
+
"model.safetensors not found after download. Only safetensors models are supported. Please ensure your model repo contains model.safetensors."
|
|
192
|
+
));
|
|
193
|
+
}
|
|
194
|
+
let final_emb_dim = Self::resolve_embedding_size(Path::new(&model_path), embedding_size)?;
|
|
195
|
+
let mut config = JinaConfig::v2_base();
|
|
196
|
+
config.hidden_size = final_emb_dim;
|
|
197
|
+
let vb = unsafe {
|
|
198
|
+
VarBuilder::from_mmaped_safetensors(&[model_path], CoreDType::F32, &device)
|
|
199
|
+
.map_err(wrap_candle_err)?
|
|
200
|
+
};
|
|
201
|
+
let model = JinaBertModel::new(vb, &config).map_err(wrap_candle_err)?;
|
|
202
|
+
Ok(EmbeddingModelVariant::JinaBert(model))
|
|
203
|
+
},
|
|
204
|
+
EmbeddingModelType::StandardBert => {
|
|
205
|
+
let model_path = api.repo(repo).get("model.safetensors").map_err(wrap_hf_err)?;
|
|
206
|
+
if !std::path::Path::new(&model_path).exists() {
|
|
207
|
+
let ruby = Ruby::get().unwrap();
|
|
208
|
+
return Err(magnus::Error::new(
|
|
209
|
+
ruby.exception_runtime_error(),
|
|
210
|
+
"model.safetensors not found after download. Only safetensors models are supported. Please ensure your model repo contains model.safetensors."
|
|
211
|
+
));
|
|
212
|
+
}
|
|
213
|
+
let final_emb_dim = Self::resolve_embedding_size(Path::new(&model_path), embedding_size)?;
|
|
214
|
+
let mut config = BertConfig::default();
|
|
215
|
+
config.hidden_size = final_emb_dim;
|
|
216
|
+
let vb = unsafe {
|
|
217
|
+
VarBuilder::from_mmaped_safetensors(&[model_path], CoreDType::F32, &device)
|
|
218
|
+
.map_err(wrap_candle_err)?
|
|
219
|
+
};
|
|
220
|
+
let model = StdBertModel::load(vb, &config).map_err(wrap_candle_err)?;
|
|
221
|
+
Ok(EmbeddingModelVariant::StandardBert(model))
|
|
222
|
+
},
|
|
223
|
+
EmbeddingModelType::DistilBert => {
|
|
224
|
+
let model_path = api.repo(repo.clone()).get("model.safetensors").map_err(wrap_hf_err)?;
|
|
225
|
+
if !std::path::Path::new(&model_path).exists() {
|
|
226
|
+
let ruby = Ruby::get().unwrap();
|
|
227
|
+
return Err(magnus::Error::new(
|
|
228
|
+
ruby.exception_runtime_error(),
|
|
229
|
+
"model.safetensors not found after download. Only safetensors models are supported. Please ensure your model repo contains model.safetensors."
|
|
230
|
+
));
|
|
231
|
+
}
|
|
232
|
+
let config_path = api.repo(repo.clone()).get("config.json").map_err(wrap_hf_err)?;
|
|
233
|
+
let config_file = std::fs::File::open(&config_path).map_err(|e| wrap_std_err(Box::new(e)))?;
|
|
234
|
+
let mut config: DistilBertConfig = serde_json::from_reader(config_file).map_err(|e| wrap_std_err(Box::new(e)))?;
|
|
235
|
+
if let Some(embedding_size) = embedding_size {
|
|
236
|
+
config.dim = embedding_size;
|
|
237
|
+
}
|
|
238
|
+
let vb = unsafe {
|
|
239
|
+
VarBuilder::from_mmaped_safetensors(&[model_path], CoreDType::F32, &device)
|
|
240
|
+
.map_err(wrap_candle_err)?
|
|
241
|
+
};
|
|
242
|
+
let model = DistilBertModel::load(vb, &config).map_err(wrap_candle_err)?;
|
|
243
|
+
Ok(EmbeddingModelVariant::DistilBert(model))
|
|
244
|
+
},
|
|
245
|
+
EmbeddingModelType::MiniLM => {
|
|
246
|
+
let model_path = api.repo(repo.clone()).get("model.safetensors").map_err(wrap_hf_err)?;
|
|
247
|
+
if !std::path::Path::new(&model_path).exists() {
|
|
248
|
+
let ruby = Ruby::get().unwrap();
|
|
249
|
+
return Err(magnus::Error::new(
|
|
250
|
+
ruby.exception_runtime_error(),
|
|
251
|
+
"model.safetensors not found after download. Only safetensors models are supported. Please ensure your model repo contains model.safetensors."
|
|
252
|
+
));
|
|
253
|
+
}
|
|
254
|
+
let config_path = api.repo(repo.clone()).get("config.json").map_err(wrap_hf_err)?;
|
|
255
|
+
let config_file = std::fs::File::open(&config_path).map_err(|e| wrap_std_err(Box::new(e)))?;
|
|
256
|
+
let mut config: BertConfig = serde_json::from_reader(config_file).map_err(|e| wrap_std_err(Box::new(e)))?;
|
|
257
|
+
if let Some(embedding_size) = embedding_size {
|
|
258
|
+
config.hidden_size = embedding_size;
|
|
259
|
+
}
|
|
260
|
+
let vb = unsafe {
|
|
261
|
+
VarBuilder::from_mmaped_safetensors(&[model_path], CoreDType::F32, &device)
|
|
262
|
+
.map_err(wrap_candle_err)?
|
|
263
|
+
};
|
|
264
|
+
let model = StdBertModel::load(vb, &config).map_err(wrap_candle_err)?;
|
|
265
|
+
Ok(EmbeddingModelVariant::MiniLM(model))
|
|
266
|
+
},
|
|
267
|
+
|
|
268
|
+
}
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
fn build_tokenizer(tokenizer_id: String) -> Result<TokenizerWrapper> {
|
|
272
|
+
use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
273
|
+
let tokenizer_path = Api::new()
|
|
274
|
+
.map_err(wrap_hf_err)?
|
|
275
|
+
.repo(Repo::new(
|
|
276
|
+
tokenizer_id,
|
|
277
|
+
RepoType::Model,
|
|
278
|
+
))
|
|
279
|
+
.get("tokenizer.json")
|
|
280
|
+
.map_err(wrap_hf_err)?;
|
|
281
|
+
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
|
|
282
|
+
.map_err(wrap_std_err)?;
|
|
283
|
+
|
|
284
|
+
let tokenizer = TokenizerLoader::with_padding(tokenizer, None);
|
|
285
|
+
Ok(TokenizerWrapper::new(tokenizer))
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
/// Pools the embedding tensor by extracting the CLS token ([1, SEQLENGTH, DIM] -> [1, DIM])
|
|
289
|
+
/// &RETURNS&: Tensor
|
|
290
|
+
fn pooled_cls_embedding(result: &CoreTensor) -> std::result::Result<CoreTensor, Error> {
|
|
291
|
+
// 1) sanity-check that we have a 3D tensor
|
|
292
|
+
let (_batch, _seq_len, _hidden_size) = result.dims3().map_err(wrap_candle_err)?;
|
|
293
|
+
|
|
294
|
+
// 2) slice out just the first token (CLS) along the sequence axis:
|
|
295
|
+
// [B, seq_len, H] → [B, 1, H]
|
|
296
|
+
let first = result
|
|
297
|
+
.narrow(1, 0, 1)
|
|
298
|
+
.map_err(wrap_candle_err)?;
|
|
299
|
+
|
|
300
|
+
// 3) remove that length-1 axis → [B, H]
|
|
301
|
+
let cls = first
|
|
302
|
+
.squeeze(1)
|
|
303
|
+
.map_err(wrap_candle_err)?;
|
|
304
|
+
|
|
305
|
+
Ok(cls)
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
fn pooled_embedding(result: &CoreTensor) -> std::result::Result<CoreTensor, Error> {
|
|
309
|
+
let (_n_sentence, n_tokens, _hidden_size) = result.dims3().map_err(wrap_candle_err)?;
|
|
310
|
+
let sum = result.sum(1).map_err(wrap_candle_err)?;
|
|
311
|
+
let mean = (sum / (n_tokens as f64)).map_err(wrap_candle_err)?;
|
|
312
|
+
Ok(mean)
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
fn pooled_normalized_embedding(result: &CoreTensor) -> std::result::Result<CoreTensor, Error> {
|
|
316
|
+
let mean = Self::pooled_embedding(result)?;
|
|
317
|
+
let norm = Self::normalize_l2(&mean).map_err(wrap_candle_err)?;
|
|
318
|
+
Ok(norm)
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
fn compute_embeddings(
|
|
322
|
+
&self,
|
|
323
|
+
prompt: String,
|
|
324
|
+
model: &EmbeddingModelVariant,
|
|
325
|
+
tokenizer: &TokenizerWrapper,
|
|
326
|
+
) -> std::result::Result<CoreTensor, Error> {
|
|
327
|
+
let tokens = tokenizer
|
|
328
|
+
.encode(&prompt, true)
|
|
329
|
+
.map_err(wrap_candle_err)?;
|
|
330
|
+
let token_ids = CoreTensor::new(&tokens[..], &self.0.device)
|
|
331
|
+
.map_err(wrap_candle_err)?
|
|
332
|
+
.unsqueeze(0)
|
|
333
|
+
.map_err(wrap_candle_err)?;
|
|
334
|
+
let batch_size = token_ids.dims()[0];
|
|
335
|
+
let seq_len = token_ids.dims()[1];
|
|
336
|
+
let token_type_ids = CoreTensor::zeros(&[batch_size, seq_len], CoreDType::U32, &self.0.device)
|
|
337
|
+
.map_err(wrap_candle_err)?;
|
|
338
|
+
let attention_mask = CoreTensor::ones(&[batch_size, seq_len], CoreDType::U32, &self.0.device)
|
|
339
|
+
.map_err(wrap_candle_err)?;
|
|
340
|
+
match model {
|
|
341
|
+
EmbeddingModelVariant::JinaBert(model) => {
|
|
342
|
+
model.forward(&token_ids).map_err(wrap_candle_err)
|
|
343
|
+
},
|
|
344
|
+
EmbeddingModelVariant::StandardBert(model) => {
|
|
345
|
+
model.forward(&token_ids, &token_type_ids, Some(&attention_mask)).map_err(wrap_candle_err)
|
|
346
|
+
},
|
|
347
|
+
EmbeddingModelVariant::DistilBert(model) => {
|
|
348
|
+
model.forward(&token_ids, &attention_mask).map_err(wrap_candle_err)
|
|
349
|
+
},
|
|
350
|
+
EmbeddingModelVariant::MiniLM(model) => {
|
|
351
|
+
model.forward(&token_ids, &token_type_ids, Some(&attention_mask)).map_err(wrap_candle_err)
|
|
352
|
+
},
|
|
353
|
+
|
|
354
|
+
}
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
/// Computes an embedding for the prompt using the specified pooling method.
|
|
358
|
+
/// pooling_method: "pooled", "pooled_normalized", or "cls"
|
|
359
|
+
fn compute_embedding(
|
|
360
|
+
&self,
|
|
361
|
+
prompt: String,
|
|
362
|
+
model: &EmbeddingModelVariant,
|
|
363
|
+
tokenizer: &TokenizerWrapper,
|
|
364
|
+
pooling_method: &str,
|
|
365
|
+
) -> std::result::Result<CoreTensor, Error> {
|
|
366
|
+
let result = self.compute_embeddings(prompt, model, tokenizer)?;
|
|
367
|
+
match pooling_method {
|
|
368
|
+
"pooled" => Self::pooled_embedding(&result),
|
|
369
|
+
"pooled_normalized" => Self::pooled_normalized_embedding(&result),
|
|
370
|
+
"cls" => Self::pooled_cls_embedding(&result),
|
|
371
|
+
_ => {
|
|
372
|
+
let ruby = Ruby::get().unwrap();
|
|
373
|
+
Err(magnus::Error::new(ruby.exception_runtime_error(), "Unknown pooling method"))
|
|
374
|
+
},
|
|
375
|
+
}
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
fn normalize_l2(v: &CoreTensor) -> candle_core::Result<CoreTensor> {
|
|
379
|
+
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
|
|
380
|
+
}
|
|
381
|
+
|
|
382
|
+
pub fn model_type(&self) -> String {
|
|
383
|
+
match self.0.model_type {
|
|
384
|
+
Some(mt) => format!("{:?}", mt),
|
|
385
|
+
None => "nil".to_string(),
|
|
386
|
+
}
|
|
387
|
+
}
|
|
388
|
+
|
|
389
|
+
pub fn __repr__(&self) -> String {
|
|
390
|
+
format!(
|
|
391
|
+
"#<Candle::EmbeddingModel model_type: {}, model_id: {}, tokenizer: {}, embedding_size: {}>",
|
|
392
|
+
self.model_type(),
|
|
393
|
+
self.0.model_id.as_deref().unwrap_or("nil"),
|
|
394
|
+
self.0.tokenizer_id.as_deref().unwrap_or("nil"),
|
|
395
|
+
self.0.embedding_size.map(|x| x.to_string()).unwrap_or("nil".to_string())
|
|
396
|
+
)
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
pub fn __str__(&self) -> String {
|
|
400
|
+
self.__repr__()
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
/// Get the tokenizer used by this model
|
|
404
|
+
pub fn tokenizer(&self) -> Result<crate::ruby::tokenizer::Tokenizer> {
|
|
405
|
+
match &self.0.tokenizer {
|
|
406
|
+
Some(tokenizer) => Ok(crate::ruby::tokenizer::Tokenizer(tokenizer.clone())),
|
|
407
|
+
None => {
|
|
408
|
+
let ruby = Ruby::get().unwrap();
|
|
409
|
+
Err(magnus::Error::new(ruby.exception_runtime_error(), "No tokenizer loaded for this model"))
|
|
410
|
+
}
|
|
411
|
+
}
|
|
412
|
+
}
|
|
413
|
+
|
|
414
|
+
/// Get the model_id
|
|
415
|
+
pub fn model_id(&self) -> Result<String> {
|
|
416
|
+
match &self.0.model_id {
|
|
417
|
+
Some(id) => Ok(id.clone()),
|
|
418
|
+
None => Ok("unknown".to_string())
|
|
419
|
+
}
|
|
420
|
+
}
|
|
421
|
+
|
|
422
|
+
/// Get the device
|
|
423
|
+
pub fn device(&self) -> Device {
|
|
424
|
+
Device::from_device(&self.0.device)
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
/// Get all options as a hash
|
|
428
|
+
pub fn options(&self) -> Result<RHash> {
|
|
429
|
+
let ruby = Ruby::get().unwrap();
|
|
430
|
+
let hash = ruby.hash_new();
|
|
431
|
+
|
|
432
|
+
// Add model_id
|
|
433
|
+
if let Some(model_id) = &self.0.model_id {
|
|
434
|
+
hash.aset("model_id", model_id.clone())?;
|
|
435
|
+
}
|
|
436
|
+
|
|
437
|
+
// Add tokenizer
|
|
438
|
+
if let Some(tokenizer_id) = &self.0.tokenizer_id {
|
|
439
|
+
hash.aset("tokenizer", tokenizer_id.clone())?;
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
// Add device
|
|
443
|
+
hash.aset("device", self.device().__str__())?;
|
|
444
|
+
|
|
445
|
+
// Add model_type
|
|
446
|
+
if let Some(model_type) = &self.0.model_type {
|
|
447
|
+
hash.aset("model_type", format!("{:?}", model_type))?;
|
|
448
|
+
}
|
|
449
|
+
|
|
450
|
+
// Add embedding_size
|
|
451
|
+
if let Some(size) = self.0.embedding_size {
|
|
452
|
+
hash.aset("embedding_size", size)?;
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
Ok(hash)
|
|
456
|
+
}
|
|
457
|
+
}
|
|
458
|
+
|
|
459
|
+
pub fn init(rb_candle: RModule) -> Result<()> {
|
|
460
|
+
let ruby = Ruby::get().unwrap();
|
|
461
|
+
let rb_embedding_model = rb_candle.define_class("EmbeddingModel", ruby.class_object())?;
|
|
462
|
+
rb_embedding_model.define_singleton_method("_create", function!(EmbeddingModel::new, 5))?;
|
|
463
|
+
// Expose embedding with an optional pooling_method argument (default: "pooled")
|
|
464
|
+
rb_embedding_model.define_method("_embedding", method!(EmbeddingModel::embedding, 2))?;
|
|
465
|
+
rb_embedding_model.define_method("embeddings", method!(EmbeddingModel::embeddings, 1))?;
|
|
466
|
+
rb_embedding_model.define_method("pool_embedding", method!(EmbeddingModel::pool_embedding, 1))?;
|
|
467
|
+
rb_embedding_model.define_method("pool_and_normalize_embedding", method!(EmbeddingModel::pool_and_normalize_embedding, 1))?;
|
|
468
|
+
rb_embedding_model.define_method("pool_cls_embedding", method!(EmbeddingModel::pool_cls_embedding, 1))?;
|
|
469
|
+
rb_embedding_model.define_method("model_type", method!(EmbeddingModel::model_type, 0))?;
|
|
470
|
+
rb_embedding_model.define_method("to_s", method!(EmbeddingModel::__str__, 0))?;
|
|
471
|
+
rb_embedding_model.define_method("inspect", method!(EmbeddingModel::__repr__, 0))?;
|
|
472
|
+
rb_embedding_model.define_method("tokenizer", method!(EmbeddingModel::tokenizer, 0))?;
|
|
473
|
+
rb_embedding_model.define_method("model_id", method!(EmbeddingModel::model_id, 0))?;
|
|
474
|
+
rb_embedding_model.define_method("device", method!(EmbeddingModel::device, 0))?;
|
|
475
|
+
rb_embedding_model.define_method("options", method!(EmbeddingModel::options, 0))?;
|
|
476
|
+
Ok(())
|
|
477
|
+
}
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
use magnus::Error;
|
|
2
|
+
|
|
3
|
+
pub fn wrap_std_err(err: Box<dyn std::error::Error + Send + Sync>) -> Error {
|
|
4
|
+
let ruby = magnus::Ruby::get().unwrap();
|
|
5
|
+
Error::new(ruby.exception_runtime_error(), err.to_string())
|
|
6
|
+
}
|
|
7
|
+
|
|
8
|
+
pub fn wrap_candle_err(err: candle_core::Error) -> Error {
|
|
9
|
+
let ruby = magnus::Ruby::get().unwrap();
|
|
10
|
+
Error::new(ruby.exception_runtime_error(), err.to_string())
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
pub fn wrap_hf_err(err: hf_hub::api::sync::ApiError) -> Error {
|
|
14
|
+
let ruby = magnus::Ruby::get().unwrap();
|
|
15
|
+
Error::new(ruby.exception_runtime_error(), err.to_string())
|
|
16
|
+
}
|