torch-rb 0.2.0 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,28 +2,33 @@ require "mkmf-rice"
2
2
 
3
3
  abort "Missing stdc++" unless have_library("stdc++")
4
4
 
5
- $CXXFLAGS << " -std=c++14"
5
+ $CXXFLAGS += " -std=c++14"
6
6
 
7
7
  # change to 0 for Linux pre-cxx11 ABI version
8
- $CXXFLAGS << " -D_GLIBCXX_USE_CXX11_ABI=1"
8
+ $CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
9
9
 
10
10
  # TODO check compiler name
11
11
  clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
12
12
 
13
+ # check omp first
13
14
  if have_library("omp") || have_library("gomp")
14
- $CXXFLAGS << " -DAT_PARALLEL_OPENMP=1"
15
- $CXXFLAGS << " -Xclang" if clang
16
- $CXXFLAGS << " -fopenmp"
15
+ $CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
16
+ $CXXFLAGS += " -Xclang" if clang
17
+ $CXXFLAGS += " -fopenmp"
17
18
  end
18
19
 
19
- # silence ruby/intern.h warning
20
- $CXXFLAGS << " -Wno-deprecated-register"
21
-
22
- # silence torch warnings
23
20
  if clang
24
- $CXXFLAGS << " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
21
+ # silence ruby/intern.h warning
22
+ $CXXFLAGS += " -Wno-deprecated-register"
23
+
24
+ # silence torch warnings
25
+ $CXXFLAGS += " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
25
26
  else
26
- $CXXFLAGS << " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
27
+ # silence rice warnings
28
+ $CXXFLAGS += " -Wno-noexcept-type"
29
+
30
+ # silence torch warnings
31
+ $CXXFLAGS += " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
27
32
  end
28
33
 
29
34
  inc, lib = dir_config("torch")
@@ -34,22 +39,30 @@ cuda_inc, cuda_lib = dir_config("cuda")
34
39
  cuda_inc ||= "/usr/local/cuda/include"
35
40
  cuda_lib ||= "/usr/local/cuda/lib64"
36
41
 
37
- with_cuda = Dir["#{lib}/*torch_cuda*"].any? && have_library("cuda") && have_library("cudnn")
42
+ $LDFLAGS += " -L#{lib}" if Dir.exist?(lib)
43
+ abort "LibTorch not found" unless have_library("torch")
44
+
45
+ have_library("mkldnn")
46
+ have_library("nnpack")
47
+
48
+ with_cuda = false
49
+ if Dir["#{lib}/*torch_cuda*"].any?
50
+ $LDFLAGS += " -L#{cuda_lib}" if Dir.exist?(cuda_lib)
51
+ with_cuda = have_library("cuda") && have_library("cudnn")
52
+ end
38
53
 
39
- $INCFLAGS << " -I#{inc}"
40
- $INCFLAGS << " -I#{inc}/torch/csrc/api/include"
54
+ $INCFLAGS += " -I#{inc}"
55
+ $INCFLAGS += " -I#{inc}/torch/csrc/api/include"
41
56
 
42
- $LDFLAGS << " -Wl,-rpath,#{lib}"
43
- $LDFLAGS << ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
44
- $LDFLAGS << " -L#{lib}"
45
- $LDFLAGS << " -L#{cuda_lib}" if with_cuda
57
+ $LDFLAGS += " -Wl,-rpath,#{lib}"
58
+ $LDFLAGS += ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
46
59
 
47
60
  # https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238
48
- $LDFLAGS << " -lc10 -ltorch_cpu -ltorch"
61
+ $LDFLAGS += " -lc10 -ltorch_cpu -ltorch"
49
62
  if with_cuda
50
- $LDFLAGS << " -lcuda -lnvrtc -lnvToolsExt -lcudart -lc10_cuda -ltorch_cuda -lcufft -lcurand -lcublas -lcudnn"
63
+ $LDFLAGS += " -lcuda -lnvrtc -lnvToolsExt -lcudart -lc10_cuda -ltorch_cuda -lcufft -lcurand -lcublas -lcudnn"
51
64
  # TODO figure out why this is needed
52
- $LDFLAGS << " -Wl,--no-as-needed,#{lib}/libtorch.so"
65
+ $LDFLAGS += " -Wl,--no-as-needed,#{lib}/libtorch.so"
53
66
  end
54
67
 
55
68
  # generate C++ functions
@@ -1,6 +1,11 @@
1
1
  # ext
2
2
  require "torch/ext"
3
3
 
4
+ # stdlib
5
+ require "fileutils"
6
+ require "net/http"
7
+ require "tmpdir"
8
+
4
9
  # native functions
5
10
  require "torch/native/generator"
6
11
  require "torch/native/parser"
@@ -174,11 +179,9 @@ require "torch/nn/init"
174
179
 
175
180
  # utils
176
181
  require "torch/utils/data/data_loader"
182
+ require "torch/utils/data/dataset"
177
183
  require "torch/utils/data/tensor_dataset"
178
184
 
179
- # random
180
- require "torch/random"
181
-
182
185
  # hub
183
186
  require "torch/hub"
184
187
 
@@ -317,12 +320,11 @@ module Torch
317
320
  end
318
321
 
319
322
  def save(obj, f)
320
- raise NotImplementedYet unless obj.is_a?(Tensor)
321
- File.binwrite(f, _save(obj))
323
+ File.binwrite(f, _save(to_ivalue(obj)))
322
324
  end
323
325
 
324
326
  def load(f)
325
- raise NotImplementedYet
327
+ to_ruby(_load(File.binread(f)))
326
328
  end
327
329
 
328
330
  # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
@@ -447,6 +449,100 @@ module Torch
447
449
 
448
450
  private
449
451
 
452
+ def to_ivalue(obj)
453
+ case obj
454
+ when String
455
+ IValue.from_string(obj)
456
+ when Integer
457
+ IValue.from_int(obj)
458
+ when Tensor
459
+ IValue.from_tensor(obj)
460
+ when Float
461
+ IValue.from_double(obj)
462
+ when Hash
463
+ dict = {}
464
+ obj.each do |k, v|
465
+ dict[to_ivalue(k)] = to_ivalue(v)
466
+ end
467
+ IValue.from_dict(dict)
468
+ when true, false
469
+ IValue.from_bool(obj)
470
+ when nil
471
+ IValue.new
472
+ when Array
473
+ if obj.all? { |v| v.is_a?(Tensor) }
474
+ IValue.from_list(obj.map { |v| IValue.from_tensor(v) })
475
+ else
476
+ raise Error, "Unknown list type"
477
+ end
478
+ else
479
+ raise Error, "Unknown type: #{obj.class.name}"
480
+ end
481
+ end
482
+
483
+ def to_ruby(ivalue)
484
+ if ivalue.bool?
485
+ ivalue.to_bool
486
+ elsif ivalue.double?
487
+ ivalue.to_double
488
+ elsif ivalue.int?
489
+ ivalue.to_int
490
+ elsif ivalue.none?
491
+ nil
492
+ elsif ivalue.string?
493
+ ivalue.to_string_ref
494
+ elsif ivalue.tensor?
495
+ ivalue.to_tensor
496
+ elsif ivalue.generic_dict?
497
+ dict = {}
498
+ ivalue.to_generic_dict.each do |k, v|
499
+ dict[to_ruby(k)] = to_ruby(v)
500
+ end
501
+ dict
502
+ elsif ivalue.list?
503
+ ivalue.to_list.map { |v| to_ruby(v) }
504
+ else
505
+ type =
506
+ if ivalue.capsule?
507
+ "Capsule"
508
+ elsif ivalue.custom_class?
509
+ "CustomClass"
510
+ elsif ivalue.tuple?
511
+ "Tuple"
512
+ elsif ivalue.future?
513
+ "Future"
514
+ elsif ivalue.r_ref?
515
+ "RRef"
516
+ elsif ivalue.int_list?
517
+ "IntList"
518
+ elsif ivalue.double_list?
519
+ "DoubleList"
520
+ elsif ivalue.bool_list?
521
+ "BoolList"
522
+ elsif ivalue.tensor_list?
523
+ "TensorList"
524
+ elsif ivalue.object?
525
+ "Object"
526
+ elsif ivalue.module?
527
+ "Module"
528
+ elsif ivalue.py_object?
529
+ "PyObject"
530
+ elsif ivalue.scalar?
531
+ "Scalar"
532
+ elsif ivalue.device?
533
+ "Device"
534
+ # elsif ivalue.generator?
535
+ # "Generator"
536
+ elsif ivalue.ptr_type?
537
+ "PtrType"
538
+ else
539
+ "Unknown"
540
+ end
541
+
542
+ raise Error, "Unsupported type: #{type}"
543
+ end
544
+ end
545
+
450
546
  def tensor_size(size)
451
547
  size.flatten
452
548
  end
@@ -4,6 +4,58 @@ module Torch
4
4
  def list(github, force_reload: false)
5
5
  raise NotImplementedYet
6
6
  end
7
+
8
+ def download_url_to_file(url, dst)
9
+ uri = URI(url)
10
+ tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
11
+ location = nil
12
+
13
+ Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
14
+ request = Net::HTTP::Get.new(uri)
15
+
16
+ puts "Downloading #{url}..."
17
+ File.open(tmp, "wb") do |f|
18
+ http.request(request) do |response|
19
+ case response
20
+ when Net::HTTPRedirection
21
+ location = response["location"]
22
+ when Net::HTTPSuccess
23
+ response.read_body do |chunk|
24
+ f.write(chunk)
25
+ end
26
+ else
27
+ raise Error, "Bad response"
28
+ end
29
+ end
30
+ end
31
+ end
32
+
33
+ if location
34
+ download_url_to_file(location, dst)
35
+ else
36
+ FileUtils.mv(tmp, dst)
37
+ nil
38
+ end
39
+ end
40
+
41
+ def load_state_dict_from_url(url, model_dir: nil)
42
+ unless model_dir
43
+ torch_home = ENV["TORCH_HOME"] || "#{ENV["XDG_CACHE_HOME"] || "#{ENV["HOME"]}/.cache"}/torch"
44
+ model_dir = File.join(torch_home, "checkpoints")
45
+ end
46
+
47
+ FileUtils.mkdir_p(model_dir)
48
+
49
+ parts = URI(url)
50
+ filename = File.basename(parts.path)
51
+ cached_file = File.join(model_dir, filename)
52
+ unless File.exist?(cached_file)
53
+ # TODO support hash_prefix
54
+ download_url_to_file(url, cached_file)
55
+ end
56
+
57
+ Torch.load(cached_file)
58
+ end
7
59
  end
8
60
  end
9
61
  end
@@ -1,6 +1,6 @@
1
1
  module Torch
2
2
  module Inspector
3
- # TODO make more performance, especially when summarizing
3
+ # TODO make more performant, especially when summarizing
4
4
  # how? only read data that will be displayed
5
5
  def inspect
6
6
  data =
@@ -14,7 +14,7 @@ module Torch
14
14
  if dtype == :bool
15
15
  fmt = "%s"
16
16
  else
17
- values = to_a.flatten
17
+ values = _flat_data
18
18
  abs = values.select { |v| v != 0 }.map(&:abs)
19
19
  max = abs.max || 1
20
20
  min = abs.min || 1
@@ -25,7 +25,7 @@ module Torch
25
25
  end
26
26
 
27
27
  if floating_point?
28
- sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
28
+ sci = max > 1e8 || max < 1e-4
29
29
 
30
30
  all_int = values.all? { |v| v.finite? && v == v.to_i }
31
31
  decimal = all_int ? 1 : 4
@@ -70,6 +70,11 @@ module Torch
70
70
  momentum: exponential_average_factor, eps: @eps
71
71
  )
72
72
  end
73
+
74
+ def extra_inspect
75
+ s = "%{num_features}, eps: %{eps}, momentum: %{momentum}, affine: %{affine}, track_running_stats: %{track_running_stats}"
76
+ format(s, **dict)
77
+ end
73
78
  end
74
79
  end
75
80
  end
@@ -20,7 +20,14 @@ module Torch
20
20
 
21
21
  # TODO add more parameters
22
22
  def extra_inspect
23
- format("%s, %s, kernel_size: %s, stride: %s", @in_channels, @out_channels, @kernel_size, @stride)
23
+ s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
24
+ s += ", padding: %{padding}" if @padding != [0] * @padding.size
25
+ s += ", dilation: %{dilation}" if @dilation != [1] * @dilation.size
26
+ s += ", output_padding: %{output_padding}" if @output_padding != [0] * @output_padding.size
27
+ s += ", groups: %{groups}" if @groups != 1
28
+ s += ", bias: false" unless @bias
29
+ s += ", padding_mode: %{padding_mode}" if @padding_mode != "zeros"
30
+ format(s, **dict)
24
31
  end
25
32
  end
26
33
  end
@@ -23,7 +23,7 @@ module Torch
23
23
  if bias
24
24
  @bias = Parameter.new(Tensor.new(out_channels))
25
25
  else
26
- raise NotImplementedError
26
+ register_parameter("bias", nil)
27
27
  end
28
28
  reset_parameters
29
29
  end
@@ -12,7 +12,8 @@ module Torch
12
12
  end
13
13
 
14
14
  def extra_inspect
15
- format("kernel_size: %s", @kernel_size)
15
+ s = "kernel_size: %{kernel_size}, stride: %{stride}, padding: %{padding}, dilation: %{dilation}, ceil_mode: %{ceil_mode}"
16
+ format(s, **dict)
16
17
  end
17
18
  end
18
19
  end
@@ -67,8 +67,9 @@ module Torch
67
67
  self
68
68
  end
69
69
 
70
- def cuda(device: nil)
71
- _apply ->(t) { t.cuda(device) }
70
+ # TODO add device
71
+ def cuda
72
+ _apply ->(t) { t.cuda }
72
73
  end
73
74
 
74
75
  def cpu
@@ -112,8 +113,28 @@ module Torch
112
113
  destination
113
114
  end
114
115
 
116
+ # TODO add strict option
117
+ # TODO match PyTorch behavior
115
118
  def load_state_dict(state_dict)
116
- raise NotImplementedYet
119
+ state_dict.each do |k, input_param|
120
+ k1, k2 = k.split(".", 2)
121
+ mod = named_modules[k1]
122
+ if mod.is_a?(Module)
123
+ param = mod.named_parameters[k2]
124
+ if param.is_a?(Parameter)
125
+ Torch.no_grad do
126
+ param.copy!(input_param)
127
+ end
128
+ else
129
+ raise Error, "Unknown parameter: #{k1}"
130
+ end
131
+ else
132
+ raise Error, "Unknown module: #{k1}"
133
+ end
134
+ end
135
+
136
+ # TODO return missing keys and unexpected keys
137
+ nil
117
138
  end
118
139
 
119
140
  def parameters
@@ -165,8 +186,22 @@ module Torch
165
186
  named_modules.values
166
187
  end
167
188
 
168
- def named_modules
169
- {"" => self}.merge(named_children)
189
+ # TODO return enumerator?
190
+ def named_modules(memo: nil, prefix: "")
191
+ ret = {}
192
+ memo ||= Set.new
193
+ unless memo.include?(self)
194
+ memo << self
195
+ ret[prefix] = self
196
+ named_children.each do |name, mod|
197
+ next unless mod.is_a?(Module)
198
+ submodule_prefix = prefix + (!prefix.empty? ? "." : "") + name
199
+ mod.named_modules(memo: memo, prefix: submodule_prefix).each do |m|
200
+ ret[m[0]] = m[1]
201
+ end
202
+ end
203
+ end
204
+ ret
170
205
  end
171
206
 
172
207
  def train(mode = true)
@@ -203,13 +238,15 @@ module Torch
203
238
 
204
239
  def inspect
205
240
  name = self.class.name.split("::").last
206
- if children.empty?
241
+ if named_children.empty?
207
242
  "#{name}(#{extra_inspect})"
208
243
  else
209
244
  str = String.new
210
245
  str << "#{name}(\n"
211
- children.each do |name, mod|
212
- str << " (#{name}): #{mod.inspect}\n"
246
+ named_children.each do |name, mod|
247
+ mod_str = mod.inspect
248
+ mod_str = mod_str.lines.join(" ")
249
+ str << " (#{name}): #{mod_str}\n"
213
250
  end
214
251
  str << ")"
215
252
  end
@@ -4,6 +4,8 @@ module Torch
4
4
  include Inspector
5
5
 
6
6
  alias_method :requires_grad?, :requires_grad
7
+ alias_method :ndim, :dim
8
+ alias_method :ndimension, :dim
7
9
 
8
10
  def self.new(*args)
9
11
  FloatTensor.new(*args)
@@ -23,8 +25,17 @@ module Torch
23
25
  inspect
24
26
  end
25
27
 
28
+ # TODO make more performant
26
29
  def to_a
27
- reshape_arr(_flat_data, shape)
30
+ arr = _flat_data
31
+ if shape.empty?
32
+ arr
33
+ else
34
+ shape[1..-1].reverse.each do |dim|
35
+ arr = arr.each_slice(dim)
36
+ end
37
+ arr.to_a
38
+ end
28
39
  end
29
40
 
30
41
  # TODO support dtype
@@ -37,6 +48,10 @@ module Torch
37
48
  to("cpu")
38
49
  end
39
50
 
51
+ def cuda
52
+ to("cuda")
53
+ end
54
+
40
55
  def size(dim = nil)
41
56
  if dim
42
57
  _size_int(dim)
@@ -58,7 +73,15 @@ module Torch
58
73
  if numel != 1
59
74
  raise Error, "only one element tensors can be converted to Ruby scalars"
60
75
  end
61
- _flat_data.first
76
+ to_a.first
77
+ end
78
+
79
+ def to_i
80
+ item.to_i
81
+ end
82
+
83
+ def to_f
84
+ item.to_f
62
85
  end
63
86
 
64
87
  # unsure if this is correct
@@ -74,7 +97,7 @@ module Torch
74
97
  def numo
75
98
  cls = Torch._dtype_to_numo[dtype]
76
99
  raise Error, "Cannot convert #{dtype} to Numo" unless cls
77
- cls.cast(_flat_data).reshape(*shape)
100
+ cls.from_string(_data_str).reshape(*shape)
78
101
  end
79
102
 
80
103
  def new_ones(*size, **options)
@@ -102,15 +125,6 @@ module Torch
102
125
  _view(size)
103
126
  end
104
127
 
105
- # value and other are swapped for some methods
106
- def add!(value = 1, other)
107
- if other.is_a?(Numeric)
108
- _add__scalar(other, value)
109
- else
110
- _add__tensor(other, value)
111
- end
112
- end
113
-
114
128
  def +(other)
115
129
  add(other)
116
130
  end
@@ -139,6 +153,7 @@ module Torch
139
153
  neg
140
154
  end
141
155
 
156
+ # TODO better compare?
142
157
  def <=>(other)
143
158
  item <=> other
144
159
  end
@@ -186,8 +201,27 @@ module Torch
186
201
  end
187
202
  end
188
203
 
189
- def random!(from = 0, to)
190
- _random__from_to(from, to)
204
+ # native functions that need manually defined
205
+
206
+ # value and other are swapped for some methods
207
+ def add!(value = 1, other)
208
+ if other.is_a?(Numeric)
209
+ _add__scalar(other, value)
210
+ else
211
+ _add__tensor(other, value)
212
+ end
213
+ end
214
+
215
+ # native functions overlap, so need to handle manually
216
+ def random!(*args)
217
+ case args.size
218
+ when 1
219
+ _random__to(*args)
220
+ when 2
221
+ _random__from_to(*args)
222
+ else
223
+ _random_(*args)
224
+ end
191
225
  end
192
226
 
193
227
  private
@@ -195,17 +229,5 @@ module Torch
195
229
  def copy_to(dst, src)
196
230
  dst.copy!(src)
197
231
  end
198
-
199
- def reshape_arr(arr, dims)
200
- if dims.empty?
201
- arr
202
- else
203
- arr = arr.flatten
204
- dims[1..-1].reverse.each do |dim|
205
- arr = arr.each_slice(dim)
206
- end
207
- arr.to_a
208
- end
209
- end
210
232
  end
211
233
  end