torch-rb 0.16.0 → 0.17.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_fft_functions(Rice::Module& m);
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_linalg_functions(Rice::Module& m);
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_nn_functions(Rice::Module& m);
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_sparse_functions(Rice::Module& m);
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_special_functions(Rice::Module& m);
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_tensor_functions(Rice::Module& m);
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_torch_functions(Rice::Module& m);
data/ext/torch/utils.h CHANGED
@@ -6,7 +6,7 @@
6
6
  #include <rice/stl.hpp>
7
7
 
8
8
  static_assert(
9
- TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 3,
9
+ TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 4,
10
10
  "Incompatible LibTorch version"
11
11
  );
12
12
 
@@ -134,7 +134,7 @@ module Torch
134
134
  raise ArgumentError, "Padding length too large" unless pad.size / 2 <= input.dim
135
135
 
136
136
  if mode == "constant"
137
- return Torch.constant_pad_nd(input, pad, value)
137
+ Torch.constant_pad_nd(input, pad, value)
138
138
  else
139
139
  raise ArgumentError, "Padding mode doesn't take in value argument" unless value == 0
140
140
 
@@ -481,6 +481,16 @@ module Torch
481
481
  Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, to_reduction(reduction))
482
482
  end
483
483
 
484
+ def normalize(input, p: 2.0, dim: 1, eps: 1e-12, out: nil)
485
+ if out.nil?
486
+ denom = input.norm(p, dim, keepdim: true).clamp_min(eps).expand_as(input)
487
+ input / denom
488
+ else
489
+ denom = input.norm(p, dim, keepdim: true).clamp_min!(eps).expand_as(input)
490
+ Torch.div(input, denom, out: out)
491
+ end
492
+ end
493
+
484
494
  # vision
485
495
 
486
496
  def interpolate(input, size: nil, scale_factor: nil, mode: "nearest", align_corners: nil, recompute_scale_factor: nil)
@@ -5,10 +5,10 @@ module Torch
5
5
  def in_projection_packed(q, k, v, w, b: nil)
6
6
  e = q.size(-1)
7
7
 
8
- if k.eql? v
9
- if q.eql? k
8
+ if k.eql?(v)
9
+ if q.eql?(k)
10
10
  # self-attention
11
- return linear(q, w, b).chunk(3, dim: -1)
11
+ linear(q, w, b).chunk(3, dim: -1)
12
12
  else
13
13
  # encoder-decoder attention
14
14
  w_q, w_kv = w.split_with_sizes([e, e * 2])
@@ -18,7 +18,7 @@ module Torch
18
18
  b_q, b_kv = b.split_with_sizes([e, e * 2])
19
19
  end
20
20
 
21
- return [linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)]
21
+ [linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)]
22
22
  end
23
23
  else
24
24
  w_q, w_k, w_v = w.chunk(3)
@@ -28,7 +28,7 @@ module Torch
28
28
  b_q, b_k, b_v = b.chunk(3)
29
29
  end
30
30
 
31
- return [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
31
+ [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
32
32
  end
33
33
  end
34
34
 
@@ -10,16 +10,23 @@ module Torch
10
10
  @parameters = {}
11
11
  @buffers = {}
12
12
  @modules = {}
13
+ @non_persistent_buffers_set = Set.new
13
14
  end
14
15
 
15
16
  def forward
16
17
  raise NotImplementedError
17
18
  end
18
19
 
19
- def register_buffer(name, tensor)
20
+ def register_buffer(name, tensor, persistent: true)
20
21
  # TODO add checks
21
22
  @buffers[name] = tensor
22
23
  instance_variable_set("@#{name}", tensor)
24
+
25
+ if persistent
26
+ @non_persistent_buffers_set.delete(name)
27
+ else
28
+ @non_persistent_buffers_set << name
29
+ end
23
30
  end
24
31
 
25
32
  def register_parameter(name, param)
@@ -190,8 +197,18 @@ module Torch
190
197
  named_buffers.values
191
198
  end
192
199
 
193
- def named_buffers
194
- @buffers || {}
200
+ # TODO set recurse: true in 0.18.0
201
+ def named_buffers(prefix: "", recurse: false)
202
+ buffers = {}
203
+ if recurse
204
+ named_children.each do |name, mod|
205
+ buffers.merge!(mod.named_buffers(prefix: "#{prefix}#{name}.", recurse: recurse))
206
+ end
207
+ end
208
+ (@buffers || {}).each do |k, v|
209
+ buffers[[prefix, k].join] = v
210
+ end
211
+ buffers
195
212
  end
196
213
 
197
214
  def children
@@ -390,7 +407,10 @@ module Torch
390
407
  destination[prefix + k] = v
391
408
  end
392
409
  named_buffers.each do |k, v|
393
- destination[prefix + k] = v
410
+ # TODO exclude v.nil?
411
+ if !@non_persistent_buffers_set.include?(k)
412
+ destination[prefix + k] = v
413
+ end
394
414
  end
395
415
  end
396
416
 
data/lib/torch/tensor.rb CHANGED
@@ -57,7 +57,7 @@ module Torch
57
57
  if shape.empty?
58
58
  arr
59
59
  else
60
- shape[1..-1].reverse.each do |dim|
60
+ shape[1..-1].reverse_each do |dim|
61
61
  arr = arr.each_slice(dim)
62
62
  end
63
63
  arr.to_a
@@ -132,9 +132,13 @@ module Torch
132
132
 
133
133
  # TODO read directly from memory
134
134
  def numo
135
- cls = Torch._dtype_to_numo[dtype]
136
- raise Error, "Cannot convert #{dtype} to Numo" unless cls
137
- cls.from_string(_data_str).reshape(*shape)
135
+ if dtype == :bool
136
+ Numo::UInt8.from_string(_data_str).ne(0).reshape(*shape)
137
+ else
138
+ cls = Torch._dtype_to_numo[dtype]
139
+ raise Error, "Cannot convert #{dtype} to Numo" unless cls
140
+ cls.from_string(_data_str).reshape(*shape)
141
+ end
138
142
  end
139
143
 
140
144
  def requires_grad=(requires_grad)
@@ -160,6 +164,7 @@ module Torch
160
164
  # based on python_variable_indexing.cpp and
161
165
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
162
166
  def [](*indexes)
167
+ indexes = indexes.map { |v| v.is_a?(Array) ? Torch.tensor(v) : v }
163
168
  _index(indexes)
164
169
  end
165
170
 
@@ -167,6 +172,7 @@ module Torch
167
172
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
168
173
  def []=(*indexes, value)
169
174
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
175
+ indexes = indexes.map { |v| v.is_a?(Array) ? Torch.tensor(v) : v }
170
176
  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
171
177
  _index_put_custom(indexes, value)
172
178
  end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.16.0"
2
+ VERSION = "0.17.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.16.0
4
+ version: 0.17.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2024-06-13 00:00:00.000000000 Z
11
+ date: 2024-08-19 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -16,14 +16,14 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: 4.1.0
19
+ version: '4.1'
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: 4.1.0
26
+ version: '4.1'
27
27
  description:
28
28
  email: andrew@ankane.org
29
29
  executables: []
@@ -43,17 +43,24 @@ files:
43
43
  - ext/torch/ext.cpp
44
44
  - ext/torch/extconf.rb
45
45
  - ext/torch/fft.cpp
46
+ - ext/torch/fft_functions.h
46
47
  - ext/torch/generator.cpp
47
48
  - ext/torch/ivalue.cpp
48
49
  - ext/torch/linalg.cpp
50
+ - ext/torch/linalg_functions.h
49
51
  - ext/torch/nn.cpp
52
+ - ext/torch/nn_functions.h
50
53
  - ext/torch/random.cpp
51
54
  - ext/torch/ruby_arg_parser.cpp
52
55
  - ext/torch/ruby_arg_parser.h
56
+ - ext/torch/sparse_functions.h
53
57
  - ext/torch/special.cpp
58
+ - ext/torch/special_functions.h
54
59
  - ext/torch/templates.h
55
60
  - ext/torch/tensor.cpp
61
+ - ext/torch/tensor_functions.h
56
62
  - ext/torch/torch.cpp
63
+ - ext/torch/torch_functions.h
57
64
  - ext/torch/utils.h
58
65
  - ext/torch/wrap_outputs.h
59
66
  - lib/torch-rb.rb