torch-rb 0.2.7 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,7 +18,6 @@ module Torch
18
18
  F.conv2d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
19
19
  end
20
20
 
21
- # TODO add more parameters
22
21
  def extra_inspect
23
22
  s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
24
23
  s += ", padding: %{padding}" if @padding != [0] * @padding.size
@@ -32,9 +32,11 @@ module Torch
32
32
  end
33
33
 
34
34
  def state_dict
35
+ raise NotImplementedYet
36
+
35
37
  pack_group = lambda do |group|
36
- packed = group.select { |k, _| k != :params }.to_h
37
- packed[:params] = group[:params].map { |p| p.object_id }
38
+ packed = group.select { |k, _| k != :params }.map { |k, v| [k.to_s, v] }.to_h
39
+ packed["params"] = group[:params].map { |p| p.object_id }
38
40
  packed
39
41
  end
40
42
 
@@ -42,8 +44,8 @@ module Torch
42
44
  packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
43
45
 
44
46
  {
45
- state: packed_state,
46
- param_groups: param_groups
47
+ "state" => packed_state,
48
+ "param_groups" => param_groups
47
49
  }
48
50
  end
49
51
 
@@ -47,10 +47,15 @@ module Torch
47
47
  end
48
48
  end
49
49
 
50
- # TODO support dtype
51
- def to(device, non_blocking: false, copy: false)
50
+ def to(device = nil, dtype: nil, non_blocking: false, copy: false)
51
+ device ||= self.device
52
52
  device = Device.new(device) if device.is_a?(String)
53
- _to(device, _dtype, non_blocking, copy)
53
+
54
+ dtype ||= self.dtype
55
+ enum = DTYPE_TO_ENUM[dtype]
56
+ raise Error, "Unknown type: #{dtype}" unless enum
57
+
58
+ _to(device, enum, non_blocking, copy)
54
59
  end
55
60
 
56
61
  def cpu
@@ -215,11 +220,11 @@ module Torch
215
220
  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
216
221
 
217
222
  if index.is_a?(Numeric)
218
- copy_to(_select_int(0, index), value)
223
+ index_put!([Torch.tensor(index)], value)
219
224
  elsif index.is_a?(Range)
220
225
  finish = index.end
221
226
  finish += 1 unless index.exclude_end?
222
- copy_to(_slice_tensor(0, index.begin, finish, 1), value)
227
+ _slice_tensor(0, index.begin, finish, 1).copy!(value)
223
228
  elsif index.is_a?(Tensor)
224
229
  index_put!([index], value)
225
230
  else
@@ -254,11 +259,5 @@ module Torch
254
259
  _clamp_min_(min)
255
260
  _clamp_max_(max)
256
261
  end
257
-
258
- private
259
-
260
- def copy_to(dst, src)
261
- dst.copy!(src)
262
- end
263
262
  end
264
263
  end
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.2.7"
2
+ VERSION = "0.3.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.7
4
+ version: 0.3.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-06-30 00:00:00.000000000 Z
11
+ date: 2020-07-29 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice