bae 0.0.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.gitignore +26 -0
- data/Gemfile +4 -0
- data/LICENSE.txt +855 -0
- data/README.md +54 -0
- data/Rakefile +8 -0
- data/bae.gemspec +24 -0
- data/build.xml +18 -0
- data/lib/bae/classifier.rb +19 -0
- data/lib/bae/version.rb +3 -0
- data/lib/bae.rb +12 -0
- data/spec/lib/bae/classifier_spec.rb +27 -0
- data/spec/spec_helper.rb +8 -0
- data/src/main/java/bae/Document.java +42 -0
- data/src/main/java/bae/FrequencyTable.java +44 -0
- data/src/main/java/bae/NaiveBayesClassifier.java +120 -0
- data/src/test/java/bae/DocumentTest.java +34 -0
- data/src/test/java/bae/FrequencyTableTest.java +23 -0
- data/src/test/java/bae/NaiveBayesClassifierTest.java +134 -0
- data/target/bae.jar +0 -0
- metadata +107 -0
data/README.md
ADDED
@@ -0,0 +1,54 @@
|
|
1
|
+
Bae
|
2
|
+
===
|
3
|
+
|
4
|
+
Bae is a multinomial naive bayes classifier based on another gem ["naivebayes"](https://github.com/id774/naivebayes), only this one uses java to do the heavy lifting.
|
5
|
+
|
6
|
+
## Installation
|
7
|
+
|
8
|
+
Add this line to your application's Gemfile:
|
9
|
+
|
10
|
+
gem 'bae'
|
11
|
+
|
12
|
+
And then execute:
|
13
|
+
|
14
|
+
$ bundle
|
15
|
+
|
16
|
+
Or install it yourself as:
|
17
|
+
|
18
|
+
$ gem install bae
|
19
|
+
|
20
|
+
## Usage
|
21
|
+
|
22
|
+
You can refer to ["naivebayes"](https://github.com/id774/naivebayes) gem for more documentation, or the tests for examples. Here is a copy/ paster example:
|
23
|
+
|
24
|
+
|
25
|
+
### You can provide a frequency hash to the trainer
|
26
|
+
|
27
|
+
```ruby
|
28
|
+
classifier = ::Bae::Classifier.new
|
29
|
+
classifier.train("positive", {"aaa" => 0, "bbb" => 1})
|
30
|
+
classifier.train("negative", {"ccc" => 2, "ddd" => 3})
|
31
|
+
classifier.classify({"aaa" => 1, "bbb" => 1})
|
32
|
+
|
33
|
+
#=> {"positive" => 0.8767123287671234, "negative" => 0.12328767123287669}
|
34
|
+
```
|
35
|
+
|
36
|
+
### Or you can train with strings
|
37
|
+
```ruby
|
38
|
+
classifier = ::Bae::Classifier.new
|
39
|
+
classifier.train("positive", "aaa aaa bbb");
|
40
|
+
classifier.train("negative", "ccc ccc ddd ddd");
|
41
|
+
classifier.train("neutral", "eee eee eee fff fff fff");
|
42
|
+
classifier.classify("aaa bbb")
|
43
|
+
|
44
|
+
#=> {"positive"=>0.8962655601659751, "negative"=>0.0663900414937759, "neutral"=>0.037344398340248955}
|
45
|
+
```
|
46
|
+
|
47
|
+
|
48
|
+
## Contributing
|
49
|
+
|
50
|
+
1. Fork it ( https://github.com/[my-github-username]/bae/fork )
|
51
|
+
2. Create your feature branch (`git checkout -b my-new-feature`)
|
52
|
+
3. Commit your changes (`git commit -am 'Add some feature'`)
|
53
|
+
4. Push to the branch (`git push origin my-new-feature`)
|
54
|
+
5. Create a new Pull Request
|
data/Rakefile
ADDED
data/bae.gemspec
ADDED
@@ -0,0 +1,24 @@
|
|
1
|
+
# coding: utf-8
|
2
|
+
lib = File.expand_path('../lib', __FILE__)
|
3
|
+
$LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib)
|
4
|
+
require 'bae/version'
|
5
|
+
|
6
|
+
Gem::Specification.new do |spec|
|
7
|
+
spec.name = "bae"
|
8
|
+
spec.version = Bae::VERSION
|
9
|
+
spec.authors = ["Garrett Thornburg"]
|
10
|
+
spec.email = ["film42@gmail.com"]
|
11
|
+
spec.summary = "Multinomial naive bayes classifier with a kick of java"
|
12
|
+
spec.description = "Multinomial naive bayes classifier with a kick of java"
|
13
|
+
spec.homepage = "https://github.com/film42/bae"
|
14
|
+
spec.license = "GPL version 3, or LGPL version 3 (Dual License)"
|
15
|
+
|
16
|
+
spec.files = `git ls-files -z`.split("\x0")
|
17
|
+
spec.executables = spec.files.grep(%r{^bin/}) { |f| File.basename(f) }
|
18
|
+
spec.test_files = spec.files.grep(%r{^(test|spec|features)/})
|
19
|
+
spec.require_paths = ["lib"]
|
20
|
+
|
21
|
+
spec.add_development_dependency "bundler", "~> 1.6"
|
22
|
+
spec.add_development_dependency "rspec"
|
23
|
+
spec.add_development_dependency "rake"
|
24
|
+
end
|
data/build.xml
ADDED
@@ -0,0 +1,18 @@
|
|
1
|
+
<project>
|
2
|
+
|
3
|
+
<target name="clean">
|
4
|
+
<delete dir="build"/>
|
5
|
+
</target>
|
6
|
+
|
7
|
+
<target name="compile">
|
8
|
+
<mkdir dir="out"/>
|
9
|
+
<mkdir dir="out/classes"/>
|
10
|
+
<javac srcdir="src/main/java" destdir="out/classes"/>
|
11
|
+
</target>
|
12
|
+
|
13
|
+
<target name="jar" depends="compile">
|
14
|
+
<mkdir dir="target"/>
|
15
|
+
<jar destfile="target/bae.jar" basedir="out/classes" excludes="**/*.jar,**/MANIFEST.MF,**/BCKEY.*" />
|
16
|
+
</target>
|
17
|
+
|
18
|
+
</project>
|
@@ -0,0 +1,19 @@
|
|
1
|
+
module Bae
|
2
|
+
class Classifier
|
3
|
+
|
4
|
+
attr_reader :internal_classifier
|
5
|
+
|
6
|
+
def initialize
|
7
|
+
@internal_classifier = ::Java::Bae::NaiveBayesClassifier.new
|
8
|
+
end
|
9
|
+
|
10
|
+
def train(label, feature)
|
11
|
+
internal_classifier.train(label, ::Java::Bae::Document.new(feature))
|
12
|
+
end
|
13
|
+
|
14
|
+
def classify(feature)
|
15
|
+
internal_classifier.classify(::Java::Bae::Document.new(feature))
|
16
|
+
end
|
17
|
+
|
18
|
+
end
|
19
|
+
end
|
data/lib/bae/version.rb
ADDED
data/lib/bae.rb
ADDED
@@ -0,0 +1,27 @@
|
|
1
|
+
require 'spec_helper'
|
2
|
+
|
3
|
+
describe ::Bae::Classifier do
|
4
|
+
|
5
|
+
subject { described_class.new }
|
6
|
+
|
7
|
+
it "can classify from ruby to java with a hash document" do
|
8
|
+
subject.train("positive", {"aaa" => 0, "bbb" => 1})
|
9
|
+
subject.train("negative", {"ccc" => 2, "ddd" => 3})
|
10
|
+
results = subject.classify({"aaa" => 1, "bbb" => 1})
|
11
|
+
|
12
|
+
expect(results["positive"]).to be_within(0.001).of(0.94117)
|
13
|
+
expect(results["negative"]).to be_within(0.001).of(0.05882)
|
14
|
+
end
|
15
|
+
|
16
|
+
it "can classify from ruby to java with a string based document" do
|
17
|
+
subject.train("positive", "aaa aaa bbb");
|
18
|
+
subject.train("negative", "ccc ccc ddd ddd");
|
19
|
+
subject.train("neutral", "eee eee eee fff fff fff");
|
20
|
+
results = subject.classify("aaa bbb")
|
21
|
+
|
22
|
+
expect(results["positive"]).to be_within(0.001).of(0.89626)
|
23
|
+
expect(results["negative"]).to be_within(0.001).of(0.06639)
|
24
|
+
expect(results["neutral"]).to be_within(0.001).of(0.03734)
|
25
|
+
end
|
26
|
+
|
27
|
+
end
|
data/spec/spec_helper.rb
ADDED
@@ -0,0 +1,42 @@
|
|
1
|
+
package bae;
|
2
|
+
|
3
|
+
import java.util.HashMap;
|
4
|
+
import java.util.Map;
|
5
|
+
import java.util.Scanner;
|
6
|
+
|
7
|
+
public class Document {
|
8
|
+
|
9
|
+
private Map<String, Long> frequencyMap;
|
10
|
+
|
11
|
+
public Document(String text) {
|
12
|
+
createFrequencyMap(text);
|
13
|
+
}
|
14
|
+
|
15
|
+
public Document(Map<String, Long> frequencyMap) {
|
16
|
+
this.frequencyMap = frequencyMap;
|
17
|
+
}
|
18
|
+
|
19
|
+
public Map<String, Long> getFrequencyMap() {
|
20
|
+
return frequencyMap;
|
21
|
+
}
|
22
|
+
|
23
|
+
public void addZeroCount(String key) {
|
24
|
+
this.frequencyMap.put(key, 0L);
|
25
|
+
}
|
26
|
+
|
27
|
+
private void createFrequencyMap(String text) {
|
28
|
+
this.frequencyMap = new HashMap<>();
|
29
|
+
|
30
|
+
Scanner parser = new Scanner(text);
|
31
|
+
while(parser.hasNext()) {
|
32
|
+
String wordToken = parser.next();
|
33
|
+
|
34
|
+
// Set initial count if it doesn't have one yet
|
35
|
+
// Use zero because we'll add counts in the next line.
|
36
|
+
this.frequencyMap.putIfAbsent(wordToken, 0L);
|
37
|
+
|
38
|
+
// Update count
|
39
|
+
this.frequencyMap.put(wordToken, this.frequencyMap.get(wordToken) + 1);
|
40
|
+
}
|
41
|
+
}
|
42
|
+
}
|
@@ -0,0 +1,44 @@
|
|
1
|
+
package bae;
|
2
|
+
|
3
|
+
import java.util.HashMap;
|
4
|
+
import java.util.Map;
|
5
|
+
import java.util.Set;
|
6
|
+
|
7
|
+
public class FrequencyTable {
|
8
|
+
|
9
|
+
private Map<String, Map<String, Long>> frequencyTable;
|
10
|
+
|
11
|
+
public FrequencyTable() {
|
12
|
+
this.frequencyTable = new HashMap<>();
|
13
|
+
}
|
14
|
+
|
15
|
+
public void insertOrIgnore(String label) {
|
16
|
+
// Add new hash to frequency table if it's not already there
|
17
|
+
this.frequencyTable.putIfAbsent(label, new HashMap<String, Long>());
|
18
|
+
}
|
19
|
+
|
20
|
+
public void increaseFrequencyBy(String label, String word, long frequency) {
|
21
|
+
// Add label if it doesn't exist
|
22
|
+
insertOrIgnore(label);
|
23
|
+
|
24
|
+
Map<String, Long> frequencyRow = this.frequencyTable.get(label);
|
25
|
+
|
26
|
+
// Make sure we have a frequency for that position in the table
|
27
|
+
frequencyRow.putIfAbsent(word, 0L);
|
28
|
+
|
29
|
+
// Update frequency
|
30
|
+
frequencyRow.put(word, frequencyRow.get(word) + frequency);
|
31
|
+
}
|
32
|
+
|
33
|
+
public Set<String> getLabels() {
|
34
|
+
return this.frequencyTable.keySet();
|
35
|
+
}
|
36
|
+
|
37
|
+
public long get(String label, String word) {
|
38
|
+
try {
|
39
|
+
return this.frequencyTable.get(label).get(word);
|
40
|
+
} catch (NullPointerException e) {
|
41
|
+
return 0L;
|
42
|
+
}
|
43
|
+
}
|
44
|
+
}
|
@@ -0,0 +1,120 @@
|
|
1
|
+
package bae;
|
2
|
+
|
3
|
+
import java.util.HashMap;
|
4
|
+
import java.util.Map;
|
5
|
+
|
6
|
+
public class NaiveBayesClassifier {
|
7
|
+
|
8
|
+
private FrequencyTable frequencyTable;
|
9
|
+
private Map<String, Long> wordTable;
|
10
|
+
private Map<String, Long> instanceCountOf;
|
11
|
+
private double totalCount = 0;
|
12
|
+
|
13
|
+
public NaiveBayesClassifier() {
|
14
|
+
this.frequencyTable = new FrequencyTable();
|
15
|
+
this.wordTable = new HashMap<>();
|
16
|
+
this.instanceCountOf = new HashMap<>();
|
17
|
+
}
|
18
|
+
|
19
|
+
public void train(String label, Document document) {
|
20
|
+
// Add to the frequency table if it doesn't exist
|
21
|
+
this.frequencyTable.insertOrIgnore(label);
|
22
|
+
|
23
|
+
// Update frequency table with the documents frequency
|
24
|
+
for(Map.Entry<String, Long> entry : document.getFrequencyMap().entrySet()) {
|
25
|
+
String word = entry.getKey();
|
26
|
+
long frequency = entry.getValue();
|
27
|
+
|
28
|
+
// Update counts
|
29
|
+
this.frequencyTable.increaseFrequencyBy(label, word, frequency);
|
30
|
+
// Add the word's presence to the word table
|
31
|
+
this.wordTable.put(word, 1L);
|
32
|
+
}
|
33
|
+
|
34
|
+
// Update global counts
|
35
|
+
totalCount += 1;
|
36
|
+
// Update instance count
|
37
|
+
updateIntegerCountBy(this.instanceCountOf, label, 1);
|
38
|
+
}
|
39
|
+
|
40
|
+
public Map<String, Double> classify(Document document) {
|
41
|
+
Map<String, Double> classPriorOf = new HashMap<>();
|
42
|
+
Map<String, Double> likelihoodOf = new HashMap<>();
|
43
|
+
Map<String, Double> classPosteriorOf = new HashMap<>();
|
44
|
+
Map<String, Long> frequencyMap = document.getFrequencyMap();
|
45
|
+
double evidence = 0;
|
46
|
+
|
47
|
+
// Update the prior
|
48
|
+
for(Map.Entry<String, Long> entry : this.instanceCountOf.entrySet()) {
|
49
|
+
String label = entry.getKey();
|
50
|
+
double frequency = entry.getValue();
|
51
|
+
|
52
|
+
// Update instance count
|
53
|
+
classPriorOf.put(label, (frequency / this.totalCount));
|
54
|
+
}
|
55
|
+
|
56
|
+
// Update likelihood counts
|
57
|
+
for(String label : this.frequencyTable.getLabels()) {
|
58
|
+
// Set initial likelihood
|
59
|
+
likelihoodOf.put(label, 1d);
|
60
|
+
|
61
|
+
// Calculate likelihoods
|
62
|
+
for(String word : wordTable.keySet()) {
|
63
|
+
double laplaceWordLikelihood =
|
64
|
+
(this.frequencyTable.get(label, word) + 1d) /
|
65
|
+
(this.instanceCountOf.get(label) + this.wordTable.size());
|
66
|
+
|
67
|
+
// Update likelihood
|
68
|
+
double likelihood = likelihoodOf.get(label);
|
69
|
+
if(frequencyMap.containsKey(word)) {
|
70
|
+
likelihoodOf.put(label, likelihood * laplaceWordLikelihood);
|
71
|
+
} else {
|
72
|
+
likelihoodOf.put(label, likelihood * (1d - laplaceWordLikelihood));
|
73
|
+
}
|
74
|
+
}
|
75
|
+
|
76
|
+
// Default class posterior of label to 1.0
|
77
|
+
classPosteriorOf.putIfAbsent(label, 1d);
|
78
|
+
|
79
|
+
// Update class posterior
|
80
|
+
double classPosterior = classPriorOf.get(label) * likelihoodOf.get(label);
|
81
|
+
classPosteriorOf.put(label, classPosterior);
|
82
|
+
evidence += classPosterior;
|
83
|
+
}
|
84
|
+
|
85
|
+
// Normalize results
|
86
|
+
for(Map.Entry<String, Double> entry : classPosteriorOf.entrySet()) {
|
87
|
+
String label = entry.getKey();
|
88
|
+
double posterior = entry.getValue();
|
89
|
+
classPosteriorOf.put(label, posterior / evidence);
|
90
|
+
}
|
91
|
+
|
92
|
+
return classPosteriorOf;
|
93
|
+
}
|
94
|
+
|
95
|
+
public void updateIntegerCountBy(Map<String, Long> someMap, String someKey, long count) {
|
96
|
+
someMap.putIfAbsent(someKey, 0L);
|
97
|
+
someMap.put(someKey, someMap.get(someKey) + count);
|
98
|
+
}
|
99
|
+
|
100
|
+
public void updateDoubleCountBy(Map<String, Double> someMap, String someKey, double count) {
|
101
|
+
someMap.putIfAbsent(someKey, 0.0);
|
102
|
+
someMap.put(someKey, someMap.get(someKey) + count);
|
103
|
+
}
|
104
|
+
|
105
|
+
public FrequencyTable getFrequencyTable() {
|
106
|
+
return this.frequencyTable;
|
107
|
+
}
|
108
|
+
|
109
|
+
public Map<String, Long> getWordTable() {
|
110
|
+
return this.wordTable;
|
111
|
+
}
|
112
|
+
|
113
|
+
public Map<String, Long> getInstanceCount() {
|
114
|
+
return this.instanceCountOf;
|
115
|
+
}
|
116
|
+
|
117
|
+
public double getTotalCount() {
|
118
|
+
return totalCount;
|
119
|
+
}
|
120
|
+
}
|
@@ -0,0 +1,34 @@
|
|
1
|
+
package bae;
|
2
|
+
|
3
|
+
import bae.Document;
|
4
|
+
import org.junit.Test;
|
5
|
+
|
6
|
+
import java.util.Map;
|
7
|
+
|
8
|
+
import static org.junit.Assert.*;
|
9
|
+
|
10
|
+
public class DocumentTest {
|
11
|
+
|
12
|
+
@Test
|
13
|
+
public void testCanCreateAccurateFrequencyTable() {
|
14
|
+
Document document = new Document("aaa bbb aaa bbb ccc");
|
15
|
+
Map<String, Long> frequencyMap = document.getFrequencyMap();
|
16
|
+
|
17
|
+
assertEquals(2, (long)frequencyMap.get("aaa"));
|
18
|
+
assertEquals(2, (long)frequencyMap.get("bbb"));
|
19
|
+
assertEquals(1, (long)frequencyMap.get("ccc"));
|
20
|
+
}
|
21
|
+
|
22
|
+
@Test
|
23
|
+
public void testCanParseADirtyString() {
|
24
|
+
Document document = new Document(" a aaa\ta \t\t\t aa a bbb a aaa bbb ccc ");
|
25
|
+
Map<String, Long> frequencyMap = document.getFrequencyMap();
|
26
|
+
|
27
|
+
assertEquals(2, (long)frequencyMap.get("aaa"));
|
28
|
+
assertEquals(2, (long)frequencyMap.get("bbb"));
|
29
|
+
assertEquals(1, (long)frequencyMap.get("ccc"));
|
30
|
+
assertEquals(1, (long)frequencyMap.get("aa"));
|
31
|
+
assertEquals(4, (long)frequencyMap.get("a"));
|
32
|
+
}
|
33
|
+
|
34
|
+
}
|
@@ -0,0 +1,23 @@
|
|
1
|
+
package bae;
|
2
|
+
|
3
|
+
import bae.FrequencyTable;
|
4
|
+
import org.junit.Test;
|
5
|
+
|
6
|
+
import static org.junit.Assert.*;
|
7
|
+
|
8
|
+
public class FrequencyTableTest {
|
9
|
+
|
10
|
+
@Test
|
11
|
+
public void canInsertIntoFrequencyTable() {
|
12
|
+
FrequencyTable frequencyTable = new FrequencyTable();
|
13
|
+
|
14
|
+
frequencyTable.increaseFrequencyBy("a", "b", 100);
|
15
|
+
|
16
|
+
assertEquals(100, frequencyTable.get("a", "b"));
|
17
|
+
|
18
|
+
// Make sure we fail correctly
|
19
|
+
assertEquals(0, frequencyTable.get("a", "z"));
|
20
|
+
assertEquals(0, frequencyTable.get("z", "z"));
|
21
|
+
}
|
22
|
+
|
23
|
+
}
|
@@ -0,0 +1,134 @@
|
|
1
|
+
package bae;
|
2
|
+
|
3
|
+
import bae.Document;
|
4
|
+
import bae.NaiveBayesClassifier;
|
5
|
+
import org.junit.Test;
|
6
|
+
|
7
|
+
import java.util.Map;
|
8
|
+
|
9
|
+
import static org.junit.Assert.*;
|
10
|
+
|
11
|
+
public class NaiveBayesClassifierTest {
|
12
|
+
|
13
|
+
@Test
|
14
|
+
public void canCorrectlyTrainFrequencyTable() {
|
15
|
+
NaiveBayesClassifier n = new NaiveBayesClassifier();
|
16
|
+
|
17
|
+
n.train("positive", new Document("bbb"));
|
18
|
+
n.train("negative", new Document("ccc ccc ddd ddd ddd"));
|
19
|
+
n.train("positive", new Document("aaa bbb bbb"));
|
20
|
+
n.train("negative", new Document("ccc ccc ccc ddd ddd ddd ddd"));
|
21
|
+
|
22
|
+
assertEquals(1, n.getFrequencyTable().get("positive", "aaa"));
|
23
|
+
assertEquals(3, n.getFrequencyTable().get("positive", "bbb"));
|
24
|
+
assertEquals(5, n.getFrequencyTable().get("negative", "ccc"));
|
25
|
+
assertEquals(7, n.getFrequencyTable().get("negative", "ddd"));
|
26
|
+
}
|
27
|
+
|
28
|
+
@Test
|
29
|
+
public void canCorrectlyTrainWordTable() {
|
30
|
+
NaiveBayesClassifier n = new NaiveBayesClassifier();
|
31
|
+
|
32
|
+
n.train("positive", new Document("bbb"));
|
33
|
+
n.train("negative", new Document("ccc ccc ddd ddd ddd"));
|
34
|
+
n.train("positive", new Document("aaa bbb bbb"));
|
35
|
+
n.train("negative", new Document("ccc ccc ccc ddd ddd ddd ddd"));
|
36
|
+
|
37
|
+
assertEquals(1, (long)n.getWordTable().get("aaa"));
|
38
|
+
assertEquals(1, (long)n.getWordTable().get("bbb"));
|
39
|
+
assertEquals(1, (long)n.getWordTable().get("ccc"));
|
40
|
+
assertEquals(1, (long)n.getWordTable().get("ddd"));
|
41
|
+
}
|
42
|
+
|
43
|
+
@Test
|
44
|
+
public void canCorrectlyTrainInstanceCount() {
|
45
|
+
NaiveBayesClassifier n = new NaiveBayesClassifier();
|
46
|
+
|
47
|
+
n.train("positive", new Document("bbb"));
|
48
|
+
n.train("negative", new Document("ccc ccc ddd ddd ddd"));
|
49
|
+
n.train("positive", new Document("aaa bbb bbb"));
|
50
|
+
n.train("negative", new Document("ccc ccc ccc ddd ddd ddd ddd"));
|
51
|
+
n.train("negative", new Document("ccc ccc ccc ddd ddd"));
|
52
|
+
|
53
|
+
assertEquals(2, (long)n.getInstanceCount().get("positive"));
|
54
|
+
assertEquals(3, (long)n.getInstanceCount().get("negative"));
|
55
|
+
assertEquals(5, (long)n.getTotalCount());
|
56
|
+
}
|
57
|
+
|
58
|
+
@Test
|
59
|
+
public void canCorrectlyClassifyPositiveWithTwoLabels() {
|
60
|
+
NaiveBayesClassifier n = new NaiveBayesClassifier();
|
61
|
+
|
62
|
+
Document d = new Document("bbb");
|
63
|
+
d.addZeroCount("aaa");
|
64
|
+
|
65
|
+
n.train("positive", d);
|
66
|
+
n.train("negative", new Document("ccc ccc ddd ddd ddd"));
|
67
|
+
|
68
|
+
Map<String, Double> results = n.classify(new Document("aaa bbb"));
|
69
|
+
|
70
|
+
assertEquals(0.9411764705882353, results.get("positive"), 0.00001);
|
71
|
+
assertEquals(0.05882352941176469, results.get("negative"), 0.00001);
|
72
|
+
}
|
73
|
+
|
74
|
+
@Test
|
75
|
+
public void canCorrectlyClassifyNegativeWithTwoLabels() {
|
76
|
+
NaiveBayesClassifier n = new NaiveBayesClassifier();
|
77
|
+
|
78
|
+
Document d = new Document("bbb");
|
79
|
+
d.addZeroCount("aaa");
|
80
|
+
|
81
|
+
n.train("positive", d);
|
82
|
+
n.train("negative", new Document("ccc ccc ddd ddd ddd"));
|
83
|
+
|
84
|
+
Map<String, Double> results = n.classify(new Document("ccc ccc ccc ddd ddd ddd"));
|
85
|
+
|
86
|
+
assertEquals(0.05882352941176469, results.get("positive"), 0.00001);
|
87
|
+
assertEquals(0.9411764705882353, results.get("negative"), 0.00001);
|
88
|
+
}
|
89
|
+
|
90
|
+
@Test
|
91
|
+
public void canCorrectlyClassifyPositiveWithThreeLabels() {
|
92
|
+
NaiveBayesClassifier n = new NaiveBayesClassifier();
|
93
|
+
|
94
|
+
n.train("positive", new Document("aaa aaa bbb"));
|
95
|
+
n.train("negative", new Document("ccc ccc ddd ddd"));
|
96
|
+
n.train("neutral", new Document("eee eee eee fff fff fff"));
|
97
|
+
|
98
|
+
Map<String, Double> results = n.classify(new Document("aaa bbb"));
|
99
|
+
|
100
|
+
assertEquals(0.896265560165975, results.get("positive"), 0.00001);
|
101
|
+
assertEquals(0.06639004149377592, results.get("negative"), 0.00001);
|
102
|
+
assertEquals(0.03734439834024896, results.get("neutral"), 0.00001);
|
103
|
+
}
|
104
|
+
|
105
|
+
@Test
|
106
|
+
public void canCorrectlyClassifyNegativeWithThreeLabels() {
|
107
|
+
NaiveBayesClassifier n = new NaiveBayesClassifier();
|
108
|
+
|
109
|
+
n.train("positive", new Document("aaa aaa bbb"));
|
110
|
+
n.train("negative", new Document("ccc ccc ddd ddd"));
|
111
|
+
n.train("neutral", new Document("eee eee eee fff fff fff"));
|
112
|
+
|
113
|
+
Map<String, Double> results = n.classify(new Document("ccc ccc ccc ddd ddd"));
|
114
|
+
|
115
|
+
assertEquals(0.05665722379603399, results.get("positive"), 0.00001);
|
116
|
+
assertEquals(0.9178470254957508, results.get("negative"), 0.00001);
|
117
|
+
assertEquals(0.0254957507082153, results.get("neutral"), 0.00001);
|
118
|
+
}
|
119
|
+
|
120
|
+
@Test
|
121
|
+
public void canCorrectlyClassifyNeutralWithThreeLabels() {
|
122
|
+
NaiveBayesClassifier n = new NaiveBayesClassifier();
|
123
|
+
|
124
|
+
n.train("positive", new Document("aaa aaa bbb"));
|
125
|
+
n.train("negative", new Document("ccc ccc ddd ddd"));
|
126
|
+
n.train("neutral", new Document("eee eee eee fff fff fff"));
|
127
|
+
|
128
|
+
Map<String, Double> results = n.classify(new Document("aaa ddd ddd eee eee eee fff"));
|
129
|
+
|
130
|
+
assertEquals(0.12195121951219513, results.get("positive"), 0.00001);
|
131
|
+
assertEquals(0.09756097560975606, results.get("negative"), 0.00001);
|
132
|
+
assertEquals(0.7804878048780488, results.get("neutral"), 0.00001);
|
133
|
+
}
|
134
|
+
}
|
data/target/bae.jar
ADDED
Binary file
|