bae 0.0.1

Sign up to get free protection for your applications and to get access to all the features.
data/README.md ADDED
@@ -0,0 +1,54 @@
1
+ Bae
2
+ ===
3
+
4
+ Bae is a multinomial naive bayes classifier based on another gem ["naivebayes"](https://github.com/id774/naivebayes), only this one uses java to do the heavy lifting.
5
+
6
+ ## Installation
7
+
8
+ Add this line to your application's Gemfile:
9
+
10
+ gem 'bae'
11
+
12
+ And then execute:
13
+
14
+ $ bundle
15
+
16
+ Or install it yourself as:
17
+
18
+ $ gem install bae
19
+
20
+ ## Usage
21
+
22
+ You can refer to ["naivebayes"](https://github.com/id774/naivebayes) gem for more documentation, or the tests for examples. Here is a copy/ paster example:
23
+
24
+
25
+ ### You can provide a frequency hash to the trainer
26
+
27
+ ```ruby
28
+ classifier = ::Bae::Classifier.new
29
+ classifier.train("positive", {"aaa" => 0, "bbb" => 1})
30
+ classifier.train("negative", {"ccc" => 2, "ddd" => 3})
31
+ classifier.classify({"aaa" => 1, "bbb" => 1})
32
+
33
+ #=> {"positive" => 0.8767123287671234, "negative" => 0.12328767123287669}
34
+ ```
35
+
36
+ ### Or you can train with strings
37
+ ```ruby
38
+ classifier = ::Bae::Classifier.new
39
+ classifier.train("positive", "aaa aaa bbb");
40
+ classifier.train("negative", "ccc ccc ddd ddd");
41
+ classifier.train("neutral", "eee eee eee fff fff fff");
42
+ classifier.classify("aaa bbb")
43
+
44
+ #=> {"positive"=>0.8962655601659751, "negative"=>0.0663900414937759, "neutral"=>0.037344398340248955}
45
+ ```
46
+
47
+
48
+ ## Contributing
49
+
50
+ 1. Fork it ( https://github.com/[my-github-username]/bae/fork )
51
+ 2. Create your feature branch (`git checkout -b my-new-feature`)
52
+ 3. Commit your changes (`git commit -am 'Add some feature'`)
53
+ 4. Push to the branch (`git push origin my-new-feature`)
54
+ 5. Create a new Pull Request
data/Rakefile ADDED
@@ -0,0 +1,8 @@
1
+ require "bundler/gem_tasks"
2
+ require "rspec/core/rake_task"
3
+
4
+ desc "Run specs"
5
+ RSpec::Core::RakeTask.new(:spec)
6
+
7
+ desc "Run specs (default)"
8
+ task :default => :spec
data/bae.gemspec ADDED
@@ -0,0 +1,24 @@
1
+ # coding: utf-8
2
+ lib = File.expand_path('../lib', __FILE__)
3
+ $LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib)
4
+ require 'bae/version'
5
+
6
+ Gem::Specification.new do |spec|
7
+ spec.name = "bae"
8
+ spec.version = Bae::VERSION
9
+ spec.authors = ["Garrett Thornburg"]
10
+ spec.email = ["film42@gmail.com"]
11
+ spec.summary = "Multinomial naive bayes classifier with a kick of java"
12
+ spec.description = "Multinomial naive bayes classifier with a kick of java"
13
+ spec.homepage = "https://github.com/film42/bae"
14
+ spec.license = "GPL version 3, or LGPL version 3 (Dual License)"
15
+
16
+ spec.files = `git ls-files -z`.split("\x0")
17
+ spec.executables = spec.files.grep(%r{^bin/}) { |f| File.basename(f) }
18
+ spec.test_files = spec.files.grep(%r{^(test|spec|features)/})
19
+ spec.require_paths = ["lib"]
20
+
21
+ spec.add_development_dependency "bundler", "~> 1.6"
22
+ spec.add_development_dependency "rspec"
23
+ spec.add_development_dependency "rake"
24
+ end
data/build.xml ADDED
@@ -0,0 +1,18 @@
1
+ <project>
2
+
3
+ <target name="clean">
4
+ <delete dir="build"/>
5
+ </target>
6
+
7
+ <target name="compile">
8
+ <mkdir dir="out"/>
9
+ <mkdir dir="out/classes"/>
10
+ <javac srcdir="src/main/java" destdir="out/classes"/>
11
+ </target>
12
+
13
+ <target name="jar" depends="compile">
14
+ <mkdir dir="target"/>
15
+ <jar destfile="target/bae.jar" basedir="out/classes" excludes="**/*.jar,**/MANIFEST.MF,**/BCKEY.*" />
16
+ </target>
17
+
18
+ </project>
@@ -0,0 +1,19 @@
1
+ module Bae
2
+ class Classifier
3
+
4
+ attr_reader :internal_classifier
5
+
6
+ def initialize
7
+ @internal_classifier = ::Java::Bae::NaiveBayesClassifier.new
8
+ end
9
+
10
+ def train(label, feature)
11
+ internal_classifier.train(label, ::Java::Bae::Document.new(feature))
12
+ end
13
+
14
+ def classify(feature)
15
+ internal_classifier.classify(::Java::Bae::Document.new(feature))
16
+ end
17
+
18
+ end
19
+ end
@@ -0,0 +1,3 @@
1
+ module Bae
2
+ VERSION = "0.0.1"
3
+ end
data/lib/bae.rb ADDED
@@ -0,0 +1,12 @@
1
+ require "bae/version"
2
+
3
+ require "target/bae.jar"
4
+
5
+ java_import "bae.Document"
6
+ java_import "bae.FrequencyTable"
7
+ java_import "bae.NaiveBayesClassifier"
8
+
9
+ require "bae/classifier"
10
+
11
+ module Bae
12
+ end
@@ -0,0 +1,27 @@
1
+ require 'spec_helper'
2
+
3
+ describe ::Bae::Classifier do
4
+
5
+ subject { described_class.new }
6
+
7
+ it "can classify from ruby to java with a hash document" do
8
+ subject.train("positive", {"aaa" => 0, "bbb" => 1})
9
+ subject.train("negative", {"ccc" => 2, "ddd" => 3})
10
+ results = subject.classify({"aaa" => 1, "bbb" => 1})
11
+
12
+ expect(results["positive"]).to be_within(0.001).of(0.94117)
13
+ expect(results["negative"]).to be_within(0.001).of(0.05882)
14
+ end
15
+
16
+ it "can classify from ruby to java with a string based document" do
17
+ subject.train("positive", "aaa aaa bbb");
18
+ subject.train("negative", "ccc ccc ddd ddd");
19
+ subject.train("neutral", "eee eee eee fff fff fff");
20
+ results = subject.classify("aaa bbb")
21
+
22
+ expect(results["positive"]).to be_within(0.001).of(0.89626)
23
+ expect(results["negative"]).to be_within(0.001).of(0.06639)
24
+ expect(results["neutral"]).to be_within(0.001).of(0.03734)
25
+ end
26
+
27
+ end
@@ -0,0 +1,8 @@
1
+ require 'bundler/setup'
2
+ require 'bae'
3
+ require 'rspec'
4
+
5
+ RSpec.configure do |c|
6
+ c.order = :rand
7
+ c.color = true
8
+ end
@@ -0,0 +1,42 @@
1
+ package bae;
2
+
3
+ import java.util.HashMap;
4
+ import java.util.Map;
5
+ import java.util.Scanner;
6
+
7
+ public class Document {
8
+
9
+ private Map<String, Long> frequencyMap;
10
+
11
+ public Document(String text) {
12
+ createFrequencyMap(text);
13
+ }
14
+
15
+ public Document(Map<String, Long> frequencyMap) {
16
+ this.frequencyMap = frequencyMap;
17
+ }
18
+
19
+ public Map<String, Long> getFrequencyMap() {
20
+ return frequencyMap;
21
+ }
22
+
23
+ public void addZeroCount(String key) {
24
+ this.frequencyMap.put(key, 0L);
25
+ }
26
+
27
+ private void createFrequencyMap(String text) {
28
+ this.frequencyMap = new HashMap<>();
29
+
30
+ Scanner parser = new Scanner(text);
31
+ while(parser.hasNext()) {
32
+ String wordToken = parser.next();
33
+
34
+ // Set initial count if it doesn't have one yet
35
+ // Use zero because we'll add counts in the next line.
36
+ this.frequencyMap.putIfAbsent(wordToken, 0L);
37
+
38
+ // Update count
39
+ this.frequencyMap.put(wordToken, this.frequencyMap.get(wordToken) + 1);
40
+ }
41
+ }
42
+ }
@@ -0,0 +1,44 @@
1
+ package bae;
2
+
3
+ import java.util.HashMap;
4
+ import java.util.Map;
5
+ import java.util.Set;
6
+
7
+ public class FrequencyTable {
8
+
9
+ private Map<String, Map<String, Long>> frequencyTable;
10
+
11
+ public FrequencyTable() {
12
+ this.frequencyTable = new HashMap<>();
13
+ }
14
+
15
+ public void insertOrIgnore(String label) {
16
+ // Add new hash to frequency table if it's not already there
17
+ this.frequencyTable.putIfAbsent(label, new HashMap<String, Long>());
18
+ }
19
+
20
+ public void increaseFrequencyBy(String label, String word, long frequency) {
21
+ // Add label if it doesn't exist
22
+ insertOrIgnore(label);
23
+
24
+ Map<String, Long> frequencyRow = this.frequencyTable.get(label);
25
+
26
+ // Make sure we have a frequency for that position in the table
27
+ frequencyRow.putIfAbsent(word, 0L);
28
+
29
+ // Update frequency
30
+ frequencyRow.put(word, frequencyRow.get(word) + frequency);
31
+ }
32
+
33
+ public Set<String> getLabels() {
34
+ return this.frequencyTable.keySet();
35
+ }
36
+
37
+ public long get(String label, String word) {
38
+ try {
39
+ return this.frequencyTable.get(label).get(word);
40
+ } catch (NullPointerException e) {
41
+ return 0L;
42
+ }
43
+ }
44
+ }
@@ -0,0 +1,120 @@
1
+ package bae;
2
+
3
+ import java.util.HashMap;
4
+ import java.util.Map;
5
+
6
+ public class NaiveBayesClassifier {
7
+
8
+ private FrequencyTable frequencyTable;
9
+ private Map<String, Long> wordTable;
10
+ private Map<String, Long> instanceCountOf;
11
+ private double totalCount = 0;
12
+
13
+ public NaiveBayesClassifier() {
14
+ this.frequencyTable = new FrequencyTable();
15
+ this.wordTable = new HashMap<>();
16
+ this.instanceCountOf = new HashMap<>();
17
+ }
18
+
19
+ public void train(String label, Document document) {
20
+ // Add to the frequency table if it doesn't exist
21
+ this.frequencyTable.insertOrIgnore(label);
22
+
23
+ // Update frequency table with the documents frequency
24
+ for(Map.Entry<String, Long> entry : document.getFrequencyMap().entrySet()) {
25
+ String word = entry.getKey();
26
+ long frequency = entry.getValue();
27
+
28
+ // Update counts
29
+ this.frequencyTable.increaseFrequencyBy(label, word, frequency);
30
+ // Add the word's presence to the word table
31
+ this.wordTable.put(word, 1L);
32
+ }
33
+
34
+ // Update global counts
35
+ totalCount += 1;
36
+ // Update instance count
37
+ updateIntegerCountBy(this.instanceCountOf, label, 1);
38
+ }
39
+
40
+ public Map<String, Double> classify(Document document) {
41
+ Map<String, Double> classPriorOf = new HashMap<>();
42
+ Map<String, Double> likelihoodOf = new HashMap<>();
43
+ Map<String, Double> classPosteriorOf = new HashMap<>();
44
+ Map<String, Long> frequencyMap = document.getFrequencyMap();
45
+ double evidence = 0;
46
+
47
+ // Update the prior
48
+ for(Map.Entry<String, Long> entry : this.instanceCountOf.entrySet()) {
49
+ String label = entry.getKey();
50
+ double frequency = entry.getValue();
51
+
52
+ // Update instance count
53
+ classPriorOf.put(label, (frequency / this.totalCount));
54
+ }
55
+
56
+ // Update likelihood counts
57
+ for(String label : this.frequencyTable.getLabels()) {
58
+ // Set initial likelihood
59
+ likelihoodOf.put(label, 1d);
60
+
61
+ // Calculate likelihoods
62
+ for(String word : wordTable.keySet()) {
63
+ double laplaceWordLikelihood =
64
+ (this.frequencyTable.get(label, word) + 1d) /
65
+ (this.instanceCountOf.get(label) + this.wordTable.size());
66
+
67
+ // Update likelihood
68
+ double likelihood = likelihoodOf.get(label);
69
+ if(frequencyMap.containsKey(word)) {
70
+ likelihoodOf.put(label, likelihood * laplaceWordLikelihood);
71
+ } else {
72
+ likelihoodOf.put(label, likelihood * (1d - laplaceWordLikelihood));
73
+ }
74
+ }
75
+
76
+ // Default class posterior of label to 1.0
77
+ classPosteriorOf.putIfAbsent(label, 1d);
78
+
79
+ // Update class posterior
80
+ double classPosterior = classPriorOf.get(label) * likelihoodOf.get(label);
81
+ classPosteriorOf.put(label, classPosterior);
82
+ evidence += classPosterior;
83
+ }
84
+
85
+ // Normalize results
86
+ for(Map.Entry<String, Double> entry : classPosteriorOf.entrySet()) {
87
+ String label = entry.getKey();
88
+ double posterior = entry.getValue();
89
+ classPosteriorOf.put(label, posterior / evidence);
90
+ }
91
+
92
+ return classPosteriorOf;
93
+ }
94
+
95
+ public void updateIntegerCountBy(Map<String, Long> someMap, String someKey, long count) {
96
+ someMap.putIfAbsent(someKey, 0L);
97
+ someMap.put(someKey, someMap.get(someKey) + count);
98
+ }
99
+
100
+ public void updateDoubleCountBy(Map<String, Double> someMap, String someKey, double count) {
101
+ someMap.putIfAbsent(someKey, 0.0);
102
+ someMap.put(someKey, someMap.get(someKey) + count);
103
+ }
104
+
105
+ public FrequencyTable getFrequencyTable() {
106
+ return this.frequencyTable;
107
+ }
108
+
109
+ public Map<String, Long> getWordTable() {
110
+ return this.wordTable;
111
+ }
112
+
113
+ public Map<String, Long> getInstanceCount() {
114
+ return this.instanceCountOf;
115
+ }
116
+
117
+ public double getTotalCount() {
118
+ return totalCount;
119
+ }
120
+ }
@@ -0,0 +1,34 @@
1
+ package bae;
2
+
3
+ import bae.Document;
4
+ import org.junit.Test;
5
+
6
+ import java.util.Map;
7
+
8
+ import static org.junit.Assert.*;
9
+
10
+ public class DocumentTest {
11
+
12
+ @Test
13
+ public void testCanCreateAccurateFrequencyTable() {
14
+ Document document = new Document("aaa bbb aaa bbb ccc");
15
+ Map<String, Long> frequencyMap = document.getFrequencyMap();
16
+
17
+ assertEquals(2, (long)frequencyMap.get("aaa"));
18
+ assertEquals(2, (long)frequencyMap.get("bbb"));
19
+ assertEquals(1, (long)frequencyMap.get("ccc"));
20
+ }
21
+
22
+ @Test
23
+ public void testCanParseADirtyString() {
24
+ Document document = new Document(" a aaa\ta \t\t\t aa a bbb a aaa bbb ccc ");
25
+ Map<String, Long> frequencyMap = document.getFrequencyMap();
26
+
27
+ assertEquals(2, (long)frequencyMap.get("aaa"));
28
+ assertEquals(2, (long)frequencyMap.get("bbb"));
29
+ assertEquals(1, (long)frequencyMap.get("ccc"));
30
+ assertEquals(1, (long)frequencyMap.get("aa"));
31
+ assertEquals(4, (long)frequencyMap.get("a"));
32
+ }
33
+
34
+ }
@@ -0,0 +1,23 @@
1
+ package bae;
2
+
3
+ import bae.FrequencyTable;
4
+ import org.junit.Test;
5
+
6
+ import static org.junit.Assert.*;
7
+
8
+ public class FrequencyTableTest {
9
+
10
+ @Test
11
+ public void canInsertIntoFrequencyTable() {
12
+ FrequencyTable frequencyTable = new FrequencyTable();
13
+
14
+ frequencyTable.increaseFrequencyBy("a", "b", 100);
15
+
16
+ assertEquals(100, frequencyTable.get("a", "b"));
17
+
18
+ // Make sure we fail correctly
19
+ assertEquals(0, frequencyTable.get("a", "z"));
20
+ assertEquals(0, frequencyTable.get("z", "z"));
21
+ }
22
+
23
+ }
@@ -0,0 +1,134 @@
1
+ package bae;
2
+
3
+ import bae.Document;
4
+ import bae.NaiveBayesClassifier;
5
+ import org.junit.Test;
6
+
7
+ import java.util.Map;
8
+
9
+ import static org.junit.Assert.*;
10
+
11
+ public class NaiveBayesClassifierTest {
12
+
13
+ @Test
14
+ public void canCorrectlyTrainFrequencyTable() {
15
+ NaiveBayesClassifier n = new NaiveBayesClassifier();
16
+
17
+ n.train("positive", new Document("bbb"));
18
+ n.train("negative", new Document("ccc ccc ddd ddd ddd"));
19
+ n.train("positive", new Document("aaa bbb bbb"));
20
+ n.train("negative", new Document("ccc ccc ccc ddd ddd ddd ddd"));
21
+
22
+ assertEquals(1, n.getFrequencyTable().get("positive", "aaa"));
23
+ assertEquals(3, n.getFrequencyTable().get("positive", "bbb"));
24
+ assertEquals(5, n.getFrequencyTable().get("negative", "ccc"));
25
+ assertEquals(7, n.getFrequencyTable().get("negative", "ddd"));
26
+ }
27
+
28
+ @Test
29
+ public void canCorrectlyTrainWordTable() {
30
+ NaiveBayesClassifier n = new NaiveBayesClassifier();
31
+
32
+ n.train("positive", new Document("bbb"));
33
+ n.train("negative", new Document("ccc ccc ddd ddd ddd"));
34
+ n.train("positive", new Document("aaa bbb bbb"));
35
+ n.train("negative", new Document("ccc ccc ccc ddd ddd ddd ddd"));
36
+
37
+ assertEquals(1, (long)n.getWordTable().get("aaa"));
38
+ assertEquals(1, (long)n.getWordTable().get("bbb"));
39
+ assertEquals(1, (long)n.getWordTable().get("ccc"));
40
+ assertEquals(1, (long)n.getWordTable().get("ddd"));
41
+ }
42
+
43
+ @Test
44
+ public void canCorrectlyTrainInstanceCount() {
45
+ NaiveBayesClassifier n = new NaiveBayesClassifier();
46
+
47
+ n.train("positive", new Document("bbb"));
48
+ n.train("negative", new Document("ccc ccc ddd ddd ddd"));
49
+ n.train("positive", new Document("aaa bbb bbb"));
50
+ n.train("negative", new Document("ccc ccc ccc ddd ddd ddd ddd"));
51
+ n.train("negative", new Document("ccc ccc ccc ddd ddd"));
52
+
53
+ assertEquals(2, (long)n.getInstanceCount().get("positive"));
54
+ assertEquals(3, (long)n.getInstanceCount().get("negative"));
55
+ assertEquals(5, (long)n.getTotalCount());
56
+ }
57
+
58
+ @Test
59
+ public void canCorrectlyClassifyPositiveWithTwoLabels() {
60
+ NaiveBayesClassifier n = new NaiveBayesClassifier();
61
+
62
+ Document d = new Document("bbb");
63
+ d.addZeroCount("aaa");
64
+
65
+ n.train("positive", d);
66
+ n.train("negative", new Document("ccc ccc ddd ddd ddd"));
67
+
68
+ Map<String, Double> results = n.classify(new Document("aaa bbb"));
69
+
70
+ assertEquals(0.9411764705882353, results.get("positive"), 0.00001);
71
+ assertEquals(0.05882352941176469, results.get("negative"), 0.00001);
72
+ }
73
+
74
+ @Test
75
+ public void canCorrectlyClassifyNegativeWithTwoLabels() {
76
+ NaiveBayesClassifier n = new NaiveBayesClassifier();
77
+
78
+ Document d = new Document("bbb");
79
+ d.addZeroCount("aaa");
80
+
81
+ n.train("positive", d);
82
+ n.train("negative", new Document("ccc ccc ddd ddd ddd"));
83
+
84
+ Map<String, Double> results = n.classify(new Document("ccc ccc ccc ddd ddd ddd"));
85
+
86
+ assertEquals(0.05882352941176469, results.get("positive"), 0.00001);
87
+ assertEquals(0.9411764705882353, results.get("negative"), 0.00001);
88
+ }
89
+
90
+ @Test
91
+ public void canCorrectlyClassifyPositiveWithThreeLabels() {
92
+ NaiveBayesClassifier n = new NaiveBayesClassifier();
93
+
94
+ n.train("positive", new Document("aaa aaa bbb"));
95
+ n.train("negative", new Document("ccc ccc ddd ddd"));
96
+ n.train("neutral", new Document("eee eee eee fff fff fff"));
97
+
98
+ Map<String, Double> results = n.classify(new Document("aaa bbb"));
99
+
100
+ assertEquals(0.896265560165975, results.get("positive"), 0.00001);
101
+ assertEquals(0.06639004149377592, results.get("negative"), 0.00001);
102
+ assertEquals(0.03734439834024896, results.get("neutral"), 0.00001);
103
+ }
104
+
105
+ @Test
106
+ public void canCorrectlyClassifyNegativeWithThreeLabels() {
107
+ NaiveBayesClassifier n = new NaiveBayesClassifier();
108
+
109
+ n.train("positive", new Document("aaa aaa bbb"));
110
+ n.train("negative", new Document("ccc ccc ddd ddd"));
111
+ n.train("neutral", new Document("eee eee eee fff fff fff"));
112
+
113
+ Map<String, Double> results = n.classify(new Document("ccc ccc ccc ddd ddd"));
114
+
115
+ assertEquals(0.05665722379603399, results.get("positive"), 0.00001);
116
+ assertEquals(0.9178470254957508, results.get("negative"), 0.00001);
117
+ assertEquals(0.0254957507082153, results.get("neutral"), 0.00001);
118
+ }
119
+
120
+ @Test
121
+ public void canCorrectlyClassifyNeutralWithThreeLabels() {
122
+ NaiveBayesClassifier n = new NaiveBayesClassifier();
123
+
124
+ n.train("positive", new Document("aaa aaa bbb"));
125
+ n.train("negative", new Document("ccc ccc ddd ddd"));
126
+ n.train("neutral", new Document("eee eee eee fff fff fff"));
127
+
128
+ Map<String, Double> results = n.classify(new Document("aaa ddd ddd eee eee eee fff"));
129
+
130
+ assertEquals(0.12195121951219513, results.get("positive"), 0.00001);
131
+ assertEquals(0.09756097560975606, results.get("negative"), 0.00001);
132
+ assertEquals(0.7804878048780488, results.get("neutral"), 0.00001);
133
+ }
134
+ }
data/target/bae.jar ADDED
Binary file