torch-rb 0.2.2 → 0.2.7

Sign up to get free protection for your applications and to get access to all the features.
@@ -11,9 +11,6 @@ module Torch
11
11
  end
12
12
 
13
13
  def step(closure = nil)
14
- # TODO implement []=
15
- raise NotImplementedYet
16
-
17
14
  loss = nil
18
15
  if closure
19
16
  loss = closure.call
@@ -1,6 +1,7 @@
1
1
  module Torch
2
2
  class Tensor
3
3
  include Comparable
4
+ include Enumerable
4
5
  include Inspector
5
6
 
6
7
  alias_method :requires_grad?, :requires_grad
@@ -25,8 +26,25 @@ module Torch
25
26
  inspect
26
27
  end
27
28
 
29
+ def each
30
+ return enum_for(:each) unless block_given?
31
+
32
+ size(0).times do |i|
33
+ yield self[i]
34
+ end
35
+ end
36
+
37
+ # TODO make more performant
28
38
  def to_a
29
- reshape_arr(_flat_data, shape)
39
+ arr = _flat_data
40
+ if shape.empty?
41
+ arr
42
+ else
43
+ shape[1..-1].reverse.each do |dim|
44
+ arr = arr.each_slice(dim)
45
+ end
46
+ arr.to_a
47
+ end
30
48
  end
31
49
 
32
50
  # TODO support dtype
@@ -64,7 +82,15 @@ module Torch
64
82
  if numel != 1
65
83
  raise Error, "only one element tensors can be converted to Ruby scalars"
66
84
  end
67
- _flat_data.first
85
+ to_a.first
86
+ end
87
+
88
+ def to_i
89
+ item.to_i
90
+ end
91
+
92
+ def to_f
93
+ item.to_f
68
94
  end
69
95
 
70
96
  # unsure if this is correct
@@ -80,7 +106,7 @@ module Torch
80
106
  def numo
81
107
  cls = Torch._dtype_to_numo[dtype]
82
108
  raise Error, "Cannot convert #{dtype} to Numo" unless cls
83
- cls.cast(_flat_data).reshape(*shape)
109
+ cls.from_string(_data_str).reshape(*shape)
84
110
  end
85
111
 
86
112
  def new_ones(*size, **options)
@@ -108,15 +134,6 @@ module Torch
108
134
  _view(size)
109
135
  end
110
136
 
111
- # value and other are swapped for some methods
112
- def add!(value = 1, other)
113
- if other.is_a?(Numeric)
114
- _add__scalar(other, value)
115
- else
116
- _add__tensor(other, value)
117
- end
118
- end
119
-
120
137
  def +(other)
121
138
  add(other)
122
139
  end
@@ -145,12 +162,25 @@ module Torch
145
162
  neg
146
163
  end
147
164
 
165
+ def &(other)
166
+ logical_and(other)
167
+ end
168
+
169
+ def |(other)
170
+ logical_or(other)
171
+ end
172
+
173
+ def ^(other)
174
+ logical_xor(other)
175
+ end
176
+
148
177
  # TODO better compare?
149
178
  def <=>(other)
150
179
  item <=> other
151
180
  end
152
181
 
153
- # based on python_variable_indexing.cpp
182
+ # based on python_variable_indexing.cpp and
183
+ # https://pytorch.org/cppdocs/notes/tensor_indexing.html
154
184
  def [](*indexes)
155
185
  result = self
156
186
  dim = 0
@@ -162,6 +192,8 @@ module Torch
162
192
  finish += 1 unless index.exclude_end?
163
193
  result = result._slice_tensor(dim, index.begin, finish, 1)
164
194
  dim += 1
195
+ elsif index.is_a?(Tensor)
196
+ result = result.index([index])
165
197
  elsif index.nil?
166
198
  result = result.unsqueeze(dim)
167
199
  dim += 1
@@ -175,12 +207,12 @@ module Torch
175
207
  result
176
208
  end
177
209
 
178
- # TODO
179
- # based on python_variable_indexing.cpp
210
+ # based on python_variable_indexing.cpp and
211
+ # https://pytorch.org/cppdocs/notes/tensor_indexing.html
180
212
  def []=(index, value)
181
213
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
182
214
 
183
- value = Torch.tensor(value) unless value.is_a?(Tensor)
215
+ value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
184
216
 
185
217
  if index.is_a?(Numeric)
186
218
  copy_to(_select_int(0, index), value)
@@ -188,13 +220,39 @@ module Torch
188
220
  finish = index.end
189
221
  finish += 1 unless index.exclude_end?
190
222
  copy_to(_slice_tensor(0, index.begin, finish, 1), value)
223
+ elsif index.is_a?(Tensor)
224
+ index_put!([index], value)
191
225
  else
192
226
  raise Error, "Unsupported index type: #{index.class.name}"
193
227
  end
194
228
  end
195
229
 
196
- def random!(from = 0, to)
197
- _random__from_to(from, to)
230
+ # native functions that need manually defined
231
+
232
+ # value and other are swapped for some methods
233
+ def add!(value = 1, other)
234
+ if other.is_a?(Numeric)
235
+ _add__scalar(other, value)
236
+ else
237
+ _add__tensor(other, value)
238
+ end
239
+ end
240
+
241
+ # native functions overlap, so need to handle manually
242
+ def random!(*args)
243
+ case args.size
244
+ when 1
245
+ _random__to(*args)
246
+ when 2
247
+ _random__from_to(*args)
248
+ else
249
+ _random_(*args)
250
+ end
251
+ end
252
+
253
+ def clamp!(min, max)
254
+ _clamp_min_(min)
255
+ _clamp_max_(max)
198
256
  end
199
257
 
200
258
  private
@@ -202,17 +260,5 @@ module Torch
202
260
  def copy_to(dst, src)
203
261
  dst.copy!(src)
204
262
  end
205
-
206
- def reshape_arr(arr, dims)
207
- if dims.empty?
208
- arr
209
- else
210
- arr = arr.flatten
211
- dims[1..-1].reverse.each do |dim|
212
- arr = arr.each_slice(dim)
213
- end
214
- arr.to_a
215
- end
216
- end
217
263
  end
218
264
  end
@@ -6,21 +6,49 @@ module Torch
6
6
 
7
7
  attr_reader :dataset
8
8
 
9
- def initialize(dataset, batch_size: 1)
9
+ def initialize(dataset, batch_size: 1, shuffle: false)
10
10
  @dataset = dataset
11
11
  @batch_size = batch_size
12
+ @shuffle = shuffle
12
13
  end
13
14
 
14
15
  def each
15
- size.times do |i|
16
- start_index = i * @batch_size
17
- yield @dataset[start_index...(start_index + @batch_size)]
16
+ # try to keep the random number generator in sync with Python
17
+ # this makes it easy to compare results
18
+ base_seed = Torch.empty([], dtype: :int64).random!.item
19
+
20
+ indexes =
21
+ if @shuffle
22
+ Torch.randperm(@dataset.size).to_a
23
+ else
24
+ @dataset.size.times
25
+ end
26
+
27
+ indexes.each_slice(@batch_size) do |idx|
28
+ batch = idx.map { |i| @dataset[i] }
29
+ yield collate(batch)
18
30
  end
19
31
  end
20
32
 
21
33
  def size
22
34
  (@dataset.size / @batch_size.to_f).ceil
23
35
  end
36
+
37
+ private
38
+
39
+ def collate(batch)
40
+ elem = batch[0]
41
+ case elem
42
+ when Tensor
43
+ Torch.stack(batch, 0)
44
+ when Integer
45
+ Torch.tensor(batch)
46
+ when Array
47
+ batch.transpose.map { |v| collate(v) }
48
+ else
49
+ raise NotImpelmentYet
50
+ end
51
+ end
24
52
  end
25
53
  end
26
54
  end
@@ -0,0 +1,8 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class Dataset
5
+ end
6
+ end
7
+ end
8
+ end
@@ -1,7 +1,7 @@
1
1
  module Torch
2
2
  module Utils
3
3
  module Data
4
- class TensorDataset
4
+ class TensorDataset < Dataset
5
5
  def initialize(*tensors)
6
6
  unless tensors.all? { |t| t.size(0) == tensors[0].size(0) }
7
7
  raise Error, "Tensors must all have same dim 0 size"
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.2.2"
2
+ VERSION = "0.2.7"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.2
4
+ version: 0.2.7
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-04-27 00:00:00.000000000 Z
11
+ date: 2020-06-30 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -95,19 +95,19 @@ dependencies:
95
95
  - !ruby/object:Gem::Version
96
96
  version: '0'
97
97
  - !ruby/object:Gem::Dependency
98
- name: npy
98
+ name: torchvision
99
99
  requirement: !ruby/object:Gem::Requirement
100
100
  requirements:
101
101
  - - ">="
102
102
  - !ruby/object:Gem::Version
103
- version: '0'
103
+ version: 0.1.1
104
104
  type: :development
105
105
  prerelease: false
106
106
  version_requirements: !ruby/object:Gem::Requirement
107
107
  requirements:
108
108
  - - ">="
109
109
  - !ruby/object:Gem::Version
110
- version: '0'
110
+ version: 0.1.1
111
111
  description:
112
112
  email: andrew@chartkick.com
113
113
  executables: []
@@ -258,9 +258,9 @@ files:
258
258
  - lib/torch/optim/rmsprop.rb
259
259
  - lib/torch/optim/rprop.rb
260
260
  - lib/torch/optim/sgd.rb
261
- - lib/torch/random.rb
262
261
  - lib/torch/tensor.rb
263
262
  - lib/torch/utils/data/data_loader.rb
263
+ - lib/torch/utils/data/dataset.rb
264
264
  - lib/torch/utils/data/tensor_dataset.rb
265
265
  - lib/torch/version.rb
266
266
  homepage: https://github.com/ankane/torch.rb
@@ -1,10 +0,0 @@
1
- module Torch
2
- module Random
3
- class << self
4
- # not available through LibTorch
5
- def initial_seed
6
- raise NotImplementedYet
7
- end
8
- end
9
- end
10
- end