red-candle 0.0.3

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,773 @@
1
+ use magnus::{function, method, prelude::*, Error, Ruby};
2
+ use std::sync::Arc;
3
+
4
+ use half::{bf16, f16};
5
+
6
+ use ::candle_core::{quantized::QTensor, DType, Device, Tensor, WithDType};
7
+
8
+ type PyResult<T> = Result<T, Error>;
9
+
10
+ pub fn wrap_err(err: candle_core::Error) -> Error {
11
+ Error::new(magnus::exception::runtime_error(), err.to_string())
12
+ }
13
+
14
+ // #[derive(Clone, Debug)]
15
+ // struct RbShape(Vec<usize>);
16
+
17
+ // impl magnus::TryConvert for RbShape {
18
+ // fn try_convert(val: magnus::Value) -> PyResult<Self> {
19
+ // let ary = magnus::RArray::try_convert(val)?;
20
+ // let shape = ary
21
+ // .each()
22
+ // .map(|v| magnus::Integer::try_convert(v?).map(|v| v.to_usize().unwrap()))
23
+ // .collect::<PyResult<Vec<_>>>()?;
24
+ // Ok(Self(shape))
25
+ // }
26
+ // }
27
+
28
+ // impl magnus::IntoValue for RbShape {
29
+ // fn into_value_with(self, ruby: &Ruby) -> magnus::Value {
30
+ // let ary = magnus::RArray::from_vec(self.0);
31
+ // ary.into_value_with(ruby)
32
+ // }
33
+ //}
34
+
35
+ #[derive(Clone, Debug)]
36
+ #[magnus::wrap(class = "Candle::Tensor", free_immediately, size)]
37
+ /// A `candle` tensor.
38
+ struct PyTensor(Tensor);
39
+
40
+ impl std::ops::Deref for PyTensor {
41
+ type Target = Tensor;
42
+
43
+ fn deref(&self) -> &Self::Target {
44
+ &self.0
45
+ }
46
+ }
47
+
48
+ #[derive(Clone, Copy, Debug, PartialEq, Eq)]
49
+ #[magnus::wrap(class = "Candle::DType", free_immediately, size)]
50
+ /// A `candle` dtype.
51
+ struct PyDType(DType);
52
+
53
+ impl PyDType {
54
+ fn __repr__(&self) -> String {
55
+ format!("{:?}", self.0)
56
+ }
57
+
58
+ fn __str__(&self) -> String {
59
+ self.__repr__()
60
+ }
61
+ }
62
+
63
+ impl PyDType {
64
+ fn from_pyobject(dtype: magnus::Symbol) -> PyResult<Self> {
65
+ let dtype = unsafe { dtype.to_s() }.unwrap().into_owned();
66
+ use std::str::FromStr;
67
+ let dtype = DType::from_str(&dtype).unwrap();
68
+ Ok(Self(dtype))
69
+ }
70
+ }
71
+
72
+ static CUDA_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
73
+
74
+ #[derive(Clone, Copy, Debug, PartialEq, Eq)]
75
+ #[magnus::wrap(class = "Candle::Device")]
76
+ enum PyDevice {
77
+ Cpu,
78
+ Cuda,
79
+ }
80
+
81
+ impl PyDevice {
82
+ fn from_device(device: &Device) -> Self {
83
+ match device {
84
+ Device::Cpu => Self::Cpu,
85
+ Device::Cuda(_) => Self::Cuda,
86
+ }
87
+ }
88
+
89
+ fn as_device(&self) -> PyResult<Device> {
90
+ match self {
91
+ Self::Cpu => Ok(Device::Cpu),
92
+ Self::Cuda => {
93
+ let mut device = CUDA_DEVICE.lock().unwrap();
94
+ if let Some(device) = device.as_ref() {
95
+ return Ok(device.clone());
96
+ };
97
+ let d = Device::new_cuda(0).map_err(wrap_err)?;
98
+ *device = Some(d.clone());
99
+ Ok(d)
100
+ }
101
+ }
102
+ }
103
+
104
+ fn __repr__(&self) -> String {
105
+ match self {
106
+ Self::Cpu => "cpu".to_string(),
107
+ Self::Cuda => "cuda".to_string(),
108
+ }
109
+ }
110
+
111
+ fn __str__(&self) -> String {
112
+ self.__repr__()
113
+ }
114
+ }
115
+
116
+ impl magnus::TryConvert for PyDevice {
117
+ fn try_convert(val: magnus::Value) -> PyResult<Self> {
118
+ let device = magnus::RString::try_convert(val)?;
119
+ let device = unsafe { device.as_str() }.unwrap();
120
+ let device = match device {
121
+ "cpu" => PyDevice::Cpu,
122
+ "cuda" => PyDevice::Cuda,
123
+ _ => return Err(Error::new(magnus::exception::arg_error(), "invalid device")),
124
+ };
125
+ Ok(device)
126
+ }
127
+ }
128
+
129
+ fn actual_index(t: &Tensor, dim: usize, index: i64) -> candle_core::Result<usize> {
130
+ let dim = t.dim(dim)?;
131
+ if 0 <= index {
132
+ let index = index as usize;
133
+ if dim <= index {
134
+ candle_core::bail!("index {index} is too large for tensor dimension {dim}")
135
+ }
136
+ Ok(index)
137
+ } else {
138
+ if (dim as i64) < -index {
139
+ candle_core::bail!("index {index} is too low for tensor dimension {dim}")
140
+ }
141
+ Ok((dim as i64 + index) as usize)
142
+ }
143
+ }
144
+
145
+ fn actual_dim(t: &Tensor, dim: i64) -> candle_core::Result<usize> {
146
+ let rank = t.rank();
147
+ if 0 <= dim {
148
+ let dim = dim as usize;
149
+ if rank <= dim {
150
+ candle_core::bail!("dimension index {dim} is too large for tensor rank {rank}")
151
+ }
152
+ Ok(dim)
153
+ } else {
154
+ if (rank as i64) < -dim {
155
+ candle_core::bail!("dimension index {dim} is too low for tensor rank {rank}")
156
+ }
157
+ Ok((rank as i64 + dim) as usize)
158
+ }
159
+ }
160
+ impl PyTensor {
161
+ fn new(array: magnus::RArray, dtype: Option<magnus::Symbol>) -> PyResult<Self> {
162
+ let dtype = dtype
163
+ .map(|dtype| PyDType::from_pyobject(dtype))
164
+ .unwrap_or(Ok(PyDType(DType::F32)))?;
165
+ // FIXME: Do not use `to_f64` here.
166
+ let array = array
167
+ .each()
168
+ .map(|v| magnus::Float::try_convert(v?).map(|v| v.to_f64()))
169
+ .collect::<PyResult<Vec<_>>>()?;
170
+ Ok(Self(
171
+ Tensor::new(array.as_slice(), &Device::Cpu)
172
+ .map_err(wrap_err)?
173
+ .to_dtype(dtype.0)
174
+ .map_err(wrap_err)?,
175
+ ))
176
+ }
177
+
178
+ /// Gets the tensor's shape.
179
+ /// &RETURNS&: Tuple[int]
180
+ fn shape(&self) -> Vec<usize> {
181
+ self.0.dims().to_vec()
182
+ }
183
+
184
+ /// Gets the tensor's strides.
185
+ /// &RETURNS&: Tuple[int]
186
+ fn stride(&self) -> Vec<usize> {
187
+ self.0.stride().to_vec()
188
+ }
189
+
190
+ /// Gets the tensor's dtype.
191
+ /// &RETURNS&: DType
192
+ fn dtype(&self) -> PyDType {
193
+ PyDType(self.0.dtype())
194
+ }
195
+
196
+ /// Gets the tensor's device.
197
+ /// &RETURNS&: Device
198
+ fn device(&self) -> PyDevice {
199
+ PyDevice::from_device(self.0.device())
200
+ }
201
+
202
+ /// Gets the tensor's rank.
203
+ /// &RETURNS&: int
204
+ fn rank(&self) -> usize {
205
+ self.0.rank()
206
+ }
207
+
208
+ fn __repr__(&self) -> String {
209
+ format!("{}", self.0)
210
+ }
211
+
212
+ fn __str__(&self) -> String {
213
+ self.__repr__()
214
+ }
215
+
216
+ /// Performs the `sin` operation on the tensor.
217
+ /// &RETURNS&: Tensor
218
+ fn sin(&self) -> PyResult<Self> {
219
+ Ok(PyTensor(self.0.sin().map_err(wrap_err)?))
220
+ }
221
+
222
+ /// Performs the `cos` operation on the tensor.
223
+ /// &RETURNS&: Tensor
224
+ fn cos(&self) -> PyResult<Self> {
225
+ Ok(PyTensor(self.0.cos().map_err(wrap_err)?))
226
+ }
227
+
228
+ /// Performs the `log` operation on the tensor.
229
+ /// &RETURNS&: Tensor
230
+ fn log(&self) -> PyResult<Self> {
231
+ Ok(PyTensor(self.0.log().map_err(wrap_err)?))
232
+ }
233
+
234
+ /// Squares the tensor.
235
+ /// &RETURNS&: Tensor
236
+ fn sqr(&self) -> PyResult<Self> {
237
+ Ok(PyTensor(self.0.sqr().map_err(wrap_err)?))
238
+ }
239
+
240
+ /// Calculates the square root of the tensor.
241
+ /// &RETURNS&: Tensor
242
+ fn sqrt(&self) -> PyResult<Self> {
243
+ Ok(PyTensor(self.0.sqrt().map_err(wrap_err)?))
244
+ }
245
+
246
+ /// Get the `recip` of the tensor.
247
+ /// &RETURNS&: Tensor
248
+ fn recip(&self) -> PyResult<Self> {
249
+ Ok(PyTensor(self.0.recip().map_err(wrap_err)?))
250
+ }
251
+
252
+ /// Performs the `exp` operation on the tensor.
253
+ /// &RETURNS&: Tensor
254
+ fn exp(&self) -> PyResult<Self> {
255
+ Ok(PyTensor(self.0.exp().map_err(wrap_err)?))
256
+ }
257
+
258
+ /// Performs the `pow` operation on the tensor with the given exponent.
259
+ /// &RETURNS&: Tensor
260
+ fn powf(&self, p: f64) -> PyResult<Self> {
261
+ Ok(PyTensor(self.0.powf(p).map_err(wrap_err)?))
262
+ }
263
+
264
+ /// Select values for the input tensor at the target indexes across the specified dimension.
265
+ ///
266
+ /// The `indexes` is argument is an int tensor with a single dimension.
267
+ /// The output has the same number of dimension as the `self` input. The target dimension of
268
+ /// the output has length the length of `indexes` and the values are taken from `self` using
269
+ /// the index from `indexes`. Other dimensions have the same number of elements as the input
270
+ /// tensor.
271
+ /// &RETURNS&: Tensor
272
+ fn index_select(&self, rhs: &Self, dim: i64) -> PyResult<Self> {
273
+ let dim = actual_dim(self, dim).map_err(wrap_err)?;
274
+ Ok(PyTensor(self.0.index_select(rhs, dim).map_err(wrap_err)?))
275
+ }
276
+
277
+ /// Performs a matrix multiplication between the two tensors.
278
+ /// &RETURNS&: Tensor
279
+ fn matmul(&self, rhs: &Self) -> PyResult<Self> {
280
+ Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?))
281
+ }
282
+
283
+ /// Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
284
+ /// &RETURNS&: Tensor
285
+ fn broadcast_add(&self, rhs: &Self) -> PyResult<Self> {
286
+ Ok(PyTensor(self.0.broadcast_add(rhs).map_err(wrap_err)?))
287
+ }
288
+
289
+ /// Subtracts the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
290
+ /// &RETURNS&: Tensor
291
+ fn broadcast_sub(&self, rhs: &Self) -> PyResult<Self> {
292
+ Ok(PyTensor(self.0.broadcast_sub(rhs).map_err(wrap_err)?))
293
+ }
294
+
295
+ /// Multiplies the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
296
+ /// &RETURNS&: Tensor
297
+ fn broadcast_mul(&self, rhs: &Self) -> PyResult<Self> {
298
+ Ok(PyTensor(self.0.broadcast_mul(rhs).map_err(wrap_err)?))
299
+ }
300
+
301
+ /// Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
302
+ /// &RETURNS&: Tensor
303
+ fn broadcast_div(&self, rhs: &Self) -> PyResult<Self> {
304
+ Ok(PyTensor(self.0.broadcast_div(rhs).map_err(wrap_err)?))
305
+ }
306
+
307
+ /// Returns a tensor with the same shape as the input tensor, the values are taken from
308
+ /// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the
309
+ /// input tensor is equal to zero.
310
+ /// &RETURNS&: Tensor
311
+ fn where_cond(&self, on_true: &Self, on_false: &Self) -> PyResult<Self> {
312
+ Ok(PyTensor(
313
+ self.0.where_cond(on_true, on_false).map_err(wrap_err)?,
314
+ ))
315
+ }
316
+
317
+ /// Add two tensors.
318
+ /// &RETURNS&: Tensor
319
+ fn __add__(&self, rhs: &PyTensor) -> PyResult<Self> {
320
+ Ok(Self(self.0.add(&rhs.0).map_err(wrap_err)?))
321
+ }
322
+
323
+ /// Multiply two tensors.
324
+ /// &RETURNS&: Tensor
325
+ fn __mul__(&self, rhs: &PyTensor) -> PyResult<Self> {
326
+ Ok(Self(self.0.mul(&rhs.0).map_err(wrap_err)?))
327
+ }
328
+
329
+ /// Subtract two tensors.
330
+ /// &RETURNS&: Tensor
331
+ fn __sub__(&self, rhs: &PyTensor) -> PyResult<Self> {
332
+ Ok(Self(self.0.sub(&rhs.0).map_err(wrap_err)?))
333
+ }
334
+
335
+ /// Divide two tensors.
336
+ /// &RETURNS&: Tensor
337
+ fn __truediv__(&self, rhs: &PyTensor) -> PyResult<Self> {
338
+ Ok(Self(self.0.div(&rhs.0).map_err(wrap_err)?))
339
+ }
340
+
341
+ /// Reshapes the tensor to the given shape.
342
+ /// &RETURNS&: Tensor
343
+ fn reshape(&self, shape: Vec<usize>) -> PyResult<Self> {
344
+ Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
345
+ }
346
+
347
+ /// Broadcasts the tensor to the given shape.
348
+ /// &RETURNS&: Tensor
349
+ fn broadcast_as(&self, shape: Vec<usize>) -> PyResult<Self> {
350
+ Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?))
351
+ }
352
+
353
+ /// Broadcasts the tensor to the given shape, adding new dimensions on the left.
354
+ /// &RETURNS&: Tensor
355
+ fn broadcast_left(&self, shape: Vec<usize>) -> PyResult<Self> {
356
+ Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?))
357
+ }
358
+
359
+ /// Creates a new tensor with the specified dimension removed if its size was one.
360
+ /// &RETURNS&: Tensor
361
+ fn squeeze(&self, dim: i64) -> PyResult<Self> {
362
+ let dim = actual_dim(self, dim).map_err(wrap_err)?;
363
+ Ok(PyTensor(self.0.squeeze(dim).map_err(wrap_err)?))
364
+ }
365
+
366
+ /// Creates a new tensor with a dimension of size one inserted at the specified position.
367
+ /// &RETURNS&: Tensor
368
+ fn unsqueeze(&self, dim: usize) -> PyResult<Self> {
369
+ Ok(PyTensor(self.0.unsqueeze(dim).map_err(wrap_err)?))
370
+ }
371
+
372
+ /// Gets the value at the specified index.
373
+ /// &RETURNS&: Tensor
374
+ fn get(&self, index: i64) -> PyResult<Self> {
375
+ let index = actual_index(self, 0, index).map_err(wrap_err)?;
376
+ Ok(PyTensor(self.0.get(index).map_err(wrap_err)?))
377
+ }
378
+
379
+ /// Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
380
+ /// &RETURNS&: Tensor
381
+ fn transpose(&self, dim1: usize, dim2: usize) -> PyResult<Self> {
382
+ Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?))
383
+ }
384
+
385
+ /// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
386
+ /// ranges from `start` to `start + len`.
387
+ /// &RETURNS&: Tensor
388
+ fn narrow(&self, dim: i64, start: i64, len: usize) -> PyResult<Self> {
389
+ let dim = actual_dim(self, dim).map_err(wrap_err)?;
390
+ let start = actual_index(self, dim, start).map_err(wrap_err)?;
391
+ Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?))
392
+ }
393
+
394
+ /// Returns the indices of the maximum value(s) across the selected dimension.
395
+ /// &RETURNS&: Tensor
396
+ fn argmax_keepdim(&self, dim: i64) -> PyResult<Self> {
397
+ let dim = actual_dim(self, dim).map_err(wrap_err)?;
398
+ Ok(PyTensor(self.0.argmax_keepdim(dim).map_err(wrap_err)?))
399
+ }
400
+
401
+ /// Returns the indices of the minimum value(s) across the selected dimension.
402
+ /// &RETURNS&: Tensor
403
+ fn argmin_keepdim(&self, dim: i64) -> PyResult<Self> {
404
+ let dim = actual_dim(self, dim).map_err(wrap_err)?;
405
+ Ok(PyTensor(self.0.argmin_keepdim(dim).map_err(wrap_err)?))
406
+ }
407
+
408
+ /// Gathers the maximum value across the selected dimension.
409
+ /// &RETURNS&: Tensor
410
+ fn max_keepdim(&self, dim: i64) -> PyResult<Self> {
411
+ let dim = actual_dim(self, dim).map_err(wrap_err)?;
412
+ Ok(PyTensor(self.0.max_keepdim(dim).map_err(wrap_err)?))
413
+ }
414
+
415
+ /// Gathers the minimum value across the selected dimension.
416
+ /// &RETURNS&: Tensor
417
+ fn min_keepdim(&self, dim: i64) -> PyResult<Self> {
418
+ let dim = actual_dim(self, dim).map_err(wrap_err)?;
419
+ Ok(PyTensor(self.0.min_keepdim(dim).map_err(wrap_err)?))
420
+ }
421
+
422
+ // fn eq(&self, rhs: &Self) -> PyResult<Self> {
423
+ // Ok(PyTensor(self.0.eq(rhs).map_err(wrap_err)?))
424
+ // }
425
+
426
+ // fn ne(&self, rhs: &Self) -> PyResult<Self> {
427
+ // Ok(PyTensor(self.0.ne(rhs).map_err(wrap_err)?))
428
+ // }
429
+
430
+ // fn lt(&self, rhs: &Self) -> PyResult<Self> {
431
+ // Ok(PyTensor(self.0.lt(rhs).map_err(wrap_err)?))
432
+ // }
433
+
434
+ // fn gt(&self, rhs: &Self) -> PyResult<Self> {
435
+ // Ok(PyTensor(self.0.gt(rhs).map_err(wrap_err)?))
436
+ // }
437
+
438
+ // fn ge(&self, rhs: &Self) -> PyResult<Self> {
439
+ // Ok(PyTensor(self.0.ge(rhs).map_err(wrap_err)?))
440
+ // }
441
+
442
+ // fn le(&self, rhs: &Self) -> PyResult<Self> {
443
+ // Ok(PyTensor(self.0.le(rhs).map_err(wrap_err)?))
444
+ // }
445
+
446
+ /// Returns the sum of the tensor.
447
+ /// &RETURNS&: Tensor
448
+ fn sum_all(&self) -> PyResult<Self> {
449
+ Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?))
450
+ }
451
+
452
+ /// Returns the mean of the tensor.
453
+ /// &RETURNS&: Tensor
454
+ fn mean_all(&self) -> PyResult<Self> {
455
+ let elements = self.0.elem_count();
456
+ let sum = self.0.sum_all().map_err(wrap_err)?;
457
+ let mean = (sum / elements as f64).map_err(wrap_err)?;
458
+ Ok(PyTensor(mean))
459
+ }
460
+
461
+ /// Flattens the tensor on the dimension indexes from `dim` (inclusive) to the last dimension.
462
+ /// &RETURNS&: Tensor
463
+ fn flatten_from(&self, dim: i64) -> PyResult<Self> {
464
+ let dim = actual_dim(self, dim).map_err(wrap_err)?;
465
+ Ok(PyTensor(self.0.flatten_from(dim).map_err(wrap_err)?))
466
+ }
467
+
468
+ ///Flattens the tensor on the dimension indexes from `0` to `dim` (inclusive).
469
+ /// &RETURNS&: Tensor
470
+ fn flatten_to(&self, dim: i64) -> PyResult<Self> {
471
+ let dim = actual_dim(self, dim).map_err(wrap_err)?;
472
+ Ok(PyTensor(self.0.flatten_to(dim).map_err(wrap_err)?))
473
+ }
474
+
475
+ /// Flattens the tensor into a 1D tensor.
476
+ /// &RETURNS&: Tensor
477
+ fn flatten_all(&self) -> PyResult<Self> {
478
+ Ok(PyTensor(self.0.flatten_all().map_err(wrap_err)?))
479
+ }
480
+
481
+ /// Transposes the tensor.
482
+ /// &RETURNS&: Tensor
483
+ fn t(&self) -> PyResult<Self> {
484
+ Ok(PyTensor(self.0.t().map_err(wrap_err)?))
485
+ }
486
+
487
+ /// Makes the tensor contiguous in memory.
488
+ /// &RETURNS&: Tensor
489
+ fn contiguous(&self) -> PyResult<Self> {
490
+ Ok(PyTensor(self.0.contiguous().map_err(wrap_err)?))
491
+ }
492
+
493
+ /// Returns true if the tensor is contiguous in C order.
494
+ /// &RETURNS&: bool
495
+ fn is_contiguous(&self) -> bool {
496
+ self.0.is_contiguous()
497
+ }
498
+
499
+ /// Returns true if the tensor is contiguous in Fortran order.
500
+ /// &RETURNS&: bool
501
+ fn is_fortran_contiguous(&self) -> bool {
502
+ self.0.is_fortran_contiguous()
503
+ }
504
+
505
+ /// Detach the tensor from the computation graph.
506
+ /// &RETURNS&: Tensor
507
+ fn detach(&self) -> PyResult<Self> {
508
+ Ok(PyTensor(self.0.detach().map_err(wrap_err)?))
509
+ }
510
+
511
+ /// Returns a copy of the tensor.
512
+ /// &RETURNS&: Tensor
513
+ fn copy(&self) -> PyResult<Self> {
514
+ Ok(PyTensor(self.0.copy().map_err(wrap_err)?))
515
+ }
516
+
517
+ /// Convert the tensor to a new dtype.
518
+ /// &RETURNS&: Tensor
519
+ fn to_dtype(&self, dtype: magnus::Symbol) -> PyResult<Self> {
520
+ let dtype = PyDType::from_pyobject(dtype)?;
521
+ Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?))
522
+ }
523
+
524
+ /// Move the tensor to a new device.
525
+ /// &RETURNS&: Tensor
526
+ fn to_device(&self, device: PyDevice) -> PyResult<Self> {
527
+ let device = device.as_device()?;
528
+ Ok(PyTensor(self.0.to_device(&device).map_err(wrap_err)?))
529
+ }
530
+ }
531
+
532
+ impl PyTensor {
533
+ // fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
534
+ // if tensors.is_empty() {
535
+ // return Err(Error::new(
536
+ // magnus::exception::arg_error(),
537
+ // "empty input to cat",
538
+ // ));
539
+ // }
540
+ // let dim = actual_dim(&tensors[0].0, dim).map_err(wrap_err)?;
541
+ // let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
542
+ // let tensor = Tensor::cat(&tensors, dim).map_err(wrap_err)?;
543
+ // Ok(PyTensor(tensor))
544
+ // }
545
+
546
+ // fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<Self> {
547
+ // let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
548
+ // let tensor = Tensor::stack(&tensors, dim).map_err(wrap_err)?;
549
+ // Ok(Self(tensor))
550
+ // }
551
+
552
+ /// Creates a new tensor with random values.
553
+ /// &RETURNS&: Tensor
554
+ fn rand(shape: Vec<usize>) -> PyResult<Self> {
555
+ let device = PyDevice::Cpu.as_device()?;
556
+ Ok(Self(
557
+ Tensor::rand(0f32, 1f32, shape, &device).map_err(wrap_err)?,
558
+ ))
559
+ }
560
+
561
+ /// Creates a new tensor with random values from a normal distribution.
562
+ /// &RETURNS&: Tensor
563
+ fn randn(shape: Vec<usize>) -> PyResult<Self> {
564
+ let device = PyDevice::Cpu.as_device()?;
565
+ Ok(Self(
566
+ Tensor::randn(0f32, 1f32, shape, &device).map_err(wrap_err)?,
567
+ ))
568
+ }
569
+
570
+ /// Creates a new tensor filled with ones.
571
+ /// &RETURNS&: Tensor
572
+ fn ones(shape: Vec<usize>) -> PyResult<Self> {
573
+ let device = PyDevice::Cpu.as_device()?;
574
+ Ok(Self(
575
+ Tensor::ones(shape, DType::F32, &device).map_err(wrap_err)?,
576
+ ))
577
+ }
578
+ /// Creates a new tensor filled with zeros.
579
+ /// &RETURNS&: Tensor
580
+ fn zeros(shape: Vec<usize>) -> PyResult<Self> {
581
+ let device = PyDevice::Cpu.as_device()?;
582
+ Ok(Self(
583
+ Tensor::zeros(shape, DType::F32, &device).map_err(wrap_err)?,
584
+ ))
585
+ }
586
+ }
587
+
588
+ #[derive(Debug)]
589
+ #[magnus::wrap(class = "Candle::QTensor", free_immediately, size)]
590
+ /// A quantized tensor.
591
+ struct PyQTensor(Arc<QTensor>);
592
+
593
+ impl std::ops::Deref for PyQTensor {
594
+ type Target = QTensor;
595
+
596
+ fn deref(&self) -> &Self::Target {
597
+ self.0.as_ref()
598
+ }
599
+ }
600
+
601
+ impl PyQTensor {
602
+ ///Gets the tensors quantized dtype.
603
+ /// &RETURNS&: str
604
+ fn ggml_dtype(&self) -> String {
605
+ format!("{:?}", self.0.dtype())
606
+ }
607
+
608
+ ///Gets the rank of the tensor.
609
+ /// &RETURNS&: int
610
+ fn rank(&self) -> usize {
611
+ self.0.rank()
612
+ }
613
+
614
+ ///Gets the shape of the tensor.
615
+ /// &RETURNS&: Tuple[int]
616
+ fn shape(&self) -> Vec<usize> {
617
+ self.0.shape().dims().to_vec()
618
+ }
619
+
620
+ fn __repr__(&self) -> String {
621
+ format!("{:?}", self.0)
622
+ }
623
+
624
+ fn __str__(&self) -> String {
625
+ self.__repr__()
626
+ }
627
+
628
+ /// Dequantizes the tensor.
629
+ /// &RETURNS&: Tensor
630
+ fn dequantize(&self) -> PyResult<PyTensor> {
631
+ let tensor = self.0.dequantize(&Device::Cpu).map_err(wrap_err)?;
632
+ Ok(PyTensor(tensor))
633
+ }
634
+
635
+ // fn matmul_t(&self, lhs: &PyTensor) -> PyResult<PyTensor> {
636
+ // let qmatmul = ::candle_core::quantized::QMatMul::from_arc(self.0.clone());
637
+ // let res = qmatmul.forward(lhs).map_err(wrap_err)?;
638
+ // Ok(PyTensor(res))
639
+ // }
640
+ }
641
+
642
+ /// Returns true if the 'cuda' backend is available.
643
+ /// &RETURNS&: bool
644
+ fn cuda_is_available() -> bool {
645
+ candle_core::utils::cuda_is_available()
646
+ }
647
+
648
+ /// Returns true if candle was compiled with 'accelerate' support.
649
+ /// &RETURNS&: bool
650
+ fn has_accelerate() -> bool {
651
+ candle_core::utils::has_accelerate()
652
+ }
653
+
654
+ /// Returns true if candle was compiled with MKL support.
655
+ /// &RETURNS&: bool
656
+ fn has_mkl() -> bool {
657
+ candle_core::utils::has_mkl()
658
+ }
659
+
660
+ /// Returns the number of threads used by the candle.
661
+ /// &RETURNS&: int
662
+ fn get_num_threads() -> usize {
663
+ candle_core::utils::get_num_threads()
664
+ }
665
+
666
+ fn candle_utils(rb_candle: magnus::RModule) -> Result<(), Error> {
667
+ let rb_utils = rb_candle.define_module("Utils")?;
668
+ rb_utils.define_singleton_method("cuda_is_available", function!(cuda_is_available, 0))?;
669
+ rb_utils.define_singleton_method("get_num_threads", function!(get_num_threads, 0))?;
670
+ rb_utils.define_singleton_method("has_accelerate", function!(has_accelerate, 0))?;
671
+ rb_utils.define_singleton_method("has_mkl", function!(has_mkl, 0))?;
672
+ Ok(())
673
+ }
674
+
675
+ /// Applies the Softmax function to a given tensor.#
676
+ /// &RETURNS&: Tensor
677
+ fn softmax(tensor: PyTensor, dim: i64) -> PyResult<PyTensor> {
678
+ let dim = actual_dim(&tensor, dim).map_err(wrap_err)?;
679
+ let sm = candle_nn::ops::softmax(&tensor.0, dim).map_err(wrap_err)?;
680
+ Ok(PyTensor(sm))
681
+ }
682
+
683
+ /// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
684
+ /// &RETURNS&: Tensor
685
+ fn silu(tensor: PyTensor) -> PyResult<PyTensor> {
686
+ let s = candle_nn::ops::silu(&tensor.0).map_err(wrap_err)?;
687
+ Ok(PyTensor(s))
688
+ }
689
+
690
+ #[magnus::init]
691
+ fn init(ruby: &Ruby) -> PyResult<()> {
692
+ let rb_candle = ruby.define_module("Candle")?;
693
+ candle_utils(rb_candle)?;
694
+ let rb_tensor = rb_candle.define_class("Tensor", Ruby::class_object(ruby))?;
695
+ rb_tensor.define_singleton_method("new", function!(PyTensor::new, 2))?;
696
+ // rb_tensor.define_singleton_method("cat", function!(PyTensor::cat, 2))?;
697
+ // rb_tensor.define_singleton_method("stack", function!(PyTensor::stack, 2))?;
698
+ rb_tensor.define_singleton_method("rand", function!(PyTensor::rand, 1))?;
699
+ rb_tensor.define_singleton_method("randn", function!(PyTensor::randn, 1))?;
700
+ rb_tensor.define_singleton_method("ones", function!(PyTensor::ones, 1))?;
701
+ rb_tensor.define_singleton_method("zeros", function!(PyTensor::zeros, 1))?;
702
+ rb_tensor.define_method("shape", method!(PyTensor::shape, 0))?;
703
+ rb_tensor.define_method("stride", method!(PyTensor::stride, 0))?;
704
+ rb_tensor.define_method("dtype", method!(PyTensor::dtype, 0))?;
705
+ rb_tensor.define_method("device", method!(PyTensor::device, 0))?;
706
+ rb_tensor.define_method("rank", method!(PyTensor::rank, 0))?;
707
+ rb_tensor.define_method("sin", method!(PyTensor::sin, 0))?;
708
+ rb_tensor.define_method("cos", method!(PyTensor::cos, 0))?;
709
+ rb_tensor.define_method("log", method!(PyTensor::log, 0))?;
710
+ rb_tensor.define_method("sqr", method!(PyTensor::sqr, 0))?;
711
+ rb_tensor.define_method("sqrt", method!(PyTensor::sqrt, 0))?;
712
+ rb_tensor.define_method("recip", method!(PyTensor::recip, 0))?;
713
+ rb_tensor.define_method("exp", method!(PyTensor::exp, 0))?;
714
+ rb_tensor.define_method("powf", method!(PyTensor::powf, 1))?;
715
+ rb_tensor.define_method("index_select", method!(PyTensor::index_select, 2))?;
716
+ rb_tensor.define_method("matmul", method!(PyTensor::matmul, 1))?;
717
+ rb_tensor.define_method("broadcast_add", method!(PyTensor::broadcast_add, 1))?;
718
+ rb_tensor.define_method("broadcast_sub", method!(PyTensor::broadcast_sub, 1))?;
719
+ rb_tensor.define_method("broadcast_mul", method!(PyTensor::broadcast_mul, 1))?;
720
+ rb_tensor.define_method("broadcast_div", method!(PyTensor::broadcast_div, 1))?;
721
+ rb_tensor.define_method("where_cond", method!(PyTensor::where_cond, 2))?;
722
+ rb_tensor.define_method("+", method!(PyTensor::__add__, 1))?;
723
+ rb_tensor.define_method("*", method!(PyTensor::__mul__, 1))?;
724
+ rb_tensor.define_method("-", method!(PyTensor::__sub__, 1))?;
725
+ rb_tensor.define_method("reshape", method!(PyTensor::reshape, 1))?;
726
+ rb_tensor.define_method("broadcast_as", method!(PyTensor::broadcast_as, 1))?;
727
+ rb_tensor.define_method("broadcast_left", method!(PyTensor::broadcast_left, 1))?;
728
+ rb_tensor.define_method("squeeze", method!(PyTensor::squeeze, 1))?;
729
+ rb_tensor.define_method("unsqueeze", method!(PyTensor::unsqueeze, 1))?;
730
+ rb_tensor.define_method("get", method!(PyTensor::get, 1))?;
731
+ rb_tensor.define_method("transpose", method!(PyTensor::transpose, 2))?;
732
+ rb_tensor.define_method("narrow", method!(PyTensor::narrow, 3))?;
733
+ rb_tensor.define_method("argmax_keepdim", method!(PyTensor::argmax_keepdim, 1))?;
734
+ rb_tensor.define_method("argmin_keepdim", method!(PyTensor::argmin_keepdim, 1))?;
735
+ rb_tensor.define_method("max_keepdim", method!(PyTensor::max_keepdim, 1))?;
736
+ rb_tensor.define_method("min_keepdim", method!(PyTensor::min_keepdim, 1))?;
737
+ // rb_tensor.define_method("eq", method!(PyTensor::eq, 1))?;
738
+ // rb_tensor.define_method("ne", method!(PyTensor::ne, 1))?;
739
+ // rb_tensor.define_method("lt", method!(PyTensor::lt, 1))?;
740
+ // rb_tensor.define_method("gt", method!(PyTensor::gt, 1))?;
741
+ // rb_tensor.define_method("ge", method!(PyTensor::ge, 1))?;
742
+ // rb_tensor.define_method("le", method!(PyTensor::le, 1))?;
743
+ rb_tensor.define_method("sum_all", method!(PyTensor::sum_all, 0))?;
744
+ rb_tensor.define_method("mean_all", method!(PyTensor::mean_all, 0))?;
745
+ rb_tensor.define_method("flatten_from", method!(PyTensor::flatten_from, 1))?;
746
+ rb_tensor.define_method("flatten_to", method!(PyTensor::flatten_to, 1))?;
747
+ rb_tensor.define_method("flatten_all", method!(PyTensor::flatten_all, 0))?;
748
+ rb_tensor.define_method("t", method!(PyTensor::t, 0))?;
749
+ rb_tensor.define_method("contiguous", method!(PyTensor::contiguous, 0))?;
750
+ rb_tensor.define_method("is_contiguous", method!(PyTensor::is_contiguous, 0))?;
751
+ rb_tensor.define_method(
752
+ "is_fortran_contiguous",
753
+ method!(PyTensor::is_fortran_contiguous, 0),
754
+ )?;
755
+ rb_tensor.define_method("detach", method!(PyTensor::detach, 0))?;
756
+ rb_tensor.define_method("copy", method!(PyTensor::copy, 0))?;
757
+ rb_tensor.define_method("to_dtype", method!(PyTensor::to_dtype, 1))?;
758
+ rb_tensor.define_method("to_device", method!(PyTensor::to_device, 1))?;
759
+ rb_tensor.define_method("to_s", method!(PyTensor::__str__, 0))?;
760
+ rb_tensor.define_method("inspect", method!(PyTensor::__repr__, 0))?;
761
+ let rb_dtype = rb_candle.define_class("DType", Ruby::class_object(ruby))?;
762
+ rb_dtype.define_method("to_s", method!(PyDType::__str__, 0))?;
763
+ rb_dtype.define_method("inspect", method!(PyDType::__repr__, 0))?;
764
+ let rb_device = rb_candle.define_class("Device", Ruby::class_object(ruby))?;
765
+ rb_device.define_method("to_s", method!(PyDevice::__str__, 0))?;
766
+ rb_device.define_method("inspect", method!(PyDevice::__repr__, 0))?;
767
+ let rb_qtensor = rb_candle.define_class("QTensor", Ruby::class_object(ruby))?;
768
+ rb_qtensor.define_method("ggml_dtype", method!(PyQTensor::ggml_dtype, 0))?;
769
+ rb_qtensor.define_method("rank", method!(PyQTensor::rank, 0))?;
770
+ rb_qtensor.define_method("shape", method!(PyQTensor::shape, 0))?;
771
+ rb_qtensor.define_method("dequantize", method!(PyQTensor::dequantize, 0))?;
772
+ Ok(())
773
+ }