ai4r 1.13 → 2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. checksums.yaml +7 -0
  2. data/README.md +174 -0
  3. data/examples/classifiers/hyperpipes_data.csv +14 -0
  4. data/examples/classifiers/hyperpipes_example.rb +22 -0
  5. data/examples/classifiers/ib1_example.rb +12 -0
  6. data/examples/classifiers/id3_example.rb +15 -10
  7. data/examples/classifiers/id3_graphviz_example.rb +17 -0
  8. data/examples/classifiers/logistic_regression_example.rb +11 -0
  9. data/examples/classifiers/naive_bayes_attributes_example.rb +13 -0
  10. data/examples/classifiers/naive_bayes_example.rb +12 -13
  11. data/examples/classifiers/one_r_example.rb +27 -0
  12. data/examples/classifiers/parameter_tutorial.rb +29 -0
  13. data/examples/classifiers/prism_nominal_example.rb +15 -0
  14. data/examples/classifiers/prism_numeric_example.rb +21 -0
  15. data/examples/classifiers/simple_linear_regression_example.rb +14 -11
  16. data/examples/classifiers/zero_and_one_r_example.rb +34 -0
  17. data/examples/classifiers/zero_one_r_data.csv +8 -0
  18. data/examples/clusterers/clusterer_example.rb +40 -34
  19. data/examples/clusterers/dbscan_example.rb +17 -0
  20. data/examples/clusterers/dendrogram_example.rb +17 -0
  21. data/examples/clusterers/hierarchical_dendrogram_example.rb +20 -0
  22. data/examples/clusterers/kmeans_custom_example.rb +26 -0
  23. data/examples/genetic_algorithm/bitstring_example.rb +41 -0
  24. data/examples/genetic_algorithm/genetic_algorithm_example.rb +26 -18
  25. data/examples/genetic_algorithm/kmeans_seed_tuning.rb +45 -0
  26. data/examples/neural_network/backpropagation_example.rb +48 -48
  27. data/examples/neural_network/hopfield_example.rb +45 -0
  28. data/examples/neural_network/patterns_with_base_noise.rb +39 -39
  29. data/examples/neural_network/patterns_with_noise.rb +41 -39
  30. data/examples/neural_network/train_epochs_callback.rb +25 -0
  31. data/examples/neural_network/training_patterns.rb +39 -39
  32. data/examples/neural_network/transformer_text_classification.rb +78 -0
  33. data/examples/neural_network/xor_example.rb +23 -22
  34. data/examples/reinforcement/q_learning_example.rb +10 -0
  35. data/examples/som/som_data.rb +155 -152
  36. data/examples/som/som_multi_node_example.rb +12 -13
  37. data/examples/som/som_single_example.rb +12 -15
  38. data/examples/transformer/decode_classifier_example.rb +68 -0
  39. data/examples/transformer/deterministic_example.rb +10 -0
  40. data/examples/transformer/seq2seq_example.rb +16 -0
  41. data/lib/ai4r/classifiers/classifier.rb +24 -16
  42. data/lib/ai4r/classifiers/gradient_boosting.rb +64 -0
  43. data/lib/ai4r/classifiers/hyperpipes.rb +119 -43
  44. data/lib/ai4r/classifiers/ib1.rb +122 -32
  45. data/lib/ai4r/classifiers/id3.rb +524 -145
  46. data/lib/ai4r/classifiers/logistic_regression.rb +96 -0
  47. data/lib/ai4r/classifiers/multilayer_perceptron.rb +75 -59
  48. data/lib/ai4r/classifiers/naive_bayes.rb +95 -34
  49. data/lib/ai4r/classifiers/one_r.rb +112 -44
  50. data/lib/ai4r/classifiers/prism.rb +167 -76
  51. data/lib/ai4r/classifiers/random_forest.rb +72 -0
  52. data/lib/ai4r/classifiers/simple_linear_regression.rb +83 -58
  53. data/lib/ai4r/classifiers/support_vector_machine.rb +91 -0
  54. data/lib/ai4r/classifiers/votes.rb +57 -0
  55. data/lib/ai4r/classifiers/zero_r.rb +71 -30
  56. data/lib/ai4r/clusterers/average_linkage.rb +46 -27
  57. data/lib/ai4r/clusterers/bisecting_k_means.rb +50 -44
  58. data/lib/ai4r/clusterers/centroid_linkage.rb +52 -36
  59. data/lib/ai4r/clusterers/cluster_tree.rb +50 -0
  60. data/lib/ai4r/clusterers/clusterer.rb +29 -14
  61. data/lib/ai4r/clusterers/complete_linkage.rb +42 -31
  62. data/lib/ai4r/clusterers/dbscan.rb +134 -0
  63. data/lib/ai4r/clusterers/diana.rb +75 -49
  64. data/lib/ai4r/clusterers/k_means.rb +270 -135
  65. data/lib/ai4r/clusterers/median_linkage.rb +49 -33
  66. data/lib/ai4r/clusterers/single_linkage.rb +196 -88
  67. data/lib/ai4r/clusterers/ward_linkage.rb +51 -35
  68. data/lib/ai4r/clusterers/ward_linkage_hierarchical.rb +25 -10
  69. data/lib/ai4r/clusterers/weighted_average_linkage.rb +48 -32
  70. data/lib/ai4r/data/data_set.rb +223 -103
  71. data/lib/ai4r/data/parameterizable.rb +31 -25
  72. data/lib/ai4r/data/proximity.rb +62 -62
  73. data/lib/ai4r/data/statistics.rb +46 -35
  74. data/lib/ai4r/experiment/classifier_evaluator.rb +84 -32
  75. data/lib/ai4r/experiment/split.rb +39 -0
  76. data/lib/ai4r/genetic_algorithm/chromosome_base.rb +43 -0
  77. data/lib/ai4r/genetic_algorithm/genetic_algorithm.rb +92 -170
  78. data/lib/ai4r/genetic_algorithm/tsp_chromosome.rb +83 -0
  79. data/lib/ai4r/hmm/hidden_markov_model.rb +134 -0
  80. data/lib/ai4r/neural_network/activation_functions.rb +37 -0
  81. data/lib/ai4r/neural_network/backpropagation.rb +399 -134
  82. data/lib/ai4r/neural_network/hopfield.rb +175 -58
  83. data/lib/ai4r/neural_network/transformer.rb +194 -0
  84. data/lib/ai4r/neural_network/weight_initializations.rb +40 -0
  85. data/lib/ai4r/reinforcement/policy_iteration.rb +66 -0
  86. data/lib/ai4r/reinforcement/q_learning.rb +51 -0
  87. data/lib/ai4r/search/a_star.rb +76 -0
  88. data/lib/ai4r/search/bfs.rb +50 -0
  89. data/lib/ai4r/search/dfs.rb +50 -0
  90. data/lib/ai4r/search/mcts.rb +118 -0
  91. data/lib/ai4r/search.rb +12 -0
  92. data/lib/ai4r/som/distance_metrics.rb +29 -0
  93. data/lib/ai4r/som/layer.rb +28 -17
  94. data/lib/ai4r/som/node.rb +61 -32
  95. data/lib/ai4r/som/som.rb +158 -41
  96. data/lib/ai4r/som/two_phase_layer.rb +21 -25
  97. data/lib/ai4r/version.rb +3 -0
  98. data/lib/ai4r.rb +57 -28
  99. metadata +79 -109
  100. data/README.rdoc +0 -39
  101. data/test/classifiers/hyperpipes_test.rb +0 -84
  102. data/test/classifiers/ib1_test.rb +0 -78
  103. data/test/classifiers/id3_test.rb +0 -220
  104. data/test/classifiers/multilayer_perceptron_test.rb +0 -79
  105. data/test/classifiers/naive_bayes_test.rb +0 -43
  106. data/test/classifiers/one_r_test.rb +0 -62
  107. data/test/classifiers/prism_test.rb +0 -85
  108. data/test/classifiers/simple_linear_regression_test.rb +0 -37
  109. data/test/classifiers/zero_r_test.rb +0 -50
  110. data/test/clusterers/average_linkage_test.rb +0 -51
  111. data/test/clusterers/bisecting_k_means_test.rb +0 -66
  112. data/test/clusterers/centroid_linkage_test.rb +0 -53
  113. data/test/clusterers/complete_linkage_test.rb +0 -57
  114. data/test/clusterers/diana_test.rb +0 -69
  115. data/test/clusterers/k_means_test.rb +0 -167
  116. data/test/clusterers/median_linkage_test.rb +0 -53
  117. data/test/clusterers/single_linkage_test.rb +0 -122
  118. data/test/clusterers/ward_linkage_hierarchical_test.rb +0 -81
  119. data/test/clusterers/ward_linkage_test.rb +0 -53
  120. data/test/clusterers/weighted_average_linkage_test.rb +0 -53
  121. data/test/data/data_set_test.rb +0 -104
  122. data/test/data/proximity_test.rb +0 -87
  123. data/test/data/statistics_test.rb +0 -65
  124. data/test/experiment/classifier_evaluator_test.rb +0 -76
  125. data/test/genetic_algorithm/chromosome_test.rb +0 -57
  126. data/test/genetic_algorithm/genetic_algorithm_test.rb +0 -81
  127. data/test/neural_network/backpropagation_test.rb +0 -82
  128. data/test/neural_network/hopfield_test.rb +0 -72
  129. data/test/som/som_test.rb +0 -97
@@ -0,0 +1,20 @@
1
+ # frozen_string_literal: true
2
+
3
+ # Demonstrates recording the merge tree using WardLinkageHierarchical
4
+ # and printing a simple dendrogram.
5
+
6
+ require 'ai4r'
7
+
8
+ points = [[1, 1], [1, 2], [2, 1], [2, 2], [8, 8], [8, 9], [9, 8], [9, 9]]
9
+ data_set = Ai4r::Data::DataSet.new(data_items: points)
10
+
11
+ clusterer = Ai4r::Clusterers::WardLinkageHierarchical.new
12
+ clusterer.build(data_set, 1)
13
+
14
+ puts 'Dendrogram:'
15
+ clusterer.cluster_tree.each_with_index do |clusters, level|
16
+ puts "Level #{level}:"
17
+ clusters.each_with_index do |cluster, idx|
18
+ puts " Cluster #{idx}: #{cluster.data_items.inspect}"
19
+ end
20
+ end
@@ -0,0 +1,26 @@
1
+ # frozen_string_literal: true
2
+
3
+ # This example shows KMeans with a custom distance function and deterministic initialization.
4
+ # It also prints the number of iterations and the sum of squared errors (SSE).
5
+
6
+ require 'ai4r'
7
+
8
+ # Simple two-cluster data set
9
+ points = [[1, 1], [1, 2], [2, 1], [2, 2], [8, 8], [8, 9], [9, 8], [9, 9]]
10
+ data_set = Ai4r::Data::DataSet.new(data_items: points)
11
+
12
+ # Manhattan distance instead of the default squared Euclidean distance
13
+ manhattan = lambda do |a, b|
14
+ a.zip(b).map { |x, y| (x - y).abs }.reduce(:+)
15
+ end
16
+
17
+ kmeans = Ai4r::Clusterers::KMeans.new
18
+ kmeans.set_parameters(distance_function: manhattan, random_seed: 1)
19
+ .build(data_set, 2)
20
+
21
+ kmeans.clusters.each_with_index do |cluster, idx|
22
+ puts "Cluster #{idx}: #{cluster.data_items.inspect}"
23
+ end
24
+
25
+ puts "Iterations: #{kmeans.iterations}"
26
+ puts "SSE: #{kmeans.sse}"
@@ -0,0 +1,41 @@
1
+ # frozen_string_literal: true
2
+
3
+ # Example using a custom chromosome to maximise ones in a bit string
4
+ require_relative '../../lib/ai4r/genetic_algorithm/genetic_algorithm'
5
+
6
+ # Represents a chromosome consisting of a bit string.
7
+ class BitStringChromosome < Ai4r::GeneticAlgorithm::ChromosomeBase
8
+ LENGTH = 16
9
+
10
+ def fitness
11
+ @data.count(1)
12
+ end
13
+
14
+ def self.seed
15
+ new(Array.new(LENGTH) { rand(2) })
16
+ end
17
+
18
+ def self.reproduce(parent_a, parent_b, crossover_rate = 0.4)
19
+ point = rand(LENGTH)
20
+ data = parent_a.data[0...point] + parent_b.data[point..]
21
+ data = parent_b.data[0...point] + parent_a.data[point..] if rand < crossover_rate
22
+ new(data)
23
+ end
24
+
25
+ def self.mutate(chromosome, mutation_rate = 0.3)
26
+ chromosome.data.map!.with_index do |bit, _|
27
+ if rand < ((1 - chromosome.normalized_fitness.to_f) * mutation_rate)
28
+ 1 - bit
29
+ else
30
+ bit
31
+ end
32
+ end
33
+ chromosome.instance_variable_set(:@fitness, nil)
34
+ end
35
+ end
36
+
37
+ search = Ai4r::GeneticAlgorithm::GeneticSearch.new(
38
+ 30, 50, BitStringChromosome, 0.2, 0.7, BitStringChromosome::LENGTH
39
+ )
40
+ best = search.run
41
+ puts "Best fitness #{best.fitness}: #{best.data.join}"
@@ -1,37 +1,45 @@
1
+ # frozen_string_literal: true
2
+
1
3
  # Author:: Sergio Fierens
2
4
  # License:: MPL 1.1
3
5
  # Project:: ai4r
4
6
  # Url:: http://www.ai4r.org/
5
7
  #
6
- # You can redistribute it and/or modify it under the terms of
7
- # the Mozilla Public License version 1.1 as published by the
8
+ # You can redistribute it and/or modify it under the terms of
9
+ # the Mozilla Public License version 1.1 as published by the
8
10
  # Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
9
11
 
10
- require File.dirname(__FILE__) + '/../../lib/ai4r/genetic_algorithm/genetic_algorithm'
11
- require File.dirname(__FILE__) + '/../../lib/ai4r/data/data_set'
12
+ require_relative '../../lib/ai4r/genetic_algorithm/genetic_algorithm'
13
+ require_relative '../../lib/ai4r/data/data_set'
12
14
  require 'csv'
13
15
 
14
16
  # Load data from data_set.csv
15
17
  data_filename = "#{File.dirname(__FILE__)}/travel_cost.csv"
16
18
  data_set = Ai4r::Data::DataSet.new.load_csv_with_labels data_filename
17
- data_set.data_items.collect! {|column| column.collect {|element| element.to_f}}
19
+ data_set.data_items.collect! { |column| column.collect(&:to_f) }
18
20
 
19
- Ai4r::GeneticAlgorithm::Chromosome.set_cost_matrix(data_set.data_items)
21
+ Ai4r::GeneticAlgorithm::TspChromosome.set_cost_matrix(data_set.data_items)
20
22
 
21
- puts "Some random selected tours costs: "
23
+ puts 'Some random selected tours costs: '
22
24
  3.times do
23
- c = Ai4r::GeneticAlgorithm::Chromosome.seed
24
- puts "COST #{-1 * c.fitness} TOUR: "+
25
- "#{c.data.collect{|c| data_set.data_labels[c]} * ', '}"
25
+ c = Ai4r::GeneticAlgorithm::TspChromosome.seed
26
+ puts "COST #{-1 * c.fitness} TOUR: " \
27
+ "#{c.data.collect { |c| data_set.data_labels[c] } * ', '}"
26
28
  end
27
29
 
28
- puts "Beginning genetic search, please wait... "
29
- search = Ai4r::GeneticAlgorithm::GeneticSearch.new(800, 100)
30
+ puts 'Beginning genetic search, please wait... '
31
+ search = Ai4r::GeneticAlgorithm::GeneticSearch.new(
32
+ 800, 100, Ai4r::GeneticAlgorithm::TspChromosome,
33
+ 0.3, 0.4, nil, nil,
34
+ ->(generation, best) { puts "Generation #{generation}: best fitness #{best}" }
35
+ )
30
36
  result = search.run
31
- puts "COST #{-1 * result.fitness} TOUR: "+
32
- "#{result.data.collect{|c| data_set.data_labels[c]} * ', '}"
37
+ puts "COST #{-1 * result.fitness} TOUR: " \
38
+ "#{result.data.collect { |c| data_set.data_labels[c] } * ', '}"
33
39
 
34
- # $7611.99 TOUR: Moscow, Kiev, Warsaw, Hamburg, Berlin, Vienna, Munich, Milan, Rome, Barcelona, Madrid, Paris, Brussels, London, Dublin
35
- # $7659.81 TOUR: Moscow, Kiev, Warsaw, Vienna, Munich, Berlin, Hamburg, Brussels, Dublin, London, Paris, Milan, Rome, Barcelona, Madrid
36
- # $7596.74 TOUR: Moscow, Kiev, Warsaw, Berlin, Hamburg, Vienna, Munich, Milan, Rome, Barcelona, Madrid, Paris, Brussels, London Dublin
37
-
40
+ # $7611.99 TOUR: Moscow, Kiev, Warsaw, Hamburg, Berlin, Vienna, Munich, Milan,
41
+ # Rome, Barcelona, Madrid, Paris, Brussels, London, Dublin
42
+ # $7659.81 TOUR: Moscow, Kiev, Warsaw, Vienna, Munich, Berlin, Hamburg, Brussels,
43
+ # Dublin, London, Paris, Milan, Rome, Barcelona, Madrid
44
+ # $7596.74 TOUR: Moscow, Kiev, Warsaw, Berlin, Hamburg, Vienna, Munich, Milan,
45
+ # Rome, Barcelona, Madrid, Paris, Brussels, London Dublin
@@ -0,0 +1,45 @@
1
+ #!/usr/bin/env ruby
2
+ # frozen_string_literal: true
3
+
4
+ require_relative '../../lib/ai4r/genetic_algorithm/genetic_algorithm'
5
+ require_relative '../../lib/ai4r/clusterers/k_means'
6
+ require_relative '../som/som_data'
7
+
8
+ ##
9
+ # Running the genetic search without a fixed random seed leads to
10
+ # different SSE results each time, which causes flaky tests in CI.
11
+ # Explicitly seed Ruby's random number generator so that the script
12
+ # behaves deterministically across runs.
13
+ srand 1
14
+
15
+ # Chromosome used to search for a good random seed for KMeans.
16
+ class SeedChromosome < Ai4r::GeneticAlgorithm::ChromosomeBase
17
+ RANGE = (0..10)
18
+ DATA = Ai4r::Data::DataSet.new(data_items: SOM_DATA.first(30))
19
+
20
+ def fitness
21
+ return @fitness if @fitness
22
+
23
+ kmeans = Ai4r::Clusterers::KMeans.new.set_parameters(random_seed: @data).build(DATA, 3)
24
+ @fitness = -kmeans.sse
25
+ end
26
+
27
+ def self.seed
28
+ new(RANGE.to_a.sample)
29
+ end
30
+
31
+ def self.reproduce(parent_a, parent_b, _rate = 0.7)
32
+ new([parent_a.data, parent_b.data].sample)
33
+ end
34
+
35
+ def self.mutate(chrom, rate = 0.3)
36
+ return unless rand < rate
37
+
38
+ chrom.data = RANGE.to_a.sample
39
+ chrom.instance_variable_set(:@fitness, nil)
40
+ end
41
+ end
42
+
43
+ search = Ai4r::GeneticAlgorithm::GeneticSearch.new(6, 5, SeedChromosome, 0.1, 0.7)
44
+ best = search.run
45
+ puts(-best.fitness)
@@ -1,68 +1,68 @@
1
+ # frozen_string_literal: true
2
+
1
3
  # Author:: Sergio Fierens
2
4
  # License:: MPL 1.1
3
5
  # Project:: ai4r
4
6
  # Url:: http://www.ai4r.org/
5
7
  #
6
- # You can redistribute it and/or modify it under the terms of
7
- # the Mozilla Public License version 1.1 as published by the
8
+ # You can redistribute it and/or modify it under the terms of
9
+ # the Mozilla Public License version 1.1 as published by the
8
10
  # Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
9
11
 
10
- require File.dirname(__FILE__) + '/training_patterns'
11
- require File.dirname(__FILE__) + '/patterns_with_noise'
12
- require File.dirname(__FILE__) + '/patterns_with_base_noise'
13
- require File.dirname(__FILE__) + '/../../lib/ai4r/neural_network/backpropagation'
12
+ require_relative 'training_patterns'
13
+ require_relative 'patterns_with_noise'
14
+ require_relative 'patterns_with_base_noise'
15
+ require_relative '../../lib/ai4r/neural_network/backpropagation'
14
16
  require 'benchmark'
15
17
 
16
18
  times = Benchmark.measure do
17
-
18
- srand 1
19
+ srand 1
19
20
 
20
- # creating network with 256 input-neurons, 3-neurons and 0 hidden layers
21
- net = Ai4r::NeuralNetwork::Backpropagation.new([256, 3])
22
-
23
- tr_input = TRIANGLE.flatten.collect { |input| input.to_f / 5.0}
24
- sq_input = SQUARE.flatten.collect { |input| input.to_f / 5.0}
25
- cr_input = CROSS.flatten.collect { |input| input.to_f / 5.0}
21
+ # creating network with 256 input-neurons, 3-neurons and 0 hidden layers
22
+ net = Ai4r::NeuralNetwork::Backpropagation.new([256, 3])
26
23
 
27
- tr_with_noise = TRIANGLE_WITH_NOISE.flatten.collect { |input| input.to_f / 5.0}
28
- sq_with_noise = SQUARE_WITH_NOISE.flatten.collect { |input| input.to_f / 5.0}
29
- cr_with_noise = CROSS_WITH_NOISE.flatten.collect { |input| input.to_f / 5.0}
24
+ tr_input = TRIANGLE.flatten.collect { |input| input.to_f / 5.0 }
25
+ sq_input = SQUARE.flatten.collect { |input| input.to_f / 5.0 }
26
+ cr_input = CROSS.flatten.collect { |input| input.to_f / 5.0 }
30
27
 
31
- tr_with_base_noise = TRIANGLE_WITH_BASE_NOISE.flatten.collect { |input| input.to_f / 5.0}
32
- sq_with_base_noise = SQUARE_WITH_BASE_NOISE.flatten.collect { |input| input.to_f / 5.0}
33
- cr_with_base_noise = CROSS_WITH_BASE_NOISE.flatten.collect { |input| input.to_f / 5.0}
28
+ tr_with_noise = TRIANGLE_WITH_NOISE.flatten.collect { |input| input.to_f / 5.0 }
29
+ sq_with_noise = SQUARE_WITH_NOISE.flatten.collect { |input| input.to_f / 5.0 }
30
+ cr_with_noise = CROSS_WITH_NOISE.flatten.collect { |input| input.to_f / 5.0 }
34
31
 
35
- puts "Training the network, please wait."
36
- 101.times do |i|
37
- error = net.train(tr_input, [1,0,0])
38
- error = net.train(sq_input, [0,1,0])
39
- error = net.train(cr_input, [0,0,1])
40
- puts "Error after iteration #{i}:\t#{error}" if i%20 == 0
41
- end
32
+ tr_with_base_noise = TRIANGLE_WITH_BASE_NOISE.flatten.collect { |input| input.to_f / 5.0 }
33
+ sq_with_base_noise = SQUARE_WITH_BASE_NOISE.flatten.collect { |input| input.to_f / 5.0 }
34
+ cr_with_base_noise = CROSS_WITH_BASE_NOISE.flatten.collect { |input| input.to_f / 5.0 }
35
+
36
+ puts 'Training the network, please wait.'
37
+ 101.times do |i|
38
+ net.train(tr_input, [1, 0, 0])
39
+ net.train(sq_input, [0, 1, 0])
40
+ error = net.train(cr_input, [0, 0, 1])
41
+ puts "Error after iteration #{i}:\t#{error}" if (i % 20).zero?
42
+ end
42
43
 
43
- def result_label(result)
44
- if result[0] > result[1] && result[0] > result[2]
45
- "TRIANGLE"
46
- elsif result[1] > result[2]
47
- "SQUARE"
48
- else
49
- "CROSS"
50
- end
44
+ def result_label(result)
45
+ if result[0] > result[1] && result[0] > result[2]
46
+ 'TRIANGLE'
47
+ elsif result[1] > result[2]
48
+ 'SQUARE'
49
+ else
50
+ 'CROSS'
51
51
  end
52
+ end
52
53
 
53
- puts "Training Examples"
54
- puts "#{net.eval(tr_input).inspect} => #{result_label(net.eval(tr_input))}"
55
- puts "#{net.eval(sq_input).inspect} => #{result_label(net.eval(sq_input))}"
56
- puts "#{net.eval(cr_input).inspect} => #{result_label(net.eval(cr_input))}"
57
- puts "Examples with noise"
58
- puts "#{net.eval(tr_with_noise).inspect} => #{result_label(net.eval(tr_with_noise))}"
59
- puts "#{net.eval(sq_with_noise).inspect} => #{result_label(net.eval(sq_with_noise))}"
60
- puts "#{net.eval(cr_with_noise).inspect} => #{result_label(net.eval(cr_with_noise))}"
61
- puts "Examples with base noise"
62
- puts "#{net.eval(tr_with_base_noise).inspect} => #{result_label(net.eval(tr_with_base_noise))}"
63
- puts "#{net.eval(sq_with_base_noise).inspect} => #{result_label(net.eval(sq_with_base_noise))}"
64
- puts "#{net.eval(cr_with_base_noise).inspect} => #{result_label(net.eval(cr_with_base_noise))}"
65
-
54
+ puts 'Training Examples'
55
+ puts "#{net.eval(tr_input).inspect} => #{result_label(net.eval(tr_input))}"
56
+ puts "#{net.eval(sq_input).inspect} => #{result_label(net.eval(sq_input))}"
57
+ puts "#{net.eval(cr_input).inspect} => #{result_label(net.eval(cr_input))}"
58
+ puts 'Examples with noise'
59
+ puts "#{net.eval(tr_with_noise).inspect} => #{result_label(net.eval(tr_with_noise))}"
60
+ puts "#{net.eval(sq_with_noise).inspect} => #{result_label(net.eval(sq_with_noise))}"
61
+ puts "#{net.eval(cr_with_noise).inspect} => #{result_label(net.eval(cr_with_noise))}"
62
+ puts 'Examples with base noise'
63
+ puts "#{net.eval(tr_with_base_noise).inspect} => #{result_label(net.eval(tr_with_base_noise))}"
64
+ puts "#{net.eval(sq_with_base_noise).inspect} => #{result_label(net.eval(sq_with_base_noise))}"
65
+ puts "#{net.eval(cr_with_base_noise).inspect} => #{result_label(net.eval(cr_with_base_noise))}"
66
66
  end
67
67
 
68
68
  puts "Elapsed time: #{times}"
@@ -0,0 +1,45 @@
1
+ # frozen_string_literal: true
2
+
3
+ # Author:: Sergio Fierens
4
+ # License:: MPL 1.1
5
+ # Project:: ai4r
6
+ # Url:: http://www.ai4r.org/
7
+ #
8
+ # You can redistribute it and/or modify it under the terms of
9
+ # the Mozilla Public License version 1.1 as published by the
10
+ # Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
11
+
12
+ require_relative '../../lib/ai4r/neural_network/hopfield'
13
+ require_relative '../../lib/ai4r/data/data_set'
14
+
15
+ patterns = [
16
+ [1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1],
17
+ [-1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1],
18
+ [-1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1],
19
+ [1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1]
20
+ ]
21
+
22
+ noisy_patterns = [
23
+ [1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, 1, -1],
24
+ [-1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1],
25
+ [-1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, -1, -1],
26
+ [-1, -1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 1, -1, -1, -1]
27
+ ]
28
+
29
+ data = Ai4r::Data::DataSet.new(data_items: patterns)
30
+
31
+ # Use random asynchronous updates instead of the default sequential
32
+ # strategy. Random updates are closer to the original Hopfield
33
+ # formulation and may help the network escape shallow local minima.
34
+ net = Ai4r::NeuralNetwork::Hopfield.new
35
+ net.set_parameters(update_strategy: :async_random)
36
+ net.train(data)
37
+
38
+ puts 'Evaluation of noisy patterns:'
39
+ noisy_patterns.each do |p|
40
+ # Pass `trace: true` to record the energy of each iteration so we can
41
+ # inspect how the network converges to a memorized pattern.
42
+ trace = net.eval(p, trace: true)
43
+ puts "#{p.inspect} => #{trace[:states].last.inspect}"
44
+ puts "Energy trace: #{trace[:energies].inspect}"
45
+ end
@@ -1,31 +1,32 @@
1
+ # frozen_string_literal: true
2
+
1
3
  # Author:: Sergio Fierens
2
4
  # License:: MPL 1.1
3
5
  # Project:: ai4r
4
6
  # Url:: http://www.ai4r.org/
5
7
  #
6
- # You can redistribute it and/or modify it under the terms of
7
- # the Mozilla Public License version 1.1 as published by the
8
+ # You can redistribute it and/or modify it under the terms of
9
+ # the Mozilla Public License version 1.1 as published by the
8
10
  # Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
9
-
10
11
 
11
12
  TRIANGLE_WITH_BASE_NOISE = [
12
- [ 3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
13
- [ 3, 3, 3, 3, 3, 3, 4, 10, 10, 4, 3, 3, 3, 3, 3, 3],
14
- [ 3, 3, 3, 3, 3, 3, 8, 8, 8, 8, 3, 3, 3, 3, 3, 3],
15
- [ 3, 3, 3, 3, 3, 4, 10, 4, 4, 10, 4, 3, 3, 3, 3, 3],
16
- [ 3, 3, 3, 3, 3, 8, 8, 3, 3, 8, 8, 3, 3, 3, 3, 3],
17
- [ 3, 3, 3, 3, 4, 10, 4, 3, 3, 4, 10, 4, 3, 3, 3, 3],
18
- [ 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3],
19
- [ 3, 3, 3, 4, 10, 4, 3, 3, 3, 3, 4, 10, 4, 3, 3, 3],
20
- [ 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3],
21
- [ 3, 3, 4, 10, 4, 3, 3, 3, 3, 3, 3, 4, 10, 4, 3, 3],
22
- [ 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3],
23
- [ 3, 4, 10, 4, 3, 3, 3, 3, 3, 3, 3, 3, 4, 10, 4, 3],
24
- [ 3, 8, 8, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 8, 8, 3],
25
- [ 4, 10, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 10, 4],
26
- [ 8, 8, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 8, 8],
13
+ [3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
14
+ [3, 3, 3, 3, 3, 3, 4, 10, 10, 4, 3, 3, 3, 3, 3, 3],
15
+ [3, 3, 3, 3, 3, 3, 8, 8, 8, 8, 3, 3, 3, 3, 3, 3],
16
+ [3, 3, 3, 3, 3, 4, 10, 4, 4, 10, 4, 3, 3, 3, 3, 3],
17
+ [3, 3, 3, 3, 3, 8, 8, 3, 3, 8, 8, 3, 3, 3, 3, 3],
18
+ [3, 3, 3, 3, 4, 10, 4, 3, 3, 4, 10, 4, 3, 3, 3, 3],
19
+ [3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3],
20
+ [3, 3, 3, 4, 10, 4, 3, 3, 3, 3, 4, 10, 4, 3, 3, 3],
21
+ [3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3],
22
+ [3, 3, 4, 10, 4, 3, 3, 3, 3, 3, 3, 4, 10, 4, 3, 3],
23
+ [3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3],
24
+ [3, 4, 10, 4, 3, 3, 3, 3, 3, 3, 3, 3, 4, 10, 4, 3],
25
+ [3, 8, 8, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 8, 8, 3],
26
+ [4, 10, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 10, 4],
27
+ [8, 8, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 8, 8],
27
28
  [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]
28
- ]
29
+ ].freeze
29
30
 
30
31
  SQUARE_WITH_BASE_NOISE = [
31
32
  [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],
@@ -44,25 +45,24 @@ SQUARE_WITH_BASE_NOISE = [
44
45
  [10, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 10],
45
46
  [10, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 10],
46
47
  [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]
47
-
48
- ]
48
+
49
+ ].freeze
49
50
 
50
51
  CROSS_WITH_BASE_NOISE = [
51
- [ 3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
52
- [ 3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
53
- [ 3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
54
- [ 3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
55
- [ 3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
56
- [ 3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
57
- [ 3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
58
- [ 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
59
- [ 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
60
- [ 3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
61
- [ 3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
62
- [ 3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
63
- [ 3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
64
- [ 3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
65
- [ 3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
66
- [ 3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3]
67
- ]
68
-
52
+ [3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
53
+ [3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
54
+ [3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
55
+ [3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
56
+ [3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
57
+ [3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
58
+ [3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
59
+ [8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
60
+ [8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
61
+ [3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
62
+ [3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
63
+ [3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
64
+ [3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
65
+ [3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
66
+ [3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3],
67
+ [3, 3, 3, 3, 3, 3, 3, 8, 8, 3, 3, 3, 3, 3, 3, 3]
68
+ ].freeze
@@ -1,30 +1,32 @@
1
+ # frozen_string_literal: true
2
+
1
3
  # Author:: Sergio Fierens
2
4
  # License:: MPL 1.1
3
5
  # Project:: ai4r
4
6
  # Url:: http://www.ai4r.org/
5
7
  #
6
- # You can redistribute it and/or modify it under the terms of
7
- # the Mozilla Public License version 1.1 as published by the
8
+ # You can redistribute it and/or modify it under the terms of
9
+ # the Mozilla Public License version 1.1 as published by the
8
10
  # Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
9
-
11
+
10
12
  TRIANGLE_WITH_NOISE = [
11
- [ 1, 0, 0, 0, 0, 0, 0, 1, 5, 0, 0, 1, 0, 0, 0, 0],
12
- [ 0, 0, 0, 0, 3, 0, 1, 9, 9, 1, 0, 0, 0, 0, 3, 0],
13
- [ 0, 3, 0, 0, 0, 0, 5, 1, 5, 3, 0, 0, 0, 0, 0, 7],
14
- [ 0, 0, 0, 7, 0, 1, 9, 1, 1, 9, 1, 0, 0, 0, 3, 0],
15
- [ 0, 0, 0, 0, 0, 3, 5, 0, 3, 5, 5, 0, 0, 0, 0, 0],
16
- [ 0, 1, 0, 0, 1, 9, 1, 0, 1, 1, 9, 1, 0, 0, 0, 0],
17
- [ 1, 0, 0, 0, 5, 5, 0, 0, 0, 0, 5, 5, 7, 0, 0, 3],
18
- [ 0, 0, 3, 3, 9, 1, 0, 0, 1, 0, 1, 9, 1, 0, 0, 0],
19
- [ 0, 0, 0, 5, 5, 0, 3, 7, 0, 0, 0, 5, 5, 0, 0, 0],
20
- [ 0, 0, 1, 9, 1, 0, 0, 0, 0, 0, 0, 1, 9, 1, 0, 0],
21
- [ 0, 0, 5, 5, 0, 0, 0, 0, 3, 0, 0, 0, 5, 5, 0, 0],
22
- [ 0, 1, 9, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 9, 1, 0],
23
- [ 0, 5, 5, 0, 3, 0, 0, 3, 0, 0, 0, 0, 0, 5, 5, 0],
24
- [ 1, 9, 1, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0, 1, 9, 1],
25
- [ 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 5],
26
- [10, 10, 10, 10, 1, 10, 10, 10, 10, 10, 1, 10, 10, 10, 10, 10]
27
- ]
13
+ [1, 0, 0, 0, 0, 0, 0, 1, 5, 0, 0, 1, 0, 0, 0, 0],
14
+ [0, 0, 0, 0, 3, 0, 1, 9, 9, 1, 0, 0, 0, 0, 3, 0],
15
+ [0, 3, 0, 0, 0, 0, 5, 1, 5, 3, 0, 0, 0, 0, 0, 7],
16
+ [0, 0, 0, 7, 0, 1, 9, 1, 1, 9, 1, 0, 0, 0, 3, 0],
17
+ [0, 0, 0, 0, 0, 3, 5, 0, 3, 5, 5, 0, 0, 0, 0, 0],
18
+ [0, 1, 0, 0, 1, 9, 1, 0, 1, 1, 9, 1, 0, 0, 0, 0],
19
+ [1, 0, 0, 0, 5, 5, 0, 0, 0, 0, 5, 5, 7, 0, 0, 3],
20
+ [0, 0, 3, 3, 9, 1, 0, 0, 1, 0, 1, 9, 1, 0, 0, 0],
21
+ [0, 0, 0, 5, 5, 0, 3, 7, 0, 0, 0, 5, 5, 0, 0, 0],
22
+ [0, 0, 1, 9, 1, 0, 0, 0, 0, 0, 0, 1, 9, 1, 0, 0],
23
+ [0, 0, 5, 5, 0, 0, 0, 0, 3, 0, 0, 0, 5, 5, 0, 0],
24
+ [0, 1, 9, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 9, 1, 0],
25
+ [0, 5, 5, 0, 3, 0, 0, 3, 0, 0, 0, 0, 0, 5, 5, 0],
26
+ [1, 9, 1, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0, 1, 9, 1],
27
+ [5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 5],
28
+ [10, 10, 10, 10, 1, 10, 10, 10, 10, 10, 1, 10, 10, 10, 10, 10]
29
+ ].freeze
28
30
 
29
31
  SQUARE_WITH_NOISE = [
30
32
  [10, 3, 10, 10, 10, 6, 10, 10, 10, 10, 10, 4, 10, 10, 10, 10],
@@ -43,24 +45,24 @@ SQUARE_WITH_NOISE = [
43
45
  [10, 0, 0, 6, 0, 0, 0, 7, 0, 0, 0, 7, 0, 0, 0, 10],
44
46
  [10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10],
45
47
  [10, 10, 10, 10, 3, 10, 10, 10, 10, 0, 10, 10, 1, 10, 1, 10]
46
-
47
- ]
48
+
49
+ ].freeze
48
50
 
49
51
  CROSS_WITH_NOISE = [
50
- [ 0, 0, 0, 0, 0, 0, 3, 3, 5, 0, 3, 0, 0, 0, 1, 0],
51
- [ 0, 1, 0, 0, 0, 1, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
52
- [ 0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 3, 0, 0, 0],
53
- [ 0, 0, 1, 8, 0, 0, 0, 5, 5, 0, 4, 0, 0, 0, 1, 0],
54
- [ 0, 0, 0, 0, 0, 3, 0, 5, 0, 0, 0, 0, 1, 0, 0, 0],
55
- [ 0, 0, 0, 8, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 1],
56
- [ 0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 3, 0, 0, 0, 0, 0],
57
- [ 5, 5, 5, 8, 5, 3, 5, 5, 5, 5, 5, 5, 5, 5, 0, 5],
58
- [ 5, 5, 5, 5, 5, 5, 5, 5, 1, 5, 5, 5, 5, 1, 0, 0],
59
- [ 0, 0, 0, 8, 0, 0, 0, 4, 5, 0, 0, 0, 0, 0, 0, 0],
60
- [ 0, 0, 0, 0, 0, 0, 0, 5, 5, 4, 0, 0, 0, 0, 0, 0],
61
- [ 0, 0, 0, 0, 0, 4, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
62
- [ 4, 0, 0, 4, 0, 0, 0, 5, 5, 0, 0, 0, 1, 0, 0, 0],
63
- [ 0, 0, 0, 0, 0, 1, 0, 5, 4, 4, 3, 0, 0, 0, 0, 0],
64
- [ 0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 10, 0, 0, 0],
65
- [ 0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0]
66
- ]
52
+ [0, 0, 0, 0, 0, 0, 3, 3, 5, 0, 3, 0, 0, 0, 1, 0],
53
+ [0, 1, 0, 0, 0, 1, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
54
+ [0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 3, 0, 0, 0],
55
+ [0, 0, 1, 8, 0, 0, 0, 5, 5, 0, 4, 0, 0, 0, 1, 0],
56
+ [0, 0, 0, 0, 0, 3, 0, 5, 0, 0, 0, 0, 1, 0, 0, 0],
57
+ [0, 0, 0, 8, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 1],
58
+ [0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 3, 0, 0, 0, 0, 0],
59
+ [5, 5, 5, 8, 5, 3, 5, 5, 5, 5, 5, 5, 5, 5, 0, 5],
60
+ [5, 5, 5, 5, 5, 5, 5, 5, 1, 5, 5, 5, 5, 1, 0, 0],
61
+ [0, 0, 0, 8, 0, 0, 0, 4, 5, 0, 0, 0, 0, 0, 0, 0],
62
+ [0, 0, 0, 0, 0, 0, 0, 5, 5, 4, 0, 0, 0, 0, 0, 0],
63
+ [0, 0, 0, 0, 0, 4, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
64
+ [4, 0, 0, 4, 0, 0, 0, 5, 5, 0, 0, 0, 1, 0, 0, 0],
65
+ [0, 0, 0, 0, 0, 1, 0, 5, 4, 4, 3, 0, 0, 0, 0, 0],
66
+ [0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 10, 0, 0, 0],
67
+ [0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0]
68
+ ].freeze
@@ -0,0 +1,25 @@
1
+ # frozen_string_literal: true
2
+
3
+ # Author:: Example contributor
4
+ # License:: MPL 1.1
5
+ # Project:: ai4r
6
+ #
7
+ # Simple example showing how to use Backpropagation#train_epochs with a callback.
8
+
9
+ require_relative '../../lib/ai4r/neural_network/backpropagation'
10
+
11
+ inputs = [[0, 0], [0, 1], [1, 0], [1, 1]]
12
+ outputs = [[0], [1], [1], [0]]
13
+
14
+ net = Ai4r::NeuralNetwork::Backpropagation.new([2, 2, 1])
15
+
16
+ loss_history = []
17
+ net.train_epochs(inputs, outputs, epochs: 200, batch_size: 1) do |epoch, loss, acc|
18
+ loss_history << [epoch, loss, acc]
19
+ if (epoch % 50).zero?
20
+ puts "Epoch #{epoch}: loss #{format('%.4f',
21
+ loss)} accuracy #{(acc * 100).round(2)}%"
22
+ end
23
+ end
24
+
25
+ puts 'Training finished.'