tensor_stream 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.gitignore +12 -0
- data/.rake_tasks~ +0 -0
- data/.rspec +2 -0
- data/.travis.yml +5 -0
- data/CODE_OF_CONDUCT.md +74 -0
- data/Gemfile +4 -0
- data/LICENSE.txt +21 -0
- data/README.md +123 -0
- data/Rakefile +6 -0
- data/bin/console +14 -0
- data/bin/setup +8 -0
- data/lib/tensor_stream.rb +138 -0
- data/lib/tensor_stream/control_flow.rb +23 -0
- data/lib/tensor_stream/evaluator/evaluator.rb +7 -0
- data/lib/tensor_stream/evaluator/operation_helpers/random_gaussian.rb +32 -0
- data/lib/tensor_stream/evaluator/ruby_evaluator.rb +749 -0
- data/lib/tensor_stream/graph.rb +98 -0
- data/lib/tensor_stream/graph_keys.rb +5 -0
- data/lib/tensor_stream/helpers/op_helper.rb +58 -0
- data/lib/tensor_stream/math_gradients.rb +161 -0
- data/lib/tensor_stream/monkey_patches/integer.rb +0 -0
- data/lib/tensor_stream/nn/nn_ops.rb +17 -0
- data/lib/tensor_stream/operation.rb +195 -0
- data/lib/tensor_stream/ops.rb +225 -0
- data/lib/tensor_stream/placeholder.rb +21 -0
- data/lib/tensor_stream/session.rb +66 -0
- data/lib/tensor_stream/tensor.rb +317 -0
- data/lib/tensor_stream/tensor_shape.rb +25 -0
- data/lib/tensor_stream/train/gradient_descent_optimizer.rb +23 -0
- data/lib/tensor_stream/train/saver.rb +61 -0
- data/lib/tensor_stream/trainer.rb +7 -0
- data/lib/tensor_stream/types.rb +17 -0
- data/lib/tensor_stream/variable.rb +52 -0
- data/lib/tensor_stream/version.rb +7 -0
- data/samples/iris.data +150 -0
- data/samples/iris.rb +117 -0
- data/samples/linear_regression.rb +55 -0
- data/samples/raw_neural_net_sample.rb +54 -0
- data/tensor_stream.gemspec +40 -0
- metadata +185 -0
@@ -0,0 +1,25 @@
|
|
1
|
+
module TensorStream
|
2
|
+
class TensorShape
|
3
|
+
attr_accessor :rank, :shape
|
4
|
+
|
5
|
+
def initialize(shape, rank)
|
6
|
+
@shape = shape
|
7
|
+
@rank = rank
|
8
|
+
end
|
9
|
+
|
10
|
+
def to_s
|
11
|
+
dimensions = @shape.collect do |r|
|
12
|
+
"Dimension(#{r})"
|
13
|
+
end.join(',')
|
14
|
+
"TensorShape([#{dimensions}])"
|
15
|
+
end
|
16
|
+
|
17
|
+
def [](index)
|
18
|
+
@shape[index]
|
19
|
+
end
|
20
|
+
|
21
|
+
def ndims
|
22
|
+
shape.size
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
@@ -0,0 +1,23 @@
|
|
1
|
+
module TensorStream
|
2
|
+
module Train
|
3
|
+
# High Level implementation of the gradient descent algorithm
|
4
|
+
class GradientDescentOptimizer
|
5
|
+
attr_accessor :learning_rate
|
6
|
+
|
7
|
+
def initialize(learning_rate, options = {})
|
8
|
+
@learning_rate = learning_rate
|
9
|
+
end
|
10
|
+
|
11
|
+
def minimize(cost)
|
12
|
+
trainable_vars = TensorStream::Graph.get_default_graph.
|
13
|
+
get_collection(GraphKeys::GLOBAL_VARIABLES).
|
14
|
+
select(&:trainable)
|
15
|
+
|
16
|
+
derivatives = TensorStream.gradients(cost, trainable_vars)
|
17
|
+
trainable_vars.each_with_index.collect do |var, index|
|
18
|
+
var.assign_sub(derivatives[index] * @learning_rate)
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
@@ -0,0 +1,61 @@
|
|
1
|
+
require 'json'
|
2
|
+
|
3
|
+
module TensorStream
|
4
|
+
module Train
|
5
|
+
class Saver
|
6
|
+
def save(session, outputfile,
|
7
|
+
global_step: nil,
|
8
|
+
latest_filename: nil,
|
9
|
+
meta_graph_suffix: 'meta',
|
10
|
+
write_meta_graph: true,
|
11
|
+
write_state: true,
|
12
|
+
strip_default_attrs: false)
|
13
|
+
vars = TensorStream::Graph.get_default_graph.get_collection(GraphKeys::GLOBAL_VARIABLES)
|
14
|
+
|
15
|
+
variables = {}
|
16
|
+
graph = {}
|
17
|
+
gs = eval_global_step(session, global_step)
|
18
|
+
output_dump = {
|
19
|
+
variables: variables,
|
20
|
+
graph: graph,
|
21
|
+
global_step: gs
|
22
|
+
}
|
23
|
+
|
24
|
+
vars.each do |variable|
|
25
|
+
variables[variable.name] = variable.value
|
26
|
+
end
|
27
|
+
|
28
|
+
basename = File.basename(outputfile)
|
29
|
+
path = File.dirname(outputfile)
|
30
|
+
|
31
|
+
new_filename = File.join(path, [basename, gs].compact.join('-'))
|
32
|
+
File.write(new_filename, output_dump.to_json)
|
33
|
+
|
34
|
+
path
|
35
|
+
end
|
36
|
+
|
37
|
+
def restore(session, inputfile)
|
38
|
+
input_dump = JSON.parse(File.read(inputfile))
|
39
|
+
|
40
|
+
vars = TensorStream::Graph.get_default_graph.get_collection(GraphKeys::GLOBAL_VARIABLES)
|
41
|
+
vars.each do |variable|
|
42
|
+
variable.value = input_dump["variables"][variable.name]
|
43
|
+
end
|
44
|
+
end
|
45
|
+
|
46
|
+
private
|
47
|
+
|
48
|
+
def eval_global_step(session, global_step)
|
49
|
+
return nil if global_step.nil?
|
50
|
+
|
51
|
+
if (global_step.is_a?(Tensor))
|
52
|
+
session.last_session_context(global_step.name)
|
53
|
+
elsif (global_step.is_a?(String) || global_step.is_a?(Symbol))
|
54
|
+
session.last_session_context(global_step)
|
55
|
+
else
|
56
|
+
global_step.to_i
|
57
|
+
end
|
58
|
+
end
|
59
|
+
end
|
60
|
+
end
|
61
|
+
end
|
@@ -0,0 +1,52 @@
|
|
1
|
+
module TensorStream
|
2
|
+
class Variable < Tensor
|
3
|
+
attr_accessor :trainable
|
4
|
+
def initialize(data_type, rank, shape, options = {})
|
5
|
+
@data_type = data_type
|
6
|
+
@rank = rank
|
7
|
+
@shape = TensorShape.new(shape, rank)
|
8
|
+
@value = nil
|
9
|
+
@source = set_source(caller_locations)
|
10
|
+
@graph = options[:graph] || TensorStream.get_default_graph
|
11
|
+
@name = options[:name] || build_name
|
12
|
+
if options[:initializer]
|
13
|
+
@initalizer_tensor = options[:initializer]
|
14
|
+
end
|
15
|
+
@trainable = options.fetch(:trainable, true)
|
16
|
+
@graph.add_variable(self, options)
|
17
|
+
end
|
18
|
+
|
19
|
+
def initializer
|
20
|
+
@initalizer_tensor.shape = @shape
|
21
|
+
assign(@initalizer_tensor)
|
22
|
+
end
|
23
|
+
|
24
|
+
def assign(value)
|
25
|
+
Operation.new(:assign, self, value)
|
26
|
+
end
|
27
|
+
|
28
|
+
def read_value
|
29
|
+
@value
|
30
|
+
end
|
31
|
+
|
32
|
+
def assign_add(value)
|
33
|
+
Operation.new(:assign_add, self, value)
|
34
|
+
end
|
35
|
+
|
36
|
+
def to_math(tensor, name_only = false, max_depth = 99)
|
37
|
+
@name
|
38
|
+
end
|
39
|
+
|
40
|
+
def assign_sub(value)
|
41
|
+
Operation.new(:assign_sub, self, value)
|
42
|
+
end
|
43
|
+
|
44
|
+
def self.variables_initializer(collection)
|
45
|
+
TensorStream.group(TensorStream.get_default_graph.get_collection(collection).map(&:initializer))
|
46
|
+
end
|
47
|
+
|
48
|
+
def self.global_variables_initializer
|
49
|
+
variables_initializer(TensorStream::GraphKeys::GLOBAL_VARIABLES)
|
50
|
+
end
|
51
|
+
end
|
52
|
+
end
|
data/samples/iris.data
ADDED
@@ -0,0 +1,150 @@
|
|
1
|
+
5.1,3.5,1.4,0.2,Iris-setosa
|
2
|
+
4.9,3.0,1.4,0.2,Iris-setosa
|
3
|
+
4.7,3.2,1.3,0.2,Iris-setosa
|
4
|
+
4.6,3.1,1.5,0.2,Iris-setosa
|
5
|
+
5.0,3.6,1.4,0.2,Iris-setosa
|
6
|
+
5.4,3.9,1.7,0.4,Iris-setosa
|
7
|
+
4.6,3.4,1.4,0.3,Iris-setosa
|
8
|
+
5.0,3.4,1.5,0.2,Iris-setosa
|
9
|
+
4.4,2.9,1.4,0.2,Iris-setosa
|
10
|
+
4.9,3.1,1.5,0.1,Iris-setosa
|
11
|
+
5.4,3.7,1.5,0.2,Iris-setosa
|
12
|
+
4.8,3.4,1.6,0.2,Iris-setosa
|
13
|
+
4.8,3.0,1.4,0.1,Iris-setosa
|
14
|
+
4.3,3.0,1.1,0.1,Iris-setosa
|
15
|
+
5.8,4.0,1.2,0.2,Iris-setosa
|
16
|
+
5.7,4.4,1.5,0.4,Iris-setosa
|
17
|
+
5.4,3.9,1.3,0.4,Iris-setosa
|
18
|
+
5.1,3.5,1.4,0.3,Iris-setosa
|
19
|
+
5.7,3.8,1.7,0.3,Iris-setosa
|
20
|
+
5.1,3.8,1.5,0.3,Iris-setosa
|
21
|
+
5.4,3.4,1.7,0.2,Iris-setosa
|
22
|
+
5.1,3.7,1.5,0.4,Iris-setosa
|
23
|
+
4.6,3.6,1.0,0.2,Iris-setosa
|
24
|
+
5.1,3.3,1.7,0.5,Iris-setosa
|
25
|
+
4.8,3.4,1.9,0.2,Iris-setosa
|
26
|
+
5.0,3.0,1.6,0.2,Iris-setosa
|
27
|
+
5.0,3.4,1.6,0.4,Iris-setosa
|
28
|
+
5.2,3.5,1.5,0.2,Iris-setosa
|
29
|
+
5.2,3.4,1.4,0.2,Iris-setosa
|
30
|
+
4.7,3.2,1.6,0.2,Iris-setosa
|
31
|
+
4.8,3.1,1.6,0.2,Iris-setosa
|
32
|
+
5.4,3.4,1.5,0.4,Iris-setosa
|
33
|
+
5.2,4.1,1.5,0.1,Iris-setosa
|
34
|
+
5.5,4.2,1.4,0.2,Iris-setosa
|
35
|
+
4.9,3.1,1.5,0.1,Iris-setosa
|
36
|
+
5.0,3.2,1.2,0.2,Iris-setosa
|
37
|
+
5.5,3.5,1.3,0.2,Iris-setosa
|
38
|
+
4.9,3.1,1.5,0.1,Iris-setosa
|
39
|
+
4.4,3.0,1.3,0.2,Iris-setosa
|
40
|
+
5.1,3.4,1.5,0.2,Iris-setosa
|
41
|
+
5.0,3.5,1.3,0.3,Iris-setosa
|
42
|
+
4.5,2.3,1.3,0.3,Iris-setosa
|
43
|
+
4.4,3.2,1.3,0.2,Iris-setosa
|
44
|
+
5.0,3.5,1.6,0.6,Iris-setosa
|
45
|
+
5.1,3.8,1.9,0.4,Iris-setosa
|
46
|
+
4.8,3.0,1.4,0.3,Iris-setosa
|
47
|
+
5.1,3.8,1.6,0.2,Iris-setosa
|
48
|
+
4.6,3.2,1.4,0.2,Iris-setosa
|
49
|
+
5.3,3.7,1.5,0.2,Iris-setosa
|
50
|
+
5.0,3.3,1.4,0.2,Iris-setosa
|
51
|
+
7.0,3.2,4.7,1.4,Iris-versicolor
|
52
|
+
6.4,3.2,4.5,1.5,Iris-versicolor
|
53
|
+
6.9,3.1,4.9,1.5,Iris-versicolor
|
54
|
+
5.5,2.3,4.0,1.3,Iris-versicolor
|
55
|
+
6.5,2.8,4.6,1.5,Iris-versicolor
|
56
|
+
5.7,2.8,4.5,1.3,Iris-versicolor
|
57
|
+
6.3,3.3,4.7,1.6,Iris-versicolor
|
58
|
+
4.9,2.4,3.3,1.0,Iris-versicolor
|
59
|
+
6.6,2.9,4.6,1.3,Iris-versicolor
|
60
|
+
5.2,2.7,3.9,1.4,Iris-versicolor
|
61
|
+
5.0,2.0,3.5,1.0,Iris-versicolor
|
62
|
+
5.9,3.0,4.2,1.5,Iris-versicolor
|
63
|
+
6.0,2.2,4.0,1.0,Iris-versicolor
|
64
|
+
6.1,2.9,4.7,1.4,Iris-versicolor
|
65
|
+
5.6,2.9,3.6,1.3,Iris-versicolor
|
66
|
+
6.7,3.1,4.4,1.4,Iris-versicolor
|
67
|
+
5.6,3.0,4.5,1.5,Iris-versicolor
|
68
|
+
5.8,2.7,4.1,1.0,Iris-versicolor
|
69
|
+
6.2,2.2,4.5,1.5,Iris-versicolor
|
70
|
+
5.6,2.5,3.9,1.1,Iris-versicolor
|
71
|
+
5.9,3.2,4.8,1.8,Iris-versicolor
|
72
|
+
6.1,2.8,4.0,1.3,Iris-versicolor
|
73
|
+
6.3,2.5,4.9,1.5,Iris-versicolor
|
74
|
+
6.1,2.8,4.7,1.2,Iris-versicolor
|
75
|
+
6.4,2.9,4.3,1.3,Iris-versicolor
|
76
|
+
6.6,3.0,4.4,1.4,Iris-versicolor
|
77
|
+
6.8,2.8,4.8,1.4,Iris-versicolor
|
78
|
+
6.7,3.0,5.0,1.7,Iris-versicolor
|
79
|
+
6.0,2.9,4.5,1.5,Iris-versicolor
|
80
|
+
5.7,2.6,3.5,1.0,Iris-versicolor
|
81
|
+
5.5,2.4,3.8,1.1,Iris-versicolor
|
82
|
+
5.5,2.4,3.7,1.0,Iris-versicolor
|
83
|
+
5.8,2.7,3.9,1.2,Iris-versicolor
|
84
|
+
6.0,2.7,5.1,1.6,Iris-versicolor
|
85
|
+
5.4,3.0,4.5,1.5,Iris-versicolor
|
86
|
+
6.0,3.4,4.5,1.6,Iris-versicolor
|
87
|
+
6.7,3.1,4.7,1.5,Iris-versicolor
|
88
|
+
6.3,2.3,4.4,1.3,Iris-versicolor
|
89
|
+
5.6,3.0,4.1,1.3,Iris-versicolor
|
90
|
+
5.5,2.5,4.0,1.3,Iris-versicolor
|
91
|
+
5.5,2.6,4.4,1.2,Iris-versicolor
|
92
|
+
6.1,3.0,4.6,1.4,Iris-versicolor
|
93
|
+
5.8,2.6,4.0,1.2,Iris-versicolor
|
94
|
+
5.0,2.3,3.3,1.0,Iris-versicolor
|
95
|
+
5.6,2.7,4.2,1.3,Iris-versicolor
|
96
|
+
5.7,3.0,4.2,1.2,Iris-versicolor
|
97
|
+
5.7,2.9,4.2,1.3,Iris-versicolor
|
98
|
+
6.2,2.9,4.3,1.3,Iris-versicolor
|
99
|
+
5.1,2.5,3.0,1.1,Iris-versicolor
|
100
|
+
5.7,2.8,4.1,1.3,Iris-versicolor
|
101
|
+
6.3,3.3,6.0,2.5,Iris-virginica
|
102
|
+
5.8,2.7,5.1,1.9,Iris-virginica
|
103
|
+
7.1,3.0,5.9,2.1,Iris-virginica
|
104
|
+
6.3,2.9,5.6,1.8,Iris-virginica
|
105
|
+
6.5,3.0,5.8,2.2,Iris-virginica
|
106
|
+
7.6,3.0,6.6,2.1,Iris-virginica
|
107
|
+
4.9,2.5,4.5,1.7,Iris-virginica
|
108
|
+
7.3,2.9,6.3,1.8,Iris-virginica
|
109
|
+
6.7,2.5,5.8,1.8,Iris-virginica
|
110
|
+
7.2,3.6,6.1,2.5,Iris-virginica
|
111
|
+
6.5,3.2,5.1,2.0,Iris-virginica
|
112
|
+
6.4,2.7,5.3,1.9,Iris-virginica
|
113
|
+
6.8,3.0,5.5,2.1,Iris-virginica
|
114
|
+
5.7,2.5,5.0,2.0,Iris-virginica
|
115
|
+
5.8,2.8,5.1,2.4,Iris-virginica
|
116
|
+
6.4,3.2,5.3,2.3,Iris-virginica
|
117
|
+
6.5,3.0,5.5,1.8,Iris-virginica
|
118
|
+
7.7,3.8,6.7,2.2,Iris-virginica
|
119
|
+
7.7,2.6,6.9,2.3,Iris-virginica
|
120
|
+
6.0,2.2,5.0,1.5,Iris-virginica
|
121
|
+
6.9,3.2,5.7,2.3,Iris-virginica
|
122
|
+
5.6,2.8,4.9,2.0,Iris-virginica
|
123
|
+
7.7,2.8,6.7,2.0,Iris-virginica
|
124
|
+
6.3,2.7,4.9,1.8,Iris-virginica
|
125
|
+
6.7,3.3,5.7,2.1,Iris-virginica
|
126
|
+
7.2,3.2,6.0,1.8,Iris-virginica
|
127
|
+
6.2,2.8,4.8,1.8,Iris-virginica
|
128
|
+
6.1,3.0,4.9,1.8,Iris-virginica
|
129
|
+
6.4,2.8,5.6,2.1,Iris-virginica
|
130
|
+
7.2,3.0,5.8,1.6,Iris-virginica
|
131
|
+
7.4,2.8,6.1,1.9,Iris-virginica
|
132
|
+
7.9,3.8,6.4,2.0,Iris-virginica
|
133
|
+
6.4,2.8,5.6,2.2,Iris-virginica
|
134
|
+
6.3,2.8,5.1,1.5,Iris-virginica
|
135
|
+
6.1,2.6,5.6,1.4,Iris-virginica
|
136
|
+
7.7,3.0,6.1,2.3,Iris-virginica
|
137
|
+
6.3,3.4,5.6,2.4,Iris-virginica
|
138
|
+
6.4,3.1,5.5,1.8,Iris-virginica
|
139
|
+
6.0,3.0,4.8,1.8,Iris-virginica
|
140
|
+
6.9,3.1,5.4,2.1,Iris-virginica
|
141
|
+
6.7,3.1,5.6,2.4,Iris-virginica
|
142
|
+
6.9,3.1,5.1,2.3,Iris-virginica
|
143
|
+
5.8,2.7,5.1,1.9,Iris-virginica
|
144
|
+
6.8,3.2,5.9,2.3,Iris-virginica
|
145
|
+
6.7,3.3,5.7,2.5,Iris-virginica
|
146
|
+
6.7,3.0,5.2,2.3,Iris-virginica
|
147
|
+
6.3,2.5,5.0,1.9,Iris-virginica
|
148
|
+
6.5,3.0,5.2,2.0,Iris-virginica
|
149
|
+
6.2,3.4,5.4,2.3,Iris-virginica
|
150
|
+
5.9,3.0,5.1,1.8,Iris-virginica
|
data/samples/iris.rb
ADDED
@@ -0,0 +1,117 @@
|
|
1
|
+
require "bundler/setup"
|
2
|
+
require 'tensor_stream'
|
3
|
+
require 'pry-byebug'
|
4
|
+
|
5
|
+
# This neural network will predict the species of an iris based on sepal and petal size
|
6
|
+
# Dataset: http://en.wikipedia.org/wiki/Iris_flower_data_set
|
7
|
+
|
8
|
+
rows = File.readlines(File.join("samples","iris.data")).map {|l| l.chomp.split(',') }
|
9
|
+
|
10
|
+
rows.shuffle!
|
11
|
+
|
12
|
+
label_encodings = {
|
13
|
+
"Iris-setosa" => [1, 0, 0],
|
14
|
+
"Iris-versicolor" => [0, 1, 0],
|
15
|
+
"Iris-virginica" => [0, 0 ,1]
|
16
|
+
}
|
17
|
+
|
18
|
+
x_data = rows.map {|row| row[0,4].map(&:to_f) }
|
19
|
+
y_data = rows.map {|row| label_encodings[row[4]] }
|
20
|
+
|
21
|
+
# Normalize data values before feeding into network
|
22
|
+
normalize = -> (val, high, low) { (val - low) / (high - low) } # maps input to float between 0 and 1
|
23
|
+
|
24
|
+
columns = (0..3).map do |i|
|
25
|
+
x_data.map {|row| row[i] }
|
26
|
+
end
|
27
|
+
|
28
|
+
x_data.map! do |row|
|
29
|
+
row.map.with_index do |val, j|
|
30
|
+
max, min = columns[j].max, columns[j].min
|
31
|
+
normalize.(val, max, min)
|
32
|
+
end
|
33
|
+
end
|
34
|
+
|
35
|
+
x_train = x_data.slice(0, 100)
|
36
|
+
y_train = y_data.slice(0, 100)
|
37
|
+
|
38
|
+
x_test = x_data.slice(100, 50)
|
39
|
+
y_test = y_data.slice(100, 50)
|
40
|
+
|
41
|
+
test_cases = []
|
42
|
+
x_train.each_with_index do |x, index|
|
43
|
+
test_cases << [x, y_train[index] ]
|
44
|
+
end
|
45
|
+
|
46
|
+
validation_cases = []
|
47
|
+
x_test.each_with_index do |x, index|
|
48
|
+
validation_cases << [x, y_test[index] ]
|
49
|
+
end
|
50
|
+
|
51
|
+
learning_rate = 0.1
|
52
|
+
num_steps = 500
|
53
|
+
batch_size = 128
|
54
|
+
display_step = 100
|
55
|
+
|
56
|
+
# Network Parameters
|
57
|
+
n_hidden_1 = 32 # 1st layer number of neurons
|
58
|
+
n_hidden_2 = 32 # 2nd layer number of neurons
|
59
|
+
num_classes = 3 # MNIST total classes (0-9 digits)
|
60
|
+
num_input = 4
|
61
|
+
training_epochs = 10
|
62
|
+
|
63
|
+
tf = TensorStream
|
64
|
+
|
65
|
+
# tf Graph input
|
66
|
+
x = tf.placeholder("float", shape: [nil, num_input], name: 'x')
|
67
|
+
y = tf.placeholder("float", shape: [nil, num_classes], name: 'y')
|
68
|
+
|
69
|
+
# Store layers weight & bias
|
70
|
+
weights = {
|
71
|
+
h1: tf.Variable(tf.random_normal([num_input, n_hidden_1]), name: 'h1'),
|
72
|
+
h2: tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2]), name: 'h2'),
|
73
|
+
out: tf.Variable(tf.random_normal([n_hidden_2, num_classes]), name: 'out')
|
74
|
+
}
|
75
|
+
|
76
|
+
biases = {
|
77
|
+
b1: tf.Variable(tf.random_normal([n_hidden_1]), name: 'b1'),
|
78
|
+
b2: tf.Variable(tf.random_normal([n_hidden_2]), name: 'b2'),
|
79
|
+
out: tf.Variable(tf.random_normal([num_classes]), name: 'b_out')
|
80
|
+
}
|
81
|
+
|
82
|
+
|
83
|
+
# Create model
|
84
|
+
def neural_net(x, weights, biases)
|
85
|
+
# Hidden fully connected layer with 256 neurons
|
86
|
+
layer_1 = TensorStream.add(TensorStream.matmul(x, weights[:h1]), biases[:b1] , name: 'layer1_add')
|
87
|
+
# Hidden fully connected layer with 256 neurons
|
88
|
+
layer_2 = TensorStream.add(TensorStream.matmul(layer_1, weights[:h2]), biases[:b2], name: 'layer2_add')
|
89
|
+
# Output fully connected layer with a neuron for each class
|
90
|
+
TensorStream.matmul(layer_2, weights[:out]) + biases[:out]
|
91
|
+
end
|
92
|
+
|
93
|
+
# Construct model
|
94
|
+
logits = neural_net(x, weights, biases)
|
95
|
+
|
96
|
+
# Mean squared error
|
97
|
+
cost = TensorStream.reduce_sum(TensorStream.pow(logits - y, 2)) / ( 2 * y_train.size)
|
98
|
+
optimizer = TensorStream::Train::GradientDescentOptimizer.new(learning_rate).minimize(cost)
|
99
|
+
|
100
|
+
# Initialize the variables (i.e. assign their default value)
|
101
|
+
init = TensorStream.global_variables_initializer()
|
102
|
+
|
103
|
+
TensorStream.Session do |sess|
|
104
|
+
puts "init vars"
|
105
|
+
sess.run(init)
|
106
|
+
puts "Testing the untrained network..."
|
107
|
+
loss = sess.run(cost, feed_dict: { x => x_train, y => y_train })
|
108
|
+
puts sess.run(loss)
|
109
|
+
puts "loss before training"
|
110
|
+
(0..training_epochs).each do |epoch|
|
111
|
+
sess.run(optimizer, feed_dict: { x => x_train, y => y_train })
|
112
|
+
loss = sess.run(cost, feed_dict: { x => x_train, y => y_train })
|
113
|
+
puts "loss #{loss}"
|
114
|
+
end
|
115
|
+
loss = sess.run(cost, feed_dict: { x => x_train, y => y_train })
|
116
|
+
puts "loss after training #{loss}"
|
117
|
+
end
|