bae 0.0.7-java → 0.0.8-java

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA1:
3
- metadata.gz: 71c761f2619746bfc4dd287a5afa5443a7bfe037
4
- data.tar.gz: 886257c2c8987fd8e95edbb63105ef87d7960de5
3
+ metadata.gz: cb626ff0b92f80f096cebc7248a64a8f47f02fda
4
+ data.tar.gz: 87c41e0571e1a31c303f9ab346eef119cec83e6f
5
5
  SHA512:
6
- metadata.gz: e02048771022daa4b61097ae500831a671f1a6ec3d8e9e48f235efd7ff9902be31678e503b4bd31b6cd0a0526d61c2868f1af39cb619c8c9fb120517242928cc
7
- data.tar.gz: c2c7073db8b7afeea466a5aa9db9f157101a9a8e9b38d87787281c204990eb3c4b3f75b64a9ecb2fd7a8dc5629bd807d101b0ea702c38d735ba52e2595bdc822
6
+ metadata.gz: 478e21b1c13f82037a5773cbb5b960ff98387b69cd024ece14c17eee7e9bf5784b658ea555e223e0affcc689ff6aa3fed9f59460ce586f283e06eeecc0d2291f
7
+ data.tar.gz: a3ceddd6c99d9ca8826f2142b2f383b9426ca0f9b28f88fbf966d954595b368c286551c94a6202d1add93709f196b4704bbe6ef5657b371dc1f41a2ac80317ed
data/README.md CHANGED
@@ -3,6 +3,15 @@ Bae
3
3
 
4
4
  Bae is a multinomial naive bayes classifier based on another gem ["naivebayes"](https://github.com/id774/naivebayes), only this one uses java to do the heavy lifting.
5
5
 
6
+ By default this will use the vanilla ruby implementation, but you can use the native classifier written in java.
7
+
8
+ ```ruby
9
+ require 'bae/native_classifier'
10
+
11
+ classifier = ::Bae::NativeClassifier.new
12
+ ```
13
+
14
+
6
15
  ## Installation
7
16
 
8
17
  Add this line to your application's Gemfile:
@@ -50,10 +59,37 @@ classifier.classify("aaa bbb")
50
59
  #=> {"positive"=>0.8962655601659751, "negative"=>0.0663900414937759, "neutral"=>0.037344398340248955}
51
60
  ```
52
61
 
62
+ ### Saving State
63
+
64
+ You can actually save a snapshot of the trained classifier to disk and load it into memory.
65
+
66
+ ```ruby
67
+ # From the example above...
68
+ classifier = ::Bae::Classifier.new
69
+ classifier.train("positive", {"aaa" => 0, "bbb" => 1})
70
+ classifier.train("negative", {"ccc" => 2, "ddd" => 3})
71
+
72
+ classifier.finish_training!
73
+
74
+ classifier.classify({"aaa" => 1, "bbb" => 1})
75
+ #=> {"positive" => 0.8767123287671234, "negative" => 0.12328767123287669}
76
+
77
+ # Now let's save it to disk
78
+ classifier.save_state("/tmp/some_state.json")
79
+
80
+ # Let's create a new classifier and load from the sate we just saved
81
+ classifier = ::Bae::Classifier.new
82
+ classifier.load_state("/tmp/some_state.json")
83
+
84
+ # Now we can classify without retraining
85
+ classifier.classify({"aaa" => 1, "bbb" => 1})
86
+ #=> {"positive" => 0.8767123287671234, "negative" => 0.12328767123287669}
87
+ ```
88
+
53
89
 
54
90
  ## Contributing
55
91
 
56
- 1. Fork it ( https://github.com/[my-github-username]/bae/fork )
92
+ 1. Fork it ( https://github.com/film42/bae/fork )
57
93
  2. Create your feature branch (`git checkout -b my-new-feature`)
58
94
  3. Commit your changes (`git commit -am 'Add some feature'`)
59
95
  4. Push to the branch (`git push origin my-new-feature`)
@@ -1,23 +1,180 @@
1
1
  module Bae
2
2
  class Classifier
3
3
 
4
- attr_reader :internal_classifier
4
+ attr_accessor :frequency_table, :label_index, :label_index_sequence,
5
+ :label_instance_count, :total_terms
5
6
 
6
7
  def initialize
7
- @internal_classifier = ::Java::Bae::NaiveBayesClassifier.new
8
+ @frequency_table = ::Hash.new { |hash, feature| hash[feature] = [] }
9
+ @label_instance_count = ::Hash.new { |hash, label| hash[label] = 0 }
10
+ @label_index = ::Hash.new { |hash, label| hash[label] = 0 }
11
+ @label_index_sequence = -1 # start at -1 so 0 is first value
12
+ @total_terms = 0.0
8
13
  end
9
14
 
10
- def train(label, feature)
11
- internal_classifier.train(label, ::Java::Bae::Document.new(feature))
15
+ def finish_training!
16
+ calculate_likelihoods!
17
+ calculate_priors!
12
18
  end
13
19
 
14
- def classify(feature)
15
- internal_classifier.classify(::Java::Bae::Document.new(feature))
20
+ def train(label, training_data)
21
+ if training_data.is_a?(::String)
22
+ train_from_string(label, training_data)
23
+ elsif training_data.is_a?(::Hash)
24
+ train_from_hash(label, training_data)
25
+ else
26
+ fail 'Training data must either be a string or hash'
27
+ end
16
28
  end
17
29
 
18
- def finish_training!
19
- internal_classifier.calculateInitialLikelihoods()
30
+ def train_from_string(label, document)
31
+ words = document.split
32
+
33
+ words.each do |word|
34
+ update_label_index(label)
35
+ update_frequency_table(label, word, 1)
36
+ end
37
+ @label_instance_count[label] += 1
38
+ @total_terms += 1
39
+ end
40
+
41
+ def train_from_hash(label, frequency_hash)
42
+ frequency_hash.each do |word, frequency|
43
+ update_label_index(label)
44
+ update_frequency_table(label, word, frequency)
45
+ end
46
+ @label_instance_count[label] += 1
47
+ @total_terms += 1
48
+ end
49
+
50
+ def classify(data)
51
+ if data.is_a?(::String)
52
+ classify_from_string(data)
53
+ elsif data.is_a?(::Hash)
54
+ classify_from_hash(data)
55
+ else
56
+ fail 'Training data must either be a string or hash'
57
+ end
58
+ end
59
+
60
+ def classify_from_hash(frequency_hash)
61
+ document = frequency_hash.map{ |word, frequency| (word + ' ') * frequency }.join
62
+
63
+ classify_from_string(document)
64
+ end
65
+
66
+ def classify_from_string(document)
67
+ words = document.split.uniq
68
+ likelihoods = @likelihoods.dup
69
+ posterior = {}
70
+
71
+ vocab_size = frequency_table.keys.size
72
+
73
+ label_index.each do |label, index|
74
+ words.map do |word|
75
+ row = frequency_table[word]
76
+
77
+ unless row.empty?
78
+ laplace_word_likelihood = (row[index] + 1.0).to_f / (label_instance_count[label] + vocab_size).to_f
79
+ likelihoods[label] *= laplace_word_likelihood / (1.0 - laplace_word_likelihood)
80
+ end
81
+ end
82
+
83
+ posterior[label] = @priors[label] * likelihoods[label]
84
+ end
85
+
86
+ normalize(posterior)
87
+ end
88
+
89
+ def save_state(path)
90
+ state = {}
91
+ state['frequency_table'] = frequency_table
92
+ state['label_instance_count'] = label_instance_count
93
+ state['label_index'] = label_index
94
+ state['label_index_sequence'] = label_index_sequence
95
+ state['total_terms'] = total_terms
96
+
97
+ ::File.open(::File.expand_path(path), 'w') do |handle|
98
+ handle.write(state.to_json)
99
+ end
100
+ end
101
+
102
+ def load_state(path)
103
+ state = ::JSON.parse(::File.read(::File.expand_path(path)))
104
+
105
+ fail 'Missing frequency_table' unless state['frequency_table']
106
+ fail 'Missing label_instance_count' unless state['label_instance_count']
107
+ fail 'Missing label_index' unless state['label_index']
108
+ fail 'Missing label_index_sequence' unless state['label_index_sequence']
109
+ fail 'Missing total_terms' unless state['total_terms']
110
+
111
+ @frequency_table = state['frequency_table']
112
+ @label_instance_count = state['label_instance_count']
113
+ @label_index = state['label_index']
114
+ @label_index_sequence = state['label_index_sequence']
115
+ @total_terms = state['total_terms']
116
+
117
+ finish_training!
118
+ end
119
+
120
+ private
121
+
122
+ def calculate_likelihoods!
123
+ @likelihoods = label_index.inject({}) do |accumulator, (label, index)|
124
+ initial_likelihood = 1.0
125
+ vocab_size = frequency_table.keys.size
126
+
127
+ frequency_table.each do |feature, row|
128
+ laplace_word_likelihood = (row[index] + 1.0).to_f / (label_instance_count[label] + vocab_size).to_f
129
+ initial_likelihood *= (1.0 - laplace_word_likelihood)
130
+ end
131
+
132
+ accumulator[label] = initial_likelihood
133
+ accumulator
134
+ end
135
+ end
136
+
137
+ def calculate_priors!
138
+ @priors = label_instance_count.inject({}) do |hash, (label, count)|
139
+ hash[label] = count / total_terms
140
+ hash
141
+ end
20
142
  end
21
143
 
144
+ def get_next_sequence_value
145
+ @label_index_sequence += 1
146
+ end
147
+
148
+ def normalize(posterior)
149
+ sum = posterior.inject(0.0) { |accumulator, (key, value)| accumulator + value }
150
+
151
+ posterior.inject({}) do |accumulator, (key, value)|
152
+ accumulator[key] = value / sum
153
+ accumulator
154
+ end
155
+ end
156
+
157
+ def update_label_index(label)
158
+ unless label_index.keys.include?(label)
159
+ index = get_next_sequence_value
160
+ label_index[label] = index
161
+
162
+ frequency_table.each do |feature, value|
163
+ value[index] = 0
164
+ end
165
+ end
166
+ end
167
+
168
+ def update_frequency_table(label, word, frequency)
169
+ row = frequency_table[word]
170
+ index = label_index[label]
171
+
172
+ if row[index]
173
+ row[index] += frequency
174
+ else
175
+ row[0..1] = label_index.keys.map { |label| 0 }
176
+ row[index] = frequency
177
+ end
178
+ end
22
179
  end
23
180
  end
@@ -0,0 +1,26 @@
1
+ require 'java'
2
+ require ::File.join(::File.dirname(__FILE__), "..", "..", "target" , "bae.jar")
3
+
4
+ module Bae
5
+ class NativeClassifier
6
+
7
+ attr_reader :internal_classifier
8
+
9
+ def initialize
10
+ @internal_classifier = ::Java::Bae::NaiveBayesClassifier.new
11
+ end
12
+
13
+ def train(label, feature)
14
+ internal_classifier.train(label, ::Java::Bae::Document.new(feature))
15
+ end
16
+
17
+ def classify(feature)
18
+ internal_classifier.classify(::Java::Bae::Document.new(feature))
19
+ end
20
+
21
+ def finish_training!
22
+ internal_classifier.calculateInitialLikelihoods()
23
+ end
24
+
25
+ end
26
+ end
data/lib/bae/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Bae
2
- VERSION = "0.0.7"
2
+ VERSION = "0.0.8"
3
3
  end
data/lib/bae.rb CHANGED
@@ -1,8 +1,6 @@
1
- require "bae/version"
2
- require 'java'
3
-
4
- require ::File.join(::File.dirname(__FILE__), "..", "target" , "bae.jar")
1
+ require 'json'
5
2
 
3
+ require "bae/version"
6
4
  require "bae/classifier"
7
5
 
8
6
  module Bae
@@ -1,10 +1,17 @@
1
1
  require 'spec_helper'
2
2
 
3
+ require 'bae/native_classifier'
4
+
3
5
  describe ::Bae::Classifier do
4
6
 
5
7
  subject { described_class.new }
6
8
 
7
- it "can classify from ruby to java with a hash document" do
9
+ let(:state_json) {
10
+ '{"frequency_table":{"aaa":[0,0],"bbb":[1,0],"ccc":[0,2],"ddd":[0,3]},"label_instance_count":{"positive":1,"negative":1},"label_index":{"positive":0,"negative":1},"label_index_sequence":1,"total_terms":2.0}'
11
+ }
12
+ let(:state) { ::JSON.parse(state_json) }
13
+
14
+ it "can classify a hash document" do
8
15
  subject.train("positive", {"aaa" => 0, "bbb" => 1})
9
16
  subject.train("negative", {"ccc" => 2, "ddd" => 3})
10
17
 
@@ -16,7 +23,7 @@ describe ::Bae::Classifier do
16
23
  expect(results["negative"]).to be_within(0.001).of(0.05882)
17
24
  end
18
25
 
19
- it "can classify from ruby to java with a string based document" do
26
+ it "can classify from a string based document" do
20
27
  subject.train("positive", "aaa aaa bbb");
21
28
  subject.train("negative", "ccc ccc ddd ddd");
22
29
  subject.train("neutral", "eee eee eee fff fff fff");
@@ -30,4 +37,46 @@ describe ::Bae::Classifier do
30
37
  expect(results["neutral"]).to be_within(0.001).of(0.03734)
31
38
  end
32
39
 
40
+ it "fails when you attempt to train or test anything other than a hash or string" do
41
+ subject.train("positive", "aaa aaa bbb");
42
+ expect{ subject.train("a", 1337) }.to raise_error 'Training data must either be a string or hash'
43
+
44
+ subject.finish_training!
45
+
46
+ subject.classify("aaa bbb")
47
+ expect{ subject.classify(1337) }.to raise_error 'Training data must either be a string or hash'
48
+ end
49
+
50
+ it "can save the classifier state" do
51
+ subject.train("positive", {"aaa" => 0, "bbb" => 1})
52
+ subject.train("negative", {"ccc" => 2, "ddd" => 3})
53
+
54
+ subject.finish_training!
55
+
56
+ temp_file = ::Tempfile.new('some_state')
57
+ subject.save_state(temp_file.path)
58
+
59
+ temp_file.rewind
60
+ expect(temp_file.read).to eq(state_json)
61
+
62
+ temp_file.close
63
+ temp_file.unlink
64
+ end
65
+
66
+ it "can correctly load a classifier state and correctly classify" do
67
+ temp_file = ::Tempfile.new('some_state')
68
+ temp_file.write(state_json)
69
+ temp_file.rewind
70
+
71
+ subject.load_state(temp_file.path)
72
+
73
+ results = subject.classify({"aaa" => 1, "bbb" => 1})
74
+
75
+ expect(results["positive"]).to be_within(0.001).of(0.94117)
76
+ expect(results["negative"]).to be_within(0.001).of(0.05882)
77
+
78
+ temp_file.close
79
+ temp_file.unlink
80
+ end
81
+
33
82
  end
@@ -0,0 +1,33 @@
1
+ require 'spec_helper'
2
+
3
+ describe ::Bae::NativeClassifier do
4
+
5
+ subject { described_class.new }
6
+
7
+ it "can classify a hash document" do
8
+ subject.train("positive", {"aaa" => 0, "bbb" => 1})
9
+ subject.train("negative", {"ccc" => 2, "ddd" => 3})
10
+
11
+ subject.finish_training!
12
+
13
+ results = subject.classify({"aaa" => 1, "bbb" => 1})
14
+
15
+ expect(results["positive"]).to be_within(0.001).of(0.94117)
16
+ expect(results["negative"]).to be_within(0.001).of(0.05882)
17
+ end
18
+
19
+ it "can classify from a string based document" do
20
+ subject.train("positive", "aaa aaa bbb");
21
+ subject.train("negative", "ccc ccc ddd ddd");
22
+ subject.train("neutral", "eee eee eee fff fff fff");
23
+
24
+ subject.finish_training!
25
+
26
+ results = subject.classify("aaa bbb")
27
+
28
+ expect(results["positive"]).to be_within(0.001).of(0.89626)
29
+ expect(results["negative"]).to be_within(0.001).of(0.06639)
30
+ expect(results["neutral"]).to be_within(0.001).of(0.03734)
31
+ end
32
+
33
+ end
data/spec/spec_helper.rb CHANGED
@@ -1,5 +1,6 @@
1
1
  require 'bundler/setup'
2
2
  require 'bae'
3
+ require 'tempfile'
3
4
  require 'rspec'
4
5
 
5
6
  RSpec.configure do |c|
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: bae
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.7
4
+ version: 0.0.8
5
5
  platform: java
6
6
  authors:
7
7
  - Garrett Thornburg
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2015-02-23 00:00:00.000000000 Z
11
+ date: 2015-02-25 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  requirement: !ruby/object:Gem::Requirement
@@ -68,8 +68,10 @@ files:
68
68
  - build.xml
69
69
  - lib/bae.rb
70
70
  - lib/bae/classifier.rb
71
+ - lib/bae/native_classifier.rb
71
72
  - lib/bae/version.rb
72
73
  - spec/lib/bae/classifier_spec.rb
74
+ - spec/lib/bae/native_classifier_spec.rb
73
75
  - spec/spec_helper.rb
74
76
  - src/main/java/bae/Document.java
75
77
  - src/main/java/bae/FrequencyTable.java
@@ -104,4 +106,5 @@ specification_version: 4
104
106
  summary: Multinomial naive bayes classifier with a kick of java
105
107
  test_files:
106
108
  - spec/lib/bae/classifier_spec.rb
109
+ - spec/lib/bae/native_classifier_spec.rb
107
110
  - spec/spec_helper.rb