bae 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.gitignore +26 -0
- data/Gemfile +4 -0
- data/LICENSE.txt +855 -0
- data/README.md +54 -0
- data/Rakefile +8 -0
- data/bae.gemspec +24 -0
- data/build.xml +18 -0
- data/lib/bae/classifier.rb +19 -0
- data/lib/bae/version.rb +3 -0
- data/lib/bae.rb +12 -0
- data/spec/lib/bae/classifier_spec.rb +27 -0
- data/spec/spec_helper.rb +8 -0
- data/src/main/java/bae/Document.java +42 -0
- data/src/main/java/bae/FrequencyTable.java +44 -0
- data/src/main/java/bae/NaiveBayesClassifier.java +120 -0
- data/src/test/java/bae/DocumentTest.java +34 -0
- data/src/test/java/bae/FrequencyTableTest.java +23 -0
- data/src/test/java/bae/NaiveBayesClassifierTest.java +134 -0
- data/target/bae.jar +0 -0
- metadata +107 -0
data/README.md
ADDED
@@ -0,0 +1,54 @@
|
|
1
|
+
Bae
|
2
|
+
===
|
3
|
+
|
4
|
+
Bae is a multinomial naive bayes classifier based on another gem ["naivebayes"](https://github.com/id774/naivebayes), only this one uses java to do the heavy lifting.
|
5
|
+
|
6
|
+
## Installation
|
7
|
+
|
8
|
+
Add this line to your application's Gemfile:
|
9
|
+
|
10
|
+
gem 'bae'
|
11
|
+
|
12
|
+
And then execute:
|
13
|
+
|
14
|
+
$ bundle
|
15
|
+
|
16
|
+
Or install it yourself as:
|
17
|
+
|
18
|
+
$ gem install bae
|
19
|
+
|
20
|
+
## Usage
|
21
|
+
|
22
|
+
You can refer to ["naivebayes"](https://github.com/id774/naivebayes) gem for more documentation, or the tests for examples. Here is a copy/ paster example:
|
23
|
+
|
24
|
+
|
25
|
+
### You can provide a frequency hash to the trainer
|
26
|
+
|
27
|
+
```ruby
|
28
|
+
classifier = ::Bae::Classifier.new
|
29
|
+
classifier.train("positive", {"aaa" => 0, "bbb" => 1})
|
30
|
+
classifier.train("negative", {"ccc" => 2, "ddd" => 3})
|
31
|
+
classifier.classify({"aaa" => 1, "bbb" => 1})
|
32
|
+
|
33
|
+
#=> {"positive" => 0.8767123287671234, "negative" => 0.12328767123287669}
|
34
|
+
```
|
35
|
+
|
36
|
+
### Or you can train with strings
|
37
|
+
```ruby
|
38
|
+
classifier = ::Bae::Classifier.new
|
39
|
+
classifier.train("positive", "aaa aaa bbb");
|
40
|
+
classifier.train("negative", "ccc ccc ddd ddd");
|
41
|
+
classifier.train("neutral", "eee eee eee fff fff fff");
|
42
|
+
classifier.classify("aaa bbb")
|
43
|
+
|
44
|
+
#=> {"positive"=>0.8962655601659751, "negative"=>0.0663900414937759, "neutral"=>0.037344398340248955}
|
45
|
+
```
|
46
|
+
|
47
|
+
|
48
|
+
## Contributing
|
49
|
+
|
50
|
+
1. Fork it ( https://github.com/[my-github-username]/bae/fork )
|
51
|
+
2. Create your feature branch (`git checkout -b my-new-feature`)
|
52
|
+
3. Commit your changes (`git commit -am 'Add some feature'`)
|
53
|
+
4. Push to the branch (`git push origin my-new-feature`)
|
54
|
+
5. Create a new Pull Request
|
data/Rakefile
ADDED
data/bae.gemspec
ADDED
@@ -0,0 +1,24 @@
|
|
1
|
+
# coding: utf-8
|
2
|
+
lib = File.expand_path('../lib', __FILE__)
|
3
|
+
$LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib)
|
4
|
+
require 'bae/version'
|
5
|
+
|
6
|
+
Gem::Specification.new do |spec|
|
7
|
+
spec.name = "bae"
|
8
|
+
spec.version = Bae::VERSION
|
9
|
+
spec.authors = ["Garrett Thornburg"]
|
10
|
+
spec.email = ["film42@gmail.com"]
|
11
|
+
spec.summary = "Multinomial naive bayes classifier with a kick of java"
|
12
|
+
spec.description = "Multinomial naive bayes classifier with a kick of java"
|
13
|
+
spec.homepage = "https://github.com/film42/bae"
|
14
|
+
spec.license = "GPL version 3, or LGPL version 3 (Dual License)"
|
15
|
+
|
16
|
+
spec.files = `git ls-files -z`.split("\x0")
|
17
|
+
spec.executables = spec.files.grep(%r{^bin/}) { |f| File.basename(f) }
|
18
|
+
spec.test_files = spec.files.grep(%r{^(test|spec|features)/})
|
19
|
+
spec.require_paths = ["lib"]
|
20
|
+
|
21
|
+
spec.add_development_dependency "bundler", "~> 1.6"
|
22
|
+
spec.add_development_dependency "rspec"
|
23
|
+
spec.add_development_dependency "rake"
|
24
|
+
end
|
data/build.xml
ADDED
@@ -0,0 +1,18 @@
|
|
1
|
+
<project>
|
2
|
+
|
3
|
+
<target name="clean">
|
4
|
+
<delete dir="build"/>
|
5
|
+
</target>
|
6
|
+
|
7
|
+
<target name="compile">
|
8
|
+
<mkdir dir="out"/>
|
9
|
+
<mkdir dir="out/classes"/>
|
10
|
+
<javac srcdir="src/main/java" destdir="out/classes"/>
|
11
|
+
</target>
|
12
|
+
|
13
|
+
<target name="jar" depends="compile">
|
14
|
+
<mkdir dir="target"/>
|
15
|
+
<jar destfile="target/bae.jar" basedir="out/classes" excludes="**/*.jar,**/MANIFEST.MF,**/BCKEY.*" />
|
16
|
+
</target>
|
17
|
+
|
18
|
+
</project>
|
@@ -0,0 +1,19 @@
|
|
1
|
+
module Bae
|
2
|
+
class Classifier
|
3
|
+
|
4
|
+
attr_reader :internal_classifier
|
5
|
+
|
6
|
+
def initialize
|
7
|
+
@internal_classifier = ::Java::Bae::NaiveBayesClassifier.new
|
8
|
+
end
|
9
|
+
|
10
|
+
def train(label, feature)
|
11
|
+
internal_classifier.train(label, ::Java::Bae::Document.new(feature))
|
12
|
+
end
|
13
|
+
|
14
|
+
def classify(feature)
|
15
|
+
internal_classifier.classify(::Java::Bae::Document.new(feature))
|
16
|
+
end
|
17
|
+
|
18
|
+
end
|
19
|
+
end
|
data/lib/bae/version.rb
ADDED
data/lib/bae.rb
ADDED
@@ -0,0 +1,27 @@
|
|
1
|
+
require 'spec_helper'
|
2
|
+
|
3
|
+
describe ::Bae::Classifier do
|
4
|
+
|
5
|
+
subject { described_class.new }
|
6
|
+
|
7
|
+
it "can classify from ruby to java with a hash document" do
|
8
|
+
subject.train("positive", {"aaa" => 0, "bbb" => 1})
|
9
|
+
subject.train("negative", {"ccc" => 2, "ddd" => 3})
|
10
|
+
results = subject.classify({"aaa" => 1, "bbb" => 1})
|
11
|
+
|
12
|
+
expect(results["positive"]).to be_within(0.001).of(0.94117)
|
13
|
+
expect(results["negative"]).to be_within(0.001).of(0.05882)
|
14
|
+
end
|
15
|
+
|
16
|
+
it "can classify from ruby to java with a string based document" do
|
17
|
+
subject.train("positive", "aaa aaa bbb");
|
18
|
+
subject.train("negative", "ccc ccc ddd ddd");
|
19
|
+
subject.train("neutral", "eee eee eee fff fff fff");
|
20
|
+
results = subject.classify("aaa bbb")
|
21
|
+
|
22
|
+
expect(results["positive"]).to be_within(0.001).of(0.89626)
|
23
|
+
expect(results["negative"]).to be_within(0.001).of(0.06639)
|
24
|
+
expect(results["neutral"]).to be_within(0.001).of(0.03734)
|
25
|
+
end
|
26
|
+
|
27
|
+
end
|
data/spec/spec_helper.rb
ADDED
@@ -0,0 +1,42 @@
|
|
1
|
+
package bae;
|
2
|
+
|
3
|
+
import java.util.HashMap;
|
4
|
+
import java.util.Map;
|
5
|
+
import java.util.Scanner;
|
6
|
+
|
7
|
+
public class Document {
|
8
|
+
|
9
|
+
private Map<String, Long> frequencyMap;
|
10
|
+
|
11
|
+
public Document(String text) {
|
12
|
+
createFrequencyMap(text);
|
13
|
+
}
|
14
|
+
|
15
|
+
public Document(Map<String, Long> frequencyMap) {
|
16
|
+
this.frequencyMap = frequencyMap;
|
17
|
+
}
|
18
|
+
|
19
|
+
public Map<String, Long> getFrequencyMap() {
|
20
|
+
return frequencyMap;
|
21
|
+
}
|
22
|
+
|
23
|
+
public void addZeroCount(String key) {
|
24
|
+
this.frequencyMap.put(key, 0L);
|
25
|
+
}
|
26
|
+
|
27
|
+
private void createFrequencyMap(String text) {
|
28
|
+
this.frequencyMap = new HashMap<>();
|
29
|
+
|
30
|
+
Scanner parser = new Scanner(text);
|
31
|
+
while(parser.hasNext()) {
|
32
|
+
String wordToken = parser.next();
|
33
|
+
|
34
|
+
// Set initial count if it doesn't have one yet
|
35
|
+
// Use zero because we'll add counts in the next line.
|
36
|
+
this.frequencyMap.putIfAbsent(wordToken, 0L);
|
37
|
+
|
38
|
+
// Update count
|
39
|
+
this.frequencyMap.put(wordToken, this.frequencyMap.get(wordToken) + 1);
|
40
|
+
}
|
41
|
+
}
|
42
|
+
}
|
@@ -0,0 +1,44 @@
|
|
1
|
+
package bae;
|
2
|
+
|
3
|
+
import java.util.HashMap;
|
4
|
+
import java.util.Map;
|
5
|
+
import java.util.Set;
|
6
|
+
|
7
|
+
public class FrequencyTable {
|
8
|
+
|
9
|
+
private Map<String, Map<String, Long>> frequencyTable;
|
10
|
+
|
11
|
+
public FrequencyTable() {
|
12
|
+
this.frequencyTable = new HashMap<>();
|
13
|
+
}
|
14
|
+
|
15
|
+
public void insertOrIgnore(String label) {
|
16
|
+
// Add new hash to frequency table if it's not already there
|
17
|
+
this.frequencyTable.putIfAbsent(label, new HashMap<String, Long>());
|
18
|
+
}
|
19
|
+
|
20
|
+
public void increaseFrequencyBy(String label, String word, long frequency) {
|
21
|
+
// Add label if it doesn't exist
|
22
|
+
insertOrIgnore(label);
|
23
|
+
|
24
|
+
Map<String, Long> frequencyRow = this.frequencyTable.get(label);
|
25
|
+
|
26
|
+
// Make sure we have a frequency for that position in the table
|
27
|
+
frequencyRow.putIfAbsent(word, 0L);
|
28
|
+
|
29
|
+
// Update frequency
|
30
|
+
frequencyRow.put(word, frequencyRow.get(word) + frequency);
|
31
|
+
}
|
32
|
+
|
33
|
+
public Set<String> getLabels() {
|
34
|
+
return this.frequencyTable.keySet();
|
35
|
+
}
|
36
|
+
|
37
|
+
public long get(String label, String word) {
|
38
|
+
try {
|
39
|
+
return this.frequencyTable.get(label).get(word);
|
40
|
+
} catch (NullPointerException e) {
|
41
|
+
return 0L;
|
42
|
+
}
|
43
|
+
}
|
44
|
+
}
|
@@ -0,0 +1,120 @@
|
|
1
|
+
package bae;
|
2
|
+
|
3
|
+
import java.util.HashMap;
|
4
|
+
import java.util.Map;
|
5
|
+
|
6
|
+
public class NaiveBayesClassifier {
|
7
|
+
|
8
|
+
private FrequencyTable frequencyTable;
|
9
|
+
private Map<String, Long> wordTable;
|
10
|
+
private Map<String, Long> instanceCountOf;
|
11
|
+
private double totalCount = 0;
|
12
|
+
|
13
|
+
public NaiveBayesClassifier() {
|
14
|
+
this.frequencyTable = new FrequencyTable();
|
15
|
+
this.wordTable = new HashMap<>();
|
16
|
+
this.instanceCountOf = new HashMap<>();
|
17
|
+
}
|
18
|
+
|
19
|
+
public void train(String label, Document document) {
|
20
|
+
// Add to the frequency table if it doesn't exist
|
21
|
+
this.frequencyTable.insertOrIgnore(label);
|
22
|
+
|
23
|
+
// Update frequency table with the documents frequency
|
24
|
+
for(Map.Entry<String, Long> entry : document.getFrequencyMap().entrySet()) {
|
25
|
+
String word = entry.getKey();
|
26
|
+
long frequency = entry.getValue();
|
27
|
+
|
28
|
+
// Update counts
|
29
|
+
this.frequencyTable.increaseFrequencyBy(label, word, frequency);
|
30
|
+
// Add the word's presence to the word table
|
31
|
+
this.wordTable.put(word, 1L);
|
32
|
+
}
|
33
|
+
|
34
|
+
// Update global counts
|
35
|
+
totalCount += 1;
|
36
|
+
// Update instance count
|
37
|
+
updateIntegerCountBy(this.instanceCountOf, label, 1);
|
38
|
+
}
|
39
|
+
|
40
|
+
public Map<String, Double> classify(Document document) {
|
41
|
+
Map<String, Double> classPriorOf = new HashMap<>();
|
42
|
+
Map<String, Double> likelihoodOf = new HashMap<>();
|
43
|
+
Map<String, Double> classPosteriorOf = new HashMap<>();
|
44
|
+
Map<String, Long> frequencyMap = document.getFrequencyMap();
|
45
|
+
double evidence = 0;
|
46
|
+
|
47
|
+
// Update the prior
|
48
|
+
for(Map.Entry<String, Long> entry : this.instanceCountOf.entrySet()) {
|
49
|
+
String label = entry.getKey();
|
50
|
+
double frequency = entry.getValue();
|
51
|
+
|
52
|
+
// Update instance count
|
53
|
+
classPriorOf.put(label, (frequency / this.totalCount));
|
54
|
+
}
|
55
|
+
|
56
|
+
// Update likelihood counts
|
57
|
+
for(String label : this.frequencyTable.getLabels()) {
|
58
|
+
// Set initial likelihood
|
59
|
+
likelihoodOf.put(label, 1d);
|
60
|
+
|
61
|
+
// Calculate likelihoods
|
62
|
+
for(String word : wordTable.keySet()) {
|
63
|
+
double laplaceWordLikelihood =
|
64
|
+
(this.frequencyTable.get(label, word) + 1d) /
|
65
|
+
(this.instanceCountOf.get(label) + this.wordTable.size());
|
66
|
+
|
67
|
+
// Update likelihood
|
68
|
+
double likelihood = likelihoodOf.get(label);
|
69
|
+
if(frequencyMap.containsKey(word)) {
|
70
|
+
likelihoodOf.put(label, likelihood * laplaceWordLikelihood);
|
71
|
+
} else {
|
72
|
+
likelihoodOf.put(label, likelihood * (1d - laplaceWordLikelihood));
|
73
|
+
}
|
74
|
+
}
|
75
|
+
|
76
|
+
// Default class posterior of label to 1.0
|
77
|
+
classPosteriorOf.putIfAbsent(label, 1d);
|
78
|
+
|
79
|
+
// Update class posterior
|
80
|
+
double classPosterior = classPriorOf.get(label) * likelihoodOf.get(label);
|
81
|
+
classPosteriorOf.put(label, classPosterior);
|
82
|
+
evidence += classPosterior;
|
83
|
+
}
|
84
|
+
|
85
|
+
// Normalize results
|
86
|
+
for(Map.Entry<String, Double> entry : classPosteriorOf.entrySet()) {
|
87
|
+
String label = entry.getKey();
|
88
|
+
double posterior = entry.getValue();
|
89
|
+
classPosteriorOf.put(label, posterior / evidence);
|
90
|
+
}
|
91
|
+
|
92
|
+
return classPosteriorOf;
|
93
|
+
}
|
94
|
+
|
95
|
+
public void updateIntegerCountBy(Map<String, Long> someMap, String someKey, long count) {
|
96
|
+
someMap.putIfAbsent(someKey, 0L);
|
97
|
+
someMap.put(someKey, someMap.get(someKey) + count);
|
98
|
+
}
|
99
|
+
|
100
|
+
public void updateDoubleCountBy(Map<String, Double> someMap, String someKey, double count) {
|
101
|
+
someMap.putIfAbsent(someKey, 0.0);
|
102
|
+
someMap.put(someKey, someMap.get(someKey) + count);
|
103
|
+
}
|
104
|
+
|
105
|
+
public FrequencyTable getFrequencyTable() {
|
106
|
+
return this.frequencyTable;
|
107
|
+
}
|
108
|
+
|
109
|
+
public Map<String, Long> getWordTable() {
|
110
|
+
return this.wordTable;
|
111
|
+
}
|
112
|
+
|
113
|
+
public Map<String, Long> getInstanceCount() {
|
114
|
+
return this.instanceCountOf;
|
115
|
+
}
|
116
|
+
|
117
|
+
public double getTotalCount() {
|
118
|
+
return totalCount;
|
119
|
+
}
|
120
|
+
}
|
@@ -0,0 +1,34 @@
|
|
1
|
+
package bae;
|
2
|
+
|
3
|
+
import bae.Document;
|
4
|
+
import org.junit.Test;
|
5
|
+
|
6
|
+
import java.util.Map;
|
7
|
+
|
8
|
+
import static org.junit.Assert.*;
|
9
|
+
|
10
|
+
public class DocumentTest {
|
11
|
+
|
12
|
+
@Test
|
13
|
+
public void testCanCreateAccurateFrequencyTable() {
|
14
|
+
Document document = new Document("aaa bbb aaa bbb ccc");
|
15
|
+
Map<String, Long> frequencyMap = document.getFrequencyMap();
|
16
|
+
|
17
|
+
assertEquals(2, (long)frequencyMap.get("aaa"));
|
18
|
+
assertEquals(2, (long)frequencyMap.get("bbb"));
|
19
|
+
assertEquals(1, (long)frequencyMap.get("ccc"));
|
20
|
+
}
|
21
|
+
|
22
|
+
@Test
|
23
|
+
public void testCanParseADirtyString() {
|
24
|
+
Document document = new Document(" a aaa\ta \t\t\t aa a bbb a aaa bbb ccc ");
|
25
|
+
Map<String, Long> frequencyMap = document.getFrequencyMap();
|
26
|
+
|
27
|
+
assertEquals(2, (long)frequencyMap.get("aaa"));
|
28
|
+
assertEquals(2, (long)frequencyMap.get("bbb"));
|
29
|
+
assertEquals(1, (long)frequencyMap.get("ccc"));
|
30
|
+
assertEquals(1, (long)frequencyMap.get("aa"));
|
31
|
+
assertEquals(4, (long)frequencyMap.get("a"));
|
32
|
+
}
|
33
|
+
|
34
|
+
}
|
@@ -0,0 +1,23 @@
|
|
1
|
+
package bae;
|
2
|
+
|
3
|
+
import bae.FrequencyTable;
|
4
|
+
import org.junit.Test;
|
5
|
+
|
6
|
+
import static org.junit.Assert.*;
|
7
|
+
|
8
|
+
public class FrequencyTableTest {
|
9
|
+
|
10
|
+
@Test
|
11
|
+
public void canInsertIntoFrequencyTable() {
|
12
|
+
FrequencyTable frequencyTable = new FrequencyTable();
|
13
|
+
|
14
|
+
frequencyTable.increaseFrequencyBy("a", "b", 100);
|
15
|
+
|
16
|
+
assertEquals(100, frequencyTable.get("a", "b"));
|
17
|
+
|
18
|
+
// Make sure we fail correctly
|
19
|
+
assertEquals(0, frequencyTable.get("a", "z"));
|
20
|
+
assertEquals(0, frequencyTable.get("z", "z"));
|
21
|
+
}
|
22
|
+
|
23
|
+
}
|
@@ -0,0 +1,134 @@
|
|
1
|
+
package bae;
|
2
|
+
|
3
|
+
import bae.Document;
|
4
|
+
import bae.NaiveBayesClassifier;
|
5
|
+
import org.junit.Test;
|
6
|
+
|
7
|
+
import java.util.Map;
|
8
|
+
|
9
|
+
import static org.junit.Assert.*;
|
10
|
+
|
11
|
+
public class NaiveBayesClassifierTest {
|
12
|
+
|
13
|
+
@Test
|
14
|
+
public void canCorrectlyTrainFrequencyTable() {
|
15
|
+
NaiveBayesClassifier n = new NaiveBayesClassifier();
|
16
|
+
|
17
|
+
n.train("positive", new Document("bbb"));
|
18
|
+
n.train("negative", new Document("ccc ccc ddd ddd ddd"));
|
19
|
+
n.train("positive", new Document("aaa bbb bbb"));
|
20
|
+
n.train("negative", new Document("ccc ccc ccc ddd ddd ddd ddd"));
|
21
|
+
|
22
|
+
assertEquals(1, n.getFrequencyTable().get("positive", "aaa"));
|
23
|
+
assertEquals(3, n.getFrequencyTable().get("positive", "bbb"));
|
24
|
+
assertEquals(5, n.getFrequencyTable().get("negative", "ccc"));
|
25
|
+
assertEquals(7, n.getFrequencyTable().get("negative", "ddd"));
|
26
|
+
}
|
27
|
+
|
28
|
+
@Test
|
29
|
+
public void canCorrectlyTrainWordTable() {
|
30
|
+
NaiveBayesClassifier n = new NaiveBayesClassifier();
|
31
|
+
|
32
|
+
n.train("positive", new Document("bbb"));
|
33
|
+
n.train("negative", new Document("ccc ccc ddd ddd ddd"));
|
34
|
+
n.train("positive", new Document("aaa bbb bbb"));
|
35
|
+
n.train("negative", new Document("ccc ccc ccc ddd ddd ddd ddd"));
|
36
|
+
|
37
|
+
assertEquals(1, (long)n.getWordTable().get("aaa"));
|
38
|
+
assertEquals(1, (long)n.getWordTable().get("bbb"));
|
39
|
+
assertEquals(1, (long)n.getWordTable().get("ccc"));
|
40
|
+
assertEquals(1, (long)n.getWordTable().get("ddd"));
|
41
|
+
}
|
42
|
+
|
43
|
+
@Test
|
44
|
+
public void canCorrectlyTrainInstanceCount() {
|
45
|
+
NaiveBayesClassifier n = new NaiveBayesClassifier();
|
46
|
+
|
47
|
+
n.train("positive", new Document("bbb"));
|
48
|
+
n.train("negative", new Document("ccc ccc ddd ddd ddd"));
|
49
|
+
n.train("positive", new Document("aaa bbb bbb"));
|
50
|
+
n.train("negative", new Document("ccc ccc ccc ddd ddd ddd ddd"));
|
51
|
+
n.train("negative", new Document("ccc ccc ccc ddd ddd"));
|
52
|
+
|
53
|
+
assertEquals(2, (long)n.getInstanceCount().get("positive"));
|
54
|
+
assertEquals(3, (long)n.getInstanceCount().get("negative"));
|
55
|
+
assertEquals(5, (long)n.getTotalCount());
|
56
|
+
}
|
57
|
+
|
58
|
+
@Test
|
59
|
+
public void canCorrectlyClassifyPositiveWithTwoLabels() {
|
60
|
+
NaiveBayesClassifier n = new NaiveBayesClassifier();
|
61
|
+
|
62
|
+
Document d = new Document("bbb");
|
63
|
+
d.addZeroCount("aaa");
|
64
|
+
|
65
|
+
n.train("positive", d);
|
66
|
+
n.train("negative", new Document("ccc ccc ddd ddd ddd"));
|
67
|
+
|
68
|
+
Map<String, Double> results = n.classify(new Document("aaa bbb"));
|
69
|
+
|
70
|
+
assertEquals(0.9411764705882353, results.get("positive"), 0.00001);
|
71
|
+
assertEquals(0.05882352941176469, results.get("negative"), 0.00001);
|
72
|
+
}
|
73
|
+
|
74
|
+
@Test
|
75
|
+
public void canCorrectlyClassifyNegativeWithTwoLabels() {
|
76
|
+
NaiveBayesClassifier n = new NaiveBayesClassifier();
|
77
|
+
|
78
|
+
Document d = new Document("bbb");
|
79
|
+
d.addZeroCount("aaa");
|
80
|
+
|
81
|
+
n.train("positive", d);
|
82
|
+
n.train("negative", new Document("ccc ccc ddd ddd ddd"));
|
83
|
+
|
84
|
+
Map<String, Double> results = n.classify(new Document("ccc ccc ccc ddd ddd ddd"));
|
85
|
+
|
86
|
+
assertEquals(0.05882352941176469, results.get("positive"), 0.00001);
|
87
|
+
assertEquals(0.9411764705882353, results.get("negative"), 0.00001);
|
88
|
+
}
|
89
|
+
|
90
|
+
@Test
|
91
|
+
public void canCorrectlyClassifyPositiveWithThreeLabels() {
|
92
|
+
NaiveBayesClassifier n = new NaiveBayesClassifier();
|
93
|
+
|
94
|
+
n.train("positive", new Document("aaa aaa bbb"));
|
95
|
+
n.train("negative", new Document("ccc ccc ddd ddd"));
|
96
|
+
n.train("neutral", new Document("eee eee eee fff fff fff"));
|
97
|
+
|
98
|
+
Map<String, Double> results = n.classify(new Document("aaa bbb"));
|
99
|
+
|
100
|
+
assertEquals(0.896265560165975, results.get("positive"), 0.00001);
|
101
|
+
assertEquals(0.06639004149377592, results.get("negative"), 0.00001);
|
102
|
+
assertEquals(0.03734439834024896, results.get("neutral"), 0.00001);
|
103
|
+
}
|
104
|
+
|
105
|
+
@Test
|
106
|
+
public void canCorrectlyClassifyNegativeWithThreeLabels() {
|
107
|
+
NaiveBayesClassifier n = new NaiveBayesClassifier();
|
108
|
+
|
109
|
+
n.train("positive", new Document("aaa aaa bbb"));
|
110
|
+
n.train("negative", new Document("ccc ccc ddd ddd"));
|
111
|
+
n.train("neutral", new Document("eee eee eee fff fff fff"));
|
112
|
+
|
113
|
+
Map<String, Double> results = n.classify(new Document("ccc ccc ccc ddd ddd"));
|
114
|
+
|
115
|
+
assertEquals(0.05665722379603399, results.get("positive"), 0.00001);
|
116
|
+
assertEquals(0.9178470254957508, results.get("negative"), 0.00001);
|
117
|
+
assertEquals(0.0254957507082153, results.get("neutral"), 0.00001);
|
118
|
+
}
|
119
|
+
|
120
|
+
@Test
|
121
|
+
public void canCorrectlyClassifyNeutralWithThreeLabels() {
|
122
|
+
NaiveBayesClassifier n = new NaiveBayesClassifier();
|
123
|
+
|
124
|
+
n.train("positive", new Document("aaa aaa bbb"));
|
125
|
+
n.train("negative", new Document("ccc ccc ddd ddd"));
|
126
|
+
n.train("neutral", new Document("eee eee eee fff fff fff"));
|
127
|
+
|
128
|
+
Map<String, Double> results = n.classify(new Document("aaa ddd ddd eee eee eee fff"));
|
129
|
+
|
130
|
+
assertEquals(0.12195121951219513, results.get("positive"), 0.00001);
|
131
|
+
assertEquals(0.09756097560975606, results.get("negative"), 0.00001);
|
132
|
+
assertEquals(0.7804878048780488, results.get("neutral"), 0.00001);
|
133
|
+
}
|
134
|
+
}
|
data/target/bae.jar
ADDED
Binary file
|