classifier-reborn 2.0.4 → 2.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +5 -5
- data/LICENSE +74 -1
- data/README.markdown +57 -207
- data/data/stopwords/ar +104 -0
- data/data/stopwords/bn +362 -0
- data/data/stopwords/hi +97 -0
- data/data/stopwords/ja +43 -0
- data/data/stopwords/ru +420 -0
- data/data/stopwords/tr +199 -30
- data/data/stopwords/vi +647 -0
- data/data/stopwords/zh +125 -0
- data/lib/classifier-reborn/backends/bayes_memory_backend.rb +77 -0
- data/lib/classifier-reborn/backends/bayes_redis_backend.rb +109 -0
- data/lib/classifier-reborn/backends/no_redis_error.rb +14 -0
- data/lib/classifier-reborn/bayes.rb +141 -65
- data/lib/classifier-reborn/category_namer.rb +6 -4
- data/lib/classifier-reborn/extensions/hasher.rb +22 -39
- data/lib/classifier-reborn/extensions/token_filter/stemmer.rb +24 -0
- data/lib/classifier-reborn/extensions/token_filter/stopword.rb +48 -0
- data/lib/classifier-reborn/extensions/token_filter/symbol.rb +20 -0
- data/lib/classifier-reborn/extensions/tokenizer/token.rb +36 -0
- data/lib/classifier-reborn/extensions/tokenizer/whitespace.rb +28 -0
- data/lib/classifier-reborn/extensions/vector.rb +35 -28
- data/lib/classifier-reborn/extensions/vector_serialize.rb +10 -10
- data/lib/classifier-reborn/extensions/zero_vector.rb +7 -0
- data/lib/classifier-reborn/lsi/cached_content_node.rb +6 -5
- data/lib/classifier-reborn/lsi/content_node.rb +35 -25
- data/lib/classifier-reborn/lsi/summarizer.rb +7 -5
- data/lib/classifier-reborn/lsi/word_list.rb +5 -6
- data/lib/classifier-reborn/lsi.rb +166 -94
- data/lib/classifier-reborn/validators/classifier_validator.rb +170 -0
- data/lib/classifier-reborn/version.rb +3 -1
- data/lib/classifier-reborn.rb +12 -1
- metadata +98 -17
- data/bin/bayes.rb +0 -36
- data/bin/summarize.rb +0 -16
@@ -1,30 +1,45 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
1
3
|
# Author:: David Fayram (mailto:dfayram@lensmen.net)
|
2
4
|
# Copyright:: Copyright (c) 2005 David Fayram II
|
3
5
|
# License:: LGPL
|
4
6
|
|
7
|
+
# Try to load Numo first - it's the most current and the most well-supported.
|
8
|
+
# Fall back to GSL.
|
9
|
+
# Fall back to native vector.
|
5
10
|
begin
|
6
|
-
raise LoadError if ENV['NATIVE_VECTOR'] ==
|
7
|
-
|
8
|
-
require 'gsl' # requires http://rb-gsl.rubyforge.org/
|
9
|
-
require_relative 'extensions/vector_serialize'
|
10
|
-
$GSL = true
|
11
|
+
raise LoadError if ENV['NATIVE_VECTOR'] == 'true' # to test the native vector class, try `rake test NATIVE_VECTOR=true`
|
12
|
+
raise LoadError if ENV['GSL'] == 'true' # to test with gsl, try `rake test GSL=true`
|
11
13
|
|
14
|
+
require 'numo/narray' # https://ruby-numo.github.io/narray/
|
15
|
+
require 'numo/linalg' # https://ruby-numo.github.io/linalg/
|
16
|
+
$SVD = :numo
|
12
17
|
rescue LoadError
|
13
|
-
|
18
|
+
begin
|
19
|
+
raise LoadError if ENV['NATIVE_VECTOR'] == 'true' # to test the native vector class, try `rake test NATIVE_VECTOR=true`
|
20
|
+
|
21
|
+
require 'gsl' # requires https://github.com/SciRuby/rb-gsl
|
22
|
+
require_relative 'extensions/vector_serialize'
|
23
|
+
$SVD = :gsl
|
24
|
+
rescue LoadError
|
25
|
+
$SVD = :ruby
|
26
|
+
require_relative 'extensions/vector'
|
27
|
+
require_relative 'extensions/zero_vector'
|
28
|
+
end
|
14
29
|
end
|
15
30
|
|
16
31
|
require_relative 'lsi/word_list'
|
17
32
|
require_relative 'lsi/content_node'
|
18
33
|
require_relative 'lsi/cached_content_node'
|
19
34
|
require_relative 'lsi/summarizer'
|
35
|
+
require_relative 'extensions/token_filter/stopword'
|
36
|
+
require_relative 'extensions/token_filter/symbol'
|
20
37
|
|
21
38
|
module ClassifierReborn
|
22
|
-
|
23
39
|
# This class implements a Latent Semantic Indexer, which can search, classify and cluster
|
24
40
|
# data based on underlying semantic relations. For more information on the algorithms used,
|
25
41
|
# please consult Wikipedia[http://en.wikipedia.org/wiki/Latent_Semantic_Indexing].
|
26
42
|
class LSI
|
27
|
-
|
28
43
|
attr_reader :word_list, :cache_node_vectors
|
29
44
|
attr_accessor :auto_rebuild
|
30
45
|
|
@@ -36,12 +51,17 @@ module ClassifierReborn
|
|
36
51
|
#
|
37
52
|
def initialize(options = {})
|
38
53
|
@auto_rebuild = options[:auto_rebuild] != false
|
39
|
-
@word_list
|
40
|
-
@
|
54
|
+
@word_list = WordList.new
|
55
|
+
@items = {}
|
56
|
+
@version = 0
|
57
|
+
@built_at_version = -1
|
41
58
|
@language = options[:language] || 'en'
|
42
|
-
|
43
|
-
|
44
|
-
|
59
|
+
@token_filters = [
|
60
|
+
TokenFilter::Stopword,
|
61
|
+
TokenFilter::Symbol
|
62
|
+
]
|
63
|
+
TokenFilter::Stopword.language = @language
|
64
|
+
extend CachedContentNode::InstanceMethods if @cache_node_vectors = options[:cache_node_vectors]
|
45
65
|
end
|
46
66
|
|
47
67
|
# Returns true if the index needs to be rebuilt. The index needs
|
@@ -64,39 +84,45 @@ module ClassifierReborn
|
|
64
84
|
# ar = ActiveRecordObject.find( :all )
|
65
85
|
# lsi.add_item ar, *ar.categories { |x| ar.content }
|
66
86
|
#
|
67
|
-
def add_item(
|
68
|
-
clean_word_hash = Hasher.
|
69
|
-
|
70
|
-
|
87
|
+
def add_item(item, *categories, &block)
|
88
|
+
clean_word_hash = Hasher.word_hash((block ? yield(item) : item.to_s),
|
89
|
+
token_filters: @token_filters)
|
90
|
+
if clean_word_hash.empty?
|
91
|
+
puts "Input: '#{item}' is entirely stopwords or words with 2 or fewer characters. Classifier-Reborn cannot handle this document properly."
|
71
92
|
else
|
72
|
-
|
93
|
+
@items[item] = if @cache_node_vectors
|
94
|
+
CachedContentNode.new(clean_word_hash, *categories)
|
95
|
+
else
|
96
|
+
ContentNode.new(clean_word_hash, *categories)
|
97
|
+
end
|
98
|
+
@version += 1
|
99
|
+
build_index if @auto_rebuild
|
73
100
|
end
|
74
|
-
@version += 1
|
75
|
-
build_index if @auto_rebuild
|
76
101
|
end
|
77
102
|
|
78
103
|
# A less flexible shorthand for add_item that assumes
|
79
104
|
# you are passing in a string with no categorries. item
|
80
105
|
# will be duck typed via to_s .
|
81
106
|
#
|
82
|
-
def <<(
|
83
|
-
add_item
|
107
|
+
def <<(item)
|
108
|
+
add_item(item)
|
84
109
|
end
|
85
110
|
|
86
|
-
# Returns
|
111
|
+
# Returns categories for a given indexed item. You are free to add and remove
|
87
112
|
# items from this as you see fit. It does not invalide an index to change its categories.
|
88
113
|
def categories_for(item)
|
89
114
|
return [] unless @items[item]
|
90
|
-
|
115
|
+
|
116
|
+
@items[item].categories
|
91
117
|
end
|
92
118
|
|
93
119
|
# Removes an item from the database, if it is indexed.
|
94
120
|
#
|
95
|
-
def remove_item(
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
121
|
+
def remove_item(item)
|
122
|
+
return unless @items.key? item
|
123
|
+
|
124
|
+
@items.delete item
|
125
|
+
@version += 1
|
100
126
|
end
|
101
127
|
|
102
128
|
# Returns an array of items that are indexed.
|
@@ -118,30 +144,43 @@ module ClassifierReborn
|
|
118
144
|
# cutoff parameter tells the indexer how many of these values to keep.
|
119
145
|
# A value of 1 for cutoff means that no semantic analysis will take place,
|
120
146
|
# turning the LSI class into a simple vector search engine.
|
121
|
-
def build_index(
|
147
|
+
def build_index(cutoff = 0.75)
|
122
148
|
return unless needs_rebuild?
|
149
|
+
|
123
150
|
make_word_list
|
124
151
|
|
125
152
|
doc_list = @items.values
|
126
|
-
tda = doc_list.collect { |node| node.raw_vector_with(
|
153
|
+
tda = doc_list.collect { |node| node.raw_vector_with(@word_list) }
|
127
154
|
|
128
|
-
if $
|
129
|
-
|
130
|
-
|
155
|
+
if $SVD == :numo
|
156
|
+
tdm = Numo::NArray.asarray(tda.map(&:to_a)).transpose
|
157
|
+
ntdm = numo_build_reduced_matrix(tdm, cutoff)
|
131
158
|
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
159
|
+
ntdm.each_over_axis(1).with_index do |col_vec, i|
|
160
|
+
doc_list[i].lsi_vector = col_vec
|
161
|
+
doc_list[i].lsi_norm = col_vec / Numo::Linalg.norm(col_vec)
|
162
|
+
end
|
163
|
+
elsif $SVD == :gsl
|
164
|
+
tdm = GSL::Matrix.alloc(*tda).trans
|
165
|
+
ntdm = build_reduced_matrix(tdm, cutoff)
|
166
|
+
|
167
|
+
ntdm.size[1].times do |col|
|
168
|
+
vec = GSL::Vector.alloc(ntdm.column(col)).row
|
169
|
+
doc_list[col].lsi_vector = vec
|
170
|
+
doc_list[col].lsi_norm = vec.normalize
|
171
|
+
end
|
137
172
|
else
|
138
|
-
|
139
|
-
|
173
|
+
tdm = Matrix.rows(tda).trans
|
174
|
+
ntdm = build_reduced_matrix(tdm, cutoff)
|
140
175
|
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
176
|
+
ntdm.column_size.times do |col|
|
177
|
+
doc_list[col].lsi_vector = ntdm.column(col) if doc_list[col]
|
178
|
+
if ntdm.column(col).zero?
|
179
|
+
doc_list[col].lsi_norm = ntdm.column(col) if doc_list[col]
|
180
|
+
else
|
181
|
+
doc_list[col].lsi_norm = ntdm.column(col).normalize if doc_list[col]
|
182
|
+
end
|
183
|
+
end
|
145
184
|
end
|
146
185
|
|
147
186
|
@built_at_version = @version
|
@@ -155,13 +194,13 @@ module ClassifierReborn
|
|
155
194
|
# your dataset's general content. For example, if you were to use categorize on the
|
156
195
|
# results of this data, you could gather information on what your dataset is generally
|
157
196
|
# about.
|
158
|
-
def highest_relative_content(
|
159
|
-
|
197
|
+
def highest_relative_content(max_chunks = 10)
|
198
|
+
return [] if needs_rebuild?
|
160
199
|
|
161
|
-
|
162
|
-
|
200
|
+
avg_density = {}
|
201
|
+
@items.each_key { |item| avg_density[item] = proximity_array_for_content(item).inject(0.0) { |x, y| x + y[1] } }
|
163
202
|
|
164
|
-
|
203
|
+
avg_density.keys.sort_by { |x| avg_density[x] }.reverse[0..max_chunks - 1].map
|
165
204
|
end
|
166
205
|
|
167
206
|
# This function is the primitive that find_related and classify
|
@@ -176,17 +215,19 @@ module ClassifierReborn
|
|
176
215
|
# The parameter doc is the content to compare. If that content is not
|
177
216
|
# indexed, you can pass an optional block to define how to create the
|
178
217
|
# text data. See add_item for examples of how this works.
|
179
|
-
def proximity_array_for_content(
|
218
|
+
def proximity_array_for_content(doc, &block)
|
180
219
|
return [] if needs_rebuild?
|
181
220
|
|
182
|
-
content_node = node_for_content(
|
221
|
+
content_node = node_for_content(doc, &block)
|
183
222
|
result =
|
184
223
|
@items.keys.collect do |item|
|
185
|
-
if $
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
224
|
+
val = if $SVD == :numo
|
225
|
+
content_node.search_vector.dot(@items[item].transposed_search_vector)
|
226
|
+
elsif $SVD == :gsl
|
227
|
+
content_node.search_vector * @items[item].transposed_search_vector
|
228
|
+
else
|
229
|
+
(Matrix[content_node.search_vector] * @items[item].search_vector)[0]
|
230
|
+
end
|
190
231
|
[item, val]
|
191
232
|
end
|
192
233
|
result.sort_by { |x| x[1] }.reverse
|
@@ -197,17 +238,28 @@ module ClassifierReborn
|
|
197
238
|
# calculated vectors instead of their full versions. This is useful when
|
198
239
|
# you're trying to perform operations on content that is much smaller than
|
199
240
|
# the text you're working with. search uses this primitive.
|
200
|
-
def proximity_norms_for_content(
|
241
|
+
def proximity_norms_for_content(doc, &block)
|
201
242
|
return [] if needs_rebuild?
|
202
243
|
|
203
|
-
content_node = node_for_content(
|
244
|
+
content_node = node_for_content(doc, &block)
|
245
|
+
if ($SVD == :gsl && content_node.raw_norm.isnan?.all?) ||
|
246
|
+
($SVD == :numo && content_node.raw_norm.isnan.all?)
|
247
|
+
puts "There are no documents that are similar to #{doc}"
|
248
|
+
else
|
249
|
+
content_node_norms(content_node)
|
250
|
+
end
|
251
|
+
end
|
252
|
+
|
253
|
+
def content_node_norms(content_node)
|
204
254
|
result =
|
205
255
|
@items.keys.collect do |item|
|
206
|
-
if $
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
256
|
+
val = if $SVD == :numo
|
257
|
+
content_node.search_norm.dot(@items[item].search_norm)
|
258
|
+
elsif $SVD == :gsl
|
259
|
+
content_node.search_norm * @items[item].search_norm.col
|
260
|
+
else
|
261
|
+
(Matrix[content_node.search_norm] * @items[item].search_norm)[0]
|
262
|
+
end
|
211
263
|
[item, val]
|
212
264
|
end
|
213
265
|
result.sort_by { |x| x[1] }.reverse
|
@@ -220,11 +272,14 @@ module ClassifierReborn
|
|
220
272
|
#
|
221
273
|
# While this may seem backwards compared to the other functions that LSI supports,
|
222
274
|
# it is actually the same algorithm, just applied on a smaller document.
|
223
|
-
def search(
|
275
|
+
def search(string, max_nearest = 3)
|
224
276
|
return [] if needs_rebuild?
|
225
|
-
|
226
|
-
|
227
|
-
|
277
|
+
|
278
|
+
carry = proximity_norms_for_content(string)
|
279
|
+
unless carry.nil?
|
280
|
+
result = carry.collect { |x| x[0] }
|
281
|
+
result[0..max_nearest - 1]
|
282
|
+
end
|
228
283
|
end
|
229
284
|
|
230
285
|
# This function takes content and finds other documents
|
@@ -236,21 +291,21 @@ module ClassifierReborn
|
|
236
291
|
# This is particularly useful for identifing clusters in your document space.
|
237
292
|
# For example you may want to identify several "What's Related" items for weblog
|
238
293
|
# articles, or find paragraphs that relate to each other in an essay.
|
239
|
-
def find_related(
|
294
|
+
def find_related(doc, max_nearest = 3, &block)
|
240
295
|
carry =
|
241
|
-
proximity_array_for_content(
|
296
|
+
proximity_array_for_content(doc, &block).reject { |pair| pair[0].eql? doc }
|
242
297
|
result = carry.collect { |x| x[0] }
|
243
|
-
|
298
|
+
result[0..max_nearest - 1]
|
244
299
|
end
|
245
300
|
|
246
301
|
# Return the most obvious category with the score
|
247
|
-
def classify_with_score(
|
248
|
-
|
302
|
+
def classify_with_score(doc, cutoff = 0.30, &block)
|
303
|
+
scored_categories(doc, cutoff, &block).last
|
249
304
|
end
|
250
305
|
|
251
306
|
# Return the most obvious category without the score
|
252
|
-
def classify(
|
253
|
-
|
307
|
+
def classify(doc, cutoff = 0.30, &block)
|
308
|
+
scored_categories(doc, cutoff, &block).last.first
|
254
309
|
end
|
255
310
|
|
256
311
|
# This function uses a voting system to categorize documents, based on
|
@@ -262,10 +317,10 @@ module ClassifierReborn
|
|
262
317
|
# text. A cutoff of 1 means that every document in the index votes on
|
263
318
|
# what category the document is in. This may not always make sense.
|
264
319
|
#
|
265
|
-
def scored_categories(
|
320
|
+
def scored_categories(doc, cutoff = 0.30, &block)
|
266
321
|
icutoff = (@items.size * cutoff).round
|
267
|
-
carry = proximity_array_for_content(
|
268
|
-
carry = carry[0..icutoff-1]
|
322
|
+
carry = proximity_array_for_content(doc, &block)
|
323
|
+
carry = carry[0..icutoff - 1]
|
269
324
|
votes = Hash.new(0.0)
|
270
325
|
carry.each do |pair|
|
271
326
|
@items[pair[0]].categories.each do |category|
|
@@ -273,56 +328,73 @@ module ClassifierReborn
|
|
273
328
|
end
|
274
329
|
end
|
275
330
|
|
276
|
-
|
331
|
+
votes.sort_by { |_, score| score }
|
277
332
|
end
|
278
333
|
|
279
334
|
# Prototype, only works on indexed documents.
|
280
335
|
# I have no clue if this is going to work, but in theory
|
281
336
|
# it's supposed to.
|
282
|
-
def highest_ranked_stems(
|
283
|
-
raise
|
284
|
-
|
285
|
-
|
286
|
-
|
337
|
+
def highest_ranked_stems(doc, count = 3)
|
338
|
+
raise 'Requested stem ranking on non-indexed content!' unless @items[doc]
|
339
|
+
|
340
|
+
content_vector_array = node_for_content(doc).lsi_vector.to_a
|
341
|
+
top_n = content_vector_array.sort.reverse[0..count - 1]
|
342
|
+
top_n.collect { |x| @word_list.word_for_index(content_vector_array.index(x)) }
|
343
|
+
end
|
344
|
+
|
345
|
+
def reset
|
346
|
+
initialize(auto_rebuild: @auto_rebuild, cache_node_vectors: @cache_node_vectors)
|
287
347
|
end
|
288
348
|
|
289
349
|
private
|
290
|
-
|
350
|
+
|
351
|
+
def build_reduced_matrix(matrix, cutoff = 0.75)
|
291
352
|
# TODO: Check that M>=N on these dimensions! Transpose helps assure this
|
292
353
|
u, v, s = matrix.SV_decomp
|
293
|
-
|
294
354
|
# TODO: Better than 75% term, please. :\
|
295
355
|
s_cutoff = s.sort.reverse[(s.size * cutoff).round - 1]
|
296
356
|
s.size.times do |ord|
|
297
357
|
s[ord] = 0.0 if s[ord] < s_cutoff
|
298
358
|
end
|
299
359
|
# Reconstruct the term document matrix, only with reduced rank
|
300
|
-
u * ($
|
360
|
+
u * ($SVD == :gsl ? GSL::Matrix : ::Matrix).diag(s) * v.trans
|
361
|
+
end
|
362
|
+
|
363
|
+
def numo_build_reduced_matrix(matrix, cutoff = 0.75)
|
364
|
+
s, u, vt = Numo::Linalg.svd(matrix, driver: 'svd', job: 'S')
|
365
|
+
|
366
|
+
# TODO: Better than 75% term (as above)
|
367
|
+
s_cutoff = s.sort.reverse[(s.size * cutoff).round - 1]
|
368
|
+
s.size.times do |ord|
|
369
|
+
s[ord] = 0.0 if s[ord] < s_cutoff
|
370
|
+
end
|
371
|
+
|
372
|
+
# Reconstruct the term document matrix, only with reduced rank
|
373
|
+
u.dot(::Numo::DFloat.eye(s.size) * s).dot(vt)
|
301
374
|
end
|
302
375
|
|
303
376
|
def node_for_content(item, &block)
|
304
377
|
if @items[item]
|
305
378
|
return @items[item]
|
306
379
|
else
|
307
|
-
clean_word_hash = Hasher.
|
380
|
+
clean_word_hash = Hasher.word_hash((block ? yield(item) : item.to_s),
|
381
|
+
token_filters: @token_filters)
|
308
382
|
|
309
|
-
|
383
|
+
content_node = ContentNode.new(clean_word_hash, &block) # make the node and extract the data
|
310
384
|
|
311
385
|
unless needs_rebuild?
|
312
|
-
|
386
|
+
content_node.raw_vector_with(@word_list) # make the lsi raw and norm vectors
|
313
387
|
end
|
314
388
|
end
|
315
389
|
|
316
|
-
|
390
|
+
content_node
|
317
391
|
end
|
318
392
|
|
319
393
|
def make_word_list
|
320
394
|
@word_list = WordList.new
|
321
395
|
@items.each_value do |node|
|
322
|
-
node.word_hash.each_key { |key| @word_list.add_word
|
396
|
+
node.word_hash.each_key { |key| @word_list.add_word(key) }
|
323
397
|
end
|
324
398
|
end
|
325
|
-
|
326
399
|
end
|
327
400
|
end
|
328
|
-
|
@@ -0,0 +1,170 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module ClassifierReborn
|
4
|
+
module ClassifierValidator
|
5
|
+
module_function
|
6
|
+
|
7
|
+
def cross_validate(classifier, sample_data, fold = 10, *options)
|
8
|
+
classifier = ClassifierReborn.const_get(classifier).new(options) if classifier.is_a?(String)
|
9
|
+
sample_data.shuffle!
|
10
|
+
partition_size = sample_data.length / fold
|
11
|
+
partitioned_data = sample_data.each_slice(partition_size)
|
12
|
+
conf_mats = []
|
13
|
+
fold.times do |i|
|
14
|
+
training_data = partitioned_data.take(fold)
|
15
|
+
test_data = training_data.slice!(i)
|
16
|
+
conf_mats << validate(classifier, training_data.flatten!(1), test_data)
|
17
|
+
end
|
18
|
+
classifier.reset
|
19
|
+
generate_report(conf_mats)
|
20
|
+
end
|
21
|
+
|
22
|
+
def validate(classifier, training_data, test_data, *options)
|
23
|
+
classifier = ClassifierReborn.const_get(classifier).new(options) if classifier.is_a?(String)
|
24
|
+
classifier.reset
|
25
|
+
training_data.each do |rec|
|
26
|
+
classifier.train(rec.first, rec.last)
|
27
|
+
end
|
28
|
+
evaluate(classifier, test_data)
|
29
|
+
end
|
30
|
+
|
31
|
+
def evaluate(classifier, test_data)
|
32
|
+
conf_mat = empty_conf_mat(classifier.categories.sort)
|
33
|
+
test_data.each do |rec|
|
34
|
+
actual = rec.first.tr('_', ' ').capitalize
|
35
|
+
predicted = classifier.classify(rec.last)
|
36
|
+
conf_mat[actual][predicted] += 1 unless predicted.nil?
|
37
|
+
end
|
38
|
+
conf_mat
|
39
|
+
end
|
40
|
+
|
41
|
+
def generate_report(*conf_mats)
|
42
|
+
conf_mats.flatten!
|
43
|
+
accumulated_conf_mat = conf_mats.length == 1 ? conf_mats.first : empty_conf_mat(conf_mats.first.keys.sort)
|
44
|
+
header = 'Run Total Correct Incorrect Accuracy'
|
45
|
+
puts
|
46
|
+
puts ' Run Report '.center(header.length, '-')
|
47
|
+
puts header
|
48
|
+
puts '-' * header.length
|
49
|
+
if conf_mats.length > 1
|
50
|
+
conf_mats.each_with_index do |conf_mat, i|
|
51
|
+
run_report = build_run_report(conf_mat)
|
52
|
+
print_run_report(run_report, i + 1)
|
53
|
+
conf_mat.each do |actual, cols|
|
54
|
+
cols.each do |predicted, v|
|
55
|
+
accumulated_conf_mat[actual][predicted] += v
|
56
|
+
end
|
57
|
+
end
|
58
|
+
end
|
59
|
+
puts '-' * header.length
|
60
|
+
end
|
61
|
+
run_report = build_run_report(accumulated_conf_mat)
|
62
|
+
print_run_report(run_report, 'All')
|
63
|
+
puts
|
64
|
+
print_conf_mat(accumulated_conf_mat)
|
65
|
+
puts
|
66
|
+
conf_tab = conf_mat_to_tab(accumulated_conf_mat)
|
67
|
+
print_conf_tab(conf_tab)
|
68
|
+
end
|
69
|
+
|
70
|
+
def build_run_report(conf_mat)
|
71
|
+
correct = incorrect = 0
|
72
|
+
conf_mat.each do |actual, cols|
|
73
|
+
cols.each do |predicted, v|
|
74
|
+
if actual == predicted
|
75
|
+
correct += v
|
76
|
+
else
|
77
|
+
incorrect += v
|
78
|
+
end
|
79
|
+
end
|
80
|
+
end
|
81
|
+
total = correct + incorrect
|
82
|
+
{ total: total, correct: correct, incorrect: incorrect, accuracy: divide(correct, total) }
|
83
|
+
end
|
84
|
+
|
85
|
+
def conf_mat_to_tab(conf_mat)
|
86
|
+
conf_tab = Hash.new { |h, k| h[k] = { p: { t: 0, f: 0 }, n: { t: 0, f: 0 } } }
|
87
|
+
conf_mat.each_key do |positive|
|
88
|
+
conf_mat.each do |actual, cols|
|
89
|
+
cols.each do |predicted, v|
|
90
|
+
conf_tab[positive][positive == predicted ? :p : :n][actual == predicted ? :t : :f] += v
|
91
|
+
end
|
92
|
+
end
|
93
|
+
end
|
94
|
+
conf_tab
|
95
|
+
end
|
96
|
+
|
97
|
+
def print_run_report(stats, prefix = '', print_header = false)
|
98
|
+
puts "#{'Run'.rjust([3, prefix.length].max)} Total Correct Incorrect Accuracy" if print_header
|
99
|
+
puts "#{prefix.to_s.rjust(3)} #{stats[:total].to_s.rjust(9)} #{stats[:correct].to_s.rjust(9)} #{stats[:incorrect].to_s.rjust(9)} #{stats[:accuracy].round(5).to_s.ljust(7, '0').rjust(9)}"
|
100
|
+
end
|
101
|
+
|
102
|
+
def print_conf_mat(conf_mat)
|
103
|
+
header = ['Predicted ->'] + conf_mat.keys + %w[Total Recall]
|
104
|
+
cell_size = header.map(&:length).max
|
105
|
+
header = header.map { |h| h.rjust(cell_size) }.join(' ')
|
106
|
+
puts ' Confusion Matrix '.center(header.length, '-')
|
107
|
+
puts header
|
108
|
+
puts '-' * header.length
|
109
|
+
predicted_totals = conf_mat.keys.map { |predicted| [predicted, 0] }.to_h
|
110
|
+
correct = 0
|
111
|
+
conf_mat.each do |k, rec|
|
112
|
+
actual_total = rec.values.reduce(:+)
|
113
|
+
puts ([k.ljust(cell_size)] + rec.values.map { |v| v.to_s.rjust(cell_size) } + [actual_total.to_s.rjust(cell_size), divide(rec[k], actual_total).round(5).to_s.rjust(cell_size)]).join(' ')
|
114
|
+
rec.each do |cat, val|
|
115
|
+
predicted_totals[cat] += val
|
116
|
+
correct += val if cat == k
|
117
|
+
end
|
118
|
+
end
|
119
|
+
total = predicted_totals.values.reduce(:+)
|
120
|
+
puts '-' * header.length
|
121
|
+
puts (['Total'.ljust(cell_size)] + predicted_totals.values.map { |v| v.to_s.rjust(cell_size) } + [total.to_s.rjust(cell_size), ''.rjust(cell_size)]).join(' ')
|
122
|
+
puts (['Precision'.ljust(cell_size)] + predicted_totals.keys.map { |k| divide(conf_mat[k][k], predicted_totals[k]).round(5).to_s.rjust(cell_size) } + ['Accuracy ->'.rjust(cell_size), divide(correct, total).round(5).to_s.rjust(cell_size)]).join(' ')
|
123
|
+
end
|
124
|
+
|
125
|
+
def print_conf_tab(conf_tab)
|
126
|
+
conf_tab.each do |positive, tab|
|
127
|
+
puts "# Positive class: #{positive}"
|
128
|
+
derivations = conf_tab_derivations(tab)
|
129
|
+
print_derivations(derivations)
|
130
|
+
puts
|
131
|
+
end
|
132
|
+
end
|
133
|
+
|
134
|
+
def conf_tab_derivations(tab)
|
135
|
+
positives = tab[:p][:t] + tab[:n][:f]
|
136
|
+
negatives = tab[:n][:t] + tab[:p][:f]
|
137
|
+
total = positives + negatives
|
138
|
+
{
|
139
|
+
total_population: positives + negatives,
|
140
|
+
condition_positive: positives,
|
141
|
+
condition_negative: negatives,
|
142
|
+
true_positive: tab[:p][:t],
|
143
|
+
true_negative: tab[:n][:t],
|
144
|
+
false_positive: tab[:p][:f],
|
145
|
+
false_negative: tab[:n][:f],
|
146
|
+
prevalence: divide(positives, total),
|
147
|
+
specificity: divide(tab[:n][:t], negatives),
|
148
|
+
recall: divide(tab[:p][:t], positives),
|
149
|
+
precision: divide(tab[:p][:t], tab[:p][:t] + tab[:p][:f]),
|
150
|
+
accuracy: divide(tab[:p][:t] + tab[:n][:t], total),
|
151
|
+
f1_score: divide(2 * tab[:p][:t], 2 * tab[:p][:t] + tab[:p][:f] + tab[:n][:f])
|
152
|
+
}
|
153
|
+
end
|
154
|
+
|
155
|
+
def print_derivations(derivations)
|
156
|
+
max_len = derivations.keys.map(&:length).max
|
157
|
+
derivations.each do |k, v|
|
158
|
+
puts k.to_s.tr('_', ' ').capitalize.ljust(max_len) + ' : ' + v.to_s
|
159
|
+
end
|
160
|
+
end
|
161
|
+
|
162
|
+
def empty_conf_mat(categories)
|
163
|
+
categories.map { |actual| [actual, categories.map { |predicted| [predicted, 0] }.to_h] }.to_h
|
164
|
+
end
|
165
|
+
|
166
|
+
def divide(dividend, divisor)
|
167
|
+
divisor.zero? ? 0.0 : dividend / divisor.to_f
|
168
|
+
end
|
169
|
+
end
|
170
|
+
end
|
data/lib/classifier-reborn.rb
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
1
3
|
#--
|
2
4
|
# Copyright (c) 2005 Lucas Carlson
|
3
5
|
#
|
@@ -25,6 +27,15 @@
|
|
25
27
|
# License:: LGPL
|
26
28
|
|
27
29
|
require 'rubygems'
|
30
|
+
|
31
|
+
case RUBY_PLATFORM
|
32
|
+
when 'java'
|
33
|
+
require 'jruby-stemmer'
|
34
|
+
else
|
35
|
+
require 'fast-stemmer'
|
36
|
+
end
|
37
|
+
|
28
38
|
require_relative 'classifier-reborn/category_namer'
|
29
39
|
require_relative 'classifier-reborn/bayes'
|
30
|
-
require_relative 'classifier-reborn/lsi'
|
40
|
+
require_relative 'classifier-reborn/lsi'
|
41
|
+
require_relative 'classifier-reborn/validators/classifier_validator'
|