torch-rb 0.1.8 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -99,6 +99,8 @@ module Torch
99
99
  v.nil?
100
100
  when "bool"
101
101
  v == true || v == false
102
+ when "str"
103
+ v.is_a?(String)
102
104
  else
103
105
  raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}."
104
106
  end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class AdaptiveAvgPool1d < AdaptiveAvgPoolNd
4
+ def forward(input)
5
+ F.adaptive_avg_pool1d(input, @output_size)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class AdaptiveAvgPool2d < AdaptiveAvgPoolNd
4
+ def forward(input)
5
+ F.adaptive_avg_pool2d(input, @output_size)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class AdaptiveAvgPool3d < AdaptiveAvgPoolNd
4
+ def forward(input)
5
+ F.adaptive_avg_pool3d(input, @output_size)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class AdaptiveAvgPoolNd < Module
4
+ def initialize(output_size)
5
+ super()
6
+ @output_size = output_size
7
+ end
8
+
9
+ def extra_inspect
10
+ format("output_size: %s", @output_size)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class AdaptiveMaxPool1d < AdaptiveMaxPoolNd
4
+ def forward(input)
5
+ F.adaptive_max_pool1d(input, @output_size) #, @return_indices)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class AdaptiveMaxPool2d < AdaptiveMaxPoolNd
4
+ def forward(input)
5
+ F.adaptive_max_pool2d(input, @output_size) #, @return_indices)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class AdaptiveMaxPool3d < AdaptiveMaxPoolNd
4
+ def forward(input)
5
+ F.adaptive_max_pool3d(input, @output_size) #, @return_indices)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,15 @@
1
+ module Torch
2
+ module NN
3
+ class AdaptiveMaxPoolNd < Module
4
+ def initialize(output_size) #, return_indices: false)
5
+ super()
6
+ @output_size = output_size
7
+ # @return_indices = return_indices
8
+ end
9
+
10
+ def extra_inspect
11
+ format("output_size: %s", @output_size)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -99,6 +99,34 @@ module Torch
99
99
  NN.avg_pool3d(*args, **options)
100
100
  end
101
101
 
102
+ def adaptive_max_pool1d(*args, **options)
103
+ Torch.adaptive_max_pool1d(*args, **options)
104
+ end
105
+
106
+ def adaptive_max_pool2d(input, output_size)
107
+ output_size = list_with_default(output_size, input.size)
108
+ NN.adaptive_max_pool2d(input, output_size)
109
+ end
110
+
111
+ def adaptive_max_pool3d(input, output_size)
112
+ output_size = list_with_default(output_size, input.size)
113
+ NN.adaptive_max_pool3d(input, output_size)
114
+ end
115
+
116
+ def adaptive_avg_pool1d(*args, **options)
117
+ Torch.adaptive_avg_pool1d(*args, **options)
118
+ end
119
+
120
+ def adaptive_avg_pool2d(input, output_size)
121
+ output_size = list_with_default(output_size, input.size)
122
+ NN.adaptive_avg_pool2d(input, output_size)
123
+ end
124
+
125
+ def adaptive_avg_pool3d(input, output_size)
126
+ output_size = list_with_default(output_size, input.size)
127
+ NN.adaptive_avg_pool3d(input, output_size)
128
+ end
129
+
102
130
  # padding layers
103
131
 
104
132
  def pad(input, pad, mode: "constant", value: 0)
@@ -369,7 +397,7 @@ module Torch
369
397
  end
370
398
 
371
399
  def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
372
- raise NotImplementedYet
400
+ Torch.cosine_embedding_loss(input1, input2, target, margin, reduction)
373
401
  end
374
402
 
375
403
  def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
@@ -394,7 +422,7 @@ module Torch
394
422
  end
395
423
 
396
424
  def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
397
- raise NotImplementedYet
425
+ Torch.margin_ranking_loss(input1, input2, target, margin, reduction)
398
426
  end
399
427
 
400
428
  def mse_loss(input, target, reduction: "mean")
@@ -438,6 +466,16 @@ module Torch
438
466
  def softmax_dim(ndim)
439
467
  ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1
440
468
  end
469
+
470
+ def list_with_default(out_size, defaults)
471
+ if out_size.is_a?(Integer)
472
+ out_size
473
+ elsif defaults.length < out_size.length
474
+ raise ArgumentError, "Input dimension should be at least #{out_size.length + 1}"
475
+ else
476
+ out_size.zip(defaults.last(out_size.length)).map { |v, d| v || d }
477
+ end
478
+ end
441
479
  end
442
480
  end
443
481
 
@@ -34,6 +34,27 @@ module Torch
34
34
  children.each do |mod|
35
35
  mod._apply(fn)
36
36
  end
37
+
38
+ instance_variables.each do |key|
39
+ param = instance_variable_get(key)
40
+ if param.is_a?(Parameter)
41
+ param_applied = nil
42
+ Torch.no_grad do
43
+ param_applied = fn.call(param)
44
+ end
45
+ # TODO should_use_set_data
46
+ instance_variable_set(key, Parameter.new(param_applied, requires_grad: param.requires_grad))
47
+
48
+ if param.grad
49
+ grad_applied = nil
50
+ Torch.no_grad do
51
+ grad_applied = fn.call(param.grad)
52
+ end
53
+ # TODO should_use_set_data
54
+ instance_variable_get(key).grad = grad_applied.requires_grad!(param.grad.requires_grad)
55
+ end
56
+ end
57
+ end
37
58
  # TODO apply to more objects
38
59
  self
39
60
  end
@@ -111,7 +132,7 @@ module Torch
111
132
  params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter)
112
133
  end
113
134
  @parameters.each do |name, param|
114
- params[[prefix, name].join] = param
135
+ params[[prefix, name].join] = param if param
115
136
  end
116
137
  params
117
138
  end
@@ -0,0 +1,29 @@
1
+ module Torch
2
+ module Optim
3
+ module LRScheduler
4
+ class CosineAnnealingLR < LRScheduler
5
+ def initialize(optimizer, t_max, eta_min: 0, last_epoch: -1)
6
+ @t_max = t_max
7
+ @eta_min = eta_min
8
+ super(optimizer, last_epoch)
9
+ end
10
+
11
+ def get_lr
12
+ if @last_epoch == 0
13
+ @base_lrs
14
+ elsif (@last_epoch - 1 - @t_max) % (2 * @t_max) == 0
15
+ @base_lrs.zip(@optimizer.param_groups).map do |base_lr, group|
16
+ group[:lr] + (base_lr - @eta_min) * (1 - Math.cos(Math::PI / @t_max)) / 2
17
+ end
18
+ else
19
+ @optimizer.param_groups.map do |group|
20
+ (1 + Math.cos(Math::PI * @last_epoch / @t_max)) /
21
+ (1 + Math.cos(Math::PI * (@last_epoch - 1) / @t_max)) *
22
+ (group[:lr] - @eta_min) + @eta_min
23
+ end
24
+ end
25
+ end
26
+ end
27
+ end
28
+ end
29
+ end
@@ -0,0 +1,22 @@
1
+ module Torch
2
+ module Optim
3
+ module LRScheduler
4
+ class ExponentialLR < LRScheduler
5
+ def initialize(optimizer, gamma, last_epoch: -1)
6
+ @gamma = gamma
7
+ super(optimizer, last_epoch)
8
+ end
9
+
10
+ def get_lr
11
+ if @last_epoch == 0
12
+ @base_lrs
13
+ else
14
+ @optimizer.param_groups.map do |group|
15
+ group[:lr] * @gamma
16
+ end
17
+ end
18
+ end
19
+ end
20
+ end
21
+ end
22
+ end
@@ -0,0 +1,28 @@
1
+ module Torch
2
+ module Optim
3
+ module LRScheduler
4
+ class LambdaLR < LRScheduler
5
+ def initialize(optimizer, lr_lambda, last_epoch: -1)
6
+ @optimizer = optimizer
7
+
8
+ if !lr_lambda.is_a?(Array)
9
+ @lr_lambdas = [lr_lambda] * optimizer.param_groups.length
10
+ else
11
+ if lr_lambda.length != optimizer.param_groups.length
12
+ raise ArgumentError, "Expected #{optimizer.param_groups.length}, but got #{lr_lambda.length}"
13
+ end
14
+ @lr_lambdas = lr_lambda
15
+ end
16
+ @last_epoch = last_epoch
17
+ super(optimizer, last_epoch)
18
+ end
19
+
20
+ def get_lr
21
+ @lr_lambdas.zip(@base_lrs).map do |lmbda, base_lr|
22
+ base_lr * lmbda.call(@last_epoch)
23
+ end
24
+ end
25
+ end
26
+ end
27
+ end
28
+ end
@@ -0,0 +1,23 @@
1
+ module Torch
2
+ module Optim
3
+ module LRScheduler
4
+ class MultiStepLR < LRScheduler
5
+ def initialize(optimizer, milestones, gamma: 0.1, last_epoch: -1)
6
+ @milestones = milestones.map.with_index.map { |v, i| [v, i + 1] }.to_h
7
+ @gamma = gamma
8
+ super(optimizer, last_epoch)
9
+ end
10
+
11
+ def get_lr
12
+ if !@milestones.include?(@last_epoch)
13
+ @optimizer.param_groups.map { |group| group[:lr] }
14
+ else
15
+ @optimizer.param_groups.map do |group|
16
+ group[:lr] * @gamma ** @milestones[@last_epoch]
17
+ end
18
+ end
19
+ end
20
+ end
21
+ end
22
+ end
23
+ end
@@ -0,0 +1,32 @@
1
+ module Torch
2
+ module Optim
3
+ module LRScheduler
4
+ class MultiplicativeLR < LRScheduler
5
+ def initialize(optimizer, lr_lambda, last_epoch: -1)
6
+ @optimizer = optimizer
7
+
8
+ if !lr_lambda.is_a?(Array)
9
+ @lr_lambdas = [lr_lambda] * optimizer.param_groups.length
10
+ else
11
+ if lr_lambda.length != optimizer.param_groups.length
12
+ raise ArgumentError, "Expected #{optimizer.param_groups.length}, but got #{lr_lambda.length}"
13
+ end
14
+ @lr_lambdas = lr_lambda
15
+ end
16
+ @last_epoch = last_epoch
17
+ super(optimizer, last_epoch)
18
+ end
19
+
20
+ def get_lr
21
+ if @last_epoch > 0
22
+ @lr_lambdas.zip(@optimizer.param_groups).map do |lmbda, group|
23
+ group[:lr] * lmbda.call(@last_epoch)
24
+ end
25
+ else
26
+ @base_lrs
27
+ end
28
+ end
29
+ end
30
+ end
31
+ end
32
+ end
data/lib/torch/tensor.rb CHANGED
@@ -33,6 +33,10 @@ module Torch
33
33
  _to(device, _dtype, non_blocking, copy)
34
34
  end
35
35
 
36
+ def cpu
37
+ to("cpu")
38
+ end
39
+
36
40
  def size(dim = nil)
37
41
  if dim
38
42
  _size_int(dim)
@@ -182,6 +186,10 @@ module Torch
182
186
  end
183
187
  end
184
188
 
189
+ def random!(from = 0, to)
190
+ _random__from_to(from, to)
191
+ end
192
+
185
193
  private
186
194
 
187
195
  def copy_to(dst, src)
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.1.8"
2
+ VERSION = "0.2.0"
3
3
  end
data/lib/torch.rb CHANGED
@@ -25,7 +25,12 @@ require "torch/optim/sgd"
25
25
 
26
26
  # optim lr_scheduler
27
27
  require "torch/optim/lr_scheduler/lr_scheduler"
28
+ require "torch/optim/lr_scheduler/lambda_lr"
29
+ require "torch/optim/lr_scheduler/multiplicative_lr"
28
30
  require "torch/optim/lr_scheduler/step_lr"
31
+ require "torch/optim/lr_scheduler/multi_step_lr"
32
+ require "torch/optim/lr_scheduler/exponential_lr"
33
+ require "torch/optim/lr_scheduler/cosine_annealing_lr"
29
34
 
30
35
  # nn parameters
31
36
  require "torch/nn/parameter"
@@ -59,6 +64,14 @@ require "torch/nn/avg_pool3d"
59
64
  require "torch/nn/lp_poolnd"
60
65
  require "torch/nn/lp_pool1d"
61
66
  require "torch/nn/lp_pool2d"
67
+ require "torch/nn/adaptive_max_poolnd"
68
+ require "torch/nn/adaptive_max_pool1d"
69
+ require "torch/nn/adaptive_max_pool2d"
70
+ require "torch/nn/adaptive_max_pool3d"
71
+ require "torch/nn/adaptive_avg_poolnd"
72
+ require "torch/nn/adaptive_avg_pool1d"
73
+ require "torch/nn/adaptive_avg_pool2d"
74
+ require "torch/nn/adaptive_avg_pool3d"
62
75
 
63
76
  # nn padding layers
64
77
  require "torch/nn/reflection_padnd"
@@ -166,6 +179,9 @@ require "torch/utils/data/tensor_dataset"
166
179
  # random
167
180
  require "torch/random"
168
181
 
182
+ # hub
183
+ require "torch/hub"
184
+
169
185
  module Torch
170
186
  class Error < StandardError; end
171
187
  class NotImplementedYet < StandardError
@@ -365,6 +381,11 @@ module Torch
365
381
  end
366
382
 
367
383
  def tensor(data, **options)
384
+ if options[:dtype].nil? && defined?(Numo::NArray) && data.is_a?(Numo::NArray)
385
+ numo_to_dtype = _dtype_to_numo.map(&:reverse).to_h
386
+ options[:dtype] = numo_to_dtype[data.class]
387
+ end
388
+
368
389
  size = []
369
390
  if data.respond_to?(:to_a)
370
391
  data = data.to_a
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.8
4
+ version: 0.2.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-01-18 00:00:00.000000000 Z
11
+ date: 2020-04-23 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -94,6 +94,20 @@ dependencies:
94
94
  - - ">="
95
95
  - !ruby/object:Gem::Version
96
96
  version: '0'
97
+ - !ruby/object:Gem::Dependency
98
+ name: npy
99
+ requirement: !ruby/object:Gem::Requirement
100
+ requirements:
101
+ - - ">="
102
+ - !ruby/object:Gem::Version
103
+ version: '0'
104
+ type: :development
105
+ prerelease: false
106
+ version_requirements: !ruby/object:Gem::Requirement
107
+ requirements:
108
+ - - ">="
109
+ - !ruby/object:Gem::Version
110
+ version: '0'
97
111
  description:
98
112
  email: andrew@chartkick.com
99
113
  executables: []
@@ -106,16 +120,32 @@ files:
106
120
  - README.md
107
121
  - ext/torch/ext.cpp
108
122
  - ext/torch/extconf.rb
123
+ - ext/torch/nn_functions.cpp
124
+ - ext/torch/nn_functions.hpp
109
125
  - ext/torch/templates.cpp
110
126
  - ext/torch/templates.hpp
127
+ - ext/torch/tensor_functions.cpp
128
+ - ext/torch/tensor_functions.hpp
129
+ - ext/torch/torch_functions.cpp
130
+ - ext/torch/torch_functions.hpp
111
131
  - lib/torch-rb.rb
112
132
  - lib/torch.rb
133
+ - lib/torch/ext.bundle
134
+ - lib/torch/hub.rb
113
135
  - lib/torch/inspector.rb
114
136
  - lib/torch/native/dispatcher.rb
115
137
  - lib/torch/native/function.rb
116
138
  - lib/torch/native/generator.rb
117
139
  - lib/torch/native/native_functions.yaml
118
140
  - lib/torch/native/parser.rb
141
+ - lib/torch/nn/adaptive_avg_pool1d.rb
142
+ - lib/torch/nn/adaptive_avg_pool2d.rb
143
+ - lib/torch/nn/adaptive_avg_pool3d.rb
144
+ - lib/torch/nn/adaptive_avg_poolnd.rb
145
+ - lib/torch/nn/adaptive_max_pool1d.rb
146
+ - lib/torch/nn/adaptive_max_pool2d.rb
147
+ - lib/torch/nn/adaptive_max_pool3d.rb
148
+ - lib/torch/nn/adaptive_max_poolnd.rb
119
149
  - lib/torch/nn/alpha_dropout.rb
120
150
  - lib/torch/nn/avg_pool1d.rb
121
151
  - lib/torch/nn/avg_pool2d.rb
@@ -224,7 +254,12 @@ files:
224
254
  - lib/torch/optim/adamax.rb
225
255
  - lib/torch/optim/adamw.rb
226
256
  - lib/torch/optim/asgd.rb
257
+ - lib/torch/optim/lr_scheduler/cosine_annealing_lr.rb
258
+ - lib/torch/optim/lr_scheduler/exponential_lr.rb
259
+ - lib/torch/optim/lr_scheduler/lambda_lr.rb
227
260
  - lib/torch/optim/lr_scheduler/lr_scheduler.rb
261
+ - lib/torch/optim/lr_scheduler/multi_step_lr.rb
262
+ - lib/torch/optim/lr_scheduler/multiplicative_lr.rb
228
263
  - lib/torch/optim/lr_scheduler/step_lr.rb
229
264
  - lib/torch/optim/optimizer.rb
230
265
  - lib/torch/optim/rmsprop.rb
@@ -235,7 +270,7 @@ files:
235
270
  - lib/torch/utils/data/data_loader.rb
236
271
  - lib/torch/utils/data/tensor_dataset.rb
237
272
  - lib/torch/version.rb
238
- homepage: https://github.com/ankane/torch-rb
273
+ homepage: https://github.com/ankane/torch.rb
239
274
  licenses:
240
275
  - BSD-3-Clause
241
276
  metadata: {}