ruby-cntk 0.1.0.pre1 → 0.1.0.pre2

Sign up to get free protection for your applications and to get access to all the features.
data/ext/cntk/extconf.rb CHANGED
@@ -1,21 +1,13 @@
1
- require "rbconfig"
2
-
3
- #RbConfig::MAKEFILE_CONFIG["CXX"] = "clang++-3.8"
4
-
5
1
  require 'mkmf'
6
2
 
7
-
8
-
9
-
10
3
  have_library("c++") or have_library("stdc++")
11
4
 
12
- # have_library("cntklibrary-2.0", nil, nil, " -L /cntk/build/cpu/release/lib/ ")
13
5
  dir_config("cntklibrary-2.0")
14
6
  have_library("cntklibrary-2.0")
15
7
 
16
8
  # rake2.3 compile -- --with-cntklibrary-2.0-lib=/cntk/build/cpu/release/lib/ --with-cntklibrary-2.0-include=/cntk/Source/CNTKv2LibraryDll/API/
17
9
  # rake2.3 compile -- --with-cntklibrary-2.0-lib=/cntk/cntk/lib/ --with-cntklibrary-2.0-include=/cntk/Include/
18
- $CXXFLAGS = ($CXXFLAGS || "") + " -std=c++11 -O2 -DSWIG "
19
- # $LDFLAGS = ($LDFLAGS || "") + ""
10
+
11
+ $CXXFLAGS = ($CXXFLAGS || "") + " -std=c++11 -O2 -DSWIG -Wl,--whole-archive "
20
12
 
21
13
  create_makefile('cntk/CNTK')
data/lib/cntk/axis.rb ADDED
@@ -0,0 +1,18 @@
1
+ module CNTK
2
+ class Axis
3
+ class << self
4
+ def from_num(n)
5
+ case n
6
+ when nil
7
+ all_static_axes
8
+ when Numeric
9
+ new(-n-1)
10
+ when Axis
11
+ new(-1-n.static_axis_index)
12
+ else
13
+ raise ArgumentError
14
+ end
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,62 @@
1
+ module CNTK
2
+
3
+ class DictionaryValue
4
+
5
+ def self.create(val)
6
+ case val
7
+ when Hash
8
+ new Dictionary.create(val)
9
+ when Array
10
+ v = StdVectorDictionaryValue.new
11
+ val.each_with_index{|e, idx|
12
+ v[idx] = create(e)
13
+ }
14
+ new( v )
15
+ else
16
+ new val
17
+ end
18
+ end
19
+
20
+ def value
21
+ case value_type
22
+ when Type_Bool
23
+ value_bool__
24
+ when Type_Int
25
+ value_int__
26
+ when Type_SizeT
27
+ value_size_t__
28
+ when Type_Float
29
+ value_float__
30
+ when Type_Double
31
+ value_double__
32
+ when Type_String
33
+ value_string__
34
+ when Type_NDShape
35
+ value_ndshape__
36
+ when Type_Axis
37
+ value_axis__
38
+ when Type_Vector
39
+ value_vec_dict_value__
40
+ when Type_Dictionary
41
+ value_dict__
42
+ when Type_NDArrayView
43
+ value_ndarrayview__
44
+ else
45
+ raise "unknown type"
46
+ end
47
+ end
48
+
49
+ end
50
+
51
+ class Dictionary
52
+ def self.create(h)
53
+ dict = new()
54
+ h.each_pair{|k, v|
55
+ k = k.to_s if k.is_a?(Symbol)
56
+ dict[k] = DictionaryValue.create(v)
57
+ }
58
+ return dict
59
+ end
60
+ end
61
+
62
+ end
data/lib/cntk/function.rb CHANGED
@@ -1,61 +1,113 @@
1
1
  module CNTK
2
+
2
3
  class Function
3
4
 
4
- def call(args)
5
- if args.outputs.length == 1
6
- return replace_placeholders({placeholders[0] => args.output})
5
+ def dot(other)
6
+ output.dot(other)
7
+ end
8
+
9
+ def -@
10
+ - output
11
+ end
12
+
13
+ def +(other)
14
+ output + other
15
+ end
16
+
17
+ def -(other)
18
+ output - other
19
+ end
20
+
21
+ def *(other)
22
+ output * other
23
+ end
24
+
25
+ def /(other)
26
+ output / other
27
+ end
28
+
29
+ # FIXME
30
+ def coerce(other)
31
+ if other.is_a?(Numeric)
32
+ [Constant::scalar(output.get_data_type, other), self]
33
+ else
34
+
35
+ end
36
+ end
37
+
38
+ def call(func)
39
+ if func.respond_to?(:output)
40
+ val = func.output
7
41
  else
8
- raise "not implemented"
42
+ val = func
9
43
  end
44
+ if placeholders().length == 1
45
+ replace_placeholders({placeholders[0] => val})
46
+ else
47
+ raise "the number of placeholders is not 1."
48
+ end
49
+ end
50
+
51
+ # forward function composition self(func(...))
52
+ def >>(func)
53
+ func.call(self)
54
+ end
55
+
56
+ def <<(func)
57
+ call(func)
58
+ end
59
+
60
+ def forward(argsmap, outputs = [], keep_for_backward: [], device: DeviceDescriptor.use_default_device(), remove_dynamic_axes: true)
61
+ input = convert_to_value(argsmap)
62
+ out = StdUMapVariableValue.new()
63
+ outputs.each{|out_var|
64
+ # By setting nullptr, Forward function implemented in C++ will allocate Value object with required storage.
65
+ out.__set_nullptr__(out_var)
66
+ }
67
+ b = __forward__(input, out, device, keep_for_backward)
68
+ # FIXME. we will remove this line.
69
+ out = remove_dynamic_axes(out) if remove_dynamic_axes
70
+ return [b, out]
10
71
  end
11
72
 
12
- def forward(*args)
13
- if args.length > 1
14
- return __forward__(*args)
15
- elsif args.length == 1
16
- input = convert_to_value(args[0])
17
- out = StdUMapVariableValue.new()
18
- outputs().each{|o|
19
- v = NDArrayView.new(CNTK::DataType_Double,
20
- required_output_shape(o),
21
- required_output_buf(o),
22
- CNTK::DeviceDescriptor.default_device(),
23
- true)
24
- out[o] = Value.new(v)
25
- }
26
- b = __forward__(input, out)
27
- out = remove_dynamic_axes(out)
28
- return [out, b]
73
+ def eval(argsmap=nil, device: DeviceDescriptor.use_default_device(), remove_dynamic_axes: true)
74
+ argsmap = {} if argsmap == nil
75
+ _, outmap = forward(argsmap, outputs(), device: device, remove_dynamic_axes: remove_dynamic_axes)
76
+ if outmap.size > 1
77
+ outmap
78
+ else
79
+ outmap.values[0]
29
80
  end
30
81
  end
31
82
 
83
+ def backward(state, root_gradients, variables, remove_dynamic_axes: true)
84
+ root_gradients = convert_to_value(root_gradients)
85
+ out = StdUMapVariableValue.new()
86
+ variables.each{|var|
87
+ out.__set_nullptr__(var)
88
+ }
89
+ __backward__(state, root_gradients, out)
90
+ out = remove_dynamic_axes(out)
91
+ end
92
+
32
93
  def convert_to_value(h)
33
- input = {}
94
+ ret = {}
34
95
  h.each_pair{|k,val|
35
96
  if val.respond_to?(:row_major?)
36
- input[k] = Value.create(val)
97
+ ret[k] = Value.new(NDArrayView.create(val))
37
98
  else
38
- input[k] = val
99
+ ret[k] = val
39
100
  end
40
101
  }
41
- return input
42
- end
43
-
44
- #FIXME
45
- # we must add dynamic axes?
46
- def required_output_shape(ov)
47
- ov.shape().to_a + [1,1]
48
- end
49
-
50
- def required_output_buf(ov)
51
- [1.0] * ov.shape.total_size
102
+ return ret
52
103
  end
53
104
 
54
105
  def remove_dynamic_axes(out)
55
106
  out1 = {}
56
107
  out.each{|o,ov|
57
- if ov.shape.rank == o.shape.rank + 2 and ov.shape.to_a[-2..-1] == [1,1]
58
- out1[o] = ov.reshape( ov.shape.to_a[0..-3] )
108
+ sz = o.dynamic_axes.size
109
+ if sz > 0 and sz < ov.shape.rank and ov.shape.to_a[0..1] == [1,1]
110
+ out1[o] = ov.reshape( ov.shape.to_a[sz..-1] )
59
111
  else
60
112
  out1[o] = ov
61
113
  end
@@ -0,0 +1,67 @@
1
+ module CNTK
2
+ module Initializer
3
+ class << self
4
+
5
+ def constant
6
+ CNTK.__constant_initializer__
7
+ end
8
+
9
+ def uniform(scale, seed = CNTK.SentinelValueForAutoSelectRandomSeed)
10
+ CNTK.__uniform_initializer__(scale, seed)
11
+ end
12
+
13
+ def normal(scale,
14
+ output_rank: CNTK.SentinelValueForInferParamInitRank,
15
+ filter_rank: CNTK.SentinelValueForInferParamInitRank,
16
+ seed: CNTK.SentinelValueForAutoSelectRandomSeed)
17
+ CNTK.__normal_initializer__(scale, output_rank, filter_rank, seed)
18
+ end
19
+
20
+ def glorot_uniform(scale = CNTK.DefaultParamInitScale,
21
+ output_rank: CNTK.SentinelValueForInferParamInitRank,
22
+ filter_rank: CNTK.SentinelValueForInferParamInitRank,
23
+ seed: CNTK.SentinelValueForAutoSelectRandomSeed)
24
+ CNTK.__glorot_uniform_initializer__(scale, output_rank, filter_rank, seed)
25
+ end
26
+
27
+ def glorot_normal(scale = CNTK.DefaultParamInitScale,
28
+ output_rank: CNTK.SentinelValueForInferParamInitRank,
29
+ filter_rank: CNTK.SentinelValueForInferParamInitRank,
30
+ seed: CNTK.SentinelValueForAutoSelectRandomSeed)
31
+ CNTK.__glorot_normal_initializer__(scale, output_rank, filter_rank, seed)
32
+ end
33
+
34
+ def xavier(scale,
35
+ output_rank: CNTK.SentinelValueForInferParamInitRank,
36
+ filter_rank: CNTK.SentinelValueForInferParamInitRank,
37
+ seed: CNTK.SentinelValueForAutoSelectRandomSeed)
38
+ CNTK.__xavier_initializer__(scale, output_rank, filter_rank, seed)
39
+ end
40
+
41
+ def he_uniform(scale = CNTK.DefaultParamInitScale,
42
+ output_rank: CNTK.SentinelValueForInferParamInitRank,
43
+ filter_rank: CNTK.SentinelValueForInferParamInitRank,
44
+ seed: CNTK.SentinelValueForAutoSelectRandomSeed)
45
+ CNTK.__he_uniform_initializer__(scale, output_rank, filter_rank, seed)
46
+ end
47
+
48
+ def he_normal(scale = CNTK.DefaultParamInitScale,
49
+ output_rank: CNTK.SentinelValueForInferParamInitRank,
50
+ filter_rank: CNTK.SentinelValueForInferParamInitRank,
51
+ seed: CNTK.SentinelValueForAutoSelectRandomSeed)
52
+ CNTK.__he_normal_initializer__(scale, output_rank, filter_rank, seed)
53
+ end
54
+
55
+ def bilinear(kernel_width, kernel_height)
56
+ CNTK.__bilinear_initializer__(kernel_width, kernel_height)
57
+ end
58
+
59
+ def initializer_with_rank(initializer,
60
+ output_rank: CNTK.SentinelValueForInferParamInitRank,
61
+ filter_rank: CNTK.SentinelValueForInferParamInitRank)
62
+ CNTK.__random_initializer_with_rank__(initializer, output_rank, filter_rank)
63
+ end
64
+
65
+ end # class << self
66
+ end # module Initializer
67
+ end # module CNT
@@ -0,0 +1,57 @@
1
+ module CNTK
2
+ module InspectUtil
3
+ def inspect_methods_p(mthds)
4
+ mthds.map{|mth| "#{mth}=" + send(mth).inspect }.join(", ")
5
+ end
6
+
7
+ def inspect_methods(mthds)
8
+ s = inspect_methods_p(mthds)
9
+ "#<#{self.class} #{s}>"
10
+ end
11
+ end
12
+
13
+ class StdUMapStreamInfoMinibatchData
14
+ include InspectUtil
15
+ def inspect
16
+ s = "{" + map{|k, v| k.inspect + " => " + v.inspect }.join(", ") + "}"
17
+ "#<#{self.class}: #{s}>"
18
+ end
19
+ end
20
+
21
+ class Axis
22
+ include InspectUtil
23
+ def inspect
24
+ inspect_methods([:name, :is_dynamic_axis])
25
+ end
26
+ end
27
+
28
+ class NDShape
29
+ def inspect
30
+ to_a.inspect
31
+ end
32
+ end
33
+
34
+ class Value
35
+ include InspectUtil
36
+ def inspect
37
+ inspect_methods([:shape])
38
+ end
39
+ end
40
+
41
+ class StreamInformation
42
+ include InspectUtil
43
+
44
+ def inspect
45
+ inspect_methods([:name, :id])
46
+ end
47
+ end
48
+
49
+ class MinibatchData
50
+ include InspectUtil
51
+
52
+ def inspect
53
+ inspect_methods([:data, :number_of_samples])
54
+ end
55
+ end
56
+
57
+ end
data/lib/cntk/io.rb ADDED
@@ -0,0 +1,50 @@
1
+ module CNTK
2
+
3
+ def self.create_composite_minibatch_source(dict)
4
+ if dict.respond_to?(:to_hash)
5
+ h = {}
6
+ dict.to_hash.each_pair{|k, v|
7
+ k = k.to_s if k.is_a?(Symbol)
8
+ h[k] = v
9
+ }
10
+ des = h["deserializers"]
11
+ unless des.respond_to?(:to_ary)
12
+ h["deserializers"] = [des]
13
+ end
14
+ dict = Dictionary.create(h)
15
+ end
16
+ CNTK.__create_composite_minibatch_source__(dict)
17
+ end
18
+
19
+ class MinibatchSource
20
+
21
+ # @param minibatch_size_in_samples [Integer]
22
+ # @param device [DeviceDescriptor]
23
+ # @param num_data_partitions [Integer]
24
+ # @param partition_index [Integer]
25
+ # @return [MinibatchData]
26
+ def next_minibatch(minibatch_size_in_samples, device: DeviceDescriptor.use_default_device,
27
+ num_data_partitions: 1, partition_index: 0)
28
+ get_next_minibatch(0, minibatch_size_in_samples, num_data_partitions, partition_index, device)
29
+ end
30
+
31
+ end
32
+
33
+ # std::unordered_map<StreamInfo, MinibatchData>
34
+ class MinibatchTable
35
+ alias __get__ :[]
36
+ def [](key)
37
+ if key.respond_to?(:to_str)
38
+ key = key.to_str
39
+ a = self.keys.find_all{|k| k.name == key }
40
+ if a.size > 1
41
+ raise "The number of input data having the name is not 1."
42
+ end
43
+ __get__(a[0])
44
+ else
45
+ __get__(key)
46
+ end
47
+ end
48
+ end
49
+
50
+ end
@@ -0,0 +1,17 @@
1
+ module CNTK
2
+ module Layers
3
+
4
+ module_function
5
+
6
+ # @return [Function]
7
+ def dense(output_shape, init: Initializer.glorot_uniform,
8
+ input_shape: [CNTK::NDShape::InferredDimension],
9
+ use_bias: true, init_bias: 0, name: "")
10
+ _W = Ops.parameter(shape: input_shape + output_shape, init: init, name: "W")
11
+ b = Ops.parameter(shape: output_shape, init: init_bias, name: "b")
12
+ x = Ops.placeholder_variable(name: "x")
13
+ Ops.times(x, _W, output_rank: output_shape.size, infer_input_rank_to_map: 0) + b
14
+ end
15
+
16
+ end
17
+ end
@@ -0,0 +1,197 @@
1
+ module CNTK
2
+
3
+ class Learner
4
+ LearningRateSchedule = MomentumSchedule
5
+ MinibatchSizeSchedule = TrainingParameterPerSampleSchedule
6
+
7
+ # TrainingParameterPerSampleSchedule == MinibatchSizeSchedule
8
+ # TrainingParameterPerMinibatchSchedule
9
+ #
10
+ class << self
11
+
12
+ private
13
+
14
+ def create_opt(l1_weight, l2_weight, ga, threshold, truncation)
15
+ opt = AdditionalLearningOptions.new
16
+ opt.l1_regularization_weight = l1_weight
17
+ opt.l2_regularization_weight = l2_weight
18
+ opt.gaussian_noise_injection_std_dev = ga
19
+ opt.gradient_clipping_threshold_per_sample = threshold
20
+ opt.gradient_clipping_with_truncation = truncation
21
+
22
+ return opt
23
+ end
24
+
25
+ public
26
+
27
+ # @param schedule [Numeric, Array<Numeric>]
28
+ # @param unit [:sample, :minibatch]
29
+ # @param epoch_size [Numeric]
30
+ # @return [TrainingParameterPerSampleSchedule, TrainingParameterPerMinibatchSchedule]
31
+ def training_parameter_schedule(schedule, unit, epoch_size = nil)
32
+ case unit
33
+ when :sample
34
+ klass = TrainingParameterPerSampleSchedule
35
+ when :minibatch
36
+ klass = TrainingParameterPerMinibatchSchedule
37
+ else
38
+ raise "unknown unit"
39
+ end
40
+
41
+ if schedule.is_a?(Numeric)
42
+ unless epoch_size.nil?
43
+ raise "epoch_size can't be given when schedule is Numeric."
44
+ else
45
+ klass.new(schedule)
46
+ end
47
+ else
48
+ if epoch_size.nil?
49
+ klass.new(schedule)
50
+ else
51
+ klass.new(schedule, epoch_size)
52
+ end
53
+ end
54
+
55
+ end
56
+
57
+ # @param schedule [Numeric, Array<Numeric>]
58
+ # @param unit [:sample, :minibatch]
59
+ # @param epoch_size [Numeric]
60
+ # @return [TrainingParameterPerSampleSchedule, TrainingParameterPerMinibatchSchedule]
61
+ def momentum_schedule(schedule, unit = :minibatch, epoch_size = nil)
62
+ training_parameter_schedule(schedule, unit, epoch_size)
63
+ end
64
+
65
+ # @param schedule [Numeric, Array<Numeric>]
66
+ # @param epoch_size [Numeric]
67
+ # @return [MomentumAsTimeConstantSchedule]
68
+ def momentum_as_time_constant_schedule(schedule, epoch_size)
69
+ klass = MomentumAsTimeConstantSchedule
70
+ if schedule.is_a?(Numeric)
71
+ if epoch_size.nil?
72
+ raise "epoch_size can't be given when schedule is Numeric."
73
+ else
74
+ klass.new(schedule)
75
+ end
76
+ else
77
+ if epoch_size.nil?
78
+ klass.new(schedule)
79
+ else
80
+ klass.new(schedule, epoch_size)
81
+ end
82
+ end
83
+ end
84
+
85
+ # @param parameters [Array<Parameter>]
86
+ # @param lr [TrainingParameterPerSampleSchedule, TrainingParameterPerMinibatchSchedule]
87
+ # @option opt [Float] :l1_weight
88
+ # @option opt [Float] :l2_weight
89
+ # @option opt [Float] :std_dev
90
+ # @option opt [Float] :threshold
91
+ # @option opt [Boolean] :truncation
92
+ # @return [Learner]
93
+ def sgd(parameters, lr, l1_weight: 0.0, l2_weight: 0.0,
94
+ std_dev: 0.0, threshold: Float::INFINITY, truncation: true)
95
+ ga = training_parameter_schedule(std_dev, :minibatch)
96
+ opt = create_opt(l1_weight, l2_weight, ga, threshold, truncation)
97
+ CNTK.__sgdlearner__(parameters, lr, opt)
98
+ end
99
+
100
+ # @param parameters [Array<Parameter>]
101
+ # @param lr [TrainingParameterPerSampleSchedule, TrainingParameterPerMinibatchSchedule]
102
+ # @param momentum [MomentumSchedule]
103
+ # @param unit_gain [Boolean]
104
+ # @option opt [Float] :l1_weight
105
+ # @option opt [Float] :l2_weight
106
+ # @option opt [Float] :std_dev
107
+ # @option opt [Boolean] :truncation
108
+ # @return [Learner]
109
+ def momentum_sgd(parameters, lr, momentum, unit_gain: CNTK.default_unit_gain_value(),
110
+ l1_weight: 0.0, l2_weight: 0.0,
111
+ std_dev: 0.0, threshold: Float::INFINITY, truncation: true)
112
+ ga = training_parameter_schedule(std_dev, :minibatch)
113
+ opt = create_opt(l1_weight, l2_weight, ga, threshold, truncation)
114
+ CNTK.__momentum_sgd_learner__(parameters, lr, momentum, unit_gain, opt)
115
+ end
116
+
117
+ # @param parameters [Array<Parameter>]
118
+ # @param lr [TrainingParameterPerSampleSchedule, TrainingParameterPerMinibatchSchedule]
119
+ # @param momentum [MomentumSchedule]
120
+ # @param unit_gain [Boolean]
121
+ # @option opt [Float] :l1_weight
122
+ # @option opt [Float] :l2_weight
123
+ # @option opt [Float] :std_dev
124
+ # @option opt [Boolean] :truncation
125
+ # @return [Learner]
126
+ def nesterov(parameters, lr, momentum, unit_gain: CNTK.default_unit_gain_value(),
127
+ l1_weight: 0.0, l2_weight: 0.0,
128
+ std_dev: 0.0, threshold: Float::INFINITY, truncation: true)
129
+ ga = training_parameter_schedule(std_dev, :minibatch)
130
+ opt = create_opt(l1_weight, l2_weight, ga, threshold, truncation)
131
+ CNTK.__nesterov_learner__(parameters, lr, momentum, unit_gain, opt)
132
+ end
133
+
134
+ # @param parameters [Array<Parameter>]
135
+ # @param lr [TrainingParameterPerSampleSchedule, TrainingParameterPerMinibatchSchedule]
136
+ # @param momentum [MomentumSchedule]
137
+ # @param unit_gain [Boolean]
138
+ # @option opt [Float] :l1_weight
139
+ # @option opt [Float] :l2_weight
140
+ # @option opt [Float] :std_dev
141
+ # @option opt [Boolean] :truncation
142
+ # @return [Learner]
143
+ def adagrad(parameters, lr, multiplier: true, unit_gain: CNTK.default_unit_gain_value(),
144
+ l1_weight: 0.0, l2_weight: 0.0,
145
+ std_dev: 0.0, threshold: Float::INFINITY, truncation: true)
146
+ ga = training_parameter_schedule(std_dev, :minibatch)
147
+ opt = create_opt(l1_weight, l2_weight, ga, threshold, truncation)
148
+ CNTK.__ada_grad_learner__(parameters, lr, multiplier, unit_gain, opt)
149
+ end
150
+
151
+ # @param parameters [Array<Parameter>]
152
+ # @param lr [TrainingParameterPerSampleSchedule, TrainingParameterPerMinibatchSchedule]
153
+ # @param momentum [MomentumSchedule]
154
+ # @param unit_gain [Boolean]
155
+ # @param variance_momentum [MomentumAsTimeConstantSchedule]
156
+ # @param low_memory [Boolean]
157
+ # @option opt [Float] :l1_weight
158
+ # @option opt [Float] :l2_weight
159
+ # @option opt [Float] :std_dev
160
+ # @option opt [Boolean] :truncation
161
+ # @return [Learner]
162
+ def adam_sgd(parameters, lr, momentum, unit_gain: CNTK.default_unit_gain_value(),
163
+ variance_momentum: momentum_as_time_constant_schedule(720000),
164
+ low_memory: true,
165
+ l1_weight: 0.0, l2_weight: 0.0,
166
+ std_dev: 0.0, threshold: Float::INFINITY, truncation: true)
167
+ ga = training_parameter_schedule(std_dev, :minibatch)
168
+ opt = create_opt(l1_weight, l2_weight, ga, threshold, truncation)
169
+ CNTK.__adam_learner__(parameters, lr, momentum, unit_gain, variance_momentum, low_memory, opt)
170
+ end
171
+
172
+ # @param parameters [Array<Parameter>]
173
+ # @param lr [TrainingParameterPerSampleSchedule, TrainingParameterPerMinibatchSchedule]
174
+ # @param gamma [Float]
175
+ # @param inc [Float]
176
+ # @param dec [Float]
177
+ # @param max [Float]
178
+ # @param min [Float]
179
+ # @param multiplier [Boolean]
180
+ # @option opt [Float] :l1_weight
181
+ # @option opt [Float] :l2_weight
182
+ # @option opt [Float] :std_dev
183
+ # @option opt [Boolean] :truncation
184
+ # @return [Learner]
185
+ def rmsprop(parameters, lr, gamma, inc, dec, max, min,
186
+ multiplier: true, l1_weight: 0.0, l2_weight: 0.0,
187
+ std_dev: 0.0, threshold: Float::INFINITY, truncation: true)
188
+ ga = training_parameter_schedule(std_dev, :minibatch)
189
+ opt = create_opt(l1_weight, l2_weight, ga, threshold, truncation)
190
+ CNTK.__rmsprop_learner__(parameters, lr, gamma, inc, dec, max, min, multiplier, opt)
191
+ end
192
+
193
+ end # class << self
194
+
195
+ end # class Learner
196
+
197
+ end # module CNTK
@@ -2,25 +2,36 @@ module CNTK
2
2
  class NDArrayView
3
3
 
4
4
  def self.create(a)
5
- if a.respond_to?(:shape) and a.respond_to?(:row_major?)
6
- if a.row_major?
7
- # NDArrayView is column-major.
8
- # So we must transpose a.
9
- ta = a.transpose
5
+ if a.respond_to?(:shape)
6
+ case a
7
+ when NDArrayView
8
+ return a
9
+ when Numo::DFloat
10
+ dtype = DataType_Double
11
+ when Numo::SFloat
12
+ dtype = DataType_Float
13
+ else
14
+ raise ArgumentError, "Numo::NArray or NDArrayView expected"
10
15
  end
11
- return self.new(DataType_Double, a.shape, ta.flatten.to_a,
16
+ return self.new(dtype, a.shape, a.flatten.to_a,
12
17
  CNTK::DeviceDescriptor.default_device(), false)
13
18
  else
14
- raise "not implemented"
19
+ raise ArgumentError, "not responds to :shape"
15
20
  end
16
21
  end
17
22
 
18
23
  def to_narray
19
- ret = Numo::DFloat[*to_vec()]
20
- # NDArrayView is column-major and NArray is row-major.
21
- # So we must reverse shape and transpose it.
22
- ret = ret.reshape(*shape().reverse)
23
- return ret.transpose
24
+ case get_data_type
25
+ when DataType_Float
26
+ klass = Numo::SFloat
27
+ when DataType_Double
28
+ klass = Numo::DFloat
29
+ else
30
+ raise "unknown data type"
31
+ end
32
+ ret = klass[*to_vec()]
33
+ ret = ret.reshape(*shape().to_a)
34
+ return ret
24
35
  end
25
36
 
26
37
  end