ai4r 1.12 → 2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +174 -0
- data/examples/classifiers/hyperpipes_data.csv +14 -0
- data/examples/classifiers/hyperpipes_example.rb +22 -0
- data/examples/classifiers/ib1_example.rb +12 -0
- data/examples/classifiers/id3_example.rb +15 -10
- data/examples/classifiers/id3_graphviz_example.rb +17 -0
- data/examples/classifiers/logistic_regression_example.rb +11 -0
- data/examples/classifiers/naive_bayes_attributes_example.rb +13 -0
- data/examples/classifiers/naive_bayes_example.rb +12 -13
- data/examples/classifiers/one_r_example.rb +27 -0
- data/examples/classifiers/parameter_tutorial.rb +29 -0
- data/examples/classifiers/prism_nominal_example.rb +15 -0
- data/examples/classifiers/prism_numeric_example.rb +21 -0
- data/examples/classifiers/simple_linear_regression_example.csv +159 -0
- data/examples/classifiers/simple_linear_regression_example.rb +18 -0
- data/examples/classifiers/zero_and_one_r_example.rb +34 -0
- data/examples/classifiers/zero_one_r_data.csv +8 -0
- data/examples/clusterers/clusterer_example.rb +62 -0
- data/examples/clusterers/dbscan_example.rb +17 -0
- data/examples/clusterers/dendrogram_example.rb +17 -0
- data/examples/clusterers/hierarchical_dendrogram_example.rb +20 -0
- data/examples/clusterers/kmeans_custom_example.rb +26 -0
- data/examples/genetic_algorithm/bitstring_example.rb +41 -0
- data/examples/genetic_algorithm/genetic_algorithm_example.rb +26 -18
- data/examples/genetic_algorithm/kmeans_seed_tuning.rb +45 -0
- data/examples/neural_network/backpropagation_example.rb +49 -48
- data/examples/neural_network/hopfield_example.rb +45 -0
- data/examples/neural_network/patterns_with_base_noise.rb +39 -39
- data/examples/neural_network/patterns_with_noise.rb +41 -39
- data/examples/neural_network/train_epochs_callback.rb +25 -0
- data/examples/neural_network/training_patterns.rb +39 -39
- data/examples/neural_network/transformer_text_classification.rb +78 -0
- data/examples/neural_network/xor_example.rb +23 -22
- data/examples/reinforcement/q_learning_example.rb +10 -0
- data/examples/som/som_data.rb +155 -152
- data/examples/som/som_multi_node_example.rb +12 -13
- data/examples/som/som_single_example.rb +12 -15
- data/examples/transformer/decode_classifier_example.rb +68 -0
- data/examples/transformer/deterministic_example.rb +10 -0
- data/examples/transformer/seq2seq_example.rb +16 -0
- data/lib/ai4r/classifiers/classifier.rb +24 -16
- data/lib/ai4r/classifiers/gradient_boosting.rb +64 -0
- data/lib/ai4r/classifiers/hyperpipes.rb +119 -43
- data/lib/ai4r/classifiers/ib1.rb +122 -32
- data/lib/ai4r/classifiers/id3.rb +527 -144
- data/lib/ai4r/classifiers/logistic_regression.rb +96 -0
- data/lib/ai4r/classifiers/multilayer_perceptron.rb +75 -59
- data/lib/ai4r/classifiers/naive_bayes.rb +112 -48
- data/lib/ai4r/classifiers/one_r.rb +112 -44
- data/lib/ai4r/classifiers/prism.rb +167 -76
- data/lib/ai4r/classifiers/random_forest.rb +72 -0
- data/lib/ai4r/classifiers/simple_linear_regression.rb +143 -0
- data/lib/ai4r/classifiers/support_vector_machine.rb +91 -0
- data/lib/ai4r/classifiers/votes.rb +57 -0
- data/lib/ai4r/classifiers/zero_r.rb +71 -30
- data/lib/ai4r/clusterers/average_linkage.rb +46 -27
- data/lib/ai4r/clusterers/bisecting_k_means.rb +50 -44
- data/lib/ai4r/clusterers/centroid_linkage.rb +52 -36
- data/lib/ai4r/clusterers/cluster_tree.rb +50 -0
- data/lib/ai4r/clusterers/clusterer.rb +28 -24
- data/lib/ai4r/clusterers/complete_linkage.rb +42 -31
- data/lib/ai4r/clusterers/dbscan.rb +134 -0
- data/lib/ai4r/clusterers/diana.rb +75 -49
- data/lib/ai4r/clusterers/k_means.rb +309 -72
- data/lib/ai4r/clusterers/median_linkage.rb +49 -33
- data/lib/ai4r/clusterers/single_linkage.rb +196 -88
- data/lib/ai4r/clusterers/ward_linkage.rb +51 -35
- data/lib/ai4r/clusterers/ward_linkage_hierarchical.rb +63 -0
- data/lib/ai4r/clusterers/weighted_average_linkage.rb +48 -32
- data/lib/ai4r/data/data_set.rb +229 -100
- data/lib/ai4r/data/parameterizable.rb +31 -25
- data/lib/ai4r/data/proximity.rb +72 -50
- data/lib/ai4r/data/statistics.rb +46 -35
- data/lib/ai4r/experiment/classifier_evaluator.rb +84 -32
- data/lib/ai4r/experiment/split.rb +39 -0
- data/lib/ai4r/genetic_algorithm/chromosome_base.rb +43 -0
- data/lib/ai4r/genetic_algorithm/genetic_algorithm.rb +92 -170
- data/lib/ai4r/genetic_algorithm/tsp_chromosome.rb +83 -0
- data/lib/ai4r/hmm/hidden_markov_model.rb +134 -0
- data/lib/ai4r/neural_network/activation_functions.rb +37 -0
- data/lib/ai4r/neural_network/backpropagation.rb +419 -143
- data/lib/ai4r/neural_network/hopfield.rb +175 -58
- data/lib/ai4r/neural_network/transformer.rb +194 -0
- data/lib/ai4r/neural_network/weight_initializations.rb +40 -0
- data/lib/ai4r/reinforcement/policy_iteration.rb +66 -0
- data/lib/ai4r/reinforcement/q_learning.rb +51 -0
- data/lib/ai4r/search/a_star.rb +76 -0
- data/lib/ai4r/search/bfs.rb +50 -0
- data/lib/ai4r/search/dfs.rb +50 -0
- data/lib/ai4r/search/mcts.rb +118 -0
- data/lib/ai4r/search.rb +12 -0
- data/lib/ai4r/som/distance_metrics.rb +29 -0
- data/lib/ai4r/som/layer.rb +28 -17
- data/lib/ai4r/som/node.rb +61 -32
- data/lib/ai4r/som/som.rb +158 -41
- data/lib/ai4r/som/two_phase_layer.rb +21 -25
- data/lib/ai4r/version.rb +3 -0
- data/lib/ai4r.rb +58 -27
- metadata +117 -106
- data/README.rdoc +0 -44
- data/test/classifiers/hyperpipes_test.rb +0 -84
- data/test/classifiers/ib1_test.rb +0 -78
- data/test/classifiers/id3_test.rb +0 -208
- data/test/classifiers/multilayer_perceptron_test.rb +0 -79
- data/test/classifiers/naive_bayes_test.rb +0 -43
- data/test/classifiers/one_r_test.rb +0 -62
- data/test/classifiers/prism_test.rb +0 -85
- data/test/classifiers/zero_r_test.rb +0 -50
- data/test/clusterers/average_linkage_test.rb +0 -51
- data/test/clusterers/bisecting_k_means_test.rb +0 -66
- data/test/clusterers/centroid_linkage_test.rb +0 -53
- data/test/clusterers/complete_linkage_test.rb +0 -57
- data/test/clusterers/diana_test.rb +0 -69
- data/test/clusterers/k_means_test.rb +0 -100
- data/test/clusterers/median_linkage_test.rb +0 -53
- data/test/clusterers/single_linkage_test.rb +0 -122
- data/test/clusterers/ward_linkage_test.rb +0 -53
- data/test/clusterers/weighted_average_linkage_test.rb +0 -53
- data/test/data/data_set_test.rb +0 -96
- data/test/data/proximity_test.rb +0 -81
- data/test/data/statistics_test.rb +0 -65
- data/test/experiment/classifier_evaluator_test.rb +0 -76
- data/test/genetic_algorithm/chromosome_test.rb +0 -57
- data/test/genetic_algorithm/genetic_algorithm_test.rb +0 -81
- data/test/neural_network/backpropagation_test.rb +0 -82
- data/test/neural_network/hopfield_test.rb +0 -72
- data/test/som/som_test.rb +0 -97
@@ -1,30 +1,32 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
1
3
|
# Author:: Sergio Fierens
|
2
4
|
# License:: MPL 1.1
|
3
5
|
# Project:: ai4r
|
4
6
|
# Url:: http://www.ai4r.org/
|
5
7
|
#
|
6
|
-
# You can redistribute it and/or modify it under the terms of
|
7
|
-
# the Mozilla Public License version 1.1 as published by the
|
8
|
+
# You can redistribute it and/or modify it under the terms of
|
9
|
+
# the Mozilla Public License version 1.1 as published by the
|
8
10
|
# Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
|
9
|
-
|
11
|
+
|
10
12
|
TRIANGLE_WITH_NOISE = [
|
11
|
-
[
|
12
|
-
[
|
13
|
-
[
|
14
|
-
[
|
15
|
-
[
|
16
|
-
[
|
17
|
-
[
|
18
|
-
[
|
19
|
-
[
|
20
|
-
[
|
21
|
-
[
|
22
|
-
[
|
23
|
-
[
|
24
|
-
[
|
25
|
-
[
|
26
|
-
[10, 10, 10, 10,
|
27
|
-
]
|
13
|
+
[1, 0, 0, 0, 0, 0, 0, 1, 5, 0, 0, 1, 0, 0, 0, 0],
|
14
|
+
[0, 0, 0, 0, 3, 0, 1, 9, 9, 1, 0, 0, 0, 0, 3, 0],
|
15
|
+
[0, 3, 0, 0, 0, 0, 5, 1, 5, 3, 0, 0, 0, 0, 0, 7],
|
16
|
+
[0, 0, 0, 7, 0, 1, 9, 1, 1, 9, 1, 0, 0, 0, 3, 0],
|
17
|
+
[0, 0, 0, 0, 0, 3, 5, 0, 3, 5, 5, 0, 0, 0, 0, 0],
|
18
|
+
[0, 1, 0, 0, 1, 9, 1, 0, 1, 1, 9, 1, 0, 0, 0, 0],
|
19
|
+
[1, 0, 0, 0, 5, 5, 0, 0, 0, 0, 5, 5, 7, 0, 0, 3],
|
20
|
+
[0, 0, 3, 3, 9, 1, 0, 0, 1, 0, 1, 9, 1, 0, 0, 0],
|
21
|
+
[0, 0, 0, 5, 5, 0, 3, 7, 0, 0, 0, 5, 5, 0, 0, 0],
|
22
|
+
[0, 0, 1, 9, 1, 0, 0, 0, 0, 0, 0, 1, 9, 1, 0, 0],
|
23
|
+
[0, 0, 5, 5, 0, 0, 0, 0, 3, 0, 0, 0, 5, 5, 0, 0],
|
24
|
+
[0, 1, 9, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 9, 1, 0],
|
25
|
+
[0, 5, 5, 0, 3, 0, 0, 3, 0, 0, 0, 0, 0, 5, 5, 0],
|
26
|
+
[1, 9, 1, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0, 1, 9, 1],
|
27
|
+
[5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 5],
|
28
|
+
[10, 10, 10, 10, 1, 10, 10, 10, 10, 10, 1, 10, 10, 10, 10, 10]
|
29
|
+
].freeze
|
28
30
|
|
29
31
|
SQUARE_WITH_NOISE = [
|
30
32
|
[10, 3, 10, 10, 10, 6, 10, 10, 10, 10, 10, 4, 10, 10, 10, 10],
|
@@ -43,24 +45,24 @@ SQUARE_WITH_NOISE = [
|
|
43
45
|
[10, 0, 0, 6, 0, 0, 0, 7, 0, 0, 0, 7, 0, 0, 0, 10],
|
44
46
|
[10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10],
|
45
47
|
[10, 10, 10, 10, 3, 10, 10, 10, 10, 0, 10, 10, 1, 10, 1, 10]
|
46
|
-
|
47
|
-
]
|
48
|
+
|
49
|
+
].freeze
|
48
50
|
|
49
51
|
CROSS_WITH_NOISE = [
|
50
|
-
[
|
51
|
-
[
|
52
|
-
[
|
53
|
-
[
|
54
|
-
[
|
55
|
-
[
|
56
|
-
[
|
57
|
-
[
|
58
|
-
[
|
59
|
-
[
|
60
|
-
[
|
61
|
-
[
|
62
|
-
[
|
63
|
-
[
|
64
|
-
[
|
65
|
-
[
|
66
|
-
]
|
52
|
+
[0, 0, 0, 0, 0, 0, 3, 3, 5, 0, 3, 0, 0, 0, 1, 0],
|
53
|
+
[0, 1, 0, 0, 0, 1, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
|
54
|
+
[0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 3, 0, 0, 0],
|
55
|
+
[0, 0, 1, 8, 0, 0, 0, 5, 5, 0, 4, 0, 0, 0, 1, 0],
|
56
|
+
[0, 0, 0, 0, 0, 3, 0, 5, 0, 0, 0, 0, 1, 0, 0, 0],
|
57
|
+
[0, 0, 0, 8, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 1],
|
58
|
+
[0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 3, 0, 0, 0, 0, 0],
|
59
|
+
[5, 5, 5, 8, 5, 3, 5, 5, 5, 5, 5, 5, 5, 5, 0, 5],
|
60
|
+
[5, 5, 5, 5, 5, 5, 5, 5, 1, 5, 5, 5, 5, 1, 0, 0],
|
61
|
+
[0, 0, 0, 8, 0, 0, 0, 4, 5, 0, 0, 0, 0, 0, 0, 0],
|
62
|
+
[0, 0, 0, 0, 0, 0, 0, 5, 5, 4, 0, 0, 0, 0, 0, 0],
|
63
|
+
[0, 0, 0, 0, 0, 4, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
|
64
|
+
[4, 0, 0, 4, 0, 0, 0, 5, 5, 0, 0, 0, 1, 0, 0, 0],
|
65
|
+
[0, 0, 0, 0, 0, 1, 0, 5, 4, 4, 3, 0, 0, 0, 0, 0],
|
66
|
+
[0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 10, 0, 0, 0],
|
67
|
+
[0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0]
|
68
|
+
].freeze
|
@@ -0,0 +1,25 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
# Author:: Example contributor
|
4
|
+
# License:: MPL 1.1
|
5
|
+
# Project:: ai4r
|
6
|
+
#
|
7
|
+
# Simple example showing how to use Backpropagation#train_epochs with a callback.
|
8
|
+
|
9
|
+
require_relative '../../lib/ai4r/neural_network/backpropagation'
|
10
|
+
|
11
|
+
inputs = [[0, 0], [0, 1], [1, 0], [1, 1]]
|
12
|
+
outputs = [[0], [1], [1], [0]]
|
13
|
+
|
14
|
+
net = Ai4r::NeuralNetwork::Backpropagation.new([2, 2, 1])
|
15
|
+
|
16
|
+
loss_history = []
|
17
|
+
net.train_epochs(inputs, outputs, epochs: 200, batch_size: 1) do |epoch, loss, acc|
|
18
|
+
loss_history << [epoch, loss, acc]
|
19
|
+
if (epoch % 50).zero?
|
20
|
+
puts "Epoch #{epoch}: loss #{format('%.4f',
|
21
|
+
loss)} accuracy #{(acc * 100).round(2)}%"
|
22
|
+
end
|
23
|
+
end
|
24
|
+
|
25
|
+
puts 'Training finished.'
|
@@ -1,31 +1,32 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
1
3
|
# Author:: Sergio Fierens
|
2
4
|
# License:: MPL 1.1
|
3
5
|
# Project:: ai4r
|
4
6
|
# Url:: http://www.ai4r.org/
|
5
7
|
#
|
6
|
-
# You can redistribute it and/or modify it under the terms of
|
7
|
-
# the Mozilla Public License version 1.1 as published by the
|
8
|
+
# You can redistribute it and/or modify it under the terms of
|
9
|
+
# the Mozilla Public License version 1.1 as published by the
|
8
10
|
# Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
|
9
|
-
|
10
11
|
|
11
12
|
TRIANGLE = [
|
12
|
-
[
|
13
|
-
[
|
14
|
-
[
|
15
|
-
[
|
16
|
-
[
|
17
|
-
[
|
18
|
-
[
|
19
|
-
[
|
20
|
-
[
|
21
|
-
[
|
22
|
-
[
|
23
|
-
[
|
24
|
-
[
|
25
|
-
[
|
26
|
-
[
|
13
|
+
[0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
|
14
|
+
[0, 0, 0, 0, 0, 0, 1, 9, 9, 1, 0, 0, 0, 0, 0, 0],
|
15
|
+
[0, 0, 0, 0, 0, 0, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0],
|
16
|
+
[0, 0, 0, 0, 0, 1, 9, 1, 1, 9, 1, 0, 0, 0, 0, 0],
|
17
|
+
[0, 0, 0, 0, 0, 5, 5, 0, 0, 5, 5, 0, 0, 0, 0, 0],
|
18
|
+
[0, 0, 0, 0, 1, 9, 1, 0, 0, 1, 9, 1, 0, 0, 0, 0],
|
19
|
+
[0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0],
|
20
|
+
[0, 0, 0, 1, 9, 1, 0, 0, 0, 0, 1, 9, 1, 0, 0, 0],
|
21
|
+
[0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0],
|
22
|
+
[0, 0, 1, 9, 1, 0, 0, 0, 0, 0, 0, 1, 9, 1, 0, 0],
|
23
|
+
[0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0],
|
24
|
+
[0, 1, 9, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 9, 1, 0],
|
25
|
+
[0, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 5, 0],
|
26
|
+
[1, 9, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 9, 1],
|
27
|
+
[5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 5],
|
27
28
|
[10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]
|
28
|
-
]
|
29
|
+
].freeze
|
29
30
|
|
30
31
|
SQUARE = [
|
31
32
|
[10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],
|
@@ -44,25 +45,24 @@ SQUARE = [
|
|
44
45
|
[10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10],
|
45
46
|
[10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10],
|
46
47
|
[10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]
|
47
|
-
|
48
|
-
]
|
48
|
+
|
49
|
+
].freeze
|
49
50
|
|
50
51
|
CROSS = [
|
51
|
-
[
|
52
|
-
[
|
53
|
-
[
|
54
|
-
[
|
55
|
-
[
|
56
|
-
[
|
57
|
-
[
|
58
|
-
[
|
59
|
-
[
|
60
|
-
[
|
61
|
-
[
|
62
|
-
[
|
63
|
-
[
|
64
|
-
[
|
65
|
-
[
|
66
|
-
[
|
67
|
-
]
|
68
|
-
|
52
|
+
[0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
|
53
|
+
[0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
|
54
|
+
[0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
|
55
|
+
[0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
|
56
|
+
[0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
|
57
|
+
[0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
|
58
|
+
[0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
|
59
|
+
[5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
|
60
|
+
[5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
|
61
|
+
[0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
|
62
|
+
[0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
|
63
|
+
[0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
|
64
|
+
[0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
|
65
|
+
[0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
|
66
|
+
[0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
|
67
|
+
[0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0]
|
68
|
+
].freeze
|
@@ -0,0 +1,78 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
# Author:: OpenAI Assistant
|
4
|
+
# License:: MPL 1.1
|
5
|
+
# Project:: ai4r
|
6
|
+
# Url:: http://www.ai4r.org/
|
7
|
+
#
|
8
|
+
# Toy example of using the minimal Transformer encoder for
|
9
|
+
# text classification. We build random sentence embeddings
|
10
|
+
# with the Transformer and train a logistic regression
|
11
|
+
# classifier on a tiny sentiment dataset.
|
12
|
+
|
13
|
+
require_relative '../../lib/ai4r/neural_network/transformer'
|
14
|
+
require_relative '../../lib/ai4r/classifiers/logistic_regression'
|
15
|
+
require_relative '../../lib/ai4r/data/data_set'
|
16
|
+
|
17
|
+
# Vocabulary for our miniature dataset
|
18
|
+
VOCAB = {
|
19
|
+
'good' => 0,
|
20
|
+
'great' => 1,
|
21
|
+
'bad' => 2,
|
22
|
+
'awful' => 3,
|
23
|
+
'movie' => 4,
|
24
|
+
'film' => 5,
|
25
|
+
'<pad>' => 6
|
26
|
+
}.freeze
|
27
|
+
|
28
|
+
MAX_LEN = 2
|
29
|
+
|
30
|
+
# Helper that converts space separated text into an array of token ids
|
31
|
+
# and pads it to MAX_LEN tokens.
|
32
|
+
def encode(text)
|
33
|
+
tokens = text.split.map { |w| VOCAB[w] }
|
34
|
+
tokens.fill(VOCAB['<pad>'], tokens.length...MAX_LEN)
|
35
|
+
end
|
36
|
+
|
37
|
+
# Build encoder only Transformer
|
38
|
+
model = Ai4r::NeuralNetwork::Transformer.new(
|
39
|
+
vocab_size: VOCAB.size,
|
40
|
+
max_len: MAX_LEN
|
41
|
+
)
|
42
|
+
|
43
|
+
train_texts = ['good movie', 'great film', 'bad movie', 'awful film']
|
44
|
+
labels = [1, 1, 0, 0]
|
45
|
+
|
46
|
+
# Obtain sentence embeddings by averaging token representations
|
47
|
+
train_features = train_texts.map do |text|
|
48
|
+
tokens = encode(text)
|
49
|
+
enc = model.eval(tokens)
|
50
|
+
mean = Array.new(model.embed_dim, 0.0)
|
51
|
+
enc.each do |vec|
|
52
|
+
vec.each_with_index { |v, i| mean[i] += v }
|
53
|
+
end
|
54
|
+
mean.map { |v| v / enc.length }
|
55
|
+
end
|
56
|
+
|
57
|
+
data_items = train_features.each_with_index.map { |feat, i| feat + [labels[i]] }
|
58
|
+
labels_names = (0...model.embed_dim).map { |i| "f#{i}" } + ['class']
|
59
|
+
|
60
|
+
dataset = Ai4r::Data::DataSet.new(
|
61
|
+
data_items: data_items,
|
62
|
+
data_labels: labels_names
|
63
|
+
)
|
64
|
+
|
65
|
+
classifier = Ai4r::Classifiers::LogisticRegression.new
|
66
|
+
classifier.set_parameters(lr: 0.5, iterations: 2000).build(dataset)
|
67
|
+
|
68
|
+
puts 'Predictions:'
|
69
|
+
['good film', 'awful movie'].each do |text|
|
70
|
+
tokens = encode(text)
|
71
|
+
enc = model.eval(tokens)
|
72
|
+
mean = Array.new(model.embed_dim, 0.0)
|
73
|
+
enc.each do |vec|
|
74
|
+
vec.each_with_index { |v, i| mean[i] += v }
|
75
|
+
end
|
76
|
+
mean.map! { |v| v / enc.length }
|
77
|
+
puts "#{text} => #{classifier.eval(mean)}"
|
78
|
+
end
|
@@ -1,35 +1,36 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
1
3
|
# Author:: Sergio Fierens
|
2
4
|
# License:: MPL 1.1
|
3
5
|
# Project:: ai4r
|
4
6
|
# Url:: http://www.ai4r.org/
|
5
7
|
#
|
6
|
-
# You can redistribute it and/or modify it under the terms of
|
7
|
-
# the Mozilla Public License version 1.1 as published by the
|
8
|
+
# You can redistribute it and/or modify it under the terms of
|
9
|
+
# the Mozilla Public License version 1.1 as published by the
|
8
10
|
# Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
|
9
11
|
|
10
|
-
|
12
|
+
require_relative '../../lib/ai4r/neural_network/backpropagation'
|
11
13
|
require 'benchmark'
|
12
14
|
|
13
15
|
times = Benchmark.measure do
|
16
|
+
srand 1
|
17
|
+
|
18
|
+
net = Ai4r::NeuralNetwork::Backpropagation.new([2, 2, 1])
|
19
|
+
|
20
|
+
puts 'Training the network, please wait.'
|
21
|
+
2001.times do |i|
|
22
|
+
net.train([0, 0], [0])
|
23
|
+
net.train([0, 1], [1])
|
24
|
+
net.train([1, 0], [1])
|
25
|
+
error = net.train([1, 1], [0])
|
26
|
+
puts "Error after iteration #{i}:\t#{error}" if (i % 200).zero?
|
27
|
+
end
|
14
28
|
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
2001.times do |i|
|
21
|
-
net.train([0,0], [0])
|
22
|
-
net.train([0,1], [1])
|
23
|
-
net.train([1,0], [1])
|
24
|
-
error = net.train([1,1], [0])
|
25
|
-
puts "Error after iteration #{i}:\t#{error}" if i%200 == 0
|
26
|
-
end
|
27
|
-
|
28
|
-
puts "Test data"
|
29
|
-
puts "[0,0] = > #{net.eval([0,0]).inspect}"
|
30
|
-
puts "[0,1] = > #{net.eval([0,1]).inspect}"
|
31
|
-
puts "[1,0] = > #{net.eval([1,0]).inspect}"
|
32
|
-
puts "[1,1] = > #{net.eval([1,1]).inspect}"
|
29
|
+
puts 'Test data'
|
30
|
+
puts "[0,0] = > #{net.eval([0, 0]).inspect}"
|
31
|
+
puts "[0,1] = > #{net.eval([0, 1]).inspect}"
|
32
|
+
puts "[1,0] = > #{net.eval([1, 0]).inspect}"
|
33
|
+
puts "[1,1] = > #{net.eval([1, 1]).inspect}"
|
33
34
|
end
|
34
35
|
|
35
|
-
|
36
|
+
puts "Elapsed time: #{times}"
|
@@ -0,0 +1,10 @@
|
|
1
|
+
require 'ai4r/reinforcement/q_learning'
|
2
|
+
|
3
|
+
agent = Ai4r::Reinforcement::QLearning.new
|
4
|
+
agent.set_parameters(learning_rate: 0.5, discount: 1.0, exploration: 0.0)
|
5
|
+
|
6
|
+
# Simple two-state MDP
|
7
|
+
agent.update(:s1, :a, 0, :s2)
|
8
|
+
agent.update(:s1, :b, 1, :s1)
|
9
|
+
|
10
|
+
puts "Best action from s1: #{agent.choose_action(:s1)}"
|