classifier 2.1.0 → 2.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +66 -199
- data/ext/classifier/classifier_ext.c +1 -0
- data/ext/classifier/incremental_svd.c +393 -0
- data/ext/classifier/linalg.h +8 -0
- data/lib/classifier/bayes.rb +177 -53
- data/lib/classifier/errors.rb +3 -0
- data/lib/classifier/knn.rb +351 -0
- data/lib/classifier/logistic_regression.rb +571 -0
- data/lib/classifier/lsi/incremental_svd.rb +166 -0
- data/lib/classifier/lsi/summary.rb +25 -5
- data/lib/classifier/lsi.rb +365 -17
- data/lib/classifier/streaming/line_reader.rb +99 -0
- data/lib/classifier/streaming/progress.rb +96 -0
- data/lib/classifier/streaming.rb +122 -0
- data/lib/classifier/tfidf.rb +408 -0
- data/lib/classifier.rb +4 -0
- data/sig/vendor/matrix.rbs +25 -14
- data/sig/vendor/streaming.rbs +14 -0
- metadata +17 -4
|
@@ -0,0 +1,571 @@
|
|
|
1
|
+
# rbs_inline: enabled
|
|
2
|
+
|
|
3
|
+
# Author:: Lucas Carlson (mailto:lucas@rufy.com)
|
|
4
|
+
# Copyright:: Copyright (c) 2024 Lucas Carlson
|
|
5
|
+
# License:: LGPL
|
|
6
|
+
|
|
7
|
+
require 'json'
|
|
8
|
+
require 'mutex_m'
|
|
9
|
+
|
|
10
|
+
module Classifier
|
|
11
|
+
# Logistic Regression (MaxEnt) classifier using Stochastic Gradient Descent.
|
|
12
|
+
# Often provides better accuracy than Naive Bayes while remaining fast and interpretable.
|
|
13
|
+
#
|
|
14
|
+
# Example:
|
|
15
|
+
# classifier = Classifier::LogisticRegression.new(:spam, :ham)
|
|
16
|
+
# classifier.train(spam: ["Buy now!", "Free money!!!"])
|
|
17
|
+
# classifier.train(ham: ["Meeting tomorrow", "Project update"])
|
|
18
|
+
# classifier.classify("Claim your prize!") # => "Spam"
|
|
19
|
+
# classifier.probabilities("Claim your prize!") # => {"Spam" => 0.92, "Ham" => 0.08}
|
|
20
|
+
#
|
|
21
|
+
class LogisticRegression # rubocop:disable Metrics/ClassLength
|
|
22
|
+
include Mutex_m
|
|
23
|
+
include Streaming
|
|
24
|
+
|
|
25
|
+
# @rbs @categories: Array[Symbol]
|
|
26
|
+
# @rbs @weights: Hash[Symbol, Hash[Symbol, Float]]
|
|
27
|
+
# @rbs @bias: Hash[Symbol, Float]
|
|
28
|
+
# @rbs @vocabulary: Hash[Symbol, bool]
|
|
29
|
+
# @rbs @training_data: Array[{category: Symbol, features: Hash[Symbol, Integer]}]
|
|
30
|
+
# @rbs @learning_rate: Float
|
|
31
|
+
# @rbs @regularization: Float
|
|
32
|
+
# @rbs @max_iterations: Integer
|
|
33
|
+
# @rbs @tolerance: Float
|
|
34
|
+
# @rbs @fitted: bool
|
|
35
|
+
# @rbs @dirty: bool
|
|
36
|
+
# @rbs @storage: Storage::Base?
|
|
37
|
+
|
|
38
|
+
attr_accessor :storage
|
|
39
|
+
|
|
40
|
+
DEFAULT_LEARNING_RATE = 0.1
|
|
41
|
+
DEFAULT_REGULARIZATION = 0.01
|
|
42
|
+
DEFAULT_MAX_ITERATIONS = 100
|
|
43
|
+
DEFAULT_TOLERANCE = 1e-4
|
|
44
|
+
|
|
45
|
+
# Creates a new Logistic Regression classifier with the specified categories.
|
|
46
|
+
#
|
|
47
|
+
# classifier = Classifier::LogisticRegression.new(:spam, :ham)
|
|
48
|
+
# classifier = Classifier::LogisticRegression.new('Positive', 'Negative', 'Neutral')
|
|
49
|
+
# classifier = Classifier::LogisticRegression.new(['Positive', 'Negative', 'Neutral'])
|
|
50
|
+
#
|
|
51
|
+
# Options:
|
|
52
|
+
# - learning_rate: Step size for gradient descent (default: 0.1)
|
|
53
|
+
# - regularization: L2 regularization strength (default: 0.01)
|
|
54
|
+
# - max_iterations: Maximum training iterations (default: 100)
|
|
55
|
+
# - tolerance: Convergence threshold (default: 1e-4)
|
|
56
|
+
#
|
|
57
|
+
# @rbs (*String | Symbol | Array[String | Symbol], ?learning_rate: Float, ?regularization: Float,
|
|
58
|
+
# ?max_iterations: Integer, ?tolerance: Float) -> void
|
|
59
|
+
def initialize(*categories, learning_rate: DEFAULT_LEARNING_RATE,
|
|
60
|
+
regularization: DEFAULT_REGULARIZATION,
|
|
61
|
+
max_iterations: DEFAULT_MAX_ITERATIONS,
|
|
62
|
+
tolerance: DEFAULT_TOLERANCE)
|
|
63
|
+
super()
|
|
64
|
+
categories = categories.flatten
|
|
65
|
+
raise ArgumentError, 'At least two categories required' if categories.size < 2
|
|
66
|
+
|
|
67
|
+
@categories = categories.map { |c| c.to_s.prepare_category_name }
|
|
68
|
+
@weights = @categories.to_h { |c| [c, {}] }
|
|
69
|
+
@bias = @categories.to_h { |c| [c, 0.0] }
|
|
70
|
+
@vocabulary = {}
|
|
71
|
+
@training_data = []
|
|
72
|
+
@learning_rate = learning_rate
|
|
73
|
+
@regularization = regularization
|
|
74
|
+
@max_iterations = max_iterations
|
|
75
|
+
@tolerance = tolerance
|
|
76
|
+
@fitted = false
|
|
77
|
+
@dirty = false
|
|
78
|
+
@storage = nil
|
|
79
|
+
end
|
|
80
|
+
|
|
81
|
+
# Trains the classifier with text for a category.
|
|
82
|
+
#
|
|
83
|
+
# classifier.train(spam: "Buy now!", ham: ["Hello", "Meeting tomorrow"])
|
|
84
|
+
# classifier.train(:spam, "legacy positional API")
|
|
85
|
+
#
|
|
86
|
+
# @rbs (?(String | Symbol)?, ?String?, **(String | Array[String])) -> void
|
|
87
|
+
def train(category = nil, text = nil, **categories)
|
|
88
|
+
return train_single(category, text) if category && text
|
|
89
|
+
|
|
90
|
+
categories.each do |cat, texts|
|
|
91
|
+
(texts.is_a?(Array) ? texts : [texts]).each { |t| train_single(cat, t) }
|
|
92
|
+
end
|
|
93
|
+
end
|
|
94
|
+
|
|
95
|
+
# Fits the model to all accumulated training data.
|
|
96
|
+
# Called automatically during classify/probabilities if not already fitted.
|
|
97
|
+
#
|
|
98
|
+
# @rbs () -> self
|
|
99
|
+
def fit
|
|
100
|
+
synchronize do
|
|
101
|
+
return self if @training_data.empty?
|
|
102
|
+
|
|
103
|
+
optimize_weights
|
|
104
|
+
@fitted = true
|
|
105
|
+
@dirty = false
|
|
106
|
+
end
|
|
107
|
+
self
|
|
108
|
+
end
|
|
109
|
+
|
|
110
|
+
# Returns the best matching category for the provided text.
|
|
111
|
+
#
|
|
112
|
+
# classifier.classify("Buy now!") # => "Spam"
|
|
113
|
+
#
|
|
114
|
+
# @rbs (String) -> String
|
|
115
|
+
def classify(text)
|
|
116
|
+
probs = probabilities(text)
|
|
117
|
+
best = probs.max_by { |_, v| v }
|
|
118
|
+
raise StandardError, 'No classifications available' unless best
|
|
119
|
+
|
|
120
|
+
best.first
|
|
121
|
+
end
|
|
122
|
+
|
|
123
|
+
# Returns probability distribution across all categories.
|
|
124
|
+
# Probabilities are well-calibrated (unlike Naive Bayes).
|
|
125
|
+
#
|
|
126
|
+
# classifier.probabilities("Buy now!")
|
|
127
|
+
# # => {"Spam" => 0.92, "Ham" => 0.08}
|
|
128
|
+
#
|
|
129
|
+
# @rbs (String) -> Hash[String, Float]
|
|
130
|
+
def probabilities(text)
|
|
131
|
+
fit unless @fitted
|
|
132
|
+
|
|
133
|
+
features = text.word_hash
|
|
134
|
+
synchronize do
|
|
135
|
+
softmax(compute_scores(features))
|
|
136
|
+
end
|
|
137
|
+
end
|
|
138
|
+
|
|
139
|
+
# Returns log-odds scores for each category (before softmax).
|
|
140
|
+
#
|
|
141
|
+
# @rbs (String) -> Hash[String, Float]
|
|
142
|
+
def classifications(text)
|
|
143
|
+
fit unless @fitted
|
|
144
|
+
|
|
145
|
+
features = text.word_hash
|
|
146
|
+
synchronize do
|
|
147
|
+
compute_scores(features).transform_keys(&:to_s)
|
|
148
|
+
end
|
|
149
|
+
end
|
|
150
|
+
|
|
151
|
+
# Returns feature weights for a category, sorted by importance.
|
|
152
|
+
# Positive weights indicate the feature supports the category.
|
|
153
|
+
#
|
|
154
|
+
# classifier.weights(:spam)
|
|
155
|
+
# # => {:free => 2.3, :buy => 1.8, :money => 1.5, ...}
|
|
156
|
+
#
|
|
157
|
+
# @rbs (String | Symbol, ?limit: Integer?) -> Hash[Symbol, Float]
|
|
158
|
+
def weights(category, limit: nil)
|
|
159
|
+
fit unless @fitted
|
|
160
|
+
|
|
161
|
+
cat = category.to_s.prepare_category_name
|
|
162
|
+
raise StandardError, "No such category: #{cat}" unless @weights.key?(cat)
|
|
163
|
+
|
|
164
|
+
sorted = @weights[cat].sort_by { |_, v| -v.abs }
|
|
165
|
+
sorted = sorted.first(limit) if limit
|
|
166
|
+
sorted.to_h
|
|
167
|
+
end
|
|
168
|
+
|
|
169
|
+
# Returns the list of categories.
|
|
170
|
+
#
|
|
171
|
+
# @rbs () -> Array[String]
|
|
172
|
+
def categories
|
|
173
|
+
synchronize { @categories.map(&:to_s) }
|
|
174
|
+
end
|
|
175
|
+
|
|
176
|
+
# Returns true if the model has been fitted.
|
|
177
|
+
#
|
|
178
|
+
# @rbs () -> bool
|
|
179
|
+
def fitted?
|
|
180
|
+
@fitted
|
|
181
|
+
end
|
|
182
|
+
|
|
183
|
+
# Returns true if there are unsaved changes.
|
|
184
|
+
#
|
|
185
|
+
# @rbs () -> bool
|
|
186
|
+
def dirty?
|
|
187
|
+
@dirty
|
|
188
|
+
end
|
|
189
|
+
|
|
190
|
+
# Provides training methods for the categories.
|
|
191
|
+
# classifier.train_spam "Buy now!"
|
|
192
|
+
def method_missing(name, *args)
|
|
193
|
+
category_match = name.to_s.match(/train_(\w+)/)
|
|
194
|
+
return super unless category_match
|
|
195
|
+
|
|
196
|
+
category = category_match[1].to_s.prepare_category_name
|
|
197
|
+
raise StandardError, "No such category: #{category}" unless @categories.include?(category)
|
|
198
|
+
|
|
199
|
+
args.each { |text| train(category, text) }
|
|
200
|
+
end
|
|
201
|
+
|
|
202
|
+
# @rbs (Symbol, ?bool) -> bool
|
|
203
|
+
def respond_to_missing?(name, include_private = false)
|
|
204
|
+
!!(name.to_s =~ /train_(\w+)/) || super
|
|
205
|
+
end
|
|
206
|
+
|
|
207
|
+
# Returns a hash representation of the classifier state.
|
|
208
|
+
#
|
|
209
|
+
# @rbs (?untyped) -> Hash[Symbol, untyped]
|
|
210
|
+
def as_json(_options = nil)
|
|
211
|
+
fit unless @fitted
|
|
212
|
+
|
|
213
|
+
{
|
|
214
|
+
version: 1,
|
|
215
|
+
type: 'logistic_regression',
|
|
216
|
+
categories: @categories.map(&:to_s),
|
|
217
|
+
weights: @weights.transform_keys(&:to_s).transform_values { |v| v.transform_keys(&:to_s) },
|
|
218
|
+
bias: @bias.transform_keys(&:to_s),
|
|
219
|
+
vocabulary: @vocabulary.keys.map(&:to_s),
|
|
220
|
+
learning_rate: @learning_rate,
|
|
221
|
+
regularization: @regularization,
|
|
222
|
+
max_iterations: @max_iterations,
|
|
223
|
+
tolerance: @tolerance
|
|
224
|
+
}
|
|
225
|
+
end
|
|
226
|
+
|
|
227
|
+
# Serializes the classifier state to a JSON string.
|
|
228
|
+
#
|
|
229
|
+
# @rbs (?untyped) -> String
|
|
230
|
+
def to_json(_options = nil)
|
|
231
|
+
JSON.generate(as_json)
|
|
232
|
+
end
|
|
233
|
+
|
|
234
|
+
# Loads a classifier from a JSON string or Hash.
|
|
235
|
+
#
|
|
236
|
+
# @rbs (String | Hash[String, untyped]) -> LogisticRegression
|
|
237
|
+
def self.from_json(json)
|
|
238
|
+
data = json.is_a?(String) ? JSON.parse(json) : json
|
|
239
|
+
raise ArgumentError, "Invalid classifier type: #{data['type']}" unless data['type'] == 'logistic_regression'
|
|
240
|
+
|
|
241
|
+
categories = data['categories'].map(&:to_sym)
|
|
242
|
+
instance = allocate
|
|
243
|
+
instance.send(:restore_state, data, categories)
|
|
244
|
+
instance
|
|
245
|
+
end
|
|
246
|
+
|
|
247
|
+
# Saves the classifier to the configured storage.
|
|
248
|
+
#
|
|
249
|
+
# @rbs () -> void
|
|
250
|
+
def save
|
|
251
|
+
raise ArgumentError, 'No storage configured' unless storage
|
|
252
|
+
|
|
253
|
+
storage.write(to_json)
|
|
254
|
+
@dirty = false
|
|
255
|
+
end
|
|
256
|
+
|
|
257
|
+
# Saves the classifier state to a file.
|
|
258
|
+
#
|
|
259
|
+
# @rbs (String) -> Integer
|
|
260
|
+
def save_to_file(path)
|
|
261
|
+
result = File.write(path, to_json)
|
|
262
|
+
@dirty = false
|
|
263
|
+
result
|
|
264
|
+
end
|
|
265
|
+
|
|
266
|
+
# Loads a classifier from the configured storage.
|
|
267
|
+
#
|
|
268
|
+
# @rbs (storage: Storage::Base) -> LogisticRegression
|
|
269
|
+
def self.load(storage:)
|
|
270
|
+
data = storage.read
|
|
271
|
+
raise StorageError, 'No saved state found' unless data
|
|
272
|
+
|
|
273
|
+
instance = from_json(data)
|
|
274
|
+
instance.storage = storage
|
|
275
|
+
instance
|
|
276
|
+
end
|
|
277
|
+
|
|
278
|
+
# Loads a classifier from a file.
|
|
279
|
+
#
|
|
280
|
+
# @rbs (String) -> LogisticRegression
|
|
281
|
+
def self.load_from_file(path)
|
|
282
|
+
from_json(File.read(path))
|
|
283
|
+
end
|
|
284
|
+
|
|
285
|
+
# Reloads the classifier from storage, raising if there are unsaved changes.
|
|
286
|
+
#
|
|
287
|
+
# @rbs () -> self
|
|
288
|
+
def reload
|
|
289
|
+
raise ArgumentError, 'No storage configured' unless storage
|
|
290
|
+
raise UnsavedChangesError, 'Unsaved changes would be lost. Call save first or use reload!' if @dirty
|
|
291
|
+
|
|
292
|
+
data = storage.read
|
|
293
|
+
raise StorageError, 'No saved state found' unless data
|
|
294
|
+
|
|
295
|
+
restore_from_json(data)
|
|
296
|
+
@dirty = false
|
|
297
|
+
self
|
|
298
|
+
end
|
|
299
|
+
|
|
300
|
+
# Force reloads the classifier from storage, discarding any unsaved changes.
|
|
301
|
+
#
|
|
302
|
+
# @rbs () -> self
|
|
303
|
+
def reload!
|
|
304
|
+
raise ArgumentError, 'No storage configured' unless storage
|
|
305
|
+
|
|
306
|
+
data = storage.read
|
|
307
|
+
raise StorageError, 'No saved state found' unless data
|
|
308
|
+
|
|
309
|
+
restore_from_json(data)
|
|
310
|
+
@dirty = false
|
|
311
|
+
self
|
|
312
|
+
end
|
|
313
|
+
|
|
314
|
+
# Custom marshal serialization to exclude mutex state.
|
|
315
|
+
#
|
|
316
|
+
# @rbs () -> Array[untyped]
|
|
317
|
+
def marshal_dump
|
|
318
|
+
fit unless @fitted
|
|
319
|
+
[@categories, @weights, @bias, @vocabulary, @learning_rate, @regularization,
|
|
320
|
+
@max_iterations, @tolerance, @fitted]
|
|
321
|
+
end
|
|
322
|
+
|
|
323
|
+
# Custom marshal deserialization to recreate mutex.
|
|
324
|
+
#
|
|
325
|
+
# @rbs (Array[untyped]) -> void
|
|
326
|
+
def marshal_load(data)
|
|
327
|
+
mu_initialize
|
|
328
|
+
@categories, @weights, @bias, @vocabulary, @learning_rate, @regularization,
|
|
329
|
+
@max_iterations, @tolerance, @fitted = data
|
|
330
|
+
@training_data = []
|
|
331
|
+
@dirty = false
|
|
332
|
+
@storage = nil
|
|
333
|
+
end
|
|
334
|
+
|
|
335
|
+
# Loads a classifier from a checkpoint.
|
|
336
|
+
#
|
|
337
|
+
# @rbs (storage: Storage::Base, checkpoint_id: String) -> LogisticRegression
|
|
338
|
+
def self.load_checkpoint(storage:, checkpoint_id:)
|
|
339
|
+
raise ArgumentError, 'Storage must be File storage for checkpoints' unless storage.is_a?(Storage::File)
|
|
340
|
+
|
|
341
|
+
dir = File.dirname(storage.path)
|
|
342
|
+
base = File.basename(storage.path, '.*')
|
|
343
|
+
ext = File.extname(storage.path)
|
|
344
|
+
checkpoint_path = File.join(dir, "#{base}_checkpoint_#{checkpoint_id}#{ext}")
|
|
345
|
+
|
|
346
|
+
checkpoint_storage = Storage::File.new(path: checkpoint_path)
|
|
347
|
+
instance = load(storage: checkpoint_storage)
|
|
348
|
+
instance.storage = storage
|
|
349
|
+
instance
|
|
350
|
+
end
|
|
351
|
+
|
|
352
|
+
# Trains the classifier from an IO stream.
|
|
353
|
+
# Each line in the stream is treated as a separate document.
|
|
354
|
+
# Note: The model is NOT automatically fitted after streaming.
|
|
355
|
+
# Call #fit to train the model after adding all data.
|
|
356
|
+
#
|
|
357
|
+
# @example Train from a file
|
|
358
|
+
# classifier.train_from_stream(:spam, File.open('spam_corpus.txt'))
|
|
359
|
+
# classifier.fit # Required to train the model
|
|
360
|
+
#
|
|
361
|
+
# @example With progress tracking
|
|
362
|
+
# classifier.train_from_stream(:spam, io, batch_size: 500) do |progress|
|
|
363
|
+
# puts "#{progress.completed} documents processed"
|
|
364
|
+
# end
|
|
365
|
+
# classifier.fit
|
|
366
|
+
#
|
|
367
|
+
# @rbs (String | Symbol, IO, ?batch_size: Integer) { (Streaming::Progress) -> void } -> void
|
|
368
|
+
def train_from_stream(category, io, batch_size: Streaming::DEFAULT_BATCH_SIZE)
|
|
369
|
+
category = category.to_s.prepare_category_name
|
|
370
|
+
raise StandardError, "No such category: #{category}" unless @categories.include?(category)
|
|
371
|
+
|
|
372
|
+
reader = Streaming::LineReader.new(io, batch_size: batch_size)
|
|
373
|
+
total = reader.estimate_line_count
|
|
374
|
+
progress = Streaming::Progress.new(total: total)
|
|
375
|
+
|
|
376
|
+
reader.each_batch do |batch|
|
|
377
|
+
synchronize do
|
|
378
|
+
batch.each do |text|
|
|
379
|
+
features = text.word_hash
|
|
380
|
+
features.each_key { |word| @vocabulary[word] = true }
|
|
381
|
+
@training_data << { category: category, features: features }
|
|
382
|
+
end
|
|
383
|
+
@fitted = false
|
|
384
|
+
@dirty = true
|
|
385
|
+
end
|
|
386
|
+
progress.completed += batch.size
|
|
387
|
+
progress.current_batch += 1
|
|
388
|
+
yield progress if block_given?
|
|
389
|
+
end
|
|
390
|
+
end
|
|
391
|
+
|
|
392
|
+
# Trains the classifier with an array of documents in batches.
|
|
393
|
+
# Note: The model is NOT automatically fitted after batch training.
|
|
394
|
+
# Call #fit to train the model after adding all data.
|
|
395
|
+
#
|
|
396
|
+
# @example Positional style
|
|
397
|
+
# classifier.train_batch(:spam, documents, batch_size: 100)
|
|
398
|
+
# classifier.fit
|
|
399
|
+
#
|
|
400
|
+
# @example Keyword style
|
|
401
|
+
# classifier.train_batch(spam: documents, ham: other_docs)
|
|
402
|
+
# classifier.fit
|
|
403
|
+
#
|
|
404
|
+
# @rbs (?(String | Symbol)?, ?Array[String]?, ?batch_size: Integer, **Array[String]) { (Streaming::Progress) -> void } -> void
|
|
405
|
+
def train_batch(category = nil, documents = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &block)
|
|
406
|
+
if category && documents
|
|
407
|
+
train_batch_for_category(category, documents, batch_size: batch_size, &block)
|
|
408
|
+
else
|
|
409
|
+
categories.each do |cat, docs|
|
|
410
|
+
train_batch_for_category(cat, Array(docs), batch_size: batch_size, &block)
|
|
411
|
+
end
|
|
412
|
+
end
|
|
413
|
+
end
|
|
414
|
+
|
|
415
|
+
private
|
|
416
|
+
|
|
417
|
+
# Trains a batch of documents for a single category.
|
|
418
|
+
# @rbs (String | Symbol, Array[String], ?batch_size: Integer) { (Streaming::Progress) -> void } -> void
|
|
419
|
+
def train_batch_for_category(category, documents, batch_size: Streaming::DEFAULT_BATCH_SIZE)
|
|
420
|
+
category = category.to_s.prepare_category_name
|
|
421
|
+
raise StandardError, "No such category: #{category}" unless @categories.include?(category)
|
|
422
|
+
|
|
423
|
+
progress = Streaming::Progress.new(total: documents.size)
|
|
424
|
+
|
|
425
|
+
documents.each_slice(batch_size) do |batch|
|
|
426
|
+
synchronize do
|
|
427
|
+
batch.each do |text|
|
|
428
|
+
features = text.word_hash
|
|
429
|
+
features.each_key { |word| @vocabulary[word] = true }
|
|
430
|
+
@training_data << { category: category, features: features }
|
|
431
|
+
end
|
|
432
|
+
@fitted = false
|
|
433
|
+
@dirty = true
|
|
434
|
+
end
|
|
435
|
+
progress.completed += batch.size
|
|
436
|
+
progress.current_batch += 1
|
|
437
|
+
yield progress if block_given?
|
|
438
|
+
end
|
|
439
|
+
end
|
|
440
|
+
|
|
441
|
+
# Core training logic for a single category and text.
|
|
442
|
+
# @rbs (String | Symbol, String) -> void
|
|
443
|
+
def train_single(category, text)
|
|
444
|
+
category = category.to_s.prepare_category_name
|
|
445
|
+
raise StandardError, "No such category: #{category}" unless @categories.include?(category)
|
|
446
|
+
|
|
447
|
+
features = text.word_hash
|
|
448
|
+
synchronize do
|
|
449
|
+
features.each_key { |word| @vocabulary[word] = true }
|
|
450
|
+
@training_data << { category: category, features: features }
|
|
451
|
+
@fitted = false
|
|
452
|
+
@dirty = true
|
|
453
|
+
end
|
|
454
|
+
end
|
|
455
|
+
|
|
456
|
+
# Optimizes weights using mini-batch SGD with L2 regularization.
|
|
457
|
+
# @rbs () -> void
|
|
458
|
+
def optimize_weights
|
|
459
|
+
return if @training_data.empty?
|
|
460
|
+
|
|
461
|
+
initialize_weights
|
|
462
|
+
prev_loss = Float::INFINITY
|
|
463
|
+
|
|
464
|
+
@max_iterations.times do
|
|
465
|
+
total_loss = run_training_epoch
|
|
466
|
+
break if (prev_loss - total_loss).abs < @tolerance
|
|
467
|
+
|
|
468
|
+
prev_loss = total_loss
|
|
469
|
+
end
|
|
470
|
+
|
|
471
|
+
@training_data = []
|
|
472
|
+
end
|
|
473
|
+
|
|
474
|
+
# @rbs () -> void
|
|
475
|
+
def initialize_weights
|
|
476
|
+
@vocabulary.each_key do |word|
|
|
477
|
+
@categories.each { |cat| @weights[cat][word] ||= 0.0 }
|
|
478
|
+
end
|
|
479
|
+
end
|
|
480
|
+
|
|
481
|
+
# @rbs () -> Float
|
|
482
|
+
def run_training_epoch
|
|
483
|
+
total_loss = 0.0
|
|
484
|
+
|
|
485
|
+
@training_data.shuffle.each do |sample|
|
|
486
|
+
probs = softmax(compute_scores(sample[:features]))
|
|
487
|
+
update_weights(sample[:features], sample[:category], probs)
|
|
488
|
+
total_loss -= Math.log([probs[sample[:category].to_s], 1e-15].max)
|
|
489
|
+
end
|
|
490
|
+
|
|
491
|
+
total_loss + l2_penalty
|
|
492
|
+
end
|
|
493
|
+
|
|
494
|
+
# @rbs (Hash[Symbol, Integer], Symbol, Hash[String, Float]) -> void
|
|
495
|
+
def update_weights(features, true_category, probs)
|
|
496
|
+
@categories.each do |cat|
|
|
497
|
+
error = probs[cat.to_s] - (cat == true_category ? 1.0 : 0.0)
|
|
498
|
+
@bias[cat] -= @learning_rate * error
|
|
499
|
+
|
|
500
|
+
features.each do |word, count|
|
|
501
|
+
gradient = (error * count) + (@regularization * (@weights[cat][word] || 0.0))
|
|
502
|
+
@weights[cat][word] ||= 0.0
|
|
503
|
+
@weights[cat][word] -= @learning_rate * gradient
|
|
504
|
+
end
|
|
505
|
+
end
|
|
506
|
+
end
|
|
507
|
+
|
|
508
|
+
# @rbs () -> Float
|
|
509
|
+
def l2_penalty
|
|
510
|
+
penalty = 0.0
|
|
511
|
+
@weights.each_value do |cat_weights|
|
|
512
|
+
cat_weights.each_value { |w| penalty += 0.5 * @regularization * w * w }
|
|
513
|
+
end
|
|
514
|
+
penalty
|
|
515
|
+
end
|
|
516
|
+
|
|
517
|
+
# Computes raw scores (logits) for each category.
|
|
518
|
+
# @rbs (Hash[Symbol, Integer]) -> Hash[Symbol, Float]
|
|
519
|
+
def compute_scores(features)
|
|
520
|
+
@categories.to_h do |cat|
|
|
521
|
+
score = @bias[cat]
|
|
522
|
+
features.each { |word, count| score += (@weights[cat][word] || 0.0) * count }
|
|
523
|
+
[cat, score]
|
|
524
|
+
end
|
|
525
|
+
end
|
|
526
|
+
|
|
527
|
+
# Applies softmax to convert scores to probabilities.
|
|
528
|
+
# @rbs (Hash[Symbol, Float]) -> Hash[String, Float]
|
|
529
|
+
def softmax(scores)
|
|
530
|
+
max_score = scores.values.max || 0.0
|
|
531
|
+
exp_scores = scores.transform_values { |s| Math.exp(s - max_score) }
|
|
532
|
+
sum = exp_scores.values.sum.to_f
|
|
533
|
+
exp_scores.transform_keys(&:to_s).transform_values { |e| (e / sum).to_f }
|
|
534
|
+
end
|
|
535
|
+
|
|
536
|
+
# Restores classifier state from JSON string.
|
|
537
|
+
# @rbs (String) -> void
|
|
538
|
+
def restore_from_json(json)
|
|
539
|
+
data = JSON.parse(json)
|
|
540
|
+
categories = data['categories'].map(&:to_sym)
|
|
541
|
+
restore_state(data, categories)
|
|
542
|
+
end
|
|
543
|
+
|
|
544
|
+
# Restores classifier state from parsed JSON data.
|
|
545
|
+
# @rbs (Hash[String, untyped], Array[Symbol]) -> void
|
|
546
|
+
def restore_state(data, categories)
|
|
547
|
+
mu_initialize
|
|
548
|
+
@categories = categories
|
|
549
|
+
@weights = {}
|
|
550
|
+
@bias = {}
|
|
551
|
+
|
|
552
|
+
data['weights'].each do |cat, words|
|
|
553
|
+
@weights[cat.to_sym] = words.transform_keys(&:to_sym).transform_values(&:to_f)
|
|
554
|
+
end
|
|
555
|
+
|
|
556
|
+
data['bias'].each do |cat, value|
|
|
557
|
+
@bias[cat.to_sym] = value.to_f
|
|
558
|
+
end
|
|
559
|
+
|
|
560
|
+
@vocabulary = data['vocabulary'].to_h { |v| [v.to_sym, true] }
|
|
561
|
+
@learning_rate = data['learning_rate']
|
|
562
|
+
@regularization = data['regularization']
|
|
563
|
+
@max_iterations = data['max_iterations']
|
|
564
|
+
@tolerance = data['tolerance']
|
|
565
|
+
@training_data = []
|
|
566
|
+
@fitted = true
|
|
567
|
+
@dirty = false
|
|
568
|
+
@storage = nil
|
|
569
|
+
end
|
|
570
|
+
end
|
|
571
|
+
end
|