red-candle 0.0.3 → 0.0.4

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,773 +1,103 @@
1
- use magnus::{function, method, prelude::*, Error, Ruby};
2
- use std::sync::Arc;
1
+ use magnus::{function, method, prelude::*, Ruby};
3
2
 
4
- use half::{bf16, f16};
3
+ use crate::model::{candle_utils, RbModel, RbDType, RbDevice, RbQTensor, RbResult, RbTensor};
5
4
 
6
- use ::candle_core::{quantized::QTensor, DType, Device, Tensor, WithDType};
7
-
8
- type PyResult<T> = Result<T, Error>;
9
-
10
- pub fn wrap_err(err: candle_core::Error) -> Error {
11
- Error::new(magnus::exception::runtime_error(), err.to_string())
12
- }
13
-
14
- // #[derive(Clone, Debug)]
15
- // struct RbShape(Vec<usize>);
16
-
17
- // impl magnus::TryConvert for RbShape {
18
- // fn try_convert(val: magnus::Value) -> PyResult<Self> {
19
- // let ary = magnus::RArray::try_convert(val)?;
20
- // let shape = ary
21
- // .each()
22
- // .map(|v| magnus::Integer::try_convert(v?).map(|v| v.to_usize().unwrap()))
23
- // .collect::<PyResult<Vec<_>>>()?;
24
- // Ok(Self(shape))
25
- // }
26
- // }
27
-
28
- // impl magnus::IntoValue for RbShape {
29
- // fn into_value_with(self, ruby: &Ruby) -> magnus::Value {
30
- // let ary = magnus::RArray::from_vec(self.0);
31
- // ary.into_value_with(ruby)
32
- // }
33
- //}
34
-
35
- #[derive(Clone, Debug)]
36
- #[magnus::wrap(class = "Candle::Tensor", free_immediately, size)]
37
- /// A `candle` tensor.
38
- struct PyTensor(Tensor);
39
-
40
- impl std::ops::Deref for PyTensor {
41
- type Target = Tensor;
42
-
43
- fn deref(&self) -> &Self::Target {
44
- &self.0
45
- }
46
- }
47
-
48
- #[derive(Clone, Copy, Debug, PartialEq, Eq)]
49
- #[magnus::wrap(class = "Candle::DType", free_immediately, size)]
50
- /// A `candle` dtype.
51
- struct PyDType(DType);
52
-
53
- impl PyDType {
54
- fn __repr__(&self) -> String {
55
- format!("{:?}", self.0)
56
- }
57
-
58
- fn __str__(&self) -> String {
59
- self.__repr__()
60
- }
61
- }
62
-
63
- impl PyDType {
64
- fn from_pyobject(dtype: magnus::Symbol) -> PyResult<Self> {
65
- let dtype = unsafe { dtype.to_s() }.unwrap().into_owned();
66
- use std::str::FromStr;
67
- let dtype = DType::from_str(&dtype).unwrap();
68
- Ok(Self(dtype))
69
- }
70
- }
71
-
72
- static CUDA_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
73
-
74
- #[derive(Clone, Copy, Debug, PartialEq, Eq)]
75
- #[magnus::wrap(class = "Candle::Device")]
76
- enum PyDevice {
77
- Cpu,
78
- Cuda,
79
- }
80
-
81
- impl PyDevice {
82
- fn from_device(device: &Device) -> Self {
83
- match device {
84
- Device::Cpu => Self::Cpu,
85
- Device::Cuda(_) => Self::Cuda,
86
- }
87
- }
88
-
89
- fn as_device(&self) -> PyResult<Device> {
90
- match self {
91
- Self::Cpu => Ok(Device::Cpu),
92
- Self::Cuda => {
93
- let mut device = CUDA_DEVICE.lock().unwrap();
94
- if let Some(device) = device.as_ref() {
95
- return Ok(device.clone());
96
- };
97
- let d = Device::new_cuda(0).map_err(wrap_err)?;
98
- *device = Some(d.clone());
99
- Ok(d)
100
- }
101
- }
102
- }
103
-
104
- fn __repr__(&self) -> String {
105
- match self {
106
- Self::Cpu => "cpu".to_string(),
107
- Self::Cuda => "cuda".to_string(),
108
- }
109
- }
110
-
111
- fn __str__(&self) -> String {
112
- self.__repr__()
113
- }
114
- }
115
-
116
- impl magnus::TryConvert for PyDevice {
117
- fn try_convert(val: magnus::Value) -> PyResult<Self> {
118
- let device = magnus::RString::try_convert(val)?;
119
- let device = unsafe { device.as_str() }.unwrap();
120
- let device = match device {
121
- "cpu" => PyDevice::Cpu,
122
- "cuda" => PyDevice::Cuda,
123
- _ => return Err(Error::new(magnus::exception::arg_error(), "invalid device")),
124
- };
125
- Ok(device)
126
- }
127
- }
128
-
129
- fn actual_index(t: &Tensor, dim: usize, index: i64) -> candle_core::Result<usize> {
130
- let dim = t.dim(dim)?;
131
- if 0 <= index {
132
- let index = index as usize;
133
- if dim <= index {
134
- candle_core::bail!("index {index} is too large for tensor dimension {dim}")
135
- }
136
- Ok(index)
137
- } else {
138
- if (dim as i64) < -index {
139
- candle_core::bail!("index {index} is too low for tensor dimension {dim}")
140
- }
141
- Ok((dim as i64 + index) as usize)
142
- }
143
- }
144
-
145
- fn actual_dim(t: &Tensor, dim: i64) -> candle_core::Result<usize> {
146
- let rank = t.rank();
147
- if 0 <= dim {
148
- let dim = dim as usize;
149
- if rank <= dim {
150
- candle_core::bail!("dimension index {dim} is too large for tensor rank {rank}")
151
- }
152
- Ok(dim)
153
- } else {
154
- if (rank as i64) < -dim {
155
- candle_core::bail!("dimension index {dim} is too low for tensor rank {rank}")
156
- }
157
- Ok((rank as i64 + dim) as usize)
158
- }
159
- }
160
- impl PyTensor {
161
- fn new(array: magnus::RArray, dtype: Option<magnus::Symbol>) -> PyResult<Self> {
162
- let dtype = dtype
163
- .map(|dtype| PyDType::from_pyobject(dtype))
164
- .unwrap_or(Ok(PyDType(DType::F32)))?;
165
- // FIXME: Do not use `to_f64` here.
166
- let array = array
167
- .each()
168
- .map(|v| magnus::Float::try_convert(v?).map(|v| v.to_f64()))
169
- .collect::<PyResult<Vec<_>>>()?;
170
- Ok(Self(
171
- Tensor::new(array.as_slice(), &Device::Cpu)
172
- .map_err(wrap_err)?
173
- .to_dtype(dtype.0)
174
- .map_err(wrap_err)?,
175
- ))
176
- }
177
-
178
- /// Gets the tensor's shape.
179
- /// &RETURNS&: Tuple[int]
180
- fn shape(&self) -> Vec<usize> {
181
- self.0.dims().to_vec()
182
- }
183
-
184
- /// Gets the tensor's strides.
185
- /// &RETURNS&: Tuple[int]
186
- fn stride(&self) -> Vec<usize> {
187
- self.0.stride().to_vec()
188
- }
189
-
190
- /// Gets the tensor's dtype.
191
- /// &RETURNS&: DType
192
- fn dtype(&self) -> PyDType {
193
- PyDType(self.0.dtype())
194
- }
195
-
196
- /// Gets the tensor's device.
197
- /// &RETURNS&: Device
198
- fn device(&self) -> PyDevice {
199
- PyDevice::from_device(self.0.device())
200
- }
201
-
202
- /// Gets the tensor's rank.
203
- /// &RETURNS&: int
204
- fn rank(&self) -> usize {
205
- self.0.rank()
206
- }
207
-
208
- fn __repr__(&self) -> String {
209
- format!("{}", self.0)
210
- }
211
-
212
- fn __str__(&self) -> String {
213
- self.__repr__()
214
- }
215
-
216
- /// Performs the `sin` operation on the tensor.
217
- /// &RETURNS&: Tensor
218
- fn sin(&self) -> PyResult<Self> {
219
- Ok(PyTensor(self.0.sin().map_err(wrap_err)?))
220
- }
221
-
222
- /// Performs the `cos` operation on the tensor.
223
- /// &RETURNS&: Tensor
224
- fn cos(&self) -> PyResult<Self> {
225
- Ok(PyTensor(self.0.cos().map_err(wrap_err)?))
226
- }
227
-
228
- /// Performs the `log` operation on the tensor.
229
- /// &RETURNS&: Tensor
230
- fn log(&self) -> PyResult<Self> {
231
- Ok(PyTensor(self.0.log().map_err(wrap_err)?))
232
- }
233
-
234
- /// Squares the tensor.
235
- /// &RETURNS&: Tensor
236
- fn sqr(&self) -> PyResult<Self> {
237
- Ok(PyTensor(self.0.sqr().map_err(wrap_err)?))
238
- }
239
-
240
- /// Calculates the square root of the tensor.
241
- /// &RETURNS&: Tensor
242
- fn sqrt(&self) -> PyResult<Self> {
243
- Ok(PyTensor(self.0.sqrt().map_err(wrap_err)?))
244
- }
245
-
246
- /// Get the `recip` of the tensor.
247
- /// &RETURNS&: Tensor
248
- fn recip(&self) -> PyResult<Self> {
249
- Ok(PyTensor(self.0.recip().map_err(wrap_err)?))
250
- }
251
-
252
- /// Performs the `exp` operation on the tensor.
253
- /// &RETURNS&: Tensor
254
- fn exp(&self) -> PyResult<Self> {
255
- Ok(PyTensor(self.0.exp().map_err(wrap_err)?))
256
- }
257
-
258
- /// Performs the `pow` operation on the tensor with the given exponent.
259
- /// &RETURNS&: Tensor
260
- fn powf(&self, p: f64) -> PyResult<Self> {
261
- Ok(PyTensor(self.0.powf(p).map_err(wrap_err)?))
262
- }
263
-
264
- /// Select values for the input tensor at the target indexes across the specified dimension.
265
- ///
266
- /// The `indexes` is argument is an int tensor with a single dimension.
267
- /// The output has the same number of dimension as the `self` input. The target dimension of
268
- /// the output has length the length of `indexes` and the values are taken from `self` using
269
- /// the index from `indexes`. Other dimensions have the same number of elements as the input
270
- /// tensor.
271
- /// &RETURNS&: Tensor
272
- fn index_select(&self, rhs: &Self, dim: i64) -> PyResult<Self> {
273
- let dim = actual_dim(self, dim).map_err(wrap_err)?;
274
- Ok(PyTensor(self.0.index_select(rhs, dim).map_err(wrap_err)?))
275
- }
276
-
277
- /// Performs a matrix multiplication between the two tensors.
278
- /// &RETURNS&: Tensor
279
- fn matmul(&self, rhs: &Self) -> PyResult<Self> {
280
- Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?))
281
- }
282
-
283
- /// Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
284
- /// &RETURNS&: Tensor
285
- fn broadcast_add(&self, rhs: &Self) -> PyResult<Self> {
286
- Ok(PyTensor(self.0.broadcast_add(rhs).map_err(wrap_err)?))
287
- }
288
-
289
- /// Subtracts the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
290
- /// &RETURNS&: Tensor
291
- fn broadcast_sub(&self, rhs: &Self) -> PyResult<Self> {
292
- Ok(PyTensor(self.0.broadcast_sub(rhs).map_err(wrap_err)?))
293
- }
294
-
295
- /// Multiplies the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
296
- /// &RETURNS&: Tensor
297
- fn broadcast_mul(&self, rhs: &Self) -> PyResult<Self> {
298
- Ok(PyTensor(self.0.broadcast_mul(rhs).map_err(wrap_err)?))
299
- }
300
-
301
- /// Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
302
- /// &RETURNS&: Tensor
303
- fn broadcast_div(&self, rhs: &Self) -> PyResult<Self> {
304
- Ok(PyTensor(self.0.broadcast_div(rhs).map_err(wrap_err)?))
305
- }
306
-
307
- /// Returns a tensor with the same shape as the input tensor, the values are taken from
308
- /// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the
309
- /// input tensor is equal to zero.
310
- /// &RETURNS&: Tensor
311
- fn where_cond(&self, on_true: &Self, on_false: &Self) -> PyResult<Self> {
312
- Ok(PyTensor(
313
- self.0.where_cond(on_true, on_false).map_err(wrap_err)?,
314
- ))
315
- }
316
-
317
- /// Add two tensors.
318
- /// &RETURNS&: Tensor
319
- fn __add__(&self, rhs: &PyTensor) -> PyResult<Self> {
320
- Ok(Self(self.0.add(&rhs.0).map_err(wrap_err)?))
321
- }
322
-
323
- /// Multiply two tensors.
324
- /// &RETURNS&: Tensor
325
- fn __mul__(&self, rhs: &PyTensor) -> PyResult<Self> {
326
- Ok(Self(self.0.mul(&rhs.0).map_err(wrap_err)?))
327
- }
328
-
329
- /// Subtract two tensors.
330
- /// &RETURNS&: Tensor
331
- fn __sub__(&self, rhs: &PyTensor) -> PyResult<Self> {
332
- Ok(Self(self.0.sub(&rhs.0).map_err(wrap_err)?))
333
- }
334
-
335
- /// Divide two tensors.
336
- /// &RETURNS&: Tensor
337
- fn __truediv__(&self, rhs: &PyTensor) -> PyResult<Self> {
338
- Ok(Self(self.0.div(&rhs.0).map_err(wrap_err)?))
339
- }
340
-
341
- /// Reshapes the tensor to the given shape.
342
- /// &RETURNS&: Tensor
343
- fn reshape(&self, shape: Vec<usize>) -> PyResult<Self> {
344
- Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
345
- }
346
-
347
- /// Broadcasts the tensor to the given shape.
348
- /// &RETURNS&: Tensor
349
- fn broadcast_as(&self, shape: Vec<usize>) -> PyResult<Self> {
350
- Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?))
351
- }
352
-
353
- /// Broadcasts the tensor to the given shape, adding new dimensions on the left.
354
- /// &RETURNS&: Tensor
355
- fn broadcast_left(&self, shape: Vec<usize>) -> PyResult<Self> {
356
- Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?))
357
- }
358
-
359
- /// Creates a new tensor with the specified dimension removed if its size was one.
360
- /// &RETURNS&: Tensor
361
- fn squeeze(&self, dim: i64) -> PyResult<Self> {
362
- let dim = actual_dim(self, dim).map_err(wrap_err)?;
363
- Ok(PyTensor(self.0.squeeze(dim).map_err(wrap_err)?))
364
- }
365
-
366
- /// Creates a new tensor with a dimension of size one inserted at the specified position.
367
- /// &RETURNS&: Tensor
368
- fn unsqueeze(&self, dim: usize) -> PyResult<Self> {
369
- Ok(PyTensor(self.0.unsqueeze(dim).map_err(wrap_err)?))
370
- }
371
-
372
- /// Gets the value at the specified index.
373
- /// &RETURNS&: Tensor
374
- fn get(&self, index: i64) -> PyResult<Self> {
375
- let index = actual_index(self, 0, index).map_err(wrap_err)?;
376
- Ok(PyTensor(self.0.get(index).map_err(wrap_err)?))
377
- }
378
-
379
- /// Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
380
- /// &RETURNS&: Tensor
381
- fn transpose(&self, dim1: usize, dim2: usize) -> PyResult<Self> {
382
- Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?))
383
- }
384
-
385
- /// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
386
- /// ranges from `start` to `start + len`.
387
- /// &RETURNS&: Tensor
388
- fn narrow(&self, dim: i64, start: i64, len: usize) -> PyResult<Self> {
389
- let dim = actual_dim(self, dim).map_err(wrap_err)?;
390
- let start = actual_index(self, dim, start).map_err(wrap_err)?;
391
- Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?))
392
- }
393
-
394
- /// Returns the indices of the maximum value(s) across the selected dimension.
395
- /// &RETURNS&: Tensor
396
- fn argmax_keepdim(&self, dim: i64) -> PyResult<Self> {
397
- let dim = actual_dim(self, dim).map_err(wrap_err)?;
398
- Ok(PyTensor(self.0.argmax_keepdim(dim).map_err(wrap_err)?))
399
- }
400
-
401
- /// Returns the indices of the minimum value(s) across the selected dimension.
402
- /// &RETURNS&: Tensor
403
- fn argmin_keepdim(&self, dim: i64) -> PyResult<Self> {
404
- let dim = actual_dim(self, dim).map_err(wrap_err)?;
405
- Ok(PyTensor(self.0.argmin_keepdim(dim).map_err(wrap_err)?))
406
- }
407
-
408
- /// Gathers the maximum value across the selected dimension.
409
- /// &RETURNS&: Tensor
410
- fn max_keepdim(&self, dim: i64) -> PyResult<Self> {
411
- let dim = actual_dim(self, dim).map_err(wrap_err)?;
412
- Ok(PyTensor(self.0.max_keepdim(dim).map_err(wrap_err)?))
413
- }
414
-
415
- /// Gathers the minimum value across the selected dimension.
416
- /// &RETURNS&: Tensor
417
- fn min_keepdim(&self, dim: i64) -> PyResult<Self> {
418
- let dim = actual_dim(self, dim).map_err(wrap_err)?;
419
- Ok(PyTensor(self.0.min_keepdim(dim).map_err(wrap_err)?))
420
- }
421
-
422
- // fn eq(&self, rhs: &Self) -> PyResult<Self> {
423
- // Ok(PyTensor(self.0.eq(rhs).map_err(wrap_err)?))
424
- // }
425
-
426
- // fn ne(&self, rhs: &Self) -> PyResult<Self> {
427
- // Ok(PyTensor(self.0.ne(rhs).map_err(wrap_err)?))
428
- // }
429
-
430
- // fn lt(&self, rhs: &Self) -> PyResult<Self> {
431
- // Ok(PyTensor(self.0.lt(rhs).map_err(wrap_err)?))
432
- // }
433
-
434
- // fn gt(&self, rhs: &Self) -> PyResult<Self> {
435
- // Ok(PyTensor(self.0.gt(rhs).map_err(wrap_err)?))
436
- // }
437
-
438
- // fn ge(&self, rhs: &Self) -> PyResult<Self> {
439
- // Ok(PyTensor(self.0.ge(rhs).map_err(wrap_err)?))
440
- // }
441
-
442
- // fn le(&self, rhs: &Self) -> PyResult<Self> {
443
- // Ok(PyTensor(self.0.le(rhs).map_err(wrap_err)?))
444
- // }
445
-
446
- /// Returns the sum of the tensor.
447
- /// &RETURNS&: Tensor
448
- fn sum_all(&self) -> PyResult<Self> {
449
- Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?))
450
- }
451
-
452
- /// Returns the mean of the tensor.
453
- /// &RETURNS&: Tensor
454
- fn mean_all(&self) -> PyResult<Self> {
455
- let elements = self.0.elem_count();
456
- let sum = self.0.sum_all().map_err(wrap_err)?;
457
- let mean = (sum / elements as f64).map_err(wrap_err)?;
458
- Ok(PyTensor(mean))
459
- }
460
-
461
- /// Flattens the tensor on the dimension indexes from `dim` (inclusive) to the last dimension.
462
- /// &RETURNS&: Tensor
463
- fn flatten_from(&self, dim: i64) -> PyResult<Self> {
464
- let dim = actual_dim(self, dim).map_err(wrap_err)?;
465
- Ok(PyTensor(self.0.flatten_from(dim).map_err(wrap_err)?))
466
- }
467
-
468
- ///Flattens the tensor on the dimension indexes from `0` to `dim` (inclusive).
469
- /// &RETURNS&: Tensor
470
- fn flatten_to(&self, dim: i64) -> PyResult<Self> {
471
- let dim = actual_dim(self, dim).map_err(wrap_err)?;
472
- Ok(PyTensor(self.0.flatten_to(dim).map_err(wrap_err)?))
473
- }
474
-
475
- /// Flattens the tensor into a 1D tensor.
476
- /// &RETURNS&: Tensor
477
- fn flatten_all(&self) -> PyResult<Self> {
478
- Ok(PyTensor(self.0.flatten_all().map_err(wrap_err)?))
479
- }
480
-
481
- /// Transposes the tensor.
482
- /// &RETURNS&: Tensor
483
- fn t(&self) -> PyResult<Self> {
484
- Ok(PyTensor(self.0.t().map_err(wrap_err)?))
485
- }
486
-
487
- /// Makes the tensor contiguous in memory.
488
- /// &RETURNS&: Tensor
489
- fn contiguous(&self) -> PyResult<Self> {
490
- Ok(PyTensor(self.0.contiguous().map_err(wrap_err)?))
491
- }
492
-
493
- /// Returns true if the tensor is contiguous in C order.
494
- /// &RETURNS&: bool
495
- fn is_contiguous(&self) -> bool {
496
- self.0.is_contiguous()
497
- }
498
-
499
- /// Returns true if the tensor is contiguous in Fortran order.
500
- /// &RETURNS&: bool
501
- fn is_fortran_contiguous(&self) -> bool {
502
- self.0.is_fortran_contiguous()
503
- }
504
-
505
- /// Detach the tensor from the computation graph.
506
- /// &RETURNS&: Tensor
507
- fn detach(&self) -> PyResult<Self> {
508
- Ok(PyTensor(self.0.detach().map_err(wrap_err)?))
509
- }
510
-
511
- /// Returns a copy of the tensor.
512
- /// &RETURNS&: Tensor
513
- fn copy(&self) -> PyResult<Self> {
514
- Ok(PyTensor(self.0.copy().map_err(wrap_err)?))
515
- }
516
-
517
- /// Convert the tensor to a new dtype.
518
- /// &RETURNS&: Tensor
519
- fn to_dtype(&self, dtype: magnus::Symbol) -> PyResult<Self> {
520
- let dtype = PyDType::from_pyobject(dtype)?;
521
- Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?))
522
- }
523
-
524
- /// Move the tensor to a new device.
525
- /// &RETURNS&: Tensor
526
- fn to_device(&self, device: PyDevice) -> PyResult<Self> {
527
- let device = device.as_device()?;
528
- Ok(PyTensor(self.0.to_device(&device).map_err(wrap_err)?))
529
- }
530
- }
531
-
532
- impl PyTensor {
533
- // fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
534
- // if tensors.is_empty() {
535
- // return Err(Error::new(
536
- // magnus::exception::arg_error(),
537
- // "empty input to cat",
538
- // ));
539
- // }
540
- // let dim = actual_dim(&tensors[0].0, dim).map_err(wrap_err)?;
541
- // let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
542
- // let tensor = Tensor::cat(&tensors, dim).map_err(wrap_err)?;
543
- // Ok(PyTensor(tensor))
544
- // }
545
-
546
- // fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<Self> {
547
- // let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
548
- // let tensor = Tensor::stack(&tensors, dim).map_err(wrap_err)?;
549
- // Ok(Self(tensor))
550
- // }
551
-
552
- /// Creates a new tensor with random values.
553
- /// &RETURNS&: Tensor
554
- fn rand(shape: Vec<usize>) -> PyResult<Self> {
555
- let device = PyDevice::Cpu.as_device()?;
556
- Ok(Self(
557
- Tensor::rand(0f32, 1f32, shape, &device).map_err(wrap_err)?,
558
- ))
559
- }
560
-
561
- /// Creates a new tensor with random values from a normal distribution.
562
- /// &RETURNS&: Tensor
563
- fn randn(shape: Vec<usize>) -> PyResult<Self> {
564
- let device = PyDevice::Cpu.as_device()?;
565
- Ok(Self(
566
- Tensor::randn(0f32, 1f32, shape, &device).map_err(wrap_err)?,
567
- ))
568
- }
569
-
570
- /// Creates a new tensor filled with ones.
571
- /// &RETURNS&: Tensor
572
- fn ones(shape: Vec<usize>) -> PyResult<Self> {
573
- let device = PyDevice::Cpu.as_device()?;
574
- Ok(Self(
575
- Tensor::ones(shape, DType::F32, &device).map_err(wrap_err)?,
576
- ))
577
- }
578
- /// Creates a new tensor filled with zeros.
579
- /// &RETURNS&: Tensor
580
- fn zeros(shape: Vec<usize>) -> PyResult<Self> {
581
- let device = PyDevice::Cpu.as_device()?;
582
- Ok(Self(
583
- Tensor::zeros(shape, DType::F32, &device).map_err(wrap_err)?,
584
- ))
585
- }
586
- }
587
-
588
- #[derive(Debug)]
589
- #[magnus::wrap(class = "Candle::QTensor", free_immediately, size)]
590
- /// A quantized tensor.
591
- struct PyQTensor(Arc<QTensor>);
592
-
593
- impl std::ops::Deref for PyQTensor {
594
- type Target = QTensor;
595
-
596
- fn deref(&self) -> &Self::Target {
597
- self.0.as_ref()
598
- }
599
- }
600
-
601
- impl PyQTensor {
602
- ///Gets the tensors quantized dtype.
603
- /// &RETURNS&: str
604
- fn ggml_dtype(&self) -> String {
605
- format!("{:?}", self.0.dtype())
606
- }
607
-
608
- ///Gets the rank of the tensor.
609
- /// &RETURNS&: int
610
- fn rank(&self) -> usize {
611
- self.0.rank()
612
- }
613
-
614
- ///Gets the shape of the tensor.
615
- /// &RETURNS&: Tuple[int]
616
- fn shape(&self) -> Vec<usize> {
617
- self.0.shape().dims().to_vec()
618
- }
619
-
620
- fn __repr__(&self) -> String {
621
- format!("{:?}", self.0)
622
- }
623
-
624
- fn __str__(&self) -> String {
625
- self.__repr__()
626
- }
627
-
628
- /// Dequantizes the tensor.
629
- /// &RETURNS&: Tensor
630
- fn dequantize(&self) -> PyResult<PyTensor> {
631
- let tensor = self.0.dequantize(&Device::Cpu).map_err(wrap_err)?;
632
- Ok(PyTensor(tensor))
633
- }
634
-
635
- // fn matmul_t(&self, lhs: &PyTensor) -> PyResult<PyTensor> {
636
- // let qmatmul = ::candle_core::quantized::QMatMul::from_arc(self.0.clone());
637
- // let res = qmatmul.forward(lhs).map_err(wrap_err)?;
638
- // Ok(PyTensor(res))
639
- // }
640
- }
641
-
642
- /// Returns true if the 'cuda' backend is available.
643
- /// &RETURNS&: bool
644
- fn cuda_is_available() -> bool {
645
- candle_core::utils::cuda_is_available()
646
- }
647
-
648
- /// Returns true if candle was compiled with 'accelerate' support.
649
- /// &RETURNS&: bool
650
- fn has_accelerate() -> bool {
651
- candle_core::utils::has_accelerate()
652
- }
653
-
654
- /// Returns true if candle was compiled with MKL support.
655
- /// &RETURNS&: bool
656
- fn has_mkl() -> bool {
657
- candle_core::utils::has_mkl()
658
- }
659
-
660
- /// Returns the number of threads used by the candle.
661
- /// &RETURNS&: int
662
- fn get_num_threads() -> usize {
663
- candle_core::utils::get_num_threads()
664
- }
665
-
666
- fn candle_utils(rb_candle: magnus::RModule) -> Result<(), Error> {
667
- let rb_utils = rb_candle.define_module("Utils")?;
668
- rb_utils.define_singleton_method("cuda_is_available", function!(cuda_is_available, 0))?;
669
- rb_utils.define_singleton_method("get_num_threads", function!(get_num_threads, 0))?;
670
- rb_utils.define_singleton_method("has_accelerate", function!(has_accelerate, 0))?;
671
- rb_utils.define_singleton_method("has_mkl", function!(has_mkl, 0))?;
672
- Ok(())
673
- }
674
-
675
- /// Applies the Softmax function to a given tensor.#
676
- /// &RETURNS&: Tensor
677
- fn softmax(tensor: PyTensor, dim: i64) -> PyResult<PyTensor> {
678
- let dim = actual_dim(&tensor, dim).map_err(wrap_err)?;
679
- let sm = candle_nn::ops::softmax(&tensor.0, dim).map_err(wrap_err)?;
680
- Ok(PyTensor(sm))
681
- }
682
-
683
- /// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
684
- /// &RETURNS&: Tensor
685
- fn silu(tensor: PyTensor) -> PyResult<PyTensor> {
686
- let s = candle_nn::ops::silu(&tensor.0).map_err(wrap_err)?;
687
- Ok(PyTensor(s))
688
- }
5
+ pub mod model;
689
6
 
690
7
  #[magnus::init]
691
- fn init(ruby: &Ruby) -> PyResult<()> {
8
+ fn init(ruby: &Ruby) -> RbResult<()> {
692
9
  let rb_candle = ruby.define_module("Candle")?;
693
10
  candle_utils(rb_candle)?;
694
11
  let rb_tensor = rb_candle.define_class("Tensor", Ruby::class_object(ruby))?;
695
- rb_tensor.define_singleton_method("new", function!(PyTensor::new, 2))?;
696
- // rb_tensor.define_singleton_method("cat", function!(PyTensor::cat, 2))?;
697
- // rb_tensor.define_singleton_method("stack", function!(PyTensor::stack, 2))?;
698
- rb_tensor.define_singleton_method("rand", function!(PyTensor::rand, 1))?;
699
- rb_tensor.define_singleton_method("randn", function!(PyTensor::randn, 1))?;
700
- rb_tensor.define_singleton_method("ones", function!(PyTensor::ones, 1))?;
701
- rb_tensor.define_singleton_method("zeros", function!(PyTensor::zeros, 1))?;
702
- rb_tensor.define_method("shape", method!(PyTensor::shape, 0))?;
703
- rb_tensor.define_method("stride", method!(PyTensor::stride, 0))?;
704
- rb_tensor.define_method("dtype", method!(PyTensor::dtype, 0))?;
705
- rb_tensor.define_method("device", method!(PyTensor::device, 0))?;
706
- rb_tensor.define_method("rank", method!(PyTensor::rank, 0))?;
707
- rb_tensor.define_method("sin", method!(PyTensor::sin, 0))?;
708
- rb_tensor.define_method("cos", method!(PyTensor::cos, 0))?;
709
- rb_tensor.define_method("log", method!(PyTensor::log, 0))?;
710
- rb_tensor.define_method("sqr", method!(PyTensor::sqr, 0))?;
711
- rb_tensor.define_method("sqrt", method!(PyTensor::sqrt, 0))?;
712
- rb_tensor.define_method("recip", method!(PyTensor::recip, 0))?;
713
- rb_tensor.define_method("exp", method!(PyTensor::exp, 0))?;
714
- rb_tensor.define_method("powf", method!(PyTensor::powf, 1))?;
715
- rb_tensor.define_method("index_select", method!(PyTensor::index_select, 2))?;
716
- rb_tensor.define_method("matmul", method!(PyTensor::matmul, 1))?;
717
- rb_tensor.define_method("broadcast_add", method!(PyTensor::broadcast_add, 1))?;
718
- rb_tensor.define_method("broadcast_sub", method!(PyTensor::broadcast_sub, 1))?;
719
- rb_tensor.define_method("broadcast_mul", method!(PyTensor::broadcast_mul, 1))?;
720
- rb_tensor.define_method("broadcast_div", method!(PyTensor::broadcast_div, 1))?;
721
- rb_tensor.define_method("where_cond", method!(PyTensor::where_cond, 2))?;
722
- rb_tensor.define_method("+", method!(PyTensor::__add__, 1))?;
723
- rb_tensor.define_method("*", method!(PyTensor::__mul__, 1))?;
724
- rb_tensor.define_method("-", method!(PyTensor::__sub__, 1))?;
725
- rb_tensor.define_method("reshape", method!(PyTensor::reshape, 1))?;
726
- rb_tensor.define_method("broadcast_as", method!(PyTensor::broadcast_as, 1))?;
727
- rb_tensor.define_method("broadcast_left", method!(PyTensor::broadcast_left, 1))?;
728
- rb_tensor.define_method("squeeze", method!(PyTensor::squeeze, 1))?;
729
- rb_tensor.define_method("unsqueeze", method!(PyTensor::unsqueeze, 1))?;
730
- rb_tensor.define_method("get", method!(PyTensor::get, 1))?;
731
- rb_tensor.define_method("transpose", method!(PyTensor::transpose, 2))?;
732
- rb_tensor.define_method("narrow", method!(PyTensor::narrow, 3))?;
733
- rb_tensor.define_method("argmax_keepdim", method!(PyTensor::argmax_keepdim, 1))?;
734
- rb_tensor.define_method("argmin_keepdim", method!(PyTensor::argmin_keepdim, 1))?;
735
- rb_tensor.define_method("max_keepdim", method!(PyTensor::max_keepdim, 1))?;
736
- rb_tensor.define_method("min_keepdim", method!(PyTensor::min_keepdim, 1))?;
737
- // rb_tensor.define_method("eq", method!(PyTensor::eq, 1))?;
738
- // rb_tensor.define_method("ne", method!(PyTensor::ne, 1))?;
739
- // rb_tensor.define_method("lt", method!(PyTensor::lt, 1))?;
740
- // rb_tensor.define_method("gt", method!(PyTensor::gt, 1))?;
741
- // rb_tensor.define_method("ge", method!(PyTensor::ge, 1))?;
742
- // rb_tensor.define_method("le", method!(PyTensor::le, 1))?;
743
- rb_tensor.define_method("sum_all", method!(PyTensor::sum_all, 0))?;
744
- rb_tensor.define_method("mean_all", method!(PyTensor::mean_all, 0))?;
745
- rb_tensor.define_method("flatten_from", method!(PyTensor::flatten_from, 1))?;
746
- rb_tensor.define_method("flatten_to", method!(PyTensor::flatten_to, 1))?;
747
- rb_tensor.define_method("flatten_all", method!(PyTensor::flatten_all, 0))?;
748
- rb_tensor.define_method("t", method!(PyTensor::t, 0))?;
749
- rb_tensor.define_method("contiguous", method!(PyTensor::contiguous, 0))?;
750
- rb_tensor.define_method("is_contiguous", method!(PyTensor::is_contiguous, 0))?;
12
+ rb_tensor.define_singleton_method("new", function!(RbTensor::new, 2))?;
13
+ // rb_tensor.define_singleton_method("cat", function!(RbTensor::cat, 2))?;
14
+ // rb_tensor.define_singleton_method("stack", function!(RbTensor::stack, 2))?;
15
+ rb_tensor.define_singleton_method("rand", function!(RbTensor::rand, 1))?;
16
+ rb_tensor.define_singleton_method("randn", function!(RbTensor::randn, 1))?;
17
+ rb_tensor.define_singleton_method("ones", function!(RbTensor::ones, 1))?;
18
+ rb_tensor.define_singleton_method("zeros", function!(RbTensor::zeros, 1))?;
19
+ rb_tensor.define_method("values", method!(RbTensor::values, 0))?;
20
+ rb_tensor.define_method("shape", method!(RbTensor::shape, 0))?;
21
+ rb_tensor.define_method("stride", method!(RbTensor::stride, 0))?;
22
+ rb_tensor.define_method("dtype", method!(RbTensor::dtype, 0))?;
23
+ rb_tensor.define_method("device", method!(RbTensor::device, 0))?;
24
+ rb_tensor.define_method("rank", method!(RbTensor::rank, 0))?;
25
+ rb_tensor.define_method("elem_count", method!(RbTensor::elem_count, 0))?;
26
+ rb_tensor.define_method("sin", method!(RbTensor::sin, 0))?;
27
+ rb_tensor.define_method("cos", method!(RbTensor::cos, 0))?;
28
+ rb_tensor.define_method("log", method!(RbTensor::log, 0))?;
29
+ rb_tensor.define_method("sqr", method!(RbTensor::sqr, 0))?;
30
+ rb_tensor.define_method("sqrt", method!(RbTensor::sqrt, 0))?;
31
+ rb_tensor.define_method("recip", method!(RbTensor::recip, 0))?;
32
+ rb_tensor.define_method("exp", method!(RbTensor::exp, 0))?;
33
+ rb_tensor.define_method("powf", method!(RbTensor::powf, 1))?;
34
+ rb_tensor.define_method("index_select", method!(RbTensor::index_select, 2))?;
35
+ rb_tensor.define_method("matmul", method!(RbTensor::matmul, 1))?;
36
+ rb_tensor.define_method("broadcast_add", method!(RbTensor::broadcast_add, 1))?;
37
+ rb_tensor.define_method("broadcast_sub", method!(RbTensor::broadcast_sub, 1))?;
38
+ rb_tensor.define_method("broadcast_mul", method!(RbTensor::broadcast_mul, 1))?;
39
+ rb_tensor.define_method("broadcast_div", method!(RbTensor::broadcast_div, 1))?;
40
+ rb_tensor.define_method("where_cond", method!(RbTensor::where_cond, 2))?;
41
+ rb_tensor.define_method("+", method!(RbTensor::__add__, 1))?;
42
+ rb_tensor.define_method("*", method!(RbTensor::__mul__, 1))?;
43
+ rb_tensor.define_method("-", method!(RbTensor::__sub__, 1))?;
44
+ rb_tensor.define_method("reshape", method!(RbTensor::reshape, 1))?;
45
+ rb_tensor.define_method("broadcast_as", method!(RbTensor::broadcast_as, 1))?;
46
+ rb_tensor.define_method("broadcast_left", method!(RbTensor::broadcast_left, 1))?;
47
+ rb_tensor.define_method("squeeze", method!(RbTensor::squeeze, 1))?;
48
+ rb_tensor.define_method("unsqueeze", method!(RbTensor::unsqueeze, 1))?;
49
+ rb_tensor.define_method("get", method!(RbTensor::get, 1))?;
50
+ rb_tensor.define_method("[]", method!(RbTensor::get, 1))?;
51
+ rb_tensor.define_method("transpose", method!(RbTensor::transpose, 2))?;
52
+ rb_tensor.define_method("narrow", method!(RbTensor::narrow, 3))?;
53
+ rb_tensor.define_method("argmax_keepdim", method!(RbTensor::argmax_keepdim, 1))?;
54
+ rb_tensor.define_method("argmin_keepdim", method!(RbTensor::argmin_keepdim, 1))?;
55
+ rb_tensor.define_method("max_keepdim", method!(RbTensor::max_keepdim, 1))?;
56
+ rb_tensor.define_method("min_keepdim", method!(RbTensor::min_keepdim, 1))?;
57
+ // rb_tensor.define_method("eq", method!(RbTensor::eq, 1))?;
58
+ // rb_tensor.define_method("ne", method!(RbTensor::ne, 1))?;
59
+ // rb_tensor.define_method("lt", method!(RbTensor::lt, 1))?;
60
+ // rb_tensor.define_method("gt", method!(RbTensor::gt, 1))?;
61
+ // rb_tensor.define_method("ge", method!(RbTensor::ge, 1))?;
62
+ // rb_tensor.define_method("le", method!(RbTensor::le, 1))?;
63
+ rb_tensor.define_method("sum_all", method!(RbTensor::sum_all, 0))?;
64
+ rb_tensor.define_method("mean_all", method!(RbTensor::mean_all, 0))?;
65
+ rb_tensor.define_method("flatten_from", method!(RbTensor::flatten_from, 1))?;
66
+ rb_tensor.define_method("flatten_to", method!(RbTensor::flatten_to, 1))?;
67
+ rb_tensor.define_method("flatten_all", method!(RbTensor::flatten_all, 0))?;
68
+ rb_tensor.define_method("t", method!(RbTensor::t, 0))?;
69
+ rb_tensor.define_method("contiguous", method!(RbTensor::contiguous, 0))?;
70
+ rb_tensor.define_method("is_contiguous", method!(RbTensor::is_contiguous, 0))?;
751
71
  rb_tensor.define_method(
752
72
  "is_fortran_contiguous",
753
- method!(PyTensor::is_fortran_contiguous, 0),
73
+ method!(RbTensor::is_fortran_contiguous, 0),
754
74
  )?;
755
- rb_tensor.define_method("detach", method!(PyTensor::detach, 0))?;
756
- rb_tensor.define_method("copy", method!(PyTensor::copy, 0))?;
757
- rb_tensor.define_method("to_dtype", method!(PyTensor::to_dtype, 1))?;
758
- rb_tensor.define_method("to_device", method!(PyTensor::to_device, 1))?;
759
- rb_tensor.define_method("to_s", method!(PyTensor::__str__, 0))?;
760
- rb_tensor.define_method("inspect", method!(PyTensor::__repr__, 0))?;
75
+ rb_tensor.define_method("detach", method!(RbTensor::detach, 0))?;
76
+ rb_tensor.define_method("copy", method!(RbTensor::copy, 0))?;
77
+ rb_tensor.define_method("to_dtype", method!(RbTensor::to_dtype, 1))?;
78
+ rb_tensor.define_method("to_device", method!(RbTensor::to_device, 1))?;
79
+ rb_tensor.define_method("to_s", method!(RbTensor::__str__, 0))?;
80
+ rb_tensor.define_method("inspect", method!(RbTensor::__repr__, 0))?;
81
+
761
82
  let rb_dtype = rb_candle.define_class("DType", Ruby::class_object(ruby))?;
762
- rb_dtype.define_method("to_s", method!(PyDType::__str__, 0))?;
763
- rb_dtype.define_method("inspect", method!(PyDType::__repr__, 0))?;
83
+ rb_dtype.define_method("to_s", method!(RbDType::__str__, 0))?;
84
+ rb_dtype.define_method("inspect", method!(RbDType::__repr__, 0))?;
85
+
764
86
  let rb_device = rb_candle.define_class("Device", Ruby::class_object(ruby))?;
765
- rb_device.define_method("to_s", method!(PyDevice::__str__, 0))?;
766
- rb_device.define_method("inspect", method!(PyDevice::__repr__, 0))?;
87
+ rb_device.define_method("to_s", method!(RbDevice::__str__, 0))?;
88
+ rb_device.define_method("inspect", method!(RbDevice::__repr__, 0))?;
89
+
767
90
  let rb_qtensor = rb_candle.define_class("QTensor", Ruby::class_object(ruby))?;
768
- rb_qtensor.define_method("ggml_dtype", method!(PyQTensor::ggml_dtype, 0))?;
769
- rb_qtensor.define_method("rank", method!(PyQTensor::rank, 0))?;
770
- rb_qtensor.define_method("shape", method!(PyQTensor::shape, 0))?;
771
- rb_qtensor.define_method("dequantize", method!(PyQTensor::dequantize, 0))?;
91
+ rb_qtensor.define_method("ggml_dtype", method!(RbQTensor::ggml_dtype, 0))?;
92
+ rb_qtensor.define_method("rank", method!(RbQTensor::rank, 0))?;
93
+ rb_qtensor.define_method("shape", method!(RbQTensor::shape, 0))?;
94
+ rb_qtensor.define_method("dequantize", method!(RbQTensor::dequantize, 0))?;
95
+
96
+ let rb_model = rb_candle.define_class("Model", Ruby::class_object(ruby))?;
97
+ rb_model.define_singleton_method("new", function!(RbModel::new, 0))?;
98
+ rb_model.define_method("embedding", method!(RbModel::embedding, 1))?;
99
+ rb_model.define_method("to_s", method!(RbModel::__str__, 0))?;
100
+ rb_model.define_method("inspect", method!(RbModel::__repr__, 0))?;
101
+
772
102
  Ok(())
773
103
  }