torch-rb 0.1.3 → 0.1.8

Sign up to get free protection for your applications and to get access to all the features.
Files changed (115) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +30 -0
  3. data/README.md +5 -2
  4. data/ext/torch/ext.cpp +130 -555
  5. data/ext/torch/extconf.rb +9 -0
  6. data/ext/torch/templates.cpp +55 -0
  7. data/ext/torch/templates.hpp +244 -0
  8. data/lib/torch.rb +209 -171
  9. data/lib/torch/inspector.rb +23 -19
  10. data/lib/torch/native/dispatcher.rb +48 -0
  11. data/lib/torch/native/function.rb +110 -0
  12. data/lib/torch/native/generator.rb +168 -0
  13. data/lib/torch/native/native_functions.yaml +6491 -0
  14. data/lib/torch/native/parser.rb +134 -0
  15. data/lib/torch/nn/avg_pool1d.rb +18 -0
  16. data/lib/torch/nn/avg_pool2d.rb +19 -0
  17. data/lib/torch/nn/avg_pool3d.rb +19 -0
  18. data/lib/torch/nn/avg_poolnd.rb +9 -0
  19. data/lib/torch/nn/batch_norm.rb +75 -0
  20. data/lib/torch/nn/batch_norm1d.rb +11 -0
  21. data/lib/torch/nn/batch_norm2d.rb +11 -0
  22. data/lib/torch/nn/batch_norm3d.rb +11 -0
  23. data/lib/torch/nn/bce_loss.rb +13 -0
  24. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  25. data/lib/torch/nn/bilinear.rb +38 -0
  26. data/lib/torch/nn/constant_pad1d.rb +10 -0
  27. data/lib/torch/nn/constant_pad2d.rb +10 -0
  28. data/lib/torch/nn/constant_pad3d.rb +10 -0
  29. data/lib/torch/nn/constant_padnd.rb +18 -0
  30. data/lib/torch/nn/conv1d.rb +22 -0
  31. data/lib/torch/nn/conv2d.rb +10 -20
  32. data/lib/torch/nn/conv3d.rb +22 -0
  33. data/lib/torch/nn/convnd.rb +3 -3
  34. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  35. data/lib/torch/nn/cosine_similarity.rb +15 -0
  36. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  37. data/lib/torch/nn/ctc_loss.rb +15 -0
  38. data/lib/torch/nn/dropoutnd.rb +2 -2
  39. data/lib/torch/nn/embedding_bag.rb +34 -0
  40. data/lib/torch/nn/fold.rb +20 -0
  41. data/lib/torch/nn/functional.rb +379 -32
  42. data/lib/torch/nn/group_norm.rb +36 -0
  43. data/lib/torch/nn/gru.rb +49 -0
  44. data/lib/torch/nn/hardshrink.rb +18 -0
  45. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  46. data/lib/torch/nn/identity.rb +14 -0
  47. data/lib/torch/nn/init.rb +58 -1
  48. data/lib/torch/nn/instance_norm.rb +20 -0
  49. data/lib/torch/nn/instance_norm1d.rb +18 -0
  50. data/lib/torch/nn/instance_norm2d.rb +11 -0
  51. data/lib/torch/nn/instance_norm3d.rb +11 -0
  52. data/lib/torch/nn/kl_div_loss.rb +13 -0
  53. data/lib/torch/nn/l1_loss.rb +13 -0
  54. data/lib/torch/nn/layer_norm.rb +35 -0
  55. data/lib/torch/nn/leaky_relu.rb +20 -0
  56. data/lib/torch/nn/linear.rb +12 -11
  57. data/lib/torch/nn/local_response_norm.rb +21 -0
  58. data/lib/torch/nn/log_sigmoid.rb +9 -0
  59. data/lib/torch/nn/log_softmax.rb +14 -0
  60. data/lib/torch/nn/loss.rb +10 -0
  61. data/lib/torch/nn/lp_pool1d.rb +9 -0
  62. data/lib/torch/nn/lp_pool2d.rb +9 -0
  63. data/lib/torch/nn/lp_poolnd.rb +22 -0
  64. data/lib/torch/nn/lstm.rb +66 -0
  65. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  66. data/lib/torch/nn/max_pool1d.rb +9 -0
  67. data/lib/torch/nn/max_pool2d.rb +9 -0
  68. data/lib/torch/nn/max_pool3d.rb +9 -0
  69. data/lib/torch/nn/max_poolnd.rb +19 -0
  70. data/lib/torch/nn/max_unpool1d.rb +16 -0
  71. data/lib/torch/nn/max_unpool2d.rb +16 -0
  72. data/lib/torch/nn/max_unpool3d.rb +16 -0
  73. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  74. data/lib/torch/nn/module.rb +186 -35
  75. data/lib/torch/nn/mse_loss.rb +2 -2
  76. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  77. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  78. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  79. data/lib/torch/nn/nll_loss.rb +14 -0
  80. data/lib/torch/nn/pairwise_distance.rb +16 -0
  81. data/lib/torch/nn/parameter.rb +2 -2
  82. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  83. data/lib/torch/nn/prelu.rb +19 -0
  84. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  85. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  86. data/lib/torch/nn/reflection_padnd.rb +13 -0
  87. data/lib/torch/nn/relu.rb +8 -3
  88. data/lib/torch/nn/replication_pad1d.rb +10 -0
  89. data/lib/torch/nn/replication_pad2d.rb +10 -0
  90. data/lib/torch/nn/replication_pad3d.rb +10 -0
  91. data/lib/torch/nn/replication_padnd.rb +13 -0
  92. data/lib/torch/nn/rnn.rb +22 -0
  93. data/lib/torch/nn/rnn_base.rb +198 -0
  94. data/lib/torch/nn/sequential.rb +1 -10
  95. data/lib/torch/nn/sigmoid.rb +9 -0
  96. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  97. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  98. data/lib/torch/nn/softmax.rb +18 -0
  99. data/lib/torch/nn/softmax2d.rb +10 -0
  100. data/lib/torch/nn/softmin.rb +14 -0
  101. data/lib/torch/nn/softplus.rb +19 -0
  102. data/lib/torch/nn/softshrink.rb +18 -0
  103. data/lib/torch/nn/softsign.rb +9 -0
  104. data/lib/torch/nn/tanh.rb +9 -0
  105. data/lib/torch/nn/tanhshrink.rb +9 -0
  106. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  107. data/lib/torch/nn/unfold.rb +19 -0
  108. data/lib/torch/nn/utils.rb +25 -0
  109. data/lib/torch/nn/weighted_loss.rb +10 -0
  110. data/lib/torch/nn/zero_pad2d.rb +9 -0
  111. data/lib/torch/random.rb +10 -0
  112. data/lib/torch/tensor.rb +51 -44
  113. data/lib/torch/version.rb +1 -1
  114. metadata +98 -6
  115. data/lib/torch/ext.bundle +0 -0
@@ -10,6 +10,9 @@ $CXXFLAGS << " -std=c++11"
10
10
  # silence ruby/intern.h warning
11
11
  $CXXFLAGS << " -Wno-deprecated-register"
12
12
 
13
+ # silence torch warnings
14
+ $CXXFLAGS << " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
15
+
13
16
  inc, lib = dir_config("torch")
14
17
 
15
18
  inc ||= "/usr/local/include"
@@ -22,4 +25,10 @@ $LDFLAGS << " -Wl,-rpath,#{lib}"
22
25
  $LDFLAGS << " -L#{lib}"
23
26
  $LDFLAGS << " -ltorch -lc10"
24
27
 
28
+ # generate C++ functions
29
+ puts "Generating C++ functions..."
30
+ require_relative "../../lib/torch/native/generator"
31
+ Torch::Native::Generator.generate_cpp_functions
32
+
33
+ # create makefile
25
34
  create_makefile("torch/ext")
@@ -0,0 +1,55 @@
1
+ #include <torch/torch.h>
2
+ #include <rice/Object.hpp>
3
+ #include "templates.hpp"
4
+
5
+ Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
6
+ Array a;
7
+ a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
8
+ a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
9
+ return Object(a);
10
+ }
11
+
12
+ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
13
+ Array a;
14
+ a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
15
+ a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
16
+ a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
17
+ return Object(a);
18
+ }
19
+
20
+ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
21
+ Array a;
22
+ a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
23
+ a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
24
+ a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
25
+ a.push(to_ruby<torch::Tensor>(std::get<3>(x)));
26
+ return Object(a);
27
+ }
28
+
29
+ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
30
+ Array a;
31
+ a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
32
+ a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
33
+ a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
34
+ a.push(to_ruby<torch::Tensor>(std::get<3>(x)));
35
+ a.push(to_ruby<torch::Tensor>(std::get<4>(x)));
36
+ return Object(a);
37
+ }
38
+
39
+ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
40
+ Array a;
41
+ a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
42
+ a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
43
+ a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
44
+ a.push(to_ruby<int64_t>(std::get<3>(x)));
45
+ return Object(a);
46
+ }
47
+
48
+ Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
49
+ Array a;
50
+ a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
51
+ a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
52
+ a.push(to_ruby<double>(std::get<2>(x)));
53
+ a.push(to_ruby<int64_t>(std::get<3>(x)));
54
+ return Object(a);
55
+ }
@@ -0,0 +1,244 @@
1
+ #pragma once
2
+
3
+ #ifdef isfinite
4
+ #undef isfinite
5
+ #endif
6
+
7
+ #include <rice/Array.hpp>
8
+ #include <rice/Object.hpp>
9
+
10
+ using namespace Rice;
11
+
12
+ // need to wrap torch::IntArrayRef() since
13
+ // it doesn't own underlying data
14
+ class IntArrayRef {
15
+ std::vector<int64_t> vec;
16
+ public:
17
+ IntArrayRef(Object o) {
18
+ Array a = Array(o);
19
+ for (size_t i = 0; i < a.size(); i++) {
20
+ vec.push_back(from_ruby<int64_t>(a[i]));
21
+ }
22
+ }
23
+ operator torch::IntArrayRef() {
24
+ return torch::IntArrayRef(vec);
25
+ }
26
+ };
27
+
28
+ template<>
29
+ inline
30
+ IntArrayRef from_ruby<IntArrayRef>(Object x)
31
+ {
32
+ return IntArrayRef(x);
33
+ }
34
+
35
+ // for now
36
+ class Scalar {
37
+ torch::Scalar value;
38
+ public:
39
+ Scalar(Object o) {
40
+ // TODO cast based on Ruby type
41
+ if (o.rb_type() == T_FIXNUM) {
42
+ value = torch::Scalar(from_ruby<int64_t>(o));
43
+ } else {
44
+ value = torch::Scalar(from_ruby<float>(o));
45
+ }
46
+ }
47
+ operator torch::Scalar() {
48
+ return value;
49
+ }
50
+ };
51
+
52
+ template<>
53
+ inline
54
+ Scalar from_ruby<Scalar>(Object x)
55
+ {
56
+ return Scalar(x);
57
+ }
58
+
59
+ class TensorList {
60
+ std::vector<torch::Tensor> vec;
61
+ public:
62
+ TensorList(Object o) {
63
+ Array a = Array(o);
64
+ for (size_t i = 0; i < a.size(); i++) {
65
+ vec.push_back(from_ruby<torch::Tensor>(a[i]));
66
+ }
67
+ }
68
+ operator torch::TensorList() {
69
+ return torch::TensorList(vec);
70
+ }
71
+ };
72
+
73
+ template<>
74
+ inline
75
+ TensorList from_ruby<TensorList>(Object x)
76
+ {
77
+ return TensorList(x);
78
+ }
79
+
80
+ class FanModeType {
81
+ std::string s;
82
+ public:
83
+ FanModeType(Object o) {
84
+ s = String(o).str();
85
+ }
86
+ operator torch::nn::init::FanModeType() {
87
+ if (s == "fan_in") {
88
+ return torch::kFanIn;
89
+ } else if (s == "fan_out") {
90
+ return torch::kFanOut;
91
+ } else {
92
+ throw std::runtime_error("Unsupported nonlinearity type: " + s);
93
+ }
94
+ }
95
+ };
96
+
97
+ template<>
98
+ inline
99
+ FanModeType from_ruby<FanModeType>(Object x)
100
+ {
101
+ return FanModeType(x);
102
+ }
103
+
104
+ class NonlinearityType {
105
+ std::string s;
106
+ public:
107
+ NonlinearityType(Object o) {
108
+ s = String(o).str();
109
+ }
110
+ operator torch::nn::init::NonlinearityType() {
111
+ if (s == "linear") {
112
+ return torch::kLinear;
113
+ } else if (s == "conv1d") {
114
+ return torch::kConv1D;
115
+ } else if (s == "conv2d") {
116
+ return torch::kConv2D;
117
+ } else if (s == "conv3d") {
118
+ return torch::kConv3D;
119
+ } else if (s == "conv_transpose1d") {
120
+ return torch::kConvTranspose1D;
121
+ } else if (s == "conv_transpose2d") {
122
+ return torch::kConvTranspose2D;
123
+ } else if (s == "conv_transpose3d") {
124
+ return torch::kConvTranspose3D;
125
+ } else if (s == "sigmoid") {
126
+ return torch::kSigmoid;
127
+ } else if (s == "tanh") {
128
+ return torch::kTanh;
129
+ } else if (s == "relu") {
130
+ return torch::kReLU;
131
+ } else if (s == "leaky_relu") {
132
+ return torch::kLeakyReLU;
133
+ } else {
134
+ throw std::runtime_error("Unsupported nonlinearity type: " + s);
135
+ }
136
+ }
137
+ };
138
+
139
+ template<>
140
+ inline
141
+ NonlinearityType from_ruby<NonlinearityType>(Object x)
142
+ {
143
+ return NonlinearityType(x);
144
+ }
145
+
146
+ class MyReduction {
147
+ Object value;
148
+ public:
149
+ MyReduction(Object o) {
150
+ value = o;
151
+ }
152
+ operator int64_t() {
153
+ if (value.is_nil()) {
154
+ return torch::Reduction::None;
155
+ }
156
+
157
+ std::string s = String(value).str();
158
+ if (s == "mean") {
159
+ return torch::Reduction::Mean;
160
+ } else if (s == "sum") {
161
+ return torch::Reduction::Sum;
162
+ } else {
163
+ throw std::runtime_error("Unsupported reduction: " + s);
164
+ }
165
+ }
166
+ };
167
+
168
+ template<>
169
+ inline
170
+ MyReduction from_ruby<MyReduction>(Object x)
171
+ {
172
+ return MyReduction(x);
173
+ }
174
+
175
+ typedef torch::Tensor Tensor;
176
+
177
+ class OptionalTensor {
178
+ Object value;
179
+ public:
180
+ OptionalTensor(Object o) {
181
+ value = o;
182
+ }
183
+ operator torch::Tensor() {
184
+ if (value.is_nil()) {
185
+ return {};
186
+ }
187
+ return from_ruby<torch::Tensor>(value);
188
+ }
189
+ };
190
+
191
+ template<>
192
+ inline
193
+ OptionalTensor from_ruby<OptionalTensor>(Object x)
194
+ {
195
+ return OptionalTensor(x);
196
+ }
197
+
198
+ class ScalarType {
199
+ Object value;
200
+ public:
201
+ ScalarType(Object o) {
202
+ value = o;
203
+ }
204
+ operator at::ScalarType() {
205
+ throw std::runtime_error("ScalarType arguments not implemented yet");
206
+ }
207
+ };
208
+
209
+ template<>
210
+ inline
211
+ ScalarType from_ruby<ScalarType>(Object x)
212
+ {
213
+ return ScalarType(x);
214
+ }
215
+
216
+ class OptionalScalarType {
217
+ Object value;
218
+ public:
219
+ OptionalScalarType(Object o) {
220
+ value = o;
221
+ }
222
+ operator c10::optional<at::ScalarType>() {
223
+ if (value.is_nil()) {
224
+ return c10::nullopt;
225
+ }
226
+ return ScalarType(value);
227
+ }
228
+ };
229
+
230
+ template<>
231
+ inline
232
+ OptionalScalarType from_ruby<OptionalScalarType>(Object x)
233
+ {
234
+ return OptionalScalarType(x);
235
+ }
236
+
237
+ typedef torch::Device Device;
238
+
239
+ Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
240
+ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
241
+ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
242
+ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
243
+ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x);
244
+ Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x);
@@ -1,6 +1,11 @@
1
1
  # ext
2
2
  require "torch/ext"
3
3
 
4
+ # native functions
5
+ require "torch/native/generator"
6
+ require "torch/native/parser"
7
+ require "torch/native/dispatcher"
8
+
4
9
  # modules
5
10
  require "torch/inspector"
6
11
  require "torch/tensor"
@@ -22,31 +27,145 @@ require "torch/optim/sgd"
22
27
  require "torch/optim/lr_scheduler/lr_scheduler"
23
28
  require "torch/optim/lr_scheduler/step_lr"
24
29
 
25
- # nn base classes
30
+ # nn parameters
31
+ require "torch/nn/parameter"
32
+ require "torch/nn/utils"
33
+
34
+ # nn containers
26
35
  require "torch/nn/module"
36
+ require "torch/nn/sequential"
37
+
38
+ # nn convolution layers
27
39
  require "torch/nn/convnd"
28
- require "torch/nn/dropoutnd"
40
+ require "torch/nn/conv1d"
41
+ require "torch/nn/conv2d"
42
+ require "torch/nn/conv3d"
43
+ require "torch/nn/unfold"
44
+ require "torch/nn/fold"
45
+
46
+ # nn pooling layers
47
+ require "torch/nn/max_poolnd"
48
+ require "torch/nn/max_pool1d"
49
+ require "torch/nn/max_pool2d"
50
+ require "torch/nn/max_pool3d"
51
+ require "torch/nn/max_unpoolnd"
52
+ require "torch/nn/max_unpool1d"
53
+ require "torch/nn/max_unpool2d"
54
+ require "torch/nn/max_unpool3d"
55
+ require "torch/nn/avg_poolnd"
56
+ require "torch/nn/avg_pool1d"
57
+ require "torch/nn/avg_pool2d"
58
+ require "torch/nn/avg_pool3d"
59
+ require "torch/nn/lp_poolnd"
60
+ require "torch/nn/lp_pool1d"
61
+ require "torch/nn/lp_pool2d"
62
+
63
+ # nn padding layers
64
+ require "torch/nn/reflection_padnd"
65
+ require "torch/nn/reflection_pad1d"
66
+ require "torch/nn/reflection_pad2d"
67
+ require "torch/nn/replication_padnd"
68
+ require "torch/nn/replication_pad1d"
69
+ require "torch/nn/replication_pad2d"
70
+ require "torch/nn/replication_pad3d"
71
+ require "torch/nn/constant_padnd"
72
+ require "torch/nn/constant_pad1d"
73
+ require "torch/nn/constant_pad2d"
74
+ require "torch/nn/constant_pad3d"
75
+ require "torch/nn/zero_pad2d"
76
+
77
+ # nn normalization layers
78
+ require "torch/nn/batch_norm"
79
+ require "torch/nn/batch_norm1d"
80
+ require "torch/nn/batch_norm2d"
81
+ require "torch/nn/batch_norm3d"
82
+ require "torch/nn/group_norm"
83
+ require "torch/nn/instance_norm"
84
+ require "torch/nn/instance_norm1d"
85
+ require "torch/nn/instance_norm2d"
86
+ require "torch/nn/instance_norm3d"
87
+ require "torch/nn/layer_norm"
88
+ require "torch/nn/local_response_norm"
89
+
90
+ # nn recurrent layers
91
+ require "torch/nn/rnn_base"
92
+ require "torch/nn/rnn"
93
+ require "torch/nn/lstm"
94
+ require "torch/nn/gru"
95
+
96
+ # nn linear layers
97
+ require "torch/nn/bilinear"
98
+ require "torch/nn/identity"
99
+ require "torch/nn/linear"
29
100
 
30
- # nn
101
+ # nn dropout layers
102
+ require "torch/nn/dropoutnd"
31
103
  require "torch/nn/alpha_dropout"
32
- require "torch/nn/conv2d"
33
104
  require "torch/nn/dropout"
34
105
  require "torch/nn/dropout2d"
35
106
  require "torch/nn/dropout3d"
36
- require "torch/nn/embedding"
37
107
  require "torch/nn/feature_alpha_dropout"
108
+
109
+ # nn activations
110
+ require "torch/nn/hardshrink"
111
+ require "torch/nn/leaky_relu"
112
+ require "torch/nn/log_sigmoid"
113
+ require "torch/nn/prelu"
114
+ require "torch/nn/relu"
115
+ require "torch/nn/sigmoid"
116
+ require "torch/nn/softplus"
117
+ require "torch/nn/softshrink"
118
+ require "torch/nn/softsign"
119
+ require "torch/nn/tanh"
120
+ require "torch/nn/tanhshrink"
121
+
122
+ # nn activations other
123
+ require "torch/nn/log_softmax"
124
+ require "torch/nn/softmax"
125
+ require "torch/nn/softmax2d"
126
+ require "torch/nn/softmin"
127
+
128
+ # nn sparse layers
129
+ require "torch/nn/embedding"
130
+ require "torch/nn/embedding_bag"
131
+
132
+ # nn distance functions
133
+ require "torch/nn/cosine_similarity"
134
+ require "torch/nn/pairwise_distance"
135
+
136
+ # nn loss functions
137
+ require "torch/nn/loss"
138
+ require "torch/nn/weighted_loss"
139
+ require "torch/nn/bce_loss"
140
+ require "torch/nn/bce_with_logits_loss"
141
+ require "torch/nn/cosine_embedding_loss"
142
+ require "torch/nn/cross_entropy_loss"
143
+ require "torch/nn/ctc_loss"
144
+ require "torch/nn/hinge_embedding_loss"
145
+ require "torch/nn/kl_div_loss"
146
+ require "torch/nn/l1_loss"
147
+ require "torch/nn/margin_ranking_loss"
148
+ require "torch/nn/mse_loss"
149
+ require "torch/nn/multi_label_margin_loss"
150
+ require "torch/nn/multi_label_soft_margin_loss"
151
+ require "torch/nn/multi_margin_loss"
152
+ require "torch/nn/nll_loss"
153
+ require "torch/nn/poisson_nll_loss"
154
+ require "torch/nn/smooth_l1_loss"
155
+ require "torch/nn/soft_margin_loss"
156
+ require "torch/nn/triplet_margin_loss"
157
+
158
+ # nn other
38
159
  require "torch/nn/functional"
39
160
  require "torch/nn/init"
40
- require "torch/nn/linear"
41
- require "torch/nn/mse_loss"
42
- require "torch/nn/parameter"
43
- require "torch/nn/relu"
44
- require "torch/nn/sequential"
45
161
 
46
162
  # utils
47
163
  require "torch/utils/data/data_loader"
48
164
  require "torch/utils/data/tensor_dataset"
49
165
 
166
+ # random
167
+ require "torch/random"
168
+
50
169
  module Torch
51
170
  class Error < StandardError; end
52
171
  class NotImplementedYet < StandardError
@@ -57,7 +176,6 @@ module Torch
57
176
 
58
177
  # keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
59
178
  # values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
60
- # complex and quantized types not supported by PyTorch yet
61
179
  DTYPE_TO_ENUM = {
62
180
  uint8: 0,
63
181
  int8: 1,
@@ -73,17 +191,52 @@ module Torch
73
191
  float32: 6,
74
192
  double: 7,
75
193
  float64: 7,
76
- # complex_half: 8,
77
- # complex_float: 9,
78
- # complex_double: 10,
194
+ complex_half: 8,
195
+ complex_float: 9,
196
+ complex_double: 10,
79
197
  bool: 11,
80
- # qint8: 12,
81
- # quint8: 13,
82
- # qint32: 14,
83
- # bfloat16: 15
198
+ qint8: 12,
199
+ quint8: 13,
200
+ qint32: 14,
201
+ bfloat16: 15
84
202
  }
85
203
  ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
86
204
 
205
+ def self._make_tensor_class(dtype, cuda = false)
206
+ cls = Class.new
207
+ device = cuda ? "cuda" : "cpu"
208
+ cls.define_singleton_method("new") do |*args|
209
+ if args.size == 1 && args.first.is_a?(Tensor)
210
+ args.first.send(dtype).to(device)
211
+ elsif args.size == 1 && args.first.is_a?(Array)
212
+ Torch.tensor(args.first, dtype: dtype, device: device)
213
+ else
214
+ Torch.empty(*args, dtype: dtype, device: device)
215
+ end
216
+ end
217
+ cls
218
+ end
219
+
220
+ FloatTensor = _make_tensor_class(:float32)
221
+ DoubleTensor = _make_tensor_class(:float64)
222
+ HalfTensor = _make_tensor_class(:float16)
223
+ ByteTensor = _make_tensor_class(:uint8)
224
+ CharTensor = _make_tensor_class(:int8)
225
+ ShortTensor = _make_tensor_class(:int16)
226
+ IntTensor = _make_tensor_class(:int32)
227
+ LongTensor = _make_tensor_class(:int64)
228
+ BoolTensor = _make_tensor_class(:bool)
229
+
230
+ CUDA::FloatTensor = _make_tensor_class(:float32, true)
231
+ CUDA::DoubleTensor = _make_tensor_class(:float64, true)
232
+ CUDA::HalfTensor = _make_tensor_class(:float16, true)
233
+ CUDA::ByteTensor = _make_tensor_class(:uint8, true)
234
+ CUDA::CharTensor = _make_tensor_class(:int8, true)
235
+ CUDA::ShortTensor = _make_tensor_class(:int16, true)
236
+ CUDA::IntTensor = _make_tensor_class(:int32, true)
237
+ CUDA::LongTensor = _make_tensor_class(:int64, true)
238
+ CUDA::BoolTensor = _make_tensor_class(:bool, true)
239
+
87
240
  class << self
88
241
  # Torch.float, Torch.long, etc
89
242
  DTYPE_TO_ENUM.each_key do |dtype|
@@ -120,6 +273,8 @@ module Torch
120
273
  # use method for cases when Numo not available
121
274
  # or available after Torch loaded
122
275
  def _dtype_to_numo
276
+ raise Error, "Numo not found" unless defined?(Numo::NArray)
277
+
123
278
  {
124
279
  uint8: Numo::UInt8,
125
280
  int8: Numo::Int8,
@@ -131,6 +286,29 @@ module Torch
131
286
  }
132
287
  end
133
288
 
289
+ def no_grad
290
+ previous_value = grad_enabled?
291
+ begin
292
+ _set_grad_enabled(false)
293
+ yield
294
+ ensure
295
+ _set_grad_enabled(previous_value)
296
+ end
297
+ end
298
+
299
+ def device(str)
300
+ Device.new(str)
301
+ end
302
+
303
+ def save(obj, f)
304
+ raise NotImplementedYet unless obj.is_a?(Tensor)
305
+ File.binwrite(f, _save(obj))
306
+ end
307
+
308
+ def load(f)
309
+ raise NotImplementedYet
310
+ end
311
+
134
312
  # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
135
313
 
136
314
  def arange(start, finish = nil, step = 1, **options)
@@ -200,8 +378,12 @@ module Torch
200
378
  data = [data].compact
201
379
  end
202
380
 
203
- if options[:dtype].nil? && data.all? { |v| v.is_a?(Integer) }
204
- options[:dtype] = :int64
381
+ if options[:dtype].nil?
382
+ if data.all? { |v| v.is_a?(Integer) }
383
+ options[:dtype] = :int64
384
+ elsif data.all? { |v| v == true || v == false }
385
+ options[:dtype] = :bool
386
+ end
205
387
  end
206
388
 
207
389
  _tensor(data, size, tensor_options(**options))
@@ -210,19 +392,19 @@ module Torch
210
392
  # --- begin like ---
211
393
 
212
394
  def ones_like(input, **options)
213
- ones(input.size, like_options(input, options))
395
+ ones(input.size, **like_options(input, options))
214
396
  end
215
397
 
216
398
  def empty_like(input, **options)
217
- empty(input.size, like_options(input, options))
399
+ empty(input.size, **like_options(input, options))
218
400
  end
219
401
 
220
402
  def full_like(input, fill_value, **options)
221
- full(input.size, fill_value, like_options(input, options))
403
+ full(input.size, fill_value, **like_options(input, options))
222
404
  end
223
405
 
224
406
  def rand_like(input, **options)
225
- rand(input.size, like_options(input, options))
407
+ rand(input.size, **like_options(input, options))
226
408
  end
227
409
 
228
410
  def randint_like(input, low, high = nil, **options)
@@ -231,163 +413,19 @@ module Torch
231
413
  high = low
232
414
  low = 0
233
415
  end
234
- randint(low, high, input.size, like_options(input, options))
416
+ randint(low, high, input.size, **like_options(input, options))
235
417
  end
236
418
 
237
419
  def randn_like(input, **options)
238
- randn(input.size, like_options(input, options))
420
+ randn(input.size, **like_options(input, options))
239
421
  end
240
422
 
241
423
  def zeros_like(input, **options)
242
- zeros(input.size, like_options(input, options))
243
- end
244
-
245
- # --- begin operations ---
246
-
247
- %w(add sub mul div remainder).each do |op|
248
- define_method(op) do |input, other, **options|
249
- execute_op(op, input, other, **options)
250
- end
251
- end
252
-
253
- def neg(input)
254
- _neg(input)
255
- end
256
-
257
- def no_grad
258
- previous_value = grad_enabled?
259
- begin
260
- _set_grad_enabled(false)
261
- yield
262
- ensure
263
- _set_grad_enabled(previous_value)
264
- end
265
- end
266
-
267
- # TODO support out
268
- def mean(input, dim = nil, keepdim: false)
269
- if dim
270
- _mean_dim(input, dim, keepdim)
271
- else
272
- _mean(input)
273
- end
274
- end
275
-
276
- # TODO support dtype
277
- def sum(input, dim = nil, keepdim: false)
278
- if dim
279
- _sum_dim(input, dim, keepdim)
280
- else
281
- _sum(input)
282
- end
283
- end
284
-
285
- def argmax(input, dim = nil, keepdim: false)
286
- if dim
287
- _argmax_dim(input, dim, keepdim)
288
- else
289
- _argmax(input)
290
- end
291
- end
292
-
293
- def eq(input, other)
294
- _eq(input, other)
295
- end
296
-
297
- def norm(input)
298
- _norm(input)
299
- end
300
-
301
- def pow(input, exponent)
302
- _pow(input, exponent)
303
- end
304
-
305
- def min(input)
306
- _min(input)
307
- end
308
-
309
- def max(input, dim = nil, keepdim: false, out: nil)
310
- if dim
311
- raise NotImplementedYet unless out
312
- _max_out(out[0], out[1], input, dim, keepdim)
313
- else
314
- _max(input)
315
- end
316
- end
317
-
318
- def exp(input)
319
- _exp(input)
320
- end
321
-
322
- def log(input)
323
- _log(input)
324
- end
325
-
326
- def sign(input)
327
- _sign(input)
328
- end
329
-
330
- def gt(input, other)
331
- _gt(input, other)
332
- end
333
-
334
- def lt(input, other)
335
- _lt(input, other)
336
- end
337
-
338
- def unsqueeze(input, dim)
339
- _unsqueeze(input, dim)
340
- end
341
-
342
- def dot(input, tensor)
343
- _dot(input, tensor)
344
- end
345
-
346
- def cat(tensors, dim = 0)
347
- _cat(tensors, dim)
348
- end
349
-
350
- def matmul(input, other)
351
- _matmul(input, other)
352
- end
353
-
354
- def reshape(input, shape)
355
- _reshape(input, shape)
356
- end
357
-
358
- def flatten(input, start_dim: 0, end_dim: -1)
359
- _flatten(input, start_dim, end_dim)
360
- end
361
-
362
- def sqrt(input)
363
- _sqrt(input)
364
- end
365
-
366
- def abs(input)
367
- _abs(input)
368
- end
369
-
370
- def device(str)
371
- Device.new(str)
424
+ zeros(input.size, **like_options(input, options))
372
425
  end
373
426
 
374
427
  private
375
428
 
376
- def execute_op(op, input, other, out: nil)
377
- scalar = other.is_a?(Numeric)
378
- if out
379
- # TODO make work with scalars
380
- raise Error, "out not supported with scalar yet" if scalar
381
- send("_#{op}_out", out, input, other)
382
- else
383
- if scalar
384
- send("_#{op}_scalar", input, other)
385
- else
386
- send("_#{op}", input, other)
387
- end
388
- end
389
- end
390
-
391
429
  def tensor_size(size)
392
430
  size.flatten
393
431
  end