torch-rb 0.2.7 → 0.3.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +7 -2
- data/README.md +4 -1
- data/ext/torch/ext.cpp +4 -3
- data/lib/torch.rb +1 -5
- data/lib/torch/native/native_functions.yaml +654 -660
- data/lib/torch/nn/conv2d.rb +0 -1
- data/lib/torch/optim/optimizer.rb +6 -4
- data/lib/torch/tensor.rb +10 -11
- data/lib/torch/version.rb +1 -1
- metadata +2 -2
data/lib/torch/nn/conv2d.rb
CHANGED
@@ -18,7 +18,6 @@ module Torch
|
|
18
18
|
F.conv2d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
|
19
19
|
end
|
20
20
|
|
21
|
-
# TODO add more parameters
|
22
21
|
def extra_inspect
|
23
22
|
s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
|
24
23
|
s += ", padding: %{padding}" if @padding != [0] * @padding.size
|
@@ -32,9 +32,11 @@ module Torch
|
|
32
32
|
end
|
33
33
|
|
34
34
|
def state_dict
|
35
|
+
raise NotImplementedYet
|
36
|
+
|
35
37
|
pack_group = lambda do |group|
|
36
|
-
packed = group.select { |k, _| k != :params }.to_h
|
37
|
-
packed[
|
38
|
+
packed = group.select { |k, _| k != :params }.map { |k, v| [k.to_s, v] }.to_h
|
39
|
+
packed["params"] = group[:params].map { |p| p.object_id }
|
38
40
|
packed
|
39
41
|
end
|
40
42
|
|
@@ -42,8 +44,8 @@ module Torch
|
|
42
44
|
packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
|
43
45
|
|
44
46
|
{
|
45
|
-
state
|
46
|
-
param_groups
|
47
|
+
"state" => packed_state,
|
48
|
+
"param_groups" => param_groups
|
47
49
|
}
|
48
50
|
end
|
49
51
|
|
data/lib/torch/tensor.rb
CHANGED
@@ -47,10 +47,15 @@ module Torch
|
|
47
47
|
end
|
48
48
|
end
|
49
49
|
|
50
|
-
|
51
|
-
|
50
|
+
def to(device = nil, dtype: nil, non_blocking: false, copy: false)
|
51
|
+
device ||= self.device
|
52
52
|
device = Device.new(device) if device.is_a?(String)
|
53
|
-
|
53
|
+
|
54
|
+
dtype ||= self.dtype
|
55
|
+
enum = DTYPE_TO_ENUM[dtype]
|
56
|
+
raise Error, "Unknown type: #{dtype}" unless enum
|
57
|
+
|
58
|
+
_to(device, enum, non_blocking, copy)
|
54
59
|
end
|
55
60
|
|
56
61
|
def cpu
|
@@ -215,11 +220,11 @@ module Torch
|
|
215
220
|
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
216
221
|
|
217
222
|
if index.is_a?(Numeric)
|
218
|
-
|
223
|
+
index_put!([Torch.tensor(index)], value)
|
219
224
|
elsif index.is_a?(Range)
|
220
225
|
finish = index.end
|
221
226
|
finish += 1 unless index.exclude_end?
|
222
|
-
|
227
|
+
_slice_tensor(0, index.begin, finish, 1).copy!(value)
|
223
228
|
elsif index.is_a?(Tensor)
|
224
229
|
index_put!([index], value)
|
225
230
|
else
|
@@ -254,11 +259,5 @@ module Torch
|
|
254
259
|
_clamp_min_(min)
|
255
260
|
_clamp_max_(max)
|
256
261
|
end
|
257
|
-
|
258
|
-
private
|
259
|
-
|
260
|
-
def copy_to(dst, src)
|
261
|
-
dst.copy!(src)
|
262
|
-
end
|
263
262
|
end
|
264
263
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.3.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-
|
11
|
+
date: 2020-07-29 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|