torch-rb 0.2.7 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +7 -2
- data/README.md +4 -1
- data/ext/torch/ext.cpp +4 -3
- data/lib/torch.rb +1 -5
- data/lib/torch/native/native_functions.yaml +654 -660
- data/lib/torch/nn/conv2d.rb +0 -1
- data/lib/torch/optim/optimizer.rb +6 -4
- data/lib/torch/tensor.rb +10 -11
- data/lib/torch/version.rb +1 -1
- metadata +2 -2
data/lib/torch/nn/conv2d.rb
CHANGED
@@ -18,7 +18,6 @@ module Torch
|
|
18
18
|
F.conv2d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
|
19
19
|
end
|
20
20
|
|
21
|
-
# TODO add more parameters
|
22
21
|
def extra_inspect
|
23
22
|
s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
|
24
23
|
s += ", padding: %{padding}" if @padding != [0] * @padding.size
|
@@ -32,9 +32,11 @@ module Torch
|
|
32
32
|
end
|
33
33
|
|
34
34
|
def state_dict
|
35
|
+
raise NotImplementedYet
|
36
|
+
|
35
37
|
pack_group = lambda do |group|
|
36
|
-
packed = group.select { |k, _| k != :params }.to_h
|
37
|
-
packed[
|
38
|
+
packed = group.select { |k, _| k != :params }.map { |k, v| [k.to_s, v] }.to_h
|
39
|
+
packed["params"] = group[:params].map { |p| p.object_id }
|
38
40
|
packed
|
39
41
|
end
|
40
42
|
|
@@ -42,8 +44,8 @@ module Torch
|
|
42
44
|
packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
|
43
45
|
|
44
46
|
{
|
45
|
-
state
|
46
|
-
param_groups
|
47
|
+
"state" => packed_state,
|
48
|
+
"param_groups" => param_groups
|
47
49
|
}
|
48
50
|
end
|
49
51
|
|
data/lib/torch/tensor.rb
CHANGED
@@ -47,10 +47,15 @@ module Torch
|
|
47
47
|
end
|
48
48
|
end
|
49
49
|
|
50
|
-
|
51
|
-
|
50
|
+
def to(device = nil, dtype: nil, non_blocking: false, copy: false)
|
51
|
+
device ||= self.device
|
52
52
|
device = Device.new(device) if device.is_a?(String)
|
53
|
-
|
53
|
+
|
54
|
+
dtype ||= self.dtype
|
55
|
+
enum = DTYPE_TO_ENUM[dtype]
|
56
|
+
raise Error, "Unknown type: #{dtype}" unless enum
|
57
|
+
|
58
|
+
_to(device, enum, non_blocking, copy)
|
54
59
|
end
|
55
60
|
|
56
61
|
def cpu
|
@@ -215,11 +220,11 @@ module Torch
|
|
215
220
|
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
216
221
|
|
217
222
|
if index.is_a?(Numeric)
|
218
|
-
|
223
|
+
index_put!([Torch.tensor(index)], value)
|
219
224
|
elsif index.is_a?(Range)
|
220
225
|
finish = index.end
|
221
226
|
finish += 1 unless index.exclude_end?
|
222
|
-
|
227
|
+
_slice_tensor(0, index.begin, finish, 1).copy!(value)
|
223
228
|
elsif index.is_a?(Tensor)
|
224
229
|
index_put!([index], value)
|
225
230
|
else
|
@@ -254,11 +259,5 @@ module Torch
|
|
254
259
|
_clamp_min_(min)
|
255
260
|
_clamp_max_(max)
|
256
261
|
end
|
257
|
-
|
258
|
-
private
|
259
|
-
|
260
|
-
def copy_to(dst, src)
|
261
|
-
dst.copy!(src)
|
262
|
-
end
|
263
262
|
end
|
264
263
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.3.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-
|
11
|
+
date: 2020-07-29 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|