ocarina 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data/.gitignore +23 -0
- data/Gemfile +4 -0
- data/LICENSE +204 -0
- data/NOTICE +13 -0
- data/README.md +83 -0
- data/Rakefile +146 -0
- data/data/.gitkeep +0 -0
- data/data/images/letterpress/board1.png +0 -0
- data/data/images/letterpress/board2.png +0 -0
- data/data/images/letterpress/board3.png +0 -0
- data/data/images/letterpress/board4.png +0 -0
- data/data/images/letterpress/board5.png +0 -0
- data/data/images/noise/.gitkeep +0 -0
- data/data/images/noise/0.gif +0 -0
- data/data/images/noise/1.gif +0 -0
- data/data/images/noise/2.gif +0 -0
- data/data/images/noise/3.gif +0 -0
- data/data/images/noise/4.gif +0 -0
- data/data/images/noise/5.gif +0 -0
- data/data/images/noise/6.gif +0 -0
- data/data/images/noise/7.gif +0 -0
- data/data/images/noise/8.gif +0 -0
- data/data/images/noise/9.gif +0 -0
- data/data/images/noise/A.gif +0 -0
- data/data/images/noise/B.gif +0 -0
- data/data/images/noise/C.gif +0 -0
- data/data/images/noise/D.gif +0 -0
- data/data/images/noise/E.gif +0 -0
- data/data/images/noise/F.gif +0 -0
- data/data/images/noise/G.gif +0 -0
- data/data/images/noise/H.gif +0 -0
- data/data/images/noise/I.gif +0 -0
- data/data/images/noise/J.gif +0 -0
- data/data/images/noise/K.gif +0 -0
- data/data/images/noise/L.gif +0 -0
- data/data/images/noise/M.gif +0 -0
- data/data/images/noise/N.gif +0 -0
- data/data/images/noise/O.gif +0 -0
- data/data/images/noise/P.gif +0 -0
- data/data/images/noise/Q.gif +0 -0
- data/data/images/noise/R.gif +0 -0
- data/data/images/noise/S.gif +0 -0
- data/data/images/noise/T.gif +0 -0
- data/data/images/noise/U.gif +0 -0
- data/data/images/noise/V.gif +0 -0
- data/data/images/noise/W.gif +0 -0
- data/data/images/noise/X.gif +0 -0
- data/data/images/noise/Y.gif +0 -0
- data/data/images/noise/Z.gif +0 -0
- data/data/images/noise/a_lower.gif +0 -0
- data/data/images/noise/b_lower.gif +0 -0
- data/data/images/noise/c_lower.gif +0 -0
- data/data/images/noise/d_lower.gif +0 -0
- data/data/images/noise/e_lower.gif +0 -0
- data/data/images/noise/f_lower.gif +0 -0
- data/data/images/noise/g_lower.gif +0 -0
- data/data/images/noise/h_lower.gif +0 -0
- data/data/images/noise/i_lower.gif +0 -0
- data/data/images/noise/j_lower.gif +0 -0
- data/data/images/noise/k_lower.gif +0 -0
- data/data/images/noise/l_lower.gif +0 -0
- data/data/images/noise/m_lower.gif +0 -0
- data/data/images/noise/n_lower.gif +0 -0
- data/data/images/noise/o_lower.gif +0 -0
- data/data/images/noise/p_lower.gif +0 -0
- data/data/images/noise/q_lower.gif +0 -0
- data/data/images/noise/r_lower.gif +0 -0
- data/data/images/noise/s_lower.gif +0 -0
- data/data/images/noise/t_lower.gif +0 -0
- data/data/images/noise/u_lower.gif +0 -0
- data/data/images/noise/v_lower.gif +0 -0
- data/data/images/noise/w_lower.gif +0 -0
- data/data/images/noise/x_lower.gif +0 -0
- data/data/images/noise/y_lower.gif +0 -0
- data/data/images/noise/z_lower.gif +0 -0
- data/data/images/reference/.gitkeep +0 -0
- data/data/images/reference/0.gif +0 -0
- data/data/images/reference/1.gif +0 -0
- data/data/images/reference/2.gif +0 -0
- data/data/images/reference/3.gif +0 -0
- data/data/images/reference/4.gif +0 -0
- data/data/images/reference/5.gif +0 -0
- data/data/images/reference/6.gif +0 -0
- data/data/images/reference/7.gif +0 -0
- data/data/images/reference/8.gif +0 -0
- data/data/images/reference/9.gif +0 -0
- data/data/images/reference/A.gif +0 -0
- data/data/images/reference/B.gif +0 -0
- data/data/images/reference/C.gif +0 -0
- data/data/images/reference/D.gif +0 -0
- data/data/images/reference/E.gif +0 -0
- data/data/images/reference/F.gif +0 -0
- data/data/images/reference/G.gif +0 -0
- data/data/images/reference/H.gif +0 -0
- data/data/images/reference/I.gif +0 -0
- data/data/images/reference/J.gif +0 -0
- data/data/images/reference/K.gif +0 -0
- data/data/images/reference/L.gif +0 -0
- data/data/images/reference/M.gif +0 -0
- data/data/images/reference/N.gif +0 -0
- data/data/images/reference/O.gif +0 -0
- data/data/images/reference/P.gif +0 -0
- data/data/images/reference/Q.gif +0 -0
- data/data/images/reference/R.gif +0 -0
- data/data/images/reference/S.gif +0 -0
- data/data/images/reference/T.gif +0 -0
- data/data/images/reference/U.gif +0 -0
- data/data/images/reference/V.gif +0 -0
- data/data/images/reference/W.gif +0 -0
- data/data/images/reference/X.gif +0 -0
- data/data/images/reference/Y.gif +0 -0
- data/data/images/reference/Z.gif +0 -0
- data/data/images/reference/a_lower.gif +0 -0
- data/data/images/reference/b_lower.gif +0 -0
- data/data/images/reference/c_lower.gif +0 -0
- data/data/images/reference/d_lower.gif +0 -0
- data/data/images/reference/e_lower.gif +0 -0
- data/data/images/reference/f_lower.gif +0 -0
- data/data/images/reference/g_lower.gif +0 -0
- data/data/images/reference/h_lower.gif +0 -0
- data/data/images/reference/i_lower.gif +0 -0
- data/data/images/reference/j_lower.gif +0 -0
- data/data/images/reference/k_lower.gif +0 -0
- data/data/images/reference/l_lower.gif +0 -0
- data/data/images/reference/m_lower.gif +0 -0
- data/data/images/reference/n_lower.gif +0 -0
- data/data/images/reference/o_lower.gif +0 -0
- data/data/images/reference/p_lower.gif +0 -0
- data/data/images/reference/q_lower.gif +0 -0
- data/data/images/reference/r_lower.gif +0 -0
- data/data/images/reference/s_lower.gif +0 -0
- data/data/images/reference/t_lower.gif +0 -0
- data/data/images/reference/u_lower.gif +0 -0
- data/data/images/reference/v_lower.gif +0 -0
- data/data/images/reference/w_lower.gif +0 -0
- data/data/images/reference/x_lower.gif +0 -0
- data/data/images/reference/y_lower.gif +0 -0
- data/data/images/reference/z_lower.gif +0 -0
- data/data/letterpress-train.bin +0 -0
- data/data/train.bin +0 -0
- data/lib/ocarina.rb +12 -0
- data/lib/ocarina/character_generator.rb +87 -0
- data/lib/ocarina/config.rb +35 -0
- data/lib/ocarina/error_stats.rb +62 -0
- data/lib/ocarina/letterpress_character_generator.rb +93 -0
- data/lib/ocarina/letterpress_cropper.rb +84 -0
- data/lib/ocarina/network.rb +264 -0
- data/lib/ocarina/util.rb +92 -0
- data/lib/ocarina/version.rb +3 -0
- data/ocarina.gemspec +20 -0
- data/ocarina.png +0 -0
- data/test/test_letterpress.rb +60 -0
- metadata +231 -0
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
require_relative 'util'
|
|
2
|
+
|
|
3
|
+
module Ocarina
|
|
4
|
+
|
|
5
|
+
# keeps track of running errors during evaluations
|
|
6
|
+
#
|
|
7
|
+
class ErrorStats
|
|
8
|
+
|
|
9
|
+
def initialize(config)
|
|
10
|
+
@num_outputs = config.num_outputs
|
|
11
|
+
|
|
12
|
+
@chars_seen = 0
|
|
13
|
+
@chars_wrong = 0
|
|
14
|
+
@bits_seen = 0
|
|
15
|
+
@bits_wrong = 0
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
# check and record the error from the expected and actual integers
|
|
19
|
+
#
|
|
20
|
+
def check_error(expected, actual)
|
|
21
|
+
@chars_seen += 1
|
|
22
|
+
@bits_seen += @num_outputs
|
|
23
|
+
|
|
24
|
+
if expected != actual
|
|
25
|
+
@chars_wrong +=1
|
|
26
|
+
puts "char wrong, expected: #{expected.chr}, guessed: #{actual.chr}"
|
|
27
|
+
|
|
28
|
+
expected_binary_string = int_to_binary_string expected
|
|
29
|
+
actual_binary_string = int_to_binary_string actual
|
|
30
|
+
|
|
31
|
+
@bits_wrong += count_differences expected_binary_string, actual_binary_string
|
|
32
|
+
|
|
33
|
+
#puts "expected: #{expected_binary_string}, decimal: #{expected}"
|
|
34
|
+
#puts "actual : #{actual_binary_string}, decimal: #{actual}"
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
# assumes a.size == b.size
|
|
40
|
+
def count_differences(a, b)
|
|
41
|
+
a.split(//).each.with_index.inject(0) { |diffs, (char, i)| char == b[i] ? diffs : diffs + 1}
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def report
|
|
46
|
+
puts "total characters evaluated: #@chars_seen"
|
|
47
|
+
puts "total characters wrong : #@chars_wrong"
|
|
48
|
+
puts "character accuracy : #{'%.2f' % character_accuracy}"
|
|
49
|
+
puts "bit accuracy : #{'%.2f' % bit_accuracy}"
|
|
50
|
+
end
|
|
51
|
+
|
|
52
|
+
def character_accuracy
|
|
53
|
+
(@chars_seen - @chars_wrong).to_f / @chars_seen * 100
|
|
54
|
+
end
|
|
55
|
+
|
|
56
|
+
def bit_accuracy
|
|
57
|
+
(@bits_seen - @bits_wrong).to_f / @bits_seen * 100
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
end
|
|
61
|
+
|
|
62
|
+
end
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
require 'RMagick'
|
|
2
|
+
|
|
3
|
+
module Ocarina
|
|
4
|
+
|
|
5
|
+
# creates bitmap images for characters using letterpress board game images
|
|
6
|
+
#
|
|
7
|
+
# We create the reference character images by cropping game board tiles for
|
|
8
|
+
# three boards for which the letters are known ahead of time. Together, the three
|
|
9
|
+
# boards provide images for all the letters of the alphabet.
|
|
10
|
+
#
|
|
11
|
+
class LetterpressCharacterGenerator < CharacterGenerator
|
|
12
|
+
|
|
13
|
+
def initialize(config)
|
|
14
|
+
@config = config
|
|
15
|
+
|
|
16
|
+
# generate reference images
|
|
17
|
+
#
|
|
18
|
+
@reference_image_hash = process_letterpress_example_boards
|
|
19
|
+
|
|
20
|
+
# generate noise images
|
|
21
|
+
#
|
|
22
|
+
@noise_image_hash = { }
|
|
23
|
+
|
|
24
|
+
@config.char_set.each_char do |char|
|
|
25
|
+
@noise_image_hash[char] = generate_noise_gif_for_char char
|
|
26
|
+
end
|
|
27
|
+
end
|
|
28
|
+
|
|
29
|
+
private
|
|
30
|
+
|
|
31
|
+
# create and return reference images from letterpress game boards.
|
|
32
|
+
# returns tiles as a hash of char -> image
|
|
33
|
+
#
|
|
34
|
+
def process_letterpress_example_boards
|
|
35
|
+
|
|
36
|
+
# board games were generated randomly by the game.
|
|
37
|
+
# with these three game boards, we have the letters A-Z.
|
|
38
|
+
|
|
39
|
+
result = { }
|
|
40
|
+
|
|
41
|
+
result.merge! create_letterpress_tiles("#{IMAGES_DIR}/letterpress/board1.png",
|
|
42
|
+
[
|
|
43
|
+
[ 'P', 'R', 'B', 'R', 'Z' ],
|
|
44
|
+
[ 'T', 'A', 'V', 'Z', 'R' ],
|
|
45
|
+
[ 'B', 'D', 'A', 'K', 'Y' ],
|
|
46
|
+
[ 'G', 'I', 'G', 'K', 'F' ],
|
|
47
|
+
[ 'R', 'Y', 'S', 'J', 'V' ]
|
|
48
|
+
])
|
|
49
|
+
|
|
50
|
+
result.merge! create_letterpress_tiles("#{IMAGES_DIR}/letterpress/board2.png",
|
|
51
|
+
[
|
|
52
|
+
[ 'Q', 'D', 'F', 'P', 'M' ],
|
|
53
|
+
[ 'N', 'E', 'E', 'S', 'I' ],
|
|
54
|
+
[ 'A', 'W', 'F', 'M', 'L' ],
|
|
55
|
+
[ 'F', 'R', 'P', 'T', 'T' ],
|
|
56
|
+
[ 'K', 'C', 'S', 'S', 'Y' ]
|
|
57
|
+
])
|
|
58
|
+
|
|
59
|
+
result.merge! create_letterpress_tiles("#{IMAGES_DIR}/letterpress/board3.png",
|
|
60
|
+
[
|
|
61
|
+
[ 'L', 'H', 'F', 'L', 'M' ],
|
|
62
|
+
[ 'R', 'V', 'P', 'U', 'K' ],
|
|
63
|
+
[ 'V', 'O', 'E', 'E', 'X' ],
|
|
64
|
+
[ 'I', 'N', 'R', 'I', 'T' ],
|
|
65
|
+
[ 'V', 'N', 'S', 'I', 'Q' ]
|
|
66
|
+
])
|
|
67
|
+
|
|
68
|
+
result
|
|
69
|
+
end
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# create character files from the given letterpress board image
|
|
73
|
+
# returns tiles as a hash
|
|
74
|
+
#
|
|
75
|
+
def create_letterpress_tiles(input_file, character_map)
|
|
76
|
+
board = Magick::Image.read(input_file).first
|
|
77
|
+
|
|
78
|
+
cropper = LetterpressCropper.new(@config)
|
|
79
|
+
tile_rows = cropper.crop board
|
|
80
|
+
|
|
81
|
+
result = { }
|
|
82
|
+
tile_rows.zip(character_map) do |tile_row, char_row|
|
|
83
|
+
tile_row.zip(char_row) do |tile, char|
|
|
84
|
+
result[char] = tile
|
|
85
|
+
end
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
result
|
|
89
|
+
end
|
|
90
|
+
|
|
91
|
+
end
|
|
92
|
+
|
|
93
|
+
end
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
require 'RMagick'
|
|
2
|
+
|
|
3
|
+
module Ocarina
|
|
4
|
+
|
|
5
|
+
LETTERPRESS_TILES_ACROSS = 5
|
|
6
|
+
LETTERPRESS_TILES_DOWN = 5
|
|
7
|
+
LETTERPRESS_TILE_PIXELS = 128
|
|
8
|
+
LETTERPRESS_HEIGHT_OFFSET = 496
|
|
9
|
+
LETTERPRESS_EXPECTED_WIDTH = LETTERPRESS_TILES_ACROSS * LETTERPRESS_TILE_PIXELS
|
|
10
|
+
LETTERPRESS_EXPECTED_HEIGHT = 1136
|
|
11
|
+
|
|
12
|
+
# creates tiles of character images from letterpress game boards
|
|
13
|
+
#
|
|
14
|
+
class LetterpressCropper
|
|
15
|
+
include Ocarina::Util
|
|
16
|
+
|
|
17
|
+
def initialize(config)
|
|
18
|
+
@config = config
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
# crops the board into tiles, runs recognizer on each of the tiles,
|
|
22
|
+
# and returns resulting array of array of chars
|
|
23
|
+
#
|
|
24
|
+
def decipher_board(network, board_image)
|
|
25
|
+
tile_rows = crop board_image
|
|
26
|
+
|
|
27
|
+
result = [ ]
|
|
28
|
+
|
|
29
|
+
tile_rows.each do |tile_row|
|
|
30
|
+
row = [ ]
|
|
31
|
+
tile_row.each do |tile|
|
|
32
|
+
row << network.recognize(tile).chr
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
result << row
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
result
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# returns an N by N array of image tiles
|
|
43
|
+
#
|
|
44
|
+
def crop(image)
|
|
45
|
+
if image.columns != LETTERPRESS_EXPECTED_WIDTH || image.rows != LETTERPRESS_EXPECTED_HEIGHT
|
|
46
|
+
image = image.resize(LETTERPRESS_EXPECTED_WIDTH, LETTERPRESS_EXPECTED_HEIGHT)
|
|
47
|
+
end
|
|
48
|
+
|
|
49
|
+
image = quantize_image(image)
|
|
50
|
+
|
|
51
|
+
y_offset = LETTERPRESS_HEIGHT_OFFSET
|
|
52
|
+
|
|
53
|
+
rows = [ ]
|
|
54
|
+
|
|
55
|
+
border = 1
|
|
56
|
+
|
|
57
|
+
0.upto(LETTERPRESS_TILES_DOWN - 1) do
|
|
58
|
+
x_offset = 0
|
|
59
|
+
row = [ ]
|
|
60
|
+
|
|
61
|
+
0.upto(LETTERPRESS_TILES_ACROSS - 1) do
|
|
62
|
+
|
|
63
|
+
tile = image.crop(x_offset - border, y_offset + border, LETTERPRESS_TILE_PIXELS - border, LETTERPRESS_TILE_PIXELS - border, true)
|
|
64
|
+
box = tile.bounding_box
|
|
65
|
+
min_bound_width = 0.75 * @config.char_width
|
|
66
|
+
if box.width > min_bound_width
|
|
67
|
+
tile = tile.crop(box.x - border, box.y - border, box.width + 2 * border, box.height + 2 * border, true)
|
|
68
|
+
end
|
|
69
|
+
|
|
70
|
+
tile.resize!(@config.char_width, @config.char_height)
|
|
71
|
+
row << tile
|
|
72
|
+
x_offset += LETTERPRESS_TILE_PIXELS
|
|
73
|
+
|
|
74
|
+
end
|
|
75
|
+
rows << row
|
|
76
|
+
|
|
77
|
+
y_offset += LETTERPRESS_TILE_PIXELS
|
|
78
|
+
end
|
|
79
|
+
|
|
80
|
+
rows
|
|
81
|
+
end
|
|
82
|
+
end
|
|
83
|
+
|
|
84
|
+
end
|
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
require_relative 'util'
|
|
2
|
+
|
|
3
|
+
module Ocarina
|
|
4
|
+
|
|
5
|
+
# a network of neurons
|
|
6
|
+
#
|
|
7
|
+
class Network
|
|
8
|
+
|
|
9
|
+
include Ocarina::Util
|
|
10
|
+
|
|
11
|
+
attr_accessor :current_error, :config
|
|
12
|
+
|
|
13
|
+
def initialize(config)
|
|
14
|
+
|
|
15
|
+
@config = config
|
|
16
|
+
@num_inputs = config.num_inputs # total of bits in the image
|
|
17
|
+
|
|
18
|
+
#@hidden_count = (1.5 * num_inputs).to_i # somewhat arbitrary
|
|
19
|
+
@hidden_count = 25
|
|
20
|
+
|
|
21
|
+
@input_values = [] # image bits
|
|
22
|
+
@input_weights = [] # weights from inputs -> hidden nodes
|
|
23
|
+
|
|
24
|
+
@hidden_outputs = [] # after feed-forward, what the hidden nodes output
|
|
25
|
+
|
|
26
|
+
@output_weights = [] # weights from hidden nodes -> output nodes
|
|
27
|
+
@output_values = [] # after feed-forward, what the output nodes output
|
|
28
|
+
|
|
29
|
+
@output_errors = []
|
|
30
|
+
@hidden_errors = []
|
|
31
|
+
|
|
32
|
+
assign_random_weights
|
|
33
|
+
|
|
34
|
+
#puts "@input_weights: #{@input_weights}"
|
|
35
|
+
|
|
36
|
+
total_input_weights = @input_weights.map { |array| array.reduce(:+) }.reduce(:+)
|
|
37
|
+
puts "total_input_weights: #{total_input_weights}"
|
|
38
|
+
|
|
39
|
+
total_output_weights = @output_weights.map { |array| array.reduce(:+) }.reduce(:+)
|
|
40
|
+
puts "total_output_weights: #{total_output_weights}"
|
|
41
|
+
end
|
|
42
|
+
|
|
43
|
+
# Attempt to recognize the character displayed on the given image.
|
|
44
|
+
# image should be an instance of Magick::Image.
|
|
45
|
+
#
|
|
46
|
+
# Returns the integer ASCII code for the recognized character.
|
|
47
|
+
#
|
|
48
|
+
def recognize(image)
|
|
49
|
+
# quantize to two-color
|
|
50
|
+
image = quantize_image(image)
|
|
51
|
+
|
|
52
|
+
# the binary string we expect to see from the output nodes
|
|
53
|
+
assign_inputs image
|
|
54
|
+
|
|
55
|
+
calculate_hidden_outputs
|
|
56
|
+
calculate_final_outputs
|
|
57
|
+
|
|
58
|
+
#@output_values.each.with_index { |v, i| puts "index: #{i} => #{v}" }
|
|
59
|
+
|
|
60
|
+
# process results
|
|
61
|
+
#
|
|
62
|
+
binary_string = quantized_result.inject("") { |accum, bit| "#{accum}#{bit.to_s}" }
|
|
63
|
+
binary_string.to_i(2)
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
# Train the network on the image, using target_char as the expected result.
|
|
67
|
+
#
|
|
68
|
+
# image should be an instance of Magick::Image.
|
|
69
|
+
#
|
|
70
|
+
def train(image, target_char)
|
|
71
|
+
# quantize to two-color
|
|
72
|
+
image = quantize_image(image)
|
|
73
|
+
#image.write(filename_for_quantized_image(target_char, 'gif'))
|
|
74
|
+
|
|
75
|
+
# the binary string we expect to see from the output nodes
|
|
76
|
+
@target_binary_string = char_to_binary_string(target_char)
|
|
77
|
+
|
|
78
|
+
assign_inputs image
|
|
79
|
+
|
|
80
|
+
calculate_hidden_outputs
|
|
81
|
+
calculate_final_outputs
|
|
82
|
+
calculate_output_errors
|
|
83
|
+
calculate_hidden_errors
|
|
84
|
+
|
|
85
|
+
# process results
|
|
86
|
+
#
|
|
87
|
+
adjust_output_weights
|
|
88
|
+
adjust_input_weights
|
|
89
|
+
end
|
|
90
|
+
|
|
91
|
+
# persist the network
|
|
92
|
+
#
|
|
93
|
+
def save_network_to_file(filepath)
|
|
94
|
+
File.open(filepath,'w') do|file|
|
|
95
|
+
Marshal.dump(self, file)
|
|
96
|
+
end
|
|
97
|
+
end
|
|
98
|
+
|
|
99
|
+
# load a previously-trained network
|
|
100
|
+
#
|
|
101
|
+
def self.load_network_from_file(filepath)
|
|
102
|
+
File.open(filepath) do |file|
|
|
103
|
+
Marshal.load(file)
|
|
104
|
+
end
|
|
105
|
+
end
|
|
106
|
+
|
|
107
|
+
private
|
|
108
|
+
|
|
109
|
+
def assign_inputs(image)
|
|
110
|
+
|
|
111
|
+
num_pixels = image.rows * image.columns
|
|
112
|
+
|
|
113
|
+
num_pixels.times do |n|
|
|
114
|
+
col = pixel_number_to_col(n, image)
|
|
115
|
+
row = pixel_number_to_row(n, image)
|
|
116
|
+
|
|
117
|
+
pixel = image.pixel_color(col, row)
|
|
118
|
+
|
|
119
|
+
@input_values[n] = pixel_to_bit(pixel)
|
|
120
|
+
end
|
|
121
|
+
|
|
122
|
+
#text = inputs_as_text image.columns
|
|
123
|
+
#puts text
|
|
124
|
+
end
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def assign_random_weights
|
|
128
|
+
weight_range = 0.0001..0.9
|
|
129
|
+
|
|
130
|
+
# input -> hidden weights
|
|
131
|
+
#
|
|
132
|
+
@num_inputs.times do |input|
|
|
133
|
+
|
|
134
|
+
@input_weights[input] = [ ]
|
|
135
|
+
|
|
136
|
+
@hidden_count.times do |hidden|
|
|
137
|
+
|
|
138
|
+
# we want the overall sum of weights to be < 1
|
|
139
|
+
weight = rand(weight_range) / (@num_inputs * @hidden_count)
|
|
140
|
+
@input_weights[input][hidden] = weight
|
|
141
|
+
|
|
142
|
+
#puts "input_weights[#{input}][#{hidden}] => #{@input_weights[input][hidden]}"
|
|
143
|
+
end
|
|
144
|
+
|
|
145
|
+
end
|
|
146
|
+
|
|
147
|
+
# hidden -> output weights
|
|
148
|
+
#
|
|
149
|
+
@hidden_count.times do |hidden|
|
|
150
|
+
|
|
151
|
+
@output_weights[hidden] = [ ]
|
|
152
|
+
|
|
153
|
+
@config.num_outputs.times do |output|
|
|
154
|
+
# we want the overall sum of weights to be < 1
|
|
155
|
+
weight = rand(weight_range) / (@hidden_count * @config.num_outputs)
|
|
156
|
+
@output_weights[hidden][output] = weight
|
|
157
|
+
end
|
|
158
|
+
|
|
159
|
+
end
|
|
160
|
+
|
|
161
|
+
end
|
|
162
|
+
|
|
163
|
+
def calculate_hidden_outputs
|
|
164
|
+
|
|
165
|
+
@hidden_count.times do |hidden|
|
|
166
|
+
sum = 0
|
|
167
|
+
@input_values.count.times do |input|
|
|
168
|
+
val = @input_values[input] * @input_weights[input][hidden]
|
|
169
|
+
#puts "input: #{@input_values[input]} * weight: #{@input_weights[input][hidden]} = #{val}"
|
|
170
|
+
|
|
171
|
+
sum += val
|
|
172
|
+
end
|
|
173
|
+
|
|
174
|
+
#puts "@hidden_outputs[#{hidden}] = #{sum} (before sigma)"
|
|
175
|
+
sum = sigma(sum)
|
|
176
|
+
#puts "@hidden_outputs[#{hidden}] = #{sum} (after sigma)"
|
|
177
|
+
|
|
178
|
+
@hidden_outputs[hidden] = sum
|
|
179
|
+
#puts "@hidden_outputs[#{hidden}] = #{@hidden_outputs[hidden]}"
|
|
180
|
+
end
|
|
181
|
+
|
|
182
|
+
end
|
|
183
|
+
|
|
184
|
+
def calculate_final_outputs
|
|
185
|
+
|
|
186
|
+
@config.num_outputs.times do |output|
|
|
187
|
+
sum = 0
|
|
188
|
+
|
|
189
|
+
@hidden_count.times do |hidden|
|
|
190
|
+
|
|
191
|
+
val = @hidden_outputs[hidden] * @output_weights[hidden][output]
|
|
192
|
+
sum += val
|
|
193
|
+
end
|
|
194
|
+
|
|
195
|
+
#puts "output: #{sum}, sigma: #{sigma sum}"
|
|
196
|
+
@output_values[output] = sigma sum
|
|
197
|
+
end
|
|
198
|
+
end
|
|
199
|
+
|
|
200
|
+
def calculate_output_errors
|
|
201
|
+
accum_error = 0
|
|
202
|
+
|
|
203
|
+
@config.num_outputs.times do |output|
|
|
204
|
+
expected = @target_binary_string[output].to_i
|
|
205
|
+
error = (expected - @output_values[output]) * (1.0 - @output_values[output]) * @output_values[output]
|
|
206
|
+
|
|
207
|
+
#puts "expected: #{expected}"
|
|
208
|
+
#puts "output: #{@output_values[output]}"
|
|
209
|
+
#puts "error: #{'%.10f' % error}"
|
|
210
|
+
accum_error += error ** 2
|
|
211
|
+
@output_errors[output] = error
|
|
212
|
+
end
|
|
213
|
+
|
|
214
|
+
# TODO: @current_error only represents the error from the last trained character.
|
|
215
|
+
# We should keep a running average for the current set of characters for each training "run".
|
|
216
|
+
#
|
|
217
|
+
@current_error = Math.sqrt(accum_error)
|
|
218
|
+
#puts "@current_error: #{'%.10f' % @current_error}"
|
|
219
|
+
end
|
|
220
|
+
|
|
221
|
+
def calculate_hidden_errors
|
|
222
|
+
@hidden_count.times do |hidden|
|
|
223
|
+
sum = 0
|
|
224
|
+
@config.num_outputs.times do |output|
|
|
225
|
+
sum += (@output_errors[output] * @output_weights[hidden][output])
|
|
226
|
+
end
|
|
227
|
+
|
|
228
|
+
#puts "sum is: #{sum}"
|
|
229
|
+
@hidden_errors[hidden] = @hidden_outputs[hidden] * (1 - @hidden_outputs[hidden]) * sum
|
|
230
|
+
#puts "@hidden_errors[#{hidden}] = #{@hidden_errors[hidden]}"
|
|
231
|
+
end
|
|
232
|
+
|
|
233
|
+
end
|
|
234
|
+
|
|
235
|
+
def adjust_output_weights
|
|
236
|
+
@hidden_count.times do |hidden|
|
|
237
|
+
@config.num_outputs.times do |output|
|
|
238
|
+
@output_weights[hidden][output] += (@output_errors[output] * @hidden_outputs[hidden])
|
|
239
|
+
end
|
|
240
|
+
end
|
|
241
|
+
end
|
|
242
|
+
|
|
243
|
+
def adjust_input_weights
|
|
244
|
+
@num_inputs.times do |input|
|
|
245
|
+
@hidden_count.times do |hidden|
|
|
246
|
+
@input_weights[input][hidden] += @hidden_errors[hidden] * @input_values[input]
|
|
247
|
+
end
|
|
248
|
+
end
|
|
249
|
+
end
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
# quantize each of the output nodes to zero or one, return as array
|
|
253
|
+
#
|
|
254
|
+
def quantized_result
|
|
255
|
+
@output_values.map { |output| output.round.to_i }
|
|
256
|
+
end
|
|
257
|
+
|
|
258
|
+
def result_as_binary_string
|
|
259
|
+
@output_values.inject("") { |accum, val| "#{accum}#{val.round}" }
|
|
260
|
+
end
|
|
261
|
+
|
|
262
|
+
end
|
|
263
|
+
|
|
264
|
+
end
|