tensor_stream 0.1.1 → 0.1.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +1 -0
- data/.rubocop.yml +74 -0
- data/README.md +4 -3
- data/lib/tensor_stream.rb +17 -18
- data/lib/tensor_stream/control_flow.rb +7 -4
- data/lib/tensor_stream/evaluator/evaluator.rb +1 -1
- data/lib/tensor_stream/evaluator/operation_helpers/random_gaussian.rb +6 -7
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +80 -91
- data/lib/tensor_stream/graph.rb +51 -18
- data/lib/tensor_stream/graph_keys.rb +2 -2
- data/lib/tensor_stream/helpers/op_helper.rb +31 -27
- data/lib/tensor_stream/math_gradients.rb +20 -23
- data/lib/tensor_stream/nn/nn_ops.rb +2 -2
- data/lib/tensor_stream/operation.rb +28 -45
- data/lib/tensor_stream/ops.rb +103 -103
- data/lib/tensor_stream/placeholder.rb +7 -4
- data/lib/tensor_stream/session.rb +20 -22
- data/lib/tensor_stream/tensor.rb +43 -101
- data/lib/tensor_stream/tensor_shape.rb +4 -3
- data/lib/tensor_stream/train/gradient_descent_optimizer.rb +5 -5
- data/lib/tensor_stream/train/saver.rb +13 -13
- data/lib/tensor_stream/trainer.rb +1 -1
- data/lib/tensor_stream/types.rb +14 -1
- data/lib/tensor_stream/variable.rb +8 -7
- data/lib/tensor_stream/version.rb +1 -1
- data/samples/iris.rb +7 -7
- data/samples/linear_regression.rb +3 -3
- data/samples/raw_neural_net_sample.rb +6 -6
- data/tensor_stream.gemspec +1 -0
- metadata +17 -2
data/lib/tensor_stream/ops.rb
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
module TensorStream
|
2
|
+
# Class that defines all available ops supported by TensorStream
|
2
3
|
module Ops
|
3
4
|
FLOATING_POINT_TYPES = %w[float32 float64].map(&:to_sym)
|
4
5
|
NUMERIC_TYPES = %w[int32 int64 float32 float64].map(&:to_sym)
|
@@ -7,27 +8,26 @@ module TensorStream
|
|
7
8
|
op(:argmax, input, nil, axis: axis, name: name, dimension: dimension, data_type: output_type)
|
8
9
|
end
|
9
10
|
|
10
|
-
def gradients(
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
fail "#{x} passed is not a tensor object" unless x.is_a?(Tensor)
|
11
|
+
def gradients(input, wrt_xs, grad_ys: nil,
|
12
|
+
name: 'gradients',
|
13
|
+
colocate_gradients_with_ops: false,
|
14
|
+
gate_gradients: false,
|
15
|
+
aggregation_method: nil,
|
16
|
+
stop_gradients: nil)
|
17
|
+
|
18
|
+
gs = wrt_xs.collect do |x|
|
19
|
+
raise "#{x} passed is not a tensor object" unless x.is_a?(Tensor)
|
20
20
|
|
21
21
|
stops = stop_gradients ? stop_gradients.map(&:name).join('_') : ''
|
22
|
-
gradient_program_name = "grad_#{
|
22
|
+
gradient_program_name = "grad_#{input.name}_#{x.name}_#{stops}".to_sym
|
23
23
|
|
24
|
-
tensor_program = if
|
25
|
-
|
24
|
+
tensor_program = if input.graph.node_added?(gradient_program_name)
|
25
|
+
input.graph.get_node(gradient_program_name)
|
26
26
|
else
|
27
|
-
derivative_ops = TensorStream::MathGradients.derivative(
|
28
|
-
|
27
|
+
derivative_ops = TensorStream::MathGradients.derivative(input, x, graph: input.graph,
|
28
|
+
stop_gradients: stop_gradients)
|
29
29
|
unit_matrix = op(:ones_like, x)
|
30
|
-
|
30
|
+
input.graph.add_node!(gradient_program_name, unit_matrix * derivative_ops)
|
31
31
|
end
|
32
32
|
tensor_program
|
33
33
|
end
|
@@ -35,12 +35,12 @@ module TensorStream
|
|
35
35
|
end
|
36
36
|
|
37
37
|
def random_uniform(shape, dtype: :float32, minval: 0, maxval: 1, seed: nil, name: nil)
|
38
|
-
options = {shape: shape, dtype: dtype, minval: minval, maxval: maxval, seed: seed, name: name}
|
38
|
+
options = { shape: shape, dtype: dtype, minval: minval, maxval: maxval, seed: seed, name: name }
|
39
39
|
op(:random_uniform, nil, nil, options)
|
40
40
|
end
|
41
41
|
|
42
42
|
def random_normal(shape, dtype: :float32, mean: 0.0, stddev: 1.0, seed: nil, name: nil)
|
43
|
-
options = {shape: shape, dtype: dtype, mean: mean, stddev: stddev, seed: seed, name: name}
|
43
|
+
options = { shape: shape, dtype: dtype, mean: mean, stddev: stddev, seed: seed, name: name }
|
44
44
|
op(:random_normal, nil, nil, options)
|
45
45
|
end
|
46
46
|
|
@@ -53,7 +53,7 @@ module TensorStream
|
|
53
53
|
end
|
54
54
|
|
55
55
|
def shape(input, name: nil, out_type: :int32)
|
56
|
-
op(:shape, input, nil, name: name)
|
56
|
+
op(:shape, input, nil, name: name, out_type: out_type)
|
57
57
|
end
|
58
58
|
|
59
59
|
def rank(input, name: nil)
|
@@ -63,35 +63,35 @@ module TensorStream
|
|
63
63
|
def zeros_initializer(options = {})
|
64
64
|
op(:zeros, nil, nil, options)
|
65
65
|
end
|
66
|
-
|
66
|
+
|
67
67
|
def slice(input, start, size, name: nil)
|
68
68
|
op(:slice, input, start, size: size, name: name)
|
69
69
|
end
|
70
|
-
|
70
|
+
|
71
71
|
def zeros(shape, dtype: :float32, name: nil)
|
72
72
|
op(:zeros, shape, nil, data_type: dtype, name: name)
|
73
73
|
end
|
74
|
-
|
74
|
+
|
75
75
|
def ones(shape, dtype: :float32, name: nil)
|
76
76
|
op(:ones, shape, nil, data_type: dtype, name: name)
|
77
77
|
end
|
78
|
-
|
79
|
-
def less(
|
80
|
-
op(:less,
|
78
|
+
|
79
|
+
def less(input_a, input_b, name: nil)
|
80
|
+
op(:less, input_a, input_b, name: name)
|
81
81
|
end
|
82
|
-
|
83
|
-
def greater(
|
84
|
-
op(:greater,
|
82
|
+
|
83
|
+
def greater(input_a, input_b, name: nil)
|
84
|
+
op(:greater, input_a, input_b, name: name)
|
85
85
|
end
|
86
86
|
|
87
|
-
def greater_equal(
|
88
|
-
op(:greater_equal,
|
87
|
+
def greater_equal(input_a, input_b, name: nil)
|
88
|
+
op(:greater_equal, input_a, input_b, name: name)
|
89
89
|
end
|
90
90
|
|
91
|
-
def less_equal(
|
92
|
-
op(:less_equal,
|
91
|
+
def less_equal(input_a, input_b, name: nil)
|
92
|
+
op(:less_equal, input_a, input_b, name: name)
|
93
93
|
end
|
94
|
-
|
94
|
+
|
95
95
|
def reduce_mean(input_tensor, axis = nil, keepdims: false, name: nil)
|
96
96
|
op(:reduce_mean, input_tensor, nil, axis: axis, keepdims: keepdims, name: name)
|
97
97
|
end
|
@@ -99,66 +99,66 @@ module TensorStream
|
|
99
99
|
def reduce_sum(input_tensor, axis = nil, keepdims: false, name: nil)
|
100
100
|
op(:reduce_sum, input_tensor, nil, axis: axis, keepdims: keepdims, name: name)
|
101
101
|
end
|
102
|
-
|
102
|
+
|
103
103
|
def reduce_prod(input, axis = nil, keepdims: false, name: nil)
|
104
104
|
op(:reduce_prod, input, nil, axis: axis, keepdims: keepdims, name: name)
|
105
105
|
end
|
106
|
-
|
106
|
+
|
107
107
|
def concat(values, axis, name: 'concat')
|
108
108
|
op(:concat, values, nil, axis: axis, name: name)
|
109
109
|
end
|
110
|
-
|
110
|
+
|
111
111
|
def reshape(tensor, shape, name: nil)
|
112
112
|
op(:reshape, tensor, shape, name: name)
|
113
113
|
end
|
114
|
-
|
114
|
+
|
115
115
|
def square(tensor, name: nil)
|
116
116
|
op(:square, tensor, nil, name: name)
|
117
117
|
end
|
118
|
-
|
118
|
+
|
119
119
|
def cond(pred, true_fn, false_fn, name: nil)
|
120
120
|
op(:cond, true_fn, false_fn, pred: pred, name: name)
|
121
121
|
end
|
122
122
|
|
123
|
-
def where(condition,
|
124
|
-
op(:where,
|
123
|
+
def where(condition, true_t = nil, false_t = nil, name: nil)
|
124
|
+
op(:where, true_t, false_t, pred: condition, name: name)
|
125
125
|
end
|
126
|
-
|
127
|
-
def add(
|
128
|
-
op(:add,
|
126
|
+
|
127
|
+
def add(input_a, input_b, name: nil)
|
128
|
+
op(:add, input_a, input_b, name: name)
|
129
129
|
end
|
130
|
-
|
131
|
-
def sub(
|
132
|
-
op(:sub,
|
130
|
+
|
131
|
+
def sub(input_a, input_b, name: nil)
|
132
|
+
op(:sub, input_a, input_b, name: name)
|
133
133
|
end
|
134
134
|
|
135
|
-
def max(
|
136
|
-
check_allowed_types(
|
137
|
-
check_allowed_types(
|
135
|
+
def max(input_a, input_b, name: nil)
|
136
|
+
check_allowed_types(input_a, NUMERIC_TYPES)
|
137
|
+
check_allowed_types(input_b, NUMERIC_TYPES)
|
138
138
|
|
139
|
-
op(:max,
|
139
|
+
op(:max, input_a, input_b, name: name)
|
140
140
|
end
|
141
141
|
|
142
142
|
def cast(input, dtype, name: nil)
|
143
143
|
op(:cast, input, nil, data_type: dtype, name: name)
|
144
144
|
end
|
145
|
-
|
145
|
+
|
146
146
|
def print(input, data, message: nil, name: nil)
|
147
147
|
op(:print, input, data, message: message, name: name)
|
148
148
|
end
|
149
|
-
|
150
|
-
def negate(
|
151
|
-
op(:negate,
|
149
|
+
|
150
|
+
def negate(input, options = {})
|
151
|
+
op(:negate, input, nil, options)
|
152
152
|
end
|
153
|
-
|
154
|
-
def equal(
|
155
|
-
op(:equal,
|
153
|
+
|
154
|
+
def equal(input_a, input_b, name: nil)
|
155
|
+
op(:equal, input_a, input_b, name: name)
|
156
156
|
end
|
157
157
|
|
158
|
-
def not_equal(
|
159
|
-
op(:not_equal,
|
158
|
+
def not_equal(input_a, input_b, name: nil)
|
159
|
+
op(:not_equal, input_a, input_b, name: name)
|
160
160
|
end
|
161
|
-
|
161
|
+
|
162
162
|
def zeros_like(tensor, dtype: nil, name: nil)
|
163
163
|
op(:zeros_like, tensor, nil, data_type: dtype, name: name)
|
164
164
|
end
|
@@ -166,75 +166,75 @@ module TensorStream
|
|
166
166
|
def ones_like(tensor, dtype: nil, name: nil)
|
167
167
|
op(:ones_like, tensor, nil, data_type: dtype, name: name)
|
168
168
|
end
|
169
|
-
|
169
|
+
|
170
170
|
def identity(input, name: nil)
|
171
171
|
op(:identity, input, nil, name: name)
|
172
172
|
end
|
173
|
-
|
174
|
-
def multiply(
|
175
|
-
op(:mul,
|
173
|
+
|
174
|
+
def multiply(input_a, input_b, name: nil)
|
175
|
+
op(:mul, input_a, input_b, name: name)
|
176
176
|
end
|
177
|
-
|
178
|
-
def pow(
|
179
|
-
op(:pow,
|
177
|
+
|
178
|
+
def pow(input_a, input_e, name: nil)
|
179
|
+
op(:pow, input_a, input_e, name: name)
|
180
180
|
end
|
181
|
-
|
182
|
-
def abs(
|
183
|
-
op(:abs,
|
181
|
+
|
182
|
+
def abs(input, name: nil)
|
183
|
+
op(:abs, input, nil, name: name)
|
184
184
|
end
|
185
|
-
|
186
|
-
def sign(
|
187
|
-
op(:sign,
|
185
|
+
|
186
|
+
def sign(input, name: nil)
|
187
|
+
op(:sign, input, nil, name: name)
|
188
188
|
end
|
189
|
-
|
190
|
-
def sin(
|
189
|
+
|
190
|
+
def sin(input, options = {})
|
191
191
|
options[:data_type] ||= :float32
|
192
|
-
check_allowed_types(
|
193
|
-
op(:sin,
|
192
|
+
check_allowed_types(input, FLOATING_POINT_TYPES)
|
193
|
+
op(:sin, input, nil, options)
|
194
194
|
end
|
195
|
-
|
196
|
-
def cos(
|
195
|
+
|
196
|
+
def cos(input, options = {})
|
197
197
|
options[:data_type] ||= :float32
|
198
|
-
check_allowed_types(
|
199
|
-
op(:cos,
|
198
|
+
check_allowed_types(input, FLOATING_POINT_TYPES)
|
199
|
+
op(:cos, input, nil, options)
|
200
200
|
end
|
201
|
-
|
202
|
-
def tan(
|
201
|
+
|
202
|
+
def tan(input, options = {})
|
203
203
|
options[:data_type] ||= :float32
|
204
|
-
check_allowed_types(
|
205
|
-
op(:tan,
|
204
|
+
check_allowed_types(input, FLOATING_POINT_TYPES)
|
205
|
+
op(:tan, input, nil, options)
|
206
206
|
end
|
207
|
-
|
208
|
-
def tanh(
|
207
|
+
|
208
|
+
def tanh(input, options = {})
|
209
209
|
options[:data_type] ||= :float32
|
210
|
-
check_allowed_types(
|
211
|
-
op(:tanh,
|
210
|
+
check_allowed_types(input, FLOATING_POINT_TYPES)
|
211
|
+
op(:tanh, input, nil, options)
|
212
212
|
end
|
213
|
-
|
214
|
-
def sqrt(
|
213
|
+
|
214
|
+
def sqrt(input, name: nil)
|
215
215
|
options = {
|
216
|
-
data_type:
|
216
|
+
data_type: input.data_type,
|
217
217
|
name: name
|
218
218
|
}
|
219
|
-
check_allowed_types(
|
220
|
-
op(:sqrt,
|
219
|
+
check_allowed_types(input, FLOATING_POINT_TYPES)
|
220
|
+
op(:sqrt, input, nil, options)
|
221
221
|
end
|
222
|
-
|
223
|
-
def log(input, options= {})
|
222
|
+
|
223
|
+
def log(input, options = {})
|
224
224
|
options[:data_type] ||= :float32
|
225
225
|
check_allowed_types(input, FLOATING_POINT_TYPES)
|
226
226
|
op(:log, input, nil, options)
|
227
227
|
end
|
228
|
-
|
229
|
-
def exp(
|
228
|
+
|
229
|
+
def exp(input, options = {})
|
230
230
|
options[:data_type] ||= :float32
|
231
|
-
check_allowed_types(
|
232
|
-
op(:exp,
|
231
|
+
check_allowed_types(input, FLOATING_POINT_TYPES)
|
232
|
+
op(:exp, input, nil, options)
|
233
233
|
end
|
234
234
|
|
235
235
|
def matmul(input_a, input_b, transpose_a: false,
|
236
|
-
|
237
|
-
|
236
|
+
transpose_b: false,
|
237
|
+
name: nil)
|
238
238
|
op(:matmul, input_a, input_b, transpose_a: transpose_a, transpose_b: transpose_b, name: name)
|
239
239
|
end
|
240
240
|
|
@@ -246,4 +246,4 @@ module TensorStream
|
|
246
246
|
op(:pad, tensor, nil, paddings: paddings, mode: mode, name: name)
|
247
247
|
end
|
248
248
|
end
|
249
|
-
end
|
249
|
+
end
|
@@ -1,13 +1,16 @@
|
|
1
1
|
module TensorStream
|
2
|
+
# Class that defines a TensorStream placeholder
|
2
3
|
class Placeholder < Tensor
|
3
4
|
def initialize(data_type, rank, shape, options = {})
|
5
|
+
@graph = options[:graph] || TensorStream.get_default_graph
|
6
|
+
|
4
7
|
@data_type = data_type
|
5
8
|
@rank = rank
|
6
9
|
@shape = TensorShape.new(shape, rank)
|
7
10
|
@value = nil
|
8
11
|
@is_const = false
|
9
|
-
@source =
|
10
|
-
|
12
|
+
@source = format_source(caller_locations)
|
13
|
+
|
11
14
|
@name = options[:name] || build_name
|
12
15
|
@graph.add_node(self)
|
13
16
|
end
|
@@ -15,7 +18,7 @@ module TensorStream
|
|
15
18
|
private
|
16
19
|
|
17
20
|
def build_name
|
18
|
-
"Placeholder#{
|
21
|
+
"Placeholder#{graph.get_placeholder_counter}:#{@rank}"
|
19
22
|
end
|
20
23
|
end
|
21
|
-
end
|
24
|
+
end
|
@@ -1,5 +1,7 @@
|
|
1
1
|
module TensorStream
|
2
|
+
# TensorStream class that defines a session
|
2
3
|
class Session
|
4
|
+
attr_reader :last_session_context
|
3
5
|
def initialize(evaluator = :ruby_evaluator, thread_pool_class: Concurrent::ImmediateExecutor)
|
4
6
|
@evaluator_class = Object.const_get("TensorStream::Evaluator::#{camelize(evaluator.to_s)}")
|
5
7
|
@thread_pool = thread_pool_class.new
|
@@ -9,25 +11,21 @@ module TensorStream
|
|
9
11
|
@session ||= Session.new
|
10
12
|
end
|
11
13
|
|
12
|
-
def last_session_context
|
13
|
-
@last_session_context
|
14
|
-
end
|
15
|
-
|
16
14
|
def run(*args)
|
17
15
|
options = if args.last.is_a?(Hash)
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
16
|
+
args.pop
|
17
|
+
else
|
18
|
+
{}
|
19
|
+
end
|
22
20
|
context = {}
|
23
21
|
|
24
22
|
# scan for placeholders and assign value
|
25
|
-
options[:feed_dict]
|
26
|
-
|
27
|
-
context[k.name.to_sym] = options[:feed_dict][k]
|
23
|
+
if options[:feed_dict]
|
24
|
+
options[:feed_dict].keys.each do |k|
|
25
|
+
context[k.name.to_sym] = options[:feed_dict][k] if k.is_a?(Placeholder)
|
28
26
|
end
|
29
|
-
end
|
30
|
-
|
27
|
+
end
|
28
|
+
|
31
29
|
evaluator = @evaluator_class.new(self, context.merge!(retain: options[:retain]), thread_pool: @thread_pool)
|
32
30
|
|
33
31
|
execution_context = {}
|
@@ -37,16 +35,16 @@ module TensorStream
|
|
37
35
|
end
|
38
36
|
|
39
37
|
def dump_internal_ops(tensor)
|
40
|
-
dump_ops(tensor, ->(
|
38
|
+
dump_ops(tensor, ->(_k, n) { n.is_a?(Tensor) && n.internal? })
|
41
39
|
end
|
42
40
|
|
43
41
|
def dump_user_ops(tensor)
|
44
|
-
dump_ops(tensor, ->(
|
42
|
+
dump_ops(tensor, ->(_k, n) { n.is_a?(Tensor) && !n.internal? })
|
45
43
|
end
|
46
44
|
|
47
45
|
def dump_ops(tensor, selector)
|
48
46
|
graph = tensor.graph
|
49
|
-
graph.nodes.select { |k,v| selector.call(k, v) }.collect do |k, node|
|
47
|
+
graph.nodes.select { |k, v| selector.call(k, v) }.collect do |k, node|
|
50
48
|
next unless @last_session_context[node.name]
|
51
49
|
"#{k} #{node.to_math(true, 1)} = #{@last_session_context[node.name]}"
|
52
50
|
end.compact
|
@@ -55,12 +53,12 @@ module TensorStream
|
|
55
53
|
private
|
56
54
|
|
57
55
|
def camelize(string, uppercase_first_letter = true)
|
58
|
-
if uppercase_first_letter
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
56
|
+
string = if uppercase_first_letter
|
57
|
+
string.sub(/^[a-z\d]*/) { $&.capitalize }
|
58
|
+
else
|
59
|
+
string.sub(/^(?:(?=\b|[A-Z_])|\w)/) { $&.downcase }
|
60
|
+
end
|
63
61
|
string.gsub(/(?:_|(\/))([a-z\d]*)/) { "#{$1}#{$2.capitalize}" }.gsub('/', '::')
|
64
62
|
end
|
65
63
|
end
|
66
|
-
end
|
64
|
+
end
|
data/lib/tensor_stream/tensor.rb
CHANGED
@@ -1,48 +1,19 @@
|
|
1
1
|
require 'ostruct'
|
2
2
|
|
3
3
|
module TensorStream
|
4
|
+
# Base class that defines a tensor like interface
|
4
5
|
class Tensor
|
5
6
|
include OpHelper
|
6
7
|
|
7
8
|
attr_accessor :name, :data_type, :shape, :rank, :native_buffer, :is_const, :value, :breakpoint, :internal, :source, :given_name, :graph
|
8
9
|
|
9
|
-
def self.const_name
|
10
|
-
@const_counter ||= 0
|
11
|
-
|
12
|
-
name = if @const_counter == 0
|
13
|
-
""
|
14
|
-
else
|
15
|
-
"_#{@const_counter}"
|
16
|
-
end
|
17
|
-
|
18
|
-
@const_counter += 1
|
19
|
-
|
20
|
-
name
|
21
|
-
end
|
22
|
-
|
23
|
-
def self.var_name
|
24
|
-
@var_counter ||= 0
|
25
|
-
@var_counter += 1
|
26
|
-
|
27
|
-
return "" if @var_counter == 1
|
28
|
-
return "_#{@var_counter}"
|
29
|
-
end
|
30
|
-
|
31
|
-
def self.placeholder_name
|
32
|
-
@placeholder_counter ||= 0
|
33
|
-
@placeholder_counter += 1
|
34
|
-
|
35
|
-
return "" if @placeholder_counter == 1
|
36
|
-
return "_#{@placeholder_counter}"
|
37
|
-
end
|
38
|
-
|
39
10
|
def initialize(data_type, rank, shape, options = {})
|
40
11
|
@data_type = data_type
|
41
12
|
@rank = rank
|
42
13
|
@breakpoint = false
|
43
14
|
@shape = TensorShape.new(shape, rank)
|
44
15
|
@value = nil
|
45
|
-
@source =
|
16
|
+
@source = format_source(caller_locations)
|
46
17
|
@is_const = options[:const] || false
|
47
18
|
@internal = options[:internal]
|
48
19
|
@graph = options[:graph] || TensorStream.get_default_graph
|
@@ -50,16 +21,14 @@ module TensorStream
|
|
50
21
|
@given_name = @name
|
51
22
|
|
52
23
|
if options[:value]
|
53
|
-
if options[:value].
|
24
|
+
if options[:value].is_a?(Array)
|
54
25
|
# check if single dimenstion array is passed
|
55
|
-
if shape.size >= 2 && options[:value].
|
56
|
-
options[:value] = reshape(options[:value], shape.reverse.dup)
|
57
|
-
end
|
26
|
+
options[:value] = reshape(options[:value], shape.reverse.dup) if shape.size >= 2 && !options[:value].empty? && !options[:value][0].is_a?(Array)
|
58
27
|
|
59
28
|
@value = options[:value].collect do |v|
|
60
|
-
v.
|
29
|
+
v.is_a?(Tensor) ? Tensor.cast_dtype(v, data_type) : v
|
61
30
|
end
|
62
|
-
elsif shape.
|
31
|
+
elsif !shape.empty?
|
63
32
|
@value = reshape(Tensor.cast_dtype(options[:value], @data_type), shape.dup)
|
64
33
|
else
|
65
34
|
@value = Tensor.cast_dtype(options[:value], @data_type)
|
@@ -83,66 +52,56 @@ module TensorStream
|
|
83
52
|
@placeholder_counter = 0
|
84
53
|
end
|
85
54
|
|
86
|
-
def
|
87
|
-
|
88
|
-
NArray.sfloat(@shape.cols * @shape.rows)
|
89
|
-
elsif @data_type == :float32 && @rank == 0
|
90
|
-
NArray.sfloat(1)
|
91
|
-
else
|
92
|
-
raise "Invalid data type #{@data_type}"
|
93
|
-
end
|
94
|
-
end
|
95
|
-
|
96
|
-
def +(operand)
|
97
|
-
TensorStream::Operation.new(:add, self, auto_wrap(operand))
|
55
|
+
def +(other)
|
56
|
+
TensorStream::Operation.new(:add, self, auto_wrap(other))
|
98
57
|
end
|
99
58
|
|
100
59
|
def [](index)
|
101
60
|
TensorStream::Operation.new(:index, self, index)
|
102
61
|
end
|
103
62
|
|
104
|
-
def *(
|
105
|
-
TensorStream::Operation.new(:mul, self, auto_wrap(
|
63
|
+
def *(other)
|
64
|
+
TensorStream::Operation.new(:mul, self, auto_wrap(other))
|
106
65
|
end
|
107
66
|
|
108
|
-
def **(
|
109
|
-
TensorStream::Operation.new(:pow, self, auto_wrap(
|
67
|
+
def **(other)
|
68
|
+
TensorStream::Operation.new(:pow, self, auto_wrap(other))
|
110
69
|
end
|
111
70
|
|
112
|
-
def /(
|
113
|
-
TensorStream::Operation.new(:div, self, auto_wrap(
|
71
|
+
def /(other)
|
72
|
+
TensorStream::Operation.new(:div, self, auto_wrap(other))
|
114
73
|
end
|
115
74
|
|
116
|
-
def -(
|
117
|
-
TensorStream::Operation.new(:sub, self, auto_wrap(
|
75
|
+
def -(other)
|
76
|
+
TensorStream::Operation.new(:sub, self, auto_wrap(other))
|
118
77
|
end
|
119
78
|
|
120
79
|
def -@
|
121
80
|
TensorStream::Operation.new(:negate, self, nil)
|
122
81
|
end
|
123
82
|
|
124
|
-
def ==(
|
125
|
-
op(:equal, self,
|
83
|
+
def ==(other)
|
84
|
+
op(:equal, self, other)
|
126
85
|
end
|
127
86
|
|
128
|
-
def <(
|
129
|
-
op(:less, self,
|
87
|
+
def <(other)
|
88
|
+
op(:less, self, other)
|
130
89
|
end
|
131
90
|
|
132
|
-
def !=(
|
133
|
-
op(:not_equal, self,
|
91
|
+
def !=(other)
|
92
|
+
op(:not_equal, self, other)
|
134
93
|
end
|
135
94
|
|
136
|
-
def >(
|
137
|
-
op(:greater, self,
|
95
|
+
def >(other)
|
96
|
+
op(:greater, self, other)
|
138
97
|
end
|
139
98
|
|
140
|
-
def >=(
|
141
|
-
op(:greater_equal, self,
|
99
|
+
def >=(other)
|
100
|
+
op(:greater_equal, self, other)
|
142
101
|
end
|
143
102
|
|
144
|
-
def <=(
|
145
|
-
op(:less_equal, self,
|
103
|
+
def <=(other)
|
104
|
+
op(:less_equal, self, other)
|
146
105
|
end
|
147
106
|
|
148
107
|
def collect(&block)
|
@@ -153,23 +112,6 @@ module TensorStream
|
|
153
112
|
@name
|
154
113
|
end
|
155
114
|
|
156
|
-
# def to_ary
|
157
|
-
# if rank == 2
|
158
|
-
# @native_buffer.to_a.each_slice(shape.cols).collect { |slice| slice }
|
159
|
-
# else
|
160
|
-
# raise "Invalid rank"
|
161
|
-
# end
|
162
|
-
# end
|
163
|
-
|
164
|
-
# open cl methods
|
165
|
-
def open_cl_buffer(context)
|
166
|
-
@cl_buffer ||= context.create_buffer(@native_buffer.size * @native_buffer.element_size, :flags => OpenCL::Mem::COPY_HOST_PTR, :host_ptr => @native_buffer)
|
167
|
-
end
|
168
|
-
|
169
|
-
def sync_cl_buffer(queue, events = [])
|
170
|
-
queue.enqueue_read_buffer(@cl_buffer, @native_buffer, :event_wait_list => events)
|
171
|
-
end
|
172
|
-
|
173
115
|
def eval(options = {})
|
174
116
|
Session.default_session.run(self, options)
|
175
117
|
end
|
@@ -201,21 +143,21 @@ module TensorStream
|
|
201
143
|
end
|
202
144
|
|
203
145
|
def to_math(name_only = false, max_depth = 99)
|
204
|
-
return @name if max_depth
|
205
|
-
|
206
|
-
if @value.
|
207
|
-
@value.collect { |v| v.
|
146
|
+
return @name if max_depth.zero? || name_only || @value.nil?
|
147
|
+
|
148
|
+
if @value.is_a?(Array)
|
149
|
+
@value.collect { |v| v.is_a?(Tensor) ? v.to_math(name_only, max_depth - 1) : v }
|
208
150
|
else
|
209
151
|
is_const ? @value : @name
|
210
152
|
end
|
211
153
|
end
|
212
154
|
|
213
155
|
def auto_math(tensor, name_only = false, max_depth = 99)
|
214
|
-
tensor.
|
156
|
+
tensor.is_a?(Tensor) ? tensor.to_math(name_only, max_depth) : tensor
|
215
157
|
end
|
216
158
|
|
217
159
|
def self.detect_type(value)
|
218
|
-
|
160
|
+
if value.is_a?(String)
|
219
161
|
:string
|
220
162
|
elsif value.is_a?(Float)
|
221
163
|
:float32
|
@@ -257,7 +199,7 @@ module TensorStream
|
|
257
199
|
when :unknown
|
258
200
|
val
|
259
201
|
else
|
260
|
-
|
202
|
+
raise "unknown data_type #{dtype} passed"
|
261
203
|
end
|
262
204
|
end
|
263
205
|
|
@@ -268,15 +210,15 @@ module TensorStream
|
|
268
210
|
|
269
211
|
protected
|
270
212
|
|
271
|
-
def
|
272
|
-
trace.reject { |c| c.to_s.include?(File.join(
|
213
|
+
def format_source(trace)
|
214
|
+
trace.reject { |c| c.to_s.include?(File.join('lib', 'tensor_stream')) }.first
|
273
215
|
end
|
274
216
|
|
275
217
|
def hashify_tensor(tensor)
|
276
|
-
if tensor.
|
218
|
+
if tensor.is_a?(Tensor)
|
277
219
|
tensor.to_h
|
278
|
-
elsif tensor.
|
279
|
-
tensor.collect
|
220
|
+
elsif tensor.is_a?(Array)
|
221
|
+
tensor.collect { |t| hashify_tensor(t) }
|
280
222
|
else
|
281
223
|
tensor
|
282
224
|
end
|
@@ -290,11 +232,11 @@ module TensorStream
|
|
290
232
|
reshape(s, shape)
|
291
233
|
end
|
292
234
|
else
|
293
|
-
return arr if shape.
|
235
|
+
return arr if shape.empty?
|
294
236
|
slice = shape.shift
|
295
237
|
return arr if slice.nil?
|
296
238
|
|
297
|
-
slice
|
239
|
+
Array.new(slice) do
|
298
240
|
reshape(arr, shape.dup)
|
299
241
|
end
|
300
242
|
end
|
@@ -311,7 +253,7 @@ module TensorStream
|
|
311
253
|
end
|
312
254
|
|
313
255
|
def build_name
|
314
|
-
|
256
|
+
@is_const ? "Const#{graph.get_const_counter}:#{@rank}" : "Variable#{graph.get_var_counter}:#{@rank}"
|
315
257
|
end
|
316
258
|
end
|
317
259
|
end
|