red-candle 1.0.0.pre.1 → 1.0.0.pre.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,88 @@
1
+ use magnus::{function, Error, Module, Object};
2
+
3
+ use ::candle_core::Tensor;
4
+
5
+ use crate::ruby::errors::wrap_candle_err;
6
+ use crate::ruby::{Result as RbResult, Tensor as RbTensor};
7
+
8
+ pub fn actual_index(t: &Tensor, dim: usize, index: i64) -> candle_core::Result<usize> {
9
+ let dim = t.dim(dim)?;
10
+ if 0 <= index {
11
+ let index = index as usize;
12
+ if dim <= index {
13
+ candle_core::bail!("index {index} is too large for tensor dimension {dim}")
14
+ }
15
+ Ok(index)
16
+ } else {
17
+ if (dim as i64) < -index {
18
+ candle_core::bail!("index {index} is too low for tensor dimension {dim}")
19
+ }
20
+ Ok((dim as i64 + index) as usize)
21
+ }
22
+ }
23
+
24
+ pub fn actual_dim(t: &Tensor, dim: i64) -> candle_core::Result<usize> {
25
+ let rank = t.rank();
26
+ if 0 <= dim {
27
+ let dim = dim as usize;
28
+ if rank <= dim {
29
+ candle_core::bail!("dimension index {dim} is too large for tensor rank {rank}")
30
+ }
31
+ Ok(dim)
32
+ } else {
33
+ if (rank as i64) < -dim {
34
+ candle_core::bail!("dimension index {dim} is too low for tensor rank {rank}")
35
+ }
36
+ Ok((rank as i64 + dim) as usize)
37
+ }
38
+ }
39
+
40
+ /// Returns true if the 'cuda' backend is available.
41
+ /// &RETURNS&: bool
42
+ fn cuda_is_available() -> bool {
43
+ candle_core::utils::cuda_is_available()
44
+ }
45
+
46
+ /// Returns true if candle was compiled with 'accelerate' support.
47
+ /// &RETURNS&: bool
48
+ fn has_accelerate() -> bool {
49
+ candle_core::utils::has_accelerate()
50
+ }
51
+
52
+ /// Returns true if candle was compiled with MKL support.
53
+ /// &RETURNS&: bool
54
+ fn has_mkl() -> bool {
55
+ candle_core::utils::has_mkl()
56
+ }
57
+
58
+ /// Returns the number of threads used by the candle.
59
+ /// &RETURNS&: int
60
+ fn get_num_threads() -> usize {
61
+ candle_core::utils::get_num_threads()
62
+ }
63
+
64
+ pub fn candle_utils(rb_candle: magnus::RModule) -> Result<(), Error> {
65
+ let rb_utils = rb_candle.define_module("Utils")?;
66
+ rb_utils.define_singleton_method("cuda_is_available", function!(cuda_is_available, 0))?;
67
+ rb_utils.define_singleton_method("get_num_threads", function!(get_num_threads, 0))?;
68
+ rb_utils.define_singleton_method("has_accelerate", function!(has_accelerate, 0))?;
69
+ rb_utils.define_singleton_method("has_mkl", function!(has_mkl, 0))?;
70
+ Ok(())
71
+ }
72
+
73
+ /// Applies the Softmax function to a given tensor.#
74
+ /// &RETURNS&: Tensor
75
+ #[allow(dead_code)]
76
+ fn softmax(tensor: RbTensor, dim: i64) -> RbResult<RbTensor> {
77
+ let dim = actual_dim(&tensor, dim).map_err(wrap_candle_err)?;
78
+ let sm = candle_nn::ops::softmax(&tensor.0, dim).map_err(wrap_candle_err)?;
79
+ Ok(RbTensor(sm))
80
+ }
81
+
82
+ /// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
83
+ /// &RETURNS&: Tensor
84
+ #[allow(dead_code)]
85
+ fn silu(tensor: RbTensor) -> RbResult<RbTensor> {
86
+ let s = candle_nn::ops::silu(&tensor.0).map_err(wrap_candle_err)?;
87
+ Ok(RbTensor(s))
88
+ }
@@ -1,3 +1,3 @@
1
1
  module Candle
2
- VERSION = "1.0.0.pre.1"
2
+ VERSION = "1.0.0.pre.2"
3
3
  end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: red-candle
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.0.0.pre.1
4
+ version: 1.0.0.pre.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Christopher Petersen
@@ -38,8 +38,24 @@ files:
38
38
  - Cargo.toml
39
39
  - README.md
40
40
  - ext/candle/Cargo.toml
41
+ - ext/candle/build.rs
41
42
  - ext/candle/extconf.rb
42
43
  - ext/candle/src/lib.rs
44
+ - ext/candle/src/llm/generation_config.rs
45
+ - ext/candle/src/llm/mistral.rs
46
+ - ext/candle/src/llm/mod.rs
47
+ - ext/candle/src/llm/text_generation.rs
48
+ - ext/candle/src/reranker.rs
49
+ - ext/candle/src/ruby/device.rs
50
+ - ext/candle/src/ruby/dtype.rs
51
+ - ext/candle/src/ruby/embedding_model.rs
52
+ - ext/candle/src/ruby/errors.rs
53
+ - ext/candle/src/ruby/llm.rs
54
+ - ext/candle/src/ruby/mod.rs
55
+ - ext/candle/src/ruby/qtensor.rs
56
+ - ext/candle/src/ruby/result.rs
57
+ - ext/candle/src/ruby/tensor.rs
58
+ - ext/candle/src/ruby/utils.rs
43
59
  - lib/candle.rb
44
60
  - lib/candle/build_info.rb
45
61
  - lib/candle/device_utils.rb