classifier 2.1.0 → 2.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +66 -199
- data/ext/classifier/classifier_ext.c +1 -0
- data/ext/classifier/incremental_svd.c +393 -0
- data/ext/classifier/linalg.h +8 -0
- data/lib/classifier/bayes.rb +177 -53
- data/lib/classifier/errors.rb +3 -0
- data/lib/classifier/knn.rb +351 -0
- data/lib/classifier/logistic_regression.rb +571 -0
- data/lib/classifier/lsi/incremental_svd.rb +166 -0
- data/lib/classifier/lsi/summary.rb +25 -5
- data/lib/classifier/lsi.rb +365 -17
- data/lib/classifier/streaming/line_reader.rb +99 -0
- data/lib/classifier/streaming/progress.rb +96 -0
- data/lib/classifier/streaming.rb +122 -0
- data/lib/classifier/tfidf.rb +408 -0
- data/lib/classifier.rb +4 -0
- data/sig/vendor/matrix.rbs +25 -14
- data/sig/vendor/streaming.rbs +14 -0
- metadata +17 -4
|
@@ -0,0 +1,351 @@
|
|
|
1
|
+
# rbs_inline: enabled
|
|
2
|
+
|
|
3
|
+
# Author:: Lucas Carlson (mailto:lucas@rufy.com)
|
|
4
|
+
# Copyright:: Copyright (c) 2024 Lucas Carlson
|
|
5
|
+
# License:: LGPL
|
|
6
|
+
|
|
7
|
+
require 'json'
|
|
8
|
+
require 'mutex_m'
|
|
9
|
+
require 'classifier/lsi'
|
|
10
|
+
|
|
11
|
+
module Classifier
|
|
12
|
+
# Instance-based classification: stores examples and classifies by similarity.
|
|
13
|
+
#
|
|
14
|
+
# Example:
|
|
15
|
+
# knn = Classifier::KNN.new(k: 3)
|
|
16
|
+
# knn.add("spam" => ["Buy now!", "Limited offer!"])
|
|
17
|
+
# knn.add("ham" => ["Meeting tomorrow", "Project update"])
|
|
18
|
+
# knn.classify("Special discount!") # => "spam"
|
|
19
|
+
#
|
|
20
|
+
class KNN
|
|
21
|
+
include Mutex_m
|
|
22
|
+
include Streaming
|
|
23
|
+
|
|
24
|
+
# @rbs @k: Integer
|
|
25
|
+
# @rbs @weighted: bool
|
|
26
|
+
# @rbs @lsi: LSI
|
|
27
|
+
# @rbs @dirty: bool
|
|
28
|
+
# @rbs @storage: Storage::Base?
|
|
29
|
+
|
|
30
|
+
attr_reader :k
|
|
31
|
+
attr_accessor :weighted, :storage
|
|
32
|
+
|
|
33
|
+
# Creates a new kNN classifier.
|
|
34
|
+
# @rbs (?k: Integer, ?weighted: bool) -> void
|
|
35
|
+
def initialize(k: 5, weighted: false) # rubocop:disable Naming/MethodParameterName
|
|
36
|
+
super()
|
|
37
|
+
validate_k!(k)
|
|
38
|
+
@k = k
|
|
39
|
+
@weighted = weighted
|
|
40
|
+
@lsi = LSI.new(auto_rebuild: true)
|
|
41
|
+
@dirty = false
|
|
42
|
+
@storage = nil
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
# Adds labeled examples. Keys are categories, values are items or arrays.
|
|
46
|
+
# Also aliased as `train` for API consistency with Bayes and LogisticRegression.
|
|
47
|
+
#
|
|
48
|
+
# knn.add(spam: "Buy now!", ham: "Meeting tomorrow")
|
|
49
|
+
# knn.train(spam: "Buy now!", ham: "Meeting tomorrow") # equivalent
|
|
50
|
+
#
|
|
51
|
+
# @rbs (**untyped items) -> void
|
|
52
|
+
def add(**items)
|
|
53
|
+
synchronize { @dirty = true }
|
|
54
|
+
@lsi.add(**items)
|
|
55
|
+
end
|
|
56
|
+
|
|
57
|
+
alias train add
|
|
58
|
+
|
|
59
|
+
# Classifies text using k nearest neighbors with majority voting.
|
|
60
|
+
# Returns the category as a String for API consistency with Bayes and LogisticRegression.
|
|
61
|
+
# @rbs (String) -> String?
|
|
62
|
+
def classify(text)
|
|
63
|
+
result = classify_with_neighbors(text)
|
|
64
|
+
result[:category]&.to_s
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
# Classifies and returns {category:, neighbors:, votes:, confidence:}.
|
|
68
|
+
# @rbs (String) -> Hash[Symbol, untyped]
|
|
69
|
+
def classify_with_neighbors(text)
|
|
70
|
+
synchronize do
|
|
71
|
+
return empty_result if @lsi.items.empty?
|
|
72
|
+
|
|
73
|
+
neighbors = find_neighbors(text)
|
|
74
|
+
return empty_result if neighbors.empty?
|
|
75
|
+
|
|
76
|
+
votes = tally_votes(neighbors)
|
|
77
|
+
winner = votes.max_by { |_, v| v }&.first
|
|
78
|
+
return empty_result unless winner
|
|
79
|
+
|
|
80
|
+
total_votes = votes.values.sum
|
|
81
|
+
confidence = total_votes.positive? ? votes[winner] / total_votes.to_f : 0.0
|
|
82
|
+
|
|
83
|
+
{ category: winner, neighbors: neighbors, votes: votes, confidence: confidence }
|
|
84
|
+
end
|
|
85
|
+
end
|
|
86
|
+
|
|
87
|
+
# @rbs (String) -> Array[String | Symbol]
|
|
88
|
+
def categories_for(item)
|
|
89
|
+
@lsi.categories_for(item)
|
|
90
|
+
end
|
|
91
|
+
|
|
92
|
+
# @rbs (String) -> void
|
|
93
|
+
def remove_item(item)
|
|
94
|
+
synchronize { @dirty = true }
|
|
95
|
+
@lsi.remove_item(item)
|
|
96
|
+
end
|
|
97
|
+
|
|
98
|
+
# @rbs () -> Array[untyped]
|
|
99
|
+
def items
|
|
100
|
+
@lsi.items
|
|
101
|
+
end
|
|
102
|
+
|
|
103
|
+
# Returns all unique categories as strings.
|
|
104
|
+
# @rbs () -> Array[String]
|
|
105
|
+
def categories
|
|
106
|
+
synchronize do
|
|
107
|
+
@lsi.items.flat_map { |item| @lsi.categories_for(item) }.uniq.map(&:to_s)
|
|
108
|
+
end
|
|
109
|
+
end
|
|
110
|
+
|
|
111
|
+
# @rbs (Integer) -> void
|
|
112
|
+
def k=(value)
|
|
113
|
+
validate_k!(value)
|
|
114
|
+
@k = value
|
|
115
|
+
end
|
|
116
|
+
|
|
117
|
+
# Provides dynamic training methods for categories.
|
|
118
|
+
# For example:
|
|
119
|
+
# knn.train_spam "Buy now!"
|
|
120
|
+
# knn.train_ham "Meeting tomorrow"
|
|
121
|
+
def method_missing(name, *args)
|
|
122
|
+
category_match = name.to_s.match(/\Atrain_(\w+)\z/)
|
|
123
|
+
return super unless category_match
|
|
124
|
+
|
|
125
|
+
category = category_match[1].to_sym
|
|
126
|
+
args.each { |text| add(category => text) }
|
|
127
|
+
end
|
|
128
|
+
|
|
129
|
+
# @rbs (Symbol, ?bool) -> bool
|
|
130
|
+
def respond_to_missing?(name, include_private = false)
|
|
131
|
+
!!(name.to_s =~ /\Atrain_(\w+)\z/) || super
|
|
132
|
+
end
|
|
133
|
+
|
|
134
|
+
# @rbs (?untyped) -> untyped
|
|
135
|
+
def as_json(_options = nil)
|
|
136
|
+
{
|
|
137
|
+
version: 1,
|
|
138
|
+
type: 'knn',
|
|
139
|
+
k: @k,
|
|
140
|
+
weighted: @weighted,
|
|
141
|
+
lsi: @lsi.as_json
|
|
142
|
+
}
|
|
143
|
+
end
|
|
144
|
+
|
|
145
|
+
# @rbs (?untyped) -> String
|
|
146
|
+
def to_json(_options = nil)
|
|
147
|
+
as_json.to_json
|
|
148
|
+
end
|
|
149
|
+
|
|
150
|
+
# Loads a classifier from a JSON string or Hash.
|
|
151
|
+
# @rbs (String | Hash[String, untyped]) -> KNN
|
|
152
|
+
def self.from_json(json)
|
|
153
|
+
data = json.is_a?(String) ? JSON.parse(json) : json
|
|
154
|
+
raise ArgumentError, "Invalid classifier type: #{data['type']}" unless data['type'] == 'knn'
|
|
155
|
+
|
|
156
|
+
lsi_data = data['lsi'].dup
|
|
157
|
+
lsi_data['type'] = 'lsi'
|
|
158
|
+
|
|
159
|
+
instance = new(k: data['k'], weighted: data['weighted'])
|
|
160
|
+
instance.instance_variable_set(:@lsi, LSI.from_json(lsi_data))
|
|
161
|
+
instance.instance_variable_set(:@dirty, false)
|
|
162
|
+
instance
|
|
163
|
+
end
|
|
164
|
+
|
|
165
|
+
# Saves the classifier to the configured storage.
|
|
166
|
+
# @rbs () -> void
|
|
167
|
+
def save
|
|
168
|
+
raise ArgumentError, 'No storage configured. Use save_to_file(path) or set storage=' unless storage
|
|
169
|
+
|
|
170
|
+
storage.write(to_json)
|
|
171
|
+
@dirty = false
|
|
172
|
+
end
|
|
173
|
+
|
|
174
|
+
# Saves the classifier to a file.
|
|
175
|
+
# @rbs (String) -> Integer
|
|
176
|
+
def save_to_file(path)
|
|
177
|
+
result = File.write(path, to_json)
|
|
178
|
+
@dirty = false
|
|
179
|
+
result
|
|
180
|
+
end
|
|
181
|
+
|
|
182
|
+
# Reloads the classifier from configured storage.
|
|
183
|
+
# @rbs () -> self
|
|
184
|
+
def reload
|
|
185
|
+
raise ArgumentError, 'No storage configured' unless storage
|
|
186
|
+
raise UnsavedChangesError, 'Unsaved changes would be lost. Call save first or use reload!' if @dirty
|
|
187
|
+
|
|
188
|
+
data = storage.read
|
|
189
|
+
raise StorageError, 'No saved state found' unless data
|
|
190
|
+
|
|
191
|
+
restore_from_json(data)
|
|
192
|
+
@dirty = false
|
|
193
|
+
self
|
|
194
|
+
end
|
|
195
|
+
|
|
196
|
+
# Force reloads, discarding unsaved changes.
|
|
197
|
+
# @rbs () -> self
|
|
198
|
+
def reload!
|
|
199
|
+
raise ArgumentError, 'No storage configured' unless storage
|
|
200
|
+
|
|
201
|
+
data = storage.read
|
|
202
|
+
raise StorageError, 'No saved state found' unless data
|
|
203
|
+
|
|
204
|
+
restore_from_json(data)
|
|
205
|
+
@dirty = false
|
|
206
|
+
self
|
|
207
|
+
end
|
|
208
|
+
|
|
209
|
+
# @rbs () -> bool
|
|
210
|
+
def dirty?
|
|
211
|
+
@dirty
|
|
212
|
+
end
|
|
213
|
+
|
|
214
|
+
# Loads a classifier from configured storage.
|
|
215
|
+
# @rbs (storage: Storage::Base) -> KNN
|
|
216
|
+
def self.load(storage:)
|
|
217
|
+
data = storage.read
|
|
218
|
+
raise StorageError, 'No saved state found' unless data
|
|
219
|
+
|
|
220
|
+
instance = from_json(data)
|
|
221
|
+
instance.storage = storage
|
|
222
|
+
instance
|
|
223
|
+
end
|
|
224
|
+
|
|
225
|
+
# Loads a classifier from a file.
|
|
226
|
+
# @rbs (String) -> KNN
|
|
227
|
+
def self.load_from_file(path)
|
|
228
|
+
from_json(File.read(path))
|
|
229
|
+
end
|
|
230
|
+
|
|
231
|
+
# @rbs () -> Array[untyped]
|
|
232
|
+
def marshal_dump
|
|
233
|
+
[@k, @weighted, @lsi, @dirty]
|
|
234
|
+
end
|
|
235
|
+
|
|
236
|
+
# @rbs (Array[untyped]) -> void
|
|
237
|
+
def marshal_load(data)
|
|
238
|
+
mu_initialize
|
|
239
|
+
@k, @weighted, @lsi, @dirty = data
|
|
240
|
+
@storage = nil
|
|
241
|
+
end
|
|
242
|
+
|
|
243
|
+
# Loads a classifier from a checkpoint.
|
|
244
|
+
#
|
|
245
|
+
# @rbs (storage: Storage::Base, checkpoint_id: String) -> KNN
|
|
246
|
+
def self.load_checkpoint(storage:, checkpoint_id:)
|
|
247
|
+
raise ArgumentError, 'Storage must be File storage for checkpoints' unless storage.is_a?(Storage::File)
|
|
248
|
+
|
|
249
|
+
dir = File.dirname(storage.path)
|
|
250
|
+
base = File.basename(storage.path, '.*')
|
|
251
|
+
ext = File.extname(storage.path)
|
|
252
|
+
checkpoint_path = File.join(dir, "#{base}_checkpoint_#{checkpoint_id}#{ext}")
|
|
253
|
+
|
|
254
|
+
checkpoint_storage = Storage::File.new(path: checkpoint_path)
|
|
255
|
+
instance = load(storage: checkpoint_storage)
|
|
256
|
+
instance.storage = storage
|
|
257
|
+
instance
|
|
258
|
+
end
|
|
259
|
+
|
|
260
|
+
# Trains the classifier from an IO stream.
|
|
261
|
+
# Each line in the stream is treated as a separate document.
|
|
262
|
+
#
|
|
263
|
+
# @example Train from a file
|
|
264
|
+
# knn.train_from_stream(:spam, File.open('spam_corpus.txt'))
|
|
265
|
+
#
|
|
266
|
+
# @example With progress tracking
|
|
267
|
+
# knn.train_from_stream(:spam, io, batch_size: 500) do |progress|
|
|
268
|
+
# puts "#{progress.completed} documents processed"
|
|
269
|
+
# end
|
|
270
|
+
#
|
|
271
|
+
# @rbs (String | Symbol, IO, ?batch_size: Integer) { (Streaming::Progress) -> void } -> void
|
|
272
|
+
def train_from_stream(category, io, batch_size: Streaming::DEFAULT_BATCH_SIZE, &block)
|
|
273
|
+
@lsi.train_from_stream(category, io, batch_size: batch_size, &block)
|
|
274
|
+
synchronize { @dirty = true }
|
|
275
|
+
end
|
|
276
|
+
|
|
277
|
+
# Adds items in batches.
|
|
278
|
+
#
|
|
279
|
+
# @example Positional style
|
|
280
|
+
# knn.train_batch(:spam, documents, batch_size: 100)
|
|
281
|
+
#
|
|
282
|
+
# @example Keyword style
|
|
283
|
+
# knn.train_batch(spam: documents, ham: other_docs)
|
|
284
|
+
#
|
|
285
|
+
# @rbs (?(String | Symbol)?, ?Array[String]?, ?batch_size: Integer, **Array[String]) { (Streaming::Progress) -> void } -> void
|
|
286
|
+
def train_batch(category = nil, documents = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &block)
|
|
287
|
+
# @type var categories: Hash[Symbol, Array[String]]
|
|
288
|
+
@lsi.train_batch(category, documents, batch_size: batch_size, **categories, &block) # steep:ignore
|
|
289
|
+
synchronize { @dirty = true }
|
|
290
|
+
end
|
|
291
|
+
|
|
292
|
+
# @rbs!
|
|
293
|
+
# alias add_batch train_batch
|
|
294
|
+
alias add_batch train_batch
|
|
295
|
+
|
|
296
|
+
private
|
|
297
|
+
|
|
298
|
+
# @rbs (String) -> Array[Hash[Symbol, untyped]]
|
|
299
|
+
def find_neighbors(text)
|
|
300
|
+
proximity = @lsi.proximity_array_for_content(text)
|
|
301
|
+
neighbors = proximity.reject { |item, _| item == text }.first(@k)
|
|
302
|
+
|
|
303
|
+
neighbors.map do |item, similarity|
|
|
304
|
+
{
|
|
305
|
+
item: item,
|
|
306
|
+
category: @lsi.categories_for(item).first,
|
|
307
|
+
similarity: similarity
|
|
308
|
+
}
|
|
309
|
+
end
|
|
310
|
+
end
|
|
311
|
+
|
|
312
|
+
# @rbs (Array[Hash[Symbol, untyped]]) -> Hash[String | Symbol, Float]
|
|
313
|
+
def tally_votes(neighbors)
|
|
314
|
+
votes = Hash.new(0.0)
|
|
315
|
+
|
|
316
|
+
neighbors.each do |neighbor|
|
|
317
|
+
category = neighbor[:category] or next
|
|
318
|
+
weight = @weighted ? neighbor[:similarity] : 1.0
|
|
319
|
+
votes[category] += weight
|
|
320
|
+
end
|
|
321
|
+
|
|
322
|
+
votes
|
|
323
|
+
end
|
|
324
|
+
|
|
325
|
+
# @rbs () -> Hash[Symbol, untyped]
|
|
326
|
+
def empty_result
|
|
327
|
+
{ category: nil, neighbors: [], votes: {}, confidence: 0.0 }
|
|
328
|
+
end
|
|
329
|
+
|
|
330
|
+
# @rbs (Integer) -> void
|
|
331
|
+
def validate_k!(val)
|
|
332
|
+
raise ArgumentError, "k must be a positive integer, got #{val}" unless val.is_a?(Integer) && val.positive?
|
|
333
|
+
end
|
|
334
|
+
|
|
335
|
+
# @rbs (String) -> void
|
|
336
|
+
def restore_from_json(json)
|
|
337
|
+
data = JSON.parse(json)
|
|
338
|
+
raise ArgumentError, "Invalid classifier type: #{data['type']}" unless data['type'] == 'knn'
|
|
339
|
+
|
|
340
|
+
synchronize do
|
|
341
|
+
@k = data['k']
|
|
342
|
+
@weighted = data['weighted']
|
|
343
|
+
|
|
344
|
+
lsi_data = data['lsi'].dup
|
|
345
|
+
lsi_data['type'] = 'lsi'
|
|
346
|
+
@lsi = LSI.from_json(lsi_data)
|
|
347
|
+
@dirty = false
|
|
348
|
+
end
|
|
349
|
+
end
|
|
350
|
+
end
|
|
351
|
+
end
|