torch-rb 0.1.3 → 0.1.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +30 -0
  3. data/README.md +5 -2
  4. data/ext/torch/ext.cpp +130 -555
  5. data/ext/torch/extconf.rb +9 -0
  6. data/ext/torch/templates.cpp +55 -0
  7. data/ext/torch/templates.hpp +244 -0
  8. data/lib/torch.rb +209 -171
  9. data/lib/torch/inspector.rb +23 -19
  10. data/lib/torch/native/dispatcher.rb +48 -0
  11. data/lib/torch/native/function.rb +110 -0
  12. data/lib/torch/native/generator.rb +168 -0
  13. data/lib/torch/native/native_functions.yaml +6491 -0
  14. data/lib/torch/native/parser.rb +134 -0
  15. data/lib/torch/nn/avg_pool1d.rb +18 -0
  16. data/lib/torch/nn/avg_pool2d.rb +19 -0
  17. data/lib/torch/nn/avg_pool3d.rb +19 -0
  18. data/lib/torch/nn/avg_poolnd.rb +9 -0
  19. data/lib/torch/nn/batch_norm.rb +75 -0
  20. data/lib/torch/nn/batch_norm1d.rb +11 -0
  21. data/lib/torch/nn/batch_norm2d.rb +11 -0
  22. data/lib/torch/nn/batch_norm3d.rb +11 -0
  23. data/lib/torch/nn/bce_loss.rb +13 -0
  24. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  25. data/lib/torch/nn/bilinear.rb +38 -0
  26. data/lib/torch/nn/constant_pad1d.rb +10 -0
  27. data/lib/torch/nn/constant_pad2d.rb +10 -0
  28. data/lib/torch/nn/constant_pad3d.rb +10 -0
  29. data/lib/torch/nn/constant_padnd.rb +18 -0
  30. data/lib/torch/nn/conv1d.rb +22 -0
  31. data/lib/torch/nn/conv2d.rb +10 -20
  32. data/lib/torch/nn/conv3d.rb +22 -0
  33. data/lib/torch/nn/convnd.rb +3 -3
  34. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  35. data/lib/torch/nn/cosine_similarity.rb +15 -0
  36. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  37. data/lib/torch/nn/ctc_loss.rb +15 -0
  38. data/lib/torch/nn/dropoutnd.rb +2 -2
  39. data/lib/torch/nn/embedding_bag.rb +34 -0
  40. data/lib/torch/nn/fold.rb +20 -0
  41. data/lib/torch/nn/functional.rb +379 -32
  42. data/lib/torch/nn/group_norm.rb +36 -0
  43. data/lib/torch/nn/gru.rb +49 -0
  44. data/lib/torch/nn/hardshrink.rb +18 -0
  45. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  46. data/lib/torch/nn/identity.rb +14 -0
  47. data/lib/torch/nn/init.rb +58 -1
  48. data/lib/torch/nn/instance_norm.rb +20 -0
  49. data/lib/torch/nn/instance_norm1d.rb +18 -0
  50. data/lib/torch/nn/instance_norm2d.rb +11 -0
  51. data/lib/torch/nn/instance_norm3d.rb +11 -0
  52. data/lib/torch/nn/kl_div_loss.rb +13 -0
  53. data/lib/torch/nn/l1_loss.rb +13 -0
  54. data/lib/torch/nn/layer_norm.rb +35 -0
  55. data/lib/torch/nn/leaky_relu.rb +20 -0
  56. data/lib/torch/nn/linear.rb +12 -11
  57. data/lib/torch/nn/local_response_norm.rb +21 -0
  58. data/lib/torch/nn/log_sigmoid.rb +9 -0
  59. data/lib/torch/nn/log_softmax.rb +14 -0
  60. data/lib/torch/nn/loss.rb +10 -0
  61. data/lib/torch/nn/lp_pool1d.rb +9 -0
  62. data/lib/torch/nn/lp_pool2d.rb +9 -0
  63. data/lib/torch/nn/lp_poolnd.rb +22 -0
  64. data/lib/torch/nn/lstm.rb +66 -0
  65. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  66. data/lib/torch/nn/max_pool1d.rb +9 -0
  67. data/lib/torch/nn/max_pool2d.rb +9 -0
  68. data/lib/torch/nn/max_pool3d.rb +9 -0
  69. data/lib/torch/nn/max_poolnd.rb +19 -0
  70. data/lib/torch/nn/max_unpool1d.rb +16 -0
  71. data/lib/torch/nn/max_unpool2d.rb +16 -0
  72. data/lib/torch/nn/max_unpool3d.rb +16 -0
  73. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  74. data/lib/torch/nn/module.rb +186 -35
  75. data/lib/torch/nn/mse_loss.rb +2 -2
  76. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  77. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  78. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  79. data/lib/torch/nn/nll_loss.rb +14 -0
  80. data/lib/torch/nn/pairwise_distance.rb +16 -0
  81. data/lib/torch/nn/parameter.rb +2 -2
  82. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  83. data/lib/torch/nn/prelu.rb +19 -0
  84. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  85. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  86. data/lib/torch/nn/reflection_padnd.rb +13 -0
  87. data/lib/torch/nn/relu.rb +8 -3
  88. data/lib/torch/nn/replication_pad1d.rb +10 -0
  89. data/lib/torch/nn/replication_pad2d.rb +10 -0
  90. data/lib/torch/nn/replication_pad3d.rb +10 -0
  91. data/lib/torch/nn/replication_padnd.rb +13 -0
  92. data/lib/torch/nn/rnn.rb +22 -0
  93. data/lib/torch/nn/rnn_base.rb +198 -0
  94. data/lib/torch/nn/sequential.rb +1 -10
  95. data/lib/torch/nn/sigmoid.rb +9 -0
  96. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  97. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  98. data/lib/torch/nn/softmax.rb +18 -0
  99. data/lib/torch/nn/softmax2d.rb +10 -0
  100. data/lib/torch/nn/softmin.rb +14 -0
  101. data/lib/torch/nn/softplus.rb +19 -0
  102. data/lib/torch/nn/softshrink.rb +18 -0
  103. data/lib/torch/nn/softsign.rb +9 -0
  104. data/lib/torch/nn/tanh.rb +9 -0
  105. data/lib/torch/nn/tanhshrink.rb +9 -0
  106. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  107. data/lib/torch/nn/unfold.rb +19 -0
  108. data/lib/torch/nn/utils.rb +25 -0
  109. data/lib/torch/nn/weighted_loss.rb +10 -0
  110. data/lib/torch/nn/zero_pad2d.rb +9 -0
  111. data/lib/torch/random.rb +10 -0
  112. data/lib/torch/tensor.rb +51 -44
  113. data/lib/torch/version.rb +1 -1
  114. metadata +98 -6
  115. data/lib/torch/ext.bundle +0 -0
@@ -10,6 +10,9 @@ $CXXFLAGS << " -std=c++11"
10
10
  # silence ruby/intern.h warning
11
11
  $CXXFLAGS << " -Wno-deprecated-register"
12
12
 
13
+ # silence torch warnings
14
+ $CXXFLAGS << " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
15
+
13
16
  inc, lib = dir_config("torch")
14
17
 
15
18
  inc ||= "/usr/local/include"
@@ -22,4 +25,10 @@ $LDFLAGS << " -Wl,-rpath,#{lib}"
22
25
  $LDFLAGS << " -L#{lib}"
23
26
  $LDFLAGS << " -ltorch -lc10"
24
27
 
28
+ # generate C++ functions
29
+ puts "Generating C++ functions..."
30
+ require_relative "../../lib/torch/native/generator"
31
+ Torch::Native::Generator.generate_cpp_functions
32
+
33
+ # create makefile
25
34
  create_makefile("torch/ext")
@@ -0,0 +1,55 @@
1
+ #include <torch/torch.h>
2
+ #include <rice/Object.hpp>
3
+ #include "templates.hpp"
4
+
5
+ Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
6
+ Array a;
7
+ a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
8
+ a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
9
+ return Object(a);
10
+ }
11
+
12
+ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
13
+ Array a;
14
+ a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
15
+ a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
16
+ a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
17
+ return Object(a);
18
+ }
19
+
20
+ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
21
+ Array a;
22
+ a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
23
+ a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
24
+ a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
25
+ a.push(to_ruby<torch::Tensor>(std::get<3>(x)));
26
+ return Object(a);
27
+ }
28
+
29
+ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
30
+ Array a;
31
+ a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
32
+ a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
33
+ a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
34
+ a.push(to_ruby<torch::Tensor>(std::get<3>(x)));
35
+ a.push(to_ruby<torch::Tensor>(std::get<4>(x)));
36
+ return Object(a);
37
+ }
38
+
39
+ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
40
+ Array a;
41
+ a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
42
+ a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
43
+ a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
44
+ a.push(to_ruby<int64_t>(std::get<3>(x)));
45
+ return Object(a);
46
+ }
47
+
48
+ Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
49
+ Array a;
50
+ a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
51
+ a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
52
+ a.push(to_ruby<double>(std::get<2>(x)));
53
+ a.push(to_ruby<int64_t>(std::get<3>(x)));
54
+ return Object(a);
55
+ }
@@ -0,0 +1,244 @@
1
+ #pragma once
2
+
3
+ #ifdef isfinite
4
+ #undef isfinite
5
+ #endif
6
+
7
+ #include <rice/Array.hpp>
8
+ #include <rice/Object.hpp>
9
+
10
+ using namespace Rice;
11
+
12
+ // need to wrap torch::IntArrayRef() since
13
+ // it doesn't own underlying data
14
+ class IntArrayRef {
15
+ std::vector<int64_t> vec;
16
+ public:
17
+ IntArrayRef(Object o) {
18
+ Array a = Array(o);
19
+ for (size_t i = 0; i < a.size(); i++) {
20
+ vec.push_back(from_ruby<int64_t>(a[i]));
21
+ }
22
+ }
23
+ operator torch::IntArrayRef() {
24
+ return torch::IntArrayRef(vec);
25
+ }
26
+ };
27
+
28
+ template<>
29
+ inline
30
+ IntArrayRef from_ruby<IntArrayRef>(Object x)
31
+ {
32
+ return IntArrayRef(x);
33
+ }
34
+
35
+ // for now
36
+ class Scalar {
37
+ torch::Scalar value;
38
+ public:
39
+ Scalar(Object o) {
40
+ // TODO cast based on Ruby type
41
+ if (o.rb_type() == T_FIXNUM) {
42
+ value = torch::Scalar(from_ruby<int64_t>(o));
43
+ } else {
44
+ value = torch::Scalar(from_ruby<float>(o));
45
+ }
46
+ }
47
+ operator torch::Scalar() {
48
+ return value;
49
+ }
50
+ };
51
+
52
+ template<>
53
+ inline
54
+ Scalar from_ruby<Scalar>(Object x)
55
+ {
56
+ return Scalar(x);
57
+ }
58
+
59
+ class TensorList {
60
+ std::vector<torch::Tensor> vec;
61
+ public:
62
+ TensorList(Object o) {
63
+ Array a = Array(o);
64
+ for (size_t i = 0; i < a.size(); i++) {
65
+ vec.push_back(from_ruby<torch::Tensor>(a[i]));
66
+ }
67
+ }
68
+ operator torch::TensorList() {
69
+ return torch::TensorList(vec);
70
+ }
71
+ };
72
+
73
+ template<>
74
+ inline
75
+ TensorList from_ruby<TensorList>(Object x)
76
+ {
77
+ return TensorList(x);
78
+ }
79
+
80
+ class FanModeType {
81
+ std::string s;
82
+ public:
83
+ FanModeType(Object o) {
84
+ s = String(o).str();
85
+ }
86
+ operator torch::nn::init::FanModeType() {
87
+ if (s == "fan_in") {
88
+ return torch::kFanIn;
89
+ } else if (s == "fan_out") {
90
+ return torch::kFanOut;
91
+ } else {
92
+ throw std::runtime_error("Unsupported nonlinearity type: " + s);
93
+ }
94
+ }
95
+ };
96
+
97
+ template<>
98
+ inline
99
+ FanModeType from_ruby<FanModeType>(Object x)
100
+ {
101
+ return FanModeType(x);
102
+ }
103
+
104
+ class NonlinearityType {
105
+ std::string s;
106
+ public:
107
+ NonlinearityType(Object o) {
108
+ s = String(o).str();
109
+ }
110
+ operator torch::nn::init::NonlinearityType() {
111
+ if (s == "linear") {
112
+ return torch::kLinear;
113
+ } else if (s == "conv1d") {
114
+ return torch::kConv1D;
115
+ } else if (s == "conv2d") {
116
+ return torch::kConv2D;
117
+ } else if (s == "conv3d") {
118
+ return torch::kConv3D;
119
+ } else if (s == "conv_transpose1d") {
120
+ return torch::kConvTranspose1D;
121
+ } else if (s == "conv_transpose2d") {
122
+ return torch::kConvTranspose2D;
123
+ } else if (s == "conv_transpose3d") {
124
+ return torch::kConvTranspose3D;
125
+ } else if (s == "sigmoid") {
126
+ return torch::kSigmoid;
127
+ } else if (s == "tanh") {
128
+ return torch::kTanh;
129
+ } else if (s == "relu") {
130
+ return torch::kReLU;
131
+ } else if (s == "leaky_relu") {
132
+ return torch::kLeakyReLU;
133
+ } else {
134
+ throw std::runtime_error("Unsupported nonlinearity type: " + s);
135
+ }
136
+ }
137
+ };
138
+
139
+ template<>
140
+ inline
141
+ NonlinearityType from_ruby<NonlinearityType>(Object x)
142
+ {
143
+ return NonlinearityType(x);
144
+ }
145
+
146
+ class MyReduction {
147
+ Object value;
148
+ public:
149
+ MyReduction(Object o) {
150
+ value = o;
151
+ }
152
+ operator int64_t() {
153
+ if (value.is_nil()) {
154
+ return torch::Reduction::None;
155
+ }
156
+
157
+ std::string s = String(value).str();
158
+ if (s == "mean") {
159
+ return torch::Reduction::Mean;
160
+ } else if (s == "sum") {
161
+ return torch::Reduction::Sum;
162
+ } else {
163
+ throw std::runtime_error("Unsupported reduction: " + s);
164
+ }
165
+ }
166
+ };
167
+
168
+ template<>
169
+ inline
170
+ MyReduction from_ruby<MyReduction>(Object x)
171
+ {
172
+ return MyReduction(x);
173
+ }
174
+
175
+ typedef torch::Tensor Tensor;
176
+
177
+ class OptionalTensor {
178
+ Object value;
179
+ public:
180
+ OptionalTensor(Object o) {
181
+ value = o;
182
+ }
183
+ operator torch::Tensor() {
184
+ if (value.is_nil()) {
185
+ return {};
186
+ }
187
+ return from_ruby<torch::Tensor>(value);
188
+ }
189
+ };
190
+
191
+ template<>
192
+ inline
193
+ OptionalTensor from_ruby<OptionalTensor>(Object x)
194
+ {
195
+ return OptionalTensor(x);
196
+ }
197
+
198
+ class ScalarType {
199
+ Object value;
200
+ public:
201
+ ScalarType(Object o) {
202
+ value = o;
203
+ }
204
+ operator at::ScalarType() {
205
+ throw std::runtime_error("ScalarType arguments not implemented yet");
206
+ }
207
+ };
208
+
209
+ template<>
210
+ inline
211
+ ScalarType from_ruby<ScalarType>(Object x)
212
+ {
213
+ return ScalarType(x);
214
+ }
215
+
216
+ class OptionalScalarType {
217
+ Object value;
218
+ public:
219
+ OptionalScalarType(Object o) {
220
+ value = o;
221
+ }
222
+ operator c10::optional<at::ScalarType>() {
223
+ if (value.is_nil()) {
224
+ return c10::nullopt;
225
+ }
226
+ return ScalarType(value);
227
+ }
228
+ };
229
+
230
+ template<>
231
+ inline
232
+ OptionalScalarType from_ruby<OptionalScalarType>(Object x)
233
+ {
234
+ return OptionalScalarType(x);
235
+ }
236
+
237
+ typedef torch::Device Device;
238
+
239
+ Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
240
+ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
241
+ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
242
+ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
243
+ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x);
244
+ Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x);
@@ -1,6 +1,11 @@
1
1
  # ext
2
2
  require "torch/ext"
3
3
 
4
+ # native functions
5
+ require "torch/native/generator"
6
+ require "torch/native/parser"
7
+ require "torch/native/dispatcher"
8
+
4
9
  # modules
5
10
  require "torch/inspector"
6
11
  require "torch/tensor"
@@ -22,31 +27,145 @@ require "torch/optim/sgd"
22
27
  require "torch/optim/lr_scheduler/lr_scheduler"
23
28
  require "torch/optim/lr_scheduler/step_lr"
24
29
 
25
- # nn base classes
30
+ # nn parameters
31
+ require "torch/nn/parameter"
32
+ require "torch/nn/utils"
33
+
34
+ # nn containers
26
35
  require "torch/nn/module"
36
+ require "torch/nn/sequential"
37
+
38
+ # nn convolution layers
27
39
  require "torch/nn/convnd"
28
- require "torch/nn/dropoutnd"
40
+ require "torch/nn/conv1d"
41
+ require "torch/nn/conv2d"
42
+ require "torch/nn/conv3d"
43
+ require "torch/nn/unfold"
44
+ require "torch/nn/fold"
45
+
46
+ # nn pooling layers
47
+ require "torch/nn/max_poolnd"
48
+ require "torch/nn/max_pool1d"
49
+ require "torch/nn/max_pool2d"
50
+ require "torch/nn/max_pool3d"
51
+ require "torch/nn/max_unpoolnd"
52
+ require "torch/nn/max_unpool1d"
53
+ require "torch/nn/max_unpool2d"
54
+ require "torch/nn/max_unpool3d"
55
+ require "torch/nn/avg_poolnd"
56
+ require "torch/nn/avg_pool1d"
57
+ require "torch/nn/avg_pool2d"
58
+ require "torch/nn/avg_pool3d"
59
+ require "torch/nn/lp_poolnd"
60
+ require "torch/nn/lp_pool1d"
61
+ require "torch/nn/lp_pool2d"
62
+
63
+ # nn padding layers
64
+ require "torch/nn/reflection_padnd"
65
+ require "torch/nn/reflection_pad1d"
66
+ require "torch/nn/reflection_pad2d"
67
+ require "torch/nn/replication_padnd"
68
+ require "torch/nn/replication_pad1d"
69
+ require "torch/nn/replication_pad2d"
70
+ require "torch/nn/replication_pad3d"
71
+ require "torch/nn/constant_padnd"
72
+ require "torch/nn/constant_pad1d"
73
+ require "torch/nn/constant_pad2d"
74
+ require "torch/nn/constant_pad3d"
75
+ require "torch/nn/zero_pad2d"
76
+
77
+ # nn normalization layers
78
+ require "torch/nn/batch_norm"
79
+ require "torch/nn/batch_norm1d"
80
+ require "torch/nn/batch_norm2d"
81
+ require "torch/nn/batch_norm3d"
82
+ require "torch/nn/group_norm"
83
+ require "torch/nn/instance_norm"
84
+ require "torch/nn/instance_norm1d"
85
+ require "torch/nn/instance_norm2d"
86
+ require "torch/nn/instance_norm3d"
87
+ require "torch/nn/layer_norm"
88
+ require "torch/nn/local_response_norm"
89
+
90
+ # nn recurrent layers
91
+ require "torch/nn/rnn_base"
92
+ require "torch/nn/rnn"
93
+ require "torch/nn/lstm"
94
+ require "torch/nn/gru"
95
+
96
+ # nn linear layers
97
+ require "torch/nn/bilinear"
98
+ require "torch/nn/identity"
99
+ require "torch/nn/linear"
29
100
 
30
- # nn
101
+ # nn dropout layers
102
+ require "torch/nn/dropoutnd"
31
103
  require "torch/nn/alpha_dropout"
32
- require "torch/nn/conv2d"
33
104
  require "torch/nn/dropout"
34
105
  require "torch/nn/dropout2d"
35
106
  require "torch/nn/dropout3d"
36
- require "torch/nn/embedding"
37
107
  require "torch/nn/feature_alpha_dropout"
108
+
109
+ # nn activations
110
+ require "torch/nn/hardshrink"
111
+ require "torch/nn/leaky_relu"
112
+ require "torch/nn/log_sigmoid"
113
+ require "torch/nn/prelu"
114
+ require "torch/nn/relu"
115
+ require "torch/nn/sigmoid"
116
+ require "torch/nn/softplus"
117
+ require "torch/nn/softshrink"
118
+ require "torch/nn/softsign"
119
+ require "torch/nn/tanh"
120
+ require "torch/nn/tanhshrink"
121
+
122
+ # nn activations other
123
+ require "torch/nn/log_softmax"
124
+ require "torch/nn/softmax"
125
+ require "torch/nn/softmax2d"
126
+ require "torch/nn/softmin"
127
+
128
+ # nn sparse layers
129
+ require "torch/nn/embedding"
130
+ require "torch/nn/embedding_bag"
131
+
132
+ # nn distance functions
133
+ require "torch/nn/cosine_similarity"
134
+ require "torch/nn/pairwise_distance"
135
+
136
+ # nn loss functions
137
+ require "torch/nn/loss"
138
+ require "torch/nn/weighted_loss"
139
+ require "torch/nn/bce_loss"
140
+ require "torch/nn/bce_with_logits_loss"
141
+ require "torch/nn/cosine_embedding_loss"
142
+ require "torch/nn/cross_entropy_loss"
143
+ require "torch/nn/ctc_loss"
144
+ require "torch/nn/hinge_embedding_loss"
145
+ require "torch/nn/kl_div_loss"
146
+ require "torch/nn/l1_loss"
147
+ require "torch/nn/margin_ranking_loss"
148
+ require "torch/nn/mse_loss"
149
+ require "torch/nn/multi_label_margin_loss"
150
+ require "torch/nn/multi_label_soft_margin_loss"
151
+ require "torch/nn/multi_margin_loss"
152
+ require "torch/nn/nll_loss"
153
+ require "torch/nn/poisson_nll_loss"
154
+ require "torch/nn/smooth_l1_loss"
155
+ require "torch/nn/soft_margin_loss"
156
+ require "torch/nn/triplet_margin_loss"
157
+
158
+ # nn other
38
159
  require "torch/nn/functional"
39
160
  require "torch/nn/init"
40
- require "torch/nn/linear"
41
- require "torch/nn/mse_loss"
42
- require "torch/nn/parameter"
43
- require "torch/nn/relu"
44
- require "torch/nn/sequential"
45
161
 
46
162
  # utils
47
163
  require "torch/utils/data/data_loader"
48
164
  require "torch/utils/data/tensor_dataset"
49
165
 
166
+ # random
167
+ require "torch/random"
168
+
50
169
  module Torch
51
170
  class Error < StandardError; end
52
171
  class NotImplementedYet < StandardError
@@ -57,7 +176,6 @@ module Torch
57
176
 
58
177
  # keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
59
178
  # values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
60
- # complex and quantized types not supported by PyTorch yet
61
179
  DTYPE_TO_ENUM = {
62
180
  uint8: 0,
63
181
  int8: 1,
@@ -73,17 +191,52 @@ module Torch
73
191
  float32: 6,
74
192
  double: 7,
75
193
  float64: 7,
76
- # complex_half: 8,
77
- # complex_float: 9,
78
- # complex_double: 10,
194
+ complex_half: 8,
195
+ complex_float: 9,
196
+ complex_double: 10,
79
197
  bool: 11,
80
- # qint8: 12,
81
- # quint8: 13,
82
- # qint32: 14,
83
- # bfloat16: 15
198
+ qint8: 12,
199
+ quint8: 13,
200
+ qint32: 14,
201
+ bfloat16: 15
84
202
  }
85
203
  ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
86
204
 
205
+ def self._make_tensor_class(dtype, cuda = false)
206
+ cls = Class.new
207
+ device = cuda ? "cuda" : "cpu"
208
+ cls.define_singleton_method("new") do |*args|
209
+ if args.size == 1 && args.first.is_a?(Tensor)
210
+ args.first.send(dtype).to(device)
211
+ elsif args.size == 1 && args.first.is_a?(Array)
212
+ Torch.tensor(args.first, dtype: dtype, device: device)
213
+ else
214
+ Torch.empty(*args, dtype: dtype, device: device)
215
+ end
216
+ end
217
+ cls
218
+ end
219
+
220
+ FloatTensor = _make_tensor_class(:float32)
221
+ DoubleTensor = _make_tensor_class(:float64)
222
+ HalfTensor = _make_tensor_class(:float16)
223
+ ByteTensor = _make_tensor_class(:uint8)
224
+ CharTensor = _make_tensor_class(:int8)
225
+ ShortTensor = _make_tensor_class(:int16)
226
+ IntTensor = _make_tensor_class(:int32)
227
+ LongTensor = _make_tensor_class(:int64)
228
+ BoolTensor = _make_tensor_class(:bool)
229
+
230
+ CUDA::FloatTensor = _make_tensor_class(:float32, true)
231
+ CUDA::DoubleTensor = _make_tensor_class(:float64, true)
232
+ CUDA::HalfTensor = _make_tensor_class(:float16, true)
233
+ CUDA::ByteTensor = _make_tensor_class(:uint8, true)
234
+ CUDA::CharTensor = _make_tensor_class(:int8, true)
235
+ CUDA::ShortTensor = _make_tensor_class(:int16, true)
236
+ CUDA::IntTensor = _make_tensor_class(:int32, true)
237
+ CUDA::LongTensor = _make_tensor_class(:int64, true)
238
+ CUDA::BoolTensor = _make_tensor_class(:bool, true)
239
+
87
240
  class << self
88
241
  # Torch.float, Torch.long, etc
89
242
  DTYPE_TO_ENUM.each_key do |dtype|
@@ -120,6 +273,8 @@ module Torch
120
273
  # use method for cases when Numo not available
121
274
  # or available after Torch loaded
122
275
  def _dtype_to_numo
276
+ raise Error, "Numo not found" unless defined?(Numo::NArray)
277
+
123
278
  {
124
279
  uint8: Numo::UInt8,
125
280
  int8: Numo::Int8,
@@ -131,6 +286,29 @@ module Torch
131
286
  }
132
287
  end
133
288
 
289
+ def no_grad
290
+ previous_value = grad_enabled?
291
+ begin
292
+ _set_grad_enabled(false)
293
+ yield
294
+ ensure
295
+ _set_grad_enabled(previous_value)
296
+ end
297
+ end
298
+
299
+ def device(str)
300
+ Device.new(str)
301
+ end
302
+
303
+ def save(obj, f)
304
+ raise NotImplementedYet unless obj.is_a?(Tensor)
305
+ File.binwrite(f, _save(obj))
306
+ end
307
+
308
+ def load(f)
309
+ raise NotImplementedYet
310
+ end
311
+
134
312
  # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
135
313
 
136
314
  def arange(start, finish = nil, step = 1, **options)
@@ -200,8 +378,12 @@ module Torch
200
378
  data = [data].compact
201
379
  end
202
380
 
203
- if options[:dtype].nil? && data.all? { |v| v.is_a?(Integer) }
204
- options[:dtype] = :int64
381
+ if options[:dtype].nil?
382
+ if data.all? { |v| v.is_a?(Integer) }
383
+ options[:dtype] = :int64
384
+ elsif data.all? { |v| v == true || v == false }
385
+ options[:dtype] = :bool
386
+ end
205
387
  end
206
388
 
207
389
  _tensor(data, size, tensor_options(**options))
@@ -210,19 +392,19 @@ module Torch
210
392
  # --- begin like ---
211
393
 
212
394
  def ones_like(input, **options)
213
- ones(input.size, like_options(input, options))
395
+ ones(input.size, **like_options(input, options))
214
396
  end
215
397
 
216
398
  def empty_like(input, **options)
217
- empty(input.size, like_options(input, options))
399
+ empty(input.size, **like_options(input, options))
218
400
  end
219
401
 
220
402
  def full_like(input, fill_value, **options)
221
- full(input.size, fill_value, like_options(input, options))
403
+ full(input.size, fill_value, **like_options(input, options))
222
404
  end
223
405
 
224
406
  def rand_like(input, **options)
225
- rand(input.size, like_options(input, options))
407
+ rand(input.size, **like_options(input, options))
226
408
  end
227
409
 
228
410
  def randint_like(input, low, high = nil, **options)
@@ -231,163 +413,19 @@ module Torch
231
413
  high = low
232
414
  low = 0
233
415
  end
234
- randint(low, high, input.size, like_options(input, options))
416
+ randint(low, high, input.size, **like_options(input, options))
235
417
  end
236
418
 
237
419
  def randn_like(input, **options)
238
- randn(input.size, like_options(input, options))
420
+ randn(input.size, **like_options(input, options))
239
421
  end
240
422
 
241
423
  def zeros_like(input, **options)
242
- zeros(input.size, like_options(input, options))
243
- end
244
-
245
- # --- begin operations ---
246
-
247
- %w(add sub mul div remainder).each do |op|
248
- define_method(op) do |input, other, **options|
249
- execute_op(op, input, other, **options)
250
- end
251
- end
252
-
253
- def neg(input)
254
- _neg(input)
255
- end
256
-
257
- def no_grad
258
- previous_value = grad_enabled?
259
- begin
260
- _set_grad_enabled(false)
261
- yield
262
- ensure
263
- _set_grad_enabled(previous_value)
264
- end
265
- end
266
-
267
- # TODO support out
268
- def mean(input, dim = nil, keepdim: false)
269
- if dim
270
- _mean_dim(input, dim, keepdim)
271
- else
272
- _mean(input)
273
- end
274
- end
275
-
276
- # TODO support dtype
277
- def sum(input, dim = nil, keepdim: false)
278
- if dim
279
- _sum_dim(input, dim, keepdim)
280
- else
281
- _sum(input)
282
- end
283
- end
284
-
285
- def argmax(input, dim = nil, keepdim: false)
286
- if dim
287
- _argmax_dim(input, dim, keepdim)
288
- else
289
- _argmax(input)
290
- end
291
- end
292
-
293
- def eq(input, other)
294
- _eq(input, other)
295
- end
296
-
297
- def norm(input)
298
- _norm(input)
299
- end
300
-
301
- def pow(input, exponent)
302
- _pow(input, exponent)
303
- end
304
-
305
- def min(input)
306
- _min(input)
307
- end
308
-
309
- def max(input, dim = nil, keepdim: false, out: nil)
310
- if dim
311
- raise NotImplementedYet unless out
312
- _max_out(out[0], out[1], input, dim, keepdim)
313
- else
314
- _max(input)
315
- end
316
- end
317
-
318
- def exp(input)
319
- _exp(input)
320
- end
321
-
322
- def log(input)
323
- _log(input)
324
- end
325
-
326
- def sign(input)
327
- _sign(input)
328
- end
329
-
330
- def gt(input, other)
331
- _gt(input, other)
332
- end
333
-
334
- def lt(input, other)
335
- _lt(input, other)
336
- end
337
-
338
- def unsqueeze(input, dim)
339
- _unsqueeze(input, dim)
340
- end
341
-
342
- def dot(input, tensor)
343
- _dot(input, tensor)
344
- end
345
-
346
- def cat(tensors, dim = 0)
347
- _cat(tensors, dim)
348
- end
349
-
350
- def matmul(input, other)
351
- _matmul(input, other)
352
- end
353
-
354
- def reshape(input, shape)
355
- _reshape(input, shape)
356
- end
357
-
358
- def flatten(input, start_dim: 0, end_dim: -1)
359
- _flatten(input, start_dim, end_dim)
360
- end
361
-
362
- def sqrt(input)
363
- _sqrt(input)
364
- end
365
-
366
- def abs(input)
367
- _abs(input)
368
- end
369
-
370
- def device(str)
371
- Device.new(str)
424
+ zeros(input.size, **like_options(input, options))
372
425
  end
373
426
 
374
427
  private
375
428
 
376
- def execute_op(op, input, other, out: nil)
377
- scalar = other.is_a?(Numeric)
378
- if out
379
- # TODO make work with scalars
380
- raise Error, "out not supported with scalar yet" if scalar
381
- send("_#{op}_out", out, input, other)
382
- else
383
- if scalar
384
- send("_#{op}_scalar", input, other)
385
- else
386
- send("_#{op}", input, other)
387
- end
388
- end
389
- end
390
-
391
429
  def tensor_size(size)
392
430
  size.flatten
393
431
  end