classifier 2.1.0 → 2.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,351 @@
1
+ # rbs_inline: enabled
2
+
3
+ # Author:: Lucas Carlson (mailto:lucas@rufy.com)
4
+ # Copyright:: Copyright (c) 2024 Lucas Carlson
5
+ # License:: LGPL
6
+
7
+ require 'json'
8
+ require 'mutex_m'
9
+ require 'classifier/lsi'
10
+
11
+ module Classifier
12
+ # Instance-based classification: stores examples and classifies by similarity.
13
+ #
14
+ # Example:
15
+ # knn = Classifier::KNN.new(k: 3)
16
+ # knn.add("spam" => ["Buy now!", "Limited offer!"])
17
+ # knn.add("ham" => ["Meeting tomorrow", "Project update"])
18
+ # knn.classify("Special discount!") # => "spam"
19
+ #
20
+ class KNN
21
+ include Mutex_m
22
+ include Streaming
23
+
24
+ # @rbs @k: Integer
25
+ # @rbs @weighted: bool
26
+ # @rbs @lsi: LSI
27
+ # @rbs @dirty: bool
28
+ # @rbs @storage: Storage::Base?
29
+
30
+ attr_reader :k
31
+ attr_accessor :weighted, :storage
32
+
33
+ # Creates a new kNN classifier.
34
+ # @rbs (?k: Integer, ?weighted: bool) -> void
35
+ def initialize(k: 5, weighted: false) # rubocop:disable Naming/MethodParameterName
36
+ super()
37
+ validate_k!(k)
38
+ @k = k
39
+ @weighted = weighted
40
+ @lsi = LSI.new(auto_rebuild: true)
41
+ @dirty = false
42
+ @storage = nil
43
+ end
44
+
45
+ # Adds labeled examples. Keys are categories, values are items or arrays.
46
+ # Also aliased as `train` for API consistency with Bayes and LogisticRegression.
47
+ #
48
+ # knn.add(spam: "Buy now!", ham: "Meeting tomorrow")
49
+ # knn.train(spam: "Buy now!", ham: "Meeting tomorrow") # equivalent
50
+ #
51
+ # @rbs (**untyped items) -> void
52
+ def add(**items)
53
+ synchronize { @dirty = true }
54
+ @lsi.add(**items)
55
+ end
56
+
57
+ alias train add
58
+
59
+ # Classifies text using k nearest neighbors with majority voting.
60
+ # Returns the category as a String for API consistency with Bayes and LogisticRegression.
61
+ # @rbs (String) -> String?
62
+ def classify(text)
63
+ result = classify_with_neighbors(text)
64
+ result[:category]&.to_s
65
+ end
66
+
67
+ # Classifies and returns {category:, neighbors:, votes:, confidence:}.
68
+ # @rbs (String) -> Hash[Symbol, untyped]
69
+ def classify_with_neighbors(text)
70
+ synchronize do
71
+ return empty_result if @lsi.items.empty?
72
+
73
+ neighbors = find_neighbors(text)
74
+ return empty_result if neighbors.empty?
75
+
76
+ votes = tally_votes(neighbors)
77
+ winner = votes.max_by { |_, v| v }&.first
78
+ return empty_result unless winner
79
+
80
+ total_votes = votes.values.sum
81
+ confidence = total_votes.positive? ? votes[winner] / total_votes.to_f : 0.0
82
+
83
+ { category: winner, neighbors: neighbors, votes: votes, confidence: confidence }
84
+ end
85
+ end
86
+
87
+ # @rbs (String) -> Array[String | Symbol]
88
+ def categories_for(item)
89
+ @lsi.categories_for(item)
90
+ end
91
+
92
+ # @rbs (String) -> void
93
+ def remove_item(item)
94
+ synchronize { @dirty = true }
95
+ @lsi.remove_item(item)
96
+ end
97
+
98
+ # @rbs () -> Array[untyped]
99
+ def items
100
+ @lsi.items
101
+ end
102
+
103
+ # Returns all unique categories as strings.
104
+ # @rbs () -> Array[String]
105
+ def categories
106
+ synchronize do
107
+ @lsi.items.flat_map { |item| @lsi.categories_for(item) }.uniq.map(&:to_s)
108
+ end
109
+ end
110
+
111
+ # @rbs (Integer) -> void
112
+ def k=(value)
113
+ validate_k!(value)
114
+ @k = value
115
+ end
116
+
117
+ # Provides dynamic training methods for categories.
118
+ # For example:
119
+ # knn.train_spam "Buy now!"
120
+ # knn.train_ham "Meeting tomorrow"
121
+ def method_missing(name, *args)
122
+ category_match = name.to_s.match(/\Atrain_(\w+)\z/)
123
+ return super unless category_match
124
+
125
+ category = category_match[1].to_sym
126
+ args.each { |text| add(category => text) }
127
+ end
128
+
129
+ # @rbs (Symbol, ?bool) -> bool
130
+ def respond_to_missing?(name, include_private = false)
131
+ !!(name.to_s =~ /\Atrain_(\w+)\z/) || super
132
+ end
133
+
134
+ # @rbs (?untyped) -> untyped
135
+ def as_json(_options = nil)
136
+ {
137
+ version: 1,
138
+ type: 'knn',
139
+ k: @k,
140
+ weighted: @weighted,
141
+ lsi: @lsi.as_json
142
+ }
143
+ end
144
+
145
+ # @rbs (?untyped) -> String
146
+ def to_json(_options = nil)
147
+ as_json.to_json
148
+ end
149
+
150
+ # Loads a classifier from a JSON string or Hash.
151
+ # @rbs (String | Hash[String, untyped]) -> KNN
152
+ def self.from_json(json)
153
+ data = json.is_a?(String) ? JSON.parse(json) : json
154
+ raise ArgumentError, "Invalid classifier type: #{data['type']}" unless data['type'] == 'knn'
155
+
156
+ lsi_data = data['lsi'].dup
157
+ lsi_data['type'] = 'lsi'
158
+
159
+ instance = new(k: data['k'], weighted: data['weighted'])
160
+ instance.instance_variable_set(:@lsi, LSI.from_json(lsi_data))
161
+ instance.instance_variable_set(:@dirty, false)
162
+ instance
163
+ end
164
+
165
+ # Saves the classifier to the configured storage.
166
+ # @rbs () -> void
167
+ def save
168
+ raise ArgumentError, 'No storage configured. Use save_to_file(path) or set storage=' unless storage
169
+
170
+ storage.write(to_json)
171
+ @dirty = false
172
+ end
173
+
174
+ # Saves the classifier to a file.
175
+ # @rbs (String) -> Integer
176
+ def save_to_file(path)
177
+ result = File.write(path, to_json)
178
+ @dirty = false
179
+ result
180
+ end
181
+
182
+ # Reloads the classifier from configured storage.
183
+ # @rbs () -> self
184
+ def reload
185
+ raise ArgumentError, 'No storage configured' unless storage
186
+ raise UnsavedChangesError, 'Unsaved changes would be lost. Call save first or use reload!' if @dirty
187
+
188
+ data = storage.read
189
+ raise StorageError, 'No saved state found' unless data
190
+
191
+ restore_from_json(data)
192
+ @dirty = false
193
+ self
194
+ end
195
+
196
+ # Force reloads, discarding unsaved changes.
197
+ # @rbs () -> self
198
+ def reload!
199
+ raise ArgumentError, 'No storage configured' unless storage
200
+
201
+ data = storage.read
202
+ raise StorageError, 'No saved state found' unless data
203
+
204
+ restore_from_json(data)
205
+ @dirty = false
206
+ self
207
+ end
208
+
209
+ # @rbs () -> bool
210
+ def dirty?
211
+ @dirty
212
+ end
213
+
214
+ # Loads a classifier from configured storage.
215
+ # @rbs (storage: Storage::Base) -> KNN
216
+ def self.load(storage:)
217
+ data = storage.read
218
+ raise StorageError, 'No saved state found' unless data
219
+
220
+ instance = from_json(data)
221
+ instance.storage = storage
222
+ instance
223
+ end
224
+
225
+ # Loads a classifier from a file.
226
+ # @rbs (String) -> KNN
227
+ def self.load_from_file(path)
228
+ from_json(File.read(path))
229
+ end
230
+
231
+ # @rbs () -> Array[untyped]
232
+ def marshal_dump
233
+ [@k, @weighted, @lsi, @dirty]
234
+ end
235
+
236
+ # @rbs (Array[untyped]) -> void
237
+ def marshal_load(data)
238
+ mu_initialize
239
+ @k, @weighted, @lsi, @dirty = data
240
+ @storage = nil
241
+ end
242
+
243
+ # Loads a classifier from a checkpoint.
244
+ #
245
+ # @rbs (storage: Storage::Base, checkpoint_id: String) -> KNN
246
+ def self.load_checkpoint(storage:, checkpoint_id:)
247
+ raise ArgumentError, 'Storage must be File storage for checkpoints' unless storage.is_a?(Storage::File)
248
+
249
+ dir = File.dirname(storage.path)
250
+ base = File.basename(storage.path, '.*')
251
+ ext = File.extname(storage.path)
252
+ checkpoint_path = File.join(dir, "#{base}_checkpoint_#{checkpoint_id}#{ext}")
253
+
254
+ checkpoint_storage = Storage::File.new(path: checkpoint_path)
255
+ instance = load(storage: checkpoint_storage)
256
+ instance.storage = storage
257
+ instance
258
+ end
259
+
260
+ # Trains the classifier from an IO stream.
261
+ # Each line in the stream is treated as a separate document.
262
+ #
263
+ # @example Train from a file
264
+ # knn.train_from_stream(:spam, File.open('spam_corpus.txt'))
265
+ #
266
+ # @example With progress tracking
267
+ # knn.train_from_stream(:spam, io, batch_size: 500) do |progress|
268
+ # puts "#{progress.completed} documents processed"
269
+ # end
270
+ #
271
+ # @rbs (String | Symbol, IO, ?batch_size: Integer) { (Streaming::Progress) -> void } -> void
272
+ def train_from_stream(category, io, batch_size: Streaming::DEFAULT_BATCH_SIZE, &block)
273
+ @lsi.train_from_stream(category, io, batch_size: batch_size, &block)
274
+ synchronize { @dirty = true }
275
+ end
276
+
277
+ # Adds items in batches.
278
+ #
279
+ # @example Positional style
280
+ # knn.train_batch(:spam, documents, batch_size: 100)
281
+ #
282
+ # @example Keyword style
283
+ # knn.train_batch(spam: documents, ham: other_docs)
284
+ #
285
+ # @rbs (?(String | Symbol)?, ?Array[String]?, ?batch_size: Integer, **Array[String]) { (Streaming::Progress) -> void } -> void
286
+ def train_batch(category = nil, documents = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &block)
287
+ # @type var categories: Hash[Symbol, Array[String]]
288
+ @lsi.train_batch(category, documents, batch_size: batch_size, **categories, &block) # steep:ignore
289
+ synchronize { @dirty = true }
290
+ end
291
+
292
+ # @rbs!
293
+ # alias add_batch train_batch
294
+ alias add_batch train_batch
295
+
296
+ private
297
+
298
+ # @rbs (String) -> Array[Hash[Symbol, untyped]]
299
+ def find_neighbors(text)
300
+ proximity = @lsi.proximity_array_for_content(text)
301
+ neighbors = proximity.reject { |item, _| item == text }.first(@k)
302
+
303
+ neighbors.map do |item, similarity|
304
+ {
305
+ item: item,
306
+ category: @lsi.categories_for(item).first,
307
+ similarity: similarity
308
+ }
309
+ end
310
+ end
311
+
312
+ # @rbs (Array[Hash[Symbol, untyped]]) -> Hash[String | Symbol, Float]
313
+ def tally_votes(neighbors)
314
+ votes = Hash.new(0.0)
315
+
316
+ neighbors.each do |neighbor|
317
+ category = neighbor[:category] or next
318
+ weight = @weighted ? neighbor[:similarity] : 1.0
319
+ votes[category] += weight
320
+ end
321
+
322
+ votes
323
+ end
324
+
325
+ # @rbs () -> Hash[Symbol, untyped]
326
+ def empty_result
327
+ { category: nil, neighbors: [], votes: {}, confidence: 0.0 }
328
+ end
329
+
330
+ # @rbs (Integer) -> void
331
+ def validate_k!(val)
332
+ raise ArgumentError, "k must be a positive integer, got #{val}" unless val.is_a?(Integer) && val.positive?
333
+ end
334
+
335
+ # @rbs (String) -> void
336
+ def restore_from_json(json)
337
+ data = JSON.parse(json)
338
+ raise ArgumentError, "Invalid classifier type: #{data['type']}" unless data['type'] == 'knn'
339
+
340
+ synchronize do
341
+ @k = data['k']
342
+ @weighted = data['weighted']
343
+
344
+ lsi_data = data['lsi'].dup
345
+ lsi_data['type'] = 'lsi'
346
+ @lsi = LSI.from_json(lsi_data)
347
+ @dirty = false
348
+ end
349
+ end
350
+ end
351
+ end