red-candle 0.0.3 → 0.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,773 +1,103 @@
1
- use magnus::{function, method, prelude::*, Error, Ruby};
2
- use std::sync::Arc;
1
+ use magnus::{function, method, prelude::*, Ruby};
3
2
 
4
- use half::{bf16, f16};
3
+ use crate::model::{candle_utils, RbModel, RbDType, RbDevice, RbQTensor, RbResult, RbTensor};
5
4
 
6
- use ::candle_core::{quantized::QTensor, DType, Device, Tensor, WithDType};
7
-
8
- type PyResult<T> = Result<T, Error>;
9
-
10
- pub fn wrap_err(err: candle_core::Error) -> Error {
11
- Error::new(magnus::exception::runtime_error(), err.to_string())
12
- }
13
-
14
- // #[derive(Clone, Debug)]
15
- // struct RbShape(Vec<usize>);
16
-
17
- // impl magnus::TryConvert for RbShape {
18
- // fn try_convert(val: magnus::Value) -> PyResult<Self> {
19
- // let ary = magnus::RArray::try_convert(val)?;
20
- // let shape = ary
21
- // .each()
22
- // .map(|v| magnus::Integer::try_convert(v?).map(|v| v.to_usize().unwrap()))
23
- // .collect::<PyResult<Vec<_>>>()?;
24
- // Ok(Self(shape))
25
- // }
26
- // }
27
-
28
- // impl magnus::IntoValue for RbShape {
29
- // fn into_value_with(self, ruby: &Ruby) -> magnus::Value {
30
- // let ary = magnus::RArray::from_vec(self.0);
31
- // ary.into_value_with(ruby)
32
- // }
33
- //}
34
-
35
- #[derive(Clone, Debug)]
36
- #[magnus::wrap(class = "Candle::Tensor", free_immediately, size)]
37
- /// A `candle` tensor.
38
- struct PyTensor(Tensor);
39
-
40
- impl std::ops::Deref for PyTensor {
41
- type Target = Tensor;
42
-
43
- fn deref(&self) -> &Self::Target {
44
- &self.0
45
- }
46
- }
47
-
48
- #[derive(Clone, Copy, Debug, PartialEq, Eq)]
49
- #[magnus::wrap(class = "Candle::DType", free_immediately, size)]
50
- /// A `candle` dtype.
51
- struct PyDType(DType);
52
-
53
- impl PyDType {
54
- fn __repr__(&self) -> String {
55
- format!("{:?}", self.0)
56
- }
57
-
58
- fn __str__(&self) -> String {
59
- self.__repr__()
60
- }
61
- }
62
-
63
- impl PyDType {
64
- fn from_pyobject(dtype: magnus::Symbol) -> PyResult<Self> {
65
- let dtype = unsafe { dtype.to_s() }.unwrap().into_owned();
66
- use std::str::FromStr;
67
- let dtype = DType::from_str(&dtype).unwrap();
68
- Ok(Self(dtype))
69
- }
70
- }
71
-
72
- static CUDA_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
73
-
74
- #[derive(Clone, Copy, Debug, PartialEq, Eq)]
75
- #[magnus::wrap(class = "Candle::Device")]
76
- enum PyDevice {
77
- Cpu,
78
- Cuda,
79
- }
80
-
81
- impl PyDevice {
82
- fn from_device(device: &Device) -> Self {
83
- match device {
84
- Device::Cpu => Self::Cpu,
85
- Device::Cuda(_) => Self::Cuda,
86
- }
87
- }
88
-
89
- fn as_device(&self) -> PyResult<Device> {
90
- match self {
91
- Self::Cpu => Ok(Device::Cpu),
92
- Self::Cuda => {
93
- let mut device = CUDA_DEVICE.lock().unwrap();
94
- if let Some(device) = device.as_ref() {
95
- return Ok(device.clone());
96
- };
97
- let d = Device::new_cuda(0).map_err(wrap_err)?;
98
- *device = Some(d.clone());
99
- Ok(d)
100
- }
101
- }
102
- }
103
-
104
- fn __repr__(&self) -> String {
105
- match self {
106
- Self::Cpu => "cpu".to_string(),
107
- Self::Cuda => "cuda".to_string(),
108
- }
109
- }
110
-
111
- fn __str__(&self) -> String {
112
- self.__repr__()
113
- }
114
- }
115
-
116
- impl magnus::TryConvert for PyDevice {
117
- fn try_convert(val: magnus::Value) -> PyResult<Self> {
118
- let device = magnus::RString::try_convert(val)?;
119
- let device = unsafe { device.as_str() }.unwrap();
120
- let device = match device {
121
- "cpu" => PyDevice::Cpu,
122
- "cuda" => PyDevice::Cuda,
123
- _ => return Err(Error::new(magnus::exception::arg_error(), "invalid device")),
124
- };
125
- Ok(device)
126
- }
127
- }
128
-
129
- fn actual_index(t: &Tensor, dim: usize, index: i64) -> candle_core::Result<usize> {
130
- let dim = t.dim(dim)?;
131
- if 0 <= index {
132
- let index = index as usize;
133
- if dim <= index {
134
- candle_core::bail!("index {index} is too large for tensor dimension {dim}")
135
- }
136
- Ok(index)
137
- } else {
138
- if (dim as i64) < -index {
139
- candle_core::bail!("index {index} is too low for tensor dimension {dim}")
140
- }
141
- Ok((dim as i64 + index) as usize)
142
- }
143
- }
144
-
145
- fn actual_dim(t: &Tensor, dim: i64) -> candle_core::Result<usize> {
146
- let rank = t.rank();
147
- if 0 <= dim {
148
- let dim = dim as usize;
149
- if rank <= dim {
150
- candle_core::bail!("dimension index {dim} is too large for tensor rank {rank}")
151
- }
152
- Ok(dim)
153
- } else {
154
- if (rank as i64) < -dim {
155
- candle_core::bail!("dimension index {dim} is too low for tensor rank {rank}")
156
- }
157
- Ok((rank as i64 + dim) as usize)
158
- }
159
- }
160
- impl PyTensor {
161
- fn new(array: magnus::RArray, dtype: Option<magnus::Symbol>) -> PyResult<Self> {
162
- let dtype = dtype
163
- .map(|dtype| PyDType::from_pyobject(dtype))
164
- .unwrap_or(Ok(PyDType(DType::F32)))?;
165
- // FIXME: Do not use `to_f64` here.
166
- let array = array
167
- .each()
168
- .map(|v| magnus::Float::try_convert(v?).map(|v| v.to_f64()))
169
- .collect::<PyResult<Vec<_>>>()?;
170
- Ok(Self(
171
- Tensor::new(array.as_slice(), &Device::Cpu)
172
- .map_err(wrap_err)?
173
- .to_dtype(dtype.0)
174
- .map_err(wrap_err)?,
175
- ))
176
- }
177
-
178
- /// Gets the tensor's shape.
179
- /// &RETURNS&: Tuple[int]
180
- fn shape(&self) -> Vec<usize> {
181
- self.0.dims().to_vec()
182
- }
183
-
184
- /// Gets the tensor's strides.
185
- /// &RETURNS&: Tuple[int]
186
- fn stride(&self) -> Vec<usize> {
187
- self.0.stride().to_vec()
188
- }
189
-
190
- /// Gets the tensor's dtype.
191
- /// &RETURNS&: DType
192
- fn dtype(&self) -> PyDType {
193
- PyDType(self.0.dtype())
194
- }
195
-
196
- /// Gets the tensor's device.
197
- /// &RETURNS&: Device
198
- fn device(&self) -> PyDevice {
199
- PyDevice::from_device(self.0.device())
200
- }
201
-
202
- /// Gets the tensor's rank.
203
- /// &RETURNS&: int
204
- fn rank(&self) -> usize {
205
- self.0.rank()
206
- }
207
-
208
- fn __repr__(&self) -> String {
209
- format!("{}", self.0)
210
- }
211
-
212
- fn __str__(&self) -> String {
213
- self.__repr__()
214
- }
215
-
216
- /// Performs the `sin` operation on the tensor.
217
- /// &RETURNS&: Tensor
218
- fn sin(&self) -> PyResult<Self> {
219
- Ok(PyTensor(self.0.sin().map_err(wrap_err)?))
220
- }
221
-
222
- /// Performs the `cos` operation on the tensor.
223
- /// &RETURNS&: Tensor
224
- fn cos(&self) -> PyResult<Self> {
225
- Ok(PyTensor(self.0.cos().map_err(wrap_err)?))
226
- }
227
-
228
- /// Performs the `log` operation on the tensor.
229
- /// &RETURNS&: Tensor
230
- fn log(&self) -> PyResult<Self> {
231
- Ok(PyTensor(self.0.log().map_err(wrap_err)?))
232
- }
233
-
234
- /// Squares the tensor.
235
- /// &RETURNS&: Tensor
236
- fn sqr(&self) -> PyResult<Self> {
237
- Ok(PyTensor(self.0.sqr().map_err(wrap_err)?))
238
- }
239
-
240
- /// Calculates the square root of the tensor.
241
- /// &RETURNS&: Tensor
242
- fn sqrt(&self) -> PyResult<Self> {
243
- Ok(PyTensor(self.0.sqrt().map_err(wrap_err)?))
244
- }
245
-
246
- /// Get the `recip` of the tensor.
247
- /// &RETURNS&: Tensor
248
- fn recip(&self) -> PyResult<Self> {
249
- Ok(PyTensor(self.0.recip().map_err(wrap_err)?))
250
- }
251
-
252
- /// Performs the `exp` operation on the tensor.
253
- /// &RETURNS&: Tensor
254
- fn exp(&self) -> PyResult<Self> {
255
- Ok(PyTensor(self.0.exp().map_err(wrap_err)?))
256
- }
257
-
258
- /// Performs the `pow` operation on the tensor with the given exponent.
259
- /// &RETURNS&: Tensor
260
- fn powf(&self, p: f64) -> PyResult<Self> {
261
- Ok(PyTensor(self.0.powf(p).map_err(wrap_err)?))
262
- }
263
-
264
- /// Select values for the input tensor at the target indexes across the specified dimension.
265
- ///
266
- /// The `indexes` is argument is an int tensor with a single dimension.
267
- /// The output has the same number of dimension as the `self` input. The target dimension of
268
- /// the output has length the length of `indexes` and the values are taken from `self` using
269
- /// the index from `indexes`. Other dimensions have the same number of elements as the input
270
- /// tensor.
271
- /// &RETURNS&: Tensor
272
- fn index_select(&self, rhs: &Self, dim: i64) -> PyResult<Self> {
273
- let dim = actual_dim(self, dim).map_err(wrap_err)?;
274
- Ok(PyTensor(self.0.index_select(rhs, dim).map_err(wrap_err)?))
275
- }
276
-
277
- /// Performs a matrix multiplication between the two tensors.
278
- /// &RETURNS&: Tensor
279
- fn matmul(&self, rhs: &Self) -> PyResult<Self> {
280
- Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?))
281
- }
282
-
283
- /// Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
284
- /// &RETURNS&: Tensor
285
- fn broadcast_add(&self, rhs: &Self) -> PyResult<Self> {
286
- Ok(PyTensor(self.0.broadcast_add(rhs).map_err(wrap_err)?))
287
- }
288
-
289
- /// Subtracts the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
290
- /// &RETURNS&: Tensor
291
- fn broadcast_sub(&self, rhs: &Self) -> PyResult<Self> {
292
- Ok(PyTensor(self.0.broadcast_sub(rhs).map_err(wrap_err)?))
293
- }
294
-
295
- /// Multiplies the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
296
- /// &RETURNS&: Tensor
297
- fn broadcast_mul(&self, rhs: &Self) -> PyResult<Self> {
298
- Ok(PyTensor(self.0.broadcast_mul(rhs).map_err(wrap_err)?))
299
- }
300
-
301
- /// Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
302
- /// &RETURNS&: Tensor
303
- fn broadcast_div(&self, rhs: &Self) -> PyResult<Self> {
304
- Ok(PyTensor(self.0.broadcast_div(rhs).map_err(wrap_err)?))
305
- }
306
-
307
- /// Returns a tensor with the same shape as the input tensor, the values are taken from
308
- /// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the
309
- /// input tensor is equal to zero.
310
- /// &RETURNS&: Tensor
311
- fn where_cond(&self, on_true: &Self, on_false: &Self) -> PyResult<Self> {
312
- Ok(PyTensor(
313
- self.0.where_cond(on_true, on_false).map_err(wrap_err)?,
314
- ))
315
- }
316
-
317
- /// Add two tensors.
318
- /// &RETURNS&: Tensor
319
- fn __add__(&self, rhs: &PyTensor) -> PyResult<Self> {
320
- Ok(Self(self.0.add(&rhs.0).map_err(wrap_err)?))
321
- }
322
-
323
- /// Multiply two tensors.
324
- /// &RETURNS&: Tensor
325
- fn __mul__(&self, rhs: &PyTensor) -> PyResult<Self> {
326
- Ok(Self(self.0.mul(&rhs.0).map_err(wrap_err)?))
327
- }
328
-
329
- /// Subtract two tensors.
330
- /// &RETURNS&: Tensor
331
- fn __sub__(&self, rhs: &PyTensor) -> PyResult<Self> {
332
- Ok(Self(self.0.sub(&rhs.0).map_err(wrap_err)?))
333
- }
334
-
335
- /// Divide two tensors.
336
- /// &RETURNS&: Tensor
337
- fn __truediv__(&self, rhs: &PyTensor) -> PyResult<Self> {
338
- Ok(Self(self.0.div(&rhs.0).map_err(wrap_err)?))
339
- }
340
-
341
- /// Reshapes the tensor to the given shape.
342
- /// &RETURNS&: Tensor
343
- fn reshape(&self, shape: Vec<usize>) -> PyResult<Self> {
344
- Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
345
- }
346
-
347
- /// Broadcasts the tensor to the given shape.
348
- /// &RETURNS&: Tensor
349
- fn broadcast_as(&self, shape: Vec<usize>) -> PyResult<Self> {
350
- Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?))
351
- }
352
-
353
- /// Broadcasts the tensor to the given shape, adding new dimensions on the left.
354
- /// &RETURNS&: Tensor
355
- fn broadcast_left(&self, shape: Vec<usize>) -> PyResult<Self> {
356
- Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?))
357
- }
358
-
359
- /// Creates a new tensor with the specified dimension removed if its size was one.
360
- /// &RETURNS&: Tensor
361
- fn squeeze(&self, dim: i64) -> PyResult<Self> {
362
- let dim = actual_dim(self, dim).map_err(wrap_err)?;
363
- Ok(PyTensor(self.0.squeeze(dim).map_err(wrap_err)?))
364
- }
365
-
366
- /// Creates a new tensor with a dimension of size one inserted at the specified position.
367
- /// &RETURNS&: Tensor
368
- fn unsqueeze(&self, dim: usize) -> PyResult<Self> {
369
- Ok(PyTensor(self.0.unsqueeze(dim).map_err(wrap_err)?))
370
- }
371
-
372
- /// Gets the value at the specified index.
373
- /// &RETURNS&: Tensor
374
- fn get(&self, index: i64) -> PyResult<Self> {
375
- let index = actual_index(self, 0, index).map_err(wrap_err)?;
376
- Ok(PyTensor(self.0.get(index).map_err(wrap_err)?))
377
- }
378
-
379
- /// Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
380
- /// &RETURNS&: Tensor
381
- fn transpose(&self, dim1: usize, dim2: usize) -> PyResult<Self> {
382
- Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?))
383
- }
384
-
385
- /// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
386
- /// ranges from `start` to `start + len`.
387
- /// &RETURNS&: Tensor
388
- fn narrow(&self, dim: i64, start: i64, len: usize) -> PyResult<Self> {
389
- let dim = actual_dim(self, dim).map_err(wrap_err)?;
390
- let start = actual_index(self, dim, start).map_err(wrap_err)?;
391
- Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?))
392
- }
393
-
394
- /// Returns the indices of the maximum value(s) across the selected dimension.
395
- /// &RETURNS&: Tensor
396
- fn argmax_keepdim(&self, dim: i64) -> PyResult<Self> {
397
- let dim = actual_dim(self, dim).map_err(wrap_err)?;
398
- Ok(PyTensor(self.0.argmax_keepdim(dim).map_err(wrap_err)?))
399
- }
400
-
401
- /// Returns the indices of the minimum value(s) across the selected dimension.
402
- /// &RETURNS&: Tensor
403
- fn argmin_keepdim(&self, dim: i64) -> PyResult<Self> {
404
- let dim = actual_dim(self, dim).map_err(wrap_err)?;
405
- Ok(PyTensor(self.0.argmin_keepdim(dim).map_err(wrap_err)?))
406
- }
407
-
408
- /// Gathers the maximum value across the selected dimension.
409
- /// &RETURNS&: Tensor
410
- fn max_keepdim(&self, dim: i64) -> PyResult<Self> {
411
- let dim = actual_dim(self, dim).map_err(wrap_err)?;
412
- Ok(PyTensor(self.0.max_keepdim(dim).map_err(wrap_err)?))
413
- }
414
-
415
- /// Gathers the minimum value across the selected dimension.
416
- /// &RETURNS&: Tensor
417
- fn min_keepdim(&self, dim: i64) -> PyResult<Self> {
418
- let dim = actual_dim(self, dim).map_err(wrap_err)?;
419
- Ok(PyTensor(self.0.min_keepdim(dim).map_err(wrap_err)?))
420
- }
421
-
422
- // fn eq(&self, rhs: &Self) -> PyResult<Self> {
423
- // Ok(PyTensor(self.0.eq(rhs).map_err(wrap_err)?))
424
- // }
425
-
426
- // fn ne(&self, rhs: &Self) -> PyResult<Self> {
427
- // Ok(PyTensor(self.0.ne(rhs).map_err(wrap_err)?))
428
- // }
429
-
430
- // fn lt(&self, rhs: &Self) -> PyResult<Self> {
431
- // Ok(PyTensor(self.0.lt(rhs).map_err(wrap_err)?))
432
- // }
433
-
434
- // fn gt(&self, rhs: &Self) -> PyResult<Self> {
435
- // Ok(PyTensor(self.0.gt(rhs).map_err(wrap_err)?))
436
- // }
437
-
438
- // fn ge(&self, rhs: &Self) -> PyResult<Self> {
439
- // Ok(PyTensor(self.0.ge(rhs).map_err(wrap_err)?))
440
- // }
441
-
442
- // fn le(&self, rhs: &Self) -> PyResult<Self> {
443
- // Ok(PyTensor(self.0.le(rhs).map_err(wrap_err)?))
444
- // }
445
-
446
- /// Returns the sum of the tensor.
447
- /// &RETURNS&: Tensor
448
- fn sum_all(&self) -> PyResult<Self> {
449
- Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?))
450
- }
451
-
452
- /// Returns the mean of the tensor.
453
- /// &RETURNS&: Tensor
454
- fn mean_all(&self) -> PyResult<Self> {
455
- let elements = self.0.elem_count();
456
- let sum = self.0.sum_all().map_err(wrap_err)?;
457
- let mean = (sum / elements as f64).map_err(wrap_err)?;
458
- Ok(PyTensor(mean))
459
- }
460
-
461
- /// Flattens the tensor on the dimension indexes from `dim` (inclusive) to the last dimension.
462
- /// &RETURNS&: Tensor
463
- fn flatten_from(&self, dim: i64) -> PyResult<Self> {
464
- let dim = actual_dim(self, dim).map_err(wrap_err)?;
465
- Ok(PyTensor(self.0.flatten_from(dim).map_err(wrap_err)?))
466
- }
467
-
468
- ///Flattens the tensor on the dimension indexes from `0` to `dim` (inclusive).
469
- /// &RETURNS&: Tensor
470
- fn flatten_to(&self, dim: i64) -> PyResult<Self> {
471
- let dim = actual_dim(self, dim).map_err(wrap_err)?;
472
- Ok(PyTensor(self.0.flatten_to(dim).map_err(wrap_err)?))
473
- }
474
-
475
- /// Flattens the tensor into a 1D tensor.
476
- /// &RETURNS&: Tensor
477
- fn flatten_all(&self) -> PyResult<Self> {
478
- Ok(PyTensor(self.0.flatten_all().map_err(wrap_err)?))
479
- }
480
-
481
- /// Transposes the tensor.
482
- /// &RETURNS&: Tensor
483
- fn t(&self) -> PyResult<Self> {
484
- Ok(PyTensor(self.0.t().map_err(wrap_err)?))
485
- }
486
-
487
- /// Makes the tensor contiguous in memory.
488
- /// &RETURNS&: Tensor
489
- fn contiguous(&self) -> PyResult<Self> {
490
- Ok(PyTensor(self.0.contiguous().map_err(wrap_err)?))
491
- }
492
-
493
- /// Returns true if the tensor is contiguous in C order.
494
- /// &RETURNS&: bool
495
- fn is_contiguous(&self) -> bool {
496
- self.0.is_contiguous()
497
- }
498
-
499
- /// Returns true if the tensor is contiguous in Fortran order.
500
- /// &RETURNS&: bool
501
- fn is_fortran_contiguous(&self) -> bool {
502
- self.0.is_fortran_contiguous()
503
- }
504
-
505
- /// Detach the tensor from the computation graph.
506
- /// &RETURNS&: Tensor
507
- fn detach(&self) -> PyResult<Self> {
508
- Ok(PyTensor(self.0.detach().map_err(wrap_err)?))
509
- }
510
-
511
- /// Returns a copy of the tensor.
512
- /// &RETURNS&: Tensor
513
- fn copy(&self) -> PyResult<Self> {
514
- Ok(PyTensor(self.0.copy().map_err(wrap_err)?))
515
- }
516
-
517
- /// Convert the tensor to a new dtype.
518
- /// &RETURNS&: Tensor
519
- fn to_dtype(&self, dtype: magnus::Symbol) -> PyResult<Self> {
520
- let dtype = PyDType::from_pyobject(dtype)?;
521
- Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?))
522
- }
523
-
524
- /// Move the tensor to a new device.
525
- /// &RETURNS&: Tensor
526
- fn to_device(&self, device: PyDevice) -> PyResult<Self> {
527
- let device = device.as_device()?;
528
- Ok(PyTensor(self.0.to_device(&device).map_err(wrap_err)?))
529
- }
530
- }
531
-
532
- impl PyTensor {
533
- // fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
534
- // if tensors.is_empty() {
535
- // return Err(Error::new(
536
- // magnus::exception::arg_error(),
537
- // "empty input to cat",
538
- // ));
539
- // }
540
- // let dim = actual_dim(&tensors[0].0, dim).map_err(wrap_err)?;
541
- // let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
542
- // let tensor = Tensor::cat(&tensors, dim).map_err(wrap_err)?;
543
- // Ok(PyTensor(tensor))
544
- // }
545
-
546
- // fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<Self> {
547
- // let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
548
- // let tensor = Tensor::stack(&tensors, dim).map_err(wrap_err)?;
549
- // Ok(Self(tensor))
550
- // }
551
-
552
- /// Creates a new tensor with random values.
553
- /// &RETURNS&: Tensor
554
- fn rand(shape: Vec<usize>) -> PyResult<Self> {
555
- let device = PyDevice::Cpu.as_device()?;
556
- Ok(Self(
557
- Tensor::rand(0f32, 1f32, shape, &device).map_err(wrap_err)?,
558
- ))
559
- }
560
-
561
- /// Creates a new tensor with random values from a normal distribution.
562
- /// &RETURNS&: Tensor
563
- fn randn(shape: Vec<usize>) -> PyResult<Self> {
564
- let device = PyDevice::Cpu.as_device()?;
565
- Ok(Self(
566
- Tensor::randn(0f32, 1f32, shape, &device).map_err(wrap_err)?,
567
- ))
568
- }
569
-
570
- /// Creates a new tensor filled with ones.
571
- /// &RETURNS&: Tensor
572
- fn ones(shape: Vec<usize>) -> PyResult<Self> {
573
- let device = PyDevice::Cpu.as_device()?;
574
- Ok(Self(
575
- Tensor::ones(shape, DType::F32, &device).map_err(wrap_err)?,
576
- ))
577
- }
578
- /// Creates a new tensor filled with zeros.
579
- /// &RETURNS&: Tensor
580
- fn zeros(shape: Vec<usize>) -> PyResult<Self> {
581
- let device = PyDevice::Cpu.as_device()?;
582
- Ok(Self(
583
- Tensor::zeros(shape, DType::F32, &device).map_err(wrap_err)?,
584
- ))
585
- }
586
- }
587
-
588
- #[derive(Debug)]
589
- #[magnus::wrap(class = "Candle::QTensor", free_immediately, size)]
590
- /// A quantized tensor.
591
- struct PyQTensor(Arc<QTensor>);
592
-
593
- impl std::ops::Deref for PyQTensor {
594
- type Target = QTensor;
595
-
596
- fn deref(&self) -> &Self::Target {
597
- self.0.as_ref()
598
- }
599
- }
600
-
601
- impl PyQTensor {
602
- ///Gets the tensors quantized dtype.
603
- /// &RETURNS&: str
604
- fn ggml_dtype(&self) -> String {
605
- format!("{:?}", self.0.dtype())
606
- }
607
-
608
- ///Gets the rank of the tensor.
609
- /// &RETURNS&: int
610
- fn rank(&self) -> usize {
611
- self.0.rank()
612
- }
613
-
614
- ///Gets the shape of the tensor.
615
- /// &RETURNS&: Tuple[int]
616
- fn shape(&self) -> Vec<usize> {
617
- self.0.shape().dims().to_vec()
618
- }
619
-
620
- fn __repr__(&self) -> String {
621
- format!("{:?}", self.0)
622
- }
623
-
624
- fn __str__(&self) -> String {
625
- self.__repr__()
626
- }
627
-
628
- /// Dequantizes the tensor.
629
- /// &RETURNS&: Tensor
630
- fn dequantize(&self) -> PyResult<PyTensor> {
631
- let tensor = self.0.dequantize(&Device::Cpu).map_err(wrap_err)?;
632
- Ok(PyTensor(tensor))
633
- }
634
-
635
- // fn matmul_t(&self, lhs: &PyTensor) -> PyResult<PyTensor> {
636
- // let qmatmul = ::candle_core::quantized::QMatMul::from_arc(self.0.clone());
637
- // let res = qmatmul.forward(lhs).map_err(wrap_err)?;
638
- // Ok(PyTensor(res))
639
- // }
640
- }
641
-
642
- /// Returns true if the 'cuda' backend is available.
643
- /// &RETURNS&: bool
644
- fn cuda_is_available() -> bool {
645
- candle_core::utils::cuda_is_available()
646
- }
647
-
648
- /// Returns true if candle was compiled with 'accelerate' support.
649
- /// &RETURNS&: bool
650
- fn has_accelerate() -> bool {
651
- candle_core::utils::has_accelerate()
652
- }
653
-
654
- /// Returns true if candle was compiled with MKL support.
655
- /// &RETURNS&: bool
656
- fn has_mkl() -> bool {
657
- candle_core::utils::has_mkl()
658
- }
659
-
660
- /// Returns the number of threads used by the candle.
661
- /// &RETURNS&: int
662
- fn get_num_threads() -> usize {
663
- candle_core::utils::get_num_threads()
664
- }
665
-
666
- fn candle_utils(rb_candle: magnus::RModule) -> Result<(), Error> {
667
- let rb_utils = rb_candle.define_module("Utils")?;
668
- rb_utils.define_singleton_method("cuda_is_available", function!(cuda_is_available, 0))?;
669
- rb_utils.define_singleton_method("get_num_threads", function!(get_num_threads, 0))?;
670
- rb_utils.define_singleton_method("has_accelerate", function!(has_accelerate, 0))?;
671
- rb_utils.define_singleton_method("has_mkl", function!(has_mkl, 0))?;
672
- Ok(())
673
- }
674
-
675
- /// Applies the Softmax function to a given tensor.#
676
- /// &RETURNS&: Tensor
677
- fn softmax(tensor: PyTensor, dim: i64) -> PyResult<PyTensor> {
678
- let dim = actual_dim(&tensor, dim).map_err(wrap_err)?;
679
- let sm = candle_nn::ops::softmax(&tensor.0, dim).map_err(wrap_err)?;
680
- Ok(PyTensor(sm))
681
- }
682
-
683
- /// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
684
- /// &RETURNS&: Tensor
685
- fn silu(tensor: PyTensor) -> PyResult<PyTensor> {
686
- let s = candle_nn::ops::silu(&tensor.0).map_err(wrap_err)?;
687
- Ok(PyTensor(s))
688
- }
5
+ pub mod model;
689
6
 
690
7
  #[magnus::init]
691
- fn init(ruby: &Ruby) -> PyResult<()> {
8
+ fn init(ruby: &Ruby) -> RbResult<()> {
692
9
  let rb_candle = ruby.define_module("Candle")?;
693
10
  candle_utils(rb_candle)?;
694
11
  let rb_tensor = rb_candle.define_class("Tensor", Ruby::class_object(ruby))?;
695
- rb_tensor.define_singleton_method("new", function!(PyTensor::new, 2))?;
696
- // rb_tensor.define_singleton_method("cat", function!(PyTensor::cat, 2))?;
697
- // rb_tensor.define_singleton_method("stack", function!(PyTensor::stack, 2))?;
698
- rb_tensor.define_singleton_method("rand", function!(PyTensor::rand, 1))?;
699
- rb_tensor.define_singleton_method("randn", function!(PyTensor::randn, 1))?;
700
- rb_tensor.define_singleton_method("ones", function!(PyTensor::ones, 1))?;
701
- rb_tensor.define_singleton_method("zeros", function!(PyTensor::zeros, 1))?;
702
- rb_tensor.define_method("shape", method!(PyTensor::shape, 0))?;
703
- rb_tensor.define_method("stride", method!(PyTensor::stride, 0))?;
704
- rb_tensor.define_method("dtype", method!(PyTensor::dtype, 0))?;
705
- rb_tensor.define_method("device", method!(PyTensor::device, 0))?;
706
- rb_tensor.define_method("rank", method!(PyTensor::rank, 0))?;
707
- rb_tensor.define_method("sin", method!(PyTensor::sin, 0))?;
708
- rb_tensor.define_method("cos", method!(PyTensor::cos, 0))?;
709
- rb_tensor.define_method("log", method!(PyTensor::log, 0))?;
710
- rb_tensor.define_method("sqr", method!(PyTensor::sqr, 0))?;
711
- rb_tensor.define_method("sqrt", method!(PyTensor::sqrt, 0))?;
712
- rb_tensor.define_method("recip", method!(PyTensor::recip, 0))?;
713
- rb_tensor.define_method("exp", method!(PyTensor::exp, 0))?;
714
- rb_tensor.define_method("powf", method!(PyTensor::powf, 1))?;
715
- rb_tensor.define_method("index_select", method!(PyTensor::index_select, 2))?;
716
- rb_tensor.define_method("matmul", method!(PyTensor::matmul, 1))?;
717
- rb_tensor.define_method("broadcast_add", method!(PyTensor::broadcast_add, 1))?;
718
- rb_tensor.define_method("broadcast_sub", method!(PyTensor::broadcast_sub, 1))?;
719
- rb_tensor.define_method("broadcast_mul", method!(PyTensor::broadcast_mul, 1))?;
720
- rb_tensor.define_method("broadcast_div", method!(PyTensor::broadcast_div, 1))?;
721
- rb_tensor.define_method("where_cond", method!(PyTensor::where_cond, 2))?;
722
- rb_tensor.define_method("+", method!(PyTensor::__add__, 1))?;
723
- rb_tensor.define_method("*", method!(PyTensor::__mul__, 1))?;
724
- rb_tensor.define_method("-", method!(PyTensor::__sub__, 1))?;
725
- rb_tensor.define_method("reshape", method!(PyTensor::reshape, 1))?;
726
- rb_tensor.define_method("broadcast_as", method!(PyTensor::broadcast_as, 1))?;
727
- rb_tensor.define_method("broadcast_left", method!(PyTensor::broadcast_left, 1))?;
728
- rb_tensor.define_method("squeeze", method!(PyTensor::squeeze, 1))?;
729
- rb_tensor.define_method("unsqueeze", method!(PyTensor::unsqueeze, 1))?;
730
- rb_tensor.define_method("get", method!(PyTensor::get, 1))?;
731
- rb_tensor.define_method("transpose", method!(PyTensor::transpose, 2))?;
732
- rb_tensor.define_method("narrow", method!(PyTensor::narrow, 3))?;
733
- rb_tensor.define_method("argmax_keepdim", method!(PyTensor::argmax_keepdim, 1))?;
734
- rb_tensor.define_method("argmin_keepdim", method!(PyTensor::argmin_keepdim, 1))?;
735
- rb_tensor.define_method("max_keepdim", method!(PyTensor::max_keepdim, 1))?;
736
- rb_tensor.define_method("min_keepdim", method!(PyTensor::min_keepdim, 1))?;
737
- // rb_tensor.define_method("eq", method!(PyTensor::eq, 1))?;
738
- // rb_tensor.define_method("ne", method!(PyTensor::ne, 1))?;
739
- // rb_tensor.define_method("lt", method!(PyTensor::lt, 1))?;
740
- // rb_tensor.define_method("gt", method!(PyTensor::gt, 1))?;
741
- // rb_tensor.define_method("ge", method!(PyTensor::ge, 1))?;
742
- // rb_tensor.define_method("le", method!(PyTensor::le, 1))?;
743
- rb_tensor.define_method("sum_all", method!(PyTensor::sum_all, 0))?;
744
- rb_tensor.define_method("mean_all", method!(PyTensor::mean_all, 0))?;
745
- rb_tensor.define_method("flatten_from", method!(PyTensor::flatten_from, 1))?;
746
- rb_tensor.define_method("flatten_to", method!(PyTensor::flatten_to, 1))?;
747
- rb_tensor.define_method("flatten_all", method!(PyTensor::flatten_all, 0))?;
748
- rb_tensor.define_method("t", method!(PyTensor::t, 0))?;
749
- rb_tensor.define_method("contiguous", method!(PyTensor::contiguous, 0))?;
750
- rb_tensor.define_method("is_contiguous", method!(PyTensor::is_contiguous, 0))?;
12
+ rb_tensor.define_singleton_method("new", function!(RbTensor::new, 2))?;
13
+ // rb_tensor.define_singleton_method("cat", function!(RbTensor::cat, 2))?;
14
+ // rb_tensor.define_singleton_method("stack", function!(RbTensor::stack, 2))?;
15
+ rb_tensor.define_singleton_method("rand", function!(RbTensor::rand, 1))?;
16
+ rb_tensor.define_singleton_method("randn", function!(RbTensor::randn, 1))?;
17
+ rb_tensor.define_singleton_method("ones", function!(RbTensor::ones, 1))?;
18
+ rb_tensor.define_singleton_method("zeros", function!(RbTensor::zeros, 1))?;
19
+ rb_tensor.define_method("values", method!(RbTensor::values, 0))?;
20
+ rb_tensor.define_method("shape", method!(RbTensor::shape, 0))?;
21
+ rb_tensor.define_method("stride", method!(RbTensor::stride, 0))?;
22
+ rb_tensor.define_method("dtype", method!(RbTensor::dtype, 0))?;
23
+ rb_tensor.define_method("device", method!(RbTensor::device, 0))?;
24
+ rb_tensor.define_method("rank", method!(RbTensor::rank, 0))?;
25
+ rb_tensor.define_method("elem_count", method!(RbTensor::elem_count, 0))?;
26
+ rb_tensor.define_method("sin", method!(RbTensor::sin, 0))?;
27
+ rb_tensor.define_method("cos", method!(RbTensor::cos, 0))?;
28
+ rb_tensor.define_method("log", method!(RbTensor::log, 0))?;
29
+ rb_tensor.define_method("sqr", method!(RbTensor::sqr, 0))?;
30
+ rb_tensor.define_method("sqrt", method!(RbTensor::sqrt, 0))?;
31
+ rb_tensor.define_method("recip", method!(RbTensor::recip, 0))?;
32
+ rb_tensor.define_method("exp", method!(RbTensor::exp, 0))?;
33
+ rb_tensor.define_method("powf", method!(RbTensor::powf, 1))?;
34
+ rb_tensor.define_method("index_select", method!(RbTensor::index_select, 2))?;
35
+ rb_tensor.define_method("matmul", method!(RbTensor::matmul, 1))?;
36
+ rb_tensor.define_method("broadcast_add", method!(RbTensor::broadcast_add, 1))?;
37
+ rb_tensor.define_method("broadcast_sub", method!(RbTensor::broadcast_sub, 1))?;
38
+ rb_tensor.define_method("broadcast_mul", method!(RbTensor::broadcast_mul, 1))?;
39
+ rb_tensor.define_method("broadcast_div", method!(RbTensor::broadcast_div, 1))?;
40
+ rb_tensor.define_method("where_cond", method!(RbTensor::where_cond, 2))?;
41
+ rb_tensor.define_method("+", method!(RbTensor::__add__, 1))?;
42
+ rb_tensor.define_method("*", method!(RbTensor::__mul__, 1))?;
43
+ rb_tensor.define_method("-", method!(RbTensor::__sub__, 1))?;
44
+ rb_tensor.define_method("reshape", method!(RbTensor::reshape, 1))?;
45
+ rb_tensor.define_method("broadcast_as", method!(RbTensor::broadcast_as, 1))?;
46
+ rb_tensor.define_method("broadcast_left", method!(RbTensor::broadcast_left, 1))?;
47
+ rb_tensor.define_method("squeeze", method!(RbTensor::squeeze, 1))?;
48
+ rb_tensor.define_method("unsqueeze", method!(RbTensor::unsqueeze, 1))?;
49
+ rb_tensor.define_method("get", method!(RbTensor::get, 1))?;
50
+ rb_tensor.define_method("[]", method!(RbTensor::get, 1))?;
51
+ rb_tensor.define_method("transpose", method!(RbTensor::transpose, 2))?;
52
+ rb_tensor.define_method("narrow", method!(RbTensor::narrow, 3))?;
53
+ rb_tensor.define_method("argmax_keepdim", method!(RbTensor::argmax_keepdim, 1))?;
54
+ rb_tensor.define_method("argmin_keepdim", method!(RbTensor::argmin_keepdim, 1))?;
55
+ rb_tensor.define_method("max_keepdim", method!(RbTensor::max_keepdim, 1))?;
56
+ rb_tensor.define_method("min_keepdim", method!(RbTensor::min_keepdim, 1))?;
57
+ // rb_tensor.define_method("eq", method!(RbTensor::eq, 1))?;
58
+ // rb_tensor.define_method("ne", method!(RbTensor::ne, 1))?;
59
+ // rb_tensor.define_method("lt", method!(RbTensor::lt, 1))?;
60
+ // rb_tensor.define_method("gt", method!(RbTensor::gt, 1))?;
61
+ // rb_tensor.define_method("ge", method!(RbTensor::ge, 1))?;
62
+ // rb_tensor.define_method("le", method!(RbTensor::le, 1))?;
63
+ rb_tensor.define_method("sum_all", method!(RbTensor::sum_all, 0))?;
64
+ rb_tensor.define_method("mean_all", method!(RbTensor::mean_all, 0))?;
65
+ rb_tensor.define_method("flatten_from", method!(RbTensor::flatten_from, 1))?;
66
+ rb_tensor.define_method("flatten_to", method!(RbTensor::flatten_to, 1))?;
67
+ rb_tensor.define_method("flatten_all", method!(RbTensor::flatten_all, 0))?;
68
+ rb_tensor.define_method("t", method!(RbTensor::t, 0))?;
69
+ rb_tensor.define_method("contiguous", method!(RbTensor::contiguous, 0))?;
70
+ rb_tensor.define_method("is_contiguous", method!(RbTensor::is_contiguous, 0))?;
751
71
  rb_tensor.define_method(
752
72
  "is_fortran_contiguous",
753
- method!(PyTensor::is_fortran_contiguous, 0),
73
+ method!(RbTensor::is_fortran_contiguous, 0),
754
74
  )?;
755
- rb_tensor.define_method("detach", method!(PyTensor::detach, 0))?;
756
- rb_tensor.define_method("copy", method!(PyTensor::copy, 0))?;
757
- rb_tensor.define_method("to_dtype", method!(PyTensor::to_dtype, 1))?;
758
- rb_tensor.define_method("to_device", method!(PyTensor::to_device, 1))?;
759
- rb_tensor.define_method("to_s", method!(PyTensor::__str__, 0))?;
760
- rb_tensor.define_method("inspect", method!(PyTensor::__repr__, 0))?;
75
+ rb_tensor.define_method("detach", method!(RbTensor::detach, 0))?;
76
+ rb_tensor.define_method("copy", method!(RbTensor::copy, 0))?;
77
+ rb_tensor.define_method("to_dtype", method!(RbTensor::to_dtype, 1))?;
78
+ rb_tensor.define_method("to_device", method!(RbTensor::to_device, 1))?;
79
+ rb_tensor.define_method("to_s", method!(RbTensor::__str__, 0))?;
80
+ rb_tensor.define_method("inspect", method!(RbTensor::__repr__, 0))?;
81
+
761
82
  let rb_dtype = rb_candle.define_class("DType", Ruby::class_object(ruby))?;
762
- rb_dtype.define_method("to_s", method!(PyDType::__str__, 0))?;
763
- rb_dtype.define_method("inspect", method!(PyDType::__repr__, 0))?;
83
+ rb_dtype.define_method("to_s", method!(RbDType::__str__, 0))?;
84
+ rb_dtype.define_method("inspect", method!(RbDType::__repr__, 0))?;
85
+
764
86
  let rb_device = rb_candle.define_class("Device", Ruby::class_object(ruby))?;
765
- rb_device.define_method("to_s", method!(PyDevice::__str__, 0))?;
766
- rb_device.define_method("inspect", method!(PyDevice::__repr__, 0))?;
87
+ rb_device.define_method("to_s", method!(RbDevice::__str__, 0))?;
88
+ rb_device.define_method("inspect", method!(RbDevice::__repr__, 0))?;
89
+
767
90
  let rb_qtensor = rb_candle.define_class("QTensor", Ruby::class_object(ruby))?;
768
- rb_qtensor.define_method("ggml_dtype", method!(PyQTensor::ggml_dtype, 0))?;
769
- rb_qtensor.define_method("rank", method!(PyQTensor::rank, 0))?;
770
- rb_qtensor.define_method("shape", method!(PyQTensor::shape, 0))?;
771
- rb_qtensor.define_method("dequantize", method!(PyQTensor::dequantize, 0))?;
91
+ rb_qtensor.define_method("ggml_dtype", method!(RbQTensor::ggml_dtype, 0))?;
92
+ rb_qtensor.define_method("rank", method!(RbQTensor::rank, 0))?;
93
+ rb_qtensor.define_method("shape", method!(RbQTensor::shape, 0))?;
94
+ rb_qtensor.define_method("dequantize", method!(RbQTensor::dequantize, 0))?;
95
+
96
+ let rb_model = rb_candle.define_class("Model", Ruby::class_object(ruby))?;
97
+ rb_model.define_singleton_method("new", function!(RbModel::new, 0))?;
98
+ rb_model.define_method("embedding", method!(RbModel::embedding, 1))?;
99
+ rb_model.define_method("to_s", method!(RbModel::__str__, 0))?;
100
+ rb_model.define_method("inspect", method!(RbModel::__repr__, 0))?;
101
+
772
102
  Ok(())
773
103
  }