ruby-cntk 0.1.0.pre1 → 0.1.0.pre2

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,8 @@
1
+ require "numo/narray"
2
+ module CNTK
3
+ class NDMask
4
+ def to_narray
5
+ Numo::Int8[*to_vec()].reshape(*shape())
6
+ end
7
+ end
8
+ end
data/lib/cntk/ndshape.rb CHANGED
@@ -1,10 +1,15 @@
1
1
  module CNTK
2
2
  class NDShape
3
- alias :to_ary :dimensions
4
- alias :to_a :dimensions
5
-
3
+ def to_ary
4
+ dimensions.reverse
5
+ end
6
+
7
+ def to_a
8
+ to_ary
9
+ end
10
+
6
11
  def reverse
7
- to_ary.reverse
12
+ dimensions
8
13
  end
9
14
  end
10
15
  end
data/lib/cntk/ops.rb ADDED
@@ -0,0 +1,552 @@
1
+ require "numo/narray"
2
+
3
+ module CNTK
4
+ module Ops
5
+
6
+ class << self
7
+
8
+ def reverse_reshape_axis(axis)
9
+ axis = Axis.new(axis) if axis.is_a?(Numeric)
10
+ case
11
+ when axis.is_static_axis
12
+ axis
13
+ when axis == Axis.end_static_axis
14
+ Axis.new(0)
15
+ when axis == Axis.new(0)
16
+ Axis.end_static_axis
17
+ else
18
+ Axis(-axis.static_axis_index)
19
+ end
20
+ end
21
+
22
+ def reverse_dynamic_axes(axes)
23
+ axes = [axes] unless axes.is_a?(Array)
24
+ axes.each{|ax|
25
+ raise ArgumentError, "Axis expected" unless ax.is_a?(Axis)
26
+ }
27
+ axes.reverse
28
+ end
29
+
30
+ def convert_to_pooling_type(type)
31
+ case type
32
+ when :max
33
+ type = CNTK::PoolingType_Max
34
+ when :average
35
+ type = CNTK::PoolingType_Average
36
+ else
37
+ raise ArgumentError, "unknown pooling type"
38
+ end
39
+ end
40
+
41
+ def convert_to_variable(*vars)
42
+ if vars.size == 1
43
+ convert_to_one_variable(vars[0])
44
+ else
45
+ dtype = highest_precision_type(*vars)
46
+ return vars.map{|v| convert_to_one_variable(v, dtype) }
47
+ end
48
+ end
49
+
50
+ def convert_to_one_variable(x, dtype = Numo::SFloat)
51
+ case x
52
+ when Variable
53
+ x
54
+ when Function
55
+ x.output
56
+ when Value, Numo::NArray, Numeric
57
+ Ops.constant(x)
58
+ when Array
59
+ Ops.constant( dtype[*x] )
60
+ else
61
+ raise ArgumentError, "CNTK::Variable, Numo::NArray, or Array expected"
62
+ end
63
+ end
64
+
65
+ def highest_precision_type(*args)
66
+ types = args.map{|v|
67
+ case v
68
+ when Variable
69
+ Variable::DataType[v.get_data_type]
70
+ when Numo::NArray
71
+ v.class
72
+ else
73
+ nil
74
+ end
75
+ }
76
+ if types.include?(Numo::DFloat)
77
+ Numo::DFloat
78
+ else
79
+ Numo::SFloat
80
+ end
81
+ end
82
+ end # class << self
83
+
84
+ module_function
85
+
86
+ #
87
+ # variable ops
88
+ #
89
+ def input_variable(shape, dtype: DataType_Float, needs_gradient: false,
90
+ is_sparse: false,
91
+ dynamic_axes: nil,
92
+ name: '')
93
+ if dynamic_axes
94
+ dynamic_axes = dynamic_axes.reverse
95
+ else
96
+ dynamic_axes = Axis.default_input_variable_dynamic_axes()
97
+ end
98
+ CNTK.__input_variable__(shape, is_sparse, dtype, needs_gradient, name, dynamic_axes)
99
+ end
100
+
101
+ def output_variable(shape: nil, dtype: nil, dynamic_axes: nil, name: "")
102
+ if dynamic_axes
103
+ dynamic_axes = dynamic_axes.reverse
104
+ end
105
+ CNTK.__output_variable__(shape, dtype, dynamic_axes, name)
106
+ end
107
+
108
+ def placeholder_variable(shape: NDShape.unknown.dimensions(), name: "", dynamic_axes: nil)
109
+ if dynamic_axes
110
+ dynamic_axes = dynamic_axes.reverse
111
+ else
112
+ dynamic_axes = Axis.unknown_dynamic_axes
113
+ end
114
+ CNTK.__placeholder_variable__(shape, name, dynamic_axes)
115
+ end
116
+
117
+ def constant(*args)
118
+ val = args[0]
119
+ if val.is_a?(Array)
120
+ args[0] = Numo::SFloat[*val]
121
+ end
122
+ Constant.create(*args)
123
+ end
124
+
125
+ def parameter(*args)
126
+ Parameter.create(*args)
127
+ end
128
+
129
+ #
130
+ # ops
131
+ #
132
+ def alias(x, name="")
133
+ x = Ops.convert_to_variable( x )
134
+ CNTK.__alias__(x, name)
135
+ end
136
+
137
+ def weighted_binary_cross_entropy(output, target, weight, name="")
138
+ output = Ops.convert_to_variable( output )
139
+ target = Ops.convert_to_variable( target )
140
+ weight = Ops.convert_to_variable( weight )
141
+ CNTK.__weighted_binary_cross_entropy__(output, target, weight, name)
142
+ end
143
+
144
+ def cross_entropy_with_softmax(output, target, axis=0, name="")
145
+ output = Ops.convert_to_variable( output )
146
+ target = Ops.convert_to_variable( target )
147
+ axis = Axis.from_num(axis)
148
+ CNTK.__cross_entropy_with_softmax__(output, target, axis, name)
149
+ end
150
+
151
+ def combine(array, name="")
152
+ a = array.map{|x| Ops.convert_to_variable( x ) }
153
+ CNTK.__combine__(a, name)
154
+ end
155
+
156
+ def convolution(kernel: nil, input: nil, strides: [1], sharing: [true],
157
+ padding: [false], lower_pad: [0], upper_pad: [0],
158
+ transpose: false, max_temp_mem_size_in_samples: 0, name: "")
159
+ kernel = Ops.convert_to_variable( kernel )
160
+ input = Ops.convert_to_variable( input )
161
+ CNTK.__convolution__(kernel, input, strides, sharing, padding, lower_pad, upper_pad,
162
+ transpose, max_temp_mem_size_in_samples, name)
163
+ end
164
+
165
+
166
+ # CNTK's NDArray is column-major.
167
+ # So to specify rois, remember it.
168
+ # y
169
+ # __________
170
+ # |
171
+ # x|
172
+ # |
173
+ def roipooling(x, rois, shape, name="")
174
+ x, rois = Ops.convert_to_variable( x, rois )
175
+ CNTK.__roipooling__(x, rois, shape, name)
176
+ end
177
+
178
+ def pooling(x, type, shape, strides: [1], padding: [false],
179
+ lower_pad: [0], upper_pad: [0], name: "")
180
+ x = Ops.convert_to_variable( x )
181
+ case type
182
+ when :max
183
+ type = CNTK::PoolingType_Max
184
+ when :average
185
+ type = CNTK::PoolingType_Average
186
+ else
187
+ raise ArgumentError, "unknown pooling type"
188
+ end
189
+ CNTK.__pooling__(x, type, shape, strides, padding, lower_pad, upper_pad, name)
190
+ end
191
+
192
+ def unpooling(operand, input, type, shape, strides: [1], padding: [false],
193
+ lower_pad: [0], upper_pad: [0], name: "")
194
+ operand, input = Ops.convert_to_variable( operand, input )
195
+ type = Ops.convert_to_pooling_type( type )
196
+ CNTK.__unpooling__(operand, input, type, shape, strides, padding, lower_pad, upper_pad, name)
197
+ end
198
+
199
+ def batch_normalization(x, scale: nil, bias: nil, mean: nil, variance: nil, spatial: false,
200
+ normalization_time_constant: 5000, blend_time_constant: 0,
201
+ epsilon: 0.00001, use_cudnn_engine: false, name: "", running_count: 0)
202
+ x, scale, bias, mean, variance, running_count =
203
+ Ops.convert_to_variable( x, scale, bias, mean, variance, running_count )
204
+ CNTK.__batch_normalization__(x, scale, bias, mean, variance, running_count, spatial,
205
+ normalization_time_constant, blend_time_constant,
206
+ epsilon, use_cudnn_engine, name)
207
+ end
208
+
209
+ def times(left, right, output_rank: 1, infer_input_rank_to_map: -1, name: "")
210
+ left, right = Ops.convert_to_variable( left, right )
211
+ # change the order because CNTK a column-major.
212
+ CNTK.__times__(right, left, output_rank, infer_input_rank_to_map, name)
213
+ end
214
+
215
+ def times_transpose(left, right, output_rank = 1, name="")
216
+ left, right = Ops.convert_to_variable( left, right )
217
+ CNTK.__transpose_times__(right, left, output_rank, name="")
218
+ end
219
+
220
+ def clip(x, min, max, name="")
221
+ x, min, max = Ops.convert_to_variable( x, min, max )
222
+ CNTK.__clip__(x, min, max, name)
223
+ end
224
+
225
+ def element_select(x, if_true, if_else, name="")
226
+ x, if_true, if_else = Ops.convert_to_variable( x, if_true, if_else )
227
+ CNTK.__element_select__(x, if_true, if_else, name)
228
+ end
229
+
230
+ def future_value(x, init=0, time_step=1, name="")
231
+ x, init = Ops.convert_to_variable( x, init )
232
+ CNTK.__future_value__(x, init, time_step, name)
233
+ end
234
+
235
+ def past_value(x, init=0, time_step=1, name="")
236
+ x, init = Ops.convert_to_variable( x, init )
237
+ CNTK.__past_value__(x, init, time_step, name)
238
+ end
239
+
240
+ def reshape(x, shape, begin_axis=Axis.new(0), end_axis=Axis.end_static_axis(), name="")
241
+ x = Ops.convert_to_variable( x )
242
+ begin_axis = Ops.reverse_reshape_axis(begin_axis)
243
+ end_axis = Ops.reverse_reshape_axis(end_axis )
244
+ CNTK.__reshape__(x, shape, begin_axis, end_axis, name)
245
+ end
246
+
247
+ def transpose(x, axis1=0, axis2=1, name="")
248
+ x = Ops.convert_to_variable( x )
249
+ unless axis1.abs <= x.shape.rank and axis2.abs <= x.shape.rank
250
+ raise ArgumentError, "out of bounds"
251
+ end
252
+ axis1 = Axis.from_num(axis1)
253
+ axis2 = Axis.from_num(axis2)
254
+ CNTK.__transpose_axes__(x, axis1, axis2, name)
255
+ end
256
+
257
+ def slice(x, axis, begin_index, end_index, name="")
258
+ x = Ops.convert_to_variable( x )
259
+ axis = Axis.from_num(axis)
260
+ CNTK.__slice__(x, axis, begin_index, end_index, name)
261
+ end
262
+
263
+ def splice(x, axis=-1, name="")
264
+ x = x.map{|var| Ops.convert_to_variable( var ) }
265
+ axis = Axis.from_num(axis)
266
+ CNTK.__splice__(x, axis, name)
267
+ end
268
+
269
+ def reduce_sum(x, axis=nil, name="")
270
+ x = Ops.convert_to_variable( x )
271
+ axis = Axis.from_num(axis)
272
+ CNTK.__reduce_sum__(x, axis, name)
273
+ end
274
+
275
+ def reduce_log_sum_exp(x, axis=nil, name="")
276
+ x = Ops.convert_to_variable( x )
277
+ axis = Axis.from_num(axis)
278
+ CNTK.__reduce_log_sum__(x, axis, name)
279
+ end
280
+
281
+ def reduce_mean(x, axis=nil, name="")
282
+ x = Ops.convert_to_variable( x )
283
+ axis = Axis.from_num(axis)
284
+ CNTK.__reduce_mean__(x, axis, name)
285
+ end
286
+
287
+ def reduce_max(x, axis=nil, name="")
288
+ x = Ops.convert_to_variable( x )
289
+ axis = Axis.from_num(axis)
290
+ CNTK.__reduce_max__(x, axis, name)
291
+ end
292
+
293
+ def reduce_min(x, axis=nil, name="")
294
+ x = Ops.convert_to_variable( x )
295
+ axis = Axis.from_num(axis)
296
+ CNTK.__reduce_min__(x, axis, name)
297
+ end
298
+
299
+ def reduce_prod(x, axis=nil, name="")
300
+ x = Ops.convert_to_variable( x )
301
+ axis = Axis.from_num(axis)
302
+ CNTK.__reduce_prod__(x, axis, name)
303
+ end
304
+
305
+ def random_sample(weights, num_samples, allow_dup, name="")
306
+ weights = Ops.convert_to_variable( weights )
307
+ CNTK.__random_sample__(weights, num_samples, allow_dup, name)
308
+ end
309
+
310
+ def random_sample_inclusion_frequency(weights, num_samples, allow_dup, name="")
311
+ weights = Ops.convert_to_variable( weights )
312
+ CNTK.__random_sample_inclusion_frequency__(weights, num_samples, allow_dup, name)
313
+ end
314
+
315
+ def dropout(x, rate=0.0, name="")
316
+ if rate < 0 or rate >= 1
317
+ raise ArgumentError, "dropout_rate must be in the interval [0,1)"
318
+ end
319
+ x = Ops.convert_to_variable( x )
320
+ CNTK.__dropout__(x, rate, name)
321
+ end
322
+
323
+ # FIXME
324
+ def lambda_rank(output, gain, group, name="")
325
+ output, gain, group = Ops.convert_to_variable( output, gain, group )
326
+ CNTK.__lambda_rank__(output, gain, group, name)
327
+ end
328
+
329
+ # FIXME
330
+ def ndcg_at_1(output, gain, group, name="")
331
+ output, gain, group = Ops.convert_to_variable( output, gain, group )
332
+ CNTK.__ndcgat1__(output, gain, group, name)
333
+ end
334
+
335
+ def classification_error(output, target, axis=-1, topN=1, name="")
336
+ output, target = Ops.convert_to_variable( output, target )
337
+ axis = Axis::from_num(axis)
338
+ CNTK.__classification_error__(output, target, topN, axis, name)
339
+ end
340
+
341
+ # FIXME
342
+ def edit_distance_error(input_a, input_b, subPen=0, delPen=0, insPen=0,
343
+ squashInputs=false, samplesToIgnore=[], name='')
344
+ input_a = Ops.convert_to_variable( input_a )
345
+ input_b = Ops.convert_to_variable( input_b )
346
+ CNTK.__edit_distance_error__(input_a, input_b, subPen, delPen, insPen, squashInputs, samplesToIgnore, name)
347
+ end
348
+
349
+ def negate(x=nil, name: "")
350
+ x = x || Ops.placeholder_variable(name: "x")
351
+ x = Ops.convert_to_variable( x )
352
+ CNTK.__negate__(x, name)
353
+ end
354
+
355
+ def sigmoid(x=nil, name: "")
356
+ x = x || Ops.placeholder_variable(name: "x")
357
+ x = Ops.convert_to_variable( x )
358
+ CNTK.__sigmoid__(x, name)
359
+ end
360
+
361
+ def tanh(x=nil, name: "")
362
+ x = x || Ops.placeholder_variable(name: "x")
363
+ x = Ops.convert_to_variable( x )
364
+ CNTK.__tanh__(x, name)
365
+ end
366
+
367
+ def sin(x=nil, name: "")
368
+ x = x || Ops.placeholder_variable(name: "x")
369
+ x = Ops.convert_to_variable( x )
370
+ CNTK.__sin__(x, name)
371
+ end
372
+
373
+ def cos(x=nil, name: "")
374
+ x = x || Ops.placeholder_variable(name: "x")
375
+ x = Ops.convert_to_variable( x )
376
+ CNTK.__cos__(x, name)
377
+ end
378
+
379
+ def relu(x=nil, name: "")
380
+ x = x || Ops.placeholder_variable(name: "x")
381
+ x = Ops.convert_to_variable( x )
382
+ CNTK.__re_lu__(x, name)
383
+ end
384
+
385
+ def exp(x=nil, name: "")
386
+ x = x || Ops.placeholder_variable(name: "x")
387
+ x = Ops.convert_to_variable( x )
388
+ CNTK.__exp__(x, name)
389
+ end
390
+
391
+ def log(x=nil, name: "")
392
+ x = x || Ops.placeholder_variable(name: "x")
393
+ x = Ops.convert_to_variable( x )
394
+ CNTK.__log__(x, name)
395
+ end
396
+
397
+ def square(x=nil, name: "")
398
+ x = x || Ops.placeholder_variable(name: "x")
399
+ x = Ops.convert_to_variable( x )
400
+ CNTK.__square__(x, name)
401
+ end
402
+
403
+ def sqrt(x=nil, name: "")
404
+ x = x || Ops.placeholder_variable(name: "x")
405
+ x = Ops.convert_to_variable( x )
406
+ CNTK.__sqrt__(x, name)
407
+ end
408
+
409
+ def round(x=nil, name: "")
410
+ x = x || Ops.placeholder_variable(name: "x")
411
+ x = Ops.convert_to_variable( x )
412
+ CNTK.__round__(x, name)
413
+ end
414
+
415
+ def floor(x=nil, name: "")
416
+ x = x || Ops.placeholder_variable(name: "x")
417
+ x = Ops.convert_to_variable( x )
418
+ CNTK.__floor__(x, name)
419
+ end
420
+
421
+ def ceil(x=nil, name: "")
422
+ x = x || Ops.placeholder_variable(name: "x")
423
+ x = Ops.convert_to_variable( x )
424
+ CNTK.__ceil__(x, name)
425
+ end
426
+
427
+ def reciprocal(x=nil, name: "")
428
+ x = x || Ops.placeholder_variable(name: "x")
429
+ x = Ops.convert_to_variable( x )
430
+ CNTK.__reciprocal__(x, name)
431
+ end
432
+
433
+ def softmax(x=nil, name: "")
434
+ x = x || Ops.placeholder_variable(name: "x")
435
+ x = Ops.convert_to_variable( x )
436
+ CNTK.__softmax__(x, name)
437
+ end
438
+
439
+ def hardmax(x=nil, name: "")
440
+ x = x || Ops.placeholder_variable(name: "x")
441
+ x = Ops.convert_to_variable( x )
442
+ CNTK.__hardmax__(x, name)
443
+ end
444
+
445
+ def plus(x=nil, y=nil, name: "")
446
+ x = x || Ops.placeholder_variable(name: "x")
447
+ y = y || Ops.placeholder_variable(name: "y")
448
+ x, y = Ops.convert_to_variable( x, y )
449
+ CNTK.__plus__(x, y, name)
450
+ end
451
+
452
+ def minus(x=nil, y=nil, name: "")
453
+ x = x || Ops.placeholder_variable(name: "x")
454
+ y = y || Ops.placeholder_variable(name: "y")
455
+ x, y = Ops.convert_to_variable( x, y )
456
+ CNTK.__minus__(x, y, name)
457
+ end
458
+
459
+ def log_add_exp(x=nil, y=nil, name: "")
460
+ x = x || Ops.placeholder_variable(name: "x")
461
+ y = y || Ops.placeholder_variable(name: "y")
462
+ x, y = Ops.convert_to_variable( x, y )
463
+ CNTK.__log_add_exp__(x, y, name)
464
+ end
465
+
466
+ def abs(x=nil, y=nil, name: "")
467
+ x = x || Ops.placeholder_variable(name: "x")
468
+ y = y || Ops.placeholder_variable(name: "y")
469
+ x, y = Ops.convert_to_variable( x, y )
470
+ CNTK.abs(x, y, name)
471
+ end
472
+
473
+ def element_times(x=nil, y=nil, name: "")
474
+ x = x || Ops.placeholder_variable(name: "x")
475
+ y = y || Ops.placeholder_variable(name: "y")
476
+ x, y = Ops.convert_to_variable( x, y )
477
+ CNTK.__element_times__(x, y, name)
478
+ end
479
+
480
+ def element_divide(x=nil, y=nil, name: "")
481
+ x = x || Ops.placeholder_variable(name: "x")
482
+ y = y || Ops.placeholder_variable(name: "y")
483
+ x, y = Ops.convert_to_variable( x, y )
484
+ CNTK.__element_divide__(x, y, name)
485
+ end
486
+
487
+ def equal(x=nil, y=nil, name: "")
488
+ x = x || Ops.placeholder_variable(name: "x")
489
+ y = y || Ops.placeholder_variable(name: "y")
490
+ x, y = Ops.convert_to_variable( x, y )
491
+ CNTK.__equal__(x, y, name)
492
+ end
493
+
494
+ def not_equal(x=nil, y=nil, name: "")
495
+ x = x || Ops.placeholder_variable(name: "x")
496
+ y = y || Ops.placeholder_variable(name: "y")
497
+ x, y = Ops.convert_to_variable( x, y )
498
+ CNTK.__not_equal__(x, y, name)
499
+ end
500
+
501
+ def less(x=nil, y=nil, name: "")
502
+ x = x || Ops.placeholder_variable(name: "x")
503
+ y = y || Ops.placeholder_variable(name: "y")
504
+ x, y = Ops.convert_to_variable( x, y )
505
+ CNTK.__less__(x, y, name)
506
+ end
507
+
508
+ def less_equal(x=nil, y=nil, name: "")
509
+ x = x || Ops.placeholder_variable(name: "x")
510
+ y = y || Ops.placeholder_variable(name: "y")
511
+ x, y = Ops.convert_to_variable( x, y )
512
+ CNTK.__less_equal__(x, y, name)
513
+ end
514
+
515
+ def greater(x=nil, y=nil, name: "")
516
+ x = x || Ops.placeholder_variable(name: "x")
517
+ y = y || Ops.placeholder_variable(name: "y")
518
+ x, y = Ops.convert_to_variable( x, y )
519
+ CNTK.__greater__(x, y, name)
520
+ end
521
+
522
+ def greater_equal(x=nil, y=nil, name: "")
523
+ x = x || Ops.placeholder_variable(name: "x")
524
+ y = y || Ops.placeholder_variable(name: "y")
525
+ x, y = Ops.convert_to_variable( x, y )
526
+ CNTK.__greater_equal__(x, y, name)
527
+ end
528
+
529
+ def cosine_distance(x=nil, y=nil, name: "")
530
+ x = x || Ops.placeholder_variable(name: "x")
531
+ y = y || Ops.placeholder_variable(name: "y")
532
+ x, y = Ops.convert_to_variable( x, y )
533
+ CNTK.__cosine_distance__(x, y, name)
534
+ end
535
+
536
+ def binary_cross_entropy(x=nil, y=nil, name: "")
537
+ x = x || Ops.placeholder_variable(name: "x")
538
+ y = y || Ops.placeholder_variable(name: "y")
539
+ x, y = Ops.convert_to_variable( x, y )
540
+ CNTK.__binary_cross_entropy__(x, y, name)
541
+ end
542
+
543
+ def squared_error(x=nil, y=nil, name: "")
544
+ x = x || Ops.placeholder_variable(name: "x")
545
+ y = y || Ops.placeholder_variable(name: "y")
546
+ x, y = Ops.convert_to_variable( x, y )
547
+ CNTK.__squared_error__(x, y, name)
548
+ end
549
+
550
+
551
+ end # module Ops
552
+ end # module CNTK
@@ -0,0 +1,59 @@
1
+ module CNTK
2
+ class Trainer
3
+
4
+ # @param args [Hash<Variable,MinibatchData>]
5
+ # @option opt [Array<Variable>] :outputs
6
+ # @option opt [DeviceDescriptor] :device
7
+ def train_minibatch(args, outputs: nil, device: DeviceDescriptor.use_default_device)
8
+ if outputs
9
+ out = StdUMapVariableValue.new()
10
+ outputs.each{|out_var|
11
+ # By setting nullptr, Forward function implemented in C++ will allocate Value object with required storage.
12
+ out.__set_nullptr__(out_var)
13
+ }
14
+ updated = __train_minibatchdata__(args, out, device)
15
+ return [updated, out]
16
+ else
17
+ __train_minibatchdata__(args, device)
18
+ end
19
+ end
20
+
21
+ # @param args
22
+ # @return [Float]
23
+ def test_minibatch(args, device: DeviceDescriptor.use_default_device)
24
+ __test_minibatchdata__(args, device)
25
+ end
26
+
27
+ class << self
28
+
29
+ def create(model: nil, loss: nil, evaluation: nil, learners: nil)
30
+ unless model and loss and learners
31
+ raise ArgumentError, "model, loss function, and learners needed"
32
+ end
33
+ model = variable_to_function(model)
34
+ loss = variable_to_function(loss)
35
+ evaluation = variable_to_function(evaluation) if evaluation
36
+ learners = [learners] unless learners.is_a?(Array)
37
+ if evaluation
38
+ CNTK.__create_trainer__(model, loss, evaluation, learners)
39
+ else
40
+ CNTK.__create_trainer__(model, loss, learners)
41
+ end
42
+ end
43
+
44
+ private
45
+
46
+ def variable_to_function(x)
47
+ case x
48
+ when Function
49
+ x
50
+ when Variable
51
+ CNTK::Ops.combine([x])
52
+ else
53
+ raise ArgumentError
54
+ end
55
+ end
56
+
57
+ end # class << self
58
+ end # class Trainer
59
+ end # module CNTK
data/lib/cntk/value.rb CHANGED
@@ -1,8 +1,14 @@
1
1
  module CNTK
2
2
  class Value
3
3
 
4
- def self.create(a)
5
- new(NDArrayView.create(a))
4
+ def self.create(variable, data, seq_starts=[], device=DeviceDescriptor.use_default_device, read_only=false)
5
+ if variable.dynamic_axes.size == 0
6
+ ndav = NDArrayView.new(data, DeviceDescriptor.cpudevice)
7
+ new(ndav)
8
+ else
9
+ ndav = data.map{|a| NDArrayView.create(a) }
10
+ __create__(variable.shape, ndav, seq_starts, device, read_only)
11
+ end
6
12
  end
7
13
 
8
14
  def to_narray
@@ -11,7 +17,7 @@ module CNTK
11
17
 
12
18
  def reshape(a)
13
19
  na = to_narray().reshape(*a)
14
- self.class.create(na)
20
+ self.class.new(NDArrayView.create(na))
15
21
  end
16
22
 
17
23
  end