simple_neural_network 0.0.1 → 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (4) hide show
  1. checksums.yaml +4 -4
  2. data/lib/layer.rb +19 -11
  3. data/lib/network.rb +15 -5
  4. metadata +17 -3
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA1:
3
- metadata.gz: 4ddf41db0522abafd12d51c425246499522c36cd
4
- data.tar.gz: 4b576ca87629cd9d98845b3d02c2c41e0a700cf1
3
+ metadata.gz: c118791ad07a0945cbc340ca50e05e5817ce1bd0
4
+ data.tar.gz: c1b921e8d031942788fec9148b06f8b7175a727e
5
5
  SHA512:
6
- metadata.gz: d11cff2d470e190617dc1ed71b3db48d6585b65b728e9c6490dcc4711535febc34c16130461bfedc1c871b63938906685c36ca34aef2e7882067eb5fdbff17ef
7
- data.tar.gz: 6e69176cc45847c10af9d662d70964ca0ff5e868bc1a7dff32ff6ff7f9b0fc1cec62dfd6567f10206f33b5b7a951fc7e30e7a55aa7f7ec86b133c956f370b633
6
+ metadata.gz: 2e70cbbee76e460546772478ff0d1e8c6e5e037bf546b0870c9b31dd36dbb3722de2f37f71793d8984f6b4db5012c0b96a3696e2e47e294ab6d89a7d63741249
7
+ data.tar.gz: a386118b62a17c49949191d39a65a3ae7a4844d6b1ee175448d7ea4272b9db257ad383a09032ff4b728be240664d733d35ab605bbd07214f836236b121860d4b
data/lib/layer.rb CHANGED
@@ -1,4 +1,5 @@
1
1
  require_relative "neuron"
2
+ require "nmatrix"
2
3
 
3
4
  class SimpleNeuralNetwork
4
5
  class Layer
@@ -22,6 +23,7 @@ class SimpleNeuralNetwork
22
23
  @next_layer = nil
23
24
 
24
25
  populate_neurons
26
+ edge_matrix # Caches edge matrix
25
27
  end
26
28
 
27
29
  # The method that drives network output resolution.
@@ -39,18 +41,11 @@ class SimpleNeuralNetwork
39
41
  # + (prev_layer.neurons[1] * prev_layer.neurons[1].edges[i])
40
42
  # + ...
41
43
  # ) + self.neurons[i].bias
44
+ prev_output = prev_layer.get_output
45
+ prev_output_matrix = NMatrix.new([prev_output.length, 1], prev_output, dtype: :float64)
42
46
 
43
- prev_layer_output = prev_layer.get_output
44
-
45
- # Generate the output values for the layer
46
- (0..@size-1).map do |i|
47
- value = 0
48
-
49
- prev_layer_output.each_with_index do |output, index|
50
- value += (output * prev_layer.neurons[index].edges[i])
51
- end
52
-
53
- value + @neurons[i].bias
47
+ result = (edge_matrix.dot(prev_output_matrix)).each_with_index.map do |val, i|
48
+ val + @neurons[i].bias
54
49
  end
55
50
  end
56
51
  end
@@ -63,6 +58,19 @@ class SimpleNeuralNetwork
63
58
  end
64
59
  end
65
60
 
61
+ def edge_matrix
62
+ return unless prev_layer
63
+
64
+ @edge_matrix ||= begin
65
+ elements = prev_layer.neurons.map{|a| a.edges}
66
+ NMatrix.new([elements.count, elements[0].count], elements.flatten, dtype: :float64).transpose
67
+ end
68
+ end
69
+
70
+ def clear_edge_cache
71
+ @edge_matrix = nil
72
+ end
73
+
66
74
  private
67
75
 
68
76
  def populate_neurons
data/lib/network.rb CHANGED
@@ -23,7 +23,7 @@ class SimpleNeuralNetwork
23
23
 
24
24
  attr_accessor :inputs
25
25
 
26
- attr_writer :normalization_function
26
+ attr_accessor :normalization_function
27
27
 
28
28
  attr_accessor :edge_initialization_function
29
29
  attr_accessor :neuron_bias_initialization_function
@@ -41,16 +41,20 @@ class SimpleNeuralNetwork
41
41
  # Accepts an array of input integers between 0 and 1
42
42
  # Input array length must be equal to the size of the first layer.
43
43
  # Returns an array of outputs.
44
- def run(inputs)
45
- unless inputs.size == input_size && inputs.all? { |input| input >= 0 && input <= 1 }
46
- raise InvalidInputError.new("Invalid input passed to Network#run")
44
+ #
45
+ # skip_validation: Skips validations that may be expensive for large sets
46
+ def run(inputs, skip_validation: false)
47
+ unless skip_validation
48
+ unless inputs.size == input_size && inputs.all? { |input| input >= 0 && input <= 1 }
49
+ raise InvalidInputError.new("Invalid input passed to Network#run")
50
+ end
47
51
  end
48
52
 
49
53
  @inputs = inputs
50
54
 
51
55
  # Get output from last layer. It recursively depends on layers before it.
52
56
  @layers[-1].get_output.map do |output|
53
- @normalization_function.call(output)
57
+ (@normalization_function || method(:default_normalization_function)).call(output)
54
58
  end
55
59
  end
56
60
 
@@ -84,6 +88,12 @@ class SimpleNeuralNetwork
84
88
  @normalization_function = method(:default_normalization_function)
85
89
  end
86
90
 
91
+ def clear_edge_caches
92
+ @layers.each do |layer|
93
+ layer.clear_edge_cache
94
+ end
95
+ end
96
+
87
97
  # Serializes the neural network into a JSON string. This can later be deserialized back into a Network object
88
98
  # Useful for storing partially trained neural networks.
89
99
  # Note: Currently does not serialize bias init function, edge init function, or normalization function
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: simple_neural_network
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.1
4
+ version: 0.1.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Nathaniel Woodthorpe
@@ -9,7 +9,21 @@ autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
11
  date: 2018-03-11 00:00:00.000000000 Z
12
- dependencies: []
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: nmatrix
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - "~>"
18
+ - !ruby/object:Gem::Version
19
+ version: '0.2'
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - "~>"
25
+ - !ruby/object:Gem::Version
26
+ version: '0.2'
13
27
  description: A simple neural network implementation in Ruby.
14
28
  email: njwoodthorpe@gmail.com
15
29
  executables: []
@@ -40,7 +54,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
40
54
  version: '0'
41
55
  requirements: []
42
56
  rubyforge_project:
43
- rubygems_version: 2.5.2
57
+ rubygems_version: 2.6.13
44
58
  signing_key:
45
59
  specification_version: 4
46
60
  summary: A simple neural network implementation in Ruby.