bae 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/README.md ADDED
@@ -0,0 +1,54 @@
1
+ Bae
2
+ ===
3
+
4
+ Bae is a multinomial naive bayes classifier based on another gem ["naivebayes"](https://github.com/id774/naivebayes), only this one uses java to do the heavy lifting.
5
+
6
+ ## Installation
7
+
8
+ Add this line to your application's Gemfile:
9
+
10
+ gem 'bae'
11
+
12
+ And then execute:
13
+
14
+ $ bundle
15
+
16
+ Or install it yourself as:
17
+
18
+ $ gem install bae
19
+
20
+ ## Usage
21
+
22
+ You can refer to ["naivebayes"](https://github.com/id774/naivebayes) gem for more documentation, or the tests for examples. Here is a copy/ paster example:
23
+
24
+
25
+ ### You can provide a frequency hash to the trainer
26
+
27
+ ```ruby
28
+ classifier = ::Bae::Classifier.new
29
+ classifier.train("positive", {"aaa" => 0, "bbb" => 1})
30
+ classifier.train("negative", {"ccc" => 2, "ddd" => 3})
31
+ classifier.classify({"aaa" => 1, "bbb" => 1})
32
+
33
+ #=> {"positive" => 0.8767123287671234, "negative" => 0.12328767123287669}
34
+ ```
35
+
36
+ ### Or you can train with strings
37
+ ```ruby
38
+ classifier = ::Bae::Classifier.new
39
+ classifier.train("positive", "aaa aaa bbb");
40
+ classifier.train("negative", "ccc ccc ddd ddd");
41
+ classifier.train("neutral", "eee eee eee fff fff fff");
42
+ classifier.classify("aaa bbb")
43
+
44
+ #=> {"positive"=>0.8962655601659751, "negative"=>0.0663900414937759, "neutral"=>0.037344398340248955}
45
+ ```
46
+
47
+
48
+ ## Contributing
49
+
50
+ 1. Fork it ( https://github.com/[my-github-username]/bae/fork )
51
+ 2. Create your feature branch (`git checkout -b my-new-feature`)
52
+ 3. Commit your changes (`git commit -am 'Add some feature'`)
53
+ 4. Push to the branch (`git push origin my-new-feature`)
54
+ 5. Create a new Pull Request
data/Rakefile ADDED
@@ -0,0 +1,8 @@
1
+ require "bundler/gem_tasks"
2
+ require "rspec/core/rake_task"
3
+
4
+ desc "Run specs"
5
+ RSpec::Core::RakeTask.new(:spec)
6
+
7
+ desc "Run specs (default)"
8
+ task :default => :spec
data/bae.gemspec ADDED
@@ -0,0 +1,24 @@
1
+ # coding: utf-8
2
+ lib = File.expand_path('../lib', __FILE__)
3
+ $LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib)
4
+ require 'bae/version'
5
+
6
+ Gem::Specification.new do |spec|
7
+ spec.name = "bae"
8
+ spec.version = Bae::VERSION
9
+ spec.authors = ["Garrett Thornburg"]
10
+ spec.email = ["film42@gmail.com"]
11
+ spec.summary = "Multinomial naive bayes classifier with a kick of java"
12
+ spec.description = "Multinomial naive bayes classifier with a kick of java"
13
+ spec.homepage = "https://github.com/film42/bae"
14
+ spec.license = "GPL version 3, or LGPL version 3 (Dual License)"
15
+
16
+ spec.files = `git ls-files -z`.split("\x0")
17
+ spec.executables = spec.files.grep(%r{^bin/}) { |f| File.basename(f) }
18
+ spec.test_files = spec.files.grep(%r{^(test|spec|features)/})
19
+ spec.require_paths = ["lib"]
20
+
21
+ spec.add_development_dependency "bundler", "~> 1.6"
22
+ spec.add_development_dependency "rspec"
23
+ spec.add_development_dependency "rake"
24
+ end
data/build.xml ADDED
@@ -0,0 +1,18 @@
1
+ <project>
2
+
3
+ <target name="clean">
4
+ <delete dir="build"/>
5
+ </target>
6
+
7
+ <target name="compile">
8
+ <mkdir dir="out"/>
9
+ <mkdir dir="out/classes"/>
10
+ <javac srcdir="src/main/java" destdir="out/classes"/>
11
+ </target>
12
+
13
+ <target name="jar" depends="compile">
14
+ <mkdir dir="target"/>
15
+ <jar destfile="target/bae.jar" basedir="out/classes" excludes="**/*.jar,**/MANIFEST.MF,**/BCKEY.*" />
16
+ </target>
17
+
18
+ </project>
@@ -0,0 +1,19 @@
1
+ module Bae
2
+ class Classifier
3
+
4
+ attr_reader :internal_classifier
5
+
6
+ def initialize
7
+ @internal_classifier = ::Java::Bae::NaiveBayesClassifier.new
8
+ end
9
+
10
+ def train(label, feature)
11
+ internal_classifier.train(label, ::Java::Bae::Document.new(feature))
12
+ end
13
+
14
+ def classify(feature)
15
+ internal_classifier.classify(::Java::Bae::Document.new(feature))
16
+ end
17
+
18
+ end
19
+ end
@@ -0,0 +1,3 @@
1
+ module Bae
2
+ VERSION = "0.0.1"
3
+ end
data/lib/bae.rb ADDED
@@ -0,0 +1,12 @@
1
+ require "bae/version"
2
+
3
+ require "target/bae.jar"
4
+
5
+ java_import "bae.Document"
6
+ java_import "bae.FrequencyTable"
7
+ java_import "bae.NaiveBayesClassifier"
8
+
9
+ require "bae/classifier"
10
+
11
+ module Bae
12
+ end
@@ -0,0 +1,27 @@
1
+ require 'spec_helper'
2
+
3
+ describe ::Bae::Classifier do
4
+
5
+ subject { described_class.new }
6
+
7
+ it "can classify from ruby to java with a hash document" do
8
+ subject.train("positive", {"aaa" => 0, "bbb" => 1})
9
+ subject.train("negative", {"ccc" => 2, "ddd" => 3})
10
+ results = subject.classify({"aaa" => 1, "bbb" => 1})
11
+
12
+ expect(results["positive"]).to be_within(0.001).of(0.94117)
13
+ expect(results["negative"]).to be_within(0.001).of(0.05882)
14
+ end
15
+
16
+ it "can classify from ruby to java with a string based document" do
17
+ subject.train("positive", "aaa aaa bbb");
18
+ subject.train("negative", "ccc ccc ddd ddd");
19
+ subject.train("neutral", "eee eee eee fff fff fff");
20
+ results = subject.classify("aaa bbb")
21
+
22
+ expect(results["positive"]).to be_within(0.001).of(0.89626)
23
+ expect(results["negative"]).to be_within(0.001).of(0.06639)
24
+ expect(results["neutral"]).to be_within(0.001).of(0.03734)
25
+ end
26
+
27
+ end
@@ -0,0 +1,8 @@
1
+ require 'bundler/setup'
2
+ require 'bae'
3
+ require 'rspec'
4
+
5
+ RSpec.configure do |c|
6
+ c.order = :rand
7
+ c.color = true
8
+ end
@@ -0,0 +1,42 @@
1
+ package bae;
2
+
3
+ import java.util.HashMap;
4
+ import java.util.Map;
5
+ import java.util.Scanner;
6
+
7
+ public class Document {
8
+
9
+ private Map<String, Long> frequencyMap;
10
+
11
+ public Document(String text) {
12
+ createFrequencyMap(text);
13
+ }
14
+
15
+ public Document(Map<String, Long> frequencyMap) {
16
+ this.frequencyMap = frequencyMap;
17
+ }
18
+
19
+ public Map<String, Long> getFrequencyMap() {
20
+ return frequencyMap;
21
+ }
22
+
23
+ public void addZeroCount(String key) {
24
+ this.frequencyMap.put(key, 0L);
25
+ }
26
+
27
+ private void createFrequencyMap(String text) {
28
+ this.frequencyMap = new HashMap<>();
29
+
30
+ Scanner parser = new Scanner(text);
31
+ while(parser.hasNext()) {
32
+ String wordToken = parser.next();
33
+
34
+ // Set initial count if it doesn't have one yet
35
+ // Use zero because we'll add counts in the next line.
36
+ this.frequencyMap.putIfAbsent(wordToken, 0L);
37
+
38
+ // Update count
39
+ this.frequencyMap.put(wordToken, this.frequencyMap.get(wordToken) + 1);
40
+ }
41
+ }
42
+ }
@@ -0,0 +1,44 @@
1
+ package bae;
2
+
3
+ import java.util.HashMap;
4
+ import java.util.Map;
5
+ import java.util.Set;
6
+
7
+ public class FrequencyTable {
8
+
9
+ private Map<String, Map<String, Long>> frequencyTable;
10
+
11
+ public FrequencyTable() {
12
+ this.frequencyTable = new HashMap<>();
13
+ }
14
+
15
+ public void insertOrIgnore(String label) {
16
+ // Add new hash to frequency table if it's not already there
17
+ this.frequencyTable.putIfAbsent(label, new HashMap<String, Long>());
18
+ }
19
+
20
+ public void increaseFrequencyBy(String label, String word, long frequency) {
21
+ // Add label if it doesn't exist
22
+ insertOrIgnore(label);
23
+
24
+ Map<String, Long> frequencyRow = this.frequencyTable.get(label);
25
+
26
+ // Make sure we have a frequency for that position in the table
27
+ frequencyRow.putIfAbsent(word, 0L);
28
+
29
+ // Update frequency
30
+ frequencyRow.put(word, frequencyRow.get(word) + frequency);
31
+ }
32
+
33
+ public Set<String> getLabels() {
34
+ return this.frequencyTable.keySet();
35
+ }
36
+
37
+ public long get(String label, String word) {
38
+ try {
39
+ return this.frequencyTable.get(label).get(word);
40
+ } catch (NullPointerException e) {
41
+ return 0L;
42
+ }
43
+ }
44
+ }
@@ -0,0 +1,120 @@
1
+ package bae;
2
+
3
+ import java.util.HashMap;
4
+ import java.util.Map;
5
+
6
+ public class NaiveBayesClassifier {
7
+
8
+ private FrequencyTable frequencyTable;
9
+ private Map<String, Long> wordTable;
10
+ private Map<String, Long> instanceCountOf;
11
+ private double totalCount = 0;
12
+
13
+ public NaiveBayesClassifier() {
14
+ this.frequencyTable = new FrequencyTable();
15
+ this.wordTable = new HashMap<>();
16
+ this.instanceCountOf = new HashMap<>();
17
+ }
18
+
19
+ public void train(String label, Document document) {
20
+ // Add to the frequency table if it doesn't exist
21
+ this.frequencyTable.insertOrIgnore(label);
22
+
23
+ // Update frequency table with the documents frequency
24
+ for(Map.Entry<String, Long> entry : document.getFrequencyMap().entrySet()) {
25
+ String word = entry.getKey();
26
+ long frequency = entry.getValue();
27
+
28
+ // Update counts
29
+ this.frequencyTable.increaseFrequencyBy(label, word, frequency);
30
+ // Add the word's presence to the word table
31
+ this.wordTable.put(word, 1L);
32
+ }
33
+
34
+ // Update global counts
35
+ totalCount += 1;
36
+ // Update instance count
37
+ updateIntegerCountBy(this.instanceCountOf, label, 1);
38
+ }
39
+
40
+ public Map<String, Double> classify(Document document) {
41
+ Map<String, Double> classPriorOf = new HashMap<>();
42
+ Map<String, Double> likelihoodOf = new HashMap<>();
43
+ Map<String, Double> classPosteriorOf = new HashMap<>();
44
+ Map<String, Long> frequencyMap = document.getFrequencyMap();
45
+ double evidence = 0;
46
+
47
+ // Update the prior
48
+ for(Map.Entry<String, Long> entry : this.instanceCountOf.entrySet()) {
49
+ String label = entry.getKey();
50
+ double frequency = entry.getValue();
51
+
52
+ // Update instance count
53
+ classPriorOf.put(label, (frequency / this.totalCount));
54
+ }
55
+
56
+ // Update likelihood counts
57
+ for(String label : this.frequencyTable.getLabels()) {
58
+ // Set initial likelihood
59
+ likelihoodOf.put(label, 1d);
60
+
61
+ // Calculate likelihoods
62
+ for(String word : wordTable.keySet()) {
63
+ double laplaceWordLikelihood =
64
+ (this.frequencyTable.get(label, word) + 1d) /
65
+ (this.instanceCountOf.get(label) + this.wordTable.size());
66
+
67
+ // Update likelihood
68
+ double likelihood = likelihoodOf.get(label);
69
+ if(frequencyMap.containsKey(word)) {
70
+ likelihoodOf.put(label, likelihood * laplaceWordLikelihood);
71
+ } else {
72
+ likelihoodOf.put(label, likelihood * (1d - laplaceWordLikelihood));
73
+ }
74
+ }
75
+
76
+ // Default class posterior of label to 1.0
77
+ classPosteriorOf.putIfAbsent(label, 1d);
78
+
79
+ // Update class posterior
80
+ double classPosterior = classPriorOf.get(label) * likelihoodOf.get(label);
81
+ classPosteriorOf.put(label, classPosterior);
82
+ evidence += classPosterior;
83
+ }
84
+
85
+ // Normalize results
86
+ for(Map.Entry<String, Double> entry : classPosteriorOf.entrySet()) {
87
+ String label = entry.getKey();
88
+ double posterior = entry.getValue();
89
+ classPosteriorOf.put(label, posterior / evidence);
90
+ }
91
+
92
+ return classPosteriorOf;
93
+ }
94
+
95
+ public void updateIntegerCountBy(Map<String, Long> someMap, String someKey, long count) {
96
+ someMap.putIfAbsent(someKey, 0L);
97
+ someMap.put(someKey, someMap.get(someKey) + count);
98
+ }
99
+
100
+ public void updateDoubleCountBy(Map<String, Double> someMap, String someKey, double count) {
101
+ someMap.putIfAbsent(someKey, 0.0);
102
+ someMap.put(someKey, someMap.get(someKey) + count);
103
+ }
104
+
105
+ public FrequencyTable getFrequencyTable() {
106
+ return this.frequencyTable;
107
+ }
108
+
109
+ public Map<String, Long> getWordTable() {
110
+ return this.wordTable;
111
+ }
112
+
113
+ public Map<String, Long> getInstanceCount() {
114
+ return this.instanceCountOf;
115
+ }
116
+
117
+ public double getTotalCount() {
118
+ return totalCount;
119
+ }
120
+ }
@@ -0,0 +1,34 @@
1
+ package bae;
2
+
3
+ import bae.Document;
4
+ import org.junit.Test;
5
+
6
+ import java.util.Map;
7
+
8
+ import static org.junit.Assert.*;
9
+
10
+ public class DocumentTest {
11
+
12
+ @Test
13
+ public void testCanCreateAccurateFrequencyTable() {
14
+ Document document = new Document("aaa bbb aaa bbb ccc");
15
+ Map<String, Long> frequencyMap = document.getFrequencyMap();
16
+
17
+ assertEquals(2, (long)frequencyMap.get("aaa"));
18
+ assertEquals(2, (long)frequencyMap.get("bbb"));
19
+ assertEquals(1, (long)frequencyMap.get("ccc"));
20
+ }
21
+
22
+ @Test
23
+ public void testCanParseADirtyString() {
24
+ Document document = new Document(" a aaa\ta \t\t\t aa a bbb a aaa bbb ccc ");
25
+ Map<String, Long> frequencyMap = document.getFrequencyMap();
26
+
27
+ assertEquals(2, (long)frequencyMap.get("aaa"));
28
+ assertEquals(2, (long)frequencyMap.get("bbb"));
29
+ assertEquals(1, (long)frequencyMap.get("ccc"));
30
+ assertEquals(1, (long)frequencyMap.get("aa"));
31
+ assertEquals(4, (long)frequencyMap.get("a"));
32
+ }
33
+
34
+ }
@@ -0,0 +1,23 @@
1
+ package bae;
2
+
3
+ import bae.FrequencyTable;
4
+ import org.junit.Test;
5
+
6
+ import static org.junit.Assert.*;
7
+
8
+ public class FrequencyTableTest {
9
+
10
+ @Test
11
+ public void canInsertIntoFrequencyTable() {
12
+ FrequencyTable frequencyTable = new FrequencyTable();
13
+
14
+ frequencyTable.increaseFrequencyBy("a", "b", 100);
15
+
16
+ assertEquals(100, frequencyTable.get("a", "b"));
17
+
18
+ // Make sure we fail correctly
19
+ assertEquals(0, frequencyTable.get("a", "z"));
20
+ assertEquals(0, frequencyTable.get("z", "z"));
21
+ }
22
+
23
+ }
@@ -0,0 +1,134 @@
1
+ package bae;
2
+
3
+ import bae.Document;
4
+ import bae.NaiveBayesClassifier;
5
+ import org.junit.Test;
6
+
7
+ import java.util.Map;
8
+
9
+ import static org.junit.Assert.*;
10
+
11
+ public class NaiveBayesClassifierTest {
12
+
13
+ @Test
14
+ public void canCorrectlyTrainFrequencyTable() {
15
+ NaiveBayesClassifier n = new NaiveBayesClassifier();
16
+
17
+ n.train("positive", new Document("bbb"));
18
+ n.train("negative", new Document("ccc ccc ddd ddd ddd"));
19
+ n.train("positive", new Document("aaa bbb bbb"));
20
+ n.train("negative", new Document("ccc ccc ccc ddd ddd ddd ddd"));
21
+
22
+ assertEquals(1, n.getFrequencyTable().get("positive", "aaa"));
23
+ assertEquals(3, n.getFrequencyTable().get("positive", "bbb"));
24
+ assertEquals(5, n.getFrequencyTable().get("negative", "ccc"));
25
+ assertEquals(7, n.getFrequencyTable().get("negative", "ddd"));
26
+ }
27
+
28
+ @Test
29
+ public void canCorrectlyTrainWordTable() {
30
+ NaiveBayesClassifier n = new NaiveBayesClassifier();
31
+
32
+ n.train("positive", new Document("bbb"));
33
+ n.train("negative", new Document("ccc ccc ddd ddd ddd"));
34
+ n.train("positive", new Document("aaa bbb bbb"));
35
+ n.train("negative", new Document("ccc ccc ccc ddd ddd ddd ddd"));
36
+
37
+ assertEquals(1, (long)n.getWordTable().get("aaa"));
38
+ assertEquals(1, (long)n.getWordTable().get("bbb"));
39
+ assertEquals(1, (long)n.getWordTable().get("ccc"));
40
+ assertEquals(1, (long)n.getWordTable().get("ddd"));
41
+ }
42
+
43
+ @Test
44
+ public void canCorrectlyTrainInstanceCount() {
45
+ NaiveBayesClassifier n = new NaiveBayesClassifier();
46
+
47
+ n.train("positive", new Document("bbb"));
48
+ n.train("negative", new Document("ccc ccc ddd ddd ddd"));
49
+ n.train("positive", new Document("aaa bbb bbb"));
50
+ n.train("negative", new Document("ccc ccc ccc ddd ddd ddd ddd"));
51
+ n.train("negative", new Document("ccc ccc ccc ddd ddd"));
52
+
53
+ assertEquals(2, (long)n.getInstanceCount().get("positive"));
54
+ assertEquals(3, (long)n.getInstanceCount().get("negative"));
55
+ assertEquals(5, (long)n.getTotalCount());
56
+ }
57
+
58
+ @Test
59
+ public void canCorrectlyClassifyPositiveWithTwoLabels() {
60
+ NaiveBayesClassifier n = new NaiveBayesClassifier();
61
+
62
+ Document d = new Document("bbb");
63
+ d.addZeroCount("aaa");
64
+
65
+ n.train("positive", d);
66
+ n.train("negative", new Document("ccc ccc ddd ddd ddd"));
67
+
68
+ Map<String, Double> results = n.classify(new Document("aaa bbb"));
69
+
70
+ assertEquals(0.9411764705882353, results.get("positive"), 0.00001);
71
+ assertEquals(0.05882352941176469, results.get("negative"), 0.00001);
72
+ }
73
+
74
+ @Test
75
+ public void canCorrectlyClassifyNegativeWithTwoLabels() {
76
+ NaiveBayesClassifier n = new NaiveBayesClassifier();
77
+
78
+ Document d = new Document("bbb");
79
+ d.addZeroCount("aaa");
80
+
81
+ n.train("positive", d);
82
+ n.train("negative", new Document("ccc ccc ddd ddd ddd"));
83
+
84
+ Map<String, Double> results = n.classify(new Document("ccc ccc ccc ddd ddd ddd"));
85
+
86
+ assertEquals(0.05882352941176469, results.get("positive"), 0.00001);
87
+ assertEquals(0.9411764705882353, results.get("negative"), 0.00001);
88
+ }
89
+
90
+ @Test
91
+ public void canCorrectlyClassifyPositiveWithThreeLabels() {
92
+ NaiveBayesClassifier n = new NaiveBayesClassifier();
93
+
94
+ n.train("positive", new Document("aaa aaa bbb"));
95
+ n.train("negative", new Document("ccc ccc ddd ddd"));
96
+ n.train("neutral", new Document("eee eee eee fff fff fff"));
97
+
98
+ Map<String, Double> results = n.classify(new Document("aaa bbb"));
99
+
100
+ assertEquals(0.896265560165975, results.get("positive"), 0.00001);
101
+ assertEquals(0.06639004149377592, results.get("negative"), 0.00001);
102
+ assertEquals(0.03734439834024896, results.get("neutral"), 0.00001);
103
+ }
104
+
105
+ @Test
106
+ public void canCorrectlyClassifyNegativeWithThreeLabels() {
107
+ NaiveBayesClassifier n = new NaiveBayesClassifier();
108
+
109
+ n.train("positive", new Document("aaa aaa bbb"));
110
+ n.train("negative", new Document("ccc ccc ddd ddd"));
111
+ n.train("neutral", new Document("eee eee eee fff fff fff"));
112
+
113
+ Map<String, Double> results = n.classify(new Document("ccc ccc ccc ddd ddd"));
114
+
115
+ assertEquals(0.05665722379603399, results.get("positive"), 0.00001);
116
+ assertEquals(0.9178470254957508, results.get("negative"), 0.00001);
117
+ assertEquals(0.0254957507082153, results.get("neutral"), 0.00001);
118
+ }
119
+
120
+ @Test
121
+ public void canCorrectlyClassifyNeutralWithThreeLabels() {
122
+ NaiveBayesClassifier n = new NaiveBayesClassifier();
123
+
124
+ n.train("positive", new Document("aaa aaa bbb"));
125
+ n.train("negative", new Document("ccc ccc ddd ddd"));
126
+ n.train("neutral", new Document("eee eee eee fff fff fff"));
127
+
128
+ Map<String, Double> results = n.classify(new Document("aaa ddd ddd eee eee eee fff"));
129
+
130
+ assertEquals(0.12195121951219513, results.get("positive"), 0.00001);
131
+ assertEquals(0.09756097560975606, results.get("negative"), 0.00001);
132
+ assertEquals(0.7804878048780488, results.get("neutral"), 0.00001);
133
+ }
134
+ }
data/target/bae.jar ADDED
Binary file