bae 0.0.7-java → 0.0.8-java

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA1:
3
- metadata.gz: 71c761f2619746bfc4dd287a5afa5443a7bfe037
4
- data.tar.gz: 886257c2c8987fd8e95edbb63105ef87d7960de5
3
+ metadata.gz: cb626ff0b92f80f096cebc7248a64a8f47f02fda
4
+ data.tar.gz: 87c41e0571e1a31c303f9ab346eef119cec83e6f
5
5
  SHA512:
6
- metadata.gz: e02048771022daa4b61097ae500831a671f1a6ec3d8e9e48f235efd7ff9902be31678e503b4bd31b6cd0a0526d61c2868f1af39cb619c8c9fb120517242928cc
7
- data.tar.gz: c2c7073db8b7afeea466a5aa9db9f157101a9a8e9b38d87787281c204990eb3c4b3f75b64a9ecb2fd7a8dc5629bd807d101b0ea702c38d735ba52e2595bdc822
6
+ metadata.gz: 478e21b1c13f82037a5773cbb5b960ff98387b69cd024ece14c17eee7e9bf5784b658ea555e223e0affcc689ff6aa3fed9f59460ce586f283e06eeecc0d2291f
7
+ data.tar.gz: a3ceddd6c99d9ca8826f2142b2f383b9426ca0f9b28f88fbf966d954595b368c286551c94a6202d1add93709f196b4704bbe6ef5657b371dc1f41a2ac80317ed
data/README.md CHANGED
@@ -3,6 +3,15 @@ Bae
3
3
 
4
4
  Bae is a multinomial naive bayes classifier based on another gem ["naivebayes"](https://github.com/id774/naivebayes), only this one uses java to do the heavy lifting.
5
5
 
6
+ By default this will use the vanilla ruby implementation, but you can use the native classifier written in java.
7
+
8
+ ```ruby
9
+ require 'bae/native_classifier'
10
+
11
+ classifier = ::Bae::NativeClassifier.new
12
+ ```
13
+
14
+
6
15
  ## Installation
7
16
 
8
17
  Add this line to your application's Gemfile:
@@ -50,10 +59,37 @@ classifier.classify("aaa bbb")
50
59
  #=> {"positive"=>0.8962655601659751, "negative"=>0.0663900414937759, "neutral"=>0.037344398340248955}
51
60
  ```
52
61
 
62
+ ### Saving State
63
+
64
+ You can actually save a snapshot of the trained classifier to disk and load it into memory.
65
+
66
+ ```ruby
67
+ # From the example above...
68
+ classifier = ::Bae::Classifier.new
69
+ classifier.train("positive", {"aaa" => 0, "bbb" => 1})
70
+ classifier.train("negative", {"ccc" => 2, "ddd" => 3})
71
+
72
+ classifier.finish_training!
73
+
74
+ classifier.classify({"aaa" => 1, "bbb" => 1})
75
+ #=> {"positive" => 0.8767123287671234, "negative" => 0.12328767123287669}
76
+
77
+ # Now let's save it to disk
78
+ classifier.save_state("/tmp/some_state.json")
79
+
80
+ # Let's create a new classifier and load from the sate we just saved
81
+ classifier = ::Bae::Classifier.new
82
+ classifier.load_state("/tmp/some_state.json")
83
+
84
+ # Now we can classify without retraining
85
+ classifier.classify({"aaa" => 1, "bbb" => 1})
86
+ #=> {"positive" => 0.8767123287671234, "negative" => 0.12328767123287669}
87
+ ```
88
+
53
89
 
54
90
  ## Contributing
55
91
 
56
- 1. Fork it ( https://github.com/[my-github-username]/bae/fork )
92
+ 1. Fork it ( https://github.com/film42/bae/fork )
57
93
  2. Create your feature branch (`git checkout -b my-new-feature`)
58
94
  3. Commit your changes (`git commit -am 'Add some feature'`)
59
95
  4. Push to the branch (`git push origin my-new-feature`)
@@ -1,23 +1,180 @@
1
1
  module Bae
2
2
  class Classifier
3
3
 
4
- attr_reader :internal_classifier
4
+ attr_accessor :frequency_table, :label_index, :label_index_sequence,
5
+ :label_instance_count, :total_terms
5
6
 
6
7
  def initialize
7
- @internal_classifier = ::Java::Bae::NaiveBayesClassifier.new
8
+ @frequency_table = ::Hash.new { |hash, feature| hash[feature] = [] }
9
+ @label_instance_count = ::Hash.new { |hash, label| hash[label] = 0 }
10
+ @label_index = ::Hash.new { |hash, label| hash[label] = 0 }
11
+ @label_index_sequence = -1 # start at -1 so 0 is first value
12
+ @total_terms = 0.0
8
13
  end
9
14
 
10
- def train(label, feature)
11
- internal_classifier.train(label, ::Java::Bae::Document.new(feature))
15
+ def finish_training!
16
+ calculate_likelihoods!
17
+ calculate_priors!
12
18
  end
13
19
 
14
- def classify(feature)
15
- internal_classifier.classify(::Java::Bae::Document.new(feature))
20
+ def train(label, training_data)
21
+ if training_data.is_a?(::String)
22
+ train_from_string(label, training_data)
23
+ elsif training_data.is_a?(::Hash)
24
+ train_from_hash(label, training_data)
25
+ else
26
+ fail 'Training data must either be a string or hash'
27
+ end
16
28
  end
17
29
 
18
- def finish_training!
19
- internal_classifier.calculateInitialLikelihoods()
30
+ def train_from_string(label, document)
31
+ words = document.split
32
+
33
+ words.each do |word|
34
+ update_label_index(label)
35
+ update_frequency_table(label, word, 1)
36
+ end
37
+ @label_instance_count[label] += 1
38
+ @total_terms += 1
39
+ end
40
+
41
+ def train_from_hash(label, frequency_hash)
42
+ frequency_hash.each do |word, frequency|
43
+ update_label_index(label)
44
+ update_frequency_table(label, word, frequency)
45
+ end
46
+ @label_instance_count[label] += 1
47
+ @total_terms += 1
48
+ end
49
+
50
+ def classify(data)
51
+ if data.is_a?(::String)
52
+ classify_from_string(data)
53
+ elsif data.is_a?(::Hash)
54
+ classify_from_hash(data)
55
+ else
56
+ fail 'Training data must either be a string or hash'
57
+ end
58
+ end
59
+
60
+ def classify_from_hash(frequency_hash)
61
+ document = frequency_hash.map{ |word, frequency| (word + ' ') * frequency }.join
62
+
63
+ classify_from_string(document)
64
+ end
65
+
66
+ def classify_from_string(document)
67
+ words = document.split.uniq
68
+ likelihoods = @likelihoods.dup
69
+ posterior = {}
70
+
71
+ vocab_size = frequency_table.keys.size
72
+
73
+ label_index.each do |label, index|
74
+ words.map do |word|
75
+ row = frequency_table[word]
76
+
77
+ unless row.empty?
78
+ laplace_word_likelihood = (row[index] + 1.0).to_f / (label_instance_count[label] + vocab_size).to_f
79
+ likelihoods[label] *= laplace_word_likelihood / (1.0 - laplace_word_likelihood)
80
+ end
81
+ end
82
+
83
+ posterior[label] = @priors[label] * likelihoods[label]
84
+ end
85
+
86
+ normalize(posterior)
87
+ end
88
+
89
+ def save_state(path)
90
+ state = {}
91
+ state['frequency_table'] = frequency_table
92
+ state['label_instance_count'] = label_instance_count
93
+ state['label_index'] = label_index
94
+ state['label_index_sequence'] = label_index_sequence
95
+ state['total_terms'] = total_terms
96
+
97
+ ::File.open(::File.expand_path(path), 'w') do |handle|
98
+ handle.write(state.to_json)
99
+ end
100
+ end
101
+
102
+ def load_state(path)
103
+ state = ::JSON.parse(::File.read(::File.expand_path(path)))
104
+
105
+ fail 'Missing frequency_table' unless state['frequency_table']
106
+ fail 'Missing label_instance_count' unless state['label_instance_count']
107
+ fail 'Missing label_index' unless state['label_index']
108
+ fail 'Missing label_index_sequence' unless state['label_index_sequence']
109
+ fail 'Missing total_terms' unless state['total_terms']
110
+
111
+ @frequency_table = state['frequency_table']
112
+ @label_instance_count = state['label_instance_count']
113
+ @label_index = state['label_index']
114
+ @label_index_sequence = state['label_index_sequence']
115
+ @total_terms = state['total_terms']
116
+
117
+ finish_training!
118
+ end
119
+
120
+ private
121
+
122
+ def calculate_likelihoods!
123
+ @likelihoods = label_index.inject({}) do |accumulator, (label, index)|
124
+ initial_likelihood = 1.0
125
+ vocab_size = frequency_table.keys.size
126
+
127
+ frequency_table.each do |feature, row|
128
+ laplace_word_likelihood = (row[index] + 1.0).to_f / (label_instance_count[label] + vocab_size).to_f
129
+ initial_likelihood *= (1.0 - laplace_word_likelihood)
130
+ end
131
+
132
+ accumulator[label] = initial_likelihood
133
+ accumulator
134
+ end
135
+ end
136
+
137
+ def calculate_priors!
138
+ @priors = label_instance_count.inject({}) do |hash, (label, count)|
139
+ hash[label] = count / total_terms
140
+ hash
141
+ end
20
142
  end
21
143
 
144
+ def get_next_sequence_value
145
+ @label_index_sequence += 1
146
+ end
147
+
148
+ def normalize(posterior)
149
+ sum = posterior.inject(0.0) { |accumulator, (key, value)| accumulator + value }
150
+
151
+ posterior.inject({}) do |accumulator, (key, value)|
152
+ accumulator[key] = value / sum
153
+ accumulator
154
+ end
155
+ end
156
+
157
+ def update_label_index(label)
158
+ unless label_index.keys.include?(label)
159
+ index = get_next_sequence_value
160
+ label_index[label] = index
161
+
162
+ frequency_table.each do |feature, value|
163
+ value[index] = 0
164
+ end
165
+ end
166
+ end
167
+
168
+ def update_frequency_table(label, word, frequency)
169
+ row = frequency_table[word]
170
+ index = label_index[label]
171
+
172
+ if row[index]
173
+ row[index] += frequency
174
+ else
175
+ row[0..1] = label_index.keys.map { |label| 0 }
176
+ row[index] = frequency
177
+ end
178
+ end
22
179
  end
23
180
  end
@@ -0,0 +1,26 @@
1
+ require 'java'
2
+ require ::File.join(::File.dirname(__FILE__), "..", "..", "target" , "bae.jar")
3
+
4
+ module Bae
5
+ class NativeClassifier
6
+
7
+ attr_reader :internal_classifier
8
+
9
+ def initialize
10
+ @internal_classifier = ::Java::Bae::NaiveBayesClassifier.new
11
+ end
12
+
13
+ def train(label, feature)
14
+ internal_classifier.train(label, ::Java::Bae::Document.new(feature))
15
+ end
16
+
17
+ def classify(feature)
18
+ internal_classifier.classify(::Java::Bae::Document.new(feature))
19
+ end
20
+
21
+ def finish_training!
22
+ internal_classifier.calculateInitialLikelihoods()
23
+ end
24
+
25
+ end
26
+ end
data/lib/bae/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Bae
2
- VERSION = "0.0.7"
2
+ VERSION = "0.0.8"
3
3
  end
data/lib/bae.rb CHANGED
@@ -1,8 +1,6 @@
1
- require "bae/version"
2
- require 'java'
3
-
4
- require ::File.join(::File.dirname(__FILE__), "..", "target" , "bae.jar")
1
+ require 'json'
5
2
 
3
+ require "bae/version"
6
4
  require "bae/classifier"
7
5
 
8
6
  module Bae
@@ -1,10 +1,17 @@
1
1
  require 'spec_helper'
2
2
 
3
+ require 'bae/native_classifier'
4
+
3
5
  describe ::Bae::Classifier do
4
6
 
5
7
  subject { described_class.new }
6
8
 
7
- it "can classify from ruby to java with a hash document" do
9
+ let(:state_json) {
10
+ '{"frequency_table":{"aaa":[0,0],"bbb":[1,0],"ccc":[0,2],"ddd":[0,3]},"label_instance_count":{"positive":1,"negative":1},"label_index":{"positive":0,"negative":1},"label_index_sequence":1,"total_terms":2.0}'
11
+ }
12
+ let(:state) { ::JSON.parse(state_json) }
13
+
14
+ it "can classify a hash document" do
8
15
  subject.train("positive", {"aaa" => 0, "bbb" => 1})
9
16
  subject.train("negative", {"ccc" => 2, "ddd" => 3})
10
17
 
@@ -16,7 +23,7 @@ describe ::Bae::Classifier do
16
23
  expect(results["negative"]).to be_within(0.001).of(0.05882)
17
24
  end
18
25
 
19
- it "can classify from ruby to java with a string based document" do
26
+ it "can classify from a string based document" do
20
27
  subject.train("positive", "aaa aaa bbb");
21
28
  subject.train("negative", "ccc ccc ddd ddd");
22
29
  subject.train("neutral", "eee eee eee fff fff fff");
@@ -30,4 +37,46 @@ describe ::Bae::Classifier do
30
37
  expect(results["neutral"]).to be_within(0.001).of(0.03734)
31
38
  end
32
39
 
40
+ it "fails when you attempt to train or test anything other than a hash or string" do
41
+ subject.train("positive", "aaa aaa bbb");
42
+ expect{ subject.train("a", 1337) }.to raise_error 'Training data must either be a string or hash'
43
+
44
+ subject.finish_training!
45
+
46
+ subject.classify("aaa bbb")
47
+ expect{ subject.classify(1337) }.to raise_error 'Training data must either be a string or hash'
48
+ end
49
+
50
+ it "can save the classifier state" do
51
+ subject.train("positive", {"aaa" => 0, "bbb" => 1})
52
+ subject.train("negative", {"ccc" => 2, "ddd" => 3})
53
+
54
+ subject.finish_training!
55
+
56
+ temp_file = ::Tempfile.new('some_state')
57
+ subject.save_state(temp_file.path)
58
+
59
+ temp_file.rewind
60
+ expect(temp_file.read).to eq(state_json)
61
+
62
+ temp_file.close
63
+ temp_file.unlink
64
+ end
65
+
66
+ it "can correctly load a classifier state and correctly classify" do
67
+ temp_file = ::Tempfile.new('some_state')
68
+ temp_file.write(state_json)
69
+ temp_file.rewind
70
+
71
+ subject.load_state(temp_file.path)
72
+
73
+ results = subject.classify({"aaa" => 1, "bbb" => 1})
74
+
75
+ expect(results["positive"]).to be_within(0.001).of(0.94117)
76
+ expect(results["negative"]).to be_within(0.001).of(0.05882)
77
+
78
+ temp_file.close
79
+ temp_file.unlink
80
+ end
81
+
33
82
  end
@@ -0,0 +1,33 @@
1
+ require 'spec_helper'
2
+
3
+ describe ::Bae::NativeClassifier do
4
+
5
+ subject { described_class.new }
6
+
7
+ it "can classify a hash document" do
8
+ subject.train("positive", {"aaa" => 0, "bbb" => 1})
9
+ subject.train("negative", {"ccc" => 2, "ddd" => 3})
10
+
11
+ subject.finish_training!
12
+
13
+ results = subject.classify({"aaa" => 1, "bbb" => 1})
14
+
15
+ expect(results["positive"]).to be_within(0.001).of(0.94117)
16
+ expect(results["negative"]).to be_within(0.001).of(0.05882)
17
+ end
18
+
19
+ it "can classify from a string based document" do
20
+ subject.train("positive", "aaa aaa bbb");
21
+ subject.train("negative", "ccc ccc ddd ddd");
22
+ subject.train("neutral", "eee eee eee fff fff fff");
23
+
24
+ subject.finish_training!
25
+
26
+ results = subject.classify("aaa bbb")
27
+
28
+ expect(results["positive"]).to be_within(0.001).of(0.89626)
29
+ expect(results["negative"]).to be_within(0.001).of(0.06639)
30
+ expect(results["neutral"]).to be_within(0.001).of(0.03734)
31
+ end
32
+
33
+ end
data/spec/spec_helper.rb CHANGED
@@ -1,5 +1,6 @@
1
1
  require 'bundler/setup'
2
2
  require 'bae'
3
+ require 'tempfile'
3
4
  require 'rspec'
4
5
 
5
6
  RSpec.configure do |c|
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: bae
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.7
4
+ version: 0.0.8
5
5
  platform: java
6
6
  authors:
7
7
  - Garrett Thornburg
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2015-02-23 00:00:00.000000000 Z
11
+ date: 2015-02-25 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  requirement: !ruby/object:Gem::Requirement
@@ -68,8 +68,10 @@ files:
68
68
  - build.xml
69
69
  - lib/bae.rb
70
70
  - lib/bae/classifier.rb
71
+ - lib/bae/native_classifier.rb
71
72
  - lib/bae/version.rb
72
73
  - spec/lib/bae/classifier_spec.rb
74
+ - spec/lib/bae/native_classifier_spec.rb
73
75
  - spec/spec_helper.rb
74
76
  - src/main/java/bae/Document.java
75
77
  - src/main/java/bae/FrequencyTable.java
@@ -104,4 +106,5 @@ specification_version: 4
104
106
  summary: Multinomial naive bayes classifier with a kick of java
105
107
  test_files:
106
108
  - spec/lib/bae/classifier_spec.rb
109
+ - spec/lib/bae/native_classifier_spec.rb
107
110
  - spec/spec_helper.rb