red-candle 1.3.0 → 1.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Cargo.lock +11 -20
- data/ext/candle/Cargo.toml +1 -1
- data/ext/candle/src/llm/constrained_generation_test.rs +79 -0
- data/ext/candle/src/llm/text_generation.rs +40 -50
- data/ext/candle/src/ruby/device.rs +8 -7
- data/ext/candle/src/ruby/dtype.rs +3 -2
- data/ext/candle/src/ruby/embedding_model.rs +31 -14
- data/ext/candle/src/ruby/errors.rs +6 -4
- data/ext/candle/src/ruby/llm.rs +78 -68
- data/ext/candle/src/ruby/ner.rs +106 -95
- data/ext/candle/src/ruby/reranker.rs +51 -38
- data/ext/candle/src/ruby/structured.rs +61 -16
- data/ext/candle/src/ruby/tensor.rs +7 -6
- data/ext/candle/src/ruby/tokenizer.rs +101 -84
- data/lib/candle/llm.rb +77 -3
- data/lib/candle/version.rb +1 -1
- metadata +31 -6
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
use magnus::{
|
|
1
|
+
use magnus::{function, method, prelude::*, Error, RModule, RArray, RHash, Ruby};
|
|
2
2
|
use candle_transformers::models::bert::{BertModel, Config};
|
|
3
3
|
use candle_core::{Device as CoreDevice, Tensor, IndexOp, DType};
|
|
4
4
|
use candle_nn::{VarBuilder, Linear, Module, ops::sigmoid};
|
|
@@ -25,6 +25,9 @@ impl Reranker {
|
|
|
25
25
|
}
|
|
26
26
|
|
|
27
27
|
fn new_with_core_device(model_id: String, device: CoreDevice, max_length: usize) -> std::result::Result<Self, Error> {
|
|
28
|
+
let ruby = Ruby::get().unwrap();
|
|
29
|
+
let runtime_error = ruby.exception_runtime_error();
|
|
30
|
+
|
|
28
31
|
let result = (|| -> std::result::Result<(BertModel, TokenizerWrapper, Linear, Linear), Box<dyn std::error::Error + Send + Sync>> {
|
|
29
32
|
let api = Api::new()?;
|
|
30
33
|
let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
|
|
@@ -64,50 +67,56 @@ impl Reranker {
|
|
|
64
67
|
Ok((model, tokenizer, pooler, classifier)) => {
|
|
65
68
|
Ok(Self { model, tokenizer, pooler, classifier, device, model_id })
|
|
66
69
|
}
|
|
67
|
-
Err(e) => Err(Error::new(
|
|
70
|
+
Err(e) => Err(Error::new(runtime_error, format!("Failed to load model: {}", e))),
|
|
68
71
|
}
|
|
69
72
|
}
|
|
70
73
|
|
|
71
74
|
/// Extract CLS embeddings from the model output, handling Metal device workarounds
|
|
72
75
|
fn extract_cls_embeddings(&self, embeddings: &Tensor) -> std::result::Result<Tensor, Error> {
|
|
76
|
+
let ruby = Ruby::get().unwrap();
|
|
77
|
+
let runtime_error = ruby.exception_runtime_error();
|
|
78
|
+
|
|
73
79
|
let cls_embeddings = if self.device.is_metal() {
|
|
74
80
|
// Metal has issues with tensor indexing, use a different approach
|
|
75
81
|
let (batch_size, seq_len, hidden_size) = embeddings.dims3()
|
|
76
|
-
.map_err(|e| Error::new(
|
|
82
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to get dims: {}", e)))?;
|
|
77
83
|
|
|
78
84
|
// Reshape to [batch * seq_len, hidden] then take first hidden vectors for each batch
|
|
79
85
|
let reshaped = embeddings.reshape((batch_size * seq_len, hidden_size))
|
|
80
|
-
.map_err(|e| Error::new(
|
|
86
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to reshape: {}", e)))?;
|
|
81
87
|
|
|
82
88
|
// Extract CLS tokens (first token of each sequence)
|
|
83
89
|
let mut cls_vecs = Vec::new();
|
|
84
90
|
for i in 0..batch_size {
|
|
85
91
|
let start_idx = i * seq_len;
|
|
86
92
|
let cls_vec = reshaped.narrow(0, start_idx, 1)
|
|
87
|
-
.map_err(|e| Error::new(
|
|
93
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to extract CLS: {}", e)))?;
|
|
88
94
|
cls_vecs.push(cls_vec);
|
|
89
95
|
}
|
|
90
96
|
|
|
91
97
|
// Stack the CLS vectors
|
|
92
98
|
Tensor::cat(&cls_vecs, 0)
|
|
93
|
-
.map_err(|e| Error::new(
|
|
99
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to cat CLS tokens: {}", e)))?
|
|
94
100
|
} else {
|
|
95
101
|
embeddings.i((.., 0))
|
|
96
|
-
.map_err(|e| Error::new(
|
|
102
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to extract CLS token: {}", e)))?
|
|
97
103
|
};
|
|
98
104
|
|
|
99
105
|
// Ensure tensor is contiguous for downstream operations
|
|
100
106
|
cls_embeddings.contiguous()
|
|
101
|
-
.map_err(|e| Error::new(
|
|
107
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to make CLS embeddings contiguous: {}", e)))
|
|
102
108
|
}
|
|
103
109
|
|
|
104
|
-
pub fn debug_tokenization(&self, query: String, document: String) -> std::result::Result<
|
|
110
|
+
pub fn debug_tokenization(&self, query: String, document: String) -> std::result::Result<RHash, Error> {
|
|
111
|
+
let ruby = Ruby::get().unwrap();
|
|
112
|
+
let runtime_error = ruby.exception_runtime_error();
|
|
113
|
+
|
|
105
114
|
// Create query-document pair for cross-encoder
|
|
106
115
|
let query_doc_pair: EncodeInput = (query.clone(), document.clone()).into();
|
|
107
116
|
|
|
108
117
|
// Tokenize using the inner tokenizer for detailed info
|
|
109
118
|
let encoding = self.tokenizer.inner().encode(query_doc_pair, true)
|
|
110
|
-
.map_err(|e| Error::new(
|
|
119
|
+
.map_err(|e| Error::new(runtime_error, format!("Tokenization failed: {}", e)))?;
|
|
111
120
|
|
|
112
121
|
// Get token information
|
|
113
122
|
let token_ids = encoding.get_ids().to_vec();
|
|
@@ -116,16 +125,18 @@ impl Reranker {
|
|
|
116
125
|
let tokens = encoding.get_tokens().iter().map(|t| t.to_string()).collect::<Vec<_>>();
|
|
117
126
|
|
|
118
127
|
// Create result hash
|
|
119
|
-
let result =
|
|
120
|
-
result.aset("token_ids",
|
|
121
|
-
result.aset("token_type_ids",
|
|
122
|
-
result.aset("attention_mask",
|
|
123
|
-
result.aset("tokens",
|
|
128
|
+
let result = ruby.hash_new();
|
|
129
|
+
result.aset("token_ids", ruby.ary_from_vec(token_ids.iter().map(|&id| id as i64).collect::<Vec<_>>()))?;
|
|
130
|
+
result.aset("token_type_ids", ruby.ary_from_vec(token_type_ids.iter().map(|&id| id as i64).collect::<Vec<_>>()))?;
|
|
131
|
+
result.aset("attention_mask", ruby.ary_from_vec(attention_mask.iter().map(|&mask| mask as i64).collect::<Vec<_>>()))?;
|
|
132
|
+
result.aset("tokens", ruby.ary_from_vec(tokens))?;
|
|
124
133
|
|
|
125
134
|
Ok(result)
|
|
126
135
|
}
|
|
127
136
|
|
|
128
137
|
pub fn rerank_with_options(&self, query: String, documents: RArray, pooling_method: String, apply_sigmoid: bool) -> std::result::Result<RArray, Error> {
|
|
138
|
+
let ruby = Ruby::get().unwrap();
|
|
139
|
+
let runtime_error = ruby.exception_runtime_error();
|
|
129
140
|
let documents: Vec<String> = documents.to_vec()?;
|
|
130
141
|
|
|
131
142
|
// Create query-document pairs for cross-encoder
|
|
@@ -136,7 +147,7 @@ impl Reranker {
|
|
|
136
147
|
|
|
137
148
|
// Tokenize batch using inner tokenizer for access to token type IDs
|
|
138
149
|
let encodings = self.tokenizer.inner().encode_batch(query_and_docs, true)
|
|
139
|
-
.map_err(|e| Error::new(
|
|
150
|
+
.map_err(|e| Error::new(runtime_error, format!("Tokenization failed: {}", e)))?;
|
|
140
151
|
|
|
141
152
|
// Convert to tensors
|
|
142
153
|
let token_ids = encodings
|
|
@@ -150,15 +161,15 @@ impl Reranker {
|
|
|
150
161
|
.collect::<Vec<_>>();
|
|
151
162
|
|
|
152
163
|
let token_ids = Tensor::new(token_ids, &self.device)
|
|
153
|
-
.map_err(|e| Error::new(
|
|
164
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to create tensor: {}", e)))?;
|
|
154
165
|
let token_type_ids = Tensor::new(token_type_ids, &self.device)
|
|
155
|
-
.map_err(|e| Error::new(
|
|
166
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to create token type ids tensor: {}", e)))?;
|
|
156
167
|
let attention_mask = token_ids.ne(0u32)
|
|
157
|
-
.map_err(|e| Error::new(
|
|
168
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to create attention mask: {}", e)))?;
|
|
158
169
|
|
|
159
170
|
// Forward pass through BERT
|
|
160
171
|
let embeddings = self.model.forward(&token_ids, &token_type_ids, Some(&attention_mask))
|
|
161
|
-
.map_err(|e| Error::new(
|
|
172
|
+
.map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
|
|
162
173
|
|
|
163
174
|
// Apply pooling based on the specified method
|
|
164
175
|
let pooled_embeddings = match pooling_method.as_str() {
|
|
@@ -166,9 +177,9 @@ impl Reranker {
|
|
|
166
177
|
// Extract [CLS] token and apply pooler (dense + tanh)
|
|
167
178
|
let cls_embeddings = self.extract_cls_embeddings(&embeddings)?;
|
|
168
179
|
let pooled = self.pooler.forward(&cls_embeddings)
|
|
169
|
-
.map_err(|e| Error::new(
|
|
180
|
+
.map_err(|e| Error::new(runtime_error, format!("Pooler forward failed: {}", e)))?;
|
|
170
181
|
pooled.tanh()
|
|
171
|
-
.map_err(|e| Error::new(
|
|
182
|
+
.map_err(|e| Error::new(runtime_error, format!("Tanh activation failed: {}", e)))?
|
|
172
183
|
},
|
|
173
184
|
"cls" => {
|
|
174
185
|
// Just use the [CLS] token embeddings directly (no pooler layer)
|
|
@@ -177,35 +188,35 @@ impl Reranker {
|
|
|
177
188
|
"mean" => {
|
|
178
189
|
// Mean pooling across all tokens
|
|
179
190
|
let (_batch, seq_len, _hidden) = embeddings.dims3()
|
|
180
|
-
.map_err(|e| Error::new(
|
|
191
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to get tensor dimensions: {}", e)))?;
|
|
181
192
|
let sum = embeddings.sum(1)
|
|
182
|
-
.map_err(|e| Error::new(
|
|
193
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to sum embeddings: {}", e)))?;
|
|
183
194
|
(sum / (seq_len as f64))
|
|
184
|
-
.map_err(|e| Error::new(
|
|
195
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to compute mean: {}", e)))?
|
|
185
196
|
},
|
|
186
|
-
_ => return Err(Error::new(
|
|
197
|
+
_ => return Err(Error::new(runtime_error,
|
|
187
198
|
format!("Unknown pooling method: {}. Use 'pooler', 'cls', or 'mean'", pooling_method)))
|
|
188
199
|
};
|
|
189
200
|
|
|
190
201
|
// Apply classifier to get relevance scores (raw logits)
|
|
191
202
|
// Ensure tensor is contiguous before linear layer
|
|
192
203
|
let pooled_embeddings = pooled_embeddings.contiguous()
|
|
193
|
-
.map_err(|e| Error::new(
|
|
204
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to make pooled_embeddings contiguous: {}", e)))?;
|
|
194
205
|
let logits = self.classifier.forward(&pooled_embeddings)
|
|
195
|
-
.map_err(|e| Error::new(
|
|
206
|
+
.map_err(|e| Error::new(runtime_error, format!("Classifier forward failed: {}", e)))?;
|
|
196
207
|
let scores = logits.squeeze(1)
|
|
197
|
-
.map_err(|e| Error::new(
|
|
208
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to squeeze tensor: {}", e)))?;
|
|
198
209
|
|
|
199
210
|
// Optionally apply sigmoid activation
|
|
200
211
|
let scores = if apply_sigmoid {
|
|
201
212
|
sigmoid(&scores)
|
|
202
|
-
.map_err(|e| Error::new(
|
|
213
|
+
.map_err(|e| Error::new(runtime_error, format!("Sigmoid failed: {}", e)))?
|
|
203
214
|
} else {
|
|
204
215
|
scores
|
|
205
216
|
};
|
|
206
217
|
|
|
207
218
|
let scores_vec: Vec<f32> = scores.to_vec1()
|
|
208
|
-
.map_err(|e| Error::new(
|
|
219
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to convert scores to vec: {}", e)))?;
|
|
209
220
|
|
|
210
221
|
// Create tuples with document, score, and original index
|
|
211
222
|
let mut ranked_docs: Vec<(String, f32, usize)> = documents
|
|
@@ -219,11 +230,11 @@ impl Reranker {
|
|
|
219
230
|
ranked_docs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
|
220
231
|
|
|
221
232
|
// Build result array with [doc, score, doc_id]
|
|
222
|
-
let result_array =
|
|
233
|
+
let result_array = ruby.ary_new();
|
|
223
234
|
for (doc, score, doc_id) in ranked_docs {
|
|
224
|
-
let tuple =
|
|
235
|
+
let tuple = ruby.ary_new();
|
|
225
236
|
tuple.push(doc)?;
|
|
226
|
-
tuple.push(
|
|
237
|
+
tuple.push(ruby.float_from_f64(score as f64))?;
|
|
227
238
|
tuple.push(doc_id)?;
|
|
228
239
|
result_array.push(tuple)?;
|
|
229
240
|
}
|
|
@@ -246,8 +257,9 @@ impl Reranker {
|
|
|
246
257
|
}
|
|
247
258
|
|
|
248
259
|
/// Get all options as a hash
|
|
249
|
-
pub fn options(&self) -> std::result::Result<
|
|
250
|
-
let
|
|
260
|
+
pub fn options(&self) -> std::result::Result<RHash, Error> {
|
|
261
|
+
let ruby = Ruby::get().unwrap();
|
|
262
|
+
let hash = ruby.hash_new();
|
|
251
263
|
hash.aset("model_id", self.model_id.clone())?;
|
|
252
264
|
hash.aset("device", self.device().__str__())?;
|
|
253
265
|
Ok(hash)
|
|
@@ -255,7 +267,8 @@ impl Reranker {
|
|
|
255
267
|
}
|
|
256
268
|
|
|
257
269
|
pub fn init(rb_candle: RModule) -> std::result::Result<(), Error> {
|
|
258
|
-
let
|
|
270
|
+
let ruby = Ruby::get().unwrap();
|
|
271
|
+
let c_reranker = rb_candle.define_class("Reranker", ruby.class_object())?;
|
|
259
272
|
c_reranker.define_singleton_method("_create", function!(Reranker::new, 3))?;
|
|
260
273
|
c_reranker.define_method("rerank_with_options", method!(Reranker::rerank_with_options, 4))?;
|
|
261
274
|
c_reranker.define_method("debug_tokenization", method!(Reranker::debug_tokenization, 2))?;
|
|
@@ -264,4 +277,4 @@ pub fn init(rb_candle: RModule) -> std::result::Result<(), Error> {
|
|
|
264
277
|
c_reranker.define_method("device", method!(Reranker::device, 0))?;
|
|
265
278
|
c_reranker.define_method("options", method!(Reranker::options, 0))?;
|
|
266
279
|
Ok(())
|
|
267
|
-
}
|
|
280
|
+
}
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
use magnus::{Error, Module, RModule, function, Object};
|
|
1
|
+
use magnus::{Error, Module, RModule, function, Object, Ruby};
|
|
2
2
|
use std::sync::Arc;
|
|
3
3
|
|
|
4
|
-
use crate::structured::{SchemaProcessor, VocabularyAdapter, Index};
|
|
4
|
+
use crate::structured::{SchemaProcessor, VocabularyAdapter, Index, Vocabulary};
|
|
5
5
|
use crate::ruby::{Result, tokenizer::Tokenizer};
|
|
6
6
|
|
|
7
7
|
/// Ruby wrapper for structured generation constraints
|
|
@@ -12,36 +12,81 @@ pub struct StructuredConstraint {
|
|
|
12
12
|
}
|
|
13
13
|
|
|
14
14
|
impl StructuredConstraint {
|
|
15
|
-
/// Create a constraint from a JSON schema
|
|
15
|
+
/// Create a constraint from a JSON schema using a model ID
|
|
16
|
+
/// This uses Vocabulary::from_pretrained which handles tokenizer byte encoding correctly
|
|
17
|
+
pub fn from_schema_with_model(schema: String, model_id: String) -> Result<Self> {
|
|
18
|
+
// Use tokio runtime for async vocabulary loading
|
|
19
|
+
let rt = tokio::runtime::Runtime::new()
|
|
20
|
+
.map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to create runtime: {}", e)))?;
|
|
21
|
+
|
|
22
|
+
let vocabulary = rt.block_on(async {
|
|
23
|
+
Vocabulary::from_pretrained(&model_id, None)
|
|
24
|
+
})
|
|
25
|
+
.map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to create vocabulary from model '{}': {:?}", model_id, e)))?;
|
|
26
|
+
|
|
27
|
+
let processor = SchemaProcessor::new();
|
|
28
|
+
let index = processor.process_schema(&schema, &vocabulary)
|
|
29
|
+
.map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to process schema: {}", e)))?;
|
|
30
|
+
|
|
31
|
+
Ok(Self { index })
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
/// Create a constraint from a regex pattern using a model ID
|
|
35
|
+
pub fn from_regex_with_model(pattern: String, model_id: String) -> Result<Self> {
|
|
36
|
+
// Use tokio runtime for async vocabulary loading
|
|
37
|
+
let rt = tokio::runtime::Runtime::new()
|
|
38
|
+
.map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to create runtime: {}", e)))?;
|
|
39
|
+
|
|
40
|
+
let vocabulary = rt.block_on(async {
|
|
41
|
+
Vocabulary::from_pretrained(&model_id, None)
|
|
42
|
+
})
|
|
43
|
+
.map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to create vocabulary from model '{}': {:?}", model_id, e)))?;
|
|
44
|
+
|
|
45
|
+
let processor = SchemaProcessor::new();
|
|
46
|
+
let index = processor.process_regex(&pattern, &vocabulary)
|
|
47
|
+
.map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to process regex: {}", e)))?;
|
|
48
|
+
|
|
49
|
+
Ok(Self { index })
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
/// Create a constraint from a JSON schema (legacy method using tokenizer directly)
|
|
53
|
+
/// Note: This may not handle all tokenizer byte encodings correctly
|
|
16
54
|
pub fn from_schema(schema: String, tokenizer: &Tokenizer) -> Result<Self> {
|
|
17
55
|
let vocabulary = VocabularyAdapter::from_tokenizer(&tokenizer.0)
|
|
18
|
-
.map_err(|e| Error::new(
|
|
19
|
-
|
|
56
|
+
.map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to create vocabulary: {}", e)))?;
|
|
57
|
+
|
|
20
58
|
let processor = SchemaProcessor::new();
|
|
21
59
|
let index = processor.process_schema(&schema, &vocabulary)
|
|
22
|
-
.map_err(|e| Error::new(
|
|
23
|
-
|
|
60
|
+
.map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to process schema: {}", e)))?;
|
|
61
|
+
|
|
24
62
|
Ok(Self { index })
|
|
25
63
|
}
|
|
26
|
-
|
|
27
|
-
/// Create a constraint from a regex pattern
|
|
64
|
+
|
|
65
|
+
/// Create a constraint from a regex pattern (legacy method using tokenizer directly)
|
|
66
|
+
/// Note: This may not handle all tokenizer byte encodings correctly
|
|
28
67
|
pub fn from_regex(pattern: String, tokenizer: &Tokenizer) -> Result<Self> {
|
|
29
68
|
let vocabulary = VocabularyAdapter::from_tokenizer(&tokenizer.0)
|
|
30
|
-
.map_err(|e| Error::new(
|
|
31
|
-
|
|
69
|
+
.map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to create vocabulary: {}", e)))?;
|
|
70
|
+
|
|
32
71
|
let processor = SchemaProcessor::new();
|
|
33
72
|
let index = processor.process_regex(&pattern, &vocabulary)
|
|
34
|
-
.map_err(|e| Error::new(
|
|
35
|
-
|
|
73
|
+
.map_err(|e| Error::new(Ruby::get().unwrap().exception_runtime_error(), format!("Failed to process regex: {}", e)))?;
|
|
74
|
+
|
|
36
75
|
Ok(Self { index })
|
|
37
76
|
}
|
|
38
77
|
}
|
|
39
78
|
|
|
40
79
|
pub fn init_structured(rb_candle: RModule) -> Result<()> {
|
|
41
|
-
let
|
|
42
|
-
|
|
80
|
+
let ruby = Ruby::get().unwrap();
|
|
81
|
+
let class = rb_candle.define_class("StructuredConstraint", ruby.class_object())?;
|
|
82
|
+
|
|
83
|
+
// New methods using model_id for proper vocabulary loading
|
|
84
|
+
class.define_singleton_method("from_schema_with_model", function!(StructuredConstraint::from_schema_with_model, 2))?;
|
|
85
|
+
class.define_singleton_method("from_regex_with_model", function!(StructuredConstraint::from_regex_with_model, 2))?;
|
|
86
|
+
|
|
87
|
+
// Legacy methods using tokenizer directly (may have byte encoding issues with some models)
|
|
43
88
|
class.define_singleton_method("from_schema", function!(StructuredConstraint::from_schema, 2))?;
|
|
44
89
|
class.define_singleton_method("from_regex", function!(StructuredConstraint::from_regex, 2))?;
|
|
45
|
-
|
|
90
|
+
|
|
46
91
|
Ok(())
|
|
47
92
|
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
use magnus::prelude::*;
|
|
2
|
-
use magnus::{function, method,
|
|
2
|
+
use magnus::{function, method, RModule, Module, Object, Ruby};
|
|
3
3
|
|
|
4
4
|
use crate::ruby::{
|
|
5
5
|
errors::wrap_candle_err,
|
|
@@ -84,7 +84,7 @@ impl Tensor {
|
|
|
84
84
|
Ok(f.to_f64() as i64)
|
|
85
85
|
} else {
|
|
86
86
|
Err(magnus::Error::new(
|
|
87
|
-
|
|
87
|
+
Ruby::get().unwrap().exception_type_error(),
|
|
88
88
|
"Cannot convert to i64"
|
|
89
89
|
))
|
|
90
90
|
}
|
|
@@ -143,7 +143,7 @@ impl Tensor {
|
|
|
143
143
|
Ok(values)
|
|
144
144
|
}
|
|
145
145
|
_ => Err(magnus::Error::new(
|
|
146
|
-
|
|
146
|
+
Ruby::get().unwrap().exception_runtime_error(),
|
|
147
147
|
"Tensor must be F32 dtype for values_f32",
|
|
148
148
|
)),
|
|
149
149
|
}
|
|
@@ -153,7 +153,7 @@ impl Tensor {
|
|
|
153
153
|
pub fn item(&self) -> Result<f64> {
|
|
154
154
|
if self.0.rank() != 0 {
|
|
155
155
|
return Err(magnus::Error::new(
|
|
156
|
-
|
|
156
|
+
Ruby::get().unwrap().exception_runtime_error(),
|
|
157
157
|
format!("item() can only be called on scalar tensors (rank 0), but tensor has rank {}", self.0.rank()),
|
|
158
158
|
));
|
|
159
159
|
}
|
|
@@ -384,7 +384,7 @@ impl Tensor {
|
|
|
384
384
|
let scalar = CoreTensor::from_vec(vec![i as f32], (1,), &self.0.device()).map_err(wrap_candle_err)?;
|
|
385
385
|
Ok(Self(self.0.broadcast_div(&scalar).map_err(wrap_candle_err)?))
|
|
386
386
|
} else {
|
|
387
|
-
Err(magnus::Error::new(
|
|
387
|
+
Err(magnus::Error::new(Ruby::get().unwrap().exception_type_error(), "Right-hand side must be a Candle::Tensor, Float, or Integer"))
|
|
388
388
|
}
|
|
389
389
|
}
|
|
390
390
|
|
|
@@ -650,7 +650,8 @@ impl Tensor {
|
|
|
650
650
|
}
|
|
651
651
|
|
|
652
652
|
pub fn init(rb_candle: RModule) -> Result<()> {
|
|
653
|
-
let
|
|
653
|
+
let ruby = Ruby::get().unwrap();
|
|
654
|
+
let rb_tensor = rb_candle.define_class("Tensor", ruby.class_object())?;
|
|
654
655
|
rb_tensor.define_singleton_method("new", function!(Tensor::new, 3))?;
|
|
655
656
|
// rb_tensor.define_singleton_method("cat", function!(Tensor::cat, 2))?;
|
|
656
657
|
// rb_tensor.define_singleton_method("stack", function!(Tensor::stack, 2))?;
|