torch-rb 0.2.4 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -75,7 +75,7 @@ module Torch
75
75
  v.is_a?(Tensor)
76
76
  when "Tensor?"
77
77
  v.nil? || v.is_a?(Tensor)
78
- when "Tensor[]"
78
+ when "Tensor[]", "Tensor?[]"
79
79
  v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
80
80
  when "int"
81
81
  if k == "reduction"
@@ -18,7 +18,6 @@ module Torch
18
18
  F.conv2d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
19
19
  end
20
20
 
21
- # TODO add more parameters
22
21
  def extra_inspect
23
22
  s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
24
23
  s += ", padding: %{padding}" if @padding != [0] * @padding.size
@@ -145,7 +145,7 @@ module Torch
145
145
  params = {}
146
146
  if recurse
147
147
  named_children.each do |name, mod|
148
- params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
148
+ params.merge!(mod.named_parameters(prefix: "#{prefix}#{name}.", recurse: recurse))
149
149
  end
150
150
  end
151
151
  instance_variables.each do |name|
@@ -286,8 +286,11 @@ module Torch
286
286
  str % vars
287
287
  end
288
288
 
289
+ # used for format
290
+ # remove tensors for performance
291
+ # so we can skip call to inspect
289
292
  def dict
290
- instance_variables.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
293
+ instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
291
294
  end
292
295
  end
293
296
  end
@@ -32,9 +32,11 @@ module Torch
32
32
  end
33
33
 
34
34
  def state_dict
35
+ raise NotImplementedYet
36
+
35
37
  pack_group = lambda do |group|
36
- packed = group.select { |k, _| k != :params }.to_h
37
- packed[:params] = group[:params].map { |p| p.object_id }
38
+ packed = group.select { |k, _| k != :params }.map { |k, v| [k.to_s, v] }.to_h
39
+ packed["params"] = group[:params].map { |p| p.object_id }
38
40
  packed
39
41
  end
40
42
 
@@ -42,8 +44,8 @@ module Torch
42
44
  packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
43
45
 
44
46
  {
45
- state: packed_state,
46
- param_groups: param_groups
47
+ "state" => packed_state,
48
+ "param_groups" => param_groups
47
49
  }
48
50
  end
49
51
 
@@ -11,9 +11,6 @@ module Torch
11
11
  end
12
12
 
13
13
  def step(closure = nil)
14
- # TODO implement []=
15
- raise NotImplementedYet
16
-
17
14
  loss = nil
18
15
  if closure
19
16
  loss = closure.call
@@ -1,6 +1,7 @@
1
1
  module Torch
2
2
  class Tensor
3
3
  include Comparable
4
+ include Enumerable
4
5
  include Inspector
5
6
 
6
7
  alias_method :requires_grad?, :requires_grad
@@ -25,14 +26,36 @@ module Torch
25
26
  inspect
26
27
  end
27
28
 
29
+ def each
30
+ return enum_for(:each) unless block_given?
31
+
32
+ size(0).times do |i|
33
+ yield self[i]
34
+ end
35
+ end
36
+
37
+ # TODO make more performant
28
38
  def to_a
29
- reshape_arr(_flat_data, shape)
39
+ arr = _flat_data
40
+ if shape.empty?
41
+ arr
42
+ else
43
+ shape[1..-1].reverse.each do |dim|
44
+ arr = arr.each_slice(dim)
45
+ end
46
+ arr.to_a
47
+ end
30
48
  end
31
49
 
32
- # TODO support dtype
33
- def to(device, non_blocking: false, copy: false)
50
+ def to(device = nil, dtype: nil, non_blocking: false, copy: false)
51
+ device ||= self.device
34
52
  device = Device.new(device) if device.is_a?(String)
35
- _to(device, _dtype, non_blocking, copy)
53
+
54
+ dtype ||= self.dtype
55
+ enum = DTYPE_TO_ENUM[dtype]
56
+ raise Error, "Unknown type: #{dtype}" unless enum
57
+
58
+ _to(device, enum, non_blocking, copy)
36
59
  end
37
60
 
38
61
  def cpu
@@ -64,7 +87,7 @@ module Torch
64
87
  if numel != 1
65
88
  raise Error, "only one element tensors can be converted to Ruby scalars"
66
89
  end
67
- _flat_data.first
90
+ to_a.first
68
91
  end
69
92
 
70
93
  def to_i
@@ -80,15 +103,16 @@ module Torch
80
103
  Torch.empty(0, dtype: dtype)
81
104
  end
82
105
 
83
- def backward(gradient = nil)
84
- _backward(gradient)
106
+ def backward(gradient = nil, retain_graph: nil, create_graph: false)
107
+ retain_graph = create_graph if retain_graph.nil?
108
+ _backward(gradient, retain_graph, create_graph)
85
109
  end
86
110
 
87
111
  # TODO read directly from memory
88
112
  def numo
89
113
  cls = Torch._dtype_to_numo[dtype]
90
114
  raise Error, "Cannot convert #{dtype} to Numo" unless cls
91
- cls.cast(_flat_data).reshape(*shape)
115
+ cls.from_string(_data_str).reshape(*shape)
92
116
  end
93
117
 
94
118
  def new_ones(*size, **options)
@@ -116,15 +140,6 @@ module Torch
116
140
  _view(size)
117
141
  end
118
142
 
119
- # value and other are swapped for some methods
120
- def add!(value = 1, other)
121
- if other.is_a?(Numeric)
122
- _add__scalar(other, value)
123
- else
124
- _add__tensor(other, value)
125
- end
126
- end
127
-
128
143
  def +(other)
129
144
  add(other)
130
145
  end
@@ -153,12 +168,25 @@ module Torch
153
168
  neg
154
169
  end
155
170
 
171
+ def &(other)
172
+ logical_and(other)
173
+ end
174
+
175
+ def |(other)
176
+ logical_or(other)
177
+ end
178
+
179
+ def ^(other)
180
+ logical_xor(other)
181
+ end
182
+
156
183
  # TODO better compare?
157
184
  def <=>(other)
158
185
  item <=> other
159
186
  end
160
187
 
161
- # based on python_variable_indexing.cpp
188
+ # based on python_variable_indexing.cpp and
189
+ # https://pytorch.org/cppdocs/notes/tensor_indexing.html
162
190
  def [](*indexes)
163
191
  result = self
164
192
  dim = 0
@@ -170,6 +198,8 @@ module Torch
170
198
  finish += 1 unless index.exclude_end?
171
199
  result = result._slice_tensor(dim, index.begin, finish, 1)
172
200
  dim += 1
201
+ elsif index.is_a?(Tensor)
202
+ result = result.index([index])
173
203
  elsif index.nil?
174
204
  result = result.unsqueeze(dim)
175
205
  dim += 1
@@ -183,24 +213,37 @@ module Torch
183
213
  result
184
214
  end
185
215
 
186
- # TODO
187
- # based on python_variable_indexing.cpp
216
+ # based on python_variable_indexing.cpp and
217
+ # https://pytorch.org/cppdocs/notes/tensor_indexing.html
188
218
  def []=(index, value)
189
219
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
190
220
 
191
- value = Torch.tensor(value) unless value.is_a?(Tensor)
221
+ value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
192
222
 
193
223
  if index.is_a?(Numeric)
194
- copy_to(_select_int(0, index), value)
224
+ index_put!([Torch.tensor(index)], value)
195
225
  elsif index.is_a?(Range)
196
226
  finish = index.end
197
227
  finish += 1 unless index.exclude_end?
198
- copy_to(_slice_tensor(0, index.begin, finish, 1), value)
228
+ _slice_tensor(0, index.begin, finish, 1).copy!(value)
229
+ elsif index.is_a?(Tensor)
230
+ index_put!([index], value)
199
231
  else
200
232
  raise Error, "Unsupported index type: #{index.class.name}"
201
233
  end
202
234
  end
203
235
 
236
+ # native functions that need manually defined
237
+
238
+ # value and other are swapped for some methods
239
+ def add!(value = 1, other)
240
+ if other.is_a?(Numeric)
241
+ _add__scalar(other, value)
242
+ else
243
+ _add__tensor(other, value)
244
+ end
245
+ end
246
+
204
247
  # native functions overlap, so need to handle manually
205
248
  def random!(*args)
206
249
  case args.size
@@ -213,22 +256,9 @@ module Torch
213
256
  end
214
257
  end
215
258
 
216
- private
217
-
218
- def copy_to(dst, src)
219
- dst.copy!(src)
220
- end
221
-
222
- def reshape_arr(arr, dims)
223
- if dims.empty?
224
- arr
225
- else
226
- arr = arr.flatten
227
- dims[1..-1].reverse.each do |dim|
228
- arr = arr.each_slice(dim)
229
- end
230
- arr.to_a
231
- end
259
+ def clamp!(min, max)
260
+ _clamp_min_(min)
261
+ _clamp_max_(max)
232
262
  end
233
263
  end
234
264
  end
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.2.4"
2
+ VERSION = "0.3.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.4
4
+ version: 0.3.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-04-29 00:00:00.000000000 Z
11
+ date: 2020-08-17 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice