torch-rb 0.1.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +28 -0
- data/LICENSE.txt +46 -0
- data/README.md +426 -0
- data/ext/torch/ext.cpp +839 -0
- data/ext/torch/extconf.rb +25 -0
- data/lib/torch-rb.rb +1 -0
- data/lib/torch.rb +422 -0
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +85 -0
- data/lib/torch/nn/alpha_dropout.rb +9 -0
- data/lib/torch/nn/conv2d.rb +37 -0
- data/lib/torch/nn/convnd.rb +41 -0
- data/lib/torch/nn/dropout.rb +9 -0
- data/lib/torch/nn/dropout2d.rb +9 -0
- data/lib/torch/nn/dropout3d.rb +9 -0
- data/lib/torch/nn/dropoutnd.rb +15 -0
- data/lib/torch/nn/embedding.rb +52 -0
- data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
- data/lib/torch/nn/functional.rb +100 -0
- data/lib/torch/nn/init.rb +30 -0
- data/lib/torch/nn/linear.rb +36 -0
- data/lib/torch/nn/module.rb +85 -0
- data/lib/torch/nn/mse_loss.rb +13 -0
- data/lib/torch/nn/parameter.rb +14 -0
- data/lib/torch/nn/relu.rb +13 -0
- data/lib/torch/nn/sequential.rb +29 -0
- data/lib/torch/optim/adadelta.rb +57 -0
- data/lib/torch/optim/adagrad.rb +71 -0
- data/lib/torch/optim/adam.rb +81 -0
- data/lib/torch/optim/adamax.rb +68 -0
- data/lib/torch/optim/adamw.rb +82 -0
- data/lib/torch/optim/asgd.rb +65 -0
- data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
- data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
- data/lib/torch/optim/optimizer.rb +62 -0
- data/lib/torch/optim/rmsprop.rb +76 -0
- data/lib/torch/optim/rprop.rb +68 -0
- data/lib/torch/optim/sgd.rb +60 -0
- data/lib/torch/tensor.rb +196 -0
- data/lib/torch/utils/data/data_loader.rb +27 -0
- data/lib/torch/utils/data/tensor_dataset.rb +22 -0
- data/lib/torch/version.rb +3 -0
- metadata +169 -0
@@ -0,0 +1,27 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
class DataLoader
|
5
|
+
include Enumerable
|
6
|
+
|
7
|
+
attr_reader :dataset
|
8
|
+
|
9
|
+
def initialize(dataset, batch_size: 1)
|
10
|
+
@dataset = dataset
|
11
|
+
@batch_size = batch_size
|
12
|
+
end
|
13
|
+
|
14
|
+
def each
|
15
|
+
size.times do |i|
|
16
|
+
start_index = i * @batch_size
|
17
|
+
yield @dataset[start_index...(start_index + @batch_size)]
|
18
|
+
end
|
19
|
+
end
|
20
|
+
|
21
|
+
def size
|
22
|
+
(@dataset.size / @batch_size.to_f).ceil
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
@@ -0,0 +1,22 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
class TensorDataset
|
5
|
+
def initialize(*tensors)
|
6
|
+
unless tensors.all? { |t| t.size(0) == tensors[0].size(0) }
|
7
|
+
raise Error, "Tensors must all have same dim 0 size"
|
8
|
+
end
|
9
|
+
@tensors = tensors
|
10
|
+
end
|
11
|
+
|
12
|
+
def [](index)
|
13
|
+
@tensors.map { |t| t[index] }
|
14
|
+
end
|
15
|
+
|
16
|
+
def size
|
17
|
+
@tensors[0].size(0)
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
metadata
ADDED
@@ -0,0 +1,169 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: torch-rb
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 0.1.3
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- Andrew Kane
|
8
|
+
autorequire:
|
9
|
+
bindir: bin
|
10
|
+
cert_chain: []
|
11
|
+
date: 2019-11-30 00:00:00.000000000 Z
|
12
|
+
dependencies:
|
13
|
+
- !ruby/object:Gem::Dependency
|
14
|
+
name: rice
|
15
|
+
requirement: !ruby/object:Gem::Requirement
|
16
|
+
requirements:
|
17
|
+
- - ">="
|
18
|
+
- !ruby/object:Gem::Version
|
19
|
+
version: '0'
|
20
|
+
type: :runtime
|
21
|
+
prerelease: false
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - ">="
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: '0'
|
27
|
+
- !ruby/object:Gem::Dependency
|
28
|
+
name: bundler
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
30
|
+
requirements:
|
31
|
+
- - ">="
|
32
|
+
- !ruby/object:Gem::Version
|
33
|
+
version: '0'
|
34
|
+
type: :development
|
35
|
+
prerelease: false
|
36
|
+
version_requirements: !ruby/object:Gem::Requirement
|
37
|
+
requirements:
|
38
|
+
- - ">="
|
39
|
+
- !ruby/object:Gem::Version
|
40
|
+
version: '0'
|
41
|
+
- !ruby/object:Gem::Dependency
|
42
|
+
name: rake
|
43
|
+
requirement: !ruby/object:Gem::Requirement
|
44
|
+
requirements:
|
45
|
+
- - ">="
|
46
|
+
- !ruby/object:Gem::Version
|
47
|
+
version: '0'
|
48
|
+
type: :development
|
49
|
+
prerelease: false
|
50
|
+
version_requirements: !ruby/object:Gem::Requirement
|
51
|
+
requirements:
|
52
|
+
- - ">="
|
53
|
+
- !ruby/object:Gem::Version
|
54
|
+
version: '0'
|
55
|
+
- !ruby/object:Gem::Dependency
|
56
|
+
name: rake-compiler
|
57
|
+
requirement: !ruby/object:Gem::Requirement
|
58
|
+
requirements:
|
59
|
+
- - ">="
|
60
|
+
- !ruby/object:Gem::Version
|
61
|
+
version: '0'
|
62
|
+
type: :development
|
63
|
+
prerelease: false
|
64
|
+
version_requirements: !ruby/object:Gem::Requirement
|
65
|
+
requirements:
|
66
|
+
- - ">="
|
67
|
+
- !ruby/object:Gem::Version
|
68
|
+
version: '0'
|
69
|
+
- !ruby/object:Gem::Dependency
|
70
|
+
name: minitest
|
71
|
+
requirement: !ruby/object:Gem::Requirement
|
72
|
+
requirements:
|
73
|
+
- - ">="
|
74
|
+
- !ruby/object:Gem::Version
|
75
|
+
version: '5'
|
76
|
+
type: :development
|
77
|
+
prerelease: false
|
78
|
+
version_requirements: !ruby/object:Gem::Requirement
|
79
|
+
requirements:
|
80
|
+
- - ">="
|
81
|
+
- !ruby/object:Gem::Version
|
82
|
+
version: '5'
|
83
|
+
- !ruby/object:Gem::Dependency
|
84
|
+
name: numo-narray
|
85
|
+
requirement: !ruby/object:Gem::Requirement
|
86
|
+
requirements:
|
87
|
+
- - ">="
|
88
|
+
- !ruby/object:Gem::Version
|
89
|
+
version: '0'
|
90
|
+
type: :development
|
91
|
+
prerelease: false
|
92
|
+
version_requirements: !ruby/object:Gem::Requirement
|
93
|
+
requirements:
|
94
|
+
- - ">="
|
95
|
+
- !ruby/object:Gem::Version
|
96
|
+
version: '0'
|
97
|
+
description:
|
98
|
+
email: andrew@chartkick.com
|
99
|
+
executables: []
|
100
|
+
extensions:
|
101
|
+
- ext/torch/extconf.rb
|
102
|
+
extra_rdoc_files: []
|
103
|
+
files:
|
104
|
+
- CHANGELOG.md
|
105
|
+
- LICENSE.txt
|
106
|
+
- README.md
|
107
|
+
- ext/torch/ext.cpp
|
108
|
+
- ext/torch/extconf.rb
|
109
|
+
- lib/torch-rb.rb
|
110
|
+
- lib/torch.rb
|
111
|
+
- lib/torch/ext.bundle
|
112
|
+
- lib/torch/inspector.rb
|
113
|
+
- lib/torch/nn/alpha_dropout.rb
|
114
|
+
- lib/torch/nn/conv2d.rb
|
115
|
+
- lib/torch/nn/convnd.rb
|
116
|
+
- lib/torch/nn/dropout.rb
|
117
|
+
- lib/torch/nn/dropout2d.rb
|
118
|
+
- lib/torch/nn/dropout3d.rb
|
119
|
+
- lib/torch/nn/dropoutnd.rb
|
120
|
+
- lib/torch/nn/embedding.rb
|
121
|
+
- lib/torch/nn/feature_alpha_dropout.rb
|
122
|
+
- lib/torch/nn/functional.rb
|
123
|
+
- lib/torch/nn/init.rb
|
124
|
+
- lib/torch/nn/linear.rb
|
125
|
+
- lib/torch/nn/module.rb
|
126
|
+
- lib/torch/nn/mse_loss.rb
|
127
|
+
- lib/torch/nn/parameter.rb
|
128
|
+
- lib/torch/nn/relu.rb
|
129
|
+
- lib/torch/nn/sequential.rb
|
130
|
+
- lib/torch/optim/adadelta.rb
|
131
|
+
- lib/torch/optim/adagrad.rb
|
132
|
+
- lib/torch/optim/adam.rb
|
133
|
+
- lib/torch/optim/adamax.rb
|
134
|
+
- lib/torch/optim/adamw.rb
|
135
|
+
- lib/torch/optim/asgd.rb
|
136
|
+
- lib/torch/optim/lr_scheduler/lr_scheduler.rb
|
137
|
+
- lib/torch/optim/lr_scheduler/step_lr.rb
|
138
|
+
- lib/torch/optim/optimizer.rb
|
139
|
+
- lib/torch/optim/rmsprop.rb
|
140
|
+
- lib/torch/optim/rprop.rb
|
141
|
+
- lib/torch/optim/sgd.rb
|
142
|
+
- lib/torch/tensor.rb
|
143
|
+
- lib/torch/utils/data/data_loader.rb
|
144
|
+
- lib/torch/utils/data/tensor_dataset.rb
|
145
|
+
- lib/torch/version.rb
|
146
|
+
homepage: https://github.com/ankane/torch-rb
|
147
|
+
licenses:
|
148
|
+
- BSD-3-Clause
|
149
|
+
metadata: {}
|
150
|
+
post_install_message:
|
151
|
+
rdoc_options: []
|
152
|
+
require_paths:
|
153
|
+
- lib
|
154
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
155
|
+
requirements:
|
156
|
+
- - ">="
|
157
|
+
- !ruby/object:Gem::Version
|
158
|
+
version: '2.4'
|
159
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
160
|
+
requirements:
|
161
|
+
- - ">="
|
162
|
+
- !ruby/object:Gem::Version
|
163
|
+
version: '0'
|
164
|
+
requirements: []
|
165
|
+
rubygems_version: 3.0.3
|
166
|
+
signing_key:
|
167
|
+
specification_version: 4
|
168
|
+
summary: Deep learning for Ruby, powered by LibTorch
|
169
|
+
test_files: []
|