red-candle 0.0.6 → 1.0.0.pre.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Cargo.lock +1667 -517
- data/Cargo.toml +4 -0
- data/README.md +224 -6
- data/ext/candle/Cargo.toml +19 -8
- data/ext/candle/build.rs +116 -0
- data/ext/candle/extconf.rb +77 -1
- data/ext/candle/src/lib.rs +44 -95
- data/ext/candle/src/llm/generation_config.rs +49 -0
- data/ext/candle/src/llm/mistral.rs +325 -0
- data/ext/candle/src/llm/mod.rs +68 -0
- data/ext/candle/src/llm/text_generation.rs +141 -0
- data/ext/candle/src/reranker.rs +267 -0
- data/ext/candle/src/ruby/device.rs +197 -0
- data/ext/candle/src/ruby/dtype.rs +37 -0
- data/ext/candle/src/ruby/embedding_model.rs +410 -0
- data/ext/candle/src/ruby/errors.rs +13 -0
- data/ext/candle/src/ruby/llm.rs +295 -0
- data/ext/candle/src/ruby/mod.rs +21 -0
- data/ext/candle/src/ruby/qtensor.rs +69 -0
- data/ext/candle/src/ruby/result.rs +3 -0
- data/ext/candle/src/ruby/tensor.rs +654 -0
- data/ext/candle/src/ruby/utils.rs +88 -0
- data/lib/candle/build_info.rb +66 -0
- data/lib/candle/device_utils.rb +20 -0
- data/lib/candle/embedding_model.rb +32 -0
- data/lib/candle/embedding_model_type.rb +31 -0
- data/lib/candle/llm.rb +107 -0
- data/lib/candle/reranker.rb +24 -0
- data/lib/candle/tensor.rb +68 -3
- data/lib/candle/version.rb +2 -2
- data/lib/candle.rb +6 -0
- metadata +25 -3
@@ -0,0 +1,410 @@
|
|
1
|
+
// MKL and Accelerate are handled by candle-core when their features are enabled
|
2
|
+
|
3
|
+
use crate::ruby::{
|
4
|
+
errors::{wrap_candle_err, wrap_hf_err, wrap_std_err},
|
5
|
+
};
|
6
|
+
use crate::ruby::{Tensor, Device, Result as RbResult};
|
7
|
+
use candle_core::{DType as CoreDType, Device as CoreDevice, Module, Tensor as CoreTensor};
|
8
|
+
use safetensors::tensor::SafeTensors;
|
9
|
+
use candle_nn::VarBuilder;
|
10
|
+
use candle_transformers::models::{
|
11
|
+
bert::{BertModel as StdBertModel, Config as BertConfig},
|
12
|
+
jina_bert::{BertModel as JinaBertModel, Config as JinaConfig},
|
13
|
+
distilbert::{DistilBertModel, Config as DistilBertConfig}
|
14
|
+
};
|
15
|
+
use magnus::{class, function, method, prelude::*, Error, RModule};
|
16
|
+
use std::path::Path;
|
17
|
+
use tokenizers::Tokenizer;
|
18
|
+
use serde_json;
|
19
|
+
|
20
|
+
|
21
|
+
#[magnus::wrap(class = "Candle::EmbeddingModel", free_immediately, size)]
|
22
|
+
pub struct EmbeddingModel(pub EmbeddingModelInner);
|
23
|
+
|
24
|
+
/// Supported model types for embedding generation
|
25
|
+
#[derive(Debug, Clone, Copy, PartialEq)]
|
26
|
+
pub enum EmbeddingModelType {
|
27
|
+
JinaBert,
|
28
|
+
StandardBert,
|
29
|
+
DistilBert,
|
30
|
+
MiniLM,
|
31
|
+
}
|
32
|
+
|
33
|
+
impl EmbeddingModelType {
|
34
|
+
pub fn from_string(model_type: &str) -> Option<Self> {
|
35
|
+
match model_type.to_lowercase().as_str() {
|
36
|
+
"jina_bert" | "jinabert" | "jina" => Some(EmbeddingModelType::JinaBert),
|
37
|
+
"bert" | "standard_bert" | "standardbert" => Some(EmbeddingModelType::StandardBert),
|
38
|
+
"minilm" => Some(EmbeddingModelType::MiniLM),
|
39
|
+
|
40
|
+
"distilbert" => Some(EmbeddingModelType::DistilBert),
|
41
|
+
_ => None
|
42
|
+
}
|
43
|
+
}
|
44
|
+
}
|
45
|
+
|
46
|
+
/// Model variants that can produce embeddings
|
47
|
+
pub enum EmbeddingModelVariant {
|
48
|
+
JinaBert(JinaBertModel),
|
49
|
+
StandardBert(StdBertModel),
|
50
|
+
DistilBert(DistilBertModel),
|
51
|
+
MiniLM(StdBertModel),
|
52
|
+
|
53
|
+
}
|
54
|
+
|
55
|
+
impl EmbeddingModelVariant {
|
56
|
+
pub fn embedding_model_type(&self) -> EmbeddingModelType {
|
57
|
+
match self {
|
58
|
+
EmbeddingModelVariant::JinaBert(_) => EmbeddingModelType::JinaBert,
|
59
|
+
EmbeddingModelVariant::StandardBert(_) => EmbeddingModelType::StandardBert,
|
60
|
+
EmbeddingModelVariant::DistilBert(_) => EmbeddingModelType::DistilBert,
|
61
|
+
EmbeddingModelVariant::MiniLM(_) => EmbeddingModelType::MiniLM,
|
62
|
+
|
63
|
+
}
|
64
|
+
}
|
65
|
+
}
|
66
|
+
|
67
|
+
pub struct EmbeddingModelInner {
|
68
|
+
device: CoreDevice,
|
69
|
+
tokenizer_path: Option<String>,
|
70
|
+
model_path: Option<String>,
|
71
|
+
embedding_model_type: Option<EmbeddingModelType>,
|
72
|
+
model: Option<EmbeddingModelVariant>,
|
73
|
+
tokenizer: Option<Tokenizer>,
|
74
|
+
embedding_size: Option<usize>,
|
75
|
+
}
|
76
|
+
|
77
|
+
impl EmbeddingModel {
|
78
|
+
pub fn new(model_path: Option<String>, tokenizer_path: Option<String>, device: Option<Device>, embedding_model_type: Option<String>, embedding_size: Option<usize>) -> RbResult<Self> {
|
79
|
+
let device = device.unwrap_or(Device::Cpu).as_device()?;
|
80
|
+
let embedding_model_type = embedding_model_type
|
81
|
+
.and_then(|mt| EmbeddingModelType::from_string(&mt))
|
82
|
+
.unwrap_or(EmbeddingModelType::JinaBert);
|
83
|
+
Ok(EmbeddingModel(EmbeddingModelInner {
|
84
|
+
device: device.clone(),
|
85
|
+
model_path: model_path.clone(),
|
86
|
+
tokenizer_path: tokenizer_path.clone(),
|
87
|
+
embedding_model_type: Some(embedding_model_type),
|
88
|
+
model: match model_path {
|
89
|
+
Some(mp) => Some(Self::build_embedding_model(Path::new(&mp), device, embedding_model_type, embedding_size)?),
|
90
|
+
None => None
|
91
|
+
},
|
92
|
+
tokenizer: match tokenizer_path {
|
93
|
+
Some(tp) => Some(Self::build_tokenizer(tp)?),
|
94
|
+
None => None
|
95
|
+
},
|
96
|
+
embedding_size,
|
97
|
+
}))
|
98
|
+
}
|
99
|
+
|
100
|
+
/// Generates an embedding vector for the input text
|
101
|
+
/// &RETURNS&: Tensor
|
102
|
+
/// Generates an embedding vector for the input text using the specified pooling method.
|
103
|
+
/// &RETURNS&: Tensor
|
104
|
+
/// pooling_method: "pooled", "pooled_normalized", or "cls" (default: "pooled")
|
105
|
+
pub fn embedding(&self, input: String, pooling_method: String) -> RbResult<Tensor> {
|
106
|
+
match &self.0.model {
|
107
|
+
Some(model) => {
|
108
|
+
match &self.0.tokenizer {
|
109
|
+
Some(tokenizer) => Ok(Tensor(self.compute_embedding(input, model, tokenizer, &pooling_method)?)),
|
110
|
+
None => Err(magnus::Error::new(magnus::exception::runtime_error(), "Tokenizer not found"))
|
111
|
+
}
|
112
|
+
}
|
113
|
+
None => Err(magnus::Error::new(magnus::exception::runtime_error(), "Model not found"))
|
114
|
+
}
|
115
|
+
}
|
116
|
+
|
117
|
+
/// Returns the unpooled embedding tensor ([1, SEQLENGTH, DIM]) for the input text
|
118
|
+
/// &RETURNS&: Tensor
|
119
|
+
pub fn embeddings(&self, input: String) -> RbResult<Tensor> {
|
120
|
+
match &self.0.model {
|
121
|
+
Some(model) => {
|
122
|
+
match &self.0.tokenizer {
|
123
|
+
Some(tokenizer) => Ok(Tensor(self.compute_embeddings(input, model, tokenizer)?)),
|
124
|
+
None => Err(magnus::Error::new(magnus::exception::runtime_error(), "Tokenizer not found"))
|
125
|
+
}
|
126
|
+
}
|
127
|
+
None => Err(magnus::Error::new(magnus::exception::runtime_error(), "Model not found"))
|
128
|
+
}
|
129
|
+
}
|
130
|
+
|
131
|
+
/// Pools and normalizes a sequence embedding tensor ([1, SEQLENGTH, DIM]) to [1, DIM]
|
132
|
+
/// &RETURNS&: Tensor
|
133
|
+
pub fn pool_embedding(&self, tensor: &Tensor) -> RbResult<Tensor> {
|
134
|
+
let pooled = Self::pooled_embedding(&tensor.0)?;
|
135
|
+
Ok(Tensor(pooled))
|
136
|
+
}
|
137
|
+
|
138
|
+
/// Pools and normalizes a sequence embedding tensor ([1, SEQLENGTH, DIM]) to [1, DIM]
|
139
|
+
/// &RETURNS&: Tensor
|
140
|
+
pub fn pool_and_normalize_embedding(&self, tensor: &Tensor) -> RbResult<Tensor> {
|
141
|
+
let pooled = Self::pooled_normalized_embedding(&tensor.0)?;
|
142
|
+
Ok(Tensor(pooled))
|
143
|
+
}
|
144
|
+
|
145
|
+
/// Pools the embedding tensor by extracting the CLS token ([1, SEQLENGTH, DIM] -> [1, DIM])
|
146
|
+
/// &RETURNS&: Tensor
|
147
|
+
pub fn pool_cls_embedding(&self, tensor: &Tensor) -> RbResult<Tensor> {
|
148
|
+
let pooled = Self::pooled_cls_embedding(&tensor.0)?;
|
149
|
+
Ok(Tensor(pooled))
|
150
|
+
}
|
151
|
+
|
152
|
+
/// Infers and validates the embedding size from a safetensors file
|
153
|
+
fn resolve_embedding_size(model_path: &Path, embedding_size: Option<usize>) -> Result<usize, magnus::Error> {
|
154
|
+
match embedding_size {
|
155
|
+
Some(user_dim) => {
|
156
|
+
Ok(user_dim)
|
157
|
+
},
|
158
|
+
None => {
|
159
|
+
let inferred_emb_dim = match SafeTensors::deserialize(&std::fs::read(model_path).map_err(|e| wrap_std_err(Box::new(e)))?) {
|
160
|
+
Ok(st) => {
|
161
|
+
if let Some(tensor) = st.tensor("embeddings.word_embeddings.weight").ok() {
|
162
|
+
let shape = tensor.shape();
|
163
|
+
if shape.len() == 2 { Some(shape[1] as usize) } else { None }
|
164
|
+
} else { None }
|
165
|
+
},
|
166
|
+
Err(_) => None
|
167
|
+
};
|
168
|
+
inferred_emb_dim.ok_or_else(|| magnus::Error::new(magnus::exception::runtime_error(), "Could not infer embedding size from model file. Please specify embedding_size explicitly."))
|
169
|
+
}
|
170
|
+
}
|
171
|
+
}
|
172
|
+
|
173
|
+
fn build_embedding_model(model_path: &Path, device: CoreDevice, embedding_model_type: EmbeddingModelType, embedding_size: Option<usize>) -> RbResult<EmbeddingModelVariant> {
|
174
|
+
use hf_hub::{api::sync::Api, Repo, RepoType};
|
175
|
+
let api = Api::new().map_err(wrap_hf_err)?;
|
176
|
+
let repo = Repo::new(model_path.to_str().unwrap().to_string(), RepoType::Model);
|
177
|
+
match embedding_model_type {
|
178
|
+
EmbeddingModelType::JinaBert => {
|
179
|
+
let model_path = api.repo(repo).get("model.safetensors").map_err(wrap_hf_err)?;
|
180
|
+
if !std::path::Path::new(&model_path).exists() {
|
181
|
+
return Err(magnus::Error::new(
|
182
|
+
magnus::exception::runtime_error(),
|
183
|
+
"model.safetensors not found after download. Only safetensors models are supported. Please ensure your model repo contains model.safetensors."
|
184
|
+
));
|
185
|
+
}
|
186
|
+
let final_emb_dim = Self::resolve_embedding_size(Path::new(&model_path), embedding_size)?;
|
187
|
+
let mut config = JinaConfig::v2_base();
|
188
|
+
config.hidden_size = final_emb_dim;
|
189
|
+
let vb = unsafe {
|
190
|
+
VarBuilder::from_mmaped_safetensors(&[model_path], CoreDType::F32, &device)
|
191
|
+
.map_err(wrap_candle_err)?
|
192
|
+
};
|
193
|
+
let model = JinaBertModel::new(vb, &config).map_err(wrap_candle_err)?;
|
194
|
+
Ok(EmbeddingModelVariant::JinaBert(model))
|
195
|
+
},
|
196
|
+
EmbeddingModelType::StandardBert => {
|
197
|
+
let model_path = api.repo(repo).get("model.safetensors").map_err(wrap_hf_err)?;
|
198
|
+
if !std::path::Path::new(&model_path).exists() {
|
199
|
+
return Err(magnus::Error::new(
|
200
|
+
magnus::exception::runtime_error(),
|
201
|
+
"model.safetensors not found after download. Only safetensors models are supported. Please ensure your model repo contains model.safetensors."
|
202
|
+
));
|
203
|
+
}
|
204
|
+
let final_emb_dim = Self::resolve_embedding_size(Path::new(&model_path), embedding_size)?;
|
205
|
+
let mut config = BertConfig::default();
|
206
|
+
config.hidden_size = final_emb_dim;
|
207
|
+
let vb = unsafe {
|
208
|
+
VarBuilder::from_mmaped_safetensors(&[model_path], CoreDType::F32, &device)
|
209
|
+
.map_err(wrap_candle_err)?
|
210
|
+
};
|
211
|
+
let model = StdBertModel::load(vb, &config).map_err(wrap_candle_err)?;
|
212
|
+
Ok(EmbeddingModelVariant::StandardBert(model))
|
213
|
+
},
|
214
|
+
EmbeddingModelType::DistilBert => {
|
215
|
+
let model_path = api.repo(repo.clone()).get("model.safetensors").map_err(wrap_hf_err)?;
|
216
|
+
if !std::path::Path::new(&model_path).exists() {
|
217
|
+
return Err(magnus::Error::new(
|
218
|
+
magnus::exception::runtime_error(),
|
219
|
+
"model.safetensors not found after download. Only safetensors models are supported. Please ensure your model repo contains model.safetensors."
|
220
|
+
));
|
221
|
+
}
|
222
|
+
let config_path = api.repo(repo.clone()).get("config.json").map_err(wrap_hf_err)?;
|
223
|
+
let config_file = std::fs::File::open(&config_path).map_err(|e| wrap_std_err(Box::new(e)))?;
|
224
|
+
let mut config: DistilBertConfig = serde_json::from_reader(config_file).map_err(|e| wrap_std_err(Box::new(e)))?;
|
225
|
+
if let Some(embedding_size) = embedding_size {
|
226
|
+
config.dim = embedding_size;
|
227
|
+
}
|
228
|
+
let vb = unsafe {
|
229
|
+
VarBuilder::from_mmaped_safetensors(&[model_path], CoreDType::F32, &device)
|
230
|
+
.map_err(wrap_candle_err)?
|
231
|
+
};
|
232
|
+
let model = DistilBertModel::load(vb, &config).map_err(wrap_candle_err)?;
|
233
|
+
Ok(EmbeddingModelVariant::DistilBert(model))
|
234
|
+
},
|
235
|
+
EmbeddingModelType::MiniLM => {
|
236
|
+
let model_path = api.repo(repo.clone()).get("model.safetensors").map_err(wrap_hf_err)?;
|
237
|
+
if !std::path::Path::new(&model_path).exists() {
|
238
|
+
return Err(magnus::Error::new(
|
239
|
+
magnus::exception::runtime_error(),
|
240
|
+
"model.safetensors not found after download. Only safetensors models are supported. Please ensure your model repo contains model.safetensors."
|
241
|
+
));
|
242
|
+
}
|
243
|
+
let config_path = api.repo(repo.clone()).get("config.json").map_err(wrap_hf_err)?;
|
244
|
+
let config_file = std::fs::File::open(&config_path).map_err(|e| wrap_std_err(Box::new(e)))?;
|
245
|
+
let mut config: BertConfig = serde_json::from_reader(config_file).map_err(|e| wrap_std_err(Box::new(e)))?;
|
246
|
+
if let Some(embedding_size) = embedding_size {
|
247
|
+
config.hidden_size = embedding_size;
|
248
|
+
}
|
249
|
+
let vb = unsafe {
|
250
|
+
VarBuilder::from_mmaped_safetensors(&[model_path], CoreDType::F32, &device)
|
251
|
+
.map_err(wrap_candle_err)?
|
252
|
+
};
|
253
|
+
let model = StdBertModel::load(vb, &config).map_err(wrap_candle_err)?;
|
254
|
+
Ok(EmbeddingModelVariant::MiniLM(model))
|
255
|
+
},
|
256
|
+
|
257
|
+
}
|
258
|
+
}
|
259
|
+
|
260
|
+
fn build_tokenizer(tokenizer_path: String) -> RbResult<Tokenizer> {
|
261
|
+
use hf_hub::{api::sync::Api, Repo, RepoType};
|
262
|
+
let tokenizer_path = Api::new()
|
263
|
+
.map_err(wrap_hf_err)?
|
264
|
+
.repo(Repo::new(
|
265
|
+
tokenizer_path,
|
266
|
+
RepoType::Model,
|
267
|
+
))
|
268
|
+
.get("tokenizer.json")
|
269
|
+
.map_err(wrap_hf_err)?;
|
270
|
+
let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
|
271
|
+
.map_err(wrap_std_err)?;
|
272
|
+
let pp = tokenizers::PaddingParams {
|
273
|
+
strategy: tokenizers::PaddingStrategy::BatchLongest,
|
274
|
+
..Default::default()
|
275
|
+
};
|
276
|
+
tokenizer.with_padding(Some(pp));
|
277
|
+
|
278
|
+
Ok(tokenizer)
|
279
|
+
}
|
280
|
+
|
281
|
+
/// Pools the embedding tensor by extracting the CLS token ([1, SEQLENGTH, DIM] -> [1, DIM])
|
282
|
+
/// &RETURNS&: Tensor
|
283
|
+
fn pooled_cls_embedding(result: &CoreTensor) -> Result<CoreTensor, Error> {
|
284
|
+
// 1) sanity-check that we have a 3D tensor
|
285
|
+
let (_batch, _seq_len, _hidden_size) = result.dims3().map_err(wrap_candle_err)?;
|
286
|
+
|
287
|
+
// 2) slice out just the first token (CLS) along the sequence axis:
|
288
|
+
// [B, seq_len, H] → [B, 1, H]
|
289
|
+
let first = result
|
290
|
+
.narrow(1, 0, 1)
|
291
|
+
.map_err(wrap_candle_err)?;
|
292
|
+
|
293
|
+
// 3) remove that length-1 axis → [B, H]
|
294
|
+
let cls = first
|
295
|
+
.squeeze(1)
|
296
|
+
.map_err(wrap_candle_err)?;
|
297
|
+
|
298
|
+
Ok(cls)
|
299
|
+
}
|
300
|
+
|
301
|
+
fn pooled_embedding(result: &CoreTensor) -> Result<CoreTensor, Error> {
|
302
|
+
let (_n_sentence, n_tokens, _hidden_size) = result.dims3().map_err(wrap_candle_err)?;
|
303
|
+
let sum = result.sum(1).map_err(wrap_candle_err)?;
|
304
|
+
let mean = (sum / (n_tokens as f64)).map_err(wrap_candle_err)?;
|
305
|
+
Ok(mean)
|
306
|
+
}
|
307
|
+
|
308
|
+
fn pooled_normalized_embedding(result: &CoreTensor) -> Result<CoreTensor, Error> {
|
309
|
+
let mean = Self::pooled_embedding(result)?;
|
310
|
+
let norm = Self::normalize_l2(&mean).map_err(wrap_candle_err)?;
|
311
|
+
Ok(norm)
|
312
|
+
}
|
313
|
+
|
314
|
+
fn compute_embeddings(
|
315
|
+
&self,
|
316
|
+
prompt: String,
|
317
|
+
model: &EmbeddingModelVariant,
|
318
|
+
tokenizer: &Tokenizer,
|
319
|
+
) -> Result<CoreTensor, Error> {
|
320
|
+
let tokens = tokenizer
|
321
|
+
.encode(prompt, true)
|
322
|
+
.map_err(wrap_std_err)?
|
323
|
+
.get_ids()
|
324
|
+
.to_vec();
|
325
|
+
let token_ids = CoreTensor::new(&tokens[..], &self.0.device)
|
326
|
+
.map_err(wrap_candle_err)?
|
327
|
+
.unsqueeze(0)
|
328
|
+
.map_err(wrap_candle_err)?;
|
329
|
+
let batch_size = token_ids.dims()[0];
|
330
|
+
let seq_len = token_ids.dims()[1];
|
331
|
+
let token_type_ids = CoreTensor::zeros(&[batch_size, seq_len], CoreDType::U32, &self.0.device)
|
332
|
+
.map_err(wrap_candle_err)?;
|
333
|
+
let attention_mask = CoreTensor::ones(&[batch_size, seq_len], CoreDType::U32, &self.0.device)
|
334
|
+
.map_err(wrap_candle_err)?;
|
335
|
+
match model {
|
336
|
+
EmbeddingModelVariant::JinaBert(model) => {
|
337
|
+
model.forward(&token_ids).map_err(wrap_candle_err)
|
338
|
+
},
|
339
|
+
EmbeddingModelVariant::StandardBert(model) => {
|
340
|
+
model.forward(&token_ids, &token_type_ids, Some(&attention_mask)).map_err(wrap_candle_err)
|
341
|
+
},
|
342
|
+
EmbeddingModelVariant::DistilBert(model) => {
|
343
|
+
model.forward(&token_ids, &attention_mask).map_err(wrap_candle_err)
|
344
|
+
},
|
345
|
+
EmbeddingModelVariant::MiniLM(model) => {
|
346
|
+
model.forward(&token_ids, &token_type_ids, Some(&attention_mask)).map_err(wrap_candle_err)
|
347
|
+
},
|
348
|
+
|
349
|
+
}
|
350
|
+
}
|
351
|
+
|
352
|
+
/// Computes an embedding for the prompt using the specified pooling method.
|
353
|
+
/// pooling_method: "pooled", "pooled_normalized", or "cls"
|
354
|
+
fn compute_embedding(
|
355
|
+
&self,
|
356
|
+
prompt: String,
|
357
|
+
model: &EmbeddingModelVariant,
|
358
|
+
tokenizer: &Tokenizer,
|
359
|
+
pooling_method: &str,
|
360
|
+
) -> Result<CoreTensor, Error> {
|
361
|
+
let result = self.compute_embeddings(prompt, model, tokenizer)?;
|
362
|
+
match pooling_method {
|
363
|
+
"pooled" => Self::pooled_embedding(&result),
|
364
|
+
"pooled_normalized" => Self::pooled_normalized_embedding(&result),
|
365
|
+
"cls" => Self::pooled_cls_embedding(&result),
|
366
|
+
_ => Err(magnus::Error::new(magnus::exception::runtime_error(), "Unknown pooling method")),
|
367
|
+
}
|
368
|
+
}
|
369
|
+
|
370
|
+
#[allow(dead_code)]
|
371
|
+
fn normalize_l2(v: &CoreTensor) -> Result<CoreTensor, candle_core::Error> {
|
372
|
+
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
|
373
|
+
}
|
374
|
+
|
375
|
+
pub fn embedding_model_type(&self) -> String {
|
376
|
+
match self.0.embedding_model_type {
|
377
|
+
Some(model_type) => format!("{:?}", model_type),
|
378
|
+
None => "nil".to_string(),
|
379
|
+
}
|
380
|
+
}
|
381
|
+
|
382
|
+
pub fn __repr__(&self) -> String {
|
383
|
+
format!(
|
384
|
+
"#<Candle::EmbeddingModel embedding_model_type: {}, model_path: {}, tokenizer_path: {}, embedding_size: {}>",
|
385
|
+
self.embedding_model_type(),
|
386
|
+
self.0.model_path.as_deref().unwrap_or("nil"),
|
387
|
+
self.0.tokenizer_path.as_deref().unwrap_or("nil"),
|
388
|
+
self.0.embedding_size.map(|x| x.to_string()).unwrap_or("nil".to_string())
|
389
|
+
)
|
390
|
+
}
|
391
|
+
|
392
|
+
pub fn __str__(&self) -> String {
|
393
|
+
self.__repr__()
|
394
|
+
}
|
395
|
+
}
|
396
|
+
|
397
|
+
pub fn init(rb_candle: RModule) -> Result<(), Error> {
|
398
|
+
let rb_embedding_model = rb_candle.define_class("EmbeddingModel", class::object())?;
|
399
|
+
rb_embedding_model.define_singleton_method("_create", function!(EmbeddingModel::new, 5))?;
|
400
|
+
// Expose embedding with an optional pooling_method argument (default: "pooled")
|
401
|
+
rb_embedding_model.define_method("_embedding", method!(EmbeddingModel::embedding, 2))?;
|
402
|
+
rb_embedding_model.define_method("embeddings", method!(EmbeddingModel::embeddings, 1))?;
|
403
|
+
rb_embedding_model.define_method("pool_embedding", method!(EmbeddingModel::pool_embedding, 1))?;
|
404
|
+
rb_embedding_model.define_method("pool_and_normalize_embedding", method!(EmbeddingModel::pool_and_normalize_embedding, 1))?;
|
405
|
+
rb_embedding_model.define_method("pool_cls_embedding", method!(EmbeddingModel::pool_cls_embedding, 1))?;
|
406
|
+
rb_embedding_model.define_method("embedding_model_type", method!(EmbeddingModel::embedding_model_type, 0))?;
|
407
|
+
rb_embedding_model.define_method("to_s", method!(EmbeddingModel::__str__, 0))?;
|
408
|
+
rb_embedding_model.define_method("inspect", method!(EmbeddingModel::__repr__, 0))?;
|
409
|
+
Ok(())
|
410
|
+
}
|
@@ -0,0 +1,13 @@
|
|
1
|
+
use magnus::Error;
|
2
|
+
|
3
|
+
pub fn wrap_std_err(err: Box<dyn std::error::Error + Send + Sync>) -> Error {
|
4
|
+
Error::new(magnus::exception::runtime_error(), err.to_string())
|
5
|
+
}
|
6
|
+
|
7
|
+
pub fn wrap_candle_err(err: candle_core::Error) -> Error {
|
8
|
+
Error::new(magnus::exception::runtime_error(), err.to_string())
|
9
|
+
}
|
10
|
+
|
11
|
+
pub fn wrap_hf_err(err: hf_hub::api::sync::ApiError) -> Error {
|
12
|
+
Error::new(magnus::exception::runtime_error(), err.to_string())
|
13
|
+
}
|