torch-rb 0.2.7 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -18,7 +18,6 @@ module Torch
18
18
  F.conv2d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
19
19
  end
20
20
 
21
- # TODO add more parameters
22
21
  def extra_inspect
23
22
  s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
24
23
  s += ", padding: %{padding}" if @padding != [0] * @padding.size
@@ -32,9 +32,11 @@ module Torch
32
32
  end
33
33
 
34
34
  def state_dict
35
+ raise NotImplementedYet
36
+
35
37
  pack_group = lambda do |group|
36
- packed = group.select { |k, _| k != :params }.to_h
37
- packed[:params] = group[:params].map { |p| p.object_id }
38
+ packed = group.select { |k, _| k != :params }.map { |k, v| [k.to_s, v] }.to_h
39
+ packed["params"] = group[:params].map { |p| p.object_id }
38
40
  packed
39
41
  end
40
42
 
@@ -42,8 +44,8 @@ module Torch
42
44
  packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
43
45
 
44
46
  {
45
- state: packed_state,
46
- param_groups: param_groups
47
+ "state" => packed_state,
48
+ "param_groups" => param_groups
47
49
  }
48
50
  end
49
51
 
@@ -47,10 +47,15 @@ module Torch
47
47
  end
48
48
  end
49
49
 
50
- # TODO support dtype
51
- def to(device, non_blocking: false, copy: false)
50
+ def to(device = nil, dtype: nil, non_blocking: false, copy: false)
51
+ device ||= self.device
52
52
  device = Device.new(device) if device.is_a?(String)
53
- _to(device, _dtype, non_blocking, copy)
53
+
54
+ dtype ||= self.dtype
55
+ enum = DTYPE_TO_ENUM[dtype]
56
+ raise Error, "Unknown type: #{dtype}" unless enum
57
+
58
+ _to(device, enum, non_blocking, copy)
54
59
  end
55
60
 
56
61
  def cpu
@@ -215,11 +220,11 @@ module Torch
215
220
  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
216
221
 
217
222
  if index.is_a?(Numeric)
218
- copy_to(_select_int(0, index), value)
223
+ index_put!([Torch.tensor(index)], value)
219
224
  elsif index.is_a?(Range)
220
225
  finish = index.end
221
226
  finish += 1 unless index.exclude_end?
222
- copy_to(_slice_tensor(0, index.begin, finish, 1), value)
227
+ _slice_tensor(0, index.begin, finish, 1).copy!(value)
223
228
  elsif index.is_a?(Tensor)
224
229
  index_put!([index], value)
225
230
  else
@@ -254,11 +259,5 @@ module Torch
254
259
  _clamp_min_(min)
255
260
  _clamp_max_(max)
256
261
  end
257
-
258
- private
259
-
260
- def copy_to(dst, src)
261
- dst.copy!(src)
262
- end
263
262
  end
264
263
  end
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.2.7"
2
+ VERSION = "0.3.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.7
4
+ version: 0.3.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-06-30 00:00:00.000000000 Z
11
+ date: 2020-07-29 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice