bae 0.0.1 → 0.0.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +2 -1
- data/README.md +43 -1
- data/build.xml +3 -3
- data/lib/bae/classifier.rb +167 -6
- data/lib/bae/native_classifier.rb +26 -0
- data/lib/bae/version.rb +1 -1
- data/lib/bae.rb +2 -7
- data/spec/lib/bae/classifier_spec.rb +57 -2
- data/spec/lib/bae/native_classifier_spec.rb +33 -0
- data/spec/spec_helper.rb +1 -0
- data/src/main/java/bae/Document.java +3 -1
- data/src/main/java/bae/FrequencyTable.java +6 -2
- data/src/main/java/bae/NaiveBayesClassifier.java +43 -19
- data/target/bae.jar +0 -0
- metadata +5 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA1:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: c28a60c92163259beddf8af99cd31357cf3a95a8
|
4
|
+
data.tar.gz: 34cc3ee332ec6f79d74e2ab04f8d60227c49c1a9
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 4f92cc52a40438b18bf543299345b1a9ae57443b53e8b8cae7169181a436dfc2e4c86627527239c6c0c29d99f0ff723c4f9272a994f4595d58c1f6d5c8acbd5c
|
7
|
+
data.tar.gz: dfca0d36849a088fdc5c60cb62c9256a7743c329d8acbbdba87cac34f0fb3de1195909598474283a12862a5a1b92688258189b47945330dcf8cc6df563b77e19
|
data/.gitignore
CHANGED
data/README.md
CHANGED
@@ -3,6 +3,15 @@ Bae
|
|
3
3
|
|
4
4
|
Bae is a multinomial naive bayes classifier based on another gem ["naivebayes"](https://github.com/id774/naivebayes), only this one uses java to do the heavy lifting.
|
5
5
|
|
6
|
+
By default this will use the vanilla ruby implementation, but you can use the native classifier written in java.
|
7
|
+
|
8
|
+
```ruby
|
9
|
+
require 'bae/native_classifier'
|
10
|
+
|
11
|
+
classifier = ::Bae::NativeClassifier.new
|
12
|
+
```
|
13
|
+
|
14
|
+
|
6
15
|
## Installation
|
7
16
|
|
8
17
|
Add this line to your application's Gemfile:
|
@@ -28,6 +37,9 @@ You can refer to ["naivebayes"](https://github.com/id774/naivebayes) gem for mor
|
|
28
37
|
classifier = ::Bae::Classifier.new
|
29
38
|
classifier.train("positive", {"aaa" => 0, "bbb" => 1})
|
30
39
|
classifier.train("negative", {"ccc" => 2, "ddd" => 3})
|
40
|
+
|
41
|
+
classifier.finish_training!
|
42
|
+
|
31
43
|
classifier.classify({"aaa" => 1, "bbb" => 1})
|
32
44
|
|
33
45
|
#=> {"positive" => 0.8767123287671234, "negative" => 0.12328767123287669}
|
@@ -39,15 +51,45 @@ classifier = ::Bae::Classifier.new
|
|
39
51
|
classifier.train("positive", "aaa aaa bbb");
|
40
52
|
classifier.train("negative", "ccc ccc ddd ddd");
|
41
53
|
classifier.train("neutral", "eee eee eee fff fff fff");
|
54
|
+
|
55
|
+
classifier.finish_training!
|
56
|
+
|
42
57
|
classifier.classify("aaa bbb")
|
43
58
|
|
44
59
|
#=> {"positive"=>0.8962655601659751, "negative"=>0.0663900414937759, "neutral"=>0.037344398340248955}
|
45
60
|
```
|
46
61
|
|
62
|
+
### Saving State
|
63
|
+
|
64
|
+
You can actually save a snapshot of the trained classifier to disk and load it into memory.
|
65
|
+
|
66
|
+
```ruby
|
67
|
+
# From the example above...
|
68
|
+
classifier = ::Bae::Classifier.new
|
69
|
+
classifier.train("positive", {"aaa" => 0, "bbb" => 1})
|
70
|
+
classifier.train("negative", {"ccc" => 2, "ddd" => 3})
|
71
|
+
|
72
|
+
classifier.finish_training!
|
73
|
+
|
74
|
+
classifier.classify({"aaa" => 1, "bbb" => 1})
|
75
|
+
#=> {"positive" => 0.8767123287671234, "negative" => 0.12328767123287669}
|
76
|
+
|
77
|
+
# Now let's save it to disk
|
78
|
+
classifier.save_state("/tmp/some_state.json")
|
79
|
+
|
80
|
+
# Let's create a new classifier and load from the sate we just saved
|
81
|
+
classifier = ::Bae::Classifier.new
|
82
|
+
classifier.load_state("/tmp/some_state.json")
|
83
|
+
|
84
|
+
# Now we can classify without retraining
|
85
|
+
classifier.classify({"aaa" => 1, "bbb" => 1})
|
86
|
+
#=> {"positive" => 0.8767123287671234, "negative" => 0.12328767123287669}
|
87
|
+
```
|
88
|
+
|
47
89
|
|
48
90
|
## Contributing
|
49
91
|
|
50
|
-
1. Fork it ( https://github.com/
|
92
|
+
1. Fork it ( https://github.com/film42/bae/fork )
|
51
93
|
2. Create your feature branch (`git checkout -b my-new-feature`)
|
52
94
|
3. Commit your changes (`git commit -am 'Add some feature'`)
|
53
95
|
4. Push to the branch (`git push origin my-new-feature`)
|
data/build.xml
CHANGED
@@ -1,13 +1,13 @@
|
|
1
1
|
<project>
|
2
2
|
|
3
3
|
<target name="clean">
|
4
|
-
<delete dir="
|
4
|
+
<delete dir="out/classes"/>
|
5
5
|
</target>
|
6
6
|
|
7
|
-
<target name="compile">
|
7
|
+
<target name="compile" depends="clean">
|
8
8
|
<mkdir dir="out"/>
|
9
9
|
<mkdir dir="out/classes"/>
|
10
|
-
<javac srcdir="src/main/java" destdir="out/classes"/>
|
10
|
+
<javac srcdir="src/main/java" destdir="out/classes" source="1.7" target="1.7" includeantruntime="false" />
|
11
11
|
</target>
|
12
12
|
|
13
13
|
<target name="jar" depends="compile">
|
data/lib/bae/classifier.rb
CHANGED
@@ -1,19 +1,180 @@
|
|
1
1
|
module Bae
|
2
2
|
class Classifier
|
3
3
|
|
4
|
-
|
4
|
+
attr_accessor :frequency_table, :label_index, :label_index_sequence,
|
5
|
+
:label_instance_count, :total_terms
|
5
6
|
|
6
7
|
def initialize
|
7
|
-
@
|
8
|
+
@frequency_table = ::Hash.new
|
9
|
+
@label_instance_count = ::Hash.new { |hash, label| hash[label] = 0 }
|
10
|
+
@label_index = ::Hash.new { |hash, label| hash[label] = 0 }
|
11
|
+
@label_index_sequence = -1 # start at -1 so 0 is first value
|
12
|
+
@total_terms = 0.0
|
8
13
|
end
|
9
14
|
|
10
|
-
def
|
11
|
-
|
15
|
+
def finish_training!
|
16
|
+
calculate_likelihoods!
|
17
|
+
calculate_priors!
|
12
18
|
end
|
13
19
|
|
14
|
-
def
|
15
|
-
|
20
|
+
def train(label, training_data)
|
21
|
+
if training_data.is_a?(::String)
|
22
|
+
train_from_string(label, training_data)
|
23
|
+
elsif training_data.is_a?(::Hash)
|
24
|
+
train_from_hash(label, training_data)
|
25
|
+
else
|
26
|
+
fail 'Training data must either be a string or hash'
|
27
|
+
end
|
16
28
|
end
|
17
29
|
|
30
|
+
def train_from_string(label, document)
|
31
|
+
words = document.split
|
32
|
+
|
33
|
+
words.each do |word|
|
34
|
+
update_label_index(label)
|
35
|
+
update_frequency_table(label, word, 1)
|
36
|
+
end
|
37
|
+
@label_instance_count[label] += 1
|
38
|
+
@total_terms += 1
|
39
|
+
end
|
40
|
+
|
41
|
+
def train_from_hash(label, frequency_hash)
|
42
|
+
frequency_hash.each do |word, frequency|
|
43
|
+
update_label_index(label)
|
44
|
+
update_frequency_table(label, word, frequency)
|
45
|
+
end
|
46
|
+
@label_instance_count[label] += 1
|
47
|
+
@total_terms += 1
|
48
|
+
end
|
49
|
+
|
50
|
+
def classify(data)
|
51
|
+
if data.is_a?(::String)
|
52
|
+
classify_from_string(data)
|
53
|
+
elsif data.is_a?(::Hash)
|
54
|
+
classify_from_hash(data)
|
55
|
+
else
|
56
|
+
fail 'Training data must either be a string or hash'
|
57
|
+
end
|
58
|
+
end
|
59
|
+
|
60
|
+
def classify_from_hash(frequency_hash)
|
61
|
+
document = frequency_hash.map{ |word, frequency| (word + ' ') * frequency }.join
|
62
|
+
|
63
|
+
classify_from_string(document)
|
64
|
+
end
|
65
|
+
|
66
|
+
def classify_from_string(document)
|
67
|
+
words = document.split.uniq
|
68
|
+
likelihoods = @likelihoods.dup
|
69
|
+
posterior = {}
|
70
|
+
|
71
|
+
vocab_size = frequency_table.keys.size
|
72
|
+
|
73
|
+
label_index.each do |label, index|
|
74
|
+
words.map do |word|
|
75
|
+
row = frequency_table[word]
|
76
|
+
|
77
|
+
unless row.nil?
|
78
|
+
laplace_word_likelihood = (row[index] + 1.0).to_f / (label_instance_count[label] + vocab_size).to_f
|
79
|
+
likelihoods[label] *= laplace_word_likelihood / (1.0 - laplace_word_likelihood)
|
80
|
+
end
|
81
|
+
end
|
82
|
+
|
83
|
+
posterior[label] = @priors[label] * likelihoods[label]
|
84
|
+
end
|
85
|
+
|
86
|
+
normalize(posterior)
|
87
|
+
end
|
88
|
+
|
89
|
+
def save_state(path)
|
90
|
+
state = {}
|
91
|
+
state['frequency_table'] = frequency_table
|
92
|
+
state['label_instance_count'] = label_instance_count
|
93
|
+
state['label_index'] = label_index
|
94
|
+
state['label_index_sequence'] = label_index_sequence
|
95
|
+
state['total_terms'] = total_terms
|
96
|
+
|
97
|
+
::File.open(::File.expand_path(path), 'w') do |handle|
|
98
|
+
handle.write(state.to_json)
|
99
|
+
end
|
100
|
+
end
|
101
|
+
|
102
|
+
def load_state(path)
|
103
|
+
state = ::JSON.parse(::File.read(::File.expand_path(path)))
|
104
|
+
|
105
|
+
fail 'Missing frequency_table' unless state['frequency_table']
|
106
|
+
fail 'Missing label_instance_count' unless state['label_instance_count']
|
107
|
+
fail 'Missing label_index' unless state['label_index']
|
108
|
+
fail 'Missing label_index_sequence' unless state['label_index_sequence']
|
109
|
+
fail 'Missing total_terms' unless state['total_terms']
|
110
|
+
|
111
|
+
@frequency_table = state['frequency_table']
|
112
|
+
@label_instance_count = state['label_instance_count']
|
113
|
+
@label_index = state['label_index']
|
114
|
+
@label_index_sequence = state['label_index_sequence']
|
115
|
+
@total_terms = state['total_terms']
|
116
|
+
|
117
|
+
finish_training!
|
118
|
+
end
|
119
|
+
|
120
|
+
private
|
121
|
+
|
122
|
+
def calculate_likelihoods!
|
123
|
+
@likelihoods = label_index.inject({}) do |accumulator, (label, index)|
|
124
|
+
initial_likelihood = 1.0
|
125
|
+
vocab_size = frequency_table.keys.size
|
126
|
+
|
127
|
+
frequency_table.each do |feature, row|
|
128
|
+
laplace_word_likelihood = (row[index] + 1.0).to_f / (label_instance_count[label] + vocab_size).to_f
|
129
|
+
initial_likelihood *= (1.0 - laplace_word_likelihood)
|
130
|
+
end
|
131
|
+
|
132
|
+
accumulator[label] = initial_likelihood
|
133
|
+
accumulator
|
134
|
+
end
|
135
|
+
end
|
136
|
+
|
137
|
+
def calculate_priors!
|
138
|
+
@priors = label_instance_count.inject({}) do |hash, (label, count)|
|
139
|
+
hash[label] = count / total_terms
|
140
|
+
hash
|
141
|
+
end
|
142
|
+
end
|
143
|
+
|
144
|
+
def get_next_sequence_value
|
145
|
+
@label_index_sequence += 1
|
146
|
+
end
|
147
|
+
|
148
|
+
def normalize(posterior)
|
149
|
+
sum = posterior.inject(0.0) { |accumulator, (key, value)| accumulator + value }
|
150
|
+
|
151
|
+
posterior.inject({}) do |accumulator, (key, value)|
|
152
|
+
accumulator[key] = value / sum
|
153
|
+
accumulator
|
154
|
+
end
|
155
|
+
end
|
156
|
+
|
157
|
+
def update_label_index(label)
|
158
|
+
unless label_index.keys.include?(label)
|
159
|
+
index = get_next_sequence_value
|
160
|
+
label_index[label] = index
|
161
|
+
|
162
|
+
frequency_table.each do |feature, value|
|
163
|
+
value[index] = 0
|
164
|
+
end
|
165
|
+
end
|
166
|
+
end
|
167
|
+
|
168
|
+
def update_frequency_table(label, word, frequency)
|
169
|
+
row = frequency_table[word]
|
170
|
+
index = label_index[label]
|
171
|
+
|
172
|
+
if row
|
173
|
+
row[index] += frequency
|
174
|
+
else
|
175
|
+
frequency_table[word] = label_index.keys.map { |label| 0 }
|
176
|
+
frequency_table[word][index] += frequency
|
177
|
+
end
|
178
|
+
end
|
18
179
|
end
|
19
180
|
end
|
@@ -0,0 +1,26 @@
|
|
1
|
+
require 'java'
|
2
|
+
require ::File.join(::File.dirname(__FILE__), "..", "..", "target" , "bae.jar")
|
3
|
+
|
4
|
+
module Bae
|
5
|
+
class NativeClassifier
|
6
|
+
|
7
|
+
attr_reader :internal_classifier
|
8
|
+
|
9
|
+
def initialize
|
10
|
+
@internal_classifier = ::Java::Bae::NaiveBayesClassifier.new
|
11
|
+
end
|
12
|
+
|
13
|
+
def train(label, feature)
|
14
|
+
internal_classifier.train(label, ::Java::Bae::Document.new(feature))
|
15
|
+
end
|
16
|
+
|
17
|
+
def classify(feature)
|
18
|
+
internal_classifier.classify(::Java::Bae::Document.new(feature))
|
19
|
+
end
|
20
|
+
|
21
|
+
def finish_training!
|
22
|
+
internal_classifier.calculateInitialLikelihoods()
|
23
|
+
end
|
24
|
+
|
25
|
+
end
|
26
|
+
end
|
data/lib/bae/version.rb
CHANGED
data/lib/bae.rb
CHANGED
@@ -1,22 +1,35 @@
|
|
1
1
|
require 'spec_helper'
|
2
2
|
|
3
|
+
require 'bae/native_classifier'
|
4
|
+
|
3
5
|
describe ::Bae::Classifier do
|
4
6
|
|
5
7
|
subject { described_class.new }
|
6
8
|
|
7
|
-
|
9
|
+
let(:state_json) {
|
10
|
+
'{"frequency_table":{"aaa":[0,0],"bbb":[1,0],"ccc":[0,2],"ddd":[0,3]},"label_instance_count":{"positive":1,"negative":1},"label_index":{"positive":0,"negative":1},"label_index_sequence":1,"total_terms":2.0}'
|
11
|
+
}
|
12
|
+
let(:state) { ::JSON.parse(state_json) }
|
13
|
+
|
14
|
+
it "can classify a hash document" do
|
8
15
|
subject.train("positive", {"aaa" => 0, "bbb" => 1})
|
9
16
|
subject.train("negative", {"ccc" => 2, "ddd" => 3})
|
17
|
+
|
18
|
+
subject.finish_training!
|
19
|
+
|
10
20
|
results = subject.classify({"aaa" => 1, "bbb" => 1})
|
11
21
|
|
12
22
|
expect(results["positive"]).to be_within(0.001).of(0.94117)
|
13
23
|
expect(results["negative"]).to be_within(0.001).of(0.05882)
|
14
24
|
end
|
15
25
|
|
16
|
-
it "can classify from
|
26
|
+
it "can classify from a string based document" do
|
17
27
|
subject.train("positive", "aaa aaa bbb");
|
18
28
|
subject.train("negative", "ccc ccc ddd ddd");
|
19
29
|
subject.train("neutral", "eee eee eee fff fff fff");
|
30
|
+
|
31
|
+
subject.finish_training!
|
32
|
+
|
20
33
|
results = subject.classify("aaa bbb")
|
21
34
|
|
22
35
|
expect(results["positive"]).to be_within(0.001).of(0.89626)
|
@@ -24,4 +37,46 @@ describe ::Bae::Classifier do
|
|
24
37
|
expect(results["neutral"]).to be_within(0.001).of(0.03734)
|
25
38
|
end
|
26
39
|
|
40
|
+
it "fails when you attempt to train or test anything other than a hash or string" do
|
41
|
+
subject.train("positive", "aaa aaa bbb");
|
42
|
+
expect{ subject.train("a", 1337) }.to raise_error 'Training data must either be a string or hash'
|
43
|
+
|
44
|
+
subject.finish_training!
|
45
|
+
|
46
|
+
subject.classify("aaa bbb")
|
47
|
+
expect{ subject.classify(1337) }.to raise_error 'Training data must either be a string or hash'
|
48
|
+
end
|
49
|
+
|
50
|
+
it "can save the classifier state" do
|
51
|
+
subject.train("positive", {"aaa" => 0, "bbb" => 1})
|
52
|
+
subject.train("negative", {"ccc" => 2, "ddd" => 3})
|
53
|
+
|
54
|
+
subject.finish_training!
|
55
|
+
|
56
|
+
temp_file = ::Tempfile.new('some_state')
|
57
|
+
subject.save_state(temp_file.path)
|
58
|
+
|
59
|
+
temp_file.rewind
|
60
|
+
expect(temp_file.read).to eq(state_json)
|
61
|
+
|
62
|
+
temp_file.close
|
63
|
+
temp_file.unlink
|
64
|
+
end
|
65
|
+
|
66
|
+
it "can correctly load a classifier state and correctly classify" do
|
67
|
+
temp_file = ::Tempfile.new('some_state')
|
68
|
+
temp_file.write(state_json)
|
69
|
+
temp_file.rewind
|
70
|
+
|
71
|
+
subject.load_state(temp_file.path)
|
72
|
+
|
73
|
+
results = subject.classify({"aaa" => 1, "bbb" => 1})
|
74
|
+
|
75
|
+
expect(results["positive"]).to be_within(0.001).of(0.94117)
|
76
|
+
expect(results["negative"]).to be_within(0.001).of(0.05882)
|
77
|
+
|
78
|
+
temp_file.close
|
79
|
+
temp_file.unlink
|
80
|
+
end
|
81
|
+
|
27
82
|
end
|
@@ -0,0 +1,33 @@
|
|
1
|
+
require 'spec_helper'
|
2
|
+
|
3
|
+
describe ::Bae::NativeClassifier do
|
4
|
+
|
5
|
+
subject { described_class.new }
|
6
|
+
|
7
|
+
it "can classify a hash document" do
|
8
|
+
subject.train("positive", {"aaa" => 0, "bbb" => 1})
|
9
|
+
subject.train("negative", {"ccc" => 2, "ddd" => 3})
|
10
|
+
|
11
|
+
subject.finish_training!
|
12
|
+
|
13
|
+
results = subject.classify({"aaa" => 1, "bbb" => 1})
|
14
|
+
|
15
|
+
expect(results["positive"]).to be_within(0.001).of(0.94117)
|
16
|
+
expect(results["negative"]).to be_within(0.001).of(0.05882)
|
17
|
+
end
|
18
|
+
|
19
|
+
it "can classify from a string based document" do
|
20
|
+
subject.train("positive", "aaa aaa bbb");
|
21
|
+
subject.train("negative", "ccc ccc ddd ddd");
|
22
|
+
subject.train("neutral", "eee eee eee fff fff fff");
|
23
|
+
|
24
|
+
subject.finish_training!
|
25
|
+
|
26
|
+
results = subject.classify("aaa bbb")
|
27
|
+
|
28
|
+
expect(results["positive"]).to be_within(0.001).of(0.89626)
|
29
|
+
expect(results["negative"]).to be_within(0.001).of(0.06639)
|
30
|
+
expect(results["neutral"]).to be_within(0.001).of(0.03734)
|
31
|
+
end
|
32
|
+
|
33
|
+
end
|
data/spec/spec_helper.rb
CHANGED
@@ -33,7 +33,9 @@ public class Document {
|
|
33
33
|
|
34
34
|
// Set initial count if it doesn't have one yet
|
35
35
|
// Use zero because we'll add counts in the next line.
|
36
|
-
this.frequencyMap.
|
36
|
+
if(!this.frequencyMap.containsKey(wordToken)) {
|
37
|
+
this.frequencyMap.put(wordToken, 0L);
|
38
|
+
}
|
37
39
|
|
38
40
|
// Update count
|
39
41
|
this.frequencyMap.put(wordToken, this.frequencyMap.get(wordToken) + 1);
|
@@ -14,7 +14,9 @@ public class FrequencyTable {
|
|
14
14
|
|
15
15
|
public void insertOrIgnore(String label) {
|
16
16
|
// Add new hash to frequency table if it's not already there
|
17
|
-
this.frequencyTable.
|
17
|
+
if(!this.frequencyTable.containsKey(label)) {
|
18
|
+
this.frequencyTable.put(label, new HashMap<String, Long>());
|
19
|
+
}
|
18
20
|
}
|
19
21
|
|
20
22
|
public void increaseFrequencyBy(String label, String word, long frequency) {
|
@@ -24,7 +26,9 @@ public class FrequencyTable {
|
|
24
26
|
Map<String, Long> frequencyRow = this.frequencyTable.get(label);
|
25
27
|
|
26
28
|
// Make sure we have a frequency for that position in the table
|
27
|
-
frequencyRow.
|
29
|
+
if(!frequencyRow.containsKey(word)) {
|
30
|
+
frequencyRow.put(word, 0L);
|
31
|
+
}
|
28
32
|
|
29
33
|
// Update frequency
|
30
34
|
frequencyRow.put(word, frequencyRow.get(word) + frequency);
|
@@ -8,12 +8,16 @@ public class NaiveBayesClassifier {
|
|
8
8
|
private FrequencyTable frequencyTable;
|
9
9
|
private Map<String, Long> wordTable;
|
10
10
|
private Map<String, Long> instanceCountOf;
|
11
|
+
private Map<String, Double> initialLikelihoodOf;
|
12
|
+
Map<String, Double> classPriorOf;
|
11
13
|
private double totalCount = 0;
|
12
14
|
|
13
15
|
public NaiveBayesClassifier() {
|
14
16
|
this.frequencyTable = new FrequencyTable();
|
15
17
|
this.wordTable = new HashMap<>();
|
16
18
|
this.instanceCountOf = new HashMap<>();
|
19
|
+
this.initialLikelihoodOf = new HashMap<>();
|
20
|
+
this.classPriorOf = new HashMap<>();
|
17
21
|
}
|
18
22
|
|
19
23
|
public void train(String label, Document document) {
|
@@ -37,12 +41,23 @@ public class NaiveBayesClassifier {
|
|
37
41
|
updateIntegerCountBy(this.instanceCountOf, label, 1);
|
38
42
|
}
|
39
43
|
|
40
|
-
public
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
44
|
+
public void calculateInitialLikelihoods() {
|
45
|
+
// Update likelihood counts
|
46
|
+
for(String label : this.frequencyTable.getLabels()) {
|
47
|
+
// Set initial likelihood
|
48
|
+
initialLikelihoodOf.put(label, 1d);
|
49
|
+
|
50
|
+
// Calculate likelihoods
|
51
|
+
for (String word : this.wordTable.keySet()) {
|
52
|
+
double laplaceWordLikelihood =
|
53
|
+
(this.frequencyTable.get(label, word) + 1d) /
|
54
|
+
(this.instanceCountOf.get(label) + this.wordTable.size());
|
55
|
+
|
56
|
+
// Update likelihood
|
57
|
+
double likelihood = initialLikelihoodOf.get(label);
|
58
|
+
initialLikelihoodOf.put(label, likelihood * (1d - laplaceWordLikelihood));
|
59
|
+
}
|
60
|
+
}
|
46
61
|
|
47
62
|
// Update the prior
|
48
63
|
for(Map.Entry<String, Long> entry : this.instanceCountOf.entrySet()) {
|
@@ -50,34 +65,39 @@ public class NaiveBayesClassifier {
|
|
50
65
|
double frequency = entry.getValue();
|
51
66
|
|
52
67
|
// Update instance count
|
53
|
-
classPriorOf.put(label, (frequency / this.totalCount));
|
68
|
+
this.classPriorOf.put(label, (frequency / this.totalCount));
|
54
69
|
}
|
70
|
+
}
|
71
|
+
|
72
|
+
public Map<String, Double> classify(Document document) {
|
73
|
+
Map<String, Double> likelihoodOf = new HashMap<>();
|
74
|
+
Map<String, Double> classPosteriorOf = new HashMap<>();
|
75
|
+
Map<String, Long> featureFrequencyMap = document.getFrequencyMap();
|
76
|
+
double evidence = 0;
|
55
77
|
|
56
78
|
// Update likelihood counts
|
57
79
|
for(String label : this.frequencyTable.getLabels()) {
|
58
80
|
// Set initial likelihood
|
59
|
-
likelihoodOf.put(label,
|
81
|
+
likelihoodOf.put(label, this.initialLikelihoodOf.get(label));
|
60
82
|
|
61
|
-
// Calculate likelihoods
|
62
|
-
for(String word :
|
83
|
+
// Calculate actual likelihoods likelihoods
|
84
|
+
for(String word : featureFrequencyMap.keySet()) {
|
63
85
|
double laplaceWordLikelihood =
|
64
86
|
(this.frequencyTable.get(label, word) + 1d) /
|
65
87
|
(this.instanceCountOf.get(label) + this.wordTable.size());
|
66
88
|
|
67
|
-
// Update likelihood
|
89
|
+
// Update likelihood for words not in features
|
68
90
|
double likelihood = likelihoodOf.get(label);
|
69
|
-
if(
|
70
|
-
likelihoodOf.put(label, likelihood * laplaceWordLikelihood);
|
71
|
-
} else {
|
72
|
-
likelihoodOf.put(label, likelihood * (1d - laplaceWordLikelihood));
|
91
|
+
if(featureFrequencyMap.containsKey(word)) {
|
92
|
+
likelihoodOf.put(label, (likelihood * laplaceWordLikelihood) / (1d - laplaceWordLikelihood));
|
73
93
|
}
|
74
94
|
}
|
75
95
|
|
76
96
|
// Default class posterior of label to 1.0
|
77
|
-
classPosteriorOf.
|
97
|
+
classPosteriorOf.put(label, 1d);
|
78
98
|
|
79
99
|
// Update class posterior
|
80
|
-
double classPosterior = classPriorOf.get(label) * likelihoodOf.get(label);
|
100
|
+
double classPosterior = this.classPriorOf.get(label) * likelihoodOf.get(label);
|
81
101
|
classPosteriorOf.put(label, classPosterior);
|
82
102
|
evidence += classPosterior;
|
83
103
|
}
|
@@ -93,12 +113,16 @@ public class NaiveBayesClassifier {
|
|
93
113
|
}
|
94
114
|
|
95
115
|
public void updateIntegerCountBy(Map<String, Long> someMap, String someKey, long count) {
|
96
|
-
someMap.
|
116
|
+
if(!someMap.containsKey(someKey)) {
|
117
|
+
someMap.put(someKey, 0L);
|
118
|
+
}
|
97
119
|
someMap.put(someKey, someMap.get(someKey) + count);
|
98
120
|
}
|
99
121
|
|
100
122
|
public void updateDoubleCountBy(Map<String, Double> someMap, String someKey, double count) {
|
101
|
-
someMap.
|
123
|
+
if(!someMap.containsKey(someKey)) {
|
124
|
+
someMap.put(someKey, 0.0);
|
125
|
+
}
|
102
126
|
someMap.put(someKey, someMap.get(someKey) + count);
|
103
127
|
}
|
104
128
|
|
data/target/bae.jar
CHANGED
Binary file
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: bae
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.0.
|
4
|
+
version: 0.0.9
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Garrett Thornburg
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2015-02-
|
11
|
+
date: 2015-02-27 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
requirement: !ruby/object:Gem::Requirement
|
@@ -68,8 +68,10 @@ files:
|
|
68
68
|
- build.xml
|
69
69
|
- lib/bae.rb
|
70
70
|
- lib/bae/classifier.rb
|
71
|
+
- lib/bae/native_classifier.rb
|
71
72
|
- lib/bae/version.rb
|
72
73
|
- spec/lib/bae/classifier_spec.rb
|
74
|
+
- spec/lib/bae/native_classifier_spec.rb
|
73
75
|
- spec/spec_helper.rb
|
74
76
|
- src/main/java/bae/Document.java
|
75
77
|
- src/main/java/bae/FrequencyTable.java
|
@@ -104,4 +106,5 @@ specification_version: 4
|
|
104
106
|
summary: Multinomial naive bayes classifier with a kick of java
|
105
107
|
test_files:
|
106
108
|
- spec/lib/bae/classifier_spec.rb
|
109
|
+
- spec/lib/bae/native_classifier_spec.rb
|
107
110
|
- spec/spec_helper.rb
|