classifier 2.0.0 → 2.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CLAUDE.md +23 -13
- data/README.md +72 -190
- data/ext/classifier/classifier_ext.c +26 -0
- data/ext/classifier/extconf.rb +15 -0
- data/ext/classifier/incremental_svd.c +393 -0
- data/ext/classifier/linalg.h +72 -0
- data/ext/classifier/matrix.c +387 -0
- data/ext/classifier/svd.c +208 -0
- data/ext/classifier/vector.c +319 -0
- data/lib/classifier/bayes.rb +398 -54
- data/lib/classifier/errors.rb +19 -0
- data/lib/classifier/extensions/vector.rb +12 -4
- data/lib/classifier/knn.rb +351 -0
- data/lib/classifier/logistic_regression.rb +571 -0
- data/lib/classifier/lsi/content_node.rb +5 -5
- data/lib/classifier/lsi/incremental_svd.rb +166 -0
- data/lib/classifier/lsi/summary.rb +25 -5
- data/lib/classifier/lsi.rb +784 -138
- data/lib/classifier/storage/base.rb +50 -0
- data/lib/classifier/storage/file.rb +51 -0
- data/lib/classifier/storage/memory.rb +49 -0
- data/lib/classifier/storage.rb +9 -0
- data/lib/classifier/streaming/line_reader.rb +99 -0
- data/lib/classifier/streaming/progress.rb +96 -0
- data/lib/classifier/streaming.rb +122 -0
- data/lib/classifier/tfidf.rb +408 -0
- data/lib/classifier.rb +6 -0
- data/sig/vendor/json.rbs +4 -0
- data/sig/vendor/matrix.rbs +25 -14
- data/sig/vendor/mutex_m.rbs +16 -0
- data/sig/vendor/streaming.rbs +14 -0
- data/test/test_helper.rb +2 -0
- metadata +52 -8
- data/lib/classifier/extensions/vector_serialize.rb +0 -18
data/lib/classifier/bayes.rb
CHANGED
|
@@ -4,68 +4,68 @@
|
|
|
4
4
|
# Copyright:: Copyright (c) 2005 Lucas Carlson
|
|
5
5
|
# License:: LGPL
|
|
6
6
|
|
|
7
|
+
require 'json'
|
|
8
|
+
require 'mutex_m'
|
|
9
|
+
|
|
7
10
|
module Classifier
|
|
8
|
-
class Bayes
|
|
11
|
+
class Bayes # rubocop:disable Metrics/ClassLength
|
|
12
|
+
include Mutex_m
|
|
13
|
+
include Streaming
|
|
14
|
+
|
|
9
15
|
# @rbs @categories: Hash[Symbol, Hash[Symbol, Integer]]
|
|
10
16
|
# @rbs @total_words: Integer
|
|
11
17
|
# @rbs @category_counts: Hash[Symbol, Integer]
|
|
12
18
|
# @rbs @category_word_count: Hash[Symbol, Integer]
|
|
19
|
+
# @rbs @cached_training_count: Float?
|
|
20
|
+
# @rbs @cached_vocab_size: Integer?
|
|
21
|
+
# @rbs @dirty: bool
|
|
22
|
+
# @rbs @storage: Storage::Base?
|
|
23
|
+
|
|
24
|
+
attr_accessor :storage
|
|
13
25
|
|
|
14
26
|
# The class can be created with one or more categories, each of which will be
|
|
15
27
|
# initialized and given a training method. E.g.,
|
|
16
28
|
# b = Classifier::Bayes.new 'Interesting', 'Uninteresting', 'Spam'
|
|
17
|
-
#
|
|
29
|
+
# b = Classifier::Bayes.new ['Interesting', 'Uninteresting', 'Spam']
|
|
30
|
+
# @rbs (*String | Symbol | Array[String | Symbol]) -> void
|
|
18
31
|
def initialize(*categories)
|
|
32
|
+
super()
|
|
19
33
|
@categories = {}
|
|
20
|
-
categories.each { |category| @categories[category.prepare_category_name] = {} }
|
|
34
|
+
categories.flatten.each { |category| @categories[category.prepare_category_name] = {} }
|
|
21
35
|
@total_words = 0
|
|
22
36
|
@category_counts = Hash.new(0)
|
|
23
37
|
@category_word_count = Hash.new(0)
|
|
38
|
+
@cached_training_count = nil
|
|
39
|
+
@cached_vocab_size = nil
|
|
40
|
+
@dirty = false
|
|
41
|
+
@storage = nil
|
|
24
42
|
end
|
|
25
43
|
|
|
26
|
-
#
|
|
27
|
-
# For example:
|
|
28
|
-
# b = Classifier::Bayes.new 'This', 'That', 'the_other'
|
|
29
|
-
# b.train :this, "This text"
|
|
30
|
-
# b.train "that", "That text"
|
|
31
|
-
# b.train "The other", "The other text"
|
|
44
|
+
# Trains the classifier with text for a category.
|
|
32
45
|
#
|
|
33
|
-
#
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
46
|
+
# b.train(spam: "Buy now!", ham: ["Hello", "Meeting tomorrow"])
|
|
47
|
+
# b.train(:spam, "legacy positional API")
|
|
48
|
+
#
|
|
49
|
+
# @rbs (?(String | Symbol)?, ?String?, **(String | Array[String])) -> void
|
|
50
|
+
def train(category = nil, text = nil, **categories)
|
|
51
|
+
return train_single(category, text) if category && text
|
|
52
|
+
|
|
53
|
+
categories.each do |cat, texts|
|
|
54
|
+
(texts.is_a?(Array) ? texts : [texts]).each { |t| train_single(cat, t) }
|
|
42
55
|
end
|
|
43
56
|
end
|
|
44
57
|
|
|
45
|
-
#
|
|
46
|
-
# Be very careful with this method.
|
|
58
|
+
# Removes training data. Be careful with this method.
|
|
47
59
|
#
|
|
48
|
-
#
|
|
49
|
-
# b
|
|
50
|
-
# b.train :this, "This text"
|
|
51
|
-
# b.untrain :this, "This text"
|
|
60
|
+
# b.untrain(spam: "Buy now!")
|
|
61
|
+
# b.untrain(:spam, "legacy positional API")
|
|
52
62
|
#
|
|
53
|
-
# @rbs (String | Symbol
|
|
54
|
-
def untrain(category, text)
|
|
55
|
-
category
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
orig = @categories[category][word] || 0
|
|
61
|
-
@categories[category][word] ||= 0
|
|
62
|
-
@categories[category][word] -= count
|
|
63
|
-
if @categories[category][word] <= 0
|
|
64
|
-
@categories[category].delete(word)
|
|
65
|
-
count = orig
|
|
66
|
-
end
|
|
67
|
-
@category_word_count[category] -= count if @category_word_count[category] >= count
|
|
68
|
-
@total_words -= count
|
|
63
|
+
# @rbs (?(String | Symbol)?, ?String?, **(String | Array[String])) -> void
|
|
64
|
+
def untrain(category = nil, text = nil, **categories)
|
|
65
|
+
return untrain_single(category, text) if category && text
|
|
66
|
+
|
|
67
|
+
categories.each do |cat, texts|
|
|
68
|
+
(texts.is_a?(Array) ? texts : [texts]).each { |t| untrain_single(cat, t) }
|
|
69
69
|
end
|
|
70
70
|
end
|
|
71
71
|
|
|
@@ -77,17 +77,19 @@ module Classifier
|
|
|
77
77
|
# @rbs (String) -> Hash[String, Float]
|
|
78
78
|
def classifications(text)
|
|
79
79
|
words = text.word_hash.keys
|
|
80
|
-
|
|
81
|
-
|
|
80
|
+
synchronize do
|
|
81
|
+
training_count = cached_training_count
|
|
82
|
+
vocab_size = cached_vocab_size
|
|
82
83
|
|
|
83
|
-
|
|
84
|
-
|
|
84
|
+
@categories.to_h do |category, category_words|
|
|
85
|
+
smoothed_total = ((@category_word_count[category] || 0) + vocab_size).to_f
|
|
85
86
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
87
|
+
# Laplace smoothing: P(word|category) = (count + α) / (total + α * V)
|
|
88
|
+
word_score = words.sum { |w| Math.log(((category_words[w] || 0) + 1) / smoothed_total) }
|
|
89
|
+
prior_score = Math.log((@category_counts[category] || 0.1) / training_count)
|
|
89
90
|
|
|
90
|
-
|
|
91
|
+
[category.to_s, word_score + prior_score]
|
|
92
|
+
end
|
|
91
93
|
end
|
|
92
94
|
end
|
|
93
95
|
|
|
@@ -104,6 +106,119 @@ module Classifier
|
|
|
104
106
|
best.first.to_s
|
|
105
107
|
end
|
|
106
108
|
|
|
109
|
+
# Returns a hash representation of the classifier state.
|
|
110
|
+
# This can be converted to JSON or used directly.
|
|
111
|
+
#
|
|
112
|
+
# @rbs (?untyped) -> untyped
|
|
113
|
+
def as_json(_options = nil)
|
|
114
|
+
{
|
|
115
|
+
version: 1,
|
|
116
|
+
type: 'bayes',
|
|
117
|
+
categories: @categories.transform_keys(&:to_s).transform_values { |v| v.transform_keys(&:to_s) },
|
|
118
|
+
total_words: @total_words,
|
|
119
|
+
category_counts: @category_counts.transform_keys(&:to_s),
|
|
120
|
+
category_word_count: @category_word_count.transform_keys(&:to_s)
|
|
121
|
+
}
|
|
122
|
+
end
|
|
123
|
+
|
|
124
|
+
# Serializes the classifier state to a JSON string.
|
|
125
|
+
# This can be saved to a file and later loaded with Bayes.from_json.
|
|
126
|
+
#
|
|
127
|
+
# @rbs (?untyped) -> String
|
|
128
|
+
def to_json(_options = nil)
|
|
129
|
+
as_json.to_json
|
|
130
|
+
end
|
|
131
|
+
|
|
132
|
+
# Loads a classifier from a JSON string or a Hash created by #to_json or #as_json.
|
|
133
|
+
#
|
|
134
|
+
# @rbs (String | Hash[String, untyped]) -> Bayes
|
|
135
|
+
def self.from_json(json)
|
|
136
|
+
data = json.is_a?(String) ? JSON.parse(json) : json
|
|
137
|
+
raise ArgumentError, "Invalid classifier type: #{data['type']}" unless data['type'] == 'bayes'
|
|
138
|
+
|
|
139
|
+
instance = allocate
|
|
140
|
+
instance.send(:restore_state, data)
|
|
141
|
+
instance
|
|
142
|
+
end
|
|
143
|
+
|
|
144
|
+
# Saves the classifier to the configured storage.
|
|
145
|
+
# Raises ArgumentError if no storage is configured.
|
|
146
|
+
#
|
|
147
|
+
# @rbs () -> void
|
|
148
|
+
def save
|
|
149
|
+
raise ArgumentError, 'No storage configured. Use save_to_file(path) or set storage=' unless storage
|
|
150
|
+
|
|
151
|
+
storage.write(to_json)
|
|
152
|
+
@dirty = false
|
|
153
|
+
end
|
|
154
|
+
|
|
155
|
+
# Saves the classifier state to a file (legacy API).
|
|
156
|
+
#
|
|
157
|
+
# @rbs (String) -> Integer
|
|
158
|
+
def save_to_file(path)
|
|
159
|
+
result = File.write(path, to_json)
|
|
160
|
+
@dirty = false
|
|
161
|
+
result
|
|
162
|
+
end
|
|
163
|
+
|
|
164
|
+
# Reloads the classifier from the configured storage.
|
|
165
|
+
# Raises UnsavedChangesError if there are unsaved changes.
|
|
166
|
+
# Use reload! to force reload and discard changes.
|
|
167
|
+
#
|
|
168
|
+
# @rbs () -> self
|
|
169
|
+
def reload
|
|
170
|
+
raise ArgumentError, 'No storage configured' unless storage
|
|
171
|
+
raise UnsavedChangesError, 'Unsaved changes would be lost. Call save first or use reload!' if @dirty
|
|
172
|
+
|
|
173
|
+
data = storage.read
|
|
174
|
+
raise StorageError, 'No saved state found' unless data
|
|
175
|
+
|
|
176
|
+
restore_from_json(data)
|
|
177
|
+
@dirty = false
|
|
178
|
+
self
|
|
179
|
+
end
|
|
180
|
+
|
|
181
|
+
# Force reloads the classifier from storage, discarding any unsaved changes.
|
|
182
|
+
#
|
|
183
|
+
# @rbs () -> self
|
|
184
|
+
def reload!
|
|
185
|
+
raise ArgumentError, 'No storage configured' unless storage
|
|
186
|
+
|
|
187
|
+
data = storage.read
|
|
188
|
+
raise StorageError, 'No saved state found' unless data
|
|
189
|
+
|
|
190
|
+
restore_from_json(data)
|
|
191
|
+
@dirty = false
|
|
192
|
+
self
|
|
193
|
+
end
|
|
194
|
+
|
|
195
|
+
# Returns true if there are unsaved changes.
|
|
196
|
+
#
|
|
197
|
+
# @rbs () -> bool
|
|
198
|
+
def dirty?
|
|
199
|
+
@dirty
|
|
200
|
+
end
|
|
201
|
+
|
|
202
|
+
# Loads a classifier from the configured storage.
|
|
203
|
+
# The storage is set on the returned instance.
|
|
204
|
+
#
|
|
205
|
+
# @rbs (storage: Storage::Base) -> Bayes
|
|
206
|
+
def self.load(storage:)
|
|
207
|
+
data = storage.read
|
|
208
|
+
raise StorageError, 'No saved state found' unless data
|
|
209
|
+
|
|
210
|
+
instance = from_json(data)
|
|
211
|
+
instance.storage = storage
|
|
212
|
+
instance
|
|
213
|
+
end
|
|
214
|
+
|
|
215
|
+
# Loads a classifier from a file (legacy API).
|
|
216
|
+
#
|
|
217
|
+
# @rbs (String) -> Bayes
|
|
218
|
+
def self.load_from_file(path)
|
|
219
|
+
from_json(File.read(path))
|
|
220
|
+
end
|
|
221
|
+
|
|
107
222
|
#
|
|
108
223
|
# Provides training and untraining methods for the categories specified in Bayes#new
|
|
109
224
|
# For example:
|
|
@@ -134,7 +249,7 @@ module Classifier
|
|
|
134
249
|
#
|
|
135
250
|
# @rbs () -> Array[String]
|
|
136
251
|
def categories
|
|
137
|
-
@categories.keys.collect(&:to_s)
|
|
252
|
+
synchronize { @categories.keys.collect(&:to_s) }
|
|
138
253
|
end
|
|
139
254
|
|
|
140
255
|
# Allows you to add categories to the classifier.
|
|
@@ -148,11 +263,31 @@ module Classifier
|
|
|
148
263
|
#
|
|
149
264
|
# @rbs (String | Symbol) -> Hash[Symbol, Integer]
|
|
150
265
|
def add_category(category)
|
|
151
|
-
|
|
266
|
+
synchronize do
|
|
267
|
+
invalidate_caches
|
|
268
|
+
@dirty = true
|
|
269
|
+
@categories[category.prepare_category_name] = {}
|
|
270
|
+
end
|
|
152
271
|
end
|
|
153
272
|
|
|
154
273
|
alias append_category add_category
|
|
155
274
|
|
|
275
|
+
# Custom marshal serialization to exclude mutex state
|
|
276
|
+
# @rbs () -> Array[untyped]
|
|
277
|
+
def marshal_dump
|
|
278
|
+
[@categories, @total_words, @category_counts, @category_word_count, @dirty]
|
|
279
|
+
end
|
|
280
|
+
|
|
281
|
+
# Custom marshal deserialization to recreate mutex
|
|
282
|
+
# @rbs (Array[untyped]) -> void
|
|
283
|
+
def marshal_load(data)
|
|
284
|
+
mu_initialize
|
|
285
|
+
@categories, @total_words, @category_counts, @category_word_count, @dirty = data
|
|
286
|
+
@cached_training_count = nil
|
|
287
|
+
@cached_vocab_size = nil
|
|
288
|
+
@storage = nil
|
|
289
|
+
end
|
|
290
|
+
|
|
156
291
|
# Allows you to remove categories from the classifier.
|
|
157
292
|
# For example:
|
|
158
293
|
# b.remove_category "Spam"
|
|
@@ -163,14 +298,223 @@ module Classifier
|
|
|
163
298
|
#
|
|
164
299
|
# @rbs (String | Symbol) -> void
|
|
165
300
|
def remove_category(category)
|
|
301
|
+
category = category.prepare_category_name
|
|
302
|
+
synchronize do
|
|
303
|
+
raise StandardError, "No such category: #{category}" unless @categories.key?(category)
|
|
304
|
+
|
|
305
|
+
invalidate_caches
|
|
306
|
+
@dirty = true
|
|
307
|
+
@total_words -= @category_word_count[category].to_i
|
|
308
|
+
|
|
309
|
+
@categories.delete(category)
|
|
310
|
+
@category_counts.delete(category)
|
|
311
|
+
@category_word_count.delete(category)
|
|
312
|
+
end
|
|
313
|
+
end
|
|
314
|
+
|
|
315
|
+
# Trains the classifier from an IO stream.
|
|
316
|
+
# Each line in the stream is treated as a separate document.
|
|
317
|
+
# This is memory-efficient for large corpora.
|
|
318
|
+
#
|
|
319
|
+
# @example Train from a file
|
|
320
|
+
# classifier.train_from_stream(:spam, File.open('spam_corpus.txt'))
|
|
321
|
+
#
|
|
322
|
+
# @example With progress tracking
|
|
323
|
+
# classifier.train_from_stream(:spam, io, batch_size: 500) do |progress|
|
|
324
|
+
# puts "#{progress.completed} documents processed"
|
|
325
|
+
# end
|
|
326
|
+
#
|
|
327
|
+
# @rbs (String | Symbol, IO, ?batch_size: Integer) { (Streaming::Progress) -> void } -> void
|
|
328
|
+
def train_from_stream(category, io, batch_size: Streaming::DEFAULT_BATCH_SIZE)
|
|
166
329
|
category = category.prepare_category_name
|
|
167
330
|
raise StandardError, "No such category: #{category}" unless @categories.key?(category)
|
|
168
331
|
|
|
169
|
-
|
|
332
|
+
reader = Streaming::LineReader.new(io, batch_size: batch_size)
|
|
333
|
+
total = reader.estimate_line_count
|
|
334
|
+
progress = Streaming::Progress.new(total: total)
|
|
335
|
+
|
|
336
|
+
reader.each_batch do |batch|
|
|
337
|
+
train_batch_internal(category, batch)
|
|
338
|
+
progress.completed += batch.size
|
|
339
|
+
progress.current_batch += 1
|
|
340
|
+
yield progress if block_given?
|
|
341
|
+
end
|
|
342
|
+
end
|
|
343
|
+
|
|
344
|
+
# Trains the classifier with an array of documents in batches.
|
|
345
|
+
# Reduces lock contention by processing multiple documents per synchronize call.
|
|
346
|
+
#
|
|
347
|
+
# @example Positional style
|
|
348
|
+
# classifier.train_batch(:spam, documents, batch_size: 100)
|
|
349
|
+
#
|
|
350
|
+
# @example Keyword style
|
|
351
|
+
# classifier.train_batch(spam: documents, ham: other_docs, batch_size: 100)
|
|
352
|
+
#
|
|
353
|
+
# @example With progress tracking
|
|
354
|
+
# classifier.train_batch(:spam, documents, batch_size: 100) do |progress|
|
|
355
|
+
# puts "#{progress.percent}% complete"
|
|
356
|
+
# end
|
|
357
|
+
#
|
|
358
|
+
# @rbs (?(String | Symbol)?, ?Array[String]?, ?batch_size: Integer, **Array[String]) { (Streaming::Progress) -> void } -> void
|
|
359
|
+
def train_batch(category = nil, documents = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &block)
|
|
360
|
+
if category && documents
|
|
361
|
+
train_batch_for_category(category, documents, batch_size: batch_size, &block)
|
|
362
|
+
else
|
|
363
|
+
categories.each do |cat, docs|
|
|
364
|
+
train_batch_for_category(cat, Array(docs), batch_size: batch_size, &block)
|
|
365
|
+
end
|
|
366
|
+
end
|
|
367
|
+
end
|
|
368
|
+
|
|
369
|
+
# Loads a classifier from a checkpoint.
|
|
370
|
+
#
|
|
371
|
+
# @rbs (storage: Storage::Base, checkpoint_id: String) -> Bayes
|
|
372
|
+
def self.load_checkpoint(storage:, checkpoint_id:)
|
|
373
|
+
raise ArgumentError, 'Storage must be File storage for checkpoints' unless storage.is_a?(Storage::File)
|
|
374
|
+
|
|
375
|
+
dir = File.dirname(storage.path)
|
|
376
|
+
base = File.basename(storage.path, '.*')
|
|
377
|
+
ext = File.extname(storage.path)
|
|
378
|
+
checkpoint_path = File.join(dir, "#{base}_checkpoint_#{checkpoint_id}#{ext}")
|
|
379
|
+
|
|
380
|
+
checkpoint_storage = Storage::File.new(path: checkpoint_path)
|
|
381
|
+
instance = load(storage: checkpoint_storage)
|
|
382
|
+
instance.storage = storage
|
|
383
|
+
instance
|
|
384
|
+
end
|
|
385
|
+
|
|
386
|
+
private
|
|
387
|
+
|
|
388
|
+
# Trains a batch of documents for a single category.
|
|
389
|
+
# @rbs (String | Symbol, Array[String], ?batch_size: Integer) { (Streaming::Progress) -> void } -> void
|
|
390
|
+
def train_batch_for_category(category, documents, batch_size: Streaming::DEFAULT_BATCH_SIZE)
|
|
391
|
+
category = category.prepare_category_name
|
|
392
|
+
raise StandardError, "No such category: #{category}" unless @categories.key?(category)
|
|
393
|
+
|
|
394
|
+
progress = Streaming::Progress.new(total: documents.size)
|
|
395
|
+
|
|
396
|
+
documents.each_slice(batch_size) do |batch|
|
|
397
|
+
train_batch_internal(category, batch)
|
|
398
|
+
progress.completed += batch.size
|
|
399
|
+
progress.current_batch += 1
|
|
400
|
+
yield progress if block_given?
|
|
401
|
+
end
|
|
402
|
+
end
|
|
403
|
+
|
|
404
|
+
# Internal method to train a batch of documents.
|
|
405
|
+
# Uses a single synchronize block for the entire batch.
|
|
406
|
+
# @rbs (Symbol, Array[String]) -> void
|
|
407
|
+
def train_batch_internal(category, batch)
|
|
408
|
+
synchronize do
|
|
409
|
+
invalidate_caches
|
|
410
|
+
@dirty = true
|
|
411
|
+
batch.each do |text|
|
|
412
|
+
word_hash = text.word_hash
|
|
413
|
+
@category_counts[category] += 1
|
|
414
|
+
word_hash.each do |word, count|
|
|
415
|
+
@categories[category][word] ||= 0
|
|
416
|
+
@categories[category][word] += count
|
|
417
|
+
@total_words += count
|
|
418
|
+
@category_word_count[category] += count
|
|
419
|
+
end
|
|
420
|
+
end
|
|
421
|
+
end
|
|
422
|
+
end
|
|
423
|
+
|
|
424
|
+
# Core training logic for a single category and text.
|
|
425
|
+
# @rbs (String | Symbol, String) -> void
|
|
426
|
+
def train_single(category, text)
|
|
427
|
+
category = category.prepare_category_name
|
|
428
|
+
word_hash = text.word_hash
|
|
429
|
+
synchronize do
|
|
430
|
+
invalidate_caches
|
|
431
|
+
@dirty = true
|
|
432
|
+
@category_counts[category] += 1
|
|
433
|
+
word_hash.each do |word, count|
|
|
434
|
+
@categories[category][word] ||= 0
|
|
435
|
+
@categories[category][word] += count
|
|
436
|
+
@total_words += count
|
|
437
|
+
@category_word_count[category] += count
|
|
438
|
+
end
|
|
439
|
+
end
|
|
440
|
+
end
|
|
441
|
+
|
|
442
|
+
# Core untraining logic for a single category and text.
|
|
443
|
+
# @rbs (String | Symbol, String) -> void
|
|
444
|
+
def untrain_single(category, text)
|
|
445
|
+
category = category.prepare_category_name
|
|
446
|
+
word_hash = text.word_hash
|
|
447
|
+
synchronize do
|
|
448
|
+
invalidate_caches
|
|
449
|
+
@dirty = true
|
|
450
|
+
@category_counts[category] -= 1
|
|
451
|
+
word_hash.each do |word, count|
|
|
452
|
+
next unless @total_words >= 0
|
|
453
|
+
|
|
454
|
+
orig = @categories[category][word] || 0
|
|
455
|
+
@categories[category][word] ||= 0
|
|
456
|
+
@categories[category][word] -= count
|
|
457
|
+
if @categories[category][word] <= 0
|
|
458
|
+
@categories[category].delete(word)
|
|
459
|
+
count = orig
|
|
460
|
+
end
|
|
461
|
+
@category_word_count[category] -= count if @category_word_count[category] >= count
|
|
462
|
+
@total_words -= count
|
|
463
|
+
end
|
|
464
|
+
end
|
|
465
|
+
end
|
|
466
|
+
|
|
467
|
+
# Restores classifier state from a JSON string (used by reload)
|
|
468
|
+
# @rbs (String) -> void
|
|
469
|
+
def restore_from_json(json)
|
|
470
|
+
data = JSON.parse(json)
|
|
471
|
+
raise ArgumentError, "Invalid classifier type: #{data['type']}" unless data['type'] == 'bayes'
|
|
472
|
+
|
|
473
|
+
synchronize do
|
|
474
|
+
restore_state(data)
|
|
475
|
+
end
|
|
476
|
+
end
|
|
477
|
+
|
|
478
|
+
# Restores classifier state from a hash (used by from_json)
|
|
479
|
+
# @rbs (Hash[String, untyped]) -> void
|
|
480
|
+
def restore_state(data)
|
|
481
|
+
mu_initialize
|
|
482
|
+
@categories = {} #: Hash[Symbol, Hash[Symbol, Integer]]
|
|
483
|
+
@total_words = data['total_words']
|
|
484
|
+
@category_counts = Hash.new(0) #: Hash[Symbol, Integer]
|
|
485
|
+
@category_word_count = Hash.new(0) #: Hash[Symbol, Integer]
|
|
486
|
+
@cached_training_count = nil
|
|
487
|
+
@cached_vocab_size = nil
|
|
488
|
+
@dirty = false
|
|
489
|
+
@storage = nil
|
|
490
|
+
|
|
491
|
+
data['categories'].each do |cat_name, words|
|
|
492
|
+
@categories[cat_name.to_sym] = words.transform_keys(&:to_sym)
|
|
493
|
+
end
|
|
494
|
+
|
|
495
|
+
data['category_counts'].each do |cat_name, count|
|
|
496
|
+
@category_counts[cat_name.to_sym] = count
|
|
497
|
+
end
|
|
498
|
+
|
|
499
|
+
data['category_word_count'].each do |cat_name, count|
|
|
500
|
+
@category_word_count[cat_name.to_sym] = count
|
|
501
|
+
end
|
|
502
|
+
end
|
|
503
|
+
|
|
504
|
+
# @rbs () -> void
|
|
505
|
+
def invalidate_caches
|
|
506
|
+
@cached_training_count = nil
|
|
507
|
+
@cached_vocab_size = nil
|
|
508
|
+
end
|
|
509
|
+
|
|
510
|
+
# @rbs () -> Float
|
|
511
|
+
def cached_training_count
|
|
512
|
+
@cached_training_count ||= @category_counts.values.sum.to_f
|
|
513
|
+
end
|
|
170
514
|
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
@
|
|
515
|
+
# @rbs () -> Integer
|
|
516
|
+
def cached_vocab_size
|
|
517
|
+
@cached_vocab_size ||= [@categories.values.flat_map(&:keys).uniq.size, 1].max
|
|
174
518
|
end
|
|
175
519
|
end
|
|
176
520
|
end
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# rbs_inline: enabled
|
|
2
|
+
|
|
3
|
+
# Author:: Lucas Carlson (mailto:lucas@rufy.com)
|
|
4
|
+
# Copyright:: Copyright (c) 2005 Lucas Carlson
|
|
5
|
+
# License:: LGPL
|
|
6
|
+
|
|
7
|
+
module Classifier
|
|
8
|
+
# Base error class for all Classifier errors
|
|
9
|
+
class Error < StandardError; end
|
|
10
|
+
|
|
11
|
+
# Raised when reload would discard unsaved changes
|
|
12
|
+
class UnsavedChangesError < Error; end
|
|
13
|
+
|
|
14
|
+
# Raised when a storage operation fails
|
|
15
|
+
class StorageError < Error; end
|
|
16
|
+
|
|
17
|
+
# Raised when using an unfitted model
|
|
18
|
+
class NotFittedError < Error; end
|
|
19
|
+
end
|
|
@@ -21,12 +21,20 @@ end
|
|
|
21
21
|
class Vector
|
|
22
22
|
EPSILON = 1e-10
|
|
23
23
|
|
|
24
|
+
# Cache magnitude since Vector is immutable after creation
|
|
25
|
+
# Note: We undefine the matrix gem's normalize method first, then redefine it
|
|
26
|
+
# to provide a more robust implementation that handles zero vectors
|
|
27
|
+
undef_method :normalize if method_defined?(:normalize)
|
|
28
|
+
|
|
24
29
|
def magnitude
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
sum_of_squares
|
|
30
|
+
# Cache magnitude since Vector is immutable after creation
|
|
31
|
+
@magnitude ||= begin
|
|
32
|
+
sum_of_squares = 0.to_r
|
|
33
|
+
size.times do |i|
|
|
34
|
+
sum_of_squares += self[i]**2.to_r
|
|
35
|
+
end
|
|
36
|
+
Math.sqrt(sum_of_squares.to_f)
|
|
28
37
|
end
|
|
29
|
-
Math.sqrt(sum_of_squares.to_f)
|
|
30
38
|
end
|
|
31
39
|
|
|
32
40
|
def normalize
|