red-chainer 0.4.0 → 0.4.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/lib/chainer.rb +2 -0
- data/lib/chainer/functions/connection/embed_id.rb +49 -0
- data/lib/chainer/links/connection/embed_id.rb +23 -0
- data/lib/chainer/serializers/marshal.rb +2 -2
- data/lib/chainer/utils/array.rb +21 -17
- data/lib/chainer/version.rb +1 -1
- metadata +4 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: d05c5d3cb55a9e6c7e45afffd5483b9968959749d911a2e6292f5bea1e90ef40
|
4
|
+
data.tar.gz: fcdbb11f8b64a3f54a68a1629c40e9ffce774b551e79d21689757f1a495a7d66
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: e0a1423530e8b62dacd6ae876957cd00e604a659815dd1bd2e3ce6a7c30d3f22f68f718451b261e559fb3719211067a79c78a0b1f229e96c96db4dc7d3a7004a
|
7
|
+
data.tar.gz: da1bc06481526f5424705a8d68f0c4a1e98decaf30c81e217c5390ab96c4d4850eeecdec7cf88287c4581cc5556bbd24adbf26c3007145b9bf9e007e40047ae1
|
data/lib/chainer.rb
CHANGED
@@ -22,6 +22,7 @@ require 'chainer/initializers/uniform'
|
|
22
22
|
require 'chainer/iterators/serial_iterator'
|
23
23
|
require 'chainer/link'
|
24
24
|
require 'chainer/links/connection/convolution_2d'
|
25
|
+
require 'chainer/links/connection/embed_id'
|
25
26
|
require 'chainer/links/connection/linear'
|
26
27
|
require 'chainer/links/normalization/batch_normalization'
|
27
28
|
require 'chainer/links/model/classifier'
|
@@ -56,6 +57,7 @@ require 'chainer/functions/loss/softmax_cross_entropy'
|
|
56
57
|
require 'chainer/functions/connection/convolution_2d'
|
57
58
|
require 'chainer/functions/connection/deconvolution_2d'
|
58
59
|
require 'chainer/functions/connection/convolution_2d_grad_w'
|
60
|
+
require 'chainer/functions/connection/embed_id'
|
59
61
|
require 'chainer/functions/connection/linear'
|
60
62
|
require 'chainer/functions/noise/dropout'
|
61
63
|
require 'chainer/functions/normalization/batch_normalization'
|
@@ -0,0 +1,49 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Connection
|
4
|
+
class EmbedIDFunction < Chainer::Function
|
5
|
+
def initialize(ignore_label: nil)
|
6
|
+
@ignore_label = ignore_label
|
7
|
+
end
|
8
|
+
|
9
|
+
def self.embed_id(x, w, ignore_label: nil)
|
10
|
+
self.new(ignore_label: ignore_label).(x, w)
|
11
|
+
end
|
12
|
+
|
13
|
+
def forward(inputs)
|
14
|
+
xm = Chainer.get_array_module(*inputs)
|
15
|
+
(x, w) = inputs
|
16
|
+
|
17
|
+
unless @ignore_label
|
18
|
+
return [Chainer::Utils::Array.take(w, x, axis: 0)]
|
19
|
+
end
|
20
|
+
|
21
|
+
valid_x = x.ne(@ignore_label)
|
22
|
+
if valid_x.count == x.size
|
23
|
+
return [Chainer::Utils::Array.take(w, x, axis: 0)]
|
24
|
+
end
|
25
|
+
x *= valid_x
|
26
|
+
y = Chainer::Utils::Array.take(w, x, axis: 0).dup
|
27
|
+
|
28
|
+
y = y.reshape(y.shape.take(y.shape.size - 1).reduce(&:*), true)
|
29
|
+
valid_x.where2.last.each {|i| y[i, true] = y.class.zeros(y.shape.last) }
|
30
|
+
|
31
|
+
[y.reshape(*x.shape, true)]
|
32
|
+
end
|
33
|
+
|
34
|
+
def backward(inputs, grad_outputs)
|
35
|
+
(x, w) = inputs
|
36
|
+
gy = grad_outputs[0].reshape(x.size, true)
|
37
|
+
gw = w.class.zeros(w.shape).reshape(w.shape.take(w.shape.size - 1).reduce(&:*), true)
|
38
|
+
|
39
|
+
x.reshape(x.size).each_with_index do |ix, i|
|
40
|
+
next if ix == @ignore_label
|
41
|
+
gw[ix, true] = gw[ix, true] + gy[i, true]
|
42
|
+
end
|
43
|
+
|
44
|
+
[nil, gw.reshape(*w.shape)]
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
@@ -0,0 +1,23 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Links
|
3
|
+
module Connection
|
4
|
+
class EmbedID < ::Chainer::Link
|
5
|
+
attr_reader :w
|
6
|
+
|
7
|
+
def initialize(in_size, out_size, initial_w: nil, ignore_label: nil)
|
8
|
+
super()
|
9
|
+
@ignore_label = ignore_label
|
10
|
+
|
11
|
+
init_scope do
|
12
|
+
initial_w ||= Chainer::Initializers::Normal.new(scale: 1.0)
|
13
|
+
@w = Chainer::Parameter.new(initializer: initial_w, shape: [in_size, out_size])
|
14
|
+
end
|
15
|
+
end
|
16
|
+
|
17
|
+
def call(x)
|
18
|
+
Chainer::Functions::Connection::EmbedIDFunction.embed_id(x, @w, ignore_label: @ignore_label)
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
@@ -46,9 +46,9 @@ module Chainer
|
|
46
46
|
#
|
47
47
|
# @param [string] filename Name of the file to be loaded.
|
48
48
|
# @param [object] obj Object to be deserialized. It must support serialization protocol.
|
49
|
-
def self.load_file(filename, obj)
|
49
|
+
def self.load_file(filename, obj, path: '', strict: true)
|
50
50
|
File.open(filename) do |f|
|
51
|
-
d = self.new(Marshal.load(f))
|
51
|
+
d = self.new(Marshal.load(f), path: path, strict: strict)
|
52
52
|
d.load(obj)
|
53
53
|
end
|
54
54
|
end
|
data/lib/chainer/utils/array.rb
CHANGED
@@ -18,28 +18,32 @@ module Chainer
|
|
18
18
|
end
|
19
19
|
end
|
20
20
|
|
21
|
-
def self.
|
22
|
-
|
23
|
-
|
21
|
+
def self.ndindex(shape)
|
22
|
+
shape.reduce(&:*).times.map do |i|
|
23
|
+
shape.size.times.reduce([]) do |ndidx, j|
|
24
|
+
ndidx << (i / shape.drop(j + 1).reduce(1, &:*)) % shape[j]
|
25
|
+
end
|
24
26
|
end
|
25
|
-
x[indices]
|
26
27
|
end
|
27
28
|
|
28
|
-
def self.
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
else
|
39
|
-
enum.map do |x|
|
40
|
-
make_indecies_with_axis(shape, indices, axis, values + [x])
|
29
|
+
def self.take(x, indices, axis: nil)
|
30
|
+
if axis
|
31
|
+
dimensional_indices = ::Array.new(x.shape.size, true)
|
32
|
+
|
33
|
+
indices_narray = Numo::Int32.cast(indices)
|
34
|
+
if indices_narray.shape.size > 1
|
35
|
+
y = x.class.zeros(*indices_narray.shape, *x.shape.drop(axis + 1))
|
36
|
+
self.ndindex(indices_narray.shape).each do |ndidx|
|
37
|
+
dimensional_indices[axis] = indices_narray[*ndidx]
|
38
|
+
y[*ndidx, *::Array.new(x.shape.size - axis - 1, true)] = x[*dimensional_indices]
|
41
39
|
end
|
40
|
+
return y
|
41
|
+
else
|
42
|
+
dimensional_indices[axis] = indices
|
42
43
|
end
|
44
|
+
x[*dimensional_indices]
|
45
|
+
else
|
46
|
+
x[indices]
|
43
47
|
end
|
44
48
|
end
|
45
49
|
|
data/lib/chainer/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: red-chainer
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.4.
|
4
|
+
version: 0.4.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Yusaku Hatanaka
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-
|
11
|
+
date: 2019-04-08 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -145,6 +145,7 @@ files:
|
|
145
145
|
- lib/chainer/functions/connection/convolution_2d.rb
|
146
146
|
- lib/chainer/functions/connection/convolution_2d_grad_w.rb
|
147
147
|
- lib/chainer/functions/connection/deconvolution_2d.rb
|
148
|
+
- lib/chainer/functions/connection/embed_id.rb
|
148
149
|
- lib/chainer/functions/connection/linear.rb
|
149
150
|
- lib/chainer/functions/evaluation/accuracy.rb
|
150
151
|
- lib/chainer/functions/loss/mean_squared_error.rb
|
@@ -169,6 +170,7 @@ files:
|
|
169
170
|
- lib/chainer/iterators/serial_iterator.rb
|
170
171
|
- lib/chainer/link.rb
|
171
172
|
- lib/chainer/links/connection/convolution_2d.rb
|
173
|
+
- lib/chainer/links/connection/embed_id.rb
|
172
174
|
- lib/chainer/links/connection/linear.rb
|
173
175
|
- lib/chainer/links/model/classifier.rb
|
174
176
|
- lib/chainer/links/normalization/batch_normalization.rb
|