classifier-reborn 2.0.4 → 2.3.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +5 -5
- data/LICENSE +74 -1
- data/README.markdown +57 -207
- data/data/stopwords/ar +104 -0
- data/data/stopwords/bn +362 -0
- data/data/stopwords/hi +97 -0
- data/data/stopwords/ja +43 -0
- data/data/stopwords/ru +420 -0
- data/data/stopwords/tr +199 -30
- data/data/stopwords/vi +647 -0
- data/data/stopwords/zh +125 -0
- data/lib/classifier-reborn/backends/bayes_memory_backend.rb +77 -0
- data/lib/classifier-reborn/backends/bayes_redis_backend.rb +109 -0
- data/lib/classifier-reborn/backends/no_redis_error.rb +14 -0
- data/lib/classifier-reborn/bayes.rb +141 -65
- data/lib/classifier-reborn/category_namer.rb +6 -4
- data/lib/classifier-reborn/extensions/hasher.rb +22 -39
- data/lib/classifier-reborn/extensions/token_filter/stemmer.rb +24 -0
- data/lib/classifier-reborn/extensions/token_filter/stopword.rb +48 -0
- data/lib/classifier-reborn/extensions/token_filter/symbol.rb +20 -0
- data/lib/classifier-reborn/extensions/tokenizer/token.rb +36 -0
- data/lib/classifier-reborn/extensions/tokenizer/whitespace.rb +28 -0
- data/lib/classifier-reborn/extensions/vector.rb +35 -28
- data/lib/classifier-reborn/extensions/vector_serialize.rb +10 -10
- data/lib/classifier-reborn/extensions/zero_vector.rb +7 -0
- data/lib/classifier-reborn/lsi/cached_content_node.rb +6 -5
- data/lib/classifier-reborn/lsi/content_node.rb +35 -25
- data/lib/classifier-reborn/lsi/summarizer.rb +7 -5
- data/lib/classifier-reborn/lsi/word_list.rb +5 -6
- data/lib/classifier-reborn/lsi.rb +166 -94
- data/lib/classifier-reborn/validators/classifier_validator.rb +170 -0
- data/lib/classifier-reborn/version.rb +3 -1
- data/lib/classifier-reborn.rb +12 -1
- metadata +98 -17
- data/bin/bayes.rb +0 -36
- data/bin/summarize.rb +0 -16
@@ -1,30 +1,45 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
1
3
|
# Author:: David Fayram (mailto:dfayram@lensmen.net)
|
2
4
|
# Copyright:: Copyright (c) 2005 David Fayram II
|
3
5
|
# License:: LGPL
|
4
6
|
|
7
|
+
# Try to load Numo first - it's the most current and the most well-supported.
|
8
|
+
# Fall back to GSL.
|
9
|
+
# Fall back to native vector.
|
5
10
|
begin
|
6
|
-
raise LoadError if ENV['NATIVE_VECTOR'] ==
|
7
|
-
|
8
|
-
require 'gsl' # requires http://rb-gsl.rubyforge.org/
|
9
|
-
require_relative 'extensions/vector_serialize'
|
10
|
-
$GSL = true
|
11
|
+
raise LoadError if ENV['NATIVE_VECTOR'] == 'true' # to test the native vector class, try `rake test NATIVE_VECTOR=true`
|
12
|
+
raise LoadError if ENV['GSL'] == 'true' # to test with gsl, try `rake test GSL=true`
|
11
13
|
|
14
|
+
require 'numo/narray' # https://ruby-numo.github.io/narray/
|
15
|
+
require 'numo/linalg' # https://ruby-numo.github.io/linalg/
|
16
|
+
$SVD = :numo
|
12
17
|
rescue LoadError
|
13
|
-
|
18
|
+
begin
|
19
|
+
raise LoadError if ENV['NATIVE_VECTOR'] == 'true' # to test the native vector class, try `rake test NATIVE_VECTOR=true`
|
20
|
+
|
21
|
+
require 'gsl' # requires https://github.com/SciRuby/rb-gsl
|
22
|
+
require_relative 'extensions/vector_serialize'
|
23
|
+
$SVD = :gsl
|
24
|
+
rescue LoadError
|
25
|
+
$SVD = :ruby
|
26
|
+
require_relative 'extensions/vector'
|
27
|
+
require_relative 'extensions/zero_vector'
|
28
|
+
end
|
14
29
|
end
|
15
30
|
|
16
31
|
require_relative 'lsi/word_list'
|
17
32
|
require_relative 'lsi/content_node'
|
18
33
|
require_relative 'lsi/cached_content_node'
|
19
34
|
require_relative 'lsi/summarizer'
|
35
|
+
require_relative 'extensions/token_filter/stopword'
|
36
|
+
require_relative 'extensions/token_filter/symbol'
|
20
37
|
|
21
38
|
module ClassifierReborn
|
22
|
-
|
23
39
|
# This class implements a Latent Semantic Indexer, which can search, classify and cluster
|
24
40
|
# data based on underlying semantic relations. For more information on the algorithms used,
|
25
41
|
# please consult Wikipedia[http://en.wikipedia.org/wiki/Latent_Semantic_Indexing].
|
26
42
|
class LSI
|
27
|
-
|
28
43
|
attr_reader :word_list, :cache_node_vectors
|
29
44
|
attr_accessor :auto_rebuild
|
30
45
|
|
@@ -36,12 +51,17 @@ module ClassifierReborn
|
|
36
51
|
#
|
37
52
|
def initialize(options = {})
|
38
53
|
@auto_rebuild = options[:auto_rebuild] != false
|
39
|
-
@word_list
|
40
|
-
@
|
54
|
+
@word_list = WordList.new
|
55
|
+
@items = {}
|
56
|
+
@version = 0
|
57
|
+
@built_at_version = -1
|
41
58
|
@language = options[:language] || 'en'
|
42
|
-
|
43
|
-
|
44
|
-
|
59
|
+
@token_filters = [
|
60
|
+
TokenFilter::Stopword,
|
61
|
+
TokenFilter::Symbol
|
62
|
+
]
|
63
|
+
TokenFilter::Stopword.language = @language
|
64
|
+
extend CachedContentNode::InstanceMethods if @cache_node_vectors = options[:cache_node_vectors]
|
45
65
|
end
|
46
66
|
|
47
67
|
# Returns true if the index needs to be rebuilt. The index needs
|
@@ -64,39 +84,45 @@ module ClassifierReborn
|
|
64
84
|
# ar = ActiveRecordObject.find( :all )
|
65
85
|
# lsi.add_item ar, *ar.categories { |x| ar.content }
|
66
86
|
#
|
67
|
-
def add_item(
|
68
|
-
clean_word_hash = Hasher.
|
69
|
-
|
70
|
-
|
87
|
+
def add_item(item, *categories, &block)
|
88
|
+
clean_word_hash = Hasher.word_hash((block ? yield(item) : item.to_s),
|
89
|
+
token_filters: @token_filters)
|
90
|
+
if clean_word_hash.empty?
|
91
|
+
puts "Input: '#{item}' is entirely stopwords or words with 2 or fewer characters. Classifier-Reborn cannot handle this document properly."
|
71
92
|
else
|
72
|
-
|
93
|
+
@items[item] = if @cache_node_vectors
|
94
|
+
CachedContentNode.new(clean_word_hash, *categories)
|
95
|
+
else
|
96
|
+
ContentNode.new(clean_word_hash, *categories)
|
97
|
+
end
|
98
|
+
@version += 1
|
99
|
+
build_index if @auto_rebuild
|
73
100
|
end
|
74
|
-
@version += 1
|
75
|
-
build_index if @auto_rebuild
|
76
101
|
end
|
77
102
|
|
78
103
|
# A less flexible shorthand for add_item that assumes
|
79
104
|
# you are passing in a string with no categorries. item
|
80
105
|
# will be duck typed via to_s .
|
81
106
|
#
|
82
|
-
def <<(
|
83
|
-
add_item
|
107
|
+
def <<(item)
|
108
|
+
add_item(item)
|
84
109
|
end
|
85
110
|
|
86
|
-
# Returns
|
111
|
+
# Returns categories for a given indexed item. You are free to add and remove
|
87
112
|
# items from this as you see fit. It does not invalide an index to change its categories.
|
88
113
|
def categories_for(item)
|
89
114
|
return [] unless @items[item]
|
90
|
-
|
115
|
+
|
116
|
+
@items[item].categories
|
91
117
|
end
|
92
118
|
|
93
119
|
# Removes an item from the database, if it is indexed.
|
94
120
|
#
|
95
|
-
def remove_item(
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
121
|
+
def remove_item(item)
|
122
|
+
return unless @items.key? item
|
123
|
+
|
124
|
+
@items.delete item
|
125
|
+
@version += 1
|
100
126
|
end
|
101
127
|
|
102
128
|
# Returns an array of items that are indexed.
|
@@ -118,30 +144,43 @@ module ClassifierReborn
|
|
118
144
|
# cutoff parameter tells the indexer how many of these values to keep.
|
119
145
|
# A value of 1 for cutoff means that no semantic analysis will take place,
|
120
146
|
# turning the LSI class into a simple vector search engine.
|
121
|
-
def build_index(
|
147
|
+
def build_index(cutoff = 0.75)
|
122
148
|
return unless needs_rebuild?
|
149
|
+
|
123
150
|
make_word_list
|
124
151
|
|
125
152
|
doc_list = @items.values
|
126
|
-
tda = doc_list.collect { |node| node.raw_vector_with(
|
153
|
+
tda = doc_list.collect { |node| node.raw_vector_with(@word_list) }
|
127
154
|
|
128
|
-
if $
|
129
|
-
|
130
|
-
|
155
|
+
if $SVD == :numo
|
156
|
+
tdm = Numo::NArray.asarray(tda.map(&:to_a)).transpose
|
157
|
+
ntdm = numo_build_reduced_matrix(tdm, cutoff)
|
131
158
|
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
159
|
+
ntdm.each_over_axis(1).with_index do |col_vec, i|
|
160
|
+
doc_list[i].lsi_vector = col_vec
|
161
|
+
doc_list[i].lsi_norm = col_vec / Numo::Linalg.norm(col_vec)
|
162
|
+
end
|
163
|
+
elsif $SVD == :gsl
|
164
|
+
tdm = GSL::Matrix.alloc(*tda).trans
|
165
|
+
ntdm = build_reduced_matrix(tdm, cutoff)
|
166
|
+
|
167
|
+
ntdm.size[1].times do |col|
|
168
|
+
vec = GSL::Vector.alloc(ntdm.column(col)).row
|
169
|
+
doc_list[col].lsi_vector = vec
|
170
|
+
doc_list[col].lsi_norm = vec.normalize
|
171
|
+
end
|
137
172
|
else
|
138
|
-
|
139
|
-
|
173
|
+
tdm = Matrix.rows(tda).trans
|
174
|
+
ntdm = build_reduced_matrix(tdm, cutoff)
|
140
175
|
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
176
|
+
ntdm.column_size.times do |col|
|
177
|
+
doc_list[col].lsi_vector = ntdm.column(col) if doc_list[col]
|
178
|
+
if ntdm.column(col).zero?
|
179
|
+
doc_list[col].lsi_norm = ntdm.column(col) if doc_list[col]
|
180
|
+
else
|
181
|
+
doc_list[col].lsi_norm = ntdm.column(col).normalize if doc_list[col]
|
182
|
+
end
|
183
|
+
end
|
145
184
|
end
|
146
185
|
|
147
186
|
@built_at_version = @version
|
@@ -155,13 +194,13 @@ module ClassifierReborn
|
|
155
194
|
# your dataset's general content. For example, if you were to use categorize on the
|
156
195
|
# results of this data, you could gather information on what your dataset is generally
|
157
196
|
# about.
|
158
|
-
def highest_relative_content(
|
159
|
-
|
197
|
+
def highest_relative_content(max_chunks = 10)
|
198
|
+
return [] if needs_rebuild?
|
160
199
|
|
161
|
-
|
162
|
-
|
200
|
+
avg_density = {}
|
201
|
+
@items.each_key { |item| avg_density[item] = proximity_array_for_content(item).inject(0.0) { |x, y| x + y[1] } }
|
163
202
|
|
164
|
-
|
203
|
+
avg_density.keys.sort_by { |x| avg_density[x] }.reverse[0..max_chunks - 1].map
|
165
204
|
end
|
166
205
|
|
167
206
|
# This function is the primitive that find_related and classify
|
@@ -176,17 +215,19 @@ module ClassifierReborn
|
|
176
215
|
# The parameter doc is the content to compare. If that content is not
|
177
216
|
# indexed, you can pass an optional block to define how to create the
|
178
217
|
# text data. See add_item for examples of how this works.
|
179
|
-
def proximity_array_for_content(
|
218
|
+
def proximity_array_for_content(doc, &block)
|
180
219
|
return [] if needs_rebuild?
|
181
220
|
|
182
|
-
content_node = node_for_content(
|
221
|
+
content_node = node_for_content(doc, &block)
|
183
222
|
result =
|
184
223
|
@items.keys.collect do |item|
|
185
|
-
if $
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
224
|
+
val = if $SVD == :numo
|
225
|
+
content_node.search_vector.dot(@items[item].transposed_search_vector)
|
226
|
+
elsif $SVD == :gsl
|
227
|
+
content_node.search_vector * @items[item].transposed_search_vector
|
228
|
+
else
|
229
|
+
(Matrix[content_node.search_vector] * @items[item].search_vector)[0]
|
230
|
+
end
|
190
231
|
[item, val]
|
191
232
|
end
|
192
233
|
result.sort_by { |x| x[1] }.reverse
|
@@ -197,17 +238,28 @@ module ClassifierReborn
|
|
197
238
|
# calculated vectors instead of their full versions. This is useful when
|
198
239
|
# you're trying to perform operations on content that is much smaller than
|
199
240
|
# the text you're working with. search uses this primitive.
|
200
|
-
def proximity_norms_for_content(
|
241
|
+
def proximity_norms_for_content(doc, &block)
|
201
242
|
return [] if needs_rebuild?
|
202
243
|
|
203
|
-
content_node = node_for_content(
|
244
|
+
content_node = node_for_content(doc, &block)
|
245
|
+
if ($SVD == :gsl && content_node.raw_norm.isnan?.all?) ||
|
246
|
+
($SVD == :numo && content_node.raw_norm.isnan.all?)
|
247
|
+
puts "There are no documents that are similar to #{doc}"
|
248
|
+
else
|
249
|
+
content_node_norms(content_node)
|
250
|
+
end
|
251
|
+
end
|
252
|
+
|
253
|
+
def content_node_norms(content_node)
|
204
254
|
result =
|
205
255
|
@items.keys.collect do |item|
|
206
|
-
if $
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
256
|
+
val = if $SVD == :numo
|
257
|
+
content_node.search_norm.dot(@items[item].search_norm)
|
258
|
+
elsif $SVD == :gsl
|
259
|
+
content_node.search_norm * @items[item].search_norm.col
|
260
|
+
else
|
261
|
+
(Matrix[content_node.search_norm] * @items[item].search_norm)[0]
|
262
|
+
end
|
211
263
|
[item, val]
|
212
264
|
end
|
213
265
|
result.sort_by { |x| x[1] }.reverse
|
@@ -220,11 +272,14 @@ module ClassifierReborn
|
|
220
272
|
#
|
221
273
|
# While this may seem backwards compared to the other functions that LSI supports,
|
222
274
|
# it is actually the same algorithm, just applied on a smaller document.
|
223
|
-
def search(
|
275
|
+
def search(string, max_nearest = 3)
|
224
276
|
return [] if needs_rebuild?
|
225
|
-
|
226
|
-
|
227
|
-
|
277
|
+
|
278
|
+
carry = proximity_norms_for_content(string)
|
279
|
+
unless carry.nil?
|
280
|
+
result = carry.collect { |x| x[0] }
|
281
|
+
result[0..max_nearest - 1]
|
282
|
+
end
|
228
283
|
end
|
229
284
|
|
230
285
|
# This function takes content and finds other documents
|
@@ -236,21 +291,21 @@ module ClassifierReborn
|
|
236
291
|
# This is particularly useful for identifing clusters in your document space.
|
237
292
|
# For example you may want to identify several "What's Related" items for weblog
|
238
293
|
# articles, or find paragraphs that relate to each other in an essay.
|
239
|
-
def find_related(
|
294
|
+
def find_related(doc, max_nearest = 3, &block)
|
240
295
|
carry =
|
241
|
-
proximity_array_for_content(
|
296
|
+
proximity_array_for_content(doc, &block).reject { |pair| pair[0].eql? doc }
|
242
297
|
result = carry.collect { |x| x[0] }
|
243
|
-
|
298
|
+
result[0..max_nearest - 1]
|
244
299
|
end
|
245
300
|
|
246
301
|
# Return the most obvious category with the score
|
247
|
-
def classify_with_score(
|
248
|
-
|
302
|
+
def classify_with_score(doc, cutoff = 0.30, &block)
|
303
|
+
scored_categories(doc, cutoff, &block).last
|
249
304
|
end
|
250
305
|
|
251
306
|
# Return the most obvious category without the score
|
252
|
-
def classify(
|
253
|
-
|
307
|
+
def classify(doc, cutoff = 0.30, &block)
|
308
|
+
scored_categories(doc, cutoff, &block).last.first
|
254
309
|
end
|
255
310
|
|
256
311
|
# This function uses a voting system to categorize documents, based on
|
@@ -262,10 +317,10 @@ module ClassifierReborn
|
|
262
317
|
# text. A cutoff of 1 means that every document in the index votes on
|
263
318
|
# what category the document is in. This may not always make sense.
|
264
319
|
#
|
265
|
-
def scored_categories(
|
320
|
+
def scored_categories(doc, cutoff = 0.30, &block)
|
266
321
|
icutoff = (@items.size * cutoff).round
|
267
|
-
carry = proximity_array_for_content(
|
268
|
-
carry = carry[0..icutoff-1]
|
322
|
+
carry = proximity_array_for_content(doc, &block)
|
323
|
+
carry = carry[0..icutoff - 1]
|
269
324
|
votes = Hash.new(0.0)
|
270
325
|
carry.each do |pair|
|
271
326
|
@items[pair[0]].categories.each do |category|
|
@@ -273,56 +328,73 @@ module ClassifierReborn
|
|
273
328
|
end
|
274
329
|
end
|
275
330
|
|
276
|
-
|
331
|
+
votes.sort_by { |_, score| score }
|
277
332
|
end
|
278
333
|
|
279
334
|
# Prototype, only works on indexed documents.
|
280
335
|
# I have no clue if this is going to work, but in theory
|
281
336
|
# it's supposed to.
|
282
|
-
def highest_ranked_stems(
|
283
|
-
raise
|
284
|
-
|
285
|
-
|
286
|
-
|
337
|
+
def highest_ranked_stems(doc, count = 3)
|
338
|
+
raise 'Requested stem ranking on non-indexed content!' unless @items[doc]
|
339
|
+
|
340
|
+
content_vector_array = node_for_content(doc).lsi_vector.to_a
|
341
|
+
top_n = content_vector_array.sort.reverse[0..count - 1]
|
342
|
+
top_n.collect { |x| @word_list.word_for_index(content_vector_array.index(x)) }
|
343
|
+
end
|
344
|
+
|
345
|
+
def reset
|
346
|
+
initialize(auto_rebuild: @auto_rebuild, cache_node_vectors: @cache_node_vectors)
|
287
347
|
end
|
288
348
|
|
289
349
|
private
|
290
|
-
|
350
|
+
|
351
|
+
def build_reduced_matrix(matrix, cutoff = 0.75)
|
291
352
|
# TODO: Check that M>=N on these dimensions! Transpose helps assure this
|
292
353
|
u, v, s = matrix.SV_decomp
|
293
|
-
|
294
354
|
# TODO: Better than 75% term, please. :\
|
295
355
|
s_cutoff = s.sort.reverse[(s.size * cutoff).round - 1]
|
296
356
|
s.size.times do |ord|
|
297
357
|
s[ord] = 0.0 if s[ord] < s_cutoff
|
298
358
|
end
|
299
359
|
# Reconstruct the term document matrix, only with reduced rank
|
300
|
-
u * ($
|
360
|
+
u * ($SVD == :gsl ? GSL::Matrix : ::Matrix).diag(s) * v.trans
|
361
|
+
end
|
362
|
+
|
363
|
+
def numo_build_reduced_matrix(matrix, cutoff = 0.75)
|
364
|
+
s, u, vt = Numo::Linalg.svd(matrix, driver: 'svd', job: 'S')
|
365
|
+
|
366
|
+
# TODO: Better than 75% term (as above)
|
367
|
+
s_cutoff = s.sort.reverse[(s.size * cutoff).round - 1]
|
368
|
+
s.size.times do |ord|
|
369
|
+
s[ord] = 0.0 if s[ord] < s_cutoff
|
370
|
+
end
|
371
|
+
|
372
|
+
# Reconstruct the term document matrix, only with reduced rank
|
373
|
+
u.dot(::Numo::DFloat.eye(s.size) * s).dot(vt)
|
301
374
|
end
|
302
375
|
|
303
376
|
def node_for_content(item, &block)
|
304
377
|
if @items[item]
|
305
378
|
return @items[item]
|
306
379
|
else
|
307
|
-
clean_word_hash = Hasher.
|
380
|
+
clean_word_hash = Hasher.word_hash((block ? yield(item) : item.to_s),
|
381
|
+
token_filters: @token_filters)
|
308
382
|
|
309
|
-
|
383
|
+
content_node = ContentNode.new(clean_word_hash, &block) # make the node and extract the data
|
310
384
|
|
311
385
|
unless needs_rebuild?
|
312
|
-
|
386
|
+
content_node.raw_vector_with(@word_list) # make the lsi raw and norm vectors
|
313
387
|
end
|
314
388
|
end
|
315
389
|
|
316
|
-
|
390
|
+
content_node
|
317
391
|
end
|
318
392
|
|
319
393
|
def make_word_list
|
320
394
|
@word_list = WordList.new
|
321
395
|
@items.each_value do |node|
|
322
|
-
node.word_hash.each_key { |key| @word_list.add_word
|
396
|
+
node.word_hash.each_key { |key| @word_list.add_word(key) }
|
323
397
|
end
|
324
398
|
end
|
325
|
-
|
326
399
|
end
|
327
400
|
end
|
328
|
-
|
@@ -0,0 +1,170 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module ClassifierReborn
|
4
|
+
module ClassifierValidator
|
5
|
+
module_function
|
6
|
+
|
7
|
+
def cross_validate(classifier, sample_data, fold = 10, *options)
|
8
|
+
classifier = ClassifierReborn.const_get(classifier).new(options) if classifier.is_a?(String)
|
9
|
+
sample_data.shuffle!
|
10
|
+
partition_size = sample_data.length / fold
|
11
|
+
partitioned_data = sample_data.each_slice(partition_size)
|
12
|
+
conf_mats = []
|
13
|
+
fold.times do |i|
|
14
|
+
training_data = partitioned_data.take(fold)
|
15
|
+
test_data = training_data.slice!(i)
|
16
|
+
conf_mats << validate(classifier, training_data.flatten!(1), test_data)
|
17
|
+
end
|
18
|
+
classifier.reset
|
19
|
+
generate_report(conf_mats)
|
20
|
+
end
|
21
|
+
|
22
|
+
def validate(classifier, training_data, test_data, *options)
|
23
|
+
classifier = ClassifierReborn.const_get(classifier).new(options) if classifier.is_a?(String)
|
24
|
+
classifier.reset
|
25
|
+
training_data.each do |rec|
|
26
|
+
classifier.train(rec.first, rec.last)
|
27
|
+
end
|
28
|
+
evaluate(classifier, test_data)
|
29
|
+
end
|
30
|
+
|
31
|
+
def evaluate(classifier, test_data)
|
32
|
+
conf_mat = empty_conf_mat(classifier.categories.sort)
|
33
|
+
test_data.each do |rec|
|
34
|
+
actual = rec.first.tr('_', ' ').capitalize
|
35
|
+
predicted = classifier.classify(rec.last)
|
36
|
+
conf_mat[actual][predicted] += 1 unless predicted.nil?
|
37
|
+
end
|
38
|
+
conf_mat
|
39
|
+
end
|
40
|
+
|
41
|
+
def generate_report(*conf_mats)
|
42
|
+
conf_mats.flatten!
|
43
|
+
accumulated_conf_mat = conf_mats.length == 1 ? conf_mats.first : empty_conf_mat(conf_mats.first.keys.sort)
|
44
|
+
header = 'Run Total Correct Incorrect Accuracy'
|
45
|
+
puts
|
46
|
+
puts ' Run Report '.center(header.length, '-')
|
47
|
+
puts header
|
48
|
+
puts '-' * header.length
|
49
|
+
if conf_mats.length > 1
|
50
|
+
conf_mats.each_with_index do |conf_mat, i|
|
51
|
+
run_report = build_run_report(conf_mat)
|
52
|
+
print_run_report(run_report, i + 1)
|
53
|
+
conf_mat.each do |actual, cols|
|
54
|
+
cols.each do |predicted, v|
|
55
|
+
accumulated_conf_mat[actual][predicted] += v
|
56
|
+
end
|
57
|
+
end
|
58
|
+
end
|
59
|
+
puts '-' * header.length
|
60
|
+
end
|
61
|
+
run_report = build_run_report(accumulated_conf_mat)
|
62
|
+
print_run_report(run_report, 'All')
|
63
|
+
puts
|
64
|
+
print_conf_mat(accumulated_conf_mat)
|
65
|
+
puts
|
66
|
+
conf_tab = conf_mat_to_tab(accumulated_conf_mat)
|
67
|
+
print_conf_tab(conf_tab)
|
68
|
+
end
|
69
|
+
|
70
|
+
def build_run_report(conf_mat)
|
71
|
+
correct = incorrect = 0
|
72
|
+
conf_mat.each do |actual, cols|
|
73
|
+
cols.each do |predicted, v|
|
74
|
+
if actual == predicted
|
75
|
+
correct += v
|
76
|
+
else
|
77
|
+
incorrect += v
|
78
|
+
end
|
79
|
+
end
|
80
|
+
end
|
81
|
+
total = correct + incorrect
|
82
|
+
{ total: total, correct: correct, incorrect: incorrect, accuracy: divide(correct, total) }
|
83
|
+
end
|
84
|
+
|
85
|
+
def conf_mat_to_tab(conf_mat)
|
86
|
+
conf_tab = Hash.new { |h, k| h[k] = { p: { t: 0, f: 0 }, n: { t: 0, f: 0 } } }
|
87
|
+
conf_mat.each_key do |positive|
|
88
|
+
conf_mat.each do |actual, cols|
|
89
|
+
cols.each do |predicted, v|
|
90
|
+
conf_tab[positive][positive == predicted ? :p : :n][actual == predicted ? :t : :f] += v
|
91
|
+
end
|
92
|
+
end
|
93
|
+
end
|
94
|
+
conf_tab
|
95
|
+
end
|
96
|
+
|
97
|
+
def print_run_report(stats, prefix = '', print_header = false)
|
98
|
+
puts "#{'Run'.rjust([3, prefix.length].max)} Total Correct Incorrect Accuracy" if print_header
|
99
|
+
puts "#{prefix.to_s.rjust(3)} #{stats[:total].to_s.rjust(9)} #{stats[:correct].to_s.rjust(9)} #{stats[:incorrect].to_s.rjust(9)} #{stats[:accuracy].round(5).to_s.ljust(7, '0').rjust(9)}"
|
100
|
+
end
|
101
|
+
|
102
|
+
def print_conf_mat(conf_mat)
|
103
|
+
header = ['Predicted ->'] + conf_mat.keys + %w[Total Recall]
|
104
|
+
cell_size = header.map(&:length).max
|
105
|
+
header = header.map { |h| h.rjust(cell_size) }.join(' ')
|
106
|
+
puts ' Confusion Matrix '.center(header.length, '-')
|
107
|
+
puts header
|
108
|
+
puts '-' * header.length
|
109
|
+
predicted_totals = conf_mat.keys.map { |predicted| [predicted, 0] }.to_h
|
110
|
+
correct = 0
|
111
|
+
conf_mat.each do |k, rec|
|
112
|
+
actual_total = rec.values.reduce(:+)
|
113
|
+
puts ([k.ljust(cell_size)] + rec.values.map { |v| v.to_s.rjust(cell_size) } + [actual_total.to_s.rjust(cell_size), divide(rec[k], actual_total).round(5).to_s.rjust(cell_size)]).join(' ')
|
114
|
+
rec.each do |cat, val|
|
115
|
+
predicted_totals[cat] += val
|
116
|
+
correct += val if cat == k
|
117
|
+
end
|
118
|
+
end
|
119
|
+
total = predicted_totals.values.reduce(:+)
|
120
|
+
puts '-' * header.length
|
121
|
+
puts (['Total'.ljust(cell_size)] + predicted_totals.values.map { |v| v.to_s.rjust(cell_size) } + [total.to_s.rjust(cell_size), ''.rjust(cell_size)]).join(' ')
|
122
|
+
puts (['Precision'.ljust(cell_size)] + predicted_totals.keys.map { |k| divide(conf_mat[k][k], predicted_totals[k]).round(5).to_s.rjust(cell_size) } + ['Accuracy ->'.rjust(cell_size), divide(correct, total).round(5).to_s.rjust(cell_size)]).join(' ')
|
123
|
+
end
|
124
|
+
|
125
|
+
def print_conf_tab(conf_tab)
|
126
|
+
conf_tab.each do |positive, tab|
|
127
|
+
puts "# Positive class: #{positive}"
|
128
|
+
derivations = conf_tab_derivations(tab)
|
129
|
+
print_derivations(derivations)
|
130
|
+
puts
|
131
|
+
end
|
132
|
+
end
|
133
|
+
|
134
|
+
def conf_tab_derivations(tab)
|
135
|
+
positives = tab[:p][:t] + tab[:n][:f]
|
136
|
+
negatives = tab[:n][:t] + tab[:p][:f]
|
137
|
+
total = positives + negatives
|
138
|
+
{
|
139
|
+
total_population: positives + negatives,
|
140
|
+
condition_positive: positives,
|
141
|
+
condition_negative: negatives,
|
142
|
+
true_positive: tab[:p][:t],
|
143
|
+
true_negative: tab[:n][:t],
|
144
|
+
false_positive: tab[:p][:f],
|
145
|
+
false_negative: tab[:n][:f],
|
146
|
+
prevalence: divide(positives, total),
|
147
|
+
specificity: divide(tab[:n][:t], negatives),
|
148
|
+
recall: divide(tab[:p][:t], positives),
|
149
|
+
precision: divide(tab[:p][:t], tab[:p][:t] + tab[:p][:f]),
|
150
|
+
accuracy: divide(tab[:p][:t] + tab[:n][:t], total),
|
151
|
+
f1_score: divide(2 * tab[:p][:t], 2 * tab[:p][:t] + tab[:p][:f] + tab[:n][:f])
|
152
|
+
}
|
153
|
+
end
|
154
|
+
|
155
|
+
def print_derivations(derivations)
|
156
|
+
max_len = derivations.keys.map(&:length).max
|
157
|
+
derivations.each do |k, v|
|
158
|
+
puts k.to_s.tr('_', ' ').capitalize.ljust(max_len) + ' : ' + v.to_s
|
159
|
+
end
|
160
|
+
end
|
161
|
+
|
162
|
+
def empty_conf_mat(categories)
|
163
|
+
categories.map { |actual| [actual, categories.map { |predicted| [predicted, 0] }.to_h] }.to_h
|
164
|
+
end
|
165
|
+
|
166
|
+
def divide(dividend, divisor)
|
167
|
+
divisor.zero? ? 0.0 : dividend / divisor.to_f
|
168
|
+
end
|
169
|
+
end
|
170
|
+
end
|
data/lib/classifier-reborn.rb
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
1
3
|
#--
|
2
4
|
# Copyright (c) 2005 Lucas Carlson
|
3
5
|
#
|
@@ -25,6 +27,15 @@
|
|
25
27
|
# License:: LGPL
|
26
28
|
|
27
29
|
require 'rubygems'
|
30
|
+
|
31
|
+
case RUBY_PLATFORM
|
32
|
+
when 'java'
|
33
|
+
require 'jruby-stemmer'
|
34
|
+
else
|
35
|
+
require 'fast-stemmer'
|
36
|
+
end
|
37
|
+
|
28
38
|
require_relative 'classifier-reborn/category_namer'
|
29
39
|
require_relative 'classifier-reborn/bayes'
|
30
|
-
require_relative 'classifier-reborn/lsi'
|
40
|
+
require_relative 'classifier-reborn/lsi'
|
41
|
+
require_relative 'classifier-reborn/validators/classifier_validator'
|