torch-rb 0.2.0 → 0.2.5

Sign up to get free protection for your applications and to get access to all the features.
@@ -2,28 +2,33 @@ require "mkmf-rice"
2
2
 
3
3
  abort "Missing stdc++" unless have_library("stdc++")
4
4
 
5
- $CXXFLAGS << " -std=c++14"
5
+ $CXXFLAGS += " -std=c++14"
6
6
 
7
7
  # change to 0 for Linux pre-cxx11 ABI version
8
- $CXXFLAGS << " -D_GLIBCXX_USE_CXX11_ABI=1"
8
+ $CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
9
9
 
10
10
  # TODO check compiler name
11
11
  clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
12
12
 
13
+ # check omp first
13
14
  if have_library("omp") || have_library("gomp")
14
- $CXXFLAGS << " -DAT_PARALLEL_OPENMP=1"
15
- $CXXFLAGS << " -Xclang" if clang
16
- $CXXFLAGS << " -fopenmp"
15
+ $CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
16
+ $CXXFLAGS += " -Xclang" if clang
17
+ $CXXFLAGS += " -fopenmp"
17
18
  end
18
19
 
19
- # silence ruby/intern.h warning
20
- $CXXFLAGS << " -Wno-deprecated-register"
21
-
22
- # silence torch warnings
23
20
  if clang
24
- $CXXFLAGS << " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
21
+ # silence ruby/intern.h warning
22
+ $CXXFLAGS += " -Wno-deprecated-register"
23
+
24
+ # silence torch warnings
25
+ $CXXFLAGS += " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
25
26
  else
26
- $CXXFLAGS << " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
27
+ # silence rice warnings
28
+ $CXXFLAGS += " -Wno-noexcept-type"
29
+
30
+ # silence torch warnings
31
+ $CXXFLAGS += " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
27
32
  end
28
33
 
29
34
  inc, lib = dir_config("torch")
@@ -34,22 +39,30 @@ cuda_inc, cuda_lib = dir_config("cuda")
34
39
  cuda_inc ||= "/usr/local/cuda/include"
35
40
  cuda_lib ||= "/usr/local/cuda/lib64"
36
41
 
37
- with_cuda = Dir["#{lib}/*torch_cuda*"].any? && have_library("cuda") && have_library("cudnn")
42
+ $LDFLAGS += " -L#{lib}" if Dir.exist?(lib)
43
+ abort "LibTorch not found" unless have_library("torch")
44
+
45
+ have_library("mkldnn")
46
+ have_library("nnpack")
47
+
48
+ with_cuda = false
49
+ if Dir["#{lib}/*torch_cuda*"].any?
50
+ $LDFLAGS += " -L#{cuda_lib}" if Dir.exist?(cuda_lib)
51
+ with_cuda = have_library("cuda") && have_library("cudnn")
52
+ end
38
53
 
39
- $INCFLAGS << " -I#{inc}"
40
- $INCFLAGS << " -I#{inc}/torch/csrc/api/include"
54
+ $INCFLAGS += " -I#{inc}"
55
+ $INCFLAGS += " -I#{inc}/torch/csrc/api/include"
41
56
 
42
- $LDFLAGS << " -Wl,-rpath,#{lib}"
43
- $LDFLAGS << ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
44
- $LDFLAGS << " -L#{lib}"
45
- $LDFLAGS << " -L#{cuda_lib}" if with_cuda
57
+ $LDFLAGS += " -Wl,-rpath,#{lib}"
58
+ $LDFLAGS += ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
46
59
 
47
60
  # https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238
48
- $LDFLAGS << " -lc10 -ltorch_cpu -ltorch"
61
+ $LDFLAGS += " -lc10 -ltorch_cpu -ltorch"
49
62
  if with_cuda
50
- $LDFLAGS << " -lcuda -lnvrtc -lnvToolsExt -lcudart -lc10_cuda -ltorch_cuda -lcufft -lcurand -lcublas -lcudnn"
63
+ $LDFLAGS += " -lcuda -lnvrtc -lnvToolsExt -lcudart -lc10_cuda -ltorch_cuda -lcufft -lcurand -lcublas -lcudnn"
51
64
  # TODO figure out why this is needed
52
- $LDFLAGS << " -Wl,--no-as-needed,#{lib}/libtorch.so"
65
+ $LDFLAGS += " -Wl,--no-as-needed,#{lib}/libtorch.so"
53
66
  end
54
67
 
55
68
  # generate C++ functions
@@ -1,6 +1,11 @@
1
1
  # ext
2
2
  require "torch/ext"
3
3
 
4
+ # stdlib
5
+ require "fileutils"
6
+ require "net/http"
7
+ require "tmpdir"
8
+
4
9
  # native functions
5
10
  require "torch/native/generator"
6
11
  require "torch/native/parser"
@@ -174,11 +179,9 @@ require "torch/nn/init"
174
179
 
175
180
  # utils
176
181
  require "torch/utils/data/data_loader"
182
+ require "torch/utils/data/dataset"
177
183
  require "torch/utils/data/tensor_dataset"
178
184
 
179
- # random
180
- require "torch/random"
181
-
182
185
  # hub
183
186
  require "torch/hub"
184
187
 
@@ -317,12 +320,11 @@ module Torch
317
320
  end
318
321
 
319
322
  def save(obj, f)
320
- raise NotImplementedYet unless obj.is_a?(Tensor)
321
- File.binwrite(f, _save(obj))
323
+ File.binwrite(f, _save(to_ivalue(obj)))
322
324
  end
323
325
 
324
326
  def load(f)
325
- raise NotImplementedYet
327
+ to_ruby(_load(File.binread(f)))
326
328
  end
327
329
 
328
330
  # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
@@ -447,6 +449,100 @@ module Torch
447
449
 
448
450
  private
449
451
 
452
+ def to_ivalue(obj)
453
+ case obj
454
+ when String
455
+ IValue.from_string(obj)
456
+ when Integer
457
+ IValue.from_int(obj)
458
+ when Tensor
459
+ IValue.from_tensor(obj)
460
+ when Float
461
+ IValue.from_double(obj)
462
+ when Hash
463
+ dict = {}
464
+ obj.each do |k, v|
465
+ dict[to_ivalue(k)] = to_ivalue(v)
466
+ end
467
+ IValue.from_dict(dict)
468
+ when true, false
469
+ IValue.from_bool(obj)
470
+ when nil
471
+ IValue.new
472
+ when Array
473
+ if obj.all? { |v| v.is_a?(Tensor) }
474
+ IValue.from_list(obj.map { |v| IValue.from_tensor(v) })
475
+ else
476
+ raise Error, "Unknown list type"
477
+ end
478
+ else
479
+ raise Error, "Unknown type: #{obj.class.name}"
480
+ end
481
+ end
482
+
483
+ def to_ruby(ivalue)
484
+ if ivalue.bool?
485
+ ivalue.to_bool
486
+ elsif ivalue.double?
487
+ ivalue.to_double
488
+ elsif ivalue.int?
489
+ ivalue.to_int
490
+ elsif ivalue.none?
491
+ nil
492
+ elsif ivalue.string?
493
+ ivalue.to_string_ref
494
+ elsif ivalue.tensor?
495
+ ivalue.to_tensor
496
+ elsif ivalue.generic_dict?
497
+ dict = {}
498
+ ivalue.to_generic_dict.each do |k, v|
499
+ dict[to_ruby(k)] = to_ruby(v)
500
+ end
501
+ dict
502
+ elsif ivalue.list?
503
+ ivalue.to_list.map { |v| to_ruby(v) }
504
+ else
505
+ type =
506
+ if ivalue.capsule?
507
+ "Capsule"
508
+ elsif ivalue.custom_class?
509
+ "CustomClass"
510
+ elsif ivalue.tuple?
511
+ "Tuple"
512
+ elsif ivalue.future?
513
+ "Future"
514
+ elsif ivalue.r_ref?
515
+ "RRef"
516
+ elsif ivalue.int_list?
517
+ "IntList"
518
+ elsif ivalue.double_list?
519
+ "DoubleList"
520
+ elsif ivalue.bool_list?
521
+ "BoolList"
522
+ elsif ivalue.tensor_list?
523
+ "TensorList"
524
+ elsif ivalue.object?
525
+ "Object"
526
+ elsif ivalue.module?
527
+ "Module"
528
+ elsif ivalue.py_object?
529
+ "PyObject"
530
+ elsif ivalue.scalar?
531
+ "Scalar"
532
+ elsif ivalue.device?
533
+ "Device"
534
+ # elsif ivalue.generator?
535
+ # "Generator"
536
+ elsif ivalue.ptr_type?
537
+ "PtrType"
538
+ else
539
+ "Unknown"
540
+ end
541
+
542
+ raise Error, "Unsupported type: #{type}"
543
+ end
544
+ end
545
+
450
546
  def tensor_size(size)
451
547
  size.flatten
452
548
  end
@@ -4,6 +4,58 @@ module Torch
4
4
  def list(github, force_reload: false)
5
5
  raise NotImplementedYet
6
6
  end
7
+
8
+ def download_url_to_file(url, dst)
9
+ uri = URI(url)
10
+ tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
11
+ location = nil
12
+
13
+ Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
14
+ request = Net::HTTP::Get.new(uri)
15
+
16
+ puts "Downloading #{url}..."
17
+ File.open(tmp, "wb") do |f|
18
+ http.request(request) do |response|
19
+ case response
20
+ when Net::HTTPRedirection
21
+ location = response["location"]
22
+ when Net::HTTPSuccess
23
+ response.read_body do |chunk|
24
+ f.write(chunk)
25
+ end
26
+ else
27
+ raise Error, "Bad response"
28
+ end
29
+ end
30
+ end
31
+ end
32
+
33
+ if location
34
+ download_url_to_file(location, dst)
35
+ else
36
+ FileUtils.mv(tmp, dst)
37
+ nil
38
+ end
39
+ end
40
+
41
+ def load_state_dict_from_url(url, model_dir: nil)
42
+ unless model_dir
43
+ torch_home = ENV["TORCH_HOME"] || "#{ENV["XDG_CACHE_HOME"] || "#{ENV["HOME"]}/.cache"}/torch"
44
+ model_dir = File.join(torch_home, "checkpoints")
45
+ end
46
+
47
+ FileUtils.mkdir_p(model_dir)
48
+
49
+ parts = URI(url)
50
+ filename = File.basename(parts.path)
51
+ cached_file = File.join(model_dir, filename)
52
+ unless File.exist?(cached_file)
53
+ # TODO support hash_prefix
54
+ download_url_to_file(url, cached_file)
55
+ end
56
+
57
+ Torch.load(cached_file)
58
+ end
7
59
  end
8
60
  end
9
61
  end
@@ -1,6 +1,6 @@
1
1
  module Torch
2
2
  module Inspector
3
- # TODO make more performance, especially when summarizing
3
+ # TODO make more performant, especially when summarizing
4
4
  # how? only read data that will be displayed
5
5
  def inspect
6
6
  data =
@@ -14,7 +14,7 @@ module Torch
14
14
  if dtype == :bool
15
15
  fmt = "%s"
16
16
  else
17
- values = to_a.flatten
17
+ values = _flat_data
18
18
  abs = values.select { |v| v != 0 }.map(&:abs)
19
19
  max = abs.max || 1
20
20
  min = abs.min || 1
@@ -25,7 +25,7 @@ module Torch
25
25
  end
26
26
 
27
27
  if floating_point?
28
- sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
28
+ sci = max > 1e8 || max < 1e-4
29
29
 
30
30
  all_int = values.all? { |v| v.finite? && v == v.to_i }
31
31
  decimal = all_int ? 1 : 4
@@ -70,6 +70,11 @@ module Torch
70
70
  momentum: exponential_average_factor, eps: @eps
71
71
  )
72
72
  end
73
+
74
+ def extra_inspect
75
+ s = "%{num_features}, eps: %{eps}, momentum: %{momentum}, affine: %{affine}, track_running_stats: %{track_running_stats}"
76
+ format(s, **dict)
77
+ end
73
78
  end
74
79
  end
75
80
  end
@@ -20,7 +20,14 @@ module Torch
20
20
 
21
21
  # TODO add more parameters
22
22
  def extra_inspect
23
- format("%s, %s, kernel_size: %s, stride: %s", @in_channels, @out_channels, @kernel_size, @stride)
23
+ s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
24
+ s += ", padding: %{padding}" if @padding != [0] * @padding.size
25
+ s += ", dilation: %{dilation}" if @dilation != [1] * @dilation.size
26
+ s += ", output_padding: %{output_padding}" if @output_padding != [0] * @output_padding.size
27
+ s += ", groups: %{groups}" if @groups != 1
28
+ s += ", bias: false" unless @bias
29
+ s += ", padding_mode: %{padding_mode}" if @padding_mode != "zeros"
30
+ format(s, **dict)
24
31
  end
25
32
  end
26
33
  end
@@ -23,7 +23,7 @@ module Torch
23
23
  if bias
24
24
  @bias = Parameter.new(Tensor.new(out_channels))
25
25
  else
26
- raise NotImplementedError
26
+ register_parameter("bias", nil)
27
27
  end
28
28
  reset_parameters
29
29
  end
@@ -12,7 +12,8 @@ module Torch
12
12
  end
13
13
 
14
14
  def extra_inspect
15
- format("kernel_size: %s", @kernel_size)
15
+ s = "kernel_size: %{kernel_size}, stride: %{stride}, padding: %{padding}, dilation: %{dilation}, ceil_mode: %{ceil_mode}"
16
+ format(s, **dict)
16
17
  end
17
18
  end
18
19
  end
@@ -67,8 +67,9 @@ module Torch
67
67
  self
68
68
  end
69
69
 
70
- def cuda(device: nil)
71
- _apply ->(t) { t.cuda(device) }
70
+ # TODO add device
71
+ def cuda
72
+ _apply ->(t) { t.cuda }
72
73
  end
73
74
 
74
75
  def cpu
@@ -112,8 +113,28 @@ module Torch
112
113
  destination
113
114
  end
114
115
 
116
+ # TODO add strict option
117
+ # TODO match PyTorch behavior
115
118
  def load_state_dict(state_dict)
116
- raise NotImplementedYet
119
+ state_dict.each do |k, input_param|
120
+ k1, k2 = k.split(".", 2)
121
+ mod = named_modules[k1]
122
+ if mod.is_a?(Module)
123
+ param = mod.named_parameters[k2]
124
+ if param.is_a?(Parameter)
125
+ Torch.no_grad do
126
+ param.copy!(input_param)
127
+ end
128
+ else
129
+ raise Error, "Unknown parameter: #{k1}"
130
+ end
131
+ else
132
+ raise Error, "Unknown module: #{k1}"
133
+ end
134
+ end
135
+
136
+ # TODO return missing keys and unexpected keys
137
+ nil
117
138
  end
118
139
 
119
140
  def parameters
@@ -165,8 +186,22 @@ module Torch
165
186
  named_modules.values
166
187
  end
167
188
 
168
- def named_modules
169
- {"" => self}.merge(named_children)
189
+ # TODO return enumerator?
190
+ def named_modules(memo: nil, prefix: "")
191
+ ret = {}
192
+ memo ||= Set.new
193
+ unless memo.include?(self)
194
+ memo << self
195
+ ret[prefix] = self
196
+ named_children.each do |name, mod|
197
+ next unless mod.is_a?(Module)
198
+ submodule_prefix = prefix + (!prefix.empty? ? "." : "") + name
199
+ mod.named_modules(memo: memo, prefix: submodule_prefix).each do |m|
200
+ ret[m[0]] = m[1]
201
+ end
202
+ end
203
+ end
204
+ ret
170
205
  end
171
206
 
172
207
  def train(mode = true)
@@ -203,13 +238,15 @@ module Torch
203
238
 
204
239
  def inspect
205
240
  name = self.class.name.split("::").last
206
- if children.empty?
241
+ if named_children.empty?
207
242
  "#{name}(#{extra_inspect})"
208
243
  else
209
244
  str = String.new
210
245
  str << "#{name}(\n"
211
- children.each do |name, mod|
212
- str << " (#{name}): #{mod.inspect}\n"
246
+ named_children.each do |name, mod|
247
+ mod_str = mod.inspect
248
+ mod_str = mod_str.lines.join(" ")
249
+ str << " (#{name}): #{mod_str}\n"
213
250
  end
214
251
  str << ")"
215
252
  end
@@ -4,6 +4,8 @@ module Torch
4
4
  include Inspector
5
5
 
6
6
  alias_method :requires_grad?, :requires_grad
7
+ alias_method :ndim, :dim
8
+ alias_method :ndimension, :dim
7
9
 
8
10
  def self.new(*args)
9
11
  FloatTensor.new(*args)
@@ -23,8 +25,17 @@ module Torch
23
25
  inspect
24
26
  end
25
27
 
28
+ # TODO make more performant
26
29
  def to_a
27
- reshape_arr(_flat_data, shape)
30
+ arr = _flat_data
31
+ if shape.empty?
32
+ arr
33
+ else
34
+ shape[1..-1].reverse.each do |dim|
35
+ arr = arr.each_slice(dim)
36
+ end
37
+ arr.to_a
38
+ end
28
39
  end
29
40
 
30
41
  # TODO support dtype
@@ -37,6 +48,10 @@ module Torch
37
48
  to("cpu")
38
49
  end
39
50
 
51
+ def cuda
52
+ to("cuda")
53
+ end
54
+
40
55
  def size(dim = nil)
41
56
  if dim
42
57
  _size_int(dim)
@@ -58,7 +73,15 @@ module Torch
58
73
  if numel != 1
59
74
  raise Error, "only one element tensors can be converted to Ruby scalars"
60
75
  end
61
- _flat_data.first
76
+ to_a.first
77
+ end
78
+
79
+ def to_i
80
+ item.to_i
81
+ end
82
+
83
+ def to_f
84
+ item.to_f
62
85
  end
63
86
 
64
87
  # unsure if this is correct
@@ -74,7 +97,7 @@ module Torch
74
97
  def numo
75
98
  cls = Torch._dtype_to_numo[dtype]
76
99
  raise Error, "Cannot convert #{dtype} to Numo" unless cls
77
- cls.cast(_flat_data).reshape(*shape)
100
+ cls.from_string(_data_str).reshape(*shape)
78
101
  end
79
102
 
80
103
  def new_ones(*size, **options)
@@ -102,15 +125,6 @@ module Torch
102
125
  _view(size)
103
126
  end
104
127
 
105
- # value and other are swapped for some methods
106
- def add!(value = 1, other)
107
- if other.is_a?(Numeric)
108
- _add__scalar(other, value)
109
- else
110
- _add__tensor(other, value)
111
- end
112
- end
113
-
114
128
  def +(other)
115
129
  add(other)
116
130
  end
@@ -139,6 +153,7 @@ module Torch
139
153
  neg
140
154
  end
141
155
 
156
+ # TODO better compare?
142
157
  def <=>(other)
143
158
  item <=> other
144
159
  end
@@ -186,8 +201,27 @@ module Torch
186
201
  end
187
202
  end
188
203
 
189
- def random!(from = 0, to)
190
- _random__from_to(from, to)
204
+ # native functions that need manually defined
205
+
206
+ # value and other are swapped for some methods
207
+ def add!(value = 1, other)
208
+ if other.is_a?(Numeric)
209
+ _add__scalar(other, value)
210
+ else
211
+ _add__tensor(other, value)
212
+ end
213
+ end
214
+
215
+ # native functions overlap, so need to handle manually
216
+ def random!(*args)
217
+ case args.size
218
+ when 1
219
+ _random__to(*args)
220
+ when 2
221
+ _random__from_to(*args)
222
+ else
223
+ _random_(*args)
224
+ end
191
225
  end
192
226
 
193
227
  private
@@ -195,17 +229,5 @@ module Torch
195
229
  def copy_to(dst, src)
196
230
  dst.copy!(src)
197
231
  end
198
-
199
- def reshape_arr(arr, dims)
200
- if dims.empty?
201
- arr
202
- else
203
- arr = arr.flatten
204
- dims[1..-1].reverse.each do |dim|
205
- arr = arr.each_slice(dim)
206
- end
207
- arr.to_a
208
- end
209
- end
210
232
  end
211
233
  end