tensor_stream 0.8.1 → 0.8.5
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.gitignore +1 -0
- data/CHANGELOG.md +8 -0
- data/README.md +12 -6
- data/lib/tensor_stream.rb +1 -0
- data/lib/tensor_stream/evaluator/base_evaluator.rb +1 -1
- data/lib/tensor_stream/evaluator/ruby/array_ops.rb +282 -0
- data/lib/tensor_stream/evaluator/ruby/images_ops.rb +61 -0
- data/lib/tensor_stream/evaluator/ruby/math_ops.rb +111 -0
- data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +48 -9
- data/lib/tensor_stream/evaluator/ruby/random_ops.rb +51 -0
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +20 -433
- data/lib/tensor_stream/images.rb +16 -0
- data/lib/tensor_stream/ops.rb +5 -1
- data/lib/tensor_stream/session.rb +15 -15
- data/lib/tensor_stream/tensor.rb +1 -1
- data/lib/tensor_stream/train/adadelta_optimizer.rb +52 -0
- data/lib/tensor_stream/train/adam_optimizer.rb +17 -2
- data/lib/tensor_stream/train/gradient_descent_optimizer.rb +7 -1
- data/lib/tensor_stream/trainer.rb +1 -0
- data/lib/tensor_stream/types.rb +4 -0
- data/lib/tensor_stream/utils.rb +4 -0
- data/lib/tensor_stream/variable_scope.rb +1 -0
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/linear_regression.rb +4 -1
- data/samples/mnist_data.rb +64 -0
- data/samples/nearest_neighbor.rb +1 -2
- data/samples/raw_neural_net_sample.rb +1 -1
- data/tensor_stream.gemspec +1 -0
- metadata +23 -57
- data/lib/tensor_stream/evaluator/opencl/kernels/_bool_operand.cl +0 -45
- data/lib/tensor_stream/evaluator/opencl/kernels/_operand.cl +0 -45
- data/lib/tensor_stream/evaluator/opencl/kernels/abs.cl +0 -20
- data/lib/tensor_stream/evaluator/opencl/kernels/acos.cl +0 -8
- data/lib/tensor_stream/evaluator/opencl/kernels/add.cl +0 -3
- data/lib/tensor_stream/evaluator/opencl/kernels/apply_adam.cl +0 -23
- data/lib/tensor_stream/evaluator/opencl/kernels/apply_gradient.cl +0 -9
- data/lib/tensor_stream/evaluator/opencl/kernels/apply_momentum.cl +0 -16
- data/lib/tensor_stream/evaluator/opencl/kernels/argmax.cl +0 -8
- data/lib/tensor_stream/evaluator/opencl/kernels/argmin.cl +0 -8
- data/lib/tensor_stream/evaluator/opencl/kernels/asin.cl +0 -9
- data/lib/tensor_stream/evaluator/opencl/kernels/cast.cl +0 -10
- data/lib/tensor_stream/evaluator/opencl/kernels/ceil.cl +0 -8
- data/lib/tensor_stream/evaluator/opencl/kernels/cond.cl.erb +0 -6
- data/lib/tensor_stream/evaluator/opencl/kernels/cos.cl +0 -8
- data/lib/tensor_stream/evaluator/opencl/kernels/div.cl.erb +0 -3
- data/lib/tensor_stream/evaluator/opencl/kernels/exp.cl +0 -8
- data/lib/tensor_stream/evaluator/opencl/kernels/floor.cl +0 -8
- data/lib/tensor_stream/evaluator/opencl/kernels/floor_div.cl +0 -48
- data/lib/tensor_stream/evaluator/opencl/kernels/floor_mod.cl +0 -3
- data/lib/tensor_stream/evaluator/opencl/kernels/gemm.cl +0 -32
- data/lib/tensor_stream/evaluator/opencl/kernels/log.cl +0 -8
- data/lib/tensor_stream/evaluator/opencl/kernels/log1p.cl +0 -8
- data/lib/tensor_stream/evaluator/opencl/kernels/log_softmax.cl +0 -26
- data/lib/tensor_stream/evaluator/opencl/kernels/max.cl +0 -46
- data/lib/tensor_stream/evaluator/opencl/kernels/min.cl +0 -46
- data/lib/tensor_stream/evaluator/opencl/kernels/mod.cl +0 -3
- data/lib/tensor_stream/evaluator/opencl/kernels/mul.cl +0 -3
- data/lib/tensor_stream/evaluator/opencl/kernels/negate.cl +0 -8
- data/lib/tensor_stream/evaluator/opencl/kernels/pack.cl +0 -24
- data/lib/tensor_stream/evaluator/opencl/kernels/pow.cl +0 -46
- data/lib/tensor_stream/evaluator/opencl/kernels/real_div.cl +0 -3
- data/lib/tensor_stream/evaluator/opencl/kernels/reciprocal.cl +0 -8
- data/lib/tensor_stream/evaluator/opencl/kernels/round.cl +0 -8
- data/lib/tensor_stream/evaluator/opencl/kernels/sigmoid.cl +0 -9
- data/lib/tensor_stream/evaluator/opencl/kernels/sigmoid_grad.cl +0 -55
- data/lib/tensor_stream/evaluator/opencl/kernels/sign.cl +0 -21
- data/lib/tensor_stream/evaluator/opencl/kernels/sin.cl +0 -9
- data/lib/tensor_stream/evaluator/opencl/kernels/softmax.cl +0 -26
- data/lib/tensor_stream/evaluator/opencl/kernels/softmax_cross.cl +0 -32
- data/lib/tensor_stream/evaluator/opencl/kernels/softmax_cross_grad.cl +0 -28
- data/lib/tensor_stream/evaluator/opencl/kernels/softmax_grad.cl +0 -46
- data/lib/tensor_stream/evaluator/opencl/kernels/sqrt.cl +0 -9
- data/lib/tensor_stream/evaluator/opencl/kernels/square.cl +0 -9
- data/lib/tensor_stream/evaluator/opencl/kernels/squared_difference.cl +0 -53
- data/lib/tensor_stream/evaluator/opencl/kernels/sub.cl +0 -3
- data/lib/tensor_stream/evaluator/opencl/kernels/tan.cl +0 -8
- data/lib/tensor_stream/evaluator/opencl/kernels/tanh.cl +0 -8
- data/lib/tensor_stream/evaluator/opencl/kernels/tanh_grad.cl +0 -7
- data/lib/tensor_stream/evaluator/opencl/kernels/where.cl +0 -8
- data/lib/tensor_stream/evaluator/opencl/opencl_buffer.rb +0 -35
- data/lib/tensor_stream/evaluator/opencl/opencl_device.rb +0 -5
- data/lib/tensor_stream/evaluator/opencl/opencl_evaluator.rb +0 -1230
- data/lib/tensor_stream/evaluator/opencl/opencl_template_helper.rb +0 -95
@@ -1,95 +0,0 @@
|
|
1
|
-
require 'erb'
|
2
|
-
class OpenclTemplateHelper
|
3
|
-
def initialize(source)
|
4
|
-
@source = source
|
5
|
-
end
|
6
|
-
|
7
|
-
def generate(args = {})
|
8
|
-
current_scope = binding
|
9
|
-
|
10
|
-
args.each do |k, v|
|
11
|
-
current_scope.local_variable_set(k.to_sym, v)
|
12
|
-
end
|
13
|
-
|
14
|
-
ERB.new(@source, nil, '%').result(current_scope)
|
15
|
-
end
|
16
|
-
|
17
|
-
def floating_point?(dtype)
|
18
|
-
TensorStream::Ops::FLOATING_POINT_TYPES.include?(dtype)
|
19
|
-
end
|
20
|
-
|
21
|
-
def render(template, locals = {})
|
22
|
-
filename = File.join(File.dirname(__FILE__), 'kernels', "_#{template}")
|
23
|
-
source = File.read(filename)
|
24
|
-
current_scope = binding
|
25
|
-
locals.each do |k, v|
|
26
|
-
current_scope.local_variable_set(k.to_sym, v)
|
27
|
-
end
|
28
|
-
ERB.new(source, nil, '%').result(current_scope)
|
29
|
-
end
|
30
|
-
|
31
|
-
def dtype_to_c_type(dtype)
|
32
|
-
case dtype.to_s
|
33
|
-
when 'float64'
|
34
|
-
'double'
|
35
|
-
when 'float32', 'float'
|
36
|
-
'float'
|
37
|
-
when 'int32', 'int'
|
38
|
-
'int'
|
39
|
-
when 'int16'
|
40
|
-
'short'
|
41
|
-
when 'boolean'
|
42
|
-
'short'
|
43
|
-
else
|
44
|
-
raise "unknown dtype #{dtype}"
|
45
|
-
end
|
46
|
-
end
|
47
|
-
|
48
|
-
def min_value_for(dtype)
|
49
|
-
case dtype.to_s
|
50
|
-
when 'float64'
|
51
|
-
'DBL_MIN'
|
52
|
-
when 'float32', 'float'
|
53
|
-
'FLT_MIN'
|
54
|
-
when 'int32', 'int'
|
55
|
-
'INT_MIN'
|
56
|
-
when 'int16'
|
57
|
-
'SHRT_MIN'
|
58
|
-
when 'boolean'
|
59
|
-
'0'
|
60
|
-
else
|
61
|
-
raise "unknown dtype #{dtype}"
|
62
|
-
end
|
63
|
-
end
|
64
|
-
|
65
|
-
def operator_to_c(op)
|
66
|
-
case op
|
67
|
-
when 'less'
|
68
|
-
'<'
|
69
|
-
when 'less_equal'
|
70
|
-
'<='
|
71
|
-
when 'equal'
|
72
|
-
'=='
|
73
|
-
when 'greater'
|
74
|
-
'>'
|
75
|
-
when 'greater_equal'
|
76
|
-
'>='
|
77
|
-
when 'not_equal'
|
78
|
-
'!='
|
79
|
-
when 'logical_and'
|
80
|
-
'&&'
|
81
|
-
when 'div'
|
82
|
-
'/'
|
83
|
-
when 'add'
|
84
|
-
'+'
|
85
|
-
when 'sub'
|
86
|
-
'-'
|
87
|
-
when 'mul'
|
88
|
-
'*'
|
89
|
-
when 'mod'
|
90
|
-
'%'
|
91
|
-
else
|
92
|
-
raise "unsupported op #{op}"
|
93
|
-
end
|
94
|
-
end
|
95
|
-
end
|