ruby-cntk 0.1.0.pre1 → 0.1.0.pre2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,8 @@
1
+ require "numo/narray"
2
+ module CNTK
3
+ class NDMask
4
+ def to_narray
5
+ Numo::Int8[*to_vec()].reshape(*shape())
6
+ end
7
+ end
8
+ end
data/lib/cntk/ndshape.rb CHANGED
@@ -1,10 +1,15 @@
1
1
  module CNTK
2
2
  class NDShape
3
- alias :to_ary :dimensions
4
- alias :to_a :dimensions
5
-
3
+ def to_ary
4
+ dimensions.reverse
5
+ end
6
+
7
+ def to_a
8
+ to_ary
9
+ end
10
+
6
11
  def reverse
7
- to_ary.reverse
12
+ dimensions
8
13
  end
9
14
  end
10
15
  end
data/lib/cntk/ops.rb ADDED
@@ -0,0 +1,552 @@
1
+ require "numo/narray"
2
+
3
+ module CNTK
4
+ module Ops
5
+
6
+ class << self
7
+
8
+ def reverse_reshape_axis(axis)
9
+ axis = Axis.new(axis) if axis.is_a?(Numeric)
10
+ case
11
+ when axis.is_static_axis
12
+ axis
13
+ when axis == Axis.end_static_axis
14
+ Axis.new(0)
15
+ when axis == Axis.new(0)
16
+ Axis.end_static_axis
17
+ else
18
+ Axis(-axis.static_axis_index)
19
+ end
20
+ end
21
+
22
+ def reverse_dynamic_axes(axes)
23
+ axes = [axes] unless axes.is_a?(Array)
24
+ axes.each{|ax|
25
+ raise ArgumentError, "Axis expected" unless ax.is_a?(Axis)
26
+ }
27
+ axes.reverse
28
+ end
29
+
30
+ def convert_to_pooling_type(type)
31
+ case type
32
+ when :max
33
+ type = CNTK::PoolingType_Max
34
+ when :average
35
+ type = CNTK::PoolingType_Average
36
+ else
37
+ raise ArgumentError, "unknown pooling type"
38
+ end
39
+ end
40
+
41
+ def convert_to_variable(*vars)
42
+ if vars.size == 1
43
+ convert_to_one_variable(vars[0])
44
+ else
45
+ dtype = highest_precision_type(*vars)
46
+ return vars.map{|v| convert_to_one_variable(v, dtype) }
47
+ end
48
+ end
49
+
50
+ def convert_to_one_variable(x, dtype = Numo::SFloat)
51
+ case x
52
+ when Variable
53
+ x
54
+ when Function
55
+ x.output
56
+ when Value, Numo::NArray, Numeric
57
+ Ops.constant(x)
58
+ when Array
59
+ Ops.constant( dtype[*x] )
60
+ else
61
+ raise ArgumentError, "CNTK::Variable, Numo::NArray, or Array expected"
62
+ end
63
+ end
64
+
65
+ def highest_precision_type(*args)
66
+ types = args.map{|v|
67
+ case v
68
+ when Variable
69
+ Variable::DataType[v.get_data_type]
70
+ when Numo::NArray
71
+ v.class
72
+ else
73
+ nil
74
+ end
75
+ }
76
+ if types.include?(Numo::DFloat)
77
+ Numo::DFloat
78
+ else
79
+ Numo::SFloat
80
+ end
81
+ end
82
+ end # class << self
83
+
84
+ module_function
85
+
86
+ #
87
+ # variable ops
88
+ #
89
+ def input_variable(shape, dtype: DataType_Float, needs_gradient: false,
90
+ is_sparse: false,
91
+ dynamic_axes: nil,
92
+ name: '')
93
+ if dynamic_axes
94
+ dynamic_axes = dynamic_axes.reverse
95
+ else
96
+ dynamic_axes = Axis.default_input_variable_dynamic_axes()
97
+ end
98
+ CNTK.__input_variable__(shape, is_sparse, dtype, needs_gradient, name, dynamic_axes)
99
+ end
100
+
101
+ def output_variable(shape: nil, dtype: nil, dynamic_axes: nil, name: "")
102
+ if dynamic_axes
103
+ dynamic_axes = dynamic_axes.reverse
104
+ end
105
+ CNTK.__output_variable__(shape, dtype, dynamic_axes, name)
106
+ end
107
+
108
+ def placeholder_variable(shape: NDShape.unknown.dimensions(), name: "", dynamic_axes: nil)
109
+ if dynamic_axes
110
+ dynamic_axes = dynamic_axes.reverse
111
+ else
112
+ dynamic_axes = Axis.unknown_dynamic_axes
113
+ end
114
+ CNTK.__placeholder_variable__(shape, name, dynamic_axes)
115
+ end
116
+
117
+ def constant(*args)
118
+ val = args[0]
119
+ if val.is_a?(Array)
120
+ args[0] = Numo::SFloat[*val]
121
+ end
122
+ Constant.create(*args)
123
+ end
124
+
125
+ def parameter(*args)
126
+ Parameter.create(*args)
127
+ end
128
+
129
+ #
130
+ # ops
131
+ #
132
+ def alias(x, name="")
133
+ x = Ops.convert_to_variable( x )
134
+ CNTK.__alias__(x, name)
135
+ end
136
+
137
+ def weighted_binary_cross_entropy(output, target, weight, name="")
138
+ output = Ops.convert_to_variable( output )
139
+ target = Ops.convert_to_variable( target )
140
+ weight = Ops.convert_to_variable( weight )
141
+ CNTK.__weighted_binary_cross_entropy__(output, target, weight, name)
142
+ end
143
+
144
+ def cross_entropy_with_softmax(output, target, axis=0, name="")
145
+ output = Ops.convert_to_variable( output )
146
+ target = Ops.convert_to_variable( target )
147
+ axis = Axis.from_num(axis)
148
+ CNTK.__cross_entropy_with_softmax__(output, target, axis, name)
149
+ end
150
+
151
+ def combine(array, name="")
152
+ a = array.map{|x| Ops.convert_to_variable( x ) }
153
+ CNTK.__combine__(a, name)
154
+ end
155
+
156
+ def convolution(kernel: nil, input: nil, strides: [1], sharing: [true],
157
+ padding: [false], lower_pad: [0], upper_pad: [0],
158
+ transpose: false, max_temp_mem_size_in_samples: 0, name: "")
159
+ kernel = Ops.convert_to_variable( kernel )
160
+ input = Ops.convert_to_variable( input )
161
+ CNTK.__convolution__(kernel, input, strides, sharing, padding, lower_pad, upper_pad,
162
+ transpose, max_temp_mem_size_in_samples, name)
163
+ end
164
+
165
+
166
+ # CNTK's NDArray is column-major.
167
+ # So to specify rois, remember it.
168
+ # y
169
+ # __________
170
+ # |
171
+ # x|
172
+ # |
173
+ def roipooling(x, rois, shape, name="")
174
+ x, rois = Ops.convert_to_variable( x, rois )
175
+ CNTK.__roipooling__(x, rois, shape, name)
176
+ end
177
+
178
+ def pooling(x, type, shape, strides: [1], padding: [false],
179
+ lower_pad: [0], upper_pad: [0], name: "")
180
+ x = Ops.convert_to_variable( x )
181
+ case type
182
+ when :max
183
+ type = CNTK::PoolingType_Max
184
+ when :average
185
+ type = CNTK::PoolingType_Average
186
+ else
187
+ raise ArgumentError, "unknown pooling type"
188
+ end
189
+ CNTK.__pooling__(x, type, shape, strides, padding, lower_pad, upper_pad, name)
190
+ end
191
+
192
+ def unpooling(operand, input, type, shape, strides: [1], padding: [false],
193
+ lower_pad: [0], upper_pad: [0], name: "")
194
+ operand, input = Ops.convert_to_variable( operand, input )
195
+ type = Ops.convert_to_pooling_type( type )
196
+ CNTK.__unpooling__(operand, input, type, shape, strides, padding, lower_pad, upper_pad, name)
197
+ end
198
+
199
+ def batch_normalization(x, scale: nil, bias: nil, mean: nil, variance: nil, spatial: false,
200
+ normalization_time_constant: 5000, blend_time_constant: 0,
201
+ epsilon: 0.00001, use_cudnn_engine: false, name: "", running_count: 0)
202
+ x, scale, bias, mean, variance, running_count =
203
+ Ops.convert_to_variable( x, scale, bias, mean, variance, running_count )
204
+ CNTK.__batch_normalization__(x, scale, bias, mean, variance, running_count, spatial,
205
+ normalization_time_constant, blend_time_constant,
206
+ epsilon, use_cudnn_engine, name)
207
+ end
208
+
209
+ def times(left, right, output_rank: 1, infer_input_rank_to_map: -1, name: "")
210
+ left, right = Ops.convert_to_variable( left, right )
211
+ # change the order because CNTK a column-major.
212
+ CNTK.__times__(right, left, output_rank, infer_input_rank_to_map, name)
213
+ end
214
+
215
+ def times_transpose(left, right, output_rank = 1, name="")
216
+ left, right = Ops.convert_to_variable( left, right )
217
+ CNTK.__transpose_times__(right, left, output_rank, name="")
218
+ end
219
+
220
+ def clip(x, min, max, name="")
221
+ x, min, max = Ops.convert_to_variable( x, min, max )
222
+ CNTK.__clip__(x, min, max, name)
223
+ end
224
+
225
+ def element_select(x, if_true, if_else, name="")
226
+ x, if_true, if_else = Ops.convert_to_variable( x, if_true, if_else )
227
+ CNTK.__element_select__(x, if_true, if_else, name)
228
+ end
229
+
230
+ def future_value(x, init=0, time_step=1, name="")
231
+ x, init = Ops.convert_to_variable( x, init )
232
+ CNTK.__future_value__(x, init, time_step, name)
233
+ end
234
+
235
+ def past_value(x, init=0, time_step=1, name="")
236
+ x, init = Ops.convert_to_variable( x, init )
237
+ CNTK.__past_value__(x, init, time_step, name)
238
+ end
239
+
240
+ def reshape(x, shape, begin_axis=Axis.new(0), end_axis=Axis.end_static_axis(), name="")
241
+ x = Ops.convert_to_variable( x )
242
+ begin_axis = Ops.reverse_reshape_axis(begin_axis)
243
+ end_axis = Ops.reverse_reshape_axis(end_axis )
244
+ CNTK.__reshape__(x, shape, begin_axis, end_axis, name)
245
+ end
246
+
247
+ def transpose(x, axis1=0, axis2=1, name="")
248
+ x = Ops.convert_to_variable( x )
249
+ unless axis1.abs <= x.shape.rank and axis2.abs <= x.shape.rank
250
+ raise ArgumentError, "out of bounds"
251
+ end
252
+ axis1 = Axis.from_num(axis1)
253
+ axis2 = Axis.from_num(axis2)
254
+ CNTK.__transpose_axes__(x, axis1, axis2, name)
255
+ end
256
+
257
+ def slice(x, axis, begin_index, end_index, name="")
258
+ x = Ops.convert_to_variable( x )
259
+ axis = Axis.from_num(axis)
260
+ CNTK.__slice__(x, axis, begin_index, end_index, name)
261
+ end
262
+
263
+ def splice(x, axis=-1, name="")
264
+ x = x.map{|var| Ops.convert_to_variable( var ) }
265
+ axis = Axis.from_num(axis)
266
+ CNTK.__splice__(x, axis, name)
267
+ end
268
+
269
+ def reduce_sum(x, axis=nil, name="")
270
+ x = Ops.convert_to_variable( x )
271
+ axis = Axis.from_num(axis)
272
+ CNTK.__reduce_sum__(x, axis, name)
273
+ end
274
+
275
+ def reduce_log_sum_exp(x, axis=nil, name="")
276
+ x = Ops.convert_to_variable( x )
277
+ axis = Axis.from_num(axis)
278
+ CNTK.__reduce_log_sum__(x, axis, name)
279
+ end
280
+
281
+ def reduce_mean(x, axis=nil, name="")
282
+ x = Ops.convert_to_variable( x )
283
+ axis = Axis.from_num(axis)
284
+ CNTK.__reduce_mean__(x, axis, name)
285
+ end
286
+
287
+ def reduce_max(x, axis=nil, name="")
288
+ x = Ops.convert_to_variable( x )
289
+ axis = Axis.from_num(axis)
290
+ CNTK.__reduce_max__(x, axis, name)
291
+ end
292
+
293
+ def reduce_min(x, axis=nil, name="")
294
+ x = Ops.convert_to_variable( x )
295
+ axis = Axis.from_num(axis)
296
+ CNTK.__reduce_min__(x, axis, name)
297
+ end
298
+
299
+ def reduce_prod(x, axis=nil, name="")
300
+ x = Ops.convert_to_variable( x )
301
+ axis = Axis.from_num(axis)
302
+ CNTK.__reduce_prod__(x, axis, name)
303
+ end
304
+
305
+ def random_sample(weights, num_samples, allow_dup, name="")
306
+ weights = Ops.convert_to_variable( weights )
307
+ CNTK.__random_sample__(weights, num_samples, allow_dup, name)
308
+ end
309
+
310
+ def random_sample_inclusion_frequency(weights, num_samples, allow_dup, name="")
311
+ weights = Ops.convert_to_variable( weights )
312
+ CNTK.__random_sample_inclusion_frequency__(weights, num_samples, allow_dup, name)
313
+ end
314
+
315
+ def dropout(x, rate=0.0, name="")
316
+ if rate < 0 or rate >= 1
317
+ raise ArgumentError, "dropout_rate must be in the interval [0,1)"
318
+ end
319
+ x = Ops.convert_to_variable( x )
320
+ CNTK.__dropout__(x, rate, name)
321
+ end
322
+
323
+ # FIXME
324
+ def lambda_rank(output, gain, group, name="")
325
+ output, gain, group = Ops.convert_to_variable( output, gain, group )
326
+ CNTK.__lambda_rank__(output, gain, group, name)
327
+ end
328
+
329
+ # FIXME
330
+ def ndcg_at_1(output, gain, group, name="")
331
+ output, gain, group = Ops.convert_to_variable( output, gain, group )
332
+ CNTK.__ndcgat1__(output, gain, group, name)
333
+ end
334
+
335
+ def classification_error(output, target, axis=-1, topN=1, name="")
336
+ output, target = Ops.convert_to_variable( output, target )
337
+ axis = Axis::from_num(axis)
338
+ CNTK.__classification_error__(output, target, topN, axis, name)
339
+ end
340
+
341
+ # FIXME
342
+ def edit_distance_error(input_a, input_b, subPen=0, delPen=0, insPen=0,
343
+ squashInputs=false, samplesToIgnore=[], name='')
344
+ input_a = Ops.convert_to_variable( input_a )
345
+ input_b = Ops.convert_to_variable( input_b )
346
+ CNTK.__edit_distance_error__(input_a, input_b, subPen, delPen, insPen, squashInputs, samplesToIgnore, name)
347
+ end
348
+
349
+ def negate(x=nil, name: "")
350
+ x = x || Ops.placeholder_variable(name: "x")
351
+ x = Ops.convert_to_variable( x )
352
+ CNTK.__negate__(x, name)
353
+ end
354
+
355
+ def sigmoid(x=nil, name: "")
356
+ x = x || Ops.placeholder_variable(name: "x")
357
+ x = Ops.convert_to_variable( x )
358
+ CNTK.__sigmoid__(x, name)
359
+ end
360
+
361
+ def tanh(x=nil, name: "")
362
+ x = x || Ops.placeholder_variable(name: "x")
363
+ x = Ops.convert_to_variable( x )
364
+ CNTK.__tanh__(x, name)
365
+ end
366
+
367
+ def sin(x=nil, name: "")
368
+ x = x || Ops.placeholder_variable(name: "x")
369
+ x = Ops.convert_to_variable( x )
370
+ CNTK.__sin__(x, name)
371
+ end
372
+
373
+ def cos(x=nil, name: "")
374
+ x = x || Ops.placeholder_variable(name: "x")
375
+ x = Ops.convert_to_variable( x )
376
+ CNTK.__cos__(x, name)
377
+ end
378
+
379
+ def relu(x=nil, name: "")
380
+ x = x || Ops.placeholder_variable(name: "x")
381
+ x = Ops.convert_to_variable( x )
382
+ CNTK.__re_lu__(x, name)
383
+ end
384
+
385
+ def exp(x=nil, name: "")
386
+ x = x || Ops.placeholder_variable(name: "x")
387
+ x = Ops.convert_to_variable( x )
388
+ CNTK.__exp__(x, name)
389
+ end
390
+
391
+ def log(x=nil, name: "")
392
+ x = x || Ops.placeholder_variable(name: "x")
393
+ x = Ops.convert_to_variable( x )
394
+ CNTK.__log__(x, name)
395
+ end
396
+
397
+ def square(x=nil, name: "")
398
+ x = x || Ops.placeholder_variable(name: "x")
399
+ x = Ops.convert_to_variable( x )
400
+ CNTK.__square__(x, name)
401
+ end
402
+
403
+ def sqrt(x=nil, name: "")
404
+ x = x || Ops.placeholder_variable(name: "x")
405
+ x = Ops.convert_to_variable( x )
406
+ CNTK.__sqrt__(x, name)
407
+ end
408
+
409
+ def round(x=nil, name: "")
410
+ x = x || Ops.placeholder_variable(name: "x")
411
+ x = Ops.convert_to_variable( x )
412
+ CNTK.__round__(x, name)
413
+ end
414
+
415
+ def floor(x=nil, name: "")
416
+ x = x || Ops.placeholder_variable(name: "x")
417
+ x = Ops.convert_to_variable( x )
418
+ CNTK.__floor__(x, name)
419
+ end
420
+
421
+ def ceil(x=nil, name: "")
422
+ x = x || Ops.placeholder_variable(name: "x")
423
+ x = Ops.convert_to_variable( x )
424
+ CNTK.__ceil__(x, name)
425
+ end
426
+
427
+ def reciprocal(x=nil, name: "")
428
+ x = x || Ops.placeholder_variable(name: "x")
429
+ x = Ops.convert_to_variable( x )
430
+ CNTK.__reciprocal__(x, name)
431
+ end
432
+
433
+ def softmax(x=nil, name: "")
434
+ x = x || Ops.placeholder_variable(name: "x")
435
+ x = Ops.convert_to_variable( x )
436
+ CNTK.__softmax__(x, name)
437
+ end
438
+
439
+ def hardmax(x=nil, name: "")
440
+ x = x || Ops.placeholder_variable(name: "x")
441
+ x = Ops.convert_to_variable( x )
442
+ CNTK.__hardmax__(x, name)
443
+ end
444
+
445
+ def plus(x=nil, y=nil, name: "")
446
+ x = x || Ops.placeholder_variable(name: "x")
447
+ y = y || Ops.placeholder_variable(name: "y")
448
+ x, y = Ops.convert_to_variable( x, y )
449
+ CNTK.__plus__(x, y, name)
450
+ end
451
+
452
+ def minus(x=nil, y=nil, name: "")
453
+ x = x || Ops.placeholder_variable(name: "x")
454
+ y = y || Ops.placeholder_variable(name: "y")
455
+ x, y = Ops.convert_to_variable( x, y )
456
+ CNTK.__minus__(x, y, name)
457
+ end
458
+
459
+ def log_add_exp(x=nil, y=nil, name: "")
460
+ x = x || Ops.placeholder_variable(name: "x")
461
+ y = y || Ops.placeholder_variable(name: "y")
462
+ x, y = Ops.convert_to_variable( x, y )
463
+ CNTK.__log_add_exp__(x, y, name)
464
+ end
465
+
466
+ def abs(x=nil, y=nil, name: "")
467
+ x = x || Ops.placeholder_variable(name: "x")
468
+ y = y || Ops.placeholder_variable(name: "y")
469
+ x, y = Ops.convert_to_variable( x, y )
470
+ CNTK.abs(x, y, name)
471
+ end
472
+
473
+ def element_times(x=nil, y=nil, name: "")
474
+ x = x || Ops.placeholder_variable(name: "x")
475
+ y = y || Ops.placeholder_variable(name: "y")
476
+ x, y = Ops.convert_to_variable( x, y )
477
+ CNTK.__element_times__(x, y, name)
478
+ end
479
+
480
+ def element_divide(x=nil, y=nil, name: "")
481
+ x = x || Ops.placeholder_variable(name: "x")
482
+ y = y || Ops.placeholder_variable(name: "y")
483
+ x, y = Ops.convert_to_variable( x, y )
484
+ CNTK.__element_divide__(x, y, name)
485
+ end
486
+
487
+ def equal(x=nil, y=nil, name: "")
488
+ x = x || Ops.placeholder_variable(name: "x")
489
+ y = y || Ops.placeholder_variable(name: "y")
490
+ x, y = Ops.convert_to_variable( x, y )
491
+ CNTK.__equal__(x, y, name)
492
+ end
493
+
494
+ def not_equal(x=nil, y=nil, name: "")
495
+ x = x || Ops.placeholder_variable(name: "x")
496
+ y = y || Ops.placeholder_variable(name: "y")
497
+ x, y = Ops.convert_to_variable( x, y )
498
+ CNTK.__not_equal__(x, y, name)
499
+ end
500
+
501
+ def less(x=nil, y=nil, name: "")
502
+ x = x || Ops.placeholder_variable(name: "x")
503
+ y = y || Ops.placeholder_variable(name: "y")
504
+ x, y = Ops.convert_to_variable( x, y )
505
+ CNTK.__less__(x, y, name)
506
+ end
507
+
508
+ def less_equal(x=nil, y=nil, name: "")
509
+ x = x || Ops.placeholder_variable(name: "x")
510
+ y = y || Ops.placeholder_variable(name: "y")
511
+ x, y = Ops.convert_to_variable( x, y )
512
+ CNTK.__less_equal__(x, y, name)
513
+ end
514
+
515
+ def greater(x=nil, y=nil, name: "")
516
+ x = x || Ops.placeholder_variable(name: "x")
517
+ y = y || Ops.placeholder_variable(name: "y")
518
+ x, y = Ops.convert_to_variable( x, y )
519
+ CNTK.__greater__(x, y, name)
520
+ end
521
+
522
+ def greater_equal(x=nil, y=nil, name: "")
523
+ x = x || Ops.placeholder_variable(name: "x")
524
+ y = y || Ops.placeholder_variable(name: "y")
525
+ x, y = Ops.convert_to_variable( x, y )
526
+ CNTK.__greater_equal__(x, y, name)
527
+ end
528
+
529
+ def cosine_distance(x=nil, y=nil, name: "")
530
+ x = x || Ops.placeholder_variable(name: "x")
531
+ y = y || Ops.placeholder_variable(name: "y")
532
+ x, y = Ops.convert_to_variable( x, y )
533
+ CNTK.__cosine_distance__(x, y, name)
534
+ end
535
+
536
+ def binary_cross_entropy(x=nil, y=nil, name: "")
537
+ x = x || Ops.placeholder_variable(name: "x")
538
+ y = y || Ops.placeholder_variable(name: "y")
539
+ x, y = Ops.convert_to_variable( x, y )
540
+ CNTK.__binary_cross_entropy__(x, y, name)
541
+ end
542
+
543
+ def squared_error(x=nil, y=nil, name: "")
544
+ x = x || Ops.placeholder_variable(name: "x")
545
+ y = y || Ops.placeholder_variable(name: "y")
546
+ x, y = Ops.convert_to_variable( x, y )
547
+ CNTK.__squared_error__(x, y, name)
548
+ end
549
+
550
+
551
+ end # module Ops
552
+ end # module CNTK
@@ -0,0 +1,59 @@
1
+ module CNTK
2
+ class Trainer
3
+
4
+ # @param args [Hash<Variable,MinibatchData>]
5
+ # @option opt [Array<Variable>] :outputs
6
+ # @option opt [DeviceDescriptor] :device
7
+ def train_minibatch(args, outputs: nil, device: DeviceDescriptor.use_default_device)
8
+ if outputs
9
+ out = StdUMapVariableValue.new()
10
+ outputs.each{|out_var|
11
+ # By setting nullptr, Forward function implemented in C++ will allocate Value object with required storage.
12
+ out.__set_nullptr__(out_var)
13
+ }
14
+ updated = __train_minibatchdata__(args, out, device)
15
+ return [updated, out]
16
+ else
17
+ __train_minibatchdata__(args, device)
18
+ end
19
+ end
20
+
21
+ # @param args
22
+ # @return [Float]
23
+ def test_minibatch(args, device: DeviceDescriptor.use_default_device)
24
+ __test_minibatchdata__(args, device)
25
+ end
26
+
27
+ class << self
28
+
29
+ def create(model: nil, loss: nil, evaluation: nil, learners: nil)
30
+ unless model and loss and learners
31
+ raise ArgumentError, "model, loss function, and learners needed"
32
+ end
33
+ model = variable_to_function(model)
34
+ loss = variable_to_function(loss)
35
+ evaluation = variable_to_function(evaluation) if evaluation
36
+ learners = [learners] unless learners.is_a?(Array)
37
+ if evaluation
38
+ CNTK.__create_trainer__(model, loss, evaluation, learners)
39
+ else
40
+ CNTK.__create_trainer__(model, loss, learners)
41
+ end
42
+ end
43
+
44
+ private
45
+
46
+ def variable_to_function(x)
47
+ case x
48
+ when Function
49
+ x
50
+ when Variable
51
+ CNTK::Ops.combine([x])
52
+ else
53
+ raise ArgumentError
54
+ end
55
+ end
56
+
57
+ end # class << self
58
+ end # class Trainer
59
+ end # module CNTK
data/lib/cntk/value.rb CHANGED
@@ -1,8 +1,14 @@
1
1
  module CNTK
2
2
  class Value
3
3
 
4
- def self.create(a)
5
- new(NDArrayView.create(a))
4
+ def self.create(variable, data, seq_starts=[], device=DeviceDescriptor.use_default_device, read_only=false)
5
+ if variable.dynamic_axes.size == 0
6
+ ndav = NDArrayView.new(data, DeviceDescriptor.cpudevice)
7
+ new(ndav)
8
+ else
9
+ ndav = data.map{|a| NDArrayView.create(a) }
10
+ __create__(variable.shape, ndav, seq_starts, device, read_only)
11
+ end
6
12
  end
7
13
 
8
14
  def to_narray
@@ -11,7 +17,7 @@ module CNTK
11
17
 
12
18
  def reshape(a)
13
19
  na = to_narray().reshape(*a)
14
- self.class.create(na)
20
+ self.class.new(NDArrayView.create(na))
15
21
  end
16
22
 
17
23
  end