classifier-reborn 2.0.4 → 2.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. checksums.yaml +5 -5
  2. data/LICENSE +74 -1
  3. data/README.markdown +57 -207
  4. data/data/stopwords/ar +104 -0
  5. data/data/stopwords/bn +362 -0
  6. data/data/stopwords/hi +97 -0
  7. data/data/stopwords/ja +43 -0
  8. data/data/stopwords/ru +420 -0
  9. data/data/stopwords/tr +199 -30
  10. data/data/stopwords/vi +647 -0
  11. data/data/stopwords/zh +125 -0
  12. data/lib/classifier-reborn/backends/bayes_memory_backend.rb +77 -0
  13. data/lib/classifier-reborn/backends/bayes_redis_backend.rb +109 -0
  14. data/lib/classifier-reborn/backends/no_redis_error.rb +14 -0
  15. data/lib/classifier-reborn/bayes.rb +141 -65
  16. data/lib/classifier-reborn/category_namer.rb +6 -4
  17. data/lib/classifier-reborn/extensions/hasher.rb +22 -39
  18. data/lib/classifier-reborn/extensions/token_filter/stemmer.rb +24 -0
  19. data/lib/classifier-reborn/extensions/token_filter/stopword.rb +48 -0
  20. data/lib/classifier-reborn/extensions/token_filter/symbol.rb +20 -0
  21. data/lib/classifier-reborn/extensions/tokenizer/token.rb +36 -0
  22. data/lib/classifier-reborn/extensions/tokenizer/whitespace.rb +28 -0
  23. data/lib/classifier-reborn/extensions/vector.rb +35 -28
  24. data/lib/classifier-reborn/extensions/vector_serialize.rb +10 -10
  25. data/lib/classifier-reborn/extensions/zero_vector.rb +7 -0
  26. data/lib/classifier-reborn/lsi/cached_content_node.rb +6 -5
  27. data/lib/classifier-reborn/lsi/content_node.rb +35 -25
  28. data/lib/classifier-reborn/lsi/summarizer.rb +7 -5
  29. data/lib/classifier-reborn/lsi/word_list.rb +5 -6
  30. data/lib/classifier-reborn/lsi.rb +166 -94
  31. data/lib/classifier-reborn/validators/classifier_validator.rb +170 -0
  32. data/lib/classifier-reborn/version.rb +3 -1
  33. data/lib/classifier-reborn.rb +12 -1
  34. metadata +98 -17
  35. data/bin/bayes.rb +0 -36
  36. data/bin/summarize.rb +0 -16
@@ -1,30 +1,45 @@
1
+ # frozen_string_literal: true
2
+
1
3
  # Author:: David Fayram (mailto:dfayram@lensmen.net)
2
4
  # Copyright:: Copyright (c) 2005 David Fayram II
3
5
  # License:: LGPL
4
6
 
7
+ # Try to load Numo first - it's the most current and the most well-supported.
8
+ # Fall back to GSL.
9
+ # Fall back to native vector.
5
10
  begin
6
- raise LoadError if ENV['NATIVE_VECTOR'] == "true" # to test the native vector class, try `rake test NATIVE_VECTOR=true`
7
-
8
- require 'gsl' # requires http://rb-gsl.rubyforge.org/
9
- require_relative 'extensions/vector_serialize'
10
- $GSL = true
11
+ raise LoadError if ENV['NATIVE_VECTOR'] == 'true' # to test the native vector class, try `rake test NATIVE_VECTOR=true`
12
+ raise LoadError if ENV['GSL'] == 'true' # to test with gsl, try `rake test GSL=true`
11
13
 
14
+ require 'numo/narray' # https://ruby-numo.github.io/narray/
15
+ require 'numo/linalg' # https://ruby-numo.github.io/linalg/
16
+ $SVD = :numo
12
17
  rescue LoadError
13
- require_relative 'extensions/vector'
18
+ begin
19
+ raise LoadError if ENV['NATIVE_VECTOR'] == 'true' # to test the native vector class, try `rake test NATIVE_VECTOR=true`
20
+
21
+ require 'gsl' # requires https://github.com/SciRuby/rb-gsl
22
+ require_relative 'extensions/vector_serialize'
23
+ $SVD = :gsl
24
+ rescue LoadError
25
+ $SVD = :ruby
26
+ require_relative 'extensions/vector'
27
+ require_relative 'extensions/zero_vector'
28
+ end
14
29
  end
15
30
 
16
31
  require_relative 'lsi/word_list'
17
32
  require_relative 'lsi/content_node'
18
33
  require_relative 'lsi/cached_content_node'
19
34
  require_relative 'lsi/summarizer'
35
+ require_relative 'extensions/token_filter/stopword'
36
+ require_relative 'extensions/token_filter/symbol'
20
37
 
21
38
  module ClassifierReborn
22
-
23
39
  # This class implements a Latent Semantic Indexer, which can search, classify and cluster
24
40
  # data based on underlying semantic relations. For more information on the algorithms used,
25
41
  # please consult Wikipedia[http://en.wikipedia.org/wiki/Latent_Semantic_Indexing].
26
42
  class LSI
27
-
28
43
  attr_reader :word_list, :cache_node_vectors
29
44
  attr_accessor :auto_rebuild
30
45
 
@@ -36,12 +51,17 @@ module ClassifierReborn
36
51
  #
37
52
  def initialize(options = {})
38
53
  @auto_rebuild = options[:auto_rebuild] != false
39
- @word_list, @items = WordList.new, {}
40
- @version, @built_at_version = 0, -1
54
+ @word_list = WordList.new
55
+ @items = {}
56
+ @version = 0
57
+ @built_at_version = -1
41
58
  @language = options[:language] || 'en'
42
- if @cache_node_vectors = options[:cache_node_vectors]
43
- extend CachedContentNode::InstanceMethods
44
- end
59
+ @token_filters = [
60
+ TokenFilter::Stopword,
61
+ TokenFilter::Symbol
62
+ ]
63
+ TokenFilter::Stopword.language = @language
64
+ extend CachedContentNode::InstanceMethods if @cache_node_vectors = options[:cache_node_vectors]
45
65
  end
46
66
 
47
67
  # Returns true if the index needs to be rebuilt. The index needs
@@ -64,39 +84,45 @@ module ClassifierReborn
64
84
  # ar = ActiveRecordObject.find( :all )
65
85
  # lsi.add_item ar, *ar.categories { |x| ar.content }
66
86
  #
67
- def add_item( item, *categories, &block )
68
- clean_word_hash = Hasher.clean_word_hash((block ? block.call(item) : item.to_s), @language)
69
- @items[item] = if @cache_node_vectors
70
- CachedContentNode.new(clean_word_hash, *categories)
87
+ def add_item(item, *categories, &block)
88
+ clean_word_hash = Hasher.word_hash((block ? yield(item) : item.to_s),
89
+ token_filters: @token_filters)
90
+ if clean_word_hash.empty?
91
+ puts "Input: '#{item}' is entirely stopwords or words with 2 or fewer characters. Classifier-Reborn cannot handle this document properly."
71
92
  else
72
- ContentNode.new(clean_word_hash, *categories)
93
+ @items[item] = if @cache_node_vectors
94
+ CachedContentNode.new(clean_word_hash, *categories)
95
+ else
96
+ ContentNode.new(clean_word_hash, *categories)
97
+ end
98
+ @version += 1
99
+ build_index if @auto_rebuild
73
100
  end
74
- @version += 1
75
- build_index if @auto_rebuild
76
101
  end
77
102
 
78
103
  # A less flexible shorthand for add_item that assumes
79
104
  # you are passing in a string with no categorries. item
80
105
  # will be duck typed via to_s .
81
106
  #
82
- def <<( item )
83
- add_item item
107
+ def <<(item)
108
+ add_item(item)
84
109
  end
85
110
 
86
- # Returns the categories for a given indexed items. You are free to add and remove
111
+ # Returns categories for a given indexed item. You are free to add and remove
87
112
  # items from this as you see fit. It does not invalide an index to change its categories.
88
113
  def categories_for(item)
89
114
  return [] unless @items[item]
90
- return @items[item].categories
115
+
116
+ @items[item].categories
91
117
  end
92
118
 
93
119
  # Removes an item from the database, if it is indexed.
94
120
  #
95
- def remove_item( item )
96
- if @items.key? item
97
- @items.delete item
98
- @version += 1
99
- end
121
+ def remove_item(item)
122
+ return unless @items.key? item
123
+
124
+ @items.delete item
125
+ @version += 1
100
126
  end
101
127
 
102
128
  # Returns an array of items that are indexed.
@@ -118,30 +144,43 @@ module ClassifierReborn
118
144
  # cutoff parameter tells the indexer how many of these values to keep.
119
145
  # A value of 1 for cutoff means that no semantic analysis will take place,
120
146
  # turning the LSI class into a simple vector search engine.
121
- def build_index( cutoff=0.75 )
147
+ def build_index(cutoff = 0.75)
122
148
  return unless needs_rebuild?
149
+
123
150
  make_word_list
124
151
 
125
152
  doc_list = @items.values
126
- tda = doc_list.collect { |node| node.raw_vector_with( @word_list ) }
153
+ tda = doc_list.collect { |node| node.raw_vector_with(@word_list) }
127
154
 
128
- if $GSL
129
- tdm = GSL::Matrix.alloc(*tda).trans
130
- ntdm = build_reduced_matrix(tdm, cutoff)
155
+ if $SVD == :numo
156
+ tdm = Numo::NArray.asarray(tda.map(&:to_a)).transpose
157
+ ntdm = numo_build_reduced_matrix(tdm, cutoff)
131
158
 
132
- ntdm.size[1].times do |col|
133
- vec = GSL::Vector.alloc( ntdm.column(col) ).row
134
- doc_list[col].lsi_vector = vec
135
- doc_list[col].lsi_norm = vec.normalize
136
- end
159
+ ntdm.each_over_axis(1).with_index do |col_vec, i|
160
+ doc_list[i].lsi_vector = col_vec
161
+ doc_list[i].lsi_norm = col_vec / Numo::Linalg.norm(col_vec)
162
+ end
163
+ elsif $SVD == :gsl
164
+ tdm = GSL::Matrix.alloc(*tda).trans
165
+ ntdm = build_reduced_matrix(tdm, cutoff)
166
+
167
+ ntdm.size[1].times do |col|
168
+ vec = GSL::Vector.alloc(ntdm.column(col)).row
169
+ doc_list[col].lsi_vector = vec
170
+ doc_list[col].lsi_norm = vec.normalize
171
+ end
137
172
  else
138
- tdm = Matrix.rows(tda).trans
139
- ntdm = build_reduced_matrix(tdm, cutoff)
173
+ tdm = Matrix.rows(tda).trans
174
+ ntdm = build_reduced_matrix(tdm, cutoff)
140
175
 
141
- ntdm.row_size.times do |col|
142
- doc_list[col].lsi_vector = ntdm.column(col) if doc_list[col]
143
- doc_list[col].lsi_norm = ntdm.column(col).normalize if doc_list[col]
144
- end
176
+ ntdm.column_size.times do |col|
177
+ doc_list[col].lsi_vector = ntdm.column(col) if doc_list[col]
178
+ if ntdm.column(col).zero?
179
+ doc_list[col].lsi_norm = ntdm.column(col) if doc_list[col]
180
+ else
181
+ doc_list[col].lsi_norm = ntdm.column(col).normalize if doc_list[col]
182
+ end
183
+ end
145
184
  end
146
185
 
147
186
  @built_at_version = @version
@@ -155,13 +194,13 @@ module ClassifierReborn
155
194
  # your dataset's general content. For example, if you were to use categorize on the
156
195
  # results of this data, you could gather information on what your dataset is generally
157
196
  # about.
158
- def highest_relative_content( max_chunks=10 )
159
- return [] if needs_rebuild?
197
+ def highest_relative_content(max_chunks = 10)
198
+ return [] if needs_rebuild?
160
199
 
161
- avg_density = Hash.new
162
- @items.each_key { |item| avg_density[item] = proximity_array_for_content(item).inject(0.0) { |x,y| x + y[1]} }
200
+ avg_density = {}
201
+ @items.each_key { |item| avg_density[item] = proximity_array_for_content(item).inject(0.0) { |x, y| x + y[1] } }
163
202
 
164
- avg_density.keys.sort_by { |x| avg_density[x] }.reverse[0..max_chunks-1].map
203
+ avg_density.keys.sort_by { |x| avg_density[x] }.reverse[0..max_chunks - 1].map
165
204
  end
166
205
 
167
206
  # This function is the primitive that find_related and classify
@@ -176,17 +215,19 @@ module ClassifierReborn
176
215
  # The parameter doc is the content to compare. If that content is not
177
216
  # indexed, you can pass an optional block to define how to create the
178
217
  # text data. See add_item for examples of how this works.
179
- def proximity_array_for_content( doc, &block )
218
+ def proximity_array_for_content(doc, &block)
180
219
  return [] if needs_rebuild?
181
220
 
182
- content_node = node_for_content( doc, &block )
221
+ content_node = node_for_content(doc, &block)
183
222
  result =
184
223
  @items.keys.collect do |item|
185
- if $GSL
186
- val = content_node.search_vector * @items[item].transposed_search_vector
187
- else
188
- val = (Matrix[content_node.search_vector] * @items[item].search_vector)[0]
189
- end
224
+ val = if $SVD == :numo
225
+ content_node.search_vector.dot(@items[item].transposed_search_vector)
226
+ elsif $SVD == :gsl
227
+ content_node.search_vector * @items[item].transposed_search_vector
228
+ else
229
+ (Matrix[content_node.search_vector] * @items[item].search_vector)[0]
230
+ end
190
231
  [item, val]
191
232
  end
192
233
  result.sort_by { |x| x[1] }.reverse
@@ -197,17 +238,28 @@ module ClassifierReborn
197
238
  # calculated vectors instead of their full versions. This is useful when
198
239
  # you're trying to perform operations on content that is much smaller than
199
240
  # the text you're working with. search uses this primitive.
200
- def proximity_norms_for_content( doc, &block )
241
+ def proximity_norms_for_content(doc, &block)
201
242
  return [] if needs_rebuild?
202
243
 
203
- content_node = node_for_content( doc, &block )
244
+ content_node = node_for_content(doc, &block)
245
+ if ($SVD == :gsl && content_node.raw_norm.isnan?.all?) ||
246
+ ($SVD == :numo && content_node.raw_norm.isnan.all?)
247
+ puts "There are no documents that are similar to #{doc}"
248
+ else
249
+ content_node_norms(content_node)
250
+ end
251
+ end
252
+
253
+ def content_node_norms(content_node)
204
254
  result =
205
255
  @items.keys.collect do |item|
206
- if $GSL
207
- val = content_node.search_norm * @items[item].search_norm.col
208
- else
209
- val = (Matrix[content_node.search_norm] * @items[item].search_norm)[0]
210
- end
256
+ val = if $SVD == :numo
257
+ content_node.search_norm.dot(@items[item].search_norm)
258
+ elsif $SVD == :gsl
259
+ content_node.search_norm * @items[item].search_norm.col
260
+ else
261
+ (Matrix[content_node.search_norm] * @items[item].search_norm)[0]
262
+ end
211
263
  [item, val]
212
264
  end
213
265
  result.sort_by { |x| x[1] }.reverse
@@ -220,11 +272,14 @@ module ClassifierReborn
220
272
  #
221
273
  # While this may seem backwards compared to the other functions that LSI supports,
222
274
  # it is actually the same algorithm, just applied on a smaller document.
223
- def search( string, max_nearest=3 )
275
+ def search(string, max_nearest = 3)
224
276
  return [] if needs_rebuild?
225
- carry = proximity_norms_for_content( string )
226
- result = carry.collect { |x| x[0] }
227
- return result[0..max_nearest-1]
277
+
278
+ carry = proximity_norms_for_content(string)
279
+ unless carry.nil?
280
+ result = carry.collect { |x| x[0] }
281
+ result[0..max_nearest - 1]
282
+ end
228
283
  end
229
284
 
230
285
  # This function takes content and finds other documents
@@ -236,21 +291,21 @@ module ClassifierReborn
236
291
  # This is particularly useful for identifing clusters in your document space.
237
292
  # For example you may want to identify several "What's Related" items for weblog
238
293
  # articles, or find paragraphs that relate to each other in an essay.
239
- def find_related( doc, max_nearest=3, &block )
294
+ def find_related(doc, max_nearest = 3, &block)
240
295
  carry =
241
- proximity_array_for_content( doc, &block ).reject { |pair| pair[0].eql? doc }
296
+ proximity_array_for_content(doc, &block).reject { |pair| pair[0].eql? doc }
242
297
  result = carry.collect { |x| x[0] }
243
- return result[0..max_nearest-1]
298
+ result[0..max_nearest - 1]
244
299
  end
245
300
 
246
301
  # Return the most obvious category with the score
247
- def classify_with_score( doc, cutoff=0.30, &block)
248
- return scored_categories(doc, cutoff, &block).last
302
+ def classify_with_score(doc, cutoff = 0.30, &block)
303
+ scored_categories(doc, cutoff, &block).last
249
304
  end
250
305
 
251
306
  # Return the most obvious category without the score
252
- def classify( doc, cutoff=0.30, &block )
253
- return scored_categories(doc, cutoff, &block).last.first
307
+ def classify(doc, cutoff = 0.30, &block)
308
+ scored_categories(doc, cutoff, &block).last.first
254
309
  end
255
310
 
256
311
  # This function uses a voting system to categorize documents, based on
@@ -262,10 +317,10 @@ module ClassifierReborn
262
317
  # text. A cutoff of 1 means that every document in the index votes on
263
318
  # what category the document is in. This may not always make sense.
264
319
  #
265
- def scored_categories( doc, cutoff=0.30, &block )
320
+ def scored_categories(doc, cutoff = 0.30, &block)
266
321
  icutoff = (@items.size * cutoff).round
267
- carry = proximity_array_for_content( doc, &block )
268
- carry = carry[0..icutoff-1]
322
+ carry = proximity_array_for_content(doc, &block)
323
+ carry = carry[0..icutoff - 1]
269
324
  votes = Hash.new(0.0)
270
325
  carry.each do |pair|
271
326
  @items[pair[0]].categories.each do |category|
@@ -273,56 +328,73 @@ module ClassifierReborn
273
328
  end
274
329
  end
275
330
 
276
- return votes.sort_by { |_, score| score }
331
+ votes.sort_by { |_, score| score }
277
332
  end
278
333
 
279
334
  # Prototype, only works on indexed documents.
280
335
  # I have no clue if this is going to work, but in theory
281
336
  # it's supposed to.
282
- def highest_ranked_stems( doc, count=3 )
283
- raise "Requested stem ranking on non-indexed content!" unless @items[doc]
284
- arr = node_for_content(doc).lsi_vector.to_a
285
- top_n = arr.sort.reverse[0..count-1]
286
- return top_n.collect { |x| @word_list.word_for_index(arr.index(x))}
337
+ def highest_ranked_stems(doc, count = 3)
338
+ raise 'Requested stem ranking on non-indexed content!' unless @items[doc]
339
+
340
+ content_vector_array = node_for_content(doc).lsi_vector.to_a
341
+ top_n = content_vector_array.sort.reverse[0..count - 1]
342
+ top_n.collect { |x| @word_list.word_for_index(content_vector_array.index(x)) }
343
+ end
344
+
345
+ def reset
346
+ initialize(auto_rebuild: @auto_rebuild, cache_node_vectors: @cache_node_vectors)
287
347
  end
288
348
 
289
349
  private
290
- def build_reduced_matrix( matrix, cutoff=0.75 )
350
+
351
+ def build_reduced_matrix(matrix, cutoff = 0.75)
291
352
  # TODO: Check that M>=N on these dimensions! Transpose helps assure this
292
353
  u, v, s = matrix.SV_decomp
293
-
294
354
  # TODO: Better than 75% term, please. :\
295
355
  s_cutoff = s.sort.reverse[(s.size * cutoff).round - 1]
296
356
  s.size.times do |ord|
297
357
  s[ord] = 0.0 if s[ord] < s_cutoff
298
358
  end
299
359
  # Reconstruct the term document matrix, only with reduced rank
300
- u * ($GSL ? GSL::Matrix : ::Matrix).diag( s ) * v.trans
360
+ u * ($SVD == :gsl ? GSL::Matrix : ::Matrix).diag(s) * v.trans
361
+ end
362
+
363
+ def numo_build_reduced_matrix(matrix, cutoff = 0.75)
364
+ s, u, vt = Numo::Linalg.svd(matrix, driver: 'svd', job: 'S')
365
+
366
+ # TODO: Better than 75% term (as above)
367
+ s_cutoff = s.sort.reverse[(s.size * cutoff).round - 1]
368
+ s.size.times do |ord|
369
+ s[ord] = 0.0 if s[ord] < s_cutoff
370
+ end
371
+
372
+ # Reconstruct the term document matrix, only with reduced rank
373
+ u.dot(::Numo::DFloat.eye(s.size) * s).dot(vt)
301
374
  end
302
375
 
303
376
  def node_for_content(item, &block)
304
377
  if @items[item]
305
378
  return @items[item]
306
379
  else
307
- clean_word_hash = Hasher.clean_word_hash((block ? block.call(item) : item.to_s), @language)
380
+ clean_word_hash = Hasher.word_hash((block ? yield(item) : item.to_s),
381
+ token_filters: @token_filters)
308
382
 
309
- cn = ContentNode.new(clean_word_hash, &block) # make the node and extract the data
383
+ content_node = ContentNode.new(clean_word_hash, &block) # make the node and extract the data
310
384
 
311
385
  unless needs_rebuild?
312
- cn.raw_vector_with( @word_list ) # make the lsi raw and norm vectors
386
+ content_node.raw_vector_with(@word_list) # make the lsi raw and norm vectors
313
387
  end
314
388
  end
315
389
 
316
- return cn
390
+ content_node
317
391
  end
318
392
 
319
393
  def make_word_list
320
394
  @word_list = WordList.new
321
395
  @items.each_value do |node|
322
- node.word_hash.each_key { |key| @word_list.add_word key }
396
+ node.word_hash.each_key { |key| @word_list.add_word(key) }
323
397
  end
324
398
  end
325
-
326
399
  end
327
400
  end
328
-
@@ -0,0 +1,170 @@
1
+ # frozen_string_literal: true
2
+
3
+ module ClassifierReborn
4
+ module ClassifierValidator
5
+ module_function
6
+
7
+ def cross_validate(classifier, sample_data, fold = 10, *options)
8
+ classifier = ClassifierReborn.const_get(classifier).new(options) if classifier.is_a?(String)
9
+ sample_data.shuffle!
10
+ partition_size = sample_data.length / fold
11
+ partitioned_data = sample_data.each_slice(partition_size)
12
+ conf_mats = []
13
+ fold.times do |i|
14
+ training_data = partitioned_data.take(fold)
15
+ test_data = training_data.slice!(i)
16
+ conf_mats << validate(classifier, training_data.flatten!(1), test_data)
17
+ end
18
+ classifier.reset
19
+ generate_report(conf_mats)
20
+ end
21
+
22
+ def validate(classifier, training_data, test_data, *options)
23
+ classifier = ClassifierReborn.const_get(classifier).new(options) if classifier.is_a?(String)
24
+ classifier.reset
25
+ training_data.each do |rec|
26
+ classifier.train(rec.first, rec.last)
27
+ end
28
+ evaluate(classifier, test_data)
29
+ end
30
+
31
+ def evaluate(classifier, test_data)
32
+ conf_mat = empty_conf_mat(classifier.categories.sort)
33
+ test_data.each do |rec|
34
+ actual = rec.first.tr('_', ' ').capitalize
35
+ predicted = classifier.classify(rec.last)
36
+ conf_mat[actual][predicted] += 1 unless predicted.nil?
37
+ end
38
+ conf_mat
39
+ end
40
+
41
+ def generate_report(*conf_mats)
42
+ conf_mats.flatten!
43
+ accumulated_conf_mat = conf_mats.length == 1 ? conf_mats.first : empty_conf_mat(conf_mats.first.keys.sort)
44
+ header = 'Run Total Correct Incorrect Accuracy'
45
+ puts
46
+ puts ' Run Report '.center(header.length, '-')
47
+ puts header
48
+ puts '-' * header.length
49
+ if conf_mats.length > 1
50
+ conf_mats.each_with_index do |conf_mat, i|
51
+ run_report = build_run_report(conf_mat)
52
+ print_run_report(run_report, i + 1)
53
+ conf_mat.each do |actual, cols|
54
+ cols.each do |predicted, v|
55
+ accumulated_conf_mat[actual][predicted] += v
56
+ end
57
+ end
58
+ end
59
+ puts '-' * header.length
60
+ end
61
+ run_report = build_run_report(accumulated_conf_mat)
62
+ print_run_report(run_report, 'All')
63
+ puts
64
+ print_conf_mat(accumulated_conf_mat)
65
+ puts
66
+ conf_tab = conf_mat_to_tab(accumulated_conf_mat)
67
+ print_conf_tab(conf_tab)
68
+ end
69
+
70
+ def build_run_report(conf_mat)
71
+ correct = incorrect = 0
72
+ conf_mat.each do |actual, cols|
73
+ cols.each do |predicted, v|
74
+ if actual == predicted
75
+ correct += v
76
+ else
77
+ incorrect += v
78
+ end
79
+ end
80
+ end
81
+ total = correct + incorrect
82
+ { total: total, correct: correct, incorrect: incorrect, accuracy: divide(correct, total) }
83
+ end
84
+
85
+ def conf_mat_to_tab(conf_mat)
86
+ conf_tab = Hash.new { |h, k| h[k] = { p: { t: 0, f: 0 }, n: { t: 0, f: 0 } } }
87
+ conf_mat.each_key do |positive|
88
+ conf_mat.each do |actual, cols|
89
+ cols.each do |predicted, v|
90
+ conf_tab[positive][positive == predicted ? :p : :n][actual == predicted ? :t : :f] += v
91
+ end
92
+ end
93
+ end
94
+ conf_tab
95
+ end
96
+
97
+ def print_run_report(stats, prefix = '', print_header = false)
98
+ puts "#{'Run'.rjust([3, prefix.length].max)} Total Correct Incorrect Accuracy" if print_header
99
+ puts "#{prefix.to_s.rjust(3)} #{stats[:total].to_s.rjust(9)} #{stats[:correct].to_s.rjust(9)} #{stats[:incorrect].to_s.rjust(9)} #{stats[:accuracy].round(5).to_s.ljust(7, '0').rjust(9)}"
100
+ end
101
+
102
+ def print_conf_mat(conf_mat)
103
+ header = ['Predicted ->'] + conf_mat.keys + %w[Total Recall]
104
+ cell_size = header.map(&:length).max
105
+ header = header.map { |h| h.rjust(cell_size) }.join(' ')
106
+ puts ' Confusion Matrix '.center(header.length, '-')
107
+ puts header
108
+ puts '-' * header.length
109
+ predicted_totals = conf_mat.keys.map { |predicted| [predicted, 0] }.to_h
110
+ correct = 0
111
+ conf_mat.each do |k, rec|
112
+ actual_total = rec.values.reduce(:+)
113
+ puts ([k.ljust(cell_size)] + rec.values.map { |v| v.to_s.rjust(cell_size) } + [actual_total.to_s.rjust(cell_size), divide(rec[k], actual_total).round(5).to_s.rjust(cell_size)]).join(' ')
114
+ rec.each do |cat, val|
115
+ predicted_totals[cat] += val
116
+ correct += val if cat == k
117
+ end
118
+ end
119
+ total = predicted_totals.values.reduce(:+)
120
+ puts '-' * header.length
121
+ puts (['Total'.ljust(cell_size)] + predicted_totals.values.map { |v| v.to_s.rjust(cell_size) } + [total.to_s.rjust(cell_size), ''.rjust(cell_size)]).join(' ')
122
+ puts (['Precision'.ljust(cell_size)] + predicted_totals.keys.map { |k| divide(conf_mat[k][k], predicted_totals[k]).round(5).to_s.rjust(cell_size) } + ['Accuracy ->'.rjust(cell_size), divide(correct, total).round(5).to_s.rjust(cell_size)]).join(' ')
123
+ end
124
+
125
+ def print_conf_tab(conf_tab)
126
+ conf_tab.each do |positive, tab|
127
+ puts "# Positive class: #{positive}"
128
+ derivations = conf_tab_derivations(tab)
129
+ print_derivations(derivations)
130
+ puts
131
+ end
132
+ end
133
+
134
+ def conf_tab_derivations(tab)
135
+ positives = tab[:p][:t] + tab[:n][:f]
136
+ negatives = tab[:n][:t] + tab[:p][:f]
137
+ total = positives + negatives
138
+ {
139
+ total_population: positives + negatives,
140
+ condition_positive: positives,
141
+ condition_negative: negatives,
142
+ true_positive: tab[:p][:t],
143
+ true_negative: tab[:n][:t],
144
+ false_positive: tab[:p][:f],
145
+ false_negative: tab[:n][:f],
146
+ prevalence: divide(positives, total),
147
+ specificity: divide(tab[:n][:t], negatives),
148
+ recall: divide(tab[:p][:t], positives),
149
+ precision: divide(tab[:p][:t], tab[:p][:t] + tab[:p][:f]),
150
+ accuracy: divide(tab[:p][:t] + tab[:n][:t], total),
151
+ f1_score: divide(2 * tab[:p][:t], 2 * tab[:p][:t] + tab[:p][:f] + tab[:n][:f])
152
+ }
153
+ end
154
+
155
+ def print_derivations(derivations)
156
+ max_len = derivations.keys.map(&:length).max
157
+ derivations.each do |k, v|
158
+ puts k.to_s.tr('_', ' ').capitalize.ljust(max_len) + ' : ' + v.to_s
159
+ end
160
+ end
161
+
162
+ def empty_conf_mat(categories)
163
+ categories.map { |actual| [actual, categories.map { |predicted| [predicted, 0] }.to_h] }.to_h
164
+ end
165
+
166
+ def divide(dividend, divisor)
167
+ divisor.zero? ? 0.0 : dividend / divisor.to_f
168
+ end
169
+ end
170
+ end
@@ -1,3 +1,5 @@
1
+ # frozen_string_literal: true
2
+
1
3
  module ClassifierReborn
2
- VERSION = '2.0.4'
4
+ VERSION = '2.3.0'
3
5
  end
@@ -1,3 +1,5 @@
1
+ # frozen_string_literal: true
2
+
1
3
  #--
2
4
  # Copyright (c) 2005 Lucas Carlson
3
5
  #
@@ -25,6 +27,15 @@
25
27
  # License:: LGPL
26
28
 
27
29
  require 'rubygems'
30
+
31
+ case RUBY_PLATFORM
32
+ when 'java'
33
+ require 'jruby-stemmer'
34
+ else
35
+ require 'fast-stemmer'
36
+ end
37
+
28
38
  require_relative 'classifier-reborn/category_namer'
29
39
  require_relative 'classifier-reborn/bayes'
30
- require_relative 'classifier-reborn/lsi'
40
+ require_relative 'classifier-reborn/lsi'
41
+ require_relative 'classifier-reborn/validators/classifier_validator'