red-candle 1.8.0-aarch64-linux
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/Cargo.lock +5021 -0
- data/Cargo.toml +6 -0
- data/Gemfile +3 -0
- data/LICENSE +22 -0
- data/README.md +1171 -0
- data/Rakefile +167 -0
- data/bin/console +11 -0
- data/bin/setup +17 -0
- data/ext/candle/Cargo.toml +38 -0
- data/ext/candle/build.rs +117 -0
- data/ext/candle/extconf.rb +79 -0
- data/ext/candle/rustfmt.toml +63 -0
- data/ext/candle/src/gvl.rs +58 -0
- data/ext/candle/src/lib.rs +59 -0
- data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
- data/ext/candle/src/llm/gemma.rs +313 -0
- data/ext/candle/src/llm/generation_config.rs +63 -0
- data/ext/candle/src/llm/glm4.rs +236 -0
- data/ext/candle/src/llm/granite.rs +308 -0
- data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
- data/ext/candle/src/llm/llama.rs +396 -0
- data/ext/candle/src/llm/mistral.rs +309 -0
- data/ext/candle/src/llm/mod.rs +49 -0
- data/ext/candle/src/llm/phi.rs +369 -0
- data/ext/candle/src/llm/quantized_gguf.rs +734 -0
- data/ext/candle/src/llm/qwen.rs +261 -0
- data/ext/candle/src/llm/qwen3.rs +257 -0
- data/ext/candle/src/llm/text_generation.rs +284 -0
- data/ext/candle/src/ruby/device.rs +234 -0
- data/ext/candle/src/ruby/dtype.rs +39 -0
- data/ext/candle/src/ruby/embedding_model.rs +477 -0
- data/ext/candle/src/ruby/errors.rs +16 -0
- data/ext/candle/src/ruby/llm.rs +730 -0
- data/ext/candle/src/ruby/mod.rs +24 -0
- data/ext/candle/src/ruby/ner.rs +444 -0
- data/ext/candle/src/ruby/reranker.rs +488 -0
- data/ext/candle/src/ruby/result.rs +3 -0
- data/ext/candle/src/ruby/structured.rs +92 -0
- data/ext/candle/src/ruby/tensor.rs +731 -0
- data/ext/candle/src/ruby/tokenizer.rs +343 -0
- data/ext/candle/src/ruby/utils.rs +96 -0
- data/ext/candle/src/ruby/vlm.rs +330 -0
- data/ext/candle/src/structured/integration_test.rs +130 -0
- data/ext/candle/src/structured/mod.rs +31 -0
- data/ext/candle/src/structured/schema_processor.rs +215 -0
- data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
- data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
- data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
- data/ext/candle/src/tokenizer/loader.rs +108 -0
- data/ext/candle/src/tokenizer/mod.rs +104 -0
- data/ext/candle/tests/device_tests.rs +43 -0
- data/ext/candle/tests/tensor_tests.rs +162 -0
- data/lib/candle/3.1/candle.so +0 -0
- data/lib/candle/3.2/candle.so +0 -0
- data/lib/candle/3.3/candle.so +0 -0
- data/lib/candle/3.4/candle.so +0 -0
- data/lib/candle/4.0/candle.so +0 -0
- data/lib/candle/agent.rb +68 -0
- data/lib/candle/build_info.rb +67 -0
- data/lib/candle/device_utils.rb +10 -0
- data/lib/candle/embedding_model.rb +75 -0
- data/lib/candle/embedding_model_type.rb +31 -0
- data/lib/candle/llm.rb +595 -0
- data/lib/candle/logger.rb +149 -0
- data/lib/candle/ner.rb +368 -0
- data/lib/candle/reranker.rb +45 -0
- data/lib/candle/tensor.rb +99 -0
- data/lib/candle/tokenizer.rb +139 -0
- data/lib/candle/tool.rb +47 -0
- data/lib/candle/tool_call_parser.rb +57 -0
- data/lib/candle/version.rb +5 -0
- data/lib/candle/vlm.rb +31 -0
- data/lib/candle.rb +29 -0
- data/lib/red-candle.rb +1 -0
- metadata +309 -0
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
use candle_core::{Result as CandleResult, Tensor};
|
|
2
|
+
use candle_transformers::generation::LogitsProcessor;
|
|
3
|
+
use std::sync::Arc;
|
|
4
|
+
|
|
5
|
+
use super::GenerationConfig;
|
|
6
|
+
use crate::structured::Index;
|
|
7
|
+
|
|
8
|
+
/// Helper struct for text generation process
|
|
9
|
+
pub struct TextGeneration {
|
|
10
|
+
logits_processor: LogitsProcessor,
|
|
11
|
+
tokens: Vec<u32>,
|
|
12
|
+
eos_token_id: Option<u32>,
|
|
13
|
+
repetition_penalty: f32,
|
|
14
|
+
repetition_penalty_last_n: usize,
|
|
15
|
+
constraint: Option<Arc<Index>>,
|
|
16
|
+
constraint_state: Option<u32>,
|
|
17
|
+
constraint_completed: bool,
|
|
18
|
+
tokens_since_constraint_start: usize,
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
impl TextGeneration {
|
|
22
|
+
pub fn new(config: &GenerationConfig) -> Self {
|
|
23
|
+
let logits_processor = LogitsProcessor::new(config.seed, Some(config.temperature), config.top_p);
|
|
24
|
+
|
|
25
|
+
let mut text_gen = Self {
|
|
26
|
+
logits_processor,
|
|
27
|
+
tokens: Vec::new(),
|
|
28
|
+
eos_token_id: None,
|
|
29
|
+
repetition_penalty: config.repetition_penalty,
|
|
30
|
+
repetition_penalty_last_n: config.repetition_penalty_last_n,
|
|
31
|
+
constraint: None,
|
|
32
|
+
constraint_state: None,
|
|
33
|
+
constraint_completed: false,
|
|
34
|
+
tokens_since_constraint_start: 0,
|
|
35
|
+
};
|
|
36
|
+
|
|
37
|
+
// Set constraint if provided
|
|
38
|
+
if let Some(ref constraint) = config.constraint {
|
|
39
|
+
text_gen.set_constraint(Arc::clone(constraint));
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
text_gen
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
pub fn set_eos_token_id(&mut self, eos_token_id: u32) {
|
|
46
|
+
self.eos_token_id = Some(eos_token_id);
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
pub fn set_tokens(&mut self, tokens: Vec<u32>) {
|
|
50
|
+
self.tokens = tokens;
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
pub fn get_tokens(&self) -> &[u32] {
|
|
54
|
+
&self.tokens
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
pub fn push_token(&mut self, token: u32) {
|
|
58
|
+
self.tokens.push(token);
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
pub fn set_constraint(&mut self, constraint: Arc<Index>) {
|
|
62
|
+
// Initialize with the first state
|
|
63
|
+
self.constraint_state = Some(constraint.initial_state());
|
|
64
|
+
self.constraint = Some(constraint);
|
|
65
|
+
self.constraint_completed = false;
|
|
66
|
+
self.tokens_since_constraint_start = self.tokens.len();
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
/// Apply constraints to logits by masking disallowed tokens
|
|
70
|
+
fn apply_constraints(&self, logits: &mut Tensor) -> CandleResult<()> {
|
|
71
|
+
if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
|
|
72
|
+
let device = logits.device();
|
|
73
|
+
let vocab_size = logits.dims1()?;
|
|
74
|
+
|
|
75
|
+
// Get allowed tokens from the constraint index for current state
|
|
76
|
+
if let Some(allowed_tokens) = constraint_index.allowed_tokens(&state) {
|
|
77
|
+
// Create a mask where allowed tokens have value 0 and others have -inf
|
|
78
|
+
let mut mask = vec![f32::NEG_INFINITY; vocab_size];
|
|
79
|
+
for &token_id in &allowed_tokens {
|
|
80
|
+
if (token_id as usize) < vocab_size {
|
|
81
|
+
mask[token_id as usize] = 0.0;
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
// Apply mask to logits
|
|
86
|
+
let mask_tensor = Tensor::from_vec(mask, vocab_size, device)?;
|
|
87
|
+
*logits = logits.add(&mask_tensor)?;
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
Ok(())
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
/// Apply repetition penalty to logits
|
|
94
|
+
pub fn apply_repetition_penalty(
|
|
95
|
+
&self,
|
|
96
|
+
logits: &mut Tensor,
|
|
97
|
+
penalty: f32,
|
|
98
|
+
context_size: usize,
|
|
99
|
+
) -> CandleResult<()> {
|
|
100
|
+
if penalty == 1.0 {
|
|
101
|
+
return Ok(());
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
let device = logits.device();
|
|
105
|
+
let vocab_size = logits.dims1()?;
|
|
106
|
+
|
|
107
|
+
// Get the context tokens to apply penalty to
|
|
108
|
+
let start = self.tokens.len().saturating_sub(context_size);
|
|
109
|
+
let context_tokens = &self.tokens[start..];
|
|
110
|
+
|
|
111
|
+
// Apply penalty to tokens that appear in the context
|
|
112
|
+
let mut logits_vec = logits.to_vec1::<f32>()?;
|
|
113
|
+
for &token in context_tokens {
|
|
114
|
+
if (token as usize) < vocab_size {
|
|
115
|
+
let idx = token as usize;
|
|
116
|
+
if logits_vec[idx] > 0.0 {
|
|
117
|
+
logits_vec[idx] /= penalty;
|
|
118
|
+
} else {
|
|
119
|
+
logits_vec[idx] *= penalty;
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
*logits = Tensor::from_vec(logits_vec, vocab_size, device)?;
|
|
125
|
+
Ok(())
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
/// Sample next token from logits
|
|
129
|
+
pub fn sample_next_token(
|
|
130
|
+
&mut self,
|
|
131
|
+
logits: &Tensor,
|
|
132
|
+
) -> CandleResult<u32> {
|
|
133
|
+
let mut logits = logits.clone();
|
|
134
|
+
|
|
135
|
+
// Apply repetition penalty using stored parameters
|
|
136
|
+
if self.repetition_penalty != 1.0 {
|
|
137
|
+
self.apply_repetition_penalty(&mut logits, self.repetition_penalty, self.repetition_penalty_last_n)?;
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
// Apply constraints if active
|
|
141
|
+
self.apply_constraints(&mut logits)?;
|
|
142
|
+
|
|
143
|
+
// Sample token
|
|
144
|
+
let next_token = self.logits_processor.sample(&logits)?;
|
|
145
|
+
self.tokens.push(next_token);
|
|
146
|
+
|
|
147
|
+
// Update constraint state if active
|
|
148
|
+
if let (Some(ref constraint_index), Some(current_state)) = (&self.constraint, self.constraint_state) {
|
|
149
|
+
// Get the next state
|
|
150
|
+
let next_state = constraint_index.next_state(¤t_state, &next_token);
|
|
151
|
+
|
|
152
|
+
// Check if we're transitioning to a state with no allowed tokens (completion)
|
|
153
|
+
if !self.constraint_completed && self.tokens.len() > self.tokens_since_constraint_start {
|
|
154
|
+
// Check if next state has no allowed tokens at all - this is definitive completion
|
|
155
|
+
if let Some(next_state_val) = next_state {
|
|
156
|
+
if let Some(allowed) = constraint_index.allowed_tokens(&next_state_val) {
|
|
157
|
+
if allowed.is_empty() {
|
|
158
|
+
self.constraint_completed = true;
|
|
159
|
+
}
|
|
160
|
+
// Only mark as complete if ONLY EOS is allowed (not just if EOS is one of many options)
|
|
161
|
+
else if let Some(eos) = self.eos_token_id {
|
|
162
|
+
if allowed.len() == 1 && allowed.contains(&eos) {
|
|
163
|
+
self.constraint_completed = true;
|
|
164
|
+
}
|
|
165
|
+
}
|
|
166
|
+
} else {
|
|
167
|
+
// None means no tokens allowed - constraint is complete
|
|
168
|
+
self.constraint_completed = true;
|
|
169
|
+
}
|
|
170
|
+
}
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
self.constraint_state = next_state;
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
Ok(next_token)
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
/// Check if the constraint is satisfied (reached a valid completion state)
|
|
180
|
+
pub fn is_constraint_satisfied(&self) -> bool {
|
|
181
|
+
// If we've explicitly marked the constraint as completed, return true
|
|
182
|
+
if self.constraint_completed {
|
|
183
|
+
return true;
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
// Also check the current state
|
|
187
|
+
if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
|
|
188
|
+
// Check if the constraint has reached a state where it MUST end
|
|
189
|
+
// This happens when there are no more allowed tokens (constraint fully satisfied)
|
|
190
|
+
if let Some(allowed) = constraint_index.allowed_tokens(&state) {
|
|
191
|
+
// If no tokens are allowed, the constraint is fully satisfied
|
|
192
|
+
if allowed.is_empty() {
|
|
193
|
+
return true;
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
// For JSON schemas, check if ONLY the EOS token is allowed
|
|
197
|
+
// This means we've generated a complete, valid JSON structure
|
|
198
|
+
// Don't treat EOS as a satisfaction signal if other tokens are also allowed
|
|
199
|
+
if let Some(eos) = self.eos_token_id {
|
|
200
|
+
if allowed.len() == 1 && allowed.contains(&eos) {
|
|
201
|
+
return true;
|
|
202
|
+
}
|
|
203
|
+
}
|
|
204
|
+
} else {
|
|
205
|
+
// None means no tokens allowed - constraint is satisfied
|
|
206
|
+
return true;
|
|
207
|
+
}
|
|
208
|
+
}
|
|
209
|
+
false
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
/// Check if the constraint is satisfied when stop_on_match is true
|
|
213
|
+
/// NOTE: For JSON schemas, this should only return true when the JSON structure is complete,
|
|
214
|
+
/// not just because we're in a state with many allowed tokens (like inside a string).
|
|
215
|
+
pub fn is_constraint_satisfied_stop_on_match(&self) -> bool {
|
|
216
|
+
// When stop_on_match is true, we stop as soon as the constraint is completed
|
|
217
|
+
if self.constraint_completed {
|
|
218
|
+
return true;
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
// For JSON and other structured outputs, don't use the "large allowed set" heuristic.
|
|
222
|
+
// Instead, only consider the constraint satisfied when:
|
|
223
|
+
// 1. There are no allowed tokens (definitive completion)
|
|
224
|
+
// 2. Only EOS is allowed (completion with optional termination)
|
|
225
|
+
if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
|
|
226
|
+
if let Some(allowed) = constraint_index.allowed_tokens(&state) {
|
|
227
|
+
// No more tokens allowed - definitely complete
|
|
228
|
+
if allowed.is_empty() {
|
|
229
|
+
return true;
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
// Only EOS is allowed - complete JSON structure
|
|
233
|
+
if let Some(eos) = self.eos_token_id {
|
|
234
|
+
if allowed.len() == 1 && allowed.contains(&eos) {
|
|
235
|
+
return true;
|
|
236
|
+
}
|
|
237
|
+
}
|
|
238
|
+
} else {
|
|
239
|
+
// None means no tokens allowed - constraint is complete
|
|
240
|
+
return true;
|
|
241
|
+
}
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
false
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
/// Check if we should stop generation
|
|
248
|
+
pub fn should_stop(&self, token: u32, max_length: usize) -> bool {
|
|
249
|
+
if self.tokens.len() >= max_length {
|
|
250
|
+
return true;
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
if let Some(eos) = self.eos_token_id {
|
|
254
|
+
if token == eos {
|
|
255
|
+
return true;
|
|
256
|
+
}
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
// Check if we've reached a final state in constraint
|
|
260
|
+
// A state is considered final if it has no allowed tokens
|
|
261
|
+
if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
|
|
262
|
+
if let Some(allowed) = constraint_index.allowed_tokens(&state) {
|
|
263
|
+
if allowed.is_empty() {
|
|
264
|
+
return true;
|
|
265
|
+
}
|
|
266
|
+
} else {
|
|
267
|
+
// None means no tokens allowed - we're done
|
|
268
|
+
return true;
|
|
269
|
+
}
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
false
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
/// Check if the generated text ends with any stop sequence
|
|
276
|
+
pub fn check_stop_sequences(&self, text: &str, stop_sequences: &[String]) -> bool {
|
|
277
|
+
for seq in stop_sequences {
|
|
278
|
+
if text.ends_with(seq) {
|
|
279
|
+
return true;
|
|
280
|
+
}
|
|
281
|
+
}
|
|
282
|
+
false
|
|
283
|
+
}
|
|
284
|
+
}
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
use magnus::Error;
|
|
2
|
+
use magnus::{function, method, RModule, Module, Object, Ruby};
|
|
3
|
+
|
|
4
|
+
use ::candle_core::Device as CoreDevice;
|
|
5
|
+
use crate::ruby::Result;
|
|
6
|
+
|
|
7
|
+
#[cfg(any(feature = "cuda", feature = "metal"))]
|
|
8
|
+
use crate::ruby::errors::wrap_candle_err;
|
|
9
|
+
|
|
10
|
+
#[cfg(feature = "cuda")]
|
|
11
|
+
static CUDA_DEVICE: std::sync::Mutex<Option<CoreDevice>> = std::sync::Mutex::new(None);
|
|
12
|
+
|
|
13
|
+
#[cfg(feature = "metal")]
|
|
14
|
+
static METAL_DEVICE: std::sync::Mutex<Option<CoreDevice>> = std::sync::Mutex::new(None);
|
|
15
|
+
|
|
16
|
+
/// Get list of available devices based on compile-time features
|
|
17
|
+
pub fn available_devices() -> Vec<String> {
|
|
18
|
+
let devices = vec!["cpu".to_string()];
|
|
19
|
+
|
|
20
|
+
#[cfg(all(feature = "cuda", not(force_cpu)))]
|
|
21
|
+
let devices = {
|
|
22
|
+
let mut devices = devices;
|
|
23
|
+
devices.push("cuda".to_string());
|
|
24
|
+
devices
|
|
25
|
+
};
|
|
26
|
+
|
|
27
|
+
#[cfg(all(feature = "metal", not(force_cpu)))]
|
|
28
|
+
let devices = {
|
|
29
|
+
let mut devices = devices;
|
|
30
|
+
devices.push("metal".to_string());
|
|
31
|
+
devices
|
|
32
|
+
};
|
|
33
|
+
|
|
34
|
+
devices
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
/// Get the default device based on what's available
|
|
38
|
+
pub fn default_device() -> Device {
|
|
39
|
+
// Return based on compiled features, not detection
|
|
40
|
+
#[cfg(all(feature = "metal", not(force_cpu)))]
|
|
41
|
+
{
|
|
42
|
+
Device::Metal
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
#[cfg(all(feature = "cuda", not(feature = "metal"), not(force_cpu)))]
|
|
46
|
+
{
|
|
47
|
+
Device::Cuda
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
#[cfg(not(any(all(feature = "metal", not(force_cpu)), all(feature = "cuda", not(feature = "metal"), not(force_cpu)))))]
|
|
51
|
+
{
|
|
52
|
+
Device::Cpu
|
|
53
|
+
}
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
/// Get the best available device by checking runtime availability
|
|
57
|
+
pub fn best_device() -> Device {
|
|
58
|
+
// Try devices in order of preference
|
|
59
|
+
|
|
60
|
+
#[cfg(feature = "metal")]
|
|
61
|
+
{
|
|
62
|
+
// Check if Metal is actually available at runtime
|
|
63
|
+
if CoreDevice::new_metal(0).is_ok() {
|
|
64
|
+
return Device::Metal;
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
#[cfg(feature = "cuda")]
|
|
69
|
+
{
|
|
70
|
+
// Check if CUDA is actually available at runtime
|
|
71
|
+
if CoreDevice::new_cuda(0).is_ok() {
|
|
72
|
+
return Device::Cuda;
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
// Always fall back to CPU
|
|
77
|
+
Device::Cpu
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
|
81
|
+
#[magnus::wrap(class = "Candle::Device")]
|
|
82
|
+
pub enum Device {
|
|
83
|
+
Cpu,
|
|
84
|
+
Cuda,
|
|
85
|
+
Metal,
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
impl Device {
|
|
89
|
+
/// Create a CPU device
|
|
90
|
+
pub fn cpu() -> Self {
|
|
91
|
+
Self::Cpu
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
/// Get the best available device
|
|
95
|
+
pub fn best() -> Self {
|
|
96
|
+
best_device()
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
/// Create a CUDA device (GPU)
|
|
100
|
+
pub fn cuda() -> Result<Self> {
|
|
101
|
+
#[cfg(not(feature = "cuda"))]
|
|
102
|
+
{
|
|
103
|
+
return Err(Error::new(
|
|
104
|
+
Ruby::get().unwrap().exception_runtime_error(),
|
|
105
|
+
"CUDA support not compiled in. Rebuild with CUDA available.",
|
|
106
|
+
));
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
#[cfg(feature = "cuda")]
|
|
110
|
+
Ok(Self::Cuda)
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
/// Create a Metal device (Apple GPU)
|
|
114
|
+
pub fn metal() -> Result<Self> {
|
|
115
|
+
#[cfg(not(feature = "metal"))]
|
|
116
|
+
{
|
|
117
|
+
return Err(Error::new(
|
|
118
|
+
Ruby::get().unwrap().exception_runtime_error(),
|
|
119
|
+
"Metal support not compiled in. Rebuild on macOS.",
|
|
120
|
+
));
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
#[cfg(feature = "metal")]
|
|
124
|
+
Ok(Self::Metal)
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
pub fn from_device(device: &CoreDevice) -> Self {
|
|
128
|
+
match device {
|
|
129
|
+
CoreDevice::Cpu => Self::Cpu,
|
|
130
|
+
CoreDevice::Cuda(_) => Self::Cuda,
|
|
131
|
+
CoreDevice::Metal(_) => Self::Metal,
|
|
132
|
+
}
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
pub fn as_device(&self) -> Result<CoreDevice> {
|
|
136
|
+
match self {
|
|
137
|
+
Self::Cpu => Ok(CoreDevice::Cpu),
|
|
138
|
+
Self::Cuda => {
|
|
139
|
+
#[cfg(not(feature = "cuda"))]
|
|
140
|
+
{
|
|
141
|
+
return Err(Error::new(
|
|
142
|
+
Ruby::get().unwrap().exception_runtime_error(),
|
|
143
|
+
"CUDA support not compiled in. Rebuild with CUDA available.",
|
|
144
|
+
));
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
#[cfg(feature = "cuda")]
|
|
148
|
+
{
|
|
149
|
+
let mut device = CUDA_DEVICE.lock().unwrap();
|
|
150
|
+
if let Some(device) = device.as_ref() {
|
|
151
|
+
return Ok(device.clone());
|
|
152
|
+
};
|
|
153
|
+
// Note: new_cuda() is used here (not cuda_if_available) because
|
|
154
|
+
// we want to fail if CUDA isn't available at runtime, not fall back to CPU
|
|
155
|
+
let d = CoreDevice::new_cuda(0).map_err(wrap_candle_err)?;
|
|
156
|
+
*device = Some(d.clone());
|
|
157
|
+
Ok(d)
|
|
158
|
+
}
|
|
159
|
+
}
|
|
160
|
+
Self::Metal => {
|
|
161
|
+
#[cfg(not(feature = "metal"))]
|
|
162
|
+
{
|
|
163
|
+
return Err(Error::new(
|
|
164
|
+
Ruby::get().unwrap().exception_runtime_error(),
|
|
165
|
+
"Metal support not compiled in. Rebuild on macOS.",
|
|
166
|
+
));
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
#[cfg(feature = "metal")]
|
|
170
|
+
{
|
|
171
|
+
let mut device = METAL_DEVICE.lock().unwrap();
|
|
172
|
+
if let Some(device) = device.as_ref() {
|
|
173
|
+
return Ok(device.clone());
|
|
174
|
+
};
|
|
175
|
+
let d = CoreDevice::new_metal(0).map_err(wrap_candle_err)?;
|
|
176
|
+
*device = Some(d.clone());
|
|
177
|
+
Ok(d)
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
}
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
pub fn __repr__(&self) -> String {
|
|
184
|
+
match self {
|
|
185
|
+
Self::Cpu => "cpu".to_string(),
|
|
186
|
+
Self::Cuda => "cuda".to_string(),
|
|
187
|
+
Self::Metal => "metal".to_string(),
|
|
188
|
+
}
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
pub fn __str__(&self) -> String {
|
|
192
|
+
self.__repr__()
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
pub fn __eq__(&self, other: &Device) -> bool {
|
|
196
|
+
self == other
|
|
197
|
+
}
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
impl magnus::TryConvert for Device {
|
|
201
|
+
fn try_convert(val: magnus::Value) -> Result<Self> {
|
|
202
|
+
// First check if it's already a wrapped Device object
|
|
203
|
+
if let Ok(device) = <magnus::typed_data::Obj<Device> as magnus::TryConvert>::try_convert(val) {
|
|
204
|
+
return Ok(*device);
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
// Otherwise try to convert from string
|
|
208
|
+
let device = magnus::RString::try_convert(val)?;
|
|
209
|
+
let device = unsafe { device.as_str() }.unwrap();
|
|
210
|
+
let device = match device {
|
|
211
|
+
"cpu" => Device::Cpu,
|
|
212
|
+
"cuda" => Device::Cuda,
|
|
213
|
+
"metal" => Device::Metal,
|
|
214
|
+
_ => return Err(Error::new(Ruby::get().unwrap().exception_arg_error(), "invalid device")),
|
|
215
|
+
};
|
|
216
|
+
Ok(device)
|
|
217
|
+
}
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
pub fn init(rb_candle: RModule) -> Result<()> {
|
|
221
|
+
let ruby = Ruby::get().unwrap();
|
|
222
|
+
let rb_device = rb_candle.define_class("Device", ruby.class_object())?;
|
|
223
|
+
rb_device.define_singleton_method("cpu", function!(Device::cpu, 0))?;
|
|
224
|
+
rb_device.define_singleton_method("cuda", function!(Device::cuda, 0))?;
|
|
225
|
+
rb_device.define_singleton_method("metal", function!(Device::metal, 0))?;
|
|
226
|
+
rb_device.define_singleton_method("available_devices", function!(available_devices, 0))?;
|
|
227
|
+
rb_device.define_singleton_method("default", function!(default_device, 0))?;
|
|
228
|
+
rb_device.define_singleton_method("best", function!(best_device, 0))?;
|
|
229
|
+
rb_device.define_method("to_s", method!(Device::__str__, 0))?;
|
|
230
|
+
rb_device.define_method("inspect", method!(Device::__repr__, 0))?;
|
|
231
|
+
rb_device.define_method("==", method!(Device::__eq__, 1))?;
|
|
232
|
+
Ok(())
|
|
233
|
+
}
|
|
234
|
+
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
use magnus::value::ReprValue;
|
|
2
|
+
use magnus::{method, RModule, Module, Ruby};
|
|
3
|
+
|
|
4
|
+
use ::candle_core::DType as CoreDType;
|
|
5
|
+
use crate::ruby::Result;
|
|
6
|
+
|
|
7
|
+
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
|
8
|
+
#[magnus::wrap(class = "Candle::DType", free_immediately, size)]
|
|
9
|
+
|
|
10
|
+
/// A `candle` dtype.
|
|
11
|
+
pub struct DType(pub CoreDType);
|
|
12
|
+
|
|
13
|
+
impl DType {
|
|
14
|
+
pub fn __repr__(&self) -> String {
|
|
15
|
+
format!("{:?}", self.0)
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
pub fn __str__(&self) -> String {
|
|
19
|
+
self.__repr__()
|
|
20
|
+
}
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
impl DType {
|
|
24
|
+
pub fn from_rbobject(dtype: magnus::Symbol) -> Result<Self> {
|
|
25
|
+
let dtype = unsafe { dtype.to_s() }.unwrap().into_owned();
|
|
26
|
+
use std::str::FromStr;
|
|
27
|
+
let dtype = CoreDType::from_str(&dtype).unwrap();
|
|
28
|
+
Ok(Self(dtype))
|
|
29
|
+
}
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
pub fn init(rb_candle: RModule) -> Result<()> {
|
|
33
|
+
let ruby = Ruby::get().unwrap();
|
|
34
|
+
let rb_dtype = rb_candle.define_class("DType", ruby.class_object())?;
|
|
35
|
+
rb_dtype.define_method("to_s", method!(DType::__str__, 0))?;
|
|
36
|
+
rb_dtype.define_method("inspect", method!(DType::__repr__, 0))?;
|
|
37
|
+
Ok(())
|
|
38
|
+
}
|
|
39
|
+
|