classifier 2.1.0 → 2.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,571 @@
1
+ # rbs_inline: enabled
2
+
3
+ # Author:: Lucas Carlson (mailto:lucas@rufy.com)
4
+ # Copyright:: Copyright (c) 2024 Lucas Carlson
5
+ # License:: LGPL
6
+
7
+ require 'json'
8
+ require 'mutex_m'
9
+
10
+ module Classifier
11
+ # Logistic Regression (MaxEnt) classifier using Stochastic Gradient Descent.
12
+ # Often provides better accuracy than Naive Bayes while remaining fast and interpretable.
13
+ #
14
+ # Example:
15
+ # classifier = Classifier::LogisticRegression.new(:spam, :ham)
16
+ # classifier.train(spam: ["Buy now!", "Free money!!!"])
17
+ # classifier.train(ham: ["Meeting tomorrow", "Project update"])
18
+ # classifier.classify("Claim your prize!") # => "Spam"
19
+ # classifier.probabilities("Claim your prize!") # => {"Spam" => 0.92, "Ham" => 0.08}
20
+ #
21
+ class LogisticRegression # rubocop:disable Metrics/ClassLength
22
+ include Mutex_m
23
+ include Streaming
24
+
25
+ # @rbs @categories: Array[Symbol]
26
+ # @rbs @weights: Hash[Symbol, Hash[Symbol, Float]]
27
+ # @rbs @bias: Hash[Symbol, Float]
28
+ # @rbs @vocabulary: Hash[Symbol, bool]
29
+ # @rbs @training_data: Array[{category: Symbol, features: Hash[Symbol, Integer]}]
30
+ # @rbs @learning_rate: Float
31
+ # @rbs @regularization: Float
32
+ # @rbs @max_iterations: Integer
33
+ # @rbs @tolerance: Float
34
+ # @rbs @fitted: bool
35
+ # @rbs @dirty: bool
36
+ # @rbs @storage: Storage::Base?
37
+
38
+ attr_accessor :storage
39
+
40
+ DEFAULT_LEARNING_RATE = 0.1
41
+ DEFAULT_REGULARIZATION = 0.01
42
+ DEFAULT_MAX_ITERATIONS = 100
43
+ DEFAULT_TOLERANCE = 1e-4
44
+
45
+ # Creates a new Logistic Regression classifier with the specified categories.
46
+ #
47
+ # classifier = Classifier::LogisticRegression.new(:spam, :ham)
48
+ # classifier = Classifier::LogisticRegression.new('Positive', 'Negative', 'Neutral')
49
+ # classifier = Classifier::LogisticRegression.new(['Positive', 'Negative', 'Neutral'])
50
+ #
51
+ # Options:
52
+ # - learning_rate: Step size for gradient descent (default: 0.1)
53
+ # - regularization: L2 regularization strength (default: 0.01)
54
+ # - max_iterations: Maximum training iterations (default: 100)
55
+ # - tolerance: Convergence threshold (default: 1e-4)
56
+ #
57
+ # @rbs (*String | Symbol | Array[String | Symbol], ?learning_rate: Float, ?regularization: Float,
58
+ # ?max_iterations: Integer, ?tolerance: Float) -> void
59
+ def initialize(*categories, learning_rate: DEFAULT_LEARNING_RATE,
60
+ regularization: DEFAULT_REGULARIZATION,
61
+ max_iterations: DEFAULT_MAX_ITERATIONS,
62
+ tolerance: DEFAULT_TOLERANCE)
63
+ super()
64
+ categories = categories.flatten
65
+ raise ArgumentError, 'At least two categories required' if categories.size < 2
66
+
67
+ @categories = categories.map { |c| c.to_s.prepare_category_name }
68
+ @weights = @categories.to_h { |c| [c, {}] }
69
+ @bias = @categories.to_h { |c| [c, 0.0] }
70
+ @vocabulary = {}
71
+ @training_data = []
72
+ @learning_rate = learning_rate
73
+ @regularization = regularization
74
+ @max_iterations = max_iterations
75
+ @tolerance = tolerance
76
+ @fitted = false
77
+ @dirty = false
78
+ @storage = nil
79
+ end
80
+
81
+ # Trains the classifier with text for a category.
82
+ #
83
+ # classifier.train(spam: "Buy now!", ham: ["Hello", "Meeting tomorrow"])
84
+ # classifier.train(:spam, "legacy positional API")
85
+ #
86
+ # @rbs (?(String | Symbol)?, ?String?, **(String | Array[String])) -> void
87
+ def train(category = nil, text = nil, **categories)
88
+ return train_single(category, text) if category && text
89
+
90
+ categories.each do |cat, texts|
91
+ (texts.is_a?(Array) ? texts : [texts]).each { |t| train_single(cat, t) }
92
+ end
93
+ end
94
+
95
+ # Fits the model to all accumulated training data.
96
+ # Called automatically during classify/probabilities if not already fitted.
97
+ #
98
+ # @rbs () -> self
99
+ def fit
100
+ synchronize do
101
+ return self if @training_data.empty?
102
+
103
+ optimize_weights
104
+ @fitted = true
105
+ @dirty = false
106
+ end
107
+ self
108
+ end
109
+
110
+ # Returns the best matching category for the provided text.
111
+ #
112
+ # classifier.classify("Buy now!") # => "Spam"
113
+ #
114
+ # @rbs (String) -> String
115
+ def classify(text)
116
+ probs = probabilities(text)
117
+ best = probs.max_by { |_, v| v }
118
+ raise StandardError, 'No classifications available' unless best
119
+
120
+ best.first
121
+ end
122
+
123
+ # Returns probability distribution across all categories.
124
+ # Probabilities are well-calibrated (unlike Naive Bayes).
125
+ #
126
+ # classifier.probabilities("Buy now!")
127
+ # # => {"Spam" => 0.92, "Ham" => 0.08}
128
+ #
129
+ # @rbs (String) -> Hash[String, Float]
130
+ def probabilities(text)
131
+ fit unless @fitted
132
+
133
+ features = text.word_hash
134
+ synchronize do
135
+ softmax(compute_scores(features))
136
+ end
137
+ end
138
+
139
+ # Returns log-odds scores for each category (before softmax).
140
+ #
141
+ # @rbs (String) -> Hash[String, Float]
142
+ def classifications(text)
143
+ fit unless @fitted
144
+
145
+ features = text.word_hash
146
+ synchronize do
147
+ compute_scores(features).transform_keys(&:to_s)
148
+ end
149
+ end
150
+
151
+ # Returns feature weights for a category, sorted by importance.
152
+ # Positive weights indicate the feature supports the category.
153
+ #
154
+ # classifier.weights(:spam)
155
+ # # => {:free => 2.3, :buy => 1.8, :money => 1.5, ...}
156
+ #
157
+ # @rbs (String | Symbol, ?limit: Integer?) -> Hash[Symbol, Float]
158
+ def weights(category, limit: nil)
159
+ fit unless @fitted
160
+
161
+ cat = category.to_s.prepare_category_name
162
+ raise StandardError, "No such category: #{cat}" unless @weights.key?(cat)
163
+
164
+ sorted = @weights[cat].sort_by { |_, v| -v.abs }
165
+ sorted = sorted.first(limit) if limit
166
+ sorted.to_h
167
+ end
168
+
169
+ # Returns the list of categories.
170
+ #
171
+ # @rbs () -> Array[String]
172
+ def categories
173
+ synchronize { @categories.map(&:to_s) }
174
+ end
175
+
176
+ # Returns true if the model has been fitted.
177
+ #
178
+ # @rbs () -> bool
179
+ def fitted?
180
+ @fitted
181
+ end
182
+
183
+ # Returns true if there are unsaved changes.
184
+ #
185
+ # @rbs () -> bool
186
+ def dirty?
187
+ @dirty
188
+ end
189
+
190
+ # Provides training methods for the categories.
191
+ # classifier.train_spam "Buy now!"
192
+ def method_missing(name, *args)
193
+ category_match = name.to_s.match(/train_(\w+)/)
194
+ return super unless category_match
195
+
196
+ category = category_match[1].to_s.prepare_category_name
197
+ raise StandardError, "No such category: #{category}" unless @categories.include?(category)
198
+
199
+ args.each { |text| train(category, text) }
200
+ end
201
+
202
+ # @rbs (Symbol, ?bool) -> bool
203
+ def respond_to_missing?(name, include_private = false)
204
+ !!(name.to_s =~ /train_(\w+)/) || super
205
+ end
206
+
207
+ # Returns a hash representation of the classifier state.
208
+ #
209
+ # @rbs (?untyped) -> Hash[Symbol, untyped]
210
+ def as_json(_options = nil)
211
+ fit unless @fitted
212
+
213
+ {
214
+ version: 1,
215
+ type: 'logistic_regression',
216
+ categories: @categories.map(&:to_s),
217
+ weights: @weights.transform_keys(&:to_s).transform_values { |v| v.transform_keys(&:to_s) },
218
+ bias: @bias.transform_keys(&:to_s),
219
+ vocabulary: @vocabulary.keys.map(&:to_s),
220
+ learning_rate: @learning_rate,
221
+ regularization: @regularization,
222
+ max_iterations: @max_iterations,
223
+ tolerance: @tolerance
224
+ }
225
+ end
226
+
227
+ # Serializes the classifier state to a JSON string.
228
+ #
229
+ # @rbs (?untyped) -> String
230
+ def to_json(_options = nil)
231
+ JSON.generate(as_json)
232
+ end
233
+
234
+ # Loads a classifier from a JSON string or Hash.
235
+ #
236
+ # @rbs (String | Hash[String, untyped]) -> LogisticRegression
237
+ def self.from_json(json)
238
+ data = json.is_a?(String) ? JSON.parse(json) : json
239
+ raise ArgumentError, "Invalid classifier type: #{data['type']}" unless data['type'] == 'logistic_regression'
240
+
241
+ categories = data['categories'].map(&:to_sym)
242
+ instance = allocate
243
+ instance.send(:restore_state, data, categories)
244
+ instance
245
+ end
246
+
247
+ # Saves the classifier to the configured storage.
248
+ #
249
+ # @rbs () -> void
250
+ def save
251
+ raise ArgumentError, 'No storage configured' unless storage
252
+
253
+ storage.write(to_json)
254
+ @dirty = false
255
+ end
256
+
257
+ # Saves the classifier state to a file.
258
+ #
259
+ # @rbs (String) -> Integer
260
+ def save_to_file(path)
261
+ result = File.write(path, to_json)
262
+ @dirty = false
263
+ result
264
+ end
265
+
266
+ # Loads a classifier from the configured storage.
267
+ #
268
+ # @rbs (storage: Storage::Base) -> LogisticRegression
269
+ def self.load(storage:)
270
+ data = storage.read
271
+ raise StorageError, 'No saved state found' unless data
272
+
273
+ instance = from_json(data)
274
+ instance.storage = storage
275
+ instance
276
+ end
277
+
278
+ # Loads a classifier from a file.
279
+ #
280
+ # @rbs (String) -> LogisticRegression
281
+ def self.load_from_file(path)
282
+ from_json(File.read(path))
283
+ end
284
+
285
+ # Reloads the classifier from storage, raising if there are unsaved changes.
286
+ #
287
+ # @rbs () -> self
288
+ def reload
289
+ raise ArgumentError, 'No storage configured' unless storage
290
+ raise UnsavedChangesError, 'Unsaved changes would be lost. Call save first or use reload!' if @dirty
291
+
292
+ data = storage.read
293
+ raise StorageError, 'No saved state found' unless data
294
+
295
+ restore_from_json(data)
296
+ @dirty = false
297
+ self
298
+ end
299
+
300
+ # Force reloads the classifier from storage, discarding any unsaved changes.
301
+ #
302
+ # @rbs () -> self
303
+ def reload!
304
+ raise ArgumentError, 'No storage configured' unless storage
305
+
306
+ data = storage.read
307
+ raise StorageError, 'No saved state found' unless data
308
+
309
+ restore_from_json(data)
310
+ @dirty = false
311
+ self
312
+ end
313
+
314
+ # Custom marshal serialization to exclude mutex state.
315
+ #
316
+ # @rbs () -> Array[untyped]
317
+ def marshal_dump
318
+ fit unless @fitted
319
+ [@categories, @weights, @bias, @vocabulary, @learning_rate, @regularization,
320
+ @max_iterations, @tolerance, @fitted]
321
+ end
322
+
323
+ # Custom marshal deserialization to recreate mutex.
324
+ #
325
+ # @rbs (Array[untyped]) -> void
326
+ def marshal_load(data)
327
+ mu_initialize
328
+ @categories, @weights, @bias, @vocabulary, @learning_rate, @regularization,
329
+ @max_iterations, @tolerance, @fitted = data
330
+ @training_data = []
331
+ @dirty = false
332
+ @storage = nil
333
+ end
334
+
335
+ # Loads a classifier from a checkpoint.
336
+ #
337
+ # @rbs (storage: Storage::Base, checkpoint_id: String) -> LogisticRegression
338
+ def self.load_checkpoint(storage:, checkpoint_id:)
339
+ raise ArgumentError, 'Storage must be File storage for checkpoints' unless storage.is_a?(Storage::File)
340
+
341
+ dir = File.dirname(storage.path)
342
+ base = File.basename(storage.path, '.*')
343
+ ext = File.extname(storage.path)
344
+ checkpoint_path = File.join(dir, "#{base}_checkpoint_#{checkpoint_id}#{ext}")
345
+
346
+ checkpoint_storage = Storage::File.new(path: checkpoint_path)
347
+ instance = load(storage: checkpoint_storage)
348
+ instance.storage = storage
349
+ instance
350
+ end
351
+
352
+ # Trains the classifier from an IO stream.
353
+ # Each line in the stream is treated as a separate document.
354
+ # Note: The model is NOT automatically fitted after streaming.
355
+ # Call #fit to train the model after adding all data.
356
+ #
357
+ # @example Train from a file
358
+ # classifier.train_from_stream(:spam, File.open('spam_corpus.txt'))
359
+ # classifier.fit # Required to train the model
360
+ #
361
+ # @example With progress tracking
362
+ # classifier.train_from_stream(:spam, io, batch_size: 500) do |progress|
363
+ # puts "#{progress.completed} documents processed"
364
+ # end
365
+ # classifier.fit
366
+ #
367
+ # @rbs (String | Symbol, IO, ?batch_size: Integer) { (Streaming::Progress) -> void } -> void
368
+ def train_from_stream(category, io, batch_size: Streaming::DEFAULT_BATCH_SIZE)
369
+ category = category.to_s.prepare_category_name
370
+ raise StandardError, "No such category: #{category}" unless @categories.include?(category)
371
+
372
+ reader = Streaming::LineReader.new(io, batch_size: batch_size)
373
+ total = reader.estimate_line_count
374
+ progress = Streaming::Progress.new(total: total)
375
+
376
+ reader.each_batch do |batch|
377
+ synchronize do
378
+ batch.each do |text|
379
+ features = text.word_hash
380
+ features.each_key { |word| @vocabulary[word] = true }
381
+ @training_data << { category: category, features: features }
382
+ end
383
+ @fitted = false
384
+ @dirty = true
385
+ end
386
+ progress.completed += batch.size
387
+ progress.current_batch += 1
388
+ yield progress if block_given?
389
+ end
390
+ end
391
+
392
+ # Trains the classifier with an array of documents in batches.
393
+ # Note: The model is NOT automatically fitted after batch training.
394
+ # Call #fit to train the model after adding all data.
395
+ #
396
+ # @example Positional style
397
+ # classifier.train_batch(:spam, documents, batch_size: 100)
398
+ # classifier.fit
399
+ #
400
+ # @example Keyword style
401
+ # classifier.train_batch(spam: documents, ham: other_docs)
402
+ # classifier.fit
403
+ #
404
+ # @rbs (?(String | Symbol)?, ?Array[String]?, ?batch_size: Integer, **Array[String]) { (Streaming::Progress) -> void } -> void
405
+ def train_batch(category = nil, documents = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &block)
406
+ if category && documents
407
+ train_batch_for_category(category, documents, batch_size: batch_size, &block)
408
+ else
409
+ categories.each do |cat, docs|
410
+ train_batch_for_category(cat, Array(docs), batch_size: batch_size, &block)
411
+ end
412
+ end
413
+ end
414
+
415
+ private
416
+
417
+ # Trains a batch of documents for a single category.
418
+ # @rbs (String | Symbol, Array[String], ?batch_size: Integer) { (Streaming::Progress) -> void } -> void
419
+ def train_batch_for_category(category, documents, batch_size: Streaming::DEFAULT_BATCH_SIZE)
420
+ category = category.to_s.prepare_category_name
421
+ raise StandardError, "No such category: #{category}" unless @categories.include?(category)
422
+
423
+ progress = Streaming::Progress.new(total: documents.size)
424
+
425
+ documents.each_slice(batch_size) do |batch|
426
+ synchronize do
427
+ batch.each do |text|
428
+ features = text.word_hash
429
+ features.each_key { |word| @vocabulary[word] = true }
430
+ @training_data << { category: category, features: features }
431
+ end
432
+ @fitted = false
433
+ @dirty = true
434
+ end
435
+ progress.completed += batch.size
436
+ progress.current_batch += 1
437
+ yield progress if block_given?
438
+ end
439
+ end
440
+
441
+ # Core training logic for a single category and text.
442
+ # @rbs (String | Symbol, String) -> void
443
+ def train_single(category, text)
444
+ category = category.to_s.prepare_category_name
445
+ raise StandardError, "No such category: #{category}" unless @categories.include?(category)
446
+
447
+ features = text.word_hash
448
+ synchronize do
449
+ features.each_key { |word| @vocabulary[word] = true }
450
+ @training_data << { category: category, features: features }
451
+ @fitted = false
452
+ @dirty = true
453
+ end
454
+ end
455
+
456
+ # Optimizes weights using mini-batch SGD with L2 regularization.
457
+ # @rbs () -> void
458
+ def optimize_weights
459
+ return if @training_data.empty?
460
+
461
+ initialize_weights
462
+ prev_loss = Float::INFINITY
463
+
464
+ @max_iterations.times do
465
+ total_loss = run_training_epoch
466
+ break if (prev_loss - total_loss).abs < @tolerance
467
+
468
+ prev_loss = total_loss
469
+ end
470
+
471
+ @training_data = []
472
+ end
473
+
474
+ # @rbs () -> void
475
+ def initialize_weights
476
+ @vocabulary.each_key do |word|
477
+ @categories.each { |cat| @weights[cat][word] ||= 0.0 }
478
+ end
479
+ end
480
+
481
+ # @rbs () -> Float
482
+ def run_training_epoch
483
+ total_loss = 0.0
484
+
485
+ @training_data.shuffle.each do |sample|
486
+ probs = softmax(compute_scores(sample[:features]))
487
+ update_weights(sample[:features], sample[:category], probs)
488
+ total_loss -= Math.log([probs[sample[:category].to_s], 1e-15].max)
489
+ end
490
+
491
+ total_loss + l2_penalty
492
+ end
493
+
494
+ # @rbs (Hash[Symbol, Integer], Symbol, Hash[String, Float]) -> void
495
+ def update_weights(features, true_category, probs)
496
+ @categories.each do |cat|
497
+ error = probs[cat.to_s] - (cat == true_category ? 1.0 : 0.0)
498
+ @bias[cat] -= @learning_rate * error
499
+
500
+ features.each do |word, count|
501
+ gradient = (error * count) + (@regularization * (@weights[cat][word] || 0.0))
502
+ @weights[cat][word] ||= 0.0
503
+ @weights[cat][word] -= @learning_rate * gradient
504
+ end
505
+ end
506
+ end
507
+
508
+ # @rbs () -> Float
509
+ def l2_penalty
510
+ penalty = 0.0
511
+ @weights.each_value do |cat_weights|
512
+ cat_weights.each_value { |w| penalty += 0.5 * @regularization * w * w }
513
+ end
514
+ penalty
515
+ end
516
+
517
+ # Computes raw scores (logits) for each category.
518
+ # @rbs (Hash[Symbol, Integer]) -> Hash[Symbol, Float]
519
+ def compute_scores(features)
520
+ @categories.to_h do |cat|
521
+ score = @bias[cat]
522
+ features.each { |word, count| score += (@weights[cat][word] || 0.0) * count }
523
+ [cat, score]
524
+ end
525
+ end
526
+
527
+ # Applies softmax to convert scores to probabilities.
528
+ # @rbs (Hash[Symbol, Float]) -> Hash[String, Float]
529
+ def softmax(scores)
530
+ max_score = scores.values.max || 0.0
531
+ exp_scores = scores.transform_values { |s| Math.exp(s - max_score) }
532
+ sum = exp_scores.values.sum.to_f
533
+ exp_scores.transform_keys(&:to_s).transform_values { |e| (e / sum).to_f }
534
+ end
535
+
536
+ # Restores classifier state from JSON string.
537
+ # @rbs (String) -> void
538
+ def restore_from_json(json)
539
+ data = JSON.parse(json)
540
+ categories = data['categories'].map(&:to_sym)
541
+ restore_state(data, categories)
542
+ end
543
+
544
+ # Restores classifier state from parsed JSON data.
545
+ # @rbs (Hash[String, untyped], Array[Symbol]) -> void
546
+ def restore_state(data, categories)
547
+ mu_initialize
548
+ @categories = categories
549
+ @weights = {}
550
+ @bias = {}
551
+
552
+ data['weights'].each do |cat, words|
553
+ @weights[cat.to_sym] = words.transform_keys(&:to_sym).transform_values(&:to_f)
554
+ end
555
+
556
+ data['bias'].each do |cat, value|
557
+ @bias[cat.to_sym] = value.to_f
558
+ end
559
+
560
+ @vocabulary = data['vocabulary'].to_h { |v| [v.to_sym, true] }
561
+ @learning_rate = data['learning_rate']
562
+ @regularization = data['regularization']
563
+ @max_iterations = data['max_iterations']
564
+ @tolerance = data['tolerance']
565
+ @training_data = []
566
+ @fitted = true
567
+ @dirty = false
568
+ @storage = nil
569
+ end
570
+ end
571
+ end