bae 0.0.1 → 0.0.9
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.gitignore +2 -1
- data/README.md +43 -1
- data/build.xml +3 -3
- data/lib/bae/classifier.rb +167 -6
- data/lib/bae/native_classifier.rb +26 -0
- data/lib/bae/version.rb +1 -1
- data/lib/bae.rb +2 -7
- data/spec/lib/bae/classifier_spec.rb +57 -2
- data/spec/lib/bae/native_classifier_spec.rb +33 -0
- data/spec/spec_helper.rb +1 -0
- data/src/main/java/bae/Document.java +3 -1
- data/src/main/java/bae/FrequencyTable.java +6 -2
- data/src/main/java/bae/NaiveBayesClassifier.java +43 -19
- data/target/bae.jar +0 -0
- metadata +5 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA1:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: c28a60c92163259beddf8af99cd31357cf3a95a8
|
4
|
+
data.tar.gz: 34cc3ee332ec6f79d74e2ab04f8d60227c49c1a9
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 4f92cc52a40438b18bf543299345b1a9ae57443b53e8b8cae7169181a436dfc2e4c86627527239c6c0c29d99f0ff723c4f9272a994f4595d58c1f6d5c8acbd5c
|
7
|
+
data.tar.gz: dfca0d36849a088fdc5c60cb62c9256a7743c329d8acbbdba87cac34f0fb3de1195909598474283a12862a5a1b92688258189b47945330dcf8cc6df563b77e19
|
data/.gitignore
CHANGED
data/README.md
CHANGED
@@ -3,6 +3,15 @@ Bae
|
|
3
3
|
|
4
4
|
Bae is a multinomial naive bayes classifier based on another gem ["naivebayes"](https://github.com/id774/naivebayes), only this one uses java to do the heavy lifting.
|
5
5
|
|
6
|
+
By default this will use the vanilla ruby implementation, but you can use the native classifier written in java.
|
7
|
+
|
8
|
+
```ruby
|
9
|
+
require 'bae/native_classifier'
|
10
|
+
|
11
|
+
classifier = ::Bae::NativeClassifier.new
|
12
|
+
```
|
13
|
+
|
14
|
+
|
6
15
|
## Installation
|
7
16
|
|
8
17
|
Add this line to your application's Gemfile:
|
@@ -28,6 +37,9 @@ You can refer to ["naivebayes"](https://github.com/id774/naivebayes) gem for mor
|
|
28
37
|
classifier = ::Bae::Classifier.new
|
29
38
|
classifier.train("positive", {"aaa" => 0, "bbb" => 1})
|
30
39
|
classifier.train("negative", {"ccc" => 2, "ddd" => 3})
|
40
|
+
|
41
|
+
classifier.finish_training!
|
42
|
+
|
31
43
|
classifier.classify({"aaa" => 1, "bbb" => 1})
|
32
44
|
|
33
45
|
#=> {"positive" => 0.8767123287671234, "negative" => 0.12328767123287669}
|
@@ -39,15 +51,45 @@ classifier = ::Bae::Classifier.new
|
|
39
51
|
classifier.train("positive", "aaa aaa bbb");
|
40
52
|
classifier.train("negative", "ccc ccc ddd ddd");
|
41
53
|
classifier.train("neutral", "eee eee eee fff fff fff");
|
54
|
+
|
55
|
+
classifier.finish_training!
|
56
|
+
|
42
57
|
classifier.classify("aaa bbb")
|
43
58
|
|
44
59
|
#=> {"positive"=>0.8962655601659751, "negative"=>0.0663900414937759, "neutral"=>0.037344398340248955}
|
45
60
|
```
|
46
61
|
|
62
|
+
### Saving State
|
63
|
+
|
64
|
+
You can actually save a snapshot of the trained classifier to disk and load it into memory.
|
65
|
+
|
66
|
+
```ruby
|
67
|
+
# From the example above...
|
68
|
+
classifier = ::Bae::Classifier.new
|
69
|
+
classifier.train("positive", {"aaa" => 0, "bbb" => 1})
|
70
|
+
classifier.train("negative", {"ccc" => 2, "ddd" => 3})
|
71
|
+
|
72
|
+
classifier.finish_training!
|
73
|
+
|
74
|
+
classifier.classify({"aaa" => 1, "bbb" => 1})
|
75
|
+
#=> {"positive" => 0.8767123287671234, "negative" => 0.12328767123287669}
|
76
|
+
|
77
|
+
# Now let's save it to disk
|
78
|
+
classifier.save_state("/tmp/some_state.json")
|
79
|
+
|
80
|
+
# Let's create a new classifier and load from the sate we just saved
|
81
|
+
classifier = ::Bae::Classifier.new
|
82
|
+
classifier.load_state("/tmp/some_state.json")
|
83
|
+
|
84
|
+
# Now we can classify without retraining
|
85
|
+
classifier.classify({"aaa" => 1, "bbb" => 1})
|
86
|
+
#=> {"positive" => 0.8767123287671234, "negative" => 0.12328767123287669}
|
87
|
+
```
|
88
|
+
|
47
89
|
|
48
90
|
## Contributing
|
49
91
|
|
50
|
-
1. Fork it ( https://github.com/
|
92
|
+
1. Fork it ( https://github.com/film42/bae/fork )
|
51
93
|
2. Create your feature branch (`git checkout -b my-new-feature`)
|
52
94
|
3. Commit your changes (`git commit -am 'Add some feature'`)
|
53
95
|
4. Push to the branch (`git push origin my-new-feature`)
|
data/build.xml
CHANGED
@@ -1,13 +1,13 @@
|
|
1
1
|
<project>
|
2
2
|
|
3
3
|
<target name="clean">
|
4
|
-
<delete dir="
|
4
|
+
<delete dir="out/classes"/>
|
5
5
|
</target>
|
6
6
|
|
7
|
-
<target name="compile">
|
7
|
+
<target name="compile" depends="clean">
|
8
8
|
<mkdir dir="out"/>
|
9
9
|
<mkdir dir="out/classes"/>
|
10
|
-
<javac srcdir="src/main/java" destdir="out/classes"/>
|
10
|
+
<javac srcdir="src/main/java" destdir="out/classes" source="1.7" target="1.7" includeantruntime="false" />
|
11
11
|
</target>
|
12
12
|
|
13
13
|
<target name="jar" depends="compile">
|
data/lib/bae/classifier.rb
CHANGED
@@ -1,19 +1,180 @@
|
|
1
1
|
module Bae
|
2
2
|
class Classifier
|
3
3
|
|
4
|
-
|
4
|
+
attr_accessor :frequency_table, :label_index, :label_index_sequence,
|
5
|
+
:label_instance_count, :total_terms
|
5
6
|
|
6
7
|
def initialize
|
7
|
-
@
|
8
|
+
@frequency_table = ::Hash.new
|
9
|
+
@label_instance_count = ::Hash.new { |hash, label| hash[label] = 0 }
|
10
|
+
@label_index = ::Hash.new { |hash, label| hash[label] = 0 }
|
11
|
+
@label_index_sequence = -1 # start at -1 so 0 is first value
|
12
|
+
@total_terms = 0.0
|
8
13
|
end
|
9
14
|
|
10
|
-
def
|
11
|
-
|
15
|
+
def finish_training!
|
16
|
+
calculate_likelihoods!
|
17
|
+
calculate_priors!
|
12
18
|
end
|
13
19
|
|
14
|
-
def
|
15
|
-
|
20
|
+
def train(label, training_data)
|
21
|
+
if training_data.is_a?(::String)
|
22
|
+
train_from_string(label, training_data)
|
23
|
+
elsif training_data.is_a?(::Hash)
|
24
|
+
train_from_hash(label, training_data)
|
25
|
+
else
|
26
|
+
fail 'Training data must either be a string or hash'
|
27
|
+
end
|
16
28
|
end
|
17
29
|
|
30
|
+
def train_from_string(label, document)
|
31
|
+
words = document.split
|
32
|
+
|
33
|
+
words.each do |word|
|
34
|
+
update_label_index(label)
|
35
|
+
update_frequency_table(label, word, 1)
|
36
|
+
end
|
37
|
+
@label_instance_count[label] += 1
|
38
|
+
@total_terms += 1
|
39
|
+
end
|
40
|
+
|
41
|
+
def train_from_hash(label, frequency_hash)
|
42
|
+
frequency_hash.each do |word, frequency|
|
43
|
+
update_label_index(label)
|
44
|
+
update_frequency_table(label, word, frequency)
|
45
|
+
end
|
46
|
+
@label_instance_count[label] += 1
|
47
|
+
@total_terms += 1
|
48
|
+
end
|
49
|
+
|
50
|
+
def classify(data)
|
51
|
+
if data.is_a?(::String)
|
52
|
+
classify_from_string(data)
|
53
|
+
elsif data.is_a?(::Hash)
|
54
|
+
classify_from_hash(data)
|
55
|
+
else
|
56
|
+
fail 'Training data must either be a string or hash'
|
57
|
+
end
|
58
|
+
end
|
59
|
+
|
60
|
+
def classify_from_hash(frequency_hash)
|
61
|
+
document = frequency_hash.map{ |word, frequency| (word + ' ') * frequency }.join
|
62
|
+
|
63
|
+
classify_from_string(document)
|
64
|
+
end
|
65
|
+
|
66
|
+
def classify_from_string(document)
|
67
|
+
words = document.split.uniq
|
68
|
+
likelihoods = @likelihoods.dup
|
69
|
+
posterior = {}
|
70
|
+
|
71
|
+
vocab_size = frequency_table.keys.size
|
72
|
+
|
73
|
+
label_index.each do |label, index|
|
74
|
+
words.map do |word|
|
75
|
+
row = frequency_table[word]
|
76
|
+
|
77
|
+
unless row.nil?
|
78
|
+
laplace_word_likelihood = (row[index] + 1.0).to_f / (label_instance_count[label] + vocab_size).to_f
|
79
|
+
likelihoods[label] *= laplace_word_likelihood / (1.0 - laplace_word_likelihood)
|
80
|
+
end
|
81
|
+
end
|
82
|
+
|
83
|
+
posterior[label] = @priors[label] * likelihoods[label]
|
84
|
+
end
|
85
|
+
|
86
|
+
normalize(posterior)
|
87
|
+
end
|
88
|
+
|
89
|
+
def save_state(path)
|
90
|
+
state = {}
|
91
|
+
state['frequency_table'] = frequency_table
|
92
|
+
state['label_instance_count'] = label_instance_count
|
93
|
+
state['label_index'] = label_index
|
94
|
+
state['label_index_sequence'] = label_index_sequence
|
95
|
+
state['total_terms'] = total_terms
|
96
|
+
|
97
|
+
::File.open(::File.expand_path(path), 'w') do |handle|
|
98
|
+
handle.write(state.to_json)
|
99
|
+
end
|
100
|
+
end
|
101
|
+
|
102
|
+
def load_state(path)
|
103
|
+
state = ::JSON.parse(::File.read(::File.expand_path(path)))
|
104
|
+
|
105
|
+
fail 'Missing frequency_table' unless state['frequency_table']
|
106
|
+
fail 'Missing label_instance_count' unless state['label_instance_count']
|
107
|
+
fail 'Missing label_index' unless state['label_index']
|
108
|
+
fail 'Missing label_index_sequence' unless state['label_index_sequence']
|
109
|
+
fail 'Missing total_terms' unless state['total_terms']
|
110
|
+
|
111
|
+
@frequency_table = state['frequency_table']
|
112
|
+
@label_instance_count = state['label_instance_count']
|
113
|
+
@label_index = state['label_index']
|
114
|
+
@label_index_sequence = state['label_index_sequence']
|
115
|
+
@total_terms = state['total_terms']
|
116
|
+
|
117
|
+
finish_training!
|
118
|
+
end
|
119
|
+
|
120
|
+
private
|
121
|
+
|
122
|
+
def calculate_likelihoods!
|
123
|
+
@likelihoods = label_index.inject({}) do |accumulator, (label, index)|
|
124
|
+
initial_likelihood = 1.0
|
125
|
+
vocab_size = frequency_table.keys.size
|
126
|
+
|
127
|
+
frequency_table.each do |feature, row|
|
128
|
+
laplace_word_likelihood = (row[index] + 1.0).to_f / (label_instance_count[label] + vocab_size).to_f
|
129
|
+
initial_likelihood *= (1.0 - laplace_word_likelihood)
|
130
|
+
end
|
131
|
+
|
132
|
+
accumulator[label] = initial_likelihood
|
133
|
+
accumulator
|
134
|
+
end
|
135
|
+
end
|
136
|
+
|
137
|
+
def calculate_priors!
|
138
|
+
@priors = label_instance_count.inject({}) do |hash, (label, count)|
|
139
|
+
hash[label] = count / total_terms
|
140
|
+
hash
|
141
|
+
end
|
142
|
+
end
|
143
|
+
|
144
|
+
def get_next_sequence_value
|
145
|
+
@label_index_sequence += 1
|
146
|
+
end
|
147
|
+
|
148
|
+
def normalize(posterior)
|
149
|
+
sum = posterior.inject(0.0) { |accumulator, (key, value)| accumulator + value }
|
150
|
+
|
151
|
+
posterior.inject({}) do |accumulator, (key, value)|
|
152
|
+
accumulator[key] = value / sum
|
153
|
+
accumulator
|
154
|
+
end
|
155
|
+
end
|
156
|
+
|
157
|
+
def update_label_index(label)
|
158
|
+
unless label_index.keys.include?(label)
|
159
|
+
index = get_next_sequence_value
|
160
|
+
label_index[label] = index
|
161
|
+
|
162
|
+
frequency_table.each do |feature, value|
|
163
|
+
value[index] = 0
|
164
|
+
end
|
165
|
+
end
|
166
|
+
end
|
167
|
+
|
168
|
+
def update_frequency_table(label, word, frequency)
|
169
|
+
row = frequency_table[word]
|
170
|
+
index = label_index[label]
|
171
|
+
|
172
|
+
if row
|
173
|
+
row[index] += frequency
|
174
|
+
else
|
175
|
+
frequency_table[word] = label_index.keys.map { |label| 0 }
|
176
|
+
frequency_table[word][index] += frequency
|
177
|
+
end
|
178
|
+
end
|
18
179
|
end
|
19
180
|
end
|
@@ -0,0 +1,26 @@
|
|
1
|
+
require 'java'
|
2
|
+
require ::File.join(::File.dirname(__FILE__), "..", "..", "target" , "bae.jar")
|
3
|
+
|
4
|
+
module Bae
|
5
|
+
class NativeClassifier
|
6
|
+
|
7
|
+
attr_reader :internal_classifier
|
8
|
+
|
9
|
+
def initialize
|
10
|
+
@internal_classifier = ::Java::Bae::NaiveBayesClassifier.new
|
11
|
+
end
|
12
|
+
|
13
|
+
def train(label, feature)
|
14
|
+
internal_classifier.train(label, ::Java::Bae::Document.new(feature))
|
15
|
+
end
|
16
|
+
|
17
|
+
def classify(feature)
|
18
|
+
internal_classifier.classify(::Java::Bae::Document.new(feature))
|
19
|
+
end
|
20
|
+
|
21
|
+
def finish_training!
|
22
|
+
internal_classifier.calculateInitialLikelihoods()
|
23
|
+
end
|
24
|
+
|
25
|
+
end
|
26
|
+
end
|
data/lib/bae/version.rb
CHANGED
data/lib/bae.rb
CHANGED
@@ -1,22 +1,35 @@
|
|
1
1
|
require 'spec_helper'
|
2
2
|
|
3
|
+
require 'bae/native_classifier'
|
4
|
+
|
3
5
|
describe ::Bae::Classifier do
|
4
6
|
|
5
7
|
subject { described_class.new }
|
6
8
|
|
7
|
-
|
9
|
+
let(:state_json) {
|
10
|
+
'{"frequency_table":{"aaa":[0,0],"bbb":[1,0],"ccc":[0,2],"ddd":[0,3]},"label_instance_count":{"positive":1,"negative":1},"label_index":{"positive":0,"negative":1},"label_index_sequence":1,"total_terms":2.0}'
|
11
|
+
}
|
12
|
+
let(:state) { ::JSON.parse(state_json) }
|
13
|
+
|
14
|
+
it "can classify a hash document" do
|
8
15
|
subject.train("positive", {"aaa" => 0, "bbb" => 1})
|
9
16
|
subject.train("negative", {"ccc" => 2, "ddd" => 3})
|
17
|
+
|
18
|
+
subject.finish_training!
|
19
|
+
|
10
20
|
results = subject.classify({"aaa" => 1, "bbb" => 1})
|
11
21
|
|
12
22
|
expect(results["positive"]).to be_within(0.001).of(0.94117)
|
13
23
|
expect(results["negative"]).to be_within(0.001).of(0.05882)
|
14
24
|
end
|
15
25
|
|
16
|
-
it "can classify from
|
26
|
+
it "can classify from a string based document" do
|
17
27
|
subject.train("positive", "aaa aaa bbb");
|
18
28
|
subject.train("negative", "ccc ccc ddd ddd");
|
19
29
|
subject.train("neutral", "eee eee eee fff fff fff");
|
30
|
+
|
31
|
+
subject.finish_training!
|
32
|
+
|
20
33
|
results = subject.classify("aaa bbb")
|
21
34
|
|
22
35
|
expect(results["positive"]).to be_within(0.001).of(0.89626)
|
@@ -24,4 +37,46 @@ describe ::Bae::Classifier do
|
|
24
37
|
expect(results["neutral"]).to be_within(0.001).of(0.03734)
|
25
38
|
end
|
26
39
|
|
40
|
+
it "fails when you attempt to train or test anything other than a hash or string" do
|
41
|
+
subject.train("positive", "aaa aaa bbb");
|
42
|
+
expect{ subject.train("a", 1337) }.to raise_error 'Training data must either be a string or hash'
|
43
|
+
|
44
|
+
subject.finish_training!
|
45
|
+
|
46
|
+
subject.classify("aaa bbb")
|
47
|
+
expect{ subject.classify(1337) }.to raise_error 'Training data must either be a string or hash'
|
48
|
+
end
|
49
|
+
|
50
|
+
it "can save the classifier state" do
|
51
|
+
subject.train("positive", {"aaa" => 0, "bbb" => 1})
|
52
|
+
subject.train("negative", {"ccc" => 2, "ddd" => 3})
|
53
|
+
|
54
|
+
subject.finish_training!
|
55
|
+
|
56
|
+
temp_file = ::Tempfile.new('some_state')
|
57
|
+
subject.save_state(temp_file.path)
|
58
|
+
|
59
|
+
temp_file.rewind
|
60
|
+
expect(temp_file.read).to eq(state_json)
|
61
|
+
|
62
|
+
temp_file.close
|
63
|
+
temp_file.unlink
|
64
|
+
end
|
65
|
+
|
66
|
+
it "can correctly load a classifier state and correctly classify" do
|
67
|
+
temp_file = ::Tempfile.new('some_state')
|
68
|
+
temp_file.write(state_json)
|
69
|
+
temp_file.rewind
|
70
|
+
|
71
|
+
subject.load_state(temp_file.path)
|
72
|
+
|
73
|
+
results = subject.classify({"aaa" => 1, "bbb" => 1})
|
74
|
+
|
75
|
+
expect(results["positive"]).to be_within(0.001).of(0.94117)
|
76
|
+
expect(results["negative"]).to be_within(0.001).of(0.05882)
|
77
|
+
|
78
|
+
temp_file.close
|
79
|
+
temp_file.unlink
|
80
|
+
end
|
81
|
+
|
27
82
|
end
|
@@ -0,0 +1,33 @@
|
|
1
|
+
require 'spec_helper'
|
2
|
+
|
3
|
+
describe ::Bae::NativeClassifier do
|
4
|
+
|
5
|
+
subject { described_class.new }
|
6
|
+
|
7
|
+
it "can classify a hash document" do
|
8
|
+
subject.train("positive", {"aaa" => 0, "bbb" => 1})
|
9
|
+
subject.train("negative", {"ccc" => 2, "ddd" => 3})
|
10
|
+
|
11
|
+
subject.finish_training!
|
12
|
+
|
13
|
+
results = subject.classify({"aaa" => 1, "bbb" => 1})
|
14
|
+
|
15
|
+
expect(results["positive"]).to be_within(0.001).of(0.94117)
|
16
|
+
expect(results["negative"]).to be_within(0.001).of(0.05882)
|
17
|
+
end
|
18
|
+
|
19
|
+
it "can classify from a string based document" do
|
20
|
+
subject.train("positive", "aaa aaa bbb");
|
21
|
+
subject.train("negative", "ccc ccc ddd ddd");
|
22
|
+
subject.train("neutral", "eee eee eee fff fff fff");
|
23
|
+
|
24
|
+
subject.finish_training!
|
25
|
+
|
26
|
+
results = subject.classify("aaa bbb")
|
27
|
+
|
28
|
+
expect(results["positive"]).to be_within(0.001).of(0.89626)
|
29
|
+
expect(results["negative"]).to be_within(0.001).of(0.06639)
|
30
|
+
expect(results["neutral"]).to be_within(0.001).of(0.03734)
|
31
|
+
end
|
32
|
+
|
33
|
+
end
|
data/spec/spec_helper.rb
CHANGED
@@ -33,7 +33,9 @@ public class Document {
|
|
33
33
|
|
34
34
|
// Set initial count if it doesn't have one yet
|
35
35
|
// Use zero because we'll add counts in the next line.
|
36
|
-
this.frequencyMap.
|
36
|
+
if(!this.frequencyMap.containsKey(wordToken)) {
|
37
|
+
this.frequencyMap.put(wordToken, 0L);
|
38
|
+
}
|
37
39
|
|
38
40
|
// Update count
|
39
41
|
this.frequencyMap.put(wordToken, this.frequencyMap.get(wordToken) + 1);
|
@@ -14,7 +14,9 @@ public class FrequencyTable {
|
|
14
14
|
|
15
15
|
public void insertOrIgnore(String label) {
|
16
16
|
// Add new hash to frequency table if it's not already there
|
17
|
-
this.frequencyTable.
|
17
|
+
if(!this.frequencyTable.containsKey(label)) {
|
18
|
+
this.frequencyTable.put(label, new HashMap<String, Long>());
|
19
|
+
}
|
18
20
|
}
|
19
21
|
|
20
22
|
public void increaseFrequencyBy(String label, String word, long frequency) {
|
@@ -24,7 +26,9 @@ public class FrequencyTable {
|
|
24
26
|
Map<String, Long> frequencyRow = this.frequencyTable.get(label);
|
25
27
|
|
26
28
|
// Make sure we have a frequency for that position in the table
|
27
|
-
frequencyRow.
|
29
|
+
if(!frequencyRow.containsKey(word)) {
|
30
|
+
frequencyRow.put(word, 0L);
|
31
|
+
}
|
28
32
|
|
29
33
|
// Update frequency
|
30
34
|
frequencyRow.put(word, frequencyRow.get(word) + frequency);
|
@@ -8,12 +8,16 @@ public class NaiveBayesClassifier {
|
|
8
8
|
private FrequencyTable frequencyTable;
|
9
9
|
private Map<String, Long> wordTable;
|
10
10
|
private Map<String, Long> instanceCountOf;
|
11
|
+
private Map<String, Double> initialLikelihoodOf;
|
12
|
+
Map<String, Double> classPriorOf;
|
11
13
|
private double totalCount = 0;
|
12
14
|
|
13
15
|
public NaiveBayesClassifier() {
|
14
16
|
this.frequencyTable = new FrequencyTable();
|
15
17
|
this.wordTable = new HashMap<>();
|
16
18
|
this.instanceCountOf = new HashMap<>();
|
19
|
+
this.initialLikelihoodOf = new HashMap<>();
|
20
|
+
this.classPriorOf = new HashMap<>();
|
17
21
|
}
|
18
22
|
|
19
23
|
public void train(String label, Document document) {
|
@@ -37,12 +41,23 @@ public class NaiveBayesClassifier {
|
|
37
41
|
updateIntegerCountBy(this.instanceCountOf, label, 1);
|
38
42
|
}
|
39
43
|
|
40
|
-
public
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
44
|
+
public void calculateInitialLikelihoods() {
|
45
|
+
// Update likelihood counts
|
46
|
+
for(String label : this.frequencyTable.getLabels()) {
|
47
|
+
// Set initial likelihood
|
48
|
+
initialLikelihoodOf.put(label, 1d);
|
49
|
+
|
50
|
+
// Calculate likelihoods
|
51
|
+
for (String word : this.wordTable.keySet()) {
|
52
|
+
double laplaceWordLikelihood =
|
53
|
+
(this.frequencyTable.get(label, word) + 1d) /
|
54
|
+
(this.instanceCountOf.get(label) + this.wordTable.size());
|
55
|
+
|
56
|
+
// Update likelihood
|
57
|
+
double likelihood = initialLikelihoodOf.get(label);
|
58
|
+
initialLikelihoodOf.put(label, likelihood * (1d - laplaceWordLikelihood));
|
59
|
+
}
|
60
|
+
}
|
46
61
|
|
47
62
|
// Update the prior
|
48
63
|
for(Map.Entry<String, Long> entry : this.instanceCountOf.entrySet()) {
|
@@ -50,34 +65,39 @@ public class NaiveBayesClassifier {
|
|
50
65
|
double frequency = entry.getValue();
|
51
66
|
|
52
67
|
// Update instance count
|
53
|
-
classPriorOf.put(label, (frequency / this.totalCount));
|
68
|
+
this.classPriorOf.put(label, (frequency / this.totalCount));
|
54
69
|
}
|
70
|
+
}
|
71
|
+
|
72
|
+
public Map<String, Double> classify(Document document) {
|
73
|
+
Map<String, Double> likelihoodOf = new HashMap<>();
|
74
|
+
Map<String, Double> classPosteriorOf = new HashMap<>();
|
75
|
+
Map<String, Long> featureFrequencyMap = document.getFrequencyMap();
|
76
|
+
double evidence = 0;
|
55
77
|
|
56
78
|
// Update likelihood counts
|
57
79
|
for(String label : this.frequencyTable.getLabels()) {
|
58
80
|
// Set initial likelihood
|
59
|
-
likelihoodOf.put(label,
|
81
|
+
likelihoodOf.put(label, this.initialLikelihoodOf.get(label));
|
60
82
|
|
61
|
-
// Calculate likelihoods
|
62
|
-
for(String word :
|
83
|
+
// Calculate actual likelihoods likelihoods
|
84
|
+
for(String word : featureFrequencyMap.keySet()) {
|
63
85
|
double laplaceWordLikelihood =
|
64
86
|
(this.frequencyTable.get(label, word) + 1d) /
|
65
87
|
(this.instanceCountOf.get(label) + this.wordTable.size());
|
66
88
|
|
67
|
-
// Update likelihood
|
89
|
+
// Update likelihood for words not in features
|
68
90
|
double likelihood = likelihoodOf.get(label);
|
69
|
-
if(
|
70
|
-
likelihoodOf.put(label, likelihood * laplaceWordLikelihood);
|
71
|
-
} else {
|
72
|
-
likelihoodOf.put(label, likelihood * (1d - laplaceWordLikelihood));
|
91
|
+
if(featureFrequencyMap.containsKey(word)) {
|
92
|
+
likelihoodOf.put(label, (likelihood * laplaceWordLikelihood) / (1d - laplaceWordLikelihood));
|
73
93
|
}
|
74
94
|
}
|
75
95
|
|
76
96
|
// Default class posterior of label to 1.0
|
77
|
-
classPosteriorOf.
|
97
|
+
classPosteriorOf.put(label, 1d);
|
78
98
|
|
79
99
|
// Update class posterior
|
80
|
-
double classPosterior = classPriorOf.get(label) * likelihoodOf.get(label);
|
100
|
+
double classPosterior = this.classPriorOf.get(label) * likelihoodOf.get(label);
|
81
101
|
classPosteriorOf.put(label, classPosterior);
|
82
102
|
evidence += classPosterior;
|
83
103
|
}
|
@@ -93,12 +113,16 @@ public class NaiveBayesClassifier {
|
|
93
113
|
}
|
94
114
|
|
95
115
|
public void updateIntegerCountBy(Map<String, Long> someMap, String someKey, long count) {
|
96
|
-
someMap.
|
116
|
+
if(!someMap.containsKey(someKey)) {
|
117
|
+
someMap.put(someKey, 0L);
|
118
|
+
}
|
97
119
|
someMap.put(someKey, someMap.get(someKey) + count);
|
98
120
|
}
|
99
121
|
|
100
122
|
public void updateDoubleCountBy(Map<String, Double> someMap, String someKey, double count) {
|
101
|
-
someMap.
|
123
|
+
if(!someMap.containsKey(someKey)) {
|
124
|
+
someMap.put(someKey, 0.0);
|
125
|
+
}
|
102
126
|
someMap.put(someKey, someMap.get(someKey) + count);
|
103
127
|
}
|
104
128
|
|
data/target/bae.jar
CHANGED
Binary file
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: bae
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.0.
|
4
|
+
version: 0.0.9
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Garrett Thornburg
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2015-02-
|
11
|
+
date: 2015-02-27 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
requirement: !ruby/object:Gem::Requirement
|
@@ -68,8 +68,10 @@ files:
|
|
68
68
|
- build.xml
|
69
69
|
- lib/bae.rb
|
70
70
|
- lib/bae/classifier.rb
|
71
|
+
- lib/bae/native_classifier.rb
|
71
72
|
- lib/bae/version.rb
|
72
73
|
- spec/lib/bae/classifier_spec.rb
|
74
|
+
- spec/lib/bae/native_classifier_spec.rb
|
73
75
|
- spec/spec_helper.rb
|
74
76
|
- src/main/java/bae/Document.java
|
75
77
|
- src/main/java/bae/FrequencyTable.java
|
@@ -104,4 +106,5 @@ specification_version: 4
|
|
104
106
|
summary: Multinomial naive bayes classifier with a kick of java
|
105
107
|
test_files:
|
106
108
|
- spec/lib/bae/classifier_spec.rb
|
109
|
+
- spec/lib/bae/native_classifier_spec.rb
|
107
110
|
- spec/spec_helper.rb
|