classifier 2.1.0 → 2.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,593 @@
1
+ # rbs_inline: enabled
2
+
3
+ # Author:: Lucas Carlson (mailto:lucas@rufy.com)
4
+ # Copyright:: Copyright (c) 2024 Lucas Carlson
5
+ # License:: LGPL
6
+
7
+ require 'json'
8
+ require 'mutex_m'
9
+
10
+ module Classifier
11
+ # Logistic Regression (MaxEnt) classifier using Stochastic Gradient Descent.
12
+ # Often provides better accuracy than Naive Bayes while remaining fast and interpretable.
13
+ #
14
+ # Example:
15
+ # classifier = Classifier::LogisticRegression.new(:spam, :ham)
16
+ # classifier.train(spam: ["Buy now!", "Free money!!!"])
17
+ # classifier.train(ham: ["Meeting tomorrow", "Project update"])
18
+ # classifier.classify("Claim your prize!") # => "Spam"
19
+ # classifier.probabilities("Claim your prize!") # => {"Spam" => 0.92, "Ham" => 0.08}
20
+ #
21
+ class LogisticRegression # rubocop:disable Metrics/ClassLength
22
+ include Mutex_m
23
+ include Streaming
24
+
25
+ # @rbs @categories: Array[Symbol]
26
+ # @rbs @weights: Hash[Symbol, Hash[Symbol, Float]]
27
+ # @rbs @bias: Hash[Symbol, Float]
28
+ # @rbs @vocabulary: Hash[Symbol, bool]
29
+ # @rbs @training_data: Array[{category: Symbol, features: Hash[Symbol, Integer]}]
30
+ # @rbs @learning_rate: Float
31
+ # @rbs @regularization: Float
32
+ # @rbs @max_iterations: Integer
33
+ # @rbs @tolerance: Float
34
+ # @rbs @fitted: bool
35
+ # @rbs @dirty: bool
36
+ # @rbs @storage: Storage::Base?
37
+
38
+ attr_accessor :storage
39
+
40
+ DEFAULT_LEARNING_RATE = 0.1
41
+ DEFAULT_REGULARIZATION = 0.01
42
+ DEFAULT_MAX_ITERATIONS = 100
43
+ DEFAULT_TOLERANCE = 1e-4
44
+
45
+ # Creates a new Logistic Regression classifier with the specified categories.
46
+ #
47
+ # classifier = Classifier::LogisticRegression.new(:spam, :ham)
48
+ # classifier = Classifier::LogisticRegression.new('Positive', 'Negative', 'Neutral')
49
+ # classifier = Classifier::LogisticRegression.new(['Positive', 'Negative', 'Neutral'])
50
+ #
51
+ # Options:
52
+ # - learning_rate: Step size for gradient descent (default: 0.1)
53
+ # - regularization: L2 regularization strength (default: 0.01)
54
+ # - max_iterations: Maximum training iterations (default: 100)
55
+ # - tolerance: Convergence threshold (default: 1e-4)
56
+ #
57
+ # @rbs (*String | Symbol | Array[String | Symbol], ?learning_rate: Float, ?regularization: Float,
58
+ # ?max_iterations: Integer, ?tolerance: Float) -> void
59
+ def initialize(*categories, learning_rate: DEFAULT_LEARNING_RATE,
60
+ regularization: DEFAULT_REGULARIZATION,
61
+ max_iterations: DEFAULT_MAX_ITERATIONS,
62
+ tolerance: DEFAULT_TOLERANCE)
63
+ super()
64
+ categories = categories.flatten
65
+ @categories = categories.map { |c| c.to_s.prepare_category_name }
66
+ @weights = @categories.to_h { |c| [c, {}] }
67
+ @bias = @categories.to_h { |c| [c, 0.0] }
68
+ @vocabulary = {}
69
+ @training_data = []
70
+ @learning_rate = learning_rate
71
+ @regularization = regularization
72
+ @max_iterations = max_iterations
73
+ @tolerance = tolerance
74
+ @fitted = false
75
+ @dirty = false
76
+ @storage = nil
77
+ end
78
+
79
+ # Trains the classifier with text for a category.
80
+ #
81
+ # classifier.train(spam: "Buy now!", ham: ["Hello", "Meeting tomorrow"])
82
+ # classifier.train(:spam, "legacy positional API")
83
+ #
84
+ # @rbs (?(String | Symbol)?, ?String?, **(String | Array[String])) -> void
85
+ def train(category = nil, text = nil, **categories)
86
+ return train_single(category, text) if category && text
87
+
88
+ categories.each do |cat, texts|
89
+ (texts.is_a?(Array) ? texts : [texts]).each { |t| train_single(cat, t) }
90
+ end
91
+ end
92
+
93
+ # Fits the model to all accumulated training data.
94
+ # Called automatically during classify/probabilities if not already fitted.
95
+ #
96
+ # @rbs () -> self
97
+ def fit
98
+ synchronize do
99
+ return self if @training_data.empty?
100
+ raise ArgumentError, 'At least two categories required for fitting' if @categories.size < 2
101
+
102
+ optimize_weights
103
+ @fitted = true
104
+ @dirty = false
105
+ end
106
+ self
107
+ end
108
+
109
+ # Returns the best matching category for the provided text.
110
+ #
111
+ # classifier.classify("Buy now!") # => "Spam"
112
+ #
113
+ # @rbs (String) -> String
114
+ def classify(text)
115
+ probs = probabilities(text)
116
+ best = probs.max_by { |_, v| v }
117
+ raise StandardError, 'No classifications available' unless best
118
+
119
+ best.first
120
+ end
121
+
122
+ # Returns probability distribution across all categories.
123
+ # Probabilities are well-calibrated (unlike Naive Bayes).
124
+ # Raises NotFittedError if model has not been fitted.
125
+ #
126
+ # classifier.probabilities("Buy now!")
127
+ # # => {"Spam" => 0.92, "Ham" => 0.08}
128
+ #
129
+ # @rbs (String) -> Hash[String, Float]
130
+ def probabilities(text)
131
+ raise NotFittedError, 'Model not fitted. Call fit() after training.' unless @fitted
132
+
133
+ features = text.word_hash
134
+ synchronize do
135
+ softmax(compute_scores(features))
136
+ end
137
+ end
138
+
139
+ # Returns log-odds scores for each category (before softmax).
140
+ # Raises NotFittedError if model has not been fitted.
141
+ #
142
+ # @rbs (String) -> Hash[String, Float]
143
+ def classifications(text)
144
+ raise NotFittedError, 'Model not fitted. Call fit() after training.' unless @fitted
145
+
146
+ features = text.word_hash
147
+ synchronize do
148
+ compute_scores(features).transform_keys(&:to_s)
149
+ end
150
+ end
151
+
152
+ # Returns feature weights for a category, sorted by importance.
153
+ # Positive weights indicate the feature supports the category.
154
+ #
155
+ # classifier.weights(:spam)
156
+ # # => {:free => 2.3, :buy => 1.8, :money => 1.5, ...}
157
+ #
158
+ # @rbs (String | Symbol, ?limit: Integer?) -> Hash[Symbol, Float]
159
+ def weights(category, limit: nil)
160
+ fit unless @fitted
161
+
162
+ cat = category.to_s.prepare_category_name
163
+ raise StandardError, "No such category: #{cat}" unless @weights.key?(cat)
164
+
165
+ sorted = @weights[cat].sort_by { |_, v| -v.abs }
166
+ sorted = sorted.first(limit) if limit
167
+ sorted.to_h
168
+ end
169
+
170
+ # Returns the list of categories.
171
+ #
172
+ # @rbs () -> Array[String]
173
+ def categories
174
+ synchronize { @categories.map(&:to_s) }
175
+ end
176
+
177
+ # Adds a new category to the classifier.
178
+ # Allows dynamic category creation for CLI and incremental training.
179
+ #
180
+ # @rbs (String | Symbol) -> void
181
+ def add_category(category)
182
+ cat = category.to_s.prepare_category_name
183
+ synchronize do
184
+ return if @categories.include?(cat)
185
+
186
+ @categories << cat
187
+ @weights[cat] = {}
188
+ @bias[cat] = 0.0
189
+ @fitted = false
190
+ @dirty = true
191
+ end
192
+ end
193
+
194
+ # Returns true if the model has been fitted.
195
+ #
196
+ # @rbs () -> bool
197
+ def fitted?
198
+ @fitted
199
+ end
200
+
201
+ # Returns true if there are unsaved changes.
202
+ #
203
+ # @rbs () -> bool
204
+ def dirty?
205
+ @dirty
206
+ end
207
+
208
+ # Provides training methods for the categories.
209
+ # classifier.train_spam "Buy now!"
210
+ def method_missing(name, *args)
211
+ category_match = name.to_s.match(/train_(\w+)/)
212
+ return super unless category_match
213
+
214
+ category = category_match[1].to_s.prepare_category_name
215
+ raise StandardError, "No such category: #{category}" unless @categories.include?(category)
216
+
217
+ args.each { |text| train(category, text) }
218
+ end
219
+
220
+ # @rbs (Symbol, ?bool) -> bool
221
+ def respond_to_missing?(name, include_private = false)
222
+ !!(name.to_s =~ /train_(\w+)/) || super
223
+ end
224
+
225
+ # Returns a hash representation of the classifier state.
226
+ # Does NOT auto-fit; saves current state including unfitted models.
227
+ #
228
+ # @rbs (?untyped) -> Hash[Symbol, untyped]
229
+ def as_json(_options = nil)
230
+ {
231
+ version: 1,
232
+ type: 'logistic_regression',
233
+ categories: @categories.map(&:to_s),
234
+ weights: @weights.transform_keys(&:to_s).transform_values { |v| v.transform_keys(&:to_s) },
235
+ bias: @bias.transform_keys(&:to_s),
236
+ vocabulary: @vocabulary.keys.map(&:to_s),
237
+ training_data: @training_data.map { |d| { category: d[:category].to_s, features: d[:features].transform_keys(&:to_s) } },
238
+ learning_rate: @learning_rate,
239
+ regularization: @regularization,
240
+ max_iterations: @max_iterations,
241
+ tolerance: @tolerance,
242
+ fitted: @fitted
243
+ }
244
+ end
245
+
246
+ # Serializes the classifier state to a JSON string.
247
+ #
248
+ # @rbs (?untyped) -> String
249
+ def to_json(_options = nil)
250
+ JSON.generate(as_json)
251
+ end
252
+
253
+ # Loads a classifier from a JSON string or Hash.
254
+ #
255
+ # @rbs (String | Hash[String, untyped]) -> LogisticRegression
256
+ def self.from_json(json)
257
+ data = json.is_a?(String) ? JSON.parse(json) : json
258
+ raise ArgumentError, "Invalid classifier type: #{data['type']}" unless data['type'] == 'logistic_regression'
259
+
260
+ categories = data['categories'].map(&:to_sym)
261
+ instance = allocate
262
+ instance.send(:restore_state, data, categories)
263
+ instance
264
+ end
265
+
266
+ # Saves the classifier to the configured storage.
267
+ #
268
+ # @rbs () -> void
269
+ def save
270
+ raise ArgumentError, 'No storage configured' unless storage
271
+
272
+ storage.write(to_json)
273
+ @dirty = false
274
+ end
275
+
276
+ # Saves the classifier state to a file.
277
+ #
278
+ # @rbs (String) -> Integer
279
+ def save_to_file(path)
280
+ result = File.write(path, to_json)
281
+ @dirty = false
282
+ result
283
+ end
284
+
285
+ # Loads a classifier from the configured storage.
286
+ #
287
+ # @rbs (storage: Storage::Base) -> LogisticRegression
288
+ def self.load(storage:)
289
+ data = storage.read
290
+ raise StorageError, 'No saved state found' unless data
291
+
292
+ instance = from_json(data)
293
+ instance.storage = storage
294
+ instance
295
+ end
296
+
297
+ # Loads a classifier from a file.
298
+ #
299
+ # @rbs (String) -> LogisticRegression
300
+ def self.load_from_file(path)
301
+ from_json(File.read(path))
302
+ end
303
+
304
+ # Reloads the classifier from storage, raising if there are unsaved changes.
305
+ #
306
+ # @rbs () -> self
307
+ def reload
308
+ raise ArgumentError, 'No storage configured' unless storage
309
+ raise UnsavedChangesError, 'Unsaved changes would be lost. Call save first or use reload!' if @dirty
310
+
311
+ data = storage.read
312
+ raise StorageError, 'No saved state found' unless data
313
+
314
+ restore_from_json(data)
315
+ @dirty = false
316
+ self
317
+ end
318
+
319
+ # Force reloads the classifier from storage, discarding any unsaved changes.
320
+ #
321
+ # @rbs () -> self
322
+ def reload!
323
+ raise ArgumentError, 'No storage configured' unless storage
324
+
325
+ data = storage.read
326
+ raise StorageError, 'No saved state found' unless data
327
+
328
+ restore_from_json(data)
329
+ @dirty = false
330
+ self
331
+ end
332
+
333
+ # Custom marshal serialization to exclude mutex state.
334
+ #
335
+ # @rbs () -> Array[untyped]
336
+ def marshal_dump
337
+ fit unless @fitted
338
+ [@categories, @weights, @bias, @vocabulary, @learning_rate, @regularization,
339
+ @max_iterations, @tolerance, @fitted]
340
+ end
341
+
342
+ # Custom marshal deserialization to recreate mutex.
343
+ #
344
+ # @rbs (Array[untyped]) -> void
345
+ def marshal_load(data)
346
+ mu_initialize
347
+ @categories, @weights, @bias, @vocabulary, @learning_rate, @regularization,
348
+ @max_iterations, @tolerance, @fitted = data
349
+ @training_data = []
350
+ @dirty = false
351
+ @storage = nil
352
+ end
353
+
354
+ # Loads a classifier from a checkpoint.
355
+ #
356
+ # @rbs (storage: Storage::Base, checkpoint_id: String) -> LogisticRegression
357
+ def self.load_checkpoint(storage:, checkpoint_id:)
358
+ raise ArgumentError, 'Storage must be File storage for checkpoints' unless storage.is_a?(Storage::File)
359
+
360
+ dir = File.dirname(storage.path)
361
+ base = File.basename(storage.path, '.*')
362
+ ext = File.extname(storage.path)
363
+ checkpoint_path = File.join(dir, "#{base}_checkpoint_#{checkpoint_id}#{ext}")
364
+
365
+ checkpoint_storage = Storage::File.new(path: checkpoint_path)
366
+ instance = load(storage: checkpoint_storage)
367
+ instance.storage = storage
368
+ instance
369
+ end
370
+
371
+ # Trains the classifier from an IO stream.
372
+ # Each line in the stream is treated as a separate document.
373
+ # Note: The model is NOT automatically fitted after streaming.
374
+ # Call #fit to train the model after adding all data.
375
+ #
376
+ # @example Train from a file
377
+ # classifier.train_from_stream(:spam, File.open('spam_corpus.txt'))
378
+ # classifier.fit # Required to train the model
379
+ #
380
+ # @example With progress tracking
381
+ # classifier.train_from_stream(:spam, io, batch_size: 500) do |progress|
382
+ # puts "#{progress.completed} documents processed"
383
+ # end
384
+ # classifier.fit
385
+ #
386
+ # @rbs (String | Symbol, IO, ?batch_size: Integer) { (Streaming::Progress) -> void } -> void
387
+ def train_from_stream(category, io, batch_size: Streaming::DEFAULT_BATCH_SIZE)
388
+ category = category.to_s.prepare_category_name
389
+ raise StandardError, "No such category: #{category}" unless @categories.include?(category)
390
+
391
+ reader = Streaming::LineReader.new(io, batch_size: batch_size)
392
+ total = reader.estimate_line_count
393
+ progress = Streaming::Progress.new(total: total)
394
+
395
+ reader.each_batch do |batch|
396
+ synchronize do
397
+ batch.each do |text|
398
+ features = text.word_hash
399
+ features.each_key { |word| @vocabulary[word] = true }
400
+ @training_data << { category: category, features: features }
401
+ end
402
+ @fitted = false
403
+ @dirty = true
404
+ end
405
+ progress.completed += batch.size
406
+ progress.current_batch += 1
407
+ yield progress if block_given?
408
+ end
409
+ end
410
+
411
+ # Trains the classifier with an array of documents in batches.
412
+ # Note: The model is NOT automatically fitted after batch training.
413
+ # Call #fit to train the model after adding all data.
414
+ #
415
+ # @example Positional style
416
+ # classifier.train_batch(:spam, documents, batch_size: 100)
417
+ # classifier.fit
418
+ #
419
+ # @example Keyword style
420
+ # classifier.train_batch(spam: documents, ham: other_docs)
421
+ # classifier.fit
422
+ #
423
+ # @rbs (?(String | Symbol)?, ?Array[String]?, ?batch_size: Integer, **Array[String]) { (Streaming::Progress) -> void } -> void
424
+ def train_batch(category = nil, documents = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &block)
425
+ if category && documents
426
+ train_batch_for_category(category, documents, batch_size: batch_size, &block)
427
+ else
428
+ categories.each do |cat, docs|
429
+ train_batch_for_category(cat, Array(docs), batch_size: batch_size, &block)
430
+ end
431
+ end
432
+ end
433
+
434
+ private
435
+
436
+ # Trains a batch of documents for a single category.
437
+ # @rbs (String | Symbol, Array[String], ?batch_size: Integer) { (Streaming::Progress) -> void } -> void
438
+ def train_batch_for_category(category, documents, batch_size: Streaming::DEFAULT_BATCH_SIZE)
439
+ category = category.to_s.prepare_category_name
440
+ raise StandardError, "No such category: #{category}" unless @categories.include?(category)
441
+
442
+ progress = Streaming::Progress.new(total: documents.size)
443
+
444
+ documents.each_slice(batch_size) do |batch|
445
+ synchronize do
446
+ batch.each do |text|
447
+ features = text.word_hash
448
+ features.each_key { |word| @vocabulary[word] = true }
449
+ @training_data << { category: category, features: features }
450
+ end
451
+ @fitted = false
452
+ @dirty = true
453
+ end
454
+ progress.completed += batch.size
455
+ progress.current_batch += 1
456
+ yield progress if block_given?
457
+ end
458
+ end
459
+
460
+ # Core training logic for a single category and text.
461
+ # @rbs (String | Symbol, String) -> void
462
+ def train_single(category, text)
463
+ category = category.to_s.prepare_category_name
464
+ raise StandardError, "No such category: #{category}" unless @categories.include?(category)
465
+
466
+ features = text.word_hash
467
+ synchronize do
468
+ features.each_key { |word| @vocabulary[word] = true }
469
+ @training_data << { category: category, features: features }
470
+ @fitted = false
471
+ @dirty = true
472
+ end
473
+ end
474
+
475
+ # Optimizes weights using mini-batch SGD with L2 regularization.
476
+ # @rbs () -> void
477
+ def optimize_weights
478
+ return if @training_data.empty?
479
+
480
+ initialize_weights
481
+ prev_loss = Float::INFINITY
482
+
483
+ @max_iterations.times do
484
+ total_loss = run_training_epoch
485
+ break if (prev_loss - total_loss).abs < @tolerance
486
+
487
+ prev_loss = total_loss
488
+ end
489
+
490
+ @training_data = []
491
+ end
492
+
493
+ # @rbs () -> void
494
+ def initialize_weights
495
+ @vocabulary.each_key do |word|
496
+ @categories.each { |cat| @weights[cat][word] ||= 0.0 }
497
+ end
498
+ end
499
+
500
+ # @rbs () -> Float
501
+ def run_training_epoch
502
+ total_loss = 0.0
503
+
504
+ @training_data.shuffle.each do |sample|
505
+ probs = softmax(compute_scores(sample[:features]))
506
+ update_weights(sample[:features], sample[:category], probs)
507
+ total_loss -= Math.log([probs[sample[:category].to_s], 1e-15].max)
508
+ end
509
+
510
+ total_loss + l2_penalty
511
+ end
512
+
513
+ # @rbs (Hash[Symbol, Integer], Symbol, Hash[String, Float]) -> void
514
+ def update_weights(features, true_category, probs)
515
+ @categories.each do |cat|
516
+ error = probs[cat.to_s] - (cat == true_category ? 1.0 : 0.0)
517
+ @bias[cat] -= @learning_rate * error
518
+
519
+ features.each do |word, count|
520
+ gradient = (error * count) + (@regularization * (@weights[cat][word] || 0.0))
521
+ @weights[cat][word] ||= 0.0
522
+ @weights[cat][word] -= @learning_rate * gradient
523
+ end
524
+ end
525
+ end
526
+
527
+ # @rbs () -> Float
528
+ def l2_penalty
529
+ penalty = 0.0
530
+ @weights.each_value do |cat_weights|
531
+ cat_weights.each_value { |w| penalty += 0.5 * @regularization * w * w }
532
+ end
533
+ penalty
534
+ end
535
+
536
+ # Computes raw scores (logits) for each category.
537
+ # @rbs (Hash[Symbol, Integer]) -> Hash[Symbol, Float]
538
+ def compute_scores(features)
539
+ @categories.to_h do |cat|
540
+ score = @bias[cat]
541
+ features.each { |word, count| score += (@weights[cat][word] || 0.0) * count }
542
+ [cat, score]
543
+ end
544
+ end
545
+
546
+ # Applies softmax to convert scores to probabilities.
547
+ # @rbs (Hash[Symbol, Float]) -> Hash[String, Float]
548
+ def softmax(scores)
549
+ max_score = scores.values.max || 0.0
550
+ exp_scores = scores.transform_values { |s| Math.exp(s - max_score) }
551
+ sum = exp_scores.values.sum.to_f
552
+ exp_scores.transform_keys(&:to_s).transform_values { |e| (e / sum).to_f }
553
+ end
554
+
555
+ # Restores classifier state from JSON string.
556
+ # @rbs (String) -> void
557
+ def restore_from_json(json)
558
+ data = JSON.parse(json)
559
+ categories = data['categories'].map(&:to_sym)
560
+ restore_state(data, categories)
561
+ end
562
+
563
+ # Restores classifier state from parsed JSON data.
564
+ # @rbs (Hash[String, untyped], Array[Symbol]) -> void
565
+ def restore_state(data, categories)
566
+ mu_initialize
567
+ @categories = categories
568
+ restore_weights_and_bias(data)
569
+ restore_hyperparameters(data)
570
+ @fitted = data.fetch('fitted', true)
571
+ @dirty = false
572
+ @storage = nil
573
+ end
574
+
575
+ def restore_weights_and_bias(data)
576
+ @weights = {}
577
+ @bias = {}
578
+ data['weights'].each { |cat, words| @weights[cat.to_sym] = words.transform_keys(&:to_sym).transform_values(&:to_f) }
579
+ data['bias'].each { |cat, value| @bias[cat.to_sym] = value.to_f }
580
+ @vocabulary = data['vocabulary'].to_h { |v| [v.to_sym, true] }
581
+ @training_data = (data['training_data'] || []).map do |d|
582
+ { category: d['category'].to_sym, features: d['features'].transform_keys(&:to_sym).transform_values(&:to_i) }
583
+ end
584
+ end
585
+
586
+ def restore_hyperparameters(data)
587
+ @learning_rate = data['learning_rate']
588
+ @regularization = data['regularization']
589
+ @max_iterations = data['max_iterations']
590
+ @tolerance = data['tolerance']
591
+ end
592
+ end
593
+ end