classifier 2.0.0 → 2.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,68 +4,68 @@
4
4
  # Copyright:: Copyright (c) 2005 Lucas Carlson
5
5
  # License:: LGPL
6
6
 
7
+ require 'json'
8
+ require 'mutex_m'
9
+
7
10
  module Classifier
8
- class Bayes
11
+ class Bayes # rubocop:disable Metrics/ClassLength
12
+ include Mutex_m
13
+ include Streaming
14
+
9
15
  # @rbs @categories: Hash[Symbol, Hash[Symbol, Integer]]
10
16
  # @rbs @total_words: Integer
11
17
  # @rbs @category_counts: Hash[Symbol, Integer]
12
18
  # @rbs @category_word_count: Hash[Symbol, Integer]
19
+ # @rbs @cached_training_count: Float?
20
+ # @rbs @cached_vocab_size: Integer?
21
+ # @rbs @dirty: bool
22
+ # @rbs @storage: Storage::Base?
23
+
24
+ attr_accessor :storage
13
25
 
14
26
  # The class can be created with one or more categories, each of which will be
15
27
  # initialized and given a training method. E.g.,
16
28
  # b = Classifier::Bayes.new 'Interesting', 'Uninteresting', 'Spam'
17
- # @rbs (*String | Symbol) -> void
29
+ # b = Classifier::Bayes.new ['Interesting', 'Uninteresting', 'Spam']
30
+ # @rbs (*String | Symbol | Array[String | Symbol]) -> void
18
31
  def initialize(*categories)
32
+ super()
19
33
  @categories = {}
20
- categories.each { |category| @categories[category.prepare_category_name] = {} }
34
+ categories.flatten.each { |category| @categories[category.prepare_category_name] = {} }
21
35
  @total_words = 0
22
36
  @category_counts = Hash.new(0)
23
37
  @category_word_count = Hash.new(0)
38
+ @cached_training_count = nil
39
+ @cached_vocab_size = nil
40
+ @dirty = false
41
+ @storage = nil
24
42
  end
25
43
 
26
- # Provides a general training method for all categories specified in Bayes#new
27
- # For example:
28
- # b = Classifier::Bayes.new 'This', 'That', 'the_other'
29
- # b.train :this, "This text"
30
- # b.train "that", "That text"
31
- # b.train "The other", "The other text"
44
+ # Trains the classifier with text for a category.
32
45
  #
33
- # @rbs (String | Symbol, String) -> void
34
- def train(category, text)
35
- category = category.prepare_category_name
36
- @category_counts[category] += 1
37
- text.word_hash.each do |word, count|
38
- @categories[category][word] ||= 0
39
- @categories[category][word] += count
40
- @total_words += count
41
- @category_word_count[category] += count
46
+ # b.train(spam: "Buy now!", ham: ["Hello", "Meeting tomorrow"])
47
+ # b.train(:spam, "legacy positional API")
48
+ #
49
+ # @rbs (?(String | Symbol)?, ?String?, **(String | Array[String])) -> void
50
+ def train(category = nil, text = nil, **categories)
51
+ return train_single(category, text) if category && text
52
+
53
+ categories.each do |cat, texts|
54
+ (texts.is_a?(Array) ? texts : [texts]).each { |t| train_single(cat, t) }
42
55
  end
43
56
  end
44
57
 
45
- # Provides a untraining method for all categories specified in Bayes#new
46
- # Be very careful with this method.
58
+ # Removes training data. Be careful with this method.
47
59
  #
48
- # For example:
49
- # b = Classifier::Bayes.new 'This', 'That', 'the_other'
50
- # b.train :this, "This text"
51
- # b.untrain :this, "This text"
60
+ # b.untrain(spam: "Buy now!")
61
+ # b.untrain(:spam, "legacy positional API")
52
62
  #
53
- # @rbs (String | Symbol, String) -> void
54
- def untrain(category, text)
55
- category = category.prepare_category_name
56
- @category_counts[category] -= 1
57
- text.word_hash.each do |word, count|
58
- next unless @total_words >= 0
59
-
60
- orig = @categories[category][word] || 0
61
- @categories[category][word] ||= 0
62
- @categories[category][word] -= count
63
- if @categories[category][word] <= 0
64
- @categories[category].delete(word)
65
- count = orig
66
- end
67
- @category_word_count[category] -= count if @category_word_count[category] >= count
68
- @total_words -= count
63
+ # @rbs (?(String | Symbol)?, ?String?, **(String | Array[String])) -> void
64
+ def untrain(category = nil, text = nil, **categories)
65
+ return untrain_single(category, text) if category && text
66
+
67
+ categories.each do |cat, texts|
68
+ (texts.is_a?(Array) ? texts : [texts]).each { |t| untrain_single(cat, t) }
69
69
  end
70
70
  end
71
71
 
@@ -77,17 +77,19 @@ module Classifier
77
77
  # @rbs (String) -> Hash[String, Float]
78
78
  def classifications(text)
79
79
  words = text.word_hash.keys
80
- training_count = @category_counts.values.sum.to_f
81
- vocab_size = [@categories.values.flat_map(&:keys).uniq.size, 1].max
80
+ synchronize do
81
+ training_count = cached_training_count
82
+ vocab_size = cached_vocab_size
82
83
 
83
- @categories.to_h do |category, category_words|
84
- smoothed_total = ((@category_word_count[category] || 0) + vocab_size).to_f
84
+ @categories.to_h do |category, category_words|
85
+ smoothed_total = ((@category_word_count[category] || 0) + vocab_size).to_f
85
86
 
86
- # Laplace smoothing: P(word|category) = (count + α) / (total + α * V)
87
- word_score = words.sum { |w| Math.log(((category_words[w] || 0) + 1) / smoothed_total) }
88
- prior_score = Math.log((@category_counts[category] || 0.1) / training_count)
87
+ # Laplace smoothing: P(word|category) = (count + α) / (total + α * V)
88
+ word_score = words.sum { |w| Math.log(((category_words[w] || 0) + 1) / smoothed_total) }
89
+ prior_score = Math.log((@category_counts[category] || 0.1) / training_count)
89
90
 
90
- [category.to_s, word_score + prior_score]
91
+ [category.to_s, word_score + prior_score]
92
+ end
91
93
  end
92
94
  end
93
95
 
@@ -104,6 +106,119 @@ module Classifier
104
106
  best.first.to_s
105
107
  end
106
108
 
109
+ # Returns a hash representation of the classifier state.
110
+ # This can be converted to JSON or used directly.
111
+ #
112
+ # @rbs (?untyped) -> untyped
113
+ def as_json(_options = nil)
114
+ {
115
+ version: 1,
116
+ type: 'bayes',
117
+ categories: @categories.transform_keys(&:to_s).transform_values { |v| v.transform_keys(&:to_s) },
118
+ total_words: @total_words,
119
+ category_counts: @category_counts.transform_keys(&:to_s),
120
+ category_word_count: @category_word_count.transform_keys(&:to_s)
121
+ }
122
+ end
123
+
124
+ # Serializes the classifier state to a JSON string.
125
+ # This can be saved to a file and later loaded with Bayes.from_json.
126
+ #
127
+ # @rbs (?untyped) -> String
128
+ def to_json(_options = nil)
129
+ as_json.to_json
130
+ end
131
+
132
+ # Loads a classifier from a JSON string or a Hash created by #to_json or #as_json.
133
+ #
134
+ # @rbs (String | Hash[String, untyped]) -> Bayes
135
+ def self.from_json(json)
136
+ data = json.is_a?(String) ? JSON.parse(json) : json
137
+ raise ArgumentError, "Invalid classifier type: #{data['type']}" unless data['type'] == 'bayes'
138
+
139
+ instance = allocate
140
+ instance.send(:restore_state, data)
141
+ instance
142
+ end
143
+
144
+ # Saves the classifier to the configured storage.
145
+ # Raises ArgumentError if no storage is configured.
146
+ #
147
+ # @rbs () -> void
148
+ def save
149
+ raise ArgumentError, 'No storage configured. Use save_to_file(path) or set storage=' unless storage
150
+
151
+ storage.write(to_json)
152
+ @dirty = false
153
+ end
154
+
155
+ # Saves the classifier state to a file (legacy API).
156
+ #
157
+ # @rbs (String) -> Integer
158
+ def save_to_file(path)
159
+ result = File.write(path, to_json)
160
+ @dirty = false
161
+ result
162
+ end
163
+
164
+ # Reloads the classifier from the configured storage.
165
+ # Raises UnsavedChangesError if there are unsaved changes.
166
+ # Use reload! to force reload and discard changes.
167
+ #
168
+ # @rbs () -> self
169
+ def reload
170
+ raise ArgumentError, 'No storage configured' unless storage
171
+ raise UnsavedChangesError, 'Unsaved changes would be lost. Call save first or use reload!' if @dirty
172
+
173
+ data = storage.read
174
+ raise StorageError, 'No saved state found' unless data
175
+
176
+ restore_from_json(data)
177
+ @dirty = false
178
+ self
179
+ end
180
+
181
+ # Force reloads the classifier from storage, discarding any unsaved changes.
182
+ #
183
+ # @rbs () -> self
184
+ def reload!
185
+ raise ArgumentError, 'No storage configured' unless storage
186
+
187
+ data = storage.read
188
+ raise StorageError, 'No saved state found' unless data
189
+
190
+ restore_from_json(data)
191
+ @dirty = false
192
+ self
193
+ end
194
+
195
+ # Returns true if there are unsaved changes.
196
+ #
197
+ # @rbs () -> bool
198
+ def dirty?
199
+ @dirty
200
+ end
201
+
202
+ # Loads a classifier from the configured storage.
203
+ # The storage is set on the returned instance.
204
+ #
205
+ # @rbs (storage: Storage::Base) -> Bayes
206
+ def self.load(storage:)
207
+ data = storage.read
208
+ raise StorageError, 'No saved state found' unless data
209
+
210
+ instance = from_json(data)
211
+ instance.storage = storage
212
+ instance
213
+ end
214
+
215
+ # Loads a classifier from a file (legacy API).
216
+ #
217
+ # @rbs (String) -> Bayes
218
+ def self.load_from_file(path)
219
+ from_json(File.read(path))
220
+ end
221
+
107
222
  #
108
223
  # Provides training and untraining methods for the categories specified in Bayes#new
109
224
  # For example:
@@ -134,7 +249,7 @@ module Classifier
134
249
  #
135
250
  # @rbs () -> Array[String]
136
251
  def categories
137
- @categories.keys.collect(&:to_s)
252
+ synchronize { @categories.keys.collect(&:to_s) }
138
253
  end
139
254
 
140
255
  # Allows you to add categories to the classifier.
@@ -148,11 +263,31 @@ module Classifier
148
263
  #
149
264
  # @rbs (String | Symbol) -> Hash[Symbol, Integer]
150
265
  def add_category(category)
151
- @categories[category.prepare_category_name] = {}
266
+ synchronize do
267
+ invalidate_caches
268
+ @dirty = true
269
+ @categories[category.prepare_category_name] = {}
270
+ end
152
271
  end
153
272
 
154
273
  alias append_category add_category
155
274
 
275
+ # Custom marshal serialization to exclude mutex state
276
+ # @rbs () -> Array[untyped]
277
+ def marshal_dump
278
+ [@categories, @total_words, @category_counts, @category_word_count, @dirty]
279
+ end
280
+
281
+ # Custom marshal deserialization to recreate mutex
282
+ # @rbs (Array[untyped]) -> void
283
+ def marshal_load(data)
284
+ mu_initialize
285
+ @categories, @total_words, @category_counts, @category_word_count, @dirty = data
286
+ @cached_training_count = nil
287
+ @cached_vocab_size = nil
288
+ @storage = nil
289
+ end
290
+
156
291
  # Allows you to remove categories from the classifier.
157
292
  # For example:
158
293
  # b.remove_category "Spam"
@@ -163,14 +298,223 @@ module Classifier
163
298
  #
164
299
  # @rbs (String | Symbol) -> void
165
300
  def remove_category(category)
301
+ category = category.prepare_category_name
302
+ synchronize do
303
+ raise StandardError, "No such category: #{category}" unless @categories.key?(category)
304
+
305
+ invalidate_caches
306
+ @dirty = true
307
+ @total_words -= @category_word_count[category].to_i
308
+
309
+ @categories.delete(category)
310
+ @category_counts.delete(category)
311
+ @category_word_count.delete(category)
312
+ end
313
+ end
314
+
315
+ # Trains the classifier from an IO stream.
316
+ # Each line in the stream is treated as a separate document.
317
+ # This is memory-efficient for large corpora.
318
+ #
319
+ # @example Train from a file
320
+ # classifier.train_from_stream(:spam, File.open('spam_corpus.txt'))
321
+ #
322
+ # @example With progress tracking
323
+ # classifier.train_from_stream(:spam, io, batch_size: 500) do |progress|
324
+ # puts "#{progress.completed} documents processed"
325
+ # end
326
+ #
327
+ # @rbs (String | Symbol, IO, ?batch_size: Integer) { (Streaming::Progress) -> void } -> void
328
+ def train_from_stream(category, io, batch_size: Streaming::DEFAULT_BATCH_SIZE)
166
329
  category = category.prepare_category_name
167
330
  raise StandardError, "No such category: #{category}" unless @categories.key?(category)
168
331
 
169
- @total_words -= @category_word_count[category].to_i
332
+ reader = Streaming::LineReader.new(io, batch_size: batch_size)
333
+ total = reader.estimate_line_count
334
+ progress = Streaming::Progress.new(total: total)
335
+
336
+ reader.each_batch do |batch|
337
+ train_batch_internal(category, batch)
338
+ progress.completed += batch.size
339
+ progress.current_batch += 1
340
+ yield progress if block_given?
341
+ end
342
+ end
343
+
344
+ # Trains the classifier with an array of documents in batches.
345
+ # Reduces lock contention by processing multiple documents per synchronize call.
346
+ #
347
+ # @example Positional style
348
+ # classifier.train_batch(:spam, documents, batch_size: 100)
349
+ #
350
+ # @example Keyword style
351
+ # classifier.train_batch(spam: documents, ham: other_docs, batch_size: 100)
352
+ #
353
+ # @example With progress tracking
354
+ # classifier.train_batch(:spam, documents, batch_size: 100) do |progress|
355
+ # puts "#{progress.percent}% complete"
356
+ # end
357
+ #
358
+ # @rbs (?(String | Symbol)?, ?Array[String]?, ?batch_size: Integer, **Array[String]) { (Streaming::Progress) -> void } -> void
359
+ def train_batch(category = nil, documents = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &block)
360
+ if category && documents
361
+ train_batch_for_category(category, documents, batch_size: batch_size, &block)
362
+ else
363
+ categories.each do |cat, docs|
364
+ train_batch_for_category(cat, Array(docs), batch_size: batch_size, &block)
365
+ end
366
+ end
367
+ end
368
+
369
+ # Loads a classifier from a checkpoint.
370
+ #
371
+ # @rbs (storage: Storage::Base, checkpoint_id: String) -> Bayes
372
+ def self.load_checkpoint(storage:, checkpoint_id:)
373
+ raise ArgumentError, 'Storage must be File storage for checkpoints' unless storage.is_a?(Storage::File)
374
+
375
+ dir = File.dirname(storage.path)
376
+ base = File.basename(storage.path, '.*')
377
+ ext = File.extname(storage.path)
378
+ checkpoint_path = File.join(dir, "#{base}_checkpoint_#{checkpoint_id}#{ext}")
379
+
380
+ checkpoint_storage = Storage::File.new(path: checkpoint_path)
381
+ instance = load(storage: checkpoint_storage)
382
+ instance.storage = storage
383
+ instance
384
+ end
385
+
386
+ private
387
+
388
+ # Trains a batch of documents for a single category.
389
+ # @rbs (String | Symbol, Array[String], ?batch_size: Integer) { (Streaming::Progress) -> void } -> void
390
+ def train_batch_for_category(category, documents, batch_size: Streaming::DEFAULT_BATCH_SIZE)
391
+ category = category.prepare_category_name
392
+ raise StandardError, "No such category: #{category}" unless @categories.key?(category)
393
+
394
+ progress = Streaming::Progress.new(total: documents.size)
395
+
396
+ documents.each_slice(batch_size) do |batch|
397
+ train_batch_internal(category, batch)
398
+ progress.completed += batch.size
399
+ progress.current_batch += 1
400
+ yield progress if block_given?
401
+ end
402
+ end
403
+
404
+ # Internal method to train a batch of documents.
405
+ # Uses a single synchronize block for the entire batch.
406
+ # @rbs (Symbol, Array[String]) -> void
407
+ def train_batch_internal(category, batch)
408
+ synchronize do
409
+ invalidate_caches
410
+ @dirty = true
411
+ batch.each do |text|
412
+ word_hash = text.word_hash
413
+ @category_counts[category] += 1
414
+ word_hash.each do |word, count|
415
+ @categories[category][word] ||= 0
416
+ @categories[category][word] += count
417
+ @total_words += count
418
+ @category_word_count[category] += count
419
+ end
420
+ end
421
+ end
422
+ end
423
+
424
+ # Core training logic for a single category and text.
425
+ # @rbs (String | Symbol, String) -> void
426
+ def train_single(category, text)
427
+ category = category.prepare_category_name
428
+ word_hash = text.word_hash
429
+ synchronize do
430
+ invalidate_caches
431
+ @dirty = true
432
+ @category_counts[category] += 1
433
+ word_hash.each do |word, count|
434
+ @categories[category][word] ||= 0
435
+ @categories[category][word] += count
436
+ @total_words += count
437
+ @category_word_count[category] += count
438
+ end
439
+ end
440
+ end
441
+
442
+ # Core untraining logic for a single category and text.
443
+ # @rbs (String | Symbol, String) -> void
444
+ def untrain_single(category, text)
445
+ category = category.prepare_category_name
446
+ word_hash = text.word_hash
447
+ synchronize do
448
+ invalidate_caches
449
+ @dirty = true
450
+ @category_counts[category] -= 1
451
+ word_hash.each do |word, count|
452
+ next unless @total_words >= 0
453
+
454
+ orig = @categories[category][word] || 0
455
+ @categories[category][word] ||= 0
456
+ @categories[category][word] -= count
457
+ if @categories[category][word] <= 0
458
+ @categories[category].delete(word)
459
+ count = orig
460
+ end
461
+ @category_word_count[category] -= count if @category_word_count[category] >= count
462
+ @total_words -= count
463
+ end
464
+ end
465
+ end
466
+
467
+ # Restores classifier state from a JSON string (used by reload)
468
+ # @rbs (String) -> void
469
+ def restore_from_json(json)
470
+ data = JSON.parse(json)
471
+ raise ArgumentError, "Invalid classifier type: #{data['type']}" unless data['type'] == 'bayes'
472
+
473
+ synchronize do
474
+ restore_state(data)
475
+ end
476
+ end
477
+
478
+ # Restores classifier state from a hash (used by from_json)
479
+ # @rbs (Hash[String, untyped]) -> void
480
+ def restore_state(data)
481
+ mu_initialize
482
+ @categories = {} #: Hash[Symbol, Hash[Symbol, Integer]]
483
+ @total_words = data['total_words']
484
+ @category_counts = Hash.new(0) #: Hash[Symbol, Integer]
485
+ @category_word_count = Hash.new(0) #: Hash[Symbol, Integer]
486
+ @cached_training_count = nil
487
+ @cached_vocab_size = nil
488
+ @dirty = false
489
+ @storage = nil
490
+
491
+ data['categories'].each do |cat_name, words|
492
+ @categories[cat_name.to_sym] = words.transform_keys(&:to_sym)
493
+ end
494
+
495
+ data['category_counts'].each do |cat_name, count|
496
+ @category_counts[cat_name.to_sym] = count
497
+ end
498
+
499
+ data['category_word_count'].each do |cat_name, count|
500
+ @category_word_count[cat_name.to_sym] = count
501
+ end
502
+ end
503
+
504
+ # @rbs () -> void
505
+ def invalidate_caches
506
+ @cached_training_count = nil
507
+ @cached_vocab_size = nil
508
+ end
509
+
510
+ # @rbs () -> Float
511
+ def cached_training_count
512
+ @cached_training_count ||= @category_counts.values.sum.to_f
513
+ end
170
514
 
171
- @categories.delete(category)
172
- @category_counts.delete(category)
173
- @category_word_count.delete(category)
515
+ # @rbs () -> Integer
516
+ def cached_vocab_size
517
+ @cached_vocab_size ||= [@categories.values.flat_map(&:keys).uniq.size, 1].max
174
518
  end
175
519
  end
176
520
  end
@@ -0,0 +1,19 @@
1
+ # rbs_inline: enabled
2
+
3
+ # Author:: Lucas Carlson (mailto:lucas@rufy.com)
4
+ # Copyright:: Copyright (c) 2005 Lucas Carlson
5
+ # License:: LGPL
6
+
7
+ module Classifier
8
+ # Base error class for all Classifier errors
9
+ class Error < StandardError; end
10
+
11
+ # Raised when reload would discard unsaved changes
12
+ class UnsavedChangesError < Error; end
13
+
14
+ # Raised when a storage operation fails
15
+ class StorageError < Error; end
16
+
17
+ # Raised when using an unfitted model
18
+ class NotFittedError < Error; end
19
+ end
@@ -21,12 +21,20 @@ end
21
21
  class Vector
22
22
  EPSILON = 1e-10
23
23
 
24
+ # Cache magnitude since Vector is immutable after creation
25
+ # Note: We undefine the matrix gem's normalize method first, then redefine it
26
+ # to provide a more robust implementation that handles zero vectors
27
+ undef_method :normalize if method_defined?(:normalize)
28
+
24
29
  def magnitude
25
- sum_of_squares = 0.to_r
26
- size.times do |i|
27
- sum_of_squares += self[i]**2.to_r
30
+ # Cache magnitude since Vector is immutable after creation
31
+ @magnitude ||= begin
32
+ sum_of_squares = 0.to_r
33
+ size.times do |i|
34
+ sum_of_squares += self[i]**2.to_r
35
+ end
36
+ Math.sqrt(sum_of_squares.to_f)
28
37
  end
29
- Math.sqrt(sum_of_squares.to_f)
30
38
  end
31
39
 
32
40
  def normalize