red-chainer 0.4.0 → 0.4.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: e7ed2df404bfc36275381f523c0439f2e73debcf36f3b5edb063e985502d7a70
4
- data.tar.gz: 357c983134aae985808568113d3f4f82bacebf6d25e5cf7c4f9197b1825455dc
3
+ metadata.gz: d05c5d3cb55a9e6c7e45afffd5483b9968959749d911a2e6292f5bea1e90ef40
4
+ data.tar.gz: fcdbb11f8b64a3f54a68a1629c40e9ffce774b551e79d21689757f1a495a7d66
5
5
  SHA512:
6
- metadata.gz: 40eb83d14d6efd140a4cb9748f04f50cfa325c9831d8020890a20fe88fc1485547f4dcab48cdcadfda317b46b3f4a6bc936eb8204ae39a876e053878caa7359f
7
- data.tar.gz: af4133b975c5b4b5ca6e2ce9fb05eddd2b1de5a8a30df9c776531a5acdcf5bc4d8322dc7d6875c49800587a4d98031d0eb62054dbd87ced964093c501da32c95
6
+ metadata.gz: e0a1423530e8b62dacd6ae876957cd00e604a659815dd1bd2e3ce6a7c30d3f22f68f718451b261e559fb3719211067a79c78a0b1f229e96c96db4dc7d3a7004a
7
+ data.tar.gz: da1bc06481526f5424705a8d68f0c4a1e98decaf30c81e217c5390ab96c4d4850eeecdec7cf88287c4581cc5556bbd24adbf26c3007145b9bf9e007e40047ae1
@@ -22,6 +22,7 @@ require 'chainer/initializers/uniform'
22
22
  require 'chainer/iterators/serial_iterator'
23
23
  require 'chainer/link'
24
24
  require 'chainer/links/connection/convolution_2d'
25
+ require 'chainer/links/connection/embed_id'
25
26
  require 'chainer/links/connection/linear'
26
27
  require 'chainer/links/normalization/batch_normalization'
27
28
  require 'chainer/links/model/classifier'
@@ -56,6 +57,7 @@ require 'chainer/functions/loss/softmax_cross_entropy'
56
57
  require 'chainer/functions/connection/convolution_2d'
57
58
  require 'chainer/functions/connection/deconvolution_2d'
58
59
  require 'chainer/functions/connection/convolution_2d_grad_w'
60
+ require 'chainer/functions/connection/embed_id'
59
61
  require 'chainer/functions/connection/linear'
60
62
  require 'chainer/functions/noise/dropout'
61
63
  require 'chainer/functions/normalization/batch_normalization'
@@ -0,0 +1,49 @@
1
+ module Chainer
2
+ module Functions
3
+ module Connection
4
+ class EmbedIDFunction < Chainer::Function
5
+ def initialize(ignore_label: nil)
6
+ @ignore_label = ignore_label
7
+ end
8
+
9
+ def self.embed_id(x, w, ignore_label: nil)
10
+ self.new(ignore_label: ignore_label).(x, w)
11
+ end
12
+
13
+ def forward(inputs)
14
+ xm = Chainer.get_array_module(*inputs)
15
+ (x, w) = inputs
16
+
17
+ unless @ignore_label
18
+ return [Chainer::Utils::Array.take(w, x, axis: 0)]
19
+ end
20
+
21
+ valid_x = x.ne(@ignore_label)
22
+ if valid_x.count == x.size
23
+ return [Chainer::Utils::Array.take(w, x, axis: 0)]
24
+ end
25
+ x *= valid_x
26
+ y = Chainer::Utils::Array.take(w, x, axis: 0).dup
27
+
28
+ y = y.reshape(y.shape.take(y.shape.size - 1).reduce(&:*), true)
29
+ valid_x.where2.last.each {|i| y[i, true] = y.class.zeros(y.shape.last) }
30
+
31
+ [y.reshape(*x.shape, true)]
32
+ end
33
+
34
+ def backward(inputs, grad_outputs)
35
+ (x, w) = inputs
36
+ gy = grad_outputs[0].reshape(x.size, true)
37
+ gw = w.class.zeros(w.shape).reshape(w.shape.take(w.shape.size - 1).reduce(&:*), true)
38
+
39
+ x.reshape(x.size).each_with_index do |ix, i|
40
+ next if ix == @ignore_label
41
+ gw[ix, true] = gw[ix, true] + gy[i, true]
42
+ end
43
+
44
+ [nil, gw.reshape(*w.shape)]
45
+ end
46
+ end
47
+ end
48
+ end
49
+ end
@@ -0,0 +1,23 @@
1
+ module Chainer
2
+ module Links
3
+ module Connection
4
+ class EmbedID < ::Chainer::Link
5
+ attr_reader :w
6
+
7
+ def initialize(in_size, out_size, initial_w: nil, ignore_label: nil)
8
+ super()
9
+ @ignore_label = ignore_label
10
+
11
+ init_scope do
12
+ initial_w ||= Chainer::Initializers::Normal.new(scale: 1.0)
13
+ @w = Chainer::Parameter.new(initializer: initial_w, shape: [in_size, out_size])
14
+ end
15
+ end
16
+
17
+ def call(x)
18
+ Chainer::Functions::Connection::EmbedIDFunction.embed_id(x, @w, ignore_label: @ignore_label)
19
+ end
20
+ end
21
+ end
22
+ end
23
+ end
@@ -46,9 +46,9 @@ module Chainer
46
46
  #
47
47
  # @param [string] filename Name of the file to be loaded.
48
48
  # @param [object] obj Object to be deserialized. It must support serialization protocol.
49
- def self.load_file(filename, obj)
49
+ def self.load_file(filename, obj, path: '', strict: true)
50
50
  File.open(filename) do |f|
51
- d = self.new(Marshal.load(f))
51
+ d = self.new(Marshal.load(f), path: path, strict: strict)
52
52
  d.load(obj)
53
53
  end
54
54
  end
@@ -18,28 +18,32 @@ module Chainer
18
18
  end
19
19
  end
20
20
 
21
- def self.take(x, indices, axis: nil)
22
- if axis
23
- indices = make_indecies_with_axis(x.shape, indices, axis)
21
+ def self.ndindex(shape)
22
+ shape.reduce(&:*).times.map do |i|
23
+ shape.size.times.reduce([]) do |ndidx, j|
24
+ ndidx << (i / shape.drop(j + 1).reduce(1, &:*)) % shape[j]
25
+ end
24
26
  end
25
- x[indices]
26
27
  end
27
28
 
28
- def self.make_indecies_with_axis(shape, indices, axis, values = [])
29
- target_axis = values.size
30
- if shape.size == values.size
31
- values.zip(shape.drop(1) + [1]).reduce(0) do |sum, (x, ndim)|
32
- (sum + x) * ndim
33
- end
34
- else
35
- enum = (axis == target_axis) ? indices : (0...shape[target_axis])
36
- if enum.is_a?(Integer)
37
- make_indecies_with_axis(shape, indices, axis, values + [indices])
38
- else
39
- enum.map do |x|
40
- make_indecies_with_axis(shape, indices, axis, values + [x])
29
+ def self.take(x, indices, axis: nil)
30
+ if axis
31
+ dimensional_indices = ::Array.new(x.shape.size, true)
32
+
33
+ indices_narray = Numo::Int32.cast(indices)
34
+ if indices_narray.shape.size > 1
35
+ y = x.class.zeros(*indices_narray.shape, *x.shape.drop(axis + 1))
36
+ self.ndindex(indices_narray.shape).each do |ndidx|
37
+ dimensional_indices[axis] = indices_narray[*ndidx]
38
+ y[*ndidx, *::Array.new(x.shape.size - axis - 1, true)] = x[*dimensional_indices]
41
39
  end
40
+ return y
41
+ else
42
+ dimensional_indices[axis] = indices
42
43
  end
44
+ x[*dimensional_indices]
45
+ else
46
+ x[indices]
43
47
  end
44
48
  end
45
49
 
@@ -1,4 +1,4 @@
1
1
  module Chainer
2
- VERSION = "0.4.0"
2
+ VERSION = "0.4.1"
3
3
  end
4
4
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: red-chainer
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.4.0
4
+ version: 0.4.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Yusaku Hatanaka
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2019-03-28 00:00:00.000000000 Z
11
+ date: 2019-04-08 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -145,6 +145,7 @@ files:
145
145
  - lib/chainer/functions/connection/convolution_2d.rb
146
146
  - lib/chainer/functions/connection/convolution_2d_grad_w.rb
147
147
  - lib/chainer/functions/connection/deconvolution_2d.rb
148
+ - lib/chainer/functions/connection/embed_id.rb
148
149
  - lib/chainer/functions/connection/linear.rb
149
150
  - lib/chainer/functions/evaluation/accuracy.rb
150
151
  - lib/chainer/functions/loss/mean_squared_error.rb
@@ -169,6 +170,7 @@ files:
169
170
  - lib/chainer/iterators/serial_iterator.rb
170
171
  - lib/chainer/link.rb
171
172
  - lib/chainer/links/connection/convolution_2d.rb
173
+ - lib/chainer/links/connection/embed_id.rb
172
174
  - lib/chainer/links/connection/linear.rb
173
175
  - lib/chainer/links/model/classifier.rb
174
176
  - lib/chainer/links/normalization/batch_normalization.rb