tensor_stream 0.1.1 → 0.1.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.gitignore +1 -0
- data/.rubocop.yml +74 -0
- data/README.md +4 -3
- data/lib/tensor_stream.rb +17 -18
- data/lib/tensor_stream/control_flow.rb +7 -4
- data/lib/tensor_stream/evaluator/evaluator.rb +1 -1
- data/lib/tensor_stream/evaluator/operation_helpers/random_gaussian.rb +6 -7
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +80 -91
- data/lib/tensor_stream/graph.rb +51 -18
- data/lib/tensor_stream/graph_keys.rb +2 -2
- data/lib/tensor_stream/helpers/op_helper.rb +31 -27
- data/lib/tensor_stream/math_gradients.rb +20 -23
- data/lib/tensor_stream/nn/nn_ops.rb +2 -2
- data/lib/tensor_stream/operation.rb +28 -45
- data/lib/tensor_stream/ops.rb +103 -103
- data/lib/tensor_stream/placeholder.rb +7 -4
- data/lib/tensor_stream/session.rb +20 -22
- data/lib/tensor_stream/tensor.rb +43 -101
- data/lib/tensor_stream/tensor_shape.rb +4 -3
- data/lib/tensor_stream/train/gradient_descent_optimizer.rb +5 -5
- data/lib/tensor_stream/train/saver.rb +13 -13
- data/lib/tensor_stream/trainer.rb +1 -1
- data/lib/tensor_stream/types.rb +14 -1
- data/lib/tensor_stream/variable.rb +8 -7
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/iris.rb +7 -7
- data/samples/linear_regression.rb +3 -3
- data/samples/raw_neural_net_sample.rb +6 -6
- data/tensor_stream.gemspec +1 -0
- metadata +17 -2
data/lib/tensor_stream/ops.rb
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
module TensorStream
|
2
|
+
# Class that defines all available ops supported by TensorStream
|
2
3
|
module Ops
|
3
4
|
FLOATING_POINT_TYPES = %w[float32 float64].map(&:to_sym)
|
4
5
|
NUMERIC_TYPES = %w[int32 int64 float32 float64].map(&:to_sym)
|
@@ -7,27 +8,26 @@ module TensorStream
|
|
7
8
|
op(:argmax, input, nil, axis: axis, name: name, dimension: dimension, data_type: output_type)
|
8
9
|
end
|
9
10
|
|
10
|
-
def gradients(
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
fail "#{x} passed is not a tensor object" unless x.is_a?(Tensor)
|
11
|
+
def gradients(input, wrt_xs, grad_ys: nil,
|
12
|
+
name: 'gradients',
|
13
|
+
colocate_gradients_with_ops: false,
|
14
|
+
gate_gradients: false,
|
15
|
+
aggregation_method: nil,
|
16
|
+
stop_gradients: nil)
|
17
|
+
|
18
|
+
gs = wrt_xs.collect do |x|
|
19
|
+
raise "#{x} passed is not a tensor object" unless x.is_a?(Tensor)
|
20
20
|
|
21
21
|
stops = stop_gradients ? stop_gradients.map(&:name).join('_') : ''
|
22
|
-
gradient_program_name = "grad_#{
|
22
|
+
gradient_program_name = "grad_#{input.name}_#{x.name}_#{stops}".to_sym
|
23
23
|
|
24
|
-
tensor_program = if
|
25
|
-
|
24
|
+
tensor_program = if input.graph.node_added?(gradient_program_name)
|
25
|
+
input.graph.get_node(gradient_program_name)
|
26
26
|
else
|
27
|
-
derivative_ops = TensorStream::MathGradients.derivative(
|
28
|
-
|
27
|
+
derivative_ops = TensorStream::MathGradients.derivative(input, x, graph: input.graph,
|
28
|
+
stop_gradients: stop_gradients)
|
29
29
|
unit_matrix = op(:ones_like, x)
|
30
|
-
|
30
|
+
input.graph.add_node!(gradient_program_name, unit_matrix * derivative_ops)
|
31
31
|
end
|
32
32
|
tensor_program
|
33
33
|
end
|
@@ -35,12 +35,12 @@ module TensorStream
|
|
35
35
|
end
|
36
36
|
|
37
37
|
def random_uniform(shape, dtype: :float32, minval: 0, maxval: 1, seed: nil, name: nil)
|
38
|
-
options = {shape: shape, dtype: dtype, minval: minval, maxval: maxval, seed: seed, name: name}
|
38
|
+
options = { shape: shape, dtype: dtype, minval: minval, maxval: maxval, seed: seed, name: name }
|
39
39
|
op(:random_uniform, nil, nil, options)
|
40
40
|
end
|
41
41
|
|
42
42
|
def random_normal(shape, dtype: :float32, mean: 0.0, stddev: 1.0, seed: nil, name: nil)
|
43
|
-
options = {shape: shape, dtype: dtype, mean: mean, stddev: stddev, seed: seed, name: name}
|
43
|
+
options = { shape: shape, dtype: dtype, mean: mean, stddev: stddev, seed: seed, name: name }
|
44
44
|
op(:random_normal, nil, nil, options)
|
45
45
|
end
|
46
46
|
|
@@ -53,7 +53,7 @@ module TensorStream
|
|
53
53
|
end
|
54
54
|
|
55
55
|
def shape(input, name: nil, out_type: :int32)
|
56
|
-
op(:shape, input, nil, name: name)
|
56
|
+
op(:shape, input, nil, name: name, out_type: out_type)
|
57
57
|
end
|
58
58
|
|
59
59
|
def rank(input, name: nil)
|
@@ -63,35 +63,35 @@ module TensorStream
|
|
63
63
|
def zeros_initializer(options = {})
|
64
64
|
op(:zeros, nil, nil, options)
|
65
65
|
end
|
66
|
-
|
66
|
+
|
67
67
|
def slice(input, start, size, name: nil)
|
68
68
|
op(:slice, input, start, size: size, name: name)
|
69
69
|
end
|
70
|
-
|
70
|
+
|
71
71
|
def zeros(shape, dtype: :float32, name: nil)
|
72
72
|
op(:zeros, shape, nil, data_type: dtype, name: name)
|
73
73
|
end
|
74
|
-
|
74
|
+
|
75
75
|
def ones(shape, dtype: :float32, name: nil)
|
76
76
|
op(:ones, shape, nil, data_type: dtype, name: name)
|
77
77
|
end
|
78
|
-
|
79
|
-
def less(
|
80
|
-
op(:less,
|
78
|
+
|
79
|
+
def less(input_a, input_b, name: nil)
|
80
|
+
op(:less, input_a, input_b, name: name)
|
81
81
|
end
|
82
|
-
|
83
|
-
def greater(
|
84
|
-
op(:greater,
|
82
|
+
|
83
|
+
def greater(input_a, input_b, name: nil)
|
84
|
+
op(:greater, input_a, input_b, name: name)
|
85
85
|
end
|
86
86
|
|
87
|
-
def greater_equal(
|
88
|
-
op(:greater_equal,
|
87
|
+
def greater_equal(input_a, input_b, name: nil)
|
88
|
+
op(:greater_equal, input_a, input_b, name: name)
|
89
89
|
end
|
90
90
|
|
91
|
-
def less_equal(
|
92
|
-
op(:less_equal,
|
91
|
+
def less_equal(input_a, input_b, name: nil)
|
92
|
+
op(:less_equal, input_a, input_b, name: name)
|
93
93
|
end
|
94
|
-
|
94
|
+
|
95
95
|
def reduce_mean(input_tensor, axis = nil, keepdims: false, name: nil)
|
96
96
|
op(:reduce_mean, input_tensor, nil, axis: axis, keepdims: keepdims, name: name)
|
97
97
|
end
|
@@ -99,66 +99,66 @@ module TensorStream
|
|
99
99
|
def reduce_sum(input_tensor, axis = nil, keepdims: false, name: nil)
|
100
100
|
op(:reduce_sum, input_tensor, nil, axis: axis, keepdims: keepdims, name: name)
|
101
101
|
end
|
102
|
-
|
102
|
+
|
103
103
|
def reduce_prod(input, axis = nil, keepdims: false, name: nil)
|
104
104
|
op(:reduce_prod, input, nil, axis: axis, keepdims: keepdims, name: name)
|
105
105
|
end
|
106
|
-
|
106
|
+
|
107
107
|
def concat(values, axis, name: 'concat')
|
108
108
|
op(:concat, values, nil, axis: axis, name: name)
|
109
109
|
end
|
110
|
-
|
110
|
+
|
111
111
|
def reshape(tensor, shape, name: nil)
|
112
112
|
op(:reshape, tensor, shape, name: name)
|
113
113
|
end
|
114
|
-
|
114
|
+
|
115
115
|
def square(tensor, name: nil)
|
116
116
|
op(:square, tensor, nil, name: name)
|
117
117
|
end
|
118
|
-
|
118
|
+
|
119
119
|
def cond(pred, true_fn, false_fn, name: nil)
|
120
120
|
op(:cond, true_fn, false_fn, pred: pred, name: name)
|
121
121
|
end
|
122
122
|
|
123
|
-
def where(condition,
|
124
|
-
op(:where,
|
123
|
+
def where(condition, true_t = nil, false_t = nil, name: nil)
|
124
|
+
op(:where, true_t, false_t, pred: condition, name: name)
|
125
125
|
end
|
126
|
-
|
127
|
-
def add(
|
128
|
-
op(:add,
|
126
|
+
|
127
|
+
def add(input_a, input_b, name: nil)
|
128
|
+
op(:add, input_a, input_b, name: name)
|
129
129
|
end
|
130
|
-
|
131
|
-
def sub(
|
132
|
-
op(:sub,
|
130
|
+
|
131
|
+
def sub(input_a, input_b, name: nil)
|
132
|
+
op(:sub, input_a, input_b, name: name)
|
133
133
|
end
|
134
134
|
|
135
|
-
def max(
|
136
|
-
check_allowed_types(
|
137
|
-
check_allowed_types(
|
135
|
+
def max(input_a, input_b, name: nil)
|
136
|
+
check_allowed_types(input_a, NUMERIC_TYPES)
|
137
|
+
check_allowed_types(input_b, NUMERIC_TYPES)
|
138
138
|
|
139
|
-
op(:max,
|
139
|
+
op(:max, input_a, input_b, name: name)
|
140
140
|
end
|
141
141
|
|
142
142
|
def cast(input, dtype, name: nil)
|
143
143
|
op(:cast, input, nil, data_type: dtype, name: name)
|
144
144
|
end
|
145
|
-
|
145
|
+
|
146
146
|
def print(input, data, message: nil, name: nil)
|
147
147
|
op(:print, input, data, message: message, name: name)
|
148
148
|
end
|
149
|
-
|
150
|
-
def negate(
|
151
|
-
op(:negate,
|
149
|
+
|
150
|
+
def negate(input, options = {})
|
151
|
+
op(:negate, input, nil, options)
|
152
152
|
end
|
153
|
-
|
154
|
-
def equal(
|
155
|
-
op(:equal,
|
153
|
+
|
154
|
+
def equal(input_a, input_b, name: nil)
|
155
|
+
op(:equal, input_a, input_b, name: name)
|
156
156
|
end
|
157
157
|
|
158
|
-
def not_equal(
|
159
|
-
op(:not_equal,
|
158
|
+
def not_equal(input_a, input_b, name: nil)
|
159
|
+
op(:not_equal, input_a, input_b, name: name)
|
160
160
|
end
|
161
|
-
|
161
|
+
|
162
162
|
def zeros_like(tensor, dtype: nil, name: nil)
|
163
163
|
op(:zeros_like, tensor, nil, data_type: dtype, name: name)
|
164
164
|
end
|
@@ -166,75 +166,75 @@ module TensorStream
|
|
166
166
|
def ones_like(tensor, dtype: nil, name: nil)
|
167
167
|
op(:ones_like, tensor, nil, data_type: dtype, name: name)
|
168
168
|
end
|
169
|
-
|
169
|
+
|
170
170
|
def identity(input, name: nil)
|
171
171
|
op(:identity, input, nil, name: name)
|
172
172
|
end
|
173
|
-
|
174
|
-
def multiply(
|
175
|
-
op(:mul,
|
173
|
+
|
174
|
+
def multiply(input_a, input_b, name: nil)
|
175
|
+
op(:mul, input_a, input_b, name: name)
|
176
176
|
end
|
177
|
-
|
178
|
-
def pow(
|
179
|
-
op(:pow,
|
177
|
+
|
178
|
+
def pow(input_a, input_e, name: nil)
|
179
|
+
op(:pow, input_a, input_e, name: name)
|
180
180
|
end
|
181
|
-
|
182
|
-
def abs(
|
183
|
-
op(:abs,
|
181
|
+
|
182
|
+
def abs(input, name: nil)
|
183
|
+
op(:abs, input, nil, name: name)
|
184
184
|
end
|
185
|
-
|
186
|
-
def sign(
|
187
|
-
op(:sign,
|
185
|
+
|
186
|
+
def sign(input, name: nil)
|
187
|
+
op(:sign, input, nil, name: name)
|
188
188
|
end
|
189
|
-
|
190
|
-
def sin(
|
189
|
+
|
190
|
+
def sin(input, options = {})
|
191
191
|
options[:data_type] ||= :float32
|
192
|
-
check_allowed_types(
|
193
|
-
op(:sin,
|
192
|
+
check_allowed_types(input, FLOATING_POINT_TYPES)
|
193
|
+
op(:sin, input, nil, options)
|
194
194
|
end
|
195
|
-
|
196
|
-
def cos(
|
195
|
+
|
196
|
+
def cos(input, options = {})
|
197
197
|
options[:data_type] ||= :float32
|
198
|
-
check_allowed_types(
|
199
|
-
op(:cos,
|
198
|
+
check_allowed_types(input, FLOATING_POINT_TYPES)
|
199
|
+
op(:cos, input, nil, options)
|
200
200
|
end
|
201
|
-
|
202
|
-
def tan(
|
201
|
+
|
202
|
+
def tan(input, options = {})
|
203
203
|
options[:data_type] ||= :float32
|
204
|
-
check_allowed_types(
|
205
|
-
op(:tan,
|
204
|
+
check_allowed_types(input, FLOATING_POINT_TYPES)
|
205
|
+
op(:tan, input, nil, options)
|
206
206
|
end
|
207
|
-
|
208
|
-
def tanh(
|
207
|
+
|
208
|
+
def tanh(input, options = {})
|
209
209
|
options[:data_type] ||= :float32
|
210
|
-
check_allowed_types(
|
211
|
-
op(:tanh,
|
210
|
+
check_allowed_types(input, FLOATING_POINT_TYPES)
|
211
|
+
op(:tanh, input, nil, options)
|
212
212
|
end
|
213
|
-
|
214
|
-
def sqrt(
|
213
|
+
|
214
|
+
def sqrt(input, name: nil)
|
215
215
|
options = {
|
216
|
-
data_type:
|
216
|
+
data_type: input.data_type,
|
217
217
|
name: name
|
218
218
|
}
|
219
|
-
check_allowed_types(
|
220
|
-
op(:sqrt,
|
219
|
+
check_allowed_types(input, FLOATING_POINT_TYPES)
|
220
|
+
op(:sqrt, input, nil, options)
|
221
221
|
end
|
222
|
-
|
223
|
-
def log(input, options= {})
|
222
|
+
|
223
|
+
def log(input, options = {})
|
224
224
|
options[:data_type] ||= :float32
|
225
225
|
check_allowed_types(input, FLOATING_POINT_TYPES)
|
226
226
|
op(:log, input, nil, options)
|
227
227
|
end
|
228
|
-
|
229
|
-
def exp(
|
228
|
+
|
229
|
+
def exp(input, options = {})
|
230
230
|
options[:data_type] ||= :float32
|
231
|
-
check_allowed_types(
|
232
|
-
op(:exp,
|
231
|
+
check_allowed_types(input, FLOATING_POINT_TYPES)
|
232
|
+
op(:exp, input, nil, options)
|
233
233
|
end
|
234
234
|
|
235
235
|
def matmul(input_a, input_b, transpose_a: false,
|
236
|
-
|
237
|
-
|
236
|
+
transpose_b: false,
|
237
|
+
name: nil)
|
238
238
|
op(:matmul, input_a, input_b, transpose_a: transpose_a, transpose_b: transpose_b, name: name)
|
239
239
|
end
|
240
240
|
|
@@ -246,4 +246,4 @@ module TensorStream
|
|
246
246
|
op(:pad, tensor, nil, paddings: paddings, mode: mode, name: name)
|
247
247
|
end
|
248
248
|
end
|
249
|
-
end
|
249
|
+
end
|
@@ -1,13 +1,16 @@
|
|
1
1
|
module TensorStream
|
2
|
+
# Class that defines a TensorStream placeholder
|
2
3
|
class Placeholder < Tensor
|
3
4
|
def initialize(data_type, rank, shape, options = {})
|
5
|
+
@graph = options[:graph] || TensorStream.get_default_graph
|
6
|
+
|
4
7
|
@data_type = data_type
|
5
8
|
@rank = rank
|
6
9
|
@shape = TensorShape.new(shape, rank)
|
7
10
|
@value = nil
|
8
11
|
@is_const = false
|
9
|
-
@source =
|
10
|
-
|
12
|
+
@source = format_source(caller_locations)
|
13
|
+
|
11
14
|
@name = options[:name] || build_name
|
12
15
|
@graph.add_node(self)
|
13
16
|
end
|
@@ -15,7 +18,7 @@ module TensorStream
|
|
15
18
|
private
|
16
19
|
|
17
20
|
def build_name
|
18
|
-
"Placeholder#{
|
21
|
+
"Placeholder#{graph.get_placeholder_counter}:#{@rank}"
|
19
22
|
end
|
20
23
|
end
|
21
|
-
end
|
24
|
+
end
|
@@ -1,5 +1,7 @@
|
|
1
1
|
module TensorStream
|
2
|
+
# TensorStream class that defines a session
|
2
3
|
class Session
|
4
|
+
attr_reader :last_session_context
|
3
5
|
def initialize(evaluator = :ruby_evaluator, thread_pool_class: Concurrent::ImmediateExecutor)
|
4
6
|
@evaluator_class = Object.const_get("TensorStream::Evaluator::#{camelize(evaluator.to_s)}")
|
5
7
|
@thread_pool = thread_pool_class.new
|
@@ -9,25 +11,21 @@ module TensorStream
|
|
9
11
|
@session ||= Session.new
|
10
12
|
end
|
11
13
|
|
12
|
-
def last_session_context
|
13
|
-
@last_session_context
|
14
|
-
end
|
15
|
-
|
16
14
|
def run(*args)
|
17
15
|
options = if args.last.is_a?(Hash)
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
16
|
+
args.pop
|
17
|
+
else
|
18
|
+
{}
|
19
|
+
end
|
22
20
|
context = {}
|
23
21
|
|
24
22
|
# scan for placeholders and assign value
|
25
|
-
options[:feed_dict]
|
26
|
-
|
27
|
-
context[k.name.to_sym] = options[:feed_dict][k]
|
23
|
+
if options[:feed_dict]
|
24
|
+
options[:feed_dict].keys.each do |k|
|
25
|
+
context[k.name.to_sym] = options[:feed_dict][k] if k.is_a?(Placeholder)
|
28
26
|
end
|
29
|
-
end
|
30
|
-
|
27
|
+
end
|
28
|
+
|
31
29
|
evaluator = @evaluator_class.new(self, context.merge!(retain: options[:retain]), thread_pool: @thread_pool)
|
32
30
|
|
33
31
|
execution_context = {}
|
@@ -37,16 +35,16 @@ module TensorStream
|
|
37
35
|
end
|
38
36
|
|
39
37
|
def dump_internal_ops(tensor)
|
40
|
-
dump_ops(tensor, ->(
|
38
|
+
dump_ops(tensor, ->(_k, n) { n.is_a?(Tensor) && n.internal? })
|
41
39
|
end
|
42
40
|
|
43
41
|
def dump_user_ops(tensor)
|
44
|
-
dump_ops(tensor, ->(
|
42
|
+
dump_ops(tensor, ->(_k, n) { n.is_a?(Tensor) && !n.internal? })
|
45
43
|
end
|
46
44
|
|
47
45
|
def dump_ops(tensor, selector)
|
48
46
|
graph = tensor.graph
|
49
|
-
graph.nodes.select { |k,v| selector.call(k, v) }.collect do |k, node|
|
47
|
+
graph.nodes.select { |k, v| selector.call(k, v) }.collect do |k, node|
|
50
48
|
next unless @last_session_context[node.name]
|
51
49
|
"#{k} #{node.to_math(true, 1)} = #{@last_session_context[node.name]}"
|
52
50
|
end.compact
|
@@ -55,12 +53,12 @@ module TensorStream
|
|
55
53
|
private
|
56
54
|
|
57
55
|
def camelize(string, uppercase_first_letter = true)
|
58
|
-
if uppercase_first_letter
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
56
|
+
string = if uppercase_first_letter
|
57
|
+
string.sub(/^[a-z\d]*/) { $&.capitalize }
|
58
|
+
else
|
59
|
+
string.sub(/^(?:(?=\b|[A-Z_])|\w)/) { $&.downcase }
|
60
|
+
end
|
63
61
|
string.gsub(/(?:_|(\/))([a-z\d]*)/) { "#{$1}#{$2.capitalize}" }.gsub('/', '::')
|
64
62
|
end
|
65
63
|
end
|
66
|
-
end
|
64
|
+
end
|
data/lib/tensor_stream/tensor.rb
CHANGED
@@ -1,48 +1,19 @@
|
|
1
1
|
require 'ostruct'
|
2
2
|
|
3
3
|
module TensorStream
|
4
|
+
# Base class that defines a tensor like interface
|
4
5
|
class Tensor
|
5
6
|
include OpHelper
|
6
7
|
|
7
8
|
attr_accessor :name, :data_type, :shape, :rank, :native_buffer, :is_const, :value, :breakpoint, :internal, :source, :given_name, :graph
|
8
9
|
|
9
|
-
def self.const_name
|
10
|
-
@const_counter ||= 0
|
11
|
-
|
12
|
-
name = if @const_counter == 0
|
13
|
-
""
|
14
|
-
else
|
15
|
-
"_#{@const_counter}"
|
16
|
-
end
|
17
|
-
|
18
|
-
@const_counter += 1
|
19
|
-
|
20
|
-
name
|
21
|
-
end
|
22
|
-
|
23
|
-
def self.var_name
|
24
|
-
@var_counter ||= 0
|
25
|
-
@var_counter += 1
|
26
|
-
|
27
|
-
return "" if @var_counter == 1
|
28
|
-
return "_#{@var_counter}"
|
29
|
-
end
|
30
|
-
|
31
|
-
def self.placeholder_name
|
32
|
-
@placeholder_counter ||= 0
|
33
|
-
@placeholder_counter += 1
|
34
|
-
|
35
|
-
return "" if @placeholder_counter == 1
|
36
|
-
return "_#{@placeholder_counter}"
|
37
|
-
end
|
38
|
-
|
39
10
|
def initialize(data_type, rank, shape, options = {})
|
40
11
|
@data_type = data_type
|
41
12
|
@rank = rank
|
42
13
|
@breakpoint = false
|
43
14
|
@shape = TensorShape.new(shape, rank)
|
44
15
|
@value = nil
|
45
|
-
@source =
|
16
|
+
@source = format_source(caller_locations)
|
46
17
|
@is_const = options[:const] || false
|
47
18
|
@internal = options[:internal]
|
48
19
|
@graph = options[:graph] || TensorStream.get_default_graph
|
@@ -50,16 +21,14 @@ module TensorStream
|
|
50
21
|
@given_name = @name
|
51
22
|
|
52
23
|
if options[:value]
|
53
|
-
if options[:value].
|
24
|
+
if options[:value].is_a?(Array)
|
54
25
|
# check if single dimenstion array is passed
|
55
|
-
if shape.size >= 2 && options[:value].
|
56
|
-
options[:value] = reshape(options[:value], shape.reverse.dup)
|
57
|
-
end
|
26
|
+
options[:value] = reshape(options[:value], shape.reverse.dup) if shape.size >= 2 && !options[:value].empty? && !options[:value][0].is_a?(Array)
|
58
27
|
|
59
28
|
@value = options[:value].collect do |v|
|
60
|
-
v.
|
29
|
+
v.is_a?(Tensor) ? Tensor.cast_dtype(v, data_type) : v
|
61
30
|
end
|
62
|
-
elsif shape.
|
31
|
+
elsif !shape.empty?
|
63
32
|
@value = reshape(Tensor.cast_dtype(options[:value], @data_type), shape.dup)
|
64
33
|
else
|
65
34
|
@value = Tensor.cast_dtype(options[:value], @data_type)
|
@@ -83,66 +52,56 @@ module TensorStream
|
|
83
52
|
@placeholder_counter = 0
|
84
53
|
end
|
85
54
|
|
86
|
-
def
|
87
|
-
|
88
|
-
NArray.sfloat(@shape.cols * @shape.rows)
|
89
|
-
elsif @data_type == :float32 && @rank == 0
|
90
|
-
NArray.sfloat(1)
|
91
|
-
else
|
92
|
-
raise "Invalid data type #{@data_type}"
|
93
|
-
end
|
94
|
-
end
|
95
|
-
|
96
|
-
def +(operand)
|
97
|
-
TensorStream::Operation.new(:add, self, auto_wrap(operand))
|
55
|
+
def +(other)
|
56
|
+
TensorStream::Operation.new(:add, self, auto_wrap(other))
|
98
57
|
end
|
99
58
|
|
100
59
|
def [](index)
|
101
60
|
TensorStream::Operation.new(:index, self, index)
|
102
61
|
end
|
103
62
|
|
104
|
-
def *(
|
105
|
-
TensorStream::Operation.new(:mul, self, auto_wrap(
|
63
|
+
def *(other)
|
64
|
+
TensorStream::Operation.new(:mul, self, auto_wrap(other))
|
106
65
|
end
|
107
66
|
|
108
|
-
def **(
|
109
|
-
TensorStream::Operation.new(:pow, self, auto_wrap(
|
67
|
+
def **(other)
|
68
|
+
TensorStream::Operation.new(:pow, self, auto_wrap(other))
|
110
69
|
end
|
111
70
|
|
112
|
-
def /(
|
113
|
-
TensorStream::Operation.new(:div, self, auto_wrap(
|
71
|
+
def /(other)
|
72
|
+
TensorStream::Operation.new(:div, self, auto_wrap(other))
|
114
73
|
end
|
115
74
|
|
116
|
-
def -(
|
117
|
-
TensorStream::Operation.new(:sub, self, auto_wrap(
|
75
|
+
def -(other)
|
76
|
+
TensorStream::Operation.new(:sub, self, auto_wrap(other))
|
118
77
|
end
|
119
78
|
|
120
79
|
def -@
|
121
80
|
TensorStream::Operation.new(:negate, self, nil)
|
122
81
|
end
|
123
82
|
|
124
|
-
def ==(
|
125
|
-
op(:equal, self,
|
83
|
+
def ==(other)
|
84
|
+
op(:equal, self, other)
|
126
85
|
end
|
127
86
|
|
128
|
-
def <(
|
129
|
-
op(:less, self,
|
87
|
+
def <(other)
|
88
|
+
op(:less, self, other)
|
130
89
|
end
|
131
90
|
|
132
|
-
def !=(
|
133
|
-
op(:not_equal, self,
|
91
|
+
def !=(other)
|
92
|
+
op(:not_equal, self, other)
|
134
93
|
end
|
135
94
|
|
136
|
-
def >(
|
137
|
-
op(:greater, self,
|
95
|
+
def >(other)
|
96
|
+
op(:greater, self, other)
|
138
97
|
end
|
139
98
|
|
140
|
-
def >=(
|
141
|
-
op(:greater_equal, self,
|
99
|
+
def >=(other)
|
100
|
+
op(:greater_equal, self, other)
|
142
101
|
end
|
143
102
|
|
144
|
-
def <=(
|
145
|
-
op(:less_equal, self,
|
103
|
+
def <=(other)
|
104
|
+
op(:less_equal, self, other)
|
146
105
|
end
|
147
106
|
|
148
107
|
def collect(&block)
|
@@ -153,23 +112,6 @@ module TensorStream
|
|
153
112
|
@name
|
154
113
|
end
|
155
114
|
|
156
|
-
# def to_ary
|
157
|
-
# if rank == 2
|
158
|
-
# @native_buffer.to_a.each_slice(shape.cols).collect { |slice| slice }
|
159
|
-
# else
|
160
|
-
# raise "Invalid rank"
|
161
|
-
# end
|
162
|
-
# end
|
163
|
-
|
164
|
-
# open cl methods
|
165
|
-
def open_cl_buffer(context)
|
166
|
-
@cl_buffer ||= context.create_buffer(@native_buffer.size * @native_buffer.element_size, :flags => OpenCL::Mem::COPY_HOST_PTR, :host_ptr => @native_buffer)
|
167
|
-
end
|
168
|
-
|
169
|
-
def sync_cl_buffer(queue, events = [])
|
170
|
-
queue.enqueue_read_buffer(@cl_buffer, @native_buffer, :event_wait_list => events)
|
171
|
-
end
|
172
|
-
|
173
115
|
def eval(options = {})
|
174
116
|
Session.default_session.run(self, options)
|
175
117
|
end
|
@@ -201,21 +143,21 @@ module TensorStream
|
|
201
143
|
end
|
202
144
|
|
203
145
|
def to_math(name_only = false, max_depth = 99)
|
204
|
-
return @name if max_depth
|
205
|
-
|
206
|
-
if @value.
|
207
|
-
@value.collect { |v| v.
|
146
|
+
return @name if max_depth.zero? || name_only || @value.nil?
|
147
|
+
|
148
|
+
if @value.is_a?(Array)
|
149
|
+
@value.collect { |v| v.is_a?(Tensor) ? v.to_math(name_only, max_depth - 1) : v }
|
208
150
|
else
|
209
151
|
is_const ? @value : @name
|
210
152
|
end
|
211
153
|
end
|
212
154
|
|
213
155
|
def auto_math(tensor, name_only = false, max_depth = 99)
|
214
|
-
tensor.
|
156
|
+
tensor.is_a?(Tensor) ? tensor.to_math(name_only, max_depth) : tensor
|
215
157
|
end
|
216
158
|
|
217
159
|
def self.detect_type(value)
|
218
|
-
|
160
|
+
if value.is_a?(String)
|
219
161
|
:string
|
220
162
|
elsif value.is_a?(Float)
|
221
163
|
:float32
|
@@ -257,7 +199,7 @@ module TensorStream
|
|
257
199
|
when :unknown
|
258
200
|
val
|
259
201
|
else
|
260
|
-
|
202
|
+
raise "unknown data_type #{dtype} passed"
|
261
203
|
end
|
262
204
|
end
|
263
205
|
|
@@ -268,15 +210,15 @@ module TensorStream
|
|
268
210
|
|
269
211
|
protected
|
270
212
|
|
271
|
-
def
|
272
|
-
trace.reject { |c| c.to_s.include?(File.join(
|
213
|
+
def format_source(trace)
|
214
|
+
trace.reject { |c| c.to_s.include?(File.join('lib', 'tensor_stream')) }.first
|
273
215
|
end
|
274
216
|
|
275
217
|
def hashify_tensor(tensor)
|
276
|
-
if tensor.
|
218
|
+
if tensor.is_a?(Tensor)
|
277
219
|
tensor.to_h
|
278
|
-
elsif tensor.
|
279
|
-
tensor.collect
|
220
|
+
elsif tensor.is_a?(Array)
|
221
|
+
tensor.collect { |t| hashify_tensor(t) }
|
280
222
|
else
|
281
223
|
tensor
|
282
224
|
end
|
@@ -290,11 +232,11 @@ module TensorStream
|
|
290
232
|
reshape(s, shape)
|
291
233
|
end
|
292
234
|
else
|
293
|
-
return arr if shape.
|
235
|
+
return arr if shape.empty?
|
294
236
|
slice = shape.shift
|
295
237
|
return arr if slice.nil?
|
296
238
|
|
297
|
-
slice
|
239
|
+
Array.new(slice) do
|
298
240
|
reshape(arr, shape.dup)
|
299
241
|
end
|
300
242
|
end
|
@@ -311,7 +253,7 @@ module TensorStream
|
|
311
253
|
end
|
312
254
|
|
313
255
|
def build_name
|
314
|
-
|
256
|
+
@is_const ? "Const#{graph.get_const_counter}:#{@rank}" : "Variable#{graph.get_var_counter}:#{@rank}"
|
315
257
|
end
|
316
258
|
end
|
317
259
|
end
|